184 lines
5.5 KiB
Python
184 lines
5.5 KiB
Python
|
|
import logging
|
||
|
|
from typing import TYPE_CHECKING, Callable, Optional, TypedDict
|
||
|
|
|
||
|
|
from literalai.cache.prompt_helpers import put_prompt
|
||
|
|
from literalai.observability.generation import GenerationMessage
|
||
|
|
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from literalai.api import LiteralAPI
|
||
|
|
from literalai.cache.shared_cache import SharedCache
|
||
|
|
|
||
|
|
from literalai.api.helpers import gql
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
def create_prompt_lineage_helper(name: str, description: Optional[str] = None):
|
||
|
|
variables = {"name": name, "description": description}
|
||
|
|
|
||
|
|
def process_response(response):
|
||
|
|
prompt = response["data"]["createPromptLineage"]
|
||
|
|
if prompt and prompt.get("deletedAt"):
|
||
|
|
logger.warning(
|
||
|
|
f"Prompt {name} was deleted - please update any references to use an active prompt in production"
|
||
|
|
)
|
||
|
|
return prompt
|
||
|
|
|
||
|
|
description = "create prompt lineage"
|
||
|
|
|
||
|
|
return gql.CREATE_PROMPT_LINEAGE, description, variables, process_response
|
||
|
|
|
||
|
|
|
||
|
|
def get_prompt_lineage_helper(name: str):
|
||
|
|
variables = {"name": name}
|
||
|
|
|
||
|
|
def process_response(response):
|
||
|
|
prompt = response["data"]["promptLineage"]
|
||
|
|
if prompt and prompt.get("deletedAt"):
|
||
|
|
logger.warning(
|
||
|
|
f"Prompt {name} was deleted - please update any references to use an active prompt in production"
|
||
|
|
)
|
||
|
|
return prompt
|
||
|
|
|
||
|
|
description = "get prompt lineage"
|
||
|
|
|
||
|
|
return gql.GET_PROMPT_LINEAGE, description, variables, process_response
|
||
|
|
|
||
|
|
|
||
|
|
def create_prompt_helper(
|
||
|
|
api: "LiteralAPI",
|
||
|
|
lineage_id: str,
|
||
|
|
template_messages: list[GenerationMessage],
|
||
|
|
settings: Optional[ProviderSettings] = None,
|
||
|
|
tools: Optional[list[dict]] = None,
|
||
|
|
):
|
||
|
|
variables = {
|
||
|
|
"lineageId": lineage_id,
|
||
|
|
"templateMessages": template_messages,
|
||
|
|
"settings": settings,
|
||
|
|
"tools": tools,
|
||
|
|
}
|
||
|
|
|
||
|
|
def process_response(response):
|
||
|
|
prompt = response["data"]["createPromptVersion"]
|
||
|
|
prompt_lineage = prompt.get("lineage")
|
||
|
|
|
||
|
|
if prompt_lineage and prompt_lineage.get("deletedAt"):
|
||
|
|
logger.warning(
|
||
|
|
f"Prompt {prompt_lineage.get('name')} was deleted - please update any references to use an active prompt in production"
|
||
|
|
)
|
||
|
|
return Prompt.from_dict(api, prompt) if prompt else None
|
||
|
|
|
||
|
|
description = "create prompt version"
|
||
|
|
|
||
|
|
return gql.CREATE_PROMPT_VERSION, description, variables, process_response
|
||
|
|
|
||
|
|
|
||
|
|
def get_prompt_cache_key(
|
||
|
|
id: Optional[str], name: Optional[str], version: Optional[int]
|
||
|
|
) -> str:
|
||
|
|
if id:
|
||
|
|
return id
|
||
|
|
elif name and version:
|
||
|
|
return f"{name}-{version}"
|
||
|
|
elif name:
|
||
|
|
return name
|
||
|
|
else:
|
||
|
|
raise ValueError("Either the `id` or the `name` must be provided.")
|
||
|
|
|
||
|
|
|
||
|
|
def get_prompt_helper(
|
||
|
|
api: "LiteralAPI",
|
||
|
|
id: Optional[str] = None,
|
||
|
|
name: Optional[str] = None,
|
||
|
|
version: Optional[int] = 0,
|
||
|
|
cache: Optional["SharedCache"] = None,
|
||
|
|
) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]:
|
||
|
|
"""Helper function for getting prompts with caching logic"""
|
||
|
|
|
||
|
|
cached_prompt = None
|
||
|
|
timeout = 10
|
||
|
|
|
||
|
|
if cache:
|
||
|
|
cached_prompt = cache.get(get_prompt_cache_key(id, name, version))
|
||
|
|
timeout = 1 if cached_prompt else timeout
|
||
|
|
|
||
|
|
variables = {"id": id, "name": name, "version": version}
|
||
|
|
|
||
|
|
def process_response(response):
|
||
|
|
prompt_version = response["data"]["promptVersion"]
|
||
|
|
prompt_lineage = prompt_version.get("lineage")
|
||
|
|
|
||
|
|
if prompt_lineage and prompt_lineage.get("deletedAt"):
|
||
|
|
logger.warning(
|
||
|
|
f"Prompt {name} was deleted - please update any references to use an active prompt in production"
|
||
|
|
)
|
||
|
|
prompt = Prompt.from_dict(api, prompt_version) if prompt_version else None
|
||
|
|
if cache and prompt:
|
||
|
|
put_prompt(cache, prompt)
|
||
|
|
return prompt
|
||
|
|
|
||
|
|
description = "get prompt"
|
||
|
|
|
||
|
|
return (
|
||
|
|
gql.GET_PROMPT_VERSION,
|
||
|
|
description,
|
||
|
|
variables,
|
||
|
|
process_response,
|
||
|
|
timeout,
|
||
|
|
cached_prompt,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def create_prompt_variant_helper(
|
||
|
|
from_lineage_id: Optional[str] = None,
|
||
|
|
template_messages: list[GenerationMessage] = [],
|
||
|
|
settings: Optional[ProviderSettings] = None,
|
||
|
|
tools: Optional[list[dict]] = None,
|
||
|
|
):
|
||
|
|
variables = {
|
||
|
|
"fromLineageId": from_lineage_id,
|
||
|
|
"templateMessages": template_messages,
|
||
|
|
"settings": settings,
|
||
|
|
"tools": tools,
|
||
|
|
}
|
||
|
|
|
||
|
|
def process_response(response) -> Optional[str]:
|
||
|
|
variant = response["data"]["createPromptExperiment"]
|
||
|
|
return variant["id"] if variant else None
|
||
|
|
|
||
|
|
description = "create prompt variant"
|
||
|
|
|
||
|
|
return gql.CREATE_PROMPT_VARIANT, description, variables, process_response
|
||
|
|
|
||
|
|
|
||
|
|
class PromptRollout(TypedDict):
|
||
|
|
version: int
|
||
|
|
rollout: int
|
||
|
|
|
||
|
|
|
||
|
|
def get_prompt_ab_testing_helper(
|
||
|
|
name: Optional[str] = None,
|
||
|
|
):
|
||
|
|
variables = {"lineageName": name}
|
||
|
|
|
||
|
|
def process_response(response) -> list[PromptRollout]:
|
||
|
|
response_data = response["data"]["promptLineageRollout"]
|
||
|
|
return list(map(lambda x: x["node"], response_data["edges"]))
|
||
|
|
|
||
|
|
description = "get prompt A/B testing"
|
||
|
|
|
||
|
|
return gql.GET_PROMPT_AB_TESTING, description, variables, process_response
|
||
|
|
|
||
|
|
|
||
|
|
def update_prompt_ab_testing_helper(name: str, rollouts: list[PromptRollout]):
|
||
|
|
variables = {"name": name, "rollouts": rollouts}
|
||
|
|
|
||
|
|
def process_response(response) -> dict:
|
||
|
|
return response["data"]["updatePromptLineageRollout"]
|
||
|
|
|
||
|
|
description = "update prompt A/B testing"
|
||
|
|
|
||
|
|
return gql.UPDATE_PROMPT_AB_TESTING, description, variables, process_response
|