Files
aiagent/backend/app/tasks/workflow_tasks.py

521 lines
19 KiB
Python
Raw Normal View History

2026-01-19 00:09:36 +08:00
"""
工作流任务
"""
from celery import Task
from app.core.tools_bootstrap import ensure_builtin_tools_registered
ensure_builtin_tools_registered()
2026-01-19 00:09:36 +08:00
from app.core.celery_app import celery_app
from app.services.workflow_engine import WorkflowEngine
from app.services.execution_logger import ExecutionLogger
from app.services.alert_service import AlertService
from app.core.database import SessionLocal
from app.core.config import settings
from app.core.exceptions import WorkflowExecutionError, WorkflowPaused
2026-01-19 00:09:36 +08:00
# 导入所有相关模型,确保关系可以正确解析
from app.models.execution import Execution
from app.models.agent import Agent
from app.models.workflow import Workflow
from app.services.execution_budget import merge_budget_for_execution
from app.services.agent_workspace_chat_log import try_append_agent_dialogue_after_success
from app.services.notification_service import create_notification
from app.websocket.manager import websocket_manager
2026-01-19 00:09:36 +08:00
import asyncio
import json
2026-01-19 00:09:36 +08:00
import time
from typing import Any, Dict, Optional
2026-01-19 00:09:36 +08:00
def _format_task_error(e: Exception) -> str:
"""Celery 任务异常写入 DB 时的可读文案HTTPException.detail 等)。"""
detail = getattr(e, "detail", None)
if isinstance(detail, str) and detail.strip():
return detail
if detail is not None and detail != "":
return str(detail)
s = str(e).strip()
return s if s else repr(e)
def _snapshot_to_jsonable(snapshot: dict) -> dict:
import json as _json
return _json.loads(_json.dumps(snapshot, default=str))
async def _on_workflow_progress(execution_id: str, progress_data: dict):
"""工作流进度回调:写入 Redis + WebSocket 广播。"""
try:
from app.core.redis_client import get_redis_client
redis_client = get_redis_client()
if redis_client:
redis_client.setex(
f"workflow:progress:{execution_id}",
300,
json.dumps(progress_data, ensure_ascii=False),
)
except Exception:
pass
try:
await websocket_manager.broadcast_to_execution(execution_id, {
"type": "progress",
"execution_id": execution_id,
**progress_data,
})
except Exception:
pass
def _trusted_user_for_execution(db, execution: Optional[Execution]) -> Optional[str]:
"""用于校验 LLM 节点引用的 model_configs 归属(与 Workflow / Agent 所有者一致)。"""
if not execution:
return None
if execution.agent_id:
ag = db.query(Agent).filter(Agent.id == execution.agent_id).first()
return ag.user_id if ag else None
if execution.workflow_id:
wf = db.query(Workflow).filter(Workflow.id == execution.workflow_id).first()
return wf.user_id if wf else None
return None
def _notify_schedule_result(
db,
execution,
status: str,
error_message: Optional[str] = None,
):
"""如果 execution 关联了定时任务,创建通知推送结果给用户。"""
if not execution or not execution.schedule_id:
return
try:
from app.models.agent_schedule import AgentSchedule
schedule = db.query(AgentSchedule).filter(AgentSchedule.id == execution.schedule_id).first()
if not schedule:
return
if status == "completed":
title = f"定时任务「{schedule.name}」执行成功"
content = f"Agent 已按计划执行完成。"
else:
title = f"定时任务「{schedule.name}」执行失败"
content = f"错误信息: {error_message or '未知错误'}"
create_notification(
db,
user_id=schedule.user_id,
title=title,
content=content,
category="schedule",
ref_type="execution",
ref_id=str(execution.id),
)
db.commit()
# 如果配置了飞书 webhook发送飞书通知非阻塞失败不影响主流程
if schedule.webhook_url:
try:
from app.services.feishu_notifier import send_feishu_card
detail_link = None
# 如果系统配置了外部访问地址,拼接 execution 详情链接
try:
from app.core.config import settings
if settings.EXTERNAL_URL:
detail_link = f"{settings.EXTERNAL_URL}/executions/{execution.id}"
except Exception:
pass
send_feishu_card(
webhook_url=schedule.webhook_url,
title=title,
body=content,
status=status,
detail_link=detail_link,
)
except Exception as e:
logger.warning("飞书 webhook 通知发送失败: %s", e)
# 如果用户绑定了飞书账号,通过飞书应用发送通知
try:
from app.models.user import User
from app.services.feishu_app_service import send_message_to_user
schedule_user = db.query(User).filter(User.id == schedule.user_id).first()
if schedule_user and schedule_user.feishu_open_id:
detail_link = None
try:
from app.core.config import settings
if settings.EXTERNAL_URL:
detail_link = f"{settings.EXTERNAL_URL}/executions/{execution.id}"
except Exception:
pass
send_message_to_user(
open_id=schedule_user.feishu_open_id,
title=title,
content=content,
status=status,
detail_link=detail_link,
)
except Exception as e:
logger.warning("飞书应用通知发送失败: %s", e)
except Exception as e:
logger.warning("创建定时任务通知失败: %s", e)
2026-01-19 00:09:36 +08:00
@celery_app.task(bind=True)
def execute_workflow_task(
self,
execution_id: str,
workflow_id: str,
workflow_data: dict,
input_data: dict,
resume_snapshot: Optional[dict] = None,
2026-01-19 00:09:36 +08:00
):
"""
执行工作流任务
Args:
2026-01-19 00:09:36 +08:00
execution_id: 执行记录ID
workflow_id: 工作流ID
workflow_data: 工作流数据nodes和edges
input_data: 输入数据
resume_snapshot: 从挂起恢复时的快照 Execution.pause_state 一致
2026-01-19 00:09:36 +08:00
"""
db = SessionLocal()
start_time = time.time()
execution_logger = None
try:
# 更新执行状态为运行中
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "running"
db.commit()
# 更新任务状态
self.update_state(state='PROGRESS', meta={'progress': 0, 'status': 'running'})
# 创建执行日志记录器
execution_logger = ExecutionLogger(execution_id, db)
execution_logger.info("工作流任务开始执行")
# 创建工作流引擎传入logger、db、合并后的执行预算
budget = merge_budget_for_execution(db, execution) if execution else {}
trusted_uid = _trusted_user_for_execution(db, execution)
engine = WorkflowEngine(
workflow_id,
workflow_data,
logger=execution_logger,
db=db,
budget_limits=budget,
trusted_model_config_user_id=trusted_uid,
)
max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0))
result: Optional[dict] = None
for attempt in range(max_retries + 1):
try:
result = asyncio.run(
engine.execute(input_data, resume_snapshot=resume_snapshot,
execution_id=execution_id, on_progress=_on_workflow_progress)
)
break
except WorkflowPaused as paused:
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "awaiting_approval"
execution.pause_state = _snapshot_to_jsonable(paused.snapshot)
execution.error_message = None
db.commit()
execution_logger.info(
"工作流在审批节点挂起,等待人工决策",
data={"pending_node_id": paused.snapshot.get("pending_node_id")},
)
return {
"status": "awaiting_approval",
"execution_id": execution_id,
"pending_node_id": paused.snapshot.get("pending_node_id"),
}
except WorkflowExecutionError:
raise
except Exception as run_e:
if attempt >= max_retries:
raise
delay = min(30.0, 1.5**attempt)
if execution_logger:
execution_logger.warn(
f"工作流执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}",
data={
"error_code": "WORKFLOW_TASK_TRANSIENT_RETRY",
"attempt": attempt + 1,
"max_retries": max_retries,
"delay_sec": round(delay, 2),
},
)
time.sleep(delay)
if result is None:
raise RuntimeError("工作流执行未返回结果")
2026-01-19 00:09:36 +08:00
# 计算执行时间
execution_time = int((time.time() - start_time) * 1000)
# 更新执行记录
if execution:
execution.status = "completed"
execution.output_data = result
execution.execution_time = execution_time
execution.pause_state = None
2026-01-19 00:09:36 +08:00
db.commit()
try_append_agent_dialogue_after_success(
db, execution, input_data, result, execution_logger
)
2026-01-19 00:09:36 +08:00
# 记录执行完成日志
execution_logger.info(f"工作流任务执行完成,耗时: {execution_time}ms")
# 检查告警规则(异步)
if execution:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
except Exception as e:
# 告警检测失败不影响执行结果
execution_logger.warn(f"告警检测失败: {str(e)}")
# 定时任务结果通知
_notify_schedule_result(db, execution, "completed")
2026-01-19 00:09:36 +08:00
return {
'status': 'completed',
'result': result,
'execution_time': execution_time
}
except Exception as e:
execution_time = int((time.time() - start_time) * 1000)
# 记录错误日志
err_text = _format_task_error(e)
err_code = getattr(e, "error_code", None)
if not err_code:
err_code = (
"WORKFLOW_EXECUTION_ERROR"
if isinstance(e, WorkflowExecutionError)
else "WORKFLOW_TASK_ERROR"
)
2026-01-19 00:09:36 +08:00
if execution_logger:
execution_logger.error(
f"工作流任务执行失败: {err_text}",
data={"error_type": type(e).__name__, "error_code": err_code},
)
2026-01-19 00:09:36 +08:00
# 更新执行记录为失败
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "failed"
execution.error_message = err_text
2026-01-19 00:09:36 +08:00
execution.execution_time = execution_time
db.commit()
# 检查告警规则(异步)
if execution:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
except Exception as e2:
# 告警检测失败不影响错误处理
if execution_logger:
execution_logger.warn(f"告警检测失败: {str(e2)}")
# 定时任务失败通知
_notify_schedule_result(db, execution, "failed", error_message=err_text)
2026-01-19 00:09:36 +08:00
raise
finally:
db.close()
@celery_app.task(bind=True)
def resume_workflow_task(
self,
execution_id: str,
decision: str,
comment: Optional[str] = None,
):
"""在 awaiting_approval 时恢复执行(审批通过/拒绝)。"""
db = SessionLocal()
start_time = time.time()
execution_logger = None
try:
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if not execution:
return {"status": "error", "detail": "执行记录不存在"}
if execution.status != "awaiting_approval":
return {
"status": "error",
"detail": f"执行状态不是 awaiting_approval: {execution.status}",
}
if not execution.pause_state:
return {"status": "error", "detail": "缺少 pause_state"}
if decision not in ("approved", "rejected"):
return {"status": "error", "detail": "decision 须为 approved 或 rejected"}
snapshot = execution.pause_state
base_input: Dict[str, Any] = dict(execution.input_data or {})
base_input["__hil_decision"] = decision
if comment:
base_input["__hil_comment"] = comment
execution.status = "running"
execution.error_message = None
execution.input_data = base_input
db.commit()
execution_logger = ExecutionLogger(execution_id, db)
execution_logger.info("审批恢复执行", data={"decision": decision})
workflow_data: dict
wf_key: str
if execution.workflow_id:
wf = db.query(Workflow).filter(Workflow.id == execution.workflow_id).first()
if not wf:
raise RuntimeError("工作流不存在")
workflow_data = {"nodes": wf.nodes, "edges": wf.edges}
wf_key = str(execution.workflow_id)
elif execution.agent_id:
ag = db.query(Agent).filter(Agent.id == execution.agent_id).first()
if not ag or not ag.workflow_config:
raise RuntimeError("Agent 或工作流配置不存在")
workflow_data = {
"nodes": ag.workflow_config.get("nodes", []),
"edges": ag.workflow_config.get("edges", []),
}
wf_key = f"agent_{execution.agent_id}"
else:
raise RuntimeError("执行未关联工作流或 Agent")
self.update_state(state="PROGRESS", meta={"progress": 0, "status": "running"})
budget = merge_budget_for_execution(db, execution) if execution else {}
trusted_uid = _trusted_user_for_execution(db, execution)
engine = WorkflowEngine(
wf_key,
workflow_data,
logger=execution_logger,
db=db,
budget_limits=budget,
trusted_model_config_user_id=trusted_uid,
)
max_retries = max(0, int(getattr(settings, "WORKFLOW_TASK_MAX_RETRIES", 0) or 0))
result: Optional[dict] = None
for attempt in range(max_retries + 1):
try:
result = asyncio.run(
engine.execute(base_input, resume_snapshot=snapshot,
execution_id=execution_id, on_progress=_on_workflow_progress)
)
break
except WorkflowPaused as paused:
ex2 = db.query(Execution).filter(Execution.id == execution_id).first()
if ex2:
ex2.status = "awaiting_approval"
ex2.pause_state = _snapshot_to_jsonable(paused.snapshot)
ex2.error_message = None
db.commit()
if execution_logger:
execution_logger.info(
"工作流再次在审批节点挂起",
data={"pending_node_id": paused.snapshot.get("pending_node_id")},
)
return {
"status": "awaiting_approval",
"execution_id": execution_id,
"pending_node_id": paused.snapshot.get("pending_node_id"),
}
except WorkflowExecutionError:
raise
except Exception as run_e:
if attempt >= max_retries:
raise
delay = min(30.0, 1.5**attempt)
if execution_logger:
execution_logger.warn(
f"恢复执行异常,将退避重试 ({attempt + 1}/{max_retries}): {_format_task_error(run_e)}",
data={
"error_code": "WORKFLOW_TASK_TRANSIENT_RETRY",
"attempt": attempt + 1,
"max_retries": max_retries,
"delay_sec": round(delay, 2),
},
)
time.sleep(delay)
if result is None:
raise RuntimeError("工作流执行未返回结果")
execution_time = int((time.time() - start_time) * 1000)
ex3 = db.query(Execution).filter(Execution.id == execution_id).first()
if ex3:
ex3.status = "completed"
ex3.output_data = result
ex3.execution_time = execution_time
ex3.pause_state = None
db.commit()
try_append_agent_dialogue_after_success(
db, ex3, base_input, result, execution_logger
)
if execution_logger:
execution_logger.info(f"审批后工作流执行完成,耗时: {execution_time}ms")
if ex3:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, ex3))
except Exception as e:
if execution_logger:
execution_logger.warn(f"告警检测失败: {str(e)}")
return {
"status": "completed",
"result": result,
"execution_time": execution_time,
}
except Exception as e:
execution_time = int((time.time() - start_time) * 1000)
err_text = _format_task_error(e)
err_code = getattr(e, "error_code", None)
if not err_code:
err_code = (
"WORKFLOW_EXECUTION_ERROR"
if isinstance(e, WorkflowExecutionError)
else "WORKFLOW_TASK_ERROR"
)
if execution_logger:
execution_logger.error(
f"审批恢复执行失败: {err_text}",
data={"error_type": type(e).__name__, "error_code": err_code},
)
execution = db.query(Execution).filter(Execution.id == execution_id).first()
if execution:
execution.status = "failed"
execution.error_message = err_text
execution.execution_time = execution_time
db.commit()
if execution:
try:
asyncio.run(AlertService.check_alerts_for_execution(db, execution))
except Exception as e2:
if execution_logger:
execution_logger.warn(f"告警检测失败: {str(e2)}")
raise
finally:
db.close()