API refactor
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
2025-10-07 16:25:52 +09:00
parent 76d0d86211
commit 91c7e04474
1171 changed files with 81940 additions and 44117 deletions

View File

@@ -0,0 +1,41 @@
"""Package containing individual source implementations."""
from .aws import AWSSecretsManagerSettingsSource
from .azure import AzureKeyVaultSettingsSource
from .cli import (
CliExplicitFlag,
CliImplicitFlag,
CliMutuallyExclusiveGroup,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
CliSuppress,
)
from .dotenv import DotEnvSettingsSource
from .env import EnvSettingsSource
from .gcp import GoogleSecretManagerSettingsSource
from .json import JsonConfigSettingsSource
from .pyproject import PyprojectTomlConfigSettingsSource
from .secrets import SecretsSettingsSource
from .toml import TomlConfigSettingsSource
from .yaml import YamlConfigSettingsSource
__all__ = [
'AWSSecretsManagerSettingsSource',
'AzureKeyVaultSettingsSource',
'CliExplicitFlag',
'CliImplicitFlag',
'CliMutuallyExclusiveGroup',
'CliPositionalArg',
'CliSettingsSource',
'CliSubCommand',
'CliSuppress',
'DotEnvSettingsSource',
'EnvSettingsSource',
'GoogleSecretManagerSettingsSource',
'JsonConfigSettingsSource',
'PyprojectTomlConfigSettingsSource',
'SecretsSettingsSource',
'TomlConfigSettingsSource',
'YamlConfigSettingsSource',
]

View File

