637 lines
25 KiB
Python
637 lines
25 KiB
Python
|
|
from contextlib import asynccontextmanager
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Any, AsyncGenerator, Callable, Collection, Tuple, Union, cast
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
|
||
|
|
from opentelemetry import context, propagate
|
||
|
|
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
|
||
|
|
from opentelemetry.instrumentation.utils import unwrap
|
||
|
|
from opentelemetry.trace import get_tracer, Tracer
|
||
|
|
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper
|
||
|
|
from opentelemetry.trace.status import Status, StatusCode
|
||
|
|
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||
|
|
from opentelemetry.semconv_ai import SpanAttributes, TraceloopSpanKindValues
|
||
|
|
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
|
||
|
|
|
||
|
|
from opentelemetry.instrumentation.mcp.version import __version__
|
||
|
|
from opentelemetry.instrumentation.mcp.utils import dont_throw, Config
|
||
|
|
from opentelemetry.instrumentation.mcp.fastmcp_instrumentation import (
|
||
|
|
FastMCPInstrumentor,
|
||
|
|
)
|
||
|
|
|
||
|
|
_instruments = ("mcp >= 1.6.0",)
|
||
|
|
|
||
|
|
|
||
|
|
class McpInstrumentor(BaseInstrumentor):
|
||
|
|
def __init__(self, exception_logger=None):
|
||
|
|
super().__init__()
|
||
|
|
Config.exception_logger = exception_logger
|
||
|
|
self._fastmcp_instrumentor = FastMCPInstrumentor()
|
||
|
|
|
||
|
|
def instrumentation_dependencies(self) -> Collection[str]:
|
||
|
|
return _instruments
|
||
|
|
|
||
|
|
def _instrument(self, **kwargs):
|
||
|
|
tracer_provider = kwargs.get("tracer_provider")
|
||
|
|
tracer = get_tracer(__name__, __version__, tracer_provider)
|
||
|
|
|
||
|
|
# Instrument FastMCP
|
||
|
|
self._fastmcp_instrumentor.instrument(tracer)
|
||
|
|
|
||
|
|
# Instrument FastMCP Client to create a session-level span
|
||
|
|
register_post_import_hook(
|
||
|
|
lambda _: wrap_function_wrapper(
|
||
|
|
"fastmcp.client",
|
||
|
|
"Client.__aenter__",
|
||
|
|
self._fastmcp_client_enter_wrapper(tracer),
|
||
|
|
),
|
||
|
|
"fastmcp.client",
|
||
|
|
)
|
||
|
|
register_post_import_hook(
|
||
|
|
lambda _: wrap_function_wrapper(
|
||
|
|
"fastmcp.client",
|
||
|
|
"Client.__aexit__",
|
||
|
|
self._fastmcp_client_exit_wrapper(tracer),
|
||
|
|
),
|
||
|
|
"fastmcp.client",
|
||
|
|
)
|
||
|
|
|
||
|
|
register_post_import_hook(
|
||
|
|
lambda _: wrap_function_wrapper(
|
||
|
|
"mcp.client.sse", "sse_client", self._transport_wrapper(tracer)
|
||
|
|
),
|
||
|
|
"mcp.client.sse",
|
||
|
|
)
|
||
|
|
register_post_import_hook(
|
||
|
|
lambda _: wrap_function_wrapper(
|
||
|
|
"mcp.server.sse",
|
||
|
|
"SseServerTransport.connect_sse",
|
||
|
|
self._transport_wrapper(tracer),
|
||
|
|
),
|
||
|
|
"mcp.server.sse",
|
||
|
|
)
|
||
|
|
register_post_import_hook(
|
||
|
|
lambda _: wrap_function_wrapper(
|
||
|
|
"mcp.client.stdio", "stdio_client", self._transport_wrapper(tracer)
|
||
|
|
),
|
||
|
|
"mcp.client.stdio",
|
||
|
|
)
|
||
|
|
register_post_import_hook(
|
||
|
|
lambda _: wrap_function_wrapper(
|
||
|
|
"mcp.server.stdio", "stdio_server", self._transport_wrapper(tracer)
|
||
|
|
),
|
||
|
|
"mcp.server.stdio",
|
||
|
|
)
|
||
|
|
register_post_import_hook(
|
||
|
|
lambda _: wrap_function_wrapper(
|
||
|
|
"mcp.server.session",
|
||
|
|
"ServerSession.__init__",
|
||
|
|
self._base_session_init_wrapper(tracer),
|
||
|
|
),
|
||
|
|
"mcp.server.session",
|
||
|
|
)
|
||
|
|
register_post_import_hook(
|
||
|
|
lambda _: wrap_function_wrapper(
|
||
|
|
"mcp.client.streamable_http",
|
||
|
|
"streamablehttp_client",
|
||
|
|
self._transport_wrapper(tracer),
|
||
|
|
),
|
||
|
|
"mcp.client.streamable_http",
|
||
|
|
)
|
||
|
|
register_post_import_hook(
|
||
|
|
lambda _: wrap_function_wrapper(
|
||
|
|
"mcp.server.streamable_http",
|
||
|
|
"StreamableHTTPServerTransport.connect",
|
||
|
|
self._transport_wrapper(tracer),
|
||
|
|
),
|
||
|
|
"mcp.server.streamable_http",
|
||
|
|
)
|
||
|
|
wrap_function_wrapper(
|
||
|
|
"mcp.shared.session",
|
||
|
|
"BaseSession.send_request",
|
||
|
|
self.patch_mcp_client(tracer),
|
||
|
|
)
|
||
|
|
|
||
|
|
def _uninstrument(self, **kwargs):
|
||
|
|
unwrap("mcp.client.stdio", "stdio_client")
|
||
|
|
unwrap("mcp.server.stdio", "stdio_server")
|
||
|
|
self._fastmcp_instrumentor.uninstrument()
|
||
|
|
|
||
|
|
def _transport_wrapper(self, tracer):
|
||
|
|
@asynccontextmanager
|
||
|
|
async def traced_method(
|
||
|
|
wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
|
||
|
|
) -> AsyncGenerator[
|
||
|
|
Union[
|
||
|
|
Tuple[InstrumentedStreamReader, InstrumentedStreamWriter],
|
||
|
|
Tuple[InstrumentedStreamReader, InstrumentedStreamWriter, Any],
|
||
|
|
],
|
||
|
|
None,
|
||
|
|
]:
|
||
|
|
async with wrapped(*args, **kwargs) as result:
|
||
|
|
try:
|
||
|
|
read_stream, write_stream = result
|
||
|
|
yield InstrumentedStreamReader(
|
||
|
|
read_stream, tracer
|
||
|
|
), InstrumentedStreamWriter(write_stream, tracer)
|
||
|
|
except ValueError:
|
||
|
|
try:
|
||
|
|
read_stream, write_stream, get_session_id_callback = result
|
||
|
|
yield InstrumentedStreamReader(
|
||
|
|
read_stream, tracer
|
||
|
|
), InstrumentedStreamWriter(
|
||
|
|
write_stream, tracer
|
||
|
|
), get_session_id_callback
|
||
|
|
except Exception as e:
|
||
|
|
logging.warning(
|
||
|
|
f"mcp instrumentation _transport_wrapper exception: {e}"
|
||
|
|
)
|
||
|
|
yield result
|
||
|
|
except Exception as e:
|
||
|
|
logging.warning(
|
||
|
|
f"mcp instrumentation transport_wrapper exception: {e}"
|
||
|
|
)
|
||
|
|
yield result
|
||
|
|
|
||
|
|
return traced_method
|
||
|
|
|
||
|
|
def _base_session_init_wrapper(self, tracer):
|
||
|
|
def traced_method(
|
||
|
|
wrapped: Callable[..., None], instance: Any, args: Any, kwargs: Any
|
||
|
|
) -> None:
|
||
|
|
wrapped(*args, **kwargs)
|
||
|
|
reader = getattr(instance, "_incoming_message_stream_reader", None)
|
||
|
|
writer = getattr(instance, "_incoming_message_stream_writer", None)
|
||
|
|
if reader and writer:
|
||
|
|
setattr(
|
||
|
|
instance,
|
||
|
|
"_incoming_message_stream_reader",
|
||
|
|
ContextAttachingStreamReader(reader, tracer),
|
||
|
|
)
|
||
|
|
setattr(
|
||
|
|
instance,
|
||
|
|
"_incoming_message_stream_writer",
|
||
|
|
ContextSavingStreamWriter(writer, tracer),
|
||
|
|
)
|
||
|
|
|
||
|
|
return traced_method
|
||
|
|
|
||
|
|
def patch_mcp_client(self, tracer: Tracer):
|
||
|
|
@dont_throw
|
||
|
|
async def traced_method(wrapped, instance, args, kwargs):
|
||
|
|
meta = None
|
||
|
|
method = None
|
||
|
|
params = None
|
||
|
|
if len(args) > 0 and hasattr(args[0].root, "method"):
|
||
|
|
method = args[0].root.method
|
||
|
|
if len(args) > 0 and hasattr(args[0].root, "params"):
|
||
|
|
params = args[0].root.params
|
||
|
|
if params:
|
||
|
|
if hasattr(args[0].root.params, "meta"):
|
||
|
|
meta = args[0].root.params.meta
|
||
|
|
|
||
|
|
# Handle trace context propagation
|
||
|
|
if meta and len(args) > 0:
|
||
|
|
carrier = {}
|
||
|
|
TraceContextTextMapPropagator().inject(carrier)
|
||
|
|
meta.traceparent = carrier["traceparent"]
|
||
|
|
args[0].root.params.meta = meta
|
||
|
|
|
||
|
|
# Create different span types based on method
|
||
|
|
if method == "tools/call":
|
||
|
|
return await self._handle_tool_call(
|
||
|
|
tracer, method, params, args, kwargs, wrapped
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
return await self._handle_mcp_method(
|
||
|
|
tracer, method, args, kwargs, wrapped
|
||
|
|
)
|
||
|
|
|
||
|
|
return traced_method
|
||
|
|
|
||
|
|
def _fastmcp_client_enter_wrapper(self, tracer):
|
||
|
|
"""Wrapper for FastMCP Client.__aenter__ to start a session trace"""
|
||
|
|
|
||
|
|
@dont_throw
|
||
|
|
async def traced_method(wrapped, instance, args, kwargs):
|
||
|
|
# Start a root span for the MCP client session and make it current
|
||
|
|
span_context_manager = tracer.start_as_current_span("mcp.client.session")
|
||
|
|
span = span_context_manager.__enter__()
|
||
|
|
span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, "session")
|
||
|
|
span.set_attribute(
|
||
|
|
SpanAttributes.TRACELOOP_ENTITY_NAME, "mcp.client.session"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Store the span context manager on the instance to properly exit it later
|
||
|
|
setattr(instance, "_tracing_session_context_manager", span_context_manager)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Call the original method
|
||
|
|
result = await wrapped(*args, **kwargs)
|
||
|
|
return result
|
||
|
|
except Exception as e:
|
||
|
|
span.set_attribute(ERROR_TYPE, type(e).__name__)
|
||
|
|
span.record_exception(e)
|
||
|
|
span.set_status(Status(StatusCode.ERROR, str(e)))
|
||
|
|
raise
|
||
|
|
|
||
|
|
return traced_method
|
||
|
|
|
||
|
|
def _fastmcp_client_exit_wrapper(self, tracer):
|
||
|
|
"""Wrapper for FastMCP Client.__aexit__ to end the session trace"""
|
||
|
|
|
||
|
|
@dont_throw
|
||
|
|
async def traced_method(wrapped, instance, args, kwargs):
|
||
|
|
try:
|
||
|
|
# Call the original method first
|
||
|
|
result = await wrapped(*args, **kwargs)
|
||
|
|
|
||
|
|
# End the session span context manager
|
||
|
|
context_manager = getattr(
|
||
|
|
instance, "_tracing_session_context_manager", None
|
||
|
|
)
|
||
|
|
if context_manager:
|
||
|
|
context_manager.__exit__(None, None, None)
|
||
|
|
|
||
|
|
return result
|
||
|
|
except Exception as e:
|
||
|
|
# End the session span context manager with exception info
|
||
|
|
context_manager = getattr(
|
||
|
|
instance, "_tracing_session_context_manager", None
|
||
|
|
)
|
||
|
|
if context_manager:
|
||
|
|
context_manager.__exit__(type(e), e, e.__traceback__)
|
||
|
|
raise
|
||
|
|
|
||
|
|
return traced_method
|
||
|
|
|
||
|
|
async def _handle_tool_call(self, tracer, method, params, args, kwargs, wrapped):
|
||
|
|
"""Handle tools/call with tool semantics"""
|
||
|
|
# Extract the actual tool name
|
||
|
|
entity_name = method
|
||
|
|
span_name = f"{method}.tool"
|
||
|
|
if params:
|
||
|
|
try:
|
||
|
|
if hasattr(params, "name"):
|
||
|
|
entity_name = params.name
|
||
|
|
span_name = f"{params.name}.tool"
|
||
|
|
elif hasattr(params, "__dict__") and "name" in params.__dict__:
|
||
|
|
entity_name = params.__dict__["name"]
|
||
|
|
span_name = f"{params.__dict__['name']}.tool"
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
|
||
|
|
with tracer.start_as_current_span(span_name) as span:
|
||
|
|
# Set tool-specific attributes
|
||
|
|
span.set_attribute(
|
||
|
|
SpanAttributes.TRACELOOP_SPAN_KIND, TraceloopSpanKindValues.TOOL.value
|
||
|
|
)
|
||
|
|
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, entity_name)
|
||
|
|
|
||
|
|
# Add input
|
||
|
|
clean_input = self._extract_clean_input(method, params)
|
||
|
|
if clean_input:
|
||
|
|
try:
|
||
|
|
span.set_attribute(
|
||
|
|
SpanAttributes.TRACELOOP_ENTITY_INPUT, json.dumps(clean_input)
|
||
|
|
)
|
||
|
|
except (TypeError, ValueError):
|
||
|
|
span.set_attribute(
|
||
|
|
SpanAttributes.TRACELOOP_ENTITY_INPUT, str(clean_input)
|
||
|
|
)
|
||
|
|
|
||
|
|
return await self._execute_and_handle_result(
|
||
|
|
span, method, args, kwargs, wrapped, clean_output=True
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _handle_mcp_method(self, tracer, method, args, kwargs, wrapped):
|
||
|
|
"""Handle non-tool MCP methods with simple serialization"""
|
||
|
|
with tracer.start_as_current_span(f"{method}.mcp") as span:
|
||
|
|
span.set_attribute(
|
||
|
|
SpanAttributes.TRACELOOP_ENTITY_INPUT, f"{serialize(args[0])}"
|
||
|
|
)
|
||
|
|
return await self._execute_and_handle_result(
|
||
|
|
span, method, args, kwargs, wrapped, clean_output=False
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _execute_and_handle_result(
|
||
|
|
self, span, method, args, kwargs, wrapped, clean_output=False
|
||
|
|
):
|
||
|
|
"""Execute the wrapped function and handle the result"""
|
||
|
|
try:
|
||
|
|
result = await wrapped(*args, **kwargs)
|
||
|
|
# Add output
|
||
|
|
if clean_output:
|
||
|
|
clean_output_data = self._extract_clean_output(method, result)
|
||
|
|
if clean_output_data:
|
||
|
|
try:
|
||
|
|
span.set_attribute(
|
||
|
|
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
|
||
|
|
json.dumps(clean_output_data),
|
||
|
|
)
|
||
|
|
except (TypeError, ValueError):
|
||
|
|
span.set_attribute(
|
||
|
|
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
|
||
|
|
str(clean_output_data),
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
span.set_attribute(
|
||
|
|
SpanAttributes.TRACELOOP_ENTITY_OUTPUT, serialize(result)
|
||
|
|
)
|
||
|
|
# Handle errors
|
||
|
|
if hasattr(result, "isError") and result.isError:
|
||
|
|
if len(result.content) > 0:
|
||
|
|
span.set_status(
|
||
|
|
Status(StatusCode.ERROR, f"{result.content[0].text}")
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
span.set_status(Status(StatusCode.OK))
|
||
|
|
return result
|
||
|
|
except Exception as e:
|
||
|
|
span.set_attribute(ERROR_TYPE, type(e).__name__)
|
||
|
|
span.record_exception(e)
|
||
|
|
span.set_status(Status(StatusCode.ERROR, str(e)))
|
||
|
|
raise
|
||
|
|
|
||
|
|
def _extract_clean_input(self, method: str, params: Any) -> dict:
|
||
|
|
"""Extract clean input parameters for different MCP method types"""
|
||
|
|
if not params:
|
||
|
|
return {}
|
||
|
|
|
||
|
|
try:
|
||
|
|
if method == "tools/call":
|
||
|
|
# For tool calls, extract name and arguments
|
||
|
|
result = {}
|
||
|
|
if hasattr(params, "name"):
|
||
|
|
result["tool_name"] = params.name
|
||
|
|
if hasattr(params, "arguments"):
|
||
|
|
if hasattr(params.arguments, "__dict__"):
|
||
|
|
result["arguments"] = params.arguments.__dict__
|
||
|
|
else:
|
||
|
|
result["arguments"] = params.arguments
|
||
|
|
elif hasattr(params, "__dict__") and "arguments" in params.__dict__:
|
||
|
|
result["arguments"] = params.__dict__["arguments"]
|
||
|
|
return result
|
||
|
|
elif method == "tools/list":
|
||
|
|
# For list_tools, there are usually no parameters
|
||
|
|
return {}
|
||
|
|
else:
|
||
|
|
# For other methods, try to serialize params cleanly
|
||
|
|
if hasattr(params, "__dict__"):
|
||
|
|
# Remove internal fields starting with _ and non-serializable objects
|
||
|
|
clean_params = {}
|
||
|
|
for k, v in params.__dict__.items():
|
||
|
|
if not k.startswith("_"):
|
||
|
|
try:
|
||
|
|
# Test if the value is JSON serializable
|
||
|
|
json.dumps(v)
|
||
|
|
clean_params[k] = v
|
||
|
|
except (TypeError, ValueError):
|
||
|
|
# If not serializable, store a string representation
|
||
|
|
clean_params[k] = str(type(v).__name__)
|
||
|
|
return clean_params
|
||
|
|
else:
|
||
|
|
return {"params": str(params)}
|
||
|
|
except Exception:
|
||
|
|
return {}
|
||
|
|
|
||
|
|
def _extract_clean_output(self, method: str, result: Any) -> dict:
|
||
|
|
"""Extract clean output for different MCP method types"""
|
||
|
|
if not result:
|
||
|
|
return {}
|
||
|
|
|
||
|
|
try:
|
||
|
|
if method == "tools/call":
|
||
|
|
# For tool calls, extract the actual result content
|
||
|
|
output = {}
|
||
|
|
if hasattr(result, "content") and result.content:
|
||
|
|
if len(result.content) > 0:
|
||
|
|
content_item = result.content[0]
|
||
|
|
if hasattr(content_item, "text"):
|
||
|
|
output["result"] = content_item.text
|
||
|
|
elif hasattr(content_item, "__dict__"):
|
||
|
|
output["result"] = content_item.__dict__
|
||
|
|
else:
|
||
|
|
output["result"] = str(content_item)
|
||
|
|
|
||
|
|
# Check if this is an error response
|
||
|
|
if hasattr(result, "isError") and result.isError:
|
||
|
|
output["is_error"] = True
|
||
|
|
|
||
|
|
return output
|
||
|
|
elif method == "tools/list":
|
||
|
|
# For list_tools, extract tool names and descriptions
|
||
|
|
output = {"tools": []}
|
||
|
|
if hasattr(result, "tools") and result.tools:
|
||
|
|
for tool in result.tools:
|
||
|
|
tool_info = {}
|
||
|
|
if hasattr(tool, "name"):
|
||
|
|
tool_info["name"] = tool.name
|
||
|
|
if hasattr(tool, "description"):
|
||
|
|
tool_info["description"] = tool.description
|
||
|
|
output["tools"].append(tool_info)
|
||
|
|
return output
|
||
|
|
else:
|
||
|
|
# For other methods, try to serialize result cleanly
|
||
|
|
if hasattr(result, "__dict__"):
|
||
|
|
clean_result = {
|
||
|
|
k: v
|
||
|
|
for k, v in result.__dict__.items()
|
||
|
|
if not k.startswith("_")
|
||
|
|
}
|
||
|
|
return clean_result
|
||
|
|
else:
|
||
|
|
return {"result": str(result)}
|
||
|
|
except Exception:
|
||
|
|
return {}
|
||
|
|
|
||
|
|
|
||
|
|
def serialize(request, depth=0, max_depth=4):
|
||
|
|
"""Serialize input args to MCP server into JSON.
|
||
|
|
The function accepts input object and converts into JSON
|
||
|
|
keeping depth in mind to prevent creating large nested JSON"""
|
||
|
|
if depth > max_depth:
|
||
|
|
return {}
|
||
|
|
depth += 1
|
||
|
|
|
||
|
|
def is_serializable(request):
|
||
|
|
try:
|
||
|
|
json.dumps(request)
|
||
|
|
return True
|
||
|
|
except Exception:
|
||
|
|
return False
|
||
|
|
|
||
|
|
if is_serializable(request):
|
||
|
|
return json.dumps(request)
|
||
|
|
else:
|
||
|
|
result = {}
|
||
|
|
try:
|
||
|
|
if hasattr(request, "model_dump_json"):
|
||
|
|
return request.model_dump_json()
|
||
|
|
if hasattr(request, "__dict__"):
|
||
|
|
for attrib in request.__dict__:
|
||
|
|
if not attrib.startswith("_"):
|
||
|
|
if type(request.__dict__[attrib]) in [
|
||
|
|
bool,
|
||
|
|
str,
|
||
|
|
int,
|
||
|
|
float,
|
||
|
|
type(None),
|
||
|
|
]:
|
||
|
|
result[str(attrib)] = request.__dict__[attrib]
|
||
|
|
else:
|
||
|
|
result[str(attrib)] = serialize(
|
||
|
|
request.__dict__[attrib], depth
|
||
|
|
)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
return json.dumps(result)
|
||
|
|
|
||
|
|
|
||
|
|
class InstrumentedStreamReader(ObjectProxy): # type: ignore
|
||
|
|
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
|
||
|
|
def __init__(self, wrapped, tracer):
|
||
|
|
super().__init__(wrapped)
|
||
|
|
self._tracer = tracer
|
||
|
|
|
||
|
|
async def __aenter__(self) -> Any:
|
||
|
|
return await self.__wrapped__.__aenter__()
|
||
|
|
|
||
|
|
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
|
||
|
|
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
|
||
|
|
|
||
|
|
@dont_throw
|
||
|
|
async def __aiter__(self) -> AsyncGenerator[Any, None]:
|
||
|
|
from mcp.types import JSONRPCMessage, JSONRPCRequest
|
||
|
|
|
||
|
|
async for item in self.__wrapped__:
|
||
|
|
# Handle different item types based on what's available
|
||
|
|
request = None
|
||
|
|
if hasattr(item, "message") and hasattr(item.message, "root"):
|
||
|
|
request = item.message.root
|
||
|
|
elif type(item) is JSONRPCMessage:
|
||
|
|
request = cast(JSONRPCMessage, item).root
|
||
|
|
elif hasattr(item, "root"):
|
||
|
|
request = item.root
|
||
|
|
else:
|
||
|
|
yield item
|
||
|
|
continue
|
||
|
|
|
||
|
|
if not isinstance(request, JSONRPCRequest):
|
||
|
|
yield item
|
||
|
|
continue
|
||
|
|
|
||
|
|
if request.params:
|
||
|
|
meta = request.params.get("_meta")
|
||
|
|
if meta:
|
||
|
|
ctx = propagate.extract(meta)
|
||
|
|
restore = context.attach(ctx)
|
||
|
|
try:
|
||
|
|
yield item
|
||
|
|
continue
|
||
|
|
finally:
|
||
|
|
context.detach(restore)
|
||
|
|
yield item
|
||
|
|
|
||
|
|
|
||
|
|
class InstrumentedStreamWriter(ObjectProxy): # type: ignore
|
||
|
|
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
|
||
|
|
def __init__(self, wrapped, tracer):
|
||
|
|
super().__init__(wrapped)
|
||
|
|
self._tracer = tracer
|
||
|
|
|
||
|
|
async def __aenter__(self) -> Any:
|
||
|
|
return await self.__wrapped__.__aenter__()
|
||
|
|
|
||
|
|
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
|
||
|
|
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
|
||
|
|
|
||
|
|
@dont_throw
|
||
|
|
async def send(self, item: Any) -> Any:
|
||
|
|
from mcp.types import JSONRPCMessage, JSONRPCRequest
|
||
|
|
|
||
|
|
# Handle different item types based on what's available
|
||
|
|
request = None
|
||
|
|
if hasattr(item, "message") and hasattr(item.message, "root"):
|
||
|
|
request = item.message.root
|
||
|
|
elif type(item) is JSONRPCMessage:
|
||
|
|
request = cast(JSONRPCMessage, item).root
|
||
|
|
elif hasattr(item, "root"):
|
||
|
|
request = item.root
|
||
|
|
else:
|
||
|
|
return await self.__wrapped__.send(item)
|
||
|
|
|
||
|
|
with self._tracer.start_as_current_span("ResponseStreamWriter") as span:
|
||
|
|
if hasattr(request, "result"):
|
||
|
|
span.set_attribute(
|
||
|
|
SpanAttributes.MCP_RESPONSE_VALUE, f"{serialize(request.result)}"
|
||
|
|
)
|
||
|
|
if "isError" in request.result:
|
||
|
|
if request.result["isError"] is True:
|
||
|
|
span.set_status(
|
||
|
|
Status(
|
||
|
|
StatusCode.ERROR,
|
||
|
|
f"{request.result['content'][0]['text']}",
|
||
|
|
)
|
||
|
|
)
|
||
|
|
if hasattr(request, "id"):
|
||
|
|
span.set_attribute(SpanAttributes.MCP_REQUEST_ID, f"{request.id}")
|
||
|
|
|
||
|
|
if not isinstance(request, JSONRPCRequest):
|
||
|
|
return await self.__wrapped__.send(item)
|
||
|
|
meta = None
|
||
|
|
if not request.params:
|
||
|
|
request.params = {}
|
||
|
|
meta = request.params.setdefault("_meta", {})
|
||
|
|
|
||
|
|
propagate.get_global_textmap().inject(meta)
|
||
|
|
return await self.__wrapped__.send(item)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True, frozen=True)
|
||
|
|
class ItemWithContext:
|
||
|
|
item: Any
|
||
|
|
ctx: context.Context
|
||
|
|
|
||
|
|
|
||
|
|
class ContextSavingStreamWriter(ObjectProxy): # type: ignore
|
||
|
|
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
|
||
|
|
def __init__(self, wrapped, tracer):
|
||
|
|
super().__init__(wrapped)
|
||
|
|
self._tracer = tracer
|
||
|
|
|
||
|
|
async def __aenter__(self) -> Any:
|
||
|
|
return await self.__wrapped__.__aenter__()
|
||
|
|
|
||
|
|
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
|
||
|
|
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
|
||
|
|
|
||
|
|
@dont_throw
|
||
|
|
async def send(self, item: Any) -> Any:
|
||
|
|
# Removed RequestStreamWriter span creation - we don't need low-level protocol spans
|
||
|
|
ctx = context.get_current()
|
||
|
|
return await self.__wrapped__.send(ItemWithContext(item, ctx))
|
||
|
|
|
||
|
|
|
||
|
|
class ContextAttachingStreamReader(ObjectProxy): # type: ignore
|
||
|
|
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
|
||
|
|
def __init__(self, wrapped, tracer):
|
||
|
|
super().__init__(wrapped)
|
||
|
|
self._tracer = tracer
|
||
|
|
|
||
|
|
async def __aenter__(self) -> Any:
|
||
|
|
return await self.__wrapped__.__aenter__()
|
||
|
|
|
||
|
|
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
|
||
|
|
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
|
||
|
|
|
||
|
|
async def __aiter__(self) -> AsyncGenerator[Any, None]:
|
||
|
|
async for item in self.__wrapped__:
|
||
|
|
item_with_context = cast(ItemWithContext, item)
|
||
|
|
restore = context.attach(item_with_context.ctx)
|
||
|
|
try:
|
||
|
|
yield item_with_context.item
|
||
|
|
finally:
|
||
|
|
context.detach(restore)
|