417 lines
13 KiB
Python
417 lines
13 KiB
Python
import asyncio
|
|
import json
|
|
from typing import Any, Dict, Literal, Optional, Tuple, TypedDict, Union
|
|
from urllib.parse import unquote
|
|
|
|
from starlette.requests import cookie_parser
|
|
from typing_extensions import TypeAlias
|
|
|
|
from chainlit.auth import (
|
|
get_current_user,
|
|
get_token_from_cookies,
|
|
require_login,
|
|
)
|
|
from chainlit.chat_context import chat_context
|
|
from chainlit.config import ChainlitConfig, config
|
|
from chainlit.context import init_ws_context
|
|
from chainlit.data import get_data_layer
|
|
from chainlit.logger import logger
|
|
from chainlit.message import ErrorMessage, Message
|
|
from chainlit.server import sio
|
|
from chainlit.session import ClientType, WebsocketSession
|
|
from chainlit.types import (
|
|
InputAudioChunk,
|
|
InputAudioChunkPayload,
|
|
MessagePayload,
|
|
)
|
|
from chainlit.user import PersistedUser, User
|
|
from chainlit.user_session import user_sessions
|
|
|
|
WSGIEnvironment: TypeAlias = dict[str, Any]
|
|
|
|
|
|
class WebSocketSessionAuth(TypedDict):
|
|
sessionId: str
|
|
userEnv: str | None
|
|
clientType: ClientType
|
|
chatProfile: str | None
|
|
threadId: str | None
|
|
|
|
|
|
def restore_existing_session(sid, session_id, emit_fn, emit_call_fn, environ):
|
|
"""Restore a session from the sessionId provided by the client."""
|
|
if session := WebsocketSession.get_by_id(session_id):
|
|
session.restore(new_socket_id=sid)
|
|
session.emit = emit_fn
|
|
session.emit_call = emit_call_fn
|
|
session.environ = environ
|
|
return True
|
|
return False
|
|
|
|
|
|
async def persist_user_session(thread_id: str, metadata: Dict):
|
|
if data_layer := get_data_layer():
|
|
await data_layer.update_thread(thread_id=thread_id, metadata=metadata)
|
|
|
|
|
|
async def resume_thread(session: WebsocketSession):
|
|
data_layer = get_data_layer()
|
|
if not data_layer or not session.user or not session.thread_id_to_resume:
|
|
return
|
|
thread = await data_layer.get_thread(thread_id=session.thread_id_to_resume)
|
|
if not thread:
|
|
return
|
|
|
|
author = thread.get("userIdentifier")
|
|
user_is_author = author == session.user.identifier
|
|
|
|
if user_is_author:
|
|
metadata = thread.get("metadata") or {}
|
|
if isinstance(metadata, str):
|
|
metadata = json.loads(metadata)
|
|
user_sessions[session.id] = metadata.copy()
|
|
if chat_profile := metadata.get("chat_profile"):
|
|
session.chat_profile = chat_profile
|
|
if chat_settings := metadata.get("chat_settings"):
|
|
session.chat_settings = chat_settings
|
|
|
|
return thread
|
|
|
|
|
|
def load_user_env(user_env):
|
|
if user_env:
|
|
user_env_dict = json.loads(user_env)
|
|
# Check user env
|
|
if config.project.user_env:
|
|
if not user_env_dict:
|
|
raise ConnectionRefusedError("Missing user environment variables")
|
|
# Check if requested user environment variables are provided
|
|
for key in config.project.user_env:
|
|
if key not in user_env_dict:
|
|
raise ConnectionRefusedError(
|
|
"Missing user environment variable: " + key
|
|
)
|
|
return user_env_dict
|
|
|
|
|
|
def _get_token_from_cookie(environ: WSGIEnvironment) -> Optional[str]:
|
|
if cookie_header := environ.get("HTTP_COOKIE", None):
|
|
cookies = cookie_parser(cookie_header)
|
|
return get_token_from_cookies(cookies)
|
|
|
|
return None
|
|
|
|
|
|
def _get_token(environ: WSGIEnvironment) -> Optional[str]:
|
|
"""Take WSGI environ, return access token."""
|
|
return _get_token_from_cookie(environ)
|
|
|
|
|
|
async def _authenticate_connection(
|
|
environ: WSGIEnvironment,
|
|
) -> Union[Tuple[Union[User, PersistedUser], str], Tuple[None, None]]:
|
|
if token := _get_token(environ):
|
|
user = await get_current_user(token=token)
|
|
if user:
|
|
return user, token
|
|
|
|
return None, None
|
|
|
|
|
|
@sio.on("connect") # pyright: ignore [reportOptionalCall]
|
|
async def connect(sid: str, environ: WSGIEnvironment, auth: WebSocketSessionAuth):
|
|
user: User | PersistedUser | None = None
|
|
token: str | None = None
|
|
thread_id = auth.get("threadId", None)
|
|
|
|
if require_login():
|
|
try:
|
|
user, token = await _authenticate_connection(environ)
|
|
except Exception as e:
|
|
logger.exception("Exception authenticating connection: %s", e)
|
|
|
|
if not user:
|
|
logger.error("Authentication failed in websocket connect.")
|
|
raise ConnectionRefusedError("authentication failed")
|
|
|
|
if thread_id:
|
|
if data_layer := get_data_layer():
|
|
thread = await data_layer.get_thread(thread_id)
|
|
if thread and not (thread["userIdentifier"] == user.identifier):
|
|
logger.error("Authorization for the thread failed.")
|
|
raise ConnectionRefusedError("authorization failed")
|
|
|
|
# Session scoped function to emit to the client
|
|
def emit_fn(event, data):
|
|
return sio.emit(event, data, to=sid)
|
|
|
|
# Session scoped function to emit to the client and wait for a response
|
|
def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
|
|
return sio.call(event, data, timeout=timeout, to=sid)
|
|
|
|
session_id = auth["sessionId"]
|
|
if restore_existing_session(sid, session_id, emit_fn, emit_call_fn, environ):
|
|
return True
|
|
|
|
user_env_string = auth.get("userEnv", None)
|
|
user_env = load_user_env(user_env_string)
|
|
|
|
client_type = auth["clientType"]
|
|
url_encoded_chat_profile = auth.get("chatProfile", None)
|
|
chat_profile = (
|
|
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
|
|
)
|
|
|
|
WebsocketSession(
|
|
id=session_id,
|
|
socket_id=sid,
|
|
emit=emit_fn,
|
|
emit_call=emit_call_fn,
|
|
client_type=client_type,
|
|
user_env=user_env,
|
|
user=user,
|
|
token=token,
|
|
chat_profile=chat_profile,
|
|
thread_id=thread_id,
|
|
environ=environ,
|
|
)
|
|
|
|
return True
|
|
|
|
|
|
@sio.on("connection_successful") # pyright: ignore [reportOptionalCall]
|
|
async def connection_successful(sid):
|
|
context = init_ws_context(sid)
|
|
|
|
await context.emitter.task_end()
|
|
await context.emitter.clear("clear_ask")
|
|
await context.emitter.clear("clear_call_fn")
|
|
|
|
if context.session.restored and not context.session.has_first_interaction:
|
|
if config.code.on_chat_start:
|
|
task = asyncio.create_task(config.code.on_chat_start())
|
|
context.session.current_task = task
|
|
return
|
|
|
|
if context.session.thread_id_to_resume and config.code.on_chat_resume:
|
|
thread = await resume_thread(context.session)
|
|
if thread:
|
|
context.session.has_first_interaction = True
|
|
await context.emitter.emit(
|
|
"first_interaction",
|
|
{"interaction": "resume", "thread_id": thread.get("id")},
|
|
)
|
|
await config.code.on_chat_resume(thread)
|
|
|
|
for step in thread.get("steps", []):
|
|
if "message" in step["type"]:
|
|
chat_context.add(Message.from_dict(step))
|
|
|
|
await context.emitter.resume_thread(thread)
|
|
return
|
|
else:
|
|
await context.emitter.send_resume_thread_error("Thread not found.")
|
|
|
|
if config.code.on_chat_start:
|
|
task = asyncio.create_task(config.code.on_chat_start())
|
|
context.session.current_task = task
|
|
|
|
|
|
@sio.on("clear_session") # pyright: ignore [reportOptionalCall]
|
|
async def clean_session(sid):
|
|
session = WebsocketSession.get(sid)
|
|
if session:
|
|
session.to_clear = True
|
|
|
|
|
|
@sio.on("disconnect") # pyright: ignore [reportOptionalCall]
|
|
async def disconnect(sid):
|
|
session = WebsocketSession.get(sid)
|
|
|
|
if not session:
|
|
return
|
|
|
|
init_ws_context(session)
|
|
|
|
if config.code.on_chat_end:
|
|
await config.code.on_chat_end()
|
|
|
|
if session.thread_id and session.has_first_interaction:
|
|
await persist_user_session(session.thread_id, session.to_persistable())
|
|
|
|
async def clear(_sid):
|
|
if session := WebsocketSession.get(_sid):
|
|
# Clean up the user session
|
|
if session.id in user_sessions:
|
|
user_sessions.pop(session.id)
|
|
# Clean up the session
|
|
await session.delete()
|
|
|
|
if session.to_clear:
|
|
await clear(sid)
|
|
else:
|
|
|
|
async def clear_on_timeout(_sid):
|
|
await asyncio.sleep(config.project.session_timeout)
|
|
await clear(_sid)
|
|
|
|
asyncio.ensure_future(clear_on_timeout(sid))
|
|
|
|
|
|
@sio.on("stop") # pyright: ignore [reportOptionalCall]
|
|
async def stop(sid):
|
|
if session := WebsocketSession.get(sid):
|
|
init_ws_context(session)
|
|
await Message(content="Task manually stopped.").send()
|
|
|
|
if session.current_task:
|
|
session.current_task.cancel()
|
|
|
|
if config.code.on_stop:
|
|
await config.code.on_stop()
|
|
|
|
|
|
async def process_message(session: WebsocketSession, payload: MessagePayload):
|
|
"""Process a message from the user."""
|
|
try:
|
|
context = init_ws_context(session)
|
|
await context.emitter.task_start()
|
|
message = await context.emitter.process_message(payload)
|
|
|
|
if config.code.on_message:
|
|
await asyncio.sleep(0.001)
|
|
await config.code.on_message(message)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
await ErrorMessage(
|
|
author="Error", content=str(e) or e.__class__.__name__
|
|
).send()
|
|
finally:
|
|
await context.emitter.task_end()
|
|
|
|
|
|
@sio.on("edit_message") # pyright: ignore [reportOptionalCall]
|
|
async def edit_message(sid, payload: MessagePayload):
|
|
"""Handle a message sent by the User."""
|
|
session = WebsocketSession.require(sid)
|
|
context = init_ws_context(session)
|
|
|
|
messages = chat_context.get()
|
|
|
|
orig_message = None
|
|
|
|
for message in messages:
|
|
if orig_message:
|
|
await message.remove()
|
|
|
|
if message.id == payload["message"]["id"]:
|
|
message.content = payload["message"]["output"]
|
|
await message.update()
|
|
orig_message = message
|
|
|
|
await context.emitter.task_start()
|
|
|
|
if config.code.on_message:
|
|
try:
|
|
await config.code.on_message(orig_message)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
await context.emitter.task_end()
|
|
|
|
|
|
@sio.on("client_message") # pyright: ignore [reportOptionalCall]
|
|
async def message(sid, payload: MessagePayload):
|
|
"""Handle a message sent by the User."""
|
|
session = WebsocketSession.require(sid)
|
|
|
|
task = asyncio.create_task(process_message(session, payload))
|
|
session.current_task = task
|
|
|
|
|
|
@sio.on("window_message") # pyright: ignore [reportOptionalCall]
|
|
async def window_message(sid, data):
|
|
"""Handle a message send by the host window."""
|
|
session = WebsocketSession.require(sid)
|
|
init_ws_context(session)
|
|
|
|
if config.code.on_window_message:
|
|
try:
|
|
await config.code.on_window_message(data)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
|
|
@sio.on("audio_start") # pyright: ignore [reportOptionalCall]
|
|
async def audio_start(sid):
|
|
"""Handle audio init."""
|
|
session = WebsocketSession.require(sid)
|
|
|
|
context = init_ws_context(session)
|
|
config: ChainlitConfig = session.get_config() # type: ignore
|
|
|
|
if config.features.audio and config.features.audio.enabled:
|
|
connected = bool(await config.code.on_audio_start())
|
|
connection_state = "on" if connected else "off"
|
|
await context.emitter.update_audio_connection(connection_state)
|
|
|
|
|
|
@sio.on("audio_chunk")
|
|
async def audio_chunk(sid, payload: InputAudioChunkPayload):
|
|
"""Handle an audio chunk sent by the user."""
|
|
session = WebsocketSession.require(sid)
|
|
|
|
init_ws_context(session)
|
|
|
|
config: ChainlitConfig = session.get_config()
|
|
|
|
if (
|
|
config.features.audio
|
|
and config.features.audio.enabled
|
|
and config.code.on_audio_chunk
|
|
):
|
|
asyncio.create_task(config.code.on_audio_chunk(InputAudioChunk(**payload)))
|
|
|
|
|
|
@sio.on("audio_end")
|
|
async def audio_end(sid):
|
|
"""Handle the end of the audio stream."""
|
|
session = WebsocketSession.require(sid)
|
|
|
|
try:
|
|
context = init_ws_context(session)
|
|
await context.emitter.task_start()
|
|
|
|
if not session.has_first_interaction:
|
|
session.has_first_interaction = True
|
|
asyncio.create_task(context.emitter.init_thread("audio"))
|
|
|
|
config: ChainlitConfig = session.get_config() # type: ignore
|
|
|
|
if config.features.audio and config.features.audio.enabled:
|
|
await config.code.on_audio_end()
|
|
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
await ErrorMessage(
|
|
author="Error", content=str(e) or e.__class__.__name__
|
|
).send()
|
|
finally:
|
|
await context.emitter.task_end()
|
|
|
|
|
|
@sio.on("chat_settings_change")
|
|
async def change_settings(sid, settings: Dict[str, Any]):
|
|
"""Handle change settings submit from the UI."""
|
|
context = init_ws_context(sid)
|
|
|
|
for key, value in settings.items():
|
|
context.session.chat_settings[key] = value
|
|
|
|
if config.code.on_settings_update:
|
|
await config.code.on_settings_update(settings)
|