@@ -0,0 +1,79 @@
from __future__ import annotations as _annotations # important for BaseSettings import to work
import json
from collections.abc import Mapping
from typing import TYPE_CHECKING, Optional
from ..utils import parse_env_vars
from .env import EnvSettingsSource
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
boto3_client = None
SecretsManagerClient = None
def import_aws_secrets_manager() -> None:
global boto3_client
global SecretsManagerClient
try:
from boto3 import client as boto3_client
from mypy_boto3_secretsmanager.client import SecretsManagerClient
except ImportError as e: # pragma: no cover
raise ImportError(
'AWS Secrets Manager dependencies are not installed, run `pip install pydantic-settings[aws-secrets-manager]`'
) from e
class AWSSecretsManagerSettingsSource(EnvSettingsSource):
_secret_id: str
_secretsmanager_client: SecretsManagerClient # type: ignore
def __init__(
self,
settings_cls: type[BaseSettings],
secret_id: str,
region_name: str | None = None,
endpoint_url: str | None = None,
case_sensitive: bool | None = True,
env_prefix: str | None = None,
env_nested_delimiter: str | None = '--',
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
import_aws_secrets_manager()
self._secretsmanager_client = boto3_client('secretsmanager', region_name=region_name, endpoint_url=endpoint_url) # type: ignore
self._secret_id = secret_id
super().__init__(
settings_cls,
case_sensitive=case_sensitive,
env_prefix=env_prefix,
env_nested_delimiter=env_nested_delimiter,
env_ignore_empty=False,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
response = self._secretsmanager_client.get_secret_value(SecretId=self._secret_id) # type: ignore
return parse_env_vars(
json.loads(response['SecretString']),
self.case_sensitive,
self.env_ignore_empty,
self.env_parse_none_str,
)
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(secret_id={self._secret_id!r}, '
f'env_nested_delimiter={self.env_nested_delimiter!r})'
)
__all__ = [
'AWSSecretsManagerSettingsSource',
]

View File

@@ -0,0 +1,145 @@
"""Azure Key Vault settings source."""
from __future__ import annotations as _annotations
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Optional
from pydantic.alias_generators import to_snake
from pydantic.fields import FieldInfo
from .env import EnvSettingsSource
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.keyvault.secrets import SecretClient
from pydantic_settings.main import BaseSettings
else:
TokenCredential = None
ResourceNotFoundError = None
SecretClient = None
def import_azure_key_vault() -> None:
global TokenCredential
global SecretClient
global ResourceNotFoundError
try:
from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.keyvault.secrets import SecretClient
except ImportError as e: # pragma: no cover
raise ImportError(
'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`'
) from e
class AzureKeyVaultMapping(Mapping[str, Optional[str]]):
_loaded_secrets: dict[str, str | None]
_secret_client: SecretClient
_secret_names: list[str]
def __init__(
self,
secret_client: SecretClient,
case_sensitive: bool,
snake_case_conversion: bool,
) -> None:
self._loaded_secrets = {}
self._secret_client = secret_client
self._case_sensitive = case_sensitive
self._snake_case_conversion = snake_case_conversion
self._secret_map: dict[str, str] = self._load_remote()
def _load_remote(self) -> dict[str, str]:
secret_names: Iterator[str] = (
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
)
if self._snake_case_conversion:
return {to_snake(name): name for name in secret_names}
if self._case_sensitive:
return {name: name for name in secret_names}
return {name.lower(): name for name in secret_names}
def __getitem__(self, key: str) -> str | None:
new_key = key
if self._snake_case_conversion:
new_key = to_snake(key)
elif not self._case_sensitive:
new_key = key.lower()
if new_key not in self._loaded_secrets:
if new_key in self._secret_map:
self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value
else:
raise KeyError(key)
return self._loaded_secrets[new_key]
def __len__(self) -> int:
return len(self._secret_map)
def __iter__(self) -> Iterator[str]:
return iter(self._secret_map.keys())
class AzureKeyVaultSettingsSource(EnvSettingsSource):
_url: str
_credential: TokenCredential
def __init__(
self,
settings_cls: type[BaseSettings],
url: str,
credential: TokenCredential,
dash_to_underscore: bool = False,
case_sensitive: bool | None = None,
snake_case_conversion: bool = False,
env_prefix: str | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
import_azure_key_vault()
self._url = url
self._credential = credential
self._dash_to_underscore = dash_to_underscore
self._snake_case_conversion = snake_case_conversion
super().__init__(
settings_cls,
case_sensitive=False if snake_case_conversion else case_sensitive,
env_prefix=env_prefix,
env_nested_delimiter='__' if snake_case_conversion else '--',
env_ignore_empty=False,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
return AzureKeyVaultMapping(
secret_client=secret_client,
case_sensitive=self.case_sensitive,
snake_case_conversion=self._snake_case_conversion,
)
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
if self._snake_case_conversion:
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
if self._dash_to_underscore:
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
return super()._extract_field_info(field, field_name)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
__all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,168 @@
"""Dotenv file settings source."""
from __future__ import annotations as _annotations
import os
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any
from dotenv import dotenv_values
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
get_origin,
)
from typing_inspection.introspection import is_union_origin
from ..types import ENV_FILE_SENTINEL, DotenvType
from ..utils import (
_annotation_is_complex,
_union_is_complex,
parse_env_vars,
)
from .env import EnvSettingsSource
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class DotEnvSettingsSource(EnvSettingsSource):
"""
Source class for loading settings values from env files.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
env_file: DotenvType | None = ENV_FILE_SENTINEL,
env_file_encoding: str | None = None,
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str | None = None,
env_nested_max_split: int | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
self.env_file_encoding = (
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
)
super().__init__(
settings_cls,
case_sensitive,
env_prefix,
env_nested_delimiter,
env_nested_max_split,
env_ignore_empty,
env_parse_none_str,
env_parse_enums,
)
def _load_env_vars(self) -> Mapping[str, str | None]:
return self._read_env_files()
@staticmethod
def _static_read_env_file(
file_path: Path,
*,
encoding: str | None = None,
case_sensitive: bool = False,
ignore_empty: bool = False,
parse_none_str: str | None = None,
) -> Mapping[str, str | None]:
file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8')
return parse_env_vars(file_vars, case_sensitive, ignore_empty, parse_none_str)
def _read_env_file(
self,
file_path: Path,
) -> Mapping[str, str | None]:
return self._static_read_env_file(
file_path,
encoding=self.env_file_encoding,
case_sensitive=self.case_sensitive,
ignore_empty=self.env_ignore_empty,
parse_none_str=self.env_parse_none_str,
)
def _read_env_files(self) -> Mapping[str, str | None]:
env_files = self.env_file
if env_files is None:
return {}
if isinstance(env_files, (str, os.PathLike)):
env_files = [env_files]
dotenv_vars: dict[str, str | None] = {}
for env_file in env_files:
env_path = Path(env_file).expanduser()
if env_path.is_file():
dotenv_vars.update(self._read_env_file(env_path))
return dotenv_vars
def __call__(self) -> dict[str, Any]:
data: dict[str, Any] = super().__call__()
is_extra_allowed = self.config.get('extra') != 'forbid'
# As `extra` config is allowed in dotenv settings source, We have to
# update data with extra env variables from dotenv file.
for env_name, env_value in self.env_vars.items():
if not env_value or env_name in data or (self.env_prefix and env_name in self.settings_cls.model_fields):
continue
env_used = False
for field_name, field in self.settings_cls.model_fields.items():
for _, field_env_name, _ in self._extract_field_info(field, field_name):
if env_name == field_env_name or (
(
_annotation_is_complex(field.annotation, field.metadata)
or (
is_union_origin(get_origin(field.annotation))
and _union_is_complex(field.annotation, field.metadata)
)
)
and env_name.startswith(field_env_name)
):
env_used = True
break
if env_used:
break
if not env_used:
if is_extra_allowed and env_name.startswith(self.env_prefix):
# env_prefix should be respected and removed from the env_name
normalized_env_name = env_name[len(self.env_prefix) :]
data[normalized_env_name] = env_value
else:
data[env_name] = env_value
return data
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})'
)
def read_env_file(
file_path: Path,
*,
encoding: str | None = None,
case_sensitive: bool = False,
ignore_empty: bool = False,
parse_none_str: str | None = None,
) -> Mapping[str, str | None]:
warnings.warn(
'read_env_file will be removed in the next version, use DotEnvSettingsSource._static_read_env_file if you must',
DeprecationWarning,
)
return DotEnvSettingsSource._static_read_env_file(
file_path,
encoding=encoding,
case_sensitive=case_sensitive,
ignore_empty=ignore_empty,
parse_none_str=parse_none_str,
)
__all__ = ['DotEnvSettingsSource', 'read_env_file']

View File

@@ -0,0 +1,270 @@
from __future__ import annotations as _annotations
import os
from collections.abc import Mapping
from typing import (
TYPE_CHECKING,
Any,
)
from pydantic._internal._utils import deep_update, is_model_class
from pydantic.dataclasses import is_pydantic_dataclass
from pydantic.fields import FieldInfo
from typing_extensions import get_args, get_origin
from typing_inspection.introspection import is_union_origin
from ...utils import _lenient_issubclass
from ..base import PydanticBaseEnvSettingsSource
from ..types import EnvNoneType
from ..utils import (
_annotation_enum_name_to_val,
_get_model_fields,
_union_is_complex,
parse_env_vars,
)
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class EnvSettingsSource(PydanticBaseEnvSettingsSource):
"""
Source class for loading settings values from environment variables.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_nested_delimiter: str | None = None,
env_nested_max_split: int | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
super().__init__(
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
)
self.env_nested_delimiter = (
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
)
self.env_nested_max_split = (
env_nested_max_split if env_nested_max_split is not None else self.config.get('env_nested_max_split')
)
self.maxsplit = (self.env_nested_max_split or 0) - 1
self.env_prefix_len = len(self.env_prefix)
self.env_vars = self._load_env_vars()
def _load_env_vars(self) -> Mapping[str, str | None]:
return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty, self.env_parse_none_str)
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Gets the value for field from environment variables and a flag to determine whether value is complex.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple that contains the value (`None` if not found), key, and
a flag to determine whether value is complex.
"""
env_val: str | None = None
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
env_val = self.env_vars.get(env_name)
if env_val is not None:
break
return env_val, field_key, value_is_complex
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
"""
Prepare value for the field.
* Extract value for nested field.
* Deserialize value to python object for complex field.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple contains prepared value for the field.
Raises:
ValuesError: When There is an error in deserializing value for complex field.
"""
is_complex, allow_parse_failure = self._field_is_complex(field)
if self.env_parse_enums:
enum_val = _annotation_enum_name_to_val(field.annotation, value)
value = value if enum_val is None else enum_val
if is_complex or value_is_complex:
if isinstance(value, EnvNoneType):
return value
elif value is None:
# field is complex but no value found so far, try explode_env_vars
env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
if env_val_built:
return env_val_built
else:
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
try:
value = self.decode_complex_value(field_name, field, value)
except ValueError as e:
if not allow_parse_failure:
raise e
if isinstance(value, dict):
return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
else:
return value
elif value is not None:
# simplest case, field is not complex, we only need to add the value if it was found
return value
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
"""
Find out if a field is complex, and if so whether JSON errors should be ignored
"""
if self.field_is_complex(field):
allow_parse_failure = False
elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
allow_parse_failure = True
else:
return False, False
return True, allow_parse_failure
# Default value of `case_sensitive` is `None`, because we don't want to break existing behavior.
# We have to change the method to a non-static method and use
# `self.case_sensitive` instead in V3.
def next_field(
self, field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None
) -> FieldInfo | None:
"""
Find the field in a sub model by key(env name)
By having the following models:
```py
class SubSubModel(BaseSettings):
dvals: Dict
class SubModel(BaseSettings):
vals: list[str]
sub_sub_model: SubSubModel
class Cfg(BaseSettings):
sub_model: SubModel
```
Then:
next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class
Args:
field: The field.
key: The key (env name).
case_sensitive: Whether to search for key case sensitively.
Returns:
Field if it finds the next field otherwise `None`.
"""
if not field:
return None
annotation = field.annotation if isinstance(field, FieldInfo) else field
for type_ in get_args(annotation):
type_has_key = self.next_field(type_, key, case_sensitive)
if type_has_key:
return type_has_key
if is_model_class(annotation) or is_pydantic_dataclass(annotation): # type: ignore[arg-type]
fields = _get_model_fields(annotation)
# `case_sensitive is None` is here to be compatible with the old behavior.
# Has to be removed in V3.
for field_name, f in fields.items():
for _, env_name, _ in self._extract_field_info(f, field_name):
if case_sensitive is None or case_sensitive:
if field_name == key or env_name == key:
return f
elif field_name.lower() == key.lower() or env_name.lower() == key.lower():
return f
return None
def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]:
"""
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
This is applied to a single field, hence filtering by env_var prefix.
Args:
field_name: The field name.
field: The field.
env_vars: Environment variables.
Returns:
A dictionary contains extracted values from nested env values.
"""
if not self.env_nested_delimiter:
return {}
ann = field.annotation
is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)
prefixes = [
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
]
result: dict[str, Any] = {}
for env_name, env_val in env_vars.items():
try:
prefix = next(prefix for prefix in prefixes if env_name.startswith(prefix))
except StopIteration:
continue
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
env_name_without_prefix = env_name[len(prefix) :]
*keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter, self.maxsplit)
env_var = result
target_field: FieldInfo | None = field
for key in keys:
target_field = self.next_field(target_field, key, self.case_sensitive)
if isinstance(env_var, dict):
env_var = env_var.setdefault(key, {})
# get proper field with last_key
target_field = self.next_field(target_field, last_key, self.case_sensitive)
# check if env_val maps to a complex field and if so, parse the env_val
if (target_field or is_dict) and env_val:
if target_field:
is_complex, allow_json_failure = self._field_is_complex(target_field)
if self.env_parse_enums:
enum_val = _annotation_enum_name_to_val(target_field.annotation, env_val)
env_val = env_val if enum_val is None else enum_val
else:
# nested field type is dict
is_complex, allow_json_failure = True, True
if is_complex:
try:
env_val = self.decode_complex_value(last_key, target_field, env_val) # type: ignore
except ValueError as e:
if not allow_json_failure:
raise e
if isinstance(env_var, dict):
if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
env_var[last_key] = env_val
return result
def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '
f'env_prefix_len={self.env_prefix_len!r})'
)
__all__ = ['EnvSettingsSource']

View File

@@ -0,0 +1,152 @@
from __future__ import annotations as _annotations
from collections.abc import Iterator, Mapping
from functools import cached_property
from typing import TYPE_CHECKING, Optional
from .env import EnvSettingsSource
if TYPE_CHECKING:
from google.auth import default as google_auth_default
from google.auth.credentials import Credentials
from google.cloud.secretmanager import SecretManagerServiceClient
from pydantic_settings.main import BaseSettings
else:
Credentials = None
SecretManagerServiceClient = None
google_auth_default = None
def import_gcp_secret_manager() -> None:
global Credentials
global SecretManagerServiceClient
global google_auth_default
try:
from google.auth import default as google_auth_default
from google.auth.credentials import Credentials
from google.cloud.secretmanager import SecretManagerServiceClient
except ImportError as e: # pragma: no cover
raise ImportError(
'GCP Secret Manager dependencies are not installed, run `pip install pydantic-settings[gcp-secret-manager]`'
) from e
class GoogleSecretManagerMapping(Mapping[str, Optional[str]]):
_loaded_secrets: dict[str, str | None]
_secret_client: SecretManagerServiceClient
def __init__(self, secret_client: SecretManagerServiceClient, project_id: str, case_sensitive: bool) -> None:
self._loaded_secrets = {}
self._secret_client = secret_client
self._project_id = project_id
self._case_sensitive = case_sensitive
@property
def _gcp_project_path(self) -> str:
return self._secret_client.common_project_path(self._project_id)
@cached_property
def _secret_names(self) -> list[str]:
rv: list[str] = []
secrets = self._secret_client.list_secrets(parent=self._gcp_project_path)
for secret in secrets:
name = self._secret_client.parse_secret_path(secret.name).get('secret', '')
if not self._case_sensitive:
name = name.lower()
rv.append(name)
return rv
def _secret_version_path(self, key: str, version: str = 'latest') -> str:
return self._secret_client.secret_version_path(self._project_id, key, version)
def __getitem__(self, key: str) -> str | None:
if not self._case_sensitive:
key = key.lower()
if key not in self._loaded_secrets:
# If we know the key isn't available in secret manager, raise a key error
if key not in self._secret_names:
raise KeyError(key)
try:
self._loaded_secrets[key] = self._secret_client.access_secret_version(
name=self._secret_version_path(key)
).payload.data.decode('UTF-8')
except Exception:
raise KeyError(key)
return self._loaded_secrets[key]
def __len__(self) -> int:
return len(self._secret_names)
def __iter__(self) -> Iterator[str]:
return iter(self._secret_names)
class GoogleSecretManagerSettingsSource(EnvSettingsSource):
_credentials: Credentials
_secret_client: SecretManagerServiceClient
_project_id: str
def __init__(
self,
settings_cls: type[BaseSettings],
credentials: Credentials | None = None,
project_id: str | None = None,
env_prefix: str | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
secret_client: SecretManagerServiceClient | None = None,
case_sensitive: bool | None = True,
) -> None:
# Import Google Packages if they haven't already been imported
if SecretManagerServiceClient is None or Credentials is None or google_auth_default is None:
import_gcp_secret_manager()
# If credentials or project_id are not passed, then
# try to get them from the default function
if not credentials or not project_id:
_creds, _project_id = google_auth_default() # type: ignore[no-untyped-call]
# Set the credentials and/or project id if they weren't specified
if credentials is None:
credentials = _creds
if project_id is None:
if isinstance(_project_id, str):
project_id = _project_id
else:
raise AttributeError(
'project_id is required to be specified either as an argument or from the google.auth.default. See https://google-auth.readthedocs.io/en/master/reference/google.auth.html#google.auth.default'
)
self._credentials: Credentials = credentials
self._project_id: str = project_id
if secret_client:
self._secret_client = secret_client
else:
self._secret_client = SecretManagerServiceClient(credentials=self._credentials)
super().__init__(
settings_cls,
case_sensitive=case_sensitive,
env_prefix=env_prefix,
env_ignore_empty=False,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
return GoogleSecretManagerMapping(
self._secret_client, project_id=self._project_id, case_sensitive=self.case_sensitive
)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(project_id={self._project_id!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
__all__ = ['GoogleSecretManagerSettingsSource', 'GoogleSecretManagerMapping']

View File

@@ -0,0 +1,47 @@
"""JSON file settings source."""
from __future__ import annotations as _annotations
import json
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from ..base import ConfigFileSourceMixin, InitSettingsSource
from ..types import DEFAULT_PATH, PathType
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class JsonConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
"""
A source class that loads variables from a JSON file
"""
def __init__(
self,
settings_cls: type[BaseSettings],
json_file: PathType | None = DEFAULT_PATH,
json_file_encoding: str | None = None,
):
self.json_file_path = json_file if json_file != DEFAULT_PATH else settings_cls.model_config.get('json_file')
self.json_file_encoding = (
json_file_encoding
if json_file_encoding is not None
else settings_cls.model_config.get('json_file_encoding')
)
self.json_data = self._read_files(self.json_file_path)
super().__init__(settings_cls, self.json_data)
def _read_file(self, file_path: Path) -> dict[str, Any]:
with open(file_path, encoding=self.json_file_encoding) as json_file:
return json.load(json_file)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(json_file={self.json_file_path})'
__all__ = ['JsonConfigSettingsSource']

View File

@@ -0,0 +1,62 @@
"""Pyproject TOML file settings source."""
from __future__ import annotations as _annotations
from pathlib import Path
from typing import (
TYPE_CHECKING,
)
from .toml import TomlConfigSettingsSource
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSource):
"""
A source class that loads variables from a `pyproject.toml` file.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
toml_file: Path | None = None,
) -> None:
self.toml_file_path = self._pick_pyproject_toml_file(
toml_file, settings_cls.model_config.get('pyproject_toml_depth', 0)
)
self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get(
'pyproject_toml_table_header', ('tool', 'pydantic-settings')
)
self.toml_data = self._read_files(self.toml_file_path)
for key in self.toml_table_header:
self.toml_data = self.toml_data.get(key, {})
super(TomlConfigSettingsSource, self).__init__(settings_cls, self.toml_data)
@staticmethod
def _pick_pyproject_toml_file(provided: Path | None, depth: int) -> Path:
"""Pick a `pyproject.toml` file path to use.
Args:
provided: Explicit path provided when instantiating this class.
depth: Number of directories up the tree to check of a pyproject.toml.
"""
if provided:
return provided.resolve()
rv = Path.cwd() / 'pyproject.toml'
count = 0
if not rv.is_file():
child = rv.parent.parent / 'pyproject.toml'
while count < depth:
if child.is_file():
return child
if str(child.parent) == rv.root:
break # end discovery after checking system root once
child = child.parent.parent / 'pyproject.toml'
count += 1
return rv
__all__ = ['PyprojectTomlConfigSettingsSource']

View File

@@ -0,0 +1,125 @@
"""Secrets file settings source."""
from __future__ import annotations as _annotations
import os
import warnings
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from pydantic.fields import FieldInfo
from pydantic_settings.utils import path_type_label
from ...exceptions import SettingsError
from ..base import PydanticBaseEnvSettingsSource
from ..types import PathType
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
class SecretsSettingsSource(PydanticBaseEnvSettingsSource):
"""
Source class for loading settings values from secret files.
"""
def __init__(
self,
settings_cls: type[BaseSettings],
secrets_dir: PathType | None = None,
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_ignore_empty: bool | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
super().__init__(
settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums
)
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')
def __call__(self) -> dict[str, Any]:
"""
Build fields from "secrets" files.
"""
secrets: dict[str, str | None] = {}
if self.secrets_dir is None:
return secrets
secrets_dirs = [self.secrets_dir] if isinstance(self.secrets_dir, (str, os.PathLike)) else self.secrets_dir
secrets_paths = [Path(p).expanduser() for p in secrets_dirs]
self.secrets_paths = []
for path in secrets_paths:
if not path.exists():
warnings.warn(f'directory "{path}" does not exist')
else:
self.secrets_paths.append(path)
if not len(self.secrets_paths):
return secrets
for path in self.secrets_paths:
if not path.is_dir():
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(path)}')
return super().__call__()
@classmethod
def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None:
"""
Find a file within path's directory matching filename, optionally ignoring case.
Args:
dir_path: Directory path.
file_name: File name.
case_sensitive: Whether to search for file name case sensitively.
Returns:
Whether file path or `None` if file does not exist in directory.
"""
for f in dir_path.iterdir():
if f.name == file_name:
return f
elif not case_sensitive and f.name.lower() == file_name.lower():
return f
return None
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
"""
Gets the value for field from secret file and a flag to determine whether value is complex.
Args:
field: The field.
field_name: The field name.
Returns:
A tuple that contains the value (`None` if the file does not exist), key, and
a flag to determine whether value is complex.
"""
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
# paths reversed to match the last-wins behaviour of `env_file`
for secrets_path in reversed(self.secrets_paths):
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
if not path:
# path does not exist, we currently don't return a warning for this
continue
if path.is_file():
return path.read_text().strip(), field_key, value_is_complex
else:
warnings.warn(
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
stacklevel=4,
)
return None, field_key, value_is_complex
def __repr__(self) -> str:
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'

View File

@@ -0,0 +1,66 @@
"""TOML file settings source."""
from __future__ import annotations as _annotations
import sys
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from ..base import ConfigFileSourceMixin, InitSettingsSource
from ..types import DEFAULT_PATH, PathType
if TYPE_CHECKING:
from pydantic_settings.main import BaseSettings
if sys.version_info >= (3, 11):
import tomllib
else:
tomllib = None
import tomli
else:
tomllib = None
tomli = None
def import_toml() -> None:
global tomli
global tomllib
if sys.version_info < (3, 11):
if tomli is not None:
return
try:
import tomli
except ImportError as e: # pragma: no cover
raise ImportError('tomli is not installed, run `pip install pydantic-settings[toml]`') from e
else:
if tomllib is not None:
return
import tomllib
class TomlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
"""
A source class that loads variables from a TOML file
"""
def __init__(
self,
settings_cls: type[BaseSettings],
toml_file: PathType | None = DEFAULT_PATH,
):
self.toml_file_path = toml_file if toml_file != DEFAULT_PATH else settings_cls.model_config.get('toml_file')
self.toml_data = self._read_files(self.toml_file_path)
super().__init__(settings_cls, self.toml_data)
def _read_file(self, file_path: Path) -> dict[str, Any]:
import_toml()
with open(file_path, mode='rb') as toml_file:
if sys.version_info < (3, 11):
return tomli.load(toml_file)
return tomllib.load(toml_file)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(toml_file={self.toml_file_path})'

View File

@@ -0,0 +1,75 @@
"""YAML file settings source."""
from __future__ import annotations as _annotations
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from ..base import ConfigFileSourceMixin, InitSettingsSource
from ..types import DEFAULT_PATH, PathType
if TYPE_CHECKING:
import yaml
from pydantic_settings.main import BaseSettings
else:
yaml = None
def import_yaml() -> None:
global yaml
if yaml is not None:
return
try:
import yaml
except ImportError as e:
raise ImportError('PyYAML is not installed, run `pip install pydantic-settings[yaml]`') from e
class YamlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
"""
A source class that loads variables from a yaml file
"""
def __init__(
self,
settings_cls: type[BaseSettings],
yaml_file: PathType | None = DEFAULT_PATH,
yaml_file_encoding: str | None = None,
yaml_config_section: str | None = None,
):
self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file')
self.yaml_file_encoding = (
yaml_file_encoding
if yaml_file_encoding is not None
else settings_cls.model_config.get('yaml_file_encoding')
)
self.yaml_config_section = (
yaml_config_section
if yaml_config_section is not None
else settings_cls.model_config.get('yaml_config_section')
)
self.yaml_data = self._read_files(self.yaml_file_path)
if self.yaml_config_section:
try:
self.yaml_data = self.yaml_data[self.yaml_config_section]
except KeyError:
raise KeyError(
f'yaml_config_section key "{self.yaml_config_section}" not found in {self.yaml_file_path}'
)
super().__init__(settings_cls, self.yaml_data)
def _read_file(self, file_path: Path) -> dict[str, Any]:
import_yaml()
with open(file_path, encoding=self.yaml_file_encoding) as yaml_file:
return yaml.safe_load(yaml_file) or {}
def __repr__(self) -> str:
return f'{self.__class__.__name__}(yaml_file={self.yaml_file_path})'
__all__ = ['YamlConfigSettingsSource']