ai-station/.venv/lib/python3.12/site-packages/chainlit/session.py

359 lines
11 KiB
Python
Raw Normal View History

import asyncio
import json
import mimetypes
import re
import shutil
import uuid
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, Literal, Optional, Union
import aiofiles
from chainlit.logger import logger
from chainlit.types import AskFileSpec, FileReference
if TYPE_CHECKING:
from mcp import ClientSession
from chainlit.config import ChainlitConfig
from chainlit.types import FileDict
from chainlit.user import PersistedUser, User
ClientType = Literal["webapp", "copilot", "teams", "slack", "discord"]
class JSONEncoderIgnoreNonSerializable(json.JSONEncoder):
def default(self, o):
try:
return super().default(o)
except TypeError:
return None
def clean_metadata(metadata: Dict, max_size: int = 1048576):
cleaned_metadata = json.loads(
json.dumps(metadata, cls=JSONEncoderIgnoreNonSerializable, ensure_ascii=False)
)
metadata_size = len(json.dumps(cleaned_metadata).encode("utf-8"))
if metadata_size > max_size:
# Redact the metadata if it exceeds the maximum size
cleaned_metadata = {
"message": f"Metadata size exceeds the limit of {max_size} bytes. Redacted."
}
return cleaned_metadata
class BaseSession:
"""Base object."""
thread_id_to_resume: Optional[str] = None
client_type: ClientType
current_task: Optional[asyncio.Task] = None
def __init__(
self,
# Id of the session
id: str,
client_type: ClientType,
# Thread id
thread_id: Optional[str],
# Logged-in user information
user: Optional[Union["User", "PersistedUser"]],
# Logged-in user token
token: Optional[str],
# User specific environment variables. Empty if no user environment variables are required.
user_env: Optional[Dict[str, str]],
# WSGI environment variables for the connection request
environ: Optional[dict[str, Any]] = None,
# Chat profile selected before the session was created
chat_profile: Optional[str] = None,
):
if thread_id:
self.thread_id_to_resume = thread_id
self.thread_id = thread_id or str(uuid.uuid4())
self.user = user
self.client_type = client_type
self.token = token
self.has_first_interaction = False
self.user_env = user_env or {}
self.environ = environ or {}
self.chat_profile = chat_profile
self.files: Dict[str, FileDict] = {}
self.files_spec: Dict[str, AskFileSpec] = {}
self.id = id
self.chat_settings: Dict[str, Any] = {}
@property
def files_dir(self):
from chainlit.config import FILES_DIRECTORY
return FILES_DIRECTORY / self.id
async def persist_file(
self,
name: str,
mime: str,
path: Optional[str] = None,
content: Optional[Union[bytes, str]] = None,
) -> FileReference:
if not path and not content:
raise ValueError(
"Either path or content must be provided to persist a file"
)
self.files_dir.mkdir(exist_ok=True)
file_id = str(uuid.uuid4())
file_path = self.files_dir / file_id
file_extension = mimetypes.guess_extension(mime)
if file_extension:
file_path = file_path.with_suffix(file_extension)
if path:
# Copy the file from the given path
async with (
aiofiles.open(path, "rb") as src,
aiofiles.open(file_path, "wb") as dst,
):
await dst.write(await src.read())
elif content:
# Write the provided content to the file
async with aiofiles.open(file_path, "wb") as buffer:
if isinstance(content, str):
content = content.encode("utf-8")
await buffer.write(content)
# Get the file size
file_size = file_path.stat().st_size
# Store the file content in memory
self.files[file_id] = {
"id": file_id,
"path": file_path,
"name": name,
"type": mime,
"size": file_size,
}
return {"id": file_id}
def to_persistable(self) -> Dict:
from chainlit.config import config
from chainlit.user_session import user_sessions
user_session = user_sessions.get(self.id) or {} # type: Dict
user_session["chat_settings"] = self.chat_settings
user_session["chat_profile"] = self.chat_profile
user_session["client_type"] = self.client_type
# Check config setting for whether to persist user environment variables
user_session_copy = user_session.copy()
if not config.project.persist_user_env:
# Remove user environment variables (API keys) before persisting to database
user_session_copy["env"] = {}
metadata = clean_metadata(user_session_copy)
return metadata
class HTTPSession(BaseSession):
"""Internal HTTP session object. Used to consume Chainlit through API (no websocket)."""
def __init__(
self,
# Id of the session
id: str,
client_type: ClientType,
# Thread id
thread_id: Optional[str] = None,
# Logged-in user information
user: Optional[Union["User", "PersistedUser"]] = None,
# Logged-in user token
token: Optional[str] = None,
user_env: Optional[Dict[str, str]] = None,
# WSGI environment variables for the connection request
environ: Optional[dict[str, Any]] = None,
):
super().__init__(
id=id,
thread_id=thread_id,
user=user,
token=token,
client_type=client_type,
user_env=user_env,
environ=environ,
)
async def delete(self):
"""Delete the session."""
if self.files_dir.is_dir():
shutil.rmtree(self.files_dir)
ThreadQueue = Deque[tuple[Callable, object, tuple, Dict]]
class WebsocketSession(BaseSession):
"""Internal web socket session object.
A socket id is an ephemeral id that can't be used as a session id
(as it is for instance regenerated after each reconnection).
The Session object store an internal mapping between socket id and
a server generated session id, allowing to persists session
between socket reconnection but also retrieving a session by
socket id for convenience.
"""
to_clear: bool = False
mcp_sessions: dict[str, tuple["ClientSession", AsyncExitStack]]
def __init__(
self,
# Id from the session cookie
id: str,
# Associated socket id
socket_id: str,
# Function to emit to the client
emit: Callable[[str, Any], None],
# Function to emit to the client and wait for a response
emit_call: Callable[[Literal["ask", "call_fn"], Any, Optional[int]], Any],
# User specific environment variables. Empty if no user environment variables are required.
user_env: Dict[str, str],
client_type: ClientType,
# WSGI environment variables for the connection request
environ: Optional[dict[str, Any]] = None,
# Thread id
thread_id: Optional[str] = None,
# Logged-in user information
user: Optional[Union["User", "PersistedUser"]] = None,
# Logged-in user token
token: Optional[str] = None,
# Chat profile selected before the session was created
chat_profile: Optional[str] = None,
):
super().__init__(
id=id,
thread_id=thread_id,
user=user,
token=token,
user_env=user_env,
client_type=client_type,
chat_profile=chat_profile,
environ=environ,
)
self.socket_id = socket_id
self.emit_call = emit_call
self.emit = emit
self.restored = False
self.thread_queues: Dict[str, ThreadQueue] = {}
self.mcp_sessions = {}
match = (
re.match(
r"^\s*([a-zA-Z0-9-]+)", environ.get("HTTP_ACCEPT_LANGUAGE", "en-US")
)
if environ
else None
)
self.language = match.group(1) if match else "en-US"
self.config: ChainlitConfig = self.get_config()
ws_sessions_id[self.id] = self
ws_sessions_sid[socket_id] = self
def get_config(self) -> "ChainlitConfig":
"""
Return the config for this session: overridden if chat profile exists and has overrides, else global config.
"""
from chainlit.config import config as global_config
# If no chat profile, always fallback to global config
if not self.chat_profile:
return global_config
# If already computed, use self.config
if hasattr(self, "config") and self.config:
return self.config
# Try to compute overrides
cfg = global_config
if global_config.code.set_chat_profiles:
import asyncio
try:
profiles = asyncio.get_event_loop().run_until_complete(
global_config.code.set_chat_profiles(self.user, self.language)
)
current_profile = next(
(p for p in profiles if p.name == self.chat_profile), None
)
if current_profile and getattr(
current_profile, "config_overrides", None
):
cfg = global_config.with_overrides(current_profile.config_overrides)
except Exception:
pass
self.config = cfg
return cfg
def restore(self, new_socket_id: str):
"""Associate a new socket id to the session."""
ws_sessions_sid.pop(self.socket_id, None)
ws_sessions_sid[new_socket_id] = self
self.socket_id = new_socket_id
self.restored = True
async def delete(self):
"""Delete the session."""
if self.files_dir.is_dir():
shutil.rmtree(self.files_dir)
ws_sessions_sid.pop(self.socket_id, None)
ws_sessions_id.pop(self.id, None)
for _, exit_stack in self.mcp_sessions.values():
try:
await exit_stack.aclose()
except Exception:
pass
async def flush_method_queue(self):
for method_name, queue in self.thread_queues.items():
while queue:
method, self, args, kwargs = queue.popleft()
try:
await method(self, *args, **kwargs)
except Exception as e:
logger.error(f"Error while flushing {method_name}: {e}")
@classmethod
def get(cls, socket_id: str):
"""Get session by socket id."""
return ws_sessions_sid.get(socket_id)
@classmethod
def get_by_id(cls, session_id: str):
"""Get session by session id."""
return ws_sessions_id.get(session_id)
@classmethod
def require(cls, socket_id: str):
"""Throws an exception if the session is not found."""
if session := cls.get(socket_id):
return session
raise ValueError("Session not found")
ws_sessions_sid: Dict[str, WebsocketSession] = {}
ws_sessions_id: Dict[str, WebsocketSession] = {}