feat: knowledge pipeline (#25360)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: jyong <718720800@qq.com> Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: quicksand <quicksandzn@gmail.com> Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zxhlyh <jasonapring2015@outlook.com> Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: Hanqing Zhao <sherry9277@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
@@ -26,7 +26,6 @@ from .dataset import (
|
||||
TidbAuthBinding,
|
||||
Whitelist,
|
||||
)
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom
|
||||
from .model import (
|
||||
ApiRequest,
|
||||
@@ -57,6 +56,7 @@ from .model import (
|
||||
TraceAppConfig,
|
||||
UploadFile,
|
||||
)
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from .provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
@@ -86,6 +86,7 @@ from .workflow import (
|
||||
WorkflowAppLog,
|
||||
WorkflowAppLogCreatedFrom,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowType,
|
||||
@@ -123,6 +124,8 @@ __all__ = [
|
||||
"DatasetProcessRule",
|
||||
"DatasetQuery",
|
||||
"DatasetRetrieverResource",
|
||||
"DatasourceOauthParamConfig",
|
||||
"DatasourceProvider",
|
||||
"DifySetup",
|
||||
"Document",
|
||||
"DocumentSegment",
|
||||
@@ -172,10 +175,10 @@ __all__ = [
|
||||
"WorkflowAppLog",
|
||||
"WorkflowAppLogCreatedFrom",
|
||||
"WorkflowNodeExecutionModel",
|
||||
"WorkflowNodeExecutionOffload",
|
||||
"WorkflowNodeExecutionTriggeredFrom",
|
||||
"WorkflowRun",
|
||||
"WorkflowRunTriggeredFrom",
|
||||
"WorkflowToolProvider",
|
||||
"WorkflowType",
|
||||
"db",
|
||||
]
|
||||
|
||||
@@ -15,7 +15,7 @@ from typing import Any, cast
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
@@ -61,12 +61,35 @@ class Dataset(Base):
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
embedding_model = mapped_column(String(255), nullable=True)
|
||||
embedding_model_provider = mapped_column(String(255), nullable=True)
|
||||
updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
embedding_model = mapped_column(db.String(255), nullable=True)
|
||||
embedding_model_provider = mapped_column(db.String(255), nullable=True)
|
||||
keyword_number = db.Column(db.Integer, nullable=True, server_default=db.text("10"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||
retrieval_model = mapped_column(JSONB, nullable=True)
|
||||
built_in_field_enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
icon_info = db.Column(JSONB, nullable=True)
|
||||
runtime_mode = db.Column(db.String(255), nullable=True, server_default=db.text("'general'::character varying"))
|
||||
pipeline_id = db.Column(StringUUID, nullable=True)
|
||||
chunk_structure = db.Column(db.String(255), nullable=True)
|
||||
enable_api = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
|
||||
@property
|
||||
def total_documents(self):
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
|
||||
@property
|
||||
def total_available_documents(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
@@ -150,7 +173,9 @@ class Dataset(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def doc_form(self):
|
||||
def doc_form(self) -> str | None:
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
|
||||
if document:
|
||||
return document.doc_form
|
||||
@@ -206,6 +231,14 @@ class Dataset(Base):
|
||||
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
|
||||
}
|
||||
|
||||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
|
||||
if pipeline:
|
||||
return pipeline.is_published
|
||||
return False
|
||||
|
||||
@property
|
||||
def doc_metadata(self):
|
||||
dataset_metadatas = db.session.scalars(
|
||||
@@ -394,7 +427,7 @@ class Document(Base):
|
||||
return status
|
||||
|
||||
@property
|
||||
def data_source_info_dict(self) -> dict[str, Any] | None:
|
||||
def data_source_info_dict(self) -> dict[str, Any]:
|
||||
if self.data_source_info:
|
||||
try:
|
||||
data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info)
|
||||
@@ -402,7 +435,7 @@ class Document(Base):
|
||||
data_source_info_dict = {}
|
||||
|
||||
return data_source_info_dict
|
||||
return None
|
||||
return {}
|
||||
|
||||
@property
|
||||
def data_source_detail_dict(self) -> dict[str, Any]:
|
||||
@@ -759,7 +792,7 @@ class DocumentSegment(Base):
|
||||
text = self.content
|
||||
|
||||
# For data before v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview"
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
|
||||
matches = re.finditer(pattern, text)
|
||||
for match in matches:
|
||||
upload_file_id = match.group(1)
|
||||
@@ -771,11 +804,12 @@ class DocumentSegment(Base):
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
signed_url = f"{match.group(0)}?{params}"
|
||||
base_url = f"/files/{upload_file_id}/image-preview"
|
||||
signed_url = f"{base_url}?{params}"
|
||||
signed_urls.append((match.start(), match.end(), signed_url))
|
||||
|
||||
# For data after v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview"
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
|
||||
matches = re.finditer(pattern, text)
|
||||
for match in matches:
|
||||
upload_file_id = match.group(1)
|
||||
@@ -787,7 +821,27 @@ class DocumentSegment(Base):
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
signed_url = f"{match.group(0)}?{params}"
|
||||
base_url = f"/files/{upload_file_id}/file-preview"
|
||||
signed_url = f"{base_url}?{params}"
|
||||
signed_urls.append((match.start(), match.end(), signed_url))
|
||||
|
||||
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
|
||||
# Match URL including any query parameters up to common URL boundaries (space, parenthesis, quotes)
|
||||
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
|
||||
matches = re.finditer(pattern, text)
|
||||
for match in matches:
|
||||
upload_file_id = match.group(1)
|
||||
file_extension = match.group(2)
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
base_url = f"/files/tools/{upload_file_id}.{file_extension}"
|
||||
signed_url = f"{base_url}?{params}"
|
||||
signed_urls.append((match.start(), match.end(), signed_url))
|
||||
|
||||
# Reconstruct the text with signed URLs
|
||||
@@ -1166,3 +1220,112 @@ class DatasetMetadataBinding(Base):
|
||||
document_id = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
|
||||
class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_built_in_templates"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_built_in_template_pkey"),)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
chunk_structure = db.Column(db.String(255), nullable=False)
|
||||
icon = db.Column(db.JSON, nullable=False)
|
||||
yaml_content = db.Column(db.Text, nullable=False)
|
||||
copyright = db.Column(db.String(255), nullable=False)
|
||||
privacy_policy = db.Column(db.String(255), nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
install_count = db.Column(db.Integer, nullable=False, default=0)
|
||||
language = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
||||
|
||||
class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipeline_customized_templates"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="pipeline_customized_template_pkey"),
|
||||
db.Index("pipeline_customized_template_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False)
|
||||
chunk_structure = db.Column(db.String(255), nullable=False)
|
||||
icon = db.Column(db.JSON, nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False)
|
||||
yaml_content = db.Column(db.Text, nullable=False)
|
||||
install_count = db.Column(db.Integer, nullable=False, default=0)
|
||||
language = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
||||
|
||||
class Pipeline(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "pipelines"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_pkey"),)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
|
||||
workflow_id = db.Column(StringUUID, nullable=True)
|
||||
is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
is_published = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLog(Base):
|
||||
__tablename__ = "document_pipeline_execution_logs"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="document_pipeline_execution_log_pkey"),
|
||||
db.Index("document_pipeline_execution_logs_document_id_idx", "document_id"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
pipeline_id = db.Column(StringUUID, nullable=False)
|
||||
document_id = db.Column(StringUUID, nullable=False)
|
||||
datasource_type = db.Column(db.String(255), nullable=False)
|
||||
datasource_info = db.Column(db.Text, nullable=False)
|
||||
datasource_node_id = db.Column(db.String(255), nullable=False)
|
||||
input_data = db.Column(db.JSON, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class PipelineRecommendedPlugin(Base):
|
||||
__tablename__ = "pipeline_recommended_plugins"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("id", name="pipeline_recommended_plugin_pkey"),)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
plugin_id = db.Column(db.Text, nullable=False)
|
||||
provider_name = db.Column(db.Text, nullable=False)
|
||||
position = db.Column(db.Integer, nullable=False, default=0)
|
||||
active = db.Column(db.Boolean, nullable=False, default=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@@ -14,6 +14,8 @@ class UserFrom(StrEnum):
|
||||
class WorkflowRunTriggeredFrom(StrEnum):
|
||||
DEBUGGING = "debugging"
|
||||
APP_RUN = "app-run"
|
||||
RAG_PIPELINE_RUN = "rag-pipeline-run"
|
||||
RAG_PIPELINE_DEBUGGING = "rag-pipeline-debugging"
|
||||
|
||||
|
||||
class DraftVariableType(StrEnum):
|
||||
@@ -30,3 +32,9 @@ class MessageStatus(StrEnum):
|
||||
|
||||
NORMAL = "normal"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ExecutionOffLoadType(StrEnum):
|
||||
INPUTS = "inputs"
|
||||
PROCESS_DATA = "process_data"
|
||||
OUTPUTS = "outputs"
|
||||
|
||||
@@ -6,14 +6,6 @@ from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.workflow import Workflow
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
||||
@@ -24,14 +16,20 @@ from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.file import helpers as file_helpers
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from libs.helper import generate_string # type: ignore[import-not-found]
|
||||
|
||||
from .account import Account, Tenant
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole
|
||||
from .provider_ids import GenericProviderID
|
||||
from .types import StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class DifySetup(Base):
|
||||
__tablename__ = "dify_setups"
|
||||
@@ -47,6 +45,8 @@ class AppMode(StrEnum):
|
||||
CHAT = "chat"
|
||||
ADVANCED_CHAT = "advanced-chat"
|
||||
AGENT_CHAT = "agent-chat"
|
||||
CHANNEL = "channel"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "AppMode":
|
||||
@@ -163,7 +163,7 @@ class App(Base):
|
||||
|
||||
@property
|
||||
def deleted_tools(self) -> list[dict[str, str]]:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.tool_manager import ToolManager, ToolProviderType
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
# get agent mode tools
|
||||
@@ -178,6 +178,7 @@ class App(Base):
|
||||
tools = agent_mode.get("tools", [])
|
||||
|
||||
api_provider_ids: list[str] = []
|
||||
|
||||
builtin_provider_ids: list[GenericProviderID] = []
|
||||
|
||||
for tool in tools:
|
||||
@@ -846,7 +847,8 @@ class Conversation(Base):
|
||||
|
||||
@property
|
||||
def app(self) -> App | None:
|
||||
return db.session.query(App).where(App.id == self.app_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return session.query(App).where(App.id == self.app_id).first()
|
||||
|
||||
@property
|
||||
def from_end_user_session_id(self):
|
||||
@@ -1138,7 +1140,7 @@ class Message(Base):
|
||||
)
|
||||
|
||||
@property
|
||||
def retriever_resources(self) -> Any | list[Any]:
|
||||
def retriever_resources(self) -> Any:
|
||||
return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else []
|
||||
|
||||
@property
|
||||
@@ -1621,6 +1623,9 @@ class UploadFile(Base):
|
||||
sa.Index("upload_file_tenant_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
# NOTE: The `id` field is generated within the application to minimize extra roundtrips
|
||||
# (especially when generating `source_url`).
|
||||
# The `server_default` serves as a fallback mechanism.
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
@@ -1629,12 +1634,32 @@ class UploadFile(Base):
|
||||
size: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||
extension: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
mime_type: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
|
||||
# The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`.
|
||||
# Its value is derived from the `CreatorUserRole` enumeration.
|
||||
created_by_role: Mapped[str] = mapped_column(
|
||||
String(255), nullable=False, server_default=sa.text("'account'::character varying")
|
||||
)
|
||||
|
||||
# The `created_by` field stores the ID of the entity that created this upload file.
|
||||
#
|
||||
# If `created_by_role` is `ACCOUNT`, it corresponds to `Account.id`.
|
||||
# Otherwise, it corresponds to `EndUser.id`.
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
# The fields `used` and `used_by` are not consistently maintained.
|
||||
#
|
||||
# When using this model in new code, ensure the following:
|
||||
#
|
||||
# 1. Set `used` to `true` when the file is utilized.
|
||||
# 2. Assign `used_by` to the corresponding `Account.id` or `EndUser.id` based on the `created_by_role`.
|
||||
# 3. Avoid relying on these fields for logic, as their values may not always be accurate.
|
||||
#
|
||||
# `used` may indicate whether the file has been utilized by another service.
|
||||
used: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
|
||||
# `used_by` may indicate the ID of the user who utilized this file.
|
||||
used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
used_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
|
||||
hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
@@ -1659,6 +1684,7 @@ class UploadFile(Base):
|
||||
hash: str | None = None,
|
||||
source_url: str = "",
|
||||
):
|
||||
self.id = str(uuid.uuid4())
|
||||
self.tenant_id = tenant_id
|
||||
self.storage_type = storage_type
|
||||
self.key = key
|
||||
|
||||
61
api/models/oauth.py
Normal file
61
api/models/oauth.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class DatasourceOauthParamConfig(Base): # type: ignore[name-defined]
|
||||
__tablename__ = "datasource_oauth_params"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_oauth_config_pkey"),
|
||||
db.UniqueConstraint("plugin_id", "provider", name="datasource_oauth_config_datasource_id_provider_idx"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
system_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
|
||||
|
||||
class DatasourceProvider(Base):
|
||||
__tablename__ = "datasource_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", "name", name="datasource_provider_unique_name"),
|
||||
db.Index("datasource_provider_auth_type_provider_idx", "tenant_id", "plugin_id", "provider"),
|
||||
)
|
||||
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
avatar_url: Mapped[str] = db.Column(db.String(255), nullable=True, default="default")
|
||||
is_default: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
expires_at: Mapped[int] = db.Column(db.Integer, nullable=False, server_default="-1")
|
||||
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
|
||||
|
||||
class DatasourceOauthTenantParamConfig(Base):
|
||||
__tablename__ = "datasource_oauth_tenant_params"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_oauth_tenant_config_pkey"),
|
||||
db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="datasource_oauth_tenant_config_unique"),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
client_params: Mapped[dict] = db.Column(JSONB, nullable=False, default={})
|
||||
enabled: Mapped[bool] = db.Column(db.Boolean, nullable=False, default=False)
|
||||
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
updated_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
59
api/models/provider_ids.py
Normal file
59
api/models/provider_ids.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Provider ID entities for plugin system."""
|
||||
|
||||
import re
|
||||
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
|
||||
class GenericProviderID:
|
||||
organization: str
|
||||
plugin_name: str
|
||||
provider_name: str
|
||||
is_hardcoded: bool
|
||||
|
||||
def to_string(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}/{self.provider_name}"
|
||||
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
if not value:
|
||||
raise NotFound("plugin not found, please add plugin")
|
||||
# check if the value is a valid plugin id with format: $organization/$plugin_name/$provider_name
|
||||
if not re.match(r"^[a-z0-9_-]+\/[a-z0-9_-]+\/[a-z0-9_-]+$", value):
|
||||
# check if matches [a-z0-9_-]+, if yes, append with langgenius/$value/$value
|
||||
if re.match(r"^[a-z0-9_-]+$", value):
|
||||
value = f"langgenius/{value}/{value}"
|
||||
else:
|
||||
raise ValueError(f"Invalid plugin id {value}")
|
||||
|
||||
self.organization, self.plugin_name, self.provider_name = value.split("/")
|
||||
self.is_hardcoded = is_hardcoded
|
||||
|
||||
def is_langgenius(self) -> bool:
|
||||
return self.organization == "langgenius"
|
||||
|
||||
@property
|
||||
def plugin_id(self) -> str:
|
||||
return f"{self.organization}/{self.plugin_name}"
|
||||
|
||||
|
||||
class ModelProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius" and self.provider_name == "google":
|
||||
self.plugin_name = "gemini"
|
||||
|
||||
|
||||
class ToolProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
if self.organization == "langgenius":
|
||||
if self.provider_name in ["jina", "siliconflow", "stepfun", "gitee_ai"]:
|
||||
self.plugin_name = f"{self.provider_name}_tool"
|
||||
|
||||
|
||||
class DatasourceProviderID(GenericProviderID):
|
||||
def __init__(self, value: str, is_hardcoded: bool = False) -> None:
|
||||
super().__init__(value, is_hardcoded)
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -8,9 +9,7 @@ from deprecated import deprecated
|
||||
from sqlalchemy import ForeignKey, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import encrypter
|
||||
from core.mcp.types import Tool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
@@ -20,6 +19,12 @@ from .engine import db
|
||||
from .model import Account, App, Tenant
|
||||
from .types import StringUUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
|
||||
|
||||
|
||||
# system level tool oauth client params (client_id, client_secret, etc.)
|
||||
class ToolOAuthSystemClient(TypeBase):
|
||||
@@ -138,11 +143,15 @@ class ApiToolProvider(Base):
|
||||
updated_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def schema_type(self) -> ApiProviderSchemaType:
|
||||
def schema_type(self) -> "ApiProviderSchemaType":
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
|
||||
return ApiProviderSchemaType.value_of(self.schema_type_str)
|
||||
|
||||
@property
|
||||
def tools(self) -> list[ApiToolBundle]:
|
||||
def tools(self) -> list["ApiToolBundle"]:
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
|
||||
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
|
||||
|
||||
@property
|
||||
@@ -230,7 +239,9 @@ class WorkflowToolProvider(Base):
|
||||
return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
|
||||
|
||||
@property
|
||||
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
|
||||
def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
|
||||
return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
|
||||
|
||||
@property
|
||||
@@ -298,13 +309,17 @@ class MCPToolProvider(Base):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def mcp_tools(self) -> list[Tool]:
|
||||
return [Tool(**tool) for tool in json.loads(self.tools)]
|
||||
def mcp_tools(self) -> list["MCPTool"]:
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
|
||||
return [MCPTool(**tool) for tool in json.loads(self.tools)]
|
||||
|
||||
@property
|
||||
def provider_icon(self) -> dict[str, str] | str:
|
||||
def provider_icon(self) -> Mapping[str, str] | str:
|
||||
from core.file import helpers as file_helpers
|
||||
|
||||
try:
|
||||
return cast(dict[str, str], json.loads(self.icon))
|
||||
return json.loads(self.icon)
|
||||
except json.JSONDecodeError:
|
||||
return file_helpers.get_signed_file_url(self.icon)
|
||||
|
||||
@@ -534,5 +549,7 @@ class DeprecatedPublishedAppTool(Base):
|
||||
updated_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.text("CURRENT_TIMESTAMP(0)"))
|
||||
|
||||
@property
|
||||
def description_i18n(self) -> I18nObject:
|
||||
def description_i18n(self) -> "I18nObject":
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
return I18nObject(**json.loads(self.description))
|
||||
|
||||
@@ -2,26 +2,28 @@ import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum, auto
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, exists, orm, select
|
||||
from sqlalchemy import DateTime, Select, exists, orm, select
|
||||
|
||||
from core.file.constants import maybe_file_object
|
||||
from core.file.models import File
|
||||
from core.variables import utils as variable_utils
|
||||
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import AppMode
|
||||
from models.model import AppMode, UploadFile
|
||||
|
||||
from sqlalchemy import Index, PrimaryKeyConstraint, String, UniqueConstraint, func
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
@@ -35,7 +37,7 @@ from libs import helper
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .enums import CreatorUserRole, DraftVariableType
|
||||
from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType
|
||||
from .types import EnumText, StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -46,8 +48,9 @@ class WorkflowType(StrEnum):
|
||||
Workflow Type Enum
|
||||
"""
|
||||
|
||||
WORKFLOW = auto()
|
||||
CHAT = auto()
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowType":
|
||||
@@ -143,6 +146,9 @@ class Workflow(Base):
|
||||
_conversation_variables: Mapped[str] = mapped_column(
|
||||
"conversation_variables", sa.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_rag_pipeline_variables: Mapped[str] = mapped_column(
|
||||
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
|
||||
VERSION_DRAFT = "draft"
|
||||
|
||||
@@ -159,6 +165,7 @@ class Workflow(Base):
|
||||
created_by: str,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
rag_pipeline_variables: list[dict],
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> "Workflow":
|
||||
@@ -173,6 +180,7 @@ class Workflow(Base):
|
||||
workflow.created_by = created_by
|
||||
workflow.environment_variables = environment_variables or []
|
||||
workflow.conversation_variables = conversation_variables or []
|
||||
workflow.rag_pipeline_variables = rag_pipeline_variables or []
|
||||
workflow.marked_name = marked_name
|
||||
workflow.marked_comment = marked_comment
|
||||
workflow.created_at = naive_utc_now()
|
||||
@@ -314,6 +322,12 @@ class Workflow(Base):
|
||||
|
||||
return variables
|
||||
|
||||
def rag_pipeline_user_input_form(self) -> list:
|
||||
# get user_input_form from start node
|
||||
variables: list[Any] = self.rag_pipeline_variables
|
||||
|
||||
return variables
|
||||
|
||||
@property
|
||||
def unique_hash(self) -> str:
|
||||
"""
|
||||
@@ -354,7 +368,7 @@ class Workflow(Base):
|
||||
if not tenant_id:
|
||||
return []
|
||||
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables)
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
|
||||
results = [
|
||||
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
||||
]
|
||||
@@ -424,6 +438,7 @@ class Workflow(Base):
|
||||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
|
||||
"rag_pipeline_variables": self.rag_pipeline_variables,
|
||||
}
|
||||
return result
|
||||
|
||||
@@ -442,6 +457,23 @@ class Workflow(Base):
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def rag_pipeline_variables(self) -> list[dict]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = list(variables_dict.values())
|
||||
return results
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{item["variable"]: item for item in values},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def version_from_datetime(d: datetime) -> str:
|
||||
return str(d)
|
||||
@@ -606,9 +638,10 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum):
|
||||
|
||||
SINGLE_STEP = "single-step"
|
||||
WORKFLOW_RUN = "workflow-run"
|
||||
RAG_PIPELINE_RUN = "rag-pipeline-run"
|
||||
|
||||
|
||||
class WorkflowNodeExecutionModel(Base):
|
||||
class WorkflowNodeExecutionModel(Base): # This model is expected to have `offload_data` preloaded in most cases.
|
||||
"""
|
||||
Workflow Node Execution
|
||||
|
||||
@@ -725,6 +758,32 @@ class WorkflowNodeExecutionModel(Base):
|
||||
created_by: Mapped[str] = mapped_column(StringUUID)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
|
||||
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
|
||||
"WorkflowNodeExecutionOffload",
|
||||
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
|
||||
uselist=True,
|
||||
lazy="raise",
|
||||
back_populates="execution",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def preload_offload_data(
|
||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
||||
):
|
||||
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
|
||||
|
||||
@staticmethod
|
||||
def preload_offload_data_and_files(
|
||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
||||
):
|
||||
return query.options(
|
||||
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
|
||||
# Using `joinedload` instead of `selectinload` to minimize database roundtrips,
|
||||
# as `selectinload` would require separate queries for `inputs_file` and `outputs_file`.
|
||||
orm.selectinload(WorkflowNodeExecutionOffload.file),
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
created_by_role = CreatorUserRole(self.created_by_role)
|
||||
@@ -773,9 +832,132 @@ class WorkflowNodeExecutionModel(Base):
|
||||
provider_type=tool_info["provider_type"],
|
||||
provider_id=tool_info["provider_id"],
|
||||
)
|
||||
|
||||
elif self.node_type == NodeType.DATASOURCE.value and "datasource_info" in self.execution_metadata_dict:
|
||||
datasource_info = self.execution_metadata_dict["datasource_info"]
|
||||
extras["icon"] = datasource_info.get("icon")
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
||||
|
||||
@property
|
||||
def inputs_truncated(self) -> bool:
|
||||
"""Check if inputs were truncated (offloaded to external storage)."""
|
||||
return self._get_offload_by_type(ExecutionOffLoadType.INPUTS) is not None
|
||||
|
||||
@property
|
||||
def outputs_truncated(self) -> bool:
|
||||
"""Check if outputs were truncated (offloaded to external storage)."""
|
||||
return self._get_offload_by_type(ExecutionOffLoadType.OUTPUTS) is not None
|
||||
|
||||
@property
|
||||
def process_data_truncated(self) -> bool:
|
||||
"""Check if process_data were truncated (offloaded to external storage)."""
|
||||
return self._get_offload_by_type(ExecutionOffLoadType.PROCESS_DATA) is not None
|
||||
|
||||
@staticmethod
|
||||
def _load_full_content(session: orm.Session, file_id: str, storage: Storage):
|
||||
from .model import UploadFile
|
||||
|
||||
stmt = sa.select(UploadFile).where(UploadFile.id == file_id)
|
||||
file = session.scalars(stmt).first()
|
||||
assert file is not None, f"UploadFile with id {file_id} should exist but not"
|
||||
content = storage.load(file.key)
|
||||
return json.loads(content)
|
||||
|
||||
def load_full_inputs(self, session: orm.Session, storage: Storage) -> Mapping[str, Any] | None:
|
||||
offload = self._get_offload_by_type(ExecutionOffLoadType.INPUTS)
|
||||
if offload is None:
|
||||
return self.inputs_dict
|
||||
|
||||
return self._load_full_content(session, offload.file_id, storage)
|
||||
|
||||
def load_full_outputs(self, session: orm.Session, storage: Storage) -> Mapping[str, Any] | None:
|
||||
offload: WorkflowNodeExecutionOffload | None = self._get_offload_by_type(ExecutionOffLoadType.OUTPUTS)
|
||||
if offload is None:
|
||||
return self.outputs_dict
|
||||
|
||||
return self._load_full_content(session, offload.file_id, storage)
|
||||
|
||||
def load_full_process_data(self, session: orm.Session, storage: Storage) -> Mapping[str, Any] | None:
|
||||
offload: WorkflowNodeExecutionOffload | None = self._get_offload_by_type(ExecutionOffLoadType.PROCESS_DATA)
|
||||
if offload is None:
|
||||
return self.process_data_dict
|
||||
|
||||
return self._load_full_content(session, offload.file_id, storage)
|
||||
|
||||
|
||||
class WorkflowNodeExecutionOffload(Base):
|
||||
__tablename__ = "workflow_node_execution_offload"
|
||||
__table_args__ = (
|
||||
# PostgreSQL 14 treats NULL values as distinct in unique constraints by default,
|
||||
# allowing multiple records with NULL values for the same column combination.
|
||||
#
|
||||
# This behavior allows us to have multiple records with NULL node_execution_id,
|
||||
# simplifying garbage collection process.
|
||||
UniqueConstraint(
|
||||
"node_execution_id",
|
||||
"type",
|
||||
# Note: PostgreSQL 15+ supports explicit `nulls distinct` behavior through
|
||||
# `postgresql_nulls_not_distinct=False`, which would make our intention clearer.
|
||||
# We rely on PostgreSQL's default behavior of treating NULLs as distinct values.
|
||||
# postgresql_nulls_not_distinct=False,
|
||||
),
|
||||
)
|
||||
_HASH_COL_SIZE = 64
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
server_default=sa.text("uuidv7()"),
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=naive_utc_now, server_default=func.current_timestamp()
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID)
|
||||
|
||||
# `node_execution_id` indicates the `WorkflowNodeExecutionModel` associated with this offload record.
|
||||
# A value of `None` signifies that this offload record is not linked to any execution record
|
||||
# and should be considered for garbage collection.
|
||||
node_execution_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
type_: Mapped[ExecutionOffLoadType] = mapped_column(EnumText(ExecutionOffLoadType), name="type", nullable=False)
|
||||
|
||||
# Design Decision: Combining inputs and outputs into a single object was considered to reduce I/O
|
||||
# operations. However, due to the current design of `WorkflowNodeExecutionRepository`,
|
||||
# the `save` method is called at two distinct times:
|
||||
#
|
||||
# - When the node starts execution: the `inputs` field exists, but the `outputs` field is absent
|
||||
# - When the node completes execution (either succeeded or failed): the `outputs` field becomes available
|
||||
#
|
||||
# It's difficult to correlate these two successive calls to `save` for combined storage.
|
||||
# Converting the `WorkflowNodeExecutionRepository` to buffer the first `save` call and flush
|
||||
# when execution completes was also considered, but this would make the execution state unobservable
|
||||
# until completion, significantly damaging the observability of workflow execution.
|
||||
#
|
||||
# Given these constraints, `inputs` and `outputs` are stored separately to maintain real-time
|
||||
# observability and system reliability.
|
||||
|
||||
# `file_id` references to the offloaded storage object containing the data.
|
||||
file_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
execution: Mapped[WorkflowNodeExecutionModel] = orm.relationship(
|
||||
foreign_keys=[node_execution_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowNodeExecutionOffload.node_execution_id == WorkflowNodeExecutionModel.id",
|
||||
back_populates="offload_data",
|
||||
)
|
||||
|
||||
file: Mapped[Optional["UploadFile"]] = orm.relationship(
|
||||
foreign_keys=[file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowNodeExecutionOffload.file_id == UploadFile.id",
|
||||
)
|
||||
|
||||
|
||||
class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
"""
|
||||
@@ -939,7 +1121,10 @@ class WorkflowDraftVariable(Base):
|
||||
]
|
||||
|
||||
__tablename__ = "workflow_draft_variables"
|
||||
__table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(*unique_app_id_node_id_name()),
|
||||
Index("workflow_draft_variable_file_id_idx", "file_id"),
|
||||
)
|
||||
# Required for instance variable annotation.
|
||||
__allow_unmapped__ = True
|
||||
|
||||
@@ -1000,9 +1185,16 @@ class WorkflowDraftVariable(Base):
|
||||
selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector")
|
||||
|
||||
# The data type of this variable's value
|
||||
#
|
||||
# If the variable is offloaded, `value_type` represents the type of the truncated value,
|
||||
# which may differ from the original value's type. Typically, they are the same,
|
||||
# but in cases where the structurally truncated value still exceeds the size limit,
|
||||
# text slicing is applied, and the `value_type` is converted to `STRING`.
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
|
||||
|
||||
# The variable's value serialized as a JSON string
|
||||
#
|
||||
# If the variable is offloaded, `value` contains a truncated version, not the full original value.
|
||||
value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value")
|
||||
|
||||
# Controls whether the variable should be displayed in the variable inspection panel
|
||||
@@ -1022,6 +1214,35 @@ class WorkflowDraftVariable(Base):
|
||||
default=None,
|
||||
)
|
||||
|
||||
# Reference to WorkflowDraftVariableFile for offloaded large variables
|
||||
#
|
||||
# Indicates whether the current draft variable is offloaded.
|
||||
# If not offloaded, this field will be None.
|
||||
file_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=True,
|
||||
default=None,
|
||||
comment="Reference to WorkflowDraftVariableFile if variable is offloaded to external storage",
|
||||
)
|
||||
|
||||
is_default_value: Mapped[bool] = mapped_column(
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
default=False,
|
||||
comment=(
|
||||
"Indicates whether the current value is the default for a conversation variable. "
|
||||
"Always `FALSE` for other types of variables."
|
||||
),
|
||||
)
|
||||
|
||||
# Relationship to WorkflowDraftVariableFile
|
||||
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
|
||||
foreign_keys=[file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowDraftVariableFile.id == WorkflowDraftVariable.file_id",
|
||||
)
|
||||
|
||||
# Cache for deserialized value
|
||||
#
|
||||
# NOTE(QuantumGhost): This field serves two purposes:
|
||||
@@ -1171,6 +1392,9 @@ class WorkflowDraftVariable(Base):
|
||||
case _:
|
||||
return DraftVariableType.NODE
|
||||
|
||||
def is_truncated(self) -> bool:
|
||||
return self.file_id is not None
|
||||
|
||||
@classmethod
|
||||
def _new(
|
||||
cls,
|
||||
@@ -1181,6 +1405,7 @@ class WorkflowDraftVariable(Base):
|
||||
value: Segment,
|
||||
node_execution_id: str | None,
|
||||
description: str = "",
|
||||
file_id: str | None = None,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = WorkflowDraftVariable()
|
||||
variable.created_at = _naive_utc_datetime()
|
||||
@@ -1190,6 +1415,7 @@ class WorkflowDraftVariable(Base):
|
||||
variable.node_id = node_id
|
||||
variable.name = name
|
||||
variable.set_value(value)
|
||||
variable.file_id = file_id
|
||||
variable._set_selector(list(variable_utils.to_selector(node_id, name)))
|
||||
variable.node_execution_id = node_execution_id
|
||||
return variable
|
||||
@@ -1245,6 +1471,7 @@ class WorkflowDraftVariable(Base):
|
||||
node_execution_id: str,
|
||||
visible: bool = True,
|
||||
editable: bool = True,
|
||||
file_id: str | None = None,
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
@@ -1252,6 +1479,7 @@ class WorkflowDraftVariable(Base):
|
||||
name=name,
|
||||
node_execution_id=node_execution_id,
|
||||
value=value,
|
||||
file_id=file_id,
|
||||
)
|
||||
variable.visible = visible
|
||||
variable.editable = editable
|
||||
@@ -1262,5 +1490,92 @@ class WorkflowDraftVariable(Base):
|
||||
return self.last_edited_at is not None
|
||||
|
||||
|
||||
class WorkflowDraftVariableFile(Base):
|
||||
"""Stores metadata about files associated with large workflow draft variables.
|
||||
|
||||
This model acts as an intermediary between WorkflowDraftVariable and UploadFile,
|
||||
allowing for proper cleanup of orphaned files when variables are updated or deleted.
|
||||
|
||||
The MIME type of the stored content is recorded in `UploadFile.mime_type`.
|
||||
Possible values are 'application/json' for JSON types other than plain text,
|
||||
and 'text/plain' for JSON strings.
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_draft_variable_files"
|
||||
|
||||
# Primary key
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
primary_key=True,
|
||||
default=uuidv7,
|
||||
server_default=sa.text("uuidv7()"),
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
default=_naive_utc_datetime,
|
||||
server_default=func.current_timestamp(),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
comment="The tenant to which the WorkflowDraftVariableFile belongs, referencing Tenant.id",
|
||||
)
|
||||
|
||||
app_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
comment="The application to which the WorkflowDraftVariableFile belongs, referencing App.id",
|
||||
)
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
comment="The owner to of the WorkflowDraftVariableFile, referencing Account.id",
|
||||
)
|
||||
|
||||
# Reference to the `UploadFile.id` field
|
||||
upload_file_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
nullable=False,
|
||||
comment="Reference to UploadFile containing the large variable data",
|
||||
)
|
||||
|
||||
# -------------- metadata about the variable content --------------
|
||||
|
||||
# The `size` is already recorded in UploadFiles. It is duplicated here to avoid an additional database lookup.
|
||||
size: Mapped[int | None] = mapped_column(
|
||||
sa.BigInteger,
|
||||
nullable=False,
|
||||
comment="Size of the original variable content in bytes",
|
||||
)
|
||||
|
||||
length: Mapped[int | None] = mapped_column(
|
||||
sa.Integer,
|
||||
nullable=True,
|
||||
comment=(
|
||||
"Length of the original variable content. For array and array-like types, "
|
||||
"this represents the number of elements. For object types, it indicates the number of keys. "
|
||||
"For other types, the value is NULL."
|
||||
),
|
||||
)
|
||||
|
||||
# The `value_type` field records the type of the original value.
|
||||
value_type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=20),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationship to UploadFile
|
||||
upload_file: Mapped["UploadFile"] = orm.relationship(
|
||||
foreign_keys=[upload_file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
primaryjoin="WorkflowDraftVariableFile.upload_file_id == UploadFile.id",
|
||||
)
|
||||
|
||||
|
||||
def is_system_variable_editable(name: str) -> bool:
|
||||
return name in _EDITABLE_SYSTEM_VARIABLE
|
||||
|
||||
Reference in New Issue
Block a user