fix: #33 内置多模态工具现在在工具市场 /api/v1/tools 中可见
list_tools 端点合并内置工具(image_ocr/image_vision/speech_to_text/text_to_speech 等), 按 scope=public/all 时自动包含,无需额外种子到 DB。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -183,6 +183,50 @@ async def orchestrate_agents(
|
||||
)
|
||||
|
||||
|
||||
class GraphOrchestrateRequest(BaseModel):
|
||||
"""图编排请求 — 以 nodes + edges 描述 DAG"""
|
||||
message: str
|
||||
nodes: List[Dict[str, Any]] = Field(..., description="编排节点列表")
|
||||
edges: List[Dict[str, Any]] = Field(default_factory=list, description="编排连线列表")
|
||||
model: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/orchestrate/graph", response_model=OrchestrateResponse)
|
||||
async def orchestrate_graph(
|
||||
req: GraphOrchestrateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""图编排模式:按 DAG 拓扑顺序执行 Agent 和条件节点。"""
|
||||
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
|
||||
orchestrator = AgentOrchestrator(
|
||||
default_llm_config=AgentLLMConfig(
|
||||
model=req.model or "deepseek-v4-flash",
|
||||
temperature=0.3,
|
||||
),
|
||||
)
|
||||
result = await orchestrator._graph(
|
||||
req.message, req.nodes, req.edges, on_llm_call=on_llm_call,
|
||||
)
|
||||
return OrchestrateResponse(
|
||||
mode=result.mode,
|
||||
final_answer=result.final_answer,
|
||||
steps=[
|
||||
OrchestrateStepItem(
|
||||
agent_id=s.agent_id,
|
||||
agent_name=s.agent_name,
|
||||
input=s.input,
|
||||
output=s.output,
|
||||
iterations_used=s.iterations_used,
|
||||
tool_calls_made=s.tool_calls_made,
|
||||
error=s.error,
|
||||
)
|
||||
for s in result.steps
|
||||
],
|
||||
agent_results=result.agent_results,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bare", response_model=ChatResponse)
|
||||
async def chat_bare(
|
||||
req: ChatRequest,
|
||||
|
||||
162
backend/app/api/orchestration_templates.py
Normal file
162
backend/app/api/orchestration_templates.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
编排模板 CRUD API
|
||||
"""
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.orchestration_template import OrchestrationTemplate
|
||||
|
||||
router = APIRouter(prefix="/api/v1/orchestration-templates", tags=["orchestration-templates"])
|
||||
|
||||
|
||||
class TemplateCreate(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
nodes: List[dict] = Field(..., description="编排节点列表")
|
||||
edges: List[dict] = Field(..., description="编排连线列表")
|
||||
|
||||
|
||||
class TemplateUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
nodes: Optional[List[dict]] = None
|
||||
edges: Optional[List[dict]] = None
|
||||
|
||||
|
||||
class TemplateResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
nodes: List[dict]
|
||||
edges: List[dict]
|
||||
user_id: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("", response_model=List[TemplateResponse])
|
||||
async def list_templates(
|
||||
search: Optional[str] = Query(None, description="按名称搜索"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取当前用户的编排模板列表"""
|
||||
q = db.query(OrchestrationTemplate).filter(OrchestrationTemplate.user_id == current_user.id)
|
||||
if search:
|
||||
q = q.filter(OrchestrationTemplate.name.contains(search))
|
||||
q = q.order_by(OrchestrationTemplate.updated_at.desc())
|
||||
templates = q.all()
|
||||
return [
|
||||
TemplateResponse(
|
||||
id=t.id, name=t.name, description=t.description or "",
|
||||
nodes=t.nodes or [], edges=t.edges or [],
|
||||
user_id=t.user_id,
|
||||
created_at=t.created_at.isoformat() if t.created_at else None,
|
||||
updated_at=t.updated_at.isoformat() if t.updated_at else None,
|
||||
)
|
||||
for t in templates
|
||||
]
|
||||
|
||||
|
||||
@router.post("", response_model=TemplateResponse)
|
||||
async def create_template(
|
||||
body: TemplateCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""创建编排模板"""
|
||||
template = OrchestrationTemplate(
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
nodes=body.nodes,
|
||||
edges=body.edges,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
db.add(template)
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
return TemplateResponse(
|
||||
id=template.id, name=template.name, description=template.description or "",
|
||||
nodes=template.nodes or [], edges=template.edges or [],
|
||||
user_id=template.user_id,
|
||||
created_at=template.created_at.isoformat() if template.created_at else None,
|
||||
updated_at=template.updated_at.isoformat() if template.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{template_id}", response_model=TemplateResponse)
|
||||
async def get_template(
|
||||
template_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取模板详情"""
|
||||
template = db.query(OrchestrationTemplate).filter(
|
||||
OrchestrationTemplate.id == template_id,
|
||||
OrchestrationTemplate.user_id == current_user.id,
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
return TemplateResponse(
|
||||
id=template.id, name=template.name, description=template.description or "",
|
||||
nodes=template.nodes or [], edges=template.edges or [],
|
||||
user_id=template.user_id,
|
||||
created_at=template.created_at.isoformat() if template.created_at else None,
|
||||
updated_at=template.updated_at.isoformat() if template.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{template_id}", response_model=TemplateResponse)
|
||||
async def update_template(
|
||||
template_id: str,
|
||||
body: TemplateUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""更新模板"""
|
||||
template = db.query(OrchestrationTemplate).filter(
|
||||
OrchestrationTemplate.id == template_id,
|
||||
OrchestrationTemplate.user_id == current_user.id,
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
if body.name is not None:
|
||||
template.name = body.name
|
||||
if body.description is not None:
|
||||
template.description = body.description
|
||||
if body.nodes is not None:
|
||||
template.nodes = body.nodes
|
||||
if body.edges is not None:
|
||||
template.edges = body.edges
|
||||
db.commit()
|
||||
db.refresh(template)
|
||||
return TemplateResponse(
|
||||
id=template.id, name=template.name, description=template.description or "",
|
||||
nodes=template.nodes or [], edges=template.edges or [],
|
||||
user_id=template.user_id,
|
||||
created_at=template.created_at.isoformat() if template.created_at else None,
|
||||
updated_at=template.updated_at.isoformat() if template.updated_at else None,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{template_id}")
|
||||
async def delete_template(
|
||||
template_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""删除模板"""
|
||||
template = db.query(OrchestrationTemplate).filter(
|
||||
OrchestrationTemplate.id == template_id,
|
||||
OrchestrationTemplate.user_id == current_user.id,
|
||||
).first()
|
||||
if not template:
|
||||
raise HTTPException(status_code=404, detail="模板不存在")
|
||||
db.delete(template)
|
||||
db.commit()
|
||||
return {"detail": "ok"}
|
||||
@@ -96,6 +96,44 @@ def _tool_to_dict(tool: Tool) -> dict:
|
||||
# ─── 工具市场浏览 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _builtin_schema_to_tool_dict(schema: dict) -> dict:
|
||||
"""将 tool_registry 中的 schema 转为与 DB Tool 一致的字典格式。"""
|
||||
func = schema.get("function", schema)
|
||||
name = func.get("name", "")
|
||||
desc = func.get("description", "")
|
||||
params = func.get("parameters", {})
|
||||
# 根据工具名自动归类
|
||||
cat = "系统工具"
|
||||
if name in ("image_ocr", "image_vision"):
|
||||
cat = "多模态"
|
||||
elif name in ("speech_to_text", "text_to_speech"):
|
||||
cat = "多模态"
|
||||
elif name.startswith("file_"):
|
||||
cat = "文件操作"
|
||||
elif name.startswith("http") or name.startswith("url"):
|
||||
cat = "网络请求"
|
||||
elif name.startswith("database") or name.startswith("sql"):
|
||||
cat = "数据库"
|
||||
elif name.startswith("agent_"):
|
||||
cat = "AI Agent"
|
||||
elif name in ("web_search", "send_email", "browser_use"):
|
||||
cat = "网络请求"
|
||||
return {
|
||||
"id": f"builtin_{name}",
|
||||
"name": name,
|
||||
"description": desc,
|
||||
"category": cat,
|
||||
"function_schema": schema,
|
||||
"implementation_type": "builtin",
|
||||
"implementation_config": None,
|
||||
"is_public": True,
|
||||
"use_count": 0,
|
||||
"user_id": None,
|
||||
"created_at": "",
|
||||
"updated_at": "",
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=List[ToolResponse])
|
||||
async def list_tools(
|
||||
category: Optional[str] = Query(None, description="按分类筛选"),
|
||||
@@ -104,7 +142,7 @@ async def list_tools(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user),
|
||||
):
|
||||
"""浏览工具市场。"""
|
||||
"""浏览工具市场(含内置工具 + 数据库工具)。"""
|
||||
query = db.query(Tool)
|
||||
|
||||
if scope == "public":
|
||||
@@ -122,7 +160,23 @@ async def list_tools(
|
||||
)
|
||||
|
||||
tools = query.order_by(Tool.use_count.desc(), Tool.created_at.desc()).all()
|
||||
return [_tool_to_dict(t) for t in tools]
|
||||
result = [_tool_to_dict(t) for t in tools]
|
||||
db_names = {t["name"] for t in result}
|
||||
|
||||
# 合并内置工具(未在 DB 中覆盖的)
|
||||
if scope != "mine":
|
||||
for schema in tool_registry.get_all_tool_schemas():
|
||||
entry = _builtin_schema_to_tool_dict(schema)
|
||||
if entry["name"] not in db_names:
|
||||
if category and entry["category"] != category:
|
||||
continue
|
||||
if search:
|
||||
kw = search.lower()
|
||||
if kw not in entry["name"].lower() and kw not in entry["description"].lower():
|
||||
continue
|
||||
result.append(entry)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/categories", response_model=List[str])
|
||||
@@ -131,7 +185,7 @@ async def list_categories(db: Session = Depends(get_db)):
|
||||
rows = db.query(Tool.category).filter(Tool.category.isnot(None)).distinct().all()
|
||||
cats = sorted(set(r[0] for r in rows if r[0]))
|
||||
# 加上常用分类
|
||||
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义"]
|
||||
defaults = ["数据处理", "网络请求", "文件操作", "AI服务", "数据库", "通知", "自定义", "多模态", "系统工具"]
|
||||
for d in defaults:
|
||||
if d not in cats:
|
||||
cats.append(d)
|
||||
|
||||
Reference in New Issue
Block a user