276 lines
10 KiB
Python
276 lines
10 KiB
Python
|
|
"""
|
||
|
|
LiteLLM Proxy uses this MCP Client to connnect to other MCP servers.
|
||
|
|
"""
|
||
|
|
import asyncio
|
||
|
|
import base64
|
||
|
|
from datetime import timedelta
|
||
|
|
from typing import List, Optional
|
||
|
|
|
||
|
|
from mcp import ClientSession, StdioServerParameters
|
||
|
|
from mcp.client.sse import sse_client
|
||
|
|
from mcp.client.stdio import stdio_client
|
||
|
|
from mcp.client.streamable_http import streamablehttp_client
|
||
|
|
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||
|
|
from mcp.types import CallToolResult as MCPCallToolResult
|
||
|
|
from mcp.types import TextContent
|
||
|
|
from mcp.types import Tool as MCPTool
|
||
|
|
|
||
|
|
from litellm._logging import verbose_logger
|
||
|
|
from litellm.types.mcp import (
|
||
|
|
MCPAuth,
|
||
|
|
MCPAuthType,
|
||
|
|
MCPSpecVersion,
|
||
|
|
MCPSpecVersionType,
|
||
|
|
MCPStdioConfig,
|
||
|
|
MCPTransport,
|
||
|
|
MCPTransportType,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def to_basic_auth(auth_value: str) -> str:
|
||
|
|
"""Convert auth value to Basic Auth format."""
|
||
|
|
return base64.b64encode(auth_value.encode("utf-8")).decode()
|
||
|
|
|
||
|
|
|
||
|
|
class MCPClient:
|
||
|
|
"""
|
||
|
|
MCP Client supporting:
|
||
|
|
SSE and HTTP transports
|
||
|
|
Authentication via Bearer token, Basic Auth, or API Key
|
||
|
|
Tool calling with error handling and result parsing
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
server_url: str = "",
|
||
|
|
transport_type: MCPTransportType = MCPTransport.http,
|
||
|
|
auth_type: MCPAuthType = None,
|
||
|
|
auth_value: Optional[str] = None,
|
||
|
|
timeout: float = 60.0,
|
||
|
|
stdio_config: Optional[MCPStdioConfig] = None,
|
||
|
|
protocol_version: MCPSpecVersionType = MCPSpecVersion.jun_2025,
|
||
|
|
):
|
||
|
|
self.server_url: str = server_url
|
||
|
|
self.transport_type: MCPTransport = transport_type
|
||
|
|
self.auth_type: MCPAuthType = auth_type
|
||
|
|
self.timeout: float = timeout
|
||
|
|
self._mcp_auth_value: Optional[str] = None
|
||
|
|
self._session: Optional[ClientSession] = None
|
||
|
|
self._context = None
|
||
|
|
self._transport_ctx = None
|
||
|
|
self._transport = None
|
||
|
|
self._session_ctx = None
|
||
|
|
self._task: Optional[asyncio.Task] = None
|
||
|
|
self.stdio_config: Optional[MCPStdioConfig] = stdio_config
|
||
|
|
self.protocol_version: MCPSpecVersionType = protocol_version
|
||
|
|
|
||
|
|
# handle the basic auth value if provided
|
||
|
|
if auth_value:
|
||
|
|
self.update_auth_value(auth_value)
|
||
|
|
|
||
|
|
async def __aenter__(self):
|
||
|
|
"""
|
||
|
|
Enable async context manager support.
|
||
|
|
Initializes the transport and session.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
await self.connect()
|
||
|
|
return self
|
||
|
|
except Exception:
|
||
|
|
await self.disconnect()
|
||
|
|
raise
|
||
|
|
|
||
|
|
async def connect(self):
|
||
|
|
"""Initialize the transport and session."""
|
||
|
|
if self._session:
|
||
|
|
return # Already connected
|
||
|
|
|
||
|
|
try:
|
||
|
|
if self.transport_type == MCPTransport.stdio:
|
||
|
|
# For stdio transport, use stdio_client with command-line parameters
|
||
|
|
if not self.stdio_config:
|
||
|
|
raise ValueError("stdio_config is required for stdio transport")
|
||
|
|
|
||
|
|
server_params = StdioServerParameters(
|
||
|
|
command=self.stdio_config.get("command", ""),
|
||
|
|
args=self.stdio_config.get("args", []),
|
||
|
|
env=self.stdio_config.get("env", {})
|
||
|
|
)
|
||
|
|
|
||
|
|
self._transport_ctx = stdio_client(server_params)
|
||
|
|
self._transport = await self._transport_ctx.__aenter__()
|
||
|
|
self._session_ctx = ClientSession(self._transport[0], self._transport[1])
|
||
|
|
self._session = await self._session_ctx.__aenter__()
|
||
|
|
await self._session.initialize()
|
||
|
|
elif self.transport_type == MCPTransport.sse:
|
||
|
|
headers = self._get_auth_headers()
|
||
|
|
self._transport_ctx = sse_client(
|
||
|
|
url=self.server_url,
|
||
|
|
timeout=self.timeout,
|
||
|
|
headers=headers,
|
||
|
|
)
|
||
|
|
self._transport = await self._transport_ctx.__aenter__()
|
||
|
|
self._session_ctx = ClientSession(self._transport[0], self._transport[1])
|
||
|
|
self._session = await self._session_ctx.__aenter__()
|
||
|
|
await self._session.initialize()
|
||
|
|
else: # http
|
||
|
|
headers = self._get_auth_headers()
|
||
|
|
self._transport_ctx = streamablehttp_client(
|
||
|
|
url=self.server_url,
|
||
|
|
timeout=timedelta(seconds=self.timeout),
|
||
|
|
headers=headers,
|
||
|
|
)
|
||
|
|
self._transport = await self._transport_ctx.__aenter__()
|
||
|
|
self._session_ctx = ClientSession(self._transport[0], self._transport[1])
|
||
|
|
self._session = await self._session_ctx.__aenter__()
|
||
|
|
await self._session.initialize()
|
||
|
|
except ValueError as e:
|
||
|
|
# Re-raise ValueError exceptions (like missing stdio_config)
|
||
|
|
verbose_logger.warning(f"MCP client connection failed: {str(e)}")
|
||
|
|
await self.disconnect()
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
verbose_logger.warning(f"MCP client connection failed: {str(e)}")
|
||
|
|
await self.disconnect()
|
||
|
|
# Don't raise other exceptions, let the calling code handle it gracefully
|
||
|
|
# This allows the server manager to continue with other servers
|
||
|
|
# Instead of raising, we'll let the calling code handle the failure
|
||
|
|
pass
|
||
|
|
|
||
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
|
|
"""Cleanup when exiting context manager."""
|
||
|
|
await self.disconnect()
|
||
|
|
|
||
|
|
async def disconnect(self):
|
||
|
|
"""Clean up session and connections."""
|
||
|
|
if self._task and not self._task.done():
|
||
|
|
self._task.cancel()
|
||
|
|
try:
|
||
|
|
await self._task
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
if self._session:
|
||
|
|
try:
|
||
|
|
await self._session_ctx.__aexit__(None, None, None) # type: ignore
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
self._session = None
|
||
|
|
self._session_ctx = None
|
||
|
|
|
||
|
|
if self._transport_ctx:
|
||
|
|
try:
|
||
|
|
await self._transport_ctx.__aexit__(None, None, None)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
self._transport_ctx = None
|
||
|
|
self._transport = None
|
||
|
|
|
||
|
|
if self._context:
|
||
|
|
try:
|
||
|
|
await self._context.__aexit__(None, None, None) # type: ignore
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
self._context = None
|
||
|
|
|
||
|
|
def update_auth_value(self, mcp_auth_value: str):
|
||
|
|
"""
|
||
|
|
Set the authentication header for the MCP client.
|
||
|
|
"""
|
||
|
|
if self.auth_type == MCPAuth.basic:
|
||
|
|
# Assuming mcp_auth_value is in format "username:password", convert it when updating
|
||
|
|
mcp_auth_value = to_basic_auth(mcp_auth_value)
|
||
|
|
self._mcp_auth_value = mcp_auth_value
|
||
|
|
|
||
|
|
def _get_auth_headers(self) -> dict:
|
||
|
|
"""Generate authentication headers based on auth type."""
|
||
|
|
headers = {}
|
||
|
|
|
||
|
|
if self._mcp_auth_value:
|
||
|
|
if self.auth_type == MCPAuth.bearer_token:
|
||
|
|
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
|
||
|
|
elif self.auth_type == MCPAuth.basic:
|
||
|
|
headers["Authorization"] = f"Basic {self._mcp_auth_value}"
|
||
|
|
elif self.auth_type == MCPAuth.api_key:
|
||
|
|
headers["X-API-Key"] = self._mcp_auth_value
|
||
|
|
|
||
|
|
# Handle protocol version - it might be a string or enum
|
||
|
|
if hasattr(self.protocol_version, 'value'):
|
||
|
|
# It's an enum
|
||
|
|
protocol_version_str = self.protocol_version.value
|
||
|
|
else:
|
||
|
|
# It's a string
|
||
|
|
protocol_version_str = str(self.protocol_version)
|
||
|
|
|
||
|
|
headers["MCP-Protocol-Version"] = protocol_version_str
|
||
|
|
return headers
|
||
|
|
|
||
|
|
|
||
|
|
async def list_tools(self) -> List[MCPTool]:
|
||
|
|
"""List available tools from the server."""
|
||
|
|
if not self._session:
|
||
|
|
try:
|
||
|
|
await self.connect()
|
||
|
|
except Exception as e:
|
||
|
|
verbose_logger.warning(f"MCP client connection failed: {str(e)}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
if self._session is None:
|
||
|
|
verbose_logger.warning("MCP client session is not initialized")
|
||
|
|
return []
|
||
|
|
|
||
|
|
try:
|
||
|
|
result = await self._session.list_tools()
|
||
|
|
return result.tools
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
await self.disconnect()
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
verbose_logger.warning(f"MCP client list_tools failed: {str(e)}")
|
||
|
|
await self.disconnect()
|
||
|
|
# Return empty list instead of raising to allow graceful degradation
|
||
|
|
return []
|
||
|
|
|
||
|
|
async def call_tool(
|
||
|
|
self, call_tool_request_params: MCPCallToolRequestParams
|
||
|
|
) -> MCPCallToolResult:
|
||
|
|
"""
|
||
|
|
Call an MCP Tool.
|
||
|
|
"""
|
||
|
|
if not self._session:
|
||
|
|
try:
|
||
|
|
await self.connect()
|
||
|
|
except Exception as e:
|
||
|
|
verbose_logger.warning(f"MCP client connection failed: {str(e)}")
|
||
|
|
return MCPCallToolResult(
|
||
|
|
content=[TextContent(type="text", text=f"{str(e)}")],
|
||
|
|
isError=True
|
||
|
|
)
|
||
|
|
|
||
|
|
if self._session is None:
|
||
|
|
verbose_logger.warning("MCP client session is not initialized")
|
||
|
|
return MCPCallToolResult(
|
||
|
|
content=[TextContent(type="text", text="MCP client session is not initialized")],
|
||
|
|
isError=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
tool_result = await self._session.call_tool(
|
||
|
|
name=call_tool_request_params.name,
|
||
|
|
arguments=call_tool_request_params.arguments,
|
||
|
|
)
|
||
|
|
return tool_result
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
await self.disconnect()
|
||
|
|
raise
|
||
|
|
except Exception as e:
|
||
|
|
verbose_logger.warning(f"MCP client call_tool failed: {str(e)}")
|
||
|
|
await self.disconnect()
|
||
|
|
# Return a default error result instead of raising
|
||
|
|
return MCPCallToolResult(
|
||
|
|
content=[TextContent(type="text", text=f"{str(e)}")], # Empty content for error case
|
||
|
|
isError=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
|