chore(api/tests): apply ruff reformat #7590 (#7591)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang
2024-08-23 23:52:25 +08:00
committed by GitHub
parent 2da63654e5
commit b035c02f78
155 changed files with 4279 additions and 5925 deletions

View File

@@ -13,11 +13,15 @@ from xinference_client.types import Embedding
class MockTcvectordbClass:
def VectorDBClient(self, url=None, username='', key='',
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=5,
adapter: HTTPAdapter = None):
def VectorDBClient(
self,
url=None,
username="",
key="",
read_consistency: ReadConsistency = ReadConsistency.EVENTUAL_CONSISTENCY,
timeout=5,
adapter: HTTPAdapter = None,
):
self._conn = None
self._read_consistency = read_consistency
@@ -26,105 +30,96 @@ class MockTcvectordbClass:
Database(
conn=self._conn,
read_consistency=self._read_consistency,
name='dify',
)]
name="dify",
)
]
def list_collections(self, timeout: Optional[float] = None) -> list[Collection]:
return []
def drop_collection(self, name: str, timeout: Optional[float] = None):
return {
"code": 0,
"msg": "operation success"
}
return {"code": 0, "msg": "operation success"}
def create_collection(
self,
name: str,
shard: int,
replicas: int,
description: str,
index: Index,
embedding: Embedding = None,
timeout: float = None,
self,
name: str,
shard: int,
replicas: int,
description: str,
index: Index,
embedding: Embedding = None,
timeout: float = None,
) -> Collection:
return Collection(self, name, shard, replicas, description, index, embedding=embedding,
read_consistency=self._read_consistency, timeout=timeout)
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
collection = Collection(
return Collection(
self,
name,
shard=1,
replicas=2,
description=name,
timeout=timeout
shard,
replicas,
description,
index,
embedding=embedding,
read_consistency=self._read_consistency,
timeout=timeout,
)
def describe_collection(self, name: str, timeout: Optional[float] = None) -> Collection:
collection = Collection(self, name, shard=1, replicas=2, description=name, timeout=timeout)
return collection
def collection_upsert(
self,
documents: list[Document],
timeout: Optional[float] = None,
build_index: bool = True,
**kwargs
self, documents: list[Document], timeout: Optional[float] = None, build_index: bool = True, **kwargs
):
return {
"code": 0,
"msg": "operation success"
}
return {"code": 0, "msg": "operation success"}
def collection_search(
self,
vectors: list[list[float]],
filter: Filter = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
self,
vectors: list[list[float]],
filter: Filter = None,
params=None,
retrieve_vector: bool = False,
limit: int = 10,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[list[dict]]:
return [[{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]]
return [[{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]]
def collection_query(
self,
document_ids: Optional[list] = None,
retrieve_vector: bool = False,
limit: Optional[int] = None,
offset: Optional[int] = None,
filter: Optional[Filter] = None,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
self,
document_ids: Optional[list] = None,
retrieve_vector: bool = False,
limit: Optional[int] = None,
offset: Optional[int] = None,
filter: Optional[Filter] = None,
output_fields: Optional[list[str]] = None,
timeout: Optional[float] = None,
) -> list[dict]:
return [{'metadata': '{"doc_id":"foo1"}', 'text': 'text', 'doc_id': 'foo1', 'score': 0.1}]
return [{"metadata": '{"doc_id":"foo1"}', "text": "text", "doc_id": "foo1", "score": 0.1}]
def collection_delete(
self,
document_ids: list[str] = None,
filter: Filter = None,
timeout: float = None,
self,
document_ids: list[str] = None,
filter: Filter = None,
timeout: float = None,
):
return {
"code": 0,
"msg": "operation success"
}
return {"code": 0, "msg": "operation success"}
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(VectorDBClient, '__init__', MockTcvectordbClass.VectorDBClient)
monkeypatch.setattr(VectorDBClient, 'list_databases', MockTcvectordbClass.list_databases)
monkeypatch.setattr(Database, 'collection', MockTcvectordbClass.describe_collection)
monkeypatch.setattr(Database, 'list_collections', MockTcvectordbClass.list_collections)
monkeypatch.setattr(Database, 'drop_collection', MockTcvectordbClass.drop_collection)
monkeypatch.setattr(Database, 'create_collection', MockTcvectordbClass.create_collection)
monkeypatch.setattr(Collection, 'upsert', MockTcvectordbClass.collection_upsert)
monkeypatch.setattr(Collection, 'search', MockTcvectordbClass.collection_search)
monkeypatch.setattr(Collection, 'query', MockTcvectordbClass.collection_query)
monkeypatch.setattr(Collection, 'delete', MockTcvectordbClass.collection_delete)
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient)
monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases)
monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection)
monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections)
monkeypatch.setattr(Database, "drop_collection", MockTcvectordbClass.drop_collection)
monkeypatch.setattr(Database, "create_collection", MockTcvectordbClass.create_collection)
monkeypatch.setattr(Collection, "upsert", MockTcvectordbClass.collection_upsert)
monkeypatch.setattr(Collection, "search", MockTcvectordbClass.collection_search)
monkeypatch.setattr(Collection, "query", MockTcvectordbClass.collection_query)
monkeypatch.setattr(Collection, "delete", MockTcvectordbClass.collection_delete)
yield

