refactor: port reqparse to Pydantic model (#28913)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato
2025-11-30 16:09:42 +09:00
committed by GitHub
parent bb096f4ae3
commit 247069c7e9
19 changed files with 1013 additions and 1369 deletions

View File

@@ -1,7 +1,9 @@
import logging
from typing import Literal
from flask_restx import Resource, fields, marshal_with, reqparse
from flask_restx.inputs import int_range
from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
@@ -33,6 +35,67 @@ from services.errors.message import MessageNotExistsError, SuggestedQuestionsAft
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
conversation_id: str = Field(..., description="Conversation ID")
first_id: str | None = Field(default=None, description="First message ID for pagination")
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return (1-100)")
@field_validator("first_id", mode="before")
@classmethod
def empty_to_none(cls, value: str | None) -> str | None:
if value == "":
return None
return value
@field_validator("conversation_id", "first_id")
@classmethod
def validate_uuid(cls, value: str | None) -> str | None:
if value is None:
return value
return uuid_value(value)
class MessageFeedbackPayload(BaseModel):
message_id: str = Field(..., description="Message ID")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
@field_validator("message_id")
@classmethod
def validate_message_id(cls, value: str) -> str:
return uuid_value(value)
class FeedbackExportQuery(BaseModel):
from_source: Literal["user", "admin"] | None = Field(default=None, description="Filter by feedback source")
rating: Literal["like", "dislike"] | None = Field(default=None, description="Filter by rating")
has_comment: bool | None = Field(default=None, description="Only include feedback with comments")
start_date: str | None = Field(default=None, description="Start date (YYYY-MM-DD)")
end_date: str | None = Field(default=None, description="End date (YYYY-MM-DD)")
format: Literal["csv", "json"] = Field(default="csv", description="Export format")
@field_validator("has_comment", mode="before")
@classmethod
def parse_bool(cls, value: bool | str | None) -> bool | None:
if isinstance(value, bool) or value is None:
return value
lowered = value.lower()
if lowered in {"true", "1", "yes", "on"}:
return True
if lowered in {"false", "0", "no", "off"}:
return False
raise ValueError("has_comment must be a boolean value")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@@ -157,12 +220,7 @@ class ChatMessageListApi(Resource):
@console_ns.doc("list_chat_messages")
@console_ns.doc(description="Get chat messages for a conversation with pagination")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.parser()
.add_argument("conversation_id", type=str, required=True, location="args", help="Conversation ID")
.add_argument("first_id", type=str, location="args", help="First message ID for pagination")
.add_argument("limit", type=int, location="args", default=20, help="Number of messages to return (1-100)")
)
@console_ns.expect(console_ns.models[ChatMessagesQuery.__name__])
@console_ns.response(200, "Success", message_infinite_scroll_pagination_model)
@console_ns.response(404, "Conversation not found")
@login_required
@@ -172,27 +230,21 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
parser = (
reqparse.RequestParser()
.add_argument("conversation_id", required=True, type=uuid_value, location="args")
.add_argument("first_id", type=uuid_value, location="args")
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
)
args = parser.parse_args()
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
conversation = (
db.session.query(Conversation)
.where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id)
.first()
)
if not conversation:
raise NotFound("Conversation Not Exists.")
if args["first_id"]:
if args.first_id:
first_message = (
db.session.query(Message)
.where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.where(Message.conversation_id == conversation.id, Message.id == args.first_id)
.first()
)
@@ -207,7 +259,7 @@ class ChatMessageListApi(Resource):
Message.id != first_message.id,
)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.limit(args.limit)
.all()
)
else:
@@ -215,12 +267,12 @@ class ChatMessageListApi(Resource):
db.session.query(Message)
.where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.limit(args.limit)
.all()
)
# Initialize has_more based on whether we have a full page
if len(history_messages) == args["limit"]:
if len(history_messages) == args.limit:
current_page_first_message = history_messages[-1]
# Check if there are more messages before the current page
has_more = db.session.scalar(
@@ -238,7 +290,7 @@ class ChatMessageListApi(Resource):
history_messages = list(reversed(history_messages))
return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@console_ns.route("/apps/<uuid:app_id>/feedbacks")
@@ -246,15 +298,7 @@ class MessageFeedbackApi(Resource):
@console_ns.doc("create_message_feedback")
@console_ns.doc(description="Create or update message feedback (like/dislike)")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(
console_ns.model(
"MessageFeedbackRequest",
{
"message_id": fields.String(required=True, description="Message ID"),
"rating": fields.String(enum=["like", "dislike"], description="Feedback rating"),
},
)
)
@console_ns.expect(console_ns.models[MessageFeedbackPayload.__name__])
@console_ns.response(200, "Feedback updated successfully")
@console_ns.response(404, "Message not found")
@console_ns.response(403, "Insufficient permissions")
@@ -265,14 +309,9 @@ class MessageFeedbackApi(Resource):
def post(self, app_model):
current_user, _ = current_account_with_tenant()
parser = (
reqparse.RequestParser()
.add_argument("message_id", required=True, type=uuid_value, location="json")
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
)
args = parser.parse_args()
args = MessageFeedbackPayload.model_validate(console_ns.payload)
message_id = str(args["message_id"])
message_id = str(args.message_id)
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
@@ -281,18 +320,21 @@ class MessageFeedbackApi(Resource):
feedback = message.admin_feedback
if not args["rating"] and feedback:
if not args.rating and feedback:
db.session.delete(feedback)
elif args["rating"] and feedback:
feedback.rating = args["rating"]
elif not args["rating"] and not feedback:
elif args.rating and feedback:
feedback.rating = args.rating
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
rating_value = args.rating
if rating_value is None:
raise ValueError("rating is required to create feedback")
feedback = MessageFeedback(
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=args["rating"],
rating=rating_value,
from_source="admin",
from_account_id=current_user.id,
)
@@ -369,24 +411,12 @@ class MessageSuggestedQuestionApi(Resource):
return {"data": questions}
# Shared parser for feedback export (used for both documentation and runtime parsing)
feedback_export_parser = (
console_ns.parser()
.add_argument("from_source", type=str, choices=["user", "admin"], location="args", help="Filter by feedback source")
.add_argument("rating", type=str, choices=["like", "dislike"], location="args", help="Filter by rating")
.add_argument("has_comment", type=bool, location="args", help="Only include feedback with comments")
.add_argument("start_date", type=str, location="args", help="Start date (YYYY-MM-DD)")
.add_argument("end_date", type=str, location="args", help="End date (YYYY-MM-DD)")
.add_argument("format", type=str, choices=["csv", "json"], default="csv", location="args", help="Export format")
)
@console_ns.route("/apps/<uuid:app_id>/feedbacks/export")
class MessageFeedbackExportApi(Resource):
@console_ns.doc("export_feedbacks")
@console_ns.doc(description="Export user feedback data for Google Sheets")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(feedback_export_parser)
@console_ns.expect(console_ns.models[FeedbackExportQuery.__name__])
@console_ns.response(200, "Feedback data exported successfully")
@console_ns.response(400, "Invalid parameters")
@console_ns.response(500, "Internal server error")
@@ -395,7 +425,7 @@ class MessageFeedbackExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
args = feedback_export_parser.parse_args()
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Import the service function
from services.feedback_service import FeedbackService
@@ -403,12 +433,12 @@ class MessageFeedbackExportApi(Resource):
try:
export_data = FeedbackService.export_feedbacks(
app_id=app_model.id,
from_source=args.get("from_source"),
rating=args.get("rating"),
has_comment=args.get("has_comment"),
start_date=args.get("start_date"),
end_date=args.get("end_date"),
format_type=args.get("format", "csv"),
from_source=args.from_source,
rating=args.rating,
has_comment=args.has_comment,
start_date=args.start_date,
end_date=args.end_date,
format_type=args.format,
)
return export_data