feat: implement RFC-compliant OAuth discovery with dynamic scope selection for MCP providers (#28294)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -23,11 +23,13 @@ from core.mcp.auth.auth_flow import (
|
||||
)
|
||||
from core.mcp.entities import AuthActionType, AuthResult
|
||||
from core.mcp.types import (
|
||||
LATEST_PROTOCOL_VERSION,
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthTokens,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -154,7 +156,7 @@ class TestOAuthDiscovery:
|
||||
assert auth_url == "https://auth.example.com"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-protected-resource",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
|
||||
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
|
||||
)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
@@ -183,59 +185,61 @@ class TestOAuthDiscovery:
|
||||
assert auth_url == "https://auth.example.com"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-protected-resource?query=1#fragment",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26", "User-Agent": "Dify"},
|
||||
headers={"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION, "User-Agent": "Dify"},
|
||||
)
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_with_resource_discovery(self, mock_get):
|
||||
def test_discover_oauth_metadata_with_resource_discovery(self):
|
||||
"""Test OAuth metadata discovery with resource discovery support."""
|
||||
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||
mock_check.return_value = (True, "https://auth.example.com")
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
# Mock protected resource metadata with auth server URL
|
||||
mock_prm.return_value = ProtectedResourceMetadata(
|
||||
resource="https://api.example.com",
|
||||
authorization_servers=["https://auth.example.com"],
|
||||
)
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.is_success = True
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"response_types_supported": ["code"],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
# Mock OAuth authorization server metadata
|
||||
mock_asm.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
|
||||
metadata = discover_oauth_metadata("https://api.example.com")
|
||||
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
|
||||
|
||||
assert metadata is not None
|
||||
assert metadata.authorization_endpoint == "https://auth.example.com/authorize"
|
||||
assert metadata.token_endpoint == "https://auth.example.com/token"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26"},
|
||||
)
|
||||
assert oauth_metadata is not None
|
||||
assert oauth_metadata.authorization_endpoint == "https://auth.example.com/authorize"
|
||||
assert oauth_metadata.token_endpoint == "https://auth.example.com/token"
|
||||
assert prm is not None
|
||||
assert prm.authorization_servers == ["https://auth.example.com"]
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_without_resource_discovery(self, mock_get):
|
||||
# Verify the discovery functions were called
|
||||
mock_prm.assert_called_once()
|
||||
mock_asm.assert_called_once()
|
||||
|
||||
def test_discover_oauth_metadata_without_resource_discovery(self):
|
||||
"""Test OAuth metadata discovery without resource discovery."""
|
||||
with patch("core.mcp.auth.auth_flow.check_support_resource_discovery") as mock_check:
|
||||
mock_check.return_value = (False, "")
|
||||
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:
|
||||
# Mock no protected resource metadata
|
||||
mock_prm.return_value = None
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.is_success = True
|
||||
mock_response.json.return_value = {
|
||||
"authorization_endpoint": "https://api.example.com/oauth/authorize",
|
||||
"token_endpoint": "https://api.example.com/oauth/token",
|
||||
"response_types_supported": ["code"],
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
# Mock OAuth authorization server metadata
|
||||
mock_asm.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://api.example.com/oauth/authorize",
|
||||
token_endpoint="https://api.example.com/oauth/token",
|
||||
response_types_supported=["code"],
|
||||
)
|
||||
|
||||
metadata = discover_oauth_metadata("https://api.example.com")
|
||||
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
|
||||
|
||||
assert metadata is not None
|
||||
assert metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
|
||||
mock_get.assert_called_once_with(
|
||||
"https://api.example.com/.well-known/oauth-authorization-server",
|
||||
headers={"MCP-Protocol-Version": "2025-03-26"},
|
||||
)
|
||||
assert oauth_metadata is not None
|
||||
assert oauth_metadata.authorization_endpoint == "https://api.example.com/oauth/authorize"
|
||||
assert prm is None
|
||||
|
||||
# Verify the discovery functions were called
|
||||
mock_prm.assert_called_once()
|
||||
mock_asm.assert_called_once()
|
||||
|
||||
@patch("core.helper.ssrf_proxy.get")
|
||||
def test_discover_oauth_metadata_not_found(self, mock_get):
|
||||
@@ -247,9 +251,9 @@ class TestOAuthDiscovery:
|
||||
mock_response.status_code = 404
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
metadata = discover_oauth_metadata("https://api.example.com")
|
||||
oauth_metadata, prm, scope = discover_oauth_metadata("https://api.example.com")
|
||||
|
||||
assert metadata is None
|
||||
assert oauth_metadata is None
|
||||
|
||||
|
||||
class TestAuthorizationFlow:
|
||||
@@ -342,6 +346,7 @@ class TestAuthorizationFlow:
|
||||
"""Test successful authorization code exchange."""
|
||||
mock_response = Mock()
|
||||
mock_response.is_success = True
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"token_type": "Bearer",
|
||||
@@ -412,6 +417,7 @@ class TestAuthorizationFlow:
|
||||
"""Test successful token refresh."""
|
||||
mock_response = Mock()
|
||||
mock_response.is_success = True
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "refreshed-access-token",
|
||||
"token_type": "Bearer",
|
||||
@@ -577,11 +583,15 @@ class TestAuthOrchestration:
|
||||
def test_auth_new_registration(self, mock_start_auth, mock_register, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow for new client registration."""
|
||||
# Setup
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
mock_register.return_value = OAuthClientInformationFull(
|
||||
client_id="new-client-id",
|
||||
@@ -619,11 +629,15 @@ class TestAuthOrchestration:
|
||||
def test_auth_exchange_code(self, mock_exchange, mock_retrieve_state, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow for exchanging authorization code."""
|
||||
# Setup metadata discovery
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
# Setup existing client
|
||||
@@ -662,11 +676,15 @@ class TestAuthOrchestration:
|
||||
def test_auth_exchange_code_without_state(self, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth flow fails when exchanging code without state."""
|
||||
# Setup metadata discovery
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
mock_provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="existing-client")
|
||||
@@ -698,11 +716,15 @@ class TestAuthOrchestration:
|
||||
mock_refresh.return_value = new_tokens
|
||||
|
||||
with patch("core.mcp.auth.auth_flow.discover_oauth_metadata") as mock_discover:
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
result = auth(mock_provider)
|
||||
@@ -725,11 +747,15 @@ class TestAuthOrchestration:
|
||||
def test_auth_registration_fails_with_code(self, mock_discover, mock_provider, mock_service):
|
||||
"""Test auth fails when no client info exists but code is provided."""
|
||||
# Setup metadata discovery
|
||||
mock_discover.return_value = OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
mock_discover.return_value = (
|
||||
OAuthMetadata(
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
response_types_supported=["code"],
|
||||
grant_types_supported=["authorization_code"],
|
||||
),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
mock_provider.retrieve_client_information.return_value = None
|
||||
|
||||
@@ -139,7 +139,9 @@ def test_sse_client_error_handling():
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock 401 HTTP error
|
||||
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=Mock(status_code=401))
|
||||
mock_response = Mock(status_code=401)
|
||||
mock_response.headers = {"WWW-Authenticate": 'Bearer realm="example"'}
|
||||
mock_error = httpx.HTTPStatusError("Unauthorized", request=Mock(), response=mock_response)
|
||||
mock_sse_connect.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MCPAuthError):
|
||||
@@ -150,7 +152,9 @@ def test_sse_client_error_handling():
|
||||
with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_client_factory:
|
||||
with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sse_connect:
|
||||
# Mock other HTTP error
|
||||
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=Mock(status_code=500))
|
||||
mock_response = Mock(status_code=500)
|
||||
mock_response.headers = {}
|
||||
mock_error = httpx.HTTPStatusError("Server Error", request=Mock(), response=mock_response)
|
||||
mock_sse_connect.side_effect = mock_error
|
||||
|
||||
with pytest.raises(MCPConnectionError):
|
||||
|
||||
@@ -58,7 +58,7 @@ class TestConstants:
|
||||
|
||||
def test_protocol_versions(self):
|
||||
"""Test protocol version constants."""
|
||||
assert LATEST_PROTOCOL_VERSION == "2025-03-26"
|
||||
assert LATEST_PROTOCOL_VERSION == "2025-06-18"
|
||||
assert SERVER_LATEST_PROTOCOL_VERSION == "2024-11-05"
|
||||
|
||||
def test_error_codes(self):
|
||||
|
||||
Reference in New Issue
Block a user