""" 工具注册表 - 管理所有可用工具 """ from typing import Dict, Any, Callable, Optional, List import json import logging from app.models.tool import Tool from sqlalchemy.orm import Session logger = logging.getLogger(__name__) class ToolRegistry: """工具注册表 - 管理所有可用工具""" def __init__(self): self._builtin_tools: Dict[str, Callable] = {} self._tool_schemas: Dict[str, Dict[str, Any]] = {} def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]): """ 注册内置工具 Args: name: 工具名称 func: 工具函数(可以是同步或异步函数) schema: 工具定义(OpenAI Function格式) """ self._builtin_tools[name] = func self._tool_schemas[name] = schema logger.debug("注册内置工具: %s", name) def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]: """获取工具定义""" return self._tool_schemas.get(name) def get_tool_function(self, name: str) -> Optional[Callable]: """获取工具函数""" return self._builtin_tools.get(name) def get_all_tool_schemas(self) -> List[Dict[str, Any]]: """获取所有工具定义(用于LLM)""" return list(self._tool_schemas.values()) def builtin_tool_count(self) -> int: """已注册的内置工具数量(用于健康检查 / 启动自检)。""" return len(self._builtin_tools) def builtin_tool_names(self) -> List[str]: """已注册的内置工具名称,有序列表。""" return sorted(self._builtin_tools.keys()) def load_tools_from_db(self, db: Session, tool_names: List[str] = None): """ 从数据库加载工具 Args: db: 数据库会话 tool_names: 工具名称列表(可选,如果为None则加载所有公开工具) """ query = db.query(Tool).filter(Tool.is_public == True) if tool_names: query = query.filter(Tool.name.in_(tool_names)) tools = query.all() for tool in tools: self._tool_schemas[tool.name] = tool.function_schema # 根据implementation_type加载工具实现 if tool.implementation_type == 'builtin': # 从内置工具中查找 if tool.name in self._builtin_tools: logger.debug(f"工具 {tool.name} 已在内置工具中注册") else: logger.warning(f"工具 {tool.name} 标记为builtin但未在内置工具中找到") elif tool.implementation_type == 'http': # HTTP工具需要特殊处理 self._register_http_tool(tool) elif tool.implementation_type == 'workflow': # 工作流工具 self._register_workflow_tool(tool) elif tool.implementation_type == 'code': # 代码执行工具 self._register_code_tool(tool) logger.info(f"从数据库加载了 {len(tools)} 个工具") def _register_http_tool(self, tool: Tool): """注册HTTP工具""" # TODO: 实现HTTP工具的动态注册 logger.warning(f"HTTP工具 {tool.name} 的动态注册尚未实现") def _register_workflow_tool(self, tool: Tool): """注册工作流工具""" # TODO: 实现工作流工具的动态注册 logger.warning(f"工作流工具 {tool.name} 的动态注册尚未实现") def _register_code_tool(self, tool: Tool): """注册代码执行工具""" # TODO: 实现代码执行工具的动态注册 logger.warning(f"代码执行工具 {tool.name} 的动态注册尚未实现") def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]: """ 根据工具名称列表获取工具定义 Args: tool_names: 工具名称列表 Returns: 工具定义列表(OpenAI Function格式) """ tools = [] for name in tool_names: schema = self.get_tool_schema(name) if schema: tools.append(schema) else: logger.warning(f"工具 {name} 未找到") return tools # 全局工具注册表实例 tool_registry = ToolRegistry()