This commit is contained in:
@@ -1,22 +1,63 @@
|
||||
from .main import BaseSettings, SettingsConfigDict
|
||||
from .exceptions import SettingsError
|
||||
from .main import BaseSettings, CliApp, SettingsConfigDict
|
||||
from .sources import (
|
||||
CLI_SUPPRESS,
|
||||
AWSSecretsManagerSettingsSource,
|
||||
AzureKeyVaultSettingsSource,
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
CliUnknownArgs,
|
||||
DotEnvSettingsSource,
|
||||
EnvSettingsSource,
|
||||
ForceDecode,
|
||||
GoogleSecretManagerSettingsSource,
|
||||
InitSettingsSource,
|
||||
JsonConfigSettingsSource,
|
||||
NoDecode,
|
||||
PydanticBaseSettingsSource,
|
||||
PyprojectTomlConfigSettingsSource,
|
||||
SecretsSettingsSource,
|
||||
TomlConfigSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
get_subcommand,
|
||||
)
|
||||
from .version import VERSION
|
||||
|
||||
__all__ = (
|
||||
'CLI_SUPPRESS',
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'BaseSettings',
|
||||
'CliApp',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'CliUnknownArgs',
|
||||
'DotEnvSettingsSource',
|
||||
'EnvSettingsSource',
|
||||
'ForceDecode',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'InitSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'NoDecode',
|
||||
'PydanticBaseSettingsSource',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'SettingsConfigDict',
|
||||
'SettingsError',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
'__version__',
|
||||
'get_subcommand',
|
||||
)
|
||||
|
||||
__version__ = VERSION
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
class SettingsError(ValueError):
|
||||
"""Base exception for settings-related errors."""
|
||||
|
||||
pass
|
||||
@@ -1,31 +1,104 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
import asyncio
|
||||
import inspect
|
||||
import threading
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic._internal._config import config_keys
|
||||
from pydantic._internal._utils import deep_update
|
||||
from pydantic._internal._signature import _field_name_for_signature
|
||||
from pydantic._internal._utils import deep_update, is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from .exceptions import SettingsError
|
||||
from .sources import (
|
||||
ENV_FILE_SENTINEL,
|
||||
CliSettingsSource,
|
||||
DefaultSettingsSource,
|
||||
DotEnvSettingsSource,
|
||||
DotenvType,
|
||||
EnvSettingsSource,
|
||||
InitSettingsSource,
|
||||
JsonConfigSettingsSource,
|
||||
PathType,
|
||||
PydanticBaseSettingsSource,
|
||||
PydanticModel,
|
||||
PyprojectTomlConfigSettingsSource,
|
||||
SecretsSettingsSource,
|
||||
TomlConfigSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
get_subcommand,
|
||||
)
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class SettingsConfigDict(ConfigDict, total=False):
|
||||
case_sensitive: bool
|
||||
nested_model_default_partial_update: bool | None
|
||||
env_prefix: str
|
||||
env_file: DotenvType | None
|
||||
env_file_encoding: str | None
|
||||
env_ignore_empty: bool
|
||||
env_nested_delimiter: str | None
|
||||
secrets_dir: str | Path | None
|
||||
env_nested_max_split: int | None
|
||||
env_parse_none_str: str | None
|
||||
env_parse_enums: bool | None
|
||||
cli_prog_name: str | None
|
||||
cli_parse_args: bool | list[str] | tuple[str, ...] | None
|
||||
cli_parse_none_str: str | None
|
||||
cli_hide_none_type: bool
|
||||
cli_avoid_json: bool
|
||||
cli_enforce_required: bool
|
||||
cli_use_class_docs_for_groups: bool
|
||||
cli_exit_on_error: bool
|
||||
cli_prefix: str
|
||||
cli_flag_prefix_char: str
|
||||
cli_implicit_flags: bool | None
|
||||
cli_ignore_unknown_args: bool | None
|
||||
cli_kebab_case: bool | None
|
||||
cli_shortcuts: Mapping[str, str | list[str]] | None
|
||||
secrets_dir: PathType | None
|
||||
json_file: PathType | None
|
||||
json_file_encoding: str | None
|
||||
yaml_file: PathType | None
|
||||
yaml_file_encoding: str | None
|
||||
yaml_config_section: str | None
|
||||
"""
|
||||
Specifies the top-level key in a YAML file from which to load the settings.
|
||||
If provided, the settings will be loaded from the nested section under this key.
|
||||
This is useful when the YAML file contains multiple configuration sections
|
||||
and you only want to load a specific subset into your settings model.
|
||||
"""
|
||||
|
||||
pyproject_toml_depth: int
|
||||
"""
|
||||
Number of levels **up** from the current working directory to attempt to find a pyproject.toml
|
||||
file.
|
||||
|
||||
This is only used when a pyproject.toml file is not found in the current working directory.
|
||||
"""
|
||||
|
||||
pyproject_toml_table_header: tuple[str, ...]
|
||||
"""
|
||||
Header of the TOML table within a pyproject.toml file to use when filling variables.
|
||||
This is supplied as a `tuple[str, ...]` instead of a `str` to accommodate for headers
|
||||
containing a `.`.
|
||||
|
||||
For example, `toml_table_header = ("tool", "my.tool", "foo")` can be used to fill variable
|
||||
values from a table with header `[tool."my.tool".foo]`.
|
||||
|
||||
To use the root table, exclude this config setting or provide an empty tuple.
|
||||
"""
|
||||
|
||||
toml_file: PathType | None
|
||||
enable_decoding: bool
|
||||
|
||||
|
||||
# Extend `config_keys` by pydantic settings config keys to
|
||||
@@ -47,35 +120,104 @@ class BaseSettings(BaseModel):
|
||||
All the below attributes can be set via `model_config`.
|
||||
|
||||
Args:
|
||||
_case_sensitive: Whether environment variables names should be read with case-sensitivity. Defaults to `None`.
|
||||
_case_sensitive: Whether environment and CLI variable names should be read with case-sensitivity.
|
||||
Defaults to `None`.
|
||||
_nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
|
||||
Defaults to `False`.
|
||||
_env_prefix: Prefix for all environment variables. Defaults to `None`.
|
||||
_env_file: The env file(s) to load settings values from. Defaults to `Path('')`, which
|
||||
means that the value from `model_config['env_file']` should be used. You can also pass
|
||||
`None` to indicate that environment variables should not be loaded from an env file.
|
||||
_env_file_encoding: The env file encoding, e.g. `'latin-1'`. Defaults to `None`.
|
||||
_env_ignore_empty: Ignore environment variables where the value is an empty string. Default to `False`.
|
||||
_env_nested_delimiter: The nested env values delimiter. Defaults to `None`.
|
||||
_secrets_dir: The secret files directory. Defaults to `None`.
|
||||
_env_nested_max_split: The nested env values maximum nesting. Defaults to `None`, which means no limit.
|
||||
_env_parse_none_str: The env string value that should be parsed (e.g. "null", "void", "None", etc.)
|
||||
into `None` type(None). Defaults to `None` type(None), which means no parsing should occur.
|
||||
_env_parse_enums: Parse enum field names to values. Defaults to `None.`, which means no parsing should occur.
|
||||
_cli_prog_name: The CLI program name to display in help text. Defaults to `None` if _cli_parse_args is `None`.
|
||||
Otherwise, defaults to sys.argv[0].
|
||||
_cli_parse_args: The list of CLI arguments to parse. Defaults to None.
|
||||
If set to `True`, defaults to sys.argv[1:].
|
||||
_cli_settings_source: Override the default CLI settings source with a user defined instance. Defaults to None.
|
||||
_cli_parse_none_str: The CLI string value that should be parsed (e.g. "null", "void", "None", etc.) into
|
||||
`None` type(None). Defaults to _env_parse_none_str value if set. Otherwise, defaults to "null" if
|
||||
_cli_avoid_json is `False`, and "None" if _cli_avoid_json is `True`.
|
||||
_cli_hide_none_type: Hide `None` values in CLI help text. Defaults to `False`.
|
||||
_cli_avoid_json: Avoid complex JSON objects in CLI help text. Defaults to `False`.
|
||||
_cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`.
|
||||
_cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions.
|
||||
Defaults to `False`.
|
||||
_cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
|
||||
Defaults to `True`.
|
||||
_cli_prefix: The root parser command line arguments prefix. Defaults to "".
|
||||
_cli_flag_prefix_char: The flag prefix character to use for CLI optional arguments. Defaults to '-'.
|
||||
_cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags.
|
||||
(e.g. --flag, --no-flag). Defaults to `False`.
|
||||
_cli_ignore_unknown_args: Whether to ignore unknown CLI args and parse only known ones. Defaults to `False`.
|
||||
_cli_kebab_case: CLI args use kebab case. Defaults to `False`.
|
||||
_cli_shortcuts: Mapping of target field name to alias names. Defaults to `None`.
|
||||
_secrets_dir: The secret files directory or a sequence of directories. Defaults to `None`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
__pydantic_self__,
|
||||
_case_sensitive: bool | None = None,
|
||||
_nested_model_default_partial_update: bool | None = None,
|
||||
_env_prefix: str | None = None,
|
||||
_env_file: DotenvType | None = ENV_FILE_SENTINEL,
|
||||
_env_file_encoding: str | None = None,
|
||||
_env_ignore_empty: bool | None = None,
|
||||
_env_nested_delimiter: str | None = None,
|
||||
_secrets_dir: str | Path | None = None,
|
||||
_env_nested_max_split: int | None = None,
|
||||
_env_parse_none_str: str | None = None,
|
||||
_env_parse_enums: bool | None = None,
|
||||
_cli_prog_name: str | None = None,
|
||||
_cli_parse_args: bool | list[str] | tuple[str, ...] | None = None,
|
||||
_cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
_cli_parse_none_str: str | None = None,
|
||||
_cli_hide_none_type: bool | None = None,
|
||||
_cli_avoid_json: bool | None = None,
|
||||
_cli_enforce_required: bool | None = None,
|
||||
_cli_use_class_docs_for_groups: bool | None = None,
|
||||
_cli_exit_on_error: bool | None = None,
|
||||
_cli_prefix: str | None = None,
|
||||
_cli_flag_prefix_char: str | None = None,
|
||||
_cli_implicit_flags: bool | None = None,
|
||||
_cli_ignore_unknown_args: bool | None = None,
|
||||
_cli_kebab_case: bool | None = None,
|
||||
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
|
||||
_secrets_dir: PathType | None = None,
|
||||
**values: Any,
|
||||
) -> None:
|
||||
# Uses something other than `self` the first arg to allow "self" as a settable attribute
|
||||
super().__init__(
|
||||
**__pydantic_self__._settings_build_values(
|
||||
values,
|
||||
_case_sensitive=_case_sensitive,
|
||||
_nested_model_default_partial_update=_nested_model_default_partial_update,
|
||||
_env_prefix=_env_prefix,
|
||||
_env_file=_env_file,
|
||||
_env_file_encoding=_env_file_encoding,
|
||||
_env_ignore_empty=_env_ignore_empty,
|
||||
_env_nested_delimiter=_env_nested_delimiter,
|
||||
_env_nested_max_split=_env_nested_max_split,
|
||||
_env_parse_none_str=_env_parse_none_str,
|
||||
_env_parse_enums=_env_parse_enums,
|
||||
_cli_prog_name=_cli_prog_name,
|
||||
_cli_parse_args=_cli_parse_args,
|
||||
_cli_settings_source=_cli_settings_source,
|
||||
_cli_parse_none_str=_cli_parse_none_str,
|
||||
_cli_hide_none_type=_cli_hide_none_type,
|
||||
_cli_avoid_json=_cli_avoid_json,
|
||||
_cli_enforce_required=_cli_enforce_required,
|
||||
_cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups,
|
||||
_cli_exit_on_error=_cli_exit_on_error,
|
||||
_cli_prefix=_cli_prefix,
|
||||
_cli_flag_prefix_char=_cli_flag_prefix_char,
|
||||
_cli_implicit_flags=_cli_implicit_flags,
|
||||
_cli_ignore_unknown_args=_cli_ignore_unknown_args,
|
||||
_cli_kebab_case=_cli_kebab_case,
|
||||
_cli_shortcuts=_cli_shortcuts,
|
||||
_secrets_dir=_secrets_dir,
|
||||
)
|
||||
)
|
||||
@@ -108,33 +250,125 @@ class BaseSettings(BaseModel):
|
||||
self,
|
||||
init_kwargs: dict[str, Any],
|
||||
_case_sensitive: bool | None = None,
|
||||
_nested_model_default_partial_update: bool | None = None,
|
||||
_env_prefix: str | None = None,
|
||||
_env_file: DotenvType | None = None,
|
||||
_env_file_encoding: str | None = None,
|
||||
_env_ignore_empty: bool | None = None,
|
||||
_env_nested_delimiter: str | None = None,
|
||||
_secrets_dir: str | Path | None = None,
|
||||
_env_nested_max_split: int | None = None,
|
||||
_env_parse_none_str: str | None = None,
|
||||
_env_parse_enums: bool | None = None,
|
||||
_cli_prog_name: str | None = None,
|
||||
_cli_parse_args: bool | list[str] | tuple[str, ...] | None = None,
|
||||
_cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
_cli_parse_none_str: str | None = None,
|
||||
_cli_hide_none_type: bool | None = None,
|
||||
_cli_avoid_json: bool | None = None,
|
||||
_cli_enforce_required: bool | None = None,
|
||||
_cli_use_class_docs_for_groups: bool | None = None,
|
||||
_cli_exit_on_error: bool | None = None,
|
||||
_cli_prefix: str | None = None,
|
||||
_cli_flag_prefix_char: str | None = None,
|
||||
_cli_implicit_flags: bool | None = None,
|
||||
_cli_ignore_unknown_args: bool | None = None,
|
||||
_cli_kebab_case: bool | None = None,
|
||||
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
|
||||
_secrets_dir: PathType | None = None,
|
||||
) -> dict[str, Any]:
|
||||
# Determine settings config values
|
||||
case_sensitive = _case_sensitive if _case_sensitive is not None else self.model_config.get('case_sensitive')
|
||||
env_prefix = _env_prefix if _env_prefix is not None else self.model_config.get('env_prefix')
|
||||
nested_model_default_partial_update = (
|
||||
_nested_model_default_partial_update
|
||||
if _nested_model_default_partial_update is not None
|
||||
else self.model_config.get('nested_model_default_partial_update')
|
||||
)
|
||||
env_file = _env_file if _env_file != ENV_FILE_SENTINEL else self.model_config.get('env_file')
|
||||
env_file_encoding = (
|
||||
_env_file_encoding if _env_file_encoding is not None else self.model_config.get('env_file_encoding')
|
||||
)
|
||||
env_ignore_empty = (
|
||||
_env_ignore_empty if _env_ignore_empty is not None else self.model_config.get('env_ignore_empty')
|
||||
)
|
||||
env_nested_delimiter = (
|
||||
_env_nested_delimiter
|
||||
if _env_nested_delimiter is not None
|
||||
else self.model_config.get('env_nested_delimiter')
|
||||
)
|
||||
env_nested_max_split = (
|
||||
_env_nested_max_split
|
||||
if _env_nested_max_split is not None
|
||||
else self.model_config.get('env_nested_max_split')
|
||||
)
|
||||
env_parse_none_str = (
|
||||
_env_parse_none_str if _env_parse_none_str is not None else self.model_config.get('env_parse_none_str')
|
||||
)
|
||||
env_parse_enums = _env_parse_enums if _env_parse_enums is not None else self.model_config.get('env_parse_enums')
|
||||
|
||||
cli_prog_name = _cli_prog_name if _cli_prog_name is not None else self.model_config.get('cli_prog_name')
|
||||
cli_parse_args = _cli_parse_args if _cli_parse_args is not None else self.model_config.get('cli_parse_args')
|
||||
cli_settings_source = (
|
||||
_cli_settings_source if _cli_settings_source is not None else self.model_config.get('cli_settings_source')
|
||||
)
|
||||
cli_parse_none_str = (
|
||||
_cli_parse_none_str if _cli_parse_none_str is not None else self.model_config.get('cli_parse_none_str')
|
||||
)
|
||||
cli_parse_none_str = cli_parse_none_str if not env_parse_none_str else env_parse_none_str
|
||||
cli_hide_none_type = (
|
||||
_cli_hide_none_type if _cli_hide_none_type is not None else self.model_config.get('cli_hide_none_type')
|
||||
)
|
||||
cli_avoid_json = _cli_avoid_json if _cli_avoid_json is not None else self.model_config.get('cli_avoid_json')
|
||||
cli_enforce_required = (
|
||||
_cli_enforce_required
|
||||
if _cli_enforce_required is not None
|
||||
else self.model_config.get('cli_enforce_required')
|
||||
)
|
||||
cli_use_class_docs_for_groups = (
|
||||
_cli_use_class_docs_for_groups
|
||||
if _cli_use_class_docs_for_groups is not None
|
||||
else self.model_config.get('cli_use_class_docs_for_groups')
|
||||
)
|
||||
cli_exit_on_error = (
|
||||
_cli_exit_on_error if _cli_exit_on_error is not None else self.model_config.get('cli_exit_on_error')
|
||||
)
|
||||
cli_prefix = _cli_prefix if _cli_prefix is not None else self.model_config.get('cli_prefix')
|
||||
cli_flag_prefix_char = (
|
||||
_cli_flag_prefix_char
|
||||
if _cli_flag_prefix_char is not None
|
||||
else self.model_config.get('cli_flag_prefix_char')
|
||||
)
|
||||
cli_implicit_flags = (
|
||||
_cli_implicit_flags if _cli_implicit_flags is not None else self.model_config.get('cli_implicit_flags')
|
||||
)
|
||||
cli_ignore_unknown_args = (
|
||||
_cli_ignore_unknown_args
|
||||
if _cli_ignore_unknown_args is not None
|
||||
else self.model_config.get('cli_ignore_unknown_args')
|
||||
)
|
||||
cli_kebab_case = _cli_kebab_case if _cli_kebab_case is not None else self.model_config.get('cli_kebab_case')
|
||||
cli_shortcuts = _cli_shortcuts if _cli_shortcuts is not None else self.model_config.get('cli_shortcuts')
|
||||
|
||||
secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir')
|
||||
|
||||
# Configure built-in sources
|
||||
init_settings = InitSettingsSource(self.__class__, init_kwargs=init_kwargs)
|
||||
default_settings = DefaultSettingsSource(
|
||||
self.__class__, nested_model_default_partial_update=nested_model_default_partial_update
|
||||
)
|
||||
init_settings = InitSettingsSource(
|
||||
self.__class__,
|
||||
init_kwargs=init_kwargs,
|
||||
nested_model_default_partial_update=nested_model_default_partial_update,
|
||||
)
|
||||
env_settings = EnvSettingsSource(
|
||||
self.__class__,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_nested_max_split=env_nested_max_split,
|
||||
env_ignore_empty=env_ignore_empty,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
dotenv_settings = DotEnvSettingsSource(
|
||||
self.__class__,
|
||||
@@ -143,6 +377,10 @@ class BaseSettings(BaseModel):
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_nested_max_split=env_nested_max_split,
|
||||
env_ignore_empty=env_ignore_empty,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
file_secret_settings = SecretsSettingsSource(
|
||||
@@ -155,23 +393,294 @@ class BaseSettings(BaseModel):
|
||||
env_settings=env_settings,
|
||||
dotenv_settings=dotenv_settings,
|
||||
file_secret_settings=file_secret_settings,
|
||||
)
|
||||
) + (default_settings,)
|
||||
custom_cli_sources = [source for source in sources if isinstance(source, CliSettingsSource)]
|
||||
if not any(custom_cli_sources):
|
||||
if isinstance(cli_settings_source, CliSettingsSource):
|
||||
sources = (cli_settings_source,) + sources
|
||||
elif cli_parse_args is not None:
|
||||
cli_settings = CliSettingsSource[Any](
|
||||
self.__class__,
|
||||
cli_prog_name=cli_prog_name,
|
||||
cli_parse_args=cli_parse_args,
|
||||
cli_parse_none_str=cli_parse_none_str,
|
||||
cli_hide_none_type=cli_hide_none_type,
|
||||
cli_avoid_json=cli_avoid_json,
|
||||
cli_enforce_required=cli_enforce_required,
|
||||
cli_use_class_docs_for_groups=cli_use_class_docs_for_groups,
|
||||
cli_exit_on_error=cli_exit_on_error,
|
||||
cli_prefix=cli_prefix,
|
||||
cli_flag_prefix_char=cli_flag_prefix_char,
|
||||
cli_implicit_flags=cli_implicit_flags,
|
||||
cli_ignore_unknown_args=cli_ignore_unknown_args,
|
||||
cli_kebab_case=cli_kebab_case,
|
||||
cli_shortcuts=cli_shortcuts,
|
||||
case_sensitive=case_sensitive,
|
||||
)
|
||||
sources = (cli_settings,) + sources
|
||||
# We ensure that if command line arguments haven't been parsed yet, we do so.
|
||||
elif cli_parse_args not in (None, False) and not custom_cli_sources[0].env_vars:
|
||||
custom_cli_sources[0](args=cli_parse_args) # type: ignore
|
||||
|
||||
self._settings_warn_unused_config_keys(sources, self.model_config)
|
||||
|
||||
if sources:
|
||||
return deep_update(*reversed([source() for source in sources]))
|
||||
state: dict[str, Any] = {}
|
||||
states: dict[str, dict[str, Any]] = {}
|
||||
for source in sources:
|
||||
if isinstance(source, PydanticBaseSettingsSource):
|
||||
source._set_current_state(state)
|
||||
source._set_settings_sources_data(states)
|
||||
|
||||
source_name = source.__name__ if hasattr(source, '__name__') else type(source).__name__
|
||||
source_state = source()
|
||||
|
||||
states[source_name] = source_state
|
||||
state = deep_update(source_state, state)
|
||||
return state
|
||||
else:
|
||||
# no one should mean to do this, but I think returning an empty dict is marginally preferable
|
||||
# to an informative error and much better than a confusing error
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _settings_warn_unused_config_keys(sources: tuple[object, ...], model_config: SettingsConfigDict) -> None:
|
||||
"""
|
||||
Warns if any values in model_config were set but the corresponding settings source has not been initialised.
|
||||
|
||||
The list alternative sources and their config keys can be found here:
|
||||
https://docs.pydantic.dev/latest/concepts/pydantic_settings/#other-settings-source
|
||||
|
||||
Args:
|
||||
sources: The tuple of configured sources
|
||||
model_config: The model config to check for unused config keys
|
||||
"""
|
||||
|
||||
def warn_if_not_used(source_type: type[PydanticBaseSettingsSource], keys: tuple[str, ...]) -> None:
|
||||
if not any(isinstance(source, source_type) for source in sources):
|
||||
for key in keys:
|
||||
if model_config.get(key) is not None:
|
||||
warnings.warn(
|
||||
f'Config key `{key}` is set in model_config but will be ignored because no '
|
||||
f'{source_type.__name__} source is configured. To use this config key, add a '
|
||||
f'{source_type.__name__} source to the settings sources via the '
|
||||
'settings_customise_sources hook.',
|
||||
UserWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
warn_if_not_used(JsonConfigSettingsSource, ('json_file', 'json_file_encoding'))
|
||||
warn_if_not_used(PyprojectTomlConfigSettingsSource, ('pyproject_toml_depth', 'pyproject_toml_table_header'))
|
||||
warn_if_not_used(TomlConfigSettingsSource, ('toml_file',))
|
||||
warn_if_not_used(YamlConfigSettingsSource, ('yaml_file', 'yaml_file_encoding', 'yaml_config_section'))
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
extra='forbid',
|
||||
arbitrary_types_allowed=True,
|
||||
validate_default=True,
|
||||
case_sensitive=False,
|
||||
env_prefix='',
|
||||
nested_model_default_partial_update=False,
|
||||
env_file=None,
|
||||
env_file_encoding=None,
|
||||
env_ignore_empty=False,
|
||||
env_nested_delimiter=None,
|
||||
env_nested_max_split=None,
|
||||
env_parse_none_str=None,
|
||||
env_parse_enums=None,
|
||||
cli_prog_name=None,
|
||||
cli_parse_args=None,
|
||||
cli_parse_none_str=None,
|
||||
cli_hide_none_type=False,
|
||||
cli_avoid_json=False,
|
||||
cli_enforce_required=False,
|
||||
cli_use_class_docs_for_groups=False,
|
||||
cli_exit_on_error=True,
|
||||
cli_prefix='',
|
||||
cli_flag_prefix_char='-',
|
||||
cli_implicit_flags=False,
|
||||
cli_ignore_unknown_args=False,
|
||||
cli_kebab_case=False,
|
||||
cli_shortcuts=None,
|
||||
json_file=None,
|
||||
json_file_encoding=None,
|
||||
yaml_file=None,
|
||||
yaml_file_encoding=None,
|
||||
yaml_config_section=None,
|
||||
toml_file=None,
|
||||
secrets_dir=None,
|
||||
protected_namespaces=('model_', 'settings_'),
|
||||
protected_namespaces=('model_validate', 'model_dump', 'settings_customise_sources'),
|
||||
enable_decoding=True,
|
||||
)
|
||||
|
||||
|
||||
class CliApp:
|
||||
"""
|
||||
A utility class for running Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as
|
||||
CLI applications.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_base_settings_cls(model_cls: type[Any]) -> type[BaseSettings]:
|
||||
if issubclass(model_cls, BaseSettings):
|
||||
return model_cls
|
||||
|
||||
class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
|
||||
__doc__ = model_cls.__doc__
|
||||
model_config = SettingsConfigDict(
|
||||
nested_model_default_partial_update=True,
|
||||
case_sensitive=True,
|
||||
cli_hide_none_type=True,
|
||||
cli_avoid_json=True,
|
||||
cli_enforce_required=True,
|
||||
cli_implicit_flags=True,
|
||||
cli_kebab_case=True,
|
||||
)
|
||||
|
||||
return CliAppBaseSettings
|
||||
|
||||
@staticmethod
|
||||
def _run_cli_cmd(model: Any, cli_cmd_method_name: str, is_required: bool) -> Any:
|
||||
command = getattr(type(model), cli_cmd_method_name, None)
|
||||
if command is None:
|
||||
if is_required:
|
||||
raise SettingsError(f'Error: {type(model).__name__} class is missing {cli_cmd_method_name} entrypoint')
|
||||
return model
|
||||
|
||||
# If the method is asynchronous, we handle its execution based on the current event loop status.
|
||||
if inspect.iscoroutinefunction(command):
|
||||
# For asynchronous methods, we have two execution scenarios:
|
||||
# 1. If no event loop is running in the current thread, run the coroutine directly with asyncio.run().
|
||||
# 2. If an event loop is already running in the current thread, run the coroutine in a separate thread to avoid conflicts.
|
||||
try:
|
||||
# Check if an event loop is currently running in this thread.
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# We're in a context with an active event loop (e.g., Jupyter Notebook).
|
||||
# Running asyncio.run() here would cause conflicts, so we use a separate thread.
|
||||
exception_container = []
|
||||
|
||||
def run_coro() -> None:
|
||||
try:
|
||||
# Execute the coroutine in a new event loop in this separate thread.
|
||||
asyncio.run(command(model))
|
||||
except Exception as e:
|
||||
exception_container.append(e)
|
||||
|
||||
thread = threading.Thread(target=run_coro)
|
||||
thread.start()
|
||||
thread.join()
|
||||
if exception_container:
|
||||
# Propagate exceptions from the separate thread.
|
||||
raise exception_container[0]
|
||||
else:
|
||||
# No event loop is running; safe to run the coroutine directly.
|
||||
asyncio.run(command(model))
|
||||
else:
|
||||
# For synchronous methods, call them directly.
|
||||
command(model)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
model_cls: type[T],
|
||||
cli_args: list[str] | Namespace | SimpleNamespace | dict[str, Any] | None = None,
|
||||
cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
cli_exit_on_error: bool | None = None,
|
||||
cli_cmd_method_name: str = 'cli_cmd',
|
||||
**model_init_data: Any,
|
||||
) -> T:
|
||||
"""
|
||||
Runs a Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as a CLI application.
|
||||
Running a model as a CLI application requires the `cli_cmd` method to be defined in the model class.
|
||||
|
||||
Args:
|
||||
model_cls: The model class to run as a CLI application.
|
||||
cli_args: The list of CLI arguments to parse. If `cli_settings_source` is specified, this may
|
||||
also be a namespace or dictionary of pre-parsed CLI arguments. Defaults to `sys.argv[1:]`.
|
||||
cli_settings_source: Override the default CLI settings source with a user defined instance.
|
||||
Defaults to `None`.
|
||||
cli_exit_on_error: Determines whether this function exits on error. If model is subclass of
|
||||
`BaseSettings`, defaults to BaseSettings `cli_exit_on_error` value. Otherwise, defaults to
|
||||
`True`.
|
||||
cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd".
|
||||
model_init_data: The model init data.
|
||||
|
||||
Returns:
|
||||
The ran instance of model.
|
||||
|
||||
Raises:
|
||||
SettingsError: If model_cls is not subclass of `BaseModel` or `pydantic.dataclasses.dataclass`.
|
||||
SettingsError: If model_cls does not have a `cli_cmd` entrypoint defined.
|
||||
"""
|
||||
|
||||
if not (is_pydantic_dataclass(model_cls) or is_model_class(model_cls)):
|
||||
raise SettingsError(
|
||||
f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass'
|
||||
)
|
||||
|
||||
cli_settings = None
|
||||
cli_parse_args = True if cli_args is None else cli_args
|
||||
if cli_settings_source is not None:
|
||||
if isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)):
|
||||
cli_settings = cli_settings_source(parsed_args=cli_parse_args)
|
||||
else:
|
||||
cli_settings = cli_settings_source(args=cli_parse_args)
|
||||
elif isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)):
|
||||
raise SettingsError('Error: `cli_args` must be list[str] or None when `cli_settings_source` is not used')
|
||||
|
||||
model_init_data['_cli_parse_args'] = cli_parse_args
|
||||
model_init_data['_cli_exit_on_error'] = cli_exit_on_error
|
||||
model_init_data['_cli_settings_source'] = cli_settings
|
||||
if not issubclass(model_cls, BaseSettings):
|
||||
base_settings_cls = CliApp._get_base_settings_cls(model_cls)
|
||||
model = base_settings_cls(**model_init_data)
|
||||
model_init_data = {}
|
||||
for field_name, field_info in base_settings_cls.model_fields.items():
|
||||
model_init_data[_field_name_for_signature(field_name, field_info)] = getattr(model, field_name)
|
||||
|
||||
return CliApp._run_cli_cmd(model_cls(**model_init_data), cli_cmd_method_name, is_required=False)
|
||||
|
||||
@staticmethod
|
||||
def run_subcommand(
|
||||
model: PydanticModel, cli_exit_on_error: bool | None = None, cli_cmd_method_name: str = 'cli_cmd'
|
||||
) -> PydanticModel:
|
||||
"""
|
||||
Runs the model subcommand. Running a model subcommand requires the `cli_cmd` method to be defined in
|
||||
the nested model subcommand class.
|
||||
|
||||
Args:
|
||||
model: The model to run the subcommand from.
|
||||
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
|
||||
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
|
||||
cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd".
|
||||
|
||||
Returns:
|
||||
The ran subcommand model.
|
||||
|
||||
Raises:
|
||||
SystemExit: When no subcommand is found and cli_exit_on_error=`True` (the default).
|
||||
SettingsError: When no subcommand is found and cli_exit_on_error=`False`.
|
||||
"""
|
||||
|
||||
subcommand = get_subcommand(model, is_required=True, cli_exit_on_error=cli_exit_on_error)
|
||||
return CliApp._run_cli_cmd(subcommand, cli_cmd_method_name, is_required=True)
|
||||
|
||||
@staticmethod
|
||||
def serialize(model: PydanticModel) -> list[str]:
|
||||
"""
|
||||
Serializes the CLI arguments for a Pydantic data model.
|
||||
|
||||
Args:
|
||||
model: The data model to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized CLI arguments for the data model.
|
||||
"""
|
||||
|
||||
base_settings_cls = CliApp._get_base_settings_cls(type(model))
|
||||
return CliSettingsSource[Any](base_settings_cls)._serialized_args(model)
|
||||
|
||||
@@ -1,653 +0,0 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from dataclasses import is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast
|
||||
|
||||
from pydantic import AliasChoices, AliasPath, BaseModel, Json
|
||||
from pydantic._internal._typing_extra import origin_is_union
|
||||
from pydantic._internal._utils import deep_update, lenient_issubclass
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import get_args, get_origin
|
||||
|
||||
from pydantic_settings.utils import path_type_label
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
DotenvType = Union[Path, str, List[Union[Path, str]], Tuple[Union[Path, str], ...]]
|
||||
|
||||
# This is used as default value for `_env_file` in the `BaseSettings` class and
|
||||
# `env_file` in `DotEnvSettingsSource` so the default can be distinguished from `None`.
|
||||
# See the docstring of `BaseSettings` for more details.
|
||||
ENV_FILE_SENTINEL: DotenvType = Path('')
|
||||
|
||||
|
||||
class SettingsError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class PydanticBaseSettingsSource(ABC):
|
||||
"""
|
||||
Abstract base class for settings sources, every settings source classes should inherit from it.
|
||||
"""
|
||||
|
||||
def __init__(self, settings_cls: type[BaseSettings]):
|
||||
self.settings_cls = settings_cls
|
||||
self.config = settings_cls.model_config
|
||||
|
||||
@abstractmethod
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value, the key for model creation, and a flag to determine whether value is complex.
|
||||
|
||||
This is an abstract method that should be overridden in every settings source classes.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple contains the key, value and a flag to determine whether value is complex.
|
||||
"""
|
||||
pass
|
||||
|
||||
def field_is_complex(self, field: FieldInfo) -> bool:
|
||||
"""
|
||||
Checks whether a field is complex, in which case it will attempt to be parsed as JSON.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
|
||||
Returns:
|
||||
Whether the field is complex.
|
||||
"""
|
||||
return _annotation_is_complex(field.annotation, field.metadata)
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
"""
|
||||
Prepares the value of a field.
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
value: The value of the field that has to be prepared.
|
||||
value_is_complex: A flag to determine whether value is complex.
|
||||
|
||||
Returns:
|
||||
The prepared value.
|
||||
"""
|
||||
if value is not None and (self.field_is_complex(field) or value_is_complex):
|
||||
return self.decode_complex_value(field_name, field, value)
|
||||
return value
|
||||
|
||||
def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any:
|
||||
"""
|
||||
Decode the value for a complex field
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
value: The value of the field that has to be prepared.
|
||||
|
||||
Returns:
|
||||
The decoded value for further preparation
|
||||
"""
|
||||
return json.loads(value)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class InitSettingsSource(PydanticBaseSettingsSource):
|
||||
"""
|
||||
Source class for loading values provided during settings class initialization.
|
||||
"""
|
||||
|
||||
def __init__(self, settings_cls: type[BaseSettings], init_kwargs: dict[str, Any]):
|
||||
self.init_kwargs = init_kwargs
|
||||
super().__init__(settings_cls)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
# Nothing to do here. Only implement the return statement to make mypy happy
|
||||
return None, '', False
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
return self.init_kwargs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})'
|
||||
|
||||
|
||||
class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource):
|
||||
def __init__(
|
||||
self, settings_cls: type[BaseSettings], case_sensitive: bool | None = None, env_prefix: str | None = None
|
||||
) -> None:
|
||||
super().__init__(settings_cls)
|
||||
self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
|
||||
self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '')
|
||||
|
||||
def _apply_case_sensitive(self, value: str) -> str:
|
||||
return value.lower() if not self.case_sensitive else value
|
||||
|
||||
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
|
||||
"""
|
||||
Extracts field info. This info is used to get the value of field from environment variables.
|
||||
|
||||
It returns a list of tuples, each tuple contains:
|
||||
* field_key: The key of field that has to be used in model creation.
|
||||
* env_name: The environment variable name of the field.
|
||||
* value_is_complex: A flag to determine whether the value from environment variable
|
||||
is complex and has to be parsed.
|
||||
|
||||
Args:
|
||||
field (FieldInfo): The field.
|
||||
field_name (str): The field name.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex.
|
||||
"""
|
||||
field_info: list[tuple[str, str, bool]] = []
|
||||
if isinstance(field.validation_alias, (AliasChoices, AliasPath)):
|
||||
v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases()
|
||||
else:
|
||||
v_alias = field.validation_alias
|
||||
|
||||
if v_alias:
|
||||
if isinstance(v_alias, list): # AliasChoices, AliasPath
|
||||
for alias in v_alias:
|
||||
if isinstance(alias, str): # AliasPath
|
||||
field_info.append((alias, self._apply_case_sensitive(alias), True if len(alias) > 1 else False))
|
||||
elif isinstance(alias, list): # AliasChoices
|
||||
first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str
|
||||
field_info.append(
|
||||
(first_arg, self._apply_case_sensitive(first_arg), True if len(alias) > 1 else False)
|
||||
)
|
||||
else: # string validation alias
|
||||
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
|
||||
else:
|
||||
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
|
||||
|
||||
return field_info
|
||||
|
||||
def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace field names in values dict by looking in models fields insensitively.
|
||||
|
||||
By having the following models:
|
||||
|
||||
```py
|
||||
class SubSubSub(BaseModel):
|
||||
VaL3: str
|
||||
|
||||
class SubSub(BaseModel):
|
||||
Val2: str
|
||||
SUB_sub_SuB: SubSubSub
|
||||
|
||||
class Sub(BaseModel):
|
||||
VAL1: str
|
||||
SUB_sub: SubSub
|
||||
|
||||
class Settings(BaseSettings):
|
||||
nested: Sub
|
||||
|
||||
model_config = SettingsConfigDict(env_nested_delimiter='__')
|
||||
```
|
||||
|
||||
Then:
|
||||
_replace_field_names_case_insensitively(
|
||||
field,
|
||||
{"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}}
|
||||
)
|
||||
Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}}
|
||||
"""
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
for name, value in field_values.items():
|
||||
sub_model_field: FieldInfo | None = None
|
||||
|
||||
# This is here to make mypy happy
|
||||
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
|
||||
if not field.annotation or not hasattr(field.annotation, 'model_fields'):
|
||||
values[name] = value
|
||||
continue
|
||||
|
||||
# Find field in sub model by looking in fields case insensitively
|
||||
for sub_model_field_name, f in field.annotation.model_fields.items():
|
||||
if not f.validation_alias and sub_model_field_name.lower() == name.lower():
|
||||
sub_model_field = f
|
||||
break
|
||||
|
||||
if not sub_model_field:
|
||||
values[name] = value
|
||||
continue
|
||||
|
||||
if lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict):
|
||||
values[sub_model_field_name] = self._replace_field_names_case_insensitively(sub_model_field, value)
|
||||
else:
|
||||
values[sub_model_field_name] = value
|
||||
|
||||
return values
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
try:
|
||||
field_value, field_key, value_is_complex = self.get_field_value(field, field_name)
|
||||
except Exception as e:
|
||||
raise SettingsError(
|
||||
f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"'
|
||||
) from e
|
||||
|
||||
try:
|
||||
field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex)
|
||||
except ValueError as e:
|
||||
raise SettingsError(
|
||||
f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"'
|
||||
) from e
|
||||
|
||||
if field_value is not None:
|
||||
if (
|
||||
not self.case_sensitive
|
||||
and lenient_issubclass(field.annotation, BaseModel)
|
||||
and isinstance(field_value, dict)
|
||||
):
|
||||
data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
|
||||
else:
|
||||
data[field_key] = field_value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class SecretsSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from secret files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
secrets_dir: str | Path | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(settings_cls, case_sensitive, env_prefix)
|
||||
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
"""
|
||||
Build fields from "secrets" files.
|
||||
"""
|
||||
secrets: dict[str, str | None] = {}
|
||||
|
||||
if self.secrets_dir is None:
|
||||
return secrets
|
||||
|
||||
self.secrets_path = Path(self.secrets_dir).expanduser()
|
||||
|
||||
if not self.secrets_path.exists():
|
||||
warnings.warn(f'directory "{self.secrets_path}" does not exist')
|
||||
return secrets
|
||||
|
||||
if not self.secrets_path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(self.secrets_path)}')
|
||||
|
||||
return super().__call__()
|
||||
|
||||
@classmethod
|
||||
def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None:
|
||||
"""
|
||||
Find a file within path's directory matching filename, optionally ignoring case.
|
||||
|
||||
Args:
|
||||
dir_path: Directory path.
|
||||
file_name: File name.
|
||||
case_sensitive: Whether to search for file name case sensitively.
|
||||
|
||||
Returns:
|
||||
Whether file path or `None` if file does not exist in directory.
|
||||
"""
|
||||
for f in dir_path.iterdir():
|
||||
if f.name == file_name:
|
||||
return f
|
||||
elif not case_sensitive and f.name.lower() == file_name.lower():
|
||||
return f
|
||||
return None
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from secret file and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple contains the key, value if the file exists otherwise `None`, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
path = self.find_case_path(self.secrets_path, env_name, self.case_sensitive)
|
||||
if not path:
|
||||
# path does not exist, we currently don't return a warning for this
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
return path.read_text().strip(), field_key, value_is_complex
|
||||
else:
|
||||
warnings.warn(
|
||||
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
return None, field_key, value_is_complex
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
|
||||
|
||||
|
||||
class EnvSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from environment variables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(settings_cls, case_sensitive, env_prefix)
|
||||
self.env_nested_delimiter = (
|
||||
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
|
||||
)
|
||||
self.env_prefix_len = len(self.env_prefix)
|
||||
|
||||
self.env_vars = self._load_env_vars()
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
if self.case_sensitive:
|
||||
return os.environ
|
||||
return {k.lower(): v for k, v in os.environ.items()}
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from environment variables and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple contains the key, value if the file exists otherwise `None`, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
env_val: str | None = None
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
env_val = self.env_vars.get(env_name)
|
||||
if env_val is not None:
|
||||
break
|
||||
|
||||
return env_val, field_key, value_is_complex
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
"""
|
||||
Prepare value for the field.
|
||||
|
||||
* Extract value for nested field.
|
||||
* Deserialize value to python object for complex field.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple contains prepared value for the field.
|
||||
|
||||
Raises:
|
||||
ValuesError: When There is an error in deserializing value for complex field.
|
||||
"""
|
||||
is_complex, allow_parse_failure = self._field_is_complex(field)
|
||||
if is_complex or value_is_complex:
|
||||
if value is None:
|
||||
# field is complex but no value found so far, try explode_env_vars
|
||||
env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
|
||||
if env_val_built:
|
||||
return env_val_built
|
||||
else:
|
||||
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
|
||||
try:
|
||||
value = self.decode_complex_value(field_name, field, value)
|
||||
except ValueError as e:
|
||||
if not allow_parse_failure:
|
||||
raise e
|
||||
|
||||
if isinstance(value, dict):
|
||||
return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
|
||||
else:
|
||||
return value
|
||||
elif value is not None:
|
||||
# simplest case, field is not complex, we only need to add the value if it was found
|
||||
return value
|
||||
|
||||
def _union_is_complex(self, annotation: type[Any] | None, metadata: list[Any]) -> bool:
|
||||
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
|
||||
|
||||
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
|
||||
"""
|
||||
Find out if a field is complex, and if so whether JSON errors should be ignored
|
||||
"""
|
||||
if self.field_is_complex(field):
|
||||
allow_parse_failure = False
|
||||
elif origin_is_union(get_origin(field.annotation)) and self._union_is_complex(field.annotation, field.metadata):
|
||||
allow_parse_failure = True
|
||||
else:
|
||||
return False, False
|
||||
|
||||
return True, allow_parse_failure
|
||||
|
||||
@staticmethod
|
||||
def next_field(field: FieldInfo | None, key: str) -> FieldInfo | None:
|
||||
"""
|
||||
Find the field in a sub model by key(env name)
|
||||
|
||||
By having the following models:
|
||||
|
||||
```py
|
||||
class SubSubModel(BaseSettings):
|
||||
dvals: Dict
|
||||
|
||||
class SubModel(BaseSettings):
|
||||
vals: list[str]
|
||||
sub_sub_model: SubSubModel
|
||||
|
||||
class Cfg(BaseSettings):
|
||||
sub_model: SubModel
|
||||
```
|
||||
|
||||
Then:
|
||||
next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
|
||||
next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
key: The key (env name).
|
||||
|
||||
Returns:
|
||||
Field if it finds the next field otherwise `None`.
|
||||
"""
|
||||
if not field or origin_is_union(get_origin(field.annotation)):
|
||||
# no support for Unions of complex BaseSettings fields
|
||||
return None
|
||||
elif field.annotation and hasattr(field.annotation, 'model_fields') and field.annotation.model_fields.get(key):
|
||||
return field.annotation.model_fields[key]
|
||||
|
||||
return None
|
||||
|
||||
def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
|
||||
"""
|
||||
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
|
||||
|
||||
This is applied to a single field, hence filtering by env_var prefix.
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
env_vars: Environment variables.
|
||||
|
||||
Returns:
|
||||
A dictionaty contains extracted values from nested env values.
|
||||
"""
|
||||
prefixes = [
|
||||
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
|
||||
]
|
||||
result: dict[str, Any] = {}
|
||||
for env_name, env_val in env_vars.items():
|
||||
if not any(env_name.startswith(prefix) for prefix in prefixes):
|
||||
continue
|
||||
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
|
||||
env_name_without_prefix = env_name[self.env_prefix_len :]
|
||||
_, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter)
|
||||
env_var = result
|
||||
target_field: FieldInfo | None = field
|
||||
for key in keys:
|
||||
target_field = self.next_field(target_field, key)
|
||||
env_var = env_var.setdefault(key, {})
|
||||
|
||||
# get proper field with last_key
|
||||
target_field = self.next_field(target_field, last_key)
|
||||
|
||||
# check if env_val maps to a complex field and if so, parse the env_val
|
||||
if target_field and env_val:
|
||||
is_complex, allow_json_failure = self._field_is_complex(target_field)
|
||||
if is_complex:
|
||||
try:
|
||||
env_val = self.decode_complex_value(last_key, target_field, env_val)
|
||||
except ValueError as e:
|
||||
if not allow_json_failure:
|
||||
raise e
|
||||
env_var[last_key] = env_val
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'EnvSettingsSource(env_nested_delimiter={self.env_nested_delimiter!r}, '
|
||||
f'env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
class DotEnvSettingsSource(EnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from env files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
env_file: DotenvType | None = ENV_FILE_SENTINEL,
|
||||
env_file_encoding: str | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
) -> None:
|
||||
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
|
||||
self.env_file_encoding = (
|
||||
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
|
||||
)
|
||||
super().__init__(settings_cls, case_sensitive, env_prefix, env_nested_delimiter)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return self._read_env_files(self.case_sensitive)
|
||||
|
||||
def _read_env_files(self, case_sensitive: bool) -> Mapping[str, str | None]:
|
||||
env_files = self.env_file
|
||||
if env_files is None:
|
||||
return {}
|
||||
|
||||
if isinstance(env_files, (str, os.PathLike)):
|
||||
env_files = [env_files]
|
||||
|
||||
dotenv_vars: dict[str, str | None] = {}
|
||||
for env_file in env_files:
|
||||
env_path = Path(env_file).expanduser()
|
||||
if env_path.is_file():
|
||||
dotenv_vars.update(
|
||||
read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive)
|
||||
)
|
||||
|
||||
return dotenv_vars
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = super().__call__()
|
||||
|
||||
data_lower_keys: list[str] = []
|
||||
if not self.case_sensitive:
|
||||
data_lower_keys = [x.lower() for x in data.keys()]
|
||||
|
||||
# As `extra` config is allowed in dotenv settings source, We have to
|
||||
# update data with extra env variabels from dotenv file.
|
||||
for env_name, env_value in self.env_vars.items():
|
||||
if env_name.startswith(self.env_prefix) and env_value is not None:
|
||||
env_name_without_prefix = env_name[self.env_prefix_len :]
|
||||
first_key, *_ = env_name_without_prefix.split(self.env_nested_delimiter)
|
||||
|
||||
if (data_lower_keys and first_key not in data_lower_keys) or (
|
||||
not data_lower_keys and first_key not in data
|
||||
):
|
||||
data[first_key] = env_value
|
||||
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'DotEnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
def read_env_file(
|
||||
file_path: Path, *, encoding: str | None = None, case_sensitive: bool = False
|
||||
) -> Mapping[str, str | None]:
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
except ImportError as e:
|
||||
raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e
|
||||
|
||||
file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8')
|
||||
if not case_sensitive:
|
||||
return {k.lower(): v for k, v in file_vars.items()}
|
||||
else:
|
||||
return file_vars
|
||||
|
||||
|
||||
def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
|
||||
if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
|
||||
return False
|
||||
origin = get_origin(annotation)
|
||||
return (
|
||||
_annotation_is_complex_inner(annotation)
|
||||
or _annotation_is_complex_inner(origin)
|
||||
or hasattr(origin, '__pydantic_core_schema__')
|
||||
or hasattr(origin, '__get_pydantic_core_schema__')
|
||||
)
|
||||
|
||||
|
||||
def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
|
||||
if lenient_issubclass(annotation, (str, bytes)):
|
||||
return False
|
||||
|
||||
return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass(
|
||||
annotation
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Package for handling configuration sources in pydantic-settings."""
|
||||
|
||||
from .base import (
|
||||
ConfigFileSourceMixin,
|
||||
DefaultSettingsSource,
|
||||
InitSettingsSource,
|
||||
PydanticBaseEnvSettingsSource,
|
||||
PydanticBaseSettingsSource,
|
||||
get_subcommand,
|
||||
)
|
||||
from .providers.aws import AWSSecretsManagerSettingsSource
|
||||
from .providers.azure import AzureKeyVaultSettingsSource
|
||||
from .providers.cli import (
|
||||
CLI_SUPPRESS,
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
CliUnknownArgs,
|
||||
)
|
||||
from .providers.dotenv import DotEnvSettingsSource, read_env_file
|
||||
from .providers.env import EnvSettingsSource
|
||||
from .providers.gcp import GoogleSecretManagerSettingsSource
|
||||
from .providers.json import JsonConfigSettingsSource
|
||||
from .providers.pyproject import PyprojectTomlConfigSettingsSource
|
||||
from .providers.secrets import SecretsSettingsSource
|
||||
from .providers.toml import TomlConfigSettingsSource
|
||||
from .providers.yaml import YamlConfigSettingsSource
|
||||
from .types import DEFAULT_PATH, ENV_FILE_SENTINEL, DotenvType, ForceDecode, NoDecode, PathType, PydanticModel
|
||||
|
||||
__all__ = [
|
||||
'CLI_SUPPRESS',
|
||||
'ENV_FILE_SENTINEL',
|
||||
'DEFAULT_PATH',
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'CliUnknownArgs',
|
||||
'DefaultSettingsSource',
|
||||
'DotEnvSettingsSource',
|
||||
'DotenvType',
|
||||
'EnvSettingsSource',
|
||||
'ForceDecode',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'InitSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'NoDecode',
|
||||
'PathType',
|
||||
'PydanticBaseEnvSettingsSource',
|
||||
'PydanticBaseSettingsSource',
|
||||
'ConfigFileSourceMixin',
|
||||
'PydanticModel',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
'get_subcommand',
|
||||
'read_env_file',
|
||||
]
|
||||
@@ -0,0 +1,527 @@
|
||||
"""Base classes and core functionality for pydantic-settings sources."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from pydantic import AliasChoices, AliasPath, BaseModel, TypeAdapter
|
||||
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
|
||||
get_origin,
|
||||
)
|
||||
from pydantic._internal._utils import is_model_class
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import get_args
|
||||
from typing_inspection import typing_objects
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ..exceptions import SettingsError
|
||||
from ..utils import _lenient_issubclass
|
||||
from .types import EnvNoneType, ForceDecode, NoDecode, PathType, PydanticModel, _CliSubCommand
|
||||
from .utils import (
|
||||
_annotation_is_complex,
|
||||
_get_alias_names,
|
||||
_get_model_fields,
|
||||
_strip_annotated,
|
||||
_union_is_complex,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
def get_subcommand(
|
||||
model: PydanticModel, is_required: bool = True, cli_exit_on_error: bool | None = None
|
||||
) -> Optional[PydanticModel]:
|
||||
"""
|
||||
Get the subcommand from a model.
|
||||
|
||||
Args:
|
||||
model: The model to get the subcommand from.
|
||||
is_required: Determines whether a model must have subcommand set and raises error if not
|
||||
found. Defaults to `True`.
|
||||
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
|
||||
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
|
||||
|
||||
Returns:
|
||||
The subcommand model if found, otherwise `None`.
|
||||
|
||||
Raises:
|
||||
SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True`
|
||||
(the default).
|
||||
SettingsError: When no subcommand is found and is_required=`True` and
|
||||
cli_exit_on_error=`False`.
|
||||
"""
|
||||
|
||||
model_cls = type(model)
|
||||
if cli_exit_on_error is None and is_model_class(model_cls):
|
||||
model_default = model_cls.model_config.get('cli_exit_on_error')
|
||||
if isinstance(model_default, bool):
|
||||
cli_exit_on_error = model_default
|
||||
if cli_exit_on_error is None:
|
||||
cli_exit_on_error = True
|
||||
|
||||
subcommands: list[str] = []
|
||||
for field_name, field_info in _get_model_fields(model_cls).items():
|
||||
if _CliSubCommand in field_info.metadata:
|
||||
if getattr(model, field_name) is not None:
|
||||
return getattr(model, field_name)
|
||||
subcommands.append(field_name)
|
||||
|
||||
if is_required:
|
||||
error_message = (
|
||||
f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}'
|
||||
if subcommands
|
||||
else 'Error: CLI subcommand is required but no subcommands were found.'
|
||||
)
|
||||
raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class PydanticBaseSettingsSource(ABC):
|
||||
"""
|
||||
Abstract base class for settings sources, every settings source classes should inherit from it.
|
||||
"""
|
||||
|
||||
def __init__(self, settings_cls: type[BaseSettings]):
|
||||
self.settings_cls = settings_cls
|
||||
self.config = settings_cls.model_config
|
||||
self._current_state: dict[str, Any] = {}
|
||||
self._settings_sources_data: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def _set_current_state(self, state: dict[str, Any]) -> None:
|
||||
"""
|
||||
Record the state of settings from the previous settings sources. This should
|
||||
be called right before __call__.
|
||||
"""
|
||||
self._current_state = state
|
||||
|
||||
def _set_settings_sources_data(self, states: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Record the state of settings from all previous settings sources. This should
|
||||
be called right before __call__.
|
||||
"""
|
||||
self._settings_sources_data = states
|
||||
|
||||
@property
|
||||
def current_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
The current state of the settings, populated by the previous settings sources.
|
||||
"""
|
||||
return self._current_state
|
||||
|
||||
@property
|
||||
def settings_sources_data(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
The state of all previous settings sources.
|
||||
"""
|
||||
return self._settings_sources_data
|
||||
|
||||
@abstractmethod
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value, the key for model creation, and a flag to determine whether value is complex.
|
||||
|
||||
This is an abstract method that should be overridden in every settings source classes.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value, key and a flag to determine whether value is complex.
|
||||
"""
|
||||
pass
|
||||
|
||||
def field_is_complex(self, field: FieldInfo) -> bool:
|
||||
"""
|
||||
Checks whether a field is complex, in which case it will attempt to be parsed as JSON.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
|
||||
Returns:
|
||||
Whether the field is complex.
|
||||
"""
|
||||
return _annotation_is_complex(field.annotation, field.metadata)
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
"""
|
||||
Prepares the value of a field.
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
value: The value of the field that has to be prepared.
|
||||
value_is_complex: A flag to determine whether value is complex.
|
||||
|
||||
Returns:
|
||||
The prepared value.
|
||||
"""
|
||||
if value is not None and (self.field_is_complex(field) or value_is_complex):
|
||||
return self.decode_complex_value(field_name, field, value)
|
||||
return value
|
||||
|
||||
def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any:
|
||||
"""
|
||||
Decode the value for a complex field
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
value: The value of the field that has to be prepared.
|
||||
|
||||
Returns:
|
||||
The decoded value for further preparation
|
||||
"""
|
||||
if field and (
|
||||
NoDecode in field.metadata
|
||||
or (self.config.get('enable_decoding') is False and ForceDecode not in field.metadata)
|
||||
):
|
||||
return value
|
||||
|
||||
return json.loads(value)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class ConfigFileSourceMixin(ABC):
|
||||
def _read_files(self, files: PathType | None) -> dict[str, Any]:
|
||||
if files is None:
|
||||
return {}
|
||||
if isinstance(files, (str, os.PathLike)):
|
||||
files = [files]
|
||||
vars: dict[str, Any] = {}
|
||||
for file in files:
|
||||
file_path = Path(file).expanduser()
|
||||
if file_path.is_file():
|
||||
vars.update(self._read_file(file_path))
|
||||
return vars
|
||||
|
||||
@abstractmethod
|
||||
def _read_file(self, path: Path) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultSettingsSource(PydanticBaseSettingsSource):
|
||||
"""
|
||||
Source class for loading default object values.
|
||||
|
||||
Args:
|
||||
settings_cls: The Settings class.
|
||||
nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
|
||||
Defaults to `False`.
|
||||
"""
|
||||
|
||||
def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partial_update: bool | None = None):
|
||||
super().__init__(settings_cls)
|
||||
self.defaults: dict[str, Any] = {}
|
||||
self.nested_model_default_partial_update = (
|
||||
nested_model_default_partial_update
|
||||
if nested_model_default_partial_update is not None
|
||||
else self.config.get('nested_model_default_partial_update', False)
|
||||
)
|
||||
if self.nested_model_default_partial_update:
|
||||
for field_name, field_info in settings_cls.model_fields.items():
|
||||
alias_names, *_ = _get_alias_names(field_name, field_info)
|
||||
preferred_alias = alias_names[0]
|
||||
if is_dataclass(type(field_info.default)):
|
||||
self.defaults[preferred_alias] = asdict(field_info.default)
|
||||
elif is_model_class(type(field_info.default)):
|
||||
self.defaults[preferred_alias] = field_info.default.model_dump()
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
# Nothing to do here. Only implement the return statement to make mypy happy
|
||||
return None, '', False
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
return self.defaults
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(nested_model_default_partial_update={self.nested_model_default_partial_update})'
|
||||
)
|
||||
|
||||
|
||||
class InitSettingsSource(PydanticBaseSettingsSource):
|
||||
"""
|
||||
Source class for loading values provided during settings class initialization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_kwargs: dict[str, Any],
|
||||
nested_model_default_partial_update: bool | None = None,
|
||||
):
|
||||
self.init_kwargs = {}
|
||||
init_kwarg_names = set(init_kwargs.keys())
|
||||
for field_name, field_info in settings_cls.model_fields.items():
|
||||
alias_names, *_ = _get_alias_names(field_name, field_info)
|
||||
init_kwarg_name = init_kwarg_names & set(alias_names)
|
||||
if init_kwarg_name:
|
||||
preferred_alias = alias_names[0]
|
||||
preferred_set_alias = next(alias for alias in alias_names if alias in init_kwarg_name)
|
||||
init_kwarg_names -= init_kwarg_name
|
||||
self.init_kwargs[preferred_alias] = init_kwargs[preferred_set_alias]
|
||||
self.init_kwargs.update({key: val for key, val in init_kwargs.items() if key in init_kwarg_names})
|
||||
|
||||
super().__init__(settings_cls)
|
||||
self.nested_model_default_partial_update = (
|
||||
nested_model_default_partial_update
|
||||
if nested_model_default_partial_update is not None
|
||||
else self.config.get('nested_model_default_partial_update', False)
|
||||
)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
# Nothing to do here. Only implement the return statement to make mypy happy
|
||||
return None, '', False
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
return (
|
||||
TypeAdapter(dict[str, Any]).dump_python(self.init_kwargs)
|
||||
if self.nested_model_default_partial_update
|
||||
else self.init_kwargs
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(init_kwargs={self.init_kwargs!r})'
|
||||
|
||||
|
||||
class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource):
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(settings_cls)
|
||||
self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
|
||||
self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '')
|
||||
self.env_ignore_empty = (
|
||||
env_ignore_empty if env_ignore_empty is not None else self.config.get('env_ignore_empty', False)
|
||||
)
|
||||
self.env_parse_none_str = (
|
||||
env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str')
|
||||
)
|
||||
self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums')
|
||||
|
||||
def _apply_case_sensitive(self, value: str) -> str:
|
||||
return value.lower() if not self.case_sensitive else value
|
||||
|
||||
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
|
||||
"""
|
||||
Extracts field info. This info is used to get the value of field from environment variables.
|
||||
|
||||
It returns a list of tuples, each tuple contains:
|
||||
* field_key: The key of field that has to be used in model creation.
|
||||
* env_name: The environment variable name of the field.
|
||||
* value_is_complex: A flag to determine whether the value from environment variable
|
||||
is complex and has to be parsed.
|
||||
|
||||
Args:
|
||||
field (FieldInfo): The field.
|
||||
field_name (str): The field name.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex.
|
||||
"""
|
||||
field_info: list[tuple[str, str, bool]] = []
|
||||
if isinstance(field.validation_alias, (AliasChoices, AliasPath)):
|
||||
v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases()
|
||||
else:
|
||||
v_alias = field.validation_alias
|
||||
|
||||
if v_alias:
|
||||
if isinstance(v_alias, list): # AliasChoices, AliasPath
|
||||
for alias in v_alias:
|
||||
if isinstance(alias, str): # AliasPath
|
||||
field_info.append((alias, self._apply_case_sensitive(alias), True if len(alias) > 1 else False))
|
||||
elif isinstance(alias, list): # AliasChoices
|
||||
first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str
|
||||
field_info.append(
|
||||
(first_arg, self._apply_case_sensitive(first_arg), True if len(alias) > 1 else False)
|
||||
)
|
||||
else: # string validation alias
|
||||
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
|
||||
|
||||
if not v_alias or self.config.get('populate_by_name', False):
|
||||
annotation = field.annotation
|
||||
if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)):
|
||||
annotation = _strip_annotated(annotation.__value__) # type: ignore[union-attr]
|
||||
if is_union_origin(get_origin(annotation)) and _union_is_complex(annotation, field.metadata):
|
||||
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
|
||||
else:
|
||||
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
|
||||
|
||||
return field_info
|
||||
|
||||
def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace field names in values dict by looking in models fields insensitively.
|
||||
|
||||
By having the following models:
|
||||
|
||||
```py
|
||||
class SubSubSub(BaseModel):
|
||||
VaL3: str
|
||||
|
||||
class SubSub(BaseModel):
|
||||
Val2: str
|
||||
SUB_sub_SuB: SubSubSub
|
||||
|
||||
class Sub(BaseModel):
|
||||
VAL1: str
|
||||
SUB_sub: SubSub
|
||||
|
||||
class Settings(BaseSettings):
|
||||
nested: Sub
|
||||
|
||||
model_config = SettingsConfigDict(env_nested_delimiter='__')
|
||||
```
|
||||
|
||||
Then:
|
||||
_replace_field_names_case_insensitively(
|
||||
field,
|
||||
{"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}}
|
||||
)
|
||||
Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}}
|
||||
"""
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
for name, value in field_values.items():
|
||||
sub_model_field: FieldInfo | None = None
|
||||
|
||||
annotation = field.annotation
|
||||
|
||||
# If field is Optional, we need to find the actual type
|
||||
if is_union_origin(get_origin(field.annotation)):
|
||||
args = get_args(annotation)
|
||||
if len(args) == 2 and type(None) in args:
|
||||
for arg in args:
|
||||
if arg is not None:
|
||||
annotation = arg
|
||||
break
|
||||
|
||||
# This is here to make mypy happy
|
||||
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
|
||||
if not annotation or not hasattr(annotation, 'model_fields'):
|
||||
values[name] = value
|
||||
continue
|
||||
else:
|
||||
model_fields: dict[str, FieldInfo] = annotation.model_fields
|
||||
|
||||
# Find field in sub model by looking in fields case insensitively
|
||||
field_key: str | None = None
|
||||
for sub_model_field_name, sub_model_field in model_fields.items():
|
||||
aliases, _ = _get_alias_names(sub_model_field_name, sub_model_field)
|
||||
_search = (alias for alias in aliases if alias.lower() == name.lower())
|
||||
if field_key := next(_search, None):
|
||||
break
|
||||
|
||||
if not field_key:
|
||||
values[name] = value
|
||||
continue
|
||||
|
||||
if (
|
||||
sub_model_field is not None
|
||||
and _lenient_issubclass(sub_model_field.annotation, BaseModel)
|
||||
and isinstance(value, dict)
|
||||
):
|
||||
values[field_key] = self._replace_field_names_case_insensitively(sub_model_field, value)
|
||||
else:
|
||||
values[field_key] = value
|
||||
|
||||
return values
|
||||
|
||||
def _replace_env_none_type_values(self, field_value: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Recursively parse values that are of "None" type(EnvNoneType) to `None` type(None).
|
||||
"""
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
for key, value in field_value.items():
|
||||
if not isinstance(value, EnvNoneType):
|
||||
values[key] = value if not isinstance(value, dict) else self._replace_env_none_type_values(value)
|
||||
else:
|
||||
values[key] = None
|
||||
|
||||
return values
|
||||
|
||||
def _get_resolved_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value, the preferred alias key for model creation, and a flag to determine whether value
|
||||
is complex.
|
||||
|
||||
Note:
|
||||
In V3, this method should either be made public, or, this method should be removed and the
|
||||
abstract method get_field_value should be updated to include a "use_preferred_alias" flag.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value, preferred key and a flag to determine whether value is complex.
|
||||
"""
|
||||
field_value, field_key, value_is_complex = self.get_field_value(field, field_name)
|
||||
if not (value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name))):
|
||||
field_infos = self._extract_field_info(field, field_name)
|
||||
preferred_key, *_ = field_infos[0]
|
||||
return field_value, preferred_key, value_is_complex
|
||||
return field_value, field_key, value_is_complex
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
try:
|
||||
field_value, field_key, value_is_complex = self._get_resolved_field_value(field, field_name)
|
||||
except Exception as e:
|
||||
raise SettingsError(
|
||||
f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"'
|
||||
) from e
|
||||
|
||||
try:
|
||||
field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex)
|
||||
except ValueError as e:
|
||||
raise SettingsError(
|
||||
f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"'
|
||||
) from e
|
||||
|
||||
if field_value is not None:
|
||||
if self.env_parse_none_str is not None:
|
||||
if isinstance(field_value, dict):
|
||||
field_value = self._replace_env_none_type_values(field_value)
|
||||
elif isinstance(field_value, EnvNoneType):
|
||||
field_value = None
|
||||
if (
|
||||
not self.case_sensitive
|
||||
# and _lenient_issubclass(field.annotation, BaseModel)
|
||||
and isinstance(field_value, dict)
|
||||
):
|
||||
data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
|
||||
else:
|
||||
data[field_key] = field_value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ConfigFileSourceMixin',
|
||||
'DefaultSettingsSource',
|
||||
'InitSettingsSource',
|
||||
'PydanticBaseEnvSettingsSource',
|
||||
'PydanticBaseSettingsSource',
|
||||
'SettingsError',
|
||||
]
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Package containing individual source implementations."""
|
||||
|
||||
from .aws import AWSSecretsManagerSettingsSource
|
||||
from .azure import AzureKeyVaultSettingsSource
|
||||
from .cli import (
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
)
|
||||
from .dotenv import DotEnvSettingsSource
|
||||
from .env import EnvSettingsSource
|
||||
from .gcp import GoogleSecretManagerSettingsSource
|
||||
from .json import JsonConfigSettingsSource
|
||||
from .pyproject import PyprojectTomlConfigSettingsSource
|
||||
from .secrets import SecretsSettingsSource
|
||||
from .toml import TomlConfigSettingsSource
|
||||
from .yaml import YamlConfigSettingsSource
|
||||
|
||||
__all__ = [
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'DotEnvSettingsSource',
|
||||
'EnvSettingsSource',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
]
|
||||
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations as _annotations # important for BaseSettings import to work
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..utils import parse_env_vars
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
boto3_client = None
|
||||
SecretsManagerClient = None
|
||||
|
||||
|
||||
def import_aws_secrets_manager() -> None:
|
||||
global boto3_client
|
||||
global SecretsManagerClient
|
||||
|
||||
try:
|
||||
from boto3 import client as boto3_client
|
||||
from mypy_boto3_secretsmanager.client import SecretsManagerClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'AWS Secrets Manager dependencies are not installed, run `pip install pydantic-settings[aws-secrets-manager]`'
|
||||
) from e
|
||||
|
||||
|
||||
class AWSSecretsManagerSettingsSource(EnvSettingsSource):
|
||||
_secret_id: str
|
||||
_secretsmanager_client: SecretsManagerClient # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
secret_id: str,
|
||||
region_name: str | None = None,
|
||||
endpoint_url: str | None = None,
|
||||
case_sensitive: bool | None = True,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = '--',
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
import_aws_secrets_manager()
|
||||
self._secretsmanager_client = boto3_client('secretsmanager', region_name=region_name, endpoint_url=endpoint_url) # type: ignore
|
||||
self._secret_id = secret_id
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
|
||||
response = self._secretsmanager_client.get_secret_value(SecretId=self._secret_id) # type: ignore
|
||||
|
||||
return parse_env_vars(
|
||||
json.loads(response['SecretString']),
|
||||
self.case_sensitive,
|
||||
self.env_ignore_empty,
|
||||
self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(secret_id={self._secret_id!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
]
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Azure Key Vault settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic.alias_generators import to_snake
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
TokenCredential = None
|
||||
ResourceNotFoundError = None
|
||||
SecretClient = None
|
||||
|
||||
|
||||
def import_azure_key_vault() -> None:
|
||||
global TokenCredential
|
||||
global SecretClient
|
||||
global ResourceNotFoundError
|
||||
|
||||
try:
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`'
|
||||
) from e
|
||||
|
||||
|
||||
class AzureKeyVaultMapping(Mapping[str, Optional[str]]):
|
||||
_loaded_secrets: dict[str, str | None]
|
||||
_secret_client: SecretClient
|
||||
_secret_names: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret_client: SecretClient,
|
||||
case_sensitive: bool,
|
||||
snake_case_conversion: bool,
|
||||
) -> None:
|
||||
self._loaded_secrets = {}
|
||||
self._secret_client = secret_client
|
||||
self._case_sensitive = case_sensitive
|
||||
self._snake_case_conversion = snake_case_conversion
|
||||
self._secret_map: dict[str, str] = self._load_remote()
|
||||
|
||||
def _load_remote(self) -> dict[str, str]:
|
||||
secret_names: Iterator[str] = (
|
||||
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
|
||||
)
|
||||
|
||||
if self._snake_case_conversion:
|
||||
return {to_snake(name): name for name in secret_names}
|
||||
|
||||
if self._case_sensitive:
|
||||
return {name: name for name in secret_names}
|
||||
|
||||
return {name.lower(): name for name in secret_names}
|
||||
|
||||
def __getitem__(self, key: str) -> str | None:
|
||||
new_key = key
|
||||
|
||||
if self._snake_case_conversion:
|
||||
new_key = to_snake(key)
|
||||
elif not self._case_sensitive:
|
||||
new_key = key.lower()
|
||||
|
||||
if new_key not in self._loaded_secrets:
|
||||
if new_key in self._secret_map:
|
||||
self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._loaded_secrets[new_key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._secret_map)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._secret_map.keys())
|
||||
|
||||
|
||||
class AzureKeyVaultSettingsSource(EnvSettingsSource):
|
||||
_url: str
|
||||
_credential: TokenCredential
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
url: str,
|
||||
credential: TokenCredential,
|
||||
dash_to_underscore: bool = False,
|
||||
case_sensitive: bool | None = None,
|
||||
snake_case_conversion: bool = False,
|
||||
env_prefix: str | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
import_azure_key_vault()
|
||||
self._url = url
|
||||
self._credential = credential
|
||||
self._dash_to_underscore = dash_to_underscore
|
||||
self._snake_case_conversion = snake_case_conversion
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=False if snake_case_conversion else case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter='__' if snake_case_conversion else '--',
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
|
||||
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
|
||||
return AzureKeyVaultMapping(
|
||||
secret_client=secret_client,
|
||||
case_sensitive=self.case_sensitive,
|
||||
snake_case_conversion=self._snake_case_conversion,
|
||||
)
|
||||
|
||||
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
|
||||
if self._snake_case_conversion:
|
||||
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
|
||||
|
||||
if self._dash_to_underscore:
|
||||
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
|
||||
|
||||
return super()._extract_field_info(field, field_name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
|
||||
|
||||
__all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,168 @@
|
||||
"""Dotenv file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dotenv import dotenv_values
|
||||
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
|
||||
get_origin,
|
||||
)
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ..types import ENV_FILE_SENTINEL, DotenvType
|
||||
from ..utils import (
|
||||
_annotation_is_complex,
|
||||
_union_is_complex,
|
||||
parse_env_vars,
|
||||
)
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class DotEnvSettingsSource(EnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from env files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
env_file: DotenvType | None = ENV_FILE_SENTINEL,
|
||||
env_file_encoding: str | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
env_nested_max_split: int | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
|
||||
self.env_file_encoding = (
|
||||
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
|
||||
)
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive,
|
||||
env_prefix,
|
||||
env_nested_delimiter,
|
||||
env_nested_max_split,
|
||||
env_ignore_empty,
|
||||
env_parse_none_str,
|
||||
env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return self._read_env_files()
|
||||
|
||||
@staticmethod
|
||||
def _static_read_env_file(
|
||||
file_path: Path,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8')
|
||||
return parse_env_vars(file_vars, case_sensitive, ignore_empty, parse_none_str)
|
||||
|
||||
def _read_env_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
) -> Mapping[str, str | None]:
|
||||
return self._static_read_env_file(
|
||||
file_path,
|
||||
encoding=self.env_file_encoding,
|
||||
case_sensitive=self.case_sensitive,
|
||||
ignore_empty=self.env_ignore_empty,
|
||||
parse_none_str=self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def _read_env_files(self) -> Mapping[str, str | None]:
|
||||
env_files = self.env_file
|
||||
if env_files is None:
|
||||
return {}
|
||||
|
||||
if isinstance(env_files, (str, os.PathLike)):
|
||||
env_files = [env_files]
|
||||
|
||||
dotenv_vars: dict[str, str | None] = {}
|
||||
for env_file in env_files:
|
||||
env_path = Path(env_file).expanduser()
|
||||
if env_path.is_file():
|
||||
dotenv_vars.update(self._read_env_file(env_path))
|
||||
|
||||
return dotenv_vars
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = super().__call__()
|
||||
is_extra_allowed = self.config.get('extra') != 'forbid'
|
||||
|
||||
# As `extra` config is allowed in dotenv settings source, We have to
|
||||
# update data with extra env variables from dotenv file.
|
||||
for env_name, env_value in self.env_vars.items():
|
||||
if not env_value or env_name in data or (self.env_prefix and env_name in self.settings_cls.model_fields):
|
||||
continue
|
||||
env_used = False
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
for _, field_env_name, _ in self._extract_field_info(field, field_name):
|
||||
if env_name == field_env_name or (
|
||||
(
|
||||
_annotation_is_complex(field.annotation, field.metadata)
|
||||
or (
|
||||
is_union_origin(get_origin(field.annotation))
|
||||
and _union_is_complex(field.annotation, field.metadata)
|
||||
)
|
||||
)
|
||||
and env_name.startswith(field_env_name)
|
||||
):
|
||||
env_used = True
|
||||
break
|
||||
if env_used:
|
||||
break
|
||||
if not env_used:
|
||||
if is_extra_allowed and env_name.startswith(self.env_prefix):
|
||||
# env_prefix should be respected and removed from the env_name
|
||||
normalized_env_name = env_name[len(self.env_prefix) :]
|
||||
data[normalized_env_name] = env_value
|
||||
else:
|
||||
data[env_name] = env_value
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
def read_env_file(
|
||||
file_path: Path,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
warnings.warn(
|
||||
'read_env_file will be removed in the next version, use DotEnvSettingsSource._static_read_env_file if you must',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return DotEnvSettingsSource._static_read_env_file(
|
||||
file_path,
|
||||
encoding=encoding,
|
||||
case_sensitive=case_sensitive,
|
||||
ignore_empty=ignore_empty,
|
||||
parse_none_str=parse_none_str,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['DotEnvSettingsSource', 'read_env_file']
|
||||
@@ -0,0 +1,270 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from pydantic._internal._utils import deep_update, is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import get_args, get_origin
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ...utils import _lenient_issubclass
|
||||
from ..base import PydanticBaseEnvSettingsSource
|
||||
from ..types import EnvNoneType
|
||||
from ..utils import (
|
||||
_annotation_enum_name_to_val,
|
||||
_get_model_fields,
|
||||
_union_is_complex,
|
||||
parse_env_vars,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class EnvSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from environment variables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
env_nested_max_split: int | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
|
||||
)
|
||||
self.env_nested_delimiter = (
|
||||
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
|
||||
)
|
||||
self.env_nested_max_split = (
|
||||
env_nested_max_split if env_nested_max_split is not None else self.config.get('env_nested_max_split')
|
||||
)
|
||||
self.maxsplit = (self.env_nested_max_split or 0) - 1
|
||||
self.env_prefix_len = len(self.env_prefix)
|
||||
|
||||
self.env_vars = self._load_env_vars()
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty, self.env_parse_none_str)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from environment variables and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value (`None` if not found), key, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
env_val: str | None = None
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
env_val = self.env_vars.get(env_name)
|
||||
if env_val is not None:
|
||||
break
|
||||
|
||||
return env_val, field_key, value_is_complex
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
"""
|
||||
Prepare value for the field.
|
||||
|
||||
* Extract value for nested field.
|
||||
* Deserialize value to python object for complex field.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple contains prepared value for the field.
|
||||
|
||||
Raises:
|
||||
ValuesError: When There is an error in deserializing value for complex field.
|
||||
"""
|
||||
is_complex, allow_parse_failure = self._field_is_complex(field)
|
||||
if self.env_parse_enums:
|
||||
enum_val = _annotation_enum_name_to_val(field.annotation, value)
|
||||
value = value if enum_val is None else enum_val
|
||||
|
||||
if is_complex or value_is_complex:
|
||||
if isinstance(value, EnvNoneType):
|
||||
return value
|
||||
elif value is None:
|
||||
# field is complex but no value found so far, try explode_env_vars
|
||||
env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
|
||||
if env_val_built:
|
||||
return env_val_built
|
||||
else:
|
||||
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
|
||||
try:
|
||||
value = self.decode_complex_value(field_name, field, value)
|
||||
except ValueError as e:
|
||||
if not allow_parse_failure:
|
||||
raise e
|
||||
|
||||
if isinstance(value, dict):
|
||||
return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
|
||||
else:
|
||||
return value
|
||||
elif value is not None:
|
||||
# simplest case, field is not complex, we only need to add the value if it was found
|
||||
return value
|
||||
|
||||
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
|
||||
"""
|
||||
Find out if a field is complex, and if so whether JSON errors should be ignored
|
||||
"""
|
||||
if self.field_is_complex(field):
|
||||
allow_parse_failure = False
|
||||
elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
|
||||
allow_parse_failure = True
|
||||
else:
|
||||
return False, False
|
||||
|
||||
return True, allow_parse_failure
|
||||
|
||||
# Default value of `case_sensitive` is `None`, because we don't want to break existing behavior.
|
||||
# We have to change the method to a non-static method and use
|
||||
# `self.case_sensitive` instead in V3.
|
||||
def next_field(
|
||||
self, field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None
|
||||
) -> FieldInfo | None:
|
||||
"""
|
||||
Find the field in a sub model by key(env name)
|
||||
|
||||
By having the following models:
|
||||
|
||||
```py
|
||||
class SubSubModel(BaseSettings):
|
||||
dvals: Dict
|
||||
|
||||
class SubModel(BaseSettings):
|
||||
vals: list[str]
|
||||
sub_sub_model: SubSubModel
|
||||
|
||||
class Cfg(BaseSettings):
|
||||
sub_model: SubModel
|
||||
```
|
||||
|
||||
Then:
|
||||
next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
|
||||
next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
key: The key (env name).
|
||||
case_sensitive: Whether to search for key case sensitively.
|
||||
|
||||
Returns:
|
||||
Field if it finds the next field otherwise `None`.
|
||||
"""
|
||||
if not field:
|
||||
return None
|
||||
|
||||
annotation = field.annotation if isinstance(field, FieldInfo) else field
|
||||
for type_ in get_args(annotation):
|
||||
type_has_key = self.next_field(type_, key, case_sensitive)
|
||||
if type_has_key:
|
||||
return type_has_key
|
||||
if is_model_class(annotation) or is_pydantic_dataclass(annotation): # type: ignore[arg-type]
|
||||
fields = _get_model_fields(annotation)
|
||||
# `case_sensitive is None` is here to be compatible with the old behavior.
|
||||
# Has to be removed in V3.
|
||||
for field_name, f in fields.items():
|
||||
for _, env_name, _ in self._extract_field_info(f, field_name):
|
||||
if case_sensitive is None or case_sensitive:
|
||||
if field_name == key or env_name == key:
|
||||
return f
|
||||
elif field_name.lower() == key.lower() or env_name.lower() == key.lower():
|
||||
return f
|
||||
return None
|
||||
|
||||
def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
|
||||
"""
|
||||
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
|
||||
|
||||
This is applied to a single field, hence filtering by env_var prefix.
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
env_vars: Environment variables.
|
||||
|
||||
Returns:
|
||||
A dictionary contains extracted values from nested env values.
|
||||
"""
|
||||
if not self.env_nested_delimiter:
|
||||
return {}
|
||||
|
||||
ann = field.annotation
|
||||
is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)
|
||||
|
||||
prefixes = [
|
||||
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
|
||||
]
|
||||
result: dict[str, Any] = {}
|
||||
for env_name, env_val in env_vars.items():
|
||||
try:
|
||||
prefix = next(prefix for prefix in prefixes if env_name.startswith(prefix))
|
||||
except StopIteration:
|
||||
continue
|
||||
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
|
||||
env_name_without_prefix = env_name[len(prefix) :]
|
||||
*keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter, self.maxsplit)
|
||||
env_var = result
|
||||
target_field: FieldInfo | None = field
|
||||
for key in keys:
|
||||
target_field = self.next_field(target_field, key, self.case_sensitive)
|
||||
if isinstance(env_var, dict):
|
||||
env_var = env_var.setdefault(key, {})
|
||||
|
||||
# get proper field with last_key
|
||||
target_field = self.next_field(target_field, last_key, self.case_sensitive)
|
||||
|
||||
# check if env_val maps to a complex field and if so, parse the env_val
|
||||
if (target_field or is_dict) and env_val:
|
||||
if target_field:
|
||||
is_complex, allow_json_failure = self._field_is_complex(target_field)
|
||||
if self.env_parse_enums:
|
||||
enum_val = _annotation_enum_name_to_val(target_field.annotation, env_val)
|
||||
env_val = env_val if enum_val is None else enum_val
|
||||
else:
|
||||
# nested field type is dict
|
||||
is_complex, allow_json_failure = True, True
|
||||
if is_complex:
|
||||
try:
|
||||
env_val = self.decode_complex_value(last_key, target_field, env_val) # type: ignore
|
||||
except ValueError as e:
|
||||
if not allow_json_failure:
|
||||
raise e
|
||||
if isinstance(env_var, dict):
|
||||
if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
|
||||
env_var[last_key] = env_val
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '
|
||||
f'env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['EnvSettingsSource']
|
||||
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.auth import default as google_auth_default
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud.secretmanager import SecretManagerServiceClient
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
Credentials = None
|
||||
SecretManagerServiceClient = None
|
||||
google_auth_default = None
|
||||
|
||||
|
||||
def import_gcp_secret_manager() -> None:
|
||||
global Credentials
|
||||
global SecretManagerServiceClient
|
||||
global google_auth_default
|
||||
|
||||
try:
|
||||
from google.auth import default as google_auth_default
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud.secretmanager import SecretManagerServiceClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'GCP Secret Manager dependencies are not installed, run `pip install pydantic-settings[gcp-secret-manager]`'
|
||||
) from e
|
||||
|
||||
|
||||
class GoogleSecretManagerMapping(Mapping[str, Optional[str]]):
|
||||
_loaded_secrets: dict[str, str | None]
|
||||
_secret_client: SecretManagerServiceClient
|
||||
|
||||
def __init__(self, secret_client: SecretManagerServiceClient, project_id: str, case_sensitive: bool) -> None:
|
||||
self._loaded_secrets = {}
|
||||
self._secret_client = secret_client
|
||||
self._project_id = project_id
|
||||
self._case_sensitive = case_sensitive
|
||||
|
||||
@property
|
||||
def _gcp_project_path(self) -> str:
|
||||
return self._secret_client.common_project_path(self._project_id)
|
||||
|
||||
@cached_property
|
||||
def _secret_names(self) -> list[str]:
|
||||
rv: list[str] = []
|
||||
|
||||
secrets = self._secret_client.list_secrets(parent=self._gcp_project_path)
|
||||
for secret in secrets:
|
||||
name = self._secret_client.parse_secret_path(secret.name).get('secret', '')
|
||||
if not self._case_sensitive:
|
||||
name = name.lower()
|
||||
rv.append(name)
|
||||
return rv
|
||||
|
||||
def _secret_version_path(self, key: str, version: str = 'latest') -> str:
|
||||
return self._secret_client.secret_version_path(self._project_id, key, version)
|
||||
|
||||
def __getitem__(self, key: str) -> str | None:
|
||||
if not self._case_sensitive:
|
||||
key = key.lower()
|
||||
if key not in self._loaded_secrets:
|
||||
# If we know the key isn't available in secret manager, raise a key error
|
||||
if key not in self._secret_names:
|
||||
raise KeyError(key)
|
||||
|
||||
try:
|
||||
self._loaded_secrets[key] = self._secret_client.access_secret_version(
|
||||
name=self._secret_version_path(key)
|
||||
).payload.data.decode('UTF-8')
|
||||
except Exception:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._loaded_secrets[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._secret_names)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._secret_names)
|
||||
|
||||
|
||||
class GoogleSecretManagerSettingsSource(EnvSettingsSource):
|
||||
_credentials: Credentials
|
||||
_secret_client: SecretManagerServiceClient
|
||||
_project_id: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
credentials: Credentials | None = None,
|
||||
project_id: str | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
secret_client: SecretManagerServiceClient | None = None,
|
||||
case_sensitive: bool | None = True,
|
||||
) -> None:
|
||||
# Import Google Packages if they haven't already been imported
|
||||
if SecretManagerServiceClient is None or Credentials is None or google_auth_default is None:
|
||||
import_gcp_secret_manager()
|
||||
|
||||
# If credentials or project_id are not passed, then
|
||||
# try to get them from the default function
|
||||
if not credentials or not project_id:
|
||||
_creds, _project_id = google_auth_default() # type: ignore[no-untyped-call]
|
||||
|
||||
# Set the credentials and/or project id if they weren't specified
|
||||
if credentials is None:
|
||||
credentials = _creds
|
||||
|
||||
if project_id is None:
|
||||
if isinstance(_project_id, str):
|
||||
project_id = _project_id
|
||||
else:
|
||||
raise AttributeError(
|
||||
'project_id is required to be specified either as an argument or from the google.auth.default. See https://google-auth.readthedocs.io/en/master/reference/google.auth.html#google.auth.default'
|
||||
)
|
||||
|
||||
self._credentials: Credentials = credentials
|
||||
self._project_id: str = project_id
|
||||
|
||||
if secret_client:
|
||||
self._secret_client = secret_client
|
||||
else:
|
||||
self._secret_client = SecretManagerServiceClient(credentials=self._credentials)
|
||||
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
|
||||
return GoogleSecretManagerMapping(
|
||||
self._secret_client, project_id=self._project_id, case_sensitive=self.case_sensitive
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(project_id={self._project_id!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
|
||||
|
||||
__all__ = ['GoogleSecretManagerSettingsSource', 'GoogleSecretManagerMapping']
|
||||
@@ -0,0 +1,47 @@
|
||||
"""JSON file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class JsonConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a JSON file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
json_file: PathType | None = DEFAULT_PATH,
|
||||
json_file_encoding: str | None = None,
|
||||
):
|
||||
self.json_file_path = json_file if json_file != DEFAULT_PATH else settings_cls.model_config.get('json_file')
|
||||
self.json_file_encoding = (
|
||||
json_file_encoding
|
||||
if json_file_encoding is not None
|
||||
else settings_cls.model_config.get('json_file_encoding')
|
||||
)
|
||||
self.json_data = self._read_files(self.json_file_path)
|
||||
super().__init__(settings_cls, self.json_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
with open(file_path, encoding=self.json_file_encoding) as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(json_file={self.json_file_path})'
|
||||
|
||||
|
||||
__all__ = ['JsonConfigSettingsSource']
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Pyproject TOML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .toml import TomlConfigSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSource):
|
||||
"""
|
||||
A source class that loads variables from a `pyproject.toml` file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
toml_file: Path | None = None,
|
||||
) -> None:
|
||||
self.toml_file_path = self._pick_pyproject_toml_file(
|
||||
toml_file, settings_cls.model_config.get('pyproject_toml_depth', 0)
|
||||
)
|
||||
self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get(
|
||||
'pyproject_toml_table_header', ('tool', 'pydantic-settings')
|
||||
)
|
||||
self.toml_data = self._read_files(self.toml_file_path)
|
||||
for key in self.toml_table_header:
|
||||
self.toml_data = self.toml_data.get(key, {})
|
||||
super(TomlConfigSettingsSource, self).__init__(settings_cls, self.toml_data)
|
||||
|
||||
@staticmethod
|
||||
def _pick_pyproject_toml_file(provided: Path | None, depth: int) -> Path:
|
||||
"""Pick a `pyproject.toml` file path to use.
|
||||
|
||||
Args:
|
||||
provided: Explicit path provided when instantiating this class.
|
||||
depth: Number of directories up the tree to check of a pyproject.toml.
|
||||
|
||||
"""
|
||||
if provided:
|
||||
return provided.resolve()
|
||||
rv = Path.cwd() / 'pyproject.toml'
|
||||
count = 0
|
||||
if not rv.is_file():
|
||||
child = rv.parent.parent / 'pyproject.toml'
|
||||
while count < depth:
|
||||
if child.is_file():
|
||||
return child
|
||||
if str(child.parent) == rv.root:
|
||||
break # end discovery after checking system root once
|
||||
child = child.parent.parent / 'pyproject.toml'
|
||||
count += 1
|
||||
return rv
|
||||
|
||||
|
||||
__all__ = ['PyprojectTomlConfigSettingsSource']
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Secrets file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from pydantic_settings.utils import path_type_label
|
||||
|
||||
from ...exceptions import SettingsError
|
||||
from ..base import PydanticBaseEnvSettingsSource
|
||||
from ..types import PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class SecretsSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from secret files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
secrets_dir: PathType | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
|
||||
)
|
||||
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
"""
|
||||
Build fields from "secrets" files.
|
||||
"""
|
||||
secrets: dict[str, str | None] = {}
|
||||
|
||||
if self.secrets_dir is None:
|
||||
return secrets
|
||||
|
||||
secrets_dirs = [self.secrets_dir] if isinstance(self.secrets_dir, (str, os.PathLike)) else self.secrets_dir
|
||||
secrets_paths = [Path(p).expanduser() for p in secrets_dirs]
|
||||
self.secrets_paths = []
|
||||
|
||||
for path in secrets_paths:
|
||||
if not path.exists():
|
||||
warnings.warn(f'directory "{path}" does not exist')
|
||||
else:
|
||||
self.secrets_paths.append(path)
|
||||
|
||||
if not len(self.secrets_paths):
|
||||
return secrets
|
||||
|
||||
for path in self.secrets_paths:
|
||||
if not path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(path)}')
|
||||
|
||||
return super().__call__()
|
||||
|
||||
@classmethod
|
||||
def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None:
|
||||
"""
|
||||
Find a file within path's directory matching filename, optionally ignoring case.
|
||||
|
||||
Args:
|
||||
dir_path: Directory path.
|
||||
file_name: File name.
|
||||
case_sensitive: Whether to search for file name case sensitively.
|
||||
|
||||
Returns:
|
||||
Whether file path or `None` if file does not exist in directory.
|
||||
"""
|
||||
for f in dir_path.iterdir():
|
||||
if f.name == file_name:
|
||||
return f
|
||||
elif not case_sensitive and f.name.lower() == file_name.lower():
|
||||
return f
|
||||
return None
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from secret file and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value (`None` if the file does not exist), key, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
# paths reversed to match the last-wins behaviour of `env_file`
|
||||
for secrets_path in reversed(self.secrets_paths):
|
||||
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
|
||||
if not path:
|
||||
# path does not exist, we currently don't return a warning for this
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
return path.read_text().strip(), field_key, value_is_complex
|
||||
else:
|
||||
warnings.warn(
|
||||
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
return None, field_key, value_is_complex
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'
|
||||
@@ -0,0 +1,66 @@
|
||||
"""TOML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
else:
|
||||
tomllib = None
|
||||
import tomli
|
||||
else:
|
||||
tomllib = None
|
||||
tomli = None
|
||||
|
||||
|
||||
def import_toml() -> None:
|
||||
global tomli
|
||||
global tomllib
|
||||
if sys.version_info < (3, 11):
|
||||
if tomli is not None:
|
||||
return
|
||||
try:
|
||||
import tomli
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError('tomli is not installed, run `pip install pydantic-settings[toml]`') from e
|
||||
else:
|
||||
if tomllib is not None:
|
||||
return
|
||||
import tomllib
|
||||
|
||||
|
||||
class TomlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a TOML file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
toml_file: PathType | None = DEFAULT_PATH,
|
||||
):
|
||||
self.toml_file_path = toml_file if toml_file != DEFAULT_PATH else settings_cls.model_config.get('toml_file')
|
||||
self.toml_data = self._read_files(self.toml_file_path)
|
||||
super().__init__(settings_cls, self.toml_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
import_toml()
|
||||
with open(file_path, mode='rb') as toml_file:
|
||||
if sys.version_info < (3, 11):
|
||||
return tomli.load(toml_file)
|
||||
return tomllib.load(toml_file)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(toml_file={self.toml_file_path})'
|
||||
@@ -0,0 +1,75 @@
|
||||
"""YAML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import yaml
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
yaml = None
|
||||
|
||||
|
||||
def import_yaml() -> None:
|
||||
global yaml
|
||||
if yaml is not None:
|
||||
return
|
||||
try:
|
||||
import yaml
|
||||
except ImportError as e:
|
||||
raise ImportError('PyYAML is not installed, run `pip install pydantic-settings[yaml]`') from e
|
||||
|
||||
|
||||
class YamlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a yaml file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
yaml_file: PathType | None = DEFAULT_PATH,
|
||||
yaml_file_encoding: str | None = None,
|
||||
yaml_config_section: str | None = None,
|
||||
):
|
||||
self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file')
|
||||
self.yaml_file_encoding = (
|
||||
yaml_file_encoding
|
||||
if yaml_file_encoding is not None
|
||||
else settings_cls.model_config.get('yaml_file_encoding')
|
||||
)
|
||||
self.yaml_config_section = (
|
||||
yaml_config_section
|
||||
if yaml_config_section is not None
|
||||
else settings_cls.model_config.get('yaml_config_section')
|
||||
)
|
||||
self.yaml_data = self._read_files(self.yaml_file_path)
|
||||
|
||||
if self.yaml_config_section:
|
||||
try:
|
||||
self.yaml_data = self.yaml_data[self.yaml_config_section]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f'yaml_config_section key "{self.yaml_config_section}" not found in {self.yaml_file_path}'
|
||||
)
|
||||
super().__init__(settings_cls, self.yaml_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
import_yaml()
|
||||
with open(file_path, encoding=self.yaml_file_encoding) as yaml_file:
|
||||
return yaml.safe_load(yaml_file) or {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(yaml_file={self.yaml_file_path})'
|
||||
|
||||
|
||||
__all__ = ['YamlConfigSettingsSource']
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Type definitions for pydantic-settings sources."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic._internal._dataclasses import PydanticDataclass
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
PydanticModel = Union[PydanticDataclass, BaseModel]
|
||||
else:
|
||||
PydanticModel = Any
|
||||
|
||||
|
||||
class EnvNoneType(str):
|
||||
pass
|
||||
|
||||
|
||||
class NoDecode:
|
||||
"""Annotation to prevent decoding of a field value."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ForceDecode:
|
||||
"""Annotation to force decoding of a field value."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
DotenvType = Union[Path, str, Sequence[Union[Path, str]]]
|
||||
PathType = Union[Path, str, Sequence[Union[Path, str]]]
|
||||
DEFAULT_PATH: PathType = Path('')
|
||||
|
||||
# This is used as default value for `_env_file` in the `BaseSettings` class and
|
||||
# `env_file` in `DotEnvSettingsSource` so the default can be distinguished from `None`.
|
||||
# See the docstring of `BaseSettings` for more details.
|
||||
ENV_FILE_SENTINEL: DotenvType = Path('')
|
||||
|
||||
|
||||
class _CliSubCommand:
|
||||
pass
|
||||
|
||||
|
||||
class _CliPositionalArg:
|
||||
pass
|
||||
|
||||
|
||||
class _CliImplicitFlag:
|
||||
pass
|
||||
|
||||
|
||||
class _CliExplicitFlag:
|
||||
pass
|
||||
|
||||
|
||||
class _CliUnknownArgs:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
'DEFAULT_PATH',
|
||||
'ENV_FILE_SENTINEL',
|
||||
'DotenvType',
|
||||
'EnvNoneType',
|
||||
'ForceDecode',
|
||||
'NoDecode',
|
||||
'PathType',
|
||||
'PydanticModel',
|
||||
'_CliExplicitFlag',
|
||||
'_CliImplicitFlag',
|
||||
'_CliPositionalArg',
|
||||
'_CliSubCommand',
|
||||
'_CliUnknownArgs',
|
||||
]
|
||||
@@ -0,0 +1,206 @@
|
||||
"""Utility functions for pydantic-settings sources."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Json, RootModel, Secret
|
||||
from pydantic._internal._utils import is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from typing_extensions import get_args, get_origin
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from ..exceptions import SettingsError
|
||||
from ..utils import _lenient_issubclass
|
||||
from .types import EnvNoneType
|
||||
|
||||
|
||||
def _get_env_var_key(key: str, case_sensitive: bool = False) -> str:
|
||||
return key if case_sensitive else key.lower()
|
||||
|
||||
|
||||
def _parse_env_none_str(value: str | None, parse_none_str: str | None = None) -> str | None | EnvNoneType:
|
||||
return value if not (value == parse_none_str and parse_none_str is not None) else EnvNoneType(value)
|
||||
|
||||
|
||||
def parse_env_vars(
|
||||
env_vars: Mapping[str, str | None],
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
return {
|
||||
_get_env_var_key(k, case_sensitive): _parse_env_none_str(v, parse_none_str)
|
||||
for k, v in env_vars.items()
|
||||
if not (ignore_empty and v == '')
|
||||
}
|
||||
|
||||
|
||||
def _annotation_is_complex(annotation: Any, metadata: list[Any]) -> bool:
|
||||
# If the model is a root model, the root annotation should be used to
|
||||
# evaluate the complexity.
|
||||
if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)):
|
||||
annotation = annotation.__value__
|
||||
if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel:
|
||||
annotation = cast('type[RootModel[Any]]', annotation)
|
||||
root_annotation = annotation.model_fields['root'].annotation
|
||||
if root_annotation is not None: # pragma: no branch
|
||||
annotation = root_annotation
|
||||
|
||||
if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
|
||||
return False
|
||||
|
||||
origin = get_origin(annotation)
|
||||
|
||||
# Check if annotation is of the form Annotated[type, metadata].
|
||||
if typing_objects.is_annotated(origin):
|
||||
# Return result of recursive call on inner type.
|
||||
inner, *meta = get_args(annotation)
|
||||
return _annotation_is_complex(inner, meta)
|
||||
|
||||
if origin is Secret:
|
||||
return False
|
||||
|
||||
return (
|
||||
_annotation_is_complex_inner(annotation)
|
||||
or _annotation_is_complex_inner(origin)
|
||||
or hasattr(origin, '__pydantic_core_schema__')
|
||||
or hasattr(origin, '__get_pydantic_core_schema__')
|
||||
)
|
||||
|
||||
|
||||
def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
|
||||
if _lenient_issubclass(annotation, (str, bytes)):
|
||||
return False
|
||||
|
||||
return _lenient_issubclass(
|
||||
annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)
|
||||
) or is_dataclass(annotation)
|
||||
|
||||
|
||||
def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
|
||||
"""Check if a union type contains any complex types."""
|
||||
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
|
||||
|
||||
|
||||
def _annotation_contains_types(
|
||||
annotation: type[Any] | None,
|
||||
types: tuple[Any, ...],
|
||||
is_include_origin: bool = True,
|
||||
is_strip_annotated: bool = False,
|
||||
) -> bool:
|
||||
"""Check if a type annotation contains any of the specified types."""
|
||||
if is_strip_annotated:
|
||||
annotation = _strip_annotated(annotation)
|
||||
if is_include_origin is True and get_origin(annotation) in types:
|
||||
return True
|
||||
for type_ in get_args(annotation):
|
||||
if _annotation_contains_types(type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated):
|
||||
return True
|
||||
return annotation in types
|
||||
|
||||
|
||||
def _strip_annotated(annotation: Any) -> Any:
|
||||
if typing_objects.is_annotated(get_origin(annotation)):
|
||||
return annotation.__origin__
|
||||
else:
|
||||
return annotation
|
||||
|
||||
|
||||
def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> Optional[str]:
|
||||
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
|
||||
if _lenient_issubclass(type_, Enum):
|
||||
if value in tuple(val.value for val in type_):
|
||||
return type_(value).name
|
||||
return None
|
||||
|
||||
|
||||
def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any:
|
||||
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
|
||||
if _lenient_issubclass(type_, Enum):
|
||||
if name in tuple(val.name for val in type_):
|
||||
return type_[name]
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_fields(model_cls: type[Any]) -> dict[str, Any]:
|
||||
"""Get fields from a pydantic model or dataclass."""
|
||||
|
||||
if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'):
|
||||
return model_cls.__pydantic_fields__
|
||||
if is_model_class(model_cls):
|
||||
return model_cls.model_fields
|
||||
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')
|
||||
|
||||
|
||||
def _get_alias_names(
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
alias_path_args: Optional[dict[str, Optional[int]]] = None,
|
||||
case_sensitive: bool = True,
|
||||
) -> tuple[tuple[str, ...], bool]:
|
||||
"""Get alias names for a field, handling alias paths and case sensitivity."""
|
||||
from pydantic import AliasChoices, AliasPath
|
||||
|
||||
alias_names: list[str] = []
|
||||
is_alias_path_only: bool = True
|
||||
if not any((field_info.alias, field_info.validation_alias)):
|
||||
alias_names += [field_name]
|
||||
is_alias_path_only = False
|
||||
else:
|
||||
new_alias_paths: list[AliasPath] = []
|
||||
for alias in (field_info.alias, field_info.validation_alias):
|
||||
if alias is None:
|
||||
continue
|
||||
elif isinstance(alias, str):
|
||||
alias_names.append(alias)
|
||||
is_alias_path_only = False
|
||||
elif isinstance(alias, AliasChoices):
|
||||
for name in alias.choices:
|
||||
if isinstance(name, str):
|
||||
alias_names.append(name)
|
||||
is_alias_path_only = False
|
||||
else:
|
||||
new_alias_paths.append(name)
|
||||
else:
|
||||
new_alias_paths.append(alias)
|
||||
for alias_path in new_alias_paths:
|
||||
name = cast(str, alias_path.path[0])
|
||||
name = name.lower() if not case_sensitive else name
|
||||
if alias_path_args is not None:
|
||||
alias_path_args[name] = (
|
||||
alias_path.path[1] if len(alias_path.path) > 1 and isinstance(alias_path.path[1], int) else None
|
||||
)
|
||||
if not alias_names and is_alias_path_only:
|
||||
alias_names.append(name)
|
||||
if not case_sensitive:
|
||||
alias_names = [alias_name.lower() for alias_name in alias_names]
|
||||
return tuple(dict.fromkeys(alias_names)), is_alias_path_only
|
||||
|
||||
|
||||
def _is_function(obj: Any) -> bool:
|
||||
"""Check if an object is a function."""
|
||||
from types import BuiltinFunctionType, FunctionType
|
||||
|
||||
return isinstance(obj, (FunctionType, BuiltinFunctionType))
|
||||
|
||||
|
||||
__all__ = [
|
||||
'_annotation_contains_types',
|
||||
'_annotation_enum_name_to_val',
|
||||
'_annotation_enum_val_to_name',
|
||||
'_annotation_is_complex',
|
||||
'_annotation_is_complex_inner',
|
||||
'_get_alias_names',
|
||||
'_get_env_var_key',
|
||||
'_get_model_fields',
|
||||
'_is_function',
|
||||
'_parse_env_none_str',
|
||||
'_strip_annotated',
|
||||
'_union_is_complex',
|
||||
'parse_env_vars',
|
||||
]
|
||||
@@ -1,14 +1,19 @@
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Any, _GenericAlias # type: ignore [attr-defined]
|
||||
|
||||
path_type_labels = {
|
||||
'is_dir': 'directory',
|
||||
'is_file': 'file',
|
||||
'is_mount': 'mount point',
|
||||
'is_symlink': 'symlink',
|
||||
'is_block_device': 'block device',
|
||||
'is_char_device': 'char device',
|
||||
'is_fifo': 'FIFO',
|
||||
'is_socket': 'socket',
|
||||
from typing_extensions import get_origin
|
||||
|
||||
_PATH_TYPE_LABELS = {
|
||||
Path.is_dir: 'directory',
|
||||
Path.is_file: 'file',
|
||||
Path.is_mount: 'mount point',
|
||||
Path.is_symlink: 'symlink',
|
||||
Path.is_block_device: 'block device',
|
||||
Path.is_char_device: 'char device',
|
||||
Path.is_fifo: 'FIFO',
|
||||
Path.is_socket: 'socket',
|
||||
}
|
||||
|
||||
|
||||
@@ -17,8 +22,27 @@ def path_type_label(p: Path) -> str:
|
||||
Find out what sort of thing a path is.
|
||||
"""
|
||||
assert p.exists(), 'path does not exist'
|
||||
for method, name in path_type_labels.items():
|
||||
if getattr(p, method)():
|
||||
for method, name in _PATH_TYPE_LABELS.items():
|
||||
if method(p):
|
||||
return name
|
||||
|
||||
return 'unknown'
|
||||
return 'unknown' # pragma: no cover
|
||||
|
||||
|
||||
# TODO remove and replace usage by `isinstance(cls, type) and issubclass(cls, class_or_tuple)`
|
||||
# once we drop support for Python 3.10.
|
||||
def _lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
|
||||
except TypeError:
|
||||
if get_origin(cls) is not None:
|
||||
# Up until Python 3.10, isinstance(<generic_alias>, type) is True
|
||||
# (e.g. list[int])
|
||||
return False
|
||||
raise
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
_WithArgsTypes = tuple()
|
||||
else:
|
||||
_WithArgsTypes = (_GenericAlias, types.GenericAlias, types.UnionType)
|
||||
|
||||
@@ -1 +1 @@
|
||||
VERSION = '2.0.3'
|
||||
VERSION = '2.11.0'
|
||||
|
||||
Reference in New Issue
Block a user