Files
aiagent/backend/app/tasks/workflow_tasks.py
renjianbo df4fab1e6e feat: Agent 批量测试、作业助手与上传预览;Windows 启动脚本与文档- 新增 run_agent_test_cases 与示例 JSON、(红头)agent测试用例文档
- 扩展 test_agent_execution(--homework、UTF-8 控制台)
- 后端:uploads 预览、file_read、工作流与对话落盘等
- 前端:AgentChatPreview 与设计器相关调整
- 忽略 redis二进制、agent_workspaces、uploads、tessdata 等本机产物

Made-with: Cursor
2026-04-13 20:17:18 +08:00

384 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工作流任务
"""
from celery import Task
from app.core.tools_bootstrap import ensure_builtin_tools_registered
ensure_builtin_tools_registered()
from app.core.celery_app import celery_app
from app.services.workflow_engine import WorkflowEngine
from app.services.execution_logger import ExecutionLogger
from app.services.alert_service import AlertService
from app.core.database import SessionLocal
from app.core.config import settings
from app.core.exceptions import WorkflowExecutionError, WorkflowPaused
# 导入所有相关模型,确保关系可以正确解析
from app.models.execution import Execution
from app.models.agent import Agent
from app.models.workflow import Workflow
from app.services.execution_budget import merge_budget_for_execution
from app.services.agent_workspace_chat_log import try_append_agent_dialogue_after_success
import asyncio
import time
from typing import Any, Dict, Optional
def _format_task_error(e: Exception) -> str:
"""Celery 任务异常写入 DB 时的可读文案HTTPException.detail 等)。"""
detail = getattr(e, "detail", None)
if isinstance(detail, str) and detail.strip():
return detail
if detail is not None and detail != "":
return str(detail)
s = str(e).strip()
return s if s else repr(e)
def _snapshot_to_jsonable(snapshot: dict) -> dict:
import json as _json
return _json.loads(_json.dumps(snapshot, default=str))
@celery_app.task(bind=True)
def execute_workflow_task(
self,
execution_id: str,
workflow_id: str,
workflow_data: dict,
input_data: dict,
resume_snapshot: Optional[dict] = None,
):
"""
执行工作流任务
Args:
execution_id: 执行记录ID
workflow_id: 工作流ID
workflow_data: 工作流数据nodes和edges
input_data: 输入数据
resume_snapshot: 从挂起恢复时的快照(与 Execution.pause_state 一致)
"""
db = SessionLocal()
start_time = time.time()
execution_logger = None
try:
# 更新执行状态为运行中
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "running"
db.commit()
# 更新任务状态
self.update_state(state='PROGRESS', meta={'progress': 0, 'status': 'running'})
# 创建执行日志记录器
execution_logger = ExecutionLogger(execution_id, db)
execution_logger.info("工作流任务开始执行")
# 创建工作流引擎传入logger、db、合并后的执行预算
budget = merge_budget_for_execution(db, execution) if execution else {}
engine = WorkflowEngine(
workflow_id,
workflow_data,
logger=execution_logger,
db=db,
budget_limits=budget,
)
max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0))
result: Optional[dict] = None
for attempt in range(max_retries + 1):
try:
result = asyncio.run(
engine.execute(input_data, resume_snapshot=resume_snapshot)
)
break
except WorkflowPaused as paused:
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "awaiting_approval"
execution.pause_state = _snapshot_to_jsonable(paused.snapshot)
execution.error_message = None
db.commit()
execution_logger.info(
"工作流在审批节点挂起,等待人工决策",
data={"pending_node_id": paused.snapshot.get("pending_node_id")},
)
return {
"status": "awaiting_approval",
"execution_id": execution_id,
"pending_node_id": paused.snapshot.get("pending_node_id"),
}
except WorkflowExecutionError:
raise
except Exception as run_e:
if attempt >= max_retries:
raise
delay = min(30.0, 1.5**attempt)
if execution_logger:
execution_logger.warn(
f"工作流执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}",
data={
"error_code": "WORKFLOW_TASK_TRANSIENT_RETRY",
"attempt": attempt + 1,
"max_retries": max_retries,
"delay_sec": round(delay, 2),
},
)
time.sleep(delay)
if result is None:
raise RuntimeError("工作流执行未返回结果")
# 计算执行时间
execution_time = int((time.time() - start_time) * 1000)
# 更新执行记录
if execution:
execution.status = "completed"
execution.output_data = result
execution.execution_time = execution_time
execution.pause_state = None
db.commit()
try_append_agent_dialogue_after_success(
db, execution, input_data, result, execution_logger
)
# 记录执行完成日志
execution_logger.info(f"工作流任务执行完成,耗时: {execution_time}ms")
# 检查告警规则(异步)
if execution:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
except Exception as e:
# 告警检测失败不影响执行结果
execution_logger.warn(f"告警检测失败: {str(e)}")
return {
'status': 'completed',
'result': result,
'execution_time': execution_time
}
except Exception as e:
execution_time = int((time.time() - start_time) * 1000)
# 记录错误日志
err_text = _format_task_error(e)
err_code = getattr(e, "error_code", None)
if not err_code:
err_code = (
"WORKFLOW_EXECUTION_ERROR"
if isinstance(e, WorkflowExecutionError)
else "WORKFLOW_TASK_ERROR"
)
if execution_logger:
execution_logger.error(
f"工作流任务执行失败: {err_text}",
data={"error_type": type(e).__name__, "error_code": err_code},
)
# 更新执行记录为失败
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "failed"
execution.error_message = err_text
execution.execution_time = execution_time
db.commit()
# 检查告警规则(异步)
if execution:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
except Exception as e2:
# 告警检测失败不影响错误处理
if execution_logger:
execution_logger.warn(f"告警检测失败: {str(e2)}")
raise
finally:
db.close()
@celery_app.task(bind=True)
def resume_workflow_task(
self,
execution_id: str,
decision: str,
comment: Optional[str] = None,
):
"""在 awaiting_approval 时恢复执行(审批通过/拒绝)。"""
db = SessionLocal()
start_time = time.time()
execution_logger = None
try:
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if not execution:
return {"status": "error", "detail": "执行记录不存在"}
if execution.status != "awaiting_approval":
return {
"status": "error",
"detail": f"执行状态不是 awaiting_approval: {execution.status}",
}
if not execution.pause_state:
return {"status": "error", "detail": "缺少 pause_state"}
if decision not in ("approved", "rejected"):
return {"status": "error", "detail": "decision 须为 approved 或 rejected"}
snapshot = execution.pause_state
base_input: Dict[str, Any] = dict(execution.input_data or {})
base_input["__hil_decision"] = decision
if comment:
base_input["__hil_comment"] = comment
execution.status = "running"
execution.error_message = None
execution.input_data = base_input
db.commit()
execution_logger = ExecutionLogger(execution_id, db)
execution_logger.info("审批恢复执行", data={"decision": decision})
workflow_data: dict
wf_key: str
if execution.workflow_id:
wf = db.query(Workflow).filter(Workflow.id == execution.workflow_id).first()
if not wf:
raise RuntimeError("工作流不存在")
workflow_data = {"nodes": wf.nodes, "edges": wf.edges}
wf_key = str(execution.workflow_id)
elif execution.agent_id:
ag = db.query(Agent).filter(Agent.id == execution.agent_id).first()
if not ag or not ag.workflow_config:
raise RuntimeError("Agent 或工作流配置不存在")
workflow_data = {
"nodes": ag.workflow_config.get("nodes", []),
"edges": ag.workflow_config.get("edges", []),
}
wf_key = f"agent_{execution.agent_id}"
else:
raise RuntimeError("执行未关联工作流或 Agent")
self.update_state(state="PROGRESS", meta={"progress": 0, "status": "running"})
budget = merge_budget_for_execution(db, execution) if execution else {}
engine = WorkflowEngine(
wf_key,
workflow_data,
logger=execution_logger,
db=db,
budget_limits=budget,
)
max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0))
result: Optional[dict] = None
for attempt in range(max_retries + 1):
try:
result = asyncio.run(
engine.execute(base_input, resume_snapshot=snapshot)
)
break
except WorkflowPaused as paused:
ex2 = db.query(Execution).filter(Execution.id == execution_id).first()
if ex2:
ex2.status = "awaiting_approval"
ex2.pause_state = _snapshot_to_jsonable(paused.snapshot)
ex2.error_message = None
db.commit()
if execution_logger:
execution_logger.info(
"工作流再次在审批节点挂起",
data={"pending_node_id": paused.snapshot.get("pending_node_id")},
)
return {
"status": "awaiting_approval",
"execution_id": execution_id,
"pending_node_id": paused.snapshot.get("pending_node_id"),
}
except WorkflowExecutionError:
raise
except Exception as run_e:
if attempt >= max_retries:
raise
delay = min(30.0, 1.5**attempt)
if execution_logger:
execution_logger.warn(
f"恢复执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}",
data={
"error_code": "WORKFLOW_TASK_TRANSIENT_RETRY",
"attempt": attempt + 1,
"max_retries": max_retries,
"delay_sec": round(delay, 2),
},
)
time.sleep(delay)
if result is None:
raise RuntimeError("工作流执行未返回结果")
execution_time = int((time.time() - start_time) * 1000)
ex3 = db.query(Execution).filter(Execution.id == execution_id).first()
if ex3:
ex3.status = "completed"
ex3.output_data = result
ex3.execution_time = execution_time
ex3.pause_state = None
db.commit()
try_append_agent_dialogue_after_success(
db, ex3, base_input, result, execution_logger
)
if execution_logger:
execution_logger.info(f"审批后工作流执行完成,耗时: {execution_time}ms")
if ex3:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, ex3))
except Exception as e:
if execution_logger:
execution_logger.warn(f"告警检测失败: {str(e)}")
return {
"status": "completed",
"result": result,
"execution_time": execution_time,
}
except Exception as e:
execution_time = int((time.time() - start_time) * 1000)
err_text = _format_task_error(e)
err_code = getattr(e, "error_code", None)
if not err_code:
err_code = (
"WORKFLOW_EXECUTION_ERROR"
if isinstance(e, WorkflowExecutionError)
else "WORKFLOW_TASK_ERROR"
)
if execution_logger:
execution_logger.error(
f"审批恢复执行失败: {err_text}",
data={"error_type": type(e).__name__, "error_code": err_code},
)
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "failed"
execution.error_message = err_text
execution.execution_time = execution_time
db.commit()
if execution:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
except Exception as e2:
if execution_logger:
execution_logger.warn(f"告警检测失败: {str(e2)}")
raise
finally:
db.close()