Files
aiagent/backend/app/services/tool_registry.py
renjianbo ab1589921a fix: 修复35个安全与功能缺陷,补全知识进化/数字孪生/行为采集模块
## 安全修复 (12项)
- Webhook接口添加全局Token认证,过滤敏感请求头
- 修复JWT Base64 padding公式,防止签名验证绕过
- 数据库密码/飞书Token从源码移除,改为环境变量
- 工作流引擎添加路径遍历防护 (_resolve_safe_path)
- eval()添加模板长度上限检查
- 审批API添加认证依赖
- 前端v-html增强XSS转义,console.log仅开发模式输出
- 500错误不再暴露内部异常详情

## Agent运行时修复 (7项)
- 删除_inject_knowledge_context中未定义db变量的finally块
- 工具执行添加try/except保护,异常不崩溃Agent
- LLM重试计入budget计数器
- self_review异常时passed=False
- max_iterations截断标记success=False
- 工具参数JSON解析失败时记录警告日志
- run()开始时重置_llm_invocations计数器

## 配置与基础设施
- DEBUG默认False,SQL_ECHO独立配置项
- init_db()补全13个缺失模型导入
- 新增WEBHOOK_AUTH_TOKEN/SQL_ECHO配置项
- 新增.env.example模板文件

## 前端修复 (12项)
- 登录改用URLSearchParams替代FormData
- 401拦截器通过Pinia store统一清理状态
- SSE流超时从60s延长至300s
- final/error事件时清除streamTimeout
- localStorage聊天记录添加24h TTL
- safeParseArgCount替代模板中裸JSON.parse
- fetchUser 401时同时清除user对象

## 新增模块
- 知识进化: knowledge_extractor/retriever/tasks
- 数字孪生: shadow_executor/comparison模型
- 行为采集: behavior_middleware/collector/fingerprint_engine
- 代码审查: code_review_agent/document_review_agent
- 反馈学习: feedback_learner
- 瓶颈检测/优化引擎/成本估算/需求估算
- 速率限制器 (rate_limiter)
- Alembic迁移 015-020

## 文档
- 商业化落地计划
- 8篇docs文档 (架构/API/部署/开发/贡献等)
- Docker Compose生产配置

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-10 19:50:20 +08:00

