feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser - 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行) - 新增 agent_vector_memory / knowledge_base 模型及对应数据库表 - 实现 SSE 流式响应与 Agent 预算控制 - AgentChat.vue 集成 MainLayout 导航布局 - 完善测试体系:7 个新测试文件共 110 个测试覆盖 - 修复 conftest.py SQLite 内存数据库连接隔离问题 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
321
backend/tests/test_agent_memory.py
Normal file
321
backend/tests/test_agent_memory.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Agent 记忆系统单元测试
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class TestAgentContext:
|
||||
"""AgentContext 基础功能测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_context_initialization(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(
|
||||
system_prompt="You are a helpful assistant.",
|
||||
user_id="test-user",
|
||||
session_id="test-session",
|
||||
)
|
||||
assert ctx.session_id == "test-session"
|
||||
assert ctx.user_id == "test-user"
|
||||
assert ctx.iteration == 0
|
||||
assert ctx.tool_calls_made == 0
|
||||
# system prompt 在 messages 中
|
||||
msgs = ctx.messages
|
||||
assert msgs[0]["role"] == "system"
|
||||
assert msgs[0]["content"] == "You are a helpful assistant."
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_user_message(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(session_id="s1")
|
||||
ctx.add_user_message("Hello")
|
||||
assert len(ctx.messages) == 2 # system + user
|
||||
assert ctx.messages[1]["role"] == "user"
|
||||
assert ctx.messages[1]["content"] == "Hello"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_assistant_message(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(session_id="s1")
|
||||
ctx.add_user_message("Hi")
|
||||
ctx.add_assistant_message("Hello!", tool_calls=[{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "test", "arguments": "{}"},
|
||||
}])
|
||||
msgs = ctx.messages
|
||||
assistant_msgs = [m for m in msgs if m["role"] == "assistant"]
|
||||
assert len(assistant_msgs) == 1
|
||||
assert assistant_msgs[0]["content"] == "Hello!"
|
||||
assert "tool_calls" in assistant_msgs[0]
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_tool_result(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(session_id="s1")
|
||||
ctx.add_tool_result("call_1", "test_tool", '{"result": "ok"}')
|
||||
tool_msgs = [m for m in ctx.messages if m["role"] == "tool"]
|
||||
assert len(tool_msgs) == 1
|
||||
assert tool_msgs[0]["tool_call_id"] == "call_1"
|
||||
assert tool_msgs[0]["name"] == "test_tool"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_iteration_tracking(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext("s1")
|
||||
assert ctx.iteration == 0
|
||||
ctx.iteration += 1
|
||||
assert ctx.iteration == 1
|
||||
ctx.tool_calls_made += 2
|
||||
assert ctx.tool_calls_made == 2
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_context_reset(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(system_prompt="Helpful.", session_id="s1")
|
||||
ctx.add_user_message("Hello")
|
||||
ctx.add_assistant_message("Hi")
|
||||
ctx.iteration = 5
|
||||
ctx.tool_calls_made = 3
|
||||
ctx.reset()
|
||||
assert ctx.iteration == 0
|
||||
assert ctx.tool_calls_made == 0
|
||||
# 重置后 messages 应仅含 system
|
||||
msgs = ctx.messages
|
||||
assert len(msgs) == 1
|
||||
assert msgs[0]["role"] == "system"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_set_system_prompt(self):
|
||||
from app.agent_runtime.context import AgentContext
|
||||
|
||||
ctx = AgentContext(system_prompt="Original.", session_id="s1")
|
||||
ctx.set_system_prompt("Updated.")
|
||||
# 未发送过消息,可以更新
|
||||
assert ctx.messages[0]["content"] == "Updated."
|
||||
|
||||
|
||||
class TestAgentMemory:
|
||||
"""AgentMemory 分层记忆测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_no_persist(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(
|
||||
scope_kind="agent",
|
||||
scope_id="test-agent",
|
||||
session_key="test-session",
|
||||
persist=False,
|
||||
)
|
||||
result = await memory.initialize(query="Hello")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_context_no_persist(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(
|
||||
scope_kind="agent",
|
||||
scope_id="test-agent",
|
||||
session_key="test-session",
|
||||
persist=False,
|
||||
)
|
||||
await memory.save_context(
|
||||
user_message="Hello",
|
||||
assistant_reply="Hi there!",
|
||||
)
|
||||
# 不报错即为通过
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_vector_memory_disabled(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(
|
||||
scope_kind="agent",
|
||||
scope_id="test-agent",
|
||||
session_key="test-session",
|
||||
persist=False,
|
||||
vector_memory_enabled=False,
|
||||
)
|
||||
await memory.save_context(
|
||||
user_message="No vector",
|
||||
assistant_reply="OK",
|
||||
)
|
||||
# 不报错即为通过
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_vector_search_no_results(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(
|
||||
scope_kind="agent",
|
||||
scope_id="nonexistent",
|
||||
persist=False,
|
||||
vector_memory_enabled=True,
|
||||
)
|
||||
result = await memory._vector_search(query="test")
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_trim_messages(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
memory = AgentMemory(persist=False)
|
||||
messages = [{"role": "system", "content": "You are helpful."}]
|
||||
for i in range(30):
|
||||
messages.append({"role": "user", "content": f"msg {i}"})
|
||||
messages.append({"role": "assistant", "content": f"reply {i}"})
|
||||
trimmed = memory.trim_messages(messages)
|
||||
assert len(trimmed) <= memory.max_history + 1 # +1 for system
|
||||
assert trimmed[0]["role"] == "system"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_summarize_history(self):
|
||||
from app.agent_runtime.memory import AgentMemory
|
||||
|
||||
history = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
summary = AgentMemory._summarize_history(history)
|
||||
assert "2 轮" in summary
|
||||
|
||||
|
||||
class TestToolManager:
|
||||
"""AgentToolManager 测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_include_filter(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
mgr = AgentToolManager(include_tools=["math", "file_read"])
|
||||
assert mgr._include_tools == {"math", "file_read"}
|
||||
assert mgr._exclude_tools == set()
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_exclude_filter(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
mgr = AgentToolManager(exclude_tools=["dangerous_tool"])
|
||||
assert "dangerous_tool" in mgr._exclude_tools
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tool_name_extraction(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
name = AgentToolManager._extract_tool_name({
|
||||
"type": "function",
|
||||
"function": {"name": "my_tool"},
|
||||
})
|
||||
assert name == "my_tool"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tool_name_extraction_flat(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
name = AgentToolManager._extract_tool_name({"name": "flat_tool"})
|
||||
assert name == "flat_tool"
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tool_name_extraction_empty(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
name = AgentToolManager._extract_tool_name({})
|
||||
assert name is None
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_has_tools(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
mgr = AgentToolManager()
|
||||
assert mgr.has_tools() is True
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_tool_names(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
|
||||
mgr = AgentToolManager()
|
||||
names = mgr.tool_names()
|
||||
assert isinstance(names, list)
|
||||
assert len(names) >= 1
|
||||
|
||||
@pytest.mark.unit
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_delegates_to_registry(self):
|
||||
from app.agent_runtime.tool_manager import AgentToolManager
|
||||
from app.services.tool_registry import tool_registry
|
||||
|
||||
mgr = AgentToolManager()
|
||||
# 执行一个内置工具
|
||||
result = await mgr.execute("datetime", {"format": "%Y"})
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
class TestAgentSchemas:
|
||||
"""Agent Schemas 测试"""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_config_defaults(self):
|
||||
from app.agent_runtime.schemas import AgentConfig
|
||||
|
||||
config = AgentConfig(name="Test Agent")
|
||||
assert config.name == "Test Agent"
|
||||
assert config.system_prompt == "你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。"
|
||||
assert config.llm.model == "gpt-4o-mini"
|
||||
assert config.llm.temperature == 0.7
|
||||
assert config.llm.max_iterations == 10
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_memory_config_defaults(self):
|
||||
from app.agent_runtime.schemas import AgentMemoryConfig
|
||||
|
||||
cfg = AgentMemoryConfig()
|
||||
assert cfg.enabled is True
|
||||
assert cfg.vector_memory_enabled is True
|
||||
assert cfg.vector_memory_top_k == 5
|
||||
assert cfg.max_history_messages == 20
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_budget_config_defaults(self):
|
||||
from app.agent_runtime.schemas import AgentBudgetConfig
|
||||
|
||||
cfg = AgentBudgetConfig()
|
||||
assert cfg.max_llm_invocations == 200
|
||||
assert cfg.max_tool_calls == 500
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_result_fields(self):
|
||||
from app.agent_runtime.schemas import AgentResult, AgentStep
|
||||
|
||||
result = AgentResult(
|
||||
content="Test result",
|
||||
iterations_used=3,
|
||||
tool_calls_made=5,
|
||||
truncated=False,
|
||||
steps=[
|
||||
AgentStep(iteration=1, type="think", content="Thinking..."),
|
||||
AgentStep(iteration=2, type="tool_call", content="Calling tool", tool_name="test"),
|
||||
],
|
||||
)
|
||||
assert result.success is True
|
||||
assert result.content == "Test result"
|
||||
assert result.iterations_used == 3
|
||||
assert len(result.steps) == 2
|
||||
Reference in New Issue
Block a user