fix: Fix parent child retrieval issues (#12206)
Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: nite-knite <nkCoding@gmail.com>
This commit is contained in:
@@ -59,36 +59,24 @@ const ConfigContent: FC<Props> = ({
|
||||
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelValid,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel: currentRerankModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
{
|
||||
provider: datasetConfigs.reranking_model?.reranking_provider_name,
|
||||
model: datasetConfigs.reranking_model?.reranking_model_name,
|
||||
},
|
||||
)
|
||||
|
||||
const rerankModel = (() => {
|
||||
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
||||
return {
|
||||
provider_name: datasetConfigs.reranking_model.reranking_provider_name,
|
||||
model_name: datasetConfigs.reranking_model.reranking_model_name,
|
||||
}
|
||||
const rerankModel = useMemo(() => {
|
||||
return {
|
||||
provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '',
|
||||
model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '',
|
||||
}
|
||||
else if (rerankDefaultModel) {
|
||||
return {
|
||||
provider_name: rerankDefaultModel.provider.provider,
|
||||
model_name: rerankDefaultModel.model,
|
||||
}
|
||||
}
|
||||
})()
|
||||
}, [datasetConfigs.reranking_model])
|
||||
|
||||
const handleParamChange = (key: string, value: number) => {
|
||||
if (key === 'top_k') {
|
||||
@@ -133,6 +121,12 @@ const ConfigContent: FC<Props> = ({
|
||||
}
|
||||
|
||||
const handleRerankModeChange = (mode: RerankingModeEnum) => {
|
||||
if (mode === datasetConfigs.reranking_mode)
|
||||
return
|
||||
|
||||
if (mode === RerankingModeEnum.RerankingModel && !currentRerankModel)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_mode: mode,
|
||||
@@ -162,31 +156,25 @@ const ConfigContent: FC<Props> = ({
|
||||
|
||||
const canManuallyToggleRerank = useMemo(() => {
|
||||
return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|
||||
|| selectedDatasetsMode.allExternal
|
||||
|| selectedDatasetsMode.allExternal
|
||||
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
|
||||
|
||||
const showRerankModel = useMemo(() => {
|
||||
if (!canManuallyToggleRerank)
|
||||
return true
|
||||
else if (canManuallyToggleRerank && !isRerankDefaultModelValid)
|
||||
return false
|
||||
|
||||
return datasetConfigs.reranking_enable
|
||||
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])
|
||||
}, [datasetConfigs.reranking_enable, canManuallyToggleRerank])
|
||||
|
||||
const handleDisabledSwitchClick = useCallback(() => {
|
||||
if (!currentRerankModel && !showRerankModel)
|
||||
const handleDisabledSwitchClick = useCallback((enable: boolean) => {
|
||||
if (!currentRerankModel && enable)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
}, [currentRerankModel, showRerankModel, t])
|
||||
|
||||
useEffect(() => {
|
||||
if (canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: showRerankModel,
|
||||
})
|
||||
}
|
||||
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: enable,
|
||||
})
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [currentRerankModel, datasetConfigs, onChange])
|
||||
|
||||
return (
|
||||
<div>
|
||||
@@ -267,24 +255,12 @@ const ConfigContent: FC<Props> = ({
|
||||
<div className='flex items-center'>
|
||||
{
|
||||
selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && (
|
||||
<div
|
||||
className='flex items-center'
|
||||
onClick={handleDisabledSwitchClick}
|
||||
>
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={showRerankModel}
|
||||
disabled={!currentRerankModel || !canManuallyToggleRerank}
|
||||
onChange={(v) => {
|
||||
if (canManuallyToggleRerank) {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: v,
|
||||
})
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={showRerankModel}
|
||||
disabled={!canManuallyToggleRerank}
|
||||
onChange={handleDisabledSwitchClick}
|
||||
/>
|
||||
)
|
||||
}
|
||||
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
|
||||
@@ -298,21 +274,24 @@ const ConfigContent: FC<Props> = ({
|
||||
triggerClassName='ml-1 w-4 h-4'
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
|
||||
onSelect={(v) => {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_model: {
|
||||
reranking_provider_name: v.provider,
|
||||
reranking_model_name: v.model,
|
||||
},
|
||||
})
|
||||
}}
|
||||
modelList={rerankModelList}
|
||||
/>
|
||||
</div>
|
||||
{
|
||||
showRerankModel && (
|
||||
<div>
|
||||
<ModelSelector
|
||||
defaultModel={rerankModel && { provider: rerankModel?.provider_name, model: rerankModel?.model_name }}
|
||||
onSelect={(v) => {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_model: {
|
||||
reranking_provider_name: v.provider,
|
||||
reranking_model_name: v.model,
|
||||
},
|
||||
})
|
||||
}}
|
||||
modelList={rerankModelList}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import Modal from '@/app/components/base/modal'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { RETRIEVE_TYPE } from '@/types/app'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { RerankingModeEnum } from '@/models/datasets'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
@@ -41,17 +41,27 @@ const ParamsConfig = ({
|
||||
}, [datasetConfigs])
|
||||
|
||||
const {
|
||||
defaultModel: rerankDefaultModel,
|
||||
currentModel: isRerankDefaultModelValid,
|
||||
modelList: rerankModelList,
|
||||
currentModel: rerankDefaultModel,
|
||||
currentProvider: rerankDefaultProvider,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel: isCurrentRerankModelValid,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
{
|
||||
provider: tempDataSetConfigs.reranking_model?.reranking_provider_name ?? '',
|
||||
model: tempDataSetConfigs.reranking_model?.reranking_model_name ?? '',
|
||||
},
|
||||
)
|
||||
|
||||
const isValid = () => {
|
||||
let errMsg = ''
|
||||
if (tempDataSetConfigs.retrieval_model === RETRIEVE_TYPE.multiWay) {
|
||||
if (tempDataSetConfigs.reranking_enable
|
||||
&& tempDataSetConfigs.reranking_mode === RerankingModeEnum.RerankingModel
|
||||
&& !isRerankDefaultModelValid
|
||||
&& !isCurrentRerankModelValid
|
||||
)
|
||||
errMsg = t('appDebug.datasetConfig.rerankModelRequired')
|
||||
}
|
||||
@@ -66,16 +76,7 @@ const ParamsConfig = ({
|
||||
const handleSave = () => {
|
||||
if (!isValid())
|
||||
return
|
||||
const config = { ...tempDataSetConfigs }
|
||||
if (config.retrieval_model === RETRIEVE_TYPE.multiWay
|
||||
&& config.reranking_mode === RerankingModeEnum.RerankingModel
|
||||
&& !config.reranking_model) {
|
||||
config.reranking_model = {
|
||||
reranking_provider_name: rerankDefaultModel?.provider?.provider,
|
||||
reranking_model_name: rerankDefaultModel?.model,
|
||||
} as any
|
||||
}
|
||||
setDatasetConfigs(config)
|
||||
setDatasetConfigs(tempDataSetConfigs)
|
||||
setRerankSettingModalOpen(false)
|
||||
}
|
||||
|
||||
@@ -94,14 +95,14 @@ const ParamsConfig = ({
|
||||
reranking_enable: restConfigs.reranking_enable,
|
||||
}, selectedDatasets, selectedDatasets, {
|
||||
provider: rerankDefaultProvider?.provider,
|
||||
model: isRerankDefaultModelValid?.model,
|
||||
model: rerankDefaultModel?.model,
|
||||
})
|
||||
|
||||
setTempDataSetConfigs({
|
||||
...retrievalConfig,
|
||||
reranking_model: restConfigs.reranking_model && {
|
||||
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
|
||||
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
|
||||
reranking_model: {
|
||||
reranking_provider_name: retrievalConfig.reranking_model?.provider || '',
|
||||
reranking_model_name: retrievalConfig.reranking_model?.model || '',
|
||||
},
|
||||
retrieval_model,
|
||||
score_threshold_enabled,
|
||||
|
||||
@@ -12,7 +12,7 @@ import Divider from '@/app/components/base/divider'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
import { type DataSet, RerankingModeEnum } from '@/models/datasets'
|
||||
import { type DataSet } from '@/models/datasets'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
import { updateDatasetSetting } from '@/service/datasets'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
@@ -21,7 +21,7 @@ import type { RetrievalConfig } from '@/types/app'
|
||||
import RetrievalSettings from '@/app/components/datasets/external-knowledge-base/create/RetrievalSettings'
|
||||
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
|
||||
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
|
||||
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'
|
||||
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
|
||||
import PermissionSelector from '@/app/components/datasets/settings/permission-selector'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
@@ -99,8 +99,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
}
|
||||
if (
|
||||
!isReRankModelSelected({
|
||||
rerankDefaultModel,
|
||||
isRerankDefaultModelValid: !!isRerankDefaultModelValid,
|
||||
rerankModelList,
|
||||
retrievalConfig,
|
||||
indexMethod,
|
||||
@@ -109,14 +107,6 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
|
||||
return
|
||||
}
|
||||
const postRetrievalConfig = ensureRerankModelSelected({
|
||||
rerankDefaultModel: rerankDefaultModel!,
|
||||
retrievalConfig: {
|
||||
...retrievalConfig,
|
||||
reranking_enable: retrievalConfig.reranking_mode === RerankingModeEnum.RerankingModel,
|
||||
},
|
||||
indexMethod,
|
||||
})
|
||||
try {
|
||||
setLoading(true)
|
||||
const { id, name, description, permission } = localeCurrentDataset
|
||||
@@ -128,8 +118,8 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
permission,
|
||||
indexing_technique: indexMethod,
|
||||
retrieval_model: {
|
||||
...postRetrievalConfig,
|
||||
score_threshold: postRetrievalConfig.score_threshold_enabled ? postRetrievalConfig.score_threshold : 0,
|
||||
...retrievalConfig,
|
||||
score_threshold: retrievalConfig.score_threshold_enabled ? retrievalConfig.score_threshold : 0,
|
||||
},
|
||||
embedding_model: localeCurrentDataset.embedding_model,
|
||||
embedding_model_provider: localeCurrentDataset.embedding_model_provider,
|
||||
@@ -157,7 +147,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
|
||||
onSave({
|
||||
...localeCurrentDataset,
|
||||
indexing_technique: indexMethod,
|
||||
retrieval_model_dict: postRetrievalConfig,
|
||||
retrieval_model_dict: retrievalConfig,
|
||||
})
|
||||
}
|
||||
catch (e) {
|
||||
|
||||
@@ -287,9 +287,9 @@ const Configuration: FC = () => {
|
||||
|
||||
setDatasetConfigs({
|
||||
...retrievalConfig,
|
||||
reranking_model: restConfigs.reranking_model && {
|
||||
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
|
||||
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
|
||||
reranking_model: {
|
||||
reranking_provider_name: retrievalConfig?.reranking_model?.provider || '',
|
||||
reranking_model_name: retrievalConfig?.reranking_model?.model || '',
|
||||
},
|
||||
retrieval_model,
|
||||
score_threshold_enabled,
|
||||
|
||||
Reference in New Issue
Block a user