""" Agent 自主学习服务 — 从历史执行中提取工具使用模式,优化后续工具选择。 工作流程: 1. 执行后:提取工具序列 → 分类任务 → 保存/更新学习模式 2. 执行前:查询匹配模式 → 注入 system prompt 作为工具选择建议 """ from __future__ import annotations import json import logging import re from datetime import datetime from typing import Any, Dict, List, Optional, Tuple from sqlalchemy.orm import Session from app.core.database import SessionLocal from app.models.agent_learning_pattern import AgentLearningPattern logger = logging.getLogger(__name__) # ─── 任务分类关键词映射 ────────────────────────────────────── _TASK_CATEGORIES = { "data_analysis": [ "分析", "统计", "csv", "数据", "平均值", "总和", "图表", "报表", "aggregate", "analyze", "statistic", "correlation", ], "code_generation": [ "写", "实现", "函数", "代码", "编程", "脚本", "算法", "implement", "function", "code", "program", "script", ], "code_review": [ "审查", "review", "检查", "代码质量", "重构", "优化", "refactor", "code style", "lint", ], "debugging": [ "调试", "bug", "错误", "异常", "崩溃", "报错", "修复", "debug", "error", "exception", "crash", "fix", ], "file_operation": [ "读取", "写入", "保存", "文件", "目录", "创建", "删除", "read file", "write file", "save", "directory", ], "web_search": [ "搜索", "查询", "查找", "找", "了解", "最新", "search", "find", "look up", "what is", ], "network": [ "请求", "api", "http", "网络", "curl", "网址", "网站", "request", "endpoint", "network", "url", ], "database": [ "数据库", "sql", "表", "查询", "mysql", "postgresql", "database", "select", "insert", "update", "delete", ], "text_processing": [ "摘要", "翻译", "总结", "提取", "格式化", "转换", "summarize", "translate", "extract", "format", "convert", ], "translation": [ "翻译", "译成", "translate", "translation", "localization", ], "testing": [ "测试", "单元测试", "集成测试", "用例", "pytest", "jest", "test", "unit test", "integration test", ], } def _classify_task(query: str) -> str: """根据用户查询内容分类任务类型。""" q = query.lower().strip() scores: List[Tuple[str, int]] = [] for category, keywords in _TASK_CATEGORIES.items(): score = sum(1 for kw in keywords if kw in q) if score > 0: scores.append((category, score)) if not scores: return "general" scores.sort(key=lambda x: -x[1]) return scores[0][0] def _extract_keywords(query: str, max_words: int = 8) -> str: """从查询中提取关键词。""" # 移除常见停用词 stop_words = { "的", "了", "在", "是", "我", "有", "和", "就", "不", "人", "都", "一", "一个", "上", "也", "很", "到", "说", "要", "去", "你", "会", "着", "没有", "看", "好", "自己", "这", "他", "她", "它", "们", "the", "a", "an", "is", "are", "was", "were", "be", "been", "have", "has", "had", "do", "does", "did", "will", "would", "can", "could", "should", "may", "might", "shall", "i", "you", "he", "she", "it", "we", "they", "this", "that", "these", "those", "my", "your", "his", "her", "of", "in", "on", "at", "to", "for", "with", "by", "from", "as", "and", "or", "but", "not", "no", "so", "if", } # 中文和英文分词 words = re.findall(r'[\w]+', query.lower()) keywords = [w for w in words if w.lower() not in stop_words and len(w) > 1] return ", ".join(keywords[:max_words]) def _extract_tool_sequence(steps: List[Dict[str, Any]]) -> List[str]: """从执行步骤中提取工具调用序列(去重连续重复)。""" tools = [] last_tool = None for step in steps: if isinstance(step, dict): tname = step.get("tool_name") or ( step.get("tool_calls", [{}])[0].get("function", {}).get("name") if step.get("tool_calls") else None ) else: tname = getattr(step, "tool_name", None) if tname and tname != last_tool: tools.append(tname) last_tool = tname return tools def extract_pattern_from_result( query: str, steps: List[Any], success: bool, iterations_used: int, tool_calls_made: int, ) -> Dict[str, Any]: """从 Agent 执行结果中提取学习模式数据。""" category = _classify_task(query) keywords = _extract_keywords(query) tool_sequence = _extract_tool_sequence(steps) return { "task_category": category, "task_keywords": keywords, "suggested_tools": tool_sequence, "success": success, "iterations_used": iterations_used, "tool_calls_made": tool_calls_made, } def save_learning_pattern( db: Session, scope_kind: str, scope_id: str, pattern_data: Dict[str, Any], ) -> AgentLearningPattern: """ 保存/更新学习模式。 对同 scope + task_category 做 upsert,累积统计。 """ category = pattern_data["task_category"] suggested_tools = pattern_data.get("suggested_tools", []) success = pattern_data.get("success", True) iterations = pattern_data.get("iterations_used", 0) tool_calls = pattern_data.get("tool_calls_made", 0) keywords = pattern_data.get("task_keywords", "") # 查询已有模式 existing = ( db.query(AgentLearningPattern) .filter( AgentLearningPattern.scope_kind == scope_kind, AgentLearningPattern.scope_id == scope_id, AgentLearningPattern.task_category == category, ) .first() ) if existing: # 更新统计 n = existing.total_runs + 1 existing.total_runs = n if success: existing.successful_runs += 1 existing.avg_iterations = ( (existing.avg_iterations * (n - 1) + iterations) / n ) existing.avg_tool_calls = ( (existing.avg_tool_calls * (n - 1) + tool_calls) / n ) # 合并工具序列:保留出现频率高的 if suggested_tools: old_tools = json.loads(existing.suggested_tools) if existing.suggested_tools else [] merged = _merge_tool_sequences(old_tools, suggested_tools) existing.suggested_tools = json.dumps(merged, ensure_ascii=False) # 合并关键词 if keywords and existing.task_keywords: merged_kw = list(dict.fromkeys(existing.task_keywords.split(", ") + keywords.split(", "))) existing.task_keywords = ", ".join(merged_kw[:12]) elif keywords: existing.task_keywords = keywords # 更新有效评分 existing.effectiveness_score = _calc_effectiveness(existing) existing.last_used_at = datetime.utcnow() existing.updated_at = datetime.utcnow() db.commit() db.refresh(existing) logger.info( "学习模式更新: scope=%s/%s category=%s runs=%d", scope_kind, scope_id, category, existing.total_runs, ) return existing else: # 创建新模式 record = AgentLearningPattern( scope_kind=scope_kind, scope_id=scope_id, task_category=category, task_keywords=keywords, suggested_tools=json.dumps(suggested_tools, ensure_ascii=False), effectiveness_score=1.0 if success else 0.3, total_runs=1, successful_runs=1 if success else 0, avg_iterations=float(iterations), avg_tool_calls=float(tool_calls), last_used_at=datetime.utcnow(), ) db.add(record) db.commit() db.refresh(record) logger.info( "学习模式创建: scope=%s/%s category=%s tools=%s", scope_kind, scope_id, category, suggested_tools, ) return record def _merge_tool_sequences( old_seq: List[str], new_seq: List[str] ) -> List[str]: """合并两套工具序列,保留新序列中工具及其顺序,并补充旧序列中独有的工具。""" seen = set(new_seq) result = list(new_seq) for t in old_seq: if t not in seen: result.append(t) seen.add(t) return result def _calc_effectiveness(pattern: AgentLearningPattern) -> float: """计算模式的有效评分 (0-1)。""" if pattern.total_runs == 0: return 0.0 success_rate = pattern.successful_runs / pattern.total_runs # 迭代效率:迭代越少越好,以 5 次为基准 efficiency = max(0, 1.0 - (pattern.avg_iterations / 10.0)) # 综合评分 return round(success_rate * 0.6 + efficiency * 0.4, 3) def load_relevant_patterns( db: Session, scope_kind: str, scope_id: str, query: str, top_k: int = 3, ) -> List[AgentLearningPattern]: """ 加载与当前查询相关的学习模式。 通过任务分类匹配来筛选。 """ category = _classify_task(query) patterns = ( db.query(AgentLearningPattern) .filter( AgentLearningPattern.scope_kind == scope_kind, AgentLearningPattern.scope_id == scope_id, AgentLearningPattern.task_category == category, AgentLearningPattern.total_runs >= 1, ) .order_by(AgentLearningPattern.effectiveness_score.desc()) .limit(top_k) .all() ) if not patterns and category != "general": # 降级:匹配 general 或任意分类 patterns = ( db.query(AgentLearningPattern) .filter( AgentLearningPattern.scope_kind == scope_kind, AgentLearningPattern.scope_id == scope_id, ) .order_by(AgentLearningPattern.effectiveness_score.desc()) .limit(top_k) .all() ) return patterns def format_pattern_hint( patterns: List[AgentLearningPattern], query: str ) -> str: """ 将学习模式格式化为 system prompt 中的提示文本。 如果模式不够成熟(runs < 2),返回空字符串。 """ if not patterns: return "" # 过滤掉模式不明显的数据(工具序列为空、运行次数太少) mature_patterns = [p for p in patterns if p.total_runs >= 2 and p.effectiveness_score >= 0.3] if not mature_patterns: return "" lines = ["## 历史经验参考(基于过往执行记录)"] for p in mature_patterns: tools = json.loads(p.suggested_tools) if p.suggested_tools else [] if not tools: continue category_names = { "data_analysis": "数据分析", "code_generation": "代码生成", "code_review": "代码审查", "debugging": "调试排错", "file_operation": "文件操作", "web_search": "网络搜索", "network": "网络请求", "database": "数据库", "text_processing": "文本处理", "translation": "翻译", "testing": "测试", "general": "通用", } cat_name = category_names.get(p.task_category, p.task_category) tool_str = " → ".join(tools) lines.append( f"- [{cat_name}] 对于此类任务,历史经验推荐工具序列: {tool_str} " f"(成功{p.successful_runs}/{p.total_runs}次, 评分{p.effectiveness_score:.2f})" ) if len(lines) == 1: return "" # 只有标题没有内容 return "\n".join(lines)