remove bare list, dict, Sequence, None, Any (#25058)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Asuka Minato
2025-09-06 04:32:23 +09:00
committed by GitHub
parent 2b0695bdde
commit a78339a040
306 changed files with 787 additions and 817 deletions

View File

@@ -105,14 +105,14 @@ class AccountService:
return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}"
@staticmethod
def _store_refresh_token(refresh_token: str, account_id: str) -> None:
def _store_refresh_token(refresh_token: str, account_id: str):
redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id)
redis_client.setex(
AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token
)
@staticmethod
def _delete_refresh_token(refresh_token: str, account_id: str) -> None:
def _delete_refresh_token(refresh_token: str, account_id: str):
redis_client.delete(AccountService._get_refresh_token_key(refresh_token))
redis_client.delete(AccountService._get_account_refresh_token_key(account_id))
@@ -312,12 +312,12 @@ class AccountService:
return True
@staticmethod
def delete_account(account: Account) -> None:
def delete_account(account: Account):
"""Delete account. This method only adds a task to the queue for deletion."""
delete_account_task.delay(account.id)
@staticmethod
def link_account_integrate(provider: str, open_id: str, account: Account) -> None:
def link_account_integrate(provider: str, open_id: str, account: Account):
"""Link account integrate"""
try:
# Query whether there is an existing binding record for the same provider
@@ -344,7 +344,7 @@ class AccountService:
raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod
def close_account(account: Account) -> None:
def close_account(account: Account):
"""Close account"""
account.status = AccountStatus.CLOSED.value
db.session.commit()
@@ -374,7 +374,7 @@ class AccountService:
return account
@staticmethod
def update_login_info(account: Account, *, ip_address: str) -> None:
def update_login_info(account: Account, *, ip_address: str):
"""Update last login time and ip"""
account.last_login_at = naive_utc_now()
account.last_login_ip = ip_address
@@ -398,7 +398,7 @@ class AccountService:
return TokenPair(access_token=access_token, refresh_token=refresh_token)
@staticmethod
def logout(*, account: Account) -> None:
def logout(*, account: Account):
refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
if refresh_token:
AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
@@ -705,7 +705,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_login_error_rate_limit(email: str) -> None:
def add_login_error_rate_limit(email: str):
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@@ -734,7 +734,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_forgot_password_error_rate_limit(email: str) -> None:
def add_forgot_password_error_rate_limit(email: str):
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@@ -763,7 +763,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_change_email_error_rate_limit(email: str) -> None:
def add_change_email_error_rate_limit(email: str):
key = f"change_email_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@@ -791,7 +791,7 @@ class AccountService:
@staticmethod
@redis_fallback(default_return=None)
def add_owner_transfer_error_rate_limit(email: str) -> None:
def add_owner_transfer_error_rate_limit(email: str):
key = f"owner_transfer_error_rate_limit:{email}"
count = redis_client.get(key)
if count is None:
@@ -970,7 +970,7 @@ class TenantService:
return tenant
@staticmethod
def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None:
def switch_tenant(account: Account, tenant_id: Optional[str] = None):
"""Switch the current workspace for the account"""
# Ensure tenant_id is provided
@@ -1067,7 +1067,7 @@ class TenantService:
return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str):
"""Check member permission"""
perms = {
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
@@ -1087,7 +1087,7 @@ class TenantService:
raise NoPermissionError(f"No permission to {action} member.")
@staticmethod
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account):
"""Remove member from tenant"""
if operator.id == account.id:
raise CannotOperateSelfError("Cannot operate self.")
@@ -1102,7 +1102,7 @@ class TenantService:
db.session.commit()
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
"""Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update")
@@ -1129,7 +1129,7 @@ class TenantService:
db.session.commit()
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
def get_custom_config(tenant_id: str):
tenant = db.get_or_404(Tenant, tenant_id)
return tenant.custom_config_dict
@@ -1150,7 +1150,7 @@ class RegisterService:
return f"member_invite:token:{token}"
@classmethod
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
def setup(cls, email: str, name: str, password: str, ip_address: str):
"""
Setup dify

View File

