337 lines
12 KiB
Python
337 lines
12 KiB
Python
|
|
import logging
|
||
|
|
import re
|
||
|
|
from urllib.parse import urljoin, urlparse
|
||
|
|
|
||
|
|
from httpx import Request, Response
|
||
|
|
from pydantic import AnyUrl, ValidationError
|
||
|
|
|
||
|
|
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
|
||
|
|
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||
|
|
from mcp.shared.auth import (
|
||
|
|
OAuthClientInformationFull,
|
||
|
|
OAuthClientMetadata,
|
||
|
|
OAuthMetadata,
|
||
|
|
OAuthToken,
|
||
|
|
ProtectedResourceMetadata,
|
||
|
|
)
|
||
|
|
from mcp.types import LATEST_PROTOCOL_VERSION
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
|
||
|
|
"""
|
||
|
|
Extract field from WWW-Authenticate header.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Field value if found in WWW-Authenticate header, None otherwise
|
||
|
|
"""
|
||
|
|
www_auth_header = response.headers.get("WWW-Authenticate")
|
||
|
|
if not www_auth_header:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Pattern matches: field_name="value" or field_name=value (unquoted)
|
||
|
|
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
|
||
|
|
match = re.search(pattern, www_auth_header)
|
||
|
|
|
||
|
|
if match:
|
||
|
|
# Return quoted value if present, otherwise unquoted value
|
||
|
|
return match.group(1) or match.group(2)
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def extract_scope_from_www_auth(response: Response) -> str | None:
|
||
|
|
"""
|
||
|
|
Extract scope parameter from WWW-Authenticate header as per RFC6750.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Scope string if found in WWW-Authenticate header, None otherwise
|
||
|
|
"""
|
||
|
|
return extract_field_from_www_auth(response, "scope")
|
||
|
|
|
||
|
|
|
||
|
|
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
|
||
|
|
"""
|
||
|
|
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Resource metadata URL if found in WWW-Authenticate header, None otherwise
|
||
|
|
"""
|
||
|
|
if not response or response.status_code != 401:
|
||
|
|
return None # pragma: no cover
|
||
|
|
|
||
|
|
return extract_field_from_www_auth(response, "resource_metadata")
|
||
|
|
|
||
|
|
|
||
|
|
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
|
||
|
|
"""
|
||
|
|
Build ordered list of URLs to try for protected resource metadata discovery.
|
||
|
|
|
||
|
|
Per SEP-985, the client MUST:
|
||
|
|
1. Try resource_metadata from WWW-Authenticate header (if present)
|
||
|
|
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
|
||
|
|
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
|
||
|
|
|
||
|
|
Args:
|
||
|
|
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
|
||
|
|
server_url: server url
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Ordered list of URLs to try for discovery
|
||
|
|
"""
|
||
|
|
urls: list[str] = []
|
||
|
|
|
||
|
|
# Priority 1: WWW-Authenticate header with resource_metadata parameter
|
||
|
|
if www_auth_url:
|
||
|
|
urls.append(www_auth_url)
|
||
|
|
|
||
|
|
# Priority 2-3: Well-known URIs (RFC 9728)
|
||
|
|
parsed = urlparse(server_url)
|
||
|
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||
|
|
|
||
|
|
# Priority 2: Path-based well-known URI (if server has a path component)
|
||
|
|
if parsed.path and parsed.path != "/":
|
||
|
|
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
|
||
|
|
urls.append(path_based_url)
|
||
|
|
|
||
|
|
# Priority 3: Root-based well-known URI
|
||
|
|
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||
|
|
urls.append(root_based_url)
|
||
|
|
|
||
|
|
return urls
|
||
|
|
|
||
|
|
|
||
|
|
def get_client_metadata_scopes(
|
||
|
|
www_authenticate_scope: str | None,
|
||
|
|
protected_resource_metadata: ProtectedResourceMetadata | None,
|
||
|
|
authorization_server_metadata: OAuthMetadata | None = None,
|
||
|
|
) -> str | None:
|
||
|
|
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
|
||
|
|
# Per MCP spec, scope selection priority order:
|
||
|
|
# 1. Use scope from WWW-Authenticate header (if provided)
|
||
|
|
# 2. Use all scopes from PRM scopes_supported (if available)
|
||
|
|
# 3. Omit scope parameter if neither is available
|
||
|
|
|
||
|
|
if www_authenticate_scope is not None:
|
||
|
|
# Priority 1: WWW-Authenticate header scope
|
||
|
|
return www_authenticate_scope
|
||
|
|
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
|
||
|
|
# Priority 2: PRM scopes_supported
|
||
|
|
return " ".join(protected_resource_metadata.scopes_supported)
|
||
|
|
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
|
||
|
|
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
|
||
|
|
else:
|
||
|
|
# Priority 3: Omit scope parameter
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
|
||
|
|
"""
|
||
|
|
Generate ordered list of (url, type) tuples for discovery attempts.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
|
||
|
|
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
|
||
|
|
"""
|
||
|
|
|
||
|
|
if not auth_server_url:
|
||
|
|
# Legacy path using the 2025-03-26 spec:
|
||
|
|
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
|
||
|
|
parsed = urlparse(server_url)
|
||
|
|
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]
|
||
|
|
|
||
|
|
urls: list[str] = []
|
||
|
|
parsed = urlparse(auth_server_url)
|
||
|
|
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||
|
|
|
||
|
|
# RFC 8414: Path-aware OAuth discovery
|
||
|
|
if parsed.path and parsed.path != "/":
|
||
|
|
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
|
||
|
|
urls.append(urljoin(base_url, oauth_path))
|
||
|
|
|
||
|
|
# RFC 8414 section 5: Path-aware OIDC discovery
|
||
|
|
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
|
||
|
|
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
|
||
|
|
urls.append(urljoin(base_url, oidc_path))
|
||
|
|
|
||
|
|
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||
|
|
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
|
||
|
|
urls.append(urljoin(base_url, oidc_path))
|
||
|
|
return urls
|
||
|
|
|
||
|
|
# OAuth root
|
||
|
|
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
|
||
|
|
|
||
|
|
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
|
||
|
|
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||
|
|
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))
|
||
|
|
|
||
|
|
return urls
|
||
|
|
|
||
|
|
|
||
|
|
async def handle_protected_resource_response(
|
||
|
|
response: Response,
|
||
|
|
) -> ProtectedResourceMetadata | None:
|
||
|
|
"""
|
||
|
|
Handle protected resource metadata discovery response.
|
||
|
|
|
||
|
|
Per SEP-985, supports fallback when discovery fails at one URL.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if metadata was successfully discovered, False if we should try next URL
|
||
|
|
"""
|
||
|
|
if response.status_code == 200:
|
||
|
|
try:
|
||
|
|
content = await response.aread()
|
||
|
|
metadata = ProtectedResourceMetadata.model_validate_json(content)
|
||
|
|
return metadata
|
||
|
|
|
||
|
|
except ValidationError: # pragma: no cover
|
||
|
|
# Invalid metadata - try next URL
|
||
|
|
return None
|
||
|
|
else:
|
||
|
|
# Not found - try next URL in fallback chain
|
||
|
|
return None
|
||
|
|
|
||
|
|
|
||
|
|
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
|
||
|
|
if response.status_code == 200:
|
||
|
|
try:
|
||
|
|
content = await response.aread()
|
||
|
|
asm = OAuthMetadata.model_validate_json(content)
|
||
|
|
return True, asm
|
||
|
|
except ValidationError: # pragma: no cover
|
||
|
|
return True, None
|
||
|
|
elif response.status_code < 400 or response.status_code >= 500:
|
||
|
|
return False, None # Non-4XX error, stop trying
|
||
|
|
return True, None
|
||
|
|
|
||
|
|
|
||
|
|
def create_oauth_metadata_request(url: str) -> Request:
|
||
|
|
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
|
||
|
|
|
||
|
|
|
||
|
|
def create_client_registration_request(
|
||
|
|
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
|
||
|
|
) -> Request:
|
||
|
|
"""Build registration request or skip if already registered."""
|
||
|
|
|
||
|
|
if auth_server_metadata and auth_server_metadata.registration_endpoint:
|
||
|
|
registration_url = str(auth_server_metadata.registration_endpoint)
|
||
|
|
else:
|
||
|
|
registration_url = urljoin(auth_base_url, "/register")
|
||
|
|
|
||
|
|
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||
|
|
|
||
|
|
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
|
||
|
|
|
||
|
|
|
||
|
|
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
|
||
|
|
"""Handle registration response."""
|
||
|
|
if response.status_code not in (200, 201):
|
||
|
|
await response.aread()
|
||
|
|
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
content = await response.aread()
|
||
|
|
client_info = OAuthClientInformationFull.model_validate_json(content)
|
||
|
|
return client_info
|
||
|
|
# self.context.client_info = client_info
|
||
|
|
# await self.context.storage.set_client_info(client_info)
|
||
|
|
except ValidationError as e: # pragma: no cover
|
||
|
|
raise OAuthRegistrationError(f"Invalid registration response: {e}")
|
||
|
|
|
||
|
|
|
||
|
|
def is_valid_client_metadata_url(url: str | None) -> bool:
|
||
|
|
"""Validate that a URL is suitable for use as a client_id (CIMD).
|
||
|
|
|
||
|
|
The URL must be HTTPS with a non-root pathname.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
url: The URL to validate
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if the URL is a valid HTTPS URL with a non-root pathname
|
||
|
|
"""
|
||
|
|
if not url:
|
||
|
|
return False
|
||
|
|
try:
|
||
|
|
parsed = urlparse(url)
|
||
|
|
return parsed.scheme == "https" and parsed.path not in ("", "/")
|
||
|
|
except Exception:
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def should_use_client_metadata_url(
|
||
|
|
oauth_metadata: OAuthMetadata | None,
|
||
|
|
client_metadata_url: str | None,
|
||
|
|
) -> bool:
|
||
|
|
"""Determine if URL-based client ID (CIMD) should be used instead of DCR.
|
||
|
|
|
||
|
|
URL-based client IDs should be used when:
|
||
|
|
1. The server advertises client_id_metadata_document_supported=true
|
||
|
|
2. The client has a valid client_metadata_url configured
|
||
|
|
|
||
|
|
Args:
|
||
|
|
oauth_metadata: OAuth authorization server metadata
|
||
|
|
client_metadata_url: URL-based client ID (already validated)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if CIMD should be used, False if DCR should be used
|
||
|
|
"""
|
||
|
|
if not client_metadata_url:
|
||
|
|
return False
|
||
|
|
|
||
|
|
if not oauth_metadata:
|
||
|
|
return False
|
||
|
|
|
||
|
|
return oauth_metadata.client_id_metadata_document_supported is True
|
||
|
|
|
||
|
|
|
||
|
|
def create_client_info_from_metadata_url(
|
||
|
|
client_metadata_url: str, redirect_uris: list[AnyUrl] | None = None
|
||
|
|
) -> OAuthClientInformationFull:
|
||
|
|
"""Create client information using a URL-based client ID (CIMD).
|
||
|
|
|
||
|
|
When using URL-based client IDs, the URL itself becomes the client_id
|
||
|
|
and no client_secret is used (token_endpoint_auth_method="none").
|
||
|
|
|
||
|
|
Args:
|
||
|
|
client_metadata_url: The URL to use as the client_id
|
||
|
|
redirect_uris: The redirect URIs from the client metadata (passed through for
|
||
|
|
compatibility with OAuthClientInformationFull which inherits from OAuthClientMetadata)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
OAuthClientInformationFull with the URL as client_id
|
||
|
|
"""
|
||
|
|
return OAuthClientInformationFull(
|
||
|
|
client_id=client_metadata_url,
|
||
|
|
token_endpoint_auth_method="none",
|
||
|
|
redirect_uris=redirect_uris,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
async def handle_token_response_scopes(
|
||
|
|
response: Response,
|
||
|
|
) -> OAuthToken:
|
||
|
|
"""Parse and validate token response with optional scope validation.
|
||
|
|
|
||
|
|
Parses token response JSON. Callers should check response.status_code before calling.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
response: HTTP response from token endpoint (status already checked by caller)
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Validated OAuthToken model
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
OAuthTokenError: If response JSON is invalid
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
content = await response.aread()
|
||
|
|
token_response = OAuthToken.model_validate_json(content)
|
||
|
|
return token_response
|
||
|
|
except ValidationError as e: # pragma: no cover
|
||
|
|
raise OAuthTokenError(f"Invalid token response: {e}")
|