ai-station/.venv/lib/python3.12/site-packages/chainlit/step.py

492 lines
14 KiB
Python
Raw Permalink Normal View History

import asyncio
import inspect
import json
import time
import uuid
from copy import deepcopy
from functools import wraps
from typing import Callable, Dict, List, Optional, TypedDict, Union
from literalai import BaseGeneration
from literalai.observability.step import StepType, TrueStepType
from chainlit.config import config
from chainlit.context import CL_RUN_NAMES, context, local_steps
from chainlit.data import get_data_layer
from chainlit.element import Element
from chainlit.logger import logger
from chainlit.types import FeedbackDict
from chainlit.utils import utc_now
def check_add_step_in_cot(step: "Step"):
is_message = step.type in [
"user_message",
"assistant_message",
]
is_cl_run = step.name in CL_RUN_NAMES and step.type == "run"
if config.ui.cot == "hidden" and not is_message and not is_cl_run:
return False
return True
def stub_step(step: "Step") -> "StepDict":
return {
"type": step.type,
"name": step.name,
"id": step.id,
"parentId": step.parent_id,
"threadId": step.thread_id,
"input": "",
"output": "",
}
class StepDict(TypedDict, total=False):
name: str
type: StepType
id: str
threadId: str
parentId: Optional[str]
command: Optional[str]
modes: Optional[Dict[str, str]]
streaming: bool
waitForAnswer: Optional[bool]
isError: Optional[bool]
metadata: Dict
tags: Optional[List[str]]
input: str
output: str
createdAt: Optional[str]
start: Optional[str]
end: Optional[str]
generation: Optional[Dict]
showInput: Optional[Union[bool, str]]
defaultOpen: Optional[bool]
language: Optional[str]
feedback: Optional[FeedbackDict]
def flatten_args_kwargs(func, args, kwargs):
signature = inspect.signature(func)
bound_arguments = signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return {k: deepcopy(v) for k, v in bound_arguments.arguments.items()}
def step(
original_function: Optional[Callable] = None,
*,
name: Optional[str] = "",
type: TrueStepType = "undefined",
id: Optional[str] = None,
parent_id: Optional[str] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
language: Optional[str] = None,
show_input: Union[bool, str] = "json",
default_open: bool = False,
):
"""Step decorator for async and sync functions."""
def wrapper(func: Callable):
nonlocal name
if not name:
name = func.__name__
# Handle async decorator
if inspect.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
async with Step(
type=type,
name=name,
id=id,
parent_id=parent_id,
tags=tags,
language=language,
show_input=show_input,
default_open=default_open,
metadata=metadata,
) as step:
try:
step.input = flatten_args_kwargs(func, args, kwargs)
except Exception as e:
logger.exception(e)
result = await func(*args, **kwargs)
try:
if result and not step.output:
step.output = result
except Exception as e:
step.is_error = True
step.output = str(e)
return result
return async_wrapper
else:
# Handle sync decorator
@wraps(func)
def sync_wrapper(*args, **kwargs):
with Step(
type=type,
name=name,
id=id,
parent_id=parent_id,
tags=tags,
language=language,
show_input=show_input,
default_open=default_open,
metadata=metadata,
) as step:
try:
step.input = flatten_args_kwargs(func, args, kwargs)
except Exception as e:
logger.exception(e)
result = func(*args, **kwargs)
try:
if result and not step.output:
step.output = result
except Exception as e:
step.is_error = True
step.output = str(e)
return result
return sync_wrapper
func = original_function
if not func:
return wrapper
else:
return wrapper(func)
class Step:
# Constructor
name: str
type: TrueStepType
id: str
parent_id: Optional[str]
streaming: bool
persisted: bool
show_input: Union[bool, str]
is_error: Optional[bool]
metadata: Dict
tags: Optional[List[str]]
thread_id: str
created_at: Union[str, None]
start: Union[str, None]
end: Union[str, None]
generation: Optional[BaseGeneration]
language: Optional[str]
default_open: Optional[bool]
elements: Optional[List[Element]]
fail_on_persist_error: bool
def __init__(
self,
name: Optional[str] = config.ui.name,
type: TrueStepType = "undefined",
id: Optional[str] = None,
parent_id: Optional[str] = None,
elements: Optional[List[Element]] = None,
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
language: Optional[str] = None,
default_open: Optional[bool] = False,
show_input: Union[bool, str] = "json",
thread_id: Optional[str] = None,
):
time.sleep(0.001)
self._input = ""
self._output = ""
self.thread_id = thread_id or context.session.thread_id
self.name = name or ""
self.type = type
self.id = id or str(uuid.uuid4())
self.metadata = metadata or {}
self.tags = tags
self.is_error = False
self.show_input = show_input
self.parent_id = parent_id
self.language = language
self.default_open = default_open
self.generation = None
self.elements = elements or []
self.created_at = utc_now()
self.start = None
self.end = None
self.streaming = False
self.persisted = False
self.fail_on_persist_error = False
def _clean_content(self, content):
"""
Recursively checks and converts bytes objects in content.
"""
def handle_bytes(item):
if isinstance(item, bytes):
return "STRIPPED_BINARY_DATA"
elif isinstance(item, dict):
return {k: handle_bytes(v) for k, v in item.items()}
elif isinstance(item, list):
return [handle_bytes(i) for i in item]
elif isinstance(item, tuple):
return tuple(handle_bytes(i) for i in item)
return item
return handle_bytes(content)
def _process_content(self, content, set_language=False):
if content is None:
return ""
content = self._clean_content(content)
if (
isinstance(content, dict)
or isinstance(content, list)
or isinstance(content, tuple)
):
try:
processed_content = json.dumps(content, indent=4, ensure_ascii=False)
if set_language:
self.language = "json"
except TypeError:
processed_content = str(content).replace("\\n", "\n")
if set_language:
self.language = "text"
elif isinstance(content, str):
processed_content = content
else:
processed_content = str(content).replace("\\n", "\n")
if set_language:
self.language = "text"
return processed_content
@property
def input(self):
return self._input
@input.setter
def input(self, content: Union[Dict, str]):
self._input = self._process_content(content, set_language=False)
@property
def output(self):
return self._output
@output.setter
def output(self, content: Union[Dict, str]):
self._output = self._process_content(content, set_language=True)
def to_dict(self) -> StepDict:
_dict: StepDict = {
"name": self.name,
"type": self.type,
"id": self.id,
"threadId": self.thread_id,
"parentId": self.parent_id,
"streaming": self.streaming,
"metadata": self.metadata,
"tags": self.tags,
"input": self.input,
"isError": self.is_error,
"output": self.output,
"createdAt": self.created_at,
"start": self.start,
"end": self.end,
"language": self.language,
"defaultOpen": self.default_open,
"showInput": self.show_input,
"generation": self.generation.to_dict() if self.generation else None,
}
return _dict
async def update(self):
"""
Update a step already sent to the UI.
"""
if self.streaming:
self.streaming = False
step_dict = self.to_dict()
data_layer = get_data_layer()
if data_layer:
try:
asyncio.create_task(data_layer.update_step(step_dict.copy()))
except Exception as e:
if self.fail_on_persist_error:
raise e
logger.error(f"Failed to persist step update: {e!s}")
tasks = [el.send(for_id=self.id) for el in self.elements]
await asyncio.gather(*tasks)
if not check_add_step_in_cot(self):
await context.emitter.update_step(stub_step(self))
else:
await context.emitter.update_step(step_dict)
return True
async def remove(self):
"""
Remove a step already sent to the UI.
"""
step_dict = self.to_dict()
data_layer = get_data_layer()
if data_layer:
try:
asyncio.create_task(data_layer.delete_step(self.id))
except Exception as e:
if self.fail_on_persist_error:
raise e
logger.error(f"Failed to persist step deletion: {e!s}")
await context.emitter.delete_step(step_dict)
return True
async def send(self):
if self.persisted:
return self
if config.code.author_rename:
self.name = await config.code.author_rename(self.name)
if self.streaming:
self.streaming = False
step_dict = self.to_dict()
data_layer = get_data_layer()
if data_layer:
try:
asyncio.create_task(data_layer.create_step(step_dict.copy()))
self.persisted = True
except Exception as e:
if self.fail_on_persist_error:
raise e
logger.error(f"Failed to persist step creation: {e!s}")
tasks = [el.send(for_id=self.id) for el in self.elements]
await asyncio.gather(*tasks)
if not check_add_step_in_cot(self):
await context.emitter.send_step(stub_step(self))
else:
await context.emitter.send_step(step_dict)
return self
async def stream_token(self, token: str, is_sequence=False, is_input=False):
"""
Sends a token to the UI.
Once all tokens have been streamed, call .send() to end the stream and persist the step if persistence is enabled.
"""
if not token:
return
if is_sequence:
if is_input:
self.input = token
else:
self.output = token
else:
if is_input:
self.input += token
else:
self.output += token
assert self.id
if not check_add_step_in_cot(self):
await context.emitter.send_step(stub_step(self))
return
if not self.streaming:
self.streaming = True
step_dict = self.to_dict()
await context.emitter.stream_start(step_dict)
else:
await context.emitter.send_token(
id=self.id, token=token, is_sequence=is_sequence, is_input=is_input
)
# Handle parameter less decorator
def __call__(self, func):
return step(
original_function=func,
type=self.type,
name=self.name,
id=self.id,
parent_id=self.parent_id,
thread_id=self.thread_id,
)
# Handle Context Manager Protocol
async def __aenter__(self):
self.start = utc_now()
previous_steps = local_steps.get() or []
parent_step = previous_steps[-1] if previous_steps else None
if not self.parent_id:
if parent_step:
self.parent_id = parent_step.id
local_steps.set(previous_steps + [self])
await self.send()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.end = utc_now()
if exc_type:
self.output = str(exc_val)
self.is_error = True
current_steps = local_steps.get()
if current_steps and self in current_steps:
current_steps.remove(self)
local_steps.set(current_steps)
await self.update()
def __enter__(self):
self.start = utc_now()
previous_steps = local_steps.get() or []
parent_step = previous_steps[-1] if previous_steps else None
if not self.parent_id:
if parent_step:
self.parent_id = parent_step.id
local_steps.set(previous_steps + [self])
asyncio.create_task(self.send())
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end = utc_now()
if exc_type:
self.output = str(exc_val)
self.is_error = True
current_steps = local_steps.get()
if current_steps and self in current_steps:
current_steps.remove(self)
local_steps.set(current_steps)
asyncio.create_task(self.update())