- 新增 embedding_service(语义检索)、knowledge_service(RAG)、text_chunker、document_parser - 新增 tool_registry(自定义工具注册表)并完善工具市场 API(CRUD + code/http 执行) - 新增 agent_vector_memory / knowledge_base 模型及对应数据库表 - 实现 SSE 流式响应与 Agent 预算控制 - AgentChat.vue 集成 MainLayout 导航布局 - 完善测试体系:7 个新测试文件共 110 个测试覆盖 - 修复 conftest.py SQLite 内存数据库连接隔离问题 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
147 lines
4.8 KiB
Python
147 lines
4.8 KiB
Python
"""
|
|
Embedding 服务单元测试
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
from unittest.mock import patch, AsyncMock
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
|
|
class TestEmbeddingService:
|
|
"""Embedding 服务测试"""
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_cosine_similarity_identical(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
a = [1.0, 0.0, 0.0]
|
|
b = [1.0, 0.0, 0.0]
|
|
sim = EmbeddingService.cosine_similarity(a, b)
|
|
assert sim == pytest.approx(1.0, abs=1e-6)
|
|
|
|
@pytest.mark.unit
|
|
def test_cosine_similarity_orthogonal(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
a = [1.0, 0.0]
|
|
b = [0.0, 1.0]
|
|
sim = EmbeddingService.cosine_similarity(a, b)
|
|
assert sim == pytest.approx(0.0, abs=1e-6)
|
|
|
|
@pytest.mark.unit
|
|
def test_cosine_similarity_opposite(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
a = [1.0, 0.0]
|
|
b = [-1.0, 0.0]
|
|
sim = EmbeddingService.cosine_similarity(a, b)
|
|
assert sim == pytest.approx(-1.0, abs=1e-6)
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_similarity_search_empty(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
svc = EmbeddingService()
|
|
results = await svc.similarity_search(
|
|
[1.0, 0.0], [], top_k=5
|
|
)
|
|
assert results == []
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_similarity_search_ordering(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
svc = EmbeddingService()
|
|
entries = [
|
|
{"content_text": "dogs are great pets", "embedding": [0.9, 0.0, 0.0]},
|
|
{"content_text": "cats are independent", "embedding": [0.1, 0.0, 0.0]},
|
|
]
|
|
query = [0.8, 0.0, 0.0]
|
|
results = await svc.similarity_search(query, entries, top_k=3, min_score=0.0)
|
|
assert len(results) == 2
|
|
assert results[0]["content_text"] == "dogs are great pets"
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_similarity_search_top_k(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
svc = EmbeddingService()
|
|
entries = [
|
|
{"content_text": f"entry {i}", "embedding": [float(i) / 10, 0.0]} for i in range(10)
|
|
]
|
|
query = [1.0, 0.0]
|
|
results = await svc.similarity_search(query, entries, top_k=3)
|
|
assert len(results) == 3
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_similarity_search_min_score(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
svc = EmbeddingService()
|
|
entries = [
|
|
{"content_text": "close match", "embedding": [0.9, 0.0]},
|
|
{"content_text": "distant match", "embedding": [-0.5, 0.0]},
|
|
]
|
|
query = [1.0, 0.0]
|
|
results = await svc.similarity_search(query, entries, top_k=5, min_score=0.5)
|
|
assert len(results) == 1
|
|
assert results[0]["content_text"] == "close match"
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_empty(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
svc = EmbeddingService()
|
|
result = await svc.generate_embedding("")
|
|
assert result is None
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embedding_no_api_key(self):
|
|
"""无 API Key 时返回 None"""
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
svc = EmbeddingService()
|
|
with patch.object(svc, "_get_client", AsyncMock(return_value=None)):
|
|
result = await svc.generate_embedding("test")
|
|
assert result is None
|
|
|
|
@pytest.mark.unit
|
|
@pytest.mark.asyncio
|
|
async def test_generate_embeddings_empty(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
svc = EmbeddingService()
|
|
result = await svc.generate_embeddings([])
|
|
assert result == []
|
|
|
|
@pytest.mark.unit
|
|
def test_serialize_deserialize(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
emb = [0.1, 0.2, 0.3]
|
|
serialized = EmbeddingService.serialize_embedding(emb)
|
|
deserialized = EmbeddingService.deserialize_embedding(serialized)
|
|
assert deserialized == emb
|
|
|
|
@pytest.mark.unit
|
|
def test_deserialize_invalid(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
result = EmbeddingService.deserialize_embedding("invalid json")
|
|
assert result == []
|
|
|
|
@pytest.mark.unit
|
|
def test_deserialize_list_already(self):
|
|
from app.services.embedding_service import EmbeddingService
|
|
|
|
result = EmbeddingService.deserialize_embedding([1.0, 2.0])
|
|
assert result == [1.0, 2.0]
|