ai-station/.venv/lib/python3.12/site-packages/opentelemetry/instrumentation/mcp/instrumentation.py

637 lines
25 KiB
Python
Raw Permalink Normal View History

from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Callable, Collection, Tuple, Union, cast
import json
import logging
from opentelemetry import context, propagate
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.trace import get_tracer, Tracer
from wrapt import ObjectProxy, register_post_import_hook, wrap_function_wrapper
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.semconv_ai import SpanAttributes, TraceloopSpanKindValues
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
from opentelemetry.instrumentation.mcp.version import __version__
from opentelemetry.instrumentation.mcp.utils import dont_throw, Config
from opentelemetry.instrumentation.mcp.fastmcp_instrumentation import (
FastMCPInstrumentor,
)
_instruments = ("mcp >= 1.6.0",)
class McpInstrumentor(BaseInstrumentor):
def __init__(self, exception_logger=None):
super().__init__()
Config.exception_logger = exception_logger
self._fastmcp_instrumentor = FastMCPInstrumentor()
def instrumentation_dependencies(self) -> Collection[str]:
return _instruments
def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, __version__, tracer_provider)
# Instrument FastMCP
self._fastmcp_instrumentor.instrument(tracer)
# Instrument FastMCP Client to create a session-level span
register_post_import_hook(
lambda _: wrap_function_wrapper(
"fastmcp.client",
"Client.__aenter__",
self._fastmcp_client_enter_wrapper(tracer),
),
"fastmcp.client",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"fastmcp.client",
"Client.__aexit__",
self._fastmcp_client_exit_wrapper(tracer),
),
"fastmcp.client",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.client.sse", "sse_client", self._transport_wrapper(tracer)
),
"mcp.client.sse",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.server.sse",
"SseServerTransport.connect_sse",
self._transport_wrapper(tracer),
),
"mcp.server.sse",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.client.stdio", "stdio_client", self._transport_wrapper(tracer)
),
"mcp.client.stdio",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.server.stdio", "stdio_server", self._transport_wrapper(tracer)
),
"mcp.server.stdio",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.server.session",
"ServerSession.__init__",
self._base_session_init_wrapper(tracer),
),
"mcp.server.session",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.client.streamable_http",
"streamablehttp_client",
self._transport_wrapper(tracer),
),
"mcp.client.streamable_http",
)
register_post_import_hook(
lambda _: wrap_function_wrapper(
"mcp.server.streamable_http",
"StreamableHTTPServerTransport.connect",
self._transport_wrapper(tracer),
),
"mcp.server.streamable_http",
)
wrap_function_wrapper(
"mcp.shared.session",
"BaseSession.send_request",
self.patch_mcp_client(tracer),
)
def _uninstrument(self, **kwargs):
unwrap("mcp.client.stdio", "stdio_client")
unwrap("mcp.server.stdio", "stdio_server")
self._fastmcp_instrumentor.uninstrument()
def _transport_wrapper(self, tracer):
@asynccontextmanager
async def traced_method(
wrapped: Callable[..., Any], instance: Any, args: Any, kwargs: Any
) -> AsyncGenerator[
Union[
Tuple[InstrumentedStreamReader, InstrumentedStreamWriter],
Tuple[InstrumentedStreamReader, InstrumentedStreamWriter, Any],
],
None,
]:
async with wrapped(*args, **kwargs) as result:
try:
read_stream, write_stream = result
yield InstrumentedStreamReader(
read_stream, tracer
), InstrumentedStreamWriter(write_stream, tracer)
except ValueError:
try:
read_stream, write_stream, get_session_id_callback = result
yield InstrumentedStreamReader(
read_stream, tracer
), InstrumentedStreamWriter(
write_stream, tracer
), get_session_id_callback
except Exception as e:
logging.warning(
f"mcp instrumentation _transport_wrapper exception: {e}"
)
yield result
except Exception as e:
logging.warning(
f"mcp instrumentation transport_wrapper exception: {e}"
)
yield result
return traced_method
def _base_session_init_wrapper(self, tracer):
def traced_method(
wrapped: Callable[..., None], instance: Any, args: Any, kwargs: Any
) -> None:
wrapped(*args, **kwargs)
reader = getattr(instance, "_incoming_message_stream_reader", None)
writer = getattr(instance, "_incoming_message_stream_writer", None)
if reader and writer:
setattr(
instance,
"_incoming_message_stream_reader",
ContextAttachingStreamReader(reader, tracer),
)
setattr(
instance,
"_incoming_message_stream_writer",
ContextSavingStreamWriter(writer, tracer),
)
return traced_method
def patch_mcp_client(self, tracer: Tracer):
@dont_throw
async def traced_method(wrapped, instance, args, kwargs):
meta = None
method = None
params = None
if len(args) > 0 and hasattr(args[0].root, "method"):
method = args[0].root.method
if len(args) > 0 and hasattr(args[0].root, "params"):
params = args[0].root.params
if params:
if hasattr(args[0].root.params, "meta"):
meta = args[0].root.params.meta
# Handle trace context propagation
if meta and len(args) > 0:
carrier = {}
TraceContextTextMapPropagator().inject(carrier)
meta.traceparent = carrier["traceparent"]
args[0].root.params.meta = meta
# Create different span types based on method
if method == "tools/call":
return await self._handle_tool_call(
tracer, method, params, args, kwargs, wrapped
)
else:
return await self._handle_mcp_method(
tracer, method, args, kwargs, wrapped
)
return traced_method
def _fastmcp_client_enter_wrapper(self, tracer):
"""Wrapper for FastMCP Client.__aenter__ to start a session trace"""
@dont_throw
async def traced_method(wrapped, instance, args, kwargs):
# Start a root span for the MCP client session and make it current
span_context_manager = tracer.start_as_current_span("mcp.client.session")
span = span_context_manager.__enter__()
span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, "session")
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_NAME, "mcp.client.session"
)
# Store the span context manager on the instance to properly exit it later
setattr(instance, "_tracing_session_context_manager", span_context_manager)
try:
# Call the original method
result = await wrapped(*args, **kwargs)
return result
except Exception as e:
span.set_attribute(ERROR_TYPE, type(e).__name__)
span.record_exception(e)
span.set_status(Status(StatusCode.ERROR, str(e)))
raise
return traced_method
def _fastmcp_client_exit_wrapper(self, tracer):
"""Wrapper for FastMCP Client.__aexit__ to end the session trace"""
@dont_throw
async def traced_method(wrapped, instance, args, kwargs):
try:
# Call the original method first
result = await wrapped(*args, **kwargs)
# End the session span context manager
context_manager = getattr(
instance, "_tracing_session_context_manager", None
)
if context_manager:
context_manager.__exit__(None, None, None)
return result
except Exception as e:
# End the session span context manager with exception info
context_manager = getattr(
instance, "_tracing_session_context_manager", None
)
if context_manager:
context_manager.__exit__(type(e), e, e.__traceback__)
raise
return traced_method
async def _handle_tool_call(self, tracer, method, params, args, kwargs, wrapped):
"""Handle tools/call with tool semantics"""
# Extract the actual tool name
entity_name = method
span_name = f"{method}.tool"
if params:
try:
if hasattr(params, "name"):
entity_name = params.name
span_name = f"{params.name}.tool"
elif hasattr(params, "__dict__") and "name" in params.__dict__:
entity_name = params.__dict__["name"]
span_name = f"{params.__dict__['name']}.tool"
except Exception:
pass
with tracer.start_as_current_span(span_name) as span:
# Set tool-specific attributes
span.set_attribute(
SpanAttributes.TRACELOOP_SPAN_KIND, TraceloopSpanKindValues.TOOL.value
)
span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, entity_name)
# Add input
clean_input = self._extract_clean_input(method, params)
if clean_input:
try:
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT, json.dumps(clean_input)
)
except (TypeError, ValueError):
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT, str(clean_input)
)
return await self._execute_and_handle_result(
span, method, args, kwargs, wrapped, clean_output=True
)
async def _handle_mcp_method(self, tracer, method, args, kwargs, wrapped):
"""Handle non-tool MCP methods with simple serialization"""
with tracer.start_as_current_span(f"{method}.mcp") as span:
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT, f"{serialize(args[0])}"
)
return await self._execute_and_handle_result(
span, method, args, kwargs, wrapped, clean_output=False
)
async def _execute_and_handle_result(
self, span, method, args, kwargs, wrapped, clean_output=False
):
"""Execute the wrapped function and handle the result"""
try:
result = await wrapped(*args, **kwargs)
# Add output
if clean_output:
clean_output_data = self._extract_clean_output(method, result)
if clean_output_data:
try:
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps(clean_output_data),
)
except (TypeError, ValueError):
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
str(clean_output_data),
)
else:
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT, serialize(result)
)
# Handle errors
if hasattr(result, "isError") and result.isError:
if len(result.content) > 0:
span.set_status(
Status(StatusCode.ERROR, f"{result.content[0].text}")
)
else:
span.set_status(Status(StatusCode.OK))
return result
except Exception as e:
span.set_attribute(ERROR_TYPE, type(e).__name__)
span.record_exception(e)
span.set_status(Status(StatusCode.ERROR, str(e)))
raise
def _extract_clean_input(self, method: str, params: Any) -> dict:
"""Extract clean input parameters for different MCP method types"""
if not params:
return {}
try:
if method == "tools/call":
# For tool calls, extract name and arguments
result = {}
if hasattr(params, "name"):
result["tool_name"] = params.name
if hasattr(params, "arguments"):
if hasattr(params.arguments, "__dict__"):
result["arguments"] = params.arguments.__dict__
else:
result["arguments"] = params.arguments
elif hasattr(params, "__dict__") and "arguments" in params.__dict__:
result["arguments"] = params.__dict__["arguments"]
return result
elif method == "tools/list":
# For list_tools, there are usually no parameters
return {}
else:
# For other methods, try to serialize params cleanly
if hasattr(params, "__dict__"):
# Remove internal fields starting with _ and non-serializable objects
clean_params = {}
for k, v in params.__dict__.items():
if not k.startswith("_"):
try:
# Test if the value is JSON serializable
json.dumps(v)
clean_params[k] = v
except (TypeError, ValueError):
# If not serializable, store a string representation
clean_params[k] = str(type(v).__name__)
return clean_params
else:
return {"params": str(params)}
except Exception:
return {}
def _extract_clean_output(self, method: str, result: Any) -> dict:
"""Extract clean output for different MCP method types"""
if not result:
return {}
try:
if method == "tools/call":
# For tool calls, extract the actual result content
output = {}
if hasattr(result, "content") and result.content:
if len(result.content) > 0:
content_item = result.content[0]
if hasattr(content_item, "text"):
output["result"] = content_item.text
elif hasattr(content_item, "__dict__"):
output["result"] = content_item.__dict__
else:
output["result"] = str(content_item)
# Check if this is an error response
if hasattr(result, "isError") and result.isError:
output["is_error"] = True
return output
elif method == "tools/list":
# For list_tools, extract tool names and descriptions
output = {"tools": []}
if hasattr(result, "tools") and result.tools:
for tool in result.tools:
tool_info = {}
if hasattr(tool, "name"):
tool_info["name"] = tool.name
if hasattr(tool, "description"):
tool_info["description"] = tool.description
output["tools"].append(tool_info)
return output
else:
# For other methods, try to serialize result cleanly
if hasattr(result, "__dict__"):
clean_result = {
k: v
for k, v in result.__dict__.items()
if not k.startswith("_")
}
return clean_result
else:
return {"result": str(result)}
except Exception:
return {}
def serialize(request, depth=0, max_depth=4):
"""Serialize input args to MCP server into JSON.
The function accepts input object and converts into JSON
keeping depth in mind to prevent creating large nested JSON"""
if depth > max_depth:
return {}
depth += 1
def is_serializable(request):
try:
json.dumps(request)
return True
except Exception:
return False
if is_serializable(request):
return json.dumps(request)
else:
result = {}
try:
if hasattr(request, "model_dump_json"):
return request.model_dump_json()
if hasattr(request, "__dict__"):
for attrib in request.__dict__:
if not attrib.startswith("_"):
if type(request.__dict__[attrib]) in [
bool,
str,
int,
float,
type(None),
]:
result[str(attrib)] = request.__dict__[attrib]
else:
result[str(attrib)] = serialize(
request.__dict__[attrib], depth
)
except Exception:
pass
return json.dumps(result)
class InstrumentedStreamReader(ObjectProxy): # type: ignore
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
def __init__(self, wrapped, tracer):
super().__init__(wrapped)
self._tracer = tracer
async def __aenter__(self) -> Any:
return await self.__wrapped__.__aenter__()
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
@dont_throw
async def __aiter__(self) -> AsyncGenerator[Any, None]:
from mcp.types import JSONRPCMessage, JSONRPCRequest
async for item in self.__wrapped__:
# Handle different item types based on what's available
request = None
if hasattr(item, "message") and hasattr(item.message, "root"):
request = item.message.root
elif type(item) is JSONRPCMessage:
request = cast(JSONRPCMessage, item).root
elif hasattr(item, "root"):
request = item.root
else:
yield item
continue
if not isinstance(request, JSONRPCRequest):
yield item
continue
if request.params:
meta = request.params.get("_meta")
if meta:
ctx = propagate.extract(meta)
restore = context.attach(ctx)
try:
yield item
continue
finally:
context.detach(restore)
yield item
class InstrumentedStreamWriter(ObjectProxy): # type: ignore
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
def __init__(self, wrapped, tracer):
super().__init__(wrapped)
self._tracer = tracer
async def __aenter__(self) -> Any:
return await self.__wrapped__.__aenter__()
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
@dont_throw
async def send(self, item: Any) -> Any:
from mcp.types import JSONRPCMessage, JSONRPCRequest
# Handle different item types based on what's available
request = None
if hasattr(item, "message") and hasattr(item.message, "root"):
request = item.message.root
elif type(item) is JSONRPCMessage:
request = cast(JSONRPCMessage, item).root
elif hasattr(item, "root"):
request = item.root
else:
return await self.__wrapped__.send(item)
with self._tracer.start_as_current_span("ResponseStreamWriter") as span:
if hasattr(request, "result"):
span.set_attribute(
SpanAttributes.MCP_RESPONSE_VALUE, f"{serialize(request.result)}"
)
if "isError" in request.result:
if request.result["isError"] is True:
span.set_status(
Status(
StatusCode.ERROR,
f"{request.result['content'][0]['text']}",
)
)
if hasattr(request, "id"):
span.set_attribute(SpanAttributes.MCP_REQUEST_ID, f"{request.id}")
if not isinstance(request, JSONRPCRequest):
return await self.__wrapped__.send(item)
meta = None
if not request.params:
request.params = {}
meta = request.params.setdefault("_meta", {})
propagate.get_global_textmap().inject(meta)
return await self.__wrapped__.send(item)
@dataclass(slots=True, frozen=True)
class ItemWithContext:
item: Any
ctx: context.Context
class ContextSavingStreamWriter(ObjectProxy): # type: ignore
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
def __init__(self, wrapped, tracer):
super().__init__(wrapped)
self._tracer = tracer
async def __aenter__(self) -> Any:
return await self.__wrapped__.__aenter__()
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
@dont_throw
async def send(self, item: Any) -> Any:
# Removed RequestStreamWriter span creation - we don't need low-level protocol spans
ctx = context.get_current()
return await self.__wrapped__.send(ItemWithContext(item, ctx))
class ContextAttachingStreamReader(ObjectProxy): # type: ignore
# ObjectProxy missing context manager - https://github.com/GrahamDumpleton/wrapt/issues/73
def __init__(self, wrapped, tracer):
super().__init__(wrapped)
self._tracer = tracer
async def __aenter__(self) -> Any:
return await self.__wrapped__.__aenter__()
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> Any:
return await self.__wrapped__.__aexit__(exc_type, exc_value, traceback)
async def __aiter__(self) -> AsyncGenerator[Any, None]:
async for item in self.__wrapped__:
item_with_context = cast(ItemWithContext, item)
restore = context.attach(item_with_context.ctx)
try:
yield item_with_context.item
finally:
context.detach(restore)