refactor: port reqparse to BaseModel (#28993)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
CompletionRequestError,
|
||||
@@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define parsers for workflow APIs
|
||||
workflow_run_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
)
|
||||
|
||||
workflow_log_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
.add_argument("created_at__before", type=str, location="args")
|
||||
.add_argument("created_at__after", type=str, location="args")
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
class WorkflowLogQuery(BaseModel):
|
||||
keyword: str | None = None
|
||||
status: Literal["succeeded", "failed", "stopped"] | None = None
|
||||
created_at__before: str | None = None
|
||||
created_at__after: str | None = None
|
||||
created_by_end_user_session_id: str | None = None
|
||||
created_by_account: str | None = None
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
@@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc("run_workflow")
|
||||
@service_api_ns.doc(description="Execute a workflow")
|
||||
@service_api_ns.doc(
|
||||
@@ -154,11 +144,12 @@ class WorkflowRunApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
@@ -185,7 +176,7 @@ class WorkflowRunApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/<string:workflow_id>/run")
|
||||
class WorkflowRunByIdApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc("run_workflow_by_id")
|
||||
@service_api_ns.doc(description="Execute a specific workflow by ID")
|
||||
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
|
||||
@@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
# Add workflow_id to args for AppGenerateService
|
||||
args["workflow_id"] = workflow_id
|
||||
@@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource):
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
@@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/logs")
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@service_api_ns.expect(workflow_log_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__])
|
||||
@service_api_ns.doc("get_workflow_logs")
|
||||
@service_api_ns.doc(description="Get workflow execution logs")
|
||||
@service_api_ns.doc(
|
||||
@@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
Returns paginated workflow execution logs with filtering options.
|
||||
"""
|
||||
args = workflow_log_parser.parse_args()
|
||||
args = WorkflowLogQuery.model_validate(request.args.to_dict())
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
created_at_before = isoparse(args.created_at__before) if args.created_at__before else None
|
||||
created_at_after = isoparse(args.created_at__after) if args.created_at__after else None
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
@@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource):
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword=args.keyword,
|
||||
status=args.status,
|
||||
created_at_before=args.created_at__before,
|
||||
created_at_after=args.created_at__after,
|
||||
status=status,
|
||||
created_at_before=created_at_before,
|
||||
created_at_after=created_at_after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
|
||||
Reference in New Issue
Block a user