refactor: port reqparse to BaseModel (#28993)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
@@ -27,6 +29,12 @@ from services.hit_testing_service import HitTestingService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HitTestingPayload(BaseModel):
|
||||
query: str = Field(max_length=250)
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DatasetsHitTestingBase:
|
||||
@staticmethod
|
||||
def get_and_validate_dataset(dataset_id: str):
|
||||
@@ -43,19 +51,9 @@ class DatasetsHitTestingBase:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def hit_testing_args_check(args):
|
||||
def hit_testing_args_check(args: dict[str, Any]):
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
@staticmethod
|
||||
def parse_args():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("query", type=str, location="json")
|
||||
.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||
.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@staticmethod
|
||||
def perform_hit_testing(dataset, args):
|
||||
assert isinstance(current_user, Account)
|
||||
|
||||
Reference in New Issue
Block a user