add redis lock on create collection in multiple thread mode (#3054)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2024-04-01 02:10:41 +08:00
committed by GitHub
parent 1716ac562c
commit 84d118de07
4 changed files with 128 additions and 105 deletions

View File

@@ -20,6 +20,7 @@ from qdrant_client.local.qdrant_local import QdrantLocal
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
@@ -77,6 +78,17 @@ class QdrantVector(BaseVector):
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
lock_name = 'vector_indexing_lock_{}'.format(collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = 'vector_indexing_{}'.format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
@@ -84,40 +96,35 @@ class QdrantVector(BaseVector):
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
# create collection
self.create_collection(collection_name, vector_size)
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
)
# create payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
field_schema=text_index_params)
# create payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
field_schema=text_index_params)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)