Fix: Enable Pyright and Fix Typing Errors in Datasets Controller (#26425)

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-09-30 00:38:59 +09:00
committed by GitHub
parent bbdcbac544
commit f79d8baf63
10 changed files with 53 additions and 75 deletions

View File

@@ -1,4 +1,5 @@
import flask_restx
from typing import Any, cast
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
@@ -31,12 +32,13 @@ from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fi
from fields.document_fields import document_status_fields
from libs.login import login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.account import Account
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
def _validate_name(name):
def _validate_name(name: str) -> str:
if not name or len(name) < 1 or len(name) > 40:
raise ValueError("Name must be between 1 to 40 characters.")
return name
@@ -92,7 +94,7 @@ class DatasetListApi(Resource):
for embedding_model in embedding_models:
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
data = marshal(datasets, dataset_detail_fields)
data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields))
for item in data:
# convert embedding_model_provider to plugin standard format
if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]:
@@ -192,7 +194,7 @@ class DatasetListApi(Resource):
name=args["name"],
description=args["description"],
indexing_technique=args["indexing_technique"],
account=current_user,
account=cast(Account, current_user),
permission=DatasetPermissionEnum.ONLY_ME,
provider=args["provider"],
external_knowledge_api_id=args["external_knowledge_api_id"],
@@ -224,7 +226,7 @@ class DatasetApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
data = marshal(dataset, dataset_detail_fields)
data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
if dataset.indexing_technique == "high_quality":
if dataset.embedding_model_provider:
provider_id = ModelProviderID(dataset.embedding_model_provider)
@@ -369,7 +371,7 @@ class DatasetApi(Resource):
if dataset is None:
raise NotFound("Dataset not found.")
result_data = marshal(dataset, dataset_detail_fields)
result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields))
tenant_id = current_user.current_tenant_id
if data.get("partial_member_list") and data.get("permission") == "partial_members":
@@ -688,7 +690,7 @@ class DatasetApiKeyApi(Resource):
)
if current_key_count >= self.max_keys:
flask_restx.abort(
api.abort(
400,
message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
code="max_keys_exceeded",
@@ -733,7 +735,7 @@ class DatasetApiDeleteApi(Resource):
)
if key is None:
flask_restx.abort(404, message="API key not found")
api.abort(404, message="API key not found")
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()