chore: apply ruff's pyupgrade linter rules to modernize Python code with targeted version (#2419)

This commit is contained in:
Bowen Liang
2024-02-09 15:21:33 +08:00
committed by GitHub
parent 589099a005
commit 063191889d
246 changed files with 912 additions and 937 deletions

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, List
from typing import Any
from langchain.schema import BaseRetriever, Document
@@ -53,7 +53,7 @@ class BaseIndex(ABC):
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
) -> list[Document]:
raise NotImplementedError
def delete(self) -> None:

View File

@@ -1,5 +1,4 @@
import re
from typing import Set
import jieba
from jieba.analyse import default_tfidf
@@ -12,7 +11,7 @@ class JiebaKeywordTableHandler:
def __init__(self):
default_tfidf.stop_words = STOPWORDS
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]:
def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf."""
keywords = jieba.analyse.extract_tags(
sentence=text,
@@ -21,7 +20,7 @@ class JiebaKeywordTableHandler:
return set(self._expand_tokens_with_subtokens(keywords))
def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]:
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords."""
results = set()
for token in tokens:

View File

@@ -1,6 +1,6 @@
import json
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from langchain.schema import BaseRetriever, Document
from pydantic import BaseModel, Extra, Field
@@ -116,7 +116,7 @@ class KeywordTableIndex(BaseIndex):
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
) -> list[Document]:
keyword_table = self._get_dataset_keyword_table()
search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {}
@@ -221,7 +221,7 @@ class KeywordTableIndex(BaseIndex):
keywords = keyword_table_handler.extract_keywords(query)
# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[str, int] = defaultdict(int)
chunk_indices_count: dict[str, int] = defaultdict(int)
keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())]
for keyword in keywords:
for node_id in keyword_table[keyword]:
@@ -235,7 +235,7 @@ class KeywordTableIndex(BaseIndex):
return sorted_chunk_indices[: k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]):
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = db.session.query(DocumentSegment).filter(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id == node_id
@@ -244,7 +244,7 @@ class KeywordTableIndex(BaseIndex):
document_segment.keywords = keywords
db.session.commit()
def create_segment_keywords(self, node_id: str, keywords: List[str]):
def create_segment_keywords(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
self._update_segment_keywords(self.dataset.id, node_id, keywords)
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
@@ -266,7 +266,7 @@ class KeywordTableIndex(BaseIndex):
keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords))
self._save_dataset_keyword_table(keyword_table)
def update_segment_keywords_index(self, node_id: str, keywords: List[str]):
def update_segment_keywords_index(self, node_id: str, keywords: list[str]):
keyword_table = self._get_dataset_keyword_table()
keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords)
self._save_dataset_keyword_table(keyword_table)
@@ -282,7 +282,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
def get_relevant_documents(self, query: str) -> list[Document]:
"""Get documents relevant for a query.
Args:
@@ -293,7 +293,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel):
"""
return self.index.search(query, **self.search_kwargs)
async def aget_relevant_documents(self, query: str) -> List[Document]:
async def aget_relevant_documents(self, query: str) -> list[Document]:
raise NotImplementedError("KeywordTableRetriever does not support async")

View File

@@ -1,7 +1,7 @@
import json
import logging
from abc import abstractmethod
from typing import Any, List, cast
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
@@ -43,13 +43,13 @@ class BaseVectorIndex(BaseIndex):
def search_by_full_text_index(
self, query: str,
**kwargs: Any
) -> List[Document]:
) -> list[Document]:
raise NotImplementedError
def search(
self, query: str,
**kwargs: Any
) -> List[Document]:
) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)

View File

@@ -1,4 +1,4 @@
from typing import Any, List, cast
from typing import Any, cast
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
@@ -160,6 +160,6 @@ class MilvusVectorIndex(BaseVectorIndex):
],
))
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, List, Optional, cast
from typing import Any, Optional, cast
import qdrant_client
from langchain.embeddings.base import Embeddings
@@ -210,7 +210,7 @@ class QdrantVectorIndex(BaseVectorIndex):
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)

View File

@@ -1,4 +1,4 @@
from typing import Any, List, Optional, cast
from typing import Any, Optional, cast
import requests
import weaviate
@@ -172,7 +172,7 @@ class WeaviateVectorIndex(BaseVectorIndex):
return False
def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]:
def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs)