ai-station/.venv/lib/python3.12/site-packages/opentelemetry/instrumentation/mcp/fastmcp_instrumentation.py

177 lines
7.8 KiB
Python

"""FastMCP-specific instrumentation logic."""
import json
import os
from opentelemetry.trace import Tracer
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.semconv_ai import SpanAttributes, TraceloopSpanKindValues
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
from wrapt import register_post_import_hook, wrap_function_wrapper
from .utils import dont_throw
class FastMCPInstrumentor:
"""Handles FastMCP-specific instrumentation logic."""
def __init__(self):
self._tracer = None
self._server_name = None
def instrument(self, tracer: Tracer):
"""Apply FastMCP-specific instrumentation."""
self._tracer = tracer
# Instrument FastMCP server-side tool execution
register_post_import_hook(
lambda _: wrap_function_wrapper(
"fastmcp.tools.tool_manager", "ToolManager.call_tool", self._fastmcp_tool_wrapper()
),
"fastmcp.tools.tool_manager",
)
# Instrument FastMCP __init__ to capture server name
register_post_import_hook(
lambda _: wrap_function_wrapper(
"fastmcp", "FastMCP.__init__", self._fastmcp_init_wrapper()
),
"fastmcp",
)
def uninstrument(self):
"""Remove FastMCP-specific instrumentation."""
# Note: wrapt doesn't provide a clean way to unwrap post-import hooks
# This is a limitation we'll need to document
pass
def _fastmcp_init_wrapper(self):
"""Create wrapper for FastMCP initialization to capture server name."""
@dont_throw
def traced_method(wrapped, instance, args, kwargs):
# Call the original __init__ first
result = wrapped(*args, **kwargs)
if args and len(args) > 0:
self._server_name = f"{args[0]}.mcp"
elif 'name' in kwargs:
self._server_name = f"{kwargs['name']}.mcp"
return result
return traced_method
def _fastmcp_tool_wrapper(self):
"""Create wrapper for FastMCP tool execution."""
async def traced_method(wrapped, instance, args, kwargs):
if not self._tracer:
return await wrapped(*args, **kwargs)
# Extract tool name from arguments - FastMCP has different call patterns
tool_key = None
tool_arguments = {}
# Pattern 1: kwargs with 'key' parameter
if kwargs and 'key' in kwargs:
tool_key = kwargs.get('key')
tool_arguments = kwargs.get('arguments', {})
# Pattern 2: positional args (tool_name, arguments)
elif args and len(args) >= 1:
tool_key = args[0]
tool_arguments = args[1] if len(args) > 1 else {}
entity_name = tool_key if tool_key else "unknown_tool"
# Create parent server.mcp span
with self._tracer.start_as_current_span("mcp.server") as mcp_span:
mcp_span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, "server")
mcp_span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, "mcp.server")
if self._server_name:
mcp_span.set_attribute(SpanAttributes.TRACELOOP_WORKFLOW_NAME, self._server_name)
# Create nested tool span
span_name = f"{entity_name}.tool"
with self._tracer.start_as_current_span(span_name) as tool_span:
tool_span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, TraceloopSpanKindValues.TOOL.value)
tool_span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, entity_name)
if self._server_name:
tool_span.set_attribute(SpanAttributes.TRACELOOP_WORKFLOW_NAME, self._server_name)
if self._should_send_prompts():
try:
input_data = {
"tool_name": entity_name,
"arguments": tool_arguments
}
json_input = json.dumps(input_data, cls=self._get_json_encoder())
truncated_input = self._truncate_json_if_needed(json_input)
tool_span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_INPUT, truncated_input)
except (TypeError, ValueError):
pass # Skip input logging if serialization fails
try:
result = await wrapped(*args, **kwargs)
except Exception as e:
tool_span.set_attribute(ERROR_TYPE, type(e).__name__)
tool_span.record_exception(e)
tool_span.set_status(Status(StatusCode.ERROR, str(e)))
mcp_span.set_attribute(ERROR_TYPE, type(e).__name__)
mcp_span.record_exception(e)
mcp_span.set_status(Status(StatusCode.ERROR, str(e)))
raise
try:
# Add output in traceloop format to tool span
if self._should_send_prompts() and result:
try:
# Convert FastMCP Content objects to serializable format
# Note: result.content for fastmcp 2.12.2+, fallback to result for older versions
output_data = []
result_items = result.content if hasattr(result, 'content') else result
for item in result_items:
if hasattr(item, 'text'):
output_data.append({"type": "text", "content": item.text})
elif hasattr(item, '__dict__'):
output_data.append(item.__dict__)
else:
output_data.append(str(item))
json_output = json.dumps(output_data, cls=self._get_json_encoder())
truncated_output = self._truncate_json_if_needed(json_output)
tool_span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_OUTPUT, truncated_output)
# Also add response to MCP span
mcp_span.set_attribute(SpanAttributes.MCP_RESPONSE_VALUE, truncated_output)
except (TypeError, ValueError):
pass # Skip output logging if serialization fails
tool_span.set_status(Status(StatusCode.OK))
mcp_span.set_status(Status(StatusCode.OK))
except Exception:
pass
return result
return traced_method
def _should_send_prompts(self):
"""Check if content tracing is enabled (matches traceloop SDK)"""
return (
os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
).lower() == "true"
def _get_json_encoder(self):
"""Get JSON encoder class (simplified - traceloop SDK uses custom JSONEncoder)"""
return None # Use default JSON encoder
def _truncate_json_if_needed(self, json_str: str) -> str:
"""Truncate JSON if it exceeds OTEL limits (matches traceloop SDK)"""
limit_str = os.getenv("OTEL_SPAN_ATTRIBUTE_VALUE_LENGTH_LIMIT")
if limit_str:
try:
limit = int(limit_str)
if limit > 0 and len(json_str) > limit:
return json_str[:limit]
except ValueError:
pass
return json_str