ai-station/.venv/lib/python3.12/site-packages/literalai/observability/generation.py

217 lines
8.3 KiB
Python
Raw Normal View History

from enum import Enum, unique
from typing import Dict, List, Literal, Optional, Union
from pydantic import Field
from pydantic.dataclasses import dataclass
from typing_extensions import TypedDict
from literalai.my_types import ImageUrlContent, TextContent, Utils
GenerationMessageRole = Literal["user", "assistant", "tool", "function", "system"]
@unique
class GenerationType(str, Enum):
CHAT = "CHAT"
COMPLETION = "COMPLETION"
def __str__(self):
return self.value
def __repr__(self):
return f"GenerationType.{self.name}"
def to_json(self):
return self.value
class GenerationMessage(TypedDict, total=False):
uuid: Optional[str]
templated: Optional[bool]
name: Optional[str]
role: Optional[GenerationMessageRole]
content: Optional[Union[str, List[Union[TextContent, ImageUrlContent]]]]
function_call: Optional[Dict]
tool_calls: Optional[List[Dict]]
tool_call_id: Optional[str]
@dataclass(repr=False)
class BaseGeneration(Utils):
"""
Base class for generation objects, containing common attributes and methods.
Attributes:
id (Optional[str]): The unique identifier of the generation.
prompt_id (Optional[str]): The unique identifier of the prompt associated with the generation.
provider (Optional[str]): The provider of the generation.
model (Optional[str]): The model used for the generation.
error (Optional[str]): Any error message associated with the generation.
settings (Optional[Dict]): Settings used for the generation.
variables (Optional[Dict]): Variables used in the generation.
tags (Optional[List[str]]): Tags associated with the generation.
metadata (Optional[Dict]): Metadata associated with the generation.
tools (Optional[List[Dict]]): Tools used in the generation.
token_count (Optional[int]): The total number of tokens in the generation.
input_token_count (Optional[int]): The number of input tokens in the generation.
output_token_count (Optional[int]): The number of output tokens in the generation.
tt_first_token (Optional[float]): Time to first token in the generation.
token_throughput_in_s (Optional[float]): Token throughput in seconds.
duration (Optional[float]): Duration of the generation.
Methods:
from_dict(cls, generation_dict: Dict) -> Union["ChatGeneration", "CompletionGeneration"]:
Creates a generation object from a dictionary.
to_dict(self) -> Dict:
Converts the generation object to a dictionary.
"""
id: Optional[str] = None
prompt_id: Optional[str] = None
provider: Optional[str] = None
model: Optional[str] = None
error: Optional[str] = None
settings: Optional[Dict] = Field(default_factory=lambda: {})
variables: Optional[Dict] = Field(default_factory=lambda: {})
tags: Optional[List[str]] = Field(default_factory=lambda: [])
metadata: Optional[Dict] = Field(default_factory=lambda: {})
tools: Optional[List[Dict]] = None
token_count: Optional[int] = None
input_token_count: Optional[int] = None
output_token_count: Optional[int] = None
tt_first_token: Optional[float] = None
token_throughput_in_s: Optional[float] = None
duration: Optional[float] = None
@classmethod
def from_dict(
cls, generation_dict: Dict
) -> Union["ChatGeneration", "CompletionGeneration"]:
type = GenerationType(generation_dict.get("type"))
if type == GenerationType.CHAT:
return ChatGeneration.from_dict(generation_dict)
elif type == GenerationType.COMPLETION:
return CompletionGeneration.from_dict(generation_dict)
else:
raise ValueError(f"Unknown generation type: {type}")
def to_dict(self):
_dict = {
"promptId": self.prompt_id,
"provider": self.provider,
"model": self.model,
"error": self.error,
"settings": self.settings,
"variables": self.variables,
"tags": self.tags,
"metadata": self.metadata,
"tools": self.tools,
"tokenCount": self.token_count,
"inputTokenCount": self.input_token_count,
"outputTokenCount": self.output_token_count,
"ttFirstToken": self.tt_first_token,
"tokenThroughputInSeconds": self.token_throughput_in_s,
"duration": self.duration,
}
if self.id:
_dict["id"] = self.id
return _dict
@dataclass(repr=False)
class CompletionGeneration(BaseGeneration, Utils):
"""
Represents a completion generation with a prompt and its corresponding completion.
Attributes:
prompt (Optional[str]): The prompt text for the generation.
completion (Optional[str]): The generated completion text.
type (GenerationType): The type of generation, which is set to GenerationType.COMPLETION.
"""
prompt: Optional[str] = None
completion: Optional[str] = None
type = GenerationType.COMPLETION
def to_dict(self):
_dict = super().to_dict()
_dict.update(
{
"prompt": self.prompt,
"completion": self.completion,
"type": self.type.value,
}
)
return _dict
@classmethod
def from_dict(cls, generation_dict: Dict):
return CompletionGeneration(
id=generation_dict.get("id"),
prompt_id=generation_dict.get("promptId"),
error=generation_dict.get("error"),
tags=generation_dict.get("tags"),
provider=generation_dict.get("provider"),
model=generation_dict.get("model"),
variables=generation_dict.get("variables"),
tools=generation_dict.get("tools"),
settings=generation_dict.get("settings"),
token_count=generation_dict.get("tokenCount"),
input_token_count=generation_dict.get("inputTokenCount"),
output_token_count=generation_dict.get("outputTokenCount"),
tt_first_token=generation_dict.get("ttFirstToken"),
token_throughput_in_s=generation_dict.get("tokenThroughputInSeconds"),
duration=generation_dict.get("duration"),
prompt=generation_dict.get("prompt"),
completion=generation_dict.get("completion"),
)
@dataclass(repr=False)
class ChatGeneration(BaseGeneration, Utils):
"""
Represents a chat generation with a list of messages and a message completion.
Attributes:
messages (Optional[List[GenerationMessage]]): The list of messages in the chat generation.
message_completion (Optional[GenerationMessage]): The completion message of the chat generation.
type (GenerationType): The type of generation, which is set to GenerationType.CHAT.
"""
type = GenerationType.CHAT
messages: Optional[List[GenerationMessage]] = Field(default_factory=lambda: [])
message_completion: Optional[GenerationMessage] = None
def to_dict(self):
_dict = super().to_dict()
_dict.update(
{
"messages": self.messages,
"messageCompletion": self.message_completion,
"type": self.type.value,
}
)
return _dict
@classmethod
def from_dict(self, generation_dict: Dict):
return ChatGeneration(
id=generation_dict.get("id"),
prompt_id=generation_dict.get("promptId"),
error=generation_dict.get("error"),
tags=generation_dict.get("tags"),
provider=generation_dict.get("provider"),
model=generation_dict.get("model"),
variables=generation_dict.get("variables"),
tools=generation_dict.get("tools"),
settings=generation_dict.get("settings"),
token_count=generation_dict.get("tokenCount"),
input_token_count=generation_dict.get("inputTokenCount"),
output_token_count=generation_dict.get("outputTokenCount"),
tt_first_token=generation_dict.get("ttFirstToken"),
token_throughput_in_s=generation_dict.get("tokenThroughputInSeconds"),
duration=generation_dict.get("duration"),
messages=generation_dict.get("messages", []),
message_completion=generation_dict.get("messageCompletion"),
)