""" Embedding 服务单元测试 """ from __future__ import annotations import pytest from unittest.mock import patch, AsyncMock from app.services.embedding_service import EmbeddingService class TestEmbeddingService: """Embedding 服务测试""" @pytest.mark.unit @pytest.mark.asyncio async def test_cosine_similarity_identical(self): from app.services.embedding_service import EmbeddingService a = [1.0, 0.0, 0.0] b = [1.0, 0.0, 0.0] sim = EmbeddingService.cosine_similarity(a, b) assert sim == pytest.approx(1.0, abs=1e-6) @pytest.mark.unit def test_cosine_similarity_orthogonal(self): from app.services.embedding_service import EmbeddingService a = [1.0, 0.0] b = [0.0, 1.0] sim = EmbeddingService.cosine_similarity(a, b) assert sim == pytest.approx(0.0, abs=1e-6) @pytest.mark.unit def test_cosine_similarity_opposite(self): from app.services.embedding_service import EmbeddingService a = [1.0, 0.0] b = [-1.0, 0.0] sim = EmbeddingService.cosine_similarity(a, b) assert sim == pytest.approx(-1.0, abs=1e-6) @pytest.mark.unit @pytest.mark.asyncio async def test_similarity_search_empty(self): from app.services.embedding_service import EmbeddingService svc = EmbeddingService() results = await svc.similarity_search( [1.0, 0.0], [], top_k=5 ) assert results == [] @pytest.mark.unit @pytest.mark.asyncio async def test_similarity_search_ordering(self): from app.services.embedding_service import EmbeddingService svc = EmbeddingService() entries = [ {"content_text": "dogs are great pets", "embedding": [0.9, 0.0, 0.0]}, {"content_text": "cats are independent", "embedding": [0.1, 0.0, 0.0]}, ] query = [0.8, 0.0, 0.0] results = await svc.similarity_search(query, entries, top_k=3, min_score=0.0) assert len(results) == 2 assert results[0]["content_text"] == "dogs are great pets" @pytest.mark.unit @pytest.mark.asyncio async def test_similarity_search_top_k(self): from app.services.embedding_service import EmbeddingService svc = EmbeddingService() entries = [ {"content_text": f"entry {i}", "embedding": [float(i) / 10, 0.0]} for i in range(10) ] query = [1.0, 0.0] results = await svc.similarity_search(query, entries, top_k=3) assert len(results) == 3 @pytest.mark.unit @pytest.mark.asyncio async def test_similarity_search_min_score(self): from app.services.embedding_service import EmbeddingService svc = EmbeddingService() entries = [ {"content_text": "close match", "embedding": [0.9, 0.0]}, {"content_text": "distant match", "embedding": [-0.5, 0.0]}, ] query = [1.0, 0.0] results = await svc.similarity_search(query, entries, top_k=5, min_score=0.5) assert len(results) == 1 assert results[0]["content_text"] == "close match" @pytest.mark.unit @pytest.mark.asyncio async def test_generate_embedding_empty(self): from app.services.embedding_service import EmbeddingService svc = EmbeddingService() result = await svc.generate_embedding("") assert result is None @pytest.mark.unit @pytest.mark.asyncio async def test_generate_embedding_no_api_key(self): """无 API Key 时返回 None""" from app.services.embedding_service import EmbeddingService svc = EmbeddingService() with patch.object(svc, "_get_client", AsyncMock(return_value=None)): result = await svc.generate_embedding("test") assert result is None @pytest.mark.unit @pytest.mark.asyncio async def test_generate_embeddings_empty(self): from app.services.embedding_service import EmbeddingService svc = EmbeddingService() result = await svc.generate_embeddings([]) assert result == [] @pytest.mark.unit def test_serialize_deserialize(self): from app.services.embedding_service import EmbeddingService emb = [0.1, 0.2, 0.3] serialized = EmbeddingService.serialize_embedding(emb) deserialized = EmbeddingService.deserialize_embedding(serialized) assert deserialized == emb @pytest.mark.unit def test_deserialize_invalid(self): from app.services.embedding_service import EmbeddingService result = EmbeddingService.deserialize_embedding("invalid json") assert result == [] @pytest.mark.unit def test_deserialize_list_already(self): from app.services.embedding_service import EmbeddingService result = EmbeddingService.deserialize_embedding([1.0, 2.0]) assert result == [1.0, 2.0]