112 lines
2.5 KiB
Python
112 lines
2.5 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
|
|
|
from chainlit.types import (
|
|
Feedback,
|
|
PaginatedResponse,
|
|
Pagination,
|
|
ThreadDict,
|
|
ThreadFilter,
|
|
)
|
|
|
|
from .utils import queue_until_user_message
|
|
|
|
if TYPE_CHECKING:
|
|
from chainlit.element import Element, ElementDict
|
|
from chainlit.step import StepDict
|
|
from chainlit.user import PersistedUser, User
|
|
|
|
|
|
class BaseDataLayer(ABC):
|
|
"""Base class for data persistence."""
|
|
|
|
@abstractmethod
|
|
async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def create_user(self, user: "User") -> Optional["PersistedUser"]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def delete_feedback(
|
|
self,
|
|
feedback_id: str,
|
|
) -> bool:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def upsert_feedback(
|
|
self,
|
|
feedback: Feedback,
|
|
) -> str:
|
|
pass
|
|
|
|
@queue_until_user_message()
|
|
@abstractmethod
|
|
async def create_element(self, element: "Element"):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_element(
|
|
self, thread_id: str, element_id: str
|
|
) -> Optional["ElementDict"]:
|
|
pass
|
|
|
|
@queue_until_user_message()
|
|
@abstractmethod
|
|
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
|
|
pass
|
|
|
|
@queue_until_user_message()
|
|
@abstractmethod
|
|
async def create_step(self, step_dict: "StepDict"):
|
|
pass
|
|
|
|
@queue_until_user_message()
|
|
@abstractmethod
|
|
async def update_step(self, step_dict: "StepDict"):
|
|
pass
|
|
|
|
@queue_until_user_message()
|
|
@abstractmethod
|
|
async def delete_step(self, step_id: str):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_thread_author(self, thread_id: str) -> str:
|
|
return ""
|
|
|
|
@abstractmethod
|
|
async def delete_thread(self, thread_id: str):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def list_threads(
|
|
self, pagination: "Pagination", filters: "ThreadFilter"
|
|
) -> "PaginatedResponse[ThreadDict]":
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def update_thread(
|
|
self,
|
|
thread_id: str,
|
|
name: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
metadata: Optional[Dict] = None,
|
|
tags: Optional[List[str]] = None,
|
|
):
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def build_debug_url(self) -> str:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def close(self) -> None:
|
|
pass
|