feat: mypy for all type check (#10921)
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Optional, cast
|
||||
import requests
|
||||
from flask import current_app
|
||||
|
||||
from core.entities.model_entities import ModelStatus, ProviderModelWithStatusEntity
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, ProviderModelWithStatusEntity
|
||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@@ -100,23 +100,15 @@ class ModelProviderService:
|
||||
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
|
||||
]
|
||||
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str):
|
||||
"""
|
||||
get provider credentials.
|
||||
|
||||
:param tenant_id:
|
||||
:param provider:
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider configuration
|
||||
provider_configuration = provider_configurations.get(provider)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Get provider custom credentials from workspace
|
||||
return provider_configuration.get_custom_credentials(obfuscated=True)
|
||||
|
||||
def provider_credentials_validate(self, tenant_id: str, provider: str, credentials: dict) -> None:
|
||||
@@ -176,7 +168,7 @@ class ModelProviderService:
|
||||
# Remove custom provider credentials.
|
||||
provider_configuration.delete_custom_credentials()
|
||||
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> dict:
|
||||
def get_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str):
|
||||
"""
|
||||
get model credentials.
|
||||
|
||||
@@ -287,7 +279,7 @@ class ModelProviderService:
|
||||
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
|
||||
|
||||
# Group models by provider
|
||||
provider_models = {}
|
||||
provider_models: dict[str, list[ModelWithProviderEntity]] = {}
|
||||
for model in models:
|
||||
if model.provider.provider not in provider_models:
|
||||
provider_models[model.provider.provider] = []
|
||||
@@ -362,7 +354,7 @@ class ModelProviderService:
|
||||
return []
|
||||
|
||||
# Call get_parameter_rules method of model instance to get model parameter rules
|
||||
return model_type_instance.get_parameter_rules(model=model, credentials=credentials)
|
||||
return list(model_type_instance.get_parameter_rules(model=model, credentials=credentials))
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
|
||||
"""
|
||||
@@ -422,6 +414,7 @@ class ModelProviderService:
|
||||
"""
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider)
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
file_name: str | None = None
|
||||
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
@@ -439,6 +432,8 @@ class ModelProviderService:
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
if not file_name:
|
||||
return None, None
|
||||
|
||||
root_path = current_app.root_path
|
||||
provider_instance_path = os.path.dirname(
|
||||
@@ -524,7 +519,7 @@ class ModelProviderService:
|
||||
|
||||
def free_quota_submit(self, tenant_id: str, provider: str):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
|
||||
api_url = api_base_url + "/api/v1/providers/apply"
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
@@ -545,7 +540,7 @@ class ModelProviderService:
|
||||
|
||||
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL", "")
|
||||
api_url = api_base_url + "/api/v1/providers/qualification-verify"
|
||||
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
|
||||
Reference in New Issue
Block a user