This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
"""Package containing individual source implementations."""
|
||||
|
||||
from .aws import AWSSecretsManagerSettingsSource
|
||||
from .azure import AzureKeyVaultSettingsSource
|
||||
from .cli import (
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
)
|
||||
from .dotenv import DotEnvSettingsSource
|
||||
from .env import EnvSettingsSource
|
||||
from .gcp import GoogleSecretManagerSettingsSource
|
||||
from .json import JsonConfigSettingsSource
|
||||
from .pyproject import PyprojectTomlConfigSettingsSource
|
||||
from .secrets import SecretsSettingsSource
|
||||
from .toml import TomlConfigSettingsSource
|
||||
from .yaml import YamlConfigSettingsSource
|
||||
|
||||
__all__ = [
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'DotEnvSettingsSource',
|
||||
'EnvSettingsSource',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
]
|
||||
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations as _annotations # important for BaseSettings import to work
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..utils import parse_env_vars
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
boto3_client = None
|
||||
SecretsManagerClient = None
|
||||
|
||||
|
||||
def import_aws_secrets_manager() -> None:
|
||||
global boto3_client
|
||||
global SecretsManagerClient
|
||||
|
||||
try:
|
||||
from boto3 import client as boto3_client
|
||||
from mypy_boto3_secretsmanager.client import SecretsManagerClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'AWS Secrets Manager dependencies are not installed, run `pip install pydantic-settings[aws-secrets-manager]`'
|
||||
) from e
|
||||
|
||||
|
||||
class AWSSecretsManagerSettingsSource(EnvSettingsSource):
|
||||
_secret_id: str
|
||||
_secretsmanager_client: SecretsManagerClient # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
secret_id: str,
|
||||
region_name: str | None = None,
|
||||
endpoint_url: str | None = None,
|
||||
case_sensitive: bool | None = True,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = '--',
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
import_aws_secrets_manager()
|
||||
self._secretsmanager_client = boto3_client('secretsmanager', region_name=region_name, endpoint_url=endpoint_url) # type: ignore
|
||||
self._secret_id = secret_id
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
|
||||
response = self._secretsmanager_client.get_secret_value(SecretId=self._secret_id) # type: ignore
|
||||
|
||||
return parse_env_vars(
|
||||
json.loads(response['SecretString']),
|
||||
self.case_sensitive,
|
||||
self.env_ignore_empty,
|
||||
self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(secret_id={self._secret_id!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
]
|
||||
@@ -0,0 +1,145 @@
|
||||
"""Azure Key Vault settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from pydantic.alias_generators import to_snake
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
TokenCredential = None
|
||||
ResourceNotFoundError = None
|
||||
SecretClient = None
|
||||
|
||||
|
||||
def import_azure_key_vault() -> None:
|
||||
global TokenCredential
|
||||
global SecretClient
|
||||
global ResourceNotFoundError
|
||||
|
||||
try:
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`'
|
||||
) from e
|
||||
|
||||
|
||||
class AzureKeyVaultMapping(Mapping[str, Optional[str]]):
|
||||
_loaded_secrets: dict[str, str | None]
|
||||
_secret_client: SecretClient
|
||||
_secret_names: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret_client: SecretClient,
|
||||
case_sensitive: bool,
|
||||
snake_case_conversion: bool,
|
||||
) -> None:
|
||||
self._loaded_secrets = {}
|
||||
self._secret_client = secret_client
|
||||
self._case_sensitive = case_sensitive
|
||||
self._snake_case_conversion = snake_case_conversion
|
||||
self._secret_map: dict[str, str] = self._load_remote()
|
||||
|
||||
def _load_remote(self) -> dict[str, str]:
|
||||
secret_names: Iterator[str] = (
|
||||
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
|
||||
)
|
||||
|
||||
if self._snake_case_conversion:
|
||||
return {to_snake(name): name for name in secret_names}
|
||||
|
||||
if self._case_sensitive:
|
||||
return {name: name for name in secret_names}
|
||||
|
||||
return {name.lower(): name for name in secret_names}
|
||||
|
||||
def __getitem__(self, key: str) -> str | None:
|
||||
new_key = key
|
||||
|
||||
if self._snake_case_conversion:
|
||||
new_key = to_snake(key)
|
||||
elif not self._case_sensitive:
|
||||
new_key = key.lower()
|
||||
|
||||
if new_key not in self._loaded_secrets:
|
||||
if new_key in self._secret_map:
|
||||
self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._loaded_secrets[new_key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._secret_map)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._secret_map.keys())
|
||||
|
||||
|
||||
class AzureKeyVaultSettingsSource(EnvSettingsSource):
|
||||
_url: str
|
||||
_credential: TokenCredential
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
url: str,
|
||||
credential: TokenCredential,
|
||||
dash_to_underscore: bool = False,
|
||||
case_sensitive: bool | None = None,
|
||||
snake_case_conversion: bool = False,
|
||||
env_prefix: str | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
import_azure_key_vault()
|
||||
self._url = url
|
||||
self._credential = credential
|
||||
self._dash_to_underscore = dash_to_underscore
|
||||
self._snake_case_conversion = snake_case_conversion
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=False if snake_case_conversion else case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter='__' if snake_case_conversion else '--',
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
|
||||
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
|
||||
return AzureKeyVaultMapping(
|
||||
secret_client=secret_client,
|
||||
case_sensitive=self.case_sensitive,
|
||||
snake_case_conversion=self._snake_case_conversion,
|
||||
)
|
||||
|
||||
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
|
||||
if self._snake_case_conversion:
|
||||
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
|
||||
|
||||
if self._dash_to_underscore:
|
||||
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
|
||||
|
||||
return super()._extract_field_info(field, field_name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
|
||||
|
||||
__all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,168 @@
|
||||
"""Dotenv file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dotenv import dotenv_values
|
||||
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
|
||||
get_origin,
|
||||
)
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ..types import ENV_FILE_SENTINEL, DotenvType
|
||||
from ..utils import (
|
||||
_annotation_is_complex,
|
||||
_union_is_complex,
|
||||
parse_env_vars,
|
||||
)
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class DotEnvSettingsSource(EnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from env files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
env_file: DotenvType | None = ENV_FILE_SENTINEL,
|
||||
env_file_encoding: str | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
env_nested_max_split: int | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
|
||||
self.env_file_encoding = (
|
||||
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
|
||||
)
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive,
|
||||
env_prefix,
|
||||
env_nested_delimiter,
|
||||
env_nested_max_split,
|
||||
env_ignore_empty,
|
||||
env_parse_none_str,
|
||||
env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return self._read_env_files()
|
||||
|
||||
@staticmethod
|
||||
def _static_read_env_file(
|
||||
file_path: Path,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8')
|
||||
return parse_env_vars(file_vars, case_sensitive, ignore_empty, parse_none_str)
|
||||
|
||||
def _read_env_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
) -> Mapping[str, str | None]:
|
||||
return self._static_read_env_file(
|
||||
file_path,
|
||||
encoding=self.env_file_encoding,
|
||||
case_sensitive=self.case_sensitive,
|
||||
ignore_empty=self.env_ignore_empty,
|
||||
parse_none_str=self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def _read_env_files(self) -> Mapping[str, str | None]:
|
||||
env_files = self.env_file
|
||||
if env_files is None:
|
||||
return {}
|
||||
|
||||
if isinstance(env_files, (str, os.PathLike)):
|
||||
env_files = [env_files]
|
||||
|
||||
dotenv_vars: dict[str, str | None] = {}
|
||||
for env_file in env_files:
|
||||
env_path = Path(env_file).expanduser()
|
||||
if env_path.is_file():
|
||||
dotenv_vars.update(self._read_env_file(env_path))
|
||||
|
||||
return dotenv_vars
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = super().__call__()
|
||||
is_extra_allowed = self.config.get('extra') != 'forbid'
|
||||
|
||||
# As `extra` config is allowed in dotenv settings source, We have to
|
||||
# update data with extra env variables from dotenv file.
|
||||
for env_name, env_value in self.env_vars.items():
|
||||
if not env_value or env_name in data or (self.env_prefix and env_name in self.settings_cls.model_fields):
|
||||
continue
|
||||
env_used = False
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
for _, field_env_name, _ in self._extract_field_info(field, field_name):
|
||||
if env_name == field_env_name or (
|
||||
(
|
||||
_annotation_is_complex(field.annotation, field.metadata)
|
||||
or (
|
||||
is_union_origin(get_origin(field.annotation))
|
||||
and _union_is_complex(field.annotation, field.metadata)
|
||||
)
|
||||
)
|
||||
and env_name.startswith(field_env_name)
|
||||
):
|
||||
env_used = True
|
||||
break
|
||||
if env_used:
|
||||
break
|
||||
if not env_used:
|
||||
if is_extra_allowed and env_name.startswith(self.env_prefix):
|
||||
# env_prefix should be respected and removed from the env_name
|
||||
normalized_env_name = env_name[len(self.env_prefix) :]
|
||||
data[normalized_env_name] = env_value
|
||||
else:
|
||||
data[env_name] = env_value
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
def read_env_file(
|
||||
file_path: Path,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
warnings.warn(
|
||||
'read_env_file will be removed in the next version, use DotEnvSettingsSource._static_read_env_file if you must',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return DotEnvSettingsSource._static_read_env_file(
|
||||
file_path,
|
||||
encoding=encoding,
|
||||
case_sensitive=case_sensitive,
|
||||
ignore_empty=ignore_empty,
|
||||
parse_none_str=parse_none_str,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['DotEnvSettingsSource', 'read_env_file']
|
||||
@@ -0,0 +1,270 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from pydantic._internal._utils import deep_update, is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import get_args, get_origin
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ...utils import _lenient_issubclass
|
||||
from ..base import PydanticBaseEnvSettingsSource
|
||||
from ..types import EnvNoneType
|
||||
from ..utils import (
|
||||
_annotation_enum_name_to_val,
|
||||
_get_model_fields,
|
||||
_union_is_complex,
|
||||
parse_env_vars,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class EnvSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from environment variables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
env_nested_max_split: int | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
|
||||
)
|
||||
self.env_nested_delimiter = (
|
||||
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
|
||||
)
|
||||
self.env_nested_max_split = (
|
||||
env_nested_max_split if env_nested_max_split is not None else self.config.get('env_nested_max_split')
|
||||
)
|
||||
self.maxsplit = (self.env_nested_max_split or 0) - 1
|
||||
self.env_prefix_len = len(self.env_prefix)
|
||||
|
||||
self.env_vars = self._load_env_vars()
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty, self.env_parse_none_str)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from environment variables and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value (`None` if not found), key, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
env_val: str | None = None
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
env_val = self.env_vars.get(env_name)
|
||||
if env_val is not None:
|
||||
break
|
||||
|
||||
return env_val, field_key, value_is_complex
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
"""
|
||||
Prepare value for the field.
|
||||
|
||||
* Extract value for nested field.
|
||||
* Deserialize value to python object for complex field.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple contains prepared value for the field.
|
||||
|
||||
Raises:
|
||||
ValuesError: When There is an error in deserializing value for complex field.
|
||||
"""
|
||||
is_complex, allow_parse_failure = self._field_is_complex(field)
|
||||
if self.env_parse_enums:
|
||||
enum_val = _annotation_enum_name_to_val(field.annotation, value)
|
||||
value = value if enum_val is None else enum_val
|
||||
|
||||
if is_complex or value_is_complex:
|
||||
if isinstance(value, EnvNoneType):
|
||||
return value
|
||||
elif value is None:
|
||||
# field is complex but no value found so far, try explode_env_vars
|
||||
env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
|
||||
if env_val_built:
|
||||
return env_val_built
|
||||
else:
|
||||
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
|
||||
try:
|
||||
value = self.decode_complex_value(field_name, field, value)
|
||||
except ValueError as e:
|
||||
if not allow_parse_failure:
|
||||
raise e
|
||||
|
||||
if isinstance(value, dict):
|
||||
return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
|
||||
else:
|
||||
return value
|
||||
elif value is not None:
|
||||
# simplest case, field is not complex, we only need to add the value if it was found
|
||||
return value
|
||||
|
||||
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
|
||||
"""
|
||||
Find out if a field is complex, and if so whether JSON errors should be ignored
|
||||
"""
|
||||
if self.field_is_complex(field):
|
||||
allow_parse_failure = False
|
||||
elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
|
||||
allow_parse_failure = True
|
||||
else:
|
||||
return False, False
|
||||
|
||||
return True, allow_parse_failure
|
||||
|
||||
# Default value of `case_sensitive` is `None`, because we don't want to break existing behavior.
|
||||
# We have to change the method to a non-static method and use
|
||||
# `self.case_sensitive` instead in V3.
|
||||
def next_field(
|
||||
self, field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None
|
||||
) -> FieldInfo | None:
|
||||
"""
|
||||
Find the field in a sub model by key(env name)
|
||||
|
||||
By having the following models:
|
||||
|
||||
```py
|
||||
class SubSubModel(BaseSettings):
|
||||
dvals: Dict
|
||||
|
||||
class SubModel(BaseSettings):
|
||||
vals: list[str]
|
||||
sub_sub_model: SubSubModel
|
||||
|
||||
class Cfg(BaseSettings):
|
||||
sub_model: SubModel
|
||||
```
|
||||
|
||||
Then:
|
||||
next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
|
||||
next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
key: The key (env name).
|
||||
case_sensitive: Whether to search for key case sensitively.
|
||||
|
||||
Returns:
|
||||
Field if it finds the next field otherwise `None`.
|
||||
"""
|
||||
if not field:
|
||||
return None
|
||||
|
||||
annotation = field.annotation if isinstance(field, FieldInfo) else field
|
||||
for type_ in get_args(annotation):
|
||||
type_has_key = self.next_field(type_, key, case_sensitive)
|
||||
if type_has_key:
|
||||
return type_has_key
|
||||
if is_model_class(annotation) or is_pydantic_dataclass(annotation): # type: ignore[arg-type]
|
||||
fields = _get_model_fields(annotation)
|
||||
# `case_sensitive is None` is here to be compatible with the old behavior.
|
||||
# Has to be removed in V3.
|
||||
for field_name, f in fields.items():
|
||||
for _, env_name, _ in self._extract_field_info(f, field_name):
|
||||
if case_sensitive is None or case_sensitive:
|
||||
if field_name == key or env_name == key:
|
||||
return f
|
||||
elif field_name.lower() == key.lower() or env_name.lower() == key.lower():
|
||||
return f
|
||||
return None
|
||||
|
||||
def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
|
||||
"""
|
||||
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
|
||||
|
||||
This is applied to a single field, hence filtering by env_var prefix.
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
env_vars: Environment variables.
|
||||
|
||||
Returns:
|
||||
A dictionary contains extracted values from nested env values.
|
||||
"""
|
||||
if not self.env_nested_delimiter:
|
||||
return {}
|
||||
|
||||
ann = field.annotation
|
||||
is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)
|
||||
|
||||
prefixes = [
|
||||
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
|
||||
]
|
||||
result: dict[str, Any] = {}
|
||||
for env_name, env_val in env_vars.items():
|
||||
try:
|
||||
prefix = next(prefix for prefix in prefixes if env_name.startswith(prefix))
|
||||
except StopIteration:
|
||||
continue
|
||||
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
|
||||
env_name_without_prefix = env_name[len(prefix) :]
|
||||
*keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter, self.maxsplit)
|
||||
env_var = result
|
||||
target_field: FieldInfo | None = field
|
||||
for key in keys:
|
||||
target_field = self.next_field(target_field, key, self.case_sensitive)
|
||||
if isinstance(env_var, dict):
|
||||
env_var = env_var.setdefault(key, {})
|
||||
|
||||
# get proper field with last_key
|
||||
target_field = self.next_field(target_field, last_key, self.case_sensitive)
|
||||
|
||||
# check if env_val maps to a complex field and if so, parse the env_val
|
||||
if (target_field or is_dict) and env_val:
|
||||
if target_field:
|
||||
is_complex, allow_json_failure = self._field_is_complex(target_field)
|
||||
if self.env_parse_enums:
|
||||
enum_val = _annotation_enum_name_to_val(target_field.annotation, env_val)
|
||||
env_val = env_val if enum_val is None else enum_val
|
||||
else:
|
||||
# nested field type is dict
|
||||
is_complex, allow_json_failure = True, True
|
||||
if is_complex:
|
||||
try:
|
||||
env_val = self.decode_complex_value(last_key, target_field, env_val) # type: ignore
|
||||
except ValueError as e:
|
||||
if not allow_json_failure:
|
||||
raise e
|
||||
if isinstance(env_var, dict):
|
||||
if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
|
||||
env_var[last_key] = env_val
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '
|
||||
f'env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['EnvSettingsSource']
|
||||
@@ -0,0 +1,152 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.auth import default as google_auth_default
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud.secretmanager import SecretManagerServiceClient
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
Credentials = None
|
||||
SecretManagerServiceClient = None
|
||||
google_auth_default = None
|
||||
|
||||
|
||||
def import_gcp_secret_manager() -> None:
|
||||
global Credentials
|
||||
global SecretManagerServiceClient
|
||||
global google_auth_default
|
||||
|
||||
try:
|
||||
from google.auth import default as google_auth_default
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud.secretmanager import SecretManagerServiceClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'GCP Secret Manager dependencies are not installed, run `pip install pydantic-settings[gcp-secret-manager]`'
|
||||
) from e
|
||||
|
||||
|
||||
class GoogleSecretManagerMapping(Mapping[str, Optional[str]]):
|
||||
_loaded_secrets: dict[str, str | None]
|
||||
_secret_client: SecretManagerServiceClient
|
||||
|
||||
def __init__(self, secret_client: SecretManagerServiceClient, project_id: str, case_sensitive: bool) -> None:
|
||||
self._loaded_secrets = {}
|
||||
self._secret_client = secret_client
|
||||
self._project_id = project_id
|
||||
self._case_sensitive = case_sensitive
|
||||
|
||||
@property
|
||||
def _gcp_project_path(self) -> str:
|
||||
return self._secret_client.common_project_path(self._project_id)
|
||||
|
||||
@cached_property
|
||||
def _secret_names(self) -> list[str]:
|
||||
rv: list[str] = []
|
||||
|
||||
secrets = self._secret_client.list_secrets(parent=self._gcp_project_path)
|
||||
for secret in secrets:
|
||||
name = self._secret_client.parse_secret_path(secret.name).get('secret', '')
|
||||
if not self._case_sensitive:
|
||||
name = name.lower()
|
||||
rv.append(name)
|
||||
return rv
|
||||
|
||||
def _secret_version_path(self, key: str, version: str = 'latest') -> str:
|
||||
return self._secret_client.secret_version_path(self._project_id, key, version)
|
||||
|
||||
def __getitem__(self, key: str) -> str | None:
|
||||
if not self._case_sensitive:
|
||||
key = key.lower()
|
||||
if key not in self._loaded_secrets:
|
||||
# If we know the key isn't available in secret manager, raise a key error
|
||||
if key not in self._secret_names:
|
||||
raise KeyError(key)
|
||||
|
||||
try:
|
||||
self._loaded_secrets[key] = self._secret_client.access_secret_version(
|
||||
name=self._secret_version_path(key)
|
||||
).payload.data.decode('UTF-8')
|
||||
except Exception:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._loaded_secrets[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._secret_names)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._secret_names)
|
||||
|
||||
|
||||
class GoogleSecretManagerSettingsSource(EnvSettingsSource):
|
||||
_credentials: Credentials
|
||||
_secret_client: SecretManagerServiceClient
|
||||
_project_id: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
credentials: Credentials | None = None,
|
||||
project_id: str | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
secret_client: SecretManagerServiceClient | None = None,
|
||||
case_sensitive: bool | None = True,
|
||||
) -> None:
|
||||
# Import Google Packages if they haven't already been imported
|
||||
if SecretManagerServiceClient is None or Credentials is None or google_auth_default is None:
|
||||
import_gcp_secret_manager()
|
||||
|
||||
# If credentials or project_id are not passed, then
|
||||
# try to get them from the default function
|
||||
if not credentials or not project_id:
|
||||
_creds, _project_id = google_auth_default() # type: ignore[no-untyped-call]
|
||||
|
||||
# Set the credentials and/or project id if they weren't specified
|
||||
if credentials is None:
|
||||
credentials = _creds
|
||||
|
||||
if project_id is None:
|
||||
if isinstance(_project_id, str):
|
||||
project_id = _project_id
|
||||
else:
|
||||
raise AttributeError(
|
||||
'project_id is required to be specified either as an argument or from the google.auth.default. See https://google-auth.readthedocs.io/en/master/reference/google.auth.html#google.auth.default'
|
||||
)
|
||||
|
||||
self._credentials: Credentials = credentials
|
||||
self._project_id: str = project_id
|
||||
|
||||
if secret_client:
|
||||
self._secret_client = secret_client
|
||||
else:
|
||||
self._secret_client = SecretManagerServiceClient(credentials=self._credentials)
|
||||
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
|
||||
return GoogleSecretManagerMapping(
|
||||
self._secret_client, project_id=self._project_id, case_sensitive=self.case_sensitive
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(project_id={self._project_id!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
|
||||
|
||||
__all__ = ['GoogleSecretManagerSettingsSource', 'GoogleSecretManagerMapping']
|
||||
@@ -0,0 +1,47 @@
|
||||
"""JSON file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class JsonConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a JSON file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
json_file: PathType | None = DEFAULT_PATH,
|
||||
json_file_encoding: str | None = None,
|
||||
):
|
||||
self.json_file_path = json_file if json_file != DEFAULT_PATH else settings_cls.model_config.get('json_file')
|
||||
self.json_file_encoding = (
|
||||
json_file_encoding
|
||||
if json_file_encoding is not None
|
||||
else settings_cls.model_config.get('json_file_encoding')
|
||||
)
|
||||
self.json_data = self._read_files(self.json_file_path)
|
||||
super().__init__(settings_cls, self.json_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
with open(file_path, encoding=self.json_file_encoding) as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(json_file={self.json_file_path})'
|
||||
|
||||
|
||||
__all__ = ['JsonConfigSettingsSource']
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Pyproject TOML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .toml import TomlConfigSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSource):
|
||||
"""
|
||||
A source class that loads variables from a `pyproject.toml` file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
toml_file: Path | None = None,
|
||||
) -> None:
|
||||
self.toml_file_path = self._pick_pyproject_toml_file(
|
||||
toml_file, settings_cls.model_config.get('pyproject_toml_depth', 0)
|
||||
)
|
||||
self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get(
|
||||
'pyproject_toml_table_header', ('tool', 'pydantic-settings')
|
||||
)
|
||||
self.toml_data = self._read_files(self.toml_file_path)
|
||||
for key in self.toml_table_header:
|
||||
self.toml_data = self.toml_data.get(key, {})
|
||||
super(TomlConfigSettingsSource, self).__init__(settings_cls, self.toml_data)
|
||||
|
||||
@staticmethod
|
||||
def _pick_pyproject_toml_file(provided: Path | None, depth: int) -> Path:
|
||||
"""Pick a `pyproject.toml` file path to use.
|
||||
|
||||
Args:
|
||||
provided: Explicit path provided when instantiating this class.
|
||||
depth: Number of directories up the tree to check of a pyproject.toml.
|
||||
|
||||
"""
|
||||
if provided:
|
||||
return provided.resolve()
|
||||
rv = Path.cwd() / 'pyproject.toml'
|
||||
count = 0
|
||||
if not rv.is_file():
|
||||
child = rv.parent.parent / 'pyproject.toml'
|
||||
while count < depth:
|
||||
if child.is_file():
|
||||
return child
|
||||
if str(child.parent) == rv.root:
|
||||
break # end discovery after checking system root once
|
||||
child = child.parent.parent / 'pyproject.toml'
|
||||
count += 1
|
||||
return rv
|
||||
|
||||
|
||||
__all__ = ['PyprojectTomlConfigSettingsSource']
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Secrets file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from pydantic_settings.utils import path_type_label
|
||||
|
||||
from ...exceptions import SettingsError
|
||||
from ..base import PydanticBaseEnvSettingsSource
|
||||
from ..types import PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class SecretsSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from secret files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
secrets_dir: PathType | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
|
||||
)
|
||||
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
"""
|
||||
Build fields from "secrets" files.
|
||||
"""
|
||||
secrets: dict[str, str | None] = {}
|
||||
|
||||
if self.secrets_dir is None:
|
||||
return secrets
|
||||
|
||||
secrets_dirs = [self.secrets_dir] if isinstance(self.secrets_dir, (str, os.PathLike)) else self.secrets_dir
|
||||
secrets_paths = [Path(p).expanduser() for p in secrets_dirs]
|
||||
self.secrets_paths = []
|
||||
|
||||
for path in secrets_paths:
|
||||
if not path.exists():
|
||||
warnings.warn(f'directory "{path}" does not exist')
|
||||
else:
|
||||
self.secrets_paths.append(path)
|
||||
|
||||
if not len(self.secrets_paths):
|
||||
return secrets
|
||||
|
||||
for path in self.secrets_paths:
|
||||
if not path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(path)}')
|
||||
|
||||
return super().__call__()
|
||||
|
||||
@classmethod
|
||||
def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None:
|
||||
"""
|
||||
Find a file within path's directory matching filename, optionally ignoring case.
|
||||
|
||||
Args:
|
||||
dir_path: Directory path.
|
||||
file_name: File name.
|
||||
case_sensitive: Whether to search for file name case sensitively.
|
||||
|
||||
Returns:
|
||||
Whether file path or `None` if file does not exist in directory.
|
||||
"""
|
||||
for f in dir_path.iterdir():
|
||||
if f.name == file_name:
|
||||
return f
|
||||
elif not case_sensitive and f.name.lower() == file_name.lower():
|
||||
return f
|
||||
return None
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from secret file and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value (`None` if the file does not exist), key, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
# paths reversed to match the last-wins behaviour of `env_file`
|
||||
for secrets_path in reversed(self.secrets_paths):
|
||||
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
|
||||
if not path:
|
||||
# path does not exist, we currently don't return a warning for this
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
return path.read_text().strip(), field_key, value_is_complex
|
||||
else:
|
||||
warnings.warn(
|
||||
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
return None, field_key, value_is_complex
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'
|
||||
@@ -0,0 +1,66 @@
|
||||
"""TOML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
else:
|
||||
tomllib = None
|
||||
import tomli
|
||||
else:
|
||||
tomllib = None
|
||||
tomli = None
|
||||
|
||||
|
||||
def import_toml() -> None:
|
||||
global tomli
|
||||
global tomllib
|
||||
if sys.version_info < (3, 11):
|
||||
if tomli is not None:
|
||||
return
|
||||
try:
|
||||
import tomli
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError('tomli is not installed, run `pip install pydantic-settings[toml]`') from e
|
||||
else:
|
||||
if tomllib is not None:
|
||||
return
|
||||
import tomllib
|
||||
|
||||
|
||||
class TomlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a TOML file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
toml_file: PathType | None = DEFAULT_PATH,
|
||||
):
|
||||
self.toml_file_path = toml_file if toml_file != DEFAULT_PATH else settings_cls.model_config.get('toml_file')
|
||||
self.toml_data = self._read_files(self.toml_file_path)
|
||||
super().__init__(settings_cls, self.toml_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
import_toml()
|
||||
with open(file_path, mode='rb') as toml_file:
|
||||
if sys.version_info < (3, 11):
|
||||
return tomli.load(toml_file)
|
||||
return tomllib.load(toml_file)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(toml_file={self.toml_file_path})'
|
||||
@@ -0,0 +1,75 @@
|
||||
"""YAML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import yaml
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
yaml = None
|
||||
|
||||
|
||||
def import_yaml() -> None:
|
||||
global yaml
|
||||
if yaml is not None:
|
||||
return
|
||||
try:
|
||||
import yaml
|
||||
except ImportError as e:
|
||||
raise ImportError('PyYAML is not installed, run `pip install pydantic-settings[yaml]`') from e
|
||||
|
||||
|
||||
class YamlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a yaml file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
yaml_file: PathType | None = DEFAULT_PATH,
|
||||
yaml_file_encoding: str | None = None,
|
||||
yaml_config_section: str | None = None,
|
||||
):
|
||||
self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file')
|
||||
self.yaml_file_encoding = (
|
||||
yaml_file_encoding
|
||||
if yaml_file_encoding is not None
|
||||
else settings_cls.model_config.get('yaml_file_encoding')
|
||||
)
|
||||
self.yaml_config_section = (
|
||||
yaml_config_section
|
||||
if yaml_config_section is not None
|
||||
else settings_cls.model_config.get('yaml_config_section')
|
||||
)
|
||||
self.yaml_data = self._read_files(self.yaml_file_path)
|
||||
|
||||
if self.yaml_config_section:
|
||||
try:
|
||||
self.yaml_data = self.yaml_data[self.yaml_config_section]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f'yaml_config_section key "{self.yaml_config_section}" not found in {self.yaml_file_path}'
|
||||
)
|
||||
super().__init__(settings_cls, self.yaml_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
import_yaml()
|
||||
with open(file_path, encoding=self.yaml_file_encoding) as yaml_file:
|
||||
return yaml.safe_load(yaml_file) or {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(yaml_file={self.yaml_file_path})'
|
||||
|
||||
|
||||
__all__ = ['YamlConfigSettingsSource']
|
||||
Reference in New Issue
Block a user