orm filter -> where (#22801)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Asuka Minato
2025-07-24 01:57:45 +09:00
committed by GitHub
parent e64e7563f6
commit ef51678c73
161 changed files with 828 additions and 857 deletions

View File

@@ -80,7 +80,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids
@@ -92,14 +92,14 @@ class DatasetService:
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
# only show datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(Dataset.id.in_(permitted_dataset_ids))
query = query.where(Dataset.id.in_(permitted_dataset_ids))
else:
return [], 0
else:
if user.current_role != TenantAccountRole.OWNER or not include_all:
# show all datasets that the user has permission to access
if permitted_dataset_ids:
query = query.filter(
query = query.where(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
@@ -112,7 +112,7 @@ class DatasetService:
)
)
else:
query = query.filter(
query = query.where(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
@@ -122,15 +122,15 @@ class DatasetService:
)
else:
# if no user, only show datasets that are shared with all team members
query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
if search:
query = query.filter(Dataset.name.ilike(f"%{search}%"))
query = query.where(Dataset.name.ilike(f"%{search}%"))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids)
if target_ids:
query = query.filter(Dataset.id.in_(target_ids))
query = query.where(Dataset.id.in_(target_ids))
else:
return [], 0
@@ -143,7 +143,7 @@ class DatasetService:
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == dataset_id)
.where(DatasetProcessRule.dataset_id == dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
@@ -158,7 +158,7 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
stmt = select(Dataset).where(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
@@ -697,7 +697,7 @@ class DatasetService:
def get_related_apps(dataset_id: str):
return (
db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id)
.where(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
@@ -714,7 +714,7 @@ class DatasetService:
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.filter(
.where(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
@@ -843,7 +843,7 @@ class DocumentService:
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
if document_id:
document = (
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
return document
else:
@@ -851,7 +851,7 @@ class DocumentService:
@staticmethod
def get_document_by_id(document_id: str) -> Optional[Document]:
document = db.session.query(Document).filter(Document.id == document_id).first()
document = db.session.query(Document).where(Document.id == document_id).first()
return document
@@ -859,7 +859,7 @@ class DocumentService:
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
.where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
@@ -873,7 +873,7 @@ class DocumentService:
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
.where(
Document.dataset_id == dataset_id,
Document.enabled == True,
)
@@ -886,7 +886,7 @@ class DocumentService:
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
.where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
@@ -901,7 +901,7 @@ class DocumentService:
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.all()
)
return documents
@@ -910,7 +910,7 @@ class DocumentService:
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
documents = (
db.session.query(Document)
.filter(
.where(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id,
@@ -922,7 +922,7 @@ class DocumentService:
@staticmethod
def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none()
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
return file_detail
@staticmethod
@@ -950,7 +950,7 @@ class DocumentService:
@staticmethod
def delete_documents(dataset: Dataset, document_ids: list[str]):
documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all()
documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
for document in documents
@@ -1189,7 +1189,7 @@ class DocumentService:
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
@@ -1270,7 +1270,7 @@ class DocumentService:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
.where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
@@ -1413,7 +1413,7 @@ class DocumentService:
def get_tenant_documents_count():
documents_count = (
db.session.query(Document)
.filter(
.where(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
@@ -1469,7 +1469,7 @@ class DocumentService:
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
.filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
@@ -1489,7 +1489,7 @@ class DocumentService:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
.where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
@@ -2005,7 +2005,7 @@ class SegmentService:
with redis_client.lock(lock_name, timeout=600):
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == document.id)
.where(DocumentSegment.document_id == document.id)
.scalar()
)
segment_document = DocumentSegment(
@@ -2043,7 +2043,7 @@ class SegmentService:
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
return segment
@classmethod
@@ -2062,7 +2062,7 @@ class SegmentService:
)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == document.id)
.where(DocumentSegment.document_id == document.id)
.scalar()
)
pre_segment_data_list = []
@@ -2201,7 +2201,7 @@ class SegmentService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -2276,7 +2276,7 @@ class SegmentService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -2295,7 +2295,7 @@ class SegmentService:
segment.status = "error"
segment.error = str(e)
db.session.commit()
new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
return new_segment
@classmethod
@@ -2321,7 +2321,7 @@ class SegmentService:
index_node_ids = (
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.filter(
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2332,7 +2332,7 @@ class SegmentService:
index_node_ids = [index_node_id[0] for index_node_id in index_node_ids]
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete()
db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
db.session.commit()
@classmethod
@@ -2340,7 +2340,7 @@ class SegmentService:
if action == "enable":
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2367,7 +2367,7 @@ class SegmentService:
elif action == "disable":
segments = (
db.session.query(DocumentSegment)
.filter(
.where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2404,7 +2404,7 @@ class SegmentService:
index_node_hash = helper.generate_text_hash(content)
child_chunk_count = (
db.session.query(ChildChunk)
.filter(
.where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
@@ -2414,7 +2414,7 @@ class SegmentService:
)
max_position = (
db.session.query(func.max(ChildChunk.position))
.filter(
.where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
@@ -2457,7 +2457,7 @@ class SegmentService:
) -> list[ChildChunk]:
child_chunks = (
db.session.query(ChildChunk)
.filter(
.where(
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
@@ -2578,7 +2578,7 @@ class SegmentService:
"""Get a child chunk by its ID."""
result = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None
@@ -2594,15 +2594,15 @@ class SegmentService:
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
query = select(DocumentSegment).filter(
query = select(DocumentSegment).where(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
)
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
query = query.where(DocumentSegment.status.in_(status_list))
if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@@ -2615,7 +2615,7 @@ class SegmentService:
) -> tuple[DocumentSegment, Document]:
"""Update a segment by its ID with validation and checks."""
# check dataset
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@@ -2647,7 +2647,7 @@ class SegmentService:
# check segment
segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment:
@@ -2664,7 +2664,7 @@ class SegmentService:
"""Get a segment by its ID."""
result = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None
@@ -2677,7 +2677,7 @@ class DatasetCollectionBindingService:
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
.where(
DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type,
@@ -2703,7 +2703,7 @@ class DatasetCollectionBindingService:
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.filter(
.where(
DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
)
.order_by(DatasetCollectionBinding.created_at)
@@ -2722,7 +2722,7 @@ class DatasetPermissionService:
db.session.query(
DatasetPermission.account_id,
)
.filter(DatasetPermission.dataset_id == dataset_id)
.where(DatasetPermission.dataset_id == dataset_id)
.all()
)
@@ -2735,7 +2735,7 @@ class DatasetPermissionService:
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
try:
db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
permissions = []
for user in user_list:
permission = DatasetPermission(
@@ -2771,7 +2771,7 @@ class DatasetPermissionService:
@classmethod
def clear_partial_member_list(cls, dataset_id):
try:
db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
db.session.commit()
except Exception as e:
db.session.rollback()