feat: support pinning, including, and excluding for model providers and tools (#7419)

Co-authored-by: GareArc <chen4851@purude.edu>
This commit is contained in:
Xiyuan Chen
2024-08-20 23:16:43 -04:00
committed by GitHub
parent 6c25d7bed3
commit 4e7b6aec3a
14 changed files with 363 additions and 57 deletions

View File

@@ -111,6 +111,12 @@ class AppService:
'completion_params': {}
}
else:
provider, model = model_manager.get_default_provider_model_name(
tenant_id=account.current_tenant_id,
model_type=ModelType.LLM
)
default_model_config['model']['provider'] = provider
default_model_config['model']['name'] = model
default_model_dict = default_model_config['model']
default_model_config['model'] = json.dumps(default_model_dict)
@@ -190,13 +196,14 @@ class AppService:
"""
Modified App class
"""
def __init__(self, app):
self.__dict__.update(app.__dict__)
@property
def app_model_config(self):
return model_config
app = ModifiedApp(app)
return app

View File

@@ -30,6 +30,7 @@ class ModelProviderService:
"""
Model Provider Service
"""
def __init__(self) -> None:
self.provider_manager = ProviderManager()
@@ -387,18 +388,21 @@ class ModelProviderService:
tenant_id=tenant_id,
model_type=model_type_enum
)
return DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
)
) if result else None
try:
return DefaultModelResponse(
model=result.model,
model_type=result.model_type,
provider=SimpleProviderEntityResponse(
provider=result.provider.provider,
label=result.provider.label,
icon_small=result.provider.icon_small,
icon_large=result.provider.icon_large,
supported_model_types=result.provider.supported_model_types
)
) if result else None
except Exception as e:
logger.info(f"get_default_model_of_model_type error: {e}")
return None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
"""

View File

@@ -1,6 +1,8 @@
import json
import logging
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
@@ -43,14 +45,14 @@ class BuiltinToolManageService:
result = []
for tool in tools:
result.append(ToolTransformService.tool_to_user_tool(
tool=tool,
credentials=credentials,
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
return result
@staticmethod
def list_builtin_provider_credentials_schema(
provider_name
@@ -78,7 +80,7 @@ class BuiltinToolManageService:
BuiltinToolProvider.provider == provider_name,
).first()
try:
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
@@ -119,8 +121,8 @@ class BuiltinToolManageService:
# delete cache
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
return {'result': 'success'}
@staticmethod
def get_builtin_tool_provider_credentials(
user_id: str, tenant_id: str, provider: str
@@ -135,7 +137,7 @@ class BuiltinToolManageService:
if provider is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider.provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
@@ -156,7 +158,7 @@ class BuiltinToolManageService:
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
db.session.delete(provider)
db.session.commit()
@@ -165,8 +167,8 @@ class BuiltinToolManageService:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
return {'result': 'success'}
@staticmethod
def get_builtin_tool_provider_icon(
provider: str
@@ -179,7 +181,7 @@ class BuiltinToolManageService:
icon_bytes = f.read()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(
user_id: str, tenant_id: str
@@ -202,6 +204,15 @@ class BuiltinToolManageService:
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller,
name_func=lambda x: x.identity.name
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@@ -226,4 +237,3 @@ class BuiltinToolManageService:
raise e
return BuiltinToolProviderSort.sort(result)