View File

@@ -26,6 +26,7 @@ class AnalyticdbVectorTest(AbstractVectorTest):
def run_all_tests(self):
self.vector.delete()
return super().run_all_tests()
def test_chroma_vector(setup_mock_redis):
AnalyticdbVectorTest().run_all_tests()
AnalyticdbVectorTest().run_all_tests()

View File

@@ -14,13 +14,13 @@ class ChromaVectorTest(AbstractVectorTest):
self.vector = ChromaVector(
collection_name=self.collection_name,
config=ChromaConfig(
host='localhost',
host="localhost",
port=8000,
tenant=chromadb.DEFAULT_TENANT,
database=chromadb.DEFAULT_DATABASE,
auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
auth_credentials="difyai123456",
)
),
)
def search_by_full_text(self):

View File

@@ -8,16 +8,11 @@ from tests.integration_tests.vdb.test_vector_store import (
class ElasticSearchVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self.vector = ElasticSearchVector(
index_name=self.collection_name.lower(),
config=ElasticSearchConfig(
host='http://localhost',
port='9200',
username='elastic',
password='elastic'
),
attributes=self.attributes
config=ElasticSearchConfig(host="http://localhost", port="9200", username="elastic", password="elastic"),
attributes=self.attributes,
)

View File

@@ -12,11 +12,11 @@ class MilvusVectorTest(AbstractVectorTest):
self.vector = MilvusVector(
collection_name=self.collection_name,
config=MilvusConfig(
host='localhost',
host="localhost",
port=19530,
user='root',
password='Milvus',
)
user="root",
password="Milvus",
),
)
def search_by_full_text(self):
@@ -25,7 +25,7 @@ class MilvusVectorTest(AbstractVectorTest):
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1

View File

@@ -21,7 +21,7 @@ class MyScaleVectorTest(AbstractVectorTest):
)
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1

View File

