Initial commit

This commit is contained in:
John Wang
2023-05-15 08:51:32 +08:00
commit db896255d6
744 changed files with 56028 additions and 0 deletions

View File

@@ -0,0 +1,34 @@
from abc import ABC, abstractmethod
from typing import Optional
from llama_index import ServiceContext, GPTVectorStoreIndex
from llama_index.data_structs import Node
from llama_index.vector_stores.types import VectorStore
class BaseVectorStoreClient(ABC):
@abstractmethod
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
raise NotImplementedError
@abstractmethod
def to_index_config(self, index_id: str) -> dict:
raise NotImplementedError
class BaseGPTVectorStoreIndex(GPTVectorStoreIndex):
def delete_node(self, node_id: str):
self._vector_store.delete_node(node_id)
def exists_by_node_id(self, node_id: str) -> bool:
return self._vector_store.exists_by_node_id(node_id)
class EnhanceVectorStore(ABC):
@abstractmethod
def delete_node(self, node_id: str):
pass
@abstractmethod
def exists_by_node_id(self, node_id: str) -> bool:
pass

View File

@@ -0,0 +1,147 @@
import os
from typing import cast, List
from llama_index.data_structs import Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult
from qdrant_client.http.models import Payload, Filter
import qdrant_client
from llama_index import ServiceContext, GPTVectorStoreIndex, GPTQdrantIndex
from llama_index.data_structs.data_structs_v2 import QdrantIndexDict
from llama_index.vector_stores import QdrantVectorStore
from qdrant_client.local.qdrant_local import QdrantLocal
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
class QdrantVectorStoreClient(BaseVectorStoreClient):
def __init__(self, url: str, api_key: str, root_path: str):
self._client = self.init_from_config(url, api_key, root_path)
@classmethod
def init_from_config(cls, url: str, api_key: str, root_path: str):
if url and url.startswith('path:'):
path = url.replace('path:', '')
if not os.path.isabs(path):
path = os.path.join(root_path, path)
return qdrant_client.QdrantClient(
path=path
)
else:
return qdrant_client.QdrantClient(
url=url,
api_key=api_key,
)
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = QdrantIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"collection_name": "Gpt_index_xxx"}
collection_name = config.get('collection_name')
if not collection_name:
raise Exception("collection_name cannot be None.")
return GPTQdrantEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=QdrantEnhanceVectorStore(
client=self._client,
collection_name=collection_name
)
)
def to_index_config(self, index_id: str) -> dict:
return {"collection_name": index_id}
class GPTQdrantEnhanceIndex(GPTQdrantIndex, BaseGPTVectorStoreIndex):
pass
class QdrantEnhanceVectorStore(QdrantVectorStore, EnhanceVectorStore):
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
from qdrant_client.http import models as rest
self._reload_if_needed()
self._client.delete(
collection_name=self._collection_name,
points_selector=rest.Filter(
must=[
rest.FieldCondition(
key="id", match=rest.MatchValue(value=node_id)
)
]
),
)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
self._reload_if_needed()
response = self._client.retrieve(
collection_name=self._collection_name,
ids=[node_id]
)
return len(response) > 0
def query(
self,
query: VectorStoreQuery,
) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query (VectorStoreQuery): query
"""
query_embedding = cast(List[float], query.query_embedding)
self._reload_if_needed()
response = self._client.search(
collection_name=self._collection_name,
query_vector=query_embedding,
limit=cast(int, query.similarity_top_k),
query_filter=cast(Filter, self._build_query_filter(query)),
with_vectors=True
)
nodes = []
similarities = []
ids = []
for point in response:
payload = cast(Payload, point.payload)
node = Node(
doc_id=str(point.id),
text=payload.get("text"),
embedding=point.vector,
extra_info=payload.get("extra_info"),
relationships={
DocumentRelationship.SOURCE: payload.get("doc_id", "None"),
},
)
nodes.append(node)
similarities.append(point.score)
ids.append(str(point.id))
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
def _reload_if_needed(self):
if isinstance(self._client._client, QdrantLocal):
self._client._client._load()

View File

@@ -0,0 +1,61 @@
from flask import Flask
from llama_index import ServiceContext, GPTVectorStoreIndex
from requests import ReadTimeout
from tenacity import retry, retry_if_exception_type, stop_after_attempt
from core.vector_store.qdrant_vector_store_client import QdrantVectorStoreClient
from core.vector_store.weaviate_vector_store_client import WeaviateVectorStoreClient
SUPPORTED_VECTOR_STORES = ['weaviate', 'qdrant']
class VectorStore:
def __init__(self):
self._vector_store = None
self._client = None
def init_app(self, app: Flask):
if not app.config['VECTOR_STORE']:
return
self._vector_store = app.config['VECTOR_STORE']
if self._vector_store not in SUPPORTED_VECTOR_STORES:
raise ValueError(f"Vector store {self._vector_store} is not supported.")
if self._vector_store == 'weaviate':
self._client = WeaviateVectorStoreClient(
endpoint=app.config['WEAVIATE_ENDPOINT'],
api_key=app.config['WEAVIATE_API_KEY'],
grpc_enabled=app.config['WEAVIATE_GRPC_ENABLED']
)
elif self._vector_store == 'qdrant':
self._client = QdrantVectorStoreClient(
url=app.config['QDRANT_URL'],
api_key=app.config['QDRANT_API_KEY'],
root_path=app.root_path
)
app.extensions['vector_store'] = self
@retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
def get_index(self, service_context: ServiceContext, index_struct: dict) -> GPTVectorStoreIndex:
vector_store_config: dict = index_struct.get('vector_store')
index = self.get_client().get_index(
service_context=service_context,
config=vector_store_config
)
return index
def to_index_struct(self, index_id: str) -> dict:
return {
"type": self._vector_store,
"vector_store": self.get_client().to_index_config(index_id)
}
def get_client(self):
if not self._client:
raise Exception("Vector store client is not initialized.")
return self._client

View File

@@ -0,0 +1,66 @@
from llama_index.indices.query.base import IS
from typing import (
Any,
Dict,
List,
Optional
)
from llama_index.docstore import BaseDocumentStore
from llama_index.indices.postprocessor.node import (
BaseNodePostprocessor,
)
from llama_index.indices.vector_store import GPTVectorStoreIndexQuery
from llama_index.indices.response.response_builder import ResponseMode
from llama_index.indices.service_context import ServiceContext
from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
from llama_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
SimpleInputPrompt,
)
from core.index.query.synthesizer import EnhanceResponseSynthesizer
class EnhanceGPTVectorStoreIndexQuery(GPTVectorStoreIndexQuery):
@classmethod
def from_args(
cls,
index_struct: IS,
service_context: ServiceContext,
docstore: Optional[BaseDocumentStore] = None,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
verbose: bool = False,
# response synthesizer args
response_mode: ResponseMode = ResponseMode.DEFAULT,
text_qa_template: Optional[QuestionAnswerPrompt] = None,
refine_template: Optional[RefinePrompt] = None,
simple_template: Optional[SimpleInputPrompt] = None,
response_kwargs: Optional[Dict] = None,
use_async: bool = False,
streaming: bool = False,
optimizer: Optional[BaseTokenUsageOptimizer] = None,
# class-specific args
**kwargs: Any,
) -> "BaseGPTIndexQuery":
response_synthesizer = EnhanceResponseSynthesizer.from_args(
service_context=service_context,
text_qa_template=text_qa_template,
refine_template=refine_template,
simple_template=simple_template,
response_mode=response_mode,
response_kwargs=response_kwargs,
use_async=use_async,
streaming=streaming,
optimizer=optimizer,
)
return cls(
index_struct=index_struct,
service_context=service_context,
response_synthesizer=response_synthesizer,
docstore=docstore,
node_postprocessors=node_postprocessors,
verbose=verbose,
**kwargs,
)

View File

@@ -0,0 +1,258 @@
import json
import weaviate
from dataclasses import field
from typing import List, Any, Dict, Optional
from core.vector_store.base import BaseVectorStoreClient, BaseGPTVectorStoreIndex, EnhanceVectorStore
from llama_index import ServiceContext, GPTWeaviateIndex, GPTVectorStoreIndex
from llama_index.data_structs.data_structs_v2 import WeaviateIndexDict, Node
from llama_index.data_structs.node_v2 import DocumentRelationship
from llama_index.readers.weaviate.client import _class_name, NODE_SCHEMA, _logger
from llama_index.vector_stores import WeaviateVectorStore
from llama_index.vector_stores.types import VectorStoreQuery, VectorStoreQueryResult, VectorStoreQueryMode
from llama_index.readers.weaviate.utils import (
parse_get_response,
validate_client,
)
class WeaviateVectorStoreClient(BaseVectorStoreClient):
def __init__(self, endpoint: str, api_key: str, grpc_enabled: bool):
self._client = self.init_from_config(endpoint, api_key, grpc_enabled)
def init_from_config(self, endpoint: str, api_key: str, grpc_enabled: bool):
auth_config = weaviate.auth.AuthApiKey(api_key=api_key)
weaviate.connect.connection.has_grpc = grpc_enabled
return weaviate.Client(
url=endpoint,
auth_client_secret=auth_config,
timeout_config=(5, 15),
startup_period=None
)
def get_index(self, service_context: ServiceContext, config: dict) -> GPTVectorStoreIndex:
index_struct = WeaviateIndexDict()
if self._client is None:
raise Exception("Vector client is not initialized.")
# {"class_prefix": "Gpt_index_xxx"}
class_prefix = config.get('class_prefix')
if not class_prefix:
raise Exception("class_prefix cannot be None.")
return GPTWeaviateEnhanceIndex(
service_context=service_context,
index_struct=index_struct,
vector_store=WeaviateWithSimilaritiesVectorStore(
weaviate_client=self._client,
class_prefix=class_prefix
)
)
def to_index_config(self, index_id: str) -> dict:
return {"class_prefix": index_id}
class WeaviateWithSimilaritiesVectorStore(WeaviateVectorStore, EnhanceVectorStore):
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes."""
nodes = self.weaviate_query(
self._client,
self._class_prefix,
query,
)
nodes = nodes[: query.similarity_top_k]
node_idxs = [str(i) for i in range(len(nodes))]
similarities = []
for node in nodes:
similarities.append(node.extra_info['similarity'])
del node.extra_info['similarity']
return VectorStoreQueryResult(nodes=nodes, ids=node_idxs, similarities=similarities)
def weaviate_query(
self,
client: Any,
class_prefix: str,
query_spec: VectorStoreQuery,
) -> List[Node]:
"""Convert to LlamaIndex list."""
validate_client(client)
class_name = _class_name(class_prefix)
prop_names = [p["name"] for p in NODE_SCHEMA]
vector = query_spec.query_embedding
# build query
query = client.query.get(class_name, prop_names).with_additional(["id", "vector", "certainty"])
if query_spec.mode == VectorStoreQueryMode.DEFAULT:
_logger.debug("Using vector search")
if vector is not None:
query = query.with_near_vector(
{
"vector": vector,
}
)
elif query_spec.mode == VectorStoreQueryMode.HYBRID:
_logger.debug(f"Using hybrid search with alpha {query_spec.alpha}")
query = query.with_hybrid(
query=query_spec.query_str,
alpha=query_spec.alpha,
vector=vector,
)
query = query.with_limit(query_spec.similarity_top_k)
_logger.debug(f"Using limit of {query_spec.similarity_top_k}")
# execute query
query_result = query.do()
# parse results
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
results = [self._to_node(entry) for entry in entries]
return results
def _to_node(self, entry: Dict) -> Node:
"""Convert to Node."""
extra_info_str = entry["extra_info"]
if extra_info_str == "":
extra_info = None
else:
extra_info = json.loads(extra_info_str)
if 'certainty' in entry['_additional']:
if extra_info:
extra_info['similarity'] = entry['_additional']['certainty']
else:
extra_info = {'similarity': entry['_additional']['certainty']}
node_info_str = entry["node_info"]
if node_info_str == "":
node_info = None
else:
node_info = json.loads(node_info_str)
relationships_str = entry["relationships"]
relationships: Dict[DocumentRelationship, str]
if relationships_str == "":
relationships = field(default_factory=dict)
else:
relationships = {
DocumentRelationship(k): v for k, v in json.loads(relationships_str).items()
}
return Node(
text=entry["text"],
doc_id=entry["doc_id"],
embedding=entry["_additional"]["vector"],
extra_info=extra_info,
node_info=node_info,
relationships=relationships,
)
def delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document.
Args:
doc_id (str): document id
"""
delete_document(self._client, doc_id, self._class_prefix)
def delete_node(self, node_id: str):
"""
Delete node from the index.
:param node_id: node id
"""
delete_node(self._client, node_id, self._class_prefix)
def exists_by_node_id(self, node_id: str) -> bool:
"""
Get node from the index by node id.
:param node_id: node id
"""
entry = get_by_node_id(self._client, node_id, self._class_prefix)
return True if entry else False
class GPTWeaviateEnhanceIndex(GPTWeaviateIndex, BaseGPTVectorStoreIndex):
pass
def delete_document(client: Any, ref_doc_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["ref_doc_id"],
"operator": "Equal",
"valueString": ref_doc_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
while len(entries) > 0:
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def delete_node(client: Any, node_id: str, class_prefix: str) -> None:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
for entry in entries:
client.data_object.delete(entry["_additional"]["id"], class_name)
def get_by_node_id(client: Any, node_id: str, class_prefix: str) -> Optional[Dict]:
"""Delete entry."""
validate_client(client)
# make sure that each entry
class_name = _class_name(class_prefix)
where_filter = {
"path": ["doc_id"],
"operator": "Equal",
"valueString": node_id,
}
query = (
client.query.get(class_name).with_additional(["id"]).with_where(where_filter)
)
query_result = query.do()
parsed_result = parse_get_response(query_result)
entries = parsed_result[class_name]
if len(entries) == 0:
return None
return entries[0]