- 新增 AgentLearningPattern 模型和 agent_learning_service 服务 - 执行前注入历史学习模式到 system prompt 作为工具选择建议 - 执行后自动提取工具序列并保存/累计学习模式 - 支持任务分类(11类)、关键词提取、工具序列合并、有效性评分 - 集成到 AgentRuntime.run()/run_stream(),支持 bare chat 和 Agent 模式 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
341 lines
12 KiB
Python
341 lines
12 KiB
Python
"""
|
||
Agent 自主学习服务 — 从历史执行中提取工具使用模式,优化后续工具选择。
|
||
|
||
工作流程:
|
||
1. 执行后:提取工具序列 → 分类任务 → 保存/更新学习模式
|
||
2. 执行前:查询匹配模式 → 注入 system prompt 作为工具选择建议
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.core.database import SessionLocal
|
||
from app.models.agent_learning_pattern import AgentLearningPattern
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ─── 任务分类关键词映射 ──────────────────────────────────────
|
||
|
||
_TASK_CATEGORIES = {
|
||
"data_analysis": [
|
||
"分析", "统计", "csv", "数据", "平均值", "总和", "图表", "报表",
|
||
"aggregate", "analyze", "statistic", "correlation",
|
||
],
|
||
"code_generation": [
|
||
"写", "实现", "函数", "代码", "编程", "脚本", "算法",
|
||
"implement", "function", "code", "program", "script",
|
||
],
|
||
"code_review": [
|
||
"审查", "review", "检查", "代码质量", "重构", "优化",
|
||
"refactor", "code style", "lint",
|
||
],
|
||
"debugging": [
|
||
"调试", "bug", "错误", "异常", "崩溃", "报错", "修复",
|
||
"debug", "error", "exception", "crash", "fix",
|
||
],
|
||
"file_operation": [
|
||
"读取", "写入", "保存", "文件", "目录", "创建", "删除",
|
||
"read file", "write file", "save", "directory",
|
||
],
|
||
"web_search": [
|
||
"搜索", "查询", "查找", "找", "了解", "最新",
|
||
"search", "find", "look up", "what is",
|
||
],
|
||
"network": [
|
||
"请求", "api", "http", "网络", "curl", "网址", "网站",
|
||
"request", "endpoint", "network", "url",
|
||
],
|
||
"database": [
|
||
"数据库", "sql", "表", "查询", "mysql", "postgresql",
|
||
"database", "select", "insert", "update", "delete",
|
||
],
|
||
"text_processing": [
|
||
"摘要", "翻译", "总结", "提取", "格式化", "转换",
|
||
"summarize", "translate", "extract", "format", "convert",
|
||
],
|
||
"translation": [
|
||
"翻译", "译成", "translate", "translation", "localization",
|
||
],
|
||
"testing": [
|
||
"测试", "单元测试", "集成测试", "用例", "pytest", "jest",
|
||
"test", "unit test", "integration test",
|
||
],
|
||
}
|
||
|
||
|
||
def _classify_task(query: str) -> str:
|
||
"""根据用户查询内容分类任务类型。"""
|
||
q = query.lower().strip()
|
||
scores: List[Tuple[str, int]] = []
|
||
for category, keywords in _TASK_CATEGORIES.items():
|
||
score = sum(1 for kw in keywords if kw in q)
|
||
if score > 0:
|
||
scores.append((category, score))
|
||
if not scores:
|
||
return "general"
|
||
scores.sort(key=lambda x: -x[1])
|
||
return scores[0][0]
|
||
|
||
|
||
def _extract_keywords(query: str, max_words: int = 8) -> str:
|
||
"""从查询中提取关键词。"""
|
||
# 移除常见停用词
|
||
stop_words = {
|
||
"的", "了", "在", "是", "我", "有", "和", "就", "不", "人", "都", "一",
|
||
"一个", "上", "也", "很", "到", "说", "要", "去", "你", "会", "着",
|
||
"没有", "看", "好", "自己", "这", "他", "她", "它", "们",
|
||
"the", "a", "an", "is", "are", "was", "were", "be", "been",
|
||
"have", "has", "had", "do", "does", "did", "will", "would",
|
||
"can", "could", "should", "may", "might", "shall",
|
||
"i", "you", "he", "she", "it", "we", "they",
|
||
"this", "that", "these", "those", "my", "your", "his", "her",
|
||
"of", "in", "on", "at", "to", "for", "with", "by", "from", "as",
|
||
"and", "or", "but", "not", "no", "so", "if",
|
||
}
|
||
# 中文和英文分词
|
||
words = re.findall(r'[\w]+', query.lower())
|
||
keywords = [w for w in words if w.lower() not in stop_words and len(w) > 1]
|
||
return ", ".join(keywords[:max_words])
|
||
|
||
|
||
def _extract_tool_sequence(steps: List[Dict[str, Any]]) -> List[str]:
|
||
"""从执行步骤中提取工具调用序列(去重连续重复)。"""
|
||
tools = []
|
||
last_tool = None
|
||
for step in steps:
|
||
if isinstance(step, dict):
|
||
tname = step.get("tool_name") or (
|
||
step.get("tool_calls", [{}])[0].get("function", {}).get("name")
|
||
if step.get("tool_calls") else None
|
||
)
|
||
else:
|
||
tname = getattr(step, "tool_name", None)
|
||
if tname and tname != last_tool:
|
||
tools.append(tname)
|
||
last_tool = tname
|
||
return tools
|
||
|
||
|
||
def extract_pattern_from_result(
|
||
query: str,
|
||
steps: List[Any],
|
||
success: bool,
|
||
iterations_used: int,
|
||
tool_calls_made: int,
|
||
) -> Dict[str, Any]:
|
||
"""从 Agent 执行结果中提取学习模式数据。"""
|
||
category = _classify_task(query)
|
||
keywords = _extract_keywords(query)
|
||
tool_sequence = _extract_tool_sequence(steps)
|
||
|
||
return {
|
||
"task_category": category,
|
||
"task_keywords": keywords,
|
||
"suggested_tools": tool_sequence,
|
||
"success": success,
|
||
"iterations_used": iterations_used,
|
||
"tool_calls_made": tool_calls_made,
|
||
}
|
||
|
||
|
||
def save_learning_pattern(
|
||
db: Session,
|
||
scope_kind: str,
|
||
scope_id: str,
|
||
pattern_data: Dict[str, Any],
|
||
) -> AgentLearningPattern:
|
||
"""
|
||
保存/更新学习模式。
|
||
对同 scope + task_category 做 upsert,累积统计。
|
||
"""
|
||
category = pattern_data["task_category"]
|
||
suggested_tools = pattern_data.get("suggested_tools", [])
|
||
success = pattern_data.get("success", True)
|
||
iterations = pattern_data.get("iterations_used", 0)
|
||
tool_calls = pattern_data.get("tool_calls_made", 0)
|
||
keywords = pattern_data.get("task_keywords", "")
|
||
|
||
# 查询已有模式
|
||
existing = (
|
||
db.query(AgentLearningPattern)
|
||
.filter(
|
||
AgentLearningPattern.scope_kind == scope_kind,
|
||
AgentLearningPattern.scope_id == scope_id,
|
||
AgentLearningPattern.task_category == category,
|
||
)
|
||
.first()
|
||
)
|
||
|
||
if existing:
|
||
# 更新统计
|
||
n = existing.total_runs + 1
|
||
existing.total_runs = n
|
||
if success:
|
||
existing.successful_runs += 1
|
||
existing.avg_iterations = (
|
||
(existing.avg_iterations * (n - 1) + iterations) / n
|
||
)
|
||
existing.avg_tool_calls = (
|
||
(existing.avg_tool_calls * (n - 1) + tool_calls) / n
|
||
)
|
||
# 合并工具序列:保留出现频率高的
|
||
if suggested_tools:
|
||
old_tools = json.loads(existing.suggested_tools) if existing.suggested_tools else []
|
||
merged = _merge_tool_sequences(old_tools, suggested_tools)
|
||
existing.suggested_tools = json.dumps(merged, ensure_ascii=False)
|
||
# 合并关键词
|
||
if keywords and existing.task_keywords:
|
||
merged_kw = list(dict.fromkeys(existing.task_keywords.split(", ") + keywords.split(", ")))
|
||
existing.task_keywords = ", ".join(merged_kw[:12])
|
||
elif keywords:
|
||
existing.task_keywords = keywords
|
||
# 更新有效评分
|
||
existing.effectiveness_score = _calc_effectiveness(existing)
|
||
existing.last_used_at = datetime.utcnow()
|
||
existing.updated_at = datetime.utcnow()
|
||
db.commit()
|
||
db.refresh(existing)
|
||
logger.info(
|
||
"学习模式更新: scope=%s/%s category=%s runs=%d",
|
||
scope_kind, scope_id, category, existing.total_runs,
|
||
)
|
||
return existing
|
||
else:
|
||
# 创建新模式
|
||
record = AgentLearningPattern(
|
||
scope_kind=scope_kind,
|
||
scope_id=scope_id,
|
||
task_category=category,
|
||
task_keywords=keywords,
|
||
suggested_tools=json.dumps(suggested_tools, ensure_ascii=False),
|
||
effectiveness_score=1.0 if success else 0.3,
|
||
total_runs=1,
|
||
successful_runs=1 if success else 0,
|
||
avg_iterations=float(iterations),
|
||
avg_tool_calls=float(tool_calls),
|
||
last_used_at=datetime.utcnow(),
|
||
)
|
||
db.add(record)
|
||
db.commit()
|
||
db.refresh(record)
|
||
logger.info(
|
||
"学习模式创建: scope=%s/%s category=%s tools=%s",
|
||
scope_kind, scope_id, category, suggested_tools,
|
||
)
|
||
return record
|
||
|
||
|
||
def _merge_tool_sequences(
|
||
old_seq: List[str], new_seq: List[str]
|
||
) -> List[str]:
|
||
"""合并两套工具序列,保留新序列中工具及其顺序,并补充旧序列中独有的工具。"""
|
||
seen = set(new_seq)
|
||
result = list(new_seq)
|
||
for t in old_seq:
|
||
if t not in seen:
|
||
result.append(t)
|
||
seen.add(t)
|
||
return result
|
||
|
||
|
||
def _calc_effectiveness(pattern: AgentLearningPattern) -> float:
|
||
"""计算模式的有效评分 (0-1)。"""
|
||
if pattern.total_runs == 0:
|
||
return 0.0
|
||
success_rate = pattern.successful_runs / pattern.total_runs
|
||
# 迭代效率:迭代越少越好,以 5 次为基准
|
||
efficiency = max(0, 1.0 - (pattern.avg_iterations / 10.0))
|
||
# 综合评分
|
||
return round(success_rate * 0.6 + efficiency * 0.4, 3)
|
||
|
||
|
||
def load_relevant_patterns(
|
||
db: Session, scope_kind: str, scope_id: str, query: str,
|
||
top_k: int = 3,
|
||
) -> List[AgentLearningPattern]:
|
||
"""
|
||
加载与当前查询相关的学习模式。
|
||
通过任务分类匹配来筛选。
|
||
"""
|
||
category = _classify_task(query)
|
||
|
||
patterns = (
|
||
db.query(AgentLearningPattern)
|
||
.filter(
|
||
AgentLearningPattern.scope_kind == scope_kind,
|
||
AgentLearningPattern.scope_id == scope_id,
|
||
AgentLearningPattern.task_category == category,
|
||
AgentLearningPattern.total_runs >= 1,
|
||
)
|
||
.order_by(AgentLearningPattern.effectiveness_score.desc())
|
||
.limit(top_k)
|
||
.all()
|
||
)
|
||
|
||
if not patterns and category != "general":
|
||
# 降级:匹配 general 或任意分类
|
||
patterns = (
|
||
db.query(AgentLearningPattern)
|
||
.filter(
|
||
AgentLearningPattern.scope_kind == scope_kind,
|
||
AgentLearningPattern.scope_id == scope_id,
|
||
)
|
||
.order_by(AgentLearningPattern.effectiveness_score.desc())
|
||
.limit(top_k)
|
||
.all()
|
||
)
|
||
|
||
return patterns
|
||
|
||
|
||
def format_pattern_hint(
|
||
patterns: List[AgentLearningPattern], query: str
|
||
) -> str:
|
||
"""
|
||
将学习模式格式化为 system prompt 中的提示文本。
|
||
如果模式不够成熟(runs < 2),返回空字符串。
|
||
"""
|
||
if not patterns:
|
||
return ""
|
||
|
||
# 过滤掉模式不明显的数据(工具序列为空、运行次数太少)
|
||
mature_patterns = [p for p in patterns if p.total_runs >= 2 and p.effectiveness_score >= 0.3]
|
||
if not mature_patterns:
|
||
return ""
|
||
|
||
lines = ["## 历史经验参考(基于过往执行记录)"]
|
||
for p in mature_patterns:
|
||
tools = json.loads(p.suggested_tools) if p.suggested_tools else []
|
||
if not tools:
|
||
continue
|
||
category_names = {
|
||
"data_analysis": "数据分析",
|
||
"code_generation": "代码生成",
|
||
"code_review": "代码审查",
|
||
"debugging": "调试排错",
|
||
"file_operation": "文件操作",
|
||
"web_search": "网络搜索",
|
||
"network": "网络请求",
|
||
"database": "数据库",
|
||
"text_processing": "文本处理",
|
||
"translation": "翻译",
|
||
"testing": "测试",
|
||
"general": "通用",
|
||
}
|
||
cat_name = category_names.get(p.task_category, p.task_category)
|
||
tool_str = " → ".join(tools)
|
||
lines.append(
|
||
f"- [{cat_name}] 对于此类任务,历史经验推荐工具序列: {tool_str} "
|
||
f"(成功{p.successful_runs}/{p.total_runs}次, 评分{p.effectiveness_score:.2f})"
|
||
)
|
||
|
||
if len(lines) == 1:
|
||
return "" # 只有标题没有内容
|
||
|
||
return "\n".join(lines)
|