import json import mimetypes import uuid from enum import Enum from io import BytesIO from typing import ( Any, ClassVar, Dict, List, Literal, Optional, TypedDict, TypeVar, Union, ) import filetype from pydantic import Field from pydantic.dataclasses import dataclass from syncer import asyncio from chainlit.context import context from chainlit.data import get_data_layer from chainlit.logger import logger mime_types = { "text": "text/plain", "tasklist": "application/json", "plotly": "application/json", } ElementType = Literal[ "image", "text", "pdf", "tasklist", "audio", "video", "file", "plotly", "dataframe", "custom", ] ElementDisplay = Literal["inline", "side", "page"] ElementSize = Literal["small", "medium", "large"] class ElementDict(TypedDict, total=False): id: str threadId: Optional[str] type: ElementType chainlitKey: Optional[str] path: Optional[str] url: Optional[str] objectKey: Optional[str] name: str display: ElementDisplay size: Optional[ElementSize] language: Optional[str] page: Optional[int] props: Optional[Dict] autoPlay: Optional[bool] playerConfig: Optional[dict] forId: Optional[str] mime: Optional[str] @dataclass class Element: # Thread id thread_id: str = Field(default_factory=lambda: context.session.thread_id) # The type of the element. This will be used to determine how to display the element in the UI. type: ClassVar[ElementType] # Name of the element, this will be used to reference the element in the UI. name: str = "" # The ID of the element. This is set automatically when the element is sent to the UI. id: str = Field(default_factory=lambda: str(uuid.uuid4())) # The key of the element hosted on Chainlit. chainlit_key: Optional[str] = None # The URL of the element if already hosted somewhere else. url: Optional[str] = None # The S3 object key. object_key: Optional[str] = None # The local path of the element. path: Optional[str] = None # The byte content of the element. content: Optional[Union[bytes, str]] = None # Controls how the image element should be displayed in the UI. Choices are “side” (default), “inline”, or “page”. display: ElementDisplay = Field(default="inline") # Controls element size size: Optional[ElementSize] = None # The ID of the message this element is associated with. for_id: Optional[str] = None # The language, if relevant language: Optional[str] = None # Mime type, inferred based on content if not provided mime: Optional[str] = None def __post_init__(self) -> None: self.persisted = False self.updatable = False if not self.url and not self.path and not self.content: raise ValueError("Must provide url, path or content to instantiate element") def to_dict(self) -> ElementDict: _dict = ElementDict( { "id": self.id, "threadId": self.thread_id, "type": self.type, "url": self.url, "chainlitKey": self.chainlit_key, "name": self.name, "display": self.display, "objectKey": getattr(self, "object_key", None), "size": getattr(self, "size", None), "props": getattr(self, "props", None), "page": getattr(self, "page", None), "autoPlay": getattr(self, "auto_play", None), "playerConfig": getattr(self, "player_config", None), "language": getattr(self, "language", None), "forId": getattr(self, "for_id", None), "mime": getattr(self, "mime", None), } ) return _dict @classmethod def from_dict(cls, e_dict: ElementDict): """ Create an Element instance from a dictionary representation. Args: _dict (ElementDict): Dictionary containing element data Returns: Element: An instance of the appropriate Element subclass """ element_id = e_dict.get("id", str(uuid.uuid4())) for_id = e_dict.get("forId") name = e_dict.get("name", "") type = e_dict.get("type", "file") path = str(e_dict.get("path")) if e_dict.get("path") else None url = str(e_dict.get("url")) if e_dict.get("url") else None content = str(e_dict.get("content")) if e_dict.get("content") else None object_key = e_dict.get("objectKey") chainlit_key = e_dict.get("chainlitKey") display = e_dict.get("display", "inline") mime_type = e_dict.get("mime", "") # Common parameters for all element types common_params = { "id": element_id, "for_id": for_id, "name": name, "content": content, "path": path, "url": url, "object_key": object_key, "chainlit_key": chainlit_key, "display": display, "mime": mime_type, } if type == "image": return Image(size="medium", **common_params) # type: ignore[arg-type] elif type == "audio": return Audio(auto_play=e_dict.get("autoPlay", False), **common_params) # type: ignore[arg-type] elif type == "video": return Video( player_config=e_dict.get("playerConfig"), **common_params, # type: ignore[arg-type] ) elif type == "plotly": return Plotly(size=e_dict.get("size", "medium"), **common_params) # type: ignore[arg-type] elif type == "custom": return CustomElement(props=e_dict.get("props", {}), **common_params) # type: ignore[arg-type] else: # Default to File for any other type return File(**common_params) # type: ignore[arg-type] @classmethod def infer_type_from_mime(cls, mime_type: str): """Infer the element type from a mime type. Useful to know which element to instantiate from a file upload.""" if "image" in mime_type: return "image" elif mime_type == "application/pdf": return "pdf" elif "audio" in mime_type: return "audio" elif "video" in mime_type: return "video" else: return "file" async def _create(self, persist=True) -> bool: if self.persisted and not self.updatable: return True if (data_layer := get_data_layer()) and persist: try: asyncio.create_task(data_layer.create_element(self)) except Exception as e: logger.error(f"Failed to create element: {e!s}") if not self.url and (not self.chainlit_key or self.updatable): file_dict = await context.session.persist_file( name=self.name, path=self.path, content=self.content, mime=self.mime or "", ) self.chainlit_key = file_dict["id"] self.persisted = True return True async def remove(self): data_layer = get_data_layer() if data_layer: await data_layer.delete_element(self.id, self.thread_id) await context.emitter.emit("remove_element", {"id": self.id}) async def send(self, for_id: str, persist=True): self.for_id = for_id if not self.mime: if self.type in mime_types: self.mime = mime_types[self.type] elif self.path or isinstance(self.content, (bytes, bytearray)): file_type = filetype.guess(self.path or self.content) if file_type: self.mime = file_type.mime elif self.url: self.mime = mimetypes.guess_type(self.url)[0] await self._create(persist=persist) if not self.url and not self.chainlit_key: raise ValueError("Must provide url or chainlit key to send element") await context.emitter.send_element(self.to_dict()) ElementBased = TypeVar("ElementBased", bound=Element) @dataclass class Image(Element): type: ClassVar[ElementType] = "image" size: ElementSize = "medium" @dataclass class Text(Element): """Useful to send a text (not a message) to the UI.""" type: ClassVar[ElementType] = "text" language: Optional[str] = None @dataclass class Pdf(Element): """Useful to send a pdf to the UI.""" mime: str = "application/pdf" page: Optional[int] = None type: ClassVar[ElementType] = "pdf" @dataclass class Pyplot(Element): """Useful to send a pyplot to the UI.""" # We reuse the frontend image element to display the chart type: ClassVar[ElementType] = "image" size: ElementSize = "medium" # The type is set to Any because the figure is not serializable # and its actual type is checked in __post_init__. figure: Any = None def __post_init__(self) -> None: from matplotlib.figure import Figure if not isinstance(self.figure, Figure): raise TypeError("figure must be a matplotlib.figure.Figure") image = BytesIO() self.figure.savefig( image, dpi=200, bbox_inches="tight", backend="Agg", format="png" ) self.content = image.getvalue() super().__post_init__() class TaskStatus(Enum): READY = "ready" RUNNING = "running" FAILED = "failed" DONE = "done" @dataclass class Task: title: str status: TaskStatus = TaskStatus.READY forId: Optional[str] = None def __init__( self, title: str, status: TaskStatus = TaskStatus.READY, forId: Optional[str] = None, ): self.title = title self.status = status self.forId = forId @dataclass class TaskList(Element): type: ClassVar[ElementType] = "tasklist" tasks: List[Task] = Field(default_factory=list, exclude=True) status: str = "Ready" name: str = "tasklist" content: str = "dummy content to pass validation" def __post_init__(self) -> None: super().__post_init__() self.updatable = True async def add_task(self, task: Task): self.tasks.append(task) async def update(self): await self.send() async def send(self): await self.preprocess_content() await super().send(for_id="") async def preprocess_content(self): # serialize enum tasks = [ {"title": task.title, "status": task.status.value, "forId": task.forId} for task in self.tasks ] # store stringified json in content so that it's correctly stored in the database self.content = json.dumps( { "status": self.status, "tasks": tasks, }, indent=4, ensure_ascii=False, ) @dataclass class Audio(Element): type: ClassVar[ElementType] = "audio" auto_play: bool = False @dataclass class Video(Element): type: ClassVar[ElementType] = "video" size: ElementSize = "medium" # Override settings for each type of player in ReactPlayer # https://github.com/cookpete/react-player?tab=readme-ov-file#config-prop player_config: Optional[dict] = None @dataclass class File(Element): type: ClassVar[ElementType] = "file" @dataclass class Plotly(Element): """Useful to send a plotly to the UI.""" type: ClassVar[ElementType] = "plotly" size: ElementSize = "medium" # The type is set to Any because the figure is not serializable # and its actual type is checked in __post_init__. figure: Any = None content: str = "" def __post_init__(self) -> None: from plotly import graph_objects as go, io as pio if not isinstance(self.figure, go.Figure): raise TypeError("figure must be a plotly.graph_objects.Figure") self.figure.layout.autosize = True self.figure.layout.width = None self.figure.layout.height = None self.content = pio.to_json(self.figure, validate=True) self.mime = "application/json" super().__post_init__() @dataclass class Dataframe(Element): """Useful to send a pandas DataFrame to the UI.""" type: ClassVar[ElementType] = "dataframe" size: ElementSize = "large" data: Any = None # The type is Any because it is checked in __post_init__. def __post_init__(self) -> None: """Ensures the data is a pandas DataFrame and converts it to JSON.""" from pandas import DataFrame if not isinstance(self.data, DataFrame): raise TypeError("data must be a pandas.DataFrame") self.content = self.data.to_json(orient="split", date_format="iso") super().__post_init__() @dataclass class CustomElement(Element): """Useful to send a custom element to the UI.""" type: ClassVar[ElementType] = "custom" mime: str = "application/json" props: Dict = Field(default_factory=dict) def __post_init__(self) -> None: self.content = json.dumps(self.props) super().__post_init__() self.updatable = True async def update(self): await super().send(self.for_id)