fix: delete agent 500 error + dynamic personality + deployment guide
- Fix delete agent 500: clean up FK records (agent_llm_logs, permissions, schedules, executions, team_members) and unbind goals/tasks before delete - Remove hardcoded personality templates in Android, replace with dynamic system prompt generation from name + description - Set promptSectionsEnabled=false to bypass PromptComposer for personality - Add Tencent Cloud Linux deployment guide (Docker Compose) - Accumulated backend service updates, frontend UI fixes, Android app changes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
319
backend/app/api/agent_branches.py
Normal file
319
backend/app/api/agent_branches.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
对话分支 API — 支持从任意历史点分叉对话
|
||||
|
||||
参考 Claude Code src/commands/branch/branch.ts
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.agent import Agent
|
||||
from app.agent_runtime import (
|
||||
AgentRuntime,
|
||||
AgentConfig,
|
||||
AgentLLMConfig,
|
||||
AgentToolConfig,
|
||||
AgentMemoryConfig,
|
||||
AgentStep,
|
||||
)
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/agent-chat", tags=["agent-branches"])
|
||||
|
||||
|
||||
def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
|
||||
if not nodes:
|
||||
return {}
|
||||
for node in nodes:
|
||||
typ = node.get("type", "")
|
||||
if typ in ("agent", "llm", "template"):
|
||||
return node.get("data") or {}
|
||||
return {}
|
||||
|
||||
|
||||
def _build_memory_config_from_node(agent_node_cfg: dict) -> AgentMemoryConfig:
|
||||
from app.core.compaction_config import CompactionConfig
|
||||
compaction_raw = agent_node_cfg.get("compaction")
|
||||
if isinstance(compaction_raw, dict):
|
||||
compaction = CompactionConfig(**compaction_raw)
|
||||
else:
|
||||
compaction = CompactionConfig()
|
||||
return AgentMemoryConfig(
|
||||
max_history_messages=int(agent_node_cfg.get("memory_max_history", 20)),
|
||||
vector_memory_top_k=int(agent_node_cfg.get("memory_vector_top_k", 5)),
|
||||
persist_to_db=bool(agent_node_cfg.get("memory_persist", True)),
|
||||
vector_memory_enabled=bool(agent_node_cfg.get("memory_vector_enabled", True)),
|
||||
learning_enabled=bool(agent_node_cfg.get("memory_learning", True)),
|
||||
memory_dir_enabled=bool(agent_node_cfg.get("memory_dir_enabled", False)),
|
||||
memory_dir_path=str(agent_node_cfg.get("memory_dir_path", "")),
|
||||
compaction=compaction,
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────── 请求/响应模型 ────────────────────────────
|
||||
|
||||
class BranchCreateRequest(BaseModel):
|
||||
session_id: str = Field(..., description="要分叉的原会话 ID")
|
||||
title: Optional[str] = Field(default=None, description="自定义分支标题")
|
||||
agent_id: Optional[str] = Field(default=None, description="关联 Agent ID")
|
||||
|
||||
|
||||
class BranchItem(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
agent_name: Optional[str] = None
|
||||
parent_session_id: str
|
||||
message_count: int
|
||||
first_user_message: Optional[str] = None
|
||||
created_at: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class BranchListResponse(BaseModel):
|
||||
branches: List[BranchItem]
|
||||
total: int
|
||||
|
||||
|
||||
class BranchDetailResponse(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
agent_name: Optional[str] = None
|
||||
parent_session_id: str
|
||||
branch_session_id: str
|
||||
message_count: int
|
||||
first_user_message: Optional[str] = None
|
||||
messages: Optional[List[Dict[str, Any]]] = None
|
||||
created_at: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class BranchResumeRequest(BaseModel):
|
||||
message: str = Field(..., description="在分支基础上继续的新消息")
|
||||
|
||||
|
||||
class BranchResumeResponse(BaseModel):
|
||||
content: str
|
||||
iterations_used: int
|
||||
tool_calls_made: int
|
||||
truncated: bool
|
||||
session_id: str
|
||||
agent_id: Optional[str] = None
|
||||
steps: List[AgentStep] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ──────────────────────────── API 端点 ────────────────────────────
|
||||
|
||||
@router.post("/branches", response_model=BranchDetailResponse)
|
||||
async def create_conversation_branch(
|
||||
req: BranchCreateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""从指定会话创建对话分支(快照完整消息列表)。"""
|
||||
from app.services.conversation_branch_service import (
|
||||
create_branch,
|
||||
get_messages_from_session,
|
||||
)
|
||||
|
||||
messages = get_messages_from_session(db, req.session_id)
|
||||
if not messages:
|
||||
raise HTTPException(status_code=404, detail="找不到该会话的消息记录(请确保会话已产生对话)")
|
||||
|
||||
agent_name = None
|
||||
if req.agent_id:
|
||||
agent = db.query(Agent).filter(Agent.id == req.agent_id).first()
|
||||
if agent:
|
||||
agent_name = agent.name
|
||||
|
||||
branch = create_branch(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
parent_session_id=req.session_id,
|
||||
messages=messages,
|
||||
agent_id=req.agent_id,
|
||||
agent_name=agent_name,
|
||||
custom_title=req.title,
|
||||
)
|
||||
|
||||
return BranchDetailResponse(
|
||||
id=branch.id,
|
||||
title=branch.title,
|
||||
agent_name=branch.agent_name,
|
||||
parent_session_id=branch.parent_session_id,
|
||||
branch_session_id=branch.branch_session_id,
|
||||
message_count=branch.message_count,
|
||||
first_user_message=branch.first_user_message,
|
||||
messages=branch.messages,
|
||||
created_at=branch.created_at.isoformat() if branch.created_at else "",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/branches", response_model=BranchListResponse)
|
||||
async def list_conversation_branches(
|
||||
agent_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""列出当前用户的所有对话分支。"""
|
||||
from app.services.conversation_branch_service import list_branches
|
||||
|
||||
branches = list_branches(
|
||||
db=db, user_id=current_user.id, agent_id=agent_id, limit=limit, offset=offset,
|
||||
)
|
||||
|
||||
return BranchListResponse(
|
||||
branches=[
|
||||
BranchItem(
|
||||
id=b.id, title=b.title, agent_name=b.agent_name,
|
||||
parent_session_id=b.parent_session_id,
|
||||
message_count=b.message_count,
|
||||
first_user_message=b.first_user_message,
|
||||
created_at=b.created_at.isoformat() if b.created_at else "",
|
||||
)
|
||||
for b in branches
|
||||
],
|
||||
total=len(branches),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/branches/{branch_id}", response_model=BranchDetailResponse)
|
||||
async def get_conversation_branch(
|
||||
branch_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取分支详情(含完整消息列表)。"""
|
||||
from app.services.conversation_branch_service import get_branch
|
||||
|
||||
branch = get_branch(db, branch_id, current_user.id)
|
||||
if not branch:
|
||||
raise HTTPException(status_code=404, detail="分支不存在")
|
||||
|
||||
return BranchDetailResponse(
|
||||
id=branch.id, title=branch.title, agent_name=branch.agent_name,
|
||||
parent_session_id=branch.parent_session_id,
|
||||
branch_session_id=branch.branch_session_id,
|
||||
message_count=branch.message_count,
|
||||
first_user_message=branch.first_user_message,
|
||||
messages=branch.messages,
|
||||
created_at=branch.created_at.isoformat() if branch.created_at else "",
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/branches/{branch_id}")
|
||||
async def delete_conversation_branch(
|
||||
branch_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""删除(软删除)对话分支。"""
|
||||
from app.services.conversation_branch_service import delete_branch
|
||||
|
||||
success = delete_branch(db, branch_id, current_user.id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="分支不存在或无权操作")
|
||||
return {"message": "分支已删除", "branch_id": branch_id}
|
||||
|
||||
|
||||
@router.post("/branches/{branch_id}/resume", response_model=BranchResumeResponse)
|
||||
async def resume_from_branch(
|
||||
branch_id: str,
|
||||
req: BranchResumeRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""从分支恢复对话 — 使用分支保存的消息作为上下文继续对话。"""
|
||||
from app.services.conversation_branch_service import get_branch
|
||||
|
||||
branch = get_branch(db, branch_id, current_user.id)
|
||||
if not branch:
|
||||
raise HTTPException(status_code=404, detail="分支不存在")
|
||||
if not branch.messages:
|
||||
raise HTTPException(status_code=400, detail="分支没有保存的消息")
|
||||
|
||||
messages = branch.messages
|
||||
system_prompt = "你是一个有用的AI助手。"
|
||||
history_messages = messages
|
||||
if messages and messages[0].get("role") == "system":
|
||||
system_prompt = messages[0].get("content", system_prompt)
|
||||
history_messages = messages[1:]
|
||||
|
||||
# 构建 Agent 配置
|
||||
agent = None
|
||||
if branch.agent_id:
|
||||
agent = db.query(Agent).filter(Agent.id == branch.agent_id).first()
|
||||
|
||||
if agent:
|
||||
wc = agent.workflow_config or {}
|
||||
nodes = wc.get("nodes", [])
|
||||
agent_node_cfg = _find_agent_node_config(nodes)
|
||||
llm_config = AgentLLMConfig(
|
||||
provider=agent_node_cfg.get("provider", "deepseek"),
|
||||
model=agent_node_cfg.get("model", "deepseek-v4-flash"),
|
||||
temperature=float(agent_node_cfg.get("temperature", 0.7)),
|
||||
max_iterations=int(agent_node_cfg.get("max_iterations", 10)),
|
||||
)
|
||||
config = AgentConfig(
|
||||
name=agent.name, system_prompt=system_prompt, llm=llm_config,
|
||||
tools=AgentToolConfig(include_tools=agent_node_cfg.get("tools", [])),
|
||||
memory=_build_memory_config_from_node(agent_node_cfg),
|
||||
user_id=current_user.id,
|
||||
memory_scope_id=f"{current_user.id}:{branch.agent_id}" if current_user.id else str(branch.agent_id),
|
||||
)
|
||||
else:
|
||||
config = AgentConfig(
|
||||
name="branch_resume", system_prompt=system_prompt,
|
||||
llm=AgentLLMConfig(model="deepseek-v4-flash", temperature=0.7, max_iterations=10),
|
||||
user_id=current_user.id,
|
||||
memory_scope_id=f"{current_user.id}:__bare__" if current_user.id else "__bare__",
|
||||
)
|
||||
|
||||
# 创建 Runtime 并预加载分支消息
|
||||
context = AgentContext(
|
||||
system_prompt=system_prompt,
|
||||
user_id=current_user.id,
|
||||
session_id=branch.branch_session_id,
|
||||
)
|
||||
for msg in history_messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
if role == "user":
|
||||
context.add_user_message(content)
|
||||
elif role == "assistant":
|
||||
tool_calls = msg.get("tool_calls")
|
||||
context.add_assistant_message(content, tool_calls)
|
||||
elif role == "tool":
|
||||
context.add_tool_result(
|
||||
msg.get("tool_call_id", ""),
|
||||
msg.get("name", "unknown"),
|
||||
content,
|
||||
)
|
||||
|
||||
# 不需要 on_llm_logger 回调(避免与主 logger 重复),仅做执行
|
||||
runtime = AgentRuntime(config=config, context=context)
|
||||
result = await runtime.run(req.message)
|
||||
|
||||
return BranchResumeResponse(
|
||||
content=result.content,
|
||||
iterations_used=result.iterations_used,
|
||||
tool_calls_made=result.tool_calls_made,
|
||||
truncated=result.truncated,
|
||||
session_id=branch.branch_session_id,
|
||||
agent_id=branch.agent_id,
|
||||
steps=result.steps,
|
||||
)
|
||||
@@ -82,6 +82,9 @@ class ChatRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
max_iterations: Optional[int] = None
|
||||
streamlined: bool = Field(default=False, description="启用工具结果流式美化")
|
||||
prompt_sections_enabled: bool = Field(default=True, description="启用系统提示词分层装配")
|
||||
system_prompt_override: Optional[str] = Field(default=None, description="覆盖 Agent 的 System Prompt")
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
@@ -92,6 +95,8 @@ class ChatResponse(BaseModel):
|
||||
session_id: str
|
||||
agent_id: Optional[str] = None
|
||||
steps: List[AgentStep] = Field(default_factory=list, description="执行追踪步骤")
|
||||
streamlined_summary: Optional[str] = Field(default=None, description="流式美化摘要(streamlined 模式)")
|
||||
token_usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 预算摘要")
|
||||
|
||||
|
||||
class OrchestrateAgentItem(BaseModel):
|
||||
@@ -249,11 +254,35 @@ async def chat_bare(
|
||||
),
|
||||
user_id=uid,
|
||||
memory_scope_id=bare_scope,
|
||||
memory=AgentMemoryConfig(
|
||||
memory_dir_enabled=True,
|
||||
memory_dir_path="",
|
||||
persist_to_db=True,
|
||||
vector_memory_enabled=True,
|
||||
learning_enabled=True,
|
||||
),
|
||||
tools=AgentToolConfig(
|
||||
permission_level="acceptEdits",
|
||||
),
|
||||
)
|
||||
if not req.prompt_sections_enabled:
|
||||
config.prompt_sections.enabled = False
|
||||
if req.system_prompt_override:
|
||||
config.system_prompt = req.system_prompt_override
|
||||
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
|
||||
result = await runtime.run(req.message)
|
||||
|
||||
# 流式美化:为 steps 生成累计摘要
|
||||
streamlined_summary = None
|
||||
if req.streamlined and result.steps:
|
||||
from app.core.streamlined_output import ToolCounts, categorize_tool, get_tool_summary_text
|
||||
counts = ToolCounts()
|
||||
for s in result.steps:
|
||||
if s.type == "tool_result" and s.tool_name:
|
||||
counts.add(categorize_tool(s.tool_name))
|
||||
streamlined_summary = get_tool_summary_text(counts)
|
||||
|
||||
return ChatResponse(
|
||||
content=result.content,
|
||||
iterations_used=result.iterations_used,
|
||||
@@ -261,6 +290,8 @@ async def chat_bare(
|
||||
truncated=result.truncated,
|
||||
session_id=runtime.context.session_id,
|
||||
steps=result.steps,
|
||||
streamlined_summary=streamlined_summary,
|
||||
token_usage=result.token_usage.model_dump() if result.token_usage else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -286,9 +317,23 @@ async def chat_bare_stream(
|
||||
),
|
||||
user_id=uid,
|
||||
memory_scope_id=bare_scope,
|
||||
memory=AgentMemoryConfig(
|
||||
memory_dir_enabled=True,
|
||||
memory_dir_path="",
|
||||
persist_to_db=True,
|
||||
vector_memory_enabled=True,
|
||||
learning_enabled=True,
|
||||
),
|
||||
tools=AgentToolConfig(
|
||||
permission_level="acceptEdits",
|
||||
),
|
||||
)
|
||||
if not req.prompt_sections_enabled:
|
||||
config.prompt_sections.enabled = False
|
||||
if req.system_prompt_override:
|
||||
config.system_prompt = req.system_prompt_override
|
||||
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
|
||||
return StreamingResponse(
|
||||
_sse_stream(runtime.run_stream(req.message)),
|
||||
media_type="text/event-stream",
|
||||
@@ -339,6 +384,8 @@ async def chat_with_agent(
|
||||
uid = current_user.id
|
||||
mem_scope = f"{uid}:{agent_id}" if uid else str(agent_id)
|
||||
memory_cfg = _build_memory_config_from_node(agent_node_cfg)
|
||||
if getattr(agent, "parent_agent_id", None):
|
||||
memory_cfg.parent_agent_id = agent.parent_agent_id
|
||||
config = AgentConfig(
|
||||
name=agent.name,
|
||||
system_prompt=system_prompt,
|
||||
@@ -347,21 +394,42 @@ async def chat_with_agent(
|
||||
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
|
||||
temperature=req.temperature or float(agent_node_cfg.get("temperature", 0.7)),
|
||||
max_iterations=req.max_iterations or int(agent_node_cfg.get("max_iterations", 10)),
|
||||
# 计划模式 (P2)
|
||||
plan_mode_enabled=bool(agent_node_cfg.get("plan_mode_enabled", False)),
|
||||
plan_approval_required=bool(agent_node_cfg.get("plan_approval_required", True)),
|
||||
),
|
||||
tools=AgentToolConfig(
|
||||
include_tools=agent_node_cfg.get("tools", []),
|
||||
exclude_tools=agent_node_cfg.get("exclude_tools", []),
|
||||
# 工具安全分级 (P3)
|
||||
permission_level=str(agent_node_cfg.get("permission_level", "default")),
|
||||
deny_tools=agent_node_cfg.get("deny_tools", []),
|
||||
auto_approve_rules=agent_node_cfg.get("auto_approve_rules", []),
|
||||
),
|
||||
memory=memory_cfg,
|
||||
budget=budget,
|
||||
user_id=uid,
|
||||
memory_scope_id=mem_scope,
|
||||
)
|
||||
if not req.prompt_sections_enabled:
|
||||
config.prompt_sections.enabled = False
|
||||
if req.system_prompt_override:
|
||||
config.system_prompt = req.system_prompt_override
|
||||
|
||||
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
|
||||
result = await runtime.run(req.message)
|
||||
|
||||
# 流式美化:为 steps 生成累计摘要
|
||||
streamlined_summary = None
|
||||
if req.streamlined and result.steps:
|
||||
from app.core.streamlined_output import ToolCounts, categorize_tool, get_tool_summary_text
|
||||
counts = ToolCounts()
|
||||
for s in result.steps:
|
||||
if s.type == "tool_result" and s.tool_name:
|
||||
counts.add(categorize_tool(s.tool_name))
|
||||
streamlined_summary = get_tool_summary_text(counts)
|
||||
|
||||
return ChatResponse(
|
||||
content=result.content,
|
||||
iterations_used=result.iterations_used,
|
||||
@@ -370,6 +438,8 @@ async def chat_with_agent(
|
||||
session_id=runtime.context.session_id,
|
||||
agent_id=agent_id,
|
||||
steps=result.steps,
|
||||
streamlined_summary=streamlined_summary,
|
||||
token_usage=result.token_usage.model_dump() if result.token_usage else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -408,6 +478,8 @@ async def chat_with_agent_stream(
|
||||
uid = current_user.id
|
||||
mem_scope = f"{uid}:{agent_id}" if uid else str(agent_id)
|
||||
memory_cfg = _build_memory_config_from_node(agent_node_cfg)
|
||||
if getattr(agent, "parent_agent_id", None):
|
||||
memory_cfg.parent_agent_id = agent.parent_agent_id
|
||||
config = AgentConfig(
|
||||
name=agent.name,
|
||||
system_prompt=system_prompt,
|
||||
@@ -416,19 +488,30 @@ async def chat_with_agent_stream(
|
||||
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
|
||||
temperature=req.temperature or float(agent_node_cfg.get("temperature", 0.7)),
|
||||
max_iterations=req.max_iterations or int(agent_node_cfg.get("max_iterations", 10)),
|
||||
# 计划模式 (P2)
|
||||
plan_mode_enabled=bool(agent_node_cfg.get("plan_mode_enabled", False)),
|
||||
plan_approval_required=bool(agent_node_cfg.get("plan_approval_required", True)),
|
||||
),
|
||||
tools=AgentToolConfig(
|
||||
include_tools=agent_node_cfg.get("tools", []),
|
||||
exclude_tools=agent_node_cfg.get("exclude_tools", []),
|
||||
# 工具安全分级 (P3)
|
||||
permission_level=str(agent_node_cfg.get("permission_level", "default")),
|
||||
deny_tools=agent_node_cfg.get("deny_tools", []),
|
||||
auto_approve_rules=agent_node_cfg.get("auto_approve_rules", []),
|
||||
),
|
||||
memory=memory_cfg,
|
||||
budget=budget,
|
||||
user_id=uid,
|
||||
memory_scope_id=mem_scope,
|
||||
)
|
||||
if not req.prompt_sections_enabled:
|
||||
config.prompt_sections.enabled = False
|
||||
if req.system_prompt_override:
|
||||
config.system_prompt = req.system_prompt_override
|
||||
|
||||
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
|
||||
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
|
||||
return StreamingResponse(
|
||||
_sse_stream(runtime.run_stream(req.message)),
|
||||
media_type="text/event-stream",
|
||||
@@ -453,10 +536,24 @@ def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
|
||||
|
||||
def _build_memory_config_from_node(agent_node_cfg: dict) -> AgentMemoryConfig:
|
||||
"""从 Agent 工作流节点配置中提取记忆配置。"""
|
||||
from app.core.compaction_config import CompactionConfig
|
||||
|
||||
# 压缩配置
|
||||
compaction_raw = agent_node_cfg.get("compaction")
|
||||
if isinstance(compaction_raw, dict):
|
||||
compaction = CompactionConfig(**compaction_raw)
|
||||
else:
|
||||
compaction = CompactionConfig() # 默认配置
|
||||
|
||||
return AgentMemoryConfig(
|
||||
max_history_messages=int(agent_node_cfg.get("memory_max_history", 20)),
|
||||
vector_memory_top_k=int(agent_node_cfg.get("memory_vector_top_k", 5)),
|
||||
persist_to_db=bool(agent_node_cfg.get("memory_persist", True)),
|
||||
vector_memory_enabled=bool(agent_node_cfg.get("memory_vector_enabled", True)),
|
||||
learning_enabled=bool(agent_node_cfg.get("memory_learning", True)),
|
||||
# 文件式记忆 (MEMORY.md)
|
||||
memory_dir_enabled=bool(agent_node_cfg.get("memory_dir_enabled", False)),
|
||||
memory_dir_path=str(agent_node_cfg.get("memory_dir_path", "")),
|
||||
# 对话压缩
|
||||
compaction=compaction,
|
||||
)
|
||||
|
||||
@@ -45,7 +45,7 @@ class AgentMarketItem(BaseModel):
|
||||
use_count: int = 0
|
||||
view_count: int = 0
|
||||
version: int = 1
|
||||
user_id: str
|
||||
user_id: Optional[str] = None
|
||||
creator_username: Optional[str] = None
|
||||
is_favorited: Optional[bool] = None
|
||||
user_rating: Optional[int] = None
|
||||
@@ -70,7 +70,7 @@ class AgentMarketDetail(BaseModel):
|
||||
use_count: int = 0
|
||||
view_count: int = 0
|
||||
version: int = 1
|
||||
user_id: str
|
||||
user_id: Optional[str] = None
|
||||
creator_username: Optional[str] = None
|
||||
is_favorited: Optional[bool] = None
|
||||
user_rating: Optional[int] = None
|
||||
@@ -100,7 +100,7 @@ class RatingResponse(BaseModel):
|
||||
"""评分响应"""
|
||||
id: str
|
||||
agent_id: str
|
||||
user_id: str
|
||||
user_id: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
rating: int
|
||||
comment: Optional[str] = None
|
||||
@@ -137,7 +137,7 @@ def _build_agent_item(agent: Agent, current_user: Optional[User], db: Session) -
|
||||
"use_count": agent.use_count or 0,
|
||||
"view_count": agent.view_count or 0,
|
||||
"version": agent.version or 1,
|
||||
"user_id": agent.user_id,
|
||||
"user_id": agent.user_id or "",
|
||||
"creator_username": agent.user.username if agent.user else None,
|
||||
"is_favorited": None,
|
||||
"user_rating": None,
|
||||
|
||||
280
backend/app/api/agent_swarm.py
Normal file
280
backend/app/api/agent_swarm.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Agent 蜂群 API — Leader/Teammate 并行协作
|
||||
|
||||
POST /api/v1/swarm/run
|
||||
{"message": "帮我做三件事: ...", "mode": "parallel", "max_teammates": 5}
|
||||
→ Leader 分解任务 → Teammates 并行执行 → Leader 汇总
|
||||
|
||||
参考 Claude Code src/tools/AgentTool/ + forkSubagent.ts
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.agent import Agent
|
||||
from app.agent_runtime.swarm import (
|
||||
SwarmRuntime,
|
||||
SwarmConfig,
|
||||
SwarmMode,
|
||||
SwarmResult,
|
||||
SwarmTask,
|
||||
create_swarm,
|
||||
)
|
||||
from app.agent_runtime.schemas import AgentConfig, AgentLLMConfig, AgentToolConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/swarm", tags=["swarm"])
|
||||
|
||||
|
||||
# ──────────────────────────── 请求/响应模型 ────────────────────────────
|
||||
|
||||
|
||||
class SwarmRunRequest(BaseModel):
|
||||
message: str = Field(..., description="用户输入")
|
||||
mode: str = Field(default="parallel", description="蜂群模式: parallel | pipeline | debate")
|
||||
max_teammates: int = Field(default=5, ge=1, le=20)
|
||||
leader_model: Optional[str] = Field(default=None, description="Leader 模型")
|
||||
teammate_model: Optional[str] = Field(default=None, description="Teammate 模型")
|
||||
mailbox_enabled: bool = Field(default=True, description="启用 Agent 间消息传递")
|
||||
agent_ids: Optional[List[str]] = Field(default=None, description="指定的 Agent ID 列表(作为 Teammates)")
|
||||
retry_failed: bool = Field(default=True, description="失败任务是否重试")
|
||||
|
||||
|
||||
class SwarmTaskItem(BaseModel):
|
||||
id: str
|
||||
description: str
|
||||
assigned_agent_id: Optional[str] = None
|
||||
status: str
|
||||
result: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
iterations_used: int = 0
|
||||
tool_calls_made: int = 0
|
||||
duration_ms: int = 0
|
||||
|
||||
|
||||
class SwarmTeammateItem(BaseModel):
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
task_id: str
|
||||
success: bool
|
||||
output: str
|
||||
duration_ms: int
|
||||
iterations_used: int
|
||||
tool_calls_made: int
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class MailboxMessageItem(BaseModel):
|
||||
id: str
|
||||
from_: str = Field(alias="from")
|
||||
to: str
|
||||
content: str
|
||||
timestamp: float
|
||||
|
||||
|
||||
class SwarmRunResponse(BaseModel):
|
||||
success: bool
|
||||
final_answer: str
|
||||
mode: str
|
||||
tasks: List[SwarmTaskItem] = Field(default_factory=list)
|
||||
teammate_results: List[SwarmTeammateItem] = Field(default_factory=list)
|
||||
mailbox_messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
total_duration_ms: int = 0
|
||||
total_iterations: int = 0
|
||||
total_tool_calls: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ──────────────────────────── 端点 ────────────────────────────
|
||||
|
||||
|
||||
@router.post("/run", response_model=SwarmRunResponse)
|
||||
async def swarm_run(
|
||||
req: SwarmRunRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""运行 Agent 蜂群 — Leader 分解任务 → Teammates 并行执行 → 汇总。
|
||||
|
||||
支持三种模式:
|
||||
- parallel: 所有子任务并发执行(无依赖)
|
||||
- pipeline: 按依赖顺序执行
|
||||
- debate: 多个 Agent 独立回答后汇总
|
||||
|
||||
Teammates 来源(优先级):
|
||||
1. agent_ids 参数指定 → 从数据库加载 Agent 配置
|
||||
2. 自动生成 → 使用 teammate_model 创建轻量 Teammate
|
||||
"""
|
||||
uid = current_user.id
|
||||
|
||||
# 解析模式
|
||||
mode = SwarmMode.PARALLEL
|
||||
if req.mode == "pipeline":
|
||||
mode = SwarmMode.PIPELINE
|
||||
elif req.mode == "debate":
|
||||
mode = SwarmMode.DEBATE
|
||||
|
||||
# 构建 SwarmConfig
|
||||
config = SwarmConfig(
|
||||
mode=mode,
|
||||
max_teammates=req.max_teammates,
|
||||
leader_model=req.leader_model or "deepseek-v4-pro",
|
||||
teammate_model=req.teammate_model or "deepseek-v4-flash",
|
||||
mailbox_enabled=req.mailbox_enabled,
|
||||
retry_failed=req.retry_failed,
|
||||
)
|
||||
|
||||
# 加载指定的 Agent 作为 Teammates
|
||||
teammate_configs: List[AgentConfig] = []
|
||||
if req.agent_ids:
|
||||
for aid in req.agent_ids:
|
||||
agent = db.query(Agent).filter(Agent.id == aid).first()
|
||||
if agent:
|
||||
wc = agent.workflow_config or {}
|
||||
nodes = wc.get("nodes", [])
|
||||
agent_node_cfg = {}
|
||||
for node in nodes:
|
||||
if node.get("type") in ("agent", "llm", "template"):
|
||||
agent_node_cfg = node.get("data") or {}
|
||||
break
|
||||
|
||||
teammate_configs.append(AgentConfig(
|
||||
name=agent.name,
|
||||
system_prompt=agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。",
|
||||
llm=AgentLLMConfig(
|
||||
model=agent_node_cfg.get("model", req.teammate_model or "deepseek-v4-flash"),
|
||||
provider=agent_node_cfg.get("provider", "deepseek"),
|
||||
temperature=float(agent_node_cfg.get("temperature", 0.7)),
|
||||
max_iterations=int(agent_node_cfg.get("max_iterations", 10)),
|
||||
),
|
||||
tools=AgentToolConfig(
|
||||
include_tools=agent_node_cfg.get("tools", []),
|
||||
),
|
||||
user_id=uid,
|
||||
))
|
||||
|
||||
# 构建 Leader 配置
|
||||
leader_config = AgentConfig(
|
||||
name="SwarmLeader",
|
||||
system_prompt="你是一个AI任务协调者。将复杂问题分解为子任务,协调多个AI Agent并行处理,并汇总结果。",
|
||||
llm=AgentLLMConfig(model=config.leader_model, temperature=0.3, max_iterations=10),
|
||||
user_id=uid,
|
||||
)
|
||||
|
||||
# 创建并运行 Swarm
|
||||
swarm = SwarmRuntime(
|
||||
config=config,
|
||||
leader_config=leader_config,
|
||||
teammate_configs=teammate_configs,
|
||||
)
|
||||
|
||||
result = await swarm.run(req.message)
|
||||
|
||||
return SwarmRunResponse(
|
||||
success=result.success,
|
||||
final_answer=result.final_answer,
|
||||
mode=result.mode.value,
|
||||
tasks=[
|
||||
SwarmTaskItem(
|
||||
id=t.id, description=t.description,
|
||||
assigned_agent_id=t.assigned_agent_id,
|
||||
status=t.status.value, result=t.result[:500] if t.result else None,
|
||||
error=t.error, iterations_used=t.iterations_used,
|
||||
tool_calls_made=t.tool_calls_made, duration_ms=t.duration_ms,
|
||||
)
|
||||
for t in result.tasks
|
||||
],
|
||||
teammate_results=[
|
||||
SwarmTeammateItem(
|
||||
agent_id=tr["agent_id"], agent_name=tr["agent_name"],
|
||||
task_id=tr["task_id"], success=tr["success"],
|
||||
output=tr["output"][:500], duration_ms=tr["duration_ms"],
|
||||
iterations_used=tr["iterations_used"], tool_calls_made=tr["tool_calls_made"],
|
||||
error=tr.get("error"),
|
||||
)
|
||||
for tr in result.teammate_results
|
||||
],
|
||||
mailbox_messages=result.mailbox_messages,
|
||||
total_duration_ms=result.total_duration_ms,
|
||||
total_iterations=result.total_iterations,
|
||||
total_tool_calls=result.total_tool_calls,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/run/stream")
|
||||
async def swarm_run_stream(
|
||||
req: SwarmRunRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""运行 Agent 蜂群(流式 SSE)— 实时推送任务分解、执行进度、汇总结果。"""
|
||||
import json as _json
|
||||
|
||||
async def _stream():
|
||||
uid = current_user.id
|
||||
mode = {"parallel": SwarmMode.PARALLEL, "pipeline": SwarmMode.PIPELINE,
|
||||
"debate": SwarmMode.DEBATE}.get(req.mode, SwarmMode.PARALLEL)
|
||||
|
||||
config = SwarmConfig(
|
||||
mode=mode, max_teammates=req.max_teammates,
|
||||
leader_model=req.leader_model or "deepseek-v4-pro",
|
||||
teammate_model=req.teammate_model or "deepseek-v4-flash",
|
||||
mailbox_enabled=req.mailbox_enabled, retry_failed=req.retry_failed,
|
||||
)
|
||||
|
||||
# Load specified agents
|
||||
teammate_configs = []
|
||||
if req.agent_ids:
|
||||
for aid in req.agent_ids:
|
||||
agent = db.query(Agent).filter(Agent.id == aid).first()
|
||||
if agent:
|
||||
wc = agent.workflow_config or {}
|
||||
agent_node_cfg = {}
|
||||
for node in wc.get("nodes", []):
|
||||
if node.get("type") in ("agent", "llm", "template"):
|
||||
agent_node_cfg = node.get("data") or {}
|
||||
break
|
||||
teammate_configs.append(AgentConfig(
|
||||
name=agent.name,
|
||||
system_prompt=agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。",
|
||||
llm=AgentLLMConfig(
|
||||
model=agent_node_cfg.get("model", req.teammate_model or "deepseek-v4-flash"),
|
||||
provider=agent_node_cfg.get("provider", "deepseek"),
|
||||
temperature=float(agent_node_cfg.get("temperature", 0.7)),
|
||||
max_iterations=int(agent_node_cfg.get("max_iterations", 10)),
|
||||
),
|
||||
tools=AgentToolConfig(include_tools=agent_node_cfg.get("tools", [])),
|
||||
user_id=uid,
|
||||
))
|
||||
|
||||
leader_config = AgentConfig(
|
||||
name="SwarmLeader",
|
||||
system_prompt="你是一个AI任务协调者。",
|
||||
llm=AgentLLMConfig(model=config.leader_model, temperature=0.3, max_iterations=10),
|
||||
user_id=uid,
|
||||
)
|
||||
|
||||
swarm = SwarmRuntime(config=config, leader_config=leader_config,
|
||||
teammate_configs=teammate_configs)
|
||||
|
||||
yield f"data: {_json.dumps({'type': 'swarm_start', 'mode': req.mode, 'max_teammates': req.max_teammates}, ensure_ascii=False)}\n\n"
|
||||
|
||||
result = await swarm.run(req.message)
|
||||
|
||||
yield f"data: {_json.dumps({'type': 'swarm_done', 'success': result.success, 'final_answer': result.final_answer, 'total_duration_ms': result.total_duration_ms, 'total_iterations': result.total_iterations, 'total_tool_calls': result.total_tool_calls}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
@@ -1,8 +1,9 @@
|
||||
"""
|
||||
Agent管理API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Response
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Response, Request
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
@@ -61,16 +62,18 @@ class SceneTemplateItem(BaseModel):
|
||||
category: Optional[str] = None
|
||||
default_temperature: Optional[float] = None
|
||||
parameter_hints: List[str] = Field(default_factory=list)
|
||||
contract_id: Optional[str] = None
|
||||
|
||||
|
||||
class AgentFromSceneTemplateCreate(BaseModel):
|
||||
"""从场景模板创建 Agent"""
|
||||
"""从场景模板创建 Agent(支持 DSL 契约参数)"""
|
||||
|
||||
template_id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Dict[str, Any] = Field(default_factory=dict)
|
||||
budget_config: Optional[Dict[str, Any]] = None
|
||||
contract_id: Optional[str] = None
|
||||
|
||||
|
||||
class PreviewChatTurnResponse(BaseModel):
|
||||
@@ -96,6 +99,9 @@ class AgentResponse(BaseModel):
|
||||
version: int
|
||||
status: str
|
||||
user_id: Optional[str] # 允许为None
|
||||
parent_agent_id: Optional[str] = None
|
||||
agent_type: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@@ -110,13 +116,14 @@ async def get_agents(
|
||||
limit: int = Query(100, ge=1, le=100, description="每页记录数"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词(按名称或描述)"),
|
||||
status: Optional[str] = Query(None, description="状态筛选"),
|
||||
workspace_id: Optional[str] = Query(None, description="工作区ID筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取Agent列表
|
||||
|
||||
支持分页、搜索、状态筛选
|
||||
|
||||
支持分页、搜索、状态筛选、工作区筛选
|
||||
"""
|
||||
# 管理员可以看到所有Agent,普通用户只能看到自己拥有的或有read权限的
|
||||
if current_user.role == "admin":
|
||||
@@ -125,10 +132,10 @@ async def get_agents(
|
||||
# 获取用户拥有或有read权限的Agent
|
||||
from sqlalchemy import or_
|
||||
from app.models.permission import AgentPermission
|
||||
|
||||
|
||||
# 用户拥有的Agent
|
||||
owned_agents = db.query(Agent.id).filter(Agent.user_id == current_user.id).subquery()
|
||||
|
||||
|
||||
# 用户有read权限的Agent(通过用户ID或角色)
|
||||
user_permissions = db.query(AgentPermission.agent_id).filter(
|
||||
AgentPermission.permission_type == "read",
|
||||
@@ -137,14 +144,18 @@ async def get_agents(
|
||||
AgentPermission.role_id.in_([r.id for r in current_user.roles])
|
||||
)
|
||||
).subquery()
|
||||
|
||||
|
||||
query = db.query(Agent).filter(
|
||||
or_(
|
||||
Agent.id.in_(db.query(owned_agents.c.id)),
|
||||
Agent.id.in_(db.query(user_permissions.c.agent_id))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# 工作区筛选
|
||||
if workspace_id:
|
||||
query = query.filter(Agent.workspace_id == workspace_id)
|
||||
|
||||
# 搜索:按名称或描述搜索
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
@@ -187,26 +198,36 @@ async def get_agents(
|
||||
@router.post("", response_model=AgentResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_agent(
|
||||
agent_data: AgentCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建Agent
|
||||
|
||||
|
||||
创建时会验证工作流配置的有效性
|
||||
"""
|
||||
# 从 JWT 提取当前工作区 ID
|
||||
from app.core.security import decode_access_token
|
||||
ws_id = None
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
payload = decode_access_token(auth_header[7:])
|
||||
if payload:
|
||||
ws_id = payload.get("ws") or None
|
||||
|
||||
# 验证工作流配置
|
||||
if "nodes" not in agent_data.workflow_config or "edges" not in agent_data.workflow_config:
|
||||
raise ValidationError("工作流配置必须包含nodes和edges")
|
||||
|
||||
|
||||
nodes = agent_data.workflow_config.get("nodes", [])
|
||||
edges = agent_data.workflow_config.get("edges", [])
|
||||
|
||||
|
||||
# 验证工作流
|
||||
validation_result = validate_workflow(nodes, edges)
|
||||
if not validation_result["valid"]:
|
||||
raise ValidationError(f"工作流配置验证失败: {', '.join(validation_result['errors'])}")
|
||||
|
||||
|
||||
# 检查名称是否重复
|
||||
existing_agent = db.query(Agent).filter(
|
||||
Agent.name == agent_data.name,
|
||||
@@ -214,7 +235,7 @@ async def create_agent(
|
||||
).first()
|
||||
if existing_agent:
|
||||
raise ConflictError(f"Agent名称 '{agent_data.name}' 已存在")
|
||||
|
||||
|
||||
# 创建Agent
|
||||
agent = Agent(
|
||||
name=agent_data.name,
|
||||
@@ -222,6 +243,7 @@ async def create_agent(
|
||||
workflow_config=agent_data.workflow_config,
|
||||
budget_config=agent_data.budget_config,
|
||||
user_id=current_user.id,
|
||||
workspace_id=ws_id,
|
||||
status="draft",
|
||||
category=agent_data.category,
|
||||
tags=agent_data.tags,
|
||||
@@ -377,9 +399,33 @@ async def delete_agent(
|
||||
raise HTTPException(status_code=403, detail="无权删除此Agent")
|
||||
|
||||
agent_name = agent.name
|
||||
|
||||
# 1. 清理直接关联的记录(级联删除)
|
||||
related_tables = [
|
||||
"agent_llm_logs",
|
||||
"agent_permissions",
|
||||
"agent_schedules",
|
||||
"executions",
|
||||
"team_members",
|
||||
]
|
||||
for table in related_tables:
|
||||
db.execute(text(f"DELETE FROM {table} WHERE agent_id = :aid"), {"aid": agent_id})
|
||||
logger.debug(f"已清理 {table} 中 agent_id={agent_id} 的记录")
|
||||
|
||||
# 2. 解除分配关系(设为 NULL)
|
||||
db.execute(
|
||||
text("UPDATE goals SET main_agent_id = NULL WHERE main_agent_id = :aid"),
|
||||
{"aid": agent_id},
|
||||
)
|
||||
db.execute(
|
||||
text("UPDATE tasks SET assigned_agent_id = NULL WHERE assigned_agent_id = :aid"),
|
||||
{"aid": agent_id},
|
||||
)
|
||||
|
||||
# 3. 删除 Agent 自身
|
||||
db.delete(agent)
|
||||
db.commit()
|
||||
|
||||
|
||||
logger.info(f"用户 {current_user.username} 删除了Agent: {agent_name} ({agent_id})")
|
||||
return {"message": "Agent已删除"}
|
||||
|
||||
@@ -773,3 +819,67 @@ async def import_agent(
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
return agent
|
||||
|
||||
|
||||
# ——— Agent 间知识共享 API ———
|
||||
|
||||
class InheritKnowledgeResponse(BaseModel):
|
||||
child_agent_id: str
|
||||
parent_agent_id: str
|
||||
knowledge_entries_copied: int
|
||||
global_knowledge_shared: int
|
||||
learning_patterns_copied: int
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/{agent_id}/inherit-knowledge", response_model=InheritKnowledgeResponse)
|
||||
def inherit_agent_knowledge(
|
||||
agent_id: str,
|
||||
parent_agent_id: str = Query(..., description="父 Agent ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""让子 Agent 从父 Agent 继承知识条目、全局知识和学习模式。"""
|
||||
child = db.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not child:
|
||||
raise NotFoundError("Agent", agent_id)
|
||||
parent = db.query(Agent).filter(Agent.id == parent_agent_id).first()
|
||||
if not parent:
|
||||
raise NotFoundError("父 Agent", parent_agent_id)
|
||||
|
||||
# 记录父子关系
|
||||
child.parent_agent_id = parent_agent_id
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
from app.services.knowledge_sharing import inherit_knowledge_from_parent
|
||||
result = inherit_knowledge_from_parent(agent_id, parent_agent_id, db)
|
||||
result["message"] = f"已从父 Agent 继承 {result['knowledge_entries_copied']} 条知识、" \
|
||||
f"{result['global_knowledge_shared']} 条全局知识、" \
|
||||
f"{result['learning_patterns_copied']} 条学习模式"
|
||||
return InheritKnowledgeResponse(**result)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"知识继承失败: {e}")
|
||||
|
||||
|
||||
@router.get("/{agent_id}/parent-knowledge")
|
||||
def get_parent_knowledge_preview(
|
||||
agent_id: str,
|
||||
limit: int = Query(5, ge=1, le=20),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""预览 Agent 可用的父级知识(不实际继承)。"""
|
||||
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
raise NotFoundError("Agent", agent_id)
|
||||
|
||||
from app.services.knowledge_sharing import get_parent_knowledge_context
|
||||
context = get_parent_knowledge_context(agent_id, max_entries=limit, db=db)
|
||||
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"parent_agent_id": getattr(agent, "parent_agent_id", None),
|
||||
"has_parent_knowledge": bool(context),
|
||||
"context_preview": context[:2000] if context else "",
|
||||
}
|
||||
|
||||
125
backend/app/api/audit_logs.py
Normal file
125
backend/app/api/audit_logs.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
操作审计日志 API
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, desc
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.audit_log import AuditLog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/audit-logs",
|
||||
tags=["audit-logs"],
|
||||
responses={
|
||||
401: {"description": "未授权"},
|
||||
403: {"description": "无权访问"},
|
||||
500: {"description": "服务器内部错误"}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ── Pydantic Schemas ─────────────────────────────────────────────
|
||||
|
||||
class AuditLogItem(BaseModel):
|
||||
id: str
|
||||
username: str
|
||||
action: str
|
||||
resource_type: str
|
||||
resource_id: Optional[str] = None
|
||||
resource_name: Optional[str] = None
|
||||
detail: Optional[dict] = None
|
||||
ip_address: Optional[str] = None
|
||||
status: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AuditLogStats(BaseModel):
|
||||
total: int
|
||||
by_action: dict # {"CREATE": 10, "UPDATE": 5, ...}
|
||||
by_resource: dict # {"agent": 8, "workflow": 7, ...}
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _check_admin(current_user: User):
|
||||
if getattr(current_user, "role", None) != "admin":
|
||||
from app.core.exceptions import ForbiddenError
|
||||
raise ForbiddenError("仅管理员可访问审计日志")
|
||||
|
||||
|
||||
# ── Endpoints ────────────────────────────────────────────────────
|
||||
|
||||
@router.get("", response_model=List[AuditLogItem])
|
||||
async def get_audit_logs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
user_id: Optional[str] = Query(None, description="按用户ID过滤"),
|
||||
action: Optional[str] = Query(None, description="操作类型: CREATE/UPDATE/DELETE/EXECUTE/LOGIN"),
|
||||
resource_type: Optional[str] = Query(None, description="资源类型: agent/workflow/user"),
|
||||
status: Optional[str] = Query(None, description="操作状态: success/failure"),
|
||||
start_date: Optional[str] = Query(None, description="开始时间 ISO格式"),
|
||||
end_date: Optional[str] = Query(None, description="结束时间 ISO格式"),
|
||||
keyword: Optional[str] = Query(None, description="资源名称关键词搜索"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=1000),
|
||||
):
|
||||
"""查询操作审计日志(仅管理员)"""
|
||||
_check_admin(current_user)
|
||||
|
||||
query = db.query(AuditLog)
|
||||
|
||||
if user_id:
|
||||
query = query.filter(AuditLog.user_id == user_id)
|
||||
if action:
|
||||
query = query.filter(AuditLog.action == action.upper())
|
||||
if resource_type:
|
||||
query = query.filter(AuditLog.resource_type == resource_type)
|
||||
if status:
|
||||
query = query.filter(AuditLog.status == status)
|
||||
if keyword:
|
||||
query = query.filter(AuditLog.resource_name.contains(keyword))
|
||||
if start_date:
|
||||
query = query.filter(AuditLog.created_at >= datetime.fromisoformat(start_date))
|
||||
if end_date:
|
||||
query = query.filter(AuditLog.created_at <= datetime.fromisoformat(end_date))
|
||||
|
||||
total = query.count()
|
||||
logs = query.order_by(desc(AuditLog.created_at)).offset(skip).limit(limit).all()
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
@router.get("/stats", response_model=AuditLogStats)
|
||||
async def get_audit_logs_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取审计日志统计(仅管理员)"""
|
||||
_check_admin(current_user)
|
||||
|
||||
total = db.query(func.count(AuditLog.id)).scalar() or 0
|
||||
|
||||
action_rows = db.query(
|
||||
AuditLog.action, func.count(AuditLog.id)
|
||||
).group_by(AuditLog.action).all()
|
||||
by_action = {row[0]: row[1] for row in action_rows}
|
||||
|
||||
resource_rows = db.query(
|
||||
AuditLog.resource_type, func.count(AuditLog.id)
|
||||
).group_by(AuditLog.resource_type).all()
|
||||
by_resource = {row[0]: row[1] for row in resource_rows}
|
||||
|
||||
return {"total": total, "by_action": by_action, "by_resource": by_resource}
|
||||
@@ -6,13 +6,15 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, field_validator
|
||||
import re
|
||||
import secrets
|
||||
import logging
|
||||
from app.core.database import get_db
|
||||
from app.core.security import verify_password, get_password_hash, create_access_token
|
||||
from app.models.user import User
|
||||
from datetime import timedelta
|
||||
from datetime import datetime, timedelta
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import ConflictError, UnauthorizedError, NotFoundError
|
||||
from app.core.redis_client import get_redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,6 +29,9 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
oauth2_scheme_optional = OAuth2PasswordBearer(
|
||||
tokenUrl="/api/v1/auth/login", auto_error=False
|
||||
)
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
@@ -54,6 +59,16 @@ class UserResponse(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MeResponse(BaseModel):
|
||||
"""当前用户完整信息(含工作区列表)"""
|
||||
id: str
|
||||
username: str
|
||||
email: str
|
||||
role: str
|
||||
workspaces: list = []
|
||||
current_workspace_id: str | None = None
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
"""令牌响应模型"""
|
||||
access_token: str
|
||||
@@ -85,40 +100,343 @@ async def register(user_data: UserCreate, db: Session = Depends(get_db)):
|
||||
return user
|
||||
|
||||
|
||||
def _get_user_default_workspace_id(db: Session, user: User) -> str | None:
|
||||
"""获取用户的默认工作区 ID。优先使用默认工作区,其次第一个 membership。"""
|
||||
from app.models.workspace import Workspace, WorkspaceMembership
|
||||
|
||||
# 优先使用系统默认工作区
|
||||
default_ws = db.query(Workspace).filter(Workspace.is_default == 1, Workspace.status == "active").first()
|
||||
if default_ws:
|
||||
membership = (
|
||||
db.query(WorkspaceMembership)
|
||||
.filter(
|
||||
WorkspaceMembership.workspace_id == default_ws.id,
|
||||
WorkspaceMembership.user_id == user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if membership:
|
||||
return default_ws.id
|
||||
|
||||
# 没有默认工作区,使用第一个 membership
|
||||
first_membership = (
|
||||
db.query(WorkspaceMembership)
|
||||
.filter(WorkspaceMembership.user_id == user.id)
|
||||
.first()
|
||||
)
|
||||
if first_membership:
|
||||
return first_membership.workspace_id
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
"""用户登录"""
|
||||
async def login(
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
db: Session = Depends(get_db),
|
||||
client_type: str = "web"
|
||||
):
|
||||
"""用户登录。client_type=android/ios 时签发 7 天 token,web 默认 30 分钟。"""
|
||||
user = db.query(User).filter(User.username == form_data.username).first()
|
||||
|
||||
|
||||
if not user or not verify_password(form_data.password, user.password_hash):
|
||||
logger.warning(f"登录失败: 用户名 {form_data.username}")
|
||||
raise UnauthorizedError("用户名或密码错误")
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
if client_type in ("android", "ios"):
|
||||
expires = timedelta(minutes=settings.JWT_MOBILE_TOKEN_EXPIRE_MINUTES)
|
||||
else:
|
||||
expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
ws_id = _get_user_default_workspace_id(db, user)
|
||||
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.id, "username": user.username}
|
||||
data={"sub": user.id, "username": user.username, "ws": ws_id or ""},
|
||||
expires_delta=expires,
|
||||
)
|
||||
|
||||
|
||||
return {"access_token": access_token, "token_type": "bearer"}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
) -> User:
|
||||
"""FastAPI 依赖 — 从 JWT 提取当前用户,返回 User 模型。"""
|
||||
from app.core.security import decode_access_token
|
||||
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise UnauthorizedError("无效的访问令牌")
|
||||
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise UnauthorizedError("无效的访问令牌")
|
||||
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise NotFoundError("用户", user_id)
|
||||
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@router.get("/me", response_model=MeResponse)
|
||||
async def get_me(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取当前用户信息(含工作区列表)。"""
|
||||
from app.core.security import decode_access_token
|
||||
from app.services.workspace_service import get_user_workspaces
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise UnauthorizedError("无效的访问令牌")
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise UnauthorizedError("无效的访问令牌")
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise NotFoundError("用户", user_id)
|
||||
|
||||
workspaces = get_user_workspaces(db, user)
|
||||
current_ws_id = payload.get("ws", "")
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
"workspaces": workspaces,
|
||||
"current_workspace_id": current_ws_id if current_ws_id else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/switch-workspace/{workspace_id}")
|
||||
async def switch_workspace(
|
||||
workspace_id: str,
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""切换当前工作区,重新签发 JWT(包含新的 ws 字段)。"""
|
||||
from app.core.security import decode_access_token
|
||||
from app.services.workspace_service import check_workspace_access
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise UnauthorizedError("无效的访问令牌")
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise UnauthorizedError("无效的访问令牌")
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if user is None:
|
||||
raise NotFoundError("用户", user_id)
|
||||
|
||||
if not check_workspace_access(db, user, workspace_id):
|
||||
raise HTTPException(status_code=403, detail="无权访问此工作区")
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
new_token = create_access_token(
|
||||
data={"sub": user.id, "username": user.username, "ws": workspace_id},
|
||||
expires_delta=expires,
|
||||
)
|
||||
|
||||
return {"access_token": new_token, "token_type": "bearer", "workspace_id": workspace_id}
|
||||
|
||||
|
||||
# ─── 密码重置 ───────────────────────────────────────────────
|
||||
|
||||
RESET_CODE_TTL_SEC = 600 # 验证码 10 分钟有效
|
||||
RESET_RATE_LIMIT_SEC = 60 # 同一邮箱 60 秒内只能发一次
|
||||
|
||||
|
||||
class ForgotPasswordRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def email_format(cls, v: str) -> str:
|
||||
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
|
||||
raise ValueError("邮箱格式无效")
|
||||
return v.lower()
|
||||
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
email: str
|
||||
code: str
|
||||
new_password: str
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def email_format(cls, v: str) -> str:
|
||||
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
|
||||
raise ValueError("邮箱格式无效")
|
||||
return v.lower()
|
||||
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def password_length(cls, v: str) -> str:
|
||||
if len(v) < 6:
|
||||
raise ValueError("密码不少于 6 个字符")
|
||||
if len(v) > 32:
|
||||
raise ValueError("密码不超过 32 个字符")
|
||||
return v
|
||||
|
||||
|
||||
async def _send_reset_email(email: str, code: str) -> bool:
|
||||
"""发送密码重置邮件。SMTP 不可用时记日志。"""
|
||||
try:
|
||||
import aiosmtplib
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
smtp_host = getattr(settings, 'SMTP_HOST', '') or 'smtp.qq.com'
|
||||
smtp_port = int(getattr(settings, 'SMTP_PORT', 0) or 587)
|
||||
smtp_user = getattr(settings, 'SMTP_USER', '') or ''
|
||||
smtp_password = getattr(settings, 'SMTP_PASSWORD', '') or ''
|
||||
|
||||
if not smtp_user or not smtp_password:
|
||||
logger.warning("SMTP 未配置,无法发送邮件。重置码: %s", code)
|
||||
return False
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = smtp_user
|
||||
msg['To'] = email
|
||||
msg['Subject'] = '天工智能体 - 密码重置验证码'
|
||||
msg.attach(MIMEText(
|
||||
f'您的密码重置验证码是:<b>{code}</b><br><br>'
|
||||
f'验证码 10 分钟内有效。如非本人操作请忽略此邮件。',
|
||||
'html', 'utf-8'
|
||||
))
|
||||
|
||||
await aiosmtplib.send(
|
||||
msg, hostname=smtp_host, port=smtp_port,
|
||||
username=smtp_user, password=smtp_password,
|
||||
use_tls=smtp_port == 587,
|
||||
)
|
||||
logger.info("密码重置邮件已发送至 %s", email)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("邮件发送失败: %s,重置码: %s", e, code)
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/forgot-password")
|
||||
async def forgot_password(body: ForgotPasswordRequest, db: Session = Depends(get_db)):
|
||||
"""发送密码重置验证码。"""
|
||||
user = db.query(User).filter(User.email == body.email).first()
|
||||
if not user:
|
||||
# 不泄露邮箱是否注册,统一返回成功
|
||||
return {"message": "如果邮箱已注册,验证码已发送"}
|
||||
|
||||
redis = get_redis_client()
|
||||
|
||||
# 频率限制
|
||||
rate_key = f"pwd_reset_rate:{body.email}"
|
||||
if redis:
|
||||
if redis.exists(rate_key):
|
||||
ttl = redis.ttl(rate_key)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"操作过于频繁,请 {ttl} 秒后重试"
|
||||
)
|
||||
|
||||
code = secrets.randbelow(900000) + 100000 # 6 位数字
|
||||
code_str = str(code)
|
||||
|
||||
# 存储到 Redis
|
||||
code_key = f"pwd_reset_code:{body.email}"
|
||||
if redis:
|
||||
redis.setex(code_key, RESET_CODE_TTL_SEC, code_str)
|
||||
redis.setex(rate_key, RESET_RATE_LIMIT_SEC, "1")
|
||||
else:
|
||||
# 无 Redis 时用内存存储(重启失效)
|
||||
if not hasattr(forgot_password, '_memory_store'):
|
||||
forgot_password._memory_store = {}
|
||||
forgot_password._memory_rate = {}
|
||||
forgot_password._memory_store[body.email] = {
|
||||
"code": code_str,
|
||||
"expires_at": datetime.utcnow() + timedelta(seconds=RESET_CODE_TTL_SEC),
|
||||
}
|
||||
forgot_password._memory_rate[body.email] = \
|
||||
datetime.utcnow() + timedelta(seconds=RESET_RATE_LIMIT_SEC)
|
||||
|
||||
# 尝试发送邮件
|
||||
sent = await _send_reset_email(body.email, code_str)
|
||||
|
||||
if not sent:
|
||||
# SMTP 未配置时记录验证码并返回(开发/测试环境)
|
||||
logger.info("开发模式:%s 的密码重置验证码为 %s", body.email, code_str)
|
||||
return {
|
||||
"message": "验证码已生成",
|
||||
"dev_code": code_str,
|
||||
}
|
||||
|
||||
return {"message": "验证码已发送至邮箱"}
|
||||
|
||||
|
||||
@router.post("/reset-password")
|
||||
async def reset_password(body: ResetPasswordRequest, db: Session = Depends(get_db)):
|
||||
"""使用验证码重置密码。"""
|
||||
user = db.query(User).filter(User.email == body.email).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="邮箱未注册")
|
||||
|
||||
redis = get_redis_client()
|
||||
code_key = f"pwd_reset_code:{body.email}"
|
||||
stored_code = None
|
||||
|
||||
if redis:
|
||||
stored_code = redis.get(code_key)
|
||||
elif hasattr(forgot_password, '_memory_store'):
|
||||
entry = forgot_password._memory_store.get(body.email, {})
|
||||
if entry and entry.get("expires_at", datetime.min) > datetime.utcnow():
|
||||
stored_code = entry.get("code")
|
||||
|
||||
if not stored_code:
|
||||
raise HTTPException(status_code=400, detail="验证码已过期或未请求")
|
||||
|
||||
if stored_code != body.code.strip():
|
||||
raise HTTPException(status_code=400, detail="验证码错误")
|
||||
|
||||
# 更新密码
|
||||
user.password_hash = get_password_hash(body.new_password)
|
||||
db.commit()
|
||||
|
||||
# 清除验证码
|
||||
if redis:
|
||||
redis.delete(code_key)
|
||||
elif hasattr(forgot_password, '_memory_store'):
|
||||
forgot_password._memory_store.pop(body.email, None)
|
||||
|
||||
logger.info("用户 %s 密码重置成功", user.username)
|
||||
return {"message": "密码重置成功,请使用新密码登录"}
|
||||
|
||||
|
||||
async def get_optional_user(
|
||||
token: str | None = Depends(oauth2_scheme_optional),
|
||||
db: Session = Depends(get_db)
|
||||
) -> User | None:
|
||||
"""获取当前用户(可选登录)。未提供 token 或 token 无效时返回 None。"""
|
||||
if not token:
|
||||
return None
|
||||
from app.core.security import decode_access_token
|
||||
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
user_id = payload.get("sub")
|
||||
if user_id is None:
|
||||
return None
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
return user
|
||||
|
||||
64
backend/app/api/deps.py
Normal file
64
backend/app/api/deps.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
FastAPI 依赖注入 — Workspace 上下文、权限检查
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.security import decode_access_token
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceContext:
|
||||
"""工作区上下文 — 包装当前用户 + 当前工作区 ID"""
|
||||
user: User
|
||||
workspace_id: str
|
||||
|
||||
|
||||
def get_current_workspace_id(request: Request) -> str:
|
||||
"""从 JWT 的 `ws` 字段提取当前工作区 ID。"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证令牌")
|
||||
|
||||
token = auth_header[7:]
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise HTTPException(status_code=401, detail="无效的访问令牌")
|
||||
|
||||
ws_id = payload.get("ws")
|
||||
if not ws_id:
|
||||
raise HTTPException(status_code=400, detail="令牌中未包含工作区信息,请重新登录")
|
||||
|
||||
return ws_id
|
||||
|
||||
|
||||
def get_workspace_context(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> WorkspaceContext:
|
||||
"""获取完整的 Workspace 上下文(用户 + 工作区),用于需要 workspace 过滤的接口。"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证令牌")
|
||||
|
||||
token = auth_header[7:]
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise HTTPException(status_code=401, detail="无效的访问令牌")
|
||||
|
||||
user_id = payload.get("sub")
|
||||
ws_id = payload.get("ws")
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="无效的访问令牌")
|
||||
|
||||
user = db.query(User).filter(User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="用户不存在")
|
||||
|
||||
return WorkspaceContext(user=user, workspace_id=ws_id or "")
|
||||
80
backend/app/api/fcm.py
Normal file
80
backend/app/api/fcm.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
FCM 推送 API — Android 设备 Token 管理
|
||||
|
||||
提供设备推送令牌的注册和注销接口。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.auth import get_current_user
|
||||
from app.core.database import SessionLocal
|
||||
from app.models.fcm_token import FcmToken
|
||||
from app.models.user import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/fcm", tags=["fcm"])
|
||||
|
||||
|
||||
class FcmRegisterRequest(BaseModel):
|
||||
token: str = Field(..., min_length=32, max_length=512, description="FCM/APNs 设备令牌")
|
||||
platform: str = Field(default="android", description="android / ios / web")
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register_fcm_token(
|
||||
req: FcmRegisterRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""注册设备推送 Token。同一 Token 多次注册会更新绑定用户和时间。"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
existing = db.query(FcmToken).filter(FcmToken.token == req.token).first()
|
||||
|
||||
if existing:
|
||||
existing.user_id = current_user.id
|
||||
existing.platform = req.platform
|
||||
db.commit()
|
||||
logger.info("FCM Token 已更新: user=%s platform=%s", current_user.id, req.platform)
|
||||
return {"status": "updated", "id": str(existing.id)}
|
||||
|
||||
fcm = FcmToken(
|
||||
user_id=current_user.id,
|
||||
token=req.token,
|
||||
platform=req.platform,
|
||||
)
|
||||
db.add(fcm)
|
||||
db.commit()
|
||||
db.refresh(fcm)
|
||||
logger.info("FCM Token 已注册: user=%s platform=%s", current_user.id, req.platform)
|
||||
return {"status": "registered", "id": str(fcm.id)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.delete("/unregister")
|
||||
async def unregister_fcm_token(
|
||||
token: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""注销设备推送 Token。"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
fcm = (
|
||||
db.query(FcmToken)
|
||||
.filter(FcmToken.token == token, FcmToken.user_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if fcm:
|
||||
db.delete(fcm)
|
||||
db.commit()
|
||||
logger.info("FCM Token 已注销: user=%s", current_user.id)
|
||||
return {"status": "deleted"}
|
||||
return {"status": "not_found"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
93
backend/app/api/feedback.py
Normal file
93
backend/app/api/feedback.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
反馈数据 API — 用户反馈分析、记录查询、反例生成
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.services.feedback_learner import feedback_learner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/feedback", tags=["feedback"])
|
||||
|
||||
|
||||
@router.get("/analysis")
|
||||
def get_feedback_analysis(
|
||||
days: int = Query(7, ge=1, le=90, description="统计天数"),
|
||||
agent_name: Optional[str] = Query(None, description="按 Agent 筛选"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取反馈分析报告 — 信号分布、负面率、策略建议。"""
|
||||
result = feedback_learner.analyze_feedback_patterns(agent_name=agent_name, days=days)
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/records")
|
||||
def list_feedback_records(
|
||||
agent_name: Optional[str] = Query(None, description="按 Agent 筛选"),
|
||||
signal_type: Optional[str] = Query(None, description="按信号类型筛选"),
|
||||
days: int = Query(30, ge=1, le=365, description="统计天数"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回条数"),
|
||||
offset: int = Query(0, ge=0, description="偏移量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""查询反馈记录列表。"""
|
||||
from datetime import datetime, timedelta
|
||||
from app.models.feedback_record import FeedbackRecord
|
||||
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
q = db.query(FeedbackRecord).filter(FeedbackRecord.created_at >= since)
|
||||
|
||||
if agent_name:
|
||||
q = q.filter(FeedbackRecord.agent_name == agent_name)
|
||||
if signal_type:
|
||||
q = q.filter(FeedbackRecord.signal_type == signal_type)
|
||||
|
||||
total = q.count()
|
||||
records = (
|
||||
q.order_by(FeedbackRecord.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"items": [
|
||||
{
|
||||
"id": r.id,
|
||||
"user_id": r.user_id,
|
||||
"signal_type": r.signal_type,
|
||||
"severity": r.severity,
|
||||
"execution_log_id": r.execution_log_id,
|
||||
"agent_name": r.agent_name,
|
||||
"original_output": (r.original_output or "")[:200],
|
||||
"user_correction": (r.user_correction or "")[:200],
|
||||
"learned": r.learned,
|
||||
"lesson_summary": r.lesson_summary,
|
||||
"created_at": r.created_at.isoformat() if r.created_at else None,
|
||||
}
|
||||
for r in records
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/negative-examples/{agent_name}")
|
||||
def get_negative_examples(
|
||||
agent_name: str,
|
||||
limit: int = Query(5, ge=1, le=20, description="返回条数"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取指定 Agent 的反例(用于改进 system prompt)。"""
|
||||
examples = feedback_learner.generate_negative_examples(agent_name=agent_name, limit=limit)
|
||||
return {"agent_name": agent_name, "count": len(examples), "examples": examples}
|
||||
109
backend/app/api/knowledge_dashboard.py
Normal file
109
backend/app/api/knowledge_dashboard.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
知识仪表盘 API — 瓶颈分析、知识条目查询、趋势统计
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.knowledge_entry import KnowledgeEntry
|
||||
from app.services.bottleneck_detector import bottleneck_detector
|
||||
from app.services.optimization_engine import optimization_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/knowledge-dashboard", tags=["knowledge-dashboard"])
|
||||
|
||||
|
||||
@router.get("/bottlenecks")
|
||||
def get_bottlenecks(
|
||||
hours: int = Query(168, ge=1, le=2160, description="分析时长(小时)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""瓶颈分析 — 检测工作流性能瓶颈并生成优化建议。"""
|
||||
analysis = bottleneck_detector.run_full_analysis(hours=hours)
|
||||
optimizations = optimization_engine.generate_optimizations(analysis.get("bottlenecks", []))
|
||||
# Build recommendations dict keyed by node_type for frontend lookup
|
||||
recommendations_map = {}
|
||||
for opt in optimizations:
|
||||
recommendations_map[opt["node_type"]] = {
|
||||
"node_type": opt["node_type"],
|
||||
"severity": opt["severity"],
|
||||
"current_state": opt.get("current_metrics", {}),
|
||||
"changes": opt.get("changes", []),
|
||||
}
|
||||
return {
|
||||
**analysis,
|
||||
"recommendations": list(optimizations),
|
||||
"optimizations": recommendations_map,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/entries")
|
||||
def get_knowledge_entries(
|
||||
days: int = Query(7, ge=1, le=365, description="统计天数"),
|
||||
limit: int = Query(50, ge=1, le=200, description="返回条数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取知识条目列表,按创建时间倒序。"""
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
entries = (
|
||||
db.query(KnowledgeEntry)
|
||||
.filter(
|
||||
KnowledgeEntry.created_at >= since,
|
||||
KnowledgeEntry.is_active == True,
|
||||
)
|
||||
.order_by(KnowledgeEntry.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
return [e.to_dict() for e in entries]
|
||||
|
||||
|
||||
@router.get("/trend")
|
||||
def get_knowledge_trend(
|
||||
days: int = Query(7, ge=1, le=365, description="统计天数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""知识条目增长趋势 — 按天统计新增数量。"""
|
||||
since = datetime.now() - timedelta(days=days)
|
||||
|
||||
# GROUP BY date(created_at)
|
||||
rows = (
|
||||
db.query(
|
||||
func.date(KnowledgeEntry.created_at).label("date"),
|
||||
func.count(KnowledgeEntry.id).label("count"),
|
||||
)
|
||||
.filter(
|
||||
KnowledgeEntry.created_at >= since,
|
||||
KnowledgeEntry.is_active == True,
|
||||
)
|
||||
.group_by(func.date(KnowledgeEntry.created_at))
|
||||
.order_by(func.date(KnowledgeEntry.created_at).asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# Fill in missing dates with 0 count
|
||||
trend = []
|
||||
current_date = since.date()
|
||||
end_date = datetime.now().date()
|
||||
date_counts = {row.date: row.count for row in rows}
|
||||
|
||||
while current_date <= end_date:
|
||||
trend.append({
|
||||
"date": current_date.isoformat(),
|
||||
"count": date_counts.get(current_date, 0),
|
||||
})
|
||||
current_date += timedelta(days=1)
|
||||
|
||||
return trend
|
||||
@@ -58,9 +58,10 @@ async def create_agent_from_template_v1(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
workflow_config = build_workflow_for_template(
|
||||
body.template_id, body.parameters or {}
|
||||
)
|
||||
params = dict(body.parameters or {})
|
||||
if body.contract_id:
|
||||
params["contract_id"] = body.contract_id
|
||||
workflow_config = build_workflow_for_template(body.template_id, params)
|
||||
except ValueError as e:
|
||||
raise ValidationError(str(e))
|
||||
|
||||
|
||||
80
backend/app/api/push.py
Normal file
80
backend/app/api/push.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""浏览器推送订阅 API"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.auth import get_current_user, get_optional_user
|
||||
from app.models.user import User
|
||||
from app.core.database import SessionLocal
|
||||
from app.models.push_subscription import PushSubscription
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/push", tags=["push"])
|
||||
|
||||
|
||||
class PushSubscriptionRequest(BaseModel):
|
||||
endpoint: str = Field(..., description="Push 订阅端点 URL")
|
||||
keys: dict = Field(..., description="包含 p256dh 和 auth")
|
||||
user_agent: str = Field(default="", description="设备 UA")
|
||||
|
||||
|
||||
@router.post("/subscribe")
|
||||
async def subscribe_push(
|
||||
req: PushSubscriptionRequest,
|
||||
current_user: User = Depends(get_optional_user),
|
||||
):
|
||||
"""保存浏览器推送订阅。用户可选登录。"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# 检查是否已存在相同 endpoint
|
||||
existing = db.query(PushSubscription).filter(
|
||||
PushSubscription.endpoint == req.endpoint
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# 更新已有记录
|
||||
existing.p256dh = req.keys.get("p256dh", "")
|
||||
existing.auth = req.keys.get("auth", "")
|
||||
existing.user_id = str(current_user.id) if current_user else existing.user_id
|
||||
existing.user_agent = req.user_agent
|
||||
db.commit()
|
||||
return {"status": "updated", "subscription_id": str(existing.id)}
|
||||
|
||||
sub = PushSubscription(
|
||||
user_id=str(current_user.id) if current_user else None,
|
||||
endpoint=req.endpoint,
|
||||
p256dh=req.keys.get("p256dh", ""),
|
||||
auth=req.keys.get("auth", ""),
|
||||
user_agent=req.user_agent,
|
||||
)
|
||||
db.add(sub)
|
||||
db.commit()
|
||||
db.refresh(sub)
|
||||
return {"status": "created", "subscription_id": str(sub.id)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.delete("/unsubscribe")
|
||||
async def unsubscribe_push(
|
||||
endpoint: str,
|
||||
current_user: User = Depends(get_optional_user),
|
||||
):
|
||||
"""取消推送订阅。"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(PushSubscription).filter(PushSubscription.endpoint == endpoint)
|
||||
if current_user:
|
||||
q = q.filter(PushSubscription.user_id == str(current_user.id))
|
||||
sub = q.first()
|
||||
if sub:
|
||||
db.delete(sub)
|
||||
db.commit()
|
||||
return {"status": "deleted"}
|
||||
return {"status": "not_found"}
|
||||
finally:
|
||||
db.close()
|
||||
366
backend/app/api/scene_contracts.py
Normal file
366
backend/app/api/scene_contracts.py
Normal file
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
场景契约 API — 统一 DSL 输入契约的 CRUD 与预置查询
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.scene_contract import SceneContract
|
||||
from app.services.scene_contract_service import (
|
||||
build_system_prompt_from_contract,
|
||||
build_acceptance_prompt,
|
||||
validate_input_against_contract,
|
||||
get_preset_contract,
|
||||
list_preset_contracts_meta,
|
||||
PRESET_CONTRACTS,
|
||||
ContractPromptConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/scene-contracts", tags=["scene-contracts"])
|
||||
|
||||
|
||||
# ── Pydantic schemas ──
|
||||
|
||||
class DeliverableItem(BaseModel):
|
||||
name: str
|
||||
format: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ExampleItem(BaseModel):
|
||||
input: str
|
||||
output: str
|
||||
|
||||
|
||||
class ContractCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
goal: str
|
||||
role: Optional[str] = None
|
||||
input_description: Optional[str] = None
|
||||
input_schema: Optional[Dict[str, Any]] = None
|
||||
constraints: List[str] = Field(default_factory=list)
|
||||
forbidden_actions: List[str] = Field(default_factory=list)
|
||||
required_tools: List[str] = Field(default_factory=list)
|
||||
deliverables: List[DeliverableItem] = Field(default_factory=list)
|
||||
acceptance_criteria: List[str] = Field(default_factory=list)
|
||||
output_schema: Optional[Dict[str, Any]] = None
|
||||
examples: List[ExampleItem] = Field(default_factory=list)
|
||||
category: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
is_public: bool = False
|
||||
template_binding: Optional[str] = None
|
||||
|
||||
|
||||
class ContractUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
goal: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
input_description: Optional[str] = None
|
||||
input_schema: Optional[Dict[str, Any]] = None
|
||||
constraints: Optional[List[str]] = None
|
||||
forbidden_actions: Optional[List[str]] = None
|
||||
required_tools: Optional[List[str]] = None
|
||||
deliverables: Optional[List[DeliverableItem]] = None
|
||||
acceptance_criteria: Optional[List[str]] = None
|
||||
output_schema: Optional[Dict[str, Any]] = None
|
||||
examples: Optional[List[ExampleItem]] = None
|
||||
category: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
is_public: Optional[bool] = None
|
||||
template_binding: Optional[str] = None
|
||||
|
||||
|
||||
class ContractResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
goal: str
|
||||
role: Optional[str] = None
|
||||
input_description: Optional[str] = None
|
||||
input_schema: Optional[Dict[str, Any]] = None
|
||||
constraints: List[str] = []
|
||||
forbidden_actions: List[str] = []
|
||||
required_tools: List[str] = []
|
||||
deliverables: List[Dict[str, Any]] = []
|
||||
acceptance_criteria: List[str] = []
|
||||
output_schema: Optional[Dict[str, Any]] = None
|
||||
examples: List[Dict[str, Any]] = []
|
||||
category: Optional[str] = None
|
||||
tags: List[str] = []
|
||||
version: int = 1
|
||||
is_public: int = 0
|
||||
use_count: int = 0
|
||||
user_id: Optional[str] = None
|
||||
template_binding: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
updated_at: Optional[str] = None
|
||||
|
||||
|
||||
class ContractMetaItem(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
category: Optional[str] = None
|
||||
tags: List[str] = []
|
||||
goal: str = ""
|
||||
deliverable_count: int = 0
|
||||
constraint_count: int = 0
|
||||
|
||||
|
||||
class PromptGenerateRequest(BaseModel):
|
||||
contract_id: Optional[str] = None
|
||||
contract_data: Optional[Dict[str, Any]] = None
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class PromptGenerateResponse(BaseModel):
|
||||
system_prompt: str
|
||||
|
||||
|
||||
class AcceptanceEvaluateRequest(BaseModel):
|
||||
contract_id: Optional[str] = None
|
||||
contract_data: Optional[Dict[str, Any]] = None
|
||||
agent_output: str
|
||||
|
||||
|
||||
class AcceptanceEvaluateResponse(BaseModel):
|
||||
evaluation_prompt: str
|
||||
|
||||
|
||||
class InputValidateRequest(BaseModel):
|
||||
contract_id: Optional[str] = None
|
||||
contract_data: Optional[Dict[str, Any]] = None
|
||||
user_input: Dict[str, Any]
|
||||
|
||||
|
||||
class InputValidateResponse(BaseModel):
|
||||
valid: bool
|
||||
errors: List[str] = []
|
||||
warnings: List[str] = []
|
||||
|
||||
|
||||
# ── Predefined preset contracts ──
|
||||
|
||||
@router.get("/presets", response_model=List[ContractMetaItem])
|
||||
async def list_preset_contracts(current_user: User = Depends(get_current_user)):
|
||||
"""列出所有预置场景契约"""
|
||||
_ = current_user
|
||||
return list_preset_contracts_meta()
|
||||
|
||||
|
||||
@router.get("/presets/{contract_id}")
|
||||
async def get_preset_contract_detail(
|
||||
contract_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取单个预置契约的完整内容"""
|
||||
_ = current_user
|
||||
contract = get_preset_contract(contract_id)
|
||||
if not contract:
|
||||
raise HTTPException(status_code=404, detail=f"预置契约不存在: {contract_id}")
|
||||
return {"id": contract_id, **contract}
|
||||
|
||||
|
||||
# ── CRUD for user contracts ──
|
||||
|
||||
@router.get("/", response_model=List[ContractResponse])
|
||||
async def list_contracts(
|
||||
category: Optional[str] = Query(None),
|
||||
is_public: Optional[bool] = Query(None),
|
||||
template_binding: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""列出用户的场景契约"""
|
||||
query = db.query(SceneContract).filter(
|
||||
(SceneContract.user_id == current_user.id) | (SceneContract.is_public == 1)
|
||||
)
|
||||
if category:
|
||||
query = query.filter(SceneContract.category == category)
|
||||
if is_public is not None:
|
||||
query = query.filter(SceneContract.is_public == (1 if is_public else 0))
|
||||
if template_binding:
|
||||
query = query.filter(SceneContract.template_binding == template_binding)
|
||||
|
||||
contracts = query.order_by(SceneContract.updated_at.desc()).all()
|
||||
return [c.to_dict() for c in contracts]
|
||||
|
||||
|
||||
@router.post("/", response_model=ContractResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_contract(
|
||||
body: ContractCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""创建自定义场景契约"""
|
||||
contract = SceneContract(
|
||||
id=str(uuid.uuid4()),
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
goal=body.goal,
|
||||
role=body.role,
|
||||
input_description=body.input_description,
|
||||
input_schema=body.input_schema,
|
||||
constraints=body.constraints,
|
||||
forbidden_actions=body.forbidden_actions,
|
||||
required_tools=body.required_tools,
|
||||
deliverables=[d.model_dump() for d in body.deliverables],
|
||||
acceptance_criteria=body.acceptance_criteria,
|
||||
output_schema=body.output_schema,
|
||||
examples=[e.model_dump() for e in body.examples],
|
||||
category=body.category,
|
||||
tags=body.tags,
|
||||
is_public=1 if body.is_public else 0,
|
||||
template_binding=body.template_binding,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
db.add(contract)
|
||||
db.commit()
|
||||
db.refresh(contract)
|
||||
return contract.to_dict()
|
||||
|
||||
|
||||
@router.get("/{contract_id}", response_model=ContractResponse)
|
||||
async def get_contract(
|
||||
contract_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取场景契约详情"""
|
||||
# 先查预置
|
||||
preset = get_preset_contract(contract_id)
|
||||
if preset:
|
||||
return {"id": contract_id, "version": 1, "is_public": 1, "use_count": 0,
|
||||
"user_id": None, "template_binding": None,
|
||||
"created_at": None, "updated_at": None, **preset}
|
||||
|
||||
contract = db.query(SceneContract).filter(SceneContract.id == contract_id).first()
|
||||
if not contract:
|
||||
raise HTTPException(status_code=404, detail="契约不存在")
|
||||
if contract.user_id != current_user.id and contract.is_public == 0:
|
||||
raise HTTPException(status_code=403, detail="无权访问此契约")
|
||||
return contract.to_dict()
|
||||
|
||||
|
||||
@router.put("/{contract_id}", response_model=ContractResponse)
|
||||
async def update_contract(
|
||||
contract_id: str,
|
||||
body: ContractUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""更新场景契约"""
|
||||
contract = db.query(SceneContract).filter(SceneContract.id == contract_id).first()
|
||||
if not contract:
|
||||
raise HTTPException(status_code=404, detail="契约不存在")
|
||||
if contract.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="只能修改自己的契约")
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
# Handle list/dict fields
|
||||
for field in ["deliverables", "examples"]:
|
||||
if field in update_data and update_data[field] is not None:
|
||||
update_data[field] = [item if isinstance(item, dict) else item.model_dump()
|
||||
for item in update_data[field]]
|
||||
if "is_public" in update_data and update_data["is_public"] is not None:
|
||||
update_data["is_public"] = 1 if update_data["is_public"] else 0
|
||||
|
||||
for key, value in update_data.items():
|
||||
if value is not None:
|
||||
setattr(contract, key, value)
|
||||
|
||||
contract.version = (contract.version or 1) + 1
|
||||
db.commit()
|
||||
db.refresh(contract)
|
||||
return contract.to_dict()
|
||||
|
||||
|
||||
@router.delete("/{contract_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_contract(
|
||||
contract_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""删除场景契约"""
|
||||
contract = db.query(SceneContract).filter(SceneContract.id == contract_id).first()
|
||||
if not contract:
|
||||
raise HTTPException(status_code=404, detail="契约不存在")
|
||||
if contract.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="只能删除自己的契约")
|
||||
db.delete(contract)
|
||||
db.commit()
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
|
||||
|
||||
# ── DSL 操作端点 ──
|
||||
|
||||
@router.post("/generate-prompt", response_model=PromptGenerateResponse)
|
||||
async def generate_prompt_from_contract(
|
||||
body: PromptGenerateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""从契约生成 system prompt"""
|
||||
_ = current_user
|
||||
|
||||
# 获取契约数据
|
||||
contract_data = body.contract_data
|
||||
if not contract_data and body.contract_id:
|
||||
preset = get_preset_contract(body.contract_id)
|
||||
if preset:
|
||||
contract_data = preset
|
||||
if not contract_data:
|
||||
raise HTTPException(status_code=400, detail="请提供 contract_id 或 contract_data")
|
||||
|
||||
config = ContractPromptConfig(**(body.config or {}))
|
||||
prompt = build_system_prompt_from_contract(contract_data, config)
|
||||
return {"system_prompt": prompt}
|
||||
|
||||
|
||||
@router.post("/evaluate", response_model=AcceptanceEvaluateResponse)
|
||||
async def evaluate_output_against_contract(
|
||||
body: AcceptanceEvaluateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成验收评估 prompt(用于评估 Agent 输出是否满足契约)"""
|
||||
_ = current_user
|
||||
|
||||
contract_data = body.contract_data
|
||||
if not contract_data and body.contract_id:
|
||||
preset = get_preset_contract(body.contract_id)
|
||||
if preset:
|
||||
contract_data = preset
|
||||
if not contract_data:
|
||||
raise HTTPException(status_code=400, detail="请提供 contract_id 或 contract_data")
|
||||
|
||||
prompt = build_acceptance_prompt(contract_data, body.agent_output)
|
||||
return {"evaluation_prompt": prompt}
|
||||
|
||||
|
||||
@router.post("/validate-input", response_model=InputValidateResponse)
|
||||
async def validate_input_against_contract_endpoint(
|
||||
body: InputValidateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""根据契约的 input_schema 验证用户输入"""
|
||||
_ = current_user
|
||||
|
||||
contract_data = body.contract_data
|
||||
if not contract_data and body.contract_id:
|
||||
preset = get_preset_contract(body.contract_id)
|
||||
if preset:
|
||||
contract_data = preset
|
||||
if not contract_data:
|
||||
raise HTTPException(status_code=400, detail="请提供 contract_id 或 contract_data")
|
||||
|
||||
result = validate_input_against_contract(contract_data, body.user_input)
|
||||
return result
|
||||
350
backend/app/api/system_logs.py
Normal file
350
backend/app/api/system_logs.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
系统日志统一查询 API
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, text, case
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.config import settings
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.execution_log import ExecutionLog
|
||||
from app.models.agent_execution_log import AgentExecutionLog
|
||||
from app.models.agent_llm_log import AgentLLMLog
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/system-logs",
|
||||
tags=["system-logs"],
|
||||
responses={
|
||||
401: {"description": "未授权"},
|
||||
403: {"description": "无权访问"},
|
||||
500: {"description": "服务器内部错误"}
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ── Pydantic Schemas ─────────────────────────────────────────────
|
||||
|
||||
class UnifiedLogItem(BaseModel):
|
||||
id: str
|
||||
source: str # "execution" | "agent" | "llm"
|
||||
level: Optional[str] = None
|
||||
message: str
|
||||
timestamp: Optional[datetime] = None
|
||||
resource_type: Optional[str] = None
|
||||
resource_id: Optional[str] = None
|
||||
duration_ms: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class LogStatsResponse(BaseModel):
|
||||
total_count: int
|
||||
error_count: int
|
||||
warn_count: int
|
||||
info_count: int
|
||||
source_breakdown: Dict[str, int] # {"execution": 123, "agent": 45, "llm": 67}
|
||||
hourly_trend: List[Dict[str, Any]] # [{"hour": "2026-05-10T14", "count": 5}, ...]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AppLogItem(BaseModel):
|
||||
line_number: int
|
||||
content: str
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────
|
||||
|
||||
def _build_union_query(
|
||||
db: Session,
|
||||
source: Optional[str] = None,
|
||||
level: Optional[str] = None,
|
||||
keyword: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
user_id: Optional[str] = None,
|
||||
):
|
||||
"""构建跨表 UNION ALL 查询"""
|
||||
queries: list = []
|
||||
params: dict = {}
|
||||
|
||||
# 1) 工作流执行日志
|
||||
if source in (None, "all", "execution"):
|
||||
q = db.query(
|
||||
ExecutionLog.id.label("id"),
|
||||
text("'execution'").label("source"),
|
||||
ExecutionLog.level.label("level"),
|
||||
ExecutionLog.message.label("message"),
|
||||
ExecutionLog.timestamp.label("timestamp"),
|
||||
ExecutionLog.node_type.label("resource_type"),
|
||||
ExecutionLog.execution_id.label("resource_id"),
|
||||
ExecutionLog.duration.label("duration_ms"),
|
||||
text("NULL").label("username"),
|
||||
)
|
||||
if level:
|
||||
q = q.filter(ExecutionLog.level == level.upper())
|
||||
if keyword:
|
||||
q = q.filter(ExecutionLog.message.contains(keyword))
|
||||
if start_date:
|
||||
q = q.filter(ExecutionLog.timestamp >= start_date)
|
||||
if end_date:
|
||||
q = q.filter(ExecutionLog.timestamp <= end_date)
|
||||
queries.append(q)
|
||||
|
||||
# 2) Agent 执行日志
|
||||
if source in (None, "all", "agent"):
|
||||
q = db.query(
|
||||
AgentExecutionLog.id.label("id"),
|
||||
text("'agent'").label("source"),
|
||||
case(
|
||||
(AgentExecutionLog.status == "failed", "ERROR"),
|
||||
(AgentExecutionLog.status == "completed", "INFO"),
|
||||
else_="INFO"
|
||||
).label("level"),
|
||||
AgentExecutionLog.user_message.label("message"),
|
||||
AgentExecutionLog.created_at.label("timestamp"),
|
||||
text("'agent_chat'").label("resource_type"),
|
||||
AgentExecutionLog.agent_id.label("resource_id"),
|
||||
AgentExecutionLog.total_latency_ms.label("duration_ms"),
|
||||
text("NULL").label("username"),
|
||||
)
|
||||
if level:
|
||||
if level.upper() == "ERROR":
|
||||
q = q.filter(AgentExecutionLog.status == "failed")
|
||||
elif level.upper() == "WARN":
|
||||
q = q.filter(AgentExecutionLog.status == "failed")
|
||||
else:
|
||||
q = q.filter(AgentExecutionLog.status != "failed")
|
||||
if keyword:
|
||||
q = q.filter(AgentExecutionLog.user_message.contains(keyword))
|
||||
if start_date:
|
||||
q = q.filter(AgentExecutionLog.created_at >= start_date)
|
||||
if end_date:
|
||||
q = q.filter(AgentExecutionLog.created_at <= end_date)
|
||||
if user_id:
|
||||
q = q.filter(AgentExecutionLog.user_id == user_id)
|
||||
queries.append(q)
|
||||
|
||||
# 3) LLM 调用日志
|
||||
if source in (None, "all", "llm"):
|
||||
q = db.query(
|
||||
AgentLLMLog.id.label("id"),
|
||||
text("'llm'").label("source"),
|
||||
case(
|
||||
(AgentLLMLog.status == "error", "ERROR"),
|
||||
(AgentLLMLog.status == "rate_limited", "WARN"),
|
||||
else_="INFO"
|
||||
).label("level"),
|
||||
AgentLLMLog.assistant_content.label("message"),
|
||||
AgentLLMLog.created_at.label("timestamp"),
|
||||
text("'llm_call'").label("resource_type"),
|
||||
AgentLLMLog.agent_id.label("resource_id"),
|
||||
AgentLLMLog.latency_ms.label("duration_ms"),
|
||||
text("NULL").label("username"),
|
||||
)
|
||||
if level:
|
||||
if level.upper() == "ERROR":
|
||||
q = q.filter(AgentLLMLog.status == "error")
|
||||
elif level.upper() == "WARN":
|
||||
q = q.filter(AgentLLMLog.status == "rate_limited")
|
||||
else:
|
||||
q = q.filter(AgentLLMLog.status == "success")
|
||||
if keyword:
|
||||
q = q.filter(AgentLLMLog.assistant_content.contains(keyword))
|
||||
if start_date:
|
||||
q = q.filter(AgentLLMLog.created_at >= start_date)
|
||||
if end_date:
|
||||
q = q.filter(AgentLLMLog.created_at <= end_date)
|
||||
queries.append(q)
|
||||
|
||||
return queries
|
||||
|
||||
|
||||
def _check_admin(current_user: User):
|
||||
if getattr(current_user, "role", None) != "admin":
|
||||
from app.core.exceptions import ForbiddenError
|
||||
raise ForbiddenError("仅管理员可访问系统日志")
|
||||
|
||||
|
||||
# ── Endpoints ────────────────────────────────────────────────────
|
||||
|
||||
@router.get("", response_model=List[UnifiedLogItem])
|
||||
async def get_system_logs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
source: Optional[str] = Query(None, description="日志来源: execution/agent/llm/all"),
|
||||
level: Optional[str] = Query(None, description="日志级别: INFO/WARN/ERROR"),
|
||||
keyword: Optional[str] = Query(None, description="关键词搜索"),
|
||||
start_date: Optional[str] = Query(None, description="开始时间 ISO格式"),
|
||||
end_date: Optional[str] = Query(None, description="结束时间 ISO格式"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=1000),
|
||||
):
|
||||
"""
|
||||
统一日志查询:跨 execution_logs / agent_execution_logs / agent_llm_logs 联合查询。
|
||||
|
||||
管理员可查全部,普通用户只能查看自己相关的 Agent 执行日志和 LLM 日志。
|
||||
"""
|
||||
_check_admin(current_user)
|
||||
|
||||
# 解析时间
|
||||
sd = datetime.fromisoformat(start_date) if start_date else None
|
||||
ed = datetime.fromisoformat(end_date) if end_date else None
|
||||
|
||||
# 非 admin 限制用户范围
|
||||
user_id = None
|
||||
if getattr(current_user, "role", None) != "admin":
|
||||
user_id = current_user.id
|
||||
|
||||
source_val = source if source else "all"
|
||||
queries = _build_union_query(db, source_val, level, keyword, sd, ed, user_id)
|
||||
|
||||
if not queries:
|
||||
return []
|
||||
|
||||
# UNION ALL
|
||||
union = queries[0]
|
||||
for q in queries[1:]:
|
||||
union = union.union_all(q)
|
||||
|
||||
# 排序 + 分页
|
||||
total = union.count() if hasattr(union, 'count') else 0
|
||||
rows = union.order_by(text("timestamp DESC")).offset(skip).limit(limit).all()
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
@router.get("/stats", response_model=LogStatsResponse)
|
||||
async def get_system_logs_stats(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取系统日志统计:各级别计数、各来源计数、24小时趋势"""
|
||||
_check_admin(current_user)
|
||||
|
||||
now = datetime.utcnow()
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
hours_24 = now - timedelta(hours=24)
|
||||
|
||||
# 工作流执行日志统计
|
||||
exec_total = db.query(func.count(ExecutionLog.id)).filter(
|
||||
ExecutionLog.timestamp >= today_start
|
||||
).scalar() or 0
|
||||
exec_errors = db.query(func.count(ExecutionLog.id)).filter(
|
||||
ExecutionLog.timestamp >= today_start,
|
||||
ExecutionLog.level == "ERROR"
|
||||
).scalar() or 0
|
||||
exec_warns = db.query(func.count(ExecutionLog.id)).filter(
|
||||
ExecutionLog.timestamp >= today_start,
|
||||
ExecutionLog.level == "WARN"
|
||||
).scalar() or 0
|
||||
|
||||
# Agent 执行日志统计
|
||||
agent_total = db.query(func.count(AgentExecutionLog.id)).filter(
|
||||
AgentExecutionLog.created_at >= today_start
|
||||
).scalar() or 0
|
||||
agent_errors = db.query(func.count(AgentExecutionLog.id)).filter(
|
||||
AgentExecutionLog.created_at >= today_start,
|
||||
AgentExecutionLog.status == "failed"
|
||||
).scalar() or 0
|
||||
|
||||
# LLM 调用日志统计
|
||||
llm_total = db.query(func.count(AgentLLMLog.id)).filter(
|
||||
AgentLLMLog.created_at >= today_start
|
||||
).scalar() or 0
|
||||
llm_errors = db.query(func.count(AgentLLMLog.id)).filter(
|
||||
AgentLLMLog.created_at >= today_start,
|
||||
AgentLLMLog.status == "error"
|
||||
).scalar() or 0
|
||||
|
||||
total_count = exec_total + agent_total + llm_total
|
||||
error_count = exec_errors + agent_errors + llm_errors
|
||||
warn_count = exec_warns
|
||||
info_count = total_count - error_count - warn_count
|
||||
|
||||
# 24小时趋势(按小时聚合)
|
||||
hourly: list[dict[str, Any]] = []
|
||||
for h in range(23, -1, -1):
|
||||
slot_start = now.replace(minute=0, second=0, microsecond=0) - timedelta(hours=h)
|
||||
slot_end = slot_start + timedelta(hours=1)
|
||||
|
||||
exec_h = db.query(func.count(ExecutionLog.id)).filter(
|
||||
ExecutionLog.timestamp >= slot_start,
|
||||
ExecutionLog.timestamp < slot_end
|
||||
).scalar() or 0
|
||||
agent_h = db.query(func.count(AgentExecutionLog.id)).filter(
|
||||
AgentExecutionLog.created_at >= slot_start,
|
||||
AgentExecutionLog.created_at < slot_end
|
||||
).scalar() or 0
|
||||
llm_h = db.query(func.count(AgentLLMLog.id)).filter(
|
||||
AgentLLMLog.created_at >= slot_start,
|
||||
AgentLLMLog.created_at < slot_end
|
||||
).scalar() or 0
|
||||
|
||||
hourly.append({
|
||||
"hour": slot_start.isoformat(),
|
||||
"count": exec_h + agent_h + llm_h
|
||||
})
|
||||
|
||||
return {
|
||||
"total_count": total_count,
|
||||
"error_count": error_count,
|
||||
"warn_count": warn_count,
|
||||
"info_count": info_count,
|
||||
"source_breakdown": {
|
||||
"execution": exec_total,
|
||||
"agent": agent_total,
|
||||
"llm": llm_total
|
||||
},
|
||||
"hourly_trend": hourly
|
||||
}
|
||||
|
||||
|
||||
@router.get("/app-logs", response_model=List[AppLogItem])
|
||||
async def get_app_logs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
lines: int = Query(200, ge=10, le=2000, description="读取行数"),
|
||||
level: Optional[str] = Query(None, description="按级别过滤: INFO/WARNING/ERROR"),
|
||||
):
|
||||
"""读取应用程序文件日志的尾部行"""
|
||||
_check_admin(current_user)
|
||||
|
||||
log_file = Path(settings.LOG_DIR) / "app.log"
|
||||
if not log_file.exists():
|
||||
return []
|
||||
|
||||
level_filter = level.upper() if level else None
|
||||
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
|
||||
all_lines = f.readlines()
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# 取尾部 N 行,倒序输出
|
||||
tail = all_lines[-lines:]
|
||||
|
||||
result: list[dict] = []
|
||||
for i, raw in enumerate(tail):
|
||||
content = raw.rstrip("\n").rstrip("\r")
|
||||
if level_filter and level_filter not in content.upper():
|
||||
continue
|
||||
result.append({
|
||||
"line_number": len(all_lines) - len(tail) + i + 1,
|
||||
"content": content
|
||||
})
|
||||
|
||||
return result
|
||||
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
Task API — 任务管理接口
|
||||
Task API — 任务管理接口 (含原子认领 + 依赖图 + Agent 状态)
|
||||
|
||||
参考 Claude Code src/utils/tasks.ts 的 API 设计
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -9,6 +11,7 @@ from datetime import datetime
|
||||
import logging
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.task_system import TaskSystem, ClaimResult, AgentState, AgentStatus
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.services import goal_service
|
||||
@@ -92,6 +95,42 @@ class TaskDependencyCheck(BaseModel):
|
||||
pending_dependencies: List[str] = []
|
||||
|
||||
|
||||
class ClaimTaskRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="认领任务的 Agent 标识")
|
||||
check_busy: bool = Field(default=True, description="是否检查 Agent 忙碌状态")
|
||||
|
||||
|
||||
class ClaimTaskResponse(BaseModel):
|
||||
success: bool
|
||||
reason: Optional[str] = None
|
||||
task: Optional[TaskResponse] = None
|
||||
busy_with_tasks: List[str] = []
|
||||
blocked_by_tasks: List[str] = []
|
||||
|
||||
|
||||
class BlockTaskRequest(BaseModel):
|
||||
from_task_id: str = Field(..., description="阻塞方任务ID")
|
||||
to_task_id: str = Field(..., description="被阻塞方任务ID")
|
||||
|
||||
|
||||
class AgentStatusResponse(BaseModel):
|
||||
agent_id: str
|
||||
status: str # idle / busy
|
||||
current_tasks: List[str] = []
|
||||
|
||||
|
||||
class ReleaseTaskRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="释放任务的 Agent 标识")
|
||||
|
||||
|
||||
class TaskCompleteRequest(BaseModel):
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskFailRequest(BaseModel):
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
# ──────────────────────────── Endpoints ────────────────────────────
|
||||
|
||||
@router.post("", response_model=TaskResponse, status_code=201)
|
||||
@@ -307,3 +346,160 @@ async def retry_task(
|
||||
error_message=str(e),
|
||||
)
|
||||
return goal_service.get_task(db, task_id)
|
||||
|
||||
|
||||
# ════════════════════ 任务系统增强 (参考 Claude Code task_system) ════════════════════
|
||||
|
||||
|
||||
def _get_task_system(db: Session) -> TaskSystem:
|
||||
return TaskSystem(db)
|
||||
|
||||
|
||||
@router.post("/{task_id}/claim", response_model=ClaimTaskResponse)
|
||||
def claim_task(
|
||||
task_id: str,
|
||||
data: ClaimTaskRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""原子认领任务 (SELECT FOR UPDATE),检查依赖+Agent忙碌"""
|
||||
ts = _get_task_system(db)
|
||||
result = ts.claim_task(task_id=task_id, agent_id=data.agent_id, check_busy=data.check_busy)
|
||||
return ClaimTaskResponse(
|
||||
success=result.success,
|
||||
reason=result.reason,
|
||||
task=TaskResponse.model_validate(result.task) if result.task else None,
|
||||
busy_with_tasks=result.busy_with_tasks,
|
||||
blocked_by_tasks=result.blocked_by_tasks,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/block", status_code=200)
|
||||
def block_task(
|
||||
data: BlockTaskRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""设置任务依赖: from_task 阻塞 to_task"""
|
||||
ts = _get_task_system(db)
|
||||
ok = ts.block_task(data.from_task_id, data.to_task_id)
|
||||
if not ok:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(404, "任务不存在")
|
||||
return {"message": "ok", "from": data.from_task_id, "to": data.to_task_id}
|
||||
|
||||
|
||||
@router.delete("/block", status_code=200)
|
||||
def unblock_task(
|
||||
data: BlockTaskRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""移除任务依赖"""
|
||||
ts = _get_task_system(db)
|
||||
ok = ts.unblock_task(data.from_task_id, data.to_task_id)
|
||||
if not ok:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(404, "任务不存在")
|
||||
return {"message": "ok"}
|
||||
|
||||
|
||||
@router.get("/agent/{agent_id}/status", response_model=AgentStatusResponse)
|
||||
def get_agent_status(
|
||||
agent_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取 Agent 忙闲状态 (idle/busy)"""
|
||||
ts = _get_task_system(db)
|
||||
state = ts.get_agent_status(agent_id)
|
||||
return AgentStatusResponse(
|
||||
agent_id=state.agent_id,
|
||||
status=state.status.value,
|
||||
current_tasks=state.current_tasks,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{task_id}/release", status_code=200)
|
||||
def release_task(
|
||||
task_id: str,
|
||||
data: ReleaseTaskRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Agent 主动释放单个任务"""
|
||||
ts = _get_task_system(db)
|
||||
ok = ts.release_task(task_id, data.agent_id)
|
||||
if not ok:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(404, "任务不存在或不属于该 Agent")
|
||||
return {"message": "ok", "task_id": task_id}
|
||||
|
||||
|
||||
@router.post("/agent/{agent_id}/unassign", status_code=200)
|
||||
def unassign_agent_tasks(
|
||||
agent_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""释放 Agent 所有未完成任务 (Agent 下线时调用)"""
|
||||
ts = _get_task_system(db)
|
||||
tasks = ts.unassign_agent_tasks(agent_id)
|
||||
return {"message": "ok", "unassigned_count": len(tasks), "task_ids": [t.id for t in tasks]}
|
||||
|
||||
|
||||
@router.post("/{task_id}/complete", response_model=TaskResponse)
|
||||
def complete_task(
|
||||
task_id: str,
|
||||
data: TaskCompleteRequest = TaskCompleteRequest(),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""标记任务完成 (自动检查被阻塞任务)"""
|
||||
ts = _get_task_system(db)
|
||||
task = ts.complete_task(task_id, result=data.result)
|
||||
if not task:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(404, "任务不存在")
|
||||
return TaskResponse.model_validate(task)
|
||||
|
||||
|
||||
@router.post("/{task_id}/fail", response_model=TaskResponse)
|
||||
def fail_task(
|
||||
task_id: str,
|
||||
data: TaskFailRequest = TaskFailRequest(),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""标记任务失败"""
|
||||
ts = _get_task_system(db)
|
||||
task = ts.fail_task(task_id, error_message=data.error_message)
|
||||
if not task:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(404, "任务不存在")
|
||||
return TaskResponse.model_validate(task)
|
||||
|
||||
|
||||
@router.get("/available/{goal_id}", response_model=List[TaskResponse])
|
||||
def get_available_tasks(
|
||||
goal_id: str,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取目标下所有可执行任务 (依赖满足 + 未被认领)"""
|
||||
ts = _get_task_system(db)
|
||||
tasks = ts.get_next_available_tasks(goal_id, limit=limit)
|
||||
return [TaskResponse.model_validate(t) for t in tasks]
|
||||
|
||||
|
||||
@router.get("/{task_id}/blockers", response_model=List[TaskResponse])
|
||||
def get_task_blockers(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取任务尚未完成的阻塞依赖"""
|
||||
ts = _get_task_system(db)
|
||||
blockers = ts.get_unresolved_blockers(task_id)
|
||||
return [TaskResponse.model_validate(t) for t in blockers]
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Request
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -33,6 +33,8 @@ class ToolCreate(BaseModel):
|
||||
implementation_type: str # builtin / http / code / workflow
|
||||
implementation_config: Optional[dict] = None
|
||||
is_public: bool = False
|
||||
status: str = "active" # active / draft / deprecated
|
||||
tags: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
@@ -44,6 +46,8 @@ class ToolResponse(BaseModel):
|
||||
implementation_type: str
|
||||
implementation_config: Optional[dict] = None
|
||||
is_public: bool = False
|
||||
status: str = "active"
|
||||
tags: Optional[List[str]] = None
|
||||
use_count: int = 0
|
||||
user_id: Optional[str] = None
|
||||
created_at: str = ""
|
||||
@@ -86,6 +90,8 @@ def _tool_to_dict(tool: Tool) -> dict:
|
||||
"implementation_type": tool.implementation_type,
|
||||
"implementation_config": tool.implementation_config,
|
||||
"is_public": tool.is_public,
|
||||
"status": tool.status if hasattr(tool, "status") else "active",
|
||||
"tags": tool.tags if hasattr(tool, "tags") else None,
|
||||
"use_count": tool.use_count,
|
||||
"user_id": tool.user_id,
|
||||
"created_at": tool.created_at.isoformat() if tool.created_at else "",
|
||||
@@ -139,12 +145,18 @@ async def list_tools(
|
||||
category: Optional[str] = Query(None, description="按分类筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
scope: Optional[str] = Query("public", description="public / mine / all"),
|
||||
status: Optional[str] = Query(None, description="状态筛选: active/draft/deprecated"),
|
||||
workspace_id: Optional[str] = Query(None, description="工作区ID筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Optional[User] = Depends(get_current_user),
|
||||
):
|
||||
"""浏览工具市场(含内置工具 + 数据库工具)。"""
|
||||
query = db.query(Tool)
|
||||
|
||||
# 工作区筛选
|
||||
if workspace_id:
|
||||
query = query.filter(Tool.workspace_id == workspace_id)
|
||||
|
||||
if scope == "public":
|
||||
query = query.filter(Tool.is_public == True)
|
||||
elif scope == "mine":
|
||||
@@ -152,6 +164,12 @@ async def list_tools(
|
||||
raise HTTPException(status_code=401, detail="需登录")
|
||||
query = query.filter(Tool.user_id == current_user.id)
|
||||
|
||||
if status:
|
||||
query = query.filter(Tool.status == status)
|
||||
else:
|
||||
# 默认只显示 active 状态的工具(排除 deprecated)
|
||||
query = query.filter((Tool.status == "active") | (Tool.status == None))
|
||||
|
||||
if category:
|
||||
query = query.filter(Tool.category == category)
|
||||
if search:
|
||||
@@ -216,10 +234,20 @@ async def get_tool(
|
||||
@router.post("", response_model=ToolResponse, status_code=201)
|
||||
async def create_tool(
|
||||
tool_data: ToolCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""创建自定义工具。"""
|
||||
# 从 JWT 提取当前工作区 ID
|
||||
from app.core.security import decode_access_token
|
||||
ws_id = None
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
payload = decode_access_token(auth_header[7:])
|
||||
if payload:
|
||||
ws_id = payload.get("ws") or None
|
||||
|
||||
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
|
||||
@@ -237,20 +265,24 @@ async def create_tool(
|
||||
implementation_type=tool_data.implementation_type,
|
||||
implementation_config=tool_data.implementation_config,
|
||||
is_public=tool_data.is_public,
|
||||
status=tool_data.status,
|
||||
tags=tool_data.tags,
|
||||
user_id=current_user.id,
|
||||
workspace_id=ws_id,
|
||||
)
|
||||
db.add(tool)
|
||||
db.commit()
|
||||
db.refresh(tool)
|
||||
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
|
||||
logger.info("工具已创建: %s (type=%s, status=%s)", tool.name, tool.implementation_type, tool.status)
|
||||
|
||||
# 刷新注册表
|
||||
tool_registry._custom_tool_configs[tool.name] = {
|
||||
**(tool.implementation_config or {}),
|
||||
"_type": tool.implementation_type,
|
||||
"_db_id": tool.id,
|
||||
}
|
||||
tool_registry._tool_schemas[tool.name] = tool.function_schema
|
||||
# 仅 active 状态注入注册表
|
||||
if tool.status == "active" or tool.status is None:
|
||||
tool_registry._custom_tool_configs[tool.name] = {
|
||||
**(tool.implementation_config or {}),
|
||||
"_type": tool.implementation_type,
|
||||
"_db_id": tool.id,
|
||||
}
|
||||
tool_registry._tool_schemas[tool.name] = tool.function_schema
|
||||
|
||||
return _tool_to_dict(tool)
|
||||
|
||||
@@ -317,12 +349,27 @@ async def delete_tool(
|
||||
db.commit()
|
||||
|
||||
# 清理注册表
|
||||
tool_registry._custom_tool_configs.pop(tool.name, None)
|
||||
tool_registry._tool_schemas.pop(tool.name, None)
|
||||
tool_registry.unregister_tool(tool.name)
|
||||
|
||||
return {"message": "工具已删除"}
|
||||
|
||||
|
||||
@router.post("/reload")
|
||||
async def reload_tools(db: Session = Depends(get_db)):
|
||||
"""从数据库重新加载所有自定义工具到注册表(热更新)。"""
|
||||
# 清除旧的自定义工具
|
||||
for name in list(tool_registry._custom_tool_configs.keys()):
|
||||
if name not in tool_registry._builtin_tools:
|
||||
tool_registry._custom_tool_configs.pop(name, None)
|
||||
tool_registry._tool_schemas.pop(name, None)
|
||||
|
||||
# 重新加载
|
||||
tool_registry.load_tools_from_db(db)
|
||||
count = len(tool_registry._custom_tool_configs)
|
||||
logger.info("工具热更新完成,已加载 %d 个自定义工具", count)
|
||||
return {"message": f"工具注册表已刷新", "custom_tool_count": count}
|
||||
|
||||
|
||||
# ─── 工具测试 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
261
backend/app/api/voice.py
Normal file
261
backend/app/api/voice.py
Normal file
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
语音 API — Android 应用语音交互接口
|
||||
|
||||
提供语音转文字 (ASR) 和文字转语音 (TTS) 的 HTTP API。
|
||||
TTS 优先使用 OpenAI TTS(需配置 OPENAI_API_KEY),否则使用免费的 Edge TTS。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/voice", tags=["voice"])
|
||||
|
||||
# Edge TTS 中文语音映射 (OpenAI voice name -> Edge TTS Chinese voice)
|
||||
_EDGE_VOICE_MAP = {
|
||||
"alloy": "zh-CN-YunxiNeural",
|
||||
"echo": "zh-CN-YunyangNeural",
|
||||
"fable": "zh-CN-XiaoxiaoNeural",
|
||||
"onyx": "zh-CN-YunjianNeural",
|
||||
"nova": "zh-CN-XiaoyiNeural",
|
||||
"shimmer": "zh-CN-XiaoxiaoNeural",
|
||||
}
|
||||
|
||||
|
||||
class AsrResponse(BaseModel):
|
||||
"""语音识别响应"""
|
||||
text: str = Field(..., description="识别出的文字")
|
||||
language: str = Field(default="zh", description="语言代码")
|
||||
|
||||
|
||||
class TtsRequest(BaseModel):
|
||||
"""文字转语音请求"""
|
||||
text: str = Field(..., min_length=1, max_length=4000, description="要合成的文字")
|
||||
voice: str = Field(
|
||||
default="alloy",
|
||||
description="语音风格:alloy / echo / fable / onyx / nova / shimmer",
|
||||
)
|
||||
|
||||
|
||||
class TtsResponse(BaseModel):
|
||||
"""文字转语音响应"""
|
||||
audio_url: str = Field(..., description="音频文件下载 URL")
|
||||
text_length: int = Field(..., description="文字长度")
|
||||
voice: str = Field(..., description="使用的语音风格")
|
||||
|
||||
|
||||
# TTS 输出缓存目录
|
||||
_TTS_DIR = Path(settings.LOCAL_FILE_TOOLS_ROOT) / "tts_outputs"
|
||||
|
||||
|
||||
@router.post("/asr", response_model=AsrResponse)
|
||||
async def voice_to_text(
|
||||
file: UploadFile = File(..., description="音频文件(AAC/WAV/MP3/WebM/M4A)"),
|
||||
language: str = Query("zh", description="语言代码"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
语音转文字 (ASR)。
|
||||
|
||||
接收 Android 端录音上传的 AAC 音频文件,调用 Whisper API 返回识别文字。
|
||||
采样率建议 16000 Hz。
|
||||
"""
|
||||
# 验证文件
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名为空")
|
||||
|
||||
ext = (file.filename.rsplit(".", 1)[-1].lower() if "." in file.filename else "aac")
|
||||
allowed = {"aac", "wav", "mp3", "webm", "m4a", "ogg", "flac", "mpeg"}
|
||||
if ext not in allowed:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的音频格式: {ext}")
|
||||
|
||||
# 检查文件大小
|
||||
content = await file.read()
|
||||
max_bytes = 25 * 1024 * 1024 # 25 MB
|
||||
if len(content) > max_bytes:
|
||||
raise HTTPException(status_code=400, detail=f"文件过大 ({len(content) / 1024 / 1024:.1f} MB)")
|
||||
|
||||
# 写入临时文件
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False)
|
||||
try:
|
||||
tmp.write(content)
|
||||
tmp.close()
|
||||
|
||||
api_key = (getattr(settings, "OPENAI_API_KEY", "") or "").strip()
|
||||
base_url = (
|
||||
getattr(settings, "OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
or "https://api.openai.com/v1"
|
||||
).strip()
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=503, detail="ASR 服务未配置 (OPENAI_API_KEY)")
|
||||
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as client:
|
||||
with open(tmp.name, "rb") as f:
|
||||
files_payload = {
|
||||
"file": (Path(tmp.name).name, f, f"audio/{ext}"),
|
||||
"model": (None, "whisper-1"),
|
||||
"language": (None, language),
|
||||
}
|
||||
resp = await client.post(
|
||||
f"{base_url.rstrip('/')}/audio/transcriptions",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
files=files_payload,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
logger.error("Whisper API 错误 %d: %s", resp.status_code, resp.text[:500])
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"语音识别服务返回错误 (HTTP {resp.status_code})",
|
||||
)
|
||||
|
||||
data = resp.json()
|
||||
text = data.get("text", "")
|
||||
logger.info("ASR 完成: user=%s file=%s len=%d", current_user.id, file.filename, len(text))
|
||||
|
||||
return AsrResponse(text=text, language=language)
|
||||
|
||||
finally:
|
||||
Path(tmp.name).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@router.post("/tts", response_model=TtsResponse)
|
||||
async def text_to_voice(
|
||||
req: TtsRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
文字转语音 (TTS)。
|
||||
|
||||
优先使用 OpenAI TTS(需配置有效 OPENAI_API_KEY),否则使用免费 Edge TTS。
|
||||
Android 端使用 ExoPlayer 播放返回的 audio_url。
|
||||
"""
|
||||
valid_voices = {"alloy", "echo", "fable", "onyx", "nova", "shimmer"}
|
||||
voice = req.voice if req.voice in valid_voices else "alloy"
|
||||
|
||||
text = req.text.strip()
|
||||
if len(text) > 4000:
|
||||
text = text[:4000] + "..."
|
||||
|
||||
_TTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
filename = f"tts_{current_user.id}_{int(time.time())}.mp3"
|
||||
filepath = _TTS_DIR / filename
|
||||
|
||||
# 尝试 OpenAI TTS
|
||||
api_key = (getattr(settings, "OPENAI_API_KEY", "") or "").strip()
|
||||
base_url = (
|
||||
getattr(settings, "OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
or "https://api.openai.com/v1"
|
||||
).strip()
|
||||
|
||||
use_edge = False
|
||||
if api_key and api_key not in ("your-openai-api-key", "sk-your-"):
|
||||
import httpx
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(
|
||||
f"{base_url.rstrip('/')}/audio/speech",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"model": "tts-1", "voice": voice, "input": text},
|
||||
)
|
||||
|
||||
if resp.status_code == 200:
|
||||
filepath.write_bytes(resp.content)
|
||||
logger.info("OpenAI TTS 完成: user=%s text_len=%d voice=%s", current_user.id, len(req.text), voice)
|
||||
return TtsResponse(
|
||||
audio_url=f"/api/v1/voice/audio/{filename}",
|
||||
text_length=len(req.text),
|
||||
voice=voice,
|
||||
)
|
||||
else:
|
||||
logger.warning("OpenAI TTS 失败 (%d), 回退到 Edge TTS", resp.status_code)
|
||||
use_edge = True
|
||||
except Exception as exc:
|
||||
logger.warning("OpenAI TTS 异常: %s, 回退到 Edge TTS", exc)
|
||||
use_edge = True
|
||||
else:
|
||||
use_edge = True
|
||||
|
||||
# 回退:Edge TTS(免费,无需 API KEY)
|
||||
if use_edge:
|
||||
edge_voice = _EDGE_VOICE_MAP.get(voice, "zh-CN-YunxiNeural")
|
||||
try:
|
||||
await _edge_tts_synthesize(text, edge_voice, str(filepath))
|
||||
logger.info("Edge TTS 完成: user=%s text_len=%d edge_voice=%s", current_user.id, len(req.text), edge_voice)
|
||||
return TtsResponse(
|
||||
audio_url=f"/api/v1/voice/audio/{filename}",
|
||||
text_length=len(req.text),
|
||||
voice=voice,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Edge TTS 失败: %s", exc)
|
||||
raise HTTPException(status_code=502, detail=f"TTS 服务不可用: {exc}")
|
||||
|
||||
|
||||
async def _edge_tts_synthesize(text: str, voice: str, output_path: str) -> None:
|
||||
"""使用 edge-tts 命令行工具合成语音。"""
|
||||
# 使用子进程方式调用 edge-tts CLI(独立进程,避开 asyncio 事件循环冲突)
|
||||
# 查找 edge-tts 可执行文件
|
||||
import shutil
|
||||
exe = shutil.which("edge-tts") or shutil.which("edge-tts", path=(
|
||||
os.environ.get("PATH", "") + os.pathsep +
|
||||
os.path.join(os.path.dirname(sys.executable), "Scripts") + os.pathsep +
|
||||
os.path.join(os.path.dirname(sys.executable), "..", "Scripts")
|
||||
))
|
||||
if not exe:
|
||||
exe = "edge-tts" # fallback, let subprocess try PATH
|
||||
|
||||
logger.debug("edge-tts exe: %s", exe)
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
exe, "--text", text, "--voice", voice, "--write-media", output_path,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
out_str = stdout.decode(errors="replace") if stdout else ""
|
||||
err_str = stderr.decode(errors="replace") if stderr else ""
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"edge-tts CLI 失败 (exit={proc.returncode}): {err_str or out_str}")
|
||||
|
||||
if not Path(output_path).is_file():
|
||||
raise RuntimeError(f"edge-tts 未生成输出文件: {out_str} {err_str}")
|
||||
|
||||
logger.info("edge-tts CLI 成功: %d bytes", Path(output_path).stat().st_size)
|
||||
|
||||
|
||||
@router.get("/audio/{filename}")
|
||||
async def get_tts_audio(filename: str):
|
||||
"""获取 TTS 生成的音频文件(无需认证,URL 包含随机名)。"""
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
raise HTTPException(status_code=400, detail="非法文件名")
|
||||
|
||||
filepath = _TTS_DIR / filename
|
||||
if not filepath.is_file():
|
||||
raise HTTPException(status_code=404, detail="音频文件不存在或已过期")
|
||||
|
||||
return FileResponse(filepath, media_type="audio/mpeg")
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
工作流API
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
@@ -142,12 +142,13 @@ async def get_workflows(
|
||||
limit: int = 100,
|
||||
search: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
workspace_id: Optional[str] = Query(None, description="工作区ID筛选"),
|
||||
sort_by: Optional[str] = "created_at",
|
||||
sort_order: Optional[str] = "desc",
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取工作流列表(支持搜索、筛选、排序)"""
|
||||
"""获取工作流列表(支持搜索、筛选、排序、工作区筛选)"""
|
||||
# 管理员可以看到所有工作流,普通用户只能看到自己拥有的或有read权限的
|
||||
if current_user.role == "admin":
|
||||
query = db.query(Workflow)
|
||||
@@ -155,10 +156,10 @@ async def get_workflows(
|
||||
# 获取用户拥有或有read权限的工作流
|
||||
from sqlalchemy import or_
|
||||
from app.models.permission import WorkflowPermission
|
||||
|
||||
|
||||
# 用户拥有的工作流
|
||||
owned_workflows = db.query(Workflow.id).filter(Workflow.user_id == current_user.id).subquery()
|
||||
|
||||
|
||||
# 用户有read权限的工作流(通过用户ID或角色)
|
||||
user_permissions = db.query(WorkflowPermission.workflow_id).filter(
|
||||
WorkflowPermission.permission_type == "read",
|
||||
@@ -167,14 +168,18 @@ async def get_workflows(
|
||||
WorkflowPermission.role_id.in_([r.id for r in current_user.roles])
|
||||
)
|
||||
).subquery()
|
||||
|
||||
|
||||
query = db.query(Workflow).filter(
|
||||
or_(
|
||||
Workflow.id.in_(db.query(owned_workflows.c.id)),
|
||||
Workflow.id.in_(db.query(user_permissions.c.workflow_id))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# 工作区筛选
|
||||
if workspace_id:
|
||||
query = query.filter(Workflow.workspace_id == workspace_id)
|
||||
|
||||
# 搜索:按名称或描述搜索
|
||||
if search:
|
||||
search_pattern = f"%{search}%"
|
||||
@@ -211,25 +216,36 @@ async def get_workflows(
|
||||
@router.post("", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_workflow(
|
||||
workflow_data: WorkflowCreate,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建工作流"""
|
||||
# 从 JWT 提取当前工作区 ID
|
||||
from app.core.security import decode_access_token
|
||||
ws_id = None
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
payload = decode_access_token(auth_header[7:])
|
||||
if payload:
|
||||
ws_id = payload.get("ws") or None
|
||||
|
||||
# 验证工作流
|
||||
validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges)
|
||||
if not validation_result["valid"]:
|
||||
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
|
||||
|
||||
|
||||
# 如果有警告,记录日志
|
||||
if validation_result["warnings"]:
|
||||
logger.warning(f"工作流创建警告: {', '.join(validation_result['warnings'])}")
|
||||
|
||||
|
||||
workflow = Workflow(
|
||||
name=workflow_data.name,
|
||||
description=workflow_data.description,
|
||||
nodes=workflow_data.nodes,
|
||||
edges=workflow_data.edges,
|
||||
user_id=current_user.id
|
||||
user_id=current_user.id,
|
||||
workspace_id=ws_id,
|
||||
)
|
||||
db.add(workflow)
|
||||
db.commit()
|
||||
|
||||
393
backend/app/api/workspaces.py
Normal file
393
backend/app/api/workspaces.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
工作区 (Workspace) API — 多租户管理
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.api.auth import get_current_user
|
||||
from app.models.user import User
|
||||
from app.models.workspace import Workspace, WorkspaceMembership
|
||||
from app.services.workspace_service import check_workspace_access, get_user_workspaces
|
||||
from app.core.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/v1/workspaces",
|
||||
tags=["workspaces"],
|
||||
responses={
|
||||
401: {"description": "未授权"},
|
||||
403: {"description": "无权访问"},
|
||||
404: {"description": "资源不存在"},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ── Pydantic Schemas ──
|
||||
|
||||
class WorkspaceCreate(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=100, description="工作区名称")
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
max_members: int = Field(default=50, ge=1, le=500)
|
||||
|
||||
|
||||
class WorkspaceUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=1000)
|
||||
max_members: Optional[int] = Field(None, ge=1, le=500)
|
||||
status: Optional[str] = None # active/disabled
|
||||
|
||||
|
||||
class MemberAddRequest(BaseModel):
|
||||
user_id: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
role: str = Field(default="member", pattern="^(admin|member)$")
|
||||
|
||||
|
||||
class MemberUpdateRequest(BaseModel):
|
||||
role: str = Field(..., pattern="^(admin|member)$")
|
||||
|
||||
|
||||
class WorkspaceResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
is_default: bool
|
||||
owner_id: str
|
||||
max_members: int
|
||||
settings: Optional[Dict[str, Any]]
|
||||
member_count: int = 0
|
||||
status: str
|
||||
created_at: Optional[str]
|
||||
updated_at: Optional[str]
|
||||
|
||||
|
||||
class MemberResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
username: str
|
||||
email: str
|
||||
role: str
|
||||
joined_at: Optional[str]
|
||||
|
||||
|
||||
# ── Endpoints ──
|
||||
|
||||
@router.get("", response_model=List[Dict[str, Any]])
|
||||
def list_workspaces(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取当前用户的工作区列表。"""
|
||||
return get_user_workspaces(db, current_user)
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED)
|
||||
def create_workspace(
|
||||
data: WorkspaceCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""创建新工作区。"""
|
||||
ws = Workspace(
|
||||
id=str(uuid.uuid4()),
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
owner_id=current_user.id,
|
||||
max_members=data.max_members,
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(ws)
|
||||
db.flush()
|
||||
|
||||
# 创建者自动成为工作区管理员
|
||||
membership = WorkspaceMembership(
|
||||
id=str(uuid.uuid4()),
|
||||
workspace_id=ws.id,
|
||||
user_id=current_user.id,
|
||||
role="admin",
|
||||
joined_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(membership)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"id": ws.id,
|
||||
"name": ws.name,
|
||||
"description": ws.description,
|
||||
"is_default": bool(ws.is_default),
|
||||
"owner_id": ws.owner_id,
|
||||
"max_members": ws.max_members,
|
||||
"role": "admin",
|
||||
"member_count": 1,
|
||||
"status": ws.status,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{workspace_id}")
|
||||
def get_workspace(
|
||||
workspace_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取工作区详情。"""
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if not ws:
|
||||
raise NotFoundError("工作区", workspace_id)
|
||||
|
||||
if not check_workspace_access(db, current_user, workspace_id):
|
||||
raise HTTPException(status_code=403, detail="无权访问此工作区")
|
||||
|
||||
member_count = (
|
||||
db.query(WorkspaceMembership)
|
||||
.filter(WorkspaceMembership.workspace_id == workspace_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
user_role = "admin" if current_user.role == "admin" else None
|
||||
if not user_role:
|
||||
membership = (
|
||||
db.query(WorkspaceMembership)
|
||||
.filter(
|
||||
WorkspaceMembership.workspace_id == workspace_id,
|
||||
WorkspaceMembership.user_id == current_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if membership:
|
||||
user_role = membership.role
|
||||
|
||||
return {
|
||||
**ws.to_dict(),
|
||||
"member_count": member_count,
|
||||
"role": user_role,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{workspace_id}")
|
||||
def update_workspace(
|
||||
workspace_id: str,
|
||||
data: WorkspaceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""更新工作区设置(需工作区管理员权限)。"""
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if not ws:
|
||||
raise NotFoundError("工作区", workspace_id)
|
||||
|
||||
if not check_workspace_access(db, current_user, workspace_id, required_role="admin"):
|
||||
raise HTTPException(status_code=403, detail="需要工作区管理员权限")
|
||||
|
||||
if data.name is not None:
|
||||
ws.name = data.name
|
||||
if data.description is not None:
|
||||
ws.description = data.description
|
||||
if data.max_members is not None:
|
||||
ws.max_members = data.max_members
|
||||
if data.status is not None:
|
||||
ws.status = data.status
|
||||
ws.updated_at = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
return {**ws.to_dict()}
|
||||
|
||||
|
||||
@router.delete("/{workspace_id}")
|
||||
def delete_workspace(
|
||||
workspace_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""删除工作区(软删除,仅平台管理员或工作区管理员可操作)。"""
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if not ws:
|
||||
raise NotFoundError("工作区", workspace_id)
|
||||
|
||||
if not check_workspace_access(db, current_user, workspace_id, required_role="admin"):
|
||||
raise HTTPException(status_code=403, detail="需要工作区管理员权限")
|
||||
|
||||
if ws.is_default and current_user.role != "admin":
|
||||
raise HTTPException(status_code=403, detail="默认工作区不可删除")
|
||||
|
||||
ws.status = "deleted"
|
||||
ws.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
return {"message": "工作区已删除"}
|
||||
|
||||
|
||||
# ── 成员管理 ──
|
||||
|
||||
@router.get("/{workspace_id}/members")
|
||||
def list_members(
|
||||
workspace_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取工作区成员列表。"""
|
||||
if not check_workspace_access(db, current_user, workspace_id):
|
||||
raise HTTPException(status_code=403, detail="无权访问此工作区")
|
||||
|
||||
memberships = (
|
||||
db.query(WorkspaceMembership)
|
||||
.filter(WorkspaceMembership.workspace_id == workspace_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
result = []
|
||||
for m in memberships:
|
||||
user = m.user
|
||||
result.append({
|
||||
"id": m.id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"role": m.role,
|
||||
"joined_at": m.joined_at.isoformat() if m.joined_at else None,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/{workspace_id}/members", status_code=status.HTTP_201_CREATED)
|
||||
def add_member(
|
||||
workspace_id: str,
|
||||
data: MemberAddRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""添加工作区成员(需工作区管理员权限)。"""
|
||||
if not check_workspace_access(db, current_user, workspace_id, required_role="admin"):
|
||||
raise HTTPException(status_code=403, detail="需要工作区管理员权限")
|
||||
|
||||
# 查找目标用户
|
||||
target_user = None
|
||||
if data.user_id:
|
||||
target_user = db.query(User).filter(User.id == data.user_id).first()
|
||||
elif data.username:
|
||||
target_user = db.query(User).filter(User.username == data.username).first()
|
||||
|
||||
if not target_user:
|
||||
raise NotFoundError("用户")
|
||||
|
||||
# 检查是否已存在
|
||||
existing = (
|
||||
db.query(WorkspaceMembership)
|
||||
.filter(
|
||||
WorkspaceMembership.workspace_id == workspace_id,
|
||||
WorkspaceMembership.user_id == target_user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该用户已是工作区成员")
|
||||
|
||||
# 检查成员数上限
|
||||
member_count = (
|
||||
db.query(WorkspaceMembership)
|
||||
.filter(WorkspaceMembership.workspace_id == workspace_id)
|
||||
.count()
|
||||
)
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if member_count >= ws.max_members:
|
||||
raise HTTPException(status_code=400, detail="工作区成员数已达上限")
|
||||
|
||||
membership = WorkspaceMembership(
|
||||
id=str(uuid.uuid4()),
|
||||
workspace_id=workspace_id,
|
||||
user_id=target_user.id,
|
||||
role=data.role,
|
||||
joined_at=datetime.utcnow(),
|
||||
)
|
||||
db.add(membership)
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"id": membership.id,
|
||||
"user_id": target_user.id,
|
||||
"username": target_user.username,
|
||||
"email": target_user.email,
|
||||
"role": membership.role,
|
||||
"joined_at": membership.joined_at.isoformat() if membership.joined_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{workspace_id}/members/{user_id}")
|
||||
def update_member_role(
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
data: MemberUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""修改成员角色(需工作区管理员权限)。"""
|
||||
if not check_workspace_access(db, current_user, workspace_id, required_role="admin"):
|
||||
raise HTTPException(status_code=403, detail="需要工作区管理员权限")
|
||||
|
||||
membership = (
|
||||
db.query(WorkspaceMembership)
|
||||
.filter(
|
||||
WorkspaceMembership.workspace_id == workspace_id,
|
||||
WorkspaceMembership.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not membership:
|
||||
raise NotFoundError("成员", user_id)
|
||||
|
||||
# 不能修改自己的工作区所有者
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if ws.owner_id == user_id and data.role != "admin":
|
||||
raise HTTPException(status_code=400, detail="工作区所有者必须保持admin角色")
|
||||
|
||||
membership.role = data.role
|
||||
db.commit()
|
||||
|
||||
return {"message": "角色已更新"}
|
||||
|
||||
|
||||
@router.delete("/{workspace_id}/members/{user_id}")
|
||||
def remove_member(
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""移除工作区成员(需工作区管理员权限,或自行退出)。"""
|
||||
is_self = user_id == current_user.id
|
||||
is_admin = check_workspace_access(db, current_user, workspace_id, required_role="admin")
|
||||
|
||||
if not is_self and not is_admin:
|
||||
raise HTTPException(status_code=403, detail="无权移除此成员")
|
||||
|
||||
# 不能移除工作区所有者
|
||||
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||
if not ws:
|
||||
raise NotFoundError("工作区", workspace_id)
|
||||
if ws.owner_id == user_id and not is_self:
|
||||
raise HTTPException(status_code=400, detail="不能移除工作区所有者")
|
||||
|
||||
membership = (
|
||||
db.query(WorkspaceMembership)
|
||||
.filter(
|
||||
WorkspaceMembership.workspace_id == workspace_id,
|
||||
WorkspaceMembership.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not membership:
|
||||
raise NotFoundError("成员", user_id)
|
||||
|
||||
db.delete(membership)
|
||||
db.commit()
|
||||
|
||||
return {"message": "成员已移除"}
|
||||
Reference in New Issue
Block a user