@@ -17,7 +17,7 @@ from models.model import AppMode
class AdvancedPromptTemplateService:
@classmethod
def get_prompt(cls, args: dict) -> dict:
def get_prompt(cls, args: dict):
app_mode = args["app_mode"]
model_mode = args["model_mode"]
model_name = args["model_name"]
@@ -29,7 +29,7 @@ class AdvancedPromptTemplateService:
return cls.get_common_prompt(app_mode, model_mode, has_context)
@classmethod
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str):
context_prompt = copy.deepcopy(CONTEXT)
if app_mode == AppMode.CHAT.value:
@@ -52,7 +52,7 @@ class AdvancedPromptTemplateService:
return {}
@classmethod
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
@@ -61,7 +61,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str):
if has_context == "true":
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
@@ -70,7 +70,7 @@ class AdvancedPromptTemplateService:
return prompt_template
@classmethod
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str):
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
if app_mode == AppMode.CHAT.value:

View File

@@ -16,7 +16,7 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough
class AgentService:
@classmethod
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str):
"""
Service to get agent logs
"""

View File

@@ -73,7 +73,7 @@ class AppAnnotationService:
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
def enable_app_annotation(cls, args: dict, app_id: str):
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(enable_app_annotation_key)
if cache_result is not None:
@@ -96,7 +96,7 @@ class AppAnnotationService:
return {"job_id": job_id, "job_status": "waiting"}
@classmethod
def disable_app_annotation(cls, app_id: str) -> dict:
def disable_app_annotation(cls, app_id: str):
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
cache_result = redis_client.get(disable_app_annotation_key)
if cache_result is not None:
@@ -315,7 +315,7 @@ class AppAnnotationService:
return {"deleted_count": deleted_count}
@classmethod
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
def batch_import_app_annotations(cls, app_id, file: FileStorage):
# get app info
app = (
db.session.query(App)
@@ -490,7 +490,7 @@ class AppAnnotationService:
}
@classmethod
def clear_all_annotations(cls, app_id: str) -> dict:
def clear_all_annotations(cls, app_id: str):
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")

View File

@@ -30,7 +30,7 @@ class APIBasedExtensionService:
return extension_data
@staticmethod
def delete(extension_data: APIBasedExtension) -> None:
def delete(extension_data: APIBasedExtension):
db.session.delete(extension_data)
db.session.commit()
@@ -51,7 +51,7 @@ class APIBasedExtensionService:
return extension
@classmethod
def _validation(cls, extension_data: APIBasedExtension) -> None:
def _validation(cls, extension_data: APIBasedExtension):
# name
if not extension_data.name:
raise ValueError("name must not be empty")
@@ -95,7 +95,7 @@ class APIBasedExtensionService:
cls._ping_connection(extension_data)
@staticmethod
def _ping_connection(extension_data: APIBasedExtension) -> None:
def _ping_connection(extension_data: APIBasedExtension):
try:
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
resp = client.request(point=APIBasedExtensionPoint.PING, params={})

View File

@@ -566,7 +566,7 @@ class AppDslService:
@classmethod
def _append_workflow_export_data(
cls, *, export_data: dict, app_model: App, include_secret: bool, workflow_id: Optional[str] = None
) -> None:
):
"""
Append workflow export data
:param export_data: export data
@@ -608,7 +608,7 @@ class AppDslService:
]
@classmethod
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
def _append_model_config_export_data(cls, export_data: dict, app_model: App):
"""
Append model config export data
:param export_data: export data

View File

@@ -6,7 +6,7 @@ from models.model import AppMode
class AppModelConfigService:
@classmethod
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode):
if app_mode == AppMode.CHAT:
return ChatAppConfigManager.config_validate(tenant_id, config)
elif app_mode == AppMode.AGENT_CHAT:

View File

@@ -316,7 +316,7 @@ class AppService:
return app
def delete_app(self, app: App) -> None:
def delete_app(self, app: App):
"""
Delete app
:param app: App instance
@@ -331,7 +331,7 @@ class AppService:
# Trigger asynchronous deletion of app and related data
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
def get_app_meta(self, app_model: App) -> dict:
def get_app_meta(self, app_model: App):
"""
Get app meta info
:param app_model: app model

