Files
aiagent/backend/app/api/tools.py

375 lines
13 KiB
Python
Raw Normal View History

2026-01-23 09:49:45 +08:00
"""
工具市场 API 管理测试发现和安装工具
提供内置工具和用户自定义工具HTTP / 代码段 CRUD测试和执行
2026-01-23 09:49:45 +08:00
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Path
from pydantic import BaseModel
2026-01-23 09:49:45 +08:00
from sqlalchemy.orm import Session
from app.api.auth import get_current_user
2026-01-23 09:49:45 +08:00
from app.core.database import get_db
from app.models.tool import Tool
from app.models.user import User
from app.services.tool_registry import tool_registry
2026-01-23 09:49:45 +08:00
logger = logging.getLogger(__name__)
2026-01-23 09:49:45 +08:00
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
# ─── Schema ──────────────────────────────────────────────────────
2026-01-23 09:49:45 +08:00
class ToolCreate(BaseModel):
name: str
description: str
category: Optional[str] = None
function_schema: dict
implementation_type: str # builtin / http / code / workflow
2026-01-23 09:49:45 +08:00
implementation_config: Optional[dict] = None
is_public: bool = False
class ToolResponse(BaseModel):
id: str
name: str
description: str
category: Optional[str] = None
2026-01-23 09:49:45 +08:00
function_schema: dict
implementation_type: str
implementation_config: Optional[dict] = None
is_public: bool = False
use_count: int = 0
user_id: Optional[str] = None
created_at: str = ""
updated_at: str = ""
class TestHTTPRequest(BaseModel):
url: str
method: str = "GET"
headers: Dict[str, str] = {}
body: Optional[Dict[str, Any]] = None
args: Dict[str, Any] = {}
timeout: int = 30
class TestCodeRequest(BaseModel):
source: str
args: Dict[str, Any] = {}
class TestResponse(BaseModel):
success: bool
elapsed_ms: Optional[int] = None
result: Optional[Any] = None
status_code: Optional[int] = None
body: Optional[str] = None
error: Optional[str] = None
# ─── 工具函数 ──────────────────────────────────────────────────
def _tool_to_dict(tool: Tool) -> dict:
return {
"id": tool.id,
"name": tool.name,
"description": tool.description,
"category": tool.category,
"function_schema": tool.function_schema,
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else "",
}
# ─── 工具市场浏览 ──────────────────────────────────────────────
2026-01-23 09:49:45 +08:00
def _builtin_schema_to_tool_dict(schema: dict) -> dict:
"""将 tool_registry 中的 schema 转为与 DB Tool 一致的字典格式。"""
func = schema.get("function", schema)
name = func.get("name", "")
desc = func.get("description", "")
params = func.get("parameters", {})
# 根据工具名自动归类
cat = "系统工具"
if name in ("image_ocr", "image_vision"):
cat = "多模态"
elif name in ("speech_to_text", "text_to_speech"):
cat = "多模态"
elif name.startswith("file_"):
cat = "文件操作"
elif name.startswith("http") or name.startswith("url"):
cat = "网络请求"
elif name.startswith("database") or name.startswith("sql"):
cat = "数据库"
elif name.startswith("agent_"):
cat = "AI Agent"
elif name in ("web_search", "send_email", "browser_use"):
cat = "网络请求"
return {
"id": f"builtin_{name}",
"name": name,
"description": desc,
"category": cat,
"function_schema": schema,
"implementation_type": "builtin",
"implementation_config": None,
"is_public": True,
"use_count": 0,
"user_id": None,
"created_at": "",
"updated_at": "",
}
2026-01-23 09:49:45 +08:00
@router.get("", response_model=List[ToolResponse])
async def list_tools(
category: Optional[str] = Query(None, description="按分类筛选"),
2026-01-23 09:49:45 +08:00
search: Optional[str] = Query(None, description="搜索关键词"),
scope: Optional[str] = Query("public", description="public / mine / all"),
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user),
2026-01-23 09:49:45 +08:00
):
"""浏览工具市场(含内置工具 + 数据库工具)。"""
query = db.query(Tool)
if scope == "public":
query = query.filter(Tool.is_public == True)
elif scope == "mine":
if not current_user:
raise HTTPException(status_code=401, detail="需登录")
query = query.filter(Tool.user_id == current_user.id)
2026-01-23 09:49:45 +08:00
if category:
query = query.filter(Tool.category == category)
if search:
query = query.filter(
Tool.name.contains(search) | Tool.description.contains(search)
2026-01-23 09:49:45 +08:00
)
2026-01-23 09:49:45 +08:00
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
result = [_tool_to_dict(t) for t in tools]
db_names = {t["name"] for t in result}
# 合并内置工具(未在 DB 中覆盖的)
if scope != "mine":
for schema in tool_registry.get_all_tool_schemas():
entry = _builtin_schema_to_tool_dict(schema)
if entry["name"] not in db_names:
if category and entry["category"] != category:
continue
if search:
kw = search.lower()
if kw not in entry["name"].lower() and kw not in entry["description"].lower():
continue
result.append(entry)
return result
@router.get("/categories", response_model=List[str])
async def list_categories(db: Session = Depends(get_db)):
"""列出所有工具分类。"""
rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
cats = sorted(set(r[0] for r in rows if r[0]))
# 加上常用分类
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义", "多模态", "系统工具"]
for d in defaults:
if d not in cats:
cats.append(d)
return cats
2026-01-23 09:49:45 +08:00
@router.get("/builtin")
async def list_builtin_tools():
"""列出所有内置工具OpenAI Function 格式)。"""
return tool_registry.get_all_tool_schemas()
2026-01-23 09:49:45 +08:00
@router.get("/{tool_id}", response_model=ToolResponse)
async def get_tool(
tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"),
db: Session = Depends(get_db)
):
"""获取工具详情。"""
2026-01-23 09:49:45 +08:00
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
return _tool_to_dict(tool)
# ─── 工具创建 / 更新 / 删除 ──────────────────────────────────
2026-01-23 09:49:45 +08:00
@router.post("", response_model=ToolResponse, status_code=201)
async def create_tool(
tool_data: ToolCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
2026-01-23 09:49:45 +08:00
):
"""创建自定义工具。"""
2026-01-23 09:49:45 +08:00
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
valid_types = {"builtin", "http", "code", "workflow"}
if tool_data.implementation_type not in valid_types:
raise HTTPException(status_code=400,
detail=f"无效的实现类型: {tool_data.implementation_type}")
2026-01-23 09:49:45 +08:00
tool = Tool(
name=tool_data.name,
description=tool_data.description,
category=tool_data.category,
function_schema=tool_data.function_schema,
implementation_type=tool_data.implementation_type,
implementation_config=tool_data.implementation_config,
is_public=tool_data.is_public,
user_id=current_user.id,
2026-01-23 09:49:45 +08:00
)
db.add(tool)
db.commit()
db.refresh(tool)
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
# 刷新注册表
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
2026-03-06 22:31:41 +08:00
}
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
2026-01-23 09:49:45 +08:00
@router.put("/{tool_id}", response_model=ToolResponse)
async def update_tool(
tool_data: ToolCreate,
tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"),
2026-01-23 09:49:45 +08:00
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
2026-01-23 09:49:45 +08:00
):
"""更新工具。"""
2026-01-23 09:49:45 +08:00
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权更新此工具")
2026-01-23 09:49:45 +08:00
if tool_data.name != tool.name:
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
2026-01-23 09:49:45 +08:00
tool.name = tool_data.name
tool.description = tool_data.description
tool.category = tool_data.category
tool.function_schema = tool_data.function_schema
tool.implementation_type = tool_data.implementation_type
tool.implementation_config = tool_data.implementation_config
tool.is_public = tool_data.is_public
2026-01-23 09:49:45 +08:00
db.commit()
db.refresh(tool)
# 刷新注册表
if tool.name in tool_registry._custom_tool_configs:
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
if tool.name in tool_registry._tool_schemas:
tool_registry._tool_schemas[tool.name] = tool.function_schema
2026-01-23 09:49:45 +08:00
return _tool_to_dict(tool)
@router.delete("/{tool_id}")
2026-01-23 09:49:45 +08:00
async def delete_tool(
tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"),
2026-01-23 09:49:45 +08:00
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
2026-01-23 09:49:45 +08:00
):
"""删除工具。"""
2026-01-23 09:49:45 +08:00
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此工具")
if tool.implementation_type == "builtin":
raise HTTPException(status_code=400, detail="内置工具不允许删除")
2026-01-23 09:49:45 +08:00
db.delete(tool)
db.commit()
# 清理注册表
tool_registry._custom_tool_configs.pop(tool.name, None)
tool_registry._tool_schemas.pop(tool.name, None)
2026-01-23 09:49:45 +08:00
return {"message": "工具已删除"}
# ─── 工具测试 ──────────────────────────────────────────────────
@router.post("/test/http", response_model=TestResponse)
async def test_http_tool(
req: TestHTTPRequest,
current_user: User = Depends(get_current_user),
):
"""测试 HTTP 工具(不保存到数据库)。"""
result = await tool_registry.test_http_tool(
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
args=req.args,
timeout=req.timeout,
)
return TestResponse(**result)
@router.post("/test/code", response_model=TestResponse)
async def test_code_tool(
req: TestCodeRequest,
current_user: User = Depends(get_current_user),
):
"""测试代码工具(不保存到数据库)。"""
result = await tool_registry.test_code_tool(
source=req.source,
args=req.args,
)
return TestResponse(**result)
# ─── 使用计数 ──────────────────────────────────────────────────
@router.post("/{tool_id}/use")
async def record_tool_use(
tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""记录工具使用次数Agent 执行时自动调用)。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
tool.use_count = (tool.use_count or 0) + 1
db.commit()
return {"use_count": tool.use_count}