feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -7,7 +7,7 @@ from typing import Optional, cast
import requests
from flask import current_app
from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
from core.model_runtime.model_providers import model_provider_factory
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -100,23 +100,15 @@ class ModelProviderService:
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
]
def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
def get_provider_credentials(self, tenant_id: str, provider: str):
"""
get provider credentials.
:param tenant_id:
:param provider:
:return:
"""
# Get all provider configurations of the current workspace
provider_configurations = self.provider_manager.get_configurations(tenant_id)
# Get provider configuration
provider_configuration = provider_configurations.get(provider)
if not provider_configuration:
raise ValueError(f"Provider {provider} does not exist.")
# Get provider custom credentials from workspace
return provider_configuration.get_custom_credentials(obfuscated=True)
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
@@ -176,7 +168,7 @@ class ModelProviderService:
# Remove custom provider credentials.
provider_configuration.delete_custom_credentials()
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str):
"""
get model credentials.
@@ -287,7 +279,7 @@ class ModelProviderService:
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
# Group models by provider
provider_models = {}
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
for model in models:
if model.provider.provider not in provider_models:
provider_models[model.provider.provider] = []
@@ -362,7 +354,7 @@ class ModelProviderService:
return []
# Call get_parameter_rules method of model instance to get model parameter rules
return model_type_instance.get_parameter_rules(model=model, credentials=credentials)
return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials))
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
"""
@@ -422,6 +414,7 @@ class ModelProviderService:
"""
provider_instance = model_provider_factory.get_provider_instance(provider)
provider_schema = provider_instance.get_provider_schema()
file_name: str | None = None
if icon_type.lower() == "icon_small":
if not provider_schema.icon_small:
@@ -439,6 +432,8 @@ class ModelProviderService:
file_name = provider_schema.icon_large.zh_Hans
else:
file_name = provider_schema.icon_large.en_US
if not file_name:
return None, None
root_path = current_app.root_path
provider_instance_path = os.path.dirname(
@@ -524,7 +519,7 @@ class ModelProviderService:
def free_quota_submit(self, tenant_id: str, provider: str):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
api_url = api_base_url + "/api/v1/providers/apply"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
@@ -545,7 +540,7 @@ class ModelProviderService:
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
api_url = api_base_url + "/api/v1/providers/qualification-verify"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}