289 lines
10 KiB
Python
289 lines
10 KiB
Python
|
|
"""Experimental handlers for the low-level MCP server.
|
||
|
|
|
||
|
|
WARNING: These APIs are experimental and may change without notice.
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from collections.abc import Awaitable, Callable
|
||
|
|
from typing import TYPE_CHECKING
|
||
|
|
|
||
|
|
from mcp.server.experimental.task_support import TaskSupport
|
||
|
|
from mcp.server.lowlevel.func_inspection import create_call_wrapper
|
||
|
|
from mcp.shared.exceptions import McpError
|
||
|
|
from mcp.shared.experimental.tasks.helpers import cancel_task
|
||
|
|
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
|
||
|
|
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue
|
||
|
|
from mcp.shared.experimental.tasks.store import TaskStore
|
||
|
|
from mcp.types import (
|
||
|
|
INVALID_PARAMS,
|
||
|
|
CancelTaskRequest,
|
||
|
|
CancelTaskResult,
|
||
|
|
ErrorData,
|
||
|
|
GetTaskPayloadRequest,
|
||
|
|
GetTaskPayloadResult,
|
||
|
|
GetTaskRequest,
|
||
|
|
GetTaskResult,
|
||
|
|
ListTasksRequest,
|
||
|
|
ListTasksResult,
|
||
|
|
ServerCapabilities,
|
||
|
|
ServerResult,
|
||
|
|
ServerTasksCapability,
|
||
|
|
ServerTasksRequestsCapability,
|
||
|
|
TasksCancelCapability,
|
||
|
|
TasksListCapability,
|
||
|
|
TasksToolsCapability,
|
||
|
|
)
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from mcp.server.lowlevel.server import Server
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class ExperimentalHandlers:
|
||
|
|
"""Experimental request/notification handlers.
|
||
|
|
|
||
|
|
WARNING: These APIs are experimental and may change without notice.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
server: Server,
|
||
|
|
request_handlers: dict[type, Callable[..., Awaitable[ServerResult]]],
|
||
|
|
notification_handlers: dict[type, Callable[..., Awaitable[None]]],
|
||
|
|
):
|
||
|
|
self._server = server
|
||
|
|
self._request_handlers = request_handlers
|
||
|
|
self._notification_handlers = notification_handlers
|
||
|
|
self._task_support: TaskSupport | None = None
|
||
|
|
|
||
|
|
@property
|
||
|
|
def task_support(self) -> TaskSupport | None:
|
||
|
|
"""Get the task support configuration, if enabled."""
|
||
|
|
return self._task_support
|
||
|
|
|
||
|
|
def update_capabilities(self, capabilities: ServerCapabilities) -> None:
|
||
|
|
# Only add tasks capability if handlers are registered
|
||
|
|
if not any(
|
||
|
|
req_type in self._request_handlers
|
||
|
|
for req_type in [GetTaskRequest, ListTasksRequest, CancelTaskRequest, GetTaskPayloadRequest]
|
||
|
|
):
|
||
|
|
return
|
||
|
|
|
||
|
|
capabilities.tasks = ServerTasksCapability()
|
||
|
|
if ListTasksRequest in self._request_handlers:
|
||
|
|
capabilities.tasks.list = TasksListCapability()
|
||
|
|
if CancelTaskRequest in self._request_handlers:
|
||
|
|
capabilities.tasks.cancel = TasksCancelCapability()
|
||
|
|
|
||
|
|
capabilities.tasks.requests = ServerTasksRequestsCapability(
|
||
|
|
tools=TasksToolsCapability()
|
||
|
|
) # assuming always supported for now
|
||
|
|
|
||
|
|
def enable_tasks(
|
||
|
|
self,
|
||
|
|
store: TaskStore | None = None,
|
||
|
|
queue: TaskMessageQueue | None = None,
|
||
|
|
) -> TaskSupport:
|
||
|
|
"""
|
||
|
|
Enable experimental task support.
|
||
|
|
|
||
|
|
This sets up the task infrastructure and auto-registers default handlers
|
||
|
|
for tasks/get, tasks/result, tasks/list, and tasks/cancel.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
store: Custom TaskStore implementation (defaults to InMemoryTaskStore)
|
||
|
|
queue: Custom TaskMessageQueue implementation (defaults to InMemoryTaskMessageQueue)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The TaskSupport configuration object
|
||
|
|
|
||
|
|
Example:
|
||
|
|
# Simple in-memory setup
|
||
|
|
server.experimental.enable_tasks()
|
||
|
|
|
||
|
|
# Custom store/queue for distributed systems
|
||
|
|
server.experimental.enable_tasks(
|
||
|
|
store=RedisTaskStore(redis_url),
|
||
|
|
queue=RedisTaskMessageQueue(redis_url),
|
||
|
|
)
|
||
|
|
|
||
|
|
WARNING: This API is experimental and may change without notice.
|
||
|
|
"""
|
||
|
|
if store is None:
|
||
|
|
store = InMemoryTaskStore()
|
||
|
|
if queue is None:
|
||
|
|
queue = InMemoryTaskMessageQueue()
|
||
|
|
|
||
|
|
self._task_support = TaskSupport(store=store, queue=queue)
|
||
|
|
|
||
|
|
# Auto-register default handlers
|
||
|
|
self._register_default_task_handlers()
|
||
|
|
|
||
|
|
return self._task_support
|
||
|
|
|
||
|
|
def _register_default_task_handlers(self) -> None:
|
||
|
|
"""Register default handlers for task operations."""
|
||
|
|
assert self._task_support is not None
|
||
|
|
support = self._task_support
|
||
|
|
|
||
|
|
# Register get_task handler if not already registered
|
||
|
|
if GetTaskRequest not in self._request_handlers:
|
||
|
|
|
||
|
|
async def _default_get_task(req: GetTaskRequest) -> ServerResult:
|
||
|
|
task = await support.store.get_task(req.params.taskId)
|
||
|
|
if task is None:
|
||
|
|
raise McpError(
|
||
|
|
ErrorData(
|
||
|
|
code=INVALID_PARAMS,
|
||
|
|
message=f"Task not found: {req.params.taskId}",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
return ServerResult(
|
||
|
|
GetTaskResult(
|
||
|
|
taskId=task.taskId,
|
||
|
|
status=task.status,
|
||
|
|
statusMessage=task.statusMessage,
|
||
|
|
createdAt=task.createdAt,
|
||
|
|
lastUpdatedAt=task.lastUpdatedAt,
|
||
|
|
ttl=task.ttl,
|
||
|
|
pollInterval=task.pollInterval,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
self._request_handlers[GetTaskRequest] = _default_get_task
|
||
|
|
|
||
|
|
# Register get_task_result handler if not already registered
|
||
|
|
if GetTaskPayloadRequest not in self._request_handlers:
|
||
|
|
|
||
|
|
async def _default_get_task_result(req: GetTaskPayloadRequest) -> ServerResult:
|
||
|
|
ctx = self._server.request_context
|
||
|
|
result = await support.handler.handle(req, ctx.session, ctx.request_id)
|
||
|
|
return ServerResult(result)
|
||
|
|
|
||
|
|
self._request_handlers[GetTaskPayloadRequest] = _default_get_task_result
|
||
|
|
|
||
|
|
# Register list_tasks handler if not already registered
|
||
|
|
if ListTasksRequest not in self._request_handlers:
|
||
|
|
|
||
|
|
async def _default_list_tasks(req: ListTasksRequest) -> ServerResult:
|
||
|
|
cursor = req.params.cursor if req.params else None
|
||
|
|
tasks, next_cursor = await support.store.list_tasks(cursor)
|
||
|
|
return ServerResult(ListTasksResult(tasks=tasks, nextCursor=next_cursor))
|
||
|
|
|
||
|
|
self._request_handlers[ListTasksRequest] = _default_list_tasks
|
||
|
|
|
||
|
|
# Register cancel_task handler if not already registered
|
||
|
|
if CancelTaskRequest not in self._request_handlers:
|
||
|
|
|
||
|
|
async def _default_cancel_task(req: CancelTaskRequest) -> ServerResult:
|
||
|
|
result = await cancel_task(support.store, req.params.taskId)
|
||
|
|
return ServerResult(result)
|
||
|
|
|
||
|
|
self._request_handlers[CancelTaskRequest] = _default_cancel_task
|
||
|
|
|
||
|
|
def list_tasks(
|
||
|
|
self,
|
||
|
|
) -> Callable[
|
||
|
|
[Callable[[ListTasksRequest], Awaitable[ListTasksResult]]],
|
||
|
|
Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
|
||
|
|
]:
|
||
|
|
"""Register a handler for listing tasks.
|
||
|
|
|
||
|
|
WARNING: This API is experimental and may change without notice.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def decorator(
|
||
|
|
func: Callable[[ListTasksRequest], Awaitable[ListTasksResult]],
|
||
|
|
) -> Callable[[ListTasksRequest], Awaitable[ListTasksResult]]:
|
||
|
|
logger.debug("Registering handler for ListTasksRequest")
|
||
|
|
wrapper = create_call_wrapper(func, ListTasksRequest)
|
||
|
|
|
||
|
|
async def handler(req: ListTasksRequest) -> ServerResult:
|
||
|
|
result = await wrapper(req)
|
||
|
|
return ServerResult(result)
|
||
|
|
|
||
|
|
self._request_handlers[ListTasksRequest] = handler
|
||
|
|
return func
|
||
|
|
|
||
|
|
return decorator
|
||
|
|
|
||
|
|
def get_task(
|
||
|
|
self,
|
||
|
|
) -> Callable[
|
||
|
|
[Callable[[GetTaskRequest], Awaitable[GetTaskResult]]], Callable[[GetTaskRequest], Awaitable[GetTaskResult]]
|
||
|
|
]:
|
||
|
|
"""Register a handler for getting task status.
|
||
|
|
|
||
|
|
WARNING: This API is experimental and may change without notice.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def decorator(
|
||
|
|
func: Callable[[GetTaskRequest], Awaitable[GetTaskResult]],
|
||
|
|
) -> Callable[[GetTaskRequest], Awaitable[GetTaskResult]]:
|
||
|
|
logger.debug("Registering handler for GetTaskRequest")
|
||
|
|
wrapper = create_call_wrapper(func, GetTaskRequest)
|
||
|
|
|
||
|
|
async def handler(req: GetTaskRequest) -> ServerResult:
|
||
|
|
result = await wrapper(req)
|
||
|
|
return ServerResult(result)
|
||
|
|
|
||
|
|
self._request_handlers[GetTaskRequest] = handler
|
||
|
|
return func
|
||
|
|
|
||
|
|
return decorator
|
||
|
|
|
||
|
|
def get_task_result(
|
||
|
|
self,
|
||
|
|
) -> Callable[
|
||
|
|
[Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]],
|
||
|
|
Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]],
|
||
|
|
]:
|
||
|
|
"""Register a handler for getting task results/payload.
|
||
|
|
|
||
|
|
WARNING: This API is experimental and may change without notice.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def decorator(
|
||
|
|
func: Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]],
|
||
|
|
) -> Callable[[GetTaskPayloadRequest], Awaitable[GetTaskPayloadResult]]:
|
||
|
|
logger.debug("Registering handler for GetTaskPayloadRequest")
|
||
|
|
wrapper = create_call_wrapper(func, GetTaskPayloadRequest)
|
||
|
|
|
||
|
|
async def handler(req: GetTaskPayloadRequest) -> ServerResult:
|
||
|
|
result = await wrapper(req)
|
||
|
|
return ServerResult(result)
|
||
|
|
|
||
|
|
self._request_handlers[GetTaskPayloadRequest] = handler
|
||
|
|
return func
|
||
|
|
|
||
|
|
return decorator
|
||
|
|
|
||
|
|
def cancel_task(
|
||
|
|
self,
|
||
|
|
) -> Callable[
|
||
|
|
[Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]],
|
||
|
|
Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]],
|
||
|
|
]:
|
||
|
|
"""Register a handler for cancelling tasks.
|
||
|
|
|
||
|
|
WARNING: This API is experimental and may change without notice.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def decorator(
|
||
|
|
func: Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]],
|
||
|
|
) -> Callable[[CancelTaskRequest], Awaitable[CancelTaskResult]]:
|
||
|
|
logger.debug("Registering handler for CancelTaskRequest")
|
||
|
|
wrapper = create_call_wrapper(func, CancelTaskRequest)
|
||
|
|
|
||
|
|
async def handler(req: CancelTaskRequest) -> ServerResult:
|
||
|
|
result = await wrapper(req)
|
||
|
|
return ServerResult(result)
|
||
|
|
|
||
|
|
self._request_handlers[CancelTaskRequest] = handler
|
||
|
|
return func
|
||
|
|
|
||
|
|
return decorator
|