feat: support pinning, including, and excluding for model providers and tools (#7419)
Co-authored-by: GareArc <chen4851@purude.edu>
This commit is contained in:
@@ -3,6 +3,7 @@ from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
@@ -19,6 +20,87 @@ def get_position_map(folder_path: str, *, file_name: str = "_position.yaml") ->
|
||||
return {name: index for index, name in enumerate(positions)}
|
||||
|
||||
|
||||
def get_tool_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping for tools from name to index from a YAML file.
|
||||
:param folder_path:
|
||||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_map = get_position_map(folder_path, file_name=file_name)
|
||||
|
||||
return pin_position_map(
|
||||
position_map,
|
||||
pin_list=dify_config.POSITION_TOOL_PINS_LIST,
|
||||
)
|
||||
|
||||
|
||||
def get_provider_position_map(folder_path: str, file_name: str = "_position.yaml") -> dict[str, int]:
|
||||
"""
|
||||
Get the mapping for providers from name to index from a YAML file.
|
||||
:param folder_path:
|
||||
:param file_name: the YAML file name, default to '_position.yaml'
|
||||
:return: a dict with name as key and index as value
|
||||
"""
|
||||
position_map = get_position_map(folder_path, file_name=file_name)
|
||||
return pin_position_map(
|
||||
position_map,
|
||||
pin_list=dify_config.POSITION_PROVIDER_PINS_LIST,
|
||||
)
|
||||
|
||||
|
||||
def pin_position_map(original_position_map: dict[str, int], pin_list: list[str]) -> dict[str, int]:
|
||||
"""
|
||||
Pin the items in the pin list to the beginning of the position map.
|
||||
Overall logic: exclude > include > pin
|
||||
:param position_map: the position map to be sorted and filtered
|
||||
:param pin_list: the list of pins to be put at the beginning
|
||||
:return: the sorted position map
|
||||
"""
|
||||
positions = sorted(original_position_map.keys(), key=lambda x: original_position_map[x])
|
||||
|
||||
# Add pins to position map
|
||||
position_map = {name: idx for idx, name in enumerate(pin_list)}
|
||||
|
||||
# Add remaining positions to position map
|
||||
start_idx = len(position_map)
|
||||
for name in positions:
|
||||
if name not in position_map:
|
||||
position_map[name] = start_idx
|
||||
start_idx += 1
|
||||
|
||||
return position_map
|
||||
|
||||
|
||||
def is_filtered(
|
||||
include_set: set[str],
|
||||
exclude_set: set[str],
|
||||
data: Any,
|
||||
name_func: Callable[[Any], str],
|
||||
) -> bool:
|
||||
"""
|
||||
Chcek if the object should be filtered out.
|
||||
Overall logic: exclude > include > pin
|
||||
:param include_set: the set of names to be included
|
||||
:param exclude_set: the set of names to be excluded
|
||||
:param name_func: the function to get the name of the object
|
||||
:param data: the data to be filtered
|
||||
:return: True if the object should be filtered out, False otherwise
|
||||
"""
|
||||
if not data:
|
||||
return False
|
||||
if not include_set and not exclude_set:
|
||||
return False
|
||||
|
||||
name = name_func(data)
|
||||
|
||||
if name in exclude_set: # exclude_set is prioritized
|
||||
return True
|
||||
if include_set and name not in include_set: # filter out only if include_set is not empty
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def sort_by_position_map(
|
||||
position_map: dict[str, int],
|
||||
data: list[Any],
|
||||
|
||||
@@ -368,6 +368,15 @@ class ModelManager:
|
||||
|
||||
return ModelInstance(provider_model_bundle, model)
|
||||
|
||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||
"""
|
||||
Return first provider and the first model in the provider
|
||||
:param tenant_id: tenant id
|
||||
:param model_type: model type
|
||||
:return: provider name, model name
|
||||
"""
|
||||
return self._provider_manager.get_first_provider_first_model(tenant_id, model_type)
|
||||
|
||||
def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
|
||||
"""
|
||||
Get default model instance
|
||||
@@ -502,7 +511,6 @@ class LBModelManager:
|
||||
config.id
|
||||
)
|
||||
|
||||
|
||||
res = redis_client.exists(cooldown_cache_key)
|
||||
res = cast(bool, res)
|
||||
return res
|
||||
|
||||
@@ -151,9 +151,9 @@ class AIModel(ABC):
|
||||
os.path.join(provider_model_type_path, model_schema_yaml)
|
||||
for model_schema_yaml in os.listdir(provider_model_type_path)
|
||||
if not model_schema_yaml.startswith('__')
|
||||
and not model_schema_yaml.startswith('_')
|
||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
||||
and model_schema_yaml.endswith('.yaml')
|
||||
and not model_schema_yaml.startswith('_')
|
||||
and os.path.isfile(os.path.join(provider_model_type_path, model_schema_yaml))
|
||||
and model_schema_yaml.endswith('.yaml')
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import get_position_map, sort_to_dict_by_position_map
|
||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
@@ -234,7 +234,7 @@ class ModelProviderFactory:
|
||||
]
|
||||
|
||||
# get _position.yaml file path
|
||||
position_map = get_position_map(model_providers_path)
|
||||
position_map = get_provider_position_map(model_providers_path)
|
||||
|
||||
# traverse all model_provider_dir_paths
|
||||
model_providers: list[ModelProviderExtension] = []
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle
|
||||
from core.entities.provider_entities import (
|
||||
@@ -18,12 +19,9 @@ from core.entities.provider_entities import (
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import (
|
||||
CredentialFormSchema,
|
||||
FormType,
|
||||
ProviderEntity,
|
||||
)
|
||||
from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from extensions import ext_hosting_provider
|
||||
from extensions.ext_database import db
|
||||
@@ -45,6 +43,7 @@ class ProviderManager:
|
||||
"""
|
||||
ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.decoding_rsa_key = None
|
||||
self.decoding_cipher_rsa = None
|
||||
@@ -117,6 +116,16 @@ class ProviderManager:
|
||||
|
||||
# Construct ProviderConfiguration objects for each provider
|
||||
for provider_entity in provider_entities:
|
||||
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
continue
|
||||
|
||||
provider_name = provider_entity.provider
|
||||
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
|
||||
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
|
||||
@@ -271,6 +280,24 @@ class ProviderManager:
|
||||
)
|
||||
)
|
||||
|
||||
def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]:
|
||||
"""
|
||||
Get names of first model and its provider
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param model_type: model type
|
||||
:return: provider name, model name
|
||||
"""
|
||||
provider_configurations = self.get_configurations(tenant_id)
|
||||
|
||||
# get available models from provider_configurations
|
||||
all_models = provider_configurations.get_models(
|
||||
model_type=model_type,
|
||||
only_active=False
|
||||
)
|
||||
|
||||
return all_models[0].provider.provider, all_models[0].model
|
||||
|
||||
def update_default_model_record(self, tenant_id: str, model_type: ModelType, provider: str, model: str) \
|
||||
-> TenantDefaultModel:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os.path
|
||||
|
||||
from core.helper.position_helper import get_position_map, sort_by_position_map
|
||||
from core.helper.position_helper import get_tool_position_map, sort_by_position_map
|
||||
from core.tools.entities.api_entities import UserToolProvider
|
||||
|
||||
|
||||
@@ -10,11 +10,11 @@ class BuiltinToolProviderSort:
|
||||
@classmethod
|
||||
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
|
||||
if not cls._position:
|
||||
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||
cls._position = get_tool_position_map(os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
def name_func(provider: UserToolProvider) -> str:
|
||||
return provider.name
|
||||
|
||||
sorted_providers = sort_by_position_map(cls._position, providers, name_func)
|
||||
|
||||
return sorted_providers
|
||||
return sorted_providers
|
||||
|
||||
@@ -10,14 +10,11 @@ from configs import dify_config
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiToolProviderController
|
||||
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
|
||||
@@ -26,10 +23,7 @@ from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import (
|
||||
ToolConfigurationManager,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
@@ -38,6 +32,7 @@ from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
_builtin_provider_lock = Lock()
|
||||
_builtin_providers = {}
|
||||
@@ -107,7 +102,7 @@ class ToolManager:
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
|
||||
@@ -346,7 +341,7 @@ class ToolManager:
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
provider: BuiltinToolProviderController = provider_class()
|
||||
cls._builtin_providers[provider.identity.name] = provider
|
||||
@@ -414,6 +409,15 @@ class ToolManager:
|
||||
|
||||
# append builtin providers
|
||||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name
|
||||
):
|
||||
continue
|
||||
|
||||
user_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider,
|
||||
db_provider=find_db_builtin_provider(provider.identity.name),
|
||||
@@ -473,7 +477,7 @@ class ToolManager:
|
||||
|
||||
@classmethod
|
||||
def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
ApiToolProviderController, dict[str, Any]]:
|
||||
"""
|
||||
get the api provider
|
||||
|
||||
@@ -593,4 +597,5 @@ class ToolManager:
|
||||
else:
|
||||
raise ValueError(f"provider type {provider_type} not found")
|
||||
|
||||
|
||||
ToolManager.load_builtin_providers_cache()
|
||||
|
||||
Reference in New Issue
Block a user