ai-station/.venv/lib/python3.12/site-packages/mcp/server/session.py

692 lines
27 KiB
Python

"""
ServerSession Module
This module provides the ServerSession class, which manages communication between the
server and client in the MCP (Model Context Protocol) framework. It is most commonly
used in MCP servers to interact with the client.
Common usage pattern:
```
server = Server(name)
@server.call_tool()
async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any:
# Check client capabilities before proceeding
if ctx.session.check_client_capability(
types.ClientCapabilities(experimental={"advanced_tools": dict()})
):
# Perform advanced tool operations
result = await perform_advanced_tool_operation(arguments)
else:
# Fall back to basic tool operations
result = await perform_basic_tool_operation(arguments)
return result
@server.list_prompts()
async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
# Access session for any necessary checks or operations
if ctx.session.client_params:
# Customize prompts based on client initialization parameters
return generate_custom_prompts(ctx.session.client_params)
else:
return default_prompts
```
The ServerSession class is typically used internally by the Server class and should not
be instantiated directly by users of the MCP framework.
"""
from enum import Enum
from typing import Any, TypeVar, overload
import anyio
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
import mcp.types as types
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
from mcp.server.models import InitializationOptions
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
from mcp.shared.experimental.tasks.capabilities import check_tasks_capability
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import (
BaseSession,
RequestResponder,
)
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
class InitializationState(Enum):
NotInitialized = 1
Initializing = 2
Initialized = 3
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
ServerRequestResponder = (
RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception
)
class ServerSession(
BaseSession[
types.ServerRequest,
types.ServerNotification,
types.ServerResult,
types.ClientRequest,
types.ClientNotification,
]
):
_initialized: InitializationState = InitializationState.NotInitialized
_client_params: types.InitializeRequestParams | None = None
_experimental_features: ExperimentalServerSessionFeatures | None = None
def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
) -> None:
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
self._initialization_state = (
InitializationState.Initialized if stateless else InitializationState.NotInitialized
)
self._init_options = init_options
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
ServerRequestResponder
](0)
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
@property
def client_params(self) -> types.InitializeRequestParams | None:
return self._client_params # pragma: no cover
@property
def experimental(self) -> ExperimentalServerSessionFeatures:
"""Experimental APIs for server→client task operations.
WARNING: These APIs are experimental and may change without notice.
"""
if self._experimental_features is None:
self._experimental_features = ExperimentalServerSessionFeatures(self)
return self._experimental_features
def check_client_capability(self, capability: types.ClientCapabilities) -> bool: # pragma: no cover
"""Check if the client supports a specific capability."""
if self._client_params is None:
return False
client_caps = self._client_params.capabilities
if capability.roots is not None:
if client_caps.roots is None:
return False
if capability.roots.listChanged and not client_caps.roots.listChanged:
return False
if capability.sampling is not None:
if client_caps.sampling is None:
return False
if capability.sampling.context is not None and client_caps.sampling.context is None:
return False
if capability.sampling.tools is not None and client_caps.sampling.tools is None:
return False
if capability.elicitation is not None and client_caps.elicitation is None:
return False
if capability.experimental is not None:
if client_caps.experimental is None:
return False
for exp_key, exp_value in capability.experimental.items():
if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value:
return False
if capability.tasks is not None:
if client_caps.tasks is None:
return False
if not check_tasks_capability(capability.tasks, client_caps.tasks):
return False
return True
async def _receive_loop(self) -> None:
async with self._incoming_message_stream_writer:
await super()._receive_loop()
async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]):
match responder.request.root:
case types.InitializeRequest(params=params):
requested_version = params.protocolVersion
self._initialization_state = InitializationState.Initializing
self._client_params = params
with responder:
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=requested_version
if requested_version in SUPPORTED_PROTOCOL_VERSIONS
else types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
websiteUrl=self._init_options.website_url,
icons=self._init_options.icons,
),
instructions=self._init_options.instructions,
)
)
)
self._initialization_state = InitializationState.Initialized
case types.PingRequest():
# Ping requests are allowed at any time
pass
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError("Received request before initialization was complete")
async def _received_notification(self, notification: types.ClientNotification) -> None:
# Need this to avoid ASYNC910
await anyio.lowlevel.checkpoint()
match notification.root:
case types.InitializedNotification():
self._initialization_state = InitializationState.Initialized
case _:
if self._initialization_state != InitializationState.Initialized: # pragma: no cover
raise RuntimeError("Received notification before initialization was complete")
async def send_log_message(
self,
level: types.LoggingLevel,
data: Any,
logger: str | None = None,
related_request_id: types.RequestId | None = None,
) -> None:
"""Send a log message notification."""
await self.send_notification(
types.ServerNotification(
types.LoggingMessageNotification(
params=types.LoggingMessageNotificationParams(
level=level,
data=data,
logger=logger,
),
)
),
related_request_id,
)
async def send_resource_updated(self, uri: AnyUrl) -> None: # pragma: no cover
"""Send a resource updated notification."""
await self.send_notification(
types.ServerNotification(
types.ResourceUpdatedNotification(
params=types.ResourceUpdatedNotificationParams(uri=uri),
)
)
)
@overload
async def create_message(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: None = None,
tool_choice: types.ToolChoice | None = None,
related_request_id: types.RequestId | None = None,
) -> types.CreateMessageResult:
"""Overload: Without tools, returns single content."""
...
@overload
async def create_message(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: list[types.Tool],
tool_choice: types.ToolChoice | None = None,
related_request_id: types.RequestId | None = None,
) -> types.CreateMessageResultWithTools:
"""Overload: With tools, returns array-capable content."""
...
async def create_message(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: list[types.Tool] | None = None,
tool_choice: types.ToolChoice | None = None,
related_request_id: types.RequestId | None = None,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools:
"""Send a sampling/create_message request.
Args:
messages: The conversation messages to send.
max_tokens: Maximum number of tokens to generate.
system_prompt: Optional system prompt.
include_context: Optional context inclusion setting.
Should only be set to "thisServer" or "allServers"
if the client has sampling.context capability.
temperature: Optional sampling temperature.
stop_sequences: Optional stop sequences.
metadata: Optional metadata to pass through to the LLM provider.
model_preferences: Optional model selection preferences.
tools: Optional list of tools the LLM can use during sampling.
Requires client to have sampling.tools capability.
tool_choice: Optional control over tool usage behavior.
Requires client to have sampling.tools capability.
related_request_id: Optional ID of a related request.
Returns:
The sampling result from the client.
Raises:
McpError: If tools are provided but client doesn't support them.
ValueError: If tool_use or tool_result message structure is invalid.
"""
client_caps = self._client_params.capabilities if self._client_params else None
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
request = types.ServerRequest(
types.CreateMessageRequest(
params=types.CreateMessageRequestParams(
messages=messages,
systemPrompt=system_prompt,
includeContext=include_context,
temperature=temperature,
maxTokens=max_tokens,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
tools=tools,
toolChoice=tool_choice,
),
)
)
metadata_obj = ServerMessageMetadata(related_request_id=related_request_id)
# Use different result types based on whether tools are provided
if tools is not None:
return await self.send_request(
request=request,
result_type=types.CreateMessageResultWithTools,
metadata=metadata_obj,
)
return await self.send_request(
request=request,
result_type=types.CreateMessageResult,
metadata=metadata_obj,
)
async def list_roots(self) -> types.ListRootsResult:
"""Send a roots/list request."""
return await self.send_request(
types.ServerRequest(types.ListRootsRequest()),
types.ListRootsResult,
)
async def elicit(
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
) -> types.ElicitResult:
"""Send a form mode elicitation/create request.
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure
related_request_id: Optional ID of the request that triggered this elicitation
Returns:
The client's response
Note:
This method is deprecated in favor of elicit_form(). It remains for
backward compatibility but new code should use elicit_form().
"""
return await self.elicit_form(message, requestedSchema, related_request_id)
async def elicit_form(
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
) -> types.ElicitResult:
"""Send a form mode elicitation/create request.
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure
related_request_id: Optional ID of the request that triggered this elicitation
Returns:
The client's response with form data
"""
return await self.send_request(
types.ServerRequest(
types.ElicitRequest(
params=types.ElicitRequestFormParams(
message=message,
requestedSchema=requestedSchema,
),
)
),
types.ElicitResult,
metadata=ServerMessageMetadata(related_request_id=related_request_id),
)
async def elicit_url(
self,
message: str,
url: str,
elicitation_id: str,
related_request_id: types.RequestId | None = None,
) -> types.ElicitResult:
"""Send a URL mode elicitation/create request.
This directs the user to an external URL for out-of-band interactions
like OAuth flows, credential collection, or payment processing.
Args:
message: Human-readable explanation of why the interaction is needed
url: The URL the user should navigate to
elicitation_id: Unique identifier for tracking this elicitation
related_request_id: Optional ID of the request that triggered this elicitation
Returns:
The client's response indicating acceptance, decline, or cancellation
"""
return await self.send_request(
types.ServerRequest(
types.ElicitRequest(
params=types.ElicitRequestURLParams(
message=message,
url=url,
elicitationId=elicitation_id,
),
)
),
types.ElicitResult,
metadata=ServerMessageMetadata(related_request_id=related_request_id),
)
async def send_ping(self) -> types.EmptyResult: # pragma: no cover
"""Send a ping request."""
return await self.send_request(
types.ServerRequest(types.PingRequest()),
types.EmptyResult,
)
async def send_progress_notification(
self,
progress_token: str | int,
progress: float,
total: float | None = None,
message: str | None = None,
related_request_id: str | None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
types.ServerNotification(
types.ProgressNotification(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
message=message,
),
)
),
related_request_id,
)
async def send_resource_list_changed(self) -> None: # pragma: no cover
"""Send a resource list changed notification."""
await self.send_notification(types.ServerNotification(types.ResourceListChangedNotification()))
async def send_tool_list_changed(self) -> None: # pragma: no cover
"""Send a tool list changed notification."""
await self.send_notification(types.ServerNotification(types.ToolListChangedNotification()))
async def send_prompt_list_changed(self) -> None: # pragma: no cover
"""Send a prompt list changed notification."""
await self.send_notification(types.ServerNotification(types.PromptListChangedNotification()))
async def send_elicit_complete(
self,
elicitation_id: str,
related_request_id: types.RequestId | None = None,
) -> None:
"""Send an elicitation completion notification.
This should be sent when a URL mode elicitation has been completed
out-of-band to inform the client that it may retry any requests
that were waiting for this elicitation.
Args:
elicitation_id: The unique identifier of the completed elicitation
related_request_id: Optional ID of the request that triggered this
"""
await self.send_notification(
types.ServerNotification(
types.ElicitCompleteNotification(
params=types.ElicitCompleteNotificationParams(elicitationId=elicitation_id)
)
),
related_request_id,
)
def _build_elicit_form_request(
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
related_task_id: str | None = None,
task: types.TaskMetadata | None = None,
) -> types.JSONRPCRequest:
"""Build a form mode elicitation request without sending it.
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure
related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata
task: If provided, makes this a task-augmented request
Returns:
A JSONRPCRequest ready to be sent or queued
"""
params = types.ElicitRequestFormParams(
message=message,
requestedSchema=requestedSchema,
task=task,
)
params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True)
# Add related-task metadata if associated with a parent task
if related_task_id is not None:
# Defensive: model_dump() never includes _meta, but guard against future changes
if "_meta" not in params_data: # pragma: no cover
params_data["_meta"] = {}
params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata(
taskId=related_task_id
).model_dump(by_alias=True)
request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id
if related_task_id is None:
self._request_id += 1
return types.JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
method="elicitation/create",
params=params_data,
)
def _build_elicit_url_request(
self,
message: str,
url: str,
elicitation_id: str,
related_task_id: str | None = None,
) -> types.JSONRPCRequest:
"""Build a URL mode elicitation request without sending it.
Args:
message: Human-readable explanation of why the interaction is needed
url: The URL the user should navigate to
elicitation_id: Unique identifier for tracking this elicitation
related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata
Returns:
A JSONRPCRequest ready to be sent or queued
"""
params = types.ElicitRequestURLParams(
message=message,
url=url,
elicitationId=elicitation_id,
)
params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True)
# Add related-task metadata if associated with a parent task
if related_task_id is not None:
# Defensive: model_dump() never includes _meta, but guard against future changes
if "_meta" not in params_data: # pragma: no cover
params_data["_meta"] = {}
params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata(
taskId=related_task_id
).model_dump(by_alias=True)
request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id
if related_task_id is None:
self._request_id += 1
return types.JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
method="elicitation/create",
params=params_data,
)
def _build_create_message_request(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: list[types.Tool] | None = None,
tool_choice: types.ToolChoice | None = None,
related_task_id: str | None = None,
task: types.TaskMetadata | None = None,
) -> types.JSONRPCRequest:
"""Build a sampling/createMessage request without sending it.
Args:
messages: The conversation messages to send
max_tokens: Maximum number of tokens to generate
system_prompt: Optional system prompt
include_context: Optional context inclusion setting
temperature: Optional sampling temperature
stop_sequences: Optional stop sequences
metadata: Optional metadata to pass through to the LLM provider
model_preferences: Optional model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
related_task_id: If provided, adds io.modelcontextprotocol/related-task metadata
task: If provided, makes this a task-augmented request
Returns:
A JSONRPCRequest ready to be sent or queued
"""
params = types.CreateMessageRequestParams(
messages=messages,
systemPrompt=system_prompt,
includeContext=include_context,
temperature=temperature,
maxTokens=max_tokens,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
tools=tools,
toolChoice=tool_choice,
task=task,
)
params_data = params.model_dump(by_alias=True, mode="json", exclude_none=True)
# Add related-task metadata if associated with a parent task
if related_task_id is not None:
# Defensive: model_dump() never includes _meta, but guard against future changes
if "_meta" not in params_data: # pragma: no cover
params_data["_meta"] = {}
params_data["_meta"][RELATED_TASK_METADATA_KEY] = types.RelatedTaskMetadata(
taskId=related_task_id
).model_dump(by_alias=True)
request_id = f"task-{related_task_id}-{id(params)}" if related_task_id else self._request_id
if related_task_id is None:
self._request_id += 1
return types.JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
method="sampling/createMessage",
params=params_data,
)
async def send_message(self, message: SessionMessage) -> None:
"""Send a raw session message.
This is primarily used by TaskResultHandler to deliver queued messages
(elicitation/sampling requests) to the client during task execution.
WARNING: This is a low-level experimental method that may change without
notice. Prefer using higher-level methods like send_notification() or
send_request() for normal operations.
Args:
message: The session message to send
"""
await self._write_stream.send(message)
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
await self._incoming_message_stream_writer.send(req)
@property
def incoming_messages(
self,
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
return self._incoming_message_stream_reader