feat: implement MCP specification 2025-06-18 (#25766)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -217,3 +217,16 @@ class Tool(ABC):
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
|
||||
)
|
||||
|
||||
def create_variable_message(
|
||||
self, variable_name: str, variable_value: Any, stream: bool = False
|
||||
) -> ToolInvokeMessage:
|
||||
"""
|
||||
create a variable message
|
||||
"""
|
||||
return ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.VARIABLE,
|
||||
message=ToolInvokeMessage.VariableMessage(
|
||||
variable_name=variable_name, variable_value=variable_value, stream=stream
|
||||
),
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@@ -44,10 +45,14 @@ class ToolProviderApiEntity(BaseModel):
|
||||
server_url: str | None = Field(default="", description="The server url of the tool")
|
||||
updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
server_identifier: str | None = Field(default="", description="The server identifier of the MCP tool")
|
||||
timeout: float | None = Field(default=30.0, description="The timeout of the MCP tool")
|
||||
sse_read_timeout: float | None = Field(default=300.0, description="The SSE read timeout of the MCP tool")
|
||||
|
||||
masked_headers: dict[str, str] | None = Field(default=None, description="The masked headers of the MCP tool")
|
||||
original_headers: dict[str, str] | None = Field(default=None, description="The original headers of the MCP tool")
|
||||
authentication: MCPAuthentication | None = Field(default=None, description="The OAuth config of the MCP tool")
|
||||
is_dynamic_registration: bool = Field(default=True, description="Whether the MCP tool is dynamically registered")
|
||||
configuration: MCPConfiguration | None = Field(
|
||||
default=None, description="The timeout and sse_read_timeout of the MCP tool"
|
||||
)
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
@@ -70,8 +75,15 @@ class ToolProviderApiEntity(BaseModel):
|
||||
if self.type == ToolProviderType.MCP:
|
||||
optional_fields.update(self.optional_field("updated_at", self.updated_at))
|
||||
optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
|
||||
optional_fields.update(self.optional_field("timeout", self.timeout))
|
||||
optional_fields.update(self.optional_field("sse_read_timeout", self.sse_read_timeout))
|
||||
optional_fields.update(
|
||||
self.optional_field(
|
||||
"configuration", self.configuration.model_dump() if self.configuration else MCPConfiguration()
|
||||
)
|
||||
)
|
||||
optional_fields.update(
|
||||
self.optional_field("authentication", self.authentication.model_dump() if self.authentication else None)
|
||||
)
|
||||
optional_fields.update(self.optional_field("is_dynamic_registration", self.is_dynamic_registration))
|
||||
optional_fields.update(self.optional_field("masked_headers", self.masked_headers))
|
||||
optional_fields.update(self.optional_field("original_headers", self.original_headers))
|
||||
return {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from typing import Any, Self
|
||||
|
||||
from core.entities.mcp_provider import MCPProviderEntity
|
||||
from core.mcp.types import Tool as RemoteMCPTool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
@@ -52,18 +52,25 @@ class MCPToolProviderController(ToolProviderController):
|
||||
"""
|
||||
from db provider
|
||||
"""
|
||||
tools = []
|
||||
tools_data = json.loads(db_provider.tools)
|
||||
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
|
||||
user = db_provider.load_user()
|
||||
# Convert to entity first
|
||||
provider_entity = db_provider.to_entity()
|
||||
return cls.from_entity(provider_entity)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, entity: MCPProviderEntity) -> Self:
|
||||
"""
|
||||
create a MCPToolProviderController from a MCPProviderEntity
|
||||
"""
|
||||
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in entity.tools]
|
||||
|
||||
tools = [
|
||||
ToolEntity(
|
||||
identity=ToolIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
author="Anonymous", # Tool level author is not stored
|
||||
name=remote_mcp_tool.name,
|
||||
label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
|
||||
provider=db_provider.server_identifier,
|
||||
icon=db_provider.icon,
|
||||
provider=entity.provider_id,
|
||||
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||
),
|
||||
parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
|
||||
description=ToolDescription(
|
||||
@@ -72,31 +79,32 @@ class MCPToolProviderController(ToolProviderController):
|
||||
),
|
||||
llm=remote_mcp_tool.description or "",
|
||||
),
|
||||
output_schema=remote_mcp_tool.outputSchema or {},
|
||||
has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
|
||||
)
|
||||
for remote_mcp_tool in remote_mcp_tools
|
||||
]
|
||||
if not db_provider.icon:
|
||||
if not entity.icon:
|
||||
raise ValueError("Database provider icon is required")
|
||||
return cls(
|
||||
entity=ToolProviderEntityWithPlugin(
|
||||
identity=ToolProviderIdentity(
|
||||
author=user.name if user else "Anonymous",
|
||||
name=db_provider.name,
|
||||
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
|
||||
author="Anonymous", # Provider level author is not stored in entity
|
||||
name=entity.name,
|
||||
label=I18nObject(en_US=entity.name, zh_Hans=entity.name),
|
||||
description=I18nObject(en_US="", zh_Hans=""),
|
||||
icon=db_provider.icon,
|
||||
icon=entity.icon if isinstance(entity.icon, str) else "",
|
||||
),
|
||||
plugin_id=None,
|
||||
credentials_schema=[],
|
||||
tools=tools,
|
||||
),
|
||||
provider_id=db_provider.server_identifier or "",
|
||||
tenant_id=db_provider.tenant_id or "",
|
||||
server_url=db_provider.decrypted_server_url,
|
||||
headers=db_provider.decrypted_headers or {},
|
||||
timeout=db_provider.timeout,
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
provider_id=entity.provider_id,
|
||||
tenant_id=entity.tenant_id,
|
||||
server_url=entity.server_url,
|
||||
headers=entity.headers,
|
||||
timeout=entity.timeout,
|
||||
sse_read_timeout=entity.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
|
||||
@@ -3,12 +3,13 @@ import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from core.mcp.error import MCPAuthError, MCPConnectionError
|
||||
from core.mcp.mcp_client import MCPClient
|
||||
from core.mcp.types import ImageContent, TextContent
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
from core.mcp.types import CallToolResult, ImageContent, TextContent
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
@@ -44,40 +45,32 @@ class MCPTool(Tool):
|
||||
app_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
from core.tools.errors import ToolInvokeError
|
||||
|
||||
try:
|
||||
with MCPClient(
|
||||
self.server_url,
|
||||
self.provider_id,
|
||||
self.tenant_id,
|
||||
authed=True,
|
||||
headers=self.headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPAuthError as e:
|
||||
raise ToolInvokeError("Please auth the tool first") from e
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
# handle dify tool output
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
yield from self._process_text_content(content)
|
||||
elif isinstance(content, ImageContent):
|
||||
yield self._process_image_content(content)
|
||||
# handle MCP structured output
|
||||
if self.entity.output_schema and result.structuredContent:
|
||||
for k, v in result.structuredContent.items():
|
||||
yield self.create_variable_message(k, v)
|
||||
|
||||
def _process_text_content(self, content: TextContent) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process text content and yield appropriate messages."""
|
||||
try:
|
||||
content_json = json.loads(content.text)
|
||||
yield from self._process_json_content(content_json)
|
||||
except json.JSONDecodeError:
|
||||
yield self.create_text_message(content.text)
|
||||
# Check if content looks like JSON before attempting to parse
|
||||
text = content.text.strip()
|
||||
if text and text[0] in ("{", "[") and text[-1] in ("}", "]"):
|
||||
try:
|
||||
content_json = json.loads(text)
|
||||
yield from self._process_json_content(content_json)
|
||||
return
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# If not JSON or parsing failed, treat as plain text
|
||||
yield self.create_text_message(content.text)
|
||||
|
||||
def _process_json_content(self, content_json: Any) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""Process JSON content based on its type."""
|
||||
@@ -126,3 +119,44 @@ class MCPTool(Tool):
|
||||
for key, value in parameter.items()
|
||||
if value is not None and not (isinstance(value, str) and value.strip() == "")
|
||||
}
|
||||
|
||||
def invoke_remote_mcp_tool(self, tool_parameters: dict[str, Any]) -> CallToolResult:
|
||||
headers = self.headers.copy() if self.headers else {}
|
||||
tool_parameters = self._handle_none_parameter(tool_parameters)
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
# Step 1: Load provider entity and credentials in a short-lived session
|
||||
# This minimizes database connection hold time
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
provider_entity = mcp_service.get_provider_entity(self.provider_id, self.tenant_id, by_server_id=True)
|
||||
|
||||
# Decrypt and prepare all credentials before closing session
|
||||
server_url = provider_entity.decrypt_server_url()
|
||||
headers = provider_entity.decrypt_headers()
|
||||
|
||||
# Try to get existing token and add to headers
|
||||
if not headers:
|
||||
tokens = provider_entity.retrieve_tokens()
|
||||
if tokens and tokens.access_token:
|
||||
headers["Authorization"] = f"{tokens.token_type.capitalize()} {tokens.access_token}"
|
||||
|
||||
# Step 2: Session is now closed, perform network operations without holding database connection
|
||||
# MCPClientWithAuthRetry will create a new session lazily only if auth retry is needed
|
||||
try:
|
||||
with MCPClientWithAuthRetry(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=self.timeout,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
provider_entity=provider_entity,
|
||||
) as mcp_client:
|
||||
return mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
|
||||
except MCPConnectionError as e:
|
||||
raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
|
||||
except Exception as e:
|
||||
raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
|
||||
|
||||
@@ -14,17 +14,32 @@ from sqlalchemy.orm import Session
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import ToolProviderID
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
|
||||
from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
@@ -40,21 +55,11 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.mcp_tool.provider import MCPToolProviderController
|
||||
from core.tools.mcp_tool.tool import MCPTool
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.plugin_tool.tool import PluginTool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
|
||||
from core.tools.utils.uuid_utils import is_valid_uuid
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import ToolProviderID
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.enterprise.plugin_manager_service import PluginCredentialType
|
||||
from services.tools.mcp_tools_manage_service import MCPToolManageService
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -719,7 +724,9 @@ class ToolManager:
|
||||
)
|
||||
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
|
||||
if "mcp" in filters:
|
||||
mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
mcp_providers = mcp_service.list_providers(tenant_id=tenant_id, for_list=True)
|
||||
for mcp_provider in mcp_providers:
|
||||
result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
|
||||
|
||||
@@ -774,17 +781,12 @@ class ToolManager:
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(
|
||||
MCPToolProvider.server_identifier == provider_id,
|
||||
MCPToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
try:
|
||||
provider = mcp_service.get_provider(server_identifier=provider_id, tenant_id=tenant_id)
|
||||
except ValueError:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
controller = MCPToolProviderController.from_db(provider)
|
||||
|
||||
@@ -922,16 +924,15 @@ class ToolManager:
|
||||
@classmethod
|
||||
def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str:
|
||||
try:
|
||||
mcp_provider: MCPToolProvider | None = (
|
||||
db.session.query(MCPToolProvider)
|
||||
.where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mcp_provider is None:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
|
||||
return mcp_provider.provider_icon
|
||||
with Session(db.engine) as session:
|
||||
mcp_service = MCPToolManageService(session=session)
|
||||
try:
|
||||
mcp_provider = mcp_service.get_provider_entity(
|
||||
provider_id=provider_id, tenant_id=tenant_id, by_server_id=True
|
||||
)
|
||||
return mcp_provider.provider_icon
|
||||
except ValueError:
|
||||
raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
|
||||
except Exception:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user