import asyncio import json import logging import os import random from dataclasses import asdict from datetime import datetime from decimal import Decimal from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import aiofiles import aiohttp import boto3 # type: ignore from boto3.dynamodb.types import TypeDeserializer, TypeSerializer from chainlit.context import context from chainlit.data.base import BaseDataLayer from chainlit.data.storage_clients.base import BaseStorageClient from chainlit.data.utils import queue_until_user_message from chainlit.element import ElementDict from chainlit.logger import logger from chainlit.step import StepDict from chainlit.types import ( Feedback, PageInfo, PaginatedResponse, Pagination, ThreadDict, ThreadFilter, ) from chainlit.user import PersistedUser, User if TYPE_CHECKING: from mypy_boto3_dynamodb import DynamoDBClient from chainlit.element import Element _logger = logger.getChild("DynamoDB") _logger.setLevel(logging.WARNING) class DynamoDBDataLayer(BaseDataLayer): def __init__( self, table_name: str, client: Optional["DynamoDBClient"] = None, storage_provider: Optional[BaseStorageClient] = None, user_thread_limit: int = 10, ): if client: self.client = client else: region_name = os.environ.get("AWS_REGION", "us-east-1") self.client = boto3.client("dynamodb", region_name=region_name) # type: ignore self.table_name = table_name self.storage_provider = storage_provider self.user_thread_limit = user_thread_limit self._type_deserializer = TypeDeserializer() self._type_serializer = TypeSerializer() def _get_current_timestamp(self) -> str: return datetime.now().isoformat() + "Z" def _serialize_item(self, item: dict[str, Any]) -> dict[str, Any]: def convert_floats(obj): if isinstance(obj, float): return Decimal(str(obj)) elif isinstance(obj, dict): return {k: convert_floats(v) for k, v in obj.items()} elif isinstance(obj, list): return [convert_floats(v) for v in obj] else: return obj return { key: self._type_serializer.serialize(convert_floats(value)) for key, value in item.items() } def _deserialize_item(self, item: dict[str, Any]) -> dict[str, Any]: def convert_decimals(obj): if isinstance(obj, Decimal): return float(obj) elif isinstance(obj, dict): return {k: convert_decimals(v) for k, v in obj.items()} elif isinstance(obj, list): return [convert_decimals(v) for v in obj] else: return obj return { key: convert_decimals(self._type_deserializer.deserialize(value)) for key, value in item.items() } def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]): update_expr: List[str] = [] expression_attribute_names = {} expression_attribute_values = {} for index, (attr, value) in enumerate(updates.items()): if not value: continue k, v = f"#{index}", f":{index}" update_expr.append(f"{k} = {v}") expression_attribute_names[k] = attr expression_attribute_values[v] = value self.client.update_item( TableName=self.table_name, Key=self._serialize_item(key), UpdateExpression="SET " + ", ".join(update_expr), ExpressionAttributeNames=expression_attribute_names, ExpressionAttributeValues=self._serialize_item(expression_attribute_values), ) @property def context(self): return context async def get_user(self, identifier: str) -> Optional["PersistedUser"]: _logger.info("DynamoDB: get_user identifier=%s", identifier) response = self.client.get_item( TableName=self.table_name, Key={ "PK": {"S": f"USER#{identifier}"}, "SK": {"S": "USER"}, }, ) if "Item" not in response: return None user = self._deserialize_item(response["Item"]) return PersistedUser( id=user["id"], identifier=user["identifier"], createdAt=user["createdAt"], metadata=user["metadata"], ) async def create_user(self, user: "User") -> Optional["PersistedUser"]: _logger.info("DynamoDB: create_user user.identifier=%s", user.identifier) ts = self._get_current_timestamp() metadata: Dict[Any, Any] = user.metadata # type: ignore item = { "PK": f"USER#{user.identifier}", "SK": "USER", "id": user.identifier, "identifier": user.identifier, "metadata": metadata, "createdAt": ts, } self.client.put_item( TableName=self.table_name, Item=self._serialize_item(item), ) return PersistedUser( id=user.identifier, identifier=user.identifier, createdAt=ts, metadata=metadata, ) async def delete_feedback(self, feedback_id: str) -> bool: _logger.info("DynamoDB: delete_feedback feedback_id=%s", feedback_id) # feedback id = THREAD#{thread_id}::STEP#{step_id} thread_id, step_id = feedback_id.split("::") thread_id = thread_id.strip("THREAD#") step_id = step_id.strip("STEP#") self.client.update_item( TableName=self.table_name, Key={ "PK": {"S": f"THREAD#{thread_id}"}, "SK": {"S": f"STEP#{step_id}"}, }, UpdateExpression="REMOVE #feedback", ExpressionAttributeNames={"#feedback": "feedback"}, ) return True async def upsert_feedback(self, feedback: Feedback) -> str: _logger.info( "DynamoDB: upsert_feedback thread=%s step=%s value=%s", feedback.threadId, feedback.forId, feedback.value, ) if not feedback.forId: raise ValueError( "DynamoDB data layer expects value for feedback.threadId got None" ) feedback.id = f"THREAD#{feedback.threadId}::STEP#{feedback.forId}" serialized_feedback = self._type_serializer.serialize(asdict(feedback)) self.client.update_item( TableName=self.table_name, Key={ "PK": {"S": f"THREAD#{feedback.threadId}"}, "SK": {"S": f"STEP#{feedback.forId}"}, }, UpdateExpression="SET #feedback = :feedback", ExpressionAttributeNames={"#feedback": "feedback"}, ExpressionAttributeValues={":feedback": serialized_feedback}, ) return feedback.id @queue_until_user_message() async def create_element(self, element: "Element"): _logger.info( "DynamoDB: create_element thread=%s step=%s type=%s", element.thread_id, element.for_id, element.type, ) _logger.debug("DynamoDB: create_element: %s", element.to_dict()) if not element.for_id: return if not self.storage_provider: _logger.warning( "DynamoDB: create_element error. No storage_provider is configured!" ) return content: Optional[Union[bytes, str]] = None if element.content: content = element.content elif element.path: _logger.debug("DynamoDB: create_element reading file %s", element.path) async with aiofiles.open(element.path, "rb") as f: content = await f.read() elif element.url: _logger.debug("DynamoDB: create_element http %s", element.url) async with aiohttp.ClientSession() as session: async with session.get(element.url) as response: if response.status == 200: content = await response.read() else: raise ValueError( f"Failed to read content from {element.url} status {response.status}", ) else: raise ValueError("Element url, path or content must be provided") if content is None: raise ValueError("Content is None, cannot upload file") if not element.mime: element.mime = "application/octet-stream" context_user = self.context.session.user user_folder = getattr(context_user, "id", "unknown") file_object_key = f"{user_folder}/{element.thread_id}/{element.id}" uploaded_file = await self.storage_provider.upload_file( object_key=file_object_key, data=content, mime=element.mime, overwrite=True, ) if not uploaded_file: raise ValueError( "DynamoDB Error: create_element, Failed to persist data in storage_provider", ) element_dict: Dict[str, Any] = element.to_dict() # type: ignore element_dict.update( { "PK": f"THREAD#{element.thread_id}", "SK": f"ELEMENT#{element.id}", "url": uploaded_file.get("url"), "objectKey": uploaded_file.get("object_key"), } ) self.client.put_item( TableName=self.table_name, Item=self._serialize_item(element_dict), ) async def get_element( self, thread_id: str, element_id: str ) -> Optional["ElementDict"]: _logger.info( "DynamoDB: get_element thread=%s element=%s", thread_id, element_id ) response = self.client.get_item( TableName=self.table_name, Key={ "PK": {"S": f"THREAD#{thread_id}"}, "SK": {"S": f"ELEMENT#{element_id}"}, }, ) if "Item" not in response: return None return self._deserialize_item(response["Item"]) # type: ignore @queue_until_user_message() async def delete_element(self, element_id: str, thread_id: Optional[str] = None): thread_id = self.context.session.thread_id _logger.info( "DynamoDB: delete_element thread=%s element=%s", thread_id, element_id ) self.client.delete_item( TableName=self.table_name, Key={ "PK": {"S": f"THREAD#{thread_id}"}, "SK": {"S": f"ELEMENT#{element_id}"}, }, ) @queue_until_user_message() async def create_step(self, step_dict: "StepDict"): _logger.info( "DynamoDB: create_step thread=%s step=%s", step_dict.get("threadId"), step_dict.get("id"), ) _logger.debug("DynamoDB: create_step: %s", step_dict) item = dict(step_dict) item.update( { # ignore type, dynamo needs these so we want to fail if not set "PK": f"THREAD#{step_dict['threadId']}", # type: ignore "SK": f"STEP#{step_dict['id']}", # type: ignore } ) self.client.put_item( TableName=self.table_name, Item=self._serialize_item(item), ) @queue_until_user_message() async def update_step(self, step_dict: "StepDict"): _logger.info( "DynamoDB: update_step thread=%s step=%s", step_dict.get("threadId"), step_dict.get("id"), ) _logger.debug("DynamoDB: update_step: %s", step_dict) self._update_item( key={ # ignore type, dynamo needs these so we want to fail if not set "PK": f"THREAD#{step_dict['threadId']}", # type: ignore "SK": f"STEP#{step_dict['id']}", # type: ignore }, updates=step_dict, # type: ignore ) @queue_until_user_message() async def delete_step(self, step_id: str): thread_id = self.context.session.thread_id _logger.info("DynamoDB: delete_feedback thread=%s step=%s", thread_id, step_id) self.client.delete_item( TableName=self.table_name, Key={ "PK": {"S": f"THREAD#{thread_id}"}, "SK": {"S": f"STEP#{step_id}"}, }, ) async def get_thread_author(self, thread_id: str) -> str: _logger.info("DynamoDB: get_thread_author thread=%s", thread_id) response = self.client.get_item( TableName=self.table_name, Key={ "PK": {"S": f"THREAD#{thread_id}"}, "SK": {"S": "THREAD"}, }, ProjectionExpression="userId", ) if "Item" not in response: raise ValueError(f"Author not found for thread_id {thread_id}") item = self._deserialize_item(response["Item"]) return item["userId"] async def delete_thread(self, thread_id: str): _logger.info("DynamoDB: delete_thread thread=%s", thread_id) thread = await self.get_thread(thread_id) if not thread: return items: List[Any] = thread["steps"] if thread["elements"]: items.extend(thread["elements"]) delete_requests = [] for item in items: key = self._serialize_item({"PK": item["PK"], "SK": item["SK"]}) req = {"DeleteRequest": {"Key": key}} delete_requests.append(req) BATCH_ITEM_SIZE = 25 # pylint: disable=invalid-name for i in range(0, len(delete_requests), BATCH_ITEM_SIZE): chunk = delete_requests[i : i + BATCH_ITEM_SIZE] response = self.client.batch_write_item( RequestItems={ self.table_name: chunk, # type: ignore } ) backoff_time = 1 while response.get("UnprocessedItems"): backoff_time *= 2 # Cap the backoff time at 32 seconds & add jitter delay = min(backoff_time, 32) + random.uniform(0, 1) await asyncio.sleep(delay) response = self.client.batch_write_item( RequestItems=response["UnprocessedItems"] ) self.client.delete_item( TableName=self.table_name, Key={ "PK": {"S": f"THREAD#{thread_id}"}, "SK": {"S": "THREAD"}, }, ) async def list_threads( self, pagination: "Pagination", filters: "ThreadFilter" ) -> "PaginatedResponse[ThreadDict]": _logger.info("DynamoDB: list_threads filters.userId=%s", filters.userId) if filters.feedback: _logger.warning("DynamoDB: filters on feedback not supported") paginated_response: PaginatedResponse[ThreadDict] = PaginatedResponse( data=[], pageInfo=PageInfo( hasNextPage=False, startCursor=pagination.cursor, endCursor=None ), ) query_args: Dict[str, Any] = { "TableName": self.table_name, "IndexName": "UserThread", "ScanIndexForward": False, "Limit": self.user_thread_limit, "KeyConditionExpression": "#UserThreadPK = :pk", "ExpressionAttributeNames": { "#UserThreadPK": "UserThreadPK", }, "ExpressionAttributeValues": { ":pk": {"S": f"USER#{filters.userId}"}, }, } if pagination.cursor: query_args["ExclusiveStartKey"] = json.loads(pagination.cursor) if filters.search: query_args["FilterExpression"] = "contains(#name, :search)" query_args["ExpressionAttributeNames"]["#name"] = "name" query_args["ExpressionAttributeValues"][":search"] = {"S": filters.search} response = self.client.query(**query_args) # type: ignore if "LastEvaluatedKey" in response: paginated_response.pageInfo.hasNextPage = True paginated_response.pageInfo.endCursor = json.dumps( response["LastEvaluatedKey"] ) for item in response["Items"]: deserialized_item: Dict[str, Any] = self._deserialize_item(item) thread = ThreadDict( # type: ignore id=deserialized_item["PK"].strip("THREAD#"), createdAt=deserialized_item["UserThreadSK"].strip("TS#"), name=deserialized_item["name"], ) paginated_response.data.append(thread) return paginated_response async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]": _logger.info("DynamoDB: get_thread thread=%s", thread_id) # Get all thread records thread_items: List[Any] = [] cursor: Dict[str, Any] = {} while True: response = self.client.query( TableName=self.table_name, KeyConditionExpression="#pk = :pk", ExpressionAttributeNames={"#pk": "PK"}, ExpressionAttributeValues={":pk": {"S": f"THREAD#{thread_id}"}}, **cursor, ) deserialized_items = map(self._deserialize_item, response["Items"]) thread_items.extend(deserialized_items) if "LastEvaluatedKey" not in response: break cursor["ExclusiveStartKey"] = response["LastEvaluatedKey"] if len(thread_items) == 0: return None # process accordingly thread_dict: Optional[ThreadDict] = None steps = [] elements = [] for item in thread_items: if item["SK"] == "THREAD": thread_dict = item elif item["SK"].startswith("ELEMENT"): if self.storage_provider is not None: item["url"] = await self.storage_provider.get_read_url( object_key=item["objectKey"], ) elements.append(item) elif item["SK"].startswith("STEP"): if "feedback" in item: # Decimal is not json serializable item["feedback"]["value"] = int(item["feedback"]["value"]) steps.append(item) if not thread_dict: if len(thread_items) > 0: _logger.warning( "DynamoDB: found orphaned items for thread=%s", thread_id ) return None steps.sort(key=lambda i: i["createdAt"]) thread_dict.update( { "steps": steps, "elements": elements, } ) return thread_dict async def update_thread( self, thread_id: str, name: Optional[str] = None, user_id: Optional[str] = None, metadata: Optional[Dict] = None, tags: Optional[List[str]] = None, ): _logger.info("DynamoDB: update_thread thread=%s userId=%s", thread_id, user_id) _logger.debug( "DynamoDB: update_thread name=%s tags=%s metadata=%s", name, tags, metadata ) ts = self._get_current_timestamp() item = { # GSI: UserThread "UserThreadSK": f"TS#{ts}", # "id": thread_id, "createdAt": ts, "name": name, "userId": user_id, "userIdentifier": user_id, "tags": tags, "metadata": metadata, } if user_id: # user_id may be None on subsequent calls, don't update UserThreadPK to "USER#{None}" item["UserThreadPK"] = f"USER#{user_id}" self._update_item( key={ "PK": f"THREAD#{thread_id}", "SK": "THREAD", }, updates=item, ) async def build_debug_url(self) -> str: return "" async def close(self) -> None: if self.storage_provider: await self.storage_provider.close() self.client.close()