Files
aiagent/backend/app/tasks/agent_tasks.py
renjianbo 592bca4f39 feat: Phase 4 - LLM/Agent fallback chain, cross-agent knowledge sharing, async agent execution
- 4.1 Fallback chain: LLM fallback_llm config in AgentLLMConfig, retry with alternate model on API failure; Agent fallback_agent in DAG nodes
- 4.2 Knowledge sharing: GlobalKnowledge model with embedding-based semantic search, auto-extraction of tool names as tags after execution
- 4.3 Async execution: execute_agent_task fully implemented with AgentRuntime, scheduler dual-path for workflow/non-workflow agents

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-05 00:27:54 +08:00

176 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Agent 异步执行任务 — Celery 任务,支持定时调度和手动触发。
"""
from celery import Task
from app.core.tools_bootstrap import ensure_builtin_tools_registered
ensure_builtin_tools_registered()
from app.core.celery_app import celery_app
from app.core.database import SessionLocal
from app.agent_runtime.core import AgentRuntime
from app.agent_runtime.schemas import (
AgentConfig,
AgentLLMConfig,
AgentToolConfig,
AgentMemoryConfig,
AgentBudgetConfig,
)
from app.models.agent import Agent
from app.models.execution import Execution
import asyncio
import logging
import time
logger = logging.getLogger(__name__)
def _build_agent_config_from_db(agent: Agent) -> AgentConfig:
"""从 DB Agent 记录构建 AgentConfig。"""
wf = agent.workflow_config or {}
nodes = wf.get("nodes", [])
edges = wf.get("edges", [])
# 从工作流节点中查找 start 节点获取 system_prompt
system_prompt = "你是一个有用的AI助手。"
tools_include = []
tools_exclude = []
model_name = "gpt-4o-mini"
provider = "openai"
temperature = 0.7
max_iterations = 10
for node in nodes:
nd = node.get("data", {}) if isinstance(node, dict) else {}
node_type = node.get("type", "") if isinstance(node, dict) else ""
if node_type == "start":
system_prompt = nd.get("system_prompt", system_prompt)
elif node_type == "agent":
tools_include = nd.get("tools", tools_include)
tools_exclude = nd.get("exclude_tools", tools_exclude)
model_name = nd.get("model", model_name)
provider = nd.get("provider", provider)
temperature = float(nd.get("temperature", temperature))
max_iterations = int(nd.get("max_iterations", max_iterations))
return AgentConfig(
name=agent.name,
system_prompt=system_prompt,
user_id=str(agent.user_id) if agent.user_id else None,
llm=AgentLLMConfig(
provider=provider,
model=model_name,
temperature=temperature,
max_iterations=max_iterations,
),
tools=AgentToolConfig(
include_tools=tools_include if tools_include else [],
exclude_tools=tools_exclude if tools_exclude else [],
),
memory=AgentMemoryConfig(
enabled=True,
persist_to_db=True,
learning_enabled=True,
),
budget=AgentBudgetConfig(),
)
@celery_app.task(bind=True)
def execute_agent_task(self, agent_id: str, input_data: dict):
"""异步执行 Agent 任务。
由定时调度 (check_agent_schedules_task) 或手动 API 触发。
创建 Execution 记录,运行 Agent更新结果。
Args:
agent_id: Agent ID
input_data: 输入数据,至少包含 "message" 字段
"""
db = SessionLocal()
start_time = time.time()
try:
agent = db.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
return {"status": "error", "detail": f"Agent {agent_id} 不存在"}
user_message = input_data.get("message") or input_data.get("query") or ""
if not user_message:
return {"status": "error", "detail": "缺少 message 输入"}
# 创建执行记录
execution = Execution(
agent_id=agent_id,
input_data=input_data,
status="running",
)
db.add(execution)
db.flush()
# 更新 Celery 任务状态
self.update_state(
state="PROGRESS",
meta={
"execution_id": str(execution.id),
"agent_id": agent_id,
"progress": 0,
"status": "running",
},
)
# 构建配置并执行
config = _build_agent_config_from_db(agent)
runtime = AgentRuntime(config)
try:
result = asyncio.run(runtime.run(user_message))
execution_time = int((time.time() - start_time) * 1000)
if result.success:
execution.status = "completed"
execution.output_data = {
"content": result.content,
"iterations_used": result.iterations_used,
"tool_calls_made": result.tool_calls_made,
}
execution.execution_time = execution_time
db.commit()
logger.info(
"Agent 异步执行完成: agent=%s execution=%s time=%dms",
agent_id, execution.id, execution_time,
)
return {
"status": "completed",
"execution_id": str(execution.id),
"content": result.content,
"iterations_used": result.iterations_used,
"tool_calls_made": result.tool_calls_made,
"execution_time": execution_time,
}
else:
execution.status = "failed"
execution.error_message = result.error or "Agent 执行返回失败"
execution.execution_time = int((time.time() - start_time) * 1000)
db.commit()
return {
"status": "failed",
"execution_id": str(execution.id),
"error": result.error,
}
except Exception as run_e:
execution_time = int((time.time() - start_time) * 1000)
execution.status = "failed"
execution.error_message = f"Agent 执行异常: {run_e!s}"
execution.execution_time = execution_time
db.commit()
logger.error("Agent 异步执行异常: agent=%s error=%s", agent_id, run_e)
raise
except Exception as e:
logger.error("execute_agent_task 失败: agent=%s error=%s", agent_id, e)
raise
finally:
db.close()