feat: support pinning, including, and excluding for model providers and tools (#7419)
Co-authored-by: GareArc <chen4851@purude.edu>
This commit is contained in:
@@ -111,6 +111,12 @@ class AppService:
|
||||
'completion_params': {}
|
||||
}
|
||||
else:
|
||||
provider, model = model_manager.get_default_provider_model_name(
|
||||
tenant_id=account.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
default_model_config['model']['provider'] = provider
|
||||
default_model_config['model']['name'] = model
|
||||
default_model_dict = default_model_config['model']
|
||||
|
||||
default_model_config['model'] = json.dumps(default_model_dict)
|
||||
@@ -190,13 +196,14 @@ class AppService:
|
||||
"""
|
||||
Modified App class
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.__dict__.update(app.__dict__)
|
||||
|
||||
@property
|
||||
def app_model_config(self):
|
||||
return model_config
|
||||
|
||||
|
||||
app = ModifiedApp(app)
|
||||
|
||||
return app
|
||||
|
||||
@@ -30,6 +30,7 @@ class ModelProviderService:
|
||||
"""
|
||||
Model Provider Service
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
@@ -387,18 +388,21 @@ class ModelProviderService:
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type_enum
|
||||
)
|
||||
|
||||
return DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types
|
||||
)
|
||||
) if result else None
|
||||
try:
|
||||
return DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types
|
||||
)
|
||||
) if result else None
|
||||
except Exception as e:
|
||||
logger.info(f"get_default_model_of_model_type error: {e}")
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
@@ -43,14 +45,14 @@ class BuiltinToolManageService:
|
||||
result = []
|
||||
for tool in tools:
|
||||
result.append(ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name
|
||||
@@ -78,7 +80,7 @@ class BuiltinToolManageService:
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
|
||||
try:
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
@@ -119,8 +121,8 @@ class BuiltinToolManageService:
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
@@ -135,7 +137,7 @@ class BuiltinToolManageService:
|
||||
|
||||
if provider is None:
|
||||
return {}
|
||||
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider.provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
@@ -156,7 +158,7 @@ class BuiltinToolManageService:
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
@@ -165,8 +167,8 @@ class BuiltinToolManageService:
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {'result': 'success'}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(
|
||||
provider: str
|
||||
@@ -179,7 +181,7 @@ class BuiltinToolManageService:
|
||||
icon_bytes = f.read()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(
|
||||
user_id: str, tenant_id: str
|
||||
@@ -202,6 +204,15 @@ class BuiltinToolManageService:
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name
|
||||
):
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
@@ -226,4 +237,3 @@ class BuiltinToolManageService:
|
||||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
Reference in New Issue
Block a user