ai-station/.venv/lib/python3.12/site-packages/traceloop/sdk/prompts/client.py

149 lines
5.8 KiB
Python
Raw Normal View History

from typing import Optional
from jinja2 import Environment, meta
from traceloop.sdk.prompts.model import Prompt, PromptVersion, TemplateEngine
from traceloop.sdk.prompts.registry import PromptRegistry
from traceloop.sdk.tracing.tracing import set_managed_prompt_tracing_context
def get_effective_version(prompt: Prompt) -> PromptVersion:
if len(prompt.versions) == 0:
raise Exception(f"No versions exist for {prompt.key} prompt")
return next(v for v in prompt.versions if v.id == prompt.target.version)
def get_version_by_name(prompt: Prompt, name: str) -> PromptVersion:
if len(prompt.versions) == 0:
raise Exception(f"No versions exist for {prompt.key} prompt")
return next(v for v in prompt.versions if v.name == name)
def get_version_by_hash(prompt: Prompt, hash: str) -> PromptVersion:
if len(prompt.versions) == 0:
raise Exception(f"No versions exist for {prompt.key} prompt")
return next(v for v in prompt.versions if v.hash == hash)
def get_specific_version(prompt: Prompt, version: int) -> PromptVersion:
if len(prompt.versions) == 0:
raise Exception(f"No versions exist for {prompt.key} prompt")
return next(v for v in prompt.versions if v.version == version)
class PromptRegistryClient:
_registry: PromptRegistry
_jinja_env: Environment
def __new__(cls) -> "PromptRegistryClient":
if not hasattr(cls, "instance"):
obj = cls.instance = super(PromptRegistryClient, cls).__new__(cls)
obj._registry = PromptRegistry()
obj._jinja_env = Environment()
return cls.instance
def render_prompt(
self,
key: str,
version: Optional[int] = None,
version_name: Optional[str] = None,
version_hash: Optional[str] = None,
variables: dict = {},
):
prompt = self._registry.get_prompt_by_key(key)
if prompt is None:
raise Exception(f"Prompt {key} does not exist")
prompt_version = None
try:
if version is not None:
prompt_version = get_specific_version(prompt, version)
elif version_name is not None:
prompt_version = get_version_by_name(prompt, version_name)
elif version_hash is not None:
prompt_version = get_version_by_hash(prompt, version_hash)
else:
prompt_version = get_effective_version(prompt)
except StopIteration:
raise Exception(
f"Prompt {key} does not have an available version to render"
)
# By default, OpenAI will set tool_choice to "auto"
# if tools not provided and there is tool_choice set it throws an error
if (
not prompt_version.llm_config.tools
or len(prompt_version.llm_config.tools) == 0
) and prompt_version.llm_config.tool_choice is not None:
prompt_version.llm_config.tool_choice = None
params_dict = {"messages": self.render_messages(prompt_version, **variables)}
params_dict.update(
(k, v) for k, v in iter(prompt_version.llm_config) if v not in [None, []]
)
params_dict.pop("mode")
set_managed_prompt_tracing_context(
prompt.key,
prompt_version.version,
prompt_version.name,
prompt_version.hash,
variables,
)
return params_dict
def render_messages(self, prompt_version: PromptVersion, **args):
if prompt_version.templating_engine == TemplateEngine.JINJA2:
rendered_messages = []
for msg in prompt_version.messages:
if isinstance(msg.template, str):
template = self._jinja_env.from_string(msg.template)
template_variables = meta.find_undeclared_variables(
self._jinja_env.parse(msg.template)
)
missing_variables = template_variables.difference(set(args.keys()))
if missing_variables == set():
rendered_msg = template.render(args)
else:
raise Exception(
f"Input variables: {','.join(str(var) for var in missing_variables)} are missing"
)
else:
rendered_msg = []
template_variables = []
for content in msg.template:
if content.type == "text":
template = self._jinja_env.from_string(content.text)
template_variables = meta.find_undeclared_variables(
self._jinja_env.parse(msg.template)
)
missing_variables = template_variables.difference(
set(args.keys())
)
if missing_variables != set():
raise Exception(
f"Input variables: {','.join(str(var) for var in missing_variables)} are missing"
)
rendered_msg.append(
{
"type": "text",
"text": template.render(args),
}
)
elif content.type == "image_url":
rendered_msg.append(content.dict())
rendered_messages.append({"role": msg.role, "content": rendered_msg})
return rendered_messages
else:
raise Exception(
f"Templating engine {prompt_version.templating_engine} is not supported"
)