Initial commit
This commit is contained in:
34
api/core/vector_store/base.py
Normal file
34
api/core/vector_store/base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from llama_index import ServiceContext, GPTVectorStoreIndex
|
||||
from llama_index.data_structs import Node
|
||||
from llama_index.vector_stores.types import VectorStore
|
||||
|
||||
|
||||
class BaseVectorStoreClient(ABC):
|
||||
@abstractmethod
|
||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def to_index_config(self, index_id: str) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
|
||||
def delete_node(self, node_id: str):
|
||||
self._vector_store.delete_node(node_id)
|
||||
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
return self._vector_store.exists_by_node_id(node_id)
|
||||
|
||||
|
||||
class EnhanceVectorStore(ABC):
|
||||
@abstractmethod
|
||||
def delete_node(self, node_id: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
pass
|
||||
147
api/core/vector_store/qdrant_vector_store_client.py
Normal file
147
api/core/vector_store/qdrant_vector_store_client.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import os
|
||||
from typing import cast, List
|
||||
|
||||
from llama_index.data_structs import Node
|
||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
||||
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
|
||||
from qdrant_client.http.models import Payload, Filter
|
||||
|
||||
import qdrant_client
|
||||
from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
|
||||
from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
|
||||
from llama_index.vector_stores import QdrantVectorStore
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
|
||||
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
|
||||
|
||||
|
||||
class QdrantVectorStoreClient(BaseVectorStoreClient):
|
||||
|
||||
def __init__(self, url: str, api_key: str, root_path: str):
|
||||
self._client = self.init_from_config(url, api_key, root_path)
|
||||
|
||||
@classmethod
|
||||
def init_from_config(cls, url: str, api_key: str, root_path: str):
|
||||
if url and url.startswith('path:'):
|
||||
path = url.replace('path:', '')
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.join(root_path, path)
|
||||
|
||||
return qdrant_client.QdrantClient(
|
||||
path=path
|
||||
)
|
||||
else:
|
||||
return qdrant_client.QdrantClient(
|
||||
url=url,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
||||
index_struct = QdrantIndexDict()
|
||||
|
||||
if self._client is None:
|
||||
raise Exception("Vector client is not initialized.")
|
||||
|
||||
# {"collection_name": "Gpt_index_xxx"}
|
||||
collection_name = config.get('collection_name')
|
||||
if not collection_name:
|
||||
raise Exception("collection_name cannot be None.")
|
||||
|
||||
return GPTQdrantEnhanceIndex(
|
||||
service_context=service_context,
|
||||
index_struct=index_struct,
|
||||
vector_store=QdrantEnhanceVectorStore(
|
||||
client=self._client,
|
||||
collection_name=collection_name
|
||||
)
|
||||
)
|
||||
|
||||
def to_index_config(self, index_id: str) -> dict:
|
||||
return {"collection_name": index_id}
|
||||
|
||||
|
||||
class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
|
||||
pass
|
||||
|
||||
|
||||
class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
|
||||
def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete node from the index.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
from qdrant_client.http import models as rest
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
self._client.delete(
|
||||
collection_name=self._collection_name,
|
||||
points_selector=rest.Filter(
|
||||
must=[
|
||||
rest.FieldCondition(
|
||||
key="id", match=rest.MatchValue(value=node_id)
|
||||
)
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
"""
|
||||
Get node from the index by node id.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
self._reload_if_needed()
|
||||
|
||||
response = self._client.retrieve(
|
||||
collection_name=self._collection_name,
|
||||
ids=[node_id]
|
||||
)
|
||||
|
||||
return len(response) > 0
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: VectorStoreQuery,
|
||||
) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes.
|
||||
|
||||
Args:
|
||||
query (VectorStoreQuery): query
|
||||
"""
|
||||
query_embedding = cast(List[float], query.query_embedding)
|
||||
|
||||
self._reload_if_needed()
|
||||
|
||||
response = self._client.search(
|
||||
collection_name=self._collection_name,
|
||||
query_vector=query_embedding,
|
||||
limit=cast(int, query.similarity_top_k),
|
||||
query_filter=cast(Filter, self._build_query_filter(query)),
|
||||
with_vectors=True
|
||||
)
|
||||
|
||||
nodes = []
|
||||
similarities = []
|
||||
ids = []
|
||||
for point in response:
|
||||
payload = cast(Payload, point.payload)
|
||||
node = Node(
|
||||
doc_id=str(point.id),
|
||||
text=payload.get("text"),
|
||||
embedding=point.vector,
|
||||
extra_info=payload.get("extra_info"),
|
||||
relationships={
|
||||
DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
|
||||
},
|
||||
)
|
||||
nodes.append(node)
|
||||
similarities.append(point.score)
|
||||
ids.append(str(point.id))
|
||||
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
||||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client._client, QdrantLocal):
|
||||
self._client._client._load()
|
||||
61
api/core/vector_store/vector_store.py
Normal file
61
api/core/vector_store/vector_store.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from flask import Flask
|
||||
from llama_index import ServiceContext, GPTVectorStoreIndex
|
||||
from requests import ReadTimeout
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt
|
||||
|
||||
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
|
||||
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
|
||||
|
||||
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
|
||||
|
||||
|
||||
class VectorStore:
|
||||
|
||||
def __init__(self):
|
||||
self._vector_store = None
|
||||
self._client = None
|
||||
|
||||
def init_app(self, app: Flask):
|
||||
if not app.config['VECTOR_STORE']:
|
||||
return
|
||||
|
||||
self._vector_store = app.config['VECTOR_STORE']
|
||||
if self._vector_store not in SUPPORTED_VECTOR_STORES:
|
||||
raise ValueError(f"Vector store {self._vector_store} is not supported.")
|
||||
|
||||
if self._vector_store == 'weaviate':
|
||||
self._client = WeaviateVectorStoreClient(
|
||||
endpoint=app.config['WEAVIATE_ENDPOINT'],
|
||||
api_key=app.config['WEAVIATE_API_KEY'],
|
||||
grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED']
|
||||
)
|
||||
elif self._vector_store == 'qdrant':
|
||||
self._client = QdrantVectorStoreClient(
|
||||
url=app.config['QDRANT_URL'],
|
||||
api_key=app.config['QDRANT_API_KEY'],
|
||||
root_path=app.root_path
|
||||
)
|
||||
|
||||
app.extensions['vector_store'] = self
|
||||
|
||||
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
|
||||
def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
|
||||
vector_store_config: dict = index_struct.get('vector_store')
|
||||
index = self.get_client().get_index(
|
||||
service_context=service_context,
|
||||
config=vector_store_config
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
def to_index_struct(self, index_id: str) -> dict:
|
||||
return {
|
||||
"type": self._vector_store,
|
||||
"vector_store": self.get_client().to_index_config(index_id)
|
||||
}
|
||||
|
||||
def get_client(self):
|
||||
if not self._client:
|
||||
raise Exception("Vector store client is not initialized.")
|
||||
|
||||
return self._client
|
||||
66
api/core/vector_store/vector_store_index_query.py
Normal file
66
api/core/vector_store/vector_store_index_query.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from llama_index.indices.query.base import IS
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional
|
||||
)
|
||||
|
||||
from llama_index.docstore import BaseDocumentStore
|
||||
from llama_index.indices.postprocessor.node import (
|
||||
BaseNodePostprocessor,
|
||||
)
|
||||
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
|
||||
from llama_index.indices.response.response_builder import ResponseMode
|
||||
from llama_index.indices.service_context import ServiceContext
|
||||
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
|
||||
from llama_index.prompts.prompts import (
|
||||
QuestionAnswerPrompt,
|
||||
RefinePrompt,
|
||||
SimpleInputPrompt,
|
||||
)
|
||||
|
||||
from core.index.query.synthesizer import EnhanceResponseSynthesizer
|
||||
|
||||
|
||||
class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls,
|
||||
index_struct: IS,
|
||||
service_context: ServiceContext,
|
||||
docstore: Optional[BaseDocumentStore] = None,
|
||||
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
|
||||
verbose: bool = False,
|
||||
# response synthesizer args
|
||||
response_mode: ResponseMode = ResponseMode.DEFAULT,
|
||||
text_qa_template: Optional[QuestionAnswerPrompt] = None,
|
||||
refine_template: Optional[RefinePrompt] = None,
|
||||
simple_template: Optional[SimpleInputPrompt] = None,
|
||||
response_kwargs: Optional[Dict] = None,
|
||||
use_async: bool = False,
|
||||
streaming: bool = False,
|
||||
optimizer: Optional[BaseTokenUsageOptimizer] = None,
|
||||
# class-specific args
|
||||
**kwargs: Any,
|
||||
) -> "BaseGPTIndexQuery":
|
||||
response_synthesizer = EnhanceResponseSynthesizer.from_args(
|
||||
service_context=service_context,
|
||||
text_qa_template=text_qa_template,
|
||||
refine_template=refine_template,
|
||||
simple_template=simple_template,
|
||||
response_mode=response_mode,
|
||||
response_kwargs=response_kwargs,
|
||||
use_async=use_async,
|
||||
streaming=streaming,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
return cls(
|
||||
index_struct=index_struct,
|
||||
service_context=service_context,
|
||||
response_synthesizer=response_synthesizer,
|
||||
docstore=docstore,
|
||||
node_postprocessors=node_postprocessors,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
258
api/core/vector_store/weaviate_vector_store_client.py
Normal file
258
api/core/vector_store/weaviate_vector_store_client.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import json
|
||||
import weaviate
|
||||
from dataclasses import field
|
||||
from typing import List, Any, Dict, Optional
|
||||
|
||||
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
|
||||
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
|
||||
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
|
||||
from llama_index.data_structs.node_v2 import DocumentRelationship
|
||||
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
|
||||
from llama_index.vector_stores import WeaviateVectorStore
|
||||
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
|
||||
from llama_index.readers.weaviate.utils import (
|
||||
parse_get_response,
|
||||
validate_client,
|
||||
)
|
||||
|
||||
|
||||
class WeaviateVectorStoreClient(BaseVectorStoreClient):
|
||||
|
||||
def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool):
|
||||
self._client = self.init_from_config(endpoint, api_key, grpc_enabled)
|
||||
|
||||
def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool):
|
||||
auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
|
||||
|
||||
weaviate.connect.connection.has_grpc = grpc_enabled
|
||||
|
||||
return weaviate.Client(
|
||||
url=endpoint,
|
||||
auth_client_secret=auth_config,
|
||||
timeout_config=(5, 15),
|
||||
startup_period=None
|
||||
)
|
||||
|
||||
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
|
||||
index_struct = WeaviateIndexDict()
|
||||
|
||||
if self._client is None:
|
||||
raise Exception("Vector client is not initialized.")
|
||||
|
||||
# {"class_prefix": "Gpt_index_xxx"}
|
||||
class_prefix = config.get('class_prefix')
|
||||
if not class_prefix:
|
||||
raise Exception("class_prefix cannot be None.")
|
||||
|
||||
return GPTWeaviateEnhanceIndex(
|
||||
service_context=service_context,
|
||||
index_struct=index_struct,
|
||||
vector_store=WeaviateWithSimilaritiesVectorStore(
|
||||
weaviate_client=self._client,
|
||||
class_prefix=class_prefix
|
||||
)
|
||||
)
|
||||
|
||||
def to_index_config(self, index_id: str) -> dict:
|
||||
return {"class_prefix": index_id}
|
||||
|
||||
|
||||
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
|
||||
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes."""
|
||||
nodes = self.weaviate_query(
|
||||
self._client,
|
||||
self._class_prefix,
|
||||
query,
|
||||
)
|
||||
nodes = nodes[: query.similarity_top_k]
|
||||
node_idxs = [str(i) for i in range(len(nodes))]
|
||||
|
||||
similarities = []
|
||||
for node in nodes:
|
||||
similarities.append(node.extra_info['similarity'])
|
||||
del node.extra_info['similarity']
|
||||
|
||||
return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
|
||||
|
||||
def weaviate_query(
|
||||
self,
|
||||
client: Any,
|
||||
class_prefix: str,
|
||||
query_spec: VectorStoreQuery,
|
||||
) -> List[Node]:
|
||||
"""Convert to LlamaIndex list."""
|
||||
validate_client(client)
|
||||
|
||||
class_name = _class_name(class_prefix)
|
||||
prop_names = [p["name"] for p in NODE_SCHEMA]
|
||||
vector = query_spec.query_embedding
|
||||
|
||||
# build query
|
||||
query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
|
||||
if query_spec.mode == VectorStoreQueryMode.DEFAULT:
|
||||
_logger.debug("Using vector search")
|
||||
if vector is not None:
|
||||
query = query.with_near_vector(
|
||||
{
|
||||
"vector": vector,
|
||||
}
|
||||
)
|
||||
elif query_spec.mode == VectorStoreQueryMode.HYBRID:
|
||||
_logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
|
||||
query = query.with_hybrid(
|
||||
query=query_spec.query_str,
|
||||
alpha=query_spec.alpha,
|
||||
vector=vector,
|
||||
)
|
||||
query = query.with_limit(query_spec.similarity_top_k)
|
||||
_logger.debug(f"Using limit of {query_spec.similarity_top_k}")
|
||||
|
||||
# execute query
|
||||
query_result = query.do()
|
||||
|
||||
# parse results
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
results = [self._to_node(entry) for entry in entries]
|
||||
return results
|
||||
|
||||
def _to_node(self, entry: Dict) -> Node:
|
||||
"""Convert to Node."""
|
||||
extra_info_str = entry["extra_info"]
|
||||
if extra_info_str == "":
|
||||
extra_info = None
|
||||
else:
|
||||
extra_info = json.loads(extra_info_str)
|
||||
|
||||
if 'certainty' in entry['_additional']:
|
||||
if extra_info:
|
||||
extra_info['similarity'] = entry['_additional']['certainty']
|
||||
else:
|
||||
extra_info = {'similarity': entry['_additional']['certainty']}
|
||||
|
||||
node_info_str = entry["node_info"]
|
||||
if node_info_str == "":
|
||||
node_info = None
|
||||
else:
|
||||
node_info = json.loads(node_info_str)
|
||||
|
||||
relationships_str = entry["relationships"]
|
||||
relationships: Dict[DocumentRelationship, str]
|
||||
if relationships_str == "":
|
||||
relationships = field(default_factory=dict)
|
||||
else:
|
||||
relationships = {
|
||||
DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
|
||||
}
|
||||
|
||||
return Node(
|
||||
text=entry["text"],
|
||||
doc_id=entry["doc_id"],
|
||||
embedding=entry["_additional"]["vector"],
|
||||
extra_info=extra_info,
|
||||
node_info=node_info,
|
||||
relationships=relationships,
|
||||
)
|
||||
|
||||
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""Delete a document.
|
||||
|
||||
Args:
|
||||
doc_id (str): document id
|
||||
|
||||
"""
|
||||
delete_document(self._client, doc_id, self._class_prefix)
|
||||
|
||||
def delete_node(self, node_id: str):
|
||||
"""
|
||||
Delete node from the index.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
delete_node(self._client, node_id, self._class_prefix)
|
||||
|
||||
def exists_by_node_id(self, node_id: str) -> bool:
|
||||
"""
|
||||
Get node from the index by node id.
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
entry = get_by_node_id(self._client, node_id, self._class_prefix)
|
||||
return True if entry else False
|
||||
|
||||
|
||||
class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
|
||||
pass
|
||||
|
||||
|
||||
def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
|
||||
"""Delete entry."""
|
||||
validate_client(client)
|
||||
# make sure that each entry
|
||||
class_name = _class_name(class_prefix)
|
||||
where_filter = {
|
||||
"path": ["ref_doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": ref_doc_id,
|
||||
}
|
||||
query = (
|
||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
||||
)
|
||||
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
for entry in entries:
|
||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
||||
|
||||
while len(entries) > 0:
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
for entry in entries:
|
||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
||||
|
||||
|
||||
def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
|
||||
"""Delete entry."""
|
||||
validate_client(client)
|
||||
# make sure that each entry
|
||||
class_name = _class_name(class_prefix)
|
||||
where_filter = {
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": node_id,
|
||||
}
|
||||
query = (
|
||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
||||
)
|
||||
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
for entry in entries:
|
||||
client.data_object.delete(entry["_additional"]["id"], class_name)
|
||||
|
||||
|
||||
def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
|
||||
"""Delete entry."""
|
||||
validate_client(client)
|
||||
# make sure that each entry
|
||||
class_name = _class_name(class_prefix)
|
||||
where_filter = {
|
||||
"path": ["doc_id"],
|
||||
"operator": "Equal",
|
||||
"valueString": node_id,
|
||||
}
|
||||
query = (
|
||||
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
|
||||
)
|
||||
|
||||
query_result = query.do()
|
||||
parsed_result = parse_get_response(query_result)
|
||||
entries = parsed_result[class_name]
|
||||
if len(entries) == 0:
|
||||
return None
|
||||
|
||||
return entries[0]
|
||||
Reference in New Issue
Block a user