ai-station/.venv/lib/python3.12/site-packages/chainlit/auth/cookie.py

198 lines
6.0 KiB
Python
Raw Permalink Normal View History

import os
from typing import Literal, Optional, cast
from fastapi import Request, Response
from fastapi.exceptions import HTTPException
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.status import HTTP_401_UNAUTHORIZED
from chainlit.config import config
""" Module level cookie settings. """
_cookie_samesite = cast(
Literal["lax", "strict", "none"],
os.environ.get("CHAINLIT_COOKIE_SAMESITE", "lax"),
)
assert _cookie_samesite in [
"lax",
"strict",
"none",
], (
"Invalid value for CHAINLIT_COOKIE_SAMESITE. Must be one of 'lax', 'strict' or 'none'."
)
_cookie_secure = _cookie_samesite == "none"
if _cookie_root_path := os.environ.get("CHAINLIT_ROOT_PATH", None):
_cookie_path = os.environ.get(_cookie_root_path, "/")
else:
_cookie_path = os.environ.get("CHAINLIT_AUTH_COOKIE_PATH", "/")
_state_cookie_lifetime = 3 * 60 # 3m
_auth_cookie_name = os.environ.get("CHAINLIT_AUTH_COOKIE_NAME", "access_token")
_state_cookie_name = "oauth_state"
class OAuth2PasswordBearerWithCookie(SecurityBase):
"""
OAuth2 password flow with cookie support with fallback to bearer token.
"""
def __init__(
self,
tokenUrl: str,
scheme_name: Optional[str] = None,
auto_error: bool = True,
):
self.tokenUrl = tokenUrl
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
# First try to get the token from the cookie
token = get_token_from_cookies(request.cookies)
# If no cookie, try the Authorization header as fallback
if not token:
# TODO: Only bother to check if cookie auth is explicitly disabled.
authorization = request.headers.get("Authorization")
if authorization:
scheme, token = get_authorization_scheme_param(authorization)
if scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None
else:
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None
return token
def _get_chunked_cookie(cookies: dict[str, str], name: str) -> Optional[str]:
# Gather all auth_chunk_i cookies, sorted by their index
chunk_parts = []
i = 0
while True:
cookie_key = f"{_auth_cookie_name}_{i}"
if cookie_key not in cookies:
break
chunk_parts.append(cookies[cookie_key])
i += 1
joined = "".join(chunk_parts)
return joined if joined != "" else None
def get_token_from_cookies(cookies: dict[str, str]) -> Optional[str]:
"""
Read all chunk cookies and reconstruct the token
"""
# Default/unchunked cookies
if value := cookies.get(_auth_cookie_name):
return value
return _get_chunked_cookie(cookies, _auth_cookie_name)
def set_auth_cookie(request: Request, response: Response, token: str):
"""
Helper function to set the authentication cookie with secure parameters
and remove any leftover chunks from a previously larger token.
"""
_chunk_size = 3000
existing_cookies = {
k for k in request.cookies.keys() if k.startswith(_auth_cookie_name)
}
if len(token) > _chunk_size:
chunks = [token[i : i + _chunk_size] for i in range(0, len(token), _chunk_size)]
for i, chunk in enumerate(chunks):
k = f"{_auth_cookie_name}_{i}"
response.set_cookie(
key=k,
value=chunk,
httponly=True,
secure=_cookie_secure,
samesite=_cookie_samesite,
max_age=config.project.user_session_timeout,
)
existing_cookies.discard(k)
else:
# Default (shorter cookies)
response.set_cookie(
key=_auth_cookie_name,
value=token,
httponly=True,
secure=_cookie_secure,
samesite=_cookie_samesite,
max_age=config.project.user_session_timeout,
)
existing_cookies.discard(_auth_cookie_name)
# Delete remaining prior cookies/cookie chunks
for k in existing_cookies:
response.delete_cookie(
key=k, path=_cookie_path, secure=_cookie_secure, samesite=_cookie_samesite
)
def clear_auth_cookie(request: Request, response: Response):
"""
Helper function to clear the authentication cookie
"""
existing_cookies = {
k for k in request.cookies.keys() if k.startswith(_auth_cookie_name)
}
for k in existing_cookies:
response.delete_cookie(
key=k, path=_cookie_path, secure=_cookie_secure, samesite=_cookie_samesite
)
def set_oauth_state_cookie(response: Response, token: str):
response.set_cookie(
_state_cookie_name,
token,
httponly=True,
samesite=_cookie_samesite,
secure=_cookie_secure,
max_age=_state_cookie_lifetime,
)
def validate_oauth_state_cookie(request: Request, state: str):
"""Check the state from the oauth provider against the browser cookie."""
oauth_state = request.cookies.get(_state_cookie_name)
if oauth_state != state:
raise Exception("oauth state does not correspond")
def clear_oauth_state_cookie(response: Response):
"""Oauth complete, delete state token."""
response.delete_cookie(_state_cookie_name) # Do we set path here?