ai-station/.venv/lib/python3.12/site-packages/traceloop/sdk/prompts/client.py

149 lines
5.8 KiB
Python

from typing import Optional
from jinja2 import Environment, meta
from traceloop.sdk.prompts.model import Prompt, PromptVersion, TemplateEngine
from traceloop.sdk.prompts.registry import PromptRegistry
from traceloop.sdk.tracing.tracing import set_managed_prompt_tracing_context
def get_effective_version(prompt: Prompt) -> PromptVersion:
if len(prompt.versions) == 0:
raise Exception(f"No versions exist for {prompt.key} prompt")
return next(v for v in prompt.versions if v.id == prompt.target.version)
def get_version_by_name(prompt: Prompt, name: str) -> PromptVersion:
if len(prompt.versions) == 0:
raise Exception(f"No versions exist for {prompt.key} prompt")
return next(v for v in prompt.versions if v.name == name)
def get_version_by_hash(prompt: Prompt, hash: str) -> PromptVersion:
if len(prompt.versions) == 0:
raise Exception(f"No versions exist for {prompt.key} prompt")
return next(v for v in prompt.versions if v.hash == hash)
def get_specific_version(prompt: Prompt, version: int) -> PromptVersion:
if len(prompt.versions) == 0:
raise Exception(f"No versions exist for {prompt.key} prompt")
return next(v for v in prompt.versions if v.version == version)
class PromptRegistryClient:
_registry: PromptRegistry
_jinja_env: Environment
def __new__(cls) -> "PromptRegistryClient":
if not hasattr(cls, "instance"):
obj = cls.instance = super(PromptRegistryClient, cls).__new__(cls)
obj._registry = PromptRegistry()
obj._jinja_env = Environment()
return cls.instance
def render_prompt(
self,
key: str,
version: Optional[int] = None,
version_name: Optional[str] = None,
version_hash: Optional[str] = None,
variables: dict = {},
):
prompt = self._registry.get_prompt_by_key(key)
if prompt is None:
raise Exception(f"Prompt {key} does not exist")
prompt_version = None
try:
if version is not None:
prompt_version = get_specific_version(prompt, version)
elif version_name is not None:
prompt_version = get_version_by_name(prompt, version_name)
elif version_hash is not None:
prompt_version = get_version_by_hash(prompt, version_hash)
else:
prompt_version = get_effective_version(prompt)
except StopIteration:
raise Exception(
f"Prompt {key} does not have an available version to render"
)
# By default, OpenAI will set tool_choice to "auto"
# if tools not provided and there is tool_choice set it throws an error
if (
not prompt_version.llm_config.tools
or len(prompt_version.llm_config.tools) == 0
) and prompt_version.llm_config.tool_choice is not None:
prompt_version.llm_config.tool_choice = None
params_dict = {"messages": self.render_messages(prompt_version, **variables)}
params_dict.update(
(k, v) for k, v in iter(prompt_version.llm_config) if v not in [None, []]
)
params_dict.pop("mode")
set_managed_prompt_tracing_context(
prompt.key,
prompt_version.version,
prompt_version.name,
prompt_version.hash,
variables,
)
return params_dict
def render_messages(self, prompt_version: PromptVersion, **args):
if prompt_version.templating_engine == TemplateEngine.JINJA2:
rendered_messages = []
for msg in prompt_version.messages:
if isinstance(msg.template, str):
template = self._jinja_env.from_string(msg.template)
template_variables = meta.find_undeclared_variables(
self._jinja_env.parse(msg.template)
)
missing_variables = template_variables.difference(set(args.keys()))
if missing_variables == set():
rendered_msg = template.render(args)
else:
raise Exception(
f"Input variables: {','.join(str(var) for var in missing_variables)} are missing"
)
else:
rendered_msg = []
template_variables = []
for content in msg.template:
if content.type == "text":
template = self._jinja_env.from_string(content.text)
template_variables = meta.find_undeclared_variables(
self._jinja_env.parse(msg.template)
)
missing_variables = template_variables.difference(
set(args.keys())
)
if missing_variables != set():
raise Exception(
f"Input variables: {','.join(str(var) for var in missing_variables)} are missing"
)
rendered_msg.append(
{
"type": "text",
"text": template.render(args),
}
)
elif content.type == "image_url":
rendered_msg.append(content.dict())
rendered_messages.append({"role": msg.role, "content": rendered_msg})
return rendered_messages
else:
raise Exception(
f"Templating engine {prompt_version.templating_engine} is not supported"
)