feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖

- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser
- 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行)
- 新增 agent_vector_memory / knowledge_base 模型及对应数据库表
- 实现 SSE 流式响应与 Agent 预算控制
- AgentChat.vue 集成 MainLayout 导航布局
- 完善测试体系:7 个新测试文件共 110 个测试覆盖
- 修复 conftest.py SQLite 内存数据库连接隔离问题

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
renjianbo
2026-05-01 22:30:46 +08:00
parent 036f533881
commit 7b9e0826de
35 changed files with 4353 additions and 365 deletions

View File

@@ -1,21 +1,42 @@
"""
工具管理API
工具市场 API — 管理、测试、发现和安装工具。
提供内置工具和用户自定义工具HTTP / 代码段)的 CRUD、测试和执行。
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from typing import List, Optional
from app.api.auth import get_current_user
from app.core.database import get_db
from app.models.tool import Tool
from app.services.tool_registry import tool_registry
from app.api.auth import get_current_user
from app.models.user import User
from pydantic import BaseModel
from app.services.tool_registry import tool_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
# ─── Schema ──────────────────────────────────────────────────────
class ToolCreate(BaseModel):
"""创建工具请求"""
name: str
description: str
category: Optional[str] = None
function_schema: dict
implementation_type: str # builtin / http / code / workflow
implementation_config: Optional[dict] = None
is_public: bool = False
class ToolResponse(BaseModel):
id: str
name: str
description: str
category: Optional[str] = None
@@ -23,86 +44,39 @@ class ToolCreate(BaseModel):
implementation_type: str
implementation_config: Optional[dict] = None
is_public: bool = False
use_count: int = 0
user_id: Optional[str] = None
created_at: str = ""
updated_at: str = ""
class ToolResponse(BaseModel):
"""工具响应"""
id: str
name: str
description: str
category: Optional[str]
function_schema: dict
implementation_type: str
implementation_config: Optional[dict]
is_public: bool
use_count: int
user_id: Optional[str]
created_at: str
updated_at: str
class Config:
from_attributes = True
class TestHTTPRequest(BaseModel):
url: str
method: str = "GET"
headers: Dict[str, str] = {}
body: Optional[Dict[str, Any]] = None
args: Dict[str, Any] = {}
timeout: int = 30
@router.get("", response_model=List[ToolResponse])
async def list_tools(
category: Optional[str] = Query(None, description="工具分类"),
search: Optional[str] = Query(None, description="搜索关键词"),
db: Session = Depends(get_db)
):
"""获取工具列表"""
query = db.query(Tool).filter(Tool.is_public == True)
if category:
query = query.filter(Tool.category == category)
if search:
query = query.filter(
Tool.name.contains(search) |
Tool.description.contains(search)
)
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
# 转换为响应格式,确保日期时间字段转换为字符串
result = []
for tool in tools:
result.append({
"id": tool.id,
"name": tool.name,
"description": tool.description,
"category": tool.category,
"function_schema": tool.function_schema,
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
})
return result
class TestCodeRequest(BaseModel):
source: str
args: Dict[str, Any] = {}
@router.get("/builtin")
async def list_builtin_tools():
"""获取内置工具列表"""
schemas = tool_registry.get_all_tool_schemas()
return schemas
class TestResponse(BaseModel):
success: bool
elapsed_ms: Optional[int] = None
result: Optional[Any] = None
status_code: Optional[int] = None
body: Optional[str] = None
error: Optional[str] = None
@router.get("/{tool_id}", response_model=ToolResponse)
async def get_tool(
tool_id: str,
db: Session = Depends(get_db)
):
"""获取工具详情"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 转换为响应格式,确保日期时间字段转换为字符串
# ─── 工具函数 ──────────────────────────────────────────────────
def _tool_to_dict(tool: Tool) -> dict:
return {
"id": tool.id,
"name": tool.name,
@@ -115,22 +89,89 @@ async def get_tool(
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
"updated_at": tool.updated_at.isoformat() if tool.updated_at else "",
}
# ─── 工具市场浏览 ──────────────────────────────────────────────
@router.get("", response_model=List[ToolResponse])
async def list_tools(
category: Optional[str] = Query(None, description="按分类筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
scope: Optional[str] = Query("public", description="public / mine / all"),
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user),
):
"""浏览工具市场。"""
query = db.query(Tool)
if scope == "public":
query = query.filter(Tool.is_public == True)
elif scope == "mine":
if not current_user:
raise HTTPException(status_code=401, detail="需登录")
query = query.filter(Tool.user_id == current_user.id)
if category:
query = query.filter(Tool.category == category)
if search:
query = query.filter(
Tool.name.contains(search) | Tool.description.contains(search)
)
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
return [_tool_to_dict(t) for t in tools]
@router.get("/categories", response_model=List[str])
async def list_categories(db: Session = Depends(get_db)):
"""列出所有工具分类。"""
rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
cats = sorted(set(r[0] for r in rows if r[0]))
# 加上常用分类
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"]
for d in defaults:
if d not in cats:
cats.append(d)
return cats
@router.get("/builtin")
async def list_builtin_tools():
"""列出所有内置工具OpenAI Function 格式)。"""
return tool_registry.get_all_tool_schemas()
@router.get("/{tool_id}", response_model=ToolResponse)
async def get_tool(tool_id: str, db: Session = Depends(get_db)):
"""获取工具详情。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
return _tool_to_dict(tool)
# ─── 工具创建 / 更新 / 删除 ──────────────────────────────────
@router.post("", response_model=ToolResponse, status_code=201)
async def create_tool(
tool_data: ToolCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""创建工具"""
# 检查工具名称是否已存在
"""创建自定义工具"""
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
valid_types = {"builtin", "http", "code", "workflow"}
if tool_data.implementation_type not in valid_types:
raise HTTPException(status_code=400,
detail=f"无效的实现类型: {tool_data.implementation_type}")
tool = Tool(
name=tool_data.name,
description=tool_data.description,
@@ -139,28 +180,22 @@ async def create_tool(
implementation_type=tool_data.implementation_type,
implementation_config=tool_data.implementation_config,
is_public=tool_data.is_public,
user_id=current_user.id
user_id=current_user.id,
)
db.add(tool)
db.commit()
db.refresh(tool)
# 转换为响应格式,确保日期时间字段转换为字符串
return {
"id": tool.id,
"name": tool.name,
"description": tool.description,
"category": tool.category,
"function_schema": tool.function_schema,
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
# 刷新注册表
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
@router.put("/{tool_id}", response_model=ToolResponse)
@@ -168,23 +203,20 @@ async def update_tool(
tool_id: str,
tool_data: ToolCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""更新工具"""
"""更新工具"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 检查权限(只有创建者可以更新)
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权更新此工具")
# 检查名称冲突
if tool_data.name != tool.name:
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
tool.name = tool_data.name
tool.description = tool_data.description
tool.category = tool_data.category
@@ -192,47 +224,94 @@ async def update_tool(
tool.implementation_type = tool_data.implementation_type
tool.implementation_config = tool_data.implementation_config
tool.is_public = tool_data.is_public
db.commit()
db.refresh(tool)
# 转换为响应格式,确保日期时间字段转换为字符串
return {
"id": tool.id,
"name": tool.name,
"description": tool.description,
"category": tool.category,
"function_schema": tool.function_schema,
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
}
# 刷新注册表
if tool.name in tool_registry._custom_tool_configs:
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
if tool.name in tool_registry._tool_schemas:
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
@router.delete("/{tool_id}", status_code=200)
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""删除工具"""
"""删除工具"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 检查权限(只有创建者可以删除)
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此工具")
# 内置工具不允许删除
if tool.implementation_type == "builtin":
raise HTTPException(status_code=400, detail="内置工具不允许删除")
db.delete(tool)
db.commit()
# 清理注册表
tool_registry._custom_tool_configs.pop(tool.name, None)
tool_registry._tool_schemas.pop(tool.name, None)
return {"message": "工具已删除"}
# ─── 工具测试 ──────────────────────────────────────────────────
@router.post("/test/http", response_model=TestResponse)
async def test_http_tool(
req: TestHTTPRequest,
current_user: User = Depends(get_current_user),
):
"""测试 HTTP 工具(不保存到数据库)。"""
result = await tool_registry.test_http_tool(
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
args=req.args,
timeout=req.timeout,
)
return TestResponse(**result)
@router.post("/test/code", response_model=TestResponse)
async def test_code_tool(
req: TestCodeRequest,
current_user: User = Depends(get_current_user),
):
"""测试代码工具(不保存到数据库)。"""
result = await tool_registry.test_code_tool(
source=req.source,
args=req.args,
)
return TestResponse(**result)
# ─── 使用计数 ──────────────────────────────────────────────────
@router.post("/{tool_id}/use")
async def record_tool_use(
tool_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""记录工具使用次数Agent 执行时自动调用)。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
tool.use_count = (tool.use_count or 0) + 1
db.commit()
return {"use_count": tool.use_count}