View File

@@ -8,7 +8,7 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list:
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))

View File

@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class ClearFreePlanTenantExpiredLogs:
@classmethod
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]) -> None:
def _clear_message_related_tables(cls, session: Session, tenant_id: str, batch_message_ids: list[str]):
"""
Clean up message-related tables to avoid data redundancy.
This method cleans up tables that have foreign key relationships with Message.
@@ -353,7 +353,7 @@ class ClearFreePlanTenantExpiredLogs:
thread_pool = ThreadPoolExecutor(max_workers=10)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
def process_tenant(flask_app: Flask, tenant_id: str):
try:
if (
not dify_config.BILLING_ENABLED

View File

@@ -3,7 +3,7 @@ from extensions.ext_code_based_extension import code_based_extension
class CodeBasedExtensionService:
@staticmethod
def get_code_based_extension(module: str) -> list[dict]:
def get_code_based_extension(module: str):
module_extensions = code_based_extension.module_extensions(module)
return [
{

View File

@@ -250,7 +250,7 @@ class ConversationService:
variable_id: str,
user: Optional[Union[Account, EndUser]],
new_value: Any,
) -> dict:
):
"""
Update a conversation variable's value.

View File

@@ -719,7 +719,7 @@ class DatasetService:
)
@staticmethod
def get_dataset_auto_disable_logs(dataset_id: str) -> dict:
def get_dataset_auto_disable_logs(dataset_id: str):
features = FeatureService.get_features(current_user.current_tenant_id)
if not features.billing.enabled or features.billing.subscription.plan == "sandbox":
return {

View File

@@ -83,7 +83,7 @@ class ProviderResponse(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@@ -113,7 +113,7 @@ class ProviderWithModelsResponse(BaseModel):
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@@ -137,7 +137,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
tenant_id: str
def __init__(self, **data) -> None:
def __init__(self, **data):
super().__init__(**data)
url_prefix = (
@@ -174,7 +174,7 @@ class ModelWithProviderEntityResponse(ProviderModelWithStatusEntity):
provider: SimpleProviderEntityResponse
def __init__(self, tenant_id: str, model: ModelWithProviderEntity) -> None:
def __init__(self, tenant_id: str, model: ModelWithProviderEntity):
dump_model = model.model_dump()
dump_model["provider"]["tenant_id"] = tenant_id
super().__init__(**dump_model)

View File

@@ -6,7 +6,7 @@ class InvokeError(Exception):
description: Optional[str] = None
def __init__(self, description: Optional[str] = None) -> None:
def __init__(self, description: Optional[str] = None):
self.description = description
def __str__(self):

View File

@@ -277,7 +277,7 @@ class ExternalDatasetService:
query: str,
external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None,
) -> list:
):
external_knowledge_binding = (
db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
)

View File

@@ -33,7 +33,7 @@ class HitTestingService:
retrieval_model: Any, # FIXME drop this any
external_retrieval_model: dict,
limit: int = 10,
) -> dict:
):
start = time.perf_counter()
# get retrieval model , if the model is not setting , using default
@@ -98,7 +98,7 @@ class HitTestingService:
account: Account,
external_retrieval_model: dict,
metadata_filtering_conditions: dict,
) -> dict:
):
if dataset.provider != "external":
return {
"query": {"content": query},

View File

@@ -25,10 +25,10 @@ logger = logging.getLogger(__name__)
class ModelLoadBalancingService:
def __init__(self) -> None:
def __init__(self):
self.provider_manager = ProviderManager()
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model load balancing.
@@ -49,7 +49,7 @@ class ModelLoadBalancingService:
# Enable model load balancing
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model load balancing.
@@ -295,7 +295,7 @@ class ModelLoadBalancingService:
def update_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict], config_from: str
) -> None:
):
"""
Update load balancing configurations.
:param tenant_id: workspace id
@@ -478,7 +478,7 @@ class ModelLoadBalancingService:
model_type: str,
credentials: dict,
config_id: Optional[str] = None,
) -> None:
):
"""
Validate load balancing credentials.
:param tenant_id: workspace id
@@ -537,7 +537,7 @@ class ModelLoadBalancingService:
credentials: dict,
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
validate: bool = True,
) -> dict:
):
"""
Validate custom credentials.
:param tenant_id: workspace id
@@ -605,7 +605,7 @@ class ModelLoadBalancingService:
else:
raise ValueError("No credential schema found")
def _clear_credentials_cache(self, tenant_id: str, config_id: str) -> None:
def _clear_credentials_cache(self, tenant_id: str, config_id: str):
"""
Clear credentials cache.
:param tenant_id: workspace id

