diff --git a/backend/app/agent_runtime/core.py b/backend/app/agent_runtime/core.py index e8f3b0e..dbbaeee 100644 --- a/backend/app/agent_runtime/core.py +++ b/backend/app/agent_runtime/core.py @@ -10,6 +10,7 @@ Agent Runtime 核心 —— 自主 ReAct 循环。 """ from __future__ import annotations +import hashlib import json import logging import time @@ -98,6 +99,9 @@ class AgentRuntime: self.tool_manager = tool_manager or AgentToolManager( include_tools=self.config.tools.include_tools, exclude_tools=self.config.tools.exclude_tools, + cache_enabled=self.config.tools.cache_enabled, + cache_tool_whitelist=self.config.tools.cache_tool_whitelist, + cache_ttl_ms=self.config.tools.cache_ttl_ms, ) self.execution_logger = execution_logger self.on_tool_executed = on_tool_executed @@ -912,6 +916,32 @@ class AgentRuntime: return any(kw in err_lower for kw in _RETRYABLE_ERRORS) +# LLM 缓存辅助 +def _llm_cache_key(messages: list, model: str) -> str: + import hashlib + raw = json.dumps({"msgs": messages[-4:], "model": model}, sort_keys=True, ensure_ascii=False) + return f"llm:{model}:{hashlib.sha256(raw.encode()).hexdigest()[:16]}" + +async def _llm_cache_get(key: str) -> Optional[str]: + try: + from app.core.redis_client import get_redis_client + redis = get_redis_client() + if redis: + return await redis.get(key) + except Exception: + pass + return None + +async def _llm_cache_set(key: str, value: str, ttl_ms: int): + try: + from app.core.redis_client import get_redis_client + redis = get_redis_client() + if redis: + await redis.setex(key, max(1, int(ttl_ms / 1000)), value) + except Exception: + pass + + class _LLMClient: """轻量 LLM 客户端包装,复用已有 LLMService 能力。""" @@ -984,12 +1014,29 @@ class _LLMClient: kwargs["tools"] = normalized kwargs["tool_choice"] = "auto" + # LLM 响应缓存(仅不用工具时缓存,避免复杂序列化) + if self._config.cache_enabled and not tools: + cache_key = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", "")) + cached = await _llm_cache_get(cache_key) + if cached is not None: + logger.info("LLM 响应命中缓存: model=%s", kwargs.get("model")) + # 构造简易 message 对象(含 content 字段即可) + class _CachedMsg: + content = cached + tool_calls = None + return _CachedMsg() + start_time = time.perf_counter() try: response = await client.chat.completions.create(**kwargs) latency_ms = int((time.perf_counter() - start_time) * 1000) message = response.choices[0].message + # 缓存写入(仅不用工具时) + if self._config.cache_enabled and not tools and message.content: + ck = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", "")) + await _llm_cache_set(ck, message.content, self._config.cache_ttl_ms) + # 提取 token 用量 usage = getattr(response, "usage", None) prompt_tokens = usage.prompt_tokens if usage else 0 diff --git a/backend/app/agent_runtime/orchestrator.py b/backend/app/agent_runtime/orchestrator.py index 6f73c1d..28ef8d7 100644 --- a/backend/app/agent_runtime/orchestrator.py +++ b/backend/app/agent_runtime/orchestrator.py @@ -9,6 +9,7 @@ Agent Orchestrator — 多 Agent 编排引擎。 """ from __future__ import annotations +import asyncio import json import logging import uuid @@ -347,9 +348,10 @@ class AgentOrchestrator: steps: List[OrchestratorStep] = [] agent_outputs: List[Dict[str, Any]] = [] - # 第一阶段:所有 Agent 独立回答 + # 第一阶段:所有 Agent 并行独立回答 + runtimes = [] for agent_cfg in agents: - runtime = AgentRuntime( + runtimes.append(AgentRuntime( AgentConfig( name=agent_cfg.name, system_prompt=agent_cfg.system_prompt, @@ -364,8 +366,30 @@ class AgentOrchestrator: ), ), on_llm_call=on_llm_call, - ) - result = await runtime.run(question) + )) + + results = await asyncio.gather( + *[rt.run(question) for rt in runtimes], + return_exceptions=True, + ) + + for i, agent_cfg in enumerate(agents): + result = results[i] + if isinstance(result, BaseException): + step = OrchestratorStep( + agent_id=agent_cfg.id, + agent_name=agent_cfg.name, + input=question, + output="", + error=str(result), + ) + steps.append(step) + agent_outputs.append({ + "agent_id": agent_cfg.id, + "agent_name": agent_cfg.name, + "output": f"[错误] {result}", + }) + continue step = OrchestratorStep( agent_id=agent_cfg.id, diff --git a/backend/app/agent_runtime/schemas.py b/backend/app/agent_runtime/schemas.py index bca2bc7..85e51a7 100644 --- a/backend/app/agent_runtime/schemas.py +++ b/backend/app/agent_runtime/schemas.py @@ -15,6 +15,10 @@ class AgentToolConfig(BaseModel): require_approval: List[str] = Field(default_factory=list, description="需要人工审批的工具名列表") approval_timeout_ms: int = Field(default=60000, description="审批超时(毫秒),超时使用默认策略") approval_default: str = Field(default="deny", description="超时默认策略: approve | deny | skip") + # 结果缓存 + cache_enabled: bool = Field(default=True, description="是否启用工具结果缓存(确定性工具默认开启)") + cache_tool_whitelist: List[str] = Field(default_factory=list, description="启用缓存的工具名(空=确定性工具默认)") + cache_ttl_ms: int = Field(default=3600000, description="缓存 TTL(毫秒),默认 1 小时") class AgentMemoryConfig(BaseModel): @@ -40,6 +44,8 @@ class AgentLLMConfig(BaseModel): request_timeout: float = 120.0 extra_body: Optional[Dict[str, Any]] = None self_review_threshold: float = 0.6 # self-review 通过阈值(0-1) + cache_enabled: bool = False # LLM 响应缓存(默认关闭,语义缓存有风险) + cache_ttl_ms: int = 300000 # LLM 缓存 TTL,默认 5 分钟 class AgentBudgetConfig(BaseModel): diff --git a/backend/app/agent_runtime/tool_manager.py b/backend/app/agent_runtime/tool_manager.py index 80dd682..ada6a7d 100644 --- a/backend/app/agent_runtime/tool_manager.py +++ b/backend/app/agent_runtime/tool_manager.py @@ -3,6 +3,8 @@ Agent 工具管理器:包装已有 ToolRegistry,提供 Agent 需要的工具 """ from __future__ import annotations +import hashlib +import json import logging from typing import Any, Dict, List, Optional @@ -10,6 +12,12 @@ from app.services.tool_registry import tool_registry logger = logging.getLogger(__name__) +# 默认确定性工具(结果可缓存) +_DETERMINISTIC_TOOLS = { + "file_read", "math_calculate", "database_query", + "json", "text", "csv", "excel", "pdf", "image", +} + class AgentToolManager: """ @@ -17,12 +25,54 @@ class AgentToolManager: - 将 ToolRegistry 的工具 schema 转为 OpenAI Function Calling 格式 - 按 Agent 配置过滤(白名单/黑名单) - 执行工具调用并返回结果字符串 + - 工具结果缓存(Redis / 内存 fallback) """ def __init__(self, include_tools: Optional[List[str]] = None, - exclude_tools: Optional[List[str]] = None): + exclude_tools: Optional[List[str]] = None, + cache_enabled: bool = True, + cache_tool_whitelist: Optional[List[str]] = None, + cache_ttl_ms: int = 3600000): self._include_tools: set = set(include_tools or []) self._exclude_tools: set = set(exclude_tools or []) + self._cache_enabled = cache_enabled + self._cache_whitelist: set = set(cache_tool_whitelist or []) + self._cache_ttl_s = max(1, int(cache_ttl_ms / 1000)) + self._cache_store: Dict[str, str] = {} # 内存 fallback + + def _is_cacheable(self, tool_name: str) -> bool: + """判断工具结果是否可缓存。""" + if not self._cache_enabled: + return False + if self._cache_whitelist: + return tool_name in self._cache_whitelist + return tool_name in _DETERMINISTIC_TOOLS + + @staticmethod + def _cache_key(name: str, args: Dict[str, Any]) -> str: + raw = json.dumps([name, args], sort_keys=True, ensure_ascii=False) + return f"tool:{name}:{hashlib.sha256(raw.encode()).hexdigest()[:16]}" + + async def _cache_get(self, key: str) -> Optional[str]: + try: + from app.core.redis_client import get_redis_client + redis = get_redis_client() + if redis: + return await redis.get(key) + except Exception: + pass + return self._cache_store.get(key) + + async def _cache_set(self, key: str, value: str): + try: + from app.core.redis_client import get_redis_client + redis = get_redis_client() + if redis: + await redis.setex(key, self._cache_ttl_s, value) + return + except Exception: + pass + self._cache_store[key] = value def get_tool_schemas(self) -> List[Dict[str, Any]]: """获取 Agent 可用的工具定义列表(OpenAI Function Calling 格式)。""" @@ -55,7 +105,7 @@ class AgentToolManager: async def execute(self, name: str, args: Dict[str, Any]) -> str: """ - 执行工具调用。 + 执行工具调用(带缓存)。 优先查找内置工具,其次查找数据库自定义工具(HTTP / Code)。 @@ -66,8 +116,24 @@ class AgentToolManager: Returns: 工具执行结果的字符串表示 """ + # 缓存检查 + if self._is_cacheable(name): + ck = self._cache_key(name, args) + cached = await self._cache_get(ck) + if cached is not None: + logger.info("Agent 工具命中缓存: %s", name) + return cached + logger.info("Agent 执行工具: %s", name) - return await tool_registry.execute_tool(name, args) + result = await tool_registry.execute_tool(name, args) + + # 缓存写入 + if self._is_cacheable(name): + ck = self._cache_key(name, args) + await self._cache_set(ck, result) + logger.debug("Agent 工具结果已缓存: %s", name) + + return result @staticmethod def _extract_tool_name(schema: Dict[str, Any]) -> Optional[str]: diff --git a/backend/app/api/websocket.py b/backend/app/api/websocket.py index 349872d..71debfd 100644 --- a/backend/app/api/websocket.py +++ b/backend/app/api/websocket.py @@ -13,6 +13,20 @@ import asyncio router = APIRouter() +def _get_progress_from_redis(execution_id: str) -> Optional[dict]: + """从 Redis 读取进度数据。""" + try: + from app.core.redis_client import get_redis_client + redis_client = get_redis_client() + if redis_client: + raw = redis_client.get(f"workflow:progress:{execution_id}") + if raw: + return json.loads(raw) + except Exception: + pass + return None + + @router.websocket("/api/v1/ws/executions/{execution_id}") async def websocket_execution_status( websocket: WebSocket, @@ -21,31 +35,36 @@ async def websocket_execution_status( ): """ WebSocket实时推送执行状态 - + Args: websocket: WebSocket连接 execution_id: 执行记录ID token: JWT Token(可选,通过query参数传递) """ - # 验证token(可选,如果需要认证) - # user = await get_current_user_optional(token) - - # 建立连接 await websocket_manager.connect(websocket, execution_id) - + db = SessionLocal() - + try: # 发送初始状态 execution = db.query(Execution).filter(Execution.id == execution_id).first() if execution: - await websocket_manager.send_personal_message({ - "type": "status", - "execution_id": execution_id, - "status": execution.status, - "progress": 0, - "message": "连接已建立" - }, websocket) + # 尝试从 Redis 读取进度 + redis_progress = _get_progress_from_redis(execution_id) + if redis_progress: + await websocket_manager.send_personal_message({ + "type": "progress", + "execution_id": execution_id, + **redis_progress, + }, websocket) + else: + await websocket_manager.send_personal_message({ + "type": "status", + "execution_id": execution_id, + "status": execution.status, + "progress": 0, + "message": "连接已建立" + }, websocket) else: await websocket_manager.send_personal_message({ "type": "error", @@ -53,30 +72,56 @@ async def websocket_execution_status( }, websocket) await websocket.close() return - + + last_progress = -1 + # 持续监听并推送状态更新 while True: try: - # 接收客户端消息(心跳等) - data = await websocket.receive_text() - - # 处理客户端消息 + # 接收客户端消息(心跳等),超时 1 秒以便轮询进度 try: - message = json.loads(data) - if message.get("type") == "ping": - await websocket_manager.send_personal_message({ - "type": "pong" - }, websocket) - except: - pass - + data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0) + try: + message = json.loads(data) + if message.get("type") == "ping": + await websocket_manager.send_personal_message({ + "type": "pong" + }, websocket) + except Exception: + pass + except asyncio.TimeoutError: + pass # 超时后轮询进度 + except WebSocketDisconnect: break - - # 检查执行状态 - db.refresh(execution) - + + # 优先从 Redis 读取进度(推送模式) + redis_progress = _get_progress_from_redis(execution_id) + if redis_progress: + pct = redis_progress.get("progress", -1) + if pct != last_progress: + last_progress = pct + await websocket_manager.send_personal_message({ + "type": "progress", + "execution_id": execution_id, + **redis_progress, + }, websocket) + else: + # Redis 不可用时回退到 DB 轮询 + db.refresh(execution) + pct = 100 if execution.status in ["completed", "failed"] else (50 if execution.status == "running" else 0) + if pct != last_progress: + last_progress = pct + await websocket_manager.send_personal_message({ + "type": "status", + "execution_id": execution_id, + "status": execution.status, + "progress": pct, + "message": f"执行中..." if execution.status == "running" else "等待执行" + }, websocket) + # 如果执行完成或失败,发送最终状态并断开 + db.refresh(execution) if execution.status in ["completed", "failed"]: await websocket_manager.send_personal_message({ "type": "status", @@ -87,24 +132,10 @@ async def websocket_execution_status( "error": execution.error_message if execution.status == "failed" else None, "execution_time": execution.execution_time }, websocket) - - # 等待一下再断开,确保客户端收到消息 + await asyncio.sleep(1) break - - # 定期发送状态更新(每2秒) - await asyncio.sleep(2) - - # 重新查询执行状态 - db.refresh(execution) - await websocket_manager.send_personal_message({ - "type": "status", - "execution_id": execution_id, - "status": execution.status, - "progress": 50 if execution.status == "running" else 0, - "message": f"执行中..." if execution.status == "running" else "等待执行" - }, websocket) - + except WebSocketDisconnect: pass except Exception as e: @@ -114,7 +145,7 @@ async def websocket_execution_status( "type": "error", "message": f"发生错误: {str(e)}" }, websocket) - except: + except Exception: pass finally: websocket_manager.disconnect(websocket, execution_id) diff --git a/backend/app/services/workflow_engine.py b/backend/app/services/workflow_engine.py index 698e832..7f7d70e 100644 --- a/backend/app/services/workflow_engine.py +++ b/backend/app/services/workflow_engine.py @@ -5589,14 +5589,18 @@ class WorkflowEngine: self, input_data: Dict[str, Any], resume_snapshot: Optional[Dict[str, Any]] = None, + execution_id: Optional[str] = None, + on_progress: Optional[callable] = None, ) -> Dict[str, Any]: """ 执行完整工作流 - + Args: input_data: 初始输入数据(恢复执行时须包含 __hil_decision 等) resume_snapshot: 从挂起快照恢复(与 pause_state 一致) - + execution_id: 执行记录 ID(用于进度上报) + on_progress: 进度回调 async def(execution_id, progress_data) + Returns: 执行结果 """ @@ -5641,6 +5645,8 @@ class WorkflowEngine: self._tool_calls_used = 0 results = {} + total_nodes = len(self.nodes) + # 按拓扑顺序执行节点(动态构建执行图) while True: # 构建当前活跃的执行图 @@ -5656,12 +5662,12 @@ class WorkflowEngine: f"[rjb] 当前执行图: {execution_order}, 活跃边数: {len(active_edges)}, 已执行节点: {executed_nodes}" ) - next_node_id = None + # 收集所有就绪节点(DAG 并行) + ready_nodes: list = [] for node_id in pending_ids: can_execute = True incoming_edges = [e for e in active_edges if e["target"] == node_id] if not incoming_edges: - # 没有入边:仅允许 Start;孤立节点跳过 if node_id not in [n["id"] for n in self.nodes.values() if n.get("type") == "start"]: logger.debug(f"[rjb] 节点 {node_id} 没有入边,跳过执行") continue @@ -5669,7 +5675,6 @@ class WorkflowEngine: for edge in incoming_edges: src = edge["source"] if src not in forward_reachable: - # 条件分支裁剪后不可达的前驱,不参与 gate(OR-join) continue if src not in executed_nodes: can_execute = False @@ -5678,298 +5683,327 @@ class WorkflowEngine: ) break if can_execute: - next_node_id = node_id + ready_nodes.append(node_id) logger.info( - f"[rjb] 选择执行节点: {next_node_id}, 类型: {self.nodes[next_node_id].get('type')}, 入边数: {len(incoming_edges)}" + f"[rjb] 就绪节点: {node_id}, 类型: {self.nodes[node_id].get('type')}" ) - break - if not next_node_id: + if not ready_nodes: break # 没有更多节点可执行 - - node = self.nodes[next_node_id] - is_approval = node.get("type") == "approval" - if not is_approval: - executed_nodes.add(next_node_id) - execution_sequence.append(next_node_id) - # 调试:检查节点数据结构 - if node.get('type') == 'llm': - logger.debug(f"[rjb] 执行LLM节点: node_id={next_node_id}, node keys={list(node.keys())}, data keys={list(node.get('data', {}).keys()) if node.get('data') else []}") - - # 获取节点输入(使用活跃的边) - node_input = self.get_node_input(next_node_id, self.node_outputs, active_edges) - - # 如果是起始节点,使用初始输入 - if node.get('type') == 'start' and not node_input: - node_input = input_data - logger.info(f"[rjb] Start节点使用初始输入: node_id={next_node_id}, node_input={node_input}") - - # 调试:记录节点输入数据 - if node.get('type') == 'llm': - logger.info(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}") - if 'start-1' in self.node_outputs: - logger.info(f"[rjb] Start节点输出内容: {self.node_outputs['start-1']}") + # 标记所有就绪节点为已执行(approval 节点除外) + for nid in ready_nodes: + node = self.nodes[nid] + if node.get("type") != "approval": + executed_nodes.add(nid) + execution_sequence.append(nid) - # 单执行步数预算(每执行一个节点计 1 步) - self._steps_used += 1 + # 获取所有就绪节点的输入 + node_inputs: dict = {} + for nid in ready_nodes: + node = self.nodes[nid] + node_input = self.get_node_input(nid, self.node_outputs, active_edges) + if node.get('type') == 'start' and not node_input: + node_input = input_data + node_inputs[nid] = node_input + + # 预算检查(整批节点) + self._steps_used += len(ready_nodes) if self._steps_used > self._cap_steps: raise WorkflowExecutionError( detail=f"已超过单执行预算上限({self._cap_steps} 步),已熔断", - node_id=next_node_id, + node_id=ready_nodes[0], ) - # 执行节点 - result = await self.execute_node(node, node_input) + # 并行执行所有就绪节点 + async def _exec_one(nid: str): + try: + return await self.execute_node(self.nodes[nid], node_inputs[nid]) + except Exception as exc: + return {"status": "failed", "error": str(exc)} - if result.get("status") == "awaiting_approval": - self._steps_used -= 1 - snap = self._build_pause_snapshot( - next_node_id, active_edges, executed_nodes, execution_sequence, results - ) - raise WorkflowPaused(snap) - - if is_approval: - executed_nodes.add(next_node_id) - execution_sequence.append(next_node_id) - - results[next_node_id] = result - - # 保存节点输出 - if result.get('status') == 'success': - output_value = result.get('output', {}) - self.node_outputs[next_node_id] = output_value - if node.get('type') == 'start': - logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}") - - # 如果是条件节点或Switch节点,根据分支结果过滤边 - if node.get('type') == 'condition': - branch = result.get('branch', 'false') - logger.info(f"[rjb] 条件节点分支过滤: node_id={next_node_id}, branch={branch}") - # 移除不符合条件的边 - # 只保留:1) 不是从条件节点出发的边,或 2) 从条件节点出发且sourceHandle匹配分支的边 - edges_to_remove = [] - edges_to_keep = [] - for edge in active_edges: - if edge['source'] == next_node_id: - # 这是从条件节点出发的边 - edge_handle = edge.get('sourceHandle') - if edge_handle == branch: - # sourceHandle匹配分支,保留 - edges_to_keep.append(edge) - logger.info(f"[rjb] 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})") - else: - # sourceHandle不匹配或为None,移除 - edges_to_remove.append(edge) - logger.info(f"[rjb] 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch})") - else: - # 不是从条件节点出发的边,保留 - edges_to_keep.append(edge) - - active_edges = edges_to_keep - - elif node.get('type') == 'switch': - branch = result.get('branch', 'default') - logger.info(f"[rjb] Switch节点分支过滤: node_id={next_node_id}, branch={branch}") - - # 记录过滤前的边信息 - edges_before = [e for e in active_edges if e['source'] == next_node_id] - logger.info(f"[rjb] Switch节点过滤前: 从节点出发的边有{len(edges_before)}条") - for edge in edges_before: - logger.info(f"[rjb] 边 {edge.get('id')}: sourceHandle={edge.get('sourceHandle')}, target={edge.get('target')}") - - # 移除不匹配的边 - edges_to_keep = [] - edges_removed_count = 0 - removed_source_nodes = set() # 记录被移除边的源节点 - - for edge in active_edges: - if edge['source'] == next_node_id: - # 这是从Switch节点出发的边 - edge_handle = edge.get('sourceHandle') - if edge_handle == branch: - # sourceHandle匹配分支,保留 - edges_to_keep.append(edge) - logger.info(f"[rjb] ✅ 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})") - else: - # sourceHandle不匹配,移除 - edges_removed_count += 1 - target_id = edge.get('target') - removed_source_nodes.add(target_id) # 记录目标节点(这些节点将不再可达) - logger.info(f"[rjb] ❌ 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch}, target={target_id})") - else: - # 不是从Switch节点出发的边,保留 - edges_to_keep.append(edge) - - # 重要:移除那些指向被过滤节点的边(这些边来自被过滤的LLM节点) - # 例如:如果llm-question被过滤了,那么llm-question → merge-response的边也应该被移除 - additional_removed = 0 - for edge in list(edges_to_keep): # 使用list副本,因为我们要修改原列表 - if edge['source'] in removed_source_nodes: - # 这条边来自被过滤的节点,也应该被移除 - edges_to_keep.remove(edge) - additional_removed += 1 - logger.info(f"[rjb] ❌ 移除来自被过滤节点的边: {edge.get('id')} ({edge.get('source')} → {edge.get('target')})") - - edges_removed_count += additional_removed - - active_edges = edges_to_keep - filter_info = { - 'branch': branch, - 'edges_before': len(edges_before), - 'edges_kept': len([e for e in edges_to_keep if e['source'] == next_node_id]), - 'edges_removed': edges_removed_count - } - logger.info(f"[rjb] Switch节点过滤后: 保留{len(active_edges)}条边(其中从Switch节点出发的{filter_info['edges_kept']}条),移除{edges_removed_count}条边") - # 记录过滤后的活跃边 - remaining_switch_edges = [e for e in active_edges if e['source'] == next_node_id] - logger.info(f"[rjb] Switch节点过滤后剩余的边: {[e.get('id') + '->' + e.get('target') for e in remaining_switch_edges]}") - - # 重要:找出那些不再可达的节点(这些节点只通过被移除的边连接) - removed_targets = set() - for edge in edges_before: - if edge not in edges_to_keep: - target_id = edge.get('target') - removed_targets.add(target_id) - logger.info(f"[rjb] ❌ 节点 {target_id} 的边已被移除,该节点将不会被执行") - - # 关键修复:立即重新构建执行图,确保不再可达的节点不在执行图中 - # 这样在下次循环时,这些节点就不会被选择执行 - logger.info(f"[rjb] Switch节点过滤后,重新构建执行图(排除 {len(removed_targets)} 个不再可达的节点)") - - # 同时记录到数据库 - if self.logger: - self.logger.info( - f"Switch节点分支过滤: branch={branch}, 保留{filter_info['edges_kept']}条边,移除{edges_removed_count}条边", - node_id=next_node_id, - node_type='switch', - data=filter_info - ) - - elif node.get('type') == 'approval': - branch = result.get('branch', 'approved') - logger.info(f"[rjb] Approval节点分支过滤: node_id={next_node_id}, branch={branch}") - edges_to_keep = [] - removed_source_nodes = set() - for edge in active_edges: - if edge['source'] == next_node_id: - edge_handle = edge.get('sourceHandle') - if edge_handle == branch: - edges_to_keep.append(edge) - else: - removed_source_nodes.add(edge.get('target')) - else: - edges_to_keep.append(edge) - for edge in list(edges_to_keep): - if edge['source'] in removed_source_nodes: - edges_to_keep.remove(edge) - active_edges = edges_to_keep - if self.logger: - self.logger.info( - f"Approval节点分支过滤: branch={branch}", - node_id=next_node_id, - node_type='approval', - ) - - # 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行) - if node.get('type') in ['loop', 'foreach']: - # 标记循环体的节点为已执行(简化处理) - for edge in active_edges[:]: # 使用切片复制列表 - if edge.get('source') == next_node_id: - target_id = edge.get('target') - if target_id in self.nodes: - # 检查是否是循环结束节点 - target_node = self.nodes[target_id] - if target_node.get('type') not in ['loop_end', 'end']: - # 标记为已执行(循环体已在循环节点内部执行) - executed_nodes.add(target_id) - # 继续查找循环体内的节点 - self._mark_loop_body_executed(target_id, executed_nodes, active_edges) + if len(ready_nodes) == 1: + batch_results = {ready_nodes[0]: await _exec_one(ready_nodes[0])} else: - # 执行失败或质量不达标 — 支持重试 - failed_status = result.get('status', 'failed') - error_msg = result.get('error', '未知错误') - node_type = node.get('type', 'unknown') + logger.info(f"[rjb] 并行执行 {len(ready_nodes)} 个节点: {ready_nodes}") + tasks = [_exec_one(nid) for nid in ready_nodes] + gathered = await asyncio.gather(*tasks, return_exceptions=True) + batch_results = {} + for i, nid in enumerate(ready_nodes): + r = gathered[i] + if isinstance(r, BaseException): + batch_results[nid] = {"status": "failed", "error": str(r)} + else: + batch_results[nid] = r - # 处理 error_handler 返回的 retry_predecessor - if failed_status == 'retry_predecessor': - pred_id = result.get('predecessor_id') - eh_retry_count = result.get('retry_count', 1) - eh_retry_delay_ms = result.get('retry_delay_ms', 1000) + # 逐一处理各节点结果 + for next_node_id in ready_nodes: + node = self.nodes[next_node_id] + result = batch_results[next_node_id] + is_approval = node.get("type") == "approval" - # 检查是否已有重试计数(多次经过 error_handler) - prev_counter = self.node_outputs.get(f"_eh_retry_{pred_id}", {}) - remaining = prev_counter.get("remaining", eh_retry_count) - if pred_id and pred_id in self.nodes and remaining > 0: - remaining -= 1 - logger.warning( - f"error_handler 请求重试前驱节点 {pred_id},剩余 {remaining} 次" - ) - self.executed_nodes.discard(pred_id) - self.node_outputs.pop(pred_id, None) - self.node_outputs[f"_eh_retry_{pred_id}"] = { - "remaining": remaining, - "delay_ms": eh_retry_delay_ms, - "error_handler_id": next_node_id, - } - executed_nodes.add(next_node_id) - execution_sequence.append(next_node_id) - results[next_node_id] = result - await asyncio.sleep(eh_retry_delay_ms / 1000.0) - continue - elif remaining <= 0: - logger.error(f"error_handler 重试次数耗尽,前驱节点 {pred_id} 停止") - raise WorkflowExecutionError( - detail=f"error_handler 重试次数耗尽: {error_msg}", - node_id=pred_id, - ) - - # 检查节点级 retry_config - node_data = node.get('data', {}) or {} - retry_cfg = node_data.get('retry_config', {}) - max_retries = retry_cfg.get('max_retries', 0) if isinstance(retry_cfg, dict) else 0 - retry_delay_ms = retry_cfg.get('retry_delay_ms', 1000) if isinstance(retry_cfg, dict) else 1000 - on_exhausted = retry_cfg.get('on_exhausted', 'stop') if isinstance(retry_cfg, dict) else 'stop' - - retry_key = f"_retry_{next_node_id}" - retries_done = self.node_outputs.get(retry_key, 0) - - if max_retries > 0 and retries_done < max_retries: - self.node_outputs[retry_key] = retries_done + 1 - logger.warning( - f"节点 {next_node_id} ({node_type}) 执行失败,重试 {retries_done + 1}/{max_retries}: {error_msg}" + if result.get("status") == "awaiting_approval": + self._steps_used -= 1 + snap = self._build_pause_snapshot( + next_node_id, active_edges, executed_nodes, execution_sequence, results ) - await asyncio.sleep(retry_delay_ms / 1000.0) - continue # 不标记已执行,下次循环重新执行 + raise WorkflowPaused(snap) - # 重试耗尽或未配置重试 - if retries_done >= max_retries and max_retries > 0: - if on_exhausted == 'skip': - logger.warning(f"节点 {next_node_id} 重试耗尽,跳过: {error_msg}") - self.node_outputs[next_node_id] = { - 'status': 'skipped', 'error': error_msg - } - executed_nodes.add(next_node_id) - execution_sequence.append(next_node_id) - results[next_node_id] = result - continue - elif on_exhausted == 'notify': - logger.error(f"节点 {next_node_id} 重试耗尽,已通知: {error_msg}") - self.node_outputs[next_node_id] = { - 'status': 'error_notified', 'error': error_msg - } - executed_nodes.add(next_node_id) - execution_sequence.append(next_node_id) - results[next_node_id] = result - continue + if is_approval: + executed_nodes.add(next_node_id) + execution_sequence.append(next_node_id) + + results[next_node_id] = result + + # 保存节点输出 + if result.get('status') == 'success': + output_value = result.get('output', {}) + self.node_outputs[next_node_id] = output_value + if node.get('type') == 'start': + logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}") + + # 如果是条件节点或Switch节点,根据分支结果过滤边 + if node.get('type') == 'condition': + branch = result.get('branch', 'false') + logger.info(f"[rjb] 条件节点分支过滤: node_id={next_node_id}, branch={branch}") + # 移除不符合条件的边 + # 只保留:1) 不是从条件节点出发的边,或 2) 从条件节点出发且sourceHandle匹配分支的边 + edges_to_remove = [] + edges_to_keep = [] + for edge in active_edges: + if edge['source'] == next_node_id: + # 这是从条件节点出发的边 + edge_handle = edge.get('sourceHandle') + if edge_handle == branch: + # sourceHandle匹配分支,保留 + edges_to_keep.append(edge) + logger.info(f"[rjb] 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})") + else: + # sourceHandle不匹配或为None,移除 + edges_to_remove.append(edge) + logger.info(f"[rjb] 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch})") + else: + # 不是从条件节点出发的边,保留 + edges_to_keep.append(edge) + + active_edges = edges_to_keep + + elif node.get('type') == 'switch': + branch = result.get('branch', 'default') + logger.info(f"[rjb] Switch节点分支过滤: node_id={next_node_id}, branch={branch}") + + # 记录过滤前的边信息 + edges_before = [e for e in active_edges if e['source'] == next_node_id] + logger.info(f"[rjb] Switch节点过滤前: 从节点出发的边有{len(edges_before)}条") + for edge in edges_before: + logger.info(f"[rjb] 边 {edge.get('id')}: sourceHandle={edge.get('sourceHandle')}, target={edge.get('target')}") + + # 移除不匹配的边 + edges_to_keep = [] + edges_removed_count = 0 + removed_source_nodes = set() # 记录被移除边的源节点 + + for edge in active_edges: + if edge['source'] == next_node_id: + # 这是从Switch节点出发的边 + edge_handle = edge.get('sourceHandle') + if edge_handle == branch: + # sourceHandle匹配分支,保留 + edges_to_keep.append(edge) + logger.info(f"[rjb] ✅ 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})") + else: + # sourceHandle不匹配,移除 + edges_removed_count += 1 + target_id = edge.get('target') + removed_source_nodes.add(target_id) # 记录目标节点(这些节点将不再可达) + logger.info(f"[rjb] ❌ 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch}, target={target_id})") + else: + # 不是从Switch节点出发的边,保留 + edges_to_keep.append(edge) + + # 重要:移除那些指向被过滤节点的边(这些边来自被过滤的LLM节点) + # 例如:如果llm-question被过滤了,那么llm-question → merge-response的边也应该被移除 + additional_removed = 0 + for edge in list(edges_to_keep): # 使用list副本,因为我们要修改原列表 + if edge['source'] in removed_source_nodes: + # 这条边来自被过滤的节点,也应该被移除 + edges_to_keep.remove(edge) + additional_removed += 1 + logger.info(f"[rjb] ❌ 移除来自被过滤节点的边: {edge.get('id')} ({edge.get('source')} → {edge.get('target')})") + + edges_removed_count += additional_removed + + active_edges = edges_to_keep + filter_info = { + 'branch': branch, + 'edges_before': len(edges_before), + 'edges_kept': len([e for e in edges_to_keep if e['source'] == next_node_id]), + 'edges_removed': edges_removed_count + } + logger.info(f"[rjb] Switch节点过滤后: 保留{len(active_edges)}条边(其中从Switch节点出发的{filter_info['edges_kept']}条),移除{edges_removed_count}条边") + # 记录过滤后的活跃边 + remaining_switch_edges = [e for e in active_edges if e['source'] == next_node_id] + logger.info(f"[rjb] Switch节点过滤后剩余的边: {[e.get('id') + '->' + e.get('target') for e in remaining_switch_edges]}") + + # 重要:找出那些不再可达的节点(这些节点只通过被移除的边连接) + removed_targets = set() + for edge in edges_before: + if edge not in edges_to_keep: + target_id = edge.get('target') + removed_targets.add(target_id) + logger.info(f"[rjb] ❌ 节点 {target_id} 的边已被移除,该节点将不会被执行") + + # 关键修复:立即重新构建执行图,确保不再可达的节点不在执行图中 + # 这样在下次循环时,这些节点就不会被选择执行 + logger.info(f"[rjb] Switch节点过滤后,重新构建执行图(排除 {len(removed_targets)} 个不再可达的节点)") + + # 同时记录到数据库 + if self.logger: + self.logger.info( + f"Switch节点分支过滤: branch={branch}, 保留{filter_info['edges_kept']}条边,移除{edges_removed_count}条边", + node_id=next_node_id, + node_type='switch', + data=filter_info + ) + + elif node.get('type') == 'approval': + branch = result.get('branch', 'approved') + logger.info(f"[rjb] Approval节点分支过滤: node_id={next_node_id}, branch={branch}") + edges_to_keep = [] + removed_source_nodes = set() + for edge in active_edges: + if edge['source'] == next_node_id: + edge_handle = edge.get('sourceHandle') + if edge_handle == branch: + edges_to_keep.append(edge) + else: + removed_source_nodes.add(edge.get('target')) + else: + edges_to_keep.append(edge) + for edge in list(edges_to_keep): + if edge['source'] in removed_source_nodes: + edges_to_keep.remove(edge) + active_edges = edges_to_keep + if self.logger: + self.logger.info( + f"Approval节点分支过滤: branch={branch}", + node_id=next_node_id, + node_type='approval', + ) + + # 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行) + if node.get('type') in ['loop', 'foreach']: + # 标记循环体的节点为已执行(简化处理) + for edge in active_edges[:]: # 使用切片复制列表 + if edge.get('source') == next_node_id: + target_id = edge.get('target') + if target_id in self.nodes: + # 检查是否是循环结束节点 + target_node = self.nodes[target_id] + if target_node.get('type') not in ['loop_end', 'end']: + # 标记为已执行(循环体已在循环节点内部执行) + executed_nodes.add(target_id) + # 继续查找循环体内的节点 + self._mark_loop_body_executed(target_id, executed_nodes, active_edges) + else: + # 执行失败或质量不达标 — 支持重试 + failed_status = result.get('status', 'failed') + error_msg = result.get('error', '未知错误') + node_type = node.get('type', 'unknown') + + # 处理 error_handler 返回的 retry_predecessor + if failed_status == 'retry_predecessor': + pred_id = result.get('predecessor_id') + eh_retry_count = result.get('retry_count', 1) + eh_retry_delay_ms = result.get('retry_delay_ms', 1000) + + # 检查是否已有重试计数(多次经过 error_handler) + prev_counter = self.node_outputs.get(f"_eh_retry_{pred_id}", {}) + remaining = prev_counter.get("remaining", eh_retry_count) + if pred_id and pred_id in self.nodes and remaining > 0: + remaining -= 1 + logger.warning( + f"error_handler 请求重试前驱节点 {pred_id},剩余 {remaining} 次" + ) + executed_nodes.discard(pred_id) + self.node_outputs.pop(pred_id, None) + self.node_outputs[f"_eh_retry_{pred_id}"] = { + "remaining": remaining, + "delay_ms": eh_retry_delay_ms, + "error_handler_id": next_node_id, + } + executed_nodes.add(next_node_id) + execution_sequence.append(next_node_id) + results[next_node_id] = result + await asyncio.sleep(eh_retry_delay_ms / 1000.0) + break # 退出 for 循环,while 循环下次迭代处理重试 + elif remaining <= 0: + logger.error(f"error_handler 重试次数耗尽,前驱节点 {pred_id} 停止") + raise WorkflowExecutionError( + detail=f"error_handler 重试次数耗尽: {error_msg}", + node_id=pred_id, + ) + + # 检查节点级 retry_config + node_data = node.get('data', {}) or {} + retry_cfg = node_data.get('retry_config', {}) + max_retries = retry_cfg.get('max_retries', 0) if isinstance(retry_cfg, dict) else 0 + retry_delay_ms = retry_cfg.get('retry_delay_ms', 1000) if isinstance(retry_cfg, dict) else 1000 + on_exhausted = retry_cfg.get('on_exhausted', 'stop') if isinstance(retry_cfg, dict) else 'stop' + + retry_key = f"_retry_{next_node_id}" + retries_done = self.node_outputs.get(retry_key, 0) + + if max_retries > 0 and retries_done < max_retries: + self.node_outputs[retry_key] = retries_done + 1 + logger.warning( + f"节点 {next_node_id} ({node_type}) 执行失败,重试 {retries_done + 1}/{max_retries}: {error_msg}" + ) + await asyncio.sleep(retry_delay_ms / 1000.0) + executed_nodes.discard(next_node_id) + break # 退出 for 循环,while 循环下次迭代重新执行 + + # 重试耗尽或未配置重试 + if retries_done >= max_retries and max_retries > 0: + if on_exhausted == 'skip': + logger.warning(f"节点 {next_node_id} 重试耗尽,跳过: {error_msg}") + self.node_outputs[next_node_id] = { + 'status': 'skipped', 'error': error_msg + } + executed_nodes.add(next_node_id) + execution_sequence.append(next_node_id) + results[next_node_id] = result + continue + elif on_exhausted == 'notify': + logger.error(f"节点 {next_node_id} 重试耗尽,已通知: {error_msg}") + self.node_outputs[next_node_id] = { + 'status': 'error_notified', 'error': error_msg + } + executed_nodes.add(next_node_id) + execution_sequence.append(next_node_id) + results[next_node_id] = result + continue + + # 默认:停止工作流 + logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}") + raise WorkflowExecutionError( + detail=error_msg, + node_id=next_node_id + ) + + # 上报批次进度 + if on_progress and execution_id: + try: + pct = int(len(executed_nodes) / total_nodes * 100) if total_nodes else 100 + await on_progress(execution_id, { + "current": len(executed_nodes), + "total": total_nodes, + "progress": min(pct, 99), + "status": "running", + }) + except Exception: + pass - # 默认:停止工作流 - logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}") - raise WorkflowExecutionError( - detail=error_msg, - node_id=next_node_id - ) - # 返回最终结果:优先取 End 类型且无出边的节点,避免向量写入等侧链与 End 同为 sink 时 # 因 executed_nodes 为 set 迭代顺序不确定而错误返回 upsert 元数据。 if executed_nodes: @@ -6052,7 +6086,19 @@ class WorkflowEngine: # 记录工作流执行完成 if self.logger: self.logger.info("工作流执行完成", data={"result": final_result.get('result')}) - + + # 上报 100% 完成进度 + if on_progress and execution_id: + try: + await on_progress(execution_id, { + "current": len(executed_nodes), + "total": total_nodes, + "progress": 100, + "status": "completed", + }) + except Exception: + pass + return final_result if self.logger: diff --git a/backend/app/tasks/workflow_tasks.py b/backend/app/tasks/workflow_tasks.py index a9eddac..6cdb9b6 100644 --- a/backend/app/tasks/workflow_tasks.py +++ b/backend/app/tasks/workflow_tasks.py @@ -20,7 +20,9 @@ from app.models.workflow import Workflow from app.services.execution_budget import merge_budget_for_execution from app.services.agent_workspace_chat_log import try_append_agent_dialogue_after_success from app.services.notification_service import create_notification +from app.websocket.manager import websocket_manager import asyncio +import json import time from typing import Any, Dict, Optional @@ -42,6 +44,29 @@ def _snapshot_to_jsonable(snapshot: dict) -> dict: return _json.loads(_json.dumps(snapshot, default=str)) +async def _on_workflow_progress(execution_id: str, progress_data: dict): + """工作流进度回调:写入 Redis + WebSocket 广播。""" + try: + from app.core.redis_client import get_redis_client + redis_client = get_redis_client() + if redis_client: + redis_client.setex( + f"workflow:progress:{execution_id}", + 300, + json.dumps(progress_data, ensure_ascii=False), + ) + except Exception: + pass + try: + await websocket_manager.broadcast_to_execution(execution_id, { + "type": "progress", + "execution_id": execution_id, + **progress_data, + }) + except Exception: + pass + + def _trusted_user_for_execution(db, execution: Optional[Execution]) -> Optional[str]: """用于校验 LLM 节点引用的 model_configs 归属(与 Workflow / Agent 所有者一致)。""" if not execution: @@ -195,7 +220,8 @@ def execute_workflow_task( for attempt in range(max_retries + 1): try: result = asyncio.run( - engine.execute(input_data, resume_snapshot=resume_snapshot) + engine.execute(input_data, resume_snapshot=resume_snapshot, + execution_id=execution_id, on_progress=_on_workflow_progress) ) break except WorkflowPaused as paused: @@ -391,7 +417,8 @@ def resume_workflow_task( for attempt in range(max_retries + 1): try: result = asyncio.run( - engine.execute(base_input, resume_snapshot=snapshot) + engine.execute(base_input, resume_snapshot=snapshot, + execution_id=execution_id, on_progress=_on_workflow_progress) ) break except WorkflowPaused as paused: diff --git a/frontend/src/views/AgentChat.vue b/frontend/src/views/AgentChat.vue index 973e676..15c9d52 100644 --- a/frontend/src/views/AgentChat.vue +++ b/frontend/src/views/AgentChat.vue @@ -439,6 +439,7 @@ async function sendMessage() { // 尝试 SSE 流式(带超时控制) let usedStreaming = false + let placeholderIdx = -1 streamingActive.value = false const abortController = new AbortController() const streamTimeout = setTimeout(() => abortController.abort(), 60000) @@ -462,8 +463,8 @@ async function sendMessage() { role: 'assistant', content: '', timestamp: Date.now(), steps: [], _traceOpen: true, iterations: 0, tool_calls_made: 0, } - const idx = messages.value[key].push(msg) - 1 - const currentMsg = messages.value[key][idx] + placeholderIdx = messages.value[key].push(msg) - 1 + const currentMsg = messages.value[key][placeholderIdx] const reader = resp.body.getReader() const decoder = new TextDecoder() @@ -532,14 +533,19 @@ async function sendMessage() { sessionId.value[key] = data.session_id } streamingActive.value = false + loading.value = false } else if (eventType === 'error') { currentMsg.content = data.content || '' currentMsg.status = 'error' streamingActive.value = false + loading.value = false + } else { + console.warn('[AgentChat] 未知 SSE 事件类型:', eventType, data) } } catch { /* 跳过畸形事件 */ } } } + clearTimeout(streamTimeout) } } catch { clearTimeout(streamTimeout) @@ -549,6 +555,11 @@ async function sendMessage() { } if (!usedStreaming) { + // 移除 SSE 阶段残留的占位消息(Issue #1) + if (placeholderIdx >= 0 && placeholderIdx < messages.value[key].length) { + messages.value[key].splice(placeholderIdx, 1) + } + // 降级:标准 POST 请求 const fallbackEndpoint = currentAgentId.value ? `/api/v1/agent-chat/${currentAgentId.value}` @@ -613,9 +624,11 @@ function retryMessage(idx: number) { // 查找错误消息之前的最后一条用户消息 let userMsg = '' + let userIdx = -1 for (let i = idx - 1; i >= 0; i--) { if (msgs[i].role === 'user') { userMsg = msgs[i].content + userIdx = i break } } @@ -624,15 +637,15 @@ function retryMessage(idx: number) { return } - // 移除该错误消息及关联的用户消息 - const removeIndices: number[] = [] - for (let i = idx - 1; i >= 0; i--) { - if (msgs[i].role === 'user' && msgs[i].content === userMsg) { + // 收集需要移除的消息:用户消息 + 错误消息 + 中间的流式占位消息 + const removeIndices: number[] = [userIdx, idx] + for (let i = userIdx + 1; i < idx; i++) { + const m = msgs[i] + // 流式占位消息:assistant 角色,内容为空或有 steps 但无实质内容 + if (m.role === 'assistant' && (!m.content || m.content === '') && (m.steps?.length || 0) > 0) { removeIndices.push(i) - break } } - removeIndices.push(idx) // 从后往前删除,避免 index 错乱 removeIndices.sort((a, b) => b - a)