API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -1,22 +1,63 @@
from .main import BaseSettings, SettingsConfigDict
from .exceptions import SettingsError
from .main import BaseSettings, CliApp, SettingsConfigDict
from .sources import (
CLI_SUPPRESS,
AWSSecretsManagerSettingsSource,
AzureKeyVaultSettingsSource,
CliExplicitFlag,
CliImplicitFlag,
CliMutuallyExclusiveGroup,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
CliSuppress,
CliUnknownArgs,
DotEnvSettingsSource,
EnvSettingsSource,
ForceDecode,
GoogleSecretManagerSettingsSource,
InitSettingsSource,
JsonConfigSettingsSource,
NoDecode,
PydanticBaseSettingsSource,
PyprojectTomlConfigSettingsSource,
SecretsSettingsSource,
TomlConfigSettingsSource,
YamlConfigSettingsSource,
get_subcommand,
)
from .version import VERSION
__all__ = (
'CLI_SUPPRESS',
'AWSSecretsManagerSettingsSource',
'AzureKeyVaultSettingsSource',
'BaseSettings',
'CliApp',
'CliExplicitFlag',
'CliImplicitFlag',
'CliMutuallyExclusiveGroup',
'CliPositionalArg',
'CliSettingsSource',
'CliSubCommand',
'CliSuppress',
'CliUnknownArgs',
'DotEnvSettingsSource',
'EnvSettingsSource',
'ForceDecode',
'GoogleSecretManagerSettingsSource',
'InitSettingsSource',
'JsonConfigSettingsSource',
'NoDecode',
'PydanticBaseSettingsSource',
'PyprojectTomlConfigSettingsSource',
'SecretsSettingsSource',
'SettingsConfigDict',
'SettingsError',
'TomlConfigSettingsSource',
'YamlConfigSettingsSource',
'__version__',
'get_subcommand',
)
__version__ = VERSION

View File

@@ -0,0 +1,4 @@
class SettingsError(ValueError):
"""Base exception for settings-related errors."""
pass

View File

@@ -1,31 +1,104 @@
from __future__ import annotations as _annotations
from pathlib import Path
from typing import Any, ClassVar
import asyncio
import inspect
import threading
import warnings
from argparse import Namespace
from collections.abc import Mapping
from types import SimpleNamespace
from typing import Any, ClassVar, TypeVar
from pydantic import ConfigDict
from pydantic._internal._config import config_keys
from pydantic._internal._utils import deep_update
from pydantic._internal._signature import _field_name_for_signature
from pydantic._internal._utils import deep_update, is_model_class
from pydantic.dataclasses import is_pydantic_dataclass
from pydantic.main import BaseModel
from .exceptions import SettingsError
from .sources import (
ENV_FILE_SENTINEL,
CliSettingsSource,
DefaultSettingsSource,
DotEnvSettingsSource,
DotenvType,
EnvSettingsSource,
InitSettingsSource,
JsonConfigSettingsSource,
PathType,
PydanticBaseSettingsSource,
PydanticModel,
PyprojectTomlConfigSettingsSource,
SecretsSettingsSource,
TomlConfigSettingsSource,
YamlConfigSettingsSource,
get_subcommand,
)
T = TypeVar('T')
class SettingsConfigDict(ConfigDict, total=False):
case_sensitive: bool
nested_model_default_partial_update: bool | None
env_prefix: str
env_file: DotenvType | None
env_file_encoding: str | None
env_ignore_empty: bool
env_nested_delimiter: str | None
secrets_dir: str | Path | None
env_nested_max_split: int | None
env_parse_none_str: str | None
env_parse_enums: bool | None
cli_prog_name: str | None
cli_parse_args: bool | list[str] | tuple[str, ...] | None
cli_parse_none_str: str | None
cli_hide_none_type: bool
cli_avoid_json: bool
cli_enforce_required: bool
cli_use_class_docs_for_groups: bool
cli_exit_on_error: bool
cli_prefix: str
cli_flag_prefix_char: str
cli_implicit_flags: bool | None
cli_ignore_unknown_args: bool | None
cli_kebab_case: bool | None
cli_shortcuts: Mapping[str, str | list[str]] | None
secrets_dir: PathType | None
json_file: PathType | None
json_file_encoding: str | None
yaml_file: PathType | None
yaml_file_encoding: str | None
yaml_config_section: str | None
"""
Specifies the top-level key in a YAML file from which to load the settings.
If provided, the settings will be loaded from the nested section under this key.
This is useful when the YAML file contains multiple configuration sections
and you only want to load a specific subset into your settings model.
"""
pyproject_toml_depth: int
"""
Number of levels **up** from the current working directory to attempt to find a pyproject.toml
file.
This is only used when a pyproject.toml file is not found in the current working directory.
"""
pyproject_toml_table_header: tuple[str, ...]
"""
Header of the TOML table within a pyproject.toml file to use when filling variables.
This is supplied as a `tuple[str, ...]` instead of a `str` to accommodate for headers
containing a `.`.
For example, `toml_table_header = ("tool", "my.tool", "foo")` can be used to fill variable
values from a table with header `[tool."my.tool".foo]`.
To use the root table, exclude this config setting or provide an empty tuple.
"""
toml_file: PathType | None
enable_decoding: bool
# Extend `config_keys` by pydantic settings config keys to
@@ -47,35 +120,104 @@ class BaseSettings(BaseModel):
All the below attributes can be set via `model_config`.
Args:
_case_sensitive: Whether environment variables names should be read with case-sensitivity. Defaults to `None`.
_case_sensitive: Whether environment and CLI variable names should be read with case-sensitivity.
Defaults to `None`.
_nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
Defaults to `False`.
_env_prefix: Prefix for all environment variables. Defaults to `None`.
_env_file: The env file(s) to load settings values from. Defaults to `Path('')`, which
means that the value from `model_config['env_file']` should be used. You can also pass
`None` to indicate that environment variables should not be loaded from an env file.
_env_file_encoding: The env file encoding, e.g. `'latin-1'`. Defaults to `None`.
_env_ignore_empty: Ignore environment variables where the value is an empty string. Default to `False`.
_env_nested_delimiter: The nested env values delimiter. Defaults to `None`.
_secrets_dir: The secret files directory. Defaults to `None`.
_env_nested_max_split: The nested env values maximum nesting. Defaults to `None`, which means no limit.
_env_parse_none_str: The env string value that should be parsed (e.g. "null", "void", "None", etc.)
into `None` type(None). Defaults to `None` type(None), which means no parsing should occur.
_env_parse_enums: Parse enum field names to values. Defaults to `None.`, which means no parsing should occur.
_cli_prog_name: The CLI program name to display in help text. Defaults to `None` if _cli_parse_args is `None`.
Otherwise, defaults to sys.argv[0].
_cli_parse_args: The list of CLI arguments to parse. Defaults to None.
If set to `True`, defaults to sys.argv[1:].
_cli_settings_source: Override the default CLI settings source with a user defined instance. Defaults to None.
_cli_parse_none_str: The CLI string value that should be parsed (e.g. "null", "void", "None", etc.) into
`None` type(None). Defaults to _env_parse_none_str value if set. Otherwise, defaults to "null" if
_cli_avoid_json is `False`, and "None" if _cli_avoid_json is `True`.
_cli_hide_none_type: Hide `None` values in CLI help text. Defaults to `False`.
_cli_avoid_json: Avoid complex JSON objects in CLI help text. Defaults to `False`.
_cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`.
_cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions.
Defaults to `False`.
_cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
Defaults to `True`.
_cli_prefix: The root parser command line arguments prefix. Defaults to "".
_cli_flag_prefix_char: The flag prefix character to use for CLI optional arguments. Defaults to '-'.
_cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags.
(e.g. --flag, --no-flag). Defaults to `False`.
_cli_ignore_unknown_args: Whether to ignore unknown CLI args and parse only known ones. Defaults to `False`.
_cli_kebab_case: CLI args use kebab case. Defaults to `False`.
_cli_shortcuts: Mapping of target field name to alias names. Defaults to `None`.
_secrets_dir: The secret files directory or a sequence of directories. Defaults to `None`.
"""
def __init__(
__pydantic_self__,
_case_sensitive: bool | None = None,
_nested_model_default_partial_update: bool | None = None,
_env_prefix: str | None = None,
_env_file: DotenvType | None = ENV_FILE_SENTINEL,
_env_file_encoding: str | None = None,
_env_ignore_empty: bool | None = None,
_env_nested_delimiter: str | None = None,
_secrets_dir: str | Path | None = None,
_env_nested_max_split: int | None = None,
_env_parse_none_str: str | None = None,
_env_parse_enums: bool | None = None,
_cli_prog_name: str | None = None,
_cli_parse_args: bool | list[str] | tuple[str, ...] | None = None,
_cli_settings_source: CliSettingsSource[Any] | None = None,
_cli_parse_none_str: str | None = None,
_cli_hide_none_type: bool | None = None,
_cli_avoid_json: bool | None = None,
_cli_enforce_required: bool | None = None,
_cli_use_class_docs_for_groups: bool | None = None,
_cli_exit_on_error: bool | None = None,
_cli_prefix: str | None = None,
_cli_flag_prefix_char: str | None = None,
_cli_implicit_flags: bool | None = None,
_cli_ignore_unknown_args: bool | None = None,
_cli_kebab_case: bool | None = None,
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
_secrets_dir: PathType | None = None,
**values: Any,
) -> None:
# Uses something other than `self` the first arg to allow "self" as a settable attribute
super().__init__(
**__pydantic_self__._settings_build_values(
values,
_case_sensitive=_case_sensitive,
_nested_model_default_partial_update=_nested_model_default_partial_update,
_env_prefix=_env_prefix,
_env_file=_env_file,
_env_file_encoding=_env_file_encoding,
_env_ignore_empty=_env_ignore_empty,
_env_nested_delimiter=_env_nested_delimiter,
_env_nested_max_split=_env_nested_max_split,
_env_parse_none_str=_env_parse_none_str,
_env_parse_enums=_env_parse_enums,
_cli_prog_name=_cli_prog_name,
_cli_parse_args=_cli_parse_args,
_cli_settings_source=_cli_settings_source,
_cli_parse_none_str=_cli_parse_none_str,
_cli_hide_none_type=_cli_hide_none_type,
_cli_avoid_json=_cli_avoid_json,
_cli_enforce_required=_cli_enforce_required,
_cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups,
_cli_exit_on_error=_cli_exit_on_error,
_cli_prefix=_cli_prefix,
_cli_flag_prefix_char=_cli_flag_prefix_char,
_cli_implicit_flags=_cli_implicit_flags,
_cli_ignore_unknown_args=_cli_ignore_unknown_args,
_cli_kebab_case=_cli_kebab_case,
_cli_shortcuts=_cli_shortcuts,
_secrets_dir=_secrets_dir,
)
)
@@ -108,33 +250,125 @@ class BaseSettings(BaseModel):
self,
init_kwargs: dict[str, Any],
_case_sensitive: bool | None = None,
_nested_model_default_partial_update: bool | None = None,
_env_prefix: str | None = None,
_env_file: DotenvType | None = None,
_env_file_encoding: str | None = None,
_env_ignore_empty: bool | None = None,
_env_nested_delimiter: str | None = None,
_secrets_dir: str | Path | None = None,
_env_nested_max_split: int | None = None,
_env_parse_none_str: str | None = None,
_env_parse_enums: bool | None = None,
_cli_prog_name: str | None = None,
_cli_parse_args: bool | list[str] | tuple[str, ...] | None = None,
_cli_settings_source: CliSettingsSource[Any] | None = None,
_cli_parse_none_str: str | None = None,
_cli_hide_none_type: bool | None = None,
_cli_avoid_json: bool | None = None,
_cli_enforce_required: bool | None = None,
_cli_use_class_docs_for_groups: bool | None = None,
_cli_exit_on_error: bool | None = None,
_cli_prefix: str | None = None,
_cli_flag_prefix_char: str | None = None,
_cli_implicit_flags: bool | None = None,
_cli_ignore_unknown_args: bool | None = None,
_cli_kebab_case: bool | None = None,
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
_secrets_dir: PathType | None = None,
) -> dict[str, Any]:
# Determine settings config values
case_sensitive = _case_sensitive if _case_sensitive is not None else self.model_config.get('case_sensitive')
env_prefix = _env_prefix if _env_prefix is not None else self.model_config.get('env_prefix')
nested_model_default_partial_update = (
_nested_model_default_partial_update
if _nested_model_default_partial_update is not None
else self.model_config.get('nested_model_default_partial_update')
)
env_file = _env_file if _env_file != ENV_FILE_SENTINEL else self.model_config.get('env_file')
env_file_encoding = (
_env_file_encoding if _env_file_encoding is not None else self.model_config.get('env_file_encoding')
)
env_ignore_empty = (
_env_ignore_empty if _env_ignore_empty is not None else self.model_config.get('env_ignore_empty')
)
env_nested_delimiter = (
_env_nested_delimiter
if _env_nested_delimiter is not None
else self.model_config.get('env_nested_delimiter')
)
env_nested_max_split = (
_env_nested_max_split
if _env_nested_max_split is not None
else self.model_config.get('env_nested_max_split')
)
env_parse_none_str = (
_env_parse_none_str if _env_parse_none_str is not None else self.model_config.get('env_parse_none_str')
)
env_parse_enums = _env_parse_enums if _env_parse_enums is not None else self.model_config.get('env_parse_enums')
cli_prog_name = _cli_prog_name if _cli_prog_name is not None else self.model_config.get('cli_prog_name')
cli_parse_args = _cli_parse_args if _cli_parse_args is not None else self.model_config.get('cli_parse_args')
cli_settings_source = (
_cli_settings_source if _cli_settings_source is not None else self.model_config.get('cli_settings_source')
)
cli_parse_none_str = (
_cli_parse_none_str if _cli_parse_none_str is not None else self.model_config.get('cli_parse_none_str')
)
cli_parse_none_str = cli_parse_none_str if not env_parse_none_str else env_parse_none_str
cli_hide_none_type = (
_cli_hide_none_type if _cli_hide_none_type is not None else self.model_config.get('cli_hide_none_type')
)
cli_avoid_json = _cli_avoid_json if _cli_avoid_json is not None else self.model_config.get('cli_avoid_json')
cli_enforce_required = (
_cli_enforce_required
if _cli_enforce_required is not None
else self.model_config.get('cli_enforce_required')
)
cli_use_class_docs_for_groups = (
_cli_use_class_docs_for_groups
if _cli_use_class_docs_for_groups is not None
else self.model_config.get('cli_use_class_docs_for_groups')
)
cli_exit_on_error = (
_cli_exit_on_error if _cli_exit_on_error is not None else self.model_config.get('cli_exit_on_error')
)
cli_prefix = _cli_prefix if _cli_prefix is not None else self.model_config.get('cli_prefix')
cli_flag_prefix_char = (
_cli_flag_prefix_char
if _cli_flag_prefix_char is not None
else self.model_config.get('cli_flag_prefix_char')
)
cli_implicit_flags = (
_cli_implicit_flags if _cli_implicit_flags is not None else self.model_config.get('cli_implicit_flags')
)
cli_ignore_unknown_args = (
_cli_ignore_unknown_args
if _cli_ignore_unknown_args is not None
else self.model_config.get('cli_ignore_unknown_args')
)
cli_kebab_case = _cli_kebab_case if _cli_kebab_case is not None else self.model_config.get('cli_kebab_case')
cli_shortcuts = _cli_shortcuts if _cli_shortcuts is not None else self.model_config.get('cli_shortcuts')
secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir')
# Configure built-in sources
init_settings = InitSettingsSource(self.__class__, init_kwargs=init_kwargs)
default_settings = DefaultSettingsSource(
self.__class__, nested_model_default_partial_update=nested_model_default_partial_update
)
init_settings = InitSettingsSource(
self.__class__,
init_kwargs=init_kwargs,
nested_model_default_partial_update=nested_model_default_partial_update,
)
env_settings = EnvSettingsSource(
self.__class__,
case_sensitive=case_sensitive,
env_prefix=env_prefix,
env_nested_delimiter=env_nested_delimiter,
env_nested_max_split=env_nested_max_split,
env_ignore_empty=env_ignore_empty,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
dotenv_settings = DotEnvSettingsSource(
self.__class__,
@@ -143,6 +377,10 @@ class BaseSettings(BaseModel):
case_sensitive=case_sensitive,
env_prefix=env_prefix,
env_nested_delimiter=env_nested_delimiter,
env_nested_max_split=env_nested_max_split,
env_ignore_empty=env_ignore_empty,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
file_secret_settings = SecretsSettingsSource(
@@ -155,23 +393,294 @@ class BaseSettings(BaseModel):
env_settings=env_settings,
dotenv_settings=dotenv_settings,
file_secret_settings=file_secret_settings,
)
) + (default_settings,)
custom_cli_sources = [source for source in sources if isinstance(source, CliSettingsSource)]
if not any(custom_cli_sources):
if isinstance(cli_settings_source, CliSettingsSource):
sources = (cli_settings_source,) + sources
elif cli_parse_args is not None:
cli_settings = CliSettingsSource[Any](
self.__class__,
cli_prog_name=cli_prog_name,
cli_parse_args=cli_parse_args,
cli_parse_none_str=cli_parse_none_str,
cli_hide_none_type=cli_hide_none_type,
cli_avoid_json=cli_avoid_json,
cli_enforce_required=cli_enforce_required,
cli_use_class_docs_for_groups=cli_use_class_docs_for_groups,
cli_exit_on_error=cli_exit_on_error,
cli_prefix=cli_prefix,
cli_flag_prefix_char=cli_flag_prefix_char,
cli_implicit_flags=cli_implicit_flags,
cli_ignore_unknown_args=cli_ignore_unknown_args,
cli_kebab_case=cli_kebab_case,
cli_shortcuts=cli_shortcuts,
case_sensitive=case_sensitive,
)
sources = (cli_settings,) + sources
# We ensure that if command line arguments haven't been parsed yet, we do so.
elif cli_parse_args not in (None, False) and not custom_cli_sources[0].env_vars:
custom_cli_sources[0](args=cli_parse_args) # type: ignore
self._settings_warn_unused_config_keys(sources, self.model_config)
if sources:
return deep_update(*reversed([source() for source in sources]))
state: dict[str, Any] = {}
states: dict[str, dict[str, Any]] = {}
for source in sources:
if isinstance(source, PydanticBaseSettingsSource):
source._set_current_state(state)
source._set_settings_sources_data(states)
source_name = source.__name__ if hasattr(source, '__name__') else type(source).__name__
source_state = source()
states[source_name] = source_state
state = deep_update(source_state, state)
return state
else:
# no one should mean to do this, but I think returning an empty dict is marginally preferable
# to an informative error and much better than a confusing error
return {}
@staticmethod
def _settings_warn_unused_config_keys(sources: tuple[object, ...], model_config: SettingsConfigDict) -> None:
"""
Warns if any values in model_config were set but the corresponding settings source has not been initialised.
The list alternative sources and their config keys can be found here:
https://docs.pydantic.dev/latest/concepts/pydantic_settings/#other-settings-source
Args:
sources: The tuple of configured sources
model_config: The model config to check for unused config keys
"""
def warn_if_not_used(source_type: type[PydanticBaseSettingsSource], keys: tuple[str, ...]) -> None:
if not any(isinstance(source, source_type) for source in sources):
for key in keys:
if model_config.get(key) is not None:
warnings.warn(
f'Config key `{key}` is set in model_config but will be ignored because no '
f'{source_type.__name__} source is configured. To use this config key, add a '
f'{source_type.__name__} source to the settings sources via the '
'settings_customise_sources hook.',
UserWarning,
stacklevel=3,
)
warn_if_not_used(JsonConfigSettingsSource, ('json_file', 'json_file_encoding'))
warn_if_not_used(PyprojectTomlConfigSettingsSource, ('pyproject_toml_depth', 'pyproject_toml_table_header'))
warn_if_not_used(TomlConfigSettingsSource, ('toml_file',))
warn_if_not_used(YamlConfigSettingsSource, ('yaml_file', 'yaml_file_encoding', 'yaml_config_section'))
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
extra='forbid',
arbitrary_types_allowed=True,
validate_default=True,
case_sensitive=False,
env_prefix='',
nested_model_default_partial_update=False,
env_file=None,
env_file_encoding=None,
env_ignore_empty=False,
env_nested_delimiter=None,
env_nested_max_split=None,
env_parse_none_str=None,
env_parse_enums=None,
cli_prog_name=None,
cli_parse_args=None,
cli_parse_none_str=None,
cli_hide_none_type=False,
cli_avoid_json=False,
cli_enforce_required=False,
cli_use_class_docs_for_groups=False,
cli_exit_on_error=True,
cli_prefix='',
cli_flag_prefix_char='-',
cli_implicit_flags=False,
cli_ignore_unknown_args=False,
cli_kebab_case=False,
cli_shortcuts=None,
json_file=None,
json_file_encoding=None,
yaml_file=None,
yaml_file_encoding=None,
yaml_config_section=None,
toml_file=None,
secrets_dir=None,
protected_namespaces=('model_', 'settings_'),
protected_namespaces=('model_validate', 'model_dump', 'settings_customise_sources'),
enable_decoding=True,
)
class CliApp:
"""
A utility class for running Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as
CLI applications.
"""
@staticmethod
def _get_base_settings_cls(model_cls: type[Any]) -> type[BaseSettings]:
if issubclass(model_cls, BaseSettings):
return model_cls
class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
__doc__ = model_cls.__doc__
model_config = SettingsConfigDict(
nested_model_default_partial_update=True,
case_sensitive=True,
cli_hide_none_type=True,
cli_avoid_json=True,
cli_enforce_required=True,
cli_implicit_flags=True,
cli_kebab_case=True,
)
return CliAppBaseSettings
@staticmethod
def _run_cli_cmd(model: Any, cli_cmd_method_name: str, is_required: bool) -> Any:
command = getattr(type(model), cli_cmd_method_name, None)
if command is None:
if is_required:
raise SettingsError(f'Error: {type(model).__name__} class is missing {cli_cmd_method_name} entrypoint')
return model
# If the method is asynchronous, we handle its execution based on the current event loop status.
if inspect.iscoroutinefunction(command):
# For asynchronous methods, we have two execution scenarios:
# 1. If no event loop is running in the current thread, run the coroutine directly with asyncio.run().
# 2. If an event loop is already running in the current thread, run the coroutine in a separate thread to avoid conflicts.
try:
# Check if an event loop is currently running in this thread.
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
# We're in a context with an active event loop (e.g., Jupyter Notebook).
# Running asyncio.run() here would cause conflicts, so we use a separate thread.
exception_container = []
def run_coro() -> None:
try:
# Execute the coroutine in a new event loop in this separate thread.
asyncio.run(command(model))
except Exception as e:
exception_container.append(e)
thread = threading.Thread(target=run_coro)
thread.start()
thread.join()
if exception_container:
# Propagate exceptions from the separate thread.
raise exception_container[0]
else:
# No event loop is running; safe to run the coroutine directly.
asyncio.run(command(model))
else:
# For synchronous methods, call them directly.
command(model)
return model
@staticmethod
def run(
model_cls: type[T],
cli_args: list[str] | Namespace | SimpleNamespace | dict[str, Any] | None = None,
cli_settings_source: CliSettingsSource[Any] | None = None,
cli_exit_on_error: bool | None = None,
cli_cmd_method_name: str = 'cli_cmd',
**model_init_data: Any,
) -> T:
"""
Runs a Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as a CLI application.
Running a model as a CLI application requires the `cli_cmd` method to be defined in the model class.
Args:
model_cls: The model class to run as a CLI application.
cli_args: The list of CLI arguments to parse. If `cli_settings_source` is specified, this may
also be a namespace or dictionary of pre-parsed CLI arguments. Defaults to `sys.argv[1:]`.
cli_settings_source: Override the default CLI settings source with a user defined instance.
Defaults to `None`.
cli_exit_on_error: Determines whether this function exits on error. If model is subclass of
`BaseSettings`, defaults to BaseSettings `cli_exit_on_error` value. Otherwise, defaults to
`True`.
cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd".
model_init_data: The model init data.
Returns:
The ran instance of model.
Raises:
SettingsError: If model_cls is not subclass of `BaseModel` or `pydantic.dataclasses.dataclass`.
SettingsError: If model_cls does not have a `cli_cmd` entrypoint defined.
"""
if not (is_pydantic_dataclass(model_cls) or is_model_class(model_cls)):
raise SettingsError(
f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass'
)
cli_settings = None
cli_parse_args = True if cli_args is None else cli_args
if cli_settings_source is not None:
if isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)):
cli_settings = cli_settings_source(parsed_args=cli_parse_args)
else:
cli_settings = cli_settings_source(args=cli_parse_args)
elif isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)):
raise SettingsError('Error: `cli_args` must be list[str] or None when `cli_settings_source` is not used')
model_init_data['_cli_parse_args'] = cli_parse_args
model_init_data['_cli_exit_on_error'] = cli_exit_on_error
model_init_data['_cli_settings_source'] = cli_settings
if not issubclass(model_cls, BaseSettings):
base_settings_cls = CliApp._get_base_settings_cls(model_cls)
model = base_settings_cls(**model_init_data)
model_init_data = {}
for field_name, field_info in base_settings_cls.model_fields.items():
model_init_data[_field_name_for_signature(field_name, field_info)] = getattr(model, field_name)
return CliApp._run_cli_cmd(model_cls(**model_init_data), cli_cmd_method_name, is_required=False)
@staticmethod
def run_subcommand(
model: PydanticModel, cli_exit_on_error: bool | None = None, cli_cmd_method_name: str = 'cli_cmd'
) -> PydanticModel:
"""
Runs the model subcommand. Running a model subcommand requires the `cli_cmd` method to be defined in
the nested model subcommand class.
Args:
model: The model to run the subcommand from.
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd".
Returns:
The ran subcommand model.
Raises:
SystemExit: When no subcommand is found and cli_exit_on_error=`True` (the default).
SettingsError: When no subcommand is found and cli_exit_on_error=`False`.
"""
subcommand = get_subcommand(model, is_required=True, cli_exit_on_error=cli_exit_on_error)
return CliApp._run_cli_cmd(subcommand, cli_cmd_method_name, is_required=True)
@staticmethod
def serialize(model: PydanticModel) -> list[str]:
"""
Serializes the CLI arguments for a Pydantic data model.
Args:
model: The data model to serialize.
Returns:
The serialized CLI arguments for the data model.
"""
base_settings_cls = CliApp._get_base_settings_cls(type(model))
return CliSettingsSource[Any](base_settings_cls)._serialized_args(model)

View File

@@ -1,653 +0,0 @@
from __future__ import annotations as _annotations
import json
import os
import warnings
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast
from pydantic import AliasChoices, AliasPath, BaseModel, Json
from pydantic._internal._typing_extra import origin_is_union
from pydantic._internal._utils import deep_update, lenient_issubclass
from pydantic.fields import FieldInfo
from typing_extensions import get_args, get_origin
from pydantic_settings.utils import path_type_label
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
DotenvType = Union[Path, str, List[Union[Path, str]], Tuple[Union[Path, str], ...]]
# This is used as default value for `_env_file` in the `BaseSettings` class and
# `env_file` in `DotEnvSettingsSource` so the default can be distinguished from `None`.
# See the docstring of `BaseSettings` for more details.
ENV_FILE_SENTINEL: DotenvType = Path('')
class SettingsError(ValueError):
pass
class PydanticBaseSettingsSource(ABC):
"""
Abstract base class for settings sources, every settings source classes should inherit from it.
"""
def __init__(self, settings_cls: type[BaseSettings]):
self.settings_cls = settings_cls
self.config = settings_cls.model_config
@abstractmethod
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Gets the value, the key for model creation, and a flag to determine whether value is complex.
This is an abstract method that should be overridden in every settings source classes.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple contains the key, value and a flag to determine whether value is complex.
"""
pass
def field_is_complex(self, field: FieldInfo) -> bool:
"""
Checks whether a field is complex, in which case it will attempt to be parsed as JSON.
Args:
field: The field.
Returns:
Whether the field is complex.
"""
return _annotation_is_complex(field.annotation, field.metadata)
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
"""
Prepares the value of a field.
Args:
field_name: The field name.
field: The field.
value: The value of the field that has to be prepared.
value_is_complex: A flag to determine whether value is complex.
Returns:
The prepared value.
"""
if value is not None and (self.field_is_complex(field) or value_is_complex):
return self.decode_complex_value(field_name, field, value)
return value
def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any:
"""
Decode the value for a complex field
Args:
field_name: The field name.
field: The field.
value: The value of the field that has to be prepared.
Returns:
The decoded value for further preparation
"""
return json.loads(value)
@abstractmethod
def __call__(self) -> dict[str, Any]:
pass
class InitSettingsSource(PydanticBaseSettingsSource):
"""
Source class for loading values provided during settings class initialization.
"""
def __init__(self, settings_cls: type[BaseSettings], init_kwargs: dict[str, Any]):
self.init_kwargs = init_kwargs
super().__init__(settings_cls)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
# Nothing to do here. Only implement the return statement to make mypy happy
return None, '', False
def __call__(self) -> dict[str, Any]:
return self.init_kwargs
def __repr__(self) -> str:
return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})'
class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource):
def __init__(
self, settings_cls: type[BaseSettings], case_sensitive: bool | None = None, env_prefix: str | None = None
) -> None:
super().__init__(settings_cls)
self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '')
def _apply_case_sensitive(self, value: str) -> str:
return value.lower() if not self.case_sensitive else value
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
"""
Extracts field info. This info is used to get the value of field from environment variables.
It returns a list of tuples, each tuple contains:
* field_key: The key of field that has to be used in model creation.
* env_name: The environment variable name of the field.
* value_is_complex: A flag to determine whether the value from environment variable
is complex and has to be parsed.
Args:
field (FieldInfo): The field.
field_name (str): The field name.
Returns:
list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex.
"""
field_info: list[tuple[str, str, bool]] = []
if isinstance(field.validation_alias, (AliasChoices, AliasPath)):
v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases()
else:
v_alias = field.validation_alias
if v_alias:
if isinstance(v_alias, list): # AliasChoices, AliasPath
for alias in v_alias:
if isinstance(alias, str): # AliasPath
field_info.append((alias, self._apply_case_sensitive(alias), True if len(alias) > 1 else False))
elif isinstance(alias, list): # AliasChoices
first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str
field_info.append(
(first_arg, self._apply_case_sensitive(first_arg), True if len(alias) > 1 else False)
)
else: # string validation alias
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
else:
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
return field_info
def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]:
"""
Replace field names in values dict by looking in models fields insensitively.
By having the following models:
```py
class SubSubSub(BaseModel):
VaL3: str
class SubSub(BaseModel):
Val2: str
SUB_sub_SuB: SubSubSub
class Sub(BaseModel):
VAL1: str
SUB_sub: SubSub
class Settings(BaseSettings):
nested: Sub
model_config = SettingsConfigDict(env_nested_delimiter='__')
```
Then:
_replace_field_names_case_insensitively(
field,
{"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}}
)
Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}}
"""
values: dict[str, Any] = {}
for name, value in field_values.items():
sub_model_field: FieldInfo | None = None
# This is here to make mypy happy
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
if not field.annotation or not hasattr(field.annotation, 'model_fields'):
values[name] = value
continue
# Find field in sub model by looking in fields case insensitively
for sub_model_field_name, f in field.annotation.model_fields.items():
if not f.validation_alias and sub_model_field_name.lower() == name.lower():
sub_model_field = f
break
if not sub_model_field:
values[name] = value
continue
if lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict):
values[sub_model_field_name] = self._replace_field_names_case_insensitively(sub_model_field, value)
else:
values[sub_model_field_name] = value
return values
def __call__(self) -> dict[str, Any]:
data: dict[str, Any] = {}
for field_name, field in self.settings_cls.model_fields.items():
try:
field_value, field_key, value_is_complex = self.get_field_value(field, field_name)
except Exception as e:
raise SettingsError(
f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"'
) from e
try:
field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex)
except ValueError as e:
raise SettingsError(
f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"'
) from e
if field_value is not None:
if (
not self.case_sensitive
and lenient_issubclass(field.annotation, BaseModel)
and isinstance(field_value, dict)
):
data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
else:
data[field_key] = field_value
return data
class SecretsSettingsSource(PydanticBaseEnvSettingsSource):
"""
Source class for loading settings values from secret files.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
secrets_dir: str | Path | None = None,
case_sensitive: bool | None = None,
env_prefix: str | None = None,
) -> None:
super().__init__(settings_cls, case_sensitive, env_prefix)
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')
def __call__(self) -> dict[str, Any]:
"""
Build fields from "secrets" files.
"""
secrets: dict[str, str | None] = {}
if self.secrets_dir is None:
return secrets
self.secrets_path = Path(self.secrets_dir).expanduser()
if not self.secrets_path.exists():
warnings.warn(f'directory "{self.secrets_path}" does not exist')
return secrets
if not self.secrets_path.is_dir():
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(self.secrets_path)}')
return super().__call__()
@classmethod
def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None:
"""
Find a file within path's directory matching filename, optionally ignoring case.
Args:
dir_path: Directory path.
file_name: File name.
case_sensitive: Whether to search for file name case sensitively.
Returns:
Whether file path or `None` if file does not exist in directory.
"""
for f in dir_path.iterdir():
if f.name == file_name:
return f
elif not case_sensitive and f.name.lower() == file_name.lower():
return f
return None
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Gets the value for field from secret file and a flag to determine whether value is complex.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple contains the key, value if the file exists otherwise `None`, and
a flag to determine whether value is complex.
"""
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
path = self.find_case_path(self.secrets_path, env_name, self.case_sensitive)
if not path:
# path does not exist, we currently don't return a warning for this
continue
if path.is_file():
return path.read_text().strip(), field_key, value_is_complex
else:
warnings.warn(
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
stacklevel=4,
)
return None, field_key, value_is_complex
def __repr__(self) -> str:
return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
class EnvSettingsSource(PydanticBaseEnvSettingsSource):
"""
Source class for loading settings values from environment variables.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str | None = None,
) -> None:
super().__init__(settings_cls, case_sensitive, env_prefix)
self.env_nested_delimiter = (
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
)
self.env_prefix_len = len(self.env_prefix)
self.env_vars = self._load_env_vars()
def _load_env_vars(self) -> Mapping[str, str | None]:
if self.case_sensitive:
return os.environ
return {k.lower(): v for k, v in os.environ.items()}
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Gets the value for field from environment variables and a flag to determine whether value is complex.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple contains the key, value if the file exists otherwise `None`, and
a flag to determine whether value is complex.
"""
env_val: str | None = None
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
env_val = self.env_vars.get(env_name)
if env_val is not None:
break
return env_val, field_key, value_is_complex
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
"""
Prepare value for the field.
* Extract value for nested field.
* Deserialize value to python object for complex field.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple contains prepared value for the field.
Raises:
ValuesError: When There is an error in deserializing value for complex field.
"""
is_complex, allow_parse_failure = self._field_is_complex(field)
if is_complex or value_is_complex:
if value is None:
# field is complex but no value found so far, try explode_env_vars
env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
if env_val_built:
return env_val_built
else:
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
try:
value = self.decode_complex_value(field_name, field, value)
except ValueError as e:
if not allow_parse_failure:
raise e
if isinstance(value, dict):
return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
else:
return value
elif value is not None:
# simplest case, field is not complex, we only need to add the value if it was found
return value
def _union_is_complex(self, annotation: type[Any] | None, metadata: list[Any]) -> bool:
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
"""
Find out if a field is complex, and if so whether JSON errors should be ignored
"""
if self.field_is_complex(field):
allow_parse_failure = False
elif origin_is_union(get_origin(field.annotation)) and self._union_is_complex(field.annotation, field.metadata):
allow_parse_failure = True
else:
return False, False
return True, allow_parse_failure
@staticmethod
def next_field(field: FieldInfo | None, key: str) -> FieldInfo | None:
"""
Find the field in a sub model by key(env name)
By having the following models:
```py
class SubSubModel(BaseSettings):
dvals: Dict
class SubModel(BaseSettings):
vals: list[str]
sub_sub_model: SubSubModel
class Cfg(BaseSettings):
sub_model: SubModel
```
Then:
next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class
Args:
field: The field.
key: The key (env name).
Returns:
Field if it finds the next field otherwise `None`.
"""
if not field or origin_is_union(get_origin(field.annotation)):
# no support for Unions of complex BaseSettings fields
return None
elif field.annotation and hasattr(field.annotation, 'model_fields') and field.annotation.model_fields.get(key):
return field.annotation.model_fields[key]
return None
def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
"""
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
This is applied to a single field, hence filtering by env_var prefix.
Args:
field_name: The field name.
field: The field.
env_vars: Environment variables.
Returns:
A dictionaty contains extracted values from nested env values.
"""
prefixes = [
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
]
result: dict[str, Any] = {}
for env_name, env_val in env_vars.items():
if not any(env_name.startswith(prefix) for prefix in prefixes):
continue
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
env_name_without_prefix = env_name[self.env_prefix_len :]
_, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter)
env_var = result
target_field: FieldInfo | None = field
for key in keys:
target_field = self.next_field(target_field, key)
env_var = env_var.setdefault(key, {})
# get proper field with last_key
target_field = self.next_field(target_field, last_key)
# check if env_val maps to a complex field and if so, parse the env_val
if target_field and env_val:
is_complex, allow_json_failure = self._field_is_complex(target_field)
if is_complex:
try:
env_val = self.decode_complex_value(last_key, target_field, env_val)
except ValueError as e:
if not allow_json_failure:
raise e
env_var[last_key] = env_val
return result
def __repr__(self) -> str:
return (
f'EnvSettingsSource(env_nested_delimiter={self.env_nested_delimiter!r}, '
f'env_prefix_len={self.env_prefix_len!r})'
)
class DotEnvSettingsSource(EnvSettingsSource):
"""
Source class for loading settings values from env files.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
env_file: DotenvType | None = ENV_FILE_SENTINEL,
env_file_encoding: str | None = None,
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str | None = None,
) -> None:
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
self.env_file_encoding = (
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
)
super().__init__(settings_cls, case_sensitive, env_prefix, env_nested_delimiter)
def _load_env_vars(self) -> Mapping[str, str | None]:
return self._read_env_files(self.case_sensitive)
def _read_env_files(self, case_sensitive: bool) -> Mapping[str, str | None]:
env_files = self.env_file
if env_files is None:
return {}
if isinstance(env_files, (str, os.PathLike)):
env_files = [env_files]
dotenv_vars: dict[str, str | None] = {}
for env_file in env_files:
env_path = Path(env_file).expanduser()
if env_path.is_file():
dotenv_vars.update(
read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive)
)
return dotenv_vars
def __call__(self) -> dict[str, Any]:
data: dict[str, Any] = super().__call__()
data_lower_keys: list[str] = []
if not self.case_sensitive:
data_lower_keys = [x.lower() for x in data.keys()]
# As `extra` config is allowed in dotenv settings source, We have to
# update data with extra env variabels from dotenv file.
for env_name, env_value in self.env_vars.items():
if env_name.startswith(self.env_prefix) and env_value is not None:
env_name_without_prefix = env_name[self.env_prefix_len :]
first_key, *_ = env_name_without_prefix.split(self.env_nested_delimiter)
if (data_lower_keys and first_key not in data_lower_keys) or (
not data_lower_keys and first_key not in data
):
data[first_key] = env_value
return data
def __repr__(self) -> str:
return (
f'DotEnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})'
)
def read_env_file(
file_path: Path, *, encoding: str | None = None, case_sensitive: bool = False
) -> Mapping[str, str | None]:
try:
from dotenv import dotenv_values
except ImportError as e:
raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e
file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8')
if not case_sensitive:
return {k.lower(): v for k, v in file_vars.items()}
else:
return file_vars
def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
return False
origin = get_origin(annotation)
return (
_annotation_is_complex_inner(annotation)
or _annotation_is_complex_inner(origin)
or hasattr(origin, '__pydantic_core_schema__')
or hasattr(origin, '__get_pydantic_core_schema__')
)
def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass(
annotation
)

View File

@@ -0,0 +1,68 @@
"""Package for handling configuration sources in pydantic-settings."""
from .base import (
ConfigFileSourceMixin,
DefaultSettingsSource,
InitSettingsSource,
PydanticBaseEnvSettingsSource,
PydanticBaseSettingsSource,
get_subcommand,
)
from .providers.aws import AWSSecretsManagerSettingsSource
from .providers.azure import AzureKeyVaultSettingsSource
from .providers.cli import (
CLI_SUPPRESS,
CliExplicitFlag,
CliImplicitFlag,
CliMutuallyExclusiveGroup,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
CliSuppress,
CliUnknownArgs,
)
from .providers.dotenv import DotEnvSettingsSource, read_env_file
from .providers.env import EnvSettingsSource
from .providers.gcp import GoogleSecretManagerSettingsSource
from .providers.json import JsonConfigSettingsSource
from .providers.pyproject import PyprojectTomlConfigSettingsSource
from .providers.secrets import SecretsSettingsSource
from .providers.toml import TomlConfigSettingsSource
from .providers.yaml import YamlConfigSettingsSource
from .types import DEFAULT_PATH, ENV_FILE_SENTINEL, DotenvType, ForceDecode, NoDecode, PathType, PydanticModel
__all__ = [
'CLI_SUPPRESS',
'ENV_FILE_SENTINEL',
'DEFAULT_PATH',
'AWSSecretsManagerSettingsSource',
'AzureKeyVaultSettingsSource',
'CliExplicitFlag',
'CliImplicitFlag',
'CliMutuallyExclusiveGroup',
'CliPositionalArg',
'CliSettingsSource',
'CliSubCommand',
'CliSuppress',
'CliUnknownArgs',
'DefaultSettingsSource',
'DotEnvSettingsSource',
'DotenvType',
'EnvSettingsSource',
'ForceDecode',
'GoogleSecretManagerSettingsSource',
'InitSettingsSource',
'JsonConfigSettingsSource',
'NoDecode',
'PathType',
'PydanticBaseEnvSettingsSource',
'PydanticBaseSettingsSource',
'ConfigFileSourceMixin',
'PydanticModel',
'PyprojectTomlConfigSettingsSource',
'SecretsSettingsSource',
'TomlConfigSettingsSource',
'YamlConfigSettingsSource',
'get_subcommand',
'read_env_file',
]

View File

@@ -0,0 +1,527 @@
"""Base classes and core functionality for pydantic-settings sources."""
from __future__ import annotations as _annotations
import json
import os
from abc import ABC, abstractmethod
from dataclasses import asdict, is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import AliasChoices, AliasPath, BaseModel, TypeAdapter
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
get_origin,
)
from pydantic._internal._utils import is_model_class
from pydantic.fields import FieldInfo
from typing_extensions import get_args
from typing_inspection import typing_objects
from typing_inspection.introspection import is_union_origin
from ..exceptions import SettingsError
from ..utils import _lenient_issubclass
from .types import EnvNoneType, ForceDecode, NoDecode, PathType, PydanticModel, _CliSubCommand
from .utils import (
_annotation_is_complex,
_get_alias_names,
_get_model_fields,
_strip_annotated,
_union_is_complex,
)
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
def get_subcommand(
model: PydanticModel, is_required: bool = True, cli_exit_on_error: bool | None = None
) -> Optional[PydanticModel]:
"""
Get the subcommand from a model.
Args:
model: The model to get the subcommand from.
is_required: Determines whether a model must have subcommand set and raises error if not
found. Defaults to `True`.
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
Returns:
The subcommand model if found, otherwise `None`.
Raises:
SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True`
(the default).
SettingsError: When no subcommand is found and is_required=`True` and
cli_exit_on_error=`False`.
"""
model_cls = type(model)
if cli_exit_on_error is None and is_model_class(model_cls):
model_default = model_cls.model_config.get('cli_exit_on_error')
if isinstance(model_default, bool):
cli_exit_on_error = model_default
if cli_exit_on_error is None:
cli_exit_on_error = True
subcommands: list[str] = []
for field_name, field_info in _get_model_fields(model_cls).items():
if _CliSubCommand in field_info.metadata:
if getattr(model, field_name) is not None:
return getattr(model, field_name)
subcommands.append(field_name)
if is_required:
error_message = (
f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}'
if subcommands
else 'Error: CLI subcommand is required but no subcommands were found.'
)
raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message)
return None
class PydanticBaseSettingsSource(ABC):
"""
Abstract base class for settings sources, every settings source classes should inherit from it.
"""
def __init__(self, settings_cls: type[BaseSettings]):
self.settings_cls = settings_cls
self.config = settings_cls.model_config
self._current_state: dict[str, Any] = {}
self._settings_sources_data: dict[str, dict[str, Any]] = {}
def _set_current_state(self, state: dict[str, Any]) -> None:
"""
Record the state of settings from the previous settings sources. This should
be called right before __call__.
"""
self._current_state = state
def _set_settings_sources_data(self, states: dict[str, dict[str, Any]]) -> None:
"""
Record the state of settings from all previous settings sources. This should
be called right before __call__.
"""
self._settings_sources_data = states
@property
def current_state(self) -> dict[str, Any]:
"""
The current state of the settings, populated by the previous settings sources.
"""
return self._current_state
@property
def settings_sources_data(self) -> dict[str, dict[str, Any]]:
"""
The state of all previous settings sources.
"""
return self._settings_sources_data
@abstractmethod
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Gets the value, the key for model creation, and a flag to determine whether value is complex.
This is an abstract method that should be overridden in every settings source classes.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple that contains the value, key and a flag to determine whether value is complex.
"""
pass
def field_is_complex(self, field: FieldInfo) -> bool:
"""
Checks whether a field is complex, in which case it will attempt to be parsed as JSON.
Args:
field: The field.
Returns:
Whether the field is complex.
"""
return _annotation_is_complex(field.annotation, field.metadata)
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
"""
Prepares the value of a field.
Args:
field_name: The field name.
field: The field.
value: The value of the field that has to be prepared.
value_is_complex: A flag to determine whether value is complex.
Returns:
The prepared value.
"""
if value is not None and (self.field_is_complex(field) or value_is_complex):
return self.decode_complex_value(field_name, field, value)
return value
def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any:
"""
Decode the value for a complex field
Args:
field_name: The field name.
field: The field.
value: The value of the field that has to be prepared.
Returns:
The decoded value for further preparation
"""
if field and (
NoDecode in field.metadata
or (self.config.get('enable_decoding') is False and ForceDecode not in field.metadata)
):
return value
return json.loads(value)
@abstractmethod
def __call__(self) -> dict[str, Any]:
pass
class ConfigFileSourceMixin(ABC):
def _read_files(self, files: PathType | None) -> dict[str, Any]:
if files is None:
return {}
if isinstance(files, (str, os.PathLike)):
files = [files]
vars: dict[str, Any] = {}
for file in files:
file_path = Path(file).expanduser()
if file_path.is_file():
vars.update(self._read_file(file_path))
return vars
@abstractmethod
def _read_file(self, path: Path) -> dict[str, Any]:
pass
class DefaultSettingsSource(PydanticBaseSettingsSource):
"""
Source class for loading default object values.
Args:
settings_cls: The Settings class.
nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
Defaults to `False`.
"""
def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partial_update: bool | None = None):
super().__init__(settings_cls)
self.defaults: dict[str, Any] = {}
self.nested_model_default_partial_update = (
nested_model_default_partial_update
if nested_model_default_partial_update is not None
else self.config.get('nested_model_default_partial_update', False)
)
if self.nested_model_default_partial_update:
for field_name, field_info in settings_cls.model_fields.items():
alias_names, *_ = _get_alias_names(field_name, field_info)
preferred_alias = alias_names[0]
if is_dataclass(type(field_info.default)):
self.defaults[preferred_alias] = asdict(field_info.default)
elif is_model_class(type(field_info.default)):
self.defaults[preferred_alias] = field_info.default.model_dump()
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
# Nothing to do here. Only implement the return statement to make mypy happy
return None, '', False
def __call__(self) -> dict[str, Any]:
return self.defaults
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(nested_model_default_partial_update={self.nested_model_default_partial_update})'
)
class InitSettingsSource(PydanticBaseSettingsSource):
"""
Source class for loading values provided during settings class initialization.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
init_kwargs: dict[str, Any],
nested_model_default_partial_update: bool | None = None,
):
self.init_kwargs = {}
init_kwarg_names = set(init_kwargs.keys())
for field_name, field_info in settings_cls.model_fields.items():
alias_names, *_ = _get_alias_names(field_name, field_info)
init_kwarg_name = init_kwarg_names & set(alias_names)
if init_kwarg_name:
preferred_alias = alias_names[0]
preferred_set_alias = next(alias for alias in alias_names if alias in init_kwarg_name)
init_kwarg_names -= init_kwarg_name
self.init_kwargs[preferred_alias] = init_kwargs[preferred_set_alias]
self.init_kwargs.update({key: val for key, val in init_kwargs.items() if key in init_kwarg_names})
super().__init__(settings_cls)
self.nested_model_default_partial_update = (
nested_model_default_partial_update
if nested_model_default_partial_update is not None
else self.config.get('nested_model_default_partial_update', False)
)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
# Nothing to do here. Only implement the return statement to make mypy happy
return None, '', False
def __call__(self) -> dict[str, Any]:
return (
TypeAdapter(dict[str, Any]).dump_python(self.init_kwargs)
if self.nested_model_default_partial_update
else self.init_kwargs
)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(init_kwargs={self.init_kwargs!r})'
class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource):
def __init__(
self,
settings_cls: type[BaseSettings],
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
super().__init__(settings_cls)
self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '')
self.env_ignore_empty = (
env_ignore_empty if env_ignore_empty is not None else self.config.get('env_ignore_empty', False)
)
self.env_parse_none_str = (
env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str')
)
self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums')
def _apply_case_sensitive(self, value: str) -> str:
return value.lower() if not self.case_sensitive else value
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
"""
Extracts field info. This info is used to get the value of field from environment variables.
It returns a list of tuples, each tuple contains:
* field_key: The key of field that has to be used in model creation.
* env_name: The environment variable name of the field.
* value_is_complex: A flag to determine whether the value from environment variable
is complex and has to be parsed.
Args:
field (FieldInfo): The field.
field_name (str): The field name.
Returns:
list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex.
"""
field_info: list[tuple[str, str, bool]] = []
if isinstance(field.validation_alias, (AliasChoices, AliasPath)):
v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases()
else:
v_alias = field.validation_alias
if v_alias:
if isinstance(v_alias, list): # AliasChoices, AliasPath
for alias in v_alias:
if isinstance(alias, str): # AliasPath
field_info.append((alias, self._apply_case_sensitive(alias), True if len(alias) > 1 else False))
elif isinstance(alias, list): # AliasChoices
first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str
field_info.append(
(first_arg, self._apply_case_sensitive(first_arg), True if len(alias) > 1 else False)
)
else: # string validation alias
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
if not v_alias or self.config.get('populate_by_name', False):
annotation = field.annotation
if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)):
annotation = _strip_annotated(annotation.__value__) # type: ignore[union-attr]
if is_union_origin(get_origin(annotation)) and _union_is_complex(annotation, field.metadata):
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
else:
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
return field_info
def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]:
"""
Replace field names in values dict by looking in models fields insensitively.
By having the following models:
```py
class SubSubSub(BaseModel):
VaL3: str
class SubSub(BaseModel):
Val2: str
SUB_sub_SuB: SubSubSub
class Sub(BaseModel):
VAL1: str
SUB_sub: SubSub
class Settings(BaseSettings):
nested: Sub
model_config = SettingsConfigDict(env_nested_delimiter='__')
```
Then:
_replace_field_names_case_insensitively(
field,
{"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}}
)
Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}}
"""
values: dict[str, Any] = {}
for name, value in field_values.items():
sub_model_field: FieldInfo | None = None
annotation = field.annotation
# If field is Optional, we need to find the actual type
if is_union_origin(get_origin(field.annotation)):
args = get_args(annotation)
if len(args) == 2 and type(None) in args:
for arg in args:
if arg is not None:
annotation = arg
break
# This is here to make mypy happy
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
if not annotation or not hasattr(annotation, 'model_fields'):
values[name] = value
continue
else:
model_fields: dict[str, FieldInfo] = annotation.model_fields
# Find field in sub model by looking in fields case insensitively
field_key: str | None = None
for sub_model_field_name, sub_model_field in model_fields.items():
aliases, _ = _get_alias_names(sub_model_field_name, sub_model_field)
_search = (alias for alias in aliases if alias.lower() == name.lower())
if field_key := next(_search, None):
break
if not field_key:
values[name] = value
continue
if (
sub_model_field is not None
and _lenient_issubclass(sub_model_field.annotation, BaseModel)
and isinstance(value, dict)
):
values[field_key] = self._replace_field_names_case_insensitively(sub_model_field, value)
else:
values[field_key] = value
return values
def _replace_env_none_type_values(self, field_value: dict[str, Any]) -> dict[str, Any]:
"""
Recursively parse values that are of "None" type(EnvNoneType) to `None` type(None).
"""
values: dict[str, Any] = {}
for key, value in field_value.items():
if not isinstance(value, EnvNoneType):
values[key] = value if not isinstance(value, dict) else self._replace_env_none_type_values(value)
else:
values[key] = None
return values
def _get_resolved_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Gets the value, the preferred alias key for model creation, and a flag to determine whether value
is complex.
Note:
In V3, this method should either be made public, or, this method should be removed and the
abstract method get_field_value should be updated to include a "use_preferred_alias" flag.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple that contains the value, preferred key and a flag to determine whether value is complex.
"""
field_value, field_key, value_is_complex = self.get_field_value(field, field_name)
if not (value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name))):
field_infos = self._extract_field_info(field, field_name)
preferred_key, *_ = field_infos[0]
return field_value, preferred_key, value_is_complex
return field_value, field_key, value_is_complex
def __call__(self) -> dict[str, Any]:
data: dict[str, Any] = {}
for field_name, field in self.settings_cls.model_fields.items():
try:
field_value, field_key, value_is_complex = self._get_resolved_field_value(field, field_name)
except Exception as e:
raise SettingsError(
f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"'
) from e
try:
field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex)
except ValueError as e:
raise SettingsError(
f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"'
) from e
if field_value is not None:
if self.env_parse_none_str is not None:
if isinstance(field_value, dict):
field_value = self._replace_env_none_type_values(field_value)
elif isinstance(field_value, EnvNoneType):
field_value = None
if (
not self.case_sensitive
# and _lenient_issubclass(field.annotation, BaseModel)
and isinstance(field_value, dict)
):
data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
else:
data[field_key] = field_value
return data
__all__ = [
'ConfigFileSourceMixin',
'DefaultSettingsSource',
'InitSettingsSource',
'PydanticBaseEnvSettingsSource',
'PydanticBaseSettingsSource',
'SettingsError',
]

