ai-station/.venv/lib/python3.12/site-packages/chainlit/utils.py

174 lines
5.4 KiB
Python

import functools
import importlib
import inspect
import os
from asyncio import CancelledError
from datetime import datetime, timezone
from typing import Callable
import click
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from packaging import version
from starlette.middleware.base import BaseHTTPMiddleware
from chainlit.auth import ensure_jwt_secret
from chainlit.context import context
from chainlit.logger import logger
def utc_now():
dt = datetime.now(timezone.utc).replace(tzinfo=None)
return dt.isoformat() + "Z"
def timestamp_utc(timestamp: float):
dt = datetime.fromtimestamp(timestamp, timezone.utc).replace(tzinfo=None)
return dt.isoformat() + "Z"
def wrap_user_function(user_function: Callable, with_task=False) -> Callable:
"""
Wraps a user-defined function to accept arguments as a dictionary.
Args:
user_function (Callable): The user-defined function to wrap.
Returns:
Callable: The wrapped function.
"""
@functools.wraps(user_function)
async def wrapper(*args):
# Get the parameter names of the user-defined function
user_function_params = list(inspect.signature(user_function).parameters.keys())
# Create a dictionary of parameter names and their corresponding values from *args
params_values = {
param_name: arg for param_name, arg in zip(user_function_params, args)
}
if with_task:
await context.emitter.task_start()
try:
# Call the user-defined function with the arguments
if inspect.iscoroutinefunction(user_function):
return await user_function(**params_values)
else:
return user_function(**params_values)
except CancelledError:
pass
except Exception as e:
logger.exception(e)
if with_task:
from chainlit.message import ErrorMessage
await ErrorMessage(
content=str(e) or e.__class__.__name__, author="Error"
).send()
finally:
if with_task:
await context.emitter.task_end()
return wrapper
def make_module_getattr(registry):
"""Leverage PEP 562 to make imports lazy in an __init__.py
The registry must be a dictionary with the items to import as keys and the
modules they belong to as a value.
"""
def __getattr__(name):
module_path = registry[name]
module = importlib.import_module(module_path, __package__)
return getattr(module, name)
return __getattr__
def check_module_version(name, required_version):
"""
Check the version of a module.
Args:
name (str): A module name.
version (str): Minimum version.
Returns:
(bool): Return True if the module is installed and the version
match the minimum required version.
"""
try:
module = importlib.import_module(name)
except ModuleNotFoundError:
return False
return version.parse(module.__version__) >= version.parse(required_version)
def check_file(target: str):
# Define accepted file extensions for Chainlit
ACCEPTED_FILE_EXTENSIONS = ("py", "py3")
_, extension = os.path.splitext(target)
# Check file extension
if extension[1:] not in ACCEPTED_FILE_EXTENSIONS:
if extension[1:] == "":
raise click.BadArgumentUsage(
"Chainlit requires raw Python (.py) files, but the provided file has no extension."
)
else:
raise click.BadArgumentUsage(
f"Chainlit requires raw Python (.py) files, not {extension}."
)
if not os.path.exists(target):
raise click.BadParameter(f"File does not exist: {target}")
def mount_chainlit(app: FastAPI, target: str, path="/chainlit"):
from chainlit.config import config, load_module
from chainlit.server import app as chainlit_app
config.run.debug = os.environ.get("CHAINLIT_DEBUG", False)
os.environ["CHAINLIT_ROOT_PATH"] = path
api_full_path = path
if app.root_path:
parent_root_path = app.root_path.rstrip("/")
api_full_path = parent_root_path + path
os.environ["CHAINLIT_PARENT_ROOT_PATH"] = parent_root_path
check_file(target)
# Load the module provided by the user
config.run.module_name = target
load_module(config.run.module_name)
ensure_jwt_secret()
class ChainlitMiddleware(BaseHTTPMiddleware):
"""Middleware to handle path routing for submounted Chainlit applications.
When Chainlit is submounted within a larger FastAPI application, its default route
`@router.get("/{full_path:path}")` can conflict with the main app's routing. This
middleware ensures requests are only forwarded to Chainlit if they match the
designated subpath, preventing routing collisions.
If a request's path doesn't start with the configured subpath, the middleware
returns a 404 response instead of forwarding to Chainlit's default route.
"""
async def dispatch(self, request: Request, call_next):
if not request.url.path.startswith(api_full_path):
return JSONResponse(status_code=404, content={"detail": "Not found"})
return await call_next(request)
chainlit_app.add_middleware(ChainlitMiddleware)
app.mount(path, chainlit_app)