orm filter -> where (#22801)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Asuka Minato
2025-07-24 01:57:45 +09:00
committed by GitHub
parent e64e7563f6
commit ef51678c73
161 changed files with 828 additions and 857 deletions

View File

@@ -93,11 +93,11 @@ class Jieba(BaseKeyword):
documents = []
for chunk_index in sorted_chunk_indices:
segment_query = db.session.query(DocumentSegment).filter(
segment_query = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
if segment:
@@ -214,7 +214,7 @@ class Jieba(BaseKeyword):
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
)
if document_segment:

View File

@@ -127,7 +127,7 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return []
metadata_condition = (
@@ -145,7 +145,7 @@ class RetrievalService:
@classmethod
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
with Session(db.engine) as session:
return session.query(Dataset).filter(Dataset.id == dataset_id).first()
return session.query(Dataset).where(Dataset.id == dataset_id).first()
@classmethod
def keyword_search(
@@ -294,7 +294,7 @@ class RetrievalService:
dataset_documents = {
doc.id: doc
for doc in db.session.query(DatasetDocument)
.filter(DatasetDocument.id.in_(document_ids))
.where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all()
}
@@ -318,7 +318,7 @@ class RetrievalService:
child_index_node_id = document.metadata.get("doc_id")
child_chunk = (
db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
if not child_chunk:
@@ -326,7 +326,7 @@ class RetrievalService:
segment = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
@@ -381,7 +381,7 @@ class RetrievalService:
segment = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",

View File

@@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:

View File

@@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
@@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)

View File

@@ -47,7 +47,7 @@ class Vector:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
)
if whitelist:

View File

@@ -42,7 +42,7 @@ class DatasetDocumentStore:
@property
def docs(self) -> dict[str, Document]:
document_segments = (
db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all()
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
)
output = {}
@@ -63,7 +63,7 @@ class DatasetDocumentStore:
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id)
.where(DocumentSegment.document_id == self._document_id)
.scalar()
)
@@ -147,7 +147,7 @@ class DatasetDocumentStore:
segment_document.tokens = tokens
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).filter(
db.session.query(ChildChunk).where(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
@@ -230,7 +230,7 @@ class DatasetDocumentStore:
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
)

View File

@@ -366,7 +366,7 @@ class NotionExtractor(BaseExtractor):
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
.where(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",

View File

@@ -118,7 +118,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
.where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
@@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids)
if delete_child_chunks:
db.session.query(ChildChunk).filter(
db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
db.session.commit()
@@ -136,7 +136,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector.delete()
if delete_child_chunks:
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
db.session.commit()
def retrieve(

View File

@@ -135,7 +135,7 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
@@ -242,7 +242,7 @@ class DatasetRetrieval:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
@@ -327,7 +327,7 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
results = []
if dataset.provider == "external":
@@ -516,14 +516,14 @@ class DatasetRetrieval:
if document.metadata is not None:
dataset_document = (
db.session.query(DatasetDocument)
.filter(DatasetDocument.id == document.metadata["document_id"])
.where(DatasetDocument.id == document.metadata["document_id"])
.first()
)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.filter(
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
@@ -533,7 +533,7 @@ class DatasetRetrieval:
if child_chunk:
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
@@ -541,13 +541,13 @@ class DatasetRetrieval:
)
db.session.commit()
else:
query = db.session.query(DocumentSegment).filter(
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
@@ -600,7 +600,7 @@ class DatasetRetrieval:
):
with flask_app.app_context():
with Session(db.engine) as session:
dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return []
@@ -685,7 +685,7 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
@@ -862,7 +862,7 @@ class DatasetRetrieval:
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
inputs: dict,
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
document_query = db.session.query(DatasetDocument).filter(
document_query = db.session.query(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -930,9 +930,9 @@ class DatasetRetrieval:
raise ValueError("Invalid metadata filtering mode")
if filters:
if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters))
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@@ -958,7 +958,7 @@ class DatasetRetrieval:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None: