883 lines
30 KiB
Python
883 lines
30 KiB
Python
import logging
|
|
import uuid
|
|
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, cast
|
|
|
|
import httpx
|
|
from typing_extensions import deprecated
|
|
|
|
from literalai.api.base import BaseLiteralAPI, prepare_variables
|
|
from literalai.api.helpers.attachment_helpers import (
|
|
AttachmentUpload,
|
|
create_attachment_helper,
|
|
delete_attachment_helper,
|
|
get_attachment_helper,
|
|
update_attachment_helper,
|
|
)
|
|
from literalai.api.helpers.dataset_helpers import (
|
|
add_generation_to_dataset_helper,
|
|
add_step_to_dataset_helper,
|
|
create_dataset_helper,
|
|
create_dataset_item_helper,
|
|
create_experiment_helper,
|
|
create_experiment_item_helper,
|
|
delete_dataset_helper,
|
|
delete_dataset_item_helper,
|
|
get_dataset_helper,
|
|
get_dataset_item_helper,
|
|
update_dataset_helper,
|
|
)
|
|
from literalai.api.helpers.generation_helpers import (
|
|
create_generation_helper,
|
|
get_generations_helper,
|
|
)
|
|
from literalai.api.helpers.prompt_helpers import (
|
|
PromptRollout,
|
|
create_prompt_helper,
|
|
create_prompt_lineage_helper,
|
|
create_prompt_variant_helper,
|
|
get_prompt_ab_testing_helper,
|
|
get_prompt_helper,
|
|
get_prompt_lineage_helper,
|
|
update_prompt_ab_testing_helper,
|
|
)
|
|
from literalai.api.helpers.score_helpers import (
|
|
ScoreUpdate,
|
|
check_scores_finite,
|
|
create_score_helper,
|
|
create_scores_query_builder,
|
|
delete_score_helper,
|
|
get_scores_helper,
|
|
update_score_helper,
|
|
)
|
|
from literalai.api.helpers.step_helpers import (
|
|
create_step_helper,
|
|
delete_step_helper,
|
|
get_step_helper,
|
|
get_steps_helper,
|
|
send_steps_helper,
|
|
update_step_helper,
|
|
)
|
|
from literalai.api.helpers.thread_helpers import (
|
|
create_thread_helper,
|
|
delete_thread_helper,
|
|
get_thread_helper,
|
|
get_threads_helper,
|
|
list_threads_helper,
|
|
update_thread_helper,
|
|
upsert_thread_helper,
|
|
)
|
|
from literalai.api.helpers.user_helpers import (
|
|
create_user_helper,
|
|
delete_user_helper,
|
|
get_user_helper,
|
|
get_users_helper,
|
|
update_user_helper,
|
|
)
|
|
from literalai.api.synchronous import LiteralAPI
|
|
from literalai.context import active_steps_var, active_thread_var
|
|
from literalai.evaluation.dataset import Dataset, DatasetType
|
|
from literalai.evaluation.dataset_experiment import (
|
|
DatasetExperiment,
|
|
DatasetExperimentItem,
|
|
)
|
|
from literalai.evaluation.dataset_item import DatasetItem
|
|
from literalai.my_types import PaginatedResponse, User
|
|
from literalai.observability.filter import (
|
|
generations_filters,
|
|
generations_order_by,
|
|
scores_filters,
|
|
scores_order_by,
|
|
steps_filters,
|
|
steps_order_by,
|
|
threads_filters,
|
|
threads_order_by,
|
|
users_filters,
|
|
)
|
|
from literalai.observability.generation import (
|
|
BaseGeneration,
|
|
ChatGeneration,
|
|
CompletionGeneration,
|
|
GenerationMessage,
|
|
)
|
|
from literalai.observability.step import (
|
|
Attachment,
|
|
Score,
|
|
ScoreDict,
|
|
ScoreType,
|
|
Step,
|
|
StepDict,
|
|
StepType,
|
|
)
|
|
from literalai.observability.thread import Thread
|
|
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AsyncLiteralAPI(BaseLiteralAPI):
|
|
"""
|
|
```python
|
|
from literalai import AsyncLiteralClient
|
|
# Initialize the client
|
|
async_literalai_client = AsyncLiteralClient(api_key="your_api_key_here")
|
|
# Access the API's methods
|
|
print(async_literalai_client.api)
|
|
```
|
|
"""
|
|
|
|
R = TypeVar("R")
|
|
|
|
async def make_gql_call(
|
|
self,
|
|
description: str,
|
|
query: str,
|
|
variables: Dict[str, Any],
|
|
timeout: Optional[int] = 10,
|
|
) -> Dict:
|
|
def raise_error(error):
|
|
logger.error(f"Failed to {description}: {error}")
|
|
raise Exception(error)
|
|
|
|
variables = prepare_variables(variables)
|
|
|
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
response = await client.post(
|
|
self.graphql_endpoint,
|
|
json={"query": query, "variables": variables},
|
|
headers=self.headers,
|
|
timeout=timeout,
|
|
)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError:
|
|
raise_error(f"Failed to {description}: {response.text}")
|
|
|
|
try:
|
|
json = response.json()
|
|
except ValueError as e:
|
|
raise_error(
|
|
f"Failed to parse JSON response: {e}, content: {response.content!r}"
|
|
)
|
|
|
|
if json.get("errors"):
|
|
raise_error(json["errors"])
|
|
|
|
if json.get("data"):
|
|
if isinstance(json["data"], dict):
|
|
for value in json["data"].values():
|
|
if value and value.get("ok") is False:
|
|
raise_error(
|
|
f"""Failed to {description}: {value.get("message")}"""
|
|
)
|
|
return json
|
|
|
|
async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
|
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
response = await client.post(
|
|
self.rest_endpoint + subpath,
|
|
json=body,
|
|
headers=self.headers,
|
|
timeout=20,
|
|
)
|
|
|
|
try:
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError:
|
|
message = f"Failed to call {subpath}: {response.text}"
|
|
logger.error(message)
|
|
raise Exception(message)
|
|
|
|
try:
|
|
return response.json()
|
|
except ValueError as e:
|
|
raise ValueError(
|
|
f"Failed to parse JSON response: {e}, content: {response.content!r}"
|
|
)
|
|
|
|
async def gql_helper(
|
|
self,
|
|
query: str,
|
|
description: str,
|
|
variables: Dict,
|
|
process_response: Callable[..., R],
|
|
timeout: Optional[int] = 10,
|
|
) -> R:
|
|
response = await self.make_gql_call(description, query, variables, timeout)
|
|
return process_response(response)
|
|
|
|
##################################################################################
|
|
# User APIs #
|
|
##################################################################################
|
|
|
|
async def get_users(
|
|
self,
|
|
first: Optional[int] = None,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
filters: Optional[users_filters] = None,
|
|
) -> PaginatedResponse["User"]:
|
|
return await self.gql_helper(*get_users_helper(first, after, before, filters))
|
|
|
|
async def get_user(
|
|
self, id: Optional[str] = None, identifier: Optional[str] = None
|
|
) -> "User":
|
|
return await self.gql_helper(*get_user_helper(id, identifier))
|
|
|
|
async def create_user(
|
|
self, identifier: str, metadata: Optional[Dict] = None
|
|
) -> "User":
|
|
return await self.gql_helper(*create_user_helper(identifier, metadata))
|
|
|
|
async def update_user(
|
|
self, id: str, identifier: Optional[str] = None, metadata: Optional[Dict] = None
|
|
) -> "User":
|
|
return await self.gql_helper(*update_user_helper(id, identifier, metadata))
|
|
|
|
async def delete_user(self, id: str) -> Dict:
|
|
return await self.gql_helper(*delete_user_helper(id))
|
|
|
|
async def get_or_create_user(
|
|
self, identifier: str, metadata: Optional[Dict] = None
|
|
) -> "User":
|
|
user = await self.get_user(identifier=identifier)
|
|
if user:
|
|
return user
|
|
|
|
return await self.create_user(identifier, metadata)
|
|
|
|
##################################################################################
|
|
# Thread APIs #
|
|
##################################################################################
|
|
|
|
async def get_threads(
|
|
self,
|
|
first: Optional[int] = None,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
filters: Optional[threads_filters] = None,
|
|
order_by: Optional[threads_order_by] = None,
|
|
step_types_to_keep: Optional[List[StepType]] = None,
|
|
) -> PaginatedResponse["Thread"]:
|
|
return await self.gql_helper(
|
|
*get_threads_helper(
|
|
first, after, before, filters, order_by, step_types_to_keep
|
|
)
|
|
)
|
|
|
|
async def list_threads(
|
|
self,
|
|
first: Optional[int] = None,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
filters: Optional[threads_filters] = None,
|
|
order_by: Optional[threads_order_by] = None,
|
|
) -> PaginatedResponse["Thread"]:
|
|
return await self.gql_helper(
|
|
*list_threads_helper(first, after, before, filters, order_by)
|
|
)
|
|
|
|
async def get_thread(self, id: str) -> "Thread":
|
|
return await self.gql_helper(*get_thread_helper(id))
|
|
|
|
async def create_thread(
|
|
self,
|
|
name: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
participant_id: Optional[str] = None,
|
|
tags: Optional[List[str]] = None,
|
|
) -> "Thread":
|
|
return await self.gql_helper(
|
|
*create_thread_helper(name, metadata, participant_id, tags)
|
|
)
|
|
|
|
async def upsert_thread(
|
|
self,
|
|
id: str,
|
|
name: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
participant_id: Optional[str] = None,
|
|
tags: Optional[List[str]] = None,
|
|
) -> "Thread":
|
|
return await self.gql_helper(
|
|
*upsert_thread_helper(id, name, metadata, participant_id, tags)
|
|
)
|
|
|
|
async def update_thread(
|
|
self,
|
|
id: str,
|
|
name: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
participant_id: Optional[str] = None,
|
|
tags: Optional[List[str]] = None,
|
|
) -> "Thread":
|
|
return await self.gql_helper(
|
|
*update_thread_helper(id, name, metadata, participant_id, tags)
|
|
)
|
|
|
|
async def delete_thread(self, id: str) -> bool:
|
|
return await self.gql_helper(*delete_thread_helper(id))
|
|
|
|
##################################################################################
|
|
# Score APIs #
|
|
##################################################################################
|
|
|
|
async def get_scores(
|
|
self,
|
|
first: Optional[int] = None,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
filters: Optional[scores_filters] = None,
|
|
order_by: Optional[scores_order_by] = None,
|
|
) -> PaginatedResponse["Score"]:
|
|
return await self.gql_helper(
|
|
*get_scores_helper(first, after, before, filters, order_by)
|
|
)
|
|
|
|
async def create_scores(self, scores: List[ScoreDict]) -> List["Score"]:
|
|
check_scores_finite(scores)
|
|
|
|
query = create_scores_query_builder(scores)
|
|
variables = {}
|
|
|
|
for id, score in enumerate(scores):
|
|
for k, v in score.items():
|
|
variables[f"{k}_{id}"] = v
|
|
|
|
def process_response(response):
|
|
return [Score.from_dict(x) for x in response["data"].values()]
|
|
|
|
return await self.gql_helper(
|
|
query, "create scores", variables, process_response
|
|
)
|
|
|
|
async def create_score(
|
|
self,
|
|
name: str,
|
|
value: float,
|
|
type: ScoreType,
|
|
step_id: Optional[str] = None,
|
|
generation_id: Optional[str] = None,
|
|
dataset_experiment_item_id: Optional[str] = None,
|
|
comment: Optional[str] = None,
|
|
tags: Optional[List[str]] = None,
|
|
) -> "Score":
|
|
if generation_id:
|
|
logger.warning(
|
|
"generation_id is deprecated and will be removed in a future version, please use step_id instead"
|
|
)
|
|
check_scores_finite([{"name": name, "value": value}])
|
|
|
|
return await self.gql_helper(
|
|
*create_score_helper(
|
|
name,
|
|
value,
|
|
type,
|
|
step_id,
|
|
dataset_experiment_item_id,
|
|
comment,
|
|
tags,
|
|
)
|
|
)
|
|
|
|
async def update_score(
|
|
self,
|
|
id: str,
|
|
update_params: ScoreUpdate,
|
|
) -> "Score":
|
|
return await self.gql_helper(*update_score_helper(id, update_params))
|
|
|
|
async def delete_score(self, id: str) -> Dict:
|
|
return await self.gql_helper(*delete_score_helper(id))
|
|
|
|
##################################################################################
|
|
# Attachment APIs #
|
|
##################################################################################
|
|
|
|
async def upload_file(
|
|
self,
|
|
content: Union[bytes, str],
|
|
thread_id: Optional[str] = None,
|
|
mime: Optional[str] = "application/octet-stream",
|
|
) -> Dict:
|
|
id = str(uuid.uuid4())
|
|
body = {"fileName": id, "contentType": mime, "threadId": thread_id}
|
|
|
|
path = "/api/upload/file"
|
|
|
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
response = await client.post(
|
|
f"{self.url}{path}",
|
|
json=body,
|
|
headers=self.headers,
|
|
)
|
|
if response.status_code >= 400:
|
|
reason = response.text
|
|
logger.error(f"Failed to sign upload url: {reason}")
|
|
return {"object_key": None, "url": None}
|
|
json_res = response.json()
|
|
method = "put" if "put" in json_res else "post"
|
|
request_dict: Dict[str, Any] = json_res.get(method, {})
|
|
url: Optional[str] = request_dict.get("url")
|
|
|
|
if not url:
|
|
raise Exception("Invalid server response")
|
|
headers: Optional[Dict] = request_dict.get("headers")
|
|
fields: Dict = request_dict.get("fields", {})
|
|
object_key: Optional[str] = fields.get("key")
|
|
upload_type: Literal["raw", "multipart"] = cast(
|
|
Literal["raw", "multipart"], request_dict.get("uploadType", "multipart")
|
|
)
|
|
signed_url: Optional[str] = json_res.get("signedUrl")
|
|
|
|
# Prepare form data
|
|
form_data = (
|
|
{}
|
|
) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
|
|
for field_name, field_value in fields.items():
|
|
form_data[field_name] = (None, field_value)
|
|
|
|
# Add file to the form_data
|
|
# Note: The content_type parameter is not needed here, as the correct MIME type should be set
|
|
# in the 'Content-Type' field from upload_details
|
|
form_data["file"] = (id, content, mime)
|
|
|
|
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
if upload_type == "raw":
|
|
upload_response = await client.request(
|
|
url=url,
|
|
headers=headers,
|
|
method=method,
|
|
data=content, # type: ignore
|
|
)
|
|
else:
|
|
upload_response = await client.request(
|
|
url=url,
|
|
headers=headers,
|
|
method=method,
|
|
files=form_data,
|
|
) # type: ignore
|
|
try:
|
|
upload_response.raise_for_status()
|
|
return {"object_key": object_key, "url": signed_url}
|
|
except Exception as e:
|
|
logger.error(f"Failed to upload file: {str(e)}")
|
|
return {"object_key": None, "url": None}
|
|
|
|
async def create_attachment(
|
|
self,
|
|
thread_id: Optional[str] = None,
|
|
step_id: Optional[str] = None,
|
|
id: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
mime: Optional[str] = None,
|
|
name: Optional[str] = None,
|
|
object_key: Optional[str] = None,
|
|
url: Optional[str] = None,
|
|
content: Optional[Union[bytes, str]] = None,
|
|
path: Optional[str] = None,
|
|
) -> "Attachment":
|
|
if not thread_id:
|
|
if active_thread := active_thread_var.get(None):
|
|
thread_id = active_thread.id
|
|
|
|
if not step_id:
|
|
if active_steps := active_steps_var.get():
|
|
step_id = active_steps[-1].id
|
|
else:
|
|
raise Exception("No step_id provided and no active step found.")
|
|
|
|
(
|
|
query,
|
|
description,
|
|
variables,
|
|
content,
|
|
process_response,
|
|
) = create_attachment_helper(
|
|
thread_id=thread_id,
|
|
step_id=step_id,
|
|
id=id,
|
|
metadata=metadata,
|
|
mime=mime,
|
|
name=name,
|
|
object_key=object_key,
|
|
url=url,
|
|
content=content,
|
|
path=path,
|
|
)
|
|
|
|
if content:
|
|
uploaded = await self.upload_file(
|
|
content=content, thread_id=thread_id, mime=mime
|
|
)
|
|
|
|
if uploaded["object_key"] is None or uploaded["url"] is None:
|
|
raise Exception("Failed to upload file")
|
|
|
|
object_key = uploaded["object_key"]
|
|
if object_key:
|
|
variables["objectKey"] = object_key
|
|
else:
|
|
variables["url"] = uploaded["url"]
|
|
|
|
response = await self.make_gql_call(description, query, variables)
|
|
return process_response(response)
|
|
|
|
async def update_attachment(
|
|
self, id: str, update_params: AttachmentUpload
|
|
) -> "Attachment":
|
|
return await self.gql_helper(*update_attachment_helper(id, update_params))
|
|
|
|
async def get_attachment(self, id: str) -> Optional["Attachment"]:
|
|
return await self.gql_helper(*get_attachment_helper(id))
|
|
|
|
async def delete_attachment(self, id: str) -> Dict:
|
|
return await self.gql_helper(*delete_attachment_helper(id))
|
|
|
|
##################################################################################
|
|
# Step APIs #
|
|
##################################################################################
|
|
|
|
async def create_step(
|
|
self,
|
|
thread_id: Optional[str] = None,
|
|
type: Optional[StepType] = "undefined",
|
|
start_time: Optional[str] = None,
|
|
end_time: Optional[str] = None,
|
|
input: Optional[Dict] = None,
|
|
output: Optional[Dict] = None,
|
|
metadata: Optional[Dict] = None,
|
|
parent_id: Optional[str] = None,
|
|
name: Optional[str] = None,
|
|
tags: Optional[List[str]] = None,
|
|
root_run_id: Optional[str] = None,
|
|
) -> "Step":
|
|
return await self.gql_helper(
|
|
*create_step_helper(
|
|
thread_id=thread_id,
|
|
type=type,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
input=input,
|
|
output=output,
|
|
metadata=metadata,
|
|
parent_id=parent_id,
|
|
name=name,
|
|
tags=tags,
|
|
root_run_id=root_run_id,
|
|
)
|
|
)
|
|
|
|
async def update_step(
|
|
self,
|
|
id: str,
|
|
type: Optional[StepType] = None,
|
|
input: Optional[str] = None,
|
|
output: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
name: Optional[str] = None,
|
|
tags: Optional[List[str]] = None,
|
|
start_time: Optional[str] = None,
|
|
end_time: Optional[str] = None,
|
|
parent_id: Optional[str] = None,
|
|
) -> "Step":
|
|
return await self.gql_helper(
|
|
*update_step_helper(
|
|
id=id,
|
|
type=type,
|
|
input=input,
|
|
output=output,
|
|
metadata=metadata,
|
|
name=name,
|
|
tags=tags,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
parent_id=parent_id,
|
|
)
|
|
)
|
|
|
|
async def get_steps(
|
|
self,
|
|
first: Optional[int] = None,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
filters: Optional[steps_filters] = None,
|
|
order_by: Optional[steps_order_by] = None,
|
|
) -> PaginatedResponse["Step"]:
|
|
return await self.gql_helper(
|
|
*get_steps_helper(first, after, before, filters, order_by)
|
|
)
|
|
|
|
async def get_step(
|
|
self,
|
|
id: str,
|
|
) -> Optional["Step"]:
|
|
return await self.gql_helper(*get_step_helper(id=id))
|
|
|
|
async def delete_step(
|
|
self,
|
|
id: str,
|
|
) -> bool:
|
|
return await self.gql_helper(*delete_step_helper(id=id))
|
|
|
|
async def send_steps(self, steps: List[Union["StepDict", "Step"]]):
|
|
return await self.gql_helper(*send_steps_helper(steps=steps))
|
|
|
|
##################################################################################
|
|
# Generation APIs #
|
|
##################################################################################
|
|
|
|
async def get_generations(
|
|
self,
|
|
first: Optional[int] = None,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
filters: Optional[generations_filters] = None,
|
|
order_by: Optional[generations_order_by] = None,
|
|
) -> PaginatedResponse["BaseGeneration"]:
|
|
return await self.gql_helper(
|
|
*get_generations_helper(first, after, before, filters, order_by)
|
|
)
|
|
|
|
async def create_generation(
|
|
self, generation: Union["ChatGeneration", "CompletionGeneration"]
|
|
) -> Union["ChatGeneration", "CompletionGeneration"]:
|
|
return await self.gql_helper(*create_generation_helper(generation))
|
|
|
|
##################################################################################
|
|
# Dataset APIs #
|
|
##################################################################################
|
|
|
|
async def create_dataset(
|
|
self,
|
|
name: str,
|
|
description: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
type: DatasetType = "key_value",
|
|
) -> "Dataset":
|
|
sync_api = LiteralAPI(self.api_key, self.url)
|
|
return await self.gql_helper(
|
|
*create_dataset_helper(sync_api, name, description, metadata, type)
|
|
)
|
|
|
|
async def get_dataset(
|
|
self, id: Optional[str] = None, name: Optional[str] = None
|
|
) -> "Dataset":
|
|
sync_api = LiteralAPI(self.api_key, self.url)
|
|
subpath, _, variables, process_response = get_dataset_helper(
|
|
sync_api, id=id, name=name
|
|
)
|
|
response = await self.make_rest_call(subpath, variables)
|
|
return process_response(response)
|
|
|
|
async def update_dataset(
|
|
self,
|
|
id: str,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
) -> "Dataset":
|
|
sync_api = LiteralAPI(self.api_key, self.url)
|
|
return await self.gql_helper(
|
|
*update_dataset_helper(sync_api, id, name, description, metadata)
|
|
)
|
|
|
|
async def delete_dataset(self, id: str) -> "Dataset":
|
|
sync_api = LiteralAPI(self.api_key, self.url)
|
|
return await self.gql_helper(*delete_dataset_helper(sync_api, id))
|
|
|
|
##################################################################################
|
|
# Experiment APIs #
|
|
##################################################################################
|
|
|
|
async def create_experiment(
|
|
self,
|
|
name: str,
|
|
dataset_id: Optional[str] = None,
|
|
prompt_variant_id: Optional[str] = None,
|
|
params: Optional[Dict] = None,
|
|
) -> "DatasetExperiment":
|
|
sync_api = LiteralAPI(self.api_key, self.url)
|
|
|
|
return await self.gql_helper(
|
|
*create_experiment_helper(
|
|
api=sync_api,
|
|
name=name,
|
|
dataset_id=dataset_id,
|
|
prompt_variant_id=prompt_variant_id,
|
|
params=params,
|
|
)
|
|
)
|
|
|
|
async def create_experiment_item(
|
|
self, experiment_item: "DatasetExperimentItem"
|
|
) -> "DatasetExperimentItem":
|
|
check_scores_finite(experiment_item.scores)
|
|
|
|
# Create the dataset experiment item
|
|
result = await self.gql_helper(
|
|
*create_experiment_item_helper(
|
|
dataset_experiment_id=experiment_item.dataset_experiment_id,
|
|
dataset_item_id=experiment_item.dataset_item_id,
|
|
experiment_run_id=experiment_item.experiment_run_id,
|
|
input=experiment_item.input,
|
|
output=experiment_item.output,
|
|
)
|
|
)
|
|
|
|
for score in experiment_item.scores:
|
|
score["datasetExperimentItemId"] = result.id
|
|
|
|
# Create the scores and add to experiment item.
|
|
result.scores = await self.create_scores(experiment_item.scores)
|
|
|
|
return result
|
|
|
|
##################################################################################
|
|
# DatasetItem APIs #
|
|
##################################################################################
|
|
|
|
async def create_dataset_item(
|
|
self,
|
|
dataset_id: str,
|
|
input: Dict,
|
|
expected_output: Optional[Dict] = None,
|
|
metadata: Optional[Dict] = None,
|
|
) -> "DatasetItem":
|
|
return await self.gql_helper(
|
|
*create_dataset_item_helper(dataset_id, input, expected_output, metadata)
|
|
)
|
|
|
|
async def get_dataset_item(self, id: str) -> "DatasetItem":
|
|
return await self.gql_helper(*get_dataset_item_helper(id))
|
|
|
|
async def delete_dataset_item(self, id: str) -> "DatasetItem":
|
|
return await self.gql_helper(*delete_dataset_item_helper(id))
|
|
|
|
async def add_step_to_dataset(
|
|
self, dataset_id: str, step_id: str, metadata: Optional[Dict] = None
|
|
) -> "DatasetItem":
|
|
return await self.gql_helper(
|
|
*add_step_to_dataset_helper(dataset_id, step_id, metadata)
|
|
)
|
|
|
|
async def add_generation_to_dataset(
|
|
self, dataset_id: str, generation_id: str, metadata: Optional[Dict] = None
|
|
) -> "DatasetItem":
|
|
return await self.gql_helper(
|
|
*add_generation_to_dataset_helper(dataset_id, generation_id, metadata)
|
|
)
|
|
|
|
##################################################################################
|
|
# Prompt APIs #
|
|
##################################################################################
|
|
|
|
async def get_or_create_prompt_lineage(
|
|
self, name: str, description: Optional[str] = None
|
|
) -> Dict:
|
|
return await self.gql_helper(*create_prompt_lineage_helper(name, description))
|
|
|
|
@deprecated('Please use "get_or_create_prompt_lineage" instead.')
|
|
async def create_prompt_lineage(
|
|
self, name: str, description: Optional[str] = None
|
|
) -> Dict:
|
|
return await self.get_or_create_prompt_lineage(name, description)
|
|
|
|
async def get_or_create_prompt(
|
|
self,
|
|
name: str,
|
|
template_messages: List[GenerationMessage],
|
|
settings: Optional[ProviderSettings] = None,
|
|
tools: Optional[List[Dict]] = None,
|
|
) -> "Prompt":
|
|
lineage = await self.get_or_create_prompt_lineage(name)
|
|
lineage_id = lineage["id"]
|
|
|
|
sync_api = LiteralAPI(self.api_key, self.url)
|
|
return await self.gql_helper(
|
|
*create_prompt_helper(
|
|
sync_api, lineage_id, template_messages, settings, tools
|
|
)
|
|
)
|
|
|
|
@deprecated("Please use `get_or_create_prompt` instead.")
|
|
async def create_prompt(
|
|
self,
|
|
name: str,
|
|
template_messages: List[GenerationMessage],
|
|
settings: Optional[ProviderSettings] = None,
|
|
) -> "Prompt":
|
|
return await self.get_or_create_prompt(name, template_messages, settings)
|
|
|
|
async def create_prompt_variant(
|
|
self,
|
|
name: str,
|
|
template_messages: List[GenerationMessage],
|
|
settings: Optional[ProviderSettings] = None,
|
|
tools: Optional[List[Dict]] = None,
|
|
) -> Optional[str]:
|
|
lineage = await self.gql_helper(*get_prompt_lineage_helper(name))
|
|
lineage_id = lineage["id"] if lineage else None
|
|
return await self.gql_helper(
|
|
*create_prompt_variant_helper(
|
|
lineage_id, template_messages, settings, tools
|
|
)
|
|
)
|
|
|
|
async def get_prompt(
|
|
self,
|
|
id: Optional[str] = None,
|
|
name: Optional[str] = None,
|
|
version: Optional[int] = None,
|
|
) -> Prompt:
|
|
if not (id or name):
|
|
raise ValueError("At least the `id` or the `name` must be provided.")
|
|
|
|
sync_api = LiteralAPI(self.api_key, self.url)
|
|
(
|
|
get_prompt_query,
|
|
description,
|
|
variables,
|
|
process_response,
|
|
timeout,
|
|
cached_prompt,
|
|
) = get_prompt_helper(
|
|
api=sync_api, id=id, name=name, version=version, cache=self.cache
|
|
)
|
|
|
|
try:
|
|
if id:
|
|
prompt = await self.gql_helper(
|
|
get_prompt_query, description, variables, process_response, timeout
|
|
)
|
|
elif name:
|
|
prompt = await self.gql_helper(
|
|
get_prompt_query, description, variables, process_response, timeout
|
|
)
|
|
|
|
return prompt
|
|
|
|
except Exception as e:
|
|
if cached_prompt:
|
|
logger.warning("Failed to get prompt from API, returning cached prompt")
|
|
return cached_prompt
|
|
raise e
|
|
|
|
async def update_prompt_ab_testing(
|
|
self, name: str, rollouts: List["PromptRollout"]
|
|
) -> Dict:
|
|
return await self.gql_helper(
|
|
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
|
|
)
|
|
|
|
async def get_prompt_ab_testing(self, name: str) -> List["PromptRollout"]:
|
|
return await self.gql_helper(*get_prompt_ab_testing_helper(name=name))
|
|
|
|
##################################################################################
|
|
# Misc APIs #
|
|
##################################################################################
|
|
|
|
async def get_my_project_id(self) -> str:
|
|
response = await self.make_rest_call("/my-project", {})
|
|
return response["projectId"]
|