- 在 `run_stream()` LLM 调用前 yield `think` 事件,前端即时显示"思考中..."
- 修复 tool schema 规范化逻辑:`{"function":{...}}` 格式缺少 `type` 字段导致 LLM API 拒绝
- 启动时从数据库加载自定义工具(`load_tools_from_db`),解决重启后工具丢失
- 前端 SSE 添加 60s 超时保护,任何事件类型均触发 `receivedFirstEvent`
- 流式失败自动降级到非流式 POST
- 添加 `scripts/seed_coding_agent.py` 和 `scripts/test_coding_agent.py`
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
5855 lines
306 KiB
Python
5855 lines
306 KiB
Python
"""
|
||
工作流执行引擎
|
||
"""
|
||
from typing import Dict, Any, List, Optional, Tuple
|
||
import asyncio
|
||
import hashlib
|
||
from collections import defaultdict, deque
|
||
import json
|
||
import logging
|
||
import math
|
||
import re
|
||
import time
|
||
from datetime import datetime as _datetime_class
|
||
from app.services.llm_service import llm_service
|
||
from app.services.condition_parser import condition_parser
|
||
from app.services.data_transformer import data_transformer
|
||
from app.core.exceptions import WorkflowExecutionError, WorkflowPaused
|
||
from app.core.database import SessionLocal
|
||
from app.models.agent import Agent
|
||
from app.models.execution import Execution
|
||
from app.models.workflow import Workflow
|
||
from app.services.execution_logger import ExecutionLogger
|
||
from app.core.config import settings
|
||
from app.services.scenario_dsl import normalize_scenario_dsl, validate_scenario_dsl
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 代码节点 exec 受限环境:禁止 open/eval/__import__ 等,但提供常用类型与 isinstance 等
|
||
_CODE_NODE_SAFE_BUILTINS = {
|
||
"isinstance": isinstance,
|
||
"issubclass": issubclass,
|
||
"len": len,
|
||
"str": str,
|
||
"int": int,
|
||
"float": float,
|
||
"bool": bool,
|
||
"dict": dict,
|
||
"list": list,
|
||
"tuple": tuple,
|
||
"set": set,
|
||
"frozenset": frozenset,
|
||
"range": range,
|
||
"enumerate": enumerate,
|
||
"zip": zip,
|
||
"map": map,
|
||
"filter": filter,
|
||
"sorted": sorted,
|
||
"reversed": reversed,
|
||
"min": min,
|
||
"max": max,
|
||
"sum": sum,
|
||
"abs": abs,
|
||
"round": round,
|
||
"pow": pow,
|
||
"divmod": divmod,
|
||
"all": all,
|
||
"any": any,
|
||
"chr": chr,
|
||
"ord": ord,
|
||
"repr": repr,
|
||
"hash": hash,
|
||
"slice": slice,
|
||
"object": object,
|
||
"super": super,
|
||
"property": property,
|
||
"staticmethod": staticmethod,
|
||
"classmethod": classmethod,
|
||
"Exception": Exception,
|
||
"ValueError": ValueError,
|
||
"TypeError": TypeError,
|
||
"KeyError": KeyError,
|
||
"AttributeError": AttributeError,
|
||
"IndexError": IndexError,
|
||
"StopIteration": StopIteration,
|
||
"True": True,
|
||
"False": False,
|
||
"None": None,
|
||
"json": json,
|
||
"math": math,
|
||
"hashlib": hashlib,
|
||
"datetime": _datetime_class,
|
||
}
|
||
|
||
|
||
class WorkflowEngine:
|
||
"""工作流执行引擎"""
|
||
|
||
def __init__(
|
||
self,
|
||
workflow_id: str,
|
||
workflow_data: Dict[str, Any],
|
||
logger=None,
|
||
db=None,
|
||
budget_limits: Optional[Dict[str, Any]] = None,
|
||
trusted_model_config_user_id: Optional[str] = None,
|
||
):
|
||
"""
|
||
初始化工作流引擎
|
||
|
||
Args:
|
||
workflow_id: 工作流ID
|
||
workflow_data: 工作流数据(包含nodes和edges)
|
||
logger: 执行日志记录器(可选)
|
||
db: 数据库会话(可选,用于Agent节点加载Agent配置)
|
||
trusted_model_config_user_id: 允许加载「模型配置」解密密钥的用户 ID(通常为当前执行所属 Workflow/Agent 的 owner)
|
||
"""
|
||
self.workflow_id = workflow_id
|
||
self.nodes = {node['id']: node for node in workflow_data.get('nodes', [])}
|
||
self.edges = workflow_data.get('edges', [])
|
||
self.execution_graph = None
|
||
self.node_outputs = {}
|
||
self.logger = logger
|
||
self.db = db
|
||
self._persist_scope_cache: Optional[Tuple[Optional[str], Optional[str]]] = None
|
||
self._initial_input_data: Optional[Dict[str, Any]] = None
|
||
self._steps_used: int = 0
|
||
self._llm_invocations: int = 0
|
||
self._tool_calls_used: int = 0
|
||
self.budget_limits: Dict[str, Any] = dict(budget_limits or {})
|
||
self.trusted_model_config_user_id: Optional[str] = trusted_model_config_user_id
|
||
self._cap_steps: int = max(
|
||
1, int(getattr(settings, "WORKFLOW_MAX_STEPS_PER_RUN", 2000) or 2000)
|
||
)
|
||
self._cap_llm: int = max(
|
||
1, int(getattr(settings, "WORKFLOW_MAX_LLM_INVOCATIONS_PER_RUN", 200) or 200)
|
||
)
|
||
self._cap_tool: int = max(
|
||
1, int(getattr(settings, "WORKFLOW_MAX_TOOL_CALLS_PER_RUN", 500) or 500)
|
||
)
|
||
for key, attr in (
|
||
("max_steps", "_cap_steps"),
|
||
("max_llm_invocations", "_cap_llm"),
|
||
("max_tool_calls", "_cap_tool"),
|
||
):
|
||
v = self.budget_limits.get(key)
|
||
if v is None:
|
||
continue
|
||
try:
|
||
setattr(self, attr, max(1, int(v)))
|
||
except (TypeError, ValueError):
|
||
pass
|
||
# 任意入口创建引擎时确保内置工具已注册(Celery / 节点测试 / 脚本,不依赖仅 import workflow_tasks)
|
||
from app.core.tools_bootstrap import ensure_builtin_tools_registered
|
||
|
||
ensure_builtin_tools_registered()
|
||
|
||
def _json_safe_copy(self, obj: Any) -> Any:
|
||
"""将对象转为 JSON 可序列化结构再还原,避免挂起快照中的类型问题。"""
|
||
try:
|
||
return json.loads(json.dumps(obj, default=str))
|
||
except Exception:
|
||
return obj
|
||
|
||
def _build_pause_snapshot(
|
||
self,
|
||
pending_node_id: str,
|
||
active_edges: List[Dict[str, Any]],
|
||
executed_nodes: set,
|
||
execution_sequence: List[str],
|
||
results: Dict[str, Any],
|
||
) -> Dict[str, Any]:
|
||
return {
|
||
"pending_node_id": pending_node_id,
|
||
"node_outputs": self._json_safe_copy(self.node_outputs),
|
||
"active_edges": self._json_safe_copy(active_edges),
|
||
"executed_nodes": list(executed_nodes),
|
||
"execution_sequence": list(execution_sequence),
|
||
"initial_input_data": self._json_safe_copy(self._initial_input_data or {}),
|
||
"steps_used": self._steps_used,
|
||
"llm_invocations": self._llm_invocations,
|
||
"tool_calls_used": self._tool_calls_used,
|
||
"node_results_partial": self._json_safe_copy(results),
|
||
}
|
||
|
||
async def _on_tool_executed_budget(self, tool_name: str) -> None:
|
||
"""LLM function calling 每执行一次工具时回调,计入预算。"""
|
||
_ = tool_name
|
||
self._tool_calls_used += 1
|
||
if self._tool_calls_used > self._cap_tool:
|
||
raise WorkflowExecutionError(
|
||
detail=f"已超过工具调用预算({self._cap_tool} 次)",
|
||
)
|
||
|
||
def _resolve_llm_credentials_from_model_config(
|
||
self, node_data: Dict[str, Any]
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
若节点 data 含 model_config_id,则从数据库加载加密密钥并校验归属;
|
||
返回 {"api_key","base_url","provider","model"},否则返回 None。
|
||
"""
|
||
raw = node_data.get("model_config_id") or node_data.get("modelConfigId")
|
||
if raw is None or raw == "":
|
||
return None
|
||
cfg_id = str(raw).strip()
|
||
if not cfg_id:
|
||
return None
|
||
|
||
if not self.trusted_model_config_user_id:
|
||
logger.warning(
|
||
"LLM 节点配置了 model_config_id=%s,但未绑定 trusted_model_config_user_id,"
|
||
"将跳过模型配置密钥注入(仍使用节点与环境变量)。",
|
||
cfg_id,
|
||
)
|
||
return None
|
||
|
||
from app.models.model_config import ModelConfig
|
||
from app.services.encryption_service import EncryptionService
|
||
|
||
db = self.db or SessionLocal()
|
||
own_db = self.db is None
|
||
try:
|
||
cfg = db.query(ModelConfig).filter(ModelConfig.id == cfg_id).first()
|
||
if not cfg:
|
||
raise ValueError(f"模型配置不存在: {cfg_id}")
|
||
if cfg.user_id != self.trusted_model_config_user_id:
|
||
raise ValueError("无权使用该模型配置(仅创建者可调用)")
|
||
|
||
api_key_plain = EncryptionService.decrypt(cfg.api_key)
|
||
if not (api_key_plain or "").strip():
|
||
raise ValueError("模型配置中的 API Key 无效")
|
||
|
||
base_url = (cfg.base_url or "").strip() or None
|
||
raw_prov = (cfg.provider or "").strip().lower()
|
||
if raw_prov == "local":
|
||
llm_prov = "openai"
|
||
elif raw_prov in ("openai", "deepseek"):
|
||
llm_prov = raw_prov
|
||
elif raw_prov == "anthropic":
|
||
raise ValueError(
|
||
"当前 LLM 节点暂不支持从模型配置加载 Anthropic;请改用 OpenAI 或 DeepSeek 兼容配置"
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的模型配置提供商: {cfg.provider}")
|
||
|
||
model_name = (cfg.model_name or "").strip()
|
||
if not model_name:
|
||
raise ValueError("模型配置中模型名称为空")
|
||
|
||
return {
|
||
"api_key": api_key_plain.strip(),
|
||
"base_url": base_url,
|
||
"provider": llm_prov,
|
||
"model": model_name,
|
||
}
|
||
finally:
|
||
if own_db:
|
||
db.close()
|
||
|
||
def _get_persist_scope(self) -> Tuple[Optional[str], Optional[str]]:
|
||
"""(scope_kind, scope_id) 或 (None, None),用于持久化 user_memory_*。"""
|
||
if self._persist_scope_cache is None:
|
||
from app.services.persistent_memory_service import parse_memory_scope
|
||
|
||
self._persist_scope_cache = parse_memory_scope(self.workflow_id)
|
||
return self._persist_scope_cache
|
||
|
||
def _build_subworkflow_input(
|
||
self, input_data: Dict[str, Any], input_mapping: Any
|
||
) -> Dict[str, Any]:
|
||
"""根据 mapping 组装子工作流输入。"""
|
||
if not isinstance(input_mapping, dict):
|
||
return input_data if isinstance(input_data, dict) else {"input": input_data}
|
||
|
||
sub_input: Dict[str, Any] = {}
|
||
for k, v in input_mapping.items():
|
||
if isinstance(v, str) and isinstance(input_data, dict):
|
||
# 支持字段名、{field}、以及嵌套路径
|
||
vv = (
|
||
input_data.get(v)
|
||
or input_data.get(v.strip("{}"))
|
||
or self._get_nested_value(input_data, v.strip("{}"))
|
||
)
|
||
sub_input[k] = vv if vv is not None else v
|
||
else:
|
||
sub_input[k] = v
|
||
return sub_input
|
||
|
||
def _resolve_subworkflow_target(
|
||
self, node_data: Dict[str, Any]
|
||
) -> Tuple[str, str, Dict[str, Any]]:
|
||
"""
|
||
解析子工作流目标,返回 (target_type, target_id, workflow_data)。
|
||
target_type: workflow | agent
|
||
"""
|
||
workflow_id = str(node_data.get("workflow_id") or "").strip()
|
||
agent_id = str(
|
||
node_data.get("agent_id")
|
||
or node_data.get("target_agent_id")
|
||
or ""
|
||
).strip()
|
||
if not workflow_id and not agent_id:
|
||
raise ValueError("subworkflow 节点缺少 workflow_id 或 agent_id")
|
||
|
||
db = self.db or SessionLocal()
|
||
own_db = self.db is None
|
||
try:
|
||
if agent_id:
|
||
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
||
if not agent:
|
||
raise ValueError(f"目标 Agent 不存在: {agent_id}")
|
||
cfg = agent.workflow_config or {}
|
||
return "agent", agent_id, {
|
||
"nodes": cfg.get("nodes", []),
|
||
"edges": cfg.get("edges", []),
|
||
}
|
||
|
||
wf = db.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||
if not wf:
|
||
raise ValueError(f"目标工作流不存在: {workflow_id}")
|
||
return "workflow", workflow_id, {
|
||
"nodes": wf.nodes or [],
|
||
"edges": wf.edges or [],
|
||
}
|
||
finally:
|
||
if own_db:
|
||
db.close()
|
||
|
||
def _looks_like_vector_upsert_payload(self, d: Any) -> bool:
|
||
"""判断是否为向量写入/upsert 返回的元数据(非用户可见话术)。"""
|
||
if not isinstance(d, dict):
|
||
return False
|
||
st = d.get("status")
|
||
if st in ("upserted", "inserted", "updated", "ok") and d.get("id") is not None:
|
||
return True
|
||
i = d.get("id")
|
||
if isinstance(i, str) and i.startswith("doc_") and st:
|
||
return True
|
||
return False
|
||
|
||
def _parse_zhini_final_json_dict(self, text: str) -> Optional[Tuple[Dict[str, Any], str]]:
|
||
"""
|
||
解析知你类 LLM 输出:整段为单行/多行合法 JSON,或「自然语言 + 最后一行单行 JSON」。
|
||
返回 (解析出的 dict, 末行 JSON 之前的前缀文本);整段仅为 JSON 时前缀为 ""。
|
||
解析失败返回 None。
|
||
"""
|
||
if not isinstance(text, str):
|
||
return None
|
||
s = text.strip()
|
||
if not s:
|
||
return None
|
||
if s.startswith("{"):
|
||
try:
|
||
o = json.loads(s)
|
||
if isinstance(o, dict):
|
||
return (o, "")
|
||
except Exception:
|
||
pass
|
||
last_nl = s.rfind("\n")
|
||
if last_nl < 0:
|
||
return None
|
||
prefix = s[:last_nl].rstrip()
|
||
last_line = s[last_nl + 1 :].strip()
|
||
if not last_line.startswith("{"):
|
||
return None
|
||
try:
|
||
o = json.loads(last_line)
|
||
if isinstance(o, dict):
|
||
return (o, prefix)
|
||
except Exception:
|
||
pass
|
||
return None
|
||
|
||
def _parse_reply_from_llm_value(self, out: Any) -> Optional[str]:
|
||
"""从 LLM 节点输出(JSON 字符串、纯文本或 dict)中取出可展示回复。"""
|
||
if out is None:
|
||
return None
|
||
if isinstance(out, dict):
|
||
r = out.get("reply")
|
||
if isinstance(r, str) and r.strip():
|
||
return r.strip()
|
||
try:
|
||
return json.dumps(out, ensure_ascii=False)
|
||
except Exception:
|
||
return str(out)
|
||
if isinstance(out, str):
|
||
s = out.strip()
|
||
if not s:
|
||
return None
|
||
zj = self._parse_zhini_final_json_dict(s)
|
||
if zj is not None:
|
||
obj, prefix = zj
|
||
r = obj.get("reply")
|
||
if isinstance(r, str) and r.strip():
|
||
return r.strip()
|
||
if prefix != "":
|
||
return prefix
|
||
return s
|
||
if s.startswith("{"):
|
||
try:
|
||
obj = json.loads(s)
|
||
except Exception:
|
||
return s
|
||
if isinstance(obj, dict):
|
||
r = obj.get("reply")
|
||
if isinstance(r, str) and r.strip():
|
||
return r.strip()
|
||
return s
|
||
return s
|
||
return None
|
||
|
||
def _extract_reply_from_llm_node_outputs(self) -> Optional[str]:
|
||
"""遍历已执行节点,优先从 llm-unified / llm 节点解析 reply。"""
|
||
items = list(self.node_outputs.items())
|
||
|
||
def sort_key(item: tuple) -> tuple:
|
||
nid = item[0].lower()
|
||
if "llm-unified" in nid:
|
||
return (0, nid)
|
||
if "llm" in nid:
|
||
return (1, nid)
|
||
return (2, nid)
|
||
|
||
for node_id, out in sorted(items, key=sort_key):
|
||
node = self.nodes.get(node_id) or {}
|
||
ntype = (node.get("type") or "").lower()
|
||
if ntype == "llm" or "llm" in node_id.lower():
|
||
got = self._parse_reply_from_llm_value(out)
|
||
if got:
|
||
return got
|
||
for _, out in items:
|
||
got = self._parse_reply_from_llm_value(out)
|
||
if got:
|
||
return got
|
||
return None
|
||
|
||
def _looks_like_unresolved_template(self, s: Any) -> bool:
|
||
"""是否为未替换的 {{...}} 占位串(含 {{memory.xxx}} 等)。"""
|
||
if not isinstance(s, str):
|
||
return False
|
||
t = s.strip()
|
||
if not t:
|
||
return False
|
||
return bool(re.fullmatch(r"\{\{\s*[\w.]+\s*\}\}", t))
|
||
|
||
def _coalesce_final_user_text(self) -> Optional[str]:
|
||
"""从已执行 LLM 节点取可展示文本,供替换滞留在 End/result 上的模板字面量。"""
|
||
fb = self._extract_reply_from_llm_node_outputs()
|
||
if isinstance(fb, str) and fb.strip() and not self._looks_like_unresolved_template(fb):
|
||
return fb.strip()
|
||
|
||
ranked: List[Tuple[int, str, str]] = []
|
||
for node_id, out in self.node_outputs.items():
|
||
node = self.nodes.get(node_id) or {}
|
||
ntype = (node.get("type") or "").lower()
|
||
if ntype != "llm" and "llm" not in node_id.lower():
|
||
continue
|
||
if not isinstance(out, str):
|
||
continue
|
||
s = out.strip()
|
||
if not s or self._looks_like_unresolved_template(s):
|
||
continue
|
||
pri = 0 if "llm-unified" in node_id else 1
|
||
ranked.append((pri, node_id, s))
|
||
ranked.sort(key=lambda x: (x[0], x[1]))
|
||
for _, _, s in ranked:
|
||
return s
|
||
return None
|
||
|
||
def _replace_if_template_placeholder(self, final_output: Any) -> Any:
|
||
"""若为未解析模板串,替换为 LLM 节点正文(字符串)。"""
|
||
if not isinstance(final_output, str):
|
||
return final_output
|
||
if not self._looks_like_unresolved_template(final_output):
|
||
return final_output
|
||
alt = self._coalesce_final_user_text()
|
||
if alt:
|
||
logger.info(
|
||
"[rjb] 模板占位符「%s」已替换为 LLM 输出(节选)",
|
||
final_output.strip()[:48],
|
||
)
|
||
return alt
|
||
return final_output
|
||
|
||
def _extract_user_profile_from_llm_node_outputs(self) -> Optional[Dict[str, Any]]:
|
||
"""从已执行 LLM 节点 JSON 输出中取 user_profile(用于缓存合并)。"""
|
||
for node_id, out in self.node_outputs.items():
|
||
if "llm" not in node_id.lower():
|
||
continue
|
||
if not isinstance(out, str):
|
||
continue
|
||
zj = self._parse_zhini_final_json_dict(out.strip())
|
||
if zj is None:
|
||
continue
|
||
obj, _ = zj
|
||
up = obj.get("user_profile")
|
||
if isinstance(up, dict):
|
||
return dict(up)
|
||
return None
|
||
|
||
def _memory_needs_backfill(self, mem: Any) -> bool:
|
||
"""上游若传了 memory: {} 或仅占位空对象,应允许从 cache-query 等节点补全。"""
|
||
if mem is None:
|
||
return True
|
||
if not isinstance(mem, dict):
|
||
return True
|
||
if not mem:
|
||
return True
|
||
up = mem.get("user_profile")
|
||
hist = mem.get("conversation_history")
|
||
ctx = mem.get("context")
|
||
has_up = isinstance(up, dict) and bool(up)
|
||
has_hist = isinstance(hist, list) and len(hist) > 0
|
||
has_ctx = isinstance(ctx, dict) and bool(ctx)
|
||
return not (has_up or has_hist or has_ctx)
|
||
|
||
def _extract_user_message_text(self, input_data: Any) -> str:
|
||
"""
|
||
从节点输入中提取用户当前轮发言。
|
||
需与 LLM 节点 user_query 提取路径一致:带 sourceHandle 时 query 常在 right/嵌套 input 下,
|
||
否则姓名补全与记忆写入会拿不到原文。
|
||
"""
|
||
if not isinstance(input_data, dict):
|
||
return ""
|
||
nested_input = input_data.get("input")
|
||
if isinstance(nested_input, dict):
|
||
for key in ("query", "input", "text", "message", "content", "user_input", "USER_INPUT"):
|
||
v = nested_input.get(key)
|
||
if isinstance(v, str) and v.strip():
|
||
return v
|
||
for key in ("query", "input", "text", "message", "content", "user_input", "USER_INPUT"):
|
||
if key not in input_data:
|
||
continue
|
||
value = input_data[key]
|
||
if isinstance(value, str) and value.strip():
|
||
return value
|
||
if isinstance(value, dict):
|
||
for sub_key in ("query", "input", "text", "message", "content", "user_input", "USER_INPUT"):
|
||
sv = value.get(sub_key)
|
||
if isinstance(sv, str) and sv.strip():
|
||
return sv
|
||
for bucket in ("right", "left", "output", "data"):
|
||
b = input_data.get(bucket)
|
||
if isinstance(b, dict):
|
||
for key in ("query", "USER_INPUT", "user_input", "input", "text", "message", "content"):
|
||
v = b.get(key)
|
||
if isinstance(v, str) and v.strip():
|
||
return v
|
||
return ""
|
||
|
||
def _enrich_llm_json_user_profile(self, result: str, input_data: Any) -> str:
|
||
"""
|
||
若 LLM 最终为单行 JSON 或「前文 + 末行 JSON」,但未在 user_profile 中写入 name,
|
||
则从上游记忆或用户输入补全,便于 Cache 合并与多轮记住姓名。
|
||
工具模式下常见「多行说明 + 末行 JSON」,必须解析末行,否则无法合并 user_profile。
|
||
"""
|
||
if not isinstance(result, str) or not result.strip():
|
||
return result
|
||
zj = self._parse_zhini_final_json_dict(result.strip())
|
||
if zj is None:
|
||
return result
|
||
obj, prefix = zj
|
||
if not isinstance(obj, dict):
|
||
return result
|
||
up = obj.get("user_profile")
|
||
if not isinstance(up, dict):
|
||
up = {}
|
||
obj["user_profile"] = up
|
||
if up.get("name"):
|
||
new_line = json.dumps(obj, ensure_ascii=False)
|
||
return f"{prefix}\n{new_line}" if prefix != "" else new_line
|
||
# 上游 Cache 已合并的用户画像(顶层或 memory.user_profile),补进 JSON,避免模型漏写导致「失忆」
|
||
if isinstance(input_data, dict):
|
||
stored_profile = None
|
||
if isinstance(input_data.get("user_profile"), dict) and input_data["user_profile"]:
|
||
stored_profile = input_data["user_profile"]
|
||
elif isinstance(input_data.get("memory"), dict):
|
||
mu = input_data["memory"].get("user_profile")
|
||
if isinstance(mu, dict) and mu:
|
||
stored_profile = mu
|
||
if stored_profile:
|
||
merged = {**stored_profile, **up}
|
||
obj["user_profile"] = merged
|
||
up = merged
|
||
if up.get("name"):
|
||
new_line = json.dumps(obj, ensure_ascii=False)
|
||
return f"{prefix}\n{new_line}" if prefix != "" else new_line
|
||
q = self._extract_user_message_text(input_data)
|
||
if not q:
|
||
new_line = json.dumps(obj, ensure_ascii=False)
|
||
return f"{prefix}\n{new_line}" if prefix != "" else new_line
|
||
for pat in (
|
||
r"我叫\s*([^\s,。!?,.!?]{2,32})",
|
||
r"我的名字?(?:是|叫)?\s*([^\s,。!?,.!?]{2,32})",
|
||
r"(?:称呼我|可以叫我)\s*([^\s,。!?,.!?]{2,32})",
|
||
):
|
||
m = re.search(pat, q)
|
||
if m:
|
||
name = m.group(1).strip().strip(",。!?,.!?")
|
||
if not name or len(name) < 2:
|
||
continue
|
||
if "什么" in name or name in ("谁", "哪位", "什么人"):
|
||
continue
|
||
up["name"] = name
|
||
new_line = json.dumps(obj, ensure_ascii=False)
|
||
return f"{prefix}\n{new_line}" if prefix != "" else new_line
|
||
new_line = json.dumps(obj, ensure_ascii=False)
|
||
return f"{prefix}\n{new_line}" if prefix != "" else new_line
|
||
|
||
def _resolve_end_output_if_vector_metadata(self, final_output: Any, input_data: Any) -> Any:
|
||
"""
|
||
End 节点误接「写入向量库」等上游时,对外会变成 upsert 元数据。
|
||
若检测到该情况,则改为 LLM 输出中的 reply 文本。
|
||
"""
|
||
upsert_like = False
|
||
if self._looks_like_vector_upsert_payload(input_data):
|
||
upsert_like = True
|
||
elif isinstance(input_data, str):
|
||
t = input_data.strip()
|
||
if t.startswith("{") and "upserted" in t:
|
||
try:
|
||
if self._looks_like_vector_upsert_payload(json.loads(t)):
|
||
upsert_like = True
|
||
except Exception:
|
||
pass
|
||
if not upsert_like and isinstance(final_output, dict):
|
||
upsert_like = self._looks_like_vector_upsert_payload(final_output)
|
||
if not upsert_like and isinstance(final_output, str):
|
||
t = final_output.strip()
|
||
if t.startswith("{") and "upserted" in t:
|
||
try:
|
||
if self._looks_like_vector_upsert_payload(json.loads(t)):
|
||
upsert_like = True
|
||
except Exception:
|
||
pass
|
||
if not upsert_like:
|
||
return final_output
|
||
reply = self._extract_reply_from_llm_node_outputs()
|
||
if reply:
|
||
logger.info(
|
||
"[rjb] End 节点上游为向量写入元数据,已替换为 LLM reply 长度=%s",
|
||
len(reply),
|
||
)
|
||
return reply
|
||
return final_output
|
||
|
||
def build_execution_graph(self, active_edges: Optional[List[Dict[str, Any]]] = None) -> List[str]:
|
||
"""
|
||
构建执行图(DAG)并返回拓扑排序结果
|
||
|
||
Args:
|
||
active_edges: 活跃的边列表(用于条件分支过滤)
|
||
|
||
Returns:
|
||
拓扑排序后的节点ID列表
|
||
"""
|
||
# 使用活跃的边,如果没有提供则使用所有边
|
||
edges_to_use = active_edges if active_edges is not None else self.edges
|
||
|
||
# 构建邻接表和入度表
|
||
graph = defaultdict(list)
|
||
in_degree = defaultdict(int)
|
||
|
||
# 初始化所有节点的入度
|
||
for node_id in self.nodes.keys():
|
||
in_degree[node_id] = 0
|
||
|
||
# 构建图
|
||
for edge in edges_to_use:
|
||
source = edge['source']
|
||
target = edge['target']
|
||
graph[source].append(target)
|
||
in_degree[target] += 1
|
||
|
||
# 拓扑排序(Kahn算法)
|
||
queue = deque()
|
||
result = []
|
||
|
||
# 找到所有入度为0的节点(起始节点)
|
||
for node_id in self.nodes.keys():
|
||
if in_degree[node_id] == 0:
|
||
queue.append(node_id)
|
||
|
||
while queue:
|
||
node_id = queue.popleft()
|
||
result.append(node_id)
|
||
|
||
# 处理该节点的所有出边
|
||
for neighbor in graph[node_id]:
|
||
in_degree[neighbor] -= 1
|
||
if in_degree[neighbor] == 0:
|
||
queue.append(neighbor)
|
||
|
||
# 检查是否有环(只检查可达节点)
|
||
reachable_nodes = set(result)
|
||
if len(reachable_nodes) < len(self.nodes):
|
||
# 有些节点不可达,这是正常的(条件分支)
|
||
pass
|
||
|
||
self.execution_graph = result
|
||
return result
|
||
|
||
def _forward_reachable_nodes(self, active_edges: List[Dict[str, Any]]) -> set:
|
||
"""从所有 Start 沿 active_edges 正向可达的节点。用于互斥分支汇合:只要求「当前图上可达」的前驱已执行。"""
|
||
start_ids = [n["id"] for n in self.nodes.values() if n.get("type") == "start"]
|
||
graph: Dict[str, List[str]] = defaultdict(list)
|
||
for e in active_edges:
|
||
graph[e["source"]].append(e["target"])
|
||
seen = set(start_ids)
|
||
q = deque(start_ids)
|
||
while q:
|
||
u = q.popleft()
|
||
for v in graph.get(u, ()):
|
||
if v not in seen:
|
||
seen.add(v)
|
||
q.append(v)
|
||
return seen
|
||
|
||
def get_node_input(self, node_id: str, node_outputs: Dict[str, Any], active_edges: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
|
||
"""
|
||
获取节点的输入数据
|
||
|
||
Args:
|
||
node_id: 节点ID
|
||
node_outputs: 所有节点的输出数据
|
||
active_edges: 活跃的边列表(用于条件分支过滤)
|
||
|
||
Returns:
|
||
节点的输入数据
|
||
"""
|
||
# 使用活跃的边,如果没有提供则使用所有边
|
||
edges_to_use = active_edges if active_edges is not None else self.edges
|
||
|
||
# 找到所有指向该节点的边
|
||
input_data = {}
|
||
|
||
for edge in edges_to_use:
|
||
if edge['target'] == node_id:
|
||
source_id = edge['source']
|
||
source_output = node_outputs.get(source_id, {})
|
||
logger.info(f"[rjb] 获取节点输入: target={node_id}, source={source_id}, source_output={source_output}, source_output_type={type(source_output)}, sourceHandle={edge.get('sourceHandle')}")
|
||
|
||
# 如果有sourceHandle,使用它作为key
|
||
if 'sourceHandle' in edge and edge['sourceHandle']:
|
||
input_data[edge['sourceHandle']] = source_output
|
||
# 重要:即使有sourceHandle,也要保留记忆相关字段(conversation_history、user_profile、context)
|
||
# 这些字段应该始终传递到下游节点
|
||
if isinstance(source_output, dict):
|
||
memory_fields = ['conversation_history', 'user_profile', 'context', 'memory']
|
||
for field in memory_fields:
|
||
if field in source_output:
|
||
input_data[field] = source_output[field]
|
||
logger.info(f"[rjb] 保留记忆字段 {field} 到节点 {node_id} 的输入")
|
||
else:
|
||
# 否则合并所有输入
|
||
if isinstance(source_output, dict):
|
||
# 如果source_output包含output字段,展开它
|
||
if 'output' in source_output and isinstance(source_output['output'], dict):
|
||
# 将output中的内容展开到顶层
|
||
input_data.update(source_output['output'])
|
||
# 保留其他字段(如status)
|
||
for key, value in source_output.items():
|
||
if key != 'output':
|
||
input_data[key] = value
|
||
else:
|
||
# 直接展开source_output的内容
|
||
input_data.update(source_output)
|
||
logger.info(f"[rjb] 展开source_output后: input_data={input_data}")
|
||
else:
|
||
# 如果source_output不是字典,包装到input字段
|
||
input_data['input'] = source_output
|
||
logger.info(f"[rjb] source_output不是字典,包装到input字段: input_data={input_data}")
|
||
|
||
# 重要:对于LLM节点和cache节点,如果输入中没有memory字段,尝试从所有已执行的节点中查找并合并记忆字段
|
||
# 这样可以确保即使上游节点没有传递记忆信息,这些节点也能访问到记忆
|
||
node_type = None
|
||
node = self.nodes.get(node_id)
|
||
if node:
|
||
node_type = node.get('type')
|
||
|
||
# 对于LLM节点和cache节点(特别是cache-update),需要memory字段
|
||
if node_type in ['llm', 'cache'] and self._memory_needs_backfill(input_data.get('memory')):
|
||
# 从所有已执行的节点中查找memory字段
|
||
for executed_node_id, node_output in self.node_outputs.items():
|
||
if isinstance(node_output, dict):
|
||
# 检查是否有memory字段
|
||
if 'memory' in node_output:
|
||
input_data['memory'] = node_output['memory']
|
||
logger.info(f"[rjb] 为{node_type}节点 {node_id} 从节点 {executed_node_id} 获取memory字段")
|
||
break
|
||
# 或者检查是否有conversation_history等记忆字段
|
||
elif 'conversation_history' in node_output:
|
||
# 构建memory对象
|
||
memory = {}
|
||
for field in ['conversation_history', 'user_profile', 'context']:
|
||
if field in node_output:
|
||
memory[field] = node_output[field]
|
||
if memory:
|
||
input_data['memory'] = memory
|
||
logger.info(f"[rjb] 为{node_type}节点 {node_id} 从节点 {executed_node_id} 构建memory对象: {list(memory.keys())}")
|
||
break
|
||
|
||
# 如果input_data中没有query字段,尝试从所有已执行的节点中查找(特别是start节点)
|
||
if 'query' not in input_data:
|
||
# 优先查找start节点
|
||
for node_id_key in ['start-1', 'start']:
|
||
if node_id_key in node_outputs:
|
||
node_output = node_outputs[node_id_key]
|
||
if isinstance(node_output, dict):
|
||
# 检查顶层字段(因为node_outputs存储的是output字段的内容)
|
||
if 'query' in node_output:
|
||
input_data['query'] = node_output['query']
|
||
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取query: {input_data['query']}")
|
||
break
|
||
# 检查output字段(兼容性)
|
||
elif 'output' in node_output and isinstance(node_output['output'], dict):
|
||
if 'query' in node_output['output']:
|
||
input_data['query'] = node_output['output']['query']
|
||
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取query: {input_data['query']}")
|
||
break
|
||
|
||
# 如果还没找到,遍历所有节点
|
||
if 'query' not in input_data:
|
||
for node_id_key, node_output in node_outputs.items():
|
||
if isinstance(node_output, dict):
|
||
# 检查顶层字段
|
||
if 'query' in node_output:
|
||
input_data['query'] = node_output['query']
|
||
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取query: {input_data['query']}")
|
||
break
|
||
# 检查output字段(兼容性)
|
||
elif 'output' in node_output and isinstance(node_output['output'], dict):
|
||
if 'query' in node_output['output']:
|
||
input_data['query'] = node_output['output']['query']
|
||
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取query: {input_data['query']}")
|
||
break
|
||
|
||
# 如果input_data中没有requirement_analysis字段,尝试从所有已执行的节点中查找
|
||
if 'requirement_analysis' not in input_data:
|
||
# 优先查找requirement-analysis节点
|
||
for node_id_key in ['llm-requirement-analysis', 'requirement-analysis']:
|
||
if node_id_key in node_outputs:
|
||
node_output = node_outputs[node_id_key]
|
||
if isinstance(node_output, dict):
|
||
# 检查顶层字段(因为node_outputs存储的是output字段的内容)
|
||
if 'requirement_analysis' in node_output:
|
||
input_data['requirement_analysis'] = node_output['requirement_analysis']
|
||
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取requirement_analysis")
|
||
break
|
||
# 检查output字段(兼容性)
|
||
elif 'output' in node_output and isinstance(node_output['output'], dict):
|
||
if 'requirement_analysis' in node_output['output']:
|
||
input_data['requirement_analysis'] = node_output['output']['requirement_analysis']
|
||
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取requirement_analysis")
|
||
break
|
||
|
||
# 如果还没找到,遍历所有节点
|
||
if 'requirement_analysis' not in input_data:
|
||
for node_id_key, node_output in node_outputs.items():
|
||
if isinstance(node_output, dict):
|
||
# 检查顶层字段
|
||
if 'requirement_analysis' in node_output:
|
||
input_data['requirement_analysis'] = node_output['requirement_analysis']
|
||
logger.debug(f"[rjb] 从节点 {node_id_key} 中获取requirement_analysis")
|
||
break
|
||
# 检查output字段(兼容性)
|
||
elif 'output' in node_output and isinstance(node_output['output'], dict):
|
||
if 'requirement_analysis' in node_output['output']:
|
||
input_data['requirement_analysis'] = node_output['output']['requirement_analysis']
|
||
logger.debug(f"[rjb] 从节点 {node_id_key} 的output中获取requirement_analysis")
|
||
break
|
||
|
||
# 会话身份:带 sourceHandle 的边会把上游输出放在 input_data[handle] 下,顶层可能缺少 user_id,
|
||
# Cache 键 user_memory_{{user_id}} 会退化为 default,Redis 跨请求记忆无法按用户隔离。
|
||
if isinstance(input_data, dict) and not input_data.get('user_id') and not input_data.get('USER_ID'):
|
||
for nid, nmeta in self.nodes.items():
|
||
if nmeta.get('type') == 'start':
|
||
st = node_outputs.get(nid)
|
||
if isinstance(st, dict):
|
||
uid = st.get('user_id') if st.get('user_id') is not None else st.get('USER_ID')
|
||
if uid is not None and str(uid).strip() != '':
|
||
input_data['user_id'] = uid
|
||
logger.info(f"[rjb] 从 Start 节点 {nid} 提升 user_id 到节点 {node_id} 输入顶层")
|
||
break
|
||
if not input_data.get('user_id') and not input_data.get('USER_ID'):
|
||
for v in input_data.values():
|
||
if isinstance(v, dict):
|
||
uid = v.get('user_id') if v.get('user_id') is not None else v.get('USER_ID')
|
||
if uid is not None and str(uid).strip() != '':
|
||
input_data['user_id'] = uid
|
||
logger.info(f"[rjb] 从嵌套上游输入提升 user_id 到节点 {node_id} 输入顶层")
|
||
break
|
||
|
||
# Start→下游 带 sourceHandle 时 query/USER_INPUT/attachments 常在 right 内,提升到顶层供 LLM 与其它节点读取
|
||
if isinstance(input_data, dict):
|
||
_rk = input_data.get('right')
|
||
if isinstance(_rk, dict):
|
||
for _uk in (
|
||
'query',
|
||
'USER_INPUT',
|
||
'user_input',
|
||
'attachments',
|
||
'text',
|
||
'message',
|
||
'content',
|
||
):
|
||
if input_data.get(_uk) not in (None, ''):
|
||
continue
|
||
if _uk in _rk and _rk[_uk] is not None:
|
||
input_data[_uk] = _rk[_uk]
|
||
logger.debug(f"[rjb] 从 right 提升到顶层: {_uk}")
|
||
|
||
logger.debug(f"[rjb] 节点输入结果: node_id={node_id}, input_data={input_data}")
|
||
return input_data
|
||
|
||
def _get_nested_value(self, data: Dict[str, Any], path: str) -> Any:
|
||
"""
|
||
从嵌套字典中获取值(支持点号路径和数组索引)
|
||
|
||
Args:
|
||
data: 数据字典
|
||
path: 路径,如 "user.name" 或 "items[0].price"
|
||
|
||
Returns:
|
||
路径对应的值
|
||
"""
|
||
if not path:
|
||
return data
|
||
|
||
parts = path.split('.')
|
||
result = data
|
||
|
||
for part in parts:
|
||
if '[' in part and ']' in part:
|
||
# 处理数组索引,如 "items[0]"
|
||
key = part[:part.index('[')]
|
||
index_str = part[part.index('[') + 1:part.index(']')]
|
||
|
||
if isinstance(result, dict):
|
||
result = result.get(key)
|
||
elif isinstance(result, list):
|
||
try:
|
||
result = result[int(index_str)]
|
||
except (ValueError, IndexError):
|
||
return None
|
||
else:
|
||
return None
|
||
|
||
if result is None:
|
||
return None
|
||
else:
|
||
# 普通键访问
|
||
if isinstance(result, dict):
|
||
result = result.get(part)
|
||
else:
|
||
return None
|
||
|
||
if result is None:
|
||
return None
|
||
|
||
return result
|
||
|
||
def _resolve_llm_prompt_placeholder(self, input_data: Dict[str, Any], var_path: str) -> Any:
|
||
"""
|
||
解析 LLM 提示词中的 {{path}}。
|
||
Cache get 常把 user_profile、conversation_history 等合并到 input_data 顶层且无 memory 包裹,
|
||
仅用 memory.user_profile 会取不到;回退到顶层同名字段或 memory 子对象。
|
||
"""
|
||
v = self._get_nested_value(input_data, var_path)
|
||
if v is not None:
|
||
return v
|
||
if not isinstance(input_data, dict) or not var_path.startswith("memory."):
|
||
return None
|
||
tail = var_path[7:]
|
||
if not tail:
|
||
return input_data.get("memory")
|
||
alt = self._get_nested_value(input_data, tail)
|
||
if alt is not None:
|
||
return alt
|
||
mem = input_data.get("memory")
|
||
if isinstance(mem, dict):
|
||
inner = self._get_nested_value(mem, tail)
|
||
if inner is not None:
|
||
return inner
|
||
if tail == "assistant_display_name":
|
||
ctx = mem.get("context")
|
||
if isinstance(ctx, dict):
|
||
n = ctx.get("assistant_display_name")
|
||
if n is not None:
|
||
return n
|
||
if tail == "assistant_display_name":
|
||
ctx = input_data.get("context")
|
||
if isinstance(ctx, dict):
|
||
return ctx.get("assistant_display_name")
|
||
return None
|
||
|
||
def _format_prior_conversation_for_llm(
|
||
self, input_data: Dict[str, Any], original_prompt_template: str
|
||
) -> Optional[str]:
|
||
"""
|
||
Agent 多轮对话:执行请求若携带 conversation_history,而提示词未使用
|
||
{{memory.conversation_history}} 等占位符,则在此处拼进最终 prompt,避免模型「失忆」。
|
||
"""
|
||
t = original_prompt_template or ""
|
||
if "memory.conversation_history" in t or re.search(
|
||
r"\{\{[^}]*conversation_history[^}]*\}\}", t
|
||
):
|
||
return None
|
||
|
||
hist: Any = None
|
||
if isinstance(input_data, dict):
|
||
hist = input_data.get("conversation_history")
|
||
if hist is None and isinstance(input_data.get("memory"), dict):
|
||
hist = input_data["memory"].get("conversation_history")
|
||
if hist is None and isinstance(input_data.get("right"), dict):
|
||
r = input_data["right"]
|
||
hist = r.get("conversation_history")
|
||
if hist is None and isinstance(r.get("memory"), dict):
|
||
hist = r["memory"].get("conversation_history")
|
||
|
||
if not hist or not isinstance(hist, list):
|
||
return None
|
||
|
||
lines: List[str] = []
|
||
max_turns = 24
|
||
for msg in hist[-max_turns:]:
|
||
if not isinstance(msg, dict):
|
||
continue
|
||
role = msg.get("role", "")
|
||
content = msg.get("content", "")
|
||
if content is None:
|
||
continue
|
||
if not isinstance(content, str):
|
||
content = str(content)
|
||
content = content.strip()
|
||
if not content:
|
||
continue
|
||
if role == "user":
|
||
lines.append(f"用户:{content}")
|
||
elif role in ("assistant", "agent"):
|
||
lines.append(f"助手:{content}")
|
||
else:
|
||
lines.append(f"{role}:{content}")
|
||
|
||
if not lines:
|
||
return None
|
||
|
||
body = "\n".join(lines)
|
||
max_chars = 12000
|
||
if len(body) > max_chars:
|
||
body = body[-max_chars:] + "\n…(更早的对话已截断)"
|
||
|
||
return f"【本轮之前的对话】\n{body}"
|
||
|
||
def _resolve_vector_db_query_embedding(
|
||
self, input_data: Any, query_vector_config: Any
|
||
) -> Optional[List[Any]]:
|
||
"""
|
||
从节点配置的 query_vector 路径及上游合并结果中解析查询向量。
|
||
上游 embedding 可能挂在 right/left/output、或 HTTP 返回的 data[0].embedding 等路径下。
|
||
无法解析时返回 None(search 应降级为空结果,避免整图失败)。
|
||
"""
|
||
|
||
def _is_numeric_vector(v: Any) -> bool:
|
||
if v is None:
|
||
return False
|
||
if isinstance(v, tuple):
|
||
v = list(v)
|
||
if not isinstance(v, list) or len(v) == 0:
|
||
return False
|
||
return isinstance(v[0], (int, float))
|
||
|
||
def _as_float_list(v: Any) -> Optional[List[Any]]:
|
||
if isinstance(v, tuple):
|
||
v = list(v)
|
||
if _is_numeric_vector(v):
|
||
return v
|
||
return None
|
||
|
||
def _deep_find_embedding(obj: Any, depth: int = 0) -> Optional[List[Any]]:
|
||
"""在嵌套 dict/list 中查找第一个数值向量 embedding。"""
|
||
if depth > 8 or obj is None:
|
||
return None
|
||
if isinstance(obj, dict):
|
||
for key in ("embedding", "vector", "query_embedding"):
|
||
got = _as_float_list(obj.get(key))
|
||
if got is not None:
|
||
return got
|
||
for vv in obj.values():
|
||
got = _deep_find_embedding(vv, depth + 1)
|
||
if got is not None:
|
||
return got
|
||
elif isinstance(obj, list) and obj:
|
||
if isinstance(obj[0], dict):
|
||
for item in obj:
|
||
got = _deep_find_embedding(item, depth + 1)
|
||
if got is not None:
|
||
return got
|
||
return None
|
||
|
||
query_vec: Any = None
|
||
if isinstance(query_vector_config, str):
|
||
path = query_vector_config.replace("{", "").replace("}", "").strip()
|
||
if path and isinstance(input_data, dict):
|
||
query_vec = self._get_nested_value(input_data, path)
|
||
elif isinstance(query_vector_config, list):
|
||
query_vec = query_vector_config
|
||
|
||
# 路径指到 JSON 字符串(部分 merge 会把数组序列化)
|
||
if isinstance(query_vec, str) and query_vec.strip().startswith("["):
|
||
try:
|
||
parsed = json.loads(query_vec)
|
||
pv = _as_float_list(parsed)
|
||
if pv is not None:
|
||
query_vec = pv
|
||
except Exception:
|
||
pass
|
||
|
||
if not _is_numeric_vector(query_vec) and isinstance(input_data, dict):
|
||
for k in ("embedding", "vector", "query_embedding", "query_vector"):
|
||
v = input_data.get(k)
|
||
pv = _as_float_list(v)
|
||
if pv is not None:
|
||
query_vec = pv
|
||
break
|
||
if not _is_numeric_vector(query_vec):
|
||
# 合并节点常把一侧整条输出放在 right / left / output(向量本身即列表)
|
||
for path in (
|
||
"right",
|
||
"left",
|
||
"output",
|
||
"right.embedding",
|
||
"left.embedding",
|
||
"output.embedding",
|
||
"right.output.embedding",
|
||
"left.output.embedding",
|
||
"data.embedding",
|
||
"output.data[0].embedding",
|
||
"body.data[0].embedding",
|
||
"result.embedding",
|
||
"response.embedding",
|
||
):
|
||
v = self._get_nested_value(input_data, path)
|
||
pv = _as_float_list(v)
|
||
if pv is not None:
|
||
query_vec = pv
|
||
break
|
||
if isinstance(v, dict):
|
||
inner = v.get("embedding") or v.get("vector")
|
||
pv = _as_float_list(inner)
|
||
if pv is not None:
|
||
query_vec = pv
|
||
break
|
||
# OpenAI 风格:data 数组
|
||
if not _is_numeric_vector(query_vec):
|
||
data = input_data.get("data")
|
||
if isinstance(data, list) and data:
|
||
first = data[0]
|
||
if isinstance(first, dict):
|
||
emb = first.get("embedding")
|
||
pv = _as_float_list(emb)
|
||
if pv is not None:
|
||
query_vec = pv
|
||
|
||
if not _is_numeric_vector(query_vec):
|
||
found = _deep_find_embedding(input_data, 0)
|
||
if found is not None:
|
||
query_vec = found
|
||
|
||
if _is_numeric_vector(query_vec):
|
||
return list(query_vec) if isinstance(query_vec, tuple) else query_vec
|
||
return None
|
||
|
||
def _resolve_brace_template_var(self, expanded_input: Dict[str, Any], var_name: str) -> Any:
|
||
"""
|
||
解析 transform mapping 中的 {{var_name}}。
|
||
LLM 输出常挂在 sourceHandle=right;历史配置误用 {{output}} 时这里回退到 right/reply 等,避免字面量 {{...}} 流入 json-parse。
|
||
"""
|
||
|
||
def _not_placeholder(val: Any) -> bool:
|
||
if val is None or val == "":
|
||
return False
|
||
if isinstance(val, str):
|
||
t = val.strip()
|
||
if len(t) >= 4 and t.startswith("{{") and t.endswith("}}"):
|
||
return False
|
||
return True
|
||
|
||
v = self._get_nested_value(expanded_input, var_name)
|
||
if _not_placeholder(v):
|
||
return v
|
||
r = expanded_input.get("right")
|
||
if var_name == "output":
|
||
for alt in ("reply", "result", "data", "content", "text"):
|
||
x = self._get_nested_value(expanded_input, alt)
|
||
if _not_placeholder(x):
|
||
return x
|
||
if isinstance(r, str) and r.strip() and _not_placeholder(r):
|
||
return r
|
||
if isinstance(r, dict):
|
||
for alt in ("reply", "output", "data", "result", "content"):
|
||
x = r.get(alt)
|
||
if _not_placeholder(x):
|
||
return x
|
||
fb = self._extract_reply_from_llm_node_outputs()
|
||
if fb:
|
||
return fb
|
||
if var_name == "reply":
|
||
if isinstance(r, dict):
|
||
x = r.get("reply")
|
||
if _not_placeholder(x):
|
||
return x
|
||
if isinstance(r, str) and r.strip().startswith("{"):
|
||
try:
|
||
obj = json.loads(r)
|
||
if isinstance(obj, dict) and obj.get("reply") is not None:
|
||
return obj.get("reply")
|
||
except Exception:
|
||
pass
|
||
fb = self._extract_reply_from_llm_node_outputs()
|
||
if fb:
|
||
return fb
|
||
if var_name == "user_profile":
|
||
if isinstance(r, dict):
|
||
x = r.get("user_profile")
|
||
if isinstance(x, dict):
|
||
return x
|
||
if var_name == "result":
|
||
fb = self._extract_reply_from_llm_node_outputs()
|
||
if fb:
|
||
return fb
|
||
return None
|
||
|
||
async def _execute_loop_body(self, loop_node_id: str, loop_input: Dict[str, Any], iteration_index: int) -> Dict[str, Any]:
|
||
"""
|
||
执行循环体
|
||
|
||
Args:
|
||
loop_node_id: 循环节点ID
|
||
loop_input: 循环体的输入数据
|
||
iteration_index: 当前迭代索引
|
||
|
||
Returns:
|
||
循环体的执行结果
|
||
"""
|
||
# 找到循环节点的直接子节点(循环体开始节点)
|
||
loop_body_start_nodes = []
|
||
for edge in self.edges:
|
||
if edge.get('source') == loop_node_id:
|
||
target_id = edge.get('target')
|
||
if target_id and target_id in self.nodes:
|
||
loop_body_start_nodes.append(target_id)
|
||
|
||
if not loop_body_start_nodes:
|
||
# 如果没有子节点,直接返回输入数据
|
||
return {'output': loop_input, 'status': 'success'}
|
||
|
||
# 执行循环体:从循环体开始节点执行到循环结束节点或没有更多节点
|
||
# 简化处理:只执行第一个子节点链
|
||
executed_in_loop = set()
|
||
loop_results = {}
|
||
current_node_id = loop_body_start_nodes[0] # 简化:只执行第一个子节点链
|
||
|
||
# 执行循环体内的节点(简化版本:只执行直接连接的子节点)
|
||
max_iterations = 100 # 防止无限循环
|
||
iteration = 0
|
||
|
||
while current_node_id and iteration < max_iterations:
|
||
iteration += 1
|
||
|
||
if current_node_id in executed_in_loop:
|
||
break # 避免循环体内部循环
|
||
|
||
if current_node_id not in self.nodes:
|
||
break
|
||
|
||
node = self.nodes[current_node_id]
|
||
executed_in_loop.add(current_node_id)
|
||
|
||
# 如果是循环结束节点,停止执行
|
||
if node.get('type') == 'loop_end' or node.get('type') == 'end':
|
||
break
|
||
|
||
# 执行节点
|
||
result = await self.execute_node(node, loop_input)
|
||
loop_results[current_node_id] = result
|
||
|
||
if result.get('status') != 'success':
|
||
return result
|
||
|
||
# 更新输入数据为当前节点的输出
|
||
if result.get('output'):
|
||
if isinstance(result.get('output'), dict):
|
||
loop_input = {**loop_input, **result.get('output')}
|
||
else:
|
||
loop_input = {**loop_input, 'result': result.get('output')}
|
||
|
||
# 找到下一个节点(简化:只找第一个子节点)
|
||
next_node_id = None
|
||
for edge in self.edges:
|
||
if edge.get('source') == current_node_id:
|
||
target_id = edge.get('target')
|
||
if target_id and target_id in self.nodes and target_id not in executed_in_loop:
|
||
# 跳过循环节点本身
|
||
if target_id != loop_node_id:
|
||
next_node_id = target_id
|
||
break
|
||
|
||
current_node_id = next_node_id
|
||
|
||
# 返回最后一个节点的输出
|
||
if loop_results:
|
||
last_result = list(loop_results.values())[-1]
|
||
return last_result
|
||
|
||
return {'output': loop_input, 'status': 'success'}
|
||
|
||
def _mark_loop_body_executed(self, node_id: str, executed_nodes: set, active_edges: List[Dict[str, Any]]):
|
||
"""
|
||
递归标记循环体内的节点为已执行
|
||
|
||
Args:
|
||
node_id: 当前节点ID
|
||
executed_nodes: 已执行节点集合
|
||
active_edges: 活跃的边列表
|
||
"""
|
||
if node_id in executed_nodes:
|
||
return
|
||
|
||
executed_nodes.add(node_id)
|
||
|
||
# 查找所有子节点
|
||
for edge in active_edges:
|
||
if edge.get('source') == node_id:
|
||
target_id = edge.get('target')
|
||
if target_id in self.nodes:
|
||
target_node = self.nodes[target_id]
|
||
# 如果是循环结束节点,停止递归
|
||
if target_node.get('type') in ['loop_end', 'end']:
|
||
continue
|
||
# 递归标记子节点
|
||
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
|
||
|
||
async def execute_node(self, node: Dict[str, Any], input_data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
执行单个节点
|
||
|
||
Args:
|
||
node: 节点配置
|
||
input_data: 输入数据
|
||
|
||
Returns:
|
||
节点执行结果
|
||
"""
|
||
# 确保可以访问全局的 json 模块
|
||
import json as json_module
|
||
|
||
node_type = node.get('type', 'unknown')
|
||
node_id = node.get('id')
|
||
start_time = time.time()
|
||
|
||
# 记录节点开始执行
|
||
if self.logger:
|
||
self.logger.log_node_start(node_id, node_type, input_data)
|
||
|
||
try:
|
||
if node_type == 'start':
|
||
# 起始节点:返回输入数据
|
||
logger.debug(f"[rjb] 开始节点执行: node_id={node_id}, input_data={input_data}")
|
||
result = {'output': input_data, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
||
logger.debug(f"[rjb] 开始节点输出: node_id={node_id}, output={result.get('output')}")
|
||
return result
|
||
|
||
elif node_type == 'input':
|
||
# 输入节点:处理输入数据
|
||
result = {'output': input_data, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
||
return result
|
||
|
||
elif node_type == 'approval':
|
||
# 人工审批(HITL):无 __hil_decision 时挂起,由 Celery 落库 awaiting_approval
|
||
nd = node.get('data', {}) or {}
|
||
message = nd.get('message', '需要人工审批')
|
||
approved_handle = nd.get('approved_branch', 'approved')
|
||
rejected_handle = nd.get('rejected_branch', 'rejected')
|
||
root = self._initial_input_data if isinstance(self._initial_input_data, dict) else {}
|
||
merged: Dict[str, Any] = {**root}
|
||
if isinstance(input_data, dict):
|
||
merged = {**merged, **input_data}
|
||
decision = merged.get('__hil_decision')
|
||
comment = merged.get('__hil_comment')
|
||
if decision == 'approved':
|
||
out = {
|
||
'approved': True,
|
||
'message': message,
|
||
'comment': comment,
|
||
'input': input_data,
|
||
}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, out, duration)
|
||
return {
|
||
'output': out,
|
||
'status': 'success',
|
||
'branch': approved_handle,
|
||
}
|
||
if decision == 'rejected':
|
||
out = {
|
||
'approved': False,
|
||
'message': message,
|
||
'comment': comment,
|
||
'input': input_data,
|
||
}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, out, duration)
|
||
return {
|
||
'output': out,
|
||
'status': 'success',
|
||
'branch': rejected_handle,
|
||
}
|
||
if self.logger:
|
||
self.logger.info(
|
||
f"审批节点等待人工决策: {message}",
|
||
node_id=node_id,
|
||
node_type='approval',
|
||
)
|
||
return {'status': 'awaiting_approval'}
|
||
|
||
elif node_type == 'llm' or node_type == 'template':
|
||
self._llm_invocations += 1
|
||
if self._llm_invocations > self._cap_llm:
|
||
raise WorkflowExecutionError(
|
||
detail=f"已超过 LLM 节点调用预算({self._cap_llm} 次)",
|
||
node_id=node_id,
|
||
)
|
||
# LLM节点:调用AI模型
|
||
node_data = node.get('data', {})
|
||
logger.debug(f"[rjb] LLM节点执行: node_id={node_id}, input_data={input_data}, input_data type={type(input_data)}")
|
||
logger.debug(f"[rjb] LLM节点数据: node_id={node_id}, node_data keys={list(node_data.keys())}, api_key={'已配置' if node_data.get('api_key') else '未配置'}")
|
||
prompt = node_data.get('prompt', '')
|
||
|
||
# 如果 prompt 为空:不要用 {input} 展开整包 input_data。
|
||
# 预览/执行侧常带 user_id、memory、conversation_history 等大对象,模型易照抄成 ```json 回复。
|
||
if not prompt:
|
||
prompt = (
|
||
"请根据用户当前问题用自然语言回答;需要时可用工具。"
|
||
"不要向用户复述或输出完整的 input_data / API 请求 JSON。"
|
||
)
|
||
|
||
# 格式化prompt,替换变量
|
||
try:
|
||
# 将input_data转换为字符串用于格式化
|
||
if isinstance(input_data, dict):
|
||
# 支持两种格式的变量:{key} 和 {{key}}
|
||
formatted_prompt = prompt
|
||
has_unfilled_variables = False
|
||
has_any_placeholder = False
|
||
|
||
# 检查是否有任何占位符
|
||
has_any_placeholder = bool(re.search(r'\{\{?\w+\}?\}', prompt))
|
||
|
||
# 首先处理 {{variable}} 和 {{variable.path}} 格式(模板节点常用)
|
||
# 支持嵌套路径,如 {{memory.conversation_history}}
|
||
double_brace_vars = re.findall(r'\{\{([^}]+)\}\}', prompt)
|
||
for var_path in double_brace_vars:
|
||
# 尝试从input_data中获取值(支持嵌套路径;含 memory.* 与 Cache 顶层合并对齐)
|
||
value = self._resolve_llm_prompt_placeholder(input_data, var_path)
|
||
|
||
# 如果变量未找到,尝试常见的别名映射
|
||
if value is None:
|
||
# user_input 可以映射到 query、input、USER_INPUT 等字段
|
||
if var_path == 'user_input':
|
||
for alias in ['query', 'input', 'USER_INPUT', 'user_input', 'text', 'message', 'content']:
|
||
value = input_data.get(alias)
|
||
if value is not None:
|
||
# 如果值是字典,尝试从中提取字符串值
|
||
if isinstance(value, dict):
|
||
for sub_key in ['query', 'input', 'text', 'message', 'content']:
|
||
if sub_key in value:
|
||
value = value[sub_key]
|
||
break
|
||
break
|
||
# output 可以映射到 right 字段(LLM节点的输出通常存储在right字段中)
|
||
elif var_path == 'output':
|
||
# 尝试从right字段中提取
|
||
right_value = input_data.get('right')
|
||
logger.info(f"[rjb] LLM节点查找output变量: right_value类型={type(right_value)}, right_value={str(right_value)[:100] if right_value else None}")
|
||
if right_value is not None:
|
||
# 如果right是字符串,直接使用
|
||
if isinstance(right_value, str):
|
||
value = right_value
|
||
logger.info(f"[rjb] LLM节点从right字段(字符串)提取output: {value[:100]}")
|
||
# 如果right是字典,尝试递归查找字符串值
|
||
elif isinstance(right_value, dict):
|
||
# 尝试从right.right.right...中提取(处理嵌套的right字段)
|
||
current = right_value
|
||
depth = 0
|
||
while isinstance(current, dict) and depth < 10:
|
||
if 'right' in current:
|
||
current = current['right']
|
||
depth += 1
|
||
if isinstance(current, str):
|
||
value = current
|
||
logger.info(f"[rjb] LLM节点从right字段(嵌套{depth}层)提取output: {value[:100]}")
|
||
break
|
||
else:
|
||
# 如果没有right字段,尝试其他可能的字段
|
||
for key in ['content', 'text', 'message', 'output']:
|
||
if key in current and isinstance(current[key], str):
|
||
value = current[key]
|
||
logger.info(f"[rjb] LLM节点从right字段中找到{key}字段: {value[:100]}")
|
||
break
|
||
if value is not None:
|
||
break
|
||
break
|
||
if value is None:
|
||
logger.warning(f"[rjb] LLM节点无法从right字段中提取output,right结构: {str(right_value)[:200]}")
|
||
|
||
if value is not None:
|
||
# 替换 {{variable}} 或 {{variable.path}} 为实际值
|
||
# 特殊处理:如果是memory.conversation_history,格式化为易读的对话格式
|
||
if var_path == 'memory.conversation_history' and isinstance(value, list):
|
||
# 将对话历史格式化为易读的文本格式
|
||
formatted_history = []
|
||
for msg in value:
|
||
role = msg.get('role', 'unknown')
|
||
content = msg.get('content', '')
|
||
if role == 'user':
|
||
formatted_history.append(f"用户:{content}")
|
||
elif role == 'assistant':
|
||
formatted_history.append(f"助手:{content}")
|
||
else:
|
||
formatted_history.append(f"{role}:{content}")
|
||
replacement = '\n'.join(formatted_history) if formatted_history else '(暂无对话历史)'
|
||
else:
|
||
# 其他情况使用JSON格式
|
||
replacement = json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value)
|
||
formatted_prompt = formatted_prompt.replace(f'{{{{{var_path}}}}}', replacement)
|
||
# 对于conversation_history,显示完整内容以便调试
|
||
if var_path == 'memory.conversation_history':
|
||
logger.info(f"[rjb] LLM节点替换变量: {var_path} = {replacement[:500] if len(replacement) > 500 else replacement}")
|
||
else:
|
||
logger.info(f"[rjb] LLM节点替换变量: {var_path} = {str(replacement)[:200]}")
|
||
else:
|
||
has_unfilled_variables = True
|
||
logger.warning(f"[rjb] LLM节点变量未找到: {var_path}, input_data keys: {list(input_data.keys()) if isinstance(input_data, dict) else 'not dict'}")
|
||
|
||
# 然后处理 {key} 格式
|
||
for key, value in input_data.items():
|
||
placeholder = f'{{{key}}}'
|
||
if placeholder in formatted_prompt:
|
||
formatted_prompt = formatted_prompt.replace(
|
||
placeholder,
|
||
json_module.dumps(value, ensure_ascii=False) if isinstance(value, (dict, list)) else str(value)
|
||
)
|
||
|
||
# 如果还有{input}占位符,替换为整个input_data
|
||
if '{input}' in formatted_prompt:
|
||
formatted_prompt = formatted_prompt.replace(
|
||
'{input}',
|
||
json_module.dumps(input_data, ensure_ascii=False)
|
||
)
|
||
|
||
# 提取用户的实际查询内容(优先提取)
|
||
user_query = None
|
||
logger.info(f"[rjb] 开始提取user_query: input_data={input_data}, input_data_type={type(input_data)}")
|
||
if isinstance(input_data, dict):
|
||
# 首先检查是否有嵌套的input字段
|
||
nested_input = input_data.get('input')
|
||
logger.info(f"[rjb] 检查嵌套input: nested_input={nested_input}, nested_input_type={type(nested_input) if nested_input else None}")
|
||
if isinstance(nested_input, dict):
|
||
# 从嵌套的input中提取
|
||
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||
if key in nested_input:
|
||
user_query = nested_input[key]
|
||
logger.info(f"[rjb] 从嵌套input中提取到user_query: key={key}, user_query={user_query}")
|
||
break
|
||
|
||
# 如果还没有,从顶层提取
|
||
if not user_query:
|
||
for key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||
if key in input_data:
|
||
value = input_data[key]
|
||
logger.info(f"[rjb] 从顶层提取: key={key}, value={value}, value_type={type(value)}")
|
||
# 如果值是字符串,直接使用
|
||
if isinstance(value, str):
|
||
user_query = value
|
||
logger.info(f"[rjb] 提取到字符串user_query: {user_query}")
|
||
break
|
||
# 如果值是字典,尝试从中提取
|
||
elif isinstance(value, dict):
|
||
for sub_key in ['query', 'input', 'text', 'message', 'content', 'user_input', 'USER_INPUT']:
|
||
if sub_key in value:
|
||
user_query = value[sub_key]
|
||
logger.info(f"[rjb] 从字典值中提取到user_query: sub_key={sub_key}, user_query={user_query}")
|
||
break
|
||
if user_query:
|
||
break
|
||
|
||
# Start→LLM 常见:边带 sourceHandle=right,Start 输出在 input_data["right"] 下,
|
||
# 顶层无 query/USER_INPUT,旧逻辑会退化为 JSON 整包,模型看不到附件路径。
|
||
if not user_query:
|
||
for bucket in ("right", "left", "output", "data"):
|
||
nested = input_data.get(bucket)
|
||
if not isinstance(nested, dict):
|
||
continue
|
||
for key in (
|
||
"query",
|
||
"input",
|
||
"text",
|
||
"message",
|
||
"content",
|
||
"user_input",
|
||
"USER_INPUT",
|
||
):
|
||
if key not in nested:
|
||
continue
|
||
value = nested[key]
|
||
logger.info(
|
||
f"[rjb] 从嵌套桶 {bucket}.{key} 提取 user_query 候选, type={type(value)}"
|
||
)
|
||
if isinstance(value, str):
|
||
user_query = value
|
||
logger.info(
|
||
f"[rjb] 从{bucket}.{key} 提取到字符串 user_query 长度={len(value)}"
|
||
)
|
||
break
|
||
if isinstance(value, dict):
|
||
for sub_key in (
|
||
"query",
|
||
"input",
|
||
"text",
|
||
"message",
|
||
"content",
|
||
"user_input",
|
||
"USER_INPUT",
|
||
):
|
||
if sub_key not in value:
|
||
continue
|
||
sv = value[sub_key]
|
||
if isinstance(sv, str):
|
||
user_query = sv
|
||
logger.info(
|
||
f"[rjb] 从{bucket}.{key}.{sub_key} 提取到 user_query"
|
||
)
|
||
break
|
||
if user_query:
|
||
break
|
||
if user_query:
|
||
break
|
||
|
||
# 如果还是没有,使用整个input_data(但排除系统字段)
|
||
if not user_query:
|
||
filtered_data = {k: v for k, v in input_data.items() if not k.startswith('_')}
|
||
logger.info(f"[rjb] 使用filtered_data: filtered_data={filtered_data}")
|
||
if filtered_data:
|
||
# 如果只有一个字段且是字符串,直接使用
|
||
if len(filtered_data) == 1:
|
||
single_value = list(filtered_data.values())[0]
|
||
if isinstance(single_value, str):
|
||
user_query = single_value
|
||
logger.info(f"[rjb] 从单个字符串字段提取到user_query: {user_query}")
|
||
elif isinstance(single_value, dict):
|
||
# 从字典中提取第一个字符串值
|
||
for v in single_value.values():
|
||
if isinstance(v, str):
|
||
user_query = v
|
||
logger.info(f"[rjb] 从字典的单个字段中提取到user_query: {user_query}")
|
||
break
|
||
if not user_query:
|
||
user_query = json_module.dumps(filtered_data, ensure_ascii=False) if len(filtered_data) > 1 else str(list(filtered_data.values())[0])
|
||
logger.info(f"[rjb] 使用JSON或字符串转换: user_query={user_query}")
|
||
|
||
logger.info(f"[rjb] 最终提取的user_query: {user_query}")
|
||
|
||
history_block = self._format_prior_conversation_for_llm(input_data, prompt)
|
||
|
||
# 如果prompt中没有占位符,或者仍有未填充的变量,将用户输入附加到prompt
|
||
is_generic_instruction = False # 初始化变量
|
||
if not has_any_placeholder:
|
||
# 如果prompt中没有占位符,将用户输入作为主要内容
|
||
if user_query:
|
||
# 判断是否是通用指令:简短且不包含具体任务描述
|
||
prompt_stripped = prompt.strip()
|
||
is_generic_instruction = (
|
||
len(prompt_stripped) < 30 or # 简短提示词
|
||
prompt_stripped in [
|
||
"请处理用户请求。", "请处理用户请求",
|
||
"请处理以下输入数据:", "请处理以下输入数据",
|
||
"请处理输入。", "请处理输入",
|
||
"处理用户请求", "处理请求",
|
||
"请回答用户问题", "请回答用户问题。",
|
||
"请帮助用户", "请帮助用户。"
|
||
] or
|
||
# 检查是否只包含通用指令关键词
|
||
(len(prompt_stripped) < 50 and any(keyword in prompt_stripped for keyword in [
|
||
"请处理", "处理", "请回答", "回答", "请帮助", "帮助", "请执行", "执行"
|
||
]) and not any(specific in prompt_stripped for specific in [
|
||
"翻译", "生成", "分析", "总结", "提取", "转换", "计算"
|
||
]))
|
||
)
|
||
|
||
if is_generic_instruction:
|
||
# 如果是通用指令,直接使用用户输入作为prompt
|
||
if history_block:
|
||
formatted_prompt = f"{history_block}\n\n{str(user_query)}"
|
||
else:
|
||
formatted_prompt = str(user_query)
|
||
logger.info(f"[rjb] 检测到通用指令,直接使用用户输入作为prompt: {user_query[:50] if user_query else 'None'}")
|
||
else:
|
||
# 否则,将用户输入附加到prompt
|
||
if history_block:
|
||
formatted_prompt = f"{formatted_prompt}\n\n{history_block}\n\n{user_query}"
|
||
else:
|
||
formatted_prompt = f"{formatted_prompt}\n\n{user_query}"
|
||
logger.info(f"[rjb] 非通用指令,将用户输入附加到prompt")
|
||
else:
|
||
# 如果没有提取到用户查询,附加整个input_data
|
||
tail = json_module.dumps(input_data, ensure_ascii=False)
|
||
if history_block:
|
||
formatted_prompt = f"{formatted_prompt}\n\n{history_block}\n\n{tail}"
|
||
else:
|
||
formatted_prompt = f"{formatted_prompt}\n\n{tail}"
|
||
elif has_unfilled_variables or re.search(r'\{\{[^}]+\}\}', formatted_prompt):
|
||
# 如果有占位符但未填充,先尝试清理所有未填充的模板变量
|
||
# 使用正则表达式替换所有 {{...}} 格式的未填充变量
|
||
formatted_prompt = re.sub(r'\{\{[^}]+\}\}', '', formatted_prompt)
|
||
# 如果有占位符但未填充,附加用户需求说明
|
||
if user_query:
|
||
user_tail = f"用户需求:{user_query}\n\n请根据用户需求来完成任务。"
|
||
if history_block:
|
||
formatted_prompt = f"{formatted_prompt}\n\n{history_block}\n\n{user_tail}"
|
||
else:
|
||
formatted_prompt = f"{formatted_prompt}\n\n{user_tail}"
|
||
else:
|
||
# 如果没有用户查询,附加整个input_data
|
||
data_tail = f"输入数据:{json_module.dumps(input_data, ensure_ascii=False)}\n\n请根据输入数据来完成任务。"
|
||
if history_block:
|
||
formatted_prompt = f"{formatted_prompt}\n\n{history_block}\n\n{data_tail}"
|
||
else:
|
||
formatted_prompt = f"{formatted_prompt}\n\n{data_tail}"
|
||
|
||
logger.info(f"[rjb] LLM节点prompt格式化: node_id={node_id}, original_prompt='{prompt[:50] if len(prompt) > 50 else prompt}', has_any_placeholder={has_any_placeholder}, user_query={user_query}, is_generic_instruction={is_generic_instruction}, final_prompt前200字符='{formatted_prompt[:200] if len(formatted_prompt) > 200 else formatted_prompt}'")
|
||
prompt = formatted_prompt
|
||
else:
|
||
# 如果input_data不是dict,直接转换为字符串
|
||
if '{input}' in prompt:
|
||
prompt = prompt.replace('{input}', str(input_data))
|
||
else:
|
||
prompt = f"{prompt}\n\n输入:{str(input_data)}"
|
||
except Exception as e:
|
||
# 格式化失败,使用原始prompt和input_data
|
||
logger.warning(f"[rjb] Prompt格式化失败: {str(e)}")
|
||
try:
|
||
prompt = f"{prompt}\n\n输入数据:\n{json_module.dumps(input_data, ensure_ascii=False)}"
|
||
except:
|
||
prompt = f"{prompt}\n\n输入数据:{str(input_data)}"
|
||
|
||
# 获取LLM配置
|
||
provider = node_data.get('provider', 'openai')
|
||
model = node_data.get('model', 'gpt-3.5-turbo')
|
||
# 确保temperature是浮点数(节点模板中可能是字符串)
|
||
temperature_raw = node_data.get('temperature', 0.7)
|
||
if isinstance(temperature_raw, str):
|
||
try:
|
||
temperature = float(temperature_raw)
|
||
except (ValueError, TypeError):
|
||
temperature = 0.7
|
||
else:
|
||
temperature = float(temperature_raw) if temperature_raw is not None else 0.7
|
||
# 确保max_tokens是整数(节点模板中可能是字符串)
|
||
max_tokens_raw = node_data.get('max_tokens')
|
||
if max_tokens_raw is not None:
|
||
if isinstance(max_tokens_raw, str):
|
||
try:
|
||
max_tokens = int(max_tokens_raw)
|
||
except (ValueError, TypeError):
|
||
max_tokens = None
|
||
else:
|
||
max_tokens = int(max_tokens_raw) if max_tokens_raw is not None else None
|
||
else:
|
||
max_tokens = None
|
||
# 默认使用环境变量中的 Key;若节点绑定 model_config_id 且执行上下文可信,则注入用户保存的密钥与 endpoint
|
||
api_key: Optional[str] = None
|
||
base_url: Optional[str] = None
|
||
mc_cred: Optional[Dict[str, Any]] = None
|
||
try:
|
||
mc_cred = self._resolve_llm_credentials_from_model_config(node_data)
|
||
if mc_cred:
|
||
api_key = mc_cred.get("api_key")
|
||
base_url = mc_cred.get("base_url")
|
||
provider = mc_cred.get("provider", provider)
|
||
model = mc_cred.get("model", model)
|
||
except Exception as mc_err:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, mc_err, duration)
|
||
logger.error(f"[rjb] LLM 模型配置解析失败: {mc_err}", exc_info=True)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': str(mc_err),
|
||
}
|
||
|
||
llm_extra_kw: Dict[str, Any] = {}
|
||
if api_key is not None:
|
||
llm_extra_kw["api_key"] = api_key
|
||
if base_url is not None:
|
||
llm_extra_kw["base_url"] = base_url
|
||
_xb = node_data.get("extra_body")
|
||
if isinstance(_xb, dict) and _xb:
|
||
llm_extra_kw["extra_body"] = _xb
|
||
|
||
# 记录实际发送给LLM的prompt
|
||
logger.info(f"[rjb] 准备调用LLM: node_id={node_id}, provider={provider}, model={model}, prompt前200字符='{prompt[:200] if len(prompt) > 200 else prompt}'")
|
||
|
||
_raw_sys = node_data.get("system_prompt")
|
||
llm_system_prompt: Optional[str] = None
|
||
if isinstance(_raw_sys, str) and _raw_sys.strip():
|
||
llm_system_prompt = _raw_sys.strip()
|
||
elif _raw_sys is not None and not isinstance(_raw_sys, (dict, list)):
|
||
_ts = str(_raw_sys).strip()
|
||
if _ts:
|
||
llm_system_prompt = _ts
|
||
|
||
# 检查是否启用工具调用
|
||
enable_tools = node_data.get('enable_tools', False)
|
||
# 支持两种字段名:tools 和 selected_tools
|
||
tools_config = node_data.get('tools') or node_data.get('selected_tools') or []
|
||
|
||
# 如果启用了工具,加载工具定义
|
||
tools = []
|
||
if enable_tools and tools_config:
|
||
from app.services.tool_registry import tool_registry
|
||
# 从注册表加载工具定义
|
||
tools = tool_registry.get_tools_by_names(tools_config)
|
||
logger.info(f"[rjb] LLM节点启用工具调用: {len(tools)} 个工具, 工具列表: {tools_config}")
|
||
if not tools:
|
||
logger.warning(
|
||
"[rjb] LLM 已 enable_tools 但当前进程 tool_registry 中 0 个匹配 schema,"
|
||
"将无法发起 function calling(常见于 Celery Worker 未加载 tools_bootstrap)。配置=%s",
|
||
tools_config,
|
||
)
|
||
elif len(tools) < len(tools_config):
|
||
missing = [n for n in tools_config if not tool_registry.get_tool_schema(n)]
|
||
logger.warning(
|
||
"[rjb] LLM 工具部分缺失 schema,缺失=%s(可动手能力不完整)",
|
||
missing,
|
||
)
|
||
|
||
# 调用LLM服务
|
||
try:
|
||
if self.logger:
|
||
key_src = "模型配置(model_config_id)" if mc_cred else "环境变量默认"
|
||
logger.debug(
|
||
f"[rjb] LLM节点配置: provider={provider}, model={model}, API密钥来源={key_src}, 工具调用: {'启用' if tools else '禁用'}"
|
||
)
|
||
self.logger.info(f"调用LLM服务: {provider}/{model}", node_id=node_id, node_type=node_type)
|
||
|
||
# 根据是否启用工具选择不同的调用方式
|
||
if tools:
|
||
_tool_choice = node_data.get("tool_choice")
|
||
if not (isinstance(_tool_choice, str) and _tool_choice.strip()):
|
||
_tool_choice = None
|
||
# 单次执行内工具多轮迭代(默认 5,见 llm_service);节点可配 max_tool_iterations
|
||
_tool_extra: Dict[str, Any] = {}
|
||
_mi = node_data.get("max_tool_iterations") or node_data.get(
|
||
"max_tool_call_rounds"
|
||
)
|
||
if _mi is not None:
|
||
try:
|
||
_tool_extra["max_iterations"] = max(1, min(int(_mi), 64))
|
||
except (TypeError, ValueError):
|
||
pass
|
||
_rt = node_data.get("request_timeout")
|
||
if _rt is not None:
|
||
try:
|
||
_tool_extra["request_timeout"] = max(10.0, float(_rt))
|
||
except (TypeError, ValueError):
|
||
pass
|
||
_merged_tool_kw = dict(llm_extra_kw)
|
||
_merged_tool_kw.update(_tool_extra)
|
||
result = await llm_service.call_llm_with_tools(
|
||
prompt=prompt,
|
||
tools=tools,
|
||
provider=provider,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
execution_logger=self.logger,
|
||
tool_choice=_tool_choice,
|
||
on_tool_executed=self._on_tool_executed_budget,
|
||
system_prompt=llm_system_prompt,
|
||
**_merged_tool_kw,
|
||
)
|
||
result = self._enrich_llm_json_user_profile(result, input_data)
|
||
else:
|
||
result = await llm_service.call_llm(
|
||
prompt=prompt,
|
||
provider=provider,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
system_prompt=llm_system_prompt,
|
||
**llm_extra_kw,
|
||
)
|
||
result = self._enrich_llm_json_user_profile(result, input_data)
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
except Exception as e:
|
||
# LLM调用失败,返回错误
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
logger.error(f"[rjb] LLM节点执行失败: {str(e)}", exc_info=True)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'LLM调用失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'agent':
|
||
# Agent 节点:自主 ReAct 循环,支持多步工具调用
|
||
if self.logger:
|
||
self.logger.info(
|
||
"Agent 节点开始执行",
|
||
data={"node_id": node_id, "input": input_data},
|
||
)
|
||
try:
|
||
from app.agent_runtime.workflow_integration import run_agent_node
|
||
|
||
_agent_on_tool = None
|
||
if hasattr(self, '_on_tool_executed_budget'):
|
||
_agent_on_tool = self._on_tool_executed_budget
|
||
|
||
# Agent 的 LLM 调用计入工作流预算
|
||
def _on_agent_llm():
|
||
self._llm_invocations += 1
|
||
if self._llm_invocations > self._cap_llm:
|
||
raise WorkflowExecutionError(
|
||
detail=f"已超过 LLM 节点调用预算({self._cap_llm} 次)",
|
||
)
|
||
|
||
result = await run_agent_node(
|
||
node_data=node.get("data", {}),
|
||
input_data=input_data,
|
||
execution_logger=self.logger,
|
||
user_id=self.trusted_model_config_user_id,
|
||
on_tool_executed=_agent_on_tool,
|
||
on_llm_invocation=_on_agent_llm,
|
||
budget_limits={
|
||
"max_llm_invocations": self._cap_llm,
|
||
"max_tool_calls": self._cap_tool,
|
||
},
|
||
)
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(
|
||
node_id, node_type, result.get("output"), duration,
|
||
)
|
||
return result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
logger.error(f"Agent 节点执行失败: {e}", exc_info=True)
|
||
return {
|
||
"output": None,
|
||
"status": "failed",
|
||
"error": f"Agent 执行失败: {e}",
|
||
}
|
||
|
||
elif node_type == 'condition':
|
||
# 条件节点:判断分支(output 必须透传上游 dict,否则 sourceHandle true/false 下游只收到布尔值,丢失 reply/memory)
|
||
condition = node.get('data', {}).get('condition', '')
|
||
|
||
def _condition_passthrough(ok: bool, failed: bool = False) -> dict:
|
||
base = input_data if isinstance(input_data, dict) else {}
|
||
out = base.copy()
|
||
out['_condition_result'] = ok
|
||
if failed:
|
||
out['_condition_error'] = True
|
||
return out
|
||
|
||
if not condition:
|
||
exec_result = {
|
||
'output': _condition_passthrough(False),
|
||
'status': 'success',
|
||
'branch': 'false'
|
||
}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, {'result': False, 'branch': 'false'}, duration)
|
||
return exec_result
|
||
|
||
# 使用条件解析器评估表达式
|
||
try:
|
||
result = condition_parser.evaluate_condition(condition, input_data)
|
||
ok = bool(result)
|
||
exec_result = {
|
||
'output': _condition_passthrough(ok),
|
||
'status': 'success',
|
||
'branch': 'true' if ok else 'false'
|
||
}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, {'result': ok, 'branch': exec_result['branch']}, duration)
|
||
return exec_result
|
||
except Exception as e:
|
||
# 条件评估失败
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': _condition_passthrough(False, failed=True),
|
||
'status': 'failed',
|
||
'error': f'条件评估失败: {str(e)}',
|
||
'branch': 'false'
|
||
}
|
||
|
||
elif node_type == 'data' or node_type == 'transform':
|
||
# 数据转换节点
|
||
node_data = node.get('data', {})
|
||
mapping = node_data.get('mapping', {})
|
||
filter_rules = node_data.get('filter_rules', [])
|
||
compute_rules = node_data.get('compute_rules', {})
|
||
mode = node_data.get('mode', 'mapping')
|
||
|
||
try:
|
||
# 处理mapping中的{{variable}}格式,从input_data中提取值
|
||
# 首先,如果input_data包含output字段,需要展开它
|
||
expanded_input = input_data.copy()
|
||
if 'output' in input_data and isinstance(input_data['output'], dict):
|
||
# 将output中的内容展开到顶层,但保留output字段
|
||
expanded_input.update(input_data['output'])
|
||
# 条件分支边常用 sourceHandle true/false,载荷在子 dict 中,需展开到顶层
|
||
for _branch_key in ('true', 'false'):
|
||
_bp = expanded_input.get(_branch_key)
|
||
if isinstance(_bp, dict):
|
||
expanded_input.update(_bp)
|
||
for _k in ('true', 'false', '_condition_result', '_condition_error'):
|
||
expanded_input.pop(_k, None)
|
||
# 展开 left:双入边 transform 的上游一路常挂在 sourceHandle=left(另一路为 LLM/code 的 right)
|
||
if isinstance(expanded_input.get('left'), dict):
|
||
expanded_input.update(expanded_input['left'])
|
||
# 展开 right:merge / json-parse 后 reply、user_profile 常在 right 或嵌套 JSON 字符串中
|
||
if isinstance(expanded_input.get('right'), dict):
|
||
expanded_input.update(expanded_input['right'])
|
||
elif isinstance(expanded_input.get('right'), str):
|
||
rs = expanded_input['right'].strip()
|
||
if rs.startswith('{'):
|
||
try:
|
||
_rj = json.loads(rs)
|
||
if isinstance(_rj, dict):
|
||
expanded_input.update(_rj)
|
||
except Exception:
|
||
pass
|
||
_r = expanded_input.get('right')
|
||
if isinstance(_r, dict) and isinstance(_r.get('right'), str):
|
||
_inner = _r['right'].strip()
|
||
if _inner.startswith('{'):
|
||
try:
|
||
_rj2 = json.loads(_inner)
|
||
if isinstance(_rj2, dict):
|
||
expanded_input.update(_rj2)
|
||
except Exception:
|
||
pass
|
||
|
||
processed_mapping = {}
|
||
for target_key, source_expr in mapping.items():
|
||
if isinstance(source_expr, str):
|
||
# 支持{{variable}}格式
|
||
double_brace_vars = re.findall(r'\{\{(\w+)\}\}', source_expr)
|
||
if double_brace_vars:
|
||
# 从expanded_input中获取变量值
|
||
var_value = None
|
||
for var_name in double_brace_vars:
|
||
var_value = self._resolve_brace_template_var(
|
||
expanded_input, var_name
|
||
)
|
||
if var_value is not None:
|
||
break
|
||
|
||
if var_value is not None:
|
||
# 如果只有一个变量,直接使用值;否则替换表达式
|
||
if len(double_brace_vars) == 1:
|
||
processed_mapping[target_key] = var_value
|
||
else:
|
||
# 多个变量,替换表达式
|
||
processed_expr = source_expr
|
||
for var_name in double_brace_vars:
|
||
var_val = self._resolve_brace_template_var(
|
||
expanded_input, var_name
|
||
)
|
||
if var_val is not None:
|
||
replacement = json_module.dumps(var_val, ensure_ascii=False) if isinstance(var_val, (dict, list)) else str(var_val)
|
||
processed_expr = processed_expr.replace(f'{{{{{var_name}}}}}', replacement)
|
||
processed_mapping[target_key] = processed_expr
|
||
else:
|
||
# 变量不存在,保持原表达式
|
||
processed_mapping[target_key] = source_expr
|
||
else:
|
||
# 不是{{variable}}格式,直接使用
|
||
processed_mapping[target_key] = source_expr
|
||
else:
|
||
# 不是字符串,直接使用
|
||
processed_mapping[target_key] = source_expr
|
||
|
||
# 如果mode是merge,需要合并所有输入数据
|
||
if mode == 'merge':
|
||
# 合并所有上游节点的输出(使用展开后的数据)
|
||
result = expanded_input.copy()
|
||
|
||
# 重要:如果输入数据中包含conversation_history、user_profile、context等记忆字段,先构建memory对象
|
||
memory_fields = ['conversation_history', 'user_profile', 'context']
|
||
memory_data = {}
|
||
for field in memory_fields:
|
||
if field in expanded_input:
|
||
memory_data[field] = expanded_input[field]
|
||
|
||
# 如果构建了memory对象,添加到result中
|
||
if memory_data:
|
||
result['memory'] = memory_data
|
||
logger.info(f"[rjb] Transform节点 {node_id} 构建memory对象: {list(memory_data.keys())}")
|
||
|
||
# 添加mapping的结果(mapping可能会覆盖memory字段)
|
||
for key, value in processed_mapping.items():
|
||
# 如果mapping中的value是None或空字符串,且key是memory,尝试从expanded_input构建
|
||
if key == 'memory' and (value is None or value == '' or value == '{{output}}'):
|
||
if memory_data:
|
||
result[key] = memory_data
|
||
logger.info(f"[rjb] Transform节点 {node_id} mapping中的memory为空,使用构建的memory对象")
|
||
elif 'memory' in expanded_input:
|
||
result[key] = expanded_input['memory']
|
||
else:
|
||
result[key] = value
|
||
|
||
# 确保记忆字段被保留(即使mapping覆盖了它们)
|
||
for field in memory_fields:
|
||
if field in expanded_input and field not in result:
|
||
result[field] = expanded_input[field]
|
||
# 如果memory字段是dict,也要检查其中的字段
|
||
if 'memory' in expanded_input and isinstance(expanded_input['memory'], dict):
|
||
if 'memory' not in result:
|
||
result['memory'] = expanded_input['memory'].copy()
|
||
else:
|
||
# 合并memory字段
|
||
if isinstance(result['memory'], dict):
|
||
result['memory'].update(expanded_input['memory'])
|
||
|
||
logger.info(f"[rjb] Transform节点 {node_id} merge模式,结果keys: {list(result.keys())}")
|
||
if 'memory' in result and isinstance(result['memory'], dict):
|
||
if 'conversation_history' in result['memory']:
|
||
logger.info(f"[rjb] memory.conversation_history: {len(result['memory']['conversation_history'])} 条")
|
||
elif 'conversation_history' in result:
|
||
logger.info(f"[rjb] conversation_history: {len(result['conversation_history'])} 条")
|
||
else:
|
||
# 使用处理后的mapping进行转换(使用展开后的数据)
|
||
result = data_transformer.transform_data(
|
||
input_data=expanded_input,
|
||
mapping=processed_mapping,
|
||
filter_rules=filter_rules,
|
||
compute_rules=compute_rules,
|
||
mode=mode
|
||
)
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
logger.error(f"[rjb] Transform节点执行失败: {str(e)}", exc_info=True)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'数据转换失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'loop' or node_type == 'foreach':
|
||
# 循环节点:对数组进行循环处理
|
||
node_data = node.get('data', {})
|
||
items_path = node_data.get('items_path', 'items') # 数组数据路径
|
||
item_variable = node_data.get('item_variable', 'item') # 循环变量名
|
||
|
||
# 从输入数据中获取数组
|
||
items = self._get_nested_value(input_data, items_path)
|
||
|
||
if not isinstance(items, list):
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type,
|
||
ValueError(f"路径 {items_path} 的值不是数组"), duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'路径 {items_path} 的值不是数组,当前类型: {type(items).__name__}'
|
||
}
|
||
|
||
if self.logger:
|
||
self.logger.info(f"循环节点开始处理 {len(items)} 个元素",
|
||
node_id=node_id, node_type=node_type,
|
||
data={"items_count": len(items)})
|
||
|
||
# 执行循环:对每个元素执行循环体
|
||
loop_results = []
|
||
for index, item in enumerate(items):
|
||
if self.logger:
|
||
self.logger.info(f"循环迭代 {index + 1}/{len(items)}",
|
||
node_id=node_id, node_type=node_type,
|
||
data={"index": index, "item": item})
|
||
|
||
# 准备循环体的输入数据
|
||
loop_input = {
|
||
**input_data, # 保留原始输入数据
|
||
item_variable: item, # 当前循环项
|
||
f'{item_variable}_index': index, # 索引
|
||
f'{item_variable}_total': len(items) # 总数
|
||
}
|
||
|
||
# 执行循环体(获取循环节点的子节点)
|
||
loop_body_result = await self._execute_loop_body(
|
||
node_id, loop_input, index
|
||
)
|
||
|
||
if loop_body_result.get('status') == 'success':
|
||
loop_results.append(loop_body_result.get('output', item))
|
||
else:
|
||
# 如果循环体执行失败,可以选择继续或停止
|
||
error_handling = node_data.get('error_handling', 'continue')
|
||
if error_handling == 'stop':
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type,
|
||
Exception(f"循环体执行失败,停止循环: {loop_body_result.get('error')}"), duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'循环体执行失败: {loop_body_result.get("error")}',
|
||
'completed_items': index,
|
||
'results': loop_results
|
||
}
|
||
else:
|
||
# continue: 继续执行,记录错误
|
||
if self.logger:
|
||
self.logger.warn(f"循环迭代 {index + 1} 失败,继续执行",
|
||
node_id=node_id, node_type=node_type,
|
||
data={"error": loop_body_result.get('error')})
|
||
loop_results.append(None)
|
||
|
||
exec_result = {
|
||
'output': loop_results,
|
||
'status': 'success',
|
||
'items_processed': len(items),
|
||
'results_count': len(loop_results)
|
||
}
|
||
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type,
|
||
{'results_count': len(loop_results)}, duration)
|
||
|
||
return exec_result
|
||
|
||
elif node_type == 'http' or node_type == 'request':
|
||
# HTTP请求节点:发送HTTP请求
|
||
node_data = node.get('data', {})
|
||
url = node_data.get('url', '')
|
||
method = node_data.get('method', 'GET').upper()
|
||
headers = node_data.get('headers', {})
|
||
params = node_data.get('params', {})
|
||
body = node_data.get('body', {})
|
||
timeout = node_data.get('timeout', 30)
|
||
|
||
# 如果URL、headers、params、body中包含变量,从input_data中替换
|
||
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
||
"""替换字符串中的变量占位符"""
|
||
if not isinstance(text, str):
|
||
return text
|
||
# 支持 {key} 或 ${key} 格式
|
||
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
||
def replacer(match):
|
||
key = match.group(1) or match.group(2)
|
||
value = self._get_nested_value(data, key)
|
||
return str(value) if value is not None else match.group(0)
|
||
return re.sub(pattern, replacer, text)
|
||
|
||
# 替换URL中的变量
|
||
if url:
|
||
url = replace_variables(url, input_data)
|
||
|
||
# 替换headers中的变量
|
||
if isinstance(headers, dict):
|
||
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
|
||
elif isinstance(headers, str):
|
||
try:
|
||
headers = json.loads(replace_variables(headers, input_data))
|
||
except:
|
||
headers = {}
|
||
|
||
# 替换params中的变量
|
||
if isinstance(params, dict):
|
||
params = {k: replace_variables(str(v), input_data) if isinstance(v, str) else v
|
||
for k, v in params.items()}
|
||
elif isinstance(params, str):
|
||
try:
|
||
params = json.loads(replace_variables(params, input_data))
|
||
except:
|
||
params = {}
|
||
|
||
# 替换body中的变量
|
||
if isinstance(body, dict):
|
||
# 递归替换字典中的变量
|
||
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
||
result = {}
|
||
for k, v in d.items():
|
||
new_k = replace_variables(k, data)
|
||
if isinstance(v, dict):
|
||
result[new_k] = replace_dict_vars(v, data)
|
||
elif isinstance(v, str):
|
||
result[new_k] = replace_variables(v, data)
|
||
else:
|
||
result[new_k] = v
|
||
return result
|
||
body = replace_dict_vars(body, input_data)
|
||
elif isinstance(body, str):
|
||
body = replace_variables(body, input_data)
|
||
try:
|
||
body = json.loads(body)
|
||
except:
|
||
pass
|
||
|
||
try:
|
||
import httpx
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
if method == 'GET':
|
||
response = await client.get(url, params=params, headers=headers)
|
||
elif method == 'POST':
|
||
response = await client.post(url, json=body, params=params, headers=headers)
|
||
elif method == 'PUT':
|
||
response = await client.put(url, json=body, params=params, headers=headers)
|
||
elif method == 'DELETE':
|
||
response = await client.delete(url, params=params, headers=headers)
|
||
elif method == 'PATCH':
|
||
response = await client.patch(url, json=body, params=params, headers=headers)
|
||
else:
|
||
raise ValueError(f"不支持的HTTP方法: {method}")
|
||
|
||
# 尝试解析JSON响应
|
||
try:
|
||
response_data = response.json()
|
||
except:
|
||
response_data = response.text
|
||
|
||
result = {
|
||
'output': {
|
||
'status_code': response.status_code,
|
||
'headers': dict(response.headers),
|
||
'data': response_data
|
||
},
|
||
'status': 'success'
|
||
}
|
||
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
||
return result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'HTTP请求失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'database' or node_type == 'db':
|
||
# 数据库操作节点:执行数据库操作
|
||
node_data = node.get('data', {})
|
||
data_source_id = node_data.get('data_source_id')
|
||
operation = node_data.get('operation', 'query') # query/insert/update/delete
|
||
sql = node_data.get('sql', '')
|
||
table = node_data.get('table', '')
|
||
data = node_data.get('data', {})
|
||
where = node_data.get('where', {})
|
||
|
||
# 如果SQL中包含变量,从input_data中替换
|
||
if sql and isinstance(sql, str):
|
||
def replace_sql_vars(text: str, data: Dict[str, Any]) -> str:
|
||
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
||
def replacer(match):
|
||
key = match.group(1) or match.group(2)
|
||
value = self._get_nested_value(data, key)
|
||
if value is None:
|
||
return match.group(0)
|
||
# 如果是字符串,需要转义SQL注入
|
||
if isinstance(value, str):
|
||
# 简单转义,实际应该使用参数化查询
|
||
escaped_value = value.replace("'", "''")
|
||
return f"'{escaped_value}'"
|
||
return str(value)
|
||
return re.sub(pattern, replacer, text)
|
||
sql = replace_sql_vars(sql, input_data)
|
||
|
||
try:
|
||
# 从数据库加载数据源配置
|
||
if not self.db:
|
||
raise ValueError("数据库会话未提供,无法执行数据库操作")
|
||
|
||
from app.models.data_source import DataSource
|
||
from app.services.data_source_connector import create_connector
|
||
|
||
data_source = self.db.query(DataSource).filter(
|
||
DataSource.id == data_source_id
|
||
).first()
|
||
|
||
if not data_source:
|
||
raise ValueError(f"数据源不存在: {data_source_id}")
|
||
|
||
connector = create_connector(data_source.type, data_source.config)
|
||
|
||
if operation == 'query':
|
||
# 查询操作
|
||
if not sql:
|
||
raise ValueError("查询操作需要提供SQL语句")
|
||
query_params = {'query': sql}
|
||
result_data = connector.query(query_params)
|
||
result = {'output': result_data, 'status': 'success'}
|
||
elif operation == 'insert':
|
||
# 插入操作
|
||
if not table:
|
||
raise ValueError("插入操作需要提供表名")
|
||
# 构建INSERT SQL
|
||
columns = ', '.join(data.keys())
|
||
# 处理字符串值,转义单引号
|
||
def escape_value(v):
|
||
if isinstance(v, str):
|
||
escaped = v.replace("'", "''")
|
||
return f"'{escaped}'"
|
||
return str(v)
|
||
values = ', '.join([escape_value(v) for v in data.values()])
|
||
insert_sql = f"INSERT INTO {table} ({columns}) VALUES ({values})"
|
||
query_params = {'query': insert_sql}
|
||
result_data = connector.query(query_params)
|
||
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
|
||
elif operation == 'update':
|
||
# 更新操作
|
||
if not table or not where:
|
||
raise ValueError("更新操作需要提供表名和WHERE条件")
|
||
set_clause = ', '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in data.items()])
|
||
where_clause = ' AND '.join([f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" for k, v in where.items()])
|
||
update_sql = f"UPDATE {table} SET {set_clause} WHERE {where_clause}"
|
||
query_params = {'query': update_sql}
|
||
result_data = connector.query(query_params)
|
||
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
|
||
elif operation == 'delete':
|
||
# 删除操作
|
||
if not table or not where:
|
||
raise ValueError("删除操作需要提供表名和WHERE条件")
|
||
# 处理字符串值,转义单引号
|
||
def escape_sql_value(k, v):
|
||
if isinstance(v, str):
|
||
escaped = v.replace("'", "''")
|
||
return f"{k} = '{escaped}'"
|
||
return f"{k} = {v}"
|
||
where_clause = ' AND '.join([escape_sql_value(k, v) for k, v in where.items()])
|
||
delete_sql = f"DELETE FROM {table} WHERE {where_clause}"
|
||
query_params = {'query': delete_sql}
|
||
result_data = connector.query(query_params)
|
||
result = {'output': {'affected_rows': 1, 'data': result_data}, 'status': 'success'}
|
||
else:
|
||
raise ValueError(f"不支持的数据库操作: {operation}")
|
||
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
||
return result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'数据库操作失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'file' or node_type == 'file_operation':
|
||
# 文件操作节点:文件读取、写入、上传、下载
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'read') # read/write/upload/download
|
||
file_path = node_data.get('file_path', '')
|
||
content = node_data.get('content', '')
|
||
encoding = node_data.get('encoding', 'utf-8')
|
||
|
||
# 替换文件路径和内容中的变量
|
||
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
||
"""替换字符串中的变量占位符"""
|
||
if not isinstance(text, str):
|
||
return text
|
||
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
||
def replacer(match):
|
||
key = match.group(1) or match.group(2)
|
||
value = self._get_nested_value(data, key)
|
||
return str(value) if value is not None else match.group(0)
|
||
return re.sub(pattern, replacer, text)
|
||
|
||
if file_path:
|
||
file_path = replace_variables(file_path, input_data)
|
||
if isinstance(content, str):
|
||
content = replace_variables(content, input_data)
|
||
|
||
try:
|
||
import os
|
||
import base64
|
||
from pathlib import Path
|
||
|
||
if operation == 'read':
|
||
# 读取文件
|
||
if not file_path:
|
||
raise ValueError("读取操作需要提供文件路径")
|
||
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||
|
||
# 根据文件扩展名决定读取方式
|
||
file_ext = Path(file_path).suffix.lower()
|
||
if file_ext == '.json':
|
||
with open(file_path, 'r', encoding=encoding) as f:
|
||
data = json.load(f)
|
||
elif file_ext in ['.txt', '.md', '.log']:
|
||
with open(file_path, 'r', encoding=encoding) as f:
|
||
data = f.read()
|
||
else:
|
||
# 二进制文件,返回base64编码
|
||
with open(file_path, 'rb') as f:
|
||
data = base64.b64encode(f.read()).decode('utf-8')
|
||
|
||
result = {'output': data, 'status': 'success'}
|
||
|
||
elif operation == 'write':
|
||
# 写入文件
|
||
if not file_path:
|
||
raise ValueError("写入操作需要提供文件路径")
|
||
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(file_path) if os.path.dirname(file_path) else '.', exist_ok=True)
|
||
|
||
# 如果content是字典或列表,转换为JSON
|
||
if isinstance(content, (dict, list)):
|
||
content = json.dumps(content, ensure_ascii=False, indent=2)
|
||
|
||
# 根据文件扩展名决定写入方式
|
||
file_ext = Path(file_path).suffix.lower()
|
||
if file_ext == '.json':
|
||
with open(file_path, 'w', encoding=encoding) as f:
|
||
json.dump(json.loads(content) if isinstance(content, str) else content, f, ensure_ascii=False, indent=2)
|
||
else:
|
||
with open(file_path, 'w', encoding=encoding) as f:
|
||
f.write(str(content))
|
||
|
||
result = {'output': {'file_path': file_path, 'message': '文件写入成功'}, 'status': 'success'}
|
||
|
||
elif operation == 'upload':
|
||
# 文件上传(从base64或URL上传)
|
||
upload_type = node_data.get('upload_type', 'base64') # base64/url
|
||
target_path = node_data.get('target_path', '')
|
||
|
||
if upload_type == 'base64':
|
||
# 从输入数据中获取base64编码的文件内容
|
||
file_data = input_data.get('file_data') or input_data.get('content')
|
||
if not file_data:
|
||
raise ValueError("上传操作需要提供file_data或content字段")
|
||
|
||
# 解码base64
|
||
if isinstance(file_data, str):
|
||
file_bytes = base64.b64decode(file_data)
|
||
else:
|
||
file_bytes = file_data
|
||
|
||
# 写入目标路径
|
||
if not target_path:
|
||
raise ValueError("上传操作需要提供target_path")
|
||
|
||
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
|
||
with open(target_path, 'wb') as f:
|
||
f.write(file_bytes)
|
||
|
||
result = {'output': {'file_path': target_path, 'message': '文件上传成功'}, 'status': 'success'}
|
||
else:
|
||
# URL上传(下载后保存)
|
||
import httpx
|
||
url = node_data.get('url', '')
|
||
if not url:
|
||
raise ValueError("URL上传需要提供url")
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.get(url)
|
||
response.raise_for_status()
|
||
|
||
if not target_path:
|
||
# 从URL提取文件名
|
||
target_path = os.path.basename(url) or 'downloaded_file'
|
||
|
||
os.makedirs(os.path.dirname(target_path) if os.path.dirname(target_path) else '.', exist_ok=True)
|
||
with open(target_path, 'wb') as f:
|
||
f.write(response.content)
|
||
|
||
result = {'output': {'file_path': target_path, 'message': '文件下载并保存成功'}, 'status': 'success'}
|
||
|
||
elif operation == 'download':
|
||
# 文件下载(返回base64编码或文件URL)
|
||
download_format = node_data.get('download_format', 'base64') # base64/url
|
||
|
||
if not file_path:
|
||
raise ValueError("下载操作需要提供文件路径")
|
||
|
||
if not os.path.exists(file_path):
|
||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||
|
||
if download_format == 'base64':
|
||
# 返回base64编码
|
||
with open(file_path, 'rb') as f:
|
||
file_bytes = f.read()
|
||
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
|
||
result = {'output': {'file_name': os.path.basename(file_path), 'content': file_base64, 'format': 'base64'}, 'status': 'success'}
|
||
else:
|
||
# 返回文件路径(实际应用中可能需要生成临时URL)
|
||
result = {'output': {'file_path': file_path, 'format': 'path'}, 'status': 'success'}
|
||
|
||
else:
|
||
raise ValueError(f"不支持的文件操作: {operation}")
|
||
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
||
return result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'文件操作失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'webhook':
|
||
# Webhook节点:发送Webhook请求到外部系统
|
||
node_data = node.get('data', {})
|
||
url = node_data.get('url', '')
|
||
method = node_data.get('method', 'POST').upper()
|
||
headers = node_data.get('headers', {})
|
||
body = node_data.get('body', {})
|
||
timeout = node_data.get('timeout', 30)
|
||
|
||
# 如果URL、headers、body中包含变量,从input_data中替换
|
||
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
||
"""替换字符串中的变量占位符"""
|
||
if not isinstance(text, str):
|
||
return text
|
||
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
||
def replacer(match):
|
||
key = match.group(1) or match.group(2)
|
||
value = self._get_nested_value(data, key)
|
||
return str(value) if value is not None else match.group(0)
|
||
return re.sub(pattern, replacer, text)
|
||
|
||
# 替换URL中的变量
|
||
if url:
|
||
url = replace_variables(url, input_data)
|
||
|
||
# 替换headers中的变量
|
||
if isinstance(headers, dict):
|
||
headers = {k: replace_variables(str(v), input_data) for k, v in headers.items()}
|
||
elif isinstance(headers, str):
|
||
try:
|
||
headers = json.loads(replace_variables(headers, input_data))
|
||
except:
|
||
headers = {}
|
||
|
||
# 替换body中的变量
|
||
if isinstance(body, dict):
|
||
# 递归替换字典中的变量
|
||
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
||
result = {}
|
||
for k, v in d.items():
|
||
new_k = replace_variables(k, data)
|
||
if isinstance(v, dict):
|
||
result[new_k] = replace_dict_vars(v, data)
|
||
elif isinstance(v, str):
|
||
result[new_k] = replace_variables(v, data)
|
||
else:
|
||
result[new_k] = v
|
||
return result
|
||
body = replace_dict_vars(body, input_data)
|
||
elif isinstance(body, str):
|
||
body = replace_variables(body, input_data)
|
||
try:
|
||
body = json.loads(body)
|
||
except:
|
||
pass
|
||
|
||
# 如果没有配置body,默认使用input_data作为body
|
||
if not body:
|
||
body = input_data
|
||
|
||
try:
|
||
import httpx
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
if method == 'GET':
|
||
response = await client.get(url, headers=headers)
|
||
elif method == 'POST':
|
||
response = await client.post(url, json=body, headers=headers)
|
||
elif method == 'PUT':
|
||
response = await client.put(url, json=body, headers=headers)
|
||
elif method == 'PATCH':
|
||
response = await client.patch(url, json=body, headers=headers)
|
||
else:
|
||
raise ValueError(f"Webhook不支持HTTP方法: {method}")
|
||
|
||
# 尝试解析JSON响应
|
||
try:
|
||
response_data = response.json()
|
||
except:
|
||
response_data = response.text
|
||
|
||
result = {
|
||
'output': {
|
||
'status_code': response.status_code,
|
||
'headers': dict(response.headers),
|
||
'data': response_data
|
||
},
|
||
'status': 'success'
|
||
}
|
||
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
||
return result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'Webhook请求失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'schedule' or node_type == 'delay' or node_type == 'timer':
|
||
# 定时任务节点:延迟执行或定时执行
|
||
node_data = node.get('data', {})
|
||
delay_type = node_data.get('delay_type', 'fixed') # fixed: 固定延迟, cron: cron表达式
|
||
delay_value = node_data.get('delay_value', 0) # 延迟值(秒)
|
||
delay_unit = node_data.get('delay_unit', 'seconds') # seconds, minutes, hours
|
||
|
||
# 计算实际延迟时间(毫秒)
|
||
if delay_unit == 'seconds':
|
||
delay_ms = int(delay_value * 1000)
|
||
elif delay_unit == 'minutes':
|
||
delay_ms = int(delay_value * 60 * 1000)
|
||
elif delay_unit == 'hours':
|
||
delay_ms = int(delay_value * 60 * 60 * 1000)
|
||
else:
|
||
delay_ms = int(delay_value * 1000)
|
||
|
||
# 如果延迟时间大于0,则等待
|
||
if delay_ms > 0:
|
||
if self.logger:
|
||
self.logger.info(
|
||
f"定时任务节点等待 {delay_value} {delay_unit}",
|
||
node_id=node_id,
|
||
node_type=node_type,
|
||
data={'delay_ms': delay_ms, 'delay_value': delay_value, 'delay_unit': delay_unit}
|
||
)
|
||
await asyncio.sleep(delay_ms / 1000.0)
|
||
|
||
# 返回输入数据(定时节点只是延迟,不改变数据)
|
||
result = {'output': input_data, 'status': 'success', 'delay_ms': delay_ms}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
||
return result
|
||
|
||
elif node_type == 'email' or node_type == 'mail':
|
||
# 邮件节点:发送邮件通知
|
||
node_data = node.get('data', {})
|
||
smtp_host = node_data.get('smtp_host', '')
|
||
smtp_port = node_data.get('smtp_port', 587)
|
||
smtp_user = node_data.get('smtp_user', '')
|
||
smtp_password = node_data.get('smtp_password', '')
|
||
use_tls = node_data.get('use_tls', True)
|
||
from_email = node_data.get('from_email', '')
|
||
to_email = node_data.get('to_email', '')
|
||
cc_email = node_data.get('cc_email', '')
|
||
bcc_email = node_data.get('bcc_email', '')
|
||
subject = node_data.get('subject', '')
|
||
body = node_data.get('body', '')
|
||
body_type = node_data.get('body_type', 'text') # text/html
|
||
attachments = node_data.get('attachments', []) # 附件列表
|
||
|
||
# 替换变量
|
||
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
||
"""替换字符串中的变量占位符"""
|
||
if not isinstance(text, str):
|
||
return text
|
||
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
||
def replacer(match):
|
||
key = match.group(1) or match.group(2)
|
||
value = self._get_nested_value(data, key)
|
||
return str(value) if value is not None else match.group(0)
|
||
return re.sub(pattern, replacer, text)
|
||
|
||
# 替换所有配置中的变量
|
||
smtp_host = replace_variables(smtp_host, input_data)
|
||
smtp_user = replace_variables(smtp_user, input_data)
|
||
smtp_password = replace_variables(smtp_password, input_data)
|
||
from_email = replace_variables(from_email, input_data)
|
||
to_email = replace_variables(to_email, input_data)
|
||
cc_email = replace_variables(cc_email, input_data)
|
||
bcc_email = replace_variables(bcc_email, input_data)
|
||
subject = replace_variables(subject, input_data)
|
||
body = replace_variables(body, input_data)
|
||
|
||
# 验证必需参数
|
||
if not smtp_host:
|
||
raise ValueError("邮件节点需要配置SMTP服务器地址")
|
||
if not from_email:
|
||
raise ValueError("邮件节点需要配置发件人邮箱")
|
||
if not to_email:
|
||
raise ValueError("邮件节点需要配置收件人邮箱")
|
||
if not subject:
|
||
raise ValueError("邮件节点需要配置邮件主题")
|
||
|
||
try:
|
||
import aiosmtplib
|
||
from email.mime.text import MIMEText
|
||
from email.mime.multipart import MIMEMultipart
|
||
from email.mime.base import MIMEBase
|
||
from email import encoders
|
||
import base64
|
||
import os
|
||
|
||
# 创建邮件消息
|
||
msg = MIMEMultipart('alternative')
|
||
msg['From'] = from_email
|
||
msg['To'] = to_email
|
||
if cc_email:
|
||
msg['Cc'] = cc_email
|
||
msg['Subject'] = subject
|
||
|
||
# 添加邮件正文
|
||
if body_type == 'html':
|
||
msg.attach(MIMEText(body, 'html', 'utf-8'))
|
||
else:
|
||
msg.attach(MIMEText(body, 'plain', 'utf-8'))
|
||
|
||
# 处理附件
|
||
for attachment in attachments:
|
||
if isinstance(attachment, dict):
|
||
file_path = attachment.get('file_path', '')
|
||
file_name = attachment.get('file_name', '')
|
||
file_content = attachment.get('file_content', '') # base64编码的内容
|
||
|
||
# 替换变量
|
||
file_path = replace_variables(file_path, input_data)
|
||
file_name = replace_variables(file_name, input_data)
|
||
|
||
if file_path and os.path.exists(file_path):
|
||
# 从文件路径读取
|
||
with open(file_path, 'rb') as f:
|
||
file_data = f.read()
|
||
if not file_name:
|
||
file_name = os.path.basename(file_path)
|
||
elif file_content:
|
||
# 从base64内容读取
|
||
file_data = base64.b64decode(file_content)
|
||
if not file_name:
|
||
file_name = 'attachment'
|
||
else:
|
||
continue
|
||
|
||
# 添加附件
|
||
part = MIMEBase('application', 'octet-stream')
|
||
part.set_payload(file_data)
|
||
encoders.encode_base64(part)
|
||
part.add_header(
|
||
'Content-Disposition',
|
||
f'attachment; filename= {file_name}'
|
||
)
|
||
msg.attach(part)
|
||
|
||
# 发送邮件
|
||
recipients = [to_email]
|
||
if cc_email:
|
||
recipients.extend([email.strip() for email in cc_email.split(',')])
|
||
if bcc_email:
|
||
recipients.extend([email.strip() for email in bcc_email.split(',')])
|
||
|
||
async with aiosmtplib.SMTP(hostname=smtp_host, port=smtp_port) as smtp:
|
||
if use_tls:
|
||
await smtp.starttls()
|
||
if smtp_user and smtp_password:
|
||
await smtp.login(smtp_user, smtp_password)
|
||
await smtp.send_message(msg, recipients=recipients)
|
||
|
||
result = {
|
||
'output': {
|
||
'message': '邮件发送成功',
|
||
'from': from_email,
|
||
'to': to_email,
|
||
'subject': subject,
|
||
'recipients_count': len(recipients)
|
||
},
|
||
'status': 'success'
|
||
}
|
||
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
||
return result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'邮件发送失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'message_queue' or node_type == 'mq' or node_type == 'rabbitmq' or node_type == 'kafka':
|
||
# 消息队列节点:发送消息到RabbitMQ或Kafka
|
||
node_data = node.get('data', {})
|
||
queue_type = node_data.get('queue_type', 'rabbitmq') # rabbitmq/kafka
|
||
|
||
# 替换变量
|
||
def replace_variables(text: str, data: Dict[str, Any]) -> str:
|
||
"""替换字符串中的变量占位符"""
|
||
if not isinstance(text, str):
|
||
return text
|
||
pattern = r'\{([^}]+)\}|\$\{([^}]+)\}'
|
||
def replacer(match):
|
||
key = match.group(1) or match.group(2)
|
||
value = self._get_nested_value(data, key)
|
||
return str(value) if value is not None else match.group(0)
|
||
return re.sub(pattern, replacer, text)
|
||
|
||
try:
|
||
if queue_type == 'rabbitmq':
|
||
# RabbitMQ实现
|
||
import aio_pika
|
||
|
||
# 获取RabbitMQ配置
|
||
host = replace_variables(node_data.get('host', 'localhost'), input_data)
|
||
port = node_data.get('port', 5672)
|
||
username = replace_variables(node_data.get('username', 'guest'), input_data)
|
||
password = replace_variables(node_data.get('password', 'guest'), input_data)
|
||
exchange = replace_variables(node_data.get('exchange', ''), input_data)
|
||
routing_key = replace_variables(node_data.get('routing_key', ''), input_data)
|
||
queue_name = replace_variables(node_data.get('queue_name', ''), input_data)
|
||
message = node_data.get('message', input_data)
|
||
|
||
# 如果message是字符串,尝试替换变量
|
||
if isinstance(message, str):
|
||
message = replace_variables(message, input_data)
|
||
try:
|
||
message = json.loads(message)
|
||
except:
|
||
pass
|
||
elif isinstance(message, dict):
|
||
# 递归替换字典中的变量
|
||
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
||
result = {}
|
||
for k, v in d.items():
|
||
new_k = replace_variables(k, data)
|
||
if isinstance(v, dict):
|
||
result[new_k] = replace_dict_vars(v, data)
|
||
elif isinstance(v, str):
|
||
result[new_k] = replace_variables(v, data)
|
||
else:
|
||
result[new_k] = v
|
||
return result
|
||
message = replace_dict_vars(message, input_data)
|
||
|
||
# 如果没有配置message,使用input_data
|
||
if not message:
|
||
message = input_data
|
||
|
||
# 连接RabbitMQ
|
||
connection_url = f"amqp://{username}:{password}@{host}:{port}/"
|
||
connection = await aio_pika.connect_robust(connection_url)
|
||
channel = await connection.channel()
|
||
|
||
# 发送消息
|
||
message_body = json.dumps(message, ensure_ascii=False).encode('utf-8')
|
||
|
||
if exchange:
|
||
# 使用exchange和routing_key
|
||
await channel.default_exchange.publish(
|
||
aio_pika.Message(message_body),
|
||
routing_key=routing_key or queue_name
|
||
)
|
||
elif queue_name:
|
||
# 直接发送到队列
|
||
queue = await channel.declare_queue(queue_name, durable=True)
|
||
await channel.default_exchange.publish(
|
||
aio_pika.Message(message_body),
|
||
routing_key=queue_name
|
||
)
|
||
else:
|
||
raise ValueError("RabbitMQ节点需要配置exchange或queue_name")
|
||
|
||
await connection.close()
|
||
|
||
result = {
|
||
'output': {
|
||
'message': '消息已发送到RabbitMQ',
|
||
'queue_type': 'rabbitmq',
|
||
'exchange': exchange,
|
||
'routing_key': routing_key or queue_name,
|
||
'queue_name': queue_name,
|
||
'message_size': len(message_body)
|
||
},
|
||
'status': 'success'
|
||
}
|
||
|
||
elif queue_type == 'kafka':
|
||
# Kafka实现
|
||
from kafka import KafkaProducer
|
||
|
||
# 获取Kafka配置
|
||
bootstrap_servers = replace_variables(node_data.get('bootstrap_servers', 'localhost:9092'), input_data)
|
||
topic = replace_variables(node_data.get('topic', ''), input_data)
|
||
message = node_data.get('message', input_data)
|
||
|
||
# 如果message是字符串,尝试替换变量
|
||
if isinstance(message, str):
|
||
message = replace_variables(message, input_data)
|
||
try:
|
||
message = json.loads(message)
|
||
except:
|
||
pass
|
||
elif isinstance(message, dict):
|
||
# 递归替换字典中的变量
|
||
def replace_dict_vars(d: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
||
result = {}
|
||
for k, v in d.items():
|
||
new_k = replace_variables(k, data)
|
||
if isinstance(v, dict):
|
||
result[new_k] = replace_dict_vars(v, data)
|
||
elif isinstance(v, str):
|
||
result[new_k] = replace_variables(v, data)
|
||
else:
|
||
result[new_k] = v
|
||
return result
|
||
message = replace_dict_vars(message, input_data)
|
||
|
||
# 如果没有配置message,使用input_data
|
||
if not message:
|
||
message = input_data
|
||
|
||
if not topic:
|
||
raise ValueError("Kafka节点需要配置topic")
|
||
|
||
# 创建Kafka生产者(注意:kafka-python是同步的,需要在线程池中运行)
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
def send_kafka_message():
|
||
producer = KafkaProducer(
|
||
bootstrap_servers=bootstrap_servers.split(','),
|
||
value_serializer=lambda v: json.dumps(v, ensure_ascii=False).encode('utf-8')
|
||
)
|
||
future = producer.send(topic, message)
|
||
record_metadata = future.get(timeout=10)
|
||
producer.close()
|
||
return record_metadata
|
||
|
||
# 在线程池中执行同步操作
|
||
loop = asyncio.get_event_loop()
|
||
with ThreadPoolExecutor() as executor:
|
||
record_metadata = await loop.run_in_executor(executor, send_kafka_message)
|
||
|
||
result = {
|
||
'output': {
|
||
'message': '消息已发送到Kafka',
|
||
'queue_type': 'kafka',
|
||
'topic': topic,
|
||
'partition': record_metadata.partition,
|
||
'offset': record_metadata.offset
|
||
},
|
||
'status': 'success'
|
||
}
|
||
else:
|
||
raise ValueError(f"不支持的消息队列类型: {queue_type}")
|
||
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result.get('output'), duration)
|
||
return result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'消息队列发送失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'switch':
|
||
# Switch节点:多分支路由
|
||
logger.info(f"[rjb] 执行Switch节点: node_id={node_id}, node_type={node_type}, input_data keys={list(input_data.keys()) if isinstance(input_data, dict) else 'not_dict'}")
|
||
node_data = node.get('data', {})
|
||
field = node_data.get('field', '')
|
||
cases = node_data.get('cases', {})
|
||
default_case = node_data.get('default', 'default')
|
||
logger.info(f"[rjb] Switch节点配置: field={field}, cases={cases}, default={default_case}")
|
||
|
||
# 处理输入数据:尝试解析JSON字符串(递归处理所有字段)
|
||
def parse_json_recursively(data, merge_parsed=True):
|
||
"""
|
||
递归解析JSON字符串
|
||
|
||
Args:
|
||
data: 要处理的数据
|
||
merge_parsed: 是否将解析后的字典内容合并到父级(用于方便字段提取)
|
||
"""
|
||
import json
|
||
if isinstance(data, str):
|
||
# 如果是字符串,尝试解析为JSON
|
||
try:
|
||
parsed = json.loads(data)
|
||
# 如果解析成功,递归处理解析后的数据
|
||
if isinstance(parsed, (dict, list)):
|
||
return parse_json_recursively(parsed, merge_parsed)
|
||
return parsed
|
||
except:
|
||
# 不是JSON,返回原字符串
|
||
return data
|
||
elif isinstance(data, dict):
|
||
# 如果是字典,递归处理每个值
|
||
result = {}
|
||
for key, value in data.items():
|
||
parsed_value = parse_json_recursively(value, merge_parsed=False)
|
||
result[key] = parsed_value
|
||
# 如果merge_parsed为True且解析后的值是字典,将其内容合并到当前层级(方便字段提取)
|
||
if merge_parsed and isinstance(parsed_value, dict):
|
||
# 合并时避免覆盖已有的键
|
||
for k, v in parsed_value.items():
|
||
if k not in result:
|
||
result[k] = v
|
||
return result
|
||
elif isinstance(data, list):
|
||
# 如果是列表,递归处理每个元素
|
||
return [parse_json_recursively(item, merge_parsed) for item in data]
|
||
else:
|
||
# 其他类型,直接返回
|
||
return data
|
||
|
||
processed_input = parse_json_recursively(input_data, merge_parsed=True)
|
||
|
||
# 从处理后的输入数据中获取字段值
|
||
field_value = self._get_nested_value(processed_input, field)
|
||
field_value_str = str(field_value) if field_value is not None else ''
|
||
|
||
# 查找匹配的case
|
||
matched_case = default_case
|
||
if field_value_str in cases:
|
||
matched_case = cases[field_value_str]
|
||
elif field_value in cases:
|
||
matched_case = cases[field_value]
|
||
|
||
# 记录详细的匹配信息(同时输出到控制台和数据库)
|
||
match_info = {
|
||
'field': field,
|
||
'field_value': field_value,
|
||
'field_value_str': field_value_str,
|
||
'matched_case': matched_case,
|
||
'processed_input_keys': list(processed_input.keys()) if isinstance(processed_input, dict) else 'not_dict',
|
||
'cases_keys': list(cases.keys())
|
||
}
|
||
logger.info(f"[rjb] Switch节点匹配: node_id={node_id}, {match_info}")
|
||
if self.logger:
|
||
self.logger.info(
|
||
f"Switch节点匹配: field={field}, field_value={field_value}, matched_case={matched_case}",
|
||
node_id=node_id,
|
||
node_type=node_type,
|
||
data=match_info
|
||
)
|
||
|
||
exec_result = {
|
||
'output': processed_input,
|
||
'status': 'success',
|
||
'branch': matched_case,
|
||
'matched_value': field_value
|
||
}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, {'branch': matched_case, 'value': field_value}, duration)
|
||
return exec_result
|
||
|
||
elif node_type == 'merge':
|
||
# Merge节点:合并多个分支的数据流
|
||
node_data = node.get('data', {})
|
||
mode = node_data.get('mode', 'merge_all') # merge_all, merge_first, merge_last
|
||
strategy = node_data.get('strategy', 'array') # array, object, concat
|
||
|
||
# 获取所有上游节点的输出(通过input_data中的特殊字段)
|
||
# 如果input_data包含多个分支的数据,合并它们
|
||
merged_data = {}
|
||
|
||
if strategy == 'array':
|
||
# 数组策略:将所有输入数据作为数组元素
|
||
if isinstance(input_data, list):
|
||
merged_data = input_data
|
||
elif isinstance(input_data, dict):
|
||
# 如果包含多个分支数据,提取为数组
|
||
branch_data = []
|
||
for key, value in input_data.items():
|
||
if not key.startswith('_'):
|
||
branch_data.append(value)
|
||
merged_data = branch_data if branch_data else [input_data]
|
||
else:
|
||
merged_data = [input_data]
|
||
|
||
elif strategy == 'object':
|
||
# 对象策略:合并所有字段
|
||
if isinstance(input_data, dict):
|
||
merged_data = input_data.copy()
|
||
else:
|
||
merged_data = {'data': input_data}
|
||
|
||
elif strategy == 'concat':
|
||
# 连接策略:将所有数据连接为字符串
|
||
if isinstance(input_data, list):
|
||
merged_data = '\n'.join(str(item) for item in input_data)
|
||
elif isinstance(input_data, dict):
|
||
merged_data = '\n'.join(f"{k}: {v}" for k, v in input_data.items() if not k.startswith('_'))
|
||
else:
|
||
merged_data = str(input_data)
|
||
|
||
exec_result = {'output': merged_data, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, merged_data, duration)
|
||
return exec_result
|
||
|
||
elif node_type == 'wait':
|
||
# Wait节点:等待条件满足
|
||
node_data = node.get('data', {})
|
||
wait_type = node_data.get('wait_type', 'condition') # condition, time, event
|
||
condition = node_data.get('condition', '')
|
||
timeout = node_data.get('timeout', 300) # 默认5分钟
|
||
poll_interval = node_data.get('poll_interval', 5) # 默认5秒
|
||
|
||
if wait_type == 'condition':
|
||
# 等待条件满足
|
||
start_wait = time.time()
|
||
while time.time() - start_wait < timeout:
|
||
try:
|
||
result = condition_parser.evaluate_condition(condition, input_data)
|
||
if result:
|
||
exec_result = {'output': input_data, 'status': 'success', 'waited': time.time() - start_wait}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, exec_result, duration)
|
||
return exec_result
|
||
except Exception as e:
|
||
logger.warning(f"Wait节点条件评估失败: {str(e)}")
|
||
|
||
await asyncio.sleep(poll_interval)
|
||
|
||
# 超时
|
||
exec_result = {
|
||
'output': input_data,
|
||
'status': 'failed',
|
||
'error': f'等待条件超时: {timeout}秒'
|
||
}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, Exception("等待超时"), duration)
|
||
return exec_result
|
||
|
||
elif wait_type == 'time':
|
||
# 等待固定时间
|
||
wait_seconds = node_data.get('wait_seconds', 0)
|
||
await asyncio.sleep(wait_seconds)
|
||
exec_result = {'output': input_data, 'status': 'success', 'waited': wait_seconds}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, exec_result, duration)
|
||
return exec_result
|
||
|
||
else:
|
||
# 其他类型暂不支持
|
||
exec_result = {'output': input_data, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, exec_result, duration)
|
||
return exec_result
|
||
|
||
elif node_type == 'json':
|
||
# JSON处理节点
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'parse') # parse, stringify, extract, validate
|
||
path = node_data.get('path', '')
|
||
schema = node_data.get('schema', {})
|
||
|
||
try:
|
||
if operation == 'parse':
|
||
# 解析JSON字符串
|
||
if isinstance(input_data, str):
|
||
result = json_module.loads(input_data)
|
||
elif isinstance(input_data, dict) and 'data' in input_data:
|
||
# 如果包含data字段,尝试解析
|
||
if isinstance(input_data['data'], str):
|
||
result = json_module.loads(input_data['data'])
|
||
else:
|
||
result = input_data['data']
|
||
else:
|
||
result = input_data
|
||
|
||
elif operation == 'stringify':
|
||
# 转换为JSON字符串
|
||
result = json_module.dumps(input_data, ensure_ascii=False, indent=2)
|
||
|
||
elif operation == 'extract':
|
||
# 使用JSONPath提取数据(简化实现)
|
||
if path and isinstance(input_data, dict):
|
||
# 简单的路径提取,支持 $.key 格式
|
||
path = path.replace('$.', '').replace('$', '')
|
||
keys = path.split('.')
|
||
result = input_data
|
||
for key in keys:
|
||
if key.endswith('[*]'):
|
||
# 数组提取
|
||
array_key = key[:-3]
|
||
if isinstance(result, dict) and array_key in result:
|
||
result = result[array_key]
|
||
elif isinstance(result, dict) and key in result:
|
||
result = result[key]
|
||
else:
|
||
result = None
|
||
break
|
||
else:
|
||
result = input_data
|
||
|
||
elif operation == 'validate':
|
||
# JSON Schema验证(简化实现)
|
||
# 这里只做基本验证,完整实现需要使用jsonschema库
|
||
if schema:
|
||
# 简单类型检查
|
||
if 'type' in schema:
|
||
expected_type = schema['type']
|
||
actual_type = type(input_data).__name__
|
||
if expected_type == 'object' and actual_type != 'dict':
|
||
raise ValueError(f"期望类型 {expected_type},实际类型 {actual_type}")
|
||
elif expected_type == 'array' and actual_type != 'list':
|
||
raise ValueError(f"期望类型 {expected_type},实际类型 {actual_type}")
|
||
result = input_data
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'JSON处理失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'text':
|
||
# 文本处理节点
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'split') # split, join, extract, replace, format
|
||
delimiter = node_data.get('delimiter', '\n')
|
||
regex = node_data.get('regex', '')
|
||
template = node_data.get('template', '')
|
||
|
||
try:
|
||
# 获取输入文本
|
||
input_text = input_data
|
||
if isinstance(input_data, dict):
|
||
# 尝试从字典中提取文本
|
||
for key in ['text', 'content', 'message', 'input', 'output']:
|
||
if key in input_data and isinstance(input_data[key], str):
|
||
input_text = input_data[key]
|
||
break
|
||
if isinstance(input_text, dict):
|
||
input_text = str(input_text)
|
||
elif not isinstance(input_text, str):
|
||
input_text = str(input_text)
|
||
|
||
if operation == 'split':
|
||
# 拆分文本
|
||
result = input_text.split(delimiter)
|
||
|
||
elif operation == 'join':
|
||
# 合并文本(需要输入是数组)
|
||
if isinstance(input_data, list):
|
||
result = delimiter.join(str(item) for item in input_data)
|
||
else:
|
||
result = input_text
|
||
|
||
elif operation == 'extract':
|
||
# 使用正则表达式提取
|
||
if regex:
|
||
matches = re.findall(regex, input_text)
|
||
result = matches if len(matches) > 1 else (matches[0] if matches else '')
|
||
else:
|
||
result = input_text
|
||
|
||
elif operation == 'replace':
|
||
# 替换文本
|
||
old_text = node_data.get('old_text', '')
|
||
new_text = node_data.get('new_text', '')
|
||
if regex:
|
||
result = re.sub(regex, new_text, input_text)
|
||
else:
|
||
result = input_text.replace(old_text, new_text)
|
||
|
||
elif operation == 'format':
|
||
# 格式化文本(使用模板)
|
||
if template:
|
||
# 支持 {key} 格式的变量替换
|
||
result = template
|
||
if isinstance(input_data, dict):
|
||
for key, value in input_data.items():
|
||
result = result.replace(f'{{{key}}}', str(value))
|
||
else:
|
||
result = result.replace('{value}', str(input_data))
|
||
else:
|
||
result = input_text
|
||
|
||
else:
|
||
result = input_text
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'文本处理失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'cache':
|
||
# 缓存节点
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'get') # get, set, delete, clear
|
||
key = node_data.get('key', '')
|
||
ttl = node_data.get('ttl', 3600) # 默认1小时
|
||
# 默认优先使用redis(如果配置了),否则memory
|
||
backend = node_data.get('backend') or ('redis' if getattr(settings, 'REDIS_URL', None) else 'memory') # redis, memory
|
||
default_value = node_data.get('default_value', '{}')
|
||
value_template = node_data.get('value', '')
|
||
|
||
# 使用Redis作为持久化缓存(如果可用),否则使用内存缓存
|
||
# 注意:内存缓存在单次执行会话内有效,跨执行不会保留
|
||
use_redis = False
|
||
redis_client = None
|
||
# 默认尝试使用Redis(如果配置了),除非明确指定使用memory
|
||
if backend != 'memory':
|
||
try:
|
||
from app.core.redis_client import get_redis_client
|
||
redis_client = get_redis_client()
|
||
if redis_client:
|
||
use_redis = True
|
||
logger.info(f"[rjb] Cache节点 {node_id} 使用Redis缓存")
|
||
except Exception as e:
|
||
logger.warning(f"Redis不可用: {str(e)},使用内存缓存")
|
||
|
||
# 内存缓存(单次执行会话内有效)
|
||
if not hasattr(self, '_cache_store'):
|
||
self._cache_store = {}
|
||
self._cache_timestamps = {}
|
||
|
||
try:
|
||
# 替换key中的变量
|
||
if isinstance(input_data, dict):
|
||
# 首先处理 {{variable}} 格式
|
||
double_brace_vars = re.findall(r'\{\{(\w+)\}\}', key)
|
||
for var_name in double_brace_vars:
|
||
if var_name in input_data:
|
||
key = key.replace(f'{{{{{var_name}}}}}', str(input_data[var_name]))
|
||
else:
|
||
# 如果变量不存在,使用默认值
|
||
if var_name == 'user_id':
|
||
# 尝试从输入数据中提取user_id,如果没有则使用"default"
|
||
user_id = input_data.get('user_id') or input_data.get('USER_ID') or 'default'
|
||
key = key.replace(f'{{{{{var_name}}}}}', str(user_id))
|
||
else:
|
||
key = key.replace(f'{{{{{var_name}}}}}', 'default')
|
||
|
||
# 然后处理 {key} 格式
|
||
for k, v in input_data.items():
|
||
key = key.replace(f'{{{k}}}', str(v))
|
||
|
||
# 如果key中还有未替换的变量,使用默认值
|
||
if '{' in key:
|
||
key = key.replace('{user_id}', 'default').replace('{{user_id}}', 'default')
|
||
# 清理其他未替换的变量
|
||
key = re.sub(r'\{[^}]+\}', 'default', key)
|
||
|
||
logger.info(f"[rjb] Cache节点 {node_id} 处理后的key: {key}")
|
||
|
||
if operation == 'get':
|
||
# 获取缓存
|
||
result = None
|
||
cache_hit = False
|
||
|
||
if use_redis and redis_client:
|
||
# 从Redis获取
|
||
try:
|
||
cached_data = redis_client.get(key)
|
||
if cached_data:
|
||
result = json_module.loads(cached_data)
|
||
cache_hit = True
|
||
except Exception as e:
|
||
logger.warning(f"从Redis获取缓存失败: {str(e)}")
|
||
|
||
if result is None:
|
||
# 从内存缓存获取
|
||
if key in self._cache_store:
|
||
# 检查是否过期
|
||
if key in self._cache_timestamps:
|
||
if time.time() - self._cache_timestamps[key] > ttl:
|
||
# 过期,删除
|
||
del self._cache_store[key]
|
||
del self._cache_timestamps[key]
|
||
else:
|
||
result = self._cache_store[key]
|
||
cache_hit = True
|
||
else:
|
||
result = self._cache_store[key]
|
||
cache_hit = True
|
||
|
||
# 永久记忆:MySQL 与 Redis 合并(Redis 过期或冷启动时仍可从 DB 恢复)
|
||
try:
|
||
from app.services import persistent_memory_service as _pmem
|
||
|
||
if (
|
||
_pmem.persist_enabled()
|
||
and self.db
|
||
and _pmem.is_user_memory_redis_key(key)
|
||
):
|
||
sk = _pmem.session_key_from_user_memory_key(key)
|
||
skind, sid = self._get_persist_scope()
|
||
if sk and skind and sid:
|
||
db_payload = _pmem.load_persistent_memory(self.db, skind, sid, sk)
|
||
if db_payload is not None:
|
||
result = _pmem.merge_memory_payloads(db_payload, result)
|
||
except Exception as _pe:
|
||
logger.warning(f"加载持久记忆失败: {_pe}")
|
||
|
||
# 如果缓存未命中,使用default_value
|
||
if result is None:
|
||
try:
|
||
if isinstance(default_value, str):
|
||
result = json_module.loads(default_value) if default_value else {}
|
||
else:
|
||
result = default_value
|
||
except:
|
||
result = {}
|
||
cache_hit = False
|
||
logger.info(f"[rjb] Cache节点 {node_id} cache miss,使用default_value: {result}")
|
||
|
||
# 合并输入数据和缓存结果
|
||
output = input_data.copy() if isinstance(input_data, dict) else {}
|
||
if isinstance(result, dict):
|
||
output.update(result)
|
||
else:
|
||
output['memory'] = result
|
||
|
||
exec_result = {'output': output, 'status': 'success', 'cache_hit': cache_hit, 'memory': result}
|
||
|
||
elif operation == 'set':
|
||
# 设置缓存
|
||
# 处理value模板
|
||
if value_template:
|
||
# 处理模板语法 {{variable}}
|
||
value_str = value_template
|
||
|
||
# 替换 {{variable}} 格式的变量
|
||
# 注意:只替换 memory.* 路径的变量,user_input、output、timestamp 等变量在Python表达式执行阶段处理
|
||
template_vars = re.findall(r'\{\{(\w+(?:\.\w+)*)\}\}', value_str)
|
||
for var_path in template_vars:
|
||
# 跳过 user_input、output、timestamp 等变量,这些在Python表达式执行阶段处理
|
||
if var_path in ['user_input', 'output', 'timestamp']:
|
||
continue
|
||
|
||
# 支持嵌套路径,如 memory.conversation_history
|
||
var_parts = var_path.split('.')
|
||
var_value = input_data
|
||
try:
|
||
for part in var_parts:
|
||
if isinstance(var_value, dict) and part in var_value:
|
||
var_value = var_value[part]
|
||
else:
|
||
var_value = None
|
||
break
|
||
|
||
if var_value is not None:
|
||
# 替换模板变量
|
||
replacement = json_module.dumps(var_value, ensure_ascii=False) if isinstance(var_value, (dict, list)) else str(var_value)
|
||
value_str = value_str.replace(f'{{{{{var_path}}}}}', replacement)
|
||
else:
|
||
# 变量不存在,根据路径使用合适的默认值
|
||
if 'conversation_history' in var_path:
|
||
value_str = value_str.replace(f'{{{{{var_path}}}}}', '[]')
|
||
elif 'user_profile' in var_path or 'context' in var_path:
|
||
value_str = value_str.replace(f'{{{{{var_path}}}}}', '{}')
|
||
else:
|
||
# 对于其他变量,保留原样,让Python表达式执行阶段处理
|
||
pass
|
||
except Exception as e:
|
||
logger.warning(f"处理模板变量 {var_path} 失败: {str(e)}")
|
||
|
||
# 替换 {key} 格式的变量(但不要替换 {{variable}} 格式的)
|
||
for k, v in input_data.items():
|
||
placeholder = f'{{{k}}}'
|
||
# 确保不是 {{variable}} 格式
|
||
if placeholder in value_str and f'{{{{{k}}}}}' not in value_str:
|
||
replacement = json_module.dumps(v, ensure_ascii=False) if isinstance(v, (dict, list)) else str(v)
|
||
value_str = value_str.replace(placeholder, replacement)
|
||
|
||
# 解析处理后的value(可能是JSON字符串或Python表达式)
|
||
try:
|
||
# 尝试作为JSON解析
|
||
value = json_module.loads(value_str)
|
||
except:
|
||
# 如果不是有效的JSON,尝试作为Python表达式执行(安全限制)
|
||
try:
|
||
# 准备安全的环境变量
|
||
from datetime import datetime
|
||
memory = input_data.get('memory', {})
|
||
if not isinstance(memory, dict):
|
||
memory = {}
|
||
|
||
# 合并本轮 LLM 的 user_profile_update,便于多轮记住姓名等信息
|
||
upd = input_data.get('user_profile_update')
|
||
if isinstance(upd, str) and upd.strip().startswith('{'):
|
||
try:
|
||
upd = json_module.loads(upd)
|
||
except Exception:
|
||
upd = {}
|
||
if not isinstance(upd, dict):
|
||
upd = {}
|
||
if not upd.get("name"):
|
||
prof = self._extract_user_profile_from_llm_node_outputs()
|
||
if prof:
|
||
upd = {**upd, **prof}
|
||
base_up = memory.get('user_profile') or {}
|
||
if not isinstance(base_up, dict):
|
||
base_up = {}
|
||
memory['user_profile'] = {**base_up, **upd}
|
||
|
||
hb_upd = input_data.get('homework_board_update')
|
||
if isinstance(hb_upd, str) and hb_upd.strip().startswith('{'):
|
||
try:
|
||
hb_upd = json_module.loads(hb_upd)
|
||
except Exception:
|
||
hb_upd = {}
|
||
if not isinstance(hb_upd, dict):
|
||
hb_upd = {}
|
||
if hb_upd:
|
||
ctx = memory.get('context')
|
||
if not isinstance(ctx, dict):
|
||
ctx = {}
|
||
base_hb = ctx.get('homework_board')
|
||
if not isinstance(base_hb, dict):
|
||
base_hb = {}
|
||
merged_hb = {**base_hb, **hb_upd}
|
||
new_items = hb_upd.get('items')
|
||
old_items = base_hb.get('items')
|
||
if isinstance(new_items, list) and len(new_items) > 0:
|
||
merged_hb['items'] = new_items
|
||
elif isinstance(old_items, list):
|
||
merged_hb['items'] = old_items
|
||
ctx['homework_board'] = merged_hb
|
||
memory['context'] = ctx
|
||
|
||
# 确保memory中有必要的字段
|
||
if 'conversation_history' not in memory:
|
||
memory['conversation_history'] = []
|
||
if 'context' not in memory:
|
||
memory['context'] = {}
|
||
|
||
# 获取 user_input(与 LLM 一致,支持 right 等嵌套)
|
||
user_input = self._extract_user_message_text(input_data)
|
||
|
||
# 获取助手回复文本:避免 {{reply}} 占位或未解析的 JSON 串写入记忆
|
||
output = input_data.get('right', '')
|
||
if isinstance(output, dict):
|
||
output = output.get('reply') or output.get('right', '') or output.get('content', '') or str(output)
|
||
if isinstance(output, str) and output.strip().startswith('{'):
|
||
try:
|
||
_jo = json_module.loads(output)
|
||
if isinstance(_jo, dict) and _jo.get('reply'):
|
||
output = _jo['reply']
|
||
except Exception:
|
||
pass
|
||
if not output:
|
||
output = ''
|
||
os = str(output).strip()
|
||
if not os or os in ('{{reply}}', '{{right}}', '{{output}}') or (
|
||
os.startswith('{{') and os.endswith('}}')
|
||
):
|
||
reply_guess = self._extract_reply_from_llm_node_outputs()
|
||
if reply_guess:
|
||
output = reply_guess
|
||
|
||
timestamp = datetime.now().isoformat()
|
||
|
||
# 在Python表达式执行前,替换 {{user_input}}、{{output}}、{{timestamp}}
|
||
# 注意:模板中已经有引号了,所以需要转义字符串中的特殊字符,然后直接插入
|
||
# 使用json.dumps来正确转义,但去掉外层的引号(因为模板中已经有引号了)
|
||
user_input_escaped = json_module.dumps(user_input, ensure_ascii=False)[1:-1] # 去掉首尾引号
|
||
output_escaped = json_module.dumps(output, ensure_ascii=False)[1:-1]
|
||
timestamp_escaped = json_module.dumps(timestamp, ensure_ascii=False)[1:-1]
|
||
|
||
value_str = value_str.replace('{{user_input}}', user_input_escaped)
|
||
value_str = value_str.replace('{{output}}', output_escaped)
|
||
value_str = value_str.replace('{{timestamp}}', timestamp_escaped)
|
||
|
||
# 只允许基本的字典和列表操作
|
||
safe_dict = {
|
||
'memory': memory,
|
||
'user_input': user_input,
|
||
'output': output,
|
||
'timestamp': timestamp
|
||
}
|
||
|
||
logger.info(f"[rjb] Cache节点 {node_id} 执行value模板")
|
||
logger.info(f"[rjb] value_str前300字符: {value_str[:300]}")
|
||
logger.info(f"[rjb] user_input: {user_input[:50]}, output: {str(output)[:50]}, timestamp: {timestamp}")
|
||
value = eval(value_str, {"__builtins__": {}}, safe_dict)
|
||
logger.info(f"[rjb] Cache节点 {node_id} value模板执行成功,类型: {type(value)}")
|
||
|
||
# 确保 conversation_history 只保留最近若干条(性能优化,可在 Cache 节点 data.max_history_length 配置)
|
||
if isinstance(value, dict) and 'conversation_history' in value:
|
||
if isinstance(value['conversation_history'], list):
|
||
max_history_length = int(node_data.get('max_history_length', 20))
|
||
if len(value['conversation_history']) > max_history_length:
|
||
value['conversation_history'] = value['conversation_history'][-max_history_length:]
|
||
logger.info(f"[rjb] 对话历史已截断,保留最近 {max_history_length} 条")
|
||
|
||
if isinstance(value, dict):
|
||
logger.info(f"[rjb] keys: {list(value.keys())}")
|
||
if 'conversation_history' in value:
|
||
logger.info(f"[rjb] conversation_history: {len(value['conversation_history'])} 条")
|
||
if value['conversation_history']:
|
||
logger.info(f"[rjb] 第一条: {value['conversation_history'][0]}")
|
||
except Exception as e:
|
||
logger.error(f"Cache节点 {node_id} value模板执行失败: {str(e)}")
|
||
logger.error(f"value_str: {value_str[:500]}")
|
||
logger.error(f"safe_dict: {safe_dict}")
|
||
import traceback
|
||
logger.error(f"traceback: {traceback.format_exc()}")
|
||
# 如果都失败,使用原始输入数据
|
||
value = input_data
|
||
else:
|
||
# 没有value模板,使用输入数据
|
||
value = input_data
|
||
if isinstance(input_data, dict) and 'value' in input_data:
|
||
value = input_data['value']
|
||
|
||
# 存储到缓存
|
||
if use_redis and redis_client:
|
||
try:
|
||
redis_client.setex(key, ttl, json_module.dumps(value, ensure_ascii=False))
|
||
logger.info(f"[rjb] Cache节点 {node_id} 已存储到Redis: key={key}")
|
||
except Exception as e:
|
||
logger.warning(f"存储到Redis失败: {str(e)}")
|
||
|
||
# 永久记忆:写入 MySQL(与 user_memory_* 键一致)
|
||
try:
|
||
from app.services import persistent_memory_service as _pmem
|
||
|
||
if (
|
||
_pmem.persist_enabled()
|
||
and self.db
|
||
and _pmem.is_user_memory_redis_key(key)
|
||
and isinstance(value, dict)
|
||
):
|
||
sk = _pmem.session_key_from_user_memory_key(key)
|
||
skind, sid = self._get_persist_scope()
|
||
if sk and skind and sid:
|
||
_pmem.save_persistent_memory(self.db, skind, sid, sk, value)
|
||
logger.info(f"[rjb] 已持久化记忆到数据库: scope={skind}:{sid}, session={sk[:48]}")
|
||
except Exception as _pse:
|
||
logger.warning(f"持久化记忆到数据库失败: {_pse}")
|
||
|
||
# 同时存储到内存缓存
|
||
self._cache_store[key] = value
|
||
self._cache_timestamps[key] = time.time()
|
||
logger.info(f"[rjb] Cache节点 {node_id} 已存储: key={key}, value类型={type(value)}")
|
||
|
||
exec_result = {'output': input_data, 'status': 'success', 'cached_value': value}
|
||
|
||
elif operation == 'delete':
|
||
# 删除缓存
|
||
if use_redis and redis_client:
|
||
try:
|
||
redis_client.delete(key)
|
||
except Exception as _de:
|
||
logger.warning(f"从Redis删除缓存失败: {_de}")
|
||
if key in self._cache_store:
|
||
del self._cache_store[key]
|
||
if key in self._cache_timestamps:
|
||
del self._cache_timestamps[key]
|
||
try:
|
||
from app.services import persistent_memory_service as _pmem
|
||
|
||
if (
|
||
_pmem.persist_enabled()
|
||
and self.db
|
||
and _pmem.is_user_memory_redis_key(key)
|
||
):
|
||
sk = _pmem.session_key_from_user_memory_key(key)
|
||
skind, sid = self._get_persist_scope()
|
||
if sk and skind and sid:
|
||
_pmem.delete_persistent_memory(self.db, skind, sid, sk)
|
||
except Exception as _pde:
|
||
logger.warning(f"删除持久记忆失败: {_pde}")
|
||
exec_result = {'output': input_data, 'status': 'success'}
|
||
|
||
elif operation == 'clear':
|
||
# 清空缓存
|
||
self._cache_store.clear()
|
||
self._cache_timestamps.clear()
|
||
exec_result = {'output': input_data, 'status': 'success'}
|
||
|
||
else:
|
||
exec_result = {'output': input_data, 'status': 'success'}
|
||
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, exec_result.get('output'), duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'缓存操作失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'vector_db':
|
||
# 向量数据库节点:向量存储、相似度搜索、RAG检索
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'search') # search, upsert, delete
|
||
collection = node_data.get('collection', 'default')
|
||
query_vector = node_data.get('query_vector', '')
|
||
top_k = node_data.get('top_k', 5)
|
||
|
||
# 简化的内存向量存储实现(实际生产环境应使用ChromaDB、Pinecone等)
|
||
if not hasattr(self, '_vector_store'):
|
||
self._vector_store = {}
|
||
|
||
try:
|
||
if operation == 'search':
|
||
# 向量相似度搜索(简化实现:使用余弦相似度)
|
||
if collection not in self._vector_store:
|
||
self._vector_store[collection] = []
|
||
|
||
query_vec = self._resolve_vector_db_query_embedding(input_data, query_vector)
|
||
|
||
# 可选:仅检索当前会话用户的向量(metadata.user_id 与请求一致)
|
||
filter_uid = None
|
||
if isinstance(input_data, dict):
|
||
filter_uid = input_data.get("user_id") or input_data.get("USER_ID")
|
||
if filter_uid is not None:
|
||
filter_uid = str(filter_uid)
|
||
|
||
results: List[Dict[str, Any]] = []
|
||
if query_vec is None:
|
||
logger.warning(
|
||
"vector_db search: 未解析到查询向量,返回空检索(collection=%s)。"
|
||
"请检查上游 embedding 与 merge 字段;并确认 API/Celery 已重启加载最新引擎代码。",
|
||
collection,
|
||
)
|
||
result = []
|
||
else:
|
||
# 计算相似度并排序
|
||
for item in self._vector_store[collection]:
|
||
if 'vector' in item:
|
||
md = item.get("metadata") or {}
|
||
if filter_uid and md.get("user_id") not in (None, "", filter_uid):
|
||
continue
|
||
vec1 = query_vec
|
||
vec2 = item['vector']
|
||
if len(vec1) != len(vec2):
|
||
continue
|
||
|
||
dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
||
magnitude1 = math.sqrt(sum(a * a for a in vec1))
|
||
magnitude2 = math.sqrt(sum(a * a for a in vec2))
|
||
|
||
if magnitude1 == 0 or magnitude2 == 0:
|
||
similarity = 0
|
||
else:
|
||
similarity = dot_product / (magnitude1 * magnitude2)
|
||
|
||
results.append({
|
||
'id': item.get('id'),
|
||
'text': item.get('text', ''),
|
||
'metadata': item.get('metadata', {}),
|
||
'similarity': similarity
|
||
})
|
||
|
||
results.sort(key=lambda x: x['similarity'], reverse=True)
|
||
result = results[:top_k]
|
||
|
||
elif operation == 'upsert':
|
||
# 插入或更新向量
|
||
if collection not in self._vector_store:
|
||
self._vector_store[collection] = []
|
||
|
||
# 从输入数据中提取向量和文本
|
||
vector = input_data.get('embedding') or input_data.get('vector')
|
||
text = input_data.get('text') or input_data.get('content', '')
|
||
metadata = input_data.get('metadata', {})
|
||
doc_id = input_data.get('id') or f"doc_{len(self._vector_store[collection])}"
|
||
|
||
# 查找是否已存在
|
||
existing_index = None
|
||
for i, item in enumerate(self._vector_store[collection]):
|
||
if item.get('id') == doc_id:
|
||
existing_index = i
|
||
break
|
||
|
||
doc_item = {
|
||
'id': doc_id,
|
||
'vector': vector,
|
||
'text': text,
|
||
'metadata': metadata
|
||
}
|
||
|
||
if existing_index is not None:
|
||
self._vector_store[collection][existing_index] = doc_item
|
||
else:
|
||
self._vector_store[collection].append(doc_item)
|
||
|
||
result = {'id': doc_id, 'status': 'upserted'}
|
||
|
||
elif operation == 'delete':
|
||
# 删除向量
|
||
if collection in self._vector_store:
|
||
doc_id = node_data.get('doc_id') or input_data.get('id')
|
||
if doc_id:
|
||
self._vector_store[collection] = [
|
||
item for item in self._vector_store[collection]
|
||
if item.get('id') != doc_id
|
||
]
|
||
result = {'id': doc_id, 'status': 'deleted'}
|
||
else:
|
||
# 删除整个集合
|
||
del self._vector_store[collection]
|
||
result = {'collection': collection, 'status': 'deleted'}
|
||
else:
|
||
result = {'status': 'not_found'}
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'向量数据库操作失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'log':
|
||
# 日志节点:记录日志、调试输出、性能监控
|
||
node_data = node.get('data', {})
|
||
level = node_data.get('level', 'info') # debug, info, warning, error
|
||
message = node_data.get('message', '')
|
||
include_data = node_data.get('include_data', True)
|
||
|
||
try:
|
||
# 格式化消息
|
||
if message:
|
||
# 替换变量
|
||
if isinstance(input_data, dict):
|
||
for key, value in input_data.items():
|
||
message = message.replace(f'{{{key}}}', str(value))
|
||
|
||
# 构建日志内容
|
||
log_data = {
|
||
'message': message or '节点执行',
|
||
'node_id': node_id,
|
||
'node_type': node_type,
|
||
'timestamp': time.time()
|
||
}
|
||
|
||
if include_data:
|
||
log_data['data'] = input_data
|
||
|
||
# 记录日志
|
||
log_message = f"[{node_id}] {log_data['message']}"
|
||
if include_data:
|
||
log_message += f" | 数据: {json_module.dumps(input_data, ensure_ascii=False)[:200]}"
|
||
|
||
if level == 'debug':
|
||
logger.debug(log_message)
|
||
elif level == 'info':
|
||
logger.info(log_message)
|
||
elif level == 'warning':
|
||
logger.warning(log_message)
|
||
elif level == 'error':
|
||
logger.error(log_message)
|
||
else:
|
||
logger.info(log_message)
|
||
|
||
# 如果使用执行日志记录器,也记录
|
||
if self.logger:
|
||
self.logger.info(log_data['message'], node_id=node_id, node_type=node_type, data=input_data if include_data else None)
|
||
|
||
exec_result = {'output': input_data, 'status': 'success', 'log': log_data}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, exec_result.get('output'), duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': input_data,
|
||
'status': 'failed',
|
||
'error': f'日志记录失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'error_handler':
|
||
# 错误处理节点:捕获错误、错误重试、错误通知
|
||
# 注意:这个节点需要特殊处理,因为它应该包装其他节点的执行
|
||
# 这里我们实现一个简化版本,主要用于错误重试和通知
|
||
node_data = node.get('data', {})
|
||
retry_count = node_data.get('retry_count', 3)
|
||
retry_delay = node_data.get('retry_delay', 1000) # 毫秒
|
||
on_error = node_data.get('on_error', 'notify') # notify, retry, stop
|
||
error_handler_workflow = node_data.get('error_handler_workflow', '')
|
||
|
||
# 这个节点通常用于包装其他节点,但在这里我们只处理输入数据中的错误
|
||
try:
|
||
# 检查输入数据中是否有错误
|
||
if isinstance(input_data, dict) and input_data.get('status') == 'failed':
|
||
error = input_data.get('error', '未知错误')
|
||
|
||
if on_error == 'retry' and retry_count > 0:
|
||
# 重试逻辑(这里简化处理,实际应该重新执行前一个节点)
|
||
logger.warning(f"错误处理节点检测到错误,将重试: {error}")
|
||
# 注意:实际重试需要重新执行前一个节点,这里只记录
|
||
exec_result = {
|
||
'output': input_data,
|
||
'status': 'retry',
|
||
'retry_count': retry_count,
|
||
'error': error
|
||
}
|
||
elif on_error == 'notify':
|
||
# 通知错误(记录日志)
|
||
logger.error(f"错误处理节点捕获错误: {error}")
|
||
exec_result = {
|
||
'output': input_data,
|
||
'status': 'error_handled',
|
||
'error': error,
|
||
'notified': True
|
||
}
|
||
else:
|
||
# 停止执行
|
||
exec_result = {
|
||
'output': input_data,
|
||
'status': 'failed',
|
||
'error': error,
|
||
'stopped': True
|
||
}
|
||
else:
|
||
# 没有错误,正常通过
|
||
exec_result = {'output': input_data, 'status': 'success'}
|
||
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, exec_result.get('output'), duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': input_data,
|
||
'status': 'failed',
|
||
'error': f'错误处理失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'csv':
|
||
# CSV处理节点:CSV解析、生成、转换
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'parse') # parse, generate, convert
|
||
delimiter = node_data.get('delimiter', ',')
|
||
headers = node_data.get('headers', True)
|
||
encoding = node_data.get('encoding', 'utf-8')
|
||
|
||
try:
|
||
import csv
|
||
import io
|
||
|
||
if operation == 'parse':
|
||
# 解析CSV
|
||
csv_text = input_data
|
||
if isinstance(input_data, dict):
|
||
# 尝试从字典中提取CSV文本
|
||
for key in ['csv', 'data', 'content', 'text']:
|
||
if key in input_data and isinstance(input_data[key], str):
|
||
csv_text = input_data[key]
|
||
break
|
||
|
||
if not isinstance(csv_text, str):
|
||
csv_text = str(csv_text)
|
||
|
||
# 解析CSV
|
||
csv_reader = csv.DictReader(io.StringIO(csv_text), delimiter=delimiter) if headers else csv.reader(io.StringIO(csv_text), delimiter=delimiter)
|
||
|
||
if headers:
|
||
result = list(csv_reader)
|
||
else:
|
||
# 没有表头,返回数组的数组
|
||
result = list(csv_reader)
|
||
|
||
elif operation == 'generate':
|
||
# 生成CSV
|
||
data = input_data
|
||
if isinstance(input_data, dict) and 'data' in input_data:
|
||
data = input_data['data']
|
||
|
||
if not isinstance(data, list):
|
||
data = [data]
|
||
|
||
output = io.StringIO()
|
||
if data and isinstance(data[0], dict):
|
||
# 字典列表,使用DictWriter
|
||
fieldnames = data[0].keys()
|
||
writer = csv.DictWriter(output, fieldnames=fieldnames, delimiter=delimiter)
|
||
if headers:
|
||
writer.writeheader()
|
||
writer.writerows(data)
|
||
else:
|
||
# 数组列表,使用writer
|
||
writer = csv.writer(output, delimiter=delimiter)
|
||
if headers and data and isinstance(data[0], list):
|
||
# 假设第一行是表头
|
||
writer.writerow(data[0])
|
||
writer.writerows(data[1:])
|
||
else:
|
||
writer.writerows(data)
|
||
|
||
result = output.getvalue()
|
||
|
||
elif operation == 'convert':
|
||
# 转换CSV格式(改变分隔符等)
|
||
csv_text = input_data
|
||
if isinstance(input_data, dict):
|
||
for key in ['csv', 'data', 'content']:
|
||
if key in input_data and isinstance(input_data[key], str):
|
||
csv_text = input_data[key]
|
||
break
|
||
|
||
if not isinstance(csv_text, str):
|
||
csv_text = str(csv_text)
|
||
|
||
# 读取并重新写入
|
||
reader = csv.reader(io.StringIO(csv_text))
|
||
output = io.StringIO()
|
||
writer = csv.writer(output, delimiter=delimiter)
|
||
writer.writerows(reader)
|
||
result = output.getvalue()
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'CSV处理失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'object_storage':
|
||
# 对象存储节点:文件上传、下载、删除、列表
|
||
node_data = node.get('data', {})
|
||
provider = node_data.get('provider', 's3') # oss, s3, cos
|
||
operation = node_data.get('operation', 'upload')
|
||
bucket = node_data.get('bucket', '')
|
||
key = node_data.get('key', '')
|
||
file_data = node_data.get('file', '')
|
||
|
||
try:
|
||
# 简化实现:实际生产环境需要使用boto3(AWS S3)、oss2(阿里云OSS)等
|
||
# 这里提供一个接口框架,实际使用时需要安装相应的SDK
|
||
|
||
if operation == 'upload':
|
||
# 上传文件
|
||
if not file_data:
|
||
# 从input_data中获取文件数据
|
||
if isinstance(input_data, dict):
|
||
file_data = input_data.get('file') or input_data.get('data') or input_data.get('content')
|
||
else:
|
||
file_data = input_data
|
||
|
||
# 替换key中的变量
|
||
if isinstance(input_data, dict):
|
||
for k, v in input_data.items():
|
||
key = key.replace(f'{{{k}}}', str(v))
|
||
|
||
# 这里只是模拟上传,实际需要调用相应的SDK
|
||
logger.info(f"对象存储上传: provider={provider}, bucket={bucket}, key={key}")
|
||
result = {
|
||
'provider': provider,
|
||
'bucket': bucket,
|
||
'key': key,
|
||
'status': 'uploaded',
|
||
'url': f"{provider}://{bucket}/{key}" # 模拟URL
|
||
}
|
||
|
||
elif operation == 'download':
|
||
# 下载文件
|
||
if isinstance(input_data, dict):
|
||
for k, v in input_data.items():
|
||
key = key.replace(f'{{{k}}}', str(v))
|
||
|
||
logger.info(f"对象存储下载: provider={provider}, bucket={bucket}, key={key}")
|
||
# 这里只是模拟下载,实际需要调用相应的SDK
|
||
result = {
|
||
'provider': provider,
|
||
'bucket': bucket,
|
||
'key': key,
|
||
'status': 'downloaded',
|
||
'data': '模拟文件内容' # 实际应该是文件内容
|
||
}
|
||
|
||
elif operation == 'delete':
|
||
# 删除文件
|
||
if isinstance(input_data, dict):
|
||
for k, v in input_data.items():
|
||
key = key.replace(f'{{{k}}}', str(v))
|
||
|
||
logger.info(f"对象存储删除: provider={provider}, bucket={bucket}, key={key}")
|
||
result = {
|
||
'provider': provider,
|
||
'bucket': bucket,
|
||
'key': key,
|
||
'status': 'deleted'
|
||
}
|
||
|
||
elif operation == 'list':
|
||
# 列出文件
|
||
prefix = node_data.get('prefix', '')
|
||
logger.info(f"对象存储列表: provider={provider}, bucket={bucket}, prefix={prefix}")
|
||
result = {
|
||
'provider': provider,
|
||
'bucket': bucket,
|
||
'prefix': prefix,
|
||
'files': [] # 实际应该是文件列表
|
||
}
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'对象存储操作失败: {str(e)}。注意:实际使用需要安装相应的SDK(如boto3、oss2等)'
|
||
}
|
||
|
||
elif node_type == 'slack':
|
||
# Slack节点:发送消息、创建频道、获取消息
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'send_message')
|
||
token = node_data.get('token', '')
|
||
channel = node_data.get('channel', '')
|
||
message = node_data.get('message', '')
|
||
attachments = node_data.get('attachments', [])
|
||
|
||
try:
|
||
import httpx
|
||
|
||
# 替换消息中的变量
|
||
if isinstance(input_data, dict):
|
||
for key, value in input_data.items():
|
||
message = message.replace(f'{{{key}}}', str(value))
|
||
channel = channel.replace(f'{{{key}}}', str(value))
|
||
|
||
if operation == 'send_message':
|
||
# 发送消息
|
||
url = 'https://slack.com/api/chat.postMessage'
|
||
headers = {
|
||
'Authorization': f'Bearer {token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
payload = {
|
||
'channel': channel,
|
||
'text': message
|
||
}
|
||
if attachments:
|
||
payload['attachments'] = attachments
|
||
|
||
# 注意:实际使用时需要安装httpx库,这里提供接口框架
|
||
# async with httpx.AsyncClient() as client:
|
||
# response = await client.post(url, headers=headers, json=payload)
|
||
# result = response.json()
|
||
|
||
# 模拟响应
|
||
logger.info(f"Slack发送消息: channel={channel}, message={message[:50]}")
|
||
result = {
|
||
'ok': True,
|
||
'channel': channel,
|
||
'ts': str(time.time()),
|
||
'message': {'text': message}
|
||
}
|
||
|
||
elif operation == 'create_channel':
|
||
# 创建频道
|
||
url = 'https://slack.com/api/conversations.create'
|
||
headers = {
|
||
'Authorization': f'Bearer {token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
payload = {'name': channel}
|
||
|
||
logger.info(f"Slack创建频道: channel={channel}")
|
||
result = {'ok': True, 'channel': {'name': channel, 'id': f'C{int(time.time())}'}}
|
||
|
||
elif operation == 'get_messages':
|
||
# 获取消息
|
||
url = f'https://slack.com/api/conversations.history'
|
||
headers = {
|
||
'Authorization': f'Bearer {token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
params = {'channel': channel}
|
||
|
||
logger.info(f"Slack获取消息: channel={channel}")
|
||
result = {'ok': True, 'messages': []}
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'Slack操作失败: {str(e)}。注意:需要配置有效的Slack Token'
|
||
}
|
||
|
||
elif node_type == 'dingtalk' or node_type == 'dingding':
|
||
# 钉钉节点:发送消息、创建群组、获取消息
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'send_message')
|
||
webhook_url = node_data.get('webhook_url', '')
|
||
access_token = node_data.get('access_token', '')
|
||
chat_id = node_data.get('chat_id', '')
|
||
message = node_data.get('message', '')
|
||
|
||
try:
|
||
import httpx
|
||
|
||
# 替换消息中的变量
|
||
if isinstance(input_data, dict):
|
||
for key, value in input_data.items():
|
||
message = message.replace(f'{{{key}}}', str(value))
|
||
chat_id = chat_id.replace(f'{{{key}}}', str(value))
|
||
|
||
if operation == 'send_message':
|
||
# 发送消息(通过Webhook或API)
|
||
if webhook_url:
|
||
# 使用Webhook
|
||
payload = {
|
||
'msgtype': 'text',
|
||
'text': {'content': message}
|
||
}
|
||
# async with httpx.AsyncClient() as client:
|
||
# response = await client.post(webhook_url, json=payload)
|
||
# result = response.json()
|
||
logger.info(f"钉钉发送消息(Webhook): message={message[:50]}")
|
||
result = {'errcode': 0, 'errmsg': 'ok'}
|
||
else:
|
||
# 使用API
|
||
url = f'https://oapi.dingtalk.com/chat/send'
|
||
headers = {
|
||
'Content-Type': 'application/json'
|
||
}
|
||
payload = {
|
||
'access_token': access_token,
|
||
'chatid': chat_id,
|
||
'msg': {
|
||
'msgtype': 'text',
|
||
'text': {'content': message}
|
||
}
|
||
}
|
||
logger.info(f"钉钉发送消息(API): chat_id={chat_id}, message={message[:50]}")
|
||
result = {'errcode': 0, 'errmsg': 'ok'}
|
||
|
||
elif operation == 'create_group':
|
||
# 创建群组
|
||
url = 'https://oapi.dingtalk.com/chat/create'
|
||
headers = {'Content-Type': 'application/json'}
|
||
payload = {
|
||
'access_token': access_token,
|
||
'name': chat_id,
|
||
'owner': node_data.get('owner', '')
|
||
}
|
||
logger.info(f"钉钉创建群组: name={chat_id}")
|
||
result = {'errcode': 0, 'errmsg': 'ok', 'chatid': f'chat_{int(time.time())}'}
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'钉钉操作失败: {str(e)}。注意:需要配置有效的Webhook URL或Access Token'
|
||
}
|
||
|
||
elif node_type == 'wechat_work' or node_type == 'wecom':
|
||
# 企业微信节点:发送消息、创建群组、获取消息
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'send_message')
|
||
corp_id = node_data.get('corp_id', '')
|
||
corp_secret = node_data.get('corp_secret', '')
|
||
agent_id = node_data.get('agent_id', '')
|
||
chat_id = node_data.get('chat_id', '')
|
||
message = node_data.get('message', '')
|
||
|
||
try:
|
||
import httpx
|
||
|
||
# 替换消息中的变量
|
||
if isinstance(input_data, dict):
|
||
for key, value in input_data.items():
|
||
message = message.replace(f'{{{key}}}', str(value))
|
||
chat_id = chat_id.replace(f'{{{key}}}', str(value))
|
||
|
||
if operation == 'send_message':
|
||
# 先获取access_token
|
||
token_url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken'
|
||
token_params = {
|
||
'corpid': corp_id,
|
||
'corpsecret': corp_secret
|
||
}
|
||
# async with httpx.AsyncClient() as client:
|
||
# token_response = await client.get(token_url, params=token_params)
|
||
# token_data = token_response.json()
|
||
# access_token = token_data.get('access_token')
|
||
|
||
# 模拟获取token
|
||
access_token = 'mock_token'
|
||
|
||
# 发送消息
|
||
url = f'https://qyapi.weixin.qq.com/cgi-bin/message/send'
|
||
params = {'access_token': access_token}
|
||
payload = {
|
||
'touser': chat_id or '@all',
|
||
'msgtype': 'text',
|
||
'agentid': agent_id,
|
||
'text': {'content': message}
|
||
}
|
||
|
||
logger.info(f"企业微信发送消息: chat_id={chat_id}, message={message[:50]}")
|
||
result = {'errcode': 0, 'errmsg': 'ok'}
|
||
|
||
elif operation == 'create_group':
|
||
# 创建群组
|
||
url = 'https://qyapi.weixin.qq.com/cgi-bin/appchat/create'
|
||
params = {'access_token': access_token}
|
||
payload = {
|
||
'name': chat_id,
|
||
'owner': node_data.get('owner', ''),
|
||
'userlist': node_data.get('userlist', [])
|
||
}
|
||
logger.info(f"企业微信创建群组: name={chat_id}")
|
||
result = {'errcode': 0, 'errmsg': 'ok', 'chatid': f'chat_{int(time.time())}'}
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'企业微信操作失败: {str(e)}。注意:需要配置有效的Corp ID和Secret'
|
||
}
|
||
|
||
elif node_type == 'sms':
|
||
# 短信节点(SMS):发送短信、批量发送、短信模板
|
||
node_data = node.get('data', {})
|
||
provider = node_data.get('provider', 'aliyun') # aliyun, tencent, twilio
|
||
operation = node_data.get('operation', 'send')
|
||
phone = node_data.get('phone', '')
|
||
template = node_data.get('template', '')
|
||
sign = node_data.get('sign', '')
|
||
access_key = node_data.get('access_key', '')
|
||
access_secret = node_data.get('access_secret', '')
|
||
|
||
try:
|
||
# 替换模板中的变量
|
||
if isinstance(input_data, dict):
|
||
for key, value in input_data.items():
|
||
template = template.replace(f'{{{key}}}', str(value))
|
||
phone = phone.replace(f'{{{key}}}', str(value))
|
||
|
||
if operation == 'send':
|
||
# 发送短信
|
||
if provider == 'aliyun':
|
||
# 阿里云短信(需要安装alibabacloud-dysmsapi20170525)
|
||
logger.info(f"阿里云短信发送: phone={phone}, template={template[:50]}")
|
||
result = {
|
||
'provider': 'aliyun',
|
||
'phone': phone,
|
||
'status': 'sent',
|
||
'message_id': f'sms_{int(time.time())}'
|
||
}
|
||
elif provider == 'tencent':
|
||
# 腾讯云短信(需要安装tencentcloud-sdk-python)
|
||
logger.info(f"腾讯云短信发送: phone={phone}, template={template[:50]}")
|
||
result = {
|
||
'provider': 'tencent',
|
||
'phone': phone,
|
||
'status': 'sent',
|
||
'message_id': f'sms_{int(time.time())}'
|
||
}
|
||
elif provider == 'twilio':
|
||
# Twilio短信(需要安装twilio)
|
||
logger.info(f"Twilio短信发送: phone={phone}, template={template[:50]}")
|
||
result = {
|
||
'provider': 'twilio',
|
||
'phone': phone,
|
||
'status': 'sent',
|
||
'message_id': f'sms_{int(time.time())}'
|
||
}
|
||
else:
|
||
result = {'error': f'不支持的短信提供商: {provider}'}
|
||
|
||
elif operation == 'batch_send':
|
||
# 批量发送
|
||
phones = node_data.get('phones', [])
|
||
if isinstance(phones, str):
|
||
phones = [p.strip() for p in phones.split(',')]
|
||
|
||
logger.info(f"批量发送短信: phones={len(phones)}, provider={provider}")
|
||
result = {
|
||
'provider': provider,
|
||
'phones': phones,
|
||
'status': 'sent',
|
||
'count': len(phones)
|
||
}
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'短信发送失败: {str(e)}。注意:需要安装相应的SDK(如alibabacloud-dysmsapi20170525、tencentcloud-sdk-python、twilio)'
|
||
}
|
||
|
||
elif node_type == 'pdf':
|
||
# PDF处理节点:PDF解析、生成、合并、拆分
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'extract_text') # extract_text, generate, merge, split
|
||
pages = node_data.get('pages', '')
|
||
template = node_data.get('template', '')
|
||
|
||
try:
|
||
# 注意:需要安装PyPDF2或pdfplumber库
|
||
# pip install PyPDF2 pdfplumber
|
||
|
||
if operation == 'extract_text':
|
||
# 提取文本
|
||
pdf_data = input_data
|
||
if isinstance(input_data, dict):
|
||
pdf_data = input_data.get('pdf') or input_data.get('data') or input_data.get('file')
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# from PyPDF2 import PdfReader
|
||
# reader = PdfReader(io.BytesIO(pdf_data))
|
||
# text = ""
|
||
# for page in reader.pages:
|
||
# text += page.extract_text()
|
||
|
||
logger.info(f"PDF提取文本: pages={pages}")
|
||
result = {
|
||
'text': 'PDF文本提取结果(需要安装PyPDF2或pdfplumber)',
|
||
'pages': pages or 'all'
|
||
}
|
||
|
||
elif operation == 'generate':
|
||
# 生成PDF
|
||
content = input_data
|
||
if isinstance(input_data, dict):
|
||
content = input_data.get('content') or input_data.get('text') or input_data.get('data')
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# from reportlab.pdfgen import canvas
|
||
# 或使用其他PDF生成库
|
||
|
||
logger.info(f"PDF生成: template={template}")
|
||
result = {
|
||
'pdf': 'PDF生成结果(需要安装reportlab或其他PDF生成库)',
|
||
'template': template
|
||
}
|
||
|
||
elif operation == 'merge':
|
||
# 合并PDF
|
||
pdfs = input_data
|
||
if isinstance(input_data, dict):
|
||
pdfs = input_data.get('pdfs') or input_data.get('files')
|
||
|
||
if not isinstance(pdfs, list):
|
||
pdfs = [pdfs]
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# from PyPDF2 import PdfMerger
|
||
# merger = PdfMerger()
|
||
# for pdf in pdfs:
|
||
# merger.append(pdf)
|
||
# result_pdf = merger.write()
|
||
|
||
logger.info(f"PDF合并: count={len(pdfs)}")
|
||
result = {
|
||
'merged_pdf': '合并后的PDF(需要安装PyPDF2)',
|
||
'count': len(pdfs)
|
||
}
|
||
|
||
elif operation == 'split':
|
||
# 拆分PDF
|
||
pdf_data = input_data
|
||
if isinstance(input_data, dict):
|
||
pdf_data = input_data.get('pdf') or input_data.get('file')
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# from PyPDF2 import PdfReader, PdfWriter
|
||
# reader = PdfReader(pdf_data)
|
||
# writer = PdfWriter()
|
||
# for page_num in range(start_page, end_page):
|
||
# writer.add_page(reader.pages[page_num])
|
||
|
||
logger.info(f"PDF拆分: pages={pages}")
|
||
result = {
|
||
'split_pdfs': ['拆分后的PDF列表(需要安装PyPDF2)'],
|
||
'pages': pages
|
||
}
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'PDF处理失败: {str(e)}。注意:需要安装PyPDF2或pdfplumber库(pip install PyPDF2 pdfplumber)'
|
||
}
|
||
|
||
elif node_type == 'image':
|
||
# 图像处理节点:图像缩放、裁剪、格式转换、OCR识别
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'resize') # resize, crop, convert, ocr
|
||
width = node_data.get('width', 800)
|
||
height = node_data.get('height', 600)
|
||
format_type = node_data.get('format', 'png')
|
||
|
||
try:
|
||
# 注意:需要安装Pillow库
|
||
# pip install Pillow
|
||
# OCR需要安装pytesseract和tesseract-ocr
|
||
# pip install pytesseract
|
||
|
||
image_data = input_data
|
||
if isinstance(input_data, dict):
|
||
image_data = input_data.get('image') or input_data.get('data') or input_data.get('file')
|
||
|
||
if operation == 'resize':
|
||
# 缩放图像
|
||
# 这里只是接口框架,实际需要:
|
||
# from PIL import Image
|
||
# import io
|
||
# img = Image.open(io.BytesIO(image_data))
|
||
# img_resized = img.resize((width, height))
|
||
# output = io.BytesIO()
|
||
# img_resized.save(output, format=format_type.upper())
|
||
# result = output.getvalue()
|
||
|
||
logger.info(f"图像缩放: {width}x{height}, format={format_type}")
|
||
result = {
|
||
'image': '缩放后的图像数据(需要安装Pillow)',
|
||
'width': width,
|
||
'height': height,
|
||
'format': format_type
|
||
}
|
||
|
||
elif operation == 'crop':
|
||
# 裁剪图像
|
||
x = node_data.get('x', 0)
|
||
y = node_data.get('y', 0)
|
||
crop_width = node_data.get('crop_width', width)
|
||
crop_height = node_data.get('crop_height', height)
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# from PIL import Image
|
||
# img = Image.open(io.BytesIO(image_data))
|
||
# img_cropped = img.crop((x, y, x + crop_width, y + crop_height))
|
||
|
||
logger.info(f"图像裁剪: ({x}, {y}, {crop_width}, {crop_height})")
|
||
result = {
|
||
'image': '裁剪后的图像数据(需要安装Pillow)',
|
||
'crop_box': (x, y, crop_width, crop_height)
|
||
}
|
||
|
||
elif operation == 'convert':
|
||
# 格式转换
|
||
target_format = node_data.get('target_format', format_type)
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# from PIL import Image
|
||
# img = Image.open(io.BytesIO(image_data))
|
||
# output = io.BytesIO()
|
||
# img.save(output, format=target_format.upper())
|
||
# result = output.getvalue()
|
||
|
||
logger.info(f"图像格式转换: {format_type} -> {target_format}")
|
||
result = {
|
||
'image': f'转换后的图像数据(需要安装Pillow)',
|
||
'format': target_format
|
||
}
|
||
|
||
elif operation == 'ocr':
|
||
# OCR识别
|
||
# 这里只是接口框架,实际需要:
|
||
# from PIL import Image
|
||
# import pytesseract
|
||
# img = Image.open(io.BytesIO(image_data))
|
||
# text = pytesseract.image_to_string(img, lang='chi_sim+eng')
|
||
|
||
logger.info(f"OCR识别")
|
||
result = {
|
||
'text': 'OCR识别结果(需要安装pytesseract和tesseract-ocr)',
|
||
'confidence': 0.95
|
||
}
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'图像处理失败: {str(e)}。注意:需要安装Pillow库(pip install Pillow),OCR需要pytesseract和tesseract-ocr'
|
||
}
|
||
|
||
elif node_type == 'excel':
|
||
# Excel处理节点:Excel读取、写入、格式转换、公式计算
|
||
node_data = node.get('data', {})
|
||
operation = node_data.get('operation', 'read') # read, write, convert, formula
|
||
sheet = node_data.get('sheet', 'Sheet1')
|
||
range_str = node_data.get('range', '')
|
||
format_type = node_data.get('format', 'xlsx') # xlsx, xls, csv
|
||
|
||
try:
|
||
# 注意:需要安装openpyxl或pandas库
|
||
# pip install openpyxl pandas
|
||
|
||
if operation == 'read':
|
||
# 读取Excel
|
||
excel_data = input_data
|
||
if isinstance(input_data, dict):
|
||
excel_data = input_data.get('excel') or input_data.get('file') or input_data.get('data')
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# import pandas as pd
|
||
# df = pd.read_excel(io.BytesIO(excel_data), sheet_name=sheet)
|
||
# if range_str:
|
||
# # 解析范围,如 "A1:C10"
|
||
# df = df.loc[range_start:range_end]
|
||
# result = df.to_dict('records')
|
||
|
||
logger.info(f"Excel读取: sheet={sheet}, range={range_str}")
|
||
result = {
|
||
'data': [{'列1': '值1', '列2': '值2'}],
|
||
'sheet': sheet,
|
||
'range': range_str
|
||
}
|
||
|
||
elif operation == 'write':
|
||
# 写入Excel
|
||
data = input_data
|
||
if isinstance(input_data, dict):
|
||
data = input_data.get('data') or input_data.get('rows')
|
||
|
||
if not isinstance(data, list):
|
||
data = [data]
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# import pandas as pd
|
||
# df = pd.DataFrame(data)
|
||
# output = io.BytesIO()
|
||
# df.to_excel(output, sheet_name=sheet, index=False)
|
||
# result = output.getvalue()
|
||
|
||
logger.info(f"Excel写入: sheet={sheet}, rows={len(data)}")
|
||
result = {
|
||
'excel': '生成的Excel数据(需要安装openpyxl或pandas)',
|
||
'sheet': sheet,
|
||
'rows': len(data)
|
||
}
|
||
|
||
elif operation == 'convert':
|
||
# 格式转换
|
||
target_format = node_data.get('target_format', 'csv')
|
||
excel_data = input_data
|
||
if isinstance(input_data, dict):
|
||
excel_data = input_data.get('excel') or input_data.get('file')
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# import pandas as pd
|
||
# df = pd.read_excel(io.BytesIO(excel_data))
|
||
# if target_format == 'csv':
|
||
# result = df.to_csv(index=False)
|
||
# elif target_format == 'json':
|
||
# result = df.to_json(orient='records')
|
||
|
||
logger.info(f"Excel格式转换: {format_type} -> {target_format}")
|
||
result = {
|
||
'data': '转换后的数据(需要安装pandas)',
|
||
'format': target_format
|
||
}
|
||
|
||
elif operation == 'formula':
|
||
# 公式计算
|
||
formula = node_data.get('formula', '')
|
||
data = input_data
|
||
if isinstance(input_data, dict):
|
||
data = input_data.get('data')
|
||
|
||
# 这里只是接口框架,实际需要:
|
||
# import pandas as pd
|
||
# df = pd.DataFrame(data)
|
||
# # 使用eval或更安全的方式计算公式
|
||
# result = df.eval(formula)
|
||
|
||
logger.info(f"Excel公式计算: formula={formula}")
|
||
result = {
|
||
'result': '公式计算结果(需要安装pandas)',
|
||
'formula': formula
|
||
}
|
||
|
||
else:
|
||
result = input_data
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'Excel处理失败: {str(e)}。注意:需要安装openpyxl或pandas库(pip install openpyxl pandas)'
|
||
}
|
||
|
||
elif node_type == 'subworkflow' or node_type == 'invoke_agent':
|
||
# 子工作流/委派节点:调用其他工作流或 Agent
|
||
node_data = node.get('data', {})
|
||
input_mapping = node_data.get('input_mapping', {})
|
||
try:
|
||
max_depth = int(node_data.get('max_subworkflow_depth', 2) or 2)
|
||
if max_depth < 1:
|
||
max_depth = 1
|
||
cur_depth = 0
|
||
if isinstance(input_data, dict):
|
||
try:
|
||
cur_depth = int(input_data.get('__subworkflow_depth', 0) or 0)
|
||
except (TypeError, ValueError):
|
||
cur_depth = 0
|
||
if cur_depth >= max_depth:
|
||
raise ValueError(
|
||
f"子工作流调用深度超限: current={cur_depth}, max={max_depth}"
|
||
)
|
||
|
||
# 将当前输入根据映射转换为子工作流输入
|
||
sub_input = self._build_subworkflow_input(input_data, input_mapping)
|
||
if isinstance(sub_input, dict):
|
||
sub_input['__subworkflow_depth'] = cur_depth + 1
|
||
|
||
if node_type == 'invoke_agent' and not node_data.get('agent_id'):
|
||
_aid = node_data.get('target_agent_id')
|
||
if _aid:
|
||
node_data = {**node_data, 'agent_id': _aid}
|
||
|
||
target_type, target_id, sub_workflow_data = self._resolve_subworkflow_target(node_data)
|
||
sub_workflow_id = (
|
||
f"agent_{target_id}" if target_type == "agent" else target_id
|
||
)
|
||
child_execution = None
|
||
child_logger = self.logger
|
||
sub_started_at = time.time()
|
||
parent_execution_id = None
|
||
if self.logger and getattr(self.logger, "execution_id", None):
|
||
parent_execution_id = str(self.logger.execution_id)
|
||
|
||
if self.db is not None:
|
||
child_execution = Execution(
|
||
workflow_id=target_id if target_type == "workflow" else None,
|
||
agent_id=target_id if target_type == "agent" else None,
|
||
input_data=sub_input,
|
||
status="running",
|
||
parent_execution_id=parent_execution_id,
|
||
depth=cur_depth + 1,
|
||
)
|
||
self.db.add(child_execution)
|
||
self.db.commit()
|
||
self.db.refresh(child_execution)
|
||
child_logger = ExecutionLogger(str(child_execution.id), self.db)
|
||
child_logger.info(
|
||
f"子工作流开始执行: target_type={target_type}, target_id={target_id}",
|
||
node_id=node_id,
|
||
node_type=node_type,
|
||
data={"parent_execution_id": parent_execution_id, "depth": cur_depth + 1},
|
||
)
|
||
|
||
from app.services.execution_budget import merge_budget_for_execution
|
||
|
||
if self.db is not None and child_execution is not None:
|
||
child_budget = merge_budget_for_execution(self.db, child_execution)
|
||
else:
|
||
child_budget = self.budget_limits
|
||
child_engine = WorkflowEngine(
|
||
sub_workflow_id,
|
||
sub_workflow_data,
|
||
logger=child_logger,
|
||
db=self.db,
|
||
budget_limits=child_budget,
|
||
trusted_model_config_user_id=self.trusted_model_config_user_id,
|
||
)
|
||
try:
|
||
child_result = await child_engine.execute(sub_input)
|
||
if child_execution is not None:
|
||
child_execution.status = "completed"
|
||
child_execution.output_data = child_result
|
||
child_execution.execution_time = int(
|
||
(time.time() - sub_started_at) * 1000
|
||
)
|
||
self.db.commit()
|
||
except Exception as sub_e:
|
||
if child_execution is not None:
|
||
child_execution.status = "failed"
|
||
child_execution.error_message = str(sub_e)
|
||
child_execution.execution_time = int(
|
||
(time.time() - sub_started_at) * 1000
|
||
)
|
||
self.db.commit()
|
||
raise
|
||
|
||
result = {
|
||
'target_type': target_type,
|
||
'target_id': target_id,
|
||
'child_execution_id': str(child_execution.id) if child_execution is not None else None,
|
||
'input': sub_input,
|
||
'status': child_result.get('status', 'completed'),
|
||
'result': child_result.get('result'),
|
||
'node_results': child_result.get('node_results'),
|
||
}
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'子工作流执行失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'code':
|
||
# 代码执行节点:支持简单的Python/JavaScript片段执行(注意安全)
|
||
node_data = node.get('data', {})
|
||
language = node_data.get('language', 'python')
|
||
code = node_data.get('code', '')
|
||
timeout = node_data.get('timeout', 30)
|
||
try:
|
||
if language.lower() == 'python':
|
||
# 受限执行环境(禁止无 __builtins__,否则 isinstance 等不可用)
|
||
# 注入 loads/dumps;使用「globals == locals」同一命名空间 exec,
|
||
# 避免嵌套函数 LOAD_GLOBAL 找不到仅在 locals 里的 loads,以及 json 作用域异常。
|
||
_code_globs = {
|
||
'__builtins__': _CODE_NODE_SAFE_BUILTINS,
|
||
'hashlib': hashlib,
|
||
're': re,
|
||
'json': json,
|
||
}
|
||
shared_ns: Dict[str, Any] = dict(_code_globs)
|
||
shared_ns.update(
|
||
{
|
||
'input_data': input_data,
|
||
'result': None,
|
||
'loads': json.loads,
|
||
'dumps': json.dumps,
|
||
}
|
||
)
|
||
exec(code, shared_ns, shared_ns)
|
||
result = shared_ns.get(
|
||
'result', shared_ns.get('output', input_data)
|
||
)
|
||
elif language.lower() == 'javascript':
|
||
# JS 执行需要外部运行时,这里仅占位
|
||
result = {
|
||
'status': 'not_implemented',
|
||
'message': 'JavaScript执行需集成运行时'
|
||
}
|
||
else:
|
||
result = {'status': 'failed', 'error': f'不支持的语言: {language}'}
|
||
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'代码执行失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'oauth':
|
||
# OAuth 节点:获取/刷新 Token
|
||
node_data = node.get('data', {})
|
||
provider = node_data.get('provider', 'google')
|
||
client_id = node_data.get('client_id', '')
|
||
client_secret = node_data.get('client_secret', '')
|
||
scopes = node_data.get('scopes', [])
|
||
try:
|
||
# 简化占位实现,返回模拟 token
|
||
token_data = {
|
||
'access_token': f'mock_access_token_{provider}',
|
||
'expires_in': 3600,
|
||
'token_type': 'Bearer',
|
||
'scope': scopes
|
||
}
|
||
exec_result = {'output': token_data, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, token_data, duration)
|
||
return exec_result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'OAuth处理失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'validator':
|
||
# 数据验证节点:基于简化的schema检查
|
||
node_data = node.get('data', {})
|
||
schema = node_data.get('schema', {})
|
||
on_error = node_data.get('on_error', 'reject') # reject, continue, transform
|
||
try:
|
||
# 简单类型检查
|
||
if 'type' in schema:
|
||
expected_type = schema['type']
|
||
actual_type = type(input_data).__name__
|
||
if expected_type == 'object' and not isinstance(input_data, dict):
|
||
raise ValueError(f'期望类型object,实际类型{actual_type}')
|
||
if expected_type == 'array' and not isinstance(input_data, list):
|
||
raise ValueError(f'期望类型array,实际类型{actual_type}')
|
||
result = input_data
|
||
exec_result = {'output': result, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
except Exception as e:
|
||
if on_error == 'continue':
|
||
return {'output': input_data, 'status': 'success', 'warning': str(e)}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'数据验证失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'batch':
|
||
# 批处理节点:数据分批处理
|
||
node_data = node.get('data', {})
|
||
batch_size = node_data.get('batch_size', 100)
|
||
mode = node_data.get('mode', 'split') # split, group, aggregate
|
||
wait_for_completion = node_data.get('wait_for_completion', True)
|
||
try:
|
||
data_list = input_data if isinstance(input_data, list) else [input_data]
|
||
if mode == 'split':
|
||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||
result = batches
|
||
elif mode == 'group':
|
||
batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
|
||
result = batches
|
||
elif mode == 'aggregate':
|
||
result = {
|
||
'count': len(data_list),
|
||
'samples': data_list[:min(3, len(data_list))]
|
||
}
|
||
else:
|
||
result = data_list
|
||
exec_result = {'output': result, 'status': 'success', 'wait': wait_for_completion}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, result, duration)
|
||
return exec_result
|
||
except Exception as e:
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': f'批处理失败: {str(e)}'
|
||
}
|
||
|
||
elif node_type == 'output' or node_type == 'end':
|
||
# 输出节点:返回最终结果
|
||
import json as json_module
|
||
# 读取节点配置中的输出格式设置
|
||
node_data = node.get('data', {})
|
||
output_format = node_data.get('output_format', 'text') # 默认纯文本
|
||
|
||
logger.debug(f"[rjb] End节点处理: node_id={node_id}, output_format={output_format}, input_data={input_data}, input_data type={type(input_data)}")
|
||
final_output = input_data
|
||
# 上游常用 sourceHandle=right:整包在 input_data['right']。若不展开,仅余 user_id 等顶层字段会被当成最终文本
|
||
if isinstance(input_data, dict):
|
||
_ex = input_data.copy()
|
||
for _branch_key in ('true', 'false'):
|
||
_bp = _ex.get(_branch_key)
|
||
if isinstance(_bp, dict):
|
||
_ex.update(_bp)
|
||
for _k in ('true', 'false', '_condition_result', '_condition_error'):
|
||
_ex.pop(_k, None)
|
||
if isinstance(_ex.get('right'), dict):
|
||
_ex.update(_ex['right'])
|
||
elif isinstance(_ex.get('right'), str):
|
||
_rs = _ex['right'].strip()
|
||
if _rs.startswith('{'):
|
||
try:
|
||
_rj = json_module.loads(_rs)
|
||
if isinstance(_rj, dict):
|
||
_ex.update(_rj)
|
||
except Exception:
|
||
# 非 JSON 但以 { 开头(如 Markdown),仍当作正文
|
||
_ex['output'] = _ex['right']
|
||
else:
|
||
# LLM 节点 output 常为纯文本,经带 handle 的边落在 right;须写入 output,
|
||
# 否则下游按 dict 拼接会把 user_id(preview_xxx)拼在回复末尾。
|
||
_ex['output'] = _ex['right']
|
||
input_data = _ex
|
||
final_output = input_data
|
||
|
||
# 如果配置为JSON格式,直接返回原始数据(或格式化的JSON)
|
||
if output_format == 'json':
|
||
# 如果是字典,直接返回JSON格式
|
||
if isinstance(input_data, dict):
|
||
final_output = json_module.dumps(input_data, ensure_ascii=False, indent=2)
|
||
elif isinstance(input_data, str):
|
||
# 尝试解析为JSON,如果成功则格式化,否则直接返回
|
||
try:
|
||
parsed = json_module.loads(input_data)
|
||
final_output = json_module.dumps(parsed, ensure_ascii=False, indent=2)
|
||
except:
|
||
final_output = input_data
|
||
else:
|
||
final_output = json_module.dumps({'output': input_data}, ensure_ascii=False, indent=2)
|
||
else:
|
||
# 默认纯文本格式:递归解包,提取实际的文本内容
|
||
if isinstance(input_data, dict):
|
||
# 优先提取 'output' 字段(LLM节点的标准输出格式)
|
||
if 'output' in input_data and isinstance(input_data['output'], str):
|
||
final_output = input_data['output']
|
||
# 如果只有一个 key 且是 'input',提取其值
|
||
elif len(input_data) == 1 and 'input' in input_data:
|
||
final_output = input_data['input']
|
||
# 如果包含 'solution' 字段,提取其值
|
||
elif 'solution' in input_data and isinstance(input_data['solution'], str):
|
||
final_output = input_data['solution']
|
||
# 如果input_data是字符串类型的字典(JSON字符串),尝试解析
|
||
elif isinstance(input_data, str):
|
||
try:
|
||
parsed = json_module.loads(input_data)
|
||
if isinstance(parsed, dict) and 'output' in parsed:
|
||
final_output = parsed['output']
|
||
elif isinstance(parsed, str):
|
||
final_output = parsed
|
||
except:
|
||
final_output = input_data
|
||
logger.debug(f"[rjb] End节点提取第一层: final_output={final_output}, type={type(final_output)}")
|
||
# 如果提取的值仍然是字典且只有一个 'input' key,继续提取
|
||
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
|
||
final_output = final_output['input']
|
||
logger.debug(f"[rjb] End节点提取第二层: final_output={final_output}, type={type(final_output)}")
|
||
|
||
# 确保最终输出是字符串(对于人机交互场景)
|
||
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
|
||
if not isinstance(final_output, str):
|
||
if isinstance(final_output, dict):
|
||
# 如果是字典,尝试提取文本内容或转换为JSON字符串
|
||
# 优先查找常见的文本字段
|
||
if 'text' in final_output:
|
||
final_output = str(final_output['text'])
|
||
elif 'content' in final_output:
|
||
final_output = str(final_output['content'])
|
||
elif 'message' in final_output:
|
||
final_output = str(final_output['message'])
|
||
elif 'response' in final_output:
|
||
final_output = str(final_output['response'])
|
||
elif len(final_output) == 1:
|
||
# 如果只有一个key,直接使用其值
|
||
final_output = str(list(final_output.values())[0])
|
||
else:
|
||
# 否则转换为纯文本(不是JSON)
|
||
# 尝试提取所有文本字段并组合,但排除系统字段和用户查询字段
|
||
text_parts = []
|
||
exclude_keys = {
|
||
'status',
|
||
'error',
|
||
'timestamp',
|
||
'node_id',
|
||
'execution_time',
|
||
'query',
|
||
'USER_INPUT',
|
||
'user_input',
|
||
'user_query',
|
||
'user_id',
|
||
'USER_ID',
|
||
'userId',
|
||
'attachments',
|
||
'memory',
|
||
'conversation_history',
|
||
'user_profile',
|
||
'context',
|
||
'right',
|
||
'left',
|
||
'data',
|
||
}
|
||
# 优先使用input字段(LLM的实际输出)
|
||
if 'input' in final_output and isinstance(final_output['input'], str):
|
||
final_output = final_output['input']
|
||
else:
|
||
for key, value in final_output.items():
|
||
if key in exclude_keys:
|
||
continue
|
||
if isinstance(value, str) and value.strip():
|
||
# 如果值本身已经包含 "key: " 格式,直接使用
|
||
if value.strip().startswith(f"{key}:"):
|
||
text_parts.append(value.strip())
|
||
else:
|
||
text_parts.append(value.strip())
|
||
elif isinstance(value, (int, float, bool)):
|
||
text_parts.append(f"{key}: {value}")
|
||
if text_parts:
|
||
final_output = "\n".join(text_parts)
|
||
else:
|
||
final_output = str(final_output)
|
||
else:
|
||
final_output = str(final_output)
|
||
|
||
# 清理输出文本:移除常见的字段前缀(如 "input: ", "query: " 等)
|
||
if isinstance(final_output, str):
|
||
# 移除行首的 "input: ", "query: ", "output: " 等前缀
|
||
lines = final_output.split('\n')
|
||
cleaned_lines = []
|
||
for line in lines:
|
||
# 匹配行首的 "字段名: " 格式并移除
|
||
# 但保留内容本身
|
||
line = re.sub(r'^(input|query|output|result|response|message|content|text):\s*', '', line, flags=re.IGNORECASE)
|
||
if line.strip(): # 只保留非空行
|
||
cleaned_lines.append(line)
|
||
|
||
# 如果清理后还有内容,使用清理后的版本
|
||
if cleaned_lines:
|
||
final_output = '\n'.join(cleaned_lines)
|
||
# 如果清理后为空,使用原始输出(避免丢失所有内容)
|
||
elif final_output.strip():
|
||
# 如果原始输出不为空,但清理后为空,说明可能格式特殊,尝试更宽松的清理
|
||
# 只移除明显的 "input: " 和 "query: " 前缀,保留其他内容
|
||
final_output = re.sub(r'^(input|query):\s*', '', final_output, flags=re.IGNORECASE | re.MULTILINE)
|
||
if not final_output.strip():
|
||
final_output = str(input_data) # 如果还是空,使用原始输入
|
||
|
||
if output_format != "json":
|
||
final_output = self._replace_if_template_placeholder(final_output)
|
||
|
||
final_output = self._resolve_end_output_if_vector_metadata(final_output, input_data)
|
||
logger.debug(f"[rjb] End节点最终输出: output_format={output_format}, final_output={final_output[:100] if isinstance(final_output, str) else type(final_output)}")
|
||
result = {'output': final_output, 'status': 'success'}
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_complete(node_id, node_type, final_output, duration)
|
||
return result
|
||
|
||
else:
|
||
# 未知节点类型
|
||
logger.warning(f"[rjb] 未知节点类型: node_id={node_id}, node_type={node_type}, node keys={list(node.keys())}")
|
||
return {
|
||
'output': input_data,
|
||
'status': 'success',
|
||
'message': f'节点类型 {node_type} 暂未实现'
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"节点执行失败: {node_id} ({node_type}) - {str(e)}", exc_info=True)
|
||
if self.logger:
|
||
duration = int((time.time() - start_time) * 1000)
|
||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||
return {
|
||
'output': None,
|
||
'status': 'failed',
|
||
'error': str(e),
|
||
'node_id': node_id,
|
||
'node_type': node_type
|
||
}
|
||
|
||
|
||
async def execute(
|
||
self,
|
||
input_data: Dict[str, Any],
|
||
resume_snapshot: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
执行完整工作流
|
||
|
||
Args:
|
||
input_data: 初始输入数据(恢复执行时须包含 __hil_decision 等)
|
||
resume_snapshot: 从挂起快照恢复(与 pause_state 一致)
|
||
|
||
Returns:
|
||
执行结果
|
||
"""
|
||
if not resume_snapshot and isinstance(input_data, dict):
|
||
dsl_box = input_data.get("scenario_dsl")
|
||
if dsl_box is None and isinstance(input_data.get("scenario"), dict):
|
||
dsl_box = input_data["scenario"]
|
||
if dsl_box is not None:
|
||
ok, errs = validate_scenario_dsl(dsl_box)
|
||
if not ok:
|
||
raise WorkflowExecutionError(
|
||
detail="scenario_dsl 校验失败: " + "; ".join(errs),
|
||
)
|
||
norm = normalize_scenario_dsl(dsl_box)
|
||
input_data = {**input_data, "_scenario": norm}
|
||
|
||
self._initial_input_data = input_data
|
||
# 记录工作流开始执行
|
||
if self.logger:
|
||
self.logger.info(
|
||
"工作流开始执行",
|
||
data={"input": input_data, "resume": bool(resume_snapshot)},
|
||
)
|
||
|
||
if resume_snapshot:
|
||
self.node_outputs = self._json_safe_copy(resume_snapshot.get("node_outputs", {}))
|
||
active_edges = list(resume_snapshot.get("active_edges", []))
|
||
executed_nodes = set(resume_snapshot.get("executed_nodes", []))
|
||
execution_sequence = list(resume_snapshot.get("execution_sequence", []))
|
||
self._steps_used = int(resume_snapshot.get("steps_used", 0))
|
||
self._llm_invocations = int(resume_snapshot.get("llm_invocations", 0))
|
||
self._tool_calls_used = int(resume_snapshot.get("tool_calls_used", 0))
|
||
results = self._json_safe_copy(resume_snapshot.get("node_results_partial", {}))
|
||
else:
|
||
# 初始化节点输出
|
||
self.node_outputs = {}
|
||
active_edges = self.edges.copy() # 活跃的边列表
|
||
executed_nodes = set() # 已执行的节点
|
||
execution_sequence = [] # 实际执行顺序(用于最终输出节点选择)
|
||
self._steps_used = 0
|
||
self._llm_invocations = 0
|
||
self._tool_calls_used = 0
|
||
results = {}
|
||
|
||
# 按拓扑顺序执行节点(动态构建执行图)
|
||
while True:
|
||
# 构建当前活跃的执行图
|
||
execution_order = self.build_execution_graph(active_edges)
|
||
forward_reachable = self._forward_reachable_nodes(active_edges)
|
||
order_pos = {nid: i for i, nid in enumerate(execution_order)}
|
||
# 拓扑序可能不含「双前驱但仅一条分支可达」的汇合点,故候选为全部未执行节点并按 execution_order 优先
|
||
pending_ids = sorted(
|
||
(nid for nid in self.nodes if nid not in executed_nodes),
|
||
key=lambda x: (order_pos.get(x, 1_000_000), x),
|
||
)
|
||
logger.debug(
|
||
f"[rjb] 当前执行图: {execution_order}, 活跃边数: {len(active_edges)}, 已执行节点: {executed_nodes}"
|
||
)
|
||
|
||
next_node_id = None
|
||
for node_id in pending_ids:
|
||
can_execute = True
|
||
incoming_edges = [e for e in active_edges if e["target"] == node_id]
|
||
if not incoming_edges:
|
||
# 没有入边:仅允许 Start;孤立节点跳过
|
||
if node_id not in [n["id"] for n in self.nodes.values() if n.get("type") == "start"]:
|
||
logger.debug(f"[rjb] 节点 {node_id} 没有入边,跳过执行")
|
||
continue
|
||
else:
|
||
for edge in incoming_edges:
|
||
src = edge["source"]
|
||
if src not in forward_reachable:
|
||
# 条件分支裁剪后不可达的前驱,不参与 gate(OR-join)
|
||
continue
|
||
if src not in executed_nodes:
|
||
can_execute = False
|
||
logger.debug(
|
||
f"[rjb] 节点 {node_id} 的前置节点 {src} 未执行,不能执行"
|
||
)
|
||
break
|
||
if can_execute:
|
||
next_node_id = node_id
|
||
logger.info(
|
||
f"[rjb] 选择执行节点: {next_node_id}, 类型: {self.nodes[next_node_id].get('type')}, 入边数: {len(incoming_edges)}"
|
||
)
|
||
break
|
||
|
||
if not next_node_id:
|
||
break # 没有更多节点可执行
|
||
|
||
node = self.nodes[next_node_id]
|
||
is_approval = node.get("type") == "approval"
|
||
if not is_approval:
|
||
executed_nodes.add(next_node_id)
|
||
execution_sequence.append(next_node_id)
|
||
|
||
# 调试:检查节点数据结构
|
||
if node.get('type') == 'llm':
|
||
logger.debug(f"[rjb] 执行LLM节点: node_id={next_node_id}, node keys={list(node.keys())}, data keys={list(node.get('data', {}).keys()) if node.get('data') else []}")
|
||
|
||
# 获取节点输入(使用活跃的边)
|
||
node_input = self.get_node_input(next_node_id, self.node_outputs, active_edges)
|
||
|
||
# 如果是起始节点,使用初始输入
|
||
if node.get('type') == 'start' and not node_input:
|
||
node_input = input_data
|
||
logger.info(f"[rjb] Start节点使用初始输入: node_id={next_node_id}, node_input={node_input}")
|
||
|
||
# 调试:记录节点输入数据
|
||
if node.get('type') == 'llm':
|
||
logger.info(f"[rjb] LLM节点输入: node_id={next_node_id}, node_input={node_input}, node_outputs keys={list(self.node_outputs.keys())}")
|
||
if 'start-1' in self.node_outputs:
|
||
logger.info(f"[rjb] Start节点输出内容: {self.node_outputs['start-1']}")
|
||
|
||
# 单执行步数预算(每执行一个节点计 1 步)
|
||
self._steps_used += 1
|
||
if self._steps_used > self._cap_steps:
|
||
raise WorkflowExecutionError(
|
||
detail=f"已超过单执行预算上限({self._cap_steps} 步),已熔断",
|
||
node_id=next_node_id,
|
||
)
|
||
|
||
# 执行节点
|
||
result = await self.execute_node(node, node_input)
|
||
|
||
if result.get("status") == "awaiting_approval":
|
||
self._steps_used -= 1
|
||
snap = self._build_pause_snapshot(
|
||
next_node_id, active_edges, executed_nodes, execution_sequence, results
|
||
)
|
||
raise WorkflowPaused(snap)
|
||
|
||
if is_approval:
|
||
executed_nodes.add(next_node_id)
|
||
execution_sequence.append(next_node_id)
|
||
|
||
results[next_node_id] = result
|
||
|
||
# 保存节点输出
|
||
if result.get('status') == 'success':
|
||
output_value = result.get('output', {})
|
||
self.node_outputs[next_node_id] = output_value
|
||
if node.get('type') == 'start':
|
||
logger.info(f"[rjb] Start节点输出已保存: node_id={next_node_id}, output={output_value}, output_type={type(output_value)}")
|
||
|
||
# 如果是条件节点或Switch节点,根据分支结果过滤边
|
||
if node.get('type') == 'condition':
|
||
branch = result.get('branch', 'false')
|
||
logger.info(f"[rjb] 条件节点分支过滤: node_id={next_node_id}, branch={branch}")
|
||
# 移除不符合条件的边
|
||
# 只保留:1) 不是从条件节点出发的边,或 2) 从条件节点出发且sourceHandle匹配分支的边
|
||
edges_to_remove = []
|
||
edges_to_keep = []
|
||
for edge in active_edges:
|
||
if edge['source'] == next_node_id:
|
||
# 这是从条件节点出发的边
|
||
edge_handle = edge.get('sourceHandle')
|
||
if edge_handle == branch:
|
||
# sourceHandle匹配分支,保留
|
||
edges_to_keep.append(edge)
|
||
logger.info(f"[rjb] 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
|
||
else:
|
||
# sourceHandle不匹配或为None,移除
|
||
edges_to_remove.append(edge)
|
||
logger.info(f"[rjb] 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch})")
|
||
else:
|
||
# 不是从条件节点出发的边,保留
|
||
edges_to_keep.append(edge)
|
||
|
||
active_edges = edges_to_keep
|
||
|
||
elif node.get('type') == 'switch':
|
||
branch = result.get('branch', 'default')
|
||
logger.info(f"[rjb] Switch节点分支过滤: node_id={next_node_id}, branch={branch}")
|
||
|
||
# 记录过滤前的边信息
|
||
edges_before = [e for e in active_edges if e['source'] == next_node_id]
|
||
logger.info(f"[rjb] Switch节点过滤前: 从节点出发的边有{len(edges_before)}条")
|
||
for edge in edges_before:
|
||
logger.info(f"[rjb] 边 {edge.get('id')}: sourceHandle={edge.get('sourceHandle')}, target={edge.get('target')}")
|
||
|
||
# 移除不匹配的边
|
||
edges_to_keep = []
|
||
edges_removed_count = 0
|
||
removed_source_nodes = set() # 记录被移除边的源节点
|
||
|
||
for edge in active_edges:
|
||
if edge['source'] == next_node_id:
|
||
# 这是从Switch节点出发的边
|
||
edge_handle = edge.get('sourceHandle')
|
||
if edge_handle == branch:
|
||
# sourceHandle匹配分支,保留
|
||
edges_to_keep.append(edge)
|
||
logger.info(f"[rjb] ✅ 保留边: {edge.get('id')} (sourceHandle={edge_handle} == branch={branch})")
|
||
else:
|
||
# sourceHandle不匹配,移除
|
||
edges_removed_count += 1
|
||
target_id = edge.get('target')
|
||
removed_source_nodes.add(target_id) # 记录目标节点(这些节点将不再可达)
|
||
logger.info(f"[rjb] ❌ 移除边: {edge.get('id')} (sourceHandle={edge_handle} != branch={branch}, target={target_id})")
|
||
else:
|
||
# 不是从Switch节点出发的边,保留
|
||
edges_to_keep.append(edge)
|
||
|
||
# 重要:移除那些指向被过滤节点的边(这些边来自被过滤的LLM节点)
|
||
# 例如:如果llm-question被过滤了,那么llm-question → merge-response的边也应该被移除
|
||
additional_removed = 0
|
||
for edge in list(edges_to_keep): # 使用list副本,因为我们要修改原列表
|
||
if edge['source'] in removed_source_nodes:
|
||
# 这条边来自被过滤的节点,也应该被移除
|
||
edges_to_keep.remove(edge)
|
||
additional_removed += 1
|
||
logger.info(f"[rjb] ❌ 移除来自被过滤节点的边: {edge.get('id')} ({edge.get('source')} → {edge.get('target')})")
|
||
|
||
edges_removed_count += additional_removed
|
||
|
||
active_edges = edges_to_keep
|
||
filter_info = {
|
||
'branch': branch,
|
||
'edges_before': len(edges_before),
|
||
'edges_kept': len([e for e in edges_to_keep if e['source'] == next_node_id]),
|
||
'edges_removed': edges_removed_count
|
||
}
|
||
logger.info(f"[rjb] Switch节点过滤后: 保留{len(active_edges)}条边(其中从Switch节点出发的{filter_info['edges_kept']}条),移除{edges_removed_count}条边")
|
||
# 记录过滤后的活跃边
|
||
remaining_switch_edges = [e for e in active_edges if e['source'] == next_node_id]
|
||
logger.info(f"[rjb] Switch节点过滤后剩余的边: {[e.get('id') + '->' + e.get('target') for e in remaining_switch_edges]}")
|
||
|
||
# 重要:找出那些不再可达的节点(这些节点只通过被移除的边连接)
|
||
removed_targets = set()
|
||
for edge in edges_before:
|
||
if edge not in edges_to_keep:
|
||
target_id = edge.get('target')
|
||
removed_targets.add(target_id)
|
||
logger.info(f"[rjb] ❌ 节点 {target_id} 的边已被移除,该节点将不会被执行")
|
||
|
||
# 关键修复:立即重新构建执行图,确保不再可达的节点不在执行图中
|
||
# 这样在下次循环时,这些节点就不会被选择执行
|
||
logger.info(f"[rjb] Switch节点过滤后,重新构建执行图(排除 {len(removed_targets)} 个不再可达的节点)")
|
||
|
||
# 同时记录到数据库
|
||
if self.logger:
|
||
self.logger.info(
|
||
f"Switch节点分支过滤: branch={branch}, 保留{filter_info['edges_kept']}条边,移除{edges_removed_count}条边",
|
||
node_id=next_node_id,
|
||
node_type='switch',
|
||
data=filter_info
|
||
)
|
||
|
||
elif node.get('type') == 'approval':
|
||
branch = result.get('branch', 'approved')
|
||
logger.info(f"[rjb] Approval节点分支过滤: node_id={next_node_id}, branch={branch}")
|
||
edges_to_keep = []
|
||
removed_source_nodes = set()
|
||
for edge in active_edges:
|
||
if edge['source'] == next_node_id:
|
||
edge_handle = edge.get('sourceHandle')
|
||
if edge_handle == branch:
|
||
edges_to_keep.append(edge)
|
||
else:
|
||
removed_source_nodes.add(edge.get('target'))
|
||
else:
|
||
edges_to_keep.append(edge)
|
||
for edge in list(edges_to_keep):
|
||
if edge['source'] in removed_source_nodes:
|
||
edges_to_keep.remove(edge)
|
||
active_edges = edges_to_keep
|
||
if self.logger:
|
||
self.logger.info(
|
||
f"Approval节点分支过滤: branch={branch}",
|
||
node_id=next_node_id,
|
||
node_type='approval',
|
||
)
|
||
|
||
# 如果是循环节点,跳过循环体的节点(循环体已在节点内部执行)
|
||
if node.get('type') in ['loop', 'foreach']:
|
||
# 标记循环体的节点为已执行(简化处理)
|
||
for edge in active_edges[:]: # 使用切片复制列表
|
||
if edge.get('source') == next_node_id:
|
||
target_id = edge.get('target')
|
||
if target_id in self.nodes:
|
||
# 检查是否是循环结束节点
|
||
target_node = self.nodes[target_id]
|
||
if target_node.get('type') not in ['loop_end', 'end']:
|
||
# 标记为已执行(循环体已在循环节点内部执行)
|
||
executed_nodes.add(target_id)
|
||
# 继续查找循环体内的节点
|
||
self._mark_loop_body_executed(target_id, executed_nodes, active_edges)
|
||
else:
|
||
# 执行失败,停止工作流
|
||
error_msg = result.get('error', '未知错误')
|
||
node_type = node.get('type', 'unknown')
|
||
logger.error(f"工作流执行失败 - 节点: {next_node_id} ({node_type}), 错误: {error_msg}")
|
||
raise WorkflowExecutionError(
|
||
detail=error_msg,
|
||
node_id=next_node_id
|
||
)
|
||
|
||
# 返回最终结果:优先取 End 类型且无出边的节点,避免向量写入等侧链与 End 同为 sink 时
|
||
# 因 executed_nodes 为 set 迭代顺序不确定而错误返回 upsert 元数据。
|
||
if executed_nodes:
|
||
sink_nodes = [
|
||
nid
|
||
for nid in executed_nodes
|
||
if not any(edge["source"] == nid for edge in active_edges)
|
||
]
|
||
|
||
def _pick_latest_in_sequence(cands: List[str]) -> Optional[str]:
|
||
best: Optional[str] = None
|
||
best_pos = -1
|
||
for nid in cands:
|
||
try:
|
||
pos = execution_sequence.index(nid)
|
||
except ValueError:
|
||
continue
|
||
if pos > best_pos:
|
||
best_pos = pos
|
||
best = nid
|
||
return best
|
||
|
||
last_node_id: Optional[str] = None
|
||
end_sinks = [
|
||
nid
|
||
for nid in sink_nodes
|
||
if self.nodes.get(nid, {}).get("type") == "end"
|
||
]
|
||
if end_sinks:
|
||
last_node_id = _pick_latest_in_sequence(end_sinks)
|
||
if not last_node_id and sink_nodes:
|
||
last_node_id = _pick_latest_in_sequence(sink_nodes)
|
||
if not last_node_id and execution_sequence:
|
||
last_node_id = execution_sequence[-1]
|
||
|
||
# 获取最终结果
|
||
final_output = self.node_outputs.get(last_node_id)
|
||
|
||
# 如果最终输出是字典且只有一个 'input' key,提取其值
|
||
# 这样可以确保最终结果不是重复包装的格式
|
||
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
|
||
final_output = final_output['input']
|
||
# 如果提取的值仍然是字典且只有一个 'input' key,继续提取
|
||
if isinstance(final_output, dict) and len(final_output) == 1 and 'input' in final_output:
|
||
final_output = final_output['input']
|
||
|
||
# 确保最终结果是字符串(对于人机交互场景)
|
||
# 如果是字典,尝试转换为字符串;如果是其他类型,也转换为字符串
|
||
if not isinstance(final_output, str):
|
||
if isinstance(final_output, dict):
|
||
# 如果是字典,尝试提取文本内容或转换为JSON字符串
|
||
# 优先查找常见的文本字段
|
||
if 'text' in final_output:
|
||
final_output = str(final_output['text'])
|
||
elif 'content' in final_output:
|
||
final_output = str(final_output['content'])
|
||
elif 'message' in final_output:
|
||
final_output = str(final_output['message'])
|
||
elif 'response' in final_output:
|
||
final_output = str(final_output['response'])
|
||
elif len(final_output) == 1:
|
||
# 如果只有一个key,直接使用其值
|
||
final_output = str(list(final_output.values())[0])
|
||
else:
|
||
# 否则转换为JSON字符串
|
||
import json as json_module
|
||
final_output = json_module.dumps(final_output, ensure_ascii=False)
|
||
else:
|
||
final_output = str(final_output)
|
||
|
||
final_output = self._resolve_end_output_if_vector_metadata(final_output, final_output)
|
||
final_output = self._replace_if_template_placeholder(final_output)
|
||
|
||
final_result = {
|
||
'status': 'completed',
|
||
'result': final_output,
|
||
'node_results': results
|
||
}
|
||
|
||
# 记录工作流执行完成
|
||
if self.logger:
|
||
self.logger.info("工作流执行完成", data={"result": final_result.get('result')})
|
||
|
||
return final_result
|
||
|
||
if self.logger:
|
||
self.logger.warn("工作流执行完成,但没有执行任何节点")
|
||
|
||
return {'status': 'completed', 'result': None}
|