feat: knowledge pipeline (#25360)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: jyong <718720800@qq.com>
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com>
Co-authored-by: quicksand <quicksandzn@gmail.com>
Co-authored-by: Jyong <76649700+JohnJyong@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: Yongtao Huang <yongtaoh2022@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: Hanqing Zhao <sherry9277@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Harry <xh001x@hotmail.com>
This commit is contained in:
-LAN-
2025-09-18 12:49:10 +08:00
committed by GitHub
parent 7dadb33003
commit 85cda47c70
1772 changed files with 102407 additions and 31710 deletions

View File

@@ -16,9 +16,9 @@ from werkzeug.exceptions import NotFound
from configs import dify_config
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.helper.name_generator import generate_incremental_name
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
@@ -43,9 +43,12 @@ from models.dataset import (
Document,
DocumentSegment,
ExternalKnowledgeBindings,
Pipeline,
)
from models.model import UploadFile
from models.provider_ids import ModelProviderID
from models.source import DataSourceOauthBinding
from models.workflow import Workflow
from services.entities.knowledge_entities.knowledge_entities import (
ChildChunkUpdateArgs,
KnowledgeConfig,
@@ -53,6 +56,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
RetrievalModel,
SegmentUpdateArgs,
)
from services.entities.knowledge_entities.rag_pipeline_entities import (
KnowledgeConfiguration,
RagPipelineDatasetCreateEntity,
)
from services.errors.account import NoPermissionError
from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError
from services.errors.dataset import DatasetNameDuplicateError
@@ -60,11 +67,13 @@ from services.errors.document import DocumentIndexingError
from services.errors.file import FileNotExistsError
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.rag_pipeline.rag_pipeline import RagPipelineService
from services.tag_service import TagService
from services.vector_service import VectorService
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.batch_clean_document_task import batch_clean_document_task
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_index_update_task import deal_dataset_index_update_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
from tasks.delete_segment_from_index_task import delete_segment_from_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
@@ -256,6 +265,55 @@ class DatasetService:
db.session.commit()
return dataset
@staticmethod
def create_empty_rag_pipeline_dataset(
tenant_id: str,
rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity,
):
if rag_pipeline_dataset_create_entity.name:
# check if dataset name already exists
if (
db.session.query(Dataset)
.filter_by(name=rag_pipeline_dataset_create_entity.name, tenant_id=tenant_id)
.first()
):
raise DatasetNameDuplicateError(
f"Dataset with name {rag_pipeline_dataset_create_entity.name} already exists."
)
else:
# generate a random name as Untitled 1 2 3 ...
datasets = db.session.query(Dataset).filter_by(tenant_id=tenant_id).all()
names = [dataset.name for dataset in datasets]
rag_pipeline_dataset_create_entity.name = generate_incremental_name(
names,
"Untitled",
)
if not current_user or not current_user.id:
raise ValueError("Current user or current user id not found")
pipeline = Pipeline(
tenant_id=tenant_id,
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
created_by=current_user.id,
)
db.session.add(pipeline)
db.session.flush()
dataset = Dataset(
tenant_id=tenant_id,
name=rag_pipeline_dataset_create_entity.name,
description=rag_pipeline_dataset_create_entity.description,
permission=rag_pipeline_dataset_create_entity.permission,
provider="vendor",
runtime_mode="rag_pipeline",
icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(),
created_by=current_user.id,
pipeline_id=pipeline.id,
)
db.session.add(dataset)
db.session.commit()
return dataset
@staticmethod
def get_dataset(dataset_id) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
@@ -339,6 +397,14 @@ class DatasetService:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found")
# check if dataset name is exists
if DatasetService._has_dataset_same_name(
tenant_id=dataset.tenant_id,
dataset_id=dataset_id,
name=data.get("name", dataset.name),
):
raise ValueError("Dataset name already exists")
# Verify user has permission to update this dataset
DatasetService.check_dataset_permission(dataset, user)
@@ -349,6 +415,19 @@ class DatasetService:
else:
return DatasetService._update_internal_dataset(dataset, data, user)
@staticmethod
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
dataset = (
db.session.query(Dataset)
.where(
Dataset.id != dataset_id,
Dataset.name == name,
Dataset.tenant_id == tenant_id,
)
.first()
)
return dataset is not None
@staticmethod
def _update_external_dataset(dataset, data, user):
"""
@@ -454,17 +533,105 @@ class DatasetService:
filtered_data["updated_at"] = naive_utc_now()
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
# update icon info
if data.get("icon_info"):
filtered_data["icon_info"] = data.get("icon_info")
# Update dataset in database
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.commit()
# update pipeline knowledge base node data
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
# Trigger vector index task if indexing technique changed
if action:
deal_dataset_vector_index_task.delay(dataset.id, action)
return dataset
@staticmethod
def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str):
"""
Update pipeline knowledge base node data.
"""
if dataset.runtime_mode != "rag_pipeline":
return
pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first()
if not pipeline:
return
try:
rag_pipeline_service = RagPipelineService()
published_workflow = rag_pipeline_service.get_published_workflow(pipeline)
draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline)
# update knowledge nodes
def update_knowledge_nodes(workflow_graph: str) -> str:
"""Update knowledge-index nodes in workflow graph."""
data: dict[str, Any] = json.loads(workflow_graph)
nodes = data.get("nodes", [])
updated = False
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
try:
knowledge_index_node_data = node.get("data", {})
knowledge_index_node_data["embedding_model"] = dataset.embedding_model
knowledge_index_node_data["embedding_model_provider"] = dataset.embedding_model_provider
knowledge_index_node_data["retrieval_model"] = dataset.retrieval_model
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
node["data"] = knowledge_index_node_data
updated = True
except Exception:
logging.exception("Failed to update knowledge node")
continue
if updated:
data["nodes"] = nodes
return json.dumps(data)
return workflow_graph
# Update published workflow
if published_workflow:
updated_graph = update_knowledge_nodes(published_workflow.graph)
if updated_graph != published_workflow.graph:
# Create new workflow version
workflow = Workflow.new(
tenant_id=pipeline.tenant_id,
app_id=pipeline.id,
type=published_workflow.type,
version=str(datetime.datetime.now(datetime.UTC).replace(tzinfo=None)),
graph=updated_graph,
features=published_workflow.features,
created_by=updata_user_id,
environment_variables=published_workflow.environment_variables,
conversation_variables=published_workflow.conversation_variables,
rag_pipeline_variables=published_workflow.rag_pipeline_variables,
marked_name="",
marked_comment="",
)
db.session.add(workflow)
# Update draft workflow
if draft_workflow:
updated_graph = update_knowledge_nodes(draft_workflow.graph)
if updated_graph != draft_workflow.graph:
draft_workflow.graph = updated_graph
db.session.add(draft_workflow)
# Commit all changes in one transaction
db.session.commit()
except Exception:
logging.exception("Failed to update pipeline knowledge base node data")
db.session.rollback()
raise
@staticmethod
def _handle_indexing_technique_change(dataset, data, filtered_data):
"""
@@ -654,6 +821,133 @@ class DatasetService:
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
@staticmethod
def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
):
if not current_user or not current_user.current_tenant_id:
raise ValueError("Current user or current tenant not found")
dataset = session.merge(dataset)
if not has_published:
dataset.chunk_structure = knowledge_configuration.chunk_structure
dataset.indexing_technique = knowledge_configuration.indexing_technique
if knowledge_configuration.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id, # ignore type error
provider=knowledge_configuration.embedding_model_provider or "",
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model or "",
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
elif knowledge_configuration.indexing_technique == "economy":
dataset.keyword_number = knowledge_configuration.keyword_number
else:
raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
session.add(dataset)
else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
raise ValueError("Chunk structure is not allowed to be updated.")
action = None
if dataset.indexing_technique != knowledge_configuration.indexing_technique:
# if update indexing_technique
if knowledge_configuration.indexing_technique == "economy":
raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.")
elif knowledge_configuration.indexing_technique == "high_quality":
action = "add"
# get embedding model setting
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_configuration.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model,
)
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
dataset.collection_binding_id = dataset_collection_binding.id
dataset.indexing_technique = knowledge_configuration.indexing_technique
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
# Skip embedding model checks if not provided in the update request
if dataset.indexing_technique == "high_quality":
skip_embedding_update = False
try:
# Handle existing model provider
plugin_model_provider = dataset.embedding_model_provider
plugin_model_provider_str = None
if plugin_model_provider:
plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
# Handle new model provider from request
new_plugin_model_provider = knowledge_configuration.embedding_model_provider
new_plugin_model_provider_str = None
if new_plugin_model_provider:
new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
# Only update embedding model if both values are provided and different from current
if (
plugin_model_provider_str != new_plugin_model_provider_str
or knowledge_configuration.embedding_model != dataset.embedding_model
):
action = "update"
model_manager = ModelManager()
embedding_model = None
try:
embedding_model = model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=knowledge_configuration.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model,
)
except ProviderTokenNotInitError:
# If we can't get the embedding model, skip updating it
# and keep the existing settings if available
# Skip the rest of the embedding model update
skip_embedding_update = True
if not skip_embedding_update:
if embedding_model:
dataset.embedding_model = embedding_model.model
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
)
)
dataset.collection_binding_id = dataset_collection_binding.id
except LLMBadRequestError:
raise ValueError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
elif dataset.indexing_technique == "economy":
if dataset.keyword_number != knowledge_configuration.keyword_number:
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
session.add(dataset)
session.commit()
if action:
deal_dataset_index_update_task.delay(dataset.id, action)
@staticmethod
def delete_dataset(dataset_id, user):
dataset = DatasetService.get_dataset(dataset_id)
@@ -730,6 +1024,18 @@ class DatasetService:
.all()
)
@staticmethod
def update_dataset_api_status(dataset_id: str, status: bool):
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
dataset.enable_api = status
if not current_user or not current_user.id:
raise ValueError("Current user or current user id not found")
dataset.updated_by = current_user.id
dataset.updated_at = naive_utc_now()
db.session.commit()
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str):
assert isinstance(current_user, Account)
@@ -974,7 +1280,7 @@ class DocumentService:
return
documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
document.data_source_info_dict.get("upload_file_id", "")
for document in documents
if document.data_source_type == "upload_file" and document.data_source_info_dict
]
@@ -1062,7 +1368,9 @@ class DocumentService:
redis_client.setex(retry_indexing_cache_key, 600, 1)
# trigger async task
document_ids = [document.id for document in documents]
retry_document_indexing_task.delay(dataset_id, document_ids)
if not current_user or not current_user.id:
raise ValueError("Current user or current user id not found")
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
@staticmethod
def sync_website_document(dataset_id: str, document: Document):
@@ -1211,7 +1519,7 @@ class DocumentService:
)
return [], ""
db.session.add(dataset_process_rule)
db.session.commit()
db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
with redis_client.lock(lock_name, timeout=600):
position = DocumentService.get_documents_position(dataset.id)
@@ -1301,23 +1609,10 @@ class DocumentService:
exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
.first()
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
if page.page_id not in exist_page_ids:
data_source_info = {
"credential_id": notion_info.credential_id,
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
@@ -1393,6 +1688,283 @@ class DocumentService:
return documents, batch
# @staticmethod
# def save_document_with_dataset_id(
# dataset: Dataset,
# knowledge_config: KnowledgeConfig,
# account: Account | Any,
# dataset_process_rule: Optional[DatasetProcessRule] = None,
# created_from: str = "web",
# ):
# # check document limit
# features = FeatureService.get_features(current_user.current_tenant_id)
# if features.billing.enabled:
# if not knowledge_config.original_document_id:
# count = 0
# if knowledge_config.data_source:
# if knowledge_config.data_source.info_list.data_source_type == "upload_file":
# upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
# # type: ignore
# count = len(upload_file_list)
# elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
# notion_info_list = knowledge_config.data_source.info_list.notion_info_list
# for notion_info in notion_info_list: # type: ignore
# count = count + len(notion_info.pages)
# elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
# website_info = knowledge_config.data_source.info_list.website_info_list
# count = len(website_info.urls) # type: ignore
# batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
# if features.billing.subscription.plan == "sandbox" and count > 1:
# raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
# if count > batch_upload_limit:
# raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
# DocumentService.check_documents_upload_quota(count, features)
# # if dataset is empty, update dataset data_source_type
# if not dataset.data_source_type:
# dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
# if not dataset.indexing_technique:
# if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
# raise ValueError("Indexing technique is invalid")
# dataset.indexing_technique = knowledge_config.indexing_technique
# if knowledge_config.indexing_technique == "high_quality":
# model_manager = ModelManager()
# if knowledge_config.embedding_model and knowledge_config.embedding_model_provider:
# dataset_embedding_model = knowledge_config.embedding_model
# dataset_embedding_model_provider = knowledge_config.embedding_model_provider
# else:
# embedding_model = model_manager.get_default_model_instance(
# tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
# )
# dataset_embedding_model = embedding_model.model
# dataset_embedding_model_provider = embedding_model.provider
# dataset.embedding_model = dataset_embedding_model
# dataset.embedding_model_provider = dataset_embedding_model_provider
# dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
# dataset_embedding_model_provider, dataset_embedding_model
# )
# dataset.collection_binding_id = dataset_collection_binding.id
# if not dataset.retrieval_model:
# default_retrieval_model = {
# "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
# "reranking_enable": False,
# "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
# "top_k": 2,
# "score_threshold_enabled": False,
# }
# dataset.retrieval_model = (
# knowledge_config.retrieval_model.model_dump()
# if knowledge_config.retrieval_model
# else default_retrieval_model
# ) # type: ignore
# documents = []
# if knowledge_config.original_document_id:
# document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account)
# documents.append(document)
# batch = document.batch
# else:
# batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
# # save process rule
# if not dataset_process_rule:
# process_rule = knowledge_config.process_rule
# if process_rule:
# if process_rule.mode in ("custom", "hierarchical"):
# dataset_process_rule = DatasetProcessRule(
# dataset_id=dataset.id,
# mode=process_rule.mode,
# rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
# created_by=account.id,
# )
# elif process_rule.mode == "automatic":
# dataset_process_rule = DatasetProcessRule(
# dataset_id=dataset.id,
# mode=process_rule.mode,
# rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
# created_by=account.id,
# )
# else:
# logging.warn(
# f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
# )
# return
# db.session.add(dataset_process_rule)
# db.session.commit()
# lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
# with redis_client.lock(lock_name, timeout=600):
# position = DocumentService.get_documents_position(dataset.id)
# document_ids = []
# duplicate_document_ids = []
# if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
# upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
# for file_id in upload_file_list:
# file = (
# db.session.query(UploadFile)
# .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
# .first()
# )
# # raise error if file not found
# if not file:
# raise FileNotExistsError()
# file_name = file.name
# data_source_info = {
# "upload_file_id": file_id,
# }
# # check duplicate
# if knowledge_config.duplicate:
# document = Document.query.filter_by(
# dataset_id=dataset.id,
# tenant_id=current_user.current_tenant_id,
# data_source_type="upload_file",
# enabled=True,
# name=file_name,
# ).first()
# if document:
# document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
# document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# document.created_from = created_from
# document.doc_form = knowledge_config.doc_form
# document.doc_language = knowledge_config.doc_language
# document.data_source_info = json.dumps(data_source_info)
# document.batch = batch
# document.indexing_status = "waiting"
# db.session.add(document)
# documents.append(document)
# duplicate_document_ids.append(document.id)
# continue
# document = DocumentService.build_document(
# dataset,
# dataset_process_rule.id, # type: ignore
# knowledge_config.data_source.info_list.data_source_type, # type: ignore
# knowledge_config.doc_form,
# knowledge_config.doc_language,
# data_source_info,
# created_from,
# position,
# account,
# file_name,
# batch,
# )
# db.session.add(document)
# db.session.flush()
# document_ids.append(document.id)
# documents.append(document)
# position += 1
# elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
# notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
# if not notion_info_list:
# raise ValueError("No notion info list found.")
# exist_page_ids = []
# exist_document = {}
# documents = Document.query.filter_by(
# dataset_id=dataset.id,
# tenant_id=current_user.current_tenant_id,
# data_source_type="notion_import",
# enabled=True,
# ).all()
# if documents:
# for document in documents:
# data_source_info = json.loads(document.data_source_info)
# exist_page_ids.append(data_source_info["notion_page_id"])
# exist_document[data_source_info["notion_page_id"]] = document.id
# for notion_info in notion_info_list:
# workspace_id = notion_info.workspace_id
# data_source_binding = DataSourceOauthBinding.query.filter(
# db.and_(
# DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
# DataSourceOauthBinding.provider == "notion",
# DataSourceOauthBinding.disabled == False,
# DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
# )
# ).first()
# if not data_source_binding:
# raise ValueError("Data source binding not found.")
# for page in notion_info.pages:
# if page.page_id not in exist_page_ids:
# data_source_info = {
# "notion_workspace_id": workspace_id,
# "notion_page_id": page.page_id,
# "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
# "type": page.type,
# }
# # Truncate page name to 255 characters to prevent DB field length errors
# truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
# document = DocumentService.build_document(
# dataset,
# dataset_process_rule.id, # type: ignore
# knowledge_config.data_source.info_list.data_source_type, # type: ignore
# knowledge_config.doc_form,
# knowledge_config.doc_language,
# data_source_info,
# created_from,
# position,
# account,
# truncated_page_name,
# batch,
# )
# db.session.add(document)
# db.session.flush()
# document_ids.append(document.id)
# documents.append(document)
# position += 1
# else:
# exist_document.pop(page.page_id)
# # delete not selected documents
# if len(exist_document) > 0:
# clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
# elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
# website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
# if not website_info:
# raise ValueError("No website info list found.")
# urls = website_info.urls
# for url in urls:
# data_source_info = {
# "url": url,
# "provider": website_info.provider,
# "job_id": website_info.job_id,
# "only_main_content": website_info.only_main_content,
# "mode": "crawl",
# }
# if len(url) > 255:
# document_name = url[:200] + "..."
# else:
# document_name = url
# document = DocumentService.build_document(
# dataset,
# dataset_process_rule.id, # type: ignore
# knowledge_config.data_source.info_list.data_source_type, # type: ignore
# knowledge_config.doc_form,
# knowledge_config.doc_language,
# data_source_info,
# created_from,
# position,
# account,
# document_name,
# batch,
# )
# db.session.add(document)
# db.session.flush()
# document_ids.append(document.id)
# documents.append(document)
# position += 1
# db.session.commit()
# # trigger async task
# if document_ids:
# document_indexing_task.delay(dataset.id, document_ids)
# if duplicate_document_ids:
# duplicate_document_indexing_task.delay(dataset.id, duplicate_document_ids)
# return documents, batch
@staticmethod
def check_documents_upload_quota(count: int, features: FeatureModel):
can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size
@@ -1404,7 +1976,7 @@ class DocumentService:
@staticmethod
def build_document(
dataset: Dataset,
process_rule_id: str,
process_rule_id: str | None,
data_source_type: str,
document_form: str,
document_language: str,
@@ -1540,6 +2112,7 @@ class DocumentService:
raise ValueError("Data source binding not found.")
for page in notion_info.pages:
data_source_info = {
"credential_id": notion_info.credential_id,
"notion_workspace_id": workspace_id,
"notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
@@ -2352,6 +2925,8 @@ class SegmentService:
segment.error = str(e)
db.session.commit()
new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
if not new_segment:
raise ValueError("new_segment is not found")
return new_segment
@classmethod
@@ -2430,9 +3005,11 @@ class SegmentService:
if index_node_ids or child_node_ids:
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id, child_node_ids)
document.word_count = (
document.word_count - total_words if document.word_count and document.word_count > total_words else 0
)
if document.word_count is None:
document.word_count = 0
else:
document.word_count = max(0, document.word_count - total_words)
db.session.add(document)
# Delete database records