feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser - 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行) - 新增 agent_vector_memory / knowledge_base 模型及对应数据库表 - 实现 SSE 流式响应与 Agent 预算控制 - AgentChat.vue 集成 MainLayout 导航布局 - 完善测试体系:7 个新测试文件共 110 个测试覆盖 - 修复 conftest.py SQLite 内存数据库连接隔离问题 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -8,8 +8,10 @@ POST /api/v1/agent-chat/bare
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.database import get_db
|
||||
@@ -23,6 +25,7 @@ from app.agent_runtime import (
|
||||
AgentConfig,
|
||||
AgentLLMConfig,
|
||||
AgentToolConfig,
|
||||
AgentBudgetConfig,
|
||||
AgentStep,
|
||||
AgentOrchestrator,
|
||||
OrchestratorAgentConfig,
|
||||
@@ -64,6 +67,14 @@ def _make_llm_logger(
|
||||
return _log
|
||||
|
||||
|
||||
async def _sse_stream(gen: AsyncGenerator[dict, None]) -> AsyncGenerator[str, None]:
|
||||
"""将 run_stream 生成的 dict 事件格式化为 SSE 文本流。"""
|
||||
async for event in gen:
|
||||
event_type = event.get("type", "message")
|
||||
data = {k: v for k, v in event.items() if k != "type"}
|
||||
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
session_id: Optional[str] = None
|
||||
@@ -205,6 +216,39 @@ async def chat_bare(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bare/stream")
|
||||
async def chat_bare_stream(
|
||||
req: ChatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""无需 Agent 配置,使用默认设置直接对话(流式 SSE)。"""
|
||||
config = AgentConfig(
|
||||
name="bare_agent",
|
||||
system_prompt="你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。",
|
||||
llm=AgentLLMConfig(
|
||||
model=req.model or (
|
||||
"gpt-4o-mini" if settings.OPENAI_API_KEY and settings.OPENAI_API_KEY != "your-openai-api-key"
|
||||
else "deepseek-v4-flash"
|
||||
),
|
||||
temperature=req.temperature or 0.7,
|
||||
max_iterations=req.max_iterations or 10,
|
||||
),
|
||||
user_id=current_user.id,
|
||||
)
|
||||
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
|
||||
return StreamingResponse(
|
||||
_sse_stream(runtime.run_stream(req.message)),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}", response_model=ChatResponse)
|
||||
async def chat_with_agent(
|
||||
agent_id: str,
|
||||
@@ -225,9 +269,25 @@ async def chat_with_agent(
|
||||
# 查找 agent 节点的配置(或第一个 llm 节点的配置)
|
||||
agent_node_cfg = _find_agent_node_config(nodes)
|
||||
|
||||
# 构建 system prompt,并自动注入智能体名称
|
||||
system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。"
|
||||
if agent.name:
|
||||
name_prefix = f"你的名字是{agent.name}"
|
||||
if name_prefix not in system_prompt:
|
||||
system_prompt = f"{name_prefix}。\n\n{system_prompt}"
|
||||
|
||||
# 合并执行预算:Agent.budget_config 覆盖默认值
|
||||
budget = AgentBudgetConfig()
|
||||
if agent.budget_config and isinstance(agent.budget_config, dict):
|
||||
bc = agent.budget_config
|
||||
if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None:
|
||||
budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"]))
|
||||
if "max_tool_calls" in bc and bc["max_tool_calls"] is not None:
|
||||
budget.max_tool_calls = max(1, int(bc["max_tool_calls"]))
|
||||
|
||||
config = AgentConfig(
|
||||
name=agent.name,
|
||||
system_prompt=agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。",
|
||||
system_prompt=system_prompt,
|
||||
llm=AgentLLMConfig(
|
||||
provider=agent_node_cfg.get("provider", "openai"),
|
||||
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
|
||||
@@ -238,6 +298,7 @@ async def chat_with_agent(
|
||||
include_tools=agent_node_cfg.get("tools", []),
|
||||
exclude_tools=agent_node_cfg.get("exclude_tools", []),
|
||||
),
|
||||
budget=budget,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
@@ -256,6 +317,68 @@ async def chat_with_agent(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/stream")
|
||||
async def chat_with_agent_stream(
|
||||
agent_id: str,
|
||||
req: ChatRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""与指定的 Agent 对话(流式 SSE)。"""
|
||||
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent 不存在")
|
||||
if agent.user_id and agent.user_id != current_user.id and current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="无权访问该 Agent")
|
||||
|
||||
wc = agent.workflow_config or {}
|
||||
nodes = wc.get("nodes", [])
|
||||
agent_node_cfg = _find_agent_node_config(nodes)
|
||||
|
||||
system_prompt = agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。"
|
||||
if agent.name:
|
||||
name_prefix = f"你的名字是{agent.name}"
|
||||
if name_prefix not in system_prompt:
|
||||
system_prompt = f"{name_prefix}。\n\n{system_prompt}"
|
||||
|
||||
budget = AgentBudgetConfig()
|
||||
if agent.budget_config and isinstance(agent.budget_config, dict):
|
||||
bc = agent.budget_config
|
||||
if "max_llm_invocations" in bc and bc["max_llm_invocations"] is not None:
|
||||
budget.max_llm_invocations = max(1, int(bc["max_llm_invocations"]))
|
||||
if "max_tool_calls" in bc and bc["max_tool_calls"] is not None:
|
||||
budget.max_tool_calls = max(1, int(bc["max_tool_calls"]))
|
||||
|
||||
config = AgentConfig(
|
||||
name=agent.name,
|
||||
system_prompt=system_prompt,
|
||||
llm=AgentLLMConfig(
|
||||
provider=agent_node_cfg.get("provider", "openai"),
|
||||
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
|
||||
temperature=req.temperature or float(agent_node_cfg.get("temperature", 0.7)),
|
||||
max_iterations=req.max_iterations or int(agent_node_cfg.get("max_iterations", 10)),
|
||||
),
|
||||
tools=AgentToolConfig(
|
||||
include_tools=agent_node_cfg.get("tools", []),
|
||||
exclude_tools=agent_node_cfg.get("exclude_tools", []),
|
||||
),
|
||||
budget=budget,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
|
||||
return StreamingResponse(
|
||||
_sse_stream(runtime.run_stream(req.message)),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
|
||||
"""从工作流节点列表中查找第一个 agent 类型或 llm 类型的节点配置。"""
|
||||
if not nodes:
|
||||
|
||||
251
backend/app/api/knowledge_base.py
Normal file
251
backend/app/api/knowledge_base.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
知识库 RAG API。
|
||||
|
||||
提供知识库管理、文档上传、语义搜索和 RAG 查询接口。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.services.knowledge_service import (
|
||||
create_knowledge_base,
|
||||
delete_document,
|
||||
delete_knowledge_base,
|
||||
get_knowledge_base,
|
||||
list_documents,
|
||||
list_knowledge_bases,
|
||||
rag_query,
|
||||
search,
|
||||
upload_document,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/knowledge-bases", tags=["knowledge-base"])
|
||||
|
||||
|
||||
# ─── Schema ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class KBCreateRequest(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
chunk_size: int = 500
|
||||
chunk_overlap: int = 50
|
||||
|
||||
|
||||
class KBResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = ""
|
||||
user_id: Optional[str] = ""
|
||||
chunk_size: int = 500
|
||||
chunk_overlap: int = 50
|
||||
doc_count: int = 0
|
||||
created_at: Optional[str] = ""
|
||||
updated_at: Optional[str] = ""
|
||||
|
||||
|
||||
class DocumentResponse(BaseModel):
|
||||
id: str
|
||||
kb_id: str
|
||||
filename: str
|
||||
file_type: str
|
||||
file_size: int = 0
|
||||
status: str = "pending"
|
||||
error_message: Optional[str] = None
|
||||
chunk_count: int = 0
|
||||
created_at: Optional[str] = None
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
top_k: int = 5
|
||||
min_score: float = 0.3
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
results: List[SearchResult] = []
|
||||
|
||||
|
||||
class RAGRequest(BaseModel):
|
||||
query: str
|
||||
top_k: int = 5
|
||||
min_score: float = 0.3
|
||||
|
||||
|
||||
class RAGResponse(BaseModel):
|
||||
query: str
|
||||
context: str = ""
|
||||
sources: List[Dict[str, Any]] = []
|
||||
found: bool = False
|
||||
|
||||
|
||||
# ─── 知识库 CRUD ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("", response_model=KBResponse)
|
||||
async def api_create_kb(
|
||||
req: KBCreateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""创建知识库。"""
|
||||
kb = create_knowledge_base(
|
||||
db=db,
|
||||
name=req.name,
|
||||
user_id=current_user.id,
|
||||
description=req.description,
|
||||
chunk_size=req.chunk_size,
|
||||
chunk_overlap=req.chunk_overlap,
|
||||
)
|
||||
return KBResponse(**kb.to_dict())
|
||||
|
||||
|
||||
@router.get("", response_model=List[KBResponse])
|
||||
async def api_list_kb(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出知识库。"""
|
||||
kbs = list_knowledge_bases(db, user_id=current_user.id)
|
||||
return [KBResponse(**kb.to_dict()) for kb in kbs]
|
||||
|
||||
|
||||
@router.get("/{kb_id}", response_model=KBResponse)
|
||||
async def api_get_kb(
|
||||
kb_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取知识库详情。"""
|
||||
kb = get_knowledge_base(db, kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
return KBResponse(**kb.to_dict())
|
||||
|
||||
|
||||
@router.delete("/{kb_id}")
|
||||
async def api_delete_kb(
|
||||
kb_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除知识库。"""
|
||||
ok = delete_knowledge_base(db, kb_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="知识库不存在")
|
||||
return {"message": "知识库已删除"}
|
||||
|
||||
|
||||
# ─── 文档管理 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/{kb_id}/documents", response_model=DocumentResponse)
|
||||
async def api_upload_document(
|
||||
kb_id: str,
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""上传文档到知识库(自动解析、分块、生成 Embedding)。"""
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||||
|
||||
content = await file.read()
|
||||
if not content:
|
||||
raise HTTPException(status_code=400, detail="文件内容为空")
|
||||
|
||||
try:
|
||||
doc = await upload_document(
|
||||
db=db,
|
||||
kb_id=kb_id,
|
||||
filename=file.filename,
|
||||
file_content=content,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if doc.status == "failed":
|
||||
# 返回成功但告知处理失败
|
||||
return DocumentResponse(**doc.to_dict())
|
||||
|
||||
return DocumentResponse(**doc.to_dict())
|
||||
|
||||
|
||||
@router.get("/{kb_id}/documents", response_model=List[DocumentResponse])
|
||||
async def api_list_documents(
|
||||
kb_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出知识库中的文档。"""
|
||||
docs = list_documents(db, kb_id)
|
||||
return [DocumentResponse(**d.to_dict()) for d in docs]
|
||||
|
||||
|
||||
@router.delete("/{kb_id}/documents/{doc_id}")
|
||||
async def api_delete_document(
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除文档。"""
|
||||
ok = delete_document(db, doc_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
return {"message": "文档已删除"}
|
||||
|
||||
|
||||
# ─── 搜索 & RAG ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/{kb_id}/search", response_model=SearchResponse)
|
||||
async def api_search(
|
||||
kb_id: str,
|
||||
req: SearchRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""语义搜索知识库。"""
|
||||
results = await search(
|
||||
db=db,
|
||||
kb_id=kb_id,
|
||||
query=req.query,
|
||||
top_k=req.top_k,
|
||||
min_score=req.min_score,
|
||||
)
|
||||
return SearchResponse(results=[SearchResult(**r) for r in results])
|
||||
|
||||
|
||||
@router.post("/{kb_id}/rag", response_model=RAGResponse)
|
||||
async def api_rag(
|
||||
kb_id: str,
|
||||
req: RAGRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""RAG 查询:搜索相关片段并格式化为上下文。"""
|
||||
result = await rag_query(
|
||||
db=db,
|
||||
kb_id=kb_id,
|
||||
query=req.query,
|
||||
top_k=req.top_k,
|
||||
min_score=req.min_score,
|
||||
)
|
||||
return RAGResponse(**result)
|
||||
@@ -1,21 +1,42 @@
|
||||
"""
|
||||
工具管理API
|
||||
工具市场 API — 管理、测试、发现和安装工具。
|
||||
|
||||
提供内置工具和用户自定义工具(HTTP / 代码段)的 CRUD、测试和执行。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
|
||||
from app.api.auth import get_current_user
|
||||
from app.core.database import get_db
|
||||
from app.models.tool import Tool
|
||||
from app.services.tool_registry import tool_registry
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from pydantic import BaseModel
|
||||
from app.services.tool_registry import tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/tools", tags=["tools"])
|
||||
|
||||
|
||||
# ─── Schema ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ToolCreate(BaseModel):
|
||||
"""创建工具请求"""
|
||||
name: str
|
||||
description: str
|
||||
category: Optional[str] = None
|
||||
function_schema: dict
|
||||
implementation_type: str # builtin / http / code / workflow
|
||||
implementation_config: Optional[dict] = None
|
||||
is_public: bool = False
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
category: Optional[str] = None
|
||||
@@ -23,86 +44,39 @@ class ToolCreate(BaseModel):
|
||||
implementation_type: str
|
||||
implementation_config: Optional[dict] = None
|
||||
is_public: bool = False
|
||||
use_count: int = 0
|
||||
user_id: Optional[str] = None
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
"""工具响应"""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
category: Optional[str]
|
||||
function_schema: dict
|
||||
implementation_type: str
|
||||
implementation_config: Optional[dict]
|
||||
is_public: bool
|
||||
use_count: int
|
||||
user_id: Optional[str]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
class TestHTTPRequest(BaseModel):
|
||||
url: str
|
||||
method: str = "GET"
|
||||
headers: Dict[str, str] = {}
|
||||
body: Optional[Dict[str, Any]] = None
|
||||
args: Dict[str, Any] = {}
|
||||
timeout: int = 30
|
||||
|
||||
|
||||
@router.get("", response_model=List[ToolResponse])
|
||||
async def list_tools(
|
||||
category: Optional[str] = Query(None, description="工具分类"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工具列表"""
|
||||
query = db.query(Tool).filter(Tool.is_public == True)
|
||||
|
||||
if category:
|
||||
query = query.filter(Tool.category == category)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
Tool.name.contains(search) |
|
||||
Tool.description.contains(search)
|
||||
)
|
||||
|
||||
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
|
||||
|
||||
# 转换为响应格式,确保日期时间字段转换为字符串
|
||||
result = []
|
||||
for tool in tools:
|
||||
result.append({
|
||||
"id": tool.id,
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"category": tool.category,
|
||||
"function_schema": tool.function_schema,
|
||||
"implementation_type": tool.implementation_type,
|
||||
"implementation_config": tool.implementation_config,
|
||||
"is_public": tool.is_public,
|
||||
"use_count": tool.use_count,
|
||||
"user_id": tool.user_id,
|
||||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
||||
})
|
||||
|
||||
return result
|
||||
class TestCodeRequest(BaseModel):
|
||||
source: str
|
||||
args: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@router.get("/builtin")
|
||||
async def list_builtin_tools():
|
||||
"""获取内置工具列表"""
|
||||
schemas = tool_registry.get_all_tool_schemas()
|
||||
return schemas
|
||||
class TestResponse(BaseModel):
|
||||
success: bool
|
||||
elapsed_ms: Optional[int] = None
|
||||
result: Optional[Any] = None
|
||||
status_code: Optional[int] = None
|
||||
body: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=ToolResponse)
|
||||
async def get_tool(
|
||||
tool_id: str,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取工具详情"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 转换为响应格式,确保日期时间字段转换为字符串
|
||||
# ─── 工具函数 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _tool_to_dict(tool: Tool) -> dict:
|
||||
return {
|
||||
"id": tool.id,
|
||||
"name": tool.name,
|
||||
@@ -115,22 +89,89 @@ async def get_tool(
|
||||
"use_count": tool.use_count,
|
||||
"user_id": tool.user_id,
|
||||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else "",
|
||||
}
|
||||
|
||||
|
||||
# ─── 工具市场浏览 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("", response_model=List[ToolResponse])
|
||||
async def list_tools(
|
||||
category: Optional[str] = Query(None, description="按分类筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
scope: Optional[str] = Query("public", description="public / mine / all"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user),
|
||||
):
|
||||
"""浏览工具市场。"""
|
||||
query = db.query(Tool)
|
||||
|
||||
if scope == "public":
|
||||
query = query.filter(Tool.is_public == True)
|
||||
elif scope == "mine":
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="需登录")
|
||||
query = query.filter(Tool.user_id == current_user.id)
|
||||
|
||||
if category:
|
||||
query = query.filter(Tool.category == category)
|
||||
if search:
|
||||
query = query.filter(
|
||||
Tool.name.contains(search) | Tool.description.contains(search)
|
||||
)
|
||||
|
||||
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
|
||||
return [_tool_to_dict(t) for t in tools]
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[str])
|
||||
async def list_categories(db: Session = Depends(get_db)):
|
||||
"""列出所有工具分类。"""
|
||||
rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
|
||||
cats = sorted(set(r[0] for r in rows if r[0]))
|
||||
# 加上常用分类
|
||||
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"]
|
||||
for d in defaults:
|
||||
if d not in cats:
|
||||
cats.append(d)
|
||||
return cats
|
||||
|
||||
|
||||
@router.get("/builtin")
|
||||
async def list_builtin_tools():
|
||||
"""列出所有内置工具(OpenAI Function 格式)。"""
|
||||
return tool_registry.get_all_tool_schemas()
|
||||
|
||||
|
||||
@router.get("/{tool_id}", response_model=ToolResponse)
|
||||
async def get_tool(tool_id: str, db: Session = Depends(get_db)):
|
||||
"""获取工具详情。"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return _tool_to_dict(tool)
|
||||
|
||||
|
||||
# ─── 工具创建 / 更新 / 删除 ──────────────────────────────────
|
||||
|
||||
|
||||
@router.post("", response_model=ToolResponse, status_code=201)
|
||||
async def create_tool(
|
||||
tool_data: ToolCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""创建工具"""
|
||||
# 检查工具名称是否已存在
|
||||
"""创建自定义工具。"""
|
||||
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在")
|
||||
|
||||
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
|
||||
|
||||
valid_types = {"builtin", "http", "code", "workflow"}
|
||||
if tool_data.implementation_type not in valid_types:
|
||||
raise HTTPException(status_code=400,
|
||||
detail=f"无效的实现类型: {tool_data.implementation_type}")
|
||||
|
||||
tool = Tool(
|
||||
name=tool_data.name,
|
||||
description=tool_data.description,
|
||||
@@ -139,28 +180,22 @@ async def create_tool(
|
||||
implementation_type=tool_data.implementation_type,
|
||||
implementation_config=tool_data.implementation_config,
|
||||
is_public=tool_data.is_public,
|
||||
user_id=current_user.id
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
db.add(tool)
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
# 转换为响应格式,确保日期时间字段转换为字符串
|
||||
return {
|
||||
"id": tool.id,
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"category": tool.category,
|
||||
"function_schema": tool.function_schema,
|
||||
"implementation_type": tool.implementation_type,
|
||||
"implementation_config": tool.implementation_config,
|
||||
"is_public": tool.is_public,
|
||||
"use_count": tool.use_count,
|
||||
"user_id": tool.user_id,
|
||||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
||||
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
|
||||
|
||||
# 刷新注册表
|
||||
tool_registry._custom_tool_configs[tool.name] = {
|
||||
**(tool.implementation_config or {}),
|
||||
"_type": tool.implementation_type,
|
||||
"_db_id": tool.id,
|
||||
}
|
||||
tool_registry._tool_schemas[tool.name] = tool.function_schema
|
||||
|
||||
return _tool_to_dict(tool)
|
||||
|
||||
|
||||
@router.put("/{tool_id}", response_model=ToolResponse)
|
||||
@@ -168,23 +203,20 @@ async def update_tool(
|
||||
tool_id: str,
|
||||
tool_data: ToolCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""更新工具"""
|
||||
"""更新工具。"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 检查权限(只有创建者可以更新)
|
||||
if tool.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权更新此工具")
|
||||
|
||||
# 检查名称冲突
|
||||
|
||||
if tool_data.name != tool.name:
|
||||
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"工具名称 '{tool_data.name}' 已存在")
|
||||
|
||||
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
|
||||
|
||||
tool.name = tool_data.name
|
||||
tool.description = tool_data.description
|
||||
tool.category = tool_data.category
|
||||
@@ -192,47 +224,94 @@ async def update_tool(
|
||||
tool.implementation_type = tool_data.implementation_type
|
||||
tool.implementation_config = tool_data.implementation_config
|
||||
tool.is_public = tool_data.is_public
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
|
||||
# 转换为响应格式,确保日期时间字段转换为字符串
|
||||
return {
|
||||
"id": tool.id,
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"category": tool.category,
|
||||
"function_schema": tool.function_schema,
|
||||
"implementation_type": tool.implementation_type,
|
||||
"implementation_config": tool.implementation_config,
|
||||
"is_public": tool.is_public,
|
||||
"use_count": tool.use_count,
|
||||
"user_id": tool.user_id,
|
||||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||||
"updated_at": tool.updated_at.isoformat() if tool.updated_at else ""
|
||||
}
|
||||
|
||||
# 刷新注册表
|
||||
if tool.name in tool_registry._custom_tool_configs:
|
||||
tool_registry._custom_tool_configs[tool.name] = {
|
||||
**(tool.implementation_config or {}),
|
||||
"_type": tool.implementation_type,
|
||||
"_db_id": tool.id,
|
||||
}
|
||||
if tool.name in tool_registry._tool_schemas:
|
||||
tool_registry._tool_schemas[tool.name] = tool.function_schema
|
||||
|
||||
return _tool_to_dict(tool)
|
||||
|
||||
|
||||
@router.delete("/{tool_id}", status_code=200)
|
||||
@router.delete("/{tool_id}")
|
||||
async def delete_tool(
|
||||
tool_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""删除工具"""
|
||||
"""删除工具。"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
|
||||
# 检查权限(只有创建者可以删除)
|
||||
if tool.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="无权删除此工具")
|
||||
|
||||
# 内置工具不允许删除
|
||||
if tool.implementation_type == "builtin":
|
||||
raise HTTPException(status_code=400, detail="内置工具不允许删除")
|
||||
|
||||
|
||||
db.delete(tool)
|
||||
db.commit()
|
||||
|
||||
|
||||
# 清理注册表
|
||||
tool_registry._custom_tool_configs.pop(tool.name, None)
|
||||
tool_registry._tool_schemas.pop(tool.name, None)
|
||||
|
||||
return {"message": "工具已删除"}
|
||||
|
||||
|
||||
# ─── 工具测试 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/test/http", response_model=TestResponse)
|
||||
async def test_http_tool(
|
||||
req: TestHTTPRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""测试 HTTP 工具(不保存到数据库)。"""
|
||||
result = await tool_registry.test_http_tool(
|
||||
url=req.url,
|
||||
method=req.method,
|
||||
headers=req.headers,
|
||||
body=req.body,
|
||||
args=req.args,
|
||||
timeout=req.timeout,
|
||||
)
|
||||
return TestResponse(**result)
|
||||
|
||||
|
||||
@router.post("/test/code", response_model=TestResponse)
|
||||
async def test_code_tool(
|
||||
req: TestCodeRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""测试代码工具(不保存到数据库)。"""
|
||||
result = await tool_registry.test_code_tool(
|
||||
source=req.source,
|
||||
args=req.args,
|
||||
)
|
||||
return TestResponse(**result)
|
||||
|
||||
|
||||
# ─── 使用计数 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/{tool_id}/use")
|
||||
async def record_tool_use(
|
||||
tool_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""记录工具使用次数(Agent 执行时自动调用)。"""
|
||||
tool = db.query(Tool).filter(Tool.id == tool_id).first()
|
||||
if not tool:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
tool.use_count = (tool.use_count or 0) + 1
|
||||
db.commit()
|
||||
return {"use_count": tool.use_count}
|
||||
|
||||
Reference in New Issue
Block a user