- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser - 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行) - 新增 agent_vector_memory / knowledge_base 模型及对应数据库表 - 实现 SSE 流式响应与 Agent 预算控制 - AgentChat.vue 集成 MainLayout 导航布局 - 完善测试体系:7 个新测试文件共 110 个测试覆盖 - 修复 conftest.py SQLite 内存数据库连接隔离问题 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
322 lines
10 KiB
Python
322 lines
10 KiB
Python
"""
|
|
Agent 记忆系统单元测试
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import pytest
|
|
from unittest.mock import AsyncMock, patch, MagicMock
|
|
from datetime import datetime
|
|
|
|
|
|
class TestAgentContext:
|
|
"""AgentContext 基础功能测试"""
|
|
|
|
@pytest.mark.unit
|
|
def test_context_initialization(self):
|
|
from app.agent_runtime.context import AgentContext
|
|
|
|
ctx = AgentContext(
|
|
system_prompt="You are a helpful assistant.",
|
|
user_id="test-user",
|
|
session_id="test-session",
|
|
)
|
|
assert ctx.session_id == "test-session"
|
|
assert ctx.user_id == "test-user"
|
|
assert ctx.iteration == 0
|
|
assert ctx.tool_calls_made == 0
|
|
# system prompt 在 messages 中
|
|
msgs = ctx.messages
|
|
assert msgs[0]["role"] == "system"
|
|
assert msgs[0]["content"] == "You are a helpful assistant."
|
|
|
|
@pytest.mark.unit
|
|
def test_add_user_message(self):
|
|
from app.agent_runtime.context import AgentContext
|
|
|
|
ctx = AgentContext(session_id="s1")
|
|
ctx.add_user_message("Hello")
|
|
assert len(ctx.messages) == 2 # system + user
|
|
assert ctx.messages[1]["role"] == "user"
|
|
assert ctx.messages[1]["content"] == "Hello"
|
|
|
|
@pytest.mark.unit
|
|
def test_add_assistant_message(self):
|
|
from app.agent_runtime.context import AgentContext
|
|
|
|
ctx = AgentContext(session_id="s1")
|
|
ctx.add_user_message("Hi")
|
|
ctx.add_assistant_message("Hello!", tool_calls=[{
|
|
"id": "call_1",
|
|
"type": "function",
|
|
"function": {"name": "test", "arguments": "{}"},
|
|
}])
|
|
msgs = ctx.messages
|
|
assistant_msgs = [m for m in msgs if m["role"] == "assistant"]
|
|
assert len(assistant_msgs) == 1
|
|
assert assistant_msgs[0]["content"] == "Hello!"
|
|
assert "tool_calls" in assistant_msgs[0]
|
|
|
|
@pytest.mark.unit
|
|
def test_add_tool_result(self):
|
|
from app.agent_runtime.context import AgentContext
|
|
|
|
ctx = AgentContext(session_id="s1")
|
|
ctx.add_tool_result("call_1", "test_tool", '{"result": "ok"}')
|
|
tool_msgs = [m for m in ctx.messages if m["role"] == "tool"]
|
|
assert len(tool_msgs) == 1
|
|
assert tool_msgs[0]["tool_call_id"] == "call_1"
|
|
assert tool_msgs[0]["name"] == "test_tool"
|
|
|
|
@pytest.mark.unit
|
|
def test_iteration_tracking(self):
|
|
from app.agent_runtime.context import AgentContext
|
|
|
|
ctx = AgentContext("s1")
|
|
assert ctx.iteration == 0
|
|
ctx.iteration += 1
|
|
assert ctx.iteration == 1
|
|
ctx.tool_calls_made += 2
|
|
assert ctx.tool_calls_made == 2
|
|
|
|
@pytest.mark.unit
|
|
def test_context_reset(self):
|
|
from app.agent_runtime.context import AgentContext
|
|
|
|
ctx = AgentContext(system_prompt="Helpful.", session_id="s1")
|
|
ctx.add_user_message("Hello")
|
|
ctx.add_assistant_message("Hi")
|
|
ctx.iteration = 5
|
|
ctx.tool_calls_made = 3
|
|
ctx.reset()
|
|
assert ctx.iteration == 0
|
|
assert ctx.tool_calls_made == 0
|
|
# 重置后 messages 应仅含 system
|
|
msgs = ctx.messages
|
|
assert len(msgs) == 1
|
|
assert msgs[0]["role"] == "system"
|
|
|
|
@pytest.mark.unit
|
|
def test_set_system_prompt(self):
|
|
from app.agent_runtime.context import AgentContext
|
|
|
|
ctx = AgentContext(system_prompt="Original.", session_id="s1")
|
|
ctx.set_system_prompt("Updated.")
|
|
# 未发送过消息,可以更新
|
|
assert ctx.messages[0]["content"] == "Updated."
|
|
|
|
|
|
class TestAgentMemory:
|
|
"""AgentMemory 分层记忆测试"""
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_initialize_no_persist(self):
|
|
from app.agent_runtime.memory import AgentMemory
|
|
|
|
memory = AgentMemory(
|
|
scope_kind="agent",
|
|
scope_id="test-agent",
|
|
session_key="test-session",
|
|
persist=False,
|
|
)
|
|
result = await memory.initialize(query="Hello")
|
|
assert result == ""
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_save_context_no_persist(self):
|
|
from app.agent_runtime.memory import AgentMemory
|
|
|
|
memory = AgentMemory(
|
|
scope_kind="agent",
|
|
scope_id="test-agent",
|
|
session_key="test-session",
|
|
persist=False,
|
|
)
|
|
await memory.save_context(
|
|
user_message="Hello",
|
|
assistant_reply="Hi there!",
|
|
)
|
|
# 不报错即为通过
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_vector_memory_disabled(self):
|
|
from app.agent_runtime.memory import AgentMemory
|
|
|
|
memory = AgentMemory(
|
|
scope_kind="agent",
|
|
scope_id="test-agent",
|
|
session_key="test-session",
|
|
persist=False,
|
|
vector_memory_enabled=False,
|
|
)
|
|
await memory.save_context(
|
|
user_message="No vector",
|
|
assistant_reply="OK",
|
|
)
|
|
# 不报错即为通过
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_vector_search_no_results(self):
|
|
from app.agent_runtime.memory import AgentMemory
|
|
|
|
memory = AgentMemory(
|
|
scope_kind="agent",
|
|
scope_id="nonexistent",
|
|
persist=False,
|
|
vector_memory_enabled=True,
|
|
)
|
|
result = await memory._vector_search(query="test")
|
|
assert result == ""
|
|
|
|
@pytest.mark.unit
|
|
def test_trim_messages(self):
|
|
from app.agent_runtime.memory import AgentMemory
|
|
|
|
memory = AgentMemory(persist=False)
|
|
messages = [{"role": "system", "content": "You are helpful."}]
|
|
for i in range(30):
|
|
messages.append({"role": "user", "content": f"msg {i}"})
|
|
messages.append({"role": "assistant", "content": f"reply {i}"})
|
|
trimmed = memory.trim_messages(messages)
|
|
assert len(trimmed) <= memory.max_history + 1 # +1 for system
|
|
assert trimmed[0]["role"] == "system"
|
|
|
|
@pytest.mark.unit
|
|
def test_summarize_history(self):
|
|
from app.agent_runtime.memory import AgentMemory
|
|
|
|
history = [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "Hello"},
|
|
{"role": "user", "content": "How are you?"},
|
|
]
|
|
summary = AgentMemory._summarize_history(history)
|
|
assert "2 轮" in summary
|
|
|
|
|
|
class TestToolManager:
|
|
"""AgentToolManager 测试"""
|
|
|
|
@pytest.mark.unit
|
|
def test_include_filter(self):
|
|
from app.agent_runtime.tool_manager import AgentToolManager
|
|
|
|
mgr = AgentToolManager(include_tools=["math", "file_read"])
|
|
assert mgr._include_tools == {"math", "file_read"}
|
|
assert mgr._exclude_tools == set()
|
|
|
|
@pytest.mark.unit
|
|
def test_exclude_filter(self):
|
|
from app.agent_runtime.tool_manager import AgentToolManager
|
|
|
|
mgr = AgentToolManager(exclude_tools=["dangerous_tool"])
|
|
assert "dangerous_tool" in mgr._exclude_tools
|
|
|
|
@pytest.mark.unit
|
|
def test_tool_name_extraction(self):
|
|
from app.agent_runtime.tool_manager import AgentToolManager
|
|
|
|
name = AgentToolManager._extract_tool_name({
|
|
"type": "function",
|
|
"function": {"name": "my_tool"},
|
|
})
|
|
assert name == "my_tool"
|
|
|
|
@pytest.mark.unit
|
|
def test_tool_name_extraction_flat(self):
|
|
from app.agent_runtime.tool_manager import AgentToolManager
|
|
|
|
name = AgentToolManager._extract_tool_name({"name": "flat_tool"})
|
|
assert name == "flat_tool"
|
|
|
|
@pytest.mark.unit
|
|
def test_tool_name_extraction_empty(self):
|
|
from app.agent_runtime.tool_manager import AgentToolManager
|
|
|
|
name = AgentToolManager._extract_tool_name({})
|
|
assert name is None
|
|
|
|
@pytest.mark.unit
|
|
def test_has_tools(self):
|
|
from app.agent_runtime.tool_manager import AgentToolManager
|
|
|
|
mgr = AgentToolManager()
|
|
assert mgr.has_tools() is True
|
|
|
|
@pytest.mark.unit
|
|
def test_tool_names(self):
|
|
from app.agent_runtime.tool_manager import AgentToolManager
|
|
|
|
mgr = AgentToolManager()
|
|
names = mgr.tool_names()
|
|
assert isinstance(names, list)
|
|
assert len(names) >= 1
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_execute_delegates_to_registry(self):
|
|
from app.agent_runtime.tool_manager import AgentToolManager
|
|
from app.services.tool_registry import tool_registry
|
|
|
|
mgr = AgentToolManager()
|
|
# 执行一个内置工具
|
|
result = await mgr.execute("datetime", {"format": "%Y"})
|
|
assert isinstance(result, str)
|
|
assert len(result) > 0
|
|
|
|
|
|
class TestAgentSchemas:
|
|
"""Agent Schemas 测试"""
|
|
|
|
@pytest.mark.unit
|
|
def test_agent_config_defaults(self):
|
|
from app.agent_runtime.schemas import AgentConfig
|
|
|
|
config = AgentConfig(name="Test Agent")
|
|
assert config.name == "Test Agent"
|
|
assert config.system_prompt == "你是一个有用的AI助手。请使用可用工具来帮助用户完成任务。"
|
|
assert config.llm.model == "gpt-4o-mini"
|
|
assert config.llm.temperature == 0.7
|
|
assert config.llm.max_iterations == 10
|
|
|
|
@pytest.mark.unit
|
|
def test_agent_memory_config_defaults(self):
|
|
from app.agent_runtime.schemas import AgentMemoryConfig
|
|
|
|
cfg = AgentMemoryConfig()
|
|
assert cfg.enabled is True
|
|
assert cfg.vector_memory_enabled is True
|
|
assert cfg.vector_memory_top_k == 5
|
|
assert cfg.max_history_messages == 20
|
|
|
|
@pytest.mark.unit
|
|
def test_agent_budget_config_defaults(self):
|
|
from app.agent_runtime.schemas import AgentBudgetConfig
|
|
|
|
cfg = AgentBudgetConfig()
|
|
assert cfg.max_llm_invocations == 200
|
|
assert cfg.max_tool_calls == 500
|
|
|
|
@pytest.mark.unit
|
|
def test_agent_result_fields(self):
|
|
from app.agent_runtime.schemas import AgentResult, AgentStep
|
|
|
|
result = AgentResult(
|
|
content="Test result",
|
|
iterations_used=3,
|
|
tool_calls_made=5,
|
|
truncated=False,
|
|
steps=[
|
|
AgentStep(iteration=1, type="think", content="Thinking..."),
|
|
AgentStep(iteration=2, type="tool_call", content="Calling tool", tool_name="test"),
|
|
],
|
|
)
|
|
assert result.success is True
|
|
assert result.content == "Test result"
|
|
assert result.iterations_used == 3
|
|
assert len(result.steps) == 2
|