""" 工作流任务 """ from celery import Task from app.core.tools_bootstrap import ensure_builtin_tools_registered ensure_builtin_tools_registered() from app.core.celery_app import celery_app from app.services.workflow_engine import WorkflowEngine from app.services.execution_logger import ExecutionLogger from app.services.alert_service import AlertService from app.core.database import SessionLocal from app.core.config import settings from app.core.exceptions import WorkflowExecutionError, WorkflowPaused # 导入所有相关模型,确保关系可以正确解析 from app.models.execution import Execution from app.models.agent import Agent from app.models.workflow import Workflow from app.services.execution_budget import merge_budget_for_execution from app.services.agent_workspace_chat_log import try_append_agent_dialogue_after_success import asyncio import time from typing import Any, Dict, Optional def _format_task_error(e: Exception) -> str: """Celery 任务异常写入 DB 时的可读文案(HTTPException.detail 等)。""" detail = getattr(e, "detail", None) if isinstance(detail, str) and detail.strip(): return detail if detail is not None and detail != "": return str(detail) s = str(e).strip() return s if s else repr(e) def _snapshot_to_jsonable(snapshot: dict) -> dict: import json as _json return _json.loads(_json.dumps(snapshot, default=str)) def _trusted_user_for_execution(db, execution: Optional[Execution]) -> Optional[str]: """用于校验 LLM 节点引用的 model_configs 归属(与 Workflow / Agent 所有者一致)。""" if not execution: return None if execution.agent_id: ag = db.query(Agent).filter(Agent.id == execution.agent_id).first() return ag.user_id if ag else None if execution.workflow_id: wf = db.query(Workflow).filter(Workflow.id == execution.workflow_id).first() return wf.user_id if wf else None return None @celery_app.task(bind=True) def execute_workflow_task( self, execution_id: str, workflow_id: str, workflow_data: dict, input_data: dict, resume_snapshot: Optional[dict] = None, ): """ 执行工作流任务 Args: execution_id: 执行记录ID workflow_id: 工作流ID workflow_data: 工作流数据(nodes和edges) input_data: 输入数据 resume_snapshot: 从挂起恢复时的快照(与 Execution.pause_state 一致) """ db = SessionLocal() start_time = time.time() execution_logger = None try: # 更新执行状态为运行中 execution = db.query(Execution).filter(Execution.id == execution_id).first() if execution: execution.status = "running" db.commit() # 更新任务状态 self.update_state(state='PROGRESS', meta={'progress': 0, 'status': 'running'}) # 创建执行日志记录器 execution_logger = ExecutionLogger(execution_id, db) execution_logger.info("工作流任务开始执行") # 创建工作流引擎(传入logger、db、合并后的执行预算) budget = merge_budget_for_execution(db, execution) if execution else {} trusted_uid = _trusted_user_for_execution(db, execution) engine = WorkflowEngine( workflow_id, workflow_data, logger=execution_logger, db=db, budget_limits=budget, trusted_model_config_user_id=trusted_uid, ) max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0)) result: Optional[dict] = None for attempt in range(max_retries + 1): try: result = asyncio.run( engine.execute(input_data, resume_snapshot=resume_snapshot) ) break except WorkflowPaused as paused: execution = db.query(Execution).filter(Execution.id == execution_id).first() if execution: execution.status = "awaiting_approval" execution.pause_state = _snapshot_to_jsonable(paused.snapshot) execution.error_message = None db.commit() execution_logger.info( "工作流在审批节点挂起,等待人工决策", data={"pending_node_id": paused.snapshot.get("pending_node_id")}, ) return { "status": "awaiting_approval", "execution_id": execution_id, "pending_node_id": paused.snapshot.get("pending_node_id"), } except WorkflowExecutionError: raise except Exception as run_e: if attempt >= max_retries: raise delay = min(30.0, 1.5**attempt) if execution_logger: execution_logger.warn( f"工作流执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}", data={ "error_code": "WORKFLOW_TASK_TRANSIENT_RETRY", "attempt": attempt + 1, "max_retries": max_retries, "delay_sec": round(delay, 2), }, ) time.sleep(delay) if result is None: raise RuntimeError("工作流执行未返回结果") # 计算执行时间 execution_time = int((time.time() - start_time) * 1000) # 更新执行记录 if execution: execution.status = "completed" execution.output_data = result execution.execution_time = execution_time execution.pause_state = None db.commit() try_append_agent_dialogue_after_success( db, execution, input_data, result, execution_logger ) # 记录执行完成日志 execution_logger.info(f"工作流任务执行完成,耗时: {execution_time}ms") # 检查告警规则(异步) if execution: try: asyncio.run(AlertService.check_alerts_for_execution(db, execution)) except Exception as e: # 告警检测失败不影响执行结果 execution_logger.warn(f"告警检测失败: {str(e)}") return { 'status': 'completed', 'result': result, 'execution_time': execution_time } except Exception as e: execution_time = int((time.time() - start_time) * 1000) # 记录错误日志 err_text = _format_task_error(e) err_code = getattr(e, "error_code", None) if not err_code: err_code = ( "WORKFLOW_EXECUTION_ERROR" if isinstance(e, WorkflowExecutionError) else "WORKFLOW_TASK_ERROR" ) if execution_logger: execution_logger.error( f"工作流任务执行失败: {err_text}", data={"error_type": type(e).__name__, "error_code": err_code}, ) # 更新执行记录为失败 execution = db.query(Execution).filter(Execution.id == execution_id).first() if execution: execution.status = "failed" execution.error_message = err_text execution.execution_time = execution_time db.commit() # 检查告警规则(异步) if execution: try: asyncio.run(AlertService.check_alerts_for_execution(db, execution)) except Exception as e2: # 告警检测失败不影响错误处理 if execution_logger: execution_logger.warn(f"告警检测失败: {str(e2)}") raise finally: db.close() @celery_app.task(bind=True) def resume_workflow_task( self, execution_id: str, decision: str, comment: Optional[str] = None, ): """在 awaiting_approval 时恢复执行(审批通过/拒绝)。""" db = SessionLocal() start_time = time.time() execution_logger = None try: execution = db.query(Execution).filter(Execution.id == execution_id).first() if not execution: return {"status": "error", "detail": "执行记录不存在"} if execution.status != "awaiting_approval": return { "status": "error", "detail": f"执行状态不是 awaiting_approval: {execution.status}", } if not execution.pause_state: return {"status": "error", "detail": "缺少 pause_state"} if decision not in ("approved", "rejected"): return {"status": "error", "detail": "decision 须为 approved 或 rejected"} snapshot = execution.pause_state base_input: Dict[str, Any] = dict(execution.input_data or {}) base_input["__hil_decision"] = decision if comment: base_input["__hil_comment"] = comment execution.status = "running" execution.error_message = None execution.input_data = base_input db.commit() execution_logger = ExecutionLogger(execution_id, db) execution_logger.info("审批恢复执行", data={"decision": decision}) workflow_data: dict wf_key: str if execution.workflow_id: wf = db.query(Workflow).filter(Workflow.id == execution.workflow_id).first() if not wf: raise RuntimeError("工作流不存在") workflow_data = {"nodes": wf.nodes, "edges": wf.edges} wf_key = str(execution.workflow_id) elif execution.agent_id: ag = db.query(Agent).filter(Agent.id == execution.agent_id).first() if not ag or not ag.workflow_config: raise RuntimeError("Agent 或工作流配置不存在") workflow_data = { "nodes": ag.workflow_config.get("nodes", []), "edges": ag.workflow_config.get("edges", []), } wf_key = f"agent_{execution.agent_id}" else: raise RuntimeError("执行未关联工作流或 Agent") self.update_state(state="PROGRESS", meta={"progress": 0, "status": "running"}) budget = merge_budget_for_execution(db, execution) if execution else {} trusted_uid = _trusted_user_for_execution(db, execution) engine = WorkflowEngine( wf_key, workflow_data, logger=execution_logger, db=db, budget_limits=budget, trusted_model_config_user_id=trusted_uid, ) max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0)) result: Optional[dict] = None for attempt in range(max_retries + 1): try: result = asyncio.run( engine.execute(base_input, resume_snapshot=snapshot) ) break except WorkflowPaused as paused: ex2 = db.query(Execution).filter(Execution.id == execution_id).first() if ex2: ex2.status = "awaiting_approval" ex2.pause_state = _snapshot_to_jsonable(paused.snapshot) ex2.error_message = None db.commit() if execution_logger: execution_logger.info( "工作流再次在审批节点挂起", data={"pending_node_id": paused.snapshot.get("pending_node_id")}, ) return { "status": "awaiting_approval", "execution_id": execution_id, "pending_node_id": paused.snapshot.get("pending_node_id"), } except WorkflowExecutionError: raise except Exception as run_e: if attempt >= max_retries: raise delay = min(30.0, 1.5**attempt) if execution_logger: execution_logger.warn( f"恢复执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}", data={ "error_code": "WORKFLOW_TASK_TRANSIENT_RETRY", "attempt": attempt + 1, "max_retries": max_retries, "delay_sec": round(delay, 2), }, ) time.sleep(delay) if result is None: raise RuntimeError("工作流执行未返回结果") execution_time = int((time.time() - start_time) * 1000) ex3 = db.query(Execution).filter(Execution.id == execution_id).first() if ex3: ex3.status = "completed" ex3.output_data = result ex3.execution_time = execution_time ex3.pause_state = None db.commit() try_append_agent_dialogue_after_success( db, ex3, base_input, result, execution_logger ) if execution_logger: execution_logger.info(f"审批后工作流执行完成,耗时: {execution_time}ms") if ex3: try: asyncio.run(AlertService.check_alerts_for_execution(db, ex3)) except Exception as e: if execution_logger: execution_logger.warn(f"告警检测失败: {str(e)}") return { "status": "completed", "result": result, "execution_time": execution_time, } except Exception as e: execution_time = int((time.time() - start_time) * 1000) err_text = _format_task_error(e) err_code = getattr(e, "error_code", None) if not err_code: err_code = ( "WORKFLOW_EXECUTION_ERROR" if isinstance(e, WorkflowExecutionError) else "WORKFLOW_TASK_ERROR" ) if execution_logger: execution_logger.error( f"审批恢复执行失败: {err_text}", data={"error_type": type(e).__name__, "error_code": err_code}, ) execution = db.query(Execution).filter(Execution.id == execution_id).first() if execution: execution.status = "failed" execution.error_message = err_text execution.execution_time = execution_time db.commit() if execution: try: asyncio.run(AlertService.check_alerts_for_execution(db, execution)) except Exception as e2: if execution_logger: execution_logger.warn(f"告警检测失败: {str(e2)}") raise finally: db.close()