feat: Phase 4 - LLM/Agent fallback chain, cross-agent knowledge sharing, async agent execution
- 4.1 Fallback chain: LLM fallback_llm config in AgentLLMConfig, retry with alternate model on API failure; Agent fallback_agent in DAG nodes - 4.2 Knowledge sharing: GlobalKnowledge model with embedding-based semantic search, auto-extraction of tool names as tags after execution - 4.3 Async execution: execute_agent_task fully implemented with AgentRuntime, scheduler dual-path for workflow/non-workflow agents Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -268,6 +268,8 @@ class AgentRuntime:
|
||||
iterations_used=self.context.iteration,
|
||||
tool_calls_made=self.context.tool_calls_made,
|
||||
)
|
||||
# 提取知识到全局知识池(Agent 间知识共享)
|
||||
await self._extract_global_knowledge(user_input, final_text, steps)
|
||||
return AgentResult(
|
||||
success=True,
|
||||
content=final_text,
|
||||
@@ -774,6 +776,35 @@ class AgentRuntime:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def _extract_global_knowledge(
|
||||
self, user_input: str, final_answer: str, steps: List[AgentStep],
|
||||
) -> None:
|
||||
"""从 Agent 执行结果中提取知识,写入全局知识池(Agent 间共享)。"""
|
||||
# 提取工具调用名称作为 tags
|
||||
tool_names = list(dict.fromkeys(
|
||||
s.tool_name for s in (steps or [])
|
||||
if s.tool_name and s.type == "tool_result"
|
||||
))
|
||||
tags = tool_names[:5] if tool_names else ["对话"]
|
||||
|
||||
# 提取关键信息:用户问题摘要 + 回答要点(前 500 字)
|
||||
content = (
|
||||
f"问题: {user_input[:300]}\n"
|
||||
f"回答要点: {final_answer[:500]}"
|
||||
)
|
||||
if tool_names:
|
||||
content += f"\n使用工具: {', '.join(tool_names[:5])}"
|
||||
|
||||
source_agent_id = self.config.name if self.config.name != "default_agent" else ""
|
||||
source_user_id = self.config.user_id or ""
|
||||
|
||||
await self.memory.save_global_knowledge(
|
||||
content=content,
|
||||
source_agent_id=source_agent_id,
|
||||
source_user_id=source_user_id,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
async def _self_review(self, content: str, task_context: str = "") -> dict:
|
||||
"""输出质量自检:用轻量 LLM 评判输出,返回 {score, passed, issues, suggestions}。"""
|
||||
criteria = (
|
||||
@@ -957,15 +988,7 @@ class _LLMClient:
|
||||
iteration: int = 1,
|
||||
on_completion: Optional[Callable[[Dict[str, Any]], Any]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
调用 LLM。
|
||||
优先使用 llm_service.call_openai_with_tools(支持 ReAct 的多次工具调用)。
|
||||
|
||||
但为避免外层 ReAct 与内部 ReAct 冲突:
|
||||
- 第 1 轮:使用标准 chat(无内部 ReAct),由外层 AgentRuntime 控制循环
|
||||
- 后续轮次:也使用标准 chat,仅追加工具结果
|
||||
"""
|
||||
# 直接用 OpenAI/DeepSeek SDK 调用,由 AgentRuntime 控制循环
|
||||
"""调用 LLM,主模型失败时自动切换 fallback_llm 重试。"""
|
||||
from openai import AsyncOpenAI
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -974,17 +997,36 @@ class _LLMClient:
|
||||
base_url = self._config.base_url or settings.OPENAI_BASE_URL or ""
|
||||
|
||||
if not api_key or api_key == "your-openai-api-key":
|
||||
# 尝试 DeepSeek
|
||||
api_key = self._config.api_key or settings.DEEPSEEK_API_KEY or ""
|
||||
base_url = self._config.base_url or settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("未配置 API Key")
|
||||
|
||||
return await self._do_chat(
|
||||
api_key=api_key, base_url=base_url, model=self._config.model,
|
||||
messages=messages, tools=tools, iteration=iteration,
|
||||
on_completion=on_completion,
|
||||
)
|
||||
|
||||
async def _do_chat(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
iteration: int = 1,
|
||||
on_completion: Optional[Callable[[Dict[str, Any]], Any]] = None,
|
||||
_is_fallback: bool = False,
|
||||
) -> Any:
|
||||
from openai import AsyncOpenAI
|
||||
from app.core.config import settings
|
||||
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
kwargs: Dict[str, Any] = {
|
||||
"model": self._config.model,
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": self._config.temperature,
|
||||
"timeout": self._config.request_timeout,
|
||||
@@ -1015,60 +1057,77 @@ class _LLMClient:
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
# LLM 响应缓存(仅不用工具时缓存,避免复杂序列化)
|
||||
if self._config.cache_enabled and not tools:
|
||||
if self._config.cache_enabled and not tools and not _is_fallback:
|
||||
cache_key = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
|
||||
cached = await _llm_cache_get(cache_key)
|
||||
if cached is not None:
|
||||
logger.info("LLM 响应命中缓存: model=%s", kwargs.get("model"))
|
||||
# 构造简易 message 对象(含 content 字段即可)
|
||||
class _CachedMsg:
|
||||
content = cached
|
||||
tool_calls = None
|
||||
return _CachedMsg()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
last_error = None
|
||||
try:
|
||||
response = await client.chat.completions.create(**kwargs)
|
||||
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
message = response.choices[0].message
|
||||
|
||||
# 缓存写入(仅不用工具时)
|
||||
if self._config.cache_enabled and not tools and message.content:
|
||||
ck = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
|
||||
await _llm_cache_set(ck, message.content, self._config.cache_ttl_ms)
|
||||
|
||||
# 提取 token 用量
|
||||
usage = getattr(response, "usage", None)
|
||||
prompt_tokens = usage.prompt_tokens if usage else 0
|
||||
completion_tokens = usage.completion_tokens if usage else 0
|
||||
total_tokens = usage.total_tokens if usage else 0
|
||||
|
||||
# 调用完成回调
|
||||
if on_completion:
|
||||
on_completion({
|
||||
"model": self._config.model,
|
||||
"provider": self._config.provider,
|
||||
"prompt_tokens": prompt_tokens or 0,
|
||||
"completion_tokens": completion_tokens or 0,
|
||||
"total_tokens": total_tokens or 0,
|
||||
"latency_ms": latency_ms,
|
||||
"iteration_number": iteration,
|
||||
"status": "success",
|
||||
})
|
||||
|
||||
return message
|
||||
except Exception as e:
|
||||
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
if on_completion:
|
||||
on_completion({
|
||||
"model": self._config.model,
|
||||
"provider": self._config.provider,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"latency_ms": latency_ms,
|
||||
"iteration_number": iteration,
|
||||
"status": "error",
|
||||
"error_message": str(e),
|
||||
})
|
||||
last_error = e
|
||||
# 降级回退:主模型失败时尝试 fallback_llm
|
||||
fallback = self._config.fallback_llm
|
||||
if fallback and isinstance(fallback, dict) and not _is_fallback:
|
||||
fb_model = fallback.get("model")
|
||||
fb_api_key = fallback.get("api_key")
|
||||
fb_base_url = fallback.get("base_url")
|
||||
if fb_model and (fb_api_key or fb_base_url):
|
||||
logger.warning(
|
||||
"主模型 %s 调用失败,降级到 %s: %s",
|
||||
model, fb_model, str(e)[:200],
|
||||
)
|
||||
# 先报告主模型失败
|
||||
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
if on_completion:
|
||||
on_completion({
|
||||
"model": model, "provider": self._config.provider,
|
||||
"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0,
|
||||
"latency_ms": latency_ms, "iteration_number": iteration,
|
||||
"status": "fallback", "error_message": str(e),
|
||||
})
|
||||
return await self._do_chat(
|
||||
api_key=fb_api_key or api_key,
|
||||
base_url=fb_base_url or base_url,
|
||||
model=fb_model,
|
||||
messages=messages, tools=tools,
|
||||
iteration=iteration, on_completion=on_completion,
|
||||
_is_fallback=True,
|
||||
)
|
||||
raise
|
||||
|
||||
latency_ms = int((time.perf_counter() - start_time) * 1000)
|
||||
message = response.choices[0].message
|
||||
|
||||
# 缓存写入(仅不用工具时)
|
||||
if self._config.cache_enabled and not tools and message.content:
|
||||
ck = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
|
||||
await _llm_cache_set(ck, message.content, self._config.cache_ttl_ms)
|
||||
|
||||
# 提取 token 用量
|
||||
usage = getattr(response, "usage", None)
|
||||
prompt_tokens = usage.prompt_tokens if usage else 0
|
||||
completion_tokens = usage.completion_tokens if usage else 0
|
||||
total_tokens = usage.total_tokens if usage else 0
|
||||
|
||||
# 调用完成回调
|
||||
if on_completion:
|
||||
on_completion({
|
||||
"model": model,
|
||||
"provider": self._config.provider,
|
||||
"prompt_tokens": prompt_tokens or 0,
|
||||
"completion_tokens": completion_tokens or 0,
|
||||
"total_tokens": total_tokens or 0,
|
||||
"latency_ms": latency_ms,
|
||||
"iteration_number": iteration,
|
||||
"status": "success",
|
||||
})
|
||||
|
||||
return message
|
||||
|
||||
@@ -95,6 +95,11 @@ class AgentMemory:
|
||||
if vector_text:
|
||||
parts.append(vector_text)
|
||||
|
||||
# 3. 全局知识检索:从 GlobalKnowledge 表加载相关条目
|
||||
global_text = await self._global_knowledge_search(query)
|
||||
if global_text:
|
||||
parts.append(global_text)
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
async def _vector_search(self, query: str = "") -> str:
|
||||
@@ -171,6 +176,119 @@ class AgentMemory:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def _global_knowledge_search(self, query: str = "") -> str:
|
||||
"""从 GlobalKnowledge 表检索相关的全局知识条目。"""
|
||||
from app.models.agent import GlobalKnowledge
|
||||
|
||||
db: Optional[Session] = None
|
||||
try:
|
||||
db = SessionLocal()
|
||||
rows = (
|
||||
db.query(GlobalKnowledge)
|
||||
.order_by(GlobalKnowledge.created_at.desc())
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
if not rows:
|
||||
return ""
|
||||
|
||||
# 如果有 query,用向量相似度筛选;否则返回最近的条目
|
||||
if query and query.strip():
|
||||
entries: List[VectorEntry] = []
|
||||
for row in rows:
|
||||
if not row.embedding:
|
||||
continue
|
||||
try:
|
||||
emb = embedding_service.deserialize_embedding(row.embedding)
|
||||
except Exception:
|
||||
emb = []
|
||||
if emb:
|
||||
entries.append({
|
||||
"id": row.id,
|
||||
"scope_kind": "global",
|
||||
"scope_id": "global",
|
||||
"content_text": row.content,
|
||||
"embedding": emb,
|
||||
"metadata": {
|
||||
"source_agent_id": row.source_agent_id,
|
||||
"tags": row.tags or [],
|
||||
},
|
||||
})
|
||||
|
||||
if entries:
|
||||
query_emb = await embedding_service.generate_embedding(query)
|
||||
if query_emb:
|
||||
matched = await embedding_service.similarity_search(
|
||||
query_emb, entries, top_k=min(5, len(entries)),
|
||||
)
|
||||
if matched:
|
||||
lines = ["## 全局知识库"]
|
||||
for i, m in enumerate(matched, 1):
|
||||
tags = m.get("metadata", {}).get("tags", [])
|
||||
tag_str = f" [{', '.join(tags[:3])}]" if tags else ""
|
||||
lines.append(f"{i}.{tag_str} {m.get('content_text', '')[:500]}")
|
||||
return "\n".join(lines)
|
||||
else:
|
||||
# 无 query,返回最近 5 条全局知识
|
||||
recent = rows[:5]
|
||||
if recent:
|
||||
lines = ["## 全局知识库(最近)"]
|
||||
for i, row in enumerate(recent, 1):
|
||||
tag_str = f" [{(', '.join(row.tags[:3]))}]" if row.tags else ""
|
||||
lines.append(f"{i}.{tag_str} {row.content[:500]}")
|
||||
return "\n".join(lines)
|
||||
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.warning("全局知识检索失败: %s", e)
|
||||
return ""
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def save_global_knowledge(
|
||||
self, content: str, source_agent_id: str = "",
|
||||
source_user_id: str = "", tags: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""将知识条目写入全局知识池。"""
|
||||
from app.models.agent import GlobalKnowledge
|
||||
|
||||
if not content or len(content) < 20:
|
||||
return
|
||||
|
||||
db: Optional[Session] = None
|
||||
try:
|
||||
db = SessionLocal()
|
||||
|
||||
# 生成 embedding
|
||||
embedding_json = ""
|
||||
try:
|
||||
emb = await embedding_service.generate_embedding(content)
|
||||
if emb:
|
||||
embedding_json = embedding_service.serialize_embedding(emb) or ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
record = GlobalKnowledge(
|
||||
content=content[:2000],
|
||||
embedding=embedding_json or None,
|
||||
source_agent_id=source_agent_id or "",
|
||||
source_user_id=source_user_id or "",
|
||||
tags=tags or [],
|
||||
scope_kind=self.scope_kind,
|
||||
scope_id=self.scope_id or "global",
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
logger.info("已写入全局知识: agent=%s tags=%s", source_agent_id, tags)
|
||||
except Exception as e:
|
||||
logger.warning("保存全局知识失败: %s", e)
|
||||
if db:
|
||||
db.rollback()
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def save_context(
|
||||
self, user_message: str, assistant_reply: str,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
|
||||
@@ -46,6 +46,7 @@ class AgentLLMConfig(BaseModel):
|
||||
self_review_threshold: float = 0.6 # self-review 通过阈值(0-1)
|
||||
cache_enabled: bool = False # LLM 响应缓存(默认关闭,语义缓存有风险)
|
||||
cache_ttl_ms: int = 300000 # LLM 缓存 TTL,默认 5 分钟
|
||||
fallback_llm: Optional[Dict[str, Any]] = None # 降级模型配置 {provider, model, api_key, base_url}
|
||||
|
||||
|
||||
class AgentBudgetConfig(BaseModel):
|
||||
|
||||
@@ -50,3 +50,21 @@ class AgentExtension(Base):
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AgentExtension(id={self.id}, type={self.extension_type}, name={self.name})>"
|
||||
|
||||
|
||||
class GlobalKnowledge(Base):
|
||||
"""Agent 间知识共享表 — 跨 Agent 的全局知识池"""
|
||||
__tablename__ = "global_knowledge"
|
||||
|
||||
id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="知识ID")
|
||||
content = Column(Text, nullable=False, comment="知识内容摘要")
|
||||
embedding = Column(Text, nullable=True, comment="内容 embedding(JSON 序列化)")
|
||||
source_agent_id = Column(CHAR(36), nullable=True, comment="来源 Agent ID")
|
||||
source_user_id = Column(CHAR(36), nullable=True, comment="来源用户 ID")
|
||||
tags = Column(JSON, nullable=True, comment="分类标签")
|
||||
scope_kind = Column(String(50), default="agent", comment="作用域类型")
|
||||
scope_id = Column(String(100), default="", comment="作用域 ID")
|
||||
created_at = Column(DateTime, default=func.now(), comment="创建时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GlobalKnowledge(id={self.id}, source_agent={self.source_agent_id})>"
|
||||
|
||||
@@ -47,9 +47,6 @@ def create_execution_for_schedule(db: Session, schedule) -> Optional[str]:
|
||||
if not agent:
|
||||
logger.warning("定时任务 %s 关联的 Agent %s 不存在", schedule.id, schedule.agent_id)
|
||||
return None
|
||||
if not agent.workflow_config:
|
||||
logger.warning("Agent %s 缺少 workflow_config,无法执行定时任务", schedule.agent_id)
|
||||
return None
|
||||
|
||||
# 创建执行记录(关联 schedule_id)
|
||||
execution = Execution(
|
||||
@@ -61,16 +58,23 @@ def create_execution_for_schedule(db: Session, schedule) -> Optional[str]:
|
||||
db.add(execution)
|
||||
db.flush() # 获取 id
|
||||
|
||||
# 投递到 Celery
|
||||
from app.tasks.workflow_tasks import execute_workflow_task
|
||||
|
||||
try:
|
||||
task = execute_workflow_task.delay(
|
||||
str(execution.id),
|
||||
f"agent_{schedule.agent_id}",
|
||||
agent.workflow_config,
|
||||
{"message": schedule.input_message},
|
||||
)
|
||||
if agent.workflow_config and agent.workflow_config.get("nodes"):
|
||||
# 有工作流配置:走完整工作流引擎
|
||||
from app.tasks.workflow_tasks import execute_workflow_task
|
||||
task = execute_workflow_task.delay(
|
||||
str(execution.id),
|
||||
f"agent_{schedule.agent_id}",
|
||||
agent.workflow_config,
|
||||
{"message": schedule.input_message},
|
||||
)
|
||||
else:
|
||||
# 无工作流配置:走简单 Agent 异步执行
|
||||
from app.tasks.agent_tasks import execute_agent_task
|
||||
task = execute_agent_task.delay(
|
||||
str(schedule.agent_id),
|
||||
{"message": schedule.input_message},
|
||||
)
|
||||
execution.task_id = task.id
|
||||
execution.status = "running"
|
||||
db.commit()
|
||||
|
||||
@@ -1956,6 +1956,38 @@ class WorkflowEngine:
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
# fallback_agent 降级:主 Agent 失败时尝试备用 Agent
|
||||
node_data = node.get("data", {}) or {}
|
||||
fallback_agent_id = node_data.get("fallback_agent", "")
|
||||
if fallback_agent_id and str(fallback_agent_id) != str(node_data.get("agent_id", "")):
|
||||
if self.logger:
|
||||
self.logger.warn(
|
||||
"Agent 节点 %s 失败,降级到 fallback_agent: %s",
|
||||
node_id, fallback_agent_id,
|
||||
)
|
||||
try:
|
||||
fb_node_data = {**node_data, "agent_id": fallback_agent_id}
|
||||
fb_node_data.pop("fallback_agent", None)
|
||||
result = await run_agent_node(
|
||||
node_data=fb_node_data,
|
||||
input_data=input_data,
|
||||
execution_logger=self.logger,
|
||||
user_id=self.trusted_model_config_user_id,
|
||||
on_tool_executed=_agent_on_tool,
|
||||
on_llm_invocation=_on_agent_llm,
|
||||
budget_limits={
|
||||
"max_llm_invocations": self._cap_llm,
|
||||
"max_tool_calls": self._cap_tool,
|
||||
},
|
||||
)
|
||||
if self.logger:
|
||||
self.logger.info("fallback_agent %s 执行成功", fallback_agent_id)
|
||||
return result
|
||||
except Exception as fb_e:
|
||||
if self.logger:
|
||||
self.logger.error("fallback_agent %s 也失败: %s", fallback_agent_id, fb_e)
|
||||
logger.error(f"fallback_agent 执行失败: {fb_e}", exc_info=True)
|
||||
|
||||
if self.logger:
|
||||
duration = int((time.time() - start_time) * 1000)
|
||||
self.logger.log_node_error(node_id, node_type, e, duration)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Agent任务
|
||||
Agent 异步执行任务 — Celery 任务,支持定时调度和手动触发。
|
||||
"""
|
||||
from celery import Task
|
||||
from app.core.tools_bootstrap import ensure_builtin_tools_registered
|
||||
@@ -7,10 +7,169 @@ from app.core.tools_bootstrap import ensure_builtin_tools_registered
|
||||
ensure_builtin_tools_registered()
|
||||
|
||||
from app.core.celery_app import celery_app
|
||||
from app.core.database import SessionLocal
|
||||
from app.agent_runtime.core import AgentRuntime
|
||||
from app.agent_runtime.schemas import (
|
||||
AgentConfig,
|
||||
AgentLLMConfig,
|
||||
AgentToolConfig,
|
||||
AgentMemoryConfig,
|
||||
AgentBudgetConfig,
|
||||
)
|
||||
from app.models.agent import Agent
|
||||
from app.models.execution import Execution
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_agent_config_from_db(agent: Agent) -> AgentConfig:
|
||||
"""从 DB Agent 记录构建 AgentConfig。"""
|
||||
wf = agent.workflow_config or {}
|
||||
nodes = wf.get("nodes", [])
|
||||
edges = wf.get("edges", [])
|
||||
|
||||
# 从工作流节点中查找 start 节点获取 system_prompt
|
||||
system_prompt = "你是一个有用的AI助手。"
|
||||
tools_include = []
|
||||
tools_exclude = []
|
||||
model_name = "gpt-4o-mini"
|
||||
provider = "openai"
|
||||
temperature = 0.7
|
||||
max_iterations = 10
|
||||
|
||||
for node in nodes:
|
||||
nd = node.get("data", {}) if isinstance(node, dict) else {}
|
||||
node_type = node.get("type", "") if isinstance(node, dict) else ""
|
||||
if node_type == "start":
|
||||
system_prompt = nd.get("system_prompt", system_prompt)
|
||||
elif node_type == "agent":
|
||||
tools_include = nd.get("tools", tools_include)
|
||||
tools_exclude = nd.get("exclude_tools", tools_exclude)
|
||||
model_name = nd.get("model", model_name)
|
||||
provider = nd.get("provider", provider)
|
||||
temperature = float(nd.get("temperature", temperature))
|
||||
max_iterations = int(nd.get("max_iterations", max_iterations))
|
||||
|
||||
return AgentConfig(
|
||||
name=agent.name,
|
||||
system_prompt=system_prompt,
|
||||
user_id=str(agent.user_id) if agent.user_id else None,
|
||||
llm=AgentLLMConfig(
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
temperature=temperature,
|
||||
max_iterations=max_iterations,
|
||||
),
|
||||
tools=AgentToolConfig(
|
||||
include_tools=tools_include if tools_include else [],
|
||||
exclude_tools=tools_exclude if tools_exclude else [],
|
||||
),
|
||||
memory=AgentMemoryConfig(
|
||||
enabled=True,
|
||||
persist_to_db=True,
|
||||
learning_enabled=True,
|
||||
),
|
||||
budget=AgentBudgetConfig(),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(bind=True)
|
||||
def execute_agent_task(self, agent_id: str, input_data: dict):
|
||||
"""执行Agent任务"""
|
||||
# TODO: 实现Agent执行逻辑
|
||||
return {"status": "pending", "agent_id": agent_id}
|
||||
"""异步执行 Agent 任务。
|
||||
|
||||
由定时调度 (check_agent_schedules_task) 或手动 API 触发。
|
||||
创建 Execution 记录,运行 Agent,更新结果。
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
input_data: 输入数据,至少包含 "message" 字段
|
||||
"""
|
||||
db = SessionLocal()
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
return {"status": "error", "detail": f"Agent {agent_id} 不存在"}
|
||||
|
||||
user_message = input_data.get("message") or input_data.get("query") or ""
|
||||
if not user_message:
|
||||
return {"status": "error", "detail": "缺少 message 输入"}
|
||||
|
||||
# 创建执行记录
|
||||
execution = Execution(
|
||||
agent_id=agent_id,
|
||||
input_data=input_data,
|
||||
status="running",
|
||||
)
|
||||
db.add(execution)
|
||||
db.flush()
|
||||
|
||||
# 更新 Celery 任务状态
|
||||
self.update_state(
|
||||
state="PROGRESS",
|
||||
meta={
|
||||
"execution_id": str(execution.id),
|
||||
"agent_id": agent_id,
|
||||
"progress": 0,
|
||||
"status": "running",
|
||||
},
|
||||
)
|
||||
|
||||
# 构建配置并执行
|
||||
config = _build_agent_config_from_db(agent)
|
||||
runtime = AgentRuntime(config)
|
||||
|
||||
try:
|
||||
result = asyncio.run(runtime.run(user_message))
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
if result.success:
|
||||
execution.status = "completed"
|
||||
execution.output_data = {
|
||||
"content": result.content,
|
||||
"iterations_used": result.iterations_used,
|
||||
"tool_calls_made": result.tool_calls_made,
|
||||
}
|
||||
execution.execution_time = execution_time
|
||||
db.commit()
|
||||
|
||||
logger.info(
|
||||
"Agent 异步执行完成: agent=%s execution=%s time=%dms",
|
||||
agent_id, execution.id, execution_time,
|
||||
)
|
||||
return {
|
||||
"status": "completed",
|
||||
"execution_id": str(execution.id),
|
||||
"content": result.content,
|
||||
"iterations_used": result.iterations_used,
|
||||
"tool_calls_made": result.tool_calls_made,
|
||||
"execution_time": execution_time,
|
||||
}
|
||||
else:
|
||||
execution.status = "failed"
|
||||
execution.error_message = result.error or "Agent 执行返回失败"
|
||||
execution.execution_time = int((time.time() - start_time) * 1000)
|
||||
db.commit()
|
||||
return {
|
||||
"status": "failed",
|
||||
"execution_id": str(execution.id),
|
||||
"error": result.error,
|
||||
}
|
||||
except Exception as run_e:
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
execution.status = "failed"
|
||||
execution.error_message = f"Agent 执行异常: {run_e!s}"
|
||||
execution.execution_time = execution_time
|
||||
db.commit()
|
||||
logger.error("Agent 异步执行异常: agent=%s error=%s", agent_id, run_e)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error("execute_agent_task 失败: agent=%s error=%s", agent_id, e)
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user