View File

@@ -0,0 +1,41 @@
"""Package containing individual source implementations."""
from .aws import AWSSecretsManagerSettingsSource
from .azure import AzureKeyVaultSettingsSource
from .cli import (
CliExplicitFlag,
CliImplicitFlag,
CliMutuallyExclusiveGroup,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
CliSuppress,
)
from .dotenv import DotEnvSettingsSource
from .env import EnvSettingsSource
from .gcp import GoogleSecretManagerSettingsSource
from .json import JsonConfigSettingsSource
from .pyproject import PyprojectTomlConfigSettingsSource
from .secrets import SecretsSettingsSource
from .toml import TomlConfigSettingsSource
from .yaml import YamlConfigSettingsSource
__all__ = [
'AWSSecretsManagerSettingsSource',
'AzureKeyVaultSettingsSource',
'CliExplicitFlag',
'CliImplicitFlag',
'CliMutuallyExclusiveGroup',
'CliPositionalArg',
'CliSettingsSource',
'CliSubCommand',
'CliSuppress',
'DotEnvSettingsSource',
'EnvSettingsSource',
'GoogleSecretManagerSettingsSource',
'JsonConfigSettingsSource',
'PyprojectTomlConfigSettingsSource',
'SecretsSettingsSource',
'TomlConfigSettingsSource',
'YamlConfigSettingsSource',
]

