import asyncio import json from typing import Any, Dict, Literal, Optional, Tuple, TypedDict, Union from urllib.parse import unquote from starlette.requests import cookie_parser from typing_extensions import TypeAlias from chainlit.auth import ( get_current_user, get_token_from_cookies, require_login, ) from chainlit.chat_context import chat_context from chainlit.config import ChainlitConfig, config from chainlit.context import init_ws_context from chainlit.data import get_data_layer from chainlit.logger import logger from chainlit.message import ErrorMessage, Message from chainlit.server import sio from chainlit.session import ClientType, WebsocketSession from chainlit.types import ( InputAudioChunk, InputAudioChunkPayload, MessagePayload, ) from chainlit.user import PersistedUser, User from chainlit.user_session import user_sessions WSGIEnvironment: TypeAlias = dict[str, Any] class WebSocketSessionAuth(TypedDict): sessionId: str userEnv: str | None clientType: ClientType chatProfile: str | None threadId: str | None def restore_existing_session(sid, session_id, emit_fn, emit_call_fn, environ): """Restore a session from the sessionId provided by the client.""" if session := WebsocketSession.get_by_id(session_id): session.restore(new_socket_id=sid) session.emit = emit_fn session.emit_call = emit_call_fn session.environ = environ return True return False async def persist_user_session(thread_id: str, metadata: Dict): if data_layer := get_data_layer(): await data_layer.update_thread(thread_id=thread_id, metadata=metadata) async def resume_thread(session: WebsocketSession): data_layer = get_data_layer() if not data_layer or not session.user or not session.thread_id_to_resume: return thread = await data_layer.get_thread(thread_id=session.thread_id_to_resume) if not thread: return author = thread.get("userIdentifier") user_is_author = author == session.user.identifier if user_is_author: metadata = thread.get("metadata") or {} if isinstance(metadata, str): metadata = json.loads(metadata) user_sessions[session.id] = metadata.copy() if chat_profile := metadata.get("chat_profile"): session.chat_profile = chat_profile if chat_settings := metadata.get("chat_settings"): session.chat_settings = chat_settings return thread def load_user_env(user_env): if user_env: user_env_dict = json.loads(user_env) # Check user env if config.project.user_env: if not user_env_dict: raise ConnectionRefusedError("Missing user environment variables") # Check if requested user environment variables are provided for key in config.project.user_env: if key not in user_env_dict: raise ConnectionRefusedError( "Missing user environment variable: " + key ) return user_env_dict def _get_token_from_cookie(environ: WSGIEnvironment) -> Optional[str]: if cookie_header := environ.get("HTTP_COOKIE", None): cookies = cookie_parser(cookie_header) return get_token_from_cookies(cookies) return None def _get_token(environ: WSGIEnvironment) -> Optional[str]: """Take WSGI environ, return access token.""" return _get_token_from_cookie(environ) async def _authenticate_connection( environ: WSGIEnvironment, ) -> Union[Tuple[Union[User, PersistedUser], str], Tuple[None, None]]: if token := _get_token(environ): user = await get_current_user(token=token) if user: return user, token return None, None @sio.on("connect") # pyright: ignore [reportOptionalCall] async def connect(sid: str, environ: WSGIEnvironment, auth: WebSocketSessionAuth): user: User | PersistedUser | None = None token: str | None = None thread_id = auth.get("threadId", None) if require_login(): try: user, token = await _authenticate_connection(environ) except Exception as e: logger.exception("Exception authenticating connection: %s", e) if not user: logger.error("Authentication failed in websocket connect.") raise ConnectionRefusedError("authentication failed") if thread_id: if data_layer := get_data_layer(): thread = await data_layer.get_thread(thread_id) if thread and not (thread["userIdentifier"] == user.identifier): logger.error("Authorization for the thread failed.") raise ConnectionRefusedError("authorization failed") # Session scoped function to emit to the client def emit_fn(event, data): return sio.emit(event, data, to=sid) # Session scoped function to emit to the client and wait for a response def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout): return sio.call(event, data, timeout=timeout, to=sid) session_id = auth["sessionId"] if restore_existing_session(sid, session_id, emit_fn, emit_call_fn, environ): return True user_env_string = auth.get("userEnv", None) user_env = load_user_env(user_env_string) client_type = auth["clientType"] url_encoded_chat_profile = auth.get("chatProfile", None) chat_profile = ( unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None ) WebsocketSession( id=session_id, socket_id=sid, emit=emit_fn, emit_call=emit_call_fn, client_type=client_type, user_env=user_env, user=user, token=token, chat_profile=chat_profile, thread_id=thread_id, environ=environ, ) return True @sio.on("connection_successful") # pyright: ignore [reportOptionalCall] async def connection_successful(sid): context = init_ws_context(sid) await context.emitter.task_end() await context.emitter.clear("clear_ask") await context.emitter.clear("clear_call_fn") if context.session.restored and not context.session.has_first_interaction: if config.code.on_chat_start: task = asyncio.create_task(config.code.on_chat_start()) context.session.current_task = task return if context.session.thread_id_to_resume and config.code.on_chat_resume: thread = await resume_thread(context.session) if thread: context.session.has_first_interaction = True await context.emitter.emit( "first_interaction", {"interaction": "resume", "thread_id": thread.get("id")}, ) await config.code.on_chat_resume(thread) for step in thread.get("steps", []): if "message" in step["type"]: chat_context.add(Message.from_dict(step)) await context.emitter.resume_thread(thread) return else: await context.emitter.send_resume_thread_error("Thread not found.") if config.code.on_chat_start: task = asyncio.create_task(config.code.on_chat_start()) context.session.current_task = task @sio.on("clear_session") # pyright: ignore [reportOptionalCall] async def clean_session(sid): session = WebsocketSession.get(sid) if session: session.to_clear = True @sio.on("disconnect") # pyright: ignore [reportOptionalCall] async def disconnect(sid): session = WebsocketSession.get(sid) if not session: return init_ws_context(session) if config.code.on_chat_end: await config.code.on_chat_end() if session.thread_id and session.has_first_interaction: await persist_user_session(session.thread_id, session.to_persistable()) async def clear(_sid): if session := WebsocketSession.get(_sid): # Clean up the user session if session.id in user_sessions: user_sessions.pop(session.id) # Clean up the session await session.delete() if session.to_clear: await clear(sid) else: async def clear_on_timeout(_sid): await asyncio.sleep(config.project.session_timeout) await clear(_sid) asyncio.ensure_future(clear_on_timeout(sid)) @sio.on("stop") # pyright: ignore [reportOptionalCall] async def stop(sid): if session := WebsocketSession.get(sid): init_ws_context(session) await Message(content="Task manually stopped.").send() if session.current_task: session.current_task.cancel() if config.code.on_stop: await config.code.on_stop() async def process_message(session: WebsocketSession, payload: MessagePayload): """Process a message from the user.""" try: context = init_ws_context(session) await context.emitter.task_start() message = await context.emitter.process_message(payload) if config.code.on_message: await asyncio.sleep(0.001) await config.code.on_message(message) except asyncio.CancelledError: pass except Exception as e: logger.exception(e) await ErrorMessage( author="Error", content=str(e) or e.__class__.__name__ ).send() finally: await context.emitter.task_end() @sio.on("edit_message") # pyright: ignore [reportOptionalCall] async def edit_message(sid, payload: MessagePayload): """Handle a message sent by the User.""" session = WebsocketSession.require(sid) context = init_ws_context(session) messages = chat_context.get() orig_message = None for message in messages: if orig_message: await message.remove() if message.id == payload["message"]["id"]: message.content = payload["message"]["output"] await message.update() orig_message = message await context.emitter.task_start() if config.code.on_message: try: await config.code.on_message(orig_message) except asyncio.CancelledError: pass finally: await context.emitter.task_end() @sio.on("client_message") # pyright: ignore [reportOptionalCall] async def message(sid, payload: MessagePayload): """Handle a message sent by the User.""" session = WebsocketSession.require(sid) task = asyncio.create_task(process_message(session, payload)) session.current_task = task @sio.on("window_message") # pyright: ignore [reportOptionalCall] async def window_message(sid, data): """Handle a message send by the host window.""" session = WebsocketSession.require(sid) init_ws_context(session) if config.code.on_window_message: try: await config.code.on_window_message(data) except asyncio.CancelledError: pass @sio.on("audio_start") # pyright: ignore [reportOptionalCall] async def audio_start(sid): """Handle audio init.""" session = WebsocketSession.require(sid) context = init_ws_context(session) config: ChainlitConfig = session.get_config() # type: ignore if config.features.audio and config.features.audio.enabled: connected = bool(await config.code.on_audio_start()) connection_state = "on" if connected else "off" await context.emitter.update_audio_connection(connection_state) @sio.on("audio_chunk") async def audio_chunk(sid, payload: InputAudioChunkPayload): """Handle an audio chunk sent by the user.""" session = WebsocketSession.require(sid) init_ws_context(session) config: ChainlitConfig = session.get_config() if ( config.features.audio and config.features.audio.enabled and config.code.on_audio_chunk ): asyncio.create_task(config.code.on_audio_chunk(InputAudioChunk(**payload))) @sio.on("audio_end") async def audio_end(sid): """Handle the end of the audio stream.""" session = WebsocketSession.require(sid) try: context = init_ws_context(session) await context.emitter.task_start() if not session.has_first_interaction: session.has_first_interaction = True asyncio.create_task(context.emitter.init_thread("audio")) config: ChainlitConfig = session.get_config() # type: ignore if config.features.audio and config.features.audio.enabled: await config.code.on_audio_end() except asyncio.CancelledError: pass except Exception as e: logger.exception(e) await ErrorMessage( author="Error", content=str(e) or e.__class__.__name__ ).send() finally: await context.emitter.task_end() @sio.on("chat_settings_change") async def change_settings(sid, settings: Dict[str, Any]): """Handle change settings submit from the UI.""" context = init_ws_context(sid) for key, value in settings.items(): context.session.chat_settings[key] = value if config.code.on_settings_update: await config.code.on_settings_update(settings)