补齐平台模板与场景 DSL、预算控制、执行看板和企业场景脚本,增强 Windows 启动/迁移与前端代理和聊天会话记忆,修复执行创建阶段 500 与异步链路排障体验。 Made-with: Cursor
375 lines
14 KiB
Python
375 lines
14 KiB
Python
"""
|
||
工作流任务
|
||
"""
|
||
from celery import Task
|
||
from app.core.tools_bootstrap import ensure_builtin_tools_registered
|
||
|
||
ensure_builtin_tools_registered()
|
||
|
||
from app.core.celery_app import celery_app
|
||
from app.services.workflow_engine import WorkflowEngine
|
||
from app.services.execution_logger import ExecutionLogger
|
||
from app.services.alert_service import AlertService
|
||
from app.core.database import SessionLocal
|
||
from app.core.config import settings
|
||
from app.core.exceptions import WorkflowExecutionError, WorkflowPaused
|
||
# 导入所有相关模型,确保关系可以正确解析
|
||
from app.models.execution import Execution
|
||
from app.models.agent import Agent
|
||
from app.models.workflow import Workflow
|
||
from app.services.execution_budget import merge_budget_for_execution
|
||
import asyncio
|
||
import time
|
||
from typing import Any, Dict, Optional
|
||
|
||
|
||
def _format_task_error(e: Exception) -> str:
|
||
"""Celery 任务异常写入 DB 时的可读文案(HTTPException.detail 等)。"""
|
||
detail = getattr(e, "detail", None)
|
||
if isinstance(detail, str) and detail.strip():
|
||
return detail
|
||
if detail is not None and detail != "":
|
||
return str(detail)
|
||
s = str(e).strip()
|
||
return s if s else repr(e)
|
||
|
||
|
||
def _snapshot_to_jsonable(snapshot: dict) -> dict:
|
||
import json as _json
|
||
|
||
return _json.loads(_json.dumps(snapshot, default=str))
|
||
|
||
|
||
@celery_app.task(bind=True)
|
||
def execute_workflow_task(
|
||
self,
|
||
execution_id: str,
|
||
workflow_id: str,
|
||
workflow_data: dict,
|
||
input_data: dict,
|
||
resume_snapshot: Optional[dict] = None,
|
||
):
|
||
"""
|
||
执行工作流任务
|
||
|
||
Args:
|
||
execution_id: 执行记录ID
|
||
workflow_id: 工作流ID
|
||
workflow_data: 工作流数据(nodes和edges)
|
||
input_data: 输入数据
|
||
resume_snapshot: 从挂起恢复时的快照(与 Execution.pause_state 一致)
|
||
"""
|
||
db = SessionLocal()
|
||
start_time = time.time()
|
||
execution_logger = None
|
||
|
||
try:
|
||
# 更新执行状态为运行中
|
||
execution = db.query(Execution).filter(Execution.id == execution_id).first()
|
||
if execution:
|
||
execution.status = "running"
|
||
db.commit()
|
||
|
||
# 更新任务状态
|
||
self.update_state(state='PROGRESS', meta={'progress': 0, 'status': 'running'})
|
||
|
||
# 创建执行日志记录器
|
||
execution_logger = ExecutionLogger(execution_id, db)
|
||
execution_logger.info("工作流任务开始执行")
|
||
|
||
# 创建工作流引擎(传入logger、db、合并后的执行预算)
|
||
budget = merge_budget_for_execution(db, execution) if execution else {}
|
||
engine = WorkflowEngine(
|
||
workflow_id,
|
||
workflow_data,
|
||
logger=execution_logger,
|
||
db=db,
|
||
budget_limits=budget,
|
||
)
|
||
|
||
max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0))
|
||
result: Optional[dict] = None
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
result = asyncio.run(
|
||
engine.execute(input_data, resume_snapshot=resume_snapshot)
|
||
)
|
||
break
|
||
except WorkflowPaused as paused:
|
||
execution = db.query(Execution).filter(Execution.id == execution_id).first()
|
||
if execution:
|
||
execution.status = "awaiting_approval"
|
||
execution.pause_state = _snapshot_to_jsonable(paused.snapshot)
|
||
execution.error_message = None
|
||
db.commit()
|
||
execution_logger.info(
|
||
"工作流在审批节点挂起,等待人工决策",
|
||
data={"pending_node_id": paused.snapshot.get("pending_node_id")},
|
||
)
|
||
return {
|
||
"status": "awaiting_approval",
|
||
"execution_id": execution_id,
|
||
"pending_node_id": paused.snapshot.get("pending_node_id"),
|
||
}
|
||
except WorkflowExecutionError:
|
||
raise
|
||
except Exception as run_e:
|
||
if attempt >= max_retries:
|
||
raise
|
||
delay = min(30.0, 1.5**attempt)
|
||
if execution_logger:
|
||
execution_logger.warn(
|
||
f"工作流执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}",
|
||
data={
|
||
"error_code": "WORKFLOW_TASK_TRANSIENT_RETRY",
|
||
"attempt": attempt + 1,
|
||
"max_retries": max_retries,
|
||
"delay_sec": round(delay, 2),
|
||
},
|
||
)
|
||
time.sleep(delay)
|
||
if result is None:
|
||
raise RuntimeError("工作流执行未返回结果")
|
||
|
||
# 计算执行时间
|
||
execution_time = int((time.time() - start_time) * 1000)
|
||
|
||
# 更新执行记录
|
||
if execution:
|
||
execution.status = "completed"
|
||
execution.output_data = result
|
||
execution.execution_time = execution_time
|
||
execution.pause_state = None
|
||
db.commit()
|
||
|
||
# 记录执行完成日志
|
||
execution_logger.info(f"工作流任务执行完成,耗时: {execution_time}ms")
|
||
|
||
# 检查告警规则(异步)
|
||
if execution:
|
||
try:
|
||
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
|
||
except Exception as e:
|
||
# 告警检测失败不影响执行结果
|
||
execution_logger.warn(f"告警检测失败: {str(e)}")
|
||
|
||
return {
|
||
'status': 'completed',
|
||
'result': result,
|
||
'execution_time': execution_time
|
||
}
|
||
|
||
except Exception as e:
|
||
execution_time = int((time.time() - start_time) * 1000)
|
||
|
||
# 记录错误日志
|
||
err_text = _format_task_error(e)
|
||
err_code = getattr(e, "error_code", None)
|
||
if not err_code:
|
||
err_code = (
|
||
"WORKFLOW_EXECUTION_ERROR"
|
||
if isinstance(e, WorkflowExecutionError)
|
||
else "WORKFLOW_TASK_ERROR"
|
||
)
|
||
if execution_logger:
|
||
execution_logger.error(
|
||
f"工作流任务执行失败: {err_text}",
|
||
data={"error_type": type(e).__name__, "error_code": err_code},
|
||
)
|
||
|
||
# 更新执行记录为失败
|
||
execution = db.query(Execution).filter(Execution.id == execution_id).first()
|
||
if execution:
|
||
execution.status = "failed"
|
||
execution.error_message = err_text
|
||
execution.execution_time = execution_time
|
||
db.commit()
|
||
|
||
# 检查告警规则(异步)
|
||
if execution:
|
||
try:
|
||
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
|
||
except Exception as e2:
|
||
# 告警检测失败不影响错误处理
|
||
if execution_logger:
|
||
execution_logger.warn(f"告警检测失败: {str(e2)}")
|
||
|
||
raise
|
||
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
@celery_app.task(bind=True)
|
||
def resume_workflow_task(
|
||
self,
|
||
execution_id: str,
|
||
decision: str,
|
||
comment: Optional[str] = None,
|
||
):
|
||
"""在 awaiting_approval 时恢复执行(审批通过/拒绝)。"""
|
||
db = SessionLocal()
|
||
start_time = time.time()
|
||
execution_logger = None
|
||
|
||
try:
|
||
execution = db.query(Execution).filter(Execution.id == execution_id).first()
|
||
if not execution:
|
||
return {"status": "error", "detail": "执行记录不存在"}
|
||
if execution.status != "awaiting_approval":
|
||
return {
|
||
"status": "error",
|
||
"detail": f"执行状态不是 awaiting_approval: {execution.status}",
|
||
}
|
||
if not execution.pause_state:
|
||
return {"status": "error", "detail": "缺少 pause_state"}
|
||
|
||
if decision not in ("approved", "rejected"):
|
||
return {"status": "error", "detail": "decision 须为 approved 或 rejected"}
|
||
|
||
snapshot = execution.pause_state
|
||
base_input: Dict[str, Any] = dict(execution.input_data or {})
|
||
base_input["__hil_decision"] = decision
|
||
if comment:
|
||
base_input["__hil_comment"] = comment
|
||
|
||
execution.status = "running"
|
||
execution.error_message = None
|
||
execution.input_data = base_input
|
||
db.commit()
|
||
|
||
execution_logger = ExecutionLogger(execution_id, db)
|
||
execution_logger.info("审批恢复执行", data={"decision": decision})
|
||
|
||
workflow_data: dict
|
||
wf_key: str
|
||
if execution.workflow_id:
|
||
wf = db.query(Workflow).filter(Workflow.id == execution.workflow_id).first()
|
||
if not wf:
|
||
raise RuntimeError("工作流不存在")
|
||
workflow_data = {"nodes": wf.nodes, "edges": wf.edges}
|
||
wf_key = str(execution.workflow_id)
|
||
elif execution.agent_id:
|
||
ag = db.query(Agent).filter(Agent.id == execution.agent_id).first()
|
||
if not ag or not ag.workflow_config:
|
||
raise RuntimeError("Agent 或工作流配置不存在")
|
||
workflow_data = {
|
||
"nodes": ag.workflow_config.get("nodes", []),
|
||
"edges": ag.workflow_config.get("edges", []),
|
||
}
|
||
wf_key = f"agent_{execution.agent_id}"
|
||
else:
|
||
raise RuntimeError("执行未关联工作流或 Agent")
|
||
|
||
self.update_state(state="PROGRESS", meta={"progress": 0, "status": "running"})
|
||
|
||
budget = merge_budget_for_execution(db, execution) if execution else {}
|
||
engine = WorkflowEngine(
|
||
wf_key,
|
||
workflow_data,
|
||
logger=execution_logger,
|
||
db=db,
|
||
budget_limits=budget,
|
||
)
|
||
max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0))
|
||
result: Optional[dict] = None
|
||
for attempt in range(max_retries + 1):
|
||
try:
|
||
result = asyncio.run(
|
||
engine.execute(base_input, resume_snapshot=snapshot)
|
||
)
|
||
break
|
||
except WorkflowPaused as paused:
|
||
ex2 = db.query(Execution).filter(Execution.id == execution_id).first()
|
||
if ex2:
|
||
ex2.status = "awaiting_approval"
|
||
ex2.pause_state = _snapshot_to_jsonable(paused.snapshot)
|
||
ex2.error_message = None
|
||
db.commit()
|
||
if execution_logger:
|
||
execution_logger.info(
|
||
"工作流再次在审批节点挂起",
|
||
data={"pending_node_id": paused.snapshot.get("pending_node_id")},
|
||
)
|
||
return {
|
||
"status": "awaiting_approval",
|
||
"execution_id": execution_id,
|
||
"pending_node_id": paused.snapshot.get("pending_node_id"),
|
||
}
|
||
except WorkflowExecutionError:
|
||
raise
|
||
except Exception as run_e:
|
||
if attempt >= max_retries:
|
||
raise
|
||
delay = min(30.0, 1.5**attempt)
|
||
if execution_logger:
|
||
execution_logger.warn(
|
||
f"恢复执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}",
|
||
data={
|
||
"error_code": "WORKFLOW_TASK_TRANSIENT_RETRY",
|
||
"attempt": attempt + 1,
|
||
"max_retries": max_retries,
|
||
"delay_sec": round(delay, 2),
|
||
},
|
||
)
|
||
time.sleep(delay)
|
||
if result is None:
|
||
raise RuntimeError("工作流执行未返回结果")
|
||
|
||
execution_time = int((time.time() - start_time) * 1000)
|
||
ex3 = db.query(Execution).filter(Execution.id == execution_id).first()
|
||
if ex3:
|
||
ex3.status = "completed"
|
||
ex3.output_data = result
|
||
ex3.execution_time = execution_time
|
||
ex3.pause_state = None
|
||
db.commit()
|
||
|
||
if execution_logger:
|
||
execution_logger.info(f"审批后工作流执行完成,耗时: {execution_time}ms")
|
||
|
||
if ex3:
|
||
try:
|
||
asyncio.run(AlertService.check_alerts_for_execution(db, ex3))
|
||
except Exception as e:
|
||
if execution_logger:
|
||
execution_logger.warn(f"告警检测失败: {str(e)}")
|
||
|
||
return {
|
||
"status": "completed",
|
||
"result": result,
|
||
"execution_time": execution_time,
|
||
}
|
||
|
||
except Exception as e:
|
||
execution_time = int((time.time() - start_time) * 1000)
|
||
err_text = _format_task_error(e)
|
||
err_code = getattr(e, "error_code", None)
|
||
if not err_code:
|
||
err_code = (
|
||
"WORKFLOW_EXECUTION_ERROR"
|
||
if isinstance(e, WorkflowExecutionError)
|
||
else "WORKFLOW_TASK_ERROR"
|
||
)
|
||
if execution_logger:
|
||
execution_logger.error(
|
||
f"审批恢复执行失败: {err_text}",
|
||
data={"error_type": type(e).__name__, "error_code": err_code},
|
||
)
|
||
execution = db.query(Execution).filter(Execution.id == execution_id).first()
|
||
if execution:
|
||
execution.status = "failed"
|
||
execution.error_message = err_text
|
||
execution.execution_time = execution_time
|
||
db.commit()
|
||
if execution:
|
||
try:
|
||
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
|
||
except Exception as e2:
|
||
if execution_logger:
|
||
execution_logger.warn(f"告警检测失败: {str(e2)}")
|
||
raise
|
||
|
||
finally:
|
||
db.close()
|