This commit is contained in:
@@ -1,73 +1,59 @@
|
||||
import typing
|
||||
from importlib import import_module
|
||||
from warnings import warn
|
||||
|
||||
import pydantic_core
|
||||
from pydantic_core.core_schema import (
|
||||
FieldSerializationInfo,
|
||||
SerializationInfo,
|
||||
SerializerFunctionWrapHandler,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
)
|
||||
|
||||
from . import dataclasses
|
||||
from ._internal._generate_schema import GenerateSchema as GenerateSchema
|
||||
from ._migration import getattr_migration
|
||||
from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
from .config import ConfigDict
|
||||
from .errors import *
|
||||
from .fields import AliasChoices, AliasPath, Field, PrivateAttr, computed_field
|
||||
from .functional_serializers import PlainSerializer, SerializeAsAny, WrapSerializer, field_serializer, model_serializer
|
||||
from .functional_validators import (
|
||||
AfterValidator,
|
||||
BeforeValidator,
|
||||
InstanceOf,
|
||||
PlainValidator,
|
||||
SkipValidation,
|
||||
WrapValidator,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from .json_schema import WithJsonSchema
|
||||
from .main import *
|
||||
from .networks import *
|
||||
from .type_adapter import TypeAdapter
|
||||
from .types import *
|
||||
from .validate_call import validate_call
|
||||
from .version import VERSION
|
||||
from .warnings import *
|
||||
|
||||
__version__ = VERSION
|
||||
|
||||
# this encourages pycharm to import `ValidationError` from here, not pydantic_core
|
||||
ValidationError = pydantic_core.ValidationError
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# import of virtually everything is supported via `__getattr__` below,
|
||||
# but we need them here for type checking and IDE support
|
||||
import pydantic_core
|
||||
from pydantic_core.core_schema import (
|
||||
FieldSerializationInfo,
|
||||
SerializationInfo,
|
||||
SerializerFunctionWrapHandler,
|
||||
ValidationInfo,
|
||||
ValidatorFunctionWrapHandler,
|
||||
)
|
||||
|
||||
from . import dataclasses
|
||||
from .aliases import AliasChoices, AliasGenerator, AliasPath
|
||||
from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
from .config import ConfigDict, with_config
|
||||
from .errors import *
|
||||
from .fields import Field, PrivateAttr, computed_field
|
||||
from .functional_serializers import (
|
||||
PlainSerializer,
|
||||
SerializeAsAny,
|
||||
WrapSerializer,
|
||||
field_serializer,
|
||||
model_serializer,
|
||||
)
|
||||
from .functional_validators import (
|
||||
AfterValidator,
|
||||
BeforeValidator,
|
||||
InstanceOf,
|
||||
ModelWrapValidatorHandler,
|
||||
PlainValidator,
|
||||
SkipValidation,
|
||||
WrapValidator,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from .json_schema import WithJsonSchema
|
||||
from .main import *
|
||||
from .networks import *
|
||||
from .type_adapter import TypeAdapter
|
||||
from .types import *
|
||||
from .validate_call_decorator import validate_call
|
||||
from .warnings import (
|
||||
PydanticDeprecatedSince20,
|
||||
PydanticDeprecatedSince26,
|
||||
PydanticDeprecatedSince29,
|
||||
PydanticDeprecatedSince210,
|
||||
PydanticDeprecatedSince211,
|
||||
PydanticDeprecationWarning,
|
||||
PydanticExperimentalWarning,
|
||||
)
|
||||
|
||||
# this encourages pycharm to import `ValidationError` from here, not pydantic_core
|
||||
ValidationError = pydantic_core.ValidationError
|
||||
# these are imported via `__getattr__` below, but we need them here for type checking and IDE support
|
||||
from .deprecated.class_validators import root_validator, validator
|
||||
from .deprecated.config import BaseConfig, Extra
|
||||
from .deprecated.tools import *
|
||||
from .root_model import RootModel
|
||||
|
||||
__version__ = VERSION
|
||||
__all__ = (
|
||||
__all__ = [
|
||||
# dataclasses
|
||||
'dataclasses',
|
||||
# pydantic_core.core_schema
|
||||
'ValidationInfo',
|
||||
'ValidatorFunctionWrapHandler',
|
||||
# functional validators
|
||||
'field_validator',
|
||||
'model_validator',
|
||||
@@ -77,8 +63,6 @@ __all__ = (
|
||||
'WrapValidator',
|
||||
'SkipValidation',
|
||||
'InstanceOf',
|
||||
'ModelWrapValidatorHandler',
|
||||
# JSON Schema
|
||||
'WithJsonSchema',
|
||||
# deprecated V1 functional validators, these are imported via `__getattr__` below
|
||||
'root_validator',
|
||||
@@ -89,14 +73,18 @@ __all__ = (
|
||||
'PlainSerializer',
|
||||
'SerializeAsAny',
|
||||
'WrapSerializer',
|
||||
'FieldSerializationInfo',
|
||||
'SerializationInfo',
|
||||
'SerializerFunctionWrapHandler',
|
||||
# config
|
||||
'ConfigDict',
|
||||
'with_config',
|
||||
# deprecated V1 config, these are imported via `__getattr__` below
|
||||
'BaseConfig',
|
||||
'Extra',
|
||||
# validate_call
|
||||
'validate_call',
|
||||
# pydantic_core errors
|
||||
'ValidationError',
|
||||
# errors
|
||||
'PydanticErrorCodes',
|
||||
'PydanticUserError',
|
||||
@@ -104,15 +92,11 @@ __all__ = (
|
||||
'PydanticImportError',
|
||||
'PydanticUndefinedAnnotation',
|
||||
'PydanticInvalidForJsonSchema',
|
||||
'PydanticForbiddenQualifier',
|
||||
# fields
|
||||
'AliasPath',
|
||||
'AliasChoices',
|
||||
'Field',
|
||||
'computed_field',
|
||||
'PrivateAttr',
|
||||
# alias
|
||||
'AliasChoices',
|
||||
'AliasGenerator',
|
||||
'AliasPath',
|
||||
# main
|
||||
'BaseModel',
|
||||
'create_model',
|
||||
@@ -121,9 +105,6 @@ __all__ = (
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'HttpUrl',
|
||||
'FtpUrl',
|
||||
'WebsocketUrl',
|
||||
'AnyWebsocketUrl',
|
||||
'UrlConstraints',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
@@ -136,11 +117,8 @@ __all__ = (
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'NatsDsn',
|
||||
'MySQLDsn',
|
||||
'MariaDBDsn',
|
||||
'ClickHouseDsn',
|
||||
'SnowflakeDsn',
|
||||
'validate_email',
|
||||
# root_model
|
||||
'RootModel',
|
||||
@@ -175,22 +153,18 @@ __all__ = (
|
||||
'UUID3',
|
||||
'UUID4',
|
||||
'UUID5',
|
||||
'UUID6',
|
||||
'UUID7',
|
||||
'UUID8',
|
||||
'FilePath',
|
||||
'DirectoryPath',
|
||||
'NewPath',
|
||||
'Json',
|
||||
'Secret',
|
||||
'SecretStr',
|
||||
'SecretBytes',
|
||||
'SocketPath',
|
||||
'StrictBool',
|
||||
'StrictBytes',
|
||||
'StrictInt',
|
||||
'StrictFloat',
|
||||
'PaymentCardNumber',
|
||||
'PrivateAttr',
|
||||
'ByteSize',
|
||||
'PastDate',
|
||||
'FutureDate',
|
||||
@@ -208,238 +182,44 @@ __all__ = (
|
||||
'Base64UrlBytes',
|
||||
'Base64UrlStr',
|
||||
'GetPydanticSchema',
|
||||
'Tag',
|
||||
'Discriminator',
|
||||
'JsonValue',
|
||||
'FailFast',
|
||||
# type_adapter
|
||||
'TypeAdapter',
|
||||
# version
|
||||
'__version__',
|
||||
'VERSION',
|
||||
# warnings
|
||||
'PydanticDeprecatedSince20',
|
||||
'PydanticDeprecatedSince26',
|
||||
'PydanticDeprecatedSince29',
|
||||
'PydanticDeprecatedSince210',
|
||||
'PydanticDeprecatedSince211',
|
||||
'PydanticDeprecationWarning',
|
||||
'PydanticExperimentalWarning',
|
||||
# annotated handlers
|
||||
'GetCoreSchemaHandler',
|
||||
'GetJsonSchemaHandler',
|
||||
# pydantic_core
|
||||
'ValidationError',
|
||||
'ValidationInfo',
|
||||
'SerializationInfo',
|
||||
'ValidatorFunctionWrapHandler',
|
||||
'FieldSerializationInfo',
|
||||
'SerializerFunctionWrapHandler',
|
||||
'OnErrorOmit',
|
||||
)
|
||||
'GenerateSchema',
|
||||
]
|
||||
|
||||
# A mapping of {<member name>: (package, <module name>)} defining dynamic imports
|
||||
_dynamic_imports: 'dict[str, tuple[str, str]]' = {
|
||||
'dataclasses': (__spec__.parent, '__module__'),
|
||||
# functional validators
|
||||
'field_validator': (__spec__.parent, '.functional_validators'),
|
||||
'model_validator': (__spec__.parent, '.functional_validators'),
|
||||
'AfterValidator': (__spec__.parent, '.functional_validators'),
|
||||
'BeforeValidator': (__spec__.parent, '.functional_validators'),
|
||||
'PlainValidator': (__spec__.parent, '.functional_validators'),
|
||||
'WrapValidator': (__spec__.parent, '.functional_validators'),
|
||||
'SkipValidation': (__spec__.parent, '.functional_validators'),
|
||||
'InstanceOf': (__spec__.parent, '.functional_validators'),
|
||||
'ModelWrapValidatorHandler': (__spec__.parent, '.functional_validators'),
|
||||
# JSON Schema
|
||||
'WithJsonSchema': (__spec__.parent, '.json_schema'),
|
||||
# functional serializers
|
||||
'field_serializer': (__spec__.parent, '.functional_serializers'),
|
||||
'model_serializer': (__spec__.parent, '.functional_serializers'),
|
||||
'PlainSerializer': (__spec__.parent, '.functional_serializers'),
|
||||
'SerializeAsAny': (__spec__.parent, '.functional_serializers'),
|
||||
'WrapSerializer': (__spec__.parent, '.functional_serializers'),
|
||||
# config
|
||||
'ConfigDict': (__spec__.parent, '.config'),
|
||||
'with_config': (__spec__.parent, '.config'),
|
||||
# validate call
|
||||
'validate_call': (__spec__.parent, '.validate_call_decorator'),
|
||||
# errors
|
||||
'PydanticErrorCodes': (__spec__.parent, '.errors'),
|
||||
'PydanticUserError': (__spec__.parent, '.errors'),
|
||||
'PydanticSchemaGenerationError': (__spec__.parent, '.errors'),
|
||||
'PydanticImportError': (__spec__.parent, '.errors'),
|
||||
'PydanticUndefinedAnnotation': (__spec__.parent, '.errors'),
|
||||
'PydanticInvalidForJsonSchema': (__spec__.parent, '.errors'),
|
||||
'PydanticForbiddenQualifier': (__spec__.parent, '.errors'),
|
||||
# fields
|
||||
'Field': (__spec__.parent, '.fields'),
|
||||
'computed_field': (__spec__.parent, '.fields'),
|
||||
'PrivateAttr': (__spec__.parent, '.fields'),
|
||||
# alias
|
||||
'AliasChoices': (__spec__.parent, '.aliases'),
|
||||
'AliasGenerator': (__spec__.parent, '.aliases'),
|
||||
'AliasPath': (__spec__.parent, '.aliases'),
|
||||
# main
|
||||
'BaseModel': (__spec__.parent, '.main'),
|
||||
'create_model': (__spec__.parent, '.main'),
|
||||
# network
|
||||
'AnyUrl': (__spec__.parent, '.networks'),
|
||||
'AnyHttpUrl': (__spec__.parent, '.networks'),
|
||||
'FileUrl': (__spec__.parent, '.networks'),
|
||||
'HttpUrl': (__spec__.parent, '.networks'),
|
||||
'FtpUrl': (__spec__.parent, '.networks'),
|
||||
'WebsocketUrl': (__spec__.parent, '.networks'),
|
||||
'AnyWebsocketUrl': (__spec__.parent, '.networks'),
|
||||
'UrlConstraints': (__spec__.parent, '.networks'),
|
||||
'EmailStr': (__spec__.parent, '.networks'),
|
||||
'NameEmail': (__spec__.parent, '.networks'),
|
||||
'IPvAnyAddress': (__spec__.parent, '.networks'),
|
||||
'IPvAnyInterface': (__spec__.parent, '.networks'),
|
||||
'IPvAnyNetwork': (__spec__.parent, '.networks'),
|
||||
'PostgresDsn': (__spec__.parent, '.networks'),
|
||||
'CockroachDsn': (__spec__.parent, '.networks'),
|
||||
'AmqpDsn': (__spec__.parent, '.networks'),
|
||||
'RedisDsn': (__spec__.parent, '.networks'),
|
||||
'MongoDsn': (__spec__.parent, '.networks'),
|
||||
'KafkaDsn': (__spec__.parent, '.networks'),
|
||||
'NatsDsn': (__spec__.parent, '.networks'),
|
||||
'MySQLDsn': (__spec__.parent, '.networks'),
|
||||
'MariaDBDsn': (__spec__.parent, '.networks'),
|
||||
'ClickHouseDsn': (__spec__.parent, '.networks'),
|
||||
'SnowflakeDsn': (__spec__.parent, '.networks'),
|
||||
'validate_email': (__spec__.parent, '.networks'),
|
||||
# root_model
|
||||
'RootModel': (__spec__.parent, '.root_model'),
|
||||
# types
|
||||
'Strict': (__spec__.parent, '.types'),
|
||||
'StrictStr': (__spec__.parent, '.types'),
|
||||
'conbytes': (__spec__.parent, '.types'),
|
||||
'conlist': (__spec__.parent, '.types'),
|
||||
'conset': (__spec__.parent, '.types'),
|
||||
'confrozenset': (__spec__.parent, '.types'),
|
||||
'constr': (__spec__.parent, '.types'),
|
||||
'StringConstraints': (__spec__.parent, '.types'),
|
||||
'ImportString': (__spec__.parent, '.types'),
|
||||
'conint': (__spec__.parent, '.types'),
|
||||
'PositiveInt': (__spec__.parent, '.types'),
|
||||
'NegativeInt': (__spec__.parent, '.types'),
|
||||
'NonNegativeInt': (__spec__.parent, '.types'),
|
||||
'NonPositiveInt': (__spec__.parent, '.types'),
|
||||
'confloat': (__spec__.parent, '.types'),
|
||||
'PositiveFloat': (__spec__.parent, '.types'),
|
||||
'NegativeFloat': (__spec__.parent, '.types'),
|
||||
'NonNegativeFloat': (__spec__.parent, '.types'),
|
||||
'NonPositiveFloat': (__spec__.parent, '.types'),
|
||||
'FiniteFloat': (__spec__.parent, '.types'),
|
||||
'condecimal': (__spec__.parent, '.types'),
|
||||
'condate': (__spec__.parent, '.types'),
|
||||
'UUID1': (__spec__.parent, '.types'),
|
||||
'UUID3': (__spec__.parent, '.types'),
|
||||
'UUID4': (__spec__.parent, '.types'),
|
||||
'UUID5': (__spec__.parent, '.types'),
|
||||
'UUID6': (__spec__.parent, '.types'),
|
||||
'UUID7': (__spec__.parent, '.types'),
|
||||
'UUID8': (__spec__.parent, '.types'),
|
||||
'FilePath': (__spec__.parent, '.types'),
|
||||
'DirectoryPath': (__spec__.parent, '.types'),
|
||||
'NewPath': (__spec__.parent, '.types'),
|
||||
'Json': (__spec__.parent, '.types'),
|
||||
'Secret': (__spec__.parent, '.types'),
|
||||
'SecretStr': (__spec__.parent, '.types'),
|
||||
'SecretBytes': (__spec__.parent, '.types'),
|
||||
'StrictBool': (__spec__.parent, '.types'),
|
||||
'StrictBytes': (__spec__.parent, '.types'),
|
||||
'StrictInt': (__spec__.parent, '.types'),
|
||||
'StrictFloat': (__spec__.parent, '.types'),
|
||||
'PaymentCardNumber': (__spec__.parent, '.types'),
|
||||
'ByteSize': (__spec__.parent, '.types'),
|
||||
'PastDate': (__spec__.parent, '.types'),
|
||||
'SocketPath': (__spec__.parent, '.types'),
|
||||
'FutureDate': (__spec__.parent, '.types'),
|
||||
'PastDatetime': (__spec__.parent, '.types'),
|
||||
'FutureDatetime': (__spec__.parent, '.types'),
|
||||
'AwareDatetime': (__spec__.parent, '.types'),
|
||||
'NaiveDatetime': (__spec__.parent, '.types'),
|
||||
'AllowInfNan': (__spec__.parent, '.types'),
|
||||
'EncoderProtocol': (__spec__.parent, '.types'),
|
||||
'EncodedBytes': (__spec__.parent, '.types'),
|
||||
'EncodedStr': (__spec__.parent, '.types'),
|
||||
'Base64Encoder': (__spec__.parent, '.types'),
|
||||
'Base64Bytes': (__spec__.parent, '.types'),
|
||||
'Base64Str': (__spec__.parent, '.types'),
|
||||
'Base64UrlBytes': (__spec__.parent, '.types'),
|
||||
'Base64UrlStr': (__spec__.parent, '.types'),
|
||||
'GetPydanticSchema': (__spec__.parent, '.types'),
|
||||
'Tag': (__spec__.parent, '.types'),
|
||||
'Discriminator': (__spec__.parent, '.types'),
|
||||
'JsonValue': (__spec__.parent, '.types'),
|
||||
'OnErrorOmit': (__spec__.parent, '.types'),
|
||||
'FailFast': (__spec__.parent, '.types'),
|
||||
# type_adapter
|
||||
'TypeAdapter': (__spec__.parent, '.type_adapter'),
|
||||
# warnings
|
||||
'PydanticDeprecatedSince20': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecatedSince26': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecatedSince29': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecatedSince210': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecatedSince211': (__spec__.parent, '.warnings'),
|
||||
'PydanticDeprecationWarning': (__spec__.parent, '.warnings'),
|
||||
'PydanticExperimentalWarning': (__spec__.parent, '.warnings'),
|
||||
# annotated handlers
|
||||
'GetCoreSchemaHandler': (__spec__.parent, '.annotated_handlers'),
|
||||
'GetJsonSchemaHandler': (__spec__.parent, '.annotated_handlers'),
|
||||
# pydantic_core stuff
|
||||
'ValidationError': ('pydantic_core', '.'),
|
||||
'ValidationInfo': ('pydantic_core', '.core_schema'),
|
||||
'SerializationInfo': ('pydantic_core', '.core_schema'),
|
||||
'ValidatorFunctionWrapHandler': ('pydantic_core', '.core_schema'),
|
||||
'FieldSerializationInfo': ('pydantic_core', '.core_schema'),
|
||||
'SerializerFunctionWrapHandler': ('pydantic_core', '.core_schema'),
|
||||
# deprecated, mostly not included in __all__
|
||||
'root_validator': (__spec__.parent, '.deprecated.class_validators'),
|
||||
'validator': (__spec__.parent, '.deprecated.class_validators'),
|
||||
'BaseConfig': (__spec__.parent, '.deprecated.config'),
|
||||
'Extra': (__spec__.parent, '.deprecated.config'),
|
||||
'parse_obj_as': (__spec__.parent, '.deprecated.tools'),
|
||||
'schema_of': (__spec__.parent, '.deprecated.tools'),
|
||||
'schema_json_of': (__spec__.parent, '.deprecated.tools'),
|
||||
# deprecated dynamic imports
|
||||
'RootModel': (__package__, '.root_model'),
|
||||
'root_validator': (__package__, '.deprecated.class_validators'),
|
||||
'validator': (__package__, '.deprecated.class_validators'),
|
||||
'BaseConfig': (__package__, '.deprecated.config'),
|
||||
'Extra': (__package__, '.deprecated.config'),
|
||||
'parse_obj_as': (__package__, '.deprecated.tools'),
|
||||
'schema_of': (__package__, '.deprecated.tools'),
|
||||
'schema_json_of': (__package__, '.deprecated.tools'),
|
||||
# FieldValidationInfo is deprecated, and hidden behind module a `__getattr__`
|
||||
'FieldValidationInfo': ('pydantic_core', '.core_schema'),
|
||||
'GenerateSchema': (__spec__.parent, '._internal._generate_schema'),
|
||||
}
|
||||
_deprecated_dynamic_imports = {'FieldValidationInfo', 'GenerateSchema'}
|
||||
|
||||
_getattr_migration = getattr_migration(__name__)
|
||||
|
||||
|
||||
def __getattr__(attr_name: str) -> object:
|
||||
if attr_name in _deprecated_dynamic_imports:
|
||||
warn(
|
||||
f'Importing {attr_name} from `pydantic` is deprecated. This feature is either no longer supported, or is not public.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
dynamic_attr = _dynamic_imports.get(attr_name)
|
||||
if dynamic_attr is None:
|
||||
return _getattr_migration(attr_name)
|
||||
|
||||
package, module_name = dynamic_attr
|
||||
|
||||
if module_name == '__module__':
|
||||
result = import_module(f'.{attr_name}', package=package)
|
||||
globals()[attr_name] = result
|
||||
return result
|
||||
else:
|
||||
module = import_module(module_name, package=package)
|
||||
result = getattr(module, attr_name)
|
||||
g = globals()
|
||||
for k, (_, v_module_name) in _dynamic_imports.items():
|
||||
if v_module_name == module_name and k not in _deprecated_dynamic_imports:
|
||||
g[k] = getattr(module, k)
|
||||
return result
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
def __dir__() -> 'list[str]':
|
||||
return list(__all__)
|
||||
module = import_module(module_name, package=package)
|
||||
return getattr(module, attr_name)
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from re import Pattern
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
ContextManager,
|
||||
Iterator,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import (
|
||||
Literal,
|
||||
Self,
|
||||
)
|
||||
|
||||
from ..aliases import AliasGenerator
|
||||
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
|
||||
from ..config import ConfigDict, ExtraValues, JsonEncoder, JsonSchemaExtraCallable
|
||||
from ..errors import PydanticUserError
|
||||
from ..warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince210
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
@@ -26,7 +28,6 @@ if not TYPE_CHECKING:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._internal._schema_generation_shared import GenerateSchema
|
||||
from ..fields import ComputedFieldInfo, FieldInfo
|
||||
|
||||
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
|
||||
|
||||
@@ -56,12 +57,10 @@ class ConfigWrapper:
|
||||
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
|
||||
# to construct error `loc`s, default `True`
|
||||
loc_by_alias: bool
|
||||
alias_generator: Callable[[str], str] | AliasGenerator | None
|
||||
model_title_generator: Callable[[type], str] | None
|
||||
field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None
|
||||
alias_generator: Callable[[str], str] | None
|
||||
ignored_types: tuple[type, ...]
|
||||
allow_inf_nan: bool
|
||||
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
|
||||
json_schema_extra: dict[str, object] | JsonSchemaExtraCallable | None
|
||||
json_encoders: dict[type[object], JsonEncoder] | None
|
||||
|
||||
# new in V2
|
||||
@@ -69,13 +68,11 @@ class ConfigWrapper:
|
||||
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
|
||||
revalidate_instances: Literal['always', 'never', 'subclass-instances']
|
||||
ser_json_timedelta: Literal['iso8601', 'float']
|
||||
ser_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
val_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
ser_json_inf_nan: Literal['null', 'constants', 'strings']
|
||||
ser_json_bytes: Literal['utf8', 'base64']
|
||||
# whether to validate default values during validation, default False
|
||||
validate_default: bool
|
||||
validate_return: bool
|
||||
protected_namespaces: tuple[str | Pattern[str], ...]
|
||||
protected_namespaces: tuple[str, ...]
|
||||
hide_input_in_errors: bool
|
||||
defer_build: bool
|
||||
plugin_settings: dict[str, object] | None
|
||||
@@ -83,13 +80,6 @@ class ConfigWrapper:
|
||||
json_schema_serialization_defaults_required: bool
|
||||
json_schema_mode_override: Literal['validation', 'serialization', None]
|
||||
coerce_numbers_to_str: bool
|
||||
regex_engine: Literal['rust-regex', 'python-re']
|
||||
validation_error_cause: bool
|
||||
use_attribute_docstrings: bool
|
||||
cache_strings: bool | Literal['all', 'keys', 'none']
|
||||
validate_by_alias: bool
|
||||
validate_by_name: bool
|
||||
serialize_by_alias: bool
|
||||
|
||||
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
|
||||
if check:
|
||||
@@ -123,19 +113,13 @@ class ConfigWrapper:
|
||||
config_class_from_namespace = namespace.get('Config')
|
||||
config_dict_from_namespace = namespace.get('model_config')
|
||||
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
if raw_annotations.get('model_config') and config_dict_from_namespace is None:
|
||||
raise PydanticUserError(
|
||||
'`model_config` cannot be used as a model field name. Use `model_config` for model configuration.',
|
||||
code='model-config-invalid-field-name',
|
||||
)
|
||||
|
||||
if config_class_from_namespace and config_dict_from_namespace:
|
||||
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
|
||||
|
||||
config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace)
|
||||
|
||||
config_new.update(config_from_namespace)
|
||||
if config_from_namespace is not None:
|
||||
config_new.update(config_from_namespace)
|
||||
|
||||
for k in list(kwargs.keys()):
|
||||
if k in config_keys:
|
||||
@@ -144,7 +128,7 @@ class ConfigWrapper:
|
||||
return cls(config_new)
|
||||
|
||||
# we don't show `__getattr__` to type checkers so missing attributes cause errors
|
||||
if not TYPE_CHECKING: # pragma: no branch
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
@@ -155,77 +139,46 @@ class ConfigWrapper:
|
||||
except KeyError:
|
||||
raise AttributeError(f'Config has no attribute {name!r}') from None
|
||||
|
||||
def core_config(self, title: str | None) -> core_schema.CoreConfig:
|
||||
"""Create a pydantic-core config.
|
||||
def core_config(self, obj: Any) -> core_schema.CoreConfig:
|
||||
"""Create a pydantic-core config, `obj` is just used to populate `title` if not set in config.
|
||||
|
||||
Pass `obj=None` if you do not want to attempt to infer the `title`.
|
||||
|
||||
We don't use getattr here since we don't want to populate with defaults.
|
||||
|
||||
Args:
|
||||
title: The title to use if not set in config.
|
||||
obj: An object used to populate `title` if not set in config.
|
||||
|
||||
Returns:
|
||||
A `CoreConfig` object created from config.
|
||||
"""
|
||||
config = self.config_dict
|
||||
|
||||
if config.get('schema_generator') is not None:
|
||||
warnings.warn(
|
||||
'The `schema_generator` setting has been deprecated since v2.10. This setting no longer has any effect.',
|
||||
PydanticDeprecatedSince210,
|
||||
stacklevel=2,
|
||||
def dict_not_none(**kwargs: Any) -> Any:
|
||||
return {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
core_config = core_schema.CoreConfig(
|
||||
**dict_not_none(
|
||||
title=self.config_dict.get('title') or (obj and obj.__name__),
|
||||
extra_fields_behavior=self.config_dict.get('extra'),
|
||||
allow_inf_nan=self.config_dict.get('allow_inf_nan'),
|
||||
populate_by_name=self.config_dict.get('populate_by_name'),
|
||||
str_strip_whitespace=self.config_dict.get('str_strip_whitespace'),
|
||||
str_to_lower=self.config_dict.get('str_to_lower'),
|
||||
str_to_upper=self.config_dict.get('str_to_upper'),
|
||||
strict=self.config_dict.get('strict'),
|
||||
ser_json_timedelta=self.config_dict.get('ser_json_timedelta'),
|
||||
ser_json_bytes=self.config_dict.get('ser_json_bytes'),
|
||||
from_attributes=self.config_dict.get('from_attributes'),
|
||||
loc_by_alias=self.config_dict.get('loc_by_alias'),
|
||||
revalidate_instances=self.config_dict.get('revalidate_instances'),
|
||||
validate_default=self.config_dict.get('validate_default'),
|
||||
str_max_length=self.config_dict.get('str_max_length'),
|
||||
str_min_length=self.config_dict.get('str_min_length'),
|
||||
hide_input_in_errors=self.config_dict.get('hide_input_in_errors'),
|
||||
coerce_numbers_to_str=self.config_dict.get('coerce_numbers_to_str'),
|
||||
)
|
||||
|
||||
if (populate_by_name := config.get('populate_by_name')) is not None:
|
||||
# We include this patch for backwards compatibility purposes, but this config setting will be deprecated in v3.0, and likely removed in v4.0.
|
||||
# Thus, the above warning and this patch can be removed then as well.
|
||||
if config.get('validate_by_name') is None:
|
||||
config['validate_by_alias'] = True
|
||||
config['validate_by_name'] = populate_by_name
|
||||
|
||||
# We dynamically patch validate_by_name to be True if validate_by_alias is set to False
|
||||
# and validate_by_name is not explicitly set.
|
||||
if config.get('validate_by_alias') is False and config.get('validate_by_name') is None:
|
||||
config['validate_by_name'] = True
|
||||
|
||||
if (not config.get('validate_by_alias', True)) and (not config.get('validate_by_name', False)):
|
||||
raise PydanticUserError(
|
||||
'At least one of `validate_by_alias` or `validate_by_name` must be set to True.',
|
||||
code='validate-by-alias-and-name-false',
|
||||
)
|
||||
|
||||
return core_schema.CoreConfig(
|
||||
**{ # pyright: ignore[reportArgumentType]
|
||||
k: v
|
||||
for k, v in (
|
||||
('title', config.get('title') or title or None),
|
||||
('extra_fields_behavior', config.get('extra')),
|
||||
('allow_inf_nan', config.get('allow_inf_nan')),
|
||||
('str_strip_whitespace', config.get('str_strip_whitespace')),
|
||||
('str_to_lower', config.get('str_to_lower')),
|
||||
('str_to_upper', config.get('str_to_upper')),
|
||||
('strict', config.get('strict')),
|
||||
('ser_json_timedelta', config.get('ser_json_timedelta')),
|
||||
('ser_json_bytes', config.get('ser_json_bytes')),
|
||||
('val_json_bytes', config.get('val_json_bytes')),
|
||||
('ser_json_inf_nan', config.get('ser_json_inf_nan')),
|
||||
('from_attributes', config.get('from_attributes')),
|
||||
('loc_by_alias', config.get('loc_by_alias')),
|
||||
('revalidate_instances', config.get('revalidate_instances')),
|
||||
('validate_default', config.get('validate_default')),
|
||||
('str_max_length', config.get('str_max_length')),
|
||||
('str_min_length', config.get('str_min_length')),
|
||||
('hide_input_in_errors', config.get('hide_input_in_errors')),
|
||||
('coerce_numbers_to_str', config.get('coerce_numbers_to_str')),
|
||||
('regex_engine', config.get('regex_engine')),
|
||||
('validation_error_cause', config.get('validation_error_cause')),
|
||||
('cache_strings', config.get('cache_strings')),
|
||||
('validate_by_alias', config.get('validate_by_alias')),
|
||||
('validate_by_name', config.get('validate_by_name')),
|
||||
('serialize_by_alias', config.get('serialize_by_alias')),
|
||||
)
|
||||
if v is not None
|
||||
}
|
||||
)
|
||||
return core_config
|
||||
|
||||
def __repr__(self):
|
||||
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
|
||||
@@ -242,20 +195,22 @@ class ConfigWrapperStack:
|
||||
def tail(self) -> ConfigWrapper:
|
||||
return self._config_wrapper_stack[-1]
|
||||
|
||||
@contextmanager
|
||||
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
|
||||
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]:
|
||||
if config_wrapper is None:
|
||||
yield
|
||||
return
|
||||
return nullcontext()
|
||||
|
||||
if not isinstance(config_wrapper, ConfigWrapper):
|
||||
config_wrapper = ConfigWrapper(config_wrapper, check=False)
|
||||
|
||||
self._config_wrapper_stack.append(config_wrapper)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._config_wrapper_stack.pop()
|
||||
@contextmanager
|
||||
def _context_manager() -> Iterator[None]:
|
||||
self._config_wrapper_stack.append(config_wrapper)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._config_wrapper_stack.pop()
|
||||
|
||||
return _context_manager()
|
||||
|
||||
|
||||
config_defaults = ConfigDict(
|
||||
@@ -275,8 +230,6 @@ config_defaults = ConfigDict(
|
||||
from_attributes=False,
|
||||
loc_by_alias=True,
|
||||
alias_generator=None,
|
||||
model_title_generator=None,
|
||||
field_title_generator=None,
|
||||
ignored_types=(),
|
||||
allow_inf_nan=True,
|
||||
json_schema_extra=None,
|
||||
@@ -284,26 +237,17 @@ config_defaults = ConfigDict(
|
||||
revalidate_instances='never',
|
||||
ser_json_timedelta='iso8601',
|
||||
ser_json_bytes='utf8',
|
||||
val_json_bytes='utf8',
|
||||
ser_json_inf_nan='null',
|
||||
validate_default=False,
|
||||
validate_return=False,
|
||||
protected_namespaces=('model_validate', 'model_dump'),
|
||||
protected_namespaces=('model_',),
|
||||
hide_input_in_errors=False,
|
||||
json_encoders=None,
|
||||
defer_build=False,
|
||||
schema_generator=None,
|
||||
plugin_settings=None,
|
||||
schema_generator=None,
|
||||
json_schema_serialization_defaults_required=False,
|
||||
json_schema_mode_override=None,
|
||||
coerce_numbers_to_str=False,
|
||||
regex_engine='rust-regex',
|
||||
validation_error_cause=False,
|
||||
use_attribute_docstrings=False,
|
||||
cache_strings=True,
|
||||
validate_by_alias=True,
|
||||
validate_by_name=False,
|
||||
serialize_by_alias=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -344,7 +288,7 @@ V2_REMOVED_KEYS = {
|
||||
'post_init_call',
|
||||
}
|
||||
V2_RENAMED_KEYS = {
|
||||
'allow_population_by_field_name': 'validate_by_name',
|
||||
'allow_population_by_field_name': 'populate_by_name',
|
||||
'anystr_lower': 'str_to_lower',
|
||||
'anystr_strip_whitespace': 'str_strip_whitespace',
|
||||
'anystr_upper': 'str_to_upper',
|
||||
|
||||
@@ -1,97 +1,92 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
from warnings import warn
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config import JsonDict, JsonSchemaExtraCallable
|
||||
import typing_extensions
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._schema_generation_shared import (
|
||||
CoreSchemaOrField as CoreSchemaOrField,
|
||||
)
|
||||
from ._schema_generation_shared import (
|
||||
GetJsonSchemaFunction,
|
||||
)
|
||||
|
||||
|
||||
class CoreMetadata(TypedDict, total=False):
|
||||
class CoreMetadata(typing_extensions.TypedDict, total=False):
|
||||
"""A `TypedDict` for holding the metadata dict of the schema.
|
||||
|
||||
Attributes:
|
||||
pydantic_js_functions: List of JSON schema functions that resolve refs during application.
|
||||
pydantic_js_annotation_functions: List of JSON schema functions that don't resolve refs during application.
|
||||
pydantic_js_functions: List of JSON schema functions.
|
||||
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
|
||||
prefer positional over keyword arguments for an 'arguments' schema.
|
||||
custom validation function. Only applies to before, plain, and wrap validators.
|
||||
pydantic_js_updates: key / value pair updates to apply to the JSON schema for a type.
|
||||
pydantic_js_extra: WIP, either key/value pair updates to apply to the JSON schema, or a custom callable.
|
||||
pydantic_internal_union_tag_key: Used internally by the `Tag` metadata to specify the tag used for a discriminated union.
|
||||
pydantic_internal_union_discriminator: Used internally to specify the discriminator value for a discriminated union
|
||||
when the discriminator was applied to a `'definition-ref'` schema, and that reference was missing at the time
|
||||
of the annotation application.
|
||||
|
||||
TODO: Perhaps we should move this structure to pydantic-core. At the moment, though,
|
||||
it's easier to iterate on if we leave it in pydantic until we feel there is a semi-stable API.
|
||||
|
||||
TODO: It's unfortunate how functionally oriented JSON schema generation is, especially that which occurs during
|
||||
the core schema generation process. It's inevitable that we need to store some json schema related information
|
||||
on core schemas, given that we generate JSON schemas directly from core schemas. That being said, debugging related
|
||||
issues is quite difficult when JSON schema information is disguised via dynamically defined functions.
|
||||
"""
|
||||
|
||||
pydantic_js_functions: list[GetJsonSchemaFunction]
|
||||
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
|
||||
pydantic_js_prefer_positional_arguments: bool
|
||||
pydantic_js_updates: JsonDict
|
||||
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable
|
||||
pydantic_internal_union_tag_key: str
|
||||
pydantic_internal_union_discriminator: str
|
||||
|
||||
# If `pydantic_js_prefer_positional_arguments` is True, the JSON schema generator will
|
||||
# prefer positional over keyword arguments for an 'arguments' schema.
|
||||
pydantic_js_prefer_positional_arguments: bool | None
|
||||
|
||||
pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema
|
||||
|
||||
|
||||
def update_core_metadata(
|
||||
core_metadata: Any,
|
||||
/,
|
||||
*,
|
||||
pydantic_js_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
pydantic_js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
pydantic_js_updates: JsonDict | None = None,
|
||||
pydantic_js_extra: JsonDict | JsonSchemaExtraCallable | None = None,
|
||||
) -> None:
|
||||
from ..json_schema import PydanticJsonSchemaWarning
|
||||
class CoreMetadataHandler:
|
||||
"""Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents.
|
||||
|
||||
"""Update CoreMetadata instance in place. When we make modifications in this function, they
|
||||
take effect on the `core_metadata` reference passed in as the first (and only) positional argument.
|
||||
|
||||
First, cast to `CoreMetadata`, then finish with a cast to `dict[str, Any]` for core schema compatibility.
|
||||
We do this here, instead of before / after each call to this function so that this typing hack
|
||||
can be easily removed if/when we move `CoreMetadata` to `pydantic-core`.
|
||||
|
||||
For parameter descriptions, see `CoreMetadata` above.
|
||||
This class is used to interact with the metadata field on a CoreSchema object in a consistent
|
||||
way throughout pydantic.
|
||||
"""
|
||||
core_metadata = cast(CoreMetadata, core_metadata)
|
||||
|
||||
if pydantic_js_functions:
|
||||
core_metadata.setdefault('pydantic_js_functions', []).extend(pydantic_js_functions)
|
||||
__slots__ = ('_schema',)
|
||||
|
||||
if pydantic_js_annotation_functions:
|
||||
core_metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions)
|
||||
def __init__(self, schema: CoreSchemaOrField):
|
||||
self._schema = schema
|
||||
|
||||
if pydantic_js_updates:
|
||||
if (existing_updates := core_metadata.get('pydantic_js_updates')) is not None:
|
||||
core_metadata['pydantic_js_updates'] = {**existing_updates, **pydantic_js_updates}
|
||||
else:
|
||||
core_metadata['pydantic_js_updates'] = pydantic_js_updates
|
||||
metadata = schema.get('metadata')
|
||||
if metadata is None:
|
||||
schema['metadata'] = CoreMetadata()
|
||||
elif not isinstance(metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
|
||||
|
||||
if pydantic_js_extra is not None:
|
||||
existing_pydantic_js_extra = core_metadata.get('pydantic_js_extra')
|
||||
if existing_pydantic_js_extra is None:
|
||||
core_metadata['pydantic_js_extra'] = pydantic_js_extra
|
||||
if isinstance(existing_pydantic_js_extra, dict):
|
||||
if isinstance(pydantic_js_extra, dict):
|
||||
core_metadata['pydantic_js_extra'] = {**existing_pydantic_js_extra, **pydantic_js_extra}
|
||||
if callable(pydantic_js_extra):
|
||||
warn(
|
||||
'Composing `dict` and `callable` type `json_schema_extra` is not supported.'
|
||||
'The `callable` type is being ignored.'
|
||||
"If you'd like support for this behavior, please open an issue on pydantic.",
|
||||
PydanticJsonSchemaWarning,
|
||||
)
|
||||
if callable(existing_pydantic_js_extra):
|
||||
# if ever there's a case of a callable, we'll just keep the last json schema extra spec
|
||||
core_metadata['pydantic_js_extra'] = pydantic_js_extra
|
||||
@property
|
||||
def metadata(self) -> CoreMetadata:
|
||||
"""Retrieves the metadata dict from the schema, initializing it to a dict if it is None
|
||||
and raises an error if it is not a dict.
|
||||
"""
|
||||
metadata = self._schema.get('metadata')
|
||||
if metadata is None:
|
||||
self._schema['metadata'] = metadata = CoreMetadata()
|
||||
if not isinstance(metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
|
||||
return metadata
|
||||
|
||||
|
||||
def build_metadata_dict(
|
||||
*, # force keyword arguments to make it easier to modify this signature in a backwards-compatible way
|
||||
js_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
|
||||
js_prefer_positional_arguments: bool | None = None,
|
||||
typed_dict_cls: type[Any] | None = None,
|
||||
initial_metadata: Any | None = None,
|
||||
) -> Any:
|
||||
"""Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent
|
||||
with the CoreMetadataHandler class.
|
||||
"""
|
||||
if initial_metadata is not None and not isinstance(initial_metadata, dict):
|
||||
raise TypeError(f'CoreSchema metadata should be a dict; got {initial_metadata!r}.')
|
||||
|
||||
metadata = CoreMetadata(
|
||||
pydantic_js_functions=js_functions or [],
|
||||
pydantic_js_annotation_functions=js_annotation_functions or [],
|
||||
pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments,
|
||||
pydantic_typed_dict_cls=typed_dict_cls,
|
||||
)
|
||||
metadata = {k: v for k, v in metadata.items() if v is not None}
|
||||
|
||||
if initial_metadata is not None:
|
||||
metadata = {**initial_metadata, **metadata}
|
||||
|
||||
return metadata
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Hashable,
|
||||
TypeVar,
|
||||
Union,
|
||||
_GenericAlias, # type: ignore
|
||||
cast,
|
||||
)
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic_core import validate_core_schema as _validate_core_schema
|
||||
from typing_extensions import TypeGuard, get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
from typing_extensions import TypeAliasType, TypeGuard, get_args
|
||||
|
||||
from . import _repr
|
||||
from ._typing_extra import is_generic_alias
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from rich.console import Console
|
||||
|
||||
AnyFunctionSchema = Union[
|
||||
core_schema.AfterValidatorFunctionSchema,
|
||||
@@ -37,7 +39,19 @@ CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
|
||||
|
||||
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
|
||||
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
|
||||
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'}
|
||||
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'}
|
||||
|
||||
_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'
|
||||
|
||||
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'
|
||||
"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of
|
||||
schema building because one of it's members refers to a definition that was not yet defined when the union
|
||||
was first encountered.
|
||||
"""
|
||||
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
|
||||
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
|
||||
schema was first encountered.
|
||||
"""
|
||||
|
||||
|
||||
def is_core_schema(
|
||||
@@ -60,27 +74,28 @@ def is_function_with_inner_schema(
|
||||
|
||||
def is_list_like_schema_with_items_schema(
|
||||
schema: CoreSchema,
|
||||
) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]:
|
||||
) -> TypeGuard[
|
||||
core_schema.ListSchema | core_schema.TupleVariableSchema | core_schema.SetSchema | core_schema.FrozenSetSchema
|
||||
]:
|
||||
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
|
||||
|
||||
|
||||
def get_type_ref(type_: Any, args_override: tuple[type[Any], ...] | None = None) -> str:
|
||||
def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
|
||||
"""Produces the ref to be used for this type by pydantic_core's core schemas.
|
||||
|
||||
This `args_override` argument was added for the purpose of creating valid recursive references
|
||||
when creating generic models without needing to create a concrete class.
|
||||
"""
|
||||
origin = get_origin(type_) or type_
|
||||
|
||||
args = get_args(type_) if is_generic_alias(type_) else (args_override or ())
|
||||
origin = type_
|
||||
args = get_args(type_) if isinstance(type_, _GenericAlias) else (args_override or ())
|
||||
generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
|
||||
if generic_metadata:
|
||||
origin = generic_metadata['origin'] or origin
|
||||
args = generic_metadata['args'] or args
|
||||
|
||||
module_name = getattr(origin, '__module__', '<No __module__>')
|
||||
if typing_objects.is_typealiastype(origin):
|
||||
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
|
||||
if isinstance(origin, TypeAliasType):
|
||||
type_ref = f'{module_name}.{origin.__name__}'
|
||||
else:
|
||||
try:
|
||||
qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
|
||||
@@ -109,74 +124,457 @@ def get_ref(s: core_schema.CoreSchema) -> None | str:
|
||||
return s.get('ref', None)
|
||||
|
||||
|
||||
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
|
||||
if os.getenv('PYDANTIC_VALIDATE_CORE_SCHEMAS'):
|
||||
return _validate_core_schema(schema)
|
||||
def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
|
||||
defs: dict[str, CoreSchema] = {}
|
||||
|
||||
def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
ref = get_ref(s)
|
||||
if ref:
|
||||
defs[ref] = s
|
||||
return recurse(s, _record_valid_refs)
|
||||
|
||||
walk_core_schema(schema, _record_valid_refs)
|
||||
|
||||
return defs
|
||||
|
||||
|
||||
def define_expected_missing_refs(
|
||||
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
|
||||
) -> core_schema.CoreSchema | None:
|
||||
if not allowed_missing_refs:
|
||||
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
|
||||
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
|
||||
return None
|
||||
|
||||
refs = collect_definitions(schema).keys()
|
||||
|
||||
expected_missing_refs = allowed_missing_refs.difference(refs)
|
||||
if expected_missing_refs:
|
||||
definitions: list[core_schema.CoreSchema] = [
|
||||
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
|
||||
# Issue: https://github.com/pydantic/pydantic-core/issues/619
|
||||
core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})
|
||||
for ref in expected_missing_refs
|
||||
]
|
||||
return core_schema.definitions_schema(schema, definitions)
|
||||
return None
|
||||
|
||||
|
||||
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
|
||||
invalid = False
|
||||
|
||||
def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal invalid
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata:
|
||||
invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY]
|
||||
return s
|
||||
return recurse(s, _is_schema_valid)
|
||||
|
||||
walk_core_schema(schema, _is_schema_valid)
|
||||
return invalid
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema]
|
||||
Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema]
|
||||
|
||||
# TODO: Should we move _WalkCoreSchema into pydantic_core proper?
|
||||
# Issue: https://github.com/pydantic/pydantic-core/issues/615
|
||||
|
||||
|
||||
class _WalkCoreSchema:
|
||||
def __init__(self):
|
||||
self._schema_type_to_method = self._build_schema_type_to_method()
|
||||
|
||||
def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]:
|
||||
mapping: dict[core_schema.CoreSchemaType, Recurse] = {}
|
||||
key: core_schema.CoreSchemaType
|
||||
for key in get_args(core_schema.CoreSchemaType):
|
||||
method_name = f"handle_{key.replace('-', '_')}_schema"
|
||||
mapping[key] = getattr(self, method_name, self._handle_other_schemas)
|
||||
return mapping
|
||||
|
||||
def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
return f(schema, self._walk)
|
||||
|
||||
def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
|
||||
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
|
||||
if ser_schema:
|
||||
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
|
||||
return schema
|
||||
|
||||
def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
sub_schema = schema.get('schema', None)
|
||||
if sub_schema is not None:
|
||||
schema['schema'] = self.walk(sub_schema, f) # type: ignore
|
||||
return schema
|
||||
|
||||
def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
|
||||
schema: core_schema.CoreSchema | None = ser_schema.get('schema', None)
|
||||
if schema is not None:
|
||||
ser_schema['schema'] = self.walk(schema, f) # type: ignore
|
||||
return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None)
|
||||
if return_schema is not None:
|
||||
ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore
|
||||
return ser_schema
|
||||
|
||||
def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_definitions: list[core_schema.CoreSchema] = []
|
||||
for definition in schema['definitions']:
|
||||
updated_definition = self.walk(definition, f)
|
||||
if 'ref' in updated_definition:
|
||||
# If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
|
||||
# This is most likely to happen due to replacing something with a definition reference, in
|
||||
# which case it should certainly not go in the definitions list
|
||||
new_definitions.append(updated_definition)
|
||||
new_inner_schema = self.walk(schema['schema'], f)
|
||||
|
||||
if not new_definitions and len(schema) == 3:
|
||||
# This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
|
||||
return new_inner_schema
|
||||
|
||||
new_schema = schema.copy()
|
||||
new_schema['schema'] = new_inner_schema
|
||||
new_schema['definitions'] = new_definitions
|
||||
return new_schema
|
||||
|
||||
def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_tuple_variable_schema(
|
||||
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = cast(core_schema.TupleVariableSchema, schema)
|
||||
items_schema = schema.get('items_schema')
|
||||
if items_schema is not None:
|
||||
schema['items_schema'] = self.walk(items_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_tuple_positional_schema(
|
||||
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = cast(core_schema.TuplePositionalSchema, schema)
|
||||
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
keys_schema = schema.get('keys_schema')
|
||||
if keys_schema is not None:
|
||||
schema['keys_schema'] = self.walk(keys_schema, f)
|
||||
values_schema = schema.get('values_schema')
|
||||
if values_schema:
|
||||
schema['values_schema'] = self.walk(values_schema, f)
|
||||
return schema
|
||||
|
||||
def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
if not is_function_with_inner_schema(schema):
|
||||
return schema
|
||||
schema['schema'] = self.walk(schema['schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
|
||||
for v in schema['choices']:
|
||||
if isinstance(v, tuple):
|
||||
new_choices.append((self.walk(v[0], f), v[1]))
|
||||
else:
|
||||
new_choices.append(self.walk(v, f))
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
|
||||
def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
new_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
for k, v in schema['choices'].items():
|
||||
new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f)
|
||||
schema['choices'] = new_choices
|
||||
return schema
|
||||
|
||||
def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['steps'] = [self.walk(v, f) for v in schema['steps']]
|
||||
return schema
|
||||
|
||||
def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['lax_schema'] = self.walk(schema['lax_schema'], f)
|
||||
schema['strict_schema'] = self.walk(schema['strict_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['json_schema'] = self.walk(schema['json_schema'], f)
|
||||
schema['python_schema'] = self.walk(schema['python_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_fields: dict[str, core_schema.ModelField] = {}
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = v.copy()
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
extras_schema = schema.get('extras_schema')
|
||||
if extras_schema is not None:
|
||||
schema['extras_schema'] = self.walk(extras_schema, f)
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
replaced_fields: dict[str, core_schema.TypedDictField] = {}
|
||||
for k, v in schema['fields'].items():
|
||||
replaced_field = v.copy()
|
||||
replaced_field['schema'] = self.walk(v['schema'], f)
|
||||
replaced_fields[k] = replaced_field
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_fields: list[core_schema.DataclassField] = []
|
||||
replaced_computed_fields: list[core_schema.ComputedField] = []
|
||||
for computed_field in schema.get('computed_fields', ()):
|
||||
replaced_field = computed_field.copy()
|
||||
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
|
||||
replaced_computed_fields.append(replaced_field)
|
||||
if replaced_computed_fields:
|
||||
schema['computed_fields'] = replaced_computed_fields
|
||||
for field in schema['fields']:
|
||||
replaced_field = field.copy()
|
||||
replaced_field['schema'] = self.walk(field['schema'], f)
|
||||
replaced_fields.append(replaced_field)
|
||||
schema['fields'] = replaced_fields
|
||||
return schema
|
||||
|
||||
def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
replaced_arguments_schema: list[core_schema.ArgumentsParameter] = []
|
||||
for param in schema['arguments_schema']:
|
||||
replaced_param = param.copy()
|
||||
replaced_param['schema'] = self.walk(param['schema'], f)
|
||||
replaced_arguments_schema.append(replaced_param)
|
||||
schema['arguments_schema'] = replaced_arguments_schema
|
||||
if 'var_args_schema' in schema:
|
||||
schema['var_args_schema'] = self.walk(schema['var_args_schema'], f)
|
||||
if 'var_kwargs_schema' in schema:
|
||||
schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f)
|
||||
return schema
|
||||
|
||||
def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
schema['arguments_schema'] = self.walk(schema['arguments_schema'], f)
|
||||
if 'return_schema' in schema:
|
||||
schema['return_schema'] = self.walk(schema['return_schema'], f)
|
||||
return schema
|
||||
|
||||
|
||||
_dispatch = _WalkCoreSchema().walk
|
||||
|
||||
|
||||
def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
|
||||
"""Recursively traverse a CoreSchema.
|
||||
|
||||
Args:
|
||||
schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified.
|
||||
f (Walk): A function to apply. This function takes two arguments:
|
||||
1. The current CoreSchema that is being processed
|
||||
(not the same one you passed into this function, one level down).
|
||||
2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)`
|
||||
to pass data down the recursive calls without using globals or other mutable state.
|
||||
|
||||
Returns:
|
||||
core_schema.CoreSchema: A processed CoreSchema.
|
||||
"""
|
||||
return f(schema.copy(), _dispatch)
|
||||
|
||||
|
||||
def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
|
||||
definitions: dict[str, core_schema.CoreSchema] = {}
|
||||
ref_counts: dict[str, int] = defaultdict(int)
|
||||
involved_in_recursion: dict[str, bool] = {}
|
||||
current_recursion_ref_count: dict[str, int] = defaultdict(int)
|
||||
|
||||
def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] == 'definitions':
|
||||
for definition in s['definitions']:
|
||||
ref = get_ref(definition)
|
||||
assert ref is not None
|
||||
if ref not in definitions:
|
||||
definitions[ref] = definition
|
||||
recurse(definition, collect_refs)
|
||||
return recurse(s['schema'], collect_refs)
|
||||
else:
|
||||
ref = get_ref(s)
|
||||
if ref is not None:
|
||||
new = recurse(s, collect_refs)
|
||||
new_ref = get_ref(new)
|
||||
if new_ref:
|
||||
definitions[new_ref] = new
|
||||
return core_schema.definition_reference_schema(schema_ref=ref)
|
||||
else:
|
||||
return recurse(s, collect_refs)
|
||||
|
||||
schema = walk_core_schema(schema, collect_refs)
|
||||
|
||||
def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] != 'definition-ref':
|
||||
return recurse(s, count_refs)
|
||||
ref = s['schema_ref']
|
||||
ref_counts[ref] += 1
|
||||
|
||||
if ref_counts[ref] >= 2:
|
||||
# If this model is involved in a recursion this should be detected
|
||||
# on its second encounter, we can safely stop the walk here.
|
||||
if current_recursion_ref_count[ref] != 0:
|
||||
involved_in_recursion[ref] = True
|
||||
return s
|
||||
|
||||
current_recursion_ref_count[ref] += 1
|
||||
recurse(definitions[ref], count_refs)
|
||||
current_recursion_ref_count[ref] -= 1
|
||||
return s
|
||||
|
||||
schema = walk_core_schema(schema, count_refs)
|
||||
|
||||
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
|
||||
|
||||
def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
|
||||
if ref_counts[ref] > 1:
|
||||
return False
|
||||
if involved_in_recursion.get(ref, False):
|
||||
return False
|
||||
if 'serialization' in s:
|
||||
return False
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
for k in (
|
||||
'pydantic_js_functions',
|
||||
'pydantic_js_annotation_functions',
|
||||
'pydantic.internal.union_discriminator',
|
||||
):
|
||||
if k in metadata:
|
||||
# we need to keep this as a ref
|
||||
return False
|
||||
return True
|
||||
|
||||
def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
|
||||
if s['type'] == 'definition-ref':
|
||||
ref = s['schema_ref']
|
||||
# Check if the reference is only used once, not involved in recursion and does not have
|
||||
# any extra keys (like 'serialization')
|
||||
if can_be_inlined(s, ref):
|
||||
# Inline the reference by replacing the reference with the actual schema
|
||||
new = definitions.pop(ref)
|
||||
ref_counts[ref] -= 1 # because we just replaced it!
|
||||
# put all other keys that were on the def-ref schema into the inlined version
|
||||
# in particular this is needed for `serialization`
|
||||
if 'serialization' in s:
|
||||
new['serialization'] = s['serialization']
|
||||
s = recurse(new, inline_refs)
|
||||
return s
|
||||
else:
|
||||
return recurse(s, inline_refs)
|
||||
else:
|
||||
return recurse(s, inline_refs)
|
||||
|
||||
schema = walk_core_schema(schema, inline_refs)
|
||||
|
||||
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore
|
||||
|
||||
if def_values:
|
||||
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
|
||||
return schema
|
||||
|
||||
|
||||
def _clean_schema_for_pretty_print(obj: Any, strip_metadata: bool = True) -> Any: # pragma: no cover
|
||||
"""A utility function to remove irrelevant information from a core schema."""
|
||||
if isinstance(obj, Mapping):
|
||||
new_dct = {}
|
||||
for k, v in obj.items():
|
||||
if k == 'metadata' and strip_metadata:
|
||||
new_metadata = {}
|
||||
|
||||
for meta_k, meta_v in v.items():
|
||||
if meta_k in ('pydantic_js_functions', 'pydantic_js_annotation_functions'):
|
||||
new_metadata['js_metadata'] = '<stripped>'
|
||||
else:
|
||||
new_metadata[meta_k] = _clean_schema_for_pretty_print(meta_v, strip_metadata=strip_metadata)
|
||||
|
||||
if list(new_metadata.keys()) == ['js_metadata']:
|
||||
new_metadata = {'<stripped>'}
|
||||
|
||||
new_dct[k] = new_metadata
|
||||
# Remove some defaults:
|
||||
elif k in ('custom_init', 'root_model') and not v:
|
||||
continue
|
||||
def _strip_metadata(schema: CoreSchema) -> CoreSchema:
|
||||
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
|
||||
s = s.copy()
|
||||
s.pop('metadata', None)
|
||||
if s['type'] == 'model-fields':
|
||||
s = s.copy()
|
||||
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
|
||||
for field_name, field_schema in s['fields'].items():
|
||||
field_schema.pop('metadata', None)
|
||||
s['fields'][field_name] = field_schema
|
||||
computed_fields = s.get('computed_fields', None)
|
||||
if computed_fields:
|
||||
s['computed_fields'] = [cf.copy() for cf in computed_fields]
|
||||
for cf in computed_fields:
|
||||
cf.pop('metadata', None)
|
||||
else:
|
||||
new_dct[k] = _clean_schema_for_pretty_print(v, strip_metadata=strip_metadata)
|
||||
s.pop('computed_fields', None)
|
||||
elif s['type'] == 'model':
|
||||
# remove some defaults
|
||||
if s.get('custom_init', True) is False:
|
||||
s.pop('custom_init')
|
||||
if s.get('root_model', True) is False:
|
||||
s.pop('root_model')
|
||||
if {'title'}.issuperset(s.get('config', {}).keys()):
|
||||
s.pop('config', None)
|
||||
|
||||
return new_dct
|
||||
elif isinstance(obj, Sequence) and not isinstance(obj, str):
|
||||
return [_clean_schema_for_pretty_print(v, strip_metadata=strip_metadata) for v in obj]
|
||||
else:
|
||||
return obj
|
||||
return recurse(s, strip_metadata)
|
||||
|
||||
return walk_core_schema(schema, strip_metadata)
|
||||
|
||||
|
||||
def pretty_print_core_schema(
|
||||
val: Any,
|
||||
*,
|
||||
console: Console | None = None,
|
||||
max_depth: int | None = None,
|
||||
strip_metadata: bool = True,
|
||||
) -> None: # pragma: no cover
|
||||
"""Pretty-print a core schema using the `rich` library.
|
||||
schema: CoreSchema,
|
||||
include_metadata: bool = False,
|
||||
) -> None:
|
||||
"""Pretty print a CoreSchema using rich.
|
||||
This is intended for debugging purposes.
|
||||
|
||||
Args:
|
||||
val: The core schema to print, or a Pydantic model/dataclass/type adapter
|
||||
(in which case the cached core schema is fetched and printed).
|
||||
console: A rich console to use when printing. Defaults to the global rich console instance.
|
||||
max_depth: The number of nesting levels which may be printed.
|
||||
strip_metadata: Whether to strip metadata in the output. If `True` any known core metadata
|
||||
attributes will be stripped (but custom attributes are kept). Defaults to `True`.
|
||||
schema: The CoreSchema to print.
|
||||
include_metadata: Whether to include metadata in the output. Defaults to `False`.
|
||||
"""
|
||||
# lazy import:
|
||||
from rich.pretty import pprint
|
||||
from rich import print # type: ignore # install it manually in your dev env
|
||||
|
||||
# circ. imports:
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
if not include_metadata:
|
||||
schema = _strip_metadata(schema)
|
||||
|
||||
if (inspect.isclass(val) and issubclass(val, BaseModel)) or is_pydantic_dataclass(val):
|
||||
val = val.__pydantic_core_schema__
|
||||
if isinstance(val, TypeAdapter):
|
||||
val = val.core_schema
|
||||
cleaned_schema = _clean_schema_for_pretty_print(val, strip_metadata=strip_metadata)
|
||||
|
||||
pprint(cleaned_schema, console=console, max_depth=max_depth)
|
||||
return print(schema)
|
||||
|
||||
|
||||
pps = pretty_print_core_schema
|
||||
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
|
||||
if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ:
|
||||
return schema
|
||||
return _validate_core_schema(schema)
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
"""Private logic for creating pydantic dataclasses."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import typing
|
||||
import warnings
|
||||
from functools import partial, wraps
|
||||
from typing import Any, ClassVar
|
||||
from inspect import Parameter, Signature, signature
|
||||
from typing import Any, Callable, ClassVar
|
||||
|
||||
from pydantic_core import (
|
||||
ArgsKwargs,
|
||||
PydanticUndefined,
|
||||
SchemaSerializer,
|
||||
SchemaValidator,
|
||||
core_schema,
|
||||
@@ -17,22 +19,28 @@ from pydantic_core import (
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
|
||||
from ..fields import FieldInfo
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
from . import _config, _decorators
|
||||
from . import _config, _decorators, _discriminated_union, _typing_extra
|
||||
from ._core_utils import collect_invalid_schemas, simplify_schema_references, validate_core_schema
|
||||
from ._fields import collect_dataclass_fields
|
||||
from ._generate_schema import GenerateSchema, InvalidSchemaError
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._generics import get_standard_typevars_map
|
||||
from ._mock_val_ser import set_dataclass_mocks
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._signature import generate_pydantic_signature
|
||||
from ._utils import LazyClassAttribute
|
||||
from ._mock_val_ser import set_dataclass_mock_validator
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._utils import is_valid_identifier
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance as StandardDataclass
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..fields import FieldInfo
|
||||
|
||||
class StandardDataclass(typing.Protocol):
|
||||
__dataclass_fields__: ClassVar[dict[str, Any]]
|
||||
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
|
||||
__post_init__: ClassVar[Callable[..., None]]
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
class PydanticDataclass(StandardDataclass, typing.Protocol):
|
||||
"""A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass.
|
||||
@@ -53,10 +61,7 @@ if typing.TYPE_CHECKING:
|
||||
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
|
||||
__pydantic_fields__: ClassVar[dict[str, FieldInfo]]
|
||||
__pydantic_serializer__: ClassVar[SchemaSerializer]
|
||||
__pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator]
|
||||
|
||||
@classmethod
|
||||
def __pydantic_fields_complete__(cls) -> bool: ...
|
||||
__pydantic_validator__: ClassVar[SchemaValidator]
|
||||
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
@@ -64,22 +69,15 @@ else:
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
|
||||
def set_dataclass_fields(
|
||||
cls: type[StandardDataclass],
|
||||
ns_resolver: NsResolver | None = None,
|
||||
config_wrapper: _config.ConfigWrapper | None = None,
|
||||
) -> None:
|
||||
def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None:
|
||||
"""Collect and set `cls.__pydantic_fields__`.
|
||||
|
||||
Args:
|
||||
cls: The class.
|
||||
ns_resolver: Namespace resolver to use when getting dataclass annotations.
|
||||
config_wrapper: The config wrapper instance, defaults to `None`.
|
||||
types_namespace: The types namespace, defaults to `None`.
|
||||
"""
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
fields = collect_dataclass_fields(
|
||||
cls, ns_resolver=ns_resolver, typevars_map=typevars_map, config_wrapper=config_wrapper
|
||||
)
|
||||
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map)
|
||||
|
||||
cls.__pydantic_fields__ = fields # type: ignore
|
||||
|
||||
@@ -89,8 +87,7 @@ def complete_dataclass(
|
||||
config_wrapper: _config.ConfigWrapper,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
_force_build: bool = False,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
) -> bool:
|
||||
"""Finish building a pydantic dataclass.
|
||||
|
||||
@@ -102,10 +99,7 @@ def complete_dataclass(
|
||||
cls: The class.
|
||||
config_wrapper: The config wrapper instance.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
ns_resolver: The namespace resolver instance to use when collecting dataclass fields
|
||||
and during schema building.
|
||||
_force_build: Whether to force building the dataclass, no matter if
|
||||
[`defer_build`][pydantic.config.ConfigDict.defer_build] is set.
|
||||
types_namespace: The types namespace.
|
||||
|
||||
Returns:
|
||||
`True` if building a pydantic dataclass is successfully completed, `False` otherwise.
|
||||
@@ -113,94 +107,136 @@ def complete_dataclass(
|
||||
Raises:
|
||||
PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations.
|
||||
"""
|
||||
original_init = cls.__init__
|
||||
if hasattr(cls, '__post_init_post_parse__'):
|
||||
warnings.warn(
|
||||
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
|
||||
)
|
||||
|
||||
if types_namespace is None:
|
||||
types_namespace = _typing_extra.get_cls_types_namespace(cls)
|
||||
|
||||
set_dataclass_fields(cls, types_namespace)
|
||||
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
types_namespace,
|
||||
typevars_map,
|
||||
)
|
||||
|
||||
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
|
||||
|
||||
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied,
|
||||
# and so that the mock validator is used if building was deferred:
|
||||
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
|
||||
__tracebackhide__ = True
|
||||
s = __dataclass_self__
|
||||
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
|
||||
|
||||
__init__.__qualname__ = f'{cls.__qualname__}.__init__'
|
||||
|
||||
sig = generate_dataclass_signature(cls)
|
||||
cls.__init__ = __init__ # type: ignore
|
||||
cls.__signature__ = sig # type: ignore
|
||||
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore
|
||||
|
||||
set_dataclass_fields(cls, ns_resolver, config_wrapper=config_wrapper)
|
||||
|
||||
if not _force_build and config_wrapper.defer_build:
|
||||
set_dataclass_mocks(cls)
|
||||
return False
|
||||
|
||||
if hasattr(cls, '__post_init_post_parse__'):
|
||||
warnings.warn(
|
||||
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
|
||||
)
|
||||
|
||||
typevars_map = get_standard_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
ns_resolver=ns_resolver,
|
||||
typevars_map=typevars_map,
|
||||
)
|
||||
|
||||
# set __signature__ attr only for the class, but not for its instances
|
||||
# (because instances can define `__call__`, and `inspect.signature` shouldn't
|
||||
# use the `__signature__` attribute and instead generate from `__call__`).
|
||||
cls.__signature__ = LazyClassAttribute(
|
||||
'__signature__',
|
||||
partial(
|
||||
generate_pydantic_signature,
|
||||
# It's important that we reference the `original_init` here,
|
||||
# as it is the one synthesized by the stdlib `dataclass` module:
|
||||
init=original_init,
|
||||
fields=cls.__pydantic_fields__, # type: ignore
|
||||
validate_by_name=config_wrapper.validate_by_name,
|
||||
extra=config_wrapper.extra,
|
||||
is_dataclass=True,
|
||||
),
|
||||
)
|
||||
|
||||
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
|
||||
try:
|
||||
schema = gen_schema.generate_schema(cls)
|
||||
if get_core_schema:
|
||||
schema = get_core_schema(
|
||||
cls,
|
||||
CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
),
|
||||
)
|
||||
else:
|
||||
schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_dataclass_mocks(cls, f'`{e.name}`')
|
||||
set_dataclass_mock_validator(cls, cls.__name__, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(title=cls.__name__)
|
||||
core_config = config_wrapper.core_config(cls)
|
||||
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except InvalidSchemaError:
|
||||
set_dataclass_mocks(cls)
|
||||
schema = gen_schema.collect_definitions(schema)
|
||||
if collect_invalid_schemas(schema):
|
||||
set_dataclass_mock_validator(cls, cls.__name__, 'all referenced types')
|
||||
return False
|
||||
|
||||
schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))
|
||||
|
||||
# We are about to set all the remaining required properties expected for this cast;
|
||||
# __pydantic_decorators__ and __pydantic_fields__ should already be set
|
||||
cls = typing.cast('type[PydanticDataclass]', cls)
|
||||
# debug(schema)
|
||||
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
cls.__pydantic_core_schema__ = schema = validate_core_schema(schema)
|
||||
cls.__pydantic_validator__ = validator = create_schema_validator(
|
||||
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
|
||||
schema, core_config, config_wrapper.plugin_settings
|
||||
)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
|
||||
if config_wrapper.validate_assignment:
|
||||
|
||||
@wraps(cls.__setattr__)
|
||||
def validated_setattr(instance: Any, field: str, value: str, /) -> None:
|
||||
validator.validate_assignment(instance, field, value)
|
||||
def validated_setattr(instance: Any, __field: str, __value: str) -> None:
|
||||
validator.validate_assignment(instance, __field, __value)
|
||||
|
||||
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
|
||||
|
||||
cls.__pydantic_complete__ = True
|
||||
return True
|
||||
|
||||
|
||||
def generate_dataclass_signature(cls: type[StandardDataclass]) -> Signature:
|
||||
"""Generate signature for a pydantic dataclass.
|
||||
|
||||
This implementation assumes we do not support custom `__init__`, which is currently true for pydantic dataclasses.
|
||||
If we change this eventually, we should make this function's logic more closely mirror that from
|
||||
`pydantic._internal._model_construction.generate_model_signature`.
|
||||
|
||||
Args:
|
||||
cls: The dataclass.
|
||||
|
||||
Returns:
|
||||
The signature.
|
||||
"""
|
||||
sig = signature(cls)
|
||||
final_params: dict[str, Parameter] = {}
|
||||
|
||||
for param in sig.parameters.values():
|
||||
param_default = param.default
|
||||
if isinstance(param_default, FieldInfo):
|
||||
annotation = param.annotation
|
||||
# Replace the annotation if appropriate
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if annotation == 'Any':
|
||||
annotation = Any
|
||||
|
||||
# Replace the field name with the alias if present
|
||||
name = param.name
|
||||
alias = param_default.alias
|
||||
validation_alias = param_default.validation_alias
|
||||
if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias):
|
||||
name = alias
|
||||
elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias):
|
||||
name = validation_alias
|
||||
|
||||
# Replace the field default
|
||||
default = param_default.default
|
||||
if default is PydanticUndefined:
|
||||
if param_default.default_factory is PydanticUndefined:
|
||||
default = inspect.Signature.empty
|
||||
else:
|
||||
# this is used by dataclasses to indicate a factory exists:
|
||||
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
|
||||
|
||||
param = param.replace(annotation=annotation, name=name, default=default)
|
||||
final_params[param.name] = param
|
||||
|
||||
return Signature(parameters=list(final_params.values()), return_annotation=None)
|
||||
|
||||
|
||||
def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
|
||||
|
||||
@@ -209,7 +245,7 @@ def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
- `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`)
|
||||
- `_cls` does not have any annotations that are not dataclass fields
|
||||
e.g.
|
||||
```python
|
||||
```py
|
||||
import dataclasses
|
||||
|
||||
import pydantic.dataclasses
|
||||
|
||||
@@ -1,30 +1,31 @@
|
||||
"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import types
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, partial, partialmethod
|
||||
from functools import partial, partialmethod
|
||||
from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union
|
||||
|
||||
from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema
|
||||
from typing_extensions import TypeAlias, is_typeddict
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from typing_extensions import Literal, TypeAlias, is_typeddict
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ..fields import ComputedFieldInfo
|
||||
from ._core_utils import get_type_ref
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._namespace_utils import GlobalsNamespace, MappingNamespace
|
||||
from ._typing_extra import get_function_type_hints
|
||||
from ._utils import can_be_positional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..fields import ComputedFieldInfo
|
||||
from ..functional_validators import FieldValidatorModes
|
||||
|
||||
try:
|
||||
from functools import cached_property # type: ignore
|
||||
except ImportError:
|
||||
# python 3.7
|
||||
cached_property = None
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
class ValidatorDecoratorInfo:
|
||||
@@ -60,9 +61,6 @@ class FieldValidatorDecoratorInfo:
|
||||
fields: A tuple of field names the validator should be called on.
|
||||
mode: The proposed validator mode.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate
|
||||
the appropriate JSON Schema (in validation mode) and can only specified
|
||||
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
|
||||
"""
|
||||
|
||||
decorator_repr: ClassVar[str] = '@field_validator'
|
||||
@@ -70,7 +68,6 @@ class FieldValidatorDecoratorInfo:
|
||||
fields: tuple[str, ...]
|
||||
mode: FieldValidatorModes
|
||||
check_fields: bool | None
|
||||
json_schema_input_type: Any
|
||||
|
||||
|
||||
@dataclass(**slots_true)
|
||||
@@ -135,7 +132,7 @@ class ModelValidatorDecoratorInfo:
|
||||
while building the pydantic-core schema.
|
||||
|
||||
Attributes:
|
||||
decorator_repr: A class variable representing the decorator string, '@model_validator'.
|
||||
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
|
||||
mode: The proposed serializer mode.
|
||||
"""
|
||||
|
||||
@@ -143,7 +140,7 @@ class ModelValidatorDecoratorInfo:
|
||||
mode: Literal['wrap', 'before', 'after']
|
||||
|
||||
|
||||
DecoratorInfo: TypeAlias = """Union[
|
||||
DecoratorInfo = Union[
|
||||
ValidatorDecoratorInfo,
|
||||
FieldValidatorDecoratorInfo,
|
||||
RootValidatorDecoratorInfo,
|
||||
@@ -151,7 +148,7 @@ DecoratorInfo: TypeAlias = """Union[
|
||||
ModelSerializerDecoratorInfo,
|
||||
ModelValidatorDecoratorInfo,
|
||||
ComputedFieldInfo,
|
||||
]"""
|
||||
]
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
DecoratedType: TypeAlias = (
|
||||
@@ -186,12 +183,6 @@ class PydanticDescriptorProxy(Generic[ReturnType]):
|
||||
|
||||
def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
|
||||
self.wrapped = getattr(self.wrapped, name)(func)
|
||||
if isinstance(self.wrapped, property):
|
||||
# update ComputedFieldInfo.wrapped_property
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
if isinstance(self.decorator_info, ComputedFieldInfo):
|
||||
self.decorator_info.wrapped_property = self.wrapped
|
||||
return self
|
||||
|
||||
def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
|
||||
@@ -203,11 +194,11 @@ class PydanticDescriptorProxy(Generic[ReturnType]):
|
||||
|
||||
def __set_name__(self, instance: Any, name: str) -> None:
|
||||
if hasattr(self.wrapped, '__set_name__'):
|
||||
self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
|
||||
self.wrapped.__set_name__(instance, name)
|
||||
|
||||
def __getattr__(self, name: str, /) -> Any:
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
"""Forward checks for __isabstractmethod__ and such."""
|
||||
return getattr(self.wrapped, name)
|
||||
return getattr(self.wrapped, __name)
|
||||
|
||||
|
||||
DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
|
||||
@@ -497,8 +488,6 @@ class DecoratorInfos:
|
||||
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
|
||||
)
|
||||
else:
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
isinstance(var_value, ComputedFieldInfo)
|
||||
res.computed_fields[var_name] = Decorator.build(
|
||||
model_dc, cls_var_name=var_name, shim=None, info=info
|
||||
@@ -509,7 +498,7 @@ class DecoratorInfos:
|
||||
# so then we don't need to re-process the type, which means we can discard our descriptor wrappers
|
||||
# and replace them with the thing they are wrapping (see the other setattr call below)
|
||||
# which allows validator class methods to also function as regular class methods
|
||||
model_dc.__pydantic_decorators__ = res
|
||||
setattr(model_dc, '__pydantic_decorators__', res)
|
||||
for name, value in to_replace:
|
||||
setattr(model_dc, name, value)
|
||||
return res
|
||||
@@ -529,11 +518,12 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
|
||||
"""
|
||||
try:
|
||||
sig = signature(validator)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
except ValueError:
|
||||
# builtins and some C extensions don't have signatures
|
||||
# assume that they don't take an info argument and only take a single argument
|
||||
# e.g. `str.strip` or `datetime.datetime`
|
||||
return False
|
||||
n_positional = count_positional_required_params(sig)
|
||||
n_positional = count_positional_params(sig)
|
||||
if mode == 'wrap':
|
||||
if n_positional == 3:
|
||||
return True
|
||||
@@ -552,7 +542,9 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes)
|
||||
)
|
||||
|
||||
|
||||
def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]:
|
||||
def inspect_field_serializer(
|
||||
serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False
|
||||
) -> tuple[bool, bool]:
|
||||
"""Look at a field serializer function and determine if it is a field serializer,
|
||||
and whether it takes an info argument.
|
||||
|
||||
@@ -561,21 +553,18 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai
|
||||
Args:
|
||||
serializer: The serializer function to inspect.
|
||||
mode: The serializer mode, either 'plain' or 'wrap'.
|
||||
computed_field: When serializer is applied on computed_field. It doesn't require
|
||||
info signature.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_field_serializer, info_arg).
|
||||
"""
|
||||
try:
|
||||
sig = signature(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present and this is not a method:
|
||||
return (False, False)
|
||||
sig = signature(serializer)
|
||||
|
||||
first = next(iter(sig.parameters.values()), None)
|
||||
is_field_serializer = first is not None and first.name == 'self'
|
||||
|
||||
n_positional = count_positional_required_params(sig)
|
||||
n_positional = count_positional_params(sig)
|
||||
if is_field_serializer:
|
||||
# -1 to correct for self parameter
|
||||
info_arg = _serializer_info_arg(mode, n_positional - 1)
|
||||
@@ -587,8 +576,13 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
code='field-serializer-signature',
|
||||
)
|
||||
if info_arg and computed_field:
|
||||
raise PydanticUserError(
|
||||
'field_serializer on computed_field does not use info signature', code='field-serializer-signature'
|
||||
)
|
||||
|
||||
return is_field_serializer, info_arg
|
||||
else:
|
||||
return is_field_serializer, info_arg
|
||||
|
||||
|
||||
def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
|
||||
@@ -603,13 +597,8 @@ def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['
|
||||
Returns:
|
||||
info_arg
|
||||
"""
|
||||
try:
|
||||
sig = signature(serializer)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no info argument is present:
|
||||
return False
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
sig = signature(serializer)
|
||||
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
@@ -637,7 +626,7 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai
|
||||
)
|
||||
|
||||
sig = signature(serializer)
|
||||
info_arg = _serializer_info_arg(mode, count_positional_required_params(sig))
|
||||
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
|
||||
if info_arg is None:
|
||||
raise PydanticUserError(
|
||||
f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
|
||||
@@ -650,18 +639,18 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai
|
||||
def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
|
||||
if mode == 'plain':
|
||||
if n_positional == 1:
|
||||
# (input_value: Any, /) -> Any
|
||||
# (__input_value: Any) -> Any
|
||||
return False
|
||||
elif n_positional == 2:
|
||||
# (model: Any, input_value: Any, /) -> Any
|
||||
# (__model: Any, __input_value: Any) -> Any
|
||||
return True
|
||||
else:
|
||||
assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
|
||||
if n_positional == 2:
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any
|
||||
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any
|
||||
return False
|
||||
elif n_positional == 3:
|
||||
# (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any
|
||||
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any
|
||||
return True
|
||||
|
||||
return None
|
||||
@@ -722,25 +711,34 @@ def unwrap_wrapped_function(
|
||||
unwrap_class_static_method: bool = True,
|
||||
) -> Any:
|
||||
"""Recursively unwraps a wrapped function until the underlying function is reached.
|
||||
This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod.
|
||||
This handles property, functools.partial, functools.partialmethod, staticmethod and classmethod.
|
||||
|
||||
Args:
|
||||
func: The function to unwrap.
|
||||
unwrap_partial: If True (default), unwrap partial and partialmethod decorators.
|
||||
unwrap_partial: If True (default), unwrap partial and partialmethod decorators, otherwise don't.
|
||||
decorators.
|
||||
unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
|
||||
decorators. If False, only unwrap partial and partialmethod decorators.
|
||||
|
||||
Returns:
|
||||
The underlying function of the wrapped function.
|
||||
"""
|
||||
# Define the types we want to check against as a single tuple.
|
||||
unwrap_types = (
|
||||
(property, cached_property)
|
||||
+ ((partial, partialmethod) if unwrap_partial else ())
|
||||
+ ((staticmethod, classmethod) if unwrap_class_static_method else ())
|
||||
)
|
||||
all: set[Any] = {property}
|
||||
|
||||
while isinstance(func, unwrap_types):
|
||||
if unwrap_partial:
|
||||
all.update({partial, partialmethod})
|
||||
|
||||
try:
|
||||
from functools import cached_property # type: ignore
|
||||
except ImportError:
|
||||
cached_property = type('', (), {})
|
||||
else:
|
||||
all.add(cached_property)
|
||||
|
||||
if unwrap_class_static_method:
|
||||
all.update({staticmethod, classmethod})
|
||||
|
||||
while isinstance(func, tuple(all)):
|
||||
if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
|
||||
func = func.__func__
|
||||
elif isinstance(func, (partial, partialmethod)):
|
||||
@@ -755,72 +753,38 @@ def unwrap_wrapped_function(
|
||||
return func
|
||||
|
||||
|
||||
_function_like = (
|
||||
partial,
|
||||
partialmethod,
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
types.MethodType,
|
||||
types.WrapperDescriptorType,
|
||||
types.MethodWrapperType,
|
||||
types.MemberDescriptorType,
|
||||
)
|
||||
def get_function_return_type(
|
||||
func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Get the function return type.
|
||||
|
||||
|
||||
def get_callable_return_type(
|
||||
callable_obj: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any | PydanticUndefinedType:
|
||||
"""Get the callable return type.
|
||||
It gets the return type from the type annotation if `explicit_return_type` is `None`.
|
||||
Otherwise, it returns `explicit_return_type`.
|
||||
|
||||
Args:
|
||||
callable_obj: The callable to analyze.
|
||||
globalns: The globals namespace to use during type annotation evaluation.
|
||||
localns: The locals namespace to use during type annotation evaluation.
|
||||
func: The function to get its return type.
|
||||
explicit_return_type: The explicit return type.
|
||||
types_namespace: The types namespace, defaults to `None`.
|
||||
|
||||
Returns:
|
||||
The function return type.
|
||||
"""
|
||||
if isinstance(callable_obj, type):
|
||||
# types are callables, and we assume the return type
|
||||
# is the type itself (e.g. `int()` results in an instance of `int`).
|
||||
return callable_obj
|
||||
|
||||
if not isinstance(callable_obj, _function_like):
|
||||
call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004
|
||||
if call_func is not None:
|
||||
callable_obj = call_func
|
||||
|
||||
hints = get_function_type_hints(
|
||||
unwrap_wrapped_function(callable_obj),
|
||||
include_keys={'return'},
|
||||
globalns=globalns,
|
||||
localns=localns,
|
||||
)
|
||||
return hints.get('return', PydanticUndefined)
|
||||
if explicit_return_type is PydanticUndefined:
|
||||
# try to get it from the type annotation
|
||||
hints = get_function_type_hints(
|
||||
unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace
|
||||
)
|
||||
return hints.get('return', PydanticUndefined)
|
||||
else:
|
||||
return explicit_return_type
|
||||
|
||||
|
||||
def count_positional_required_params(sig: Signature) -> int:
|
||||
"""Get the number of positional (required) arguments of a signature.
|
||||
def count_positional_params(sig: Signature) -> int:
|
||||
return sum(1 for param in sig.parameters.values() if can_be_positional(param))
|
||||
|
||||
This function should only be used to inspect signatures of validation and serialization functions.
|
||||
The first argument (the value being serialized or validated) is counted as a required argument
|
||||
even if a default value exists.
|
||||
|
||||
Returns:
|
||||
The number of positional arguments of a signature.
|
||||
"""
|
||||
parameters = list(sig.parameters.values())
|
||||
return sum(
|
||||
1
|
||||
for param in parameters
|
||||
if can_be_positional(param)
|
||||
# First argument is the value being validated/serialized, and can have a default value
|
||||
# (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg
|
||||
# for instance) should be required, and thus without any default value.
|
||||
and (param.default is Parameter.empty or param is parameters[0])
|
||||
)
|
||||
def can_be_positional(param: Parameter) -> bool:
|
||||
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
||||
|
||||
|
||||
def ensure_property(f: Any) -> Any:
|
||||
|
||||
@@ -1,45 +1,49 @@
|
||||
"""Logic for V1 validators, e.g. `@validator` and `@root_validator`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from inspect import Parameter, signature
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, Dict, Tuple, Union, cast
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from ._utils import can_be_positional
|
||||
from ._decorators import can_be_positional
|
||||
|
||||
|
||||
class V1OnlyValueValidator(Protocol):
|
||||
"""A simple validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any) -> Any: ...
|
||||
def __call__(self, __value: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class V1ValidatorWithValues(Protocol):
|
||||
"""A validator with `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any]) -> Any: ...
|
||||
def __call__(self, __value: Any, values: dict[str, Any]) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesKwOnly(Protocol):
|
||||
"""A validator with keyword only `values` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: ...
|
||||
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class V1ValidatorWithKwargs(Protocol):
|
||||
"""A validator with `kwargs` argument, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, **kwargs: Any) -> Any: ...
|
||||
def __call__(self, __value: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class V1ValidatorWithValuesAndKwargs(Protocol):
|
||||
"""A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
|
||||
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
|
||||
V1Validator = Union[
|
||||
@@ -105,21 +109,23 @@ def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithI
|
||||
return wrapper2
|
||||
|
||||
|
||||
RootValidatorValues = dict[str, Any]
|
||||
RootValidatorValues = Dict[str, Any]
|
||||
# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars]
|
||||
RootValidatorFieldsTuple = tuple[Any, ...]
|
||||
RootValidatorFieldsTuple = Tuple[Any, ...]
|
||||
|
||||
|
||||
class V1RootValidatorFunction(Protocol):
|
||||
"""A simple root validator, supported for V1 validators and V2 validators."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: ...
|
||||
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues:
|
||||
...
|
||||
|
||||
|
||||
class V2CoreBeforeRootValidator(Protocol):
|
||||
"""V2 validator with mode='before'."""
|
||||
|
||||
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: ...
|
||||
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues:
|
||||
...
|
||||
|
||||
|
||||
class V2CoreAfterRootValidator(Protocol):
|
||||
@@ -127,7 +133,8 @@ class V2CoreAfterRootValidator(Protocol):
|
||||
|
||||
def __call__(
|
||||
self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo
|
||||
) -> RootValidatorFieldsTuple: ...
|
||||
) -> RootValidatorFieldsTuple:
|
||||
...
|
||||
|
||||
|
||||
def make_v1_generic_root_validator(
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Hashable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import Any, Hashable, Sequence
|
||||
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from ..errors import PydanticUserError
|
||||
from . import _core_utils
|
||||
from ._core_utils import (
|
||||
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
|
||||
CoreSchemaField,
|
||||
collect_definitions,
|
||||
simplify_schema_references,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..types import Discriminator
|
||||
from ._core_metadata import CoreMetadata
|
||||
CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'
|
||||
|
||||
|
||||
class MissingDefinitionForUnionRef(Exception):
|
||||
@@ -26,15 +26,39 @@ class MissingDefinitionForUnionRef(Exception):
|
||||
super().__init__(f'Missing definition for ref {self.ref!r}')
|
||||
|
||||
|
||||
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
|
||||
metadata = cast('CoreMetadata', schema.setdefault('metadata', {}))
|
||||
metadata['pydantic_internal_union_discriminator'] = discriminator
|
||||
def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:
|
||||
schema.setdefault('metadata', {})
|
||||
metadata = schema.get('metadata')
|
||||
assert metadata is not None
|
||||
metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator
|
||||
|
||||
|
||||
def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
definitions: dict[str, CoreSchema] | None = None
|
||||
|
||||
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
|
||||
nonlocal definitions
|
||||
if 'metadata' in s:
|
||||
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False:
|
||||
return s
|
||||
|
||||
s = recurse(s, inner)
|
||||
if s['type'] == 'tagged-union':
|
||||
return s
|
||||
|
||||
metadata = s.get('metadata', {})
|
||||
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
|
||||
if discriminator is not None:
|
||||
if definitions is None:
|
||||
definitions = collect_definitions(schema)
|
||||
s = apply_discriminator(s, discriminator, definitions)
|
||||
return s
|
||||
|
||||
return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
|
||||
|
||||
|
||||
def apply_discriminator(
|
||||
schema: core_schema.CoreSchema,
|
||||
discriminator: str | Discriminator,
|
||||
definitions: dict[str, core_schema.CoreSchema] | None = None,
|
||||
schema: core_schema.CoreSchema, discriminator: str, definitions: dict[str, core_schema.CoreSchema] | None = None
|
||||
) -> core_schema.CoreSchema:
|
||||
"""Applies the discriminator and returns a new core schema.
|
||||
|
||||
@@ -59,14 +83,6 @@ def apply_discriminator(
|
||||
- If discriminator fields have different aliases.
|
||||
- If discriminator field not of type `Literal`.
|
||||
"""
|
||||
from ..types import Discriminator
|
||||
|
||||
if isinstance(discriminator, Discriminator):
|
||||
if isinstance(discriminator.discriminator, str):
|
||||
discriminator = discriminator.discriminator
|
||||
else:
|
||||
return discriminator._convert_schema(schema)
|
||||
|
||||
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
|
||||
|
||||
|
||||
@@ -134,7 +150,7 @@ class _ApplyInferredDiscriminator:
|
||||
# in the output TaggedUnionSchema that will replace the union from the input schema
|
||||
self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
|
||||
|
||||
# `_used` is changed to True after applying the discriminator to prevent accidental reuse
|
||||
# `_used` is changed to True after applying the discriminator to prevent accidental re-use
|
||||
self._used = False
|
||||
|
||||
def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
@@ -160,11 +176,16 @@ class _ApplyInferredDiscriminator:
|
||||
- If discriminator fields have different aliases.
|
||||
- If discriminator field not of type `Literal`.
|
||||
"""
|
||||
self.definitions.update(collect_definitions(schema))
|
||||
assert not self._used
|
||||
schema = self._apply_to_root(schema)
|
||||
if self._should_be_nullable and not self._is_nullable:
|
||||
schema = core_schema.nullable_schema(schema)
|
||||
self._used = True
|
||||
new_defs = collect_definitions(schema)
|
||||
missing_defs = self.definitions.keys() - new_defs.keys()
|
||||
if missing_defs:
|
||||
schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs])
|
||||
return schema
|
||||
|
||||
def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
@@ -234,10 +255,6 @@ class _ApplyInferredDiscriminator:
|
||||
* Validating that each allowed discriminator value maps to a unique choice
|
||||
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
|
||||
"""
|
||||
if choice['type'] == 'definition-ref':
|
||||
if choice['schema_ref'] not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(choice['schema_ref'])
|
||||
|
||||
if choice['type'] == 'none':
|
||||
self._should_be_nullable = True
|
||||
elif choice['type'] == 'definitions':
|
||||
@@ -249,6 +266,10 @@ class _ApplyInferredDiscriminator:
|
||||
# Reverse the choices list before extending the stack so that they get handled in the order they occur
|
||||
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
|
||||
self._choices_to_handle.extend(choices_schemas)
|
||||
elif choice['type'] == 'definition-ref':
|
||||
if choice['schema_ref'] not in self.definitions:
|
||||
raise MissingDefinitionForUnionRef(choice['schema_ref'])
|
||||
self._handle_choice(self.definitions[choice['schema_ref']])
|
||||
elif choice['type'] not in {
|
||||
'model',
|
||||
'typed-dict',
|
||||
@@ -256,16 +277,12 @@ class _ApplyInferredDiscriminator:
|
||||
'lax-or-strict',
|
||||
'dataclass',
|
||||
'dataclass-args',
|
||||
'definition-ref',
|
||||
} and not _core_utils.is_function_with_inner_schema(choice):
|
||||
# We should eventually handle 'definition-ref' as well
|
||||
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
|
||||
if choice['type'] == 'list':
|
||||
err_str += (
|
||||
' If you are making use of a list of union types, make sure the discriminator is applied to the '
|
||||
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
|
||||
)
|
||||
raise TypeError(err_str)
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
else:
|
||||
if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
|
||||
# In this case, this inner tagged-union is compatible with the outer tagged-union,
|
||||
@@ -299,10 +316,13 @@ class _ApplyInferredDiscriminator:
|
||||
"""
|
||||
if choice['type'] == 'definitions':
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'function-plain':
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
elif _core_utils.is_function_with_inner_schema(choice):
|
||||
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
|
||||
|
||||
elif choice['type'] == 'lax-or-strict':
|
||||
return sorted(
|
||||
set(
|
||||
@@ -353,13 +373,10 @@ class _ApplyInferredDiscriminator:
|
||||
raise MissingDefinitionForUnionRef(schema_ref)
|
||||
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
|
||||
else:
|
||||
err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.'
|
||||
if choice['type'] == 'list':
|
||||
err_str += (
|
||||
' If you are making use of a list of union types, make sure the discriminator is applied to the '
|
||||
'union type and not the list (e.g. `list[Annotated[<T> | <U>, Field(discriminator=...)]]`).'
|
||||
)
|
||||
raise TypeError(err_str)
|
||||
raise TypeError(
|
||||
f'{choice["type"]!r} is not a valid discriminated union variant;'
|
||||
' should be a `BaseModel` or `dataclass`'
|
||||
)
|
||||
|
||||
def _infer_discriminator_values_for_typed_dict_choice(
|
||||
self, choice: core_schema.TypedDictSchema, source_name: str | None = None
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
"""Utilities related to attribute docstring extraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DocstringVisitor(ast.NodeVisitor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.target: str | None = None
|
||||
self.attrs: dict[str, str] = {}
|
||||
self.previous_node_type: type[ast.AST] | None = None
|
||||
|
||||
def visit(self, node: ast.AST) -> Any:
|
||||
node_result = super().visit(node)
|
||||
self.previous_node_type = type(node)
|
||||
return node_result
|
||||
|
||||
def visit_AnnAssign(self, node: ast.AnnAssign) -> Any:
|
||||
if isinstance(node.target, ast.Name):
|
||||
self.target = node.target.id
|
||||
|
||||
def visit_Expr(self, node: ast.Expr) -> Any:
|
||||
if (
|
||||
isinstance(node.value, ast.Constant)
|
||||
and isinstance(node.value.value, str)
|
||||
and self.previous_node_type is ast.AnnAssign
|
||||
):
|
||||
docstring = inspect.cleandoc(node.value.value)
|
||||
if self.target:
|
||||
self.attrs[self.target] = docstring
|
||||
self.target = None
|
||||
|
||||
|
||||
def _dedent_source_lines(source: list[str]) -> str:
|
||||
# Required for nested class definitions, e.g. in a function block
|
||||
dedent_source = textwrap.dedent(''.join(source))
|
||||
if dedent_source.startswith((' ', '\t')):
|
||||
# We are in the case where there's a dedented (usually multiline) string
|
||||
# at a lower indentation level than the class itself. We wrap our class
|
||||
# in a function as a workaround.
|
||||
dedent_source = f'def dedent_workaround():\n{dedent_source}'
|
||||
return dedent_source
|
||||
|
||||
|
||||
def _extract_source_from_frame(cls: type[Any]) -> list[str] | None:
|
||||
frame = inspect.currentframe()
|
||||
|
||||
while frame:
|
||||
if inspect.getmodule(frame) is inspect.getmodule(cls):
|
||||
lnum = frame.f_lineno
|
||||
try:
|
||||
lines, _ = inspect.findsource(frame)
|
||||
except OSError: # pragma: no cover
|
||||
# Source can't be retrieved (maybe because running in an interactive terminal),
|
||||
# we don't want to error here.
|
||||
pass
|
||||
else:
|
||||
block_lines = inspect.getblock(lines[lnum - 1 :])
|
||||
dedent_source = _dedent_source_lines(block_lines)
|
||||
try:
|
||||
block_tree = ast.parse(dedent_source)
|
||||
except SyntaxError:
|
||||
pass
|
||||
else:
|
||||
stmt = block_tree.body[0]
|
||||
if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround':
|
||||
# `_dedent_source_lines` wrapped the class around the workaround function
|
||||
stmt = stmt.body[0]
|
||||
if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__:
|
||||
return block_lines
|
||||
|
||||
frame = frame.f_back
|
||||
|
||||
|
||||
def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]:
|
||||
"""Map model attributes and their corresponding docstring.
|
||||
|
||||
Args:
|
||||
cls: The class of the Pydantic model to inspect.
|
||||
use_inspect: Whether to skip usage of frames to find the object and use
|
||||
the `inspect` module instead.
|
||||
|
||||
Returns:
|
||||
A mapping containing attribute names and their corresponding docstring.
|
||||
"""
|
||||
if use_inspect:
|
||||
# Might not work as expected if two classes have the same name in the same source file.
|
||||
try:
|
||||
source, _ = inspect.getsourcelines(cls)
|
||||
except OSError: # pragma: no cover
|
||||
return {}
|
||||
else:
|
||||
source = _extract_source_from_frame(cls)
|
||||
|
||||
if not source:
|
||||
return {}
|
||||
|
||||
dedent_source = _dedent_source_lines(source)
|
||||
|
||||
visitor = DocstringVisitor()
|
||||
visitor.visit(ast.parse(dedent_source))
|
||||
return visitor.attrs
|
||||
@@ -1,104 +1,92 @@
|
||||
"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
from inspect import Parameter, ismethoddescriptor, signature
|
||||
from re import Pattern
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from annotated_types import BaseMetadata
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeIs, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import AnnotationSource
|
||||
|
||||
from pydantic import PydanticDeprecatedSince211
|
||||
from pydantic.errors import PydanticUserError
|
||||
|
||||
from . import _generics, _typing_extra
|
||||
from . import _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
from ._docs_extraction import extract_docstrings_from_cls
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._repr import Representation
|
||||
from ._utils import can_be_positional
|
||||
from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from annotated_types import BaseMetadata
|
||||
|
||||
from ..fields import FieldInfo
|
||||
from ..main import BaseModel
|
||||
from ._dataclasses import PydanticDataclass, StandardDataclass
|
||||
from ._dataclasses import StandardDataclass
|
||||
from ._decorators import DecoratorInfos
|
||||
|
||||
|
||||
def get_type_hints_infer_globalns(
|
||||
obj: Any,
|
||||
localns: dict[str, Any] | None = None,
|
||||
include_extras: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Gets type hints for an object by inferring the global namespace.
|
||||
|
||||
It uses the `typing.get_type_hints`, The only thing that we do here is fetching
|
||||
global namespace from `obj.__module__` if it is not `None`.
|
||||
|
||||
Args:
|
||||
obj: The object to get its type hints.
|
||||
localns: The local namespaces.
|
||||
include_extras: Whether to recursively include annotation metadata.
|
||||
|
||||
Returns:
|
||||
The object type hints.
|
||||
"""
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
globalns: dict[str, Any] | None = None
|
||||
if module_name:
|
||||
try:
|
||||
globalns = sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
pass
|
||||
return get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
|
||||
|
||||
|
||||
class PydanticMetadata(Representation):
|
||||
"""Base class for annotation markers like `Strict`."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
|
||||
"""Create a new `_PydanticGeneralMetadata` class with the given metadata.
|
||||
class PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
|
||||
"""Pydantic general metada like `max_digits`."""
|
||||
|
||||
Args:
|
||||
**metadata: The metadata to add.
|
||||
|
||||
Returns:
|
||||
The new `_PydanticGeneralMetadata` class.
|
||||
"""
|
||||
return _general_metadata_cls()(metadata) # type: ignore
|
||||
|
||||
|
||||
@cache
|
||||
def _general_metadata_cls() -> type[BaseMetadata]:
|
||||
"""Do it this way to avoid importing `annotated_types` at import time."""
|
||||
from annotated_types import BaseMetadata
|
||||
|
||||
class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
|
||||
"""Pydantic general metadata like `max_digits`."""
|
||||
|
||||
def __init__(self, metadata: Any):
|
||||
self.__dict__ = metadata
|
||||
|
||||
return _PydanticGeneralMetadata # type: ignore
|
||||
|
||||
|
||||
def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], use_inspect: bool = False) -> None:
|
||||
fields_docs = extract_docstrings_from_cls(cls, use_inspect=use_inspect)
|
||||
for ann_name, field_info in fields.items():
|
||||
if field_info.description is None and ann_name in fields_docs:
|
||||
field_info.description = fields_docs[ann_name]
|
||||
def __init__(self, **metadata: Any):
|
||||
self.__dict__ = metadata
|
||||
|
||||
|
||||
def collect_model_fields( # noqa: C901
|
||||
cls: type[BaseModel],
|
||||
bases: tuple[type[Any], ...],
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver | None,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
*,
|
||||
typevars_map: Mapping[TypeVar, Any] | None = None,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
) -> tuple[dict[str, FieldInfo], set[str]]:
|
||||
"""Collect the fields and class variables names of a nascent Pydantic model.
|
||||
"""Collect the fields of a nascent pydantic model.
|
||||
|
||||
The fields collection process is *lenient*, meaning it won't error if string annotations
|
||||
fail to evaluate. If this happens, the original annotation (and assigned value, if any)
|
||||
is stored on the created `FieldInfo` instance.
|
||||
Also collect the names of any ClassVars present in the type hints.
|
||||
|
||||
The `rebuild_model_fields()` should be called at a later point (e.g. when rebuilding the model),
|
||||
and will make use of these stored attributes.
|
||||
The returned value is a tuple of two items: the fields dict, and the set of ClassVar names.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
ns_resolver: Namespace resolver to use when getting model annotations.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
|
||||
Returns:
|
||||
A two-tuple containing model fields and class variables names.
|
||||
A tuple contains fields and class variables.
|
||||
|
||||
Raises:
|
||||
NameError:
|
||||
@@ -106,16 +94,9 @@ def collect_model_fields( # noqa: C901
|
||||
- If there is a field other than `root` in `RootModel`.
|
||||
- If a field shadows an attribute in the parent model.
|
||||
"""
|
||||
BaseModel = import_cached_base_model()
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
from ..fields import FieldInfo
|
||||
|
||||
bases = cls.__bases__
|
||||
parent_fields_lookup: dict[str, FieldInfo] = {}
|
||||
for base in reversed(bases):
|
||||
if model_fields := getattr(base, '__pydantic_fields__', None):
|
||||
parent_fields_lookup.update(model_fields)
|
||||
|
||||
type_hints = _typing_extra.get_model_type_hints(cls, ns_resolver=ns_resolver)
|
||||
type_hints = get_cls_type_hints_lenient(cls, types_namespace)
|
||||
|
||||
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
|
||||
# annotations is only used for finding fields in parent classes
|
||||
@@ -123,50 +104,39 @@ def collect_model_fields( # noqa: C901
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
|
||||
class_vars: set[str] = set()
|
||||
for ann_name, (ann_type, evaluated) in type_hints.items():
|
||||
for ann_name, ann_type in type_hints.items():
|
||||
if ann_name == 'model_config':
|
||||
# We never want to treat `model_config` as a field
|
||||
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
|
||||
# protected namespaces (where `model_config` might be allowed as a field name)
|
||||
continue
|
||||
|
||||
for protected_namespace in config_wrapper.protected_namespaces:
|
||||
ns_violation: bool = False
|
||||
if isinstance(protected_namespace, Pattern):
|
||||
ns_violation = protected_namespace.match(ann_name) is not None
|
||||
elif isinstance(protected_namespace, str):
|
||||
ns_violation = ann_name.startswith(protected_namespace)
|
||||
|
||||
if ns_violation:
|
||||
if ann_name.startswith(protected_namespace):
|
||||
for b in bases:
|
||||
if hasattr(b, ann_name):
|
||||
if not (issubclass(b, BaseModel) and ann_name in getattr(b, '__pydantic_fields__', {})):
|
||||
from ..main import BaseModel
|
||||
|
||||
if not (issubclass(b, BaseModel) and ann_name in b.model_fields):
|
||||
raise NameError(
|
||||
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
|
||||
f' of protected namespace "{protected_namespace}".'
|
||||
)
|
||||
else:
|
||||
valid_namespaces = ()
|
||||
for pn in config_wrapper.protected_namespaces:
|
||||
if isinstance(pn, Pattern):
|
||||
if not pn.match(ann_name):
|
||||
valid_namespaces += (f're.compile({pn.pattern})',)
|
||||
else:
|
||||
if not ann_name.startswith(pn):
|
||||
valid_namespaces += (pn,)
|
||||
|
||||
valid_namespaces = tuple(
|
||||
x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x)
|
||||
)
|
||||
warnings.warn(
|
||||
f'Field "{ann_name}" in {cls.__name__} has conflict with protected namespace "{protected_namespace}".'
|
||||
f'Field "{ann_name}" has conflict with protected namespace "{protected_namespace}".'
|
||||
'\n\nYou may be able to resolve this warning by setting'
|
||||
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
|
||||
UserWarning,
|
||||
)
|
||||
if _typing_extra.is_classvar_annotation(ann_type):
|
||||
if is_classvar(ann_type):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
|
||||
assigned_value = getattr(cls, ann_name, PydanticUndefined)
|
||||
|
||||
if not is_valid_field_name(ann_name):
|
||||
continue
|
||||
if cls.__pydantic_root_model__ and ann_name != 'root':
|
||||
@@ -175,7 +145,7 @@ def collect_model_fields( # noqa: C901
|
||||
)
|
||||
|
||||
# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
|
||||
# "... shadows an attribute" warnings
|
||||
# "... shadows an attribute" errors
|
||||
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
|
||||
for base in bases:
|
||||
dataclass_fields = {
|
||||
@@ -183,77 +153,42 @@ def collect_model_fields( # noqa: C901
|
||||
}
|
||||
if hasattr(base, ann_name):
|
||||
if base is generic_origin:
|
||||
# Don't warn when "shadowing" of attributes in parametrized generics
|
||||
# Don't error when "shadowing" of attributes in parametrized generics
|
||||
continue
|
||||
|
||||
if ann_name in dataclass_fields:
|
||||
# Don't warn when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
|
||||
# Don't error when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
|
||||
# on the class instance.
|
||||
continue
|
||||
|
||||
if ann_name not in annotations:
|
||||
# Don't warn when a field exists in a parent class but has not been defined in the current class
|
||||
continue
|
||||
|
||||
warnings.warn(
|
||||
f'Field name "{ann_name}" in "{cls.__qualname__}" shadows an attribute in parent '
|
||||
f'"{base.__qualname__}"',
|
||||
f'Field name "{ann_name}" shadows an attribute in parent "{base.__qualname__}"; ',
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if assigned_value is PydanticUndefined: # no assignment, just a plain annotation
|
||||
if ann_name in annotations or ann_name not in parent_fields_lookup:
|
||||
# field is either:
|
||||
# - present in the current model's annotations (and *not* from parent classes)
|
||||
# - not found on any base classes; this seems to be caused by fields bot getting
|
||||
# generated due to models not being fully defined while initializing recursive models.
|
||||
# Nothing stops us from just creating a `FieldInfo` for this type hint, so we do this.
|
||||
field_info = FieldInfo_.from_annotation(ann_type, _source=AnnotationSource.CLASS)
|
||||
if not evaluated:
|
||||
field_info._complete = False
|
||||
# Store the original annotation that should be used to rebuild
|
||||
# the field info later:
|
||||
field_info._original_annotation = ann_type
|
||||
try:
|
||||
default = getattr(cls, ann_name, PydanticUndefined)
|
||||
if default is PydanticUndefined:
|
||||
raise AttributeError
|
||||
except AttributeError:
|
||||
if ann_name in annotations:
|
||||
field_info = FieldInfo.from_annotation(ann_type)
|
||||
else:
|
||||
# The field was present on one of the (possibly multiple) base classes
|
||||
# copy the field to make sure typevar substitutions don't cause issues with the base classes
|
||||
field_info = copy(parent_fields_lookup[ann_name])
|
||||
|
||||
else: # An assigned value is present (either the default value, or a `Field()` function)
|
||||
_warn_on_nested_alias_in_annotation(ann_type, ann_name)
|
||||
if isinstance(assigned_value, FieldInfo_) and ismethoddescriptor(assigned_value.default):
|
||||
# `assigned_value` was fetched using `getattr`, which triggers a call to `__get__`
|
||||
# for descriptors, so we do the same if the `= field(default=...)` form is used.
|
||||
# Note that we only do this for method descriptors for now, we might want to
|
||||
# extend this to any descriptor in the future (by simply checking for
|
||||
# `hasattr(assigned_value.default, '__get__')`).
|
||||
assigned_value.default = assigned_value.default.__get__(None, cls)
|
||||
|
||||
# The `from_annotated_attribute()` call below mutates the assigned `Field()`, so make a copy:
|
||||
original_assignment = (
|
||||
assigned_value._copy() if not evaluated and isinstance(assigned_value, FieldInfo_) else assigned_value
|
||||
)
|
||||
|
||||
field_info = FieldInfo_.from_annotated_attribute(ann_type, assigned_value, _source=AnnotationSource.CLASS)
|
||||
# Store the original annotation and assignment value that should be used to rebuild the field info later.
|
||||
# Note that the assignment is always stored as the annotation might contain a type var that is later
|
||||
# parameterized with an unknown forward reference (and we'll need it to rebuild the field info):
|
||||
field_info._original_assignment = original_assignment
|
||||
if not evaluated:
|
||||
field_info._complete = False
|
||||
field_info._original_annotation = ann_type
|
||||
elif 'final' in field_info._qualifiers and not field_info.is_required():
|
||||
warnings.warn(
|
||||
f'Annotation {ann_name!r} is marked as final and has a default value. Pydantic treats {ann_name!r} as a '
|
||||
'class variable, but it will be considered as a normal field in V3 to be aligned with dataclasses. If you '
|
||||
f'still want {ann_name!r} to be considered as a class variable, annotate it as: `ClassVar[<type>] = <default>.`',
|
||||
category=PydanticDeprecatedSince211,
|
||||
# Incorrect when `create_model` is used, but the chance that final with a default is used is low in that case:
|
||||
stacklevel=4,
|
||||
)
|
||||
class_vars.add(ann_name)
|
||||
continue
|
||||
|
||||
# if field has no default value and is not in __annotations__ this means that it is
|
||||
# defined in a base class and we can take it from there
|
||||
model_fields_lookup: dict[str, FieldInfo] = {}
|
||||
for x in cls.__bases__[::-1]:
|
||||
model_fields_lookup.update(getattr(x, 'model_fields', {}))
|
||||
if ann_name in model_fields_lookup:
|
||||
# The field was present on one of the (possibly multiple) base classes
|
||||
# copy the field to make sure typevar substitutions don't cause issues with the base classes
|
||||
field_info = copy(model_fields_lookup[ann_name])
|
||||
else:
|
||||
# The field was not found on any base classes; this seems to be caused by fields not getting
|
||||
# generated thanks to models not being fully defined while initializing recursive models.
|
||||
# Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
|
||||
field_info = FieldInfo.from_annotation(ann_type)
|
||||
else:
|
||||
field_info = FieldInfo.from_annotated_attribute(ann_type, default)
|
||||
# attributes which are fields are removed from the class namespace:
|
||||
# 1. To match the behaviour of annotation-only fields
|
||||
# 2. To avoid false positives in the NameError check above
|
||||
@@ -266,250 +201,81 @@ def collect_model_fields( # noqa: C901
|
||||
# to make sure the decorators have already been built for this exact class
|
||||
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
|
||||
if ann_name in decorators.computed_fields:
|
||||
raise TypeError(
|
||||
f'Field {ann_name!r} of class {cls.__name__!r} overrides symbol of same name in a parent class. '
|
||||
'This override with a computed_field is incompatible.'
|
||||
)
|
||||
raise ValueError("you can't override a field with a computed field")
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
if field._complete:
|
||||
field.apply_typevars_map(typevars_map)
|
||||
field.apply_typevars_map(typevars_map, types_namespace)
|
||||
|
||||
if config_wrapper.use_attribute_docstrings:
|
||||
_update_fields_from_docstrings(cls, fields)
|
||||
return fields, class_vars
|
||||
|
||||
|
||||
def _warn_on_nested_alias_in_annotation(ann_type: type[Any], ann_name: str) -> None:
|
||||
FieldInfo = import_cached_field_info()
|
||||
def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
|
||||
from ..fields import FieldInfo
|
||||
|
||||
args = getattr(ann_type, '__args__', None)
|
||||
if args:
|
||||
for anno_arg in args:
|
||||
if typing_objects.is_annotated(get_origin(anno_arg)):
|
||||
for anno_type_arg in _typing_extra.get_args(anno_arg):
|
||||
if isinstance(anno_type_arg, FieldInfo) and anno_type_arg.alias is not None:
|
||||
warnings.warn(
|
||||
f'`alias` specification on field "{ann_name}" must be set on outermost annotation to take effect.',
|
||||
UserWarning,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def rebuild_model_fields(
|
||||
cls: type[BaseModel],
|
||||
*,
|
||||
ns_resolver: NsResolver,
|
||||
typevars_map: Mapping[TypeVar, Any],
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Rebuild the (already present) model fields by trying to reevaluate annotations.
|
||||
|
||||
This function should be called whenever a model with incomplete fields is encountered.
|
||||
|
||||
Raises:
|
||||
NameError: If one of the annotations failed to evaluate.
|
||||
|
||||
Note:
|
||||
This function *doesn't* mutate the model fields in place, as it can be called during
|
||||
schema generation, where you don't want to mutate other model's fields.
|
||||
"""
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
rebuilt_fields: dict[str, FieldInfo] = {}
|
||||
with ns_resolver.push(cls):
|
||||
for f_name, field_info in cls.__pydantic_fields__.items():
|
||||
if field_info._complete:
|
||||
rebuilt_fields[f_name] = field_info
|
||||
else:
|
||||
existing_desc = field_info.description
|
||||
ann = _typing_extra.eval_type(
|
||||
field_info._original_annotation,
|
||||
*ns_resolver.types_namespace,
|
||||
)
|
||||
ann = _generics.replace_types(ann, typevars_map)
|
||||
|
||||
if (assign := field_info._original_assignment) is PydanticUndefined:
|
||||
new_field = FieldInfo_.from_annotation(ann, _source=AnnotationSource.CLASS)
|
||||
else:
|
||||
new_field = FieldInfo_.from_annotated_attribute(ann, assign, _source=AnnotationSource.CLASS)
|
||||
# The description might come from the docstring if `use_attribute_docstrings` was `True`:
|
||||
new_field.description = new_field.description if new_field.description is not None else existing_desc
|
||||
rebuilt_fields[f_name] = new_field
|
||||
|
||||
return rebuilt_fields
|
||||
if not is_finalvar(type_):
|
||||
return False
|
||||
elif val is PydanticUndefined:
|
||||
return False
|
||||
elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def collect_dataclass_fields(
|
||||
cls: type[StandardDataclass],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
typevars_map: dict[Any, Any] | None = None,
|
||||
config_wrapper: ConfigWrapper | None = None,
|
||||
cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Collect the fields of a dataclass.
|
||||
|
||||
Args:
|
||||
cls: dataclass.
|
||||
ns_resolver: Namespace resolver to use when getting dataclass annotations.
|
||||
Defaults to an empty instance.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
typevars_map: A dictionary mapping type variables to their concrete types.
|
||||
config_wrapper: The config wrapper instance.
|
||||
|
||||
Returns:
|
||||
The dataclass fields.
|
||||
"""
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
from ..fields import FieldInfo
|
||||
|
||||
fields: dict[str, FieldInfo] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
dataclass_fields = cls.__dataclass_fields__
|
||||
dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__
|
||||
cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead
|
||||
|
||||
# The logic here is similar to `_typing_extra.get_cls_type_hints`,
|
||||
# although we do it manually as stdlib dataclasses already have annotations
|
||||
# collected in each class:
|
||||
for base in reversed(cls.__mro__):
|
||||
if not dataclasses.is_dataclass(base):
|
||||
for ann_name, dataclass_field in dataclass_fields.items():
|
||||
ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns)
|
||||
if is_classvar(ann_type):
|
||||
continue
|
||||
|
||||
with ns_resolver.push(base):
|
||||
for ann_name, dataclass_field in dataclass_fields.items():
|
||||
if ann_name not in base.__dict__.get('__annotations__', {}):
|
||||
# `__dataclass_fields__`contains every field, even the ones from base classes.
|
||||
# Only collect the ones defined on `base`.
|
||||
continue
|
||||
if not dataclass_field.init and dataclass_field.default_factory == dataclasses.MISSING:
|
||||
# TODO: We should probably do something with this so that validate_assignment behaves properly
|
||||
# Issue: https://github.com/pydantic/pydantic/issues/5470
|
||||
continue
|
||||
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
ann_type, evaluated = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns)
|
||||
if isinstance(dataclass_field.default, FieldInfo):
|
||||
if dataclass_field.default.init_var:
|
||||
# TODO: same note as above
|
||||
continue
|
||||
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default)
|
||||
else:
|
||||
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field)
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if _typing_extra.is_classvar_annotation(ann_type):
|
||||
continue
|
||||
|
||||
if (
|
||||
not dataclass_field.init
|
||||
and dataclass_field.default is dataclasses.MISSING
|
||||
and dataclass_field.default_factory is dataclasses.MISSING
|
||||
):
|
||||
# TODO: We should probably do something with this so that validate_assignment behaves properly
|
||||
# Issue: https://github.com/pydantic/pydantic/issues/5470
|
||||
continue
|
||||
|
||||
if isinstance(dataclass_field.default, FieldInfo_):
|
||||
if dataclass_field.default.init_var:
|
||||
if dataclass_field.default.init is False:
|
||||
raise PydanticUserError(
|
||||
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
|
||||
code='clashing-init-and-init-var',
|
||||
)
|
||||
|
||||
# TODO: same note as above re validate_assignment
|
||||
continue
|
||||
field_info = FieldInfo_.from_annotated_attribute(
|
||||
ann_type, dataclass_field.default, _source=AnnotationSource.DATACLASS
|
||||
)
|
||||
field_info._original_assignment = dataclass_field.default
|
||||
else:
|
||||
field_info = FieldInfo_.from_annotated_attribute(
|
||||
ann_type, dataclass_field, _source=AnnotationSource.DATACLASS
|
||||
)
|
||||
field_info._original_assignment = dataclass_field
|
||||
|
||||
if not evaluated:
|
||||
field_info._complete = False
|
||||
field_info._original_annotation = ann_type
|
||||
|
||||
fields[ann_name] = field_info
|
||||
|
||||
if field_info.default is not PydanticUndefined and isinstance(
|
||||
getattr(cls, ann_name, field_info), FieldInfo_
|
||||
):
|
||||
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
|
||||
setattr(cls, ann_name, field_info.default)
|
||||
if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo):
|
||||
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
|
||||
setattr(cls, ann_name, field_info.default)
|
||||
|
||||
if typevars_map:
|
||||
for field in fields.values():
|
||||
# We don't pass any ns, as `field.annotation`
|
||||
# was already evaluated. TODO: is this method relevant?
|
||||
# Can't we juste use `_generics.replace_types`?
|
||||
field.apply_typevars_map(typevars_map)
|
||||
|
||||
if config_wrapper is not None and config_wrapper.use_attribute_docstrings:
|
||||
_update_fields_from_docstrings(
|
||||
cls,
|
||||
fields,
|
||||
# We can't rely on the (more reliable) frame inspection method
|
||||
# for stdlib dataclasses:
|
||||
use_inspect=not hasattr(cls, '__is_pydantic_dataclass__'),
|
||||
)
|
||||
field.apply_typevars_map(typevars_map, types_namespace)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def rebuild_dataclass_fields(
|
||||
cls: type[PydanticDataclass],
|
||||
*,
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver,
|
||||
typevars_map: Mapping[TypeVar, Any],
|
||||
) -> dict[str, FieldInfo]:
|
||||
"""Rebuild the (already present) dataclass fields by trying to reevaluate annotations.
|
||||
|
||||
This function should be called whenever a dataclass with incomplete fields is encountered.
|
||||
|
||||
Raises:
|
||||
NameError: If one of the annotations failed to evaluate.
|
||||
|
||||
Note:
|
||||
This function *doesn't* mutate the dataclass fields in place, as it can be called during
|
||||
schema generation, where you don't want to mutate other dataclass's fields.
|
||||
"""
|
||||
FieldInfo_ = import_cached_field_info()
|
||||
|
||||
rebuilt_fields: dict[str, FieldInfo] = {}
|
||||
with ns_resolver.push(cls):
|
||||
for f_name, field_info in cls.__pydantic_fields__.items():
|
||||
if field_info._complete:
|
||||
rebuilt_fields[f_name] = field_info
|
||||
else:
|
||||
existing_desc = field_info.description
|
||||
ann = _typing_extra.eval_type(
|
||||
field_info._original_annotation,
|
||||
*ns_resolver.types_namespace,
|
||||
)
|
||||
ann = _generics.replace_types(ann, typevars_map)
|
||||
new_field = FieldInfo_.from_annotated_attribute(
|
||||
ann,
|
||||
field_info._original_assignment,
|
||||
_source=AnnotationSource.DATACLASS,
|
||||
)
|
||||
|
||||
# The description might come from the docstring if `use_attribute_docstrings` was `True`:
|
||||
new_field.description = new_field.description if new_field.description is not None else existing_desc
|
||||
rebuilt_fields[f_name] = new_field
|
||||
|
||||
return rebuilt_fields
|
||||
|
||||
|
||||
def is_valid_field_name(name: str) -> bool:
|
||||
return not name.startswith('_')
|
||||
|
||||
|
||||
def is_valid_privateattr_name(name: str) -> bool:
|
||||
return name.startswith('_') and not name.startswith('__')
|
||||
|
||||
|
||||
def takes_validated_data_argument(
|
||||
default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any],
|
||||
) -> TypeIs[Callable[[dict[str, Any]], Any]]:
|
||||
"""Whether the provided default factory callable has a validated data parameter."""
|
||||
try:
|
||||
sig = signature(default_factory)
|
||||
except (ValueError, TypeError):
|
||||
# `inspect.signature` might not be able to infer a signature, e.g. with C objects.
|
||||
# In this case, we assume no data argument is present:
|
||||
return False
|
||||
|
||||
parameters = list(sig.parameters.values())
|
||||
|
||||
return len(parameters) == 1 and can_be_positional(parameters[0]) and parameters[0].default is Parameter.empty
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -15,9 +14,3 @@ class PydanticRecursiveRef:
|
||||
"""Defining __call__ is necessary for the `typing` module to let you use an instance of
|
||||
this class as the result of resolving a standard ForwardRef.
|
||||
"""
|
||||
|
||||
def __or__(self, other):
|
||||
return Union[self, other] # type: ignore
|
||||
|
||||
def __ror__(self, other):
|
||||
return Union[other, self] # type: ignore
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,21 +4,17 @@ import sys
|
||||
import types
|
||||
import typing
|
||||
from collections import ChainMap
|
||||
from collections.abc import Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from itertools import zip_longest
|
||||
from types import prepare_class
|
||||
from typing import TYPE_CHECKING, Annotated, Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import typing_extensions
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from . import _typing_extra
|
||||
from ._core_utils import get_type_ref
|
||||
from ._forward_ref import PydanticRecursiveRef
|
||||
from ._typing_extra import TypeVarType, typing_base
|
||||
from ._utils import all_identical, is_model_class
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
@@ -27,7 +23,7 @@ if sys.version_info >= (3, 10):
|
||||
if TYPE_CHECKING:
|
||||
from ..main import BaseModel
|
||||
|
||||
GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]]
|
||||
GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
|
||||
|
||||
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
|
||||
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
|
||||
@@ -38,25 +34,43 @@ GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]]
|
||||
KT = TypeVar('KT')
|
||||
VT = TypeVar('VT')
|
||||
_LIMITED_DICT_SIZE = 100
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class LimitedDict(dict, MutableMapping[KT, VT]):
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
|
||||
...
|
||||
|
||||
class LimitedDict(dict[KT, VT]):
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE) -> None:
|
||||
self.size_limit = size_limit
|
||||
super().__init__()
|
||||
else:
|
||||
|
||||
def __setitem__(self, key: KT, value: VT, /) -> None:
|
||||
super().__setitem__(key, value)
|
||||
if len(self) > self.size_limit:
|
||||
excess = len(self) - self.size_limit + self.size_limit // 10
|
||||
to_remove = list(self.keys())[:excess]
|
||||
for k in to_remove:
|
||||
del self[k]
|
||||
class LimitedDict(dict):
|
||||
"""Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
|
||||
|
||||
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
|
||||
"""
|
||||
|
||||
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
|
||||
self.size_limit = size_limit
|
||||
super().__init__()
|
||||
|
||||
def __setitem__(self, __key: Any, __value: Any) -> None:
|
||||
super().__setitem__(__key, __value)
|
||||
if len(self) > self.size_limit:
|
||||
excess = len(self) - self.size_limit + self.size_limit // 10
|
||||
to_remove = list(self.keys())[:excess]
|
||||
for key in to_remove:
|
||||
del self[key]
|
||||
|
||||
def __class_getitem__(cls, *args: Any) -> Any:
|
||||
# to avoid errors with 3.7
|
||||
return cls
|
||||
|
||||
|
||||
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
|
||||
# once they are no longer referenced by the caller.
|
||||
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
|
||||
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
|
||||
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
|
||||
else:
|
||||
GenericTypesCache = WeakValueDictionary
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -94,13 +108,13 @@ else:
|
||||
# and discover later on that we need to re-add all this infrastructure...
|
||||
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
|
||||
|
||||
_GENERIC_TYPES_CACHE: ContextVar[GenericTypesCache | None] = ContextVar('_GENERIC_TYPES_CACHE', default=None)
|
||||
_GENERIC_TYPES_CACHE = GenericTypesCache()
|
||||
|
||||
|
||||
class PydanticGenericMetadata(typing_extensions.TypedDict):
|
||||
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
|
||||
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
|
||||
parameters: tuple[TypeVar, ...] # analogous to typing.Generic.__parameters__
|
||||
parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__
|
||||
|
||||
|
||||
def create_generic_submodel(
|
||||
@@ -157,7 +171,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
|
||||
depth: The depth to get the frame.
|
||||
|
||||
Returns:
|
||||
A tuple contains `module_name` and `called_globally`.
|
||||
A tuple contains `module_nam` and `called_globally`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the function is not called inside a function.
|
||||
@@ -175,7 +189,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
|
||||
DictValues: type[Any] = {}.values().__class__
|
||||
|
||||
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVar]:
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
|
||||
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
|
||||
|
||||
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
|
||||
@@ -208,7 +222,7 @@ def get_origin(v: Any) -> Any:
|
||||
return typing_extensions.get_origin(v)
|
||||
|
||||
|
||||
def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
|
||||
def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
|
||||
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
|
||||
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
|
||||
"""
|
||||
@@ -221,11 +235,11 @@ def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
|
||||
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
|
||||
# So it is safe to access cls.__args__ and origin.__parameters__
|
||||
args: tuple[Any, ...] = cls.__args__ # type: ignore
|
||||
parameters: tuple[TypeVar, ...] = origin.__parameters__
|
||||
parameters: tuple[TypeVarType, ...] = origin.__parameters__
|
||||
return dict(zip(parameters, args))
|
||||
|
||||
|
||||
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]:
|
||||
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None:
|
||||
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
|
||||
with the `replace_types` function.
|
||||
|
||||
@@ -237,13 +251,10 @@ def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]:
|
||||
generic_metadata = cls.__pydantic_generic_metadata__
|
||||
origin = generic_metadata['origin']
|
||||
args = generic_metadata['args']
|
||||
if not args:
|
||||
# No need to go into `iter_contained_typevars`:
|
||||
return {}
|
||||
return dict(zip(iter_contained_typevars(origin), args))
|
||||
|
||||
|
||||
def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
|
||||
def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
|
||||
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
|
||||
|
||||
Args:
|
||||
@@ -255,13 +266,13 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
|
||||
`typevar_map` keys recursively replaced.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from typing import List, Union
|
||||
```py
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from pydantic._internal._generics import replace_types
|
||||
|
||||
replace_types(tuple[str, Union[List[str], float]], {str: int})
|
||||
#> tuple[int, Union[List[int], float]]
|
||||
replace_types(Tuple[str, Union[List[str], float]], {str: int})
|
||||
#> Tuple[int, Union[List[int], float]]
|
||||
```
|
||||
"""
|
||||
if not type_map:
|
||||
@@ -270,25 +281,25 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if typing_objects.is_annotated(origin_type):
|
||||
if origin_type is typing_extensions.Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
annotated_type = replace_types(annotated_type, type_map)
|
||||
# TODO remove parentheses when we drop support for Python 3.10:
|
||||
return Annotated[(annotated_type, *annotations)]
|
||||
annotated = replace_types(annotated_type, type_map)
|
||||
for annotation in annotations:
|
||||
annotated = typing_extensions.Annotated[annotated, annotation]
|
||||
return annotated
|
||||
|
||||
# Having type args is a good indicator that this is a typing special form
|
||||
# instance or a generic alias of some sort.
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
if type_args:
|
||||
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
# If all arguments are the same, there is no need to modify the
|
||||
# type or create a new object at all
|
||||
return type_
|
||||
|
||||
if (
|
||||
origin_type is not None
|
||||
and isinstance(type_, _typing_extra.typing_base)
|
||||
and not isinstance(origin_type, _typing_extra.typing_base)
|
||||
and isinstance(type_, typing_base)
|
||||
and not isinstance(origin_type, typing_base)
|
||||
and getattr(type_, '_name', None) is not None
|
||||
):
|
||||
# In python < 3.9 generic aliases don't exist so any of these like `list`,
|
||||
@@ -296,24 +307,11 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
|
||||
# See: https://www.python.org/dev/peps/pep-0585
|
||||
origin_type = getattr(typing, type_._name)
|
||||
assert origin_type is not None
|
||||
|
||||
if is_union_origin(origin_type):
|
||||
if any(typing_objects.is_any(arg) for arg in resolved_type_args):
|
||||
# `Any | T` ~ `Any`:
|
||||
resolved_type_args = (Any,)
|
||||
# `Never | T` ~ `T`:
|
||||
resolved_type_args = tuple(
|
||||
arg
|
||||
for arg in resolved_type_args
|
||||
if not (typing_objects.is_noreturn(arg) or typing_objects.is_never(arg))
|
||||
)
|
||||
|
||||
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
|
||||
# We also cannot use isinstance() since we have to compare types.
|
||||
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
|
||||
return _UnionGenericAlias(origin_type, resolved_type_args)
|
||||
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
|
||||
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
|
||||
return origin_type[resolved_type_args]
|
||||
|
||||
# We handle pydantic generic models separately as they don't have the same
|
||||
# semantics as "typing" classes or generic aliases
|
||||
@@ -329,8 +327,8 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, list):
|
||||
resolved_list = [replace_types(element, type_map) for element in type_]
|
||||
if isinstance(type_, (List, list)):
|
||||
resolved_list = list(replace_types(element, type_map) for element in type_)
|
||||
if all_identical(type_, resolved_list):
|
||||
return type_
|
||||
return resolved_list
|
||||
@@ -340,57 +338,49 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any:
|
||||
return type_map.get(type_, type_)
|
||||
|
||||
|
||||
def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]:
|
||||
"""Return a mapping between the parameters of a generic model and the provided arguments during parameterization.
|
||||
def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
|
||||
"""Checks if the type, or any of its arbitrary nested args, satisfy
|
||||
`isinstance(<type>, isinstance_target)`.
|
||||
"""
|
||||
if isinstance(type_, isinstance_target):
|
||||
return True
|
||||
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is typing_extensions.Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
return has_instance_in_type(annotated_type, isinstance_target)
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
if any(has_instance_in_type(a, isinstance_target) for a in type_args):
|
||||
return True
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec):
|
||||
if any(has_instance_in_type(element, isinstance_target) for element in type_):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
|
||||
"""Check the generic model parameters count is equal.
|
||||
|
||||
Args:
|
||||
cls: The generic model.
|
||||
parameters: A tuple of passed parameters to the generic model.
|
||||
|
||||
Raises:
|
||||
TypeError: If the number of arguments does not match the parameters (i.e. if providing too few or too many arguments).
|
||||
|
||||
Example:
|
||||
```python {test="skip" lint="skip"}
|
||||
class Model[T, U, V = int](BaseModel): ...
|
||||
|
||||
map_generic_model_arguments(Model, (str, bytes))
|
||||
#> {T: str, U: bytes, V: int}
|
||||
|
||||
map_generic_model_arguments(Model, (str,))
|
||||
#> TypeError: Too few arguments for <class '__main__.Model'>; actual 1, expected at least 2
|
||||
|
||||
map_generic_model_arguments(Model, (str, bytes, int, complex))
|
||||
#> TypeError: Too many arguments for <class '__main__.Model'>; actual 4, expected 3
|
||||
```
|
||||
|
||||
Note:
|
||||
This function is analogous to the private `typing._check_generic_specialization` function.
|
||||
TypeError: If the passed parameters count is not equal to generic model parameters count.
|
||||
"""
|
||||
parameters = cls.__pydantic_generic_metadata__['parameters']
|
||||
expected_len = len(parameters)
|
||||
typevars_map: dict[TypeVar, Any] = {}
|
||||
|
||||
_missing = object()
|
||||
for parameter, argument in zip_longest(parameters, args, fillvalue=_missing):
|
||||
if parameter is _missing:
|
||||
raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}')
|
||||
|
||||
if argument is _missing:
|
||||
param = typing.cast(TypeVar, parameter)
|
||||
try:
|
||||
has_default = param.has_default()
|
||||
except AttributeError:
|
||||
# Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13.
|
||||
has_default = False
|
||||
if has_default:
|
||||
# The default might refer to other type parameters. For an example, see:
|
||||
# https://typing.readthedocs.io/en/latest/spec/generics.html#type-parameters-as-parameters-to-generics
|
||||
typevars_map[param] = replace_types(param.__default__, typevars_map)
|
||||
else:
|
||||
expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters)
|
||||
raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}')
|
||||
else:
|
||||
param = typing.cast(TypeVar, parameter)
|
||||
typevars_map[param] = argument
|
||||
|
||||
return typevars_map
|
||||
actual = len(parameters)
|
||||
expected = len(cls.__pydantic_generic_metadata__['parameters'])
|
||||
if actual != expected:
|
||||
description = 'many' if actual > expected else 'few'
|
||||
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
|
||||
|
||||
|
||||
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
|
||||
@@ -421,8 +411,7 @@ def generic_recursion_self_type(
|
||||
yield self_type
|
||||
else:
|
||||
previously_seen_type_refs.add(type_ref)
|
||||
yield
|
||||
previously_seen_type_refs.remove(type_ref)
|
||||
yield None
|
||||
finally:
|
||||
if token:
|
||||
_generic_recursion_cache.reset(token)
|
||||
@@ -453,24 +442,14 @@ def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any)
|
||||
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
|
||||
equal.
|
||||
"""
|
||||
generic_types_cache = _GENERIC_TYPES_CACHE.get()
|
||||
if generic_types_cache is None:
|
||||
generic_types_cache = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE.set(generic_types_cache)
|
||||
return generic_types_cache.get(_early_cache_key(parent, typevar_values))
|
||||
return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
|
||||
|
||||
|
||||
def get_cached_generic_type_late(
|
||||
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
|
||||
) -> type[BaseModel] | None:
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
|
||||
generic_types_cache = _GENERIC_TYPES_CACHE.get()
|
||||
if (
|
||||
generic_types_cache is None
|
||||
): # pragma: no cover (early cache is guaranteed to run first and initialize the cache)
|
||||
generic_types_cache = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE.set(generic_types_cache)
|
||||
cached = generic_types_cache.get(_late_cache_key(origin, args, typevar_values))
|
||||
cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
|
||||
if cached is not None:
|
||||
set_cached_generic_type(parent, typevar_values, cached, origin, args)
|
||||
return cached
|
||||
@@ -486,17 +465,11 @@ def set_cached_generic_type(
|
||||
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
|
||||
two different keys.
|
||||
"""
|
||||
generic_types_cache = _GENERIC_TYPES_CACHE.get()
|
||||
if (
|
||||
generic_types_cache is None
|
||||
): # pragma: no cover (cache lookup is guaranteed to run first and initialize the cache)
|
||||
generic_types_cache = GenericTypesCache()
|
||||
_GENERIC_TYPES_CACHE.set(generic_types_cache)
|
||||
generic_types_cache[_early_cache_key(parent, typevar_values)] = type_
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
|
||||
if len(typevar_values) == 1:
|
||||
generic_types_cache[_early_cache_key(parent, typevar_values[0])] = type_
|
||||
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
|
||||
if origin and args:
|
||||
generic_types_cache[_late_cache_key(origin, args, typevar_values)] = type_
|
||||
_GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
|
||||
|
||||
|
||||
def _union_orderings_key(typevar_values: Any) -> Any:
|
||||
@@ -517,7 +490,7 @@ def _union_orderings_key(typevar_values: Any) -> Any:
|
||||
for value in typevar_values:
|
||||
args_data.append(_union_orderings_key(value))
|
||||
return tuple(args_data)
|
||||
elif typing_objects.is_union(typing_extensions.get_origin(typevar_values)):
|
||||
elif typing_extensions.get_origin(typevar_values) is typing.Union:
|
||||
return get_args(typevar_values)
|
||||
else:
|
||||
return ()
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def is_git_repo(dir: Path) -> bool:
|
||||
"""Is the given directory version-controlled with git?"""
|
||||
return dir.joinpath('.git').exists()
|
||||
|
||||
|
||||
def have_git() -> bool: # pragma: no cover
|
||||
"""Can we run the git executable?"""
|
||||
try:
|
||||
subprocess.check_output(['git', '--help'])
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def git_revision(dir: Path) -> str:
|
||||
"""Get the SHA-1 of the HEAD of a git repository."""
|
||||
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip()
|
||||
@@ -1,20 +0,0 @@
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
|
||||
@cache
|
||||
def import_cached_base_model() -> type['BaseModel']:
|
||||
from pydantic import BaseModel
|
||||
|
||||
return BaseModel
|
||||
|
||||
|
||||
@cache
|
||||
def import_cached_field_info() -> type['FieldInfo']:
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
return FieldInfo
|
||||
@@ -1,4 +1,7 @@
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
dataclass_kwargs: Dict[str, Any]
|
||||
|
||||
# `slots` is available on Python >= 3.10
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@@ -1,57 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from copy import copy
|
||||
from functools import lru_cache, partial
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable
|
||||
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, ValidationError, to_jsonable_python
|
||||
import annotated_types as at
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, to_jsonable_python
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from ._fields import PydanticMetadata
|
||||
from ._import_utils import import_cached_field_info
|
||||
from . import _validators
|
||||
from ._fields import PydanticGeneralMetadata, PydanticMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from ..annotated_handlers import GetJsonSchemaHandler
|
||||
|
||||
|
||||
STRICT = {'strict'}
|
||||
FAIL_FAST = {'fail_fast'}
|
||||
LENGTH_CONSTRAINTS = {'min_length', 'max_length'}
|
||||
SEQUENCE_CONSTRAINTS = {'min_length', 'max_length'}
|
||||
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
|
||||
NUMERIC_CONSTRAINTS = {'multiple_of', *INEQUALITY}
|
||||
ALLOW_INF_NAN = {'allow_inf_nan'}
|
||||
NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY}
|
||||
|
||||
STR_CONSTRAINTS = {
|
||||
*LENGTH_CONSTRAINTS,
|
||||
*STRICT,
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'pattern',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
BYTES_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'}
|
||||
BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
|
||||
LIST_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
TUPLE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
SET_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST}
|
||||
DICT_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
GENERATOR_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT}
|
||||
SEQUENCE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *FAIL_FAST}
|
||||
LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
TUPLE_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
SET_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
DICT_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
GENERATOR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
|
||||
|
||||
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
DECIMAL_CONSTRAINTS = {'max_digits', 'decimal_places', *FLOAT_CONSTRAINTS}
|
||||
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT}
|
||||
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
BOOL_CONSTRAINTS = STRICT
|
||||
UUID_CONSTRAINTS = STRICT
|
||||
|
||||
DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
|
||||
LAX_OR_STRICT_CONSTRAINTS = STRICT
|
||||
ENUM_CONSTRAINTS = STRICT
|
||||
COMPLEX_CONSTRAINTS = STRICT
|
||||
|
||||
UNION_CONSTRAINTS = {'union_mode'}
|
||||
URL_CONSTRAINTS = {
|
||||
@@ -68,33 +53,54 @@ SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT
|
||||
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')
|
||||
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
|
||||
for constraint in STR_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES)
|
||||
for constraint in BYTES_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',))
|
||||
for constraint in LIST_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',))
|
||||
for constraint in TUPLE_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',))
|
||||
for constraint in SET_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset'))
|
||||
for constraint in DICT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',))
|
||||
for constraint in GENERATOR_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',))
|
||||
for constraint in FLOAT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',))
|
||||
for constraint in INT_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',))
|
||||
for constraint in DATE_TIME_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime'))
|
||||
for constraint in TIMEDELTA_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',))
|
||||
for constraint in TIME_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
|
||||
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
|
||||
for constraint in UNION_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',))
|
||||
for constraint in URL_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
|
||||
for constraint in BOOL_CONSTRAINTS:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',))
|
||||
|
||||
constraint_schema_pairings: list[tuple[set[str], tuple[str, ...]]] = [
|
||||
(STR_CONSTRAINTS, TEXT_SCHEMA_TYPES),
|
||||
(BYTES_CONSTRAINTS, ('bytes',)),
|
||||
(LIST_CONSTRAINTS, ('list',)),
|
||||
(TUPLE_CONSTRAINTS, ('tuple',)),
|
||||
(SET_CONSTRAINTS, ('set', 'frozenset')),
|
||||
(DICT_CONSTRAINTS, ('dict',)),
|
||||
(GENERATOR_CONSTRAINTS, ('generator',)),
|
||||
(FLOAT_CONSTRAINTS, ('float',)),
|
||||
(INT_CONSTRAINTS, ('int',)),
|
||||
(DATE_TIME_CONSTRAINTS, ('date', 'time', 'datetime', 'timedelta')),
|
||||
# TODO: this is a bit redundant, we could probably avoid some of these
|
||||
(STRICT, (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model')),
|
||||
(UNION_CONSTRAINTS, ('union',)),
|
||||
(URL_CONSTRAINTS, ('url', 'multi-host-url')),
|
||||
(BOOL_CONSTRAINTS, ('bool',)),
|
||||
(UUID_CONSTRAINTS, ('uuid',)),
|
||||
(LAX_OR_STRICT_CONSTRAINTS, ('lax-or-strict',)),
|
||||
(ENUM_CONSTRAINTS, ('enum',)),
|
||||
(DECIMAL_CONSTRAINTS, ('decimal',)),
|
||||
(COMPLEX_CONSTRAINTS, ('complex',)),
|
||||
]
|
||||
|
||||
for constraints, schemas in constraint_schema_pairings:
|
||||
for c in constraints:
|
||||
CONSTRAINTS_TO_ALLOWED_SCHEMAS[c].update(schemas)
|
||||
def add_js_update_schema(s: cs.CoreSchema, f: Callable[[], dict[str, Any]]) -> None:
|
||||
def update_js_schema(s: cs.CoreSchema, handler: GetJsonSchemaHandler) -> dict[str, Any]:
|
||||
js_schema = handler(s)
|
||||
js_schema.update(f())
|
||||
return js_schema
|
||||
|
||||
if 'metadata' in s:
|
||||
metadata = s['metadata']
|
||||
if 'pydantic_js_functions' in s:
|
||||
metadata['pydantic_js_functions'].append(update_js_schema)
|
||||
else:
|
||||
metadata['pydantic_js_functions'] = [update_js_schema]
|
||||
else:
|
||||
s['metadata'] = {'pydantic_js_functions': [update_js_schema]}
|
||||
|
||||
|
||||
def as_jsonable_value(v: Any) -> Any:
|
||||
@@ -113,7 +119,7 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
An iterable of expanded annotations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
```py
|
||||
from annotated_types import Ge, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import expand_grouped_metadata
|
||||
@@ -122,9 +128,7 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
#> [Ge(ge=4), MinLen(min_length=5)]
|
||||
```
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
from pydantic.fields import FieldInfo # circular import
|
||||
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, at.GroupedMetadata):
|
||||
@@ -143,28 +147,6 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
|
||||
yield annotation
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_at_to_constraint_map() -> dict[type, str]:
|
||||
"""Return a mapping of annotated types to constraints.
|
||||
|
||||
Normally, we would define a mapping like this in the module scope, but we can't do that
|
||||
because we don't permit module level imports of `annotated_types`, in an attempt to speed up
|
||||
the import time of `pydantic`. We still only want to have this dictionary defined in one place,
|
||||
so we use this function to cache the result.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
return {
|
||||
at.Gt: 'gt',
|
||||
at.Ge: 'ge',
|
||||
at.Lt: 'lt',
|
||||
at.Le: 'le',
|
||||
at.MultipleOf: 'multiple_of',
|
||||
at.MinLen: 'min_length',
|
||||
at.MaxLen: 'max_length',
|
||||
}
|
||||
|
||||
|
||||
def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901
|
||||
"""Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.).
|
||||
Otherwise return `None`.
|
||||
@@ -184,37 +166,14 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
|
||||
Raises:
|
||||
PydanticCustomError: If `Predicate` fails.
|
||||
"""
|
||||
import annotated_types as at
|
||||
|
||||
from ._validators import NUMERIC_VALIDATOR_LOOKUP, forbid_inf_nan_check
|
||||
|
||||
schema = schema.copy()
|
||||
schema_update, other_metadata = collect_known_metadata([annotation])
|
||||
schema_type = schema['type']
|
||||
|
||||
chain_schema_constraints: set[str] = {
|
||||
'pattern',
|
||||
'strip_whitespace',
|
||||
'to_lower',
|
||||
'to_upper',
|
||||
'coerce_numbers_to_str',
|
||||
}
|
||||
chain_schema_steps: list[CoreSchema] = []
|
||||
|
||||
for constraint, value in schema_update.items():
|
||||
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
|
||||
raise ValueError(f'Unknown constraint {constraint}')
|
||||
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]
|
||||
|
||||
# if it becomes necessary to handle more than one constraint
|
||||
# in this recursive case with function-after or function-wrap, we should refactor
|
||||
# this is a bit challenging because we sometimes want to apply constraints to the inner schema,
|
||||
# whereas other times we want to wrap the existing schema with a new one that enforces a new constraint.
|
||||
if schema_type in {'function-before', 'function-wrap', 'function-after'} and constraint == 'strict':
|
||||
schema['schema'] = apply_known_metadata(annotation, schema['schema']) # type: ignore # schema is function schema
|
||||
return schema
|
||||
|
||||
# if we're allowed to apply constraint directly to the schema, like le to int, do that
|
||||
if schema_type in allowed_schemas:
|
||||
if constraint == 'union_mode' and schema_type == 'union':
|
||||
schema['mode'] = value # type: ignore # schema is UnionSchema
|
||||
@@ -222,109 +181,145 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No
|
||||
schema[constraint] = value
|
||||
continue
|
||||
|
||||
# else, apply a function after validator to the schema to enforce the corresponding constraint
|
||||
if constraint in chain_schema_constraints:
|
||||
|
||||
def _apply_constraint_with_incompatibility_info(
|
||||
value: Any, handler: cs.ValidatorFunctionWrapHandler
|
||||
) -> Any:
|
||||
try:
|
||||
x = handler(value)
|
||||
except ValidationError as ve:
|
||||
# if the error is about the type, it's likely that the constraint is incompatible the type of the field
|
||||
# for example, the following invalid schema wouldn't be caught during schema build, but rather at this point
|
||||
# with a cryptic 'string_type' error coming from the string validator,
|
||||
# that we'd rather express as a constraint incompatibility error (TypeError)
|
||||
# Annotated[list[int], Field(pattern='abc')]
|
||||
if 'type' in ve.errors()[0]['type']:
|
||||
raise TypeError(
|
||||
f"Unable to apply constraint '{constraint}' to supplied value {value} for schema of type '{schema_type}'" # noqa: B023
|
||||
)
|
||||
raise ve
|
||||
return x
|
||||
|
||||
chain_schema_steps.append(
|
||||
cs.no_info_wrap_validator_function(
|
||||
_apply_constraint_with_incompatibility_info, cs.str_schema(**{constraint: value})
|
||||
)
|
||||
if constraint == 'allow_inf_nan' and value is False:
|
||||
return cs.no_info_after_validator_function(
|
||||
_validators.forbid_inf_nan_check,
|
||||
schema,
|
||||
)
|
||||
elif constraint in NUMERIC_VALIDATOR_LOOKUP:
|
||||
if constraint in LENGTH_CONSTRAINTS:
|
||||
inner_schema = schema
|
||||
while inner_schema['type'] in {'function-before', 'function-wrap', 'function-after'}:
|
||||
inner_schema = inner_schema['schema'] # type: ignore
|
||||
inner_schema_type = inner_schema['type']
|
||||
if inner_schema_type == 'list' or (
|
||||
inner_schema_type == 'json-or-python' and inner_schema['json_schema']['type'] == 'list' # type: ignore
|
||||
):
|
||||
js_constraint_key = 'minItems' if constraint == 'min_length' else 'maxItems'
|
||||
else:
|
||||
js_constraint_key = 'minLength' if constraint == 'min_length' else 'maxLength'
|
||||
else:
|
||||
js_constraint_key = constraint
|
||||
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(NUMERIC_VALIDATOR_LOOKUP[constraint], **{constraint: value}), schema
|
||||
elif constraint == 'pattern':
|
||||
# insert a str schema to make sure the regex engine matches
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(pattern=value),
|
||||
]
|
||||
)
|
||||
metadata = schema.get('metadata', {})
|
||||
if (existing_json_schema_updates := metadata.get('pydantic_js_updates')) is not None:
|
||||
metadata['pydantic_js_updates'] = {
|
||||
**existing_json_schema_updates,
|
||||
**{js_constraint_key: as_jsonable_value(value)},
|
||||
}
|
||||
else:
|
||||
metadata['pydantic_js_updates'] = {js_constraint_key: as_jsonable_value(value)}
|
||||
schema['metadata'] = metadata
|
||||
elif constraint == 'allow_inf_nan' and value is False:
|
||||
schema = cs.no_info_after_validator_function(
|
||||
forbid_inf_nan_check,
|
||||
elif constraint == 'gt':
|
||||
s = cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_validator, gt=value),
|
||||
schema,
|
||||
)
|
||||
add_js_update_schema(s, lambda: {'gt': as_jsonable_value(value)})
|
||||
return s
|
||||
elif constraint == 'ge':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_or_equal_validator, ge=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'lt':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_validator, lt=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'le':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_or_equal_validator, le=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'multiple_of':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.multiple_of_validator, multiple_of=value),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'min_length':
|
||||
s = cs.no_info_after_validator_function(
|
||||
partial(_validators.min_length_validator, min_length=value),
|
||||
schema,
|
||||
)
|
||||
add_js_update_schema(s, lambda: {'minLength': (as_jsonable_value(value))})
|
||||
return s
|
||||
elif constraint == 'max_length':
|
||||
s = cs.no_info_after_validator_function(
|
||||
partial(_validators.max_length_validator, max_length=value),
|
||||
schema,
|
||||
)
|
||||
add_js_update_schema(s, lambda: {'maxLength': (as_jsonable_value(value))})
|
||||
return s
|
||||
elif constraint == 'strip_whitespace':
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(strip_whitespace=True),
|
||||
]
|
||||
)
|
||||
elif constraint == 'to_lower':
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(to_lower=True),
|
||||
]
|
||||
)
|
||||
elif constraint == 'to_upper':
|
||||
return cs.chain_schema(
|
||||
[
|
||||
schema,
|
||||
cs.str_schema(to_upper=True),
|
||||
]
|
||||
)
|
||||
elif constraint == 'min_length':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.min_length_validator, min_length=annotation.min_length),
|
||||
schema,
|
||||
)
|
||||
elif constraint == 'max_length':
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.max_length_validator, max_length=annotation.max_length),
|
||||
schema,
|
||||
)
|
||||
else:
|
||||
# It's rare that we'd get here, but it's possible if we add a new constraint and forget to handle it
|
||||
# Most constraint errors are caught at runtime during attempted application
|
||||
raise RuntimeError(f"Unable to apply constraint '{constraint}' to schema of type '{schema_type}'")
|
||||
raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}')
|
||||
|
||||
for annotation in other_metadata:
|
||||
if (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
validator = NUMERIC_VALIDATOR_LOOKUP.get(constraint)
|
||||
if validator is None:
|
||||
raise ValueError(f'Unknown constraint {constraint}')
|
||||
schema = cs.no_info_after_validator_function(
|
||||
partial(validator, {constraint: getattr(annotation, constraint)}), schema
|
||||
if isinstance(annotation, at.Gt):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_validator, gt=annotation.gt),
|
||||
schema,
|
||||
)
|
||||
continue
|
||||
elif isinstance(annotation, (at.Predicate, at.Not)):
|
||||
predicate_name = f'{annotation.func.__qualname__}' if hasattr(annotation.func, '__qualname__') else ''
|
||||
elif isinstance(annotation, at.Ge):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.greater_than_or_equal_validator, ge=annotation.ge),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.Lt):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_validator, lt=annotation.lt),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.Le):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.less_than_or_equal_validator, le=annotation.le),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.MultipleOf):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.MinLen):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.min_length_validator, min_length=annotation.min_length),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.MaxLen):
|
||||
return cs.no_info_after_validator_function(
|
||||
partial(_validators.max_length_validator, max_length=annotation.max_length),
|
||||
schema,
|
||||
)
|
||||
elif isinstance(annotation, at.Predicate):
|
||||
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''
|
||||
|
||||
def val_func(v: Any) -> Any:
|
||||
predicate_satisfied = annotation.func(v) # noqa: B023
|
||||
|
||||
# annotation.func may also raise an exception, let it pass through
|
||||
if isinstance(annotation, at.Predicate): # noqa: B023
|
||||
if not predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'predicate_failed',
|
||||
f'Predicate {predicate_name} failed', # type: ignore # noqa: B023
|
||||
)
|
||||
else:
|
||||
if predicate_satisfied:
|
||||
raise PydanticCustomError(
|
||||
'not_operation_failed',
|
||||
f'Not of {predicate_name} failed', # type: ignore # noqa: B023
|
||||
)
|
||||
|
||||
if not annotation.func(v):
|
||||
raise PydanticCustomError(
|
||||
'predicate_failed',
|
||||
f'Predicate {predicate_name}failed', # type: ignore
|
||||
)
|
||||
return v
|
||||
|
||||
schema = cs.no_info_after_validator_function(val_func, schema)
|
||||
else:
|
||||
# ignore any other unknown metadata
|
||||
return None
|
||||
|
||||
if chain_schema_steps:
|
||||
chain_schema_steps = [schema] + chain_schema_steps
|
||||
return cs.chain_schema(chain_schema_steps)
|
||||
return cs.no_info_after_validator_function(val_func, schema)
|
||||
# ignore any other unknown metadata
|
||||
return None
|
||||
|
||||
return schema
|
||||
|
||||
@@ -339,7 +334,7 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
|
||||
A tuple contains a dict of known metadata and a list of unknown annotations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
```py
|
||||
from annotated_types import Gt, Len
|
||||
|
||||
from pydantic._internal._known_annotated_metadata import collect_known_metadata
|
||||
@@ -352,15 +347,29 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any],
|
||||
|
||||
res: dict[str, Any] = {}
|
||||
remaining: list[Any] = []
|
||||
|
||||
for annotation in annotations:
|
||||
# isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata
|
||||
if isinstance(annotation, PydanticMetadata):
|
||||
# Do we really want to consume any `BaseMetadata`?
|
||||
# It does let us give a better error when there is an annotation that doesn't apply
|
||||
# But it seems dangerous!
|
||||
if isinstance(annotation, PydanticGeneralMetadata):
|
||||
res.update(annotation.__dict__)
|
||||
elif isinstance(annotation, PydanticMetadata):
|
||||
res.update(annotation.__dict__)
|
||||
# we don't use dataclasses.asdict because that recursively calls asdict on the field values
|
||||
elif (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()):
|
||||
constraint = at_to_constraint_map[annotation_type]
|
||||
res[constraint] = getattr(annotation, constraint)
|
||||
elif isinstance(annotation, at.MinLen):
|
||||
res.update({'min_length': annotation.min_length})
|
||||
elif isinstance(annotation, at.MaxLen):
|
||||
res.update({'max_length': annotation.max_length})
|
||||
elif isinstance(annotation, at.Gt):
|
||||
res.update({'gt': annotation.gt})
|
||||
elif isinstance(annotation, at.Ge):
|
||||
res.update({'ge': annotation.ge})
|
||||
elif isinstance(annotation, at.Lt):
|
||||
res.update({'lt': annotation.lt})
|
||||
elif isinstance(annotation, at.Le):
|
||||
res.update({'le': annotation.le})
|
||||
elif isinstance(annotation, at.MultipleOf):
|
||||
res.update({'multiple_of': annotation.multiple_of})
|
||||
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
|
||||
# also support PydanticMetadata classes being used without initialisation,
|
||||
# e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]`
|
||||
|
||||
@@ -1,71 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
|
||||
|
||||
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
|
||||
from pydantic_core import SchemaSerializer, SchemaValidator
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..errors import PydanticErrorCodes, PydanticUserError
|
||||
from ..plugin._schema_validator import PluggableSchemaValidator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..dataclasses import PydanticDataclass
|
||||
from ..main import BaseModel
|
||||
from ..type_adapter import TypeAdapter
|
||||
|
||||
|
||||
ValSer = TypeVar('ValSer', bound=Union[SchemaValidator, PluggableSchemaValidator, SchemaSerializer])
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class MockCoreSchema(Mapping[str, Any]):
|
||||
"""Mocker for `pydantic_core.CoreSchema` which optionally attempts to
|
||||
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
|
||||
"""
|
||||
|
||||
__slots__ = '_error_message', '_code', '_attempt_rebuild', '_built_memo'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
error_message: str,
|
||||
*,
|
||||
code: PydanticErrorCodes,
|
||||
attempt_rebuild: Callable[[], CoreSchema | None] | None = None,
|
||||
) -> None:
|
||||
self._error_message = error_message
|
||||
self._code: PydanticErrorCodes = code
|
||||
self._attempt_rebuild = attempt_rebuild
|
||||
self._built_memo: CoreSchema | None = None
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._get_built().__getitem__(key)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._get_built().__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return self._get_built().__iter__()
|
||||
|
||||
def _get_built(self) -> CoreSchema:
|
||||
if self._built_memo is not None:
|
||||
return self._built_memo
|
||||
|
||||
if self._attempt_rebuild:
|
||||
schema = self._attempt_rebuild()
|
||||
if schema is not None:
|
||||
self._built_memo = schema
|
||||
return schema
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
|
||||
def rebuild(self) -> CoreSchema | None:
|
||||
self._built_memo = None
|
||||
if self._attempt_rebuild:
|
||||
schema = self._attempt_rebuild()
|
||||
if schema is not None:
|
||||
return schema
|
||||
else:
|
||||
raise PydanticUserError(self._error_message, code=self._code)
|
||||
return None
|
||||
ValSer = TypeVar('ValSer', SchemaValidator, SchemaSerializer)
|
||||
|
||||
|
||||
class MockValSer(Generic[ValSer]):
|
||||
@@ -109,120 +56,63 @@ class MockValSer(Generic[ValSer]):
|
||||
return None
|
||||
|
||||
|
||||
def set_type_adapter_mocks(adapter: TypeAdapter) -> None:
|
||||
"""Set `core_schema`, `validator` and `serializer` to mock core types on a type adapter instance.
|
||||
|
||||
Args:
|
||||
adapter: The type adapter instance to set the mocks on
|
||||
"""
|
||||
type_repr = str(adapter._type)
|
||||
undefined_type_error_message = (
|
||||
f'`TypeAdapter[{type_repr}]` is not fully defined; you should define `{type_repr}` and all referenced types,'
|
||||
f' then call `.rebuild()` on the instance.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[TypeAdapter], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if adapter.rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(adapter)
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
adapter.core_schema = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.core_schema),
|
||||
)
|
||||
adapter.validator = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.validator),
|
||||
)
|
||||
adapter.serializer = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.serializer),
|
||||
)
|
||||
|
||||
|
||||
def set_model_mocks(cls: type[BaseModel], undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_core_schema__`, `__pydantic_validator__` and `__pydantic_serializer__` to mock core types on a model.
|
||||
def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
cls_name: Name of the model class, used in error messages
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
undefined_type_error_message = (
|
||||
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `{cls.__name__}.model_rebuild()`.'
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `{cls_name}.model_rebuild()`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[BaseModel]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
def attempt_rebuild_validator() -> SchemaValidator | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5):
|
||||
return cls.__pydantic_validator__
|
||||
else:
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
attempt_rebuild=attempt_rebuild_validator,
|
||||
)
|
||||
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
def attempt_rebuild_serializer() -> SchemaSerializer | None:
|
||||
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5):
|
||||
return cls.__pydantic_serializer__
|
||||
else:
|
||||
return None
|
||||
|
||||
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
attempt_rebuild=attempt_rebuild_serializer,
|
||||
)
|
||||
|
||||
|
||||
def set_dataclass_mocks(cls: type[PydanticDataclass], undefined_name: str = 'all referenced types') -> None:
|
||||
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
|
||||
|
||||
Args:
|
||||
cls: The model class to set the mocks on
|
||||
undefined_name: Name of the undefined thing, used in error messages
|
||||
"""
|
||||
from ..dataclasses import rebuild_dataclass
|
||||
|
||||
def set_dataclass_mock_validator(cls: type[PydanticDataclass], cls_name: str, undefined_name: str) -> None:
|
||||
undefined_type_error_message = (
|
||||
f'`{cls.__name__}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `pydantic.dataclasses.rebuild_dataclass({cls.__name__})`.'
|
||||
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
|
||||
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
|
||||
)
|
||||
|
||||
def attempt_rebuild_fn(attr_fn: Callable[[type[PydanticDataclass]], T]) -> Callable[[], T | None]:
|
||||
def handler() -> T | None:
|
||||
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
|
||||
return attr_fn(cls)
|
||||
def attempt_rebuild() -> SchemaValidator | None:
|
||||
from ..dataclasses import rebuild_dataclass
|
||||
|
||||
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5):
|
||||
return cls.__pydantic_validator__
|
||||
else:
|
||||
return None
|
||||
|
||||
return handler
|
||||
|
||||
cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__),
|
||||
)
|
||||
cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='validator',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__),
|
||||
)
|
||||
cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue]
|
||||
undefined_type_error_message,
|
||||
code='class-not-fully-defined',
|
||||
val_or_ser='serializer',
|
||||
attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__),
|
||||
attempt_rebuild=attempt_rebuild,
|
||||
)
|
||||
|
||||
@@ -1,54 +1,58 @@
|
||||
"""Private logic for creating models."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import builtins
|
||||
import operator
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABCMeta
|
||||
from functools import cache, partial, wraps
|
||||
from functools import partial
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable, Generic, Literal, NoReturn, cast
|
||||
from typing import Any, Callable, Generic, Mapping
|
||||
|
||||
from pydantic_core import PydanticUndefined, SchemaSerializer
|
||||
from typing_extensions import TypeAliasType, dataclass_transform, deprecated, get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
from typing_extensions import dataclass_transform, deprecated
|
||||
|
||||
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
|
||||
from ..fields import Field, FieldInfo, ModelPrivateAttr, PrivateAttr
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
from ._config import ConfigWrapper
|
||||
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases, unwrap_wrapped_function
|
||||
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
|
||||
from ._generate_schema import GenerateSchema, InvalidSchemaError
|
||||
from ._generics import PydanticGenericMetadata, get_model_typevars_map
|
||||
from ._import_utils import import_cached_base_model, import_cached_field_info
|
||||
from ._mock_val_ser import set_model_mocks
|
||||
from ._namespace_utils import NsResolver
|
||||
from ._signature import generate_pydantic_signature
|
||||
from ._typing_extra import (
|
||||
_make_forward_ref,
|
||||
eval_type_backport,
|
||||
is_classvar_annotation,
|
||||
parent_frame_namespace,
|
||||
from ._core_utils import collect_invalid_schemas, simplify_schema_references, validate_core_schema
|
||||
from ._decorators import (
|
||||
ComputedFieldInfo,
|
||||
DecoratorInfos,
|
||||
PydanticDescriptorProxy,
|
||||
get_attribute_from_bases,
|
||||
)
|
||||
from ._utils import LazyClassAttribute, SafeGetItemProxy
|
||||
from ._discriminated_union import apply_discriminators
|
||||
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._generics import PydanticGenericMetadata, get_model_typevars_map
|
||||
from ._mock_val_ser import MockValSer, set_model_mocks
|
||||
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
|
||||
from ._typing_extra import get_cls_types_namespace, is_classvar, parent_frame_namespace
|
||||
from ._utils import ClassAttribute, is_valid_identifier
|
||||
from ._validate_call import ValidateCallWrapper
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..fields import Field as PydanticModelField
|
||||
from ..fields import FieldInfo, ModelPrivateAttr
|
||||
from ..fields import PrivateAttr as PydanticModelPrivateAttr
|
||||
from inspect import Signature
|
||||
|
||||
from ..main import BaseModel
|
||||
else:
|
||||
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
|
||||
# and https://youtrack.jetbrains.com/issue/PY-51428
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
PydanticModelField = object()
|
||||
PydanticModelPrivateAttr = object()
|
||||
|
||||
|
||||
IGNORED_TYPES: tuple[Any, ...] = (
|
||||
FunctionType,
|
||||
property,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
PydanticDescriptorProxy,
|
||||
ComputedFieldInfo,
|
||||
ValidateCallWrapper,
|
||||
)
|
||||
object_setattr = object.__setattr__
|
||||
|
||||
|
||||
@@ -65,17 +69,7 @@ class _ModelNamespaceDict(dict):
|
||||
return super().__setitem__(k, v)
|
||||
|
||||
|
||||
def NoInitField(
|
||||
*,
|
||||
init: Literal[False] = False,
|
||||
) -> Any:
|
||||
"""Only for typing purposes. Used as default value of `__pydantic_fields_set__`,
|
||||
`__pydantic_extra__`, `__pydantic_private__`, so they could be ignored when
|
||||
synthesizing the `__init__` signature.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr, NoInitField))
|
||||
@dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
|
||||
class ModelMetaclass(ABCMeta):
|
||||
def __new__(
|
||||
mcs,
|
||||
@@ -84,7 +78,6 @@ class ModelMetaclass(ABCMeta):
|
||||
namespace: dict[str, Any],
|
||||
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
|
||||
__pydantic_reset_parent_namespace__: bool = True,
|
||||
_create_model_module: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> type:
|
||||
"""Metaclass for creating Pydantic models.
|
||||
@@ -95,7 +88,6 @@ class ModelMetaclass(ABCMeta):
|
||||
namespace: The attribute dictionary of the class to be created.
|
||||
__pydantic_generic_metadata__: Metadata for generic models.
|
||||
__pydantic_reset_parent_namespace__: Reset parent namespace.
|
||||
_create_model_module: The module of the class to be created, if created by `create_model`.
|
||||
**kwargs: Catch-all for any other keyword arguments.
|
||||
|
||||
Returns:
|
||||
@@ -112,18 +104,17 @@ class ModelMetaclass(ABCMeta):
|
||||
private_attributes = inspect_namespace(
|
||||
namespace, config_wrapper.ignored_types, class_vars, base_field_names
|
||||
)
|
||||
if private_attributes or base_private_attributes:
|
||||
if private_attributes:
|
||||
original_model_post_init = get_model_post_init(namespace, bases)
|
||||
if original_model_post_init is not None:
|
||||
# if there are private_attributes and a model_post_init function, we handle both
|
||||
|
||||
@wraps(original_model_post_init)
|
||||
def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None:
|
||||
def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
|
||||
"""We need to both initialize private attributes and call the user-defined model_post_init
|
||||
method.
|
||||
"""
|
||||
init_private_attributes(self, context)
|
||||
original_model_post_init(self, context)
|
||||
init_private_attributes(self, __context)
|
||||
original_model_post_init(self, __context)
|
||||
|
||||
namespace['model_post_init'] = wrapped_model_post_init
|
||||
else:
|
||||
@@ -132,25 +123,15 @@ class ModelMetaclass(ABCMeta):
|
||||
namespace['__class_vars__'] = class_vars
|
||||
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
|
||||
|
||||
cls = cast('type[BaseModel]', super().__new__(mcs, cls_name, bases, namespace, **kwargs))
|
||||
BaseModel_ = import_cached_base_model()
|
||||
if config_wrapper.frozen:
|
||||
set_default_hash_func(namespace, bases)
|
||||
|
||||
mro = cls.__mro__
|
||||
if Generic in mro and mro.index(Generic) < mro.index(BaseModel_):
|
||||
warnings.warn(
|
||||
GenericBeforeBaseModelWarning(
|
||||
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
|
||||
'for pydantic generics to work properly.'
|
||||
),
|
||||
stacklevel=2,
|
||||
)
|
||||
cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore
|
||||
|
||||
from ..main import BaseModel
|
||||
|
||||
cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
|
||||
cls.__pydantic_post_init__ = (
|
||||
None if cls.model_post_init is BaseModel_.model_post_init else 'model_post_init'
|
||||
)
|
||||
|
||||
cls.__pydantic_setattr_handlers__ = {}
|
||||
cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init'
|
||||
|
||||
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
|
||||
|
||||
@@ -161,40 +142,22 @@ class ModelMetaclass(ABCMeta):
|
||||
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
|
||||
parameters = getattr(cls, '__parameters__', None) or parent_parameters
|
||||
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
|
||||
from ..root_model import RootModelRootType
|
||||
|
||||
missing_parameters = tuple(x for x in parameters if x not in parent_parameters)
|
||||
if RootModelRootType in parent_parameters and RootModelRootType not in parameters:
|
||||
# This is a special case where the user has subclassed `RootModel`, but has not parametrized
|
||||
# RootModel with the generic type identifiers being used. Ex:
|
||||
# class MyModel(RootModel, Generic[T]):
|
||||
# root: T
|
||||
# Should instead just be:
|
||||
# class MyModel(RootModel[T]):
|
||||
# root: T
|
||||
parameters_str = ', '.join([x.__name__ for x in missing_parameters])
|
||||
error_message = (
|
||||
f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) '
|
||||
f'{parameters_str} in its parameters. '
|
||||
f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.'
|
||||
combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters)
|
||||
parameters_str = ', '.join([str(x) for x in combined_parameters])
|
||||
generic_type_label = f'typing.Generic[{parameters_str}]'
|
||||
error_message = (
|
||||
f'All parameters must be present on typing.Generic;'
|
||||
f' you should inherit from {generic_type_label}.'
|
||||
)
|
||||
if Generic not in bases: # pragma: no cover
|
||||
# We raise an error here not because it is desirable, but because some cases are mishandled.
|
||||
# It would be nice to remove this error and still have things behave as expected, it's just
|
||||
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
|
||||
# and not returning a typing._GenericAlias from it.
|
||||
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
|
||||
error_message += (
|
||||
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
|
||||
)
|
||||
else:
|
||||
combined_parameters = parent_parameters + missing_parameters
|
||||
parameters_str = ', '.join([str(x) for x in combined_parameters])
|
||||
generic_type_label = f'typing.Generic[{parameters_str}]'
|
||||
error_message = (
|
||||
f'All parameters must be present on typing.Generic;'
|
||||
f' you should inherit from {generic_type_label}.'
|
||||
)
|
||||
if Generic not in bases: # pragma: no cover
|
||||
# We raise an error here not because it is desirable, but because some cases are mishandled.
|
||||
# It would be nice to remove this error and still have things behave as expected, it's just
|
||||
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
|
||||
# and not returning a typing._GenericAlias from it.
|
||||
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
|
||||
error_message += (
|
||||
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
|
||||
)
|
||||
raise TypeError(error_message)
|
||||
|
||||
cls.__pydantic_generic_metadata__ = {
|
||||
@@ -212,55 +175,29 @@ class ModelMetaclass(ABCMeta):
|
||||
|
||||
if __pydantic_reset_parent_namespace__:
|
||||
cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace())
|
||||
parent_namespace: dict[str, Any] | None = getattr(cls, '__pydantic_parent_namespace__', None)
|
||||
parent_namespace = getattr(cls, '__pydantic_parent_namespace__', None)
|
||||
if isinstance(parent_namespace, dict):
|
||||
parent_namespace = unpack_lenient_weakvaluedict(parent_namespace)
|
||||
|
||||
ns_resolver = NsResolver(parent_namespace=parent_namespace)
|
||||
|
||||
set_model_fields(cls, config_wrapper=config_wrapper, ns_resolver=ns_resolver)
|
||||
|
||||
# This is also set in `complete_model_class()`, after schema gen because they are recreated.
|
||||
# We set them here as well for backwards compatibility:
|
||||
cls.__pydantic_computed_fields__ = {
|
||||
k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()
|
||||
}
|
||||
|
||||
if config_wrapper.defer_build:
|
||||
# TODO we can also stop there if `__pydantic_fields_complete__` is False.
|
||||
# However, `set_model_fields()` is currently lenient and we don't have access to the `NameError`.
|
||||
# (which is useful as we can provide the name in the error message: `set_model_mock(cls, e.name)`)
|
||||
set_model_mocks(cls)
|
||||
else:
|
||||
# Any operation that requires accessing the field infos instances should be put inside
|
||||
# `complete_model_class()`:
|
||||
complete_model_class(
|
||||
cls,
|
||||
config_wrapper,
|
||||
raise_errors=False,
|
||||
ns_resolver=ns_resolver,
|
||||
create_model_module=_create_model_module,
|
||||
)
|
||||
|
||||
if config_wrapper.frozen and '__hash__' not in namespace:
|
||||
set_default_hash_func(cls, bases)
|
||||
|
||||
types_namespace = get_cls_types_namespace(cls, parent_namespace)
|
||||
set_model_fields(cls, bases, config_wrapper, types_namespace)
|
||||
complete_model_class(
|
||||
cls,
|
||||
cls_name,
|
||||
config_wrapper,
|
||||
raise_errors=False,
|
||||
types_namespace=types_namespace,
|
||||
)
|
||||
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
|
||||
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
|
||||
# only hit for _proper_ subclasses of BaseModel
|
||||
super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc]
|
||||
return cls
|
||||
else:
|
||||
# These are instance variables, but have been assigned to `NoInitField` to trick the type checker.
|
||||
for instance_slot in '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__':
|
||||
namespace.pop(
|
||||
instance_slot,
|
||||
None, # In case the metaclass is used with a class other than `BaseModel`.
|
||||
)
|
||||
namespace.get('__annotations__', {}).clear()
|
||||
# this is the BaseModel class itself being created, no logic required
|
||||
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
|
||||
|
||||
if not typing.TYPE_CHECKING: # pragma: no branch
|
||||
if not typing.TYPE_CHECKING:
|
||||
# We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
@@ -268,29 +205,30 @@ class ModelMetaclass(ABCMeta):
|
||||
private_attributes = self.__dict__.get('__private_attributes__')
|
||||
if private_attributes and item in private_attributes:
|
||||
return private_attributes[item]
|
||||
if item == '__pydantic_core_schema__':
|
||||
# This means the class didn't get a schema generated for it, likely because there was an undefined reference
|
||||
maybe_mock_validator = getattr(self, '__pydantic_validator__', None)
|
||||
if isinstance(maybe_mock_validator, MockValSer):
|
||||
rebuilt_validator = maybe_mock_validator.rebuild()
|
||||
if rebuilt_validator is not None:
|
||||
# In this case, a validator was built, and so `__pydantic_core_schema__` should now be set
|
||||
return getattr(self, '__pydantic_core_schema__')
|
||||
raise AttributeError(item)
|
||||
|
||||
@classmethod
|
||||
def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]:
|
||||
def __prepare__(cls, *args: Any, **kwargs: Any) -> Mapping[str, object]:
|
||||
return _ModelNamespaceDict()
|
||||
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
"""Avoid calling ABC _abc_instancecheck unless we're pretty sure.
|
||||
|
||||
See #3829 and python/cpython#92810
|
||||
"""
|
||||
return hasattr(instance, '__pydantic_decorators__') and super().__instancecheck__(instance)
|
||||
|
||||
def __subclasscheck__(self, subclass: type[Any]) -> bool:
|
||||
"""Avoid calling ABC _abc_subclasscheck unless we're pretty sure.
|
||||
|
||||
See #3829 and python/cpython#92810
|
||||
"""
|
||||
return hasattr(subclass, '__pydantic_decorators__') and super().__subclasscheck__(subclass)
|
||||
return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance)
|
||||
|
||||
@staticmethod
|
||||
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]:
|
||||
BaseModel = import_cached_base_model()
|
||||
from ..main import BaseModel
|
||||
|
||||
field_names: set[str] = set()
|
||||
class_vars: set[str] = set()
|
||||
@@ -298,57 +236,35 @@ class ModelMetaclass(ABCMeta):
|
||||
for base in bases:
|
||||
if issubclass(base, BaseModel) and base is not BaseModel:
|
||||
# model_fields might not be defined yet in the case of generics, so we use getattr here:
|
||||
field_names.update(getattr(base, '__pydantic_fields__', {}).keys())
|
||||
field_names.update(getattr(base, 'model_fields', {}).keys())
|
||||
class_vars.update(base.__class_vars__)
|
||||
private_attributes.update(base.__private_attributes__)
|
||||
return field_names, class_vars, private_attributes
|
||||
|
||||
@property
|
||||
@deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None)
|
||||
@deprecated(
|
||||
'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20
|
||||
)
|
||||
def __fields__(self) -> dict[str, FieldInfo]:
|
||||
warnings.warn(
|
||||
'The `__fields__` attribute is deprecated, use `model_fields` instead.',
|
||||
PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
return getattr(self, '__pydantic_fields__', {})
|
||||
|
||||
@property
|
||||
def __pydantic_fields_complete__(self) -> bool:
|
||||
"""Whether the fields where successfully collected (i.e. type hints were successfully resolves).
|
||||
|
||||
This is a private attribute, not meant to be used outside Pydantic.
|
||||
"""
|
||||
if not hasattr(self, '__pydantic_fields__'):
|
||||
return False
|
||||
|
||||
field_infos = cast('dict[str, FieldInfo]', self.__pydantic_fields__) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return all(field_info._complete for field_info in field_infos.values())
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
attributes = list(super().__dir__())
|
||||
if '__fields__' in attributes:
|
||||
attributes.remove('__fields__')
|
||||
return attributes
|
||||
warnings.warn('The `__fields__` attribute is deprecated, use `model_fields` instead.', DeprecationWarning)
|
||||
return self.model_fields # type: ignore
|
||||
|
||||
|
||||
def init_private_attributes(self: BaseModel, context: Any, /) -> None:
|
||||
def init_private_attributes(self: BaseModel, __context: Any) -> None:
|
||||
"""This function is meant to behave like a BaseModel method to initialise private attributes.
|
||||
|
||||
It takes context as an argument since that's what pydantic-core passes when calling it.
|
||||
|
||||
Args:
|
||||
self: The BaseModel instance.
|
||||
context: The context.
|
||||
__context: The context.
|
||||
"""
|
||||
if getattr(self, '__pydantic_private__', None) is None:
|
||||
pydantic_private = {}
|
||||
for name, private_attr in self.__private_attributes__.items():
|
||||
default = private_attr.get_default()
|
||||
if default is not PydanticUndefined:
|
||||
pydantic_private[name] = default
|
||||
object_setattr(self, '__pydantic_private__', pydantic_private)
|
||||
pydantic_private = {}
|
||||
for name, private_attr in self.__private_attributes__.items():
|
||||
default = private_attr.get_default()
|
||||
if default is not PydanticUndefined:
|
||||
pydantic_private[name] = default
|
||||
object_setattr(self, '__pydantic_private__', pydantic_private)
|
||||
|
||||
|
||||
def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None:
|
||||
@@ -356,7 +272,7 @@ def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...])
|
||||
if 'model_post_init' in namespace:
|
||||
return namespace['model_post_init']
|
||||
|
||||
BaseModel = import_cached_base_model()
|
||||
from ..main import BaseModel
|
||||
|
||||
model_post_init = get_attribute_from_bases(bases, 'model_post_init')
|
||||
if model_post_init is not BaseModel.model_post_init:
|
||||
@@ -389,11 +305,7 @@ def inspect_namespace( # noqa C901
|
||||
- If a field does not have a type annotation.
|
||||
- If a field on base class was overridden by a non-annotated attribute.
|
||||
"""
|
||||
from ..fields import ModelPrivateAttr, PrivateAttr
|
||||
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
all_ignored_types = ignored_types + default_ignored_types()
|
||||
all_ignored_types = ignored_types + IGNORED_TYPES
|
||||
|
||||
private_attributes: dict[str, ModelPrivateAttr] = {}
|
||||
raw_annotations = namespace.get('__annotations__', {})
|
||||
@@ -403,12 +315,11 @@ def inspect_namespace( # noqa C901
|
||||
|
||||
ignored_names: set[str] = set()
|
||||
for var_name, value in list(namespace.items()):
|
||||
if var_name == 'model_config' or var_name == '__pydantic_extra__':
|
||||
if var_name == 'model_config':
|
||||
continue
|
||||
elif (
|
||||
isinstance(value, type)
|
||||
and value.__module__ == namespace['__module__']
|
||||
and '__qualname__' in namespace
|
||||
and value.__qualname__.startswith(namespace['__qualname__'])
|
||||
):
|
||||
# `value` is a nested type defined in this namespace; don't error
|
||||
@@ -439,8 +350,8 @@ def inspect_namespace( # noqa C901
|
||||
elif var_name.startswith('__'):
|
||||
continue
|
||||
elif is_valid_privateattr_name(var_name):
|
||||
if var_name not in raw_annotations or not is_classvar_annotation(raw_annotations[var_name]):
|
||||
private_attributes[var_name] = cast(ModelPrivateAttr, PrivateAttr(default=value))
|
||||
if var_name not in raw_annotations or not is_classvar(raw_annotations[var_name]):
|
||||
private_attributes[var_name] = PrivateAttr(default=value)
|
||||
del namespace[var_name]
|
||||
elif var_name in base_class_vars:
|
||||
continue
|
||||
@@ -457,8 +368,8 @@ def inspect_namespace( # noqa C901
|
||||
)
|
||||
else:
|
||||
raise PydanticUserError(
|
||||
f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a '
|
||||
f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this '
|
||||
f"A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a "
|
||||
f"type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this "
|
||||
f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.",
|
||||
code='model-field-missing-annotation',
|
||||
)
|
||||
@@ -468,82 +379,45 @@ def inspect_namespace( # noqa C901
|
||||
is_valid_privateattr_name(ann_name)
|
||||
and ann_name not in private_attributes
|
||||
and ann_name not in ignored_names
|
||||
# This condition can be a false negative when `ann_type` is stringified,
|
||||
# but it is handled in most cases in `set_model_fields`:
|
||||
and not is_classvar_annotation(ann_type)
|
||||
and not is_classvar(ann_type)
|
||||
and ann_type not in all_ignored_types
|
||||
and getattr(ann_type, '__module__', None) != 'functools'
|
||||
):
|
||||
if isinstance(ann_type, str):
|
||||
# Walking up the frames to get the module namespace where the model is defined
|
||||
# (as the model class wasn't created yet, we unfortunately can't use `cls.__module__`):
|
||||
frame = sys._getframe(2)
|
||||
if frame is not None:
|
||||
try:
|
||||
ann_type = eval_type_backport(
|
||||
_make_forward_ref(ann_type, is_argument=False, is_class=True),
|
||||
globalns=frame.f_globals,
|
||||
localns=frame.f_locals,
|
||||
)
|
||||
except (NameError, TypeError):
|
||||
pass
|
||||
|
||||
if typing_objects.is_annotated(get_origin(ann_type)):
|
||||
_, *metadata = get_args(ann_type)
|
||||
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
|
||||
if private_attr is not None:
|
||||
private_attributes[ann_name] = private_attr
|
||||
continue
|
||||
private_attributes[ann_name] = PrivateAttr()
|
||||
|
||||
return private_attributes
|
||||
|
||||
|
||||
def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
|
||||
def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> None:
|
||||
if '__hash__' in namespace:
|
||||
return
|
||||
|
||||
base_hash_func = get_attribute_from_bases(bases, '__hash__')
|
||||
new_hash_func = make_hash_func(cls)
|
||||
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
|
||||
# If `__hash__` is some default, we generate a hash function.
|
||||
# It will be `None` if not overridden from BaseModel.
|
||||
# It may be `object.__hash__` if there is another
|
||||
if base_hash_func in {None, object.__hash__}:
|
||||
# If `__hash__` is None _or_ `object.__hash__`, we generate a hash function.
|
||||
# It will be `None` if not overridden from BaseModel, but may be `object.__hash__` if there is another
|
||||
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
|
||||
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
|
||||
# In the last case we still need a new hash function to account for new `model_fields`.
|
||||
cls.__hash__ = new_hash_func
|
||||
def hash_func(self: Any) -> int:
|
||||
return hash(self.__class__) + hash(tuple(self.__dict__.values()))
|
||||
|
||||
|
||||
def make_hash_func(cls: type[BaseModel]) -> Any:
|
||||
getter = operator.itemgetter(*cls.__pydantic_fields__.keys()) if cls.__pydantic_fields__ else lambda _: 0
|
||||
|
||||
def hash_func(self: Any) -> int:
|
||||
try:
|
||||
return hash(getter(self.__dict__))
|
||||
except KeyError:
|
||||
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
|
||||
# all model fields, which is how we can get here.
|
||||
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
|
||||
# and wrapping it in a `try` doesn't slow things down much in the common case.
|
||||
return hash(getter(SafeGetItemProxy(self.__dict__)))
|
||||
|
||||
return hash_func
|
||||
namespace['__hash__'] = hash_func
|
||||
|
||||
|
||||
def set_model_fields(
|
||||
cls: type[BaseModel],
|
||||
config_wrapper: ConfigWrapper,
|
||||
ns_resolver: NsResolver | None,
|
||||
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
|
||||
) -> None:
|
||||
"""Collect and set `cls.__pydantic_fields__` and `cls.__class_vars__`.
|
||||
"""Collect and set `cls.model_fields` and `cls.__class_vars__`.
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
bases: Parents of the class, generally `cls.__bases__`.
|
||||
config_wrapper: The config wrapper instance.
|
||||
ns_resolver: Namespace resolver to use when getting model annotations.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
"""
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
fields, class_vars = collect_model_fields(cls, config_wrapper, ns_resolver, typevars_map=typevars_map)
|
||||
fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map)
|
||||
|
||||
cls.__pydantic_fields__ = fields
|
||||
cls.model_fields = fields
|
||||
cls.__class_vars__.update(class_vars)
|
||||
|
||||
for k in class_vars:
|
||||
@@ -561,11 +435,11 @@ def set_model_fields(
|
||||
|
||||
def complete_model_class(
|
||||
cls: type[BaseModel],
|
||||
cls_name: str,
|
||||
config_wrapper: ConfigWrapper,
|
||||
*,
|
||||
raise_errors: bool = True,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
create_model_module: str | None = None,
|
||||
types_namespace: dict[str, Any] | None,
|
||||
) -> bool:
|
||||
"""Finish building a model class.
|
||||
|
||||
@@ -574,10 +448,10 @@ def complete_model_class(
|
||||
|
||||
Args:
|
||||
cls: BaseModel or dataclass.
|
||||
cls_name: The model or dataclass name.
|
||||
config_wrapper: The config wrapper instance.
|
||||
raise_errors: Whether to raise errors.
|
||||
ns_resolver: The namespace resolver instance to use during schema building.
|
||||
create_model_module: The module of the class to be created, if created by `create_model`.
|
||||
types_namespace: Optional extra namespace to look for types in.
|
||||
|
||||
Returns:
|
||||
`True` if the model is successfully completed, else `False`.
|
||||
@@ -589,151 +463,132 @@ def complete_model_class(
|
||||
typevars_map = get_model_typevars_map(cls)
|
||||
gen_schema = GenerateSchema(
|
||||
config_wrapper,
|
||||
ns_resolver,
|
||||
types_namespace,
|
||||
typevars_map,
|
||||
)
|
||||
|
||||
handler = CallbackGetCoreSchemaHandler(
|
||||
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
|
||||
gen_schema,
|
||||
ref_mode='unpack',
|
||||
)
|
||||
|
||||
if config_wrapper.defer_build:
|
||||
set_model_mocks(cls, cls_name)
|
||||
return False
|
||||
|
||||
try:
|
||||
schema = gen_schema.generate_schema(cls)
|
||||
schema = cls.__get_pydantic_core_schema__(cls, handler)
|
||||
except PydanticUndefinedAnnotation as e:
|
||||
if raise_errors:
|
||||
raise
|
||||
set_model_mocks(cls, f'`{e.name}`')
|
||||
set_model_mocks(cls, cls_name, f'`{e.name}`')
|
||||
return False
|
||||
|
||||
core_config = config_wrapper.core_config(title=cls.__name__)
|
||||
core_config = config_wrapper.core_config(cls)
|
||||
|
||||
try:
|
||||
schema = gen_schema.clean_schema(schema)
|
||||
except InvalidSchemaError:
|
||||
set_model_mocks(cls)
|
||||
schema = gen_schema.collect_definitions(schema)
|
||||
|
||||
schema = apply_discriminators(simplify_schema_references(schema))
|
||||
if collect_invalid_schemas(schema):
|
||||
set_model_mocks(cls, cls_name)
|
||||
return False
|
||||
|
||||
# This needs to happen *after* model schema generation, as the return type
|
||||
# of the properties are evaluated and the `ComputedFieldInfo` are recreated:
|
||||
cls.__pydantic_computed_fields__ = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()}
|
||||
|
||||
set_deprecated_descriptors(cls)
|
||||
|
||||
cls.__pydantic_core_schema__ = schema
|
||||
|
||||
cls.__pydantic_validator__ = create_schema_validator(
|
||||
schema,
|
||||
cls,
|
||||
create_model_module or cls.__module__,
|
||||
cls.__qualname__,
|
||||
'create_model' if create_model_module else 'BaseModel',
|
||||
core_config,
|
||||
config_wrapper.plugin_settings,
|
||||
)
|
||||
# debug(schema)
|
||||
cls.__pydantic_core_schema__ = schema = validate_core_schema(schema)
|
||||
cls.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
|
||||
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
|
||||
cls.__pydantic_complete__ = True
|
||||
|
||||
# set __signature__ attr only for model class, but not for its instances
|
||||
# (because instances can define `__call__`, and `inspect.signature` shouldn't
|
||||
# use the `__signature__` attribute and instead generate from `__call__`).
|
||||
cls.__signature__ = LazyClassAttribute(
|
||||
'__signature__',
|
||||
partial(
|
||||
generate_pydantic_signature,
|
||||
init=cls.__init__,
|
||||
fields=cls.__pydantic_fields__,
|
||||
validate_by_name=config_wrapper.validate_by_name,
|
||||
extra=config_wrapper.extra,
|
||||
),
|
||||
cls.__signature__ = ClassAttribute(
|
||||
'__signature__', generate_model_signature(cls.__init__, cls.model_fields, config_wrapper)
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def set_deprecated_descriptors(cls: type[BaseModel]) -> None:
|
||||
"""Set data descriptors on the class for deprecated fields."""
|
||||
for field, field_info in cls.__pydantic_fields__.items():
|
||||
if (msg := field_info.deprecation_message) is not None:
|
||||
desc = _DeprecatedFieldDescriptor(msg)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
def generate_model_signature(
|
||||
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper
|
||||
) -> Signature:
|
||||
"""Generate signature for model based on its fields.
|
||||
|
||||
for field, computed_field_info in cls.__pydantic_computed_fields__.items():
|
||||
if (
|
||||
(msg := computed_field_info.deprecation_message) is not None
|
||||
# Avoid having two warnings emitted:
|
||||
and not hasattr(unwrap_wrapped_function(computed_field_info.wrapped_property), '__deprecated__')
|
||||
):
|
||||
desc = _DeprecatedFieldDescriptor(msg, computed_field_info.wrapped_property)
|
||||
desc.__set_name__(cls, field)
|
||||
setattr(cls, field, desc)
|
||||
Args:
|
||||
init: The class init.
|
||||
fields: The model fields.
|
||||
config_wrapper: The config wrapper instance.
|
||||
|
||||
|
||||
class _DeprecatedFieldDescriptor:
|
||||
"""Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field.
|
||||
|
||||
Attributes:
|
||||
msg: The deprecation message to be emitted.
|
||||
wrapped_property: The property instance if the deprecated field is a computed field, or `None`.
|
||||
field_name: The name of the field being deprecated.
|
||||
Returns:
|
||||
The model signature.
|
||||
"""
|
||||
from inspect import Parameter, Signature, signature
|
||||
from itertools import islice
|
||||
|
||||
field_name: str
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
|
||||
def __init__(self, msg: str, wrapped_property: property | None = None) -> None:
|
||||
self.msg = msg
|
||||
self.wrapped_property = wrapped_property
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if param.annotation == 'Any':
|
||||
param = param.replace(annotation=Any)
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
|
||||
def __set_name__(self, cls: type[BaseModel], name: str) -> None:
|
||||
self.field_name = name
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = config_wrapper.populate_by_name
|
||||
for field_name, field in fields.items():
|
||||
# when alias is a str it should be used for signature generation
|
||||
if isinstance(field.alias, str):
|
||||
param_name = field.alias
|
||||
else:
|
||||
param_name = field_name
|
||||
|
||||
def __get__(self, obj: BaseModel | None, obj_type: type[BaseModel] | None = None) -> Any:
|
||||
if obj is None:
|
||||
if self.wrapped_property is not None:
|
||||
return self.wrapped_property.__get__(None, obj_type)
|
||||
raise AttributeError(self.field_name)
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
|
||||
warnings.warn(self.msg, builtins.DeprecationWarning, stacklevel=2)
|
||||
if not is_valid_identifier(param_name):
|
||||
if allow_names and is_valid_identifier(field_name):
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
|
||||
if self.wrapped_property is not None:
|
||||
return self.wrapped_property.__get__(obj, obj_type)
|
||||
return obj.__dict__[self.field_name]
|
||||
kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs
|
||||
)
|
||||
|
||||
# Defined to make it a data descriptor and take precedence over the instance's dictionary.
|
||||
# Note that it will not be called when setting a value on a model instance
|
||||
# as `BaseModel.__setattr__` is defined and takes priority.
|
||||
def __set__(self, obj: Any, value: Any) -> NoReturn:
|
||||
raise AttributeError(self.field_name)
|
||||
if config_wrapper.extra == 'allow':
|
||||
use_var_kw = True
|
||||
|
||||
|
||||
class _PydanticWeakRef:
|
||||
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
|
||||
|
||||
Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
|
||||
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
|
||||
`weakref.ref` instead of subclassing it.
|
||||
|
||||
See https://github.com/pydantic/pydantic/issues/6763 for context.
|
||||
|
||||
Semantics:
|
||||
- If not pickled, behaves the same as a `weakref.ref`.
|
||||
- If pickled along with the referenced object, the same `weakref.ref` behavior
|
||||
will be maintained between them after unpickling.
|
||||
- If pickled without the referenced object, after unpickling the underlying
|
||||
reference will be cleared (`__call__` will always return `None`).
|
||||
"""
|
||||
|
||||
def __init__(self, obj: Any):
|
||||
if obj is None:
|
||||
# The object will be `None` upon deserialization if the serialized weakref
|
||||
# had lost its underlying object.
|
||||
self._wr = None
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
else:
|
||||
self._wr = weakref.ref(obj)
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
|
||||
def __call__(self) -> Any:
|
||||
if self._wr is None:
|
||||
return None
|
||||
else:
|
||||
return self._wr()
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
|
||||
def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
|
||||
return _PydanticWeakRef, (self(),)
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
|
||||
|
||||
class _PydanticWeakRef(weakref.ReferenceType):
|
||||
pass
|
||||
|
||||
|
||||
def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
@@ -770,23 +625,3 @@ def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | N
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
@cache
|
||||
def default_ignored_types() -> tuple[type[Any], ...]:
|
||||
from ..fields import ComputedFieldInfo
|
||||
|
||||
ignored_types = [
|
||||
FunctionType,
|
||||
property,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
PydanticDescriptorProxy,
|
||||
ComputedFieldInfo,
|
||||
TypeAliasType, # from `typing_extensions`
|
||||
]
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
ignored_types.append(typing.TypeAliasType)
|
||||
|
||||
return tuple(ignored_types)
|
||||
|
||||
@@ -1,293 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Generator, Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, NamedTuple, TypeVar
|
||||
|
||||
from typing_extensions import ParamSpec, TypeAlias, TypeAliasType, TypeVarTuple
|
||||
|
||||
GlobalsNamespace: TypeAlias = 'dict[str, Any]'
|
||||
"""A global namespace.
|
||||
|
||||
In most cases, this is a reference to the `__dict__` attribute of a module.
|
||||
This namespace type is expected as the `globals` argument during annotations evaluation.
|
||||
"""
|
||||
|
||||
MappingNamespace: TypeAlias = Mapping[str, Any]
|
||||
"""Any kind of namespace.
|
||||
|
||||
In most cases, this is a local namespace (e.g. the `__dict__` attribute of a class,
|
||||
the [`f_locals`][frame.f_locals] attribute of a frame object, when dealing with types
|
||||
defined inside functions).
|
||||
This namespace type is expected as the `locals` argument during annotations evaluation.
|
||||
"""
|
||||
|
||||
_TypeVarLike: TypeAlias = 'TypeVar | ParamSpec | TypeVarTuple'
|
||||
|
||||
|
||||
class NamespacesTuple(NamedTuple):
|
||||
"""A tuple of globals and locals to be used during annotations evaluation.
|
||||
|
||||
This datastructure is defined as a named tuple so that it can easily be unpacked:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
def eval_type(typ: type[Any], ns: NamespacesTuple) -> None:
|
||||
return eval(typ, *ns)
|
||||
```
|
||||
"""
|
||||
|
||||
globals: GlobalsNamespace
|
||||
"""The namespace to be used as the `globals` argument during annotations evaluation."""
|
||||
|
||||
locals: MappingNamespace
|
||||
"""The namespace to be used as the `locals` argument during annotations evaluation."""
|
||||
|
||||
|
||||
def get_module_ns_of(obj: Any) -> dict[str, Any]:
|
||||
"""Get the namespace of the module where the object is defined.
|
||||
|
||||
Caution: this function does not return a copy of the module namespace, so the result
|
||||
should not be mutated. The burden of enforcing this is on the caller.
|
||||
"""
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
if module_name:
|
||||
try:
|
||||
return sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
# Note that this class is almost identical to `collections.ChainMap`, but need to enforce
|
||||
# immutable mappings here:
|
||||
class LazyLocalNamespace(Mapping[str, Any]):
|
||||
"""A lazily evaluated mapping, to be used as the `locals` argument during annotations evaluation.
|
||||
|
||||
While the [`eval`][eval] function expects a mapping as the `locals` argument, it only
|
||||
performs `__getitem__` calls. The [`Mapping`][collections.abc.Mapping] abstract base class
|
||||
is fully implemented only for type checking purposes.
|
||||
|
||||
Args:
|
||||
*namespaces: The namespaces to consider, in ascending order of priority.
|
||||
|
||||
Example:
|
||||
```python {lint="skip" test="skip"}
|
||||
ns = LazyLocalNamespace({'a': 1, 'b': 2}, {'a': 3})
|
||||
ns['a']
|
||||
#> 3
|
||||
ns['b']
|
||||
#> 2
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, *namespaces: MappingNamespace) -> None:
|
||||
self._namespaces = namespaces
|
||||
|
||||
@cached_property
|
||||
def data(self) -> dict[str, Any]:
|
||||
return {k: v for ns in self._namespaces for k, v in ns.items()}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self.data[key]
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self.data
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self.data)
|
||||
|
||||
|
||||
def ns_for_function(obj: Callable[..., Any], parent_namespace: MappingNamespace | None = None) -> NamespacesTuple:
|
||||
"""Return the global and local namespaces to be used when evaluating annotations for the provided function.
|
||||
|
||||
The global namespace will be the `__dict__` attribute of the module the function was defined in.
|
||||
The local namespace will contain the `__type_params__` introduced by PEP 695.
|
||||
|
||||
Args:
|
||||
obj: The object to use when building namespaces.
|
||||
parent_namespace: Optional namespace to be added with the lowest priority in the local namespace.
|
||||
If the passed function is a method, the `parent_namespace` will be the namespace of the class
|
||||
the method is defined in. Thus, we also fetch type `__type_params__` from there (i.e. the
|
||||
class-scoped type variables).
|
||||
"""
|
||||
locals_list: list[MappingNamespace] = []
|
||||
if parent_namespace is not None:
|
||||
locals_list.append(parent_namespace)
|
||||
|
||||
# Get the `__type_params__` attribute introduced by PEP 695.
|
||||
# Note that the `typing._eval_type` function expects type params to be
|
||||
# passed as a separate argument. However, internally, `_eval_type` calls
|
||||
# `ForwardRef._evaluate` which will merge type params with the localns,
|
||||
# essentially mimicking what we do here.
|
||||
type_params: tuple[_TypeVarLike, ...] = getattr(obj, '__type_params__', ())
|
||||
if parent_namespace is not None:
|
||||
# We also fetch type params from the parent namespace. If present, it probably
|
||||
# means the function was defined in a class. This is to support the following:
|
||||
# https://github.com/python/cpython/issues/124089.
|
||||
type_params += parent_namespace.get('__type_params__', ())
|
||||
|
||||
locals_list.append({t.__name__: t for t in type_params})
|
||||
|
||||
# What about short-cirtuiting to `obj.__globals__`?
|
||||
globalns = get_module_ns_of(obj)
|
||||
|
||||
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
|
||||
|
||||
|
||||
class NsResolver:
|
||||
"""A class responsible for the namespaces resolving logic for annotations evaluation.
|
||||
|
||||
This class handles the namespace logic when evaluating annotations mainly for class objects.
|
||||
|
||||
It holds a stack of classes that are being inspected during the core schema building,
|
||||
and the `types_namespace` property exposes the globals and locals to be used for
|
||||
type annotation evaluation. Additionally -- if no class is present in the stack -- a
|
||||
fallback globals and locals can be provided using the `namespaces_tuple` argument
|
||||
(this is useful when generating a schema for a simple annotation, e.g. when using
|
||||
`TypeAdapter`).
|
||||
|
||||
The namespace creation logic is unfortunately flawed in some cases, for backwards
|
||||
compatibility reasons and to better support valid edge cases. See the description
|
||||
for the `parent_namespace` argument and the example for more details.
|
||||
|
||||
Args:
|
||||
namespaces_tuple: The default globals and locals to use if no class is present
|
||||
on the stack. This can be useful when using the `GenerateSchema` class
|
||||
with `TypeAdapter`, where the "type" being analyzed is a simple annotation.
|
||||
parent_namespace: An optional parent namespace that will be added to the locals
|
||||
with the lowest priority. For a given class defined in a function, the locals
|
||||
of this function are usually used as the parent namespace:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
from pydantic import BaseModel
|
||||
|
||||
def func() -> None:
|
||||
SomeType = int
|
||||
|
||||
class Model(BaseModel):
|
||||
f: 'SomeType'
|
||||
|
||||
# when collecting fields, an namespace resolver instance will be created
|
||||
# this way:
|
||||
# ns_resolver = NsResolver(parent_namespace={'SomeType': SomeType})
|
||||
```
|
||||
|
||||
For backwards compatibility reasons and to support valid edge cases, this parent
|
||||
namespace will be used for *every* type being pushed to the stack. In the future,
|
||||
we might want to be smarter by only doing so when the type being pushed is defined
|
||||
in the same module as the parent namespace.
|
||||
|
||||
Example:
|
||||
```python {lint="skip" test="skip"}
|
||||
ns_resolver = NsResolver(
|
||||
parent_namespace={'fallback': 1},
|
||||
)
|
||||
|
||||
class Sub:
|
||||
m: 'Model'
|
||||
|
||||
class Model:
|
||||
some_local = 1
|
||||
sub: Sub
|
||||
|
||||
ns_resolver = NsResolver()
|
||||
|
||||
# This is roughly what happens when we build a core schema for `Model`:
|
||||
with ns_resolver.push(Model):
|
||||
ns_resolver.types_namespace
|
||||
#> NamespacesTuple({'Sub': Sub}, {'Model': Model, 'some_local': 1})
|
||||
# First thing to notice here, the model being pushed is added to the locals.
|
||||
# Because `NsResolver` is being used during the model definition, it is not
|
||||
# yet added to the globals. This is useful when resolving self-referencing annotations.
|
||||
|
||||
with ns_resolver.push(Sub):
|
||||
ns_resolver.types_namespace
|
||||
#> NamespacesTuple({'Sub': Sub}, {'Sub': Sub, 'Model': Model})
|
||||
# Second thing to notice: `Sub` is present in both the globals and locals.
|
||||
# This is not an issue, just that as described above, the model being pushed
|
||||
# is added to the locals, but it happens to be present in the globals as well
|
||||
# because it is already defined.
|
||||
# Third thing to notice: `Model` is also added in locals. This is a backwards
|
||||
# compatibility workaround that allows for `Sub` to be able to resolve `'Model'`
|
||||
# correctly (as otherwise models would have to be rebuilt even though this
|
||||
# doesn't look necessary).
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespaces_tuple: NamespacesTuple | None = None,
|
||||
parent_namespace: MappingNamespace | None = None,
|
||||
) -> None:
|
||||
self._base_ns_tuple = namespaces_tuple or NamespacesTuple({}, {})
|
||||
self._parent_ns = parent_namespace
|
||||
self._types_stack: list[type[Any] | TypeAliasType] = []
|
||||
|
||||
@cached_property
|
||||
def types_namespace(self) -> NamespacesTuple:
|
||||
"""The current global and local namespaces to be used for annotations evaluation."""
|
||||
if not self._types_stack:
|
||||
# TODO: should we merge the parent namespace here?
|
||||
# This is relevant for TypeAdapter, where there are no types on the stack, and we might
|
||||
# need access to the parent_ns. Right now, we sidestep this in `type_adapter.py` by passing
|
||||
# locals to both parent_ns and the base_ns_tuple, but this is a bit hacky.
|
||||
# we might consider something like:
|
||||
# if self._parent_ns is not None:
|
||||
# # Hacky workarounds, see class docstring:
|
||||
# # An optional parent namespace that will be added to the locals with the lowest priority
|
||||
# locals_list: list[MappingNamespace] = [self._parent_ns, self._base_ns_tuple.locals]
|
||||
# return NamespacesTuple(self._base_ns_tuple.globals, LazyLocalNamespace(*locals_list))
|
||||
return self._base_ns_tuple
|
||||
|
||||
typ = self._types_stack[-1]
|
||||
|
||||
globalns = get_module_ns_of(typ)
|
||||
|
||||
locals_list: list[MappingNamespace] = []
|
||||
# Hacky workarounds, see class docstring:
|
||||
# An optional parent namespace that will be added to the locals with the lowest priority
|
||||
if self._parent_ns is not None:
|
||||
locals_list.append(self._parent_ns)
|
||||
if len(self._types_stack) > 1:
|
||||
first_type = self._types_stack[0]
|
||||
locals_list.append({first_type.__name__: first_type})
|
||||
|
||||
# Adding `__type_params__` *before* `vars(typ)`, as the latter takes priority
|
||||
# (see https://github.com/python/cpython/pull/120272).
|
||||
# TODO `typ.__type_params__` when we drop support for Python 3.11:
|
||||
type_params: tuple[_TypeVarLike, ...] = getattr(typ, '__type_params__', ())
|
||||
if type_params:
|
||||
# Adding `__type_params__` is mostly useful for generic classes defined using
|
||||
# PEP 695 syntax *and* using forward annotations (see the example in
|
||||
# https://github.com/python/cpython/issues/114053). For TypeAliasType instances,
|
||||
# it is way less common, but still required if using a string annotation in the alias
|
||||
# value, e.g. `type A[T] = 'T'` (which is not necessary in most cases).
|
||||
locals_list.append({t.__name__: t for t in type_params})
|
||||
|
||||
# TypeAliasType instances don't have a `__dict__` attribute, so the check
|
||||
# is necessary:
|
||||
if hasattr(typ, '__dict__'):
|
||||
locals_list.append(vars(typ))
|
||||
|
||||
# The `len(self._types_stack) > 1` check above prevents this from being added twice:
|
||||
locals_list.append({typ.__name__: typ})
|
||||
|
||||
return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list))
|
||||
|
||||
@contextmanager
|
||||
def push(self, typ: type[Any] | TypeAliasType, /) -> Generator[None]:
|
||||
"""Push a type to the stack."""
|
||||
self._types_stack.append(typ)
|
||||
# Reset the cached property:
|
||||
self.__dict__.pop('types_namespace', None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._types_stack.pop()
|
||||
self.__dict__.pop('types_namespace', None)
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Tools to provide pretty/human-readable display of objects."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import types
|
||||
@@ -7,8 +6,6 @@ import typing
|
||||
from typing import Any
|
||||
|
||||
import typing_extensions
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from . import _typing_extra
|
||||
|
||||
@@ -35,7 +32,7 @@ class Representation:
|
||||
# (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
|
||||
|
||||
# we don't want to use a type annotation here as it can break get_type_hints
|
||||
__slots__ = () # type: typing.Collection[str]
|
||||
__slots__ = tuple() # type: typing.Collection[str]
|
||||
|
||||
def __repr_args__(self) -> ReprArgs:
|
||||
"""Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
|
||||
@@ -48,17 +45,12 @@ class Representation:
|
||||
if not attrs_names and hasattr(self, '__dict__'):
|
||||
attrs_names = self.__dict__.keys()
|
||||
attrs = ((s, getattr(self, s)) for s in attrs_names)
|
||||
return [(a, v if v is not self else self.__repr_recursion__(v)) for a, v in attrs if v is not None]
|
||||
return [(a, v) for a, v in attrs if v is not None]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""Name of the instance's class, used in __repr__."""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_recursion__(self, object: Any) -> str:
|
||||
"""Returns the string representation of a recursive object."""
|
||||
# This is copied over from the stdlib `pprint` module:
|
||||
return f'<Recursion on {type(object).__name__} with id={id(object)}>'
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||||
|
||||
@@ -95,30 +87,25 @@ def display_as_type(obj: Any) -> str:
|
||||
|
||||
Takes some logic from `typing._type_repr`.
|
||||
"""
|
||||
if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)):
|
||||
if isinstance(obj, types.FunctionType):
|
||||
return obj.__name__
|
||||
elif obj is ...:
|
||||
return '...'
|
||||
elif isinstance(obj, Representation):
|
||||
return repr(obj)
|
||||
elif isinstance(obj, typing.ForwardRef) or typing_objects.is_typealiastype(obj):
|
||||
return str(obj)
|
||||
|
||||
if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
|
||||
obj = obj.__class__
|
||||
|
||||
if is_union_origin(typing_extensions.get_origin(obj)):
|
||||
if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)):
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
return f'Union[{args}]'
|
||||
elif isinstance(obj, _typing_extra.WithArgsTypes):
|
||||
if typing_objects.is_literal(typing_extensions.get_origin(obj)):
|
||||
if typing_extensions.get_origin(obj) == typing_extensions.Literal:
|
||||
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
|
||||
else:
|
||||
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
|
||||
try:
|
||||
return f'{obj.__qualname__}[{args}]'
|
||||
except AttributeError:
|
||||
return str(obj).replace('typing.', '').replace('typing_extensions.', '') # handles TypeAliasType in 3.12
|
||||
return f'{obj.__qualname__}[{args}]'
|
||||
elif isinstance(obj, type):
|
||||
return obj.__qualname__
|
||||
else:
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic_core.core_schema import ComputedField, CoreSchema, DefinitionReferenceSchema, SerSchema
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField'
|
||||
|
||||
|
||||
class GatherResult(TypedDict):
|
||||
"""Schema traversing result."""
|
||||
|
||||
collected_references: dict[str, DefinitionReferenceSchema | None]
|
||||
"""The collected definition references.
|
||||
|
||||
If a definition reference schema can be inlined, it means that there is
|
||||
only one in the whole core schema. As such, it is stored as the value.
|
||||
Otherwise, the value is set to `None`.
|
||||
"""
|
||||
|
||||
deferred_discriminator_schemas: list[CoreSchema]
|
||||
"""The list of core schemas having the discriminator application deferred."""
|
||||
|
||||
|
||||
class MissingDefinitionError(LookupError):
|
||||
"""A reference was pointing to a non-existing core schema."""
|
||||
|
||||
def __init__(self, schema_reference: str, /) -> None:
|
||||
self.schema_reference = schema_reference
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatherContext:
|
||||
"""The current context used during core schema traversing.
|
||||
|
||||
Context instances should only be used during schema traversing.
|
||||
"""
|
||||
|
||||
definitions: dict[str, CoreSchema]
|
||||
"""The available definitions."""
|
||||
|
||||
deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list)
|
||||
"""The list of core schemas having the discriminator application deferred.
|
||||
|
||||
Internally, these core schemas have a specific key set in the core metadata dict.
|
||||
"""
|
||||
|
||||
collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict)
|
||||
"""The collected definition references.
|
||||
|
||||
If a definition reference schema can be inlined, it means that there is
|
||||
only one in the whole core schema. As such, it is stored as the value.
|
||||
Otherwise, the value is set to `None`.
|
||||
|
||||
During schema traversing, definition reference schemas can be added as candidates, or removed
|
||||
(by setting the value to `None`).
|
||||
"""
|
||||
|
||||
|
||||
def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None:
|
||||
meta = schema.get('metadata')
|
||||
if meta is not None and 'pydantic_internal_union_discriminator' in meta:
|
||||
ctx.deferred_discriminator_schemas.append(schema) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None:
|
||||
schema_ref = def_ref_schema['schema_ref']
|
||||
|
||||
if schema_ref not in ctx.collected_references:
|
||||
definition = ctx.definitions.get(schema_ref)
|
||||
if definition is None:
|
||||
raise MissingDefinitionError(schema_ref)
|
||||
|
||||
# The `'definition-ref'` schema was only encountered once, make it
|
||||
# a candidate to be inlined:
|
||||
ctx.collected_references[schema_ref] = def_ref_schema
|
||||
traverse_schema(definition, ctx)
|
||||
if 'serialization' in def_ref_schema:
|
||||
traverse_schema(def_ref_schema['serialization'], ctx)
|
||||
traverse_metadata(def_ref_schema, ctx)
|
||||
else:
|
||||
# The `'definition-ref'` schema was already encountered, meaning
|
||||
# the previously encountered schema (and this one) can't be inlined:
|
||||
ctx.collected_references[schema_ref] = None
|
||||
|
||||
|
||||
def traverse_schema(schema: AllSchemas, context: GatherContext) -> None:
|
||||
# TODO When we drop 3.9, use a match statement to get better type checking and remove
|
||||
# file-level type ignore.
|
||||
# (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance).
|
||||
schema_type = schema['type']
|
||||
|
||||
if schema_type == 'definition-ref':
|
||||
traverse_definition_ref(schema, context)
|
||||
# `traverse_definition_ref` handles the possible serialization and metadata schemas:
|
||||
return
|
||||
elif schema_type == 'definitions':
|
||||
traverse_schema(schema['schema'], context)
|
||||
for definition in schema['definitions']:
|
||||
traverse_schema(definition, context)
|
||||
elif schema_type in {'list', 'set', 'frozenset', 'generator'}:
|
||||
if 'items_schema' in schema:
|
||||
traverse_schema(schema['items_schema'], context)
|
||||
elif schema_type == 'tuple':
|
||||
if 'items_schema' in schema:
|
||||
for s in schema['items_schema']:
|
||||
traverse_schema(s, context)
|
||||
elif schema_type == 'dict':
|
||||
if 'keys_schema' in schema:
|
||||
traverse_schema(schema['keys_schema'], context)
|
||||
if 'values_schema' in schema:
|
||||
traverse_schema(schema['values_schema'], context)
|
||||
elif schema_type == 'union':
|
||||
for choice in schema['choices']:
|
||||
if isinstance(choice, tuple):
|
||||
traverse_schema(choice[0], context)
|
||||
else:
|
||||
traverse_schema(choice, context)
|
||||
elif schema_type == 'tagged-union':
|
||||
for v in schema['choices'].values():
|
||||
traverse_schema(v, context)
|
||||
elif schema_type == 'chain':
|
||||
for step in schema['steps']:
|
||||
traverse_schema(step, context)
|
||||
elif schema_type == 'lax-or-strict':
|
||||
traverse_schema(schema['lax_schema'], context)
|
||||
traverse_schema(schema['strict_schema'], context)
|
||||
elif schema_type == 'json-or-python':
|
||||
traverse_schema(schema['json_schema'], context)
|
||||
traverse_schema(schema['python_schema'], context)
|
||||
elif schema_type in {'model-fields', 'typed-dict'}:
|
||||
if 'extras_schema' in schema:
|
||||
traverse_schema(schema['extras_schema'], context)
|
||||
if 'computed_fields' in schema:
|
||||
for s in schema['computed_fields']:
|
||||
traverse_schema(s, context)
|
||||
for s in schema['fields'].values():
|
||||
traverse_schema(s, context)
|
||||
elif schema_type == 'dataclass-args':
|
||||
if 'computed_fields' in schema:
|
||||
for s in schema['computed_fields']:
|
||||
traverse_schema(s, context)
|
||||
for s in schema['fields']:
|
||||
traverse_schema(s, context)
|
||||
elif schema_type == 'arguments':
|
||||
for s in schema['arguments_schema']:
|
||||
traverse_schema(s['schema'], context)
|
||||
if 'var_args_schema' in schema:
|
||||
traverse_schema(schema['var_args_schema'], context)
|
||||
if 'var_kwargs_schema' in schema:
|
||||
traverse_schema(schema['var_kwargs_schema'], context)
|
||||
elif schema_type == 'arguments-v3':
|
||||
for s in schema['arguments_schema']:
|
||||
traverse_schema(s['schema'], context)
|
||||
elif schema_type == 'call':
|
||||
traverse_schema(schema['arguments_schema'], context)
|
||||
if 'return_schema' in schema:
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
elif schema_type == 'computed-field':
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
elif schema_type == 'function-before':
|
||||
if 'schema' in schema:
|
||||
traverse_schema(schema['schema'], context)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
traverse_schema(schema['json_schema_input_schema'], context)
|
||||
elif schema_type == 'function-plain':
|
||||
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
|
||||
if 'return_schema' in schema:
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
traverse_schema(schema['json_schema_input_schema'], context)
|
||||
elif schema_type == 'function-wrap':
|
||||
# TODO duplicate schema types for serializers and validators, needs to be deduplicated.
|
||||
if 'return_schema' in schema:
|
||||
traverse_schema(schema['return_schema'], context)
|
||||
if 'schema' in schema:
|
||||
traverse_schema(schema['schema'], context)
|
||||
if 'json_schema_input_schema' in schema:
|
||||
traverse_schema(schema['json_schema_input_schema'], context)
|
||||
else:
|
||||
if 'schema' in schema:
|
||||
traverse_schema(schema['schema'], context)
|
||||
|
||||
if 'serialization' in schema:
|
||||
traverse_schema(schema['serialization'], context)
|
||||
traverse_metadata(schema, context)
|
||||
|
||||
|
||||
def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult:
|
||||
"""Traverse the core schema and definitions and return the necessary information for schema cleaning.
|
||||
|
||||
During the core schema traversing, any `'definition-ref'` schema is:
|
||||
|
||||
- Validated: the reference must point to an existing definition. If this is not the case, a
|
||||
`MissingDefinitionError` exception is raised.
|
||||
- Stored in the context: the actual reference is stored in the context. Depending on whether
|
||||
the `'definition-ref'` schema is encountered more that once, the schema itself is also
|
||||
saved in the context to be inlined (i.e. replaced by the definition it points to).
|
||||
"""
|
||||
context = GatherContext(definitions)
|
||||
traverse_schema(schema, context)
|
||||
|
||||
return {
|
||||
'collected_references': context.collected_references,
|
||||
'deferred_discriminator_schemas': context.deferred_discriminator_schemas,
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Types and utility functions used by various other internal tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from pydantic_core import core_schema
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
@@ -12,7 +12,6 @@ if TYPE_CHECKING:
|
||||
from ..json_schema import GenerateJsonSchema, JsonSchemaValue
|
||||
from ._core_utils import CoreSchemaOrField
|
||||
from ._generate_schema import GenerateSchema
|
||||
from ._namespace_utils import NamespacesTuple
|
||||
|
||||
GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue]
|
||||
HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue]
|
||||
@@ -33,8 +32,8 @@ class GenerateJsonSchemaHandler(GetJsonSchemaHandler):
|
||||
self.handler = handler_override or generate_json_schema.generate_inner
|
||||
self.mode = generate_json_schema.mode
|
||||
|
||||
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
|
||||
return self.handler(core_schema)
|
||||
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
|
||||
return self.handler(__core_schema)
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
|
||||
"""Resolves `$ref` in the json schema.
|
||||
@@ -79,21 +78,22 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
|
||||
self._generate_schema = generate_schema
|
||||
self._ref_mode = ref_mode
|
||||
|
||||
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
schema = self._handler(source_type)
|
||||
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
|
||||
schema = self._handler(__source_type)
|
||||
ref = schema.get('ref')
|
||||
if self._ref_mode == 'to-def':
|
||||
ref = schema.get('ref')
|
||||
if ref is not None:
|
||||
return self._generate_schema.defs.create_definition_reference_schema(schema)
|
||||
self._generate_schema.defs.definitions[ref] = schema
|
||||
return core_schema.definition_reference_schema(ref)
|
||||
return schema
|
||||
else: # ref_mode = 'unpack'
|
||||
else: # ref_mode = 'unpack
|
||||
return self.resolve_ref_schema(schema)
|
||||
|
||||
def _get_types_namespace(self) -> NamespacesTuple:
|
||||
def _get_types_namespace(self) -> dict[str, Any] | None:
|
||||
return self._generate_schema._types_namespace
|
||||
|
||||
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
return self._generate_schema.generate_schema(source_type)
|
||||
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
|
||||
return self._generate_schema.generate_schema(__source_type)
|
||||
|
||||
@property
|
||||
def field_name(self) -> str | None:
|
||||
@@ -113,13 +113,12 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
|
||||
"""
|
||||
if maybe_ref_schema['type'] == 'definition-ref':
|
||||
ref = maybe_ref_schema['schema_ref']
|
||||
definition = self._generate_schema.defs.get_schema_from_ref(ref)
|
||||
if definition is None:
|
||||
if ref not in self._generate_schema.defs.definitions:
|
||||
raise LookupError(
|
||||
f'Could not find a ref for {ref}.'
|
||||
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
|
||||
)
|
||||
return definition
|
||||
return self._generate_schema.defs.definitions[ref]
|
||||
elif maybe_ref_schema['type'] == 'definitions':
|
||||
return self.resolve_ref_schema(maybe_ref_schema['schema'])
|
||||
return maybe_ref_schema
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import typing
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticOmit, core_schema
|
||||
|
||||
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.Deque: collections.deque, # noqa: UP006
|
||||
collections.deque: collections.deque,
|
||||
list: list,
|
||||
typing.List: list, # noqa: UP006
|
||||
tuple: tuple,
|
||||
typing.Tuple: tuple, # noqa: UP006
|
||||
set: set,
|
||||
typing.AbstractSet: set,
|
||||
typing.Set: set, # noqa: UP006
|
||||
frozenset: frozenset,
|
||||
typing.FrozenSet: frozenset, # noqa: UP006
|
||||
typing.Sequence: list,
|
||||
typing.MutableSequence: list,
|
||||
typing.MutableSet: set,
|
||||
# this doesn't handle subclasses of these
|
||||
# parametrized typing.Set creates one of these
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
}
|
||||
|
||||
|
||||
def serialize_sequence_via_list(
|
||||
v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
|
||||
) -> Any:
|
||||
items: list[Any] = []
|
||||
|
||||
mapped_origin = SEQUENCE_ORIGIN_MAP.get(type(v), None)
|
||||
if mapped_origin is None:
|
||||
# we shouldn't hit this branch, should probably add a serialization error or something
|
||||
return v
|
||||
|
||||
for index, item in enumerate(v):
|
||||
try:
|
||||
v = handler(item, index)
|
||||
except PydanticOmit:
|
||||
pass
|
||||
else:
|
||||
items.append(v)
|
||||
|
||||
if info.mode_is_json():
|
||||
return items
|
||||
else:
|
||||
return mapped_origin(items)
|
||||
@@ -1,188 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from inspect import Parameter, Signature, signature
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from ._utils import is_valid_identifier
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..config import ExtraValues
|
||||
from ..fields import FieldInfo
|
||||
|
||||
|
||||
# Copied over from stdlib dataclasses
|
||||
class _HAS_DEFAULT_FACTORY_CLASS:
|
||||
def __repr__(self):
|
||||
return '<factory>'
|
||||
|
||||
|
||||
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
|
||||
|
||||
|
||||
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
|
||||
"""Extract the correct name to use for the field when generating a signature.
|
||||
|
||||
Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
|
||||
First priority is given to the alias, then the validation_alias, then the field name.
|
||||
|
||||
Args:
|
||||
field_name: The name of the field
|
||||
field_info: The corresponding FieldInfo object.
|
||||
|
||||
Returns:
|
||||
The correct name to use when generating a signature.
|
||||
"""
|
||||
if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias):
|
||||
return field_info.alias
|
||||
if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias):
|
||||
return field_info.validation_alias
|
||||
|
||||
return field_name
|
||||
|
||||
|
||||
def _process_param_defaults(param: Parameter) -> Parameter:
|
||||
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
|
||||
|
||||
Args:
|
||||
param (Parameter): The parameter
|
||||
|
||||
Returns:
|
||||
Parameter: The custom processed parameter
|
||||
"""
|
||||
from ..fields import FieldInfo
|
||||
|
||||
param_default = param.default
|
||||
if isinstance(param_default, FieldInfo):
|
||||
annotation = param.annotation
|
||||
# Replace the annotation if appropriate
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if annotation == 'Any':
|
||||
annotation = Any
|
||||
|
||||
# Replace the field default
|
||||
default = param_default.default
|
||||
if default is PydanticUndefined:
|
||||
if param_default.default_factory is PydanticUndefined:
|
||||
default = Signature.empty
|
||||
else:
|
||||
# this is used by dataclasses to indicate a factory exists:
|
||||
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
|
||||
return param.replace(
|
||||
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
|
||||
)
|
||||
return param
|
||||
|
||||
|
||||
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
|
||||
init: Callable[..., None],
|
||||
fields: dict[str, FieldInfo],
|
||||
validate_by_name: bool,
|
||||
extra: ExtraValues | None,
|
||||
) -> dict[str, Parameter]:
|
||||
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
|
||||
from itertools import islice
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
# inspect does "clever" things to show annotations as strings because we have
|
||||
# `from __future__ import annotations` in main, we don't want that
|
||||
if fields.get(param.name):
|
||||
# exclude params with init=False
|
||||
if getattr(fields[param.name], 'init', True) is False:
|
||||
continue
|
||||
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
|
||||
if param.annotation == 'Any':
|
||||
param = param.replace(annotation=Any)
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = validate_by_name
|
||||
for field_name, field in fields.items():
|
||||
# when alias is a str it should be used for signature generation
|
||||
param_name = _field_name_for_signature(field_name, field)
|
||||
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
|
||||
if not is_valid_identifier(param_name):
|
||||
if allow_names:
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
|
||||
if field.is_required():
|
||||
default = Parameter.empty
|
||||
elif field.default_factory is not None:
|
||||
# Mimics stdlib dataclasses:
|
||||
default = _HAS_DEFAULT_FACTORY
|
||||
else:
|
||||
default = field.default
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name,
|
||||
Parameter.KEYWORD_ONLY,
|
||||
annotation=field.rebuild_annotation(),
|
||||
default=default,
|
||||
)
|
||||
|
||||
if extra == 'allow':
|
||||
use_var_kw = True
|
||||
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('self', Parameter.POSITIONAL_ONLY),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
else:
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
|
||||
return merged_params
|
||||
|
||||
|
||||
def generate_pydantic_signature(
|
||||
init: Callable[..., None],
|
||||
fields: dict[str, FieldInfo],
|
||||
validate_by_name: bool,
|
||||
extra: ExtraValues | None,
|
||||
is_dataclass: bool = False,
|
||||
) -> Signature:
|
||||
"""Generate signature for a pydantic BaseModel or dataclass.
|
||||
|
||||
Args:
|
||||
init: The class init.
|
||||
fields: The model fields.
|
||||
validate_by_name: The `validate_by_name` value of the config.
|
||||
extra: The `extra` value of the config.
|
||||
is_dataclass: Whether the model is a dataclass.
|
||||
|
||||
Returns:
|
||||
The dataclass/BaseModel subclass signature.
|
||||
"""
|
||||
merged_params = _generate_signature_parameters(init, fields, validate_by_name, extra)
|
||||
|
||||
if is_dataclass:
|
||||
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
|
||||
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
@@ -0,0 +1,713 @@
|
||||
"""Logic for generating pydantic-core schemas for standard library types.
|
||||
|
||||
Import of this module is deferred since it contains imports of many standard library modules.
|
||||
"""
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import collections
|
||||
import collections.abc
|
||||
import dataclasses
|
||||
import decimal
|
||||
import inspect
|
||||
import os
|
||||
import typing
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from typing import Any, Callable, Iterable, TypeVar
|
||||
|
||||
import typing_extensions
|
||||
from pydantic_core import (
|
||||
CoreSchema,
|
||||
MultiHostUrl,
|
||||
PydanticCustomError,
|
||||
PydanticOmit,
|
||||
Url,
|
||||
core_schema,
|
||||
)
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
from pydantic.errors import PydanticSchemaGenerationError
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic.types import Strict
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..json_schema import JsonSchemaValue, update_json_schema
|
||||
from . import _known_annotated_metadata, _typing_extra, _validators
|
||||
from ._core_utils import get_type_ref
|
||||
from ._internal_dataclass import slots_true
|
||||
from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from ._generate_schema import GenerateSchema
|
||||
|
||||
StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema]
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class SchemaTransformer:
|
||||
get_core_schema: Callable[[Any, GetCoreSchemaHandler], CoreSchema]
|
||||
get_json_schema: Callable[[CoreSchema, GetJsonSchemaHandler], JsonSchemaValue]
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
return self.get_core_schema(source_type, handler)
|
||||
|
||||
def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
return self.get_json_schema(schema, handler)
|
||||
|
||||
|
||||
def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
|
||||
cases: list[Any] = list(enum_type.__members__.values())
|
||||
|
||||
if not cases:
|
||||
# Use an isinstance check for enums with no cases.
|
||||
# This won't work with serialization or JSON schema, but that's okay -- the most important
|
||||
# use case for this is creating typevar bounds for generics that should be restricted to enums.
|
||||
# This is more consistent than it might seem at first, since you can only subclass enum.Enum
|
||||
# (or subclasses of enum.Enum) if all parent classes have no cases.
|
||||
return core_schema.is_instance_schema(enum_type)
|
||||
|
||||
use_enum_values = config.get('use_enum_values', False)
|
||||
|
||||
if len(cases) == 1:
|
||||
expected = repr(cases[0].value)
|
||||
else:
|
||||
expected = ', '.join([repr(case.value) for case in cases[:-1]]) + f' or {cases[-1].value!r}'
|
||||
|
||||
def to_enum(__input_value: Any) -> Enum:
|
||||
try:
|
||||
enum_field = enum_type(__input_value)
|
||||
if use_enum_values:
|
||||
return enum_field.value
|
||||
return enum_field
|
||||
except ValueError:
|
||||
# The type: ignore on the next line is to ignore the requirement of LiteralString
|
||||
raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore
|
||||
|
||||
enum_ref = get_type_ref(enum_type)
|
||||
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
|
||||
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
|
||||
description = None
|
||||
updates = {'title': enum_type.__name__, 'description': description}
|
||||
updates = {k: v for k, v in updates.items() if v is not None}
|
||||
|
||||
def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
|
||||
original_schema = handler.resolve_ref_schema(json_schema)
|
||||
update_json_schema(original_schema, updates)
|
||||
return json_schema
|
||||
|
||||
strict_python_schema = core_schema.is_instance_schema(enum_type)
|
||||
if use_enum_values:
|
||||
strict_python_schema = core_schema.chain_schema(
|
||||
[strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)]
|
||||
)
|
||||
|
||||
to_enum_validator = core_schema.no_info_plain_validator_function(to_enum)
|
||||
if issubclass(enum_type, int):
|
||||
# this handles `IntEnum`, and also `Foobar(int, Enum)`
|
||||
updates['type'] = 'integer'
|
||||
lax = core_schema.chain_schema([core_schema.int_schema(), to_enum_validator])
|
||||
# Disallow float from JSON due to strict mode
|
||||
strict = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()),
|
||||
python_schema=strict_python_schema,
|
||||
)
|
||||
elif issubclass(enum_type, str):
|
||||
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
|
||||
updates['type'] = 'string'
|
||||
lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator])
|
||||
strict = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()),
|
||||
python_schema=strict_python_schema,
|
||||
)
|
||||
elif issubclass(enum_type, float):
|
||||
updates['type'] = 'numeric'
|
||||
lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator])
|
||||
strict = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()),
|
||||
python_schema=strict_python_schema,
|
||||
)
|
||||
else:
|
||||
lax = to_enum_validator
|
||||
strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema)
|
||||
return core_schema.lax_or_strict_schema(
|
||||
lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]}
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class InnerSchemaValidator:
|
||||
"""Use a fixed CoreSchema, avoiding interference from outward annotations."""
|
||||
|
||||
core_schema: CoreSchema
|
||||
js_schema: JsonSchemaValue | None = None
|
||||
js_core_schema: CoreSchema | None = None
|
||||
js_schema_update: JsonSchemaValue | None = None
|
||||
|
||||
def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
if self.js_schema is not None:
|
||||
return self.js_schema
|
||||
js_schema = handler(self.js_core_schema or self.core_schema)
|
||||
if self.js_schema_update is not None:
|
||||
js_schema.update(self.js_schema_update)
|
||||
return js_schema
|
||||
|
||||
def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
return self.core_schema
|
||||
|
||||
|
||||
def decimal_prepare_pydantic_annotations(
|
||||
source: Any, annotations: Iterable[Any], config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
if source is not decimal.Decimal:
|
||||
return None
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
|
||||
config_allow_inf_nan = config.get('allow_inf_nan')
|
||||
if config_allow_inf_nan is not None:
|
||||
metadata.setdefault('allow_inf_nan', config_allow_inf_nan)
|
||||
|
||||
_known_annotated_metadata.check_metadata(
|
||||
metadata, {*_known_annotated_metadata.FLOAT_CONSTRAINTS, 'max_digits', 'decimal_places'}, decimal.Decimal
|
||||
)
|
||||
return source, [InnerSchemaValidator(core_schema.decimal_schema(**metadata)), *remaining_annotations]
|
||||
|
||||
|
||||
def datetime_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
import datetime
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
if source_type is datetime.date:
|
||||
sv = InnerSchemaValidator(core_schema.date_schema(**metadata))
|
||||
elif source_type is datetime.datetime:
|
||||
sv = InnerSchemaValidator(core_schema.datetime_schema(**metadata))
|
||||
elif source_type is datetime.time:
|
||||
sv = InnerSchemaValidator(core_schema.time_schema(**metadata))
|
||||
elif source_type is datetime.timedelta:
|
||||
sv = InnerSchemaValidator(core_schema.timedelta_schema(**metadata))
|
||||
else:
|
||||
return None
|
||||
# check now that we know the source type is correct
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.DATE_TIME_CONSTRAINTS, source_type)
|
||||
return (source_type, [sv, *remaining_annotations])
|
||||
|
||||
|
||||
def uuid_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
# UUIDs have no constraints - they are fixed length, constructing a UUID instance checks the length
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
if source_type is not UUID:
|
||||
return None
|
||||
|
||||
return (source_type, [InnerSchemaValidator(core_schema.uuid_schema()), *annotations])
|
||||
|
||||
|
||||
def path_schema_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
import pathlib
|
||||
|
||||
if source_type not in {
|
||||
os.PathLike,
|
||||
pathlib.Path,
|
||||
pathlib.PurePath,
|
||||
pathlib.PosixPath,
|
||||
pathlib.PurePosixPath,
|
||||
pathlib.PureWindowsPath,
|
||||
}:
|
||||
return None
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, source_type)
|
||||
|
||||
construct_path = pathlib.PurePath if source_type is os.PathLike else source_type
|
||||
|
||||
def path_validator(input_value: str) -> os.PathLike[Any]:
|
||||
try:
|
||||
return construct_path(input_value)
|
||||
except TypeError as e:
|
||||
raise PydanticCustomError('path_type', 'Input is not a valid path') from e
|
||||
|
||||
constrained_str_schema = core_schema.str_schema(**metadata)
|
||||
|
||||
instance_schema = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
|
||||
python_schema=core_schema.is_instance_schema(source_type),
|
||||
)
|
||||
|
||||
strict: bool | None = None
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, Strict):
|
||||
strict = annotation.strict
|
||||
|
||||
schema = core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.union_schema(
|
||||
[
|
||||
instance_schema,
|
||||
core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
|
||||
],
|
||||
custom_error_type='path_type',
|
||||
custom_error_message='Input is not a valid path',
|
||||
strict=True,
|
||||
),
|
||||
strict_schema=instance_schema,
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
return (
|
||||
source_type,
|
||||
[
|
||||
InnerSchemaValidator(schema, js_core_schema=constrained_str_schema, js_schema_update={'format': 'path'}),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def dequeue_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int
|
||||
) -> collections.deque[Any]:
|
||||
if isinstance(input_value, collections.deque):
|
||||
maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None]
|
||||
if maxlens:
|
||||
maxlen = min(maxlens)
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
else:
|
||||
return collections.deque(handler(input_value), maxlen=maxlen)
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class SequenceValidator:
|
||||
mapped_origin: type[Any]
|
||||
item_source_type: type[Any]
|
||||
min_length: int | None = None
|
||||
max_length: int | None = None
|
||||
strict: bool = False
|
||||
|
||||
def serialize_sequence_via_list(
|
||||
self, v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
|
||||
) -> Any:
|
||||
items: list[Any] = []
|
||||
for index, item in enumerate(v):
|
||||
try:
|
||||
v = handler(item, index)
|
||||
except PydanticOmit:
|
||||
pass
|
||||
else:
|
||||
items.append(v)
|
||||
|
||||
if info.mode_is_json():
|
||||
return items
|
||||
else:
|
||||
return self.mapped_origin(items)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if self.item_source_type is Any:
|
||||
items_schema = None
|
||||
else:
|
||||
items_schema = handler.generate_schema(self.item_source_type)
|
||||
|
||||
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
|
||||
|
||||
if self.mapped_origin in (list, set, frozenset):
|
||||
if self.mapped_origin is list:
|
||||
constrained_schema = core_schema.list_schema(items_schema, **metadata)
|
||||
elif self.mapped_origin is set:
|
||||
constrained_schema = core_schema.set_schema(items_schema, **metadata)
|
||||
else:
|
||||
assert self.mapped_origin is frozenset # safety check in case we forget to add a case
|
||||
constrained_schema = core_schema.frozenset_schema(items_schema, **metadata)
|
||||
|
||||
schema = constrained_schema
|
||||
else:
|
||||
# safety check in case we forget to add a case
|
||||
assert self.mapped_origin in (collections.deque, collections.Counter)
|
||||
|
||||
if self.mapped_origin is collections.deque:
|
||||
# if we have a MaxLen annotation might as well set that as the default maxlen on the deque
|
||||
# this lets us re-use existing metadata annotations to let users set the maxlen on a dequeue
|
||||
# that e.g. comes from JSON
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(dequeue_validator, maxlen=metadata.get('max_length', None)),
|
||||
)
|
||||
else:
|
||||
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
|
||||
|
||||
constrained_schema = core_schema.list_schema(items_schema, **metadata)
|
||||
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.list_schema(),
|
||||
python_schema=core_schema.is_instance_schema(self.mapped_origin),
|
||||
)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
self.serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.Deque: collections.deque,
|
||||
collections.deque: collections.deque,
|
||||
list: list,
|
||||
typing.List: list,
|
||||
set: set,
|
||||
typing.AbstractSet: set,
|
||||
typing.Set: set,
|
||||
frozenset: frozenset,
|
||||
typing.FrozenSet: frozenset,
|
||||
typing.Sequence: list,
|
||||
typing.MutableSequence: list,
|
||||
typing.MutableSet: set,
|
||||
# this doesn't handle subclasses of these
|
||||
# parametrized typing.Set creates one of these
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
}
|
||||
|
||||
|
||||
def identity(s: CoreSchema) -> CoreSchema:
|
||||
return s
|
||||
|
||||
|
||||
def sequence_like_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
origin: Any = get_origin(source_type)
|
||||
|
||||
mapped_origin = SEQUENCE_ORIGIN_MAP.get(origin, None) if origin else SEQUENCE_ORIGIN_MAP.get(source_type, None)
|
||||
if mapped_origin is None:
|
||||
return None
|
||||
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = (Any,)
|
||||
elif len(args) != 1:
|
||||
raise ValueError('Expected sequence to have exactly 1 generic parameter')
|
||||
|
||||
item_source_type = args[0]
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (source_type, [SequenceValidator(mapped_origin, item_source_type, **metadata), *remaining_annotations])
|
||||
|
||||
|
||||
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.DefaultDict: collections.defaultdict,
|
||||
collections.defaultdict: collections.defaultdict,
|
||||
collections.OrderedDict: collections.OrderedDict,
|
||||
typing_extensions.OrderedDict: collections.OrderedDict,
|
||||
dict: dict,
|
||||
typing.Dict: dict,
|
||||
collections.Counter: collections.Counter,
|
||||
typing.Counter: collections.Counter,
|
||||
# this doesn't handle subclasses of these
|
||||
typing.Mapping: dict,
|
||||
typing.MutableMapping: dict,
|
||||
# parametrized typing.{Mutable}Mapping creates one of these
|
||||
collections.abc.MutableMapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
}
|
||||
|
||||
|
||||
def defaultdict_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
|
||||
) -> collections.defaultdict[Any, Any]:
|
||||
if isinstance(input_value, collections.defaultdict):
|
||||
default_factory = input_value.default_factory
|
||||
return collections.defaultdict(default_factory, handler(input_value))
|
||||
else:
|
||||
return collections.defaultdict(default_default_factory, handler(input_value))
|
||||
|
||||
|
||||
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
|
||||
def infer_default() -> Callable[[], Any]:
|
||||
allowed_default_types: dict[Any, Any] = {
|
||||
typing.Tuple: tuple,
|
||||
tuple: tuple,
|
||||
collections.abc.Sequence: tuple,
|
||||
collections.abc.MutableSequence: list,
|
||||
typing.List: list,
|
||||
list: list,
|
||||
typing.Sequence: list,
|
||||
typing.Set: set,
|
||||
set: set,
|
||||
typing.MutableSet: set,
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
typing.MutableMapping: dict,
|
||||
typing.Mapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
float: float,
|
||||
int: int,
|
||||
str: str,
|
||||
bool: bool,
|
||||
}
|
||||
values_type_origin = get_origin(values_source_type) or values_source_type
|
||||
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
|
||||
if isinstance(values_type_origin, TypeVar):
|
||||
|
||||
def type_var_default_factory() -> None:
|
||||
raise RuntimeError(
|
||||
'Generic defaultdict cannot be used without a concrete value type or an'
|
||||
' explicit default factory, ' + instructions
|
||||
)
|
||||
|
||||
return type_var_default_factory
|
||||
elif values_type_origin not in allowed_default_types:
|
||||
# a somewhat subjective set of types that have reasonable default values
|
||||
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
|
||||
raise PydanticSchemaGenerationError(
|
||||
f'Unable to infer a default factory for keys of type {values_source_type}.'
|
||||
f' Only {allowed_msg} are supported, other types require an explicit default factory'
|
||||
' ' + instructions
|
||||
)
|
||||
return allowed_default_types[values_type_origin]
|
||||
|
||||
# Assume Annotated[..., Field(...)]
|
||||
if _typing_extra.is_annotated(values_source_type):
|
||||
field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
|
||||
else:
|
||||
field_info = None
|
||||
if field_info and field_info.default_factory:
|
||||
default_default_factory = field_info.default_factory
|
||||
else:
|
||||
default_default_factory = infer_default()
|
||||
return default_default_factory
|
||||
|
||||
|
||||
@dataclasses.dataclass(**slots_true)
|
||||
class MappingValidator:
|
||||
mapped_origin: type[Any]
|
||||
keys_source_type: type[Any]
|
||||
values_source_type: type[Any]
|
||||
min_length: int | None = None
|
||||
max_length: int | None = None
|
||||
strict: bool = False
|
||||
|
||||
def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any:
|
||||
return handler(v)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
|
||||
if self.keys_source_type is Any:
|
||||
keys_schema = None
|
||||
else:
|
||||
keys_schema = handler.generate_schema(self.keys_source_type)
|
||||
if self.values_source_type is Any:
|
||||
values_schema = None
|
||||
else:
|
||||
values_schema = handler.generate_schema(self.values_source_type)
|
||||
|
||||
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
|
||||
|
||||
if self.mapped_origin is dict:
|
||||
schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
else:
|
||||
constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
|
||||
check_instance = core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.dict_schema(),
|
||||
python_schema=core_schema.is_instance_schema(self.mapped_origin),
|
||||
)
|
||||
|
||||
if self.mapped_origin is collections.defaultdict:
|
||||
default_default_factory = get_defaultdict_default_default_factory(self.values_source_type)
|
||||
coerce_instance_wrap = partial(
|
||||
core_schema.no_info_wrap_validator_function,
|
||||
partial(defaultdict_validator, default_default_factory=default_default_factory),
|
||||
)
|
||||
else:
|
||||
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
|
||||
|
||||
serialization = core_schema.wrap_serializer_function_ser_schema(
|
||||
self.serialize_mapping_via_dict,
|
||||
schema=core_schema.dict_schema(
|
||||
keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema()
|
||||
),
|
||||
info_arg=False,
|
||||
)
|
||||
|
||||
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
|
||||
|
||||
if metadata.get('strict', False):
|
||||
schema = strict
|
||||
else:
|
||||
lax = coerce_instance_wrap(constrained_schema)
|
||||
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
|
||||
schema['serialization'] = serialization
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def mapping_like_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
origin: Any = get_origin(source_type)
|
||||
|
||||
mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None)
|
||||
if mapped_origin is None:
|
||||
return None
|
||||
|
||||
args = get_args(source_type)
|
||||
|
||||
if not args:
|
||||
args = (Any, Any)
|
||||
elif mapped_origin is collections.Counter:
|
||||
# a single generic
|
||||
if len(args) != 1:
|
||||
raise ValueError('Expected Counter to have exactly 1 generic parameter')
|
||||
args = (args[0], int) # keys are always an int
|
||||
elif len(args) != 2:
|
||||
raise ValueError('Expected mapping to have exactly 2 generic parameters')
|
||||
|
||||
keys_source_type, values_source_type = args
|
||||
|
||||
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
|
||||
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
|
||||
|
||||
return (
|
||||
source_type,
|
||||
[
|
||||
MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata),
|
||||
*remaining_annotations,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def ip_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
def make_strict_ip_schema(tp: type[Any]) -> CoreSchema:
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.no_info_after_validator_function(tp, core_schema.str_schema()),
|
||||
python_schema=core_schema.is_instance_schema(tp),
|
||||
)
|
||||
|
||||
if source_type is IPv4Address:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_address_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv4Address),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv4'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv4Network:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_network_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv4Network),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv4network'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv4Interface:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_interface_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv4Interface),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv4interface'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
|
||||
if source_type is IPv6Address:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_address_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv6Address),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv6'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv6Network:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_network_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv6Network),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv6network'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is IPv6Interface:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.lax_or_strict_schema(
|
||||
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_interface_validator),
|
||||
strict_schema=make_strict_ip_schema(IPv6Interface),
|
||||
serialization=core_schema.to_string_ser_schema(),
|
||||
),
|
||||
lambda _1, _2: {'type': 'string', 'format': 'ipv6interface'},
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def url_prepare_pydantic_annotations(
|
||||
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
|
||||
) -> tuple[Any, list[Any]] | None:
|
||||
if source_type is Url:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.url_schema(),
|
||||
lambda cs, handler: handler(cs),
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
if source_type is MultiHostUrl:
|
||||
return source_type, [
|
||||
SchemaTransformer(
|
||||
lambda _1, _2: core_schema.multi_host_url_schema(),
|
||||
lambda cs, handler: handler(cs),
|
||||
),
|
||||
*annotations,
|
||||
]
|
||||
|
||||
|
||||
PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any], ConfigDict], tuple[Any, list[Any]] | None], ...] = (
|
||||
decimal_prepare_pydantic_annotations,
|
||||
sequence_like_prepare_pydantic_annotations,
|
||||
datetime_prepare_pydantic_annotations,
|
||||
uuid_prepare_pydantic_annotations,
|
||||
path_schema_prepare_pydantic_annotations,
|
||||
mapping_like_prepare_pydantic_annotations,
|
||||
ip_prepare_pydantic_annotations,
|
||||
url_prepare_pydantic_annotations,
|
||||
)
|
||||
@@ -1,544 +1,244 @@
|
||||
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap Python's typing module."""
|
||||
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap python's typing module."""
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import re
|
||||
import dataclasses
|
||||
import sys
|
||||
import types
|
||||
import typing
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, cast
|
||||
from types import GetSetDescriptorType
|
||||
from typing import TYPE_CHECKING, Any, ForwardRef
|
||||
|
||||
import typing_extensions
|
||||
from typing_extensions import deprecated, get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
from typing_extensions import Annotated, Final, Literal, TypeAliasType, TypeGuard, get_args, get_origin
|
||||
|
||||
from pydantic.version import version_short
|
||||
if TYPE_CHECKING:
|
||||
from ._dataclasses import StandardDataclass
|
||||
|
||||
try:
|
||||
from typing import _TypingBase # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
from typing import _Final as _TypingBase # type: ignore[attr-defined]
|
||||
|
||||
typing_base = _TypingBase
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
TypingGenericAlias = ()
|
||||
else:
|
||||
from typing import GenericAlias as TypingGenericAlias # type: ignore
|
||||
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import NotRequired, Required
|
||||
else:
|
||||
from typing import NotRequired, Required # noqa: F401
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
|
||||
def origin_is_union(tp: type[Any] | None) -> bool:
|
||||
return tp is typing.Union
|
||||
|
||||
WithArgsTypes = (TypingGenericAlias,)
|
||||
|
||||
else:
|
||||
|
||||
def origin_is_union(tp: type[Any] | None) -> bool:
|
||||
return tp is typing.Union or tp is types.UnionType
|
||||
|
||||
WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined]
|
||||
|
||||
from ._namespace_utils import GlobalsNamespace, MappingNamespace, NsResolver, get_module_ns_of
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
NoneType = type(None)
|
||||
EllipsisType = type(Ellipsis)
|
||||
else:
|
||||
from types import EllipsisType as EllipsisType
|
||||
from types import NoneType as NoneType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
# As per https://typing-extensions.readthedocs.io/en/latest/#runtime-use-of-types,
|
||||
# always check for both `typing` and `typing_extensions` variants of a typing construct.
|
||||
# (this is implemented differently than the suggested approach in the `typing_extensions`
|
||||
# docs for performance).
|
||||
LITERAL_TYPES: set[Any] = {Literal}
|
||||
if hasattr(typing, 'Literal'):
|
||||
LITERAL_TYPES.add(typing.Literal) # type: ignore
|
||||
|
||||
NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES))
|
||||
|
||||
|
||||
_t_annotated = typing.Annotated
|
||||
_te_annotated = typing_extensions.Annotated
|
||||
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
|
||||
|
||||
|
||||
def is_annotated(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Annotated` special form.
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
return type_ in NONE_TYPES
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_annotated(Annotated[int, ...])
|
||||
#> True
|
||||
```
|
||||
|
||||
def is_callable_type(type_: type[Any]) -> bool:
|
||||
return type_ is Callable or get_origin(type_) is Callable
|
||||
|
||||
|
||||
def is_literal_type(type_: type[Any]) -> bool:
|
||||
return Literal is not None and get_origin(type_) in LITERAL_TYPES
|
||||
|
||||
|
||||
def literal_values(type_: type[Any]) -> tuple[Any, ...]:
|
||||
return get_args(type_)
|
||||
|
||||
|
||||
def all_literal_values(type_: type[Any]) -> list[Any]:
|
||||
"""This method is used to retrieve all Literal values as
|
||||
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
|
||||
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`.
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
return origin is _t_annotated or origin is _te_annotated
|
||||
if not is_literal_type(type_):
|
||||
return [type_]
|
||||
|
||||
values = literal_values(type_)
|
||||
return list(x for value in values for x in all_literal_values(value))
|
||||
|
||||
|
||||
def annotated_type(tp: Any, /) -> Any | None:
|
||||
"""Return the type of the `Annotated` special form, or `None`."""
|
||||
return tp.__origin__ if typing_objects.is_annotated(get_origin(tp)) else None
|
||||
def is_annotated(ann_type: Any) -> bool:
|
||||
from ._utils import lenient_issubclass
|
||||
|
||||
origin = get_origin(ann_type)
|
||||
return origin is not None and lenient_issubclass(origin, Annotated)
|
||||
|
||||
|
||||
def unpack_type(tp: Any, /) -> Any | None:
|
||||
"""Return the type wrapped by the `Unpack` special form, or `None`."""
|
||||
return get_args(tp)[0] if typing_objects.is_unpack(get_origin(tp)) else None
|
||||
|
||||
|
||||
def is_hashable(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is the `Hashable` class.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_hashable(Hashable)
|
||||
#> True
|
||||
```
|
||||
def is_namedtuple(type_: type[Any]) -> bool:
|
||||
"""Check if a given class is a named tuple.
|
||||
It can be either a `typing.NamedTuple` or `collections.namedtuple`.
|
||||
"""
|
||||
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
|
||||
# hence the second check:
|
||||
return tp is collections.abc.Hashable or get_origin(tp) is collections.abc.Hashable
|
||||
from ._utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
|
||||
|
||||
|
||||
def is_callable(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Callable`, parametrized or not.
|
||||
test_new_type = typing.NewType('test_new_type', str)
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_callable(Callable[[int], str])
|
||||
#> True
|
||||
is_callable(typing.Callable)
|
||||
#> True
|
||||
is_callable(collections.abc.Callable)
|
||||
#> True
|
||||
```
|
||||
|
||||
def is_new_type(type_: type[Any]) -> bool:
|
||||
"""Check whether type_ was created using typing.NewType.
|
||||
|
||||
Can't use isinstance because it fails <3.10.
|
||||
"""
|
||||
# `get_origin` is documented as normalizing any typing-module aliases to `collections` classes,
|
||||
# hence the second check:
|
||||
return tp is collections.abc.Callable or get_origin(tp) is collections.abc.Callable
|
||||
return isinstance(type_, test_new_type.__class__) and hasattr(type_, '__supertype__') # type: ignore[arg-type]
|
||||
|
||||
|
||||
_classvar_re = re.compile(r'((\w+\.)?Annotated\[)?(\w+\.)?ClassVar\[')
|
||||
def _check_classvar(v: type[Any] | None) -> bool:
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == typing.ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
|
||||
|
||||
|
||||
def is_classvar_annotation(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument represents a class variable annotation.
|
||||
|
||||
Although not explicitly stated by the typing specification, `ClassVar` can be used
|
||||
inside `Annotated` and as such, this function checks for this specific scenario.
|
||||
|
||||
Because this function is used to detect class variables before evaluating forward references
|
||||
(or because evaluation failed), we also implement a naive regex match implementation. This is
|
||||
required because class variables are inspected before fields are collected, so we try to be
|
||||
as accurate as possible.
|
||||
"""
|
||||
if typing_objects.is_classvar(tp):
|
||||
def is_classvar(ann_type: type[Any]) -> bool:
|
||||
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
|
||||
return True
|
||||
|
||||
origin = get_origin(tp)
|
||||
|
||||
if typing_objects.is_classvar(origin):
|
||||
return True
|
||||
|
||||
if typing_objects.is_annotated(origin):
|
||||
annotated_type = tp.__origin__
|
||||
if typing_objects.is_classvar(annotated_type) or typing_objects.is_classvar(get_origin(annotated_type)):
|
||||
return True
|
||||
|
||||
str_ann: str | None = None
|
||||
if isinstance(tp, typing.ForwardRef):
|
||||
str_ann = tp.__forward_arg__
|
||||
if isinstance(tp, str):
|
||||
str_ann = tp
|
||||
|
||||
if str_ann is not None and _classvar_re.match(str_ann):
|
||||
# stdlib dataclasses do something similar, although a bit more advanced
|
||||
# (see `dataclass._is_type`).
|
||||
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
|
||||
# forward references, see #3679
|
||||
if ann_type.__class__ == typing.ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): # type: ignore
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
_t_final = typing.Final
|
||||
_te_final = typing_extensions.Final
|
||||
def _check_finalvar(v: type[Any] | None) -> bool:
|
||||
"""Check if a given type is a `typing.Final` type."""
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
|
||||
|
||||
|
||||
# TODO implement `is_finalvar_annotation` as Final can be wrapped with other special forms:
|
||||
def is_finalvar(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a `Final` special form, parametrized or not.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_finalvar(Final[int])
|
||||
#> True
|
||||
is_finalvar(Final)
|
||||
#> True
|
||||
"""
|
||||
# Final is not necessarily parametrized:
|
||||
if tp is _t_final or tp is _te_final:
|
||||
return True
|
||||
origin = get_origin(tp)
|
||||
return origin is _t_final or origin is _te_final
|
||||
def is_finalvar(ann_type: Any) -> bool:
|
||||
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
|
||||
|
||||
|
||||
_NONE_TYPES: tuple[Any, ...] = (None, NoneType, typing.Literal[None], typing_extensions.Literal[None])
|
||||
def parent_frame_namespace(*, parent_depth: int = 2) -> dict[str, Any] | None:
|
||||
"""We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
|
||||
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope
|
||||
and suggestion at the end of the next comment by @gvanrossum.
|
||||
|
||||
WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the
|
||||
parent of where it is called.
|
||||
|
||||
def is_none_type(tp: Any, /) -> bool:
|
||||
"""Return whether the argument represents the `None` type as part of an annotation.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
is_none_type(None)
|
||||
#> True
|
||||
is_none_type(NoneType)
|
||||
#> True
|
||||
is_none_type(Literal[None])
|
||||
#> True
|
||||
is_none_type(type[None])
|
||||
#> False
|
||||
"""
|
||||
return tp in _NONE_TYPES
|
||||
|
||||
|
||||
def is_namedtuple(tp: Any, /) -> bool:
|
||||
"""Return whether the provided argument is a named tuple class.
|
||||
|
||||
The class can be created using `typing.NamedTuple` or `collections.namedtuple`.
|
||||
Parametrized generic classes are *not* assumed to be named tuples.
|
||||
"""
|
||||
from ._utils import lenient_issubclass # circ. import
|
||||
|
||||
return lenient_issubclass(tp, tuple) and hasattr(tp, '_fields')
|
||||
|
||||
|
||||
# TODO In 2.12, delete this export. It is currently defined only to not break
|
||||
# pydantic-settings which relies on it:
|
||||
origin_is_union = is_union_origin
|
||||
|
||||
|
||||
def is_generic_alias(tp: Any, /) -> bool:
|
||||
return isinstance(tp, (types.GenericAlias, typing._GenericAlias)) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
# TODO: Ideally, we should avoid relying on the private `typing` constructs:
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # pyright: ignore[reportAttributeAccessIssue]
|
||||
else:
|
||||
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias, types.UnionType) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
# Similarly, we shouldn't rely on this `_Final` class, which is even more private than `_GenericAlias`:
|
||||
typing_base: Any = typing._Final # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
### Annotation evaluations functions:
|
||||
|
||||
|
||||
def parent_frame_namespace(*, parent_depth: int = 2, force: bool = False) -> dict[str, Any] | None:
|
||||
"""Fetch the local namespace of the parent frame where this function is called.
|
||||
|
||||
Using this function is mostly useful to resolve forward annotations pointing to members defined in a local namespace,
|
||||
such as assignments inside a function. Using the standard library tools, it is currently not possible to resolve
|
||||
such annotations:
|
||||
|
||||
```python {lint="skip" test="skip"}
|
||||
from typing import get_type_hints
|
||||
|
||||
def func() -> None:
|
||||
Alias = int
|
||||
|
||||
class C:
|
||||
a: 'Alias'
|
||||
|
||||
# Raises a `NameError: 'Alias' is not defined`
|
||||
get_type_hints(C)
|
||||
```
|
||||
|
||||
Pydantic uses this function when a Pydantic model is being defined to fetch the parent frame locals. However,
|
||||
this only allows us to fetch the parent frame namespace and not other parents (e.g. a model defined in a function,
|
||||
itself defined in another function). Inspecting the next outer frames (using `f_back`) is not reliable enough
|
||||
(see https://discuss.python.org/t/20659).
|
||||
|
||||
Because this function is mostly used to better resolve forward annotations, nothing is returned if the parent frame's
|
||||
code object is defined at the module level. In this case, the locals of the frame will be the same as the module
|
||||
globals where the class is defined (see `_namespace_utils.get_module_ns_of`). However, if you still want to fetch
|
||||
the module globals (e.g. when rebuilding a model, where the frame where the rebuild call is performed might contain
|
||||
members that you want to use for forward annotations evaluation), you can use the `force` parameter.
|
||||
|
||||
Args:
|
||||
parent_depth: The depth at which to get the frame. Defaults to 2, meaning the parent frame where this function
|
||||
is called will be used.
|
||||
force: Whether to always return the frame locals, even if the frame's code object is defined at the module level.
|
||||
|
||||
Returns:
|
||||
The locals of the namespace, or `None` if it was skipped as per the described logic.
|
||||
WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a
|
||||
dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many
|
||||
other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659.
|
||||
"""
|
||||
frame = sys._getframe(parent_depth)
|
||||
|
||||
if frame.f_code.co_name.startswith('<generic parameters of'):
|
||||
# As `parent_frame_namespace` is mostly called in `ModelMetaclass.__new__`,
|
||||
# the parent frame can be the annotation scope if the PEP 695 generic syntax is used.
|
||||
# (see https://docs.python.org/3/reference/executionmodel.html#annotation-scopes,
|
||||
# https://docs.python.org/3/reference/compound_stmts.html#generic-classes).
|
||||
# In this case, the code name is set to `<generic parameters of MyClass>`,
|
||||
# and we need to skip this frame as it is irrelevant.
|
||||
frame = cast(types.FrameType, frame.f_back) # guaranteed to not be `None`
|
||||
|
||||
# note, we don't copy frame.f_locals here (or during the last return call), because we don't expect the namespace to be
|
||||
# modified down the line if this becomes a problem, we could implement some sort of frozen mapping structure to enforce this.
|
||||
if force:
|
||||
# if f_back is None, it's the global module namespace and we don't need to include it here
|
||||
if frame.f_back is None:
|
||||
return None
|
||||
else:
|
||||
return frame.f_locals
|
||||
|
||||
# If either of the following conditions are true, the class is defined at the top module level.
|
||||
# To better understand why we need both of these checks, see
|
||||
# https://github.com/pydantic/pydantic/pull/10113#discussion_r1714981531.
|
||||
if frame.f_back is None or frame.f_code.co_name == '<module>':
|
||||
return None
|
||||
|
||||
return frame.f_locals
|
||||
def add_module_globals(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
module_name = getattr(obj, '__module__', None)
|
||||
if module_name:
|
||||
try:
|
||||
module_globalns = sys.modules[module_name].__dict__
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
pass
|
||||
else:
|
||||
if globalns:
|
||||
return {**module_globalns, **globalns}
|
||||
else:
|
||||
# copy module globals to make sure it can't be updated later
|
||||
return module_globalns.copy()
|
||||
|
||||
return globalns or {}
|
||||
|
||||
|
||||
def _type_convert(arg: Any) -> Any:
|
||||
"""Convert `None` to `NoneType` and strings to `ForwardRef` instances.
|
||||
|
||||
This is a backport of the private `typing._type_convert` function. When
|
||||
evaluating a type, `ForwardRef._evaluate` ends up being called, and is
|
||||
responsible for making this conversion. However, we still have to apply
|
||||
it for the first argument passed to our type evaluation functions, similarly
|
||||
to the `typing.get_type_hints` function.
|
||||
"""
|
||||
if arg is None:
|
||||
return NoneType
|
||||
if isinstance(arg, str):
|
||||
# Like `typing.get_type_hints`, assume the arg can be in any context,
|
||||
# hence the proper `is_argument` and `is_class` args:
|
||||
return _make_forward_ref(arg, is_argument=False, is_class=True)
|
||||
return arg
|
||||
def get_cls_types_namespace(cls: type[Any], parent_namespace: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
ns = add_module_globals(cls, parent_namespace)
|
||||
ns[cls.__name__] = cls
|
||||
return ns
|
||||
|
||||
|
||||
def get_model_type_hints(
|
||||
obj: type[BaseModel],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
) -> dict[str, tuple[Any, bool]]:
|
||||
"""Collect annotations from a Pydantic model class, including those from parent classes.
|
||||
|
||||
Args:
|
||||
obj: The Pydantic model to inspect.
|
||||
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping annotation names to a two-tuple: the first element is the evaluated
|
||||
type or the original annotation if a `NameError` occurred, the second element is a boolean
|
||||
indicating if whether the evaluation succeeded.
|
||||
"""
|
||||
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
|
||||
for base in reversed(obj.__mro__):
|
||||
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
|
||||
if not ann or isinstance(ann, types.GetSetDescriptorType):
|
||||
continue
|
||||
with ns_resolver.push(base):
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
for name, value in ann.items():
|
||||
if name.startswith('_'):
|
||||
# For private attributes, we only need the annotation to detect the `ClassVar` special form.
|
||||
# For this reason, we still try to evaluate it, but we also catch any possible exception (on
|
||||
# top of the `NameError`s caught in `try_eval_type`) that could happen so that users are free
|
||||
# to use any kind of forward annotation for private fields (e.g. circular imports, new typing
|
||||
# syntax, etc).
|
||||
try:
|
||||
hints[name] = try_eval_type(value, globalns, localns)
|
||||
except Exception:
|
||||
hints[name] = (value, False)
|
||||
else:
|
||||
hints[name] = try_eval_type(value, globalns, localns)
|
||||
return hints
|
||||
|
||||
|
||||
def get_cls_type_hints(
|
||||
obj: type[Any],
|
||||
*,
|
||||
ns_resolver: NsResolver | None = None,
|
||||
) -> dict[str, Any]:
|
||||
def get_cls_type_hints_lenient(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
|
||||
"""Collect annotations from a class, including those from parent classes.
|
||||
|
||||
Args:
|
||||
obj: The class to inspect.
|
||||
ns_resolver: A namespace resolver instance to use. Defaults to an empty instance.
|
||||
Unlike `typing.get_type_hints`, this function will not error if a forward reference is not resolvable.
|
||||
"""
|
||||
hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {}
|
||||
ns_resolver = ns_resolver or NsResolver()
|
||||
|
||||
hints = {}
|
||||
for base in reversed(obj.__mro__):
|
||||
ann: dict[str, Any] | None = base.__dict__.get('__annotations__')
|
||||
if not ann or isinstance(ann, types.GetSetDescriptorType):
|
||||
continue
|
||||
with ns_resolver.push(base):
|
||||
globalns, localns = ns_resolver.types_namespace
|
||||
ann = base.__dict__.get('__annotations__')
|
||||
localns = dict(vars(base))
|
||||
if ann is not None and ann is not GetSetDescriptorType:
|
||||
for name, value in ann.items():
|
||||
hints[name] = eval_type(value, globalns, localns)
|
||||
hints[name] = eval_type_lenient(value, globalns, localns)
|
||||
return hints
|
||||
|
||||
|
||||
def try_eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> tuple[Any, bool]:
|
||||
"""Try evaluating the annotation using the provided namespaces.
|
||||
|
||||
Args:
|
||||
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
|
||||
of `str`, it will be converted to a `ForwardRef`.
|
||||
localns: The global namespace to use during annotation evaluation.
|
||||
globalns: The local namespace to use during annotation evaluation.
|
||||
|
||||
Returns:
|
||||
A two-tuple containing the possibly evaluated type and a boolean indicating
|
||||
whether the evaluation succeeded or not.
|
||||
"""
|
||||
value = _type_convert(value)
|
||||
def eval_type_lenient(value: Any, globalns: dict[str, Any] | None, localns: dict[str, Any] | None) -> Any:
|
||||
"""Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved."""
|
||||
if value is None:
|
||||
value = NoneType
|
||||
elif isinstance(value, str):
|
||||
value = _make_forward_ref(value, is_argument=False, is_class=True)
|
||||
|
||||
try:
|
||||
return eval_type_backport(value, globalns, localns), True
|
||||
return typing._eval_type(value, globalns, localns) # type: ignore
|
||||
except NameError:
|
||||
return value, False
|
||||
|
||||
|
||||
def eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any:
|
||||
"""Evaluate the annotation using the provided namespaces.
|
||||
|
||||
Args:
|
||||
value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance
|
||||
of `str`, it will be converted to a `ForwardRef`.
|
||||
localns: The global namespace to use during annotation evaluation.
|
||||
globalns: The local namespace to use during annotation evaluation.
|
||||
"""
|
||||
value = _type_convert(value)
|
||||
return eval_type_backport(value, globalns, localns)
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`eval_type_lenient` is deprecated, use `try_eval_type` instead.',
|
||||
category=None,
|
||||
)
|
||||
def eval_type_lenient(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
) -> Any:
|
||||
ev, _ = try_eval_type(value, globalns, localns)
|
||||
return ev
|
||||
|
||||
|
||||
def eval_type_backport(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
"""An enhanced version of `typing._eval_type` which will fall back to using the `eval_type_backport`
|
||||
package if it's installed to let older Python versions use newer typing constructs.
|
||||
|
||||
Specifically, this transforms `X | Y` into `typing.Union[X, Y]` and `list[X]` into `typing.List[X]`
|
||||
(as well as all the types made generic in PEP 585) if the original syntax is not supported in the
|
||||
current Python version.
|
||||
|
||||
This function will also display a helpful error if the value passed fails to evaluate.
|
||||
"""
|
||||
try:
|
||||
return _eval_type_backport(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if 'Unable to evaluate type annotation' in str(e):
|
||||
raise
|
||||
|
||||
# If it is a `TypeError` and value isn't a `ForwardRef`, it would have failed during annotation definition.
|
||||
# Thus we assert here for type checking purposes:
|
||||
assert isinstance(value, typing.ForwardRef)
|
||||
|
||||
message = f'Unable to evaluate type annotation {value.__forward_arg__!r}.'
|
||||
if sys.version_info >= (3, 11):
|
||||
e.add_note(message)
|
||||
raise
|
||||
else:
|
||||
raise TypeError(message) from e
|
||||
except RecursionError as e:
|
||||
# TODO ideally recursion errors should be checked in `eval_type` above, but `eval_type_backport`
|
||||
# is used directly in some places.
|
||||
message = (
|
||||
"If you made use of an implicit recursive type alias (e.g. `MyType = list['MyType']), "
|
||||
'consider using PEP 695 type aliases instead. For more details, refer to the documentation: '
|
||||
f'https://docs.pydantic.dev/{version_short()}/concepts/types/#named-recursive-types'
|
||||
)
|
||||
if sys.version_info >= (3, 11):
|
||||
e.add_note(message)
|
||||
raise
|
||||
else:
|
||||
raise RecursionError(f'{e.args[0]}\n{message}')
|
||||
|
||||
|
||||
def _eval_type_backport(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
try:
|
||||
return _eval_type(value, globalns, localns, type_params)
|
||||
except TypeError as e:
|
||||
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
|
||||
raise
|
||||
|
||||
try:
|
||||
from eval_type_backport import eval_type_backport
|
||||
except ImportError:
|
||||
raise TypeError(
|
||||
f'Unable to evaluate type annotation {value.__forward_arg__!r}. If you are making use '
|
||||
'of the new typing syntax (unions using `|` since Python 3.10 or builtins subscripting '
|
||||
'since Python 3.9), you should either replace the use of new syntax with the existing '
|
||||
'`typing` constructs or install the `eval_type_backport` package.'
|
||||
) from e
|
||||
|
||||
return eval_type_backport(
|
||||
value,
|
||||
globalns,
|
||||
localns, # pyright: ignore[reportArgumentType], waiting on a new `eval_type_backport` release.
|
||||
try_default=False,
|
||||
)
|
||||
|
||||
|
||||
def _eval_type(
|
||||
value: Any,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
type_params: tuple[Any, ...] | None = None,
|
||||
) -> Any:
|
||||
if sys.version_info >= (3, 13):
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns, type_params=type_params
|
||||
)
|
||||
else:
|
||||
return typing._eval_type( # type: ignore
|
||||
value, globalns, localns
|
||||
)
|
||||
|
||||
|
||||
def is_backport_fixable_error(e: TypeError) -> bool:
|
||||
msg = str(e)
|
||||
|
||||
return sys.version_info < (3, 10) and msg.startswith('unsupported operand type(s) for |: ')
|
||||
# the point of this function is to be tolerant to this case
|
||||
return value
|
||||
|
||||
|
||||
def get_function_type_hints(
|
||||
function: Callable[..., Any],
|
||||
*,
|
||||
include_keys: set[str] | None = None,
|
||||
globalns: GlobalsNamespace | None = None,
|
||||
localns: MappingNamespace | None = None,
|
||||
function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Return type hints for a function.
|
||||
|
||||
This is similar to the `typing.get_type_hints` function, with a few differences:
|
||||
- Support `functools.partial` by using the underlying `func` attribute.
|
||||
- Do not wrap type annotation of a parameter with `Optional` if it has a default value of `None`
|
||||
(related bug: https://github.com/python/cpython/issues/90353, only fixed in 3.11+).
|
||||
"""Like `typing.get_type_hints`, but doesn't convert `X` to `Optional[X]` if the default value is `None`, also
|
||||
copes with `partial`.
|
||||
"""
|
||||
try:
|
||||
if isinstance(function, partial):
|
||||
annotations = function.func.__annotations__
|
||||
else:
|
||||
annotations = function.__annotations__
|
||||
except AttributeError:
|
||||
# Some functions (e.g. builtins) don't have annotations:
|
||||
return {}
|
||||
|
||||
if globalns is None:
|
||||
globalns = get_module_ns_of(function)
|
||||
type_params: tuple[Any, ...] | None = None
|
||||
if localns is None:
|
||||
# If localns was specified, it is assumed to already contain type params. This is because
|
||||
# Pydantic has more advanced logic to do so (see `_namespace_utils.ns_for_function`).
|
||||
type_params = getattr(function, '__type_params__', ())
|
||||
if isinstance(function, partial):
|
||||
annotations = function.func.__annotations__
|
||||
else:
|
||||
annotations = function.__annotations__
|
||||
|
||||
globalns = add_module_globals(function)
|
||||
type_hints = {}
|
||||
for name, value in annotations.items():
|
||||
if include_keys is not None and name not in include_keys:
|
||||
@@ -548,7 +248,7 @@ def get_function_type_hints(
|
||||
elif isinstance(value, str):
|
||||
value = _make_forward_ref(value)
|
||||
|
||||
type_hints[name] = eval_type_backport(value, globalns, localns, type_params)
|
||||
type_hints[name] = typing._eval_type(value, globalns, types_namespace) # type: ignore
|
||||
|
||||
return type_hints
|
||||
|
||||
@@ -663,15 +363,11 @@ else:
|
||||
if isinstance(value, str):
|
||||
value = _make_forward_ref(value, is_argument=False, is_class=True)
|
||||
|
||||
value = eval_type_backport(value, base_globals, base_locals)
|
||||
value = typing._eval_type(value, base_globals, base_locals) # type: ignore
|
||||
hints[name] = value
|
||||
if not include_extras and hasattr(typing, '_strip_annotations'):
|
||||
return {
|
||||
k: typing._strip_annotations(t) # type: ignore
|
||||
for k, t in hints.items()
|
||||
}
|
||||
else:
|
||||
return hints
|
||||
return (
|
||||
hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
|
||||
)
|
||||
|
||||
if globalns is None:
|
||||
if isinstance(obj, types.ModuleType):
|
||||
@@ -692,7 +388,7 @@ else:
|
||||
if isinstance(obj, typing._allowed_types): # type: ignore
|
||||
return {}
|
||||
else:
|
||||
raise TypeError(f'{obj!r} is not a module, class, method, or function.')
|
||||
raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.')
|
||||
defaults = typing._get_defaults(obj) # type: ignore
|
||||
hints = dict(hints)
|
||||
for name, value in hints.items():
|
||||
@@ -707,8 +403,33 @@ else:
|
||||
is_argument=not isinstance(obj, types.ModuleType),
|
||||
is_class=False,
|
||||
)
|
||||
value = eval_type_backport(value, globalns, localns)
|
||||
value = typing._eval_type(value, globalns, localns) # type: ignore
|
||||
if name in defaults and defaults[name] is None:
|
||||
value = typing.Optional[value]
|
||||
hints[name] = value
|
||||
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def evaluate_fwd_ref(
|
||||
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
return ref._evaluate(globalns=globalns, localns=localns)
|
||||
|
||||
else:
|
||||
|
||||
def evaluate_fwd_ref(
|
||||
ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
return ref._evaluate(globalns=globalns, localns=localns, recursive_guard=frozenset())
|
||||
|
||||
|
||||
def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
|
||||
# The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality,
|
||||
# so I created this convenience function
|
||||
return dataclasses.is_dataclass(_cls)
|
||||
|
||||
|
||||
def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]:
|
||||
return isinstance(origin, TypeAliasType)
|
||||
|
||||
@@ -2,30 +2,20 @@
|
||||
|
||||
This should be reduced as much as possible with functions only used in one place, moved to that place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import keyword
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from inspect import Parameter
|
||||
from itertools import zip_longest
|
||||
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
|
||||
from typing import Any, Callable, Generic, TypeVar, overload
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from typing_extensions import TypeAlias, TypeGuard, deprecated
|
||||
|
||||
from pydantic import PydanticDeprecatedSince211
|
||||
from typing_extensions import TypeAlias, TypeGuard
|
||||
|
||||
from . import _repr, _typing_extra
|
||||
from ._import_utils import import_cached_base_model
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
|
||||
@@ -69,25 +59,6 @@ BUILTIN_COLLECTIONS: set[type[Any]] = {
|
||||
}
|
||||
|
||||
|
||||
def can_be_positional(param: Parameter) -> bool:
|
||||
"""Return whether the parameter accepts a positional argument.
|
||||
|
||||
```python {test="skip" lint="skip"}
|
||||
def func(a, /, b, *, c):
|
||||
pass
|
||||
|
||||
params = inspect.signature(func).parameters
|
||||
can_be_positional(params['a'])
|
||||
#> True
|
||||
can_be_positional(params['b'])
|
||||
#> True
|
||||
can_be_positional(params['c'])
|
||||
#> False
|
||||
```
|
||||
"""
|
||||
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
|
||||
|
||||
|
||||
def sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||||
|
||||
@@ -112,7 +83,7 @@ def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
|
||||
"""Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
|
||||
unlike raw calls to lenient_issubclass.
|
||||
"""
|
||||
BaseModel = import_cached_base_model()
|
||||
from ..main import BaseModel
|
||||
|
||||
return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
|
||||
|
||||
@@ -304,23 +275,19 @@ class ValueItems(_repr.Representation):
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def LazyClassAttribute(name: str, get_value: Callable[[], T]) -> T: ...
|
||||
def ClassAttribute(name: str, value: T) -> T:
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
class LazyClassAttribute:
|
||||
"""A descriptor exposing an attribute only accessible on a class (hidden from instances).
|
||||
class ClassAttribute:
|
||||
"""Hide class attribute from its instances."""
|
||||
|
||||
The attribute is lazily computed and cached during the first access.
|
||||
"""
|
||||
__slots__ = 'name', 'value'
|
||||
|
||||
def __init__(self, name: str, get_value: Callable[[], Any]) -> None:
|
||||
def __init__(self, name: str, value: Any) -> None:
|
||||
self.name = name
|
||||
self.get_value = get_value
|
||||
|
||||
@cached_property
|
||||
def value(self) -> Any:
|
||||
return self.get_value()
|
||||
self.value = value
|
||||
|
||||
def __get__(self, instance: Any, owner: type[Any]) -> None:
|
||||
if instance is None:
|
||||
@@ -342,7 +309,7 @@ def smart_deepcopy(obj: Obj) -> Obj:
|
||||
try:
|
||||
if not obj and obj_type in BUILTIN_COLLECTIONS:
|
||||
# faster way for empty collections, no need to copy its members
|
||||
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore
|
||||
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method
|
||||
except (TypeError, ValueError, RuntimeError):
|
||||
# do we really dare to catch ALL errors? Seems a bit risky
|
||||
pass
|
||||
@@ -350,7 +317,7 @@ def smart_deepcopy(obj: Obj) -> Obj:
|
||||
return deepcopy(obj) # slowest way when we actually might need a deepcopy
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
_EMPTY = object()
|
||||
|
||||
|
||||
def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
|
||||
@@ -362,70 +329,7 @@ def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bo
|
||||
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
|
||||
False
|
||||
"""
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
|
||||
if left_item is not right_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SafeGetItemProxy:
|
||||
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
|
||||
|
||||
This makes is safe to use in `operator.itemgetter` when some keys may be missing
|
||||
"""
|
||||
|
||||
# Define __slots__manually for performances
|
||||
# @dataclasses.dataclass() only support slots=True in python>=3.10
|
||||
__slots__ = ('wrapped',)
|
||||
|
||||
wrapped: Mapping[str, Any]
|
||||
|
||||
def __getitem__(self, key: str, /) -> Any:
|
||||
return self.wrapped.get(key, _SENTINEL)
|
||||
|
||||
# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
|
||||
# https://github.com/python/mypy/issues/13713
|
||||
# https://github.com/python/typeshed/pull/8785
|
||||
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def __contains__(self, key: str, /) -> bool:
|
||||
return self.wrapped.__contains__(key)
|
||||
|
||||
|
||||
_ModelT = TypeVar('_ModelT', bound='BaseModel')
|
||||
_RT = TypeVar('_RT')
|
||||
|
||||
|
||||
class deprecated_instance_property(Generic[_ModelT, _RT]):
|
||||
"""A decorator exposing the decorated class method as a property, with a warning on instance access.
|
||||
|
||||
This decorator takes a class method defined on the `BaseModel` class and transforms it into
|
||||
an attribute. The attribute can be accessed on both the class and instances of the class. If accessed
|
||||
via an instance, a deprecation warning is emitted stating that instance access will be removed in V3.
|
||||
"""
|
||||
|
||||
def __init__(self, fget: Callable[[type[_ModelT]], _RT], /) -> None:
|
||||
# Note: fget should be a classmethod:
|
||||
self.fget = fget
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, objtype: type[_ModelT]) -> _RT: ...
|
||||
@overload
|
||||
@deprecated(
|
||||
'Accessing this attribute on the instance is deprecated, and will be removed in Pydantic V3. '
|
||||
'Instead, you should access this attribute from the model class.',
|
||||
category=None,
|
||||
)
|
||||
def __get__(self, instance: _ModelT, objtype: type[_ModelT]) -> _RT: ...
|
||||
def __get__(self, instance: _ModelT | None, objtype: type[_ModelT]) -> _RT:
|
||||
if instance is not None:
|
||||
attr_name = self.fget.__name__ if sys.version_info >= (3, 10) else self.fget.__func__.__name__
|
||||
warnings.warn(
|
||||
f'Accessing the {attr_name!r} attribute on the instance is deprecated. '
|
||||
'Instead, you should access this attribute from the model class.',
|
||||
category=PydanticDeprecatedSince211,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.fget.__get__(instance, objtype)()
|
||||
|
||||
@@ -1,122 +1,88 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import pydantic_core
|
||||
|
||||
from ..config import ConfigDict
|
||||
from ..plugin._schema_validator import create_schema_validator
|
||||
from . import _discriminated_union, _generate_schema, _typing_extra
|
||||
from ._config import ConfigWrapper
|
||||
from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
|
||||
from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
|
||||
from ._core_utils import simplify_schema_references, validate_core_schema
|
||||
|
||||
|
||||
def extract_function_name(func: ValidateCallSupportedTypes) -> str:
|
||||
"""Extract the name of a `ValidateCallSupportedTypes` object."""
|
||||
return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
|
||||
|
||||
|
||||
def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
|
||||
"""Extract the qualname of a `ValidateCallSupportedTypes` object."""
|
||||
return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
|
||||
|
||||
|
||||
def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
|
||||
"""Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
|
||||
if inspect.iscoroutinefunction(wrapped):
|
||||
|
||||
@functools.wraps(wrapped)
|
||||
async def wrapper_function(*args, **kwargs): # type: ignore
|
||||
return await wrapper(*args, **kwargs)
|
||||
else:
|
||||
|
||||
@functools.wraps(wrapped)
|
||||
def wrapper_function(*args, **kwargs):
|
||||
return wrapper(*args, **kwargs)
|
||||
|
||||
# We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
|
||||
wrapper_function.__name__ = extract_function_name(wrapped)
|
||||
wrapper_function.__qualname__ = extract_function_qualname(wrapped)
|
||||
wrapper_function.raw_function = wrapped # type: ignore
|
||||
|
||||
return wrapper_function
|
||||
@dataclass
|
||||
class CallMarker:
|
||||
function: Callable[..., Any]
|
||||
validate_return: bool
|
||||
|
||||
|
||||
class ValidateCallWrapper:
|
||||
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
|
||||
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value.
|
||||
|
||||
It's partially inspired by `wraps` which in turn uses `partial`, but extended to be a descriptor so
|
||||
these functions can be applied to instance methods, class methods, static methods, as well as normal functions.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'function',
|
||||
'validate_return',
|
||||
'schema_type',
|
||||
'module',
|
||||
'qualname',
|
||||
'ns_resolver',
|
||||
'config_wrapper',
|
||||
'__pydantic_complete__',
|
||||
'raw_function',
|
||||
'_config',
|
||||
'_validate_return',
|
||||
'__pydantic_core_schema__',
|
||||
'__pydantic_validator__',
|
||||
'__return_pydantic_validator__',
|
||||
'__signature__',
|
||||
'__name__',
|
||||
'__qualname__',
|
||||
'__annotations__',
|
||||
'__dict__', # required for __module__
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: ValidateCallSupportedTypes,
|
||||
config: ConfigDict | None,
|
||||
validate_return: bool,
|
||||
parent_namespace: MappingNamespace | None,
|
||||
) -> None:
|
||||
self.function = function
|
||||
self.validate_return = validate_return
|
||||
def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
|
||||
self.raw_function = function
|
||||
self._config = config
|
||||
self._validate_return = validate_return
|
||||
self.__signature__ = inspect.signature(function)
|
||||
if isinstance(function, partial):
|
||||
self.schema_type = function.func
|
||||
self.module = function.func.__module__
|
||||
func = function.func
|
||||
self.__name__ = f'partial({func.__name__})'
|
||||
self.__qualname__ = f'partial({func.__qualname__})'
|
||||
self.__annotations__ = func.__annotations__
|
||||
self.__module__ = func.__module__
|
||||
self.__doc__ = func.__doc__
|
||||
else:
|
||||
self.schema_type = function
|
||||
self.module = function.__module__
|
||||
self.qualname = extract_function_qualname(function)
|
||||
self.__name__ = function.__name__
|
||||
self.__qualname__ = function.__qualname__
|
||||
self.__annotations__ = function.__annotations__
|
||||
self.__module__ = function.__module__
|
||||
self.__doc__ = function.__doc__
|
||||
|
||||
self.ns_resolver = NsResolver(
|
||||
namespaces_tuple=ns_for_function(self.schema_type, parent_namespace=parent_namespace)
|
||||
)
|
||||
self.config_wrapper = ConfigWrapper(config)
|
||||
if not self.config_wrapper.defer_build:
|
||||
self._create_validators()
|
||||
else:
|
||||
self.__pydantic_complete__ = False
|
||||
namespace = _typing_extra.add_module_globals(function, None)
|
||||
config_wrapper = ConfigWrapper(config)
|
||||
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
|
||||
schema = gen_schema.collect_definitions(gen_schema.generate_schema(function))
|
||||
schema = simplify_schema_references(schema)
|
||||
self.__pydantic_core_schema__ = schema = schema
|
||||
core_config = config_wrapper.core_config(self)
|
||||
schema = _discriminated_union.apply_discriminators(schema)
|
||||
self.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings)
|
||||
|
||||
def _create_validators(self) -> None:
|
||||
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(self.function))
|
||||
core_config = self.config_wrapper.core_config(title=self.qualname)
|
||||
|
||||
self.__pydantic_validator__ = create_schema_validator(
|
||||
schema,
|
||||
self.schema_type,
|
||||
self.module,
|
||||
self.qualname,
|
||||
'validate_call',
|
||||
core_config,
|
||||
self.config_wrapper.plugin_settings,
|
||||
)
|
||||
if self.validate_return:
|
||||
signature = inspect.signature(self.function)
|
||||
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
|
||||
gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver)
|
||||
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
|
||||
validator = create_schema_validator(
|
||||
schema,
|
||||
self.schema_type,
|
||||
self.module,
|
||||
self.qualname,
|
||||
'validate_call',
|
||||
core_config,
|
||||
self.config_wrapper.plugin_settings,
|
||||
if self._validate_return:
|
||||
return_type = (
|
||||
self.__signature__.return_annotation
|
||||
if self.__signature__.return_annotation is not self.__signature__.empty
|
||||
else Any
|
||||
)
|
||||
if inspect.iscoroutinefunction(self.function):
|
||||
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
|
||||
schema = gen_schema.collect_definitions(gen_schema.generate_schema(return_type))
|
||||
schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema))
|
||||
self.__return_pydantic_core_schema__ = schema
|
||||
core_config = config_wrapper.core_config(self)
|
||||
schema = validate_core_schema(schema)
|
||||
validator = pydantic_core.SchemaValidator(schema, core_config)
|
||||
if inspect.iscoroutinefunction(self.raw_function):
|
||||
|
||||
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
|
||||
return validator.validate_python(await aw)
|
||||
@@ -125,16 +91,38 @@ class ValidateCallWrapper:
|
||||
else:
|
||||
self.__return_pydantic_validator__ = validator.validate_python
|
||||
else:
|
||||
self.__return_pydantic_core_schema__ = None
|
||||
self.__return_pydantic_validator__ = None
|
||||
|
||||
self.__pydantic_complete__ = True
|
||||
self._name: str | None = None # set by __get__, used to set the instance attribute when decorating methods
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
if not self.__pydantic_complete__:
|
||||
self._create_validators()
|
||||
|
||||
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
|
||||
if self.__return_pydantic_validator__:
|
||||
return self.__return_pydantic_validator__(res)
|
||||
else:
|
||||
return res
|
||||
return res
|
||||
|
||||
def __get__(self, obj: Any, objtype: type[Any] | None = None) -> ValidateCallWrapper:
|
||||
"""Bind the raw function and return another ValidateCallWrapper wrapping that."""
|
||||
if obj is None:
|
||||
try:
|
||||
# Handle the case where a method is accessed as a class attribute
|
||||
return objtype.__getattribute__(objtype, self._name) # type: ignore
|
||||
except AttributeError:
|
||||
# This will happen the first time the attribute is accessed
|
||||
pass
|
||||
|
||||
bound_function = self.raw_function.__get__(obj, objtype)
|
||||
result = self.__class__(bound_function, self._config, self._validate_return)
|
||||
if self._name is not None:
|
||||
if obj is not None:
|
||||
object.__setattr__(obj, self._name, result)
|
||||
else:
|
||||
object.__setattr__(objtype, self._name, result)
|
||||
return result
|
||||
|
||||
def __set_name__(self, owner: Any, name: str) -> None:
|
||||
self._name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'ValidateCallWrapper({self.raw_function})'
|
||||
|
||||
@@ -5,32 +5,22 @@ Import of this module is deferred since it contains imports of many standard lib
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import collections.abc
|
||||
import math
|
||||
import re
|
||||
import typing
|
||||
from decimal import Decimal
|
||||
from fractions import Fraction
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from typing import Any, Callable, Union, cast, get_origin
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
from typing import Any
|
||||
|
||||
import typing_extensions
|
||||
from pydantic_core import PydanticCustomError, core_schema
|
||||
from pydantic_core._pydantic_core import PydanticKnownError
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from pydantic._internal._import_utils import import_cached_field_info
|
||||
from pydantic.errors import PydanticSchemaGenerationError
|
||||
|
||||
|
||||
def sequence_validator(
|
||||
input_value: typing.Sequence[Any],
|
||||
/,
|
||||
__input_value: typing.Sequence[Any],
|
||||
validator: core_schema.ValidatorFunctionWrapHandler,
|
||||
) -> typing.Sequence[Any]:
|
||||
"""Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
|
||||
value_type = type(input_value)
|
||||
value_type = type(__input_value)
|
||||
|
||||
# We don't accept any plain string as a sequence
|
||||
# Relevant issue: https://github.com/pydantic/pydantic/issues/5595
|
||||
@@ -41,24 +31,14 @@ def sequence_validator(
|
||||
{'type_name': value_type.__name__},
|
||||
)
|
||||
|
||||
# TODO: refactor sequence validation to validate with either a list or a tuple
|
||||
# schema, depending on the type of the value.
|
||||
# Additionally, we should be able to remove one of either this validator or the
|
||||
# SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic).
|
||||
# Effectively, a refactor for sequence validation is needed.
|
||||
if value_type is tuple:
|
||||
input_value = list(input_value)
|
||||
|
||||
v_list = validator(input_value)
|
||||
v_list = validator(__input_value)
|
||||
|
||||
# the rest of the logic is just re-creating the original type from `v_list`
|
||||
if value_type is list:
|
||||
if value_type == list:
|
||||
return v_list
|
||||
elif issubclass(value_type, range):
|
||||
# return the list as we probably can't re-create the range
|
||||
return v_list
|
||||
elif value_type is tuple:
|
||||
return tuple(v_list)
|
||||
else:
|
||||
# best guess at how to re-create the original type, more custom construction logic might be required
|
||||
return value_type(v_list) # type: ignore[call-arg]
|
||||
@@ -69,7 +49,7 @@ def import_string(value: Any) -> Any:
|
||||
try:
|
||||
return _import_string_logic(value)
|
||||
except ImportError as e:
|
||||
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
|
||||
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)})
|
||||
else:
|
||||
# otherwise we just return the value and let the next validator do the rest of the work
|
||||
return value
|
||||
@@ -126,39 +106,39 @@ def _import_string_logic(dotted_path: str) -> Any:
|
||||
return module
|
||||
|
||||
|
||||
def pattern_either_validator(input_value: Any, /) -> typing.Pattern[Any]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
return input_value
|
||||
elif isinstance(input_value, (str, bytes)):
|
||||
def pattern_either_validator(__input_value: Any) -> typing.Pattern[Any]:
|
||||
if isinstance(__input_value, typing.Pattern):
|
||||
return __input_value
|
||||
elif isinstance(__input_value, (str, bytes)):
|
||||
# todo strict mode
|
||||
return compile_pattern(input_value) # type: ignore
|
||||
return compile_pattern(__input_value) # type: ignore
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_str_validator(input_value: Any, /) -> typing.Pattern[str]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
if isinstance(input_value.pattern, str):
|
||||
return input_value
|
||||
def pattern_str_validator(__input_value: Any) -> typing.Pattern[str]:
|
||||
if isinstance(__input_value, typing.Pattern):
|
||||
if isinstance(__input_value.pattern, str):
|
||||
return __input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
elif isinstance(input_value, str):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, bytes):
|
||||
elif isinstance(__input_value, str):
|
||||
return compile_pattern(__input_value)
|
||||
elif isinstance(__input_value, bytes):
|
||||
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
|
||||
|
||||
def pattern_bytes_validator(input_value: Any, /) -> typing.Pattern[bytes]:
|
||||
if isinstance(input_value, typing.Pattern):
|
||||
if isinstance(input_value.pattern, bytes):
|
||||
return input_value
|
||||
def pattern_bytes_validator(__input_value: Any) -> typing.Pattern[bytes]:
|
||||
if isinstance(__input_value, typing.Pattern):
|
||||
if isinstance(__input_value.pattern, bytes):
|
||||
return __input_value
|
||||
else:
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
elif isinstance(input_value, bytes):
|
||||
return compile_pattern(input_value)
|
||||
elif isinstance(input_value, str):
|
||||
elif isinstance(__input_value, bytes):
|
||||
return compile_pattern(__input_value)
|
||||
elif isinstance(__input_value, str):
|
||||
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
|
||||
else:
|
||||
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
|
||||
@@ -174,359 +154,125 @@ def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]:
|
||||
raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
|
||||
|
||||
|
||||
def ip_v4_address_validator(input_value: Any, /) -> IPv4Address:
|
||||
if isinstance(input_value, IPv4Address):
|
||||
return input_value
|
||||
def ip_v4_address_validator(__input_value: Any) -> IPv4Address:
|
||||
if isinstance(__input_value, IPv4Address):
|
||||
return __input_value
|
||||
|
||||
try:
|
||||
return IPv4Address(input_value)
|
||||
return IPv4Address(__input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
|
||||
|
||||
|
||||
def ip_v6_address_validator(input_value: Any, /) -> IPv6Address:
|
||||
if isinstance(input_value, IPv6Address):
|
||||
return input_value
|
||||
def ip_v6_address_validator(__input_value: Any) -> IPv6Address:
|
||||
if isinstance(__input_value, IPv6Address):
|
||||
return __input_value
|
||||
|
||||
try:
|
||||
return IPv6Address(input_value)
|
||||
return IPv6Address(__input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
|
||||
|
||||
|
||||
def ip_v4_network_validator(input_value: Any, /) -> IPv4Network:
|
||||
def ip_v4_network_validator(__input_value: Any) -> IPv4Network:
|
||||
"""Assume IPv4Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
|
||||
"""
|
||||
if isinstance(input_value, IPv4Network):
|
||||
return input_value
|
||||
if isinstance(__input_value, IPv4Network):
|
||||
return __input_value
|
||||
|
||||
try:
|
||||
return IPv4Network(input_value)
|
||||
return IPv4Network(__input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
|
||||
|
||||
|
||||
def ip_v6_network_validator(input_value: Any, /) -> IPv6Network:
|
||||
def ip_v6_network_validator(__input_value: Any) -> IPv6Network:
|
||||
"""Assume IPv6Network initialised with a default `strict` argument.
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
|
||||
"""
|
||||
if isinstance(input_value, IPv6Network):
|
||||
return input_value
|
||||
if isinstance(__input_value, IPv6Network):
|
||||
return __input_value
|
||||
|
||||
try:
|
||||
return IPv6Network(input_value)
|
||||
return IPv6Network(__input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
|
||||
|
||||
|
||||
def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface:
|
||||
if isinstance(input_value, IPv4Interface):
|
||||
return input_value
|
||||
def ip_v4_interface_validator(__input_value: Any) -> IPv4Interface:
|
||||
if isinstance(__input_value, IPv4Interface):
|
||||
return __input_value
|
||||
|
||||
try:
|
||||
return IPv4Interface(input_value)
|
||||
return IPv4Interface(__input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
|
||||
|
||||
|
||||
def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface:
|
||||
if isinstance(input_value, IPv6Interface):
|
||||
return input_value
|
||||
def ip_v6_interface_validator(__input_value: Any) -> IPv6Interface:
|
||||
if isinstance(__input_value, IPv6Interface):
|
||||
return __input_value
|
||||
|
||||
try:
|
||||
return IPv6Interface(input_value)
|
||||
return IPv6Interface(__input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
|
||||
|
||||
|
||||
def fraction_validator(input_value: Any, /) -> Fraction:
|
||||
if isinstance(input_value, Fraction):
|
||||
return input_value
|
||||
def greater_than_validator(x: Any, gt: Any) -> Any:
|
||||
if not (x > gt):
|
||||
raise PydanticKnownError('greater_than', {'gt': gt})
|
||||
return x
|
||||
|
||||
try:
|
||||
return Fraction(input_value)
|
||||
except ValueError:
|
||||
raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction')
|
||||
|
||||
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
|
||||
if not (x >= ge):
|
||||
raise PydanticKnownError('greater_than_equal', {'ge': ge})
|
||||
return x
|
||||
|
||||
|
||||
def less_than_validator(x: Any, lt: Any) -> Any:
|
||||
if not (x < lt):
|
||||
raise PydanticKnownError('less_than', {'lt': lt})
|
||||
return x
|
||||
|
||||
|
||||
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
|
||||
if not (x <= le):
|
||||
raise PydanticKnownError('less_than_equal', {'le': le})
|
||||
return x
|
||||
|
||||
|
||||
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
|
||||
if not (x % multiple_of == 0):
|
||||
raise PydanticKnownError('multiple_of', {'multiple_of': multiple_of})
|
||||
return x
|
||||
|
||||
|
||||
def min_length_validator(x: Any, min_length: Any) -> Any:
|
||||
if not (len(x) >= min_length):
|
||||
raise PydanticKnownError(
|
||||
'too_short',
|
||||
{'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def max_length_validator(x: Any, max_length: Any) -> Any:
|
||||
if len(x) > max_length:
|
||||
raise PydanticKnownError(
|
||||
'too_long',
|
||||
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def forbid_inf_nan_check(x: Any) -> Any:
|
||||
if not math.isfinite(x):
|
||||
raise PydanticKnownError('finite_number')
|
||||
return x
|
||||
|
||||
|
||||
def _safe_repr(v: Any) -> int | float | str:
|
||||
"""The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta.
|
||||
|
||||
See tests/test_types.py::test_annotated_metadata_any_order for some context.
|
||||
"""
|
||||
if isinstance(v, (int, float, str)):
|
||||
return v
|
||||
return repr(v)
|
||||
|
||||
|
||||
def greater_than_validator(x: Any, gt: Any) -> Any:
|
||||
try:
|
||||
if not (x > gt):
|
||||
raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}")
|
||||
|
||||
|
||||
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
|
||||
try:
|
||||
if not (x >= ge):
|
||||
raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}")
|
||||
|
||||
|
||||
def less_than_validator(x: Any, lt: Any) -> Any:
|
||||
try:
|
||||
if not (x < lt):
|
||||
raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}")
|
||||
|
||||
|
||||
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
|
||||
try:
|
||||
if not (x <= le):
|
||||
raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}")
|
||||
|
||||
|
||||
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
|
||||
try:
|
||||
if x % multiple_of:
|
||||
raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)})
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}")
|
||||
|
||||
|
||||
def min_length_validator(x: Any, min_length: Any) -> Any:
|
||||
try:
|
||||
if not (len(x) >= min_length):
|
||||
raise PydanticKnownError(
|
||||
'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}")
|
||||
|
||||
|
||||
def max_length_validator(x: Any, max_length: Any) -> Any:
|
||||
try:
|
||||
if len(x) > max_length:
|
||||
raise PydanticKnownError(
|
||||
'too_long',
|
||||
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}")
|
||||
|
||||
|
||||
def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]:
|
||||
"""Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance.
|
||||
|
||||
This function handles both normalized and non-normalized Decimal instances.
|
||||
Example: Decimal('1.230') -> 4 digits, 3 decimal places
|
||||
|
||||
Args:
|
||||
decimal (Decimal): The decimal number to analyze.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: A tuple containing the number of decimal places and total digits.
|
||||
|
||||
Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation
|
||||
of the number of decimals and digits together.
|
||||
"""
|
||||
try:
|
||||
decimal_tuple = decimal.as_tuple()
|
||||
|
||||
assert isinstance(decimal_tuple.exponent, int)
|
||||
|
||||
exponent = decimal_tuple.exponent
|
||||
num_digits = len(decimal_tuple.digits)
|
||||
|
||||
if exponent >= 0:
|
||||
# A positive exponent adds that many trailing zeros
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits
|
||||
num_digits += exponent
|
||||
decimal_places = 0
|
||||
else:
|
||||
# If the absolute value of the negative exponent is larger than the
|
||||
# number of digits, then it's the same as the number of digits,
|
||||
# because it'll consume all the digits in digit_tuple and then
|
||||
# add abs(exponent) - len(digit_tuple) leading zeros after the decimal point.
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits
|
||||
# Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits
|
||||
decimal_places = abs(exponent)
|
||||
num_digits = max(num_digits, decimal_places)
|
||||
|
||||
return decimal_places, num_digits
|
||||
except (AssertionError, AttributeError):
|
||||
raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}')
|
||||
|
||||
|
||||
def max_digits_validator(x: Any, max_digits: Any) -> Any:
|
||||
try:
|
||||
_, num_digits = _extract_decimal_digits_info(x)
|
||||
_, normalized_num_digits = _extract_decimal_digits_info(x.normalize())
|
||||
if (num_digits > max_digits) and (normalized_num_digits > max_digits):
|
||||
raise PydanticKnownError(
|
||||
'decimal_max_digits',
|
||||
{'max_digits': max_digits},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}")
|
||||
|
||||
|
||||
def decimal_places_validator(x: Any, decimal_places: Any) -> Any:
|
||||
try:
|
||||
decimal_places_, _ = _extract_decimal_digits_info(x)
|
||||
if decimal_places_ > decimal_places:
|
||||
normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize())
|
||||
if normalized_decimal_places > decimal_places:
|
||||
raise PydanticKnownError(
|
||||
'decimal_max_places',
|
||||
{'decimal_places': decimal_places},
|
||||
)
|
||||
return x
|
||||
except TypeError:
|
||||
raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}")
|
||||
|
||||
|
||||
def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]:
|
||||
return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None))
|
||||
|
||||
|
||||
def defaultdict_validator(
|
||||
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
|
||||
) -> collections.defaultdict[Any, Any]:
|
||||
if isinstance(input_value, collections.defaultdict):
|
||||
default_factory = input_value.default_factory
|
||||
return collections.defaultdict(default_factory, handler(input_value))
|
||||
else:
|
||||
return collections.defaultdict(default_default_factory, handler(input_value))
|
||||
|
||||
|
||||
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
|
||||
FieldInfo = import_cached_field_info()
|
||||
|
||||
values_type_origin = get_origin(values_source_type)
|
||||
|
||||
def infer_default() -> Callable[[], Any]:
|
||||
allowed_default_types: dict[Any, Any] = {
|
||||
tuple: tuple,
|
||||
collections.abc.Sequence: tuple,
|
||||
collections.abc.MutableSequence: list,
|
||||
list: list,
|
||||
typing.Sequence: list,
|
||||
set: set,
|
||||
typing.MutableSet: set,
|
||||
collections.abc.MutableSet: set,
|
||||
collections.abc.Set: frozenset,
|
||||
typing.MutableMapping: dict,
|
||||
typing.Mapping: dict,
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
float: float,
|
||||
int: int,
|
||||
str: str,
|
||||
bool: bool,
|
||||
}
|
||||
values_type = values_type_origin or values_source_type
|
||||
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
|
||||
if typing_objects.is_typevar(values_type):
|
||||
|
||||
def type_var_default_factory() -> None:
|
||||
raise RuntimeError(
|
||||
'Generic defaultdict cannot be used without a concrete value type or an'
|
||||
' explicit default factory, ' + instructions
|
||||
)
|
||||
|
||||
return type_var_default_factory
|
||||
elif values_type not in allowed_default_types:
|
||||
# a somewhat subjective set of types that have reasonable default values
|
||||
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
|
||||
raise PydanticSchemaGenerationError(
|
||||
f'Unable to infer a default factory for keys of type {values_source_type}.'
|
||||
f' Only {allowed_msg} are supported, other types require an explicit default factory'
|
||||
' ' + instructions
|
||||
)
|
||||
return allowed_default_types[values_type]
|
||||
|
||||
# Assume Annotated[..., Field(...)]
|
||||
if typing_objects.is_annotated(values_type_origin):
|
||||
field_info = next((v for v in typing_extensions.get_args(values_source_type) if isinstance(v, FieldInfo)), None)
|
||||
else:
|
||||
field_info = None
|
||||
if field_info and field_info.default_factory:
|
||||
# Assume the default factory does not take any argument:
|
||||
default_default_factory = cast(Callable[[], Any], field_info.default_factory)
|
||||
else:
|
||||
default_default_factory = infer_default()
|
||||
return default_default_factory
|
||||
|
||||
|
||||
def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo:
|
||||
if isinstance(value, ZoneInfo):
|
||||
return value
|
||||
try:
|
||||
return ZoneInfo(value)
|
||||
except (ZoneInfoNotFoundError, ValueError, TypeError):
|
||||
raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value})
|
||||
|
||||
|
||||
NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = {
|
||||
'gt': greater_than_validator,
|
||||
'ge': greater_than_or_equal_validator,
|
||||
'lt': less_than_validator,
|
||||
'le': less_than_or_equal_validator,
|
||||
'multiple_of': multiple_of_validator,
|
||||
'min_length': min_length_validator,
|
||||
'max_length': max_length_validator,
|
||||
'max_digits': max_digits_validator,
|
||||
'decimal_places': decimal_places_validator,
|
||||
}
|
||||
|
||||
IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface]
|
||||
|
||||
IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = {
|
||||
IPv4Address: ip_v4_address_validator,
|
||||
IPv6Address: ip_v6_address_validator,
|
||||
IPv4Network: ip_v4_network_validator,
|
||||
IPv6Network: ip_v6_network_validator,
|
||||
IPv4Interface: ip_v4_interface_validator,
|
||||
IPv6Interface: ip_v6_interface_validator,
|
||||
}
|
||||
|
||||
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
|
||||
typing.DefaultDict: collections.defaultdict, # noqa: UP006
|
||||
collections.defaultdict: collections.defaultdict,
|
||||
typing.OrderedDict: collections.OrderedDict, # noqa: UP006
|
||||
collections.OrderedDict: collections.OrderedDict,
|
||||
typing_extensions.OrderedDict: collections.OrderedDict,
|
||||
typing.Counter: collections.Counter,
|
||||
collections.Counter: collections.Counter,
|
||||
# this doesn't handle subclasses of these
|
||||
typing.Mapping: dict,
|
||||
typing.MutableMapping: dict,
|
||||
# parametrized typing.{Mutable}Mapping creates one of these
|
||||
collections.abc.Mapping: dict,
|
||||
collections.abc.MutableMapping: dict,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import sys
|
||||
from typing import Any, Callable
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from ._internal._validators import import_string
|
||||
from .version import version_short
|
||||
|
||||
MOVED_IN_V2 = {
|
||||
@@ -271,11 +273,7 @@ def getattr_migration(module: str) -> Callable[[str], Any]:
|
||||
The object.
|
||||
"""
|
||||
if name == '__path__':
|
||||
raise AttributeError(f'module {module!r} has no attribute {name!r}')
|
||||
|
||||
import warnings
|
||||
|
||||
from ._internal._validators import import_string
|
||||
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
|
||||
|
||||
import_path = f'{module}:{name}'
|
||||
if import_path in MOVED_IN_V2.keys():
|
||||
@@ -300,9 +298,9 @@ def getattr_migration(module: str) -> Callable[[str], Any]:
|
||||
)
|
||||
if import_path in REMOVED_IN_V2:
|
||||
raise PydanticImportError(f'`{import_path}` has been removed in V2.')
|
||||
globals: dict[str, Any] = sys.modules[module].__dict__
|
||||
globals: Dict[str, Any] = sys.modules[module].__dict__
|
||||
if name in globals:
|
||||
return globals[name]
|
||||
raise AttributeError(f'module {module!r} has no attribute {name!r}')
|
||||
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
"""Alias generators for converting between different capitalization conventions."""
|
||||
|
||||
import re
|
||||
|
||||
__all__ = ('to_pascal', 'to_camel', 'to_snake')
|
||||
|
||||
# TODO: in V3, change the argument names to be more descriptive
|
||||
# Generally, don't only convert from snake_case, or name the functions
|
||||
# more specifically like snake_to_camel.
|
||||
|
||||
|
||||
def to_pascal(snake: str) -> str:
|
||||
"""Convert a snake_case string to PascalCase.
|
||||
@@ -31,17 +26,12 @@ def to_camel(snake: str) -> str:
|
||||
Returns:
|
||||
The converted camelCase string.
|
||||
"""
|
||||
# If the string is already in camelCase and does not contain a digit followed
|
||||
# by a lowercase letter, return it as it is
|
||||
if re.match('^[a-z]+[A-Za-z0-9]*$', snake) and not re.search(r'\d[a-z]', snake):
|
||||
return snake
|
||||
|
||||
camel = to_pascal(snake)
|
||||
return re.sub('(^_*[A-Z])', lambda m: m.group(1).lower(), camel)
|
||||
|
||||
|
||||
def to_snake(camel: str) -> str:
|
||||
"""Convert a PascalCase, camelCase, or kebab-case string to snake_case.
|
||||
"""Convert a PascalCase or camelCase string to snake_case.
|
||||
|
||||
Args:
|
||||
camel: The string to convert.
|
||||
@@ -49,14 +39,6 @@ def to_snake(camel: str) -> str:
|
||||
Returns:
|
||||
The converted string in snake_case.
|
||||
"""
|
||||
# Handle the sequence of uppercase letters followed by a lowercase letter
|
||||
snake = re.sub(r'([A-Z]+)([A-Z][a-z])', lambda m: f'{m.group(1)}_{m.group(2)}', camel)
|
||||
# Insert an underscore between a lowercase letter and an uppercase letter
|
||||
snake = re.sub(r'([a-z])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
# Insert an underscore between a digit and an uppercase letter
|
||||
snake = re.sub(r'([0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
# Insert an underscore between a lowercase letter and a digit
|
||||
snake = re.sub(r'([a-z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
# Replace hyphens with underscores to handle kebab-case
|
||||
snake = snake.replace('-', '_')
|
||||
snake = re.sub(r'([a-zA-Z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', camel)
|
||||
snake = re.sub(r'([a-z0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
|
||||
return snake.lower()
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Support for alias configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from ._internal import _internal_dataclass
|
||||
|
||||
__all__ = ('AliasGenerator', 'AliasPath', 'AliasChoices')
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class AliasPath:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[`AliasPath` and `AliasChoices`](../concepts/alias.md#aliaspath-and-aliaschoices)
|
||||
|
||||
A data class used by `validation_alias` as a convenience to create aliases.
|
||||
|
||||
Attributes:
|
||||
path: A list of string or integer aliases.
|
||||
"""
|
||||
|
||||
path: list[int | str]
|
||||
|
||||
def __init__(self, first_arg: str, *args: str | int) -> None:
|
||||
self.path = [first_arg] + list(args)
|
||||
|
||||
def convert_to_aliases(self) -> list[str | int]:
|
||||
"""Converts arguments to a list of string or integer aliases.
|
||||
|
||||
Returns:
|
||||
The list of aliases.
|
||||
"""
|
||||
return self.path
|
||||
|
||||
def search_dict_for_path(self, d: dict) -> Any:
|
||||
"""Searches a dictionary for the path specified by the alias.
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or `PydanticUndefined` if the path is not found.
|
||||
"""
|
||||
v = d
|
||||
for k in self.path:
|
||||
if isinstance(v, str):
|
||||
# disallow indexing into a str, like for AliasPath('x', 0) and x='abc'
|
||||
return PydanticUndefined
|
||||
try:
|
||||
v = v[k]
|
||||
except (KeyError, IndexError, TypeError):
|
||||
return PydanticUndefined
|
||||
return v
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class AliasChoices:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[`AliasPath` and `AliasChoices`](../concepts/alias.md#aliaspath-and-aliaschoices)
|
||||
|
||||
A data class used by `validation_alias` as a convenience to create aliases.
|
||||
|
||||
Attributes:
|
||||
choices: A list containing a string or `AliasPath`.
|
||||
"""
|
||||
|
||||
choices: list[str | AliasPath]
|
||||
|
||||
def __init__(self, first_choice: str | AliasPath, *choices: str | AliasPath) -> None:
|
||||
self.choices = [first_choice] + list(choices)
|
||||
|
||||
def convert_to_aliases(self) -> list[list[str | int]]:
|
||||
"""Converts arguments to a list of lists containing string or integer aliases.
|
||||
|
||||
Returns:
|
||||
The list of aliases.
|
||||
"""
|
||||
aliases: list[list[str | int]] = []
|
||||
for c in self.choices:
|
||||
if isinstance(c, AliasPath):
|
||||
aliases.append(c.convert_to_aliases())
|
||||
else:
|
||||
aliases.append([c])
|
||||
return aliases
|
||||
|
||||
|
||||
@dataclasses.dataclass(**_internal_dataclass.slots_true)
|
||||
class AliasGenerator:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[Using an `AliasGenerator`](../concepts/alias.md#using-an-aliasgenerator)
|
||||
|
||||
A data class used by `alias_generator` as a convenience to create various aliases.
|
||||
|
||||
Attributes:
|
||||
alias: A callable that takes a field name and returns an alias for it.
|
||||
validation_alias: A callable that takes a field name and returns a validation alias for it.
|
||||
serialization_alias: A callable that takes a field name and returns a serialization alias for it.
|
||||
"""
|
||||
|
||||
alias: Callable[[str], str] | None = None
|
||||
validation_alias: Callable[[str], str | AliasPath | AliasChoices] | None = None
|
||||
serialization_alias: Callable[[str], str] | None = None
|
||||
|
||||
def _generate_alias(
|
||||
self,
|
||||
alias_kind: Literal['alias', 'validation_alias', 'serialization_alias'],
|
||||
allowed_types: tuple[type[str] | type[AliasPath] | type[AliasChoices], ...],
|
||||
field_name: str,
|
||||
) -> str | AliasPath | AliasChoices | None:
|
||||
"""Generate an alias of the specified kind. Returns None if the alias generator is None.
|
||||
|
||||
Raises:
|
||||
TypeError: If the alias generator produces an invalid type.
|
||||
"""
|
||||
alias = None
|
||||
if alias_generator := getattr(self, alias_kind):
|
||||
alias = alias_generator(field_name)
|
||||
if alias and not isinstance(alias, allowed_types):
|
||||
raise TypeError(
|
||||
f'Invalid `{alias_kind}` type. `{alias_kind}` generator must produce one of `{allowed_types}`'
|
||||
)
|
||||
return alias
|
||||
|
||||
def generate_aliases(self, field_name: str) -> tuple[str | None, str | AliasPath | AliasChoices | None, str | None]:
|
||||
"""Generate `alias`, `validation_alias`, and `serialization_alias` for a field.
|
||||
|
||||
Returns:
|
||||
A tuple of three aliases - validation, alias, and serialization.
|
||||
"""
|
||||
alias = self._generate_alias('alias', (str,), field_name)
|
||||
validation_alias = self._generate_alias('validation_alias', (str, AliasChoices, AliasPath), field_name)
|
||||
serialization_alias = self._generate_alias('serialization_alias', (str,), field_name)
|
||||
|
||||
return alias, validation_alias, serialization_alias # type: ignore
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Type annotations to use with `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__`."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
@@ -7,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Union
|
||||
from pydantic_core import core_schema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._internal._namespace_utils import NamespacesTuple
|
||||
from .json_schema import JsonSchemaMode, JsonSchemaValue
|
||||
|
||||
CoreSchemaOrField = Union[
|
||||
@@ -30,7 +28,7 @@ class GetJsonSchemaHandler:
|
||||
|
||||
mode: JsonSchemaMode
|
||||
|
||||
def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue:
|
||||
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
|
||||
"""Call the inner handler and get the JsonSchemaValue it returns.
|
||||
This will call the next JSON schema modifying function up until it calls
|
||||
into `pydantic.json_schema.GenerateJsonSchema`, which will raise a
|
||||
@@ -38,7 +36,7 @@ class GetJsonSchemaHandler:
|
||||
a JSON schema.
|
||||
|
||||
Args:
|
||||
core_schema: A `pydantic_core.core_schema.CoreSchema`.
|
||||
__core_schema: A `pydantic_core.core_schema.CoreSchema`.
|
||||
|
||||
Returns:
|
||||
JsonSchemaValue: The JSON schema generated by the inner JSON schema modify
|
||||
@@ -46,13 +44,13 @@ class GetJsonSchemaHandler:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue, /) -> JsonSchemaValue:
|
||||
def resolve_ref_schema(self, __maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
|
||||
"""Get the real schema for a `{"$ref": ...}` schema.
|
||||
If the schema given is not a `$ref` schema, it will be returned as is.
|
||||
This means you don't have to check before calling this function.
|
||||
|
||||
Args:
|
||||
maybe_ref_json_schema: A JsonSchemaValue which may be a `$ref` schema.
|
||||
__maybe_ref_json_schema: A JsonSchemaValue, ref based or not.
|
||||
|
||||
Raises:
|
||||
LookupError: If the ref is not found.
|
||||
@@ -66,7 +64,7 @@ class GetJsonSchemaHandler:
|
||||
class GetCoreSchemaHandler:
|
||||
"""Handler to call into the next CoreSchema schema generation function."""
|
||||
|
||||
def __call__(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
|
||||
"""Call the inner handler and get the CoreSchema it returns.
|
||||
This will call the next CoreSchema modifying function up until it calls
|
||||
into Pydantic's internal schema generation machinery, which will raise a
|
||||
@@ -74,14 +72,14 @@ class GetCoreSchemaHandler:
|
||||
a CoreSchema for the given source type.
|
||||
|
||||
Args:
|
||||
source_type: The input type.
|
||||
__source_type: The input type.
|
||||
|
||||
Returns:
|
||||
CoreSchema: The `pydantic-core` CoreSchema generated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema:
|
||||
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
|
||||
"""Generate a schema unrelated to the current context.
|
||||
Use this function if e.g. you are handling schema generation for a sequence
|
||||
and want to generate a schema for its items.
|
||||
@@ -89,20 +87,20 @@ class GetCoreSchemaHandler:
|
||||
that was intended for the sequence itself to its items!
|
||||
|
||||
Args:
|
||||
source_type: The input type.
|
||||
__source_type: The input type.
|
||||
|
||||
Returns:
|
||||
CoreSchema: The `pydantic-core` CoreSchema generated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema, /) -> core_schema.CoreSchema:
|
||||
def resolve_ref_schema(self, __maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
|
||||
"""Get the real schema for a `definition-ref` schema.
|
||||
If the schema given is not a `definition-ref` schema, it will be returned as is.
|
||||
This means you don't have to check before calling this function.
|
||||
|
||||
Args:
|
||||
maybe_ref_schema: A `CoreSchema`, `ref`-based or not.
|
||||
__maybe_ref_schema: A `CoreSchema`, `ref`-based or not.
|
||||
|
||||
Raises:
|
||||
LookupError: If the `ref` is not found.
|
||||
@@ -117,6 +115,6 @@ class GetCoreSchemaHandler:
|
||||
"""Get the name of the closest field to this validator."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_types_namespace(self) -> NamespacesTuple:
|
||||
def _get_types_namespace(self) -> dict[str, Any] | None:
|
||||
"""Internal method used during type resolution for serializer annotations."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""`class_validators` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -11,11 +11,10 @@ Warning: Deprecated
|
||||
See [`pydantic-extra-types.Color`](../usage/types/extra_types/color_types.md)
|
||||
for more information.
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from colorsys import hls_to_rgb, rgb_to_hls
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
from typing import Any, Callable, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from pydantic_core import CoreSchema, PydanticCustomError, core_schema
|
||||
from typing_extensions import deprecated
|
||||
@@ -25,9 +24,9 @@ from ._internal._schema_generation_shared import GetJsonSchemaHandler as _GetJso
|
||||
from .json_schema import JsonSchemaValue
|
||||
from .warnings import PydanticDeprecatedSince20
|
||||
|
||||
ColorTuple = Union[tuple[int, int, int], tuple[int, int, int, float]]
|
||||
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
|
||||
ColorType = Union[ColorTuple, str]
|
||||
HslColorTuple = Union[tuple[float, float, float], tuple[float, float, float, float]]
|
||||
HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]]
|
||||
|
||||
|
||||
class RGBA:
|
||||
@@ -41,7 +40,7 @@ class RGBA:
|
||||
self.b = b
|
||||
self.alpha = alpha
|
||||
|
||||
self._tuple: tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
|
||||
self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
|
||||
|
||||
def __getitem__(self, item: Any) -> Any:
|
||||
return self._tuple[item]
|
||||
@@ -56,13 +55,13 @@ _r_sl = r'(\d{1,3}(?:\.\d+)?)%'
|
||||
r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
|
||||
r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
|
||||
# CSS3 RGB examples: rgb(0, 0, 0), rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 50%)
|
||||
r_rgb = rf'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
|
||||
r_rgb = fr'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
|
||||
# CSS3 HSL examples: hsl(270, 60%, 50%), hsla(270, 60%, 50%, 0.5), hsla(270, 60%, 50%, 50%)
|
||||
r_hsl = rf'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
|
||||
r_hsl = fr'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
|
||||
# CSS4 RGB examples: rgb(0 0 0), rgb(0 0 0 / 0.5), rgb(0 0 0 / 50%), rgba(0 0 0 / 50%)
|
||||
r_rgb_v4_style = rf'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
|
||||
r_rgb_v4_style = fr'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
|
||||
# CSS4 HSL examples: hsl(270 60% 50%), hsl(270 60% 50% / 0.5), hsl(270 60% 50% / 50%), hsla(270 60% 50% / 50%)
|
||||
r_hsl_v4_style = rf'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
|
||||
r_hsl_v4_style = fr'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
|
||||
|
||||
# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used
|
||||
repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'}
|
||||
@@ -124,7 +123,7 @@ class Color(_repr.Representation):
|
||||
ValueError: When no named color is found and fallback is `False`.
|
||||
"""
|
||||
if self._rgba.alpha is None:
|
||||
rgb = cast(tuple[int, int, int], self.as_rgb_tuple())
|
||||
rgb = cast(Tuple[int, int, int], self.as_rgb_tuple())
|
||||
try:
|
||||
return COLORS_BY_VALUE[rgb]
|
||||
except KeyError as e:
|
||||
@@ -232,7 +231,7 @@ class Color(_repr.Representation):
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: type[Any], handler: Callable[[Any], CoreSchema]
|
||||
cls, source: Type[Any], handler: Callable[[Any], CoreSchema]
|
||||
) -> core_schema.CoreSchema:
|
||||
return core_schema.with_info_plain_validator_function(
|
||||
cls._validate, serialization=core_schema.to_string_ser_schema()
|
||||
@@ -255,7 +254,7 @@ class Color(_repr.Representation):
|
||||
return hash(self.as_rgb_tuple())
|
||||
|
||||
|
||||
def parse_tuple(value: tuple[Any, ...]) -> RGBA:
|
||||
def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
|
||||
"""Parse a tuple or list to get RGBA values.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,33 +1,23 @@
|
||||
"""Configuration for Pydantic models."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from re import Pattern
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast, overload
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union
|
||||
|
||||
from typing_extensions import TypeAlias, TypedDict, Unpack, deprecated
|
||||
from typing_extensions import Literal, TypeAlias, TypedDict
|
||||
|
||||
from ._migration import getattr_migration
|
||||
from .aliases import AliasGenerator
|
||||
from .errors import PydanticUserError
|
||||
from .warnings import PydanticDeprecatedSince211
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._internal._generate_schema import GenerateSchema as _GenerateSchema
|
||||
from .fields import ComputedFieldInfo, FieldInfo
|
||||
|
||||
__all__ = ('ConfigDict', 'with_config')
|
||||
__all__ = ('ConfigDict',)
|
||||
|
||||
|
||||
JsonValue: TypeAlias = Union[int, float, str, bool, None, list['JsonValue'], 'JsonDict']
|
||||
JsonDict: TypeAlias = dict[str, JsonValue]
|
||||
|
||||
JsonEncoder = Callable[[Any], Any]
|
||||
|
||||
JsonSchemaExtraCallable: TypeAlias = Union[
|
||||
Callable[[JsonDict], None],
|
||||
Callable[[JsonDict, type[Any]], None],
|
||||
Callable[[Dict[str, Any]], None],
|
||||
Callable[[Dict[str, Any], Type[Any]], None],
|
||||
]
|
||||
|
||||
ExtraValues = Literal['allow', 'ignore', 'forbid']
|
||||
@@ -39,18 +29,11 @@ class ConfigDict(TypedDict, total=False):
|
||||
title: str | None
|
||||
"""The title for the generated JSON schema, defaults to the model's name"""
|
||||
|
||||
model_title_generator: Callable[[type], str] | None
|
||||
"""A callable that takes a model class and returns the title for it. Defaults to `None`."""
|
||||
|
||||
field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None
|
||||
"""A callable that takes a field's name and info and returns title for it. Defaults to `None`."""
|
||||
|
||||
str_to_lower: bool
|
||||
"""Whether to convert all characters to lowercase for str types. Defaults to `False`."""
|
||||
|
||||
str_to_upper: bool
|
||||
"""Whether to convert all characters to uppercase for str types. Defaults to `False`."""
|
||||
|
||||
str_strip_whitespace: bool
|
||||
"""Whether to strip leading and trailing whitespace for str types."""
|
||||
|
||||
@@ -61,108 +44,84 @@ class ConfigDict(TypedDict, total=False):
|
||||
"""The maximum length for str types. Defaults to `None`."""
|
||||
|
||||
extra: ExtraValues | None
|
||||
'''
|
||||
Whether to ignore, allow, or forbid extra data during model initialization. Defaults to `'ignore'`.
|
||||
"""
|
||||
Whether to ignore, allow, or forbid extra attributes during model initialization. Defaults to `'ignore'`.
|
||||
|
||||
Three configuration values are available:
|
||||
You can configure how pydantic handles the attributes that are not defined in the model:
|
||||
|
||||
- `'ignore'`: Providing extra data is ignored (the default):
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
* `allow` - Allow any extra attributes.
|
||||
* `forbid` - Forbid any extra attributes.
|
||||
* `ignore` - Ignore any extra attributes.
|
||||
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(extra='ignore') # (1)!
|
||||
|
||||
name: str
|
||||
|
||||
user = User(name='John Doe', age=20) # (2)!
|
||||
print(user)
|
||||
#> name='John Doe'
|
||||
```
|
||||
|
||||
1. This is the default behaviour.
|
||||
2. The `age` argument is ignored.
|
||||
|
||||
- `'forbid'`: Providing extra data is not permitted, and a [`ValidationError`][pydantic_core.ValidationError]
|
||||
will be raised if this is the case:
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
x: int
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(extra='ignore') # (1)!
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
name: str
|
||||
|
||||
|
||||
try:
|
||||
Model(x=1, y='a')
|
||||
except ValidationError as exc:
|
||||
print(exc)
|
||||
"""
|
||||
1 validation error for Model
|
||||
y
|
||||
Extra inputs are not permitted [type=extra_forbidden, input_value='a', input_type=str]
|
||||
"""
|
||||
```
|
||||
user = User(name='John Doe', age=20) # (2)!
|
||||
print(user)
|
||||
#> name='John Doe'
|
||||
```
|
||||
|
||||
- `'allow'`: Providing extra data is allowed and stored in the `__pydantic_extra__` dictionary attribute:
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
1. This is the default behaviour.
|
||||
2. The `age` argument is ignored.
|
||||
|
||||
Instead, with `extra='allow'`, the `age` argument is included:
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
x: int
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
name: str
|
||||
|
||||
|
||||
m = Model(x=1, y='a')
|
||||
assert m.__pydantic_extra__ == {'y': 'a'}
|
||||
```
|
||||
By default, no validation will be applied to these extra items, but you can set a type for the values by overriding
|
||||
the type annotation for `__pydantic_extra__`:
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
user = User(name='John Doe', age=20) # (1)!
|
||||
print(user)
|
||||
#> name='John Doe' age=20
|
||||
```
|
||||
|
||||
1. The `age` argument is included.
|
||||
|
||||
With `extra='forbid'`, an error is raised:
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
__pydantic_extra__: dict[str, int] = Field(init=False) # (1)!
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
|
||||
x: int
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
name: str
|
||||
|
||||
|
||||
try:
|
||||
Model(x=1, y='a')
|
||||
except ValidationError as exc:
|
||||
print(exc)
|
||||
"""
|
||||
1 validation error for Model
|
||||
y
|
||||
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str]
|
||||
"""
|
||||
|
||||
m = Model(x=1, y='2')
|
||||
assert m.x == 1
|
||||
assert m.y == 2
|
||||
assert m.model_dump() == {'x': 1, 'y': 2}
|
||||
assert m.__pydantic_extra__ == {'y': 2}
|
||||
```
|
||||
|
||||
1. The `= Field(init=False)` does not have any effect at runtime, but prevents the `__pydantic_extra__` field from
|
||||
being included as a parameter to the model's `__init__` method by type checkers.
|
||||
'''
|
||||
try:
|
||||
User(name='John Doe', age=20)
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
'''
|
||||
1 validation error for User
|
||||
age
|
||||
Extra inputs are not permitted [type=extra_forbidden, input_value=20, input_type=int]
|
||||
'''
|
||||
```
|
||||
"""
|
||||
|
||||
frozen: bool
|
||||
"""
|
||||
Whether models are faux-immutable, i.e. whether `__setattr__` is allowed, and also generates
|
||||
Whether or not models are faux-immutable, i.e. whether `__setattr__` is allowed, and also generates
|
||||
a `__hash__()` method for the model. This makes instances of the model potentially hashable if all the
|
||||
attributes are hashable. Defaults to `False`.
|
||||
|
||||
Note:
|
||||
On V1, the inverse of this setting was called `allow_mutation`, and was `True` by default.
|
||||
On V1, this setting was called `allow_mutation`, and was `True` by default.
|
||||
"""
|
||||
|
||||
populate_by_name: bool
|
||||
@@ -170,77 +129,38 @@ class ConfigDict(TypedDict, total=False):
|
||||
Whether an aliased field may be populated by its name as given by the model
|
||||
attribute, as well as the alias. Defaults to `False`.
|
||||
|
||||
!!! warning
|
||||
`populate_by_name` usage is not recommended in v2.11+ and will be deprecated in v3.
|
||||
Instead, you should use the [`validate_by_name`][pydantic.config.ConfigDict.validate_by_name] configuration setting.
|
||||
Note:
|
||||
The name of this configuration setting was changed in **v2.0** from
|
||||
`allow_population_by_alias` to `populate_by_name`.
|
||||
|
||||
When `validate_by_name=True` and `validate_by_alias=True`, this is strictly equivalent to the
|
||||
previous behavior of `populate_by_name=True`.
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
In v2.11, we also introduced a [`validate_by_alias`][pydantic.config.ConfigDict.validate_by_alias] setting that introduces more fine grained
|
||||
control for validation behavior.
|
||||
|
||||
Here's how you might go about using the new settings to achieve the same behavior:
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
name: str = Field(alias='full_name') # (1)!
|
||||
age: int
|
||||
|
||||
class Model(BaseModel):
|
||||
model_config = ConfigDict(validate_by_name=True, validate_by_alias=True)
|
||||
|
||||
my_field: str = Field(alias='my_alias') # (1)!
|
||||
user = User(full_name='John Doe', age=20) # (2)!
|
||||
print(user)
|
||||
#> name='John Doe' age=20
|
||||
user = User(name='John Doe', age=20) # (3)!
|
||||
print(user)
|
||||
#> name='John Doe' age=20
|
||||
```
|
||||
|
||||
m = Model(my_alias='foo') # (2)!
|
||||
print(m)
|
||||
#> my_field='foo'
|
||||
|
||||
m = Model(my_alias='foo') # (3)!
|
||||
print(m)
|
||||
#> my_field='foo'
|
||||
```
|
||||
|
||||
1. The field `'my_field'` has an alias `'my_alias'`.
|
||||
2. The model is populated by the alias `'my_alias'`.
|
||||
3. The model is populated by the attribute name `'my_field'`.
|
||||
1. The field `'name'` has an alias `'full_name'`.
|
||||
2. The model is populated by the alias `'full_name'`.
|
||||
3. The model is populated by the field name `'name'`.
|
||||
"""
|
||||
|
||||
use_enum_values: bool
|
||||
"""
|
||||
Whether to populate models with the `value` property of enums, rather than the raw enum.
|
||||
This may be useful if you want to serialize `model.model_dump()` later. Defaults to `False`.
|
||||
|
||||
!!! note
|
||||
If you have an `Optional[Enum]` value that you set a default for, you need to use `validate_default=True`
|
||||
for said Field to ensure that the `use_enum_values` flag takes effect on the default, as extracting an
|
||||
enum's value occurs during validation, not serialization.
|
||||
|
||||
```python
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
class SomeEnum(Enum):
|
||||
FOO = 'foo'
|
||||
BAR = 'bar'
|
||||
BAZ = 'baz'
|
||||
|
||||
class SomeModel(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
some_enum: SomeEnum
|
||||
another_enum: Optional[SomeEnum] = Field(
|
||||
default=SomeEnum.FOO, validate_default=True
|
||||
)
|
||||
|
||||
model1 = SomeModel(some_enum=SomeEnum.BAR)
|
||||
print(model1.model_dump())
|
||||
#> {'some_enum': 'bar', 'another_enum': 'foo'}
|
||||
|
||||
model2 = SomeModel(some_enum=SomeEnum.BAR, another_enum=SomeEnum.BAZ)
|
||||
print(model2.model_dump())
|
||||
#> {'some_enum': 'bar', 'another_enum': 'baz'}
|
||||
```
|
||||
"""
|
||||
|
||||
validate_assignment: bool
|
||||
@@ -251,7 +171,7 @@ class ConfigDict(TypedDict, total=False):
|
||||
|
||||
In case the user changes the data after the model is created, the model is _not_ revalidated.
|
||||
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel
|
||||
|
||||
class User(BaseModel):
|
||||
@@ -270,7 +190,7 @@ class ConfigDict(TypedDict, total=False):
|
||||
|
||||
In case you want to revalidate the model when the data is changed, you can use `validate_assignment=True`:
|
||||
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
class User(BaseModel, validate_assignment=True): # (1)!
|
||||
@@ -299,7 +219,7 @@ class ConfigDict(TypedDict, total=False):
|
||||
"""
|
||||
Whether arbitrary types are allowed for field types. Defaults to `False`.
|
||||
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
|
||||
# This is not a pydantic model, it's an arbitrary class
|
||||
@@ -358,20 +278,14 @@ class ConfigDict(TypedDict, total=False):
|
||||
loc_by_alias: bool
|
||||
"""Whether to use the actual key provided in the data (e.g. alias) for error `loc`s rather than the field's name. Defaults to `True`."""
|
||||
|
||||
alias_generator: Callable[[str], str] | AliasGenerator | None
|
||||
alias_generator: Callable[[str], str] | None
|
||||
"""
|
||||
A callable that takes a field name and returns an alias for it
|
||||
or an instance of [`AliasGenerator`][pydantic.aliases.AliasGenerator]. Defaults to `None`.
|
||||
|
||||
When using a callable, the alias generator is used for both validation and serialization.
|
||||
If you want to use different alias generators for validation and serialization, you can use
|
||||
[`AliasGenerator`][pydantic.aliases.AliasGenerator] instead.
|
||||
A callable that takes a field name and returns an alias for it.
|
||||
|
||||
If data source field names do not match your code style (e. g. CamelCase fields),
|
||||
you can automatically generate aliases using `alias_generator`. Here's an example with
|
||||
a basic callable:
|
||||
you can automatically generate aliases using `alias_generator`:
|
||||
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.alias_generators import to_pascal
|
||||
|
||||
@@ -388,30 +302,6 @@ class ConfigDict(TypedDict, total=False):
|
||||
#> {'Name': 'Filiz', 'LanguageCode': 'tr-TR'}
|
||||
```
|
||||
|
||||
If you want to use different alias generators for validation and serialization, you can use
|
||||
[`AliasGenerator`][pydantic.aliases.AliasGenerator].
|
||||
|
||||
```python
|
||||
from pydantic import AliasGenerator, BaseModel, ConfigDict
|
||||
from pydantic.alias_generators import to_camel, to_pascal
|
||||
|
||||
class Athlete(BaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
sport: str
|
||||
|
||||
model_config = ConfigDict(
|
||||
alias_generator=AliasGenerator(
|
||||
validation_alias=to_camel,
|
||||
serialization_alias=to_pascal,
|
||||
)
|
||||
)
|
||||
|
||||
athlete = Athlete(firstName='John', lastName='Doe', sport='track')
|
||||
print(athlete.model_dump(by_alias=True))
|
||||
#> {'FirstName': 'John', 'LastName': 'Doe', 'Sport': 'track'}
|
||||
```
|
||||
|
||||
Note:
|
||||
Pydantic offers three built-in alias generators: [`to_pascal`][pydantic.alias_generators.to_pascal],
|
||||
[`to_camel`][pydantic.alias_generators.to_camel], and [`to_snake`][pydantic.alias_generators.to_snake].
|
||||
@@ -425,9 +315,9 @@ class ConfigDict(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
allow_inf_nan: bool
|
||||
"""Whether to allow infinity (`+inf` an `-inf`) and NaN values to float and decimal fields. Defaults to `True`."""
|
||||
"""Whether to allow infinity (`+inf` an `-inf`) and NaN values to float fields. Defaults to `True`."""
|
||||
|
||||
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
|
||||
json_schema_extra: dict[str, object] | JsonSchemaExtraCallable | None
|
||||
"""A dict or callable to provide extra JSON schema properties. Defaults to `None`."""
|
||||
|
||||
json_encoders: dict[type[object], JsonEncoder] | None
|
||||
@@ -452,7 +342,7 @@ class ConfigDict(TypedDict, total=False):
|
||||
|
||||
To configure strict mode for all fields on a model, you can set `strict=True` on the model.
|
||||
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
class Model(BaseModel):
|
||||
@@ -480,14 +370,16 @@ class ConfigDict(TypedDict, total=False):
|
||||
|
||||
By default, model and dataclass instances are not revalidated during validation.
|
||||
|
||||
```python
|
||||
```py
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class User(BaseModel, revalidate_instances='never'): # (1)!
|
||||
hobbies: list[str]
|
||||
hobbies: List[str]
|
||||
|
||||
class SubUser(User):
|
||||
sins: list[str]
|
||||
sins: List[str]
|
||||
|
||||
class Transaction(BaseModel):
|
||||
user: User
|
||||
@@ -515,14 +407,16 @@ class ConfigDict(TypedDict, total=False):
|
||||
If you want to revalidate instances during validation, you can set `revalidate_instances` to `'always'`
|
||||
in the model's config.
|
||||
|
||||
```python
|
||||
```py
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
class User(BaseModel, revalidate_instances='always'): # (1)!
|
||||
hobbies: list[str]
|
||||
hobbies: List[str]
|
||||
|
||||
class SubUser(User):
|
||||
sins: list[str]
|
||||
sins: List[str]
|
||||
|
||||
class Transaction(BaseModel):
|
||||
user: User
|
||||
@@ -556,14 +450,16 @@ class ConfigDict(TypedDict, total=False):
|
||||
It's also possible to set `revalidate_instances` to `'subclass-instances'` to only revalidate instances
|
||||
of subclasses of the model.
|
||||
|
||||
```python
|
||||
```py
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class User(BaseModel, revalidate_instances='subclass-instances'): # (1)!
|
||||
hobbies: list[str]
|
||||
hobbies: List[str]
|
||||
|
||||
class SubUser(User):
|
||||
sins: list[str]
|
||||
sins: List[str]
|
||||
|
||||
class Transaction(BaseModel):
|
||||
user: User
|
||||
@@ -598,33 +494,13 @@ class ConfigDict(TypedDict, total=False):
|
||||
- `'float'` will serialize timedeltas to the total number of seconds.
|
||||
"""
|
||||
|
||||
ser_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
ser_json_bytes: Literal['utf8', 'base64']
|
||||
"""
|
||||
The encoding of JSON serialized bytes. Defaults to `'utf8'`.
|
||||
Set equal to `val_json_bytes` to get back an equal value after serialization round trip.
|
||||
The encoding of JSON serialized bytes. Accepts the string values of `'utf8'` and `'base64'`.
|
||||
Defaults to `'utf8'`.
|
||||
|
||||
- `'utf8'` will serialize bytes to UTF-8 strings.
|
||||
- `'base64'` will serialize bytes to URL safe base64 strings.
|
||||
- `'hex'` will serialize bytes to hexadecimal strings.
|
||||
"""
|
||||
|
||||
val_json_bytes: Literal['utf8', 'base64', 'hex']
|
||||
"""
|
||||
The encoding of JSON serialized bytes to decode. Defaults to `'utf8'`.
|
||||
Set equal to `ser_json_bytes` to get back an equal value after serialization round trip.
|
||||
|
||||
- `'utf8'` will deserialize UTF-8 strings to bytes.
|
||||
- `'base64'` will deserialize URL safe base64 strings to bytes.
|
||||
- `'hex'` will deserialize hexadecimal strings to bytes.
|
||||
"""
|
||||
|
||||
ser_json_inf_nan: Literal['null', 'constants', 'strings']
|
||||
"""
|
||||
The encoding of JSON serialized infinity and NaN float values. Defaults to `'null'`.
|
||||
|
||||
- `'null'` will serialize infinity and NaN values as `null`.
|
||||
- `'constants'` will serialize infinity and NaN values as `Infinity` and `NaN`.
|
||||
- `'strings'` will serialize infinity as string `"Infinity"` and NaN as string `"NaN"`.
|
||||
"""
|
||||
|
||||
# whether to validate default values during validation, default False
|
||||
@@ -632,26 +508,17 @@ class ConfigDict(TypedDict, total=False):
|
||||
"""Whether to validate default values during validation. Defaults to `False`."""
|
||||
|
||||
validate_return: bool
|
||||
"""Whether to validate the return value from call validators. Defaults to `False`."""
|
||||
"""whether to validate the return value from call validators. Defaults to `False`."""
|
||||
|
||||
protected_namespaces: tuple[str | Pattern[str], ...]
|
||||
protected_namespaces: tuple[str, ...]
|
||||
"""
|
||||
A `tuple` of strings and/or patterns that prevent models from having fields with names that conflict with them.
|
||||
For strings, we match on a prefix basis. Ex, if 'dog' is in the protected namespace, 'dog_name' will be protected.
|
||||
For patterns, we match on the entire field name. Ex, if `re.compile(r'^dog$')` is in the protected namespace, 'dog' will be protected, but 'dog_name' will not be.
|
||||
Defaults to `('model_validate', 'model_dump',)`.
|
||||
A `tuple` of strings that prevent model to have field which conflict with them.
|
||||
Defaults to `('model_', )`).
|
||||
|
||||
The reason we've selected these is to prevent collisions with other validation / dumping formats
|
||||
in the future - ex, `model_validate_{some_newly_supported_format}`.
|
||||
Pydantic prevents collisions between model attributes and `BaseModel`'s own methods by
|
||||
namespacing them with the prefix `model_`.
|
||||
|
||||
Before v2.10, Pydantic used `('model_',)` as the default value for this setting to
|
||||
prevent collisions between model attributes and `BaseModel`'s own methods. This was changed
|
||||
in v2.10 given feedback that this restriction was limiting in AI and data science contexts,
|
||||
where it is common to have fields with names like `model_id`, `model_input`, `model_output`, etc.
|
||||
|
||||
For more details, see https://github.com/pydantic/pydantic/issues/10315.
|
||||
|
||||
```python
|
||||
```py
|
||||
import warnings
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -661,65 +528,56 @@ class ConfigDict(TypedDict, total=False):
|
||||
try:
|
||||
|
||||
class Model(BaseModel):
|
||||
model_dump_something: str
|
||||
model_prefixed_field: str
|
||||
|
||||
except UserWarning as e:
|
||||
print(e)
|
||||
'''
|
||||
Field "model_dump_something" in Model has conflict with protected namespace "model_dump".
|
||||
Field "model_prefixed_field" has conflict with protected namespace "model_".
|
||||
|
||||
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('model_validate',)`.
|
||||
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
|
||||
'''
|
||||
```
|
||||
|
||||
You can customize this behavior using the `protected_namespaces` setting:
|
||||
|
||||
```python {test="skip"}
|
||||
import re
|
||||
```py
|
||||
import warnings
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter('always') # Catch all warnings
|
||||
warnings.filterwarnings('error') # Raise warnings as errors
|
||||
|
||||
try:
|
||||
|
||||
class Model(BaseModel):
|
||||
safe_field: str
|
||||
model_prefixed_field: str
|
||||
also_protect_field: str
|
||||
protect_this: str
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(
|
||||
'protect_me_',
|
||||
'also_protect_',
|
||||
re.compile('^protect_this$'),
|
||||
)
|
||||
protected_namespaces=('protect_me_', 'also_protect_')
|
||||
)
|
||||
|
||||
for warning in caught_warnings:
|
||||
print(f'{warning.message}')
|
||||
except UserWarning as e:
|
||||
print(e)
|
||||
'''
|
||||
Field "also_protect_field" in Model has conflict with protected namespace "also_protect_".
|
||||
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_', re.compile('^protect_this$'))`.
|
||||
Field "also_protect_field" has conflict with protected namespace "also_protect_".
|
||||
|
||||
Field "protect_this" in Model has conflict with protected namespace "re.compile('^protect_this$')".
|
||||
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_', 'also_protect_')`.
|
||||
You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_',)`.
|
||||
'''
|
||||
```
|
||||
|
||||
While Pydantic will only emit a warning when an item is in a protected namespace but does not actually have a collision,
|
||||
an error _is_ raised if there is an actual collision with an existing attribute:
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
```py
|
||||
from pydantic import BaseModel
|
||||
|
||||
try:
|
||||
|
||||
class Model(BaseModel):
|
||||
model_validate: str
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=('model_',))
|
||||
|
||||
except NameError as e:
|
||||
print(e)
|
||||
'''
|
||||
@@ -734,7 +592,7 @@ class ConfigDict(TypedDict, total=False):
|
||||
|
||||
Pydantic shows the input value and type when it raises `ValidationError` during the validation.
|
||||
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
class Model(BaseModel):
|
||||
@@ -753,7 +611,7 @@ class ConfigDict(TypedDict, total=False):
|
||||
|
||||
You can hide the input value and type by setting the `hide_input_in_errors` config to `True`.
|
||||
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
|
||||
class Model(BaseModel):
|
||||
@@ -774,26 +632,27 @@ class ConfigDict(TypedDict, total=False):
|
||||
|
||||
defer_build: bool
|
||||
"""
|
||||
Whether to defer model validator and serializer construction until the first model validation. Defaults to False.
|
||||
Whether to defer model validator and serializer construction until the first model validation.
|
||||
|
||||
This can be useful to avoid the overhead of building models which are only
|
||||
used nested within other models, or when you want to manually define type namespace via
|
||||
[`Model.model_rebuild(_types_namespace=...)`][pydantic.BaseModel.model_rebuild].
|
||||
|
||||
Since v2.10, this setting also applies to pydantic dataclasses and TypeAdapter instances.
|
||||
[`Model.model_rebuild(_types_namespace=...)`][pydantic.BaseModel.model_rebuild]. Defaults to False.
|
||||
"""
|
||||
|
||||
plugin_settings: dict[str, object] | None
|
||||
"""A `dict` of settings for plugins. Defaults to `None`."""
|
||||
"""A `dict` of settings for plugins. Defaults to `None`.
|
||||
|
||||
See [Pydantic Plugins](../concepts/plugins.md) for details.
|
||||
"""
|
||||
|
||||
schema_generator: type[_GenerateSchema] | None
|
||||
"""
|
||||
!!! warning
|
||||
`schema_generator` is deprecated in v2.10.
|
||||
A custom core schema generator class to use when generating JSON schemas.
|
||||
Useful if you want to change the way types are validated across an entire model/schema. Defaults to `None`.
|
||||
|
||||
Prior to v2.10, this setting was advertised as highly subject to change.
|
||||
It's possible that this interface may once again become public once the internal core schema generation
|
||||
API is more stable, but that will likely come after significant performance improvements have been made.
|
||||
The `GenerateSchema` interface is subject to change, currently only the `string_schema` method is public.
|
||||
|
||||
See [#6737](https://github.com/pydantic/pydantic/pull/6737) for details.
|
||||
"""
|
||||
|
||||
json_schema_serialization_defaults_required: bool
|
||||
@@ -807,7 +666,7 @@ class ConfigDict(TypedDict, total=False):
|
||||
between validation and serialization, and don't mind fields with defaults being marked as not required during
|
||||
serialization. See [#7209](https://github.com/pydantic/pydantic/issues/7209) for more details.
|
||||
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
class Model(BaseModel):
|
||||
@@ -850,7 +709,7 @@ class ConfigDict(TypedDict, total=False):
|
||||
the validation and serialization schemas (since both will use the specified schema), and so prevents the suffixes
|
||||
from being added to the definition references.
|
||||
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict, Json
|
||||
|
||||
class Model(BaseModel):
|
||||
@@ -896,7 +755,7 @@ class ConfigDict(TypedDict, total=False):
|
||||
|
||||
Pydantic doesn't allow number types (`int`, `float`, `Decimal`) to be coerced as type `str` by default.
|
||||
|
||||
```python
|
||||
```py
|
||||
from decimal import Decimal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
@@ -928,286 +787,5 @@ class ConfigDict(TypedDict, total=False):
|
||||
```
|
||||
"""
|
||||
|
||||
regex_engine: Literal['rust-regex', 'python-re']
|
||||
"""
|
||||
The regex engine to be used for pattern validation.
|
||||
Defaults to `'rust-regex'`.
|
||||
|
||||
- `rust-regex` uses the [`regex`](https://docs.rs/regex) Rust crate,
|
||||
which is non-backtracking and therefore more DDoS resistant, but does not support all regex features.
|
||||
- `python-re` use the [`re`](https://docs.python.org/3/library/re.html) module,
|
||||
which supports all regex features, but may be slower.
|
||||
|
||||
!!! note
|
||||
If you use a compiled regex pattern, the python-re engine will be used regardless of this setting.
|
||||
This is so that flags such as `re.IGNORECASE` are respected.
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
|
||||
class Model(BaseModel):
|
||||
model_config = ConfigDict(regex_engine='python-re')
|
||||
|
||||
value: str = Field(pattern=r'^abc(?=def)')
|
||||
|
||||
print(Model(value='abcdef').value)
|
||||
#> abcdef
|
||||
|
||||
try:
|
||||
print(Model(value='abxyzcdef'))
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
'''
|
||||
1 validation error for Model
|
||||
value
|
||||
String should match pattern '^abc(?=def)' [type=string_pattern_mismatch, input_value='abxyzcdef', input_type=str]
|
||||
'''
|
||||
```
|
||||
"""
|
||||
|
||||
validation_error_cause: bool
|
||||
"""
|
||||
If `True`, Python exceptions that were part of a validation failure will be shown as an exception group as a cause. Can be useful for debugging. Defaults to `False`.
|
||||
|
||||
Note:
|
||||
Python 3.10 and older don't support exception groups natively. <=3.10, backport must be installed: `pip install exceptiongroup`.
|
||||
|
||||
Note:
|
||||
The structure of validation errors are likely to change in future Pydantic versions. Pydantic offers no guarantees about their structure. Should be used for visual traceback debugging only.
|
||||
"""
|
||||
|
||||
use_attribute_docstrings: bool
|
||||
'''
|
||||
Whether docstrings of attributes (bare string literals immediately following the attribute declaration)
|
||||
should be used for field descriptions. Defaults to `False`.
|
||||
|
||||
Available in Pydantic v2.7+.
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
model_config = ConfigDict(use_attribute_docstrings=True)
|
||||
|
||||
x: str
|
||||
"""
|
||||
Example of an attribute docstring
|
||||
"""
|
||||
|
||||
y: int = Field(description="Description in Field")
|
||||
"""
|
||||
Description in Field overrides attribute docstring
|
||||
"""
|
||||
|
||||
|
||||
print(Model.model_fields["x"].description)
|
||||
# > Example of an attribute docstring
|
||||
print(Model.model_fields["y"].description)
|
||||
# > Description in Field
|
||||
```
|
||||
This requires the source code of the class to be available at runtime.
|
||||
|
||||
!!! warning "Usage with `TypedDict` and stdlib dataclasses"
|
||||
Due to current limitations, attribute docstrings detection may not work as expected when using
|
||||
[`TypedDict`][typing.TypedDict] and stdlib dataclasses, in particular when:
|
||||
|
||||
- inheritance is being used.
|
||||
- multiple classes have the same name in the same source file.
|
||||
'''
|
||||
|
||||
cache_strings: bool | Literal['all', 'keys', 'none']
|
||||
"""
|
||||
Whether to cache strings to avoid constructing new Python objects. Defaults to True.
|
||||
|
||||
Enabling this setting should significantly improve validation performance while increasing memory usage slightly.
|
||||
|
||||
- `True` or `'all'` (the default): cache all strings
|
||||
- `'keys'`: cache only dictionary keys
|
||||
- `False` or `'none'`: no caching
|
||||
|
||||
!!! note
|
||||
`True` or `'all'` is required to cache strings during general validation because
|
||||
validators don't know if they're in a key or a value.
|
||||
|
||||
!!! tip
|
||||
If repeated strings are rare, it's recommended to use `'keys'` or `'none'` to reduce memory usage,
|
||||
as the performance difference is minimal if repeated strings are rare.
|
||||
"""
|
||||
|
||||
validate_by_alias: bool
|
||||
"""
|
||||
Whether an aliased field may be populated by its alias. Defaults to `True`.
|
||||
|
||||
!!! note
|
||||
In v2.11, `validate_by_alias` was introduced in conjunction with [`validate_by_name`][pydantic.ConfigDict.validate_by_name]
|
||||
to empower users with more fine grained validation control. In <v2.11, disabling validation by alias was not possible.
|
||||
|
||||
Here's an example of disabling validation by alias:
|
||||
|
||||
```py
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
class Model(BaseModel):
|
||||
model_config = ConfigDict(validate_by_name=True, validate_by_alias=False)
|
||||
|
||||
my_field: str = Field(validation_alias='my_alias') # (1)!
|
||||
|
||||
m = Model(my_field='foo') # (2)!
|
||||
print(m)
|
||||
#> my_field='foo'
|
||||
```
|
||||
|
||||
1. The field `'my_field'` has an alias `'my_alias'`.
|
||||
2. The model can only be populated by the attribute name `'my_field'`.
|
||||
|
||||
!!! warning
|
||||
You cannot set both `validate_by_alias` and `validate_by_name` to `False`.
|
||||
This would make it impossible to populate an attribute.
|
||||
|
||||
See [usage errors](../errors/usage_errors.md#validate-by-alias-and-name-false) for an example.
|
||||
|
||||
If you set `validate_by_alias` to `False`, under the hood, Pydantic dynamically sets
|
||||
`validate_by_name` to `True` to ensure that validation can still occur.
|
||||
"""
|
||||
|
||||
validate_by_name: bool
|
||||
"""
|
||||
Whether an aliased field may be populated by its name as given by the model
|
||||
attribute. Defaults to `False`.
|
||||
|
||||
!!! note
|
||||
In v2.0-v2.10, the `populate_by_name` configuration setting was used to specify
|
||||
whether or not a field could be populated by its name **and** alias.
|
||||
|
||||
In v2.11, `validate_by_name` was introduced in conjunction with [`validate_by_alias`][pydantic.ConfigDict.validate_by_alias]
|
||||
to empower users with more fine grained validation behavior control.
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
class Model(BaseModel):
|
||||
model_config = ConfigDict(validate_by_name=True, validate_by_alias=True)
|
||||
|
||||
my_field: str = Field(validation_alias='my_alias') # (1)!
|
||||
|
||||
m = Model(my_alias='foo') # (2)!
|
||||
print(m)
|
||||
#> my_field='foo'
|
||||
|
||||
m = Model(my_field='foo') # (3)!
|
||||
print(m)
|
||||
#> my_field='foo'
|
||||
```
|
||||
|
||||
1. The field `'my_field'` has an alias `'my_alias'`.
|
||||
2. The model is populated by the alias `'my_alias'`.
|
||||
3. The model is populated by the attribute name `'my_field'`.
|
||||
|
||||
!!! warning
|
||||
You cannot set both `validate_by_alias` and `validate_by_name` to `False`.
|
||||
This would make it impossible to populate an attribute.
|
||||
|
||||
See [usage errors](../errors/usage_errors.md#validate-by-alias-and-name-false) for an example.
|
||||
"""
|
||||
|
||||
serialize_by_alias: bool
|
||||
"""
|
||||
Whether an aliased field should be serialized by its alias. Defaults to `False`.
|
||||
|
||||
Note: In v2.11, `serialize_by_alias` was introduced to address the
|
||||
[popular request](https://github.com/pydantic/pydantic/issues/8379)
|
||||
for consistency with alias behavior for validation and serialization settings.
|
||||
In v3, the default value is expected to change to `True` for consistency with the validation default.
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
class Model(BaseModel):
|
||||
model_config = ConfigDict(serialize_by_alias=True)
|
||||
|
||||
my_field: str = Field(serialization_alias='my_alias') # (1)!
|
||||
|
||||
m = Model(my_field='foo')
|
||||
print(m.model_dump()) # (2)!
|
||||
#> {'my_alias': 'foo'}
|
||||
```
|
||||
|
||||
1. The field `'my_field'` has an alias `'my_alias'`.
|
||||
2. The model is serialized using the alias `'my_alias'` for the `'my_field'` attribute.
|
||||
"""
|
||||
|
||||
|
||||
_TypeT = TypeVar('_TypeT', bound=type)
|
||||
|
||||
|
||||
@overload
|
||||
@deprecated('Passing `config` as a keyword argument is deprecated. Pass `config` as a positional argument instead.')
|
||||
def with_config(*, config: ConfigDict) -> Callable[[_TypeT], _TypeT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def with_config(config: ConfigDict, /) -> Callable[[_TypeT], _TypeT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def with_config(**config: Unpack[ConfigDict]) -> Callable[[_TypeT], _TypeT]: ...
|
||||
|
||||
|
||||
def with_config(config: ConfigDict | None = None, /, **kwargs: Any) -> Callable[[_TypeT], _TypeT]:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[Configuration with other types](../concepts/config.md#configuration-on-other-supported-types)
|
||||
|
||||
A convenience decorator to set a [Pydantic configuration](config.md) on a `TypedDict` or a `dataclass` from the standard library.
|
||||
|
||||
Although the configuration can be set using the `__pydantic_config__` attribute, it does not play well with type checkers,
|
||||
especially with `TypedDict`.
|
||||
|
||||
!!! example "Usage"
|
||||
|
||||
```python
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from pydantic import ConfigDict, TypeAdapter, with_config
|
||||
|
||||
@with_config(ConfigDict(str_to_lower=True))
|
||||
class TD(TypedDict):
|
||||
x: str
|
||||
|
||||
ta = TypeAdapter(TD)
|
||||
|
||||
print(ta.validate_python({'x': 'ABC'}))
|
||||
#> {'x': 'abc'}
|
||||
```
|
||||
"""
|
||||
if config is not None and kwargs:
|
||||
raise ValueError('Cannot specify both `config` and keyword arguments')
|
||||
|
||||
if len(kwargs) == 1 and (kwargs_conf := kwargs.get('config')) is not None:
|
||||
warnings.warn(
|
||||
'Passing `config` as a keyword argument is deprecated. Pass `config` as a positional argument instead',
|
||||
category=PydanticDeprecatedSince211,
|
||||
stacklevel=2,
|
||||
)
|
||||
final_config = cast(ConfigDict, kwargs_conf)
|
||||
else:
|
||||
final_config = config if config is not None else cast(ConfigDict, kwargs)
|
||||
|
||||
def inner(class_: _TypeT, /) -> _TypeT:
|
||||
# Ideally, we would check for `class_` to either be a `TypedDict` or a stdlib dataclass.
|
||||
# However, the `@with_config` decorator can be applied *after* `@dataclass`. To avoid
|
||||
# common mistakes, we at least check for `class_` to not be a Pydantic model.
|
||||
from ._internal._utils import is_model_class
|
||||
|
||||
if is_model_class(class_):
|
||||
raise PydanticUserError(
|
||||
f'Cannot use `with_config` on {class_.__name__} as it is a Pydantic model',
|
||||
code='with-config-on-model',
|
||||
)
|
||||
class_.__pydantic_config__ = final_config
|
||||
return class_
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
"""Provide an enhanced dataclass that performs validation."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NoReturn, TypeVar, overload
|
||||
from warnings import warn
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, TypeVar, overload
|
||||
|
||||
from typing_extensions import TypeGuard, dataclass_transform
|
||||
from typing_extensions import Literal, TypeGuard, dataclass_transform
|
||||
|
||||
from ._internal import _config, _decorators, _namespace_utils, _typing_extra
|
||||
from ._internal import _config, _decorators, _typing_extra
|
||||
from ._internal import _dataclasses as _pydantic_dataclasses
|
||||
from ._migration import getattr_migration
|
||||
from .config import ConfigDict
|
||||
from .errors import PydanticUserError
|
||||
from .fields import Field, FieldInfo, PrivateAttr
|
||||
from .fields import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._internal._dataclasses import PydanticDataclass
|
||||
from ._internal._namespace_utils import MappingNamespace
|
||||
|
||||
__all__ = 'dataclass', 'rebuild_dataclass'
|
||||
|
||||
@@ -27,7 +23,7 @@ _T = TypeVar('_T')
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
@@ -44,7 +40,7 @@ if sys.version_info >= (3, 10):
|
||||
) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: type[_T], # type: ignore
|
||||
@@ -54,16 +50,17 @@ if sys.version_info >= (3, 10):
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
frozen: bool = False,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
kw_only: bool = ...,
|
||||
slots: bool = ...,
|
||||
) -> type[PydanticDataclass]: ...
|
||||
) -> type[PydanticDataclass]:
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
@@ -72,13 +69,13 @@ else:
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
frozen: bool = False,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
|
||||
...
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: type[_T], # type: ignore
|
||||
@@ -88,13 +85,14 @@ else:
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
frozen: bool = False,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
) -> type[PydanticDataclass]: ...
|
||||
) -> type[PydanticDataclass]:
|
||||
...
|
||||
|
||||
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr))
|
||||
@dataclass_transform(field_specifiers=(dataclasses.field, Field))
|
||||
def dataclass(
|
||||
_cls: type[_T] | None = None,
|
||||
*,
|
||||
@@ -103,14 +101,13 @@ def dataclass(
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool | None = None,
|
||||
frozen: bool = False,
|
||||
config: ConfigDict | type[object] | None = None,
|
||||
validate_on_init: bool | None = None,
|
||||
kw_only: bool = False,
|
||||
slots: bool = False,
|
||||
) -> Callable[[type[_T]], type[PydanticDataclass]] | type[PydanticDataclass]:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[`dataclasses`](../concepts/dataclasses.md)
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/dataclasses/
|
||||
|
||||
A decorator used to create a Pydantic-enhanced dataclass, similar to the standard Python `dataclass`,
|
||||
but with added validation.
|
||||
@@ -122,13 +119,13 @@ def dataclass(
|
||||
init: Included for signature compatibility with `dataclasses.dataclass`, and is passed through to
|
||||
`dataclasses.dataclass` when appropriate. If specified, must be set to `False`, as pydantic inserts its
|
||||
own `__init__` function.
|
||||
repr: A boolean indicating whether to include the field in the `__repr__` output.
|
||||
eq: Determines if a `__eq__` method should be generated for the class.
|
||||
repr: A boolean indicating whether or not to include the field in the `__repr__` output.
|
||||
eq: Determines if a `__eq__` should be generated for the class.
|
||||
order: Determines if comparison magic methods should be generated, such as `__lt__`, but not `__eq__`.
|
||||
unsafe_hash: Determines if a `__hash__` method should be included in the class, as in `dataclasses.dataclass`.
|
||||
unsafe_hash: Determines if an unsafe hashing function should be included in the class.
|
||||
frozen: Determines if the generated class should be a 'frozen' `dataclass`, which does not allow its
|
||||
attributes to be modified after it has been initialized. If not set, the value from the provided `config` argument will be used (and will default to `False` otherwise).
|
||||
config: The Pydantic config to use for the `dataclass`.
|
||||
attributes to be modified from its constructor.
|
||||
config: A configuration for the `dataclass` generation.
|
||||
validate_on_init: A deprecated parameter included for backwards compatibility; in V2, all Pydantic dataclasses
|
||||
are validated on init.
|
||||
kw_only: Determines if `__init__` method parameters must be specified by keyword only. Defaults to `False`.
|
||||
@@ -145,43 +142,10 @@ def dataclass(
|
||||
assert validate_on_init is not False, 'validate_on_init=False is no longer supported'
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
kwargs = {'kw_only': kw_only, 'slots': slots}
|
||||
kwargs = dict(kw_only=kw_only, slots=slots)
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
def make_pydantic_fields_compatible(cls: type[Any]) -> None:
|
||||
"""Make sure that stdlib `dataclasses` understands `Field` kwargs like `kw_only`
|
||||
To do that, we simply change
|
||||
`x: int = pydantic.Field(..., kw_only=True)`
|
||||
into
|
||||
`x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)`
|
||||
"""
|
||||
for annotation_cls in cls.__mro__:
|
||||
annotations: dict[str, Any] = getattr(annotation_cls, '__annotations__', {})
|
||||
for field_name in annotations:
|
||||
field_value = getattr(cls, field_name, None)
|
||||
# Process only if this is an instance of `FieldInfo`.
|
||||
if not isinstance(field_value, FieldInfo):
|
||||
continue
|
||||
|
||||
# Initialize arguments for the standard `dataclasses.field`.
|
||||
field_args: dict = {'default': field_value}
|
||||
|
||||
# Handle `kw_only` for Python 3.10+
|
||||
if sys.version_info >= (3, 10) and field_value.kw_only:
|
||||
field_args['kw_only'] = True
|
||||
|
||||
# Set `repr` attribute if it's explicitly specified to be not `True`.
|
||||
if field_value.repr is not True:
|
||||
field_args['repr'] = field_value.repr
|
||||
|
||||
setattr(cls, field_name, dataclasses.field(**field_args))
|
||||
# In Python 3.9, when subclassing, information is pulled from cls.__dict__['__annotations__']
|
||||
# for annotations, so we must make sure it's initialized before we add to it.
|
||||
if cls.__dict__.get('__annotations__') is None:
|
||||
cls.__annotations__ = {}
|
||||
cls.__annotations__[field_name] = annotations[field_name]
|
||||
|
||||
def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
|
||||
"""Create a Pydantic dataclass from a regular dataclass.
|
||||
|
||||
@@ -191,29 +155,14 @@ def dataclass(
|
||||
Returns:
|
||||
A Pydantic dataclass.
|
||||
"""
|
||||
from ._internal._utils import is_model_class
|
||||
|
||||
if is_model_class(cls):
|
||||
raise PydanticUserError(
|
||||
f'Cannot create a Pydantic dataclass from {cls.__name__} as it is already a Pydantic model',
|
||||
code='dataclass-on-model',
|
||||
)
|
||||
|
||||
original_cls = cls
|
||||
|
||||
# we warn on conflicting config specifications, but only if the class doesn't have a dataclass base
|
||||
# because a dataclass base might provide a __pydantic_config__ attribute that we don't want to warn about
|
||||
has_dataclass_base = any(dataclasses.is_dataclass(base) for base in cls.__bases__)
|
||||
if not has_dataclass_base and config is not None and hasattr(cls, '__pydantic_config__'):
|
||||
warn(
|
||||
f'`config` is set via both the `dataclass` decorator and `__pydantic_config__` for dataclass {cls.__name__}. '
|
||||
f'The `config` specification from `dataclass` decorator will take priority.',
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# if config is not explicitly provided, try to read it from the type
|
||||
config_dict = config if config is not None else getattr(cls, '__pydantic_config__', None)
|
||||
config_dict = config
|
||||
if config_dict is None:
|
||||
# if not explicitly provided, read from the type
|
||||
cls_config = getattr(cls, '__pydantic_config__', None)
|
||||
if cls_config is not None:
|
||||
config_dict = cls_config
|
||||
config_wrapper = _config.ConfigWrapper(config_dict)
|
||||
decorators = _decorators.DecoratorInfos.build(cls)
|
||||
|
||||
@@ -236,22 +185,6 @@ def dataclass(
|
||||
bases = bases + (generic_base,)
|
||||
cls = types.new_class(cls.__name__, bases)
|
||||
|
||||
make_pydantic_fields_compatible(cls)
|
||||
|
||||
# Respect frozen setting from dataclass constructor and fallback to config setting if not provided
|
||||
if frozen is not None:
|
||||
frozen_ = frozen
|
||||
if config_wrapper.frozen:
|
||||
# It's not recommended to define both, as the setting from the dataclass decorator will take priority.
|
||||
warn(
|
||||
f'`frozen` is set via both the `dataclass` decorator and `config` for dataclass {cls.__name__!r}.'
|
||||
'This is not recommended. The `frozen` specification on `dataclass` will take priority.',
|
||||
category=UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
else:
|
||||
frozen_ = config_wrapper.frozen or False
|
||||
|
||||
cls = dataclasses.dataclass( # type: ignore[call-overload]
|
||||
cls,
|
||||
# the value of init here doesn't affect anything except that it makes it easier to generate a signature
|
||||
@@ -260,40 +193,29 @@ def dataclass(
|
||||
eq=eq,
|
||||
order=order,
|
||||
unsafe_hash=unsafe_hash,
|
||||
frozen=frozen_,
|
||||
frozen=frozen,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# This is an undocumented attribute to distinguish stdlib/Pydantic dataclasses.
|
||||
# It should be set as early as possible:
|
||||
cls.__is_pydantic_dataclass__ = True
|
||||
|
||||
cls.__pydantic_decorators__ = decorators # type: ignore
|
||||
cls.__doc__ = original_doc
|
||||
cls.__module__ = original_cls.__module__
|
||||
cls.__qualname__ = original_cls.__qualname__
|
||||
cls.__pydantic_fields_complete__ = classmethod(_pydantic_fields_complete)
|
||||
cls.__pydantic_complete__ = False # `complete_dataclass` will set it to `True` if successful.
|
||||
# TODO `parent_namespace` is currently None, but we could do the same thing as Pydantic models:
|
||||
# fetch the parent ns using `parent_frame_namespace` (if the dataclass was defined in a function),
|
||||
# and possibly cache it (see the `__pydantic_parent_namespace__` logic for models).
|
||||
_pydantic_dataclasses.complete_dataclass(cls, config_wrapper, raise_errors=False)
|
||||
pydantic_complete = _pydantic_dataclasses.complete_dataclass(
|
||||
cls, config_wrapper, raise_errors=False, types_namespace=None
|
||||
)
|
||||
cls.__pydantic_complete__ = pydantic_complete # type: ignore
|
||||
return cls
|
||||
|
||||
return create_dataclass if _cls is None else create_dataclass(_cls)
|
||||
if _cls is None:
|
||||
return create_dataclass
|
||||
|
||||
|
||||
def _pydantic_fields_complete(cls: type[PydanticDataclass]) -> bool:
|
||||
"""Return whether the fields where successfully collected (i.e. type hints were successfully resolves).
|
||||
|
||||
This is a private property, not meant to be used outside Pydantic.
|
||||
"""
|
||||
return all(field_info._complete for field_info in cls.__pydantic_fields__.values())
|
||||
return create_dataclass(_cls)
|
||||
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
if (3, 8) <= sys.version_info < (3, 11):
|
||||
# Monkeypatch dataclasses.InitVar so that typing doesn't error if it occurs as a type when evaluating type hints
|
||||
# Starting in 3.11, typing.get_type_hints will not raise an error if the retrieved type hints are not callable.
|
||||
|
||||
@@ -313,7 +235,7 @@ def rebuild_dataclass(
|
||||
force: bool = False,
|
||||
raise_errors: bool = True,
|
||||
_parent_namespace_depth: int = 2,
|
||||
_types_namespace: MappingNamespace | None = None,
|
||||
_types_namespace: dict[str, Any] | None = None,
|
||||
) -> bool | None:
|
||||
"""Try to rebuild the pydantic-core schema for the dataclass.
|
||||
|
||||
@@ -323,8 +245,8 @@ def rebuild_dataclass(
|
||||
This is analogous to `BaseModel.model_rebuild`.
|
||||
|
||||
Args:
|
||||
cls: The class to rebuild the pydantic-core schema for.
|
||||
force: Whether to force the rebuilding of the schema, defaults to `False`.
|
||||
cls: The class to build the dataclass core schema for.
|
||||
force: Whether to force the rebuilding of the model schema, defaults to `False`.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
_parent_namespace_depth: The depth level of the parent namespace, defaults to 2.
|
||||
_types_namespace: The types namespace, defaults to `None`.
|
||||
@@ -335,49 +257,34 @@ def rebuild_dataclass(
|
||||
"""
|
||||
if not force and cls.__pydantic_complete__:
|
||||
return None
|
||||
|
||||
for attr in ('__pydantic_core_schema__', '__pydantic_validator__', '__pydantic_serializer__'):
|
||||
if attr in cls.__dict__:
|
||||
# Deleting the validator/serializer is necessary as otherwise they can get reused in
|
||||
# pydantic-core. Same applies for the core schema that can be reused in schema generation.
|
||||
delattr(cls, attr)
|
||||
|
||||
cls.__pydantic_complete__ = False
|
||||
|
||||
if _types_namespace is not None:
|
||||
rebuild_ns = _types_namespace
|
||||
elif _parent_namespace_depth > 0:
|
||||
rebuild_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {}
|
||||
else:
|
||||
rebuild_ns = {}
|
||||
if _types_namespace is not None:
|
||||
types_namespace: dict[str, Any] | None = _types_namespace.copy()
|
||||
else:
|
||||
if _parent_namespace_depth > 0:
|
||||
frame_parent_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth) or {}
|
||||
# Note: we may need to add something similar to cls.__pydantic_parent_namespace__ from BaseModel
|
||||
# here when implementing handling of recursive generics. See BaseModel.model_rebuild for reference.
|
||||
types_namespace = frame_parent_ns
|
||||
else:
|
||||
types_namespace = {}
|
||||
|
||||
ns_resolver = _namespace_utils.NsResolver(
|
||||
parent_namespace=rebuild_ns,
|
||||
)
|
||||
|
||||
return _pydantic_dataclasses.complete_dataclass(
|
||||
cls,
|
||||
_config.ConfigWrapper(cls.__pydantic_config__, check=False),
|
||||
raise_errors=raise_errors,
|
||||
ns_resolver=ns_resolver,
|
||||
# We could provide a different config instead (with `'defer_build'` set to `True`)
|
||||
# of this explicit `_force_build` argument, but because config can come from the
|
||||
# decorator parameter or the `__pydantic_config__` attribute, `complete_dataclass`
|
||||
# will overwrite `__pydantic_config__` with the provided config above:
|
||||
_force_build=True,
|
||||
)
|
||||
types_namespace = _typing_extra.get_cls_types_namespace(cls, types_namespace)
|
||||
return _pydantic_dataclasses.complete_dataclass(
|
||||
cls,
|
||||
_config.ConfigWrapper(cls.__pydantic_config__, check=False),
|
||||
raise_errors=raise_errors,
|
||||
types_namespace=types_namespace,
|
||||
)
|
||||
|
||||
|
||||
def is_pydantic_dataclass(class_: type[Any], /) -> TypeGuard[type[PydanticDataclass]]:
|
||||
def is_pydantic_dataclass(__cls: type[Any]) -> TypeGuard[type[PydanticDataclass]]:
|
||||
"""Whether a class is a pydantic dataclass.
|
||||
|
||||
Args:
|
||||
class_: The class.
|
||||
__cls: The class.
|
||||
|
||||
Returns:
|
||||
`True` if the class is a pydantic dataclass, `False` otherwise.
|
||||
"""
|
||||
try:
|
||||
return '__is_pydantic_dataclass__' in class_.__dict__ and dataclasses.is_dataclass(class_)
|
||||
except AttributeError:
|
||||
return False
|
||||
return dataclasses.is_dataclass(__cls) and '__pydantic_validator__' in __cls.__dict__
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""The `datetime_parse` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""The `decorator` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -4,10 +4,10 @@ from __future__ import annotations as _annotations
|
||||
|
||||
from functools import partial, partialmethod
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
|
||||
from warnings import warn
|
||||
|
||||
from typing_extensions import Protocol, TypeAlias, deprecated
|
||||
from typing_extensions import Literal, Protocol, TypeAlias
|
||||
|
||||
from .._internal import _decorators, _decorators_v1
|
||||
from ..errors import PydanticUserError
|
||||
@@ -19,24 +19,30 @@ _ALLOW_REUSE_WARNING_MESSAGE = '`allow_reuse` is deprecated and will be ignored;
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _OnlyValueValidatorClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, __value: Any) -> Any: ...
|
||||
def __call__(self, __cls: Any, __value: Any) -> Any:
|
||||
...
|
||||
|
||||
class _V1ValidatorWithValuesClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any: ...
|
||||
def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any:
|
||||
...
|
||||
|
||||
class _V1ValidatorWithValuesKwOnlyClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any: ...
|
||||
def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any:
|
||||
...
|
||||
|
||||
class _V1ValidatorWithKwargsClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, **kwargs: Any) -> Any: ...
|
||||
def __call__(self, __cls: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
class _V1ValidatorWithValuesAndKwargsClsMethod(Protocol):
|
||||
def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any: ...
|
||||
def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any:
|
||||
...
|
||||
|
||||
class _V1RootValidatorClsMethod(Protocol):
|
||||
def __call__(
|
||||
self, __cls: Any, __values: _decorators_v1.RootValidatorValues
|
||||
) -> _decorators_v1.RootValidatorValues: ...
|
||||
) -> _decorators_v1.RootValidatorValues:
|
||||
...
|
||||
|
||||
V1Validator = Union[
|
||||
_OnlyValueValidatorClsMethod,
|
||||
@@ -73,12 +79,6 @@ else:
|
||||
DeprecationWarning = PydanticDeprecatedSince20
|
||||
|
||||
|
||||
@deprecated(
|
||||
'Pydantic V1 style `@validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@field_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
category=None,
|
||||
)
|
||||
def validator(
|
||||
__field: str,
|
||||
*fields: str,
|
||||
@@ -94,7 +94,7 @@ def validator(
|
||||
__field (str): The first field the validator should be called on; this is separate
|
||||
from `fields` to ensure an error is raised if you don't pass at least one.
|
||||
*fields (str): Additional field(s) the validator should be called on.
|
||||
pre (bool, optional): Whether this validator should be called before the standard
|
||||
pre (bool, optional): Whether or not this validator should be called before the standard
|
||||
validators (else after). Defaults to False.
|
||||
each_item (bool, optional): For complex objects (sets, lists etc.) whether to validate
|
||||
individual elements rather than the whole object. Defaults to False.
|
||||
@@ -109,6 +109,22 @@ def validator(
|
||||
Callable: A decorator that can be used to decorate a
|
||||
function to be used as a validator.
|
||||
"""
|
||||
if allow_reuse is True: # pragma: no cover
|
||||
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
|
||||
fields = tuple((__field, *fields))
|
||||
if isinstance(fields[0], FunctionType):
|
||||
raise PydanticUserError(
|
||||
"`@validator` should be used with fields and keyword arguments, not bare. "
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`",
|
||||
code='validator-no-fields',
|
||||
)
|
||||
elif not all(isinstance(field, str) for field in fields):
|
||||
raise PydanticUserError(
|
||||
"`@validator` fields should be passed as separate string args. "
|
||||
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
|
||||
code='validator-invalid-fields',
|
||||
)
|
||||
|
||||
warn(
|
||||
'Pydantic V1 style `@validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@field_validator` validators,'
|
||||
@@ -117,22 +133,6 @@ def validator(
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if allow_reuse is True: # pragma: no cover
|
||||
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
|
||||
fields = __field, *fields
|
||||
if isinstance(fields[0], FunctionType):
|
||||
raise PydanticUserError(
|
||||
'`@validator` should be used with fields and keyword arguments, not bare. '
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`",
|
||||
code='validator-no-fields',
|
||||
)
|
||||
elif not all(isinstance(field, str) for field in fields):
|
||||
raise PydanticUserError(
|
||||
'`@validator` fields should be passed as separate string args. '
|
||||
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
|
||||
code='validator-invalid-fields',
|
||||
)
|
||||
|
||||
mode: Literal['before', 'after'] = 'before' if pre is True else 'after'
|
||||
|
||||
def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
@@ -162,10 +162,8 @@ def root_validator(
|
||||
# which means you need to specify `skip_on_failure=True`
|
||||
skip_on_failure: Literal[True],
|
||||
allow_reuse: bool = ...,
|
||||
) -> Callable[
|
||||
[_V1RootValidatorFunctionType],
|
||||
_V1RootValidatorFunctionType,
|
||||
]: ...
|
||||
) -> Callable[[_V1RootValidatorFunctionType], _V1RootValidatorFunctionType,]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -175,10 +173,8 @@ def root_validator(
|
||||
# `skip_on_failure`, in fact it is not allowed as an argument!
|
||||
pre: Literal[True],
|
||||
allow_reuse: bool = ...,
|
||||
) -> Callable[
|
||||
[_V1RootValidatorFunctionType],
|
||||
_V1RootValidatorFunctionType,
|
||||
]: ...
|
||||
) -> Callable[[_V1RootValidatorFunctionType], _V1RootValidatorFunctionType,]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -189,18 +185,10 @@ def root_validator(
|
||||
pre: Literal[False],
|
||||
skip_on_failure: Literal[True],
|
||||
allow_reuse: bool = ...,
|
||||
) -> Callable[
|
||||
[_V1RootValidatorFunctionType],
|
||||
_V1RootValidatorFunctionType,
|
||||
]: ...
|
||||
) -> Callable[[_V1RootValidatorFunctionType], _V1RootValidatorFunctionType,]:
|
||||
...
|
||||
|
||||
|
||||
@deprecated(
|
||||
'Pydantic V1 style `@root_validator` validators are deprecated.'
|
||||
' You should migrate to Pydantic V2 style `@model_validator` validators,'
|
||||
' see the migration guide for more details',
|
||||
category=None,
|
||||
)
|
||||
def root_validator(
|
||||
*__args,
|
||||
pre: bool = False,
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from typing_extensions import deprecated
|
||||
from typing_extensions import Literal, deprecated
|
||||
|
||||
from .._internal import _config
|
||||
from ..warnings import PydanticDeprecatedSince20
|
||||
@@ -18,10 +18,10 @@ __all__ = 'BaseConfig', 'Extra'
|
||||
|
||||
class _ConfigMetaclass(type):
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
|
||||
try:
|
||||
obj = _config.config_defaults[item]
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
return obj
|
||||
return _config.config_defaults[item]
|
||||
except KeyError as exc:
|
||||
raise AttributeError(f"type object '{self.__name__}' has no attribute {exc}") from exc
|
||||
|
||||
@@ -35,10 +35,9 @@ class BaseConfig(metaclass=_ConfigMetaclass):
|
||||
"""
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
try:
|
||||
obj = super().__getattribute__(item)
|
||||
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
|
||||
return obj
|
||||
return super().__getattribute__(item)
|
||||
except AttributeError as exc:
|
||||
try:
|
||||
return getattr(type(self), item)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations as _annotations
|
||||
import typing
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, Tuple
|
||||
|
||||
import typing_extensions
|
||||
|
||||
@@ -18,7 +18,7 @@ if typing.TYPE_CHECKING:
|
||||
from .._internal._utils import AbstractSetIntStr, MappingIntStrAny
|
||||
|
||||
AnyClassMethod = classmethod[Any, Any, Any]
|
||||
TupleGenerator = typing.Generator[tuple[str, Any], None, None]
|
||||
TupleGenerator = typing.Generator[Tuple[str, Any], None, None]
|
||||
Model = typing.TypeVar('Model', bound='BaseModel')
|
||||
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
|
||||
IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
|
||||
@@ -40,11 +40,11 @@ def _iter(
|
||||
# The extra "is not None" guards are not logically necessary but optimizes performance for the simple case.
|
||||
if exclude is not None:
|
||||
exclude = _utils.ValueItems.merge(
|
||||
{k: v.exclude for k, v in self.__pydantic_fields__.items() if v.exclude is not None}, exclude
|
||||
{k: v.exclude for k, v in self.model_fields.items() if v.exclude is not None}, exclude
|
||||
)
|
||||
|
||||
if include is not None:
|
||||
include = _utils.ValueItems.merge({k: True for k in self.__pydantic_fields__}, include, intersect=True)
|
||||
include = _utils.ValueItems.merge({k: True for k in self.model_fields}, include, intersect=True)
|
||||
|
||||
allowed_keys = _calculate_keys(self, include=include, exclude=exclude, exclude_unset=exclude_unset) # type: ignore
|
||||
if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none):
|
||||
@@ -68,15 +68,15 @@ def _iter(
|
||||
|
||||
if exclude_defaults:
|
||||
try:
|
||||
field = self.__pydantic_fields__[field_key]
|
||||
field = self.model_fields[field_key]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
if not field.is_required() and field.default == v:
|
||||
continue
|
||||
|
||||
if by_alias and field_key in self.__pydantic_fields__:
|
||||
dict_key = self.__pydantic_fields__[field_key].alias or field_key
|
||||
if by_alias and field_key in self.model_fields:
|
||||
dict_key = self.model_fields[field_key].alias or field_key
|
||||
else:
|
||||
dict_key = field_key
|
||||
|
||||
@@ -200,7 +200,7 @@ def _calculate_keys(
|
||||
include: MappingIntStrAny | None,
|
||||
exclude: MappingIntStrAny | None,
|
||||
exclude_unset: bool,
|
||||
update: dict[str, Any] | None = None, # noqa UP006
|
||||
update: typing.Dict[str, Any] | None = None, # noqa UP006
|
||||
) -> typing.AbstractSet[str] | None:
|
||||
if include is None and exclude is None and exclude_unset is False:
|
||||
return None
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@@ -23,29 +22,29 @@ if TYPE_CHECKING:
|
||||
AnyCallable = Callable[..., Any]
|
||||
|
||||
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
|
||||
ConfigType = Union[None, type[Any], dict[str, Any]]
|
||||
ConfigType = Union[None, Type[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(
|
||||
func: None = None, *, config: 'ConfigType' = None
|
||||
) -> Callable[['AnyCallableT'], 'AnyCallableT']: ...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': ...
|
||||
|
||||
|
||||
@deprecated(
|
||||
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
|
||||
category=None,
|
||||
'The `validate_arguments` method is deprecated; use `validate_call` instead.', category=PydanticDeprecatedSince20
|
||||
)
|
||||
def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
@deprecated(
|
||||
'The `validate_arguments` method is deprecated; use `validate_call` instead.', category=PydanticDeprecatedSince20
|
||||
)
|
||||
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT':
|
||||
...
|
||||
|
||||
|
||||
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
|
||||
"""Decorator to validate the arguments passed to a function."""
|
||||
warnings.warn(
|
||||
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
|
||||
PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
'The `validate_arguments` method is deprecated; use `validate_call` instead.', DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
def validate(_func: 'AnyCallable') -> 'AnyCallable':
|
||||
@@ -87,7 +86,7 @@ class ValidatedFunction:
|
||||
)
|
||||
|
||||
self.raw_function = function
|
||||
self.arg_mapping: dict[int, str] = {}
|
||||
self.arg_mapping: Dict[int, str] = {}
|
||||
self.positional_only_args: set[str] = set()
|
||||
self.v_args_name = 'args'
|
||||
self.v_kwargs_name = 'kwargs'
|
||||
@@ -95,7 +94,7 @@ class ValidatedFunction:
|
||||
type_hints = _typing_extra.get_type_hints(function, include_extras=True)
|
||||
takes_args = False
|
||||
takes_kwargs = False
|
||||
fields: dict[str, tuple[Any, Any]] = {}
|
||||
fields: Dict[str, Tuple[Any, Any]] = {}
|
||||
for i, (name, p) in enumerate(parameters.items()):
|
||||
if p.annotation is p.empty:
|
||||
annotation = Any
|
||||
@@ -106,22 +105,22 @@ class ValidatedFunction:
|
||||
if p.kind == Parameter.POSITIONAL_ONLY:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_POSITIONAL_ONLY_NAME] = list[str], None
|
||||
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
|
||||
self.positional_only_args.add(name)
|
||||
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_DUPLICATE_KWARGS] = list[str], None
|
||||
fields[V_DUPLICATE_KWARGS] = List[str], None
|
||||
elif p.kind == Parameter.KEYWORD_ONLY:
|
||||
fields[name] = annotation, default
|
||||
elif p.kind == Parameter.VAR_POSITIONAL:
|
||||
self.v_args_name = name
|
||||
fields[name] = tuple[annotation, ...], None
|
||||
fields[name] = Tuple[annotation, ...], None
|
||||
takes_args = True
|
||||
else:
|
||||
assert p.kind == Parameter.VAR_KEYWORD, p.kind
|
||||
self.v_kwargs_name = name
|
||||
fields[name] = dict[str, annotation], None
|
||||
fields[name] = Dict[str, annotation], None
|
||||
takes_kwargs = True
|
||||
|
||||
# these checks avoid a clash between "args" and a field with that name
|
||||
@@ -134,11 +133,11 @@ class ValidatedFunction:
|
||||
|
||||
if not takes_args:
|
||||
# we add the field so validation below can raise the correct exception
|
||||
fields[self.v_args_name] = list[Any], None
|
||||
fields[self.v_args_name] = List[Any], None
|
||||
|
||||
if not takes_kwargs:
|
||||
# same with kwargs
|
||||
fields[self.v_kwargs_name] = dict[Any, Any], None
|
||||
fields[self.v_kwargs_name] = Dict[Any, Any], None
|
||||
|
||||
self.create_model(fields, takes_args, takes_kwargs, config)
|
||||
|
||||
@@ -150,8 +149,8 @@ class ValidatedFunction:
|
||||
m = self.init_model_instance(*args, **kwargs)
|
||||
return self.execute(m)
|
||||
|
||||
def build_values(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
values: dict[str, Any] = {}
|
||||
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values: Dict[str, Any] = {}
|
||||
if args:
|
||||
arg_iter = enumerate(args)
|
||||
while True:
|
||||
@@ -166,15 +165,15 @@ class ValidatedFunction:
|
||||
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
|
||||
break
|
||||
|
||||
var_kwargs: dict[str, Any] = {}
|
||||
var_kwargs: Dict[str, Any] = {}
|
||||
wrong_positional_args = []
|
||||
duplicate_kwargs = []
|
||||
fields_alias = [
|
||||
field.alias
|
||||
for name, field in self.model.__pydantic_fields__.items()
|
||||
for name, field in self.model.model_fields.items()
|
||||
if name not in (self.v_args_name, self.v_kwargs_name)
|
||||
]
|
||||
non_var_fields = set(self.model.__pydantic_fields__) - {self.v_args_name, self.v_kwargs_name}
|
||||
non_var_fields = set(self.model.model_fields) - {self.v_args_name, self.v_kwargs_name}
|
||||
for k, v in kwargs.items():
|
||||
if k in non_var_fields or k in fields_alias:
|
||||
if k in self.positional_only_args:
|
||||
@@ -194,15 +193,11 @@ class ValidatedFunction:
|
||||
return values
|
||||
|
||||
def execute(self, m: BaseModel) -> Any:
|
||||
d = {
|
||||
k: v
|
||||
for k, v in m.__dict__.items()
|
||||
if k in m.__pydantic_fields_set__ or m.__pydantic_fields__[k].default_factory
|
||||
}
|
||||
d = {k: v for k, v in m.__dict__.items() if k in m.__pydantic_fields_set__ or m.model_fields[k].default_factory}
|
||||
var_kwargs = d.pop(self.v_kwargs_name, {})
|
||||
|
||||
if self.v_args_name in d:
|
||||
args_: list[Any] = []
|
||||
args_: List[Any] = []
|
||||
in_kwargs = False
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
@@ -226,7 +221,7 @@ class ValidatedFunction:
|
||||
else:
|
||||
return self.raw_function(**d, **var_kwargs)
|
||||
|
||||
def create_model(self, fields: dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
|
||||
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
|
||||
pos_args = len(self.arg_mapping)
|
||||
|
||||
config_wrapper = _config.ConfigWrapper(config)
|
||||
@@ -243,7 +238,7 @@ class ValidatedFunction:
|
||||
class DecoratorBaseModel(BaseModel):
|
||||
@field_validator(self.v_args_name, check_fields=False)
|
||||
@classmethod
|
||||
def check_args(cls, v: Optional[list[Any]]) -> Optional[list[Any]]:
|
||||
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
|
||||
if takes_args or v is None:
|
||||
return v
|
||||
|
||||
@@ -251,7 +246,7 @@ class ValidatedFunction:
|
||||
|
||||
@field_validator(self.v_kwargs_name, check_fields=False)
|
||||
@classmethod
|
||||
def check_kwargs(cls, v: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
|
||||
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if takes_kwargs or v is None:
|
||||
return v
|
||||
|
||||
@@ -261,7 +256,7 @@ class ValidatedFunction:
|
||||
|
||||
@field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False)
|
||||
@classmethod
|
||||
def check_positional_only(cls, v: Optional[list[str]]) -> None:
|
||||
def check_positional_only(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
@@ -271,7 +266,7 @@ class ValidatedFunction:
|
||||
|
||||
@field_validator(V_DUPLICATE_KWARGS, check_fields=False)
|
||||
@classmethod
|
||||
def check_duplicate_kwargs(cls, v: Optional[list[str]]) -> None:
|
||||
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -7,12 +7,11 @@ from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from types import GeneratorType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from .._internal._import_utils import import_cached_base_model
|
||||
from ..color import Color
|
||||
from ..networks import NameEmail
|
||||
from ..types import SecretBytes, SecretStr
|
||||
@@ -51,7 +50,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
||||
return float(dec_value)
|
||||
|
||||
|
||||
ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
|
||||
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
bytes: lambda o: o.decode(),
|
||||
Color: str,
|
||||
datetime.date: isoformat,
|
||||
@@ -80,23 +79,18 @@ ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
|
||||
category=None,
|
||||
'pydantic_encoder is deprecated, use pydantic_core.to_jsonable_python instead.', category=PydanticDeprecatedSince20
|
||||
)
|
||||
def pydantic_encoder(obj: Any) -> Any:
|
||||
warnings.warn(
|
||||
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
BaseModel = import_cached_base_model()
|
||||
from ..main import BaseModel
|
||||
|
||||
warnings.warn('pydantic_encoder is deprecated, use BaseModel.model_dump instead.', DeprecationWarning, stacklevel=2)
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.model_dump()
|
||||
elif is_dataclass(obj):
|
||||
return asdict(obj) # type: ignore
|
||||
return asdict(obj)
|
||||
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
@@ -110,17 +104,12 @@ def pydantic_encoder(obj: Any) -> Any:
|
||||
|
||||
|
||||
# TODO: Add a suggested migration path once there is a way to use custom encoders
|
||||
@deprecated(
|
||||
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
|
||||
category=None,
|
||||
)
|
||||
def custom_pydantic_encoder(type_encoders: dict[Any, Callable[[type[Any]], Any]], obj: Any) -> Any:
|
||||
warnings.warn(
|
||||
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
)
|
||||
@deprecated('custom_pydantic_encoder is deprecated.', category=PydanticDeprecatedSince20)
|
||||
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
warnings.warn(
|
||||
'custom_pydantic_encoder is deprecated, use BaseModel.model_dump instead.', DeprecationWarning, stacklevel=2
|
||||
)
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = type_encoders[base]
|
||||
@@ -132,10 +121,10 @@ def custom_pydantic_encoder(type_encoders: dict[Any, Callable[[type[Any]], Any]]
|
||||
return pydantic_encoder(obj)
|
||||
|
||||
|
||||
@deprecated('`timedelta_isoformat` is deprecated.', category=None)
|
||||
@deprecated('timedelta_isoformat is deprecated.', category=PydanticDeprecatedSince20)
|
||||
def timedelta_isoformat(td: datetime.timedelta) -> str:
|
||||
"""ISO 8601 encoding for Python timedelta object."""
|
||||
warnings.warn('`timedelta_isoformat` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
|
||||
warnings.warn('timedelta_isoformat is deprecated.', DeprecationWarning, stacklevel=2)
|
||||
minutes, seconds = divmod(td.seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'
|
||||
|
||||
@@ -22,7 +22,7 @@ class Protocol(str, Enum):
|
||||
pickle = 'pickle'
|
||||
|
||||
|
||||
@deprecated('`load_str_bytes` is deprecated.', category=None)
|
||||
@deprecated('load_str_bytes is deprecated.', category=PydanticDeprecatedSince20)
|
||||
def load_str_bytes(
|
||||
b: str | bytes,
|
||||
*,
|
||||
@@ -32,7 +32,7 @@ def load_str_bytes(
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
warnings.warn('`load_str_bytes` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
|
||||
warnings.warn('load_str_bytes is deprecated.', DeprecationWarning, stacklevel=2)
|
||||
if proto is None and content_type:
|
||||
if content_type.endswith(('json', 'javascript')):
|
||||
pass
|
||||
@@ -46,17 +46,17 @@ def load_str_bytes(
|
||||
if proto == Protocol.json:
|
||||
if isinstance(b, bytes):
|
||||
b = b.decode(encoding)
|
||||
return json_loads(b) # type: ignore
|
||||
return json_loads(b)
|
||||
elif proto == Protocol.pickle:
|
||||
if not allow_pickle:
|
||||
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
|
||||
bb = b if isinstance(b, bytes) else b.encode() # type: ignore
|
||||
bb = b if isinstance(b, bytes) else b.encode()
|
||||
return pickle.loads(bb)
|
||||
else:
|
||||
raise TypeError(f'Unknown protocol: {proto}')
|
||||
|
||||
|
||||
@deprecated('`load_file` is deprecated.', category=None)
|
||||
@deprecated('load_file is deprecated.', category=PydanticDeprecatedSince20)
|
||||
def load_file(
|
||||
path: str | Path,
|
||||
*,
|
||||
@@ -66,7 +66,7 @@ def load_file(
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
warnings.warn('`load_file` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
|
||||
warnings.warn('load_file is deprecated.', DeprecationWarning, stacklevel=2)
|
||||
path = Path(path)
|
||||
b = path.read_bytes()
|
||||
if content_type is None:
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, Union
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@@ -17,20 +17,19 @@ if not TYPE_CHECKING:
|
||||
|
||||
__all__ = 'parse_obj_as', 'schema_of', 'schema_json_of'
|
||||
|
||||
NameFactory = Union[str, Callable[[type[Any]], str]]
|
||||
NameFactory = Union[str, Callable[[Type[Any]], str]]
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
|
||||
category=None,
|
||||
'parse_obj_as is deprecated. Use pydantic.TypeAdapter.validate_python instead.', category=PydanticDeprecatedSince20
|
||||
)
|
||||
def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None) -> T:
|
||||
warnings.warn(
|
||||
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
'parse_obj_as is deprecated. Use pydantic.TypeAdapter.validate_python instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if type_name is not None: # pragma: no cover
|
||||
@@ -43,8 +42,7 @@ def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None)
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=None,
|
||||
'schema_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', category=PydanticDeprecatedSince20
|
||||
)
|
||||
def schema_of(
|
||||
type_: Any,
|
||||
@@ -56,9 +54,7 @@ def schema_of(
|
||||
) -> dict[str, Any]:
|
||||
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one."""
|
||||
warnings.warn(
|
||||
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
'schema_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', DeprecationWarning, stacklevel=2
|
||||
)
|
||||
res = TypeAdapter(type_).json_schema(
|
||||
by_alias=by_alias,
|
||||
@@ -79,8 +75,7 @@ def schema_of(
|
||||
|
||||
|
||||
@deprecated(
|
||||
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=None,
|
||||
'schema_json_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', category=PydanticDeprecatedSince20
|
||||
)
|
||||
def schema_json_of(
|
||||
type_: Any,
|
||||
@@ -93,9 +88,7 @@ def schema_json_of(
|
||||
) -> str:
|
||||
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one."""
|
||||
warnings.warn(
|
||||
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
|
||||
category=PydanticDeprecatedSince20,
|
||||
stacklevel=2,
|
||||
'schema_json_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', DeprecationWarning, stacklevel=2
|
||||
)
|
||||
return json.dumps(
|
||||
schema_of(type_, title=title, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator),
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""The `env_settings` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""The `error_wrappers` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,14 +1,9 @@
|
||||
"""Pydantic-specific errors."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import re
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
from typing_extensions import Self
|
||||
from typing_inspection.introspection import Qualifier
|
||||
|
||||
from pydantic._internal import _repr
|
||||
from typing_extensions import Literal, Self
|
||||
|
||||
from ._migration import getattr_migration
|
||||
from .version import version_short
|
||||
@@ -19,7 +14,6 @@ __all__ = (
|
||||
'PydanticImportError',
|
||||
'PydanticSchemaGenerationError',
|
||||
'PydanticInvalidForJsonSchema',
|
||||
'PydanticForbiddenQualifier',
|
||||
'PydanticErrorCodes',
|
||||
)
|
||||
|
||||
@@ -36,13 +30,11 @@ PydanticErrorCodes = Literal[
|
||||
'discriminator-needs-literal',
|
||||
'discriminator-alias',
|
||||
'discriminator-validator',
|
||||
'callable-discriminator-no-tag',
|
||||
'typed-dict-version',
|
||||
'model-field-overridden',
|
||||
'model-field-missing-annotation',
|
||||
'config-both',
|
||||
'removed-kwargs',
|
||||
'circular-reference-schema',
|
||||
'invalid-for-json-schema',
|
||||
'json-schema-already-used',
|
||||
'base-model-instantiated',
|
||||
@@ -50,10 +42,10 @@ PydanticErrorCodes = Literal[
|
||||
'schema-for-unknown-type',
|
||||
'import-error',
|
||||
'create-model-field-definitions',
|
||||
'create-model-config-base',
|
||||
'validator-no-fields',
|
||||
'validator-invalid-fields',
|
||||
'validator-instance-method',
|
||||
'validator-input-type',
|
||||
'root-validator-pre-skip',
|
||||
'model-serializer-instance-method',
|
||||
'validator-field-config-info',
|
||||
@@ -62,20 +54,9 @@ PydanticErrorCodes = Literal[
|
||||
'field-serializer-signature',
|
||||
'model-serializer-signature',
|
||||
'multiple-field-serializers',
|
||||
'invalid-annotated-type',
|
||||
'invalid_annotated_type',
|
||||
'type-adapter-config-unused',
|
||||
'root-model-extra',
|
||||
'unevaluable-type-annotation',
|
||||
'dataclass-init-false-extra-allow',
|
||||
'clashing-init-and-init-var',
|
||||
'model-config-invalid-field-name',
|
||||
'with-config-on-model',
|
||||
'dataclass-on-model',
|
||||
'validate-call-type',
|
||||
'unpack-typed-dict',
|
||||
'overlapping-unpack-typed-dict',
|
||||
'invalid-self-type',
|
||||
'validate-by-alias-and-name-false',
|
||||
]
|
||||
|
||||
|
||||
@@ -164,26 +145,4 @@ class PydanticInvalidForJsonSchema(PydanticUserError):
|
||||
super().__init__(message, code='invalid-for-json-schema')
|
||||
|
||||
|
||||
class PydanticForbiddenQualifier(PydanticUserError):
|
||||
"""An error raised if a forbidden type qualifier is found in a type annotation."""
|
||||
|
||||
_qualifier_repr_map: ClassVar[dict[Qualifier, str]] = {
|
||||
'required': 'typing.Required',
|
||||
'not_required': 'typing.NotRequired',
|
||||
'read_only': 'typing.ReadOnly',
|
||||
'class_var': 'typing.ClassVar',
|
||||
'init_var': 'dataclasses.InitVar',
|
||||
'final': 'typing.Final',
|
||||
}
|
||||
|
||||
def __init__(self, qualifier: Qualifier, annotation: Any) -> None:
|
||||
super().__init__(
|
||||
message=(
|
||||
f'The annotation {_repr.display_as_type(annotation)!r} contains the {self._qualifier_repr_map[qualifier]!r} '
|
||||
f'type qualifier, which is invalid in the context it is defined.'
|
||||
),
|
||||
code=None,
|
||||
)
|
||||
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
"""The "experimental" module of pydantic contains potential new features that are subject to change."""
|
||||
|
||||
import warnings
|
||||
|
||||
from pydantic.warnings import PydanticExperimentalWarning
|
||||
|
||||
warnings.warn(
|
||||
'This module is experimental, its contents are subject to change and deprecation.',
|
||||
category=PydanticExperimentalWarning,
|
||||
)
|
||||
@@ -1,44 +0,0 @@
|
||||
"""Experimental module exposing a function to generate a core schema that validates callable arguments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic_core import CoreSchema
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic._internal import _config, _generate_schema, _namespace_utils
|
||||
|
||||
|
||||
def generate_arguments_schema(
|
||||
func: Callable[..., Any],
|
||||
schema_type: Literal['arguments', 'arguments-v3'] = 'arguments-v3',
|
||||
parameters_callback: Callable[[int, str, Any], Literal['skip'] | None] | None = None,
|
||||
config: ConfigDict | None = None,
|
||||
) -> CoreSchema:
|
||||
"""Generate the schema for the arguments of a function.
|
||||
|
||||
Args:
|
||||
func: The function to generate the schema for.
|
||||
schema_type: The type of schema to generate.
|
||||
parameters_callback: A callable that will be invoked for each parameter. The callback
|
||||
should take three required arguments: the index, the name and the type annotation
|
||||
(or [`Parameter.empty`][inspect.Parameter.empty] if not annotated) of the parameter.
|
||||
The callback can optionally return `'skip'`, so that the parameter gets excluded
|
||||
from the resulting schema.
|
||||
config: The configuration to use.
|
||||
|
||||
Returns:
|
||||
The generated schema.
|
||||
"""
|
||||
generate_schema = _generate_schema.GenerateSchema(
|
||||
_config.ConfigWrapper(config),
|
||||
ns_resolver=_namespace_utils.NsResolver(namespaces_tuple=_namespace_utils.ns_for_function(func)),
|
||||
)
|
||||
|
||||
if schema_type == 'arguments':
|
||||
schema = generate_schema._arguments_schema(func, parameters_callback) # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
schema = generate_schema._arguments_v3_schema(func, parameters_callback) # pyright: ignore[reportArgumentType]
|
||||
return generate_schema.clean_schema(schema)
|
||||
@@ -1,667 +0,0 @@
|
||||
"""Experimental pipeline API functionality. Be careful with this API, it's subject to change."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import operator
|
||||
import re
|
||||
import sys
|
||||
from collections import deque
|
||||
from collections.abc import Container
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal
|
||||
from functools import cached_property, partial
|
||||
from re import Pattern
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, Protocol, TypeVar, Union, overload
|
||||
|
||||
import annotated_types
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
|
||||
from pydantic._internal._internal_dataclass import slots_true as _slots_true
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
EllipsisType = type(Ellipsis)
|
||||
else:
|
||||
from types import EllipsisType
|
||||
|
||||
__all__ = ['validate_as', 'validate_as_deferred', 'transform']
|
||||
|
||||
_slots_frozen = {**_slots_true, 'frozen': True}
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _ValidateAs:
|
||||
tp: type[Any]
|
||||
strict: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ValidateAsDefer:
|
||||
func: Callable[[], type[Any]]
|
||||
|
||||
@cached_property
|
||||
def tp(self) -> type[Any]:
|
||||
return self.func()
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _Transform:
|
||||
func: Callable[[Any], Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _PipelineOr:
|
||||
left: _Pipeline[Any, Any]
|
||||
right: _Pipeline[Any, Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _PipelineAnd:
|
||||
left: _Pipeline[Any, Any]
|
||||
right: _Pipeline[Any, Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _Eq:
|
||||
value: Any
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _NotEq:
|
||||
value: Any
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _In:
|
||||
values: Container[Any]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _NotIn:
|
||||
values: Container[Any]
|
||||
|
||||
|
||||
_ConstraintAnnotation = Union[
|
||||
annotated_types.Le,
|
||||
annotated_types.Ge,
|
||||
annotated_types.Lt,
|
||||
annotated_types.Gt,
|
||||
annotated_types.Len,
|
||||
annotated_types.MultipleOf,
|
||||
annotated_types.Timezone,
|
||||
annotated_types.Interval,
|
||||
annotated_types.Predicate,
|
||||
# common predicates not included in annotated_types
|
||||
_Eq,
|
||||
_NotEq,
|
||||
_In,
|
||||
_NotIn,
|
||||
# regular expressions
|
||||
Pattern[str],
|
||||
]
|
||||
|
||||
|
||||
@dataclass(**_slots_frozen)
|
||||
class _Constraint:
|
||||
constraint: _ConstraintAnnotation
|
||||
|
||||
|
||||
_Step = Union[_ValidateAs, _ValidateAsDefer, _Transform, _PipelineOr, _PipelineAnd, _Constraint]
|
||||
|
||||
_InT = TypeVar('_InT')
|
||||
_OutT = TypeVar('_OutT')
|
||||
_NewOutT = TypeVar('_NewOutT')
|
||||
|
||||
|
||||
class _FieldTypeMarker:
|
||||
pass
|
||||
|
||||
|
||||
# TODO: ultimately, make this public, see https://github.com/pydantic/pydantic/pull/9459#discussion_r1628197626
|
||||
# Also, make this frozen eventually, but that doesn't work right now because of the generic base
|
||||
# Which attempts to modify __orig_base__ and such.
|
||||
# We could go with a manual freeze, but that seems overkill for now.
|
||||
@dataclass(**_slots_true)
|
||||
class _Pipeline(Generic[_InT, _OutT]):
|
||||
"""Abstract representation of a chain of validation, transformation, and parsing steps."""
|
||||
|
||||
_steps: tuple[_Step, ...]
|
||||
|
||||
def transform(
|
||||
self,
|
||||
func: Callable[[_OutT], _NewOutT],
|
||||
) -> _Pipeline[_InT, _NewOutT]:
|
||||
"""Transform the output of the previous step.
|
||||
|
||||
If used as the first step in a pipeline, the type of the field is used.
|
||||
That is, the transformation is applied to after the value is parsed to the field's type.
|
||||
"""
|
||||
return _Pipeline[_InT, _NewOutT](self._steps + (_Transform(func),))
|
||||
|
||||
@overload
|
||||
def validate_as(self, tp: type[_NewOutT], *, strict: bool = ...) -> _Pipeline[_InT, _NewOutT]: ...
|
||||
|
||||
@overload
|
||||
def validate_as(self, tp: EllipsisType, *, strict: bool = ...) -> _Pipeline[_InT, Any]: # type: ignore
|
||||
...
|
||||
|
||||
def validate_as(self, tp: type[_NewOutT] | EllipsisType, *, strict: bool = False) -> _Pipeline[_InT, Any]: # type: ignore
|
||||
"""Validate / parse the input into a new type.
|
||||
|
||||
If no type is provided, the type of the field is used.
|
||||
|
||||
Types are parsed in Pydantic's `lax` mode by default,
|
||||
but you can enable `strict` mode by passing `strict=True`.
|
||||
"""
|
||||
if isinstance(tp, EllipsisType):
|
||||
return _Pipeline[_InT, Any](self._steps + (_ValidateAs(_FieldTypeMarker, strict=strict),))
|
||||
return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAs(tp, strict=strict),))
|
||||
|
||||
def validate_as_deferred(self, func: Callable[[], type[_NewOutT]]) -> _Pipeline[_InT, _NewOutT]:
|
||||
"""Parse the input into a new type, deferring resolution of the type until the current class
|
||||
is fully defined.
|
||||
|
||||
This is useful when you need to reference the class in it's own type annotations.
|
||||
"""
|
||||
return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAsDefer(func),))
|
||||
|
||||
# constraints
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutGe], constraint: annotated_types.Ge) -> _Pipeline[_InT, _NewOutGe]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutGt], constraint: annotated_types.Gt) -> _Pipeline[_InT, _NewOutGt]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutLe], constraint: annotated_types.Le) -> _Pipeline[_InT, _NewOutLe]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutLt], constraint: annotated_types.Lt) -> _Pipeline[_InT, _NewOutLt]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutLen], constraint: annotated_types.Len
|
||||
) -> _Pipeline[_InT, _NewOutLen]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutT], constraint: annotated_types.MultipleOf
|
||||
) -> _Pipeline[_InT, _NewOutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutDatetime], constraint: annotated_types.Timezone
|
||||
) -> _Pipeline[_InT, _NewOutDatetime]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: annotated_types.Predicate) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(
|
||||
self: _Pipeline[_InT, _NewOutInterval], constraint: annotated_types.Interval
|
||||
) -> _Pipeline[_InT, _NewOutInterval]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _Eq) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotEq) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _In) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotIn) -> _Pipeline[_InT, _OutT]: ...
|
||||
|
||||
@overload
|
||||
def constrain(self: _Pipeline[_InT, _NewOutT], constraint: Pattern[str]) -> _Pipeline[_InT, _NewOutT]: ...
|
||||
|
||||
def constrain(self, constraint: _ConstraintAnnotation) -> Any:
|
||||
"""Constrain a value to meet a certain condition.
|
||||
|
||||
We support most conditions from `annotated_types`, as well as regular expressions.
|
||||
|
||||
Most of the time you'll be calling a shortcut method like `gt`, `lt`, `len`, etc
|
||||
so you don't need to call this directly.
|
||||
"""
|
||||
return _Pipeline[_InT, _OutT](self._steps + (_Constraint(constraint),))
|
||||
|
||||
def predicate(self: _Pipeline[_InT, _NewOutT], func: Callable[[_NewOutT], bool]) -> _Pipeline[_InT, _NewOutT]:
|
||||
"""Constrain a value to meet a certain predicate."""
|
||||
return self.constrain(annotated_types.Predicate(func))
|
||||
|
||||
def gt(self: _Pipeline[_InT, _NewOutGt], gt: _NewOutGt) -> _Pipeline[_InT, _NewOutGt]:
|
||||
"""Constrain a value to be greater than a certain value."""
|
||||
return self.constrain(annotated_types.Gt(gt))
|
||||
|
||||
def lt(self: _Pipeline[_InT, _NewOutLt], lt: _NewOutLt) -> _Pipeline[_InT, _NewOutLt]:
|
||||
"""Constrain a value to be less than a certain value."""
|
||||
return self.constrain(annotated_types.Lt(lt))
|
||||
|
||||
def ge(self: _Pipeline[_InT, _NewOutGe], ge: _NewOutGe) -> _Pipeline[_InT, _NewOutGe]:
|
||||
"""Constrain a value to be greater than or equal to a certain value."""
|
||||
return self.constrain(annotated_types.Ge(ge))
|
||||
|
||||
def le(self: _Pipeline[_InT, _NewOutLe], le: _NewOutLe) -> _Pipeline[_InT, _NewOutLe]:
|
||||
"""Constrain a value to be less than or equal to a certain value."""
|
||||
return self.constrain(annotated_types.Le(le))
|
||||
|
||||
def len(self: _Pipeline[_InT, _NewOutLen], min_len: int, max_len: int | None = None) -> _Pipeline[_InT, _NewOutLen]:
|
||||
"""Constrain a value to have a certain length."""
|
||||
return self.constrain(annotated_types.Len(min_len, max_len))
|
||||
|
||||
@overload
|
||||
def multiple_of(self: _Pipeline[_InT, _NewOutDiv], multiple_of: _NewOutDiv) -> _Pipeline[_InT, _NewOutDiv]: ...
|
||||
|
||||
@overload
|
||||
def multiple_of(self: _Pipeline[_InT, _NewOutMod], multiple_of: _NewOutMod) -> _Pipeline[_InT, _NewOutMod]: ...
|
||||
|
||||
def multiple_of(self: _Pipeline[_InT, Any], multiple_of: Any) -> _Pipeline[_InT, Any]:
|
||||
"""Constrain a value to be a multiple of a certain number."""
|
||||
return self.constrain(annotated_types.MultipleOf(multiple_of))
|
||||
|
||||
def eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to be equal to a certain value."""
|
||||
return self.constrain(_Eq(value))
|
||||
|
||||
def not_eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to not be equal to a certain value."""
|
||||
return self.constrain(_NotEq(value))
|
||||
|
||||
def in_(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to be in a certain set."""
|
||||
return self.constrain(_In(values))
|
||||
|
||||
def not_in(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]:
|
||||
"""Constrain a value to not be in a certain set."""
|
||||
return self.constrain(_NotIn(values))
|
||||
|
||||
# timezone methods
|
||||
def datetime_tz_naive(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.constrain(annotated_types.Timezone(None))
|
||||
|
||||
def datetime_tz_aware(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.constrain(annotated_types.Timezone(...))
|
||||
|
||||
def datetime_tz(
|
||||
self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo
|
||||
) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.constrain(annotated_types.Timezone(tz)) # type: ignore
|
||||
|
||||
def datetime_with_tz(
|
||||
self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo | None
|
||||
) -> _Pipeline[_InT, datetime.datetime]:
|
||||
return self.transform(partial(datetime.datetime.replace, tzinfo=tz))
|
||||
|
||||
# string methods
|
||||
def str_lower(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.lower)
|
||||
|
||||
def str_upper(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.upper)
|
||||
|
||||
def str_title(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.title)
|
||||
|
||||
def str_strip(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]:
|
||||
return self.transform(str.strip)
|
||||
|
||||
def str_pattern(self: _Pipeline[_InT, str], pattern: str) -> _Pipeline[_InT, str]:
|
||||
return self.constrain(re.compile(pattern))
|
||||
|
||||
def str_contains(self: _Pipeline[_InT, str], substring: str) -> _Pipeline[_InT, str]:
|
||||
return self.predicate(lambda v: substring in v)
|
||||
|
||||
def str_starts_with(self: _Pipeline[_InT, str], prefix: str) -> _Pipeline[_InT, str]:
|
||||
return self.predicate(lambda v: v.startswith(prefix))
|
||||
|
||||
def str_ends_with(self: _Pipeline[_InT, str], suffix: str) -> _Pipeline[_InT, str]:
|
||||
return self.predicate(lambda v: v.endswith(suffix))
|
||||
|
||||
# operators
|
||||
def otherwise(self, other: _Pipeline[_OtherIn, _OtherOut]) -> _Pipeline[_InT | _OtherIn, _OutT | _OtherOut]:
|
||||
"""Combine two validation chains, returning the result of the first chain if it succeeds, and the second chain if it fails."""
|
||||
return _Pipeline((_PipelineOr(self, other),))
|
||||
|
||||
__or__ = otherwise
|
||||
|
||||
def then(self, other: _Pipeline[_OutT, _OtherOut]) -> _Pipeline[_InT, _OtherOut]:
|
||||
"""Pipe the result of one validation chain into another."""
|
||||
return _Pipeline((_PipelineAnd(self, other),))
|
||||
|
||||
__and__ = then
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
queue = deque(self._steps)
|
||||
|
||||
s = None
|
||||
|
||||
while queue:
|
||||
step = queue.popleft()
|
||||
s = _apply_step(step, s, handler, source_type)
|
||||
|
||||
s = s or cs.any_schema()
|
||||
return s
|
||||
|
||||
def __supports_type__(self, _: _OutT) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
validate_as = _Pipeline[Any, Any](()).validate_as
|
||||
validate_as_deferred = _Pipeline[Any, Any](()).validate_as_deferred
|
||||
transform = _Pipeline[Any, Any]((_ValidateAs(_FieldTypeMarker),)).transform
|
||||
|
||||
|
||||
def _check_func(
|
||||
func: Callable[[Any], bool], predicate_err: str | Callable[[], str], s: cs.CoreSchema | None
|
||||
) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
def handler(v: Any) -> Any:
|
||||
if func(v):
|
||||
return v
|
||||
raise ValueError(f'Expected {predicate_err if isinstance(predicate_err, str) else predicate_err()}')
|
||||
|
||||
if s is None:
|
||||
return cs.no_info_plain_validator_function(handler)
|
||||
else:
|
||||
return cs.no_info_after_validator_function(handler, s)
|
||||
|
||||
|
||||
def _apply_step(step: _Step, s: cs.CoreSchema | None, handler: GetCoreSchemaHandler, source_type: Any) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
if isinstance(step, _ValidateAs):
|
||||
s = _apply_parse(s, step.tp, step.strict, handler, source_type)
|
||||
elif isinstance(step, _ValidateAsDefer):
|
||||
s = _apply_parse(s, step.tp, False, handler, source_type)
|
||||
elif isinstance(step, _Transform):
|
||||
s = _apply_transform(s, step.func, handler)
|
||||
elif isinstance(step, _Constraint):
|
||||
s = _apply_constraint(s, step.constraint)
|
||||
elif isinstance(step, _PipelineOr):
|
||||
s = cs.union_schema([handler(step.left), handler(step.right)])
|
||||
else:
|
||||
assert isinstance(step, _PipelineAnd)
|
||||
s = cs.chain_schema([handler(step.left), handler(step.right)])
|
||||
return s
|
||||
|
||||
|
||||
def _apply_parse(
|
||||
s: cs.CoreSchema | None,
|
||||
tp: type[Any],
|
||||
strict: bool,
|
||||
handler: GetCoreSchemaHandler,
|
||||
source_type: Any,
|
||||
) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
from pydantic import Strict
|
||||
|
||||
if tp is _FieldTypeMarker:
|
||||
return cs.chain_schema([s, handler(source_type)]) if s else handler(source_type)
|
||||
|
||||
if strict:
|
||||
tp = Annotated[tp, Strict()] # type: ignore
|
||||
|
||||
if s and s['type'] == 'any':
|
||||
return handler(tp)
|
||||
else:
|
||||
return cs.chain_schema([s, handler(tp)]) if s else handler(tp)
|
||||
|
||||
|
||||
def _apply_transform(
|
||||
s: cs.CoreSchema | None, func: Callable[[Any], Any], handler: GetCoreSchemaHandler
|
||||
) -> cs.CoreSchema:
|
||||
from pydantic_core import core_schema as cs
|
||||
|
||||
if s is None:
|
||||
return cs.no_info_plain_validator_function(func)
|
||||
|
||||
if s['type'] == 'str':
|
||||
if func is str.strip:
|
||||
s = s.copy()
|
||||
s['strip_whitespace'] = True
|
||||
return s
|
||||
elif func is str.lower:
|
||||
s = s.copy()
|
||||
s['to_lower'] = True
|
||||
return s
|
||||
elif func is str.upper:
|
||||
s = s.copy()
|
||||
s['to_upper'] = True
|
||||
return s
|
||||
|
||||
return cs.no_info_after_validator_function(func, s)
|
||||
|
||||
|
||||
def _apply_constraint( # noqa: C901
|
||||
s: cs.CoreSchema | None, constraint: _ConstraintAnnotation
|
||||
) -> cs.CoreSchema:
|
||||
"""Apply a single constraint to a schema."""
|
||||
if isinstance(constraint, annotated_types.Gt):
|
||||
gt = constraint.gt
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(gt, int):
|
||||
s['gt'] = gt
|
||||
elif s['type'] == 'float' and isinstance(gt, float):
|
||||
s['gt'] = gt
|
||||
elif s['type'] == 'decimal' and isinstance(gt, Decimal):
|
||||
s['gt'] = gt
|
||||
else:
|
||||
|
||||
def check_gt(v: Any) -> bool:
|
||||
return v > gt
|
||||
|
||||
s = _check_func(check_gt, f'> {gt}', s)
|
||||
elif isinstance(constraint, annotated_types.Ge):
|
||||
ge = constraint.ge
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(ge, int):
|
||||
s['ge'] = ge
|
||||
elif s['type'] == 'float' and isinstance(ge, float):
|
||||
s['ge'] = ge
|
||||
elif s['type'] == 'decimal' and isinstance(ge, Decimal):
|
||||
s['ge'] = ge
|
||||
|
||||
def check_ge(v: Any) -> bool:
|
||||
return v >= ge
|
||||
|
||||
s = _check_func(check_ge, f'>= {ge}', s)
|
||||
elif isinstance(constraint, annotated_types.Lt):
|
||||
lt = constraint.lt
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(lt, int):
|
||||
s['lt'] = lt
|
||||
elif s['type'] == 'float' and isinstance(lt, float):
|
||||
s['lt'] = lt
|
||||
elif s['type'] == 'decimal' and isinstance(lt, Decimal):
|
||||
s['lt'] = lt
|
||||
|
||||
def check_lt(v: Any) -> bool:
|
||||
return v < lt
|
||||
|
||||
s = _check_func(check_lt, f'< {lt}', s)
|
||||
elif isinstance(constraint, annotated_types.Le):
|
||||
le = constraint.le
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(le, int):
|
||||
s['le'] = le
|
||||
elif s['type'] == 'float' and isinstance(le, float):
|
||||
s['le'] = le
|
||||
elif s['type'] == 'decimal' and isinstance(le, Decimal):
|
||||
s['le'] = le
|
||||
|
||||
def check_le(v: Any) -> bool:
|
||||
return v <= le
|
||||
|
||||
s = _check_func(check_le, f'<= {le}', s)
|
||||
elif isinstance(constraint, annotated_types.Len):
|
||||
min_len = constraint.min_length
|
||||
max_len = constraint.max_length
|
||||
|
||||
if s and s['type'] in {'str', 'list', 'tuple', 'set', 'frozenset', 'dict'}:
|
||||
assert (
|
||||
s['type'] == 'str'
|
||||
or s['type'] == 'list'
|
||||
or s['type'] == 'tuple'
|
||||
or s['type'] == 'set'
|
||||
or s['type'] == 'dict'
|
||||
or s['type'] == 'frozenset'
|
||||
)
|
||||
s = s.copy()
|
||||
if min_len != 0:
|
||||
s['min_length'] = min_len
|
||||
if max_len is not None:
|
||||
s['max_length'] = max_len
|
||||
|
||||
def check_len(v: Any) -> bool:
|
||||
if max_len is not None:
|
||||
return (min_len <= len(v)) and (len(v) <= max_len)
|
||||
return min_len <= len(v)
|
||||
|
||||
s = _check_func(check_len, f'length >= {min_len} and length <= {max_len}', s)
|
||||
elif isinstance(constraint, annotated_types.MultipleOf):
|
||||
multiple_of = constraint.multiple_of
|
||||
if s and s['type'] in {'int', 'float', 'decimal'}:
|
||||
s = s.copy()
|
||||
if s['type'] == 'int' and isinstance(multiple_of, int):
|
||||
s['multiple_of'] = multiple_of
|
||||
elif s['type'] == 'float' and isinstance(multiple_of, float):
|
||||
s['multiple_of'] = multiple_of
|
||||
elif s['type'] == 'decimal' and isinstance(multiple_of, Decimal):
|
||||
s['multiple_of'] = multiple_of
|
||||
|
||||
def check_multiple_of(v: Any) -> bool:
|
||||
return v % multiple_of == 0
|
||||
|
||||
s = _check_func(check_multiple_of, f'% {multiple_of} == 0', s)
|
||||
elif isinstance(constraint, annotated_types.Timezone):
|
||||
tz = constraint.tz
|
||||
|
||||
if tz is ...:
|
||||
if s and s['type'] == 'datetime':
|
||||
s = s.copy()
|
||||
s['tz_constraint'] = 'aware'
|
||||
else:
|
||||
|
||||
def check_tz_aware(v: object) -> bool:
|
||||
assert isinstance(v, datetime.datetime)
|
||||
return v.tzinfo is not None
|
||||
|
||||
s = _check_func(check_tz_aware, 'timezone aware', s)
|
||||
elif tz is None:
|
||||
if s and s['type'] == 'datetime':
|
||||
s = s.copy()
|
||||
s['tz_constraint'] = 'naive'
|
||||
else:
|
||||
|
||||
def check_tz_naive(v: object) -> bool:
|
||||
assert isinstance(v, datetime.datetime)
|
||||
return v.tzinfo is None
|
||||
|
||||
s = _check_func(check_tz_naive, 'timezone naive', s)
|
||||
else:
|
||||
raise NotImplementedError('Constraining to a specific timezone is not yet supported')
|
||||
elif isinstance(constraint, annotated_types.Interval):
|
||||
if constraint.ge:
|
||||
s = _apply_constraint(s, annotated_types.Ge(constraint.ge))
|
||||
if constraint.gt:
|
||||
s = _apply_constraint(s, annotated_types.Gt(constraint.gt))
|
||||
if constraint.le:
|
||||
s = _apply_constraint(s, annotated_types.Le(constraint.le))
|
||||
if constraint.lt:
|
||||
s = _apply_constraint(s, annotated_types.Lt(constraint.lt))
|
||||
assert s is not None
|
||||
elif isinstance(constraint, annotated_types.Predicate):
|
||||
func = constraint.func
|
||||
|
||||
if func.__name__ == '<lambda>':
|
||||
# attempt to extract the source code for a lambda function
|
||||
# to use as the function name in error messages
|
||||
# TODO: is there a better way? should we just not do this?
|
||||
import inspect
|
||||
|
||||
try:
|
||||
source = inspect.getsource(func).strip()
|
||||
source = source.removesuffix(')')
|
||||
lambda_source_code = '`' + ''.join(''.join(source.split('lambda ')[1:]).split(':')[1:]).strip() + '`'
|
||||
except OSError:
|
||||
# stringified annotations
|
||||
lambda_source_code = 'lambda'
|
||||
|
||||
s = _check_func(func, lambda_source_code, s)
|
||||
else:
|
||||
s = _check_func(func, func.__name__, s)
|
||||
elif isinstance(constraint, _NotEq):
|
||||
value = constraint.value
|
||||
|
||||
def check_not_eq(v: Any) -> bool:
|
||||
return operator.__ne__(v, value)
|
||||
|
||||
s = _check_func(check_not_eq, f'!= {value}', s)
|
||||
elif isinstance(constraint, _Eq):
|
||||
value = constraint.value
|
||||
|
||||
def check_eq(v: Any) -> bool:
|
||||
return operator.__eq__(v, value)
|
||||
|
||||
s = _check_func(check_eq, f'== {value}', s)
|
||||
elif isinstance(constraint, _In):
|
||||
values = constraint.values
|
||||
|
||||
def check_in(v: Any) -> bool:
|
||||
return operator.__contains__(values, v)
|
||||
|
||||
s = _check_func(check_in, f'in {values}', s)
|
||||
elif isinstance(constraint, _NotIn):
|
||||
values = constraint.values
|
||||
|
||||
def check_not_in(v: Any) -> bool:
|
||||
return operator.__not__(operator.__contains__(values, v))
|
||||
|
||||
s = _check_func(check_not_in, f'not in {values}', s)
|
||||
else:
|
||||
assert isinstance(constraint, Pattern)
|
||||
if s and s['type'] == 'str':
|
||||
s = s.copy()
|
||||
s['pattern'] = constraint.pattern
|
||||
else:
|
||||
|
||||
def check_pattern(v: object) -> bool:
|
||||
assert isinstance(v, str)
|
||||
return constraint.match(v) is not None
|
||||
|
||||
s = _check_func(check_pattern, f'~ {constraint.pattern}', s)
|
||||
return s
|
||||
|
||||
|
||||
class _SupportsRange(annotated_types.SupportsLe, annotated_types.SupportsGe, Protocol):
|
||||
pass
|
||||
|
||||
|
||||
class _SupportsLen(Protocol):
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
|
||||
_NewOutGt = TypeVar('_NewOutGt', bound=annotated_types.SupportsGt)
|
||||
_NewOutGe = TypeVar('_NewOutGe', bound=annotated_types.SupportsGe)
|
||||
_NewOutLt = TypeVar('_NewOutLt', bound=annotated_types.SupportsLt)
|
||||
_NewOutLe = TypeVar('_NewOutLe', bound=annotated_types.SupportsLe)
|
||||
_NewOutLen = TypeVar('_NewOutLen', bound=_SupportsLen)
|
||||
_NewOutDiv = TypeVar('_NewOutDiv', bound=annotated_types.SupportsDiv)
|
||||
_NewOutMod = TypeVar('_NewOutMod', bound=annotated_types.SupportsMod)
|
||||
_NewOutDatetime = TypeVar('_NewOutDatetime', bound=datetime.datetime)
|
||||
_NewOutInterval = TypeVar('_NewOutInterval', bound=_SupportsRange)
|
||||
_OtherIn = TypeVar('_OtherIn')
|
||||
_OtherOut = TypeVar('_OtherOut')
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,13 @@
|
||||
"""This module contains related classes and functions for serialization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from functools import partial, partialmethod
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, TypeVar, overload
|
||||
from functools import partialmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler, WhenUsed
|
||||
from typing_extensions import TypeAlias
|
||||
from pydantic_core import core_schema as _core_schema
|
||||
from typing_extensions import Annotated, Literal, TypeAlias
|
||||
|
||||
from . import PydanticUndefinedAnnotation
|
||||
from ._internal import _decorators, _internal_dataclass
|
||||
@@ -19,26 +18,6 @@ from .annotated_handlers import GetCoreSchemaHandler
|
||||
class PlainSerializer:
|
||||
"""Plain serializers use a function to modify the output of serialization.
|
||||
|
||||
This is particularly helpful when you want to customize the serialization for annotated types.
|
||||
Consider an input of `list`, which will be serialized into a space-delimited string.
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, PlainSerializer
|
||||
|
||||
CustomStr = Annotated[
|
||||
list, PlainSerializer(lambda x: ' '.join(x), return_type=str)
|
||||
]
|
||||
|
||||
class StudentModel(BaseModel):
|
||||
courses: CustomStr
|
||||
|
||||
student = StudentModel(courses=['Math', 'Chemistry', 'English'])
|
||||
print(student.model_dump())
|
||||
#> {'courses': 'Math Chemistry English'}
|
||||
```
|
||||
|
||||
Attributes:
|
||||
func: The serializer function.
|
||||
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
|
||||
@@ -48,7 +27,7 @@ class PlainSerializer:
|
||||
|
||||
func: core_schema.SerializerFunction
|
||||
return_type: Any = PydanticUndefined
|
||||
when_used: WhenUsed = 'always'
|
||||
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always'
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
"""Gets the Pydantic core schema.
|
||||
@@ -61,20 +40,12 @@ class PlainSerializer:
|
||||
The Pydantic core schema.
|
||||
"""
|
||||
schema = handler(source_type)
|
||||
if self.return_type is not PydanticUndefined:
|
||||
return_type = self.return_type
|
||||
else:
|
||||
try:
|
||||
# Do not pass in globals as the function could be defined in a different module.
|
||||
# Instead, let `get_callable_return_type` infer the globals to use, but still pass
|
||||
# in locals that may contain a parent/rebuild namespace:
|
||||
return_type = _decorators.get_callable_return_type(
|
||||
self.func,
|
||||
localns=handler._get_types_namespace().locals,
|
||||
)
|
||||
except NameError as e:
|
||||
raise PydanticUndefinedAnnotation.from_name_error(e) from e
|
||||
|
||||
try:
|
||||
return_type = _decorators.get_function_return_type(
|
||||
self.func, self.return_type, handler._get_types_namespace()
|
||||
)
|
||||
except NameError as e:
|
||||
raise PydanticUndefinedAnnotation.from_name_error(e) from e
|
||||
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
|
||||
schema['serialization'] = core_schema.plain_serializer_function_ser_schema(
|
||||
function=self.func,
|
||||
@@ -90,58 +61,6 @@ class WrapSerializer:
|
||||
"""Wrap serializers receive the raw inputs along with a handler function that applies the standard serialization
|
||||
logic, and can modify the resulting value before returning it as the final output of serialization.
|
||||
|
||||
For example, here's a scenario in which a wrap serializer transforms timezones to UTC **and** utilizes the existing `datetime` serialization logic.
|
||||
|
||||
```python
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import BaseModel, WrapSerializer
|
||||
|
||||
class EventDatetime(BaseModel):
|
||||
start: datetime
|
||||
end: datetime
|
||||
|
||||
def convert_to_utc(value: Any, handler, info) -> dict[str, datetime]:
|
||||
# Note that `handler` can actually help serialize the `value` for
|
||||
# further custom serialization in case it's a subclass.
|
||||
partial_result = handler(value, info)
|
||||
if info.mode == 'json':
|
||||
return {
|
||||
k: datetime.fromisoformat(v).astimezone(timezone.utc)
|
||||
for k, v in partial_result.items()
|
||||
}
|
||||
return {k: v.astimezone(timezone.utc) for k, v in partial_result.items()}
|
||||
|
||||
UTCEventDatetime = Annotated[EventDatetime, WrapSerializer(convert_to_utc)]
|
||||
|
||||
class EventModel(BaseModel):
|
||||
event_datetime: UTCEventDatetime
|
||||
|
||||
dt = EventDatetime(
|
||||
start='2024-01-01T07:00:00-08:00', end='2024-01-03T20:00:00+06:00'
|
||||
)
|
||||
event = EventModel(event_datetime=dt)
|
||||
print(event.model_dump())
|
||||
'''
|
||||
{
|
||||
'event_datetime': {
|
||||
'start': datetime.datetime(
|
||||
2024, 1, 1, 15, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
'end': datetime.datetime(
|
||||
2024, 1, 3, 14, 0, tzinfo=datetime.timezone.utc
|
||||
),
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
print(event.model_dump_json())
|
||||
'''
|
||||
{"event_datetime":{"start":"2024-01-01T15:00:00Z","end":"2024-01-03T14:00:00Z"}}
|
||||
'''
|
||||
```
|
||||
|
||||
Attributes:
|
||||
func: The serializer function to be wrapped.
|
||||
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
|
||||
@@ -151,7 +70,7 @@ class WrapSerializer:
|
||||
|
||||
func: core_schema.WrapSerializerFunction
|
||||
return_type: Any = PydanticUndefined
|
||||
when_used: WhenUsed = 'always'
|
||||
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always'
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
"""This method is used to get the Pydantic core schema of the class.
|
||||
@@ -164,20 +83,12 @@ class WrapSerializer:
|
||||
The generated core schema of the class.
|
||||
"""
|
||||
schema = handler(source_type)
|
||||
if self.return_type is not PydanticUndefined:
|
||||
return_type = self.return_type
|
||||
else:
|
||||
try:
|
||||
# Do not pass in globals as the function could be defined in a different module.
|
||||
# Instead, let `get_callable_return_type` infer the globals to use, but still pass
|
||||
# in locals that may contain a parent/rebuild namespace:
|
||||
return_type = _decorators.get_callable_return_type(
|
||||
self.func,
|
||||
localns=handler._get_types_namespace().locals,
|
||||
)
|
||||
except NameError as e:
|
||||
raise PydanticUndefinedAnnotation.from_name_error(e) from e
|
||||
|
||||
try:
|
||||
return_type = _decorators.get_function_return_type(
|
||||
self.func, self.return_type, handler._get_types_namespace()
|
||||
)
|
||||
except NameError as e:
|
||||
raise PydanticUndefinedAnnotation.from_name_error(e) from e
|
||||
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
|
||||
schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
|
||||
function=self.func,
|
||||
@@ -189,77 +100,57 @@ class WrapSerializer:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_Partial: TypeAlias = 'partial[Any] | partialmethod[Any]'
|
||||
|
||||
FieldPlainSerializer: TypeAlias = 'core_schema.SerializerFunction | _Partial'
|
||||
"""A field serializer method or function in `plain` mode."""
|
||||
|
||||
FieldWrapSerializer: TypeAlias = 'core_schema.WrapSerializerFunction | _Partial'
|
||||
"""A field serializer method or function in `wrap` mode."""
|
||||
|
||||
FieldSerializer: TypeAlias = 'FieldPlainSerializer | FieldWrapSerializer'
|
||||
"""A field serializer method or function."""
|
||||
|
||||
_FieldPlainSerializerT = TypeVar('_FieldPlainSerializerT', bound=FieldPlainSerializer)
|
||||
_FieldWrapSerializerT = TypeVar('_FieldWrapSerializerT', bound=FieldWrapSerializer)
|
||||
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
|
||||
_PlainSerializationFunction = Union[_core_schema.SerializerFunction, _PartialClsOrStaticMethod]
|
||||
_WrapSerializationFunction = Union[_core_schema.WrapSerializerFunction, _PartialClsOrStaticMethod]
|
||||
_PlainSerializeMethodType = TypeVar('_PlainSerializeMethodType', bound=_PlainSerializationFunction)
|
||||
_WrapSerializeMethodType = TypeVar('_WrapSerializeMethodType', bound=_WrapSerializationFunction)
|
||||
|
||||
|
||||
@overload
|
||||
def field_serializer(
|
||||
field: str,
|
||||
/,
|
||||
__field: str,
|
||||
*fields: str,
|
||||
return_type: Any = ...,
|
||||
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def field_serializer(
|
||||
__field: str,
|
||||
*fields: str,
|
||||
mode: Literal['plain'],
|
||||
return_type: Any = ...,
|
||||
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def field_serializer(
|
||||
__field: str,
|
||||
*fields: str,
|
||||
mode: Literal['wrap'],
|
||||
return_type: Any = ...,
|
||||
when_used: WhenUsed = ...,
|
||||
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def field_serializer(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['plain'] = ...,
|
||||
return_type: Any = ...,
|
||||
when_used: WhenUsed = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT]: ...
|
||||
) -> Callable[[_WrapSerializeMethodType], _WrapSerializeMethodType]:
|
||||
...
|
||||
|
||||
|
||||
def field_serializer(
|
||||
*fields: str,
|
||||
mode: Literal['plain', 'wrap'] = 'plain',
|
||||
return_type: Any = PydanticUndefined,
|
||||
when_used: WhenUsed = 'always',
|
||||
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always',
|
||||
check_fields: bool | None = None,
|
||||
) -> (
|
||||
Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT]
|
||||
| Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT]
|
||||
):
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Decorator that enables custom field serialization.
|
||||
|
||||
In the below example, a field of type `set` is used to mitigate duplication. A `field_serializer` is used to serialize the data as a sorted list.
|
||||
|
||||
```python
|
||||
from typing import Set
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
class StudentModel(BaseModel):
|
||||
name: str = 'Jane'
|
||||
courses: Set[str]
|
||||
|
||||
@field_serializer('courses', when_used='json')
|
||||
def serialize_courses_in_order(self, courses: Set[str]):
|
||||
return sorted(courses)
|
||||
|
||||
student = StudentModel(courses={'Math', 'Chemistry', 'English'})
|
||||
print(student.model_dump_json())
|
||||
#> {"name":"Jane","courses":["Chemistry","English","Math"]}
|
||||
```
|
||||
|
||||
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
|
||||
|
||||
Four signatures are supported:
|
||||
@@ -284,7 +175,9 @@ def field_serializer(
|
||||
The decorator function.
|
||||
"""
|
||||
|
||||
def dec(f: FieldSerializer) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
def dec(
|
||||
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any]
|
||||
) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
dec_info = _decorators.FieldSerializerDecoratorInfo(
|
||||
fields=fields,
|
||||
mode=mode,
|
||||
@@ -292,109 +185,42 @@ def field_serializer(
|
||||
when_used=when_used,
|
||||
check_fields=check_fields,
|
||||
)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info) # pyright: ignore[reportArgumentType]
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info)
|
||||
|
||||
return dec # pyright: ignore[reportReturnType]
|
||||
return dec
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# The first argument in the following callables represent the `self` type:
|
||||
|
||||
ModelPlainSerializerWithInfo: TypeAlias = Callable[[Any, SerializationInfo], Any]
|
||||
"""A model serializer method with the `info` argument, in `plain` mode."""
|
||||
|
||||
ModelPlainSerializerWithoutInfo: TypeAlias = Callable[[Any], Any]
|
||||
"""A model serializer method without the `info` argument, in `plain` mode."""
|
||||
|
||||
ModelPlainSerializer: TypeAlias = 'ModelPlainSerializerWithInfo | ModelPlainSerializerWithoutInfo'
|
||||
"""A model serializer method in `plain` mode."""
|
||||
|
||||
ModelWrapSerializerWithInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler, SerializationInfo], Any]
|
||||
"""A model serializer method with the `info` argument, in `wrap` mode."""
|
||||
|
||||
ModelWrapSerializerWithoutInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler], Any]
|
||||
"""A model serializer method without the `info` argument, in `wrap` mode."""
|
||||
|
||||
ModelWrapSerializer: TypeAlias = 'ModelWrapSerializerWithInfo | ModelWrapSerializerWithoutInfo'
|
||||
"""A model serializer method in `wrap` mode."""
|
||||
|
||||
ModelSerializer: TypeAlias = 'ModelPlainSerializer | ModelWrapSerializer'
|
||||
|
||||
_ModelPlainSerializerT = TypeVar('_ModelPlainSerializerT', bound=ModelPlainSerializer)
|
||||
_ModelWrapSerializerT = TypeVar('_ModelWrapSerializerT', bound=ModelWrapSerializer)
|
||||
FuncType = TypeVar('FuncType', bound=Callable[..., Any])
|
||||
|
||||
|
||||
@overload
|
||||
def model_serializer(f: _ModelPlainSerializerT, /) -> _ModelPlainSerializerT: ...
|
||||
|
||||
|
||||
@overload
|
||||
def model_serializer(
|
||||
*, mode: Literal['wrap'], when_used: WhenUsed = 'always', return_type: Any = ...
|
||||
) -> Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT]: ...
|
||||
def model_serializer(__f: FuncType) -> FuncType:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def model_serializer(
|
||||
*,
|
||||
mode: Literal['plain'] = ...,
|
||||
when_used: WhenUsed = 'always',
|
||||
mode: Literal['plain', 'wrap'] = ...,
|
||||
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always',
|
||||
return_type: Any = ...,
|
||||
) -> Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT]: ...
|
||||
) -> Callable[[FuncType], FuncType]:
|
||||
...
|
||||
|
||||
|
||||
def model_serializer(
|
||||
f: _ModelPlainSerializerT | _ModelWrapSerializerT | None = None,
|
||||
/,
|
||||
__f: Callable[..., Any] | None = None,
|
||||
*,
|
||||
mode: Literal['plain', 'wrap'] = 'plain',
|
||||
when_used: WhenUsed = 'always',
|
||||
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always',
|
||||
return_type: Any = PydanticUndefined,
|
||||
) -> (
|
||||
_ModelPlainSerializerT
|
||||
| Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT]
|
||||
| Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT]
|
||||
):
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Decorator that enables custom model serialization.
|
||||
|
||||
This is useful when a model need to be serialized in a customized manner, allowing for flexibility beyond just specific fields.
|
||||
|
||||
An example would be to serialize temperature to the same temperature scale, such as degrees Celsius.
|
||||
|
||||
```python
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, model_serializer
|
||||
|
||||
class TemperatureModel(BaseModel):
|
||||
unit: Literal['C', 'F']
|
||||
value: int
|
||||
|
||||
@model_serializer()
|
||||
def serialize_model(self):
|
||||
if self.unit == 'F':
|
||||
return {'unit': 'C', 'value': int((self.value - 32) / 1.8)}
|
||||
return {'unit': self.unit, 'value': self.value}
|
||||
|
||||
temperature = TemperatureModel(unit='F', value=212)
|
||||
print(temperature.model_dump())
|
||||
#> {'unit': 'C', 'value': 100}
|
||||
```
|
||||
|
||||
Two signatures are supported for `mode='plain'`, which is the default:
|
||||
|
||||
- `(self)`
|
||||
- `(self, info: SerializationInfo)`
|
||||
|
||||
And two other signatures for `mode='wrap'`:
|
||||
|
||||
- `(self, nxt: SerializerFunctionWrapHandler)`
|
||||
- `(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo)`
|
||||
|
||||
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
|
||||
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
|
||||
|
||||
Args:
|
||||
f: The function to be decorated.
|
||||
__f: The function to be decorated.
|
||||
mode: The serialization mode.
|
||||
|
||||
- `'plain'` means the function will be called instead of the default serialization logic
|
||||
@@ -407,14 +233,14 @@ def model_serializer(
|
||||
The decorator function.
|
||||
"""
|
||||
|
||||
def dec(f: ModelSerializer) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
def dec(f: Callable[..., Any]) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
dec_info = _decorators.ModelSerializerDecoratorInfo(mode=mode, return_type=return_type, when_used=when_used)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info)
|
||||
|
||||
if f is None:
|
||||
return dec # pyright: ignore[reportReturnType]
|
||||
if __f is None:
|
||||
return dec
|
||||
else:
|
||||
return dec(f) # pyright: ignore[reportReturnType]
|
||||
return dec(__f) # type: ignore
|
||||
|
||||
|
||||
AnyType = TypeVar('AnyType')
|
||||
|
||||
@@ -6,13 +6,14 @@ import dataclasses
|
||||
import sys
|
||||
from functools import partialmethod
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, TypeVar, Union, cast, overload
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast, overload
|
||||
|
||||
from pydantic_core import PydanticUndefined, core_schema
|
||||
from pydantic_core import core_schema
|
||||
from pydantic_core import core_schema as _core_schema
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from typing_extensions import Annotated, Literal, TypeAlias
|
||||
|
||||
from ._internal import _decorators, _generics, _internal_dataclass
|
||||
from . import GetCoreSchemaHandler as _GetCoreSchemaHandler
|
||||
from ._internal import _core_metadata, _decorators, _generics, _internal_dataclass
|
||||
from .annotated_handlers import GetCoreSchemaHandler
|
||||
from .errors import PydanticUserError
|
||||
|
||||
@@ -26,8 +27,7 @@ _inspect_validator = _decorators.inspect_validator
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class AfterValidator:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[field *after* validators](../concepts/validators.md#field-after-validator)
|
||||
'''Usage docs: https://docs.pydantic.dev/2.2/concepts/validators/#annotated-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **after** the inner validation logic.
|
||||
|
||||
@@ -35,10 +35,11 @@ class AfterValidator:
|
||||
func: The validator function.
|
||||
|
||||
Example:
|
||||
```python
|
||||
```py
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, ValidationError
|
||||
from pydantic import BaseModel, AfterValidator, ValidationError
|
||||
|
||||
|
||||
MyInt = Annotated[int, AfterValidator(lambda v: v + 1)]
|
||||
|
||||
@@ -46,31 +47,31 @@ class AfterValidator:
|
||||
a: MyInt
|
||||
|
||||
print(Model(a=1).a)
|
||||
#> 2
|
||||
# > 2
|
||||
|
||||
try:
|
||||
Model(a='a')
|
||||
except ValidationError as e:
|
||||
print(e.json(indent=2))
|
||||
'''
|
||||
[
|
||||
{
|
||||
"""
|
||||
[
|
||||
{
|
||||
"type": "int_parsing",
|
||||
"loc": [
|
||||
"a"
|
||||
"a"
|
||||
],
|
||||
"msg": "Input should be a valid integer, unable to parse string as an integer",
|
||||
"input": "a",
|
||||
"url": "https://errors.pydantic.dev/2/v/int_parsing"
|
||||
}
|
||||
]
|
||||
'''
|
||||
"url": "https://errors.pydantic.dev/0.38.0/v/int_parsing"
|
||||
}
|
||||
]
|
||||
"""
|
||||
```
|
||||
"""
|
||||
'''
|
||||
|
||||
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
info_arg = _inspect_validator(self.func, 'after')
|
||||
if info_arg:
|
||||
@@ -80,26 +81,19 @@ class AfterValidator:
|
||||
func = cast(core_schema.NoInfoValidatorFunction, self.func)
|
||||
return core_schema.no_info_after_validator_function(func, schema=schema)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(func=decorator.func)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class BeforeValidator:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[field *before* validators](../concepts/validators.md#field-before-validator)
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#annotated-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **before** the inner validation logic.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
|
||||
JSON Schema (in validation mode).
|
||||
|
||||
Example:
|
||||
```python
|
||||
from typing import Annotated
|
||||
```py
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, BeforeValidator
|
||||
|
||||
@@ -120,151 +114,68 @@ class BeforeValidator:
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
|
||||
json_schema_input_type: Any = PydanticUndefined
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
input_schema = (
|
||||
None
|
||||
if self.json_schema_input_type is PydanticUndefined
|
||||
else handler.generate_schema(self.json_schema_input_type)
|
||||
)
|
||||
|
||||
info_arg = _inspect_validator(self.func, 'before')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoValidatorFunction, self.func)
|
||||
return core_schema.with_info_before_validator_function(
|
||||
func,
|
||||
schema=schema,
|
||||
field_name=handler.field_name,
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
return core_schema.with_info_before_validator_function(func, schema=schema, field_name=handler.field_name)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoValidatorFunction, self.func)
|
||||
return core_schema.no_info_before_validator_function(
|
||||
func, schema=schema, json_schema_input_schema=input_schema
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(
|
||||
func=decorator.func,
|
||||
json_schema_input_type=decorator.info.json_schema_input_type,
|
||||
)
|
||||
return core_schema.no_info_before_validator_function(func, schema=schema)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class PlainValidator:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[field *plain* validators](../concepts/validators.md#field-plain-validator)
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#annotated-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **instead** of the inner validation logic.
|
||||
|
||||
!!! note
|
||||
Before v2.9, `PlainValidator` wasn't always compatible with JSON Schema generation for `mode='validation'`.
|
||||
You can now use the `json_schema_input_type` argument to specify the input type of the function
|
||||
to be used in the JSON schema when `mode='validation'` (the default). See the example below for more details.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
|
||||
JSON Schema (in validation mode). If not provided, will default to `Any`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from typing import Annotated, Union
|
||||
```py
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, PlainValidator
|
||||
|
||||
MyInt = Annotated[
|
||||
int,
|
||||
PlainValidator(
|
||||
lambda v: int(v) + 1, json_schema_input_type=Union[str, int] # (1)!
|
||||
),
|
||||
]
|
||||
MyInt = Annotated[int, PlainValidator(lambda v: int(v) + 1)]
|
||||
|
||||
class Model(BaseModel):
|
||||
a: MyInt
|
||||
|
||||
print(Model(a='1').a)
|
||||
#> 2
|
||||
|
||||
print(Model(a=1).a)
|
||||
#> 2
|
||||
```
|
||||
|
||||
1. In this example, we've specified the `json_schema_input_type` as `Union[str, int]` which indicates to the JSON schema
|
||||
generator that in validation mode, the input type for the `a` field can be either a `str` or an `int`.
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
|
||||
json_schema_input_type: Any = Any
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
# Note that for some valid uses of PlainValidator, it is not possible to generate a core schema for the
|
||||
# source_type, so calling `handler(source_type)` will error, which prevents us from generating a proper
|
||||
# serialization schema. To work around this for use cases that will not involve serialization, we simply
|
||||
# catch any PydanticSchemaGenerationError that may be raised while attempting to build the serialization schema
|
||||
# and abort any attempts to handle special serialization.
|
||||
from pydantic import PydanticSchemaGenerationError
|
||||
|
||||
try:
|
||||
schema = handler(source_type)
|
||||
# TODO if `schema['serialization']` is one of `'include-exclude-dict/sequence',
|
||||
# schema validation will fail. That's why we use 'type ignore' comments below.
|
||||
serialization = schema.get(
|
||||
'serialization',
|
||||
core_schema.wrap_serializer_function_ser_schema(
|
||||
function=lambda v, h: h(v),
|
||||
schema=schema,
|
||||
return_schema=handler.generate_schema(source_type),
|
||||
),
|
||||
)
|
||||
except PydanticSchemaGenerationError:
|
||||
serialization = None
|
||||
|
||||
input_schema = handler.generate_schema(self.json_schema_input_type)
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
info_arg = _inspect_validator(self.func, 'plain')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoValidatorFunction, self.func)
|
||||
return core_schema.with_info_plain_validator_function(
|
||||
func,
|
||||
field_name=handler.field_name,
|
||||
serialization=serialization, # pyright: ignore[reportArgumentType]
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
return core_schema.with_info_plain_validator_function(func, field_name=handler.field_name)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoValidatorFunction, self.func)
|
||||
return core_schema.no_info_plain_validator_function(
|
||||
func,
|
||||
serialization=serialization, # pyright: ignore[reportArgumentType]
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(
|
||||
func=decorator.func,
|
||||
json_schema_input_type=decorator.info.json_schema_input_type,
|
||||
)
|
||||
return core_schema.no_info_plain_validator_function(func)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
|
||||
class WrapValidator:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[field *wrap* validators](../concepts/validators.md#field-wrap-validator)
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#annotated-validators
|
||||
|
||||
A metadata class that indicates that a validation should be applied **around** the inner validation logic.
|
||||
|
||||
Attributes:
|
||||
func: The validator function.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate the appropriate
|
||||
JSON Schema (in validation mode).
|
||||
|
||||
```python
|
||||
```py
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic import BaseModel, ValidationError, WrapValidator
|
||||
|
||||
@@ -291,61 +202,37 @@ class WrapValidator:
|
||||
"""
|
||||
|
||||
func: core_schema.NoInfoWrapValidatorFunction | core_schema.WithInfoWrapValidatorFunction
|
||||
json_schema_input_type: Any = PydanticUndefined
|
||||
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
schema = handler(source_type)
|
||||
input_schema = (
|
||||
None
|
||||
if self.json_schema_input_type is PydanticUndefined
|
||||
else handler.generate_schema(self.json_schema_input_type)
|
||||
)
|
||||
|
||||
info_arg = _inspect_validator(self.func, 'wrap')
|
||||
if info_arg:
|
||||
func = cast(core_schema.WithInfoWrapValidatorFunction, self.func)
|
||||
return core_schema.with_info_wrap_validator_function(
|
||||
func,
|
||||
schema=schema,
|
||||
field_name=handler.field_name,
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
return core_schema.with_info_wrap_validator_function(func, schema=schema, field_name=handler.field_name)
|
||||
else:
|
||||
func = cast(core_schema.NoInfoWrapValidatorFunction, self.func)
|
||||
return core_schema.no_info_wrap_validator_function(
|
||||
func,
|
||||
schema=schema,
|
||||
json_schema_input_schema=input_schema,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self:
|
||||
return cls(
|
||||
func=decorator.func,
|
||||
json_schema_input_type=decorator.info.json_schema_input_type,
|
||||
)
|
||||
return core_schema.no_info_wrap_validator_function(func, schema=schema)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _OnlyValueValidatorClsMethod(Protocol):
|
||||
def __call__(self, cls: Any, value: Any, /) -> Any: ...
|
||||
def __call__(self, __cls: Any, __value: Any) -> Any:
|
||||
...
|
||||
|
||||
class _V2ValidatorClsMethod(Protocol):
|
||||
def __call__(self, cls: Any, value: Any, info: _core_schema.ValidationInfo, /) -> Any: ...
|
||||
|
||||
class _OnlyValueWrapValidatorClsMethod(Protocol):
|
||||
def __call__(self, cls: Any, value: Any, handler: _core_schema.ValidatorFunctionWrapHandler, /) -> Any: ...
|
||||
def __call__(self, __cls: Any, __input_value: Any, __info: _core_schema.ValidationInfo) -> Any:
|
||||
...
|
||||
|
||||
class _V2WrapValidatorClsMethod(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
cls: Any,
|
||||
value: Any,
|
||||
handler: _core_schema.ValidatorFunctionWrapHandler,
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> Any: ...
|
||||
__cls: Any,
|
||||
__input_value: Any,
|
||||
__validator: _core_schema.ValidatorFunctionWrapHandler,
|
||||
__info: _core_schema.ValidationInfo,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
_V2Validator = Union[
|
||||
_V2ValidatorClsMethod,
|
||||
@@ -357,111 +244,57 @@ if TYPE_CHECKING:
|
||||
_V2WrapValidator = Union[
|
||||
_V2WrapValidatorClsMethod,
|
||||
_core_schema.WithInfoWrapValidatorFunction,
|
||||
_OnlyValueWrapValidatorClsMethod,
|
||||
_core_schema.NoInfoWrapValidatorFunction,
|
||||
]
|
||||
|
||||
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
|
||||
|
||||
_V2BeforeAfterOrPlainValidatorType = TypeVar(
|
||||
'_V2BeforeAfterOrPlainValidatorType',
|
||||
bound=Union[_V2Validator, _PartialClsOrStaticMethod],
|
||||
_V2Validator,
|
||||
_PartialClsOrStaticMethod,
|
||||
)
|
||||
_V2WrapValidatorType = TypeVar('_V2WrapValidatorType', bound=Union[_V2WrapValidator, _PartialClsOrStaticMethod])
|
||||
_V2WrapValidatorType = TypeVar('_V2WrapValidatorType', _V2WrapValidator, _PartialClsOrStaticMethod)
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
__field: str,
|
||||
*fields: str,
|
||||
mode: Literal['before', 'after', 'plain'] = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
__field: str,
|
||||
*fields: str,
|
||||
mode: Literal['wrap'],
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]:
|
||||
...
|
||||
|
||||
|
||||
FieldValidatorModes: TypeAlias = Literal['before', 'after', 'wrap', 'plain']
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['wrap'],
|
||||
check_fields: bool | None = ...,
|
||||
json_schema_input_type: Any = ...,
|
||||
) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['before', 'plain'],
|
||||
check_fields: bool | None = ...,
|
||||
json_schema_input_type: Any = ...,
|
||||
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
*fields: str,
|
||||
mode: Literal['after'] = ...,
|
||||
check_fields: bool | None = ...,
|
||||
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ...
|
||||
|
||||
|
||||
def field_validator(
|
||||
field: str,
|
||||
/,
|
||||
__field: str,
|
||||
*fields: str,
|
||||
mode: FieldValidatorModes = 'after',
|
||||
check_fields: bool | None = None,
|
||||
json_schema_input_type: Any = PydanticUndefined,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[field validators](../concepts/validators.md#field-validators)
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#field-validators
|
||||
|
||||
Decorate methods on the class indicating that they should be used to validate fields.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from typing import Any
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
|
||||
class Model(BaseModel):
|
||||
a: str
|
||||
|
||||
@field_validator('a')
|
||||
@classmethod
|
||||
def ensure_foobar(cls, v: Any):
|
||||
if 'foobar' not in v:
|
||||
raise ValueError('"foobar" not found in a')
|
||||
return v
|
||||
|
||||
print(repr(Model(a='this is foobar good')))
|
||||
#> Model(a='this is foobar good')
|
||||
|
||||
try:
|
||||
Model(a='snap')
|
||||
except ValidationError as exc_info:
|
||||
print(exc_info)
|
||||
'''
|
||||
1 validation error for Model
|
||||
a
|
||||
Value error, "foobar" not found in a [type=value_error, input_value='snap', input_type=str]
|
||||
'''
|
||||
```
|
||||
|
||||
For more in depth examples, see [Field Validators](../concepts/validators.md#field-validators).
|
||||
|
||||
Args:
|
||||
field: The first field the `field_validator` should be called on; this is separate
|
||||
__field: The first field the `field_validator` should be called on; this is separate
|
||||
from `fields` to ensure an error is raised if you don't pass at least one.
|
||||
*fields: Additional field(s) the `field_validator` should be called on.
|
||||
mode: Specifies whether to validate the fields before or after validation.
|
||||
check_fields: Whether to check that the fields actually exist on the model.
|
||||
json_schema_input_type: The input type of the function. This is only used to generate
|
||||
the appropriate JSON Schema (in validation mode) and can only specified
|
||||
when `mode` is either `'before'`, `'plain'` or `'wrap'`.
|
||||
|
||||
Returns:
|
||||
A decorator that can be used to decorate a function to be used as a field_validator.
|
||||
@@ -472,23 +305,13 @@ def field_validator(
|
||||
- If the args passed to `@field_validator` as fields are not strings.
|
||||
- If `@field_validator` applied to instance methods.
|
||||
"""
|
||||
if isinstance(field, FunctionType):
|
||||
if isinstance(__field, FunctionType):
|
||||
raise PydanticUserError(
|
||||
'`@field_validator` should be used with fields and keyword arguments, not bare. '
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`",
|
||||
code='validator-no-fields',
|
||||
)
|
||||
|
||||
if mode not in ('before', 'plain', 'wrap') and json_schema_input_type is not PydanticUndefined:
|
||||
raise PydanticUserError(
|
||||
f"`json_schema_input_type` can't be used when mode is set to {mode!r}",
|
||||
code='validator-input-type',
|
||||
)
|
||||
|
||||
if json_schema_input_type is PydanticUndefined and mode == 'plain':
|
||||
json_schema_input_type = Any
|
||||
|
||||
fields = field, *fields
|
||||
fields = __field, *fields
|
||||
if not all(isinstance(field, str) for field in fields):
|
||||
raise PydanticUserError(
|
||||
'`@field_validator` fields should be passed as separate string args. '
|
||||
@@ -497,7 +320,7 @@ def field_validator(
|
||||
)
|
||||
|
||||
def dec(
|
||||
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any],
|
||||
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any]
|
||||
) -> _decorators.PydanticDescriptorProxy[Any]:
|
||||
if _decorators.is_instance_method_from_sig(f):
|
||||
raise PydanticUserError(
|
||||
@@ -507,9 +330,7 @@ def field_validator(
|
||||
# auto apply the @classmethod decorator
|
||||
f = _decorators.ensure_classmethod_based_on_signature(f)
|
||||
|
||||
dec_info = _decorators.FieldValidatorDecoratorInfo(
|
||||
fields=fields, mode=mode, check_fields=check_fields, json_schema_input_type=json_schema_input_type
|
||||
)
|
||||
dec_info = _decorators.FieldValidatorDecoratorInfo(fields=fields, mode=mode, check_fields=check_fields)
|
||||
return _decorators.PydanticDescriptorProxy(f, dec_info)
|
||||
|
||||
return dec
|
||||
@@ -520,19 +341,16 @@ _ModelTypeCo = TypeVar('_ModelTypeCo', covariant=True)
|
||||
|
||||
|
||||
class ModelWrapValidatorHandler(_core_schema.ValidatorFunctionWrapHandler, Protocol[_ModelTypeCo]):
|
||||
"""`@model_validator` decorated function handler argument type. This is used when `mode='wrap'`."""
|
||||
"""@model_validator decorated function handler argument type. This is used when `mode='wrap'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
value: Any,
|
||||
outer_location: str | int | None = None,
|
||||
/,
|
||||
self, input_value: Any, outer_location: str | int | None = None
|
||||
) -> _ModelTypeCo: # pragma: no cover
|
||||
...
|
||||
|
||||
|
||||
class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]):
|
||||
"""A `@model_validator` decorated function signature.
|
||||
"""A @model_validator decorated function signature.
|
||||
This is used when `mode='wrap'` and the function does not have info argument.
|
||||
"""
|
||||
|
||||
@@ -542,14 +360,14 @@ class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]):
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
handler: ModelWrapValidatorHandler[_ModelType],
|
||||
/,
|
||||
) -> _ModelType: ...
|
||||
__value: Any,
|
||||
__handler: ModelWrapValidatorHandler[_ModelType],
|
||||
) -> _ModelType:
|
||||
...
|
||||
|
||||
|
||||
class ModelWrapValidator(Protocol[_ModelType]):
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='wrap'`."""
|
||||
"""A @model_validator decorated function signature. This is used when `mode='wrap'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
@@ -557,30 +375,15 @@ class ModelWrapValidator(Protocol[_ModelType]):
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
handler: ModelWrapValidatorHandler[_ModelType],
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> _ModelType: ...
|
||||
|
||||
|
||||
class FreeModelBeforeValidatorWithoutInfo(Protocol):
|
||||
"""A `@model_validator` decorated function signature.
|
||||
This is used when `mode='before'` and the function does not have info argument.
|
||||
"""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
/,
|
||||
) -> Any: ...
|
||||
__value: Any,
|
||||
__handler: ModelWrapValidatorHandler[_ModelType],
|
||||
__info: _core_schema.ValidationInfo,
|
||||
) -> _ModelType:
|
||||
...
|
||||
|
||||
|
||||
class ModelBeforeValidatorWithoutInfo(Protocol):
|
||||
"""A `@model_validator` decorated function signature.
|
||||
"""A @model_validator decorated function signature.
|
||||
This is used when `mode='before'` and the function does not have info argument.
|
||||
"""
|
||||
|
||||
@@ -590,23 +393,9 @@ class ModelBeforeValidatorWithoutInfo(Protocol):
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
/,
|
||||
) -> Any: ...
|
||||
|
||||
|
||||
class FreeModelBeforeValidator(Protocol):
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""
|
||||
|
||||
def __call__( # noqa: D102
|
||||
self,
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> Any: ...
|
||||
__value: Any,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
|
||||
class ModelBeforeValidator(Protocol):
|
||||
@@ -618,10 +407,10 @@ class ModelBeforeValidator(Protocol):
|
||||
# this can be a dict, a model instance
|
||||
# or anything else that gets passed to validate_python
|
||||
# thus validators _must_ handle all cases
|
||||
value: Any,
|
||||
info: _core_schema.ValidationInfo,
|
||||
/,
|
||||
) -> Any: ...
|
||||
__value: Any,
|
||||
__info: _core_schema.ValidationInfo,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
|
||||
ModelAfterValidatorWithoutInfo = Callable[[_ModelType], _ModelType]
|
||||
@@ -633,9 +422,7 @@ ModelAfterValidator = Callable[[_ModelType, _core_schema.ValidationInfo], _Model
|
||||
"""A `@model_validator` decorated function signature. This is used when `mode='after'`."""
|
||||
|
||||
_AnyModelWrapValidator = Union[ModelWrapValidator[_ModelType], ModelWrapValidatorWithoutInfo[_ModelType]]
|
||||
_AnyModelBeforeValidator = Union[
|
||||
FreeModelBeforeValidator, ModelBeforeValidator, FreeModelBeforeValidatorWithoutInfo, ModelBeforeValidatorWithoutInfo
|
||||
]
|
||||
_AnyModeBeforeValidator = Union[ModelBeforeValidator, ModelBeforeValidatorWithoutInfo]
|
||||
_AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]]
|
||||
|
||||
|
||||
@@ -645,16 +432,16 @@ def model_validator(
|
||||
mode: Literal['wrap'],
|
||||
) -> Callable[
|
||||
[_AnyModelWrapValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
|
||||
]: ...
|
||||
]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['before'],
|
||||
) -> Callable[
|
||||
[_AnyModelBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
|
||||
]: ...
|
||||
) -> Callable[[_AnyModeBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -663,49 +450,15 @@ def model_validator(
|
||||
mode: Literal['after'],
|
||||
) -> Callable[
|
||||
[_AnyModelAfterValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
|
||||
]: ...
|
||||
]:
|
||||
...
|
||||
|
||||
|
||||
def model_validator(
|
||||
*,
|
||||
mode: Literal['wrap', 'before', 'after'],
|
||||
) -> Any:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[Model Validators](../concepts/validators.md#model-validators)
|
||||
|
||||
Decorate model methods for validation purposes.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import BaseModel, ValidationError, model_validator
|
||||
|
||||
class Square(BaseModel):
|
||||
width: float
|
||||
height: float
|
||||
|
||||
@model_validator(mode='after')
|
||||
def verify_square(self) -> Self:
|
||||
if self.width != self.height:
|
||||
raise ValueError('width and height do not match')
|
||||
return self
|
||||
|
||||
s = Square(width=1, height=1)
|
||||
print(repr(s))
|
||||
#> Square(width=1.0, height=1.0)
|
||||
|
||||
try:
|
||||
Square(width=1, height=2)
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
'''
|
||||
1 validation error for Square
|
||||
Value error, width and height do not match [type=value_error, input_value={'width': 1, 'height': 2}, input_type=dict]
|
||||
'''
|
||||
```
|
||||
|
||||
For more in depth examples, see [Model Validators](../concepts/validators.md#model-validators).
|
||||
"""Decorate model methods for validation purposes.
|
||||
|
||||
Args:
|
||||
mode: A required string literal that specifies the validation mode.
|
||||
@@ -738,7 +491,7 @@ else:
|
||||
'''Generic type for annotating a type that is an instance of a given class.
|
||||
|
||||
Example:
|
||||
```python
|
||||
```py
|
||||
from pydantic import BaseModel, InstanceOf
|
||||
|
||||
class Foo:
|
||||
@@ -817,7 +570,7 @@ else:
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
original_schema = handler(source)
|
||||
metadata = {'pydantic_js_annotation_functions': [lambda _c, h: h(original_schema)]}
|
||||
metadata = _core_metadata.build_metadata_dict(js_annotation_functions=[lambda _c, h: h(original_schema)])
|
||||
return core_schema.any_schema(
|
||||
metadata=metadata,
|
||||
serialization=core_schema.wrap_serializer_function_ser_schema(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""The `generics` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""The `json` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -3,9 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Iterator
|
||||
from configparser import ConfigParser
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
from mypy.errorcodes import ErrorCode
|
||||
from mypy.expandtype import expand_type, expand_type_by_instance
|
||||
@@ -15,7 +14,6 @@ from mypy.nodes import (
|
||||
ARG_OPT,
|
||||
ARG_POS,
|
||||
ARG_STAR2,
|
||||
INVARIANT,
|
||||
MDEF,
|
||||
Argument,
|
||||
AssignmentStmt,
|
||||
@@ -47,24 +45,26 @@ from mypy.options import Options
|
||||
from mypy.plugin import (
|
||||
CheckerPluginInterface,
|
||||
ClassDefContext,
|
||||
FunctionContext,
|
||||
MethodContext,
|
||||
Plugin,
|
||||
ReportConfigContext,
|
||||
SemanticAnalyzerPluginInterface,
|
||||
)
|
||||
from mypy.plugins import dataclasses
|
||||
from mypy.plugins.common import (
|
||||
deserialize_and_fixup_type,
|
||||
)
|
||||
from mypy.semanal import set_callable_name
|
||||
from mypy.server.trigger import make_wildcard_trigger
|
||||
from mypy.state import state
|
||||
from mypy.type_visitor import TypeTranslator
|
||||
from mypy.typeops import map_type_from_supertype
|
||||
from mypy.types import (
|
||||
AnyType,
|
||||
CallableType,
|
||||
Instance,
|
||||
NoneType,
|
||||
Overloaded,
|
||||
Type,
|
||||
TypeOfAny,
|
||||
TypeType,
|
||||
@@ -79,11 +79,16 @@ from mypy.version import __version__ as mypy_version
|
||||
from pydantic._internal import _fields
|
||||
from pydantic.version import parse_mypy_version
|
||||
|
||||
try:
|
||||
from mypy.types import TypeVarDef # type: ignore[attr-defined]
|
||||
except ImportError: # pragma: no cover
|
||||
# Backward-compatible with TypeVarDef from Mypy 0.930.
|
||||
from mypy.types import TypeVarType as TypeVarDef
|
||||
|
||||
CONFIGFILE_KEY = 'pydantic-mypy'
|
||||
METADATA_KEY = 'pydantic-mypy-metadata'
|
||||
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
|
||||
BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings'
|
||||
ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel'
|
||||
MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass'
|
||||
FIELD_FULLNAME = 'pydantic.fields.Field'
|
||||
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass'
|
||||
@@ -96,11 +101,10 @@ DECORATOR_FULLNAMES = {
|
||||
'pydantic.deprecated.class_validators.validator',
|
||||
'pydantic.deprecated.class_validators.root_validator',
|
||||
}
|
||||
IMPLICIT_CLASSMETHOD_DECORATOR_FULLNAMES = DECORATOR_FULLNAMES - {'pydantic.functional_serializers.model_serializer'}
|
||||
|
||||
|
||||
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
|
||||
BUILTINS_NAME = 'builtins'
|
||||
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'
|
||||
|
||||
# Increment version if plugin changes and mypy caches should be invalidated
|
||||
__version__ = 2
|
||||
@@ -129,12 +133,12 @@ class PydanticPlugin(Plugin):
|
||||
self._plugin_data = self.plugin_config.to_data()
|
||||
super().__init__(options)
|
||||
|
||||
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
|
||||
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], bool] | None:
|
||||
"""Update Pydantic model class."""
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and isinstance(sym.node, TypeInfo): # pragma: no branch
|
||||
# No branching may occur if the mypy cache has not been cleared
|
||||
if sym.node.has_base(BASEMODEL_FULLNAME):
|
||||
if any(base.fullname == BASEMODEL_FULLNAME for base in sym.node.mro):
|
||||
return self._pydantic_model_class_maker_callback
|
||||
return None
|
||||
|
||||
@@ -144,12 +148,28 @@ class PydanticPlugin(Plugin):
|
||||
return self._pydantic_model_metaclass_marker_callback
|
||||
return None
|
||||
|
||||
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
|
||||
"""Adjust the return type of the `Field` function."""
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and sym.fullname == FIELD_FULLNAME:
|
||||
return self._pydantic_field_callback
|
||||
return None
|
||||
|
||||
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
|
||||
"""Adjust return type of `from_orm` method call."""
|
||||
if fullname.endswith('.from_orm'):
|
||||
return from_attributes_callback
|
||||
return None
|
||||
|
||||
def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
|
||||
"""Mark pydantic.dataclasses as dataclass.
|
||||
|
||||
Mypy version 1.1.1 added support for `@dataclass_transform` decorator.
|
||||
"""
|
||||
if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1):
|
||||
return dataclasses.dataclass_class_maker_callback # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def report_config_data(self, ctx: ReportConfigContext) -> dict[str, Any]:
|
||||
"""Return all plugin config data.
|
||||
|
||||
@@ -157,9 +177,9 @@ class PydanticPlugin(Plugin):
|
||||
"""
|
||||
return self._plugin_data
|
||||
|
||||
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
|
||||
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool:
|
||||
transformer = PydanticModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config)
|
||||
transformer.transform()
|
||||
return transformer.transform()
|
||||
|
||||
def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None:
|
||||
"""Reset dataclass_transform_spec attribute of ModelMetaclass.
|
||||
@@ -174,6 +194,54 @@ class PydanticPlugin(Plugin):
|
||||
if getattr(info_metaclass.type, 'dataclass_transform_spec', None):
|
||||
info_metaclass.type.dataclass_transform_spec = None
|
||||
|
||||
def _pydantic_field_callback(self, ctx: FunctionContext) -> Type:
|
||||
"""Extract the type of the `default` argument from the Field function, and use it as the return type.
|
||||
|
||||
In particular:
|
||||
* Check whether the default and default_factory argument is specified.
|
||||
* Output an error if both are specified.
|
||||
* Retrieve the type of the argument which is specified, and use it as return type for the function.
|
||||
"""
|
||||
default_any_type = ctx.default_return_type
|
||||
|
||||
assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()'
|
||||
assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()'
|
||||
default_args = ctx.args[0]
|
||||
default_factory_args = ctx.args[1]
|
||||
|
||||
if default_args and default_factory_args:
|
||||
error_default_and_default_factory_specified(ctx.api, ctx.context)
|
||||
return default_any_type
|
||||
|
||||
if default_args:
|
||||
default_type = ctx.arg_types[0][0]
|
||||
default_arg = default_args[0]
|
||||
|
||||
# Fallback to default Any type if the field is required
|
||||
if not isinstance(default_arg, EllipsisExpr):
|
||||
return default_type
|
||||
|
||||
elif default_factory_args:
|
||||
default_factory_type = ctx.arg_types[1][0]
|
||||
|
||||
# Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter
|
||||
# Pydantic calls the default factory without any argument, so we retrieve the first item
|
||||
if isinstance(default_factory_type, Overloaded):
|
||||
default_factory_type = default_factory_type.items[0]
|
||||
|
||||
if isinstance(default_factory_type, CallableType):
|
||||
ret_type = default_factory_type.ret_type
|
||||
# mypy doesn't think `ret_type` has `args`, you'd think mypy should know,
|
||||
# add this check in case it varies by version
|
||||
args = getattr(ret_type, 'args', None)
|
||||
if args:
|
||||
if all(isinstance(arg, TypeVarType) for arg in args):
|
||||
# Looks like the default factory is a type like `list` or `dict`, replace all args with `Any`
|
||||
ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined]
|
||||
return ret_type
|
||||
|
||||
return default_any_type
|
||||
|
||||
|
||||
class PydanticPluginConfig:
|
||||
"""A Pydantic mypy plugin config holder.
|
||||
@@ -238,9 +306,6 @@ def from_attributes_callback(ctx: MethodContext) -> Type:
|
||||
pydantic_metadata = model_type.type.metadata.get(METADATA_KEY)
|
||||
if pydantic_metadata is None:
|
||||
return ctx.default_return_type
|
||||
if not model_type.type.has_base(BASEMODEL_FULLNAME):
|
||||
# not a Pydantic v2 model
|
||||
return ctx.default_return_type
|
||||
from_attributes = pydantic_metadata.get('config', {}).get('from_attributes')
|
||||
if from_attributes is not True:
|
||||
error_from_attributes(model_type.type.name, ctx.api, ctx.context)
|
||||
@@ -254,10 +319,8 @@ class PydanticModelField:
|
||||
self,
|
||||
name: str,
|
||||
alias: str | None,
|
||||
is_frozen: bool,
|
||||
has_dynamic_alias: bool,
|
||||
has_default: bool,
|
||||
strict: bool | None,
|
||||
line: int,
|
||||
column: int,
|
||||
type: Type | None,
|
||||
@@ -265,103 +328,40 @@ class PydanticModelField:
|
||||
):
|
||||
self.name = name
|
||||
self.alias = alias
|
||||
self.is_frozen = is_frozen
|
||||
self.has_dynamic_alias = has_dynamic_alias
|
||||
self.has_default = has_default
|
||||
self.strict = strict
|
||||
self.line = line
|
||||
self.column = column
|
||||
self.type = type
|
||||
self.info = info
|
||||
|
||||
def to_argument(
|
||||
self,
|
||||
current_info: TypeInfo,
|
||||
typed: bool,
|
||||
model_strict: bool,
|
||||
force_optional: bool,
|
||||
use_alias: bool,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
force_typevars_invariant: bool,
|
||||
is_root_model_root: bool,
|
||||
) -> Argument:
|
||||
def to_argument(self, current_info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument:
|
||||
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_argument."""
|
||||
variable = self.to_var(current_info, api, use_alias, force_typevars_invariant)
|
||||
|
||||
strict = model_strict if self.strict is None else self.strict
|
||||
if typed or strict:
|
||||
type_annotation = self.expand_type(current_info, api, include_root_type=True)
|
||||
else:
|
||||
type_annotation = AnyType(TypeOfAny.explicit)
|
||||
|
||||
return Argument(
|
||||
variable=variable,
|
||||
type_annotation=type_annotation,
|
||||
variable=self.to_var(current_info, use_alias),
|
||||
type_annotation=self.expand_type(current_info) if typed else AnyType(TypeOfAny.explicit),
|
||||
initializer=None,
|
||||
kind=ARG_OPT
|
||||
if is_root_model_root
|
||||
else (ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED),
|
||||
kind=ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED,
|
||||
)
|
||||
|
||||
def expand_type(
|
||||
self,
|
||||
current_info: TypeInfo,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
force_typevars_invariant: bool = False,
|
||||
include_root_type: bool = False,
|
||||
) -> Type | None:
|
||||
def expand_type(self, current_info: TypeInfo) -> Type | None:
|
||||
"""Based on mypy.plugins.dataclasses.DataclassAttribute.expand_type."""
|
||||
if force_typevars_invariant:
|
||||
# In some cases, mypy will emit an error "Cannot use a covariant type variable as a parameter"
|
||||
# To prevent that, we add an option to replace typevars with invariant ones while building certain
|
||||
# method signatures (in particular, `__init__`). There may be a better way to do this, if this causes
|
||||
# us problems in the future, we should look into why the dataclasses plugin doesn't have this issue.
|
||||
if isinstance(self.type, TypeVarType):
|
||||
modified_type = self.type.copy_modified()
|
||||
modified_type.variance = INVARIANT
|
||||
self.type = modified_type
|
||||
|
||||
if self.type is not None and self.info.self_type is not None:
|
||||
# In general, it is not safe to call `expand_type()` during semantic analysis,
|
||||
# In general, it is not safe to call `expand_type()` during semantic analyzis,
|
||||
# however this plugin is called very late, so all types should be fully ready.
|
||||
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
|
||||
# we serialize attributes).
|
||||
with state.strict_optional_set(api.options.strict_optional):
|
||||
filled_with_typevars = fill_typevars(current_info)
|
||||
# Cannot be TupleType as current_info represents a Pydantic model:
|
||||
assert isinstance(filled_with_typevars, Instance)
|
||||
if force_typevars_invariant:
|
||||
for arg in filled_with_typevars.args:
|
||||
if isinstance(arg, TypeVarType):
|
||||
arg.variance = INVARIANT
|
||||
|
||||
expanded_type = expand_type(self.type, {self.info.self_type.id: filled_with_typevars})
|
||||
if include_root_type and isinstance(expanded_type, Instance) and is_root_model(expanded_type.type):
|
||||
# When a root model is used as a field, Pydantic allows both an instance of the root model
|
||||
# as well as instances of the `root` field type:
|
||||
root_type = expanded_type.type['root'].type
|
||||
if root_type is None:
|
||||
# Happens if the hint for 'root' has unsolved forward references
|
||||
return expanded_type
|
||||
expanded_root_type = expand_type_by_instance(root_type, expanded_type)
|
||||
expanded_type = UnionType([expanded_type, expanded_root_type])
|
||||
return expanded_type
|
||||
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
|
||||
return self.type
|
||||
|
||||
def to_var(
|
||||
self,
|
||||
current_info: TypeInfo,
|
||||
api: SemanticAnalyzerPluginInterface,
|
||||
use_alias: bool,
|
||||
force_typevars_invariant: bool = False,
|
||||
) -> Var:
|
||||
def to_var(self, current_info: TypeInfo, use_alias: bool) -> Var:
|
||||
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_var."""
|
||||
if use_alias and self.alias is not None:
|
||||
name = self.alias
|
||||
else:
|
||||
name = self.name
|
||||
|
||||
return Var(name, self.expand_type(current_info, api, force_typevars_invariant))
|
||||
return Var(name, self.expand_type(current_info))
|
||||
|
||||
def serialize(self) -> JsonDict:
|
||||
"""Based on mypy.plugins.dataclasses.DataclassAttribute.serialize."""
|
||||
@@ -369,10 +369,8 @@ class PydanticModelField:
|
||||
return {
|
||||
'name': self.name,
|
||||
'alias': self.alias,
|
||||
'is_frozen': self.is_frozen,
|
||||
'has_dynamic_alias': self.has_dynamic_alias,
|
||||
'has_default': self.has_default,
|
||||
'strict': self.strict,
|
||||
'line': self.line,
|
||||
'column': self.column,
|
||||
'type': self.type.serialize(),
|
||||
@@ -385,38 +383,12 @@ class PydanticModelField:
|
||||
typ = deserialize_and_fixup_type(data.pop('type'), api)
|
||||
return cls(type=typ, info=info, **data)
|
||||
|
||||
def expand_typevar_from_subtype(self, sub_type: TypeInfo, api: SemanticAnalyzerPluginInterface) -> None:
|
||||
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
|
||||
"""Expands type vars in the context of a subtype when an attribute is inherited
|
||||
from a generic super type.
|
||||
"""
|
||||
if self.type is not None:
|
||||
with state.strict_optional_set(api.options.strict_optional):
|
||||
self.type = map_type_from_supertype(self.type, sub_type, self.info)
|
||||
|
||||
|
||||
class PydanticModelClassVar:
|
||||
"""Based on mypy.plugins.dataclasses.DataclassAttribute.
|
||||
|
||||
ClassVars are ignored by subclasses.
|
||||
|
||||
Attributes:
|
||||
name: the ClassVar name
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, data: JsonDict) -> PydanticModelClassVar:
|
||||
"""Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize."""
|
||||
data = data.copy()
|
||||
return cls(**data)
|
||||
|
||||
def serialize(self) -> JsonDict:
|
||||
"""Based on mypy.plugins.dataclasses.DataclassAttribute.serialize."""
|
||||
return {
|
||||
'name': self.name,
|
||||
}
|
||||
self.type = map_type_from_supertype(self.type, sub_type, self.info)
|
||||
|
||||
|
||||
class PydanticModelTransformer:
|
||||
@@ -431,10 +403,7 @@ class PydanticModelTransformer:
|
||||
'frozen',
|
||||
'from_attributes',
|
||||
'populate_by_name',
|
||||
'validate_by_alias',
|
||||
'validate_by_name',
|
||||
'alias_generator',
|
||||
'strict',
|
||||
}
|
||||
|
||||
def __init__(
|
||||
@@ -461,26 +430,24 @@ class PydanticModelTransformer:
|
||||
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
|
||||
"""
|
||||
info = self._cls.info
|
||||
is_a_root_model = is_root_model(info)
|
||||
config = self.collect_config()
|
||||
fields, class_vars = self.collect_fields_and_class_vars(config, is_a_root_model)
|
||||
if fields is None or class_vars is None:
|
||||
fields = self.collect_fields(config)
|
||||
if fields is None:
|
||||
# Some definitions are not ready. We need another pass.
|
||||
return False
|
||||
for field in fields:
|
||||
if field.type is None:
|
||||
return False
|
||||
|
||||
is_settings = info.has_base(BASESETTINGS_FULLNAME)
|
||||
self.add_initializer(fields, config, is_settings, is_a_root_model)
|
||||
self.add_model_construct_method(fields, config, is_settings, is_a_root_model)
|
||||
self.set_frozen(fields, self._api, frozen=config.frozen is True)
|
||||
is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1])
|
||||
self.add_initializer(fields, config, is_settings)
|
||||
self.add_model_construct_method(fields, config, is_settings)
|
||||
self.set_frozen(fields, frozen=config.frozen is True)
|
||||
|
||||
self.adjust_decorator_signatures()
|
||||
|
||||
info.metadata[METADATA_KEY] = {
|
||||
'fields': {field.name: field.serialize() for field in fields},
|
||||
'class_vars': {class_var.name: class_var.serialize() for class_var in class_vars},
|
||||
'config': config.get_values_dict(),
|
||||
}
|
||||
|
||||
@@ -494,13 +461,13 @@ class PydanticModelTransformer:
|
||||
Teach mypy this by marking any function whose outermost decorator is a `validator()`,
|
||||
`field_validator()` or `serializer()` call as a `classmethod`.
|
||||
"""
|
||||
for sym in self._cls.info.names.values():
|
||||
for name, sym in self._cls.info.names.items():
|
||||
if isinstance(sym.node, Decorator):
|
||||
first_dec = sym.node.original_decorators[0]
|
||||
if (
|
||||
isinstance(first_dec, CallExpr)
|
||||
and isinstance(first_dec.callee, NameExpr)
|
||||
and first_dec.callee.fullname in IMPLICIT_CLASSMETHOD_DECORATOR_FULLNAMES
|
||||
and first_dec.callee.fullname in DECORATOR_FULLNAMES
|
||||
# @model_validator(mode="after") is an exception, it expects a regular method
|
||||
and not (
|
||||
first_dec.callee.fullname == MODEL_VALIDATOR_FULLNAME
|
||||
@@ -543,7 +510,7 @@ class PydanticModelTransformer:
|
||||
for arg_name, arg in zip(stmt.rvalue.arg_names, stmt.rvalue.args):
|
||||
if arg_name is None:
|
||||
continue
|
||||
config.update(self.get_config_update(arg_name, arg, lax_extra=True))
|
||||
config.update(self.get_config_update(arg_name, arg))
|
||||
elif isinstance(stmt.rvalue, DictExpr): # dict literals
|
||||
for key_expr, value_expr in stmt.rvalue.items:
|
||||
if not isinstance(key_expr, StrExpr):
|
||||
@@ -574,7 +541,7 @@ class PydanticModelTransformer:
|
||||
if (
|
||||
stmt
|
||||
and config.has_alias_generator
|
||||
and not (config.validate_by_name or config.populate_by_name)
|
||||
and not config.populate_by_name
|
||||
and self.plugin_config.warn_required_dynamic_aliases
|
||||
):
|
||||
error_required_dynamic_aliases(self._api, stmt)
|
||||
@@ -589,13 +556,11 @@ class PydanticModelTransformer:
|
||||
config.setdefault(name, value)
|
||||
return config
|
||||
|
||||
def collect_fields_and_class_vars(
|
||||
self, model_config: ModelConfigData, is_root_model: bool
|
||||
) -> tuple[list[PydanticModelField] | None, list[PydanticModelClassVar] | None]:
|
||||
def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelField] | None:
|
||||
"""Collects the fields for the model, accounting for parent classes."""
|
||||
cls = self._cls
|
||||
|
||||
# First, collect fields and ClassVars belonging to any class in the MRO, ignoring duplicates.
|
||||
# First, collect fields belonging to any class in the MRO, ignoring duplicates.
|
||||
#
|
||||
# We iterate through the MRO in reverse because attrs defined in the parent must appear
|
||||
# earlier in the attributes list than attrs defined in the child. See:
|
||||
@@ -605,11 +570,10 @@ class PydanticModelTransformer:
|
||||
# in the parent. We can implement this via a dict without disrupting the attr order
|
||||
# because dicts preserve insertion order in Python 3.7+.
|
||||
found_fields: dict[str, PydanticModelField] = {}
|
||||
found_class_vars: dict[str, PydanticModelClassVar] = {}
|
||||
for info in reversed(cls.info.mro[1:-1]): # 0 is the current class, -2 is BaseModel, -1 is object
|
||||
# if BASEMODEL_METADATA_TAG_KEY in info.metadata and BASEMODEL_METADATA_KEY not in info.metadata:
|
||||
# # We haven't processed the base class yet. Need another pass.
|
||||
# return None, None
|
||||
# return None
|
||||
if METADATA_KEY not in info.metadata:
|
||||
continue
|
||||
|
||||
@@ -622,7 +586,8 @@ class PydanticModelTransformer:
|
||||
# TODO: We shouldn't be performing type operations during the main
|
||||
# semantic analysis pass, since some TypeInfo attributes might
|
||||
# still be in flux. This should be performed in a later phase.
|
||||
field.expand_typevar_from_subtype(cls.info, self._api)
|
||||
with state.strict_optional_set(self._api.options.strict_optional):
|
||||
field.expand_typevar_from_subtype(cls.info)
|
||||
found_fields[name] = field
|
||||
|
||||
sym_node = cls.info.names.get(name)
|
||||
@@ -631,31 +596,17 @@ class PydanticModelTransformer:
|
||||
'BaseModel field may only be overridden by another field',
|
||||
sym_node.node,
|
||||
)
|
||||
# Collect ClassVars
|
||||
for name, data in info.metadata[METADATA_KEY]['class_vars'].items():
|
||||
found_class_vars[name] = PydanticModelClassVar.deserialize(data)
|
||||
|
||||
# Second, collect fields and ClassVars belonging to the current class.
|
||||
# Second, collect fields belonging to the current class.
|
||||
current_field_names: set[str] = set()
|
||||
current_class_vars_names: set[str] = set()
|
||||
for stmt in self._get_assignment_statements_from_block(cls.defs):
|
||||
maybe_field = self.collect_field_or_class_var_from_stmt(stmt, model_config, found_class_vars)
|
||||
if maybe_field is None:
|
||||
continue
|
||||
maybe_field = self.collect_field_from_stmt(stmt, model_config)
|
||||
if maybe_field is not None:
|
||||
lhs = stmt.lvalues[0]
|
||||
current_field_names.add(lhs.name)
|
||||
found_fields[lhs.name] = maybe_field
|
||||
|
||||
lhs = stmt.lvalues[0]
|
||||
assert isinstance(lhs, NameExpr) # collect_field_or_class_var_from_stmt guarantees this
|
||||
if isinstance(maybe_field, PydanticModelField):
|
||||
if is_root_model and lhs.name != 'root':
|
||||
error_extra_fields_on_root_model(self._api, stmt)
|
||||
else:
|
||||
current_field_names.add(lhs.name)
|
||||
found_fields[lhs.name] = maybe_field
|
||||
elif isinstance(maybe_field, PydanticModelClassVar):
|
||||
current_class_vars_names.add(lhs.name)
|
||||
found_class_vars[lhs.name] = maybe_field
|
||||
|
||||
return list(found_fields.values()), list(found_class_vars.values())
|
||||
return list(found_fields.values())
|
||||
|
||||
def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]:
|
||||
for body in stmt.body:
|
||||
@@ -671,15 +622,14 @@ class PydanticModelTransformer:
|
||||
elif isinstance(stmt, IfStmt):
|
||||
yield from self._get_assignment_statements_from_if_statement(stmt)
|
||||
|
||||
def collect_field_or_class_var_from_stmt( # noqa C901
|
||||
self, stmt: AssignmentStmt, model_config: ModelConfigData, class_vars: dict[str, PydanticModelClassVar]
|
||||
) -> PydanticModelField | PydanticModelClassVar | None:
|
||||
def collect_field_from_stmt( # noqa C901
|
||||
self, stmt: AssignmentStmt, model_config: ModelConfigData
|
||||
) -> PydanticModelField | None:
|
||||
"""Get pydantic model field from statement.
|
||||
|
||||
Args:
|
||||
stmt: The statement.
|
||||
model_config: Configuration settings for the model.
|
||||
class_vars: ClassVars already known to be defined on the model.
|
||||
|
||||
Returns:
|
||||
A pydantic model field if it could find the field in statement. Otherwise, `None`.
|
||||
@@ -702,10 +652,6 @@ class PydanticModelTransformer:
|
||||
# Eventually, we may want to attempt to respect model_config['ignored_types']
|
||||
return None
|
||||
|
||||
if lhs.name in class_vars:
|
||||
# Class vars are not fields and are not required to be annotated
|
||||
return None
|
||||
|
||||
# The assignment does not have an annotation, and it's not anything else we recognize
|
||||
error_untyped_fields(self._api, stmt)
|
||||
return None
|
||||
@@ -750,7 +696,7 @@ class PydanticModelTransformer:
|
||||
|
||||
# x: ClassVar[int] is not a field
|
||||
if node.is_classvar:
|
||||
return PydanticModelClassVar(lhs.name)
|
||||
return None
|
||||
|
||||
# x: InitVar[int] is not supported in BaseModel
|
||||
node_type = get_proper_type(node.type)
|
||||
@@ -761,7 +707,6 @@ class PydanticModelTransformer:
|
||||
)
|
||||
|
||||
has_default = self.get_has_default(stmt)
|
||||
strict = self.get_strict(stmt)
|
||||
|
||||
if sym.type is None and node.is_final and node.is_inferred:
|
||||
# This follows the logic from the dataclasses plugin. The following comment is taken verbatim:
|
||||
@@ -781,27 +726,16 @@ class PydanticModelTransformer:
|
||||
)
|
||||
node.type = AnyType(TypeOfAny.from_error)
|
||||
|
||||
if node.is_final and has_default:
|
||||
# TODO this path should be removed (see https://github.com/pydantic/pydantic/issues/11119)
|
||||
return PydanticModelClassVar(lhs.name)
|
||||
|
||||
alias, has_dynamic_alias = self.get_alias_info(stmt)
|
||||
if (
|
||||
has_dynamic_alias
|
||||
and not (model_config.validate_by_name or model_config.populate_by_name)
|
||||
and self.plugin_config.warn_required_dynamic_aliases
|
||||
):
|
||||
if has_dynamic_alias and not model_config.populate_by_name and self.plugin_config.warn_required_dynamic_aliases:
|
||||
error_required_dynamic_aliases(self._api, stmt)
|
||||
is_frozen = self.is_field_frozen(stmt)
|
||||
|
||||
init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt)
|
||||
return PydanticModelField(
|
||||
name=lhs.name,
|
||||
has_dynamic_alias=has_dynamic_alias,
|
||||
has_default=has_default,
|
||||
strict=strict,
|
||||
alias=alias,
|
||||
is_frozen=is_frozen,
|
||||
line=stmt.line,
|
||||
column=stmt.column,
|
||||
type=init_type,
|
||||
@@ -846,9 +780,7 @@ class PydanticModelTransformer:
|
||||
|
||||
return default
|
||||
|
||||
def add_initializer(
|
||||
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool
|
||||
) -> None:
|
||||
def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool) -> None:
|
||||
"""Adds a fields-aware `__init__` method to the class.
|
||||
|
||||
The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
|
||||
@@ -857,42 +789,28 @@ class PydanticModelTransformer:
|
||||
return # Don't generate an __init__ if one already exists
|
||||
|
||||
typed = self.plugin_config.init_typed
|
||||
model_strict = bool(config.strict)
|
||||
use_alias = not (config.validate_by_name or config.populate_by_name) and config.validate_by_alias is not False
|
||||
requires_dynamic_aliases = bool(config.has_alias_generator and not config.validate_by_name)
|
||||
args = self.get_field_arguments(
|
||||
fields,
|
||||
typed=typed,
|
||||
model_strict=model_strict,
|
||||
requires_dynamic_aliases=requires_dynamic_aliases,
|
||||
use_alias=use_alias,
|
||||
is_settings=is_settings,
|
||||
is_root_model=is_root_model,
|
||||
force_typevars_invariant=True,
|
||||
)
|
||||
|
||||
if is_settings:
|
||||
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
|
||||
assert isinstance(base_settings_node, TypeInfo)
|
||||
if '__init__' in base_settings_node.names:
|
||||
base_settings_init_node = base_settings_node.names['__init__'].node
|
||||
assert isinstance(base_settings_init_node, FuncDef)
|
||||
if base_settings_init_node is not None and base_settings_init_node.type is not None:
|
||||
func_type = base_settings_init_node.type
|
||||
assert isinstance(func_type, CallableType)
|
||||
for arg_idx, arg_name in enumerate(func_type.arg_names):
|
||||
if arg_name is None or arg_name.startswith('__') or not arg_name.startswith('_'):
|
||||
continue
|
||||
analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx])
|
||||
if analyzed_variable_type is not None and arg_name == '_cli_settings_source':
|
||||
# _cli_settings_source is defined as CliSettingsSource[Any], and as such
|
||||
# the Any causes issues with --disallow-any-explicit. As a workaround, change
|
||||
# the Any type (as if CliSettingsSource was left unparameterized):
|
||||
analyzed_variable_type = analyzed_variable_type.accept(
|
||||
ChangeExplicitTypeOfAny(TypeOfAny.from_omitted_generics)
|
||||
)
|
||||
variable = Var(arg_name, analyzed_variable_type)
|
||||
args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT))
|
||||
use_alias = config.populate_by_name is not True
|
||||
requires_dynamic_aliases = bool(config.has_alias_generator and not config.populate_by_name)
|
||||
with state.strict_optional_set(self._api.options.strict_optional):
|
||||
args = self.get_field_arguments(
|
||||
fields,
|
||||
typed=typed,
|
||||
requires_dynamic_aliases=requires_dynamic_aliases,
|
||||
use_alias=use_alias,
|
||||
is_settings=is_settings,
|
||||
)
|
||||
if is_settings:
|
||||
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
|
||||
if '__init__' in base_settings_node.names:
|
||||
base_settings_init_node = base_settings_node.names['__init__'].node
|
||||
if base_settings_init_node is not None and base_settings_init_node.type is not None:
|
||||
func_type = base_settings_init_node.type
|
||||
for arg_idx, arg_name in enumerate(func_type.arg_names):
|
||||
if arg_name.startswith('__') or not arg_name.startswith('_'):
|
||||
continue
|
||||
analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx])
|
||||
variable = Var(arg_name, analyzed_variable_type)
|
||||
args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT))
|
||||
|
||||
if not self.should_init_forbid_extra(fields, config):
|
||||
var = Var('kwargs')
|
||||
@@ -901,11 +819,7 @@ class PydanticModelTransformer:
|
||||
add_method(self._api, self._cls, '__init__', args=args, return_type=NoneType())
|
||||
|
||||
def add_model_construct_method(
|
||||
self,
|
||||
fields: list[PydanticModelField],
|
||||
config: ModelConfigData,
|
||||
is_settings: bool,
|
||||
is_root_model: bool,
|
||||
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool
|
||||
) -> None:
|
||||
"""Adds a fully typed `model_construct` classmethod to the class.
|
||||
|
||||
@@ -917,19 +831,13 @@ class PydanticModelTransformer:
|
||||
fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT)
|
||||
with state.strict_optional_set(self._api.options.strict_optional):
|
||||
args = self.get_field_arguments(
|
||||
fields,
|
||||
typed=True,
|
||||
model_strict=bool(config.strict),
|
||||
requires_dynamic_aliases=False,
|
||||
use_alias=False,
|
||||
is_settings=is_settings,
|
||||
is_root_model=is_root_model,
|
||||
fields, typed=True, requires_dynamic_aliases=False, use_alias=False, is_settings=is_settings
|
||||
)
|
||||
if not self.should_init_forbid_extra(fields, config):
|
||||
var = Var('kwargs')
|
||||
args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
|
||||
|
||||
args = args + [fields_set_argument] if is_root_model else [fields_set_argument] + args
|
||||
args = [fields_set_argument] + args
|
||||
|
||||
add_method(
|
||||
self._api,
|
||||
@@ -940,7 +848,7 @@ class PydanticModelTransformer:
|
||||
is_classmethod=True,
|
||||
)
|
||||
|
||||
def set_frozen(self, fields: list[PydanticModelField], api: SemanticAnalyzerPluginInterface, frozen: bool) -> None:
|
||||
def set_frozen(self, fields: list[PydanticModelField], frozen: bool) -> None:
|
||||
"""Marks all fields as properties so that attempts to set them trigger mypy errors.
|
||||
|
||||
This is the same approach used by the attrs and dataclasses plugins.
|
||||
@@ -951,7 +859,7 @@ class PydanticModelTransformer:
|
||||
if sym_node is not None:
|
||||
var = sym_node.node
|
||||
if isinstance(var, Var):
|
||||
var.is_property = frozen or field.is_frozen
|
||||
var.is_property = frozen
|
||||
elif isinstance(var, PlaceholderNode) and not self._api.final_iteration:
|
||||
# See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage
|
||||
self._api.defer()
|
||||
@@ -965,13 +873,13 @@ class PydanticModelTransformer:
|
||||
detail = f'sym_node.node: {var_str} (of type {var.__class__})'
|
||||
error_unexpected_behavior(detail, self._api, self._cls)
|
||||
else:
|
||||
var = field.to_var(info, api, use_alias=False)
|
||||
var = field.to_var(info, use_alias=False)
|
||||
var.info = info
|
||||
var.is_property = frozen
|
||||
var._fullname = info.fullname + '.' + var.name
|
||||
info.names[var.name] = SymbolTableNode(MDEF, var)
|
||||
|
||||
def get_config_update(self, name: str, arg: Expression, lax_extra: bool = False) -> ModelConfigData | None:
|
||||
def get_config_update(self, name: str, arg: Expression) -> ModelConfigData | None:
|
||||
"""Determines the config update due to a single kwarg in the ConfigDict definition.
|
||||
|
||||
Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int)
|
||||
@@ -984,16 +892,7 @@ class PydanticModelTransformer:
|
||||
elif isinstance(arg, MemberExpr):
|
||||
forbid_extra = arg.name == 'forbid'
|
||||
else:
|
||||
if not lax_extra:
|
||||
# Only emit an error for other types of `arg` (e.g., `NameExpr`, `ConditionalExpr`, etc.) when
|
||||
# reading from a config class, etc. If a ConfigDict is used, then we don't want to emit an error
|
||||
# because you'll get type checking from the ConfigDict itself.
|
||||
#
|
||||
# It would be nice if we could introspect the types better otherwise, but I don't know what the API
|
||||
# is to evaluate an expr into its type and then check if that type is compatible with the expected
|
||||
# type. Note that you can still get proper type checking via: `model_config = ConfigDict(...)`, just
|
||||
# if you don't use an explicit string, the plugin won't be able to infer whether extra is forbidden.
|
||||
error_invalid_config_value(name, self._api, arg)
|
||||
error_invalid_config_value(name, self._api, arg)
|
||||
return None
|
||||
return ModelConfigData(forbid_extra=forbid_extra)
|
||||
if name == 'alias_generator':
|
||||
@@ -1028,22 +927,6 @@ class PydanticModelTransformer:
|
||||
# Has no default if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`)
|
||||
return not isinstance(expr, EllipsisExpr)
|
||||
|
||||
@staticmethod
|
||||
def get_strict(stmt: AssignmentStmt) -> bool | None:
|
||||
"""Returns a the `strict` value of a field if defined, otherwise `None`."""
|
||||
expr = stmt.rvalue
|
||||
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME:
|
||||
for arg, name in zip(expr.args, expr.arg_names):
|
||||
if name != 'strict':
|
||||
continue
|
||||
if isinstance(arg, NameExpr):
|
||||
if arg.fullname == 'builtins.True':
|
||||
return True
|
||||
elif arg.fullname == 'builtins.False':
|
||||
return False
|
||||
return None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_alias_info(stmt: AssignmentStmt) -> tuple[str | None, bool]:
|
||||
"""Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`.
|
||||
@@ -1062,53 +945,23 @@ class PydanticModelTransformer:
|
||||
# Assigned value is not a call to pydantic.fields.Field
|
||||
return None, False
|
||||
|
||||
if 'validation_alias' in expr.arg_names:
|
||||
arg = expr.args[expr.arg_names.index('validation_alias')]
|
||||
elif 'alias' in expr.arg_names:
|
||||
arg = expr.args[expr.arg_names.index('alias')]
|
||||
else:
|
||||
return None, False
|
||||
|
||||
if isinstance(arg, StrExpr):
|
||||
return arg.value, False
|
||||
else:
|
||||
return None, True
|
||||
|
||||
@staticmethod
|
||||
def is_field_frozen(stmt: AssignmentStmt) -> bool:
|
||||
"""Returns whether the field is frozen, extracted from the declaration of the field defined in `stmt`.
|
||||
|
||||
Note that this is only whether the field was declared to be frozen in a `<field_name> = Field(frozen=True)`
|
||||
sense; this does not determine whether the field is frozen because the entire model is frozen; that is
|
||||
handled separately.
|
||||
"""
|
||||
expr = stmt.rvalue
|
||||
if isinstance(expr, TempNode):
|
||||
# TempNode means annotation-only
|
||||
return False
|
||||
|
||||
if not (
|
||||
isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME
|
||||
):
|
||||
# Assigned value is not a call to pydantic.fields.Field
|
||||
return False
|
||||
|
||||
for i, arg_name in enumerate(expr.arg_names):
|
||||
if arg_name == 'frozen':
|
||||
arg = expr.args[i]
|
||||
return isinstance(arg, NameExpr) and arg.fullname == 'builtins.True'
|
||||
return False
|
||||
if arg_name != 'alias':
|
||||
continue
|
||||
arg = expr.args[i]
|
||||
if isinstance(arg, StrExpr):
|
||||
return arg.value, False
|
||||
else:
|
||||
return None, True
|
||||
return None, False
|
||||
|
||||
def get_field_arguments(
|
||||
self,
|
||||
fields: list[PydanticModelField],
|
||||
typed: bool,
|
||||
model_strict: bool,
|
||||
use_alias: bool,
|
||||
requires_dynamic_aliases: bool,
|
||||
is_settings: bool,
|
||||
is_root_model: bool,
|
||||
force_typevars_invariant: bool = False,
|
||||
) -> list[Argument]:
|
||||
"""Helper function used during the construction of the `__init__` and `model_construct` method signatures.
|
||||
|
||||
@@ -1117,14 +970,7 @@ class PydanticModelTransformer:
|
||||
info = self._cls.info
|
||||
arguments = [
|
||||
field.to_argument(
|
||||
info,
|
||||
typed=typed,
|
||||
model_strict=model_strict,
|
||||
force_optional=requires_dynamic_aliases or is_settings,
|
||||
use_alias=use_alias,
|
||||
api=self._api,
|
||||
force_typevars_invariant=force_typevars_invariant,
|
||||
is_root_model_root=is_root_model and field.name == 'root',
|
||||
info, typed=typed, force_optional=requires_dynamic_aliases or is_settings, use_alias=use_alias
|
||||
)
|
||||
for field in fields
|
||||
if not (use_alias and field.has_dynamic_alias)
|
||||
@@ -1137,7 +983,7 @@ class PydanticModelTransformer:
|
||||
We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to,
|
||||
*unless* a required dynamic alias is present (since then we can't determine a valid signature).
|
||||
"""
|
||||
if not (config.validate_by_name or config.populate_by_name):
|
||||
if not config.populate_by_name:
|
||||
if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)):
|
||||
return False
|
||||
if config.forbid_extra:
|
||||
@@ -1159,20 +1005,6 @@ class PydanticModelTransformer:
|
||||
return False
|
||||
|
||||
|
||||
class ChangeExplicitTypeOfAny(TypeTranslator):
|
||||
"""A type translator used to change type of Any's, if explicit."""
|
||||
|
||||
def __init__(self, type_of_any: int) -> None:
|
||||
self._type_of_any = type_of_any
|
||||
super().__init__()
|
||||
|
||||
def visit_any(self, t: AnyType) -> Type: # noqa: D102
|
||||
if t.type_of_any == TypeOfAny.explicit:
|
||||
return t.copy_modified(type_of_any=self._type_of_any)
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
class ModelConfigData:
|
||||
"""Pydantic mypy plugin model config class."""
|
||||
|
||||
@@ -1182,19 +1014,13 @@ class ModelConfigData:
|
||||
frozen: bool | None = None,
|
||||
from_attributes: bool | None = None,
|
||||
populate_by_name: bool | None = None,
|
||||
validate_by_alias: bool | None = None,
|
||||
validate_by_name: bool | None = None,
|
||||
has_alias_generator: bool | None = None,
|
||||
strict: bool | None = None,
|
||||
):
|
||||
self.forbid_extra = forbid_extra
|
||||
self.frozen = frozen
|
||||
self.from_attributes = from_attributes
|
||||
self.populate_by_name = populate_by_name
|
||||
self.validate_by_alias = validate_by_alias
|
||||
self.validate_by_name = validate_by_name
|
||||
self.has_alias_generator = has_alias_generator
|
||||
self.strict = strict
|
||||
|
||||
def get_values_dict(self) -> dict[str, Any]:
|
||||
"""Returns a dict of Pydantic model config names to their values.
|
||||
@@ -1216,18 +1042,12 @@ class ModelConfigData:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def is_root_model(info: TypeInfo) -> bool:
|
||||
"""Return whether the type info is a root model subclass (or the `RootModel` class itself)."""
|
||||
return info.has_base(ROOT_MODEL_FULLNAME)
|
||||
|
||||
|
||||
ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_attributes call', 'Pydantic')
|
||||
ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic')
|
||||
ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic')
|
||||
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
|
||||
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
|
||||
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
|
||||
ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic')
|
||||
|
||||
|
||||
def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
|
||||
@@ -1264,9 +1084,9 @@ def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context)
|
||||
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)
|
||||
|
||||
|
||||
def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None:
|
||||
"""Emits an error when there is more than just a root field defined for a subclass of RootModel."""
|
||||
api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL)
|
||||
def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
|
||||
"""Emits an error when `Field` has both `default` and `default_factory` together."""
|
||||
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
|
||||
|
||||
|
||||
def add_method(
|
||||
@@ -1276,7 +1096,7 @@ def add_method(
|
||||
args: list[Argument],
|
||||
return_type: Type,
|
||||
self_type: Type | None = None,
|
||||
tvar_def: TypeVarType | None = None,
|
||||
tvar_def: TypeVarDef | None = None,
|
||||
is_classmethod: bool = False,
|
||||
) -> None:
|
||||
"""Very closely related to `mypy.plugins.common.add_method_to_class`, with a few pydantic-specific changes."""
|
||||
@@ -1299,16 +1119,6 @@ def add_method(
|
||||
first = [Argument(Var('_cls'), self_type, None, ARG_POS, True)]
|
||||
else:
|
||||
self_type = self_type or fill_typevars(info)
|
||||
# `self` is positional *ONLY* here, but this can't be expressed
|
||||
# fully in the mypy internal API. ARG_POS is the closest we can get.
|
||||
# Using ARG_POS will, however, give mypy errors if a `self` field
|
||||
# is present on a model:
|
||||
#
|
||||
# Name "self" already defined (possibly by an import) [no-redef]
|
||||
#
|
||||
# As a workaround, we give this argument a name that will
|
||||
# never conflict. By its positional nature, this name will not
|
||||
# be used or exposed to users.
|
||||
first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)]
|
||||
args = first + args
|
||||
|
||||
@@ -1319,9 +1129,9 @@ def add_method(
|
||||
arg_names.append(arg.variable.name)
|
||||
arg_kinds.append(arg.kind)
|
||||
|
||||
signature = CallableType(
|
||||
arg_types, arg_kinds, arg_names, return_type, function_type, variables=[tvar_def] if tvar_def else None
|
||||
)
|
||||
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
|
||||
if tvar_def:
|
||||
signature.variables = [tvar_def]
|
||||
|
||||
func = FuncDef(name, args, Block([PassStmt()]))
|
||||
func.info = info
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,4 @@
|
||||
"""The `parse` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[Build a Plugin](../concepts/plugins.md#build-a-plugin)
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/plugins#build-a-plugin
|
||||
|
||||
Plugin interface for Pydantic plugins, and related types.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Literal, NamedTuple
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic_core import CoreConfig, CoreSchema, ValidationError
|
||||
from typing_extensions import Protocol, TypeAlias
|
||||
@@ -18,32 +16,17 @@ __all__ = (
|
||||
'ValidateJsonHandlerProtocol',
|
||||
'ValidateStringsHandlerProtocol',
|
||||
'NewSchemaReturns',
|
||||
'SchemaTypePath',
|
||||
'SchemaKind',
|
||||
)
|
||||
|
||||
NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]'
|
||||
|
||||
|
||||
class SchemaTypePath(NamedTuple):
|
||||
"""Path defining where `schema_type` was defined, or where `TypeAdapter` was called."""
|
||||
|
||||
module: str
|
||||
name: str
|
||||
|
||||
|
||||
SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call']
|
||||
|
||||
|
||||
class PydanticPluginProtocol(Protocol):
|
||||
"""Protocol defining the interface for Pydantic plugins."""
|
||||
|
||||
def new_schema_validator(
|
||||
self,
|
||||
schema: CoreSchema,
|
||||
schema_type: Any,
|
||||
schema_type_path: SchemaTypePath,
|
||||
schema_kind: SchemaKind,
|
||||
config: CoreConfig | None,
|
||||
plugin_settings: dict[str, object],
|
||||
) -> tuple[
|
||||
@@ -57,9 +40,6 @@ class PydanticPluginProtocol(Protocol):
|
||||
|
||||
Args:
|
||||
schema: The schema to validate against.
|
||||
schema_type: The original type which the schema was created from, e.g. the model class.
|
||||
schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called.
|
||||
schema_kind: The kind of schema to validate against.
|
||||
config: The config to use for validation.
|
||||
plugin_settings: Any plugin settings.
|
||||
|
||||
@@ -96,14 +76,6 @@ class BaseValidateHandlerProtocol(Protocol):
|
||||
"""
|
||||
return
|
||||
|
||||
def on_exception(self, exception: Exception) -> None:
|
||||
"""Callback to be notified of validation exceptions.
|
||||
|
||||
Args:
|
||||
exception: The exception raised during validation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
"""Event handler for `SchemaValidator.validate_python`."""
|
||||
@@ -116,8 +88,6 @@ class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
from_attributes: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
self_instance: Any | None = None,
|
||||
by_alias: bool | None = None,
|
||||
by_name: bool | None = None,
|
||||
) -> None:
|
||||
"""Callback to be notified of validation start, and create an instance of the event handler.
|
||||
|
||||
@@ -128,8 +98,6 @@ class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
context: The context to use for validation, this is passed to functional validators.
|
||||
self_instance: An instance of a model to set attributes on from validation, this is used when running
|
||||
validation from the `__init__` method of a model.
|
||||
by_alias: Whether to use the field's alias to match the input data to an attribute.
|
||||
by_name: Whether to use the field's name to match the input data to an attribute.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -144,8 +112,6 @@ class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
strict: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
self_instance: Any | None = None,
|
||||
by_alias: bool | None = None,
|
||||
by_name: bool | None = None,
|
||||
) -> None:
|
||||
"""Callback to be notified of validation start, and create an instance of the event handler.
|
||||
|
||||
@@ -155,8 +121,6 @@ class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
context: The context to use for validation, this is passed to functional validators.
|
||||
self_instance: An instance of a model to set attributes on from validation, this is used when running
|
||||
validation from the `__init__` method of a model.
|
||||
by_alias: Whether to use the field's alias to match the input data to an attribute.
|
||||
by_name: Whether to use the field's name to match the input data to an attribute.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -168,13 +132,7 @@ class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
"""Event handler for `SchemaValidator.validate_strings`."""
|
||||
|
||||
def on_enter(
|
||||
self,
|
||||
input: StringInput,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
by_alias: bool | None = None,
|
||||
by_name: bool | None = None,
|
||||
self, input: StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Callback to be notified of validation start, and create an instance of the event handler.
|
||||
|
||||
@@ -182,7 +140,5 @@ class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
|
||||
input: The string data to be validated.
|
||||
strict: Whether to validate the object in strict mode.
|
||||
context: The context to use for validation, this is passed to functional validators.
|
||||
by_alias: Whether to use the field's alias to match the input data to an attribute.
|
||||
by_name: Whether to use the field's name to match the input data to an attribute.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.metadata as importlib_metadata
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from typing_extensions import Final
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
import importlib.metadata as importlib_metadata
|
||||
else:
|
||||
import importlib_metadata
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import PydanticPluginProtocol
|
||||
@@ -24,13 +30,10 @@ def get_plugins() -> Iterable[PydanticPluginProtocol]:
|
||||
|
||||
Inspired by: https://github.com/pytest-dev/pluggy/blob/1.3.0/src/pluggy/_manager.py#L376-L402
|
||||
"""
|
||||
disabled_plugins = os.getenv('PYDANTIC_DISABLE_PLUGINS')
|
||||
global _plugins, _loading_plugins
|
||||
if _loading_plugins:
|
||||
# this happens when plugins themselves use pydantic, we return no plugins
|
||||
return ()
|
||||
elif disabled_plugins in ('__all__', '1', 'true'):
|
||||
return ()
|
||||
elif _plugins is None:
|
||||
_plugins = {}
|
||||
# set _loading_plugins so any plugins that use pydantic don't themselves use plugins
|
||||
@@ -42,8 +45,6 @@ def get_plugins() -> Iterable[PydanticPluginProtocol]:
|
||||
continue
|
||||
if entry_point.value in _plugins:
|
||||
continue
|
||||
if disabled_plugins is not None and entry_point.name in disabled_plugins.split(','):
|
||||
continue
|
||||
try:
|
||||
_plugins[entry_point.value] = entry_point.load()
|
||||
except (ImportError, AttributeError) as e:
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
"""Pluggable schema validator for pydantic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar
|
||||
|
||||
from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError
|
||||
from typing_extensions import ParamSpec
|
||||
from typing_extensions import Literal, ParamSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath
|
||||
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol
|
||||
|
||||
|
||||
P = ParamSpec('P')
|
||||
@@ -20,33 +18,18 @@ events: list[Event] = list(Event.__args__) # type: ignore
|
||||
|
||||
|
||||
def create_schema_validator(
|
||||
schema: CoreSchema,
|
||||
schema_type: Any,
|
||||
schema_type_module: str,
|
||||
schema_type_name: str,
|
||||
schema_kind: SchemaKind,
|
||||
config: CoreConfig | None = None,
|
||||
plugin_settings: dict[str, Any] | None = None,
|
||||
) -> SchemaValidator | PluggableSchemaValidator:
|
||||
schema: CoreSchema, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None
|
||||
) -> SchemaValidator:
|
||||
"""Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
|
||||
|
||||
Returns:
|
||||
If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
|
||||
"""
|
||||
from . import SchemaTypePath
|
||||
from ._loader import get_plugins
|
||||
|
||||
plugins = get_plugins()
|
||||
if plugins:
|
||||
return PluggableSchemaValidator(
|
||||
schema,
|
||||
schema_type,
|
||||
SchemaTypePath(schema_type_module, schema_type_name),
|
||||
schema_kind,
|
||||
config,
|
||||
plugins,
|
||||
plugin_settings or {},
|
||||
)
|
||||
return PluggableSchemaValidator(schema, config, plugins, plugin_settings or {}) # type: ignore
|
||||
else:
|
||||
return SchemaValidator(schema, config)
|
||||
|
||||
@@ -59,9 +42,6 @@ class PluggableSchemaValidator:
|
||||
def __init__(
|
||||
self,
|
||||
schema: CoreSchema,
|
||||
schema_type: Any,
|
||||
schema_type_path: SchemaTypePath,
|
||||
schema_kind: SchemaKind,
|
||||
config: CoreConfig | None,
|
||||
plugins: Iterable[PydanticPluginProtocol],
|
||||
plugin_settings: dict[str, Any],
|
||||
@@ -72,12 +52,7 @@ class PluggableSchemaValidator:
|
||||
json_event_handlers: list[BaseValidateHandlerProtocol] = []
|
||||
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
|
||||
for plugin in plugins:
|
||||
try:
|
||||
p, j, s = plugin.new_schema_validator(
|
||||
schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
|
||||
)
|
||||
except TypeError as e: # pragma: no cover
|
||||
raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e
|
||||
p, j, s = plugin.new_schema_validator(schema, config, plugin_settings)
|
||||
if p is not None:
|
||||
python_event_handlers.append(p)
|
||||
if j is not None:
|
||||
@@ -100,7 +75,6 @@ def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandler
|
||||
on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter'))
|
||||
on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success'))
|
||||
on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error'))
|
||||
on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception'))
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
@@ -113,10 +87,6 @@ def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandler
|
||||
for on_error_handler in on_errors:
|
||||
on_error_handler(error)
|
||||
raise
|
||||
except Exception as exception:
|
||||
for on_exception_handler in on_exceptions:
|
||||
on_exception_handler(exception)
|
||||
raise
|
||||
else:
|
||||
for on_success_handler in on_successes:
|
||||
on_success_handler(result)
|
||||
|
||||
@@ -8,33 +8,25 @@ from copy import copy, deepcopy
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from . import PydanticUserError
|
||||
from ._internal import _model_construction, _repr
|
||||
from ._internal import _repr
|
||||
from .main import BaseModel, _object_setattr
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Self, dataclass_transform
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .fields import Field as PydanticModelField
|
||||
from .fields import PrivateAttr as PydanticModelPrivateAttr
|
||||
Model = typing.TypeVar('Model', bound='BaseModel')
|
||||
|
||||
# dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform
|
||||
# takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform
|
||||
# on a new metaclass.
|
||||
@dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr))
|
||||
class _RootModelMetaclass(_model_construction.ModelMetaclass): ...
|
||||
else:
|
||||
_RootModelMetaclass = _model_construction.ModelMetaclass
|
||||
|
||||
__all__ = ('RootModel',)
|
||||
|
||||
|
||||
RootModelRootType = typing.TypeVar('RootModelRootType')
|
||||
|
||||
|
||||
class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass):
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[`RootModel` and Custom Root Types](../concepts/models.md#rootmodel-and-custom-root-types)
|
||||
class RootModel(BaseModel, typing.Generic[RootModelRootType]):
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/models/#rootmodel-and-custom-root-types
|
||||
|
||||
A Pydantic `BaseModel` for the root object of the model.
|
||||
|
||||
@@ -60,7 +52,7 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
|
||||
)
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
|
||||
def __init__(__pydantic_self__, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
|
||||
__tracebackhide__ = True
|
||||
if data:
|
||||
if root is not PydanticUndefined:
|
||||
@@ -68,12 +60,12 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
|
||||
'"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments'
|
||||
)
|
||||
root = data # type: ignore
|
||||
self.__pydantic_validator__.validate_python(root, self_instance=self)
|
||||
__pydantic_self__.__pydantic_validator__.validate_python(root, self_instance=__pydantic_self__)
|
||||
|
||||
__init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess]
|
||||
__init__.__pydantic_base_init__ = True
|
||||
|
||||
@classmethod
|
||||
def model_construct(cls, root: RootModelRootType, _fields_set: set[str] | None = None) -> Self: # type: ignore
|
||||
def model_construct(cls: type[Model], root: RootModelRootType, _fields_set: set[str] | None = None) -> Model:
|
||||
"""Create a new model using the provided root object and update fields set.
|
||||
|
||||
Args:
|
||||
@@ -98,7 +90,7 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
|
||||
_object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
|
||||
_object_setattr(self, '__dict__', state['__dict__'])
|
||||
|
||||
def __copy__(self) -> Self:
|
||||
def __copy__(self: Model) -> Model:
|
||||
"""Returns a shallow copy of the model."""
|
||||
cls = type(self)
|
||||
m = cls.__new__(cls)
|
||||
@@ -106,7 +98,7 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
|
||||
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
|
||||
return m
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self:
|
||||
def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model:
|
||||
"""Returns a deep copy of the model."""
|
||||
cls = type(self)
|
||||
m = cls.__new__(cls)
|
||||
@@ -118,40 +110,30 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
def model_dump( # type: ignore
|
||||
def model_dump(
|
||||
self,
|
||||
*,
|
||||
mode: Literal['json', 'python'] | str = 'python',
|
||||
include: Any = None,
|
||||
exclude: Any = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
by_alias: bool | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal['none', 'warn', 'error'] = True,
|
||||
serialize_as_any: bool = False,
|
||||
) -> Any:
|
||||
warnings: bool = True,
|
||||
) -> RootModelRootType:
|
||||
"""This method is included just to get a more accurate return type for type checkers.
|
||||
It is included in this `if TYPE_CHECKING:` block since no override is actually necessary.
|
||||
|
||||
See the documentation of `BaseModel.model_dump` for more details about the arguments.
|
||||
|
||||
Generally, this method will have a return type of `RootModelRootType`, assuming that `RootModelRootType` is
|
||||
not a `BaseModel` subclass. If `RootModelRootType` is a `BaseModel` subclass, then the return
|
||||
type will likely be `dict[str, Any]`, as `model_dump` calls are recursive. The return type could
|
||||
even be something different, in the case of a custom serializer.
|
||||
Thus, `Any` is used here to catch all of these cases.
|
||||
"""
|
||||
...
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, RootModel):
|
||||
return NotImplemented
|
||||
return self.__pydantic_fields__['root'].annotation == other.__pydantic_fields__[
|
||||
'root'
|
||||
].annotation and super().__eq__(other)
|
||||
return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other)
|
||||
|
||||
def __repr_args__(self) -> _repr.ReprArgs:
|
||||
yield 'root', self.root
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""The `schema` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""The `tools` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,30 +1,93 @@
|
||||
"""Type adapter specification."""
|
||||
"""
|
||||
You may have types that are not `BaseModel`s that you want to validate data against.
|
||||
Or you may want to validate a `List[SomeModel]`, or dump it to JSON.
|
||||
|
||||
For use cases like this, Pydantic provides [`TypeAdapter`][pydantic.type_adapter.TypeAdapter],
|
||||
which can be used for type validation, serialization, and JSON schema generation without creating a
|
||||
[`BaseModel`][pydantic.main.BaseModel].
|
||||
|
||||
A [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] instance exposes some of the functionality from
|
||||
[`BaseModel`][pydantic.main.BaseModel] instance methods for types that do not have such methods
|
||||
(such as dataclasses, primitive types, and more):
|
||||
|
||||
```py
|
||||
from typing import List
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
class User(TypedDict):
|
||||
name: str
|
||||
id: int
|
||||
|
||||
UserListValidator = TypeAdapter(List[User])
|
||||
print(repr(UserListValidator.validate_python([{'name': 'Fred', 'id': '3'}])))
|
||||
#> [{'name': 'Fred', 'id': 3}]
|
||||
|
||||
try:
|
||||
UserListValidator.validate_python(
|
||||
[{'name': 'Fred', 'id': 'wrong', 'other': 'no'}]
|
||||
)
|
||||
except ValidationError as e:
|
||||
print(e)
|
||||
'''
|
||||
1 validation error for list[typed-dict]
|
||||
0.id
|
||||
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='wrong', input_type=str]
|
||||
'''
|
||||
```
|
||||
|
||||
Note:
|
||||
Despite some overlap in use cases with [`RootModel`][pydantic.root_model.RootModel],
|
||||
[`TypeAdapter`][pydantic.type_adapter.TypeAdapter] should not be used as a type annotation for
|
||||
specifying fields of a `BaseModel`, etc.
|
||||
|
||||
## Parsing data into a specified type
|
||||
|
||||
[`TypeAdapter`][pydantic.type_adapter.TypeAdapter] can be used to apply the parsing logic to populate Pydantic models
|
||||
in a more ad-hoc way. This function behaves similarly to
|
||||
[`BaseModel.model_validate`][pydantic.main.BaseModel.model_validate],
|
||||
but works with arbitrary Pydantic-compatible types.
|
||||
|
||||
This is especially useful when you want to parse results into a type that is not a direct subclass of
|
||||
[`BaseModel`][pydantic.main.BaseModel]. For example:
|
||||
|
||||
```py
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
class Item(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
# `item_data` could come from an API call, eg., via something like:
|
||||
# item_data = requests.get('https://my-api.com/items').json()
|
||||
item_data = [{'id': 1, 'name': 'My Item'}]
|
||||
|
||||
items = TypeAdapter(List[Item]).validate_python(item_data)
|
||||
print(items)
|
||||
#> [Item(id=1, name='My Item')]
|
||||
```
|
||||
|
||||
[`TypeAdapter`][pydantic.type_adapter.TypeAdapter] is capable of parsing data into any of the types Pydantic can
|
||||
handle as fields of a [`BaseModel`][pydantic.main.BaseModel].
|
||||
""" # noqa: D212
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import is_dataclass
|
||||
from types import FrameType
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeVar,
|
||||
cast,
|
||||
final,
|
||||
overload,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, overload
|
||||
|
||||
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some
|
||||
from typing_extensions import ParamSpec, is_typeddict
|
||||
from typing_extensions import Literal, is_typeddict
|
||||
|
||||
from pydantic.errors import PydanticUserError
|
||||
from pydantic.main import BaseModel, IncEx
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from ._internal import _config, _generate_schema, _mock_val_ser, _namespace_utils, _repr, _typing_extra, _utils
|
||||
from ._internal import _config, _core_utils, _discriminated_union, _generate_schema, _typing_extra
|
||||
from .config import ConfigDict
|
||||
from .errors import PydanticUndefinedAnnotation
|
||||
from .json_schema import (
|
||||
DEFAULT_REF_TEMPLATE,
|
||||
GenerateJsonSchema,
|
||||
@@ -32,12 +95,67 @@ from .json_schema import (
|
||||
JsonSchemaMode,
|
||||
JsonSchemaValue,
|
||||
)
|
||||
from .plugin._schema_validator import PluggableSchemaValidator, create_schema_validator
|
||||
from .plugin._schema_validator import create_schema_validator
|
||||
|
||||
T = TypeVar('T')
|
||||
R = TypeVar('R')
|
||||
P = ParamSpec('P')
|
||||
TypeAdapterT = TypeVar('TypeAdapterT', bound='TypeAdapter')
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
|
||||
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
|
||||
|
||||
|
||||
def _get_schema(type_: Any, config_wrapper: _config.ConfigWrapper, parent_depth: int) -> CoreSchema:
|
||||
"""`BaseModel` uses its own `__module__` to find out where it was defined
|
||||
and then look for symbols to resolve forward references in those globals.
|
||||
On the other hand this function can be called with arbitrary objects,
|
||||
including type aliases where `__module__` (always `typing.py`) is not useful.
|
||||
So instead we look at the globals in our parent stack frame.
|
||||
|
||||
This works for the case where this function is called in a module that
|
||||
has the target of forward references in its scope, but
|
||||
does not work for more complex cases.
|
||||
|
||||
For example, take the following:
|
||||
|
||||
a.py
|
||||
```python
|
||||
from typing import Dict, List
|
||||
|
||||
IntList = List[int]
|
||||
OuterDict = Dict[str, 'IntList']
|
||||
```
|
||||
|
||||
b.py
|
||||
```python test="skip"
|
||||
from a import OuterDict
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
IntList = int # replaces the symbol the forward reference is looking for
|
||||
v = TypeAdapter(OuterDict)
|
||||
v({'x': 1}) # should fail but doesn't
|
||||
```
|
||||
|
||||
If OuterDict were a `BaseModel`, this would work because it would resolve
|
||||
the forward reference within the `a.py` namespace.
|
||||
But `TypeAdapter(OuterDict)`
|
||||
can't know what module OuterDict came from.
|
||||
|
||||
In other words, the assumption that _all_ forward references exist in the
|
||||
module we are being called from is not technically always true.
|
||||
Although most of the time it is and it works fine for recursive models and such,
|
||||
`BaseModel`'s behavior isn't perfect either and _can_ break in similar ways,
|
||||
so there is no right or wrong between the two.
|
||||
|
||||
But at the very least this behavior is _subtly_ different from `BaseModel`'s.
|
||||
"""
|
||||
local_ns = _typing_extra.parent_frame_namespace(parent_depth=parent_depth)
|
||||
global_ns = sys._getframe(max(parent_depth - 1, 1)).f_globals.copy()
|
||||
global_ns.update(local_ns or {})
|
||||
gen = _generate_schema.GenerateSchema(config_wrapper, types_namespace=global_ns, typevars_map={})
|
||||
schema = gen.generate_schema(type_)
|
||||
schema = gen.collect_definitions(schema)
|
||||
return schema
|
||||
|
||||
|
||||
def _getattr_no_parents(obj: Any, attribute: str) -> Any:
|
||||
@@ -55,152 +173,59 @@ def _getattr_no_parents(obj: Any, attribute: str) -> Any:
|
||||
raise AttributeError(attribute)
|
||||
|
||||
|
||||
def _type_has_config(type_: Any) -> bool:
|
||||
"""Returns whether the type has config."""
|
||||
type_ = _typing_extra.annotated_type(type_) or type_
|
||||
try:
|
||||
return issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_)
|
||||
except TypeError:
|
||||
# type is not a class
|
||||
return False
|
||||
|
||||
|
||||
@final
|
||||
class TypeAdapter(Generic[T]):
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[`TypeAdapter`](../concepts/type_adapter.md)
|
||||
|
||||
Type adapters provide a flexible way to perform validation and serialization based on a Python type.
|
||||
"""Type adapters provide a flexible way to perform validation and serialization based on a Python type.
|
||||
|
||||
A `TypeAdapter` instance exposes some of the functionality from `BaseModel` instance methods
|
||||
for types that do not have such methods (such as dataclasses, primitive types, and more).
|
||||
|
||||
**Note:** `TypeAdapter` instances are not types, and cannot be used as type annotations for fields.
|
||||
|
||||
Args:
|
||||
type: The type associated with the `TypeAdapter`.
|
||||
config: Configuration for the `TypeAdapter`, should be a dictionary conforming to
|
||||
[`ConfigDict`][pydantic.config.ConfigDict].
|
||||
|
||||
!!! note
|
||||
You cannot provide a configuration when instantiating a `TypeAdapter` if the type you're using
|
||||
has its own config that cannot be overridden (ex: `BaseModel`, `TypedDict`, and `dataclass`). A
|
||||
[`type-adapter-config-unused`](../errors/usage_errors.md#type-adapter-config-unused) error will
|
||||
be raised in this case.
|
||||
_parent_depth: Depth at which to search for the [parent frame][frame-objects]. This frame is used when
|
||||
resolving forward annotations during schema building, by looking for the globals and locals of this
|
||||
frame. Defaults to 2, which will result in the frame where the `TypeAdapter` was instantiated.
|
||||
|
||||
!!! note
|
||||
This parameter is named with an underscore to suggest its private nature and discourage use.
|
||||
It may be deprecated in a minor version, so we only recommend using it if you're comfortable
|
||||
with potential change in behavior/support. It's default value is 2 because internally,
|
||||
the `TypeAdapter` class makes another call to fetch the frame.
|
||||
module: The module that passes to plugin if provided.
|
||||
Note that `TypeAdapter` is not an actual type, so you cannot use it in type annotations.
|
||||
|
||||
Attributes:
|
||||
core_schema: The core schema for the type.
|
||||
validator: The schema validator for the type.
|
||||
validator (SchemaValidator): The schema validator for the type.
|
||||
serializer: The schema serializer for the type.
|
||||
pydantic_complete: Whether the core schema for the type is successfully built.
|
||||
|
||||
??? tip "Compatibility with `mypy`"
|
||||
Depending on the type used, `mypy` might raise an error when instantiating a `TypeAdapter`. As a workaround, you can explicitly
|
||||
annotate your variable:
|
||||
|
||||
```py
|
||||
from typing import Union
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
ta: TypeAdapter[Union[str, int]] = TypeAdapter(Union[str, int]) # type: ignore[arg-type]
|
||||
```
|
||||
|
||||
??? info "Namespace management nuances and implementation details"
|
||||
|
||||
Here, we collect some notes on namespace management, and subtle differences from `BaseModel`:
|
||||
|
||||
`BaseModel` uses its own `__module__` to find out where it was defined
|
||||
and then looks for symbols to resolve forward references in those globals.
|
||||
On the other hand, `TypeAdapter` can be initialized with arbitrary objects,
|
||||
which may not be types and thus do not have a `__module__` available.
|
||||
So instead we look at the globals in our parent stack frame.
|
||||
|
||||
It is expected that the `ns_resolver` passed to this function will have the correct
|
||||
namespace for the type we're adapting. See the source code for `TypeAdapter.__init__`
|
||||
and `TypeAdapter.rebuild` for various ways to construct this namespace.
|
||||
|
||||
This works for the case where this function is called in a module that
|
||||
has the target of forward references in its scope, but
|
||||
does not always work for more complex cases.
|
||||
|
||||
For example, take the following:
|
||||
|
||||
```python {title="a.py"}
|
||||
IntList = list[int]
|
||||
OuterDict = dict[str, 'IntList']
|
||||
```
|
||||
|
||||
```python {test="skip" title="b.py"}
|
||||
from a import OuterDict
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
IntList = int # replaces the symbol the forward reference is looking for
|
||||
v = TypeAdapter(OuterDict)
|
||||
v({'x': 1}) # should fail but doesn't
|
||||
```
|
||||
|
||||
If `OuterDict` were a `BaseModel`, this would work because it would resolve
|
||||
the forward reference within the `a.py` namespace.
|
||||
But `TypeAdapter(OuterDict)` can't determine what module `OuterDict` came from.
|
||||
|
||||
In other words, the assumption that _all_ forward references exist in the
|
||||
module we are being called from is not technically always true.
|
||||
Although most of the time it is and it works fine for recursive models and such,
|
||||
`BaseModel`'s behavior isn't perfect either and _can_ break in similar ways,
|
||||
so there is no right or wrong between the two.
|
||||
|
||||
But at the very least this behavior is _subtly_ different from `BaseModel`'s.
|
||||
"""
|
||||
|
||||
core_schema: CoreSchema
|
||||
validator: SchemaValidator | PluggableSchemaValidator
|
||||
serializer: SchemaSerializer
|
||||
pydantic_complete: bool
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
type: type[T],
|
||||
*,
|
||||
config: ConfigDict | None = ...,
|
||||
_parent_depth: int = ...,
|
||||
module: str | None = ...,
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __new__(cls, __type: type[T], *, config: ConfigDict | None = ...) -> TypeAdapter[T]:
|
||||
...
|
||||
|
||||
# This second overload is for unsupported special forms (such as Annotated, Union, etc.)
|
||||
# Currently there is no way to type this correctly
|
||||
# See https://github.com/python/typing/pull/1618
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
type: Any,
|
||||
*,
|
||||
config: ConfigDict | None = ...,
|
||||
_parent_depth: int = ...,
|
||||
module: str | None = ...,
|
||||
) -> None: ...
|
||||
# this overload is for non-type things like Union[int, str]
|
||||
# Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object]
|
||||
# so an explicit type cast is needed
|
||||
@overload
|
||||
def __new__(cls, __type: T, *, config: ConfigDict | None = ...) -> TypeAdapter[T]:
|
||||
...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Any,
|
||||
*,
|
||||
config: ConfigDict | None = None,
|
||||
_parent_depth: int = 2,
|
||||
module: str | None = None,
|
||||
) -> None:
|
||||
if _type_has_config(type) and config is not None:
|
||||
def __new__(cls, __type: Any, *, config: ConfigDict | None = ...) -> TypeAdapter[T]:
|
||||
"""A class representing the type adapter."""
|
||||
raise NotImplementedError
|
||||
|
||||
@overload
|
||||
def __init__(self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
|
||||
...
|
||||
|
||||
# this overload is for non-type things like Union[int, str]
|
||||
# Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object]
|
||||
# so an explicit type cast is needed
|
||||
@overload
|
||||
def __init__(self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
|
||||
...
|
||||
|
||||
def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None:
|
||||
"""Initializes the TypeAdapter object."""
|
||||
config_wrapper = _config.ConfigWrapper(config)
|
||||
|
||||
try:
|
||||
type_has_config = issubclass(type, BaseModel) or is_dataclass(type) or is_typeddict(type)
|
||||
except TypeError:
|
||||
# type is not a class
|
||||
type_has_config = False
|
||||
|
||||
if type_has_config and config is not None:
|
||||
raise PydanticUserError(
|
||||
'Cannot use `config` when the type is a BaseModel, dataclass or TypedDict.'
|
||||
' These types can have their own config and setting the config via the `config`'
|
||||
@@ -209,313 +234,81 @@ class TypeAdapter(Generic[T]):
|
||||
code='type-adapter-config-unused',
|
||||
)
|
||||
|
||||
self._type = type
|
||||
self._config = config
|
||||
self._parent_depth = _parent_depth
|
||||
self.pydantic_complete = False
|
||||
|
||||
parent_frame = self._fetch_parent_frame()
|
||||
if parent_frame is not None:
|
||||
globalns = parent_frame.f_globals
|
||||
# Do not provide a local ns if the type adapter happens to be instantiated at the module level:
|
||||
localns = parent_frame.f_locals if parent_frame.f_locals is not globalns else {}
|
||||
else:
|
||||
globalns = {}
|
||||
localns = {}
|
||||
|
||||
self._module_name = module or cast(str, globalns.get('__name__', ''))
|
||||
self._init_core_attrs(
|
||||
ns_resolver=_namespace_utils.NsResolver(
|
||||
namespaces_tuple=_namespace_utils.NamespacesTuple(locals=localns, globals=globalns),
|
||||
parent_namespace=localns,
|
||||
),
|
||||
force=False,
|
||||
)
|
||||
|
||||
def _fetch_parent_frame(self) -> FrameType | None:
|
||||
frame = sys._getframe(self._parent_depth)
|
||||
if frame.f_globals.get('__name__') == 'typing':
|
||||
# Because `TypeAdapter` is generic, explicitly parametrizing the class results
|
||||
# in a `typing._GenericAlias` instance, which proxies instantiation calls to the
|
||||
# "real" `TypeAdapter` class and thus adding an extra frame to the call. To avoid
|
||||
# pulling anything from the `typing` module, use the correct frame (the one before):
|
||||
return frame.f_back
|
||||
|
||||
return frame
|
||||
|
||||
def _init_core_attrs(
|
||||
self, ns_resolver: _namespace_utils.NsResolver, force: bool, raise_errors: bool = False
|
||||
) -> bool:
|
||||
"""Initialize the core schema, validator, and serializer for the type.
|
||||
|
||||
Args:
|
||||
ns_resolver: The namespace resolver to use when building the core schema for the adapted type.
|
||||
force: Whether to force the construction of the core schema, validator, and serializer.
|
||||
If `force` is set to `False` and `_defer_build` is `True`, the core schema, validator, and serializer will be set to mocks.
|
||||
raise_errors: Whether to raise errors if initializing any of the core attrs fails.
|
||||
|
||||
Returns:
|
||||
`True` if the core schema, validator, and serializer were successfully initialized, otherwise `False`.
|
||||
|
||||
Raises:
|
||||
PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__`
|
||||
and `raise_errors=True`.
|
||||
"""
|
||||
if not force and self._defer_build:
|
||||
_mock_val_ser.set_type_adapter_mocks(self)
|
||||
self.pydantic_complete = False
|
||||
return False
|
||||
|
||||
core_schema: CoreSchema
|
||||
try:
|
||||
self.core_schema = _getattr_no_parents(self._type, '__pydantic_core_schema__')
|
||||
self.validator = _getattr_no_parents(self._type, '__pydantic_validator__')
|
||||
self.serializer = _getattr_no_parents(self._type, '__pydantic_serializer__')
|
||||
|
||||
# TODO: we don't go through the rebuild logic here directly because we don't want
|
||||
# to repeat all of the namespace fetching logic that we've already done
|
||||
# so we simply skip to the block below that does the actual schema generation
|
||||
if (
|
||||
isinstance(self.core_schema, _mock_val_ser.MockCoreSchema)
|
||||
or isinstance(self.validator, _mock_val_ser.MockValSer)
|
||||
or isinstance(self.serializer, _mock_val_ser.MockValSer)
|
||||
):
|
||||
raise AttributeError()
|
||||
core_schema = _getattr_no_parents(type, '__pydantic_core_schema__')
|
||||
except AttributeError:
|
||||
config_wrapper = _config.ConfigWrapper(self._config)
|
||||
core_schema = _get_schema(type, config_wrapper, parent_depth=_parent_depth + 1)
|
||||
|
||||
schema_generator = _generate_schema.GenerateSchema(config_wrapper, ns_resolver=ns_resolver)
|
||||
core_schema = _discriminated_union.apply_discriminators(_core_utils.simplify_schema_references(core_schema))
|
||||
|
||||
try:
|
||||
core_schema = schema_generator.generate_schema(self._type)
|
||||
except PydanticUndefinedAnnotation:
|
||||
if raise_errors:
|
||||
raise
|
||||
_mock_val_ser.set_type_adapter_mocks(self)
|
||||
return False
|
||||
core_schema = _core_utils.validate_core_schema(core_schema)
|
||||
|
||||
try:
|
||||
self.core_schema = schema_generator.clean_schema(core_schema)
|
||||
except _generate_schema.InvalidSchemaError:
|
||||
_mock_val_ser.set_type_adapter_mocks(self)
|
||||
return False
|
||||
core_config = config_wrapper.core_config(None)
|
||||
validator: SchemaValidator
|
||||
try:
|
||||
validator = _getattr_no_parents(type, '__pydantic_validator__')
|
||||
except AttributeError:
|
||||
validator = create_schema_validator(core_schema, core_config, config_wrapper.plugin_settings)
|
||||
|
||||
core_config = config_wrapper.core_config(None)
|
||||
serializer: SchemaSerializer
|
||||
try:
|
||||
serializer = _getattr_no_parents(type, '__pydantic_serializer__')
|
||||
except AttributeError:
|
||||
serializer = SchemaSerializer(core_schema, core_config)
|
||||
|
||||
self.validator = create_schema_validator(
|
||||
schema=self.core_schema,
|
||||
schema_type=self._type,
|
||||
schema_type_module=self._module_name,
|
||||
schema_type_name=str(self._type),
|
||||
schema_kind='TypeAdapter',
|
||||
config=core_config,
|
||||
plugin_settings=config_wrapper.plugin_settings,
|
||||
)
|
||||
self.serializer = SchemaSerializer(self.core_schema, core_config)
|
||||
|
||||
self.pydantic_complete = True
|
||||
return True
|
||||
|
||||
@property
|
||||
def _defer_build(self) -> bool:
|
||||
config = self._config if self._config is not None else self._model_config
|
||||
if config:
|
||||
return config.get('defer_build') is True
|
||||
return False
|
||||
|
||||
@property
|
||||
def _model_config(self) -> ConfigDict | None:
|
||||
type_: Any = _typing_extra.annotated_type(self._type) or self._type # Eg FastAPI heavily uses Annotated
|
||||
if _utils.lenient_issubclass(type_, BaseModel):
|
||||
return type_.model_config
|
||||
return getattr(type_, '__pydantic_config__', None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'TypeAdapter({_repr.display_as_type(self._type)})'
|
||||
|
||||
def rebuild(
|
||||
self,
|
||||
*,
|
||||
force: bool = False,
|
||||
raise_errors: bool = True,
|
||||
_parent_namespace_depth: int = 2,
|
||||
_types_namespace: _namespace_utils.MappingNamespace | None = None,
|
||||
) -> bool | None:
|
||||
"""Try to rebuild the pydantic-core schema for the adapter's type.
|
||||
|
||||
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
|
||||
the initial attempt to build the schema, and automatic rebuilding fails.
|
||||
|
||||
Args:
|
||||
force: Whether to force the rebuilding of the type adapter's schema, defaults to `False`.
|
||||
raise_errors: Whether to raise errors, defaults to `True`.
|
||||
_parent_namespace_depth: Depth at which to search for the [parent frame][frame-objects]. This
|
||||
frame is used when resolving forward annotations during schema rebuilding, by looking for
|
||||
the locals of this frame. Defaults to 2, which will result in the frame where the method
|
||||
was called.
|
||||
_types_namespace: An explicit types namespace to use, instead of using the local namespace
|
||||
from the parent frame. Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
Returns `None` if the schema is already "complete" and rebuilding was not required.
|
||||
If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`.
|
||||
"""
|
||||
if not force and self.pydantic_complete:
|
||||
return None
|
||||
|
||||
if _types_namespace is not None:
|
||||
rebuild_ns = _types_namespace
|
||||
elif _parent_namespace_depth > 0:
|
||||
rebuild_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {}
|
||||
else:
|
||||
rebuild_ns = {}
|
||||
|
||||
# we have to manually fetch globals here because there's no type on the stack of the NsResolver
|
||||
# and so we skip the globalns = get_module_ns_of(typ) call that would normally happen
|
||||
globalns = sys._getframe(max(_parent_namespace_depth - 1, 1)).f_globals
|
||||
ns_resolver = _namespace_utils.NsResolver(
|
||||
namespaces_tuple=_namespace_utils.NamespacesTuple(locals=rebuild_ns, globals=globalns),
|
||||
parent_namespace=rebuild_ns,
|
||||
)
|
||||
return self._init_core_attrs(ns_resolver=ns_resolver, force=True, raise_errors=raise_errors)
|
||||
self.core_schema = core_schema
|
||||
self.validator = validator
|
||||
self.serializer = serializer
|
||||
|
||||
def validate_python(
|
||||
self,
|
||||
object: Any,
|
||||
/,
|
||||
__object: Any,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
from_attributes: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
|
||||
by_alias: bool | None = None,
|
||||
by_name: bool | None = None,
|
||||
) -> T:
|
||||
"""Validate a Python object against the model.
|
||||
|
||||
Args:
|
||||
object: The Python object to validate against the model.
|
||||
__object: The Python object to validate against the model.
|
||||
strict: Whether to strictly check types.
|
||||
from_attributes: Whether to extract data from object attributes.
|
||||
context: Additional context to pass to the validator.
|
||||
experimental_allow_partial: **Experimental** whether to enable
|
||||
[partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams.
|
||||
* False / 'off': Default behavior, no partial validation.
|
||||
* True / 'on': Enable partial validation.
|
||||
* 'trailing-strings': Enable partial validation and allow trailing strings in the input.
|
||||
by_alias: Whether to use the field's alias when validating against the provided input data.
|
||||
by_name: Whether to use the field's name when validating against the provided input data.
|
||||
|
||||
!!! note
|
||||
When using `TypeAdapter` with a Pydantic `dataclass`, the use of the `from_attributes`
|
||||
argument is not supported.
|
||||
|
||||
Returns:
|
||||
The validated object.
|
||||
"""
|
||||
if by_alias is False and by_name is not True:
|
||||
raise PydanticUserError(
|
||||
'At least one of `by_alias` or `by_name` must be set to True.',
|
||||
code='validate-by-alias-and-name-false',
|
||||
)
|
||||
|
||||
return self.validator.validate_python(
|
||||
object,
|
||||
strict=strict,
|
||||
from_attributes=from_attributes,
|
||||
context=context,
|
||||
allow_partial=experimental_allow_partial,
|
||||
by_alias=by_alias,
|
||||
by_name=by_name,
|
||||
)
|
||||
return self.validator.validate_python(__object, strict=strict, from_attributes=from_attributes, context=context)
|
||||
|
||||
def validate_json(
|
||||
self,
|
||||
data: str | bytes | bytearray,
|
||||
/,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
|
||||
by_alias: bool | None = None,
|
||||
by_name: bool | None = None,
|
||||
self, __data: str | bytes, *, strict: bool | None = None, context: dict[str, Any] | None = None
|
||||
) -> T:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[JSON Parsing](../concepts/json.md#json-parsing)
|
||||
|
||||
Validate a JSON string or bytes against the model.
|
||||
"""Validate a JSON string or bytes against the model.
|
||||
|
||||
Args:
|
||||
data: The JSON data to validate against the model.
|
||||
__data: The JSON data to validate against the model.
|
||||
strict: Whether to strictly check types.
|
||||
context: Additional context to use during validation.
|
||||
experimental_allow_partial: **Experimental** whether to enable
|
||||
[partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams.
|
||||
* False / 'off': Default behavior, no partial validation.
|
||||
* True / 'on': Enable partial validation.
|
||||
* 'trailing-strings': Enable partial validation and allow trailing strings in the input.
|
||||
by_alias: Whether to use the field's alias when validating against the provided input data.
|
||||
by_name: Whether to use the field's name when validating against the provided input data.
|
||||
|
||||
Returns:
|
||||
The validated object.
|
||||
"""
|
||||
if by_alias is False and by_name is not True:
|
||||
raise PydanticUserError(
|
||||
'At least one of `by_alias` or `by_name` must be set to True.',
|
||||
code='validate-by-alias-and-name-false',
|
||||
)
|
||||
return self.validator.validate_json(__data, strict=strict, context=context)
|
||||
|
||||
return self.validator.validate_json(
|
||||
data,
|
||||
strict=strict,
|
||||
context=context,
|
||||
allow_partial=experimental_allow_partial,
|
||||
by_alias=by_alias,
|
||||
by_name=by_name,
|
||||
)
|
||||
|
||||
def validate_strings(
|
||||
self,
|
||||
obj: Any,
|
||||
/,
|
||||
*,
|
||||
strict: bool | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False,
|
||||
by_alias: bool | None = None,
|
||||
by_name: bool | None = None,
|
||||
) -> T:
|
||||
def validate_strings(self, __obj: Any, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> T:
|
||||
"""Validate object contains string data against the model.
|
||||
|
||||
Args:
|
||||
obj: The object contains string data to validate.
|
||||
__obj: The object contains string data to validate.
|
||||
strict: Whether to strictly check types.
|
||||
context: Additional context to use during validation.
|
||||
experimental_allow_partial: **Experimental** whether to enable
|
||||
[partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams.
|
||||
* False / 'off': Default behavior, no partial validation.
|
||||
* True / 'on': Enable partial validation.
|
||||
* 'trailing-strings': Enable partial validation and allow trailing strings in the input.
|
||||
by_alias: Whether to use the field's alias when validating against the provided input data.
|
||||
by_name: Whether to use the field's name when validating against the provided input data.
|
||||
|
||||
Returns:
|
||||
The validated object.
|
||||
"""
|
||||
if by_alias is False and by_name is not True:
|
||||
raise PydanticUserError(
|
||||
'At least one of `by_alias` or `by_name` must be set to True.',
|
||||
code='validate-by-alias-and-name-false',
|
||||
)
|
||||
|
||||
return self.validator.validate_strings(
|
||||
obj,
|
||||
strict=strict,
|
||||
context=context,
|
||||
allow_partial=experimental_allow_partial,
|
||||
by_alias=by_alias,
|
||||
by_name=by_name,
|
||||
)
|
||||
return self.validator.validate_strings(__obj, strict=strict, context=context)
|
||||
|
||||
def get_default_value(self, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> Some[T] | None:
|
||||
"""Get the default value for the wrapped type.
|
||||
@@ -531,26 +324,22 @@ class TypeAdapter(Generic[T]):
|
||||
|
||||
def dump_python(
|
||||
self,
|
||||
instance: T,
|
||||
/,
|
||||
__instance: T,
|
||||
*,
|
||||
mode: Literal['json', 'python'] = 'python',
|
||||
include: IncEx | None = None,
|
||||
exclude: IncEx | None = None,
|
||||
by_alias: bool | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal['none', 'warn', 'error'] = True,
|
||||
fallback: Callable[[Any], Any] | None = None,
|
||||
serialize_as_any: bool = False,
|
||||
context: dict[str, Any] | None = None,
|
||||
warnings: bool = True,
|
||||
) -> Any:
|
||||
"""Dump an instance of the adapted type to a Python object.
|
||||
|
||||
Args:
|
||||
instance: The Python object to serialize.
|
||||
__instance: The Python object to serialize.
|
||||
mode: The output format.
|
||||
include: Fields to include in the output.
|
||||
exclude: Fields to exclude from the output.
|
||||
@@ -559,18 +348,13 @@ class TypeAdapter(Generic[T]):
|
||||
exclude_defaults: Whether to exclude fields with default values.
|
||||
exclude_none: Whether to exclude fields with None values.
|
||||
round_trip: Whether to output the serialized data in a way that is compatible with deserialization.
|
||||
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
|
||||
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
|
||||
fallback: A function to call when an unknown value is encountered. If not provided,
|
||||
a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
|
||||
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
|
||||
context: Additional context to pass to the serializer.
|
||||
warnings: Whether to display serialization warnings.
|
||||
|
||||
Returns:
|
||||
The serialized object.
|
||||
"""
|
||||
return self.serializer.to_python(
|
||||
instance,
|
||||
__instance,
|
||||
mode=mode,
|
||||
by_alias=by_alias,
|
||||
include=include,
|
||||
@@ -580,36 +364,26 @@ class TypeAdapter(Generic[T]):
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
fallback=fallback,
|
||||
serialize_as_any=serialize_as_any,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def dump_json(
|
||||
self,
|
||||
instance: T,
|
||||
/,
|
||||
__instance: T,
|
||||
*,
|
||||
indent: int | None = None,
|
||||
include: IncEx | None = None,
|
||||
exclude: IncEx | None = None,
|
||||
by_alias: bool | None = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal['none', 'warn', 'error'] = True,
|
||||
fallback: Callable[[Any], Any] | None = None,
|
||||
serialize_as_any: bool = False,
|
||||
context: dict[str, Any] | None = None,
|
||||
warnings: bool = True,
|
||||
) -> bytes:
|
||||
"""!!! abstract "Usage Documentation"
|
||||
[JSON Serialization](../concepts/json.md#json-serialization)
|
||||
|
||||
Serialize an instance of the adapted type to JSON.
|
||||
"""Serialize an instance of the adapted type to JSON.
|
||||
|
||||
Args:
|
||||
instance: The instance to be serialized.
|
||||
__instance: The instance to be serialized.
|
||||
indent: Number of spaces for JSON indentation.
|
||||
include: Fields to include.
|
||||
exclude: Fields to exclude.
|
||||
@@ -618,18 +392,13 @@ class TypeAdapter(Generic[T]):
|
||||
exclude_defaults: Whether to exclude fields with default values.
|
||||
exclude_none: Whether to exclude fields with a value of `None`.
|
||||
round_trip: Whether to serialize and deserialize the instance to ensure round-tripping.
|
||||
warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors,
|
||||
"error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError].
|
||||
fallback: A function to call when an unknown value is encountered. If not provided,
|
||||
a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised.
|
||||
serialize_as_any: Whether to serialize fields with duck-typing serialization behavior.
|
||||
context: Additional context to pass to the serializer.
|
||||
warnings: Whether to emit serialization warnings.
|
||||
|
||||
Returns:
|
||||
The JSON representation of the given instance as bytes.
|
||||
"""
|
||||
return self.serializer.to_json(
|
||||
instance,
|
||||
__instance,
|
||||
indent=indent,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
@@ -639,9 +408,6 @@ class TypeAdapter(Generic[T]):
|
||||
exclude_none=exclude_none,
|
||||
round_trip=round_trip,
|
||||
warnings=warnings,
|
||||
fallback=fallback,
|
||||
serialize_as_any=serialize_as_any,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def json_schema(
|
||||
@@ -664,15 +430,11 @@ class TypeAdapter(Generic[T]):
|
||||
The JSON schema for the model as a dictionary.
|
||||
"""
|
||||
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
|
||||
if isinstance(self.core_schema, _mock_val_ser.MockCoreSchema):
|
||||
self.core_schema.rebuild()
|
||||
assert not isinstance(self.core_schema, _mock_val_ser.MockCoreSchema), 'this is a bug! please report it'
|
||||
return schema_generator_instance.generate(self.core_schema, mode=mode)
|
||||
|
||||
@staticmethod
|
||||
def json_schemas(
|
||||
inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]],
|
||||
/,
|
||||
__inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]],
|
||||
*,
|
||||
by_alias: bool = True,
|
||||
title: str | None = None,
|
||||
@@ -683,7 +445,7 @@ class TypeAdapter(Generic[T]):
|
||||
"""Generate a JSON schema including definitions from multiple type adapters.
|
||||
|
||||
Args:
|
||||
inputs: Inputs to schema generation. The first two items will form the keys of the (first)
|
||||
__inputs: Inputs to schema generation. The first two items will form the keys of the (first)
|
||||
output mapping; the type adapters will provide the core schemas that get converted into
|
||||
definitions in the output JSON schema.
|
||||
by_alias: Whether to use alias names.
|
||||
@@ -704,17 +466,9 @@ class TypeAdapter(Generic[T]):
|
||||
"""
|
||||
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
|
||||
|
||||
inputs_ = []
|
||||
for key, mode, adapter in inputs:
|
||||
# This is the same pattern we follow for model json schemas - we attempt a core schema rebuild if we detect a mock
|
||||
if isinstance(adapter.core_schema, _mock_val_ser.MockCoreSchema):
|
||||
adapter.core_schema.rebuild()
|
||||
assert not isinstance(adapter.core_schema, _mock_val_ser.MockCoreSchema), (
|
||||
'this is a bug! please report it'
|
||||
)
|
||||
inputs_.append((key, mode, adapter.core_schema))
|
||||
inputs = [(key, mode, adapter.core_schema) for key, mode, adapter in __inputs]
|
||||
|
||||
json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs_)
|
||||
json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs)
|
||||
|
||||
json_schema: dict[str, Any] = {}
|
||||
if definitions:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,4 @@
|
||||
"""`typing` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""The `utils` module is a backport module from V1."""
|
||||
|
||||
from ._migration import getattr_migration
|
||||
|
||||
__getattr__ = getattr_migration(__name__)
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
# flake8: noqa
|
||||
from pydantic.v1 import dataclasses
|
||||
from pydantic.v1.annotated_types import create_model_from_namedtuple, create_model_from_typeddict
|
||||
from pydantic.v1.class_validators import root_validator, validator
|
||||
from pydantic.v1.config import BaseConfig, ConfigDict, Extra
|
||||
from pydantic.v1.decorator import validate_arguments
|
||||
from pydantic.v1.env_settings import BaseSettings
|
||||
from pydantic.v1.error_wrappers import ValidationError
|
||||
from pydantic.v1.errors import *
|
||||
from pydantic.v1.fields import Field, PrivateAttr, Required
|
||||
from pydantic.v1.main import *
|
||||
from pydantic.v1.networks import *
|
||||
from pydantic.v1.parse import Protocol
|
||||
from pydantic.v1.tools import *
|
||||
from pydantic.v1.types import *
|
||||
from pydantic.v1.version import VERSION, compiled
|
||||
from . import dataclasses
|
||||
from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict
|
||||
from .class_validators import root_validator, validator
|
||||
from .config import BaseConfig, ConfigDict, Extra
|
||||
from .decorator import validate_arguments
|
||||
from .env_settings import BaseSettings
|
||||
from .error_wrappers import ValidationError
|
||||
from .errors import *
|
||||
from .fields import Field, PrivateAttr, Required
|
||||
from .main import *
|
||||
from .networks import *
|
||||
from .parse import Protocol
|
||||
from .tools import *
|
||||
from .types import *
|
||||
from .version import VERSION, compiled
|
||||
|
||||
__version__ = VERSION
|
||||
|
||||
# WARNING __all__ from pydantic.errors is not included here, it will be removed as an export here in v2
|
||||
# please use "from pydantic.v1.errors import ..." instead
|
||||
# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2
|
||||
# please use "from pydantic.errors import ..." instead
|
||||
__all__ = [
|
||||
# annotated types utils
|
||||
'create_model_from_namedtuple',
|
||||
|
||||
@@ -35,7 +35,7 @@ import hypothesis.strategies as st
|
||||
import pydantic
|
||||
import pydantic.color
|
||||
import pydantic.types
|
||||
from pydantic.v1.utils import lenient_issubclass
|
||||
from pydantic.utils import lenient_issubclass
|
||||
|
||||
# FilePath and DirectoryPath are explicitly unsupported, as we'd have to create
|
||||
# them on-disk, and that's unsafe in general without being told *where* to do so.
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, NamedTuple, Type
|
||||
|
||||
from pydantic.v1.fields import Required
|
||||
from pydantic.v1.main import BaseModel, create_model
|
||||
from pydantic.v1.typing import is_typeddict, is_typeddict_special
|
||||
from .fields import Required
|
||||
from .main import BaseModel, create_model
|
||||
from .typing import is_typeddict, is_typeddict_special
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -5,12 +5,12 @@ from itertools import chain
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload
|
||||
|
||||
from pydantic.v1.errors import ConfigError
|
||||
from pydantic.v1.typing import AnyCallable
|
||||
from pydantic.v1.utils import ROOT_KEY, in_ipython
|
||||
from .errors import ConfigError
|
||||
from .typing import AnyCallable
|
||||
from .utils import ROOT_KEY, in_ipython
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import AnyClassMethod
|
||||
from .typing import AnyClassMethod
|
||||
|
||||
|
||||
class Validator:
|
||||
@@ -36,9 +36,9 @@ class Validator:
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.types import ModelOrDc
|
||||
from .config import BaseConfig
|
||||
from .fields import ModelField
|
||||
from .types import ModelOrDc
|
||||
|
||||
ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any]
|
||||
ValidatorsList = List[ValidatorCallable]
|
||||
|
||||
@@ -12,11 +12,11 @@ import re
|
||||
from colorsys import hls_to_rgb, rgb_to_hls
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
from pydantic.v1.errors import ColorError
|
||||
from pydantic.v1.utils import Representation, almost_equal_floats
|
||||
from .errors import ColorError
|
||||
from .utils import Representation, almost_equal_floats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import CallableGenerator, ReprArgs
|
||||
from .typing import CallableGenerator, ReprArgs
|
||||
|
||||
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
|
||||
ColorType = Union[ColorTuple, str]
|
||||
|
||||
@@ -4,15 +4,15 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tup
|
||||
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
from pydantic.v1.typing import AnyArgTCallable, AnyCallable
|
||||
from pydantic.v1.utils import GetterDict
|
||||
from pydantic.v1.version import compiled
|
||||
from .typing import AnyArgTCallable, AnyCallable
|
||||
from .utils import GetterDict
|
||||
from .version import compiled
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import overload
|
||||
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.main import BaseModel
|
||||
from .fields import ModelField
|
||||
from .main import BaseModel
|
||||
|
||||
ConfigType = Type['BaseConfig']
|
||||
|
||||
|
||||
@@ -36,28 +36,21 @@ import dataclasses
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
try:
|
||||
from functools import cached_property
|
||||
except ImportError:
|
||||
# cached_property available only for python3.8+
|
||||
pass
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
|
||||
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
from pydantic.v1.class_validators import gather_all_validators
|
||||
from pydantic.v1.config import BaseConfig, ConfigDict, Extra, get_config
|
||||
from pydantic.v1.error_wrappers import ValidationError
|
||||
from pydantic.v1.errors import DataclassTypeError
|
||||
from pydantic.v1.fields import Field, FieldInfo, Required, Undefined
|
||||
from pydantic.v1.main import create_model, validate_model
|
||||
from pydantic.v1.utils import ClassAttribute
|
||||
from .class_validators import gather_all_validators
|
||||
from .config import BaseConfig, ConfigDict, Extra, get_config
|
||||
from .error_wrappers import ValidationError
|
||||
from .errors import DataclassTypeError
|
||||
from .fields import Field, FieldInfo, Required, Undefined
|
||||
from .main import create_model, validate_model
|
||||
from .utils import ClassAttribute
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.typing import CallableGenerator, NoArgAnyCallable
|
||||
from .main import BaseModel
|
||||
from .typing import CallableGenerator, NoArgAnyCallable
|
||||
|
||||
DataclassT = TypeVar('DataclassT', bound='Dataclass')
|
||||
|
||||
@@ -416,17 +409,6 @@ def create_pydantic_model_from_dataclass(
|
||||
return model
|
||||
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
|
||||
def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
|
||||
return isinstance(getattr(type(obj), k, None), cached_property)
|
||||
|
||||
else:
|
||||
|
||||
def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _dataclass_validate_values(self: 'Dataclass') -> None:
|
||||
# validation errors can occur if this function is called twice on an already initialised dataclass.
|
||||
# for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
|
||||
@@ -435,13 +417,9 @@ def _dataclass_validate_values(self: 'Dataclass') -> None:
|
||||
if getattr(self, '__pydantic_has_field_info_default__', False):
|
||||
# We need to remove `FieldInfo` values since they are not valid as input
|
||||
# It's ok to do that because they are obviously the default values!
|
||||
input_data = {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k))
|
||||
}
|
||||
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
|
||||
else:
|
||||
input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)}
|
||||
input_data = self.__dict__
|
||||
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
|
||||
if validation_error:
|
||||
raise validation_error
|
||||
|
||||
@@ -18,7 +18,7 @@ import re
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from pydantic.v1 import errors
|
||||
from . import errors
|
||||
|
||||
date_expr = r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
|
||||
time_expr = (
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from pydantic.v1 import validator
|
||||
from pydantic.v1.config import Extra
|
||||
from pydantic.v1.errors import ConfigError
|
||||
from pydantic.v1.main import BaseModel, create_model
|
||||
from pydantic.v1.typing import get_all_type_hints
|
||||
from pydantic.v1.utils import to_camel
|
||||
from . import validator
|
||||
from .config import Extra
|
||||
from .errors import ConfigError
|
||||
from .main import BaseModel, create_model
|
||||
from .typing import get_all_type_hints
|
||||
from .utils import to_camel
|
||||
|
||||
__all__ = ('validate_arguments',)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import AnyCallable
|
||||
from .typing import AnyCallable
|
||||
|
||||
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
|
||||
ConfigType = Union[None, Type[Any], Dict[str, Any]]
|
||||
|
||||
@@ -3,12 +3,12 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic.v1.config import BaseConfig, Extra
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.types import JsonWrapper
|
||||
from pydantic.v1.typing import StrPath, display_as_type, get_origin, is_union
|
||||
from pydantic.v1.utils import deep_update, lenient_issubclass, path_type, sequence_like
|
||||
from .config import BaseConfig, Extra
|
||||
from .fields import ModelField
|
||||
from .main import BaseModel
|
||||
from .types import JsonWrapper
|
||||
from .typing import StrPath, display_as_type, get_origin, is_union
|
||||
from .utils import deep_update, lenient_issubclass, path_type, sequence_like
|
||||
|
||||
env_file_sentinel = str(object())
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from pydantic.v1.json import pydantic_encoder
|
||||
from pydantic.v1.utils import Representation
|
||||
from .json import pydantic_encoder
|
||||
from .utils import Representation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.types import ModelOrDc
|
||||
from pydantic.v1.typing import ReprArgs
|
||||
from .config import BaseConfig
|
||||
from .types import ModelOrDc
|
||||
from .typing import ReprArgs
|
||||
|
||||
Loc = Tuple[Union[int, str], ...]
|
||||
|
||||
@@ -101,6 +101,7 @@ def flatten_errors(
|
||||
) -> Generator['ErrorDict', None, None]:
|
||||
for error in errors:
|
||||
if isinstance(error, ErrorWrapper):
|
||||
|
||||
if loc:
|
||||
error_loc = loc + error.loc_tuple()
|
||||
else:
|
||||
|
||||
@@ -2,12 +2,12 @@ from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
from pydantic.v1.typing import display_as_type
|
||||
from .typing import display_as_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import DictStrAny
|
||||
from .typing import DictStrAny
|
||||
|
||||
# explicitly state exports to avoid "from pydantic.v1.errors import *" also importing Decimal, Path etc.
|
||||
# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc.
|
||||
__all__ = (
|
||||
'PydanticTypeError',
|
||||
'PydanticValueError',
|
||||
|
||||
@@ -28,12 +28,12 @@ from typing import (
|
||||
|
||||
from typing_extensions import Annotated, Final
|
||||
|
||||
from pydantic.v1 import errors as errors_
|
||||
from pydantic.v1.class_validators import Validator, make_generic_validator, prep_validators
|
||||
from pydantic.v1.error_wrappers import ErrorWrapper
|
||||
from pydantic.v1.errors import ConfigError, InvalidDiscriminator, MissingDiscriminator, NoneIsNotAllowedError
|
||||
from pydantic.v1.types import Json, JsonWrapper
|
||||
from pydantic.v1.typing import (
|
||||
from . import errors as errors_
|
||||
from .class_validators import Validator, make_generic_validator, prep_validators
|
||||
from .error_wrappers import ErrorWrapper
|
||||
from .errors import ConfigError, InvalidDiscriminator, MissingDiscriminator, NoneIsNotAllowedError
|
||||
from .types import Json, JsonWrapper
|
||||
from .typing import (
|
||||
NoArgAnyCallable,
|
||||
convert_generics,
|
||||
display_as_type,
|
||||
@@ -48,7 +48,7 @@ from pydantic.v1.typing import (
|
||||
is_union,
|
||||
new_type_supertype,
|
||||
)
|
||||
from pydantic.v1.utils import (
|
||||
from .utils import (
|
||||
PyObjectStr,
|
||||
Representation,
|
||||
ValueItems,
|
||||
@@ -59,7 +59,7 @@ from pydantic.v1.utils import (
|
||||
sequence_like,
|
||||
smart_deepcopy,
|
||||
)
|
||||
from pydantic.v1.validators import constant_validator, dict_validator, find_validators, validate_json
|
||||
from .validators import constant_validator, dict_validator, find_validators, validate_json
|
||||
|
||||
Required: Any = Ellipsis
|
||||
|
||||
@@ -83,11 +83,11 @@ class UndefinedType:
|
||||
Undefined = UndefinedType()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.class_validators import ValidatorsList
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.error_wrappers import ErrorList
|
||||
from pydantic.v1.types import ModelOrDc
|
||||
from pydantic.v1.typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs
|
||||
from .class_validators import ValidatorsList
|
||||
from .config import BaseConfig
|
||||
from .error_wrappers import ErrorList
|
||||
from .types import ModelOrDc
|
||||
from .typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs
|
||||
|
||||
ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]]
|
||||
LocStr = Union[Tuple[Union[int, str], ...], str]
|
||||
@@ -178,6 +178,7 @@ class FieldInfo(Representation):
|
||||
self.extra = kwargs
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
|
||||
field_defaults_to_hide: Dict[str, Any] = {
|
||||
'repr': True,
|
||||
**self.__field_constraints__,
|
||||
@@ -404,6 +405,7 @@ class ModelField(Representation):
|
||||
alias: Optional[str] = None,
|
||||
field_info: Optional[FieldInfo] = None,
|
||||
) -> None:
|
||||
|
||||
self.name: str = name
|
||||
self.has_alias: bool = alias is not None
|
||||
self.alias: str = alias if alias is not None else name
|
||||
@@ -490,7 +492,7 @@ class ModelField(Representation):
|
||||
class_validators: Optional[Dict[str, Validator]],
|
||||
config: Type['BaseConfig'],
|
||||
) -> 'ModelField':
|
||||
from pydantic.v1.schema import get_annotation_from_field_info
|
||||
from .schema import get_annotation_from_field_info
|
||||
|
||||
field_info, value = cls._get_field_info(name, annotation, value, config)
|
||||
required: 'BoolUndefined' = Undefined
|
||||
@@ -850,6 +852,7 @@ class ModelField(Representation):
|
||||
def validate(
|
||||
self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None
|
||||
) -> 'ValidateReturn':
|
||||
|
||||
assert self.type_.__class__ is not DeferredType
|
||||
|
||||
if self.type_.__class__ is ForwardRef:
|
||||
@@ -1160,7 +1163,7 @@ class ModelField(Representation):
|
||||
"""
|
||||
Whether the field is "complex" eg. env variables should be parsed as JSON.
|
||||
"""
|
||||
from pydantic.v1.main import BaseModel
|
||||
from .main import BaseModel
|
||||
|
||||
return (
|
||||
self.shape != SHAPE_SINGLETON
|
||||
|
||||
@@ -22,12 +22,12 @@ from weakref import WeakKeyDictionary, WeakValueDictionary
|
||||
|
||||
from typing_extensions import Annotated, Literal as ExtLiteral
|
||||
|
||||
from pydantic.v1.class_validators import gather_all_validators
|
||||
from pydantic.v1.fields import DeferredType
|
||||
from pydantic.v1.main import BaseModel, create_model
|
||||
from pydantic.v1.types import JsonWrapper
|
||||
from pydantic.v1.typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
|
||||
from pydantic.v1.utils import all_identical, lenient_issubclass
|
||||
from .class_validators import gather_all_validators
|
||||
from .fields import DeferredType
|
||||
from .main import BaseModel, create_model
|
||||
from .types import JsonWrapper
|
||||
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
|
||||
from .utils import all_identical, lenient_issubclass
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import _UnionGenericAlias
|
||||
|
||||
@@ -9,9 +9,9 @@ from types import GeneratorType
|
||||
from typing import Any, Callable, Dict, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic.v1.color import Color
|
||||
from pydantic.v1.networks import NameEmail
|
||||
from pydantic.v1.types import SecretBytes, SecretStr
|
||||
from .color import Color
|
||||
from .networks import NameEmail
|
||||
from .types import SecretBytes, SecretStr
|
||||
|
||||
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
|
||||
|
||||
@@ -72,7 +72,7 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
def pydantic_encoder(obj: Any) -> Any:
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
from pydantic.v1.main import BaseModel
|
||||
from .main import BaseModel
|
||||
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.dict()
|
||||
|
||||
@@ -26,11 +26,11 @@ from typing import (
|
||||
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
from pydantic.v1.class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators
|
||||
from pydantic.v1.config import BaseConfig, Extra, inherit_config, prepare_config
|
||||
from pydantic.v1.error_wrappers import ErrorWrapper, ValidationError
|
||||
from pydantic.v1.errors import ConfigError, DictError, ExtraError, MissingError
|
||||
from pydantic.v1.fields import (
|
||||
from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators
|
||||
from .config import BaseConfig, Extra, inherit_config, prepare_config
|
||||
from .error_wrappers import ErrorWrapper, ValidationError
|
||||
from .errors import ConfigError, DictError, ExtraError, MissingError
|
||||
from .fields import (
|
||||
MAPPING_LIKE_SHAPES,
|
||||
Field,
|
||||
ModelField,
|
||||
@@ -39,11 +39,11 @@ from pydantic.v1.fields import (
|
||||
Undefined,
|
||||
is_finalvar_with_default_val,
|
||||
)
|
||||
from pydantic.v1.json import custom_pydantic_encoder, pydantic_encoder
|
||||
from pydantic.v1.parse import Protocol, load_file, load_str_bytes
|
||||
from pydantic.v1.schema import default_ref_template, model_schema
|
||||
from pydantic.v1.types import PyObject, StrBytes
|
||||
from pydantic.v1.typing import (
|
||||
from .json import custom_pydantic_encoder, pydantic_encoder
|
||||
from .parse import Protocol, load_file, load_str_bytes
|
||||
from .schema import default_ref_template, model_schema
|
||||
from .types import PyObject, StrBytes
|
||||
from .typing import (
|
||||
AnyCallable,
|
||||
get_args,
|
||||
get_origin,
|
||||
@@ -53,7 +53,7 @@ from pydantic.v1.typing import (
|
||||
resolve_annotations,
|
||||
update_model_forward_refs,
|
||||
)
|
||||
from pydantic.v1.utils import (
|
||||
from .utils import (
|
||||
DUNDER_ATTRIBUTES,
|
||||
ROOT_KEY,
|
||||
ClassAttribute,
|
||||
@@ -73,9 +73,9 @@ from pydantic.v1.utils import (
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
|
||||
from pydantic.v1.class_validators import ValidatorListDict
|
||||
from pydantic.v1.types import ModelOrDc
|
||||
from pydantic.v1.typing import (
|
||||
from .class_validators import ValidatorListDict
|
||||
from .types import ModelOrDc
|
||||
from .typing import (
|
||||
AbstractSetIntStr,
|
||||
AnyClassMethod,
|
||||
CallableGenerator,
|
||||
@@ -282,12 +282,6 @@ class ModelMetaclass(ABCMeta):
|
||||
cls = super().__new__(mcs, name, bases, new_namespace, **kwargs)
|
||||
# set __signature__ attr only for model class, but not for its instances
|
||||
cls.__signature__ = ClassAttribute('__signature__', generate_model_signature(cls.__init__, fields, config))
|
||||
|
||||
if not _is_base_model_class_defined:
|
||||
# Cython does not understand the `if TYPE_CHECKING:` condition in the
|
||||
# BaseModel's body (where annotations are set), so clear them manually:
|
||||
getattr(cls, '__annotations__', {}).clear()
|
||||
|
||||
if resolve_forward_refs:
|
||||
cls.__try_update_forward_refs__()
|
||||
|
||||
@@ -307,7 +301,7 @@ class ModelMetaclass(ABCMeta):
|
||||
|
||||
See #3829 and python/cpython#92810
|
||||
"""
|
||||
return hasattr(instance, '__post_root_validators__') and super().__instancecheck__(instance)
|
||||
return hasattr(instance, '__fields__') and super().__instancecheck__(instance)
|
||||
|
||||
|
||||
object_setattr = object.__setattr__
|
||||
@@ -675,7 +669,7 @@ class BaseModel(Representation, metaclass=ModelMetaclass):
|
||||
def schema_json(
|
||||
cls, *, by_alias: bool = True, ref_template: str = default_ref_template, **dumps_kwargs: Any
|
||||
) -> str:
|
||||
from pydantic.v1.json import pydantic_encoder
|
||||
from .json import pydantic_encoder
|
||||
|
||||
return cls.__config__.json_dumps(
|
||||
cls.schema(by_alias=by_alias, ref_template=ref_template), default=pydantic_encoder, **dumps_kwargs
|
||||
@@ -743,6 +737,7 @@ class BaseModel(Representation, metaclass=ModelMetaclass):
|
||||
exclude_defaults: bool,
|
||||
exclude_none: bool,
|
||||
) -> Any:
|
||||
|
||||
if isinstance(v, BaseModel):
|
||||
if to_dict:
|
||||
v_dict = v.dict(
|
||||
@@ -835,6 +830,7 @@ class BaseModel(Representation, metaclass=ModelMetaclass):
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> 'TupleGenerator':
|
||||
|
||||
# Merge field set excludes with explicit exclude parameter with explicit overriding field set options.
|
||||
# The extra "is not None" guards are not logically necessary but optimizes performance for the simple case.
|
||||
if exclude is not None or self.__exclude_fields__ is not None:
|
||||
|
||||
@@ -57,7 +57,6 @@ from mypy.types import (
|
||||
Type,
|
||||
TypeOfAny,
|
||||
TypeType,
|
||||
TypeVarId,
|
||||
TypeVarType,
|
||||
UnionType,
|
||||
get_proper_type,
|
||||
@@ -66,7 +65,7 @@ from mypy.typevars import fill_typevars
|
||||
from mypy.util import get_unique_redefinition_name
|
||||
from mypy.version import __version__ as mypy_version
|
||||
|
||||
from pydantic.v1.utils import is_valid_field
|
||||
from pydantic.utils import is_valid_field
|
||||
|
||||
try:
|
||||
from mypy.types import TypeVarDef # type: ignore[attr-defined]
|
||||
@@ -499,11 +498,7 @@ class PydanticModelTransformer:
|
||||
tvd = TypeVarType(
|
||||
self_tvar_name,
|
||||
tvar_fullname,
|
||||
(
|
||||
TypeVarId(-1, namespace=ctx.cls.fullname + '.construct')
|
||||
if MYPY_VERSION_TUPLE >= (1, 11)
|
||||
else TypeVarId(-1)
|
||||
),
|
||||
-1,
|
||||
[],
|
||||
obj_type,
|
||||
AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type]
|
||||
@@ -863,9 +858,9 @@ def add_method(
|
||||
arg_kinds.append(arg.kind)
|
||||
|
||||
function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function')
|
||||
signature = CallableType(
|
||||
arg_types, arg_kinds, arg_names, return_type, function_type, variables=[tvar_def] if tvar_def else None
|
||||
)
|
||||
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
|
||||
if tvar_def:
|
||||
signature.variables = [tvar_def]
|
||||
|
||||
func = FuncDef(name, args, Block([PassStmt()]))
|
||||
func.info = info
|
||||
|
||||
@@ -27,17 +27,17 @@ from typing import (
|
||||
no_type_check,
|
||||
)
|
||||
|
||||
from pydantic.v1 import errors
|
||||
from pydantic.v1.utils import Representation, update_not_none
|
||||
from pydantic.v1.validators import constr_length_validator, str_validator
|
||||
from . import errors
|
||||
from .utils import Representation, update_not_none
|
||||
from .validators import constr_length_validator, str_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import email_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.typing import AnyCallable
|
||||
from .config import BaseConfig
|
||||
from .fields import ModelField
|
||||
from .typing import AnyCallable
|
||||
|
||||
CallableGenerator = Generator[AnyCallable, None, None]
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic.v1.types import StrBytes
|
||||
from .types import StrBytes
|
||||
|
||||
|
||||
class Protocol(str, Enum):
|
||||
|
||||
@@ -31,7 +31,7 @@ from uuid import UUID
|
||||
|
||||
from typing_extensions import Annotated, Literal
|
||||
|
||||
from pydantic.v1.fields import (
|
||||
from .fields import (
|
||||
MAPPING_LIKE_SHAPES,
|
||||
SHAPE_DEQUE,
|
||||
SHAPE_FROZENSET,
|
||||
@@ -46,9 +46,9 @@ from pydantic.v1.fields import (
|
||||
FieldInfo,
|
||||
ModelField,
|
||||
)
|
||||
from pydantic.v1.json import pydantic_encoder
|
||||
from pydantic.v1.networks import AnyUrl, EmailStr
|
||||
from pydantic.v1.types import (
|
||||
from .json import pydantic_encoder
|
||||
from .networks import AnyUrl, EmailStr
|
||||
from .types import (
|
||||
ConstrainedDecimal,
|
||||
ConstrainedFloat,
|
||||
ConstrainedFrozenSet,
|
||||
@@ -69,7 +69,7 @@ from pydantic.v1.types import (
|
||||
conset,
|
||||
constr,
|
||||
)
|
||||
from pydantic.v1.typing import (
|
||||
from .typing import (
|
||||
all_literal_values,
|
||||
get_args,
|
||||
get_origin,
|
||||
@@ -80,11 +80,11 @@ from pydantic.v1.typing import (
|
||||
is_none_type,
|
||||
is_union,
|
||||
)
|
||||
from pydantic.v1.utils import ROOT_KEY, get_model, lenient_issubclass
|
||||
from .utils import ROOT_KEY, get_model, lenient_issubclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.dataclasses import Dataclass
|
||||
from pydantic.v1.main import BaseModel
|
||||
from .dataclasses import Dataclass
|
||||
from .main import BaseModel
|
||||
|
||||
default_prefix = '#/definitions/'
|
||||
default_ref_template = '#/definitions/{model}'
|
||||
@@ -198,6 +198,7 @@ def model_schema(
|
||||
|
||||
|
||||
def get_field_info_schema(field: ModelField, schema_overrides: bool = False) -> Tuple[Dict[str, Any], bool]:
|
||||
|
||||
# If no title is explicitly set, we don't set title in the schema for enums.
|
||||
# The behaviour is the same as `BaseModel` reference, where the default title
|
||||
# is in the definitions part of the schema.
|
||||
@@ -378,7 +379,7 @@ def get_flat_models_from_field(field: ModelField, known_models: TypeModelSet) ->
|
||||
:param known_models: used to solve circular references
|
||||
:return: a set with the model used in the declaration for this field, if any, and all its sub-models
|
||||
"""
|
||||
from pydantic.v1.main import BaseModel
|
||||
from .main import BaseModel
|
||||
|
||||
flat_models: TypeModelSet = set()
|
||||
|
||||
@@ -445,7 +446,7 @@ def field_type_schema(
|
||||
Take a single ``field`` and generate the schema for its type only, not including additional
|
||||
information as title, etc. Also return additional schema definitions, from sub-models.
|
||||
"""
|
||||
from pydantic.v1.main import BaseModel # noqa: F811
|
||||
from .main import BaseModel # noqa: F811
|
||||
|
||||
definitions = {}
|
||||
nested_models: Set[str] = set()
|
||||
@@ -738,7 +739,7 @@ def field_singleton_sub_fields_schema(
|
||||
discriminator_models_refs[discriminator_value] = discriminator_model_ref['$ref']
|
||||
|
||||
s['discriminator'] = {
|
||||
'propertyName': field.discriminator_alias if by_alias else field.discriminator_key,
|
||||
'propertyName': field.discriminator_alias,
|
||||
'mapping': discriminator_models_refs,
|
||||
}
|
||||
|
||||
@@ -838,7 +839,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity)
|
||||
|
||||
Take a single Pydantic ``ModelField``, and return its schema and any additional definitions from sub-models.
|
||||
"""
|
||||
from pydantic.v1.main import BaseModel
|
||||
from .main import BaseModel
|
||||
|
||||
definitions: Dict[str, Any] = {}
|
||||
nested_models: Set[str] = set()
|
||||
@@ -974,7 +975,7 @@ def multitypes_literal_field_for_schema(values: Tuple[Any, ...], field: ModelFie
|
||||
|
||||
|
||||
def encode_default(dft: Any) -> Any:
|
||||
from pydantic.v1.main import BaseModel
|
||||
from .main import BaseModel
|
||||
|
||||
if isinstance(dft, BaseModel) or is_dataclass(dft):
|
||||
dft = cast('dict[str, Any]', pydantic_encoder(dft))
|
||||
@@ -1090,7 +1091,7 @@ def get_annotation_with_constraints(annotation: Any, field_info: FieldInfo) -> T
|
||||
if issubclass(type_, (SecretStr, SecretBytes)):
|
||||
attrs = ('max_length', 'min_length')
|
||||
|
||||
def constraint_func(**kw: Any) -> Type[Any]: # noqa: F811
|
||||
def constraint_func(**kw: Any) -> Type[Any]:
|
||||
return type(type_.__name__, (type_,), kw)
|
||||
|
||||
elif issubclass(type_, str) and not issubclass(type_, (EmailStr, AnyUrl)):
|
||||
|
||||
@@ -3,16 +3,16 @@ from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic.v1.parse import Protocol, load_file, load_str_bytes
|
||||
from pydantic.v1.types import StrBytes
|
||||
from pydantic.v1.typing import display_as_type
|
||||
from .parse import Protocol, load_file, load_str_bytes
|
||||
from .types import StrBytes
|
||||
from .typing import display_as_type
|
||||
|
||||
__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of')
|
||||
|
||||
NameFactory = Union[str, Callable[[Type[Any]], str]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.typing import DictStrAny
|
||||
from .typing import DictStrAny
|
||||
|
||||
|
||||
def _generate_parsing_type_name(type_: Any) -> str:
|
||||
@@ -21,7 +21,7 @@ def _generate_parsing_type_name(type_: Any) -> str:
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any:
|
||||
from pydantic.v1.main import create_model
|
||||
from .main import create_model
|
||||
|
||||
if type_name is None:
|
||||
type_name = _generate_parsing_type_name
|
||||
|
||||
@@ -28,10 +28,10 @@ from typing import (
|
||||
from uuid import UUID
|
||||
from weakref import WeakSet
|
||||
|
||||
from pydantic.v1 import errors
|
||||
from pydantic.v1.datetime_parse import parse_date
|
||||
from pydantic.v1.utils import import_string, update_not_none
|
||||
from pydantic.v1.validators import (
|
||||
from . import errors
|
||||
from .datetime_parse import parse_date
|
||||
from .utils import import_string, update_not_none
|
||||
from .validators import (
|
||||
bytes_validator,
|
||||
constr_length_validator,
|
||||
constr_lower,
|
||||
@@ -123,9 +123,9 @@ StrIntFloat = Union[str, int, float]
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic.v1.dataclasses import Dataclass
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.typing import CallableGenerator
|
||||
from .dataclasses import Dataclass
|
||||
from .main import BaseModel
|
||||
from .typing import CallableGenerator
|
||||
|
||||
ModelOrDc = Type[Union[BaseModel, Dataclass]]
|
||||
|
||||
@@ -481,7 +481,6 @@ else:
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
# This types superclass should be Set[T], but cython chokes on that...
|
||||
class ConstrainedSet(set): # type: ignore
|
||||
# Needed for pydantic to detect that this is a set
|
||||
@@ -570,7 +569,6 @@ def confrozenset(
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LIST TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
# This types superclass should be List[T], but cython chokes on that...
|
||||
class ConstrainedList(list): # type: ignore
|
||||
# Needed for pydantic to detect that this is a list
|
||||
@@ -1096,6 +1094,7 @@ class ByteSize(int):
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v: StrIntFloat) -> 'ByteSize':
|
||||
|
||||
try:
|
||||
return cls(int(v))
|
||||
except ValueError:
|
||||
@@ -1117,6 +1116,7 @@ class ByteSize(int):
|
||||
return cls(int(float(scalar) * unit_mult))
|
||||
|
||||
def human_readable(self, decimal: bool = False) -> str:
|
||||
|
||||
if decimal:
|
||||
divisor = 1000
|
||||
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
||||
@@ -1135,6 +1135,7 @@ class ByteSize(int):
|
||||
return f'{num:0.1f}{final_unit}'
|
||||
|
||||
def to(self, unit: str) -> float:
|
||||
|
||||
try:
|
||||
unit_div = BYTE_SIZES[unit.lower()]
|
||||
except KeyError:
|
||||
|
||||
@@ -58,21 +58,12 @@ if sys.version_info < (3, 9):
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
return type_._evaluate(globalns, localns)
|
||||
|
||||
elif sys.version_info < (3, 12, 4):
|
||||
else:
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
# Even though it is the right signature for python 3.9, mypy complains with
|
||||
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
|
||||
# Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid:
|
||||
# TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard'
|
||||
return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set())
|
||||
|
||||
else:
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
# Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid
|
||||
# warnings:
|
||||
return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set())
|
||||
return cast(Any, type_)._evaluate(globalns, localns, set())
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
@@ -265,7 +256,7 @@ StrPath = Union[str, PathLike]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic.v1.fields import ModelField
|
||||
from .fields import ModelField
|
||||
|
||||
TupleGenerator = Generator[Tuple[str, Any], None, None]
|
||||
DictStrAny = Dict[str, Any]
|
||||
@@ -406,10 +397,7 @@ def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Opti
|
||||
else:
|
||||
value = ForwardRef(value, is_argument=False)
|
||||
try:
|
||||
if sys.version_info >= (3, 13):
|
||||
value = _eval_type(value, base_globals, None, type_params=())
|
||||
else:
|
||||
value = _eval_type(value, base_globals, None)
|
||||
value = _eval_type(value, base_globals, None)
|
||||
except NameError:
|
||||
# this is ok, it can be fixed with update_forward_refs
|
||||
pass
|
||||
@@ -447,7 +435,7 @@ def is_namedtuple(type_: Type[Any]) -> bool:
|
||||
Check if a given class is a named tuple.
|
||||
It can be either a `typing.NamedTuple` or `collections.namedtuple`
|
||||
"""
|
||||
from pydantic.v1.utils import lenient_issubclass
|
||||
from .utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
|
||||
|
||||
@@ -457,7 +445,7 @@ def is_typeddict(type_: Type[Any]) -> bool:
|
||||
Check if a given class is a typed dict (from `typing` or `typing_extensions`)
|
||||
In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict)
|
||||
"""
|
||||
from pydantic.v1.utils import lenient_issubclass
|
||||
from .utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, dict) and hasattr(type_, '__total__')
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ from typing import (
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from pydantic.v1.errors import ConfigError
|
||||
from pydantic.v1.typing import (
|
||||
from .errors import ConfigError
|
||||
from .typing import (
|
||||
NoneType,
|
||||
WithArgsTypes,
|
||||
all_literal_values,
|
||||
@@ -39,17 +39,17 @@ from pydantic.v1.typing import (
|
||||
is_literal_type,
|
||||
is_union,
|
||||
)
|
||||
from pydantic.v1.version import version_info
|
||||
from .version import version_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.dataclasses import Dataclass
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.main import BaseModel
|
||||
from pydantic.v1.typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
|
||||
from .config import BaseConfig
|
||||
from .dataclasses import Dataclass
|
||||
from .fields import ModelField
|
||||
from .main import BaseModel
|
||||
from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
|
||||
|
||||
RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]]
|
||||
|
||||
@@ -66,7 +66,6 @@ __all__ = (
|
||||
'almost_equal_floats',
|
||||
'get_model',
|
||||
'to_camel',
|
||||
'to_lower_camel',
|
||||
'is_valid_field',
|
||||
'smart_deepcopy',
|
||||
'PyObjectStr',
|
||||
@@ -159,7 +158,7 @@ def sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||||
|
||||
|
||||
def validate_field_name(bases: Iterable[Type[Any]], field_name: str) -> None:
|
||||
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
|
||||
"""
|
||||
Ensure that the field's name does not shadow an existing attribute of the model.
|
||||
"""
|
||||
@@ -241,7 +240,7 @@ def generate_model_signature(
|
||||
"""
|
||||
from inspect import Parameter, Signature, signature
|
||||
|
||||
from pydantic.v1.config import Extra
|
||||
from .config import Extra
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: Dict[str, Parameter] = {}
|
||||
@@ -299,7 +298,7 @@ def generate_model_signature(
|
||||
|
||||
|
||||
def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
|
||||
from pydantic.v1.main import BaseModel
|
||||
from .main import BaseModel
|
||||
|
||||
try:
|
||||
model_cls = obj.__pydantic_model__ # type: ignore
|
||||
@@ -708,8 +707,6 @@ DUNDER_ATTRIBUTES = {
|
||||
'__orig_bases__',
|
||||
'__orig_class__',
|
||||
'__qualname__',
|
||||
'__firstlineno__',
|
||||
'__static_attributes__',
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -27,11 +27,10 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
from warnings import warn
|
||||
|
||||
from pydantic.v1 import errors
|
||||
from pydantic.v1.datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
|
||||
from pydantic.v1.typing import (
|
||||
from . import errors
|
||||
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
|
||||
from .typing import (
|
||||
AnyCallable,
|
||||
all_literal_values,
|
||||
display_as_type,
|
||||
@@ -42,14 +41,14 @@ from pydantic.v1.typing import (
|
||||
is_none_type,
|
||||
is_typeddict,
|
||||
)
|
||||
from pydantic.v1.utils import almost_equal_floats, lenient_issubclass, sequence_like
|
||||
from .utils import almost_equal_floats, lenient_issubclass, sequence_like
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
from pydantic.v1.config import BaseConfig
|
||||
from pydantic.v1.fields import ModelField
|
||||
from pydantic.v1.types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt
|
||||
from .config import BaseConfig
|
||||
from .fields import ModelField
|
||||
from .types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt
|
||||
|
||||
ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt]
|
||||
AnyOrderedDict = OrderedDict[Any, Any]
|
||||
@@ -595,7 +594,7 @@ NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple)
|
||||
def make_namedtuple_validator(
|
||||
namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig']
|
||||
) -> Callable[[Tuple[Any, ...]], NamedTupleT]:
|
||||
from pydantic.v1.annotated_types import create_model_from_namedtuple
|
||||
from .annotated_types import create_model_from_namedtuple
|
||||
|
||||
NamedTupleModel = create_model_from_namedtuple(
|
||||
namedtuple_cls,
|
||||
@@ -620,7 +619,7 @@ def make_namedtuple_validator(
|
||||
def make_typeddict_validator(
|
||||
typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type]
|
||||
) -> Callable[[Any], Dict[str, Any]]:
|
||||
from pydantic.v1.annotated_types import create_model_from_typeddict
|
||||
from .annotated_types import create_model_from_typeddict
|
||||
|
||||
TypedDictModel = create_model_from_typeddict(
|
||||
typeddict_cls,
|
||||
@@ -699,7 +698,7 @@ _VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [
|
||||
def find_validators( # noqa: C901 (ignore complexity)
|
||||
type_: Type[Any], config: Type['BaseConfig']
|
||||
) -> Generator[AnyCallable, None, None]:
|
||||
from pydantic.v1.dataclasses import is_builtin_dataclass, make_dataclass_validator
|
||||
from .dataclasses import is_builtin_dataclass, make_dataclass_validator
|
||||
|
||||
if type_ is Any or type_ is object:
|
||||
return
|
||||
@@ -763,6 +762,4 @@ def find_validators( # noqa: C901 (ignore complexity)
|
||||
if config.arbitrary_types_allowed:
|
||||
yield make_arbitrary_type_validator(type_)
|
||||
else:
|
||||
if hasattr(type_, '__pydantic_core_schema__'):
|
||||
warn(f'Mixing V1 and V2 models is not supported. `{type_.__name__}` is a V2 model.', UserWarning)
|
||||
raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config')
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
__all__ = 'compiled', 'VERSION', 'version_info'
|
||||
|
||||
VERSION = '1.10.21'
|
||||
VERSION = '1.10.13'
|
||||
|
||||
try:
|
||||
import cython # type: ignore
|
||||
|
||||
58
venv/lib/python3.12/site-packages/pydantic/validate_call.py
Normal file
58
venv/lib/python3.12/site-packages/pydantic/validate_call.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Decorator for validating function calls."""
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
|
||||
|
||||
from ._internal import _validate_call
|
||||
|
||||
__all__ = ('validate_call',)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import ConfigDict
|
||||
|
||||
AnyCallableT = TypeVar('AnyCallableT', bound=Callable[..., Any])
|
||||
|
||||
|
||||
@overload
|
||||
def validate_call(
|
||||
*, config: ConfigDict | None = None, validate_return: bool = False
|
||||
) -> Callable[[AnyCallableT], AnyCallableT]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_call(__func: AnyCallableT) -> AnyCallableT:
|
||||
...
|
||||
|
||||
|
||||
def validate_call(
|
||||
__func: AnyCallableT | None = None,
|
||||
*,
|
||||
config: ConfigDict | None = None,
|
||||
validate_return: bool = False,
|
||||
) -> AnyCallableT | Callable[[AnyCallableT], AnyCallableT]:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/validation_decorator/
|
||||
|
||||
Returns a decorated wrapper around the function that validates the arguments and, optionally, the return value.
|
||||
|
||||
Usage may be either as a plain decorator `@validate_call` or with arguments `@validate_call(...)`.
|
||||
|
||||
Args:
|
||||
__func: The function to be decorated.
|
||||
config: The configuration dictionary.
|
||||
validate_return: Whether to validate the return value.
|
||||
|
||||
Returns:
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
def validate(function: AnyCallableT) -> AnyCallableT:
|
||||
if isinstance(function, (classmethod, staticmethod)):
|
||||
name = type(function).__name__
|
||||
raise TypeError(f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)')
|
||||
return _validate_call.ValidateCallWrapper(function, config, validate_return) # type: ignore
|
||||
|
||||
if __func:
|
||||
return validate(__func)
|
||||
else:
|
||||
return validate
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user