- 4.1 Fallback chain: LLM fallback_llm config in AgentLLMConfig, retry with alternate model on API failure; Agent fallback_agent in DAG nodes - 4.2 Knowledge sharing: GlobalKnowledge model with embedding-based semantic search, auto-extraction of tool names as tags after execution - 4.3 Async execution: execute_agent_task fully implemented with AgentRuntime, scheduler dual-path for workflow/non-workflow agents Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
176 lines
5.9 KiB
Python
176 lines
5.9 KiB
Python
"""
|
||
Agent 异步执行任务 — Celery 任务,支持定时调度和手动触发。
|
||
"""
|
||
from celery import Task
|
||
from app.core.tools_bootstrap import ensure_builtin_tools_registered
|
||
|
||
ensure_builtin_tools_registered()
|
||
|
||
from app.core.celery_app import celery_app
|
||
from app.core.database import SessionLocal
|
||
from app.agent_runtime.core import AgentRuntime
|
||
from app.agent_runtime.schemas import (
|
||
AgentConfig,
|
||
AgentLLMConfig,
|
||
AgentToolConfig,
|
||
AgentMemoryConfig,
|
||
AgentBudgetConfig,
|
||
)
|
||
from app.models.agent import Agent
|
||
from app.models.execution import Execution
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _build_agent_config_from_db(agent: Agent) -> AgentConfig:
|
||
"""从 DB Agent 记录构建 AgentConfig。"""
|
||
wf = agent.workflow_config or {}
|
||
nodes = wf.get("nodes", [])
|
||
edges = wf.get("edges", [])
|
||
|
||
# 从工作流节点中查找 start 节点获取 system_prompt
|
||
system_prompt = "你是一个有用的AI助手。"
|
||
tools_include = []
|
||
tools_exclude = []
|
||
model_name = "gpt-4o-mini"
|
||
provider = "openai"
|
||
temperature = 0.7
|
||
max_iterations = 10
|
||
|
||
for node in nodes:
|
||
nd = node.get("data", {}) if isinstance(node, dict) else {}
|
||
node_type = node.get("type", "") if isinstance(node, dict) else ""
|
||
if node_type == "start":
|
||
system_prompt = nd.get("system_prompt", system_prompt)
|
||
elif node_type == "agent":
|
||
tools_include = nd.get("tools", tools_include)
|
||
tools_exclude = nd.get("exclude_tools", tools_exclude)
|
||
model_name = nd.get("model", model_name)
|
||
provider = nd.get("provider", provider)
|
||
temperature = float(nd.get("temperature", temperature))
|
||
max_iterations = int(nd.get("max_iterations", max_iterations))
|
||
|
||
return AgentConfig(
|
||
name=agent.name,
|
||
system_prompt=system_prompt,
|
||
user_id=str(agent.user_id) if agent.user_id else None,
|
||
llm=AgentLLMConfig(
|
||
provider=provider,
|
||
model=model_name,
|
||
temperature=temperature,
|
||
max_iterations=max_iterations,
|
||
),
|
||
tools=AgentToolConfig(
|
||
include_tools=tools_include if tools_include else [],
|
||
exclude_tools=tools_exclude if tools_exclude else [],
|
||
),
|
||
memory=AgentMemoryConfig(
|
||
enabled=True,
|
||
persist_to_db=True,
|
||
learning_enabled=True,
|
||
),
|
||
budget=AgentBudgetConfig(),
|
||
)
|
||
|
||
|
||
@celery_app.task(bind=True)
|
||
def execute_agent_task(self, agent_id: str, input_data: dict):
|
||
"""异步执行 Agent 任务。
|
||
|
||
由定时调度 (check_agent_schedules_task) 或手动 API 触发。
|
||
创建 Execution 记录,运行 Agent,更新结果。
|
||
|
||
Args:
|
||
agent_id: Agent ID
|
||
input_data: 输入数据,至少包含 "message" 字段
|
||
"""
|
||
db = SessionLocal()
|
||
start_time = time.time()
|
||
|
||
try:
|
||
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
||
if not agent:
|
||
return {"status": "error", "detail": f"Agent {agent_id} 不存在"}
|
||
|
||
user_message = input_data.get("message") or input_data.get("query") or ""
|
||
if not user_message:
|
||
return {"status": "error", "detail": "缺少 message 输入"}
|
||
|
||
# 创建执行记录
|
||
execution = Execution(
|
||
agent_id=agent_id,
|
||
input_data=input_data,
|
||
status="running",
|
||
)
|
||
db.add(execution)
|
||
db.flush()
|
||
|
||
# 更新 Celery 任务状态
|
||
self.update_state(
|
||
state="PROGRESS",
|
||
meta={
|
||
"execution_id": str(execution.id),
|
||
"agent_id": agent_id,
|
||
"progress": 0,
|
||
"status": "running",
|
||
},
|
||
)
|
||
|
||
# 构建配置并执行
|
||
config = _build_agent_config_from_db(agent)
|
||
runtime = AgentRuntime(config)
|
||
|
||
try:
|
||
result = asyncio.run(runtime.run(user_message))
|
||
execution_time = int((time.time() - start_time) * 1000)
|
||
|
||
if result.success:
|
||
execution.status = "completed"
|
||
execution.output_data = {
|
||
"content": result.content,
|
||
"iterations_used": result.iterations_used,
|
||
"tool_calls_made": result.tool_calls_made,
|
||
}
|
||
execution.execution_time = execution_time
|
||
db.commit()
|
||
|
||
logger.info(
|
||
"Agent 异步执行完成: agent=%s execution=%s time=%dms",
|
||
agent_id, execution.id, execution_time,
|
||
)
|
||
return {
|
||
"status": "completed",
|
||
"execution_id": str(execution.id),
|
||
"content": result.content,
|
||
"iterations_used": result.iterations_used,
|
||
"tool_calls_made": result.tool_calls_made,
|
||
"execution_time": execution_time,
|
||
}
|
||
else:
|
||
execution.status = "failed"
|
||
execution.error_message = result.error or "Agent 执行返回失败"
|
||
execution.execution_time = int((time.time() - start_time) * 1000)
|
||
db.commit()
|
||
return {
|
||
"status": "failed",
|
||
"execution_id": str(execution.id),
|
||
"error": result.error,
|
||
}
|
||
except Exception as run_e:
|
||
execution_time = int((time.time() - start_time) * 1000)
|
||
execution.status = "failed"
|
||
execution.error_message = f"Agent 执行异常: {run_e!s}"
|
||
execution.execution_time = execution_time
|
||
db.commit()
|
||
logger.error("Agent 异步执行异常: agent=%s error=%s", agent_id, run_e)
|
||
raise
|
||
|
||
except Exception as e:
|
||
logger.error("execute_agent_task 失败: agent=%s error=%s", agent_id, e)
|
||
raise
|
||
finally:
|
||
db.close()
|