701 lines
25 KiB
Python
701 lines
25 KiB
Python
import json
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
|
|
import aiofiles
|
|
import asyncpg # type: ignore
|
|
|
|
from chainlit.data.base import BaseDataLayer
|
|
from chainlit.data.storage_clients.base import BaseStorageClient
|
|
from chainlit.data.utils import queue_until_user_message
|
|
from chainlit.element import ElementDict
|
|
from chainlit.logger import logger
|
|
from chainlit.step import StepDict
|
|
from chainlit.types import (
|
|
Feedback,
|
|
FeedbackDict,
|
|
PageInfo,
|
|
PaginatedResponse,
|
|
Pagination,
|
|
ThreadDict,
|
|
ThreadFilter,
|
|
)
|
|
from chainlit.user import PersistedUser, User
|
|
|
|
# Import for runtime usage (isinstance checks)
|
|
try:
|
|
from chainlit.data.storage_clients.gcs import GCSStorageClient
|
|
except ImportError:
|
|
GCSStorageClient = None # type: ignore[assignment,misc]
|
|
|
|
if TYPE_CHECKING:
|
|
from chainlit.data.storage_clients.gcs import GCSStorageClient
|
|
from chainlit.element import Element, ElementDict
|
|
from chainlit.step import StepDict
|
|
|
|
ISO_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
|
|
|
|
|
|
class ChainlitDataLayer(BaseDataLayer):
|
|
def __init__(
|
|
self,
|
|
database_url: str,
|
|
storage_client: Optional[BaseStorageClient] = None,
|
|
show_logger: bool = False,
|
|
):
|
|
self.database_url = database_url
|
|
self.pool: Optional[asyncpg.Pool] = None
|
|
self.storage_client = storage_client
|
|
self.show_logger = show_logger
|
|
|
|
async def connect(self):
|
|
if not self.pool:
|
|
self.pool = await asyncpg.create_pool(self.database_url)
|
|
|
|
async def get_current_timestamp(self) -> datetime:
|
|
return datetime.now()
|
|
|
|
async def execute_query(
|
|
self, query: str, params: Union[Dict, None] = None
|
|
) -> List[Dict[str, Any]]:
|
|
if not self.pool:
|
|
await self.connect()
|
|
|
|
try:
|
|
async with self.pool.acquire() as connection: # type: ignore
|
|
try:
|
|
if params:
|
|
records = await connection.fetch(query, *params.values())
|
|
else:
|
|
records = await connection.fetch(query)
|
|
return [dict(record) for record in records]
|
|
except Exception as e:
|
|
logger.error(f"Database error: {e!s}")
|
|
raise
|
|
except (
|
|
asyncpg.exceptions.ConnectionDoesNotExistError,
|
|
asyncpg.exceptions.InterfaceError,
|
|
) as e:
|
|
# Handle connection issues by cleaning up and rethrowing
|
|
logger.error(f"Connection error: {e!s}")
|
|
await self.cleanup()
|
|
raise
|
|
|
|
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
|
|
query = """
|
|
SELECT * FROM "User"
|
|
WHERE identifier = $1
|
|
"""
|
|
result = await self.execute_query(query, {"identifier": identifier})
|
|
if not result or len(result) == 0:
|
|
return None
|
|
row = result[0]
|
|
|
|
return PersistedUser(
|
|
id=str(row.get("id")),
|
|
identifier=str(row.get("identifier")),
|
|
createdAt=row.get("createdAt").isoformat(), # type: ignore
|
|
metadata=json.loads(row.get("metadata", "{}")),
|
|
)
|
|
|
|
async def create_user(self, user: User) -> Optional[PersistedUser]:
|
|
query = """
|
|
INSERT INTO "User" (id, identifier, metadata, "createdAt", "updatedAt")
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (identifier) DO UPDATE
|
|
SET metadata = $3
|
|
RETURNING *
|
|
"""
|
|
now = await self.get_current_timestamp()
|
|
params = {
|
|
"id": str(uuid.uuid4()),
|
|
"identifier": user.identifier,
|
|
"metadata": json.dumps(user.metadata),
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
}
|
|
result = await self.execute_query(query, params)
|
|
row = result[0]
|
|
|
|
return PersistedUser(
|
|
id=str(row.get("id")),
|
|
identifier=str(row.get("identifier")),
|
|
createdAt=row.get("createdAt").isoformat(), # type: ignore
|
|
metadata=json.loads(row.get("metadata", "{}")),
|
|
)
|
|
|
|
async def delete_feedback(self, feedback_id: str) -> bool:
|
|
query = """
|
|
DELETE FROM "Feedback" WHERE id = $1
|
|
"""
|
|
await self.execute_query(query, {"feedback_id": feedback_id})
|
|
return True
|
|
|
|
async def upsert_feedback(self, feedback: Feedback) -> str:
|
|
query = """
|
|
INSERT INTO "Feedback" (id, "stepId", name, value, comment)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (id) DO UPDATE
|
|
SET value = $4, comment = $5
|
|
RETURNING id
|
|
"""
|
|
feedback_id = feedback.id or str(uuid.uuid4())
|
|
params = {
|
|
"id": feedback_id,
|
|
"step_id": feedback.forId,
|
|
"name": "user_feedback",
|
|
"value": float(feedback.value),
|
|
"comment": feedback.comment,
|
|
}
|
|
results = await self.execute_query(query, params)
|
|
return str(results[0]["id"])
|
|
|
|
@queue_until_user_message()
|
|
async def create_element(self, element: "Element"):
|
|
if not element.for_id:
|
|
return
|
|
|
|
if element.thread_id:
|
|
query = 'SELECT id FROM "Thread" WHERE id = $1'
|
|
results = await self.execute_query(query, {"thread_id": element.thread_id})
|
|
if not results:
|
|
await self.update_thread(thread_id=element.thread_id)
|
|
|
|
if element.for_id:
|
|
query = 'SELECT id FROM "Step" WHERE id = $1'
|
|
results = await self.execute_query(query, {"step_id": element.for_id})
|
|
if not results:
|
|
await self.create_step(
|
|
{
|
|
"id": element.for_id,
|
|
"metadata": {},
|
|
"type": "run",
|
|
"start_time": await self.get_current_timestamp(),
|
|
"end_time": await self.get_current_timestamp(),
|
|
}
|
|
)
|
|
|
|
# Handle file uploads only if storage_client is configured
|
|
path = None
|
|
if self.storage_client:
|
|
content: Optional[Union[bytes, str]] = None
|
|
|
|
if element.path:
|
|
async with aiofiles.open(element.path, "rb") as f:
|
|
content = await f.read()
|
|
elif element.content:
|
|
content = element.content
|
|
elif not element.url:
|
|
raise ValueError("Element url, path or content must be provided")
|
|
|
|
if content is not None:
|
|
if element.thread_id:
|
|
path = f"threads/{element.thread_id}/files/{element.id}"
|
|
else:
|
|
path = f"files/{element.id}"
|
|
|
|
content_disposition = (
|
|
f'attachment; filename="{element.name}"'
|
|
if not (
|
|
GCSStorageClient is not None
|
|
and isinstance(self.storage_client, GCSStorageClient)
|
|
)
|
|
else None
|
|
)
|
|
await self.storage_client.upload_file(
|
|
object_key=path,
|
|
data=content,
|
|
mime=element.mime or "application/octet-stream",
|
|
overwrite=True,
|
|
content_disposition=content_disposition,
|
|
)
|
|
|
|
else:
|
|
# Log warning only if element has file content that needs uploading
|
|
if element.path or element.url or element.content:
|
|
logger.warning(
|
|
"Data Layer: No storage client configured. "
|
|
"File will not be uploaded."
|
|
)
|
|
|
|
# Always persist element metadata to database
|
|
query = """
|
|
INSERT INTO "Element" (
|
|
id, "threadId", "stepId", metadata, mime, name, "objectKey", url,
|
|
"chainlitKey", display, size, language, page, props
|
|
) VALUES (
|
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14
|
|
)
|
|
ON CONFLICT (id) DO UPDATE SET
|
|
props = EXCLUDED.props
|
|
"""
|
|
params = {
|
|
"id": element.id,
|
|
"thread_id": element.thread_id,
|
|
"step_id": element.for_id,
|
|
"metadata": json.dumps(
|
|
{
|
|
"size": element.size,
|
|
"language": element.language,
|
|
"display": element.display,
|
|
"type": element.type,
|
|
"page": getattr(element, "page", None),
|
|
}
|
|
),
|
|
"mime": element.mime,
|
|
"name": element.name,
|
|
"object_key": path,
|
|
"url": element.url,
|
|
"chainlit_key": element.chainlit_key,
|
|
"display": element.display,
|
|
"size": element.size,
|
|
"language": element.language,
|
|
"page": getattr(element, "page", None),
|
|
"props": json.dumps(getattr(element, "props", {})),
|
|
}
|
|
await self.execute_query(query, params)
|
|
|
|
async def get_element(
|
|
self, thread_id: str, element_id: str
|
|
) -> Optional[ElementDict]:
|
|
query = """
|
|
SELECT * FROM "Element"
|
|
WHERE id = $1 AND "threadId" = $2
|
|
"""
|
|
results = await self.execute_query(
|
|
query, {"element_id": element_id, "thread_id": thread_id}
|
|
)
|
|
|
|
if not results:
|
|
return None
|
|
|
|
row = results[0]
|
|
metadata = json.loads(row.get("metadata", "{}"))
|
|
|
|
return ElementDict(
|
|
id=str(row["id"]),
|
|
threadId=str(row["threadId"]),
|
|
type=metadata.get("type", "file"),
|
|
url=str(row["url"]),
|
|
name=str(row["name"]),
|
|
mime=str(row["mime"]),
|
|
objectKey=str(row["objectKey"]),
|
|
forId=str(row["stepId"]),
|
|
chainlitKey=row.get("chainlitKey"),
|
|
display=row["display"],
|
|
size=row["size"],
|
|
language=row["language"],
|
|
page=row["page"],
|
|
autoPlay=row.get("autoPlay"),
|
|
playerConfig=row.get("playerConfig"),
|
|
props=json.loads(row.get("props", "{}")),
|
|
)
|
|
|
|
@queue_until_user_message()
|
|
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
|
|
query = """
|
|
SELECT * FROM "Element"
|
|
WHERE id = $1
|
|
"""
|
|
elements = await self.execute_query(query, {"id": element_id})
|
|
|
|
if self.storage_client is not None and len(elements) > 0:
|
|
if elements[0]["objectKey"]:
|
|
await self.storage_client.delete_file(
|
|
object_key=elements[0]["objectKey"]
|
|
)
|
|
query = """
|
|
DELETE FROM "Element"
|
|
WHERE id = $1
|
|
"""
|
|
params = {"id": element_id}
|
|
|
|
if thread_id:
|
|
query += ' AND "threadId" = $2'
|
|
params["thread_id"] = thread_id
|
|
|
|
await self.execute_query(query, params)
|
|
|
|
@queue_until_user_message()
|
|
async def create_step(self, step_dict: StepDict):
|
|
if step_dict.get("threadId"):
|
|
thread_query = 'SELECT id FROM "Thread" WHERE id = $1'
|
|
thread_results = await self.execute_query(
|
|
thread_query, {"thread_id": step_dict["threadId"]}
|
|
)
|
|
if not thread_results:
|
|
await self.update_thread(thread_id=step_dict["threadId"])
|
|
|
|
if step_dict.get("parentId"):
|
|
parent_query = 'SELECT id FROM "Step" WHERE id = $1'
|
|
parent_results = await self.execute_query(
|
|
parent_query, {"parent_id": step_dict["parentId"]}
|
|
)
|
|
if not parent_results:
|
|
await self.create_step(
|
|
{
|
|
"id": step_dict["parentId"],
|
|
"metadata": {},
|
|
"type": "run",
|
|
"createdAt": step_dict.get("createdAt"),
|
|
}
|
|
)
|
|
|
|
query = """
|
|
INSERT INTO "Step" (
|
|
id, "threadId", "parentId", input, metadata, name, output,
|
|
type, "startTime", "endTime", "showInput", "isError"
|
|
) VALUES (
|
|
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12
|
|
)
|
|
ON CONFLICT (id) DO UPDATE SET
|
|
"parentId" = COALESCE(EXCLUDED."parentId", "Step"."parentId"),
|
|
input = COALESCE(EXCLUDED.input, "Step".input),
|
|
metadata = CASE
|
|
WHEN EXCLUDED.metadata <> '{}' THEN EXCLUDED.metadata
|
|
ELSE "Step".metadata
|
|
END,
|
|
name = COALESCE(EXCLUDED.name, "Step".name),
|
|
output = COALESCE(EXCLUDED.output, "Step".output),
|
|
type = CASE
|
|
WHEN EXCLUDED.type = 'run' THEN "Step".type
|
|
ELSE EXCLUDED.type
|
|
END,
|
|
"threadId" = COALESCE(EXCLUDED."threadId", "Step"."threadId"),
|
|
"endTime" = COALESCE(EXCLUDED."endTime", "Step"."endTime"),
|
|
"startTime" = LEAST(EXCLUDED."startTime", "Step"."startTime"),
|
|
"showInput" = COALESCE(EXCLUDED."showInput", "Step"."showInput"),
|
|
"isError" = COALESCE(EXCLUDED."isError", "Step"."isError")
|
|
"""
|
|
|
|
timestamp = await self.get_current_timestamp()
|
|
created_at = step_dict.get("createdAt")
|
|
if created_at:
|
|
timestamp = datetime.strptime(created_at, ISO_FORMAT)
|
|
|
|
params = {
|
|
"id": step_dict["id"],
|
|
"thread_id": step_dict.get("threadId"),
|
|
"parent_id": step_dict.get("parentId"),
|
|
"input": step_dict.get("input"),
|
|
"metadata": json.dumps(step_dict.get("metadata", {})),
|
|
"name": step_dict.get("name"),
|
|
"output": step_dict.get("output"),
|
|
"type": step_dict["type"],
|
|
"start_time": timestamp,
|
|
"end_time": timestamp,
|
|
"show_input": str(step_dict.get("showInput", "json")),
|
|
"is_error": step_dict.get("isError", False),
|
|
}
|
|
await self.execute_query(query, params)
|
|
|
|
@queue_until_user_message()
|
|
async def update_step(self, step_dict: StepDict):
|
|
await self.create_step(step_dict)
|
|
|
|
@queue_until_user_message()
|
|
async def delete_step(self, step_id: str):
|
|
# Delete associated elements and feedbacks first
|
|
await self.execute_query(
|
|
'DELETE FROM "Element" WHERE "stepId" = $1', {"step_id": step_id}
|
|
)
|
|
await self.execute_query(
|
|
'DELETE FROM "Feedback" WHERE "stepId" = $1', {"step_id": step_id}
|
|
)
|
|
# Delete the step
|
|
await self.execute_query(
|
|
'DELETE FROM "Step" WHERE id = $1', {"step_id": step_id}
|
|
)
|
|
|
|
async def get_thread_author(self, thread_id: str) -> str:
|
|
query = """
|
|
SELECT u.identifier
|
|
FROM "Thread" t
|
|
JOIN "User" u ON t."userId" = u.id
|
|
WHERE t.id = $1
|
|
"""
|
|
results = await self.execute_query(query, {"thread_id": thread_id})
|
|
if not results:
|
|
raise ValueError(f"Thread {thread_id} not found")
|
|
return results[0]["identifier"]
|
|
|
|
async def delete_thread(self, thread_id: str):
|
|
elements_query = """
|
|
SELECT * FROM "Element"
|
|
WHERE "threadId" = $1
|
|
"""
|
|
elements_results = await self.execute_query(
|
|
elements_query, {"thread_id": thread_id}
|
|
)
|
|
|
|
if self.storage_client is not None:
|
|
for elem in elements_results:
|
|
if elem["objectKey"]:
|
|
await self.storage_client.delete_file(object_key=elem["objectKey"])
|
|
|
|
await self.execute_query(
|
|
'DELETE FROM "Thread" WHERE id = $1', {"thread_id": thread_id}
|
|
)
|
|
|
|
async def list_threads(
|
|
self, pagination: Pagination, filters: ThreadFilter
|
|
) -> PaginatedResponse[ThreadDict]:
|
|
query = """
|
|
SELECT
|
|
t.*,
|
|
u.identifier as user_identifier,
|
|
(SELECT COUNT(*) FROM "Thread" WHERE "userId" = t."userId") as total
|
|
FROM "Thread" t
|
|
LEFT JOIN "User" u ON t."userId" = u.id
|
|
WHERE t."deletedAt" IS NULL
|
|
"""
|
|
params: Dict[str, Any] = {}
|
|
param_count = 1
|
|
|
|
if filters.search:
|
|
query += f" AND t.name ILIKE ${param_count}"
|
|
params["name"] = f"%{filters.search}%"
|
|
param_count += 1
|
|
|
|
if filters.userId:
|
|
query += f' AND t."userId" = ${param_count}'
|
|
params["user_id"] = filters.userId
|
|
param_count += 1
|
|
|
|
if pagination.cursor:
|
|
query += f' AND t."updatedAt" < (SELECT "updatedAt" FROM "Thread" WHERE id = ${param_count})'
|
|
params["cursor"] = pagination.cursor
|
|
param_count += 1
|
|
|
|
query += f' ORDER BY t."updatedAt" DESC LIMIT ${param_count}'
|
|
params["limit"] = pagination.first + 1
|
|
|
|
results = await self.execute_query(query, params)
|
|
threads = results
|
|
|
|
has_next_page = len(threads) > pagination.first
|
|
if has_next_page:
|
|
threads = threads[:-1]
|
|
|
|
thread_dicts = []
|
|
for thread in threads:
|
|
thread_dict = ThreadDict(
|
|
id=str(thread["id"]),
|
|
createdAt=thread["updatedAt"].isoformat(),
|
|
name=thread["name"],
|
|
userId=str(thread["userId"]) if thread["userId"] else None,
|
|
userIdentifier=thread["user_identifier"],
|
|
metadata=json.loads(thread["metadata"]),
|
|
steps=[],
|
|
elements=[],
|
|
tags=[],
|
|
)
|
|
thread_dicts.append(thread_dict)
|
|
|
|
return PaginatedResponse(
|
|
pageInfo=PageInfo(
|
|
hasNextPage=has_next_page,
|
|
startCursor=thread_dicts[0]["id"] if thread_dicts else None,
|
|
endCursor=thread_dicts[-1]["id"] if thread_dicts else None,
|
|
),
|
|
data=thread_dicts,
|
|
)
|
|
|
|
async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
|
|
query = """
|
|
SELECT t.*, u.identifier as user_identifier
|
|
FROM "Thread" t
|
|
LEFT JOIN "User" u ON t."userId" = u.id
|
|
WHERE t.id = $1 AND t."deletedAt" IS NULL
|
|
"""
|
|
results = await self.execute_query(query, {"thread_id": thread_id})
|
|
|
|
if not results:
|
|
return None
|
|
|
|
thread = results[0]
|
|
|
|
# Get steps and related feedback
|
|
steps_query = """
|
|
SELECT s.*,
|
|
f.id feedback_id,
|
|
f.value feedback_value,
|
|
f."comment" feedback_comment
|
|
FROM "Step" s left join "Feedback" f on s.id = f."stepId"
|
|
WHERE s."threadId" = $1
|
|
ORDER BY "startTime"
|
|
"""
|
|
steps_results = await self.execute_query(steps_query, {"thread_id": thread_id})
|
|
|
|
# Get elements
|
|
elements_query = """
|
|
SELECT * FROM "Element"
|
|
WHERE "threadId" = $1
|
|
"""
|
|
elements_results = await self.execute_query(
|
|
elements_query, {"thread_id": thread_id}
|
|
)
|
|
|
|
if self.storage_client is not None:
|
|
for elem in elements_results:
|
|
if not elem["url"] and elem["objectKey"]:
|
|
elem["url"] = await self.storage_client.get_read_url(
|
|
object_key=elem["objectKey"],
|
|
)
|
|
|
|
return ThreadDict(
|
|
id=str(thread["id"]),
|
|
createdAt=thread["createdAt"].isoformat(),
|
|
name=thread["name"],
|
|
userId=str(thread["userId"]) if thread["userId"] else None,
|
|
userIdentifier=thread["user_identifier"],
|
|
metadata=json.loads(thread["metadata"]),
|
|
steps=[self._convert_step_row_to_dict(step) for step in steps_results],
|
|
elements=[
|
|
self._convert_element_row_to_dict(elem) for elem in elements_results
|
|
],
|
|
tags=[],
|
|
)
|
|
|
|
async def update_thread(
|
|
self,
|
|
thread_id: str,
|
|
name: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
tags: Optional[List[str]] = None,
|
|
):
|
|
if self.show_logger:
|
|
logger.info(f"asyncpg: update_thread, thread_id={thread_id}")
|
|
|
|
thread_name = truncate(
|
|
name
|
|
if name is not None
|
|
else (metadata.get("name") if metadata and "name" in metadata else None)
|
|
)
|
|
|
|
# Merge incoming metadata with existing metadata, deleting incoming keys with None values
|
|
if metadata is not None:
|
|
existing = await self.execute_query(
|
|
'SELECT "metadata" FROM "Thread" WHERE id = $1',
|
|
{"thread_id": thread_id},
|
|
)
|
|
base = {}
|
|
if isinstance(existing, list) and existing:
|
|
raw = existing[0].get("metadata") or {}
|
|
if isinstance(raw, str):
|
|
try:
|
|
base = json.loads(raw)
|
|
except json.JSONDecodeError:
|
|
base = {}
|
|
elif isinstance(raw, dict):
|
|
base = raw
|
|
to_delete = {k for k, v in metadata.items() if v is None}
|
|
incoming = {k: v for k, v in metadata.items() if v is not None}
|
|
base = {k: v for k, v in base.items() if k not in to_delete}
|
|
metadata = {**base, **incoming}
|
|
|
|
data = {
|
|
"id": thread_id,
|
|
"name": thread_name,
|
|
"userId": user_id,
|
|
"tags": tags,
|
|
"metadata": json.dumps(metadata or {}),
|
|
"updatedAt": datetime.now(),
|
|
}
|
|
|
|
# Remove None values
|
|
data = {k: v for k, v in data.items() if v is not None}
|
|
|
|
# Build the query dynamically based on available fields
|
|
columns = [f'"{k}"' for k in data.keys()]
|
|
placeholders = [f"${i + 1}" for i in range(len(data))]
|
|
values = list(data.values())
|
|
|
|
update_sets = [f'"{k}" = EXCLUDED."{k}"' for k in data.keys() if k != "id"]
|
|
|
|
if update_sets:
|
|
query = f"""
|
|
INSERT INTO "Thread" ({", ".join(columns)})
|
|
VALUES ({", ".join(placeholders)})
|
|
ON CONFLICT (id) DO UPDATE
|
|
SET {", ".join(update_sets)};
|
|
"""
|
|
else:
|
|
query = f"""
|
|
INSERT INTO "Thread" ({", ".join(columns)})
|
|
VALUES ({", ".join(placeholders)})
|
|
ON CONFLICT (id) DO NOTHING
|
|
"""
|
|
|
|
await self.execute_query(query, {str(i + 1): v for i, v in enumerate(values)})
|
|
|
|
def _extract_feedback_dict_from_step_row(self, row: Dict) -> Optional[FeedbackDict]:
|
|
if row["feedback_id"] is not None:
|
|
return FeedbackDict(
|
|
forId=row["id"],
|
|
id=row["feedback_id"],
|
|
value=row["feedback_value"],
|
|
comment=row["feedback_comment"],
|
|
)
|
|
return None
|
|
|
|
def _convert_step_row_to_dict(self, row: Dict) -> StepDict:
|
|
return StepDict(
|
|
id=str(row["id"]),
|
|
threadId=str(row["threadId"]) if row.get("threadId") else "",
|
|
parentId=str(row["parentId"]) if row.get("parentId") else None,
|
|
name=str(row.get("name")),
|
|
type=row["type"],
|
|
input=row.get("input", {}),
|
|
output=row.get("output", {}),
|
|
metadata=json.loads(row.get("metadata", "{}")),
|
|
createdAt=row["createdAt"].isoformat() if row.get("createdAt") else None,
|
|
start=row["startTime"].isoformat() if row.get("startTime") else None,
|
|
showInput=row.get("showInput"),
|
|
isError=row.get("isError"),
|
|
end=row["endTime"].isoformat() if row.get("endTime") else None,
|
|
feedback=self._extract_feedback_dict_from_step_row(row),
|
|
)
|
|
|
|
def _convert_element_row_to_dict(self, row: Dict) -> ElementDict:
|
|
metadata = json.loads(row.get("metadata", "{}"))
|
|
return ElementDict(
|
|
id=str(row["id"]),
|
|
threadId=str(row["threadId"]) if row.get("threadId") else None,
|
|
type=metadata.get("type", "file"),
|
|
url=row["url"],
|
|
name=row["name"],
|
|
mime=row["mime"],
|
|
objectKey=row["objectKey"],
|
|
forId=str(row["stepId"]),
|
|
chainlitKey=row.get("chainlitKey"),
|
|
display=row["display"],
|
|
size=row["size"],
|
|
language=row["language"],
|
|
page=row["page"],
|
|
autoPlay=row.get("autoPlay"),
|
|
playerConfig=row.get("playerConfig"),
|
|
props=json.loads(row.get("props") or "{}"),
|
|
)
|
|
|
|
async def build_debug_url(self) -> str:
|
|
return ""
|
|
|
|
async def cleanup(self):
|
|
"""Cleanup database connections"""
|
|
if self.pool:
|
|
logger.debug("Cleaning up connection pool")
|
|
await self.pool.close()
|
|
self.pool = None
|
|
|
|
async def close(self) -> None:
|
|
if self.storage_client:
|
|
await self.storage_client.close()
|
|
await self.cleanup()
|
|
|
|
|
|
def truncate(text: Optional[str], max_length: int = 255) -> Optional[str]:
|
|
return None if text is None else text[:max_length]
|