import logging from collections.abc import Callable from contextlib import asynccontextmanager from typing import Any from urllib.parse import parse_qs, urljoin, urlparse import anyio import httpx from anyio.abc import TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import aconnect_sse from httpx_sse._exceptions import SSEError import mcp.types as types from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage logger = logging.getLogger(__name__) def remove_request_params(url: str) -> str: return urljoin(url, urlparse(url).path) def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None: query_params = parse_qs(urlparse(endpoint_url).query) return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0] @asynccontextmanager async def sse_client( url: str, headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, auth: httpx.Auth | None = None, on_session_created: Callable[[str], None] | None = None, ): """ Client transport for SSE. `sse_read_timeout` determines how long (in seconds) the client will wait for a new event before disconnecting. All other HTTP operations are controlled by `timeout`. Args: url: The SSE endpoint URL. headers: Optional headers to include in requests. timeout: HTTP timeout for regular operations. sse_read_timeout: Timeout for SSE read operations. auth: Optional HTTPX authentication handler. on_session_created: Optional callback invoked with the session ID when received. """ read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] write_stream: MemoryObjectSendStream[SessionMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage] read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) async with anyio.create_task_group() as tg: try: logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") async with httpx_client_factory( headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) ) as client: async with aconnect_sse( client, "GET", url, ) as event_source: event_source.response.raise_for_status() logger.debug("SSE connection established") async def sse_reader( task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, ): try: async for sse in event_source.aiter_sse(): # pragma: no branch logger.debug(f"Received SSE event: {sse.event}") match sse.event: case "endpoint": endpoint_url = urljoin(url, sse.data) logger.debug(f"Received endpoint URL: {endpoint_url}") url_parsed = urlparse(url) endpoint_parsed = urlparse(endpoint_url) if ( # pragma: no cover url_parsed.netloc != endpoint_parsed.netloc or url_parsed.scheme != endpoint_parsed.scheme ): error_msg = ( # pragma: no cover f"Endpoint origin does not match connection origin: {endpoint_url}" ) logger.error(error_msg) # pragma: no cover raise ValueError(error_msg) # pragma: no cover if on_session_created: session_id = _extract_session_id_from_endpoint(endpoint_url) if session_id: on_session_created(session_id) task_status.started(endpoint_url) case "message": # Skip empty data (keep-alive pings) if not sse.data: continue try: message = types.JSONRPCMessage.model_validate_json( # noqa: E501 sse.data ) logger.debug(f"Received server message: {message}") except Exception as exc: # pragma: no cover logger.exception("Error parsing server message") # pragma: no cover await read_stream_writer.send(exc) # pragma: no cover continue # pragma: no cover session_message = SessionMessage(message) await read_stream_writer.send(session_message) case _: # pragma: no cover logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover except SSEError as sse_exc: # pragma: no cover logger.exception("Encountered SSE exception") # pragma: no cover raise sse_exc # pragma: no cover except Exception as exc: # pragma: no cover logger.exception("Error in sse_reader") # pragma: no cover await read_stream_writer.send(exc) # pragma: no cover finally: await read_stream_writer.aclose() async def post_writer(endpoint_url: str): try: async with write_stream_reader: async for session_message in write_stream_reader: logger.debug(f"Sending client message: {session_message}") response = await client.post( endpoint_url, json=session_message.message.model_dump( by_alias=True, mode="json", exclude_none=True, ), ) response.raise_for_status() logger.debug(f"Client message sent successfully: {response.status_code}") except Exception: # pragma: no cover logger.exception("Error in post_writer") # pragma: no cover finally: await write_stream.aclose() endpoint_url = await tg.start(sse_reader) logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") tg.start_soon(post_writer, endpoint_url) try: yield read_stream, write_stream finally: tg.cancel_scope.cancel() finally: await read_stream_writer.aclose() await write_stream.aclose()