ai-station/.venv/lib/python3.12/site-packages/mcp/server/lowlevel/experimental.py

289 lines
10 KiB
Python
Raw Normal View History

2025-12-25 14:54:33 +00:00
"""Experimental handlers for the low-level MCP server.
WARNING: These APIs are experimental and may change without notice.
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING
from mcp.server.experimental.task_support import TaskSupport
from mcp.server.lowlevel.func_inspection import create_call_wrapper
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.helpers import cancel_task
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import (
INVALID_PARAMS,
CancelTaskRequest,
CancelTaskResult,
ErrorData,
GetTaskPayloadRequest,
GetTaskPayloadResult,
GetTaskRequest,
GetTaskResult,
ListTasksRequest,
ListTasksResult,
ServerCapabilities,
ServerResult,
ServerTasksCapability,
ServerTasksRequestsCapability,
TasksCancelCapability,
TasksListCapability,
TasksToolsCapability,
)
if TYPE_CHECKING:
from mcp.server.lowlevel.server import Server
logger = logging.getLogger(__name__)
class ExperimentalHandlers:
"""Experimental request/notification handlers.
WARNING: These APIs are experimental and may change without notice.
"""
def __init__(
self,
server: Server,
request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]],
notification_handlers: dict[type, Callable[..., Awaitable[None]]],
):
self._server = server
self._request_handlers = request_handlers
self._notification_handlers = notification_handlers
self._task_support: TaskSupport | None = None
@property
def task_support(self) -> TaskSupport | None:
"""Get the task support configuration, if enabled."""
return self._task_support
def update_capabilities(self, capabilities: ServerCapabilities) -> None:
# Only add tasks capability if handlers are registered
if not any(
req_type in self._request_handlers
for req_type in [GetTaskRequest, ListTasksRequest, CancelTaskRequest, GetTaskPayloadRequest]
):
return
capabilities.tasks = ServerTasksCapability()
if ListTasksRequest in self._request_handlers:
capabilities.tasks.list = TasksListCapability()
if CancelTaskRequest in self._request_handlers:
capabilities.tasks.cancel = TasksCancelCapability()
capabilities.tasks.requests = ServerTasksRequestsCapability(
tools=TasksToolsCapability()
) # assuming always supported for now
def enable_tasks(
self,
store: TaskStore | None = None,
queue: TaskMessageQueue | None = None,
) -> TaskSupport:
"""
Enable experimental task support.
This sets up the task infrastructure and auto-registers default handlers
for tasks/get, tasks/result, tasks/list, and tasks/cancel.
Args:
store: Custom TaskStore implementation (defaults to InMemoryTaskStore)
queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue)
Returns:
The TaskSupport configuration object
Example:
# Simple in-memory setup
server.experimental.enable_tasks()
# Custom store/queue for distributed systems
server.experimental.enable_tasks(
store=RedisTaskStore(redis_url),
queue=RedisTaskMessageQueue(redis_url),
)
WARNING: This API is experimental and may change without notice.
"""
if store is None:
store = InMemoryTaskStore()
if queue is None:
queue = InMemoryTaskMessageQueue()
self._task_support = TaskSupport(store=store, queue=queue)
# Auto-register default handlers
self._register_default_task_handlers()
return self._task_support
def _register_default_task_handlers(self) -> None:
"""Register default handlers for task operations."""
assert self._task_support is not None
support = self._task_support
# Register get_task handler if not already registered
if GetTaskRequest not in self._request_handlers:
async def _default_get_task(req: GetTaskRequest) -> ServerResult:
task = await support.store.get_task(req.params.taskId)
if task is None:
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Task not found: {req.params.taskId}",
)
)
return ServerResult(
GetTaskResult(
taskId=task.taskId,
status=task.status,
statusMessage=task.statusMessage,
createdAt=task.createdAt,
lastUpdatedAt=task.lastUpdatedAt,
ttl=task.ttl,
pollInterval=task.pollInterval,
)
)
self._request_handlers[GetTaskRequest] = _default_get_task
# Register get_task_result handler if not already registered
if GetTaskPayloadRequest not in self._request_handlers:
async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult:
ctx = self._server.request_context
result = await support.handler.handle(req, ctx.session, ctx.request_id)
return ServerResult(result)
self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result
# Register list_tasks handler if not already registered
if ListTasksRequest not in self._request_handlers:
async def _default_list_tasks(req: ListTasksRequest) -> ServerResult:
cursor = req.params.cursor if req.params else None
tasks, next_cursor = await support.store.list_tasks(cursor)
return ServerResult(ListTasksResult(tasks=tasks, nextCursor=next_cursor))
self._request_handlers[ListTasksRequest] = _default_list_tasks
# Register cancel_task handler if not already registered
if CancelTaskRequest not in self._request_handlers:
async def _default_cancel_task(req: CancelTaskRequest) -> ServerResult:
result = await cancel_task(support.store, req.params.taskId)
return ServerResult(result)
self._request_handlers[CancelTaskRequest] = _default_cancel_task
def list_tasks(
self,
) -> Callable[
[Callable[[ListTasksRequest], Awaitable[ListTasksResult]]],
Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
]:
"""Register a handler for listing tasks.
WARNING: This API is experimental and may change without notice.
"""
def decorator(
func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]:
logger.debug("Registering handler for ListTasksRequest")
wrapper = create_call_wrapper(func, ListTasksRequest)
async def handler(req: ListTasksRequest) -> ServerResult:
result = await wrapper(req)
return ServerResult(result)
self._request_handlers[ListTasksRequest] = handler
return func
return decorator
def get_task(
self,
) -> Callable[
[Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]]
]:
"""Register a handler for getting task status.
WARNING: This API is experimental and may change without notice.
"""
def decorator(
func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]],
) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]:
logger.debug("Registering handler for GetTaskRequest")
wrapper = create_call_wrapper(func, GetTaskRequest)
async def handler(req: GetTaskRequest) -> ServerResult:
result = await wrapper(req)
return ServerResult(result)
self._request_handlers[GetTaskRequest] = handler
return func
return decorator
def get_task_result(
self,
) -> Callable[
[Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]],
Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]],
]:
"""Register a handler for getting task results/payload.
WARNING: This API is experimental and may change without notice.
"""
def decorator(
func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]],
) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]:
logger.debug("Registering handler for GetTaskPayloadRequest")
wrapper = create_call_wrapper(func, GetTaskPayloadRequest)
async def handler(req: GetTaskPayloadRequest) -> ServerResult:
result = await wrapper(req)
return ServerResult(result)
self._request_handlers[GetTaskPayloadRequest] = handler
return func
return decorator
def cancel_task(
self,
) -> Callable[
[Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]],
Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]],
]:
"""Register a handler for cancelling tasks.
WARNING: This API is experimental and may change without notice.
"""
def decorator(
func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]],
) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]:
logger.debug("Registering handler for CancelTaskRequest")
wrapper = create_call_wrapper(func, CancelTaskRequest)
async def handler(req: CancelTaskRequest) -> ServerResult:
result = await wrapper(req)
return ServerResult(result)
self._request_handlers[CancelTaskRequest] = handler
return func
return decorator