Files
aiagent/backend/tests/test_embedding_service.py

147 lines
4.8 KiB
Python
Raw Normal View History

"""
Embedding 服务单元测试
"""
from __future__ import annotations
import pytest
from unittest.mock import patch, AsyncMock
from app.services.embedding_service import EmbeddingService
class TestEmbeddingService:
"""Embedding 服务测试"""
@pytest.mark.unit
@pytest.mark.asyncio
async def test_cosine_similarity_identical(self):
from app.services.embedding_service import EmbeddingService
a = [1.0, 0.0, 0.0]
b = [1.0, 0.0, 0.0]
sim = EmbeddingService.cosine_similarity(a, b)
assert sim == pytest.approx(1.0, abs=1e-6)
@pytest.mark.unit
def test_cosine_similarity_orthogonal(self):
from app.services.embedding_service import EmbeddingService
a = [1.0, 0.0]
b = [0.0, 1.0]
sim = EmbeddingService.cosine_similarity(a, b)
assert sim == pytest.approx(0.0, abs=1e-6)
@pytest.mark.unit
def test_cosine_similarity_opposite(self):
from app.services.embedding_service import EmbeddingService
a = [1.0, 0.0]
b = [-1.0, 0.0]
sim = EmbeddingService.cosine_similarity(a, b)
assert sim == pytest.approx(-1.0, abs=1e-6)
@pytest.mark.unit
@pytest.mark.asyncio
async def test_similarity_search_empty(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
results = await svc.similarity_search(
[1.0, 0.0], [], top_k=5
)
assert results == []
@pytest.mark.unit
@pytest.mark.asyncio
async def test_similarity_search_ordering(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
entries = [
{"content_text": "dogs are great pets", "embedding": [0.9, 0.0, 0.0]},
{"content_text": "cats are independent", "embedding": [0.1, 0.0, 0.0]},
]
query = [0.8, 0.0, 0.0]
results = await svc.similarity_search(query, entries, top_k=3, min_score=0.0)
assert len(results) == 2
assert results[0]["content_text"] == "dogs are great pets"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_similarity_search_top_k(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
entries = [
{"content_text": f"entry {i}", "embedding": [float(i) / 10, 0.0]} for i in range(10)
]
query = [1.0, 0.0]
results = await svc.similarity_search(query, entries, top_k=3)
assert len(results) == 3
@pytest.mark.unit
@pytest.mark.asyncio
async def test_similarity_search_min_score(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
entries = [
{"content_text": "close match", "embedding": [0.9, 0.0]},
{"content_text": "distant match", "embedding": [-0.5, 0.0]},
]
query = [1.0, 0.0]
results = await svc.similarity_search(query, entries, top_k=5, min_score=0.5)
assert len(results) == 1
assert results[0]["content_text"] == "close match"
@pytest.mark.unit
@pytest.mark.asyncio
async def test_generate_embedding_empty(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
result = await svc.generate_embedding("")
assert result is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_generate_embedding_no_api_key(self):
"""无 API Key 时返回 None"""
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
with patch.object(svc, "_get_client", AsyncMock(return_value=None)):
result = await svc.generate_embedding("test")
assert result is None
@pytest.mark.unit
@pytest.mark.asyncio
async def test_generate_embeddings_empty(self):
from app.services.embedding_service import EmbeddingService
svc = EmbeddingService()
result = await svc.generate_embeddings([])
assert result == []
@pytest.mark.unit
def test_serialize_deserialize(self):
from app.services.embedding_service import EmbeddingService
emb = [0.1, 0.2, 0.3]
serialized = EmbeddingService.serialize_embedding(emb)
deserialized = EmbeddingService.deserialize_embedding(serialized)
assert deserialized == emb
@pytest.mark.unit
def test_deserialize_invalid(self):
from app.services.embedding_service import EmbeddingService
result = EmbeddingService.deserialize_embedding("invalid json")
assert result == []
@pytest.mark.unit
def test_deserialize_list_already(self):
from app.services.embedding_service import EmbeddingService
result = EmbeddingService.deserialize_embedding([1.0, 2.0])
assert result == [1.0, 2.0]