Migrate SQLAlchemy from 1.x to 2.0 with automated and manual adjustments (#23224)

Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yongtao Huang
2025-09-02 10:30:19 +08:00
committed by GitHub
parent 2e89d29c87
commit be3af1e234
33 changed files with 226 additions and 260 deletions

View File

@@ -3,6 +3,7 @@ from typing import Any, Optional
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@@ -211,11 +212,10 @@ class Jieba(BaseKeyword):
return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
)
document_segment = db.session.scalar(stmt)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)

View File

@@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
@@ -127,7 +128,8 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []
metadata_condition = (
@@ -316,10 +318,8 @@ class RetrievalService:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk = (
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
continue
@@ -378,17 +378,13 @@ class RetrievalService:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
if not segment:
continue

View File

@@ -18,6 +18,7 @@ from qdrant_client.http.models import (
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -445,11 +446,8 @@ class QdrantVector(BaseVector):
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
dataset_collection_binding = db.session.scalars(stmt).one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:

View File

@@ -20,6 +20,7 @@ from qdrant_client.http.models import (
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

View File

@@ -3,6 +3,8 @@ import time
from abc import ABC, abstractmethod
from typing import Any, Optional
from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@@ -45,11 +47,10 @@ class Vector:
vector_type = self._dataset.index_struct_dict["type"]
else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
stmt = select(Whitelist).where(
Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
)
whitelist = db.session.scalars(stmt).one_or_none()
if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT