feat: add multi model credentials (#24451)

Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
非法操作
2025-08-25 16:12:29 +08:00
committed by GitHub
parent b08bfa203a
commit 6010d5f24c
65 changed files with 5202 additions and 1814 deletions

View File

@@ -12,6 +12,7 @@ from configs import dify_config
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
from core.entities.provider_entities import (
CredentialConfiguration,
CustomConfiguration,
CustomModelConfiguration,
CustomProviderConfiguration,
@@ -40,7 +41,9 @@ from extensions.ext_redis import redis_client
from models.provider import (
LoadBalancingModelConfig,
Provider,
ProviderCredential,
ProviderModel,
ProviderModelCredential,
ProviderModelSetting,
ProviderType,
TenantDefaultModel,
@@ -488,6 +491,61 @@ class ProviderManager:
return provider_name_to_provider_load_balancing_model_configs_dict
@staticmethod
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
"""
Get provider all credentials.
:param tenant_id: workspace id
:param provider_name: provider name
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(ProviderCredential)
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
.order_by(ProviderCredential.created_at.desc())
)
available_credentials = session.scalars(stmt).all()
return [
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
for credential in available_credentials
]
@staticmethod
def get_provider_model_available_credentials(
tenant_id: str, provider_name: str, model_name: str, model_type: str
) -> list[CredentialConfiguration]:
"""
Get provider custom model all credentials.
:param tenant_id: workspace id
:param provider_name: provider name
:param model_name: model name
:param model_type: model type
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(ProviderModelCredential)
.where(
ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_name,
ProviderModelCredential.model_name == model_name,
ProviderModelCredential.model_type == model_type,
)
.order_by(ProviderModelCredential.created_at.desc())
)
available_credentials = session.scalars(stmt).all()
return [
CredentialConfiguration(credential_id=credential.id, credential_name=credential.credential_name)
for credential in available_credentials
]
@staticmethod
def _init_trial_provider_records(
tenant_id: str, provider_name_to_provider_records_dict: dict[str, list[Provider]]
@@ -590,9 +648,6 @@ class ProviderManager:
if provider_record.provider_type == ProviderType.SYSTEM.value:
continue
if not provider_record.encrypted_config:
continue
custom_provider_record = provider_record
# Get custom provider credentials
@@ -611,8 +666,8 @@ class ProviderManager:
try:
# fix origin data
if custom_provider_record.encrypted_config is None:
raise ValueError("No credentials found")
if not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {}
elif not custom_provider_record.encrypted_config.startswith("{"):
provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
@@ -637,7 +692,14 @@ class ProviderManager:
else:
provider_credentials = cached_provider_credentials
custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials)
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials,
current_credential_name=custom_provider_record.credential_name,
current_credential_id=custom_provider_record.credential_id,
available_credentials=self.get_provider_available_credentials(
tenant_id, custom_provider_record.provider_name
),
)
# Get provider model credential secret variables
model_credential_secret_variables = self._extract_secret_variables(
@@ -649,8 +711,12 @@ class ProviderManager:
# Get custom provider model credentials
custom_model_configurations = []
for provider_model_record in provider_model_records:
if not provider_model_record.encrypted_config:
continue
available_model_credentials = self.get_provider_model_available_credentials(
tenant_id,
provider_model_record.provider_name,
provider_model_record.model_name,
provider_model_record.model_type,
)
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL
@@ -659,7 +725,7 @@ class ProviderManager:
# Get cached provider model credentials
cached_provider_model_credentials = provider_model_credentials_cache.get()
if not cached_provider_model_credentials:
if not cached_provider_model_credentials and provider_model_record.encrypted_config:
try:
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
except JSONDecodeError:
@@ -688,6 +754,9 @@ class ProviderManager:
model=provider_model_record.model_name,
model_type=ModelType.value_of(provider_model_record.model_type),
credentials=provider_model_credentials,
current_credential_id=provider_model_record.credential_id,
current_credential_name=provider_model_record.credential_name,
available_model_credentials=available_model_credentials,
)
)
@@ -899,6 +968,18 @@ class ProviderManager:
load_balancing_model_config.model_name == provider_model_setting.model_name
and load_balancing_model_config.model_type == provider_model_setting.model_type
):
if load_balancing_model_config.name == "__delete__":
# to calculate current model whether has invalidate lb configs
load_balancing_configs.append(
ModelLoadBalancingConfiguration(
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials={},
credential_source_type=load_balancing_model_config.credential_source_type,
)
)
continue
if not load_balancing_model_config.enabled:
continue
@@ -955,6 +1036,7 @@ class ProviderManager:
id=load_balancing_model_config.id,
name=load_balancing_model_config.name,
credentials=provider_model_credentials,
credential_source_type=load_balancing_model_config.credential_source_type,
)
)