116 lines
4.1 KiB
Python
116 lines
4.1 KiB
Python
"""
|
||
工具注册表 - 管理所有可用工具
|
||
"""
|
||
from typing import Dict, Any, Callable, Optional, List
|
||
import json
|
||
import logging
|
||
from app.models.tool import Tool
|
||
from sqlalchemy.orm import Session
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ToolRegistry:
|
||
"""工具注册表 - 管理所有可用工具"""
|
||
|
||
def __init__(self):
|
||
self._builtin_tools: Dict[str, Callable] = {}
|
||
self._tool_schemas: Dict[str, Dict[str, Any]] = {}
|
||
|
||
def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]):
|
||
"""
|
||
注册内置工具
|
||
|
||
Args:
|
||
name: 工具名称
|
||
func: 工具函数(可以是同步或异步函数)
|
||
schema: 工具定义(OpenAI Function格式)
|
||
"""
|
||
self._builtin_tools[name] = func
|
||
self._tool_schemas[name] = schema
|
||
logger.info(f"注册内置工具: {name}")
|
||
|
||
def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]:
|
||
"""获取工具定义"""
|
||
return self._tool_schemas.get(name)
|
||
|
||
def get_tool_function(self, name: str) -> Optional[Callable]:
|
||
"""获取工具函数"""
|
||
return self._builtin_tools.get(name)
|
||
|
||
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
|
||
"""获取所有工具定义(用于LLM)"""
|
||
return list(self._tool_schemas.values())
|
||
|
||
def load_tools_from_db(self, db: Session, tool_names: List[str] = None):
|
||
"""
|
||
从数据库加载工具
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
tool_names: 工具名称列表(可选,如果为None则加载所有公开工具)
|
||
"""
|
||
query = db.query(Tool).filter(Tool.is_public == True)
|
||
if tool_names:
|
||
query = query.filter(Tool.name.in_(tool_names))
|
||
|
||
tools = query.all()
|
||
for tool in tools:
|
||
self._tool_schemas[tool.name] = tool.function_schema
|
||
# 根据implementation_type加载工具实现
|
||
if tool.implementation_type == 'builtin':
|
||
# 从内置工具中查找
|
||
if tool.name in self._builtin_tools:
|
||
logger.debug(f"工具 {tool.name} 已在内置工具中注册")
|
||
else:
|
||
logger.warning(f"工具 {tool.name} 标记为builtin但未在内置工具中找到")
|
||
elif tool.implementation_type == 'http':
|
||
# HTTP工具需要特殊处理
|
||
self._register_http_tool(tool)
|
||
elif tool.implementation_type == 'workflow':
|
||
# 工作流工具
|
||
self._register_workflow_tool(tool)
|
||
elif tool.implementation_type == 'code':
|
||
# 代码执行工具
|
||
self._register_code_tool(tool)
|
||
|
||
logger.info(f"从数据库加载了 {len(tools)} 个工具")
|
||
|
||
def _register_http_tool(self, tool: Tool):
|
||
"""注册HTTP工具"""
|
||
# TODO: 实现HTTP工具的动态注册
|
||
logger.warning(f"HTTP工具 {tool.name} 的动态注册尚未实现")
|
||
|
||
def _register_workflow_tool(self, tool: Tool):
|
||
"""注册工作流工具"""
|
||
# TODO: 实现工作流工具的动态注册
|
||
logger.warning(f"工作流工具 {tool.name} 的动态注册尚未实现")
|
||
|
||
def _register_code_tool(self, tool: Tool):
|
||
"""注册代码执行工具"""
|
||
# TODO: 实现代码执行工具的动态注册
|
||
logger.warning(f"代码执行工具 {tool.name} 的动态注册尚未实现")
|
||
|
||
def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]:
|
||
"""
|
||
根据工具名称列表获取工具定义
|
||
|
||
Args:
|
||
tool_names: 工具名称列表
|
||
|
||
Returns:
|
||
工具定义列表(OpenAI Function格式)
|
||
"""
|
||
tools = []
|
||
for name in tool_names:
|
||
schema = self.get_tool_schema(name)
|
||
if schema:
|
||
tools.append(schema)
|
||
else:
|
||
logger.warning(f"工具 {name} 未找到")
|
||
return tools
|
||
|
||
|
||
# 全局工具注册表实例
|
||
tool_registry = ToolRegistry()
|