fix: delete agent 500 error + dynamic personality + deployment guide

- Fix delete agent 500: clean up FK records (agent_llm_logs, permissions,
  schedules, executions, team_members) and unbind goals/tasks before delete
- Remove hardcoded personality templates in Android, replace with dynamic
  system prompt generation from name + description
- Set promptSectionsEnabled=false to bypass PromptComposer for personality
- Add Tencent Cloud Linux deployment guide (Docker Compose)
- Accumulated backend service updates, frontend UI fixes, Android app changes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-06-29 01:17:21 +08:00
parent 86b98865e3
commit beff3fac8d
1084 changed files with 117315 additions and 1281 deletions

View File

@@ -0,0 +1,319 @@
"""
对话分支 API — 支持从任意历史点分叉对话
参考 Claude Code src/commands/branch/branch.ts
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.auth import get_current_user
from app.models.user import User
from app.models.agent import Agent
from app.agent_runtime import (
AgentRuntime,
AgentConfig,
AgentLLMConfig,
AgentToolConfig,
AgentMemoryConfig,
AgentStep,
)
from app.agent_runtime.context import AgentContext
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/agent-chat", tags=["agent-branches"])
def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
if not nodes:
return {}
for node in nodes:
typ = node.get("type", "")
if typ in ("agent", "llm", "template"):
return node.get("data") or {}
return {}
def _build_memory_config_from_node(agent_node_cfg: dict) -> AgentMemoryConfig:
from app.core.compaction_config import CompactionConfig
compaction_raw = agent_node_cfg.get("compaction")
if isinstance(compaction_raw, dict):
compaction = CompactionConfig(**compaction_raw)
else:
compaction = CompactionConfig()
return AgentMemoryConfig(
max_history_messages=int(agent_node_cfg.get("memory_max_history", 20)),
vector_memory_top_k=int(agent_node_cfg.get("memory_vector_top_k", 5)),
persist_to_db=bool(agent_node_cfg.get("memory_persist", True)),
vector_memory_enabled=bool(agent_node_cfg.get("memory_vector_enabled", True)),
learning_enabled=bool(agent_node_cfg.get("memory_learning", True)),
memory_dir_enabled=bool(agent_node_cfg.get("memory_dir_enabled", False)),
memory_dir_path=str(agent_node_cfg.get("memory_dir_path", "")),
compaction=compaction,
)
# ──────────────────────────── 请求/响应模型 ────────────────────────────
class BranchCreateRequest(BaseModel):
session_id: str = Field(..., description="要分叉的原会话 ID")
title: Optional[str] = Field(default=None, description="自定义分支标题")
agent_id: Optional[str] = Field(default=None, description="关联 Agent ID")
class BranchItem(BaseModel):
id: str
title: str
agent_name: Optional[str] = None
parent_session_id: str
message_count: int
first_user_message: Optional[str] = None
created_at: str
class Config:
from_attributes = True
class BranchListResponse(BaseModel):
branches: List[BranchItem]
total: int
class BranchDetailResponse(BaseModel):
id: str
title: str
agent_name: Optional[str] = None
parent_session_id: str
branch_session_id: str
message_count: int
first_user_message: Optional[str] = None
messages: Optional[List[Dict[str, Any]]] = None
created_at: str
class Config:
from_attributes = True
class BranchResumeRequest(BaseModel):
message: str = Field(..., description="在分支基础上继续的新消息")
class BranchResumeResponse(BaseModel):
content: str
iterations_used: int
tool_calls_made: int
truncated: bool
session_id: str
agent_id: Optional[str] = None
steps: List[AgentStep] = Field(default_factory=list)
# ──────────────────────────── API 端点 ────────────────────────────
@router.post("/branches", response_model=BranchDetailResponse)
async def create_conversation_branch(
req: BranchCreateRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""从指定会话创建对话分支(快照完整消息列表)。"""
from app.services.conversation_branch_service import (
create_branch,
get_messages_from_session,
)
messages = get_messages_from_session(db, req.session_id)
if not messages:
raise HTTPException(status_code=404, detail="找不到该会话的消息记录(请确保会话已产生对话)")
agent_name = None
if req.agent_id:
agent = db.query(Agent).filter(Agent.id == req.agent_id).first()
if agent:
agent_name = agent.name
branch = create_branch(
db=db,
user_id=current_user.id,
parent_session_id=req.session_id,
messages=messages,
agent_id=req.agent_id,
agent_name=agent_name,
custom_title=req.title,
)
return BranchDetailResponse(
id=branch.id,
title=branch.title,
agent_name=branch.agent_name,
parent_session_id=branch.parent_session_id,
branch_session_id=branch.branch_session_id,
message_count=branch.message_count,
first_user_message=branch.first_user_message,
messages=branch.messages,
created_at=branch.created_at.isoformat() if branch.created_at else "",
)
@router.get("/branches", response_model=BranchListResponse)
async def list_conversation_branches(
agent_id: Optional[str] = None,
limit: int = 50,
offset: int = 0,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""列出当前用户的所有对话分支。"""
from app.services.conversation_branch_service import list_branches
branches = list_branches(
db=db, user_id=current_user.id, agent_id=agent_id, limit=limit, offset=offset,
)
return BranchListResponse(
branches=[
BranchItem(
id=b.id, title=b.title, agent_name=b.agent_name,
parent_session_id=b.parent_session_id,
message_count=b.message_count,
first_user_message=b.first_user_message,
created_at=b.created_at.isoformat() if b.created_at else "",
)
for b in branches
],
total=len(branches),
)
@router.get("/branches/{branch_id}", response_model=BranchDetailResponse)
async def get_conversation_branch(
branch_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取分支详情(含完整消息列表)。"""
from app.services.conversation_branch_service import get_branch
branch = get_branch(db, branch_id, current_user.id)
if not branch:
raise HTTPException(status_code=404, detail="分支不存在")
return BranchDetailResponse(
id=branch.id, title=branch.title, agent_name=branch.agent_name,
parent_session_id=branch.parent_session_id,
branch_session_id=branch.branch_session_id,
message_count=branch.message_count,
first_user_message=branch.first_user_message,
messages=branch.messages,
created_at=branch.created_at.isoformat() if branch.created_at else "",
)
@router.delete("/branches/{branch_id}")
async def delete_conversation_branch(
branch_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""删除(软删除)对话分支。"""
from app.services.conversation_branch_service import delete_branch
success = delete_branch(db, branch_id, current_user.id)
if not success:
raise HTTPException(status_code=404, detail="分支不存在或无权操作")
return {"message": "分支已删除", "branch_id": branch_id}
@router.post("/branches/{branch_id}/resume", response_model=BranchResumeResponse)
async def resume_from_branch(
branch_id: str,
req: BranchResumeRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""从分支恢复对话 — 使用分支保存的消息作为上下文继续对话。"""
from app.services.conversation_branch_service import get_branch
branch = get_branch(db, branch_id, current_user.id)
if not branch:
raise HTTPException(status_code=404, detail="分支不存在")
if not branch.messages:
raise HTTPException(status_code=400, detail="分支没有保存的消息")
messages = branch.messages
system_prompt = "你是一个有用的AI助手。"
history_messages = messages
if messages and messages[0].get("role") == "system":
system_prompt = messages[0].get("content", system_prompt)
history_messages = messages[1:]
# 构建 Agent 配置
agent = None
if branch.agent_id:
agent = db.query(Agent).filter(Agent.id == branch.agent_id).first()
if agent:
wc = agent.workflow_config or {}
nodes = wc.get("nodes", [])
agent_node_cfg = _find_agent_node_config(nodes)
llm_config = AgentLLMConfig(
provider=agent_node_cfg.get("provider", "deepseek"),
model=agent_node_cfg.get("model", "deepseek-v4-flash"),
temperature=float(agent_node_cfg.get("temperature", 0.7)),
max_iterations=int(agent_node_cfg.get("max_iterations", 10)),
)
config = AgentConfig(
name=agent.name, system_prompt=system_prompt, llm=llm_config,
tools=AgentToolConfig(include_tools=agent_node_cfg.get("tools", [])),
memory=_build_memory_config_from_node(agent_node_cfg),
user_id=current_user.id,
memory_scope_id=f"{current_user.id}:{branch.agent_id}" if current_user.id else str(branch.agent_id),
)
else:
config = AgentConfig(
name="branch_resume", system_prompt=system_prompt,
llm=AgentLLMConfig(model="deepseek-v4-flash", temperature=0.7, max_iterations=10),
user_id=current_user.id,
memory_scope_id=f"{current_user.id}:__bare__" if current_user.id else "__bare__",
)
# 创建 Runtime 并预加载分支消息
context = AgentContext(
system_prompt=system_prompt,
user_id=current_user.id,
session_id=branch.branch_session_id,
)
for msg in history_messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "user":
context.add_user_message(content)
elif role == "assistant":
tool_calls = msg.get("tool_calls")
context.add_assistant_message(content, tool_calls)
elif role == "tool":
context.add_tool_result(
msg.get("tool_call_id", ""),
msg.get("name", "unknown"),
content,
)
# 不需要 on_llm_logger 回调(避免与主 logger 重复),仅做执行
runtime = AgentRuntime(config=config, context=context)
result = await runtime.run(req.message)
return BranchResumeResponse(
content=result.content,
iterations_used=result.iterations_used,
tool_calls_made=result.tool_calls_made,
truncated=result.truncated,
session_id=branch.branch_session_id,
agent_id=branch.agent_id,
steps=result.steps,
)

View File

@@ -82,6 +82,9 @@ class ChatRequest(BaseModel):
model: Optional[str] = None
temperature: Optional[float] = None
max_iterations: Optional[int] = None
streamlined: bool = Field(default=False, description="启用工具结果流式美化")
prompt_sections_enabled: bool = Field(default=True, description="启用系统提示词分层装配")
system_prompt_override: Optional[str] = Field(default=None, description="覆盖 Agent 的 System Prompt")
class ChatResponse(BaseModel):
@@ -92,6 +95,8 @@ class ChatResponse(BaseModel):
session_id: str
agent_id: Optional[str] = None
steps: List[AgentStep] = Field(default_factory=list, description="执行追踪步骤")
streamlined_summary: Optional[str] = Field(default=None, description="流式美化摘要streamlined 模式)")
token_usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 预算摘要")
class OrchestrateAgentItem(BaseModel):
@@ -249,11 +254,35 @@ async def chat_bare(
),
user_id=uid,
memory_scope_id=bare_scope,
memory=AgentMemoryConfig(
memory_dir_enabled=True,
memory_dir_path="",
persist_to_db=True,
vector_memory_enabled=True,
learning_enabled=True,
),
tools=AgentToolConfig(
permission_level="acceptEdits",
),
)
if not req.prompt_sections_enabled:
config.prompt_sections.enabled = False
if req.system_prompt_override:
config.system_prompt = req.system_prompt_override
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
result = await runtime.run(req.message)
# 流式美化:为 steps 生成累计摘要
streamlined_summary = None
if req.streamlined and result.steps:
from app.core.streamlined_output import ToolCounts, categorize_tool, get_tool_summary_text
counts = ToolCounts()
for s in result.steps:
if s.type == "tool_result" and s.tool_name:
counts.add(categorize_tool(s.tool_name))
streamlined_summary = get_tool_summary_text(counts)
return ChatResponse(
content=result.content,
iterations_used=result.iterations_used,
@@ -261,6 +290,8 @@ async def chat_bare(
truncated=result.truncated,
session_id=runtime.context.session_id,
steps=result.steps,
streamlined_summary=streamlined_summary,
token_usage=result.token_usage.model_dump() if result.token_usage else None,
)
@@ -286,9 +317,23 @@ async def chat_bare_stream(
),
user_id=uid,
memory_scope_id=bare_scope,
memory=AgentMemoryConfig(
memory_dir_enabled=True,
memory_dir_path="",
persist_to_db=True,
vector_memory_enabled=True,
learning_enabled=True,
),
tools=AgentToolConfig(
permission_level="acceptEdits",
),
)
if not req.prompt_sections_enabled:
config.prompt_sections.enabled = False
if req.system_prompt_override:
config.system_prompt = req.system_prompt_override
on_llm_call = _make_llm_logger(db, agent_id=None, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
return StreamingResponse(
_sse_stream(runtime.run_stream(req.message)),
media_type="text/event-stream",
@@ -339,6 +384,8 @@ async def chat_with_agent(
uid = current_user.id
mem_scope = f"{uid}:{agent_id}" if uid else str(agent_id)
memory_cfg = _build_memory_config_from_node(agent_node_cfg)
if getattr(agent, "parent_agent_id", None):
memory_cfg.parent_agent_id = agent.parent_agent_id
config = AgentConfig(
name=agent.name,
system_prompt=system_prompt,
@@ -347,21 +394,42 @@ async def chat_with_agent(
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
temperature=req.temperature or float(agent_node_cfg.get("temperature", 0.7)),
max_iterations=req.max_iterations or int(agent_node_cfg.get("max_iterations", 10)),
# 计划模式 (P2)
plan_mode_enabled=bool(agent_node_cfg.get("plan_mode_enabled", False)),
plan_approval_required=bool(agent_node_cfg.get("plan_approval_required", True)),
),
tools=AgentToolConfig(
include_tools=agent_node_cfg.get("tools", []),
exclude_tools=agent_node_cfg.get("exclude_tools", []),
# 工具安全分级 (P3)
permission_level=str(agent_node_cfg.get("permission_level", "default")),
deny_tools=agent_node_cfg.get("deny_tools", []),
auto_approve_rules=agent_node_cfg.get("auto_approve_rules", []),
),
memory=memory_cfg,
budget=budget,
user_id=uid,
memory_scope_id=mem_scope,
)
if not req.prompt_sections_enabled:
config.prompt_sections.enabled = False
if req.system_prompt_override:
config.system_prompt = req.system_prompt_override
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
result = await runtime.run(req.message)
# 流式美化:为 steps 生成累计摘要
streamlined_summary = None
if req.streamlined and result.steps:
from app.core.streamlined_output import ToolCounts, categorize_tool, get_tool_summary_text
counts = ToolCounts()
for s in result.steps:
if s.type == "tool_result" and s.tool_name:
counts.add(categorize_tool(s.tool_name))
streamlined_summary = get_tool_summary_text(counts)
return ChatResponse(
content=result.content,
iterations_used=result.iterations_used,
@@ -370,6 +438,8 @@ async def chat_with_agent(
session_id=runtime.context.session_id,
agent_id=agent_id,
steps=result.steps,
streamlined_summary=streamlined_summary,
token_usage=result.token_usage.model_dump() if result.token_usage else None,
)
@@ -408,6 +478,8 @@ async def chat_with_agent_stream(
uid = current_user.id
mem_scope = f"{uid}:{agent_id}" if uid else str(agent_id)
memory_cfg = _build_memory_config_from_node(agent_node_cfg)
if getattr(agent, "parent_agent_id", None):
memory_cfg.parent_agent_id = agent.parent_agent_id
config = AgentConfig(
name=agent.name,
system_prompt=system_prompt,
@@ -416,19 +488,30 @@ async def chat_with_agent_stream(
model=req.model or agent_node_cfg.get("model", "gpt-4o-mini"),
temperature=req.temperature or float(agent_node_cfg.get("temperature", 0.7)),
max_iterations=req.max_iterations or int(agent_node_cfg.get("max_iterations", 10)),
# 计划模式 (P2)
plan_mode_enabled=bool(agent_node_cfg.get("plan_mode_enabled", False)),
plan_approval_required=bool(agent_node_cfg.get("plan_approval_required", True)),
),
tools=AgentToolConfig(
include_tools=agent_node_cfg.get("tools", []),
exclude_tools=agent_node_cfg.get("exclude_tools", []),
# 工具安全分级 (P3)
permission_level=str(agent_node_cfg.get("permission_level", "default")),
deny_tools=agent_node_cfg.get("deny_tools", []),
auto_approve_rules=agent_node_cfg.get("auto_approve_rules", []),
),
memory=memory_cfg,
budget=budget,
user_id=uid,
memory_scope_id=mem_scope,
)
if not req.prompt_sections_enabled:
config.prompt_sections.enabled = False
if req.system_prompt_override:
config.system_prompt = req.system_prompt_override
on_llm_call = _make_llm_logger(db, agent_id=agent_id, user_id=current_user.id)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call)
runtime = AgentRuntime(config=config, on_llm_call=on_llm_call, streamlined=req.streamlined)
return StreamingResponse(
_sse_stream(runtime.run_stream(req.message)),
media_type="text/event-stream",
@@ -453,10 +536,24 @@ def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
def _build_memory_config_from_node(agent_node_cfg: dict) -> AgentMemoryConfig:
"""从 Agent 工作流节点配置中提取记忆配置。"""
from app.core.compaction_config import CompactionConfig
# 压缩配置
compaction_raw = agent_node_cfg.get("compaction")
if isinstance(compaction_raw, dict):
compaction = CompactionConfig(**compaction_raw)
else:
compaction = CompactionConfig() # 默认配置
return AgentMemoryConfig(
max_history_messages=int(agent_node_cfg.get("memory_max_history", 20)),
vector_memory_top_k=int(agent_node_cfg.get("memory_vector_top_k", 5)),
persist_to_db=bool(agent_node_cfg.get("memory_persist", True)),
vector_memory_enabled=bool(agent_node_cfg.get("memory_vector_enabled", True)),
learning_enabled=bool(agent_node_cfg.get("memory_learning", True)),
# 文件式记忆 (MEMORY.md)
memory_dir_enabled=bool(agent_node_cfg.get("memory_dir_enabled", False)),
memory_dir_path=str(agent_node_cfg.get("memory_dir_path", "")),
# 对话压缩
compaction=compaction,
)

View File

@@ -45,7 +45,7 @@ class AgentMarketItem(BaseModel):
use_count: int = 0
view_count: int = 0
version: int = 1
user_id: str
user_id: Optional[str] = None
creator_username: Optional[str] = None
is_favorited: Optional[bool] = None
user_rating: Optional[int] = None
@@ -70,7 +70,7 @@ class AgentMarketDetail(BaseModel):
use_count: int = 0
view_count: int = 0
version: int = 1
user_id: str
user_id: Optional[str] = None
creator_username: Optional[str] = None
is_favorited: Optional[bool] = None
user_rating: Optional[int] = None
@@ -100,7 +100,7 @@ class RatingResponse(BaseModel):
"""评分响应"""
id: str
agent_id: str
user_id: str
user_id: Optional[str] = None
username: Optional[str] = None
rating: int
comment: Optional[str] = None
@@ -137,7 +137,7 @@ def _build_agent_item(agent: Agent, current_user: Optional[User], db: Session) -
"use_count": agent.use_count or 0,
"view_count": agent.view_count or 0,
"version": agent.version or 1,
"user_id": agent.user_id,
"user_id": agent.user_id or "",
"creator_username": agent.user.username if agent.user else None,
"is_favorited": None,
"user_rating": None,

View File

@@ -0,0 +1,280 @@
"""
Agent 蜂群 API — Leader/Teammate 并行协作
POST /api/v1/swarm/run
{"message": "帮我做三件事: ...", "mode": "parallel", "max_teammates": 5}
→ Leader 分解任务 → Teammates 并行执行 → Leader 汇总
参考 Claude Code src/tools/AgentTool/ + forkSubagent.ts
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from app.core.database import get_db
from sqlalchemy.orm import Session
from app.api.auth import get_current_user
from app.models.user import User
from app.models.agent import Agent
from app.agent_runtime.swarm import (
SwarmRuntime,
SwarmConfig,
SwarmMode,
SwarmResult,
SwarmTask,
create_swarm,
)
from app.agent_runtime.schemas import AgentConfig, AgentLLMConfig, AgentToolConfig
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/swarm", tags=["swarm"])
# ──────────────────────────── 请求/响应模型 ────────────────────────────
class SwarmRunRequest(BaseModel):
message: str = Field(..., description="用户输入")
mode: str = Field(default="parallel", description="蜂群模式: parallel | pipeline | debate")
max_teammates: int = Field(default=5, ge=1, le=20)
leader_model: Optional[str] = Field(default=None, description="Leader 模型")
teammate_model: Optional[str] = Field(default=None, description="Teammate 模型")
mailbox_enabled: bool = Field(default=True, description="启用 Agent 间消息传递")
agent_ids: Optional[List[str]] = Field(default=None, description="指定的 Agent ID 列表(作为 Teammates")
retry_failed: bool = Field(default=True, description="失败任务是否重试")
class SwarmTaskItem(BaseModel):
id: str
description: str
assigned_agent_id: Optional[str] = None
status: str
result: Optional[str] = None
error: Optional[str] = None
iterations_used: int = 0
tool_calls_made: int = 0
duration_ms: int = 0
class SwarmTeammateItem(BaseModel):
agent_id: str
agent_name: str
task_id: str
success: bool
output: str
duration_ms: int
iterations_used: int
tool_calls_made: int
error: Optional[str] = None
class MailboxMessageItem(BaseModel):
id: str
from_: str = Field(alias="from")
to: str
content: str
timestamp: float
class SwarmRunResponse(BaseModel):
success: bool
final_answer: str
mode: str
tasks: List[SwarmTaskItem] = Field(default_factory=list)
teammate_results: List[SwarmTeammateItem] = Field(default_factory=list)
mailbox_messages: List[Dict[str, Any]] = Field(default_factory=list)
total_duration_ms: int = 0
total_iterations: int = 0
total_tool_calls: int = 0
error: Optional[str] = None
# ──────────────────────────── 端点 ────────────────────────────
@router.post("/run", response_model=SwarmRunResponse)
async def swarm_run(
req: SwarmRunRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""运行 Agent 蜂群 — Leader 分解任务 → Teammates 并行执行 → 汇总。
支持三种模式:
- parallel: 所有子任务并发执行(无依赖)
- pipeline: 按依赖顺序执行
- debate: 多个 Agent 独立回答后汇总
Teammates 来源(优先级):
1. agent_ids 参数指定 → 从数据库加载 Agent 配置
2. 自动生成 → 使用 teammate_model 创建轻量 Teammate
"""
uid = current_user.id
# 解析模式
mode = SwarmMode.PARALLEL
if req.mode == "pipeline":
mode = SwarmMode.PIPELINE
elif req.mode == "debate":
mode = SwarmMode.DEBATE
# 构建 SwarmConfig
config = SwarmConfig(
mode=mode,
max_teammates=req.max_teammates,
leader_model=req.leader_model or "deepseek-v4-pro",
teammate_model=req.teammate_model or "deepseek-v4-flash",
mailbox_enabled=req.mailbox_enabled,
retry_failed=req.retry_failed,
)
# 加载指定的 Agent 作为 Teammates
teammate_configs: List[AgentConfig] = []
if req.agent_ids:
for aid in req.agent_ids:
agent = db.query(Agent).filter(Agent.id == aid).first()
if agent:
wc = agent.workflow_config or {}
nodes = wc.get("nodes", [])
agent_node_cfg = {}
for node in nodes:
if node.get("type") in ("agent", "llm", "template"):
agent_node_cfg = node.get("data") or {}
break
teammate_configs.append(AgentConfig(
name=agent.name,
system_prompt=agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。",
llm=AgentLLMConfig(
model=agent_node_cfg.get("model", req.teammate_model or "deepseek-v4-flash"),
provider=agent_node_cfg.get("provider", "deepseek"),
temperature=float(agent_node_cfg.get("temperature", 0.7)),
max_iterations=int(agent_node_cfg.get("max_iterations", 10)),
),
tools=AgentToolConfig(
include_tools=agent_node_cfg.get("tools", []),
),
user_id=uid,
))
# 构建 Leader 配置
leader_config = AgentConfig(
name="SwarmLeader",
system_prompt="你是一个AI任务协调者。将复杂问题分解为子任务协调多个AI Agent并行处理并汇总结果。",
llm=AgentLLMConfig(model=config.leader_model, temperature=0.3, max_iterations=10),
user_id=uid,
)
# 创建并运行 Swarm
swarm = SwarmRuntime(
config=config,
leader_config=leader_config,
teammate_configs=teammate_configs,
)
result = await swarm.run(req.message)
return SwarmRunResponse(
success=result.success,
final_answer=result.final_answer,
mode=result.mode.value,
tasks=[
SwarmTaskItem(
id=t.id, description=t.description,
assigned_agent_id=t.assigned_agent_id,
status=t.status.value, result=t.result[:500] if t.result else None,
error=t.error, iterations_used=t.iterations_used,
tool_calls_made=t.tool_calls_made, duration_ms=t.duration_ms,
)
for t in result.tasks
],
teammate_results=[
SwarmTeammateItem(
agent_id=tr["agent_id"], agent_name=tr["agent_name"],
task_id=tr["task_id"], success=tr["success"],
output=tr["output"][:500], duration_ms=tr["duration_ms"],
iterations_used=tr["iterations_used"], tool_calls_made=tr["tool_calls_made"],
error=tr.get("error"),
)
for tr in result.teammate_results
],
mailbox_messages=result.mailbox_messages,
total_duration_ms=result.total_duration_ms,
total_iterations=result.total_iterations,
total_tool_calls=result.total_tool_calls,
error=result.error,
)
@router.post("/run/stream")
async def swarm_run_stream(
req: SwarmRunRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""运行 Agent 蜂群(流式 SSE— 实时推送任务分解、执行进度、汇总结果。"""
import json as _json
async def _stream():
uid = current_user.id
mode = {"parallel": SwarmMode.PARALLEL, "pipeline": SwarmMode.PIPELINE,
"debate": SwarmMode.DEBATE}.get(req.mode, SwarmMode.PARALLEL)
config = SwarmConfig(
mode=mode, max_teammates=req.max_teammates,
leader_model=req.leader_model or "deepseek-v4-pro",
teammate_model=req.teammate_model or "deepseek-v4-flash",
mailbox_enabled=req.mailbox_enabled, retry_failed=req.retry_failed,
)
# Load specified agents
teammate_configs = []
if req.agent_ids:
for aid in req.agent_ids:
agent = db.query(Agent).filter(Agent.id == aid).first()
if agent:
wc = agent.workflow_config or {}
agent_node_cfg = {}
for node in wc.get("nodes", []):
if node.get("type") in ("agent", "llm", "template"):
agent_node_cfg = node.get("data") or {}
break
teammate_configs.append(AgentConfig(
name=agent.name,
system_prompt=agent_node_cfg.get("system_prompt") or agent.description or "你是一个有用的AI助手。",
llm=AgentLLMConfig(
model=agent_node_cfg.get("model", req.teammate_model or "deepseek-v4-flash"),
provider=agent_node_cfg.get("provider", "deepseek"),
temperature=float(agent_node_cfg.get("temperature", 0.7)),
max_iterations=int(agent_node_cfg.get("max_iterations", 10)),
),
tools=AgentToolConfig(include_tools=agent_node_cfg.get("tools", [])),
user_id=uid,
))
leader_config = AgentConfig(
name="SwarmLeader",
system_prompt="你是一个AI任务协调者。",
llm=AgentLLMConfig(model=config.leader_model, temperature=0.3, max_iterations=10),
user_id=uid,
)
swarm = SwarmRuntime(config=config, leader_config=leader_config,
teammate_configs=teammate_configs)
yield f"data: {_json.dumps({'type': 'swarm_start', 'mode': req.mode, 'max_teammates': req.max_teammates}, ensure_ascii=False)}\n\n"
result = await swarm.run(req.message)
yield f"data: {_json.dumps({'type': 'swarm_done', 'success': result.success, 'final_answer': result.final_answer, 'total_duration_ms': result.total_duration_ms, 'total_iterations': result.total_iterations, 'total_tool_calls': result.total_tool_calls}, ensure_ascii=False)}\n\n"
return StreamingResponse(
_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)

View File

@@ -1,8 +1,9 @@
"""
Agent管理API
"""
from fastapi import APIRouter, Depends, HTTPException, status, Query, Response
from fastapi import APIRouter, Depends, HTTPException, status, Query, Response, Request
from sqlalchemy.orm import Session
from sqlalchemy import text
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
@@ -61,16 +62,18 @@ class SceneTemplateItem(BaseModel):
category: Optional[str] = None
default_temperature: Optional[float] = None
parameter_hints: List[str] = Field(default_factory=list)
contract_id: Optional[str] = None
class AgentFromSceneTemplateCreate(BaseModel):
"""从场景模板创建 Agent"""
"""从场景模板创建 Agent(支持 DSL 契约参数)"""
template_id: str
name: str
description: Optional[str] = None
parameters: Dict[str, Any] = Field(default_factory=dict)
budget_config: Optional[Dict[str, Any]] = None
contract_id: Optional[str] = None
class PreviewChatTurnResponse(BaseModel):
@@ -96,6 +99,9 @@ class AgentResponse(BaseModel):
version: int
status: str
user_id: Optional[str] # 允许为None
parent_agent_id: Optional[str] = None
agent_type: Optional[str] = None
category: Optional[str] = None
created_at: datetime
updated_at: datetime
@@ -110,13 +116,14 @@ async def get_agents(
limit: int = Query(100, ge=1, le=100, description="每页记录数"),
search: Optional[str] = Query(None, description="搜索关键词(按名称或描述)"),
status: Optional[str] = Query(None, description="状态筛选"),
workspace_id: Optional[str] = Query(None, description="工作区ID筛选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取Agent列表
支持分页、搜索、状态筛选
支持分页、搜索、状态筛选、工作区筛选
"""
# 管理员可以看到所有Agent普通用户只能看到自己拥有的或有read权限的
if current_user.role == "admin":
@@ -125,10 +132,10 @@ async def get_agents(
# 获取用户拥有或有read权限的Agent
from sqlalchemy import or_
from app.models.permission import AgentPermission
# 用户拥有的Agent
owned_agents = db.query(Agent.id).filter(Agent.user_id == current_user.id).subquery()
# 用户有read权限的Agent通过用户ID或角色
user_permissions = db.query(AgentPermission.agent_id).filter(
AgentPermission.permission_type == "read",
@@ -137,14 +144,18 @@ async def get_agents(
AgentPermission.role_id.in_([r.id for r in current_user.roles])
)
).subquery()
query = db.query(Agent).filter(
or_(
Agent.id.in_(db.query(owned_agents.c.id)),
Agent.id.in_(db.query(user_permissions.c.agent_id))
)
)
# 工作区筛选
if workspace_id:
query = query.filter(Agent.workspace_id == workspace_id)
# 搜索:按名称或描述搜索
if search:
search_pattern = f"%{search}%"
@@ -187,26 +198,36 @@ async def get_agents(
@router.post("", response_model=AgentResponse, status_code=status.HTTP_201_CREATED)
async def create_agent(
agent_data: AgentCreate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
创建Agent
创建时会验证工作流配置的有效性
"""
# 从 JWT 提取当前工作区 ID
from app.core.security import decode_access_token
ws_id = None
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
payload = decode_access_token(auth_header[7:])
if payload:
ws_id = payload.get("ws") or None
# 验证工作流配置
if "nodes" not in agent_data.workflow_config or "edges" not in agent_data.workflow_config:
raise ValidationError("工作流配置必须包含nodes和edges")
nodes = agent_data.workflow_config.get("nodes", [])
edges = agent_data.workflow_config.get("edges", [])
# 验证工作流
validation_result = validate_workflow(nodes, edges)
if not validation_result["valid"]:
raise ValidationError(f"工作流配置验证失败: {', '.join(validation_result['errors'])}")
# 检查名称是否重复
existing_agent = db.query(Agent).filter(
Agent.name == agent_data.name,
@@ -214,7 +235,7 @@ async def create_agent(
).first()
if existing_agent:
raise ConflictError(f"Agent名称 '{agent_data.name}' 已存在")
# 创建Agent
agent = Agent(
name=agent_data.name,
@@ -222,6 +243,7 @@ async def create_agent(
workflow_config=agent_data.workflow_config,
budget_config=agent_data.budget_config,
user_id=current_user.id,
workspace_id=ws_id,
status="draft",
category=agent_data.category,
tags=agent_data.tags,
@@ -377,9 +399,33 @@ async def delete_agent(
raise HTTPException(status_code=403, detail="无权删除此Agent")
agent_name = agent.name
# 1. 清理直接关联的记录(级联删除)
related_tables = [
"agent_llm_logs",
"agent_permissions",
"agent_schedules",
"executions",
"team_members",
]
for table in related_tables:
db.execute(text(f"DELETE FROM {table} WHERE agent_id = :aid"), {"aid": agent_id})
logger.debug(f"已清理 {table} 中 agent_id={agent_id} 的记录")
# 2. 解除分配关系(设为 NULL
db.execute(
text("UPDATE goals SET main_agent_id = NULL WHERE main_agent_id = :aid"),
{"aid": agent_id},
)
db.execute(
text("UPDATE tasks SET assigned_agent_id = NULL WHERE assigned_agent_id = :aid"),
{"aid": agent_id},
)
# 3. 删除 Agent 自身
db.delete(agent)
db.commit()
logger.info(f"用户 {current_user.username} 删除了Agent: {agent_name} ({agent_id})")
return {"message": "Agent已删除"}
@@ -773,3 +819,67 @@ async def import_agent(
db.commit()
db.refresh(agent)
return agent
# ——— Agent 间知识共享 API ———
class InheritKnowledgeResponse(BaseModel):
child_agent_id: str
parent_agent_id: str
knowledge_entries_copied: int
global_knowledge_shared: int
learning_patterns_copied: int
message: str
@router.post("/{agent_id}/inherit-knowledge", response_model=InheritKnowledgeResponse)
def inherit_agent_knowledge(
agent_id: str,
parent_agent_id: str = Query(..., description="父 Agent ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""让子 Agent 从父 Agent 继承知识条目、全局知识和学习模式。"""
child = db.query(Agent).filter(Agent.id == agent_id).first()
if not child:
raise NotFoundError("Agent", agent_id)
parent = db.query(Agent).filter(Agent.id == parent_agent_id).first()
if not parent:
raise NotFoundError("父 Agent", parent_agent_id)
# 记录父子关系
child.parent_agent_id = parent_agent_id
db.commit()
try:
from app.services.knowledge_sharing import inherit_knowledge_from_parent
result = inherit_knowledge_from_parent(agent_id, parent_agent_id, db)
result["message"] = f"已从父 Agent 继承 {result['knowledge_entries_copied']} 条知识、" \
f"{result['global_knowledge_shared']} 条全局知识、" \
f"{result['learning_patterns_copied']} 条学习模式"
return InheritKnowledgeResponse(**result)
except Exception as e:
raise HTTPException(status_code=500, detail=f"知识继承失败: {e}")
@router.get("/{agent_id}/parent-knowledge")
def get_parent_knowledge_preview(
agent_id: str,
limit: int = Query(5, ge=1, le=20),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""预览 Agent 可用的父级知识(不实际继承)。"""
agent = db.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
raise NotFoundError("Agent", agent_id)
from app.services.knowledge_sharing import get_parent_knowledge_context
context = get_parent_knowledge_context(agent_id, max_entries=limit, db=db)
return {
"agent_id": agent_id,
"parent_agent_id": getattr(agent, "parent_agent_id", None),
"has_parent_knowledge": bool(context),
"context_preview": context[:2000] if context else "",
}

View File

@@ -0,0 +1,125 @@
"""
操作审计日志 API
"""
import logging
from datetime import datetime
from typing import Optional, List
from fastapi import APIRouter, Depends, Query, HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from pydantic import BaseModel
from app.core.database import get_db
from app.api.auth import get_current_user
from app.models.user import User
from app.models.audit_log import AuditLog
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/v1/audit-logs",
tags=["audit-logs"],
responses={
401: {"description": "未授权"},
403: {"description": "无权访问"},
500: {"description": "服务器内部错误"}
}
)
# ── Pydantic Schemas ─────────────────────────────────────────────
class AuditLogItem(BaseModel):
id: str
username: str
action: str
resource_type: str
resource_id: Optional[str] = None
resource_name: Optional[str] = None
detail: Optional[dict] = None
ip_address: Optional[str] = None
status: str
created_at: datetime
class Config:
from_attributes = True
class AuditLogStats(BaseModel):
total: int
by_action: dict # {"CREATE": 10, "UPDATE": 5, ...}
by_resource: dict # {"agent": 8, "workflow": 7, ...}
# ── Helpers ──────────────────────────────────────────────────────
def _check_admin(current_user: User):
if getattr(current_user, "role", None) != "admin":
from app.core.exceptions import ForbiddenError
raise ForbiddenError("仅管理员可访问审计日志")
# ── Endpoints ────────────────────────────────────────────────────
@router.get("", response_model=List[AuditLogItem])
async def get_audit_logs(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
user_id: Optional[str] = Query(None, description="按用户ID过滤"),
action: Optional[str] = Query(None, description="操作类型: CREATE/UPDATE/DELETE/EXECUTE/LOGIN"),
resource_type: Optional[str] = Query(None, description="资源类型: agent/workflow/user"),
status: Optional[str] = Query(None, description="操作状态: success/failure"),
start_date: Optional[str] = Query(None, description="开始时间 ISO格式"),
end_date: Optional[str] = Query(None, description="结束时间 ISO格式"),
keyword: Optional[str] = Query(None, description="资源名称关键词搜索"),
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=1000),
):
"""查询操作审计日志(仅管理员)"""
_check_admin(current_user)
query = db.query(AuditLog)
if user_id:
query = query.filter(AuditLog.user_id == user_id)
if action:
query = query.filter(AuditLog.action == action.upper())
if resource_type:
query = query.filter(AuditLog.resource_type == resource_type)
if status:
query = query.filter(AuditLog.status == status)
if keyword:
query = query.filter(AuditLog.resource_name.contains(keyword))
if start_date:
query = query.filter(AuditLog.created_at >= datetime.fromisoformat(start_date))
if end_date:
query = query.filter(AuditLog.created_at <= datetime.fromisoformat(end_date))
total = query.count()
logs = query.order_by(desc(AuditLog.created_at)).offset(skip).limit(limit).all()
return logs
@router.get("/stats", response_model=AuditLogStats)
async def get_audit_logs_stats(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取审计日志统计(仅管理员)"""
_check_admin(current_user)
total = db.query(func.count(AuditLog.id)).scalar() or 0
action_rows = db.query(
AuditLog.action, func.count(AuditLog.id)
).group_by(AuditLog.action).all()
by_action = {row[0]: row[1] for row in action_rows}
resource_rows = db.query(
AuditLog.resource_type, func.count(AuditLog.id)
).group_by(AuditLog.resource_type).all()
by_resource = {row[0]: row[1] for row in resource_rows}
return {"total": total, "by_action": by_action, "by_resource": by_resource}

View File

@@ -6,13 +6,15 @@ from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from pydantic import BaseModel, field_validator
import re
import secrets
import logging
from app.core.database import get_db
from app.core.security import verify_password, get_password_hash, create_access_token
from app.models.user import User
from datetime import timedelta
from datetime import datetime, timedelta
from app.core.config import settings
from app.core.exceptions import ConflictError, UnauthorizedError, NotFoundError
from app.core.redis_client import get_redis_client
logger = logging.getLogger(__name__)
@@ -27,6 +29,9 @@ router = APIRouter(
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
oauth2_scheme_optional = OAuth2PasswordBearer(
tokenUrl="/api/v1/auth/login", auto_error=False
)
class UserCreate(BaseModel):
@@ -54,6 +59,16 @@ class UserResponse(BaseModel):
from_attributes = True
class MeResponse(BaseModel):
"""当前用户完整信息(含工作区列表)"""
id: str
username: str
email: str
role: str
workspaces: list = []
current_workspace_id: str | None = None
class Token(BaseModel):
"""令牌响应模型"""
access_token: str
@@ -85,40 +100,343 @@ async def register(user_data: UserCreate, db: Session = Depends(get_db)):
return user
def _get_user_default_workspace_id(db: Session, user: User) -> str | None:
"""获取用户的默认工作区 ID。优先使用默认工作区其次第一个 membership。"""
from app.models.workspace import Workspace, WorkspaceMembership
# 优先使用系统默认工作区
default_ws = db.query(Workspace).filter(Workspace.is_default == 1, Workspace.status == "active").first()
if default_ws:
membership = (
db.query(WorkspaceMembership)
.filter(
WorkspaceMembership.workspace_id == default_ws.id,
WorkspaceMembership.user_id == user.id,
)
.first()
)
if membership:
return default_ws.id
# 没有默认工作区,使用第一个 membership
first_membership = (
db.query(WorkspaceMembership)
.filter(WorkspaceMembership.user_id == user.id)
.first()
)
if first_membership:
return first_membership.workspace_id
return None
@router.post("/login", response_model=Token)
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
"""用户登录"""
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db),
client_type: str = "web"
):
"""用户登录。client_type=android/ios 时签发 7 天 tokenweb 默认 30 分钟。"""
user = db.query(User).filter(User.username == form_data.username).first()
if not user or not verify_password(form_data.password, user.password_hash):
logger.warning(f"登录失败: 用户名 {form_data.username}")
raise UnauthorizedError("用户名或密码错误")
from datetime import timedelta
if client_type in ("android", "ios"):
expires = timedelta(minutes=settings.JWT_MOBILE_TOKEN_EXPIRE_MINUTES)
else:
expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
ws_id = _get_user_default_workspace_id(db, user)
access_token = create_access_token(
data={"sub": user.id, "username": user.username}
data={"sub": user.id, "username": user.username, "ws": ws_id or ""},
expires_delta=expires,
)
return {"access_token": access_token, "token_type": "bearer"}
@router.get("/me", response_model=UserResponse)
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
):
"""获取当前用户信息"""
) -> User:
"""FastAPI 依赖 — 从 JWT 提取当前用户,返回 User 模型。"""
from app.core.security import decode_access_token
payload = decode_access_token(token)
if payload is None:
raise UnauthorizedError("无效的访问令牌")
user_id = payload.get("sub")
if user_id is None:
raise UnauthorizedError("无效的访问令牌")
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise NotFoundError("用户", user_id)
return user
@router.get("/me", response_model=MeResponse)
async def get_me(
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db)
):
"""获取当前用户信息(含工作区列表)。"""
from app.core.security import decode_access_token
from app.services.workspace_service import get_user_workspaces
payload = decode_access_token(token)
if payload is None:
raise UnauthorizedError("无效的访问令牌")
user_id = payload.get("sub")
if user_id is None:
raise UnauthorizedError("无效的访问令牌")
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise NotFoundError("用户", user_id)
workspaces = get_user_workspaces(db, user)
current_ws_id = payload.get("ws", "")
return {
"id": user.id,
"username": user.username,
"email": user.email,
"role": user.role,
"workspaces": workspaces,
"current_workspace_id": current_ws_id if current_ws_id else None,
}
@router.post("/switch-workspace/{workspace_id}")
async def switch_workspace(
workspace_id: str,
token: str = Depends(oauth2_scheme),
db: Session = Depends(get_db),
):
"""切换当前工作区,重新签发 JWT包含新的 ws 字段)。"""
from app.core.security import decode_access_token
from app.services.workspace_service import check_workspace_access
payload = decode_access_token(token)
if payload is None:
raise UnauthorizedError("无效的访问令牌")
user_id = payload.get("sub")
if user_id is None:
raise UnauthorizedError("无效的访问令牌")
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise NotFoundError("用户", user_id)
if not check_workspace_access(db, user, workspace_id):
raise HTTPException(status_code=403, detail="无权访问此工作区")
from datetime import timedelta
expires = timedelta(minutes=settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES)
new_token = create_access_token(
data={"sub": user.id, "username": user.username, "ws": workspace_id},
expires_delta=expires,
)
return {"access_token": new_token, "token_type": "bearer", "workspace_id": workspace_id}
# ─── 密码重置 ───────────────────────────────────────────────
RESET_CODE_TTL_SEC = 600 # 验证码 10 分钟有效
RESET_RATE_LIMIT_SEC = 60 # 同一邮箱 60 秒内只能发一次
class ForgotPasswordRequest(BaseModel):
email: str
@field_validator("email")
@classmethod
def email_format(cls, v: str) -> str:
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
raise ValueError("邮箱格式无效")
return v.lower()
class ResetPasswordRequest(BaseModel):
email: str
code: str
new_password: str
@field_validator("email")
@classmethod
def email_format(cls, v: str) -> str:
if not v or not re.match(r"^[^@]+@[^@]+\.[^@]+$", v):
raise ValueError("邮箱格式无效")
return v.lower()
@field_validator("new_password")
@classmethod
def password_length(cls, v: str) -> str:
if len(v) < 6:
raise ValueError("密码不少于 6 个字符")
if len(v) > 32:
raise ValueError("密码不超过 32 个字符")
return v
async def _send_reset_email(email: str, code: str) -> bool:
"""发送密码重置邮件。SMTP 不可用时记日志。"""
try:
import aiosmtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
smtp_host = getattr(settings, 'SMTP_HOST', '') or 'smtp.qq.com'
smtp_port = int(getattr(settings, 'SMTP_PORT', 0) or 587)
smtp_user = getattr(settings, 'SMTP_USER', '') or ''
smtp_password = getattr(settings, 'SMTP_PASSWORD', '') or ''
if not smtp_user or not smtp_password:
logger.warning("SMTP 未配置,无法发送邮件。重置码: %s", code)
return False
msg = MIMEMultipart()
msg['From'] = smtp_user
msg['To'] = email
msg['Subject'] = '天工智能体 - 密码重置验证码'
msg.attach(MIMEText(
f'您的密码重置验证码是:<b>{code}</b><br><br>'
f'验证码 10 分钟内有效。如非本人操作请忽略此邮件。',
'html', 'utf-8'
))
await aiosmtplib.send(
msg, hostname=smtp_host, port=smtp_port,
username=smtp_user, password=smtp_password,
use_tls=smtp_port == 587,
)
logger.info("密码重置邮件已发送至 %s", email)
return True
except Exception as e:
logger.warning("邮件发送失败: %s,重置码: %s", e, code)
return False
@router.post("/forgot-password")
async def forgot_password(body: ForgotPasswordRequest, db: Session = Depends(get_db)):
"""发送密码重置验证码。"""
user = db.query(User).filter(User.email == body.email).first()
if not user:
# 不泄露邮箱是否注册,统一返回成功
return {"message": "如果邮箱已注册,验证码已发送"}
redis = get_redis_client()
# 频率限制
rate_key = f"pwd_reset_rate:{body.email}"
if redis:
if redis.exists(rate_key):
ttl = redis.ttl(rate_key)
raise HTTPException(
status_code=429,
detail=f"操作过于频繁,请 {ttl} 秒后重试"
)
code = secrets.randbelow(900000) + 100000 # 6 位数字
code_str = str(code)
# 存储到 Redis
code_key = f"pwd_reset_code:{body.email}"
if redis:
redis.setex(code_key, RESET_CODE_TTL_SEC, code_str)
redis.setex(rate_key, RESET_RATE_LIMIT_SEC, "1")
else:
# 无 Redis 时用内存存储(重启失效)
if not hasattr(forgot_password, '_memory_store'):
forgot_password._memory_store = {}
forgot_password._memory_rate = {}
forgot_password._memory_store[body.email] = {
"code": code_str,
"expires_at": datetime.utcnow() + timedelta(seconds=RESET_CODE_TTL_SEC),
}
forgot_password._memory_rate[body.email] = \
datetime.utcnow() + timedelta(seconds=RESET_RATE_LIMIT_SEC)
# 尝试发送邮件
sent = await _send_reset_email(body.email, code_str)
if not sent:
# SMTP 未配置时记录验证码并返回(开发/测试环境)
logger.info("开发模式:%s 的密码重置验证码为 %s", body.email, code_str)
return {
"message": "验证码已生成",
"dev_code": code_str,
}
return {"message": "验证码已发送至邮箱"}
@router.post("/reset-password")
async def reset_password(body: ResetPasswordRequest, db: Session = Depends(get_db)):
"""使用验证码重置密码。"""
user = db.query(User).filter(User.email == body.email).first()
if not user:
raise HTTPException(status_code=400, detail="邮箱未注册")
redis = get_redis_client()
code_key = f"pwd_reset_code:{body.email}"
stored_code = None
if redis:
stored_code = redis.get(code_key)
elif hasattr(forgot_password, '_memory_store'):
entry = forgot_password._memory_store.get(body.email, {})
if entry and entry.get("expires_at", datetime.min) > datetime.utcnow():
stored_code = entry.get("code")
if not stored_code:
raise HTTPException(status_code=400, detail="验证码已过期或未请求")
if stored_code != body.code.strip():
raise HTTPException(status_code=400, detail="验证码错误")
# 更新密码
user.password_hash = get_password_hash(body.new_password)
db.commit()
# 清除验证码
if redis:
redis.delete(code_key)
elif hasattr(forgot_password, '_memory_store'):
forgot_password._memory_store.pop(body.email, None)
logger.info("用户 %s 密码重置成功", user.username)
return {"message": "密码重置成功,请使用新密码登录"}
async def get_optional_user(
token: str | None = Depends(oauth2_scheme_optional),
db: Session = Depends(get_db)
) -> User | None:
"""获取当前用户(可选登录)。未提供 token 或 token 无效时返回 None。"""
if not token:
return None
from app.core.security import decode_access_token
payload = decode_access_token(token)
if payload is None:
return None
user_id = payload.get("sub")
if user_id is None:
return None
user = db.query(User).filter(User.id == user_id).first()
return user

64
backend/app/api/deps.py Normal file
View File

@@ -0,0 +1,64 @@
"""
FastAPI 依赖注入 — Workspace 上下文、权限检查
"""
from dataclasses import dataclass
from typing import Optional
from fastapi import Depends, HTTPException, Request
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.core.security import decode_access_token
from app.models.user import User
@dataclass
class WorkspaceContext:
"""工作区上下文 — 包装当前用户 + 当前工作区 ID"""
user: User
workspace_id: str
def get_current_workspace_id(request: Request) -> str:
"""从 JWT 的 `ws` 字段提取当前工作区 ID。"""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证令牌")
token = auth_header[7:]
payload = decode_access_token(token)
if payload is None:
raise HTTPException(status_code=401, detail="无效的访问令牌")
ws_id = payload.get("ws")
if not ws_id:
raise HTTPException(status_code=400, detail="令牌中未包含工作区信息,请重新登录")
return ws_id
def get_workspace_context(
request: Request,
db: Session = Depends(get_db),
) -> WorkspaceContext:
"""获取完整的 Workspace 上下文(用户 + 工作区),用于需要 workspace 过滤的接口。"""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证令牌")
token = auth_header[7:]
payload = decode_access_token(token)
if payload is None:
raise HTTPException(status_code=401, detail="无效的访问令牌")
user_id = payload.get("sub")
ws_id = payload.get("ws")
if not user_id:
raise HTTPException(status_code=401, detail="无效的访问令牌")
user = db.query(User).filter(User.id == user_id).first()
if not user:
raise HTTPException(status_code=401, detail="用户不存在")
return WorkspaceContext(user=user, workspace_id=ws_id or "")

80
backend/app/api/fcm.py Normal file
View File

@@ -0,0 +1,80 @@
"""
FCM 推送 API — Android 设备 Token 管理
提供设备推送令牌的注册和注销接口。
"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from app.api.auth import get_current_user
from app.core.database import SessionLocal
from app.models.fcm_token import FcmToken
from app.models.user import User
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/fcm", tags=["fcm"])
class FcmRegisterRequest(BaseModel):
token: str = Field(..., min_length=32, max_length=512, description="FCM/APNs 设备令牌")
platform: str = Field(default="android", description="android / ios / web")
@router.post("/register")
async def register_fcm_token(
req: FcmRegisterRequest,
current_user: User = Depends(get_current_user),
):
"""注册设备推送 Token。同一 Token 多次注册会更新绑定用户和时间。"""
db = SessionLocal()
try:
existing = db.query(FcmToken).filter(FcmToken.token == req.token).first()
if existing:
existing.user_id = current_user.id
existing.platform = req.platform
db.commit()
logger.info("FCM Token 已更新: user=%s platform=%s", current_user.id, req.platform)
return {"status": "updated", "id": str(existing.id)}
fcm = FcmToken(
user_id=current_user.id,
token=req.token,
platform=req.platform,
)
db.add(fcm)
db.commit()
db.refresh(fcm)
logger.info("FCM Token 已注册: user=%s platform=%s", current_user.id, req.platform)
return {"status": "registered", "id": str(fcm.id)}
finally:
db.close()
@router.delete("/unregister")
async def unregister_fcm_token(
token: str,
current_user: User = Depends(get_current_user),
):
"""注销设备推送 Token。"""
db = SessionLocal()
try:
fcm = (
db.query(FcmToken)
.filter(FcmToken.token == token, FcmToken.user_id == current_user.id)
.first()
)
if fcm:
db.delete(fcm)
db.commit()
logger.info("FCM Token 已注销: user=%s", current_user.id)
return {"status": "deleted"}
return {"status": "not_found"}
finally:
db.close()

View File

@@ -0,0 +1,93 @@
"""
反馈数据 API — 用户反馈分析、记录查询、反例生成
"""
from __future__ import annotations
import logging
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.auth import get_current_user
from app.models.user import User
from app.services.feedback_learner import feedback_learner
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/feedback", tags=["feedback"])
@router.get("/analysis")
def get_feedback_analysis(
days: int = Query(7, ge=1, le=90, description="统计天数"),
agent_name: Optional[str] = Query(None, description="按 Agent 筛选"),
current_user: User = Depends(get_current_user),
):
"""获取反馈分析报告 — 信号分布、负面率、策略建议。"""
result = feedback_learner.analyze_feedback_patterns(agent_name=agent_name, days=days)
return result
@router.get("/records")
def list_feedback_records(
agent_name: Optional[str] = Query(None, description="按 Agent 筛选"),
signal_type: Optional[str] = Query(None, description="按信号类型筛选"),
days: int = Query(30, ge=1, le=365, description="统计天数"),
limit: int = Query(50, ge=1, le=200, description="返回条数"),
offset: int = Query(0, ge=0, description="偏移量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""查询反馈记录列表。"""
from datetime import datetime, timedelta
from app.models.feedback_record import FeedbackRecord
since = datetime.now() - timedelta(days=days)
q = db.query(FeedbackRecord).filter(FeedbackRecord.created_at >= since)
if agent_name:
q = q.filter(FeedbackRecord.agent_name == agent_name)
if signal_type:
q = q.filter(FeedbackRecord.signal_type == signal_type)
total = q.count()
records = (
q.order_by(FeedbackRecord.created_at.desc())
.offset(offset)
.limit(limit)
.all()
)
return {
"total": total,
"limit": limit,
"offset": offset,
"items": [
{
"id": r.id,
"user_id": r.user_id,
"signal_type": r.signal_type,
"severity": r.severity,
"execution_log_id": r.execution_log_id,
"agent_name": r.agent_name,
"original_output": (r.original_output or "")[:200],
"user_correction": (r.user_correction or "")[:200],
"learned": r.learned,
"lesson_summary": r.lesson_summary,
"created_at": r.created_at.isoformat() if r.created_at else None,
}
for r in records
],
}
@router.get("/negative-examples/{agent_name}")
def get_negative_examples(
agent_name: str,
limit: int = Query(5, ge=1, le=20, description="返回条数"),
current_user: User = Depends(get_current_user),
):
"""获取指定 Agent 的反例(用于改进 system prompt"""
examples = feedback_learner.generate_negative_examples(agent_name=agent_name, limit=limit)
return {"agent_name": agent_name, "count": len(examples), "examples": examples}

View File

@@ -0,0 +1,109 @@
"""
知识仪表盘 API — 瓶颈分析、知识条目查询、趋势统计
"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.core.database import get_db
from app.api.auth import get_current_user
from app.models.user import User
from app.models.knowledge_entry import KnowledgeEntry
from app.services.bottleneck_detector import bottleneck_detector
from app.services.optimization_engine import optimization_engine
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/knowledge-dashboard", tags=["knowledge-dashboard"])
@router.get("/bottlenecks")
def get_bottlenecks(
hours: int = Query(168, ge=1, le=2160, description="分析时长(小时)"),
current_user: User = Depends(get_current_user),
):
"""瓶颈分析 — 检测工作流性能瓶颈并生成优化建议。"""
analysis = bottleneck_detector.run_full_analysis(hours=hours)
optimizations = optimization_engine.generate_optimizations(analysis.get("bottlenecks", []))
# Build recommendations dict keyed by node_type for frontend lookup
recommendations_map = {}
for opt in optimizations:
recommendations_map[opt["node_type"]] = {
"node_type": opt["node_type"],
"severity": opt["severity"],
"current_state": opt.get("current_metrics", {}),
"changes": opt.get("changes", []),
}
return {
**analysis,
"recommendations": list(optimizations),
"optimizations": recommendations_map,
}
@router.get("/entries")
def get_knowledge_entries(
days: int = Query(7, ge=1, le=365, description="统计天数"),
limit: int = Query(50, ge=1, le=200, description="返回条数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取知识条目列表,按创建时间倒序。"""
since = datetime.now() - timedelta(days=days)
entries = (
db.query(KnowledgeEntry)
.filter(
KnowledgeEntry.created_at >= since,
KnowledgeEntry.is_active == True,
)
.order_by(KnowledgeEntry.created_at.desc())
.limit(limit)
.all()
)
return [e.to_dict() for e in entries]
@router.get("/trend")
def get_knowledge_trend(
days: int = Query(7, ge=1, le=365, description="统计天数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""知识条目增长趋势 — 按天统计新增数量。"""
since = datetime.now() - timedelta(days=days)
# GROUP BY date(created_at)
rows = (
db.query(
func.date(KnowledgeEntry.created_at).label("date"),
func.count(KnowledgeEntry.id).label("count"),
)
.filter(
KnowledgeEntry.created_at >= since,
KnowledgeEntry.is_active == True,
)
.group_by(func.date(KnowledgeEntry.created_at))
.order_by(func.date(KnowledgeEntry.created_at).asc())
.all()
)
# Fill in missing dates with 0 count
trend = []
current_date = since.date()
end_date = datetime.now().date()
date_counts = {row.date: row.count for row in rows}
while current_date <= end_date:
trend.append({
"date": current_date.isoformat(),
"count": date_counts.get(current_date, 0),
})
current_date += timedelta(days=1)
return trend

View File

@@ -58,9 +58,10 @@ async def create_agent_from_template_v1(
current_user: User = Depends(get_current_user),
):
try:
workflow_config = build_workflow_for_template(
body.template_id, body.parameters or {}
)
params = dict(body.parameters or {})
if body.contract_id:
params["contract_id"] = body.contract_id
workflow_config = build_workflow_for_template(body.template_id, params)
except ValueError as e:
raise ValidationError(str(e))

80
backend/app/api/push.py Normal file
View File

@@ -0,0 +1,80 @@
"""浏览器推送订阅 API"""
from __future__ import annotations
import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from app.api.auth import get_current_user, get_optional_user
from app.models.user import User
from app.core.database import SessionLocal
from app.models.push_subscription import PushSubscription
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/push", tags=["push"])
class PushSubscriptionRequest(BaseModel):
endpoint: str = Field(..., description="Push 订阅端点 URL")
keys: dict = Field(..., description="包含 p256dh 和 auth")
user_agent: str = Field(default="", description="设备 UA")
@router.post("/subscribe")
async def subscribe_push(
req: PushSubscriptionRequest,
current_user: User = Depends(get_optional_user),
):
"""保存浏览器推送订阅。用户可选登录。"""
db = SessionLocal()
try:
# 检查是否已存在相同 endpoint
existing = db.query(PushSubscription).filter(
PushSubscription.endpoint == req.endpoint
).first()
if existing:
# 更新已有记录
existing.p256dh = req.keys.get("p256dh", "")
existing.auth = req.keys.get("auth", "")
existing.user_id = str(current_user.id) if current_user else existing.user_id
existing.user_agent = req.user_agent
db.commit()
return {"status": "updated", "subscription_id": str(existing.id)}
sub = PushSubscription(
user_id=str(current_user.id) if current_user else None,
endpoint=req.endpoint,
p256dh=req.keys.get("p256dh", ""),
auth=req.keys.get("auth", ""),
user_agent=req.user_agent,
)
db.add(sub)
db.commit()
db.refresh(sub)
return {"status": "created", "subscription_id": str(sub.id)}
finally:
db.close()
@router.delete("/unsubscribe")
async def unsubscribe_push(
endpoint: str,
current_user: User = Depends(get_optional_user),
):
"""取消推送订阅。"""
db = SessionLocal()
try:
q = db.query(PushSubscription).filter(PushSubscription.endpoint == endpoint)
if current_user:
q = q.filter(PushSubscription.user_id == str(current_user.id))
sub = q.first()
if sub:
db.delete(sub)
db.commit()
return {"status": "deleted"}
return {"status": "not_found"}
finally:
db.close()

View File

@@ -0,0 +1,366 @@
"""
场景契约 API — 统一 DSL 输入契约的 CRUD 与预置查询
"""
from fastapi import APIRouter, Depends, HTTPException, Response, status, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Optional
import logging
import uuid
from app.core.database import get_db
from app.api.auth import get_current_user
from app.models.user import User
from app.models.scene_contract import SceneContract
from app.services.scene_contract_service import (
build_system_prompt_from_contract,
build_acceptance_prompt,
validate_input_against_contract,
get_preset_contract,
list_preset_contracts_meta,
PRESET_CONTRACTS,
ContractPromptConfig,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/scene-contracts", tags=["scene-contracts"])
# ── Pydantic schemas ──
class DeliverableItem(BaseModel):
name: str
format: Optional[str] = None
description: Optional[str] = None
class ExampleItem(BaseModel):
input: str
output: str
class ContractCreate(BaseModel):
name: str
description: Optional[str] = None
goal: str
role: Optional[str] = None
input_description: Optional[str] = None
input_schema: Optional[Dict[str, Any]] = None
constraints: List[str] = Field(default_factory=list)
forbidden_actions: List[str] = Field(default_factory=list)
required_tools: List[str] = Field(default_factory=list)
deliverables: List[DeliverableItem] = Field(default_factory=list)
acceptance_criteria: List[str] = Field(default_factory=list)
output_schema: Optional[Dict[str, Any]] = None
examples: List[ExampleItem] = Field(default_factory=list)
category: Optional[str] = None
tags: List[str] = Field(default_factory=list)
is_public: bool = False
template_binding: Optional[str] = None
class ContractUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
goal: Optional[str] = None
role: Optional[str] = None
input_description: Optional[str] = None
input_schema: Optional[Dict[str, Any]] = None
constraints: Optional[List[str]] = None
forbidden_actions: Optional[List[str]] = None
required_tools: Optional[List[str]] = None
deliverables: Optional[List[DeliverableItem]] = None
acceptance_criteria: Optional[List[str]] = None
output_schema: Optional[Dict[str, Any]] = None
examples: Optional[List[ExampleItem]] = None
category: Optional[str] = None
tags: Optional[List[str]] = None
is_public: Optional[bool] = None
template_binding: Optional[str] = None
class ContractResponse(BaseModel):
id: str
name: str
description: Optional[str] = None
goal: str
role: Optional[str] = None
input_description: Optional[str] = None
input_schema: Optional[Dict[str, Any]] = None
constraints: List[str] = []
forbidden_actions: List[str] = []
required_tools: List[str] = []
deliverables: List[Dict[str, Any]] = []
acceptance_criteria: List[str] = []
output_schema: Optional[Dict[str, Any]] = None
examples: List[Dict[str, Any]] = []
category: Optional[str] = None
tags: List[str] = []
version: int = 1
is_public: int = 0
use_count: int = 0
user_id: Optional[str] = None
template_binding: Optional[str] = None
created_at: Optional[str] = None
updated_at: Optional[str] = None
class ContractMetaItem(BaseModel):
id: str
name: str
description: Optional[str] = None
category: Optional[str] = None
tags: List[str] = []
goal: str = ""
deliverable_count: int = 0
constraint_count: int = 0
class PromptGenerateRequest(BaseModel):
contract_id: Optional[str] = None
contract_data: Optional[Dict[str, Any]] = None
config: Optional[Dict[str, Any]] = None
class PromptGenerateResponse(BaseModel):
system_prompt: str
class AcceptanceEvaluateRequest(BaseModel):
contract_id: Optional[str] = None
contract_data: Optional[Dict[str, Any]] = None
agent_output: str
class AcceptanceEvaluateResponse(BaseModel):
evaluation_prompt: str
class InputValidateRequest(BaseModel):
contract_id: Optional[str] = None
contract_data: Optional[Dict[str, Any]] = None
user_input: Dict[str, Any]
class InputValidateResponse(BaseModel):
valid: bool
errors: List[str] = []
warnings: List[str] = []
# ── Predefined preset contracts ──
@router.get("/presets", response_model=List[ContractMetaItem])
async def list_preset_contracts(current_user: User = Depends(get_current_user)):
"""列出所有预置场景契约"""
_ = current_user
return list_preset_contracts_meta()
@router.get("/presets/{contract_id}")
async def get_preset_contract_detail(
contract_id: str,
current_user: User = Depends(get_current_user),
):
"""获取单个预置契约的完整内容"""
_ = current_user
contract = get_preset_contract(contract_id)
if not contract:
raise HTTPException(status_code=404, detail=f"预置契约不存在: {contract_id}")
return {"id": contract_id, **contract}
# ── CRUD for user contracts ──
@router.get("/", response_model=List[ContractResponse])
async def list_contracts(
category: Optional[str] = Query(None),
is_public: Optional[bool] = Query(None),
template_binding: Optional[str] = Query(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""列出用户的场景契约"""
query = db.query(SceneContract).filter(
(SceneContract.user_id == current_user.id) | (SceneContract.is_public == 1)
)
if category:
query = query.filter(SceneContract.category == category)
if is_public is not None:
query = query.filter(SceneContract.is_public == (1 if is_public else 0))
if template_binding:
query = query.filter(SceneContract.template_binding == template_binding)
contracts = query.order_by(SceneContract.updated_at.desc()).all()
return [c.to_dict() for c in contracts]
@router.post("/", response_model=ContractResponse, status_code=status.HTTP_201_CREATED)
async def create_contract(
body: ContractCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""创建自定义场景契约"""
contract = SceneContract(
id=str(uuid.uuid4()),
name=body.name,
description=body.description,
goal=body.goal,
role=body.role,
input_description=body.input_description,
input_schema=body.input_schema,
constraints=body.constraints,
forbidden_actions=body.forbidden_actions,
required_tools=body.required_tools,
deliverables=[d.model_dump() for d in body.deliverables],
acceptance_criteria=body.acceptance_criteria,
output_schema=body.output_schema,
examples=[e.model_dump() for e in body.examples],
category=body.category,
tags=body.tags,
is_public=1 if body.is_public else 0,
template_binding=body.template_binding,
user_id=current_user.id,
)
db.add(contract)
db.commit()
db.refresh(contract)
return contract.to_dict()
@router.get("/{contract_id}", response_model=ContractResponse)
async def get_contract(
contract_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取场景契约详情"""
# 先查预置
preset = get_preset_contract(contract_id)
if preset:
return {"id": contract_id, "version": 1, "is_public": 1, "use_count": 0,
"user_id": None, "template_binding": None,
"created_at": None, "updated_at": None, **preset}
contract = db.query(SceneContract).filter(SceneContract.id == contract_id).first()
if not contract:
raise HTTPException(status_code=404, detail="契约不存在")
if contract.user_id != current_user.id and contract.is_public == 0:
raise HTTPException(status_code=403, detail="无权访问此契约")
return contract.to_dict()
@router.put("/{contract_id}", response_model=ContractResponse)
async def update_contract(
contract_id: str,
body: ContractUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新场景契约"""
contract = db.query(SceneContract).filter(SceneContract.id == contract_id).first()
if not contract:
raise HTTPException(status_code=404, detail="契约不存在")
if contract.user_id != current_user.id:
raise HTTPException(status_code=403, detail="只能修改自己的契约")
update_data = body.model_dump(exclude_unset=True)
# Handle list/dict fields
for field in ["deliverables", "examples"]:
if field in update_data and update_data[field] is not None:
update_data[field] = [item if isinstance(item, dict) else item.model_dump()
for item in update_data[field]]
if "is_public" in update_data and update_data["is_public"] is not None:
update_data["is_public"] = 1 if update_data["is_public"] else 0
for key, value in update_data.items():
if value is not None:
setattr(contract, key, value)
contract.version = (contract.version or 1) + 1
db.commit()
db.refresh(contract)
return contract.to_dict()
@router.delete("/{contract_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_contract(
contract_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""删除场景契约"""
contract = db.query(SceneContract).filter(SceneContract.id == contract_id).first()
if not contract:
raise HTTPException(status_code=404, detail="契约不存在")
if contract.user_id != current_user.id:
raise HTTPException(status_code=403, detail="只能删除自己的契约")
db.delete(contract)
db.commit()
return Response(status_code=status.HTTP_204_NO_CONTENT)
# ── DSL 操作端点 ──
@router.post("/generate-prompt", response_model=PromptGenerateResponse)
async def generate_prompt_from_contract(
body: PromptGenerateRequest,
current_user: User = Depends(get_current_user),
):
"""从契约生成 system prompt"""
_ = current_user
# 获取契约数据
contract_data = body.contract_data
if not contract_data and body.contract_id:
preset = get_preset_contract(body.contract_id)
if preset:
contract_data = preset
if not contract_data:
raise HTTPException(status_code=400, detail="请提供 contract_id 或 contract_data")
config = ContractPromptConfig(**(body.config or {}))
prompt = build_system_prompt_from_contract(contract_data, config)
return {"system_prompt": prompt}
@router.post("/evaluate", response_model=AcceptanceEvaluateResponse)
async def evaluate_output_against_contract(
body: AcceptanceEvaluateRequest,
current_user: User = Depends(get_current_user),
):
"""生成验收评估 prompt用于评估 Agent 输出是否满足契约)"""
_ = current_user
contract_data = body.contract_data
if not contract_data and body.contract_id:
preset = get_preset_contract(body.contract_id)
if preset:
contract_data = preset
if not contract_data:
raise HTTPException(status_code=400, detail="请提供 contract_id 或 contract_data")
prompt = build_acceptance_prompt(contract_data, body.agent_output)
return {"evaluation_prompt": prompt}
@router.post("/validate-input", response_model=InputValidateResponse)
async def validate_input_against_contract_endpoint(
body: InputValidateRequest,
current_user: User = Depends(get_current_user),
):
"""根据契约的 input_schema 验证用户输入"""
_ = current_user
contract_data = body.contract_data
if not contract_data and body.contract_id:
preset = get_preset_contract(body.contract_id)
if preset:
contract_data = preset
if not contract_data:
raise HTTPException(status_code=400, detail="请提供 contract_id 或 contract_data")
result = validate_input_against_contract(contract_data, body.user_input)
return result

View File

@@ -0,0 +1,350 @@
"""
系统日志统一查询 API
"""
import logging
from pathlib import Path
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, Depends, Query, HTTPException, status
from sqlalchemy.orm import Session
from sqlalchemy import func, text, case
from pydantic import BaseModel
from app.core.database import get_db
from app.core.config import settings
from app.api.auth import get_current_user
from app.models.user import User
from app.models.execution_log import ExecutionLog
from app.models.agent_execution_log import AgentExecutionLog
from app.models.agent_llm_log import AgentLLMLog
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/v1/system-logs",
tags=["system-logs"],
responses={
401: {"description": "未授权"},
403: {"description": "无权访问"},
500: {"description": "服务器内部错误"}
}
)
# ── Pydantic Schemas ─────────────────────────────────────────────
class UnifiedLogItem(BaseModel):
id: str
source: str # "execution" | "agent" | "llm"
level: Optional[str] = None
message: str
timestamp: Optional[datetime] = None
resource_type: Optional[str] = None
resource_id: Optional[str] = None
duration_ms: Optional[int] = None
username: Optional[str] = None
class Config:
from_attributes = True
class LogStatsResponse(BaseModel):
total_count: int
error_count: int
warn_count: int
info_count: int
source_breakdown: Dict[str, int] # {"execution": 123, "agent": 45, "llm": 67}
hourly_trend: List[Dict[str, Any]] # [{"hour": "2026-05-10T14", "count": 5}, ...]
class Config:
from_attributes = True
class AppLogItem(BaseModel):
line_number: int
content: str
# ── Helpers ──────────────────────────────────────────────────────
def _build_union_query(
db: Session,
source: Optional[str] = None,
level: Optional[str] = None,
keyword: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
user_id: Optional[str] = None,
):
"""构建跨表 UNION ALL 查询"""
queries: list = []
params: dict = {}
# 1) 工作流执行日志
if source in (None, "all", "execution"):
q = db.query(
ExecutionLog.id.label("id"),
text("'execution'").label("source"),
ExecutionLog.level.label("level"),
ExecutionLog.message.label("message"),
ExecutionLog.timestamp.label("timestamp"),
ExecutionLog.node_type.label("resource_type"),
ExecutionLog.execution_id.label("resource_id"),
ExecutionLog.duration.label("duration_ms"),
text("NULL").label("username"),
)
if level:
q = q.filter(ExecutionLog.level == level.upper())
if keyword:
q = q.filter(ExecutionLog.message.contains(keyword))
if start_date:
q = q.filter(ExecutionLog.timestamp >= start_date)
if end_date:
q = q.filter(ExecutionLog.timestamp <= end_date)
queries.append(q)
# 2) Agent 执行日志
if source in (None, "all", "agent"):
q = db.query(
AgentExecutionLog.id.label("id"),
text("'agent'").label("source"),
case(
(AgentExecutionLog.status == "failed", "ERROR"),
(AgentExecutionLog.status == "completed", "INFO"),
else_="INFO"
).label("level"),
AgentExecutionLog.user_message.label("message"),
AgentExecutionLog.created_at.label("timestamp"),
text("'agent_chat'").label("resource_type"),
AgentExecutionLog.agent_id.label("resource_id"),
AgentExecutionLog.total_latency_ms.label("duration_ms"),
text("NULL").label("username"),
)
if level:
if level.upper() == "ERROR":
q = q.filter(AgentExecutionLog.status == "failed")
elif level.upper() == "WARN":
q = q.filter(AgentExecutionLog.status == "failed")
else:
q = q.filter(AgentExecutionLog.status != "failed")
if keyword:
q = q.filter(AgentExecutionLog.user_message.contains(keyword))
if start_date:
q = q.filter(AgentExecutionLog.created_at >= start_date)
if end_date:
q = q.filter(AgentExecutionLog.created_at <= end_date)
if user_id:
q = q.filter(AgentExecutionLog.user_id == user_id)
queries.append(q)
# 3) LLM 调用日志
if source in (None, "all", "llm"):
q = db.query(
AgentLLMLog.id.label("id"),
text("'llm'").label("source"),
case(
(AgentLLMLog.status == "error", "ERROR"),
(AgentLLMLog.status == "rate_limited", "WARN"),
else_="INFO"
).label("level"),
AgentLLMLog.assistant_content.label("message"),
AgentLLMLog.created_at.label("timestamp"),
text("'llm_call'").label("resource_type"),
AgentLLMLog.agent_id.label("resource_id"),
AgentLLMLog.latency_ms.label("duration_ms"),
text("NULL").label("username"),
)
if level:
if level.upper() == "ERROR":
q = q.filter(AgentLLMLog.status == "error")
elif level.upper() == "WARN":
q = q.filter(AgentLLMLog.status == "rate_limited")
else:
q = q.filter(AgentLLMLog.status == "success")
if keyword:
q = q.filter(AgentLLMLog.assistant_content.contains(keyword))
if start_date:
q = q.filter(AgentLLMLog.created_at >= start_date)
if end_date:
q = q.filter(AgentLLMLog.created_at <= end_date)
queries.append(q)
return queries
def _check_admin(current_user: User):
if getattr(current_user, "role", None) != "admin":
from app.core.exceptions import ForbiddenError
raise ForbiddenError("仅管理员可访问系统日志")
# ── Endpoints ────────────────────────────────────────────────────
@router.get("", response_model=List[UnifiedLogItem])
async def get_system_logs(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
source: Optional[str] = Query(None, description="日志来源: execution/agent/llm/all"),
level: Optional[str] = Query(None, description="日志级别: INFO/WARN/ERROR"),
keyword: Optional[str] = Query(None, description="关键词搜索"),
start_date: Optional[str] = Query(None, description="开始时间 ISO格式"),
end_date: Optional[str] = Query(None, description="结束时间 ISO格式"),
skip: int = Query(0, ge=0),
limit: int = Query(50, ge=1, le=1000),
):
"""
统一日志查询:跨 execution_logs / agent_execution_logs / agent_llm_logs 联合查询。
管理员可查全部,普通用户只能查看自己相关的 Agent 执行日志和 LLM 日志。
"""
_check_admin(current_user)
# 解析时间
sd = datetime.fromisoformat(start_date) if start_date else None
ed = datetime.fromisoformat(end_date) if end_date else None
# 非 admin 限制用户范围
user_id = None
if getattr(current_user, "role", None) != "admin":
user_id = current_user.id
source_val = source if source else "all"
queries = _build_union_query(db, source_val, level, keyword, sd, ed, user_id)
if not queries:
return []
# UNION ALL
union = queries[0]
for q in queries[1:]:
union = union.union_all(q)
# 排序 + 分页
total = union.count() if hasattr(union, 'count') else 0
rows = union.order_by(text("timestamp DESC")).offset(skip).limit(limit).all()
return rows
@router.get("/stats", response_model=LogStatsResponse)
async def get_system_logs_stats(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取系统日志统计各级别计数、各来源计数、24小时趋势"""
_check_admin(current_user)
now = datetime.utcnow()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
hours_24 = now - timedelta(hours=24)
# 工作流执行日志统计
exec_total = db.query(func.count(ExecutionLog.id)).filter(
ExecutionLog.timestamp >= today_start
).scalar() or 0
exec_errors = db.query(func.count(ExecutionLog.id)).filter(
ExecutionLog.timestamp >= today_start,
ExecutionLog.level == "ERROR"
).scalar() or 0
exec_warns = db.query(func.count(ExecutionLog.id)).filter(
ExecutionLog.timestamp >= today_start,
ExecutionLog.level == "WARN"
).scalar() or 0
# Agent 执行日志统计
agent_total = db.query(func.count(AgentExecutionLog.id)).filter(
AgentExecutionLog.created_at >= today_start
).scalar() or 0
agent_errors = db.query(func.count(AgentExecutionLog.id)).filter(
AgentExecutionLog.created_at >= today_start,
AgentExecutionLog.status == "failed"
).scalar() or 0
# LLM 调用日志统计
llm_total = db.query(func.count(AgentLLMLog.id)).filter(
AgentLLMLog.created_at >= today_start
).scalar() or 0
llm_errors = db.query(func.count(AgentLLMLog.id)).filter(
AgentLLMLog.created_at >= today_start,
AgentLLMLog.status == "error"
).scalar() or 0
total_count = exec_total + agent_total + llm_total
error_count = exec_errors + agent_errors + llm_errors
warn_count = exec_warns
info_count = total_count - error_count - warn_count
# 24小时趋势按小时聚合
hourly: list[dict[str, Any]] = []
for h in range(23, -1, -1):
slot_start = now.replace(minute=0, second=0, microsecond=0) - timedelta(hours=h)
slot_end = slot_start + timedelta(hours=1)
exec_h = db.query(func.count(ExecutionLog.id)).filter(
ExecutionLog.timestamp >= slot_start,
ExecutionLog.timestamp < slot_end
).scalar() or 0
agent_h = db.query(func.count(AgentExecutionLog.id)).filter(
AgentExecutionLog.created_at >= slot_start,
AgentExecutionLog.created_at < slot_end
).scalar() or 0
llm_h = db.query(func.count(AgentLLMLog.id)).filter(
AgentLLMLog.created_at >= slot_start,
AgentLLMLog.created_at < slot_end
).scalar() or 0
hourly.append({
"hour": slot_start.isoformat(),
"count": exec_h + agent_h + llm_h
})
return {
"total_count": total_count,
"error_count": error_count,
"warn_count": warn_count,
"info_count": info_count,
"source_breakdown": {
"execution": exec_total,
"agent": agent_total,
"llm": llm_total
},
"hourly_trend": hourly
}
@router.get("/app-logs", response_model=List[AppLogItem])
async def get_app_logs(
current_user: User = Depends(get_current_user),
lines: int = Query(200, ge=10, le=2000, description="读取行数"),
level: Optional[str] = Query(None, description="按级别过滤: INFO/WARNING/ERROR"),
):
"""读取应用程序文件日志的尾部行"""
_check_admin(current_user)
log_file = Path(settings.LOG_DIR) / "app.log"
if not log_file.exists():
return []
level_filter = level.upper() if level else None
try:
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
all_lines = f.readlines()
except Exception:
return []
# 取尾部 N 行,倒序输出
tail = all_lines[-lines:]
result: list[dict] = []
for i, raw in enumerate(tail):
content = raw.rstrip("\n").rstrip("\r")
if level_filter and level_filter not in content.upper():
continue
result.append({
"line_number": len(all_lines) - len(tail) + i + 1,
"content": content
})
return result

View File

@@ -1,5 +1,7 @@
"""
Task API — 任务管理接口
Task API — 任务管理接口 (含原子认领 + 依赖图 + Agent 状态)
参考 Claude Code src/utils/tasks.ts 的 API 设计
"""
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
@@ -9,6 +11,7 @@ from datetime import datetime
import logging
from app.core.database import get_db
from app.core.task_system import TaskSystem, ClaimResult, AgentState, AgentStatus
from app.api.auth import get_current_user
from app.models.user import User
from app.services import goal_service
@@ -92,6 +95,42 @@ class TaskDependencyCheck(BaseModel):
pending_dependencies: List[str] = []
class ClaimTaskRequest(BaseModel):
agent_id: str = Field(..., description="认领任务的 Agent 标识")
check_busy: bool = Field(default=True, description="是否检查 Agent 忙碌状态")
class ClaimTaskResponse(BaseModel):
success: bool
reason: Optional[str] = None
task: Optional[TaskResponse] = None
busy_with_tasks: List[str] = []
blocked_by_tasks: List[str] = []
class BlockTaskRequest(BaseModel):
from_task_id: str = Field(..., description="阻塞方任务ID")
to_task_id: str = Field(..., description="被阻塞方任务ID")
class AgentStatusResponse(BaseModel):
agent_id: str
status: str # idle / busy
current_tasks: List[str] = []
class ReleaseTaskRequest(BaseModel):
agent_id: str = Field(..., description="释放任务的 Agent 标识")
class TaskCompleteRequest(BaseModel):
result: Optional[Dict[str, Any]] = None
class TaskFailRequest(BaseModel):
error_message: str = ""
# ──────────────────────────── Endpoints ────────────────────────────
@router.post("", response_model=TaskResponse, status_code=201)
@@ -307,3 +346,160 @@ async def retry_task(
error_message=str(e),
)
return goal_service.get_task(db, task_id)
# ════════════════════ 任务系统增强 (参考 Claude Code task_system) ════════════════════
def _get_task_system(db: Session) -> TaskSystem:
return TaskSystem(db)
@router.post("/{task_id}/claim", response_model=ClaimTaskResponse)
def claim_task(
task_id: str,
data: ClaimTaskRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""原子认领任务 (SELECT FOR UPDATE),检查依赖+Agent忙碌"""
ts = _get_task_system(db)
result = ts.claim_task(task_id=task_id, agent_id=data.agent_id, check_busy=data.check_busy)
return ClaimTaskResponse(
success=result.success,
reason=result.reason,
task=TaskResponse.model_validate(result.task) if result.task else None,
busy_with_tasks=result.busy_with_tasks,
blocked_by_tasks=result.blocked_by_tasks,
)
@router.post("/block", status_code=200)
def block_task(
data: BlockTaskRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""设置任务依赖: from_task 阻塞 to_task"""
ts = _get_task_system(db)
ok = ts.block_task(data.from_task_id, data.to_task_id)
if not ok:
from fastapi import HTTPException
raise HTTPException(404, "任务不存在")
return {"message": "ok", "from": data.from_task_id, "to": data.to_task_id}
@router.delete("/block", status_code=200)
def unblock_task(
data: BlockTaskRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""移除任务依赖"""
ts = _get_task_system(db)
ok = ts.unblock_task(data.from_task_id, data.to_task_id)
if not ok:
from fastapi import HTTPException
raise HTTPException(404, "任务不存在")
return {"message": "ok"}
@router.get("/agent/{agent_id}/status", response_model=AgentStatusResponse)
def get_agent_status(
agent_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取 Agent 忙闲状态 (idle/busy)"""
ts = _get_task_system(db)
state = ts.get_agent_status(agent_id)
return AgentStatusResponse(
agent_id=state.agent_id,
status=state.status.value,
current_tasks=state.current_tasks,
)
@router.post("/{task_id}/release", status_code=200)
def release_task(
task_id: str,
data: ReleaseTaskRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Agent 主动释放单个任务"""
ts = _get_task_system(db)
ok = ts.release_task(task_id, data.agent_id)
if not ok:
from fastapi import HTTPException
raise HTTPException(404, "任务不存在或不属于该 Agent")
return {"message": "ok", "task_id": task_id}
@router.post("/agent/{agent_id}/unassign", status_code=200)
def unassign_agent_tasks(
agent_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""释放 Agent 所有未完成任务 (Agent 下线时调用)"""
ts = _get_task_system(db)
tasks = ts.unassign_agent_tasks(agent_id)
return {"message": "ok", "unassigned_count": len(tasks), "task_ids": [t.id for t in tasks]}
@router.post("/{task_id}/complete", response_model=TaskResponse)
def complete_task(
task_id: str,
data: TaskCompleteRequest = TaskCompleteRequest(),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""标记任务完成 (自动检查被阻塞任务)"""
ts = _get_task_system(db)
task = ts.complete_task(task_id, result=data.result)
if not task:
from fastapi import HTTPException
raise HTTPException(404, "任务不存在")
return TaskResponse.model_validate(task)
@router.post("/{task_id}/fail", response_model=TaskResponse)
def fail_task(
task_id: str,
data: TaskFailRequest = TaskFailRequest(),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""标记任务失败"""
ts = _get_task_system(db)
task = ts.fail_task(task_id, error_message=data.error_message)
if not task:
from fastapi import HTTPException
raise HTTPException(404, "任务不存在")
return TaskResponse.model_validate(task)
@router.get("/available/{goal_id}", response_model=List[TaskResponse])
def get_available_tasks(
goal_id: str,
limit: int = Query(default=10, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取目标下所有可执行任务 (依赖满足 + 未被认领)"""
ts = _get_task_system(db)
tasks = ts.get_next_available_tasks(goal_id, limit=limit)
return [TaskResponse.model_validate(t) for t in tasks]
@router.get("/{task_id}/blockers", response_model=List[TaskResponse])
def get_task_blockers(
task_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""获取任务尚未完成的阻塞依赖"""
ts = _get_task_system(db)
blockers = ts.get_unresolved_blockers(task_id)
return [TaskResponse.model_validate(t) for t in blockers]

View File

@@ -8,7 +8,7 @@ from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, Path
from fastapi import APIRouter, Depends, HTTPException, Query, Path, Request
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -33,6 +33,8 @@ class ToolCreate(BaseModel):
implementation_type: str # builtin / http / code / workflow
implementation_config: Optional[dict] = None
is_public: bool = False
status: str = "active" # active / draft / deprecated
tags: Optional[List[str]] = None
class ToolResponse(BaseModel):
@@ -44,6 +46,8 @@ class ToolResponse(BaseModel):
implementation_type: str
implementation_config: Optional[dict] = None
is_public: bool = False
status: str = "active"
tags: Optional[List[str]] = None
use_count: int = 0
user_id: Optional[str] = None
created_at: str = ""
@@ -86,6 +90,8 @@ def _tool_to_dict(tool: Tool) -> dict:
"implementation_type": tool.implementation_type,
"implementation_config": tool.implementation_config,
"is_public": tool.is_public,
"status": tool.status if hasattr(tool, "status") else "active",
"tags": tool.tags if hasattr(tool, "tags") else None,
"use_count": tool.use_count,
"user_id": tool.user_id,
"created_at": tool.created_at.isoformat() if tool.created_at else "",
@@ -139,12 +145,18 @@ async def list_tools(
category: Optional[str] = Query(None, description="按分类筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
scope: Optional[str] = Query("public", description="public / mine / all"),
status: Optional[str] = Query(None, description="状态筛选: active/draft/deprecated"),
workspace_id: Optional[str] = Query(None, description="工作区ID筛选"),
db: Session = Depends(get_db),
current_user: Optional[User] = Depends(get_current_user),
):
"""浏览工具市场(含内置工具 + 数据库工具)。"""
query = db.query(Tool)
# 工作区筛选
if workspace_id:
query = query.filter(Tool.workspace_id == workspace_id)
if scope == "public":
query = query.filter(Tool.is_public == True)
elif scope == "mine":
@@ -152,6 +164,12 @@ async def list_tools(
raise HTTPException(status_code=401, detail="需登录")
query = query.filter(Tool.user_id == current_user.id)
if status:
query = query.filter(Tool.status == status)
else:
# 默认只显示 active 状态的工具(排除 deprecated
query = query.filter((Tool.status == "active") | (Tool.status == None))
if category:
query = query.filter(Tool.category == category)
if search:
@@ -216,10 +234,20 @@ async def get_tool(
@router.post("", response_model=ToolResponse, status_code=201)
async def create_tool(
tool_data: ToolCreate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""创建自定义工具。"""
# 从 JWT 提取当前工作区 ID
from app.core.security import decode_access_token
ws_id = None
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
payload = decode_access_token(auth_header[7:])
if payload:
ws_id = payload.get("ws") or None
existing = db.query(Tool).filter(Tool.name == tool_data.name).first()
if existing:
raise HTTPException(status_code=400, detail=f"工具名 '{tool_data.name}' 已存在")
@@ -237,20 +265,24 @@ async def create_tool(
implementation_type=tool_data.implementation_type,
implementation_config=tool_data.implementation_config,
is_public=tool_data.is_public,
status=tool_data.status,
tags=tool_data.tags,
user_id=current_user.id,
workspace_id=ws_id,
)
db.add(tool)
db.commit()
db.refresh(tool)
logger.info("工具已创建: %s (type=%s)", tool.name, tool.implementation_type)
logger.info("工具已创建: %s (type=%s, status=%s)", tool.name, tool.implementation_type, tool.status)
# 刷新注册表
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
tool_registry._tool_schemas[tool.name] = tool.function_schema
# 仅 active 状态注入注册表
if tool.status == "active" or tool.status is None:
tool_registry._custom_tool_configs[tool.name] = {
**(tool.implementation_config or {}),
"_type": tool.implementation_type,
"_db_id": tool.id,
}
tool_registry._tool_schemas[tool.name] = tool.function_schema
return _tool_to_dict(tool)
@@ -317,12 +349,27 @@ async def delete_tool(
db.commit()
# 清理注册表
tool_registry._custom_tool_configs.pop(tool.name, None)
tool_registry._tool_schemas.pop(tool.name, None)
tool_registry.unregister_tool(tool.name)
return {"message": "工具已删除"}
@router.post("/reload")
async def reload_tools(db: Session = Depends(get_db)):
"""从数据库重新加载所有自定义工具到注册表(热更新)。"""
# 清除旧的自定义工具
for name in list(tool_registry._custom_tool_configs.keys()):
if name not in tool_registry._builtin_tools:
tool_registry._custom_tool_configs.pop(name, None)
tool_registry._tool_schemas.pop(name, None)
# 重新加载
tool_registry.load_tools_from_db(db)
count = len(tool_registry._custom_tool_configs)
logger.info("工具热更新完成,已加载 %d 个自定义工具", count)
return {"message": f"工具注册表已刷新", "custom_tool_count": count}
# ─── 工具测试 ──────────────────────────────────────────────────

261
backend/app/api/voice.py Normal file
View File

@@ -0,0 +1,261 @@
"""
语音 API — Android 应用语音交互接口
提供语音转文字 (ASR) 和文字转语音 (TTS) 的 HTTP API。
TTS 优先使用 OpenAI TTS需配置 OPENAI_API_KEY否则使用免费的 Edge TTS。
"""
from __future__ import annotations
import asyncio
import logging
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from app.api.auth import get_current_user
from app.models.user import User
from app.core.config import settings
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/voice", tags=["voice"])
# Edge TTS 中文语音映射 (OpenAI voice name -> Edge TTS Chinese voice)
_EDGE_VOICE_MAP = {
"alloy": "zh-CN-YunxiNeural",
"echo": "zh-CN-YunyangNeural",
"fable": "zh-CN-XiaoxiaoNeural",
"onyx": "zh-CN-YunjianNeural",
"nova": "zh-CN-XiaoyiNeural",
"shimmer": "zh-CN-XiaoxiaoNeural",
}
class AsrResponse(BaseModel):
"""语音识别响应"""
text: str = Field(..., description="识别出的文字")
language: str = Field(default="zh", description="语言代码")
class TtsRequest(BaseModel):
"""文字转语音请求"""
text: str = Field(..., min_length=1, max_length=4000, description="要合成的文字")
voice: str = Field(
default="alloy",
description="语音风格alloy / echo / fable / onyx / nova / shimmer",
)
class TtsResponse(BaseModel):
"""文字转语音响应"""
audio_url: str = Field(..., description="音频文件下载 URL")
text_length: int = Field(..., description="文字长度")
voice: str = Field(..., description="使用的语音风格")
# TTS 输出缓存目录
_TTS_DIR = Path(settings.LOCAL_FILE_TOOLS_ROOT) / "tts_outputs"
@router.post("/asr", response_model=AsrResponse)
async def voice_to_text(
file: UploadFile = File(..., description="音频文件AAC/WAV/MP3/WebM/M4A"),
language: str = Query("zh", description="语言代码"),
current_user: User = Depends(get_current_user),
):
"""
语音转文字 (ASR)。
接收 Android 端录音上传的 AAC 音频文件,调用 Whisper API 返回识别文字。
采样率建议 16000 Hz。
"""
# 验证文件
if not file.filename:
raise HTTPException(status_code=400, detail="文件名为空")
ext = (file.filename.rsplit(".", 1)[-1].lower() if "." in file.filename else "aac")
allowed = {"aac", "wav", "mp3", "webm", "m4a", "ogg", "flac", "mpeg"}
if ext not in allowed:
raise HTTPException(status_code=400, detail=f"不支持的音频格式: {ext}")
# 检查文件大小
content = await file.read()
max_bytes = 25 * 1024 * 1024 # 25 MB
if len(content) > max_bytes:
raise HTTPException(status_code=400, detail=f"文件过大 ({len(content) / 1024 / 1024:.1f} MB)")
# 写入临时文件
tmp = tempfile.NamedTemporaryFile(suffix=f".{ext}", delete=False)
try:
tmp.write(content)
tmp.close()
api_key = (getattr(settings, "OPENAI_API_KEY", "") or "").strip()
base_url = (
getattr(settings, "OPENAI_BASE_URL", "https://api.openai.com/v1")
or "https://api.openai.com/v1"
).strip()
if not api_key:
raise HTTPException(status_code=503, detail="ASR 服务未配置 (OPENAI_API_KEY)")
import httpx
async with httpx.AsyncClient(timeout=120) as client:
with open(tmp.name, "rb") as f:
files_payload = {
"file": (Path(tmp.name).name, f, f"audio/{ext}"),
"model": (None, "whisper-1"),
"language": (None, language),
}
resp = await client.post(
f"{base_url.rstrip('/')}/audio/transcriptions",
headers={"Authorization": f"Bearer {api_key}"},
files=files_payload,
)
if resp.status_code != 200:
logger.error("Whisper API 错误 %d: %s", resp.status_code, resp.text[:500])
raise HTTPException(
status_code=502,
detail=f"语音识别服务返回错误 (HTTP {resp.status_code})",
)
data = resp.json()
text = data.get("text", "")
logger.info("ASR 完成: user=%s file=%s len=%d", current_user.id, file.filename, len(text))
return AsrResponse(text=text, language=language)
finally:
Path(tmp.name).unlink(missing_ok=True)
@router.post("/tts", response_model=TtsResponse)
async def text_to_voice(
req: TtsRequest,
current_user: User = Depends(get_current_user),
):
"""
文字转语音 (TTS)。
优先使用 OpenAI TTS需配置有效 OPENAI_API_KEY否则使用免费 Edge TTS。
Android 端使用 ExoPlayer 播放返回的 audio_url。
"""
valid_voices = {"alloy", "echo", "fable", "onyx", "nova", "shimmer"}
voice = req.voice if req.voice in valid_voices else "alloy"
text = req.text.strip()
if len(text) > 4000:
text = text[:4000] + "..."
_TTS_DIR.mkdir(parents=True, exist_ok=True)
filename = f"tts_{current_user.id}_{int(time.time())}.mp3"
filepath = _TTS_DIR / filename
# 尝试 OpenAI TTS
api_key = (getattr(settings, "OPENAI_API_KEY", "") or "").strip()
base_url = (
getattr(settings, "OPENAI_BASE_URL", "https://api.openai.com/v1")
or "https://api.openai.com/v1"
).strip()
use_edge = False
if api_key and api_key not in ("your-openai-api-key", "sk-your-"):
import httpx
try:
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{base_url.rstrip('/')}/audio/speech",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={"model": "tts-1", "voice": voice, "input": text},
)
if resp.status_code == 200:
filepath.write_bytes(resp.content)
logger.info("OpenAI TTS 完成: user=%s text_len=%d voice=%s", current_user.id, len(req.text), voice)
return TtsResponse(
audio_url=f"/api/v1/voice/audio/{filename}",
text_length=len(req.text),
voice=voice,
)
else:
logger.warning("OpenAI TTS 失败 (%d), 回退到 Edge TTS", resp.status_code)
use_edge = True
except Exception as exc:
logger.warning("OpenAI TTS 异常: %s, 回退到 Edge TTS", exc)
use_edge = True
else:
use_edge = True
# 回退Edge TTS免费无需 API KEY
if use_edge:
edge_voice = _EDGE_VOICE_MAP.get(voice, "zh-CN-YunxiNeural")
try:
await _edge_tts_synthesize(text, edge_voice, str(filepath))
logger.info("Edge TTS 完成: user=%s text_len=%d edge_voice=%s", current_user.id, len(req.text), edge_voice)
return TtsResponse(
audio_url=f"/api/v1/voice/audio/{filename}",
text_length=len(req.text),
voice=voice,
)
except Exception as exc:
logger.error("Edge TTS 失败: %s", exc)
raise HTTPException(status_code=502, detail=f"TTS 服务不可用: {exc}")
async def _edge_tts_synthesize(text: str, voice: str, output_path: str) -> None:
"""使用 edge-tts 命令行工具合成语音。"""
# 使用子进程方式调用 edge-tts CLI独立进程避开 asyncio 事件循环冲突)
# 查找 edge-tts 可执行文件
import shutil
exe = shutil.which("edge-tts") or shutil.which("edge-tts", path=(
os.environ.get("PATH", "") + os.pathsep +
os.path.join(os.path.dirname(sys.executable), "Scripts") + os.pathsep +
os.path.join(os.path.dirname(sys.executable), "..", "Scripts")
))
if not exe:
exe = "edge-tts" # fallback, let subprocess try PATH
logger.debug("edge-tts exe: %s", exe)
proc = await asyncio.create_subprocess_exec(
exe, "--text", text, "--voice", voice, "--write-media", output_path,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
out_str = stdout.decode(errors="replace") if stdout else ""
err_str = stderr.decode(errors="replace") if stderr else ""
if proc.returncode != 0:
raise RuntimeError(f"edge-tts CLI 失败 (exit={proc.returncode}): {err_str or out_str}")
if not Path(output_path).is_file():
raise RuntimeError(f"edge-tts 未生成输出文件: {out_str} {err_str}")
logger.info("edge-tts CLI 成功: %d bytes", Path(output_path).stat().st_size)
@router.get("/audio/{filename}")
async def get_tts_audio(filename: str):
"""获取 TTS 生成的音频文件无需认证URL 包含随机名)。"""
if ".." in filename or "/" in filename or "\\" in filename:
raise HTTPException(status_code=400, detail="非法文件名")
filepath = _TTS_DIR / filename
if not filepath.is_file():
raise HTTPException(status_code=404, detail="音频文件不存在或已过期")
return FileResponse(filepath, media_type="audio/mpeg")

View File

@@ -1,7 +1,7 @@
"""
工作流API
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Request, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
@@ -142,12 +142,13 @@ async def get_workflows(
limit: int = 100,
search: Optional[str] = None,
status: Optional[str] = None,
workspace_id: Optional[str] = Query(None, description="工作区ID筛选"),
sort_by: Optional[str] = "created_at",
sort_order: Optional[str] = "desc",
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取工作流列表(支持搜索、筛选、排序)"""
"""获取工作流列表(支持搜索、筛选、排序、工作区筛选"""
# 管理员可以看到所有工作流普通用户只能看到自己拥有的或有read权限的
if current_user.role == "admin":
query = db.query(Workflow)
@@ -155,10 +156,10 @@ async def get_workflows(
# 获取用户拥有或有read权限的工作流
from sqlalchemy import or_
from app.models.permission import WorkflowPermission
# 用户拥有的工作流
owned_workflows = db.query(Workflow.id).filter(Workflow.user_id == current_user.id).subquery()
# 用户有read权限的工作流通过用户ID或角色
user_permissions = db.query(WorkflowPermission.workflow_id).filter(
WorkflowPermission.permission_type == "read",
@@ -167,14 +168,18 @@ async def get_workflows(
WorkflowPermission.role_id.in_([r.id for r in current_user.roles])
)
).subquery()
query = db.query(Workflow).filter(
or_(
Workflow.id.in_(db.query(owned_workflows.c.id)),
Workflow.id.in_(db.query(user_permissions.c.workflow_id))
)
)
# 工作区筛选
if workspace_id:
query = query.filter(Workflow.workspace_id == workspace_id)
# 搜索:按名称或描述搜索
if search:
search_pattern = f"%{search}%"
@@ -211,25 +216,36 @@ async def get_workflows(
@router.post("", response_model=WorkflowResponse, status_code=status.HTTP_201_CREATED)
async def create_workflow(
workflow_data: WorkflowCreate,
request: Request,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建工作流"""
# 从 JWT 提取当前工作区 ID
from app.core.security import decode_access_token
ws_id = None
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
payload = decode_access_token(auth_header[7:])
if payload:
ws_id = payload.get("ws") or None
# 验证工作流
validation_result = validate_workflow(workflow_data.nodes, workflow_data.edges)
if not validation_result["valid"]:
raise ValidationError(f"工作流验证失败: {', '.join(validation_result['errors'])}")
# 如果有警告,记录日志
if validation_result["warnings"]:
logger.warning(f"工作流创建警告: {', '.join(validation_result['warnings'])}")
workflow = Workflow(
name=workflow_data.name,
description=workflow_data.description,
nodes=workflow_data.nodes,
edges=workflow_data.edges,
user_id=current_user.id
user_id=current_user.id,
workspace_id=ws_id,
)
db.add(workflow)
db.commit()

View File

@@ -0,0 +1,393 @@
"""
工作区 (Workspace) API — 多租户管理
"""
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from datetime import datetime
import uuid
import logging
from app.core.database import get_db
from app.api.auth import get_current_user
from app.models.user import User
from app.models.workspace import Workspace, WorkspaceMembership
from app.services.workspace_service import check_workspace_access, get_user_workspaces
from app.core.exceptions import NotFoundError
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/v1/workspaces",
tags=["workspaces"],
responses={
401: {"description": "未授权"},
403: {"description": "无权访问"},
404: {"description": "资源不存在"},
}
)
# ── Pydantic Schemas ──
class WorkspaceCreate(BaseModel):
name: str = Field(..., min_length=1, max_length=100, description="工作区名称")
description: Optional[str] = Field(None, max_length=1000)
max_members: int = Field(default=50, ge=1, le=500)
class WorkspaceUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=1000)
max_members: Optional[int] = Field(None, ge=1, le=500)
status: Optional[str] = None # active/disabled
class MemberAddRequest(BaseModel):
user_id: Optional[str] = None
username: Optional[str] = None
role: str = Field(default="member", pattern="^(admin|member)$")
class MemberUpdateRequest(BaseModel):
role: str = Field(..., pattern="^(admin|member)$")
class WorkspaceResponse(BaseModel):
id: str
name: str
description: Optional[str]
is_default: bool
owner_id: str
max_members: int
settings: Optional[Dict[str, Any]]
member_count: int = 0
status: str
created_at: Optional[str]
updated_at: Optional[str]
class MemberResponse(BaseModel):
id: str
user_id: str
username: str
email: str
role: str
joined_at: Optional[str]
# ── Endpoints ──
@router.get("", response_model=List[Dict[str, Any]])
def list_workspaces(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取当前用户的工作区列表。"""
return get_user_workspaces(db, current_user)
@router.post("", status_code=status.HTTP_201_CREATED)
def create_workspace(
data: WorkspaceCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""创建新工作区。"""
ws = Workspace(
id=str(uuid.uuid4()),
name=data.name,
description=data.description,
owner_id=current_user.id,
max_members=data.max_members,
status="active",
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
db.add(ws)
db.flush()
# 创建者自动成为工作区管理员
membership = WorkspaceMembership(
id=str(uuid.uuid4()),
workspace_id=ws.id,
user_id=current_user.id,
role="admin",
joined_at=datetime.utcnow(),
)
db.add(membership)
db.commit()
return {
"id": ws.id,
"name": ws.name,
"description": ws.description,
"is_default": bool(ws.is_default),
"owner_id": ws.owner_id,
"max_members": ws.max_members,
"role": "admin",
"member_count": 1,
"status": ws.status,
}
@router.get("/{workspace_id}")
def get_workspace(
workspace_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取工作区详情。"""
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if not ws:
raise NotFoundError("工作区", workspace_id)
if not check_workspace_access(db, current_user, workspace_id):
raise HTTPException(status_code=403, detail="无权访问此工作区")
member_count = (
db.query(WorkspaceMembership)
.filter(WorkspaceMembership.workspace_id == workspace_id)
.count()
)
user_role = "admin" if current_user.role == "admin" else None
if not user_role:
membership = (
db.query(WorkspaceMembership)
.filter(
WorkspaceMembership.workspace_id == workspace_id,
WorkspaceMembership.user_id == current_user.id,
)
.first()
)
if membership:
user_role = membership.role
return {
**ws.to_dict(),
"member_count": member_count,
"role": user_role,
}
@router.put("/{workspace_id}")
def update_workspace(
workspace_id: str,
data: WorkspaceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""更新工作区设置(需工作区管理员权限)。"""
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if not ws:
raise NotFoundError("工作区", workspace_id)
if not check_workspace_access(db, current_user, workspace_id, required_role="admin"):
raise HTTPException(status_code=403, detail="需要工作区管理员权限")
if data.name is not None:
ws.name = data.name
if data.description is not None:
ws.description = data.description
if data.max_members is not None:
ws.max_members = data.max_members
if data.status is not None:
ws.status = data.status
ws.updated_at = datetime.utcnow()
db.commit()
return {**ws.to_dict()}
@router.delete("/{workspace_id}")
def delete_workspace(
workspace_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""删除工作区(软删除,仅平台管理员或工作区管理员可操作)。"""
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if not ws:
raise NotFoundError("工作区", workspace_id)
if not check_workspace_access(db, current_user, workspace_id, required_role="admin"):
raise HTTPException(status_code=403, detail="需要工作区管理员权限")
if ws.is_default and current_user.role != "admin":
raise HTTPException(status_code=403, detail="默认工作区不可删除")
ws.status = "deleted"
ws.updated_at = datetime.utcnow()
db.commit()
return {"message": "工作区已删除"}
# ── 成员管理 ──
@router.get("/{workspace_id}/members")
def list_members(
workspace_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取工作区成员列表。"""
if not check_workspace_access(db, current_user, workspace_id):
raise HTTPException(status_code=403, detail="无权访问此工作区")
memberships = (
db.query(WorkspaceMembership)
.filter(WorkspaceMembership.workspace_id == workspace_id)
.all()
)
result = []
for m in memberships:
user = m.user
result.append({
"id": m.id,
"user_id": user.id,
"username": user.username,
"email": user.email,
"role": m.role,
"joined_at": m.joined_at.isoformat() if m.joined_at else None,
})
return result
@router.post("/{workspace_id}/members", status_code=status.HTTP_201_CREATED)
def add_member(
workspace_id: str,
data: MemberAddRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""添加工作区成员(需工作区管理员权限)。"""
if not check_workspace_access(db, current_user, workspace_id, required_role="admin"):
raise HTTPException(status_code=403, detail="需要工作区管理员权限")
# 查找目标用户
target_user = None
if data.user_id:
target_user = db.query(User).filter(User.id == data.user_id).first()
elif data.username:
target_user = db.query(User).filter(User.username == data.username).first()
if not target_user:
raise NotFoundError("用户")
# 检查是否已存在
existing = (
db.query(WorkspaceMembership)
.filter(
WorkspaceMembership.workspace_id == workspace_id,
WorkspaceMembership.user_id == target_user.id,
)
.first()
)
if existing:
raise HTTPException(status_code=400, detail="该用户已是工作区成员")
# 检查成员数上限
member_count = (
db.query(WorkspaceMembership)
.filter(WorkspaceMembership.workspace_id == workspace_id)
.count()
)
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if member_count >= ws.max_members:
raise HTTPException(status_code=400, detail="工作区成员数已达上限")
membership = WorkspaceMembership(
id=str(uuid.uuid4()),
workspace_id=workspace_id,
user_id=target_user.id,
role=data.role,
joined_at=datetime.utcnow(),
)
db.add(membership)
db.commit()
return {
"id": membership.id,
"user_id": target_user.id,
"username": target_user.username,
"email": target_user.email,
"role": membership.role,
"joined_at": membership.joined_at.isoformat() if membership.joined_at else None,
}
@router.put("/{workspace_id}/members/{user_id}")
def update_member_role(
workspace_id: str,
user_id: str,
data: MemberUpdateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""修改成员角色(需工作区管理员权限)。"""
if not check_workspace_access(db, current_user, workspace_id, required_role="admin"):
raise HTTPException(status_code=403, detail="需要工作区管理员权限")
membership = (
db.query(WorkspaceMembership)
.filter(
WorkspaceMembership.workspace_id == workspace_id,
WorkspaceMembership.user_id == user_id,
)
.first()
)
if not membership:
raise NotFoundError("成员", user_id)
# 不能修改自己的工作区所有者
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if ws.owner_id == user_id and data.role != "admin":
raise HTTPException(status_code=400, detail="工作区所有者必须保持admin角色")
membership.role = data.role
db.commit()
return {"message": "角色已更新"}
@router.delete("/{workspace_id}/members/{user_id}")
def remove_member(
workspace_id: str,
user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""移除工作区成员(需工作区管理员权限,或自行退出)。"""
is_self = user_id == current_user.id
is_admin = check_workspace_access(db, current_user, workspace_id, required_role="admin")
if not is_self and not is_admin:
raise HTTPException(status_code=403, detail="无权移除此成员")
# 不能移除工作区所有者
ws = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if not ws:
raise NotFoundError("工作区", workspace_id)
if ws.owner_id == user_id and not is_self:
raise HTTPException(status_code=400, detail="不能移除工作区所有者")
membership = (
db.query(WorkspaceMembership)
.filter(
WorkspaceMembership.workspace_id == workspace_id,
WorkspaceMembership.user_id == user_id,
)
.first()
)
if not membership:
raise NotFoundError("成员", user_id)
db.delete(membership)
db.commit()
return {"message": "成员已移除"}