remove bare list, dict, Sequence, None, Any (#25058)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -105,14 +105,14 @@ class AccountService:
|
||||
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
|
||||
|
||||
@staticmethod
|
||||
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
|
||||
def _store_refresh_token(refresh_token: str, account_id: str):
|
||||
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
|
||||
redis_client.setex(
|
||||
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
|
||||
def _delete_refresh_token(refresh_token: str, account_id: str):
|
||||
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
|
||||
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
|
||||
|
||||
@@ -312,12 +312,12 @@ class AccountService:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def delete_account(account: Account) -> None:
|
||||
def delete_account(account: Account):
|
||||
"""Delete account. This method only adds a task to the queue for deletion."""
|
||||
delete_account_task.delay(account.id)
|
||||
|
||||
@staticmethod
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
|
||||
def link_account_integrate(provider: str, open_id: str, account: Account):
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
@@ -344,7 +344,7 @@ class AccountService:
|
||||
raise LinkAccountIntegrateError("Failed to link account.") from e
|
||||
|
||||
@staticmethod
|
||||
def close_account(account: Account) -> None:
|
||||
def close_account(account: Account):
|
||||
"""Close account"""
|
||||
account.status = AccountStatus.CLOSED.value
|
||||
db.session.commit()
|
||||
@@ -374,7 +374,7 @@ class AccountService:
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_login_info(account: Account, *, ip_address: str) -> None:
|
||||
def update_login_info(account: Account, *, ip_address: str):
|
||||
"""Update last login time and ip"""
|
||||
account.last_login_at = naive_utc_now()
|
||||
account.last_login_ip = ip_address
|
||||
@@ -398,7 +398,7 @@ class AccountService:
|
||||
return TokenPair(access_token=access_token, refresh_token=refresh_token)
|
||||
|
||||
@staticmethod
|
||||
def logout(*, account: Account) -> None:
|
||||
def logout(*, account: Account):
|
||||
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
|
||||
if refresh_token:
|
||||
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
|
||||
@@ -705,7 +705,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_login_error_rate_limit(email: str) -> None:
|
||||
def add_login_error_rate_limit(email: str):
|
||||
key = f"login_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@@ -734,7 +734,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_forgot_password_error_rate_limit(email: str) -> None:
|
||||
def add_forgot_password_error_rate_limit(email: str):
|
||||
key = f"forgot_password_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@@ -763,7 +763,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_change_email_error_rate_limit(email: str) -> None:
|
||||
def add_change_email_error_rate_limit(email: str):
|
||||
key = f"change_email_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@@ -791,7 +791,7 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
@redis_fallback(default_return=None)
|
||||
def add_owner_transfer_error_rate_limit(email: str) -> None:
|
||||
def add_owner_transfer_error_rate_limit(email: str):
|
||||
key = f"owner_transfer_error_rate_limit:{email}"
|
||||
count = redis_client.get(key)
|
||||
if count is None:
|
||||
@@ -970,7 +970,7 @@ class TenantService:
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None:
|
||||
def switch_tenant(account: Account, tenant_id: Optional[str] = None):
|
||||
"""Switch the current workspace for the account"""
|
||||
|
||||
# Ensure tenant_id is provided
|
||||
@@ -1067,7 +1067,7 @@ class TenantService:
|
||||
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
|
||||
|
||||
@staticmethod
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
|
||||
"""Check member permission"""
|
||||
perms = {
|
||||
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
@@ -1087,7 +1087,7 @@ class TenantService:
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
|
||||
"""Remove member from tenant"""
|
||||
if operator.id == account.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
@@ -1102,7 +1102,7 @@ class TenantService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
|
||||
@@ -1129,7 +1129,7 @@ class TenantService:
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_custom_config(tenant_id: str) -> dict:
|
||||
def get_custom_config(tenant_id: str):
|
||||
tenant = db.get_or_404(Tenant, tenant_id)
|
||||
|
||||
return tenant.custom_config_dict
|
||||
@@ -1150,7 +1150,7 @@ class RegisterService:
|
||||
return f"member_invite:token:{token}"
|
||||
|
||||
@classmethod
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str):
|
||||
"""
|
||||
Setup dify
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from models.model import AppMode
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
@classmethod
|
||||
def get_prompt(cls, args: dict) -> dict:
|
||||
def get_prompt(cls, args: dict):
|
||||
app_mode = args["app_mode"]
|
||||
model_mode = args["model_mode"]
|
||||
model_name = args["model_name"]
|
||||
@@ -29,7 +29,7 @@ class AdvancedPromptTemplateService:
|
||||
return cls.get_common_prompt(app_mode, model_mode, has_context)
|
||||
|
||||
@classmethod
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
@@ -52,7 +52,7 @@ class AdvancedPromptTemplateService:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
|
||||
if has_context == "true":
|
||||
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
|
||||
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
@@ -61,7 +61,7 @@ class AdvancedPromptTemplateService:
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
|
||||
if has_context == "true":
|
||||
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
|
||||
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
@@ -70,7 +70,7 @@ class AdvancedPromptTemplateService:
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
|
||||
@@ -16,7 +16,7 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
|
||||
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str):
|
||||
"""
|
||||
Service to get agent logs
|
||||
"""
|
||||
|
||||
@@ -73,7 +73,7 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
|
||||
def enable_app_annotation(cls, args: dict, app_id: str):
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@@ -96,7 +96,7 @@ class AppAnnotationService:
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str) -> dict:
|
||||
def disable_app_annotation(cls, app_id: str):
|
||||
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
@@ -315,7 +315,7 @@ class AppAnnotationService:
|
||||
return {"deleted_count": deleted_count}
|
||||
|
||||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage):
|
||||
# get app info
|
||||
app = (
|
||||
db.session.query(App)
|
||||
@@ -490,7 +490,7 @@ class AppAnnotationService:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def clear_all_annotations(cls, app_id: str) -> dict:
|
||||
def clear_all_annotations(cls, app_id: str):
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
|
||||
@@ -30,7 +30,7 @@ class APIBasedExtensionService:
|
||||
return extension_data
|
||||
|
||||
@staticmethod
|
||||
def delete(extension_data: APIBasedExtension) -> None:
|
||||
def delete(extension_data: APIBasedExtension):
|
||||
db.session.delete(extension_data)
|
||||
db.session.commit()
|
||||
|
||||
@@ -51,7 +51,7 @@ class APIBasedExtensionService:
|
||||
return extension
|
||||
|
||||
@classmethod
|
||||
def _validation(cls, extension_data: APIBasedExtension) -> None:
|
||||
def _validation(cls, extension_data: APIBasedExtension):
|
||||
# name
|
||||
if not extension_data.name:
|
||||
raise ValueError("name must not be empty")
|
||||
@@ -95,7 +95,7 @@ class APIBasedExtensionService:
|
||||
cls._ping_connection(extension_data)
|
||||
|
||||
@staticmethod
|
||||
def _ping_connection(extension_data: APIBasedExtension) -> None:
|
||||
def _ping_connection(extension_data: APIBasedExtension):
|
||||
try:
|
||||
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
|
||||
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
|
||||
|
||||
@@ -566,7 +566,7 @@ class AppDslService:
|
||||
@classmethod
|
||||
def _append_workflow_export_data(
|
||||
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Append workflow export data
|
||||
:param export_data: export data
|
||||
@@ -608,7 +608,7 @@ class AppDslService:
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
|
||||
"""
|
||||
Append model config export data
|
||||
:param export_data: export data
|
||||
|
||||
@@ -6,7 +6,7 @@ from models.model import AppMode
|
||||
|
||||
class AppModelConfigService:
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
|
||||
if app_mode == AppMode.CHAT:
|
||||
return ChatAppConfigManager.config_validate(tenant_id, config)
|
||||
elif app_mode == AppMode.AGENT_CHAT:
|
||||
|
||||
@@ -316,7 +316,7 @@ class AppService:
|
||||
|
||||
return app
|
||||
|
||||
def delete_app(self, app: App) -> None:
|
||||
def delete_app(self, app: App):
|
||||
"""
|
||||
Delete app
|
||||
:param app: App instance
|
||||
@@ -331,7 +331,7 @@ class AppService:
|
||||
# Trigger asynchronous deletion of app and related data
|
||||
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
|
||||
|
||||
def get_app_meta(self, app_model: App) -> dict:
|
||||
def get_app_meta(self, app_model: App):
|
||||
"""
|
||||
Get app meta info
|
||||
:param app_model: app model
|
||||
|
||||
@@ -8,7 +8,7 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str) -> list:
|
||||
def get_provider_auth_list(tenant_id: str):
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
|
||||
|
||||
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ClearFreePlanTenantExpiredLogs:
|
||||
@classmethod
|
||||
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
|
||||
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]):
|
||||
"""
|
||||
Clean up message-related tables to avoid data redundancy.
|
||||
This method cleans up tables that have foreign key relationships with Message.
|
||||
@@ -353,7 +353,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
def process_tenant(flask_app: Flask, tenant_id: str):
|
||||
try:
|
||||
if (
|
||||
not dify_config.BILLING_ENABLED
|
||||
|
||||
@@ -3,7 +3,7 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
class CodeBasedExtensionService:
|
||||
@staticmethod
|
||||
def get_code_based_extension(module: str) -> list[dict]:
|
||||
def get_code_based_extension(module: str):
|
||||
module_extensions = code_based_extension.module_extensions(module)
|
||||
return [
|
||||
{
|
||||
|
||||
@@ -250,7 +250,7 @@ class ConversationService:
|
||||
variable_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
new_value: Any,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Update a conversation variable's value.
|
||||
|
||||
|
||||
@@ -719,7 +719,7 @@ class DatasetService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
|
||||
def get_dataset_auto_disable_logs(dataset_id: str):
|
||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
|
||||
return {
|
||||
|
||||
@@ -83,7 +83,7 @@ class ProviderResponse(BaseModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (
|
||||
@@ -113,7 +113,7 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
status: CustomConfigurationStatus
|
||||
models: list[ProviderModelWithStatusEntity]
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (
|
||||
@@ -137,7 +137,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
|
||||
tenant_id: str
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (
|
||||
@@ -174,7 +174,7 @@ class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
|
||||
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
|
||||
def __init__(self, tenant_id: str, model: ModelWithProviderEntity):
|
||||
dump_model = model.model_dump()
|
||||
dump_model["provider"]["tenant_id"] = tenant_id
|
||||
super().__init__(**dump_model)
|
||||
|
||||
@@ -6,7 +6,7 @@ class InvokeError(Exception):
|
||||
|
||||
description: Optional[str] = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None) -> None:
|
||||
def __init__(self, description: Optional[str] = None):
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -277,7 +277,7 @@ class ExternalDatasetService:
|
||||
query: str,
|
||||
external_retrieval_parameters: dict,
|
||||
metadata_condition: Optional[MetadataCondition] = None,
|
||||
) -> list:
|
||||
):
|
||||
external_knowledge_binding = (
|
||||
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ class HitTestingService:
|
||||
retrieval_model: Any, # FIXME drop this any
|
||||
external_retrieval_model: dict,
|
||||
limit: int = 10,
|
||||
) -> dict:
|
||||
):
|
||||
start = time.perf_counter()
|
||||
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
@@ -98,7 +98,7 @@ class HitTestingService:
|
||||
account: Account,
|
||||
external_retrieval_model: dict,
|
||||
metadata_filtering_conditions: dict,
|
||||
) -> dict:
|
||||
):
|
||||
if dataset.provider != "external":
|
||||
return {
|
||||
"query": {"content": query},
|
||||
|
||||
@@ -25,10 +25,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelLoadBalancingService:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
enable model load balancing.
|
||||
|
||||
@@ -49,7 +49,7 @@ class ModelLoadBalancingService:
|
||||
# Enable model load balancing
|
||||
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
disable model load balancing.
|
||||
|
||||
@@ -295,7 +295,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
def update_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Update load balancing configurations.
|
||||
:param tenant_id: workspace id
|
||||
@@ -478,7 +478,7 @@ class ModelLoadBalancingService:
|
||||
model_type: str,
|
||||
credentials: dict,
|
||||
config_id: Optional[str] = None,
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
Validate load balancing credentials.
|
||||
:param tenant_id: workspace id
|
||||
@@ -537,7 +537,7 @@ class ModelLoadBalancingService:
|
||||
credentials: dict,
|
||||
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
|
||||
validate: bool = True,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param tenant_id: workspace id
|
||||
@@ -605,7 +605,7 @@ class ModelLoadBalancingService:
|
||||
else:
|
||||
raise ValueError("No credential schema found")
|
||||
|
||||
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
|
||||
def _clear_credentials_cache(self, tenant_id: str, config_id: str):
|
||||
"""
|
||||
Clear credentials cache.
|
||||
:param tenant_id: workspace id
|
||||
|
||||
@@ -26,7 +26,7 @@ class ModelProviderService:
|
||||
Model Provider Service
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self):
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
def _get_provider_configuration(self, tenant_id: str, provider: str):
|
||||
@@ -142,7 +142,7 @@ class ModelProviderService:
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
|
||||
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
|
||||
"""
|
||||
validate provider credentials before saving.
|
||||
|
||||
@@ -193,7 +193,7 @@ class ModelProviderService:
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
|
||||
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
|
||||
"""
|
||||
remove a saved provider credential (by credential_id).
|
||||
:param tenant_id: workspace id
|
||||
@@ -204,7 +204,7 @@ class ModelProviderService:
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.delete_provider_credential(credential_id=credential_id)
|
||||
|
||||
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
|
||||
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
|
||||
"""
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
@@ -232,9 +232,7 @@ class ModelProviderService:
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def validate_model_credentials(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
) -> None:
|
||||
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
|
||||
"""
|
||||
validate model credentials.
|
||||
|
||||
@@ -303,9 +301,7 @@ class ModelProviderService:
|
||||
credential_name=credential_name,
|
||||
)
|
||||
|
||||
def remove_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str):
|
||||
"""
|
||||
remove model credentials.
|
||||
|
||||
@@ -323,7 +319,7 @@ class ModelProviderService:
|
||||
|
||||
def switch_active_custom_model_credential(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
switch model credentials.
|
||||
|
||||
@@ -341,7 +337,7 @@ class ModelProviderService:
|
||||
|
||||
def add_model_credential_to_model_list(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
|
||||
) -> None:
|
||||
):
|
||||
"""
|
||||
add model credentials to model list.
|
||||
|
||||
@@ -357,7 +353,7 @@ class ModelProviderService:
|
||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||
)
|
||||
|
||||
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
|
||||
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
|
||||
"""
|
||||
remove model credentials.
|
||||
|
||||
@@ -485,7 +481,7 @@ class ModelProviderService:
|
||||
logger.debug("get_default_model_of_model_type error: %s", e)
|
||||
return None
|
||||
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
|
||||
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str):
|
||||
"""
|
||||
update default model of model type.
|
||||
|
||||
@@ -517,7 +513,7 @@ class ModelProviderService:
|
||||
|
||||
return byte_data, mime_type
|
||||
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
|
||||
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str):
|
||||
"""
|
||||
switch preferred provider.
|
||||
|
||||
@@ -534,7 +530,7 @@ class ModelProviderService:
|
||||
# Switch preferred provider type
|
||||
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
|
||||
|
||||
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
enable model.
|
||||
|
||||
@@ -547,7 +543,7 @@ class ModelProviderService:
|
||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
|
||||
"""
|
||||
disable model.
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class PluginDataMigration:
|
||||
@classmethod
|
||||
def migrate(cls) -> None:
|
||||
def migrate(cls):
|
||||
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
|
||||
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
|
||||
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
|
||||
@@ -26,7 +26,7 @@ class PluginDataMigration:
|
||||
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
|
||||
|
||||
@classmethod
|
||||
def migrate_datasets(cls) -> None:
|
||||
def migrate_datasets(cls):
|
||||
table_name = "datasets"
|
||||
provider_column_name = "embedding_model_provider"
|
||||
|
||||
@@ -126,9 +126,7 @@ limit 1000"""
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_db_records(
|
||||
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
|
||||
) -> None:
|
||||
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
|
||||
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
|
||||
|
||||
processed_count = 0
|
||||
|
||||
@@ -33,7 +33,7 @@ excluded_providers = ["time", "audio", "code", "webscraper"]
|
||||
|
||||
class PluginMigration:
|
||||
@classmethod
|
||||
def extract_plugins(cls, filepath: str, workers: int) -> None:
|
||||
def extract_plugins(cls, filepath: str, workers: int):
|
||||
"""
|
||||
Migrate plugin.
|
||||
"""
|
||||
@@ -55,7 +55,7 @@ class PluginMigration:
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=workers)
|
||||
|
||||
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
|
||||
def process_tenant(flask_app: Flask, tenant_id: str):
|
||||
with flask_app.app_context():
|
||||
nonlocal handled_tenant_count
|
||||
try:
|
||||
@@ -291,7 +291,7 @@ class PluginMigration:
|
||||
return plugin_manifest[0].latest_package_identifier
|
||||
|
||||
@classmethod
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
|
||||
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
|
||||
"""
|
||||
Extract unique plugins.
|
||||
"""
|
||||
@@ -328,7 +328,7 @@ class PluginMigration:
|
||||
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
|
||||
|
||||
@classmethod
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
|
||||
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
|
||||
"""
|
||||
Install plugins.
|
||||
"""
|
||||
@@ -348,7 +348,7 @@ class PluginMigration:
|
||||
if response.get("failed"):
|
||||
plugin_install_failed.extend(response.get("failed", []))
|
||||
|
||||
def install(tenant_id: str, plugin_ids: list[str]) -> None:
|
||||
def install(tenant_id: str, plugin_ids: list[str]):
|
||||
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
|
||||
# fetch plugin already installed
|
||||
installed_plugins = manager.list_plugins(tenant_id)
|
||||
|
||||
@@ -19,7 +19,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
def get_type(self) -> str:
|
||||
return RecommendAppType.BUILDIN
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
result = self.fetch_recommended_apps_from_builtin(language)
|
||||
return result
|
||||
|
||||
@@ -28,7 +28,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _get_builtin_data(cls) -> dict:
|
||||
def _get_builtin_data(cls):
|
||||
"""
|
||||
Get builtin data.
|
||||
:return:
|
||||
@@ -44,7 +44,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return cls.builtin_data or {}
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
|
||||
def fetch_recommended_apps_from_builtin(cls, language: str):
|
||||
"""
|
||||
Fetch recommended apps from builtin.
|
||||
:param language: language
|
||||
|
||||
@@ -13,7 +13,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
Retrieval recommended app from database
|
||||
"""
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
result = self.fetch_recommended_apps_from_db(language)
|
||||
return result
|
||||
|
||||
@@ -25,7 +25,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return RecommendAppType.DATABASE
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
|
||||
def fetch_recommended_apps_from_db(cls, language: str):
|
||||
"""
|
||||
Fetch recommended apps from db.
|
||||
:param language: language
|
||||
|
||||
@@ -5,7 +5,7 @@ class RecommendAppRetrievalBase(ABC):
|
||||
"""Interface for recommend app retrieval."""
|
||||
|
||||
@abstractmethod
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -24,7 +24,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id)
|
||||
return result
|
||||
|
||||
def get_recommended_apps_and_categories(self, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(self, language: str):
|
||||
try:
|
||||
result = self.fetch_recommended_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
@@ -51,7 +51,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
|
||||
def fetch_recommended_apps_from_dify_official(cls, language: str):
|
||||
"""
|
||||
Fetch recommended apps from dify official.
|
||||
:param language: language
|
||||
|
||||
@@ -6,7 +6,7 @@ from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFa
|
||||
|
||||
class RecommendedAppService:
|
||||
@classmethod
|
||||
def get_recommended_apps_and_categories(cls, language: str) -> dict:
|
||||
def get_recommended_apps_and_categories(cls, language: str):
|
||||
"""
|
||||
Get recommended apps and categories.
|
||||
:param language: language
|
||||
|
||||
@@ -12,7 +12,7 @@ from models.model import App, Tag, TagBinding
|
||||
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None):
|
||||
query = (
|
||||
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
||||
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
@@ -25,7 +25,7 @@ class TagService:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list):
|
||||
# Check if tag_ids is not empty to avoid WHERE false condition
|
||||
if not tag_ids or len(tag_ids) == 0:
|
||||
return []
|
||||
@@ -51,7 +51,7 @@ class TagService:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
|
||||
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
|
||||
if not tag_type or not tag_name:
|
||||
return []
|
||||
tags = (
|
||||
@@ -64,7 +64,7 @@ class TagService:
|
||||
return tags
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
|
||||
@@ -37,7 +37,7 @@ class WorkflowToolManageService:
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
@@ -103,7 +103,7 @@ class WorkflowToolManageService:
|
||||
parameters: list[Mapping[str, Any]],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Update a workflow tool.
|
||||
:param user_id: the user id
|
||||
@@ -217,7 +217,7 @@ class WorkflowToolManageService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Delete a workflow tool.
|
||||
:param user_id: the user id
|
||||
@@ -233,7 +233,7 @@ class WorkflowToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
@@ -249,7 +249,7 @@ class WorkflowToolManageService:
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:param user_id: the user id
|
||||
@@ -265,7 +265,7 @@ class WorkflowToolManageService:
|
||||
return cls._get_workflow_tool(tenant_id, db_tool)
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
|
||||
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
|
||||
"""
|
||||
Get a workflow tool.
|
||||
:db_tool: the database tool
|
||||
|
||||
@@ -132,7 +132,7 @@ class WebsiteService:
|
||||
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict) -> None:
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
"""Validate arguments for document creation."""
|
||||
try:
|
||||
WebsiteCrawlApiRequest.from_args(args)
|
||||
|
||||
@@ -217,7 +217,7 @@ class WorkflowConverter:
|
||||
|
||||
return app_config
|
||||
|
||||
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
|
||||
def _convert_to_start_node(self, variables: list[VariableEntity]):
|
||||
"""
|
||||
Convert to Start Node
|
||||
:param variables: list of variables
|
||||
@@ -384,7 +384,7 @@ class WorkflowConverter:
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileUploadConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Convert to LLM Node
|
||||
:param original_app_mode: original app mode
|
||||
@@ -550,7 +550,7 @@ class WorkflowConverter:
|
||||
|
||||
return template
|
||||
|
||||
def _convert_to_end_node(self) -> dict:
|
||||
def _convert_to_end_node(self):
|
||||
"""
|
||||
Convert to End Node
|
||||
:return:
|
||||
@@ -566,7 +566,7 @@ class WorkflowConverter:
|
||||
},
|
||||
}
|
||||
|
||||
def _convert_to_answer_node(self) -> dict:
|
||||
def _convert_to_answer_node(self):
|
||||
"""
|
||||
Convert to Answer Node
|
||||
:return:
|
||||
@@ -578,7 +578,7 @@ class WorkflowConverter:
|
||||
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
|
||||
}
|
||||
|
||||
def _create_edge(self, source: str, target: str) -> dict:
|
||||
def _create_edge(self, source: str, target: str):
|
||||
"""
|
||||
Create Edge
|
||||
:param source: source node id
|
||||
@@ -587,7 +587,7 @@ class WorkflowConverter:
|
||||
"""
|
||||
return {"id": f"{source}-{target}", "source": source, "target": target}
|
||||
|
||||
def _append_node(self, graph: dict, node: dict) -> dict:
|
||||
def _append_node(self, graph: dict, node: dict):
|
||||
"""
|
||||
Append Node to Graph
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class WorkflowAppService:
|
||||
limit: int = 20,
|
||||
created_by_end_user_session_id: str | None = None,
|
||||
created_by_account: str | None = None,
|
||||
) -> dict:
|
||||
):
|
||||
"""
|
||||
Get paginate workflow app logs using SQLAlchemy 2.0 style
|
||||
:param session: SQLAlchemy session
|
||||
|
||||
@@ -67,7 +67,7 @@ class DraftVarLoader(VariableLoader):
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
fallback_variables: Sequence[Variable] | None = None,
|
||||
) -> None:
|
||||
):
|
||||
self._engine = engine
|
||||
self._app_id = app_id
|
||||
self._tenant_id = tenant_id
|
||||
@@ -117,7 +117,7 @@ class DraftVarLoader(VariableLoader):
|
||||
class WorkflowDraftVariableService:
|
||||
_session: Session
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
def __init__(self, session: Session):
|
||||
"""
|
||||
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
|
||||
|
||||
@@ -438,7 +438,7 @@ def _batch_upsert_draft_variable(
|
||||
session: Session,
|
||||
draft_vars: Sequence[WorkflowDraftVariable],
|
||||
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
|
||||
) -> None:
|
||||
):
|
||||
if not draft_vars:
|
||||
return None
|
||||
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
|
||||
|
||||
@@ -591,7 +591,7 @@ class WorkflowService:
|
||||
|
||||
return new_app
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict) -> dict:
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
|
||||
Reference in New Issue
Block a user