Files
aiagent/backend/app/tasks/workflow_tasks.py
renjianbo 4366312946 feat: DeepSeek v4 模型对齐、作业助手脚本与 Agent 对比测试
- 前端 WorkflowEditor/ModelConfigs/NodeTemplates:deepseek-v4-flash、v4-pro,弃用提示
- llm_service 默认 deepseek-v4-flash;workflow_engine 等与模型配置注入
- 作业管理脚本支持 AGENT_NAME 与 v4-pro;新增 compare_homework_agents 脚本
- 文档重命名为 (红头)项目核心文档汇总.md 并更新 DeepSeek 说明

Made-with: Cursor
2026-04-30 00:57:13 +08:00

401 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
工作流任务
"""
from celery import Task
from app.core.tools_bootstrap import ensure_builtin_tools_registered
ensure_builtin_tools_registered()
from app.core.celery_app import celery_app
from app.services.workflow_engine import WorkflowEngine
from app.services.execution_logger import ExecutionLogger
from app.services.alert_service import AlertService
from app.core.database import SessionLocal
from app.core.config import settings
from app.core.exceptions import WorkflowExecutionError, WorkflowPaused
# 导入所有相关模型,确保关系可以正确解析
from app.models.execution import Execution
from app.models.agent import Agent
from app.models.workflow import Workflow
from app.services.execution_budget import merge_budget_for_execution
from app.services.agent_workspace_chat_log import try_append_agent_dialogue_after_success
import asyncio
import time
from typing import Any, Dict, Optional
def _format_task_error(e: Exception) -> str:
"""Celery 任务异常写入 DB 时的可读文案HTTPException.detail 等)。"""
detail = getattr(e, "detail", None)
if isinstance(detail, str) and detail.strip():
return detail
if detail is not None and detail != "":
return str(detail)
s = str(e).strip()
return s if s else repr(e)
def _snapshot_to_jsonable(snapshot: dict) -> dict:
import json as _json
return _json.loads(_json.dumps(snapshot, default=str))
def _trusted_user_for_execution(db, execution: Optional[Execution]) -> Optional[str]:
"""用于校验 LLM 节点引用的 model_configs 归属(与 Workflow / Agent 所有者一致)。"""
if not execution:
return None
if execution.agent_id:
ag = db.query(Agent).filter(Agent.id == execution.agent_id).first()
return ag.user_id if ag else None
if execution.workflow_id:
wf = db.query(Workflow).filter(Workflow.id == execution.workflow_id).first()
return wf.user_id if wf else None
return None
@celery_app.task(bind=True)
def execute_workflow_task(
self,
execution_id: str,
workflow_id: str,
workflow_data: dict,
input_data: dict,
resume_snapshot: Optional[dict] = None,
):
"""
执行工作流任务
Args:
execution_id: 执行记录ID
workflow_id: 工作流ID
workflow_data: 工作流数据nodes和edges
input_data: 输入数据
resume_snapshot: 从挂起恢复时的快照(与 Execution.pause_state 一致)
"""
db = SessionLocal()
start_time = time.time()
execution_logger = None
try:
# 更新执行状态为运行中
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "running"
db.commit()
# 更新任务状态
self.update_state(state='PROGRESS', meta={'progress': 0, 'status': 'running'})
# 创建执行日志记录器
execution_logger = ExecutionLogger(execution_id, db)
execution_logger.info("工作流任务开始执行")
# 创建工作流引擎传入logger、db、合并后的执行预算
budget = merge_budget_for_execution(db, execution) if execution else {}
trusted_uid = _trusted_user_for_execution(db, execution)
engine = WorkflowEngine(
workflow_id,
workflow_data,
logger=execution_logger,
db=db,
budget_limits=budget,
trusted_model_config_user_id=trusted_uid,
)
max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0))
result: Optional[dict] = None
for attempt in range(max_retries + 1):
try:
result = asyncio.run(
engine.execute(input_data, resume_snapshot=resume_snapshot)
)
break
except WorkflowPaused as paused:
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "awaiting_approval"
execution.pause_state = _snapshot_to_jsonable(paused.snapshot)
execution.error_message = None
db.commit()
execution_logger.info(
"工作流在审批节点挂起,等待人工决策",
data={"pending_node_id": paused.snapshot.get("pending_node_id")},
)
return {
"status": "awaiting_approval",
"execution_id": execution_id,
"pending_node_id": paused.snapshot.get("pending_node_id"),
}
except WorkflowExecutionError:
raise
except Exception as run_e:
if attempt >= max_retries:
raise
delay = min(30.0, 1.5**attempt)
if execution_logger:
execution_logger.warn(
f"工作流执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}",
data={
"error_code": "WORKFLOW_TASK_TRANSIENT_RETRY",
"attempt": attempt + 1,
"max_retries": max_retries,
"delay_sec": round(delay, 2),
},
)
time.sleep(delay)
if result is None:
raise RuntimeError("工作流执行未返回结果")
# 计算执行时间
execution_time = int((time.time() - start_time) * 1000)
# 更新执行记录
if execution:
execution.status = "completed"
execution.output_data = result
execution.execution_time = execution_time
execution.pause_state = None
db.commit()
try_append_agent_dialogue_after_success(
db, execution, input_data, result, execution_logger
)
# 记录执行完成日志
execution_logger.info(f"工作流任务执行完成,耗时: {execution_time}ms")
# 检查告警规则(异步)
if execution:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
except Exception as e:
# 告警检测失败不影响执行结果
execution_logger.warn(f"告警检测失败: {str(e)}")
return {
'status': 'completed',
'result': result,
'execution_time': execution_time
}
except Exception as e:
execution_time = int((time.time() - start_time) * 1000)
# 记录错误日志
err_text = _format_task_error(e)
err_code = getattr(e, "error_code", None)
if not err_code:
err_code = (
"WORKFLOW_EXECUTION_ERROR"
if isinstance(e, WorkflowExecutionError)
else "WORKFLOW_TASK_ERROR"
)
if execution_logger:
execution_logger.error(
f"工作流任务执行失败: {err_text}",
data={"error_type": type(e).__name__, "error_code": err_code},
)
# 更新执行记录为失败
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "failed"
execution.error_message = err_text
execution.execution_time = execution_time
db.commit()
# 检查告警规则(异步)
if execution:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
except Exception as e2:
# 告警检测失败不影响错误处理
if execution_logger:
execution_logger.warn(f"告警检测失败: {str(e2)}")
raise
finally:
db.close()
@celery_app.task(bind=True)
def resume_workflow_task(
self,
execution_id: str,
decision: str,
comment: Optional[str] = None,
):
"""在 awaiting_approval 时恢复执行(审批通过/拒绝)。"""
db = SessionLocal()
start_time = time.time()
execution_logger = None
try:
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if not execution:
return {"status": "error", "detail": "执行记录不存在"}
if execution.status != "awaiting_approval":
return {
"status": "error",
"detail": f"执行状态不是 awaiting_approval: {execution.status}",
}
if not execution.pause_state:
return {"status": "error", "detail": "缺少 pause_state"}
if decision not in ("approved", "rejected"):
return {"status": "error", "detail": "decision 须为 approved 或 rejected"}
snapshot = execution.pause_state
base_input: Dict[str, Any] = dict(execution.input_data or {})
base_input["__hil_decision"] = decision
if comment:
base_input["__hil_comment"] = comment
execution.status = "running"
execution.error_message = None
execution.input_data = base_input
db.commit()
execution_logger = ExecutionLogger(execution_id, db)
execution_logger.info("审批恢复执行", data={"decision": decision})
workflow_data: dict
wf_key: str
if execution.workflow_id:
wf = db.query(Workflow).filter(Workflow.id == execution.workflow_id).first()
if not wf:
raise RuntimeError("工作流不存在")
workflow_data = {"nodes": wf.nodes, "edges": wf.edges}
wf_key = str(execution.workflow_id)
elif execution.agent_id:
ag = db.query(Agent).filter(Agent.id == execution.agent_id).first()
if not ag or not ag.workflow_config:
raise RuntimeError("Agent 或工作流配置不存在")
workflow_data = {
"nodes": ag.workflow_config.get("nodes", []),
"edges": ag.workflow_config.get("edges", []),
}
wf_key = f"agent_{execution.agent_id}"
else:
raise RuntimeError("执行未关联工作流或 Agent")
self.update_state(state="PROGRESS", meta={"progress": 0, "status": "running"})
budget = merge_budget_for_execution(db, execution) if execution else {}
trusted_uid = _trusted_user_for_execution(db, execution)
engine = WorkflowEngine(
wf_key,
workflow_data,
logger=execution_logger,
db=db,
budget_limits=budget,
trusted_model_config_user_id=trusted_uid,
)
max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0))
result: Optional[dict] = None
for attempt in range(max_retries + 1):
try:
result = asyncio.run(
engine.execute(base_input, resume_snapshot=snapshot)
)
break
except WorkflowPaused as paused:
ex2 = db.query(Execution).filter(Execution.id == execution_id).first()
if ex2:
ex2.status = "awaiting_approval"
ex2.pause_state = _snapshot_to_jsonable(paused.snapshot)
ex2.error_message = None
db.commit()
if execution_logger:
execution_logger.info(
"工作流再次在审批节点挂起",
data={"pending_node_id": paused.snapshot.get("pending_node_id")},
)
return {
"status": "awaiting_approval",
"execution_id": execution_id,
"pending_node_id": paused.snapshot.get("pending_node_id"),
}
except WorkflowExecutionError:
raise
except Exception as run_e:
if attempt >= max_retries:
raise
delay = min(30.0, 1.5**attempt)
if execution_logger:
execution_logger.warn(
f"恢复执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}",
data={
"error_code": "WORKFLOW_TASK_TRANSIENT_RETRY",
"attempt": attempt + 1,
"max_retries": max_retries,
"delay_sec": round(delay, 2),
},
)
time.sleep(delay)
if result is None:
raise RuntimeError("工作流执行未返回结果")
execution_time = int((time.time() - start_time) * 1000)
ex3 = db.query(Execution).filter(Execution.id == execution_id).first()
if ex3:
ex3.status = "completed"
ex3.output_data = result
ex3.execution_time = execution_time
ex3.pause_state = None
db.commit()
try_append_agent_dialogue_after_success(
db, ex3, base_input, result, execution_logger
)
if execution_logger:
execution_logger.info(f"审批后工作流执行完成,耗时: {execution_time}ms")
if ex3:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, ex3))
except Exception as e:
if execution_logger:
execution_logger.warn(f"告警检测失败: {str(e)}")
return {
"status": "completed",
"result": result,
"execution_time": execution_time,
}
except Exception as e:
execution_time = int((time.time() - start_time) * 1000)
err_text = _format_task_error(e)
err_code = getattr(e, "error_code", None)
if not err_code:
err_code = (
"WORKFLOW_EXECUTION_ERROR"
if isinstance(e, WorkflowExecutionError)
else "WORKFLOW_TASK_ERROR"
)
if execution_logger:
execution_logger.error(
f"审批恢复执行失败: {err_text}",
data={"error_type": type(e).__name__, "error_code": err_code},
)
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "failed"
execution.error_message = err_text
execution.execution_time = execution_time
db.commit()
if execution:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
except Exception as e2:
if execution_logger:
execution_logger.warn(f"告警检测失败: {str(e2)}")
raise
finally:
db.close()