"""Azure Key Vault settings source.""" from __future__ import annotations as _annotations from collections.abc import Iterator, Mapping from typing import TYPE_CHECKING from pydantic.alias_generators import to_snake from pydantic.fields import FieldInfo from .env import EnvSettingsSource if TYPE_CHECKING: from azure.core.credentials import TokenCredential from azure.core.exceptions import ResourceNotFoundError from azure.keyvault.secrets import SecretClient from pydantic_settings.main import BaseSettings else: TokenCredential = None ResourceNotFoundError = None SecretClient = None def import_azure_key_vault() -> None: global TokenCredential global SecretClient global ResourceNotFoundError try: from azure.core.credentials import TokenCredential from azure.core.exceptions import ResourceNotFoundError from azure.keyvault.secrets import SecretClient except ImportError as e: # pragma: no cover raise ImportError( 'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`' ) from e class AzureKeyVaultMapping(Mapping[str, str | None]): _loaded_secrets: dict[str, str | None] _secret_client: SecretClient _secret_names: list[str] def __init__( self, secret_client: SecretClient, case_sensitive: bool, snake_case_conversion: bool, ) -> None: self._loaded_secrets = {} self._secret_client = secret_client self._case_sensitive = case_sensitive self._snake_case_conversion = snake_case_conversion self._secret_map: dict[str, str] = self._load_remote() def _load_remote(self) -> dict[str, str]: secret_names: Iterator[str] = ( secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled ) if self._snake_case_conversion: return {to_snake(name): name for name in secret_names} if self._case_sensitive: return {name: name for name in secret_names} return {name.lower(): name for name in secret_names} def __getitem__(self, key: str) -> str | None: new_key = key if self._snake_case_conversion: new_key = to_snake(key) elif not self._case_sensitive: new_key = key.lower() if new_key not in self._loaded_secrets: if new_key in self._secret_map: self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value else: raise KeyError(key) return self._loaded_secrets[new_key] def __len__(self) -> int: return len(self._secret_map) def __iter__(self) -> Iterator[str]: return iter(self._secret_map.keys()) class AzureKeyVaultSettingsSource(EnvSettingsSource): _url: str _credential: TokenCredential def __init__( self, settings_cls: type[BaseSettings], url: str, credential: TokenCredential, dash_to_underscore: bool = False, case_sensitive: bool | None = None, snake_case_conversion: bool = False, env_prefix: str | None = None, env_parse_none_str: str | None = None, env_parse_enums: bool | None = None, ) -> None: import_azure_key_vault() self._url = url self._credential = credential self._dash_to_underscore = dash_to_underscore self._snake_case_conversion = snake_case_conversion super().__init__( settings_cls, case_sensitive=False if snake_case_conversion else case_sensitive, env_prefix=env_prefix, env_nested_delimiter='__' if snake_case_conversion else '--', env_ignore_empty=False, env_parse_none_str=env_parse_none_str, env_parse_enums=env_parse_enums, ) def _load_env_vars(self) -> Mapping[str, str | None]: secret_client = SecretClient(vault_url=self._url, credential=self._credential) return AzureKeyVaultMapping( secret_client=secret_client, case_sensitive=self.case_sensitive, snake_case_conversion=self._snake_case_conversion, ) def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]: if self._snake_case_conversion: return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name)) if self._dash_to_underscore: return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name)) return super()._extract_field_info(field, field_name) def __repr__(self) -> str: return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})' __all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']