Migrate SQLAlchemy from 1.x to 2.0 with automated and manual adjustments (#23224)

Co-authored-by: Yongtao Huang <99629139+hyongtao-db@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yongtao Huang
2025-09-02 10:30:19 +08:00
committed by GitHub
parent 2e89d29c87
commit be3af1e234
33 changed files with 226 additions and 260 deletions

View File

@@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from sqlalchemy import Float, and_, or_, text
from sqlalchemy import Float, and_, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session
@@ -135,7 +135,8 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available
if not dataset:
@@ -240,15 +241,12 @@ class DatasetRetrieval:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
@@ -327,7 +325,8 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
if dataset:
results = []
if dataset.provider == "external":
@@ -514,22 +513,18 @@ class DatasetRetrieval:
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document = (
db.session.query(DatasetDocument)
.where(DatasetDocument.id == document.metadata["document_id"])
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
db.session.query(DocumentSegment)
@@ -600,7 +595,8 @@ class DatasetRetrieval:
):
with flask_app.app_context():
with Session(db.engine) as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
if not dataset:
return []
@@ -685,7 +681,8 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available
if not dataset:
@@ -958,7 +955,8 @@ class DatasetRetrieval:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(metadata_stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None: