Files
aiagent/backend/app/api/tools.py
renjianbo 7b9e0826de feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser
- 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行)
- 新增 agent_vector_memory / knowledge_base 模型及对应数据库表
- 实现 SSE 流式响应与 Agent 预算控制
- AgentChat.vue 集成 MainLayout 导航布局
- 完善测试体系:7 个新测试文件共 110 个测试覆盖
- 修复 conftest.py SQLite 内存数据库连接隔离问题

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-01 22:30:46 +08:00

318 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工具市场 API — 管理、测试、发现和安装工具。
提供内置工具和用户自定义工具HTTP / 代码段)的 CRUD、测试和执行。
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.api.auth import get_current_user
from app.core.database import get_db
from app.models.tool import Tool
from app.models.user import User
from app.services.tool_registry import tool_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
# ─── Schema ──────────────────────────────────────────────────────
class ToolCreate(BaseModel):
name: str
description: str
category: Optional[str] = None
function_schema: dict
implementation_type: str # builtin / http / code / workflow
implementation_config: Optional[dict] = None
is_public: bool = False
class ToolResponse(BaseModel):
id: str
name: str
description: str
category: Optional[str] = None
function_schema: dict
implementation_type: str
implementation_config: Optional[dict] = None
is_public: bool = False
use_count: int = 0
user_id: Optional[str] = None
created_at: str = ""
updated_at: str = ""
class TestHTTPRequest(BaseModel):
url: str
method: str = "GET"
headers: Dict[str, str] = {}
body: Optional[Dict[str, Any]] = None
args: Dict[str, Any] = {}
timeout: int = 30
class TestCodeRequest(BaseModel):
source: str
args: Dict[str, Any] = {}
class TestResponse(BaseModel):
success: bool
elapsed_ms: Optional[int] = None
result: Optional[Any] = None
status_code: Optional[int] = None
body: Optional[str] = None
error: Optional[str] = None
# ─── 工具函数 ──────────────────────────────────────────────────
def _tool_to_dict(tool: Tool) -> dict:
return {
"id": tool.id,
"name": tool.name,
"description": tool.description,
"category": tool.category,
"function_schema": tool.function_schema,
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else "",
}
# ─── 工具市场浏览 ──────────────────────────────────────────────
@router.get("", response_model=List[ToolResponse])
async def list_tools(
category: Optional[str] = Query(None, description="按分类筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
scope: Optional[str] = Query("public", description="public / mine / all"),
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user),
):
"""浏览工具市场。"""
query = db.query(Tool)
if scope == "public":
query = query.filter(Tool.is_public == True)
elif scope == "mine":
if not current_user:
raise HTTPException(status_code=401, detail="需登录")
query = query.filter(Tool.user_id == current_user.id)
if category:
query = query.filter(Tool.category == category)
if search:
query = query.filter(
Tool.name.contains(search) | Tool.description.contains(search)
)
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
return [_tool_to_dict(t) for t in tools]
@router.get("/categories", response_model=List[str])
async def list_categories(db: Session = Depends(get_db)):
"""列出所有工具分类。"""
rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
cats = sorted(set(r[0] for r in rows if r[0]))
# 加上常用分类
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"]
for d in defaults:
if d not in cats:
cats.append(d)
return cats
@router.get("/builtin")
async def list_builtin_tools():
"""列出所有内置工具OpenAI Function 格式)。"""
return tool_registry.get_all_tool_schemas()
@router.get("/{tool_id}", response_model=ToolResponse)
async def get_tool(tool_id: str, db: Session = Depends(get_db)):
"""获取工具详情。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
return _tool_to_dict(tool)
# ─── 工具创建 / 更新 / 删除 ──────────────────────────────────
@router.post("", response_model=ToolResponse, status_code=201)
async def create_tool(
tool_data: ToolCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""创建自定义工具。"""
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
valid_types = {"builtin", "http", "code", "workflow"}
if tool_data.implementation_type not in valid_types:
raise HTTPException(status_code=400,
detail=f"无效的实现类型: {tool_data.implementation_type}")
tool = Tool(
name=tool_data.name,
description=tool_data.description,
category=tool_data.category,
function_schema=tool_data.function_schema,
implementation_type=tool_data.implementation_type,
implementation_config=tool_data.implementation_config,
is_public=tool_data.is_public,
user_id=current_user.id,
)
db.add(tool)
db.commit()
db.refresh(tool)
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
# 刷新注册表
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
@router.put("/{tool_id}", response_model=ToolResponse)
async def update_tool(
tool_id: str,
tool_data: ToolCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新工具。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权更新此工具")
if tool_data.name != tool.name:
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
tool.name = tool_data.name
tool.description = tool_data.description
tool.category = tool_data.category
tool.function_schema = tool_data.function_schema
tool.implementation_type = tool_data.implementation_type
tool.implementation_config = tool_data.implementation_config
tool.is_public = tool_data.is_public
db.commit()
db.refresh(tool)
# 刷新注册表
if tool.name in tool_registry._custom_tool_configs:
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
if tool.name in tool_registry._tool_schemas:
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""删除工具。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此工具")
if tool.implementation_type == "builtin":
raise HTTPException(status_code=400, detail="内置工具不允许删除")
db.delete(tool)
db.commit()
# 清理注册表
tool_registry._custom_tool_configs.pop(tool.name, None)
tool_registry._tool_schemas.pop(tool.name, None)
return {"message": "工具已删除"}
# ─── 工具测试 ──────────────────────────────────────────────────
@router.post("/test/http", response_model=TestResponse)
async def test_http_tool(
req: TestHTTPRequest,
current_user: User = Depends(get_current_user),
):
"""测试 HTTP 工具(不保存到数据库)。"""
result = await tool_registry.test_http_tool(
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
args=req.args,
timeout=req.timeout,
)
return TestResponse(**result)
@router.post("/test/code", response_model=TestResponse)
async def test_code_tool(
req: TestCodeRequest,
current_user: User = Depends(get_current_user),
):
"""测试代码工具(不保存到数据库)。"""
result = await tool_registry.test_code_tool(
source=req.source,
args=req.args,
)
return TestResponse(**result)
# ─── 使用计数 ──────────────────────────────────────────────────
@router.post("/{tool_id}/use")
async def record_tool_use(
tool_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""记录工具使用次数Agent 执行时自动调用)。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
tool.use_count = (tool.use_count or 0) + 1
db.commit()
return {"use_count": tool.use_count}