fix: Fix parent child retrieval issues (#12206)

Co-authored-by: NFish <douxc512@gmail.com>
Co-authored-by: nite-knite <nkCoding@gmail.com>
This commit is contained in:
Wu Tianwei
2025-01-02 16:07:21 +08:00
committed by GitHub
parent 68757950ce
commit 09d759d196
34 changed files with 446 additions and 387 deletions

View File

@@ -59,36 +59,24 @@ const ConfigContent: FC<Props> = ({
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel: currentRerankModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
{
provider: datasetConfigs.reranking_model?.reranking_provider_name,
model: datasetConfigs.reranking_model?.reranking_model_name,
},
)
const rerankModel = (() => {
if (datasetConfigs.reranking_model?.reranking_provider_name) {
return {
provider_name: datasetConfigs.reranking_model.reranking_provider_name,
model_name: datasetConfigs.reranking_model.reranking_model_name,
}
const rerankModel = useMemo(() => {
return {
provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '',
model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '',
}
else if (rerankDefaultModel) {
return {
provider_name: rerankDefaultModel.provider.provider,
model_name: rerankDefaultModel.model,
}
}
})()
}, [datasetConfigs.reranking_model])
const handleParamChange = (key: string, value: number) => {
if (key === 'top_k') {
@@ -133,6 +121,12 @@ const ConfigContent: FC<Props> = ({
}
const handleRerankModeChange = (mode: RerankingModeEnum) => {
if (mode === datasetConfigs.reranking_mode)
return
if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
onChange({
...datasetConfigs,
reranking_mode: mode,
@@ -162,31 +156,25 @@ const ConfigContent: FC<Props> = ({
const canManuallyToggleRerank = useMemo(() => {
return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|| selectedDatasetsMode.allExternal
|| selectedDatasetsMode.allExternal
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
const showRerankModel = useMemo(() => {
if (!canManuallyToggleRerank)
return true
else if (canManuallyToggleRerank && !isRerankDefaultModelValid)
return false
return datasetConfigs.reranking_enable
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])
}, [datasetConfigs.reranking_enable, canManuallyToggleRerank])
const handleDisabledSwitchClick = useCallback(() => {
if (!currentRerankModel && !showRerankModel)
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
if (!currentRerankModel && enable)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentRerankModel, showRerankModel, t])
useEffect(() => {
if (canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
onChange({
...datasetConfigs,
reranking_enable: showRerankModel,
})
}
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
onChange({
...datasetConfigs,
reranking_enable: enable,
})
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [currentRerankModel, datasetConfigs, onChange])
return (
<div>
@@ -267,24 +255,12 @@ const ConfigContent: FC<Props> = ({
<div className='flex items-center'>
{
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
<div
className='flex items-center'
onClick={handleDisabledSwitchClick}
>
<Switch
size='md'
defaultValue={showRerankModel}
disabled={!currentRerankModel || !canManuallyToggleRerank}
onChange={(v) => {
if (canManuallyToggleRerank) {
onChange({
...datasetConfigs,
reranking_enable: v,
})
}
}}
/>
</div>
<Switch
size='md'
defaultValue={showRerankModel}
disabled={!canManuallyToggleRerank}
onChange={handleDisabledSwitchClick}
/>
)
}
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
@@ -298,21 +274,24 @@ const ConfigContent: FC<Props> = ({
triggerClassName='ml-1 w-4 h-4'
/>
</div>
<div>
<ModelSelector
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
onSelect={(v) => {
onChange({
...datasetConfigs,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
modelList={rerankModelList}
/>
</div>
{
showRerankModel && (
<div>
<ModelSelector
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
onSelect={(v) => {
onChange({
...datasetConfigs,
reranking_model: {
reranking_provider_name: v.provider,
reranking_model_name: v.model,
},
})
}}
modelList={rerankModelList}
/>
</div>
)}
</div>
)
}

View File

@@ -10,7 +10,7 @@ import Modal from '@/app/components/base/modal'
import Button from '@/app/components/base/button'
import { RETRIEVE_TYPE } from '@/types/app'
import Toast from '@/app/components/base/toast'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { RerankingModeEnum } from '@/models/datasets'
import type { DataSet } from '@/models/datasets'
@@ -41,17 +41,27 @@ const ParamsConfig = ({
}, [datasetConfigs])
const {
defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid,
modelList: rerankModelList,
currentModel: rerankDefaultModel,
currentProvider: rerankDefaultProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel: isCurrentRerankModelValid,
} = useCurrentProviderAndModel(
rerankModelList,
{
provider: tempDataSetConfigs.reranking_model?.reranking_provider_name ?? '',
model: tempDataSetConfigs.reranking_model?.reranking_model_name ?? '',
},
)
const isValid = () => {
let errMsg = ''
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
if (tempDataSetConfigs.reranking_enable
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
&& !isRerankDefaultModelValid
&& !isCurrentRerankModelValid
)
errMsg = t('appDebug.datasetConfig.rerankModelRequired')
}
@@ -66,16 +76,7 @@ const ParamsConfig = ({
const handleSave = () => {
if (!isValid())
return
const config = { ...tempDataSetConfigs }
if (config.retrieval_model === RETRIEVE_TYPE.multiWay
&& config.reranking_mode === RerankingModeEnum.RerankingModel
&& !config.reranking_model) {
config.reranking_model = {
reranking_provider_name: rerankDefaultModel?.provider?.provider,
reranking_model_name: rerankDefaultModel?.model,
} as any
}
setDatasetConfigs(config)
setDatasetConfigs(tempDataSetConfigs)
setRerankSettingModalOpen(false)
}
@@ -94,14 +95,14 @@ const ParamsConfig = ({
reranking_enable: restConfigs.reranking_enable,
}, selectedDatasets, selectedDatasets, {
provider: rerankDefaultProvider?.provider,
model: isRerankDefaultModelValid?.model,
model: rerankDefaultModel?.model,
})
setTempDataSetConfigs({
...retrievalConfig,
reranking_model: restConfigs.reranking_model && {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
reranking_model: {
reranking_provider_name: retrievalConfig.reranking_model?.provider || '',
reranking_model_name: retrievalConfig.reranking_model?.model || '',
},
retrieval_model,
score_threshold_enabled,

View File

@@ -12,7 +12,7 @@ import Divider from '@/app/components/base/divider'
import Button from '@/app/components/base/button'
import Input from '@/app/components/base/input'
import Textarea from '@/app/components/base/textarea'
import { type DataSet, RerankingModeEnum } from '@/models/datasets'
import { type DataSet } from '@/models/datasets'
import { useToastContext } from '@/app/components/base/toast'
import { updateDatasetSetting } from '@/service/datasets'
import { useAppContext } from '@/context/app-context'
@@ -21,7 +21,7 @@ import type { RetrievalConfig } from '@/types/app'
import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import PermissionSelector from '@/app/components/datasets/settings/permission-selector'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
@@ -99,8 +99,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
}
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
rerankModelList,
retrievalConfig,
indexMethod,
@@ -109,14 +107,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return
}
const postRetrievalConfig = ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig: {
...retrievalConfig,
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
},
indexMethod,
})
try {
setLoading(true)
const { id, name, description, permission } = localeCurrentDataset
@@ -128,8 +118,8 @@ const SettingsModal: FC<SettingsModalProps> = ({
permission,
indexing_technique: indexMethod,
retrieval_model: {
...postRetrievalConfig,
score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0,
...retrievalConfig,
score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0,
},
embedding_model: localeCurrentDataset.embedding_model,
embedding_model_provider: localeCurrentDataset.embedding_model_provider,
@@ -157,7 +147,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
onSave({
...localeCurrentDataset,
indexing_technique: indexMethod,
retrieval_model_dict: postRetrievalConfig,
retrieval_model_dict: retrievalConfig,
})
}
catch (e) {

View File

@@ -287,9 +287,9 @@ const Configuration: FC = () => {
setDatasetConfigs({
...retrievalConfig,
reranking_model: restConfigs.reranking_model && {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
reranking_model: {
reranking_provider_name: retrievalConfig?.reranking_model?.provider || '',
reranking_model_name: retrievalConfig?.reranking_model?.model || '',
},
retrieval_model,
score_threshold_enabled,