Files
aiagent/backend/app/services/tool_registry.py

124 lines
4.4 KiB
Python
Raw Normal View History

2026-01-23 09:49:45 +08:00
"""
工具注册表 - 管理所有可用工具
"""
from typing import Dict, Any, Callable, Optional, List
import json
import logging
from app.models.tool import Tool
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class ToolRegistry:
"""工具注册表 - 管理所有可用工具"""
def __init__(self):
self._builtin_tools: Dict[str, Callable] = {}
self._tool_schemas: Dict[str, Dict[str, Any]] = {}
def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]):
"""
注册内置工具
Args:
name: 工具名称
func: 工具函数可以是同步或异步函数
schema: 工具定义OpenAI Function格式
"""
self._builtin_tools[name] = func
self._tool_schemas[name] = schema
logger.debug("注册内置工具: %s", name)
2026-01-23 09:49:45 +08:00
def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]:
"""获取工具定义"""
return self._tool_schemas.get(name)
def get_tool_function(self, name: str) -> Optional[Callable]:
"""获取工具函数"""
return self._builtin_tools.get(name)
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
"""获取所有工具定义用于LLM"""
return list(self._tool_schemas.values())
def builtin_tool_count(self) -> int:
"""已注册的内置工具数量(用于健康检查 / 启动自检)。"""
return len(self._builtin_tools)
def builtin_tool_names(self) -> List[str]:
"""已注册的内置工具名称,有序列表。"""
return sorted(self._builtin_tools.keys())
2026-01-23 09:49:45 +08:00
def load_tools_from_db(self, db: Session, tool_names: List[str] = None):
"""
从数据库加载工具
Args:
db: 数据库会话
tool_names: 工具名称列表可选如果为None则加载所有公开工具
"""
query = db.query(Tool).filter(Tool.is_public == True)
if tool_names:
query = query.filter(Tool.name.in_(tool_names))
tools = query.all()
for tool in tools:
self._tool_schemas[tool.name] = tool.function_schema
# 根据implementation_type加载工具实现
if tool.implementation_type == 'builtin':
# 从内置工具中查找
if tool.name in self._builtin_tools:
logger.debug(f"工具 {tool.name} 已在内置工具中注册")
else:
logger.warning(f"工具 {tool.name} 标记为builtin但未在内置工具中找到")
elif tool.implementation_type == 'http':
# HTTP工具需要特殊处理
self._register_http_tool(tool)
elif tool.implementation_type == 'workflow':
# 工作流工具
self._register_workflow_tool(tool)
elif tool.implementation_type == 'code':
# 代码执行工具
self._register_code_tool(tool)
logger.info(f"从数据库加载了 {len(tools)} 个工具")
def _register_http_tool(self, tool: Tool):
"""注册HTTP工具"""
# TODO: 实现HTTP工具的动态注册
logger.warning(f"HTTP工具 {tool.name} 的动态注册尚未实现")
def _register_workflow_tool(self, tool: Tool):
"""注册工作流工具"""
# TODO: 实现工作流工具的动态注册
logger.warning(f"工作流工具 {tool.name} 的动态注册尚未实现")
def _register_code_tool(self, tool: Tool):
"""注册代码执行工具"""
# TODO: 实现代码执行工具的动态注册
logger.warning(f"代码执行工具 {tool.name} 的动态注册尚未实现")
def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]:
"""
根据工具名称列表获取工具定义
Args:
tool_names: 工具名称列表
Returns:
工具定义列表OpenAI Function格式
"""
tools = []
for name in tool_names:
schema = self.get_tool_schema(name)
if schema:
tools.append(schema)
else:
logger.warning(f"工具 {name} 未找到")
return tools
# 全局工具注册表实例
tool_registry = ToolRegistry()