remove bare list, dict, Sequence, None, Any (#25058)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -20,7 +20,7 @@ class Tool(ABC):
|
||||
The base class of a tool
|
||||
"""
|
||||
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime):
|
||||
self.entity = entity
|
||||
self.runtime = runtime
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ToolProviderController(ABC):
|
||||
def __init__(self, entity: ToolProviderEntity) -> None:
|
||||
def __init__(self, entity: ToolProviderEntity):
|
||||
self.entity = entity
|
||||
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
@@ -41,7 +41,7 @@ class ToolProviderController(ABC):
|
||||
"""
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||
def validate_credentials_format(self, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the format of the credentials of the provider and set the default value if needed
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from core.tools.utils.yaml_utils import load_yaml_file
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
tools: list[BuiltinTool]
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
def __init__(self, **data: Any):
|
||||
self.tools = []
|
||||
|
||||
# load provider yaml
|
||||
@@ -197,7 +197,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.entity.identity.tags or []
|
||||
|
||||
def validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
@@ -211,7 +211,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
self._validate_credentials(user_id, credentials)
|
||||
|
||||
@abstractmethod
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
|
||||
@@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class AudioToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class CodeToolProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@@ -4,5 +4,5 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
pass
|
||||
|
||||
@@ -4,7 +4,7 @@ from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
|
||||
|
||||
class WebscraperProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
Validate credentials
|
||||
"""
|
||||
|
||||
@@ -24,7 +24,7 @@ class ApiToolProviderController(ToolProviderController):
|
||||
tenant_id: str
|
||||
tools: list[ApiTool] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str) -> None:
|
||||
def __init__(self, entity: ToolProviderEntity, provider_id: str, tenant_id: str):
|
||||
super().__init__(entity)
|
||||
self.provider_id = provider_id
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@@ -302,7 +302,7 @@ class ApiTool(Tool):
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
) -> Any:
|
||||
):
|
||||
if max_recursive <= 0:
|
||||
raise Exception("Max recursion depth reached")
|
||||
for option in any_of or []:
|
||||
@@ -337,7 +337,7 @@ class ApiTool(Tool):
|
||||
# If no option succeeded, you might want to return the value as is or raise an error
|
||||
return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
|
||||
|
||||
def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any:
|
||||
def _convert_body_property_type(self, property: dict[str, Any], value: Any):
|
||||
try:
|
||||
if "type" in property:
|
||||
if property["type"] == "integer" or property["type"] == "int":
|
||||
|
||||
@@ -49,7 +49,7 @@ class ToolProviderApiEntity(BaseModel):
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
# -------------
|
||||
# overwrite tool parameter types for temp fix
|
||||
tools = jsonable_encoder(self.tools)
|
||||
@@ -84,7 +84,7 @@ class ToolProviderApiEntity(BaseModel):
|
||||
**optional_fields,
|
||||
}
|
||||
|
||||
def optional_field(self, key: str, value: Any) -> dict:
|
||||
def optional_field(self, key: str, value: Any):
|
||||
"""Return dict with key-value if value is truthy, empty dict otherwise."""
|
||||
return {key: value} if value else {}
|
||||
|
||||
|
||||
@@ -19,5 +19,5 @@ class I18nObject(BaseModel):
|
||||
self.pt_BR = self.pt_BR or self.en_US
|
||||
self.ja_JP = self.ja_JP or self.en_US
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||
|
||||
@@ -150,7 +150,7 @@ class ToolInvokeMessage(BaseModel):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def transform_variable_value(cls, values) -> Any:
|
||||
def transform_variable_value(cls, values):
|
||||
"""
|
||||
Only basic types and lists are allowed.
|
||||
"""
|
||||
@@ -428,7 +428,7 @@ class ToolInvokeMeta(BaseModel):
|
||||
"""
|
||||
return cls(time_cost=0.0, error=error, tool_config={})
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
def to_dict(self):
|
||||
return {
|
||||
"time_cost": self.time_cost,
|
||||
"error": self.error,
|
||||
|
||||
@@ -28,7 +28,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(entity)
|
||||
self.entity: ToolProviderEntityWithPlugin = entity
|
||||
self.tenant_id = tenant_id
|
||||
@@ -99,7 +99,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||
sse_read_timeout=db_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
|
||||
@@ -23,7 +23,7 @@ class MCPTool(Tool):
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
sse_read_timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
|
||||
@@ -16,7 +16,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
|
||||
def __init__(
|
||||
self, entity: ToolProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||
) -> None:
|
||||
):
|
||||
self.entity = entity
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_id = plugin_id
|
||||
@@ -31,7 +31,7 @@ class PluginToolProviderController(BuiltinToolProviderController):
|
||||
"""
|
||||
return ToolProviderType.PLUGIN
|
||||
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]):
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
|
||||
class PluginTool(Tool):
|
||||
def __init__(
|
||||
self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, plugin_unique_identifier: str
|
||||
) -> None:
|
||||
):
|
||||
super().__init__(entity, runtime)
|
||||
self.tenant_id = tenant_id
|
||||
self.icon = icon
|
||||
|
||||
@@ -778,7 +778,7 @@ class ToolManager:
|
||||
return controller
|
||||
|
||||
@classmethod
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
|
||||
def user_get_api_provider(cls, provider: str, tenant_id: str):
|
||||
"""
|
||||
get api provider
|
||||
"""
|
||||
@@ -873,7 +873,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
@@ -890,7 +890,7 @@ class ToolManager:
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str):
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
|
||||
@@ -24,7 +24,7 @@ class ToolParameterConfigurationManager:
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
) -> None:
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
|
||||
@@ -20,7 +20,7 @@ from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import Datas
|
||||
|
||||
|
||||
class DatasetRetrieverTool(Tool):
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool) -> None:
|
||||
def __init__(self, entity: ToolEntity, runtime: ToolRuntime, retrieval_tool: DatasetRetrieverBaseTool):
|
||||
super().__init__(entity, runtime)
|
||||
self.retrieval_tool = retrieval_tool
|
||||
|
||||
|
||||
@@ -17,11 +17,11 @@ class ProviderConfigCache(Protocol):
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
def set(self, config: dict[str, Any]):
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
def delete(self):
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
@@ -242,7 +242,7 @@ class ApiBasedToolSchemaParser:
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None):
|
||||
warning = warning or {}
|
||||
"""
|
||||
parse swagger to openapi
|
||||
|
||||
@@ -8,7 +8,7 @@ from yaml import YAMLError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}):
|
||||
"""
|
||||
Safe loading a YAML file
|
||||
:param file_path: the path of the YAML file
|
||||
|
||||
@@ -223,7 +223,7 @@ class WorkflowTool(Tool):
|
||||
|
||||
return result, files
|
||||
|
||||
def _update_file_mapping(self, file_dict: dict) -> dict:
|
||||
def _update_file_mapping(self, file_dict: dict):
|
||||
transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method"))
|
||||
if transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
file_dict["tool_file_id"] = file_dict.get("related_id")
|
||||
|
||||
Reference in New Issue
Block a user