Feat/dify rag (#2528)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2024-02-22 23:31:57 +08:00
committed by GitHub
parent 97fe817186
commit 6c4e6bf1d6
119 changed files with 3181 additions and 5892 deletions

View File

View File

@@ -0,0 +1,10 @@
from enum import Enum
class Field(Enum):
CONTENT_KEY = "page_content"
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
TEXT_KEY = "text"
PRIMARY_KEY = " id"

View File

@@ -0,0 +1,214 @@
import logging
from typing import Any, Optional
from uuid import uuid4
from pydantic import BaseModel, root_validator
from pymilvus import MilvusClient, MilvusException, connections
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
logger = logging.getLogger(__name__)
class MilvusConfig(BaseModel):
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVector(BaseVector):
def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = 'Session'
self._fields = []
def get_type(self) -> str:
return 'milvus'
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
metadatas = [d.metadata for d in texts]
# Grab the existing collection if it exists
from pymilvus import utility
alias = uuid4().hex
if self._client_config.secure:
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
else:
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
if not utility.has_collection(self._collection_name, using=alias):
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata
}
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)
pks: list[str] = []
for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i:i + 1000]
# Insert into the collection.
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
def delete_by_document_id(self, document_id: str):
ids = self.get_ids_by_metadata_field('document_id', document_id)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def get_ids_by_metadata_field(self, key: str, value: str):
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["{key}"] == "{value}"',
output_fields=["id"])
if result:
return [item["id"] for item in result]
else:
return None
def delete_by_metadata_field(self, key: str, value: str):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)
def delete_by_ids(self, doc_ids: list[str]) -> None:
self._client.delete(collection_name=self._collection_name, pks=doc_ids)
def delete(self) -> None:
from pymilvus import utility
utility.drop_collection(self._collection_name, None)
def text_exists(self, id: str) -> bool:
result = self._client.query(collection_name=self._collection_name,
filter=f'metadata["doc_id"] == "{id}"',
output_fields=["id"])
return len(result) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
results = self._client.search(collection_name=self._collection_name,
data=[query_vector],
limit=kwargs.get('top_k', 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result['entity'].get(Field.METADATA_KEY.value)
metadata['score'] = result['distance']
score_threshold = kwargs.get('score_threshold') if kwargs.get('score_threshold') else 0.0
if result['distance'] > score_threshold:
doc = Document(page_content=result['entity'].get(Field.CONTENT_KEY.value),
metadata=metadata)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
) -> str:
from pymilvus import CollectionSchema, DataType, FieldSchema
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
fields.append(
FieldSchema(
Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(fields)
for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
self._fields.remove(Field.PRIMARY_KEY.value)
# Create the collection
collection_name = self._collection_name
self._client.create_collection_with_schema(collection_name=collection_name,
schema=schema, index_param=index_params,
consistency_level=self._consistency_level)
return collection_name
def _init_client(self, config) -> MilvusClient:
if config.secure:
uri = "https://" + str(config.host) + ":" + str(config.port)
else:
uri = "http://" + str(config.host) + ":" + str(config.port)
client = MilvusClient(uri=uri, user=config.user, password=config.password)
return client

View File

@@ -0,0 +1,360 @@
import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import qdrant_client
from pydantic import BaseModel
from qdrant_client.http import models as rest
from qdrant_client.http.models import (
FilterSelector,
HnswConfigDiff,
PayloadSchemaType,
TextIndexParams,
TextIndexType,
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
if TYPE_CHECKING:
from qdrant_client import grpc # noqa
from qdrant_client.conversions import common_types
from qdrant_client.http import models as rest
DictFilter = dict[str, Union[str, int, bool, dict, list]]
MetadataFilter = Union[DictFilter, common_types.Filter]
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith('path:'):
path = self.endpoint.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(self.root_path, path)
return {
'path': path
}
else:
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
}
class QdrantVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = 'Cosine'):
super().__init__(collection_name)
self._client_config = config
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
self._distance_func = distance_func.upper()
self._group_id = group_id
def get_type(self) -> str:
return 'qdrant'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if texts:
# get embedding vector size
vector_size = len(embeddings[0])
# get collection name
collection_name = self._collection_name
collection_name = collection_name or uuid.uuid4().hex
all_collection_name = []
collections_response = self._client.get_collections()
collection_list = collections_response.collections
for collection in collection_list:
all_collection_name.append(collection.name)
if collection_name not in all_collection_name:
# create collection
self.create_collection(collection_name, vector_size)
self.add_texts(texts, embeddings, **kwargs)
def create_collection(self, collection_name: str, vector_size: int):
from qdrant_client.http import models as rest
vectors_config = rest.VectorParams(
size=vector_size,
distance=rest.Distance[self._distance_func],
)
hnsw_config = HnswConfigDiff(m=0, payload_m=16, ef_construct=100, full_scan_threshold=10000,
max_indexing_threads=0, on_disk=False)
self._client.recreate_collection(
collection_name=collection_name,
vectors_config=vectors_config,
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
)
# create payload index
self._client.create_payload_index(collection_name, Field.GROUP_KEY.value,
field_schema=PayloadSchemaType.KEYWORD,
field_type=PayloadSchemaType.KEYWORD)
# creat full text index
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
self._client.create_payload_index(collection_name, Field.CONTENT_KEY.value,
field_schema=text_index_params)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
added_ids = []
for batch_ids, points in self._generate_rest_batches(
texts, embeddings, metadatas, uuids, 64, self._group_id
):
self._client.upsert(
collection_name=self._collection_name, points=points
)
added_ids.extend(batch_ids)
return added_ids
def _generate_rest_batches(
self,
texts: Iterable[str],
embeddings: list[list[float]],
metadatas: Optional[list[dict]] = None,
ids: Optional[Sequence[str]] = None,
batch_size: int = 64,
group_id: Optional[str] = None,
) -> Generator[tuple[list[str], list[rest.PointStruct]], None, None]:
from qdrant_client.http import models as rest
texts_iterator = iter(texts)
embeddings_iterator = iter(embeddings)
metadatas_iterator = iter(metadatas or [])
ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
while batch_texts := list(islice(texts_iterator, batch_size)):
# Take the corresponding metadata and id for each text in a batch
batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
batch_ids = list(islice(ids_iterator, batch_size))
# Generate the embeddings for all the texts in a batch
batch_embeddings = list(islice(embeddings_iterator, batch_size))
points = [
rest.PointStruct(
id=point_id,
vector=vector,
payload=payload,
)
for point_id, vector, payload in zip(
batch_ids,
batch_embeddings,
self._build_payloads(
batch_texts,
batch_metadatas,
Field.CONTENT_KEY.value,
Field.METADATA_KEY.value,
group_id,
Field.GROUP_KEY.value,
),
)
]
yield batch_ids, points
@classmethod
def _build_payloads(
cls,
texts: Iterable[str],
metadatas: Optional[list[dict]],
content_payload_key: str,
metadata_payload_key: str,
group_id: str,
group_payload_key: str
) -> list[dict]:
payloads = []
for i, text in enumerate(texts):
if text is None:
raise ValueError(
"At least one of the texts is None. Please remove it before "
"calling .from_texts or .add_texts on Qdrant instance."
)
metadata = metadatas[i] if metadatas is not None else None
payloads.append(
{
content_payload_key: text,
metadata_payload_key: metadata,
group_payload_key: group_id
}
)
return payloads
def delete_by_metadata_field(self, key: str, value: str):
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
),
],
)
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def delete(self):
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def delete_by_ids(self, ids: list[str]) -> None:
from qdrant_client.http import models
for node_id in ids:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.doc_id",
match=models.MatchValue(value=node_id),
),
],
)
self._client.delete(
collection_name=self._collection_name,
points_selector=FilterSelector(
filter=filter
),
)
def text_exists(self, id: str) -> bool:
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[id]
)
return len(response) > 0
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from qdrant_client.http import models
filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
],
)
results = self._client.search(
collection_name=self._collection_name,
query_vector=query_vector,
query_filter=filter,
limit=kwargs.get("top_k", 4),
with_payload=True,
with_vectors=True,
score_threshold=kwargs.get("score_threshold", .0)
)
docs = []
for result in results:
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
if result.score > score_threshold:
metadata['score'] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value),
metadata=metadata,
)
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs most similar by bm25.
Returns:
List of documents most similar to the query text and distance for each.
"""
from qdrant_client.http import models
scroll_filter = models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self._group_id),
),
models.FieldCondition(
key="page_content",
match=models.MatchText(text=query),
)
]
)
response = self._client.scroll(
collection_name=self._collection_name,
scroll_filter=scroll_filter,
limit=kwargs.get('top_k', 2),
with_payload=True,
with_vectors=True
)
results = response[0]
documents = []
for result in results:
if result:
documents.append(self._document_from_scored_point(
result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value
))
return documents
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod
def _document_from_scored_point(
cls,
scored_point: Any,
content_payload_key: str,
metadata_payload_key: str,
) -> Document:
return Document(
page_content=scored_point.payload.get(content_payload_key),
metadata=scored_point.payload.get(metadata_payload_key) or {},
)

View File

@@ -0,0 +1,62 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from core.rag.models.document import Document
class BaseVector(ABC):
def __init__(self, collection_name: str):
self._collection_name = collection_name
@abstractmethod
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
raise NotImplementedError
@abstractmethod
def text_exists(self, id: str) -> bool:
raise NotImplementedError
@abstractmethod
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
@abstractmethod
def delete_by_metadata_field(self, key: str, value: str) -> None:
raise NotImplementedError
@abstractmethod
def search_by_vector(
self,
query_vector: list[float],
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
@abstractmethod
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:
raise NotImplementedError
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def _get_uuids(self, texts: list[Document]) -> list[str]:
return [text.metadata['doc_id'] for text in texts]

View File

@@ -0,0 +1,171 @@
from typing import Any, cast
from flask import current_app
from core.embedding.cached_embedding import CacheEmbedding
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class Vector:
def __init__(self, dataset: Dataset, attributes: list = None):
if attributes is None:
attributes = ['doc_id', 'dataset_id', 'document_id', 'doc_hash']
self._dataset = dataset
self._embeddings = self._get_embeddings()
self._attributes = attributes
self._vector_processor = self._init_vector()
def _init_vector(self) -> BaseVector:
config = cast(dict, current_app.config)
vector_type = config.get('VECTOR_STORE')
if self._dataset.index_struct_dict:
vector_type = self._dataset.index_struct_dict['type']
if not vector_type:
raise ValueError("Vector store must be specified.")
if vector_type == "weaviate":
from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return WeaviateVector(
collection_name=collection_name,
config=WeaviateConfig(
endpoint=config.get('WEAVIATE_ENDPOINT'),
api_key=config.get('WEAVIATE_API_KEY'),
batch_size=int(config.get('WEAVIATE_BATCH_SIZE'))
),
attributes=self._attributes
)
elif vector_type == "qdrant":
from core.rag.datasource.vdb.qdrant.qdrant_vector import QdrantConfig, QdrantVector
if self._dataset.collection_binding_id:
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
filter(DatasetCollectionBinding.id == self._dataset.collection_binding_id). \
one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
raise ValueError('Dataset Collection Bindings is not exist!')
else:
if self._dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return QdrantVector(
collection_name=collection_name,
group_id=self._dataset.id,
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
)
)
elif vector_type == "milvus":
from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig, MilvusVector
if self._dataset.index_struct_dict:
class_prefix: str = self._dataset.index_struct_dict['vector_store']['class_prefix']
collection_name = class_prefix
else:
dataset_id = self._dataset.id
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
return MilvusVector(
collection_name=collection_name,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
)
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
def create(self, texts: list = None, **kwargs):
if texts:
embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
self._vector_processor.create(
texts=texts,
embeddings=embeddings,
**kwargs
)
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get('duplicate_check', False):
documents = self._filter_duplicate_texts(documents)
embeddings = self._embeddings.embed_documents([document.page_content for document in documents])
self._vector_processor.add_texts(
documents=documents,
embeddings=embeddings,
**kwargs
)
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
self._vector_processor.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(
self, query: str,
**kwargs: Any
) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
def search_by_full_text(
self, query: str,
**kwargs: Any
) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self) -> None:
self._vector_processor.delete()
def _get_embeddings(self) -> Embeddings:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=self._dataset.tenant_id,
provider=self._dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=self._dataset.embedding_model
)
return CacheEmbedding(embedding_model)
def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]:
for text in texts:
doc_id = text.metadata['doc_id']
exists_duplicate_node = self.text_exists(doc_id)
if exists_duplicate_node:
texts.remove(text)
return texts
def __getattr__(self, name):
if self._vector_processor is not None:
method = getattr(self._vector_processor, name)
if callable(method):
return method
raise AttributeError(f"'vector_processor' object has no attribute '{name}'")

View File

@@ -0,0 +1,235 @@
import datetime
from typing import Any, Optional
import requests
import weaviate
from pydantic import BaseModel, root_validator
from core.rag.datasource.vdb.field import Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.models.document import Document
from models.dataset import Dataset
class WeaviateConfig(BaseModel):
endpoint: str
api_key: Optional[str]
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config WEAVIATE_ENDPOINT is required")
return values
class WeaviateVector(BaseVector):
def __init__(self, collection_name: str, config: WeaviateConfig, attributes: list):
super().__init__(collection_name)
self._client = self._init_client(config)
self._attributes = attributes
def _init_client(self, config: WeaviateConfig) -> weaviate.Client:
auth_config = weaviate.auth.AuthApiKey(api_key=config.api_key)
weaviate.connect.connection.has_grpc = False
try:
client = weaviate.Client(
url=config.endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 60),
startup_period=None
)
except requests.exceptions.ConnectionError:
raise ConnectionError("Vector database connection error")
client.batch.configure(
# `batch_size` takes an `int` value to enable auto-batching
# (`None` is used for manual batching)
batch_size=config.batch_size,
# dynamically update the `batch_size` based on import speed
dynamic=True,
# `timeout_retries` takes an `int` value to retry on time outs
timeout_retries=3,
)
return client
def get_type(self) -> str:
return 'weaviate'
def get_collection_name(self, dataset: Dataset) -> str:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
class_prefix += '_Node'
return class_prefix
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
"vector_store": {"class_prefix": self._collection_name}
}
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
schema = self._default_schema(self._collection_name)
# check whether the index already exists
if not self._client.schema.contains(schema):
# create collection
self._client.schema.create_class(schema)
# create vector
self.add_texts(texts, embeddings)
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
uuids = self._get_uuids(documents)
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
ids = []
with self._client.batch as batch:
for i, text in enumerate(texts):
data_properties = {Field.TEXT_KEY.value: text}
if metadatas is not None:
for key, val in metadatas[i].items():
data_properties[key] = self._json_serializable(val)
batch.add_data_object(
data_object=data_properties,
class_name=self._collection_name,
uuid=uuids[i],
vector=embeddings[i] if embeddings else None,
)
ids.append(uuids[i])
return ids
def delete_by_metadata_field(self, key: str, value: str):
where_filter = {
"operator": "Equal",
"path": [key],
"valueText": value
}
self._client.batch.delete_objects(
class_name=self._collection_name,
where=where_filter,
output='minimal'
)
def delete(self):
self._client.schema.delete_class(self._collection_name)
def text_exists(self, id: str) -> bool:
collection_name = self._collection_name
result = self._client.query.get(collection_name).with_additional(["id"]).with_where({
"path": ["doc_id"],
"operator": "Equal",
"valueText": id,
}).with_limit(1).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
entries = result["data"]["Get"][collection_name]
if len(entries) == 0:
return False
return True
def delete_by_ids(self, ids: list[str]) -> None:
self._client.data_object.delete(
ids,
class_name=self._collection_name
)
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""Look up similar documents by embedding vector in Weaviate."""
collection_name = self._collection_name
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
result = (
query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4))
.with_additional(["vector", "distance"])
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", .0) if kwargs.get('score_threshold', .0) else 0.0
# check score threshold
if score > score_threshold:
doc.metadata['score'] = score
docs.append(doc)
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
if kwargs.get("additional"):
query_obj = query_obj.with_additional(kwargs.get("additional"))
properties = ['text']
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get('top_k', 2)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
docs.append(Document(page_content=text, metadata=res))
return docs
def _default_schema(self, index_name: str) -> dict:
return {
"class": index_name,
"properties": [
{
"name": "text",
"dataType": ["text"],
}
],
}
def _json_serializable(self, value: Any) -> Any:
if isinstance(value, datetime.datetime):
return value.isoformat()
return value