Files
aiagent/backend/app/services/tool_registry.py
2026-01-23 09:49:45 +08:00

116 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工具注册表 - 管理所有可用工具
"""
from typing import Dict, Any, Callable, Optional, List
import json
import logging
from app.models.tool import Tool
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class ToolRegistry:
"""工具注册表 - 管理所有可用工具"""
def __init__(self):
self._builtin_tools: Dict[str, Callable] = {}
self._tool_schemas: Dict[str, Dict[str, Any]] = {}
def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]):
"""
注册内置工具
Args:
name: 工具名称
func: 工具函数(可以是同步或异步函数)
schema: 工具定义OpenAI Function格式
"""
self._builtin_tools[name] = func
self._tool_schemas[name] = schema
logger.info(f"注册内置工具: {name}")
def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]:
"""获取工具定义"""
return self._tool_schemas.get(name)
def get_tool_function(self, name: str) -> Optional[Callable]:
"""获取工具函数"""
return self._builtin_tools.get(name)
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
"""获取所有工具定义用于LLM"""
return list(self._tool_schemas.values())
def load_tools_from_db(self, db: Session, tool_names: List[str] = None):
"""
从数据库加载工具
Args:
db: 数据库会话
tool_names: 工具名称列表可选如果为None则加载所有公开工具
"""
query = db.query(Tool).filter(Tool.is_public == True)
if tool_names:
query = query.filter(Tool.name.in_(tool_names))
tools = query.all()
for tool in tools:
self._tool_schemas[tool.name] = tool.function_schema
# 根据implementation_type加载工具实现
if tool.implementation_type == 'builtin':
# 从内置工具中查找
if tool.name in self._builtin_tools:
logger.debug(f"工具 {tool.name} 已在内置工具中注册")
else:
logger.warning(f"工具 {tool.name} 标记为builtin但未在内置工具中找到")
elif tool.implementation_type == 'http':
# HTTP工具需要特殊处理
self._register_http_tool(tool)
elif tool.implementation_type == 'workflow':
# 工作流工具
self._register_workflow_tool(tool)
elif tool.implementation_type == 'code':
# 代码执行工具
self._register_code_tool(tool)
logger.info(f"从数据库加载了 {len(tools)} 个工具")
def _register_http_tool(self, tool: Tool):
"""注册HTTP工具"""
# TODO: 实现HTTP工具的动态注册
logger.warning(f"HTTP工具 {tool.name} 的动态注册尚未实现")
def _register_workflow_tool(self, tool: Tool):
"""注册工作流工具"""
# TODO: 实现工作流工具的动态注册
logger.warning(f"工作流工具 {tool.name} 的动态注册尚未实现")
def _register_code_tool(self, tool: Tool):
"""注册代码执行工具"""
# TODO: 实现代码执行工具的动态注册
logger.warning(f"代码执行工具 {tool.name} 的动态注册尚未实现")
def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]:
"""
根据工具名称列表获取工具定义
Args:
tool_names: 工具名称列表
Returns:
工具定义列表OpenAI Function格式
"""
tools = []
for name in tool_names:
schema = self.get_tool_schema(name)
if schema:
tools.append(schema)
else:
logger.warning(f"工具 {name} 未找到")
return tools
# 全局工具注册表实例
tool_registry = ToolRegistry()