feat: Phase 4 - LLM/Agent fallback chain, cross-agent knowledge sharing, async agent execution

- 4.1 Fallback chain: LLM fallback_llm config in AgentLLMConfig, retry with alternate model on API failure; Agent fallback_agent in DAG nodes
- 4.2 Knowledge sharing: GlobalKnowledge model with embedding-based semantic search, auto-extraction of tool names as tags after execution
- 4.3 Async execution: execute_agent_task fully implemented with AgentRuntime, scheduler dual-path for workflow/non-workflow agents

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
renjianbo
2026-05-05 00:27:54 +08:00
parent 7e00b027d4
commit 592bca4f39
7 changed files with 461 additions and 70 deletions

View File

@@ -268,6 +268,8 @@ class AgentRuntime:
iterations_used=self.context.iteration,
tool_calls_made=self.context.tool_calls_made,
)
# 提取知识到全局知识池Agent 间知识共享)
await self._extract_global_knowledge(user_input, final_text, steps)
return AgentResult(
success=True,
content=final_text,
@@ -774,6 +776,35 @@ class AgentRuntime:
if db:
db.close()
async def _extract_global_knowledge(
self, user_input: str, final_answer: str, steps: List[AgentStep],
) -> None:
"""从 Agent 执行结果中提取知识写入全局知识池Agent 间共享)。"""
# 提取工具调用名称作为 tags
tool_names = list(dict.fromkeys(
s.tool_name for s in (steps or [])
if s.tool_name and s.type == "tool_result"
))
tags = tool_names[:5] if tool_names else ["对话"]
# 提取关键信息:用户问题摘要 + 回答要点(前 500 字)
content = (
f"问题: {user_input[:300]}\n"
f"回答要点: {final_answer[:500]}"
)
if tool_names:
content += f"\n使用工具: {', '.join(tool_names[:5])}"
source_agent_id = self.config.name if self.config.name != "default_agent" else ""
source_user_id = self.config.user_id or ""
await self.memory.save_global_knowledge(
content=content,
source_agent_id=source_agent_id,
source_user_id=source_user_id,
tags=tags,
)
async def _self_review(self, content: str, task_context: str = "") -> dict:
"""输出质量自检:用轻量 LLM 评判输出,返回 {score, passed, issues, suggestions}。"""
criteria = (
@@ -957,15 +988,7 @@ class _LLMClient:
iteration: int = 1,
on_completion: Optional[Callable[[Dict[str, Any]], Any]] = None,
) -> Any:
"""
调用 LLM。
优先使用 llm_service.call_openai_with_tools支持 ReAct 的多次工具调用)。
但为避免外层 ReAct 与内部 ReAct 冲突:
- 第 1 轮:使用标准 chat无内部 ReAct由外层 AgentRuntime 控制循环
- 后续轮次:也使用标准 chat仅追加工具结果
"""
# 直接用 OpenAI/DeepSeek SDK 调用,由 AgentRuntime 控制循环
"""调用 LLM主模型失败时自动切换 fallback_llm 重试。"""
from openai import AsyncOpenAI
from app.core.config import settings
@@ -974,17 +997,36 @@ class _LLMClient:
base_url = self._config.base_url or settings.OPENAI_BASE_URL or ""
if not api_key or api_key == "your-openai-api-key":
# 尝试 DeepSeek
api_key = self._config.api_key or settings.DEEPSEEK_API_KEY or ""
base_url = self._config.base_url or settings.DEEPSEEK_BASE_URL or "https://api.deepseek.com"
if not api_key:
raise ValueError("未配置 API Key")
return await self._do_chat(
api_key=api_key, base_url=base_url, model=self._config.model,
messages=messages, tools=tools, iteration=iteration,
on_completion=on_completion,
)
async def _do_chat(
self,
api_key: str,
base_url: str,
model: str,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
iteration: int = 1,
on_completion: Optional[Callable[[Dict[str, Any]], Any]] = None,
_is_fallback: bool = False,
) -> Any:
from openai import AsyncOpenAI
from app.core.config import settings
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
kwargs: Dict[str, Any] = {
"model": self._config.model,
"model": model,
"messages": messages,
"temperature": self._config.temperature,
"timeout": self._config.request_timeout,
@@ -1015,60 +1057,77 @@ class _LLMClient:
kwargs["tool_choice"] = "auto"
# LLM 响应缓存(仅不用工具时缓存,避免复杂序列化)
if self._config.cache_enabled and not tools:
if self._config.cache_enabled and not tools and not _is_fallback:
cache_key = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
cached = await _llm_cache_get(cache_key)
if cached is not None:
logger.info("LLM 响应命中缓存: model=%s", kwargs.get("model"))
# 构造简易 message 对象(含 content 字段即可)
class _CachedMsg:
content = cached
tool_calls = None
return _CachedMsg()
start_time = time.perf_counter()
last_error = None
try:
response = await client.chat.completions.create(**kwargs)
latency_ms = int((time.perf_counter() - start_time) * 1000)
message = response.choices[0].message
# 缓存写入(仅不用工具时)
if self._config.cache_enabled and not tools and message.content:
ck = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
await _llm_cache_set(ck, message.content, self._config.cache_ttl_ms)
# 提取 token 用量
usage = getattr(response, "usage", None)
prompt_tokens = usage.prompt_tokens if usage else 0
completion_tokens = usage.completion_tokens if usage else 0
total_tokens = usage.total_tokens if usage else 0
# 调用完成回调
if on_completion:
on_completion({
"model": self._config.model,
"provider": self._config.provider,
"prompt_tokens": prompt_tokens or 0,
"completion_tokens": completion_tokens or 0,
"total_tokens": total_tokens or 0,
"latency_ms": latency_ms,
"iteration_number": iteration,
"status": "success",
})
return message
except Exception as e:
latency_ms = int((time.perf_counter() - start_time) * 1000)
if on_completion:
on_completion({
"model": self._config.model,
"provider": self._config.provider,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"latency_ms": latency_ms,
"iteration_number": iteration,
"status": "error",
"error_message": str(e),
})
last_error = e
# 降级回退:主模型失败时尝试 fallback_llm
fallback = self._config.fallback_llm
if fallback and isinstance(fallback, dict) and not _is_fallback:
fb_model = fallback.get("model")
fb_api_key = fallback.get("api_key")
fb_base_url = fallback.get("base_url")
if fb_model and (fb_api_key or fb_base_url):
logger.warning(
"主模型 %s 调用失败,降级到 %s: %s",
model, fb_model, str(e)[:200],
)
# 先报告主模型失败
latency_ms = int((time.perf_counter() - start_time) * 1000)
if on_completion:
on_completion({
"model": model, "provider": self._config.provider,
"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0,
"latency_ms": latency_ms, "iteration_number": iteration,
"status": "fallback", "error_message": str(e),
})
return await self._do_chat(
api_key=fb_api_key or api_key,
base_url=fb_base_url or base_url,
model=fb_model,
messages=messages, tools=tools,
iteration=iteration, on_completion=on_completion,
_is_fallback=True,
)
raise
latency_ms = int((time.perf_counter() - start_time) * 1000)
message = response.choices[0].message
# 缓存写入(仅不用工具时)
if self._config.cache_enabled and not tools and message.content:
ck = _llm_cache_key(kwargs.get("messages", []), kwargs.get("model", ""))
await _llm_cache_set(ck, message.content, self._config.cache_ttl_ms)
# 提取 token 用量
usage = getattr(response, "usage", None)
prompt_tokens = usage.prompt_tokens if usage else 0
completion_tokens = usage.completion_tokens if usage else 0
total_tokens = usage.total_tokens if usage else 0
# 调用完成回调
if on_completion:
on_completion({
"model": model,
"provider": self._config.provider,
"prompt_tokens": prompt_tokens or 0,
"completion_tokens": completion_tokens or 0,
"total_tokens": total_tokens or 0,
"latency_ms": latency_ms,
"iteration_number": iteration,
"status": "success",
})
return message

View File

@@ -95,6 +95,11 @@ class AgentMemory:
if vector_text:
parts.append(vector_text)
# 3. 全局知识检索:从 GlobalKnowledge 表加载相关条目
global_text = await self._global_knowledge_search(query)
if global_text:
parts.append(global_text)
return "\n\n".join(parts) if parts else ""
async def _vector_search(self, query: str = "") -> str:
@@ -171,6 +176,119 @@ class AgentMemory:
if db:
db.close()
async def _global_knowledge_search(self, query: str = "") -> str:
"""从 GlobalKnowledge 表检索相关的全局知识条目。"""
from app.models.agent import GlobalKnowledge
db: Optional[Session] = None
try:
db = SessionLocal()
rows = (
db.query(GlobalKnowledge)
.order_by(GlobalKnowledge.created_at.desc())
.limit(50)
.all()
)
if not rows:
return ""
# 如果有 query用向量相似度筛选否则返回最近的条目
if query and query.strip():
entries: List[VectorEntry] = []
for row in rows:
if not row.embedding:
continue
try:
emb = embedding_service.deserialize_embedding(row.embedding)
except Exception:
emb = []
if emb:
entries.append({
"id": row.id,
"scope_kind": "global",
"scope_id": "global",
"content_text": row.content,
"embedding": emb,
"metadata": {
"source_agent_id": row.source_agent_id,
"tags": row.tags or [],
},
})
if entries:
query_emb = await embedding_service.generate_embedding(query)
if query_emb:
matched = await embedding_service.similarity_search(
query_emb, entries, top_k=min(5, len(entries)),
)
if matched:
lines = ["## 全局知识库"]
for i, m in enumerate(matched, 1):
tags = m.get("metadata", {}).get("tags", [])
tag_str = f" [{', '.join(tags[:3])}]" if tags else ""
lines.append(f"{i}.{tag_str} {m.get('content_text', '')[:500]}")
return "\n".join(lines)
else:
# 无 query返回最近 5 条全局知识
recent = rows[:5]
if recent:
lines = ["## 全局知识库(最近)"]
for i, row in enumerate(recent, 1):
tag_str = f" [{(', '.join(row.tags[:3]))}]" if row.tags else ""
lines.append(f"{i}.{tag_str} {row.content[:500]}")
return "\n".join(lines)
return ""
except Exception as e:
logger.warning("全局知识检索失败: %s", e)
return ""
finally:
if db:
db.close()
async def save_global_knowledge(
self, content: str, source_agent_id: str = "",
source_user_id: str = "", tags: Optional[List[str]] = None,
) -> None:
"""将知识条目写入全局知识池。"""
from app.models.agent import GlobalKnowledge
if not content or len(content) < 20:
return
db: Optional[Session] = None
try:
db = SessionLocal()
# 生成 embedding
embedding_json = ""
try:
emb = await embedding_service.generate_embedding(content)
if emb:
embedding_json = embedding_service.serialize_embedding(emb) or ""
except Exception:
pass
record = GlobalKnowledge(
content=content[:2000],
embedding=embedding_json or None,
source_agent_id=source_agent_id or "",
source_user_id=source_user_id or "",
tags=tags or [],
scope_kind=self.scope_kind,
scope_id=self.scope_id or "global",
)
db.add(record)
db.commit()
logger.info("已写入全局知识: agent=%s tags=%s", source_agent_id, tags)
except Exception as e:
logger.warning("保存全局知识失败: %s", e)
if db:
db.rollback()
finally:
if db:
db.close()
async def save_context(
self, user_message: str, assistant_reply: str,
messages: Optional[List[Dict[str, Any]]] = None,

View File

@@ -46,6 +46,7 @@ class AgentLLMConfig(BaseModel):
self_review_threshold: float = 0.6 # self-review 通过阈值0-1
cache_enabled: bool = False # LLM 响应缓存(默认关闭,语义缓存有风险)
cache_ttl_ms: int = 300000 # LLM 缓存 TTL默认 5 分钟
fallback_llm: Optional[Dict[str, Any]] = None # 降级模型配置 {provider, model, api_key, base_url}
class AgentBudgetConfig(BaseModel):

View File

@@ -50,3 +50,21 @@ class AgentExtension(Base):
def __repr__(self):
return f"<AgentExtension(id={self.id}, type={self.extension_type}, name={self.name})>"
class GlobalKnowledge(Base):
"""Agent 间知识共享表 — 跨 Agent 的全局知识池"""
__tablename__ = "global_knowledge"
id = Column(CHAR(36), primary_key=True, default=lambda: str(uuid.uuid4()), comment="知识ID")
content = Column(Text, nullable=False, comment="知识内容摘要")
embedding = Column(Text, nullable=True, comment="内容 embeddingJSON 序列化)")
source_agent_id = Column(CHAR(36), nullable=True, comment="来源 Agent ID")
source_user_id = Column(CHAR(36), nullable=True, comment="来源用户 ID")
tags = Column(JSON, nullable=True, comment="分类标签")
scope_kind = Column(String(50), default="agent", comment="作用域类型")
scope_id = Column(String(100), default="", comment="作用域 ID")
created_at = Column(DateTime, default=func.now(), comment="创建时间")
def __repr__(self):
return f"<GlobalKnowledge(id={self.id}, source_agent={self.source_agent_id})>"

View File

@@ -47,9 +47,6 @@ def create_execution_for_schedule(db: Session, schedule) -> Optional[str]:
if not agent:
logger.warning("定时任务 %s 关联的 Agent %s 不存在", schedule.id, schedule.agent_id)
return None
if not agent.workflow_config:
logger.warning("Agent %s 缺少 workflow_config无法执行定时任务", schedule.agent_id)
return None
# 创建执行记录(关联 schedule_id
execution = Execution(
@@ -61,16 +58,23 @@ def create_execution_for_schedule(db: Session, schedule) -> Optional[str]:
db.add(execution)
db.flush() # 获取 id
# 投递到 Celery
from app.tasks.workflow_tasks import execute_workflow_task
try:
task = execute_workflow_task.delay(
str(execution.id),
f"agent_{schedule.agent_id}",
agent.workflow_config,
{"message": schedule.input_message},
)
if agent.workflow_config and agent.workflow_config.get("nodes"):
# 有工作流配置:走完整工作流引擎
from app.tasks.workflow_tasks import execute_workflow_task
task = execute_workflow_task.delay(
str(execution.id),
f"agent_{schedule.agent_id}",
agent.workflow_config,
{"message": schedule.input_message},
)
else:
# 无工作流配置:走简单 Agent 异步执行
from app.tasks.agent_tasks import execute_agent_task
task = execute_agent_task.delay(
str(schedule.agent_id),
{"message": schedule.input_message},
)
execution.task_id = task.id
execution.status = "running"
db.commit()

View File

@@ -1956,6 +1956,38 @@ class WorkflowEngine:
)
return result
except Exception as e:
# fallback_agent 降级:主 Agent 失败时尝试备用 Agent
node_data = node.get("data", {}) or {}
fallback_agent_id = node_data.get("fallback_agent", "")
if fallback_agent_id and str(fallback_agent_id) != str(node_data.get("agent_id", "")):
if self.logger:
self.logger.warn(
"Agent 节点 %s 失败,降级到 fallback_agent: %s",
node_id, fallback_agent_id,
)
try:
fb_node_data = {**node_data, "agent_id": fallback_agent_id}
fb_node_data.pop("fallback_agent", None)
result = await run_agent_node(
node_data=fb_node_data,
input_data=input_data,
execution_logger=self.logger,
user_id=self.trusted_model_config_user_id,
on_tool_executed=_agent_on_tool,
on_llm_invocation=_on_agent_llm,
budget_limits={
"max_llm_invocations": self._cap_llm,
"max_tool_calls": self._cap_tool,
},
)
if self.logger:
self.logger.info("fallback_agent %s 执行成功", fallback_agent_id)
return result
except Exception as fb_e:
if self.logger:
self.logger.error("fallback_agent %s 也失败: %s", fallback_agent_id, fb_e)
logger.error(f"fallback_agent 执行失败: {fb_e}", exc_info=True)
if self.logger:
duration = int((time.time() - start_time) * 1000)
self.logger.log_node_error(node_id, node_type, e, duration)

View File

@@ -1,5 +1,5 @@
"""
Agent任务
Agent 异步执行任务 — Celery 任务,支持定时调度和手动触发。
"""
from celery import Task
from app.core.tools_bootstrap import ensure_builtin_tools_registered
@@ -7,10 +7,169 @@ from app.core.tools_bootstrap import ensure_builtin_tools_registered
ensure_builtin_tools_registered()
from app.core.celery_app import celery_app
from app.core.database import SessionLocal
from app.agent_runtime.core import AgentRuntime
from app.agent_runtime.schemas import (
AgentConfig,
AgentLLMConfig,
AgentToolConfig,
AgentMemoryConfig,
AgentBudgetConfig,
)
from app.models.agent import Agent
from app.models.execution import Execution
import asyncio
import logging
import time
logger = logging.getLogger(__name__)
def _build_agent_config_from_db(agent: Agent) -> AgentConfig:
"""从 DB Agent 记录构建 AgentConfig。"""
wf = agent.workflow_config or {}
nodes = wf.get("nodes", [])
edges = wf.get("edges", [])
# 从工作流节点中查找 start 节点获取 system_prompt
system_prompt = "你是一个有用的AI助手。"
tools_include = []
tools_exclude = []
model_name = "gpt-4o-mini"
provider = "openai"
temperature = 0.7
max_iterations = 10
for node in nodes:
nd = node.get("data", {}) if isinstance(node, dict) else {}
node_type = node.get("type", "") if isinstance(node, dict) else ""
if node_type == "start":
system_prompt = nd.get("system_prompt", system_prompt)
elif node_type == "agent":
tools_include = nd.get("tools", tools_include)
tools_exclude = nd.get("exclude_tools", tools_exclude)
model_name = nd.get("model", model_name)
provider = nd.get("provider", provider)
temperature = float(nd.get("temperature", temperature))
max_iterations = int(nd.get("max_iterations", max_iterations))
return AgentConfig(
name=agent.name,
system_prompt=system_prompt,
user_id=str(agent.user_id) if agent.user_id else None,
llm=AgentLLMConfig(
provider=provider,
model=model_name,
temperature=temperature,
max_iterations=max_iterations,
),
tools=AgentToolConfig(
include_tools=tools_include if tools_include else [],
exclude_tools=tools_exclude if tools_exclude else [],
),
memory=AgentMemoryConfig(
enabled=True,
persist_to_db=True,
learning_enabled=True,
),
budget=AgentBudgetConfig(),
)
@celery_app.task(bind=True)
def execute_agent_task(self, agent_id: str, input_data: dict):
"""执行Agent任务"""
# TODO: 实现Agent执行逻辑
return {"status": "pending", "agent_id": agent_id}
"""异步执行 Agent 任务
由定时调度 (check_agent_schedules_task) 或手动 API 触发。
创建 Execution 记录,运行 Agent更新结果。
Args:
agent_id: Agent ID
input_data: 输入数据,至少包含 "message" 字段
"""
db = SessionLocal()
start_time = time.time()
try:
agent = db.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
return {"status": "error", "detail": f"Agent {agent_id} 不存在"}
user_message = input_data.get("message") or input_data.get("query") or ""
if not user_message:
return {"status": "error", "detail": "缺少 message 输入"}
# 创建执行记录
execution = Execution(
agent_id=agent_id,
input_data=input_data,
status="running",
)
db.add(execution)
db.flush()
# 更新 Celery 任务状态
self.update_state(
state="PROGRESS",
meta={
"execution_id": str(execution.id),
"agent_id": agent_id,
"progress": 0,
"status": "running",
},
)
# 构建配置并执行
config = _build_agent_config_from_db(agent)
runtime = AgentRuntime(config)
try:
result = asyncio.run(runtime.run(user_message))
execution_time = int((time.time() - start_time) * 1000)
if result.success:
execution.status = "completed"
execution.output_data = {
"content": result.content,
"iterations_used": result.iterations_used,
"tool_calls_made": result.tool_calls_made,
}
execution.execution_time = execution_time
db.commit()
logger.info(
"Agent 异步执行完成: agent=%s execution=%s time=%dms",
agent_id, execution.id, execution_time,
)
return {
"status": "completed",
"execution_id": str(execution.id),
"content": result.content,
"iterations_used": result.iterations_used,
"tool_calls_made": result.tool_calls_made,
"execution_time": execution_time,
}
else:
execution.status = "failed"
execution.error_message = result.error or "Agent 执行返回失败"
execution.execution_time = int((time.time() - start_time) * 1000)
db.commit()
return {
"status": "failed",
"execution_id": str(execution.id),
"error": result.error,
}
except Exception as run_e:
execution_time = int((time.time() - start_time) * 1000)
execution.status = "failed"
execution.error_message = f"Agent 执行异常: {run_e!s}"
execution.execution_time = execution_time
db.commit()
logger.error("Agent 异步执行异常: agent=%s error=%s", agent_id, run_e)
raise
except Exception as e:
logger.error("execute_agent_task 失败: agent=%s error=%s", agent_id, e)
raise
finally:
db.close()