ai-station/.venv/lib/python3.12/site-packages/chainlit/socket.py

417 lines
13 KiB
Python
Raw Normal View History

import asyncio
import json
from typing import Any, Dict, Literal, Optional, Tuple, TypedDict, Union
from urllib.parse import unquote
from starlette.requests import cookie_parser
from typing_extensions import TypeAlias
from chainlit.auth import (
get_current_user,
get_token_from_cookies,
require_login,
)
from chainlit.chat_context import chat_context
from chainlit.config import ChainlitConfig, config
from chainlit.context import init_ws_context
from chainlit.data import get_data_layer
from chainlit.logger import logger
from chainlit.message import ErrorMessage, Message
from chainlit.server import sio
from chainlit.session import ClientType, WebsocketSession
from chainlit.types import (
InputAudioChunk,
InputAudioChunkPayload,
MessagePayload,
)
from chainlit.user import PersistedUser, User
from chainlit.user_session import user_sessions
WSGIEnvironment: TypeAlias = dict[str, Any]
class WebSocketSessionAuth(TypedDict):
sessionId: str
userEnv: str | None
clientType: ClientType
chatProfile: str | None
threadId: str | None
def restore_existing_session(sid, session_id, emit_fn, emit_call_fn, environ):
"""Restore a session from the sessionId provided by the client."""
if session := WebsocketSession.get_by_id(session_id):
session.restore(new_socket_id=sid)
session.emit = emit_fn
session.emit_call = emit_call_fn
session.environ = environ
return True
return False
async def persist_user_session(thread_id: str, metadata: Dict):
if data_layer := get_data_layer():
await data_layer.update_thread(thread_id=thread_id, metadata=metadata)
async def resume_thread(session: WebsocketSession):
data_layer = get_data_layer()
if not data_layer or not session.user or not session.thread_id_to_resume:
return
thread = await data_layer.get_thread(thread_id=session.thread_id_to_resume)
if not thread:
return
author = thread.get("userIdentifier")
user_is_author = author == session.user.identifier
if user_is_author:
metadata = thread.get("metadata") or {}
if isinstance(metadata, str):
metadata = json.loads(metadata)
user_sessions[session.id] = metadata.copy()
if chat_profile := metadata.get("chat_profile"):
session.chat_profile = chat_profile
if chat_settings := metadata.get("chat_settings"):
session.chat_settings = chat_settings
return thread
def load_user_env(user_env):
if user_env:
user_env_dict = json.loads(user_env)
# Check user env
if config.project.user_env:
if not user_env_dict:
raise ConnectionRefusedError("Missing user environment variables")
# Check if requested user environment variables are provided
for key in config.project.user_env:
if key not in user_env_dict:
raise ConnectionRefusedError(
"Missing user environment variable: " + key
)
return user_env_dict
def _get_token_from_cookie(environ: WSGIEnvironment) -> Optional[str]:
if cookie_header := environ.get("HTTP_COOKIE", None):
cookies = cookie_parser(cookie_header)
return get_token_from_cookies(cookies)
return None
def _get_token(environ: WSGIEnvironment) -> Optional[str]:
"""Take WSGI environ, return access token."""
return _get_token_from_cookie(environ)
async def _authenticate_connection(
environ: WSGIEnvironment,
) -> Union[Tuple[Union[User, PersistedUser], str], Tuple[None, None]]:
if token := _get_token(environ):
user = await get_current_user(token=token)
if user:
return user, token
return None, None
@sio.on("connect") # pyright: ignore [reportOptionalCall]
async def connect(sid: str, environ: WSGIEnvironment, auth: WebSocketSessionAuth):
user: User | PersistedUser | None = None
token: str | None = None
thread_id = auth.get("threadId", None)
if require_login():
try:
user, token = await _authenticate_connection(environ)
except Exception as e:
logger.exception("Exception authenticating connection: %s", e)
if not user:
logger.error("Authentication failed in websocket connect.")
raise ConnectionRefusedError("authentication failed")
if thread_id:
if data_layer := get_data_layer():
thread = await data_layer.get_thread(thread_id)
if thread and not (thread["userIdentifier"] == user.identifier):
logger.error("Authorization for the thread failed.")
raise ConnectionRefusedError("authorization failed")
# Session scoped function to emit to the client
def emit_fn(event, data):
return sio.emit(event, data, to=sid)
# Session scoped function to emit to the client and wait for a response
def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
return sio.call(event, data, timeout=timeout, to=sid)
session_id = auth["sessionId"]
if restore_existing_session(sid, session_id, emit_fn, emit_call_fn, environ):
return True
user_env_string = auth.get("userEnv", None)
user_env = load_user_env(user_env_string)
client_type = auth["clientType"]
url_encoded_chat_profile = auth.get("chatProfile", None)
chat_profile = (
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
)
WebsocketSession(
id=session_id,
socket_id=sid,
emit=emit_fn,
emit_call=emit_call_fn,
client_type=client_type,
user_env=user_env,
user=user,
token=token,
chat_profile=chat_profile,
thread_id=thread_id,
environ=environ,
)
return True
@sio.on("connection_successful") # pyright: ignore [reportOptionalCall]
async def connection_successful(sid):
context = init_ws_context(sid)
await context.emitter.task_end()
await context.emitter.clear("clear_ask")
await context.emitter.clear("clear_call_fn")
if context.session.restored and not context.session.has_first_interaction:
if config.code.on_chat_start:
task = asyncio.create_task(config.code.on_chat_start())
context.session.current_task = task
return
if context.session.thread_id_to_resume and config.code.on_chat_resume:
thread = await resume_thread(context.session)
if thread:
context.session.has_first_interaction = True
await context.emitter.emit(
"first_interaction",
{"interaction": "resume", "thread_id": thread.get("id")},
)
await config.code.on_chat_resume(thread)
for step in thread.get("steps", []):
if "message" in step["type"]:
chat_context.add(Message.from_dict(step))
await context.emitter.resume_thread(thread)
return
else:
await context.emitter.send_resume_thread_error("Thread not found.")
if config.code.on_chat_start:
task = asyncio.create_task(config.code.on_chat_start())
context.session.current_task = task
@sio.on("clear_session") # pyright: ignore [reportOptionalCall]
async def clean_session(sid):
session = WebsocketSession.get(sid)
if session:
session.to_clear = True
@sio.on("disconnect") # pyright: ignore [reportOptionalCall]
async def disconnect(sid):
session = WebsocketSession.get(sid)
if not session:
return
init_ws_context(session)
if config.code.on_chat_end:
await config.code.on_chat_end()
if session.thread_id and session.has_first_interaction:
await persist_user_session(session.thread_id, session.to_persistable())
async def clear(_sid):
if session := WebsocketSession.get(_sid):
# Clean up the user session
if session.id in user_sessions:
user_sessions.pop(session.id)
# Clean up the session
await session.delete()
if session.to_clear:
await clear(sid)
else:
async def clear_on_timeout(_sid):
await asyncio.sleep(config.project.session_timeout)
await clear(_sid)
asyncio.ensure_future(clear_on_timeout(sid))
@sio.on("stop") # pyright: ignore [reportOptionalCall]
async def stop(sid):
if session := WebsocketSession.get(sid):
init_ws_context(session)
await Message(content="Task manually stopped.").send()
if session.current_task:
session.current_task.cancel()
if config.code.on_stop:
await config.code.on_stop()
async def process_message(session: WebsocketSession, payload: MessagePayload):
"""Process a message from the user."""
try:
context = init_ws_context(session)
await context.emitter.task_start()
message = await context.emitter.process_message(payload)
if config.code.on_message:
await asyncio.sleep(0.001)
await config.code.on_message(message)
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(e)
await ErrorMessage(
author="Error", content=str(e) or e.__class__.__name__
).send()
finally:
await context.emitter.task_end()
@sio.on("edit_message") # pyright: ignore [reportOptionalCall]
async def edit_message(sid, payload: MessagePayload):
"""Handle a message sent by the User."""
session = WebsocketSession.require(sid)
context = init_ws_context(session)
messages = chat_context.get()
orig_message = None
for message in messages:
if orig_message:
await message.remove()
if message.id == payload["message"]["id"]:
message.content = payload["message"]["output"]
await message.update()
orig_message = message
await context.emitter.task_start()
if config.code.on_message:
try:
await config.code.on_message(orig_message)
except asyncio.CancelledError:
pass
finally:
await context.emitter.task_end()
@sio.on("client_message") # pyright: ignore [reportOptionalCall]
async def message(sid, payload: MessagePayload):
"""Handle a message sent by the User."""
session = WebsocketSession.require(sid)
task = asyncio.create_task(process_message(session, payload))
session.current_task = task
@sio.on("window_message") # pyright: ignore [reportOptionalCall]
async def window_message(sid, data):
"""Handle a message send by the host window."""
session = WebsocketSession.require(sid)
init_ws_context(session)
if config.code.on_window_message:
try:
await config.code.on_window_message(data)
except asyncio.CancelledError:
pass
@sio.on("audio_start") # pyright: ignore [reportOptionalCall]
async def audio_start(sid):
"""Handle audio init."""
session = WebsocketSession.require(sid)
context = init_ws_context(session)
config: ChainlitConfig = session.get_config() # type: ignore
if config.features.audio and config.features.audio.enabled:
connected = bool(await config.code.on_audio_start())
connection_state = "on" if connected else "off"
await context.emitter.update_audio_connection(connection_state)
@sio.on("audio_chunk")
async def audio_chunk(sid, payload: InputAudioChunkPayload):
"""Handle an audio chunk sent by the user."""
session = WebsocketSession.require(sid)
init_ws_context(session)
config: ChainlitConfig = session.get_config()
if (
config.features.audio
and config.features.audio.enabled
and config.code.on_audio_chunk
):
asyncio.create_task(config.code.on_audio_chunk(InputAudioChunk(**payload)))
@sio.on("audio_end")
async def audio_end(sid):
"""Handle the end of the audio stream."""
session = WebsocketSession.require(sid)
try:
context = init_ws_context(session)
await context.emitter.task_start()
if not session.has_first_interaction:
session.has_first_interaction = True
asyncio.create_task(context.emitter.init_thread("audio"))
config: ChainlitConfig = session.get_config() # type: ignore
if config.features.audio and config.features.audio.enabled:
await config.code.on_audio_end()
except asyncio.CancelledError:
pass
except Exception as e:
logger.exception(e)
await ErrorMessage(
author="Error", content=str(e) or e.__class__.__name__
).send()
finally:
await context.emitter.task_end()
@sio.on("chat_settings_change")
async def change_settings(sid, settings: Dict[str, Any]):
"""Handle change settings submit from the UI."""
context = init_ws_context(sid)
for key, value in settings.items():
context.session.chat_settings[key] = value
if config.code.on_settings_update:
await config.code.on_settings_update(settings)