- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser - 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行) - 新增 agent_vector_memory / knowledge_base 模型及对应数据库表 - 实现 SSE 流式响应与 Agent 预算控制 - AgentChat.vue 集成 MainLayout 导航布局 - 完善测试体系:7 个新测试文件共 110 个测试覆盖 - 修复 conftest.py SQLite 内存数据库连接隔离问题 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
318 lines
10 KiB
Python
318 lines
10 KiB
Python
"""
|
||
工具市场 API — 管理、测试、发现和安装工具。
|
||
|
||
提供内置工具和用户自定义工具(HTTP / 代码段)的 CRUD、测试和执行。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
from pydantic import BaseModel
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.api.auth import get_current_user
|
||
from app.core.database import get_db
|
||
from app.models.tool import Tool
|
||
from app.models.user import User
|
||
from app.services.tool_registry import tool_registry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
|
||
|
||
|
||
# ─── Schema ──────────────────────────────────────────────────────
|
||
|
||
|
||
class ToolCreate(BaseModel):
|
||
name: str
|
||
description: str
|
||
category: Optional[str] = None
|
||
function_schema: dict
|
||
implementation_type: str # builtin / http / code / workflow
|
||
implementation_config: Optional[dict] = None
|
||
is_public: bool = False
|
||
|
||
|
||
class ToolResponse(BaseModel):
|
||
id: str
|
||
name: str
|
||
description: str
|
||
category: Optional[str] = None
|
||
function_schema: dict
|
||
implementation_type: str
|
||
implementation_config: Optional[dict] = None
|
||
is_public: bool = False
|
||
use_count: int = 0
|
||
user_id: Optional[str] = None
|
||
created_at: str = ""
|
||
updated_at: str = ""
|
||
|
||
|
||
class TestHTTPRequest(BaseModel):
|
||
url: str
|
||
method: str = "GET"
|
||
headers: Dict[str, str] = {}
|
||
body: Optional[Dict[str, Any]] = None
|
||
args: Dict[str, Any] = {}
|
||
timeout: int = 30
|
||
|
||
|
||
class TestCodeRequest(BaseModel):
|
||
source: str
|
||
args: Dict[str, Any] = {}
|
||
|
||
|
||
class TestResponse(BaseModel):
|
||
success: bool
|
||
elapsed_ms: Optional[int] = None
|
||
result: Optional[Any] = None
|
||
status_code: Optional[int] = None
|
||
body: Optional[str] = None
|
||
error: Optional[str] = None
|
||
|
||
|
||
# ─── 工具函数 ──────────────────────────────────────────────────
|
||
|
||
|
||
def _tool_to_dict(tool: Tool) -> dict:
|
||
return {
|
||
"id": tool.id,
|
||
"name": tool.name,
|
||
"description": tool.description,
|
||
"category": tool.category,
|
||
"function_schema": tool.function_schema,
|
||
"implementation_type": tool.implementation_type,
|
||
"implementation_config": tool.implementation_config,
|
||
"is_public": tool.is_public,
|
||
"use_count": tool.use_count,
|
||
"user_id": tool.user_id,
|
||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else "",
|
||
}
|
||
|
||
|
||
# ─── 工具市场浏览 ──────────────────────────────────────────────
|
||
|
||
|
||
@router.get("", response_model=List[ToolResponse])
|
||
async def list_tools(
|
||
category: Optional[str] = Query(None, description="按分类筛选"),
|
||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||
scope: Optional[str] = Query("public", description="public / mine / all"),
|
||
db: Session = Depends(get_db),
|
||
current_user: Optional[User] = Depends(get_current_user),
|
||
):
|
||
"""浏览工具市场。"""
|
||
query = db.query(Tool)
|
||
|
||
if scope == "public":
|
||
query = query.filter(Tool.is_public == True)
|
||
elif scope == "mine":
|
||
if not current_user:
|
||
raise HTTPException(status_code=401, detail="需登录")
|
||
query = query.filter(Tool.user_id == current_user.id)
|
||
|
||
if category:
|
||
query = query.filter(Tool.category == category)
|
||
if search:
|
||
query = query.filter(
|
||
Tool.name.contains(search) | Tool.description.contains(search)
|
||
)
|
||
|
||
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
|
||
return [_tool_to_dict(t) for t in tools]
|
||
|
||
|
||
@router.get("/categories", response_model=List[str])
|
||
async def list_categories(db: Session = Depends(get_db)):
|
||
"""列出所有工具分类。"""
|
||
rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
|
||
cats = sorted(set(r[0] for r in rows if r[0]))
|
||
# 加上常用分类
|
||
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"]
|
||
for d in defaults:
|
||
if d not in cats:
|
||
cats.append(d)
|
||
return cats
|
||
|
||
|
||
@router.get("/builtin")
|
||
async def list_builtin_tools():
|
||
"""列出所有内置工具(OpenAI Function 格式)。"""
|
||
return tool_registry.get_all_tool_schemas()
|
||
|
||
|
||
@router.get("/{tool_id}", response_model=ToolResponse)
|
||
async def get_tool(tool_id: str, db: Session = Depends(get_db)):
|
||
"""获取工具详情。"""
|
||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||
if not tool:
|
||
raise HTTPException(status_code=404, detail="工具不存在")
|
||
return _tool_to_dict(tool)
|
||
|
||
|
||
# ─── 工具创建 / 更新 / 删除 ──────────────────────────────────
|
||
|
||
|
||
@router.post("", response_model=ToolResponse, status_code=201)
|
||
async def create_tool(
|
||
tool_data: ToolCreate,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""创建自定义工具。"""
|
||
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
|
||
if existing:
|
||
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
|
||
|
||
valid_types = {"builtin", "http", "code", "workflow"}
|
||
if tool_data.implementation_type not in valid_types:
|
||
raise HTTPException(status_code=400,
|
||
detail=f"无效的实现类型: {tool_data.implementation_type}")
|
||
|
||
tool = Tool(
|
||
name=tool_data.name,
|
||
description=tool_data.description,
|
||
category=tool_data.category,
|
||
function_schema=tool_data.function_schema,
|
||
implementation_type=tool_data.implementation_type,
|
||
implementation_config=tool_data.implementation_config,
|
||
is_public=tool_data.is_public,
|
||
user_id=current_user.id,
|
||
)
|
||
db.add(tool)
|
||
db.commit()
|
||
db.refresh(tool)
|
||
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
|
||
|
||
# 刷新注册表
|
||
tool_registry._custom_tool_configs[tool.name] = {
|
||
**(tool.implementation_config or {}),
|
||
"_type": tool.implementation_type,
|
||
"_db_id": tool.id,
|
||
}
|
||
tool_registry._tool_schemas[tool.name] = tool.function_schema
|
||
|
||
return _tool_to_dict(tool)
|
||
|
||
|
||
@router.put("/{tool_id}", response_model=ToolResponse)
|
||
async def update_tool(
|
||
tool_id: str,
|
||
tool_data: ToolCreate,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""更新工具。"""
|
||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||
if not tool:
|
||
raise HTTPException(status_code=404, detail="工具不存在")
|
||
if tool.user_id != current_user.id:
|
||
raise HTTPException(status_code=403, detail="无权更新此工具")
|
||
|
||
if tool_data.name != tool.name:
|
||
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
|
||
if existing:
|
||
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
|
||
|
||
tool.name = tool_data.name
|
||
tool.description = tool_data.description
|
||
tool.category = tool_data.category
|
||
tool.function_schema = tool_data.function_schema
|
||
tool.implementation_type = tool_data.implementation_type
|
||
tool.implementation_config = tool_data.implementation_config
|
||
tool.is_public = tool_data.is_public
|
||
|
||
db.commit()
|
||
db.refresh(tool)
|
||
|
||
# 刷新注册表
|
||
if tool.name in tool_registry._custom_tool_configs:
|
||
tool_registry._custom_tool_configs[tool.name] = {
|
||
**(tool.implementation_config or {}),
|
||
"_type": tool.implementation_type,
|
||
"_db_id": tool.id,
|
||
}
|
||
if tool.name in tool_registry._tool_schemas:
|
||
tool_registry._tool_schemas[tool.name] = tool.function_schema
|
||
|
||
return _tool_to_dict(tool)
|
||
|
||
|
||
@router.delete("/{tool_id}")
|
||
async def delete_tool(
|
||
tool_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""删除工具。"""
|
||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||
if not tool:
|
||
raise HTTPException(status_code=404, detail="工具不存在")
|
||
if tool.user_id != current_user.id:
|
||
raise HTTPException(status_code=403, detail="无权删除此工具")
|
||
if tool.implementation_type == "builtin":
|
||
raise HTTPException(status_code=400, detail="内置工具不允许删除")
|
||
|
||
db.delete(tool)
|
||
db.commit()
|
||
|
||
# 清理注册表
|
||
tool_registry._custom_tool_configs.pop(tool.name, None)
|
||
tool_registry._tool_schemas.pop(tool.name, None)
|
||
|
||
return {"message": "工具已删除"}
|
||
|
||
|
||
# ─── 工具测试 ──────────────────────────────────────────────────
|
||
|
||
|
||
@router.post("/test/http", response_model=TestResponse)
|
||
async def test_http_tool(
|
||
req: TestHTTPRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""测试 HTTP 工具(不保存到数据库)。"""
|
||
result = await tool_registry.test_http_tool(
|
||
url=req.url,
|
||
method=req.method,
|
||
headers=req.headers,
|
||
body=req.body,
|
||
args=req.args,
|
||
timeout=req.timeout,
|
||
)
|
||
return TestResponse(**result)
|
||
|
||
|
||
@router.post("/test/code", response_model=TestResponse)
|
||
async def test_code_tool(
|
||
req: TestCodeRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""测试代码工具(不保存到数据库)。"""
|
||
result = await tool_registry.test_code_tool(
|
||
source=req.source,
|
||
args=req.args,
|
||
)
|
||
return TestResponse(**result)
|
||
|
||
|
||
# ─── 使用计数 ──────────────────────────────────────────────────
|
||
|
||
|
||
@router.post("/{tool_id}/use")
|
||
async def record_tool_use(
|
||
tool_id: str,
|
||
db: Session = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""记录工具使用次数(Agent 执行时自动调用)。"""
|
||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||
if not tool:
|
||
raise HTTPException(status_code=404, detail="工具不存在")
|
||
tool.use_count = (tool.use_count or 0) + 1
|
||
db.commit()
|
||
return {"use_count": tool.use_count}
|