ai-station/.venv/lib/python3.12/site-packages/mcp/shared/progress.py

59 lines
1.7 KiB
Python
Raw Normal View History

2025-12-25 14:54:33 +00:00
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Generic
from pydantic import BaseModel
from mcp.shared.context import LifespanContextT, RequestContext
from mcp.shared.session import (
BaseSession,
ReceiveNotificationT,
ReceiveRequestT,
SendNotificationT,
SendRequestT,
SendResultT,
)
from mcp.types import ProgressToken
class Progress(BaseModel):
progress: float
total: float | None
@dataclass
class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]):
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]
progress_token: ProgressToken
total: float | None
current: float = field(default=0.0, init=False)
async def progress(self, amount: float, message: str | None = None) -> None:
self.current += amount
await self.session.send_progress_notification(
self.progress_token, self.current, total=self.total, message=message
)
@contextmanager
def progress(
ctx: RequestContext[
BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
LifespanContextT,
],
total: float | None = None,
) -> Generator[
ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
None,
]:
if ctx.meta is None or ctx.meta.progressToken is None: # pragma: no cover
raise ValueError("No progress token provided")
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
try:
yield progress_ctx
finally:
pass