133 lines
4.1 KiB
Python
133 lines
4.1 KiB
Python
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict
|
|
|
|
from literalai.context import active_experiment_item_run_id_var
|
|
from literalai.my_types import Utils
|
|
from literalai.observability.step import ScoreDict
|
|
|
|
if TYPE_CHECKING:
|
|
from literalai.api import LiteralAPI
|
|
|
|
|
|
class DatasetExperimentItemDict(TypedDict, total=False):
|
|
id: str
|
|
datasetExperimentId: str
|
|
datasetItemId: Optional[str]
|
|
scores: List[ScoreDict]
|
|
input: Optional[Dict]
|
|
output: Optional[Dict]
|
|
experimentRunId: Optional[str]
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class DatasetExperimentItem(Utils):
|
|
"""
|
|
An item of a `DatasetExperiment`: it may be linked to a `DatasetItem`.
|
|
"""
|
|
|
|
id: str
|
|
dataset_experiment_id: str
|
|
dataset_item_id: Optional[str]
|
|
scores: List[ScoreDict]
|
|
input: Optional[Dict]
|
|
output: Optional[Dict]
|
|
experiment_run_id: Optional[str]
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"id": self.id,
|
|
"datasetExperimentId": self.dataset_experiment_id,
|
|
"datasetItemId": self.dataset_item_id,
|
|
"experimentRunId": self.experiment_run_id,
|
|
"scores": self.scores,
|
|
"input": self.input,
|
|
"output": self.output,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, item: DatasetExperimentItemDict) -> "DatasetExperimentItem":
|
|
return cls(
|
|
id=item.get("id", ""),
|
|
experiment_run_id=item.get("experimentRunId"),
|
|
dataset_experiment_id=item.get("datasetExperimentId", ""),
|
|
dataset_item_id=item.get("datasetItemId"),
|
|
scores=item.get("scores", []),
|
|
input=item.get("input"),
|
|
output=item.get("output"),
|
|
)
|
|
|
|
|
|
class DatasetExperimentDict(TypedDict, total=False):
|
|
id: str
|
|
createdAt: str
|
|
name: str
|
|
datasetId: str
|
|
params: Dict
|
|
promptExperimentId: Optional[str]
|
|
items: Optional[List[DatasetExperimentItemDict]]
|
|
|
|
|
|
@dataclass(repr=False)
|
|
class DatasetExperiment(Utils):
|
|
"""
|
|
An experiment, linked or not to a `Dataset`.
|
|
"""
|
|
|
|
api: "LiteralAPI"
|
|
id: str
|
|
created_at: str
|
|
name: str
|
|
dataset_id: Optional[str]
|
|
params: Optional[Dict]
|
|
prompt_variant_id: Optional[str] = None
|
|
items: List[DatasetExperimentItem] = field(default_factory=lambda: [])
|
|
|
|
def log(self, item_dict: DatasetExperimentItemDict) -> DatasetExperimentItem:
|
|
"""
|
|
Logs an item to the dataset experiment.
|
|
"""
|
|
experiment_run_id = active_experiment_item_run_id_var.get()
|
|
dataset_experiment_item = DatasetExperimentItem.from_dict(
|
|
{
|
|
"experimentRunId": experiment_run_id,
|
|
"datasetExperimentId": self.id,
|
|
"datasetItemId": item_dict.get("datasetItemId"),
|
|
"input": item_dict.get("input", {}),
|
|
"output": item_dict.get("output", {}),
|
|
"scores": item_dict.get("scores", []),
|
|
}
|
|
)
|
|
|
|
item = self.api.create_experiment_item(dataset_experiment_item)
|
|
self.items.append(item)
|
|
return item
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"id": self.id,
|
|
"createdAt": self.created_at,
|
|
"name": self.name,
|
|
"datasetId": self.dataset_id,
|
|
"promptExperimentId": self.prompt_variant_id,
|
|
"params": self.params,
|
|
"items": [item.to_dict() for item in self.items],
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(
|
|
cls, api: "LiteralAPI", dataset_experiment: DatasetExperimentDict
|
|
) -> "DatasetExperiment":
|
|
items = dataset_experiment.get("items", [])
|
|
if not isinstance(items, list):
|
|
raise Exception("Dataset items should be a list.")
|
|
return cls(
|
|
api=api,
|
|
id=dataset_experiment.get("id", ""),
|
|
created_at=dataset_experiment.get("createdAt", ""),
|
|
name=dataset_experiment.get("name", ""),
|
|
dataset_id=dataset_experiment.get("datasetId", ""),
|
|
params=dataset_experiment.get("params"),
|
|
prompt_variant_id=dataset_experiment.get("promptExperimentId"),
|
|
items=[DatasetExperimentItem.from_dict(item) for item in items],
|
|
)
|