Fix/retrieval setting weight default value (#9622)

This commit is contained in:
zxhlyh
2024-10-22 18:31:39 +08:00
committed by GitHub
parent 7d7e0f9800
commit ff956cb546
7 changed files with 86 additions and 75 deletions

View File

@@ -13,6 +13,11 @@ import ContextVar from './context-var'
import ConfigContext from '@/context/debug-configuration'
import { AppType } from '@/types/app'
import type { DataSet } from '@/models/datasets'
import {
getMultipleRetrievalConfig,
} from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
const Icon = (
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
@@ -31,13 +36,25 @@ const DatasetConfig: FC = () => {
setModelConfig,
showSelectDataSet,
isAgent,
datasetConfigs,
setDatasetConfigs,
} = useContext(ConfigContext)
const formattingChangedDispatcher = useFormattingChangedDispatcher()
const hasData = dataSet.length > 0
const {
currentModel: currentRerankModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const onRemove = (id: string) => {
setDataSet(dataSet.filter(item => item.id !== id))
const filteredDataSets = dataSet.filter(item => item.id !== id)
setDataSet(filteredDataSets)
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel)
setDatasetConfigs({
...(datasetConfigs as any),
...retrievalConfig,
})
formattingChangedDispatcher()
}

View File

@@ -55,7 +55,7 @@ const ConfigContent: FC<Props> = ({
retrieval_model: RETRIEVE_TYPE.multiWay,
}, isInWorkflow)
}
}, [type])
}, [type, datasetConfigs, isInWorkflow, onChange])
const {
modelList: rerankModelList,

View File

@@ -16,7 +16,6 @@ import type { DataSet } from '@/models/datasets'
import type { DatasetConfigs } from '@/models/debug'
import {
getMultipleRetrievalConfig,
getSelectedDatasetsMode,
} from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
type ParamsConfigProps = {
@@ -37,57 +36,8 @@ const ParamsConfig = ({
const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs)
useEffect(() => {
const {
allEconomic,
allHighQuality,
allHighQualityFullTextSearch,
allHighQualityVectorSearch,
allExternal,
mixtureHighQualityAndEconomic,
inconsistentEmbeddingModel,
mixtureInternalAndExternal,
} = getSelectedDatasetsMode(selectedDatasets)
if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))
setRerankSettingModalOpen(false)
if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal || (allExternal && selectedDatasets.length > 1))
setRerankSettingModalOpen(true)
}, [selectedDatasets])
useEffect(() => {
const {
allEconomic,
allInternal,
allExternal,
} = getSelectedDatasetsMode(selectedDatasets)
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
let rerankEnable = restConfigs.reranking_enable
if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
rerankEnable = false
setTempDataSetConfigs({
...getMultipleRetrievalConfig({
top_k: restConfigs.top_k,
score_threshold: restConfigs.score_threshold,
reranking_model: restConfigs.reranking_model && {
provider: restConfigs.reranking_model.reranking_provider_name,
model: restConfigs.reranking_model.reranking_model_name,
},
reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights,
reranking_enable: rerankEnable,
}, selectedDatasets),
reranking_model: restConfigs.reranking_model && {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
},
retrieval_model,
score_threshold_enabled,
datasets,
})
}, [selectedDatasets, datasetConfigs])
setTempDataSetConfigs(datasetConfigs)
}, [datasetConfigs])
const {
defaultModel: rerankDefaultModel,
@@ -135,7 +85,7 @@ const ParamsConfig = ({
reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights,
reranking_enable: restConfigs.reranking_enable,
}, selectedDatasets)
}, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid)
setTempDataSetConfigs({
...retrievalConfig,
@@ -180,6 +130,7 @@ const ParamsConfig = ({
<div className='mt-6 flex justify-end'>
<Button className='mr-2 flex-shrink-0' onClick={() => {
setTempDataSetConfigs(datasetConfigs)
setRerankSettingModalOpen(false)
}}>{t('common.operation.cancel')}</Button>
<Button variant='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>

View File

@@ -38,7 +38,7 @@ import ConfigContext from '@/context/debug-configuration'
import Config from '@/app/components/app/configuration/config'
import Debug from '@/app/components/app/configuration/debug'
import Confirm from '@/app/components/base/confirm'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ToastContext } from '@/app/components/base/toast'
import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
@@ -53,7 +53,10 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import Drawer from '@/app/components/base/drawer'
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
import {
useModelListAndDefaultModelAndCurrentProviderAndModel,
useTextGenerationCurrentProviderAndModelAndModelList,
} from '@/app/components/header/account-setting/model-provider-page/hooks'
import { fetchCollectionList } from '@/service/tools'
import { type Collection } from '@/app/components/tools/types'
import { useStore as useAppStore } from '@/app/components/app/store'
@@ -217,6 +220,9 @@ const Configuration: FC = () => {
const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false)
const selectedIds = dataSets.map(item => item.id)
const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
const {
currentModel: currentRerankModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
hideSelectDataSet()
@@ -263,7 +269,7 @@ const Configuration: FC = () => {
reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights,
reranking_enable: restConfigs.reranking_enable,
}, newDatasets)
}, newDatasets, dataSets, !!currentRerankModel)
setDatasetConfigs({
...retrievalConfig,
@@ -603,9 +609,11 @@ const Configuration: FC = () => {
syncToPublishedConfig(config)
setPublishedConfig(config)
const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel)
setDatasetConfigs({
retrieval_model: RETRIEVE_TYPE.multiWay,
...modelConfig.dataset_configs,
...retrievalConfig,
})
setHasFetchedDetail(true)
})