147 lines
4.8 KiB
Python
147 lines
4.8 KiB
Python
|
|
"""
|
||
|
|
Embedding 服务单元测试
|
||
|
|
"""
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import patch, AsyncMock
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
|
||
|
|
class TestEmbeddingService:
|
||
|
|
"""Embedding 服务测试"""
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_cosine_similarity_identical(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
a = [1.0, 0.0, 0.0]
|
||
|
|
b = [1.0, 0.0, 0.0]
|
||
|
|
sim = EmbeddingService.cosine_similarity(a, b)
|
||
|
|
assert sim == pytest.approx(1.0, abs=1e-6)
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
def test_cosine_similarity_orthogonal(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
a = [1.0, 0.0]
|
||
|
|
b = [0.0, 1.0]
|
||
|
|
sim = EmbeddingService.cosine_similarity(a, b)
|
||
|
|
assert sim == pytest.approx(0.0, abs=1e-6)
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
def test_cosine_similarity_opposite(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
a = [1.0, 0.0]
|
||
|
|
b = [-1.0, 0.0]
|
||
|
|
sim = EmbeddingService.cosine_similarity(a, b)
|
||
|
|
assert sim == pytest.approx(-1.0, abs=1e-6)
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_similarity_search_empty(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
svc = EmbeddingService()
|
||
|
|
results = await svc.similarity_search(
|
||
|
|
[1.0, 0.0], [], top_k=5
|
||
|
|
)
|
||
|
|
assert results == []
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_similarity_search_ordering(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
svc = EmbeddingService()
|
||
|
|
entries = [
|
||
|
|
{"content_text": "dogs are great pets", "embedding": [0.9, 0.0, 0.0]},
|
||
|
|
{"content_text": "cats are independent", "embedding": [0.1, 0.0, 0.0]},
|
||
|
|
]
|
||
|
|
query = [0.8, 0.0, 0.0]
|
||
|
|
results = await svc.similarity_search(query, entries, top_k=3, min_score=0.0)
|
||
|
|
assert len(results) == 2
|
||
|
|
assert results[0]["content_text"] == "dogs are great pets"
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_similarity_search_top_k(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
svc = EmbeddingService()
|
||
|
|
entries = [
|
||
|
|
{"content_text": f"entry {i}", "embedding": [float(i) / 10, 0.0]} for i in range(10)
|
||
|
|
]
|
||
|
|
query = [1.0, 0.0]
|
||
|
|
results = await svc.similarity_search(query, entries, top_k=3)
|
||
|
|
assert len(results) == 3
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_similarity_search_min_score(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
svc = EmbeddingService()
|
||
|
|
entries = [
|
||
|
|
{"content_text": "close match", "embedding": [0.9, 0.0]},
|
||
|
|
{"content_text": "distant match", "embedding": [-0.5, 0.0]},
|
||
|
|
]
|
||
|
|
query = [1.0, 0.0]
|
||
|
|
results = await svc.similarity_search(query, entries, top_k=5, min_score=0.5)
|
||
|
|
assert len(results) == 1
|
||
|
|
assert results[0]["content_text"] == "close match"
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_embedding_empty(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
svc = EmbeddingService()
|
||
|
|
result = await svc.generate_embedding("")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_embedding_no_api_key(self):
|
||
|
|
"""无 API Key 时返回 None"""
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
svc = EmbeddingService()
|
||
|
|
with patch.object(svc, "_get_client", AsyncMock(return_value=None)):
|
||
|
|
result = await svc.generate_embedding("test")
|
||
|
|
assert result is None
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_generate_embeddings_empty(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
svc = EmbeddingService()
|
||
|
|
result = await svc.generate_embeddings([])
|
||
|
|
assert result == []
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
def test_serialize_deserialize(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
emb = [0.1, 0.2, 0.3]
|
||
|
|
serialized = EmbeddingService.serialize_embedding(emb)
|
||
|
|
deserialized = EmbeddingService.deserialize_embedding(serialized)
|
||
|
|
assert deserialized == emb
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
def test_deserialize_invalid(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
result = EmbeddingService.deserialize_embedding("invalid json")
|
||
|
|
assert result == []
|
||
|
|
|
||
|
|
@pytest.mark.unit
|
||
|
|
def test_deserialize_list_already(self):
|
||
|
|
from app.services.embedding_service import EmbeddingService
|
||
|
|
|
||
|
|
result = EmbeddingService.deserialize_embedding([1.0, 2.0])
|
||
|
|
assert result == [1.0, 2.0]
|