main commit
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-16 16:30:25 +09:00
parent 91c7e04474
commit 537e7b363f
1146 changed files with 45926 additions and 77196 deletions

View File

@@ -1,73 +1,59 @@
import typing
from importlib import import_module
from warnings import warn
import pydantic_core
from pydantic_core.core_schema import (
FieldSerializationInfo,
SerializationInfo,
SerializerFunctionWrapHandler,
ValidationInfo,
ValidatorFunctionWrapHandler,
)
from . import dataclasses
from ._internal._generate_schema import GenerateSchema as GenerateSchema
from ._migration import getattr_migration
from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
from .config import ConfigDict
from .errors import *
from .fields import AliasChoices, AliasPath, Field, PrivateAttr, computed_field
from .functional_serializers import PlainSerializer, SerializeAsAny, WrapSerializer, field_serializer, model_serializer
from .functional_validators import (
AfterValidator,
BeforeValidator,
InstanceOf,
PlainValidator,
SkipValidation,
WrapValidator,
field_validator,
model_validator,
)
from .json_schema import WithJsonSchema
from .main import *
from .networks import *
from .type_adapter import TypeAdapter
from .types import *
from .validate_call import validate_call
from .version import VERSION
from .warnings import *
__version__ = VERSION
# this encourages pycharm to import `ValidationError` from here, not pydantic_core
ValidationError = pydantic_core.ValidationError
if typing.TYPE_CHECKING:
# import of virtually everything is supported via `__getattr__` below,
# but we need them here for type checking and IDE support
import pydantic_core
from pydantic_core.core_schema import (
FieldSerializationInfo,
SerializationInfo,
SerializerFunctionWrapHandler,
ValidationInfo,
ValidatorFunctionWrapHandler,
)
from . import dataclasses
from .aliases import AliasChoices, AliasGenerator, AliasPath
from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
from .config import ConfigDict, with_config
from .errors import *
from .fields import Field, PrivateAttr, computed_field
from .functional_serializers import (
PlainSerializer,
SerializeAsAny,
WrapSerializer,
field_serializer,
model_serializer,
)
from .functional_validators import (
AfterValidator,
BeforeValidator,
InstanceOf,
ModelWrapValidatorHandler,
PlainValidator,
SkipValidation,
WrapValidator,
field_validator,
model_validator,
)
from .json_schema import WithJsonSchema
from .main import *
from .networks import *
from .type_adapter import TypeAdapter
from .types import *
from .validate_call_decorator import validate_call
from .warnings import (
PydanticDeprecatedSince20,
PydanticDeprecatedSince26,
PydanticDeprecatedSince29,
PydanticDeprecatedSince210,
PydanticDeprecatedSince211,
PydanticDeprecationWarning,
PydanticExperimentalWarning,
)
# this encourages pycharm to import `ValidationError` from here, not pydantic_core
ValidationError = pydantic_core.ValidationError
# these are imported via `__getattr__` below, but we need them here for type checking and IDE support
from .deprecated.class_validators import root_validator, validator
from .deprecated.config import BaseConfig, Extra
from .deprecated.tools import *
from .root_model import RootModel
__version__ = VERSION
__all__ = (
__all__ = [
# dataclasses
'dataclasses',
# pydantic_core.core_schema
'ValidationInfo',
'ValidatorFunctionWrapHandler',
# functional validators
'field_validator',
'model_validator',
@@ -77,8 +63,6 @@ __all__ = (
'WrapValidator',
'SkipValidation',
'InstanceOf',
'ModelWrapValidatorHandler',
# JSON Schema
'WithJsonSchema',
# deprecated V1 functional validators, these are imported via `__getattr__` below
'root_validator',
@@ -89,14 +73,18 @@ __all__ = (
'PlainSerializer',
'SerializeAsAny',
'WrapSerializer',
'FieldSerializationInfo',
'SerializationInfo',
'SerializerFunctionWrapHandler',
# config
'ConfigDict',
'with_config',
# deprecated V1 config, these are imported via `__getattr__` below
'BaseConfig',
'Extra',
# validate_call
'validate_call',
# pydantic_core errors
'ValidationError',
# errors
'PydanticErrorCodes',
'PydanticUserError',
@@ -104,15 +92,11 @@ __all__ = (
'PydanticImportError',
'PydanticUndefinedAnnotation',
'PydanticInvalidForJsonSchema',
'PydanticForbiddenQualifier',
# fields
'AliasPath',
'AliasChoices',
'Field',
'computed_field',
'PrivateAttr',
# alias
'AliasChoices',
'AliasGenerator',
'AliasPath',
# main
'BaseModel',
'create_model',
@@ -121,9 +105,6 @@ __all__ = (
'AnyHttpUrl',
'FileUrl',
'HttpUrl',
'FtpUrl',
'WebsocketUrl',
'AnyWebsocketUrl',
'UrlConstraints',
'EmailStr',
'NameEmail',
@@ -136,11 +117,8 @@ __all__ = (
'RedisDsn',
'MongoDsn',
'KafkaDsn',
'NatsDsn',
'MySQLDsn',
'MariaDBDsn',
'ClickHouseDsn',
'SnowflakeDsn',
'validate_email',
# root_model
'RootModel',
@@ -175,22 +153,18 @@ __all__ = (
'UUID3',
'UUID4',
'UUID5',
'UUID6',
'UUID7',
'UUID8',
'FilePath',
'DirectoryPath',
'NewPath',
'Json',
'Secret',
'SecretStr',
'SecretBytes',
'SocketPath',
'StrictBool',
'StrictBytes',
'StrictInt',
'StrictFloat',
'PaymentCardNumber',
'PrivateAttr',
'ByteSize',
'PastDate',
'FutureDate',
@@ -208,238 +182,44 @@ __all__ = (
'Base64UrlBytes',
'Base64UrlStr',
'GetPydanticSchema',
'Tag',
'Discriminator',
'JsonValue',
'FailFast',
# type_adapter
'TypeAdapter',
# version
'__version__',
'VERSION',
# warnings
'PydanticDeprecatedSince20',
'PydanticDeprecatedSince26',
'PydanticDeprecatedSince29',
'PydanticDeprecatedSince210',
'PydanticDeprecatedSince211',
'PydanticDeprecationWarning',
'PydanticExperimentalWarning',
# annotated handlers
'GetCoreSchemaHandler',
'GetJsonSchemaHandler',
# pydantic_core
'ValidationError',
'ValidationInfo',
'SerializationInfo',
'ValidatorFunctionWrapHandler',
'FieldSerializationInfo',
'SerializerFunctionWrapHandler',
'OnErrorOmit',
)
'GenerateSchema',
]
# A mapping of {<member name>: (package, <module name>)} defining dynamic imports
_dynamic_imports: 'dict[str, tuple[str, str]]' = {
'dataclasses': (__spec__.parent, '__module__'),
# functional validators
'field_validator': (__spec__.parent, '.functional_validators'),
'model_validator': (__spec__.parent, '.functional_validators'),
'AfterValidator': (__spec__.parent, '.functional_validators'),
'BeforeValidator': (__spec__.parent, '.functional_validators'),
'PlainValidator': (__spec__.parent, '.functional_validators'),
'WrapValidator': (__spec__.parent, '.functional_validators'),
'SkipValidation': (__spec__.parent, '.functional_validators'),
'InstanceOf': (__spec__.parent, '.functional_validators'),
'ModelWrapValidatorHandler': (__spec__.parent, '.functional_validators'),
# JSON Schema
'WithJsonSchema': (__spec__.parent, '.json_schema'),
# functional serializers
'field_serializer': (__spec__.parent, '.functional_serializers'),
'model_serializer': (__spec__.parent, '.functional_serializers'),
'PlainSerializer': (__spec__.parent, '.functional_serializers'),
'SerializeAsAny': (__spec__.parent, '.functional_serializers'),
'WrapSerializer': (__spec__.parent, '.functional_serializers'),
# config
'ConfigDict': (__spec__.parent, '.config'),
'with_config': (__spec__.parent, '.config'),
# validate call
'validate_call': (__spec__.parent, '.validate_call_decorator'),
# errors
'PydanticErrorCodes': (__spec__.parent, '.errors'),
'PydanticUserError': (__spec__.parent, '.errors'),
'PydanticSchemaGenerationError': (__spec__.parent, '.errors'),
'PydanticImportError': (__spec__.parent, '.errors'),
'PydanticUndefinedAnnotation': (__spec__.parent, '.errors'),
'PydanticInvalidForJsonSchema': (__spec__.parent, '.errors'),
'PydanticForbiddenQualifier': (__spec__.parent, '.errors'),
# fields
'Field': (__spec__.parent, '.fields'),
'computed_field': (__spec__.parent, '.fields'),
'PrivateAttr': (__spec__.parent, '.fields'),
# alias
'AliasChoices': (__spec__.parent, '.aliases'),
'AliasGenerator': (__spec__.parent, '.aliases'),
'AliasPath': (__spec__.parent, '.aliases'),
# main
'BaseModel': (__spec__.parent, '.main'),
'create_model': (__spec__.parent, '.main'),
# network
'AnyUrl': (__spec__.parent, '.networks'),
'AnyHttpUrl': (__spec__.parent, '.networks'),
'FileUrl': (__spec__.parent, '.networks'),
'HttpUrl': (__spec__.parent, '.networks'),
'FtpUrl': (__spec__.parent, '.networks'),
'WebsocketUrl': (__spec__.parent, '.networks'),
'AnyWebsocketUrl': (__spec__.parent, '.networks'),
'UrlConstraints': (__spec__.parent, '.networks'),
'EmailStr': (__spec__.parent, '.networks'),
'NameEmail': (__spec__.parent, '.networks'),
'IPvAnyAddress': (__spec__.parent, '.networks'),
'IPvAnyInterface': (__spec__.parent, '.networks'),
'IPvAnyNetwork': (__spec__.parent, '.networks'),
'PostgresDsn': (__spec__.parent, '.networks'),
'CockroachDsn': (__spec__.parent, '.networks'),
'AmqpDsn': (__spec__.parent, '.networks'),
'RedisDsn': (__spec__.parent, '.networks'),
'MongoDsn': (__spec__.parent, '.networks'),
'KafkaDsn': (__spec__.parent, '.networks'),
'NatsDsn': (__spec__.parent, '.networks'),
'MySQLDsn': (__spec__.parent, '.networks'),
'MariaDBDsn': (__spec__.parent, '.networks'),
'ClickHouseDsn': (__spec__.parent, '.networks'),
'SnowflakeDsn': (__spec__.parent, '.networks'),
'validate_email': (__spec__.parent, '.networks'),
# root_model
'RootModel': (__spec__.parent, '.root_model'),
# types
'Strict': (__spec__.parent, '.types'),
'StrictStr': (__spec__.parent, '.types'),
'conbytes': (__spec__.parent, '.types'),
'conlist': (__spec__.parent, '.types'),
'conset': (__spec__.parent, '.types'),
'confrozenset': (__spec__.parent, '.types'),
'constr': (__spec__.parent, '.types'),
'StringConstraints': (__spec__.parent, '.types'),
'ImportString': (__spec__.parent, '.types'),
'conint': (__spec__.parent, '.types'),
'PositiveInt': (__spec__.parent, '.types'),
'NegativeInt': (__spec__.parent, '.types'),
'NonNegativeInt': (__spec__.parent, '.types'),
'NonPositiveInt': (__spec__.parent, '.types'),
'confloat': (__spec__.parent, '.types'),
'PositiveFloat': (__spec__.parent, '.types'),
'NegativeFloat': (__spec__.parent, '.types'),
'NonNegativeFloat': (__spec__.parent, '.types'),
'NonPositiveFloat': (__spec__.parent, '.types'),
'FiniteFloat': (__spec__.parent, '.types'),
'condecimal': (__spec__.parent, '.types'),
'condate': (__spec__.parent, '.types'),
'UUID1': (__spec__.parent, '.types'),
'UUID3': (__spec__.parent, '.types'),
'UUID4': (__spec__.parent, '.types'),
'UUID5': (__spec__.parent, '.types'),
'UUID6': (__spec__.parent, '.types'),
'UUID7': (__spec__.parent, '.types'),
'UUID8': (__spec__.parent, '.types'),
'FilePath': (__spec__.parent, '.types'),
'DirectoryPath': (__spec__.parent, '.types'),
'NewPath': (__spec__.parent, '.types'),
'Json': (__spec__.parent, '.types'),
'Secret': (__spec__.parent, '.types'),
'SecretStr': (__spec__.parent, '.types'),
'SecretBytes': (__spec__.parent, '.types'),
'StrictBool': (__spec__.parent, '.types'),
'StrictBytes': (__spec__.parent, '.types'),
'StrictInt': (__spec__.parent, '.types'),
'StrictFloat': (__spec__.parent, '.types'),
'PaymentCardNumber': (__spec__.parent, '.types'),
'ByteSize': (__spec__.parent, '.types'),
'PastDate': (__spec__.parent, '.types'),
'SocketPath': (__spec__.parent, '.types'),
'FutureDate': (__spec__.parent, '.types'),
'PastDatetime': (__spec__.parent, '.types'),
'FutureDatetime': (__spec__.parent, '.types'),
'AwareDatetime': (__spec__.parent, '.types'),
'NaiveDatetime': (__spec__.parent, '.types'),
'AllowInfNan': (__spec__.parent, '.types'),
'EncoderProtocol': (__spec__.parent, '.types'),
'EncodedBytes': (__spec__.parent, '.types'),
'EncodedStr': (__spec__.parent, '.types'),
'Base64Encoder': (__spec__.parent, '.types'),
'Base64Bytes': (__spec__.parent, '.types'),
'Base64Str': (__spec__.parent, '.types'),
'Base64UrlBytes': (__spec__.parent, '.types'),
'Base64UrlStr': (__spec__.parent, '.types'),
'GetPydanticSchema': (__spec__.parent, '.types'),
'Tag': (__spec__.parent, '.types'),
'Discriminator': (__spec__.parent, '.types'),
'JsonValue': (__spec__.parent, '.types'),
'OnErrorOmit': (__spec__.parent, '.types'),
'FailFast': (__spec__.parent, '.types'),
# type_adapter
'TypeAdapter': (__spec__.parent, '.type_adapter'),
# warnings
'PydanticDeprecatedSince20': (__spec__.parent, '.warnings'),
'PydanticDeprecatedSince26': (__spec__.parent, '.warnings'),
'PydanticDeprecatedSince29': (__spec__.parent, '.warnings'),
'PydanticDeprecatedSince210': (__spec__.parent, '.warnings'),
'PydanticDeprecatedSince211': (__spec__.parent, '.warnings'),
'PydanticDeprecationWarning': (__spec__.parent, '.warnings'),
'PydanticExperimentalWarning': (__spec__.parent, '.warnings'),
# annotated handlers
'GetCoreSchemaHandler': (__spec__.parent, '.annotated_handlers'),
'GetJsonSchemaHandler': (__spec__.parent, '.annotated_handlers'),
# pydantic_core stuff
'ValidationError': ('pydantic_core', '.'),
'ValidationInfo': ('pydantic_core', '.core_schema'),
'SerializationInfo': ('pydantic_core', '.core_schema'),
'ValidatorFunctionWrapHandler': ('pydantic_core', '.core_schema'),
'FieldSerializationInfo': ('pydantic_core', '.core_schema'),
'SerializerFunctionWrapHandler': ('pydantic_core', '.core_schema'),
# deprecated, mostly not included in __all__
'root_validator': (__spec__.parent, '.deprecated.class_validators'),
'validator': (__spec__.parent, '.deprecated.class_validators'),
'BaseConfig': (__spec__.parent, '.deprecated.config'),
'Extra': (__spec__.parent, '.deprecated.config'),
'parse_obj_as': (__spec__.parent, '.deprecated.tools'),
'schema_of': (__spec__.parent, '.deprecated.tools'),
'schema_json_of': (__spec__.parent, '.deprecated.tools'),
# deprecated dynamic imports
'RootModel': (__package__, '.root_model'),
'root_validator': (__package__, '.deprecated.class_validators'),
'validator': (__package__, '.deprecated.class_validators'),
'BaseConfig': (__package__, '.deprecated.config'),
'Extra': (__package__, '.deprecated.config'),
'parse_obj_as': (__package__, '.deprecated.tools'),
'schema_of': (__package__, '.deprecated.tools'),
'schema_json_of': (__package__, '.deprecated.tools'),
# FieldValidationInfo is deprecated, and hidden behind module a `__getattr__`
'FieldValidationInfo': ('pydantic_core', '.core_schema'),
'GenerateSchema': (__spec__.parent, '._internal._generate_schema'),
}
_deprecated_dynamic_imports = {'FieldValidationInfo', 'GenerateSchema'}
_getattr_migration = getattr_migration(__name__)
def __getattr__(attr_name: str) -> object:
if attr_name in _deprecated_dynamic_imports:
warn(
f'Importing {attr_name} from `pydantic` is deprecated. This feature is either no longer supported, or is not public.',
DeprecationWarning,
stacklevel=2,
)
dynamic_attr = _dynamic_imports.get(attr_name)
if dynamic_attr is None:
return _getattr_migration(attr_name)
package, module_name = dynamic_attr
if module_name == '__module__':
result = import_module(f'.{attr_name}', package=package)
globals()[attr_name] = result
return result
else:
module = import_module(module_name, package=package)
result = getattr(module, attr_name)
g = globals()
for k, (_, v_module_name) in _dynamic_imports.items():
if v_module_name == module_name and k not in _deprecated_dynamic_imports:
g[k] = getattr(module, k)
return result
from importlib import import_module
def __dir__() -> 'list[str]':
return list(__all__)
module = import_module(module_name, package=package)
return getattr(module, attr_name)

View File

@@ -1,23 +1,25 @@
from __future__ import annotations as _annotations
import warnings
from contextlib import contextmanager
from re import Pattern
from contextlib import contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
ContextManager,
Iterator,
cast,
)
from pydantic_core import core_schema
from typing_extensions import Self
from typing_extensions import (
Literal,
Self,
)
from ..aliases import AliasGenerator
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
from ..config import ConfigDict, ExtraValues, JsonEncoder, JsonSchemaExtraCallable
from ..errors import PydanticUserError
from ..warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince210
from ..warnings import PydanticDeprecatedSince20
if not TYPE_CHECKING:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
@@ -26,7 +28,6 @@ if not TYPE_CHECKING:
if TYPE_CHECKING:
from .._internal._schema_generation_shared import GenerateSchema
from ..fields import ComputedFieldInfo, FieldInfo
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
@@ -56,12 +57,10 @@ class ConfigWrapper:
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
# to construct error `loc`s, default `True`
loc_by_alias: bool
alias_generator: Callable[[str], str] | AliasGenerator | None
model_title_generator: Callable[[type], str] | None
field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None
alias_generator: Callable[[str], str] | None
ignored_types: tuple[type, ...]
allow_inf_nan: bool
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
json_schema_extra: dict[str, object] | JsonSchemaExtraCallable | None
json_encoders: dict[type[object], JsonEncoder] | None
# new in V2
@@ -69,13 +68,11 @@ class ConfigWrapper:
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
revalidate_instances: Literal['always', 'never', 'subclass-instances']
ser_json_timedelta: Literal['iso8601', 'float']
ser_json_bytes: Literal['utf8', 'base64', 'hex']
val_json_bytes: Literal['utf8', 'base64', 'hex']
ser_json_inf_nan: Literal['null', 'constants', 'strings']
ser_json_bytes: Literal['utf8', 'base64']
# whether to validate default values during validation, default False
validate_default: bool
validate_return: bool
protected_namespaces: tuple[str | Pattern[str], ...]
protected_namespaces: tuple[str, ...]
hide_input_in_errors: bool
defer_build: bool
plugin_settings: dict[str, object] | None
@@ -83,13 +80,6 @@ class ConfigWrapper:
json_schema_serialization_defaults_required: bool
json_schema_mode_override: Literal['validation', 'serialization', None]
coerce_numbers_to_str: bool
regex_engine: Literal['rust-regex', 'python-re']
validation_error_cause: bool
use_attribute_docstrings: bool
cache_strings: bool | Literal['all', 'keys', 'none']
validate_by_alias: bool
validate_by_name: bool
serialize_by_alias: bool
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
if check:
@@ -123,19 +113,13 @@ class ConfigWrapper:
config_class_from_namespace = namespace.get('Config')
config_dict_from_namespace = namespace.get('model_config')
raw_annotations = namespace.get('__annotations__', {})
if raw_annotations.get('model_config') and config_dict_from_namespace is None:
raise PydanticUserError(
'`model_config` cannot be used as a model field name. Use `model_config` for model configuration.',
code='model-config-invalid-field-name',
)
if config_class_from_namespace and config_dict_from_namespace:
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace)
config_new.update(config_from_namespace)
if config_from_namespace is not None:
config_new.update(config_from_namespace)
for k in list(kwargs.keys()):
if k in config_keys:
@@ -144,7 +128,7 @@ class ConfigWrapper:
return cls(config_new)
# we don't show `__getattr__` to type checkers so missing attributes cause errors
if not TYPE_CHECKING: # pragma: no branch
if not TYPE_CHECKING:
def __getattr__(self, name: str) -> Any:
try:
@@ -155,77 +139,46 @@ class ConfigWrapper:
except KeyError:
raise AttributeError(f'Config has no attribute {name!r}') from None
def core_config(self, title: str | None) -> core_schema.CoreConfig:
"""Create a pydantic-core config.
def core_config(self, obj: Any) -> core_schema.CoreConfig:
"""Create a pydantic-core config, `obj` is just used to populate `title` if not set in config.
Pass `obj=None` if you do not want to attempt to infer the `title`.
We don't use getattr here since we don't want to populate with defaults.
Args:
title: The title to use if not set in config.
obj: An object used to populate `title` if not set in config.
Returns:
A `CoreConfig` object created from config.
"""
config = self.config_dict
if config.get('schema_generator') is not None:
warnings.warn(
'The `schema_generator` setting has been deprecated since v2.10. This setting no longer has any effect.',
PydanticDeprecatedSince210,
stacklevel=2,
def dict_not_none(**kwargs: Any) -> Any:
return {k: v for k, v in kwargs.items() if v is not None}
core_config = core_schema.CoreConfig(
**dict_not_none(
title=self.config_dict.get('title') or (obj and obj.__name__),
extra_fields_behavior=self.config_dict.get('extra'),
allow_inf_nan=self.config_dict.get('allow_inf_nan'),
populate_by_name=self.config_dict.get('populate_by_name'),
str_strip_whitespace=self.config_dict.get('str_strip_whitespace'),
str_to_lower=self.config_dict.get('str_to_lower'),
str_to_upper=self.config_dict.get('str_to_upper'),
strict=self.config_dict.get('strict'),
ser_json_timedelta=self.config_dict.get('ser_json_timedelta'),
ser_json_bytes=self.config_dict.get('ser_json_bytes'),
from_attributes=self.config_dict.get('from_attributes'),
loc_by_alias=self.config_dict.get('loc_by_alias'),
revalidate_instances=self.config_dict.get('revalidate_instances'),
validate_default=self.config_dict.get('validate_default'),
str_max_length=self.config_dict.get('str_max_length'),
str_min_length=self.config_dict.get('str_min_length'),
hide_input_in_errors=self.config_dict.get('hide_input_in_errors'),
coerce_numbers_to_str=self.config_dict.get('coerce_numbers_to_str'),
)
if (populate_by_name := config.get('populate_by_name')) is not None:
# We include this patch for backwards compatibility purposes, but this config setting will be deprecated in v3.0, and likely removed in v4.0.
# Thus, the above warning and this patch can be removed then as well.
if config.get('validate_by_name') is None:
config['validate_by_alias'] = True
config['validate_by_name'] = populate_by_name
# We dynamically patch validate_by_name to be True if validate_by_alias is set to False
# and validate_by_name is not explicitly set.
if config.get('validate_by_alias') is False and config.get('validate_by_name') is None:
config['validate_by_name'] = True
if (not config.get('validate_by_alias', True)) and (not config.get('validate_by_name', False)):
raise PydanticUserError(
'At least one of `validate_by_alias` or `validate_by_name` must be set to True.',
code='validate-by-alias-and-name-false',
)
return core_schema.CoreConfig(
**{ # pyright: ignore[reportArgumentType]
k: v
for k, v in (
('title', config.get('title') or title or None),
('extra_fields_behavior', config.get('extra')),
('allow_inf_nan', config.get('allow_inf_nan')),
('str_strip_whitespace', config.get('str_strip_whitespace')),
('str_to_lower', config.get('str_to_lower')),
('str_to_upper', config.get('str_to_upper')),
('strict', config.get('strict')),
('ser_json_timedelta', config.get('ser_json_timedelta')),
('ser_json_bytes', config.get('ser_json_bytes')),
('val_json_bytes', config.get('val_json_bytes')),
('ser_json_inf_nan', config.get('ser_json_inf_nan')),
('from_attributes', config.get('from_attributes')),
('loc_by_alias', config.get('loc_by_alias')),
('revalidate_instances', config.get('revalidate_instances')),
('validate_default', config.get('validate_default')),
('str_max_length', config.get('str_max_length')),
('str_min_length', config.get('str_min_length')),
('hide_input_in_errors', config.get('hide_input_in_errors')),
('coerce_numbers_to_str', config.get('coerce_numbers_to_str')),
('regex_engine', config.get('regex_engine')),
('validation_error_cause', config.get('validation_error_cause')),
('cache_strings', config.get('cache_strings')),
('validate_by_alias', config.get('validate_by_alias')),
('validate_by_name', config.get('validate_by_name')),
('serialize_by_alias', config.get('serialize_by_alias')),
)
if v is not None
}
)
return core_config
def __repr__(self):
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
@@ -242,20 +195,22 @@ class ConfigWrapperStack:
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]
@contextmanager
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]:
if config_wrapper is None:
yield
return
return nullcontext()
if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()
@contextmanager
def _context_manager() -> Iterator[None]:
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()
return _context_manager()
config_defaults = ConfigDict(
@@ -275,8 +230,6 @@ config_defaults = ConfigDict(
from_attributes=False,
loc_by_alias=True,
alias_generator=None,
model_title_generator=None,
field_title_generator=None,
ignored_types=(),
allow_inf_nan=True,
json_schema_extra=None,
@@ -284,26 +237,17 @@ config_defaults = ConfigDict(
revalidate_instances='never',
ser_json_timedelta='iso8601',
ser_json_bytes='utf8',
val_json_bytes='utf8',
ser_json_inf_nan='null',
validate_default=False,
validate_return=False,
protected_namespaces=('model_validate', 'model_dump'),
protected_namespaces=('model_',),
hide_input_in_errors=False,
json_encoders=None,
defer_build=False,
schema_generator=None,
plugin_settings=None,
schema_generator=None,
json_schema_serialization_defaults_required=False,
json_schema_mode_override=None,
coerce_numbers_to_str=False,
regex_engine='rust-regex',
validation_error_cause=False,
use_attribute_docstrings=False,
cache_strings=True,
validate_by_alias=True,
validate_by_name=False,
serialize_by_alias=False,
)
@@ -344,7 +288,7 @@ V2_REMOVED_KEYS = {
'post_init_call',
}
V2_RENAMED_KEYS = {
'allow_population_by_field_name': 'validate_by_name',
'allow_population_by_field_name': 'populate_by_name',
'anystr_lower': 'str_to_lower',
'anystr_strip_whitespace': 'str_strip_whitespace',
'anystr_upper': 'str_to_upper',

View File

@@ -1,97 +1,92 @@
from __future__ import annotations as _annotations
from typing import TYPE_CHECKING, Any, TypedDict, cast
from warnings import warn
import typing
from typing import Any
if TYPE_CHECKING:
from ..config import JsonDict, JsonSchemaExtraCallable
import typing_extensions
if typing.TYPE_CHECKING:
from ._schema_generation_shared import (
CoreSchemaOrField as CoreSchemaOrField,
)
from ._schema_generation_shared import (
GetJsonSchemaFunction,
)
class CoreMetadata(TypedDict, total=False):
class CoreMetadata(typing_extensions.TypedDict, total=False):
"""A `TypedDict` for holding the metadata dict of the schema.
Attributes:
pydantic_js_functions: List of JSON schema functions that resolve refs during application.
pydantic_js_annotation_functions: List of JSON schema functions that don't resolve refs during application.
pydantic_js_functions: List of JSON schema functions.
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
prefer positional over keyword arguments for an 'arguments' schema.
custom validation function. Only applies to before, plain, and wrap validators.
pydantic_js_updates: key / value pair updates to apply to the JSON schema for a type.
pydantic_js_extra: WIP, either key/value pair updates to apply to the JSON schema, or a custom callable.
pydantic_internal_union_tag_key: Used internally by the `Tag` metadata to specify the tag used for a discriminated union.
pydantic_internal_union_discriminator: Used internally to specify the discriminator value for a discriminated union
when the discriminator was applied to a `'definition-ref'` schema, and that reference was missing at the time
of the annotation application.
TODO: Perhaps we should move this structure to pydantic-core. At the moment, though,
it's easier to iterate on if we leave it in pydantic until we feel there is a semi-stable API.
TODO: It's unfortunate how functionally oriented JSON schema generation is, especially that which occurs during
the core schema generation process. It's inevitable that we need to store some json schema related information
on core schemas, given that we generate JSON schemas directly from core schemas. That being said, debugging related
issues is quite difficult when JSON schema information is disguised via dynamically defined functions.
"""
pydantic_js_functions: list[GetJsonSchemaFunction]
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
pydantic_js_prefer_positional_arguments: bool
pydantic_js_updates: JsonDict
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable
pydantic_internal_union_tag_key: str
pydantic_internal_union_discriminator: str
# If `pydantic_js_prefer_positional_arguments` is True, the JSON schema generator will
# prefer positional over keyword arguments for an 'arguments' schema.
pydantic_js_prefer_positional_arguments: bool | None
pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema
def update_core_metadata(
core_metadata: Any,
/,
*,
pydantic_js_functions: list[GetJsonSchemaFunction] | None = None,
pydantic_js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
pydantic_js_updates: JsonDict | None = None,
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable | None = None,
) -> None:
from ..json_schema import PydanticJsonSchemaWarning
class CoreMetadataHandler:
"""Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents.
"""Update CoreMetadata instance in place. When we make modifications in this function, they
take effect on the `core_metadata` reference passed in as the first (and only) positional argument.
First, cast to `CoreMetadata`, then finish with a cast to `dict[str, Any]` for core schema compatibility.
We do this here, instead of before / after each call to this function so that this typing hack
can be easily removed if/when we move `CoreMetadata` to `pydantic-core`.
For parameter descriptions, see `CoreMetadata` above.
This class is used to interact with the metadata field on a CoreSchema object in a consistent
way throughout pydantic.
"""
core_metadata = cast(CoreMetadata, core_metadata)
if pydantic_js_functions:
core_metadata.setdefault('pydantic_js_functions', []).extend(pydantic_js_functions)
__slots__ = ('_schema',)
if pydantic_js_annotation_functions:
core_metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions)
def __init__(self, schema: CoreSchemaOrField):
self._schema = schema
if pydantic_js_updates:
if (existing_updates := core_metadata.get('pydantic_js_updates')) is not None:
core_metadata['pydantic_js_updates'] = {**existing_updates, **pydantic_js_updates}
else:
core_metadata['pydantic_js_updates'] = pydantic_js_updates
metadata = schema.get('metadata')
if metadata is None:
schema['metadata'] = CoreMetadata()
elif not isinstance(metadata, dict):
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
if pydantic_js_extra is not None:
existing_pydantic_js_extra = core_metadata.get('pydantic_js_extra')
if existing_pydantic_js_extra is None:
core_metadata['pydantic_js_extra'] = pydantic_js_extra
if isinstance(existing_pydantic_js_extra, dict):
if isinstance(pydantic_js_extra, dict):
core_metadata['pydantic_js_extra'] = {**existing_pydantic_js_extra, **pydantic_js_extra}
if callable(pydantic_js_extra):
warn(
'Composing `dict` and `callable` type `json_schema_extra` is not supported.'
'The `callable` type is being ignored.'
"If you'd like support for this behavior, please open an issue on pydantic.",
PydanticJsonSchemaWarning,
)
if callable(existing_pydantic_js_extra):
# if ever there's a case of a callable, we'll just keep the last json schema extra spec
core_metadata['pydantic_js_extra'] = pydantic_js_extra
@property
def metadata(self) -> CoreMetadata:
"""Retrieves the metadata dict from the schema, initializing it to a dict if it is None
and raises an error if it is not a dict.
"""
metadata = self._schema.get('metadata')
if metadata is None:
self._schema['metadata'] = metadata = CoreMetadata()
if not isinstance(metadata, dict):
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
return metadata
def build_metadata_dict(
*, # force keyword arguments to make it easier to modify this signature in a backwards-compatible way
js_functions: list[GetJsonSchemaFunction] | None = None,
js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
js_prefer_positional_arguments: bool | None = None,
typed_dict_cls: type[Any] | None = None,
initial_metadata: Any | None = None,
) -> Any:
"""Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent
with the CoreMetadataHandler class.
"""
if initial_metadata is not None and not isinstance(initial_metadata, dict):
raise TypeError(f'CoreSchema metadata should be a dict; got {initial_metadata!r}.')
metadata = CoreMetadata(
pydantic_js_functions=js_functions or [],
pydantic_js_annotation_functions=js_annotation_functions or [],
pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments,
pydantic_typed_dict_cls=typed_dict_cls,
)
metadata = {k: v for k, v in metadata.items() if v is not None}
if initial_metadata is not None:
metadata = {**initial_metadata, **metadata}
return metadata

View File

@@ -1,20 +1,22 @@
from __future__ import annotations
import inspect
import os
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union
from collections import defaultdict
from typing import (
Any,
Callable,
Hashable,
TypeVar,
Union,
_GenericAlias, # type: ignore
cast,
)
from pydantic_core import CoreSchema, core_schema
from pydantic_core import validate_core_schema as _validate_core_schema
from typing_extensions import TypeGuard, get_args, get_origin
from typing_inspection import typing_objects
from typing_extensions import TypeAliasType, TypeGuard, get_args
from . import _repr
from ._typing_extra import is_generic_alias
if TYPE_CHECKING:
from rich.console import Console
AnyFunctionSchema = Union[
core_schema.AfterValidatorFunctionSchema,
@@ -37,7 +39,19 @@ CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'}
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'}
_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'
"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of
schema building because one of it's members refers to a definition that was not yet defined when the union
was first encountered.
"""
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
schema was first encountered.
"""
def is_core_schema(
@@ -60,27 +74,28 @@ def is_function_with_inner_schema(
def is_list_like_schema_with_items_schema(
schema: CoreSchema,
) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]:
) -> TypeGuard[
core_schema.ListSchema | core_schema.TupleVariableSchema | core_schema.SetSchema | core_schema.FrozenSetSchema
]:
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
def get_type_ref(type_: Any, args_override: tuple[type[Any], ...] | None = None) -> str:
def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
"""Produces the ref to be used for this type by pydantic_core's core schemas.
This `args_override` argument was added for the purpose of creating valid recursive references
when creating generic models without needing to create a concrete class.
"""
origin = get_origin(type_) or type_
args = get_args(type_) if is_generic_alias(type_) else (args_override or ())
origin = type_
args = get_args(type_) if isinstance(type_, _GenericAlias) else (args_override or ())
generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
if generic_metadata:
origin = generic_metadata['origin'] or origin
args = generic_metadata['args'] or args
module_name = getattr(origin, '__module__', '<No __module__>')
if typing_objects.is_typealiastype(origin):
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
if isinstance(origin, TypeAliasType):
type_ref = f'{module_name}.{origin.__name__}'
else:
try:
qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
@@ -109,74 +124,457 @@ def get_ref(s: core_schema.CoreSchema) -> None | str:
return s.get('ref', None)
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
if os.getenv('PYDANTIC_VALIDATE_CORE_SCHEMAS'):
return _validate_core_schema(schema)
def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
defs: dict[str, CoreSchema] = {}
def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
ref = get_ref(s)
if ref:
defs[ref] = s
return recurse(s, _record_valid_refs)
walk_core_schema(schema, _record_valid_refs)
return defs
def define_expected_missing_refs(
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
) -> core_schema.CoreSchema | None:
if not allowed_missing_refs:
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
return None
refs = collect_definitions(schema).keys()
expected_missing_refs = allowed_missing_refs.difference(refs)
if expected_missing_refs:
definitions: list[core_schema.CoreSchema] = [
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
# Issue: https://github.com/pydantic/pydantic-core/issues/619
core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})
for ref in expected_missing_refs
]
return core_schema.definitions_schema(schema, definitions)
return None
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
invalid = False
def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
nonlocal invalid
if 'metadata' in s:
metadata = s['metadata']
if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata:
invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY]
return s
return recurse(s, _is_schema_valid)
walk_core_schema(schema, _is_schema_valid)
return invalid
T = TypeVar('T')
Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema]
Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema]
# TODO: Should we move _WalkCoreSchema into pydantic_core proper?
# Issue: https://github.com/pydantic/pydantic-core/issues/615
class _WalkCoreSchema:
def __init__(self):
self._schema_type_to_method = self._build_schema_type_to_method()
def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]:
mapping: dict[core_schema.CoreSchemaType, Recurse] = {}
key: core_schema.CoreSchemaType
for key in get_args(core_schema.CoreSchemaType):
method_name = f"handle_{key.replace('-', '_')}_schema"
mapping[key] = getattr(self, method_name, self._handle_other_schemas)
return mapping
def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
return f(schema, self._walk)
def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
if ser_schema:
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
return schema
def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
sub_schema = schema.get('schema', None)
if sub_schema is not None:
schema['schema'] = self.walk(sub_schema, f) # type: ignore
return schema
def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
schema: core_schema.CoreSchema | None = ser_schema.get('schema', None)
if schema is not None:
ser_schema['schema'] = self.walk(schema, f) # type: ignore
return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None)
if return_schema is not None:
ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore
return ser_schema
def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema:
new_definitions: list[core_schema.CoreSchema] = []
for definition in schema['definitions']:
updated_definition = self.walk(definition, f)
if 'ref' in updated_definition:
# If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
# This is most likely to happen due to replacing something with a definition reference, in
# which case it should certainly not go in the definitions list
new_definitions.append(updated_definition)
new_inner_schema = self.walk(schema['schema'], f)
if not new_definitions and len(schema) == 3:
# This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
return new_inner_schema
new_schema = schema.copy()
new_schema['schema'] = new_inner_schema
new_schema['definitions'] = new_definitions
return new_schema
def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_tuple_variable_schema(
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
) -> core_schema.CoreSchema:
schema = cast(core_schema.TupleVariableSchema, schema)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_tuple_positional_schema(
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
) -> core_schema.CoreSchema:
schema = cast(core_schema.TuplePositionalSchema, schema)
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
extras_schema = schema.get('extras_schema')
if extras_schema is not None:
schema['extras_schema'] = self.walk(extras_schema, f)
return schema
def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
keys_schema = schema.get('keys_schema')
if keys_schema is not None:
schema['keys_schema'] = self.walk(keys_schema, f)
values_schema = schema.get('values_schema')
if values_schema:
schema['values_schema'] = self.walk(values_schema, f)
return schema
def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
if not is_function_with_inner_schema(schema):
return schema
schema['schema'] = self.walk(schema['schema'], f)
return schema
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
for v in schema['choices']:
if isinstance(v, tuple):
new_choices.append((self.walk(v[0], f), v[1]))
else:
new_choices.append(self.walk(v, f))
schema['choices'] = new_choices
return schema
def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema:
new_choices: dict[Hashable, core_schema.CoreSchema] = {}
for k, v in schema['choices'].items():
new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f)
schema['choices'] = new_choices
return schema
def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema:
schema['steps'] = [self.walk(v, f) for v in schema['steps']]
return schema
def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema:
schema['lax_schema'] = self.walk(schema['lax_schema'], f)
schema['strict_schema'] = self.walk(schema['strict_schema'], f)
return schema
def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema:
schema['json_schema'] = self.walk(schema['json_schema'], f)
schema['python_schema'] = self.walk(schema['python_schema'], f)
return schema
def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
extras_schema = schema.get('extras_schema')
if extras_schema is not None:
schema['extras_schema'] = self.walk(extras_schema, f)
replaced_fields: dict[str, core_schema.ModelField] = {}
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
for k, v in schema['fields'].items():
replaced_field = v.copy()
replaced_field['schema'] = self.walk(v['schema'], f)
replaced_fields[k] = replaced_field
schema['fields'] = replaced_fields
return schema
def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
extras_schema = schema.get('extras_schema')
if extras_schema is not None:
schema['extras_schema'] = self.walk(extras_schema, f)
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
replaced_fields: dict[str, core_schema.TypedDictField] = {}
for k, v in schema['fields'].items():
replaced_field = v.copy()
replaced_field['schema'] = self.walk(v['schema'], f)
replaced_fields[k] = replaced_field
schema['fields'] = replaced_fields
return schema
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_fields: list[core_schema.DataclassField] = []
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
for field in schema['fields']:
replaced_field = field.copy()
replaced_field['schema'] = self.walk(field['schema'], f)
replaced_fields.append(replaced_field)
schema['fields'] = replaced_fields
return schema
def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_arguments_schema: list[core_schema.ArgumentsParameter] = []
for param in schema['arguments_schema']:
replaced_param = param.copy()
replaced_param['schema'] = self.walk(param['schema'], f)
replaced_arguments_schema.append(replaced_param)
schema['arguments_schema'] = replaced_arguments_schema
if 'var_args_schema' in schema:
schema['var_args_schema'] = self.walk(schema['var_args_schema'], f)
if 'var_kwargs_schema' in schema:
schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f)
return schema
def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema:
schema['arguments_schema'] = self.walk(schema['arguments_schema'], f)
if 'return_schema' in schema:
schema['return_schema'] = self.walk(schema['return_schema'], f)
return schema
_dispatch = _WalkCoreSchema().walk
def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
"""Recursively traverse a CoreSchema.
Args:
schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified.
f (Walk): A function to apply. This function takes two arguments:
1. The current CoreSchema that is being processed
(not the same one you passed into this function, one level down).
2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)`
to pass data down the recursive calls without using globals or other mutable state.
Returns:
core_schema.CoreSchema: A processed CoreSchema.
"""
return f(schema.copy(), _dispatch)
def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
definitions: dict[str, core_schema.CoreSchema] = {}
ref_counts: dict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: dict[str, int] = defaultdict(int)
def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definitions':
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
if ref not in definitions:
definitions[ref] = definition
recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
else:
ref = get_ref(s)
if ref is not None:
new = recurse(s, collect_refs)
new_ref = get_ref(new)
if new_ref:
definitions[new_ref] = new
return core_schema.definition_reference_schema(schema_ref=ref)
else:
return recurse(s, collect_refs)
schema = walk_core_schema(schema, collect_refs)
def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
ref_counts[ref] += 1
if ref_counts[ref] >= 2:
# If this model is involved in a recursion this should be detected
# on its second encounter, we can safely stop the walk here.
if current_recursion_ref_count[ref] != 0:
involved_in_recursion[ref] = True
return s
current_recursion_ref_count[ref] += 1
recurse(definitions[ref], count_refs)
current_recursion_ref_count[ref] -= 1
return s
schema = walk_core_schema(schema, count_refs)
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
if ref_counts[ref] > 1:
return False
if involved_in_recursion.get(ref, False):
return False
if 'serialization' in s:
return False
if 'metadata' in s:
metadata = s['metadata']
for k in (
'pydantic_js_functions',
'pydantic_js_annotation_functions',
'pydantic.internal.union_discriminator',
):
if k in metadata:
# we need to keep this as a ref
return False
return True
def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definition-ref':
ref = s['schema_ref']
# Check if the reference is only used once, not involved in recursion and does not have
# any extra keys (like 'serialization')
if can_be_inlined(s, ref):
# Inline the reference by replacing the reference with the actual schema
new = definitions.pop(ref)
ref_counts[ref] -= 1 # because we just replaced it!
# put all other keys that were on the def-ref schema into the inlined version
# in particular this is needed for `serialization`
if 'serialization' in s:
new['serialization'] = s['serialization']
s = recurse(new, inline_refs)
return s
else:
return recurse(s, inline_refs)
else:
return recurse(s, inline_refs)
schema = walk_core_schema(schema, inline_refs)
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore
if def_values:
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
return schema
def _clean_schema_for_pretty_print(obj: Any, strip_metadata: bool = True) -> Any: # pragma: no cover
"""A utility function to remove irrelevant information from a core schema."""
if isinstance(obj, Mapping):
new_dct = {}
for k, v in obj.items():
if k == 'metadata' and strip_metadata:
new_metadata = {}
for meta_k, meta_v in v.items():
if meta_k in ('pydantic_js_functions', 'pydantic_js_annotation_functions'):
new_metadata['js_metadata'] = '<stripped>'
else:
new_metadata[meta_k] = _clean_schema_for_pretty_print(meta_v, strip_metadata=strip_metadata)
if list(new_metadata.keys()) == ['js_metadata']:
new_metadata = {'<stripped>'}
new_dct[k] = new_metadata
# Remove some defaults:
elif k in ('custom_init', 'root_model') and not v:
continue
def _strip_metadata(schema: CoreSchema) -> CoreSchema:
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
s = s.copy()
s.pop('metadata', None)
if s['type'] == 'model-fields':
s = s.copy()
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
for field_name, field_schema in s['fields'].items():
field_schema.pop('metadata', None)
s['fields'][field_name] = field_schema
computed_fields = s.get('computed_fields', None)
if computed_fields:
s['computed_fields'] = [cf.copy() for cf in computed_fields]
for cf in computed_fields:
cf.pop('metadata', None)
else:
new_dct[k] = _clean_schema_for_pretty_print(v, strip_metadata=strip_metadata)
s.pop('computed_fields', None)
elif s['type'] == 'model':
# remove some defaults
if s.get('custom_init', True) is False:
s.pop('custom_init')
if s.get('root_model', True) is False:
s.pop('root_model')
if {'title'}.issuperset(s.get('config', {}).keys()):
s.pop('config', None)
return new_dct
elif isinstance(obj, Sequence) and not isinstance(obj, str):
return [_clean_schema_for_pretty_print(v, strip_metadata=strip_metadata) for v in obj]
else:
return obj
return recurse(s, strip_metadata)
return walk_core_schema(schema, strip_metadata)
def pretty_print_core_schema(
val: Any,
*,
console: Console | None = None,
max_depth: int | None = None,
strip_metadata: bool = True,
) -> None: # pragma: no cover
"""Pretty-print a core schema using the `rich` library.
schema: CoreSchema,
include_metadata: bool = False,
) -> None:
"""Pretty print a CoreSchema using rich.
This is intended for debugging purposes.
Args:
val: The core schema to print, or a Pydantic model/dataclass/type adapter
(in which case the cached core schema is fetched and printed).
console: A rich console to use when printing. Defaults to the global rich console instance.
max_depth: The number of nesting levels which may be printed.
strip_metadata: Whether to strip metadata in the output. If `True` any known core metadata
attributes will be stripped (but custom attributes are kept). Defaults to `True`.
schema: The CoreSchema to print.
include_metadata: Whether to include metadata in the output. Defaults to `False`.
"""
# lazy import:
from rich.pretty import pprint
from rich import print # type: ignore # install it manually in your dev env
# circ. imports:
from pydantic import BaseModel, TypeAdapter
from pydantic.dataclasses import is_pydantic_dataclass
if not include_metadata:
schema = _strip_metadata(schema)
if (inspect.isclass(val) and issubclass(val, BaseModel)) or is_pydantic_dataclass(val):
val = val.__pydantic_core_schema__
if isinstance(val, TypeAdapter):
val = val.core_schema
cleaned_schema = _clean_schema_for_pretty_print(val, strip_metadata=strip_metadata)
pprint(cleaned_schema, console=console, max_depth=max_depth)
return print(schema)
pps = pretty_print_core_schema
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ:
return schema
return _validate_core_schema(schema)

View File

@@ -1,15 +1,17 @@
"""Private logic for creating pydantic dataclasses."""
from __future__ import annotations as _annotations
import dataclasses
import inspect
import typing
import warnings
from functools import partial, wraps
from typing import Any, ClassVar
from inspect import Parameter, Signature, signature
from typing import Any, Callable, ClassVar
from pydantic_core import (
ArgsKwargs,
PydanticUndefined,
SchemaSerializer,
SchemaValidator,
core_schema,
@@ -17,22 +19,28 @@ from pydantic_core import (
from typing_extensions import TypeGuard
from ..errors import PydanticUndefinedAnnotation
from ..plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
from ..fields import FieldInfo
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators
from . import _config, _decorators, _discriminated_union, _typing_extra
from ._core_utils import collect_invalid_schemas, simplify_schema_references, validate_core_schema
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema, InvalidSchemaError
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._namespace_utils import NsResolver
from ._signature import generate_pydantic_signature
from ._utils import LazyClassAttribute
from ._mock_val_ser import set_dataclass_mock_validator
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._utils import is_valid_identifier
if typing.TYPE_CHECKING:
from _typeshed import DataclassInstance as StandardDataclass
from ..config import ConfigDict
from ..fields import FieldInfo
class StandardDataclass(typing.Protocol):
__dataclass_fields__: ClassVar[dict[str, Any]]
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
__post_init__: ClassVar[Callable[..., None]]
def __init__(self, *args: object, **kwargs: object) -> None:
pass
class PydanticDataclass(StandardDataclass, typing.Protocol):
"""A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass.
@@ -53,10 +61,7 @@ if typing.TYPE_CHECKING:
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
__pydantic_fields__: ClassVar[dict[str, FieldInfo]]
__pydantic_serializer__: ClassVar[SchemaSerializer]
__pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator]
@classmethod
def __pydantic_fields_complete__(cls) -> bool: ...
__pydantic_validator__: ClassVar[SchemaValidator]
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
@@ -64,22 +69,15 @@ else:
DeprecationWarning = PydanticDeprecatedSince20
def set_dataclass_fields(
cls: type[StandardDataclass],
ns_resolver: NsResolver | None = None,
config_wrapper: _config.ConfigWrapper | None = None,
) -> None:
def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None:
"""Collect and set `cls.__pydantic_fields__`.
Args:
cls: The class.
ns_resolver: Namespace resolver to use when getting dataclass annotations.
config_wrapper: The config wrapper instance, defaults to `None`.
types_namespace: The types namespace, defaults to `None`.
"""
typevars_map = get_standard_typevars_map(cls)
fields = collect_dataclass_fields(
cls, ns_resolver=ns_resolver, typevars_map=typevars_map, config_wrapper=config_wrapper
)
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map)
cls.__pydantic_fields__ = fields # type: ignore
@@ -89,8 +87,7 @@ def complete_dataclass(
config_wrapper: _config.ConfigWrapper,
*,
raise_errors: bool = True,
ns_resolver: NsResolver | None = None,
_force_build: bool = False,
types_namespace: dict[str, Any] | None,
) -> bool:
"""Finish building a pydantic dataclass.
@@ -102,10 +99,7 @@ def complete_dataclass(
cls: The class.
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors, defaults to `True`.
ns_resolver: The namespace resolver instance to use when collecting dataclass fields
and during schema building.
_force_build: Whether to force building the dataclass, no matter if
[`defer_build`][pydantic.config.ConfigDict.defer_build] is set.
types_namespace: The types namespace.
Returns:
`True` if building a pydantic dataclass is successfully completed, `False` otherwise.
@@ -113,94 +107,136 @@ def complete_dataclass(
Raises:
PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations.
"""
original_init = cls.__init__
if hasattr(cls, '__post_init_post_parse__'):
warnings.warn(
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
)
if types_namespace is None:
types_namespace = _typing_extra.get_cls_types_namespace(cls)
set_dataclass_fields(cls, types_namespace)
typevars_map = get_standard_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
types_namespace,
typevars_map,
)
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied,
# and so that the mock validator is used if building was deferred:
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
__tracebackhide__ = True
s = __dataclass_self__
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
__init__.__qualname__ = f'{cls.__qualname__}.__init__'
sig = generate_dataclass_signature(cls)
cls.__init__ = __init__ # type: ignore
cls.__signature__ = sig # type: ignore
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore
set_dataclass_fields(cls, ns_resolver, config_wrapper=config_wrapper)
if not _force_build and config_wrapper.defer_build:
set_dataclass_mocks(cls)
return False
if hasattr(cls, '__post_init_post_parse__'):
warnings.warn(
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
)
typevars_map = get_standard_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
ns_resolver=ns_resolver,
typevars_map=typevars_map,
)
# set __signature__ attr only for the class, but not for its instances
# (because instances can define `__call__`, and `inspect.signature` shouldn't
# use the `__signature__` attribute and instead generate from `__call__`).
cls.__signature__ = LazyClassAttribute(
'__signature__',
partial(
generate_pydantic_signature,
# It's important that we reference the `original_init` here,
# as it is the one synthesized by the stdlib `dataclass` module:
init=original_init,
fields=cls.__pydantic_fields__, # type: ignore
validate_by_name=config_wrapper.validate_by_name,
extra=config_wrapper.extra,
is_dataclass=True,
),
)
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
try:
schema = gen_schema.generate_schema(cls)
if get_core_schema:
schema = get_core_schema(
cls,
CallbackGetCoreSchemaHandler(
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
gen_schema,
ref_mode='unpack',
),
)
else:
schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False)
except PydanticUndefinedAnnotation as e:
if raise_errors:
raise
set_dataclass_mocks(cls, f'`{e.name}`')
set_dataclass_mock_validator(cls, cls.__name__, f'`{e.name}`')
return False
core_config = config_wrapper.core_config(title=cls.__name__)
core_config = config_wrapper.core_config(cls)
try:
schema = gen_schema.clean_schema(schema)
except InvalidSchemaError:
set_dataclass_mocks(cls)
schema = gen_schema.collect_definitions(schema)
if collect_invalid_schemas(schema):
set_dataclass_mock_validator(cls, cls.__name__, 'all referenced types')
return False
schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))
# We are about to set all the remaining required properties expected for this cast;
# __pydantic_decorators__ and __pydantic_fields__ should already be set
cls = typing.cast('type[PydanticDataclass]', cls)
# debug(schema)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_core_schema__ = schema = validate_core_schema(schema)
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
schema, core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
if config_wrapper.validate_assignment:
@wraps(cls.__setattr__)
def validated_setattr(instance: Any, field: str, value: str, /) -> None:
validator.validate_assignment(instance, field, value)
def validated_setattr(instance: Any, __field: str, __value: str) -> None:
validator.validate_assignment(instance, __field, __value)
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
cls.__pydantic_complete__ = True
return True
def generate_dataclass_signature(cls: type[StandardDataclass]) -> Signature:
"""Generate signature for a pydantic dataclass.
This implementation assumes we do not support custom `__init__`, which is currently true for pydantic dataclasses.
If we change this eventually, we should make this function's logic more closely mirror that from
`pydantic._internal._model_construction.generate_model_signature`.
Args:
cls: The dataclass.
Returns:
The signature.
"""
sig = signature(cls)
final_params: dict[str, Parameter] = {}
for param in sig.parameters.values():
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any
# Replace the field name with the alias if present
name = param.name
alias = param_default.alias
validation_alias = param_default.validation_alias
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
name = alias
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
name = validation_alias
# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = inspect.Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
param = param.replace(annotation=annotation, name=name, default=default)
final_params[param.name] = param
return Signature(parameters=list(final_params.values()), return_annotation=None)
def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
@@ -209,7 +245,7 @@ def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
- `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`)
- `_cls` does not have any annotations that are not dataclass fields
e.g.
```python
```py
import dataclasses
import pydantic.dataclasses

View File

@@ -1,30 +1,31 @@
"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
from __future__ import annotations as _annotations
import types
from collections import deque
from collections.abc import Iterable
from dataclasses import dataclass, field
from functools import cached_property, partial, partialmethod
from functools import partial, partialmethod
from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
from itertools import islice
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union
from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema
from typing_extensions import TypeAlias, is_typeddict
from pydantic_core import PydanticUndefined, core_schema
from typing_extensions import Literal, TypeAlias, is_typeddict
from ..errors import PydanticUserError
from ..fields import ComputedFieldInfo
from ._core_utils import get_type_ref
from ._internal_dataclass import slots_true
from ._namespace_utils import GlobalsNamespace, MappingNamespace
from ._typing_extra import get_function_type_hints
from ._utils import can_be_positional
if TYPE_CHECKING:
from ..fields import ComputedFieldInfo
from ..functional_validators import FieldValidatorModes
try:
from functools import cached_property # type: ignore
except ImportError:
# python 3.7
cached_property = None
@dataclass(**slots_true)
class ValidatorDecoratorInfo:
@@ -60,9 +61,6 @@ class FieldValidatorDecoratorInfo:
fields: A tuple of field names the validator should be called on.
mode: The proposed validator mode.
check_fields: Whether to check that the fields actually exist on the model.
json_schema_input_type: The input type of the function. This is only used to generate
the appropriate JSON Schema (in validation mode) and can only specified
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
"""
decorator_repr: ClassVar[str] = '@field_validator'
@@ -70,7 +68,6 @@ class FieldValidatorDecoratorInfo:
fields: tuple[str, ...]
mode: FieldValidatorModes
check_fields: bool | None
json_schema_input_type: Any
@dataclass(**slots_true)
@@ -135,7 +132,7 @@ class ModelValidatorDecoratorInfo:
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@model_validator'.
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
mode: The proposed serializer mode.
"""
@@ -143,7 +140,7 @@ class ModelValidatorDecoratorInfo:
mode: Literal['wrap', 'before', 'after']
DecoratorInfo: TypeAlias = """Union[
DecoratorInfo = Union[
ValidatorDecoratorInfo,
FieldValidatorDecoratorInfo,
RootValidatorDecoratorInfo,
@@ -151,7 +148,7 @@ DecoratorInfo: TypeAlias = """Union[
ModelSerializerDecoratorInfo,
ModelValidatorDecoratorInfo,
ComputedFieldInfo,
]"""
]
ReturnType = TypeVar('ReturnType')
DecoratedType: TypeAlias = (
@@ -186,12 +183,6 @@ class PydanticDescriptorProxy(Generic[ReturnType]):
def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
self.wrapped = getattr(self.wrapped, name)(func)
if isinstance(self.wrapped, property):
# update ComputedFieldInfo.wrapped_property
from ..fields import ComputedFieldInfo
if isinstance(self.decorator_info, ComputedFieldInfo):
self.decorator_info.wrapped_property = self.wrapped
return self
def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
@@ -203,11 +194,11 @@ class PydanticDescriptorProxy(Generic[ReturnType]):
def __set_name__(self, instance: Any, name: str) -> None:
if hasattr(self.wrapped, '__set_name__'):
self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
self.wrapped.__set_name__(instance, name)
def __getattr__(self, name: str, /) -> Any:
def __getattr__(self, __name: str) -> Any:
"""Forward checks for __isabstractmethod__ and such."""
return getattr(self.wrapped, name)
return getattr(self.wrapped, __name)
DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
@@ -497,8 +488,6 @@ class DecoratorInfos:
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
else:
from ..fields import ComputedFieldInfo
isinstance(var_value, ComputedFieldInfo)
res.computed_fields[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=None, info=info
@@ -509,7 +498,7 @@ class DecoratorInfos:
# so then we don't need to re-process the type, which means we can discard our descriptor wrappers
# and replace them with the thing they are wrapping (see the other setattr call below)
# which allows validator class methods to also function as regular class methods
model_dc.__pydantic_decorators__ = res
setattr(model_dc, '__pydantic_decorators__', res)
for name, value in to_replace:
setattr(model_dc, name, value)
return res
@@ -529,11 +518,12 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
"""
try:
sig = signature(validator)
except (ValueError, TypeError):
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
# In this case, we assume no info argument is present:
except ValueError:
# builtins and some C extensions don't have signatures
# assume that they don't take an info argument and only take a single argument
# e.g. `str.strip` or `datetime.datetime`
return False
n_positional = count_positional_required_params(sig)
n_positional = count_positional_params(sig)
if mode == 'wrap':
if n_positional == 3:
return True
@@ -552,7 +542,9 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
)
def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
def inspect_field_serializer(
serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False
) -> tuple[bool, bool]:
"""Look at a field serializer function and determine if it is a field serializer,
and whether it takes an info argument.
@@ -561,21 +553,18 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai
Args:
serializer: The serializer function to inspect.
mode: The serializer mode, either 'plain' or 'wrap'.
computed_field: When serializer is applied on computed_field. It doesn't require
info signature.
Returns:
Tuple of (is_field_serializer, info_arg).
"""
try:
sig = signature(serializer)
except (ValueError, TypeError):
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
# In this case, we assume no info argument is present and this is not a method:
return (False, False)
sig = signature(serializer)
first = next(iter(sig.parameters.values()), None)
is_field_serializer = first is not None and first.name == 'self'
n_positional = count_positional_required_params(sig)
n_positional = count_positional_params(sig)
if is_field_serializer:
# -1 to correct for self parameter
info_arg = _serializer_info_arg(mode, n_positional - 1)
@@ -587,8 +576,13 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='field-serializer-signature',
)
if info_arg and computed_field:
raise PydanticUserError(
'field_serializer on computed_field does not use info signature', code='field-serializer-signature'
)
return is_field_serializer, info_arg
else:
return is_field_serializer, info_arg
def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
@@ -603,13 +597,8 @@ def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['
Returns:
info_arg
"""
try:
sig = signature(serializer)
except (ValueError, TypeError):
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
# In this case, we assume no info argument is present:
return False
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
sig = signature(serializer)
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
if info_arg is None:
raise PydanticUserError(
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
@@ -637,7 +626,7 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai
)
sig = signature(serializer)
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
if info_arg is None:
raise PydanticUserError(
f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
@@ -650,18 +639,18 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai
def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
if mode == 'plain':
if n_positional == 1:
# (input_value: Any, /) -> Any
# (__input_value: Any) -> Any
return False
elif n_positional == 2:
# (model: Any, input_value: Any, /) -> Any
# (__model: Any, __input_value: Any) -> Any
return True
else:
assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
if n_positional == 2:
# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any
return False
elif n_positional == 3:
# (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any
return True
return None
@@ -722,25 +711,34 @@ def unwrap_wrapped_function(
unwrap_class_static_method: bool = True,
) -> Any:
"""Recursively unwraps a wrapped function until the underlying function is reached.
This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
This handles property, functools.partial, functools.partialmethod, staticmethod and classmethod.
Args:
func: The function to unwrap.
unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
unwrap_partial: If True (default), unwrap partial and partialmethod decorators, otherwise don't.
decorators.
unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
decorators. If False, only unwrap partial and partialmethod decorators.
Returns:
The underlying function of the wrapped function.
"""
# Define the types we want to check against as a single tuple.
unwrap_types = (
(property, cached_property)
+ ((partial, partialmethod) if unwrap_partial else ())
+ ((staticmethod, classmethod) if unwrap_class_static_method else ())
)
all: set[Any] = {property}
while isinstance(func, unwrap_types):
if unwrap_partial:
all.update({partial, partialmethod})
try:
from functools import cached_property # type: ignore
except ImportError:
cached_property = type('', (), {})
else:
all.add(cached_property)
if unwrap_class_static_method:
all.update({staticmethod, classmethod})
while isinstance(func, tuple(all)):
if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
func = func.__func__
elif isinstance(func, (partial, partialmethod)):
@@ -755,72 +753,38 @@ def unwrap_wrapped_function(
return func
_function_like = (
partial,
partialmethod,
types.FunctionType,
types.BuiltinFunctionType,
types.MethodType,
types.WrapperDescriptorType,
types.MethodWrapperType,
types.MemberDescriptorType,
)
def get_function_return_type(
func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None
) -> Any:
"""Get the function return type.
def get_callable_return_type(
callable_obj: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> Any | PydanticUndefinedType:
"""Get the callable return type.
It gets the return type from the type annotation if `explicit_return_type` is `None`.
Otherwise, it returns `explicit_return_type`.
Args:
callable_obj: The callable to analyze.
globalns: The globals namespace to use during type annotation evaluation.
localns: The locals namespace to use during type annotation evaluation.
func: The function to get its return type.
explicit_return_type: The explicit return type.
types_namespace: The types namespace, defaults to `None`.
Returns:
The function return type.
"""
if isinstance(callable_obj, type):
# types are callables, and we assume the return type
# is the type itself (e.g. `int()` results in an instance of `int`).
return callable_obj
if not isinstance(callable_obj, _function_like):
call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004
if call_func is not None:
callable_obj = call_func
hints = get_function_type_hints(
unwrap_wrapped_function(callable_obj),
include_keys={'return'},
globalns=globalns,
localns=localns,
)
return hints.get('return', PydanticUndefined)
if explicit_return_type is PydanticUndefined:
# try to get it from the type annotation
hints = get_function_type_hints(
unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace
)
return hints.get('return', PydanticUndefined)
else:
return explicit_return_type
def count_positional_required_params(sig: Signature) -> int:
"""Get the number of positional (required) arguments of a signature.
def count_positional_params(sig: Signature) -> int:
return sum(1 for param in sig.parameters.values() if can_be_positional(param))
This function should only be used to inspect signatures of validation and serialization functions.
The first argument (the value being serialized or validated) is counted as a required argument
even if a default value exists.
Returns:
The number of positional arguments of a signature.
"""
parameters = list(sig.parameters.values())
return sum(
1
for param in parameters
if can_be_positional(param)
# First argument is the value being validated/serialized, and can have a default value
# (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
# for instance) should be required, and thus without any default value.
and (param.default is Parameter.empty or param is parameters[0])
)
def can_be_positional(param: Parameter) -> bool:
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
def ensure_property(f: Any) -> Any:

View File

@@ -1,45 +1,49 @@
"""Logic for V1 validators, e.g. `@validator` and `@root_validator`."""
from __future__ import annotations as _annotations
from inspect import Parameter, signature
from typing import Any, Union, cast
from typing import Any, Dict, Tuple, Union, cast
from pydantic_core import core_schema
from typing_extensions import Protocol
from ..errors import PydanticUserError
from ._utils import can_be_positional
from ._decorators import can_be_positional
class V1OnlyValueValidator(Protocol):
"""A simple validator, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any) -> Any: ...
def __call__(self, __value: Any) -> Any:
...
class V1ValidatorWithValues(Protocol):
"""A validator with `values` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, values: dict[str, Any]) -> Any: ...
def __call__(self, __value: Any, values: dict[str, Any]) -> Any:
...
class V1ValidatorWithValuesKwOnly(Protocol):
"""A validator with keyword only `values` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: ...
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any:
...
class V1ValidatorWithKwargs(Protocol):
"""A validator with `kwargs` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, **kwargs: Any) -> Any: ...
def __call__(self, __value: Any, **kwargs: Any) -> Any:
...
class V1ValidatorWithValuesAndKwargs(Protocol):
"""A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any:
...
V1Validator = Union[
@@ -105,21 +109,23 @@ def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithI
return wrapper2
RootValidatorValues = dict[str, Any]
RootValidatorValues = Dict[str, Any]
# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars]
RootValidatorFieldsTuple = tuple[Any, ...]
RootValidatorFieldsTuple = Tuple[Any, ...]
class V1RootValidatorFunction(Protocol):
"""A simple root validator, supported for V1 validators and V2 validators."""
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: ...
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues:
...
class V2CoreBeforeRootValidator(Protocol):
"""V2 validator with mode='before'."""
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: ...
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues:
...
class V2CoreAfterRootValidator(Protocol):
@@ -127,7 +133,8 @@ class V2CoreAfterRootValidator(Protocol):
def __call__(
self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo
) -> RootValidatorFieldsTuple: ...
) -> RootValidatorFieldsTuple:
...
def make_v1_generic_root_validator(

View File

@@ -1,19 +1,19 @@
from __future__ import annotations as _annotations
from collections.abc import Hashable, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import Any, Hashable, Sequence
from pydantic_core import CoreSchema, core_schema
from ..errors import PydanticUserError
from . import _core_utils
from ._core_utils import (
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaField,
collect_definitions,
simplify_schema_references,
)
if TYPE_CHECKING:
from ..types import Discriminator
from ._core_metadata import CoreMetadata
CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'
class MissingDefinitionForUnionRef(Exception):
@@ -26,15 +26,39 @@ class MissingDefinitionForUnionRef(Exception):
super().__init__(f'Missing definition for ref {self.ref!r}')
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
metadata = cast('CoreMetadata', schema.setdefault('metadata', {}))
metadata['pydantic_internal_union_discriminator'] = discriminator
def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:
schema.setdefault('metadata', {})
metadata = schema.get('metadata')
assert metadata is not None
metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator
def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
definitions: dict[str, CoreSchema] | None = None
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
nonlocal definitions
if 'metadata' in s:
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False:
return s
s = recurse(s, inner)
if s['type'] == 'tagged-union':
return s
metadata = s.get('metadata', {})
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
s = apply_discriminator(s, discriminator, definitions)
return s
return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
def apply_discriminator(
schema: core_schema.CoreSchema,
discriminator: str | Discriminator,
definitions: dict[str, core_schema.CoreSchema] | None = None,
schema: core_schema.CoreSchema, discriminator: str, definitions: dict[str, core_schema.CoreSchema] | None = None
) -> core_schema.CoreSchema:
"""Applies the discriminator and returns a new core schema.
@@ -59,14 +83,6 @@ def apply_discriminator(
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
from ..types import Discriminator
if isinstance(discriminator, Discriminator):
if isinstance(discriminator.discriminator, str):
discriminator = discriminator.discriminator
else:
return discriminator._convert_schema(schema)
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
@@ -134,7 +150,7 @@ class _ApplyInferredDiscriminator:
# in the output TaggedUnionSchema that will replace the union from the input schema
self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
# `_used` is changed to True after applying the discriminator to prevent accidental reuse
# `_used` is changed to True after applying the discriminator to prevent accidental re-use
self._used = False
def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
@@ -160,11 +176,16 @@ class _ApplyInferredDiscriminator:
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
self.definitions.update(collect_definitions(schema))
assert not self._used
schema = self._apply_to_root(schema)
if self._should_be_nullable and not self._is_nullable:
schema = core_schema.nullable_schema(schema)
self._used = True
new_defs = collect_definitions(schema)
missing_defs = self.definitions.keys() - new_defs.keys()
if missing_defs:
schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs])
return schema
def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
@@ -234,10 +255,6 @@ class _ApplyInferredDiscriminator:
* Validating that each allowed discriminator value maps to a unique choice
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
"""
if choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])
if choice['type'] == 'none':
self._should_be_nullable = True
elif choice['type'] == 'definitions':
@@ -249,6 +266,10 @@ class _ApplyInferredDiscriminator:
# Reverse the choices list before extending the stack so that they get handled in the order they occur
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
self._choices_to_handle.extend(choices_schemas)
elif choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])
self._handle_choice(self.definitions[choice['schema_ref']])
elif choice['type'] not in {
'model',
'typed-dict',
@@ -256,16 +277,12 @@ class _ApplyInferredDiscriminator:
'lax-or-strict',
'dataclass',
'dataclass-args',
'definition-ref',
} and not _core_utils.is_function_with_inner_schema(choice):
# We should eventually handle 'definition-ref' as well
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
if choice['type'] == 'list':
err_str += (
' If you are making use of a list of union types, make sure the discriminator is applied to the '
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
)
raise TypeError(err_str)
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
' should be a `BaseModel` or `dataclass`'
)
else:
if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
# In this case, this inner tagged-union is compatible with the outer tagged-union,
@@ -299,10 +316,13 @@ class _ApplyInferredDiscriminator:
"""
if choice['type'] == 'definitions':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
elif choice['type'] == 'function-plain':
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
' should be a `BaseModel` or `dataclass`'
)
elif _core_utils.is_function_with_inner_schema(choice):
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
elif choice['type'] == 'lax-or-strict':
return sorted(
set(
@@ -353,13 +373,10 @@ class _ApplyInferredDiscriminator:
raise MissingDefinitionForUnionRef(schema_ref)
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
else:
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
if choice['type'] == 'list':
err_str += (
' If you are making use of a list of union types, make sure the discriminator is applied to the '
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
)
raise TypeError(err_str)
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
' should be a `BaseModel` or `dataclass`'
)
def _infer_discriminator_values_for_typed_dict_choice(
self, choice: core_schema.TypedDictSchema, source_name: str | None = None

View File

@@ -1,108 +0,0 @@
"""Utilities related to attribute docstring extraction."""
from __future__ import annotations
import ast
import inspect
import textwrap
from typing import Any
class DocstringVisitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.target: str | None = None
self.attrs: dict[str, str] = {}
self.previous_node_type: type[ast.AST] | None = None
def visit(self, node: ast.AST) -> Any:
node_result = super().visit(node)
self.previous_node_type = type(node)
return node_result
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
if isinstance(node.target, ast.Name):
self.target = node.target.id
def visit_Expr(self, node: ast.Expr) -> Any:
if (
isinstance(node.value, ast.Constant)
and isinstance(node.value.value, str)
and self.previous_node_type is ast.AnnAssign
):
docstring = inspect.cleandoc(node.value.value)
if self.target:
self.attrs[self.target] = docstring
self.target = None
def _dedent_source_lines(source: list[str]) -> str:
# Required for nested class definitions, e.g. in a function block
dedent_source = textwrap.dedent(''.join(source))
if dedent_source.startswith((' ', '\t')):
# We are in the case where there's a dedented (usually multiline) string
# at a lower indentation level than the class itself. We wrap our class
# in a function as a workaround.
dedent_source = f'def dedent_workaround():\n{dedent_source}'
return dedent_source
def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
frame = inspect.currentframe()
while frame:
if inspect.getmodule(frame) is inspect.getmodule(cls):
lnum = frame.f_lineno
try:
lines, _ = inspect.findsource(frame)
except OSError: # pragma: no cover
# Source can't be retrieved (maybe because running in an interactive terminal),
# we don't want to error here.
pass
else:
block_lines = inspect.getblock(lines[lnum - 1 :])
dedent_source = _dedent_source_lines(block_lines)
try:
block_tree = ast.parse(dedent_source)
except SyntaxError:
pass
else:
stmt = block_tree.body[0]
if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
# `_dedent_source_lines` wrapped the class around the workaround function
stmt = stmt.body[0]
if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
return block_lines
frame = frame.f_back
def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
"""Map model attributes and their corresponding docstring.
Args:
cls: The class of the Pydantic model to inspect.
use_inspect: Whether to skip usage of frames to find the object and use
the `inspect` module instead.
Returns:
A mapping containing attribute names and their corresponding docstring.
"""
if use_inspect:
# Might not work as expected if two classes have the same name in the same source file.
try:
source, _ = inspect.getsourcelines(cls)
except OSError: # pragma: no cover
return {}
else:
source = _extract_source_from_frame(cls)
if not source:
return {}
dedent_source = _dedent_source_lines(source)
visitor = DocstringVisitor()
visitor.visit(ast.parse(dedent_source))
return visitor.attrs

View File

@@ -1,104 +1,92 @@
"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
from __future__ import annotations as _annotations
import dataclasses
import sys
import warnings
from collections.abc import Mapping
from copy import copy
from functools import cache
from inspect import Parameter, ismethoddescriptor, signature
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from typing import TYPE_CHECKING, Any
from annotated_types import BaseMetadata
from pydantic_core import PydanticUndefined
from typing_extensions import TypeIs, get_origin
from typing_inspection import typing_objects
from typing_inspection.introspection import AnnotationSource
from pydantic import PydanticDeprecatedSince211
from pydantic.errors import PydanticUserError
from . import _generics, _typing_extra
from . import _typing_extra
from ._config import ConfigWrapper
from ._docs_extraction import extract_docstrings_from_cls
from ._import_utils import import_cached_base_model, import_cached_field_info
from ._namespace_utils import NsResolver
from ._repr import Representation
from ._utils import can_be_positional
from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar
if TYPE_CHECKING:
from annotated_types import BaseMetadata
from ..fields import FieldInfo
from ..main import BaseModel
from ._dataclasses import PydanticDataclass, StandardDataclass
from ._dataclasses import StandardDataclass
from ._decorators import DecoratorInfos
def get_type_hints_infer_globalns(
obj: Any,
localns: dict[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]:
"""Gets type hints for an object by inferring the global namespace.
It uses the `typing.get_type_hints`, The only thing that we do here is fetching
global namespace from `obj.__module__` if it is not `None`.
Args:
obj: The object to get its type hints.
localns: The local namespaces.
include_extras: Whether to recursively include annotation metadata.
Returns:
The object type hints.
"""
module_name = getattr(obj, '__module__', None)
globalns: dict[str, Any] | None = None
if module_name:
try:
globalns = sys.modules[module_name].__dict__
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
return get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
class PydanticMetadata(Representation):
"""Base class for annotation markers like `Strict`."""
__slots__ = ()
def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
"""Create a new `_PydanticGeneralMetadata` class with the given metadata.
class PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
"""Pydantic general metada like `max_digits`."""
Args:
**metadata: The metadata to add.
Returns:
The new `_PydanticGeneralMetadata` class.
"""
return _general_metadata_cls()(metadata) # type: ignore
@cache
def _general_metadata_cls() -> type[BaseMetadata]:
"""Do it this way to avoid importing `annotated_types` at import time."""
from annotated_types import BaseMetadata
class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
"""Pydantic general metadata like `max_digits`."""
def __init__(self, metadata: Any):
self.__dict__ = metadata
return _PydanticGeneralMetadata # type: ignore
def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], use_inspect: bool = False) -> None:
fields_docs = extract_docstrings_from_cls(cls, use_inspect=use_inspect)
for ann_name, field_info in fields.items():
if field_info.description is None and ann_name in fields_docs:
field_info.description = fields_docs[ann_name]
def __init__(self, **metadata: Any):
self.__dict__ = metadata
def collect_model_fields( # noqa: C901
cls: type[BaseModel],
bases: tuple[type[Any], ...],
config_wrapper: ConfigWrapper,
ns_resolver: NsResolver | None,
types_namespace: dict[str, Any] | None,
*,
typevars_map: Mapping[TypeVar, Any] | None = None,
typevars_map: dict[Any, Any] | None = None,
) -> tuple[dict[str, FieldInfo], set[str]]:
"""Collect the fields and class variables names of a nascent Pydantic model.
"""Collect the fields of a nascent pydantic model.
The fields collection process is *lenient*, meaning it won't error if string annotations
fail to evaluate. If this happens, the original annotation (and assigned value, if any)
is stored on the created `FieldInfo` instance.
Also collect the names of any ClassVars present in the type hints.
The `rebuild_model_fields()` should be called at a later point (e.g. when rebuilding the model),
and will make use of these stored attributes.
The returned value is a tuple of two items: the fields dict, and the set of ClassVar names.
Args:
cls: BaseModel or dataclass.
bases: Parents of the class, generally `cls.__bases__`.
config_wrapper: The config wrapper instance.
ns_resolver: Namespace resolver to use when getting model annotations.
types_namespace: Optional extra namespace to look for types in.
typevars_map: A dictionary mapping type variables to their concrete types.
Returns:
A two-tuple containing model fields and class variables names.
A tuple contains fields and class variables.
Raises:
NameError:
@@ -106,16 +94,9 @@ def collect_model_fields( # noqa: C901
- If there is a field other than `root` in `RootModel`.
- If a field shadows an attribute in the parent model.
"""
BaseModel = import_cached_base_model()
FieldInfo_ = import_cached_field_info()
from ..fields import FieldInfo
bases = cls.__bases__
parent_fields_lookup: dict[str, FieldInfo] = {}
for base in reversed(bases):
if model_fields := getattr(base, '__pydantic_fields__', None):
parent_fields_lookup.update(model_fields)
type_hints = _typing_extra.get_model_type_hints(cls, ns_resolver=ns_resolver)
type_hints = get_cls_type_hints_lenient(cls, types_namespace)
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
# annotations is only used for finding fields in parent classes
@@ -123,50 +104,39 @@ def collect_model_fields( # noqa: C901
fields: dict[str, FieldInfo] = {}
class_vars: set[str] = set()
for ann_name, (ann_type, evaluated) in type_hints.items():
for ann_name, ann_type in type_hints.items():
if ann_name == 'model_config':
# We never want to treat `model_config` as a field
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
# protected namespaces (where `model_config` might be allowed as a field name)
continue
for protected_namespace in config_wrapper.protected_namespaces:
ns_violation: bool = False
if isinstance(protected_namespace, Pattern):
ns_violation = protected_namespace.match(ann_name) is not None
elif isinstance(protected_namespace, str):
ns_violation = ann_name.startswith(protected_namespace)
if ns_violation:
if ann_name.startswith(protected_namespace):
for b in bases:
if hasattr(b, ann_name):
if not (issubclass(b, BaseModel) and ann_name in getattr(b, '__pydantic_fields__', {})):
from ..main import BaseModel
if not (issubclass(b, BaseModel) and ann_name in b.model_fields):
raise NameError(
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
f' of protected namespace "{protected_namespace}".'
)
else:
valid_namespaces = ()
for pn in config_wrapper.protected_namespaces:
if isinstance(pn, Pattern):
if not pn.match(ann_name):
valid_namespaces += (f're.compile({pn.pattern})',)
else:
if not ann_name.startswith(pn):
valid_namespaces += (pn,)
valid_namespaces = tuple(
x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x)
)
warnings.warn(
f'Field "{ann_name}" in {cls.__name__} has conflict with protected namespace "{protected_namespace}".'
f'Field "{ann_name}" has conflict with protected namespace "{protected_namespace}".'
'\n\nYou may be able to resolve this warning by setting'
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
UserWarning,
)
if _typing_extra.is_classvar_annotation(ann_type):
if is_classvar(ann_type):
class_vars.add(ann_name)
continue
if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
class_vars.add(ann_name)
continue
assigned_value = getattr(cls, ann_name, PydanticUndefined)
if not is_valid_field_name(ann_name):
continue
if cls.__pydantic_root_model__ and ann_name != 'root':
@@ -175,7 +145,7 @@ def collect_model_fields( # noqa: C901
)
# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
# "... shadows an attribute" warnings
# "... shadows an attribute" errors
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
for base in bases:
dataclass_fields = {
@@ -183,77 +153,42 @@ def collect_model_fields( # noqa: C901
}
if hasattr(base, ann_name):
if base is generic_origin:
# Don't warn when "shadowing" of attributes in parametrized generics
# Don't error when "shadowing" of attributes in parametrized generics
continue
if ann_name in dataclass_fields:
# Don't warn when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
# Don't error when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
# on the class instance.
continue
if ann_name not in annotations:
# Don't warn when a field exists in a parent class but has not been defined in the current class
continue
warnings.warn(
f'Field name "{ann_name}" in "{cls.__qualname__}" shadows an attribute in parent '
f'"{base.__qualname__}"',
f'Field name "{ann_name}" shadows an attribute in parent "{base.__qualname__}"; ',
UserWarning,
)
if assigned_value is PydanticUndefined: # no assignment, just a plain annotation
if ann_name in annotations or ann_name not in parent_fields_lookup:
# field is either:
# - present in the current model's annotations (and *not* from parent classes)
# - not found on any base classes; this seems to be caused by fields bot getting
# generated due to models not being fully defined while initializing recursive models.
# Nothing stops us from just creating a `FieldInfo` for this type hint, so we do this.
field_info = FieldInfo_.from_annotation(ann_type, _source=AnnotationSource.CLASS)
if not evaluated:
field_info._complete = False
# Store the original annotation that should be used to rebuild
# the field info later:
field_info._original_annotation = ann_type
try:
default = getattr(cls, ann_name, PydanticUndefined)
if default is PydanticUndefined:
raise AttributeError
except AttributeError:
if ann_name in annotations:
field_info = FieldInfo.from_annotation(ann_type)
else:
# The field was present on one of the (possibly multiple) base classes
# copy the field to make sure typevar substitutions don't cause issues with the base classes
field_info = copy(parent_fields_lookup[ann_name])
else: # An assigned value is present (either the default value, or a `Field()` function)
_warn_on_nested_alias_in_annotation(ann_type, ann_name)
if isinstance(assigned_value, FieldInfo_) and ismethoddescriptor(assigned_value.default):
# `assigned_value` was fetched using `getattr`, which triggers a call to `__get__`
# for descriptors, so we do the same if the `= field(default=...)` form is used.
# Note that we only do this for method descriptors for now, we might want to
# extend this to any descriptor in the future (by simply checking for
# `hasattr(assigned_value.default, '__get__')`).
assigned_value.default = assigned_value.default.__get__(None, cls)
# The `from_annotated_attribute()` call below mutates the assigned `Field()`, so make a copy:
original_assignment = (
assigned_value._copy() if not evaluated and isinstance(assigned_value, FieldInfo_) else assigned_value
)
field_info = FieldInfo_.from_annotated_attribute(ann_type, assigned_value, _source=AnnotationSource.CLASS)
# Store the original annotation and assignment value that should be used to rebuild the field info later.
# Note that the assignment is always stored as the annotation might contain a type var that is later
# parameterized with an unknown forward reference (and we'll need it to rebuild the field info):
field_info._original_assignment = original_assignment
if not evaluated:
field_info._complete = False
field_info._original_annotation = ann_type
elif 'final' in field_info._qualifiers and not field_info.is_required():
warnings.warn(
f'Annotation {ann_name!r} is marked as final and has a default value. Pydantic treats {ann_name!r} as a '
'class variable, but it will be considered as a normal field in V3 to be aligned with dataclasses. If you '
f'still want {ann_name!r} to be considered as a class variable, annotate it as: `ClassVar[<type>] = <default>.`',
category=PydanticDeprecatedSince211,
# Incorrect when `create_model` is used, but the chance that final with a default is used is low in that case:
stacklevel=4,
)
class_vars.add(ann_name)
continue
# if field has no default value and is not in __annotations__ this means that it is
# defined in a base class and we can take it from there
model_fields_lookup: dict[str, FieldInfo] = {}
for x in cls.__bases__[::-1]:
model_fields_lookup.update(getattr(x, 'model_fields', {}))
if ann_name in model_fields_lookup:
# The field was present on one of the (possibly multiple) base classes
# copy the field to make sure typevar substitutions don't cause issues with the base classes
field_info = copy(model_fields_lookup[ann_name])
else:
# The field was not found on any base classes; this seems to be caused by fields not getting
# generated thanks to models not being fully defined while initializing recursive models.
# Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
field_info = FieldInfo.from_annotation(ann_type)
else:
field_info = FieldInfo.from_annotated_attribute(ann_type, default)
# attributes which are fields are removed from the class namespace:
# 1. To match the behaviour of annotation-only fields
# 2. To avoid false positives in the NameError check above
@@ -266,250 +201,81 @@ def collect_model_fields( # noqa: C901
# to make sure the decorators have already been built for this exact class
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
if ann_name in decorators.computed_fields:
raise TypeError(
f'Field {ann_name!r} of class {cls.__name__!r} overrides symbol of same name in a parent class. '
'This override with a computed_field is incompatible.'
)
raise ValueError("you can't override a field with a computed field")
fields[ann_name] = field_info
if typevars_map:
for field in fields.values():
if field._complete:
field.apply_typevars_map(typevars_map)
field.apply_typevars_map(typevars_map, types_namespace)
if config_wrapper.use_attribute_docstrings:
_update_fields_from_docstrings(cls, fields)
return fields, class_vars
def _warn_on_nested_alias_in_annotation(ann_type: type[Any], ann_name: str) -> None:
FieldInfo = import_cached_field_info()
def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
from ..fields import FieldInfo
args = getattr(ann_type, '__args__', None)
if args:
for anno_arg in args:
if typing_objects.is_annotated(get_origin(anno_arg)):
for anno_type_arg in _typing_extra.get_args(anno_arg):
if isinstance(anno_type_arg, FieldInfo) and anno_type_arg.alias is not None:
warnings.warn(
f'`alias` specification on field "{ann_name}" must be set on outermost annotation to take effect.',
UserWarning,
)
return
def rebuild_model_fields(
cls: type[BaseModel],
*,
ns_resolver: NsResolver,
typevars_map: Mapping[TypeVar, Any],
) -> dict[str, FieldInfo]:
"""Rebuild the (already present) model fields by trying to reevaluate annotations.
This function should be called whenever a model with incomplete fields is encountered.
Raises:
NameError: If one of the annotations failed to evaluate.
Note:
This function *doesn't* mutate the model fields in place, as it can be called during
schema generation, where you don't want to mutate other model's fields.
"""
FieldInfo_ = import_cached_field_info()
rebuilt_fields: dict[str, FieldInfo] = {}
with ns_resolver.push(cls):
for f_name, field_info in cls.__pydantic_fields__.items():
if field_info._complete:
rebuilt_fields[f_name] = field_info
else:
existing_desc = field_info.description
ann = _typing_extra.eval_type(
field_info._original_annotation,
*ns_resolver.types_namespace,
)
ann = _generics.replace_types(ann, typevars_map)
if (assign := field_info._original_assignment) is PydanticUndefined:
new_field = FieldInfo_.from_annotation(ann, _source=AnnotationSource.CLASS)
else:
new_field = FieldInfo_.from_annotated_attribute(ann, assign, _source=AnnotationSource.CLASS)
# The description might come from the docstring if `use_attribute_docstrings` was `True`:
new_field.description = new_field.description if new_field.description is not None else existing_desc
rebuilt_fields[f_name] = new_field
return rebuilt_fields
if not is_finalvar(type_):
return False
elif val is PydanticUndefined:
return False
elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
return False
else:
return True
def collect_dataclass_fields(
cls: type[StandardDataclass],
*,
ns_resolver: NsResolver | None = None,
typevars_map: dict[Any, Any] | None = None,
config_wrapper: ConfigWrapper | None = None,
cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None
) -> dict[str, FieldInfo]:
"""Collect the fields of a dataclass.
Args:
cls: dataclass.
ns_resolver: Namespace resolver to use when getting dataclass annotations.
Defaults to an empty instance.
types_namespace: Optional extra namespace to look for types in.
typevars_map: A dictionary mapping type variables to their concrete types.
config_wrapper: The config wrapper instance.
Returns:
The dataclass fields.
"""
FieldInfo_ = import_cached_field_info()
from ..fields import FieldInfo
fields: dict[str, FieldInfo] = {}
ns_resolver = ns_resolver or NsResolver()
dataclass_fields = cls.__dataclass_fields__
dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__
cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead
# The logic here is similar to `_typing_extra.get_cls_type_hints`,
# although we do it manually as stdlib dataclasses already have annotations
# collected in each class:
for base in reversed(cls.__mro__):
if not dataclasses.is_dataclass(base):
for ann_name, dataclass_field in dataclass_fields.items():
ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns)
if is_classvar(ann_type):
continue
with ns_resolver.push(base):
for ann_name, dataclass_field in dataclass_fields.items():
if ann_name not in base.__dict__.get('__annotations__', {}):
# `__dataclass_fields__`contains every field, even the ones from base classes.
# Only collect the ones defined on `base`.
continue
if not dataclass_field.init and dataclass_field.default_factory == dataclasses.MISSING:
# TODO: We should probably do something with this so that validate_assignment behaves properly
# Issue: https://github.com/pydantic/pydantic/issues/5470
continue
globalns, localns = ns_resolver.types_namespace
ann_type, evaluated = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns)
if isinstance(dataclass_field.default, FieldInfo):
if dataclass_field.default.init_var:
# TODO: same note as above
continue
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default)
else:
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field)
fields[ann_name] = field_info
if _typing_extra.is_classvar_annotation(ann_type):
continue
if (
not dataclass_field.init
and dataclass_field.default is dataclasses.MISSING
and dataclass_field.default_factory is dataclasses.MISSING
):
# TODO: We should probably do something with this so that validate_assignment behaves properly
# Issue: https://github.com/pydantic/pydantic/issues/5470
continue
if isinstance(dataclass_field.default, FieldInfo_):
if dataclass_field.default.init_var:
if dataclass_field.default.init is False:
raise PydanticUserError(
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
code='clashing-init-and-init-var',
)
# TODO: same note as above re validate_assignment
continue
field_info = FieldInfo_.from_annotated_attribute(
ann_type, dataclass_field.default, _source=AnnotationSource.DATACLASS
)
field_info._original_assignment = dataclass_field.default
else:
field_info = FieldInfo_.from_annotated_attribute(
ann_type, dataclass_field, _source=AnnotationSource.DATACLASS
)
field_info._original_assignment = dataclass_field
if not evaluated:
field_info._complete = False
field_info._original_annotation = ann_type
fields[ann_name] = field_info
if field_info.default is not PydanticUndefined and isinstance(
getattr(cls, ann_name, field_info), FieldInfo_
):
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
setattr(cls, ann_name, field_info.default)
if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo):
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
setattr(cls, ann_name, field_info.default)
if typevars_map:
for field in fields.values():
# We don't pass any ns, as `field.annotation`
# was already evaluated. TODO: is this method relevant?
# Can't we juste use `_generics.replace_types`?
field.apply_typevars_map(typevars_map)
if config_wrapper is not None and config_wrapper.use_attribute_docstrings:
_update_fields_from_docstrings(
cls,
fields,
# We can't rely on the (more reliable) frame inspection method
# for stdlib dataclasses:
use_inspect=not hasattr(cls, '__is_pydantic_dataclass__'),
)
field.apply_typevars_map(typevars_map, types_namespace)
return fields
def rebuild_dataclass_fields(
cls: type[PydanticDataclass],
*,
config_wrapper: ConfigWrapper,
ns_resolver: NsResolver,
typevars_map: Mapping[TypeVar, Any],
) -> dict[str, FieldInfo]:
"""Rebuild the (already present) dataclass fields by trying to reevaluate annotations.
This function should be called whenever a dataclass with incomplete fields is encountered.
Raises:
NameError: If one of the annotations failed to evaluate.
Note:
This function *doesn't* mutate the dataclass fields in place, as it can be called during
schema generation, where you don't want to mutate other dataclass's fields.
"""
FieldInfo_ = import_cached_field_info()
rebuilt_fields: dict[str, FieldInfo] = {}
with ns_resolver.push(cls):
for f_name, field_info in cls.__pydantic_fields__.items():
if field_info._complete:
rebuilt_fields[f_name] = field_info
else:
existing_desc = field_info.description
ann = _typing_extra.eval_type(
field_info._original_annotation,
*ns_resolver.types_namespace,
)
ann = _generics.replace_types(ann, typevars_map)
new_field = FieldInfo_.from_annotated_attribute(
ann,
field_info._original_assignment,
_source=AnnotationSource.DATACLASS,
)
# The description might come from the docstring if `use_attribute_docstrings` was `True`:
new_field.description = new_field.description if new_field.description is not None else existing_desc
rebuilt_fields[f_name] = new_field
return rebuilt_fields
def is_valid_field_name(name: str) -> bool:
return not name.startswith('_')
def is_valid_privateattr_name(name: str) -> bool:
return name.startswith('_') and not name.startswith('__')
def takes_validated_data_argument(
default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any],
) -> TypeIs[Callable[[dict[str, Any]], Any]]:
"""Whether the provided default factory callable has a validated data parameter."""
try:
sig = signature(default_factory)
except (ValueError, TypeError):
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
# In this case, we assume no data argument is present:
return False
parameters = list(sig.parameters.values())
return len(parameters) == 1 and can_be_positional(parameters[0]) and parameters[0].default is Parameter.empty

View File

@@ -1,7 +1,6 @@
from __future__ import annotations as _annotations
from dataclasses import dataclass
from typing import Union
@dataclass
@@ -15,9 +14,3 @@ class PydanticRecursiveRef:
"""Defining __call__ is necessary for the `typing` module to let you use an instance of
this class as the result of resolving a standard ForwardRef.
"""
def __or__(self, other):
return Union[self, other] # type: ignore
def __ror__(self, other):
return Union[other, self] # type: ignore

View File

@@ -4,21 +4,17 @@ import sys
import types
import typing
from collections import ChainMap
from collections.abc import Iterator, Mapping
from contextlib import contextmanager
from contextvars import ContextVar
from itertools import zip_longest
from types import prepare_class
from typing import TYPE_CHECKING, Annotated, Any, TypeVar
from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar
from weakref import WeakValueDictionary
import typing_extensions
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin
from . import _typing_extra
from ._core_utils import get_type_ref
from ._forward_ref import PydanticRecursiveRef
from ._typing_extra import TypeVarType, typing_base
from ._utils import all_identical, is_model_class
if sys.version_info >= (3, 10):
@@ -27,7 +23,7 @@ if sys.version_info >= (3, 10):
if TYPE_CHECKING:
from ..main import BaseModel
GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]]
GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
@@ -38,25 +34,43 @@ GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]]
KT = TypeVar('KT')
VT = TypeVar('VT')
_LIMITED_DICT_SIZE = 100
if TYPE_CHECKING:
class LimitedDict(dict, MutableMapping[KT, VT]):
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
...
class LimitedDict(dict[KT, VT]):
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE) -> None:
self.size_limit = size_limit
super().__init__()
else:
def __setitem__(self, key: KT, value: VT, /) -> None:
super().__setitem__(key, value)
if len(self) > self.size_limit:
excess = len(self) - self.size_limit + self.size_limit // 10
to_remove = list(self.keys())[:excess]
for k in to_remove:
del self[k]
class LimitedDict(dict):
"""Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
"""
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
self.size_limit = size_limit
super().__init__()
def __setitem__(self, __key: Any, __value: Any) -> None:
super().__setitem__(__key, __value)
if len(self) > self.size_limit:
excess = len(self) - self.size_limit + self.size_limit // 10
to_remove = list(self.keys())[:excess]
for key in to_remove:
del self[key]
def __class_getitem__(cls, *args: Any) -> Any:
# to avoid errors with 3.7
return cls
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
# once they are no longer referenced by the caller.
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
else:
GenericTypesCache = WeakValueDictionary
if TYPE_CHECKING:
@@ -94,13 +108,13 @@ else:
# and discover later on that we need to re-add all this infrastructure...
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
_GENERIC_TYPES_CACHE: ContextVar[GenericTypesCache | None] = ContextVar('_GENERIC_TYPES_CACHE', default=None)
_GENERIC_TYPES_CACHE = GenericTypesCache()
class PydanticGenericMetadata(typing_extensions.TypedDict):
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
parameters: tuple[TypeVar, ...] # analogous to typing.Generic.__parameters__
parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__
def create_generic_submodel(
@@ -157,7 +171,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
depth: The depth to get the frame.
Returns:
A tuple contains `module_name` and `called_globally`.
A tuple contains `module_nam` and `called_globally`.
Raises:
RuntimeError: If the function is not called inside a function.
@@ -175,7 +189,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
DictValues: type[Any] = {}.values().__class__
def iter_contained_typevars(v: Any) -> Iterator[TypeVar]:
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
@@ -208,7 +222,7 @@ def get_origin(v: Any) -> Any:
return typing_extensions.get_origin(v)
def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
"""
@@ -221,11 +235,11 @@ def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
# So it is safe to access cls.__args__ and origin.__parameters__
args: tuple[Any, ...] = cls.__args__ # type: ignore
parameters: tuple[TypeVar, ...] = origin.__parameters__
parameters: tuple[TypeVarType, ...] = origin.__parameters__
return dict(zip(parameters, args))
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]:
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None:
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
with the `replace_types` function.
@@ -237,13 +251,10 @@ def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]:
generic_metadata = cls.__pydantic_generic_metadata__
origin = generic_metadata['origin']
args = generic_metadata['args']
if not args:
# No need to go into `iter_contained_typevars`:
return {}
return dict(zip(iter_contained_typevars(origin), args))
def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
Args:
@@ -255,13 +266,13 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
`typevar_map` keys recursively replaced.
Example:
```python
from typing import List, Union
```py
from typing import List, Tuple, Union
from pydantic._internal._generics import replace_types
replace_types(tuple[str, Union[List[str], float]], {str: int})
#> tuple[int, Union[List[int], float]]
replace_types(Tuple[str, Union[List[str], float]], {str: int})
#> Tuple[int, Union[List[int], float]]
```
"""
if not type_map:
@@ -270,25 +281,25 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
type_args = get_args(type_)
origin_type = get_origin(type_)
if typing_objects.is_annotated(origin_type):
if origin_type is typing_extensions.Annotated:
annotated_type, *annotations = type_args
annotated_type = replace_types(annotated_type, type_map)
# TODO remove parentheses when we drop support for Python 3.10:
return Annotated[(annotated_type, *annotations)]
annotated = replace_types(annotated_type, type_map)
for annotation in annotations:
annotated = typing_extensions.Annotated[annotated, annotation]
return annotated
# Having type args is a good indicator that this is a typing special form
# instance or a generic alias of some sort.
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if type_args:
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
if all_identical(type_args, resolved_type_args):
# If all arguments are the same, there is no need to modify the
# type or create a new object at all
return type_
if (
origin_type is not None
and isinstance(type_, _typing_extra.typing_base)
and not isinstance(origin_type, _typing_extra.typing_base)
and isinstance(type_, typing_base)
and not isinstance(origin_type, typing_base)
and getattr(type_, '_name', None) is not None
):
# In python < 3.9 generic aliases don't exist so any of these like `list`,
@@ -296,24 +307,11 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
# See: https://www.python.org/dev/peps/pep-0585
origin_type = getattr(typing, type_._name)
assert origin_type is not None
if is_union_origin(origin_type):
if any(typing_objects.is_any(arg) for arg in resolved_type_args):
# `Any | T` ~ `Any`:
resolved_type_args = (Any,)
# `Never | T` ~ `T`:
resolved_type_args = tuple(
arg
for arg in resolved_type_args
if not (typing_objects.is_noreturn(arg) or typing_objects.is_never(arg))
)
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
# We also cannot use isinstance() since we have to compare types.
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
return _UnionGenericAlias(origin_type, resolved_type_args)
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
return origin_type[resolved_type_args]
# We handle pydantic generic models separately as they don't have the same
# semantics as "typing" classes or generic aliases
@@ -329,8 +327,8 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, list):
resolved_list = [replace_types(element, type_map) for element in type_]
if isinstance(type_, (List, list)):
resolved_list = list(replace_types(element, type_map) for element in type_)
if all_identical(type_, resolved_list):
return type_
return resolved_list
@@ -340,57 +338,49 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
return type_map.get(type_, type_)
def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]:
"""Return a mapping between the parameters of a generic model and the provided arguments during parameterization.
def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
"""Checks if the type, or any of its arbitrary nested args, satisfy
`isinstance(<type>, isinstance_target)`.
"""
if isinstance(type_, isinstance_target):
return True
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is typing_extensions.Annotated:
annotated_type, *annotations = type_args
return has_instance_in_type(annotated_type, isinstance_target)
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if any(has_instance_in_type(a, isinstance_target) for a in type_args):
return True
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec):
if any(has_instance_in_type(element, isinstance_target) for element in type_):
return True
return False
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
"""Check the generic model parameters count is equal.
Args:
cls: The generic model.
parameters: A tuple of passed parameters to the generic model.
Raises:
TypeError: If the number of arguments does not match the parameters (i.e. if providing too few or too many arguments).
Example:
```python {test="skip" lint="skip"}
class Model[T, U, V = int](BaseModel): ...
map_generic_model_arguments(Model, (str, bytes))
#> {T: str, U: bytes, V: int}
map_generic_model_arguments(Model, (str,))
#> TypeError: Too few arguments for <class '__main__.Model'>; actual 1, expected at least 2
map_generic_model_arguments(Model, (str, bytes, int, complex))
#> TypeError: Too many arguments for <class '__main__.Model'>; actual 4, expected 3
```
Note:
This function is analogous to the private `typing._check_generic_specialization` function.
TypeError: If the passed parameters count is not equal to generic model parameters count.
"""
parameters = cls.__pydantic_generic_metadata__['parameters']
expected_len = len(parameters)
typevars_map: dict[TypeVar, Any] = {}
_missing = object()
for parameter, argument in zip_longest(parameters, args, fillvalue=_missing):
if parameter is _missing:
raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}')
if argument is _missing:
param = typing.cast(TypeVar, parameter)
try:
has_default = param.has_default()
except AttributeError:
# Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13.
has_default = False
if has_default:
# The default might refer to other type parameters. For an example, see:
# https://typing.readthedocs.io/en/latest/spec/generics.html#type-parameters-as-parameters-to-generics
typevars_map[param] = replace_types(param.__default__, typevars_map)
else:
expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters)
raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}')
else:
param = typing.cast(TypeVar, parameter)
typevars_map[param] = argument
return typevars_map
actual = len(parameters)
expected = len(cls.__pydantic_generic_metadata__['parameters'])
if actual != expected:
description = 'many' if actual > expected else 'few'
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
@@ -421,8 +411,7 @@ def generic_recursion_self_type(
yield self_type
else:
previously_seen_type_refs.add(type_ref)
yield
previously_seen_type_refs.remove(type_ref)
yield None
finally:
if token:
_generic_recursion_cache.reset(token)
@@ -453,24 +442,14 @@ def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any)
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
equal.
"""
generic_types_cache = _GENERIC_TYPES_CACHE.get()
if generic_types_cache is None:
generic_types_cache = GenericTypesCache()
_GENERIC_TYPES_CACHE.set(generic_types_cache)
return generic_types_cache.get(_early_cache_key(parent, typevar_values))
return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
def get_cached_generic_type_late(
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
) -> type[BaseModel] | None:
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
generic_types_cache = _GENERIC_TYPES_CACHE.get()
if (
generic_types_cache is None
): # pragma: no cover (early cache is guaranteed to run first and initialize the cache)
generic_types_cache = GenericTypesCache()
_GENERIC_TYPES_CACHE.set(generic_types_cache)
cached = generic_types_cache.get(_late_cache_key(origin, args, typevar_values))
cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
if cached is not None:
set_cached_generic_type(parent, typevar_values, cached, origin, args)
return cached
@@ -486,17 +465,11 @@ def set_cached_generic_type(
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
two different keys.
"""
generic_types_cache = _GENERIC_TYPES_CACHE.get()
if (
generic_types_cache is None
): # pragma: no cover (cache lookup is guaranteed to run first and initialize the cache)
generic_types_cache = GenericTypesCache()
_GENERIC_TYPES_CACHE.set(generic_types_cache)
generic_types_cache[_early_cache_key(parent, typevar_values)] = type_
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
if len(typevar_values) == 1:
generic_types_cache[_early_cache_key(parent, typevar_values[0])] = type_
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
if origin and args:
generic_types_cache[_late_cache_key(origin, args, typevar_values)] = type_
_GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
def _union_orderings_key(typevar_values: Any) -> Any:
@@ -517,7 +490,7 @@ def _union_orderings_key(typevar_values: Any) -> Any:
for value in typevar_values:
args_data.append(_union_orderings_key(value))
return tuple(args_data)
elif typing_objects.is_union(typing_extensions.get_origin(typevar_values)):
elif typing_extensions.get_origin(typevar_values) is typing.Union:
return get_args(typevar_values)
else:
return ()

View File

@@ -1,27 +0,0 @@
"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py)."""
from __future__ import annotations
import subprocess
from pathlib import Path
def is_git_repo(dir: Path) -> bool:
"""Is the given directory version-controlled with git?"""
return dir.joinpath('.git').exists()
def have_git() -> bool: # pragma: no cover
"""Can we run the git executable?"""
try:
subprocess.check_output(['git', '--help'])
return True
except subprocess.CalledProcessError:
return False
except OSError:
return False
def git_revision(dir: Path) -> str:
"""Get the SHA-1 of the HEAD of a git repository."""
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip()

View File

@@ -1,20 +0,0 @@
from functools import cache
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pydantic import BaseModel
from pydantic.fields import FieldInfo
@cache
def import_cached_base_model() -> type['BaseModel']:
from pydantic import BaseModel
return BaseModel
@cache
def import_cached_field_info() -> type['FieldInfo']:
from pydantic.fields import FieldInfo
return FieldInfo

View File

@@ -1,4 +1,7 @@
import sys
from typing import Any, Dict
dataclass_kwargs: Dict[str, Any]
# `slots` is available on Python >= 3.10
if sys.version_info >= (3, 10):

View File

@@ -1,57 +1,42 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable
from copy import copy
from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Iterable
from pydantic_core import CoreSchema, PydanticCustomError, ValidationError, to_jsonable_python
import annotated_types as at
from pydantic_core import CoreSchema, PydanticCustomError, to_jsonable_python
from pydantic_core import core_schema as cs
from ._fields import PydanticMetadata
from ._import_utils import import_cached_field_info
from . import _validators
from ._fields import PydanticGeneralMetadata, PydanticMetadata
if TYPE_CHECKING:
pass
from ..annotated_handlers import GetJsonSchemaHandler
STRICT = {'strict'}
FAIL_FAST = {'fail_fast'}
LENGTH_CONSTRAINTS = {'min_length', 'max_length'}
SEQUENCE_CONSTRAINTS = {'min_length', 'max_length'}
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
NUMERIC_CONSTRAINTS = {'multiple_of', *INEQUALITY}
ALLOW_INF_NAN = {'allow_inf_nan'}
NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY}
STR_CONSTRAINTS = {
*LENGTH_CONSTRAINTS,
*STRICT,
'strip_whitespace',
'to_lower',
'to_upper',
'pattern',
'coerce_numbers_to_str',
}
BYTES_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'}
BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
LIST_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
TUPLE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
SET_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
DICT_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
GENERATOR_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
SEQUENCE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *FAIL_FAST}
LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
TUPLE_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
SET_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
DICT_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
GENERATOR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
DECIMAL_CONSTRAINTS = {'max_digits', 'decimal_places', *FLOAT_CONSTRAINTS}
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
BOOL_CONSTRAINTS = STRICT
UUID_CONSTRAINTS = STRICT
DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
LAX_OR_STRICT_CONSTRAINTS = STRICT
ENUM_CONSTRAINTS = STRICT
COMPLEX_CONSTRAINTS = STRICT
UNION_CONSTRAINTS = {'union_mode'}
URL_CONSTRAINTS = {
@@ -68,33 +53,54 @@ SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')
CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
for constraint in STR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES)
for constraint in BYTES_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',))
for constraint in LIST_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',))
for constraint in TUPLE_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',))
for constraint in SET_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset'))
for constraint in DICT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',))
for constraint in GENERATOR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',))
for constraint in FLOAT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',))
for constraint in INT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',))
for constraint in DATE_TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime'))
for constraint in TIMEDELTA_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',))
for constraint in TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
for constraint in UNION_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',))
for constraint in URL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
for constraint in BOOL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',))
constraint_schema_pairings: list[tuple[set[str], tuple[str, ...]]] = [
(STR_CONSTRAINTS, TEXT_SCHEMA_TYPES),
(BYTES_CONSTRAINTS, ('bytes',)),
(LIST_CONSTRAINTS, ('list',)),
(TUPLE_CONSTRAINTS, ('tuple',)),
(SET_CONSTRAINTS, ('set', 'frozenset')),
(DICT_CONSTRAINTS, ('dict',)),
(GENERATOR_CONSTRAINTS, ('generator',)),
(FLOAT_CONSTRAINTS, ('float',)),
(INT_CONSTRAINTS, ('int',)),
(DATE_TIME_CONSTRAINTS, ('date', 'time', 'datetime', 'timedelta')),
# TODO: this is a bit redundant, we could probably avoid some of these
(STRICT, (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model')),
(UNION_CONSTRAINTS, ('union',)),
(URL_CONSTRAINTS, ('url', 'multi-host-url')),
(BOOL_CONSTRAINTS, ('bool',)),
(UUID_CONSTRAINTS, ('uuid',)),
(LAX_OR_STRICT_CONSTRAINTS, ('lax-or-strict',)),
(ENUM_CONSTRAINTS, ('enum',)),
(DECIMAL_CONSTRAINTS, ('decimal',)),
(COMPLEX_CONSTRAINTS, ('complex',)),
]
for constraints, schemas in constraint_schema_pairings:
for c in constraints:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[c].update(schemas)
def add_js_update_schema(s: cs.CoreSchema, f: Callable[[], dict[str, Any]]) -> None:
def update_js_schema(s: cs.CoreSchema, handler: GetJsonSchemaHandler) -> dict[str, Any]:
js_schema = handler(s)
js_schema.update(f())
return js_schema
if 'metadata' in s:
metadata = s['metadata']
if 'pydantic_js_functions' in s:
metadata['pydantic_js_functions'].append(update_js_schema)
else:
metadata['pydantic_js_functions'] = [update_js_schema]
else:
s['metadata'] = {'pydantic_js_functions': [update_js_schema]}
def as_jsonable_value(v: Any) -> Any:
@@ -113,7 +119,7 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
An iterable of expanded annotations.
Example:
```python
```py
from annotated_types import Ge, Len
from pydantic._internal._known_annotated_metadata import expand_grouped_metadata
@@ -122,9 +128,7 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
#> [Ge(ge=4), MinLen(min_length=5)]
```
"""
import annotated_types as at
FieldInfo = import_cached_field_info()
from pydantic.fields import FieldInfo # circular import
for annotation in annotations:
if isinstance(annotation, at.GroupedMetadata):
@@ -143,28 +147,6 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
yield annotation
@lru_cache
def _get_at_to_constraint_map() -> dict[type, str]:
"""Return a mapping of annotated types to constraints.
Normally, we would define a mapping like this in the module scope, but we can't do that
because we don't permit module level imports of `annotated_types`, in an attempt to speed up
the import time of `pydantic`. We still only want to have this dictionary defined in one place,
so we use this function to cache the result.
"""
import annotated_types as at
return {
at.Gt: 'gt',
at.Ge: 'ge',
at.Lt: 'lt',
at.Le: 'le',
at.MultipleOf: 'multiple_of',
at.MinLen: 'min_length',
at.MaxLen: 'max_length',
}
def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901
"""Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.).
Otherwise return `None`.
@@ -184,37 +166,14 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
Raises:
PydanticCustomError: If `Predicate` fails.
"""
import annotated_types as at
from ._validators import NUMERIC_VALIDATOR_LOOKUP, forbid_inf_nan_check
schema = schema.copy()
schema_update, other_metadata = collect_known_metadata([annotation])
schema_type = schema['type']
chain_schema_constraints: set[str] = {
'pattern',
'strip_whitespace',
'to_lower',
'to_upper',
'coerce_numbers_to_str',
}
chain_schema_steps: list[CoreSchema] = []
for constraint, value in schema_update.items():
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
raise ValueError(f'Unknown constraint {constraint}')
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]
# if it becomes necessary to handle more than one constraint
# in this recursive case with function-after or function-wrap, we should refactor
# this is a bit challenging because we sometimes want to apply constraints to the inner schema,
# whereas other times we want to wrap the existing schema with a new one that enforces a new constraint.
if schema_type in {'function-before', 'function-wrap', 'function-after'} and constraint == 'strict':
schema['schema'] = apply_known_metadata(annotation, schema['schema']) # type: ignore # schema is function schema
return schema
# if we're allowed to apply constraint directly to the schema, like le to int, do that
if schema_type in allowed_schemas:
if constraint == 'union_mode' and schema_type == 'union':
schema['mode'] = value # type: ignore # schema is UnionSchema
@@ -222,109 +181,145 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
schema[constraint] = value
continue
# else, apply a function after validator to the schema to enforce the corresponding constraint
if constraint in chain_schema_constraints:
def _apply_constraint_with_incompatibility_info(
value: Any, handler: cs.ValidatorFunctionWrapHandler
) -> Any:
try:
x = handler(value)
except ValidationError as ve:
# if the error is about the type, it's likely that the constraint is incompatible the type of the field
# for example, the following invalid schema wouldn't be caught during schema build, but rather at this point
# with a cryptic 'string_type' error coming from the string validator,
# that we'd rather express as a constraint incompatibility error (TypeError)
# Annotated[list[int], Field(pattern='abc')]
if 'type' in ve.errors()[0]['type']:
raise TypeError(
f"Unable to apply constraint '{constraint}' to supplied value {value} for schema of type '{schema_type}'" # noqa: B023
)
raise ve
return x
chain_schema_steps.append(
cs.no_info_wrap_validator_function(
_apply_constraint_with_incompatibility_info, cs.str_schema(**{constraint: value})
)
if constraint == 'allow_inf_nan' and value is False:
return cs.no_info_after_validator_function(
_validators.forbid_inf_nan_check,
schema,
)
elif constraint in NUMERIC_VALIDATOR_LOOKUP:
if constraint in LENGTH_CONSTRAINTS:
inner_schema = schema
while inner_schema['type'] in {'function-before', 'function-wrap', 'function-after'}:
inner_schema = inner_schema['schema'] # type: ignore
inner_schema_type = inner_schema['type']
if inner_schema_type == 'list' or (
inner_schema_type == 'json-or-python' and inner_schema['json_schema']['type'] == 'list' # type: ignore
):
js_constraint_key = 'minItems' if constraint == 'min_length' else 'maxItems'
else:
js_constraint_key = 'minLength' if constraint == 'min_length' else 'maxLength'
else:
js_constraint_key = constraint
schema = cs.no_info_after_validator_function(
partial(NUMERIC_VALIDATOR_LOOKUP[constraint], **{constraint: value}), schema
elif constraint == 'pattern':
# insert a str schema to make sure the regex engine matches
return cs.chain_schema(
[
schema,
cs.str_schema(pattern=value),
]
)
metadata = schema.get('metadata', {})
if (existing_json_schema_updates := metadata.get('pydantic_js_updates')) is not None:
metadata['pydantic_js_updates'] = {
**existing_json_schema_updates,
**{js_constraint_key: as_jsonable_value(value)},
}
else:
metadata['pydantic_js_updates'] = {js_constraint_key: as_jsonable_value(value)}
schema['metadata'] = metadata
elif constraint == 'allow_inf_nan' and value is False:
schema = cs.no_info_after_validator_function(
forbid_inf_nan_check,
elif constraint == 'gt':
s = cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=value),
schema,
)
add_js_update_schema(s, lambda: {'gt': as_jsonable_value(value)})
return s
elif constraint == 'ge':
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=value),
schema,
)
elif constraint == 'lt':
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=value),
schema,
)
elif constraint == 'le':
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=value),
schema,
)
elif constraint == 'multiple_of':
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=value),
schema,
)
elif constraint == 'min_length':
s = cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=value),
schema,
)
add_js_update_schema(s, lambda: {'minLength': (as_jsonable_value(value))})
return s
elif constraint == 'max_length':
s = cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=value),
schema,
)
add_js_update_schema(s, lambda: {'maxLength': (as_jsonable_value(value))})
return s
elif constraint == 'strip_whitespace':
return cs.chain_schema(
[
schema,
cs.str_schema(strip_whitespace=True),
]
)
elif constraint == 'to_lower':
return cs.chain_schema(
[
schema,
cs.str_schema(to_lower=True),
]
)
elif constraint == 'to_upper':
return cs.chain_schema(
[
schema,
cs.str_schema(to_upper=True),
]
)
elif constraint == 'min_length':
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif constraint == 'max_length':
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
else:
# It's rare that we'd get here, but it's possible if we add a new constraint and forget to handle it
# Most constraint errors are caught at runtime during attempted application
raise RuntimeError(f"Unable to apply constraint '{constraint}' to schema of type '{schema_type}'")
raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}')
for annotation in other_metadata:
if (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
constraint = at_to_constraint_map[annotation_type]
validator = NUMERIC_VALIDATOR_LOOKUP.get(constraint)
if validator is None:
raise ValueError(f'Unknown constraint {constraint}')
schema = cs.no_info_after_validator_function(
partial(validator, {constraint: getattr(annotation, constraint)}), schema
if isinstance(annotation, at.Gt):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=annotation.gt),
schema,
)
continue
elif isinstance(annotation, (at.Predicate, at.Not)):
predicate_name = f'{annotation.func.__qualname__}' if hasattr(annotation.func, '__qualname__') else ''
elif isinstance(annotation, at.Ge):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=annotation.ge),
schema,
)
elif isinstance(annotation, at.Lt):
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=annotation.lt),
schema,
)
elif isinstance(annotation, at.Le):
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=annotation.le),
schema,
)
elif isinstance(annotation, at.MultipleOf):
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of),
schema,
)
elif isinstance(annotation, at.MinLen):
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif isinstance(annotation, at.MaxLen):
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
elif isinstance(annotation, at.Predicate):
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''
def val_func(v: Any) -> Any:
predicate_satisfied = annotation.func(v) # noqa: B023
# annotation.func may also raise an exception, let it pass through
if isinstance(annotation, at.Predicate): # noqa: B023
if not predicate_satisfied:
raise PydanticCustomError(
'predicate_failed',
f'Predicate {predicate_name} failed', # type: ignore # noqa: B023
)
else:
if predicate_satisfied:
raise PydanticCustomError(
'not_operation_failed',
f'Not of {predicate_name} failed', # type: ignore # noqa: B023
)
if not annotation.func(v):
raise PydanticCustomError(
'predicate_failed',
f'Predicate {predicate_name}failed', # type: ignore
)
return v
schema = cs.no_info_after_validator_function(val_func, schema)
else:
# ignore any other unknown metadata
return None
if chain_schema_steps:
chain_schema_steps = [schema] + chain_schema_steps
return cs.chain_schema(chain_schema_steps)
return cs.no_info_after_validator_function(val_func, schema)
# ignore any other unknown metadata
return None
return schema
@@ -339,7 +334,7 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
A tuple contains a dict of known metadata and a list of unknown annotations.
Example:
```python
```py
from annotated_types import Gt, Len
from pydantic._internal._known_annotated_metadata import collect_known_metadata
@@ -352,15 +347,29 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
res: dict[str, Any] = {}
remaining: list[Any] = []
for annotation in annotations:
# isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata
if isinstance(annotation, PydanticMetadata):
# Do we really want to consume any `BaseMetadata`?
# It does let us give a better error when there is an annotation that doesn't apply
# But it seems dangerous!
if isinstance(annotation, PydanticGeneralMetadata):
res.update(annotation.__dict__)
elif isinstance(annotation, PydanticMetadata):
res.update(annotation.__dict__)
# we don't use dataclasses.asdict because that recursively calls asdict on the field values
elif (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
constraint = at_to_constraint_map[annotation_type]
res[constraint] = getattr(annotation, constraint)
elif isinstance(annotation, at.MinLen):
res.update({'min_length': annotation.min_length})
elif isinstance(annotation, at.MaxLen):
res.update({'max_length': annotation.max_length})
elif isinstance(annotation, at.Gt):
res.update({'gt': annotation.gt})
elif isinstance(annotation, at.Ge):
res.update({'ge': annotation.ge})
elif isinstance(annotation, at.Lt):
res.update({'lt': annotation.lt})
elif isinstance(annotation, at.Le):
res.update({'le': annotation.le})
elif isinstance(annotation, at.MultipleOf):
res.update({'multiple_of': annotation.multiple_of})
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
# also support PydanticMetadata classes being used without initialisation,
# e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]`

View File

@@ -1,71 +1,18 @@
from __future__ import annotations
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
from pydantic_core import SchemaSerializer, SchemaValidator
from typing_extensions import Literal
from ..errors import PydanticErrorCodes, PydanticUserError
from ..plugin._schema_validator import PluggableSchemaValidator
if TYPE_CHECKING:
from ..dataclasses import PydanticDataclass
from ..main import BaseModel
from ..type_adapter import TypeAdapter
ValSer = TypeVar('ValSer', bound=Union[SchemaValidator, PluggableSchemaValidator, SchemaSerializer])
T = TypeVar('T')
class MockCoreSchema(Mapping[str, Any]):
"""Mocker for `pydantic_core.CoreSchema` which optionally attempts to
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
"""
__slots__ = '_error_message', '_code', '_attempt_rebuild', '_built_memo'
def __init__(
self,
error_message: str,
*,
code: PydanticErrorCodes,
attempt_rebuild: Callable[[], CoreSchema | None] | None = None,
) -> None:
self._error_message = error_message
self._code: PydanticErrorCodes = code
self._attempt_rebuild = attempt_rebuild
self._built_memo: CoreSchema | None = None
def __getitem__(self, key: str) -> Any:
return self._get_built().__getitem__(key)
def __len__(self) -> int:
return self._get_built().__len__()
def __iter__(self) -> Iterator[str]:
return self._get_built().__iter__()
def _get_built(self) -> CoreSchema:
if self._built_memo is not None:
return self._built_memo
if self._attempt_rebuild:
schema = self._attempt_rebuild()
if schema is not None:
self._built_memo = schema
return schema
raise PydanticUserError(self._error_message, code=self._code)
def rebuild(self) -> CoreSchema | None:
self._built_memo = None
if self._attempt_rebuild:
schema = self._attempt_rebuild()
if schema is not None:
return schema
else:
raise PydanticUserError(self._error_message, code=self._code)
return None
ValSer = TypeVar('ValSer', SchemaValidator, SchemaSerializer)
class MockValSer(Generic[ValSer]):
@@ -109,120 +56,63 @@ class MockValSer(Generic[ValSer]):
return None
def set_type_adapter_mocks(adapter: TypeAdapter) -> None:
"""Set `core_schema`, `validator` and `serializer` to mock core types on a type adapter instance.
Args:
adapter: The type adapter instance to set the mocks on
"""
type_repr = str(adapter._type)
undefined_type_error_message = (
f'`TypeAdapter[{type_repr}]` is not fully defined; you should define `{type_repr}` and all referenced types,'
f' then call `.rebuild()` on the instance.'
)
def attempt_rebuild_fn(attr_fn: Callable[[TypeAdapter], T]) -> Callable[[], T | None]:
def handler() -> T | None:
if adapter.rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return attr_fn(adapter)
return None
return handler
adapter.core_schema = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.core_schema),
)
adapter.validator = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.validator),
)
adapter.serializer = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='serializer',
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.serializer),
)
def set_model_mocks(cls: type[BaseModel], undefined_name: str = 'all referenced types') -> None:
"""Set `__pydantic_core_schema__`, `__pydantic_validator__` and `__pydantic_serializer__` to mock core types on a model.
def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model.
Args:
cls: The model class to set the mocks on
cls_name: Name of the model class, used in error messages
undefined_name: Name of the undefined thing, used in error messages
"""
undefined_type_error_message = (
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
f' then call `{cls.__name__}.model_rebuild()`.'
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
f' then call `{cls_name}.model_rebuild()`.'
)
def attempt_rebuild_fn(attr_fn: Callable[[type[BaseModel]], T]) -> Callable[[], T | None]:
def handler() -> T | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return attr_fn(cls)
def attempt_rebuild_validator() -> SchemaValidator | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5):
return cls.__pydantic_validator__
else:
return None
return handler
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
)
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
attempt_rebuild=attempt_rebuild_validator,
)
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
def attempt_rebuild_serializer() -> SchemaSerializer | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5):
return cls.__pydantic_serializer__
else:
return None
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='serializer',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
attempt_rebuild=attempt_rebuild_serializer,
)
def set_dataclass_mocks(cls: type[PydanticDataclass], undefined_name: str = 'all referenced types') -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
Args:
cls: The model class to set the mocks on
undefined_name: Name of the undefined thing, used in error messages
"""
from ..dataclasses import rebuild_dataclass
def set_dataclass_mock_validator(cls: type[PydanticDataclass], cls_name: str, undefined_name: str) -> None:
undefined_type_error_message = (
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
f' then call `pydantic.dataclasses.rebuild_dataclass({cls.__name__})`.'
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
)
def attempt_rebuild_fn(attr_fn: Callable[[type[PydanticDataclass]], T]) -> Callable[[], T | None]:
def handler() -> T | None:
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
return attr_fn(cls)
def attempt_rebuild() -> SchemaValidator | None:
from ..dataclasses import rebuild_dataclass
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5):
return cls.__pydantic_validator__
else:
return None
return handler
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
)
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
)
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='serializer',
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
attempt_rebuild=attempt_rebuild,
)

View File

@@ -1,54 +1,58 @@
"""Private logic for creating models."""
from __future__ import annotations as _annotations
import builtins
import operator
import sys
import typing
import warnings
import weakref
from abc import ABCMeta
from functools import cache, partial, wraps
from functools import partial
from types import FunctionType
from typing import Any, Callable, Generic, Literal, NoReturn, cast
from typing import Any, Callable, Generic, Mapping
from pydantic_core import PydanticUndefined, SchemaSerializer
from typing_extensions import TypeAliasType, dataclass_transform, deprecated, get_args, get_origin
from typing_inspection import typing_objects
from typing_extensions import dataclass_transform, deprecated
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
from ..fields import Field, FieldInfo, ModelPrivateAttr, PrivateAttr
from ..plugin._schema_validator import create_schema_validator
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ..warnings import PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases, unwrap_wrapped_function
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema, InvalidSchemaError
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._import_utils import import_cached_base_model, import_cached_field_info
from ._mock_val_ser import set_model_mocks
from ._namespace_utils import NsResolver
from ._signature import generate_pydantic_signature
from ._typing_extra import (
_make_forward_ref,
eval_type_backport,
is_classvar_annotation,
parent_frame_namespace,
from ._core_utils import collect_invalid_schemas, simplify_schema_references, validate_core_schema
from ._decorators import (
ComputedFieldInfo,
DecoratorInfos,
PydanticDescriptorProxy,
get_attribute_from_bases,
)
from ._utils import LazyClassAttribute, SafeGetItemProxy
from ._discriminated_union import apply_discriminators
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import get_cls_types_namespace, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute, is_valid_identifier
from ._validate_call import ValidateCallWrapper
if typing.TYPE_CHECKING:
from ..fields import Field as PydanticModelField
from ..fields import FieldInfo, ModelPrivateAttr
from ..fields import PrivateAttr as PydanticModelPrivateAttr
from inspect import Signature
from ..main import BaseModel
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
PydanticModelField = object()
PydanticModelPrivateAttr = object()
IGNORED_TYPES: tuple[Any, ...] = (
FunctionType,
property,
classmethod,
staticmethod,
PydanticDescriptorProxy,
ComputedFieldInfo,
ValidateCallWrapper,
)
object_setattr = object.__setattr__
@@ -65,17 +69,7 @@ class _ModelNamespaceDict(dict):
return super().__setitem__(k, v)
def NoInitField(
*,
init: Literal[False] = False,
) -> Any:
"""Only for typing purposes. Used as default value of `__pydantic_fields_set__`,
`__pydantic_extra__`, `__pydantic_private__`, so they could be ignored when
synthesizing the `__init__` signature.
"""
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr, NoInitField))
@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
class ModelMetaclass(ABCMeta):
def __new__(
mcs,
@@ -84,7 +78,6 @@ class ModelMetaclass(ABCMeta):
namespace: dict[str, Any],
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
__pydantic_reset_parent_namespace__: bool = True,
_create_model_module: str | None = None,
**kwargs: Any,
) -> type:
"""Metaclass for creating Pydantic models.
@@ -95,7 +88,6 @@ class ModelMetaclass(ABCMeta):
namespace: The attribute dictionary of the class to be created.
__pydantic_generic_metadata__: Metadata for generic models.
__pydantic_reset_parent_namespace__: Reset parent namespace.
_create_model_module: The module of the class to be created, if created by `create_model`.
**kwargs: Catch-all for any other keyword arguments.
Returns:
@@ -112,18 +104,17 @@ class ModelMetaclass(ABCMeta):
private_attributes = inspect_namespace(
namespace, config_wrapper.ignored_types, class_vars, base_field_names
)
if private_attributes or base_private_attributes:
if private_attributes:
original_model_post_init = get_model_post_init(namespace, bases)
if original_model_post_init is not None:
# if there are private_attributes and a model_post_init function, we handle both
@wraps(original_model_post_init)
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
method.
"""
init_private_attributes(self, context)
original_model_post_init(self, context)
init_private_attributes(self, __context)
original_model_post_init(self, __context)
namespace['model_post_init'] = wrapped_model_post_init
else:
@@ -132,25 +123,15 @@ class ModelMetaclass(ABCMeta):
namespace['__class_vars__'] = class_vars
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
cls = cast('type[BaseModel]', super().__new__(mcs, cls_name, bases, namespace, **kwargs))
BaseModel_ = import_cached_base_model()
if config_wrapper.frozen:
set_default_hash_func(namespace, bases)
mro = cls.__mro__
if Generic in mro and mro.index(Generic) < mro.index(BaseModel_):
warnings.warn(
GenericBeforeBaseModelWarning(
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
'for pydantic generics to work properly.'
),
stacklevel=2,
)
cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore
from ..main import BaseModel
cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
cls.__pydantic_post_init__ = (
None if cls.model_post_init is BaseModel_.model_post_init else 'model_post_init'
)
cls.__pydantic_setattr_handlers__ = {}
cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init'
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
@@ -161,40 +142,22 @@ class ModelMetaclass(ABCMeta):
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
parameters = getattr(cls, '__parameters__', None) or parent_parameters
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
from ..root_model import RootModelRootType
missing_parameters = tuple(x for x in parameters if x not in parent_parameters)
if RootModelRootType in parent_parameters and RootModelRootType not in parameters:
# This is a special case where the user has subclassed `RootModel`, but has not parametrized
# RootModel with the generic type identifiers being used. Ex:
# class MyModel(RootModel, Generic[T]):
# root: T
# Should instead just be:
# class MyModel(RootModel[T]):
# root: T
parameters_str = ', '.join([x.__name__ for x in missing_parameters])
error_message = (
f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) '
f'{parameters_str} in its parameters. '
f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.'
combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters)
parameters_str = ', '.join([str(x) for x in combined_parameters])
generic_type_label = f'typing.Generic[{parameters_str}]'
error_message = (
f'All parameters must be present on typing.Generic;'
f' you should inherit from {generic_type_label}.'
)
if Generic not in bases: # pragma: no cover
# We raise an error here not because it is desirable, but because some cases are mishandled.
# It would be nice to remove this error and still have things behave as expected, it's just
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
# and not returning a typing._GenericAlias from it.
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
error_message += (
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
)
else:
combined_parameters = parent_parameters + missing_parameters
parameters_str = ', '.join([str(x) for x in combined_parameters])
generic_type_label = f'typing.Generic[{parameters_str}]'
error_message = (
f'All parameters must be present on typing.Generic;'
f' you should inherit from {generic_type_label}.'
)
if Generic not in bases: # pragma: no cover
# We raise an error here not because it is desirable, but because some cases are mishandled.
# It would be nice to remove this error and still have things behave as expected, it's just
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
# and not returning a typing._GenericAlias from it.
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
error_message += (
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
)
raise TypeError(error_message)
cls.__pydantic_generic_metadata__ = {
@@ -212,55 +175,29 @@ class ModelMetaclass(ABCMeta):
if __pydantic_reset_parent_namespace__:
cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace())
parent_namespace: dict[str, Any] | None = getattr(cls, '__pydantic_parent_namespace__', None)
parent_namespace = getattr(cls, '__pydantic_parent_namespace__', None)
if isinstance(parent_namespace, dict):
parent_namespace = unpack_lenient_weakvaluedict(parent_namespace)
ns_resolver = NsResolver(parent_namespace=parent_namespace)
set_model_fields(cls, config_wrapper=config_wrapper, ns_resolver=ns_resolver)
# This is also set in `complete_model_class()`, after schema gen because they are recreated.
# We set them here as well for backwards compatibility:
cls.__pydantic_computed_fields__ = {
k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()
}
if config_wrapper.defer_build:
# TODO we can also stop there if `__pydantic_fields_complete__` is False.
# However, `set_model_fields()` is currently lenient and we don't have access to the `NameError`.
# (which is useful as we can provide the name in the error message: `set_model_mock(cls, e.name)`)
set_model_mocks(cls)
else:
# Any operation that requires accessing the field infos instances should be put inside
# `complete_model_class()`:
complete_model_class(
cls,
config_wrapper,
raise_errors=False,
ns_resolver=ns_resolver,
create_model_module=_create_model_module,
)
if config_wrapper.frozen and '__hash__' not in namespace:
set_default_hash_func(cls, bases)
types_namespace = get_cls_types_namespace(cls, parent_namespace)
set_model_fields(cls, bases, config_wrapper, types_namespace)
complete_model_class(
cls,
cls_name,
config_wrapper,
raise_errors=False,
types_namespace=types_namespace,
)
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
# only hit for _proper_ subclasses of BaseModel
super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc]
return cls
else:
# These are instance variables, but have been assigned to `NoInitField` to trick the type checker.
for instance_slot in '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__':
namespace.pop(
instance_slot,
None, # In case the metaclass is used with a class other than `BaseModel`.
)
namespace.get('__annotations__', {}).clear()
# this is the BaseModel class itself being created, no logic required
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
if not typing.TYPE_CHECKING: # pragma: no branch
if not typing.TYPE_CHECKING:
# We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
def __getattr__(self, item: str) -> Any:
@@ -268,29 +205,30 @@ class ModelMetaclass(ABCMeta):
private_attributes = self.__dict__.get('__private_attributes__')
if private_attributes and item in private_attributes:
return private_attributes[item]
if item == '__pydantic_core_schema__':
# This means the class didn't get a schema generated for it, likely because there was an undefined reference
maybe_mock_validator = getattr(self, '__pydantic_validator__', None)
if isinstance(maybe_mock_validator, MockValSer):
rebuilt_validator = maybe_mock_validator.rebuild()
if rebuilt_validator is not None:
# In this case, a validator was built, and so `__pydantic_core_schema__` should now be set
return getattr(self, '__pydantic_core_schema__')
raise AttributeError(item)
@classmethod
def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]:
def __prepare__(cls, *args: Any, **kwargs: Any) -> Mapping[str, object]:
return _ModelNamespaceDict()
def __instancecheck__(self, instance: Any) -> bool:
"""Avoid calling ABC _abc_instancecheck unless we're pretty sure.
See #3829 and python/cpython#92810
"""
return hasattr(instance, '__pydantic_decorators__') and super().__instancecheck__(instance)
def __subclasscheck__(self, subclass: type[Any]) -> bool:
"""Avoid calling ABC _abc_subclasscheck unless we're pretty sure.
See #3829 and python/cpython#92810
"""
return hasattr(subclass, '__pydantic_decorators__') and super().__subclasscheck__(subclass)
return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance)
@staticmethod
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]:
BaseModel = import_cached_base_model()
from ..main import BaseModel
field_names: set[str] = set()
class_vars: set[str] = set()
@@ -298,57 +236,35 @@ class ModelMetaclass(ABCMeta):
for base in bases:
if issubclass(base, BaseModel) and base is not BaseModel:
# model_fields might not be defined yet in the case of generics, so we use getattr here:
field_names.update(getattr(base, '__pydantic_fields__', {}).keys())
field_names.update(getattr(base, 'model_fields', {}).keys())
class_vars.update(base.__class_vars__)
private_attributes.update(base.__private_attributes__)
return field_names, class_vars, private_attributes
@property
@deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None)
@deprecated(
'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20
)
def __fields__(self) -> dict[str, FieldInfo]:
warnings.warn(
'The `__fields__` attribute is deprecated, use `model_fields` instead.',
PydanticDeprecatedSince20,
stacklevel=2,
)
return getattr(self, '__pydantic_fields__', {})
@property
def __pydantic_fields_complete__(self) -> bool:
"""Whether the fields where successfully collected (i.e. type hints were successfully resolves).
This is a private attribute, not meant to be used outside Pydantic.
"""
if not hasattr(self, '__pydantic_fields__'):
return False
field_infos = cast('dict[str, FieldInfo]', self.__pydantic_fields__) # pyright: ignore[reportAttributeAccessIssue]
return all(field_info._complete for field_info in field_infos.values())
def __dir__(self) -> list[str]:
attributes = list(super().__dir__())
if '__fields__' in attributes:
attributes.remove('__fields__')
return attributes
warnings.warn('The `__fields__` attribute is deprecated, use `model_fields` instead.', DeprecationWarning)
return self.model_fields # type: ignore
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
def init_private_attributes(self: BaseModel, __context: Any) -> None:
"""This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args:
self: The BaseModel instance.
context: The context.
__context: The context.
"""
if getattr(self, '__pydantic_private__', None) is None:
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None:
@@ -356,7 +272,7 @@ def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...])
if 'model_post_init' in namespace:
return namespace['model_post_init']
BaseModel = import_cached_base_model()
from ..main import BaseModel
model_post_init = get_attribute_from_bases(bases, 'model_post_init')
if model_post_init is not BaseModel.model_post_init:
@@ -389,11 +305,7 @@ def inspect_namespace( # noqa C901
- If a field does not have a type annotation.
- If a field on base class was overridden by a non-annotated attribute.
"""
from ..fields import ModelPrivateAttr, PrivateAttr
FieldInfo = import_cached_field_info()
all_ignored_types = ignored_types + default_ignored_types()
all_ignored_types = ignored_types + IGNORED_TYPES
private_attributes: dict[str, ModelPrivateAttr] = {}
raw_annotations = namespace.get('__annotations__', {})
@@ -403,12 +315,11 @@ def inspect_namespace( # noqa C901
ignored_names: set[str] = set()
for var_name, value in list(namespace.items()):
if var_name == 'model_config' or var_name == '__pydantic_extra__':
if var_name == 'model_config':
continue
elif (
isinstance(value, type)
and value.__module__ == namespace['__module__']
and '__qualname__' in namespace
and value.__qualname__.startswith(namespace['__qualname__'])
):
# `value` is a nested type defined in this namespace; don't error
@@ -439,8 +350,8 @@ def inspect_namespace( # noqa C901
elif var_name.startswith('__'):
continue
elif is_valid_privateattr_name(var_name):
if var_name not in raw_annotations or not is_classvar_annotation(raw_annotations[var_name]):
private_attributes[var_name] = cast(ModelPrivateAttr, PrivateAttr(default=value))
if var_name not in raw_annotations or not is_classvar(raw_annotations[var_name]):
private_attributes[var_name] = PrivateAttr(default=value)
del namespace[var_name]
elif var_name in base_class_vars:
continue
@@ -457,8 +368,8 @@ def inspect_namespace( # noqa C901
)
else:
raise PydanticUserError(
f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a '
f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this '
f"A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a "
f"type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this "
f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.",
code='model-field-missing-annotation',
)
@@ -468,82 +379,45 @@ def inspect_namespace( # noqa C901
is_valid_privateattr_name(ann_name)
and ann_name not in private_attributes
and ann_name not in ignored_names
# This condition can be a false negative when `ann_type` is stringified,
# but it is handled in most cases in `set_model_fields`:
and not is_classvar_annotation(ann_type)
and not is_classvar(ann_type)
and ann_type not in all_ignored_types
and getattr(ann_type, '__module__', None) != 'functools'
):
if isinstance(ann_type, str):
# Walking up the frames to get the module namespace where the model is defined
# (as the model class wasn't created yet, we unfortunately can't use `cls.__module__`):
frame = sys._getframe(2)
if frame is not None:
try:
ann_type = eval_type_backport(
_make_forward_ref(ann_type, is_argument=False, is_class=True),
globalns=frame.f_globals,
localns=frame.f_locals,
)
except (NameError, TypeError):
pass
if typing_objects.is_annotated(get_origin(ann_type)):
_, *metadata = get_args(ann_type)
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
if private_attr is not None:
private_attributes[ann_name] = private_attr
continue
private_attributes[ann_name] = PrivateAttr()
return private_attributes
def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> None:
if '__hash__' in namespace:
return
base_hash_func = get_attribute_from_bases(bases, '__hash__')
new_hash_func = make_hash_func(cls)
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
# If `__hash__` is some default, we generate a hash function.
# It will be `None` if not overridden from BaseModel.
# It may be `object.__hash__` if there is another
if base_hash_func in {None, object.__hash__}:
# If `__hash__` is None _or_ `object.__hash__`, we generate a hash function.
# It will be `None` if not overridden from BaseModel, but may be `object.__hash__` if there is another
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
# In the last case we still need a new hash function to account for new `model_fields`.
cls.__hash__ = new_hash_func
def hash_func(self: Any) -> int:
return hash(self.__class__) + hash(tuple(self.__dict__.values()))
def make_hash_func(cls: type[BaseModel]) -> Any:
getter = operator.itemgetter(*cls.__pydantic_fields__.keys()) if cls.__pydantic_fields__ else lambda _: 0
def hash_func(self: Any) -> int:
try:
return hash(getter(self.__dict__))
except KeyError:
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
# all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
# and wrapping it in a `try` doesn't slow things down much in the common case.
return hash(getter(SafeGetItemProxy(self.__dict__)))
return hash_func
namespace['__hash__'] = hash_func
def set_model_fields(
cls: type[BaseModel],
config_wrapper: ConfigWrapper,
ns_resolver: NsResolver | None,
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
) -> None:
"""Collect and set `cls.__pydantic_fields__` and `cls.__class_vars__`.
"""Collect and set `cls.model_fields` and `cls.__class_vars__`.
Args:
cls: BaseModel or dataclass.
bases: Parents of the class, generally `cls.__bases__`.
config_wrapper: The config wrapper instance.
ns_resolver: Namespace resolver to use when getting model annotations.
types_namespace: Optional extra namespace to look for types in.
"""
typevars_map = get_model_typevars_map(cls)
fields, class_vars = collect_model_fields(cls, config_wrapper, ns_resolver, typevars_map=typevars_map)
fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map)
cls.__pydantic_fields__ = fields
cls.model_fields = fields
cls.__class_vars__.update(class_vars)
for k in class_vars:
@@ -561,11 +435,11 @@ def set_model_fields(
def complete_model_class(
cls: type[BaseModel],
cls_name: str,
config_wrapper: ConfigWrapper,
*,
raise_errors: bool = True,
ns_resolver: NsResolver | None = None,
create_model_module: str | None = None,
types_namespace: dict[str, Any] | None,
) -> bool:
"""Finish building a model class.
@@ -574,10 +448,10 @@ def complete_model_class(
Args:
cls: BaseModel or dataclass.
cls_name: The model or dataclass name.
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors.
ns_resolver: The namespace resolver instance to use during schema building.
create_model_module: The module of the class to be created, if created by `create_model`.
types_namespace: Optional extra namespace to look for types in.
Returns:
`True` if the model is successfully completed, else `False`.
@@ -589,151 +463,132 @@ def complete_model_class(
typevars_map = get_model_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
ns_resolver,
types_namespace,
typevars_map,
)
handler = CallbackGetCoreSchemaHandler(
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
gen_schema,
ref_mode='unpack',
)
if config_wrapper.defer_build:
set_model_mocks(cls, cls_name)
return False
try:
schema = gen_schema.generate_schema(cls)
schema = cls.__get_pydantic_core_schema__(cls, handler)
except PydanticUndefinedAnnotation as e:
if raise_errors:
raise
set_model_mocks(cls, f'`{e.name}`')
set_model_mocks(cls, cls_name, f'`{e.name}`')
return False
core_config = config_wrapper.core_config(title=cls.__name__)
core_config = config_wrapper.core_config(cls)
try:
schema = gen_schema.clean_schema(schema)
except InvalidSchemaError:
set_model_mocks(cls)
schema = gen_schema.collect_definitions(schema)
schema = apply_discriminators(simplify_schema_references(schema))
if collect_invalid_schemas(schema):
set_model_mocks(cls, cls_name)
return False
# This needs to happen *after* model schema generation, as the return type
# of the properties are evaluated and the `ComputedFieldInfo` are recreated:
cls.__pydantic_computed_fields__ = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()}
set_deprecated_descriptors(cls)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = create_schema_validator(
schema,
cls,
create_model_module or cls.__module__,
cls.__qualname__,
'create_model' if create_model_module else 'BaseModel',
core_config,
config_wrapper.plugin_settings,
)
# debug(schema)
cls.__pydantic_core_schema__ = schema = validate_core_schema(schema)
cls.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True
# set __signature__ attr only for model class, but not for its instances
# (because instances can define `__call__`, and `inspect.signature` shouldn't
# use the `__signature__` attribute and instead generate from `__call__`).
cls.__signature__ = LazyClassAttribute(
'__signature__',
partial(
generate_pydantic_signature,
init=cls.__init__,
fields=cls.__pydantic_fields__,
validate_by_name=config_wrapper.validate_by_name,
extra=config_wrapper.extra,
),
cls.__signature__ = ClassAttribute(
'__signature__', generate_model_signature(cls.__init__, cls.model_fields, config_wrapper)
)
return True
def set_deprecated_descriptors(cls: type[BaseModel]) -> None:
"""Set data descriptors on the class for deprecated fields."""
for field, field_info in cls.__pydantic_fields__.items():
if (msg := field_info.deprecation_message) is not None:
desc = _DeprecatedFieldDescriptor(msg)
desc.__set_name__(cls, field)
setattr(cls, field, desc)
def generate_model_signature(
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
) -> Signature:
"""Generate signature for model based on its fields.
for field, computed_field_info in cls.__pydantic_computed_fields__.items():
if (
(msg := computed_field_info.deprecation_message) is not None
# Avoid having two warnings emitted:
and not hasattr(unwrap_wrapped_function(computed_field_info.wrapped_property), '__deprecated__')
):
desc = _DeprecatedFieldDescriptor(msg, computed_field_info.wrapped_property)
desc.__set_name__(cls, field)
setattr(cls, field, desc)
Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
class _DeprecatedFieldDescriptor:
"""Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
Attributes:
msg: The deprecation message to be emitted.
wrapped_property: The property instance if the deprecated field is a computed field, or `None`.
field_name: The name of the field being deprecated.
Returns:
The model signature.
"""
from inspect import Parameter, Signature, signature
from itertools import islice
field_name: str
present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False
def __init__(self, msg: str, wrapped_property: property | None = None) -> None:
self.msg = msg
self.wrapped_property = wrapped_property
for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = param
def __set_name__(self, cls: type[BaseModel], name: str) -> None:
self.field_name = name
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
if isinstance(field.alias, str):
param_name = field.alias
else:
param_name = field_name
def __get__(self, obj: BaseModel | None, obj_type: type[BaseModel] | None = None) -> Any:
if obj is None:
if self.wrapped_property is not None:
return self.wrapped_property.__get__(None, obj_type)
raise AttributeError(self.field_name)
if field_name in merged_params or param_name in merged_params:
continue
warnings.warn(self.msg, builtins.DeprecationWarning, stacklevel=2)
if not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue
if self.wrapped_property is not None:
return self.wrapped_property.__get__(obj, obj_type)
return obj.__dict__[self.field_name]
kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = Parameter(
param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs
)
# Defined to make it a data descriptor and take precedence over the instance's dictionary.
# Note that it will not be called when setting a value on a model instance
# as `BaseModel.__setattr__` is defined and takes priority.
def __set__(self, obj: Any, value: Any) -> NoReturn:
raise AttributeError(self.field_name)
if config_wrapper.extra == 'allow':
use_var_kw = True
class _PydanticWeakRef:
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
`weakref.ref` instead of subclassing it.
See https://github.com/pydantic/pydantic/issues/6763 for context.
Semantics:
- If not pickled, behaves the same as a `weakref.ref`.
- If pickled along with the referenced object, the same `weakref.ref` behavior
will be maintained between them after unpickling.
- If pickled without the referenced object, after unpickling the underlying
reference will be cleared (`__call__` will always return `None`).
"""
def __init__(self, obj: Any):
if obj is None:
# The object will be `None` upon deserialization if the serialized weakref
# had lost its underlying object.
self._wr = None
if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
self._wr = weakref.ref(obj)
# else start from var_kw
var_kw_name = var_kw.name
def __call__(self) -> Any:
if self._wr is None:
return None
else:
return self._wr()
# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
return _PydanticWeakRef, (self(),)
return Signature(parameters=list(merged_params.values()), return_annotation=None)
class _PydanticWeakRef(weakref.ReferenceType):
pass
def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
@@ -770,23 +625,3 @@ def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | N
else:
result[k] = v
return result
@cache
def default_ignored_types() -> tuple[type[Any], ...]:
from ..fields import ComputedFieldInfo
ignored_types = [
FunctionType,
property,
classmethod,
staticmethod,
PydanticDescriptorProxy,
ComputedFieldInfo,
TypeAliasType, # from `typing_extensions`
]
if sys.version_info >= (3, 12):
ignored_types.append(typing.TypeAliasType)
return tuple(ignored_types)

View File

@@ -1,293 +0,0 @@
from __future__ import annotations
import sys
from collections.abc import Generator, Iterator, Mapping
from contextlib import contextmanager
from functools import cached_property
from typing import Any, Callable, NamedTuple, TypeVar
from typing_extensions import ParamSpec, TypeAlias, TypeAliasType, TypeVarTuple
GlobalsNamespace: TypeAlias = 'dict[str, Any]'
"""A global namespace.
In most cases, this is a reference to the `__dict__` attribute of a module.
This namespace type is expected as the `globals` argument during annotations evaluation.
"""
MappingNamespace: TypeAlias = Mapping[str, Any]
"""Any kind of namespace.
In most cases, this is a local namespace (e.g. the `__dict__` attribute of a class,
the [`f_locals`][frame.f_locals] attribute of a frame object, when dealing with types
defined inside functions).
This namespace type is expected as the `locals` argument during annotations evaluation.
"""
_TypeVarLike: TypeAlias = 'TypeVar | ParamSpec | TypeVarTuple'
class NamespacesTuple(NamedTuple):
"""A tuple of globals and locals to be used during annotations evaluation.
This datastructure is defined as a named tuple so that it can easily be unpacked:
```python {lint="skip" test="skip"}
def eval_type(typ: type[Any], ns: NamespacesTuple) -> None:
return eval(typ, *ns)
```
"""
globals: GlobalsNamespace
"""The namespace to be used as the `globals` argument during annotations evaluation."""
locals: MappingNamespace
"""The namespace to be used as the `locals` argument during annotations evaluation."""
def get_module_ns_of(obj: Any) -> dict[str, Any]:
"""Get the namespace of the module where the object is defined.
Caution: this function does not return a copy of the module namespace, so the result
should not be mutated. The burden of enforcing this is on the caller.
"""
module_name = getattr(obj, '__module__', None)
if module_name:
try:
return sys.modules[module_name].__dict__
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
return {}
return {}
# Note that this class is almost identical to `collections.ChainMap`, but need to enforce
# immutable mappings here:
class LazyLocalNamespace(Mapping[str, Any]):
"""A lazily evaluated mapping, to be used as the `locals` argument during annotations evaluation.
While the [`eval`][eval] function expects a mapping as the `locals` argument, it only
performs `__getitem__` calls. The [`Mapping`][collections.abc.Mapping] abstract base class
is fully implemented only for type checking purposes.
Args:
*namespaces: The namespaces to consider, in ascending order of priority.
Example:
```python {lint="skip" test="skip"}
ns = LazyLocalNamespace({'a': 1, 'b': 2}, {'a': 3})
ns['a']
#> 3
ns['b']
#> 2
```
"""
def __init__(self, *namespaces: MappingNamespace) -> None:
self._namespaces = namespaces
@cached_property
def data(self) -> dict[str, Any]:
return {k: v for ns in self._namespaces for k, v in ns.items()}
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, key: str) -> Any:
return self.data[key]
def __contains__(self, key: object) -> bool:
return key in self.data
def __iter__(self) -> Iterator[str]:
return iter(self.data)
def ns_for_function(obj: Callable[..., Any], parent_namespace: MappingNamespace | None = None) -> NamespacesTuple:
"""Return the global and local namespaces to be used when evaluating annotations for the provided function.
The global namespace will be the `__dict__` attribute of the module the function was defined in.
The local namespace will contain the `__type_params__` introduced by PEP 695.
Args:
obj: The object to use when building namespaces.
parent_namespace: Optional namespace to be added with the lowest priority in the local namespace.
If the passed function is a method, the `parent_namespace` will be the namespace of the class
the method is defined in. Thus, we also fetch type `__type_params__` from there (i.e. the
class-scoped type variables).
"""
locals_list: list[MappingNamespace] = []
if parent_namespace is not None:
locals_list.append(parent_namespace)
# Get the `__type_params__` attribute introduced by PEP 695.
# Note that the `typing._eval_type` function expects type params to be
# passed as a separate argument. However, internally, `_eval_type` calls
# `ForwardRef._evaluate` which will merge type params with the localns,
# essentially mimicking what we do here.
type_params: tuple[_TypeVarLike, ...] = getattr(obj, '__type_params__', ())
if parent_namespace is not None:
# We also fetch type params from the parent namespace. If present, it probably
# means the function was defined in a class. This is to support the following:
# https://github.com/python/cpython/issues/124089.
type_params += parent_namespace.get('__type_params__', ())
locals_list.append({t.__name__: t for t in type_params})
# What about short-cirtuiting to `obj.__globals__`?
globalns = get_module_ns_of(obj)
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
class NsResolver:
"""A class responsible for the namespaces resolving logic for annotations evaluation.
This class handles the namespace logic when evaluating annotations mainly for class objects.
It holds a stack of classes that are being inspected during the core schema building,
and the `types_namespace` property exposes the globals and locals to be used for
type annotation evaluation. Additionally -- if no class is present in the stack -- a
fallback globals and locals can be provided using the `namespaces_tuple` argument
(this is useful when generating a schema for a simple annotation, e.g. when using
`TypeAdapter`).
The namespace creation logic is unfortunately flawed in some cases, for backwards
compatibility reasons and to better support valid edge cases. See the description
for the `parent_namespace` argument and the example for more details.
Args:
namespaces_tuple: The default globals and locals to use if no class is present
on the stack. This can be useful when using the `GenerateSchema` class
with `TypeAdapter`, where the "type" being analyzed is a simple annotation.
parent_namespace: An optional parent namespace that will be added to the locals
with the lowest priority. For a given class defined in a function, the locals
of this function are usually used as the parent namespace:
```python {lint="skip" test="skip"}
from pydantic import BaseModel
def func() -> None:
SomeType = int
class Model(BaseModel):
f: 'SomeType'
# when collecting fields, an namespace resolver instance will be created
# this way:
# ns_resolver = NsResolver(parent_namespace={'SomeType': SomeType})
```
For backwards compatibility reasons and to support valid edge cases, this parent
namespace will be used for *every* type being pushed to the stack. In the future,
we might want to be smarter by only doing so when the type being pushed is defined
in the same module as the parent namespace.
Example:
```python {lint="skip" test="skip"}
ns_resolver = NsResolver(
parent_namespace={'fallback': 1},
)
class Sub:
m: 'Model'
class Model:
some_local = 1
sub: Sub
ns_resolver = NsResolver()
# This is roughly what happens when we build a core schema for `Model`:
with ns_resolver.push(Model):
ns_resolver.types_namespace
#> NamespacesTuple({'Sub': Sub}, {'Model': Model, 'some_local': 1})
# First thing to notice here, the model being pushed is added to the locals.
# Because `NsResolver` is being used during the model definition, it is not
# yet added to the globals. This is useful when resolving self-referencing annotations.
with ns_resolver.push(Sub):
ns_resolver.types_namespace
#> NamespacesTuple({'Sub': Sub}, {'Sub': Sub, 'Model': Model})
# Second thing to notice: `Sub` is present in both the globals and locals.
# This is not an issue, just that as described above, the model being pushed
# is added to the locals, but it happens to be present in the globals as well
# because it is already defined.
# Third thing to notice: `Model` is also added in locals. This is a backwards
# compatibility workaround that allows for `Sub` to be able to resolve `'Model'`
# correctly (as otherwise models would have to be rebuilt even though this
# doesn't look necessary).
```
"""
def __init__(
self,
namespaces_tuple: NamespacesTuple | None = None,
parent_namespace: MappingNamespace | None = None,
) -> None:
self._base_ns_tuple = namespaces_tuple or NamespacesTuple({}, {})
self._parent_ns = parent_namespace
self._types_stack: list[type[Any] | TypeAliasType] = []
@cached_property
def types_namespace(self) -> NamespacesTuple:
"""The current global and local namespaces to be used for annotations evaluation."""
if not self._types_stack:
# TODO: should we merge the parent namespace here?
# This is relevant for TypeAdapter, where there are no types on the stack, and we might
# need access to the parent_ns. Right now, we sidestep this in `type_adapter.py` by passing
# locals to both parent_ns and the base_ns_tuple, but this is a bit hacky.
# we might consider something like:
# if self._parent_ns is not None:
# # Hacky workarounds, see class docstring:
# # An optional parent namespace that will be added to the locals with the lowest priority
# locals_list: list[MappingNamespace] = [self._parent_ns, self._base_ns_tuple.locals]
# return NamespacesTuple(self._base_ns_tuple.globals, LazyLocalNamespace(*locals_list))
return self._base_ns_tuple
typ = self._types_stack[-1]
globalns = get_module_ns_of(typ)
locals_list: list[MappingNamespace] = []
# Hacky workarounds, see class docstring:
# An optional parent namespace that will be added to the locals with the lowest priority
if self._parent_ns is not None:
locals_list.append(self._parent_ns)
if len(self._types_stack) > 1:
first_type = self._types_stack[0]
locals_list.append({first_type.__name__: first_type})
# Adding `__type_params__` *before* `vars(typ)`, as the latter takes priority
# (see https://github.com/python/cpython/pull/120272).
# TODO `typ.__type_params__` when we drop support for Python 3.11:
type_params: tuple[_TypeVarLike, ...] = getattr(typ, '__type_params__', ())
if type_params:
# Adding `__type_params__` is mostly useful for generic classes defined using
# PEP 695 syntax *and* using forward annotations (see the example in
# https://github.com/python/cpython/issues/114053). For TypeAliasType instances,
# it is way less common, but still required if using a string annotation in the alias
# value, e.g. `type A[T] = 'T'` (which is not necessary in most cases).
locals_list.append({t.__name__: t for t in type_params})
# TypeAliasType instances don't have a `__dict__` attribute, so the check
# is necessary:
if hasattr(typ, '__dict__'):
locals_list.append(vars(typ))
# The `len(self._types_stack) > 1` check above prevents this from being added twice:
locals_list.append({typ.__name__: typ})
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
@contextmanager
def push(self, typ: type[Any] | TypeAliasType, /) -> Generator[None]:
"""Push a type to the stack."""
self._types_stack.append(typ)
# Reset the cached property:
self.__dict__.pop('types_namespace', None)
try:
yield
finally:
self._types_stack.pop()
self.__dict__.pop('types_namespace', None)

View File

@@ -1,5 +1,4 @@
"""Tools to provide pretty/human-readable display of objects."""
from __future__ import annotations as _annotations
import types
@@ -7,8 +6,6 @@ import typing
from typing import Any
import typing_extensions
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin
from . import _typing_extra
@@ -35,7 +32,7 @@ class Representation:
# (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
# we don't want to use a type annotation here as it can break get_type_hints
__slots__ = () # type: typing.Collection[str]
__slots__ = tuple() # type: typing.Collection[str]
def __repr_args__(self) -> ReprArgs:
"""Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
@@ -48,17 +45,12 @@ class Representation:
if not attrs_names and hasattr(self, '__dict__'):
attrs_names = self.__dict__.keys()
attrs = ((s, getattr(self, s)) for s in attrs_names)
return [(a, v if v is not self else self.__repr_recursion__(v)) for a, v in attrs if v is not None]
return [(a, v) for a, v in attrs if v is not None]
def __repr_name__(self) -> str:
"""Name of the instance's class, used in __repr__."""
return self.__class__.__name__
def __repr_recursion__(self, object: Any) -> str:
"""Returns the string representation of a recursive object."""
# This is copied over from the stdlib `pprint` module:
return f'<Recursion on {type(object).__name__} with id={id(object)}>'
def __repr_str__(self, join_str: str) -> str:
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
@@ -95,30 +87,25 @@ def display_as_type(obj: Any) -> str:
Takes some logic from `typing._type_repr`.
"""
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
if isinstance(obj, types.FunctionType):
return obj.__name__
elif obj is ...:
return '...'
elif isinstance(obj, Representation):
return repr(obj)
elif isinstance(obj, typing.ForwardRef) or typing_objects.is_typealiastype(obj):
return str(obj)
if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
obj = obj.__class__
if is_union_origin(typing_extensions.get_origin(obj)):
if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)):
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
return f'Union[{args}]'
elif isinstance(obj, _typing_extra.WithArgsTypes):
if typing_objects.is_literal(typing_extensions.get_origin(obj)):
if typing_extensions.get_origin(obj) == typing_extensions.Literal:
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
else:
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
try:
return f'{obj.__qualname__}[{args}]'
except AttributeError:
return str(obj).replace('typing.', '').replace('typing_extensions.', '') # handles TypeAliasType in 3.12
return f'{obj.__qualname__}[{args}]'
elif isinstance(obj, type):
return obj.__qualname__
else:

View File

@@ -1,209 +0,0 @@
# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TypedDict
from pydantic_core.core_schema import ComputedField, CoreSchema, DefinitionReferenceSchema, SerSchema
from typing_extensions import TypeAlias
AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField'
class GatherResult(TypedDict):
"""Schema traversing result."""
collected_references: dict[str, DefinitionReferenceSchema | None]
"""The collected definition references.
If a definition reference schema can be inlined, it means that there is
only one in the whole core schema. As such, it is stored as the value.
Otherwise, the value is set to `None`.
"""
deferred_discriminator_schemas: list[CoreSchema]
"""The list of core schemas having the discriminator application deferred."""
class MissingDefinitionError(LookupError):
"""A reference was pointing to a non-existing core schema."""
def __init__(self, schema_reference: str, /) -> None:
self.schema_reference = schema_reference
@dataclass
class GatherContext:
"""The current context used during core schema traversing.
Context instances should only be used during schema traversing.
"""
definitions: dict[str, CoreSchema]
"""The available definitions."""
deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list)
"""The list of core schemas having the discriminator application deferred.
Internally, these core schemas have a specific key set in the core metadata dict.
"""
collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict)
"""The collected definition references.
If a definition reference schema can be inlined, it means that there is
only one in the whole core schema. As such, it is stored as the value.
Otherwise, the value is set to `None`.
During schema traversing, definition reference schemas can be added as candidates, or removed
(by setting the value to `None`).
"""
def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None:
meta = schema.get('metadata')
if meta is not None and 'pydantic_internal_union_discriminator' in meta:
ctx.deferred_discriminator_schemas.append(schema) # pyright: ignore[reportArgumentType]
def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None:
schema_ref = def_ref_schema['schema_ref']
if schema_ref not in ctx.collected_references:
definition = ctx.definitions.get(schema_ref)
if definition is None:
raise MissingDefinitionError(schema_ref)
# The `'definition-ref'` schema was only encountered once, make it
# a candidate to be inlined:
ctx.collected_references[schema_ref] = def_ref_schema
traverse_schema(definition, ctx)
if 'serialization' in def_ref_schema:
traverse_schema(def_ref_schema['serialization'], ctx)
traverse_metadata(def_ref_schema, ctx)
else:
# The `'definition-ref'` schema was already encountered, meaning
# the previously encountered schema (and this one) can't be inlined:
ctx.collected_references[schema_ref] = None
def traverse_schema(schema: AllSchemas, context: GatherContext) -> None:
# TODO When we drop 3.9, use a match statement to get better type checking and remove
# file-level type ignore.
# (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance).
schema_type = schema['type']
if schema_type == 'definition-ref':
traverse_definition_ref(schema, context)
# `traverse_definition_ref` handles the possible serialization and metadata schemas:
return
elif schema_type == 'definitions':
traverse_schema(schema['schema'], context)
for definition in schema['definitions']:
traverse_schema(definition, context)
elif schema_type in {'list', 'set', 'frozenset', 'generator'}:
if 'items_schema' in schema:
traverse_schema(schema['items_schema'], context)
elif schema_type == 'tuple':
if 'items_schema' in schema:
for s in schema['items_schema']:
traverse_schema(s, context)
elif schema_type == 'dict':
if 'keys_schema' in schema:
traverse_schema(schema['keys_schema'], context)
if 'values_schema' in schema:
traverse_schema(schema['values_schema'], context)
elif schema_type == 'union':
for choice in schema['choices']:
if isinstance(choice, tuple):
traverse_schema(choice[0], context)
else:
traverse_schema(choice, context)
elif schema_type == 'tagged-union':
for v in schema['choices'].values():
traverse_schema(v, context)
elif schema_type == 'chain':
for step in schema['steps']:
traverse_schema(step, context)
elif schema_type == 'lax-or-strict':
traverse_schema(schema['lax_schema'], context)
traverse_schema(schema['strict_schema'], context)
elif schema_type == 'json-or-python':
traverse_schema(schema['json_schema'], context)
traverse_schema(schema['python_schema'], context)
elif schema_type in {'model-fields', 'typed-dict'}:
if 'extras_schema' in schema:
traverse_schema(schema['extras_schema'], context)
if 'computed_fields' in schema:
for s in schema['computed_fields']:
traverse_schema(s, context)
for s in schema['fields'].values():
traverse_schema(s, context)
elif schema_type == 'dataclass-args':
if 'computed_fields' in schema:
for s in schema['computed_fields']:
traverse_schema(s, context)
for s in schema['fields']:
traverse_schema(s, context)
elif schema_type == 'arguments':
for s in schema['arguments_schema']:
traverse_schema(s['schema'], context)
if 'var_args_schema' in schema:
traverse_schema(schema['var_args_schema'], context)
if 'var_kwargs_schema' in schema:
traverse_schema(schema['var_kwargs_schema'], context)
elif schema_type == 'arguments-v3':
for s in schema['arguments_schema']:
traverse_schema(s['schema'], context)
elif schema_type == 'call':
traverse_schema(schema['arguments_schema'], context)
if 'return_schema' in schema:
traverse_schema(schema['return_schema'], context)
elif schema_type == 'computed-field':
traverse_schema(schema['return_schema'], context)
elif schema_type == 'function-before':
if 'schema' in schema:
traverse_schema(schema['schema'], context)
if 'json_schema_input_schema' in schema:
traverse_schema(schema['json_schema_input_schema'], context)
elif schema_type == 'function-plain':
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
if 'return_schema' in schema:
traverse_schema(schema['return_schema'], context)
if 'json_schema_input_schema' in schema:
traverse_schema(schema['json_schema_input_schema'], context)
elif schema_type == 'function-wrap':
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
if 'return_schema' in schema:
traverse_schema(schema['return_schema'], context)
if 'schema' in schema:
traverse_schema(schema['schema'], context)
if 'json_schema_input_schema' in schema:
traverse_schema(schema['json_schema_input_schema'], context)
else:
if 'schema' in schema:
traverse_schema(schema['schema'], context)
if 'serialization' in schema:
traverse_schema(schema['serialization'], context)
traverse_metadata(schema, context)
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
"""Traverse the core schema and definitions and return the necessary information for schema cleaning.
During the core schema traversing, any `'definition-ref'` schema is:
- Validated: the reference must point to an existing definition. If this is not the case, a
`MissingDefinitionError` exception is raised.
- Stored in the context: the actual reference is stored in the context. Depending on whether
the `'definition-ref'` schema is encountered more that once, the schema itself is also
saved in the context to be inlined (i.e. replaced by the definition it points to).
"""
context = GatherContext(definitions)
traverse_schema(schema, context)
return {
'collected_references': context.collected_references,
'deferred_discriminator_schemas': context.deferred_discriminator_schemas,
}

View File

@@ -1,10 +1,10 @@
"""Types and utility functions used by various other internal tools."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable
from pydantic_core import core_schema
from typing_extensions import Literal
from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
@@ -12,7 +12,6 @@ if TYPE_CHECKING:
from ..json_schema import GenerateJsonSchema, JsonSchemaValue
from ._core_utils import CoreSchemaOrField
from ._generate_schema import GenerateSchema
from ._namespace_utils import NamespacesTuple
GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue]
HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue]
@@ -33,8 +32,8 @@ class GenerateJsonSchemaHandler(GetJsonSchemaHandler):
self.handler = handler_override or generate_json_schema.generate_inner
self.mode = generate_json_schema.mode
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
return self.handler(core_schema)
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
return self.handler(__core_schema)
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
"""Resolves `$ref` in the json schema.
@@ -79,21 +78,22 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
self._generate_schema = generate_schema
self._ref_mode = ref_mode
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
schema = self._handler(source_type)
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
schema = self._handler(__source_type)
ref = schema.get('ref')
if self._ref_mode == 'to-def':
ref = schema.get('ref')
if ref is not None:
return self._generate_schema.defs.create_definition_reference_schema(schema)
self._generate_schema.defs.definitions[ref] = schema
return core_schema.definition_reference_schema(ref)
return schema
else: # ref_mode = 'unpack'
else: # ref_mode = 'unpack
return self.resolve_ref_schema(schema)
def _get_types_namespace(self) -> NamespacesTuple:
def _get_types_namespace(self) -> dict[str, Any] | None:
return self._generate_schema._types_namespace
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
return self._generate_schema.generate_schema(source_type)
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
return self._generate_schema.generate_schema(__source_type)
@property
def field_name(self) -> str | None:
@@ -113,13 +113,12 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
"""
if maybe_ref_schema['type'] == 'definition-ref':
ref = maybe_ref_schema['schema_ref']
definition = self._generate_schema.defs.get_schema_from_ref(ref)
if definition is None:
if ref not in self._generate_schema.defs.definitions:
raise LookupError(
f'Could not find a ref for {ref}.'
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
)
return definition
return self._generate_schema.defs.definitions[ref]
elif maybe_ref_schema['type'] == 'definitions':
return self.resolve_ref_schema(maybe_ref_schema['schema'])
return maybe_ref_schema

View File

@@ -1,53 +0,0 @@
from __future__ import annotations
import collections
import collections.abc
import typing
from typing import Any
from pydantic_core import PydanticOmit, core_schema
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
typing.Deque: collections.deque, # noqa: UP006
collections.deque: collections.deque,
list: list,
typing.List: list, # noqa: UP006
tuple: tuple,
typing.Tuple: tuple, # noqa: UP006
set: set,
typing.AbstractSet: set,
typing.Set: set, # noqa: UP006
frozenset: frozenset,
typing.FrozenSet: frozenset, # noqa: UP006
typing.Sequence: list,
typing.MutableSequence: list,
typing.MutableSet: set,
# this doesn't handle subclasses of these
# parametrized typing.Set creates one of these
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
}
def serialize_sequence_via_list(
v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
) -> Any:
items: list[Any] = []
mapped_origin = SEQUENCE_ORIGIN_MAP.get(type(v), None)
if mapped_origin is None:
# we shouldn't hit this branch, should probably add a serialization error or something
return v
for index, item in enumerate(v):
try:
v = handler(item, index)
except PydanticOmit:
pass
else:
items.append(v)
if info.mode_is_json():
return items
else:
return mapped_origin(items)

View File

@@ -1,188 +0,0 @@
from __future__ import annotations
import dataclasses
from inspect import Parameter, Signature, signature
from typing import TYPE_CHECKING, Any, Callable
from pydantic_core import PydanticUndefined
from ._utils import is_valid_identifier
if TYPE_CHECKING:
from ..config import ExtraValues
from ..fields import FieldInfo
# Copied over from stdlib dataclasses
class _HAS_DEFAULT_FACTORY_CLASS:
def __repr__(self):
return '<factory>'
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
"""Extract the correct name to use for the field when generating a signature.
Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
First priority is given to the alias, then the validation_alias, then the field name.
Args:
field_name: The name of the field
field_info: The corresponding FieldInfo object.
Returns:
The correct name to use when generating a signature.
"""
if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
return field_info.alias
if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
return field_info.validation_alias
return field_name
def _process_param_defaults(param: Parameter) -> Parameter:
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
Args:
param (Parameter): The parameter
Returns:
Parameter: The custom processed parameter
"""
from ..fields import FieldInfo
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any
# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
return param.replace(
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
)
return param
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
init: Callable[..., None],
fields: dict[str, FieldInfo],
validate_by_name: bool,
extra: ExtraValues | None,
) -> dict[str, Parameter]:
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
from itertools import islice
present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False
for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if fields.get(param.name):
# exclude params with init=False
if getattr(fields[param.name], 'init', True) is False:
continue
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = param
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = validate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
param_name = _field_name_for_signature(field_name, field)
if field_name in merged_params or param_name in merged_params:
continue
if not is_valid_identifier(param_name):
if allow_names:
param_name = field_name
else:
use_var_kw = True
continue
if field.is_required():
default = Parameter.empty
elif field.default_factory is not None:
# Mimics stdlib dataclasses:
default = _HAS_DEFAULT_FACTORY
else:
default = field.default
merged_params[param_name] = Parameter(
param_name,
Parameter.KEYWORD_ONLY,
annotation=field.rebuild_annotation(),
default=default,
)
if extra == 'allow':
use_var_kw = True
if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('self', Parameter.POSITIONAL_ONLY),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name
# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
return merged_params
def generate_pydantic_signature(
init: Callable[..., None],
fields: dict[str, FieldInfo],
validate_by_name: bool,
extra: ExtraValues | None,
is_dataclass: bool = False,
) -> Signature:
"""Generate signature for a pydantic BaseModel or dataclass.
Args:
init: The class init.
fields: The model fields.
validate_by_name: The `validate_by_name` value of the config.
extra: The `extra` value of the config.
is_dataclass: Whether the model is a dataclass.
Returns:
The dataclass/BaseModel subclass signature.
"""
merged_params = _generate_signature_parameters(init, fields, validate_by_name, extra)
if is_dataclass:
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
return Signature(parameters=list(merged_params.values()), return_annotation=None)

View File

@@ -0,0 +1,713 @@
"""Logic for generating pydantic-core schemas for standard library types.
Import of this module is deferred since it contains imports of many standard library modules.
"""
from __future__ import annotations as _annotations
import collections
import collections.abc
import dataclasses
import decimal
import inspect
import os
import typing
from enum import Enum
from functools import partial
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from typing import Any, Callable, Iterable, TypeVar
import typing_extensions
from pydantic_core import (
CoreSchema,
MultiHostUrl,
PydanticCustomError,
PydanticOmit,
Url,
core_schema,
)
from typing_extensions import get_args, get_origin
from pydantic.errors import PydanticSchemaGenerationError
from pydantic.fields import FieldInfo
from pydantic.types import Strict
from ..config import ConfigDict
from ..json_schema import JsonSchemaValue, update_json_schema
from . import _known_annotated_metadata, _typing_extra, _validators
from ._core_utils import get_type_ref
from ._internal_dataclass import slots_true
from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler
if typing.TYPE_CHECKING:
from ._generate_schema import GenerateSchema
StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema]
@dataclasses.dataclass(**slots_true)
class SchemaTransformer:
get_core_schema: Callable[[Any, GetCoreSchemaHandler], CoreSchema]
get_json_schema: Callable[[CoreSchema, GetJsonSchemaHandler], JsonSchemaValue]
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
return self.get_core_schema(source_type, handler)
def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
return self.get_json_schema(schema, handler)
def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
cases: list[Any] = list(enum_type.__members__.values())
if not cases:
# Use an isinstance check for enums with no cases.
# This won't work with serialization or JSON schema, but that's okay -- the most important
# use case for this is creating typevar bounds for generics that should be restricted to enums.
# This is more consistent than it might seem at first, since you can only subclass enum.Enum
# (or subclasses of enum.Enum) if all parent classes have no cases.
return core_schema.is_instance_schema(enum_type)
use_enum_values = config.get('use_enum_values', False)
if len(cases) == 1:
expected = repr(cases[0].value)
else:
expected = ', '.join([repr(case.value) for case in cases[:-1]]) + f' or {cases[-1].value!r}'
def to_enum(__input_value: Any) -> Enum:
try:
enum_field = enum_type(__input_value)
if use_enum_values:
return enum_field.value
return enum_field
except ValueError:
# The type: ignore on the next line is to ignore the requirement of LiteralString
raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore
enum_ref = get_type_ref(enum_type)
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
description = None
updates = {'title': enum_type.__name__, 'description': description}
updates = {k: v for k, v in updates.items() if v is not None}
def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
original_schema = handler.resolve_ref_schema(json_schema)
update_json_schema(original_schema, updates)
return json_schema
strict_python_schema = core_schema.is_instance_schema(enum_type)
if use_enum_values:
strict_python_schema = core_schema.chain_schema(
[strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)]
)
to_enum_validator = core_schema.no_info_plain_validator_function(to_enum)
if issubclass(enum_type, int):
# this handles `IntEnum`, and also `Foobar(int, Enum)`
updates['type'] = 'integer'
lax = core_schema.chain_schema([core_schema.int_schema(), to_enum_validator])
# Disallow float from JSON due to strict mode
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()),
python_schema=strict_python_schema,
)
elif issubclass(enum_type, str):
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
updates['type'] = 'string'
lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()),
python_schema=strict_python_schema,
)
elif issubclass(enum_type, float):
updates['type'] = 'numeric'
lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()),
python_schema=strict_python_schema,
)
else:
lax = to_enum_validator
strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema)
return core_schema.lax_or_strict_schema(
lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]}
)
@dataclasses.dataclass(**slots_true)
class InnerSchemaValidator:
"""Use a fixed CoreSchema, avoiding interference from outward annotations."""
core_schema: CoreSchema
js_schema: JsonSchemaValue | None = None
js_core_schema: CoreSchema | None = None
js_schema_update: JsonSchemaValue | None = None
def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
if self.js_schema is not None:
return self.js_schema
js_schema = handler(self.js_core_schema or self.core_schema)
if self.js_schema_update is not None:
js_schema.update(self.js_schema_update)
return js_schema
def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
return self.core_schema
def decimal_prepare_pydantic_annotations(
source: Any, annotations: Iterable[Any], config: ConfigDict
) -> tuple[Any, list[Any]] | None:
if source is not decimal.Decimal:
return None
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
config_allow_inf_nan = config.get('allow_inf_nan')
if config_allow_inf_nan is not None:
metadata.setdefault('allow_inf_nan', config_allow_inf_nan)
_known_annotated_metadata.check_metadata(
metadata, {*_known_annotated_metadata.FLOAT_CONSTRAINTS, 'max_digits', 'decimal_places'}, decimal.Decimal
)
return source, [InnerSchemaValidator(core_schema.decimal_schema(**metadata)), *remaining_annotations]
def datetime_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
import datetime
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
if source_type is datetime.date:
sv = InnerSchemaValidator(core_schema.date_schema(**metadata))
elif source_type is datetime.datetime:
sv = InnerSchemaValidator(core_schema.datetime_schema(**metadata))
elif source_type is datetime.time:
sv = InnerSchemaValidator(core_schema.time_schema(**metadata))
elif source_type is datetime.timedelta:
sv = InnerSchemaValidator(core_schema.timedelta_schema(**metadata))
else:
return None
# check now that we know the source type is correct
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.DATE_TIME_CONSTRAINTS, source_type)
return (source_type, [sv, *remaining_annotations])
def uuid_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
# UUIDs have no constraints - they are fixed length, constructing a UUID instance checks the length
from uuid import UUID
if source_type is not UUID:
return None
return (source_type, [InnerSchemaValidator(core_schema.uuid_schema()), *annotations])
def path_schema_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
import pathlib
if source_type not in {
os.PathLike,
pathlib.Path,
pathlib.PurePath,
pathlib.PosixPath,
pathlib.PurePosixPath,
pathlib.PureWindowsPath,
}:
return None
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, source_type)
construct_path = pathlib.PurePath if source_type is os.PathLike else source_type
def path_validator(input_value: str) -> os.PathLike[Any]:
try:
return construct_path(input_value)
except TypeError as e:
raise PydanticCustomError('path_type', 'Input is not a valid path') from e
constrained_str_schema = core_schema.str_schema(**metadata)
instance_schema = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
python_schema=core_schema.is_instance_schema(source_type),
)
strict: bool | None = None
for annotation in annotations:
if isinstance(annotation, Strict):
strict = annotation.strict
schema = core_schema.lax_or_strict_schema(
lax_schema=core_schema.union_schema(
[
instance_schema,
core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
],
custom_error_type='path_type',
custom_error_message='Input is not a valid path',
strict=True,
),
strict_schema=instance_schema,
serialization=core_schema.to_string_ser_schema(),
strict=strict,
)
return (
source_type,
[
InnerSchemaValidator(schema, js_core_schema=constrained_str_schema, js_schema_update={'format': 'path'}),
*remaining_annotations,
],
)
def dequeue_validator(
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int
) -> collections.deque[Any]:
if isinstance(input_value, collections.deque):
maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None]
if maxlens:
maxlen = min(maxlens)
return collections.deque(handler(input_value), maxlen=maxlen)
else:
return collections.deque(handler(input_value), maxlen=maxlen)
@dataclasses.dataclass(**slots_true)
class SequenceValidator:
mapped_origin: type[Any]
item_source_type: type[Any]
min_length: int | None = None
max_length: int | None = None
strict: bool = False
def serialize_sequence_via_list(
self, v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
) -> Any:
items: list[Any] = []
for index, item in enumerate(v):
try:
v = handler(item, index)
except PydanticOmit:
pass
else:
items.append(v)
if info.mode_is_json():
return items
else:
return self.mapped_origin(items)
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
if self.item_source_type is Any:
items_schema = None
else:
items_schema = handler.generate_schema(self.item_source_type)
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
if self.mapped_origin in (list, set, frozenset):
if self.mapped_origin is list:
constrained_schema = core_schema.list_schema(items_schema, **metadata)
elif self.mapped_origin is set:
constrained_schema = core_schema.set_schema(items_schema, **metadata)
else:
assert self.mapped_origin is frozenset # safety check in case we forget to add a case
constrained_schema = core_schema.frozenset_schema(items_schema, **metadata)
schema = constrained_schema
else:
# safety check in case we forget to add a case
assert self.mapped_origin in (collections.deque, collections.Counter)
if self.mapped_origin is collections.deque:
# if we have a MaxLen annotation might as well set that as the default maxlen on the deque
# this lets us re-use existing metadata annotations to let users set the maxlen on a dequeue
# that e.g. comes from JSON
coerce_instance_wrap = partial(
core_schema.no_info_wrap_validator_function,
partial(dequeue_validator, maxlen=metadata.get('max_length', None)),
)
else:
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
constrained_schema = core_schema.list_schema(items_schema, **metadata)
check_instance = core_schema.json_or_python_schema(
json_schema=core_schema.list_schema(),
python_schema=core_schema.is_instance_schema(self.mapped_origin),
)
serialization = core_schema.wrap_serializer_function_ser_schema(
self.serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
)
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
if metadata.get('strict', False):
schema = strict
else:
lax = coerce_instance_wrap(constrained_schema)
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
schema['serialization'] = serialization
return schema
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
typing.Deque: collections.deque,
collections.deque: collections.deque,
list: list,
typing.List: list,
set: set,
typing.AbstractSet: set,
typing.Set: set,
frozenset: frozenset,
typing.FrozenSet: frozenset,
typing.Sequence: list,
typing.MutableSequence: list,
typing.MutableSet: set,
# this doesn't handle subclasses of these
# parametrized typing.Set creates one of these
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
}
def identity(s: CoreSchema) -> CoreSchema:
return s
def sequence_like_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
origin: Any = get_origin(source_type)
mapped_origin = SEQUENCE_ORIGIN_MAP.get(origin, None) if origin else SEQUENCE_ORIGIN_MAP.get(source_type, None)
if mapped_origin is None:
return None
args = get_args(source_type)
if not args:
args = (Any,)
elif len(args) != 1:
raise ValueError('Expected sequence to have exactly 1 generic parameter')
item_source_type = args[0]
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
return (source_type, [SequenceValidator(mapped_origin, item_source_type, **metadata), *remaining_annotations])
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
typing.DefaultDict: collections.defaultdict,
collections.defaultdict: collections.defaultdict,
collections.OrderedDict: collections.OrderedDict,
typing_extensions.OrderedDict: collections.OrderedDict,
dict: dict,
typing.Dict: dict,
collections.Counter: collections.Counter,
typing.Counter: collections.Counter,
# this doesn't handle subclasses of these
typing.Mapping: dict,
typing.MutableMapping: dict,
# parametrized typing.{Mutable}Mapping creates one of these
collections.abc.MutableMapping: dict,
collections.abc.Mapping: dict,
}
def defaultdict_validator(
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
) -> collections.defaultdict[Any, Any]:
if isinstance(input_value, collections.defaultdict):
default_factory = input_value.default_factory
return collections.defaultdict(default_factory, handler(input_value))
else:
return collections.defaultdict(default_default_factory, handler(input_value))
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
def infer_default() -> Callable[[], Any]:
allowed_default_types: dict[Any, Any] = {
typing.Tuple: tuple,
tuple: tuple,
collections.abc.Sequence: tuple,
collections.abc.MutableSequence: list,
typing.List: list,
list: list,
typing.Sequence: list,
typing.Set: set,
set: set,
typing.MutableSet: set,
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
typing.MutableMapping: dict,
typing.Mapping: dict,
collections.abc.Mapping: dict,
collections.abc.MutableMapping: dict,
float: float,
int: int,
str: str,
bool: bool,
}
values_type_origin = get_origin(values_source_type) or values_source_type
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
if isinstance(values_type_origin, TypeVar):
def type_var_default_factory() -> None:
raise RuntimeError(
'Generic defaultdict cannot be used without a concrete value type or an'
' explicit default factory, ' + instructions
)
return type_var_default_factory
elif values_type_origin not in allowed_default_types:
# a somewhat subjective set of types that have reasonable default values
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
raise PydanticSchemaGenerationError(
f'Unable to infer a default factory for keys of type {values_source_type}.'
f' Only {allowed_msg} are supported, other types require an explicit default factory'
' ' + instructions
)
return allowed_default_types[values_type_origin]
# Assume Annotated[..., Field(...)]
if _typing_extra.is_annotated(values_source_type):
field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
else:
field_info = None
if field_info and field_info.default_factory:
default_default_factory = field_info.default_factory
else:
default_default_factory = infer_default()
return default_default_factory
@dataclasses.dataclass(**slots_true)
class MappingValidator:
mapped_origin: type[Any]
keys_source_type: type[Any]
values_source_type: type[Any]
min_length: int | None = None
max_length: int | None = None
strict: bool = False
def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any:
return handler(v)
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
if self.keys_source_type is Any:
keys_schema = None
else:
keys_schema = handler.generate_schema(self.keys_source_type)
if self.values_source_type is Any:
values_schema = None
else:
values_schema = handler.generate_schema(self.values_source_type)
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
if self.mapped_origin is dict:
schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
else:
constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
check_instance = core_schema.json_or_python_schema(
json_schema=core_schema.dict_schema(),
python_schema=core_schema.is_instance_schema(self.mapped_origin),
)
if self.mapped_origin is collections.defaultdict:
default_default_factory = get_defaultdict_default_default_factory(self.values_source_type)
coerce_instance_wrap = partial(
core_schema.no_info_wrap_validator_function,
partial(defaultdict_validator, default_default_factory=default_default_factory),
)
else:
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
serialization = core_schema.wrap_serializer_function_ser_schema(
self.serialize_mapping_via_dict,
schema=core_schema.dict_schema(
keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema()
),
info_arg=False,
)
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
if metadata.get('strict', False):
schema = strict
else:
lax = coerce_instance_wrap(constrained_schema)
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
schema['serialization'] = serialization
return schema
def mapping_like_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
origin: Any = get_origin(source_type)
mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None)
if mapped_origin is None:
return None
args = get_args(source_type)
if not args:
args = (Any, Any)
elif mapped_origin is collections.Counter:
# a single generic
if len(args) != 1:
raise ValueError('Expected Counter to have exactly 1 generic parameter')
args = (args[0], int) # keys are always an int
elif len(args) != 2:
raise ValueError('Expected mapping to have exactly 2 generic parameters')
keys_source_type, values_source_type = args
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
return (
source_type,
[
MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata),
*remaining_annotations,
],
)
def ip_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
def make_strict_ip_schema(tp: type[Any]) -> CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(tp, core_schema.str_schema()),
python_schema=core_schema.is_instance_schema(tp),
)
if source_type is IPv4Address:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_address_validator),
strict_schema=make_strict_ip_schema(IPv4Address),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv4'},
),
*annotations,
]
if source_type is IPv4Network:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_network_validator),
strict_schema=make_strict_ip_schema(IPv4Network),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv4network'},
),
*annotations,
]
if source_type is IPv4Interface:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_interface_validator),
strict_schema=make_strict_ip_schema(IPv4Interface),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv4interface'},
),
*annotations,
]
if source_type is IPv6Address:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_address_validator),
strict_schema=make_strict_ip_schema(IPv6Address),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv6'},
),
*annotations,
]
if source_type is IPv6Network:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_network_validator),
strict_schema=make_strict_ip_schema(IPv6Network),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv6network'},
),
*annotations,
]
if source_type is IPv6Interface:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_interface_validator),
strict_schema=make_strict_ip_schema(IPv6Interface),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv6interface'},
),
*annotations,
]
return None
def url_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
if source_type is Url:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.url_schema(),
lambda cs, handler: handler(cs),
),
*annotations,
]
if source_type is MultiHostUrl:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.multi_host_url_schema(),
lambda cs, handler: handler(cs),
),
*annotations,
]
PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any], ConfigDict], tuple[Any, list[Any]] | None], ...] = (
decimal_prepare_pydantic_annotations,
sequence_like_prepare_pydantic_annotations,
datetime_prepare_pydantic_annotations,
uuid_prepare_pydantic_annotations,
path_schema_prepare_pydantic_annotations,
mapping_like_prepare_pydantic_annotations,
ip_prepare_pydantic_annotations,
url_prepare_pydantic_annotations,
)

View File

@@ -1,544 +1,244 @@
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap Python's typing module."""
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap python's typing module."""
from __future__ import annotations as _annotations
from __future__ import annotations
import collections.abc
import re
import dataclasses
import sys
import types
import typing
from collections.abc import Callable
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, cast
from types import GetSetDescriptorType
from typing import TYPE_CHECKING, Any, ForwardRef
import typing_extensions
from typing_extensions import deprecated, get_args, get_origin
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin
from typing_extensions import Annotated, Final, Literal, TypeAliasType, TypeGuard, get_args, get_origin
from pydantic.version import version_short
if TYPE_CHECKING:
from ._dataclasses import StandardDataclass
try:
from typing import _TypingBase # type: ignore[attr-defined]
except ImportError:
from typing import _Final as _TypingBase # type: ignore[attr-defined]
typing_base = _TypingBase
if sys.version_info < (3, 9):
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
TypingGenericAlias = ()
else:
from typing import GenericAlias as TypingGenericAlias # type: ignore
if sys.version_info < (3, 11):
from typing_extensions import NotRequired, Required
else:
from typing import NotRequired, Required # noqa: F401
if sys.version_info < (3, 10):
def origin_is_union(tp: type[Any] | None) -> bool:
return tp is typing.Union
WithArgsTypes = (TypingGenericAlias,)
else:
def origin_is_union(tp: type[Any] | None) -> bool:
return tp is typing.Union or tp is types.UnionType
WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined]
from ._namespace_utils import GlobalsNamespace, MappingNamespace, NsResolver, get_module_ns_of
if sys.version_info < (3, 10):
NoneType = type(None)
EllipsisType = type(Ellipsis)
else:
from types import EllipsisType as EllipsisType
from types import NoneType as NoneType
if TYPE_CHECKING:
from pydantic import BaseModel
# As per https://typing-extensions.readthedocs.io/en/latest/#runtime-use-of-types,
# always check for both `typing` and `typing_extensions` variants of a typing construct.
# (this is implemented differently than the suggested approach in the `typing_extensions`
# docs for performance).
LITERAL_TYPES: set[Any] = {Literal}
if hasattr(typing, 'Literal'):
LITERAL_TYPES.add(typing.Literal) # type: ignore
NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES))
_t_annotated = typing.Annotated
_te_annotated = typing_extensions.Annotated
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
def is_annotated(tp: Any, /) -> bool:
"""Return whether the provided argument is a `Annotated` special form.
def is_none_type(type_: Any) -> bool:
return type_ in NONE_TYPES
```python {test="skip" lint="skip"}
is_annotated(Annotated[int, ...])
#> True
```
def is_callable_type(type_: type[Any]) -> bool:
return type_ is Callable or get_origin(type_) is Callable
def is_literal_type(type_: type[Any]) -> bool:
return Literal is not None and get_origin(type_) in LITERAL_TYPES
def literal_values(type_: type[Any]) -> tuple[Any, ...]:
return get_args(type_)
def all_literal_values(type_: type[Any]) -> list[Any]:
"""This method is used to retrieve all Literal values as
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`.
"""
origin = get_origin(tp)
return origin is _t_annotated or origin is _te_annotated
if not is_literal_type(type_):
return [type_]
values = literal_values(type_)
return list(x for value in values for x in all_literal_values(value))
def annotated_type(tp: Any, /) -> Any | None:
"""Return the type of the `Annotated` special form, or `None`."""
return tp.__origin__ if typing_objects.is_annotated(get_origin(tp)) else None
def is_annotated(ann_type: Any) -> bool:
from ._utils import lenient_issubclass
origin = get_origin(ann_type)
return origin is not None and lenient_issubclass(origin, Annotated)
def unpack_type(tp: Any, /) -> Any | None:
"""Return the type wrapped by the `Unpack` special form, or `None`."""
return get_args(tp)[0] if typing_objects.is_unpack(get_origin(tp)) else None
def is_hashable(tp: Any, /) -> bool:
"""Return whether the provided argument is the `Hashable` class.
```python {test="skip" lint="skip"}
is_hashable(Hashable)
#> True
```
def is_namedtuple(type_: type[Any]) -> bool:
"""Check if a given class is a named tuple.
It can be either a `typing.NamedTuple` or `collections.namedtuple`.
"""
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
# hence the second check:
return tp is collections.abc.Hashable or get_origin(tp) is collections.abc.Hashable
from ._utils import lenient_issubclass
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
def is_callable(tp: Any, /) -> bool:
"""Return whether the provided argument is a `Callable`, parametrized or not.
test_new_type = typing.NewType('test_new_type', str)
```python {test="skip" lint="skip"}
is_callable(Callable[[int], str])
#> True
is_callable(typing.Callable)
#> True
is_callable(collections.abc.Callable)
#> True
```
def is_new_type(type_: type[Any]) -> bool:
"""Check whether type_ was created using typing.NewType.
Can't use isinstance because it fails <3.10.
"""
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
# hence the second check:
return tp is collections.abc.Callable or get_origin(tp) is collections.abc.Callable
return isinstance(type_, test_new_type.__class__) and hasattr(type_, '__supertype__') # type: ignore[arg-type]
_classvar_re = re.compile(r'((\w+\.)?Annotated\[)?(\w+\.)?ClassVar\[')
def _check_classvar(v: type[Any] | None) -> bool:
if v is None:
return False
return v.__class__ == typing.ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
def is_classvar_annotation(tp: Any, /) -> bool:
"""Return whether the provided argument represents a class variable annotation.
Although not explicitly stated by the typing specification, `ClassVar` can be used
inside `Annotated` and as such, this function checks for this specific scenario.
Because this function is used to detect class variables before evaluating forward references
(or because evaluation failed), we also implement a naive regex match implementation. This is
required because class variables are inspected before fields are collected, so we try to be
as accurate as possible.
"""
if typing_objects.is_classvar(tp):
def is_classvar(ann_type: type[Any]) -> bool:
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
return True
origin = get_origin(tp)
if typing_objects.is_classvar(origin):
return True
if typing_objects.is_annotated(origin):
annotated_type = tp.__origin__
if typing_objects.is_classvar(annotated_type) or typing_objects.is_classvar(get_origin(annotated_type)):
return True
str_ann: str | None = None
if isinstance(tp, typing.ForwardRef):
str_ann = tp.__forward_arg__
if isinstance(tp, str):
str_ann = tp
if str_ann is not None and _classvar_re.match(str_ann):
# stdlib dataclasses do something similar, although a bit more advanced
# (see `dataclass._is_type`).
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
# forward references, see #3679
if ann_type.__class__ == typing.ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): # type: ignore
return True
return False
_t_final = typing.Final
_te_final = typing_extensions.Final
def _check_finalvar(v: type[Any] | None) -> bool:
"""Check if a given type is a `typing.Final` type."""
if v is None:
return False
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
# TODO implement `is_finalvar_annotation` as Final can be wrapped with other special forms:
def is_finalvar(tp: Any, /) -> bool:
"""Return whether the provided argument is a `Final` special form, parametrized or not.
```python {test="skip" lint="skip"}
is_finalvar(Final[int])
#> True
is_finalvar(Final)
#> True
"""
# Final is not necessarily parametrized:
if tp is _t_final or tp is _te_final:
return True
origin = get_origin(tp)
return origin is _t_final or origin is _te_final
def is_finalvar(ann_type: Any) -> bool:
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
_NONE_TYPES: tuple[Any, ...] = (None, NoneType, typing.Literal[None], typing_extensions.Literal[None])
def parent_frame_namespace(*, parent_depth: int = 2) -> dict[str, Any] | None:
"""We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope
and suggestion at the end of the next comment by @gvanrossum.
WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the
parent of where it is called.
def is_none_type(tp: Any, /) -> bool:
"""Return whether the argument represents the `None` type as part of an annotation.
```python {test="skip" lint="skip"}
is_none_type(None)
#> True
is_none_type(NoneType)
#> True
is_none_type(Literal[None])
#> True
is_none_type(type[None])
#> False
"""
return tp in _NONE_TYPES
def is_namedtuple(tp: Any, /) -> bool:
"""Return whether the provided argument is a named tuple class.
The class can be created using `typing.NamedTuple` or `collections.namedtuple`.
Parametrized generic classes are *not* assumed to be named tuples.
"""
from ._utils import lenient_issubclass # circ. import
return lenient_issubclass(tp, tuple) and hasattr(tp, '_fields')
# TODO In 2.12, delete this export. It is currently defined only to not break
# pydantic-settings which relies on it:
origin_is_union = is_union_origin
def is_generic_alias(tp: Any, /) -> bool:
return isinstance(tp, (types.GenericAlias, typing._GenericAlias)) # pyright: ignore[reportAttributeAccessIssue]
# TODO: Ideally, we should avoid relying on the private `typing` constructs:
if sys.version_info < (3, 10):
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # pyright: ignore[reportAttributeAccessIssue]
else:
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias, types.UnionType) # pyright: ignore[reportAttributeAccessIssue]
# Similarly, we shouldn't rely on this `_Final` class, which is even more private than `_GenericAlias`:
typing_base: Any = typing._Final # pyright: ignore[reportAttributeAccessIssue]
### Annotation evaluations functions:
def parent_frame_namespace(*, parent_depth: int = 2, force: bool = False) -> dict[str, Any] | None:
"""Fetch the local namespace of the parent frame where this function is called.
Using this function is mostly useful to resolve forward annotations pointing to members defined in a local namespace,
such as assignments inside a function. Using the standard library tools, it is currently not possible to resolve
such annotations:
```python {lint="skip" test="skip"}
from typing import get_type_hints
def func() -> None:
Alias = int
class C:
a: 'Alias'
# Raises a `NameError: 'Alias' is not defined`
get_type_hints(C)
```
Pydantic uses this function when a Pydantic model is being defined to fetch the parent frame locals. However,
this only allows us to fetch the parent frame namespace and not other parents (e.g. a model defined in a function,
itself defined in another function). Inspecting the next outer frames (using `f_back`) is not reliable enough
(see https://discuss.python.org/t/20659).
Because this function is mostly used to better resolve forward annotations, nothing is returned if the parent frame's
code object is defined at the module level. In this case, the locals of the frame will be the same as the module
globals where the class is defined (see `_namespace_utils.get_module_ns_of`). However, if you still want to fetch
the module globals (e.g. when rebuilding a model, where the frame where the rebuild call is performed might contain
members that you want to use for forward annotations evaluation), you can use the `force` parameter.
Args:
parent_depth: The depth at which to get the frame. Defaults to 2, meaning the parent frame where this function
is called will be used.
force: Whether to always return the frame locals, even if the frame's code object is defined at the module level.
Returns:
The locals of the namespace, or `None` if it was skipped as per the described logic.
WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a
dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many
other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659.
"""
frame = sys._getframe(parent_depth)
if frame.f_code.co_name.startswith('<generic parameters of'):
# As `parent_frame_namespace` is mostly called in `ModelMetaclass.__new__`,
# the parent frame can be the annotation scope if the PEP 695 generic syntax is used.
# (see https://docs.python.org/3/reference/executionmodel.html#annotation-scopes,
# https://docs.python.org/3/reference/compound_stmts.html#generic-classes).
# In this case, the code name is set to `<generic parameters of MyClass>`,
# and we need to skip this frame as it is irrelevant.
frame = cast(types.FrameType, frame.f_back) # guaranteed to not be `None`
# note, we don't copy frame.f_locals here (or during the last return call), because we don't expect the namespace to be
# modified down the line if this becomes a problem, we could implement some sort of frozen mapping structure to enforce this.
if force:
# if f_back is None, it's the global module namespace and we don't need to include it here
if frame.f_back is None:
return None
else:
return frame.f_locals
# If either of the following conditions are true, the class is defined at the top module level.
# To better understand why we need both of these checks, see
# https://github.com/pydantic/pydantic/pull/10113#discussion_r1714981531.
if frame.f_back is None or frame.f_code.co_name == '<module>':
return None
return frame.f_locals
def add_module_globals(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
module_name = getattr(obj, '__module__', None)
if module_name:
try:
module_globalns = sys.modules[module_name].__dict__
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
else:
if globalns:
return {**module_globalns, **globalns}
else:
# copy module globals to make sure it can't be updated later
return module_globalns.copy()
return globalns or {}
def _type_convert(arg: Any) -> Any:
"""Convert `None` to `NoneType` and strings to `ForwardRef` instances.
This is a backport of the private `typing._type_convert` function. When
evaluating a type, `ForwardRef._evaluate` ends up being called, and is
responsible for making this conversion. However, we still have to apply
it for the first argument passed to our type evaluation functions, similarly
to the `typing.get_type_hints` function.
"""
if arg is None:
return NoneType
if isinstance(arg, str):
# Like `typing.get_type_hints`, assume the arg can be in any context,
# hence the proper `is_argument` and `is_class` args:
return _make_forward_ref(arg, is_argument=False, is_class=True)
return arg
def get_cls_types_namespace(cls: type[Any], parent_namespace: dict[str, Any] | None = None) -> dict[str, Any]:
ns = add_module_globals(cls, parent_namespace)
ns[cls.__name__] = cls
return ns
def get_model_type_hints(
obj: type[BaseModel],
*,
ns_resolver: NsResolver | None = None,
) -> dict[str, tuple[Any, bool]]:
"""Collect annotations from a Pydantic model class, including those from parent classes.
Args:
obj: The Pydantic model to inspect.
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
Returns:
A dictionary mapping annotation names to a two-tuple: the first element is the evaluated
type or the original annotation if a `NameError` occurred, the second element is a boolean
indicating if whether the evaluation succeeded.
"""
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
ns_resolver = ns_resolver or NsResolver()
for base in reversed(obj.__mro__):
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
if not ann or isinstance(ann, types.GetSetDescriptorType):
continue
with ns_resolver.push(base):
globalns, localns = ns_resolver.types_namespace
for name, value in ann.items():
if name.startswith('_'):
# For private attributes, we only need the annotation to detect the `ClassVar` special form.
# For this reason, we still try to evaluate it, but we also catch any possible exception (on
# top of the `NameError`s caught in `try_eval_type`) that could happen so that users are free
# to use any kind of forward annotation for private fields (e.g. circular imports, new typing
# syntax, etc).
try:
hints[name] = try_eval_type(value, globalns, localns)
except Exception:
hints[name] = (value, False)
else:
hints[name] = try_eval_type(value, globalns, localns)
return hints
def get_cls_type_hints(
obj: type[Any],
*,
ns_resolver: NsResolver | None = None,
) -> dict[str, Any]:
def get_cls_type_hints_lenient(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
"""Collect annotations from a class, including those from parent classes.
Args:
obj: The class to inspect.
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
Unlike `typing.get_type_hints`, this function will not error if a forward reference is not resolvable.
"""
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
ns_resolver = ns_resolver or NsResolver()
hints = {}
for base in reversed(obj.__mro__):
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
if not ann or isinstance(ann, types.GetSetDescriptorType):
continue
with ns_resolver.push(base):
globalns, localns = ns_resolver.types_namespace
ann = base.__dict__.get('__annotations__')
localns = dict(vars(base))
if ann is not None and ann is not GetSetDescriptorType:
for name, value in ann.items():
hints[name] = eval_type(value, globalns, localns)
hints[name] = eval_type_lenient(value, globalns, localns)
return hints
def try_eval_type(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> tuple[Any, bool]:
"""Try evaluating the annotation using the provided namespaces.
Args:
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
of `str`, it will be converted to a `ForwardRef`.
localns: The global namespace to use during annotation evaluation.
globalns: The local namespace to use during annotation evaluation.
Returns:
A two-tuple containing the possibly evaluated type and a boolean indicating
whether the evaluation succeeded or not.
"""
value = _type_convert(value)
def eval_type_lenient(value: Any, globalns: dict[str, Any] | None, localns: dict[str, Any] | None) -> Any:
"""Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved."""
if value is None:
value = NoneType
elif isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)
try:
return eval_type_backport(value, globalns, localns), True
return typing._eval_type(value, globalns, localns) # type: ignore
except NameError:
return value, False
def eval_type(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> Any:
"""Evaluate the annotation using the provided namespaces.
Args:
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
of `str`, it will be converted to a `ForwardRef`.
localns: The global namespace to use during annotation evaluation.
globalns: The local namespace to use during annotation evaluation.
"""
value = _type_convert(value)
return eval_type_backport(value, globalns, localns)
@deprecated(
'`eval_type_lenient` is deprecated, use `try_eval_type` instead.',
category=None,
)
def eval_type_lenient(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
) -> Any:
ev, _ = try_eval_type(value, globalns, localns)
return ev
def eval_type_backport(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
type_params: tuple[Any, ...] | None = None,
) -> Any:
"""An enhanced version of `typing._eval_type` which will fall back to using the `eval_type_backport`
package if it's installed to let older Python versions use newer typing constructs.
Specifically, this transforms `X | Y` into `typing.Union[X, Y]` and `list[X]` into `typing.List[X]`
(as well as all the types made generic in PEP 585) if the original syntax is not supported in the
current Python version.
This function will also display a helpful error if the value passed fails to evaluate.
"""
try:
return _eval_type_backport(value, globalns, localns, type_params)
except TypeError as e:
if 'Unable to evaluate type annotation' in str(e):
raise
# If it is a `TypeError` and value isn't a `ForwardRef`, it would have failed during annotation definition.
# Thus we assert here for type checking purposes:
assert isinstance(value, typing.ForwardRef)
message = f'Unable to evaluate type annotation {value.__forward_arg__!r}.'
if sys.version_info >= (3, 11):
e.add_note(message)
raise
else:
raise TypeError(message) from e
except RecursionError as e:
# TODO ideally recursion errors should be checked in `eval_type` above, but `eval_type_backport`
# is used directly in some places.
message = (
"If you made use of an implicit recursive type alias (e.g. `MyType = list['MyType']), "
'consider using PEP 695 type aliases instead. For more details, refer to the documentation: '
f'https://docs.pydantic.dev/{version_short()}/concepts/types/#named-recursive-types'
)
if sys.version_info >= (3, 11):
e.add_note(message)
raise
else:
raise RecursionError(f'{e.args[0]}\n{message}')
def _eval_type_backport(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
type_params: tuple[Any, ...] | None = None,
) -> Any:
try:
return _eval_type(value, globalns, localns, type_params)
except TypeError as e:
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
raise
try:
from eval_type_backport import eval_type_backport
except ImportError:
raise TypeError(
f'Unable to evaluate type annotation {value.__forward_arg__!r}. If you are making use '
'of the new typing syntax (unions using `|` since Python 3.10 or builtins subscripting '
'since Python 3.9), you should either replace the use of new syntax with the existing '
'`typing` constructs or install the `eval_type_backport` package.'
) from e
return eval_type_backport(
value,
globalns,
localns, # pyright: ignore[reportArgumentType], waiting on a new `eval_type_backport` release.
try_default=False,
)
def _eval_type(
value: Any,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
type_params: tuple[Any, ...] | None = None,
) -> Any:
if sys.version_info >= (3, 13):
return typing._eval_type( # type: ignore
value, globalns, localns, type_params=type_params
)
else:
return typing._eval_type( # type: ignore
value, globalns, localns
)
def is_backport_fixable_error(e: TypeError) -> bool:
msg = str(e)
return sys.version_info < (3, 10) and msg.startswith('unsupported operand type(s) for |: ')
# the point of this function is to be tolerant to this case
return value
def get_function_type_hints(
function: Callable[..., Any],
*,
include_keys: set[str] | None = None,
globalns: GlobalsNamespace | None = None,
localns: MappingNamespace | None = None,
function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Return type hints for a function.
This is similar to the `typing.get_type_hints` function, with a few differences:
- Support `functools.partial` by using the underlying `func` attribute.
- Do not wrap type annotation of a parameter with `Optional` if it has a default value of `None`
(related bug: https://github.com/python/cpython/issues/90353, only fixed in 3.11+).
"""Like `typing.get_type_hints`, but doesn't convert `X` to `Optional[X]` if the default value is `None`, also
copes with `partial`.
"""
try:
if isinstance(function, partial):
annotations = function.func.__annotations__
else:
annotations = function.__annotations__
except AttributeError:
# Some functions (e.g. builtins) don't have annotations:
return {}
if globalns is None:
globalns = get_module_ns_of(function)
type_params: tuple[Any, ...] | None = None
if localns is None:
# If localns was specified, it is assumed to already contain type params. This is because
# Pydantic has more advanced logic to do so (see `_namespace_utils.ns_for_function`).
type_params = getattr(function, '__type_params__', ())
if isinstance(function, partial):
annotations = function.func.__annotations__
else:
annotations = function.__annotations__
globalns = add_module_globals(function)
type_hints = {}
for name, value in annotations.items():
if include_keys is not None and name not in include_keys:
@@ -548,7 +248,7 @@ def get_function_type_hints(
elif isinstance(value, str):
value = _make_forward_ref(value)
type_hints[name] = eval_type_backport(value, globalns, localns, type_params)
type_hints[name] = typing._eval_type(value, globalns, types_namespace) # type: ignore
return type_hints
@@ -663,15 +363,11 @@ else:
if isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)
value = eval_type_backport(value, base_globals, base_locals)
value = typing._eval_type(value, base_globals, base_locals) # type: ignore
hints[name] = value
if not include_extras and hasattr(typing, '_strip_annotations'):
return {
k: typing._strip_annotations(t) # type: ignore
for k, t in hints.items()
}
else:
return hints
return (
hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
)
if globalns is None:
if isinstance(obj, types.ModuleType):
@@ -692,7 +388,7 @@ else:
if isinstance(obj, typing._allowed_types): # type: ignore
return {}
else:
raise TypeError(f'{obj!r} is not a module, class, method, or function.')
raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.')
defaults = typing._get_defaults(obj) # type: ignore
hints = dict(hints)
for name, value in hints.items():
@@ -707,8 +403,33 @@ else:
is_argument=not isinstance(obj, types.ModuleType),
is_class=False,
)
value = eval_type_backport(value, globalns, localns)
value = typing._eval_type(value, globalns, localns) # type: ignore
if name in defaults and defaults[name] is None:
value = typing.Optional[value]
hints[name] = value
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
if sys.version_info < (3, 9):
def evaluate_fwd_ref(
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
) -> Any:
return ref._evaluate(globalns=globalns, localns=localns)
else:
def evaluate_fwd_ref(
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
) -> Any:
return ref._evaluate(globalns=globalns, localns=localns, recursive_guard=frozenset())
def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
# The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality,
# so I created this convenience function
return dataclasses.is_dataclass(_cls)
def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]:
return isinstance(origin, TypeAliasType)

View File

@@ -2,30 +2,20 @@
This should be reduced as much as possible with functions only used in one place, moved to that place.
"""
from __future__ import annotations as _annotations
import dataclasses
import keyword
import sys
import typing
import warnings
import weakref
from collections import OrderedDict, defaultdict, deque
from collections.abc import Mapping
from copy import deepcopy
from functools import cached_property
from inspect import Parameter
from itertools import zip_longest
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
from typing import Any, Callable, Generic, TypeVar, overload
from typing import Any, TypeVar
from typing_extensions import TypeAlias, TypeGuard, deprecated
from pydantic import PydanticDeprecatedSince211
from typing_extensions import TypeAlias, TypeGuard
from . import _repr, _typing_extra
from ._import_utils import import_cached_base_model
if typing.TYPE_CHECKING:
MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
@@ -69,25 +59,6 @@ BUILTIN_COLLECTIONS: set[type[Any]] = {
}
def can_be_positional(param: Parameter) -> bool:
"""Return whether the parameter accepts a positional argument.
```python {test="skip" lint="skip"}
def func(a, /, b, *, c):
pass
params = inspect.signature(func).parameters
can_be_positional(params['a'])
#> True
can_be_positional(params['b'])
#> True
can_be_positional(params['c'])
#> False
```
"""
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
def sequence_like(v: Any) -> bool:
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
@@ -112,7 +83,7 @@ def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
"""Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
unlike raw calls to lenient_issubclass.
"""
BaseModel = import_cached_base_model()
from ..main import BaseModel
return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
@@ -304,23 +275,19 @@ class ValueItems(_repr.Representation):
if typing.TYPE_CHECKING:
def LazyClassAttribute(name: str, get_value: Callable[[], T]) -> T: ...
def ClassAttribute(name: str, value: T) -> T:
...
else:
class LazyClassAttribute:
"""A descriptor exposing an attribute only accessible on a class (hidden from instances).
class ClassAttribute:
"""Hide class attribute from its instances."""
The attribute is lazily computed and cached during the first access.
"""
__slots__ = 'name', 'value'
def __init__(self, name: str, get_value: Callable[[], Any]) -> None:
def __init__(self, name: str, value: Any) -> None:
self.name = name
self.get_value = get_value
@cached_property
def value(self) -> Any:
return self.get_value()
self.value = value
def __get__(self, instance: Any, owner: type[Any]) -> None:
if instance is None:
@@ -342,7 +309,7 @@ def smart_deepcopy(obj: Obj) -> Obj:
try:
if not obj and obj_type in BUILTIN_COLLECTIONS:
# faster way for empty collections, no need to copy its members
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method
except (TypeError, ValueError, RuntimeError):
# do we really dare to catch ALL errors? Seems a bit risky
pass
@@ -350,7 +317,7 @@ def smart_deepcopy(obj: Obj) -> Obj:
return deepcopy(obj) # slowest way when we actually might need a deepcopy
_SENTINEL = object()
_EMPTY = object()
def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
@@ -362,70 +329,7 @@ def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bo
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
False
"""
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
if left_item is not right_item:
return False
return True
@dataclasses.dataclass(frozen=True)
class SafeGetItemProxy:
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
This makes is safe to use in `operator.itemgetter` when some keys may be missing
"""
# Define __slots__manually for performances
# @dataclasses.dataclass() only support slots=True in python>=3.10
__slots__ = ('wrapped',)
wrapped: Mapping[str, Any]
def __getitem__(self, key: str, /) -> Any:
return self.wrapped.get(key, _SENTINEL)
# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
# https://github.com/python/mypy/issues/13713
# https://github.com/python/typeshed/pull/8785
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
if typing.TYPE_CHECKING:
def __contains__(self, key: str, /) -> bool:
return self.wrapped.__contains__(key)
_ModelT = TypeVar('_ModelT', bound='BaseModel')
_RT = TypeVar('_RT')
class deprecated_instance_property(Generic[_ModelT, _RT]):
"""A decorator exposing the decorated class method as a property, with a warning on instance access.
This decorator takes a class method defined on the `BaseModel` class and transforms it into
an attribute. The attribute can be accessed on both the class and instances of the class. If accessed
via an instance, a deprecation warning is emitted stating that instance access will be removed in V3.
"""
def __init__(self, fget: Callable[[type[_ModelT]], _RT], /) -> None:
# Note: fget should be a classmethod:
self.fget = fget
@overload
def __get__(self, instance: None, objtype: type[_ModelT]) -> _RT: ...
@overload
@deprecated(
'Accessing this attribute on the instance is deprecated, and will be removed in Pydantic V3. '
'Instead, you should access this attribute from the model class.',
category=None,
)
def __get__(self, instance: _ModelT, objtype: type[_ModelT]) -> _RT: ...
def __get__(self, instance: _ModelT | None, objtype: type[_ModelT]) -> _RT:
if instance is not None:
attr_name = self.fget.__name__ if sys.version_info >= (3, 10) else self.fget.__func__.__name__
warnings.warn(
f'Accessing the {attr_name!r} attribute on the instance is deprecated. '
'Instead, you should access this attribute from the model class.',
category=PydanticDeprecatedSince211,
stacklevel=2,
)
return self.fget.__get__(instance, objtype)()

View File

@@ -1,122 +1,88 @@
from __future__ import annotations as _annotations
import functools
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable
from typing import Any, Awaitable, Callable
import pydantic_core
from ..config import ConfigDict
from ..plugin._schema_validator import create_schema_validator
from . import _discriminated_union, _generate_schema, _typing_extra
from ._config import ConfigWrapper
from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
from ._core_utils import simplify_schema_references, validate_core_schema
def extract_function_name(func: ValidateCallSupportedTypes) -> str:
"""Extract the name of a `ValidateCallSupportedTypes` object."""
return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
"""Extract the qualname of a `ValidateCallSupportedTypes` object."""
return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
"""Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
if inspect.iscoroutinefunction(wrapped):
@functools.wraps(wrapped)
async def wrapper_function(*args, **kwargs): # type: ignore
return await wrapper(*args, **kwargs)
else:
@functools.wraps(wrapped)
def wrapper_function(*args, **kwargs):
return wrapper(*args, **kwargs)
# We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
wrapper_function.__name__ = extract_function_name(wrapped)
wrapper_function.__qualname__ = extract_function_qualname(wrapped)
wrapper_function.raw_function = wrapped # type: ignore
return wrapper_function
@dataclass
class CallMarker:
function: Callable[..., Any]
validate_return: bool
class ValidateCallWrapper:
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value.
It's partially inspired by `wraps` which in turn uses `partial`, but extended to be a descriptor so
these functions can be applied to instance methods, class methods, static methods, as well as normal functions.
"""
__slots__ = (
'function',
'validate_return',
'schema_type',
'module',
'qualname',
'ns_resolver',
'config_wrapper',
'__pydantic_complete__',
'raw_function',
'_config',
'_validate_return',
'__pydantic_core_schema__',
'__pydantic_validator__',
'__return_pydantic_validator__',
'__signature__',
'__name__',
'__qualname__',
'__annotations__',
'__dict__', # required for __module__
)
def __init__(
self,
function: ValidateCallSupportedTypes,
config: ConfigDict | None,
validate_return: bool,
parent_namespace: MappingNamespace | None,
) -> None:
self.function = function
self.validate_return = validate_return
def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
self.raw_function = function
self._config = config
self._validate_return = validate_return
self.__signature__ = inspect.signature(function)
if isinstance(function, partial):
self.schema_type = function.func
self.module = function.func.__module__
func = function.func
self.__name__ = f'partial({func.__name__})'
self.__qualname__ = f'partial({func.__qualname__})'
self.__annotations__ = func.__annotations__
self.__module__ = func.__module__
self.__doc__ = func.__doc__
else:
self.schema_type = function
self.module = function.__module__
self.qualname = extract_function_qualname(function)
self.__name__ = function.__name__
self.__qualname__ = function.__qualname__
self.__annotations__ = function.__annotations__
self.__module__ = function.__module__
self.__doc__ = function.__doc__
self.ns_resolver = NsResolver(
namespaces_tuple=ns_for_function(self.schema_type, parent_namespace=parent_namespace)
)
self.config_wrapper = ConfigWrapper(config)
if not self.config_wrapper.defer_build:
self._create_validators()
else:
self.__pydantic_complete__ = False
namespace = _typing_extra.add_module_globals(function, None)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.collect_definitions(gen_schema.generate_schema(function))
schema = simplify_schema_references(schema)
self.__pydantic_core_schema__ = schema = schema
core_config = config_wrapper.core_config(self)
schema = _discriminated_union.apply_discriminators(schema)
self.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
def _create_validators(self) -> None:
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
schema = gen_schema.clean_schema(gen_schema.generate_schema(self.function))
core_config = self.config_wrapper.core_config(title=self.qualname)
self.__pydantic_validator__ = create_schema_validator(
schema,
self.schema_type,
self.module,
self.qualname,
'validate_call',
core_config,
self.config_wrapper.plugin_settings,
)
if self.validate_return:
signature = inspect.signature(self.function)
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
validator = create_schema_validator(
schema,
self.schema_type,
self.module,
self.qualname,
'validate_call',
core_config,
self.config_wrapper.plugin_settings,
if self._validate_return:
return_type = (
self.__signature__.return_annotation
if self.__signature__.return_annotation is not self.__signature__.empty
else Any
)
if inspect.iscoroutinefunction(self.function):
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.collect_definitions(gen_schema.generate_schema(return_type))
schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))
self.__return_pydantic_core_schema__ = schema
core_config = config_wrapper.core_config(self)
schema = validate_core_schema(schema)
validator = pydantic_core.SchemaValidator(schema, core_config)
if inspect.iscoroutinefunction(self.raw_function):
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
return validator.validate_python(await aw)
@@ -125,16 +91,38 @@ class ValidateCallWrapper:
else:
self.__return_pydantic_validator__ = validator.validate_python
else:
self.__return_pydantic_core_schema__ = None
self.__return_pydantic_validator__ = None
self.__pydantic_complete__ = True
self._name: str | None = None # set by __get__, used to set the instance attribute when decorating methods
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if not self.__pydantic_complete__:
self._create_validators()
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
if self.__return_pydantic_validator__:
return self.__return_pydantic_validator__(res)
else:
return res
return res
def __get__(self, obj: Any, objtype: type[Any] | None = None) -> ValidateCallWrapper:
"""Bind the raw function and return another ValidateCallWrapper wrapping that."""
if obj is None:
try:
# Handle the case where a method is accessed as a class attribute
return objtype.__getattribute__(objtype, self._name) # type: ignore
except AttributeError:
# This will happen the first time the attribute is accessed
pass
bound_function = self.raw_function.__get__(obj, objtype)
result = self.__class__(bound_function, self._config, self._validate_return)
if self._name is not None:
if obj is not None:
object.__setattr__(obj, self._name, result)
else:
object.__setattr__(objtype, self._name, result)
return result
def __set_name__(self, owner: Any, name: str) -> None:
self._name = name
def __repr__(self) -> str:
return f'ValidateCallWrapper({self.raw_function})'

View File

@@ -5,32 +5,22 @@ Import of this module is deferred since it contains imports of many standard lib
from __future__ import annotations as _annotations
import collections.abc
import math
import re
import typing
from decimal import Decimal
from fractions import Fraction
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from typing import Any, Callable, Union, cast, get_origin
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from typing import Any
import typing_extensions
from pydantic_core import PydanticCustomError, core_schema
from pydantic_core._pydantic_core import PydanticKnownError
from typing_inspection import typing_objects
from pydantic._internal._import_utils import import_cached_field_info
from pydantic.errors import PydanticSchemaGenerationError
def sequence_validator(
input_value: typing.Sequence[Any],
/,
__input_value: typing.Sequence[Any],
validator: core_schema.ValidatorFunctionWrapHandler,
) -> typing.Sequence[Any]:
"""Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
value_type = type(input_value)
value_type = type(__input_value)
# We don't accept any plain string as a sequence
# Relevant issue: https://github.com/pydantic/pydantic/issues/5595
@@ -41,24 +31,14 @@ def sequence_validator(
{'type_name': value_type.__name__},
)
# TODO: refactor sequence validation to validate with either a list or a tuple
# schema, depending on the type of the value.
# Additionally, we should be able to remove one of either this validator or the
# SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
# Effectively, a refactor for sequence validation is needed.
if value_type is tuple:
input_value = list(input_value)
v_list = validator(input_value)
v_list = validator(__input_value)
# the rest of the logic is just re-creating the original type from `v_list`
if value_type is list:
if value_type == list:
return v_list
elif issubclass(value_type, range):
# return the list as we probably can't re-create the range
return v_list
elif value_type is tuple:
return tuple(v_list)
else:
# best guess at how to re-create the original type, more custom construction logic might be required
return value_type(v_list) # type: ignore[call-arg]
@@ -69,7 +49,7 @@ def import_string(value: Any) -> Any:
try:
return _import_string_logic(value)
except ImportError as e:
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)})
else:
# otherwise we just return the value and let the next validator do the rest of the work
return value
@@ -126,39 +106,39 @@ def _import_string_logic(dotted_path: str) -> Any:
return module
def pattern_either_validator(input_value: Any, /) -> typing.Pattern[Any]:
if isinstance(input_value, typing.Pattern):
return input_value
elif isinstance(input_value, (str, bytes)):
def pattern_either_validator(__input_value: Any) -> typing.Pattern[Any]:
if isinstance(__input_value, typing.Pattern):
return __input_value
elif isinstance(__input_value, (str, bytes)):
# todo strict mode
return compile_pattern(input_value) # type: ignore
return compile_pattern(__input_value) # type: ignore
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
def pattern_str_validator(input_value: Any, /) -> typing.Pattern[str]:
if isinstance(input_value, typing.Pattern):
if isinstance(input_value.pattern, str):
return input_value
def pattern_str_validator(__input_value: Any) -> typing.Pattern[str]:
if isinstance(__input_value, typing.Pattern):
if isinstance(__input_value.pattern, str):
return __input_value
else:
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
elif isinstance(input_value, str):
return compile_pattern(input_value)
elif isinstance(input_value, bytes):
elif isinstance(__input_value, str):
return compile_pattern(__input_value)
elif isinstance(__input_value, bytes):
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
def pattern_bytes_validator(input_value: Any, /) -> typing.Pattern[bytes]:
if isinstance(input_value, typing.Pattern):
if isinstance(input_value.pattern, bytes):
return input_value
def pattern_bytes_validator(__input_value: Any) -> typing.Pattern[bytes]:
if isinstance(__input_value, typing.Pattern):
if isinstance(__input_value.pattern, bytes):
return __input_value
else:
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
elif isinstance(input_value, bytes):
return compile_pattern(input_value)
elif isinstance(input_value, str):
elif isinstance(__input_value, bytes):
return compile_pattern(__input_value)
elif isinstance(__input_value, str):
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
@@ -174,359 +154,125 @@ def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]:
raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
if isinstance(input_value, IPv4Address):
return input_value
def ip_v4_address_validator(__input_value: Any) -> IPv4Address:
if isinstance(__input_value, IPv4Address):
return __input_value
try:
return IPv4Address(input_value)
return IPv4Address(__input_value)
except ValueError:
raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
if isinstance(input_value, IPv6Address):
return input_value
def ip_v6_address_validator(__input_value: Any) -> IPv6Address:
if isinstance(__input_value, IPv6Address):
return __input_value
try:
return IPv6Address(input_value)
return IPv6Address(__input_value)
except ValueError:
raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
def ip_v4_network_validator(__input_value: Any) -> IPv4Network:
"""Assume IPv4Network initialised with a default `strict` argument.
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
"""
if isinstance(input_value, IPv4Network):
return input_value
if isinstance(__input_value, IPv4Network):
return __input_value
try:
return IPv4Network(input_value)
return IPv4Network(__input_value)
except ValueError:
raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
def ip_v6_network_validator(__input_value: Any) -> IPv6Network:
"""Assume IPv6Network initialised with a default `strict` argument.
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
"""
if isinstance(input_value, IPv6Network):
return input_value
if isinstance(__input_value, IPv6Network):
return __input_value
try:
return IPv6Network(input_value)
return IPv6Network(__input_value)
except ValueError:
raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
if isinstance(input_value, IPv4Interface):
return input_value
def ip_v4_interface_validator(__input_value: Any) -> IPv4Interface:
if isinstance(__input_value, IPv4Interface):
return __input_value
try:
return IPv4Interface(input_value)
return IPv4Interface(__input_value)
except ValueError:
raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
if isinstance(input_value, IPv6Interface):
return input_value
def ip_v6_interface_validator(__input_value: Any) -> IPv6Interface:
if isinstance(__input_value, IPv6Interface):
return __input_value
try:
return IPv6Interface(input_value)
return IPv6Interface(__input_value)
except ValueError:
raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
def fraction_validator(input_value: Any, /) -> Fraction:
if isinstance(input_value, Fraction):
return input_value
def greater_than_validator(x: Any, gt: Any) -> Any:
if not (x > gt):
raise PydanticKnownError('greater_than', {'gt': gt})
return x
try:
return Fraction(input_value)
except ValueError:
raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
if not (x >= ge):
raise PydanticKnownError('greater_than_equal', {'ge': ge})
return x
def less_than_validator(x: Any, lt: Any) -> Any:
if not (x < lt):
raise PydanticKnownError('less_than', {'lt': lt})
return x
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
if not (x <= le):
raise PydanticKnownError('less_than_equal', {'le': le})
return x
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
if not (x % multiple_of == 0):
raise PydanticKnownError('multiple_of', {'multiple_of': multiple_of})
return x
def min_length_validator(x: Any, min_length: Any) -> Any:
if not (len(x) >= min_length):
raise PydanticKnownError(
'too_short',
{'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)},
)
return x
def max_length_validator(x: Any, max_length: Any) -> Any:
if len(x) > max_length:
raise PydanticKnownError(
'too_long',
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
)
return x
def forbid_inf_nan_check(x: Any) -> Any:
if not math.isfinite(x):
raise PydanticKnownError('finite_number')
return x
def _safe_repr(v: Any) -> int | float | str:
"""The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
See tests/test_types.py::test_annotated_metadata_any_order for some context.
"""
if isinstance(v, (int, float, str)):
return v
return repr(v)
def greater_than_validator(x: Any, gt: Any) -> Any:
try:
if not (x > gt):
raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
try:
if not (x >= ge):
raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
def less_than_validator(x: Any, lt: Any) -> Any:
try:
if not (x < lt):
raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
try:
if not (x <= le):
raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
try:
if x % multiple_of:
raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
def min_length_validator(x: Any, min_length: Any) -> Any:
try:
if not (len(x) >= min_length):
raise PydanticKnownError(
'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
)
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
def max_length_validator(x: Any, max_length: Any) -> Any:
try:
if len(x) > max_length:
raise PydanticKnownError(
'too_long',
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
)
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
"""Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
This function handles both normalized and non-normalized Decimal instances.
Example: Decimal('1.230') -> 4 digits, 3 decimal places
Args:
decimal (Decimal): The decimal number to analyze.
Returns:
tuple[int, int]: A tuple containing the number of decimal places and total digits.
Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
of the number of decimals and digits together.
"""
try:
decimal_tuple = decimal.as_tuple()
assert isinstance(decimal_tuple.exponent, int)
exponent = decimal_tuple.exponent
num_digits = len(decimal_tuple.digits)
if exponent >= 0:
# A positive exponent adds that many trailing zeros
# Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
num_digits += exponent
decimal_places = 0
else:
# If the absolute value of the negative exponent is larger than the
# number of digits, then it's the same as the number of digits,
# because it'll consume all the digits in digit_tuple and then
# add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
# Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
# Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
decimal_places = abs(exponent)
num_digits = max(num_digits, decimal_places)
return decimal_places, num_digits
except (AssertionError, AttributeError):
raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
def max_digits_validator(x: Any, max_digits: Any) -> Any:
try:
_, num_digits = _extract_decimal_digits_info(x)
_, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
if (num_digits > max_digits) and (normalized_num_digits > max_digits):
raise PydanticKnownError(
'decimal_max_digits',
{'max_digits': max_digits},
)
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
try:
decimal_places_, _ = _extract_decimal_digits_info(x)
if decimal_places_ > decimal_places:
normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
if normalized_decimal_places > decimal_places:
raise PydanticKnownError(
'decimal_max_places',
{'decimal_places': decimal_places},
)
return x
except TypeError:
raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]:
return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None))
def defaultdict_validator(
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
) -> collections.defaultdict[Any, Any]:
if isinstance(input_value, collections.defaultdict):
default_factory = input_value.default_factory
return collections.defaultdict(default_factory, handler(input_value))
else:
return collections.defaultdict(default_default_factory, handler(input_value))
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
FieldInfo = import_cached_field_info()
values_type_origin = get_origin(values_source_type)
def infer_default() -> Callable[[], Any]:
allowed_default_types: dict[Any, Any] = {
tuple: tuple,
collections.abc.Sequence: tuple,
collections.abc.MutableSequence: list,
list: list,
typing.Sequence: list,
set: set,
typing.MutableSet: set,
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
typing.MutableMapping: dict,
typing.Mapping: dict,
collections.abc.Mapping: dict,
collections.abc.MutableMapping: dict,
float: float,
int: int,
str: str,
bool: bool,
}
values_type = values_type_origin or values_source_type
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
if typing_objects.is_typevar(values_type):
def type_var_default_factory() -> None:
raise RuntimeError(
'Generic defaultdict cannot be used without a concrete value type or an'
' explicit default factory, ' + instructions
)
return type_var_default_factory
elif values_type not in allowed_default_types:
# a somewhat subjective set of types that have reasonable default values
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
raise PydanticSchemaGenerationError(
f'Unable to infer a default factory for keys of type {values_source_type}.'
f' Only {allowed_msg} are supported, other types require an explicit default factory'
' ' + instructions
)
return allowed_default_types[values_type]
# Assume Annotated[..., Field(...)]
if typing_objects.is_annotated(values_type_origin):
field_info = next((v for v in typing_extensions.get_args(values_source_type) if isinstance(v, FieldInfo)), None)
else:
field_info = None
if field_info and field_info.default_factory:
# Assume the default factory does not take any argument:
default_default_factory = cast(Callable[[], Any], field_info.default_factory)
else:
default_default_factory = infer_default()
return default_default_factory
def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo:
if isinstance(value, ZoneInfo):
return value
try:
return ZoneInfo(value)
except (ZoneInfoNotFoundError, ValueError, TypeError):
raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value})
NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
'gt': greater_than_validator,
'ge': greater_than_or_equal_validator,
'lt': less_than_validator,
'le': less_than_or_equal_validator,
'multiple_of': multiple_of_validator,
'min_length': min_length_validator,
'max_length': max_length_validator,
'max_digits': max_digits_validator,
'decimal_places': decimal_places_validator,
}
IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
IPv4Address: ip_v4_address_validator,
IPv6Address: ip_v6_address_validator,
IPv4Network: ip_v4_network_validator,
IPv6Network: ip_v6_network_validator,
IPv4Interface: ip_v4_interface_validator,
IPv6Interface: ip_v6_interface_validator,
}
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
typing.DefaultDict: collections.defaultdict, # noqa: UP006
collections.defaultdict: collections.defaultdict,
typing.OrderedDict: collections.OrderedDict, # noqa: UP006
collections.OrderedDict: collections.OrderedDict,
typing_extensions.OrderedDict: collections.OrderedDict,
typing.Counter: collections.Counter,
collections.Counter: collections.Counter,
# this doesn't handle subclasses of these
typing.Mapping: dict,
typing.MutableMapping: dict,
# parametrized typing.{Mutable}Mapping creates one of these
collections.abc.Mapping: dict,
collections.abc.MutableMapping: dict,
}

View File

@@ -1,6 +1,8 @@
import sys
from typing import Any, Callable
import warnings
from typing import Any, Callable, Dict
from ._internal._validators import import_string
from .version import version_short
MOVED_IN_V2 = {
@@ -271,11 +273,7 @@ def getattr_migration(module: str) -> Callable[[str], Any]:
The object.
"""
if name == '__path__':
raise AttributeError(f'module {module!r} has no attribute {name!r}')
import warnings
from ._internal._validators import import_string
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
import_path = f'{module}:{name}'
if import_path in MOVED_IN_V2.keys():
@@ -300,9 +298,9 @@ def getattr_migration(module: str) -> Callable[[str], Any]:
)
if import_path in REMOVED_IN_V2:
raise PydanticImportError(f'`{import_path}` has been removed in V2.')
globals: dict[str, Any] = sys.modules[module].__dict__
globals: Dict[str, Any] = sys.modules[module].__dict__
if name in globals:
return globals[name]
raise AttributeError(f'module {module!r} has no attribute {name!r}')
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
return wrapper

View File

@@ -1,13 +1,8 @@
"""Alias generators for converting between different capitalization conventions."""
import re
__all__ = ('to_pascal', 'to_camel', 'to_snake')
# TODO: in V3, change the argument names to be more descriptive
# Generally, don't only convert from snake_case, or name the functions
# more specifically like snake_to_camel.
def to_pascal(snake: str) -> str:
"""Convert a snake_case string to PascalCase.
@@ -31,17 +26,12 @@ def to_camel(snake: str) -> str:
Returns:
The converted camelCase string.
"""
# If the string is already in camelCase and does not contain a digit followed
# by a lowercase letter, return it as it is
if re.match('^[a-z]+[A-Za-z0-9]*$', snake) and not re.search(r'\d[a-z]', snake):
return snake
camel = to_pascal(snake)
return re.sub('(^_*[A-Z])', lambda m: m.group(1).lower(), camel)
def to_snake(camel: str) -> str:
"""Convert a PascalCase, camelCase, or kebab-case string to snake_case.
"""Convert a PascalCase or camelCase string to snake_case.
Args:
camel: The string to convert.
@@ -49,14 +39,6 @@ def to_snake(camel: str) -> str:
Returns:
The converted string in snake_case.
"""
# Handle the sequence of uppercase letters followed by a lowercase letter
snake = re.sub(r'([A-Z]+)([A-Z][a-z])', lambda m: f'{m.group(1)}_{m.group(2)}', camel)
# Insert an underscore between a lowercase letter and an uppercase letter
snake = re.sub(r'([a-z])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
# Insert an underscore between a digit and an uppercase letter
snake = re.sub(r'([0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
# Insert an underscore between a lowercase letter and a digit
snake = re.sub(r'([a-z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
# Replace hyphens with underscores to handle kebab-case
snake = snake.replace('-', '_')
snake = re.sub(r'([a-zA-Z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', camel)
snake = re.sub(r'([a-z0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
return snake.lower()

View File

@@ -1,135 +0,0 @@
"""Support for alias configurations."""
from __future__ import annotations
import dataclasses
from typing import Any, Callable, Literal
from pydantic_core import PydanticUndefined
from ._internal import _internal_dataclass
__all__ = ('AliasGenerator', 'AliasPath', 'AliasChoices')
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class AliasPath:
"""!!! abstract "Usage Documentation"
[`AliasPath` and `AliasChoices`](../concepts/alias.md#aliaspath-and-aliaschoices)
A data class used by `validation_alias` as a convenience to create aliases.
Attributes:
path: A list of string or integer aliases.
"""
path: list[int | str]
def __init__(self, first_arg: str, *args: str | int) -> None:
self.path = [first_arg] + list(args)
def convert_to_aliases(self) -> list[str | int]:
"""Converts arguments to a list of string or integer aliases.
Returns:
The list of aliases.
"""
return self.path
def search_dict_for_path(self, d: dict) -> Any:
"""Searches a dictionary for the path specified by the alias.
Returns:
The value at the specified path, or `PydanticUndefined` if the path is not found.
"""
v = d
for k in self.path:
if isinstance(v, str):
# disallow indexing into a str, like for AliasPath('x', 0) and x='abc'
return PydanticUndefined
try:
v = v[k]
except (KeyError, IndexError, TypeError):
return PydanticUndefined
return v
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class AliasChoices:
"""!!! abstract "Usage Documentation"
[`AliasPath` and `AliasChoices`](../concepts/alias.md#aliaspath-and-aliaschoices)
A data class used by `validation_alias` as a convenience to create aliases.
Attributes:
choices: A list containing a string or `AliasPath`.
"""
choices: list[str | AliasPath]
def __init__(self, first_choice: str | AliasPath, *choices: str | AliasPath) -> None:
self.choices = [first_choice] + list(choices)
def convert_to_aliases(self) -> list[list[str | int]]:
"""Converts arguments to a list of lists containing string or integer aliases.
Returns:
The list of aliases.
"""
aliases: list[list[str | int]] = []
for c in self.choices:
if isinstance(c, AliasPath):
aliases.append(c.convert_to_aliases())
else:
aliases.append([c])
return aliases
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class AliasGenerator:
"""!!! abstract "Usage Documentation"
[Using an `AliasGenerator`](../concepts/alias.md#using-an-aliasgenerator)
A data class used by `alias_generator` as a convenience to create various aliases.
Attributes:
alias: A callable that takes a field name and returns an alias for it.
validation_alias: A callable that takes a field name and returns a validation alias for it.
serialization_alias: A callable that takes a field name and returns a serialization alias for it.
"""
alias: Callable[[str], str] | None = None
validation_alias: Callable[[str], str | AliasPath | AliasChoices] | None = None
serialization_alias: Callable[[str], str] | None = None
def _generate_alias(
self,
alias_kind: Literal['alias', 'validation_alias', 'serialization_alias'],
allowed_types: tuple[type[str] | type[AliasPath] | type[AliasChoices], ...],
field_name: str,
) -> str | AliasPath | AliasChoices | None:
"""Generate an alias of the specified kind. Returns None if the alias generator is None.
Raises:
TypeError: If the alias generator produces an invalid type.
"""
alias = None
if alias_generator := getattr(self, alias_kind):
alias = alias_generator(field_name)
if alias and not isinstance(alias, allowed_types):
raise TypeError(
f'Invalid `{alias_kind}` type. `{alias_kind}` generator must produce one of `{allowed_types}`'
)
return alias
def generate_aliases(self, field_name: str) -> tuple[str | None, str | AliasPath | AliasChoices | None, str | None]:
"""Generate `alias`, `validation_alias`, and `serialization_alias` for a field.
Returns:
A tuple of three aliases - validation, alias, and serialization.
"""
alias = self._generate_alias('alias', (str,), field_name)
validation_alias = self._generate_alias('validation_alias', (str, AliasChoices, AliasPath), field_name)
serialization_alias = self._generate_alias('serialization_alias', (str,), field_name)
return alias, validation_alias, serialization_alias # type: ignore

View File

@@ -1,5 +1,4 @@
"""Type annotations to use with `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__`."""
from __future__ import annotations as _annotations
from typing import TYPE_CHECKING, Any, Union
@@ -7,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Union
from pydantic_core import core_schema
if TYPE_CHECKING:
from ._internal._namespace_utils import NamespacesTuple
from .json_schema import JsonSchemaMode, JsonSchemaValue
CoreSchemaOrField = Union[
@@ -30,7 +28,7 @@ class GetJsonSchemaHandler:
mode: JsonSchemaMode
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
"""Call the inner handler and get the JsonSchemaValue it returns.
This will call the next JSON schema modifying function up until it calls
into `pydantic.json_schema.GenerateJsonSchema`, which will raise a
@@ -38,7 +36,7 @@ class GetJsonSchemaHandler:
a JSON schema.
Args:
core_schema: A `pydantic_core.core_schema.CoreSchema`.
__core_schema: A `pydantic_core.core_schema.CoreSchema`.
Returns:
JsonSchemaValue: The JSON schema generated by the inner JSON schema modify
@@ -46,13 +44,13 @@ class GetJsonSchemaHandler:
"""
raise NotImplementedError
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue, /) -> JsonSchemaValue:
def resolve_ref_schema(self, __maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
"""Get the real schema for a `{"$ref": ...}` schema.
If the schema given is not a `$ref` schema, it will be returned as is.
This means you don't have to check before calling this function.
Args:
maybe_ref_json_schema: A JsonSchemaValue which may be a `$ref` schema.
__maybe_ref_json_schema: A JsonSchemaValue, ref based or not.
Raises:
LookupError: If the ref is not found.
@@ -66,7 +64,7 @@ class GetJsonSchemaHandler:
class GetCoreSchemaHandler:
"""Handler to call into the next CoreSchema schema generation function."""
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
"""Call the inner handler and get the CoreSchema it returns.
This will call the next CoreSchema modifying function up until it calls
into Pydantic's internal schema generation machinery, which will raise a
@@ -74,14 +72,14 @@ class GetCoreSchemaHandler:
a CoreSchema for the given source type.
Args:
source_type: The input type.
__source_type: The input type.
Returns:
CoreSchema: The `pydantic-core` CoreSchema generated.
"""
raise NotImplementedError
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
"""Generate a schema unrelated to the current context.
Use this function if e.g. you are handling schema generation for a sequence
and want to generate a schema for its items.
@@ -89,20 +87,20 @@ class GetCoreSchemaHandler:
that was intended for the sequence itself to its items!
Args:
source_type: The input type.
__source_type: The input type.
Returns:
CoreSchema: The `pydantic-core` CoreSchema generated.
"""
raise NotImplementedError
def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema, /) -> core_schema.CoreSchema:
def resolve_ref_schema(self, __maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Get the real schema for a `definition-ref` schema.
If the schema given is not a `definition-ref` schema, it will be returned as is.
This means you don't have to check before calling this function.
Args:
maybe_ref_schema: A `CoreSchema`, `ref`-based or not.
__maybe_ref_schema: A `CoreSchema`, `ref`-based or not.
Raises:
LookupError: If the `ref` is not found.
@@ -117,6 +115,6 @@ class GetCoreSchemaHandler:
"""Get the name of the closest field to this validator."""
raise NotImplementedError
def _get_types_namespace(self) -> NamespacesTuple:
def _get_types_namespace(self) -> dict[str, Any] | None:
"""Internal method used during type resolution for serializer annotations."""
raise NotImplementedError

View File

@@ -1,5 +1,4 @@
"""`class_validators` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -11,11 +11,10 @@ Warning: Deprecated
See [`pydantic-extra-types.Color`](../usage/types/extra_types/color_types.md)
for more information.
"""
import math
import re
from colorsys import hls_to_rgb, rgb_to_hls
from typing import Any, Callable, Optional, Union, cast
from typing import Any, Callable, Optional, Tuple, Type, Union, cast
from pydantic_core import CoreSchema, PydanticCustomError, core_schema
from typing_extensions import deprecated
@@ -25,9 +24,9 @@ from ._internal._schema_generation_shared import GetJsonSchemaHandler as _GetJso
from .json_schema import JsonSchemaValue
from .warnings import PydanticDeprecatedSince20
ColorTuple = Union[tuple[int, int, int], tuple[int, int, int, float]]
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
ColorType = Union[ColorTuple, str]
HslColorTuple = Union[tuple[float, float, float], tuple[float, float, float, float]]
HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]]
class RGBA:
@@ -41,7 +40,7 @@ class RGBA:
self.b = b
self.alpha = alpha
self._tuple: tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
def __getitem__(self, item: Any) -> Any:
return self._tuple[item]
@@ -56,13 +55,13 @@ _r_sl = r'(\d{1,3}(?:\.\d+)?)%'
r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
# CSS3 RGB examples: rgb(0, 0, 0), rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 50%)
r_rgb = rf'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
r_rgb = fr'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
# CSS3 HSL examples: hsl(270, 60%, 50%), hsla(270, 60%, 50%, 0.5), hsla(270, 60%, 50%, 50%)
r_hsl = rf'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
r_hsl = fr'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
# CSS4 RGB examples: rgb(0 0 0), rgb(0 0 0 / 0.5), rgb(0 0 0 / 50%), rgba(0 0 0 / 50%)
r_rgb_v4_style = rf'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
r_rgb_v4_style = fr'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
# CSS4 HSL examples: hsl(270 60% 50%), hsl(270 60% 50% / 0.5), hsl(270 60% 50% / 50%), hsla(270 60% 50% / 50%)
r_hsl_v4_style = rf'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
r_hsl_v4_style = fr'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used
repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'}
@@ -124,7 +123,7 @@ class Color(_repr.Representation):
ValueError: When no named color is found and fallback is `False`.
"""
if self._rgba.alpha is None:
rgb = cast(tuple[int, int, int], self.as_rgb_tuple())
rgb = cast(Tuple[int, int, int], self.as_rgb_tuple())
try:
return COLORS_BY_VALUE[rgb]
except KeyError as e:
@@ -232,7 +231,7 @@ class Color(_repr.Representation):
@classmethod
def __get_pydantic_core_schema__(
cls, source: type[Any], handler: Callable[[Any], CoreSchema]
cls, source: Type[Any], handler: Callable[[Any], CoreSchema]
) -> core_schema.CoreSchema:
return core_schema.with_info_plain_validator_function(
cls._validate, serialization=core_schema.to_string_ser_schema()
@@ -255,7 +254,7 @@ class Color(_repr.Representation):
return hash(self.as_rgb_tuple())
def parse_tuple(value: tuple[Any, ...]) -> RGBA:
def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
"""Parse a tuple or list to get RGBA values.
Args:

View File

@@ -1,33 +1,23 @@
"""Configuration for Pydantic models."""
from __future__ import annotations as _annotations
import warnings
from re import Pattern
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast, overload
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union
from typing_extensions import TypeAlias, TypedDict, Unpack, deprecated
from typing_extensions import Literal, TypeAlias, TypedDict
from ._migration import getattr_migration
from .aliases import AliasGenerator
from .errors import PydanticUserError
from .warnings import PydanticDeprecatedSince211
if TYPE_CHECKING:
from ._internal._generate_schema import GenerateSchema as _GenerateSchema
from .fields import ComputedFieldInfo, FieldInfo
__all__ = ('ConfigDict', 'with_config')
__all__ = ('ConfigDict',)
JsonValue: TypeAlias = Union[int, float, str, bool, None, list['JsonValue'], 'JsonDict']
JsonDict: TypeAlias = dict[str, JsonValue]
JsonEncoder = Callable[[Any], Any]
JsonSchemaExtraCallable: TypeAlias = Union[
Callable[[JsonDict], None],
Callable[[JsonDict, type[Any]], None],
Callable[[Dict[str, Any]], None],
Callable[[Dict[str, Any], Type[Any]], None],
]
ExtraValues = Literal['allow', 'ignore', 'forbid']
@@ -39,18 +29,11 @@ class ConfigDict(TypedDict, total=False):
title: str | None
"""The title for the generated JSON schema, defaults to the model's name"""
model_title_generator: Callable[[type], str] | None
"""A callable that takes a model class and returns the title for it. Defaults to `None`."""
field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None
"""A callable that takes a field's name and info and returns title for it. Defaults to `None`."""
str_to_lower: bool
"""Whether to convert all characters to lowercase for str types. Defaults to `False`."""
str_to_upper: bool
"""Whether to convert all characters to uppercase for str types. Defaults to `False`."""
str_strip_whitespace: bool
"""Whether to strip leading and trailing whitespace for str types."""
@@ -61,108 +44,84 @@ class ConfigDict(TypedDict, total=False):
"""The maximum length for str types. Defaults to `None`."""
extra: ExtraValues | None
'''
Whether to ignore, allow, or forbid extra data during model initialization. Defaults to `'ignore'`.
"""
Whether to ignore, allow, or forbid extra attributes during model initialization. Defaults to `'ignore'`.
Three configuration values are available:
You can configure how pydantic handles the attributes that are not defined in the model:
- `'ignore'`: Providing extra data is ignored (the default):
```python
from pydantic import BaseModel, ConfigDict
* `allow` - Allow any extra attributes.
* `forbid` - Forbid any extra attributes.
* `ignore` - Ignore any extra attributes.
class User(BaseModel):
model_config = ConfigDict(extra='ignore') # (1)!
name: str
user = User(name='John Doe', age=20) # (2)!
print(user)
#> name='John Doe'
```
1. This is the default behaviour.
2. The `age` argument is ignored.
- `'forbid'`: Providing extra data is not permitted, and a [`ValidationError`][pydantic_core.ValidationError]
will be raised if this is the case:
```python
from pydantic import BaseModel, ConfigDict, ValidationError
```py
from pydantic import BaseModel, ConfigDict
class Model(BaseModel):
x: int
class User(BaseModel):
model_config = ConfigDict(extra='ignore') # (1)!
model_config = ConfigDict(extra='forbid')
name: str
try:
Model(x=1, y='a')
except ValidationError as exc:
print(exc)
"""
1 validation error for Model
y
Extra inputs are not permitted [type=extra_forbidden, input_value='a', input_type=str]
"""
```
user = User(name='John Doe', age=20) # (2)!
print(user)
#> name='John Doe'
```
- `'allow'`: Providing extra data is allowed and stored in the `__pydantic_extra__` dictionary attribute:
```python
from pydantic import BaseModel, ConfigDict
1. This is the default behaviour.
2. The `age` argument is ignored.
Instead, with `extra='allow'`, the `age` argument is included:
```py
from pydantic import BaseModel, ConfigDict
class Model(BaseModel):
x: int
class User(BaseModel):
model_config = ConfigDict(extra='allow')
model_config = ConfigDict(extra='allow')
name: str
m = Model(x=1, y='a')
assert m.__pydantic_extra__ == {'y': 'a'}
```
By default, no validation will be applied to these extra items, but you can set a type for the values by overriding
the type annotation for `__pydantic_extra__`:
```python
from pydantic import BaseModel, ConfigDict, Field, ValidationError
user = User(name='John Doe', age=20) # (1)!
print(user)
#> name='John Doe' age=20
```
1. The `age` argument is included.
With `extra='forbid'`, an error is raised:
```py
from pydantic import BaseModel, ConfigDict, ValidationError
class Model(BaseModel):
__pydantic_extra__: dict[str, int] = Field(init=False) # (1)!
class User(BaseModel):
model_config = ConfigDict(extra='forbid')
x: int
model_config = ConfigDict(extra='allow')
name: str
try:
Model(x=1, y='a')
except ValidationError as exc:
print(exc)
"""
1 validation error for Model
y
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str]
"""
m = Model(x=1, y='2')
assert m.x == 1
assert m.y == 2
assert m.model_dump() == {'x': 1, 'y': 2}
assert m.__pydantic_extra__ == {'y': 2}
```
1. The `= Field(init=False)` does not have any effect at runtime, but prevents the `__pydantic_extra__` field from
being included as a parameter to the model's `__init__` method by type checkers.
'''
try:
User(name='John Doe', age=20)
except ValidationError as e:
print(e)
'''
1 validation error for User
age
Extra inputs are not permitted [type=extra_forbidden, input_value=20, input_type=int]
'''
```
"""
frozen: bool
"""
Whether models are faux-immutable, i.e. whether `__setattr__` is allowed, and also generates
Whether or not models are faux-immutable, i.e. whether `__setattr__` is allowed, and also generates
a `__hash__()` method for the model. This makes instances of the model potentially hashable if all the
attributes are hashable. Defaults to `False`.
Note:
On V1, the inverse of this setting was called `allow_mutation`, and was `True` by default.
On V1, this setting was called `allow_mutation`, and was `True` by default.
"""
populate_by_name: bool
@@ -170,77 +129,38 @@ class ConfigDict(TypedDict, total=False):
Whether an aliased field may be populated by its name as given by the model
attribute, as well as the alias. Defaults to `False`.
!!! warning
`populate_by_name` usage is not recommended in v2.11+ and will be deprecated in v3.
Instead, you should use the [`validate_by_name`][pydantic.config.ConfigDict.validate_by_name] configuration setting.
Note:
The name of this configuration setting was changed in **v2.0** from
`allow_population_by_alias` to `populate_by_name`.
When `validate_by_name=True` and `validate_by_alias=True`, this is strictly equivalent to the
previous behavior of `populate_by_name=True`.
```py
from pydantic import BaseModel, ConfigDict, Field
In v2.11, we also introduced a [`validate_by_alias`][pydantic.config.ConfigDict.validate_by_alias] setting that introduces more fine grained
control for validation behavior.
Here's how you might go about using the new settings to achieve the same behavior:
class User(BaseModel):
model_config = ConfigDict(populate_by_name=True)
```python
from pydantic import BaseModel, ConfigDict, Field
name: str = Field(alias='full_name') # (1)!
age: int
class Model(BaseModel):
model_config = ConfigDict(validate_by_name=True, validate_by_alias=True)
my_field: str = Field(alias='my_alias') # (1)!
user = User(full_name='John Doe', age=20) # (2)!
print(user)
#> name='John Doe' age=20
user = User(name='John Doe', age=20) # (3)!
print(user)
#> name='John Doe' age=20
```
m = Model(my_alias='foo') # (2)!
print(m)
#> my_field='foo'
m = Model(my_alias='foo') # (3)!
print(m)
#> my_field='foo'
```
1. The field `'my_field'` has an alias `'my_alias'`.
2. The model is populated by the alias `'my_alias'`.
3. The model is populated by the attribute name `'my_field'`.
1. The field `'name'` has an alias `'full_name'`.
2. The model is populated by the alias `'full_name'`.
3. The model is populated by the field name `'name'`.
"""
use_enum_values: bool
"""
Whether to populate models with the `value` property of enums, rather than the raw enum.
This may be useful if you want to serialize `model.model_dump()` later. Defaults to `False`.
!!! note
If you have an `Optional[Enum]` value that you set a default for, you need to use `validate_default=True`
for said Field to ensure that the `use_enum_values` flag takes effect on the default, as extracting an
enum's value occurs during validation, not serialization.
```python
from enum import Enum
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
class SomeEnum(Enum):
FOO = 'foo'
BAR = 'bar'
BAZ = 'baz'
class SomeModel(BaseModel):
model_config = ConfigDict(use_enum_values=True)
some_enum: SomeEnum
another_enum: Optional[SomeEnum] = Field(
default=SomeEnum.FOO, validate_default=True
)
model1 = SomeModel(some_enum=SomeEnum.BAR)
print(model1.model_dump())
#> {'some_enum': 'bar', 'another_enum': 'foo'}
model2 = SomeModel(some_enum=SomeEnum.BAR, another_enum=SomeEnum.BAZ)
print(model2.model_dump())
#> {'some_enum': 'bar', 'another_enum': 'baz'}
```
"""
validate_assignment: bool
@@ -251,7 +171,7 @@ class ConfigDict(TypedDict, total=False):
In case the user changes the data after the model is created, the model is _not_ revalidated.
```python
```py
from pydantic import BaseModel
class User(BaseModel):
@@ -270,7 +190,7 @@ class ConfigDict(TypedDict, total=False):
In case you want to revalidate the model when the data is changed, you can use `validate_assignment=True`:
```python
```py
from pydantic import BaseModel, ValidationError
class User(BaseModel, validate_assignment=True): # (1)!
@@ -299,7 +219,7 @@ class ConfigDict(TypedDict, total=False):
"""
Whether arbitrary types are allowed for field types. Defaults to `False`.
```python
```py
from pydantic import BaseModel, ConfigDict, ValidationError
# This is not a pydantic model, it's an arbitrary class
@@ -358,20 +278,14 @@ class ConfigDict(TypedDict, total=False):
loc_by_alias: bool
"""Whether to use the actual key provided in the data (e.g. alias) for error `loc`s rather than the field's name. Defaults to `True`."""
alias_generator: Callable[[str], str] | AliasGenerator | None
alias_generator: Callable[[str], str] | None
"""
A callable that takes a field name and returns an alias for it
or an instance of [`AliasGenerator`][pydantic.aliases.AliasGenerator]. Defaults to `None`.
When using a callable, the alias generator is used for both validation and serialization.
If you want to use different alias generators for validation and serialization, you can use
[`AliasGenerator`][pydantic.aliases.AliasGenerator] instead.
A callable that takes a field name and returns an alias for it.
If data source field names do not match your code style (e. g. CamelCase fields),
you can automatically generate aliases using `alias_generator`. Here's an example with
a basic callable:
you can automatically generate aliases using `alias_generator`:
```python
```py
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_pascal
@@ -388,30 +302,6 @@ class ConfigDict(TypedDict, total=False):
#> {'Name': 'Filiz', 'LanguageCode': 'tr-TR'}
```
If you want to use different alias generators for validation and serialization, you can use
[`AliasGenerator`][pydantic.aliases.AliasGenerator].
```python
from pydantic import AliasGenerator, BaseModel, ConfigDict
from pydantic.alias_generators import to_camel, to_pascal
class Athlete(BaseModel):
first_name: str
last_name: str
sport: str
model_config = ConfigDict(
alias_generator=AliasGenerator(
validation_alias=to_camel,
serialization_alias=to_pascal,
)
)
athlete = Athlete(firstName='John', lastName='Doe', sport='track')
print(athlete.model_dump(by_alias=True))
#> {'FirstName': 'John', 'LastName': 'Doe', 'Sport': 'track'}
```
Note:
Pydantic offers three built-in alias generators: [`to_pascal`][pydantic.alias_generators.to_pascal],
[`to_camel`][pydantic.alias_generators.to_camel], and [`to_snake`][pydantic.alias_generators.to_snake].
@@ -425,9 +315,9 @@ class ConfigDict(TypedDict, total=False):
"""
allow_inf_nan: bool
"""Whether to allow infinity (`+inf` an `-inf`) and NaN values to float and decimal fields. Defaults to `True`."""
"""Whether to allow infinity (`+inf` an `-inf`) and NaN values to float fields. Defaults to `True`."""
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
json_schema_extra: dict[str, object] | JsonSchemaExtraCallable | None
"""A dict or callable to provide extra JSON schema properties. Defaults to `None`."""
json_encoders: dict[type[object], JsonEncoder] | None
@@ -452,7 +342,7 @@ class ConfigDict(TypedDict, total=False):
To configure strict mode for all fields on a model, you can set `strict=True` on the model.
```python
```py
from pydantic import BaseModel, ConfigDict
class Model(BaseModel):
@@ -480,14 +370,16 @@ class ConfigDict(TypedDict, total=False):
By default, model and dataclass instances are not revalidated during validation.
```python
```py
from typing import List
from pydantic import BaseModel
class User(BaseModel, revalidate_instances='never'): # (1)!
hobbies: list[str]
hobbies: List[str]
class SubUser(User):
sins: list[str]
sins: List[str]
class Transaction(BaseModel):
user: User
@@ -515,14 +407,16 @@ class ConfigDict(TypedDict, total=False):
If you want to revalidate instances during validation, you can set `revalidate_instances` to `'always'`
in the model's config.
```python
```py
from typing import List
from pydantic import BaseModel, ValidationError
class User(BaseModel, revalidate_instances='always'): # (1)!
hobbies: list[str]
hobbies: List[str]
class SubUser(User):
sins: list[str]
sins: List[str]
class Transaction(BaseModel):
user: User
@@ -556,14 +450,16 @@ class ConfigDict(TypedDict, total=False):
It's also possible to set `revalidate_instances` to `'subclass-instances'` to only revalidate instances
of subclasses of the model.
```python
```py
from typing import List
from pydantic import BaseModel
class User(BaseModel, revalidate_instances='subclass-instances'): # (1)!
hobbies: list[str]
hobbies: List[str]
class SubUser(User):
sins: list[str]
sins: List[str]
class Transaction(BaseModel):
user: User
@@ -598,33 +494,13 @@ class ConfigDict(TypedDict, total=False):
- `'float'` will serialize timedeltas to the total number of seconds.
"""
ser_json_bytes: Literal['utf8', 'base64', 'hex']
ser_json_bytes: Literal['utf8', 'base64']
"""
The encoding of JSON serialized bytes. Defaults to `'utf8'`.
Set equal to `val_json_bytes` to get back an equal value after serialization round trip.
The encoding of JSON serialized bytes. Accepts the string values of `'utf8'` and `'base64'`.
Defaults to `'utf8'`.
- `'utf8'` will serialize bytes to UTF-8 strings.
- `'base64'` will serialize bytes to URL safe base64 strings.
- `'hex'` will serialize bytes to hexadecimal strings.
"""
val_json_bytes: Literal['utf8', 'base64', 'hex']
"""
The encoding of JSON serialized bytes to decode. Defaults to `'utf8'`.
Set equal to `ser_json_bytes` to get back an equal value after serialization round trip.
- `'utf8'` will deserialize UTF-8 strings to bytes.
- `'base64'` will deserialize URL safe base64 strings to bytes.
- `'hex'` will deserialize hexadecimal strings to bytes.
"""
ser_json_inf_nan: Literal['null', 'constants', 'strings']
"""
The encoding of JSON serialized infinity and NaN float values. Defaults to `'null'`.
- `'null'` will serialize infinity and NaN values as `null`.
- `'constants'` will serialize infinity and NaN values as `Infinity` and `NaN`.
- `'strings'` will serialize infinity as string `"Infinity"` and NaN as string `"NaN"`.
"""
# whether to validate default values during validation, default False
@@ -632,26 +508,17 @@ class ConfigDict(TypedDict, total=False):
"""Whether to validate default values during validation. Defaults to `False`."""
validate_return: bool
"""Whether to validate the return value from call validators. Defaults to `False`."""
"""whether to validate the return value from call validators. Defaults to `False`."""
protected_namespaces: tuple[str | Pattern[str], ...]
protected_namespaces: tuple[str, ...]
"""
A `tuple` of strings and/or patterns that prevent models from having fields with names that conflict with them.
For strings, we match on a prefix basis. Ex, if 'dog' is in the protected namespace, 'dog_name' will be protected.
For patterns, we match on the entire field name. Ex, if `re.compile(r'^dog$')` is in the protected namespace, 'dog' will be protected, but 'dog_name' will not be.
Defaults to `('model_validate', 'model_dump',)`.
A `tuple` of strings that prevent model to have field which conflict with them.
Defaults to `('model_', )`).
The reason we've selected these is to prevent collisions with other validation / dumping formats
in the future - ex, `model_validate_{some_newly_supported_format}`.
Pydantic prevents collisions between model attributes and `BaseModel`'s own methods by
namespacing them with the prefix `model_`.
Before v2.10, Pydantic used `('model_',)` as the default value for this setting to
prevent collisions between model attributes and `BaseModel`'s own methods. This was changed
in v2.10 given feedback that this restriction was limiting in AI and data science contexts,
where it is common to have fields with names like `model_id`, `model_input`, `model_output`, etc.
For more details, see https://github.com/pydantic/pydantic/issues/10315.
```python
```py
import warnings
from pydantic import BaseModel
@@ -661,65 +528,56 @@ class ConfigDict(TypedDict, total=False):
try:
class Model(BaseModel):
model_dump_something: str
model_prefixed_field: str
except UserWarning as e:
print(e)
'''
Field "model_dump_something" in Model has conflict with protected namespace "model_dump".
Field "model_prefixed_field" has conflict with protected namespace "model_".
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('model_validate',)`.
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
'''
```
You can customize this behavior using the `protected_namespaces` setting:
```python {test="skip"}
import re
```py
import warnings
from pydantic import BaseModel, ConfigDict
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter('always') # Catch all warnings
warnings.filterwarnings('error') # Raise warnings as errors
try:
class Model(BaseModel):
safe_field: str
model_prefixed_field: str
also_protect_field: str
protect_this: str
model_config = ConfigDict(
protected_namespaces=(
'protect_me_',
'also_protect_',
re.compile('^protect_this$'),
)
protected_namespaces=('protect_me_', 'also_protect_')
)
for warning in caught_warnings:
print(f'{warning.message}')
except UserWarning as e:
print(e)
'''
Field "also_protect_field" in Model has conflict with protected namespace "also_protect_".
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_', re.compile('^protect_this$'))`.
Field "also_protect_field" has conflict with protected namespace "also_protect_".
Field "protect_this" in Model has conflict with protected namespace "re.compile('^protect_this$')".
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_', 'also_protect_')`.
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_',)`.
'''
```
While Pydantic will only emit a warning when an item is in a protected namespace but does not actually have a collision,
an error _is_ raised if there is an actual collision with an existing attribute:
```python
from pydantic import BaseModel, ConfigDict
```py
from pydantic import BaseModel
try:
class Model(BaseModel):
model_validate: str
model_config = ConfigDict(protected_namespaces=('model_',))
except NameError as e:
print(e)
'''
@@ -734,7 +592,7 @@ class ConfigDict(TypedDict, total=False):
Pydantic shows the input value and type when it raises `ValidationError` during the validation.
```python
```py
from pydantic import BaseModel, ValidationError
class Model(BaseModel):
@@ -753,7 +611,7 @@ class ConfigDict(TypedDict, total=False):
You can hide the input value and type by setting the `hide_input_in_errors` config to `True`.
```python
```py
from pydantic import BaseModel, ConfigDict, ValidationError
class Model(BaseModel):
@@ -774,26 +632,27 @@ class ConfigDict(TypedDict, total=False):
defer_build: bool
"""
Whether to defer model validator and serializer construction until the first model validation. Defaults to False.
Whether to defer model validator and serializer construction until the first model validation.
This can be useful to avoid the overhead of building models which are only
used nested within other models, or when you want to manually define type namespace via
[`Model.model_rebuild(_types_namespace=...)`][pydantic.BaseModel.model_rebuild].
Since v2.10, this setting also applies to pydantic dataclasses and TypeAdapter instances.
[`Model.model_rebuild(_types_namespace=...)`][pydantic.BaseModel.model_rebuild]. Defaults to False.
"""
plugin_settings: dict[str, object] | None
"""A `dict` of settings for plugins. Defaults to `None`."""
"""A `dict` of settings for plugins. Defaults to `None`.
See [Pydantic Plugins](../concepts/plugins.md) for details.
"""
schema_generator: type[_GenerateSchema] | None
"""
!!! warning
`schema_generator` is deprecated in v2.10.
A custom core schema generator class to use when generating JSON schemas.
Useful if you want to change the way types are validated across an entire model/schema. Defaults to `None`.
Prior to v2.10, this setting was advertised as highly subject to change.
It's possible that this interface may once again become public once the internal core schema generation
API is more stable, but that will likely come after significant performance improvements have been made.
The `GenerateSchema` interface is subject to change, currently only the `string_schema` method is public.
See [#6737](https://github.com/pydantic/pydantic/pull/6737) for details.
"""
json_schema_serialization_defaults_required: bool
@@ -807,7 +666,7 @@ class ConfigDict(TypedDict, total=False):
between validation and serialization, and don't mind fields with defaults being marked as not required during
serialization. See [#7209](https://github.com/pydantic/pydantic/issues/7209) for more details.
```python
```py
from pydantic import BaseModel, ConfigDict
class Model(BaseModel):
@@ -850,7 +709,7 @@ class ConfigDict(TypedDict, total=False):
the validation and serialization schemas (since both will use the specified schema), and so prevents the suffixes
from being added to the definition references.
```python
```py
from pydantic import BaseModel, ConfigDict, Json
class Model(BaseModel):
@@ -896,7 +755,7 @@ class ConfigDict(TypedDict, total=False):
Pydantic doesn't allow number types (`int`, `float`, `Decimal`) to be coerced as type `str` by default.
```python
```py
from decimal import Decimal
from pydantic import BaseModel, ConfigDict, ValidationError
@@ -928,286 +787,5 @@ class ConfigDict(TypedDict, total=False):
```
"""
regex_engine: Literal['rust-regex', 'python-re']
"""
The regex engine to be used for pattern validation.
Defaults to `'rust-regex'`.
- `rust-regex` uses the [`regex`](https://docs.rs/regex) Rust crate,
which is non-backtracking and therefore more DDoS resistant, but does not support all regex features.
- `python-re` use the [`re`](https://docs.python.org/3/library/re.html) module,
which supports all regex features, but may be slower.
!!! note
If you use a compiled regex pattern, the python-re engine will be used regardless of this setting.
This is so that flags such as `re.IGNORECASE` are respected.
```python
from pydantic import BaseModel, ConfigDict, Field, ValidationError
class Model(BaseModel):
model_config = ConfigDict(regex_engine='python-re')
value: str = Field(pattern=r'^abc(?=def)')
print(Model(value='abcdef').value)
#> abcdef
try:
print(Model(value='abxyzcdef'))
except ValidationError as e:
print(e)
'''
1 validation error for Model
value
String should match pattern '^abc(?=def)' [type=string_pattern_mismatch, input_value='abxyzcdef', input_type=str]
'''
```
"""
validation_error_cause: bool
"""
If `True`, Python exceptions that were part of a validation failure will be shown as an exception group as a cause. Can be useful for debugging. Defaults to `False`.
Note:
Python 3.10 and older don't support exception groups natively. <=3.10, backport must be installed: `pip install exceptiongroup`.
Note:
The structure of validation errors are likely to change in future Pydantic versions. Pydantic offers no guarantees about their structure. Should be used for visual traceback debugging only.
"""
use_attribute_docstrings: bool
'''
Whether docstrings of attributes (bare string literals immediately following the attribute declaration)
should be used for field descriptions. Defaults to `False`.
Available in Pydantic v2.7+.
```python
from pydantic import BaseModel, ConfigDict, Field
class Model(BaseModel):
model_config = ConfigDict(use_attribute_docstrings=True)
x: str
"""
Example of an attribute docstring
"""
y: int = Field(description="Description in Field")
"""
Description in Field overrides attribute docstring
"""
print(Model.model_fields["x"].description)
# > Example of an attribute docstring
print(Model.model_fields["y"].description)
# > Description in Field
```
This requires the source code of the class to be available at runtime.
!!! warning "Usage with `TypedDict` and stdlib dataclasses"
Due to current limitations, attribute docstrings detection may not work as expected when using
[`TypedDict`][typing.TypedDict] and stdlib dataclasses, in particular when:
- inheritance is being used.
- multiple classes have the same name in the same source file.
'''
cache_strings: bool | Literal['all', 'keys', 'none']
"""
Whether to cache strings to avoid constructing new Python objects. Defaults to True.
Enabling this setting should significantly improve validation performance while increasing memory usage slightly.
- `True` or `'all'` (the default): cache all strings
- `'keys'`: cache only dictionary keys
- `False` or `'none'`: no caching
!!! note
`True` or `'all'` is required to cache strings during general validation because
validators don't know if they're in a key or a value.
!!! tip
If repeated strings are rare, it's recommended to use `'keys'` or `'none'` to reduce memory usage,
as the performance difference is minimal if repeated strings are rare.
"""
validate_by_alias: bool
"""
Whether an aliased field may be populated by its alias. Defaults to `True`.
!!! note
In v2.11, `validate_by_alias` was introduced in conjunction with [`validate_by_name`][pydantic.ConfigDict.validate_by_name]
to empower users with more fine grained validation control. In <v2.11, disabling validation by alias was not possible.
Here's an example of disabling validation by alias:
```py
from pydantic import BaseModel, ConfigDict, Field
class Model(BaseModel):
model_config = ConfigDict(validate_by_name=True, validate_by_alias=False)
my_field: str = Field(validation_alias='my_alias') # (1)!
m = Model(my_field='foo') # (2)!
print(m)
#> my_field='foo'
```
1. The field `'my_field'` has an alias `'my_alias'`.
2. The model can only be populated by the attribute name `'my_field'`.
!!! warning
You cannot set both `validate_by_alias` and `validate_by_name` to `False`.
This would make it impossible to populate an attribute.
See [usage errors](../errors/usage_errors.md#validate-by-alias-and-name-false) for an example.
If you set `validate_by_alias` to `False`, under the hood, Pydantic dynamically sets
`validate_by_name` to `True` to ensure that validation can still occur.
"""
validate_by_name: bool
"""
Whether an aliased field may be populated by its name as given by the model
attribute. Defaults to `False`.
!!! note
In v2.0-v2.10, the `populate_by_name` configuration setting was used to specify
whether or not a field could be populated by its name **and** alias.
In v2.11, `validate_by_name` was introduced in conjunction with [`validate_by_alias`][pydantic.ConfigDict.validate_by_alias]
to empower users with more fine grained validation behavior control.
```python
from pydantic import BaseModel, ConfigDict, Field
class Model(BaseModel):
model_config = ConfigDict(validate_by_name=True, validate_by_alias=True)
my_field: str = Field(validation_alias='my_alias') # (1)!
m = Model(my_alias='foo') # (2)!
print(m)
#> my_field='foo'
m = Model(my_field='foo') # (3)!
print(m)
#> my_field='foo'
```
1. The field `'my_field'` has an alias `'my_alias'`.
2. The model is populated by the alias `'my_alias'`.
3. The model is populated by the attribute name `'my_field'`.
!!! warning
You cannot set both `validate_by_alias` and `validate_by_name` to `False`.
This would make it impossible to populate an attribute.
See [usage errors](../errors/usage_errors.md#validate-by-alias-and-name-false) for an example.
"""
serialize_by_alias: bool
"""
Whether an aliased field should be serialized by its alias. Defaults to `False`.
Note: In v2.11, `serialize_by_alias` was introduced to address the
[popular request](https://github.com/pydantic/pydantic/issues/8379)
for consistency with alias behavior for validation and serialization settings.
In v3, the default value is expected to change to `True` for consistency with the validation default.
```python
from pydantic import BaseModel, ConfigDict, Field
class Model(BaseModel):
model_config = ConfigDict(serialize_by_alias=True)
my_field: str = Field(serialization_alias='my_alias') # (1)!
m = Model(my_field='foo')
print(m.model_dump()) # (2)!
#> {'my_alias': 'foo'}
```
1. The field `'my_field'` has an alias `'my_alias'`.
2. The model is serialized using the alias `'my_alias'` for the `'my_field'` attribute.
"""
_TypeT = TypeVar('_TypeT', bound=type)
@overload
@deprecated('Passing `config` as a keyword argument is deprecated. Pass `config` as a positional argument instead.')
def with_config(*, config: ConfigDict) -> Callable[[_TypeT], _TypeT]: ...
@overload
def with_config(config: ConfigDict, /) -> Callable[[_TypeT], _TypeT]: ...
@overload
def with_config(**config: Unpack[ConfigDict]) -> Callable[[_TypeT], _TypeT]: ...
def with_config(config: ConfigDict | None = None, /, **kwargs: Any) -> Callable[[_TypeT], _TypeT]:
"""!!! abstract "Usage Documentation"
[Configuration with other types](../concepts/config.md#configuration-on-other-supported-types)
A convenience decorator to set a [Pydantic configuration](config.md) on a `TypedDict` or a `dataclass` from the standard library.
Although the configuration can be set using the `__pydantic_config__` attribute, it does not play well with type checkers,
especially with `TypedDict`.
!!! example "Usage"
```python
from typing_extensions import TypedDict
from pydantic import ConfigDict, TypeAdapter, with_config
@with_config(ConfigDict(str_to_lower=True))
class TD(TypedDict):
x: str
ta = TypeAdapter(TD)
print(ta.validate_python({'x': 'ABC'}))
#> {'x': 'abc'}
```
"""
if config is not None and kwargs:
raise ValueError('Cannot specify both `config` and keyword arguments')
if len(kwargs) == 1 and (kwargs_conf := kwargs.get('config')) is not None:
warnings.warn(
'Passing `config` as a keyword argument is deprecated. Pass `config` as a positional argument instead',
category=PydanticDeprecatedSince211,
stacklevel=2,
)
final_config = cast(ConfigDict, kwargs_conf)
else:
final_config = config if config is not None else cast(ConfigDict, kwargs)
def inner(class_: _TypeT, /) -> _TypeT:
# Ideally, we would check for `class_` to either be a `TypedDict` or a stdlib dataclass.
# However, the `@with_config` decorator can be applied *after* `@dataclass`. To avoid
# common mistakes, we at least check for `class_` to not be a Pydantic model.
from ._internal._utils import is_model_class
if is_model_class(class_):
raise PydanticUserError(
f'Cannot use `with_config` on {class_.__name__} as it is a Pydantic model',
code='with-config-on-model',
)
class_.__pydantic_config__ = final_config
return class_
return inner
__getattr__ = getattr_migration(__name__)

View File

@@ -1,25 +1,21 @@
"""Provide an enhanced dataclass that performs validation."""
from __future__ import annotations as _annotations
import dataclasses
import sys
import types
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NoReturn, TypeVar, overload
from warnings import warn
from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, TypeVar, overload
from typing_extensions import TypeGuard, dataclass_transform
from typing_extensions import Literal, TypeGuard, dataclass_transform
from ._internal import _config, _decorators, _namespace_utils, _typing_extra
from ._internal import _config, _decorators, _typing_extra
from ._internal import _dataclasses as _pydantic_dataclasses
from ._migration import getattr_migration
from .config import ConfigDict
from .errors import PydanticUserError
from .fields import Field, FieldInfo, PrivateAttr
from .fields import Field
if TYPE_CHECKING:
from ._internal._dataclasses import PydanticDataclass
from ._internal._namespace_utils import MappingNamespace
__all__ = 'dataclass', 'rebuild_dataclass'
@@ -27,7 +23,7 @@ _T = TypeVar('_T')
if sys.version_info >= (3, 10):
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload
def dataclass(
*,
@@ -44,7 +40,7 @@ if sys.version_info >= (3, 10):
) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
...
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload
def dataclass(
_cls: type[_T], # type: ignore
@@ -54,16 +50,17 @@ if sys.version_info >= (3, 10):
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool | None = None,
frozen: bool = False,
config: ConfigDict | type[object] | None = None,
validate_on_init: bool | None = None,
kw_only: bool = ...,
slots: bool = ...,
) -> type[PydanticDataclass]: ...
) -> type[PydanticDataclass]:
...
else:
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload
def dataclass(
*,
@@ -72,13 +69,13 @@ else:
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool | None = None,
frozen: bool = False,
config: ConfigDict | type[object] | None = None,
validate_on_init: bool | None = None,
) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
...
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload
def dataclass(
_cls: type[_T], # type: ignore
@@ -88,13 +85,14 @@ else:
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool | None = None,
frozen: bool = False,
config: ConfigDict | type[object] | None = None,
validate_on_init: bool | None = None,
) -> type[PydanticDataclass]: ...
) -> type[PydanticDataclass]:
...
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
def dataclass(
_cls: type[_T] | None = None,
*,
@@ -103,14 +101,13 @@ def dataclass(
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool | None = None,
frozen: bool = False,
config: ConfigDict | type[object] | None = None,
validate_on_init: bool | None = None,
kw_only: bool = False,
slots: bool = False,
) -> Callable[[type[_T]], type[PydanticDataclass]] | type[PydanticDataclass]:
"""!!! abstract "Usage Documentation"
[`dataclasses`](../concepts/dataclasses.md)
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/dataclasses/
A decorator used to create a Pydantic-enhanced dataclass, similar to the standard Python `dataclass`,
but with added validation.
@@ -122,13 +119,13 @@ def dataclass(
init: Included for signature compatibility with `dataclasses.dataclass`, and is passed through to
`dataclasses.dataclass` when appropriate. If specified, must be set to `False`, as pydantic inserts its
own `__init__` function.
repr: A boolean indicating whether to include the field in the `__repr__` output.
eq: Determines if a `__eq__` method should be generated for the class.
repr: A boolean indicating whether or not to include the field in the `__repr__` output.
eq: Determines if a `__eq__` should be generated for the class.
order: Determines if comparison magic methods should be generated, such as `__lt__`, but not `__eq__`.
unsafe_hash: Determines if a `__hash__` method should be included in the class, as in `dataclasses.dataclass`.
unsafe_hash: Determines if an unsafe hashing function should be included in the class.
frozen: Determines if the generated class should be a 'frozen' `dataclass`, which does not allow its
attributes to be modified after it has been initialized. If not set, the value from the provided `config` argument will be used (and will default to `False` otherwise).
config: The Pydantic config to use for the `dataclass`.
attributes to be modified from its constructor.
config: A configuration for the `dataclass` generation.
validate_on_init: A deprecated parameter included for backwards compatibility; in V2, all Pydantic dataclasses
are validated on init.
kw_only: Determines if `__init__` method parameters must be specified by keyword only. Defaults to `False`.
@@ -145,43 +142,10 @@ def dataclass(
assert validate_on_init is not False, 'validate_on_init=False is no longer supported'
if sys.version_info >= (3, 10):
kwargs = {'kw_only': kw_only, 'slots': slots}
kwargs = dict(kw_only=kw_only, slots=slots)
else:
kwargs = {}
def make_pydantic_fields_compatible(cls: type[Any]) -> None:
"""Make sure that stdlib `dataclasses` understands `Field` kwargs like `kw_only`
To do that, we simply change
`x: int = pydantic.Field(..., kw_only=True)`
into
`x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)`
"""
for annotation_cls in cls.__mro__:
annotations: dict[str, Any] = getattr(annotation_cls, '__annotations__', {})
for field_name in annotations:
field_value = getattr(cls, field_name, None)
# Process only if this is an instance of `FieldInfo`.
if not isinstance(field_value, FieldInfo):
continue
# Initialize arguments for the standard `dataclasses.field`.
field_args: dict = {'default': field_value}
# Handle `kw_only` for Python 3.10+
if sys.version_info >= (3, 10) and field_value.kw_only:
field_args['kw_only'] = True
# Set `repr` attribute if it's explicitly specified to be not `True`.
if field_value.repr is not True:
field_args['repr'] = field_value.repr
setattr(cls, field_name, dataclasses.field(**field_args))
# In Python 3.9, when subclassing, information is pulled from cls.__dict__['__annotations__']
# for annotations, so we must make sure it's initialized before we add to it.
if cls.__dict__.get('__annotations__') is None:
cls.__annotations__ = {}
cls.__annotations__[field_name] = annotations[field_name]
def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
"""Create a Pydantic dataclass from a regular dataclass.
@@ -191,29 +155,14 @@ def dataclass(
Returns:
A Pydantic dataclass.
"""
from ._internal._utils import is_model_class
if is_model_class(cls):
raise PydanticUserError(
f'Cannot create a Pydantic dataclass from {cls.__name__} as it is already a Pydantic model',
code='dataclass-on-model',
)
original_cls = cls
# we warn on conflicting config specifications, but only if the class doesn't have a dataclass base
# because a dataclass base might provide a __pydantic_config__ attribute that we don't want to warn about
has_dataclass_base = any(dataclasses.is_dataclass(base) for base in cls.__bases__)
if not has_dataclass_base and config is not None and hasattr(cls, '__pydantic_config__'):
warn(
f'`config` is set via both the `dataclass` decorator and `__pydantic_config__` for dataclass {cls.__name__}. '
f'The `config` specification from `dataclass` decorator will take priority.',
category=UserWarning,
stacklevel=2,
)
# if config is not explicitly provided, try to read it from the type
config_dict = config if config is not None else getattr(cls, '__pydantic_config__', None)
config_dict = config
if config_dict is None:
# if not explicitly provided, read from the type
cls_config = getattr(cls, '__pydantic_config__', None)
if cls_config is not None:
config_dict = cls_config
config_wrapper = _config.ConfigWrapper(config_dict)
decorators = _decorators.DecoratorInfos.build(cls)
@@ -236,22 +185,6 @@ def dataclass(
bases = bases + (generic_base,)
cls = types.new_class(cls.__name__, bases)
make_pydantic_fields_compatible(cls)
# Respect frozen setting from dataclass constructor and fallback to config setting if not provided
if frozen is not None:
frozen_ = frozen
if config_wrapper.frozen:
# It's not recommended to define both, as the setting from the dataclass decorator will take priority.
warn(
f'`frozen` is set via both the `dataclass` decorator and `config` for dataclass {cls.__name__!r}.'
'This is not recommended. The `frozen` specification on `dataclass` will take priority.',
category=UserWarning,
stacklevel=2,
)
else:
frozen_ = config_wrapper.frozen or False
cls = dataclasses.dataclass( # type: ignore[call-overload]
cls,
# the value of init here doesn't affect anything except that it makes it easier to generate a signature
@@ -260,40 +193,29 @@ def dataclass(
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=frozen_,
frozen=frozen,
**kwargs,
)
# This is an undocumented attribute to distinguish stdlib/Pydantic dataclasses.
# It should be set as early as possible:
cls.__is_pydantic_dataclass__ = True
cls.__pydantic_decorators__ = decorators # type: ignore
cls.__doc__ = original_doc
cls.__module__ = original_cls.__module__
cls.__qualname__ = original_cls.__qualname__
cls.__pydantic_fields_complete__ = classmethod(_pydantic_fields_complete)
cls.__pydantic_complete__ = False # `complete_dataclass` will set it to `True` if successful.
# TODO `parent_namespace` is currently None, but we could do the same thing as Pydantic models:
# fetch the parent ns using `parent_frame_namespace` (if the dataclass was defined in a function),
# and possibly cache it (see the `__pydantic_parent_namespace__` logic for models).
_pydantic_dataclasses.complete_dataclass(cls, config_wrapper, raise_errors=False)
pydantic_complete = _pydantic_dataclasses.complete_dataclass(
cls, config_wrapper, raise_errors=False, types_namespace=None
)
cls.__pydantic_complete__ = pydantic_complete # type: ignore
return cls
return create_dataclass if _cls is None else create_dataclass(_cls)
if _cls is None:
return create_dataclass
def _pydantic_fields_complete(cls: type[PydanticDataclass]) -> bool:
"""Return whether the fields where successfully collected (i.e. type hints were successfully resolves).
This is a private property, not meant to be used outside Pydantic.
"""
return all(field_info._complete for field_info in cls.__pydantic_fields__.values())
return create_dataclass(_cls)
__getattr__ = getattr_migration(__name__)
if sys.version_info < (3, 11):
if (3, 8) <= sys.version_info < (3, 11):
# Monkeypatch dataclasses.InitVar so that typing doesn't error if it occurs as a type when evaluating type hints
# Starting in 3.11, typing.get_type_hints will not raise an error if the retrieved type hints are not callable.
@@ -313,7 +235,7 @@ def rebuild_dataclass(
force: bool = False,
raise_errors: bool = True,
_parent_namespace_depth: int = 2,
_types_namespace: MappingNamespace | None = None,
_types_namespace: dict[str, Any] | None = None,
) -> bool | None:
"""Try to rebuild the pydantic-core schema for the dataclass.
@@ -323,8 +245,8 @@ def rebuild_dataclass(
This is analogous to `BaseModel.model_rebuild`.
Args:
cls: The class to rebuild the pydantic-core schema for.
force: Whether to force the rebuilding of the schema, defaults to `False`.
cls: The class to build the dataclass core schema for.
force: Whether to force the rebuilding of the model schema, defaults to `False`.
raise_errors: Whether to raise errors, defaults to `True`.
_parent_namespace_depth: The depth level of the parent namespace, defaults to 2.
_types_namespace: The types namespace, defaults to `None`.
@@ -335,49 +257,34 @@ def rebuild_dataclass(
"""
if not force and cls.__pydantic_complete__:
return None
for attr in ('__pydantic_core_schema__', '__pydantic_validator__', '__pydantic_serializer__'):
if attr in cls.__dict__:
# Deleting the validator/serializer is necessary as otherwise they can get reused in
# pydantic-core. Same applies for the core schema that can be reused in schema generation.
delattr(cls, attr)
cls.__pydantic_complete__ = False
if _types_namespace is not None:
rebuild_ns = _types_namespace
elif _parent_namespace_depth > 0:
rebuild_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {}
else:
rebuild_ns = {}
if _types_namespace is not None:
types_namespace: dict[str, Any] | None = _types_namespace.copy()
else:
if _parent_namespace_depth > 0:
frame_parent_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth) or {}
# Note: we may need to add something similar to cls.__pydantic_parent_namespace__ from BaseModel
# here when implementing handling of recursive generics. See BaseModel.model_rebuild for reference.
types_namespace = frame_parent_ns
else:
types_namespace = {}
ns_resolver = _namespace_utils.NsResolver(
parent_namespace=rebuild_ns,
)
return _pydantic_dataclasses.complete_dataclass(
cls,
_config.ConfigWrapper(cls.__pydantic_config__, check=False),
raise_errors=raise_errors,
ns_resolver=ns_resolver,
# We could provide a different config instead (with `'defer_build'` set to `True`)
# of this explicit `_force_build` argument, but because config can come from the
# decorator parameter or the `__pydantic_config__` attribute, `complete_dataclass`
# will overwrite `__pydantic_config__` with the provided config above:
_force_build=True,
)
types_namespace = _typing_extra.get_cls_types_namespace(cls, types_namespace)
return _pydantic_dataclasses.complete_dataclass(
cls,
_config.ConfigWrapper(cls.__pydantic_config__, check=False),
raise_errors=raise_errors,
types_namespace=types_namespace,
)
def is_pydantic_dataclass(class_: type[Any], /) -> TypeGuard[type[PydanticDataclass]]:
def is_pydantic_dataclass(__cls: type[Any]) -> TypeGuard[type[PydanticDataclass]]:
"""Whether a class is a pydantic dataclass.
Args:
class_: The class.
__cls: The class.
Returns:
`True` if the class is a pydantic dataclass, `False` otherwise.
"""
try:
return '__is_pydantic_dataclass__' in class_.__dict__ and dataclasses.is_dataclass(class_)
except AttributeError:
return False
return dataclasses.is_dataclass(__cls) and '__pydantic_validator__' in __cls.__dict__

View File

@@ -1,5 +1,4 @@
"""The `datetime_parse` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -1,5 +1,4 @@
"""The `decorator` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -4,10 +4,10 @@ from __future__ import annotations as _annotations
from functools import partial, partialmethod
from types import FunctionType
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
from warnings import warn
from typing_extensions import Protocol, TypeAlias, deprecated
from typing_extensions import Literal, Protocol, TypeAlias
from .._internal import _decorators, _decorators_v1
from ..errors import PydanticUserError
@@ -19,24 +19,30 @@ _ALLOW_REUSE_WARNING_MESSAGE = '`allow_reuse` is deprecated and will be ignored;
if TYPE_CHECKING:
class _OnlyValueValidatorClsMethod(Protocol):
def __call__(self, __cls: Any, __value: Any) -> Any: ...
def __call__(self, __cls: Any, __value: Any) -> Any:
...
class _V1ValidatorWithValuesClsMethod(Protocol):
def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any: ...
def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any:
...
class _V1ValidatorWithValuesKwOnlyClsMethod(Protocol):
def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any: ...
def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any:
...
class _V1ValidatorWithKwargsClsMethod(Protocol):
def __call__(self, __cls: Any, **kwargs: Any) -> Any: ...
def __call__(self, __cls: Any, **kwargs: Any) -> Any:
...
class _V1ValidatorWithValuesAndKwargsClsMethod(Protocol):
def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any:
...
class _V1RootValidatorClsMethod(Protocol):
def __call__(
self, __cls: Any, __values: _decorators_v1.RootValidatorValues
) -> _decorators_v1.RootValidatorValues: ...
) -> _decorators_v1.RootValidatorValues:
...
V1Validator = Union[
_OnlyValueValidatorClsMethod,
@@ -73,12 +79,6 @@ else:
DeprecationWarning = PydanticDeprecatedSince20
@deprecated(
'Pydantic V1 style `@validator` validators are deprecated.'
' You should migrate to Pydantic V2 style `@field_validator` validators,'
' see the migration guide for more details',
category=None,
)
def validator(
__field: str,
*fields: str,
@@ -94,7 +94,7 @@ def validator(
__field (str): The first field the validator should be called on; this is separate
from `fields` to ensure an error is raised if you don't pass at least one.
*fields (str): Additional field(s) the validator should be called on.
pre (bool, optional): Whether this validator should be called before the standard
pre (bool, optional): Whether or not this validator should be called before the standard
validators (else after). Defaults to False.
each_item (bool, optional): For complex objects (sets, lists etc.) whether to validate
individual elements rather than the whole object. Defaults to False.
@@ -109,6 +109,22 @@ def validator(
Callable: A decorator that can be used to decorate a
function to be used as a validator.
"""
if allow_reuse is True: # pragma: no cover
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
fields = tuple((__field, *fields))
if isinstance(fields[0], FunctionType):
raise PydanticUserError(
"`@validator` should be used with fields and keyword arguments, not bare. "
"E.g. usage should be `@validator('<field_name>', ...)`",
code='validator-no-fields',
)
elif not all(isinstance(field, str) for field in fields):
raise PydanticUserError(
"`@validator` fields should be passed as separate string args. "
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
code='validator-invalid-fields',
)
warn(
'Pydantic V1 style `@validator` validators are deprecated.'
' You should migrate to Pydantic V2 style `@field_validator` validators,'
@@ -117,22 +133,6 @@ def validator(
stacklevel=2,
)
if allow_reuse is True: # pragma: no cover
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
fields = __field, *fields
if isinstance(fields[0], FunctionType):
raise PydanticUserError(
'`@validator` should be used with fields and keyword arguments, not bare. '
"E.g. usage should be `@validator('<field_name>', ...)`",
code='validator-no-fields',
)
elif not all(isinstance(field, str) for field in fields):
raise PydanticUserError(
'`@validator` fields should be passed as separate string args. '
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
code='validator-invalid-fields',
)
mode: Literal['before', 'after'] = 'before' if pre is True else 'after'
def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]:
@@ -162,10 +162,8 @@ def root_validator(
# which means you need to specify `skip_on_failure=True`
skip_on_failure: Literal[True],
allow_reuse: bool = ...,
) -> Callable[
[_V1RootValidatorFunctionType],
_V1RootValidatorFunctionType,
]: ...
) -> Callable[[_V1RootValidatorFunctionType], _V1RootValidatorFunctionType,]:
...
@overload
@@ -175,10 +173,8 @@ def root_validator(
# `skip_on_failure`, in fact it is not allowed as an argument!
pre: Literal[True],
allow_reuse: bool = ...,
) -> Callable[
[_V1RootValidatorFunctionType],
_V1RootValidatorFunctionType,
]: ...
) -> Callable[[_V1RootValidatorFunctionType], _V1RootValidatorFunctionType,]:
...
@overload
@@ -189,18 +185,10 @@ def root_validator(
pre: Literal[False],
skip_on_failure: Literal[True],
allow_reuse: bool = ...,
) -> Callable[
[_V1RootValidatorFunctionType],
_V1RootValidatorFunctionType,
]: ...
) -> Callable[[_V1RootValidatorFunctionType], _V1RootValidatorFunctionType,]:
...
@deprecated(
'Pydantic V1 style `@root_validator` validators are deprecated.'
' You should migrate to Pydantic V2 style `@model_validator` validators,'
' see the migration guide for more details',
category=None,
)
def root_validator(
*__args,
pre: bool = False,

View File

@@ -1,9 +1,9 @@
from __future__ import annotations as _annotations
import warnings
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any
from typing_extensions import deprecated
from typing_extensions import Literal, deprecated
from .._internal import _config
from ..warnings import PydanticDeprecatedSince20
@@ -18,10 +18,10 @@ __all__ = 'BaseConfig', 'Extra'
class _ConfigMetaclass(type):
def __getattr__(self, item: str) -> Any:
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
try:
obj = _config.config_defaults[item]
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
return obj
return _config.config_defaults[item]
except KeyError as exc:
raise AttributeError(f"type object '{self.__name__}' has no attribute {exc}") from exc
@@ -35,10 +35,9 @@ class BaseConfig(metaclass=_ConfigMetaclass):
"""
def __getattr__(self, item: str) -> Any:
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
try:
obj = super().__getattribute__(item)
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
return obj
return super().__getattribute__(item)
except AttributeError as exc:
try:
return getattr(type(self), item)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
import typing
from copy import deepcopy
from enum import Enum
from typing import Any
from typing import Any, Tuple
import typing_extensions
@@ -18,7 +18,7 @@ if typing.TYPE_CHECKING:
from .._internal._utils import AbstractSetIntStr, MappingIntStrAny
AnyClassMethod = classmethod[Any, Any, Any]
TupleGenerator = typing.Generator[tuple[str, Any], None, None]
TupleGenerator = typing.Generator[Tuple[str, Any], None, None]
Model = typing.TypeVar('Model', bound='BaseModel')
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
@@ -40,11 +40,11 @@ def _iter(
# The extra "is not None" guards are not logically necessary but optimizes performance for the simple case.
if exclude is not None:
exclude = _utils.ValueItems.merge(
{k: v.exclude for k, v in self.__pydantic_fields__.items() if v.exclude is not None}, exclude
{k: v.exclude for k, v in self.model_fields.items() if v.exclude is not None}, exclude
)
if include is not None:
include = _utils.ValueItems.merge({k: True for k in self.__pydantic_fields__}, include, intersect=True)
include = _utils.ValueItems.merge({k: True for k in self.model_fields}, include, intersect=True)
allowed_keys = _calculate_keys(self, include=include, exclude=exclude, exclude_unset=exclude_unset) # type: ignore
if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none):
@@ -68,15 +68,15 @@ def _iter(
if exclude_defaults:
try:
field = self.__pydantic_fields__[field_key]
field = self.model_fields[field_key]
except KeyError:
pass
else:
if not field.is_required() and field.default == v:
continue
if by_alias and field_key in self.__pydantic_fields__:
dict_key = self.__pydantic_fields__[field_key].alias or field_key
if by_alias and field_key in self.model_fields:
dict_key = self.model_fields[field_key].alias or field_key
else:
dict_key = field_key
@@ -200,7 +200,7 @@ def _calculate_keys(
include: MappingIntStrAny | None,
exclude: MappingIntStrAny | None,
exclude_unset: bool,
update: dict[str, Any] | None = None, # noqa UP006
update: typing.Dict[str, Any] | None = None, # noqa UP006
) -> typing.AbstractSet[str] | None:
if include is None and exclude is None and exclude_unset is False:
return None

View File

@@ -1,7 +1,6 @@
import warnings
from collections.abc import Mapping
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
from typing_extensions import deprecated
@@ -23,29 +22,29 @@ if TYPE_CHECKING:
AnyCallable = Callable[..., Any]
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
ConfigType = Union[None, type[Any], dict[str, Any]]
ConfigType = Union[None, Type[Any], Dict[str, Any]]
@overload
def validate_arguments(
func: None = None, *, config: 'ConfigType' = None
) -> Callable[['AnyCallableT'], 'AnyCallableT']: ...
@overload
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': ...
@deprecated(
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
category=None,
'The `validate_arguments` method is deprecated; use `validate_call` instead.', category=PydanticDeprecatedSince20
)
def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']:
...
@overload
@deprecated(
'The `validate_arguments` method is deprecated; use `validate_call` instead.', category=PydanticDeprecatedSince20
)
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT':
...
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
"""Decorator to validate the arguments passed to a function."""
warnings.warn(
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
PydanticDeprecatedSince20,
stacklevel=2,
'The `validate_arguments` method is deprecated; use `validate_call` instead.', DeprecationWarning, stacklevel=2
)
def validate(_func: 'AnyCallable') -> 'AnyCallable':
@@ -87,7 +86,7 @@ class ValidatedFunction:
)
self.raw_function = function
self.arg_mapping: dict[int, str] = {}
self.arg_mapping: Dict[int, str] = {}
self.positional_only_args: set[str] = set()
self.v_args_name = 'args'
self.v_kwargs_name = 'kwargs'
@@ -95,7 +94,7 @@ class ValidatedFunction:
type_hints = _typing_extra.get_type_hints(function, include_extras=True)
takes_args = False
takes_kwargs = False
fields: dict[str, tuple[Any, Any]] = {}
fields: Dict[str, Tuple[Any, Any]] = {}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation is p.empty:
annotation = Any
@@ -106,22 +105,22 @@ class ValidatedFunction:
if p.kind == Parameter.POSITIONAL_ONLY:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_POSITIONAL_ONLY_NAME] = list[str], None
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
self.positional_only_args.add(name)
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_DUPLICATE_KWARGS] = list[str], None
fields[V_DUPLICATE_KWARGS] = List[str], None
elif p.kind == Parameter.KEYWORD_ONLY:
fields[name] = annotation, default
elif p.kind == Parameter.VAR_POSITIONAL:
self.v_args_name = name
fields[name] = tuple[annotation, ...], None
fields[name] = Tuple[annotation, ...], None
takes_args = True
else:
assert p.kind == Parameter.VAR_KEYWORD, p.kind
self.v_kwargs_name = name
fields[name] = dict[str, annotation], None
fields[name] = Dict[str, annotation], None
takes_kwargs = True
# these checks avoid a clash between "args" and a field with that name
@@ -134,11 +133,11 @@ class ValidatedFunction:
if not takes_args:
# we add the field so validation below can raise the correct exception
fields[self.v_args_name] = list[Any], None
fields[self.v_args_name] = List[Any], None
if not takes_kwargs:
# same with kwargs
fields[self.v_kwargs_name] = dict[Any, Any], None
fields[self.v_kwargs_name] = Dict[Any, Any], None
self.create_model(fields, takes_args, takes_kwargs, config)
@@ -150,8 +149,8 @@ class ValidatedFunction:
m = self.init_model_instance(*args, **kwargs)
return self.execute(m)
def build_values(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]:
values: dict[str, Any] = {}
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
values: Dict[str, Any] = {}
if args:
arg_iter = enumerate(args)
while True:
@@ -166,15 +165,15 @@ class ValidatedFunction:
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
break
var_kwargs: dict[str, Any] = {}
var_kwargs: Dict[str, Any] = {}
wrong_positional_args = []
duplicate_kwargs = []
fields_alias = [
field.alias
for name, field in self.model.__pydantic_fields__.items()
for name, field in self.model.model_fields.items()
if name not in (self.v_args_name, self.v_kwargs_name)
]
non_var_fields = set(self.model.__pydantic_fields__) - {self.v_args_name, self.v_kwargs_name}
non_var_fields = set(self.model.model_fields) - {self.v_args_name, self.v_kwargs_name}
for k, v in kwargs.items():
if k in non_var_fields or k in fields_alias:
if k in self.positional_only_args:
@@ -194,15 +193,11 @@ class ValidatedFunction:
return values
def execute(self, m: BaseModel) -> Any:
d = {
k: v
for k, v in m.__dict__.items()
if k in m.__pydantic_fields_set__ or m.__pydantic_fields__[k].default_factory
}
d = {k: v for k, v in m.__dict__.items() if k in m.__pydantic_fields_set__ or m.model_fields[k].default_factory}
var_kwargs = d.pop(self.v_kwargs_name, {})
if self.v_args_name in d:
args_: list[Any] = []
args_: List[Any] = []
in_kwargs = False
kwargs = {}
for name, value in d.items():
@@ -226,7 +221,7 @@ class ValidatedFunction:
else:
return self.raw_function(**d, **var_kwargs)
def create_model(self, fields: dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
pos_args = len(self.arg_mapping)
config_wrapper = _config.ConfigWrapper(config)
@@ -243,7 +238,7 @@ class ValidatedFunction:
class DecoratorBaseModel(BaseModel):
@field_validator(self.v_args_name, check_fields=False)
@classmethod
def check_args(cls, v: Optional[list[Any]]) -> Optional[list[Any]]:
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
if takes_args or v is None:
return v
@@ -251,7 +246,7 @@ class ValidatedFunction:
@field_validator(self.v_kwargs_name, check_fields=False)
@classmethod
def check_kwargs(cls, v: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if takes_kwargs or v is None:
return v
@@ -261,7 +256,7 @@ class ValidatedFunction:
@field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False)
@classmethod
def check_positional_only(cls, v: Optional[list[str]]) -> None:
def check_positional_only(cls, v: Optional[List[str]]) -> None:
if v is None:
return
@@ -271,7 +266,7 @@ class ValidatedFunction:
@field_validator(V_DUPLICATE_KWARGS, check_fields=False)
@classmethod
def check_duplicate_kwargs(cls, v: Optional[list[str]]) -> None:
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
if v is None:
return

View File

@@ -7,12 +7,11 @@ from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6
from pathlib import Path
from re import Pattern
from types import GeneratorType
from typing import TYPE_CHECKING, Any, Callable, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union
from uuid import UUID
from typing_extensions import deprecated
from .._internal._import_utils import import_cached_base_model
from ..color import Color
from ..networks import NameEmail
from ..types import SecretBytes, SecretStr
@@ -51,7 +50,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
return float(dec_value)
ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
datetime.date: isoformat,
@@ -80,23 +79,18 @@ ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
@deprecated(
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
category=None,
'pydantic_encoder is deprecated, use pydantic_core.to_jsonable_python instead.', category=PydanticDeprecatedSince20
)
def pydantic_encoder(obj: Any) -> Any:
warnings.warn(
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
category=PydanticDeprecatedSince20,
stacklevel=2,
)
from dataclasses import asdict, is_dataclass
BaseModel = import_cached_base_model()
from ..main import BaseModel
warnings.warn('pydantic_encoder is deprecated, use BaseModel.model_dump instead.', DeprecationWarning, stacklevel=2)
if isinstance(obj, BaseModel):
return obj.model_dump()
elif is_dataclass(obj):
return asdict(obj) # type: ignore
return asdict(obj)
# Check the class type and its superclasses for a matching encoder
for base in obj.__class__.__mro__[:-1]:
@@ -110,17 +104,12 @@ def pydantic_encoder(obj: Any) -> Any:
# TODO: Add a suggested migration path once there is a way to use custom encoders
@deprecated(
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
category=None,
)
def custom_pydantic_encoder(type_encoders: dict[Any, Callable[[type[Any]], Any]], obj: Any) -> Any:
warnings.warn(
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
category=PydanticDeprecatedSince20,
stacklevel=2,
)
@deprecated('custom_pydantic_encoder is deprecated.', category=PydanticDeprecatedSince20)
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
# Check the class type and its superclasses for a matching encoder
warnings.warn(
'custom_pydantic_encoder is deprecated, use BaseModel.model_dump instead.', DeprecationWarning, stacklevel=2
)
for base in obj.__class__.__mro__[:-1]:
try:
encoder = type_encoders[base]
@@ -132,10 +121,10 @@ def custom_pydantic_encoder(type_encoders: dict[Any, Callable[[type[Any]], Any]]
return pydantic_encoder(obj)
@deprecated('`timedelta_isoformat` is deprecated.', category=None)
@deprecated('timedelta_isoformat is deprecated.', category=PydanticDeprecatedSince20)
def timedelta_isoformat(td: datetime.timedelta) -> str:
"""ISO 8601 encoding for Python timedelta object."""
warnings.warn('`timedelta_isoformat` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
warnings.warn('timedelta_isoformat is deprecated.', DeprecationWarning, stacklevel=2)
minutes, seconds = divmod(td.seconds, 60)
hours, minutes = divmod(minutes, 60)
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'

View File

@@ -22,7 +22,7 @@ class Protocol(str, Enum):
pickle = 'pickle'
@deprecated('`load_str_bytes` is deprecated.', category=None)
@deprecated('load_str_bytes is deprecated.', category=PydanticDeprecatedSince20)
def load_str_bytes(
b: str | bytes,
*,
@@ -32,7 +32,7 @@ def load_str_bytes(
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
) -> Any:
warnings.warn('`load_str_bytes` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
warnings.warn('load_str_bytes is deprecated.', DeprecationWarning, stacklevel=2)
if proto is None and content_type:
if content_type.endswith(('json', 'javascript')):
pass
@@ -46,17 +46,17 @@ def load_str_bytes(
if proto == Protocol.json:
if isinstance(b, bytes):
b = b.decode(encoding)
return json_loads(b) # type: ignore
return json_loads(b)
elif proto == Protocol.pickle:
if not allow_pickle:
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
bb = b if isinstance(b, bytes) else b.encode() # type: ignore
bb = b if isinstance(b, bytes) else b.encode()
return pickle.loads(bb)
else:
raise TypeError(f'Unknown protocol: {proto}')
@deprecated('`load_file` is deprecated.', category=None)
@deprecated('load_file is deprecated.', category=PydanticDeprecatedSince20)
def load_file(
path: str | Path,
*,
@@ -66,7 +66,7 @@ def load_file(
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
) -> Any:
warnings.warn('`load_file` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
warnings.warn('load_file is deprecated.', DeprecationWarning, stacklevel=2)
path = Path(path)
b = path.read_bytes()
if content_type is None:

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json
import warnings
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, Union
from typing_extensions import deprecated
@@ -17,20 +17,19 @@ if not TYPE_CHECKING:
__all__ = 'parse_obj_as', 'schema_of', 'schema_json_of'
NameFactory = Union[str, Callable[[type[Any]], str]]
NameFactory = Union[str, Callable[[Type[Any]], str]]
T = TypeVar('T')
@deprecated(
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
category=None,
'parse_obj_as is deprecated. Use pydantic.TypeAdapter.validate_python instead.', category=PydanticDeprecatedSince20
)
def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None) -> T:
warnings.warn(
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
category=PydanticDeprecatedSince20,
'parse_obj_as is deprecated. Use pydantic.TypeAdapter.validate_python instead.',
DeprecationWarning,
stacklevel=2,
)
if type_name is not None: # pragma: no cover
@@ -43,8 +42,7 @@ def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None)
@deprecated(
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
category=None,
'schema_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', category=PydanticDeprecatedSince20
)
def schema_of(
type_: Any,
@@ -56,9 +54,7 @@ def schema_of(
) -> dict[str, Any]:
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one."""
warnings.warn(
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
category=PydanticDeprecatedSince20,
stacklevel=2,
'schema_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', DeprecationWarning, stacklevel=2
)
res = TypeAdapter(type_).json_schema(
by_alias=by_alias,
@@ -79,8 +75,7 @@ def schema_of(
@deprecated(
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
category=None,
'schema_json_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', category=PydanticDeprecatedSince20
)
def schema_json_of(
type_: Any,
@@ -93,9 +88,7 @@ def schema_json_of(
) -> str:
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one."""
warnings.warn(
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
category=PydanticDeprecatedSince20,
stacklevel=2,
'schema_json_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', DeprecationWarning, stacklevel=2
)
return json.dumps(
schema_of(type_, title=title, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator),

View File

@@ -1,5 +1,4 @@
"""The `env_settings` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -1,5 +1,4 @@
"""The `error_wrappers` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -1,14 +1,9 @@
"""Pydantic-specific errors."""
from __future__ import annotations as _annotations
import re
from typing import Any, ClassVar, Literal
from typing_extensions import Self
from typing_inspection.introspection import Qualifier
from pydantic._internal import _repr
from typing_extensions import Literal, Self
from ._migration import getattr_migration
from .version import version_short
@@ -19,7 +14,6 @@ __all__ = (
'PydanticImportError',
'PydanticSchemaGenerationError',
'PydanticInvalidForJsonSchema',
'PydanticForbiddenQualifier',
'PydanticErrorCodes',
)
@@ -36,13 +30,11 @@ PydanticErrorCodes = Literal[
'discriminator-needs-literal',
'discriminator-alias',
'discriminator-validator',
'callable-discriminator-no-tag',
'typed-dict-version',
'model-field-overridden',
'model-field-missing-annotation',
'config-both',
'removed-kwargs',
'circular-reference-schema',
'invalid-for-json-schema',
'json-schema-already-used',
'base-model-instantiated',
@@ -50,10 +42,10 @@ PydanticErrorCodes = Literal[
'schema-for-unknown-type',
'import-error',
'create-model-field-definitions',
'create-model-config-base',
'validator-no-fields',
'validator-invalid-fields',
'validator-instance-method',
'validator-input-type',
'root-validator-pre-skip',
'model-serializer-instance-method',
'validator-field-config-info',
@@ -62,20 +54,9 @@ PydanticErrorCodes = Literal[
'field-serializer-signature',
'model-serializer-signature',
'multiple-field-serializers',
'invalid-annotated-type',
'invalid_annotated_type',
'type-adapter-config-unused',
'root-model-extra',
'unevaluable-type-annotation',
'dataclass-init-false-extra-allow',
'clashing-init-and-init-var',
'model-config-invalid-field-name',
'with-config-on-model',
'dataclass-on-model',
'validate-call-type',
'unpack-typed-dict',
'overlapping-unpack-typed-dict',
'invalid-self-type',
'validate-by-alias-and-name-false',
]
@@ -164,26 +145,4 @@ class PydanticInvalidForJsonSchema(PydanticUserError):
super().__init__(message, code='invalid-for-json-schema')
class PydanticForbiddenQualifier(PydanticUserError):
"""An error raised if a forbidden type qualifier is found in a type annotation."""
_qualifier_repr_map: ClassVar[dict[Qualifier, str]] = {
'required': 'typing.Required',
'not_required': 'typing.NotRequired',
'read_only': 'typing.ReadOnly',
'class_var': 'typing.ClassVar',
'init_var': 'dataclasses.InitVar',
'final': 'typing.Final',
}
def __init__(self, qualifier: Qualifier, annotation: Any) -> None:
super().__init__(
message=(
f'The annotation {_repr.display_as_type(annotation)!r} contains the {self._qualifier_repr_map[qualifier]!r} '
f'type qualifier, which is invalid in the context it is defined.'
),
code=None,
)
__getattr__ = getattr_migration(__name__)

View File

@@ -1,10 +0,0 @@
"""The "experimental" module of pydantic contains potential new features that are subject to change."""
import warnings
from pydantic.warnings import PydanticExperimentalWarning
warnings.warn(
'This module is experimental, its contents are subject to change and deprecation.',
category=PydanticExperimentalWarning,
)

View File

@@ -1,44 +0,0 @@
"""Experimental module exposing a function to generate a core schema that validates callable arguments."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any, Literal
from pydantic_core import CoreSchema
from pydantic import ConfigDict
from pydantic._internal import _config, _generate_schema, _namespace_utils
def generate_arguments_schema(
func: Callable[..., Any],
schema_type: Literal['arguments', 'arguments-v3'] = 'arguments-v3',
parameters_callback: Callable[[int, str, Any], Literal['skip'] | None] | None = None,
config: ConfigDict | None = None,
) -> CoreSchema:
"""Generate the schema for the arguments of a function.
Args:
func: The function to generate the schema for.
schema_type: The type of schema to generate.
parameters_callback: A callable that will be invoked for each parameter. The callback
should take three required arguments: the index, the name and the type annotation
(or [`Parameter.empty`][inspect.Parameter.empty] if not annotated) of the parameter.
The callback can optionally return `'skip'`, so that the parameter gets excluded
from the resulting schema.
config: The configuration to use.
Returns:
The generated schema.
"""
generate_schema = _generate_schema.GenerateSchema(
_config.ConfigWrapper(config),
ns_resolver=_namespace_utils.NsResolver(namespaces_tuple=_namespace_utils.ns_for_function(func)),
)
if schema_type == 'arguments':
schema = generate_schema._arguments_schema(func, parameters_callback) # pyright: ignore[reportArgumentType]
else:
schema = generate_schema._arguments_v3_schema(func, parameters_callback) # pyright: ignore[reportArgumentType]
return generate_schema.clean_schema(schema)

View File

@@ -1,667 +0,0 @@
"""Experimental pipeline API functionality. Be careful with this API, it's subject to change."""
from __future__ import annotations
import datetime
import operator
import re
import sys
from collections import deque
from collections.abc import Container
from dataclasses import dataclass
from decimal import Decimal
from functools import cached_property, partial
from re import Pattern
from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, Protocol, TypeVar, Union, overload
import annotated_types
if TYPE_CHECKING:
from pydantic_core import core_schema as cs
from pydantic import GetCoreSchemaHandler
from pydantic._internal._internal_dataclass import slots_true as _slots_true
if sys.version_info < (3, 10):
EllipsisType = type(Ellipsis)
else:
from types import EllipsisType
__all__ = ['validate_as', 'validate_as_deferred', 'transform']
_slots_frozen = {**_slots_true, 'frozen': True}
@dataclass(**_slots_frozen)
class _ValidateAs:
tp: type[Any]
strict: bool = False
@dataclass
class _ValidateAsDefer:
func: Callable[[], type[Any]]
@cached_property
def tp(self) -> type[Any]:
return self.func()
@dataclass(**_slots_frozen)
class _Transform:
func: Callable[[Any], Any]
@dataclass(**_slots_frozen)
class _PipelineOr:
left: _Pipeline[Any, Any]
right: _Pipeline[Any, Any]
@dataclass(**_slots_frozen)
class _PipelineAnd:
left: _Pipeline[Any, Any]
right: _Pipeline[Any, Any]
@dataclass(**_slots_frozen)
class _Eq:
value: Any
@dataclass(**_slots_frozen)
class _NotEq:
value: Any
@dataclass(**_slots_frozen)
class _In:
values: Container[Any]
@dataclass(**_slots_frozen)
class _NotIn:
values: Container[Any]
_ConstraintAnnotation = Union[
annotated_types.Le,
annotated_types.Ge,
annotated_types.Lt,
annotated_types.Gt,
annotated_types.Len,
annotated_types.MultipleOf,
annotated_types.Timezone,
annotated_types.Interval,
annotated_types.Predicate,
# common predicates not included in annotated_types
_Eq,
_NotEq,
_In,
_NotIn,
# regular expressions
Pattern[str],
]
@dataclass(**_slots_frozen)
class _Constraint:
constraint: _ConstraintAnnotation
_Step = Union[_ValidateAs, _ValidateAsDefer, _Transform, _PipelineOr, _PipelineAnd, _Constraint]
_InT = TypeVar('_InT')
_OutT = TypeVar('_OutT')
_NewOutT = TypeVar('_NewOutT')
class _FieldTypeMarker:
pass
# TODO: ultimately, make this public, see https://github.com/pydantic/pydantic/pull/9459#discussion_r1628197626
# Also, make this frozen eventually, but that doesn't work right now because of the generic base
# Which attempts to modify __orig_base__ and such.
# We could go with a manual freeze, but that seems overkill for now.
@dataclass(**_slots_true)
class _Pipeline(Generic[_InT, _OutT]):
"""Abstract representation of a chain of validation, transformation, and parsing steps."""
_steps: tuple[_Step, ...]
def transform(
self,
func: Callable[[_OutT], _NewOutT],
) -> _Pipeline[_InT, _NewOutT]:
"""Transform the output of the previous step.
If used as the first step in a pipeline, the type of the field is used.
That is, the transformation is applied to after the value is parsed to the field's type.
"""
return _Pipeline[_InT, _NewOutT](self._steps + (_Transform(func),))
@overload
def validate_as(self, tp: type[_NewOutT], *, strict: bool = ...) -> _Pipeline[_InT, _NewOutT]: ...
@overload
def validate_as(self, tp: EllipsisType, *, strict: bool = ...) -> _Pipeline[_InT, Any]: # type: ignore
...
def validate_as(self, tp: type[_NewOutT] | EllipsisType, *, strict: bool = False) -> _Pipeline[_InT, Any]: # type: ignore
"""Validate / parse the input into a new type.
If no type is provided, the type of the field is used.
Types are parsed in Pydantic's `lax` mode by default,
but you can enable `strict` mode by passing `strict=True`.
"""
if isinstance(tp, EllipsisType):
return _Pipeline[_InT, Any](self._steps + (_ValidateAs(_FieldTypeMarker, strict=strict),))
return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAs(tp, strict=strict),))
def validate_as_deferred(self, func: Callable[[], type[_NewOutT]]) -> _Pipeline[_InT, _NewOutT]:
"""Parse the input into a new type, deferring resolution of the type until the current class
is fully defined.
This is useful when you need to reference the class in it's own type annotations.
"""
return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAsDefer(func),))
# constraints
@overload
def constrain(self: _Pipeline[_InT, _NewOutGe], constraint: annotated_types.Ge) -> _Pipeline[_InT, _NewOutGe]: ...
@overload
def constrain(self: _Pipeline[_InT, _NewOutGt], constraint: annotated_types.Gt) -> _Pipeline[_InT, _NewOutGt]: ...
@overload
def constrain(self: _Pipeline[_InT, _NewOutLe], constraint: annotated_types.Le) -> _Pipeline[_InT, _NewOutLe]: ...
@overload
def constrain(self: _Pipeline[_InT, _NewOutLt], constraint: annotated_types.Lt) -> _Pipeline[_InT, _NewOutLt]: ...
@overload
def constrain(
self: _Pipeline[_InT, _NewOutLen], constraint: annotated_types.Len
) -> _Pipeline[_InT, _NewOutLen]: ...
@overload
def constrain(
self: _Pipeline[_InT, _NewOutT], constraint: annotated_types.MultipleOf
) -> _Pipeline[_InT, _NewOutT]: ...
@overload
def constrain(
self: _Pipeline[_InT, _NewOutDatetime], constraint: annotated_types.Timezone
) -> _Pipeline[_InT, _NewOutDatetime]: ...
@overload
def constrain(self: _Pipeline[_InT, _OutT], constraint: annotated_types.Predicate) -> _Pipeline[_InT, _OutT]: ...
@overload
def constrain(
self: _Pipeline[_InT, _NewOutInterval], constraint: annotated_types.Interval
) -> _Pipeline[_InT, _NewOutInterval]: ...
@overload
def constrain(self: _Pipeline[_InT, _OutT], constraint: _Eq) -> _Pipeline[_InT, _OutT]: ...
@overload
def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotEq) -> _Pipeline[_InT, _OutT]: ...
@overload
def constrain(self: _Pipeline[_InT, _OutT], constraint: _In) -> _Pipeline[_InT, _OutT]: ...
@overload
def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotIn) -> _Pipeline[_InT, _OutT]: ...
@overload
def constrain(self: _Pipeline[_InT, _NewOutT], constraint: Pattern[str]) -> _Pipeline[_InT, _NewOutT]: ...
def constrain(self, constraint: _ConstraintAnnotation) -> Any:
"""Constrain a value to meet a certain condition.
We support most conditions from `annotated_types`, as well as regular expressions.
Most of the time you'll be calling a shortcut method like `gt`, `lt`, `len`, etc
so you don't need to call this directly.
"""
return _Pipeline[_InT, _OutT](self._steps + (_Constraint(constraint),))
def predicate(self: _Pipeline[_InT, _NewOutT], func: Callable[[_NewOutT], bool]) -> _Pipeline[_InT, _NewOutT]:
"""Constrain a value to meet a certain predicate."""
return self.constrain(annotated_types.Predicate(func))
def gt(self: _Pipeline[_InT, _NewOutGt], gt: _NewOutGt) -> _Pipeline[_InT, _NewOutGt]:
"""Constrain a value to be greater than a certain value."""
return self.constrain(annotated_types.Gt(gt))
def lt(self: _Pipeline[_InT, _NewOutLt], lt: _NewOutLt) -> _Pipeline[_InT, _NewOutLt]:
"""Constrain a value to be less than a certain value."""
return self.constrain(annotated_types.Lt(lt))
def ge(self: _Pipeline[_InT, _NewOutGe], ge: _NewOutGe) -> _Pipeline[_InT, _NewOutGe]:
"""Constrain a value to be greater than or equal to a certain value."""
return self.constrain(annotated_types.Ge(ge))
def le(self: _Pipeline[_InT, _NewOutLe], le: _NewOutLe) -> _Pipeline[_InT, _NewOutLe]:
"""Constrain a value to be less than or equal to a certain value."""
return self.constrain(annotated_types.Le(le))
def len(self: _Pipeline[_InT, _NewOutLen], min_len: int, max_len: int | None = None) -> _Pipeline[_InT, _NewOutLen]:
"""Constrain a value to have a certain length."""
return self.constrain(annotated_types.Len(min_len, max_len))
@overload
def multiple_of(self: _Pipeline[_InT, _NewOutDiv], multiple_of: _NewOutDiv) -> _Pipeline[_InT, _NewOutDiv]: ...
@overload
def multiple_of(self: _Pipeline[_InT, _NewOutMod], multiple_of: _NewOutMod) -> _Pipeline[_InT, _NewOutMod]: ...
def multiple_of(self: _Pipeline[_InT, Any], multiple_of: Any) -> _Pipeline[_InT, Any]:
"""Constrain a value to be a multiple of a certain number."""
return self.constrain(annotated_types.MultipleOf(multiple_of))
def eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]:
"""Constrain a value to be equal to a certain value."""
return self.constrain(_Eq(value))
def not_eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]:
"""Constrain a value to not be equal to a certain value."""
return self.constrain(_NotEq(value))
def in_(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]:
"""Constrain a value to be in a certain set."""
return self.constrain(_In(values))
def not_in(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]:
"""Constrain a value to not be in a certain set."""
return self.constrain(_NotIn(values))
# timezone methods
def datetime_tz_naive(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]:
return self.constrain(annotated_types.Timezone(None))
def datetime_tz_aware(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]:
return self.constrain(annotated_types.Timezone(...))
def datetime_tz(
self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo
) -> _Pipeline[_InT, datetime.datetime]:
return self.constrain(annotated_types.Timezone(tz)) # type: ignore
def datetime_with_tz(
self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo | None
) -> _Pipeline[_InT, datetime.datetime]:
return self.transform(partial(datetime.datetime.replace, tzinfo=tz))
# string methods
def str_lower(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
return self.transform(str.lower)
def str_upper(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
return self.transform(str.upper)
def str_title(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
return self.transform(str.title)
def str_strip(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
return self.transform(str.strip)
def str_pattern(self: _Pipeline[_InT, str], pattern: str) -> _Pipeline[_InT, str]:
return self.constrain(re.compile(pattern))
def str_contains(self: _Pipeline[_InT, str], substring: str) -> _Pipeline[_InT, str]:
return self.predicate(lambda v: substring in v)
def str_starts_with(self: _Pipeline[_InT, str], prefix: str) -> _Pipeline[_InT, str]:
return self.predicate(lambda v: v.startswith(prefix))
def str_ends_with(self: _Pipeline[_InT, str], suffix: str) -> _Pipeline[_InT, str]:
return self.predicate(lambda v: v.endswith(suffix))
# operators
def otherwise(self, other: _Pipeline[_OtherIn, _OtherOut]) -> _Pipeline[_InT | _OtherIn, _OutT | _OtherOut]:
"""Combine two validation chains, returning the result of the first chain if it succeeds, and the second chain if it fails."""
return _Pipeline((_PipelineOr(self, other),))
__or__ = otherwise
def then(self, other: _Pipeline[_OutT, _OtherOut]) -> _Pipeline[_InT, _OtherOut]:
"""Pipe the result of one validation chain into another."""
return _Pipeline((_PipelineAnd(self, other),))
__and__ = then
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> cs.CoreSchema:
from pydantic_core import core_schema as cs
queue = deque(self._steps)
s = None
while queue:
step = queue.popleft()
s = _apply_step(step, s, handler, source_type)
s = s or cs.any_schema()
return s
def __supports_type__(self, _: _OutT) -> bool:
raise NotImplementedError
validate_as = _Pipeline[Any, Any](()).validate_as
validate_as_deferred = _Pipeline[Any, Any](()).validate_as_deferred
transform = _Pipeline[Any, Any]((_ValidateAs(_FieldTypeMarker),)).transform
def _check_func(
func: Callable[[Any], bool], predicate_err: str | Callable[[], str], s: cs.CoreSchema | None
) -> cs.CoreSchema:
from pydantic_core import core_schema as cs
def handler(v: Any) -> Any:
if func(v):
return v
raise ValueError(f'Expected {predicate_err if isinstance(predicate_err, str) else predicate_err()}')
if s is None:
return cs.no_info_plain_validator_function(handler)
else:
return cs.no_info_after_validator_function(handler, s)
def _apply_step(step: _Step, s: cs.CoreSchema | None, handler: GetCoreSchemaHandler, source_type: Any) -> cs.CoreSchema:
from pydantic_core import core_schema as cs
if isinstance(step, _ValidateAs):
s = _apply_parse(s, step.tp, step.strict, handler, source_type)
elif isinstance(step, _ValidateAsDefer):
s = _apply_parse(s, step.tp, False, handler, source_type)
elif isinstance(step, _Transform):
s = _apply_transform(s, step.func, handler)
elif isinstance(step, _Constraint):
s = _apply_constraint(s, step.constraint)
elif isinstance(step, _PipelineOr):
s = cs.union_schema([handler(step.left), handler(step.right)])
else:
assert isinstance(step, _PipelineAnd)
s = cs.chain_schema([handler(step.left), handler(step.right)])
return s
def _apply_parse(
s: cs.CoreSchema | None,
tp: type[Any],
strict: bool,
handler: GetCoreSchemaHandler,
source_type: Any,
) -> cs.CoreSchema:
from pydantic_core import core_schema as cs
from pydantic import Strict
if tp is _FieldTypeMarker:
return cs.chain_schema([s, handler(source_type)]) if s else handler(source_type)
if strict:
tp = Annotated[tp, Strict()] # type: ignore
if s and s['type'] == 'any':
return handler(tp)
else:
return cs.chain_schema([s, handler(tp)]) if s else handler(tp)
def _apply_transform(
s: cs.CoreSchema | None, func: Callable[[Any], Any], handler: GetCoreSchemaHandler
) -> cs.CoreSchema:
from pydantic_core import core_schema as cs
if s is None:
return cs.no_info_plain_validator_function(func)
if s['type'] == 'str':
if func is str.strip:
s = s.copy()
s['strip_whitespace'] = True
return s
elif func is str.lower:
s = s.copy()
s['to_lower'] = True
return s
elif func is str.upper:
s = s.copy()
s['to_upper'] = True
return s
return cs.no_info_after_validator_function(func, s)
def _apply_constraint( # noqa: C901
s: cs.CoreSchema | None, constraint: _ConstraintAnnotation
) -> cs.CoreSchema:
"""Apply a single constraint to a schema."""
if isinstance(constraint, annotated_types.Gt):
gt = constraint.gt
if s and s['type'] in {'int', 'float', 'decimal'}:
s = s.copy()
if s['type'] == 'int' and isinstance(gt, int):
s['gt'] = gt
elif s['type'] == 'float' and isinstance(gt, float):
s['gt'] = gt
elif s['type'] == 'decimal' and isinstance(gt, Decimal):
s['gt'] = gt
else:
def check_gt(v: Any) -> bool:
return v > gt
s = _check_func(check_gt, f'> {gt}', s)
elif isinstance(constraint, annotated_types.Ge):
ge = constraint.ge
if s and s['type'] in {'int', 'float', 'decimal'}:
s = s.copy()
if s['type'] == 'int' and isinstance(ge, int):
s['ge'] = ge
elif s['type'] == 'float' and isinstance(ge, float):
s['ge'] = ge
elif s['type'] == 'decimal' and isinstance(ge, Decimal):
s['ge'] = ge
def check_ge(v: Any) -> bool:
return v >= ge
s = _check_func(check_ge, f'>= {ge}', s)
elif isinstance(constraint, annotated_types.Lt):
lt = constraint.lt
if s and s['type'] in {'int', 'float', 'decimal'}:
s = s.copy()
if s['type'] == 'int' and isinstance(lt, int):
s['lt'] = lt
elif s['type'] == 'float' and isinstance(lt, float):
s['lt'] = lt
elif s['type'] == 'decimal' and isinstance(lt, Decimal):
s['lt'] = lt
def check_lt(v: Any) -> bool:
return v < lt
s = _check_func(check_lt, f'< {lt}', s)
elif isinstance(constraint, annotated_types.Le):
le = constraint.le
if s and s['type'] in {'int', 'float', 'decimal'}:
s = s.copy()
if s['type'] == 'int' and isinstance(le, int):
s['le'] = le
elif s['type'] == 'float' and isinstance(le, float):
s['le'] = le
elif s['type'] == 'decimal' and isinstance(le, Decimal):
s['le'] = le
def check_le(v: Any) -> bool:
return v <= le
s = _check_func(check_le, f'<= {le}', s)
elif isinstance(constraint, annotated_types.Len):
min_len = constraint.min_length
max_len = constraint.max_length
if s and s['type'] in {'str', 'list', 'tuple', 'set', 'frozenset', 'dict'}:
assert (
s['type'] == 'str'
or s['type'] == 'list'
or s['type'] == 'tuple'
or s['type'] == 'set'
or s['type'] == 'dict'
or s['type'] == 'frozenset'
)
s = s.copy()
if min_len != 0:
s['min_length'] = min_len
if max_len is not None:
s['max_length'] = max_len
def check_len(v: Any) -> bool:
if max_len is not None:
return (min_len <= len(v)) and (len(v) <= max_len)
return min_len <= len(v)
s = _check_func(check_len, f'length >= {min_len} and length <= {max_len}', s)
elif isinstance(constraint, annotated_types.MultipleOf):
multiple_of = constraint.multiple_of
if s and s['type'] in {'int', 'float', 'decimal'}:
s = s.copy()
if s['type'] == 'int' and isinstance(multiple_of, int):
s['multiple_of'] = multiple_of
elif s['type'] == 'float' and isinstance(multiple_of, float):
s['multiple_of'] = multiple_of
elif s['type'] == 'decimal' and isinstance(multiple_of, Decimal):
s['multiple_of'] = multiple_of
def check_multiple_of(v: Any) -> bool:
return v % multiple_of == 0
s = _check_func(check_multiple_of, f'% {multiple_of} == 0', s)
elif isinstance(constraint, annotated_types.Timezone):
tz = constraint.tz
if tz is ...:
if s and s['type'] == 'datetime':
s = s.copy()
s['tz_constraint'] = 'aware'
else:
def check_tz_aware(v: object) -> bool:
assert isinstance(v, datetime.datetime)
return v.tzinfo is not None
s = _check_func(check_tz_aware, 'timezone aware', s)
elif tz is None:
if s and s['type'] == 'datetime':
s = s.copy()
s['tz_constraint'] = 'naive'
else:
def check_tz_naive(v: object) -> bool:
assert isinstance(v, datetime.datetime)
return v.tzinfo is None
s = _check_func(check_tz_naive, 'timezone naive', s)
else:
raise NotImplementedError('Constraining to a specific timezone is not yet supported')
elif isinstance(constraint, annotated_types.Interval):
if constraint.ge:
s = _apply_constraint(s, annotated_types.Ge(constraint.ge))
if constraint.gt:
s = _apply_constraint(s, annotated_types.Gt(constraint.gt))
if constraint.le:
s = _apply_constraint(s, annotated_types.Le(constraint.le))
if constraint.lt:
s = _apply_constraint(s, annotated_types.Lt(constraint.lt))
assert s is not None
elif isinstance(constraint, annotated_types.Predicate):
func = constraint.func
if func.__name__ == '<lambda>':
# attempt to extract the source code for a lambda function
# to use as the function name in error messages
# TODO: is there a better way? should we just not do this?
import inspect
try:
source = inspect.getsource(func).strip()
source = source.removesuffix(')')
lambda_source_code = '`' + ''.join(''.join(source.split('lambda ')[1:]).split(':')[1:]).strip() + '`'
except OSError:
# stringified annotations
lambda_source_code = 'lambda'
s = _check_func(func, lambda_source_code, s)
else:
s = _check_func(func, func.__name__, s)
elif isinstance(constraint, _NotEq):
value = constraint.value
def check_not_eq(v: Any) -> bool:
return operator.__ne__(v, value)
s = _check_func(check_not_eq, f'!= {value}', s)
elif isinstance(constraint, _Eq):
value = constraint.value
def check_eq(v: Any) -> bool:
return operator.__eq__(v, value)
s = _check_func(check_eq, f'== {value}', s)
elif isinstance(constraint, _In):
values = constraint.values
def check_in(v: Any) -> bool:
return operator.__contains__(values, v)
s = _check_func(check_in, f'in {values}', s)
elif isinstance(constraint, _NotIn):
values = constraint.values
def check_not_in(v: Any) -> bool:
return operator.__not__(operator.__contains__(values, v))
s = _check_func(check_not_in, f'not in {values}', s)
else:
assert isinstance(constraint, Pattern)
if s and s['type'] == 'str':
s = s.copy()
s['pattern'] = constraint.pattern
else:
def check_pattern(v: object) -> bool:
assert isinstance(v, str)
return constraint.match(v) is not None
s = _check_func(check_pattern, f'~ {constraint.pattern}', s)
return s
class _SupportsRange(annotated_types.SupportsLe, annotated_types.SupportsGe, Protocol):
pass
class _SupportsLen(Protocol):
def __len__(self) -> int: ...
_NewOutGt = TypeVar('_NewOutGt', bound=annotated_types.SupportsGt)
_NewOutGe = TypeVar('_NewOutGe', bound=annotated_types.SupportsGe)
_NewOutLt = TypeVar('_NewOutLt', bound=annotated_types.SupportsLt)
_NewOutLe = TypeVar('_NewOutLe', bound=annotated_types.SupportsLe)
_NewOutLen = TypeVar('_NewOutLen', bound=_SupportsLen)
_NewOutDiv = TypeVar('_NewOutDiv', bound=annotated_types.SupportsDiv)
_NewOutMod = TypeVar('_NewOutMod', bound=annotated_types.SupportsMod)
_NewOutDatetime = TypeVar('_NewOutDatetime', bound=datetime.datetime)
_NewOutInterval = TypeVar('_NewOutInterval', bound=_SupportsRange)
_OtherIn = TypeVar('_OtherIn')
_OtherOut = TypeVar('_OtherOut')

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,13 @@
"""This module contains related classes and functions for serialization."""
from __future__ import annotations
import dataclasses
from functools import partial, partialmethod
from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, TypeVar, overload
from functools import partialmethod
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
from pydantic_core import PydanticUndefined, core_schema
from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler, WhenUsed
from typing_extensions import TypeAlias
from pydantic_core import core_schema as _core_schema
from typing_extensions import Annotated, Literal, TypeAlias
from . import PydanticUndefinedAnnotation
from ._internal import _decorators, _internal_dataclass
@@ -19,26 +18,6 @@ from .annotated_handlers import GetCoreSchemaHandler
class PlainSerializer:
"""Plain serializers use a function to modify the output of serialization.
This is particularly helpful when you want to customize the serialization for annotated types.
Consider an input of `list`, which will be serialized into a space-delimited string.
```python
from typing import Annotated
from pydantic import BaseModel, PlainSerializer
CustomStr = Annotated[
list, PlainSerializer(lambda x: ' '.join(x), return_type=str)
]
class StudentModel(BaseModel):
courses: CustomStr
student = StudentModel(courses=['Math', 'Chemistry', 'English'])
print(student.model_dump())
#> {'courses': 'Math Chemistry English'}
```
Attributes:
func: The serializer function.
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
@@ -48,7 +27,7 @@ class PlainSerializer:
func: core_schema.SerializerFunction
return_type: Any = PydanticUndefined
when_used: WhenUsed = 'always'
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always'
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
"""Gets the Pydantic core schema.
@@ -61,20 +40,12 @@ class PlainSerializer:
The Pydantic core schema.
"""
schema = handler(source_type)
if self.return_type is not PydanticUndefined:
return_type = self.return_type
else:
try:
# Do not pass in globals as the function could be defined in a different module.
# Instead, let `get_callable_return_type` infer the globals to use, but still pass
# in locals that may contain a parent/rebuild namespace:
return_type = _decorators.get_callable_return_type(
self.func,
localns=handler._get_types_namespace().locals,
)
except NameError as e:
raise PydanticUndefinedAnnotation.from_name_error(e) from e
try:
return_type = _decorators.get_function_return_type(
self.func, self.return_type, handler._get_types_namespace()
)
except NameError as e:
raise PydanticUndefinedAnnotation.from_name_error(e) from e
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
schema['serialization'] = core_schema.plain_serializer_function_ser_schema(
function=self.func,
@@ -90,58 +61,6 @@ class WrapSerializer:
"""Wrap serializers receive the raw inputs along with a handler function that applies the standard serialization
logic, and can modify the resulting value before returning it as the final output of serialization.
For example, here's a scenario in which a wrap serializer transforms timezones to UTC **and** utilizes the existing `datetime` serialization logic.
```python
from datetime import datetime, timezone
from typing import Annotated, Any
from pydantic import BaseModel, WrapSerializer
class EventDatetime(BaseModel):
start: datetime
end: datetime
def convert_to_utc(value: Any, handler, info) -> dict[str, datetime]:
# Note that `handler` can actually help serialize the `value` for
# further custom serialization in case it's a subclass.
partial_result = handler(value, info)
if info.mode == 'json':
return {
k: datetime.fromisoformat(v).astimezone(timezone.utc)
for k, v in partial_result.items()
}
return {k: v.astimezone(timezone.utc) for k, v in partial_result.items()}
UTCEventDatetime = Annotated[EventDatetime, WrapSerializer(convert_to_utc)]
class EventModel(BaseModel):
event_datetime: UTCEventDatetime
dt = EventDatetime(
start='2024-01-01T07:00:00-08:00', end='2024-01-03T20:00:00+06:00'
)
event = EventModel(event_datetime=dt)
print(event.model_dump())
'''
{
'event_datetime': {
'start': datetime.datetime(
2024, 1, 1, 15, 0, tzinfo=datetime.timezone.utc
),
'end': datetime.datetime(
2024, 1, 3, 14, 0, tzinfo=datetime.timezone.utc
),
}
}
'''
print(event.model_dump_json())
'''
{"event_datetime":{"start":"2024-01-01T15:00:00Z","end":"2024-01-03T14:00:00Z"}}
'''
```
Attributes:
func: The serializer function to be wrapped.
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
@@ -151,7 +70,7 @@ class WrapSerializer:
func: core_schema.WrapSerializerFunction
return_type: Any = PydanticUndefined
when_used: WhenUsed = 'always'
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always'
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
"""This method is used to get the Pydantic core schema of the class.
@@ -164,20 +83,12 @@ class WrapSerializer:
The generated core schema of the class.
"""
schema = handler(source_type)
if self.return_type is not PydanticUndefined:
return_type = self.return_type
else:
try:
# Do not pass in globals as the function could be defined in a different module.
# Instead, let `get_callable_return_type` infer the globals to use, but still pass
# in locals that may contain a parent/rebuild namespace:
return_type = _decorators.get_callable_return_type(
self.func,
localns=handler._get_types_namespace().locals,
)
except NameError as e:
raise PydanticUndefinedAnnotation.from_name_error(e) from e
try:
return_type = _decorators.get_function_return_type(
self.func, self.return_type, handler._get_types_namespace()
)
except NameError as e:
raise PydanticUndefinedAnnotation.from_name_error(e) from e
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
function=self.func,
@@ -189,77 +100,57 @@ class WrapSerializer:
if TYPE_CHECKING:
_Partial: TypeAlias = 'partial[Any] | partialmethod[Any]'
FieldPlainSerializer: TypeAlias = 'core_schema.SerializerFunction | _Partial'
"""A field serializer method or function in `plain` mode."""
FieldWrapSerializer: TypeAlias = 'core_schema.WrapSerializerFunction | _Partial'
"""A field serializer method or function in `wrap` mode."""
FieldSerializer: TypeAlias = 'FieldPlainSerializer | FieldWrapSerializer'
"""A field serializer method or function."""
_FieldPlainSerializerT = TypeVar('_FieldPlainSerializerT', bound=FieldPlainSerializer)
_FieldWrapSerializerT = TypeVar('_FieldWrapSerializerT', bound=FieldWrapSerializer)
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
_PlainSerializationFunction = Union[_core_schema.SerializerFunction, _PartialClsOrStaticMethod]
_WrapSerializationFunction = Union[_core_schema.WrapSerializerFunction, _PartialClsOrStaticMethod]
_PlainSerializeMethodType = TypeVar('_PlainSerializeMethodType', bound=_PlainSerializationFunction)
_WrapSerializeMethodType = TypeVar('_WrapSerializeMethodType', bound=_WrapSerializationFunction)
@overload
def field_serializer(
field: str,
/,
__field: str,
*fields: str,
return_type: Any = ...,
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ...,
check_fields: bool | None = ...,
) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]:
...
@overload
def field_serializer(
__field: str,
*fields: str,
mode: Literal['plain'],
return_type: Any = ...,
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ...,
check_fields: bool | None = ...,
) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]:
...
@overload
def field_serializer(
__field: str,
*fields: str,
mode: Literal['wrap'],
return_type: Any = ...,
when_used: WhenUsed = ...,
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ...,
check_fields: bool | None = ...,
) -> Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT]: ...
@overload
def field_serializer(
field: str,
/,
*fields: str,
mode: Literal['plain'] = ...,
return_type: Any = ...,
when_used: WhenUsed = ...,
check_fields: bool | None = ...,
) -> Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT]: ...
) -> Callable[[_WrapSerializeMethodType], _WrapSerializeMethodType]:
...
def field_serializer(
*fields: str,
mode: Literal['plain', 'wrap'] = 'plain',
return_type: Any = PydanticUndefined,
when_used: WhenUsed = 'always',
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always',
check_fields: bool | None = None,
) -> (
Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT]
| Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT]
):
) -> Callable[[Any], Any]:
"""Decorator that enables custom field serialization.
In the below example, a field of type `set` is used to mitigate duplication. A `field_serializer` is used to serialize the data as a sorted list.
```python
from typing import Set
from pydantic import BaseModel, field_serializer
class StudentModel(BaseModel):
name: str = 'Jane'
courses: Set[str]
@field_serializer('courses', when_used='json')
def serialize_courses_in_order(self, courses: Set[str]):
return sorted(courses)
student = StudentModel(courses={'Math', 'Chemistry', 'English'})
print(student.model_dump_json())
#> {"name":"Jane","courses":["Chemistry","English","Math"]}
```
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
Four signatures are supported:
@@ -284,7 +175,9 @@ def field_serializer(
The decorator function.
"""
def dec(f: FieldSerializer) -> _decorators.PydanticDescriptorProxy[Any]:
def dec(
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any]
) -> _decorators.PydanticDescriptorProxy[Any]:
dec_info = _decorators.FieldSerializerDecoratorInfo(
fields=fields,
mode=mode,
@@ -292,109 +185,42 @@ def field_serializer(
when_used=when_used,
check_fields=check_fields,
)
return _decorators.PydanticDescriptorProxy(f, dec_info) # pyright: ignore[reportArgumentType]
return _decorators.PydanticDescriptorProxy(f, dec_info)
return dec # pyright: ignore[reportReturnType]
return dec
if TYPE_CHECKING:
# The first argument in the following callables represent the `self` type:
ModelPlainSerializerWithInfo: TypeAlias = Callable[[Any, SerializationInfo], Any]
"""A model serializer method with the `info` argument, in `plain` mode."""
ModelPlainSerializerWithoutInfo: TypeAlias = Callable[[Any], Any]
"""A model serializer method without the `info` argument, in `plain` mode."""
ModelPlainSerializer: TypeAlias = 'ModelPlainSerializerWithInfo | ModelPlainSerializerWithoutInfo'
"""A model serializer method in `plain` mode."""
ModelWrapSerializerWithInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler, SerializationInfo], Any]
"""A model serializer method with the `info` argument, in `wrap` mode."""
ModelWrapSerializerWithoutInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler], Any]
"""A model serializer method without the `info` argument, in `wrap` mode."""
ModelWrapSerializer: TypeAlias = 'ModelWrapSerializerWithInfo | ModelWrapSerializerWithoutInfo'
"""A model serializer method in `wrap` mode."""
ModelSerializer: TypeAlias = 'ModelPlainSerializer | ModelWrapSerializer'
_ModelPlainSerializerT = TypeVar('_ModelPlainSerializerT', bound=ModelPlainSerializer)
_ModelWrapSerializerT = TypeVar('_ModelWrapSerializerT', bound=ModelWrapSerializer)
FuncType = TypeVar('FuncType', bound=Callable[..., Any])
@overload
def model_serializer(f: _ModelPlainSerializerT, /) -> _ModelPlainSerializerT: ...
@overload
def model_serializer(
*, mode: Literal['wrap'], when_used: WhenUsed = 'always', return_type: Any = ...
) -> Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT]: ...
def model_serializer(__f: FuncType) -> FuncType:
...
@overload
def model_serializer(
*,
mode: Literal['plain'] = ...,
when_used: WhenUsed = 'always',
mode: Literal['plain', 'wrap'] = ...,
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always',
return_type: Any = ...,
) -> Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT]: ...
) -> Callable[[FuncType], FuncType]:
...
def model_serializer(
f: _ModelPlainSerializerT | _ModelWrapSerializerT | None = None,
/,
__f: Callable[..., Any] | None = None,
*,
mode: Literal['plain', 'wrap'] = 'plain',
when_used: WhenUsed = 'always',
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always',
return_type: Any = PydanticUndefined,
) -> (
_ModelPlainSerializerT
| Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT]
| Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT]
):
) -> Callable[[Any], Any]:
"""Decorator that enables custom model serialization.
This is useful when a model need to be serialized in a customized manner, allowing for flexibility beyond just specific fields.
An example would be to serialize temperature to the same temperature scale, such as degrees Celsius.
```python
from typing import Literal
from pydantic import BaseModel, model_serializer
class TemperatureModel(BaseModel):
unit: Literal['C', 'F']
value: int
@model_serializer()
def serialize_model(self):
if self.unit == 'F':
return {'unit': 'C', 'value': int((self.value - 32) / 1.8)}
return {'unit': self.unit, 'value': self.value}
temperature = TemperatureModel(unit='F', value=212)
print(temperature.model_dump())
#> {'unit': 'C', 'value': 100}
```
Two signatures are supported for `mode='plain'`, which is the default:
- `(self)`
- `(self, info: SerializationInfo)`
And two other signatures for `mode='wrap'`:
- `(self, nxt: SerializerFunctionWrapHandler)`
- `(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo)`
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
Args:
f: The function to be decorated.
__f: The function to be decorated.
mode: The serialization mode.
- `'plain'` means the function will be called instead of the default serialization logic
@@ -407,14 +233,14 @@ def model_serializer(
The decorator function.
"""
def dec(f: ModelSerializer) -> _decorators.PydanticDescriptorProxy[Any]:
def dec(f: Callable[..., Any]) -> _decorators.PydanticDescriptorProxy[Any]:
dec_info = _decorators.ModelSerializerDecoratorInfo(mode=mode, return_type=return_type, when_used=when_used)
return _decorators.PydanticDescriptorProxy(f, dec_info)
if f is None:
return dec # pyright: ignore[reportReturnType]
if __f is None:
return dec
else:
return dec(f) # pyright: ignore[reportReturnType]
return dec(__f) # type: ignore
AnyType = TypeVar('AnyType')

View File

@@ -6,13 +6,14 @@ import dataclasses
import sys
from functools import partialmethod
from types import FunctionType
from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, TypeVar, Union, cast, overload
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast, overload
from pydantic_core import PydanticUndefined, core_schema
from pydantic_core import core_schema
from pydantic_core import core_schema as _core_schema
from typing_extensions import Self, TypeAlias
from typing_extensions import Annotated, Literal, TypeAlias
from ._internal import _decorators, _generics, _internal_dataclass
from . import GetCoreSchemaHandler as _GetCoreSchemaHandler
from ._internal import _core_metadata, _decorators, _generics, _internal_dataclass
from .annotated_handlers import GetCoreSchemaHandler
from .errors import PydanticUserError
@@ -26,8 +27,7 @@ _inspect_validator = _decorators.inspect_validator
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
class AfterValidator:
"""!!! abstract "Usage Documentation"
[field *after* validators](../concepts/validators.md#field-after-validator)
'''Usage docs: https://docs.pydantic.dev/2.2/concepts/validators/#annotated-validators
A metadata class that indicates that a validation should be applied **after** the inner validation logic.
@@ -35,10 +35,11 @@ class AfterValidator:
func: The validator function.
Example:
```python
```py
from typing import Annotated
from pydantic import AfterValidator, BaseModel, ValidationError
from pydantic import BaseModel, AfterValidator, ValidationError
MyInt = Annotated[int, AfterValidator(lambda v: v + 1)]
@@ -46,31 +47,31 @@ class AfterValidator:
a: MyInt
print(Model(a=1).a)
#> 2
# > 2
try:
Model(a='a')
except ValidationError as e:
print(e.json(indent=2))
'''
[
{
"""
[
{
"type": "int_parsing",
"loc": [
"a"
"a"
],
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "a",
"url": "https://errors.pydantic.dev/2/v/int_parsing"
}
]
'''
"url": "https://errors.pydantic.dev/0.38.0/v/int_parsing"
}
]
"""
```
"""
'''
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
schema = handler(source_type)
info_arg = _inspect_validator(self.func, 'after')
if info_arg:
@@ -80,26 +81,19 @@ class AfterValidator:
func = cast(core_schema.NoInfoValidatorFunction, self.func)
return core_schema.no_info_after_validator_function(func, schema=schema)
@classmethod
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
return cls(func=decorator.func)
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
class BeforeValidator:
"""!!! abstract "Usage Documentation"
[field *before* validators](../concepts/validators.md#field-before-validator)
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#annotated-validators
A metadata class that indicates that a validation should be applied **before** the inner validation logic.
Attributes:
func: The validator function.
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
JSON Schema (in validation mode).
Example:
```python
from typing import Annotated
```py
from typing_extensions import Annotated
from pydantic import BaseModel, BeforeValidator
@@ -120,151 +114,68 @@ class BeforeValidator:
"""
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
json_schema_input_type: Any = PydanticUndefined
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
schema = handler(source_type)
input_schema = (
None
if self.json_schema_input_type is PydanticUndefined
else handler.generate_schema(self.json_schema_input_type)
)
info_arg = _inspect_validator(self.func, 'before')
if info_arg:
func = cast(core_schema.WithInfoValidatorFunction, self.func)
return core_schema.with_info_before_validator_function(
func,
schema=schema,
field_name=handler.field_name,
json_schema_input_schema=input_schema,
)
return core_schema.with_info_before_validator_function(func, schema=schema, field_name=handler.field_name)
else:
func = cast(core_schema.NoInfoValidatorFunction, self.func)
return core_schema.no_info_before_validator_function(
func, schema=schema, json_schema_input_schema=input_schema
)
@classmethod
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
return cls(
func=decorator.func,
json_schema_input_type=decorator.info.json_schema_input_type,
)
return core_schema.no_info_before_validator_function(func, schema=schema)
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
class PlainValidator:
"""!!! abstract "Usage Documentation"
[field *plain* validators](../concepts/validators.md#field-plain-validator)
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#annotated-validators
A metadata class that indicates that a validation should be applied **instead** of the inner validation logic.
!!! note
Before v2.9, `PlainValidator` wasn't always compatible with JSON Schema generation for `mode='validation'`.
You can now use the `json_schema_input_type` argument to specify the input type of the function
to be used in the JSON schema when `mode='validation'` (the default). See the example below for more details.
Attributes:
func: The validator function.
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
JSON Schema (in validation mode). If not provided, will default to `Any`.
Example:
```python
from typing import Annotated, Union
```py
from typing_extensions import Annotated
from pydantic import BaseModel, PlainValidator
MyInt = Annotated[
int,
PlainValidator(
lambda v: int(v) + 1, json_schema_input_type=Union[str, int] # (1)!
),
]
MyInt = Annotated[int, PlainValidator(lambda v: int(v) + 1)]
class Model(BaseModel):
a: MyInt
print(Model(a='1').a)
#> 2
print(Model(a=1).a)
#> 2
```
1. In this example, we've specified the `json_schema_input_type` as `Union[str, int]` which indicates to the JSON schema
generator that in validation mode, the input type for the `a` field can be either a `str` or an `int`.
"""
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
json_schema_input_type: Any = Any
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
# Note that for some valid uses of PlainValidator, it is not possible to generate a core schema for the
# source_type, so calling `handler(source_type)` will error, which prevents us from generating a proper
# serialization schema. To work around this for use cases that will not involve serialization, we simply
# catch any PydanticSchemaGenerationError that may be raised while attempting to build the serialization schema
# and abort any attempts to handle special serialization.
from pydantic import PydanticSchemaGenerationError
try:
schema = handler(source_type)
# TODO if `schema['serialization']` is one of `'include-exclude-dict/sequence',
# schema validation will fail. That's why we use 'type ignore' comments below.
serialization = schema.get(
'serialization',
core_schema.wrap_serializer_function_ser_schema(
function=lambda v, h: h(v),
schema=schema,
return_schema=handler.generate_schema(source_type),
),
)
except PydanticSchemaGenerationError:
serialization = None
input_schema = handler.generate_schema(self.json_schema_input_type)
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
info_arg = _inspect_validator(self.func, 'plain')
if info_arg:
func = cast(core_schema.WithInfoValidatorFunction, self.func)
return core_schema.with_info_plain_validator_function(
func,
field_name=handler.field_name,
serialization=serialization, # pyright: ignore[reportArgumentType]
json_schema_input_schema=input_schema,
)
return core_schema.with_info_plain_validator_function(func, field_name=handler.field_name)
else:
func = cast(core_schema.NoInfoValidatorFunction, self.func)
return core_schema.no_info_plain_validator_function(
func,
serialization=serialization, # pyright: ignore[reportArgumentType]
json_schema_input_schema=input_schema,
)
@classmethod
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
return cls(
func=decorator.func,
json_schema_input_type=decorator.info.json_schema_input_type,
)
return core_schema.no_info_plain_validator_function(func)
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
class WrapValidator:
"""!!! abstract "Usage Documentation"
[field *wrap* validators](../concepts/validators.md#field-wrap-validator)
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#annotated-validators
A metadata class that indicates that a validation should be applied **around** the inner validation logic.
Attributes:
func: The validator function.
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
JSON Schema (in validation mode).
```python
```py
from datetime import datetime
from typing import Annotated
from typing_extensions import Annotated
from pydantic import BaseModel, ValidationError, WrapValidator
@@ -291,61 +202,37 @@ class WrapValidator:
"""
func: core_schema.NoInfoWrapValidatorFunction | core_schema.WithInfoWrapValidatorFunction
json_schema_input_type: Any = PydanticUndefined
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
schema = handler(source_type)
input_schema = (
None
if self.json_schema_input_type is PydanticUndefined
else handler.generate_schema(self.json_schema_input_type)
)
info_arg = _inspect_validator(self.func, 'wrap')
if info_arg:
func = cast(core_schema.WithInfoWrapValidatorFunction, self.func)
return core_schema.with_info_wrap_validator_function(
func,
schema=schema,
field_name=handler.field_name,
json_schema_input_schema=input_schema,
)
return core_schema.with_info_wrap_validator_function(func, schema=schema, field_name=handler.field_name)
else:
func = cast(core_schema.NoInfoWrapValidatorFunction, self.func)
return core_schema.no_info_wrap_validator_function(
func,
schema=schema,
json_schema_input_schema=input_schema,
)
@classmethod
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
return cls(
func=decorator.func,
json_schema_input_type=decorator.info.json_schema_input_type,
)
return core_schema.no_info_wrap_validator_function(func, schema=schema)
if TYPE_CHECKING:
class _OnlyValueValidatorClsMethod(Protocol):
def __call__(self, cls: Any, value: Any, /) -> Any: ...
def __call__(self, __cls: Any, __value: Any) -> Any:
...
class _V2ValidatorClsMethod(Protocol):
def __call__(self, cls: Any, value: Any, info: _core_schema.ValidationInfo, /) -> Any: ...
class _OnlyValueWrapValidatorClsMethod(Protocol):
def __call__(self, cls: Any, value: Any, handler: _core_schema.ValidatorFunctionWrapHandler, /) -> Any: ...
def __call__(self, __cls: Any, __input_value: Any, __info: _core_schema.ValidationInfo) -> Any:
...
class _V2WrapValidatorClsMethod(Protocol):
def __call__(
self,
cls: Any,
value: Any,
handler: _core_schema.ValidatorFunctionWrapHandler,
info: _core_schema.ValidationInfo,
/,
) -> Any: ...
__cls: Any,
__input_value: Any,
__validator: _core_schema.ValidatorFunctionWrapHandler,
__info: _core_schema.ValidationInfo,
) -> Any:
...
_V2Validator = Union[
_V2ValidatorClsMethod,
@@ -357,111 +244,57 @@ if TYPE_CHECKING:
_V2WrapValidator = Union[
_V2WrapValidatorClsMethod,
_core_schema.WithInfoWrapValidatorFunction,
_OnlyValueWrapValidatorClsMethod,
_core_schema.NoInfoWrapValidatorFunction,
]
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
_V2BeforeAfterOrPlainValidatorType = TypeVar(
'_V2BeforeAfterOrPlainValidatorType',
bound=Union[_V2Validator, _PartialClsOrStaticMethod],
_V2Validator,
_PartialClsOrStaticMethod,
)
_V2WrapValidatorType = TypeVar('_V2WrapValidatorType', bound=Union[_V2WrapValidator, _PartialClsOrStaticMethod])
_V2WrapValidatorType = TypeVar('_V2WrapValidatorType', _V2WrapValidator, _PartialClsOrStaticMethod)
@overload
def field_validator(
__field: str,
*fields: str,
mode: Literal['before', 'after', 'plain'] = ...,
check_fields: bool | None = ...,
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]:
...
@overload
def field_validator(
__field: str,
*fields: str,
mode: Literal['wrap'],
check_fields: bool | None = ...,
) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]:
...
FieldValidatorModes: TypeAlias = Literal['before', 'after', 'wrap', 'plain']
@overload
def field_validator(
field: str,
/,
*fields: str,
mode: Literal['wrap'],
check_fields: bool | None = ...,
json_schema_input_type: Any = ...,
) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]: ...
@overload
def field_validator(
field: str,
/,
*fields: str,
mode: Literal['before', 'plain'],
check_fields: bool | None = ...,
json_schema_input_type: Any = ...,
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ...
@overload
def field_validator(
field: str,
/,
*fields: str,
mode: Literal['after'] = ...,
check_fields: bool | None = ...,
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ...
def field_validator(
field: str,
/,
__field: str,
*fields: str,
mode: FieldValidatorModes = 'after',
check_fields: bool | None = None,
json_schema_input_type: Any = PydanticUndefined,
) -> Callable[[Any], Any]:
"""!!! abstract "Usage Documentation"
[field validators](../concepts/validators.md#field-validators)
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#field-validators
Decorate methods on the class indicating that they should be used to validate fields.
Example usage:
```python
from typing import Any
from pydantic import (
BaseModel,
ValidationError,
field_validator,
)
class Model(BaseModel):
a: str
@field_validator('a')
@classmethod
def ensure_foobar(cls, v: Any):
if 'foobar' not in v:
raise ValueError('"foobar" not found in a')
return v
print(repr(Model(a='this is foobar good')))
#> Model(a='this is foobar good')
try:
Model(a='snap')
except ValidationError as exc_info:
print(exc_info)
'''
1 validation error for Model
a
Value error, "foobar" not found in a [type=value_error, input_value='snap', input_type=str]
'''
```
For more in depth examples, see [Field Validators](../concepts/validators.md#field-validators).
Args:
field: The first field the `field_validator` should be called on; this is separate
__field: The first field the `field_validator` should be called on; this is separate
from `fields` to ensure an error is raised if you don't pass at least one.
*fields: Additional field(s) the `field_validator` should be called on.
mode: Specifies whether to validate the fields before or after validation.
check_fields: Whether to check that the fields actually exist on the model.
json_schema_input_type: The input type of the function. This is only used to generate
the appropriate JSON Schema (in validation mode) and can only specified
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
Returns:
A decorator that can be used to decorate a function to be used as a field_validator.
@@ -472,23 +305,13 @@ def field_validator(
- If the args passed to `@field_validator` as fields are not strings.
- If `@field_validator` applied to instance methods.
"""
if isinstance(field, FunctionType):
if isinstance(__field, FunctionType):
raise PydanticUserError(
'`@field_validator` should be used with fields and keyword arguments, not bare. '
"E.g. usage should be `@validator('<field_name>', ...)`",
code='validator-no-fields',
)
if mode not in ('before', 'plain', 'wrap') and json_schema_input_type is not PydanticUndefined:
raise PydanticUserError(
f"`json_schema_input_type` can't be used when mode is set to {mode!r}",
code='validator-input-type',
)
if json_schema_input_type is PydanticUndefined and mode == 'plain':
json_schema_input_type = Any
fields = field, *fields
fields = __field, *fields
if not all(isinstance(field, str) for field in fields):
raise PydanticUserError(
'`@field_validator` fields should be passed as separate string args. '
@@ -497,7 +320,7 @@ def field_validator(
)
def dec(
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any],
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any]
) -> _decorators.PydanticDescriptorProxy[Any]:
if _decorators.is_instance_method_from_sig(f):
raise PydanticUserError(
@@ -507,9 +330,7 @@ def field_validator(
# auto apply the @classmethod decorator
f = _decorators.ensure_classmethod_based_on_signature(f)
dec_info = _decorators.FieldValidatorDecoratorInfo(
fields=fields, mode=mode, check_fields=check_fields, json_schema_input_type=json_schema_input_type
)
dec_info = _decorators.FieldValidatorDecoratorInfo(fields=fields, mode=mode, check_fields=check_fields)
return _decorators.PydanticDescriptorProxy(f, dec_info)
return dec
@@ -520,19 +341,16 @@ _ModelTypeCo = TypeVar('_ModelTypeCo', covariant=True)
class ModelWrapValidatorHandler(_core_schema.ValidatorFunctionWrapHandler, Protocol[_ModelTypeCo]):
"""`@model_validator` decorated function handler argument type. This is used when `mode='wrap'`."""
"""@model_validator decorated function handler argument type. This is used when `mode='wrap'`."""
def __call__( # noqa: D102
self,
value: Any,
outer_location: str | int | None = None,
/,
self, input_value: Any, outer_location: str | int | None = None
) -> _ModelTypeCo: # pragma: no cover
...
class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]):
"""A `@model_validator` decorated function signature.
"""A @model_validator decorated function signature.
This is used when `mode='wrap'` and the function does not have info argument.
"""
@@ -542,14 +360,14 @@ class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]):
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
handler: ModelWrapValidatorHandler[_ModelType],
/,
) -> _ModelType: ...
__value: Any,
__handler: ModelWrapValidatorHandler[_ModelType],
) -> _ModelType:
...
class ModelWrapValidator(Protocol[_ModelType]):
"""A `@model_validator` decorated function signature. This is used when `mode='wrap'`."""
"""A @model_validator decorated function signature. This is used when `mode='wrap'`."""
def __call__( # noqa: D102
self,
@@ -557,30 +375,15 @@ class ModelWrapValidator(Protocol[_ModelType]):
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
handler: ModelWrapValidatorHandler[_ModelType],
info: _core_schema.ValidationInfo,
/,
) -> _ModelType: ...
class FreeModelBeforeValidatorWithoutInfo(Protocol):
"""A `@model_validator` decorated function signature.
This is used when `mode='before'` and the function does not have info argument.
"""
def __call__( # noqa: D102
self,
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
/,
) -> Any: ...
__value: Any,
__handler: ModelWrapValidatorHandler[_ModelType],
__info: _core_schema.ValidationInfo,
) -> _ModelType:
...
class ModelBeforeValidatorWithoutInfo(Protocol):
"""A `@model_validator` decorated function signature.
"""A @model_validator decorated function signature.
This is used when `mode='before'` and the function does not have info argument.
"""
@@ -590,23 +393,9 @@ class ModelBeforeValidatorWithoutInfo(Protocol):
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
/,
) -> Any: ...
class FreeModelBeforeValidator(Protocol):
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""
def __call__( # noqa: D102
self,
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
info: _core_schema.ValidationInfo,
/,
) -> Any: ...
__value: Any,
) -> Any:
...
class ModelBeforeValidator(Protocol):
@@ -618,10 +407,10 @@ class ModelBeforeValidator(Protocol):
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
info: _core_schema.ValidationInfo,
/,
) -> Any: ...
__value: Any,
__info: _core_schema.ValidationInfo,
) -> Any:
...
ModelAfterValidatorWithoutInfo = Callable[[_ModelType], _ModelType]
@@ -633,9 +422,7 @@ ModelAfterValidator = Callable[[_ModelType, _core_schema.ValidationInfo], _Model
"""A `@model_validator` decorated function signature. This is used when `mode='after'`."""
_AnyModelWrapValidator = Union[ModelWrapValidator[_ModelType], ModelWrapValidatorWithoutInfo[_ModelType]]
_AnyModelBeforeValidator = Union[
FreeModelBeforeValidator, ModelBeforeValidator, FreeModelBeforeValidatorWithoutInfo, ModelBeforeValidatorWithoutInfo
]
_AnyModeBeforeValidator = Union[ModelBeforeValidator, ModelBeforeValidatorWithoutInfo]
_AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]]
@@ -645,16 +432,16 @@ def model_validator(
mode: Literal['wrap'],
) -> Callable[
[_AnyModelWrapValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
]: ...
]:
...
@overload
def model_validator(
*,
mode: Literal['before'],
) -> Callable[
[_AnyModelBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
]: ...
) -> Callable[[_AnyModeBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]]:
...
@overload
@@ -663,49 +450,15 @@ def model_validator(
mode: Literal['after'],
) -> Callable[
[_AnyModelAfterValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
]: ...
]:
...
def model_validator(
*,
mode: Literal['wrap', 'before', 'after'],
) -> Any:
"""!!! abstract "Usage Documentation"
[Model Validators](../concepts/validators.md#model-validators)
Decorate model methods for validation purposes.
Example usage:
```python
from typing_extensions import Self
from pydantic import BaseModel, ValidationError, model_validator
class Square(BaseModel):
width: float
height: float
@model_validator(mode='after')
def verify_square(self) -> Self:
if self.width != self.height:
raise ValueError('width and height do not match')
return self
s = Square(width=1, height=1)
print(repr(s))
#> Square(width=1.0, height=1.0)
try:
Square(width=1, height=2)
except ValidationError as e:
print(e)
'''
1 validation error for Square
Value error, width and height do not match [type=value_error, input_value={'width': 1, 'height': 2}, input_type=dict]
'''
```
For more in depth examples, see [Model Validators](../concepts/validators.md#model-validators).
"""Decorate model methods for validation purposes.
Args:
mode: A required string literal that specifies the validation mode.
@@ -738,7 +491,7 @@ else:
'''Generic type for annotating a type that is an instance of a given class.
Example:
```python
```py
from pydantic import BaseModel, InstanceOf
class Foo:
@@ -817,7 +570,7 @@ else:
@classmethod
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
original_schema = handler(source)
metadata = {'pydantic_js_annotation_functions': [lambda _c, h: h(original_schema)]}
metadata = _core_metadata.build_metadata_dict(js_annotation_functions=[lambda _c, h: h(original_schema)])
return core_schema.any_schema(
metadata=metadata,
serialization=core_schema.wrap_serializer_function_ser_schema(

View File

@@ -1,5 +1,4 @@
"""The `generics` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -1,5 +1,4 @@
"""The `json` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -3,9 +3,8 @@
from __future__ import annotations
import sys
from collections.abc import Iterator
from configparser import ConfigParser
from typing import Any, Callable
from typing import Any, Callable, Iterator
from mypy.errorcodes import ErrorCode
from mypy.expandtype import expand_type, expand_type_by_instance
@@ -15,7 +14,6 @@ from mypy.nodes import (
ARG_OPT,
ARG_POS,
ARG_STAR2,
INVARIANT,
MDEF,
Argument,
AssignmentStmt,
@@ -47,24 +45,26 @@ from mypy.options import Options
from mypy.plugin import (
CheckerPluginInterface,
ClassDefContext,
FunctionContext,
MethodContext,
Plugin,
ReportConfigContext,
SemanticAnalyzerPluginInterface,
)
from mypy.plugins import dataclasses
from mypy.plugins.common import (
deserialize_and_fixup_type,
)
from mypy.semanal import set_callable_name
from mypy.server.trigger import make_wildcard_trigger
from mypy.state import state
from mypy.type_visitor import TypeTranslator
from mypy.typeops import map_type_from_supertype
from mypy.types import (
AnyType,
CallableType,
Instance,
NoneType,
Overloaded,
Type,
TypeOfAny,
TypeType,
@@ -79,11 +79,16 @@ from mypy.version import __version__ as mypy_version
from pydantic._internal import _fields
from pydantic.version import parse_mypy_version
try:
from mypy.types import TypeVarDef # type: ignore[attr-defined]
except ImportError: # pragma: no cover
# Backward-compatible with TypeVarDef from Mypy 0.930.
from mypy.types import TypeVarType as TypeVarDef
CONFIGFILE_KEY = 'pydantic-mypy'
METADATA_KEY = 'pydantic-mypy-metadata'
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings'
ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel'
MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass'
FIELD_FULLNAME = 'pydantic.fields.Field'
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass'
@@ -96,11 +101,10 @@ DECORATOR_FULLNAMES = {
'pydantic.deprecated.class_validators.validator',
'pydantic.deprecated.class_validators.root_validator',
}
IMPLICIT_CLASSMETHOD_DECORATOR_FULLNAMES = DECORATOR_FULLNAMES - {'pydantic.functional_serializers.model_serializer'}
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
BUILTINS_NAME = 'builtins'
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'
# Increment version if plugin changes and mypy caches should be invalidated
__version__ = 2
@@ -129,12 +133,12 @@ class PydanticPlugin(Plugin):
self._plugin_data = self.plugin_config.to_data()
super().__init__(options)
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], bool] | None:
"""Update Pydantic model class."""
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo): # pragma: no branch
# No branching may occur if the mypy cache has not been cleared
if sym.node.has_base(BASEMODEL_FULLNAME):
if any(base.fullname == BASEMODEL_FULLNAME for base in sym.node.mro):
return self._pydantic_model_class_maker_callback
return None
@@ -144,12 +148,28 @@ class PydanticPlugin(Plugin):
return self._pydantic_model_metaclass_marker_callback
return None
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
"""Adjust the return type of the `Field` function."""
sym = self.lookup_fully_qualified(fullname)
if sym and sym.fullname == FIELD_FULLNAME:
return self._pydantic_field_callback
return None
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
"""Adjust return type of `from_orm` method call."""
if fullname.endswith('.from_orm'):
return from_attributes_callback
return None
def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
"""Mark pydantic.dataclasses as dataclass.
Mypy version 1.1.1 added support for `@dataclass_transform` decorator.
"""
if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1):
return dataclasses.dataclass_class_maker_callback # type: ignore[return-value]
return None
def report_config_data(self, ctx: ReportConfigContext) -> dict[str, Any]:
"""Return all plugin config data.
@@ -157,9 +177,9 @@ class PydanticPlugin(Plugin):
"""
return self._plugin_data
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool:
transformer = PydanticModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config)
transformer.transform()
return transformer.transform()
def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None:
"""Reset dataclass_transform_spec attribute of ModelMetaclass.
@@ -174,6 +194,54 @@ class PydanticPlugin(Plugin):
if getattr(info_metaclass.type, 'dataclass_transform_spec', None):
info_metaclass.type.dataclass_transform_spec = None
def _pydantic_field_callback(self, ctx: FunctionContext) -> Type:
"""Extract the type of the `default` argument from the Field function, and use it as the return type.
In particular:
* Check whether the default and default_factory argument is specified.
* Output an error if both are specified.
* Retrieve the type of the argument which is specified, and use it as return type for the function.
"""
default_any_type = ctx.default_return_type
assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()'
assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()'
default_args = ctx.args[0]
default_factory_args = ctx.args[1]
if default_args and default_factory_args:
error_default_and_default_factory_specified(ctx.api, ctx.context)
return default_any_type
if default_args:
default_type = ctx.arg_types[0][0]
default_arg = default_args[0]
# Fallback to default Any type if the field is required
if not isinstance(default_arg, EllipsisExpr):
return default_type
elif default_factory_args:
default_factory_type = ctx.arg_types[1][0]
# Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter
# Pydantic calls the default factory without any argument, so we retrieve the first item
if isinstance(default_factory_type, Overloaded):
default_factory_type = default_factory_type.items[0]
if isinstance(default_factory_type, CallableType):
ret_type = default_factory_type.ret_type
# mypy doesn't think `ret_type` has `args`, you'd think mypy should know,
# add this check in case it varies by version
args = getattr(ret_type, 'args', None)
if args:
if all(isinstance(arg, TypeVarType) for arg in args):
# Looks like the default factory is a type like `list` or `dict`, replace all args with `Any`
ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined]
return ret_type
return default_any_type
class PydanticPluginConfig:
"""A Pydantic mypy plugin config holder.
@@ -238,9 +306,6 @@ def from_attributes_callback(ctx: MethodContext) -> Type:
pydantic_metadata = model_type.type.metadata.get(METADATA_KEY)
if pydantic_metadata is None:
return ctx.default_return_type
if not model_type.type.has_base(BASEMODEL_FULLNAME):
# not a Pydantic v2 model
return ctx.default_return_type
from_attributes = pydantic_metadata.get('config', {}).get('from_attributes')
if from_attributes is not True:
error_from_attributes(model_type.type.name, ctx.api, ctx.context)
@@ -254,10 +319,8 @@ class PydanticModelField:
self,
name: str,
alias: str | None,
is_frozen: bool,
has_dynamic_alias: bool,
has_default: bool,
strict: bool | None,
line: int,
column: int,
type: Type | None,
@@ -265,103 +328,40 @@ class PydanticModelField:
):
self.name = name
self.alias = alias
self.is_frozen = is_frozen
self.has_dynamic_alias = has_dynamic_alias
self.has_default = has_default
self.strict = strict
self.line = line
self.column = column
self.type = type
self.info = info
def to_argument(
self,
current_info: TypeInfo,
typed: bool,
model_strict: bool,
force_optional: bool,
use_alias: bool,
api: SemanticAnalyzerPluginInterface,
force_typevars_invariant: bool,
is_root_model_root: bool,
) -> Argument:
def to_argument(self, current_info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_argument."""
variable = self.to_var(current_info, api, use_alias, force_typevars_invariant)
strict = model_strict if self.strict is None else self.strict
if typed or strict:
type_annotation = self.expand_type(current_info, api, include_root_type=True)
else:
type_annotation = AnyType(TypeOfAny.explicit)
return Argument(
variable=variable,
type_annotation=type_annotation,
variable=self.to_var(current_info, use_alias),
type_annotation=self.expand_type(current_info) if typed else AnyType(TypeOfAny.explicit),
initializer=None,
kind=ARG_OPT
if is_root_model_root
else (ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED),
kind=ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED,
)
def expand_type(
self,
current_info: TypeInfo,
api: SemanticAnalyzerPluginInterface,
force_typevars_invariant: bool = False,
include_root_type: bool = False,
) -> Type | None:
def expand_type(self, current_info: TypeInfo) -> Type | None:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.expand_type."""
if force_typevars_invariant:
# In some cases, mypy will emit an error "Cannot use a covariant type variable as a parameter"
# To prevent that, we add an option to replace typevars with invariant ones while building certain
# method signatures (in particular, `__init__`). There may be a better way to do this, if this causes
# us problems in the future, we should look into why the dataclasses plugin doesn't have this issue.
if isinstance(self.type, TypeVarType):
modified_type = self.type.copy_modified()
modified_type.variance = INVARIANT
self.type = modified_type
if self.type is not None and self.info.self_type is not None:
# In general, it is not safe to call `expand_type()` during semantic analysis,
# In general, it is not safe to call `expand_type()` during semantic analyzis,
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
with state.strict_optional_set(api.options.strict_optional):
filled_with_typevars = fill_typevars(current_info)
# Cannot be TupleType as current_info represents a Pydantic model:
assert isinstance(filled_with_typevars, Instance)
if force_typevars_invariant:
for arg in filled_with_typevars.args:
if isinstance(arg, TypeVarType):
arg.variance = INVARIANT
expanded_type = expand_type(self.type, {self.info.self_type.id: filled_with_typevars})
if include_root_type and isinstance(expanded_type, Instance) and is_root_model(expanded_type.type):
# When a root model is used as a field, Pydantic allows both an instance of the root model
# as well as instances of the `root` field type:
root_type = expanded_type.type['root'].type
if root_type is None:
# Happens if the hint for 'root' has unsolved forward references
return expanded_type
expanded_root_type = expand_type_by_instance(root_type, expanded_type)
expanded_type = UnionType([expanded_type, expanded_root_type])
return expanded_type
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
return self.type
def to_var(
self,
current_info: TypeInfo,
api: SemanticAnalyzerPluginInterface,
use_alias: bool,
force_typevars_invariant: bool = False,
) -> Var:
def to_var(self, current_info: TypeInfo, use_alias: bool) -> Var:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_var."""
if use_alias and self.alias is not None:
name = self.alias
else:
name = self.name
return Var(name, self.expand_type(current_info, api, force_typevars_invariant))
return Var(name, self.expand_type(current_info))
def serialize(self) -> JsonDict:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.serialize."""
@@ -369,10 +369,8 @@ class PydanticModelField:
return {
'name': self.name,
'alias': self.alias,
'is_frozen': self.is_frozen,
'has_dynamic_alias': self.has_dynamic_alias,
'has_default': self.has_default,
'strict': self.strict,
'line': self.line,
'column': self.column,
'type': self.type.serialize(),
@@ -385,38 +383,12 @@ class PydanticModelField:
typ = deserialize_and_fixup_type(data.pop('type'), api)
return cls(type=typ, info=info, **data)
def expand_typevar_from_subtype(self, sub_type: TypeInfo, api: SemanticAnalyzerPluginInterface) -> None:
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type.
"""
if self.type is not None:
with state.strict_optional_set(api.options.strict_optional):
self.type = map_type_from_supertype(self.type, sub_type, self.info)
class PydanticModelClassVar:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.
ClassVars are ignored by subclasses.
Attributes:
name: the ClassVar name
"""
def __init__(self, name):
self.name = name
@classmethod
def deserialize(cls, data: JsonDict) -> PydanticModelClassVar:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize."""
data = data.copy()
return cls(**data)
def serialize(self) -> JsonDict:
"""Based on mypy.plugins.dataclasses.DataclassAttribute.serialize."""
return {
'name': self.name,
}
self.type = map_type_from_supertype(self.type, sub_type, self.info)
class PydanticModelTransformer:
@@ -431,10 +403,7 @@ class PydanticModelTransformer:
'frozen',
'from_attributes',
'populate_by_name',
'validate_by_alias',
'validate_by_name',
'alias_generator',
'strict',
}
def __init__(
@@ -461,26 +430,24 @@ class PydanticModelTransformer:
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
"""
info = self._cls.info
is_a_root_model = is_root_model(info)
config = self.collect_config()
fields, class_vars = self.collect_fields_and_class_vars(config, is_a_root_model)
if fields is None or class_vars is None:
fields = self.collect_fields(config)
if fields is None:
# Some definitions are not ready. We need another pass.
return False
for field in fields:
if field.type is None:
return False
is_settings = info.has_base(BASESETTINGS_FULLNAME)
self.add_initializer(fields, config, is_settings, is_a_root_model)
self.add_model_construct_method(fields, config, is_settings, is_a_root_model)
self.set_frozen(fields, self._api, frozen=config.frozen is True)
is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1])
self.add_initializer(fields, config, is_settings)
self.add_model_construct_method(fields, config, is_settings)
self.set_frozen(fields, frozen=config.frozen is True)
self.adjust_decorator_signatures()
info.metadata[METADATA_KEY] = {
'fields': {field.name: field.serialize() for field in fields},
'class_vars': {class_var.name: class_var.serialize() for class_var in class_vars},
'config': config.get_values_dict(),
}
@@ -494,13 +461,13 @@ class PydanticModelTransformer:
Teach mypy this by marking any function whose outermost decorator is a `validator()`,
`field_validator()` or `serializer()` call as a `classmethod`.
"""
for sym in self._cls.info.names.values():
for name, sym in self._cls.info.names.items():
if isinstance(sym.node, Decorator):
first_dec = sym.node.original_decorators[0]
if (
isinstance(first_dec, CallExpr)
and isinstance(first_dec.callee, NameExpr)
and first_dec.callee.fullname in IMPLICIT_CLASSMETHOD_DECORATOR_FULLNAMES
and first_dec.callee.fullname in DECORATOR_FULLNAMES
# @model_validator(mode="after") is an exception, it expects a regular method
and not (
first_dec.callee.fullname == MODEL_VALIDATOR_FULLNAME
@@ -543,7 +510,7 @@ class PydanticModelTransformer:
for arg_name, arg in zip(stmt.rvalue.arg_names, stmt.rvalue.args):
if arg_name is None:
continue
config.update(self.get_config_update(arg_name, arg, lax_extra=True))
config.update(self.get_config_update(arg_name, arg))
elif isinstance(stmt.rvalue, DictExpr): # dict literals
for key_expr, value_expr in stmt.rvalue.items:
if not isinstance(key_expr, StrExpr):
@@ -574,7 +541,7 @@ class PydanticModelTransformer:
if (
stmt
and config.has_alias_generator
and not (config.validate_by_name or config.populate_by_name)
and not config.populate_by_name
and self.plugin_config.warn_required_dynamic_aliases
):
error_required_dynamic_aliases(self._api, stmt)
@@ -589,13 +556,11 @@ class PydanticModelTransformer:
config.setdefault(name, value)
return config
def collect_fields_and_class_vars(
self, model_config: ModelConfigData, is_root_model: bool
) -> tuple[list[PydanticModelField] | None, list[PydanticModelClassVar] | None]:
def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelField] | None:
"""Collects the fields for the model, accounting for parent classes."""
cls = self._cls
# First, collect fields and ClassVars belonging to any class in the MRO, ignoring duplicates.
# First, collect fields belonging to any class in the MRO, ignoring duplicates.
#
# We iterate through the MRO in reverse because attrs defined in the parent must appear
# earlier in the attributes list than attrs defined in the child. See:
@@ -605,11 +570,10 @@ class PydanticModelTransformer:
# in the parent. We can implement this via a dict without disrupting the attr order
# because dicts preserve insertion order in Python 3.7+.
found_fields: dict[str, PydanticModelField] = {}
found_class_vars: dict[str, PydanticModelClassVar] = {}
for info in reversed(cls.info.mro[1:-1]): # 0 is the current class, -2 is BaseModel, -1 is object
# if BASEMODEL_METADATA_TAG_KEY in info.metadata and BASEMODEL_METADATA_KEY not in info.metadata:
# # We haven't processed the base class yet. Need another pass.
# return None, None
# return None
if METADATA_KEY not in info.metadata:
continue
@@ -622,7 +586,8 @@ class PydanticModelTransformer:
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
field.expand_typevar_from_subtype(cls.info, self._api)
with state.strict_optional_set(self._api.options.strict_optional):
field.expand_typevar_from_subtype(cls.info)
found_fields[name] = field
sym_node = cls.info.names.get(name)
@@ -631,31 +596,17 @@ class PydanticModelTransformer:
'BaseModel field may only be overridden by another field',
sym_node.node,
)
# Collect ClassVars
for name, data in info.metadata[METADATA_KEY]['class_vars'].items():
found_class_vars[name] = PydanticModelClassVar.deserialize(data)
# Second, collect fields and ClassVars belonging to the current class.
# Second, collect fields belonging to the current class.
current_field_names: set[str] = set()
current_class_vars_names: set[str] = set()
for stmt in self._get_assignment_statements_from_block(cls.defs):
maybe_field = self.collect_field_or_class_var_from_stmt(stmt, model_config, found_class_vars)
if maybe_field is None:
continue
maybe_field = self.collect_field_from_stmt(stmt, model_config)
if maybe_field is not None:
lhs = stmt.lvalues[0]
current_field_names.add(lhs.name)
found_fields[lhs.name] = maybe_field
lhs = stmt.lvalues[0]
assert isinstance(lhs, NameExpr) # collect_field_or_class_var_from_stmt guarantees this
if isinstance(maybe_field, PydanticModelField):
if is_root_model and lhs.name != 'root':
error_extra_fields_on_root_model(self._api, stmt)
else:
current_field_names.add(lhs.name)
found_fields[lhs.name] = maybe_field
elif isinstance(maybe_field, PydanticModelClassVar):
current_class_vars_names.add(lhs.name)
found_class_vars[lhs.name] = maybe_field
return list(found_fields.values()), list(found_class_vars.values())
return list(found_fields.values())
def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]:
for body in stmt.body:
@@ -671,15 +622,14 @@ class PydanticModelTransformer:
elif isinstance(stmt, IfStmt):
yield from self._get_assignment_statements_from_if_statement(stmt)
def collect_field_or_class_var_from_stmt( # noqa C901
self, stmt: AssignmentStmt, model_config: ModelConfigData, class_vars: dict[str, PydanticModelClassVar]
) -> PydanticModelField | PydanticModelClassVar | None:
def collect_field_from_stmt( # noqa C901
self, stmt: AssignmentStmt, model_config: ModelConfigData
) -> PydanticModelField | None:
"""Get pydantic model field from statement.
Args:
stmt: The statement.
model_config: Configuration settings for the model.
class_vars: ClassVars already known to be defined on the model.
Returns:
A pydantic model field if it could find the field in statement. Otherwise, `None`.
@@ -702,10 +652,6 @@ class PydanticModelTransformer:
# Eventually, we may want to attempt to respect model_config['ignored_types']
return None
if lhs.name in class_vars:
# Class vars are not fields and are not required to be annotated
return None
# The assignment does not have an annotation, and it's not anything else we recognize
error_untyped_fields(self._api, stmt)
return None
@@ -750,7 +696,7 @@ class PydanticModelTransformer:
# x: ClassVar[int] is not a field
if node.is_classvar:
return PydanticModelClassVar(lhs.name)
return None
# x: InitVar[int] is not supported in BaseModel
node_type = get_proper_type(node.type)
@@ -761,7 +707,6 @@ class PydanticModelTransformer:
)
has_default = self.get_has_default(stmt)
strict = self.get_strict(stmt)
if sym.type is None and node.is_final and node.is_inferred:
# This follows the logic from the dataclasses plugin. The following comment is taken verbatim:
@@ -781,27 +726,16 @@ class PydanticModelTransformer:
)
node.type = AnyType(TypeOfAny.from_error)
if node.is_final and has_default:
# TODO this path should be removed (see https://github.com/pydantic/pydantic/issues/11119)
return PydanticModelClassVar(lhs.name)
alias, has_dynamic_alias = self.get_alias_info(stmt)
if (
has_dynamic_alias
and not (model_config.validate_by_name or model_config.populate_by_name)
and self.plugin_config.warn_required_dynamic_aliases
):
if has_dynamic_alias and not model_config.populate_by_name and self.plugin_config.warn_required_dynamic_aliases:
error_required_dynamic_aliases(self._api, stmt)
is_frozen = self.is_field_frozen(stmt)
init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt)
return PydanticModelField(
name=lhs.name,
has_dynamic_alias=has_dynamic_alias,
has_default=has_default,
strict=strict,
alias=alias,
is_frozen=is_frozen,
line=stmt.line,
column=stmt.column,
type=init_type,
@@ -846,9 +780,7 @@ class PydanticModelTransformer:
return default
def add_initializer(
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool
) -> None:
def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool) -> None:
"""Adds a fields-aware `__init__` method to the class.
The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
@@ -857,42 +789,28 @@ class PydanticModelTransformer:
return # Don't generate an __init__ if one already exists
typed = self.plugin_config.init_typed
model_strict = bool(config.strict)
use_alias = not (config.validate_by_name or config.populate_by_name) and config.validate_by_alias is not False
requires_dynamic_aliases = bool(config.has_alias_generator and not config.validate_by_name)
args = self.get_field_arguments(
fields,
typed=typed,
model_strict=model_strict,
requires_dynamic_aliases=requires_dynamic_aliases,
use_alias=use_alias,
is_settings=is_settings,
is_root_model=is_root_model,
force_typevars_invariant=True,
)
if is_settings:
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
assert isinstance(base_settings_node, TypeInfo)
if '__init__' in base_settings_node.names:
base_settings_init_node = base_settings_node.names['__init__'].node
assert isinstance(base_settings_init_node, FuncDef)
if base_settings_init_node is not None and base_settings_init_node.type is not None:
func_type = base_settings_init_node.type
assert isinstance(func_type, CallableType)
for arg_idx, arg_name in enumerate(func_type.arg_names):
if arg_name is None or arg_name.startswith('__') or not arg_name.startswith('_'):
continue
analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx])
if analyzed_variable_type is not None and arg_name == '_cli_settings_source':
# _cli_settings_source is defined as CliSettingsSource[Any], and as such
# the Any causes issues with --disallow-any-explicit. As a workaround, change
# the Any type (as if CliSettingsSource was left unparameterized):
analyzed_variable_type = analyzed_variable_type.accept(
ChangeExplicitTypeOfAny(TypeOfAny.from_omitted_generics)
)
variable = Var(arg_name, analyzed_variable_type)
args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT))
use_alias = config.populate_by_name is not True
requires_dynamic_aliases = bool(config.has_alias_generator and not config.populate_by_name)
with state.strict_optional_set(self._api.options.strict_optional):
args = self.get_field_arguments(
fields,
typed=typed,
requires_dynamic_aliases=requires_dynamic_aliases,
use_alias=use_alias,
is_settings=is_settings,
)
if is_settings:
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
if '__init__' in base_settings_node.names:
base_settings_init_node = base_settings_node.names['__init__'].node
if base_settings_init_node is not None and base_settings_init_node.type is not None:
func_type = base_settings_init_node.type
for arg_idx, arg_name in enumerate(func_type.arg_names):
if arg_name.startswith('__') or not arg_name.startswith('_'):
continue
analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx])
variable = Var(arg_name, analyzed_variable_type)
args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT))
if not self.should_init_forbid_extra(fields, config):
var = Var('kwargs')
@@ -901,11 +819,7 @@ class PydanticModelTransformer:
add_method(self._api, self._cls, '__init__', args=args, return_type=NoneType())
def add_model_construct_method(
self,
fields: list[PydanticModelField],
config: ModelConfigData,
is_settings: bool,
is_root_model: bool,
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool
) -> None:
"""Adds a fully typed `model_construct` classmethod to the class.
@@ -917,19 +831,13 @@ class PydanticModelTransformer:
fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT)
with state.strict_optional_set(self._api.options.strict_optional):
args = self.get_field_arguments(
fields,
typed=True,
model_strict=bool(config.strict),
requires_dynamic_aliases=False,
use_alias=False,
is_settings=is_settings,
is_root_model=is_root_model,
fields, typed=True, requires_dynamic_aliases=False, use_alias=False, is_settings=is_settings
)
if not self.should_init_forbid_extra(fields, config):
var = Var('kwargs')
args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
args = args + [fields_set_argument] if is_root_model else [fields_set_argument] + args
args = [fields_set_argument] + args
add_method(
self._api,
@@ -940,7 +848,7 @@ class PydanticModelTransformer:
is_classmethod=True,
)
def set_frozen(self, fields: list[PydanticModelField], api: SemanticAnalyzerPluginInterface, frozen: bool) -> None:
def set_frozen(self, fields: list[PydanticModelField], frozen: bool) -> None:
"""Marks all fields as properties so that attempts to set them trigger mypy errors.
This is the same approach used by the attrs and dataclasses plugins.
@@ -951,7 +859,7 @@ class PydanticModelTransformer:
if sym_node is not None:
var = sym_node.node
if isinstance(var, Var):
var.is_property = frozen or field.is_frozen
var.is_property = frozen
elif isinstance(var, PlaceholderNode) and not self._api.final_iteration:
# See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage
self._api.defer()
@@ -965,13 +873,13 @@ class PydanticModelTransformer:
detail = f'sym_node.node: {var_str} (of type {var.__class__})'
error_unexpected_behavior(detail, self._api, self._cls)
else:
var = field.to_var(info, api, use_alias=False)
var = field.to_var(info, use_alias=False)
var.info = info
var.is_property = frozen
var._fullname = info.fullname + '.' + var.name
info.names[var.name] = SymbolTableNode(MDEF, var)
def get_config_update(self, name: str, arg: Expression, lax_extra: bool = False) -> ModelConfigData | None:
def get_config_update(self, name: str, arg: Expression) -> ModelConfigData | None:
"""Determines the config update due to a single kwarg in the ConfigDict definition.
Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int)
@@ -984,16 +892,7 @@ class PydanticModelTransformer:
elif isinstance(arg, MemberExpr):
forbid_extra = arg.name == 'forbid'
else:
if not lax_extra:
# Only emit an error for other types of `arg` (e.g., `NameExpr`, `ConditionalExpr`, etc.) when
# reading from a config class, etc. If a ConfigDict is used, then we don't want to emit an error
# because you'll get type checking from the ConfigDict itself.
#
# It would be nice if we could introspect the types better otherwise, but I don't know what the API
# is to evaluate an expr into its type and then check if that type is compatible with the expected
# type. Note that you can still get proper type checking via: `model_config = ConfigDict(...)`, just
# if you don't use an explicit string, the plugin won't be able to infer whether extra is forbidden.
error_invalid_config_value(name, self._api, arg)
error_invalid_config_value(name, self._api, arg)
return None
return ModelConfigData(forbid_extra=forbid_extra)
if name == 'alias_generator':
@@ -1028,22 +927,6 @@ class PydanticModelTransformer:
# Has no default if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`)
return not isinstance(expr, EllipsisExpr)
@staticmethod
def get_strict(stmt: AssignmentStmt) -> bool | None:
"""Returns a the `strict` value of a field if defined, otherwise `None`."""
expr = stmt.rvalue
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME:
for arg, name in zip(expr.args, expr.arg_names):
if name != 'strict':
continue
if isinstance(arg, NameExpr):
if arg.fullname == 'builtins.True':
return True
elif arg.fullname == 'builtins.False':
return False
return None
return None
@staticmethod
def get_alias_info(stmt: AssignmentStmt) -> tuple[str | None, bool]:
"""Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`.
@@ -1062,53 +945,23 @@ class PydanticModelTransformer:
# Assigned value is not a call to pydantic.fields.Field
return None, False
if 'validation_alias' in expr.arg_names:
arg = expr.args[expr.arg_names.index('validation_alias')]
elif 'alias' in expr.arg_names:
arg = expr.args[expr.arg_names.index('alias')]
else:
return None, False
if isinstance(arg, StrExpr):
return arg.value, False
else:
return None, True
@staticmethod
def is_field_frozen(stmt: AssignmentStmt) -> bool:
"""Returns whether the field is frozen, extracted from the declaration of the field defined in `stmt`.
Note that this is only whether the field was declared to be frozen in a `<field_name> = Field(frozen=True)`
sense; this does not determine whether the field is frozen because the entire model is frozen; that is
handled separately.
"""
expr = stmt.rvalue
if isinstance(expr, TempNode):
# TempNode means annotation-only
return False
if not (
isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME
):
# Assigned value is not a call to pydantic.fields.Field
return False
for i, arg_name in enumerate(expr.arg_names):
if arg_name == 'frozen':
arg = expr.args[i]
return isinstance(arg, NameExpr) and arg.fullname == 'builtins.True'
return False
if arg_name != 'alias':
continue
arg = expr.args[i]
if isinstance(arg, StrExpr):
return arg.value, False
else:
return None, True
return None, False
def get_field_arguments(
self,
fields: list[PydanticModelField],
typed: bool,
model_strict: bool,
use_alias: bool,
requires_dynamic_aliases: bool,
is_settings: bool,
is_root_model: bool,
force_typevars_invariant: bool = False,
) -> list[Argument]:
"""Helper function used during the construction of the `__init__` and `model_construct` method signatures.
@@ -1117,14 +970,7 @@ class PydanticModelTransformer:
info = self._cls.info
arguments = [
field.to_argument(
info,
typed=typed,
model_strict=model_strict,
force_optional=requires_dynamic_aliases or is_settings,
use_alias=use_alias,
api=self._api,
force_typevars_invariant=force_typevars_invariant,
is_root_model_root=is_root_model and field.name == 'root',
info, typed=typed, force_optional=requires_dynamic_aliases or is_settings, use_alias=use_alias
)
for field in fields
if not (use_alias and field.has_dynamic_alias)
@@ -1137,7 +983,7 @@ class PydanticModelTransformer:
We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to,
*unless* a required dynamic alias is present (since then we can't determine a valid signature).
"""
if not (config.validate_by_name or config.populate_by_name):
if not config.populate_by_name:
if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)):
return False
if config.forbid_extra:
@@ -1159,20 +1005,6 @@ class PydanticModelTransformer:
return False
class ChangeExplicitTypeOfAny(TypeTranslator):
"""A type translator used to change type of Any's, if explicit."""
def __init__(self, type_of_any: int) -> None:
self._type_of_any = type_of_any
super().__init__()
def visit_any(self, t: AnyType) -> Type: # noqa: D102
if t.type_of_any == TypeOfAny.explicit:
return t.copy_modified(type_of_any=self._type_of_any)
else:
return t
class ModelConfigData:
"""Pydantic mypy plugin model config class."""
@@ -1182,19 +1014,13 @@ class ModelConfigData:
frozen: bool | None = None,
from_attributes: bool | None = None,
populate_by_name: bool | None = None,
validate_by_alias: bool | None = None,
validate_by_name: bool | None = None,
has_alias_generator: bool | None = None,
strict: bool | None = None,
):
self.forbid_extra = forbid_extra
self.frozen = frozen
self.from_attributes = from_attributes
self.populate_by_name = populate_by_name
self.validate_by_alias = validate_by_alias
self.validate_by_name = validate_by_name
self.has_alias_generator = has_alias_generator
self.strict = strict
def get_values_dict(self) -> dict[str, Any]:
"""Returns a dict of Pydantic model config names to their values.
@@ -1216,18 +1042,12 @@ class ModelConfigData:
setattr(self, key, value)
def is_root_model(info: TypeInfo) -> bool:
"""Return whether the type info is a root model subclass (or the `RootModel` class itself)."""
return info.has_base(ROOT_MODEL_FULLNAME)
ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_attributes call', 'Pydantic')
ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic')
ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic')
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic')
def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
@@ -1264,9 +1084,9 @@ def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context)
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)
def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None:
"""Emits an error when there is more than just a root field defined for a subclass of RootModel."""
api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL)
def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
"""Emits an error when `Field` has both `default` and `default_factory` together."""
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
def add_method(
@@ -1276,7 +1096,7 @@ def add_method(
args: list[Argument],
return_type: Type,
self_type: Type | None = None,
tvar_def: TypeVarType | None = None,
tvar_def: TypeVarDef | None = None,
is_classmethod: bool = False,
) -> None:
"""Very closely related to `mypy.plugins.common.add_method_to_class`, with a few pydantic-specific changes."""
@@ -1299,16 +1119,6 @@ def add_method(
first = [Argument(Var('_cls'), self_type, None, ARG_POS, True)]
else:
self_type = self_type or fill_typevars(info)
# `self` is positional *ONLY* here, but this can't be expressed
# fully in the mypy internal API. ARG_POS is the closest we can get.
# Using ARG_POS will, however, give mypy errors if a `self` field
# is present on a model:
#
# Name "self" already defined (possibly by an import) [no-redef]
#
# As a workaround, we give this argument a name that will
# never conflict. By its positional nature, this name will not
# be used or exposed to users.
first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)]
args = first + args
@@ -1319,9 +1129,9 @@ def add_method(
arg_names.append(arg.variable.name)
arg_kinds.append(arg.kind)
signature = CallableType(
arg_types, arg_kinds, arg_names, return_type, function_type, variables=[tvar_def] if tvar_def else None
)
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
if tvar_def:
signature.variables = [tvar_def]
func = FuncDef(name, args, Block([PassStmt()]))
func.info = info

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,4 @@
"""The `parse` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -1,12 +1,10 @@
"""!!! abstract "Usage Documentation"
[Build a Plugin](../concepts/plugins.md#build-a-plugin)
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/plugins#build-a-plugin
Plugin interface for Pydantic plugins, and related types.
"""
from __future__ import annotations
from typing import Any, Callable, Literal, NamedTuple
from typing import Any, Callable
from pydantic_core import CoreConfig, CoreSchema, ValidationError
from typing_extensions import Protocol, TypeAlias
@@ -18,32 +16,17 @@ __all__ = (
'ValidateJsonHandlerProtocol',
'ValidateStringsHandlerProtocol',
'NewSchemaReturns',
'SchemaTypePath',
'SchemaKind',
)
NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]'
class SchemaTypePath(NamedTuple):
"""Path defining where `schema_type` was defined, or where `TypeAdapter` was called."""
module: str
name: str
SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call']
class PydanticPluginProtocol(Protocol):
"""Protocol defining the interface for Pydantic plugins."""
def new_schema_validator(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: CoreConfig | None,
plugin_settings: dict[str, object],
) -> tuple[
@@ -57,9 +40,6 @@ class PydanticPluginProtocol(Protocol):
Args:
schema: The schema to validate against.
schema_type: The original type which the schema was created from, e.g. the model class.
schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called.
schema_kind: The kind of schema to validate against.
config: The config to use for validation.
plugin_settings: Any plugin settings.
@@ -96,14 +76,6 @@ class BaseValidateHandlerProtocol(Protocol):
"""
return
def on_exception(self, exception: Exception) -> None:
"""Callback to be notified of validation exceptions.
Args:
exception: The exception raised during validation.
"""
return
class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
"""Event handler for `SchemaValidator.validate_python`."""
@@ -116,8 +88,6 @@ class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
by_alias: bool | None = None,
by_name: bool | None = None,
) -> None:
"""Callback to be notified of validation start, and create an instance of the event handler.
@@ -128,8 +98,6 @@ class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
context: The context to use for validation, this is passed to functional validators.
self_instance: An instance of a model to set attributes on from validation, this is used when running
validation from the `__init__` method of a model.
by_alias: Whether to use the field's alias to match the input data to an attribute.
by_name: Whether to use the field's name to match the input data to an attribute.
"""
pass
@@ -144,8 +112,6 @@ class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
strict: bool | None = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
by_alias: bool | None = None,
by_name: bool | None = None,
) -> None:
"""Callback to be notified of validation start, and create an instance of the event handler.
@@ -155,8 +121,6 @@ class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
context: The context to use for validation, this is passed to functional validators.
self_instance: An instance of a model to set attributes on from validation, this is used when running
validation from the `__init__` method of a model.
by_alias: Whether to use the field's alias to match the input data to an attribute.
by_name: Whether to use the field's name to match the input data to an attribute.
"""
pass
@@ -168,13 +132,7 @@ class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
"""Event handler for `SchemaValidator.validate_strings`."""
def on_enter(
self,
input: StringInput,
*,
strict: bool | None = None,
context: dict[str, Any] | None = None,
by_alias: bool | None = None,
by_name: bool | None = None,
self, input: StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None
) -> None:
"""Callback to be notified of validation start, and create an instance of the event handler.
@@ -182,7 +140,5 @@ class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
input: The string data to be validated.
strict: Whether to validate the object in strict mode.
context: The context to use for validation, this is passed to functional validators.
by_alias: Whether to use the field's alias to match the input data to an attribute.
by_name: Whether to use the field's name to match the input data to an attribute.
"""
pass

View File

@@ -1,10 +1,16 @@
from __future__ import annotations
import importlib.metadata as importlib_metadata
import os
import sys
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Final
from typing import TYPE_CHECKING, Iterable
from typing_extensions import Final
if sys.version_info >= (3, 8):
import importlib.metadata as importlib_metadata
else:
import importlib_metadata
if TYPE_CHECKING:
from . import PydanticPluginProtocol
@@ -24,13 +30,10 @@ def get_plugins() -> Iterable[PydanticPluginProtocol]:
Inspired by: https://github.com/pytest-dev/pluggy/blob/1.3.0/src/pluggy/_manager.py#L376-L402
"""
disabled_plugins = os.getenv('PYDANTIC_DISABLE_PLUGINS')
global _plugins, _loading_plugins
if _loading_plugins:
# this happens when plugins themselves use pydantic, we return no plugins
return ()
elif disabled_plugins in ('__all__', '1', 'true'):
return ()
elif _plugins is None:
_plugins = {}
# set _loading_plugins so any plugins that use pydantic don't themselves use plugins
@@ -42,8 +45,6 @@ def get_plugins() -> Iterable[PydanticPluginProtocol]:
continue
if entry_point.value in _plugins:
continue
if disabled_plugins is not None and entry_point.name in disabled_plugins.split(','):
continue
try:
_plugins[entry_point.value] = entry_point.load()
except (ImportError, AttributeError) as e:

View File

@@ -1,16 +1,14 @@
"""Pluggable schema validator for pydantic."""
from __future__ import annotations
import functools
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar
from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError
from typing_extensions import ParamSpec
from typing_extensions import Literal, ParamSpec
if TYPE_CHECKING:
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol
P = ParamSpec('P')
@@ -20,33 +18,18 @@ events: list[Event] = list(Event.__args__) # type: ignore
def create_schema_validator(
schema: CoreSchema,
schema_type: Any,
schema_type_module: str,
schema_type_name: str,
schema_kind: SchemaKind,
config: CoreConfig | None = None,
plugin_settings: dict[str, Any] | None = None,
) -> SchemaValidator | PluggableSchemaValidator:
schema: CoreSchema, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None
) -> SchemaValidator:
"""Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
Returns:
If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
"""
from . import SchemaTypePath
from ._loader import get_plugins
plugins = get_plugins()
if plugins:
return PluggableSchemaValidator(
schema,
schema_type,
SchemaTypePath(schema_type_module, schema_type_name),
schema_kind,
config,
plugins,
plugin_settings or {},
)
return PluggableSchemaValidator(schema, config, plugins, plugin_settings or {}) # type: ignore
else:
return SchemaValidator(schema, config)
@@ -59,9 +42,6 @@ class PluggableSchemaValidator:
def __init__(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: CoreConfig | None,
plugins: Iterable[PydanticPluginProtocol],
plugin_settings: dict[str, Any],
@@ -72,12 +52,7 @@ class PluggableSchemaValidator:
json_event_handlers: list[BaseValidateHandlerProtocol] = []
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
for plugin in plugins:
try:
p, j, s = plugin.new_schema_validator(
schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
)
except TypeError as e: # pragma: no cover
raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e
p, j, s = plugin.new_schema_validator(schema, config, plugin_settings)
if p is not None:
python_event_handlers.append(p)
if j is not None:
@@ -100,7 +75,6 @@ def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandler
on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter'))
on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success'))
on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error'))
on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception'))
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
@@ -113,10 +87,6 @@ def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandler
for on_error_handler in on_errors:
on_error_handler(error)
raise
except Exception as exception:
for on_exception_handler in on_exceptions:
on_exception_handler(exception)
raise
else:
for on_success_handler in on_successes:
on_success_handler(result)

View File

@@ -8,33 +8,25 @@ from copy import copy, deepcopy
from pydantic_core import PydanticUndefined
from . import PydanticUserError
from ._internal import _model_construction, _repr
from ._internal import _repr
from .main import BaseModel, _object_setattr
if typing.TYPE_CHECKING:
from typing import Any, Literal
from typing import Any
from typing_extensions import Self, dataclass_transform
from typing_extensions import Literal
from .fields import Field as PydanticModelField
from .fields import PrivateAttr as PydanticModelPrivateAttr
Model = typing.TypeVar('Model', bound='BaseModel')
# dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform
# takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform
# on a new metaclass.
@dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr))
class _RootModelMetaclass(_model_construction.ModelMetaclass): ...
else:
_RootModelMetaclass = _model_construction.ModelMetaclass
__all__ = ('RootModel',)
RootModelRootType = typing.TypeVar('RootModelRootType')
class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass):
"""!!! abstract "Usage Documentation"
[`RootModel` and Custom Root Types](../concepts/models.md#rootmodel-and-custom-root-types)
class RootModel(BaseModel, typing.Generic[RootModelRootType]):
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/models/#rootmodel-and-custom-root-types
A Pydantic `BaseModel` for the root object of the model.
@@ -60,7 +52,7 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
)
super().__init_subclass__(**kwargs)
def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
def __init__(__pydantic_self__, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
__tracebackhide__ = True
if data:
if root is not PydanticUndefined:
@@ -68,12 +60,12 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
'"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments'
)
root = data # type: ignore
self.__pydantic_validator__.validate_python(root, self_instance=self)
__pydantic_self__.__pydantic_validator__.validate_python(root, self_instance=__pydantic_self__)
__init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess]
__init__.__pydantic_base_init__ = True
@classmethod
def model_construct(cls, root: RootModelRootType, _fields_set: set[str] | None = None) -> Self: # type: ignore
def model_construct(cls: type[Model], root: RootModelRootType, _fields_set: set[str] | None = None) -> Model:
"""Create a new model using the provided root object and update fields set.
Args:
@@ -98,7 +90,7 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
_object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
_object_setattr(self, '__dict__', state['__dict__'])
def __copy__(self) -> Self:
def __copy__(self: Model) -> Model:
"""Returns a shallow copy of the model."""
cls = type(self)
m = cls.__new__(cls)
@@ -106,7 +98,7 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
return m
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model:
"""Returns a deep copy of the model."""
cls = type(self)
m = cls.__new__(cls)
@@ -118,40 +110,30 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
if typing.TYPE_CHECKING:
def model_dump( # type: ignore
def model_dump(
self,
*,
mode: Literal['json', 'python'] | str = 'python',
include: Any = None,
exclude: Any = None,
context: dict[str, Any] | None = None,
by_alias: bool | None = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal['none', 'warn', 'error'] = True,
serialize_as_any: bool = False,
) -> Any:
warnings: bool = True,
) -> RootModelRootType:
"""This method is included just to get a more accurate return type for type checkers.
It is included in this `if TYPE_CHECKING:` block since no override is actually necessary.
See the documentation of `BaseModel.model_dump` for more details about the arguments.
Generally, this method will have a return type of `RootModelRootType`, assuming that `RootModelRootType` is
not a `BaseModel` subclass. If `RootModelRootType` is a `BaseModel` subclass, then the return
type will likely be `dict[str, Any]`, as `model_dump` calls are recursive. The return type could
even be something different, in the case of a custom serializer.
Thus, `Any` is used here to catch all of these cases.
"""
...
def __eq__(self, other: Any) -> bool:
if not isinstance(other, RootModel):
return NotImplemented
return self.__pydantic_fields__['root'].annotation == other.__pydantic_fields__[
'root'
].annotation and super().__eq__(other)
return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other)
def __repr_args__(self) -> _repr.ReprArgs:
yield 'root', self.root

View File

@@ -1,5 +1,4 @@
"""The `schema` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -1,5 +1,4 @@
"""The `tools` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -1,30 +1,93 @@
"""Type adapter specification."""
"""
You may have types that are not `BaseModel`s that you want to validate data against.
Or you may want to validate a `List[SomeModel]`, or dump it to JSON.
For use cases like this, Pydantic provides [`TypeAdapter`][pydantic.type_adapter.TypeAdapter],
which can be used for type validation, serialization, and JSON schema generation without creating a
[`BaseModel`][pydantic.main.BaseModel].
A [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] instance exposes some of the functionality from
[`BaseModel`][pydantic.main.BaseModel] instance methods for types that do not have such methods
(such as dataclasses, primitive types, and more):
```py
from typing import List
from typing_extensions import TypedDict
from pydantic import TypeAdapter, ValidationError
class User(TypedDict):
name: str
id: int
UserListValidator = TypeAdapter(List[User])
print(repr(UserListValidator.validate_python([{'name': 'Fred', 'id': '3'}])))
#> [{'name': 'Fred', 'id': 3}]
try:
UserListValidator.validate_python(
[{'name': 'Fred', 'id': 'wrong', 'other': 'no'}]
)
except ValidationError as e:
print(e)
'''
1 validation error for list[typed-dict]
0.id
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='wrong', input_type=str]
'''
```
Note:
Despite some overlap in use cases with [`RootModel`][pydantic.root_model.RootModel],
[`TypeAdapter`][pydantic.type_adapter.TypeAdapter] should not be used as a type annotation for
specifying fields of a `BaseModel`, etc.
## Parsing data into a specified type
[`TypeAdapter`][pydantic.type_adapter.TypeAdapter] can be used to apply the parsing logic to populate Pydantic models
in a more ad-hoc way. This function behaves similarly to
[`BaseModel.model_validate`][pydantic.main.BaseModel.model_validate],
but works with arbitrary Pydantic-compatible types.
This is especially useful when you want to parse results into a type that is not a direct subclass of
[`BaseModel`][pydantic.main.BaseModel]. For example:
```py
from typing import List
from pydantic import BaseModel, TypeAdapter
class Item(BaseModel):
id: int
name: str
# `item_data` could come from an API call, eg., via something like:
# item_data = requests.get('https://my-api.com/items').json()
item_data = [{'id': 1, 'name': 'My Item'}]
items = TypeAdapter(List[Item]).validate_python(item_data)
print(items)
#> [Item(id=1, name='My Item')]
```
[`TypeAdapter`][pydantic.type_adapter.TypeAdapter] is capable of parsing data into any of the types Pydantic can
handle as fields of a [`BaseModel`][pydantic.main.BaseModel].
""" # noqa: D212
from __future__ import annotations as _annotations
import sys
from collections.abc import Callable, Iterable
from dataclasses import is_dataclass
from types import FrameType
from typing import (
Any,
Generic,
Literal,
TypeVar,
cast,
final,
overload,
)
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, overload
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some
from typing_extensions import ParamSpec, is_typeddict
from typing_extensions import Literal, is_typeddict
from pydantic.errors import PydanticUserError
from pydantic.main import BaseModel, IncEx
from pydantic.main import BaseModel
from ._internal import _config, _generate_schema, _mock_val_ser, _namespace_utils, _repr, _typing_extra, _utils
from ._internal import _config, _core_utils, _discriminated_union, _generate_schema, _typing_extra
from .config import ConfigDict
from .errors import PydanticUndefinedAnnotation
from .json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
@@ -32,12 +95,67 @@ from .json_schema import (
JsonSchemaMode,
JsonSchemaValue,
)
from .plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
from .plugin._schema_validator import create_schema_validator
T = TypeVar('T')
R = TypeVar('R')
P = ParamSpec('P')
TypeAdapterT = TypeVar('TypeAdapterT', bound='TypeAdapter')
if TYPE_CHECKING:
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
def _get_schema(type_: Any, config_wrapper: _config.ConfigWrapper, parent_depth: int) -> CoreSchema:
"""`BaseModel` uses its own `__module__` to find out where it was defined
and then look for symbols to resolve forward references in those globals.
On the other hand this function can be called with arbitrary objects,
including type aliases where `__module__` (always `typing.py`) is not useful.
So instead we look at the globals in our parent stack frame.
This works for the case where this function is called in a module that
has the target of forward references in its scope, but
does not work for more complex cases.
For example, take the following:
a.py
```python
from typing import Dict, List
IntList = List[int]
OuterDict = Dict[str, 'IntList']
```
b.py
```python test="skip"
from a import OuterDict
from pydantic import TypeAdapter
IntList = int # replaces the symbol the forward reference is looking for
v = TypeAdapter(OuterDict)
v({'x': 1}) # should fail but doesn't
```
If OuterDict were a `BaseModel`, this would work because it would resolve
the forward reference within the `a.py` namespace.
But `TypeAdapter(OuterDict)`
can't know what module OuterDict came from.
In other words, the assumption that _all_ forward references exist in the
module we are being called from is not technically always true.
Although most of the time it is and it works fine for recursive models and such,
`BaseModel`'s behavior isn't perfect either and _can_ break in similar ways,
so there is no right or wrong between the two.
But at the very least this behavior is _subtly_ different from `BaseModel`'s.
"""
local_ns = _typing_extra.parent_frame_namespace(parent_depth=parent_depth)
global_ns = sys._getframe(max(parent_depth - 1, 1)).f_globals.copy()
global_ns.update(local_ns or {})
gen = _generate_schema.GenerateSchema(config_wrapper, types_namespace=global_ns, typevars_map={})
schema = gen.generate_schema(type_)
schema = gen.collect_definitions(schema)
return schema
def _getattr_no_parents(obj: Any, attribute: str) -> Any:
@@ -55,152 +173,59 @@ def _getattr_no_parents(obj: Any, attribute: str) -> Any:
raise AttributeError(attribute)
def _type_has_config(type_: Any) -> bool:
"""Returns whether the type has config."""
type_ = _typing_extra.annotated_type(type_) or type_
try:
return issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_)
except TypeError:
# type is not a class
return False
@final
class TypeAdapter(Generic[T]):
"""!!! abstract "Usage Documentation"
[`TypeAdapter`](../concepts/type_adapter.md)
Type adapters provide a flexible way to perform validation and serialization based on a Python type.
"""Type adapters provide a flexible way to perform validation and serialization based on a Python type.
A `TypeAdapter` instance exposes some of the functionality from `BaseModel` instance methods
for types that do not have such methods (such as dataclasses, primitive types, and more).
**Note:** `TypeAdapter` instances are not types, and cannot be used as type annotations for fields.
Args:
type: The type associated with the `TypeAdapter`.
config: Configuration for the `TypeAdapter`, should be a dictionary conforming to
[`ConfigDict`][pydantic.config.ConfigDict].
!!! note
You cannot provide a configuration when instantiating a `TypeAdapter` if the type you're using
has its own config that cannot be overridden (ex: `BaseModel`, `TypedDict`, and `dataclass`). A
[`type-adapter-config-unused`](../errors/usage_errors.md#type-adapter-config-unused) error will
be raised in this case.
_parent_depth: Depth at which to search for the [parent frame][frame-objects]. This frame is used when
resolving forward annotations during schema building, by looking for the globals and locals of this
frame. Defaults to 2, which will result in the frame where the `TypeAdapter` was instantiated.
!!! note
This parameter is named with an underscore to suggest its private nature and discourage use.
It may be deprecated in a minor version, so we only recommend using it if you're comfortable
with potential change in behavior/support. It's default value is 2 because internally,
the `TypeAdapter` class makes another call to fetch the frame.
module: The module that passes to plugin if provided.
Note that `TypeAdapter` is not an actual type, so you cannot use it in type annotations.
Attributes:
core_schema: The core schema for the type.
validator: The schema validator for the type.
validator (SchemaValidator): The schema validator for the type.
serializer: The schema serializer for the type.
pydantic_complete: Whether the core schema for the type is successfully built.
??? tip "Compatibility with `mypy`"
Depending on the type used, `mypy` might raise an error when instantiating a `TypeAdapter`. As a workaround, you can explicitly
annotate your variable:
```py
from typing import Union
from pydantic import TypeAdapter
ta: TypeAdapter[Union[str, int]] = TypeAdapter(Union[str, int]) # type: ignore[arg-type]
```
??? info "Namespace management nuances and implementation details"
Here, we collect some notes on namespace management, and subtle differences from `BaseModel`:
`BaseModel` uses its own `__module__` to find out where it was defined
and then looks for symbols to resolve forward references in those globals.
On the other hand, `TypeAdapter` can be initialized with arbitrary objects,
which may not be types and thus do not have a `__module__` available.
So instead we look at the globals in our parent stack frame.
It is expected that the `ns_resolver` passed to this function will have the correct
namespace for the type we're adapting. See the source code for `TypeAdapter.__init__`
and `TypeAdapter.rebuild` for various ways to construct this namespace.
This works for the case where this function is called in a module that
has the target of forward references in its scope, but
does not always work for more complex cases.
For example, take the following:
```python {title="a.py"}
IntList = list[int]
OuterDict = dict[str, 'IntList']
```
```python {test="skip" title="b.py"}
from a import OuterDict
from pydantic import TypeAdapter
IntList = int # replaces the symbol the forward reference is looking for
v = TypeAdapter(OuterDict)
v({'x': 1}) # should fail but doesn't
```
If `OuterDict` were a `BaseModel`, this would work because it would resolve
the forward reference within the `a.py` namespace.
But `TypeAdapter(OuterDict)` can't determine what module `OuterDict` came from.
In other words, the assumption that _all_ forward references exist in the
module we are being called from is not technically always true.
Although most of the time it is and it works fine for recursive models and such,
`BaseModel`'s behavior isn't perfect either and _can_ break in similar ways,
so there is no right or wrong between the two.
But at the very least this behavior is _subtly_ different from `BaseModel`'s.
"""
core_schema: CoreSchema
validator: SchemaValidator | PluggableSchemaValidator
serializer: SchemaSerializer
pydantic_complete: bool
if TYPE_CHECKING:
@overload
def __init__(
self,
type: type[T],
*,
config: ConfigDict | None = ...,
_parent_depth: int = ...,
module: str | None = ...,
) -> None: ...
@overload
def __new__(cls, __type: type[T], *, config: ConfigDict | None = ...) -> TypeAdapter[T]:
...
# This second overload is for unsupported special forms (such as Annotated, Union, etc.)
# Currently there is no way to type this correctly
# See https://github.com/python/typing/pull/1618
@overload
def __init__(
self,
type: Any,
*,
config: ConfigDict | None = ...,
_parent_depth: int = ...,
module: str | None = ...,
) -> None: ...
# this overload is for non-type things like Union[int, str]
# Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object]
# so an explicit type cast is needed
@overload
def __new__(cls, __type: T, *, config: ConfigDict | None = ...) -> TypeAdapter[T]:
...
def __init__(
self,
type: Any,
*,
config: ConfigDict | None = None,
_parent_depth: int = 2,
module: str | None = None,
) -> None:
if _type_has_config(type) and config is not None:
def __new__(cls, __type: Any, *, config: ConfigDict | None = ...) -> TypeAdapter[T]:
"""A class representing the type adapter."""
raise NotImplementedError
@overload
def __init__(self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
...
# this overload is for non-type things like Union[int, str]
# Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object]
# so an explicit type cast is needed
@overload
def __init__(self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
...
def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
"""Initializes the TypeAdapter object."""
config_wrapper = _config.ConfigWrapper(config)
try:
type_has_config = issubclass(type, BaseModel) or is_dataclass(type) or is_typeddict(type)
except TypeError:
# type is not a class
type_has_config = False
if type_has_config and config is not None:
raise PydanticUserError(
'Cannot use `config` when the type is a BaseModel, dataclass or TypedDict.'
' These types can have their own config and setting the config via the `config`'
@@ -209,313 +234,81 @@ class TypeAdapter(Generic[T]):
code='type-adapter-config-unused',
)
self._type = type
self._config = config
self._parent_depth = _parent_depth
self.pydantic_complete = False
parent_frame = self._fetch_parent_frame()
if parent_frame is not None:
globalns = parent_frame.f_globals
# Do not provide a local ns if the type adapter happens to be instantiated at the module level:
localns = parent_frame.f_locals if parent_frame.f_locals is not globalns else {}
else:
globalns = {}
localns = {}
self._module_name = module or cast(str, globalns.get('__name__', ''))
self._init_core_attrs(
ns_resolver=_namespace_utils.NsResolver(
namespaces_tuple=_namespace_utils.NamespacesTuple(locals=localns, globals=globalns),
parent_namespace=localns,
),
force=False,
)
def _fetch_parent_frame(self) -> FrameType | None:
frame = sys._getframe(self._parent_depth)
if frame.f_globals.get('__name__') == 'typing':
# Because `TypeAdapter` is generic, explicitly parametrizing the class results
# in a `typing._GenericAlias` instance, which proxies instantiation calls to the
# "real" `TypeAdapter` class and thus adding an extra frame to the call. To avoid
# pulling anything from the `typing` module, use the correct frame (the one before):
return frame.f_back
return frame
def _init_core_attrs(
self, ns_resolver: _namespace_utils.NsResolver, force: bool, raise_errors: bool = False
) -> bool:
"""Initialize the core schema, validator, and serializer for the type.
Args:
ns_resolver: The namespace resolver to use when building the core schema for the adapted type.
force: Whether to force the construction of the core schema, validator, and serializer.
If `force` is set to `False` and `_defer_build` is `True`, the core schema, validator, and serializer will be set to mocks.
raise_errors: Whether to raise errors if initializing any of the core attrs fails.
Returns:
`True` if the core schema, validator, and serializer were successfully initialized, otherwise `False`.
Raises:
PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__`
and `raise_errors=True`.
"""
if not force and self._defer_build:
_mock_val_ser.set_type_adapter_mocks(self)
self.pydantic_complete = False
return False
core_schema: CoreSchema
try:
self.core_schema = _getattr_no_parents(self._type, '__pydantic_core_schema__')
self.validator = _getattr_no_parents(self._type, '__pydantic_validator__')
self.serializer = _getattr_no_parents(self._type, '__pydantic_serializer__')
# TODO: we don't go through the rebuild logic here directly because we don't want
# to repeat all of the namespace fetching logic that we've already done
# so we simply skip to the block below that does the actual schema generation
if (
isinstance(self.core_schema, _mock_val_ser.MockCoreSchema)
or isinstance(self.validator, _mock_val_ser.MockValSer)
or isinstance(self.serializer, _mock_val_ser.MockValSer)
):
raise AttributeError()
core_schema = _getattr_no_parents(type, '__pydantic_core_schema__')
except AttributeError:
config_wrapper = _config.ConfigWrapper(self._config)
core_schema = _get_schema(type, config_wrapper, parent_depth=_parent_depth + 1)
schema_generator = _generate_schema.GenerateSchema(config_wrapper, ns_resolver=ns_resolver)
core_schema = _discriminated_union.apply_discriminators(_core_utils.simplify_schema_references(core_schema))
try:
core_schema = schema_generator.generate_schema(self._type)
except PydanticUndefinedAnnotation:
if raise_errors:
raise
_mock_val_ser.set_type_adapter_mocks(self)
return False
core_schema = _core_utils.validate_core_schema(core_schema)
try:
self.core_schema = schema_generator.clean_schema(core_schema)
except _generate_schema.InvalidSchemaError:
_mock_val_ser.set_type_adapter_mocks(self)
return False
core_config = config_wrapper.core_config(None)
validator: SchemaValidator
try:
validator = _getattr_no_parents(type, '__pydantic_validator__')
except AttributeError:
validator = create_schema_validator(core_schema, core_config, config_wrapper.plugin_settings)
core_config = config_wrapper.core_config(None)
serializer: SchemaSerializer
try:
serializer = _getattr_no_parents(type, '__pydantic_serializer__')
except AttributeError:
serializer = SchemaSerializer(core_schema, core_config)
self.validator = create_schema_validator(
schema=self.core_schema,
schema_type=self._type,
schema_type_module=self._module_name,
schema_type_name=str(self._type),
schema_kind='TypeAdapter',
config=core_config,
plugin_settings=config_wrapper.plugin_settings,
)
self.serializer = SchemaSerializer(self.core_schema, core_config)
self.pydantic_complete = True
return True
@property
def _defer_build(self) -> bool:
config = self._config if self._config is not None else self._model_config
if config:
return config.get('defer_build') is True
return False
@property
def _model_config(self) -> ConfigDict | None:
type_: Any = _typing_extra.annotated_type(self._type) or self._type # Eg FastAPI heavily uses Annotated
if _utils.lenient_issubclass(type_, BaseModel):
return type_.model_config
return getattr(type_, '__pydantic_config__', None)
def __repr__(self) -> str:
return f'TypeAdapter({_repr.display_as_type(self._type)})'
def rebuild(
self,
*,
force: bool = False,
raise_errors: bool = True,
_parent_namespace_depth: int = 2,
_types_namespace: _namespace_utils.MappingNamespace | None = None,
) -> bool | None:
"""Try to rebuild the pydantic-core schema for the adapter's type.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
Args:
force: Whether to force the rebuilding of the type adapter's schema, defaults to `False`.
raise_errors: Whether to raise errors, defaults to `True`.
_parent_namespace_depth: Depth at which to search for the [parent frame][frame-objects]. This
frame is used when resolving forward annotations during schema rebuilding, by looking for
the locals of this frame. Defaults to 2, which will result in the frame where the method
was called.
_types_namespace: An explicit types namespace to use, instead of using the local namespace
from the parent frame. Defaults to `None`.
Returns:
Returns `None` if the schema is already "complete" and rebuilding was not required.
If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`.
"""
if not force and self.pydantic_complete:
return None
if _types_namespace is not None:
rebuild_ns = _types_namespace
elif _parent_namespace_depth > 0:
rebuild_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {}
else:
rebuild_ns = {}
# we have to manually fetch globals here because there's no type on the stack of the NsResolver
# and so we skip the globalns = get_module_ns_of(typ) call that would normally happen
globalns = sys._getframe(max(_parent_namespace_depth - 1, 1)).f_globals
ns_resolver = _namespace_utils.NsResolver(
namespaces_tuple=_namespace_utils.NamespacesTuple(locals=rebuild_ns, globals=globalns),
parent_namespace=rebuild_ns,
)
return self._init_core_attrs(ns_resolver=ns_resolver, force=True, raise_errors=raise_errors)
self.core_schema = core_schema
self.validator = validator
self.serializer = serializer
def validate_python(
self,
object: Any,
/,
__object: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
by_alias: bool | None = None,
by_name: bool | None = None,
) -> T:
"""Validate a Python object against the model.
Args:
object: The Python object to validate against the model.
__object: The Python object to validate against the model.
strict: Whether to strictly check types.
from_attributes: Whether to extract data from object attributes.
context: Additional context to pass to the validator.
experimental_allow_partial: **Experimental** whether to enable
[partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams.
* False / 'off': Default behavior, no partial validation.
* True / 'on': Enable partial validation.
* 'trailing-strings': Enable partial validation and allow trailing strings in the input.
by_alias: Whether to use the field's alias when validating against the provided input data.
by_name: Whether to use the field's name when validating against the provided input data.
!!! note
When using `TypeAdapter` with a Pydantic `dataclass`, the use of the `from_attributes`
argument is not supported.
Returns:
The validated object.
"""
if by_alias is False and by_name is not True:
raise PydanticUserError(
'At least one of `by_alias` or `by_name` must be set to True.',
code='validate-by-alias-and-name-false',
)
return self.validator.validate_python(
object,
strict=strict,
from_attributes=from_attributes,
context=context,
allow_partial=experimental_allow_partial,
by_alias=by_alias,
by_name=by_name,
)
return self.validator.validate_python(__object, strict=strict, from_attributes=from_attributes, context=context)
def validate_json(
self,
data: str | bytes | bytearray,
/,
*,
strict: bool | None = None,
context: dict[str, Any] | None = None,
experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
by_alias: bool | None = None,
by_name: bool | None = None,
self, __data: str | bytes, *, strict: bool | None = None, context: dict[str, Any] | None = None
) -> T:
"""!!! abstract "Usage Documentation"
[JSON Parsing](../concepts/json.md#json-parsing)
Validate a JSON string or bytes against the model.
"""Validate a JSON string or bytes against the model.
Args:
data: The JSON data to validate against the model.
__data: The JSON data to validate against the model.
strict: Whether to strictly check types.
context: Additional context to use during validation.
experimental_allow_partial: **Experimental** whether to enable
[partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams.
* False / 'off': Default behavior, no partial validation.
* True / 'on': Enable partial validation.
* 'trailing-strings': Enable partial validation and allow trailing strings in the input.
by_alias: Whether to use the field's alias when validating against the provided input data.
by_name: Whether to use the field's name when validating against the provided input data.
Returns:
The validated object.
"""
if by_alias is False and by_name is not True:
raise PydanticUserError(
'At least one of `by_alias` or `by_name` must be set to True.',
code='validate-by-alias-and-name-false',
)
return self.validator.validate_json(__data, strict=strict, context=context)
return self.validator.validate_json(
data,
strict=strict,
context=context,
allow_partial=experimental_allow_partial,
by_alias=by_alias,
by_name=by_name,
)
def validate_strings(
self,
obj: Any,
/,
*,
strict: bool | None = None,
context: dict[str, Any] | None = None,
experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
by_alias: bool | None = None,
by_name: bool | None = None,
) -> T:
def validate_strings(self, __obj: Any, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> T:
"""Validate object contains string data against the model.
Args:
obj: The object contains string data to validate.
__obj: The object contains string data to validate.
strict: Whether to strictly check types.
context: Additional context to use during validation.
experimental_allow_partial: **Experimental** whether to enable
[partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams.
* False / 'off': Default behavior, no partial validation.
* True / 'on': Enable partial validation.
* 'trailing-strings': Enable partial validation and allow trailing strings in the input.
by_alias: Whether to use the field's alias when validating against the provided input data.
by_name: Whether to use the field's name when validating against the provided input data.
Returns:
The validated object.
"""
if by_alias is False and by_name is not True:
raise PydanticUserError(
'At least one of `by_alias` or `by_name` must be set to True.',
code='validate-by-alias-and-name-false',
)
return self.validator.validate_strings(
obj,
strict=strict,
context=context,
allow_partial=experimental_allow_partial,
by_alias=by_alias,
by_name=by_name,
)
return self.validator.validate_strings(__obj, strict=strict, context=context)
def get_default_value(self, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> Some[T] | None:
"""Get the default value for the wrapped type.
@@ -531,26 +324,22 @@ class TypeAdapter(Generic[T]):
def dump_python(
self,
instance: T,
/,
__instance: T,
*,
mode: Literal['json', 'python'] = 'python',
include: IncEx | None = None,
exclude: IncEx | None = None,
by_alias: bool | None = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal['none', 'warn', 'error'] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
context: dict[str, Any] | None = None,
warnings: bool = True,
) -> Any:
"""Dump an instance of the adapted type to a Python object.
Args:
instance: The Python object to serialize.
__instance: The Python object to serialize.
mode: The output format.
include: Fields to include in the output.
exclude: Fields to exclude from the output.
@@ -559,18 +348,13 @@ class TypeAdapter(Generic[T]):
exclude_defaults: Whether to exclude fields with default values.
exclude_none: Whether to exclude fields with None values.
round_trip: Whether to output the serialized data in a way that is compatible with deserialization.
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
fallback: A function to call when an unknown value is encountered. If not provided,
a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
context: Additional context to pass to the serializer.
warnings: Whether to display serialization warnings.
Returns:
The serialized object.
"""
return self.serializer.to_python(
instance,
__instance,
mode=mode,
by_alias=by_alias,
include=include,
@@ -580,36 +364,26 @@ class TypeAdapter(Generic[T]):
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
fallback=fallback,
serialize_as_any=serialize_as_any,
context=context,
)
def dump_json(
self,
instance: T,
/,
__instance: T,
*,
indent: int | None = None,
include: IncEx | None = None,
exclude: IncEx | None = None,
by_alias: bool | None = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool | Literal['none', 'warn', 'error'] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
context: dict[str, Any] | None = None,
warnings: bool = True,
) -> bytes:
"""!!! abstract "Usage Documentation"
[JSON Serialization](../concepts/json.md#json-serialization)
Serialize an instance of the adapted type to JSON.
"""Serialize an instance of the adapted type to JSON.
Args:
instance: The instance to be serialized.
__instance: The instance to be serialized.
indent: Number of spaces for JSON indentation.
include: Fields to include.
exclude: Fields to exclude.
@@ -618,18 +392,13 @@ class TypeAdapter(Generic[T]):
exclude_defaults: Whether to exclude fields with default values.
exclude_none: Whether to exclude fields with a value of `None`.
round_trip: Whether to serialize and deserialize the instance to ensure round-tripping.
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
fallback: A function to call when an unknown value is encountered. If not provided,
a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
context: Additional context to pass to the serializer.
warnings: Whether to emit serialization warnings.
Returns:
The JSON representation of the given instance as bytes.
"""
return self.serializer.to_json(
instance,
__instance,
indent=indent,
include=include,
exclude=exclude,
@@ -639,9 +408,6 @@ class TypeAdapter(Generic[T]):
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
fallback=fallback,
serialize_as_any=serialize_as_any,
context=context,
)
def json_schema(
@@ -664,15 +430,11 @@ class TypeAdapter(Generic[T]):
The JSON schema for the model as a dictionary.
"""
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
if isinstance(self.core_schema, _mock_val_ser.MockCoreSchema):
self.core_schema.rebuild()
assert not isinstance(self.core_schema, _mock_val_ser.MockCoreSchema), 'this is a bug! please report it'
return schema_generator_instance.generate(self.core_schema, mode=mode)
@staticmethod
def json_schemas(
inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]],
/,
__inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]],
*,
by_alias: bool = True,
title: str | None = None,
@@ -683,7 +445,7 @@ class TypeAdapter(Generic[T]):
"""Generate a JSON schema including definitions from multiple type adapters.
Args:
inputs: Inputs to schema generation. The first two items will form the keys of the (first)
__inputs: Inputs to schema generation. The first two items will form the keys of the (first)
output mapping; the type adapters will provide the core schemas that get converted into
definitions in the output JSON schema.
by_alias: Whether to use alias names.
@@ -704,17 +466,9 @@ class TypeAdapter(Generic[T]):
"""
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
inputs_ = []
for key, mode, adapter in inputs:
# This is the same pattern we follow for model json schemas - we attempt a core schema rebuild if we detect a mock
if isinstance(adapter.core_schema, _mock_val_ser.MockCoreSchema):
adapter.core_schema.rebuild()
assert not isinstance(adapter.core_schema, _mock_val_ser.MockCoreSchema), (
'this is a bug! please report it'
)
inputs_.append((key, mode, adapter.core_schema))
inputs = [(key, mode, adapter.core_schema) for key, mode, adapter in __inputs]
json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs_)
json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs)
json_schema: dict[str, Any] = {}
if definitions:

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,4 @@
"""`typing` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -1,5 +1,4 @@
"""The `utils` module is a backport module from V1."""
from ._migration import getattr_migration
__getattr__ = getattr_migration(__name__)

View File

@@ -1,24 +1,24 @@
# flake8: noqa
from pydantic.v1 import dataclasses
from pydantic.v1.annotated_types import create_model_from_namedtuple, create_model_from_typeddict
from pydantic.v1.class_validators import root_validator, validator
from pydantic.v1.config import BaseConfig, ConfigDict, Extra
from pydantic.v1.decorator import validate_arguments
from pydantic.v1.env_settings import BaseSettings
from pydantic.v1.error_wrappers import ValidationError
from pydantic.v1.errors import *
from pydantic.v1.fields import Field, PrivateAttr, Required
from pydantic.v1.main import *
from pydantic.v1.networks import *
from pydantic.v1.parse import Protocol
from pydantic.v1.tools import *
from pydantic.v1.types import *
from pydantic.v1.version import VERSION, compiled
from . import dataclasses
from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict
from .class_validators import root_validator, validator
from .config import BaseConfig, ConfigDict, Extra
from .decorator import validate_arguments
from .env_settings import BaseSettings
from .error_wrappers import ValidationError
from .errors import *
from .fields import Field, PrivateAttr, Required
from .main import *
from .networks import *
from .parse import Protocol
from .tools import *
from .types import *
from .version import VERSION, compiled
__version__ = VERSION
# WARNING __all__ from pydantic.errors is not included here, it will be removed as an export here in v2
# please use "from pydantic.v1.errors import ..." instead
# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2
# please use "from pydantic.errors import ..." instead
__all__ = [
# annotated types utils
'create_model_from_namedtuple',

View File

@@ -35,7 +35,7 @@ import hypothesis.strategies as st
import pydantic
import pydantic.color
import pydantic.types
from pydantic.v1.utils import lenient_issubclass
from pydantic.utils import lenient_issubclass
# FilePath and DirectoryPath are explicitly unsupported, as we'd have to create
# them on-disk, and that's unsafe in general without being told *where* to do so.

View File

@@ -1,9 +1,9 @@
import sys
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, NamedTuple, Type
from pydantic.v1.fields import Required
from pydantic.v1.main import BaseModel, create_model
from pydantic.v1.typing import is_typeddict, is_typeddict_special
from .fields import Required
from .main import BaseModel, create_model
from .typing import is_typeddict, is_typeddict_special
if TYPE_CHECKING:
from typing_extensions import TypedDict

View File

@@ -5,12 +5,12 @@ from itertools import chain
from types import FunctionType
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload
from pydantic.v1.errors import ConfigError
from pydantic.v1.typing import AnyCallable
from pydantic.v1.utils import ROOT_KEY, in_ipython
from .errors import ConfigError
from .typing import AnyCallable
from .utils import ROOT_KEY, in_ipython
if TYPE_CHECKING:
from pydantic.v1.typing import AnyClassMethod
from .typing import AnyClassMethod
class Validator:
@@ -36,9 +36,9 @@ class Validator:
if TYPE_CHECKING:
from inspect import Signature
from pydantic.v1.config import BaseConfig
from pydantic.v1.fields import ModelField
from pydantic.v1.types import ModelOrDc
from .config import BaseConfig
from .fields import ModelField
from .types import ModelOrDc
ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any]
ValidatorsList = List[ValidatorCallable]

View File

@@ -12,11 +12,11 @@ import re
from colorsys import hls_to_rgb, rgb_to_hls
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast
from pydantic.v1.errors import ColorError
from pydantic.v1.utils import Representation, almost_equal_floats
from .errors import ColorError
from .utils import Representation, almost_equal_floats
if TYPE_CHECKING:
from pydantic.v1.typing import CallableGenerator, ReprArgs
from .typing import CallableGenerator, ReprArgs
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
ColorType = Union[ColorTuple, str]

View File

@@ -4,15 +4,15 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tup
from typing_extensions import Literal, Protocol
from pydantic.v1.typing import AnyArgTCallable, AnyCallable
from pydantic.v1.utils import GetterDict
from pydantic.v1.version import compiled
from .typing import AnyArgTCallable, AnyCallable
from .utils import GetterDict
from .version import compiled
if TYPE_CHECKING:
from typing import overload
from pydantic.v1.fields import ModelField
from pydantic.v1.main import BaseModel
from .fields import ModelField
from .main import BaseModel
ConfigType = Type['BaseConfig']

View File

@@ -36,28 +36,21 @@ import dataclasses
import sys
from contextlib import contextmanager
from functools import wraps
try:
from functools import cached_property
except ImportError:
# cached_property available only for python3.8+
pass
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
from typing_extensions import dataclass_transform
from pydantic.v1.class_validators import gather_all_validators
from pydantic.v1.config import BaseConfig, ConfigDict, Extra, get_config
from pydantic.v1.error_wrappers import ValidationError
from pydantic.v1.errors import DataclassTypeError
from pydantic.v1.fields import Field, FieldInfo, Required, Undefined
from pydantic.v1.main import create_model, validate_model
from pydantic.v1.utils import ClassAttribute
from .class_validators import gather_all_validators
from .config import BaseConfig, ConfigDict, Extra, get_config
from .error_wrappers import ValidationError
from .errors import DataclassTypeError
from .fields import Field, FieldInfo, Required, Undefined
from .main import create_model, validate_model
from .utils import ClassAttribute
if TYPE_CHECKING:
from pydantic.v1.main import BaseModel
from pydantic.v1.typing import CallableGenerator, NoArgAnyCallable
from .main import BaseModel
from .typing import CallableGenerator, NoArgAnyCallable
DataclassT = TypeVar('DataclassT', bound='Dataclass')
@@ -416,17 +409,6 @@ def create_pydantic_model_from_dataclass(
return model
if sys.version_info >= (3, 8):
def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
return isinstance(getattr(type(obj), k, None), cached_property)
else:
def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
return False
def _dataclass_validate_values(self: 'Dataclass') -> None:
# validation errors can occur if this function is called twice on an already initialised dataclass.
# for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
@@ -435,13 +417,9 @@ def _dataclass_validate_values(self: 'Dataclass') -> None:
if getattr(self, '__pydantic_has_field_info_default__', False):
# We need to remove `FieldInfo` values since they are not valid as input
# It's ok to do that because they are obviously the default values!
input_data = {
k: v
for k, v in self.__dict__.items()
if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k))
}
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
else:
input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)}
input_data = self.__dict__
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
if validation_error:
raise validation_error

View File

@@ -18,7 +18,7 @@ import re
from datetime import date, datetime, time, timedelta, timezone
from typing import Dict, Optional, Type, Union
from pydantic.v1 import errors
from . import errors
date_expr = r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
time_expr = (

View File

@@ -1,17 +1,17 @@
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
from pydantic.v1 import validator
from pydantic.v1.config import Extra
from pydantic.v1.errors import ConfigError
from pydantic.v1.main import BaseModel, create_model
from pydantic.v1.typing import get_all_type_hints
from pydantic.v1.utils import to_camel
from . import validator
from .config import Extra
from .errors import ConfigError
from .main import BaseModel, create_model
from .typing import get_all_type_hints
from .utils import to_camel
__all__ = ('validate_arguments',)
if TYPE_CHECKING:
from pydantic.v1.typing import AnyCallable
from .typing import AnyCallable
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
ConfigType = Union[None, Type[Any], Dict[str, Any]]

View File

@@ -3,12 +3,12 @@ import warnings
from pathlib import Path
from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union
from pydantic.v1.config import BaseConfig, Extra
from pydantic.v1.fields import ModelField
from pydantic.v1.main import BaseModel
from pydantic.v1.types import JsonWrapper
from pydantic.v1.typing import StrPath, display_as_type, get_origin, is_union
from pydantic.v1.utils import deep_update, lenient_issubclass, path_type, sequence_like
from .config import BaseConfig, Extra
from .fields import ModelField
from .main import BaseModel
from .types import JsonWrapper
from .typing import StrPath, display_as_type, get_origin, is_union
from .utils import deep_update, lenient_issubclass, path_type, sequence_like
env_file_sentinel = str(object())

View File

@@ -1,15 +1,15 @@
import json
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union
from pydantic.v1.json import pydantic_encoder
from pydantic.v1.utils import Representation
from .json import pydantic_encoder
from .utils import Representation
if TYPE_CHECKING:
from typing_extensions import TypedDict
from pydantic.v1.config import BaseConfig
from pydantic.v1.types import ModelOrDc
from pydantic.v1.typing import ReprArgs
from .config import BaseConfig
from .types import ModelOrDc
from .typing import ReprArgs
Loc = Tuple[Union[int, str], ...]
@@ -101,6 +101,7 @@ def flatten_errors(
) -> Generator['ErrorDict', None, None]:
for error in errors:
if isinstance(error, ErrorWrapper):
if loc:
error_loc = loc + error.loc_tuple()
else:

View File

@@ -2,12 +2,12 @@ from decimal import Decimal
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union
from pydantic.v1.typing import display_as_type
from .typing import display_as_type
if TYPE_CHECKING:
from pydantic.v1.typing import DictStrAny
from .typing import DictStrAny
# explicitly state exports to avoid "from pydantic.v1.errors import *" also importing Decimal, Path etc.
# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc.
__all__ = (
'PydanticTypeError',
'PydanticValueError',

View File

@@ -28,12 +28,12 @@ from typing import (
from typing_extensions import Annotated, Final
from pydantic.v1 import errors as errors_
from pydantic.v1.class_validators import Validator, make_generic_validator, prep_validators
from pydantic.v1.error_wrappers import ErrorWrapper
from pydantic.v1.errors import ConfigError, InvalidDiscriminator, MissingDiscriminator, NoneIsNotAllowedError
from pydantic.v1.types import Json, JsonWrapper
from pydantic.v1.typing import (
from . import errors as errors_
from .class_validators import Validator, make_generic_validator, prep_validators
from .error_wrappers import ErrorWrapper
from .errors import ConfigError, InvalidDiscriminator, MissingDiscriminator, NoneIsNotAllowedError
from .types import Json, JsonWrapper
from .typing import (
NoArgAnyCallable,
convert_generics,
display_as_type,
@@ -48,7 +48,7 @@ from pydantic.v1.typing import (
is_union,
new_type_supertype,
)
from pydantic.v1.utils import (
from .utils import (
PyObjectStr,
Representation,
ValueItems,
@@ -59,7 +59,7 @@ from pydantic.v1.utils import (
sequence_like,
smart_deepcopy,
)
from pydantic.v1.validators import constant_validator, dict_validator, find_validators, validate_json
from .validators import constant_validator, dict_validator, find_validators, validate_json
Required: Any = Ellipsis
@@ -83,11 +83,11 @@ class UndefinedType:
Undefined = UndefinedType()
if TYPE_CHECKING:
from pydantic.v1.class_validators import ValidatorsList
from pydantic.v1.config import BaseConfig
from pydantic.v1.error_wrappers import ErrorList
from pydantic.v1.types import ModelOrDc
from pydantic.v1.typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs
from .class_validators import ValidatorsList
from .config import BaseConfig
from .error_wrappers import ErrorList
from .types import ModelOrDc
from .typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs
ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]]
LocStr = Union[Tuple[Union[int, str], ...], str]
@@ -178,6 +178,7 @@ class FieldInfo(Representation):
self.extra = kwargs
def __repr_args__(self) -> 'ReprArgs':
field_defaults_to_hide: Dict[str, Any] = {
'repr': True,
**self.__field_constraints__,
@@ -404,6 +405,7 @@ class ModelField(Representation):
alias: Optional[str] = None,
field_info: Optional[FieldInfo] = None,
) -> None:
self.name: str = name
self.has_alias: bool = alias is not None
self.alias: str = alias if alias is not None else name
@@ -490,7 +492,7 @@ class ModelField(Representation):
class_validators: Optional[Dict[str, Validator]],
config: Type['BaseConfig'],
) -> 'ModelField':
from pydantic.v1.schema import get_annotation_from_field_info
from .schema import get_annotation_from_field_info
field_info, value = cls._get_field_info(name, annotation, value, config)
required: 'BoolUndefined' = Undefined
@@ -850,6 +852,7 @@ class ModelField(Representation):
def validate(
self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None
) -> 'ValidateReturn':
assert self.type_.__class__ is not DeferredType
if self.type_.__class__ is ForwardRef:
@@ -1160,7 +1163,7 @@ class ModelField(Representation):
"""
Whether the field is "complex" eg. env variables should be parsed as JSON.
"""
from pydantic.v1.main import BaseModel
from .main import BaseModel
return (
self.shape != SHAPE_SINGLETON

View File

@@ -22,12 +22,12 @@ from weakref import WeakKeyDictionary, WeakValueDictionary
from typing_extensions import Annotated, Literal as ExtLiteral
from pydantic.v1.class_validators import gather_all_validators
from pydantic.v1.fields import DeferredType
from pydantic.v1.main import BaseModel, create_model
from pydantic.v1.types import JsonWrapper
from pydantic.v1.typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
from pydantic.v1.utils import all_identical, lenient_issubclass
from .class_validators import gather_all_validators
from .fields import DeferredType
from .main import BaseModel, create_model
from .types import JsonWrapper
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
from .utils import all_identical, lenient_issubclass
if sys.version_info >= (3, 10):
from typing import _UnionGenericAlias

View File

@@ -9,9 +9,9 @@ from types import GeneratorType
from typing import Any, Callable, Dict, Type, Union
from uuid import UUID
from pydantic.v1.color import Color
from pydantic.v1.networks import NameEmail
from pydantic.v1.types import SecretBytes, SecretStr
from .color import Color
from .networks import NameEmail
from .types import SecretBytes, SecretStr
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
@@ -72,7 +72,7 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
def pydantic_encoder(obj: Any) -> Any:
from dataclasses import asdict, is_dataclass
from pydantic.v1.main import BaseModel
from .main import BaseModel
if isinstance(obj, BaseModel):
return obj.dict()

View File

@@ -26,11 +26,11 @@ from typing import (
from typing_extensions import dataclass_transform
from pydantic.v1.class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators
from pydantic.v1.config import BaseConfig, Extra, inherit_config, prepare_config
from pydantic.v1.error_wrappers import ErrorWrapper, ValidationError
from pydantic.v1.errors import ConfigError, DictError, ExtraError, MissingError
from pydantic.v1.fields import (
from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators
from .config import BaseConfig, Extra, inherit_config, prepare_config
from .error_wrappers import ErrorWrapper, ValidationError
from .errors import ConfigError, DictError, ExtraError, MissingError
from .fields import (
MAPPING_LIKE_SHAPES,
Field,
ModelField,
@@ -39,11 +39,11 @@ from pydantic.v1.fields import (
Undefined,
is_finalvar_with_default_val,
)
from pydantic.v1.json import custom_pydantic_encoder, pydantic_encoder
from pydantic.v1.parse import Protocol, load_file, load_str_bytes
from pydantic.v1.schema import default_ref_template, model_schema
from pydantic.v1.types import PyObject, StrBytes
from pydantic.v1.typing import (
from .json import custom_pydantic_encoder, pydantic_encoder
from .parse import Protocol, load_file, load_str_bytes
from .schema import default_ref_template, model_schema
from .types import PyObject, StrBytes
from .typing import (
AnyCallable,
get_args,
get_origin,
@@ -53,7 +53,7 @@ from pydantic.v1.typing import (
resolve_annotations,
update_model_forward_refs,
)
from pydantic.v1.utils import (
from .utils import (
DUNDER_ATTRIBUTES,
ROOT_KEY,
ClassAttribute,
@@ -73,9 +73,9 @@ from pydantic.v1.utils import (
if TYPE_CHECKING:
from inspect import Signature
from pydantic.v1.class_validators import ValidatorListDict
from pydantic.v1.types import ModelOrDc
from pydantic.v1.typing import (
from .class_validators import ValidatorListDict
from .types import ModelOrDc
from .typing import (
AbstractSetIntStr,
AnyClassMethod,
CallableGenerator,
@@ -282,12 +282,6 @@ class ModelMetaclass(ABCMeta):
cls = super().__new__(mcs, name, bases, new_namespace, **kwargs)
# set __signature__ attr only for model class, but not for its instances
cls.__signature__ = ClassAttribute('__signature__', generate_model_signature(cls.__init__, fields, config))
if not _is_base_model_class_defined:
# Cython does not understand the `if TYPE_CHECKING:` condition in the
# BaseModel's body (where annotations are set), so clear them manually:
getattr(cls, '__annotations__', {}).clear()
if resolve_forward_refs:
cls.__try_update_forward_refs__()
@@ -307,7 +301,7 @@ class ModelMetaclass(ABCMeta):
See #3829 and python/cpython#92810
"""
return hasattr(instance, '__post_root_validators__') and super().__instancecheck__(instance)
return hasattr(instance, '__fields__') and super().__instancecheck__(instance)
object_setattr = object.__setattr__
@@ -675,7 +669,7 @@ class BaseModel(Representation, metaclass=ModelMetaclass):
def schema_json(
cls, *, by_alias: bool = True, ref_template: str = default_ref_template, **dumps_kwargs: Any
) -> str:
from pydantic.v1.json import pydantic_encoder
from .json import pydantic_encoder
return cls.__config__.json_dumps(
cls.schema(by_alias=by_alias, ref_template=ref_template), default=pydantic_encoder, **dumps_kwargs
@@ -743,6 +737,7 @@ class BaseModel(Representation, metaclass=ModelMetaclass):
exclude_defaults: bool,
exclude_none: bool,
) -> Any:
if isinstance(v, BaseModel):
if to_dict:
v_dict = v.dict(
@@ -835,6 +830,7 @@ class BaseModel(Representation, metaclass=ModelMetaclass):
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> 'TupleGenerator':
# Merge field set excludes with explicit exclude parameter with explicit overriding field set options.
# The extra "is not None" guards are not logically necessary but optimizes performance for the simple case.
if exclude is not None or self.__exclude_fields__ is not None:

View File

@@ -57,7 +57,6 @@ from mypy.types import (
Type,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarType,
UnionType,
get_proper_type,
@@ -66,7 +65,7 @@ from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
from mypy.version import __version__ as mypy_version
from pydantic.v1.utils import is_valid_field
from pydantic.utils import is_valid_field
try:
from mypy.types import TypeVarDef # type: ignore[attr-defined]
@@ -499,11 +498,7 @@ class PydanticModelTransformer:
tvd = TypeVarType(
self_tvar_name,
tvar_fullname,
(
TypeVarId(-1, namespace=ctx.cls.fullname + '.construct')
if MYPY_VERSION_TUPLE >= (1, 11)
else TypeVarId(-1)
),
-1,
[],
obj_type,
AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type]
@@ -863,9 +858,9 @@ def add_method(
arg_kinds.append(arg.kind)
function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function')
signature = CallableType(
arg_types, arg_kinds, arg_names, return_type, function_type, variables=[tvar_def] if tvar_def else None
)
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
if tvar_def:
signature.variables = [tvar_def]
func = FuncDef(name, args, Block([PassStmt()]))
func.info = info

View File

@@ -27,17 +27,17 @@ from typing import (
no_type_check,
)
from pydantic.v1 import errors
from pydantic.v1.utils import Representation, update_not_none
from pydantic.v1.validators import constr_length_validator, str_validator
from . import errors
from .utils import Representation, update_not_none
from .validators import constr_length_validator, str_validator
if TYPE_CHECKING:
import email_validator
from typing_extensions import TypedDict
from pydantic.v1.config import BaseConfig
from pydantic.v1.fields import ModelField
from pydantic.v1.typing import AnyCallable
from .config import BaseConfig
from .fields import ModelField
from .typing import AnyCallable
CallableGenerator = Generator[AnyCallable, None, None]

View File

@@ -4,7 +4,7 @@ from enum import Enum
from pathlib import Path
from typing import Any, Callable, Union
from pydantic.v1.types import StrBytes
from .types import StrBytes
class Protocol(str, Enum):

View File

@@ -31,7 +31,7 @@ from uuid import UUID
from typing_extensions import Annotated, Literal
from pydantic.v1.fields import (
from .fields import (
MAPPING_LIKE_SHAPES,
SHAPE_DEQUE,
SHAPE_FROZENSET,
@@ -46,9 +46,9 @@ from pydantic.v1.fields import (
FieldInfo,
ModelField,
)
from pydantic.v1.json import pydantic_encoder
from pydantic.v1.networks import AnyUrl, EmailStr
from pydantic.v1.types import (
from .json import pydantic_encoder
from .networks import AnyUrl, EmailStr
from .types import (
ConstrainedDecimal,
ConstrainedFloat,
ConstrainedFrozenSet,
@@ -69,7 +69,7 @@ from pydantic.v1.types import (
conset,
constr,
)
from pydantic.v1.typing import (
from .typing import (
all_literal_values,
get_args,
get_origin,
@@ -80,11 +80,11 @@ from pydantic.v1.typing import (
is_none_type,
is_union,
)
from pydantic.v1.utils import ROOT_KEY, get_model, lenient_issubclass
from .utils import ROOT_KEY, get_model, lenient_issubclass
if TYPE_CHECKING:
from pydantic.v1.dataclasses import Dataclass
from pydantic.v1.main import BaseModel
from .dataclasses import Dataclass
from .main import BaseModel
default_prefix = '#/definitions/'
default_ref_template = '#/definitions/{model}'
@@ -198,6 +198,7 @@ def model_schema(
def get_field_info_schema(field: ModelField, schema_overrides: bool = False) -> Tuple[Dict[str, Any], bool]:
# If no title is explicitly set, we don't set title in the schema for enums.
# The behaviour is the same as `BaseModel` reference, where the default title
# is in the definitions part of the schema.
@@ -378,7 +379,7 @@ def get_flat_models_from_field(field: ModelField, known_models: TypeModelSet) ->
:param known_models: used to solve circular references
:return: a set with the model used in the declaration for this field, if any, and all its sub-models
"""
from pydantic.v1.main import BaseModel
from .main import BaseModel
flat_models: TypeModelSet = set()
@@ -445,7 +446,7 @@ def field_type_schema(
Take a single ``field`` and generate the schema for its type only, not including additional
information as title, etc. Also return additional schema definitions, from sub-models.
"""
from pydantic.v1.main import BaseModel # noqa: F811
from .main import BaseModel # noqa: F811
definitions = {}
nested_models: Set[str] = set()
@@ -738,7 +739,7 @@ def field_singleton_sub_fields_schema(
discriminator_models_refs[discriminator_value] = discriminator_model_ref['$ref']
s['discriminator'] = {
'propertyName': field.discriminator_alias if by_alias else field.discriminator_key,
'propertyName': field.discriminator_alias,
'mapping': discriminator_models_refs,
}
@@ -838,7 +839,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
Take a single Pydantic ``ModelField``, and return its schema and any additional definitions from sub-models.
"""
from pydantic.v1.main import BaseModel
from .main import BaseModel
definitions: Dict[str, Any] = {}
nested_models: Set[str] = set()
@@ -974,7 +975,7 @@ def multitypes_literal_field_for_schema(values: Tuple[Any, ...], field: ModelFie
def encode_default(dft: Any) -> Any:
from pydantic.v1.main import BaseModel
from .main import BaseModel
if isinstance(dft, BaseModel) or is_dataclass(dft):
dft = cast('dict[str, Any]', pydantic_encoder(dft))
@@ -1090,7 +1091,7 @@ def get_annotation_with_constraints(annotation: Any, field_info: FieldInfo) -> T
if issubclass(type_, (SecretStr, SecretBytes)):
attrs = ('max_length', 'min_length')
def constraint_func(**kw: Any) -> Type[Any]: # noqa: F811
def constraint_func(**kw: Any) -> Type[Any]:
return type(type_.__name__, (type_,), kw)
elif issubclass(type_, str) and not issubclass(type_, (EmailStr, AnyUrl)):

View File

@@ -3,16 +3,16 @@ from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union
from pydantic.v1.parse import Protocol, load_file, load_str_bytes
from pydantic.v1.types import StrBytes
from pydantic.v1.typing import display_as_type
from .parse import Protocol, load_file, load_str_bytes
from .types import StrBytes
from .typing import display_as_type
__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of')
NameFactory = Union[str, Callable[[Type[Any]], str]]
if TYPE_CHECKING:
from pydantic.v1.typing import DictStrAny
from .typing import DictStrAny
def _generate_parsing_type_name(type_: Any) -> str:
@@ -21,7 +21,7 @@ def _generate_parsing_type_name(type_: Any) -> str:
@lru_cache(maxsize=2048)
def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any:
from pydantic.v1.main import create_model
from .main import create_model
if type_name is None:
type_name = _generate_parsing_type_name

View File

@@ -28,10 +28,10 @@ from typing import (
from uuid import UUID
from weakref import WeakSet
from pydantic.v1 import errors
from pydantic.v1.datetime_parse import parse_date
from pydantic.v1.utils import import_string, update_not_none
from pydantic.v1.validators import (
from . import errors
from .datetime_parse import parse_date
from .utils import import_string, update_not_none
from .validators import (
bytes_validator,
constr_length_validator,
constr_lower,
@@ -123,9 +123,9 @@ StrIntFloat = Union[str, int, float]
if TYPE_CHECKING:
from typing_extensions import Annotated
from pydantic.v1.dataclasses import Dataclass
from pydantic.v1.main import BaseModel
from pydantic.v1.typing import CallableGenerator
from .dataclasses import Dataclass
from .main import BaseModel
from .typing import CallableGenerator
ModelOrDc = Type[Union[BaseModel, Dataclass]]
@@ -481,7 +481,6 @@ else:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# This types superclass should be Set[T], but cython chokes on that...
class ConstrainedSet(set): # type: ignore
# Needed for pydantic to detect that this is a set
@@ -570,7 +569,6 @@ def confrozenset(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LIST TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# This types superclass should be List[T], but cython chokes on that...
class ConstrainedList(list): # type: ignore
# Needed for pydantic to detect that this is a list
@@ -1096,6 +1094,7 @@ class ByteSize(int):
@classmethod
def validate(cls, v: StrIntFloat) -> 'ByteSize':
try:
return cls(int(v))
except ValueError:
@@ -1117,6 +1116,7 @@ class ByteSize(int):
return cls(int(float(scalar) * unit_mult))
def human_readable(self, decimal: bool = False) -> str:
if decimal:
divisor = 1000
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
@@ -1135,6 +1135,7 @@ class ByteSize(int):
return f'{num:0.1f}{final_unit}'
def to(self, unit: str) -> float:
try:
unit_div = BYTE_SIZES[unit.lower()]
except KeyError:

View File

@@ -58,21 +58,12 @@ if sys.version_info < (3, 9):
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
return type_._evaluate(globalns, localns)
elif sys.version_info < (3, 12, 4):
else:
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
# Even though it is the right signature for python 3.9, mypy complains with
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
else:
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
# Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid
# warnings:
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
return cast(Any, type_)._evaluate(globalns, localns, set())
if sys.version_info < (3, 9):
@@ -265,7 +256,7 @@ StrPath = Union[str, PathLike]
if TYPE_CHECKING:
from pydantic.v1.fields import ModelField
from .fields import ModelField
TupleGenerator = Generator[Tuple[str, Any], None, None]
DictStrAny = Dict[str, Any]
@@ -406,10 +397,7 @@ def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Opti
else:
value = ForwardRef(value, is_argument=False)
try:
if sys.version_info >= (3, 13):
value = _eval_type(value, base_globals, None, type_params=())
else:
value = _eval_type(value, base_globals, None)
value = _eval_type(value, base_globals, None)
except NameError:
# this is ok, it can be fixed with update_forward_refs
pass
@@ -447,7 +435,7 @@ def is_namedtuple(type_: Type[Any]) -> bool:
Check if a given class is a named tuple.
It can be either a `typing.NamedTuple` or `collections.namedtuple`
"""
from pydantic.v1.utils import lenient_issubclass
from .utils import lenient_issubclass
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
@@ -457,7 +445,7 @@ def is_typeddict(type_: Type[Any]) -> bool:
Check if a given class is a typed dict (from `typing` or `typing_extensions`)
In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict)
"""
from pydantic.v1.utils import lenient_issubclass
from .utils import lenient_issubclass
return lenient_issubclass(type_, dict) and hasattr(type_, '__total__')

View File

@@ -28,8 +28,8 @@ from typing import (
from typing_extensions import Annotated
from pydantic.v1.errors import ConfigError
from pydantic.v1.typing import (
from .errors import ConfigError
from .typing import (
NoneType,
WithArgsTypes,
all_literal_values,
@@ -39,17 +39,17 @@ from pydantic.v1.typing import (
is_literal_type,
is_union,
)
from pydantic.v1.version import version_info
from .version import version_info
if TYPE_CHECKING:
from inspect import Signature
from pathlib import Path
from pydantic.v1.config import BaseConfig
from pydantic.v1.dataclasses import Dataclass
from pydantic.v1.fields import ModelField
from pydantic.v1.main import BaseModel
from pydantic.v1.typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
from .config import BaseConfig
from .dataclasses import Dataclass
from .fields import ModelField
from .main import BaseModel
from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]]
@@ -66,7 +66,6 @@ __all__ = (
'almost_equal_floats',
'get_model',
'to_camel',
'to_lower_camel',
'is_valid_field',
'smart_deepcopy',
'PyObjectStr',
@@ -159,7 +158,7 @@ def sequence_like(v: Any) -> bool:
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
def validate_field_name(bases: Iterable[Type[Any]], field_name: str) -> None:
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
"""
Ensure that the field's name does not shadow an existing attribute of the model.
"""
@@ -241,7 +240,7 @@ def generate_model_signature(
"""
from inspect import Parameter, Signature, signature
from pydantic.v1.config import Extra
from .config import Extra
present_params = signature(init).parameters.values()
merged_params: Dict[str, Parameter] = {}
@@ -299,7 +298,7 @@ def generate_model_signature(
def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
from pydantic.v1.main import BaseModel
from .main import BaseModel
try:
model_cls = obj.__pydantic_model__ # type: ignore
@@ -708,8 +707,6 @@ DUNDER_ATTRIBUTES = {
'__orig_bases__',
'__orig_class__',
'__qualname__',
'__firstlineno__',
'__static_attributes__',
}

View File

@@ -27,11 +27,10 @@ from typing import (
Union,
)
from uuid import UUID
from warnings import warn
from pydantic.v1 import errors
from pydantic.v1.datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
from pydantic.v1.typing import (
from . import errors
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
from .typing import (
AnyCallable,
all_literal_values,
display_as_type,
@@ -42,14 +41,14 @@ from pydantic.v1.typing import (
is_none_type,
is_typeddict,
)
from pydantic.v1.utils import almost_equal_floats, lenient_issubclass, sequence_like
from .utils import almost_equal_floats, lenient_issubclass, sequence_like
if TYPE_CHECKING:
from typing_extensions import Literal, TypedDict
from pydantic.v1.config import BaseConfig
from pydantic.v1.fields import ModelField
from pydantic.v1.types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt
from .config import BaseConfig
from .fields import ModelField
from .types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt
ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt]
AnyOrderedDict = OrderedDict[Any, Any]
@@ -595,7 +594,7 @@ NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple)
def make_namedtuple_validator(
namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig']
) -> Callable[[Tuple[Any, ...]], NamedTupleT]:
from pydantic.v1.annotated_types import create_model_from_namedtuple
from .annotated_types import create_model_from_namedtuple
NamedTupleModel = create_model_from_namedtuple(
namedtuple_cls,
@@ -620,7 +619,7 @@ def make_namedtuple_validator(
def make_typeddict_validator(
typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type]
) -> Callable[[Any], Dict[str, Any]]:
from pydantic.v1.annotated_types import create_model_from_typeddict
from .annotated_types import create_model_from_typeddict
TypedDictModel = create_model_from_typeddict(
typeddict_cls,
@@ -699,7 +698,7 @@ _VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [
def find_validators( # noqa: C901 (ignore complexity)
type_: Type[Any], config: Type['BaseConfig']
) -> Generator[AnyCallable, None, None]:
from pydantic.v1.dataclasses import is_builtin_dataclass, make_dataclass_validator
from .dataclasses import is_builtin_dataclass, make_dataclass_validator
if type_ is Any or type_ is object:
return
@@ -763,6 +762,4 @@ def find_validators( # noqa: C901 (ignore complexity)
if config.arbitrary_types_allowed:
yield make_arbitrary_type_validator(type_)
else:
if hasattr(type_, '__pydantic_core_schema__'):
warn(f'Mixing V1 and V2 models is not supported. `{type_.__name__}` is a V2 model.', UserWarning)
raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config')

View File

@@ -1,6 +1,6 @@
__all__ = 'compiled', 'VERSION', 'version_info'
VERSION = '1.10.21'
VERSION = '1.10.13'
try:
import cython # type: ignore

View File

@@ -0,0 +1,58 @@
"""Decorator for validating function calls."""
from __future__ import annotations as _annotations
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
from ._internal import _validate_call
__all__ = ('validate_call',)
if TYPE_CHECKING:
from .config import ConfigDict
AnyCallableT = TypeVar('AnyCallableT', bound=Callable[..., Any])
@overload
def validate_call(
*, config: ConfigDict | None = None, validate_return: bool = False
) -> Callable[[AnyCallableT], AnyCallableT]:
...
@overload
def validate_call(__func: AnyCallableT) -> AnyCallableT:
...
def validate_call(
__func: AnyCallableT | None = None,
*,
config: ConfigDict | None = None,
validate_return: bool = False,
) -> AnyCallableT | Callable[[AnyCallableT], AnyCallableT]:
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validation_decorator/
Returns a decorated wrapper around the function that validates the arguments and, optionally, the return value.
Usage may be either as a plain decorator `@validate_call` or with arguments `@validate_call(...)`.
Args:
__func: The function to be decorated.
config: The configuration dictionary.
validate_return: Whether to validate the return value.
Returns:
The decorated function.
"""
def validate(function: AnyCallableT) -> AnyCallableT:
if isinstance(function, (classmethod, staticmethod)):
name = type(function).__name__
raise TypeError(f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)')
return _validate_call.ValidateCallWrapper(function, config, validate_return) # type: ignore
if __func:
return validate(__func)
else:
return validate

Some files were not shown because too many files have changed in this diff Show More