feat: implement RFC-compliant OAuth discovery with dynamic scope selection for MCP providers (#28294)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Novice
2025-11-20 11:18:16 +08:00
committed by GitHub
parent 014cbaf387
commit 6be013e072
14 changed files with 442 additions and 141 deletions

View File

@@ -23,11 +23,13 @@ from core.mcp.auth.auth_flow import (
)
from core.mcp.entities import AuthActionType, AuthResult
from core.mcp.types import (
LATEST_PROTOCOL_VERSION,
OAuthClientInformation,
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthTokens,
ProtectedResourceMetadata,
)
@@ -154,7 +156,7 @@ class TestOAuthDiscovery:
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
@@ -183,59 +185,61 @@ class TestOAuthDiscovery:
assert auth_url == "https://auth.example.com"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
)
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
def test_discover_oauth_metadata_with_resource_discovery(self):
"""Test OAuth metadata discovery with resource discovery support."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (True, "https://auth.example.com")
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
# Mock protected resource metadata with auth server URL
mock_prm.return_value = ProtectedResourceMetadata(
resource="https://api.example.com",
authorization_servers=["https://auth.example.com"],
)
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
# Mock OAuth authorization server metadata
mock_asm.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
)
metadata = discover_oauth_metadata("https://api.example.com")
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert metadata.token_endpoint == "https://auth.example.com/token"
mock_get.assert_called_once_with(
"https://auth.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
assert oauth_metadata is not None
assert oauth_metadata.authorization_endpoint == "https://auth.example.com/authorize"
assert oauth_metadata.token_endpoint == "https://auth.example.com/token"
assert prm is not None
assert prm.authorization_servers == ["https://auth.example.com"]
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
# Verify the discovery functions were called
mock_prm.assert_called_once()
mock_asm.assert_called_once()
def test_discover_oauth_metadata_without_resource_discovery(self):
"""Test OAuth metadata discovery without resource discovery."""
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
mock_check.return_value = (False, "")
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
# Mock no protected resource metadata
mock_prm.return_value = None
mock_response = Mock()
mock_response.status_code = 200
mock_response.is_success = True
mock_response.json.return_value = {
"authorization_endpoint": "https://api.example.com/oauth/authorize",
"token_endpoint": "https://api.example.com/oauth/token",
"response_types_supported": ["code"],
}
mock_get.return_value = mock_response
# Mock OAuth authorization server metadata
mock_asm.return_value = OAuthMetadata(
authorization_endpoint="https://api.example.com/oauth/authorize",
token_endpoint="https://api.example.com/oauth/token",
response_types_supported=["code"],
)
metadata = discover_oauth_metadata("https://api.example.com")
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
assert metadata is not None
assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
mock_get.assert_called_once_with(
"https://api.example.com/.well-known/oauth-authorization-server",
headers={"MCP-Protocol-Version": "2025-03-26"},
)
assert oauth_metadata is not None
assert oauth_metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
assert prm is None
# Verify the discovery functions were called
mock_prm.assert_called_once()
mock_asm.assert_called_once()
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_metadata_not_found(self, mock_get):
@@ -247,9 +251,9 @@ class TestOAuthDiscovery:
mock_response.status_code = 404
mock_get.return_value = mock_response
metadata = discover_oauth_metadata("https://api.example.com")
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
assert metadata is None
assert oauth_metadata is None
class TestAuthorizationFlow:
@@ -342,6 +346,7 @@ class TestAuthorizationFlow:
"""Test successful authorization code exchange."""
mock_response = Mock()
mock_response.is_success = True
mock_response.headers = {"content-type": "application/json"}
mock_response.json.return_value = {
"access_token": "new-access-token",
"token_type": "Bearer",
@@ -412,6 +417,7 @@ class TestAuthorizationFlow:
"""Test successful token refresh."""
mock_response = Mock()
mock_response.is_success = True
mock_response.headers = {"content-type": "application/json"}
mock_response.json.return_value = {
"access_token": "refreshed-access-token",
"token_type": "Bearer",
@@ -577,11 +583,15 @@ class TestAuthOrchestration:
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
"""Test auth flow for new client registration."""
# Setup
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
mock_register.return_value = OAuthClientInformationFull(
client_id="new-client-id",
@@ -619,11 +629,15 @@ class TestAuthOrchestration:
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
"""Test auth flow for exchanging authorization code."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
# Setup existing client
@@ -662,11 +676,15 @@ class TestAuthOrchestration:
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
"""Test auth flow fails when exchanging code without state."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
@@ -698,11 +716,15 @@ class TestAuthOrchestration:
mock_refresh.return_value = new_tokens
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
result = auth(mock_provider)
@@ -725,11 +747,15 @@ class TestAuthOrchestration:
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
"""Test auth fails when no client info exists but code is provided."""
# Setup metadata discovery
mock_discover.return_value = OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
mock_discover.return_value = (
OAuthMetadata(
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
response_types_supported=["code"],
grant_types_supported=["authorization_code"],
),
None,
None,
)
mock_provider.retrieve_client_information.return_value = None

View File

@@ -139,7 +139,9 @@ def test_sse_client_error_handling():
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock 401 HTTP error
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=Mock(status_code=401))
mock_response = Mock(status_code=401)
mock_response.headers = {"WWW-Authenticate": 'Bearer realm="example"'}
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=mock_response)
mock_sse_connect.side_effect = mock_error
with pytest.raises(MCPAuthError):
@@ -150,7 +152,9 @@ def test_sse_client_error_handling():
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
# Mock other HTTP error
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=Mock(status_code=500))
mock_response = Mock(status_code=500)
mock_response.headers = {}
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=mock_response)
mock_sse_connect.side_effect = mock_error
with pytest.raises(MCPConnectionError):

View File

@@ -58,7 +58,7 @@ class TestConstants:
def test_protocol_versions(self):
"""Test protocol version constants."""
assert LATEST_PROTOCOL_VERSION == "2025-03-26"
assert LATEST_PROTOCOL_VERSION == "2025-06-18"
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
def test_error_codes(self):