feat: Phase 3 - parallel execution, progress reporting, result caching + AgentChat bug fixes
Phase 3 能力: - DAG 并行执行 (workflow_engine): asyncio.gather 并行执行就绪节点 - Debate 并行 (orchestrator): for 循环改为 asyncio.gather - 粒度进度上报 (workflow_engine + tasks + websocket): Redis 推送 + DB 降级 - 工具结果缓存 (tool_manager): 确定性工具默认开启缓存 - LLM 响应缓存 (core): messages[-4:] + model 哈希,5min TTL AgentChat bug 修复 (Gitea #1-#5): - #1 SSE 降级重复空消息: fallback POST 前移除占位消息 - #2 streamTimeout 泄漏: while 正常退出后 clearTimeout - #3 loading 闪烁: final/error 事件中提前设 loading=false - #4 SSE 事件类型对齐: 确认匹配,未知类型加 console.warn - #5 retryMessage 流式残留: 重试时清理占位消息 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,7 @@ Agent Runtime 核心 —— 自主 ReAct 循环。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
@@ -98,6 +99,9 @@ class AgentRuntime:
|
||||
self.tool_manager = tool_manager or AgentToolManager(
|
||||
include_tools=self.config.tools.include_tools,
|
||||
exclude_tools=self.config.tools.exclude_tools,
|
||||
cache_enabled=self.config.tools.cache_enabled,
|
||||
cache_tool_whitelist=self.config.tools.cache_tool_whitelist,
|
||||
cache_ttl_ms=self.config.tools.cache_ttl_ms,
|
||||
)
|
||||
self.execution_logger = execution_logger
|
||||
self.on_tool_executed = on_tool_executed
|
||||
@@ -912,6 +916,32 @@ class AgentRuntime:
|
||||
return any(kw in err_lower for kw in _RETRYABLE_ERRORS)
|
||||
|
||||
|
||||
# LLM 缓存辅助
|
||||
def _llm_cache_key(messages: list, model: str) -> str:
|
||||
import hashlib
|
||||
raw = json.dumps({"msgs": messages[-4:], "model": model}, sort_keys=True, ensure_ascii=False)
|
||||
return f"llm:{model}:{hashlib.sha256(raw.encode()).hexdigest()[:16]}"
|
||||
|
||||
async def _llm_cache_get(key: str) -> Optional[str]:
|
||||
try:
|
||||
from app.core.redis_client import get_redis_client
|
||||
redis = get_redis_client()
|
||||
if redis:
|
||||
return await redis.get(key)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def _llm_cache_set(key: str, value: str, ttl_ms: int):
|
||||
try:
|
||||
from app.core.redis_client import get_redis_client
|
||||
redis = get_redis_client()
|
||||
if redis:
|
||||
await redis.setex(key, max(1, int(ttl_ms / 1000)), value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class _LLMClient:
|
||||
"""轻量 LLM 客户端包装,复用已有 LLMService 能力。"""
|
||||
|
||||
@@ -984,12 +1014,29 @@ class _LLMClient:
|
||||
kwargs["tools"] = normalized
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
# LLM 响应缓存(仅不用工具时缓存,避免复杂序列化)
|
||||
if self._config.cache_enabled and not tools:
|
||||
cache_key = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
|
||||
cached = await _llm_cache_get(cache_key)
|
||||
if cached is not None:
|
||||
logger.info("LLM 响应命中缓存: model=%s", kwargs.get("model"))
|
||||
# 构造简易 message 对象(含 content 字段即可)
|
||||
class _CachedMsg:
|
||||
content = cached
|
||||
tool_calls = None
|
||||
return _CachedMsg()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
message = response.choices[0].message
|
||||
|
||||
# 缓存写入(仅不用工具时)
|
||||
if self._config.cache_enabled and not tools and message.content:
|
||||
ck = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
|
||||
await _llm_cache_set(ck, message.content, self._config.cache_ttl_ms)
|
||||
|
||||
# 提取 token 用量
|
||||
usage = getattr(response, "usage", None)
|
||||
prompt_tokens = usage.prompt_tokens if usage else 0
|
||||
|
||||
@@ -9,6 +9,7 @@ Agent Orchestrator — 多 Agent 编排引擎。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
@@ -347,9 +348,10 @@ class AgentOrchestrator:
|
||||
steps: List[OrchestratorStep] = []
|
||||
agent_outputs: List[Dict[str, Any]] = []
|
||||
|
||||
# 第一阶段:所有 Agent 独立回答
|
||||
# 第一阶段:所有 Agent 并行独立回答
|
||||
runtimes = []
|
||||
for agent_cfg in agents:
|
||||
runtime = AgentRuntime(
|
||||
runtimes.append(AgentRuntime(
|
||||
AgentConfig(
|
||||
name=agent_cfg.name,
|
||||
system_prompt=agent_cfg.system_prompt,
|
||||
@@ -364,8 +366,30 @@ class AgentOrchestrator:
|
||||
),
|
||||
),
|
||||
on_llm_call=on_llm_call,
|
||||
)
|
||||
result = await runtime.run(question)
|
||||
))
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[rt.run(question) for rt in runtimes],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
for i, agent_cfg in enumerate(agents):
|
||||
result = results[i]
|
||||
if isinstance(result, BaseException):
|
||||
step = OrchestratorStep(
|
||||
agent_id=agent_cfg.id,
|
||||
agent_name=agent_cfg.name,
|
||||
input=question,
|
||||
output="",
|
||||
error=str(result),
|
||||
)
|
||||
steps.append(step)
|
||||
agent_outputs.append({
|
||||
"agent_id": agent_cfg.id,
|
||||
"agent_name": agent_cfg.name,
|
||||
"output": f"[错误] {result}",
|
||||
})
|
||||
continue
|
||||
|
||||
step = OrchestratorStep(
|
||||
agent_id=agent_cfg.id,
|
||||
|
||||
@@ -15,6 +15,10 @@ class AgentToolConfig(BaseModel):
|
||||
require_approval: List[str] = Field(default_factory=list, description="需要人工审批的工具名列表")
|
||||
approval_timeout_ms: int = Field(default=60000, description="审批超时(毫秒),超时使用默认策略")
|
||||
approval_default: str = Field(default="deny", description="超时默认策略: approve | deny | skip")
|
||||
# 结果缓存
|
||||
cache_enabled: bool = Field(default=True, description="是否启用工具结果缓存(确定性工具默认开启)")
|
||||
cache_tool_whitelist: List[str] = Field(default_factory=list, description="启用缓存的工具名(空=确定性工具默认)")
|
||||
cache_ttl_ms: int = Field(default=3600000, description="缓存 TTL(毫秒),默认 1 小时")
|
||||
|
||||
|
||||
class AgentMemoryConfig(BaseModel):
|
||||
@@ -40,6 +44,8 @@ class AgentLLMConfig(BaseModel):
|
||||
request_timeout: float = 120.0
|
||||
extra_body: Optional[Dict[str, Any]] = None
|
||||
self_review_threshold: float = 0.6 # self-review 通过阈值(0-1)
|
||||
cache_enabled: bool = False # LLM 响应缓存(默认关闭,语义缓存有风险)
|
||||
cache_ttl_ms: int = 300000 # LLM 缓存 TTL,默认 5 分钟
|
||||
|
||||
|
||||
class AgentBudgetConfig(BaseModel):
|
||||
|
||||
@@ -3,6 +3,8 @@ Agent 工具管理器:包装已有 ToolRegistry,提供 Agent 需要的工具
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -10,6 +12,12 @@ from app.services.tool_registry import tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认确定性工具(结果可缓存)
|
||||
_DETERMINISTIC_TOOLS = {
|
||||
"file_read", "math_calculate", "database_query",
|
||||
"json", "text", "csv", "excel", "pdf", "image",
|
||||
}
|
||||
|
||||
|
||||
class AgentToolManager:
|
||||
"""
|
||||
@@ -17,12 +25,54 @@ class AgentToolManager:
|
||||
- 将 ToolRegistry 的工具 schema 转为 OpenAI Function Calling 格式
|
||||
- 按 Agent 配置过滤(白名单/黑名单)
|
||||
- 执行工具调用并返回结果字符串
|
||||
- 工具结果缓存(Redis / 内存 fallback)
|
||||
"""
|
||||
|
||||
def __init__(self, include_tools: Optional[List[str]] = None,
|
||||
exclude_tools: Optional[List[str]] = None):
|
||||
exclude_tools: Optional[List[str]] = None,
|
||||
cache_enabled: bool = True,
|
||||
cache_tool_whitelist: Optional[List[str]] = None,
|
||||
cache_ttl_ms: int = 3600000):
|
||||
self._include_tools: set = set(include_tools or [])
|
||||
self._exclude_tools: set = set(exclude_tools or [])
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache_whitelist: set = set(cache_tool_whitelist or [])
|
||||
self._cache_ttl_s = max(1, int(cache_ttl_ms / 1000))
|
||||
self._cache_store: Dict[str, str] = {} # 内存 fallback
|
||||
|
||||
def _is_cacheable(self, tool_name: str) -> bool:
|
||||
"""判断工具结果是否可缓存。"""
|
||||
if not self._cache_enabled:
|
||||
return False
|
||||
if self._cache_whitelist:
|
||||
return tool_name in self._cache_whitelist
|
||||
return tool_name in _DETERMINISTIC_TOOLS
|
||||
|
||||
@staticmethod
|
||||
def _cache_key(name: str, args: Dict[str, Any]) -> str:
|
||||
raw = json.dumps([name, args], sort_keys=True, ensure_ascii=False)
|
||||
return f"tool:{name}:{hashlib.sha256(raw.encode()).hexdigest()[:16]}"
|
||||
|
||||
async def _cache_get(self, key: str) -> Optional[str]:
|
||||
try:
|
||||
from app.core.redis_client import get_redis_client
|
||||
redis = get_redis_client()
|
||||
if redis:
|
||||
return await redis.get(key)
|
||||
except Exception:
|
||||
pass
|
||||
return self._cache_store.get(key)
|
||||
|
||||
async def _cache_set(self, key: str, value: str):
|
||||
try:
|
||||
from app.core.redis_client import get_redis_client
|
||||
redis = get_redis_client()
|
||||
if redis:
|
||||
await redis.setex(key, self._cache_ttl_s, value)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
self._cache_store[key] = value
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""获取 Agent 可用的工具定义列表(OpenAI Function Calling 格式)。"""
|
||||
@@ -55,7 +105,7 @@ class AgentToolManager:
|
||||
|
||||
async def execute(self, name: str, args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
执行工具调用。
|
||||
执行工具调用(带缓存)。
|
||||
|
||||
优先查找内置工具,其次查找数据库自定义工具(HTTP / Code)。
|
||||
|
||||
@@ -66,8 +116,24 @@ class AgentToolManager:
|
||||
Returns:
|
||||
工具执行结果的字符串表示
|
||||
"""
|
||||
# 缓存检查
|
||||
if self._is_cacheable(name):
|
||||
ck = self._cache_key(name, args)
|
||||
cached = await self._cache_get(ck)
|
||||
if cached is not None:
|
||||
logger.info("Agent 工具命中缓存: %s", name)
|
||||
return cached
|
||||
|
||||
logger.info("Agent 执行工具: %s", name)
|
||||
return await tool_registry.execute_tool(name, args)
|
||||
result = await tool_registry.execute_tool(name, args)
|
||||
|
||||
# 缓存写入
|
||||
if self._is_cacheable(name):
|
||||
ck = self._cache_key(name, args)
|
||||
await self._cache_set(ck, result)
|
||||
logger.debug("Agent 工具结果已缓存: %s", name)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_name(schema: Dict[str, Any]) -> Optional[str]:
|
||||
|
||||
@@ -13,6 +13,20 @@ import asyncio
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _get_progress_from_redis(execution_id: str) -> Optional[dict]:
|
||||
"""从 Redis 读取进度数据。"""
|
||||
try:
|
||||
from app.core.redis_client import get_redis_client
|
||||
redis_client = get_redis_client()
|
||||
if redis_client:
|
||||
raw = redis_client.get(f"workflow:progress:{execution_id}")
|
||||
if raw:
|
||||
return json.loads(raw)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@router.websocket("/api/v1/ws/executions/{execution_id}")
|
||||
async def websocket_execution_status(
|
||||
websocket: WebSocket,
|
||||
@@ -21,31 +35,36 @@ async def websocket_execution_status(
|
||||
):
|
||||
"""
|
||||
WebSocket实时推送执行状态
|
||||
|
||||
|
||||
Args:
|
||||
websocket: WebSocket连接
|
||||
execution_id: 执行记录ID
|
||||
token: JWT Token(可选,通过query参数传递)
|
||||
"""
|
||||
# 验证token(可选,如果需要认证)
|
||||
# user = await get_current_user_optional(token)
|
||||
|
||||
# 建立连接
|
||||
await websocket_manager.connect(websocket, execution_id)
|
||||
|
||||
|
||||
db = SessionLocal()
|
||||
|
||||
|
||||
try:
|
||||
# 发送初始状态
|
||||
execution = db.query(Execution).filter(Execution.id == execution_id).first()
|
||||
if execution:
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "status",
|
||||
"execution_id": execution_id,
|
||||
"status": execution.status,
|
||||
"progress": 0,
|
||||
"message": "连接已建立"
|
||||
}, websocket)
|
||||
# 尝试从 Redis 读取进度
|
||||
redis_progress = _get_progress_from_redis(execution_id)
|
||||
if redis_progress:
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "progress",
|
||||
"execution_id": execution_id,
|
||||
**redis_progress,
|
||||
}, websocket)
|
||||
else:
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "status",
|
||||
"execution_id": execution_id,
|
||||
"status": execution.status,
|
||||
"progress": 0,
|
||||
"message": "连接已建立"
|
||||
}, websocket)
|
||||
else:
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "error",
|
||||
@@ -53,30 +72,56 @@ async def websocket_execution_status(
|
||||
}, websocket)
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
|
||||
last_progress = -1
|
||||
|
||||
# 持续监听并推送状态更新
|
||||
while True:
|
||||
try:
|
||||
# 接收客户端消息(心跳等)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# 处理客户端消息
|
||||
# 接收客户端消息(心跳等),超时 1 秒以便轮询进度
|
||||
try:
|
||||
message = json.loads(data)
|
||||
if message.get("type") == "ping":
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "pong"
|
||||
}, websocket)
|
||||
except:
|
||||
pass
|
||||
|
||||
data = await asyncio.wait_for(websocket.receive_text(), timeout=1.0)
|
||||
try:
|
||||
message = json.loads(data)
|
||||
if message.get("type") == "ping":
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "pong"
|
||||
}, websocket)
|
||||
except Exception:
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
pass # 超时后轮询进度
|
||||
|
||||
except WebSocketDisconnect:
|
||||
break
|
||||
|
||||
# 检查执行状态
|
||||
db.refresh(execution)
|
||||
|
||||
|
||||
# 优先从 Redis 读取进度(推送模式)
|
||||
redis_progress = _get_progress_from_redis(execution_id)
|
||||
if redis_progress:
|
||||
pct = redis_progress.get("progress", -1)
|
||||
if pct != last_progress:
|
||||
last_progress = pct
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "progress",
|
||||
"execution_id": execution_id,
|
||||
**redis_progress,
|
||||
}, websocket)
|
||||
else:
|
||||
# Redis 不可用时回退到 DB 轮询
|
||||
db.refresh(execution)
|
||||
pct = 100 if execution.status in ["completed", "failed"] else (50 if execution.status == "running" else 0)
|
||||
if pct != last_progress:
|
||||
last_progress = pct
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "status",
|
||||
"execution_id": execution_id,
|
||||
"status": execution.status,
|
||||
"progress": pct,
|
||||
"message": f"执行中..." if execution.status == "running" else "等待执行"
|
||||
}, websocket)
|
||||
|
||||
# 如果执行完成或失败,发送最终状态并断开
|
||||
db.refresh(execution)
|
||||
if execution.status in ["completed", "failed"]:
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "status",
|
||||
@@ -87,24 +132,10 @@ async def websocket_execution_status(
|
||||
"error": execution.error_message if execution.status == "failed" else None,
|
||||
"execution_time": execution.execution_time
|
||||
}, websocket)
|
||||
|
||||
# 等待一下再断开,确保客户端收到消息
|
||||
|
||||
await asyncio.sleep(1)
|
||||
break
|
||||
|
||||
# 定期发送状态更新(每2秒)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# 重新查询执行状态
|
||||
db.refresh(execution)
|
||||
await websocket_manager.send_personal_message({
|
||||
"type": "status",
|
||||
"execution_id": execution_id,
|
||||
"status": execution.status,
|
||||
"progress": 50 if execution.status == "running" else 0,
|
||||
"message": f"执行中..." if execution.status == "running" else "等待执行"
|
||||
}, websocket)
|
||||
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as e:
|
||||
@@ -114,7 +145,7 @@ async def websocket_execution_status(
|
||||
"type": "error",
|
||||
"message": f"发生错误: {str(e)}"
|
||||
}, websocket)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
websocket_manager.disconnect(websocket, execution_id)
|
||||
|
||||
@@ -5589,14 +5589,18 @@ class WorkflowEngine:
|
||||
self,
|
||||
input_data: Dict[str, Any],
|
||||
resume_snapshot: Optional[Dict[str, Any]] = None,
|
||||
execution_id: Optional[str] = None,
|
||||
on_progress: Optional[callable] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行完整工作流
|
||||
|
||||
|
||||
Args:
|
||||
input_data: 初始输入数据(恢复执行时须包含 __hil_decision 等)
|
||||
resume_snapshot: 从挂起快照恢复(与 pause_state 一致)
|
||||
|
||||
execution_id: 执行记录 ID(用于进度上报)
|
||||
on_progress: 进度回调 async def(execution_id, progress_data)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
@@ -5641,6 +5645,8 @@ class WorkflowEngine:
|
||||
self._tool_calls_used = 0
|
||||
results = {}
|
||||
|
||||
total_nodes = len(self.nodes)
|
||||
|
||||
# 按拓扑顺序执行节点(动态构建执行图)
|
||||
while True:
|
||||
# 构建当前活跃的执行图
|
||||
@@ -5656,12 +5662,12 @@ class WorkflowEngine:
|
||||
f"[rjb] 当前执行图: {execution_order}, 活跃边数: {len(active_edges)}, 已执行节点: {executed_nodes}"
|
||||
)
|
||||
|
||||
next_node_id = None
|
||||
# 收集所有就绪节点(DAG 并行)
|
||||
ready_nodes: list = []
|
||||
for node_id in pending_ids:
|
||||
can_execute = True
|
||||
incoming_edges = [e for e in active_edges if e["target"] == node_id]
|
||||
if not incoming_edges:
|
||||
# 没有入边:仅允许 Start;孤立节点跳过
|
||||
if node_id not in [n["id"] for n in self.nodes.values() if n.get("type") == "start"]:
|
||||
logger.debug(f"[rjb] 节点 {node_id} 没有入边,跳过执行")
|
||||
continue
|
||||
@@ -5669,7 +5675,6 @@ class WorkflowEngine:
|
||||
for edge in incoming_edges:
|
||||
src = edge["source"]
|
||||
if src not in forward_reachable:
|
||||
# 条件分支裁剪后不可达的前驱,不参与 gate(OR-join)
|
||||
continue
|
||||
if src not in executed_nodes:
|
||||
can_execute = False
|
||||
@@ -5678,298 +5683,327 @@ class WorkflowEngine:
|
||||
)
|
||||
break
|
||||
if can_execute:
|
||||
next_node_id = node_id
|
||||
ready_nodes.append(node_id)
|
||||
logger.info(
|
||||
f"[rjb] 选择执行节点: {next_node_id}, 类型: {self.nodes[next_node_id].get('type')}, 入边数: {len(incoming_edges)}"
|
||||
f"[rjb] 就绪节点: {node_id}, 类型: {self.nodes[node_id].get('type')}"
|
||||
)
|
||||
break
|
||||
|
||||
if not next_node_id:
|
||||
if not ready_nodes:
|
||||
break # 没有更多节点可执行
|
||||
|
||||
node = self.nodes[next_node_id]
|
||||
is_approval = node.get("type") == "approval"
|
||||
if not is_approval:
|
||||
executed_nodes.add(next_node_id)
|
||||
execution_sequence.append(next_node_id)
|
||||
|
||||
# 调试:检查节点数据结构
|
||||
if node.get('type') == 'llm':
|
||||
logger.debug(f"[rjb] 执行LLM节点: node_id={next_node_id}, node keys={list(node.keys())}, data keys={list(node.get('data', {}).keys()) if node.get('data') else []}")
|
||||
|
||||
# 获取节点输入(使用活跃的边)
|
||||
node_input = self.get_node_input(next_node_id, self.node_outputs, active_edges)
|
||||
|
||||
# 如果是起始节点,使用初始输入
|
||||
if node.get('type') == 'start' and not node_input:
|
||||
node_input = input_data
|
||||
logger.info(f"[rjb] Start节点使用初始输入: node_id={next_node_id}, node_input={node_input}")
|
||||
|
||||
# 调试:记录节点输入数据
|
||||
if node.get('type') == 'llm':
|
||||
logger.info(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}")
|
||||
if 'start-1' in self.node_outputs:
|
||||
logger.info(f"[rjb] Start节点输出内容: {self.node_outputs['start-1']}")
|
||||
# 标记所有就绪节点为已执行(approval 节点除外)
|
||||
for nid in ready_nodes:
|
||||
node = self.nodes[nid]
|
||||
if node.get("type") != "approval":
|
||||
executed_nodes.add(nid)
|
||||
execution_sequence.append(nid)
|
||||
|
||||
# 单执行步数预算(每执行一个节点计 1 步)
|
||||
self._steps_used += 1
|
||||
# 获取所有就绪节点的输入
|
||||
node_inputs: dict = {}
|
||||
for nid in ready_nodes:
|
||||
node = self.nodes[nid]
|
||||
node_input = self.get_node_input(nid, self.node_outputs, active_edges)
|
||||
if node.get('type') == 'start' and not node_input:
|
||||
node_input = input_data
|
||||
node_inputs[nid] = node_input
|
||||
|
||||
# 预算检查(整批节点)
|
||||
self._steps_used += len(ready_nodes)
|
||||
if self._steps_used > self._cap_steps:
|
||||
raise WorkflowExecutionError(
|
||||
detail=f"已超过单执行预算上限({self._cap_steps} 步),已熔断",
|
||||
node_id=next_node_id,
|
||||
node_id=ready_nodes[0],
|
||||
)
|
||||
|
||||
# 执行节点
|
||||
result = await self.execute_node(node, node_input)
|
||||
# 并行执行所有就绪节点
|
||||
async def _exec_one(nid: str):
|
||||
try:
|
||||
return await self.execute_node(self.nodes[nid], node_inputs[nid])
|
||||
except Exception as exc:
|
||||
return {"status": "failed", "error": str(exc)}
|
||||
|
||||
if result.get("status") == "awaiting_approval":
|
||||
self._steps_used -= 1
|
||||
snap = self._build_pause_snapshot(
|
||||
next_node_id, active_edges, executed_nodes, execution_sequence, results
|
||||
)
|
||||
raise WorkflowPaused(snap)
|
||||
|
||||
if is_approval:
|
||||
executed_nodes.add(next_node_id)
|
||||
execution_sequence.append(next_node_id)
|
||||
|
||||
results[next_node_id] = result
|
||||
|
||||
# 保存节点输出
|
||||
if result.get('status') == 'success':
|
||||
output_value = result.get('output', {})
|
||||
self.node_outputs[next_node_id] = output_value
|
||||
if node.get('type') == 'start':
|
||||
logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}")
|
||||
|
||||
# 如果是条件节点或Switch节点,根据分支结果过滤边
|
||||
if node.get('type') == 'condition':
|
||||
branch = result.get('branch', 'false')
|
||||
logger.info(f"[rjb] 条件节点分支过滤: node_id={next_node_id}, branch={branch}")
|
||||
# 移除不符合条件的边
|
||||
# 只保留:1) 不是从条件节点出发的边,或 2) 从条件节点出发且sourceHandle匹配分支的边
|
||||
edges_to_remove = []
|
||||
edges_to_keep = []
|
||||
for edge in active_edges:
|
||||
if edge['source'] == next_node_id:
|
||||
# 这是从条件节点出发的边
|
||||
edge_handle = edge.get('sourceHandle')
|
||||
if edge_handle == branch:
|
||||
# sourceHandle匹配分支,保留
|
||||
edges_to_keep.append(edge)
|
||||
logger.info(f"[rjb] 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
|
||||
else:
|
||||
# sourceHandle不匹配或为None,移除
|
||||
edges_to_remove.append(edge)
|
||||
logger.info(f"[rjb] 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch})")
|
||||
else:
|
||||
# 不是从条件节点出发的边,保留
|
||||
edges_to_keep.append(edge)
|
||||
|
||||
active_edges = edges_to_keep
|
||||
|
||||
elif node.get('type') == 'switch':
|
||||
branch = result.get('branch', 'default')
|
||||
logger.info(f"[rjb] Switch节点分支过滤: node_id={next_node_id}, branch={branch}")
|
||||
|
||||
# 记录过滤前的边信息
|
||||
edges_before = [e for e in active_edges if e['source'] == next_node_id]
|
||||
logger.info(f"[rjb] Switch节点过滤前: 从节点出发的边有{len(edges_before)}条")
|
||||
for edge in edges_before:
|
||||
logger.info(f"[rjb] 边 {edge.get('id')}: sourceHandle={edge.get('sourceHandle')}, target={edge.get('target')}")
|
||||
|
||||
# 移除不匹配的边
|
||||
edges_to_keep = []
|
||||
edges_removed_count = 0
|
||||
removed_source_nodes = set() # 记录被移除边的源节点
|
||||
|
||||
for edge in active_edges:
|
||||
if edge['source'] == next_node_id:
|
||||
# 这是从Switch节点出发的边
|
||||
edge_handle = edge.get('sourceHandle')
|
||||
if edge_handle == branch:
|
||||
# sourceHandle匹配分支,保留
|
||||
edges_to_keep.append(edge)
|
||||
logger.info(f"[rjb] ✅ 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
|
||||
else:
|
||||
# sourceHandle不匹配,移除
|
||||
edges_removed_count += 1
|
||||
target_id = edge.get('target')
|
||||
removed_source_nodes.add(target_id) # 记录目标节点(这些节点将不再可达)
|
||||
logger.info(f"[rjb] ❌ 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch}, target={target_id})")
|
||||
else:
|
||||
# 不是从Switch节点出发的边,保留
|
||||
edges_to_keep.append(edge)
|
||||
|
||||
# 重要:移除那些指向被过滤节点的边(这些边来自被过滤的LLM节点)
|
||||
# 例如:如果llm-question被过滤了,那么llm-question → merge-response的边也应该被移除
|
||||
additional_removed = 0
|
||||
for edge in list(edges_to_keep): # 使用list副本,因为我们要修改原列表
|
||||
if edge['source'] in removed_source_nodes:
|
||||
# 这条边来自被过滤的节点,也应该被移除
|
||||
edges_to_keep.remove(edge)
|
||||
additional_removed += 1
|
||||
logger.info(f"[rjb] ❌ 移除来自被过滤节点的边: {edge.get('id')} ({edge.get('source')} → {edge.get('target')})")
|
||||
|
||||
edges_removed_count += additional_removed
|
||||
|
||||
active_edges = edges_to_keep
|
||||
filter_info = {
|
||||
'branch': branch,
|
||||
'edges_before': len(edges_before),
|
||||
'edges_kept': len([e for e in edges_to_keep if e['source'] == next_node_id]),
|
||||
'edges_removed': edges_removed_count
|
||||
}
|
||||
logger.info(f"[rjb] Switch节点过滤后: 保留{len(active_edges)}条边(其中从Switch节点出发的{filter_info['edges_kept']}条),移除{edges_removed_count}条边")
|
||||
# 记录过滤后的活跃边
|
||||
remaining_switch_edges = [e for e in active_edges if e['source'] == next_node_id]
|
||||
logger.info(f"[rjb] Switch节点过滤后剩余的边: {[e.get('id') + '->' + e.get('target') for e in remaining_switch_edges]}")
|
||||
|
||||
# 重要:找出那些不再可达的节点(这些节点只通过被移除的边连接)
|
||||
removed_targets = set()
|
||||
for edge in edges_before:
|
||||
if edge not in edges_to_keep:
|
||||
target_id = edge.get('target')
|
||||
removed_targets.add(target_id)
|
||||
logger.info(f"[rjb] ❌ 节点 {target_id} 的边已被移除,该节点将不会被执行")
|
||||
|
||||
# 关键修复:立即重新构建执行图,确保不再可达的节点不在执行图中
|
||||
# 这样在下次循环时,这些节点就不会被选择执行
|
||||
logger.info(f"[rjb] Switch节点过滤后,重新构建执行图(排除 {len(removed_targets)} 个不再可达的节点)")
|
||||
|
||||
# 同时记录到数据库
|
||||
if self.logger:
|
||||
self.logger.info(
|
||||
f"Switch节点分支过滤: branch={branch}, 保留{filter_info['edges_kept']}条边,移除{edges_removed_count}条边",
|
||||
node_id=next_node_id,
|
||||
node_type='switch',
|
||||
data=filter_info
|
||||
)
|
||||
|
||||
elif node.get('type') == 'approval':
|
||||
branch = result.get('branch', 'approved')
|
||||
logger.info(f"[rjb] Approval节点分支过滤: node_id={next_node_id}, branch={branch}")
|
||||
edges_to_keep = []
|
||||
removed_source_nodes = set()
|
||||
for edge in active_edges:
|
||||
if edge['source'] == next_node_id:
|
||||
edge_handle = edge.get('sourceHandle')
|
||||
if edge_handle == branch:
|
||||
edges_to_keep.append(edge)
|
||||
else:
|
||||
removed_source_nodes.add(edge.get('target'))
|
||||
else:
|
||||
edges_to_keep.append(edge)
|
||||
for edge in list(edges_to_keep):
|
||||
if edge['source'] in removed_source_nodes:
|
||||
edges_to_keep.remove(edge)
|
||||
active_edges = edges_to_keep
|
||||
if self.logger:
|
||||
self.logger.info(
|
||||
f"Approval节点分支过滤: branch={branch}",
|
||||
node_id=next_node_id,
|
||||
node_type='approval',
|
||||
)
|
||||
|
||||
# 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行)
|
||||
if node.get('type') in ['loop', 'foreach']:
|
||||
# 标记循环体的节点为已执行(简化处理)
|
||||
for edge in active_edges[:]: # 使用切片复制列表
|
||||
if edge.get('source') == next_node_id:
|
||||
target_id = edge.get('target')
|
||||
if target_id in self.nodes:
|
||||
# 检查是否是循环结束节点
|
||||
target_node = self.nodes[target_id]
|
||||
if target_node.get('type') not in ['loop_end', 'end']:
|
||||
# 标记为已执行(循环体已在循环节点内部执行)
|
||||
executed_nodes.add(target_id)
|
||||
# 继续查找循环体内的节点
|
||||
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
|
||||
if len(ready_nodes) == 1:
|
||||
batch_results = {ready_nodes[0]: await _exec_one(ready_nodes[0])}
|
||||
else:
|
||||
# 执行失败或质量不达标 — 支持重试
|
||||
failed_status = result.get('status', 'failed')
|
||||
error_msg = result.get('error', '未知错误')
|
||||
node_type = node.get('type', 'unknown')
|
||||
logger.info(f"[rjb] 并行执行 {len(ready_nodes)} 个节点: {ready_nodes}")
|
||||
tasks = [_exec_one(nid) for nid in ready_nodes]
|
||||
gathered = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
batch_results = {}
|
||||
for i, nid in enumerate(ready_nodes):
|
||||
r = gathered[i]
|
||||
if isinstance(r, BaseException):
|
||||
batch_results[nid] = {"status": "failed", "error": str(r)}
|
||||
else:
|
||||
batch_results[nid] = r
|
||||
|
||||
# 处理 error_handler 返回的 retry_predecessor
|
||||
if failed_status == 'retry_predecessor':
|
||||
pred_id = result.get('predecessor_id')
|
||||
eh_retry_count = result.get('retry_count', 1)
|
||||
eh_retry_delay_ms = result.get('retry_delay_ms', 1000)
|
||||
# 逐一处理各节点结果
|
||||
for next_node_id in ready_nodes:
|
||||
node = self.nodes[next_node_id]
|
||||
result = batch_results[next_node_id]
|
||||
is_approval = node.get("type") == "approval"
|
||||
|
||||
# 检查是否已有重试计数(多次经过 error_handler)
|
||||
prev_counter = self.node_outputs.get(f"_eh_retry_{pred_id}", {})
|
||||
remaining = prev_counter.get("remaining", eh_retry_count)
|
||||
if pred_id and pred_id in self.nodes and remaining > 0:
|
||||
remaining -= 1
|
||||
logger.warning(
|
||||
f"error_handler 请求重试前驱节点 {pred_id},剩余 {remaining} 次"
|
||||
)
|
||||
self.executed_nodes.discard(pred_id)
|
||||
self.node_outputs.pop(pred_id, None)
|
||||
self.node_outputs[f"_eh_retry_{pred_id}"] = {
|
||||
"remaining": remaining,
|
||||
"delay_ms": eh_retry_delay_ms,
|
||||
"error_handler_id": next_node_id,
|
||||
}
|
||||
executed_nodes.add(next_node_id)
|
||||
execution_sequence.append(next_node_id)
|
||||
results[next_node_id] = result
|
||||
await asyncio.sleep(eh_retry_delay_ms / 1000.0)
|
||||
continue
|
||||
elif remaining <= 0:
|
||||
logger.error(f"error_handler 重试次数耗尽,前驱节点 {pred_id} 停止")
|
||||
raise WorkflowExecutionError(
|
||||
detail=f"error_handler 重试次数耗尽: {error_msg}",
|
||||
node_id=pred_id,
|
||||
)
|
||||
|
||||
# 检查节点级 retry_config
|
||||
node_data = node.get('data', {}) or {}
|
||||
retry_cfg = node_data.get('retry_config', {})
|
||||
max_retries = retry_cfg.get('max_retries', 0) if isinstance(retry_cfg, dict) else 0
|
||||
retry_delay_ms = retry_cfg.get('retry_delay_ms', 1000) if isinstance(retry_cfg, dict) else 1000
|
||||
on_exhausted = retry_cfg.get('on_exhausted', 'stop') if isinstance(retry_cfg, dict) else 'stop'
|
||||
|
||||
retry_key = f"_retry_{next_node_id}"
|
||||
retries_done = self.node_outputs.get(retry_key, 0)
|
||||
|
||||
if max_retries > 0 and retries_done < max_retries:
|
||||
self.node_outputs[retry_key] = retries_done + 1
|
||||
logger.warning(
|
||||
f"节点 {next_node_id} ({node_type}) 执行失败,重试 {retries_done + 1}/{max_retries}: {error_msg}"
|
||||
if result.get("status") == "awaiting_approval":
|
||||
self._steps_used -= 1
|
||||
snap = self._build_pause_snapshot(
|
||||
next_node_id, active_edges, executed_nodes, execution_sequence, results
|
||||
)
|
||||
await asyncio.sleep(retry_delay_ms / 1000.0)
|
||||
continue # 不标记已执行,下次循环重新执行
|
||||
raise WorkflowPaused(snap)
|
||||
|
||||
# 重试耗尽或未配置重试
|
||||
if retries_done >= max_retries and max_retries > 0:
|
||||
if on_exhausted == 'skip':
|
||||
logger.warning(f"节点 {next_node_id} 重试耗尽,跳过: {error_msg}")
|
||||
self.node_outputs[next_node_id] = {
|
||||
'status': 'skipped', 'error': error_msg
|
||||
}
|
||||
executed_nodes.add(next_node_id)
|
||||
execution_sequence.append(next_node_id)
|
||||
results[next_node_id] = result
|
||||
continue
|
||||
elif on_exhausted == 'notify':
|
||||
logger.error(f"节点 {next_node_id} 重试耗尽,已通知: {error_msg}")
|
||||
self.node_outputs[next_node_id] = {
|
||||
'status': 'error_notified', 'error': error_msg
|
||||
}
|
||||
executed_nodes.add(next_node_id)
|
||||
execution_sequence.append(next_node_id)
|
||||
results[next_node_id] = result
|
||||
continue
|
||||
if is_approval:
|
||||
executed_nodes.add(next_node_id)
|
||||
execution_sequence.append(next_node_id)
|
||||
|
||||
results[next_node_id] = result
|
||||
|
||||
# 保存节点输出
|
||||
if result.get('status') == 'success':
|
||||
output_value = result.get('output', {})
|
||||
self.node_outputs[next_node_id] = output_value
|
||||
if node.get('type') == 'start':
|
||||
logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}")
|
||||
|
||||
# 如果是条件节点或Switch节点,根据分支结果过滤边
|
||||
if node.get('type') == 'condition':
|
||||
branch = result.get('branch', 'false')
|
||||
logger.info(f"[rjb] 条件节点分支过滤: node_id={next_node_id}, branch={branch}")
|
||||
# 移除不符合条件的边
|
||||
# 只保留:1) 不是从条件节点出发的边,或 2) 从条件节点出发且sourceHandle匹配分支的边
|
||||
edges_to_remove = []
|
||||
edges_to_keep = []
|
||||
for edge in active_edges:
|
||||
if edge['source'] == next_node_id:
|
||||
# 这是从条件节点出发的边
|
||||
edge_handle = edge.get('sourceHandle')
|
||||
if edge_handle == branch:
|
||||
# sourceHandle匹配分支,保留
|
||||
edges_to_keep.append(edge)
|
||||
logger.info(f"[rjb] 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
|
||||
else:
|
||||
# sourceHandle不匹配或为None,移除
|
||||
edges_to_remove.append(edge)
|
||||
logger.info(f"[rjb] 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch})")
|
||||
else:
|
||||
# 不是从条件节点出发的边,保留
|
||||
edges_to_keep.append(edge)
|
||||
|
||||
active_edges = edges_to_keep
|
||||
|
||||
elif node.get('type') == 'switch':
|
||||
branch = result.get('branch', 'default')
|
||||
logger.info(f"[rjb] Switch节点分支过滤: node_id={next_node_id}, branch={branch}")
|
||||
|
||||
# 记录过滤前的边信息
|
||||
edges_before = [e for e in active_edges if e['source'] == next_node_id]
|
||||
logger.info(f"[rjb] Switch节点过滤前: 从节点出发的边有{len(edges_before)}条")
|
||||
for edge in edges_before:
|
||||
logger.info(f"[rjb] 边 {edge.get('id')}: sourceHandle={edge.get('sourceHandle')}, target={edge.get('target')}")
|
||||
|
||||
# 移除不匹配的边
|
||||
edges_to_keep = []
|
||||
edges_removed_count = 0
|
||||
removed_source_nodes = set() # 记录被移除边的源节点
|
||||
|
||||
for edge in active_edges:
|
||||
if edge['source'] == next_node_id:
|
||||
# 这是从Switch节点出发的边
|
||||
edge_handle = edge.get('sourceHandle')
|
||||
if edge_handle == branch:
|
||||
# sourceHandle匹配分支,保留
|
||||
edges_to_keep.append(edge)
|
||||
logger.info(f"[rjb] ✅ 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
|
||||
else:
|
||||
# sourceHandle不匹配,移除
|
||||
edges_removed_count += 1
|
||||
target_id = edge.get('target')
|
||||
removed_source_nodes.add(target_id) # 记录目标节点(这些节点将不再可达)
|
||||
logger.info(f"[rjb] ❌ 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch}, target={target_id})")
|
||||
else:
|
||||
# 不是从Switch节点出发的边,保留
|
||||
edges_to_keep.append(edge)
|
||||
|
||||
# 重要:移除那些指向被过滤节点的边(这些边来自被过滤的LLM节点)
|
||||
# 例如:如果llm-question被过滤了,那么llm-question → merge-response的边也应该被移除
|
||||
additional_removed = 0
|
||||
for edge in list(edges_to_keep): # 使用list副本,因为我们要修改原列表
|
||||
if edge['source'] in removed_source_nodes:
|
||||
# 这条边来自被过滤的节点,也应该被移除
|
||||
edges_to_keep.remove(edge)
|
||||
additional_removed += 1
|
||||
logger.info(f"[rjb] ❌ 移除来自被过滤节点的边: {edge.get('id')} ({edge.get('source')} → {edge.get('target')})")
|
||||
|
||||
edges_removed_count += additional_removed
|
||||
|
||||
active_edges = edges_to_keep
|
||||
filter_info = {
|
||||
'branch': branch,
|
||||
'edges_before': len(edges_before),
|
||||
'edges_kept': len([e for e in edges_to_keep if e['source'] == next_node_id]),
|
||||
'edges_removed': edges_removed_count
|
||||
}
|
||||
logger.info(f"[rjb] Switch节点过滤后: 保留{len(active_edges)}条边(其中从Switch节点出发的{filter_info['edges_kept']}条),移除{edges_removed_count}条边")
|
||||
# 记录过滤后的活跃边
|
||||
remaining_switch_edges = [e for e in active_edges if e['source'] == next_node_id]
|
||||
logger.info(f"[rjb] Switch节点过滤后剩余的边: {[e.get('id') + '->' + e.get('target') for e in remaining_switch_edges]}")
|
||||
|
||||
# 重要:找出那些不再可达的节点(这些节点只通过被移除的边连接)
|
||||
removed_targets = set()
|
||||
for edge in edges_before:
|
||||
if edge not in edges_to_keep:
|
||||
target_id = edge.get('target')
|
||||
removed_targets.add(target_id)
|
||||
logger.info(f"[rjb] ❌ 节点 {target_id} 的边已被移除,该节点将不会被执行")
|
||||
|
||||
# 关键修复:立即重新构建执行图,确保不再可达的节点不在执行图中
|
||||
# 这样在下次循环时,这些节点就不会被选择执行
|
||||
logger.info(f"[rjb] Switch节点过滤后,重新构建执行图(排除 {len(removed_targets)} 个不再可达的节点)")
|
||||
|
||||
# 同时记录到数据库
|
||||
if self.logger:
|
||||
self.logger.info(
|
||||
f"Switch节点分支过滤: branch={branch}, 保留{filter_info['edges_kept']}条边,移除{edges_removed_count}条边",
|
||||
node_id=next_node_id,
|
||||
node_type='switch',
|
||||
data=filter_info
|
||||
)
|
||||
|
||||
elif node.get('type') == 'approval':
|
||||
branch = result.get('branch', 'approved')
|
||||
logger.info(f"[rjb] Approval节点分支过滤: node_id={next_node_id}, branch={branch}")
|
||||
edges_to_keep = []
|
||||
removed_source_nodes = set()
|
||||
for edge in active_edges:
|
||||
if edge['source'] == next_node_id:
|
||||
edge_handle = edge.get('sourceHandle')
|
||||
if edge_handle == branch:
|
||||
edges_to_keep.append(edge)
|
||||
else:
|
||||
removed_source_nodes.add(edge.get('target'))
|
||||
else:
|
||||
edges_to_keep.append(edge)
|
||||
for edge in list(edges_to_keep):
|
||||
if edge['source'] in removed_source_nodes:
|
||||
edges_to_keep.remove(edge)
|
||||
active_edges = edges_to_keep
|
||||
if self.logger:
|
||||
self.logger.info(
|
||||
f"Approval节点分支过滤: branch={branch}",
|
||||
node_id=next_node_id,
|
||||
node_type='approval',
|
||||
)
|
||||
|
||||
# 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行)
|
||||
if node.get('type') in ['loop', 'foreach']:
|
||||
# 标记循环体的节点为已执行(简化处理)
|
||||
for edge in active_edges[:]: # 使用切片复制列表
|
||||
if edge.get('source') == next_node_id:
|
||||
target_id = edge.get('target')
|
||||
if target_id in self.nodes:
|
||||
# 检查是否是循环结束节点
|
||||
target_node = self.nodes[target_id]
|
||||
if target_node.get('type') not in ['loop_end', 'end']:
|
||||
# 标记为已执行(循环体已在循环节点内部执行)
|
||||
executed_nodes.add(target_id)
|
||||
# 继续查找循环体内的节点
|
||||
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
|
||||
else:
|
||||
# 执行失败或质量不达标 — 支持重试
|
||||
failed_status = result.get('status', 'failed')
|
||||
error_msg = result.get('error', '未知错误')
|
||||
node_type = node.get('type', 'unknown')
|
||||
|
||||
# 处理 error_handler 返回的 retry_predecessor
|
||||
if failed_status == 'retry_predecessor':
|
||||
pred_id = result.get('predecessor_id')
|
||||
eh_retry_count = result.get('retry_count', 1)
|
||||
eh_retry_delay_ms = result.get('retry_delay_ms', 1000)
|
||||
|
||||
# 检查是否已有重试计数(多次经过 error_handler)
|
||||
prev_counter = self.node_outputs.get(f"_eh_retry_{pred_id}", {})
|
||||
remaining = prev_counter.get("remaining", eh_retry_count)
|
||||
if pred_id and pred_id in self.nodes and remaining > 0:
|
||||
remaining -= 1
|
||||
logger.warning(
|
||||
f"error_handler 请求重试前驱节点 {pred_id},剩余 {remaining} 次"
|
||||
)
|
||||
executed_nodes.discard(pred_id)
|
||||
self.node_outputs.pop(pred_id, None)
|
||||
self.node_outputs[f"_eh_retry_{pred_id}"] = {
|
||||
"remaining": remaining,
|
||||
"delay_ms": eh_retry_delay_ms,
|
||||
"error_handler_id": next_node_id,
|
||||
}
|
||||
executed_nodes.add(next_node_id)
|
||||
execution_sequence.append(next_node_id)
|
||||
results[next_node_id] = result
|
||||
await asyncio.sleep(eh_retry_delay_ms / 1000.0)
|
||||
break # 退出 for 循环,while 循环下次迭代处理重试
|
||||
elif remaining <= 0:
|
||||
logger.error(f"error_handler 重试次数耗尽,前驱节点 {pred_id} 停止")
|
||||
raise WorkflowExecutionError(
|
||||
detail=f"error_handler 重试次数耗尽: {error_msg}",
|
||||
node_id=pred_id,
|
||||
)
|
||||
|
||||
# 检查节点级 retry_config
|
||||
node_data = node.get('data', {}) or {}
|
||||
retry_cfg = node_data.get('retry_config', {})
|
||||
max_retries = retry_cfg.get('max_retries', 0) if isinstance(retry_cfg, dict) else 0
|
||||
retry_delay_ms = retry_cfg.get('retry_delay_ms', 1000) if isinstance(retry_cfg, dict) else 1000
|
||||
on_exhausted = retry_cfg.get('on_exhausted', 'stop') if isinstance(retry_cfg, dict) else 'stop'
|
||||
|
||||
retry_key = f"_retry_{next_node_id}"
|
||||
retries_done = self.node_outputs.get(retry_key, 0)
|
||||
|
||||
if max_retries > 0 and retries_done < max_retries:
|
||||
self.node_outputs[retry_key] = retries_done + 1
|
||||
logger.warning(
|
||||
f"节点 {next_node_id} ({node_type}) 执行失败,重试 {retries_done + 1}/{max_retries}: {error_msg}"
|
||||
)
|
||||
await asyncio.sleep(retry_delay_ms / 1000.0)
|
||||
executed_nodes.discard(next_node_id)
|
||||
break # 退出 for 循环,while 循环下次迭代重新执行
|
||||
|
||||
# 重试耗尽或未配置重试
|
||||
if retries_done >= max_retries and max_retries > 0:
|
||||
if on_exhausted == 'skip':
|
||||
logger.warning(f"节点 {next_node_id} 重试耗尽,跳过: {error_msg}")
|
||||
self.node_outputs[next_node_id] = {
|
||||
'status': 'skipped', 'error': error_msg
|
||||
}
|
||||
executed_nodes.add(next_node_id)
|
||||
execution_sequence.append(next_node_id)
|
||||
results[next_node_id] = result
|
||||
continue
|
||||
elif on_exhausted == 'notify':
|
||||
logger.error(f"节点 {next_node_id} 重试耗尽,已通知: {error_msg}")
|
||||
self.node_outputs[next_node_id] = {
|
||||
'status': 'error_notified', 'error': error_msg
|
||||
}
|
||||
executed_nodes.add(next_node_id)
|
||||
execution_sequence.append(next_node_id)
|
||||
results[next_node_id] = result
|
||||
continue
|
||||
|
||||
# 默认:停止工作流
|
||||
logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}")
|
||||
raise WorkflowExecutionError(
|
||||
detail=error_msg,
|
||||
node_id=next_node_id
|
||||
)
|
||||
|
||||
# 上报批次进度
|
||||
if on_progress and execution_id:
|
||||
try:
|
||||
pct = int(len(executed_nodes) / total_nodes * 100) if total_nodes else 100
|
||||
await on_progress(execution_id, {
|
||||
"current": len(executed_nodes),
|
||||
"total": total_nodes,
|
||||
"progress": min(pct, 99),
|
||||
"status": "running",
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 默认:停止工作流
|
||||
logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}")
|
||||
raise WorkflowExecutionError(
|
||||
detail=error_msg,
|
||||
node_id=next_node_id
|
||||
)
|
||||
|
||||
# 返回最终结果:优先取 End 类型且无出边的节点,避免向量写入等侧链与 End 同为 sink 时
|
||||
# 因 executed_nodes 为 set 迭代顺序不确定而错误返回 upsert 元数据。
|
||||
if executed_nodes:
|
||||
@@ -6052,7 +6086,19 @@ class WorkflowEngine:
|
||||
# 记录工作流执行完成
|
||||
if self.logger:
|
||||
self.logger.info("工作流执行完成", data={"result": final_result.get('result')})
|
||||
|
||||
|
||||
# 上报 100% 完成进度
|
||||
if on_progress and execution_id:
|
||||
try:
|
||||
await on_progress(execution_id, {
|
||||
"current": len(executed_nodes),
|
||||
"total": total_nodes,
|
||||
"progress": 100,
|
||||
"status": "completed",
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return final_result
|
||||
|
||||
if self.logger:
|
||||
|
||||
@@ -20,7 +20,9 @@ from app.models.workflow import Workflow
|
||||
from app.services.execution_budget import merge_budget_for_execution
|
||||
from app.services.agent_workspace_chat_log import try_append_agent_dialogue_after_success
|
||||
from app.services.notification_service import create_notification
|
||||
from app.websocket.manager import websocket_manager
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@@ -42,6 +44,29 @@ def _snapshot_to_jsonable(snapshot: dict) -> dict:
|
||||
return _json.loads(_json.dumps(snapshot, default=str))
|
||||
|
||||
|
||||
async def _on_workflow_progress(execution_id: str, progress_data: dict):
|
||||
"""工作流进度回调:写入 Redis + WebSocket 广播。"""
|
||||
try:
|
||||
from app.core.redis_client import get_redis_client
|
||||
redis_client = get_redis_client()
|
||||
if redis_client:
|
||||
redis_client.setex(
|
||||
f"workflow:progress:{execution_id}",
|
||||
300,
|
||||
json.dumps(progress_data, ensure_ascii=False),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket_manager.broadcast_to_execution(execution_id, {
|
||||
"type": "progress",
|
||||
"execution_id": execution_id,
|
||||
**progress_data,
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _trusted_user_for_execution(db, execution: Optional[Execution]) -> Optional[str]:
|
||||
"""用于校验 LLM 节点引用的 model_configs 归属(与 Workflow / Agent 所有者一致)。"""
|
||||
if not execution:
|
||||
@@ -195,7 +220,8 @@ def execute_workflow_task(
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
result = asyncio.run(
|
||||
engine.execute(input_data, resume_snapshot=resume_snapshot)
|
||||
engine.execute(input_data, resume_snapshot=resume_snapshot,
|
||||
execution_id=execution_id, on_progress=_on_workflow_progress)
|
||||
)
|
||||
break
|
||||
except WorkflowPaused as paused:
|
||||
@@ -391,7 +417,8 @@ def resume_workflow_task(
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
result = asyncio.run(
|
||||
engine.execute(base_input, resume_snapshot=snapshot)
|
||||
engine.execute(base_input, resume_snapshot=snapshot,
|
||||
execution_id=execution_id, on_progress=_on_workflow_progress)
|
||||
)
|
||||
break
|
||||
except WorkflowPaused as paused:
|
||||
|
||||
@@ -439,6 +439,7 @@ async function sendMessage() {
|
||||
|
||||
// 尝试 SSE 流式(带超时控制)
|
||||
let usedStreaming = false
|
||||
let placeholderIdx = -1
|
||||
streamingActive.value = false
|
||||
const abortController = new AbortController()
|
||||
const streamTimeout = setTimeout(() => abortController.abort(), 60000)
|
||||
@@ -462,8 +463,8 @@ async function sendMessage() {
|
||||
role: 'assistant', content: '', timestamp: Date.now(),
|
||||
steps: [], _traceOpen: true, iterations: 0, tool_calls_made: 0,
|
||||
}
|
||||
const idx = messages.value[key].push(msg) - 1
|
||||
const currentMsg = messages.value[key][idx]
|
||||
placeholderIdx = messages.value[key].push(msg) - 1
|
||||
const currentMsg = messages.value[key][placeholderIdx]
|
||||
|
||||
const reader = resp.body.getReader()
|
||||
const decoder = new TextDecoder()
|
||||
@@ -532,14 +533,19 @@ async function sendMessage() {
|
||||
sessionId.value[key] = data.session_id
|
||||
}
|
||||
streamingActive.value = false
|
||||
loading.value = false
|
||||
} else if (eventType === 'error') {
|
||||
currentMsg.content = data.content || ''
|
||||
currentMsg.status = 'error'
|
||||
streamingActive.value = false
|
||||
loading.value = false
|
||||
} else {
|
||||
console.warn('[AgentChat] 未知 SSE 事件类型:', eventType, data)
|
||||
}
|
||||
} catch { /* 跳过畸形事件 */ }
|
||||
}
|
||||
}
|
||||
clearTimeout(streamTimeout)
|
||||
}
|
||||
} catch {
|
||||
clearTimeout(streamTimeout)
|
||||
@@ -549,6 +555,11 @@ async function sendMessage() {
|
||||
}
|
||||
|
||||
if (!usedStreaming) {
|
||||
// 移除 SSE 阶段残留的占位消息(Issue #1)
|
||||
if (placeholderIdx >= 0 && placeholderIdx < messages.value[key].length) {
|
||||
messages.value[key].splice(placeholderIdx, 1)
|
||||
}
|
||||
|
||||
// 降级:标准 POST 请求
|
||||
const fallbackEndpoint = currentAgentId.value
|
||||
? `/api/v1/agent-chat/${currentAgentId.value}`
|
||||
@@ -613,9 +624,11 @@ function retryMessage(idx: number) {
|
||||
|
||||
// 查找错误消息之前的最后一条用户消息
|
||||
let userMsg = ''
|
||||
let userIdx = -1
|
||||
for (let i = idx - 1; i >= 0; i--) {
|
||||
if (msgs[i].role === 'user') {
|
||||
userMsg = msgs[i].content
|
||||
userIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -624,15 +637,15 @@ function retryMessage(idx: number) {
|
||||
return
|
||||
}
|
||||
|
||||
// 移除该错误消息及关联的用户消息
|
||||
const removeIndices: number[] = []
|
||||
for (let i = idx - 1; i >= 0; i--) {
|
||||
if (msgs[i].role === 'user' && msgs[i].content === userMsg) {
|
||||
// 收集需要移除的消息:用户消息 + 错误消息 + 中间的流式占位消息
|
||||
const removeIndices: number[] = [userIdx, idx]
|
||||
for (let i = userIdx + 1; i < idx; i++) {
|
||||
const m = msgs[i]
|
||||
// 流式占位消息:assistant 角色,内容为空或有 steps 但无实质内容
|
||||
if (m.role === 'assistant' && (!m.content || m.content === '') && (m.steps?.length || 0) > 0) {
|
||||
removeIndices.push(i)
|
||||
break
|
||||
}
|
||||
}
|
||||
removeIndices.push(idx)
|
||||
|
||||
// 从后往前删除,避免 index 错乱
|
||||
removeIndices.sort((a, b) => b - a)
|
||||
|
||||
Reference in New Issue
Block a user