359 lines
11 KiB
Python
359 lines
11 KiB
Python
import asyncio
|
|
import json
|
|
import mimetypes
|
|
import re
|
|
import shutil
|
|
import uuid
|
|
from contextlib import AsyncExitStack
|
|
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Literal, Optional, Union
|
|
|
|
import aiofiles
|
|
|
|
from chainlit.logger import logger
|
|
from chainlit.types import AskFileSpec, FileReference
|
|
|
|
if TYPE_CHECKING:
|
|
from mcp import ClientSession
|
|
|
|
from chainlit.config import ChainlitConfig
|
|
from chainlit.types import FileDict
|
|
from chainlit.user import PersistedUser, User
|
|
|
|
ClientType = Literal["webapp", "copilot", "teams", "slack", "discord"]
|
|
|
|
|
|
class JSONEncoderIgnoreNonSerializable(json.JSONEncoder):
|
|
def default(self, o):
|
|
try:
|
|
return super().default(o)
|
|
except TypeError:
|
|
return None
|
|
|
|
|
|
def clean_metadata(metadata: Dict, max_size: int = 1048576):
|
|
cleaned_metadata = json.loads(
|
|
json.dumps(metadata, cls=JSONEncoderIgnoreNonSerializable, ensure_ascii=False)
|
|
)
|
|
|
|
metadata_size = len(json.dumps(cleaned_metadata).encode("utf-8"))
|
|
if metadata_size > max_size:
|
|
# Redact the metadata if it exceeds the maximum size
|
|
cleaned_metadata = {
|
|
"message": f"Metadata size exceeds the limit of {max_size} bytes. Redacted."
|
|
}
|
|
|
|
return cleaned_metadata
|
|
|
|
|
|
class BaseSession:
|
|
"""Base object."""
|
|
|
|
thread_id_to_resume: Optional[str] = None
|
|
client_type: ClientType
|
|
current_task: Optional[asyncio.Task] = None
|
|
|
|
def __init__(
|
|
self,
|
|
# Id of the session
|
|
id: str,
|
|
client_type: ClientType,
|
|
# Thread id
|
|
thread_id: Optional[str],
|
|
# Logged-in user information
|
|
user: Optional[Union["User", "PersistedUser"]],
|
|
# Logged-in user token
|
|
token: Optional[str],
|
|
# User specific environment variables. Empty if no user environment variables are required.
|
|
user_env: Optional[Dict[str, str]],
|
|
# WSGI environment variables for the connection request
|
|
environ: Optional[dict[str, Any]] = None,
|
|
# Chat profile selected before the session was created
|
|
chat_profile: Optional[str] = None,
|
|
):
|
|
if thread_id:
|
|
self.thread_id_to_resume = thread_id
|
|
self.thread_id = thread_id or str(uuid.uuid4())
|
|
self.user = user
|
|
self.client_type = client_type
|
|
self.token = token
|
|
self.has_first_interaction = False
|
|
self.user_env = user_env or {}
|
|
self.environ = environ or {}
|
|
self.chat_profile = chat_profile
|
|
|
|
self.files: Dict[str, FileDict] = {}
|
|
self.files_spec: Dict[str, AskFileSpec] = {}
|
|
|
|
self.id = id
|
|
|
|
self.chat_settings: Dict[str, Any] = {}
|
|
|
|
@property
|
|
def files_dir(self):
|
|
from chainlit.config import FILES_DIRECTORY
|
|
|
|
return FILES_DIRECTORY / self.id
|
|
|
|
async def persist_file(
|
|
self,
|
|
name: str,
|
|
mime: str,
|
|
path: Optional[str] = None,
|
|
content: Optional[Union[bytes, str]] = None,
|
|
) -> FileReference:
|
|
if not path and not content:
|
|
raise ValueError(
|
|
"Either path or content must be provided to persist a file"
|
|
)
|
|
|
|
self.files_dir.mkdir(exist_ok=True)
|
|
|
|
file_id = str(uuid.uuid4())
|
|
|
|
file_path = self.files_dir / file_id
|
|
|
|
file_extension = mimetypes.guess_extension(mime)
|
|
|
|
if file_extension:
|
|
file_path = file_path.with_suffix(file_extension)
|
|
|
|
if path:
|
|
# Copy the file from the given path
|
|
async with (
|
|
aiofiles.open(path, "rb") as src,
|
|
aiofiles.open(file_path, "wb") as dst,
|
|
):
|
|
await dst.write(await src.read())
|
|
elif content:
|
|
# Write the provided content to the file
|
|
async with aiofiles.open(file_path, "wb") as buffer:
|
|
if isinstance(content, str):
|
|
content = content.encode("utf-8")
|
|
await buffer.write(content)
|
|
|
|
# Get the file size
|
|
file_size = file_path.stat().st_size
|
|
# Store the file content in memory
|
|
self.files[file_id] = {
|
|
"id": file_id,
|
|
"path": file_path,
|
|
"name": name,
|
|
"type": mime,
|
|
"size": file_size,
|
|
}
|
|
|
|
return {"id": file_id}
|
|
|
|
def to_persistable(self) -> Dict:
|
|
from chainlit.config import config
|
|
from chainlit.user_session import user_sessions
|
|
|
|
user_session = user_sessions.get(self.id) or {} # type: Dict
|
|
user_session["chat_settings"] = self.chat_settings
|
|
user_session["chat_profile"] = self.chat_profile
|
|
user_session["client_type"] = self.client_type
|
|
|
|
# Check config setting for whether to persist user environment variables
|
|
user_session_copy = user_session.copy()
|
|
if not config.project.persist_user_env:
|
|
# Remove user environment variables (API keys) before persisting to database
|
|
user_session_copy["env"] = {}
|
|
|
|
metadata = clean_metadata(user_session_copy)
|
|
return metadata
|
|
|
|
|
|
class HTTPSession(BaseSession):
|
|
"""Internal HTTP session object. Used to consume Chainlit through API (no websocket)."""
|
|
|
|
def __init__(
|
|
self,
|
|
# Id of the session
|
|
id: str,
|
|
client_type: ClientType,
|
|
# Thread id
|
|
thread_id: Optional[str] = None,
|
|
# Logged-in user information
|
|
user: Optional[Union["User", "PersistedUser"]] = None,
|
|
# Logged-in user token
|
|
token: Optional[str] = None,
|
|
user_env: Optional[Dict[str, str]] = None,
|
|
# WSGI environment variables for the connection request
|
|
environ: Optional[dict[str, Any]] = None,
|
|
):
|
|
super().__init__(
|
|
id=id,
|
|
thread_id=thread_id,
|
|
user=user,
|
|
token=token,
|
|
client_type=client_type,
|
|
user_env=user_env,
|
|
environ=environ,
|
|
)
|
|
|
|
async def delete(self):
|
|
"""Delete the session."""
|
|
if self.files_dir.is_dir():
|
|
shutil.rmtree(self.files_dir)
|
|
|
|
|
|
ThreadQueue = Deque[tuple[Callable, object, tuple, Dict]]
|
|
|
|
|
|
class WebsocketSession(BaseSession):
|
|
"""Internal web socket session object.
|
|
|
|
A socket id is an ephemeral id that can't be used as a session id
|
|
(as it is for instance regenerated after each reconnection).
|
|
|
|
The Session object store an internal mapping between socket id and
|
|
a server generated session id, allowing to persists session
|
|
between socket reconnection but also retrieving a session by
|
|
socket id for convenience.
|
|
"""
|
|
|
|
to_clear: bool = False
|
|
|
|
mcp_sessions: dict[str, tuple["ClientSession", AsyncExitStack]]
|
|
|
|
def __init__(
|
|
self,
|
|
# Id from the session cookie
|
|
id: str,
|
|
# Associated socket id
|
|
socket_id: str,
|
|
# Function to emit to the client
|
|
emit: Callable[[str, Any], None],
|
|
# Function to emit to the client and wait for a response
|
|
emit_call: Callable[[Literal["ask", "call_fn"], Any, Optional[int]], Any],
|
|
# User specific environment variables. Empty if no user environment variables are required.
|
|
user_env: Dict[str, str],
|
|
client_type: ClientType,
|
|
# WSGI environment variables for the connection request
|
|
environ: Optional[dict[str, Any]] = None,
|
|
# Thread id
|
|
thread_id: Optional[str] = None,
|
|
# Logged-in user information
|
|
user: Optional[Union["User", "PersistedUser"]] = None,
|
|
# Logged-in user token
|
|
token: Optional[str] = None,
|
|
# Chat profile selected before the session was created
|
|
chat_profile: Optional[str] = None,
|
|
):
|
|
super().__init__(
|
|
id=id,
|
|
thread_id=thread_id,
|
|
user=user,
|
|
token=token,
|
|
user_env=user_env,
|
|
client_type=client_type,
|
|
chat_profile=chat_profile,
|
|
environ=environ,
|
|
)
|
|
|
|
self.socket_id = socket_id
|
|
self.emit_call = emit_call
|
|
self.emit = emit
|
|
|
|
self.restored = False
|
|
|
|
self.thread_queues: Dict[str, ThreadQueue] = {}
|
|
self.mcp_sessions = {}
|
|
|
|
match = (
|
|
re.match(
|
|
r"^\s*([a-zA-Z0-9-]+)", environ.get("HTTP_ACCEPT_LANGUAGE", "en-US")
|
|
)
|
|
if environ
|
|
else None
|
|
)
|
|
self.language = match.group(1) if match else "en-US"
|
|
|
|
self.config: ChainlitConfig = self.get_config()
|
|
|
|
ws_sessions_id[self.id] = self
|
|
ws_sessions_sid[socket_id] = self
|
|
|
|
def get_config(self) -> "ChainlitConfig":
|
|
"""
|
|
Return the config for this session: overridden if chat profile exists and has overrides, else global config.
|
|
"""
|
|
from chainlit.config import config as global_config
|
|
|
|
# If no chat profile, always fallback to global config
|
|
if not self.chat_profile:
|
|
return global_config
|
|
# If already computed, use self.config
|
|
if hasattr(self, "config") and self.config:
|
|
return self.config
|
|
# Try to compute overrides
|
|
cfg = global_config
|
|
if global_config.code.set_chat_profiles:
|
|
import asyncio
|
|
|
|
try:
|
|
profiles = asyncio.get_event_loop().run_until_complete(
|
|
global_config.code.set_chat_profiles(self.user, self.language)
|
|
)
|
|
current_profile = next(
|
|
(p for p in profiles if p.name == self.chat_profile), None
|
|
)
|
|
if current_profile and getattr(
|
|
current_profile, "config_overrides", None
|
|
):
|
|
cfg = global_config.with_overrides(current_profile.config_overrides)
|
|
except Exception:
|
|
pass
|
|
self.config = cfg
|
|
return cfg
|
|
|
|
def restore(self, new_socket_id: str):
|
|
"""Associate a new socket id to the session."""
|
|
ws_sessions_sid.pop(self.socket_id, None)
|
|
ws_sessions_sid[new_socket_id] = self
|
|
self.socket_id = new_socket_id
|
|
self.restored = True
|
|
|
|
async def delete(self):
|
|
"""Delete the session."""
|
|
if self.files_dir.is_dir():
|
|
shutil.rmtree(self.files_dir)
|
|
ws_sessions_sid.pop(self.socket_id, None)
|
|
ws_sessions_id.pop(self.id, None)
|
|
|
|
for _, exit_stack in self.mcp_sessions.values():
|
|
try:
|
|
await exit_stack.aclose()
|
|
except Exception:
|
|
pass
|
|
|
|
async def flush_method_queue(self):
|
|
for method_name, queue in self.thread_queues.items():
|
|
while queue:
|
|
method, self, args, kwargs = queue.popleft()
|
|
try:
|
|
await method(self, *args, **kwargs)
|
|
except Exception as e:
|
|
logger.error(f"Error while flushing {method_name}: {e}")
|
|
|
|
@classmethod
|
|
def get(cls, socket_id: str):
|
|
"""Get session by socket id."""
|
|
return ws_sessions_sid.get(socket_id)
|
|
|
|
@classmethod
|
|
def get_by_id(cls, session_id: str):
|
|
"""Get session by session id."""
|
|
return ws_sessions_id.get(session_id)
|
|
|
|
@classmethod
|
|
def require(cls, socket_id: str):
|
|
"""Throws an exception if the session is not found."""
|
|
if session := cls.get(socket_id):
|
|
return session
|
|
raise ValueError("Session not found")
|
|
|
|
|
|
ws_sessions_sid: Dict[str, WebsocketSession] = {}
|
|
ws_sessions_id: Dict[str, WebsocketSession] = {}
|