View File

@@ -26,7 +26,7 @@ class ModelProviderService:
Model Provider Service
"""
def __init__(self) -> None:
def __init__(self):
self.provider_manager = ProviderManager()
def _get_provider_configuration(self, tenant_id: str, provider: str):
@@ -142,7 +142,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
"""
validate provider credentials before saving.
@@ -193,7 +193,7 @@ class ModelProviderService:
credential_name=credential_name,
)
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
def remove_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
remove a saved provider credential (by credential_id).
:param tenant_id: workspace id
@@ -204,7 +204,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.delete_provider_credential(credential_id=credential_id)
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str) -> None:
def switch_active_provider_credential(self, tenant_id: str, provider: str, credential_id: str):
"""
:param tenant_id: workspace id
:param provider: provider name
@@ -232,9 +232,7 @@ class ModelProviderService:
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def validate_model_credentials(
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
) -> None:
def validate_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict):
"""
validate model credentials.
@@ -303,9 +301,7 @@ class ModelProviderService:
credential_name=credential_name,
)
def remove_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
) -> None:
def remove_model_credential(self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str):
"""
remove model credentials.
@@ -323,7 +319,7 @@ class ModelProviderService:
def switch_active_custom_model_credential(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
) -> None:
):
"""
switch model credentials.
@@ -341,7 +337,7 @@ class ModelProviderService:
def add_model_credential_to_model_list(
self, tenant_id: str, provider: str, model_type: str, model: str, credential_id: str
) -> None:
):
"""
add model credentials to model list.
@@ -357,7 +353,7 @@ class ModelProviderService:
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
def remove_model(self, tenant_id: str, provider: str, model_type: str, model: str):
"""
remove model credentials.
@@ -485,7 +481,7 @@ class ModelProviderService:
logger.debug("get_default_model_of_model_type error: %s", e)
return None
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None:
def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str):
"""
update default model of model type.
@@ -517,7 +513,7 @@ class ModelProviderService:
return byte_data, mime_type
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str) -> None:
def switch_preferred_provider(self, tenant_id: str, provider: str, preferred_provider_type: str):
"""
switch preferred provider.
@@ -534,7 +530,7 @@ class ModelProviderService:
# Switch preferred provider type
provider_configuration.switch_preferred_provider_type(preferred_provider_type_enum)
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def enable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
enable model.
@@ -547,7 +543,7 @@ class ModelProviderService:
provider_configuration = self._get_provider_configuration(tenant_id, provider)
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str):
"""
disable model.

View File

@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls) -> None:
def migrate(cls):
cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
@@ -26,7 +26,7 @@ class PluginDataMigration:
cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
@classmethod
def migrate_datasets(cls) -> None:
def migrate_datasets(cls):
table_name = "datasets"
provider_column_name = "embedding_model_provider"
@@ -126,9 +126,7 @@ limit 1000"""
)
@classmethod
def migrate_db_records(
cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
) -> None:
def migrate_db_records(cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]):
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0

View File

