refactor: port reqparse to BaseModel (#28993)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-12-08 15:31:19 +09:00
committed by GitHub
parent 2f96374837
commit 05fe92a541
44 changed files with 1531 additions and 1894 deletions

View File

@@ -1,8 +1,12 @@
from typing import Any
from flask import request
from flask_restx import marshal, reqparse
from flask_restx import marshal
from pydantic import BaseModel, Field
from werkzeug.exceptions import NotFound
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.service_api import service_api_ns
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
@@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
# Define parsers for segment operations
segment_create_parser = reqparse.RequestParser().add_argument(
"segments", type=list, required=False, nullable=True, location="json"
)
segment_list_parser = (
reqparse.RequestParser()
.add_argument("status", type=str, action="append", default=[], location="args")
.add_argument("keyword", type=str, default=None, location="args")
)
class SegmentCreatePayload(BaseModel):
segments: list[dict[str, Any]] | None = None
segment_update_parser = reqparse.RequestParser().add_argument(
"segment", type=dict, required=False, nullable=True, location="json"
)
child_chunk_create_parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
)
class SegmentListQuery(BaseModel):
status: list[str] = Field(default_factory=list)
keyword: str | None = None
child_chunk_list_parser = (
reqparse.RequestParser()
.add_argument("limit", type=int, default=20, location="args")
.add_argument("keyword", type=str, default=None, location="args")
.add_argument("page", type=int, default=1, location="args")
)
child_chunk_update_parser = reqparse.RequestParser().add_argument(
"content", type=str, required=True, nullable=False, location="json"
class SegmentUpdatePayload(BaseModel):
segment: SegmentUpdateArgs
class ChildChunkCreatePayload(BaseModel):
content: str
class ChildChunkListQuery(BaseModel):
limit: int = Field(default=20, ge=1)
keyword: str | None = None
page: int = Field(default=1, ge=1)
class ChildChunkUpdatePayload(BaseModel):
content: str
register_schema_models(
service_api_ns,
SegmentCreatePayload,
SegmentListQuery,
SegmentUpdatePayload,
ChildChunkCreatePayload,
ChildChunkListQuery,
ChildChunkUpdatePayload,
)
@@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument(
class SegmentApi(DatasetApiResource):
"""Resource for segments."""
@service_api_ns.expect(segment_create_parser)
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
@service_api_ns.doc("create_segments")
@service_api_ns.doc(description="Create segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
args = segment_create_parser.parse_args()
if args["segments"] is not None:
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
if payload.segments is not None:
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
if segments_limit > 0 and len(args["segments"]) > segments_limit:
if segments_limit > 0 and len(payload.segments) > segments_limit:
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
for args_item in args["segments"]:
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
else:
return {"error": "Segments is required"}, 400
@service_api_ns.expect(segment_list_parser)
@service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
@service_api_ns.doc("list_segments")
@service_api_ns.doc(description="List segments in a document")
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource):
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
args = segment_list_parser.parse_args()
args = SegmentListQuery.model_validate(
{
"status": request.args.getlist("status"),
"keyword": request.args.get("keyword"),
}
)
segments, total = SegmentService.get_segments(
document_id=document_id,
tenant_id=current_tenant_id,
status_list=args["status"],
keyword=args["keyword"],
status_list=args.status,
keyword=args.keyword,
page=page,
limit=limit,
)
@@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource):
SegmentService.delete_segment(segment, document, dataset)
return 204
@service_api_ns.expect(segment_update_parser)
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
@service_api_ns.doc("update_segment")
@service_api_ns.doc(description="Update a specific segment")
@service_api_ns.doc(
@@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
# validate args
args = segment_update_parser.parse_args()
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
@service_api_ns.doc("get_segment")
@@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource):
class ChildChunkApi(DatasetApiResource):
"""Resource for child chunks."""
@service_api_ns.expect(child_chunk_create_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
@service_api_ns.doc("create_child_chunk")
@service_api_ns.doc(description="Create a new child chunk for a segment")
@service_api_ns.doc(
@@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource):
raise ProviderNotInitializeError(ex.description)
# validate args
args = child_chunk_create_parser.parse_args()
payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {})
try:
child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset)
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@service_api_ns.expect(child_chunk_list_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
@service_api_ns.doc("list_child_chunks")
@service_api_ns.doc(description="List child chunks for a segment")
@service_api_ns.doc(
@@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
args = child_chunk_list_parser.parse_args()
args = ChildChunkListQuery.model_validate(
{
"limit": request.args.get("limit", default=20, type=int),
"keyword": request.args.get("keyword"),
"page": request.args.get("page", default=1, type=int),
}
)
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
page = args.page
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
@@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
return 204
@service_api_ns.expect(child_chunk_update_parser)
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
@service_api_ns.doc("update_child_chunk")
@service_api_ns.doc(description="Update a specific child chunk")
@service_api_ns.doc(
@@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource):
raise NotFound("Child chunk not found.")
# validate args
args = child_chunk_update_parser.parse_args()
payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})
try:
child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset)
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))