239 lines
7.7 KiB
Python
239 lines
7.7 KiB
Python
"""
|
|
工具管理API
|
|
"""
|
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
from sqlalchemy.orm import Session
|
|
from typing import List, Optional
|
|
from app.core.database import get_db
|
|
from app.models.tool import Tool
|
|
from app.services.tool_registry import tool_registry
|
|
from app.api.auth import get_current_user
|
|
from app.models.user import User
|
|
from pydantic import BaseModel
|
|
|
|
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
|
|
|
|
|
|
class ToolCreate(BaseModel):
|
|
"""创建工具请求"""
|
|
name: str
|
|
description: str
|
|
category: Optional[str] = None
|
|
function_schema: dict
|
|
implementation_type: str
|
|
implementation_config: Optional[dict] = None
|
|
is_public: bool = False
|
|
|
|
|
|
class ToolResponse(BaseModel):
|
|
"""工具响应"""
|
|
id: str
|
|
name: str
|
|
description: str
|
|
category: Optional[str]
|
|
function_schema: dict
|
|
implementation_type: str
|
|
implementation_config: Optional[dict]
|
|
is_public: bool
|
|
use_count: int
|
|
user_id: Optional[str]
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
@router.get("", response_model=List[ToolResponse])
|
|
async def list_tools(
|
|
category: Optional[str] = Query(None, description="工具分类"),
|
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取工具列表"""
|
|
query = db.query(Tool).filter(Tool.is_public == True)
|
|
|
|
if category:
|
|
query = query.filter(Tool.category == category)
|
|
|
|
if search:
|
|
query = query.filter(
|
|
Tool.name.contains(search) |
|
|
Tool.description.contains(search)
|
|
)
|
|
|
|
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
|
|
|
|
# 转换为响应格式,确保日期时间字段转换为字符串
|
|
result = []
|
|
for tool in tools:
|
|
result.append({
|
|
"id": tool.id,
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"category": tool.category,
|
|
"function_schema": tool.function_schema,
|
|
"implementation_type": tool.implementation_type,
|
|
"implementation_config": tool.implementation_config,
|
|
"is_public": tool.is_public,
|
|
"use_count": tool.use_count,
|
|
"user_id": tool.user_id,
|
|
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
|
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
|
})
|
|
|
|
return result
|
|
|
|
|
|
@router.get("/builtin")
|
|
async def list_builtin_tools():
|
|
"""获取内置工具列表"""
|
|
schemas = tool_registry.get_all_tool_schemas()
|
|
return schemas
|
|
|
|
|
|
@router.get("/{tool_id}", response_model=ToolResponse)
|
|
async def get_tool(
|
|
tool_id: str,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""获取工具详情"""
|
|
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail="工具不存在")
|
|
|
|
# 转换为响应格式,确保日期时间字段转换为字符串
|
|
return {
|
|
"id": tool.id,
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"category": tool.category,
|
|
"function_schema": tool.function_schema,
|
|
"implementation_type": tool.implementation_type,
|
|
"implementation_config": tool.implementation_config,
|
|
"is_public": tool.is_public,
|
|
"use_count": tool.use_count,
|
|
"user_id": tool.user_id,
|
|
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
|
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
|
}
|
|
|
|
|
|
@router.post("", response_model=ToolResponse, status_code=201)
|
|
async def create_tool(
|
|
tool_data: ToolCreate,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""创建工具"""
|
|
# 检查工具名称是否已存在
|
|
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
|
|
if existing:
|
|
raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在")
|
|
|
|
tool = Tool(
|
|
name=tool_data.name,
|
|
description=tool_data.description,
|
|
category=tool_data.category,
|
|
function_schema=tool_data.function_schema,
|
|
implementation_type=tool_data.implementation_type,
|
|
implementation_config=tool_data.implementation_config,
|
|
is_public=tool_data.is_public,
|
|
user_id=current_user.id
|
|
)
|
|
|
|
db.add(tool)
|
|
db.commit()
|
|
db.refresh(tool)
|
|
|
|
# 转换为响应格式,确保日期时间字段转换为字符串
|
|
return {
|
|
"id": tool.id,
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"category": tool.category,
|
|
"function_schema": tool.function_schema,
|
|
"implementation_type": tool.implementation_type,
|
|
"implementation_config": tool.implementation_config,
|
|
"is_public": tool.is_public,
|
|
"use_count": tool.use_count,
|
|
"user_id": tool.user_id,
|
|
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
|
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
|
}
|
|
|
|
|
|
@router.put("/{tool_id}", response_model=ToolResponse)
|
|
async def update_tool(
|
|
tool_id: str,
|
|
tool_data: ToolCreate,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""更新工具"""
|
|
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail="工具不存在")
|
|
|
|
# 检查权限(只有创建者可以更新)
|
|
if tool.user_id != current_user.id:
|
|
raise HTTPException(status_code=403, detail="无权更新此工具")
|
|
|
|
# 检查名称冲突
|
|
if tool_data.name != tool.name:
|
|
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
|
|
if existing:
|
|
raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在")
|
|
|
|
tool.name = tool_data.name
|
|
tool.description = tool_data.description
|
|
tool.category = tool_data.category
|
|
tool.function_schema = tool_data.function_schema
|
|
tool.implementation_type = tool_data.implementation_type
|
|
tool.implementation_config = tool_data.implementation_config
|
|
tool.is_public = tool_data.is_public
|
|
|
|
db.commit()
|
|
db.refresh(tool)
|
|
|
|
# 转换为响应格式,确保日期时间字段转换为字符串
|
|
return {
|
|
"id": tool.id,
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"category": tool.category,
|
|
"function_schema": tool.function_schema,
|
|
"implementation_type": tool.implementation_type,
|
|
"implementation_config": tool.implementation_config,
|
|
"is_public": tool.is_public,
|
|
"use_count": tool.use_count,
|
|
"user_id": tool.user_id,
|
|
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
|
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
|
}
|
|
|
|
|
|
@router.delete("/{tool_id}", status_code=200)
|
|
async def delete_tool(
|
|
tool_id: str,
|
|
db: Session = Depends(get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""删除工具"""
|
|
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail="工具不存在")
|
|
|
|
# 检查权限(只有创建者可以删除)
|
|
if tool.user_id != current_user.id:
|
|
raise HTTPException(status_code=403, detail="无权删除此工具")
|
|
|
|
# 内置工具不允许删除
|
|
if tool.implementation_type == "builtin":
|
|
raise HTTPException(status_code=400, detail="内置工具不允许删除")
|
|
|
|
db.delete(tool)
|
|
db.commit()
|
|
|
|
return {"message": "工具已删除"}
|