@@ -33,7 +33,7 @@ excluded_providers = ["time", "audio", "code", "webscraper"]
class PluginMigration:
@classmethod
def extract_plugins(cls, filepath: str, workers: int) -> None:
def extract_plugins(cls, filepath: str, workers: int):
"""
Migrate plugin.
"""
@@ -55,7 +55,7 @@ class PluginMigration:
thread_pool = ThreadPoolExecutor(max_workers=workers)
def process_tenant(flask_app: Flask, tenant_id: str) -> None:
def process_tenant(flask_app: Flask, tenant_id: str):
with flask_app.app_context():
nonlocal handled_tenant_count
try:
@@ -291,7 +291,7 @@ class PluginMigration:
return plugin_manifest[0].latest_package_identifier
@classmethod
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str) -> None:
def extract_unique_plugins_to_file(cls, extracted_plugins: str, output_file: str):
"""
Extract unique plugins.
"""
@@ -328,7 +328,7 @@ class PluginMigration:
return {"plugins": plugins, "plugin_not_exist": plugin_not_exist}
@classmethod
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100) -> None:
def install_plugins(cls, extracted_plugins: str, output_file: str, workers: int = 100):
"""
Install plugins.
"""
@@ -348,7 +348,7 @@ class PluginMigration:
if response.get("failed"):
plugin_install_failed.extend(response.get("failed", []))
def install(tenant_id: str, plugin_ids: list[str]) -> None:
def install(tenant_id: str, plugin_ids: list[str]):
logger.info("Installing %s plugins for tenant %s", len(plugin_ids), tenant_id)
# fetch plugin already installed
installed_plugins = manager.list_plugins(tenant_id)

View File

@@ -19,7 +19,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
def get_type(self) -> str:
return RecommendAppType.BUILDIN
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
result = self.fetch_recommended_apps_from_builtin(language)
return result
@@ -28,7 +28,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return result
@classmethod
def _get_builtin_data(cls) -> dict:
def _get_builtin_data(cls):
"""
Get builtin data.
:return:
@@ -44,7 +44,7 @@ class BuildInRecommendAppRetrieval(RecommendAppRetrievalBase):
return cls.builtin_data or {}
@classmethod
def fetch_recommended_apps_from_builtin(cls, language: str) -> dict:
def fetch_recommended_apps_from_builtin(cls, language: str):
"""
Fetch recommended apps from builtin.
:param language: language

View File

@@ -13,7 +13,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
Retrieval recommended app from database
"""
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
result = self.fetch_recommended_apps_from_db(language)
return result
@@ -25,7 +25,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.DATABASE
@classmethod
def fetch_recommended_apps_from_db(cls, language: str) -> dict:
def fetch_recommended_apps_from_db(cls, language: str):
"""
Fetch recommended apps from db.
:param language: language

View File

@@ -5,7 +5,7 @@ class RecommendAppRetrievalBase(ABC):
"""Interface for recommend app retrieval."""
@abstractmethod
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
raise NotImplementedError
@abstractmethod

View File

@@ -24,7 +24,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin(app_id)
return result
def get_recommended_apps_and_categories(self, language: str) -> dict:
def get_recommended_apps_and_categories(self, language: str):
try:
result = self.fetch_recommended_apps_from_dify_official(language)
except Exception as e:
@@ -51,7 +51,7 @@ class RemoteRecommendAppRetrieval(RecommendAppRetrievalBase):
return data
@classmethod
def fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
def fetch_recommended_apps_from_dify_official(cls, language: str):
"""
Fetch recommended apps from dify official.
:param language: language

View File

@@ -6,7 +6,7 @@ from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFa
class RecommendedAppService:
@classmethod
def get_recommended_apps_and_categories(cls, language: str) -> dict:
def get_recommended_apps_and_categories(cls, language: str):
"""
Get recommended apps and categories.
:param language: language

View File

@@ -12,7 +12,7 @@ from models.model import App, Tag, TagBinding
class TagService:
@staticmethod
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None):
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
@@ -25,7 +25,7 @@ class TagService:
return results
@staticmethod
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list):
# Check if tag_ids is not empty to avoid WHERE false condition
if not tag_ids or len(tag_ids) == 0:
return []
@@ -51,7 +51,7 @@ class TagService:
return results
@staticmethod
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list:
def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str):
if not tag_type or not tag_name:
return []
tags = (
@@ -64,7 +64,7 @@ class TagService:
return tags
@staticmethod
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str):
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)

View File

