Files
aiagent/backend/app/agent_runtime/tool_manager.py
renjianbo 7e00b027d4 feat: Phase 3 - parallel execution, progress reporting, result caching + AgentChat bug fixes
Phase 3 能力:
- DAG 并行执行 (workflow_engine): asyncio.gather 并行执行就绪节点
- Debate 并行 (orchestrator): for 循环改为 asyncio.gather
- 粒度进度上报 (workflow_engine + tasks + websocket): Redis 推送 + DB 降级
- 工具结果缓存 (tool_manager): 确定性工具默认开启缓存
- LLM 响应缓存 (core): messages[-4:] + model 哈希,5min TTL

AgentChat bug 修复 (Gitea #1-#5):
- #1 SSE 降级重复空消息: fallback POST 前移除占位消息
- #2 streamTimeout 泄漏: while 正常退出后 clearTimeout
- #3 loading 闪烁: final/error 事件中提前设 loading=false
- #4 SSE 事件类型对齐: 确认匹配,未知类型加 console.warn
- #5 retryMessage 流式残留: 重试时清理占位消息

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-05 00:00:51 +08:00

143 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Agent 工具管理器:包装已有 ToolRegistry提供 Agent 需要的工具格式转换和执行。
"""
from __future__ import annotations
import hashlib
import json
import logging
from typing import Any, Dict, List, Optional
from app.services.tool_registry import tool_registry
logger = logging.getLogger(__name__)
# 默认确定性工具(结果可缓存)
_DETERMINISTIC_TOOLS = {
"file_read", "math_calculate", "database_query",
"json", "text", "csv", "excel", "pdf", "image",
}
class AgentToolManager:
"""
为 Agent Runtime 管理工具:
- 将 ToolRegistry 的工具 schema 转为 OpenAI Function Calling 格式
- 按 Agent 配置过滤(白名单/黑名单)
- 执行工具调用并返回结果字符串
- 工具结果缓存Redis / 内存 fallback
"""
def __init__(self, include_tools: Optional[List[str]] = None,
exclude_tools: Optional[List[str]] = None,
cache_enabled: bool = True,
cache_tool_whitelist: Optional[List[str]] = None,
cache_ttl_ms: int = 3600000):
self._include_tools: set = set(include_tools or [])
self._exclude_tools: set = set(exclude_tools or [])
self._cache_enabled = cache_enabled
self._cache_whitelist: set = set(cache_tool_whitelist or [])
self._cache_ttl_s = max(1, int(cache_ttl_ms / 1000))
self._cache_store: Dict[str, str] = {} # 内存 fallback
def _is_cacheable(self, tool_name: str) -> bool:
"""判断工具结果是否可缓存。"""
if not self._cache_enabled:
return False
if self._cache_whitelist:
return tool_name in self._cache_whitelist
return tool_name in _DETERMINISTIC_TOOLS
@staticmethod
def _cache_key(name: str, args: Dict[str, Any]) -> str:
raw = json.dumps([name, args], sort_keys=True, ensure_ascii=False)
return f"tool:{name}:{hashlib.sha256(raw.encode()).hexdigest()[:16]}"
async def _cache_get(self, key: str) -> Optional[str]:
try:
from app.core.redis_client import get_redis_client
redis = get_redis_client()
if redis:
return await redis.get(key)
except Exception:
pass
return self._cache_store.get(key)
async def _cache_set(self, key: str, value: str):
try:
from app.core.redis_client import get_redis_client
redis = get_redis_client()
if redis:
await redis.setex(key, self._cache_ttl_s, value)
return
except Exception:
pass
self._cache_store[key] = value
def get_tool_schemas(self) -> List[Dict[str, Any]]:
"""获取 Agent 可用的工具定义列表OpenAI Function Calling 格式)。"""
all_schemas = tool_registry.get_all_tool_schemas()
if not self._include_tools and not self._exclude_tools:
return all_schemas
filtered = []
for schema in all_schemas:
name = self._extract_tool_name(schema)
if not name:
continue
if self._include_tools and name not in self._include_tools:
continue
if name in self._exclude_tools:
continue
filtered.append(schema)
return filtered
def has_tools(self) -> bool:
"""是否有可用工具。"""
return len(self.get_tool_schemas()) > 0
def tool_names(self) -> List[str]:
"""可用工具名称列表。"""
return [
self._extract_tool_name(s) or "?"
for s in self.get_tool_schemas()
]
async def execute(self, name: str, args: Dict[str, Any]) -> str:
"""
执行工具调用(带缓存)。
优先查找内置工具其次查找数据库自定义工具HTTP / Code
Args:
name: 工具名称
args: 工具参数字典
Returns:
工具执行结果的字符串表示
"""
# 缓存检查
if self._is_cacheable(name):
ck = self._cache_key(name, args)
cached = await self._cache_get(ck)
if cached is not None:
logger.info("Agent 工具命中缓存: %s", name)
return cached
logger.info("Agent 执行工具: %s", name)
result = await tool_registry.execute_tool(name, args)
# 缓存写入
if self._is_cacheable(name):
ck = self._cache_key(name, args)
await self._cache_set(ck, result)
logger.debug("Agent 工具结果已缓存: %s", name)
return result
@staticmethod
def _extract_tool_name(schema: Dict[str, Any]) -> Optional[str]:
"""从工具 schema 中提取工具名称。"""
fn = schema.get("function") or schema
return fn.get("name") if isinstance(fn, dict) else None