""" 工具市场 API — 管理、测试、发现和安装工具。 提供内置工具和用户自定义工具(HTTP / 代码段)的 CRUD、测试和执行。 """ from __future__ import annotations import logging from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Path from pydantic import BaseModel from sqlalchemy.orm import Session from app.api.auth import get_current_user from app.core.database import get_db from app.models.tool import Tool from app.models.user import User from app.services.tool_registry import tool_registry logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/tools", tags=["tools"]) # ─── Schema ────────────────────────────────────────────────────── class ToolCreate(BaseModel): name: str description: str category: Optional[str] = None function_schema: dict implementation_type: str # builtin / http / code / workflow implementation_config: Optional[dict] = None is_public: bool = False class ToolResponse(BaseModel): id: str name: str description: str category: Optional[str] = None function_schema: dict implementation_type: str implementation_config: Optional[dict] = None is_public: bool = False use_count: int = 0 user_id: Optional[str] = None created_at: str = "" updated_at: str = "" class TestHTTPRequest(BaseModel): url: str method: str = "GET" headers: Dict[str, str] = {} body: Optional[Dict[str, Any]] = None args: Dict[str, Any] = {} timeout: int = 30 class TestCodeRequest(BaseModel): source: str args: Dict[str, Any] = {} class TestResponse(BaseModel): success: bool elapsed_ms: Optional[int] = None result: Optional[Any] = None status_code: Optional[int] = None body: Optional[str] = None error: Optional[str] = None # ─── 工具函数 ────────────────────────────────────────────────── def _tool_to_dict(tool: Tool) -> dict: return { "id": tool.id, "name": tool.name, "description": tool.description, "category": tool.category, "function_schema": tool.function_schema, "implementation_type": tool.implementation_type, "implementation_config": tool.implementation_config, "is_public": tool.is_public, "use_count": tool.use_count, "user_id": tool.user_id, "created_at": tool.created_at.isoformat() if tool.created_at else "", "updated_at": tool.updated_at.isoformat() if tool.updated_at else "", } # ─── 工具市场浏览 ────────────────────────────────────────────── @router.get("", response_model=List[ToolResponse]) async def list_tools( category: Optional[str] = Query(None, description="按分类筛选"), search: Optional[str] = Query(None, description="搜索关键词"), scope: Optional[str] = Query("public", description="public / mine / all"), db: Session = Depends(get_db), current_user: Optional[User] = Depends(get_current_user), ): """浏览工具市场。""" query = db.query(Tool) if scope == "public": query = query.filter(Tool.is_public == True) elif scope == "mine": if not current_user: raise HTTPException(status_code=401, detail="需登录") query = query.filter(Tool.user_id == current_user.id) if category: query = query.filter(Tool.category == category) if search: query = query.filter( Tool.name.contains(search) | Tool.description.contains(search) ) tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all() return [_tool_to_dict(t) for t in tools] @router.get("/categories", response_model=List[str]) async def list_categories(db: Session = Depends(get_db)): """列出所有工具分类。""" rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all() cats = sorted(set(r[0] for r in rows if r[0])) # 加上常用分类 defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"] for d in defaults: if d not in cats: cats.append(d) return cats @router.get("/builtin") async def list_builtin_tools(): """列出所有内置工具(OpenAI Function 格式)。""" return tool_registry.get_all_tool_schemas() @router.get("/{tool_id}", response_model=ToolResponse) async def get_tool( tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"), db: Session = Depends(get_db) ): """获取工具详情。""" tool = db.query(Tool).filter(Tool.id == tool_id).first() if not tool: raise HTTPException(status_code=404, detail="工具不存在") return _tool_to_dict(tool) # ─── 工具创建 / 更新 / 删除 ────────────────────────────────── @router.post("", response_model=ToolResponse, status_code=201) async def create_tool( tool_data: ToolCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """创建自定义工具。""" existing = db.query(Tool).filter(Tool.name == tool_data.name).first() if existing: raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在") valid_types = {"builtin", "http", "code", "workflow"} if tool_data.implementation_type not in valid_types: raise HTTPException(status_code=400, detail=f"无效的实现类型: {tool_data.implementation_type}") tool = Tool( name=tool_data.name, description=tool_data.description, category=tool_data.category, function_schema=tool_data.function_schema, implementation_type=tool_data.implementation_type, implementation_config=tool_data.implementation_config, is_public=tool_data.is_public, user_id=current_user.id, ) db.add(tool) db.commit() db.refresh(tool) logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type) # 刷新注册表 tool_registry._custom_tool_configs[tool.name] = { **(tool.implementation_config or {}), "_type": tool.implementation_type, "_db_id": tool.id, } tool_registry._tool_schemas[tool.name] = tool.function_schema return _tool_to_dict(tool) @router.put("/{tool_id}", response_model=ToolResponse) async def update_tool( tool_data: ToolCreate, tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """更新工具。""" tool = db.query(Tool).filter(Tool.id == tool_id).first() if not tool: raise HTTPException(status_code=404, detail="工具不存在") if tool.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权更新此工具") if tool_data.name != tool.name: existing = db.query(Tool).filter(Tool.name == tool_data.name).first() if existing: raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在") tool.name = tool_data.name tool.description = tool_data.description tool.category = tool_data.category tool.function_schema = tool_data.function_schema tool.implementation_type = tool_data.implementation_type tool.implementation_config = tool_data.implementation_config tool.is_public = tool_data.is_public db.commit() db.refresh(tool) # 刷新注册表 if tool.name in tool_registry._custom_tool_configs: tool_registry._custom_tool_configs[tool.name] = { **(tool.implementation_config or {}), "_type": tool.implementation_type, "_db_id": tool.id, } if tool.name in tool_registry._tool_schemas: tool_registry._tool_schemas[tool.name] = tool.function_schema return _tool_to_dict(tool) @router.delete("/{tool_id}") async def delete_tool( tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """删除工具。""" tool = db.query(Tool).filter(Tool.id == tool_id).first() if not tool: raise HTTPException(status_code=404, detail="工具不存在") if tool.user_id != current_user.id: raise HTTPException(status_code=403, detail="无权删除此工具") if tool.implementation_type == "builtin": raise HTTPException(status_code=400, detail="内置工具不允许删除") db.delete(tool) db.commit() # 清理注册表 tool_registry._custom_tool_configs.pop(tool.name, None) tool_registry._tool_schemas.pop(tool.name, None) return {"message": "工具已删除"} # ─── 工具测试 ────────────────────────────────────────────────── @router.post("/test/http", response_model=TestResponse) async def test_http_tool( req: TestHTTPRequest, current_user: User = Depends(get_current_user), ): """测试 HTTP 工具(不保存到数据库)。""" result = await tool_registry.test_http_tool( url=req.url, method=req.method, headers=req.headers, body=req.body, args=req.args, timeout=req.timeout, ) return TestResponse(**result) @router.post("/test/code", response_model=TestResponse) async def test_code_tool( req: TestCodeRequest, current_user: User = Depends(get_current_user), ): """测试代码工具(不保存到数据库)。""" result = await tool_registry.test_code_tool( source=req.source, args=req.args, ) return TestResponse(**result) # ─── 使用计数 ────────────────────────────────────────────────── @router.post("/{tool_id}/use") async def record_tool_use( tool_id: str = Path(..., pattern=r"^[0-9a-f-]{20,}$"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): """记录工具使用次数(Agent 执行时自动调用)。""" tool = db.query(Tool).filter(Tool.id == tool_id).first() if not tool: raise HTTPException(status_code=404, detail="工具不存在") tool.use_count = (tool.use_count or 0) + 1 db.commit() return {"use_count": tool.use_count}