Files
aiagent/backend/tests/test_embedding_service.py
renjianbo 7b9e0826de feat: 向量记忆 RAG、工具市场、SSE 流式响应、前端集成与测试覆盖
- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser
- 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行)
- 新增 agent_vector_memory / knowledge_base 模型及对应数据库表
- 实现 SSE 流式响应与 Agent 预算控制
- AgentChat.vue 集成 MainLayout 导航布局
- 完善测试体系:7 个新测试文件共 110 个测试覆盖
- 修复 conftest.py SQLite 内存数据库连接隔离问题

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-01 22:30:46 +08:00

147 lines
4.8 KiB
Python

"""
Embedding 服务单元测试
"""
from __future__ import annotations
import pytest
from unittest.mock import patch, AsyncMock
from app.services.embedding_service import EmbeddingService
class TestEmbeddingService:
"""Embedding 服务测试"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cosine_similarity_identical(self):
from app.services.embedding_service import EmbeddingService
a = [1.0, 0.0, 0.0]
b = [1.0, 0.0, 0.0]
sim = EmbeddingService.cosine_similarity(a, b)
assert sim == pytest.approx(1.0, abs=1e-6)
@pytest.mark.unit
def test_cosine_similarity_orthogonal(self):
from app.services.embedding_service import EmbeddingService
a = [1.0, 0.0]
b = [0.0, 1.0]
sim = EmbeddingService.cosine_similarity(a, b)
assert sim == pytest.approx(0.0, abs=1e-6)
@pytest.mark.unit
def test_cosine_similarity_opposite(self):
from app.services.embedding_service import EmbeddingService
a = [1.0, 0.0]
b = [-1.0, 0.0]
sim = EmbeddingService.cosine_similarity(a, b)
assert sim == pytest.approx(-1.0, abs=1e-6)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_similarity_search_empty(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
results = await svc.similarity_search(
[1.0, 0.0], [], top_k=5
)
assert results == []
@pytest.mark.unit
@pytest.mark.asyncio
async def test_similarity_search_ordering(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
entries = [
{"content_text": "dogs are great pets", "embedding": [0.9, 0.0, 0.0]},
{"content_text": "cats are independent", "embedding": [0.1, 0.0, 0.0]},
]
query = [0.8, 0.0, 0.0]
results = await svc.similarity_search(query, entries, top_k=3, min_score=0.0)
assert len(results) == 2
assert results[0]["content_text"] == "dogs are great pets"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_similarity_search_top_k(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
entries = [
{"content_text": f"entry {i}", "embedding": [float(i) / 10, 0.0]} for i in range(10)
]
query = [1.0, 0.0]
results = await svc.similarity_search(query, entries, top_k=3)
assert len(results) == 3
@pytest.mark.unit
@pytest.mark.asyncio
async def test_similarity_search_min_score(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
entries = [
{"content_text": "close match", "embedding": [0.9, 0.0]},
{"content_text": "distant match", "embedding": [-0.5, 0.0]},
]
query = [1.0, 0.0]
results = await svc.similarity_search(query, entries, top_k=5, min_score=0.5)
assert len(results) == 1
assert results[0]["content_text"] == "close match"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_generate_embedding_empty(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
result = await svc.generate_embedding("")
assert result is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_generate_embedding_no_api_key(self):
"""无 API Key 时返回 None"""
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
with patch.object(svc, "_get_client", AsyncMock(return_value=None)):
result = await svc.generate_embedding("test")
assert result is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_generate_embeddings_empty(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
result = await svc.generate_embeddings([])
assert result == []
@pytest.mark.unit
def test_serialize_deserialize(self):
from app.services.embedding_service import EmbeddingService
emb = [0.1, 0.2, 0.3]
serialized = EmbeddingService.serialize_embedding(emb)
deserialized = EmbeddingService.deserialize_embedding(serialized)
assert deserialized == emb
@pytest.mark.unit
def test_deserialize_invalid(self):
from app.services.embedding_service import EmbeddingService
result = EmbeddingService.deserialize_embedding("invalid json")
assert result == []
@pytest.mark.unit
def test_deserialize_list_already(self):
from app.services.embedding_service import EmbeddingService
result = EmbeddingService.deserialize_embedding([1.0, 2.0])
assert result == [1.0, 2.0]