feat: 实现 Agent 自主学习 — 从历史执行中优化工具选择

- 新增 AgentLearningPattern 模型和 agent_learning_service 服务
- 执行前注入历史学习模式到 system prompt 作为工具选择建议
- 执行后自动提取工具序列并保存/累计学习模式
- 支持任务分类(11类)、关键词提取、工具序列合并、有效性评分
- 集成到 AgentRuntime.run()/run_stream(),支持 bare chat 和 Agent 模式

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
renjianbo
2026-05-02 12:04:00 +08:00
parent c28cf40f61
commit e3802eff60
6 changed files with 459 additions and 7 deletions

View File

@@ -24,6 +24,12 @@ from app.agent_runtime.context import AgentContext
from app.agent_runtime.memory import AgentMemory
from app.agent_runtime.tool_manager import AgentToolManager
from app.core.exceptions import WorkflowExecutionError
from app.services.agent_learning_service import (
extract_pattern_from_result,
format_pattern_hint,
load_relevant_patterns,
save_learning_pattern,
)
logger = logging.getLogger(__name__)
@@ -98,6 +104,8 @@ class AgentRuntime:
self.on_llm_call = on_llm_call
self._memory_context_loaded = False
self._llm_invocations = 0
# 自主学习作用域bare 聊天用 "bare"Agent 用 "agent"
self._learning_scope_kind = "bare" if "bare" in str(_mem_scope) else "agent"
# 预算回调:供 WorkflowEngine 注入,使 Agent 内部计数计入工作流预算
# 返回 True 表示预算充足;返回 False 或抛出异常表示超限
@@ -221,6 +229,13 @@ class AgentRuntime:
))
# 保存记忆
await self.memory.save_context(user_input, final_text, self.context.messages)
# 保存学习模式
if self.config.memory.learning_enabled:
await self._save_learning_pattern(
user_input, steps, success=True,
iterations_used=self.context.iteration,
tool_calls_made=self.context.tool_calls_made,
)
return AgentResult(
success=True,
content=final_text,
@@ -318,6 +333,13 @@ class AgentRuntime:
logger.warning("Agent 达到最大迭代次数 (%s)", max_iter)
await self.memory.save_context(user_input, last_content or "(已达最大迭代次数)", self.context.messages)
# 保存学习模式(即便截断,工具调用模式仍有参考价值)
if self.config.memory.learning_enabled:
await self._save_learning_pattern(
user_input, steps, success=True,
iterations_used=self.context.iteration,
tool_calls_made=self.context.tool_calls_made,
)
if last_content:
steps.append(AgentStep(
iteration=self.context.iteration,
@@ -446,6 +468,13 @@ class AgentRuntime:
"session_id": self.context.session_id,
}
await self.memory.save_context(user_input, final_text, self.context.messages)
# 保存学习模式
if self.config.memory.learning_enabled:
await self._save_learning_pattern(
user_input, steps, success=True,
iterations_used=self.context.iteration,
tool_calls_made=self.context.tool_calls_made,
)
return
# 有工具调用 → 先记录 assistant 消息
@@ -557,6 +586,13 @@ class AgentRuntime:
logger.warning("Agent 达到最大迭代次数 (%s)", max_iter)
await self.memory.save_context(user_input, last_content or "(已达最大迭代次数)", self.context.messages)
# 保存学习模式(即便截断,工具调用模式仍有参考价值)
if self.config.memory.learning_enabled:
await self._save_learning_pattern(
user_input, steps, success=True,
iterations_used=self.context.iteration,
tool_calls_made=self.context.tool_calls_made,
)
yield {
"type": "final",
"content": last_content or "已达最大迭代次数,但模型未返回最终回答。",
@@ -570,14 +606,62 @@ class AgentRuntime:
async def _inject_memory_context(self, query: str = "") -> None:
"""加载长期记忆并注入 system prompt。"""
mem_text = await self.memory.initialize(query=query)
enriched = self.config.system_prompt.rstrip("\n")
if mem_text:
enriched = (
self.config.system_prompt.rstrip("\n")
+ "\n\n"
+ mem_text
enriched += "\n\n" + mem_text
# 注入学习模式提示(历史工具使用建议)
if self.config.memory.learning_enabled:
pattern_hint = await self._inject_learning_patterns(query)
if pattern_hint:
enriched += "\n\n" + pattern_hint
self.context.set_system_prompt(enriched)
logger.info("Agent 已注入长期记忆上下文")
async def _inject_learning_patterns(self, query: str) -> str:
"""查询学习模式,返回格式化的提示文本。"""
from app.core.database import SessionLocal
db = None
try:
db = SessionLocal()
patterns = load_relevant_patterns(
db, self._learning_scope_kind, self.memory.scope_id, query
)
self.context.set_system_prompt(enriched)
logger.info("Agent 已注入长期记忆上下文")
return format_pattern_hint(patterns, query)
except Exception as e:
logger.warning("加载学习模式失败: %s", e)
return ""
finally:
if db:
db.close()
async def _save_learning_pattern(
self, query: str, steps: List[AgentStep],
success: bool, iterations_used: int, tool_calls_made: int,
) -> None:
"""从执行结果中提取模式并保存。"""
from app.core.database import SessionLocal
db = None
try:
db = SessionLocal()
pattern_data = extract_pattern_from_result(
query=query,
steps=steps,
success=success,
iterations_used=iterations_used,
tool_calls_made=tool_calls_made,
)
save_learning_pattern(
db, self._learning_scope_kind,
self.memory.scope_id, pattern_data,
)
except Exception as e:
logger.warning("保存学习模式失败: %s", e)
finally:
if db:
db.close()
@staticmethod
def _extract_tool_calls(response: Any) -> List[Dict[str, Any]]:

View File

@@ -22,6 +22,7 @@ class AgentMemoryConfig(BaseModel):
persist_to_db: bool = True # 是否写入 MySQL 长期记忆
vector_memory_enabled: bool = True # 是否启用向量记忆(语义检索)
vector_memory_top_k: int = 5 # 向量检索 Top-K
learning_enabled: bool = True # 是否启用自主学习(工具模式学习)
class AgentLLMConfig(BaseModel):

View File

@@ -48,5 +48,6 @@ def init_db():
import app.models.alert_rule
import app.models.agent_llm_log
import app.models.agent_vector_memory
import app.models.agent_learning_pattern
import app.models.knowledge_base
Base.metadata.create_all(bind=engine)

View File

@@ -14,6 +14,7 @@ from app.models.alert_rule import AlertRule, AlertLog
from app.models.persistent_user_memory import PersistentUserMemory
from app.models.agent_llm_log import AgentLLMLog
from app.models.agent_vector_memory import AgentVectorMemory
from app.models.agent_learning_pattern import AgentLearningPattern
from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk
__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "KnowledgeBase", "Document", "DocumentChunk"]
__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "AgentLearningPattern", "KnowledgeBase", "Document", "DocumentChunk"]

View File

@@ -0,0 +1,25 @@
"""Agent 自主学习模式表:记录工具使用模式,支持自主学习优化"""
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, Float, Integer, DateTime
from app.core.database import Base
class AgentLearningPattern(Base):
"""Agent 学习模式 — 记录工具调用序列与任务类型的关系"""
__tablename__ = "agent_learning_patterns"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
scope_kind = Column(String(16), nullable=False, index=True, comment="作用域类型: agent/bare")
scope_id = Column(String(64), nullable=False, index=True, comment="作用域 ID: agent_id/user_id")
task_category = Column(String(64), nullable=False, default="general", comment="任务分类")
task_keywords = Column(String(256), default="", comment="任务关键词")
suggested_tools = Column(Text, nullable=False, comment="推荐工具序列 (JSON array)")
effectiveness_score = Column(Float, default=0.0, comment="有效评分 0-1")
total_runs = Column(Integer, default=1, comment="总运行次数")
successful_runs = Column(Integer, default=1, comment="成功次数")
avg_iterations = Column(Float, default=1.0, comment="平均迭代次数")
avg_tool_calls = Column(Float, default=1.0, comment="平均工具调用数")
last_used_at = Column(DateTime, default=datetime.utcnow, comment="最后使用时间")
created_at = Column(DateTime, default=datetime.utcnow, comment="创建时间")
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间")

View File

@@ -0,0 +1,340 @@
"""
Agent 自主学习服务 — 从历史执行中提取工具使用模式,优化后续工具选择。
工作流程:
1. 执行后:提取工具序列 → 分类任务 → 保存/更新学习模式
2. 执行前:查询匹配模式 → 注入 system prompt 作为工具选择建议
"""
from __future__ import annotations
import json
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy.orm import Session
from app.core.database import SessionLocal
from app.models.agent_learning_pattern import AgentLearningPattern
logger = logging.getLogger(__name__)
# ─── 任务分类关键词映射 ──────────────────────────────────────
_TASK_CATEGORIES = {
"data_analysis": [
"分析", "统计", "csv", "数据", "平均值", "总和", "图表", "报表",
"aggregate", "analyze", "statistic", "correlation",
],
"code_generation": [
"", "实现", "函数", "代码", "编程", "脚本", "算法",
"implement", "function", "code", "program", "script",
],
"code_review": [
"审查", "review", "检查", "代码质量", "重构", "优化",
"refactor", "code style", "lint",
],
"debugging": [
"调试", "bug", "错误", "异常", "崩溃", "报错", "修复",
"debug", "error", "exception", "crash", "fix",
],
"file_operation": [
"读取", "写入", "保存", "文件", "目录", "创建", "删除",
"read file", "write file", "save", "directory",
],
"web_search": [
"搜索", "查询", "查找", "", "了解", "最新",
"search", "find", "look up", "what is",
],
"network": [
"请求", "api", "http", "网络", "curl", "网址", "网站",
"request", "endpoint", "network", "url",
],
"database": [
"数据库", "sql", "", "查询", "mysql", "postgresql",
"database", "select", "insert", "update", "delete",
],
"text_processing": [
"摘要", "翻译", "总结", "提取", "格式化", "转换",
"summarize", "translate", "extract", "format", "convert",
],
"translation": [
"翻译", "译成", "translate", "translation", "localization",
],
"testing": [
"测试", "单元测试", "集成测试", "用例", "pytest", "jest",
"test", "unit test", "integration test",
],
}
def _classify_task(query: str) -> str:
"""根据用户查询内容分类任务类型。"""
q = query.lower().strip()
scores: List[Tuple[str, int]] = []
for category, keywords in _TASK_CATEGORIES.items():
score = sum(1 for kw in keywords if kw in q)
if score > 0:
scores.append((category, score))
if not scores:
return "general"
scores.sort(key=lambda x: -x[1])
return scores[0][0]
def _extract_keywords(query: str, max_words: int = 8) -> str:
"""从查询中提取关键词。"""
# 移除常见停用词
stop_words = {
"", "", "", "", "", "", "", "", "", "", "", "",
"一个", "", "", "", "", "", "", "", "", "", "",
"没有", "", "", "自己", "", "", "", "", "",
"the", "a", "an", "is", "are", "was", "were", "be", "been",
"have", "has", "had", "do", "does", "did", "will", "would",
"can", "could", "should", "may", "might", "shall",
"i", "you", "he", "she", "it", "we", "they",
"this", "that", "these", "those", "my", "your", "his", "her",
"of", "in", "on", "at", "to", "for", "with", "by", "from", "as",
"and", "or", "but", "not", "no", "so", "if",
}
# 中文和英文分词
words = re.findall(r'[\w]+', query.lower())
keywords = [w for w in words if w.lower() not in stop_words and len(w) > 1]
return ", ".join(keywords[:max_words])
def _extract_tool_sequence(steps: List[Dict[str, Any]]) -> List[str]:
"""从执行步骤中提取工具调用序列(去重连续重复)。"""
tools = []
last_tool = None
for step in steps:
if isinstance(step, dict):
tname = step.get("tool_name") or (
step.get("tool_calls", [{}])[0].get("function", {}).get("name")
if step.get("tool_calls") else None
)
else:
tname = getattr(step, "tool_name", None)
if tname and tname != last_tool:
tools.append(tname)
last_tool = tname
return tools
def extract_pattern_from_result(
query: str,
steps: List[Any],
success: bool,
iterations_used: int,
tool_calls_made: int,
) -> Dict[str, Any]:
"""从 Agent 执行结果中提取学习模式数据。"""
category = _classify_task(query)
keywords = _extract_keywords(query)
tool_sequence = _extract_tool_sequence(steps)
return {
"task_category": category,
"task_keywords": keywords,
"suggested_tools": tool_sequence,
"success": success,
"iterations_used": iterations_used,
"tool_calls_made": tool_calls_made,
}
def save_learning_pattern(
db: Session,
scope_kind: str,
scope_id: str,
pattern_data: Dict[str, Any],
) -> AgentLearningPattern:
"""
保存/更新学习模式。
对同 scope + task_category 做 upsert累积统计。
"""
category = pattern_data["task_category"]
suggested_tools = pattern_data.get("suggested_tools", [])
success = pattern_data.get("success", True)
iterations = pattern_data.get("iterations_used", 0)
tool_calls = pattern_data.get("tool_calls_made", 0)
keywords = pattern_data.get("task_keywords", "")
# 查询已有模式
existing = (
db.query(AgentLearningPattern)
.filter(
AgentLearningPattern.scope_kind == scope_kind,
AgentLearningPattern.scope_id == scope_id,
AgentLearningPattern.task_category == category,
)
.first()
)
if existing:
# 更新统计
n = existing.total_runs + 1
existing.total_runs = n
if success:
existing.successful_runs += 1
existing.avg_iterations = (
(existing.avg_iterations * (n - 1) + iterations) / n
)
existing.avg_tool_calls = (
(existing.avg_tool_calls * (n - 1) + tool_calls) / n
)
# 合并工具序列:保留出现频率高的
if suggested_tools:
old_tools = json.loads(existing.suggested_tools) if existing.suggested_tools else []
merged = _merge_tool_sequences(old_tools, suggested_tools)
existing.suggested_tools = json.dumps(merged, ensure_ascii=False)
# 合并关键词
if keywords and existing.task_keywords:
merged_kw = list(dict.fromkeys(existing.task_keywords.split(", ") + keywords.split(", ")))
existing.task_keywords = ", ".join(merged_kw[:12])
elif keywords:
existing.task_keywords = keywords
# 更新有效评分
existing.effectiveness_score = _calc_effectiveness(existing)
existing.last_used_at = datetime.utcnow()
existing.updated_at = datetime.utcnow()
db.commit()
db.refresh(existing)
logger.info(
"学习模式更新: scope=%s/%s category=%s runs=%d",
scope_kind, scope_id, category, existing.total_runs,
)
return existing
else:
# 创建新模式
record = AgentLearningPattern(
scope_kind=scope_kind,
scope_id=scope_id,
task_category=category,
task_keywords=keywords,
suggested_tools=json.dumps(suggested_tools, ensure_ascii=False),
effectiveness_score=1.0 if success else 0.3,
total_runs=1,
successful_runs=1 if success else 0,
avg_iterations=float(iterations),
avg_tool_calls=float(tool_calls),
last_used_at=datetime.utcnow(),
)
db.add(record)
db.commit()
db.refresh(record)
logger.info(
"学习模式创建: scope=%s/%s category=%s tools=%s",
scope_kind, scope_id, category, suggested_tools,
)
return record
def _merge_tool_sequences(
old_seq: List[str], new_seq: List[str]
) -> List[str]:
"""合并两套工具序列,保留新序列中工具及其顺序,并补充旧序列中独有的工具。"""
seen = set(new_seq)
result = list(new_seq)
for t in old_seq:
if t not in seen:
result.append(t)
seen.add(t)
return result
def _calc_effectiveness(pattern: AgentLearningPattern) -> float:
"""计算模式的有效评分 (0-1)。"""
if pattern.total_runs == 0:
return 0.0
success_rate = pattern.successful_runs / pattern.total_runs
# 迭代效率:迭代越少越好,以 5 次为基准
efficiency = max(0, 1.0 - (pattern.avg_iterations / 10.0))
# 综合评分
return round(success_rate * 0.6 + efficiency * 0.4, 3)
def load_relevant_patterns(
db: Session, scope_kind: str, scope_id: str, query: str,
top_k: int = 3,
) -> List[AgentLearningPattern]:
"""
加载与当前查询相关的学习模式。
通过任务分类匹配来筛选。
"""
category = _classify_task(query)
patterns = (
db.query(AgentLearningPattern)
.filter(
AgentLearningPattern.scope_kind == scope_kind,
AgentLearningPattern.scope_id == scope_id,
AgentLearningPattern.task_category == category,
AgentLearningPattern.total_runs >= 1,
)
.order_by(AgentLearningPattern.effectiveness_score.desc())
.limit(top_k)
.all()
)
if not patterns and category != "general":
# 降级:匹配 general 或任意分类
patterns = (
db.query(AgentLearningPattern)
.filter(
AgentLearningPattern.scope_kind == scope_kind,
AgentLearningPattern.scope_id == scope_id,
)
.order_by(AgentLearningPattern.effectiveness_score.desc())
.limit(top_k)
.all()
)
return patterns
def format_pattern_hint(
patterns: List[AgentLearningPattern], query: str
) -> str:
"""
将学习模式格式化为 system prompt 中的提示文本。
如果模式不够成熟runs < 2返回空字符串。
"""
if not patterns:
return ""
# 过滤掉模式不明显的数据(工具序列为空、运行次数太少)
mature_patterns = [p for p in patterns if p.total_runs >= 2 and p.effectiveness_score >= 0.3]
if not mature_patterns:
return ""
lines = ["## 历史经验参考(基于过往执行记录)"]
for p in mature_patterns:
tools = json.loads(p.suggested_tools) if p.suggested_tools else []
if not tools:
continue
category_names = {
"data_analysis": "数据分析",
"code_generation": "代码生成",
"code_review": "代码审查",
"debugging": "调试排错",
"file_operation": "文件操作",
"web_search": "网络搜索",
"network": "网络请求",
"database": "数据库",
"text_processing": "文本处理",
"translation": "翻译",
"testing": "测试",
"general": "通用",
}
cat_name = category_names.get(p.task_category, p.task_category)
tool_str = "".join(tools)
lines.append(
f"- [{cat_name}] 对于此类任务,历史经验推荐工具序列: {tool_str} "
f"(成功{p.successful_runs}/{p.total_runs}次, 评分{p.effectiveness_score:.2f})"
)
if len(lines) == 1:
return "" # 只有标题没有内容
return "\n".join(lines)