refactor: port reqparse to BaseModel (#28993)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,20 +1,63 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import make_response, redirect, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from libs.helper import StrLen
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
class DatasourceCredentialPayload(BaseModel):
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
credentials: dict[str, Any]
|
||||
|
||||
|
||||
class DatasourceCredentialDeletePayload(BaseModel):
|
||||
credential_id: str
|
||||
|
||||
|
||||
class DatasourceCredentialUpdatePayload(BaseModel):
|
||||
credential_id: str
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
credentials: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DatasourceCustomClientPayload(BaseModel):
|
||||
client_params: dict[str, Any] | None = None
|
||||
enable_oauth_custom_client: bool | None = None
|
||||
|
||||
|
||||
class DatasourceDefaultPayload(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class DatasourceUpdateNamePayload(BaseModel):
|
||||
credential_id: str
|
||||
name: str = Field(max_length=100)
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DatasourceCredentialPayload,
|
||||
DatasourceCredentialDeletePayload,
|
||||
DatasourceCredentialUpdatePayload,
|
||||
DatasourceCustomClientPayload,
|
||||
DatasourceDefaultPayload,
|
||||
DatasourceUpdateNamePayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/oauth/plugin/<path:provider_id>/datasource/get-authorization-url")
|
||||
class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@setup_required
|
||||
@@ -121,16 +164,9 @@ class DatasourceOAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
|
||||
|
||||
|
||||
parser_datasource = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json", default=None)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>")
|
||||
class DatasourceAuth(Resource):
|
||||
@console_ns.expect(parser_datasource)
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -138,7 +174,7 @@ class DatasourceAuth(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_datasource.parse_args()
|
||||
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
@@ -146,8 +182,8 @@ class DatasourceAuth(Resource):
|
||||
datasource_provider_service.add_datasource_api_key_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
provider_id=datasource_provider_id,
|
||||
credentials=args["credentials"],
|
||||
name=args["name"],
|
||||
credentials=payload.credentials,
|
||||
name=payload.name,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ValueError(str(ex))
|
||||
@@ -169,14 +205,9 @@ class DatasourceAuth(Resource):
|
||||
return {"result": datasources}, 200
|
||||
|
||||
|
||||
parser_datasource_delete = reqparse.RequestParser().add_argument(
|
||||
"credential_id", type=str, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/delete")
|
||||
class DatasourceAuthDeleteApi(Resource):
|
||||
@console_ns.expect(parser_datasource_delete)
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialDeletePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -188,28 +219,20 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
|
||||
args = parser_datasource_delete.parse_args()
|
||||
payload = DatasourceCredentialDeletePayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
auth_id=payload.credential_id,
|
||||
provider=provider_name,
|
||||
plugin_id=plugin_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_datasource_update = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("name", type=StrLen(max_length=100), required=False, nullable=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update")
|
||||
class DatasourceAuthUpdateApi(Resource):
|
||||
@console_ns.expect(parser_datasource_update)
|
||||
@console_ns.expect(console_ns.models[DatasourceCredentialUpdatePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -218,16 +241,16 @@ class DatasourceAuthUpdateApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
args = parser_datasource_update.parse_args()
|
||||
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
auth_id=args["credential_id"],
|
||||
auth_id=payload.credential_id,
|
||||
provider=datasource_provider_id.provider_name,
|
||||
plugin_id=datasource_provider_id.plugin_id,
|
||||
credentials=args.get("credentials", {}),
|
||||
name=args.get("name", None),
|
||||
credentials=payload.credentials or {},
|
||||
name=payload.name,
|
||||
)
|
||||
return {"result": "success"}, 201
|
||||
|
||||
@@ -258,16 +281,9 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
|
||||
|
||||
parser_datasource_custom = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
|
||||
.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/custom-client")
|
||||
class DatasourceAuthOauthCustomClient(Resource):
|
||||
@console_ns.expect(parser_datasource_custom)
|
||||
@console_ns.expect(console_ns.models[DatasourceCustomClientPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -275,14 +291,14 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_datasource_custom.parse_args()
|
||||
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.setup_oauth_custom_client_params(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
client_params=args.get("client_params", {}),
|
||||
enabled=args.get("enable_oauth_custom_client", False),
|
||||
client_params=payload.client_params or {},
|
||||
enabled=payload.enable_oauth_custom_client or False,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -301,12 +317,9 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("id", type=str, required=True, nullable=False, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/default")
|
||||
class DatasourceAuthDefaultApi(Resource):
|
||||
@console_ns.expect(parser_default)
|
||||
@console_ns.expect(console_ns.models[DatasourceDefaultPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -314,27 +327,20 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_default.parse_args()
|
||||
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.set_default_datasource_provider(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
credential_id=args["id"],
|
||||
credential_id=payload.id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
parser_update_name = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("name", type=StrLen(max_length=100), required=True, nullable=False, location="json")
|
||||
.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/auth/plugin/datasource/<path:provider_id>/update-name")
|
||||
class DatasourceUpdateProviderNameApi(Resource):
|
||||
@console_ns.expect(parser_update_name)
|
||||
@console_ns.expect(console_ns.models[DatasourceUpdateNamePayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -342,13 +348,13 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
args = parser_update_name.parse_args()
|
||||
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.update_datasource_provider_name(
|
||||
tenant_id=current_tenant_id,
|
||||
datasource_provider_id=datasource_provider_id,
|
||||
name=args["name"],
|
||||
credential_id=args["credential_id"],
|
||||
name=payload.name,
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
@@ -20,18 +22,6 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_name(name: str) -> str:
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description: str) -> str:
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/templates")
|
||||
class PipelineTemplateListApi(Resource):
|
||||
@setup_required
|
||||
@@ -59,6 +49,15 @@ class PipelineTemplateDetailApi(Resource):
|
||||
return pipeline_template, 200
|
||||
|
||||
|
||||
class Payload(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=40)
|
||||
description: str = Field(default="", max_length=400)
|
||||
icon_info: dict[str, object] | None = None
|
||||
|
||||
|
||||
register_schema_models(console_ns, Payload)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/customized/templates/<string:template_id>")
|
||||
class CustomizedPipelineTemplateApi(Resource):
|
||||
@setup_required
|
||||
@@ -66,31 +65,8 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump())
|
||||
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||
return 200
|
||||
|
||||
@@ -119,36 +95,14 @@ class CustomizedPipelineTemplateApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<string:pipeline_id>/customized/publish")
|
||||
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[Payload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
.add_argument(
|
||||
"description",
|
||||
type=_validate_description_length,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
location="json",
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = Payload.model_validate(console_ns.payload or {})
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump())
|
||||
return {"result": "success"}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from flask_restx import Resource, marshal, reqparse
|
||||
from flask_restx import Resource, marshal
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
@@ -19,22 +21,22 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo,
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineDatasetImportPayload(BaseModel):
|
||||
yaml_content: str
|
||||
|
||||
|
||||
register_schema_model(console_ns, RagPipelineDatasetImportPayload)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipeline/dataset")
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@console_ns.expect(console_ns.models[RagPipelineDatasetImportPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="yaml_content is required.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
payload = RagPipelineDatasetImportPayload.model_validate(console_ns.payload or {})
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
@@ -49,7 +51,7 @@ class CreateRagPipelineDatasetApi(Resource):
|
||||
),
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
partial_member_list=None,
|
||||
yaml_content=args["yaml_content"],
|
||||
yaml_content=payload.yaml_content,
|
||||
)
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import logging
|
||||
from typing import NoReturn
|
||||
from typing import Any, NoReturn
|
||||
|
||||
from flask import Response
|
||||
from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
DraftWorkflowNotExist,
|
||||
@@ -33,19 +35,21 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_pagination_parser():
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(
|
||||
"page",
|
||||
type=inputs.int_range(1, 100_000),
|
||||
required=False,
|
||||
default=1,
|
||||
location="args",
|
||||
help="the page of data requested",
|
||||
)
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
return parser
|
||||
class PaginationQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=100_000)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
register_schema_models(console_ns, PaginationQuery)
|
||||
|
||||
return PaginationQuery
|
||||
|
||||
|
||||
class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
name: str | None = None
|
||||
value: Any | None = None
|
||||
|
||||
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||
@@ -93,8 +97,8 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
parser = _create_pagination_parser()
|
||||
args = parser.parse_args()
|
||||
pagination = _create_pagination_parser()
|
||||
query = pagination.model_validate(request.args.to_dict())
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
@@ -109,8 +113,8 @@ class RagPipelineVariableCollectionApi(Resource):
|
||||
)
|
||||
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||
app_id=pipeline.id,
|
||||
page=args.page,
|
||||
limit=args.limit,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
)
|
||||
|
||||
return workflow_vars
|
||||
@@ -186,6 +190,7 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
# Request payload for file types:
|
||||
#
|
||||
@@ -208,16 +213,11 @@ class RagPipelineVariableApi(Resource):
|
||||
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
|
||||
# }
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||
.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
|
||||
)
|
||||
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
args = parser.parse_args(strict=True)
|
||||
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
if variable is None:
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from flask_restx import Resource, marshal_with, reqparse # type: ignore
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
@@ -16,6 +19,25 @@ from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineImportPayload(BaseModel):
|
||||
mode: str
|
||||
yaml_content: str | None = None
|
||||
yaml_url: str | None = None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon: str | None = None
|
||||
icon_background: str | None = None
|
||||
pipeline_id: str | None = None
|
||||
|
||||
|
||||
class IncludeSecretQuery(BaseModel):
|
||||
include_secret: str = Field(default="false")
|
||||
|
||||
|
||||
register_schema_models(console_ns, RagPipelineImportPayload, IncludeSecretQuery)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/imports")
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@@ -23,23 +45,11 @@ class RagPipelineImportApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
@console_ns.expect(console_ns.models[RagPipelineImportPayload.__name__])
|
||||
def post(self):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("mode", type=str, required=True, location="json")
|
||||
.add_argument("yaml_content", type=str, location="json")
|
||||
.add_argument("yaml_url", type=str, location="json")
|
||||
.add_argument("name", type=str, location="json")
|
||||
.add_argument("description", type=str, location="json")
|
||||
.add_argument("icon_type", type=str, location="json")
|
||||
.add_argument("icon", type=str, location="json")
|
||||
.add_argument("icon_background", type=str, location="json")
|
||||
.add_argument("pipeline_id", type=str, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = RagPipelineImportPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
@@ -48,11 +58,11 @@ class RagPipelineImportApi(Resource):
|
||||
account = current_user
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
pipeline_id=args.get("pipeline_id"),
|
||||
dataset_name=args.get("name"),
|
||||
import_mode=payload.mode,
|
||||
yaml_content=payload.yaml_content,
|
||||
yaml_url=payload.yaml_url,
|
||||
pipeline_id=payload.pipeline_id,
|
||||
dataset_name=payload.name,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
@@ -114,13 +124,12 @@ class RagPipelineExportApi(Resource):
|
||||
@edit_permission_required
|
||||
def get(self, pipeline: Pipeline):
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser().add_argument("include_secret", type=str, default="false", location="args")
|
||||
args = parser.parse_args()
|
||||
query = IncludeSecretQuery.model_validate(request.args.to_dict())
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(
|
||||
pipeline=pipeline, include_secret=args["include_secret"] == "true"
|
||||
pipeline=pipeline, include_secret=query.include_secret == "true"
|
||||
)
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
|
||||
from flask import abort, request
|
||||
from flask_restx import Resource, inputs, marshal_with, reqparse # type: ignore # type: ignore
|
||||
from flask_restx.inputs import int_range # type: ignore
|
||||
from flask_restx import Resource, marshal_with # type: ignore
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import (
|
||||
ConversationCompletedError,
|
||||
@@ -36,7 +38,7 @@ from fields.workflow_run_fields import (
|
||||
workflow_run_pagination_fields,
|
||||
)
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, current_user, login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
@@ -51,6 +53,91 @@ from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTran
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DraftWorkflowSyncPayload(BaseModel):
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
environment_variables: list[dict[str, Any]] | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = None
|
||||
rag_pipeline_variables: list[dict[str, Any]] | None = None
|
||||
features: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class NodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class NodeRunRequiredPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
|
||||
|
||||
class DatasourceNodeRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
datasource_type: str
|
||||
credential_id: str | None = None
|
||||
|
||||
|
||||
class DraftWorkflowRunPayload(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
datasource_type: str
|
||||
datasource_info_list: list[dict[str, Any]]
|
||||
start_node_id: str
|
||||
|
||||
|
||||
class PublishedWorkflowRunPayload(DraftWorkflowRunPayload):
|
||||
is_preview: bool = False
|
||||
response_mode: Literal["streaming", "blocking"] = "streaming"
|
||||
original_document_id: str | None = None
|
||||
|
||||
|
||||
class DefaultBlockConfigQuery(BaseModel):
|
||||
q: str | None = None
|
||||
|
||||
|
||||
class WorkflowListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=10, ge=1, le=100)
|
||||
user_id: str | None = None
|
||||
named_only: bool = False
|
||||
|
||||
|
||||
class WorkflowUpdatePayload(BaseModel):
|
||||
marked_name: str | None = Field(default=None, max_length=20)
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class NodeIdQuery(BaseModel):
|
||||
node_id: str
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class DatasourceVariablesPayload(BaseModel):
|
||||
datasource_type: str
|
||||
datasource_info: dict[str, Any]
|
||||
start_node_id: str
|
||||
start_node_title: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
DraftWorkflowSyncPayload,
|
||||
NodeRunPayload,
|
||||
NodeRunRequiredPayload,
|
||||
DatasourceNodeRunPayload,
|
||||
DraftWorkflowRunPayload,
|
||||
PublishedWorkflowRunPayload,
|
||||
DefaultBlockConfigQuery,
|
||||
WorkflowListQuery,
|
||||
WorkflowUpdatePayload,
|
||||
NodeIdQuery,
|
||||
WorkflowRunQuery,
|
||||
DatasourceVariablesPayload,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft")
|
||||
class DraftRagPipelineApi(Resource):
|
||||
@setup_required
|
||||
@@ -88,15 +175,7 @@ class DraftRagPipelineApi(Resource):
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("graph", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("hash", type=str, required=False, location="json")
|
||||
.add_argument("environment_variables", type=list, required=False, location="json")
|
||||
.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
.add_argument("rag_pipeline_variables", type=list, required=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload_dict = console_ns.payload or {}
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
data = json.loads(request.data.decode("utf-8"))
|
||||
@@ -106,7 +185,7 @@ class DraftRagPipelineApi(Resource):
|
||||
if not isinstance(data.get("graph"), dict):
|
||||
raise ValueError("graph is not a dict")
|
||||
|
||||
args = {
|
||||
payload_dict = {
|
||||
"graph": data.get("graph"),
|
||||
"features": data.get("features"),
|
||||
"hash": data.get("hash"),
|
||||
@@ -119,24 +198,26 @@ class DraftRagPipelineApi(Resource):
|
||||
else:
|
||||
abort(415)
|
||||
|
||||
payload = DraftWorkflowSyncPayload.model_validate(payload_dict)
|
||||
|
||||
try:
|
||||
environment_variables_list = args.get("environment_variables") or []
|
||||
environment_variables_list = payload.environment_variables or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = args.get("conversation_variables") or []
|
||||
conversation_variables_list = payload.conversation_variables or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=args["graph"],
|
||||
unique_hash=args.get("hash"),
|
||||
graph=payload.graph,
|
||||
unique_hash=payload.hash,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
rag_pipeline_variables=args.get("rag_pipeline_variables") or [],
|
||||
rag_pipeline_variables=payload.rag_pipeline_variables or [],
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
@@ -148,12 +229,9 @@ class DraftRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser_run = reqparse.RequestParser().add_argument("inputs", type=dict, location="json")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.expect(parser_run)
|
||||
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -166,7 +244,8 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_run.parse_args()
|
||||
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate_single_iteration(
|
||||
@@ -187,7 +266,7 @@ class RagPipelineDraftRunIterationNodeApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.expect(parser_run)
|
||||
@console_ns.expect(console_ns.models[NodeRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -200,7 +279,8 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_run.parse_args()
|
||||
payload = NodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate_single_loop(
|
||||
@@ -219,18 +299,9 @@ class RagPipelineDraftRunLoopNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
parser_draft_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/run")
|
||||
class DraftRagPipelineRunApi(Resource):
|
||||
@console_ns.expect(parser_draft_run)
|
||||
@console_ns.expect(console_ns.models[DraftWorkflowRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -243,7 +314,8 @@ class DraftRagPipelineRunApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_draft_run.parse_args()
|
||||
payload = DraftWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate(
|
||||
@@ -259,21 +331,9 @@ class DraftRagPipelineRunApi(Resource):
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
|
||||
|
||||
parser_published_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info_list", type=list, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("is_preview", type=bool, required=True, location="json", default=False)
|
||||
.add_argument("response_mode", type=str, required=True, location="json", default="streaming")
|
||||
.add_argument("original_document_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/run")
|
||||
class PublishedRagPipelineRunApi(Resource):
|
||||
@console_ns.expect(parser_published_run)
|
||||
@console_ns.expect(console_ns.models[PublishedWorkflowRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -286,16 +346,16 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_published_run.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
payload = PublishedWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
streaming = payload.response_mode == "streaming"
|
||||
|
||||
try:
|
||||
response = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER if args.get("is_preview") else InvokeFrom.PUBLISHED,
|
||||
invoke_from=InvokeFrom.DEBUGGER if payload.is_preview else InvokeFrom.PUBLISHED,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
@@ -387,17 +447,9 @@ class PublishedRagPipelineRunApi(Resource):
|
||||
#
|
||||
# return result
|
||||
#
|
||||
parser_rag_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("credential_id", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
@console_ns.expect(parser_rag_run)
|
||||
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -410,14 +462,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return helper.compact_generate_response(
|
||||
@@ -425,11 +470,11 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
user_inputs=payload.inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
datasource_type=payload.datasource_type,
|
||||
is_published=False,
|
||||
credential_id=args.get("credential_id"),
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -437,7 +482,7 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
@console_ns.expect(parser_rag_run)
|
||||
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@@ -450,14 +495,7 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_rag_run.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs is None:
|
||||
raise ValueError("missing inputs")
|
||||
datasource_type = args.get("datasource_type")
|
||||
if datasource_type is None:
|
||||
raise ValueError("missing datasource_type")
|
||||
payload = DatasourceNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
return helper.compact_generate_response(
|
||||
@@ -465,24 +503,19 @@ class RagPipelineDraftDatasourceNodeRunApi(Resource):
|
||||
rag_pipeline_service.run_datasource_workflow_node(
|
||||
pipeline=pipeline,
|
||||
node_id=node_id,
|
||||
user_inputs=inputs,
|
||||
user_inputs=payload.inputs,
|
||||
account=current_user,
|
||||
datasource_type=datasource_type,
|
||||
datasource_type=payload.datasource_type,
|
||||
is_published=False,
|
||||
credential_id=args.get("credential_id"),
|
||||
credential_id=payload.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
parser_run_api = reqparse.RequestParser().add_argument(
|
||||
"inputs", type=dict, required=True, nullable=False, location="json"
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class RagPipelineDraftNodeRunApi(Resource):
|
||||
@console_ns.expect(parser_run_api)
|
||||
@console_ns.expect(console_ns.models[NodeRunRequiredPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@edit_permission_required
|
||||
@@ -496,11 +529,8 @@ class RagPipelineDraftNodeRunApi(Resource):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_run_api.parse_args()
|
||||
|
||||
inputs = args.get("inputs")
|
||||
if inputs == None:
|
||||
raise ValueError("missing inputs")
|
||||
payload = NodeRunRequiredPayload.model_validate(console_ns.payload or {})
|
||||
inputs = payload.inputs
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_node_execution = rag_pipeline_service.run_draft_workflow_node(
|
||||
@@ -602,12 +632,8 @@ class DefaultRagPipelineBlockConfigsApi(Resource):
|
||||
return rag_pipeline_service.get_default_block_configs()
|
||||
|
||||
|
||||
parser_default = reqparse.RequestParser().add_argument("q", type=str, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/default-workflow-block-configs/<string:block_type>")
|
||||
class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
@console_ns.expect(parser_default)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -617,14 +643,12 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
"""
|
||||
Get default block config
|
||||
"""
|
||||
args = parser_default.parse_args()
|
||||
|
||||
q = args.get("q")
|
||||
query = DefaultBlockConfigQuery.model_validate(request.args.to_dict())
|
||||
|
||||
filters = None
|
||||
if q:
|
||||
if query.q:
|
||||
try:
|
||||
filters = json.loads(args.get("q", ""))
|
||||
filters = json.loads(query.q)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid filters")
|
||||
|
||||
@@ -633,18 +657,8 @@ class DefaultRagPipelineBlockConfigApi(Resource):
|
||||
return rag_pipeline_service.get_default_block_config(node_type=block_type, filters=filters)
|
||||
|
||||
|
||||
parser_wf = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args")
|
||||
.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=10, location="args")
|
||||
.add_argument("user_id", type=str, required=False, location="args")
|
||||
.add_argument("named_only", type=inputs.boolean, required=False, default=False, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows")
|
||||
class PublishedAllRagPipelineApi(Resource):
|
||||
@console_ns.expect(parser_wf)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -657,16 +671,16 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_wf.parse_args()
|
||||
page = args["page"]
|
||||
limit = args["limit"]
|
||||
user_id = args.get("user_id")
|
||||
named_only = args.get("named_only", False)
|
||||
query = WorkflowListQuery.model_validate(request.args.to_dict())
|
||||
|
||||
page = query.page
|
||||
limit = query.limit
|
||||
user_id = query.user_id
|
||||
named_only = query.named_only
|
||||
|
||||
if user_id:
|
||||
if user_id != current_user.id:
|
||||
raise Forbidden()
|
||||
user_id = cast(str, user_id)
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
with Session(db.engine) as session:
|
||||
@@ -687,16 +701,8 @@ class PublishedAllRagPipelineApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser_wf_id = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("marked_name", type=str, required=False, location="json")
|
||||
.add_argument("marked_comment", type=str, required=False, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/<string:workflow_id>")
|
||||
class RagPipelineByIdApi(Resource):
|
||||
@console_ns.expect(parser_wf_id)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -710,20 +716,8 @@ class RagPipelineByIdApi(Resource):
|
||||
# Check permission
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = parser_wf_id.parse_args()
|
||||
|
||||
# Validate name and comment length
|
||||
if args.marked_name and len(args.marked_name) > 20:
|
||||
raise ValueError("Marked name cannot exceed 20 characters")
|
||||
if args.marked_comment and len(args.marked_comment) > 100:
|
||||
raise ValueError("Marked comment cannot exceed 100 characters")
|
||||
|
||||
# Prepare update data
|
||||
update_data = {}
|
||||
if args.get("marked_name") is not None:
|
||||
update_data["marked_name"] = args["marked_name"]
|
||||
if args.get("marked_comment") is not None:
|
||||
update_data["marked_comment"] = args["marked_comment"]
|
||||
payload = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
@@ -749,12 +743,8 @@ class RagPipelineByIdApi(Resource):
|
||||
return workflow
|
||||
|
||||
|
||||
parser_parameters = reqparse.RequestParser().add_argument("node_id", type=str, required=True, location="args")
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/processing/parameters")
|
||||
class PublishedRagPipelineSecondStepApi(Resource):
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -764,10 +754,8 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||
node_id = query.node_id
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
|
||||
return {
|
||||
@@ -777,7 +765,6 @@ class PublishedRagPipelineSecondStepApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/pre-processing/parameters")
|
||||
class PublishedRagPipelineFirstStepApi(Resource):
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -787,10 +774,8 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
||||
"""
|
||||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||
node_id = query.node_id
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=False)
|
||||
return {
|
||||
@@ -800,7 +785,6 @@ class PublishedRagPipelineFirstStepApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/pre-processing/parameters")
|
||||
class DraftRagPipelineFirstStepApi(Resource):
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -810,10 +794,8 @@ class DraftRagPipelineFirstStepApi(Resource):
|
||||
"""
|
||||
Get first step parameters of rag pipeline
|
||||
"""
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||
node_id = query.node_id
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_first_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
|
||||
return {
|
||||
@@ -823,7 +805,6 @@ class DraftRagPipelineFirstStepApi(Resource):
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/processing/parameters")
|
||||
class DraftRagPipelineSecondStepApi(Resource):
|
||||
@console_ns.expect(parser_parameters)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -833,10 +814,8 @@ class DraftRagPipelineSecondStepApi(Resource):
|
||||
"""
|
||||
Get second step parameters of rag pipeline
|
||||
"""
|
||||
args = parser_parameters.parse_args()
|
||||
node_id = args.get("node_id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required")
|
||||
query = NodeIdQuery.model_validate(request.args.to_dict())
|
||||
node_id = query.node_id
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
variables = rag_pipeline_service.get_second_step_parameters(pipeline=pipeline, node_id=node_id, is_draft=True)
|
||||
@@ -845,16 +824,8 @@ class DraftRagPipelineSecondStepApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
parser_wf_run = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("last_id", type=uuid_value, location="args")
|
||||
.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflow-runs")
|
||||
class RagPipelineWorkflowRunListApi(Resource):
|
||||
@console_ns.expect(parser_wf_run)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -864,7 +835,16 @@ class RagPipelineWorkflowRunListApi(Resource):
|
||||
"""
|
||||
Get workflow run list
|
||||
"""
|
||||
args = parser_wf_run.parse_args()
|
||||
query = WorkflowRunQuery.model_validate(
|
||||
{
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", type=int, default=20),
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": str(query.last_id) if query.last_id else None,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline=pipeline, args=args)
|
||||
@@ -964,18 +944,9 @@ class RagPipelineTransformApi(Resource):
|
||||
return result
|
||||
|
||||
|
||||
parser_var = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("datasource_type", type=str, required=True, location="json")
|
||||
.add_argument("datasource_info", type=dict, required=True, location="json")
|
||||
.add_argument("start_node_id", type=str, required=True, location="json")
|
||||
.add_argument("start_node_title", type=str, required=True, location="json")
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/datasource/variables-inspect")
|
||||
class RagPipelineDatasourceVariableApi(Resource):
|
||||
@console_ns.expect(parser_var)
|
||||
@console_ns.expect(console_ns.models[DatasourceVariablesPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -987,7 +958,7 @@ class RagPipelineDatasourceVariableApi(Resource):
|
||||
Set datasource variables
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = parser_var.parse_args()
|
||||
args = DatasourceVariablesPayload.model_validate(console_ns.payload or {}).model_dump()
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow_node_execution = rag_pipeline_service.set_datasource_variables(
|
||||
|
||||
Reference in New Issue
Block a user