use deco to avoid current_user (#26077)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-10-16 15:45:51 +09:00
committed by GitHub
parent bd01af6415
commit cced33d068
109 changed files with 526 additions and 788 deletions

View File

@@ -1,7 +1,6 @@
from typing import Any, cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound
@@ -30,10 +29,9 @@ from extensions.ext_database import db
from fields.app_fields import related_app_list
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
from fields.document_fields import document_status_fields
from libs.login import login_required
from libs.login import current_account_with_tenant, login_required
from libs.validators import validate_description_length
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
@@ -138,6 +136,7 @@ class DatasetListApi(Resource):
@account_initialization_required
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
@@ -146,15 +145,15 @@ class DatasetListApi(Resource):
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, current_user.current_tenant_id, current_user, search, tag_ids, include_all
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
)
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -251,6 +250,7 @@ class DatasetListApi(Resource):
required=False,
)
args = parser.parse_args()
current_user, current_tenant_id = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
@@ -258,11 +258,11 @@ class DatasetListApi(Resource):
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=cast(Account, current_user),
account=current_user,
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
@@ -286,6 +286,7 @@ class DatasetApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -305,7 +306,7 @@ class DatasetApi(Resource):
# check embedding setting
provider_manager = ProviderManager()
configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
configurations = provider_manager.get_configurations(tenant_id=current_tenant_id)
embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
@@ -418,6 +419,7 @@ class DatasetApi(Resource):
)
args = parser.parse_args()
data = request.get_json()
current_user, current_tenant_id = current_account_with_tenant()
# check embedding model setting
if (
@@ -440,7 +442,7 @@ class DatasetApi(Resource):
raise NotFound("Dataset not found.")
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id
tenant_id = current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
DatasetPermissionService.update_partial_member_list(
@@ -464,9 +466,10 @@ class DatasetApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
def delete(self, dataset_id):
dataset_id_str = str(dataset_id)
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.is_editor or current_user.is_dataset_operator):
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
raise Forbidden()
try:
@@ -505,6 +508,7 @@ class DatasetQueryApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -556,15 +560,14 @@ class DatasetIndexingEstimateApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args()
_, current_tenant_id = current_account_with_tenant()
# validate args
DocumentService.estimate_args_validate(args)
extract_settings = []
if args["info_list"]["data_source_type"] == "upload_file":
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)
)
select(UploadFile).where(UploadFile.tenant_id == current_tenant_id, UploadFile.id.in_(file_ids))
).all()
if file_details is None:
@@ -592,7 +595,7 @@ class DatasetIndexingEstimateApi(Resource):
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
}
),
document_model=args["doc_form"],
@@ -608,7 +611,7 @@ class DatasetIndexingEstimateApi(Resource):
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"tenant_id": current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
@@ -621,7 +624,7 @@ class DatasetIndexingEstimateApi(Resource):
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
current_tenant_id,
extract_settings,
args["process_rule"],
args["doc_form"],
@@ -652,6 +655,7 @@ class DatasetRelatedAppListApi(Resource):
@account_initialization_required
@marshal_with(related_app_list)
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@@ -683,11 +687,10 @@ class DatasetIndexingStatusApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
_, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
documents = db.session.scalars(
select(Document).where(
Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id
)
select(Document).where(Document.dataset_id == dataset_id, Document.tenant_id == current_tenant_id)
).all()
documents_status = []
for document in documents:
@@ -739,10 +742,9 @@ class DatasetApiKeyApi(Resource):
@account_initialization_required
@marshal_with(api_key_list)
def get(self):
_, current_tenant_id = current_account_with_tenant()
keys = db.session.scalars(
select(ApiToken).where(
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id
)
select(ApiToken).where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
).all()
return {"items": keys}
@@ -752,12 +754,13 @@ class DatasetApiKeyApi(Resource):
@marshal_with(api_key_fields)
def post(self):
# The role of the current user in the ta table must be admin or owner
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_admin_or_owner:
raise Forbidden()
current_key_count = (
db.session.query(ApiToken)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
.count()
)
@@ -770,7 +773,7 @@ class DatasetApiKeyApi(Resource):
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = current_user.current_tenant_id
api_token.tenant_id = current_tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
@@ -790,6 +793,7 @@ class DatasetApiDeleteApi(Resource):
@login_required
@account_initialization_required
def delete(self, api_key_id):
current_user, current_tenant_id = current_account_with_tenant()
api_key_id = str(api_key_id)
# The role of the current user in the ta table must be admin or owner
@@ -799,7 +803,7 @@ class DatasetApiDeleteApi(Resource):
key = (
db.session.query(ApiToken)
.where(
ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.tenant_id == current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
)
@@ -898,6 +902,7 @@ class DatasetPermissionUserListApi(Resource):
@login_required
@account_initialization_required
def get(self, dataset_id):
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None: