Fix basedpyright type errors (#25435)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,19 @@ if TYPE_CHECKING:
|
||||
MetadataFilter = Union[DictFilter, common_types.Filter]
|
||||
|
||||
|
||||
class PathQdrantParams(BaseModel):
|
||||
path: str
|
||||
|
||||
|
||||
class UrlQdrantParams(BaseModel):
|
||||
url: str
|
||||
api_key: Optional[str]
|
||||
timeout: float
|
||||
verify: bool
|
||||
grpc_port: int
|
||||
prefer_grpc: bool
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str] = None
|
||||
@@ -50,7 +63,7 @@ class QdrantConfig(BaseModel):
|
||||
replication_factor: int = 1
|
||||
write_consistency_factor: int = 1
|
||||
|
||||
def to_qdrant_params(self):
|
||||
def to_qdrant_params(self) -> PathQdrantParams | UrlQdrantParams:
|
||||
if self.endpoint and self.endpoint.startswith("path:"):
|
||||
path = self.endpoint.replace("path:", "")
|
||||
if not os.path.isabs(path):
|
||||
@@ -58,23 +71,23 @@ class QdrantConfig(BaseModel):
|
||||
raise ValueError("Root path is not set")
|
||||
path = os.path.join(self.root_path, path)
|
||||
|
||||
return {"path": path}
|
||||
return PathQdrantParams(path=path)
|
||||
else:
|
||||
return {
|
||||
"url": self.endpoint,
|
||||
"api_key": self.api_key,
|
||||
"timeout": self.timeout,
|
||||
"verify": self.endpoint.startswith("https"),
|
||||
"grpc_port": self.grpc_port,
|
||||
"prefer_grpc": self.prefer_grpc,
|
||||
}
|
||||
return UrlQdrantParams(
|
||||
url=self.endpoint,
|
||||
api_key=self.api_key,
|
||||
timeout=self.timeout,
|
||||
verify=self.endpoint.startswith("https"),
|
||||
grpc_port=self.grpc_port,
|
||||
prefer_grpc=self.prefer_grpc,
|
||||
)
|
||||
|
||||
|
||||
class QdrantVector(BaseVector):
|
||||
def __init__(self, collection_name: str, group_id: str, config: QdrantConfig, distance_func: str = "Cosine"):
|
||||
super().__init__(collection_name)
|
||||
self._client_config = config
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params())
|
||||
self._client = qdrant_client.QdrantClient(**self._client_config.to_qdrant_params().model_dump())
|
||||
self._distance_func = distance_func.upper()
|
||||
self._group_id = group_id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user