463 lines
16 KiB
Python
463 lines
16 KiB
Python
import asyncio
|
|
import uuid
|
|
from typing import Any, Dict, List, Literal, Optional, Union, cast, get_args
|
|
|
|
from socketio.exceptions import TimeoutError
|
|
|
|
from chainlit.chat_context import chat_context
|
|
from chainlit.config import config
|
|
from chainlit.data import get_data_layer
|
|
from chainlit.element import Element, ElementDict, File
|
|
from chainlit.logger import logger
|
|
from chainlit.message import Message
|
|
from chainlit.mode import Mode
|
|
from chainlit.session import BaseSession, WebsocketSession
|
|
from chainlit.step import StepDict
|
|
from chainlit.types import (
|
|
AskActionResponse,
|
|
AskElementResponse,
|
|
AskFileSpec,
|
|
AskSpec,
|
|
CommandDict,
|
|
FileDict,
|
|
FileReference,
|
|
MessagePayload,
|
|
OutputAudioChunk,
|
|
ThreadDict,
|
|
ToastType,
|
|
)
|
|
from chainlit.user import PersistedUser
|
|
from chainlit.utils import utc_now
|
|
|
|
|
|
class BaseChainlitEmitter:
|
|
"""
|
|
Chainlit Emitter Stub class. This class is used for testing purposes.
|
|
It stubs the ChainlitEmitter class and does nothing on function calls.
|
|
"""
|
|
|
|
session: BaseSession
|
|
enabled: bool = True
|
|
|
|
def __init__(self, session: BaseSession) -> None:
|
|
"""Initialize with the user session."""
|
|
self.session = session
|
|
|
|
async def emit(self, event: str, data: Any):
|
|
"""Stub method to get the 'emit' property from the session."""
|
|
pass
|
|
|
|
async def emit_call(self):
|
|
"""Stub method to get the 'emit_call' property from the session."""
|
|
pass
|
|
|
|
async def resume_thread(self, thread_dict: ThreadDict):
|
|
"""Stub method to resume a thread."""
|
|
pass
|
|
|
|
async def send_resume_thread_error(self, error: str):
|
|
"""Stub method to send a resume thread error."""
|
|
pass
|
|
|
|
async def send_element(self, element_dict: ElementDict):
|
|
"""Stub method to send an element to the UI."""
|
|
pass
|
|
|
|
async def update_audio_connection(self, state: Literal["on", "off"]):
|
|
"""Audio connection signaling."""
|
|
pass
|
|
|
|
async def send_audio_chunk(self, chunk: OutputAudioChunk):
|
|
"""Stub method to send an audio chunk to the UI."""
|
|
pass
|
|
|
|
async def send_audio_interrupt(self):
|
|
"""Stub method to interrupt the current audio response."""
|
|
pass
|
|
|
|
async def send_step(self, step_dict: StepDict):
|
|
"""Stub method to send a message to the UI."""
|
|
pass
|
|
|
|
async def update_step(self, step_dict: StepDict):
|
|
"""Stub method to update a message in the UI."""
|
|
pass
|
|
|
|
async def delete_step(self, step_dict: StepDict):
|
|
"""Stub method to delete a message in the UI."""
|
|
pass
|
|
|
|
def send_timeout(self, event: Literal["ask_timeout", "call_fn_timeout"]):
|
|
"""Stub method to send a timeout to the UI."""
|
|
pass
|
|
|
|
def clear(self, event: Literal["clear_ask", "clear_call_fn"]):
|
|
pass
|
|
|
|
async def init_thread(self, interaction: str):
|
|
pass
|
|
|
|
async def process_message(self, payload: MessagePayload) -> Message:
|
|
"""Stub method to process user message."""
|
|
return Message(content="")
|
|
|
|
async def send_ask_user(
|
|
self, step_dict: StepDict, spec: AskSpec, raise_on_timeout=False
|
|
) -> Optional[
|
|
Union["StepDict", "AskActionResponse", "AskElementResponse", List["FileDict"]]
|
|
]:
|
|
"""Stub method to send a prompt to the UI and wait for a response."""
|
|
pass
|
|
|
|
async def send_call_fn(
|
|
self, name: str, args: Dict[str, Any], timeout=300, raise_on_timeout=False
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Stub method to send a call function event to the copilot and wait for a response."""
|
|
pass
|
|
|
|
async def update_token_count(self, count: int):
|
|
"""Stub method to update the token count for the UI."""
|
|
pass
|
|
|
|
async def task_start(self):
|
|
"""Stub method to send a task start signal to the UI."""
|
|
pass
|
|
|
|
async def task_end(self):
|
|
"""Stub method to send a task end signal to the UI."""
|
|
pass
|
|
|
|
async def stream_start(self, step_dict: StepDict):
|
|
"""Stub method to send a stream start signal to the UI."""
|
|
pass
|
|
|
|
async def send_token(self, id: str, token: str, is_sequence=False, is_input=False):
|
|
"""Stub method to send a message token to the UI."""
|
|
pass
|
|
|
|
async def set_chat_settings(self, settings: dict):
|
|
"""Stub method to set chat settings."""
|
|
pass
|
|
|
|
async def set_commands(self, commands: List[CommandDict]):
|
|
"""Stub method to send the available commands to the UI."""
|
|
pass
|
|
|
|
async def set_modes(self, modes: List[Mode]):
|
|
"""Stub method to send the available modes to the UI."""
|
|
pass
|
|
|
|
async def send_window_message(self, data: Any):
|
|
"""Stub method to send custom data to the host window."""
|
|
pass
|
|
|
|
def send_toast(self, message: str, type: Optional[ToastType] = "info"):
|
|
"""Stub method to send a toast message to the UI."""
|
|
pass
|
|
|
|
|
|
class ChainlitEmitter(BaseChainlitEmitter):
|
|
"""
|
|
Chainlit Emitter class. The Emitter is not directly exposed to the developer.
|
|
Instead, the developer interacts with the Emitter through the methods and classes exposed in the __init__ file.
|
|
"""
|
|
|
|
session: WebsocketSession
|
|
|
|
def __init__(self, session: WebsocketSession) -> None:
|
|
"""Initialize with the user session."""
|
|
self.session = session
|
|
|
|
def _get_session_property(self, property_name: str, raise_error=True):
|
|
"""Helper method to get a property from the session."""
|
|
if not hasattr(self, "session") or not hasattr(self.session, property_name):
|
|
if raise_error:
|
|
raise ValueError(f"Session does not have property '{property_name}'")
|
|
else:
|
|
return None
|
|
return getattr(self.session, property_name)
|
|
|
|
@property
|
|
def emit(self):
|
|
"""Get the 'emit' property from the session."""
|
|
|
|
return self._get_session_property("emit")
|
|
|
|
@property
|
|
def emit_call(self):
|
|
"""Get the 'emit_call' property from the session."""
|
|
return self._get_session_property("emit_call")
|
|
|
|
def resume_thread(self, thread_dict: ThreadDict):
|
|
"""Send a thread to the UI to resume it"""
|
|
return self.emit("resume_thread", thread_dict)
|
|
|
|
def send_resume_thread_error(self, error: str):
|
|
"""Send a thread resume error to the UI"""
|
|
return self.emit("resume_thread_error", error)
|
|
|
|
async def update_audio_connection(self, state: Literal["on", "off"]):
|
|
"""Audio connection signaling."""
|
|
await self.emit("audio_connection", state)
|
|
|
|
async def send_audio_chunk(self, chunk: OutputAudioChunk):
|
|
"""Send an audio chunk to the UI."""
|
|
await self.emit("audio_chunk", chunk)
|
|
|
|
async def send_audio_interrupt(self):
|
|
"""Method to interrupt the current audio response."""
|
|
await self.emit("audio_interrupt", {})
|
|
|
|
async def send_element(self, element_dict: ElementDict):
|
|
"""Stub method to send an element to the UI."""
|
|
await self.emit("element", element_dict)
|
|
|
|
def send_step(self, step_dict: StepDict):
|
|
"""Send a message to the UI."""
|
|
return self.emit("new_message", step_dict)
|
|
|
|
def update_step(self, step_dict: StepDict):
|
|
"""Update a message in the UI."""
|
|
return self.emit("update_message", step_dict)
|
|
|
|
def delete_step(self, step_dict: StepDict):
|
|
"""Delete a message in the UI."""
|
|
return self.emit("delete_message", step_dict)
|
|
|
|
def send_timeout(self, event: Literal["ask_timeout", "call_fn_timeout"]):
|
|
return self.emit(event, {})
|
|
|
|
def clear(self, event: Literal["clear_ask", "clear_call_fn"]):
|
|
return self.emit(event, {})
|
|
|
|
async def flush_thread_queues(self, interaction: str):
|
|
if data_layer := get_data_layer():
|
|
if isinstance(self.session.user, PersistedUser):
|
|
user_id = self.session.user.id
|
|
else:
|
|
user_id = None
|
|
try:
|
|
should_tag_thread = (
|
|
self.session.chat_profile and config.features.auto_tag_thread
|
|
)
|
|
tags = [self.session.chat_profile] if should_tag_thread else None
|
|
await data_layer.update_thread(
|
|
thread_id=self.session.thread_id,
|
|
name=interaction,
|
|
user_id=user_id,
|
|
tags=tags,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error updating thread: {e}")
|
|
asyncio.create_task(self.session.flush_method_queue())
|
|
|
|
async def init_thread(self, interaction: str):
|
|
await self.flush_thread_queues(interaction)
|
|
await self.emit(
|
|
"first_interaction",
|
|
{
|
|
"interaction": interaction,
|
|
"thread_id": self.session.thread_id,
|
|
},
|
|
)
|
|
|
|
async def process_message(self, payload: MessagePayload):
|
|
step_dict = payload["message"]
|
|
file_refs = payload.get("fileReferences")
|
|
# UUID generated by the frontend should use v4
|
|
assert uuid.UUID(step_dict["id"]).version == 4
|
|
|
|
message = Message.from_dict(step_dict)
|
|
# Overwrite the created_at timestamp with the current time
|
|
message.created_at = utc_now()
|
|
chat_context.add(message)
|
|
|
|
asyncio.create_task(message._create())
|
|
|
|
if not self.session.has_first_interaction:
|
|
self.session.has_first_interaction = True
|
|
asyncio.create_task(self.init_thread(message.content))
|
|
|
|
if file_refs:
|
|
files = [
|
|
self.session.files[file["id"]]
|
|
for file in file_refs
|
|
if file["id"] in self.session.files
|
|
]
|
|
|
|
elements = [
|
|
Element.from_dict(
|
|
{
|
|
"id": file["id"],
|
|
"name": file["name"],
|
|
"path": str(file["path"]),
|
|
"chainlitKey": file["id"],
|
|
"display": "inline",
|
|
"type": Element.infer_type_from_mime(file["type"]),
|
|
"mime": file["type"],
|
|
}
|
|
)
|
|
for file in files
|
|
]
|
|
|
|
message.elements = elements
|
|
|
|
async def send_elements():
|
|
for element in message.elements:
|
|
await element.send(for_id=message.id)
|
|
|
|
asyncio.create_task(send_elements())
|
|
|
|
return message
|
|
|
|
async def send_ask_user(
|
|
self, step_dict: StepDict, spec: AskSpec, raise_on_timeout=False
|
|
):
|
|
"""Send a prompt to the UI and wait for a response."""
|
|
parent_id = str(step_dict["parentId"])
|
|
try:
|
|
if spec.type == "file":
|
|
self.session.files_spec[parent_id] = cast(AskFileSpec, spec)
|
|
|
|
# Send the prompt to the UI
|
|
user_res = await self.emit_call(
|
|
"ask", {"msg": step_dict, "spec": spec.to_dict()}, spec.timeout
|
|
) # type: Optional[Union["StepDict", "AskActionResponse", "AskElementResponse", List["FileReference"]]]
|
|
|
|
# End the task temporarily so that the User can answer the prompt
|
|
await self.task_end()
|
|
|
|
final_res: Optional[
|
|
Union[StepDict, AskActionResponse, AskElementResponse, List[FileDict]]
|
|
] = None
|
|
|
|
if user_res:
|
|
interaction: Union[str, None] = None
|
|
if spec.type == "text":
|
|
message_dict_res = cast(StepDict, user_res)
|
|
await self.process_message(
|
|
{"message": message_dict_res, "fileReferences": None}
|
|
)
|
|
interaction = message_dict_res["output"]
|
|
final_res = message_dict_res
|
|
elif spec.type == "file":
|
|
file_refs = cast(List[FileReference], user_res)
|
|
files = [
|
|
self.session.files[file["id"]]
|
|
for file in file_refs
|
|
if file["id"] in self.session.files
|
|
]
|
|
final_res = files
|
|
interaction = ",".join([file["name"] for file in files])
|
|
if get_data_layer():
|
|
coros = [
|
|
File(
|
|
id=file["id"],
|
|
name=file["name"],
|
|
path=str(file["path"]),
|
|
mime=file["type"],
|
|
chainlit_key=file["id"],
|
|
for_id=step_dict["id"],
|
|
)._create()
|
|
for file in files
|
|
]
|
|
await asyncio.gather(*coros)
|
|
elif spec.type == "action":
|
|
action_res = cast(AskActionResponse, user_res)
|
|
final_res = action_res
|
|
interaction = action_res["name"]
|
|
elif spec.type == "element":
|
|
final_res = cast(AskElementResponse, user_res)
|
|
interaction = "custom_element"
|
|
|
|
if not self.session.has_first_interaction and interaction:
|
|
self.session.has_first_interaction = True
|
|
await self.init_thread(interaction=interaction)
|
|
|
|
await self.clear("clear_ask")
|
|
return final_res
|
|
except TimeoutError as e:
|
|
await self.send_timeout("ask_timeout")
|
|
|
|
if raise_on_timeout:
|
|
raise e
|
|
finally:
|
|
if parent_id in self.session.files_spec:
|
|
del self.session.files_spec[parent_id]
|
|
await self.task_start()
|
|
|
|
async def send_call_fn(
|
|
self, name: str, args: Dict[str, Any], timeout=300, raise_on_timeout=False
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Stub method to send a call function event to the copilot and wait for a response."""
|
|
try:
|
|
call_fn_res = await self.emit_call(
|
|
"call_fn", {"name": name, "args": args}, timeout
|
|
) # type: Dict
|
|
|
|
await self.clear("clear_call_fn")
|
|
return call_fn_res
|
|
except TimeoutError as e:
|
|
await self.send_timeout("call_fn_timeout")
|
|
|
|
if raise_on_timeout:
|
|
raise e
|
|
return None
|
|
|
|
def update_token_count(self, count: int):
|
|
"""Update the token count for the UI."""
|
|
|
|
return self.emit("token_usage", count)
|
|
|
|
def task_start(self):
|
|
"""
|
|
Send a task start signal to the UI.
|
|
"""
|
|
return self.emit("task_start", {})
|
|
|
|
def task_end(self):
|
|
"""Send a task end signal to the UI."""
|
|
return self.emit("task_end", {})
|
|
|
|
def stream_start(self, step_dict: StepDict):
|
|
"""Send a stream start signal to the UI."""
|
|
return self.emit(
|
|
"stream_start",
|
|
step_dict,
|
|
)
|
|
|
|
def send_token(self, id: str, token: str, is_sequence=False, is_input=False):
|
|
"""Send a message token to the UI."""
|
|
return self.emit(
|
|
"stream_token",
|
|
{"id": id, "token": token, "isSequence": is_sequence, "isInput": is_input},
|
|
)
|
|
|
|
def set_chat_settings(self, settings: Dict[str, Any]):
|
|
self.session.chat_settings = settings
|
|
|
|
def set_commands(self, commands: List[CommandDict]):
|
|
"""Send the available commands to the UI."""
|
|
return self.emit(
|
|
"set_commands",
|
|
commands,
|
|
)
|
|
|
|
def set_modes(self, modes: List[Mode]):
|
|
"""Send the available modes to the UI."""
|
|
return self.emit(
|
|
"set_modes",
|
|
[mode.to_dict() for mode in modes],
|
|
)
|
|
|
|
def send_window_message(self, data: Any):
|
|
"""Send custom data to the host window."""
|
|
return self.emit("window_message", data)
|
|
|
|
def send_toast(self, message: str, type: Optional[ToastType] = "info"):
|
|
"""Send a toast message to the UI."""
|
|
# check that the type is valid using ToastType
|
|
if type not in get_args(ToastType):
|
|
raise ValueError(f"Invalid toast type: {type}")
|
|
return self.emit("toast", {"message": message, "type": type})
|