ai-station/.venv/lib/python3.12/site-packages/litellm/proxy/prompts/prompt_registry.py

179 lines
6.3 KiB
Python

import importlib
import os
from pathlib import Path
from typing import Callable, Dict, Optional
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.types.prompts.init_prompts import (
PromptInfo,
PromptLiteLLMParams,
PromptSpec,
)
prompt_initializer_registry = {}
def get_prompt_initializer_from_integrations():
"""
Get prompt initializers by discovering them from the prompt_integrations directory structure.
Scans the integrations directory for subdirectories containing __init__.py files
with either prompt_initializer_registry or initialize_prompt functions.
Returns:
Dict[str, Callable]: A dictionary mapping guardrail types to their initializer functions
"""
discovered_initializers: Dict[str, Callable] = {}
try:
# Get the path to the prompt_integrations directory
current_dir = Path(__file__).parent.parent.parent
integrations_dir = os.path.join(current_dir, "integrations")
if not os.path.exists(integrations_dir):
verbose_proxy_logger.debug("integrations directory not found")
return discovered_initializers
# Scan each subdirectory in prompt_integrations
for item in os.listdir(integrations_dir):
item_path = os.path.join(integrations_dir, item)
# Skip files and __pycache__ directories
if not os.path.isdir(item_path) or item.startswith("__"):
continue
# Check if the directory has an __init__.py file
init_file = os.path.join(item_path, "__init__.py")
if not os.path.exists(init_file):
continue
module_path = f"litellm.integrations.{item}"
try:
# Import the module
verbose_proxy_logger.debug(
f"Discovering prompt integrations in: {module_path}"
)
module = importlib.import_module(module_path)
# Check for prompt_initializer_registry dictionary
if hasattr(module, "prompt_initializer_registry"):
registry = getattr(module, "prompt_initializer_registry")
if isinstance(registry, dict):
discovered_initializers.update(registry)
verbose_proxy_logger.debug(
f"Found prompt_initializer_registry in {module_path}: {list(registry.keys())}"
)
except ImportError as e:
verbose_proxy_logger.error(f"Could not import {module_path}: {e}")
continue
except Exception as e:
verbose_proxy_logger.error(f"Error processing {module_path}: {e}")
continue
verbose_proxy_logger.debug(
f"Discovered {len(discovered_initializers)} prompt initializers: {list(discovered_initializers.keys())}"
)
except Exception as e:
verbose_proxy_logger.error(f"Error discovering prompt initializers: {e}")
return discovered_initializers
prompt_initializer_registry = get_prompt_initializer_from_integrations()
class InMemoryPromptRegistry:
"""
Class that handles adding prompt callbacks to the CallbacksManager.
"""
def __init__(self):
self.IN_MEMORY_PROMPTS: Dict[str, PromptSpec] = {}
"""
Prompt id to Prompt object mapping
"""
self.prompt_id_to_custom_prompt: Dict[str, Optional[CustomPromptManagement]] = (
{}
)
"""
Guardrail id to CustomGuardrail object mapping
"""
def initialize_prompt(
self,
prompt: PromptSpec,
config_file_path: Optional[str] = None,
) -> Optional[PromptSpec]:
"""
Initialize a guardrail from a dictionary and add it to the litellm callback manager
Returns a Guardrail object if the guardrail is initialized successfully
"""
import litellm
prompt_id = prompt.prompt_id
if prompt_id in self.IN_MEMORY_PROMPTS:
verbose_proxy_logger.debug("prompt_id already exists in IN_MEMORY_PROMPTS")
return self.IN_MEMORY_PROMPTS[prompt_id]
custom_prompt_callback: Optional[CustomPromptManagement] = None
litellm_params_data = prompt.litellm_params
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
if isinstance(litellm_params_data, dict):
litellm_params = PromptLiteLLMParams(**litellm_params_data)
else:
litellm_params = litellm_params_data
prompt_integration = litellm_params.prompt_integration
if prompt_integration is None:
raise ValueError("prompt_integration is required")
initializer = prompt_initializer_registry.get(prompt_integration)
if initializer:
custom_prompt_callback = initializer(litellm_params, prompt)
if not isinstance(custom_prompt_callback, CustomPromptManagement):
raise ValueError(
f"CustomPromptManagement is required, got {type(custom_prompt_callback)}"
)
litellm.logging_callback_manager.add_litellm_callback(custom_prompt_callback) # type: ignore
else:
raise ValueError(f"Unsupported prompt: {prompt_integration}")
parsed_prompt = PromptSpec(
prompt_id=prompt_id,
litellm_params=litellm_params,
prompt_info=prompt.prompt_info or PromptInfo(prompt_type="config"),
created_at=prompt.created_at,
updated_at=prompt.updated_at,
)
# store references to the prompt in memory
self.IN_MEMORY_PROMPTS[prompt_id] = parsed_prompt
self.prompt_id_to_custom_prompt[prompt_id] = custom_prompt_callback
return parsed_prompt
def get_prompt_by_id(self, prompt_id: str) -> Optional[PromptSpec]:
"""
Get a prompt by its ID from memory
"""
return self.IN_MEMORY_PROMPTS.get(prompt_id)
def get_prompt_callback_by_id(
self, prompt_id: str
) -> Optional[CustomPromptManagement]:
"""
Get a prompt callback by its ID from memory
"""
return self.prompt_id_to_custom_prompt.get(prompt_id)
IN_MEMORY_PROMPT_REGISTRY = InMemoryPromptRegistry()