fix: add RFC 9728 compliant well-known URL discovery with path insertion fallback (#29960)
This commit is contained in:
@@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPAuthError, MCPError
|
||||
@@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
|
||||
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
|
||||
|
||||
|
||||
class ProviderUrlValidationData(BaseModel):
|
||||
"""Data required for URL validation, extracted from database to perform network operations outside of session"""
|
||||
|
||||
current_server_url_hash: str
|
||||
headers: dict[str, str]
|
||||
timeout: float | None
|
||||
sse_read_timeout: float | None
|
||||
|
||||
|
||||
class MCPToolManageService:
|
||||
"""Service class for managing MCP tools and providers."""
|
||||
|
||||
@@ -166,9 +174,6 @@ class MCPToolManageService:
|
||||
self._session.add(mcp_tool)
|
||||
self._session.flush()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
|
||||
return mcp_providers
|
||||
|
||||
@@ -192,7 +197,7 @@ class MCPToolManageService:
|
||||
Update an MCP provider.
|
||||
|
||||
Args:
|
||||
validation_result: Pre-validation result from validate_server_url_change.
|
||||
validation_result: Pre-validation result from validate_server_url_standalone.
|
||||
If provided and contains reconnect_result, it will be used
|
||||
instead of performing network operations.
|
||||
"""
|
||||
@@ -251,8 +256,6 @@ class MCPToolManageService:
|
||||
# Flush changes to database
|
||||
self._session.flush()
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
except IntegrityError as e:
|
||||
self._handle_integrity_error(e, name, server_url, server_identifier)
|
||||
|
||||
@@ -261,9 +264,6 @@ class MCPToolManageService:
|
||||
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
self._session.delete(mcp_tool)
|
||||
|
||||
# Invalidate tool providers cache
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
def list_providers(
|
||||
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
|
||||
) -> list[ToolProviderApiEntity]:
|
||||
@@ -546,30 +546,39 @@ class MCPToolManageService:
|
||||
)
|
||||
return self.execute_auth_actions(auth_result)
|
||||
|
||||
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
|
||||
"""Attempt to reconnect to MCP provider with new server URL."""
|
||||
def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
|
||||
"""
|
||||
Get provider data required for URL validation.
|
||||
This method performs database read and should be called within a session.
|
||||
|
||||
Returns:
|
||||
ProviderUrlValidationData: Data needed for standalone URL validation
|
||||
"""
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
provider_entity = provider.to_entity()
|
||||
headers = provider_entity.headers
|
||||
return ProviderUrlValidationData(
|
||||
current_server_url_hash=provider.server_url_hash,
|
||||
headers=provider_entity.headers,
|
||||
timeout=provider_entity.timeout,
|
||||
sse_read_timeout=provider_entity.sse_read_timeout,
|
||||
)
|
||||
|
||||
try:
|
||||
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def validate_server_url_change(
|
||||
self, *, tenant_id: str, provider_id: str, new_server_url: str
|
||||
@staticmethod
|
||||
def validate_server_url_standalone(
|
||||
*,
|
||||
tenant_id: str,
|
||||
new_server_url: str,
|
||||
validation_data: ProviderUrlValidationData,
|
||||
) -> ServerUrlValidationResult:
|
||||
"""
|
||||
Validate server URL change by attempting to connect to the new server.
|
||||
This method should be called BEFORE update_provider to perform network operations
|
||||
outside of the database transaction.
|
||||
This method performs network operations and MUST be called OUTSIDE of any database session
|
||||
to avoid holding locks during network I/O.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for encryption
|
||||
new_server_url: The new server URL to validate
|
||||
validation_data: Provider data obtained from get_provider_for_url_validation
|
||||
|
||||
Returns:
|
||||
ServerUrlValidationResult: Validation result with connection status and tools if successful
|
||||
@@ -579,25 +588,30 @@ class MCPToolManageService:
|
||||
return ServerUrlValidationResult(needs_validation=False)
|
||||
|
||||
# Validate URL format
|
||||
if not self._is_valid_url(new_server_url):
|
||||
parsed = urlparse(new_server_url)
|
||||
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
|
||||
raise ValueError("Server URL is not valid.")
|
||||
|
||||
# Always encrypt and hash the URL
|
||||
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
|
||||
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
|
||||
|
||||
# Get current provider
|
||||
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
|
||||
|
||||
# Check if URL is actually different
|
||||
if new_server_url_hash == provider.server_url_hash:
|
||||
if new_server_url_hash == validation_data.current_server_url_hash:
|
||||
# URL hasn't changed, but still return the encrypted data
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
|
||||
needs_validation=False,
|
||||
encrypted_server_url=encrypted_server_url,
|
||||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
# Perform validation by attempting to connect
|
||||
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
|
||||
# Perform network validation - this is the expensive operation that should be outside session
|
||||
reconnect_result = MCPToolManageService._reconnect_with_url(
|
||||
server_url=new_server_url,
|
||||
headers=validation_data.headers,
|
||||
timeout=validation_data.timeout,
|
||||
sse_read_timeout=validation_data.sse_read_timeout,
|
||||
)
|
||||
return ServerUrlValidationResult(
|
||||
needs_validation=True,
|
||||
validation_passed=True,
|
||||
@@ -606,6 +620,38 @@ class MCPToolManageService:
|
||||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reconnect_with_url(
|
||||
*,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
timeout: float | None,
|
||||
sse_read_timeout: float | None,
|
||||
) -> ReconnectResult:
|
||||
"""
|
||||
Attempt to connect to MCP server with given URL.
|
||||
This is a static method that performs network I/O without database access.
|
||||
"""
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
|
||||
except MCPError as e:
|
||||
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
|
||||
|
||||
def _build_tool_provider_response(
|
||||
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
|
||||
) -> ToolProviderApiEntity:
|
||||
|
||||
Reference in New Issue
Block a user