ai-station/.venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/azure.py

146 lines
4.9 KiB
Python
Raw Normal View History

2025-12-25 14:54:33 +00:00
"""Azure Key Vault settings source."""
from __future__ import annotations as _annotations
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING
from pydantic.alias_generators import to_snake
from pydantic.fields import FieldInfo
from .env import EnvSettingsSource
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.keyvault.secrets import SecretClient
from pydantic_settings.main import BaseSettings
else:
TokenCredential = None
ResourceNotFoundError = None
SecretClient = None
def import_azure_key_vault() -> None:
global TokenCredential
global SecretClient
global ResourceNotFoundError
try:
from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.keyvault.secrets import SecretClient
except ImportError as e: # pragma: no cover
raise ImportError(
'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`'
) from e
class AzureKeyVaultMapping(Mapping[str, str | None]):
_loaded_secrets: dict[str, str | None]
_secret_client: SecretClient
_secret_names: list[str]
def __init__(
self,
secret_client: SecretClient,
case_sensitive: bool,
snake_case_conversion: bool,
) -> None:
self._loaded_secrets = {}
self._secret_client = secret_client
self._case_sensitive = case_sensitive
self._snake_case_conversion = snake_case_conversion
self._secret_map: dict[str, str] = self._load_remote()
def _load_remote(self) -> dict[str, str]:
secret_names: Iterator[str] = (
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
)
if self._snake_case_conversion:
return {to_snake(name): name for name in secret_names}
if self._case_sensitive:
return {name: name for name in secret_names}
return {name.lower(): name for name in secret_names}
def __getitem__(self, key: str) -> str | None:
new_key = key
if self._snake_case_conversion:
new_key = to_snake(key)
elif not self._case_sensitive:
new_key = key.lower()
if new_key not in self._loaded_secrets:
if new_key in self._secret_map:
self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value
else:
raise KeyError(key)
return self._loaded_secrets[new_key]
def __len__(self) -> int:
return len(self._secret_map)
def __iter__(self) -> Iterator[str]:
return iter(self._secret_map.keys())
class AzureKeyVaultSettingsSource(EnvSettingsSource):
_url: str
_credential: TokenCredential
def __init__(
self,
settings_cls: type[BaseSettings],
url: str,
credential: TokenCredential,
dash_to_underscore: bool = False,
case_sensitive: bool | None = None,
snake_case_conversion: bool = False,
env_prefix: str | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
import_azure_key_vault()
self._url = url
self._credential = credential
self._dash_to_underscore = dash_to_underscore
self._snake_case_conversion = snake_case_conversion
super().__init__(
settings_cls,
case_sensitive=False if snake_case_conversion else case_sensitive,
env_prefix=env_prefix,
env_nested_delimiter='__' if snake_case_conversion else '--',
env_ignore_empty=False,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)
def _load_env_vars(self) -> Mapping[str, str | None]:
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
return AzureKeyVaultMapping(
secret_client=secret_client,
case_sensitive=self.case_sensitive,
snake_case_conversion=self._snake_case_conversion,
)
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
if self._snake_case_conversion:
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
if self._dash_to_underscore:
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
return super()._extract_field_info(field, field_name)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
__all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']