@@ -29,54 +29,55 @@ class TestOpenSearchVector:
self.example_doc_id = "example_doc_id"
self.vector = OpenSearchVector(
collection_name=self.collection_name,
config=OpenSearchConfig(
host='localhost',
port=9200,
user='admin',
password='password',
secure=False
)
config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False),
)
self.vector._client = MagicMock()
@pytest.mark.parametrize("search_response, expected_length, expected_doc_id", [
({
'hits': {
'total': {'value': 1},
'hits': [
{'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}}
]
}
}, 1, "example_doc_id"),
({
'hits': {
'total': {'value': 0},
'hits': []
}
}, 0, None)
])
@pytest.mark.parametrize(
"search_response, expected_length, expected_doc_id",
[
(
{
"hits": {
"total": {"value": 1},
"hits": [
{
"_source": {
"page_content": get_example_text(),
"metadata": {"document_id": "example_doc_id"},
}
}
],
}
},
1,
"example_doc_id",
),
({"hits": {"total": {"value": 0}, "hits": []}}, 0, None),
],
)
def test_search_by_full_text(self, search_response, expected_length, expected_doc_id):
self.vector._client.search.return_value = search_response
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == expected_length
if expected_length > 0:
assert hits_by_full_text[0].metadata['document_id'] == expected_doc_id
assert hits_by_full_text[0].metadata["document_id"] == expected_doc_id
def test_search_by_vector(self):
vector = [0.1] * 128
mock_response = {
'hits': {
'total': {'value': 1},
'hits': [
"hits": {
"total": {"value": 1},
"hits": [
{
'_source': {
"_source": {
Field.CONTENT_KEY.value: get_example_text(),
Field.METADATA_KEY.value: {"document_id": self.example_doc_id}
Field.METADATA_KEY.value: {"document_id": self.example_doc_id},
},
'_score': 1.0
"_score": 1.0,
}
]
],
}
}
self.vector._client.search.return_value = mock_response
@@ -85,53 +86,45 @@ class TestOpenSearchVector:
print("Hits by vector:", hits_by_vector)
print("Expected document ID:", self.example_doc_id)
print("Actual document ID:", hits_by_vector[0].metadata['document_id'] if hits_by_vector else "No hits")
print("Actual document ID:", hits_by_vector[0].metadata["document_id"] if hits_by_vector else "No hits")
assert len(hits_by_vector) > 0, f"Expected at least one hit, got {len(hits_by_vector)}"
assert hits_by_vector[0].metadata['document_id'] == self.example_doc_id, \
f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
assert (
hits_by_vector[0].metadata["document_id"] == self.example_doc_id
), f"Expected document ID {self.example_doc_id}, got {hits_by_vector[0].metadata['document_id']}"
def test_get_ids_by_metadata_field(self):
mock_response = {
'hits': {
'total': {'value': 1},
'hits': [{'_id': 'mock_id'}]
}
}
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
self.vector._client.search.return_value = mock_response
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
embedding = [0.1] * 128
with patch('opensearchpy.helpers.bulk') as mock_bulk:
with patch("opensearchpy.helpers.bulk") as mock_bulk:
mock_bulk.return_value = ([], [])
self.vector.add_texts([doc], [embedding])
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
assert ids[0] == 'mock_id'
assert ids[0] == "mock_id"
def test_add_texts(self):
self.vector._client.index.return_value = {'result': 'created'}
self.vector._client.index.return_value = {"result": "created"}
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
embedding = [0.1] * 128
with patch('opensearchpy.helpers.bulk') as mock_bulk:
with patch("opensearchpy.helpers.bulk") as mock_bulk:
mock_bulk.return_value = ([], [])
self.vector.add_texts([doc], [embedding])
mock_response = {
'hits': {
'total': {'value': 1},
'hits': [{'_id': 'mock_id'}]
}
}
mock_response = {"hits": {"total": {"value": 1}, "hits": [{"_id": "mock_id"}]}}
self.vector._client.search.return_value = mock_response
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
assert ids[0] == 'mock_id'
assert ids[0] == "mock_id"
@pytest.mark.usefixtures("setup_mock_redis")
class TestOpenSearchVectorWithRedis:
@@ -141,11 +134,11 @@ class TestOpenSearchVectorWithRedis:
def test_search_by_full_text(self):
self.tester.setup_method()
search_response = {
'hits': {
'total': {'value': 1},
'hits': [
{'_source': {'page_content': get_example_text(), 'metadata': {"document_id": "example_doc_id"}}}
]
"hits": {
"total": {"value": 1},
"hits": [
{"_source": {"page_content": get_example_text(), "metadata": {"document_id": "example_doc_id"}}}
],
}
}
expected_length = 1

View File

@@ -12,13 +12,13 @@ class PGVectoRSVectorTest(AbstractVectorTest):
self.vector = PGVectoRS(
collection_name=self.collection_name.lower(),
config=PgvectoRSConfig(
host='localhost',
host="localhost",
port=5431,
user='postgres',
password='difyai123456',
database='dify',
user="postgres",
password="difyai123456",
database="dify",
),
dim=128
dim=128,
)
def search_by_full_text(self):
@@ -27,8 +27,9 @@ class PGVectoRSVectorTest(AbstractVectorTest):
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 1
def test_pgvecot_rs(setup_mock_redis):
PGVectoRSVectorTest().run_all_tests()

View File

@@ -8,14 +8,14 @@ from tests.integration_tests.vdb.test_vector_store import (
class QdrantVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self.vector = QdrantVector(
collection_name=self.collection_name,
group_id=self.dataset_id,
config=QdrantConfig(
endpoint='http://localhost:6333',
api_key='difyai123456',
)
endpoint="http://localhost:6333",
api_key="difyai123456",
),
)

View File

@@ -7,18 +7,22 @@ from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, ge
mock_client = MagicMock()
mock_client.list_databases.return_value = [{"name": "test"}]
class TencentVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.vector = TencentVector("dify", TencentConfig(
url="http://127.0.0.1",
api_key="dify",
timeout=30,
username="dify",
database="dify",
shard=1,
replicas=2,
))
self.vector = TencentVector(
"dify",
TencentConfig(
url="http://127.0.0.1",
api_key="dify",
timeout=30,
username="dify",
database="dify",
shard=1,
replicas=2,
),
)
def search_by_vector(self):
hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding)
@@ -28,8 +32,6 @@ class TencentVectorTest(AbstractVectorTest):
hits_by_full_text = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 0
def test_tencent_vector(setup_mock_redis,setup_tcvectordb_mock):
def test_tencent_vector(setup_mock_redis, setup_tcvectordb_mock):
TencentVectorTest().run_all_tests()

View File

@@ -10,7 +10,7 @@ from models.dataset import Dataset
def get_example_text() -> str:
return 'test_text'
return "test_text"
def get_example_document(doc_id: str) -> Document:
@@ -21,7 +21,7 @@ def get_example_document(doc_id: str) -> Document:
"doc_hash": doc_id,
"document_id": doc_id,
"dataset_id": doc_id,
}
},
)
return doc
@@ -45,7 +45,7 @@ class AbstractVectorTest:
def __init__(self):
self.vector = None
self.dataset_id = str(uuid.uuid4())
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + '_test'
self.collection_name = Dataset.gen_collection_name_by_id(self.dataset_id) + "_test"
self.example_doc_id = str(uuid.uuid4())
self.example_embedding = [1.001 * i for i in range(128)]
@@ -58,12 +58,12 @@ class AbstractVectorTest:
def search_by_vector(self):
hits_by_vector: list[Document] = self.vector.search_by_vector(query_vector=self.example_embedding)
assert len(hits_by_vector) == 1
assert hits_by_vector[0].metadata['doc_id'] == self.example_doc_id
assert hits_by_vector[0].metadata["doc_id"] == self.example_doc_id
def search_by_full_text(self):
hits_by_full_text: list[Document] = self.vector.search_by_full_text(query=get_example_text())
assert len(hits_by_full_text) == 1
assert hits_by_full_text[0].metadata['doc_id'] == self.example_doc_id
assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id
def delete_vector(self):
self.vector.delete()
@@ -76,14 +76,14 @@ class AbstractVectorTest:
documents = [get_example_document(doc_id=str(uuid.uuid4())) for _ in range(batch_size)]
embeddings = [self.example_embedding] * batch_size
self.vector.add_texts(documents=documents, embeddings=embeddings)
return [doc.metadata['doc_id'] for doc in documents]
return [doc.metadata["doc_id"] for doc in documents]
def text_exists(self):
assert self.vector.text_exists(self.example_doc_id)
def get_ids_by_metadata_field(self):
with pytest.raises(NotImplementedError):
self.vector.get_ids_by_metadata_field(key='key', value='value')
self.vector.get_ids_by_metadata_field(key="key", value="value")
def run_all_tests(self):
self.create_vector()

