feat: Phase 3 - parallel execution, progress reporting, result caching + AgentChat bug fixes

Phase 3 能力:
- DAG 并行执行 (workflow_engine): asyncio.gather 并行执行就绪节点
- Debate 并行 (orchestrator): for 循环改为 asyncio.gather
- 粒度进度上报 (workflow_engine + tasks + websocket): Redis 推送 + DB 降级
- 工具结果缓存 (tool_manager): 确定性工具默认开启缓存
- LLM 响应缓存 (core): messages[-4:] + model 哈希,5min TTL

AgentChat bug 修复 (Gitea #1-#5):
- #1 SSE 降级重复空消息: fallback POST 前移除占位消息
- #2 streamTimeout 泄漏: while 正常退出后 clearTimeout
- #3 loading 闪烁: final/error 事件中提前设 loading=false
- #4 SSE 事件类型对齐: 确认匹配,未知类型加 console.warn
- #5 retryMessage 流式残留: 重试时清理占位消息

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
renjianbo
2026-05-05 00:00:51 +08:00
parent f3cb35c460
commit 7e00b027d4
8 changed files with 605 additions and 345 deletions

View File

@@ -10,6 +10,7 @@ Agent Runtime 核心 —— 自主 ReAct 循环。
"""
from __future__ import annotations
import hashlib
import json
import logging
import time
@@ -98,6 +99,9 @@ class AgentRuntime:
self.tool_manager = tool_manager or AgentToolManager(
include_tools=self.config.tools.include_tools,
exclude_tools=self.config.tools.exclude_tools,
cache_enabled=self.config.tools.cache_enabled,
cache_tool_whitelist=self.config.tools.cache_tool_whitelist,
cache_ttl_ms=self.config.tools.cache_ttl_ms,
)
self.execution_logger = execution_logger
self.on_tool_executed = on_tool_executed
@@ -912,6 +916,32 @@ class AgentRuntime:
return any(kw in err_lower for kw in _RETRYABLE_ERRORS)
# LLM 缓存辅助
def _llm_cache_key(messages: list, model: str) -> str:
import hashlib
raw = json.dumps({"msgs": messages[-4:], "model": model}, sort_keys=True, ensure_ascii=False)
return f"llm:{model}:{hashlib.sha256(raw.encode()).hexdigest()[:16]}"
async def _llm_cache_get(key: str) -> Optional[str]:
try:
from app.core.redis_client import get_redis_client
redis = get_redis_client()
if redis:
return await redis.get(key)
except Exception:
pass
return None
async def _llm_cache_set(key: str, value: str, ttl_ms: int):
try:
from app.core.redis_client import get_redis_client
redis = get_redis_client()
if redis:
await redis.setex(key, max(1, int(ttl_ms / 1000)), value)
except Exception:
pass
class _LLMClient:
"""轻量 LLM 客户端包装,复用已有 LLMService 能力。"""
@@ -984,12 +1014,29 @@ class _LLMClient:
kwargs["tools"] = normalized
kwargs["tool_choice"] = "auto"
# LLM 响应缓存(仅不用工具时缓存,避免复杂序列化)
if self._config.cache_enabled and not tools:
cache_key = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
cached = await _llm_cache_get(cache_key)
if cached is not None:
logger.info("LLM 响应命中缓存: model=%s", kwargs.get("model"))
# 构造简易 message 对象(含 content 字段即可)
class _CachedMsg:
content = cached
tool_calls = None
return _CachedMsg()
start_time = time.perf_counter()
try:
response = await client.chat.completions.create(**kwargs)
latency_ms = int((time.perf_counter() - start_time) * 1000)
message = response.choices[0].message
# 缓存写入(仅不用工具时)
if self._config.cache_enabled and not tools and message.content:
ck = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
await _llm_cache_set(ck, message.content, self._config.cache_ttl_ms)
# 提取 token 用量
usage = getattr(response, "usage", None)
prompt_tokens = usage.prompt_tokens if usage else 0

View File

@@ -9,6 +9,7 @@ Agent Orchestrator — 多 Agent 编排引擎。
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
@@ -347,9 +348,10 @@ class AgentOrchestrator:
steps: List[OrchestratorStep] = []
agent_outputs: List[Dict[str, Any]] = []
# 第一阶段:所有 Agent 独立回答
# 第一阶段:所有 Agent 并行独立回答
runtimes = []
for agent_cfg in agents:
runtime = AgentRuntime(
runtimes.append(AgentRuntime(
AgentConfig(
name=agent_cfg.name,
system_prompt=agent_cfg.system_prompt,
@@ -364,8 +366,30 @@ class AgentOrchestrator:
),
),
on_llm_call=on_llm_call,
)
result = await runtime.run(question)
))
results = await asyncio.gather(
*[rt.run(question) for rt in runtimes],
return_exceptions=True,
)
for i, agent_cfg in enumerate(agents):
result = results[i]
if isinstance(result, BaseException):
step = OrchestratorStep(
agent_id=agent_cfg.id,
agent_name=agent_cfg.name,
input=question,
output="",
error=str(result),
)
steps.append(step)
agent_outputs.append({
"agent_id": agent_cfg.id,
"agent_name": agent_cfg.name,
"output": f"[错误] {result}",
})
continue
step = OrchestratorStep(
agent_id=agent_cfg.id,

View File

@@ -15,6 +15,10 @@ class AgentToolConfig(BaseModel):
require_approval: List[str] = Field(default_factory=list, description="需要人工审批的工具名列表")
approval_timeout_ms: int = Field(default=60000, description="审批超时(毫秒),超时使用默认策略")
approval_default: str = Field(default="deny", description="超时默认策略: approve | deny | skip")
# 结果缓存
cache_enabled: bool = Field(default=True, description="是否启用工具结果缓存(确定性工具默认开启)")
cache_tool_whitelist: List[str] = Field(default_factory=list, description="启用缓存的工具名(空=确定性工具默认)")
cache_ttl_ms: int = Field(default=3600000, description="缓存 TTL毫秒默认 1 小时")
class AgentMemoryConfig(BaseModel):
@@ -40,6 +44,8 @@ class AgentLLMConfig(BaseModel):
request_timeout: float = 120.0
extra_body: Optional[Dict[str, Any]] = None
self_review_threshold: float = 0.6 # self-review 通过阈值0-1
cache_enabled: bool = False # LLM 响应缓存(默认关闭,语义缓存有风险)
cache_ttl_ms: int = 300000 # LLM 缓存 TTL默认 5 分钟
class AgentBudgetConfig(BaseModel):

View File

@@ -3,6 +3,8 @@ Agent 工具管理器:包装已有 ToolRegistry提供 Agent 需要的工具
"""
from __future__ import annotations
import hashlib
import json
import logging
from typing import Any, Dict, List, Optional
@@ -10,6 +12,12 @@ from app.services.tool_registry import tool_registry
logger = logging.getLogger(__name__)
# 默认确定性工具(结果可缓存)
_DETERMINISTIC_TOOLS = {
"file_read", "math_calculate", "database_query",
"json", "text", "csv", "excel", "pdf", "image",
}
class AgentToolManager:
"""
@@ -17,12 +25,54 @@ class AgentToolManager:
- 将 ToolRegistry 的工具 schema 转为 OpenAI Function Calling 格式
- 按 Agent 配置过滤(白名单/黑名单)
- 执行工具调用并返回结果字符串
- 工具结果缓存Redis / 内存 fallback
"""
def __init__(self, include_tools: Optional[List[str]] = None,
exclude_tools: Optional[List[str]] = None):
exclude_tools: Optional[List[str]] = None,
cache_enabled: bool = True,
cache_tool_whitelist: Optional[List[str]] = None,
cache_ttl_ms: int = 3600000):
self._include_tools: set = set(include_tools or [])
self._exclude_tools: set = set(exclude_tools or [])
self._cache_enabled = cache_enabled
self._cache_whitelist: set = set(cache_tool_whitelist or [])
self._cache_ttl_s = max(1, int(cache_ttl_ms / 1000))
self._cache_store: Dict[str, str] = {} # 内存 fallback
def _is_cacheable(self, tool_name: str) -> bool:
"""判断工具结果是否可缓存。"""
if not self._cache_enabled:
return False
if self._cache_whitelist:
return tool_name in self._cache_whitelist
return tool_name in _DETERMINISTIC_TOOLS
@staticmethod
def _cache_key(name: str, args: Dict[str, Any]) -> str:
raw = json.dumps([name, args], sort_keys=True, ensure_ascii=False)
return f"tool:{name}:{hashlib.sha256(raw.encode()).hexdigest()[:16]}"
async def _cache_get(self, key: str) -> Optional[str]:
try:
from app.core.redis_client import get_redis_client
redis = get_redis_client()
if redis:
return await redis.get(key)
except Exception:
pass
return self._cache_store.get(key)
async def _cache_set(self, key: str, value: str):
try:
from app.core.redis_client import get_redis_client
redis = get_redis_client()
if redis:
await redis.setex(key, self._cache_ttl_s, value)
return
except Exception:
pass
self._cache_store[key] = value
def get_tool_schemas(self) -> List[Dict[str, Any]]:
"""获取 Agent 可用的工具定义列表OpenAI Function Calling 格式)。"""
@@ -55,7 +105,7 @@ class AgentToolManager:
async def execute(self, name: str, args: Dict[str, Any]) -> str:
"""
执行工具调用。
执行工具调用(带缓存)
优先查找内置工具其次查找数据库自定义工具HTTP / Code
@@ -66,8 +116,24 @@ class AgentToolManager:
Returns:
工具执行结果的字符串表示
"""
# 缓存检查
if self._is_cacheable(name):
ck = self._cache_key(name, args)
cached = await self._cache_get(ck)
if cached is not None:
logger.info("Agent 工具命中缓存: %s", name)
return cached
logger.info("Agent 执行工具: %s", name)
return await tool_registry.execute_tool(name, args)
result = await tool_registry.execute_tool(name, args)
# 缓存写入
if self._is_cacheable(name):
ck = self._cache_key(name, args)
await self._cache_set(ck, result)
logger.debug("Agent 工具结果已缓存: %s", name)
return result
@staticmethod
def _extract_tool_name(schema: Dict[str, Any]) -> Optional[str]:

