341 lines
12 KiB
Python
341 lines
12 KiB
Python
|
|
"""
|
|||
|
|
Agent 自主学习服务 — 从历史执行中提取工具使用模式,优化后续工具选择。
|
|||
|
|
|
|||
|
|
工作流程:
|
|||
|
|
1. 执行后:提取工具序列 → 分类任务 → 保存/更新学习模式
|
|||
|
|
2. 执行前:查询匹配模式 → 注入 system prompt 作为工具选择建议
|
|||
|
|
"""
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import re
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.core.database import SessionLocal
|
|||
|
|
from app.models.agent_learning_pattern import AgentLearningPattern
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# ─── 任务分类关键词映射 ──────────────────────────────────────
|
|||
|
|
|
|||
|
|
_TASK_CATEGORIES = {
|
|||
|
|
"data_analysis": [
|
|||
|
|
"分析", "统计", "csv", "数据", "平均值", "总和", "图表", "报表",
|
|||
|
|
"aggregate", "analyze", "statistic", "correlation",
|
|||
|
|
],
|
|||
|
|
"code_generation": [
|
|||
|
|
"写", "实现", "函数", "代码", "编程", "脚本", "算法",
|
|||
|
|
"implement", "function", "code", "program", "script",
|
|||
|
|
],
|
|||
|
|
"code_review": [
|
|||
|
|
"审查", "review", "检查", "代码质量", "重构", "优化",
|
|||
|
|
"refactor", "code style", "lint",
|
|||
|
|
],
|
|||
|
|
"debugging": [
|
|||
|
|
"调试", "bug", "错误", "异常", "崩溃", "报错", "修复",
|
|||
|
|
"debug", "error", "exception", "crash", "fix",
|
|||
|
|
],
|
|||
|
|
"file_operation": [
|
|||
|
|
"读取", "写入", "保存", "文件", "目录", "创建", "删除",
|
|||
|
|
"read file", "write file", "save", "directory",
|
|||
|
|
],
|
|||
|
|
"web_search": [
|
|||
|
|
"搜索", "查询", "查找", "找", "了解", "最新",
|
|||
|
|
"search", "find", "look up", "what is",
|
|||
|
|
],
|
|||
|
|
"network": [
|
|||
|
|
"请求", "api", "http", "网络", "curl", "网址", "网站",
|
|||
|
|
"request", "endpoint", "network", "url",
|
|||
|
|
],
|
|||
|
|
"database": [
|
|||
|
|
"数据库", "sql", "表", "查询", "mysql", "postgresql",
|
|||
|
|
"database", "select", "insert", "update", "delete",
|
|||
|
|
],
|
|||
|
|
"text_processing": [
|
|||
|
|
"摘要", "翻译", "总结", "提取", "格式化", "转换",
|
|||
|
|
"summarize", "translate", "extract", "format", "convert",
|
|||
|
|
],
|
|||
|
|
"translation": [
|
|||
|
|
"翻译", "译成", "translate", "translation", "localization",
|
|||
|
|
],
|
|||
|
|
"testing": [
|
|||
|
|
"测试", "单元测试", "集成测试", "用例", "pytest", "jest",
|
|||
|
|
"test", "unit test", "integration test",
|
|||
|
|
],
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _classify_task(query: str) -> str:
|
|||
|
|
"""根据用户查询内容分类任务类型。"""
|
|||
|
|
q = query.lower().strip()
|
|||
|
|
scores: List[Tuple[str, int]] = []
|
|||
|
|
for category, keywords in _TASK_CATEGORIES.items():
|
|||
|
|
score = sum(1 for kw in keywords if kw in q)
|
|||
|
|
if score > 0:
|
|||
|
|
scores.append((category, score))
|
|||
|
|
if not scores:
|
|||
|
|
return "general"
|
|||
|
|
scores.sort(key=lambda x: -x[1])
|
|||
|
|
return scores[0][0]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_keywords(query: str, max_words: int = 8) -> str:
|
|||
|
|
"""从查询中提取关键词。"""
|
|||
|
|
# 移除常见停用词
|
|||
|
|
stop_words = {
|
|||
|
|
"的", "了", "在", "是", "我", "有", "和", "就", "不", "人", "都", "一",
|
|||
|
|
"一个", "上", "也", "很", "到", "说", "要", "去", "你", "会", "着",
|
|||
|
|
"没有", "看", "好", "自己", "这", "他", "她", "它", "们",
|
|||
|
|
"the", "a", "an", "is", "are", "was", "were", "be", "been",
|
|||
|
|
"have", "has", "had", "do", "does", "did", "will", "would",
|
|||
|
|
"can", "could", "should", "may", "might", "shall",
|
|||
|
|
"i", "you", "he", "she", "it", "we", "they",
|
|||
|
|
"this", "that", "these", "those", "my", "your", "his", "her",
|
|||
|
|
"of", "in", "on", "at", "to", "for", "with", "by", "from", "as",
|
|||
|
|
"and", "or", "but", "not", "no", "so", "if",
|
|||
|
|
}
|
|||
|
|
# 中文和英文分词
|
|||
|
|
words = re.findall(r'[\w]+', query.lower())
|
|||
|
|
keywords = [w for w in words if w.lower() not in stop_words and len(w) > 1]
|
|||
|
|
return ", ".join(keywords[:max_words])
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _extract_tool_sequence(steps: List[Dict[str, Any]]) -> List[str]:
|
|||
|
|
"""从执行步骤中提取工具调用序列(去重连续重复)。"""
|
|||
|
|
tools = []
|
|||
|
|
last_tool = None
|
|||
|
|
for step in steps:
|
|||
|
|
if isinstance(step, dict):
|
|||
|
|
tname = step.get("tool_name") or (
|
|||
|
|
step.get("tool_calls", [{}])[0].get("function", {}).get("name")
|
|||
|
|
if step.get("tool_calls") else None
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
tname = getattr(step, "tool_name", None)
|
|||
|
|
if tname and tname != last_tool:
|
|||
|
|
tools.append(tname)
|
|||
|
|
last_tool = tname
|
|||
|
|
return tools
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_pattern_from_result(
|
|||
|
|
query: str,
|
|||
|
|
steps: List[Any],
|
|||
|
|
success: bool,
|
|||
|
|
iterations_used: int,
|
|||
|
|
tool_calls_made: int,
|
|||
|
|
) -> Dict[str, Any]:
|
|||
|
|
"""从 Agent 执行结果中提取学习模式数据。"""
|
|||
|
|
category = _classify_task(query)
|
|||
|
|
keywords = _extract_keywords(query)
|
|||
|
|
tool_sequence = _extract_tool_sequence(steps)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"task_category": category,
|
|||
|
|
"task_keywords": keywords,
|
|||
|
|
"suggested_tools": tool_sequence,
|
|||
|
|
"success": success,
|
|||
|
|
"iterations_used": iterations_used,
|
|||
|
|
"tool_calls_made": tool_calls_made,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_learning_pattern(
|
|||
|
|
db: Session,
|
|||
|
|
scope_kind: str,
|
|||
|
|
scope_id: str,
|
|||
|
|
pattern_data: Dict[str, Any],
|
|||
|
|
) -> AgentLearningPattern:
|
|||
|
|
"""
|
|||
|
|
保存/更新学习模式。
|
|||
|
|
对同 scope + task_category 做 upsert,累积统计。
|
|||
|
|
"""
|
|||
|
|
category = pattern_data["task_category"]
|
|||
|
|
suggested_tools = pattern_data.get("suggested_tools", [])
|
|||
|
|
success = pattern_data.get("success", True)
|
|||
|
|
iterations = pattern_data.get("iterations_used", 0)
|
|||
|
|
tool_calls = pattern_data.get("tool_calls_made", 0)
|
|||
|
|
keywords = pattern_data.get("task_keywords", "")
|
|||
|
|
|
|||
|
|
# 查询已有模式
|
|||
|
|
existing = (
|
|||
|
|
db.query(AgentLearningPattern)
|
|||
|
|
.filter(
|
|||
|
|
AgentLearningPattern.scope_kind == scope_kind,
|
|||
|
|
AgentLearningPattern.scope_id == scope_id,
|
|||
|
|
AgentLearningPattern.task_category == category,
|
|||
|
|
)
|
|||
|
|
.first()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if existing:
|
|||
|
|
# 更新统计
|
|||
|
|
n = existing.total_runs + 1
|
|||
|
|
existing.total_runs = n
|
|||
|
|
if success:
|
|||
|
|
existing.successful_runs += 1
|
|||
|
|
existing.avg_iterations = (
|
|||
|
|
(existing.avg_iterations * (n - 1) + iterations) / n
|
|||
|
|
)
|
|||
|
|
existing.avg_tool_calls = (
|
|||
|
|
(existing.avg_tool_calls * (n - 1) + tool_calls) / n
|
|||
|
|
)
|
|||
|
|
# 合并工具序列:保留出现频率高的
|
|||
|
|
if suggested_tools:
|
|||
|
|
old_tools = json.loads(existing.suggested_tools) if existing.suggested_tools else []
|
|||
|
|
merged = _merge_tool_sequences(old_tools, suggested_tools)
|
|||
|
|
existing.suggested_tools = json.dumps(merged, ensure_ascii=False)
|
|||
|
|
# 合并关键词
|
|||
|
|
if keywords and existing.task_keywords:
|
|||
|
|
merged_kw = list(dict.fromkeys(existing.task_keywords.split(", ") + keywords.split(", ")))
|
|||
|
|
existing.task_keywords = ", ".join(merged_kw[:12])
|
|||
|
|
elif keywords:
|
|||
|
|
existing.task_keywords = keywords
|
|||
|
|
# 更新有效评分
|
|||
|
|
existing.effectiveness_score = _calc_effectiveness(existing)
|
|||
|
|
existing.last_used_at = datetime.utcnow()
|
|||
|
|
existing.updated_at = datetime.utcnow()
|
|||
|
|
db.commit()
|
|||
|
|
db.refresh(existing)
|
|||
|
|
logger.info(
|
|||
|
|
"学习模式更新: scope=%s/%s category=%s runs=%d",
|
|||
|
|
scope_kind, scope_id, category, existing.total_runs,
|
|||
|
|
)
|
|||
|
|
return existing
|
|||
|
|
else:
|
|||
|
|
# 创建新模式
|
|||
|
|
record = AgentLearningPattern(
|
|||
|
|
scope_kind=scope_kind,
|
|||
|
|
scope_id=scope_id,
|
|||
|
|
task_category=category,
|
|||
|
|
task_keywords=keywords,
|
|||
|
|
suggested_tools=json.dumps(suggested_tools, ensure_ascii=False),
|
|||
|
|
effectiveness_score=1.0 if success else 0.3,
|
|||
|
|
total_runs=1,
|
|||
|
|
successful_runs=1 if success else 0,
|
|||
|
|
avg_iterations=float(iterations),
|
|||
|
|
avg_tool_calls=float(tool_calls),
|
|||
|
|
last_used_at=datetime.utcnow(),
|
|||
|
|
)
|
|||
|
|
db.add(record)
|
|||
|
|
db.commit()
|
|||
|
|
db.refresh(record)
|
|||
|
|
logger.info(
|
|||
|
|
"学习模式创建: scope=%s/%s category=%s tools=%s",
|
|||
|
|
scope_kind, scope_id, category, suggested_tools,
|
|||
|
|
)
|
|||
|
|
return record
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _merge_tool_sequences(
|
|||
|
|
old_seq: List[str], new_seq: List[str]
|
|||
|
|
) -> List[str]:
|
|||
|
|
"""合并两套工具序列,保留新序列中工具及其顺序,并补充旧序列中独有的工具。"""
|
|||
|
|
seen = set(new_seq)
|
|||
|
|
result = list(new_seq)
|
|||
|
|
for t in old_seq:
|
|||
|
|
if t not in seen:
|
|||
|
|
result.append(t)
|
|||
|
|
seen.add(t)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _calc_effectiveness(pattern: AgentLearningPattern) -> float:
|
|||
|
|
"""计算模式的有效评分 (0-1)。"""
|
|||
|
|
if pattern.total_runs == 0:
|
|||
|
|
return 0.0
|
|||
|
|
success_rate = pattern.successful_runs / pattern.total_runs
|
|||
|
|
# 迭代效率:迭代越少越好,以 5 次为基准
|
|||
|
|
efficiency = max(0, 1.0 - (pattern.avg_iterations / 10.0))
|
|||
|
|
# 综合评分
|
|||
|
|
return round(success_rate * 0.6 + efficiency * 0.4, 3)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_relevant_patterns(
|
|||
|
|
db: Session, scope_kind: str, scope_id: str, query: str,
|
|||
|
|
top_k: int = 3,
|
|||
|
|
) -> List[AgentLearningPattern]:
|
|||
|
|
"""
|
|||
|
|
加载与当前查询相关的学习模式。
|
|||
|
|
通过任务分类匹配来筛选。
|
|||
|
|
"""
|
|||
|
|
category = _classify_task(query)
|
|||
|
|
|
|||
|
|
patterns = (
|
|||
|
|
db.query(AgentLearningPattern)
|
|||
|
|
.filter(
|
|||
|
|
AgentLearningPattern.scope_kind == scope_kind,
|
|||
|
|
AgentLearningPattern.scope_id == scope_id,
|
|||
|
|
AgentLearningPattern.task_category == category,
|
|||
|
|
AgentLearningPattern.total_runs >= 1,
|
|||
|
|
)
|
|||
|
|
.order_by(AgentLearningPattern.effectiveness_score.desc())
|
|||
|
|
.limit(top_k)
|
|||
|
|
.all()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not patterns and category != "general":
|
|||
|
|
# 降级:匹配 general 或任意分类
|
|||
|
|
patterns = (
|
|||
|
|
db.query(AgentLearningPattern)
|
|||
|
|
.filter(
|
|||
|
|
AgentLearningPattern.scope_kind == scope_kind,
|
|||
|
|
AgentLearningPattern.scope_id == scope_id,
|
|||
|
|
)
|
|||
|
|
.order_by(AgentLearningPattern.effectiveness_score.desc())
|
|||
|
|
.limit(top_k)
|
|||
|
|
.all()
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return patterns
|
|||
|
|
|
|||
|
|
|
|||
|
|
def format_pattern_hint(
|
|||
|
|
patterns: List[AgentLearningPattern], query: str
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
将学习模式格式化为 system prompt 中的提示文本。
|
|||
|
|
如果模式不够成熟(runs < 2),返回空字符串。
|
|||
|
|
"""
|
|||
|
|
if not patterns:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
# 过滤掉模式不明显的数据(工具序列为空、运行次数太少)
|
|||
|
|
mature_patterns = [p for p in patterns if p.total_runs >= 2 and p.effectiveness_score >= 0.3]
|
|||
|
|
if not mature_patterns:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
lines = ["## 历史经验参考(基于过往执行记录)"]
|
|||
|
|
for p in mature_patterns:
|
|||
|
|
tools = json.loads(p.suggested_tools) if p.suggested_tools else []
|
|||
|
|
if not tools:
|
|||
|
|
continue
|
|||
|
|
category_names = {
|
|||
|
|
"data_analysis": "数据分析",
|
|||
|
|
"code_generation": "代码生成",
|
|||
|
|
"code_review": "代码审查",
|
|||
|
|
"debugging": "调试排错",
|
|||
|
|
"file_operation": "文件操作",
|
|||
|
|
"web_search": "网络搜索",
|
|||
|
|
"network": "网络请求",
|
|||
|
|
"database": "数据库",
|
|||
|
|
"text_processing": "文本处理",
|
|||
|
|
"translation": "翻译",
|
|||
|
|
"testing": "测试",
|
|||
|
|
"general": "通用",
|
|||
|
|
}
|
|||
|
|
cat_name = category_names.get(p.task_category, p.task_category)
|
|||
|
|
tool_str = " → ".join(tools)
|
|||
|
|
lines.append(
|
|||
|
|
f"- [{cat_name}] 对于此类任务,历史经验推荐工具序列: {tool_str} "
|
|||
|
|
f"(成功{p.successful_runs}/{p.total_runs}次, 评分{p.effectiveness_score:.2f})"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if len(lines) == 1:
|
|||
|
|
return "" # 只有标题没有内容
|
|||
|
|
|
|||
|
|
return "\n".join(lines)
|