ai-station/.venv/lib/python3.12/site-packages/mcp/client/session.py

616 lines
23 KiB
Python

import logging
from datetime import timedelta
from typing import Any, Protocol, overload
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
from typing_extensions import deprecated
import mcp.types as types
from mcp.client.experimental import ExperimentalClientFeatures
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
logger = logging.getLogger("client")
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
class LoggingFnT(Protocol):
async def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ... # pragma: no branch
class MessageHandlerFnT(Protocol):
async def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ... # pragma: no branch
async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
await anyio.lowlevel.checkpoint()
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData( # pragma: no cover
code=types.INVALID_REQUEST,
message="Elicitation not supported",
)
async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
async def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
pass
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._sampling_capabilities = sampling_capabilities
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._server_capabilities: types.ServerCapabilities | None = None
self._experimental_features: ExperimentalClientFeatures | None = None
# Experimental: Task handlers (use defaults if not provided)
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
async def initialize(self) -> types.InitializeResult:
sampling = (
(self._sampling_capabilities or types.SamplingCapability())
if self._sampling_callback is not _default_sampling_callback
else None
)
elicitation = (
types.ElicitationCapability(
form=types.FormElicitationCapability(),
url=types.UrlElicitationCapability(),
)
if self._elicitation_callback is not _default_elicitation_callback
else None
)
roots = (
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
types.RootsCapability(listChanged=True)
if self._list_roots_callback is not _default_list_roots_callback
else None
)
result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
tasks=self._task_handlers.build_capability(),
),
clientInfo=self._client_info,
),
)
),
types.InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
self._server_capabilities = result.capabilities
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
return result
def get_server_capabilities(self) -> types.ServerCapabilities | None:
"""Return the server capabilities received during initialization.
Returns None if the session has not been initialized yet.
"""
return self._server_capabilities
@property
def experimental(self) -> ExperimentalClientFeatures:
"""Experimental APIs for tasks and other features.
WARNING: These APIs are experimental and may change without notice.
Example:
status = await session.experimental.get_task(task_id)
result = await session.experimental.get_task_result(task_id, CallToolResult)
"""
if self._experimental_features is None:
self._experimental_features = ExperimentalClientFeatures(self)
return self._experimental_features
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(
types.ClientRequest(types.PingRequest()),
types.EmptyResult,
)
async def send_progress_notification(
self,
progress_token: str | int,
progress: float,
total: float | None = None,
message: str | None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
types.ClientNotification(
types.ProgressNotification(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
message=message,
),
),
)
)
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.SetLevelRequest(
params=types.SetLevelRequestParams(level=level),
)
),
types.EmptyResult,
)
@overload
@deprecated("Use list_resources(params=PaginatedRequestParams(...)) instead")
async def list_resources(self, cursor: str | None) -> types.ListResourcesResult: ...
@overload
async def list_resources(self, *, params: types.PaginatedRequestParams | None) -> types.ListResourcesResult: ...
@overload
async def list_resources(self) -> types.ListResourcesResult: ...
async def list_resources(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourcesResult:
"""Send a resources/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListResourcesRequest(params=request_params)),
types.ListResourcesResult,
)
@overload
@deprecated("Use list_resource_templates(params=PaginatedRequestParams(...)) instead")
async def list_resource_templates(self, cursor: str | None) -> types.ListResourceTemplatesResult: ...
@overload
async def list_resource_templates(
self, *, params: types.PaginatedRequestParams | None
) -> types.ListResourceTemplatesResult: ...
@overload
async def list_resource_templates(self) -> types.ListResourceTemplatesResult: ...
async def list_resource_templates(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListResourceTemplatesRequest(params=request_params)),
types.ListResourceTemplatesResult,
)
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return await self.send_request(
types.ClientRequest(
types.ReadResourceRequest(
params=types.ReadResourceRequestParams(uri=uri),
)
),
types.ReadResourceResult,
)
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.SubscribeRequest(
params=types.SubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.UnsubscribeRequest(
params=types.UnsubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult:
"""Send a tools/call request with optional progress callback support."""
_meta: types.RequestParams.Meta | None = None
if meta is not None:
_meta = types.RequestParams.Meta(**meta)
result = await self.send_request(
types.ClientRequest(
types.CallToolRequest(
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
)
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
)
if not result.isError:
await self._validate_tool_result(name, result)
return result
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
"""Validate the structured content of a tool result against its output schema."""
if name not in self._tool_output_schemas:
# refresh output schema cache
await self.list_tools()
output_schema = None
if name in self._tool_output_schemas:
output_schema = self._tool_output_schemas.get(name)
else:
logger.warning(f"Tool {name} not listed by server, cannot validate any structured content")
if output_schema is not None:
from jsonschema import SchemaError, ValidationError, validate
if result.structuredContent is None:
raise RuntimeError(
f"Tool {name} has an output schema but did not return structured content"
) # pragma: no cover
try:
validate(result.structuredContent, output_schema)
except ValidationError as e:
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") # pragma: no cover
except SchemaError as e: # pragma: no cover
raise RuntimeError(f"Invalid schema for tool {name}: {e}") # pragma: no cover
@overload
@deprecated("Use list_prompts(params=PaginatedRequestParams(...)) instead")
async def list_prompts(self, cursor: str | None) -> types.ListPromptsResult: ...
@overload
async def list_prompts(self, *, params: types.PaginatedRequestParams | None) -> types.ListPromptsResult: ...
@overload
async def list_prompts(self) -> types.ListPromptsResult: ...
async def list_prompts(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListPromptsResult:
"""Send a prompts/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListPromptsRequest(params=request_params)),
types.ListPromptsResult,
)
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Send a prompts/get request."""
return await self.send_request(
types.ClientRequest(
types.GetPromptRequest(
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
types.GetPromptResult,
)
async def complete(
self,
ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str],
context_arguments: dict[str, str] | None = None,
) -> types.CompleteResult:
"""Send a completion/complete request."""
context = None
if context_arguments is not None:
context = types.CompletionContext(arguments=context_arguments)
return await self.send_request(
types.ClientRequest(
types.CompleteRequest(
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
context=context,
),
)
),
types.CompleteResult,
)
@overload
@deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead")
async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ...
@overload
async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ...
@overload
async def list_tools(self) -> types.ListToolsResult: ...
async def list_tools(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListToolsResult:
"""Send a tools/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
result = await self.send_request(
types.ClientRequest(types.ListToolsRequest(params=request_params)),
types.ListToolsResult,
)
# Cache tool output schemas for future validation
# Note: don't clear the cache, as we may be using a cursor
for tool in result.tools:
self._tool_output_schemas[tool.name] = tool.outputSchema
return result
async def send_roots_list_changed(self) -> None: # pragma: no cover
"""Send a roots/list_changed notification."""
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
# Delegate to experimental task handler if applicable
if self._task_handlers.handles_request(responder.request):
with responder:
await self._task_handlers.handle_request(ctx, responder)
return None
# Core request handling
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
# Check if this is a task-augmented request
if params.task is not None:
response = await self._task_handlers.augmented_sampling(ctx, params, params.task)
else:
response = await self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ElicitRequest(params=params):
with responder:
# Check if this is a task-augmented request
if params.task is not None:
response = await self._task_handlers.augmented_elicitation(ctx, params, params.task)
else:
response = await self._elicitation_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.PingRequest(): # pragma: no cover
with responder:
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
case _: # pragma: no cover
pass # Task requests handled above by _task_handlers
return None
async def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req)
async def _received_notification(self, notification: types.ServerNotification) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
case types.ElicitCompleteNotification(params=params):
# Handle elicitation completion notification
# Clients MAY use this to retry requests or update UI
# The notification contains the elicitationId of the completed elicitation
pass
case _:
pass