Files
aiagent/backend/app/agent_runtime/tool_manager.py
renjianbo 09467568ec feat: Agent 运行时、对话 API、作业助手与引擎修复及前端执行超时
- agent_runtime 模块与 agent_chat API,前端 AgentChat 视图与路由对接
- workflow_engine: code 节点命名空间与 json 引用修复
- llm_service: 工具调用 extra_body(如 DeepSeek)
- create_homework_manager_agent / _3 脚本与测试脚本扩展
- frontend: WORKFLOW_EXECUTION_HTTP_TIMEOUT_MS、AgentChatPreview/MainLayout 等
- 文档:架构说明与自主 Agent 改造完成情况

Made-with: Cursor
2026-05-01 11:31:48 +08:00

95 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Agent 工具管理器:包装已有 ToolRegistry提供 Agent 需要的工具格式转换和执行。
"""
from __future__ import annotations
import json
import logging
from typing import Any, Callable, Dict, List, Optional
from app.services.tool_registry import tool_registry
logger = logging.getLogger(__name__)
class AgentToolManager:
"""
为 Agent Runtime 管理工具:
- 将 ToolRegistry 的工具 schema 转为 OpenAI Function Calling 格式
- 按 Agent 配置过滤(白名单/黑名单)
- 执行工具调用并返回结果字符串
"""
def __init__(self, include_tools: Optional[List[str]] = None,
exclude_tools: Optional[List[str]] = None):
self._include_tools: set = set(include_tools or [])
self._exclude_tools: set = set(exclude_tools or [])
def get_tool_schemas(self) -> List[Dict[str, Any]]:
"""获取 Agent 可用的工具定义列表OpenAI Function Calling 格式)。"""
all_schemas = tool_registry.get_all_tool_schemas()
if not self._include_tools and not self._exclude_tools:
return all_schemas
filtered = []
for schema in all_schemas:
name = self._extract_tool_name(schema)
if not name:
continue
if self._include_tools and name not in self._include_tools:
continue
if name in self._exclude_tools:
continue
filtered.append(schema)
return filtered
def has_tools(self) -> bool:
"""是否有可用工具。"""
return len(self.get_tool_schemas()) > 0
def tool_names(self) -> List[str]:
"""可用工具名称列表。"""
return [
self._extract_tool_name(s) or "?"
for s in self.get_tool_schemas()
]
async def execute(self, name: str, args: Dict[str, Any]) -> str:
"""
执行工具调用。
Args:
name: 工具名称
args: 工具参数字典
Returns:
工具执行结果的字符串表示
"""
func: Optional[Callable] = tool_registry.get_tool_function(name)
if not func:
err = f"工具 '{name}' 不存在"
logger.error(err)
return json.dumps({"error": err}, ensure_ascii=False)
logger.info("Agent 执行工具: %s, 参数: %s", name, args)
try:
import asyncio
if asyncio.iscoroutinefunction(func):
result = await func(**args)
else:
result = func(**args)
if isinstance(result, (dict, list)):
return json.dumps(result, ensure_ascii=False)
return str(result)
except Exception as e:
err_msg = f"工具 '{name}' 执行失败: {e}"
logger.error(err_msg, exc_info=True)
return json.dumps({"error": err_msg}, ensure_ascii=False)
@staticmethod
def _extract_tool_name(schema: Dict[str, Any]) -> Optional[str]:
"""从工具 schema 中提取工具名称。"""
fn = schema.get("function") or schema
return fn.get("name") if isinstance(fn, dict) else None