@@ -37,7 +37,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
@@ -103,7 +103,7 @@ class WorkflowToolManageService:
parameters: list[Mapping[str, Any]],
privacy_policy: str = "",
labels: list[str] | None = None,
) -> dict:
):
"""
Update a workflow tool.
:param user_id: the user id
@@ -217,7 +217,7 @@ class WorkflowToolManageService:
return result
@classmethod
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Delete a workflow tool.
:param user_id: the user id
@@ -233,7 +233,7 @@ class WorkflowToolManageService:
return {"result": "success"}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@@ -249,7 +249,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str):
"""
Get a workflow tool.
:param user_id: the user id
@@ -265,7 +265,7 @@ class WorkflowToolManageService:
return cls._get_workflow_tool(tenant_id, db_tool)
@classmethod
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None) -> dict:
def _get_workflow_tool(cls, tenant_id: str, db_tool: WorkflowToolProvider | None):
"""
Get a workflow tool.
:db_tool: the database tool

View File

@@ -132,7 +132,7 @@ class WebsiteService:
return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
@classmethod
def document_create_args_validate(cls, args: dict) -> None:
def document_create_args_validate(cls, args: dict):
"""Validate arguments for document creation."""
try:
WebsiteCrawlApiRequest.from_args(args)

View File

@@ -217,7 +217,7 @@ class WorkflowConverter:
return app_config
def _convert_to_start_node(self, variables: list[VariableEntity]) -> dict:
def _convert_to_start_node(self, variables: list[VariableEntity]):
"""
Convert to Start Node
:param variables: list of variables
@@ -384,7 +384,7 @@ class WorkflowConverter:
prompt_template: PromptTemplateEntity,
file_upload: Optional[FileUploadConfig] = None,
external_data_variable_node_mapping: dict[str, str] | None = None,
) -> dict:
):
"""
Convert to LLM Node
:param original_app_mode: original app mode
@@ -550,7 +550,7 @@ class WorkflowConverter:
return template
def _convert_to_end_node(self) -> dict:
def _convert_to_end_node(self):
"""
Convert to End Node
:return:
@@ -566,7 +566,7 @@ class WorkflowConverter:
},
}
def _convert_to_answer_node(self) -> dict:
def _convert_to_answer_node(self):
"""
Convert to Answer Node
:return:
@@ -578,7 +578,7 @@ class WorkflowConverter:
"data": {"title": "ANSWER", "type": NodeType.ANSWER.value, "answer": "{{#llm.text#}}"},
}
def _create_edge(self, source: str, target: str) -> dict:
def _create_edge(self, source: str, target: str):
"""
Create Edge
:param source: source node id
@@ -587,7 +587,7 @@ class WorkflowConverter:
"""
return {"id": f"{source}-{target}", "source": source, "target": target}
def _append_node(self, graph: dict, node: dict) -> dict:
def _append_node(self, graph: dict, node: dict):
"""
Append Node to Graph

View File

@@ -23,7 +23,7 @@ class WorkflowAppService:
limit: int = 20,
created_by_end_user_session_id: str | None = None,
created_by_account: str | None = None,
) -> dict:
):
"""
Get paginate workflow app logs using SQLAlchemy 2.0 style
:param session: SQLAlchemy session

View File

@@ -67,7 +67,7 @@ class DraftVarLoader(VariableLoader):
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
) -> None:
):
self._engine = engine
self._app_id = app_id
self._tenant_id = tenant_id
@@ -117,7 +117,7 @@ class DraftVarLoader(VariableLoader):
class WorkflowDraftVariableService:
_session: Session
def __init__(self, session: Session) -> None:
def __init__(self, session: Session):
"""
Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
@@ -438,7 +438,7 @@ def _batch_upsert_draft_variable(
session: Session,
draft_vars: Sequence[WorkflowDraftVariable],
policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
) -> None:
):
if not draft_vars:
return None
# Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:

View File

@@ -591,7 +591,7 @@ class WorkflowService:
return new_app
def validate_features_structure(self, app_model: App, features: dict) -> dict:
def validate_features_structure(self, app_model: App, features: dict):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
return AdvancedChatAppConfigManager.config_validate(
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True