ai-station/.venv/lib/python3.12/site-packages/mcp/server/experimental/task_context.py

613 lines
22 KiB
Python
Raw Permalink Normal View History

2025-12-25 14:54:33 +00:00
"""
ServerTaskContext - Server-integrated task context with elicitation and sampling.
This wraps the pure TaskContext and adds server-specific functionality:
- Elicitation (task.elicit())
- Sampling (task.create_message())
- Status notifications
"""
from typing import Any
import anyio
from mcp.server.experimental.task_result_handler import TaskResultHandler
from mcp.server.session import ServerSession
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.capabilities import (
require_task_augmented_elicitation,
require_task_augmented_sampling,
)
from mcp.shared.experimental.tasks.context import TaskContext
from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue
from mcp.shared.experimental.tasks.resolver import Resolver
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import (
INVALID_REQUEST,
TASK_STATUS_INPUT_REQUIRED,
TASK_STATUS_WORKING,
ClientCapabilities,
CreateMessageResult,
CreateTaskResult,
ElicitationCapability,
ElicitRequestedSchema,
ElicitResult,
ErrorData,
IncludeContext,
ModelPreferences,
RequestId,
Result,
SamplingCapability,
SamplingMessage,
ServerNotification,
Task,
TaskMetadata,
TaskStatusNotification,
TaskStatusNotificationParams,
Tool,
ToolChoice,
)
class ServerTaskContext:
"""
Server-integrated task context with elicitation and sampling.
This wraps a pure TaskContext and adds server-specific functionality:
- elicit() for sending elicitation requests to the client
- create_message() for sampling requests
- Status notifications via the session
Example:
async def my_task_work(task: ServerTaskContext) -> CallToolResult:
await task.update_status("Starting...")
result = await task.elicit(
message="Continue?",
requestedSchema={"type": "object", "properties": {"ok": {"type": "boolean"}}}
)
if result.content.get("ok"):
return CallToolResult(content=[TextContent(text="Done!")])
else:
return CallToolResult(content=[TextContent(text="Cancelled")])
"""
def __init__(
self,
*,
task: Task,
store: TaskStore,
session: ServerSession,
queue: TaskMessageQueue,
handler: TaskResultHandler | None = None,
):
"""
Create a ServerTaskContext.
Args:
task: The Task object
store: The task store
session: The server session
queue: The message queue for elicitation/sampling
handler: The result handler for response routing (required for elicit/create_message)
"""
self._ctx = TaskContext(task=task, store=store)
self._session = session
self._queue = queue
self._handler = handler
self._store = store
# Delegate pure properties to inner context
@property
def task_id(self) -> str:
"""The task identifier."""
return self._ctx.task_id
@property
def task(self) -> Task:
"""The current task state."""
return self._ctx.task
@property
def is_cancelled(self) -> bool:
"""Whether cancellation has been requested."""
return self._ctx.is_cancelled
def request_cancellation(self) -> None:
"""Request cancellation of this task."""
self._ctx.request_cancellation()
# Enhanced methods with notifications
async def update_status(self, message: str, *, notify: bool = True) -> None:
"""
Update the task's status message.
Args:
message: The new status message
notify: Whether to send a notification to the client
"""
await self._ctx.update_status(message)
if notify:
await self._send_notification()
async def complete(self, result: Result, *, notify: bool = True) -> None:
"""
Mark the task as completed with the given result.
Args:
result: The task result
notify: Whether to send a notification to the client
"""
await self._ctx.complete(result)
if notify:
await self._send_notification()
async def fail(self, error: str, *, notify: bool = True) -> None:
"""
Mark the task as failed with an error message.
Args:
error: The error message
notify: Whether to send a notification to the client
"""
await self._ctx.fail(error)
if notify:
await self._send_notification()
async def _send_notification(self) -> None:
"""Send a task status notification to the client."""
task = self._ctx.task
await self._session.send_notification(
ServerNotification(
TaskStatusNotification(
params=TaskStatusNotificationParams(
taskId=task.taskId,
status=task.status,
statusMessage=task.statusMessage,
createdAt=task.createdAt,
lastUpdatedAt=task.lastUpdatedAt,
ttl=task.ttl,
pollInterval=task.pollInterval,
)
)
)
)
# Server-specific methods: elicitation and sampling
def _check_elicitation_capability(self) -> None:
"""Check if the client supports elicitation."""
if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support elicitation capability",
)
)
def _check_sampling_capability(self) -> None:
"""Check if the client supports sampling."""
if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support sampling capability",
)
)
async def elicit(
self,
message: str,
requestedSchema: ElicitRequestedSchema,
) -> ElicitResult:
"""
Send an elicitation request via the task message queue.
This method:
1. Checks client capability
2. Updates task status to "input_required"
3. Queues the elicitation request
4. Waits for the response (delivered via tasks/result round-trip)
5. Updates task status back to "working"
6. Returns the result
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure
Returns:
The client's response
Raises:
McpError: If client doesn't support elicitation capability
"""
self._check_elicitation_capability()
if self._handler is None:
raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build the request using session's helper
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
message=message,
requestedSchema=requestedSchema,
related_task_id=self.task_id,
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return ElicitResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
# Coverage can't track async exception handlers reliably.
# This path is tested in test_elicit_restores_status_on_cancellation
# which verifies status is restored to "working" after cancellation.
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def elicit_url(
self,
message: str,
url: str,
elicitation_id: str,
) -> ElicitResult:
"""
Send a URL mode elicitation request via the task message queue.
This directs the user to an external URL for out-of-band interactions
like OAuth flows, credential collection, or payment processing.
This method:
1. Checks client capability
2. Updates task status to "input_required"
3. Queues the elicitation request
4. Waits for the response (delivered via tasks/result round-trip)
5. Updates task status back to "working"
6. Returns the result
Args:
message: Human-readable explanation of why the interaction is needed
url: The URL the user should navigate to
elicitation_id: Unique identifier for tracking this elicitation
Returns:
The client's response indicating acceptance, decline, or cancellation
Raises:
McpError: If client doesn't support elicitation capability
RuntimeError: If handler is not configured
"""
self._check_elicitation_capability()
if self._handler is None:
raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build the request using session's helper
request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage]
message=message,
url=url,
elicitation_id=elicitation_id,
related_task_id=self.task_id,
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return ElicitResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def create_message(
self,
messages: list[SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
tools: list[Tool] | None = None,
tool_choice: ToolChoice | None = None,
) -> CreateMessageResult:
"""
Send a sampling request via the task message queue.
This method:
1. Checks client capability
2. Updates task status to "input_required"
3. Queues the sampling request
4. Waits for the response (delivered via tasks/result round-trip)
5. Updates task status back to "working"
6. Returns the result
Args:
messages: The conversation messages for sampling
max_tokens: Maximum tokens in the response
system_prompt: Optional system prompt
include_context: Context inclusion strategy
temperature: Sampling temperature
stop_sequences: Stop sequences
metadata: Additional metadata
model_preferences: Model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
Returns:
The sampling result from the client
Raises:
McpError: If client doesn't support sampling capability or tools
ValueError: If tool_use or tool_result message structure is invalid
"""
self._check_sampling_capability()
client_caps = self._session.client_params.capabilities if self._session.client_params else None
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
if self._handler is None:
raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build the request using session's helper
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
messages=messages,
max_tokens=max_tokens,
system_prompt=system_prompt,
include_context=include_context,
temperature=temperature,
stop_sequences=stop_sequences,
metadata=metadata,
model_preferences=model_preferences,
tools=tools,
tool_choice=tool_choice,
related_task_id=self.task_id,
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return CreateMessageResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
# Coverage can't track async exception handlers reliably.
# This path is tested in test_create_message_restores_status_on_cancellation
# which verifies status is restored to "working" after cancellation.
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def elicit_as_task(
self,
message: str,
requestedSchema: ElicitRequestedSchema,
*,
ttl: int = 60000,
) -> ElicitResult:
"""
Send a task-augmented elicitation via the queue, then poll client.
This is for use inside a task-augmented tool call when you want the client
to handle the elicitation as its own task. The elicitation request is queued
and delivered when the client calls tasks/result. After the client responds
with CreateTaskResult, we poll the client's task until complete.
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure
ttl: Task time-to-live in milliseconds for the client's task
Returns:
The client's elicitation response
Raises:
McpError: If client doesn't support task-augmented elicitation
RuntimeError: If handler is not configured
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_elicitation(client_caps)
if self._handler is None:
raise RuntimeError("handler is required for elicit_as_task()")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
message=message,
requestedSchema=requestedSchema,
related_task_id=self.task_id,
task=TaskMetadata(ttl=ttl),
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for initial response (CreateTaskResult from client)
response_data = await resolver.wait()
create_result = CreateTaskResult.model_validate(response_data)
client_task_id = create_result.task.taskId
# Poll the client's task using session.experimental
async for _ in self._session.experimental.poll_task(client_task_id):
pass
# Get final result from client
result = await self._session.experimental.get_task_result(
client_task_id,
ElicitResult,
)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return result
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def create_message_as_task(
self,
messages: list[SamplingMessage],
*,
max_tokens: int,
ttl: int = 60000,
system_prompt: str | None = None,
include_context: IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
tools: list[Tool] | None = None,
tool_choice: ToolChoice | None = None,
) -> CreateMessageResult:
"""
Send a task-augmented sampling request via the queue, then poll client.
This is for use inside a task-augmented tool call when you want the client
to handle the sampling as its own task. The request is queued and delivered
when the client calls tasks/result. After the client responds with
CreateTaskResult, we poll the client's task until complete.
Args:
messages: The conversation messages for sampling
max_tokens: Maximum tokens in the response
ttl: Task time-to-live in milliseconds for the client's task
system_prompt: Optional system prompt
include_context: Context inclusion strategy
temperature: Sampling temperature
stop_sequences: Stop sequences
metadata: Additional metadata
model_preferences: Model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
Returns:
The sampling result from the client
Raises:
McpError: If client doesn't support task-augmented sampling or tools
ValueError: If tool_use or tool_result message structure is invalid
RuntimeError: If handler is not configured
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_sampling(client_caps)
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
if self._handler is None:
raise RuntimeError("handler is required for create_message_as_task()")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build request WITH task field for task-augmented sampling
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
messages=messages,
max_tokens=max_tokens,
system_prompt=system_prompt,
include_context=include_context,
temperature=temperature,
stop_sequences=stop_sequences,
metadata=metadata,
model_preferences=model_preferences,
tools=tools,
tool_choice=tool_choice,
related_task_id=self.task_id,
task=TaskMetadata(ttl=ttl),
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for initial response (CreateTaskResult from client)
response_data = await resolver.wait()
create_result = CreateTaskResult.model_validate(response_data)
client_task_id = create_result.task.taskId
# Poll the client's task using session.experimental
async for _ in self._session.experimental.poll_task(client_task_id):
pass
# Get final result from client
result = await self._session.experimental.get_task_result(
client_task_id,
CreateMessageResult,
)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return result
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise