chore: apply ruff's pyupgrade linter rules to modernize Python code with targeted version (#2419)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
@@ -53,7 +53,7 @@ class BaseIndex(ABC):
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete(self) -> None:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
import jieba
|
||||
from jieba.analyse import default_tfidf
|
||||
@@ -12,7 +11,7 @@ class JiebaKeywordTableHandler:
|
||||
def __init__(self):
|
||||
default_tfidf.stop_words = STOPWORDS
|
||||
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
|
||||
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
|
||||
"""Extract keywords with JIEBA tfidf."""
|
||||
keywords = jieba.analyse.extract_tags(
|
||||
sentence=text,
|
||||
@@ -21,7 +20,7 @@ class JiebaKeywordTableHandler:
|
||||
|
||||
return set(self._expand_tokens_with_subtokens(keywords))
|
||||
|
||||
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
|
||||
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
|
||||
"""Get subtokens from a list of tokens., filtering for stopwords."""
|
||||
results = set()
|
||||
for token in tokens:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from pydantic import BaseModel, Extra, Field
|
||||
@@ -116,7 +116,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
|
||||
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
|
||||
@@ -221,7 +221,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
keywords = keyword_table_handler.extract_keywords(query)
|
||||
|
||||
# go through text chunks in order of most matching keywords
|
||||
chunk_indices_count: Dict[str, int] = defaultdict(int)
|
||||
chunk_indices_count: dict[str, int] = defaultdict(int)
|
||||
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
|
||||
for keyword in keywords:
|
||||
for node_id in keyword_table[keyword]:
|
||||
@@ -235,7 +235,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
|
||||
return sorted_chunk_indices[: k]
|
||||
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]):
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.index_node_id == node_id
|
||||
@@ -244,7 +244,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
document_segment.keywords = keywords
|
||||
db.session.commit()
|
||||
|
||||
def create_segment_keywords(self, node_id: str, keywords: List[str]):
|
||||
def create_segment_keywords(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
self._update_segment_keywords(self.dataset.id, node_id, keywords)
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
@@ -266,7 +266,7 @@ class KeywordTableIndex(BaseIndex):
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
|
||||
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
|
||||
keyword_table = self._get_dataset_keyword_table()
|
||||
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
|
||||
self._save_dataset_keyword_table(keyword_table)
|
||||
@@ -282,7 +282,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||
def get_relevant_documents(self, query: str) -> list[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
Args:
|
||||
@@ -293,7 +293,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
|
||||
"""
|
||||
return self.index.search(query, **self.search_kwargs)
|
||||
|
||||
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||
async def aget_relevant_documents(self, query: str) -> list[Document]:
|
||||
raise NotImplementedError("KeywordTableRetriever does not support async")
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
@@ -43,13 +43,13 @@ class BaseVectorIndex(BaseIndex):
|
||||
def search_by_full_text_index(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def search(
|
||||
self, query: str,
|
||||
**kwargs: Any
|
||||
) -> List[Document]:
|
||||
) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
@@ -160,6 +160,6 @@ class MilvusVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
# milvus/zilliz doesn't support bm25 search
|
||||
return []
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, List, Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import qdrant_client
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@@ -210,7 +210,7 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, List, Optional, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
import weaviate
|
||||
@@ -172,7 +172,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
|
||||
|
||||
return False
|
||||
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
|
||||
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user