ai-station/.venv/lib/python3.12/site-packages/chainlit/data/dynamodb.py

622 lines
20 KiB
Python

import asyncio
import json
import logging
import os
import random
from dataclasses import asdict
from datetime import datetime
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import aiofiles
import aiohttp
import boto3 # type: ignore
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from chainlit.context import context
from chainlit.data.base import BaseDataLayer
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.data.utils import queue_until_user_message
from chainlit.element import ElementDict
from chainlit.logger import logger
from chainlit.step import StepDict
from chainlit.types import (
Feedback,
PageInfo,
PaginatedResponse,
Pagination,
ThreadDict,
ThreadFilter,
)
from chainlit.user import PersistedUser, User
if TYPE_CHECKING:
from mypy_boto3_dynamodb import DynamoDBClient
from chainlit.element import Element
_logger = logger.getChild("DynamoDB")
_logger.setLevel(logging.WARNING)
class DynamoDBDataLayer(BaseDataLayer):
def __init__(
self,
table_name: str,
client: Optional["DynamoDBClient"] = None,
storage_provider: Optional[BaseStorageClient] = None,
user_thread_limit: int = 10,
):
if client:
self.client = client
else:
region_name = os.environ.get("AWS_REGION", "us-east-1")
self.client = boto3.client("dynamodb", region_name=region_name) # type: ignore
self.table_name = table_name
self.storage_provider = storage_provider
self.user_thread_limit = user_thread_limit
self._type_deserializer = TypeDeserializer()
self._type_serializer = TypeSerializer()
def _get_current_timestamp(self) -> str:
return datetime.now().isoformat() + "Z"
def _serialize_item(self, item: dict[str, Any]) -> dict[str, Any]:
def convert_floats(obj):
if isinstance(obj, float):
return Decimal(str(obj))
elif isinstance(obj, dict):
return {k: convert_floats(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_floats(v) for v in obj]
else:
return obj
return {
key: self._type_serializer.serialize(convert_floats(value))
for key, value in item.items()
}
def _deserialize_item(self, item: dict[str, Any]) -> dict[str, Any]:
def convert_decimals(obj):
if isinstance(obj, Decimal):
return float(obj)
elif isinstance(obj, dict):
return {k: convert_decimals(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_decimals(v) for v in obj]
else:
return obj
return {
key: convert_decimals(self._type_deserializer.deserialize(value))
for key, value in item.items()
}
def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]):
update_expr: List[str] = []
expression_attribute_names = {}
expression_attribute_values = {}
for index, (attr, value) in enumerate(updates.items()):
if not value:
continue
k, v = f"#{index}", f":{index}"
update_expr.append(f"{k} = {v}")
expression_attribute_names[k] = attr
expression_attribute_values[v] = value
self.client.update_item(
TableName=self.table_name,
Key=self._serialize_item(key),
UpdateExpression="SET " + ", ".join(update_expr),
ExpressionAttributeNames=expression_attribute_names,
ExpressionAttributeValues=self._serialize_item(expression_attribute_values),
)
@property
def context(self):
return context
async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
_logger.info("DynamoDB: get_user identifier=%s", identifier)
response = self.client.get_item(
TableName=self.table_name,
Key={
"PK": {"S": f"USER#{identifier}"},
"SK": {"S": "USER"},
},
)
if "Item" not in response:
return None
user = self._deserialize_item(response["Item"])
return PersistedUser(
id=user["id"],
identifier=user["identifier"],
createdAt=user["createdAt"],
metadata=user["metadata"],
)
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
_logger.info("DynamoDB: create_user user.identifier=%s", user.identifier)
ts = self._get_current_timestamp()
metadata: Dict[Any, Any] = user.metadata # type: ignore
item = {
"PK": f"USER#{user.identifier}",
"SK": "USER",
"id": user.identifier,
"identifier": user.identifier,
"metadata": metadata,
"createdAt": ts,
}
self.client.put_item(
TableName=self.table_name,
Item=self._serialize_item(item),
)
return PersistedUser(
id=user.identifier,
identifier=user.identifier,
createdAt=ts,
metadata=metadata,
)
async def delete_feedback(self, feedback_id: str) -> bool:
_logger.info("DynamoDB: delete_feedback feedback_id=%s", feedback_id)
# feedback id = THREAD#{thread_id}::STEP#{step_id}
thread_id, step_id = feedback_id.split("::")
thread_id = thread_id.strip("THREAD#")
step_id = step_id.strip("STEP#")
self.client.update_item(
TableName=self.table_name,
Key={
"PK": {"S": f"THREAD#{thread_id}"},
"SK": {"S": f"STEP#{step_id}"},
},
UpdateExpression="REMOVE #feedback",
ExpressionAttributeNames={"#feedback": "feedback"},
)
return True
async def upsert_feedback(self, feedback: Feedback) -> str:
_logger.info(
"DynamoDB: upsert_feedback thread=%s step=%s value=%s",
feedback.threadId,
feedback.forId,
feedback.value,
)
if not feedback.forId:
raise ValueError(
"DynamoDB data layer expects value for feedback.threadId got None"
)
feedback.id = f"THREAD#{feedback.threadId}::STEP#{feedback.forId}"
serialized_feedback = self._type_serializer.serialize(asdict(feedback))
self.client.update_item(
TableName=self.table_name,
Key={
"PK": {"S": f"THREAD#{feedback.threadId}"},
"SK": {"S": f"STEP#{feedback.forId}"},
},
UpdateExpression="SET #feedback = :feedback",
ExpressionAttributeNames={"#feedback": "feedback"},
ExpressionAttributeValues={":feedback": serialized_feedback},
)
return feedback.id
@queue_until_user_message()
async def create_element(self, element: "Element"):
_logger.info(
"DynamoDB: create_element thread=%s step=%s type=%s",
element.thread_id,
element.for_id,
element.type,
)
_logger.debug("DynamoDB: create_element: %s", element.to_dict())
if not element.for_id:
return
if not self.storage_provider:
_logger.warning(
"DynamoDB: create_element error. No storage_provider is configured!"
)
return
content: Optional[Union[bytes, str]] = None
if element.content:
content = element.content
elif element.path:
_logger.debug("DynamoDB: create_element reading file %s", element.path)
async with aiofiles.open(element.path, "rb") as f:
content = await f.read()
elif element.url:
_logger.debug("DynamoDB: create_element http %s", element.url)
async with aiohttp.ClientSession() as session:
async with session.get(element.url) as response:
if response.status == 200:
content = await response.read()
else:
raise ValueError(
f"Failed to read content from {element.url} status {response.status}",
)
else:
raise ValueError("Element url, path or content must be provided")
if content is None:
raise ValueError("Content is None, cannot upload file")
if not element.mime:
element.mime = "application/octet-stream"
context_user = self.context.session.user
user_folder = getattr(context_user, "id", "unknown")
file_object_key = f"{user_folder}/{element.thread_id}/{element.id}"
uploaded_file = await self.storage_provider.upload_file(
object_key=file_object_key,
data=content,
mime=element.mime,
overwrite=True,
)
if not uploaded_file:
raise ValueError(
"DynamoDB Error: create_element, Failed to persist data in storage_provider",
)
element_dict: Dict[str, Any] = element.to_dict() # type: ignore
element_dict.update(
{
"PK": f"THREAD#{element.thread_id}",
"SK": f"ELEMENT#{element.id}",
"url": uploaded_file.get("url"),
"objectKey": uploaded_file.get("object_key"),
}
)
self.client.put_item(
TableName=self.table_name,
Item=self._serialize_item(element_dict),
)
async def get_element(
self, thread_id: str, element_id: str
) -> Optional["ElementDict"]:
_logger.info(
"DynamoDB: get_element thread=%s element=%s", thread_id, element_id
)
response = self.client.get_item(
TableName=self.table_name,
Key={
"PK": {"S": f"THREAD#{thread_id}"},
"SK": {"S": f"ELEMENT#{element_id}"},
},
)
if "Item" not in response:
return None
return self._deserialize_item(response["Item"]) # type: ignore
@queue_until_user_message()
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
thread_id = self.context.session.thread_id
_logger.info(
"DynamoDB: delete_element thread=%s element=%s", thread_id, element_id
)
self.client.delete_item(
TableName=self.table_name,
Key={
"PK": {"S": f"THREAD#{thread_id}"},
"SK": {"S": f"ELEMENT#{element_id}"},
},
)
@queue_until_user_message()
async def create_step(self, step_dict: "StepDict"):
_logger.info(
"DynamoDB: create_step thread=%s step=%s",
step_dict.get("threadId"),
step_dict.get("id"),
)
_logger.debug("DynamoDB: create_step: %s", step_dict)
item = dict(step_dict)
item.update(
{
# ignore type, dynamo needs these so we want to fail if not set
"PK": f"THREAD#{step_dict['threadId']}", # type: ignore
"SK": f"STEP#{step_dict['id']}", # type: ignore
}
)
self.client.put_item(
TableName=self.table_name,
Item=self._serialize_item(item),
)
@queue_until_user_message()
async def update_step(self, step_dict: "StepDict"):
_logger.info(
"DynamoDB: update_step thread=%s step=%s",
step_dict.get("threadId"),
step_dict.get("id"),
)
_logger.debug("DynamoDB: update_step: %s", step_dict)
self._update_item(
key={
# ignore type, dynamo needs these so we want to fail if not set
"PK": f"THREAD#{step_dict['threadId']}", # type: ignore
"SK": f"STEP#{step_dict['id']}", # type: ignore
},
updates=step_dict, # type: ignore
)
@queue_until_user_message()
async def delete_step(self, step_id: str):
thread_id = self.context.session.thread_id
_logger.info("DynamoDB: delete_feedback thread=%s step=%s", thread_id, step_id)
self.client.delete_item(
TableName=self.table_name,
Key={
"PK": {"S": f"THREAD#{thread_id}"},
"SK": {"S": f"STEP#{step_id}"},
},
)
async def get_thread_author(self, thread_id: str) -> str:
_logger.info("DynamoDB: get_thread_author thread=%s", thread_id)
response = self.client.get_item(
TableName=self.table_name,
Key={
"PK": {"S": f"THREAD#{thread_id}"},
"SK": {"S": "THREAD"},
},
ProjectionExpression="userId",
)
if "Item" not in response:
raise ValueError(f"Author not found for thread_id {thread_id}")
item = self._deserialize_item(response["Item"])
return item["userId"]
async def delete_thread(self, thread_id: str):
_logger.info("DynamoDB: delete_thread thread=%s", thread_id)
thread = await self.get_thread(thread_id)
if not thread:
return
items: List[Any] = thread["steps"]
if thread["elements"]:
items.extend(thread["elements"])
delete_requests = []
for item in items:
key = self._serialize_item({"PK": item["PK"], "SK": item["SK"]})
req = {"DeleteRequest": {"Key": key}}
delete_requests.append(req)
BATCH_ITEM_SIZE = 25 # pylint: disable=invalid-name
for i in range(0, len(delete_requests), BATCH_ITEM_SIZE):
chunk = delete_requests[i : i + BATCH_ITEM_SIZE]
response = self.client.batch_write_item(
RequestItems={
self.table_name: chunk, # type: ignore
}
)
backoff_time = 1
while response.get("UnprocessedItems"):
backoff_time *= 2
# Cap the backoff time at 32 seconds & add jitter
delay = min(backoff_time, 32) + random.uniform(0, 1)
await asyncio.sleep(delay)
response = self.client.batch_write_item(
RequestItems=response["UnprocessedItems"]
)
self.client.delete_item(
TableName=self.table_name,
Key={
"PK": {"S": f"THREAD#{thread_id}"},
"SK": {"S": "THREAD"},
},
)
async def list_threads(
self, pagination: "Pagination", filters: "ThreadFilter"
) -> "PaginatedResponse[ThreadDict]":
_logger.info("DynamoDB: list_threads filters.userId=%s", filters.userId)
if filters.feedback:
_logger.warning("DynamoDB: filters on feedback not supported")
paginated_response: PaginatedResponse[ThreadDict] = PaginatedResponse(
data=[],
pageInfo=PageInfo(
hasNextPage=False, startCursor=pagination.cursor, endCursor=None
),
)
query_args: Dict[str, Any] = {
"TableName": self.table_name,
"IndexName": "UserThread",
"ScanIndexForward": False,
"Limit": self.user_thread_limit,
"KeyConditionExpression": "#UserThreadPK = :pk",
"ExpressionAttributeNames": {
"#UserThreadPK": "UserThreadPK",
},
"ExpressionAttributeValues": {
":pk": {"S": f"USER#{filters.userId}"},
},
}
if pagination.cursor:
query_args["ExclusiveStartKey"] = json.loads(pagination.cursor)
if filters.search:
query_args["FilterExpression"] = "contains(#name, :search)"
query_args["ExpressionAttributeNames"]["#name"] = "name"
query_args["ExpressionAttributeValues"][":search"] = {"S": filters.search}
response = self.client.query(**query_args) # type: ignore
if "LastEvaluatedKey" in response:
paginated_response.pageInfo.hasNextPage = True
paginated_response.pageInfo.endCursor = json.dumps(
response["LastEvaluatedKey"]
)
for item in response["Items"]:
deserialized_item: Dict[str, Any] = self._deserialize_item(item)
thread = ThreadDict( # type: ignore
id=deserialized_item["PK"].strip("THREAD#"),
createdAt=deserialized_item["UserThreadSK"].strip("TS#"),
name=deserialized_item["name"],
)
paginated_response.data.append(thread)
return paginated_response
async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
_logger.info("DynamoDB: get_thread thread=%s", thread_id)
# Get all thread records
thread_items: List[Any] = []
cursor: Dict[str, Any] = {}
while True:
response = self.client.query(
TableName=self.table_name,
KeyConditionExpression="#pk = :pk",
ExpressionAttributeNames={"#pk": "PK"},
ExpressionAttributeValues={":pk": {"S": f"THREAD#{thread_id}"}},
**cursor,
)
deserialized_items = map(self._deserialize_item, response["Items"])
thread_items.extend(deserialized_items)
if "LastEvaluatedKey" not in response:
break
cursor["ExclusiveStartKey"] = response["LastEvaluatedKey"]
if len(thread_items) == 0:
return None
# process accordingly
thread_dict: Optional[ThreadDict] = None
steps = []
elements = []
for item in thread_items:
if item["SK"] == "THREAD":
thread_dict = item
elif item["SK"].startswith("ELEMENT"):
if self.storage_provider is not None:
item["url"] = await self.storage_provider.get_read_url(
object_key=item["objectKey"],
)
elements.append(item)
elif item["SK"].startswith("STEP"):
if "feedback" in item: # Decimal is not json serializable
item["feedback"]["value"] = int(item["feedback"]["value"])
steps.append(item)
if not thread_dict:
if len(thread_items) > 0:
_logger.warning(
"DynamoDB: found orphaned items for thread=%s", thread_id
)
return None
steps.sort(key=lambda i: i["createdAt"])
thread_dict.update(
{
"steps": steps,
"elements": elements,
}
)
return thread_dict
async def update_thread(
self,
thread_id: str,
name: Optional[str] = None,
user_id: Optional[str] = None,
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
):
_logger.info("DynamoDB: update_thread thread=%s userId=%s", thread_id, user_id)
_logger.debug(
"DynamoDB: update_thread name=%s tags=%s metadata=%s", name, tags, metadata
)
ts = self._get_current_timestamp()
item = {
# GSI: UserThread
"UserThreadSK": f"TS#{ts}",
#
"id": thread_id,
"createdAt": ts,
"name": name,
"userId": user_id,
"userIdentifier": user_id,
"tags": tags,
"metadata": metadata,
}
if user_id:
# user_id may be None on subsequent calls, don't update UserThreadPK to "USER#{None}"
item["UserThreadPK"] = f"USER#{user_id}"
self._update_item(
key={
"PK": f"THREAD#{thread_id}",
"SK": "THREAD",
},
updates=item,
)
async def build_debug_url(self) -> str:
return ""
async def close(self) -> None:
if self.storage_provider:
await self.storage_provider.close()
self.client.close()