View File

@@ -10,15 +10,15 @@ from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, ge
@pytest.fixture
def tidb_vector():
return TiDBVector(
collection_name='test_collection',
collection_name="test_collection",
config=TiDBVectorConfig(
host="xxx.eu-central-1.xxx.aws.tidbcloud.com",
port="4000",
user="xxx.root",
password="xxxxxx",
database="dify",
program_name="langgenius/dify"
)
program_name="langgenius/dify",
),
)
@@ -40,7 +40,7 @@ class TiDBVectorTest(AbstractVectorTest):
assert len(hits_by_full_text) == 0
def get_ids_by_metadata_field(self):
ids = self.vector.get_ids_by_metadata_field(key='document_id', value=self.example_doc_id)
ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id)
assert len(ids) == 0
@@ -50,12 +50,12 @@ def test_tidb_vector(setup_mock_redis, setup_tidbvector_mock, tidb_vector, mock_
@pytest.fixture
def mock_session():
with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.Session', new_callable=MagicMock) as mock_session:
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.Session", new_callable=MagicMock) as mock_session:
yield mock_session
@pytest.fixture
def setup_tidbvector_mock(tidb_vector, mock_session):
with patch('core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine'):
with patch.object(tidb_vector._engine, 'connect'):
with patch("core.rag.datasource.vdb.tidb_vector.tidb_vector.create_engine"):
with patch.object(tidb_vector._engine, "connect"):
yield tidb_vector

View File

@@ -8,14 +8,14 @@ from tests.integration_tests.vdb.test_vector_store import (
class WeaviateVectorTest(AbstractVectorTest):
def __init__(self):
super().__init__()
self.attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"]
self.vector = WeaviateVector(
collection_name=self.collection_name,
config=WeaviateConfig(
endpoint='http://localhost:8080',
api_key='WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih',
endpoint="http://localhost:8080",
api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih",
),
attributes=self.attributes
attributes=self.attributes,
)