ai-station/.venv/lib/python3.12/site-packages/literalai/api/synchronous.py

865 lines
29 KiB
Python
Raw Normal View History

import logging
import uuid
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, cast
import httpx
from typing_extensions import deprecated
from literalai.api.base import BaseLiteralAPI, prepare_variables
from literalai.api.helpers.attachment_helpers import (
AttachmentUpload,
create_attachment_helper,
delete_attachment_helper,
get_attachment_helper,
update_attachment_helper,
)
from literalai.api.helpers.dataset_helpers import (
add_generation_to_dataset_helper,
add_step_to_dataset_helper,
create_dataset_helper,
create_dataset_item_helper,
create_experiment_helper,
create_experiment_item_helper,
delete_dataset_helper,
delete_dataset_item_helper,
get_dataset_helper,
get_dataset_item_helper,
update_dataset_helper,
)
from literalai.api.helpers.generation_helpers import (
create_generation_helper,
get_generations_helper,
)
from literalai.api.helpers.prompt_helpers import (
PromptRollout,
create_prompt_helper,
create_prompt_lineage_helper,
create_prompt_variant_helper,
get_prompt_ab_testing_helper,
get_prompt_helper,
get_prompt_lineage_helper,
update_prompt_ab_testing_helper,
)
from literalai.api.helpers.score_helpers import (
ScoreUpdate,
check_scores_finite,
create_score_helper,
create_scores_query_builder,
delete_score_helper,
get_scores_helper,
update_score_helper,
)
from literalai.api.helpers.step_helpers import (
create_step_helper,
delete_step_helper,
get_step_helper,
get_steps_helper,
send_steps_helper,
update_step_helper,
)
from literalai.api.helpers.thread_helpers import (
create_thread_helper,
delete_thread_helper,
get_thread_helper,
get_threads_helper,
list_threads_helper,
update_thread_helper,
upsert_thread_helper,
)
from literalai.api.helpers.user_helpers import (
create_user_helper,
delete_user_helper,
get_user_helper,
get_users_helper,
update_user_helper,
)
from literalai.context import active_steps_var, active_thread_var
from literalai.evaluation.dataset import Dataset, DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperiment,
DatasetExperimentItem,
)
from literalai.evaluation.dataset_item import DatasetItem
from literalai.my_types import PaginatedResponse, User
from literalai.observability.filter import (
generations_filters,
generations_order_by,
scores_filters,
scores_order_by,
steps_filters,
steps_order_by,
threads_filters,
threads_order_by,
users_filters,
)
from literalai.observability.generation import (
BaseGeneration,
ChatGeneration,
CompletionGeneration,
GenerationMessage,
)
from literalai.observability.step import (
Attachment,
Score,
ScoreDict,
ScoreType,
Step,
StepDict,
StepType,
)
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
logger = logging.getLogger(__name__)
class LiteralAPI(BaseLiteralAPI):
"""
```python
from literalai import LiteralClient
# Initialize the client
literalai_client = LiteralClient(api_key="your_api_key_here")
# Access the API's methods
print(literalai_client.api)
```
"""
R = TypeVar("R")
def make_gql_call(
self,
description: str,
query: str,
variables: dict[str, Any],
timeout: Optional[int] = 10,
) -> dict:
def raise_error(error):
logger.error(f"Failed to {description}: {error}")
raise Exception(error)
variables = prepare_variables(variables)
with httpx.Client(follow_redirects=True) as client:
response = client.post(
self.graphql_endpoint,
json={"query": query, "variables": variables},
headers=self.headers,
timeout=timeout,
)
try:
response.raise_for_status()
except httpx.HTTPStatusError:
raise_error(f"Failed to {description}: {response.text}")
try:
json = response.json()
except ValueError as e:
raise_error(
f"Failed to parse JSON response: {e}, content: {response.content!r}"
)
if json.get("errors"):
raise_error(json["errors"])
if json.get("data"):
if isinstance(json["data"], dict):
for value in json["data"].values():
if value and value.get("ok") is False:
raise_error(
f"""Failed to {description}: {value.get("message")}"""
)
return json
def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
with httpx.Client(follow_redirects=True) as client:
response = client.post(
self.rest_endpoint + subpath,
json=body,
headers=self.headers,
timeout=20,
)
try:
response.raise_for_status()
except httpx.HTTPStatusError:
message = f"Failed to call {subpath}: {response.text}"
logger.error(message)
raise Exception(message)
try:
return response.json()
except ValueError as e:
raise ValueError(
f"Failed to parse JSON response: {e}, content: {response.content!r}"
)
def gql_helper(
self,
query: str,
description: str,
variables: Dict,
process_response: Callable[..., R],
timeout: Optional[int] = None,
) -> R:
response = self.make_gql_call(description, query, variables, timeout)
return process_response(response)
##################################################################################
# User APIs #
##################################################################################
def get_users(
self,
first: Optional[int] = None,
after: Optional[str] = None,
before: Optional[str] = None,
filters: Optional[users_filters] = None,
) -> PaginatedResponse["User"]:
return self.gql_helper(*get_users_helper(first, after, before, filters))
def get_user(
self, id: Optional[str] = None, identifier: Optional[str] = None
) -> "User":
return self.gql_helper(*get_user_helper(id, identifier))
def create_user(self, identifier: str, metadata: Optional[Dict] = None) -> "User":
return self.gql_helper(*create_user_helper(identifier, metadata))
def update_user(
self, id: str, identifier: Optional[str] = None, metadata: Optional[Dict] = None
) -> "User":
return self.gql_helper(*update_user_helper(id, identifier, metadata))
def delete_user(self, id: str) -> Dict:
return self.gql_helper(*delete_user_helper(id))
def get_or_create_user(
self, identifier: str, metadata: Optional[Dict] = None
) -> "User":
user = self.get_user(identifier=identifier)
if user:
return user
return self.create_user(identifier, metadata)
##################################################################################
# Thread APIs #
##################################################################################
def get_threads(
self,
first: Optional[int] = None,
after: Optional[str] = None,
before: Optional[str] = None,
filters: Optional[threads_filters] = None,
order_by: Optional[threads_order_by] = None,
step_types_to_keep: Optional[List[StepType]] = None,
) -> PaginatedResponse["Thread"]:
return self.gql_helper(
*get_threads_helper(
first, after, before, filters, order_by, step_types_to_keep
)
)
def list_threads(
self,
first: Optional[int] = None,
after: Optional[str] = None,
before: Optional[str] = None,
filters: Optional[threads_filters] = None,
order_by: Optional[threads_order_by] = None,
) -> PaginatedResponse["Thread"]:
return self.gql_helper(
*list_threads_helper(first, after, before, filters, order_by)
)
def get_thread(self, id: str) -> "Thread":
return self.gql_helper(*get_thread_helper(id))
def create_thread(
self,
name: Optional[str] = None,
metadata: Optional[Dict] = None,
participant_id: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> "Thread":
return self.gql_helper(
*create_thread_helper(name, metadata, participant_id, tags)
)
def upsert_thread(
self,
id: str,
name: Optional[str] = None,
metadata: Optional[Dict] = None,
participant_id: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> "Thread":
return self.gql_helper(
*upsert_thread_helper(id, name, metadata, participant_id, tags)
)
def update_thread(
self,
id: str,
name: Optional[str] = None,
metadata: Optional[Dict] = None,
participant_id: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> "Thread":
return self.gql_helper(
*update_thread_helper(id, name, metadata, participant_id, tags)
)
def delete_thread(self, id: str) -> bool:
return self.gql_helper(*delete_thread_helper(id))
##################################################################################
# Score APIs #
##################################################################################
def get_scores(
self,
first: Optional[int] = None,
after: Optional[str] = None,
before: Optional[str] = None,
filters: Optional[scores_filters] = None,
order_by: Optional[scores_order_by] = None,
) -> PaginatedResponse["Score"]:
return self.gql_helper(
*get_scores_helper(first, after, before, filters, order_by)
)
def create_scores(self, scores: List["ScoreDict"]):
check_scores_finite(scores)
query = create_scores_query_builder(scores)
variables = {}
for id, score in enumerate(scores):
for k, v in score.items():
variables[f"{k}_{id}"] = v
def process_response(response):
return [x for x in response["data"].values()]
return self.gql_helper(query, "create scores", variables, process_response)
def create_score(
self,
name: str,
value: float,
type: ScoreType,
step_id: Optional[str] = None,
generation_id: Optional[str] = None,
dataset_experiment_item_id: Optional[str] = None,
comment: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> "Score":
if generation_id:
logger.warning(
"generation_id is deprecated and will be removed in a future version, please use step_id instead"
)
check_scores_finite([{"name": name, "value": value}])
return self.gql_helper(
*create_score_helper(
name,
value,
type,
step_id,
dataset_experiment_item_id,
comment,
tags,
)
)
def update_score(
self,
id: str,
update_params: ScoreUpdate,
) -> "Score":
return self.gql_helper(*update_score_helper(id, update_params))
def delete_score(self, id: str) -> Dict:
return self.gql_helper(*delete_score_helper(id))
##################################################################################
# Attachment APIs #
##################################################################################
def upload_file(
self,
content: Union[bytes, str],
thread_id: Optional[str] = None,
mime: Optional[str] = "application/octet-stream",
) -> Dict:
id = str(uuid.uuid4())
body = {"fileName": id, "contentType": mime}
if thread_id:
body["threadId"] = thread_id
path = "/api/upload/file"
with httpx.Client(follow_redirects=True) as client:
response = client.post(
f"{self.url}{path}",
json=body,
headers=self.headers,
)
if response.status_code >= 400:
reason = response.text
logger.error(f"Failed to sign upload url: {reason}")
return {"object_key": None, "url": None}
json_res = response.json()
method = "put" if "put" in json_res else "post"
request_dict: Dict[str, Any] = json_res.get(method, {})
url: Optional[str] = request_dict.get("url")
if not url:
raise Exception("Invalid server response")
headers: Optional[Dict] = request_dict.get("headers")
fields: Dict = request_dict.get("fields", {})
object_key: Optional[str] = fields.get("key")
upload_type: Literal["raw", "multipart"] = cast(
Literal["raw", "multipart"], request_dict.get("uploadType", "multipart")
)
signed_url: Optional[str] = json_res.get("signedUrl")
# Prepare form data
form_data = (
{}
) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
for field_name, field_value in fields.items():
form_data[field_name] = (None, field_value)
# Add file to the form_data
# Note: The content_type parameter is not needed here, as the correct MIME type should be set
# in the 'Content-Type' field from upload_details
form_data["file"] = (id, content, mime)
with httpx.Client(follow_redirects=True) as client:
if upload_type == "raw":
upload_response = client.request(
url=url,
headers=headers,
method=method,
data=content, # type: ignore
)
else:
upload_response = client.request(
url=url,
headers=headers,
method=method,
files=form_data,
) # type: ignore
try:
upload_response.raise_for_status()
return {"object_key": object_key, "url": signed_url}
except Exception as e:
logger.error(f"Failed to upload file: {str(e)}")
return {"object_key": None, "url": None}
def create_attachment(
self,
thread_id: Optional[str] = None,
step_id: Optional[str] = None,
id: Optional[str] = None,
metadata: Optional[Dict] = None,
mime: Optional[str] = None,
name: Optional[str] = None,
object_key: Optional[str] = None,
url: Optional[str] = None,
content: Optional[Union[bytes, str]] = None,
path: Optional[str] = None,
) -> "Attachment":
if not thread_id:
if active_thread := active_thread_var.get(None):
thread_id = active_thread.id
if not step_id:
if active_steps := active_steps_var.get():
step_id = active_steps[-1].id
else:
raise Exception("No step_id provided and no active step found.")
(
query,
description,
variables,
content,
process_response,
) = create_attachment_helper(
thread_id=thread_id,
step_id=step_id,
id=id,
metadata=metadata,
mime=mime,
name=name,
object_key=object_key,
url=url,
content=content,
path=path,
)
if content:
uploaded = self.upload_file(content=content, thread_id=thread_id, mime=mime)
if uploaded["object_key"] is None or uploaded["url"] is None:
raise Exception("Failed to upload file")
object_key = uploaded["object_key"]
if object_key:
variables["objectKey"] = object_key
else:
variables["url"] = uploaded["url"]
response = self.make_gql_call(description, query, variables)
return process_response(response)
def update_attachment(
self, id: str, update_params: AttachmentUpload
) -> "Attachment":
return self.gql_helper(*update_attachment_helper(id, update_params))
def get_attachment(self, id: str) -> Optional["Attachment"]:
return self.gql_helper(*get_attachment_helper(id))
def delete_attachment(self, id: str) -> Dict:
return self.gql_helper(*delete_attachment_helper(id))
##################################################################################
# Step APIs #
##################################################################################
def create_step(
self,
thread_id: Optional[str] = None,
type: Optional[StepType] = "undefined",
start_time: Optional[str] = None,
end_time: Optional[str] = None,
input: Optional[Dict] = None,
output: Optional[Dict] = None,
metadata: Optional[Dict] = None,
parent_id: Optional[str] = None,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
root_run_id: Optional[str] = None,
) -> Step:
return self.gql_helper(
*create_step_helper(
thread_id=thread_id,
type=type,
start_time=start_time,
end_time=end_time,
input=input,
output=output,
metadata=metadata,
parent_id=parent_id,
name=name,
tags=tags,
root_run_id=root_run_id,
)
)
def update_step(
self,
id: str,
type: Optional[StepType] = None,
input: Optional[str] = None,
output: Optional[str] = None,
metadata: Optional[Dict] = None,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
parent_id: Optional[str] = None,
) -> "Step":
return self.gql_helper(
*update_step_helper(
id=id,
type=type,
input=input,
output=output,
metadata=metadata,
name=name,
tags=tags,
start_time=start_time,
end_time=end_time,
parent_id=parent_id,
)
)
def get_steps(
self,
first: Optional[int] = None,
after: Optional[str] = None,
before: Optional[str] = None,
filters: Optional[steps_filters] = None,
order_by: Optional[steps_order_by] = None,
) -> PaginatedResponse["Step"]:
return self.gql_helper(
*get_steps_helper(first, after, before, filters, order_by)
)
def get_step(
self,
id: str,
) -> Optional["Step"]:
return self.gql_helper(*get_step_helper(id=id))
def delete_step(
self,
id: str,
) -> bool:
return self.gql_helper(*delete_step_helper(id=id))
def send_steps(self, steps: List[Union["StepDict", "Step"]]):
return self.gql_helper(*send_steps_helper(steps=steps))
##################################################################################
# Generation APIs #
##################################################################################
def get_generations(
self,
first: Optional[int] = None,
after: Optional[str] = None,
before: Optional[str] = None,
filters: Optional[generations_filters] = None,
order_by: Optional[generations_order_by] = None,
) -> PaginatedResponse["BaseGeneration"]:
return self.gql_helper(
*get_generations_helper(first, after, before, filters, order_by)
)
def create_generation(
self, generation: Union["ChatGeneration", "CompletionGeneration"]
) -> Union["ChatGeneration", "CompletionGeneration"]:
return self.gql_helper(*create_generation_helper(generation))
##################################################################################
# Dataset APIs #
##################################################################################
def create_dataset(
self,
name: str,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
type: DatasetType = "key_value",
) -> "Dataset":
return self.gql_helper(
*create_dataset_helper(self, name, description, metadata, type)
)
def get_dataset(
self, id: Optional[str] = None, name: Optional[str] = None
) -> Optional["Dataset"]:
subpath, _, variables, process_response = get_dataset_helper(
self, id=id, name=name
)
response = self.make_rest_call(subpath, variables)
return process_response(response)
def update_dataset(
self,
id: str,
name: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[Dict] = None,
) -> "Dataset":
return self.gql_helper(
*update_dataset_helper(self, id, name, description, metadata)
)
def delete_dataset(self, id: str) -> "Dataset":
return self.gql_helper(*delete_dataset_helper(self, id))
##################################################################################
# Experiment APIs #
##################################################################################
def create_experiment(
self,
name: str,
dataset_id: Optional[str] = None,
prompt_variant_id: Optional[str] = None,
params: Optional[Dict] = None,
) -> "DatasetExperiment":
return self.gql_helper(
*create_experiment_helper(
api=self,
name=name,
dataset_id=dataset_id,
prompt_variant_id=prompt_variant_id,
params=params,
)
)
def create_experiment_item(
self, experiment_item: DatasetExperimentItem
) -> "DatasetExperimentItem":
# Create the dataset experiment item
result = self.gql_helper(
*create_experiment_item_helper(
dataset_experiment_id=experiment_item.dataset_experiment_id,
dataset_item_id=experiment_item.dataset_item_id,
experiment_run_id=experiment_item.experiment_run_id,
input=experiment_item.input,
output=experiment_item.output,
)
)
for score in experiment_item.scores:
score["datasetExperimentItemId"] = result.id
# Create the scores and add to experiment item.
result.scores = self.create_scores(experiment_item.scores)
return result
##################################################################################
# Dataset Item APIs #
##################################################################################
def create_dataset_item(
self,
dataset_id: str,
input: Dict,
expected_output: Optional[Dict] = None,
metadata: Optional[Dict] = None,
) -> "DatasetItem":
return self.gql_helper(
*create_dataset_item_helper(dataset_id, input, expected_output, metadata)
)
def get_dataset_item(self, id: str) -> Optional["DatasetItem"]:
return self.gql_helper(*get_dataset_item_helper(id))
def delete_dataset_item(self, id: str) -> "DatasetItem":
return self.gql_helper(*delete_dataset_item_helper(id))
def add_step_to_dataset(
self, dataset_id: str, step_id: str, metadata: Optional[Dict] = None
) -> "DatasetItem":
return self.gql_helper(
*add_step_to_dataset_helper(dataset_id, step_id, metadata)
)
def add_generation_to_dataset(
self, dataset_id: str, generation_id: str, metadata: Optional[Dict] = None
) -> "DatasetItem":
return self.gql_helper(
*add_generation_to_dataset_helper(dataset_id, generation_id, metadata)
)
##################################################################################
# Prompt APIs #
##################################################################################
def get_or_create_prompt_lineage(
self, name: str, description: Optional[str] = None
) -> Dict:
return self.gql_helper(*create_prompt_lineage_helper(name, description))
@deprecated("Use get_or_create_prompt_lineage instead")
def create_prompt_lineage(
self, name: str, description: Optional[str] = None
) -> Dict:
return self.get_or_create_prompt_lineage(name, description)
def get_or_create_prompt(
self,
name: str,
template_messages: List[GenerationMessage],
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
) -> "Prompt":
lineage = self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]
return self.gql_helper(
*create_prompt_helper(self, lineage_id, template_messages, settings, tools)
)
@deprecated("Please use `get_or_create_prompt` instead.")
def create_prompt(
self,
name: str,
template_messages: List[GenerationMessage],
settings: Optional[ProviderSettings] = None,
) -> "Prompt":
return self.get_or_create_prompt(name, template_messages, settings)
def get_prompt(
self,
id: Optional[str] = None,
name: Optional[str] = None,
version: Optional[int] = None,
) -> "Prompt":
if not (id or name):
raise ValueError("At least the `id` or the `name` must be provided.")
(
get_prompt_query,
description,
variables,
process_response,
timeout,
cached_prompt,
) = get_prompt_helper(
api=self, id=id, name=name, version=version, cache=self.cache
)
try:
if id:
prompt = self.gql_helper(
get_prompt_query, description, variables, process_response, timeout
)
elif name:
prompt = self.gql_helper(
get_prompt_query, description, variables, process_response, timeout
)
return prompt
except Exception as e:
if cached_prompt:
logger.warning("Failed to get prompt from API, returning cached prompt")
return cached_prompt
raise e
def create_prompt_variant(
self,
name: str,
template_messages: List[GenerationMessage],
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
) -> Optional[str]:
lineage = self.gql_helper(*get_prompt_lineage_helper(name))
lineage_id = lineage["id"] if lineage else None
return self.gql_helper(
*create_prompt_variant_helper(
lineage_id, template_messages, settings, tools
)
)
def get_prompt_ab_testing(self, name: str) -> List["PromptRollout"]:
return self.gql_helper(*get_prompt_ab_testing_helper(name=name))
def update_prompt_ab_testing(
self, name: str, rollouts: List["PromptRollout"]
) -> Dict:
return self.gql_helper(
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
)
##################################################################################
# Misc APIs #
##################################################################################
def get_my_project_id(self) -> str:
response = self.make_rest_call("/my-project", {})
return response["projectId"]