Files
aiagent/backend/tests/test_knowledge_loop.py

625 lines
24 KiB
Python
Raw Normal View History

"""
知识闭环端到端测试
验证 执行日志 知识提取 知识存储 检索 注入 Prompt 的完整闭环
"""
import pytest
from unittest.mock import patch, MagicMock
@pytest.mark.unit
@pytest.mark.knowledge
class TestKnowledgeExtractionPipeline:
"""知识提取管道:日志 → LLM 提取 → KnowledgeEntry 存入"""
def test_agent_execution_log_has_extraction_flag(self, db_session):
"""AgentExecutionLog 模型包含 knowledge_extracted 字段"""
from app.models.agent_execution_log import AgentExecutionLog
import uuid
log = AgentExecutionLog(
id=str(uuid.uuid4()),
agent_name="test_agent",
input_text="如何优化数据库查询性能?",
output_text="可以通过建立索引、优化SQL语句、使用连接池等方式提升性能。",
success=True,
iterations_used=3,
tool_calls_made=5,
)
db_session.add(log)
db_session.commit()
assert log.knowledge_extracted is False
assert log.success is True
assert log.agent_name == "test_agent"
def test_extract_from_execution_with_llm(self, db_session):
"""知识提取器从执行日志中提取知识mock LLM"""
from app.services.knowledge_extractor import KnowledgeExtractor
import json
log_data = {
"input_text": "Nginx 502 Bad Gateway 错误应该如何排查和修复?请给出完整步骤",
"output_text": "先检查后端服务是否正常运行,再查看 Nginx 错误日志定位 upstream 连接问题,确认超时配置和缓冲区大小是否合理。",
"success": True,
"tool_chain": [{"name": "web_search", "result": "ok"}],
"iterations_used": 3,
"tool_calls_made": 2,
}
mock_llm_response = json.dumps({
"title": "Nginx 502 错误排查方法",
"category": "best_practice",
"tags": ["nginx", "502", "排查"],
"situation": "Nginx 返回 502 Bad Gateway",
"solution": "1.检查后端服务状态 2.查看error.log 3.调整超时配置",
"caveats": "注意区分 502 和 504 的不同排查路径",
"confidence": 0.85,
})
extractor = KnowledgeExtractor(llm_model="test-model")
with patch.object(extractor, "_sync_llm_call", return_value=mock_llm_response):
result = extractor.extract_from_execution(log_data)
assert result is not None
assert result["title"] == "Nginx 502 错误排查方法"
assert result["category"] == "best_practice"
assert result["confidence"] == 0.85
assert len(result["tags"]) == 3
assert "检查后端服务状态" in result["solution"]
def test_extract_skips_low_confidence(self, db_session):
"""低置信度 (<0.3) 的知识应被跳过"""
from app.services.knowledge_extractor import KnowledgeExtractor
import json
log_data = {
"input_text": "测试问题" * 10,
"output_text": "测试回答" * 20,
"success": True,
"tool_chain": [],
"iterations_used": 1,
"tool_calls_made": 0,
}
mock_llm_response = json.dumps({
"title": "低质量知识",
"category": "insight",
"tags": [],
"situation": "测试",
"solution": "测试",
"caveats": "",
"confidence": 0.15,
})
extractor = KnowledgeExtractor(llm_model="test-model")
with patch.object(extractor, "_sync_llm_call", return_value=mock_llm_response):
result = extractor.extract_from_execution(log_data)
assert result is None
def test_extract_skips_marked_skip(self, db_session):
"""LLM 标记 skip 时应跳过"""
from app.services.knowledge_extractor import KnowledgeExtractor
import json
log_data = {
"input_text": "你好" * 10,
"output_text": "你好!有什么可以帮助你的吗?" * 10,
"success": True,
"tool_chain": [],
"iterations_used": 1,
"tool_calls_made": 0,
}
mock_llm_response = json.dumps({"skip": True, "reason": "对话过于简单"})
extractor = KnowledgeExtractor(llm_model="test-model")
with patch.object(extractor, "_sync_llm_call", return_value=mock_llm_response):
result = extractor.extract_from_execution(log_data)
assert result is None
def test_extract_skips_short_content(self, db_session):
"""输入/输出太短时应跳过(不调用 LLM"""
from app.services.knowledge_extractor import KnowledgeExtractor
log_data = {
"input_text": "hi",
"output_text": "hello",
"success": True,
"tool_chain": [],
"iterations_used": 1,
"tool_calls_made": 0,
}
extractor = KnowledgeExtractor()
result = extractor.extract_from_execution(log_data)
assert result is None
def test_knowledge_entry_created_from_extraction(self, db_session):
"""提取的知识应正确存入 KnowledgeEntry"""
from app.models.knowledge_entry import KnowledgeEntry
import uuid
entry = KnowledgeEntry(
title="MySQL 死锁重试策略",
category="bug_fix",
tags=["mysql", "deadlock", "retry"],
situation="高并发下 MySQL 发生死锁",
solution="捕获 DeadlockError 后休眠随机毫秒数再重试最多3次",
caveats="幂等操作才可安全重试",
source_execution_ids=[str(uuid.uuid4())],
source_agent_name="db_admin_agent",
embedding_text="MySQL 死锁重试策略 高并发下 MySQL 发生死锁 捕获 DeadlockError 后休眠随机毫秒数再重试",
extracted_by="llm_auto",
confidence=0.9,
)
db_session.add(entry)
db_session.commit()
assert entry.id is not None
assert entry.category == "bug_fix"
assert entry.confidence == 0.9
assert entry.retrieval_count == 0
assert entry.success_rate is None
assert entry.is_active is True
def test_pipeline_dedup_by_title(self, db_session):
"""管道去重:同名标题的知识条目不重复创建"""
from app.models.knowledge_entry import KnowledgeEntry
title = "唯一标题去重测试"
entry1 = KnowledgeEntry(title=title, category="insight", extracted_by="llm_auto", confidence=0.8)
db_session.add(entry1)
db_session.commit()
existing = db_session.query(KnowledgeEntry).filter(
KnowledgeEntry.title == title
).first()
assert existing is not None
assert existing.id == entry1.id
@pytest.mark.unit
@pytest.mark.knowledge
class TestKnowledgeRetrievalPipeline:
"""知识检索管道:关键词搜索 → 排序 → 格式化注入"""
def test_query_by_keyword(self, db_session):
"""关键词搜索应能通过 LIKE 查询找到匹配条目"""
from app.models.knowledge_entry import KnowledgeEntry
from sqlalchemy import or_
entry = KnowledgeEntry(
title="Redis 缓存穿透防护",
category="best_practice",
situation="大量请求查询不存在的缓存key导致DB压力",
solution="使用布隆过滤器或缓存空值短TTL",
caveats="布隆过滤器有误判率",
extracted_by="llm_auto",
confidence=0.9,
)
db_session.add(entry)
db_session.commit()
# 模拟 retriever 的关键词 LIKE 查询
keywords = ["Redis", "缓存穿透"]
conditions = []
for kw in keywords:
like_pat = f"%{kw}%"
conditions.append(KnowledgeEntry.title.like(like_pat))
conditions.append(KnowledgeEntry.situation.like(like_pat))
conditions.append(KnowledgeEntry.solution.like(like_pat))
results = db_session.query(KnowledgeEntry).filter(
KnowledgeEntry.is_active == True,
or_(*conditions),
).all()
assert len(results) >= 1
assert results[0].title == "Redis 缓存穿透防护"
def test_retrieval_increments_count(self, db_session):
"""检索后 retrieval_count 应 +1"""
from app.models.knowledge_entry import KnowledgeEntry
entry = KnowledgeEntry(
title="测试检索计数",
category="insight",
situation="测试场景",
solution="测试方案",
extracted_by="llm_auto",
confidence=0.8,
retrieval_count=0,
)
db_session.add(entry)
db_session.commit()
entry_id = entry.id
# 模拟检索计数更新
entry.retrieval_count = (entry.retrieval_count or 0) + 1
db_session.commit()
updated = db_session.query(KnowledgeEntry).filter(KnowledgeEntry.id == entry_id).first()
assert updated is not None
assert updated.retrieval_count == 1
def test_retrieve_respects_top_k(self, db_session):
"""检索应限制返回 top_k 条"""
from app.models.knowledge_entry import KnowledgeEntry
from sqlalchemy import or_
for i in range(5):
entry = KnowledgeEntry(
title=f"共享关键词条目{i}",
category="insight",
situation=f"场景{i}",
solution=f"方案{i}",
extracted_by="llm_auto",
confidence=0.5 + i * 0.1,
)
db_session.add(entry)
db_session.commit()
# 模拟 top_k 限制
top_k = 3
conditions = []
for kw in ["共享关键词条目"]:
like_pat = f"%{kw}%"
conditions.append(KnowledgeEntry.title.like(like_pat))
results = db_session.query(KnowledgeEntry).filter(
KnowledgeEntry.is_active == True,
or_(*conditions),
).limit(top_k).all()
assert len(results) <= top_k
def test_retrieve_filter_by_category(self, db_session):
"""应按类别过滤检索结果"""
from app.models.knowledge_entry import KnowledgeEntry
entry_bug = KnowledgeEntry(title="公共关键词Bug条目", category="bug_fix",
situation="bug", solution="fix",
extracted_by="llm_auto", confidence=0.8)
entry_bp = KnowledgeEntry(title="公共关键词最佳实践", category="best_practice",
situation="practice", solution="do",
extracted_by="llm_auto", confidence=0.9)
db_session.add_all([entry_bug, entry_bp])
db_session.commit()
# 模拟类别过滤
results = db_session.query(KnowledgeEntry).filter(
KnowledgeEntry.is_active == True,
KnowledgeEntry.category == "bug_fix",
).all()
for r in results:
assert r.category == "bug_fix"
def test_format_for_prompt(self, db_session):
"""知识条目格式化应为 Markdown 格式"""
from app.services.knowledge_retriever import KnowledgeRetriever
entries = [{
"id": "test-1",
"title": "测试知识",
"category": "best_practice",
"tags": ["test"],
"situation": "测试场景",
"solution": "执行以下步骤\n1. 步骤A\n2. 步骤B",
"caveats": "注意安全",
"confidence": 0.9,
}]
retriever = KnowledgeRetriever()
formatted = retriever.format_for_prompt(entries)
assert "相关知识库经验" in formatted
assert "测试知识" in formatted
assert "置信度: 90%" in formatted
assert "步骤A" in formatted
assert "注意安全" in formatted
def test_inject_knowledge_appends_to_prompt(self, db_session):
"""inject_knowledge 应检索并追加知识到 system prompt"""
from app.models.knowledge_entry import KnowledgeEntry
from app.services.knowledge_retriever import KnowledgeRetriever
# 存入知识条目
entry = KnowledgeEntry(
title="注入测试条目",
category="insight",
situation="注入测试",
solution="应出现在结果中",
extracted_by="llm_auto",
confidence=0.9,
)
db_session.add(entry)
db_session.commit()
# 使用 mock retrieve 返回已知条目
mock_entries = [{
"id": str(entry.id),
"title": "注入测试条目",
"category": "insight",
"tags": [],
"situation": "注入测试",
"solution": "应出现在结果中",
"caveats": "",
"confidence": 0.9,
}]
retriever = KnowledgeRetriever()
with patch.object(retriever, "retrieve", return_value=mock_entries):
system_prompt = "你是一个有用的助手。"
result = retriever.inject_knowledge(system_prompt, "注入测试")
assert result.startswith(system_prompt)
assert "相关知识库经验" in result
assert "注入测试条目" in result
def test_empty_retrieval_returns_original_prompt(self, db_session):
"""无匹配时 inject_knowledge 应返回原始的 system prompt"""
from app.services.knowledge_retriever import KnowledgeRetriever
retriever = KnowledgeRetriever()
with patch.object(retriever, "retrieve", return_value=[]):
system_prompt = "原始系统提示"
result = retriever.inject_knowledge(system_prompt, "不存在的关键词")
assert result == system_prompt
def test_knowledge_with_tags_stored(self, db_session):
"""知识标签应正确存储"""
from app.models.knowledge_entry import KnowledgeEntry
tags = ["python", "async", "性能优化"]
entry = KnowledgeEntry(
title="Python 异步性能优化",
category="optimization",
tags=tags,
situation="高并发异步任务",
solution="使用 asyncio.gather 并发执行",
extracted_by="llm_auto",
confidence=0.85,
)
db_session.add(entry)
db_session.commit()
refreshed = db_session.query(KnowledgeEntry).filter(KnowledgeEntry.id == entry.id).first()
assert refreshed is not None
assert "python" in refreshed.tags
assert len(refreshed.tags) == 3
@pytest.mark.unit
@pytest.mark.knowledge
class TestKnowledgeLoopEndToEnd:
"""知识闭环端到端:执行→提取→检索→注入→验证"""
def test_full_loop_log_to_retrieval(self, db_session):
"""端到端:模拟执行日志 → 提取知识 → 检索 → 注入 Prompt"""
from app.models.agent_execution_log import AgentExecutionLog
from app.models.knowledge_entry import KnowledgeEntry
from app.services.knowledge_extractor import KnowledgeExtractor
from app.services.knowledge_retriever import KnowledgeRetriever
import json
import uuid
# Step 1: 创建执行日志
log = AgentExecutionLog(
id=str(uuid.uuid4()),
agent_name="devops_agent",
input_text="生产环境 Redis 连接池耗尽导致服务不可用,如何预防?",
output_text=(
"1. 设置合理的 max_connections 上限\n"
"2. 配置连接空闲超时和最大等待时间\n"
"3. 添加连接池监控告警(可用连接数 < 20% 时告警)\n"
"4. 使用连接池预热避免冷启动雪崩"
),
success=True,
iterations_used=4,
tool_calls_made=3,
)
db_session.add(log)
db_session.commit()
# Step 2: 模拟知识提取mock LLM
mock_result = json.dumps({
"title": "Redis 连接池耗尽预防策略",
"category": "best_practice",
"tags": ["redis", "连接池", "高可用"],
"situation": "生产环境 Redis 连接池耗尽",
"solution": "设置max_connections上限+空闲超时+监控告警+连接池预热",
"caveats": "max_connections 需要根据实例规格合理设置",
"confidence": 0.92,
})
extractor = KnowledgeExtractor()
log_data = {
"input_text": log.input_text,
"output_text": log.output_text,
"success": log.success,
"tool_chain": [],
"iterations_used": log.iterations_used,
"tool_calls_made": log.tool_calls_made,
}
with patch.object(extractor, "_sync_llm_call", return_value=mock_result):
knowledge = extractor.extract_from_execution(log_data)
assert knowledge is not None
assert knowledge["category"] == "best_practice"
assert knowledge["confidence"] == 0.92
# Step 3: 创建 KnowledgeEntry
entry = KnowledgeEntry(
title=knowledge["title"],
category=knowledge["category"],
tags=knowledge["tags"],
situation=knowledge["situation"],
solution=knowledge["solution"],
caveats=knowledge["caveats"],
source_execution_ids=[str(log.id)],
source_agent_name=log.agent_name,
embedding_text=f"{knowledge['title']} {knowledge['situation']} {knowledge['solution']}",
extracted_by="llm_auto",
confidence=knowledge["confidence"],
)
db_session.add(entry)
db_session.commit()
# Step 4: 通过 DB 查询验证知识已存储
stored = db_session.query(KnowledgeEntry).filter(
KnowledgeEntry.title == "Redis 连接池耗尽预防策略"
).first()
assert stored is not None
assert stored.category == "best_practice"
assert stored.source_execution_ids == [str(log.id)]
assert stored.source_agent_name == "devops_agent"
# Step 5: 格式化 Prompt不用真实的 retriever直接测试 format_for_prompt
retriever = KnowledgeRetriever()
mock_entries = [{
"id": str(stored.id),
"title": stored.title,
"category": stored.category,
"tags": stored.tags,
"situation": stored.situation,
"solution": stored.solution,
"caveats": stored.caveats,
"confidence": stored.confidence,
}]
enriched = retriever.format_for_prompt(mock_entries)
assert "Redis 连接池耗尽预防策略" in enriched
assert "相关知识库经验" in enriched
# Step 6: 验证格式化后的内容可注入 system prompt
system_prompt = "你是一个运维助手。"
final_with_mock = system_prompt + enriched
assert system_prompt in final_with_mock
assert "相关知识库经验" in final_with_mock
assert "Redis 连接池耗尽预防策略" in final_with_mock
def test_multiple_executions_accumulate_knowledge(self, db_session):
"""多次执行应累积多条知识"""
from app.models.knowledge_entry import KnowledgeEntry
import uuid
entries_data = [
("Nginx配置优化", "optimization", 0.9),
("MySQL慢查询优化", "optimization", 0.85),
("Docker内存泄漏处理", "bug_fix", 0.8),
("API限流策略", "best_practice", 0.95),
("日志收集最佳实践", "best_practice", 0.88),
]
for title, cat, conf in entries_data:
entry = KnowledgeEntry(
title=title, category=cat, extracted_by="llm_auto",
confidence=conf, embedding_text=title,
source_execution_ids=[str(uuid.uuid4())],
)
db_session.add(entry)
db_session.commit()
all_entries = db_session.query(KnowledgeEntry).all()
assert len(all_entries) >= 5
optimizations = [e for e in all_entries if e.category == "optimization"]
assert len(optimizations) >= 2
best_practices = [e for e in all_entries if e.category == "best_practice"]
assert len(best_practices) >= 2
def test_knowledge_entry_has_all_required_fields(self, db_session):
"""KnowledgeEntry 应包含所有闭环所需的字段"""
from app.models.knowledge_entry import KnowledgeEntry
import uuid
entry = KnowledgeEntry(
title="完整字段测试",
category="insight",
tags=["a", "b", "c"],
situation="场景描述",
solution="解决方案",
caveats="注意事项",
source_execution_ids=[str(uuid.uuid4()), str(uuid.uuid4())],
source_agent_name="source_agent",
source_model="deepseek-chat",
embedding_text="embedding text",
extracted_by="llm_auto",
confidence=0.75,
retrieval_count=5,
success_rate=0.8,
is_active=True,
)
db_session.add(entry)
db_session.commit()
refreshed = db_session.query(KnowledgeEntry).filter(KnowledgeEntry.id == entry.id).first()
assert refreshed is not None
assert refreshed.title == "完整字段测试"
assert refreshed.confidence == 0.75
assert refreshed.retrieval_count == 5
assert refreshed.success_rate == 0.8
assert refreshed.is_active is True
assert len(refreshed.source_execution_ids) == 2
assert len(refreshed.tags) == 3
def test_deactivated_knowledge_not_retrieved(self, db_session):
"""is_active=False 的条目不应被检索到"""
from app.models.knowledge_entry import KnowledgeEntry
inactive = KnowledgeEntry(
title="已停用知识点", category="insight",
situation="不应出现", solution="不应出现",
extracted_by="llm_auto", confidence=0.9,
is_active=False,
)
db_session.add(inactive)
db_session.commit()
# 查询活跃条目不应包含已停用的
active_results = db_session.query(KnowledgeEntry).filter(
KnowledgeEntry.is_active == True,
KnowledgeEntry.title == "已停用知识点",
).all()
assert len(active_results) == 0
# 但停用条目仍在数据库中
all_results = db_session.query(KnowledgeEntry).filter(
KnowledgeEntry.title == "已停用知识点",
).all()
assert len(all_results) == 1
assert all_results[0].is_active is False
def test_extraction_pipeline_marks_log_processed(self, db_session):
"""提取管道处理后应标记日志为已提取"""
from app.models.agent_execution_log import AgentExecutionLog
import uuid
log = AgentExecutionLog(
id=str(uuid.uuid4()),
agent_name="test_agent",
input_text="生产环境数据库连接超时如何排查?",
output_text="逐步排查1.检查网络连通性 2.检查连接池配置 3.检查数据库负载 4.检查慢查询",
success=True,
iterations_used=2,
tool_calls_made=4,
)
db_session.add(log)
db_session.commit()
# 模拟提取后的标记
log.knowledge_extracted = True
db_session.commit()
refreshed = db_session.query(AgentExecutionLog).filter(
AgentExecutionLog.id == log.id
).first()
assert refreshed is not None
assert refreshed.knowledge_extracted is True