Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -13,11 +13,15 @@ from xinference_client.types import Embedding
|
||||
|
||||
|
||||
class MockTcvectordbClass:
|
||||
|
||||
def VectorDBClient(self, url=None, username='', key='',
|
||||
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
||||
timeout=5,
|
||||
adapter: HTTPAdapter = None):
|
||||
def VectorDBClient(
|
||||
self,
|
||||
url=None,
|
||||
username="",
|
||||
key="",
|
||||
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
|
||||
timeout=5,
|
||||
adapter: HTTPAdapter = None,
|
||||
):
|
||||
self._conn = None
|
||||
self._read_consistency = read_consistency
|
||||
|
||||
@@ -26,105 +30,96 @@ class MockTcvectordbClass:
|
||||
Database(
|
||||
conn=self._conn,
|
||||
read_consistency=self._read_consistency,
|
||||
name='dify',
|
||||
)]
|
||||
name="dify",
|
||||
)
|
||||
]
|
||||
|
||||
def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
|
||||
return []
|
||||
|
||||
def drop_collection(self, name: str, timeout: Optional[float] = None):
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "operation success"
|
||||
}
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
name: str,
|
||||
shard: int,
|
||||
replicas: int,
|
||||
description: str,
|
||||
index: Index,
|
||||
embedding: Embedding = None,
|
||||
timeout: float = None,
|
||||
self,
|
||||
name: str,
|
||||
shard: int,
|
||||
replicas: int,
|
||||
description: str,
|
||||
index: Index,
|
||||
embedding: Embedding = None,
|
||||
timeout: float = None,
|
||||
) -> Collection:
|
||||
return Collection(self, name, shard, replicas, description, index, embedding=embedding,
|
||||
read_consistency=self._read_consistency, timeout=timeout)
|
||||
|
||||
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
|
||||
collection = Collection(
|
||||
return Collection(
|
||||
self,
|
||||
name,
|
||||
shard=1,
|
||||
replicas=2,
|
||||
description=name,
|
||||
timeout=timeout
|
||||
shard,
|
||||
replicas,
|
||||
description,
|
||||
index,
|
||||
embedding=embedding,
|
||||
read_consistency=self._read_consistency,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
|
||||
collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout)
|
||||
return collection
|
||||
|
||||
def collection_upsert(
|
||||
self,
|
||||
documents: list[Document],
|
||||
timeout: Optional[float] = None,
|
||||
build_index: bool = True,
|
||||
**kwargs
|
||||
self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs
|
||||
):
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "operation success"
|
||||
}
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
def collection_search(
|
||||
self,
|
||||
vectors: list[list[float]],
|
||||
filter: Filter = None,
|
||||
params=None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
self,
|
||||
vectors: list[list[float]],
|
||||
filter: Filter = None,
|
||||
params=None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> list[list[dict]]:
|
||||
return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]]
|
||||
return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]]
|
||||
|
||||
def collection_query(
|
||||
self,
|
||||
document_ids: Optional[list] = None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
filter: Optional[Filter] = None,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
self,
|
||||
document_ids: Optional[list] = None,
|
||||
retrieve_vector: bool = False,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
filter: Optional[Filter] = None,
|
||||
output_fields: Optional[list[str]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> list[dict]:
|
||||
return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]
|
||||
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
|
||||
|
||||
def collection_delete(
|
||||
self,
|
||||
document_ids: list[str] = None,
|
||||
filter: Filter = None,
|
||||
timeout: float = None,
|
||||
self,
|
||||
document_ids: list[str] = None,
|
||||
filter: Filter = None,
|
||||
timeout: float = None,
|
||||
):
|
||||
return {
|
||||
"code": 0,
|
||||
"msg": "operation success"
|
||||
}
|
||||
return {"code": 0, "msg": "operation success"}
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient)
|
||||
monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases)
|
||||
monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection)
|
||||
monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections)
|
||||
monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection)
|
||||
monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection)
|
||||
monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert)
|
||||
monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search)
|
||||
monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query)
|
||||
monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete)
|
||||
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient)
|
||||
monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases)
|
||||
monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection)
|
||||
monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections)
|
||||
monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection)
|
||||
monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection)
|
||||
monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert)
|
||||
monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search)
|
||||
monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query)
|
||||
monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class AnalyticdbVectorTest(AbstractVectorTest):
|
||||
def run_all_tests(self):
|
||||
self.vector.delete()
|
||||
return super().run_all_tests()
|
||||
|
||||
|
||||
|
||||
def test_chroma_vector(setup_mock_redis):
|
||||
AnalyticdbVectorTest().run_all_tests()
|
||||
AnalyticdbVectorTest().run_all_tests()
|
||||
|
||||
@@ -14,13 +14,13 @@ class ChromaVectorTest(AbstractVectorTest):
|
||||
self.vector = ChromaVector(
|
||||
collection_name=self.collection_name,
|
||||
config=ChromaConfig(
|
||||
host='localhost',
|
||||
host="localhost",
|
||||
port=8000,
|
||||
tenant=chromadb.DEFAULT_TENANT,
|
||||
database=chromadb.DEFAULT_DATABASE,
|
||||
auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
|
||||
auth_credentials="difyai123456",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
|
||||
@@ -8,16 +8,11 @@ from tests.integration_tests.vdb.test_vector_store import (
|
||||
class ElasticSearchVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = ElasticSearchVector(
|
||||
index_name=self.collection_name.lower(),
|
||||
config=ElasticSearchConfig(
|
||||
host='http://localhost',
|
||||
port='9200',
|
||||
username='elastic',
|
||||
password='elastic'
|
||||
),
|
||||
attributes=self.attributes
|
||||
config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"),
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,11 +12,11 @@ class MilvusVectorTest(AbstractVectorTest):
|
||||
self.vector = MilvusVector(
|
||||
collection_name=self.collection_name,
|
||||
config=MilvusConfig(
|
||||
host='localhost',
|
||||
host="localhost",
|
||||
port=19530,
|
||||
user='root',
|
||||
password='Milvus',
|
||||
)
|
||||
user="root",
|
||||
password="Milvus",
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
@@ -25,7 +25,7 @@ class MilvusVectorTest(AbstractVectorTest):
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ class MyScaleVectorTest(AbstractVectorTest):
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
|
||||
@@ -29,54 +29,55 @@ class TestOpenSearchVector:
|
||||
self.example_doc_id = "example_doc_id"
|
||||
self.vector = OpenSearchVector(
|
||||
collection_name=self.collection_name,
|
||||
config=OpenSearchConfig(
|
||||
host='localhost',
|
||||
port=9200,
|
||||
user='admin',
|
||||
password='password',
|
||||
secure=False
|
||||
)
|
||||
config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False),
|
||||
)
|
||||
self.vector._client = MagicMock()
|
||||
|
||||
@pytest.mark.parametrize("search_response, expected_length, expected_doc_id", [
|
||||
({
|
||||
'hits': {
|
||||
'total': {'value': 1},
|
||||
'hits': [
|
||||
{'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}}
|
||||
]
|
||||
}
|
||||
}, 1, "example_doc_id"),
|
||||
({
|
||||
'hits': {
|
||||
'total': {'value': 0},
|
||||
'hits': []
|
||||
}
|
||||
}, 0, None)
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"search_response, expected_length, expected_doc_id",
|
||||
[
|
||||
(
|
||||
{
|
||||
"hits": {
|
||||
"total": {"value": 1},
|
||||
"hits": [
|
||||
{
|
||||
"_source": {
|
||||
"page_content": get_example_text(),
|
||||
"metadata": {"document_id": "example_doc_id"},
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
1,
|
||||
"example_doc_id",
|
||||
),
|
||||
({"hits": {"total": {"value": 0}, "hits": []}}, 0, None),
|
||||
],
|
||||
)
|
||||
def test_search_by_full_text(self, search_response, expected_length, expected_doc_id):
|
||||
self.vector._client.search.return_value = search_response
|
||||
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == expected_length
|
||||
if expected_length > 0:
|
||||
assert hits_by_full_text[0].metadata['document_id'] == expected_doc_id
|
||||
assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id
|
||||
|
||||
def test_search_by_vector(self):
|
||||
vector = [0.1] * 128
|
||||
mock_response = {
|
||||
'hits': {
|
||||
'total': {'value': 1},
|
||||
'hits': [
|
||||
"hits": {
|
||||
"total": {"value": 1},
|
||||
"hits": [
|
||||
{
|
||||
'_source': {
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: get_example_text(),
|
||||
Field.METADATA_KEY.value: {"document_id": self.example_doc_id}
|
||||
Field.METADATA_KEY.value: {"document_id": self.example_doc_id},
|
||||
},
|
||||
'_score': 1.0
|
||||
"_score": 1.0,
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
}
|
||||
self.vector._client.search.return_value = mock_response
|
||||
@@ -85,53 +86,45 @@ class TestOpenSearchVector:
|
||||
|
||||
print("Hits by vector:", hits_by_vector)
|
||||
print("Expected document ID:", self.example_doc_id)
|
||||
print("Actual document ID:", hits_by_vector[0].metadata['document_id'] if hits_by_vector else "No hits")
|
||||
print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits")
|
||||
|
||||
assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}"
|
||||
assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \
|
||||
f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
|
||||
assert (
|
||||
hits_by_vector[0].metadata["document_id"] == self.example_doc_id
|
||||
), f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
|
||||
|
||||
def test_get_ids_by_metadata_field(self):
|
||||
mock_response = {
|
||||
'hits': {
|
||||
'total': {'value': 1},
|
||||
'hits': [{'_id': 'mock_id'}]
|
||||
}
|
||||
}
|
||||
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
|
||||
self.vector._client.search.return_value = mock_response
|
||||
|
||||
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
|
||||
embedding = [0.1] * 128
|
||||
|
||||
with patch('opensearchpy.helpers.bulk') as mock_bulk:
|
||||
with patch("opensearchpy.helpers.bulk") as mock_bulk:
|
||||
mock_bulk.return_value = ([], [])
|
||||
self.vector.add_texts([doc], [embedding])
|
||||
|
||||
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == 'mock_id'
|
||||
assert ids[0] == "mock_id"
|
||||
|
||||
def test_add_texts(self):
|
||||
self.vector._client.index.return_value = {'result': 'created'}
|
||||
self.vector._client.index.return_value = {"result": "created"}
|
||||
|
||||
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
|
||||
embedding = [0.1] * 128
|
||||
|
||||
with patch('opensearchpy.helpers.bulk') as mock_bulk:
|
||||
with patch("opensearchpy.helpers.bulk") as mock_bulk:
|
||||
mock_bulk.return_value = ([], [])
|
||||
self.vector.add_texts([doc], [embedding])
|
||||
|
||||
mock_response = {
|
||||
'hits': {
|
||||
'total': {'value': 1},
|
||||
'hits': [{'_id': 'mock_id'}]
|
||||
}
|
||||
}
|
||||
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
|
||||
self.vector._client.search.return_value = mock_response
|
||||
|
||||
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
assert ids[0] == 'mock_id'
|
||||
assert ids[0] == "mock_id"
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("setup_mock_redis")
|
||||
class TestOpenSearchVectorWithRedis:
|
||||
@@ -141,11 +134,11 @@ class TestOpenSearchVectorWithRedis:
|
||||
def test_search_by_full_text(self):
|
||||
self.tester.setup_method()
|
||||
search_response = {
|
||||
'hits': {
|
||||
'total': {'value': 1},
|
||||
'hits': [
|
||||
{'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}}
|
||||
]
|
||||
"hits": {
|
||||
"total": {"value": 1},
|
||||
"hits": [
|
||||
{"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}}
|
||||
],
|
||||
}
|
||||
}
|
||||
expected_length = 1
|
||||
|
||||
@@ -12,13 +12,13 @@ class PGVectoRSVectorTest(AbstractVectorTest):
|
||||
self.vector = PGVectoRS(
|
||||
collection_name=self.collection_name.lower(),
|
||||
config=PgvectoRSConfig(
|
||||
host='localhost',
|
||||
host="localhost",
|
||||
port=5431,
|
||||
user='postgres',
|
||||
password='difyai123456',
|
||||
database='dify',
|
||||
user="postgres",
|
||||
password="difyai123456",
|
||||
database="dify",
|
||||
),
|
||||
dim=128
|
||||
dim=128,
|
||||
)
|
||||
|
||||
def search_by_full_text(self):
|
||||
@@ -27,8 +27,9 @@ class PGVectoRSVectorTest(AbstractVectorTest):
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 1
|
||||
|
||||
|
||||
def test_pgvecot_rs(setup_mock_redis):
|
||||
PGVectoRSVectorTest().run_all_tests()
|
||||
|
||||
@@ -8,14 +8,14 @@ from tests.integration_tests.vdb.test_vector_store import (
|
||||
class QdrantVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = QdrantVector(
|
||||
collection_name=self.collection_name,
|
||||
group_id=self.dataset_id,
|
||||
config=QdrantConfig(
|
||||
endpoint='http://localhost:6333',
|
||||
api_key='difyai123456',
|
||||
)
|
||||
endpoint="http://localhost:6333",
|
||||
api_key="difyai123456",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,18 +7,22 @@ from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, ge
|
||||
mock_client = MagicMock()
|
||||
mock_client.list_databases.return_value = [{"name": "test"}]
|
||||
|
||||
|
||||
class TencentVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vector = TencentVector("dify", TencentConfig(
|
||||
url="http://127.0.0.1",
|
||||
api_key="dify",
|
||||
timeout=30,
|
||||
username="dify",
|
||||
database="dify",
|
||||
shard=1,
|
||||
replicas=2,
|
||||
))
|
||||
self.vector = TencentVector(
|
||||
"dify",
|
||||
TencentConfig(
|
||||
url="http://127.0.0.1",
|
||||
api_key="dify",
|
||||
timeout=30,
|
||||
username="dify",
|
||||
database="dify",
|
||||
shard=1,
|
||||
replicas=2,
|
||||
),
|
||||
)
|
||||
|
||||
def search_by_vector(self):
|
||||
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
@@ -28,8 +32,6 @@ class TencentVectorTest(AbstractVectorTest):
|
||||
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock):
|
||||
|
||||
def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock):
|
||||
TencentVectorTest().run_all_tests()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from models.dataset import Dataset
|
||||
|
||||
|
||||
def get_example_text() -> str:
|
||||
return 'test_text'
|
||||
return "test_text"
|
||||
|
||||
|
||||
def get_example_document(doc_id: str) -> Document:
|
||||
@@ -21,7 +21,7 @@ def get_example_document(doc_id: str) -> Document:
|
||||
"doc_hash": doc_id,
|
||||
"document_id": doc_id,
|
||||
"dataset_id": doc_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
return doc
|
||||
|
||||
@@ -45,7 +45,7 @@ class AbstractVectorTest:
|
||||
def __init__(self):
|
||||
self.vector = None
|
||||
self.dataset_id = str(uuid.uuid4())
|
||||
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test'
|
||||
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
|
||||
self.example_doc_id = str(uuid.uuid4())
|
||||
self.example_embedding = [1.001 * i for i in range(128)]
|
||||
|
||||
@@ -58,12 +58,12 @@ class AbstractVectorTest:
|
||||
def search_by_vector(self):
|
||||
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
|
||||
assert len(hits_by_vector) == 1
|
||||
assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id
|
||||
assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id
|
||||
|
||||
def search_by_full_text(self):
|
||||
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
|
||||
assert len(hits_by_full_text) == 1
|
||||
assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id
|
||||
assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id
|
||||
|
||||
def delete_vector(self):
|
||||
self.vector.delete()
|
||||
@@ -76,14 +76,14 @@ class AbstractVectorTest:
|
||||
documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
|
||||
embeddings = [self.example_embedding] * batch_size
|
||||
self.vector.add_texts(documents=documents, embeddings=embeddings)
|
||||
return [doc.metadata['doc_id'] for doc in documents]
|
||||
return [doc.metadata["doc_id"] for doc in documents]
|
||||
|
||||
def text_exists(self):
|
||||
assert self.vector.text_exists(self.example_doc_id)
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
with pytest.raises(NotImplementedError):
|
||||
self.vector.get_ids_by_metadata_field(key='key', value='value')
|
||||
self.vector.get_ids_by_metadata_field(key="key", value="value")
|
||||
|
||||
def run_all_tests(self):
|
||||
self.create_vector()
|
||||
|
||||
@@ -10,15 +10,15 @@ from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, ge
|
||||
@pytest.fixture
|
||||
def tidb_vector():
|
||||
return TiDBVector(
|
||||
collection_name='test_collection',
|
||||
collection_name="test_collection",
|
||||
config=TiDBVectorConfig(
|
||||
host="xxx.eu-central-1.xxx.aws.tidbcloud.com",
|
||||
port="4000",
|
||||
user="xxx.root",
|
||||
password="xxxxxx",
|
||||
database="dify",
|
||||
program_name="langgenius/dify"
|
||||
)
|
||||
program_name="langgenius/dify",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class TiDBVectorTest(AbstractVectorTest):
|
||||
assert len(hits_by_full_text) == 0
|
||||
|
||||
def get_ids_by_metadata_field(self):
|
||||
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
|
||||
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
|
||||
assert len(ids) == 0
|
||||
|
||||
|
||||
@@ -50,12 +50,12 @@ def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session:
|
||||
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tidbvector_mock(tidb_vector, mock_session):
|
||||
with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'):
|
||||
with patch.object(tidb_vector._engine, 'connect'):
|
||||
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"):
|
||||
with patch.object(tidb_vector._engine, "connect"):
|
||||
yield tidb_vector
|
||||
|
||||
@@ -8,14 +8,14 @@ from tests.integration_tests.vdb.test_vector_store import (
|
||||
class WeaviateVectorTest(AbstractVectorTest):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
|
||||
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
|
||||
self.vector = WeaviateVector(
|
||||
collection_name=self.collection_name,
|
||||
config=WeaviateConfig(
|
||||
endpoint='http://localhost:8080',
|
||||
api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
|
||||
endpoint="http://localhost:8080",
|
||||
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
|
||||
),
|
||||
attributes=self.attributes
|
||||
attributes=self.attributes,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user