ai-station/.venv/lib/python3.12/site-packages/literalai/observability/generation.py

217 lines
8.3 KiB
Python

from enum import Enum, unique
from typing import Dict, List, Literal, Optional, Union
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import TypedDict
from literalai.my_types import ImageUrlContent, TextContent, Utils
GenerationMessageRole = Literal["user", "assistant", "tool", "function", "system"]
@unique
class GenerationType(str, Enum):
CHAT = "CHAT"
COMPLETION = "COMPLETION"
def __str__(self):
return self.value
def __repr__(self):
return f"GenerationType.{self.name}"
def to_json(self):
return self.value
class GenerationMessage(TypedDict, total=False):
uuid: Optional[str]
templated: Optional[bool]
name: Optional[str]
role: Optional[GenerationMessageRole]
content: Optional[Union[str, List[Union[TextContent, ImageUrlContent]]]]
function_call: Optional[Dict]
tool_calls: Optional[List[Dict]]
tool_call_id: Optional[str]
@dataclass(repr=False)
class BaseGeneration(Utils):
"""
Base class for generation objects, containing common attributes and methods.
Attributes:
id (Optional[str]): The unique identifier of the generation.
prompt_id (Optional[str]): The unique identifier of the prompt associated with the generation.
provider (Optional[str]): The provider of the generation.
model (Optional[str]): The model used for the generation.
error (Optional[str]): Any error message associated with the generation.
settings (Optional[Dict]): Settings used for the generation.
variables (Optional[Dict]): Variables used in the generation.
tags (Optional[List[str]]): Tags associated with the generation.
metadata (Optional[Dict]): Metadata associated with the generation.
tools (Optional[List[Dict]]): Tools used in the generation.
token_count (Optional[int]): The total number of tokens in the generation.
input_token_count (Optional[int]): The number of input tokens in the generation.
output_token_count (Optional[int]): The number of output tokens in the generation.
tt_first_token (Optional[float]): Time to first token in the generation.
token_throughput_in_s (Optional[float]): Token throughput in seconds.
duration (Optional[float]): Duration of the generation.
Methods:
from_dict(cls, generation_dict: Dict) -> Union["ChatGeneration", "CompletionGeneration"]:
Creates a generation object from a dictionary.
to_dict(self) -> Dict:
Converts the generation object to a dictionary.
"""
id: Optional[str] = None
prompt_id: Optional[str] = None
provider: Optional[str] = None
model: Optional[str] = None
error: Optional[str] = None
settings: Optional[Dict] = Field(default_factory=lambda: {})
variables: Optional[Dict] = Field(default_factory=lambda: {})
tags: Optional[List[str]] = Field(default_factory=lambda: [])
metadata: Optional[Dict] = Field(default_factory=lambda: {})
tools: Optional[List[Dict]] = None
token_count: Optional[int] = None
input_token_count: Optional[int] = None
output_token_count: Optional[int] = None
tt_first_token: Optional[float] = None
token_throughput_in_s: Optional[float] = None
duration: Optional[float] = None
@classmethod
def from_dict(
cls, generation_dict: Dict
) -> Union["ChatGeneration", "CompletionGeneration"]:
type = GenerationType(generation_dict.get("type"))
if type == GenerationType.CHAT:
return ChatGeneration.from_dict(generation_dict)
elif type == GenerationType.COMPLETION:
return CompletionGeneration.from_dict(generation_dict)
else:
raise ValueError(f"Unknown generation type: {type}")
def to_dict(self):
_dict = {
"promptId": self.prompt_id,
"provider": self.provider,
"model": self.model,
"error": self.error,
"settings": self.settings,
"variables": self.variables,
"tags": self.tags,
"metadata": self.metadata,
"tools": self.tools,
"tokenCount": self.token_count,
"inputTokenCount": self.input_token_count,
"outputTokenCount": self.output_token_count,
"ttFirstToken": self.tt_first_token,
"tokenThroughputInSeconds": self.token_throughput_in_s,
"duration": self.duration,
}
if self.id:
_dict["id"] = self.id
return _dict
@dataclass(repr=False)
class CompletionGeneration(BaseGeneration, Utils):
"""
Represents a completion generation with a prompt and its corresponding completion.
Attributes:
prompt (Optional[str]): The prompt text for the generation.
completion (Optional[str]): The generated completion text.
type (GenerationType): The type of generation, which is set to GenerationType.COMPLETION.
"""
prompt: Optional[str] = None
completion: Optional[str] = None
type = GenerationType.COMPLETION
def to_dict(self):
_dict = super().to_dict()
_dict.update(
{
"prompt": self.prompt,
"completion": self.completion,
"type": self.type.value,
}
)
return _dict
@classmethod
def from_dict(cls, generation_dict: Dict):
return CompletionGeneration(
id=generation_dict.get("id"),
prompt_id=generation_dict.get("promptId"),
error=generation_dict.get("error"),
tags=generation_dict.get("tags"),
provider=generation_dict.get("provider"),
model=generation_dict.get("model"),
variables=generation_dict.get("variables"),
tools=generation_dict.get("tools"),
settings=generation_dict.get("settings"),
token_count=generation_dict.get("tokenCount"),
input_token_count=generation_dict.get("inputTokenCount"),
output_token_count=generation_dict.get("outputTokenCount"),
tt_first_token=generation_dict.get("ttFirstToken"),
token_throughput_in_s=generation_dict.get("tokenThroughputInSeconds"),
duration=generation_dict.get("duration"),
prompt=generation_dict.get("prompt"),
completion=generation_dict.get("completion"),
)
@dataclass(repr=False)
class ChatGeneration(BaseGeneration, Utils):
"""
Represents a chat generation with a list of messages and a message completion.
Attributes:
messages (Optional[List[GenerationMessage]]): The list of messages in the chat generation.
message_completion (Optional[GenerationMessage]): The completion message of the chat generation.
type (GenerationType): The type of generation, which is set to GenerationType.CHAT.
"""
type = GenerationType.CHAT
messages: Optional[List[GenerationMessage]] = Field(default_factory=lambda: [])
message_completion: Optional[GenerationMessage] = None
def to_dict(self):
_dict = super().to_dict()
_dict.update(
{
"messages": self.messages,
"messageCompletion": self.message_completion,
"type": self.type.value,
}
)
return _dict
@classmethod
def from_dict(self, generation_dict: Dict):
return ChatGeneration(
id=generation_dict.get("id"),
prompt_id=generation_dict.get("promptId"),
error=generation_dict.get("error"),
tags=generation_dict.get("tags"),
provider=generation_dict.get("provider"),
model=generation_dict.get("model"),
variables=generation_dict.get("variables"),
tools=generation_dict.get("tools"),
settings=generation_dict.get("settings"),
token_count=generation_dict.get("tokenCount"),
input_token_count=generation_dict.get("inputTokenCount"),
output_token_count=generation_dict.get("outputTokenCount"),
tt_first_token=generation_dict.get("ttFirstToken"),
token_throughput_in_s=generation_dict.get("tokenThroughputInSeconds"),
duration=generation_dict.get("duration"),
messages=generation_dict.get("messages", []),
message_completion=generation_dict.get("messageCompletion"),
)