""" 对话分支 API — 支持从任意历史点分叉对话 参考 Claude Code src/commands/branch/branch.ts """ from __future__ import annotations import logging from typing import Any, Dict, List, Optional from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field from sqlalchemy.orm import Session from app.core.database import get_db from app.api.auth import get_current_user from app.models.user import User from app.models.agent import Agent from app.agent_runtime import ( AgentRuntime, AgentConfig, AgentLLMConfig, AgentToolConfig, AgentMemoryConfig, AgentStep, ) from app.agent_runtime.context import AgentContext logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1/agent-chat", tags=["agent-branches"]) def _find_agent_node_config(nodes: list) -> Dict[str, Any]: if not nodes: return {} for node in nodes: typ = node.get("type", "") if typ in ("agent", "llm", "template"): return node.get("data") or {} return {} def _build_memory_config_from_node(agent_node_cfg: dict) -> AgentMemoryConfig: from app.core.compaction_config import CompactionConfig compaction_raw = agent_node_cfg.get("compaction") if isinstance(compaction_raw, dict): compaction = CompactionConfig(**compaction_raw) else: compaction = CompactionConfig() return AgentMemoryConfig( max_history_messages=int(agent_node_cfg.get("memory_max_history", 20)), vector_memory_top_k=int(agent_node_cfg.get("memory_vector_top_k", 5)), persist_to_db=bool(agent_node_cfg.get("memory_persist", True)), vector_memory_enabled=bool(agent_node_cfg.get("memory_vector_enabled", True)), learning_enabled=bool(agent_node_cfg.get("memory_learning", True)), memory_dir_enabled=bool(agent_node_cfg.get("memory_dir_enabled", False)), memory_dir_path=str(agent_node_cfg.get("memory_dir_path", "")), compaction=compaction, ) # ──────────────────────────── 请求/响应模型 ──────────────────────────── class BranchCreateRequest(BaseModel): session_id: str = Field(..., description="要分叉的原会话 ID") title: Optional[str] = Field(default=None, description="自定义分支标题") agent_id: Optional[str] = Field(default=None, description="关联 Agent ID") class BranchItem(BaseModel): id: str title: str agent_name: Optional[str] = None parent_session_id: str message_count: int first_user_message: Optional[str] = None created_at: str class Config: from_attributes = True class BranchListResponse(BaseModel): branches: List[BranchItem] total: int class BranchDetailResponse(BaseModel): id: str title: str agent_name: Optional[str] = None parent_session_id: str branch_session_id: str message_count: int first_user_message: Optional[str] = None messages: Optional[List[Dict[str, Any]]] = None created_at: str class Config: from_attributes = True class BranchResumeRequest(BaseModel): message: str = Field(..., description="在分支基础上继续的新消息") class BranchResumeResponse(BaseModel): content: str iterations_used: int tool_calls_made: int truncated: bool session_id: str agent_id: Optional[str] = None steps: List[AgentStep] = Field(default_factory=list) # ──────────────────────────── API 端点 ──────────────────────────── @router.post("/branches", response_model=BranchDetailResponse) async def create_conversation_branch( req: BranchCreateRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """从指定会话创建对话分支(快照完整消息列表)。""" from app.services.conversation_branch_service import ( create_branch, get_messages_from_session, ) messages = get_messages_from_session(db, req.session_id) if not messages: raise HTTPException(status_code=404, detail="找不到该会话的消息记录(请确保会话已产生对话)") agent_name = None if req.agent_id: agent = db.query(Agent).filter(Agent.id == req.agent_id).first() if agent: agent_name = agent.name branch = create_branch( db=db, user_id=current_user.id, parent_session_id=req.session_id, messages=messages, agent_id=req.agent_id, agent_name=agent_name, custom_title=req.title, ) return BranchDetailResponse( id=branch.id, title=branch.title, agent_name=branch.agent_name, parent_session_id=branch.parent_session_id, branch_session_id=branch.branch_session_id, message_count=branch.message_count, first_user_message=branch.first_user_message, messages=branch.messages, created_at=branch.created_at.isoformat() if branch.created_at else "", ) @router.get("/branches", response_model=BranchListResponse) async def list_conversation_branches( agent_id: Optional[str] = None, limit: int = 50, offset: int = 0, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """列出当前用户的所有对话分支。""" from app.services.conversation_branch_service import list_branches branches = list_branches( db=db, user_id=current_user.id, agent_id=agent_id, limit=limit, offset=offset, ) return BranchListResponse( branches=[ BranchItem( id=b.id, title=b.title, agent_name=b.agent_name, parent_session_id=b.parent_session_id, message_count=b.message_count, first_user_message=b.first_user_message, created_at=b.created_at.isoformat() if b.created_at else "", ) for b in branches ], total=len(branches), ) @router.get("/branches/{branch_id}", response_model=BranchDetailResponse) async def get_conversation_branch( branch_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """获取分支详情(含完整消息列表)。""" from app.services.conversation_branch_service import get_branch branch = get_branch(db, branch_id, current_user.id) if not branch: raise HTTPException(status_code=404, detail="分支不存在") return BranchDetailResponse( id=branch.id, title=branch.title, agent_name=branch.agent_name, parent_session_id=branch.parent_session_id, branch_session_id=branch.branch_session_id, message_count=branch.message_count, first_user_message=branch.first_user_message, messages=branch.messages, created_at=branch.created_at.isoformat() if branch.created_at else "", ) @router.delete("/branches/{branch_id}") async def delete_conversation_branch( branch_id: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """删除(软删除)对话分支。""" from app.services.conversation_branch_service import delete_branch success = delete_branch(db, branch_id, current_user.id) if not success: raise HTTPException(status_code=404, detail="分支不存在或无权操作") return {"message": "分支已删除", "branch_id": branch_id} @router.post("/branches/{branch_id}/resume", response_model=BranchResumeResponse) async def resume_from_branch( branch_id: str, req: BranchResumeRequest, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """从分支恢复对话 — 使用分支保存的消息作为上下文继续对话。""" from app.services.conversation_branch_service import get_branch branch = get_branch(db, branch_id, current_user.id) if not branch: raise HTTPException(status_code=404, detail="分支不存在") if not branch.messages: raise HTTPException(status_code=400, detail="分支没有保存的消息") messages = branch.messages system_prompt = "你是一个有用的AI助手。" history_messages = messages if messages and messages[0].get("role") == "system": system_prompt = messages[0].get("content", system_prompt) history_messages = messages[1:] # 构建 Agent 配置 agent = None if branch.agent_id: agent = db.query(Agent).filter(Agent.id == branch.agent_id).first() if agent: wc = agent.workflow_config or {} nodes = wc.get("nodes", []) agent_node_cfg = _find_agent_node_config(nodes) llm_config = AgentLLMConfig( provider=agent_node_cfg.get("provider", "deepseek"), model=agent_node_cfg.get("model", "deepseek-v4-flash"), temperature=float(agent_node_cfg.get("temperature", 0.7)), max_iterations=int(agent_node_cfg.get("max_iterations", 10)), ) config = AgentConfig( name=agent.name, system_prompt=system_prompt, llm=llm_config, tools=AgentToolConfig(include_tools=agent_node_cfg.get("tools", [])), memory=_build_memory_config_from_node(agent_node_cfg), user_id=current_user.id, memory_scope_id=f"{current_user.id}:{branch.agent_id}" if current_user.id else str(branch.agent_id), ) else: config = AgentConfig( name="branch_resume", system_prompt=system_prompt, llm=AgentLLMConfig(model="deepseek-v4-flash", temperature=0.7, max_iterations=10), user_id=current_user.id, memory_scope_id=f"{current_user.id}:__bare__" if current_user.id else "__bare__", ) # 创建 Runtime 并预加载分支消息 context = AgentContext( system_prompt=system_prompt, user_id=current_user.id, session_id=branch.branch_session_id, ) for msg in history_messages: role = msg.get("role", "") content = msg.get("content", "") if role == "user": context.add_user_message(content) elif role == "assistant": tool_calls = msg.get("tool_calls") context.add_assistant_message(content, tool_calls) elif role == "tool": context.add_tool_result( msg.get("tool_call_id", ""), msg.get("name", "unknown"), content, ) # 不需要 on_llm_logger 回调(避免与主 logger 重复),仅做执行 runtime = AgentRuntime(config=config, context=context) result = await runtime.run(req.message) return BranchResumeResponse( content=result.content, iterations_used=result.iterations_used, tool_calls_made=result.tool_calls_made, truncated=result.truncated, session_id=branch.branch_session_id, agent_id=branch.agent_id, steps=result.steps, )