feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖

- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser
- 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行)
- 新增 agent_vector_memory / knowledge_base 模型及对应数据库表
- 实现 SSE 流式响应与 Agent 预算控制
- AgentChat.vue 集成 MainLayout 导航布局
- 完善测试体系:7 个新测试文件共 110 个测试覆盖
- 修复 conftest.py SQLite 内存数据库连接隔离问题

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
renjianbo
2026-05-01 22:30:46 +08:00
parent 036f533881
commit 7b9e0826de
35 changed files with 4353 additions and 365 deletions

View File

@@ -8,8 +8,10 @@ POST /api/v1/agent-chat/bare
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
import json
from typing import Any, AsyncGenerator, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from app.core.database import get_db
@@ -23,6 +25,7 @@ from app.agent_runtime import (
AgentConfig,
AgentLLMConfig,
AgentToolConfig,
AgentBudgetConfig,
AgentStep,
AgentOrchestrator,
OrchestratorAgentConfig,
@@ -64,6 +67,14 @@ def _make_llm_logger(
return _log
async def _sse_stream(gen: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]:
"""将 run_stream 生成的 dict 事件格式化为 SSE 文本流。"""
async for event in gen:
event_type = event.get("type", "message")
data = {k: v for k, v in event.items() if k != "type"}
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
class ChatRequest(BaseModel):
message: str
session_id: Optional[str] = None
@@ -205,6 +216,39 @@ async def chat_bare(
)
@router.post("/bare/stream")
async def chat_bare_stream(
req: ChatRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""无需 Agent 配置,使用默认设置直接对话(流式 SSE"""
config = AgentConfig(
name="bare_agent",
system_prompt="你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。",
llm=AgentLLMConfig(
model=req.model or (
"gpt-4o-mini" if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key"
else "deepseek-v4-flash"
),
temperature=req.temperature or 0.7,
max_iterations=req.max_iterations or 10,
),
user_id=current_user.id,
)
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
return StreamingResponse(
_sse_stream(runtime.run_stream(req.message)),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.post("/{agent_id}", response_model=ChatResponse)
async def chat_with_agent(
agent_id: str,
@@ -225,9 +269,25 @@ async def chat_with_agent(
# 查找 agent 节点的配置(或第一个 llm 节点的配置)
agent_node_cfg = _find_agent_node_config(nodes)
# 构建 system prompt并自动注入智能体名称
system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。"
if agent.name:
name_prefix = f"你的名字是{agent.name}"
if name_prefix not in system_prompt:
system_prompt = f"{name_prefix}\n\n{system_prompt}"
# 合并执行预算Agent.budget_config 覆盖默认值
budget = AgentBudgetConfig()
if agent.budget_config and isinstance(agent.budget_config, dict):
bc = agent.budget_config
if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None:
budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"]))
if "max_tool_calls" in bc and bc["max_tool_calls"] is not None:
budget.max_tool_calls = max(1, int(bc["max_tool_calls"]))
config = AgentConfig(
name=agent.name,
system_prompt=agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。",
system_prompt=system_prompt,
llm=AgentLLMConfig(
provider=agent_node_cfg.get("provider", "openai"),
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
@@ -238,6 +298,7 @@ async def chat_with_agent(
include_tools=agent_node_cfg.get("tools", []),
exclude_tools=agent_node_cfg.get("exclude_tools", []),
),
budget=budget,
user_id=current_user.id,
)
@@ -256,6 +317,68 @@ async def chat_with_agent(
)
@router.post("/{agent_id}/stream")
async def chat_with_agent_stream(
agent_id: str,
req: ChatRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""与指定的 Agent 对话(流式 SSE"""
agent = db.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
raise HTTPException(status_code=404, detail="Agent 不存在")
if agent.user_id and agent.user_id != current_user.id and current_user.role != "admin":
raise HTTPException(status_code=403, detail="无权访问该 Agent")
wc = agent.workflow_config or {}
nodes = wc.get("nodes", [])
agent_node_cfg = _find_agent_node_config(nodes)
system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。"
if agent.name:
name_prefix = f"你的名字是{agent.name}"
if name_prefix not in system_prompt:
system_prompt = f"{name_prefix}\n\n{system_prompt}"
budget = AgentBudgetConfig()
if agent.budget_config and isinstance(agent.budget_config, dict):
bc = agent.budget_config
if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None:
budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"]))
if "max_tool_calls" in bc and bc["max_tool_calls"] is not None:
budget.max_tool_calls = max(1, int(bc["max_tool_calls"]))
config = AgentConfig(
name=agent.name,
system_prompt=system_prompt,
llm=AgentLLMConfig(
provider=agent_node_cfg.get("provider", "openai"),
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
temperature=req.temperature or float(agent_node_cfg.get("temperature", 0.7)),
max_iterations=req.max_iterations or int(agent_node_cfg.get("max_iterations", 10)),
),
tools=AgentToolConfig(
include_tools=agent_node_cfg.get("tools", []),
exclude_tools=agent_node_cfg.get("exclude_tools", []),
),
budget=budget,
user_id=current_user.id,
)
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
return StreamingResponse(
_sse_stream(runtime.run_stream(req.message)),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
"""从工作流节点列表中查找第一个 agent 类型或 llm 类型的节点配置。"""
if not nodes:

View File

@@ -0,0 +1,251 @@
"""
知识库 RAG API。
提供知识库管理、文档上传、语义搜索和 RAG 查询接口。
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.auth import get_current_user
from app.models.user import User
from app.services.knowledge_service import (
create_knowledge_base,
delete_document,
delete_knowledge_base,
get_knowledge_base,
list_documents,
list_knowledge_bases,
rag_query,
search,
upload_document,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/knowledge-bases", tags=["knowledge-base"])
# ─── Schema ──────────────────────────────────────────────────────
class KBCreateRequest(BaseModel):
name: str
description: str = ""
chunk_size: int = 500
chunk_overlap: int = 50
class KBResponse(BaseModel):
id: str
name: str
description: Optional[str] = ""
user_id: Optional[str] = ""
chunk_size: int = 500
chunk_overlap: int = 50
doc_count: int = 0
created_at: Optional[str] = ""
updated_at: Optional[str] = ""
class DocumentResponse(BaseModel):
id: str
kb_id: str
filename: str
file_type: str
file_size: int = 0
status: str = "pending"
error_message: Optional[str] = None
chunk_count: int = 0
created_at: Optional[str] = None
class SearchRequest(BaseModel):
query: str
top_k: int = 5
min_score: float = 0.3
class SearchResult(BaseModel):
chunk_id: str
content: str
score: float
metadata: Dict[str, Any] = {}
class SearchResponse(BaseModel):
results: List[SearchResult] = []
class RAGRequest(BaseModel):
query: str
top_k: int = 5
min_score: float = 0.3
class RAGResponse(BaseModel):
query: str
context: str = ""
sources: List[Dict[str, Any]] = []
found: bool = False
# ─── 知识库 CRUD ──────────────────────────────────────────────
@router.post("", response_model=KBResponse)
async def api_create_kb(
req: KBCreateRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""创建知识库。"""
kb = create_knowledge_base(
db=db,
name=req.name,
user_id=current_user.id,
description=req.description,
chunk_size=req.chunk_size,
chunk_overlap=req.chunk_overlap,
)
return KBResponse(**kb.to_dict())
@router.get("", response_model=List[KBResponse])
async def api_list_kb(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""列出知识库。"""
kbs = list_knowledge_bases(db, user_id=current_user.id)
return [KBResponse(**kb.to_dict()) for kb in kbs]
@router.get("/{kb_id}", response_model=KBResponse)
async def api_get_kb(
kb_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取知识库详情。"""
kb = get_knowledge_base(db, kb_id)
if not kb:
raise HTTPException(status_code=404, detail="知识库不存在")
return KBResponse(**kb.to_dict())
@router.delete("/{kb_id}")
async def api_delete_kb(
kb_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除知识库。"""
ok = delete_knowledge_base(db, kb_id)
if not ok:
raise HTTPException(status_code=404, detail="知识库不存在")
return {"message": "知识库已删除"}
# ─── 文档管理 ──────────────────────────────────────────────────
@router.post("/{kb_id}/documents", response_model=DocumentResponse)
async def api_upload_document(
kb_id: str,
file: UploadFile = File(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""上传文档到知识库(自动解析、分块、生成 Embedding"""
if not file.filename:
raise HTTPException(status_code=400, detail="文件名不能为空")
content = await file.read()
if not content:
raise HTTPException(status_code=400, detail="文件内容为空")
try:
doc = await upload_document(
db=db,
kb_id=kb_id,
filename=file.filename,
file_content=content,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
if doc.status == "failed":
# 返回成功但告知处理失败
return DocumentResponse(**doc.to_dict())
return DocumentResponse(**doc.to_dict())
@router.get("/{kb_id}/documents", response_model=List[DocumentResponse])
async def api_list_documents(
kb_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""列出知识库中的文档。"""
docs = list_documents(db, kb_id)
return [DocumentResponse(**d.to_dict()) for d in docs]
@router.delete("/{kb_id}/documents/{doc_id}")
async def api_delete_document(
kb_id: str,
doc_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除文档。"""
ok = delete_document(db, doc_id)
if not ok:
raise HTTPException(status_code=404, detail="文档不存在")
return {"message": "文档已删除"}
# ─── 搜索 & RAG ────────────────────────────────────────────────
@router.post("/{kb_id}/search", response_model=SearchResponse)
async def api_search(
kb_id: str,
req: SearchRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""语义搜索知识库。"""
results = await search(
db=db,
kb_id=kb_id,
query=req.query,
top_k=req.top_k,
min_score=req.min_score,
)
return SearchResponse(results=[SearchResult(**r) for r in results])
@router.post("/{kb_id}/rag", response_model=RAGResponse)
async def api_rag(
kb_id: str,
req: RAGRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""RAG 查询:搜索相关片段并格式化为上下文。"""
result = await rag_query(
db=db,
kb_id=kb_id,
query=req.query,
top_k=req.top_k,
min_score=req.min_score,
)
return RAGResponse(**result)

View File

@@ -1,21 +1,42 @@
"""
工具管理API
工具市场 API — 管理、测试、发现和安装工具。
提供内置工具和用户自定义工具HTTP / 代码段)的 CRUD、测试和执行。
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from typing import List, Optional
from app.api.auth import get_current_user
from app.core.database import get_db
from app.models.tool import Tool
from app.services.tool_registry import tool_registry
from app.api.auth import get_current_user
from app.models.user import User
from pydantic import BaseModel
from app.services.tool_registry import tool_registry
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
# ─── Schema ──────────────────────────────────────────────────────
class ToolCreate(BaseModel):
"""创建工具请求"""
name: str
description: str
category: Optional[str] = None
function_schema: dict
implementation_type: str # builtin / http / code / workflow
implementation_config: Optional[dict] = None
is_public: bool = False
class ToolResponse(BaseModel):
id: str
name: str
description: str
category: Optional[str] = None
@@ -23,86 +44,39 @@ class ToolCreate(BaseModel):
implementation_type: str
implementation_config: Optional[dict] = None
is_public: bool = False
use_count: int = 0
user_id: Optional[str] = None
created_at: str = ""
updated_at: str = ""
class ToolResponse(BaseModel):
"""工具响应"""
id: str
name: str
description: str
category: Optional[str]
function_schema: dict
implementation_type: str
implementation_config: Optional[dict]
is_public: bool
use_count: int
user_id: Optional[str]
created_at: str
updated_at: str
class Config:
from_attributes = True
class TestHTTPRequest(BaseModel):
url: str
method: str = "GET"
headers: Dict[str, str] = {}
body: Optional[Dict[str, Any]] = None
args: Dict[str, Any] = {}
timeout: int = 30
@router.get("", response_model=List[ToolResponse])
async def list_tools(
category: Optional[str] = Query(None, description="工具分类"),
search: Optional[str] = Query(None, description="搜索关键词"),
db: Session = Depends(get_db)
):
"""获取工具列表"""
query = db.query(Tool).filter(Tool.is_public == True)
if category:
query = query.filter(Tool.category == category)
if search:
query = query.filter(
Tool.name.contains(search) |
Tool.description.contains(search)
)
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
# 转换为响应格式,确保日期时间字段转换为字符串
result = []
for tool in tools:
result.append({
"id": tool.id,
"name": tool.name,
"description": tool.description,
"category": tool.category,
"function_schema": tool.function_schema,
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
})
return result
class TestCodeRequest(BaseModel):
source: str
args: Dict[str, Any] = {}
@router.get("/builtin")
async def list_builtin_tools():
"""获取内置工具列表"""
schemas = tool_registry.get_all_tool_schemas()
return schemas
class TestResponse(BaseModel):
success: bool
elapsed_ms: Optional[int] = None
result: Optional[Any] = None
status_code: Optional[int] = None
body: Optional[str] = None
error: Optional[str] = None
@router.get("/{tool_id}", response_model=ToolResponse)
async def get_tool(
tool_id: str,
db: Session = Depends(get_db)
):
"""获取工具详情"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 转换为响应格式,确保日期时间字段转换为字符串
# ─── 工具函数 ──────────────────────────────────────────────────
def _tool_to_dict(tool: Tool) -> dict:
return {
"id": tool.id,
"name": tool.name,
@@ -115,22 +89,89 @@ async def get_tool(
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
"updated_at": tool.updated_at.isoformat() if tool.updated_at else "",
}
# ─── 工具市场浏览 ──────────────────────────────────────────────
@router.get("", response_model=List[ToolResponse])
async def list_tools(
category: Optional[str] = Query(None, description="按分类筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
scope: Optional[str] = Query("public", description="public / mine / all"),
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user),
):
"""浏览工具市场。"""
query = db.query(Tool)
if scope == "public":
query = query.filter(Tool.is_public == True)
elif scope == "mine":
if not current_user:
raise HTTPException(status_code=401, detail="需登录")
query = query.filter(Tool.user_id == current_user.id)
if category:
query = query.filter(Tool.category == category)
if search:
query = query.filter(
Tool.name.contains(search) | Tool.description.contains(search)
)
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
return [_tool_to_dict(t) for t in tools]
@router.get("/categories", response_model=List[str])
async def list_categories(db: Session = Depends(get_db)):
"""列出所有工具分类。"""
rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
cats = sorted(set(r[0] for r in rows if r[0]))
# 加上常用分类
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"]
for d in defaults:
if d not in cats:
cats.append(d)
return cats
@router.get("/builtin")
async def list_builtin_tools():
"""列出所有内置工具OpenAI Function 格式)。"""
return tool_registry.get_all_tool_schemas()
@router.get("/{tool_id}", response_model=ToolResponse)
async def get_tool(tool_id: str, db: Session = Depends(get_db)):
"""获取工具详情。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
return _tool_to_dict(tool)
# ─── 工具创建 / 更新 / 删除 ──────────────────────────────────
@router.post("", response_model=ToolResponse, status_code=201)
async def create_tool(
tool_data: ToolCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""创建工具"""
# 检查工具名称是否已存在
"""创建自定义工具"""
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
valid_types = {"builtin", "http", "code", "workflow"}
if tool_data.implementation_type not in valid_types:
raise HTTPException(status_code=400,
detail=f"无效的实现类型: {tool_data.implementation_type}")
tool = Tool(
name=tool_data.name,
description=tool_data.description,
@@ -139,28 +180,22 @@ async def create_tool(
implementation_type=tool_data.implementation_type,
implementation_config=tool_data.implementation_config,
is_public=tool_data.is_public,
user_id=current_user.id
user_id=current_user.id,
)
db.add(tool)
db.commit()
db.refresh(tool)
# 转换为响应格式,确保日期时间字段转换为字符串
return {
"id": tool.id,
"name": tool.name,
"description": tool.description,
"category": tool.category,
"function_schema": tool.function_schema,
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
# 刷新注册表
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
@router.put("/{tool_id}", response_model=ToolResponse)
@@ -168,23 +203,20 @@ async def update_tool(
tool_id: str,
tool_data: ToolCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""更新工具"""
"""更新工具"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 检查权限(只有创建者可以更新)
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权更新此工具")
# 检查名称冲突
if tool_data.name != tool.name:
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
tool.name = tool_data.name
tool.description = tool_data.description
tool.category = tool_data.category
@@ -192,47 +224,94 @@ async def update_tool(
tool.implementation_type = tool_data.implementation_type
tool.implementation_config = tool_data.implementation_config
tool.is_public = tool_data.is_public
db.commit()
db.refresh(tool)
# 转换为响应格式,确保日期时间字段转换为字符串
return {
"id": tool.id,
"name": tool.name,
"description": tool.description,
"category": tool.category,
"function_schema": tool.function_schema,
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
}
# 刷新注册表
if tool.name in tool_registry._custom_tool_configs:
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
if tool.name in tool_registry._tool_schemas:
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
@router.delete("/{tool_id}", status_code=200)
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
):
"""删除工具"""
"""删除工具"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 检查权限(只有创建者可以删除)
if tool.user_id != current_user.id:
raise HTTPException(status_code=403, detail="无权删除此工具")
# 内置工具不允许删除
if tool.implementation_type == "builtin":
raise HTTPException(status_code=400, detail="内置工具不允许删除")
db.delete(tool)
db.commit()
# 清理注册表
tool_registry._custom_tool_configs.pop(tool.name, None)
tool_registry._tool_schemas.pop(tool.name, None)
return {"message": "工具已删除"}
# ─── 工具测试 ──────────────────────────────────────────────────
@router.post("/test/http", response_model=TestResponse)
async def test_http_tool(
req: TestHTTPRequest,
current_user: User = Depends(get_current_user),
):
"""测试 HTTP 工具(不保存到数据库)。"""
result = await tool_registry.test_http_tool(
url=req.url,
method=req.method,
headers=req.headers,
body=req.body,
args=req.args,
timeout=req.timeout,
)
return TestResponse(**result)
@router.post("/test/code", response_model=TestResponse)
async def test_code_tool(
req: TestCodeRequest,
current_user: User = Depends(get_current_user),
):
"""测试代码工具(不保存到数据库)。"""
result = await tool_registry.test_code_tool(
source=req.source,
args=req.args,
)
return TestResponse(**result)
# ─── 使用计数 ──────────────────────────────────────────────────
@router.post("/{tool_id}/use")
async def record_tool_use(
tool_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""记录工具使用次数Agent 执行时自动调用)。"""
tool = db.query(Tool).filter(Tool.id == tool_id).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
tool.use_count = (tool.use_count or 0) + 1
db.commit()
return {"use_count": tool.use_count}