import json import uuid from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import aiofiles import asyncpg # type: ignore from chainlit.data.base import BaseDataLayer from chainlit.data.storage_clients.base import BaseStorageClient from chainlit.data.utils import queue_until_user_message from chainlit.element import ElementDict from chainlit.logger import logger from chainlit.step import StepDict from chainlit.types import ( Feedback, FeedbackDict, PageInfo, PaginatedResponse, Pagination, ThreadDict, ThreadFilter, ) from chainlit.user import PersistedUser, User # Import for runtime usage (isinstance checks) try: from chainlit.data.storage_clients.gcs import GCSStorageClient except ImportError: GCSStorageClient = None # type: ignore[assignment,misc] if TYPE_CHECKING: from chainlit.data.storage_clients.gcs import GCSStorageClient from chainlit.element import Element, ElementDict from chainlit.step import StepDict ISO_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" class ChainlitDataLayer(BaseDataLayer): def __init__( self, database_url: str, storage_client: Optional[BaseStorageClient] = None, show_logger: bool = False, ): self.database_url = database_url self.pool: Optional[asyncpg.Pool] = None self.storage_client = storage_client self.show_logger = show_logger async def connect(self): if not self.pool: self.pool = await asyncpg.create_pool(self.database_url) async def get_current_timestamp(self) -> datetime: return datetime.now() async def execute_query( self, query: str, params: Union[Dict, None] = None ) -> List[Dict[str, Any]]: if not self.pool: await self.connect() try: async with self.pool.acquire() as connection: # type: ignore try: if params: records = await connection.fetch(query, *params.values()) else: records = await connection.fetch(query) return [dict(record) for record in records] except Exception as e: logger.error(f"Database error: {e!s}") raise except ( asyncpg.exceptions.ConnectionDoesNotExistError, asyncpg.exceptions.InterfaceError, ) as e: # Handle connection issues by cleaning up and rethrowing logger.error(f"Connection error: {e!s}") await self.cleanup() raise async def get_user(self, identifier: str) -> Optional[PersistedUser]: query = """ SELECT * FROM "User" WHERE identifier = $1 """ result = await self.execute_query(query, {"identifier": identifier}) if not result or len(result) == 0: return None row = result[0] return PersistedUser( id=str(row.get("id")), identifier=str(row.get("identifier")), createdAt=row.get("createdAt").isoformat(), # type: ignore metadata=json.loads(row.get("metadata", "{}")), ) async def create_user(self, user: User) -> Optional[PersistedUser]: query = """ INSERT INTO "User" (id, identifier, metadata, "createdAt", "updatedAt") VALUES ($1, $2, $3, $4, $5) ON CONFLICT (identifier) DO UPDATE SET metadata = $3 RETURNING * """ now = await self.get_current_timestamp() params = { "id": str(uuid.uuid4()), "identifier": user.identifier, "metadata": json.dumps(user.metadata), "created_at": now, "updated_at": now, } result = await self.execute_query(query, params) row = result[0] return PersistedUser( id=str(row.get("id")), identifier=str(row.get("identifier")), createdAt=row.get("createdAt").isoformat(), # type: ignore metadata=json.loads(row.get("metadata", "{}")), ) async def delete_feedback(self, feedback_id: str) -> bool: query = """ DELETE FROM "Feedback" WHERE id = $1 """ await self.execute_query(query, {"feedback_id": feedback_id}) return True async def upsert_feedback(self, feedback: Feedback) -> str: query = """ INSERT INTO "Feedback" (id, "stepId", name, value, comment) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (id) DO UPDATE SET value = $4, comment = $5 RETURNING id """ feedback_id = feedback.id or str(uuid.uuid4()) params = { "id": feedback_id, "step_id": feedback.forId, "name": "user_feedback", "value": float(feedback.value), "comment": feedback.comment, } results = await self.execute_query(query, params) return str(results[0]["id"]) @queue_until_user_message() async def create_element(self, element: "Element"): if not element.for_id: return if element.thread_id: query = 'SELECT id FROM "Thread" WHERE id = $1' results = await self.execute_query(query, {"thread_id": element.thread_id}) if not results: await self.update_thread(thread_id=element.thread_id) if element.for_id: query = 'SELECT id FROM "Step" WHERE id = $1' results = await self.execute_query(query, {"step_id": element.for_id}) if not results: await self.create_step( { "id": element.for_id, "metadata": {}, "type": "run", "start_time": await self.get_current_timestamp(), "end_time": await self.get_current_timestamp(), } ) # Handle file uploads only if storage_client is configured path = None if self.storage_client: content: Optional[Union[bytes, str]] = None if element.path: async with aiofiles.open(element.path, "rb") as f: content = await f.read() elif element.content: content = element.content elif not element.url: raise ValueError("Element url, path or content must be provided") if content is not None: if element.thread_id: path = f"threads/{element.thread_id}/files/{element.id}" else: path = f"files/{element.id}" content_disposition = ( f'attachment; filename="{element.name}"' if not ( GCSStorageClient is not None and isinstance(self.storage_client, GCSStorageClient) ) else None ) await self.storage_client.upload_file( object_key=path, data=content, mime=element.mime or "application/octet-stream", overwrite=True, content_disposition=content_disposition, ) else: # Log warning only if element has file content that needs uploading if element.path or element.url or element.content: logger.warning( "Data Layer: No storage client configured. " "File will not be uploaded." ) # Always persist element metadata to database query = """ INSERT INTO "Element" ( id, "threadId", "stepId", metadata, mime, name, "objectKey", url, "chainlitKey", display, size, language, page, props ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14 ) ON CONFLICT (id) DO UPDATE SET props = EXCLUDED.props """ params = { "id": element.id, "thread_id": element.thread_id, "step_id": element.for_id, "metadata": json.dumps( { "size": element.size, "language": element.language, "display": element.display, "type": element.type, "page": getattr(element, "page", None), } ), "mime": element.mime, "name": element.name, "object_key": path, "url": element.url, "chainlit_key": element.chainlit_key, "display": element.display, "size": element.size, "language": element.language, "page": getattr(element, "page", None), "props": json.dumps(getattr(element, "props", {})), } await self.execute_query(query, params) async def get_element( self, thread_id: str, element_id: str ) -> Optional[ElementDict]: query = """ SELECT * FROM "Element" WHERE id = $1 AND "threadId" = $2 """ results = await self.execute_query( query, {"element_id": element_id, "thread_id": thread_id} ) if not results: return None row = results[0] metadata = json.loads(row.get("metadata", "{}")) return ElementDict( id=str(row["id"]), threadId=str(row["threadId"]), type=metadata.get("type", "file"), url=str(row["url"]), name=str(row["name"]), mime=str(row["mime"]), objectKey=str(row["objectKey"]), forId=str(row["stepId"]), chainlitKey=row.get("chainlitKey"), display=row["display"], size=row["size"], language=row["language"], page=row["page"], autoPlay=row.get("autoPlay"), playerConfig=row.get("playerConfig"), props=json.loads(row.get("props", "{}")), ) @queue_until_user_message() async def delete_element(self, element_id: str, thread_id: Optional[str] = None): query = """ SELECT * FROM "Element" WHERE id = $1 """ elements = await self.execute_query(query, {"id": element_id}) if self.storage_client is not None and len(elements) > 0: if elements[0]["objectKey"]: await self.storage_client.delete_file( object_key=elements[0]["objectKey"] ) query = """ DELETE FROM "Element" WHERE id = $1 """ params = {"id": element_id} if thread_id: query += ' AND "threadId" = $2' params["thread_id"] = thread_id await self.execute_query(query, params) @queue_until_user_message() async def create_step(self, step_dict: StepDict): if step_dict.get("threadId"): thread_query = 'SELECT id FROM "Thread" WHERE id = $1' thread_results = await self.execute_query( thread_query, {"thread_id": step_dict["threadId"]} ) if not thread_results: await self.update_thread(thread_id=step_dict["threadId"]) if step_dict.get("parentId"): parent_query = 'SELECT id FROM "Step" WHERE id = $1' parent_results = await self.execute_query( parent_query, {"parent_id": step_dict["parentId"]} ) if not parent_results: await self.create_step( { "id": step_dict["parentId"], "metadata": {}, "type": "run", "createdAt": step_dict.get("createdAt"), } ) query = """ INSERT INTO "Step" ( id, "threadId", "parentId", input, metadata, name, output, type, "startTime", "endTime", "showInput", "isError" ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12 ) ON CONFLICT (id) DO UPDATE SET "parentId" = COALESCE(EXCLUDED."parentId", "Step"."parentId"), input = COALESCE(EXCLUDED.input, "Step".input), metadata = CASE WHEN EXCLUDED.metadata <> '{}' THEN EXCLUDED.metadata ELSE "Step".metadata END, name = COALESCE(EXCLUDED.name, "Step".name), output = COALESCE(EXCLUDED.output, "Step".output), type = CASE WHEN EXCLUDED.type = 'run' THEN "Step".type ELSE EXCLUDED.type END, "threadId" = COALESCE(EXCLUDED."threadId", "Step"."threadId"), "endTime" = COALESCE(EXCLUDED."endTime", "Step"."endTime"), "startTime" = LEAST(EXCLUDED."startTime", "Step"."startTime"), "showInput" = COALESCE(EXCLUDED."showInput", "Step"."showInput"), "isError" = COALESCE(EXCLUDED."isError", "Step"."isError") """ timestamp = await self.get_current_timestamp() created_at = step_dict.get("createdAt") if created_at: timestamp = datetime.strptime(created_at, ISO_FORMAT) params = { "id": step_dict["id"], "thread_id": step_dict.get("threadId"), "parent_id": step_dict.get("parentId"), "input": step_dict.get("input"), "metadata": json.dumps(step_dict.get("metadata", {})), "name": step_dict.get("name"), "output": step_dict.get("output"), "type": step_dict["type"], "start_time": timestamp, "end_time": timestamp, "show_input": str(step_dict.get("showInput", "json")), "is_error": step_dict.get("isError", False), } await self.execute_query(query, params) @queue_until_user_message() async def update_step(self, step_dict: StepDict): await self.create_step(step_dict) @queue_until_user_message() async def delete_step(self, step_id: str): # Delete associated elements and feedbacks first await self.execute_query( 'DELETE FROM "Element" WHERE "stepId" = $1', {"step_id": step_id} ) await self.execute_query( 'DELETE FROM "Feedback" WHERE "stepId" = $1', {"step_id": step_id} ) # Delete the step await self.execute_query( 'DELETE FROM "Step" WHERE id = $1', {"step_id": step_id} ) async def get_thread_author(self, thread_id: str) -> str: query = """ SELECT u.identifier FROM "Thread" t JOIN "User" u ON t."userId" = u.id WHERE t.id = $1 """ results = await self.execute_query(query, {"thread_id": thread_id}) if not results: raise ValueError(f"Thread {thread_id} not found") return results[0]["identifier"] async def delete_thread(self, thread_id: str): elements_query = """ SELECT * FROM "Element" WHERE "threadId" = $1 """ elements_results = await self.execute_query( elements_query, {"thread_id": thread_id} ) if self.storage_client is not None: for elem in elements_results: if elem["objectKey"]: await self.storage_client.delete_file(object_key=elem["objectKey"]) await self.execute_query( 'DELETE FROM "Thread" WHERE id = $1', {"thread_id": thread_id} ) async def list_threads( self, pagination: Pagination, filters: ThreadFilter ) -> PaginatedResponse[ThreadDict]: query = """ SELECT t.*, u.identifier as user_identifier, (SELECT COUNT(*) FROM "Thread" WHERE "userId" = t."userId") as total FROM "Thread" t LEFT JOIN "User" u ON t."userId" = u.id WHERE t."deletedAt" IS NULL """ params: Dict[str, Any] = {} param_count = 1 if filters.search: query += f" AND t.name ILIKE ${param_count}" params["name"] = f"%{filters.search}%" param_count += 1 if filters.userId: query += f' AND t."userId" = ${param_count}' params["user_id"] = filters.userId param_count += 1 if pagination.cursor: query += f' AND t."updatedAt" < (SELECT "updatedAt" FROM "Thread" WHERE id = ${param_count})' params["cursor"] = pagination.cursor param_count += 1 query += f' ORDER BY t."updatedAt" DESC LIMIT ${param_count}' params["limit"] = pagination.first + 1 results = await self.execute_query(query, params) threads = results has_next_page = len(threads) > pagination.first if has_next_page: threads = threads[:-1] thread_dicts = [] for thread in threads: thread_dict = ThreadDict( id=str(thread["id"]), createdAt=thread["updatedAt"].isoformat(), name=thread["name"], userId=str(thread["userId"]) if thread["userId"] else None, userIdentifier=thread["user_identifier"], metadata=json.loads(thread["metadata"]), steps=[], elements=[], tags=[], ) thread_dicts.append(thread_dict) return PaginatedResponse( pageInfo=PageInfo( hasNextPage=has_next_page, startCursor=thread_dicts[0]["id"] if thread_dicts else None, endCursor=thread_dicts[-1]["id"] if thread_dicts else None, ), data=thread_dicts, ) async def get_thread(self, thread_id: str) -> Optional[ThreadDict]: query = """ SELECT t.*, u.identifier as user_identifier FROM "Thread" t LEFT JOIN "User" u ON t."userId" = u.id WHERE t.id = $1 AND t."deletedAt" IS NULL """ results = await self.execute_query(query, {"thread_id": thread_id}) if not results: return None thread = results[0] # Get steps and related feedback steps_query = """ SELECT s.*, f.id feedback_id, f.value feedback_value, f."comment" feedback_comment FROM "Step" s left join "Feedback" f on s.id = f."stepId" WHERE s."threadId" = $1 ORDER BY "startTime" """ steps_results = await self.execute_query(steps_query, {"thread_id": thread_id}) # Get elements elements_query = """ SELECT * FROM "Element" WHERE "threadId" = $1 """ elements_results = await self.execute_query( elements_query, {"thread_id": thread_id} ) if self.storage_client is not None: for elem in elements_results: if not elem["url"] and elem["objectKey"]: elem["url"] = await self.storage_client.get_read_url( object_key=elem["objectKey"], ) return ThreadDict( id=str(thread["id"]), createdAt=thread["createdAt"].isoformat(), name=thread["name"], userId=str(thread["userId"]) if thread["userId"] else None, userIdentifier=thread["user_identifier"], metadata=json.loads(thread["metadata"]), steps=[self._convert_step_row_to_dict(step) for step in steps_results], elements=[ self._convert_element_row_to_dict(elem) for elem in elements_results ], tags=[], ) async def update_thread( self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None, ): if self.show_logger: logger.info(f"asyncpg: update_thread, thread_id={thread_id}") thread_name = truncate( name if name is not None else (metadata.get("name") if metadata and "name" in metadata else None) ) # Merge incoming metadata with existing metadata, deleting incoming keys with None values if metadata is not None: existing = await self.execute_query( 'SELECT "metadata" FROM "Thread" WHERE id = $1', {"thread_id": thread_id}, ) base = {} if isinstance(existing, list) and existing: raw = existing[0].get("metadata") or {} if isinstance(raw, str): try: base = json.loads(raw) except json.JSONDecodeError: base = {} elif isinstance(raw, dict): base = raw to_delete = {k for k, v in metadata.items() if v is None} incoming = {k: v for k, v in metadata.items() if v is not None} base = {k: v for k, v in base.items() if k not in to_delete} metadata = {**base, **incoming} data = { "id": thread_id, "name": thread_name, "userId": user_id, "tags": tags, "metadata": json.dumps(metadata or {}), "updatedAt": datetime.now(), } # Remove None values data = {k: v for k, v in data.items() if v is not None} # Build the query dynamically based on available fields columns = [f'"{k}"' for k in data.keys()] placeholders = [f"${i + 1}" for i in range(len(data))] values = list(data.values()) update_sets = [f'"{k}" = EXCLUDED."{k}"' for k in data.keys() if k != "id"] if update_sets: query = f""" INSERT INTO "Thread" ({", ".join(columns)}) VALUES ({", ".join(placeholders)}) ON CONFLICT (id) DO UPDATE SET {", ".join(update_sets)}; """ else: query = f""" INSERT INTO "Thread" ({", ".join(columns)}) VALUES ({", ".join(placeholders)}) ON CONFLICT (id) DO NOTHING """ await self.execute_query(query, {str(i + 1): v for i, v in enumerate(values)}) def _extract_feedback_dict_from_step_row(self, row: Dict) -> Optional[FeedbackDict]: if row["feedback_id"] is not None: return FeedbackDict( forId=row["id"], id=row["feedback_id"], value=row["feedback_value"], comment=row["feedback_comment"], ) return None def _convert_step_row_to_dict(self, row: Dict) -> StepDict: return StepDict( id=str(row["id"]), threadId=str(row["threadId"]) if row.get("threadId") else "", parentId=str(row["parentId"]) if row.get("parentId") else None, name=str(row.get("name")), type=row["type"], input=row.get("input", {}), output=row.get("output", {}), metadata=json.loads(row.get("metadata", "{}")), createdAt=row["createdAt"].isoformat() if row.get("createdAt") else None, start=row["startTime"].isoformat() if row.get("startTime") else None, showInput=row.get("showInput"), isError=row.get("isError"), end=row["endTime"].isoformat() if row.get("endTime") else None, feedback=self._extract_feedback_dict_from_step_row(row), ) def _convert_element_row_to_dict(self, row: Dict) -> ElementDict: metadata = json.loads(row.get("metadata", "{}")) return ElementDict( id=str(row["id"]), threadId=str(row["threadId"]) if row.get("threadId") else None, type=metadata.get("type", "file"), url=row["url"], name=row["name"], mime=row["mime"], objectKey=row["objectKey"], forId=str(row["stepId"]), chainlitKey=row.get("chainlitKey"), display=row["display"], size=row["size"], language=row["language"], page=row["page"], autoPlay=row.get("autoPlay"), playerConfig=row.get("playerConfig"), props=json.loads(row.get("props") or "{}"), ) async def build_debug_url(self) -> str: return "" async def cleanup(self): """Cleanup database connections""" if self.pool: logger.debug("Cleaning up connection pool") await self.pool.close() self.pool = None async def close(self) -> None: if self.storage_client: await self.storage_client.close() await self.cleanup() def truncate(text: Optional[str], max_length: int = 255) -> Optional[str]: return None if text is None else text[:max_length]