Files
aiagent/backend/app/api/tools.py
renjianbo 5b5eb84dfb fix: #33 内置多模态工具现在在工具市场 /api/v1/tools 中可见
list_tools 端点合并内置工具(image_ocr/image_vision/speech_to_text/text_to_speech 等),
按 scope=public/all 时自动包含,无需额外种子到 DB。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-06 22:13:41 +08:00

375 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工具市场 API — 管理、测试、发现和安装工具。
提供内置工具和用户自定义工具HTTP / 代码段)的 CRUD、测试和执行。
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Path
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.api.auth import get_current_user
from app.core.database import get_db
from app.models.tool import Tool
from app.models.user import User
from app.services.tool_registry import tool_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
# ─── Schema ──────────────────────────────────────────────────────
class ToolCreate(BaseModel):
name: str
description: str
category: Optional[str] = None
function_schema: dict
implementation_type: str # builtin / http / code / workflow
implementation_config: Optional[dict] = None
is_public: bool = False
class ToolResponse(BaseModel):
id: str
name: str
description: str
category: Optional[str] = None
function_schema: dict
implementation_type: str
implementation_config: Optional[dict] = None
is_public: bool = False
use_count: int = 0
user_id: Optional[str] = None
created_at: str = ""
updated_at: str = ""
class TestHTTPRequest(BaseModel):
url: str
method: str = "GET"
headers: Dict[str, str] = {}
body: Optional[Dict[str, Any]] = None
args: Dict[str, Any] = {}
timeout: int = 30
class TestCodeRequest(BaseModel):
source: str
args: Dict[str, Any] = {}
class TestResponse(BaseModel):
success: bool
elapsed_ms: Optional[int] = None
result: Optional[Any] = None
status_code: Optional[int] = None
body: Optional[str] = None
error: Optional[str] = None
# ─── 工具函数 ──────────────────────────────────────────────────
def _tool_to_dict(tool: Tool) -> dict:
return {
"id": tool.id,
"name": tool.name,
"description": tool.description,
"category": tool.category,
"function_schema": tool.function_schema,
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else "",
}
# ─── 工具市场浏览 ──────────────────────────────────────────────
def _builtin_schema_to_tool_dict(schema: dict) -> dict:
"""将 tool_registry 中的 schema 转为与 DB Tool 一致的字典格式。"""
func = schema.get("function", schema)
name = func.get("name", "")
desc = func.get("description", "")
params = func.get("parameters", {})
# 根据工具名自动归类
cat = "系统工具"
if name in ("image_ocr", "image_vision"):
cat = "多模态"
elif name in ("speech_to_text", "text_to_speech"):
cat = "多模态"
elif name.startswith("file_"):
cat = "文件操作"
elif name.startswith("http") or name.startswith("url"):
cat = "网络请求"
elif name.startswith("database") or name.startswith("sql"):
cat = "数据库"
elif name.startswith("agent_"):
cat = "AI Agent"
elif name in ("web_search", "send_email", "browser_use"):
cat = "网络请求"
return {
"id": f"builtin_{name}",
"name": name,
"description": desc,
"category": cat,
"function_schema": schema,
"implementation_type": "builtin",
"implementation_config": None,
"is_public": True,
"use_count": 0,
"user_id": None,
"created_at": "",
"updated_at": "",
}
@router.get("", response_model=List[ToolResponse])
async def list_tools(
category: Optional[str] = Query(None, description="按分类筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
scope: Optional[str] = Query("public", description="public / mine / all"),
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user),
):
"""浏览工具市场(含内置工具 + 数据库工具)。"""
query = db.query(Tool)
if scope == "public":
query = query.filter(Tool.is_public == True)
elif scope == "mine":
if not current_user:
raise HTTPException(status_code=401, detail="需登录")
query = query.filter(Tool.user_id == current_user.id)
if category:
query = query.filter(Tool.category == category)
if search:
query = query.filter(
Tool.name.contains(search) | Tool.description.contains(search)
)
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
result = [_tool_to_dict(t) for t in tools]
db_names = {t["name"] for t in result}
# 合并内置工具(未在 DB 中覆盖的)
if scope != "mine":
for schema in tool_registry.get_all_tool_schemas():
entry = _builtin_schema_to_tool_dict(schema)
if entry["name"] not in db_names:
if category and entry["category"] != category:
continue
if search:
kw = search.lower()
if kw not in entry["name"].lower() and kw not in entry["description"].lower():
continue
result.append(entry)
return result
@router.get("/categories", response_model=List[str])
async def list_categories(db: Session = Depends(get_db)):
"""列出所有工具分类。"""
rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
cats = sorted(set(r[0] for r in rows if r[0]))
# 加上常用分类
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义", "多模态", "系统工具"]
for d in defaults:
if d not in cats:
cats.append(d)
return cats
@router.get("/builtin")
async def list_builtin_tools():
"""列出所有内置工具OpenAI Function 格式)。"""
return tool_registry.get_all_tool_schemas()
@router.get("/{tool_id}", response_model=ToolResponse)
async def get_tool(
tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"),
db: Session = Depends(get_db)
):
"""获取工具详情。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
return _tool_to_dict(tool)
# ─── 工具创建 / 更新 / 删除 ──────────────────────────────────
@router.post("", response_model=ToolResponse, status_code=201)
async def create_tool(
tool_data: ToolCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""创建自定义工具。"""
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
valid_types = {"builtin", "http", "code", "workflow"}
if tool_data.implementation_type not in valid_types:
raise HTTPException(status_code=400,
detail=f"无效的实现类型: {tool_data.implementation_type}")
tool = Tool(
name=tool_data.name,
description=tool_data.description,
category=tool_data.category,
function_schema=tool_data.function_schema,
implementation_type=tool_data.implementation_type,
implementation_config=tool_data.implementation_config,
is_public=tool_data.is_public,
user_id=current_user.id,
)
db.add(tool)
db.commit()
db.refresh(tool)
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
# 刷新注册表
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
@router.put("/{tool_id}", response_model=ToolResponse)
async def update_tool(
tool_data: ToolCreate,
tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新工具。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权更新此工具")
if tool_data.name != tool.name:
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
tool.name = tool_data.name
tool.description = tool_data.description
tool.category = tool_data.category
tool.function_schema = tool_data.function_schema
tool.implementation_type = tool_data.implementation_type
tool.implementation_config = tool_data.implementation_config
tool.is_public = tool_data.is_public
db.commit()
db.refresh(tool)
# 刷新注册表
if tool.name in tool_registry._custom_tool_configs:
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
if tool.name in tool_registry._tool_schemas:
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""删除工具。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此工具")
if tool.implementation_type == "builtin":
raise HTTPException(status_code=400, detail="内置工具不允许删除")
db.delete(tool)
db.commit()
# 清理注册表
tool_registry._custom_tool_configs.pop(tool.name, None)
tool_registry._tool_schemas.pop(tool.name, None)
return {"message": "工具已删除"}
# ─── 工具测试 ──────────────────────────────────────────────────
@router.post("/test/http", response_model=TestResponse)
async def test_http_tool(
req: TestHTTPRequest,
current_user: User = Depends(get_current_user),
):
"""测试 HTTP 工具(不保存到数据库)。"""
result = await tool_registry.test_http_tool(
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
args=req.args,
timeout=req.timeout,
)
return TestResponse(**result)
@router.post("/test/code", response_model=TestResponse)
async def test_code_tool(
req: TestCodeRequest,
current_user: User = Depends(get_current_user),
):
"""测试代码工具(不保存到数据库)。"""
result = await tool_registry.test_code_tool(
source=req.source,
args=req.args,
)
return TestResponse(**result)
# ─── 使用计数 ──────────────────────────────────────────────────
@router.post("/{tool_id}/use")
async def record_tool_use(
tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""记录工具使用次数Agent 执行时自动调用)。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
tool.use_count = (tool.use_count or 0) + 1
db.commit()
return {"use_count": tool.use_count}