380 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工具注册表 — 管理所有可用工具(内置 / HTTP / 代码段)。
提供统一的工具注册、查找和执行接口。
"""
from __future__ import annotations
import json
import logging
import traceback
from typing import Any, Callable, Dict, List, Optional
from app.models.tool import Tool
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# 代码工具的安全内置模块
_CODE_SAFE_GLOBALS = {
"json": json,
"dict": dict,
"list": list,
"str": str,
"int": int,
"float": float,
"bool": bool,
"len": len,
"range": range,
"enumerate": enumerate,
"zip": zip,
"map": map,
"filter": filter,
"sorted": sorted,
"min": min,
"max": max,
"sum": sum,
"abs": abs,
"round": round,
"isinstance": isinstance,
"type": type,
"True": True,
"False": False,
"None": None,
}
class ToolRegistry:
"""工具注册表 — 管理所有可用工具。"""
def __init__(self):
self._builtin_tools: Dict[str, Callable] = {}
self._tool_schemas: Dict[str, Dict[str, Any]] = {}
# 自定义工具配置(从 DB 加载的非 builtin 工具)
self._custom_tool_configs: Dict[str, Dict[str, Any]] = {}
# ─── 内置工具注册 ─────────────────────────────────────────
def register_builtin_tool(self, name: str, func: Callable, schema: Dict[str, Any]):
"""
注册内置工具。
Args:
name: 工具名称
func: 工具函数(同步或异步)
schema: OpenAI function calling 格式 schema
"""
self._builtin_tools[name] = func
self._tool_schemas[name] = schema
logger.debug("注册内置工具: %s", name)
def register_external_tool(self, name: str, description: str, parameters: dict, category: str = "plugin"):
"""注册外部/插件工具 schema无实际执行函数仅用于工作流编辑器展示"""
schema = {
"name": name,
"description": description,
"parameters": parameters,
}
self._tool_schemas[name] = schema
self._custom_tool_configs[name] = {"category": category, "description": description}
logger.debug("注册外部工具 schema: %s (category=%s)", name, category)
# ─── 工具信息查询 ─────────────────────────────────────────
def get_tool_schema(self, name: str) -> Optional[Dict[str, Any]]:
return self._tool_schemas.get(name)
def get_tool_function(self, name: str) -> Optional[Callable]:
return self._builtin_tools.get(name)
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
return list(self._tool_schemas.values())
def builtin_tool_count(self) -> int:
return len(self._builtin_tools)
def builtin_tool_names(self) -> List[str]:
return sorted(self._builtin_tools.keys())
def get_tools_by_names(self, tool_names: List[str]) -> List[Dict[str, Any]]:
tools = []
for name in tool_names:
schema = self.get_tool_schema(name)
if schema:
tools.append(schema)
else:
logger.warning("工具 %s 未找到", name)
return tools
# ─── 从数据库加载自定义工具 ──────────────────────────────
def load_tools_from_db(self, db: Session, tool_names: Optional[List[str]] = None):
"""
从数据库加载工具定义。
Args:
db: 数据库会话
tool_names: 指定名称列表None 则加载所有公开工具
"""
query = db.query(Tool).filter(Tool.is_public == True)
if tool_names:
query = query.filter(Tool.name.in_(tool_names))
tools = query.all()
count = 0
for tool in tools:
self._tool_schemas[tool.name] = tool.function_schema
if tool.implementation_type == "builtin":
if tool.name not in self._builtin_tools:
logger.warning("工具 %s 标记为 builtin 但未注册", tool.name)
else:
# 存储自定义工具配置
config = tool.implementation_config or {}
config["_type"] = tool.implementation_type
config["_db_id"] = tool.id
self._custom_tool_configs[tool.name] = config
count += 1
if count:
logger.info("从数据库加载了 %d 个自定义工具", count)
# ─── 工具执行 ─────────────────────────────────────────────
async def execute_tool(self, name: str, args: Dict[str, Any]) -> str:
"""
执行任意工具(内置 / HTTP / 代码段)。
Args:
name: 工具名称
args: 参数字典
Returns:
结果字符串
"""
# 1. 内置工具
func = self._builtin_tools.get(name)
if func:
return await self._run_function(func, name, args)
# 2. 自定义工具
config = self._custom_tool_configs.get(name)
if not config:
return json.dumps({"error": f"工具 '{name}' 不存在"}, ensure_ascii=False)
impl_type = config.get("_type", "")
try:
if impl_type == "http":
return await self._execute_http_tool(name, config, args)
elif impl_type == "code":
return await self._execute_code_tool(name, config, args)
elif impl_type == "workflow":
return json.dumps({"error": "工作流工具暂不支持动态执行"},
ensure_ascii=False)
else:
return json.dumps({"error": f"不支持的实现类型: {impl_type}"},
ensure_ascii=False)
except Exception as e:
logger.error("工具 %s 执行失败: %s", name, e, exc_info=True)
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
ensure_ascii=False)
@staticmethod
async def _run_function(func: Callable, name: str, args: Dict[str, Any]) -> str:
"""执行内置工具函数。"""
import asyncio, inspect
try:
# 过滤掉函数不接受的参数,避免 LLM 生成的参数名不匹配导致 TypeError
sig = inspect.signature(func)
valid_params = set(sig.parameters.keys())
filtered_args = {k: v for k, v in args.items() if k in valid_params}
skipped = [k for k in args if k not in valid_params]
if skipped:
logger.warning("工具 '%s' 忽略未知参数: %s", name, skipped)
if asyncio.iscoroutinefunction(func):
result = await func(**filtered_args)
else:
result = func(**filtered_args)
if isinstance(result, (dict, list)):
return json.dumps(result, ensure_ascii=False)
return str(result)
except Exception as e:
logger.error("工具 '%s' 执行失败: %s", name, e, exc_info=True)
return json.dumps({"error": f"工具 '{name}' 执行失败: {e}"},
ensure_ascii=False)
# ─── HTTP 工具 ────────────────────────────────────────────
async def _execute_http_tool(
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
) -> str:
"""
执行 HTTP 工具。
implementation_config 格式:
{
"url": "https://api.example.com/{path_param}",
"method": "GET",
"headers": {"Authorization": "Bearer xxx"},
"body_template": {"key": "{param}"}, # 可选
"timeout": 30
}
URL 和 body_template 中的 {param} 会被 args 替换。
"""
import httpx
url = config.get("url", "")
method = (config.get("method") or "GET").upper()
headers = config.get("headers") or {}
body_template = config.get("body_template")
timeout = config.get("timeout", 30)
if not url:
return json.dumps({"error": "HTTP 工具未配置 URL"}, ensure_ascii=False)
# 模板替换:{param} → args["param"]
def _fmt(template: str) -> str:
try:
return template.format(**args)
except KeyError:
return template
url = _fmt(url)
body: Any = None
if body_template:
body_str = json.dumps(body_template)
body_str = _fmt(body_str)
try:
body = json.loads(body_str)
except json.JSONDecodeError:
body = body_str
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.request(method, url, headers=headers, json=body)
text = response.text[:10000] # 截断过长的响应
result = {
"status_code": response.status_code,
"body": text,
}
return json.dumps(result, ensure_ascii=False)
# ─── 代码工具 ─────────────────────────────────────────────
async def _execute_code_tool(
self, name: str, config: Dict[str, Any], args: Dict[str, Any]
) -> str:
"""
执行代码工具。
implementation_config 格式:
{
"source": "def run(args):\\n return {'sum': args['a'] + args['b']}",
"language": "python"
}
代码工具需定义一个 run(args) 函数作为入口。
source 在沙箱环境中执行,不可访问文件系统/网络。
"""
source = config.get("source", "")
if not source:
return json.dumps({"error": "代码工具未配置 source"}, ensure_ascii=False)
# 编译代码,限制可访问的全局变量
safe_globals = _CODE_SAFE_GLOBALS.copy()
safe_globals["__builtins__"] = {} # 禁用所有内置函数
try:
exec(source, safe_globals)
except Exception as e:
return json.dumps({
"error": f"代码编译失败: {e}",
"traceback": traceback.format_exc(),
}, ensure_ascii=False)
run_func = safe_globals.get("run")
if not run_func or not callable(run_func):
return json.dumps({
"error": "代码工具必须定义一个 run(args) 函数"
}, ensure_ascii=False)
try:
result = run_func(args)
if isinstance(result, (dict, list)):
return json.dumps(result, ensure_ascii=False)
return str(result)
except Exception as e:
return json.dumps({
"error": f"代码执行失败: {e}",
"traceback": traceback.format_exc(),
}, ensure_ascii=False)
# ─── 测试工具(不保存,直接执行)─────────────────────────
async def test_http_tool(
self, url: str, method: str, headers: Dict[str, str],
body: Optional[Dict[str, Any]], args: Dict[str, Any],
timeout: int = 30,
) -> Dict[str, Any]:
"""测试 HTTP 工具配置(不保存到 DB"""
import httpx
def _fmt(template: str) -> str:
try:
return template.format(**args)
except KeyError:
return template
url = _fmt(url)
body_sent = None
if body:
body_str = json.dumps(body)
body_str = _fmt(body_str)
try:
body_sent = json.loads(body_str)
except json.JSONDecodeError:
body_sent = body_str
start = __import__("time").time()
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.request(method.upper(), url,
headers=headers, json=body_sent)
elapsed_ms = int((__import__("time").time() - start) * 1000)
return {
"success": True,
"status_code": response.status_code,
"elapsed_ms": elapsed_ms,
"body": response.text[:10000],
}
except Exception as e:
return {"success": False, "error": str(e)}
async def test_code_tool(
self, source: str, args: Dict[str, Any]
) -> Dict[str, Any]:
"""测试代码工具(不保存到 DB"""
safe_globals = _CODE_SAFE_GLOBALS.copy()
safe_globals["__builtins__"] = {}
try:
exec(source, safe_globals)
except Exception as e:
return {"success": False, "error": f"编译失败: {e}"}
run_func = safe_globals.get("run")
if not run_func:
return {"success": False, "error": "代码须定义 run(args) 函数"}
try:
start = __import__("time").time()
result = run_func(args)
elapsed_ms = int((__import__("time").time() - start) * 1000)
return {"success": True, "elapsed_ms": elapsed_ms, "result": result}
except Exception as e:
return {"success": False, "error": f"执行失败: {e}"}
# 全局单例
tool_registry = ToolRegistry()