- Fix delete agent 500: clean up FK records (agent_llm_logs, permissions, schedules, executions, team_members) and unbind goals/tasks before delete - Remove hardcoded personality templates in Android, replace with dynamic system prompt generation from name + description - Set promptSectionsEnabled=false to bypass PromptComposer for personality - Add Tencent Cloud Linux deployment guide (Docker Compose) - Accumulated backend service updates, frontend UI fixes, Android app changes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
320 lines
11 KiB
Python
320 lines
11 KiB
Python
"""
|
|
对话分支 API — 支持从任意历史点分叉对话
|
|
|
|
参考 Claude Code src/commands/branch/branch.ts
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.database import get_db
|
|
from app.api.auth import get_current_user
|
|
from app.models.user import User
|
|
from app.models.agent import Agent
|
|
from app.agent_runtime import (
|
|
AgentRuntime,
|
|
AgentConfig,
|
|
AgentLLMConfig,
|
|
AgentToolConfig,
|
|
AgentMemoryConfig,
|
|
AgentStep,
|
|
)
|
|
from app.agent_runtime.context import AgentContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/v1/agent-chat", tags=["agent-branches"])
|
|
|
|
|
|
def _find_agent_node_config(nodes: list) -> Dict[str, Any]:
|
|
if not nodes:
|
|
return {}
|
|
for node in nodes:
|
|
typ = node.get("type", "")
|
|
if typ in ("agent", "llm", "template"):
|
|
return node.get("data") or {}
|
|
return {}
|
|
|
|
|
|
def _build_memory_config_from_node(agent_node_cfg: dict) -> AgentMemoryConfig:
|
|
from app.core.compaction_config import CompactionConfig
|
|
compaction_raw = agent_node_cfg.get("compaction")
|
|
if isinstance(compaction_raw, dict):
|
|
compaction = CompactionConfig(**compaction_raw)
|
|
else:
|
|
compaction = CompactionConfig()
|
|
return AgentMemoryConfig(
|
|
max_history_messages=int(agent_node_cfg.get("memory_max_history", 20)),
|
|
vector_memory_top_k=int(agent_node_cfg.get("memory_vector_top_k", 5)),
|
|
persist_to_db=bool(agent_node_cfg.get("memory_persist", True)),
|
|
vector_memory_enabled=bool(agent_node_cfg.get("memory_vector_enabled", True)),
|
|
learning_enabled=bool(agent_node_cfg.get("memory_learning", True)),
|
|
memory_dir_enabled=bool(agent_node_cfg.get("memory_dir_enabled", False)),
|
|
memory_dir_path=str(agent_node_cfg.get("memory_dir_path", "")),
|
|
compaction=compaction,
|
|
)
|
|
|
|
|
|
# ──────────────────────────── 请求/响应模型 ────────────────────────────
|
|
|
|
class BranchCreateRequest(BaseModel):
|
|
session_id: str = Field(..., description="要分叉的原会话 ID")
|
|
title: Optional[str] = Field(default=None, description="自定义分支标题")
|
|
agent_id: Optional[str] = Field(default=None, description="关联 Agent ID")
|
|
|
|
|
|
class BranchItem(BaseModel):
|
|
id: str
|
|
title: str
|
|
agent_name: Optional[str] = None
|
|
parent_session_id: str
|
|
message_count: int
|
|
first_user_message: Optional[str] = None
|
|
created_at: str
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class BranchListResponse(BaseModel):
|
|
branches: List[BranchItem]
|
|
total: int
|
|
|
|
|
|
class BranchDetailResponse(BaseModel):
|
|
id: str
|
|
title: str
|
|
agent_name: Optional[str] = None
|
|
parent_session_id: str
|
|
branch_session_id: str
|
|
message_count: int
|
|
first_user_message: Optional[str] = None
|
|
messages: Optional[List[Dict[str, Any]]] = None
|
|
created_at: str
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class BranchResumeRequest(BaseModel):
|
|
message: str = Field(..., description="在分支基础上继续的新消息")
|
|
|
|
|
|
class BranchResumeResponse(BaseModel):
|
|
content: str
|
|
iterations_used: int
|
|
tool_calls_made: int
|
|
truncated: bool
|
|
session_id: str
|
|
agent_id: Optional[str] = None
|
|
steps: List[AgentStep] = Field(default_factory=list)
|
|
|
|
|
|
# ──────────────────────────── API 端点 ────────────────────────────
|
|
|
|
@router.post("/branches", response_model=BranchDetailResponse)
|
|
async def create_conversation_branch(
|
|
req: BranchCreateRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""从指定会话创建对话分支(快照完整消息列表)。"""
|
|
from app.services.conversation_branch_service import (
|
|
create_branch,
|
|
get_messages_from_session,
|
|
)
|
|
|
|
messages = get_messages_from_session(db, req.session_id)
|
|
if not messages:
|
|
raise HTTPException(status_code=404, detail="找不到该会话的消息记录(请确保会话已产生对话)")
|
|
|
|
agent_name = None
|
|
if req.agent_id:
|
|
agent = db.query(Agent).filter(Agent.id == req.agent_id).first()
|
|
if agent:
|
|
agent_name = agent.name
|
|
|
|
branch = create_branch(
|
|
db=db,
|
|
user_id=current_user.id,
|
|
parent_session_id=req.session_id,
|
|
messages=messages,
|
|
agent_id=req.agent_id,
|
|
agent_name=agent_name,
|
|
custom_title=req.title,
|
|
)
|
|
|
|
return BranchDetailResponse(
|
|
id=branch.id,
|
|
title=branch.title,
|
|
agent_name=branch.agent_name,
|
|
parent_session_id=branch.parent_session_id,
|
|
branch_session_id=branch.branch_session_id,
|
|
message_count=branch.message_count,
|
|
first_user_message=branch.first_user_message,
|
|
messages=branch.messages,
|
|
created_at=branch.created_at.isoformat() if branch.created_at else "",
|
|
)
|
|
|
|
|
|
@router.get("/branches", response_model=BranchListResponse)
|
|
async def list_conversation_branches(
|
|
agent_id: Optional[str] = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""列出当前用户的所有对话分支。"""
|
|
from app.services.conversation_branch_service import list_branches
|
|
|
|
branches = list_branches(
|
|
db=db, user_id=current_user.id, agent_id=agent_id, limit=limit, offset=offset,
|
|
)
|
|
|
|
return BranchListResponse(
|
|
branches=[
|
|
BranchItem(
|
|
id=b.id, title=b.title, agent_name=b.agent_name,
|
|
parent_session_id=b.parent_session_id,
|
|
message_count=b.message_count,
|
|
first_user_message=b.first_user_message,
|
|
created_at=b.created_at.isoformat() if b.created_at else "",
|
|
)
|
|
for b in branches
|
|
],
|
|
total=len(branches),
|
|
)
|
|
|
|
|
|
@router.get("/branches/{branch_id}", response_model=BranchDetailResponse)
|
|
async def get_conversation_branch(
|
|
branch_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""获取分支详情(含完整消息列表)。"""
|
|
from app.services.conversation_branch_service import get_branch
|
|
|
|
branch = get_branch(db, branch_id, current_user.id)
|
|
if not branch:
|
|
raise HTTPException(status_code=404, detail="分支不存在")
|
|
|
|
return BranchDetailResponse(
|
|
id=branch.id, title=branch.title, agent_name=branch.agent_name,
|
|
parent_session_id=branch.parent_session_id,
|
|
branch_session_id=branch.branch_session_id,
|
|
message_count=branch.message_count,
|
|
first_user_message=branch.first_user_message,
|
|
messages=branch.messages,
|
|
created_at=branch.created_at.isoformat() if branch.created_at else "",
|
|
)
|
|
|
|
|
|
@router.delete("/branches/{branch_id}")
|
|
async def delete_conversation_branch(
|
|
branch_id: str,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""删除(软删除)对话分支。"""
|
|
from app.services.conversation_branch_service import delete_branch
|
|
|
|
success = delete_branch(db, branch_id, current_user.id)
|
|
if not success:
|
|
raise HTTPException(status_code=404, detail="分支不存在或无权操作")
|
|
return {"message": "分支已删除", "branch_id": branch_id}
|
|
|
|
|
|
@router.post("/branches/{branch_id}/resume", response_model=BranchResumeResponse)
|
|
async def resume_from_branch(
|
|
branch_id: str,
|
|
req: BranchResumeRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db),
|
|
):
|
|
"""从分支恢复对话 — 使用分支保存的消息作为上下文继续对话。"""
|
|
from app.services.conversation_branch_service import get_branch
|
|
|
|
branch = get_branch(db, branch_id, current_user.id)
|
|
if not branch:
|
|
raise HTTPException(status_code=404, detail="分支不存在")
|
|
if not branch.messages:
|
|
raise HTTPException(status_code=400, detail="分支没有保存的消息")
|
|
|
|
messages = branch.messages
|
|
system_prompt = "你是一个有用的AI助手。"
|
|
history_messages = messages
|
|
if messages and messages[0].get("role") == "system":
|
|
system_prompt = messages[0].get("content", system_prompt)
|
|
history_messages = messages[1:]
|
|
|
|
# 构建 Agent 配置
|
|
agent = None
|
|
if branch.agent_id:
|
|
agent = db.query(Agent).filter(Agent.id == branch.agent_id).first()
|
|
|
|
if agent:
|
|
wc = agent.workflow_config or {}
|
|
nodes = wc.get("nodes", [])
|
|
agent_node_cfg = _find_agent_node_config(nodes)
|
|
llm_config = AgentLLMConfig(
|
|
provider=agent_node_cfg.get("provider", "deepseek"),
|
|
model=agent_node_cfg.get("model", "deepseek-v4-flash"),
|
|
temperature=float(agent_node_cfg.get("temperature", 0.7)),
|
|
max_iterations=int(agent_node_cfg.get("max_iterations", 10)),
|
|
)
|
|
config = AgentConfig(
|
|
name=agent.name, system_prompt=system_prompt, llm=llm_config,
|
|
tools=AgentToolConfig(include_tools=agent_node_cfg.get("tools", [])),
|
|
memory=_build_memory_config_from_node(agent_node_cfg),
|
|
user_id=current_user.id,
|
|
memory_scope_id=f"{current_user.id}:{branch.agent_id}" if current_user.id else str(branch.agent_id),
|
|
)
|
|
else:
|
|
config = AgentConfig(
|
|
name="branch_resume", system_prompt=system_prompt,
|
|
llm=AgentLLMConfig(model="deepseek-v4-flash", temperature=0.7, max_iterations=10),
|
|
user_id=current_user.id,
|
|
memory_scope_id=f"{current_user.id}:__bare__" if current_user.id else "__bare__",
|
|
)
|
|
|
|
# 创建 Runtime 并预加载分支消息
|
|
context = AgentContext(
|
|
system_prompt=system_prompt,
|
|
user_id=current_user.id,
|
|
session_id=branch.branch_session_id,
|
|
)
|
|
for msg in history_messages:
|
|
role = msg.get("role", "")
|
|
content = msg.get("content", "")
|
|
if role == "user":
|
|
context.add_user_message(content)
|
|
elif role == "assistant":
|
|
tool_calls = msg.get("tool_calls")
|
|
context.add_assistant_message(content, tool_calls)
|
|
elif role == "tool":
|
|
context.add_tool_result(
|
|
msg.get("tool_call_id", ""),
|
|
msg.get("name", "unknown"),
|
|
content,
|
|
)
|
|
|
|
# 不需要 on_llm_logger 回调(避免与主 logger 重复),仅做执行
|
|
runtime = AgentRuntime(config=config, context=context)
|
|
result = await runtime.run(req.message)
|
|
|
|
return BranchResumeResponse(
|
|
content=result.content,
|
|
iterations_used=result.iterations_used,
|
|
tool_calls_made=result.tool_calls_made,
|
|
truncated=result.truncated,
|
|
session_id=branch.branch_session_id,
|
|
agent_id=branch.agent_id,
|
|
steps=result.steps,
|
|
)
|