more typed orm (#28507)

This commit is contained in:
Asuka Minato
2025-11-21 22:45:51 +09:00
committed by GitHub
parent 63b8bbbab3
commit a6c6bcf95c
20 changed files with 196 additions and 134 deletions

View File

@@ -254,6 +254,8 @@ class DatasetService:
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(external_knowledge_api_id)
if not external_knowledge_api:
raise ValueError("External API template not found.")
if external_knowledge_id is None:
raise ValueError("external_knowledge_id is required")
external_knowledge_binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset.id,

View File

@@ -257,12 +257,16 @@ class ExternalDatasetService:
db.session.add(dataset)
db.session.flush()
if args.get("external_knowledge_id") is None:
raise ValueError("external_knowledge_id is required")
if args.get("external_knowledge_api_id") is None:
raise ValueError("external_knowledge_api_id is required")
external_knowledge_binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset.id,
external_knowledge_api_id=args.get("external_knowledge_api_id"),
external_knowledge_id=args.get("external_knowledge_id"),
external_knowledge_api_id=args.get("external_knowledge_api_id") or "",
external_knowledge_id=args.get("external_knowledge_id") or "",
created_by=user_id,
)
db.session.add(external_knowledge_binding)

View File

@@ -82,7 +82,12 @@ class HitTestingService:
logger.debug("Hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)
@@ -118,7 +123,12 @@ class HitTestingService:
logger.debug("External knowledge hit testing retrieve in %s seconds", end - start)
dataset_query = DatasetQuery(
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
dataset_id=dataset.id,
content=query,
source="hit_testing",
source_app_id=None,
created_by_role="account",
created_by=account.id,
)
db.session.add(dataset_query)

View File

@@ -29,6 +29,8 @@ class OpsService:
if not app:
return None
tenant_id = app.tenant_id
if trace_config_data.tracing_config is None:
raise ValueError("Tracing config cannot be None.")
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
tenant_id, tracing_provider, trace_config_data.tracing_config
)

View File

@@ -1119,13 +1119,19 @@ class RagPipelineService:
with Session(db.engine) as session:
rag_pipeline_dsl_service = RagPipelineDslService(session)
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
if args.get("icon_info") is None:
args["icon_info"] = {}
if args.get("description") is None:
raise ValueError("Description is required")
if args.get("name") is None:
raise ValueError("Name is required")
pipeline_customized_template = PipelineCustomizedTemplate(
name=args.get("name"),
description=args.get("description"),
icon=args.get("icon_info"),
name=args.get("name") or "",
description=args.get("description") or "",
icon=args.get("icon_info") or {},
tenant_id=pipeline.tenant_id,
yaml_content=dsl,
install_count=0,
position=max_position + 1 if max_position else 1,
chunk_structure=dataset.chunk_structure,
language="en-US",

View File

@@ -322,9 +322,9 @@ class RagPipelineTransformService:
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
created_at=document.created_at,
datasource_node_id=file_node_id,
)
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "notion_import":
@@ -350,9 +350,9 @@ class RagPipelineTransformService:
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
created_at=document.created_at,
datasource_node_id=notion_node_id,
)
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)
elif document.data_source_type == "website_crawl":
@@ -379,8 +379,8 @@ class RagPipelineTransformService:
datasource_info=data_source_info,
input_data={},
created_by=document.created_by,
created_at=document.created_at,
datasource_node_id=datasource_node_id,
)
document_pipeline_execution_log.created_at = document.created_at
db.session.add(document)
db.session.add(document_pipeline_execution_log)