diff --git a/backend/app/agent_runtime/core.py b/backend/app/agent_runtime/core.py index b921b16..bfe424b 100644 --- a/backend/app/agent_runtime/core.py +++ b/backend/app/agent_runtime/core.py @@ -24,6 +24,12 @@ from app.agent_runtime.context import AgentContext from app.agent_runtime.memory import AgentMemory from app.agent_runtime.tool_manager import AgentToolManager from app.core.exceptions import WorkflowExecutionError +from app.services.agent_learning_service import ( + extract_pattern_from_result, + format_pattern_hint, + load_relevant_patterns, + save_learning_pattern, +) logger = logging.getLogger(__name__) @@ -98,6 +104,8 @@ class AgentRuntime: self.on_llm_call = on_llm_call self._memory_context_loaded = False self._llm_invocations = 0 + # 自主学习作用域:bare 聊天用 "bare",Agent 用 "agent" + self._learning_scope_kind = "bare" if "bare" in str(_mem_scope) else "agent" # 预算回调:供 WorkflowEngine 注入,使 Agent 内部计数计入工作流预算 # 返回 True 表示预算充足;返回 False 或抛出异常表示超限 @@ -221,6 +229,13 @@ class AgentRuntime: )) # 保存记忆 await self.memory.save_context(user_input, final_text, self.context.messages) + # 保存学习模式 + if self.config.memory.learning_enabled: + await self._save_learning_pattern( + user_input, steps, success=True, + iterations_used=self.context.iteration, + tool_calls_made=self.context.tool_calls_made, + ) return AgentResult( success=True, content=final_text, @@ -318,6 +333,13 @@ class AgentRuntime: logger.warning("Agent 达到最大迭代次数 (%s)", max_iter) await self.memory.save_context(user_input, last_content or "(已达最大迭代次数)", self.context.messages) + # 保存学习模式(即便截断,工具调用模式仍有参考价值) + if self.config.memory.learning_enabled: + await self._save_learning_pattern( + user_input, steps, success=True, + iterations_used=self.context.iteration, + tool_calls_made=self.context.tool_calls_made, + ) if last_content: steps.append(AgentStep( iteration=self.context.iteration, @@ -446,6 +468,13 @@ class AgentRuntime: "session_id": self.context.session_id, } await self.memory.save_context(user_input, final_text, self.context.messages) + # 保存学习模式 + if self.config.memory.learning_enabled: + await self._save_learning_pattern( + user_input, steps, success=True, + iterations_used=self.context.iteration, + tool_calls_made=self.context.tool_calls_made, + ) return # 有工具调用 → 先记录 assistant 消息 @@ -557,6 +586,13 @@ class AgentRuntime: logger.warning("Agent 达到最大迭代次数 (%s)", max_iter) await self.memory.save_context(user_input, last_content or "(已达最大迭代次数)", self.context.messages) + # 保存学习模式(即便截断,工具调用模式仍有参考价值) + if self.config.memory.learning_enabled: + await self._save_learning_pattern( + user_input, steps, success=True, + iterations_used=self.context.iteration, + tool_calls_made=self.context.tool_calls_made, + ) yield { "type": "final", "content": last_content or "已达最大迭代次数,但模型未返回最终回答。", @@ -570,14 +606,62 @@ class AgentRuntime: async def _inject_memory_context(self, query: str = "") -> None: """加载长期记忆并注入 system prompt。""" mem_text = await self.memory.initialize(query=query) + enriched = self.config.system_prompt.rstrip("\n") + if mem_text: - enriched = ( - self.config.system_prompt.rstrip("\n") - + "\n\n" - + mem_text + enriched += "\n\n" + mem_text + + # 注入学习模式提示(历史工具使用建议) + if self.config.memory.learning_enabled: + pattern_hint = await self._inject_learning_patterns(query) + if pattern_hint: + enriched += "\n\n" + pattern_hint + + self.context.set_system_prompt(enriched) + logger.info("Agent 已注入长期记忆上下文") + + async def _inject_learning_patterns(self, query: str) -> str: + """查询学习模式,返回格式化的提示文本。""" + from app.core.database import SessionLocal + db = None + try: + db = SessionLocal() + patterns = load_relevant_patterns( + db, self._learning_scope_kind, self.memory.scope_id, query ) - self.context.set_system_prompt(enriched) - logger.info("Agent 已注入长期记忆上下文") + return format_pattern_hint(patterns, query) + except Exception as e: + logger.warning("加载学习模式失败: %s", e) + return "" + finally: + if db: + db.close() + + async def _save_learning_pattern( + self, query: str, steps: List[AgentStep], + success: bool, iterations_used: int, tool_calls_made: int, + ) -> None: + """从执行结果中提取模式并保存。""" + from app.core.database import SessionLocal + db = None + try: + db = SessionLocal() + pattern_data = extract_pattern_from_result( + query=query, + steps=steps, + success=success, + iterations_used=iterations_used, + tool_calls_made=tool_calls_made, + ) + save_learning_pattern( + db, self._learning_scope_kind, + self.memory.scope_id, pattern_data, + ) + except Exception as e: + logger.warning("保存学习模式失败: %s", e) + finally: + if db: + db.close() @staticmethod def _extract_tool_calls(response: Any) -> List[Dict[str, Any]]: diff --git a/backend/app/agent_runtime/schemas.py b/backend/app/agent_runtime/schemas.py index 4e0f39f..abeaa71 100644 --- a/backend/app/agent_runtime/schemas.py +++ b/backend/app/agent_runtime/schemas.py @@ -22,6 +22,7 @@ class AgentMemoryConfig(BaseModel): persist_to_db: bool = True # 是否写入 MySQL 长期记忆 vector_memory_enabled: bool = True # 是否启用向量记忆(语义检索) vector_memory_top_k: int = 5 # 向量检索 Top-K + learning_enabled: bool = True # 是否启用自主学习(工具模式学习) class AgentLLMConfig(BaseModel): diff --git a/backend/app/core/database.py b/backend/app/core/database.py index 4cb9ad5..fbaa82e 100644 --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -48,5 +48,6 @@ def init_db(): import app.models.alert_rule import app.models.agent_llm_log import app.models.agent_vector_memory + import app.models.agent_learning_pattern import app.models.knowledge_base Base.metadata.create_all(bind=engine) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 13b3fda..8640e50 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -14,6 +14,7 @@ from app.models.alert_rule import AlertRule, AlertLog from app.models.persistent_user_memory import PersistentUserMemory from app.models.agent_llm_log import AgentLLMLog from app.models.agent_vector_memory import AgentVectorMemory +from app.models.agent_learning_pattern import AgentLearningPattern from app.models.knowledge_base import KnowledgeBase, Document, DocumentChunk -__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "KnowledgeBase", "Document", "DocumentChunk"] \ No newline at end of file +__all__ = ["User", "Workflow", "WorkflowVersion", "Agent", "Execution", "ExecutionLog", "ModelConfig", "DataSource", "WorkflowTemplate", "TemplateRating", "TemplateFavorite", "NodeTemplate", "Role", "Permission", "WorkflowPermission", "AgentPermission", "AlertRule", "AlertLog", "PersistentUserMemory", "AgentLLMLog", "AgentVectorMemory", "AgentLearningPattern", "KnowledgeBase", "Document", "DocumentChunk"] \ No newline at end of file diff --git a/backend/app/models/agent_learning_pattern.py b/backend/app/models/agent_learning_pattern.py new file mode 100644 index 0000000..58f0162 --- /dev/null +++ b/backend/app/models/agent_learning_pattern.py @@ -0,0 +1,25 @@ +"""Agent 自主学习模式表:记录工具使用模式,支持自主学习优化""" +import uuid +from datetime import datetime +from sqlalchemy import Column, String, Text, Float, Integer, DateTime +from app.core.database import Base + + +class AgentLearningPattern(Base): + """Agent 学习模式 — 记录工具调用序列与任务类型的关系""" + __tablename__ = "agent_learning_patterns" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + scope_kind = Column(String(16), nullable=False, index=True, comment="作用域类型: agent/bare") + scope_id = Column(String(64), nullable=False, index=True, comment="作用域 ID: agent_id/user_id") + task_category = Column(String(64), nullable=False, default="general", comment="任务分类") + task_keywords = Column(String(256), default="", comment="任务关键词") + suggested_tools = Column(Text, nullable=False, comment="推荐工具序列 (JSON array)") + effectiveness_score = Column(Float, default=0.0, comment="有效评分 0-1") + total_runs = Column(Integer, default=1, comment="总运行次数") + successful_runs = Column(Integer, default=1, comment="成功次数") + avg_iterations = Column(Float, default=1.0, comment="平均迭代次数") + avg_tool_calls = Column(Float, default=1.0, comment="平均工具调用数") + last_used_at = Column(DateTime, default=datetime.utcnow, comment="最后使用时间") + created_at = Column(DateTime, default=datetime.utcnow, comment="创建时间") + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间") diff --git a/backend/app/services/agent_learning_service.py b/backend/app/services/agent_learning_service.py new file mode 100644 index 0000000..0ef0836 --- /dev/null +++ b/backend/app/services/agent_learning_service.py @@ -0,0 +1,340 @@ +""" +Agent 自主学习服务 — 从历史执行中提取工具使用模式,优化后续工具选择。 + +工作流程: +1. 执行后:提取工具序列 → 分类任务 → 保存/更新学习模式 +2. 执行前:查询匹配模式 → 注入 system prompt 作为工具选择建议 +""" +from __future__ import annotations + +import json +import logging +import re +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple +from sqlalchemy.orm import Session + +from app.core.database import SessionLocal +from app.models.agent_learning_pattern import AgentLearningPattern + +logger = logging.getLogger(__name__) + +# ─── 任务分类关键词映射 ────────────────────────────────────── + +_TASK_CATEGORIES = { + "data_analysis": [ + "分析", "统计", "csv", "数据", "平均值", "总和", "图表", "报表", + "aggregate", "analyze", "statistic", "correlation", + ], + "code_generation": [ + "写", "实现", "函数", "代码", "编程", "脚本", "算法", + "implement", "function", "code", "program", "script", + ], + "code_review": [ + "审查", "review", "检查", "代码质量", "重构", "优化", + "refactor", "code style", "lint", + ], + "debugging": [ + "调试", "bug", "错误", "异常", "崩溃", "报错", "修复", + "debug", "error", "exception", "crash", "fix", + ], + "file_operation": [ + "读取", "写入", "保存", "文件", "目录", "创建", "删除", + "read file", "write file", "save", "directory", + ], + "web_search": [ + "搜索", "查询", "查找", "找", "了解", "最新", + "search", "find", "look up", "what is", + ], + "network": [ + "请求", "api", "http", "网络", "curl", "网址", "网站", + "request", "endpoint", "network", "url", + ], + "database": [ + "数据库", "sql", "表", "查询", "mysql", "postgresql", + "database", "select", "insert", "update", "delete", + ], + "text_processing": [ + "摘要", "翻译", "总结", "提取", "格式化", "转换", + "summarize", "translate", "extract", "format", "convert", + ], + "translation": [ + "翻译", "译成", "translate", "translation", "localization", + ], + "testing": [ + "测试", "单元测试", "集成测试", "用例", "pytest", "jest", + "test", "unit test", "integration test", + ], +} + + +def _classify_task(query: str) -> str: + """根据用户查询内容分类任务类型。""" + q = query.lower().strip() + scores: List[Tuple[str, int]] = [] + for category, keywords in _TASK_CATEGORIES.items(): + score = sum(1 for kw in keywords if kw in q) + if score > 0: + scores.append((category, score)) + if not scores: + return "general" + scores.sort(key=lambda x: -x[1]) + return scores[0][0] + + +def _extract_keywords(query: str, max_words: int = 8) -> str: + """从查询中提取关键词。""" + # 移除常见停用词 + stop_words = { + "的", "了", "在", "是", "我", "有", "和", "就", "不", "人", "都", "一", + "一个", "上", "也", "很", "到", "说", "要", "去", "你", "会", "着", + "没有", "看", "好", "自己", "这", "他", "她", "它", "们", + "the", "a", "an", "is", "are", "was", "were", "be", "been", + "have", "has", "had", "do", "does", "did", "will", "would", + "can", "could", "should", "may", "might", "shall", + "i", "you", "he", "she", "it", "we", "they", + "this", "that", "these", "those", "my", "your", "his", "her", + "of", "in", "on", "at", "to", "for", "with", "by", "from", "as", + "and", "or", "but", "not", "no", "so", "if", + } + # 中文和英文分词 + words = re.findall(r'[\w]+', query.lower()) + keywords = [w for w in words if w.lower() not in stop_words and len(w) > 1] + return ", ".join(keywords[:max_words]) + + +def _extract_tool_sequence(steps: List[Dict[str, Any]]) -> List[str]: + """从执行步骤中提取工具调用序列(去重连续重复)。""" + tools = [] + last_tool = None + for step in steps: + if isinstance(step, dict): + tname = step.get("tool_name") or ( + step.get("tool_calls", [{}])[0].get("function", {}).get("name") + if step.get("tool_calls") else None + ) + else: + tname = getattr(step, "tool_name", None) + if tname and tname != last_tool: + tools.append(tname) + last_tool = tname + return tools + + +def extract_pattern_from_result( + query: str, + steps: List[Any], + success: bool, + iterations_used: int, + tool_calls_made: int, +) -> Dict[str, Any]: + """从 Agent 执行结果中提取学习模式数据。""" + category = _classify_task(query) + keywords = _extract_keywords(query) + tool_sequence = _extract_tool_sequence(steps) + + return { + "task_category": category, + "task_keywords": keywords, + "suggested_tools": tool_sequence, + "success": success, + "iterations_used": iterations_used, + "tool_calls_made": tool_calls_made, + } + + +def save_learning_pattern( + db: Session, + scope_kind: str, + scope_id: str, + pattern_data: Dict[str, Any], +) -> AgentLearningPattern: + """ + 保存/更新学习模式。 + 对同 scope + task_category 做 upsert,累积统计。 + """ + category = pattern_data["task_category"] + suggested_tools = pattern_data.get("suggested_tools", []) + success = pattern_data.get("success", True) + iterations = pattern_data.get("iterations_used", 0) + tool_calls = pattern_data.get("tool_calls_made", 0) + keywords = pattern_data.get("task_keywords", "") + + # 查询已有模式 + existing = ( + db.query(AgentLearningPattern) + .filter( + AgentLearningPattern.scope_kind == scope_kind, + AgentLearningPattern.scope_id == scope_id, + AgentLearningPattern.task_category == category, + ) + .first() + ) + + if existing: + # 更新统计 + n = existing.total_runs + 1 + existing.total_runs = n + if success: + existing.successful_runs += 1 + existing.avg_iterations = ( + (existing.avg_iterations * (n - 1) + iterations) / n + ) + existing.avg_tool_calls = ( + (existing.avg_tool_calls * (n - 1) + tool_calls) / n + ) + # 合并工具序列:保留出现频率高的 + if suggested_tools: + old_tools = json.loads(existing.suggested_tools) if existing.suggested_tools else [] + merged = _merge_tool_sequences(old_tools, suggested_tools) + existing.suggested_tools = json.dumps(merged, ensure_ascii=False) + # 合并关键词 + if keywords and existing.task_keywords: + merged_kw = list(dict.fromkeys(existing.task_keywords.split(", ") + keywords.split(", "))) + existing.task_keywords = ", ".join(merged_kw[:12]) + elif keywords: + existing.task_keywords = keywords + # 更新有效评分 + existing.effectiveness_score = _calc_effectiveness(existing) + existing.last_used_at = datetime.utcnow() + existing.updated_at = datetime.utcnow() + db.commit() + db.refresh(existing) + logger.info( + "学习模式更新: scope=%s/%s category=%s runs=%d", + scope_kind, scope_id, category, existing.total_runs, + ) + return existing + else: + # 创建新模式 + record = AgentLearningPattern( + scope_kind=scope_kind, + scope_id=scope_id, + task_category=category, + task_keywords=keywords, + suggested_tools=json.dumps(suggested_tools, ensure_ascii=False), + effectiveness_score=1.0 if success else 0.3, + total_runs=1, + successful_runs=1 if success else 0, + avg_iterations=float(iterations), + avg_tool_calls=float(tool_calls), + last_used_at=datetime.utcnow(), + ) + db.add(record) + db.commit() + db.refresh(record) + logger.info( + "学习模式创建: scope=%s/%s category=%s tools=%s", + scope_kind, scope_id, category, suggested_tools, + ) + return record + + +def _merge_tool_sequences( + old_seq: List[str], new_seq: List[str] +) -> List[str]: + """合并两套工具序列,保留新序列中工具及其顺序,并补充旧序列中独有的工具。""" + seen = set(new_seq) + result = list(new_seq) + for t in old_seq: + if t not in seen: + result.append(t) + seen.add(t) + return result + + +def _calc_effectiveness(pattern: AgentLearningPattern) -> float: + """计算模式的有效评分 (0-1)。""" + if pattern.total_runs == 0: + return 0.0 + success_rate = pattern.successful_runs / pattern.total_runs + # 迭代效率:迭代越少越好,以 5 次为基准 + efficiency = max(0, 1.0 - (pattern.avg_iterations / 10.0)) + # 综合评分 + return round(success_rate * 0.6 + efficiency * 0.4, 3) + + +def load_relevant_patterns( + db: Session, scope_kind: str, scope_id: str, query: str, + top_k: int = 3, +) -> List[AgentLearningPattern]: + """ + 加载与当前查询相关的学习模式。 + 通过任务分类匹配来筛选。 + """ + category = _classify_task(query) + + patterns = ( + db.query(AgentLearningPattern) + .filter( + AgentLearningPattern.scope_kind == scope_kind, + AgentLearningPattern.scope_id == scope_id, + AgentLearningPattern.task_category == category, + AgentLearningPattern.total_runs >= 1, + ) + .order_by(AgentLearningPattern.effectiveness_score.desc()) + .limit(top_k) + .all() + ) + + if not patterns and category != "general": + # 降级:匹配 general 或任意分类 + patterns = ( + db.query(AgentLearningPattern) + .filter( + AgentLearningPattern.scope_kind == scope_kind, + AgentLearningPattern.scope_id == scope_id, + ) + .order_by(AgentLearningPattern.effectiveness_score.desc()) + .limit(top_k) + .all() + ) + + return patterns + + +def format_pattern_hint( + patterns: List[AgentLearningPattern], query: str +) -> str: + """ + 将学习模式格式化为 system prompt 中的提示文本。 + 如果模式不够成熟(runs < 2),返回空字符串。 + """ + if not patterns: + return "" + + # 过滤掉模式不明显的数据(工具序列为空、运行次数太少) + mature_patterns = [p for p in patterns if p.total_runs >= 2 and p.effectiveness_score >= 0.3] + if not mature_patterns: + return "" + + lines = ["## 历史经验参考(基于过往执行记录)"] + for p in mature_patterns: + tools = json.loads(p.suggested_tools) if p.suggested_tools else [] + if not tools: + continue + category_names = { + "data_analysis": "数据分析", + "code_generation": "代码生成", + "code_review": "代码审查", + "debugging": "调试排错", + "file_operation": "文件操作", + "web_search": "网络搜索", + "network": "网络请求", + "database": "数据库", + "text_processing": "文本处理", + "translation": "翻译", + "testing": "测试", + "general": "通用", + } + cat_name = category_names.get(p.task_category, p.task_category) + tool_str = " → ".join(tools) + lines.append( + f"- [{cat_name}] 对于此类任务,历史经验推荐工具序列: {tool_str} " + f"(成功{p.successful_runs}/{p.total_runs}次, 评分{p.effectiveness_score:.2f})" + ) + + if len(lines) == 1: + return "" # 只有标题没有内容 + + return "\n".join(lines)