Files
aiagent/backend/app/agent_runtime/tool_manager.py

143 lines
4.9 KiB
Python
Raw Normal View History

"""
Agent 工具管理器包装已有 ToolRegistry提供 Agent 需要的工具格式转换和执行
"""
from __future__ import annotations
import hashlib
import json
import logging
from typing import Any, Dict, List, Optional
from app.services.tool_registry import tool_registry
logger = logging.getLogger(__name__)
# 默认确定性工具(结果可缓存)
_DETERMINISTIC_TOOLS = {
"file_read", "math_calculate", "database_query",
"json", "text", "csv", "excel", "pdf", "image",
}
class AgentToolManager:
"""
Agent Runtime 管理工具
- ToolRegistry 的工具 schema 转为 OpenAI Function Calling 格式
- Agent 配置过滤白名单/黑名单
- 执行工具调用并返回结果字符串
- 工具结果缓存Redis / 内存 fallback
"""
def __init__(self, include_tools: Optional[List[str]] = None,
exclude_tools: Optional[List[str]] = None,
cache_enabled: bool = True,
cache_tool_whitelist: Optional[List[str]] = None,
cache_ttl_ms: int = 3600000):
self._include_tools: set = set(include_tools or [])
self._exclude_tools: set = set(exclude_tools or [])
self._cache_enabled = cache_enabled
self._cache_whitelist: set = set(cache_tool_whitelist or [])
self._cache_ttl_s = max(1, int(cache_ttl_ms / 1000))
self._cache_store: Dict[str, str] = {} # 内存 fallback
def _is_cacheable(self, tool_name: str) -> bool:
"""判断工具结果是否可缓存。"""
if not self._cache_enabled:
return False
if self._cache_whitelist:
return tool_name in self._cache_whitelist
return tool_name in _DETERMINISTIC_TOOLS
@staticmethod
def _cache_key(name: str, args: Dict[str, Any]) -> str:
raw = json.dumps([name, args], sort_keys=True, ensure_ascii=False)
return f"tool:{name}:{hashlib.sha256(raw.encode()).hexdigest()[:16]}"
async def _cache_get(self, key: str) -> Optional[str]:
try:
from app.core.redis_client import get_redis_client
redis = get_redis_client()
if redis:
return await redis.get(key)
except Exception:
pass
return self._cache_store.get(key)
async def _cache_set(self, key: str, value: str):
try:
from app.core.redis_client import get_redis_client
redis = get_redis_client()
if redis:
await redis.setex(key, self._cache_ttl_s, value)
return
except Exception:
pass
self._cache_store[key] = value
def get_tool_schemas(self) -> List[Dict[str, Any]]:
"""获取 Agent 可用的工具定义列表OpenAI Function Calling 格式)。"""
all_schemas = tool_registry.get_all_tool_schemas()
if not self._include_tools and not self._exclude_tools:
return all_schemas
filtered = []
for schema in all_schemas:
name = self._extract_tool_name(schema)
if not name:
continue
if self._include_tools and name not in self._include_tools:
continue
if name in self._exclude_tools:
continue
filtered.append(schema)
return filtered
def has_tools(self) -> bool:
"""是否有可用工具。"""
return len(self.get_tool_schemas()) > 0
def tool_names(self) -> List[str]:
"""可用工具名称列表。"""
return [
self._extract_tool_name(s) or "?"
for s in self.get_tool_schemas()
]
async def execute(self, name: str, args: Dict[str, Any]) -> str:
"""
执行工具调用带缓存
优先查找内置工具其次查找数据库自定义工具HTTP / Code
Args:
name: 工具名称
args: 工具参数字典
Returns:
工具执行结果的字符串表示
"""
# 缓存检查
if self._is_cacheable(name):
ck = self._cache_key(name, args)
cached = await self._cache_get(ck)
if cached is not None:
logger.info("Agent 工具命中缓存: %s", name)
return cached
logger.info("Agent 执行工具: %s", name)
result = await tool_registry.execute_tool(name, args)
# 缓存写入
if self._is_cacheable(name):
ck = self._cache_key(name, args)
await self._cache_set(ck, result)
logger.debug("Agent 工具结果已缓存: %s", name)
return result
@staticmethod
def _extract_tool_name(schema: Dict[str, Any]) -> Optional[str]:
"""从工具 schema 中提取工具名称。"""
fn = schema.get("function") or schema
return fn.get("name") if isinstance(fn, dict) else None