ai-station/.venv/lib/python3.12/site-packages/literalai/evaluation/dataset_experiment.py

133 lines
4.1 KiB
Python

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict
from literalai.context import active_experiment_item_run_id_var
from literalai.my_types import Utils
from literalai.observability.step import ScoreDict
if TYPE_CHECKING:
from literalai.api import LiteralAPI
class DatasetExperimentItemDict(TypedDict, total=False):
id: str
datasetExperimentId: str
datasetItemId: Optional[str]
scores: List[ScoreDict]
input: Optional[Dict]
output: Optional[Dict]
experimentRunId: Optional[str]
@dataclass(repr=False)
class DatasetExperimentItem(Utils):
"""
An item of a `DatasetExperiment`: it may be linked to a `DatasetItem`.
"""
id: str
dataset_experiment_id: str
dataset_item_id: Optional[str]
scores: List[ScoreDict]
input: Optional[Dict]
output: Optional[Dict]
experiment_run_id: Optional[str]
def to_dict(self):
return {
"id": self.id,
"datasetExperimentId": self.dataset_experiment_id,
"datasetItemId": self.dataset_item_id,
"experimentRunId": self.experiment_run_id,
"scores": self.scores,
"input": self.input,
"output": self.output,
}
@classmethod
def from_dict(cls, item: DatasetExperimentItemDict) -> "DatasetExperimentItem":
return cls(
id=item.get("id", ""),
experiment_run_id=item.get("experimentRunId"),
dataset_experiment_id=item.get("datasetExperimentId", ""),
dataset_item_id=item.get("datasetItemId"),
scores=item.get("scores", []),
input=item.get("input"),
output=item.get("output"),
)
class DatasetExperimentDict(TypedDict, total=False):
id: str
createdAt: str
name: str
datasetId: str
params: Dict
promptExperimentId: Optional[str]
items: Optional[List[DatasetExperimentItemDict]]
@dataclass(repr=False)
class DatasetExperiment(Utils):
"""
An experiment, linked or not to a `Dataset`.
"""
api: "LiteralAPI"
id: str
created_at: str
name: str
dataset_id: Optional[str]
params: Optional[Dict]
prompt_variant_id: Optional[str] = None
items: List[DatasetExperimentItem] = field(default_factory=lambda: [])
def log(self, item_dict: DatasetExperimentItemDict) -> DatasetExperimentItem:
"""
Logs an item to the dataset experiment.
"""
experiment_run_id = active_experiment_item_run_id_var.get()
dataset_experiment_item = DatasetExperimentItem.from_dict(
{
"experimentRunId": experiment_run_id,
"datasetExperimentId": self.id,
"datasetItemId": item_dict.get("datasetItemId"),
"input": item_dict.get("input", {}),
"output": item_dict.get("output", {}),
"scores": item_dict.get("scores", []),
}
)
item = self.api.create_experiment_item(dataset_experiment_item)
self.items.append(item)
return item
def to_dict(self):
return {
"id": self.id,
"createdAt": self.created_at,
"name": self.name,
"datasetId": self.dataset_id,
"promptExperimentId": self.prompt_variant_id,
"params": self.params,
"items": [item.to_dict() for item in self.items],
}
@classmethod
def from_dict(
cls, api: "LiteralAPI", dataset_experiment: DatasetExperimentDict
) -> "DatasetExperiment":
items = dataset_experiment.get("items", [])
if not isinstance(items, list):
raise Exception("Dataset items should be a list.")
return cls(
api=api,
id=dataset_experiment.get("id", ""),
created_at=dataset_experiment.get("createdAt", ""),
name=dataset_experiment.get("name", ""),
dataset_id=dataset_experiment.get("datasetId", ""),
params=dataset_experiment.get("params"),
prompt_variant_id=dataset_experiment.get("promptExperimentId"),
items=[DatasetExperimentItem.from_dict(item) for item in items],
)