View File

@@ -0,0 +1,79 @@
from __future__ import annotations as _annotations # important for BaseSettings import to work
import json
from collections.abc import Mapping
from typing import TYPE_CHECKING, Optional
from ..utils import parse_env_vars
from .env import EnvSettingsSource
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
boto3_client = None
SecretsManagerClient = None
def import_aws_secrets_manager() -> None:
global boto3_client
global SecretsManagerClient
try:
from boto3 import client as boto3_client
from mypy_boto3_secretsmanager.client import SecretsManagerClient
except ImportError as e: # pragma: no cover
raise ImportError(
'AWS Secrets Manager dependencies are not installed, run `pip install pydantic-settings[aws-secrets-manager]`'
) from e
class AWSSecretsManagerSettingsSource(EnvSettingsSource):
_secret_id: str
_secretsmanager_client: SecretsManagerClient # type: ignore
def __init__(
self,
settings_cls: type[BaseSettings],
secret_id: str,
region_name: str | None = None,
endpoint_url: str | None = None,
case_sensitive: bool | None = True,
env_prefix: str | None = None,
env_nested_delimiter: str | None = '--',
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
import_aws_secrets_manager()
self._secretsmanager_client = boto3_client('secretsmanager', region_name=region_name, endpoint_url=endpoint_url) # type: ignore
self._secret_id = secret_id
super().__init__(
settings_cls,
case_sensitive=case_sensitive,
env_prefix=env_prefix,
env_nested_delimiter=env_nested_delimiter,
env_ignore_empty=False,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
response = self._secretsmanager_client.get_secret_value(SecretId=self._secret_id) # type: ignore
return parse_env_vars(
json.loads(response['SecretString']),
self.case_sensitive,
self.env_ignore_empty,
self.env_parse_none_str,
)
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(secret_id={self._secret_id!r}, '
f'env_nested_delimiter={self.env_nested_delimiter!r})'
)
__all__ = [
'AWSSecretsManagerSettingsSource',
]

View File

@@ -0,0 +1,145 @@
"""Azure Key Vault settings source."""
from __future__ import annotations as _annotations
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Optional
from pydantic.alias_generators import to_snake
from pydantic.fields import FieldInfo
from .env import EnvSettingsSource
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.keyvault.secrets import SecretClient
from pydantic_settings.main import BaseSettings
else:
TokenCredential = None
ResourceNotFoundError = None
SecretClient = None
def import_azure_key_vault() -> None:
global TokenCredential
global SecretClient
global ResourceNotFoundError
try:
from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.keyvault.secrets import SecretClient
except ImportError as e: # pragma: no cover
raise ImportError(
'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`'
) from e
class AzureKeyVaultMapping(Mapping[str, Optional[str]]):
_loaded_secrets: dict[str, str | None]
_secret_client: SecretClient
_secret_names: list[str]
def __init__(
self,
secret_client: SecretClient,
case_sensitive: bool,
snake_case_conversion: bool,
) -> None:
self._loaded_secrets = {}
self._secret_client = secret_client
self._case_sensitive = case_sensitive
self._snake_case_conversion = snake_case_conversion
self._secret_map: dict[str, str] = self._load_remote()
def _load_remote(self) -> dict[str, str]:
secret_names: Iterator[str] = (
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
)
if self._snake_case_conversion:
return {to_snake(name): name for name in secret_names}
if self._case_sensitive:
return {name: name for name in secret_names}
return {name.lower(): name for name in secret_names}
def __getitem__(self, key: str) -> str | None:
new_key = key
if self._snake_case_conversion:
new_key = to_snake(key)
elif not self._case_sensitive:
new_key = key.lower()
if new_key not in self._loaded_secrets:
if new_key in self._secret_map:
self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value
else:
raise KeyError(key)
return self._loaded_secrets[new_key]
def __len__(self) -> int:
return len(self._secret_map)
def __iter__(self) -> Iterator[str]:
return iter(self._secret_map.keys())
class AzureKeyVaultSettingsSource(EnvSettingsSource):
_url: str
_credential: TokenCredential
def __init__(
self,
settings_cls: type[BaseSettings],
url: str,
credential: TokenCredential,
dash_to_underscore: bool = False,
case_sensitive: bool | None = None,
snake_case_conversion: bool = False,
env_prefix: str | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
import_azure_key_vault()
self._url = url
self._credential = credential
self._dash_to_underscore = dash_to_underscore
self._snake_case_conversion = snake_case_conversion
super().__init__(
settings_cls,
case_sensitive=False if snake_case_conversion else case_sensitive,
env_prefix=env_prefix,
env_nested_delimiter='__' if snake_case_conversion else '--',
env_ignore_empty=False,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
return AzureKeyVaultMapping(
secret_client=secret_client,
case_sensitive=self.case_sensitive,
snake_case_conversion=self._snake_case_conversion,
)
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
if self._snake_case_conversion:
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
if self._dash_to_underscore:
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
return super()._extract_field_info(field, field_name)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
__all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,168 @@
"""Dotenv file settings source."""
from __future__ import annotations as _annotations
import os
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any
from dotenv import dotenv_values
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
get_origin,
)
from typing_inspection.introspection import is_union_origin
from ..types import ENV_FILE_SENTINEL, DotenvType
from ..utils import (
_annotation_is_complex,
_union_is_complex,
parse_env_vars,
)
from .env import EnvSettingsSource
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class DotEnvSettingsSource(EnvSettingsSource):
"""
Source class for loading settings values from env files.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
env_file: DotenvType | None = ENV_FILE_SENTINEL,
env_file_encoding: str | None = None,
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str | None = None,
env_nested_max_split: int | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
self.env_file_encoding = (
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
)
super().__init__(
settings_cls,
case_sensitive,
env_prefix,
env_nested_delimiter,
env_nested_max_split,
env_ignore_empty,
env_parse_none_str,
env_parse_enums,
)
def _load_env_vars(self) -> Mapping[str, str | None]:
return self._read_env_files()
@staticmethod
def _static_read_env_file(
file_path: Path,
*,
encoding: str | None = None,
case_sensitive: bool = False,
ignore_empty: bool = False,
parse_none_str: str | None = None,
) -> Mapping[str, str | None]:
file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8')
return parse_env_vars(file_vars, case_sensitive, ignore_empty, parse_none_str)
def _read_env_file(
self,
file_path: Path,
) -> Mapping[str, str | None]:
return self._static_read_env_file(
file_path,
encoding=self.env_file_encoding,
case_sensitive=self.case_sensitive,
ignore_empty=self.env_ignore_empty,
parse_none_str=self.env_parse_none_str,
)
def _read_env_files(self) -> Mapping[str, str | None]:
env_files = self.env_file
if env_files is None:
return {}
if isinstance(env_files, (str, os.PathLike)):
env_files = [env_files]
dotenv_vars: dict[str, str | None] = {}
for env_file in env_files:
env_path = Path(env_file).expanduser()
if env_path.is_file():
dotenv_vars.update(self._read_env_file(env_path))
return dotenv_vars
def __call__(self) -> dict[str, Any]:
data: dict[str, Any] = super().__call__()
is_extra_allowed = self.config.get('extra') != 'forbid'
# As `extra` config is allowed in dotenv settings source, We have to
# update data with extra env variables from dotenv file.
for env_name, env_value in self.env_vars.items():
if not env_value or env_name in data or (self.env_prefix and env_name in self.settings_cls.model_fields):
continue
env_used = False
for field_name, field in self.settings_cls.model_fields.items():
for _, field_env_name, _ in self._extract_field_info(field, field_name):
if env_name == field_env_name or (
(
_annotation_is_complex(field.annotation, field.metadata)
or (
is_union_origin(get_origin(field.annotation))
and _union_is_complex(field.annotation, field.metadata)
)
)
and env_name.startswith(field_env_name)
):
env_used = True
break
if env_used:
break
if not env_used:
if is_extra_allowed and env_name.startswith(self.env_prefix):
# env_prefix should be respected and removed from the env_name
normalized_env_name = env_name[len(self.env_prefix) :]
data[normalized_env_name] = env_value
else:
data[env_name] = env_value
return data
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})'
)
def read_env_file(
file_path: Path,
*,
encoding: str | None = None,
case_sensitive: bool = False,
ignore_empty: bool = False,
parse_none_str: str | None = None,
) -> Mapping[str, str | None]:
warnings.warn(
'read_env_file will be removed in the next version, use DotEnvSettingsSource._static_read_env_file if you must',
DeprecationWarning,
)
return DotEnvSettingsSource._static_read_env_file(
file_path,
encoding=encoding,
case_sensitive=case_sensitive,
ignore_empty=ignore_empty,
parse_none_str=parse_none_str,
)
__all__ = ['DotEnvSettingsSource', 'read_env_file']

View File

@@ -0,0 +1,270 @@
from __future__ import annotations as _annotations
import os
from collections.abc import Mapping
from typing import (
TYPE_CHECKING,
Any,
)
from pydantic._internal._utils import deep_update, is_model_class
from pydantic.dataclasses import is_pydantic_dataclass
from pydantic.fields import FieldInfo
from typing_extensions import get_args, get_origin
from typing_inspection.introspection import is_union_origin
from ...utils import _lenient_issubclass
from ..base import PydanticBaseEnvSettingsSource
from ..types import EnvNoneType
from ..utils import (
_annotation_enum_name_to_val,
_get_model_fields,
_union_is_complex,
parse_env_vars,
)
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class EnvSettingsSource(PydanticBaseEnvSettingsSource):
"""
Source class for loading settings values from environment variables.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str | None = None,
env_nested_max_split: int | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
super().__init__(
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
)
self.env_nested_delimiter = (
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
)
self.env_nested_max_split = (
env_nested_max_split if env_nested_max_split is not None else self.config.get('env_nested_max_split')
)
self.maxsplit = (self.env_nested_max_split or 0) - 1
self.env_prefix_len = len(self.env_prefix)
self.env_vars = self._load_env_vars()
def _load_env_vars(self) -> Mapping[str, str | None]:
return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty, self.env_parse_none_str)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Gets the value for field from environment variables and a flag to determine whether value is complex.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple that contains the value (`None` if not found), key, and
a flag to determine whether value is complex.
"""
env_val: str | None = None
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
env_val = self.env_vars.get(env_name)
if env_val is not None:
break
return env_val, field_key, value_is_complex
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
"""
Prepare value for the field.
* Extract value for nested field.
* Deserialize value to python object for complex field.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple contains prepared value for the field.
Raises:
ValuesError: When There is an error in deserializing value for complex field.
"""
is_complex, allow_parse_failure = self._field_is_complex(field)
if self.env_parse_enums:
enum_val = _annotation_enum_name_to_val(field.annotation, value)
value = value if enum_val is None else enum_val
if is_complex or value_is_complex:
if isinstance(value, EnvNoneType):
return value
elif value is None:
# field is complex but no value found so far, try explode_env_vars
env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
if env_val_built:
return env_val_built
else:
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
try:
value = self.decode_complex_value(field_name, field, value)
except ValueError as e:
if not allow_parse_failure:
raise e
if isinstance(value, dict):
return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
else:
return value
elif value is not None:
# simplest case, field is not complex, we only need to add the value if it was found
return value
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
"""
Find out if a field is complex, and if so whether JSON errors should be ignored
"""
if self.field_is_complex(field):
allow_parse_failure = False
elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
allow_parse_failure = True
else:
return False, False
return True, allow_parse_failure
# Default value of `case_sensitive` is `None`, because we don't want to break existing behavior.
# We have to change the method to a non-static method and use
# `self.case_sensitive` instead in V3.
def next_field(
self, field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None
) -> FieldInfo | None:
"""
Find the field in a sub model by key(env name)
By having the following models:
```py
class SubSubModel(BaseSettings):
dvals: Dict
class SubModel(BaseSettings):
vals: list[str]
sub_sub_model: SubSubModel
class Cfg(BaseSettings):
sub_model: SubModel
```
Then:
next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class
Args:
field: The field.
key: The key (env name).
case_sensitive: Whether to search for key case sensitively.
Returns:
Field if it finds the next field otherwise `None`.
"""
if not field:
return None
annotation = field.annotation if isinstance(field, FieldInfo) else field
for type_ in get_args(annotation):
type_has_key = self.next_field(type_, key, case_sensitive)
if type_has_key:
return type_has_key
if is_model_class(annotation) or is_pydantic_dataclass(annotation): # type: ignore[arg-type]
fields = _get_model_fields(annotation)
# `case_sensitive is None` is here to be compatible with the old behavior.
# Has to be removed in V3.
for field_name, f in fields.items():
for _, env_name, _ in self._extract_field_info(f, field_name):
if case_sensitive is None or case_sensitive:
if field_name == key or env_name == key:
return f
elif field_name.lower() == key.lower() or env_name.lower() == key.lower():
return f
return None
def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
"""
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
This is applied to a single field, hence filtering by env_var prefix.
Args:
field_name: The field name.
field: The field.
env_vars: Environment variables.
Returns:
A dictionary contains extracted values from nested env values.
"""
if not self.env_nested_delimiter:
return {}
ann = field.annotation
is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)
prefixes = [
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
]
result: dict[str, Any] = {}
for env_name, env_val in env_vars.items():
try:
prefix = next(prefix for prefix in prefixes if env_name.startswith(prefix))
except StopIteration:
continue
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
env_name_without_prefix = env_name[len(prefix) :]
*keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter, self.maxsplit)
env_var = result
target_field: FieldInfo | None = field
for key in keys:
target_field = self.next_field(target_field, key, self.case_sensitive)
if isinstance(env_var, dict):
env_var = env_var.setdefault(key, {})
# get proper field with last_key
target_field = self.next_field(target_field, last_key, self.case_sensitive)
# check if env_val maps to a complex field and if so, parse the env_val
if (target_field or is_dict) and env_val:
if target_field:
is_complex, allow_json_failure = self._field_is_complex(target_field)
if self.env_parse_enums:
enum_val = _annotation_enum_name_to_val(target_field.annotation, env_val)
env_val = env_val if enum_val is None else enum_val
else:
# nested field type is dict
is_complex, allow_json_failure = True, True
if is_complex:
try:
env_val = self.decode_complex_value(last_key, target_field, env_val) # type: ignore
except ValueError as e:
if not allow_json_failure:
raise e
if isinstance(env_var, dict):
if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
env_var[last_key] = env_val
return result
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '
f'env_prefix_len={self.env_prefix_len!r})'
)
__all__ = ['EnvSettingsSource']

View File

@@ -0,0 +1,152 @@
from __future__ import annotations as _annotations
from collections.abc import Iterator, Mapping
from functools import cached_property
from typing import TYPE_CHECKING, Optional
from .env import EnvSettingsSource
if TYPE_CHECKING:
from google.auth import default as google_auth_default
from google.auth.credentials import Credentials
from google.cloud.secretmanager import SecretManagerServiceClient
from pydantic_settings.main import BaseSettings
else:
Credentials = None
SecretManagerServiceClient = None
google_auth_default = None
def import_gcp_secret_manager() -> None:
global Credentials
global SecretManagerServiceClient
global google_auth_default
try:
from google.auth import default as google_auth_default
from google.auth.credentials import Credentials
from google.cloud.secretmanager import SecretManagerServiceClient
except ImportError as e: # pragma: no cover
raise ImportError(
'GCP Secret Manager dependencies are not installed, run `pip install pydantic-settings[gcp-secret-manager]`'
) from e
class GoogleSecretManagerMapping(Mapping[str, Optional[str]]):
_loaded_secrets: dict[str, str | None]
_secret_client: SecretManagerServiceClient
def __init__(self, secret_client: SecretManagerServiceClient, project_id: str, case_sensitive: bool) -> None:
self._loaded_secrets = {}
self._secret_client = secret_client
self._project_id = project_id
self._case_sensitive = case_sensitive
@property
def _gcp_project_path(self) -> str:
return self._secret_client.common_project_path(self._project_id)
@cached_property
def _secret_names(self) -> list[str]:
rv: list[str] = []
secrets = self._secret_client.list_secrets(parent=self._gcp_project_path)
for secret in secrets:
name = self._secret_client.parse_secret_path(secret.name).get('secret', '')
if not self._case_sensitive:
name = name.lower()
rv.append(name)
return rv
def _secret_version_path(self, key: str, version: str = 'latest') -> str:
return self._secret_client.secret_version_path(self._project_id, key, version)
def __getitem__(self, key: str) -> str | None:
if not self._case_sensitive:
key = key.lower()
if key not in self._loaded_secrets:
# If we know the key isn't available in secret manager, raise a key error
if key not in self._secret_names:
raise KeyError(key)
try:
self._loaded_secrets[key] = self._secret_client.access_secret_version(
name=self._secret_version_path(key)
).payload.data.decode('UTF-8')
except Exception:
raise KeyError(key)
return self._loaded_secrets[key]
def __len__(self) -> int:
return len(self._secret_names)
def __iter__(self) -> Iterator[str]:
return iter(self._secret_names)
class GoogleSecretManagerSettingsSource(EnvSettingsSource):
_credentials: Credentials
_secret_client: SecretManagerServiceClient
_project_id: str
def __init__(
self,
settings_cls: type[BaseSettings],
credentials: Credentials | None = None,
project_id: str | None = None,
env_prefix: str | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
secret_client: SecretManagerServiceClient | None = None,
case_sensitive: bool | None = True,
) -> None:
# Import Google Packages if they haven't already been imported
if SecretManagerServiceClient is None or Credentials is None or google_auth_default is None:
import_gcp_secret_manager()
# If credentials or project_id are not passed, then
# try to get them from the default function
if not credentials or not project_id:
_creds, _project_id = google_auth_default() # type: ignore[no-untyped-call]
# Set the credentials and/or project id if they weren't specified
if credentials is None:
credentials = _creds
if project_id is None:
if isinstance(_project_id, str):
project_id = _project_id
else:
raise AttributeError(
'project_id is required to be specified either as an argument or from the google.auth.default. See https://google-auth.readthedocs.io/en/master/reference/google.auth.html#google.auth.default'
)
self._credentials: Credentials = credentials
self._project_id: str = project_id
if secret_client:
self._secret_client = secret_client
else:
self._secret_client = SecretManagerServiceClient(credentials=self._credentials)
super().__init__(
settings_cls,
case_sensitive=case_sensitive,
env_prefix=env_prefix,
env_ignore_empty=False,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
return GoogleSecretManagerMapping(
self._secret_client, project_id=self._project_id, case_sensitive=self.case_sensitive
)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(project_id={self._project_id!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
__all__ = ['GoogleSecretManagerSettingsSource', 'GoogleSecretManagerMapping']

View File

@@ -0,0 +1,47 @@
"""JSON file settings source."""
from __future__ import annotations as _annotations
import json
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from ..base import ConfigFileSourceMixin, InitSettingsSource
from ..types import DEFAULT_PATH, PathType
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class JsonConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
"""
A source class that loads variables from a JSON file
"""
def __init__(
self,
settings_cls: type[BaseSettings],
json_file: PathType | None = DEFAULT_PATH,
json_file_encoding: str | None = None,
):
self.json_file_path = json_file if json_file != DEFAULT_PATH else settings_cls.model_config.get('json_file')
self.json_file_encoding = (
json_file_encoding
if json_file_encoding is not None
else settings_cls.model_config.get('json_file_encoding')
)
self.json_data = self._read_files(self.json_file_path)
super().__init__(settings_cls, self.json_data)
def _read_file(self, file_path: Path) -> dict[str, Any]:
with open(file_path, encoding=self.json_file_encoding) as json_file:
return json.load(json_file)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(json_file={self.json_file_path})'
__all__ = ['JsonConfigSettingsSource']

View File

@@ -0,0 +1,62 @@
"""Pyproject TOML file settings source."""
from __future__ import annotations as _annotations
from pathlib import Path
from typing import (
TYPE_CHECKING,
)
from .toml import TomlConfigSettingsSource
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSource):
"""
A source class that loads variables from a `pyproject.toml` file.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
toml_file: Path | None = None,
) -> None:
self.toml_file_path = self._pick_pyproject_toml_file(
toml_file, settings_cls.model_config.get('pyproject_toml_depth', 0)
)
self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get(
'pyproject_toml_table_header', ('tool', 'pydantic-settings')
)
self.toml_data = self._read_files(self.toml_file_path)
for key in self.toml_table_header:
self.toml_data = self.toml_data.get(key, {})
super(TomlConfigSettingsSource, self).__init__(settings_cls, self.toml_data)
@staticmethod
def _pick_pyproject_toml_file(provided: Path | None, depth: int) -> Path:
"""Pick a `pyproject.toml` file path to use.
Args:
provided: Explicit path provided when instantiating this class.
depth: Number of directories up the tree to check of a pyproject.toml.
"""
if provided:
return provided.resolve()
rv = Path.cwd() / 'pyproject.toml'
count = 0
if not rv.is_file():
child = rv.parent.parent / 'pyproject.toml'
while count < depth:
if child.is_file():
return child
if str(child.parent) == rv.root:
break # end discovery after checking system root once
child = child.parent.parent / 'pyproject.toml'
count += 1
return rv
__all__ = ['PyprojectTomlConfigSettingsSource']

View File

@@ -0,0 +1,125 @@
"""Secrets file settings source."""
from __future__ import annotations as _annotations
import os
import warnings
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from pydantic.fields import FieldInfo
from pydantic_settings.utils import path_type_label
from ...exceptions import SettingsError
from ..base import PydanticBaseEnvSettingsSource
from ..types import PathType
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class SecretsSettingsSource(PydanticBaseEnvSettingsSource):
"""
Source class for loading settings values from secret files.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
secrets_dir: PathType | None = None,
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
super().__init__(
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
)
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')
def __call__(self) -> dict[str, Any]:
"""
Build fields from "secrets" files.
"""
secrets: dict[str, str | None] = {}
if self.secrets_dir is None:
return secrets
secrets_dirs = [self.secrets_dir] if isinstance(self.secrets_dir, (str, os.PathLike)) else self.secrets_dir
secrets_paths = [Path(p).expanduser() for p in secrets_dirs]
self.secrets_paths = []
for path in secrets_paths:
if not path.exists():
warnings.warn(f'directory "{path}" does not exist')
else:
self.secrets_paths.append(path)
if not len(self.secrets_paths):
return secrets
for path in self.secrets_paths:
if not path.is_dir():
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(path)}')
return super().__call__()
@classmethod
def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None:
"""
Find a file within path's directory matching filename, optionally ignoring case.
Args:
dir_path: Directory path.
file_name: File name.
case_sensitive: Whether to search for file name case sensitively.
Returns:
Whether file path or `None` if file does not exist in directory.
"""
for f in dir_path.iterdir():
if f.name == file_name:
return f
elif not case_sensitive and f.name.lower() == file_name.lower():
return f
return None
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Gets the value for field from secret file and a flag to determine whether value is complex.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple that contains the value (`None` if the file does not exist), key, and
a flag to determine whether value is complex.
"""
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
# paths reversed to match the last-wins behaviour of `env_file`
for secrets_path in reversed(self.secrets_paths):
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
if not path:
# path does not exist, we currently don't return a warning for this
continue
if path.is_file():
return path.read_text().strip(), field_key, value_is_complex
else:
warnings.warn(
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
stacklevel=4,
)
return None, field_key, value_is_complex
def __repr__(self) -> str:
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'

View File

@@ -0,0 +1,66 @@
"""TOML file settings source."""
from __future__ import annotations as _annotations
import sys
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from ..base import ConfigFileSourceMixin, InitSettingsSource
from ..types import DEFAULT_PATH, PathType
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
if sys.version_info >= (3, 11):
import tomllib
else:
tomllib = None
import tomli
else:
tomllib = None
tomli = None
def import_toml() -> None:
global tomli
global tomllib
if sys.version_info < (3, 11):
if tomli is not None:
return
try:
import tomli
except ImportError as e: # pragma: no cover
raise ImportError('tomli is not installed, run `pip install pydantic-settings[toml]`') from e
else:
if tomllib is not None:
return
import tomllib
class TomlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
"""
A source class that loads variables from a TOML file
"""
def __init__(
self,
settings_cls: type[BaseSettings],
toml_file: PathType | None = DEFAULT_PATH,
):
self.toml_file_path = toml_file if toml_file != DEFAULT_PATH else settings_cls.model_config.get('toml_file')
self.toml_data = self._read_files(self.toml_file_path)
super().__init__(settings_cls, self.toml_data)
def _read_file(self, file_path: Path) -> dict[str, Any]:
import_toml()
with open(file_path, mode='rb') as toml_file:
if sys.version_info < (3, 11):
return tomli.load(toml_file)
return tomllib.load(toml_file)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(toml_file={self.toml_file_path})'

View File

@@ -0,0 +1,75 @@
"""YAML file settings source."""
from __future__ import annotations as _annotations
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from ..base import ConfigFileSourceMixin, InitSettingsSource
from ..types import DEFAULT_PATH, PathType
if TYPE_CHECKING:
import yaml
from pydantic_settings.main import BaseSettings
else:
yaml = None
def import_yaml() -> None:
global yaml
if yaml is not None:
return
try:
import yaml
except ImportError as e:
raise ImportError('PyYAML is not installed, run `pip install pydantic-settings[yaml]`') from e
class YamlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
"""
A source class that loads variables from a yaml file
"""
def __init__(
self,
settings_cls: type[BaseSettings],
yaml_file: PathType | None = DEFAULT_PATH,
yaml_file_encoding: str | None = None,
yaml_config_section: str | None = None,
):
self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file')
self.yaml_file_encoding = (
yaml_file_encoding
if yaml_file_encoding is not None
else settings_cls.model_config.get('yaml_file_encoding')
)
self.yaml_config_section = (
yaml_config_section
if yaml_config_section is not None
else settings_cls.model_config.get('yaml_config_section')
)
self.yaml_data = self._read_files(self.yaml_file_path)
if self.yaml_config_section:
try:
self.yaml_data = self.yaml_data[self.yaml_config_section]
except KeyError:
raise KeyError(
f'yaml_config_section key "{self.yaml_config_section}" not found in {self.yaml_file_path}'
)
super().__init__(settings_cls, self.yaml_data)
def _read_file(self, file_path: Path) -> dict[str, Any]:
import_yaml()
with open(file_path, encoding=self.yaml_file_encoding) as yaml_file:
return yaml.safe_load(yaml_file) or {}
def __repr__(self) -> str:
return f'{self.__class__.__name__}(yaml_file={self.yaml_file_path})'
__all__ = ['YamlConfigSettingsSource']

View File

@@ -0,0 +1,78 @@
"""Type definitions for pydantic-settings sources."""
from __future__ import annotations as _annotations
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Union
if TYPE_CHECKING:
from pydantic._internal._dataclasses import PydanticDataclass
from pydantic.main import BaseModel
PydanticModel = Union[PydanticDataclass, BaseModel]
else:
PydanticModel = Any
class EnvNoneType(str):
pass
class NoDecode:
"""Annotation to prevent decoding of a field value."""
pass
class ForceDecode:
"""Annotation to force decoding of a field value."""
pass
DotenvType = Union[Path, str, Sequence[Union[Path, str]]]
PathType = Union[Path, str, Sequence[Union[Path, str]]]
DEFAULT_PATH: PathType = Path('')
# This is used as default value for `_env_file` in the `BaseSettings` class and
# `env_file` in `DotEnvSettingsSource` so the default can be distinguished from `None`.
# See the docstring of `BaseSettings` for more details.
ENV_FILE_SENTINEL: DotenvType = Path('')
class _CliSubCommand:
pass
class _CliPositionalArg:
pass
class _CliImplicitFlag:
pass
class _CliExplicitFlag:
pass
class _CliUnknownArgs:
pass
__all__ = [
'DEFAULT_PATH',
'ENV_FILE_SENTINEL',
'DotenvType',
'EnvNoneType',
'ForceDecode',
'NoDecode',
'PathType',
'PydanticModel',
'_CliExplicitFlag',
'_CliImplicitFlag',
'_CliPositionalArg',
'_CliSubCommand',
'_CliUnknownArgs',
]

View File

@@ -0,0 +1,206 @@
"""Utility functions for pydantic-settings sources."""
from __future__ import annotations as _annotations
from collections import deque
from collections.abc import Mapping, Sequence
from dataclasses import is_dataclass
from enum import Enum
from typing import Any, Optional, cast
from pydantic import BaseModel, Json, RootModel, Secret
from pydantic._internal._utils import is_model_class
from pydantic.dataclasses import is_pydantic_dataclass
from typing_extensions import get_args, get_origin
from typing_inspection import typing_objects
from ..exceptions import SettingsError
from ..utils import _lenient_issubclass
from .types import EnvNoneType
def _get_env_var_key(key: str, case_sensitive: bool = False) -> str:
return key if case_sensitive else key.lower()
def _parse_env_none_str(value: str | None, parse_none_str: str | None = None) -> str | None | EnvNoneType:
return value if not (value == parse_none_str and parse_none_str is not None) else EnvNoneType(value)
def parse_env_vars(
env_vars: Mapping[str, str | None],
case_sensitive: bool = False,
ignore_empty: bool = False,
parse_none_str: str | None = None,
) -> Mapping[str, str | None]:
return {
_get_env_var_key(k, case_sensitive): _parse_env_none_str(v, parse_none_str)
for k, v in env_vars.items()
if not (ignore_empty and v == '')
}
def _annotation_is_complex(annotation: Any, metadata: list[Any]) -> bool:
# If the model is a root model, the root annotation should be used to
# evaluate the complexity.
if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)):
annotation = annotation.__value__
if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel:
annotation = cast('type[RootModel[Any]]', annotation)
root_annotation = annotation.model_fields['root'].annotation
if root_annotation is not None: # pragma: no branch
annotation = root_annotation
if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
return False
origin = get_origin(annotation)
# Check if annotation is of the form Annotated[type, metadata].
if typing_objects.is_annotated(origin):
# Return result of recursive call on inner type.
inner, *meta = get_args(annotation)
return _annotation_is_complex(inner, meta)
if origin is Secret:
return False
return (
_annotation_is_complex_inner(annotation)
or _annotation_is_complex_inner(origin)
or hasattr(origin, '__pydantic_core_schema__')
or hasattr(origin, '__get_pydantic_core_schema__')
)
def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
if _lenient_issubclass(annotation, (str, bytes)):
return False
return _lenient_issubclass(
annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)
) or is_dataclass(annotation)
def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
"""Check if a union type contains any complex types."""
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
def _annotation_contains_types(
annotation: type[Any] | None,
types: tuple[Any, ...],
is_include_origin: bool = True,
is_strip_annotated: bool = False,
) -> bool:
"""Check if a type annotation contains any of the specified types."""
if is_strip_annotated:
annotation = _strip_annotated(annotation)
if is_include_origin is True and get_origin(annotation) in types:
return True
for type_ in get_args(annotation):
if _annotation_contains_types(type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated):
return True
return annotation in types
def _strip_annotated(annotation: Any) -> Any:
if typing_objects.is_annotated(get_origin(annotation)):
return annotation.__origin__
else:
return annotation
def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> Optional[str]:
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
if _lenient_issubclass(type_, Enum):
if value in tuple(val.value for val in type_):
return type_(value).name
return None
def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any:
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
if _lenient_issubclass(type_, Enum):
if name in tuple(val.name for val in type_):
return type_[name]
return None
def _get_model_fields(model_cls: type[Any]) -> dict[str, Any]:
"""Get fields from a pydantic model or dataclass."""
if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'):
return model_cls.__pydantic_fields__
if is_model_class(model_cls):
return model_cls.model_fields
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')
def _get_alias_names(
field_name: str,
field_info: Any,
alias_path_args: Optional[dict[str, Optional[int]]] = None,
case_sensitive: bool = True,
) -> tuple[tuple[str, ...], bool]:
"""Get alias names for a field, handling alias paths and case sensitivity."""
from pydantic import AliasChoices, AliasPath
alias_names: list[str] = []
is_alias_path_only: bool = True
if not any((field_info.alias, field_info.validation_alias)):
alias_names += [field_name]
is_alias_path_only = False
else:
new_alias_paths: list[AliasPath] = []
for alias in (field_info.alias, field_info.validation_alias):
if alias is None:
continue
elif isinstance(alias, str):
alias_names.append(alias)
is_alias_path_only = False
elif isinstance(alias, AliasChoices):
for name in alias.choices:
if isinstance(name, str):
alias_names.append(name)
is_alias_path_only = False
else:
new_alias_paths.append(name)
else:
new_alias_paths.append(alias)
for alias_path in new_alias_paths:
name = cast(str, alias_path.path[0])
name = name.lower() if not case_sensitive else name
if alias_path_args is not None:
alias_path_args[name] = (
alias_path.path[1] if len(alias_path.path) > 1 and isinstance(alias_path.path[1], int) else None
)
if not alias_names and is_alias_path_only:
alias_names.append(name)
if not case_sensitive:
alias_names = [alias_name.lower() for alias_name in alias_names]
return tuple(dict.fromkeys(alias_names)), is_alias_path_only
def _is_function(obj: Any) -> bool:
"""Check if an object is a function."""
from types import BuiltinFunctionType, FunctionType
return isinstance(obj, (FunctionType, BuiltinFunctionType))
__all__ = [
'_annotation_contains_types',
'_annotation_enum_name_to_val',
'_annotation_enum_val_to_name',
'_annotation_is_complex',
'_annotation_is_complex_inner',
'_get_alias_names',
'_get_env_var_key',
'_get_model_fields',
'_is_function',
'_parse_env_none_str',
'_strip_annotated',
'_union_is_complex',
'parse_env_vars',
]

