feat: 实现 Agent 自主学习 — 从历史执行中优化工具选择
- 新增 AgentLearningPattern 模型和 agent_learning_service 服务 - 执行前注入历史学习模式到 system prompt 作为工具选择建议 - 执行后自动提取工具序列并保存/累计学习模式 - 支持任务分类(11类)、关键词提取、工具序列合并、有效性评分 - 集成到 AgentRuntime.run()/run_stream(),支持 bare chat 和 Agent 模式 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -24,6 +24,12 @@ from app.agent_runtime.context import AgentContext
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
from app.core.exceptions import WorkflowExecutionError
|
||||
from app.services.agent_learning_service import (
|
||||
extract_pattern_from_result,
|
||||
format_pattern_hint,
|
||||
load_relevant_patterns,
|
||||
save_learning_pattern,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,6 +104,8 @@ class AgentRuntime:
|
||||
self.on_llm_call = on_llm_call
|
||||
self._memory_context_loaded = False
|
||||
self._llm_invocations = 0
|
||||
# 自主学习作用域:bare 聊天用 "bare",Agent 用 "agent"
|
||||
self._learning_scope_kind = "bare" if "bare" in str(_mem_scope) else "agent"
|
||||
|
||||
# 预算回调:供 WorkflowEngine 注入,使 Agent 内部计数计入工作流预算
|
||||
# 返回 True 表示预算充足;返回 False 或抛出异常表示超限
|
||||
@@ -221,6 +229,13 @@ class AgentRuntime:
|
||||
))
|
||||
# 保存记忆
|
||||
await self.memory.save_context(user_input, final_text, self.context.messages)
|
||||
# 保存学习模式
|
||||
if self.config.memory.learning_enabled:
|
||||
await self._save_learning_pattern(
|
||||
user_input, steps, success=True,
|
||||
iterations_used=self.context.iteration,
|
||||
tool_calls_made=self.context.tool_calls_made,
|
||||
)
|
||||
return AgentResult(
|
||||
success=True,
|
||||
content=final_text,
|
||||
@@ -318,6 +333,13 @@ class AgentRuntime:
|
||||
|
||||
logger.warning("Agent 达到最大迭代次数 (%s)", max_iter)
|
||||
await self.memory.save_context(user_input, last_content or "(已达最大迭代次数)", self.context.messages)
|
||||
# 保存学习模式(即便截断,工具调用模式仍有参考价值)
|
||||
if self.config.memory.learning_enabled:
|
||||
await self._save_learning_pattern(
|
||||
user_input, steps, success=True,
|
||||
iterations_used=self.context.iteration,
|
||||
tool_calls_made=self.context.tool_calls_made,
|
||||
)
|
||||
if last_content:
|
||||
steps.append(AgentStep(
|
||||
iteration=self.context.iteration,
|
||||
@@ -446,6 +468,13 @@ class AgentRuntime:
|
||||
"session_id": self.context.session_id,
|
||||
}
|
||||
await self.memory.save_context(user_input, final_text, self.context.messages)
|
||||
# 保存学习模式
|
||||
if self.config.memory.learning_enabled:
|
||||
await self._save_learning_pattern(
|
||||
user_input, steps, success=True,
|
||||
iterations_used=self.context.iteration,
|
||||
tool_calls_made=self.context.tool_calls_made,
|
||||
)
|
||||
return
|
||||
|
||||
# 有工具调用 → 先记录 assistant 消息
|
||||
@@ -557,6 +586,13 @@ class AgentRuntime:
|
||||
|
||||
logger.warning("Agent 达到最大迭代次数 (%s)", max_iter)
|
||||
await self.memory.save_context(user_input, last_content or "(已达最大迭代次数)", self.context.messages)
|
||||
# 保存学习模式(即便截断,工具调用模式仍有参考价值)
|
||||
if self.config.memory.learning_enabled:
|
||||
await self._save_learning_pattern(
|
||||
user_input, steps, success=True,
|
||||
iterations_used=self.context.iteration,
|
||||
tool_calls_made=self.context.tool_calls_made,
|
||||
)
|
||||
yield {
|
||||
"type": "final",
|
||||
"content": last_content or "已达最大迭代次数,但模型未返回最终回答。",
|
||||
@@ -570,14 +606,62 @@ class AgentRuntime:
|
||||
async def _inject_memory_context(self, query: str = "") -> None:
|
||||
"""加载长期记忆并注入 system prompt。"""
|
||||
mem_text = await self.memory.initialize(query=query)
|
||||
enriched = self.config.system_prompt.rstrip("\n")
|
||||
|
||||
if mem_text:
|
||||
enriched = (
|
||||
self.config.system_prompt.rstrip("\n")
|
||||
+ "\n\n"
|
||||
+ mem_text
|
||||
enriched += "\n\n" + mem_text
|
||||
|
||||
# 注入学习模式提示(历史工具使用建议)
|
||||
if self.config.memory.learning_enabled:
|
||||
pattern_hint = await self._inject_learning_patterns(query)
|
||||
if pattern_hint:
|
||||
enriched += "\n\n" + pattern_hint
|
||||
|
||||
self.context.set_system_prompt(enriched)
|
||||
logger.info("Agent 已注入长期记忆上下文")
|
||||
|
||||
async def _inject_learning_patterns(self, query: str) -> str:
|
||||
"""查询学习模式,返回格式化的提示文本。"""
|
||||
from app.core.database import SessionLocal
|
||||
db = None
|
||||
try:
|
||||
db = SessionLocal()
|
||||
patterns = load_relevant_patterns(
|
||||
db, self._learning_scope_kind, self.memory.scope_id, query
|
||||
)
|
||||
self.context.set_system_prompt(enriched)
|
||||
logger.info("Agent 已注入长期记忆上下文")
|
||||
return format_pattern_hint(patterns, query)
|
||||
except Exception as e:
|
||||
logger.warning("加载学习模式失败: %s", e)
|
||||
return ""
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
async def _save_learning_pattern(
|
||||
self, query: str, steps: List[AgentStep],
|
||||
success: bool, iterations_used: int, tool_calls_made: int,
|
||||
) -> None:
|
||||
"""从执行结果中提取模式并保存。"""
|
||||
from app.core.database import SessionLocal
|
||||
db = None
|
||||
try:
|
||||
db = SessionLocal()
|
||||
pattern_data = extract_pattern_from_result(
|
||||
query=query,
|
||||
steps=steps,
|
||||
success=success,
|
||||
iterations_used=iterations_used,
|
||||
tool_calls_made=tool_calls_made,
|
||||
)
|
||||
save_learning_pattern(
|
||||
db, self._learning_scope_kind,
|
||||
self.memory.scope_id, pattern_data,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("保存学习模式失败: %s", e)
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_calls(response: Any) -> List[Dict[str, Any]]:
|
||||
|
||||
Reference in New Issue
Block a user