View File

@@ -13,6 +13,20 @@ import asyncio
router = APIRouter()
def _get_progress_from_redis(execution_id: str) -> Optional[dict]:
"""从 Redis 读取进度数据。"""
try:
from app.core.redis_client import get_redis_client
redis_client = get_redis_client()
if redis_client:
raw = redis_client.get(f"workflow:progress:{execution_id}")
if raw:
return json.loads(raw)
except Exception:
pass
return None
@router.websocket("/api/v1/ws/executions/{execution_id}")
async def websocket_execution_status(
websocket: WebSocket,
@@ -21,31 +35,36 @@ async def websocket_execution_status(
):
"""
WebSocket实时推送执行状态
Args:
websocket: WebSocket连接
execution_id: 执行记录ID
token: JWT Token可选通过query参数传递
"""
# 验证token可选如果需要认证
# user = await get_current_user_optional(token)
# 建立连接
await websocket_manager.connect(websocket, execution_id)
db = SessionLocal()
try:
# 发送初始状态
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
await websocket_manager.send_personal_message({
"type": "status",
"execution_id": execution_id,
"status": execution.status,
"progress": 0,
"message": "连接已建立"
}, websocket)
# 尝试从 Redis 读取进度
redis_progress = _get_progress_from_redis(execution_id)
if redis_progress:
await websocket_manager.send_personal_message({
"type": "progress",
"execution_id": execution_id,
**redis_progress,
}, websocket)
else:
await websocket_manager.send_personal_message({
"type": "status",
"execution_id": execution_id,
"status": execution.status,
"progress": 0,
"message": "连接已建立"
}, websocket)
else:
await websocket_manager.send_personal_message({
"type": "error",
@@ -53,30 +72,56 @@ async def websocket_execution_status(
}, websocket)
await websocket.close()
return
last_progress = -1
# 持续监听并推送状态更新
while True:
try:
# 接收客户端消息(心跳等)
data = await websocket.receive_text()
# 处理客户端消息
# 接收客户端消息(心跳等),超时 1 秒以便轮询进度
try:
message = json.loads(data)
if message.get("type") == "ping":
await websocket_manager.send_personal_message({
"type": "pong"
}, websocket)
except:
pass
data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0)
try:
message = json.loads(data)
if message.get("type") == "ping":
await websocket_manager.send_personal_message({
"type": "pong"
}, websocket)
except Exception:
pass
except asyncio.TimeoutError:
pass # 超时后轮询进度
except WebSocketDisconnect:
break
# 检查执行状态
db.refresh(execution)
# 优先从 Redis 读取进度(推送模式)
redis_progress = _get_progress_from_redis(execution_id)
if redis_progress:
pct = redis_progress.get("progress", -1)
if pct != last_progress:
last_progress = pct
await websocket_manager.send_personal_message({
"type": "progress",
"execution_id": execution_id,
**redis_progress,
}, websocket)
else:
# Redis 不可用时回退到 DB 轮询
db.refresh(execution)
pct = 100 if execution.status in ["completed", "failed"] else (50 if execution.status == "running" else 0)
if pct != last_progress:
last_progress = pct
await websocket_manager.send_personal_message({
"type": "status",
"execution_id": execution_id,
"status": execution.status,
"progress": pct,
"message": f"执行中..." if execution.status == "running" else "等待执行"
}, websocket)
# 如果执行完成或失败,发送最终状态并断开
db.refresh(execution)
if execution.status in ["completed", "failed"]:
await websocket_manager.send_personal_message({
"type": "status",
@@ -87,24 +132,10 @@ async def websocket_execution_status(
"error": execution.error_message if execution.status == "failed" else None,
"execution_time": execution.execution_time
}, websocket)
# 等待一下再断开,确保客户端收到消息
await asyncio.sleep(1)
break
# 定期发送状态更新每2秒
await asyncio.sleep(2)
# 重新查询执行状态
db.refresh(execution)
await websocket_manager.send_personal_message({
"type": "status",
"execution_id": execution_id,
"status": execution.status,
"progress": 50 if execution.status == "running" else 0,
"message": f"执行中..." if execution.status == "running" else "等待执行"
}, websocket)
except WebSocketDisconnect:
pass
except Exception as e:
@@ -114,7 +145,7 @@ async def websocket_execution_status(
"type": "error",
"message": f"发生错误: {str(e)}"
}, websocket)
except:
except Exception:
pass
finally:
websocket_manager.disconnect(websocket, execution_id)

View File

@@ -5589,14 +5589,18 @@ class WorkflowEngine:
self,
input_data: Dict[str, Any],
resume_snapshot: Optional[Dict[str, Any]] = None,
execution_id: Optional[str] = None,
on_progress: Optional[callable] = None,
) -> Dict[str, Any]:
"""
执行完整工作流
Args:
input_data: 初始输入数据(恢复执行时须包含 __hil_decision 等)
resume_snapshot: 从挂起快照恢复(与 pause_state 一致)
execution_id: 执行记录 ID用于进度上报
on_progress: 进度回调 async def(execution_id, progress_data)
Returns:
执行结果
"""
@@ -5641,6 +5645,8 @@ class WorkflowEngine:
self._tool_calls_used = 0
results = {}
total_nodes = len(self.nodes)
# 按拓扑顺序执行节点(动态构建执行图)
while True:
# 构建当前活跃的执行图
@@ -5656,12 +5662,12 @@ class WorkflowEngine:
f"[rjb] 当前执行图: {execution_order}, 活跃边数: {len(active_edges)}, 已执行节点: {executed_nodes}"
)
next_node_id = None
# 收集所有就绪节点DAG 并行)
ready_nodes: list = []
for node_id in pending_ids:
can_execute = True
incoming_edges = [e for e in active_edges if e["target"] == node_id]
if not incoming_edges:
# 没有入边:仅允许 Start孤立节点跳过
if node_id not in [n["id"] for n in self.nodes.values() if n.get("type") == "start"]:
logger.debug(f"[rjb] 节点 {node_id} 没有入边,跳过执行")
continue
@@ -5669,7 +5675,6 @@ class WorkflowEngine:
for edge in incoming_edges:
src = edge["source"]
if src not in forward_reachable:
# 条件分支裁剪后不可达的前驱,不参与 gateOR-join
continue
if src not in executed_nodes:
can_execute = False
@@ -5678,298 +5683,327 @@ class WorkflowEngine:
)
break
if can_execute:
next_node_id = node_id
ready_nodes.append(node_id)
logger.info(
f"[rjb] 选择执行节点: {next_node_id}, 类型: {self.nodes[next_node_id].get('type')}, 入边数: {len(incoming_edges)}"
f"[rjb] 就绪节点: {node_id}, 类型: {self.nodes[node_id].get('type')}"
)
break
if not next_node_id:
if not ready_nodes:
break # 没有更多节点可执行
node = self.nodes[next_node_id]
is_approval = node.get("type") == "approval"
if not is_approval:
executed_nodes.add(next_node_id)
execution_sequence.append(next_node_id)
# 调试:检查节点数据结构
if node.get('type') == 'llm':
logger.debug(f"[rjb] 执行LLM节点: node_id={next_node_id}, node keys={list(node.keys())}, data keys={list(node.get('data', {}).keys()) if node.get('data') else []}")
# 获取节点输入(使用活跃的边)
node_input = self.get_node_input(next_node_id, self.node_outputs, active_edges)
# 如果是起始节点,使用初始输入
if node.get('type') == 'start' and not node_input:
node_input = input_data
logger.info(f"[rjb] Start节点使用初始输入: node_id={next_node_id}, node_input={node_input}")
# 调试:记录节点输入数据
if node.get('type') == 'llm':
logger.info(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}")
if 'start-1' in self.node_outputs:
logger.info(f"[rjb] Start节点输出内容: {self.node_outputs['start-1']}")
# 标记所有就绪节点为已执行approval 节点除外)
for nid in ready_nodes:
node = self.nodes[nid]
if node.get("type") != "approval":
executed_nodes.add(nid)
execution_sequence.append(nid)
# 单执行步数预算(每执行一个节点计 1 步)
self._steps_used += 1
# 获取所有就绪节点的输入
node_inputs: dict = {}
for nid in ready_nodes:
node = self.nodes[nid]
node_input = self.get_node_input(nid, self.node_outputs, active_edges)
if node.get('type') == 'start' and not node_input:
node_input = input_data
node_inputs[nid] = node_input
# 预算检查(整批节点)
self._steps_used += len(ready_nodes)
if self._steps_used > self._cap_steps:
raise WorkflowExecutionError(
detail=f"已超过单执行预算上限({self._cap_steps} 步),已熔断",
node_id=next_node_id,
node_id=ready_nodes[0],
)
# 执行节点
result = await self.execute_node(node, node_input)
# 并行执行所有就绪节点
async def _exec_one(nid: str):
try:
return await self.execute_node(self.nodes[nid], node_inputs[nid])
except Exception as exc:
return {"status": "failed", "error": str(exc)}
if result.get("status") == "awaiting_approval":
self._steps_used -= 1
snap = self._build_pause_snapshot(
next_node_id, active_edges, executed_nodes, execution_sequence, results
)
raise WorkflowPaused(snap)
if is_approval:
executed_nodes.add(next_node_id)
execution_sequence.append(next_node_id)
results[next_node_id] = result
# 保存节点输出
if result.get('status') == 'success':
output_value = result.get('output', {})
self.node_outputs[next_node_id] = output_value
if node.get('type') == 'start':
logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}")
# 如果是条件节点或Switch节点根据分支结果过滤边
if node.get('type') == 'condition':
branch = result.get('branch', 'false')
logger.info(f"[rjb] 条件节点分支过滤: node_id={next_node_id}, branch={branch}")
# 移除不符合条件的边
# 只保留1) 不是从条件节点出发的边,或 2) 从条件节点出发且sourceHandle匹配分支的边
edges_to_remove = []
edges_to_keep = []
for edge in active_edges:
if edge['source'] == next_node_id:
# 这是从条件节点出发的边
edge_handle = edge.get('sourceHandle')
if edge_handle == branch:
# sourceHandle匹配分支保留
edges_to_keep.append(edge)
logger.info(f"[rjb] 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
else:
# sourceHandle不匹配或为None移除
edges_to_remove.append(edge)
logger.info(f"[rjb] 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch})")
else:
# 不是从条件节点出发的边,保留
edges_to_keep.append(edge)
active_edges = edges_to_keep
elif node.get('type') == 'switch':
branch = result.get('branch', 'default')
logger.info(f"[rjb] Switch节点分支过滤: node_id={next_node_id}, branch={branch}")
# 记录过滤前的边信息
edges_before = [e for e in active_edges if e['source'] == next_node_id]
logger.info(f"[rjb] Switch节点过滤前: 从节点出发的边有{len(edges_before)}")
for edge in edges_before:
logger.info(f"[rjb] 边 {edge.get('id')}: sourceHandle={edge.get('sourceHandle')}, target={edge.get('target')}")
# 移除不匹配的边
edges_to_keep = []
edges_removed_count = 0
removed_source_nodes = set() # 记录被移除边的源节点
for edge in active_edges:
if edge['source'] == next_node_id:
# 这是从Switch节点出发的边
edge_handle = edge.get('sourceHandle')
if edge_handle == branch:
# sourceHandle匹配分支保留
edges_to_keep.append(edge)
logger.info(f"[rjb] ✅ 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
else:
# sourceHandle不匹配移除
edges_removed_count += 1
target_id = edge.get('target')
removed_source_nodes.add(target_id) # 记录目标节点(这些节点将不再可达)
logger.info(f"[rjb] ❌ 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch}, target={target_id})")
else:
# 不是从Switch节点出发的边保留
edges_to_keep.append(edge)
# 重要移除那些指向被过滤节点的边这些边来自被过滤的LLM节点
# 例如如果llm-question被过滤了那么llm-question → merge-response的边也应该被移除
additional_removed = 0
for edge in list(edges_to_keep): # 使用list副本因为我们要修改原列表
if edge['source'] in removed_source_nodes:
# 这条边来自被过滤的节点,也应该被移除
edges_to_keep.remove(edge)
additional_removed += 1
logger.info(f"[rjb] ❌ 移除来自被过滤节点的边: {edge.get('id')} ({edge.get('source')}{edge.get('target')})")
edges_removed_count += additional_removed
active_edges = edges_to_keep
filter_info = {
'branch': branch,
'edges_before': len(edges_before),
'edges_kept': len([e for e in edges_to_keep if e['source'] == next_node_id]),
'edges_removed': edges_removed_count
}
logger.info(f"[rjb] Switch节点过滤后: 保留{len(active_edges)}条边其中从Switch节点出发的{filter_info['edges_kept']}条),移除{edges_removed_count}条边")
# 记录过滤后的活跃边
remaining_switch_edges = [e for e in active_edges if e['source'] == next_node_id]
logger.info(f"[rjb] Switch节点过滤后剩余的边: {[e.get('id') + '->' + e.get('target') for e in remaining_switch_edges]}")
# 重要:找出那些不再可达的节点(这些节点只通过被移除的边连接)
removed_targets = set()
for edge in edges_before:
if edge not in edges_to_keep:
target_id = edge.get('target')
removed_targets.add(target_id)
logger.info(f"[rjb] ❌ 节点 {target_id} 的边已被移除,该节点将不会被执行")
# 关键修复:立即重新构建执行图,确保不再可达的节点不在执行图中
# 这样在下次循环时,这些节点就不会被选择执行
logger.info(f"[rjb] Switch节点过滤后重新构建执行图排除 {len(removed_targets)} 个不再可达的节点)")
# 同时记录到数据库
if self.logger:
self.logger.info(
f"Switch节点分支过滤: branch={branch}, 保留{filter_info['edges_kept']}条边,移除{edges_removed_count}条边",
node_id=next_node_id,
node_type='switch',
data=filter_info
)
elif node.get('type') == 'approval':
branch = result.get('branch', 'approved')
logger.info(f"[rjb] Approval节点分支过滤: node_id={next_node_id}, branch={branch}")
edges_to_keep = []
removed_source_nodes = set()
for edge in active_edges:
if edge['source'] == next_node_id:
edge_handle = edge.get('sourceHandle')
if edge_handle == branch:
edges_to_keep.append(edge)
else:
removed_source_nodes.add(edge.get('target'))
else:
edges_to_keep.append(edge)
for edge in list(edges_to_keep):
if edge['source'] in removed_source_nodes:
edges_to_keep.remove(edge)
active_edges = edges_to_keep
if self.logger:
self.logger.info(
f"Approval节点分支过滤: branch={branch}",
node_id=next_node_id,
node_type='approval',
)
# 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行)
if node.get('type') in ['loop', 'foreach']:
# 标记循环体的节点为已执行(简化处理)
for edge in active_edges[:]: # 使用切片复制列表
if edge.get('source') == next_node_id:
target_id = edge.get('target')
if target_id in self.nodes:
# 检查是否是循环结束节点
target_node = self.nodes[target_id]
if target_node.get('type') not in ['loop_end', 'end']:
# 标记为已执行(循环体已在循环节点内部执行)
executed_nodes.add(target_id)
# 继续查找循环体内的节点
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
if len(ready_nodes) == 1:
batch_results = {ready_nodes[0]: await _exec_one(ready_nodes[0])}
else:
# 执行失败或质量不达标 — 支持重试
failed_status = result.get('status', 'failed')
error_msg = result.get('error', '未知错误')
node_type = node.get('type', 'unknown')
logger.info(f"[rjb] 并行执行 {len(ready_nodes)} 个节点: {ready_nodes}")
tasks = [_exec_one(nid) for nid in ready_nodes]
gathered = await asyncio.gather(*tasks, return_exceptions=True)
batch_results = {}
for i, nid in enumerate(ready_nodes):
r = gathered[i]
if isinstance(r, BaseException):
batch_results[nid] = {"status": "failed", "error": str(r)}
else:
batch_results[nid] = r
# 处理 error_handler 返回的 retry_predecessor
if failed_status == 'retry_predecessor':
pred_id = result.get('predecessor_id')
eh_retry_count = result.get('retry_count', 1)
eh_retry_delay_ms = result.get('retry_delay_ms', 1000)
# 逐一处理各节点结果
for next_node_id in ready_nodes:
node = self.nodes[next_node_id]
result = batch_results[next_node_id]
is_approval = node.get("type") == "approval"
# 检查是否已有重试计数(多次经过 error_handler
prev_counter = self.node_outputs.get(f"_eh_retry_{pred_id}", {})
remaining = prev_counter.get("remaining", eh_retry_count)
if pred_id and pred_id in self.nodes and remaining > 0:
remaining -= 1
logger.warning(
f"error_handler 请求重试前驱节点 {pred_id},剩余 {remaining}"
)
self.executed_nodes.discard(pred_id)
self.node_outputs.pop(pred_id, None)
self.node_outputs[f"_eh_retry_{pred_id}"] = {
"remaining": remaining,
"delay_ms": eh_retry_delay_ms,
"error_handler_id": next_node_id,
}
executed_nodes.add(next_node_id)
execution_sequence.append(next_node_id)
results[next_node_id] = result
await asyncio.sleep(eh_retry_delay_ms / 1000.0)
continue
elif remaining <= 0:
logger.error(f"error_handler 重试次数耗尽,前驱节点 {pred_id} 停止")
raise WorkflowExecutionError(
detail=f"error_handler 重试次数耗尽: {error_msg}",
node_id=pred_id,
)
# 检查节点级 retry_config
node_data = node.get('data', {}) or {}
retry_cfg = node_data.get('retry_config', {})
max_retries = retry_cfg.get('max_retries', 0) if isinstance(retry_cfg, dict) else 0
retry_delay_ms = retry_cfg.get('retry_delay_ms', 1000) if isinstance(retry_cfg, dict) else 1000
on_exhausted = retry_cfg.get('on_exhausted', 'stop') if isinstance(retry_cfg, dict) else 'stop'
retry_key = f"_retry_{next_node_id}"
retries_done = self.node_outputs.get(retry_key, 0)
if max_retries > 0 and retries_done < max_retries:
self.node_outputs[retry_key] = retries_done + 1
logger.warning(
f"节点 {next_node_id} ({node_type}) 执行失败,重试 {retries_done + 1}/{max_retries}: {error_msg}"
if result.get("status") == "awaiting_approval":
self._steps_used -= 1
snap = self._build_pause_snapshot(
next_node_id, active_edges, executed_nodes, execution_sequence, results
)
await asyncio.sleep(retry_delay_ms / 1000.0)
continue # 不标记已执行,下次循环重新执行
raise WorkflowPaused(snap)
# 重试耗尽或未配置重试
if retries_done >= max_retries and max_retries > 0:
if on_exhausted == 'skip':
logger.warning(f"节点 {next_node_id} 重试耗尽,跳过: {error_msg}")
self.node_outputs[next_node_id] = {
'status': 'skipped', 'error': error_msg
}
executed_nodes.add(next_node_id)
execution_sequence.append(next_node_id)
results[next_node_id] = result
continue
elif on_exhausted == 'notify':
logger.error(f"节点 {next_node_id} 重试耗尽,已通知: {error_msg}")
self.node_outputs[next_node_id] = {
'status': 'error_notified', 'error': error_msg
}
executed_nodes.add(next_node_id)
execution_sequence.append(next_node_id)
results[next_node_id] = result
continue
if is_approval:
executed_nodes.add(next_node_id)
execution_sequence.append(next_node_id)
results[next_node_id] = result
# 保存节点输出
if result.get('status') == 'success':
output_value = result.get('output', {})
self.node_outputs[next_node_id] = output_value
if node.get('type') == 'start':
logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}")
# 如果是条件节点或Switch节点根据分支结果过滤边
if node.get('type') == 'condition':
branch = result.get('branch', 'false')
logger.info(f"[rjb] 条件节点分支过滤: node_id={next_node_id}, branch={branch}")
# 移除不符合条件的边
# 只保留1) 不是从条件节点出发的边,或 2) 从条件节点出发且sourceHandle匹配分支的边
edges_to_remove = []
edges_to_keep = []
for edge in active_edges:
if edge['source'] == next_node_id:
# 这是从条件节点出发的边
edge_handle = edge.get('sourceHandle')
if edge_handle == branch:
# sourceHandle匹配分支保留
edges_to_keep.append(edge)
logger.info(f"[rjb] 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
else:
# sourceHandle不匹配或为None移除
edges_to_remove.append(edge)
logger.info(f"[rjb] 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch})")
else:
# 不是从条件节点出发的边,保留
edges_to_keep.append(edge)
active_edges = edges_to_keep
elif node.get('type') == 'switch':
branch = result.get('branch', 'default')
logger.info(f"[rjb] Switch节点分支过滤: node_id={next_node_id}, branch={branch}")
# 记录过滤前的边信息
edges_before = [e for e in active_edges if e['source'] == next_node_id]
logger.info(f"[rjb] Switch节点过滤前: 从节点出发的边有{len(edges_before)}")
for edge in edges_before:
logger.info(f"[rjb] 边 {edge.get('id')}: sourceHandle={edge.get('sourceHandle')}, target={edge.get('target')}")
# 移除不匹配的边
edges_to_keep = []
edges_removed_count = 0
removed_source_nodes = set() # 记录被移除边的源节点
for edge in active_edges:
if edge['source'] == next_node_id:
# 这是从Switch节点出发的边
edge_handle = edge.get('sourceHandle')
if edge_handle == branch:
# sourceHandle匹配分支保留
edges_to_keep.append(edge)
logger.info(f"[rjb] ✅ 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
else:
# sourceHandle不匹配移除
edges_removed_count += 1
target_id = edge.get('target')
removed_source_nodes.add(target_id) # 记录目标节点(这些节点将不再可达)
logger.info(f"[rjb] ❌ 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch}, target={target_id})")
else:
# 不是从Switch节点出发的边保留
edges_to_keep.append(edge)
# 重要移除那些指向被过滤节点的边这些边来自被过滤的LLM节点
# 例如如果llm-question被过滤了那么llm-question → merge-response的边也应该被移除
additional_removed = 0
for edge in list(edges_to_keep): # 使用list副本因为我们要修改原列表
if edge['source'] in removed_source_nodes:
# 这条边来自被过滤的节点,也应该被移除
edges_to_keep.remove(edge)
additional_removed += 1
logger.info(f"[rjb] ❌ 移除来自被过滤节点的边: {edge.get('id')} ({edge.get('source')}{edge.get('target')})")
edges_removed_count += additional_removed
active_edges = edges_to_keep
filter_info = {
'branch': branch,
'edges_before': len(edges_before),
'edges_kept': len([e for e in edges_to_keep if e['source'] == next_node_id]),
'edges_removed': edges_removed_count
}
logger.info(f"[rjb] Switch节点过滤后: 保留{len(active_edges)}条边其中从Switch节点出发的{filter_info['edges_kept']}条),移除{edges_removed_count}条边")
# 记录过滤后的活跃边
remaining_switch_edges = [e for e in active_edges if e['source'] == next_node_id]
logger.info(f"[rjb] Switch节点过滤后剩余的边: {[e.get('id') + '->' + e.get('target') for e in remaining_switch_edges]}")
# 重要:找出那些不再可达的节点(这些节点只通过被移除的边连接)
removed_targets = set()
for edge in edges_before:
if edge not in edges_to_keep:
target_id = edge.get('target')
removed_targets.add(target_id)
logger.info(f"[rjb] ❌ 节点 {target_id} 的边已被移除,该节点将不会被执行")
# 关键修复:立即重新构建执行图,确保不再可达的节点不在执行图中
# 这样在下次循环时,这些节点就不会被选择执行
logger.info(f"[rjb] Switch节点过滤后重新构建执行图排除 {len(removed_targets)} 个不再可达的节点)")
# 同时记录到数据库
if self.logger:
self.logger.info(
f"Switch节点分支过滤: branch={branch}, 保留{filter_info['edges_kept']}条边,移除{edges_removed_count}条边",
node_id=next_node_id,
node_type='switch',
data=filter_info
)
elif node.get('type') == 'approval':
branch = result.get('branch', 'approved')
logger.info(f"[rjb] Approval节点分支过滤: node_id={next_node_id}, branch={branch}")
edges_to_keep = []
removed_source_nodes = set()
for edge in active_edges:
if edge['source'] == next_node_id:
edge_handle = edge.get('sourceHandle')
if edge_handle == branch:
edges_to_keep.append(edge)
else:
removed_source_nodes.add(edge.get('target'))
else:
edges_to_keep.append(edge)
for edge in list(edges_to_keep):
if edge['source'] in removed_source_nodes:
edges_to_keep.remove(edge)
active_edges = edges_to_keep
if self.logger:
self.logger.info(
f"Approval节点分支过滤: branch={branch}",
node_id=next_node_id,
node_type='approval',
)
# 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行)
if node.get('type') in ['loop', 'foreach']:
# 标记循环体的节点为已执行(简化处理)
for edge in active_edges[:]: # 使用切片复制列表
if edge.get('source') == next_node_id:
target_id = edge.get('target')
if target_id in self.nodes:
# 检查是否是循环结束节点
target_node = self.nodes[target_id]
if target_node.get('type') not in ['loop_end', 'end']:
# 标记为已执行(循环体已在循环节点内部执行)
executed_nodes.add(target_id)
# 继续查找循环体内的节点
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
else:
# 执行失败或质量不达标 — 支持重试
failed_status = result.get('status', 'failed')
error_msg = result.get('error', '未知错误')
node_type = node.get('type', 'unknown')
# 处理 error_handler 返回的 retry_predecessor
if failed_status == 'retry_predecessor':
pred_id = result.get('predecessor_id')
eh_retry_count = result.get('retry_count', 1)
eh_retry_delay_ms = result.get('retry_delay_ms', 1000)
# 检查是否已有重试计数(多次经过 error_handler
prev_counter = self.node_outputs.get(f"_eh_retry_{pred_id}", {})
remaining = prev_counter.get("remaining", eh_retry_count)
if pred_id and pred_id in self.nodes and remaining > 0:
remaining -= 1
logger.warning(
f"error_handler 请求重试前驱节点 {pred_id},剩余 {remaining}"
)
executed_nodes.discard(pred_id)
self.node_outputs.pop(pred_id, None)
self.node_outputs[f"_eh_retry_{pred_id}"] = {
"remaining": remaining,
"delay_ms": eh_retry_delay_ms,
"error_handler_id": next_node_id,
}
executed_nodes.add(next_node_id)
execution_sequence.append(next_node_id)
results[next_node_id] = result
await asyncio.sleep(eh_retry_delay_ms / 1000.0)
break # 退出 for 循环while 循环下次迭代处理重试
elif remaining <= 0:
logger.error(f"error_handler 重试次数耗尽,前驱节点 {pred_id} 停止")
raise WorkflowExecutionError(
detail=f"error_handler 重试次数耗尽: {error_msg}",
node_id=pred_id,
)
# 检查节点级 retry_config
node_data = node.get('data', {}) or {}
retry_cfg = node_data.get('retry_config', {})
max_retries = retry_cfg.get('max_retries', 0) if isinstance(retry_cfg, dict) else 0
retry_delay_ms = retry_cfg.get('retry_delay_ms', 1000) if isinstance(retry_cfg, dict) else 1000
on_exhausted = retry_cfg.get('on_exhausted', 'stop') if isinstance(retry_cfg, dict) else 'stop'
retry_key = f"_retry_{next_node_id}"
retries_done = self.node_outputs.get(retry_key, 0)
if max_retries > 0 and retries_done < max_retries:
self.node_outputs[retry_key] = retries_done + 1
logger.warning(
f"节点 {next_node_id} ({node_type}) 执行失败,重试 {retries_done + 1}/{max_retries}: {error_msg}"
)
await asyncio.sleep(retry_delay_ms / 1000.0)
executed_nodes.discard(next_node_id)
break # 退出 for 循环while 循环下次迭代重新执行
# 重试耗尽或未配置重试
if retries_done >= max_retries and max_retries > 0:
if on_exhausted == 'skip':
logger.warning(f"节点 {next_node_id} 重试耗尽,跳过: {error_msg}")
self.node_outputs[next_node_id] = {
'status': 'skipped', 'error': error_msg
}
executed_nodes.add(next_node_id)
execution_sequence.append(next_node_id)
results[next_node_id] = result
continue
elif on_exhausted == 'notify':
logger.error(f"节点 {next_node_id} 重试耗尽,已通知: {error_msg}")
self.node_outputs[next_node_id] = {
'status': 'error_notified', 'error': error_msg
}
executed_nodes.add(next_node_id)
execution_sequence.append(next_node_id)
results[next_node_id] = result
continue
# 默认:停止工作流
logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}")
raise WorkflowExecutionError(
detail=error_msg,
node_id=next_node_id
)
# 上报批次进度
if on_progress and execution_id:
try:
pct = int(len(executed_nodes) / total_nodes * 100) if total_nodes else 100
await on_progress(execution_id, {
"current": len(executed_nodes),
"total": total_nodes,
"progress": min(pct, 99),
"status": "running",
})
except Exception:
pass
# 默认:停止工作流
logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}")
raise WorkflowExecutionError(
detail=error_msg,
node_id=next_node_id
)
# 返回最终结果:优先取 End 类型且无出边的节点,避免向量写入等侧链与 End 同为 sink 时
# 因 executed_nodes 为 set 迭代顺序不确定而错误返回 upsert 元数据。
if executed_nodes:
@@ -6052,7 +6086,19 @@ class WorkflowEngine:
# 记录工作流执行完成
if self.logger:
self.logger.info("工作流执行完成", data={"result": final_result.get('result')})
# 上报 100% 完成进度
if on_progress and execution_id:
try:
await on_progress(execution_id, {
"current": len(executed_nodes),
"total": total_nodes,
"progress": 100,
"status": "completed",
})
except Exception:
pass
return final_result
if self.logger:

View File

@@ -20,7 +20,9 @@ from app.models.workflow import Workflow
from app.services.execution_budget import merge_budget_for_execution
from app.services.agent_workspace_chat_log import try_append_agent_dialogue_after_success
from app.services.notification_service import create_notification
from app.websocket.manager import websocket_manager
import asyncio
import json
import time
from typing import Any, Dict, Optional
@@ -42,6 +44,29 @@ def _snapshot_to_jsonable(snapshot: dict) -> dict:
return _json.loads(_json.dumps(snapshot, default=str))
async def _on_workflow_progress(execution_id: str, progress_data: dict):
"""工作流进度回调:写入 Redis + WebSocket 广播。"""
try:
from app.core.redis_client import get_redis_client
redis_client = get_redis_client()
if redis_client:
redis_client.setex(
f"workflow:progress:{execution_id}",
300,
json.dumps(progress_data, ensure_ascii=False),
)
except Exception:
pass
try:
await websocket_manager.broadcast_to_execution(execution_id, {
"type": "progress",
"execution_id": execution_id,
**progress_data,
})
except Exception:
pass
def _trusted_user_for_execution(db, execution: Optional[Execution]) -> Optional[str]:
"""用于校验 LLM 节点引用的 model_configs 归属(与 Workflow / Agent 所有者一致)。"""
if not execution:
@@ -195,7 +220,8 @@ def execute_workflow_task(
for attempt in range(max_retries + 1):
try:
result = asyncio.run(
engine.execute(input_data, resume_snapshot=resume_snapshot)
engine.execute(input_data, resume_snapshot=resume_snapshot,
execution_id=execution_id, on_progress=_on_workflow_progress)
)
break
except WorkflowPaused as paused:
@@ -391,7 +417,8 @@ def resume_workflow_task(
for attempt in range(max_retries + 1):
try:
result = asyncio.run(
engine.execute(base_input, resume_snapshot=snapshot)
engine.execute(base_input, resume_snapshot=snapshot,
execution_id=execution_id, on_progress=_on_workflow_progress)
)
break
except WorkflowPaused as paused:

View File

@@ -439,6 +439,7 @@ async function sendMessage() {
// 尝试 SSE 流式(带超时控制)
let usedStreaming = false
let placeholderIdx = -1
streamingActive.value = false
const abortController = new AbortController()
const streamTimeout = setTimeout(() => abortController.abort(), 60000)
@@ -462,8 +463,8 @@ async function sendMessage() {
role: 'assistant', content: '', timestamp: Date.now(),
steps: [], _traceOpen: true, iterations: 0, tool_calls_made: 0,
}
const idx = messages.value[key].push(msg) - 1
const currentMsg = messages.value[key][idx]
placeholderIdx = messages.value[key].push(msg) - 1
const currentMsg = messages.value[key][placeholderIdx]
const reader = resp.body.getReader()
const decoder = new TextDecoder()
@@ -532,14 +533,19 @@ async function sendMessage() {
sessionId.value[key] = data.session_id
}
streamingActive.value = false
loading.value = false
} else if (eventType === 'error') {
currentMsg.content = data.content || ''
currentMsg.status = 'error'
streamingActive.value = false
loading.value = false
} else {
console.warn('[AgentChat] 未知 SSE 事件类型:', eventType, data)
}
} catch { /* 跳过畸形事件 */ }
}
}
clearTimeout(streamTimeout)
}
} catch {
clearTimeout(streamTimeout)
@@ -549,6 +555,11 @@ async function sendMessage() {
}
if (!usedStreaming) {
// 移除 SSE 阶段残留的占位消息Issue #1
if (placeholderIdx >= 0 && placeholderIdx < messages.value[key].length) {
messages.value[key].splice(placeholderIdx, 1)
}
// 降级:标准 POST 请求
const fallbackEndpoint = currentAgentId.value
? `/api/v1/agent-chat/${currentAgentId.value}`
@@ -613,9 +624,11 @@ function retryMessage(idx: number) {
// 查找错误消息之前的最后一条用户消息
let userMsg = ''
let userIdx = -1
for (let i = idx - 1; i >= 0; i--) {
if (msgs[i].role === 'user') {
userMsg = msgs[i].content
userIdx = i
break
}
}
@@ -624,15 +637,15 @@ function retryMessage(idx: number) {
return
}
// 移除该错误消息及关联的用户消息
const removeIndices: number[] = []
for (let i = idx - 1; i >= 0; i--) {
if (msgs[i].role === 'user' && msgs[i].content === userMsg) {
// 收集需要移除的消息:用户消息 + 错误消息 + 中间的流式占位消息
const removeIndices: number[] = [userIdx, idx]
for (let i = userIdx + 1; i < idx; i++) {
const m = msgs[i]
// 流式占位消息assistant 角色,内容为空或有 steps 但无实质内容
if (m.role === 'assistant' && (!m.content || m.content === '') && (m.steps?.length || 0) > 0) {
removeIndices.push(i)
break
}
}
removeIndices.push(idx)
// 从后往前删除,避免 index 错乱
removeIndices.sort((a, b) => b - a)