refactor: port reqparse to BaseModel (#28993)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from flask_restx.api import HTTPStatus
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
@@ -12,26 +14,24 @@ from fields.annotation_fields import annotation_fields, build_annotation_model
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
# Define parsers for annotation API
|
||||
annotation_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("question", required=True, type=str, location="json", help="Annotation question")
|
||||
.add_argument("answer", required=True, type=str, location="json", help="Annotation answer")
|
||||
)
|
||||
|
||||
annotation_reply_action_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"score_threshold", required=True, type=float, location="json", help="Score threshold for annotation matching"
|
||||
)
|
||||
.add_argument("embedding_provider_name", required=True, type=str, location="json", help="Embedding provider name")
|
||||
.add_argument("embedding_model_name", required=True, type=str, location="json", help="Embedding model name")
|
||||
)
|
||||
class AnnotationCreatePayload(BaseModel):
|
||||
question: str = Field(description="Annotation question")
|
||||
answer: str = Field(description="Annotation answer")
|
||||
|
||||
|
||||
class AnnotationReplyActionPayload(BaseModel):
|
||||
score_threshold: float = Field(description="Score threshold for annotation matching")
|
||||
embedding_provider_name: str = Field(description="Embedding provider name")
|
||||
embedding_model_name: str = Field(description="Embedding model name")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, AnnotationCreatePayload, AnnotationReplyActionPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotation-reply/<string:action>")
|
||||
class AnnotationReplyActionApi(Resource):
|
||||
@service_api_ns.expect(annotation_reply_action_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationReplyActionPayload.__name__])
|
||||
@service_api_ns.doc("annotation_reply_action")
|
||||
@service_api_ns.doc(description="Enable or disable annotation reply feature")
|
||||
@service_api_ns.doc(params={"action": "Action to perform: 'enable' or 'disable'"})
|
||||
@@ -44,7 +44,7 @@ class AnnotationReplyActionApi(Resource):
|
||||
@validate_app_token
|
||||
def post(self, app_model: App, action: Literal["enable", "disable"]):
|
||||
"""Enable or disable annotation reply feature."""
|
||||
args = annotation_reply_action_parser.parse_args()
|
||||
args = AnnotationReplyActionPayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
if action == "enable":
|
||||
result = AppAnnotationService.enable_app_annotation(args, app_model.id)
|
||||
elif action == "disable":
|
||||
@@ -126,7 +126,7 @@ class AnnotationListApi(Resource):
|
||||
"page": page,
|
||||
}
|
||||
|
||||
@service_api_ns.expect(annotation_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_annotation")
|
||||
@service_api_ns.doc(description="Create a new annotation")
|
||||
@service_api_ns.doc(
|
||||
@@ -139,14 +139,14 @@ class AnnotationListApi(Resource):
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns), code=HTTPStatus.CREATED)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new annotation."""
|
||||
args = annotation_create_parser.parse_args()
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.insert_app_annotation_directly(args, app_model.id)
|
||||
return annotation, 201
|
||||
|
||||
|
||||
@service_api_ns.route("/apps/annotations/<uuid:annotation_id>")
|
||||
class AnnotationUpdateDeleteApi(Resource):
|
||||
@service_api_ns.expect(annotation_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[AnnotationCreatePayload.__name__])
|
||||
@service_api_ns.doc("update_annotation")
|
||||
@service_api_ns.doc(description="Update an existing annotation")
|
||||
@service_api_ns.doc(params={"annotation_id": "Annotation ID"})
|
||||
@@ -163,7 +163,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
@service_api_ns.marshal_with(build_annotation_model(service_api_ns))
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
"""Update an existing annotation."""
|
||||
args = annotation_create_parser.parse_args()
|
||||
args = AnnotationCreatePayload.model_validate(service_api_ns.payload or {}).model_dump()
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id)
|
||||
return annotation
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
@@ -84,19 +86,19 @@ class AudioApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
# Define parser for text-to-audio API
|
||||
text_to_audio_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("message_id", type=str, required=False, location="json", help="Message ID")
|
||||
.add_argument("voice", type=str, location="json", help="Voice to use for TTS")
|
||||
.add_argument("text", type=str, location="json", help="Text to convert to audio")
|
||||
.add_argument("streaming", type=bool, location="json", help="Enable streaming response")
|
||||
)
|
||||
class TextToAudioPayload(BaseModel):
|
||||
message_id: str | None = Field(default=None, description="Message ID")
|
||||
voice: str | None = Field(default=None, description="Voice to use for TTS")
|
||||
text: str | None = Field(default=None, description="Text to convert to audio")
|
||||
streaming: bool | None = Field(default=None, description="Enable streaming response")
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, TextToAudioPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/text-to-audio")
|
||||
class TextApi(Resource):
|
||||
@service_api_ns.expect(text_to_audio_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TextToAudioPayload.__name__])
|
||||
@service_api_ns.doc("text_to_audio")
|
||||
@service_api_ns.doc(description="Convert text to audio using text-to-speech")
|
||||
@service_api_ns.doc(
|
||||
@@ -114,11 +116,11 @@ class TextApi(Resource):
|
||||
Converts the provided text to audio using the specified voice.
|
||||
"""
|
||||
try:
|
||||
args = text_to_audio_parser.parse_args()
|
||||
payload = TextToAudioPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
message_id = payload.message_id
|
||||
text = payload.text
|
||||
voice = payload.voice
|
||||
response = AudioService.transcript_tts(
|
||||
app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
|
||||
)
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
@@ -26,7 +30,6 @@ from core.errors.error import (
|
||||
from core.helper.trace_id_helper import get_external_trace_id
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.app_task_service import AppTaskService
|
||||
@@ -36,40 +39,31 @@ from services.errors.llm import InvokeRateLimitError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parser for completion API
|
||||
completion_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for completion")
|
||||
.add_argument("query", type=str, location="json", default="", help="The query string")
|
||||
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||
)
|
||||
class CompletionRequestPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str = Field(default="")
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = Field(default="dev")
|
||||
|
||||
# Define parser for chat API
|
||||
chat_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, location="json", help="Input parameters for chat")
|
||||
.add_argument("query", type=str, required=True, location="json", help="The chat query")
|
||||
.add_argument("files", type=list, required=False, location="json", help="List of file attachments")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json", help="Response mode")
|
||||
.add_argument("conversation_id", type=uuid_value, location="json", help="Existing conversation ID")
|
||||
.add_argument("retriever_from", type=str, required=False, default="dev", location="json", help="Retriever source")
|
||||
.add_argument(
|
||||
"auto_generate_name",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=True,
|
||||
location="json",
|
||||
help="Auto generate conversation name",
|
||||
)
|
||||
.add_argument("workflow_id", type=str, required=False, location="json", help="Workflow ID for advanced chat")
|
||||
)
|
||||
|
||||
class ChatRequestPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUID | None = None
|
||||
retriever_from: str = Field(default="dev")
|
||||
auto_generate_name: bool = Field(default=True, description="Auto generate conversation name")
|
||||
workflow_id: str | None = Field(default=None, description="Workflow ID for advanced chat")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, CompletionRequestPayload, ChatRequestPayload)
|
||||
|
||||
|
||||
@service_api_ns.route("/completion-messages")
|
||||
class CompletionApi(Resource):
|
||||
@service_api_ns.expect(completion_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[CompletionRequestPayload.__name__])
|
||||
@service_api_ns.doc("create_completion")
|
||||
@service_api_ns.doc(description="Create a completion for the given prompt")
|
||||
@service_api_ns.doc(
|
||||
@@ -91,12 +85,13 @@ class CompletionApi(Resource):
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise AppUnavailableError()
|
||||
|
||||
args = completion_parser.parse_args()
|
||||
payload = CompletionRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -162,7 +157,7 @@ class CompletionStopApi(Resource):
|
||||
|
||||
@service_api_ns.route("/chat-messages")
|
||||
class ChatApi(Resource):
|
||||
@service_api_ns.expect(chat_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ChatRequestPayload.__name__])
|
||||
@service_api_ns.doc("create_chat_message")
|
||||
@service_api_ns.doc(description="Send a message in a chat conversation")
|
||||
@service_api_ns.doc(
|
||||
@@ -186,13 +181,14 @@ class ChatApi(Resource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = chat_parser.parse_args()
|
||||
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from flask_restx import Resource, reqparse
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from flask_restx._http import HTTPStatus
|
||||
from flask_restx.inputs import int_range
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
@@ -19,74 +24,44 @@ from fields.conversation_variable_fields import (
|
||||
build_conversation_variable_infinite_scroll_pagination_model,
|
||||
build_conversation_variable_model,
|
||||
)
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
# Define parsers for conversation APIs
|
||||
conversation_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args", help="Last conversation ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of conversations to return",
|
||||
)
|
||||
.add_argument(
|
||||
"sort_by",
|
||||
type=str,
|
||||
choices=["created_at", "-created_at", "updated_at", "-updated_at"],
|
||||
required=False,
|
||||
default="-updated_at",
|
||||
location="args",
|
||||
help="Sort order for conversations",
|
||||
)
|
||||
)
|
||||
|
||||
conversation_rename_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=False, location="json", help="New conversation name")
|
||||
.add_argument(
|
||||
"auto_generate",
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False,
|
||||
location="json",
|
||||
help="Auto-generate conversation name",
|
||||
class ConversationListQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last conversation ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of conversations to return")
|
||||
sort_by: Literal["created_at", "-created_at", "updated_at", "-updated_at"] = Field(
|
||||
default="-updated_at", description="Sort order for conversations"
|
||||
)
|
||||
)
|
||||
|
||||
conversation_variables_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args", help="Last variable ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of variables to return",
|
||||
)
|
||||
)
|
||||
|
||||
conversation_variable_update_parser = reqparse.RequestParser().add_argument(
|
||||
# using lambda is for passing the already-typed value without modification
|
||||
# if no lambda, it will be converted to string
|
||||
# the string cannot be converted using json.loads
|
||||
"value",
|
||||
required=True,
|
||||
location="json",
|
||||
type=lambda x: x,
|
||||
help="New value for the conversation variable",
|
||||
class ConversationRenamePayload(BaseModel):
|
||||
name: str = Field(description="New conversation name")
|
||||
auto_generate: bool = Field(default=False, description="Auto-generate conversation name")
|
||||
|
||||
|
||||
class ConversationVariablesQuery(BaseModel):
|
||||
last_id: UUID | None = Field(default=None, description="Last variable ID for pagination")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of variables to return")
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
value: Any
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
ConversationListQuery,
|
||||
ConversationRenamePayload,
|
||||
ConversationVariablesQuery,
|
||||
ConversationVariableUpdatePayload,
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations")
|
||||
class ConversationApi(Resource):
|
||||
@service_api_ns.expect(conversation_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationListQuery.__name__])
|
||||
@service_api_ns.doc("list_conversations")
|
||||
@service_api_ns.doc(description="List all conversations for the current user")
|
||||
@service_api_ns.doc(
|
||||
@@ -107,7 +82,8 @@ class ConversationApi(Resource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = conversation_list_parser.parse_args()
|
||||
query_args = ConversationListQuery.model_validate(request.args.to_dict())
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
@@ -115,10 +91,10 @@ class ConversationApi(Resource):
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
last_id=last_id,
|
||||
limit=query_args.limit,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
sort_by=query_args.sort_by,
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
@@ -155,7 +131,7 @@ class ConversationDetailApi(Resource):
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/name")
|
||||
class ConversationRenameApi(Resource):
|
||||
@service_api_ns.expect(conversation_rename_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationRenamePayload.__name__])
|
||||
@service_api_ns.doc("rename_conversation")
|
||||
@service_api_ns.doc(description="Rename a conversation or auto-generate a name")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
@@ -176,17 +152,17 @@ class ConversationRenameApi(Resource):
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
args = conversation_rename_parser.parse_args()
|
||||
payload = ConversationRenamePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"])
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, payload.name, payload.auto_generate)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables")
|
||||
class ConversationVariablesApi(Resource):
|
||||
@service_api_ns.expect(conversation_variables_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationVariablesQuery.__name__])
|
||||
@service_api_ns.doc("list_conversation_variables")
|
||||
@service_api_ns.doc(description="List all variables for a conversation")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID"})
|
||||
@@ -211,11 +187,12 @@ class ConversationVariablesApi(Resource):
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
args = conversation_variables_parser.parse_args()
|
||||
query_args = ConversationVariablesQuery.model_validate(request.args.to_dict())
|
||||
last_id = str(query_args.last_id) if query_args.last_id else None
|
||||
|
||||
try:
|
||||
return ConversationService.get_conversational_variable(
|
||||
app_model, conversation_id, end_user, args["limit"], args["last_id"]
|
||||
app_model, conversation_id, end_user, query_args.limit, last_id
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@@ -223,7 +200,7 @@ class ConversationVariablesApi(Resource):
|
||||
|
||||
@service_api_ns.route("/conversations/<uuid:c_id>/variables/<uuid:variable_id>")
|
||||
class ConversationVariableDetailApi(Resource):
|
||||
@service_api_ns.expect(conversation_variable_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_conversation_variable")
|
||||
@service_api_ns.doc(description="Update a conversation variable's value")
|
||||
@service_api_ns.doc(params={"c_id": "Conversation ID", "variable_id": "Variable ID"})
|
||||
@@ -250,11 +227,11 @@ class ConversationVariableDetailApi(Resource):
|
||||
conversation_id = str(c_id)
|
||||
variable_id = str(variable_id)
|
||||
|
||||
args = conversation_variable_update_parser.parse_args()
|
||||
payload = ConversationVariableUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
return ConversationService.update_conversation_variable(
|
||||
app_model, conversation_id, variable_id, end_user, args["value"]
|
||||
app_model, conversation_id, variable_id, end_user, payload.value
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
FileAccessDeniedError,
|
||||
@@ -17,10 +19,11 @@ from models.model import App, EndUser, Message, MessageFile, UploadFile
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parser for file preview API
|
||||
file_preview_parser = reqparse.RequestParser().add_argument(
|
||||
"as_attachment", type=bool, required=False, default=False, location="args", help="Download as attachment"
|
||||
)
|
||||
class FilePreviewQuery(BaseModel):
|
||||
as_attachment: bool = Field(default=False, description="Download as attachment")
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, FilePreviewQuery)
|
||||
|
||||
|
||||
@service_api_ns.route("/files/<uuid:file_id>/preview")
|
||||
@@ -32,7 +35,7 @@ class FilePreviewApi(Resource):
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
|
||||
@service_api_ns.expect(file_preview_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[FilePreviewQuery.__name__])
|
||||
@service_api_ns.doc("preview_file")
|
||||
@service_api_ns.doc(description="Preview or download a file uploaded via Service API")
|
||||
@service_api_ns.doc(params={"file_id": "UUID of the file to preview"})
|
||||
@@ -55,7 +58,7 @@ class FilePreviewApi(Resource):
|
||||
file_id = str(file_id)
|
||||
|
||||
# Parse query parameters
|
||||
args = file_preview_parser.parse_args()
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
@@ -67,7 +70,7 @@ class FilePreviewApi(Resource):
|
||||
raise FileNotFoundError(f"Failed to load file content: {str(e)}")
|
||||
|
||||
# Build response with appropriate headers
|
||||
response = self._build_file_response(generator, upload_file, args["as_attachment"])
|
||||
response = self._build_file_response(generator, upload_file, args.as_attachment)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask import request
|
||||
from flask_restx import Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
@@ -13,7 +17,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import build_message_file_model
|
||||
from fields.message_fields import build_agent_thought_model, build_feedback_model
|
||||
from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
@@ -25,42 +29,26 @@ from services.message_service import MessageService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Define parsers for message APIs
|
||||
message_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("conversation_id", required=True, type=uuid_value, location="args", help="Conversation ID")
|
||||
.add_argument("first_id", type=uuid_value, location="args", help="First message ID for pagination")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 100),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of messages to return",
|
||||
)
|
||||
)
|
||||
|
||||
message_feedback_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("rating", type=str, choices=["like", "dislike", None], location="json", help="Feedback rating")
|
||||
.add_argument("content", type=str, location="json", help="Feedback content")
|
||||
)
|
||||
|
||||
feedback_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=int, default=1, location="args", help="Page number")
|
||||
.add_argument(
|
||||
"limit",
|
||||
type=int_range(1, 101),
|
||||
required=False,
|
||||
default=20,
|
||||
location="args",
|
||||
help="Number of feedbacks per page",
|
||||
)
|
||||
)
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Number of messages to return")
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Api | Namespace):
|
||||
class MessageFeedbackPayload(BaseModel):
|
||||
rating: Literal["like", "dislike"] | None = Field(default=None, description="Feedback rating")
|
||||
content: str | None = Field(default=None, description="Feedback content")
|
||||
|
||||
|
||||
class FeedbackListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, description="Page number")
|
||||
limit: int = Field(default=20, ge=1, le=101, description="Number of feedbacks per page")
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, MessageListQuery, MessageFeedbackPayload, FeedbackListQuery)
|
||||
|
||||
|
||||
def build_message_model(api_or_ns: Namespace):
|
||||
"""Build the message model for the API or Namespace."""
|
||||
# First build the nested models
|
||||
feedback_model = build_feedback_model(api_or_ns)
|
||||
@@ -90,7 +78,7 @@ def build_message_model(api_or_ns: Api | Namespace):
|
||||
return api_or_ns.model("Message", message_fields)
|
||||
|
||||
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
def build_message_infinite_scroll_pagination_model(api_or_ns: Namespace):
|
||||
"""Build the message infinite scroll pagination model for the API or Namespace."""
|
||||
# Build the nested message model first
|
||||
message_model = build_message_model(api_or_ns)
|
||||
@@ -105,7 +93,7 @@ def build_message_infinite_scroll_pagination_model(api_or_ns: Api | Namespace):
|
||||
|
||||
@service_api_ns.route("/messages")
|
||||
class MessageListApi(Resource):
|
||||
@service_api_ns.expect(message_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MessageListQuery.__name__])
|
||||
@service_api_ns.doc("list_messages")
|
||||
@service_api_ns.doc(description="List messages in a conversation")
|
||||
@service_api_ns.doc(
|
||||
@@ -126,11 +114,13 @@ class MessageListApi(Resource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
args = message_list_parser.parse_args()
|
||||
query_args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
conversation_id = str(query_args.conversation_id)
|
||||
first_id = str(query_args.first_id) if query_args.first_id else None
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(
|
||||
app_model, end_user, args["conversation_id"], args["first_id"], args["limit"]
|
||||
app_model, end_user, conversation_id, first_id, query_args.limit
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@@ -140,7 +130,7 @@ class MessageListApi(Resource):
|
||||
|
||||
@service_api_ns.route("/messages/<uuid:message_id>/feedbacks")
|
||||
class MessageFeedbackApi(Resource):
|
||||
@service_api_ns.expect(message_feedback_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MessageFeedbackPayload.__name__])
|
||||
@service_api_ns.doc("create_message_feedback")
|
||||
@service_api_ns.doc(description="Submit feedback for a message")
|
||||
@service_api_ns.doc(params={"message_id": "Message ID"})
|
||||
@@ -159,15 +149,15 @@ class MessageFeedbackApi(Resource):
|
||||
"""
|
||||
message_id = str(message_id)
|
||||
|
||||
args = message_feedback_parser.parse_args()
|
||||
payload = MessageFeedbackPayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=args.get("rating"),
|
||||
content=args.get("content"),
|
||||
rating=payload.rating,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@@ -177,7 +167,7 @@ class MessageFeedbackApi(Resource):
|
||||
|
||||
@service_api_ns.route("/app/feedbacks")
|
||||
class AppGetFeedbacksApi(Resource):
|
||||
@service_api_ns.expect(feedback_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[FeedbackListQuery.__name__])
|
||||
@service_api_ns.doc("get_app_feedbacks")
|
||||
@service_api_ns.doc(description="Get all feedbacks for the application")
|
||||
@service_api_ns.doc(
|
||||
@@ -192,8 +182,8 @@ class AppGetFeedbacksApi(Resource):
|
||||
|
||||
Returns paginated list of all feedback submitted for messages in this app.
|
||||
"""
|
||||
args = feedback_list_parser.parse_args()
|
||||
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
|
||||
query_args = FeedbackListQuery.model_validate(request.args.to_dict())
|
||||
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=query_args.page, limit=query_args.limit)
|
||||
return {"data": feedbacks}
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from flask import request
|
||||
from flask_restx import Api, Namespace, Resource, fields, reqparse
|
||||
from flask_restx.inputs import int_range
|
||||
from flask_restx import Api, Namespace, Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import (
|
||||
CompletionRequestError,
|
||||
@@ -41,37 +43,25 @@ from services.workflow_app_service import WorkflowAppService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define parsers for workflow APIs
|
||||
workflow_run_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("files", type=list, required=False, location="json")
|
||||
.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
)
|
||||
|
||||
workflow_log_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("keyword", type=str, location="args")
|
||||
.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
|
||||
.add_argument("created_at__before", type=str, location="args")
|
||||
.add_argument("created_at__after", type=str, location="args")
|
||||
.add_argument(
|
||||
"created_by_end_user_session_id",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument(
|
||||
"created_by_account",
|
||||
type=str,
|
||||
location="args",
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
.add_argument("page", type=int_range(1, 99999), default=1, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), default=20, location="args")
|
||||
)
|
||||
class WorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
|
||||
|
||||
class WorkflowLogQuery(BaseModel):
|
||||
keyword: str | None = None
|
||||
status: Literal["succeeded", "failed", "stopped"] | None = None
|
||||
created_at__before: str | None = None
|
||||
created_at__after: str | None = None
|
||||
created_by_end_user_session_id: str | None = None
|
||||
created_by_account: str | None = None
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
@@ -130,7 +120,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/run")
|
||||
class WorkflowRunApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc("run_workflow")
|
||||
@service_api_ns.doc(description="Execute a workflow")
|
||||
@service_api_ns.doc(
|
||||
@@ -154,11 +144,12 @@ class WorkflowRunApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
@@ -185,7 +176,7 @@ class WorkflowRunApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/<string:workflow_id>/run")
|
||||
class WorkflowRunByIdApi(Resource):
|
||||
@service_api_ns.expect(workflow_run_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowRunPayload.__name__])
|
||||
@service_api_ns.doc("run_workflow_by_id")
|
||||
@service_api_ns.doc(description="Execute a specific workflow by ID")
|
||||
@service_api_ns.doc(params={"workflow_id": "Workflow ID to execute"})
|
||||
@@ -209,7 +200,8 @@ class WorkflowRunByIdApi(Resource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
args = workflow_run_parser.parse_args()
|
||||
payload = WorkflowRunPayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
# Add workflow_id to args for AppGenerateService
|
||||
args["workflow_id"] = workflow_id
|
||||
@@ -217,7 +209,7 @@ class WorkflowRunByIdApi(Resource):
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
if external_trace_id:
|
||||
args["external_trace_id"] = external_trace_id
|
||||
streaming = args.get("response_mode") == "streaming"
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = AppGenerateService.generate(
|
||||
@@ -279,7 +271,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
|
||||
@service_api_ns.route("/workflows/logs")
|
||||
class WorkflowAppLogApi(Resource):
|
||||
@service_api_ns.expect(workflow_log_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[WorkflowLogQuery.__name__])
|
||||
@service_api_ns.doc("get_workflow_logs")
|
||||
@service_api_ns.doc(description="Get workflow execution logs")
|
||||
@service_api_ns.doc(
|
||||
@@ -295,14 +287,11 @@ class WorkflowAppLogApi(Resource):
|
||||
|
||||
Returns paginated workflow execution logs with filtering options.
|
||||
"""
|
||||
args = workflow_log_parser.parse_args()
|
||||
args = WorkflowLogQuery.model_validate(request.args.to_dict())
|
||||
|
||||
args.status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
if args.created_at__before:
|
||||
args.created_at__before = isoparse(args.created_at__before)
|
||||
|
||||
if args.created_at__after:
|
||||
args.created_at__after = isoparse(args.created_at__after)
|
||||
status = WorkflowExecutionStatus(args.status) if args.status else None
|
||||
created_at_before = isoparse(args.created_at__before) if args.created_at__before else None
|
||||
created_at_after = isoparse(args.created_at__after) if args.created_at__after else None
|
||||
|
||||
# get paginate workflow app logs
|
||||
workflow_app_service = WorkflowAppService()
|
||||
@@ -311,9 +300,9 @@ class WorkflowAppLogApi(Resource):
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
keyword=args.keyword,
|
||||
status=args.status,
|
||||
created_at_before=args.created_at__before,
|
||||
created_at_after=args.created_at__after,
|
||||
status=status,
|
||||
created_at_before=created_at_before,
|
||||
created_at_after=created_at_after,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
created_by_end_user_session_id=args.created_by_end_user_session_id,
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.wraps import edit_permission_required
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
|
||||
@@ -18,173 +20,83 @@ from core.provider_manager import ProviderManager
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from fields.tag_fields import build_dataset_tag_fields
|
||||
from libs.login import current_user
|
||||
from libs.validators import validate_description_length
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
|
||||
from services.tag_service import TagService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
class DatasetCreatePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400)
|
||||
indexing_technique: Literal["high_quality", "economy"] | None = None
|
||||
permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME
|
||||
external_knowledge_api_id: str | None = None
|
||||
provider: str = "vendor"
|
||||
external_knowledge_id: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
|
||||
|
||||
# Define parsers for dataset operations
|
||||
dataset_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
required=False,
|
||||
nullable=False,
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="_validate_name",
|
||||
)
|
||||
.add_argument(
|
||||
"provider",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="vendor",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
)
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
class DatasetUpdatePayload(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=40)
|
||||
description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400)
|
||||
indexing_technique: Literal["high_quality", "economy"] | None = None
|
||||
permission: DatasetPermissionEnum | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
partial_member_list: list[str] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
external_knowledge_api_id: str | None = None
|
||||
|
||||
dataset_update_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument("description", location="json", store_missing=False, type=validate_description_length)
|
||||
.add_argument(
|
||||
"indexing_technique",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=Dataset.INDEXING_TECHNIQUE_LIST,
|
||||
nullable=True,
|
||||
help="Invalid indexing technique.",
|
||||
)
|
||||
.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
location="json",
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
help="Invalid permission.",
|
||||
)
|
||||
.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
|
||||
.add_argument("embedding_model_provider", type=str, location="json", help="Invalid embedding model provider.")
|
||||
.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
|
||||
.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
|
||||
.add_argument(
|
||||
"external_retrieval_model",
|
||||
type=dict,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external retrieval model.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge id.",
|
||||
)
|
||||
.add_argument(
|
||||
"external_knowledge_api_id",
|
||||
type=str,
|
||||
required=False,
|
||||
nullable=True,
|
||||
location="json",
|
||||
help="Invalid external knowledge api id.",
|
||||
)
|
||||
)
|
||||
|
||||
tag_create_parser = reqparse.RequestParser().add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=lambda x: x
|
||||
if x and 1 <= len(x) <= 50
|
||||
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
|
||||
)
|
||||
class TagNamePayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=50)
|
||||
|
||||
tag_update_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 50 characters.",
|
||||
type=lambda x: x
|
||||
if x and 1 <= len(x) <= 50
|
||||
else (_ for _ in ()).throw(ValueError("Name must be between 1 to 50 characters.")),
|
||||
)
|
||||
.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str)
|
||||
)
|
||||
|
||||
tag_delete_parser = reqparse.RequestParser().add_argument(
|
||||
"tag_id", nullable=False, required=True, help="Id of a tag.", type=str
|
||||
)
|
||||
class TagCreatePayload(TagNamePayload):
|
||||
pass
|
||||
|
||||
tag_binding_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required.")
|
||||
.add_argument(
|
||||
"target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required."
|
||||
)
|
||||
)
|
||||
|
||||
tag_unbinding_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.")
|
||||
.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.")
|
||||
class TagUpdatePayload(TagNamePayload):
|
||||
tag_id: str
|
||||
|
||||
|
||||
class TagDeletePayload(BaseModel):
|
||||
tag_id: str
|
||||
|
||||
|
||||
class TagBindingPayload(BaseModel):
|
||||
tag_ids: list[str]
|
||||
target_id: str
|
||||
|
||||
@field_validator("tag_ids")
|
||||
@classmethod
|
||||
def validate_tag_ids(cls, value: list[str]) -> list[str]:
|
||||
if not value:
|
||||
raise ValueError("Tag IDs is required.")
|
||||
return value
|
||||
|
||||
|
||||
class TagUnbindingPayload(BaseModel):
|
||||
tag_id: str
|
||||
target_id: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
DatasetCreatePayload,
|
||||
DatasetUpdatePayload,
|
||||
TagCreatePayload,
|
||||
TagUpdatePayload,
|
||||
TagDeletePayload,
|
||||
TagBindingPayload,
|
||||
TagUnbindingPayload,
|
||||
)
|
||||
|
||||
|
||||
@@ -239,7 +151,7 @@ class DatasetListApi(DatasetApiResource):
|
||||
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(dataset_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset")
|
||||
@service_api_ns.doc(description="Create a new dataset")
|
||||
@service_api_ns.doc(
|
||||
@@ -252,42 +164,41 @@ class DatasetListApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id):
|
||||
"""Resource for creating datasets."""
|
||||
args = dataset_create_parser.parse_args()
|
||||
payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
try:
|
||||
assert isinstance(current_user, Account)
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args["name"],
|
||||
description=args["description"],
|
||||
indexing_technique=args["indexing_technique"],
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
indexing_technique=payload.indexing_technique,
|
||||
account=current_user,
|
||||
permission=args["permission"],
|
||||
provider=args["provider"],
|
||||
external_knowledge_api_id=args["external_knowledge_api_id"],
|
||||
external_knowledge_id=args["external_knowledge_id"],
|
||||
embedding_model_provider=args["embedding_model_provider"],
|
||||
embedding_model_name=args["embedding_model"],
|
||||
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
|
||||
if args["retrieval_model"] is not None
|
||||
else None,
|
||||
permission=str(payload.permission) if payload.permission else None,
|
||||
provider=payload.provider,
|
||||
external_knowledge_api_id=payload.external_knowledge_api_id,
|
||||
external_knowledge_id=payload.external_knowledge_id,
|
||||
embedding_model_provider=payload.embedding_model_provider,
|
||||
embedding_model_name=payload.embedding_model,
|
||||
retrieval_model=payload.retrieval_model,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
@@ -353,7 +264,7 @@ class DatasetApi(DatasetApiResource):
|
||||
|
||||
return data, 200
|
||||
|
||||
@service_api_ns.expect(dataset_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset")
|
||||
@service_api_ns.doc(description="Update an existing dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@@ -372,36 +283,45 @@ class DatasetApi(DatasetApiResource):
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
args = dataset_update_parser.parse_args()
|
||||
data = request.get_json()
|
||||
payload_dict = service_api_ns.payload or {}
|
||||
payload = DatasetUpdatePayload.model_validate(payload_dict)
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
if payload.permission is not None:
|
||||
update_data["permission"] = str(payload.permission)
|
||||
if payload.retrieval_model is not None:
|
||||
update_data["retrieval_model"] = payload.retrieval_model.model_dump()
|
||||
|
||||
# check embedding model setting
|
||||
embedding_model_provider = data.get("embedding_model_provider")
|
||||
embedding_model = data.get("embedding_model")
|
||||
if data.get("indexing_technique") == "high_quality" or embedding_model_provider:
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if payload.indexing_technique == "high_quality" or embedding_model_provider:
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(
|
||||
dataset.tenant_id, embedding_model_provider, embedding_model
|
||||
)
|
||||
|
||||
retrieval_model = data.get("retrieval_model")
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
dataset.tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
DatasetPermissionService.check_permission(
|
||||
current_user, dataset, data.get("permission"), data.get("partial_member_list")
|
||||
current_user,
|
||||
dataset,
|
||||
str(payload.permission) if payload.permission else None,
|
||||
payload.partial_member_list,
|
||||
)
|
||||
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
|
||||
dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user)
|
||||
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
@@ -410,15 +330,10 @@ class DatasetApi(DatasetApiResource):
|
||||
assert isinstance(current_user, Account)
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
if data.get("partial_member_list") and data.get("permission") == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
tenant_id, dataset_id_str, data.get("partial_member_list")
|
||||
)
|
||||
if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
|
||||
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
|
||||
# clear partial member list when permission is only_me or all_team_members
|
||||
elif (
|
||||
data.get("permission") == DatasetPermissionEnum.ONLY_ME
|
||||
or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
|
||||
):
|
||||
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
|
||||
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
|
||||
@@ -556,7 +471,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
return tags, 200
|
||||
|
||||
@service_api_ns.expect(tag_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_dataset_tag")
|
||||
@service_api_ns.doc(description="Add a knowledge type tag")
|
||||
@service_api_ns.doc(
|
||||
@@ -574,14 +489,13 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_create_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag = TagService.save_tags(args)
|
||||
payload = TagCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
tag = TagService.save_tags({"name": payload.name, "type": "knowledge"})
|
||||
|
||||
response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0}
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(tag_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset_tag")
|
||||
@service_api_ns.doc(description="Update a knowledge type tag")
|
||||
@service_api_ns.doc(
|
||||
@@ -598,10 +512,10 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_update_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
tag_id = args["tag_id"]
|
||||
tag = TagService.update_tags(args, tag_id)
|
||||
payload = TagUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
params = {"name": payload.name, "type": "knowledge"}
|
||||
tag_id = payload.tag_id
|
||||
tag = TagService.update_tags(params, tag_id)
|
||||
|
||||
binding_count = TagService.get_tag_binding_count(tag_id)
|
||||
|
||||
@@ -609,7 +523,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
|
||||
return response, 200
|
||||
|
||||
@service_api_ns.expect(tag_delete_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__])
|
||||
@service_api_ns.doc("delete_dataset_tag")
|
||||
@service_api_ns.doc(description="Delete a knowledge type tag")
|
||||
@service_api_ns.doc(
|
||||
@@ -623,15 +537,15 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
@edit_permission_required
|
||||
def delete(self, _, dataset_id):
|
||||
"""Delete a knowledge type tag."""
|
||||
args = tag_delete_parser.parse_args()
|
||||
TagService.delete_tag(args["tag_id"])
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/binding")
|
||||
class DatasetTagBindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(tag_binding_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__])
|
||||
@service_api_ns.doc("bind_dataset_tags")
|
||||
@service_api_ns.doc(description="Bind tags to a dataset")
|
||||
@service_api_ns.doc(
|
||||
@@ -648,16 +562,15 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_binding_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.save_tag_binding(args)
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
@service_api_ns.expect(tag_unbinding_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__])
|
||||
@service_api_ns.doc("unbind_dataset_tag")
|
||||
@service_api_ns.doc(description="Unbind a tag from a dataset")
|
||||
@service_api_ns.doc(
|
||||
@@ -674,9 +587,8 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
args = tag_unbinding_parser.parse_args()
|
||||
args["type"] = "knowledge"
|
||||
TagService.delete_tag_binding(args)
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return 204
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Self
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from pydantic import BaseModel, model_validator
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlalchemy import desc, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
@@ -37,22 +37,19 @@ from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
|
||||
# Define parsers for document operations
|
||||
document_text_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("text", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument("process_rule", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
|
||||
.add_argument("doc_language", type=str, default="English", required=False, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json"
|
||||
)
|
||||
.add_argument("retrieval_model", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model", type=str, required=False, nullable=True, location="json")
|
||||
.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
class DocumentTextCreatePayload(BaseModel):
|
||||
name: str
|
||||
text: str
|
||||
process_rule: ProcessRule | None = None
|
||||
original_document_id: str | None = None
|
||||
doc_form: str = Field(default="text_model")
|
||||
doc_language: str = Field(default="English")
|
||||
indexing_technique: str | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
@@ -72,7 +69,7 @@ class DocumentTextUpdate(BaseModel):
|
||||
return self
|
||||
|
||||
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
|
||||
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
|
||||
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
|
||||
|
||||
|
||||
@@ -83,7 +80,7 @@ for m in [ProcessRule, RetrievalModel, DocumentTextUpdate]:
|
||||
class DocumentAddByTextApi(DatasetApiResource):
|
||||
"""Resource for documents."""
|
||||
|
||||
@service_api_ns.expect(document_text_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_document_by_text")
|
||||
@service_api_ns.doc(description="Create a new document by providing text content")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@@ -99,7 +96,8 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create document by text."""
|
||||
args = document_text_create_parser.parse_args()
|
||||
payload = DocumentTextCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
@@ -111,33 +109,29 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
if not dataset.indexing_technique and not args["indexing_technique"]:
|
||||
raise ValueError("indexing_technique is required.")
|
||||
|
||||
text = args.get("text")
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both 'text' and 'name' must be non-null values.")
|
||||
|
||||
embedding_model_provider = args.get("embedding_model_provider")
|
||||
embedding_model = args.get("embedding_model")
|
||||
embedding_model_provider = payload.embedding_model_provider
|
||||
embedding_model = payload.embedding_model
|
||||
if embedding_model_provider and embedding_model:
|
||||
DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model)
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
text=payload.text, text_name=payload.name, user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
@@ -174,7 +168,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
"""Resource for update documents."""
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__], validate=True)
|
||||
@service_api_ns.expect(service_api_ns.models[DocumentTextUpdate.__name__])
|
||||
@service_api_ns.doc("update_document_by_text")
|
||||
@service_api_ns.doc(description="Update an existing document by providing text content")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@@ -189,22 +183,23 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""Update document by text."""
|
||||
args = DocumentTextUpdate.model_validate(service_api_ns.payload).model_dump(exclude_unset=True)
|
||||
payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {})
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first()
|
||||
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset does not exist.")
|
||||
|
||||
retrieval_model = args.get("retrieval_model")
|
||||
retrieval_model = payload.retrieval_model
|
||||
if (
|
||||
retrieval_model
|
||||
and retrieval_model.get("reranking_model")
|
||||
and retrieval_model.get("reranking_model").get("reranking_provider_name")
|
||||
and retrieval_model.reranking_model
|
||||
and retrieval_model.reranking_model.reranking_provider_name
|
||||
and retrieval_model.reranking_model.reranking_model_name
|
||||
):
|
||||
DatasetService.check_reranking_model_setting(
|
||||
tenant_id,
|
||||
retrieval_model.get("reranking_model").get("reranking_provider_name"),
|
||||
retrieval_model.get("reranking_model").get("reranking_model_name"),
|
||||
retrieval_model.reranking_model.reranking_provider_name,
|
||||
retrieval_model.reranking_model.reranking_model_name,
|
||||
)
|
||||
|
||||
# indexing_technique is already set in dataset since this is an update
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from typing import Literal
|
||||
|
||||
from flask_login import current_user
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_model, register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
@@ -14,25 +16,18 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
# Define parsers for metadata APIs
|
||||
metadata_create_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("type", type=str, required=True, nullable=False, location="json", help="Metadata type")
|
||||
.add_argument("name", type=str, required=True, nullable=False, location="json", help="Metadata name")
|
||||
)
|
||||
|
||||
metadata_update_parser = reqparse.RequestParser().add_argument(
|
||||
"name", type=str, required=True, nullable=False, location="json", help="New metadata name"
|
||||
)
|
||||
class MetadataUpdatePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
document_metadata_parser = reqparse.RequestParser().add_argument(
|
||||
"operation_data", type=list, required=True, nullable=False, location="json", help="Metadata operation data"
|
||||
)
|
||||
|
||||
register_schema_model(service_api_ns, MetadataUpdatePayload)
|
||||
register_schema_models(service_api_ns, MetadataArgs, MetadataOperationData)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata")
|
||||
class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
@service_api_ns.expect(metadata_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MetadataArgs.__name__])
|
||||
@service_api_ns.doc("create_dataset_metadata")
|
||||
@service_api_ns.doc(description="Create metadata for a dataset")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@@ -46,8 +41,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id, dataset_id):
|
||||
"""Create metadata for a dataset."""
|
||||
args = metadata_create_parser.parse_args()
|
||||
metadata_args = MetadataArgs.model_validate(args)
|
||||
metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@@ -79,7 +73,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/<uuid:metadata_id>")
|
||||
class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
@service_api_ns.expect(metadata_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MetadataUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_dataset_metadata")
|
||||
@service_api_ns.doc(description="Update metadata name")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "metadata_id": "Metadata ID"})
|
||||
@@ -93,7 +87,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id, dataset_id, metadata_id):
|
||||
"""Update metadata name."""
|
||||
args = metadata_update_parser.parse_args()
|
||||
payload = MetadataUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
@@ -102,7 +96,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args["name"])
|
||||
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, payload.name)
|
||||
return marshal(metadata, dataset_metadata_fields), 200
|
||||
|
||||
@service_api_ns.doc("delete_dataset_metadata")
|
||||
@@ -175,7 +169,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/documents/metadata")
|
||||
class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
@service_api_ns.expect(document_metadata_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[MetadataOperationData.__name__])
|
||||
@service_api_ns.doc("update_documents_metadata")
|
||||
@service_api_ns.doc(description="Update metadata for multiple documents")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@@ -195,8 +189,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
args = document_metadata_parser.parse_args()
|
||||
metadata_args = MetadataOperationData.model_validate(args)
|
||||
metadata_args = MetadataOperationData.model_validate(service_api_ns.payload or {})
|
||||
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
|
||||
@@ -4,12 +4,12 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import reqparse
|
||||
from flask_restx.reqparse import ParseResult, RequestParser
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.errors import FilenameNotExistsError, NoFileUploadedError, TooManyFilesError
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
@@ -22,11 +22,25 @@ from models.dataset import Pipeline
|
||||
from models.engine import db
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import DatasourceNodeRunApiEntity
|
||||
from services.rag_pipeline.entity.pipeline_service_api_entities import (
|
||||
DatasourceNodeRunApiEntity,
|
||||
PipelineRunApiEntity,
|
||||
)
|
||||
from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
|
||||
class DatasourceNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
datasource_type: str
|
||||
credential_id: str | None = None
|
||||
is_published: bool
|
||||
|
||||
|
||||
register_schema_model(service_api_ns, DatasourceNodeRunPayload)
|
||||
register_schema_model(service_api_ns, PipelineRunApiEntity)
|
||||
|
||||
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
@@ -88,22 +102,20 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
# Get query parameter to determine published or draft
|
||||
parser: RequestParser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
.add_argument("is_published", type=bool, required=True, location="json")
|
||||
)
|
||||
args: ParseResult = parser.parse_args()
|
||||
|
||||
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
|
||||
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
|
||||
{
|
||||
**payload.model_dump(exclude_none=True),
|
||||
"pipeline_id": str(pipeline.id),
|
||||
"node_id": node_id,
|
||||
}
|
||||
)
|
||||
return helper.compact_generate_response(
|
||||
PipelineGenerator.convert_to_event_stream(
|
||||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
@@ -147,25 +159,10 @@ class PipelineRunApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
"""Resource for running a rag pipeline."""
|
||||
parser: RequestParser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("is_published", type=bool, required=True, default=True, location="json")
|
||||
.add_argument(
|
||||
"response_mode",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["streaming", "blocking"],
|
||||
default="blocking",
|
||||
location="json",
|
||||
)
|
||||
)
|
||||
args: ParseResult = parser.parse_args()
|
||||
payload = PipelineRunApiEntity.model_validate(service_api_ns.payload or {})
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise Forbidden()
|
||||
@@ -176,9 +173,9 @@ class PipelineRunApi(DatasetApiResource):
|
||||
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.PUBLISHED if args.get("is_published") else InvokeFrom.DEBUGGER,
|
||||
streaming=args.get("response_mode") == "streaming",
|
||||
args=payload.model_dump(),
|
||||
invoke_from=InvokeFrom.PUBLISHED if payload.is_published else InvokeFrom.DEBUGGER,
|
||||
streaming=payload.response_mode == "streaming",
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal, reqparse
|
||||
from flask_restx import marshal
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import (
|
||||
@@ -24,34 +28,42 @@ from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexing
|
||||
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
|
||||
# Define parsers for segment operations
|
||||
segment_create_parser = reqparse.RequestParser().add_argument(
|
||||
"segments", type=list, required=False, nullable=True, location="json"
|
||||
)
|
||||
|
||||
segment_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("status", type=str, action="append", default=[], location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
)
|
||||
class SegmentCreatePayload(BaseModel):
|
||||
segments: list[dict[str, Any]] | None = None
|
||||
|
||||
segment_update_parser = reqparse.RequestParser().add_argument(
|
||||
"segment", type=dict, required=False, nullable=True, location="json"
|
||||
)
|
||||
|
||||
child_chunk_create_parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
class SegmentListQuery(BaseModel):
|
||||
status: list[str] = Field(default_factory=list)
|
||||
keyword: str | None = None
|
||||
|
||||
child_chunk_list_parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("limit", type=int, default=20, location="args")
|
||||
.add_argument("keyword", type=str, default=None, location="args")
|
||||
.add_argument("page", type=int, default=1, location="args")
|
||||
)
|
||||
|
||||
child_chunk_update_parser = reqparse.RequestParser().add_argument(
|
||||
"content", type=str, required=True, nullable=False, location="json"
|
||||
class SegmentUpdatePayload(BaseModel):
|
||||
segment: SegmentUpdateArgs
|
||||
|
||||
|
||||
class ChildChunkCreatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChildChunkListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1)
|
||||
keyword: str | None = None
|
||||
page: int = Field(default=1, ge=1)
|
||||
|
||||
|
||||
class ChildChunkUpdatePayload(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
service_api_ns,
|
||||
SegmentCreatePayload,
|
||||
SegmentListQuery,
|
||||
SegmentUpdatePayload,
|
||||
ChildChunkCreatePayload,
|
||||
ChildChunkListQuery,
|
||||
ChildChunkUpdatePayload,
|
||||
)
|
||||
|
||||
|
||||
@@ -59,7 +71,7 @@ child_chunk_update_parser = reqparse.RequestParser().add_argument(
|
||||
class SegmentApi(DatasetApiResource):
|
||||
"""Resource for segments."""
|
||||
|
||||
@service_api_ns.expect(segment_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_segments")
|
||||
@service_api_ns.doc(description="Create segments in a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@@ -106,20 +118,20 @@ class SegmentApi(DatasetApiResource):
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
args = segment_create_parser.parse_args()
|
||||
if args["segments"] is not None:
|
||||
payload = SegmentCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
if payload.segments is not None:
|
||||
segments_limit = dify_config.DATASET_MAX_SEGMENTS_PER_REQUEST
|
||||
if segments_limit > 0 and len(args["segments"]) > segments_limit:
|
||||
if segments_limit > 0 and len(payload.segments) > segments_limit:
|
||||
raise ValueError(f"Exceeded maximum segments limit of {segments_limit}.")
|
||||
|
||||
for args_item in args["segments"]:
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args["segments"], document, dataset)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@service_api_ns.expect(segment_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentListQuery.__name__])
|
||||
@service_api_ns.doc("list_segments")
|
||||
@service_api_ns.doc(description="List segments in a document")
|
||||
@service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@@ -160,13 +172,18 @@ class SegmentApi(DatasetApiResource):
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
args = segment_list_parser.parse_args()
|
||||
args = SegmentListQuery.model_validate(
|
||||
{
|
||||
"status": request.args.getlist("status"),
|
||||
"keyword": request.args.get("keyword"),
|
||||
}
|
||||
)
|
||||
|
||||
segments, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
tenant_id=current_tenant_id,
|
||||
status_list=args["status"],
|
||||
keyword=args["keyword"],
|
||||
status_list=args.status,
|
||||
keyword=args.keyword,
|
||||
page=page,
|
||||
limit=limit,
|
||||
)
|
||||
@@ -217,7 +234,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(segment_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_segment")
|
||||
@service_api_ns.doc(description="Update a specific segment")
|
||||
@service_api_ns.doc(
|
||||
@@ -265,12 +282,9 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate args
|
||||
args = segment_update_parser.parse_args()
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
|
||||
)
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@@ -308,7 +322,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
class ChildChunkApi(DatasetApiResource):
|
||||
"""Resource for child chunks."""
|
||||
|
||||
@service_api_ns.expect(child_chunk_create_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkCreatePayload.__name__])
|
||||
@service_api_ns.doc("create_child_chunk")
|
||||
@service_api_ns.doc(description="Create a new child chunk for a segment")
|
||||
@service_api_ns.doc(
|
||||
@@ -360,16 +374,16 @@ class ChildChunkApi(DatasetApiResource):
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
# validate args
|
||||
args = child_chunk_create_parser.parse_args()
|
||||
payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
child_chunk = SegmentService.create_child_chunk(args["content"], segment, document, dataset)
|
||||
child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
||||
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
|
||||
|
||||
@service_api_ns.expect(child_chunk_list_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkListQuery.__name__])
|
||||
@service_api_ns.doc("list_child_chunks")
|
||||
@service_api_ns.doc(description="List child chunks for a segment")
|
||||
@service_api_ns.doc(
|
||||
@@ -400,11 +414,17 @@ class ChildChunkApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
args = child_chunk_list_parser.parse_args()
|
||||
args = ChildChunkListQuery.model_validate(
|
||||
{
|
||||
"limit": request.args.get("limit", default=20, type=int),
|
||||
"keyword": request.args.get("keyword"),
|
||||
"page": request.args.get("page", default=1, type=int),
|
||||
}
|
||||
)
|
||||
|
||||
page = args["page"]
|
||||
limit = min(args["limit"], 100)
|
||||
keyword = args["keyword"]
|
||||
page = args.page
|
||||
limit = min(args.limit, 100)
|
||||
keyword = args.keyword
|
||||
|
||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||
|
||||
@@ -480,7 +500,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(child_chunk_update_parser)
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_child_chunk")
|
||||
@service_api_ns.doc(description="Update a specific child chunk")
|
||||
@service_api_ns.doc(
|
||||
@@ -533,10 +553,10 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
# validate args
|
||||
args = child_chunk_update_parser.parse_args()
|
||||
payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
try:
|
||||
child_chunk = SegmentService.update_child_chunk(args["content"], child_chunk, segment, document, dataset)
|
||||
child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset)
|
||||
except ChildChunkIndexingServiceError as e:
|
||||
raise ChildChunkIndexingError(str(e))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user