refactor: port reqparse to BaseModel (#28993)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@@ -12,26 +14,24 @@ from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Define parsers for annotation API
|
||||
annotation_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json", help="Annotation question")
|
||||
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
|
||||
)
|
||||
|
||||
annotation_reply_action_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
|
||||
)
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name")
|
||||
)
|
||||
class AnnotationCreatePayload(BaseModel):
|
||||
question: str = Field(description="Annotation question")
|
||||
answer: str = Field(description="Annotation answer")
|
||||
|
||||
|
||||
class AnnotationReplyActionPayload(BaseModel):
|
||||
score_threshold: float = Field(description="Score threshold for annotation matching")
|
||||
embedding_provider_name: str = Field(description="Embedding provider name")
|
||||
embedding_model_name: str = Field(description="Embedding model name")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@service_api_ns.expect(annotation_reply_action_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__])
|
||||
@service_api_ns.doc("annotation_reply_action")
|
||||
@service_api_ns.doc(description="Enable or disable annotation reply feature")
|
||||
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
|
||||
@@ -44,7 +44,7 @@ class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = annotation_reply_action_parser.parse_args()
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
elif action == "disable":
|
||||
@@ -126,7 +126,7 @@ class AnnotationListApi(Resource):
|
||||
"page": page,
|
||||
}
|
||||
|
||||
@service_api_ns.expect(annotation_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_annotation")
|
||||
@service_api_ns.doc(description="Create a new annotation")
|
||||
@service_api_ns.doc(
|
||||
@@ -139,14 +139,14 @@ class AnnotationListApi(Resource):
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = annotation_create_parser.parse_args()
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
return annotation, 201
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@service_api_ns.expect(annotation_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("update_annotation")
|
||||
@service_api_ns.doc(description="Update an existing annotation")
|
||||
@service_api_ns.doc(params={"annotation_id": "Annotation ID"})
|
||||
@@ -163,7 +163,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = annotation_create_parser.parse_args()
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
@@ -84,19 +86,19 @@ class AudioApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
# Define parser for text-to-audio API
|
||||
text_to_audio_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, required=False, location="json", help="Message ID")
|
||||
.add_argument("voice", type=str, location="json", help="Voice to use for TTS")
|
||||
.add_argument("text", type=str, location="json", help="Text to convert to audio")
|
||||
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
|
||||
)
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/text-to-audio")
|
||||
class TextApi(Resource):
|
||||
@service_api_ns.expect(text_to_audio_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__])
|
||||
@service_api_ns.doc("text_to_audio")
|
||||
@service_api_ns.doc(description="Convert text to audio using text-to-speech")
|
||||
@service_api_ns.doc(
|
||||
@@ -114,11 +116,11 @@ class TextApi(Resource):
|
||||
Converts the provided text to audio using the specified voice.
|
||||
"""
|
||||
try:
|
||||
args = text_to_audio_parser.parse_args()
|
||||
payload = TextToAudioPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = payload.message_id
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
@@ -26,7 +30,6 @@ from core.errors.error import (
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
@@ -36,40 +39,31 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parser for completion API
|
||||
completion_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion")
|
||||
.add_argument("query", type=str, location="json", default="", help="The query string")
|
||||
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||
)
|
||||
class CompletionRequestPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = Field(default="")
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = Field(default="dev")
|
||||
|
||||
# Define parser for chat API
|
||||
chat_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
|
||||
.add_argument("query", type=str, required=True, location="json", help="The chat query")
|
||||
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||
.add_argument(
|
||||
"auto_generate_name",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=True,
|
||||
location="json",
|
||||
help="Auto generate conversation name",
|
||||
)
|
||||
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
|
||||
)
|
||||
|
||||
class ChatRequestPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUID | None = None
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/completion-messages")
|
||||
class CompletionApi(Resource):
|
||||
@service_api_ns.expect(completion_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__])
|
||||
@service_api_ns.doc("create_completion")
|
||||
@service_api_ns.doc(description="Create a completion for the given prompt")
|
||||
@service_api_ns.doc(
|
||||
@@ -91,12 +85,13 @@ class CompletionApi(Resource):
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
args = completion_parser.parse_args()
|
||||
payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -162,7 +157,7 @@ class CompletionStopApi(Resource):
|
||||
|
||||
@service_api_ns.route("/chat-messages")
|
||||
class ChatApi(Resource):
|
||||
@service_api_ns.expect(chat_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__])
|
||||
@service_api_ns.doc("create_chat_message")
|
||||
@service_api_ns.doc(description="Send a message in a chat conversation")
|
||||
@service_api_ns.doc(
|
||||
@@ -186,13 +181,14 @@ class ChatApi(Resource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = chat_parser.parse_args()
|
||||
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
@@ -19,74 +24,44 @@ from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
# Define parsers for conversation APIs
|
||||
conversation_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of conversations to return",
|
||||
)
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
help="Sort order for conversations",
|
||||
)
|
||||
)
|
||||
|
||||
conversation_rename_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, location="json", help="New conversation name")
|
||||
.add_argument(
|
||||
"auto_generate",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
location="json",
|
||||
help="Auto-generate conversation name",
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort order for conversations"
|
||||
)
|
||||
)
|
||||
|
||||
conversation_variables_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of variables to return",
|
||||
)
|
||||
)
|
||||
|
||||
conversation_variable_update_parser = reqparse.RequestParser().add_argument(
|
||||
# using lambda is for passing the already-typed value without modification
|
||||
# if no lambda, it will be converted to string
|
||||
# the string cannot be converted using json.loads
|
||||
"value",
|
||||
required=True,
|
||||
location="json",
|
||||
type=lambda x: x,
|
||||
help="New value for the conversation variable",
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str = Field(description="New conversation name")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
value: Any
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
ConversationListQuery,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations")
|
||||
class ConversationApi(Resource):
|
||||
@service_api_ns.expect(conversation_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationListQuery.__name__])
|
||||
@service_api_ns.doc("list_conversations")
|
||||
@service_api_ns.doc(description="List all conversations for the current user")
|
||||
@service_api_ns.doc(
|
||||
@@ -107,7 +82,8 @@ class ConversationApi(Resource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = conversation_list_parser.parse_args()
|
||||
query_args = ConversationListQuery.model_validate(request.args.to_dict())
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
@@ -115,10 +91,10 @@ class ConversationApi(Resource):
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
last_id=last_id,
|
||||
limit=query_args.limit,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
sort_by=query_args.sort_by,
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
@@ -155,7 +131,7 @@ class ConversationDetailApi(Resource):
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
class ConversationRenameApi(Resource):
|
||||
@service_api_ns.expect(conversation_rename_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__])
|
||||
@service_api_ns.doc("rename_conversation")
|
||||
@service_api_ns.doc(description="Rename a conversation or auto-generate a name")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
@@ -176,17 +152,17 @@ class ConversationRenameApi(Resource):
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
args = conversation_rename_parser.parse_args()
|
||||
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables")
|
||||
class ConversationVariablesApi(Resource):
|
||||
@service_api_ns.expect(conversation_variables_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationVariablesQuery.__name__])
|
||||
@service_api_ns.doc("list_conversation_variables")
|
||||
@service_api_ns.doc(description="List all variables for a conversation")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
@@ -211,11 +187,12 @@ class ConversationVariablesApi(Resource):
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
args = conversation_variables_parser.parse_args()
|
||||
query_args = ConversationVariablesQuery.model_validate(request.args.to_dict())
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
return ConversationService.get_conversational_variable(
|
||||
app_model, conversation_id, end_user, args["limit"], args["last_id"]
|
||||
app_model, conversation_id, end_user, query_args.limit, last_id
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@@ -223,7 +200,7 @@ class ConversationVariablesApi(Resource):
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
|
||||
class ConversationVariableDetailApi(Resource):
|
||||
@service_api_ns.expect(conversation_variable_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_conversation_variable")
|
||||
@service_api_ns.doc(description="Update a conversation variable's value")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
|
||||
@@ -250,11 +227,11 @@ class ConversationVariableDetailApi(Resource):
|
||||
conversation_id = str(c_id)
|
||||
variable_id = str(variable_id)
|
||||
|
||||
args = conversation_variable_update_parser.parse_args()
|
||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, args["value"]
|
||||
app_model, conversation_id, variable_id, end_user, payload.value
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
FileAccessDeniedError,
|
||||
@@ -17,10 +19,11 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parser for file preview API
|
||||
file_preview_parser = reqparse.RequestParser().add_argument(
|
||||
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
|
||||
)
|
||||
class FilePreviewQuery(BaseModel):
|
||||
as_attachment: bool = Field(default=False, description="Download as attachment")
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, FilePreviewQuery)
|
||||
|
||||
|
||||
@service_api_ns.route("/files/<uuid:file_id>/preview")
|
||||
@@ -32,7 +35,7 @@ class FilePreviewApi(Resource):
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
|
||||
@service_api_ns.expect(file_preview_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[FilePreviewQuery.__name__])
|
||||
@service_api_ns.doc("preview_file")
|
||||
@service_api_ns.doc(description="Preview or download a file uploaded via Service API")
|
||||
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
|
||||
@@ -55,7 +58,7 @@ class FilePreviewApi(Resource):
|
||||
file_id = str(file_id)
|
||||
|
||||
# Parse query parameters
|
||||
args = file_preview_parser.parse_args()
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
@@ -67,7 +70,7 @@ class FilePreviewApi(Resource):
|
||||
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
|
||||
|
||||
# Build response with appropriate headers
|
||||
response = self._build_file_response(generator, upload_file, args["as_attachment"])
|
||||
response = self._build_file_response(generator, upload_file, args.as_attachment)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
@@ -13,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import build_message_file_model
|
||||
from fields.message_fields import build_agent_thought_model, build_feedback_model
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
@@ -25,42 +29,26 @@ from services.message_service import MessageService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parsers for message APIs
|
||||
message_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of messages to return",
|
||||
)
|
||||
)
|
||||
|
||||
message_feedback_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
|
||||
.add_argument("content", type=str, location="json", help="Feedback content")
|
||||
)
|
||||
|
||||
feedback_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, default=1, location="args", help="Page number")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 101),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of feedbacks per page",
|
||||
)
|
||||
)
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Api | Namespace):
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class FeedbackListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Namespace):
|
||||
"""Build the message model for the API or Namespace."""
|
||||
# First build the nested models
|
||||
feedback_model = build_feedback_model(api_or_ns)
|
||||
@@ -90,7 +78,7 @@ def build_message_model(api_or_ns: Api | Namespace):
|
||||
return api_or_ns.model("Message", message_fields)
|
||||
|
||||
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the message infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested message model first
|
||||
message_model = build_message_model(api_or_ns)
|
||||
@@ -105,7 +93,7 @@ def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
|
||||
@service_api_ns.route("/messages")
|
||||
class MessageListApi(Resource):
|
||||
@service_api_ns.expect(message_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
|
||||
@service_api_ns.doc("list_messages")
|
||||
@service_api_ns.doc(description="List messages in a conversation")
|
||||
@service_api_ns.doc(
|
||||
@@ -126,11 +114,13 @@ class MessageListApi(Resource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = message_list_parser.parse_args()
|
||||
query_args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
conversation_id = str(query_args.conversation_id)
|
||||
first_id = str(query_args.first_id) if query_args.first_id else None
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model, end_user, conversation_id, first_id, query_args.limit
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@@ -140,7 +130,7 @@ class MessageListApi(Resource):
|
||||
|
||||
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
|
||||
class MessageFeedbackApi(Resource):
|
||||
@service_api_ns.expect(message_feedback_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__])
|
||||
@service_api_ns.doc("create_message_feedback")
|
||||
@service_api_ns.doc(description="Submit feedback for a message")
|
||||
@service_api_ns.doc(params={"message_id": "Message ID"})
|
||||
@@ -159,15 +149,15 @@ class MessageFeedbackApi(Resource):
|
||||
"""
|
||||
message_id = str(message_id)
|
||||
|
||||
args = message_feedback_parser.parse_args()
|
||||
payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=args.get("rating"),
|
||||
content=args.get("content"),
|
||||
rating=payload.rating,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@@ -177,7 +167,7 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
@service_api_ns.route("/app/feedbacks")
|
||||
class AppGetFeedbacksApi(Resource):
|
||||
@service_api_ns.expect(feedback_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[FeedbackListQuery.__name__])
|
||||
@service_api_ns.doc("get_app_feedbacks")
|
||||
@service_api_ns.doc(description="Get all feedbacks for the application")
|
||||
@service_api_ns.doc(
|
||||
@@ -192,8 +182,8 @@ class AppGetFeedbacksApi(Resource):
|
||||
|
||||
Returns paginated list of all feedback submitted for messages in this app.
|
||||
"""
|
||||
args = feedback_list_parser.parse_args()
|
||||
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
|
||||
query_args = FeedbackListQuery.model_validate(request.args.to_dict())
|
||||
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=query_args.page, limit=query_args.limit)
|
||||
return {"data": feedbacks}
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
CompletionRequestError,
|
||||
@@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define parsers for workflow APIs
|
||||
workflow_run_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
)
|
||||
|
||||
workflow_log_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
.add_argument("created_at__before", type=str, location="args")
|
||||
.add_argument("created_at__after", type=str, location="args")
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
class WorkflowLogQuery(BaseModel):
|
||||
keyword: str | None = None
|
||||
status: Literal["succeeded", "failed", "stopped"] | None = None
|
||||
created_at__before: str | None = None
|
||||
created_at__after: str | None = None
|
||||
created_by_end_user_session_id: str | None = None
|
||||
created_by_account: str | None = None
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
@@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc("run_workflow")
|
||||
@service_api_ns.doc(description="Execute a workflow")
|
||||
@service_api_ns.doc(
|
||||
@@ -154,11 +144,12 @@ class WorkflowRunApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
@@ -185,7 +176,7 @@ class WorkflowRunApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/<string:workflow_id>/run")
|
||||
class WorkflowRunByIdApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc("run_workflow_by_id")
|
||||
@service_api_ns.doc(description="Execute a specific workflow by ID")
|
||||
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
|
||||
@@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
# Add workflow_id to args for AppGenerateService
|
||||
args["workflow_id"] = workflow_id
|
||||
@@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource):
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
@@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/logs")
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@service_api_ns.expect(workflow_log_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__])
|
||||
@service_api_ns.doc("get_workflow_logs")
|
||||
@service_api_ns.doc(description="Get workflow execution logs")
|
||||
@service_api_ns.doc(
|
||||
@@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
Returns paginated workflow execution logs with filtering options.
|
||||
"""
|
||||
args = workflow_log_parser.parse_args()
|
||||
args = WorkflowLogQuery.model_validate(request.args.to_dict())
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
created_at_before = isoparse(args.created_at__before) if args.created_at__before else None
|
||||
created_at_after = isoparse(args.created_at__after) if args.created_at__after else None
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
@@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource):
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword=args.keyword,
|
||||
status=args.status,
|
||||
created_at_before=args.created_at__before,
|
||||
created_at_after=args.created_at__after,
|
||||
status=status,
|
||||
created_at_before=created_at_before,
|
||||
created_at_after=created_at_after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
|
||||
Reference in New Issue
Block a user