View File

@@ -1,14 +1,19 @@
import sys
import types
from pathlib import Path
from typing import Any, _GenericAlias # type: ignore [attr-defined]
path_type_labels = {
'is_dir': 'directory',
'is_file': 'file',
'is_mount': 'mount point',
'is_symlink': 'symlink',
'is_block_device': 'block device',
'is_char_device': 'char device',
'is_fifo': 'FIFO',
'is_socket': 'socket',
from typing_extensions import get_origin
_PATH_TYPE_LABELS = {
Path.is_dir: 'directory',
Path.is_file: 'file',
Path.is_mount: 'mount point',
Path.is_symlink: 'symlink',
Path.is_block_device: 'block device',
Path.is_char_device: 'char device',
Path.is_fifo: 'FIFO',
Path.is_socket: 'socket',
}
@@ -17,8 +22,27 @@ def path_type_label(p: Path) -> str:
Find out what sort of thing a path is.
"""
assert p.exists(), 'path does not exist'
for method, name in path_type_labels.items():
if getattr(p, method)():
for method, name in _PATH_TYPE_LABELS.items():
if method(p):
return name
return 'unknown'
return 'unknown' # pragma: no cover
# TODO remove and replace usage by `isinstance(cls, type) and issubclass(cls, class_or_tuple)`
# once we drop support for Python 3.10.
def _lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
except TypeError:
if get_origin(cls) is not None:
# Up until Python 3.10, isinstance(<generic_alias>, type) is True
# (e.g. list[int])
return False
raise
if sys.version_info < (3, 10):
_WithArgsTypes = tuple()
else:
_WithArgsTypes = (_GenericAlias, types.GenericAlias, types.UnionType)

View File

@@ -1 +1 @@
VERSION = '2.0.3'
VERSION = '2.11.0'