chore: add ast-grep rule to convert Optional[T] to T | None (#25560)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
@@ -31,10 +30,10 @@ class Callback(ABC):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Before invoke callback
|
||||
@@ -60,10 +59,10 @@ class Callback(ABC):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
@@ -90,10 +89,10 @@ class Callback(ABC):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
After invoke callback
|
||||
@@ -120,10 +119,10 @@ class Callback(ABC):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Invoke error callback
|
||||
@@ -141,7 +140,7 @@ class Callback(ABC):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def print_text(self, text: str, color: Optional[str] = None, end: str = ""):
|
||||
def print_text(self, text: str, color: str | None = None, end: str = ""):
|
||||
"""Print text with highlighting and no end characters."""
|
||||
text_to_print = self._get_colored_text(text, color) if color else text
|
||||
print(text_to_print, end=end)
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
@@ -20,10 +20,10 @@ class LoggingCallback(Callback):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Before invoke callback
|
||||
@@ -76,10 +76,10 @@ class LoggingCallback(Callback):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
On new chunk callback
|
||||
@@ -106,10 +106,10 @@ class LoggingCallback(Callback):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
After invoke callback
|
||||
@@ -147,10 +147,10 @@ class LoggingCallback(Callback):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
):
|
||||
"""
|
||||
Invoke error callback
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -8,7 +6,7 @@ class I18nObject(BaseModel):
|
||||
Model class for i18n object.
|
||||
"""
|
||||
|
||||
zh_Hans: Optional[str] = None
|
||||
zh_Hans: str | None = None
|
||||
en_US: str
|
||||
|
||||
def __init__(self, **data):
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Mapping, Sequence
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, TypedDict, Union
|
||||
from typing import Any, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -150,13 +150,13 @@ class LLMResult(BaseModel):
|
||||
Model class for llm result.
|
||||
"""
|
||||
|
||||
id: Optional[str] = None
|
||||
id: str | None = None
|
||||
model: str
|
||||
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||
message: AssistantPromptMessage
|
||||
usage: LLMUsage
|
||||
system_fingerprint: Optional[str] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
system_fingerprint: str | None = None
|
||||
reasoning_content: str | None = None
|
||||
|
||||
|
||||
class LLMStructuredOutput(BaseModel):
|
||||
@@ -164,7 +164,7 @@ class LLMStructuredOutput(BaseModel):
|
||||
Model class for llm structured output.
|
||||
"""
|
||||
|
||||
structured_output: Optional[Mapping[str, Any]] = None
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
|
||||
|
||||
class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
|
||||
@@ -180,8 +180,8 @@ class LLMResultChunkDelta(BaseModel):
|
||||
|
||||
index: int
|
||||
message: AssistantPromptMessage
|
||||
usage: Optional[LLMUsage] = None
|
||||
finish_reason: Optional[str] = None
|
||||
usage: LLMUsage | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class LLMResultChunk(BaseModel):
|
||||
@@ -191,7 +191,7 @@ class LLMResultChunk(BaseModel):
|
||||
|
||||
model: str
|
||||
prompt_messages: Sequence[PromptMessage] = Field(default_factory=list)
|
||||
system_fingerprint: Optional[str] = None
|
||||
system_fingerprint: str | None = None
|
||||
delta: LLMResultChunkDelta
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
from typing import Annotated, Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
@@ -146,8 +146,8 @@ class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
|
||||
role: PromptMessageRole
|
||||
content: Optional[str | list[PromptMessageContentUnionTypes]] = None
|
||||
name: Optional[str] = None
|
||||
content: str | list[PromptMessageContentUnionTypes] | None = None
|
||||
name: str | None = None
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""
|
||||
@@ -193,8 +193,8 @@ class PromptMessage(ABC, BaseModel):
|
||||
|
||||
@field_serializer("content")
|
||||
def serialize_content(
|
||||
self, content: Optional[Union[str, Sequence[PromptMessageContent]]]
|
||||
) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]:
|
||||
self, content: Union[str, Sequence[PromptMessageContent]] | None
|
||||
) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None:
|
||||
if content is None or isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from decimal import Decimal
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
@@ -154,7 +154,7 @@ class ProviderModel(BaseModel):
|
||||
model: str
|
||||
label: I18nObject
|
||||
model_type: ModelType
|
||||
features: Optional[list[ModelFeature]] = None
|
||||
features: list[ModelFeature] | None = None
|
||||
fetch_from: FetchFrom
|
||||
model_properties: dict[ModelPropertyKey, Any]
|
||||
deprecated: bool = False
|
||||
@@ -171,15 +171,15 @@ class ParameterRule(BaseModel):
|
||||
"""
|
||||
|
||||
name: str
|
||||
use_template: Optional[str] = None
|
||||
use_template: str | None = None
|
||||
label: I18nObject
|
||||
type: ParameterType
|
||||
help: Optional[I18nObject] = None
|
||||
help: I18nObject | None = None
|
||||
required: bool = False
|
||||
default: Optional[Any] = None
|
||||
min: Optional[float] = None
|
||||
max: Optional[float] = None
|
||||
precision: Optional[int] = None
|
||||
default: Any | None = None
|
||||
min: float | None = None
|
||||
max: float | None = None
|
||||
precision: int | None = None
|
||||
options: list[str] = []
|
||||
|
||||
|
||||
@@ -189,7 +189,7 @@ class PriceConfig(BaseModel):
|
||||
"""
|
||||
|
||||
input: Decimal
|
||||
output: Optional[Decimal] = None
|
||||
output: Decimal | None = None
|
||||
unit: Decimal
|
||||
currency: str
|
||||
|
||||
@@ -200,7 +200,7 @@ class AIModelEntity(ProviderModel):
|
||||
"""
|
||||
|
||||
parameter_rules: list[ParameterRule] = []
|
||||
pricing: Optional[PriceConfig] = None
|
||||
pricing: PriceConfig | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model(self):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum, StrEnum, auto
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
@@ -62,9 +61,9 @@ class CredentialFormSchema(BaseModel):
|
||||
label: I18nObject
|
||||
type: FormType
|
||||
required: bool = True
|
||||
default: Optional[str] = None
|
||||
options: Optional[list[FormOption]] = None
|
||||
placeholder: Optional[I18nObject] = None
|
||||
default: str | None = None
|
||||
options: list[FormOption] | None = None
|
||||
placeholder: I18nObject | None = None
|
||||
max_length: int = 0
|
||||
show_on: list[FormShowOnObject] = []
|
||||
|
||||
@@ -79,7 +78,7 @@ class ProviderCredentialSchema(BaseModel):
|
||||
|
||||
class FieldModelSchema(BaseModel):
|
||||
label: I18nObject
|
||||
placeholder: Optional[I18nObject] = None
|
||||
placeholder: I18nObject | None = None
|
||||
|
||||
|
||||
class ModelCredentialSchema(BaseModel):
|
||||
@@ -98,8 +97,8 @@ class SimpleProviderEntity(BaseModel):
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
models: list[AIModelEntity] = []
|
||||
|
||||
@@ -120,24 +119,24 @@ class ProviderEntity(BaseModel):
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
icon_small: Optional[I18nObject] = None
|
||||
icon_large: Optional[I18nObject] = None
|
||||
icon_small_dark: Optional[I18nObject] = None
|
||||
icon_large_dark: Optional[I18nObject] = None
|
||||
background: Optional[str] = None
|
||||
help: Optional[ProviderHelpEntity] = None
|
||||
description: I18nObject | None = None
|
||||
icon_small: I18nObject | None = None
|
||||
icon_large: I18nObject | None = None
|
||||
icon_small_dark: I18nObject | None = None
|
||||
icon_large_dark: I18nObject | None = None
|
||||
background: str | None = None
|
||||
help: ProviderHelpEntity | None = None
|
||||
supported_model_types: Sequence[ModelType]
|
||||
configurate_methods: list[ConfigurateMethod]
|
||||
models: list[AIModelEntity] = Field(default_factory=list)
|
||||
provider_credential_schema: Optional[ProviderCredentialSchema] = None
|
||||
model_credential_schema: Optional[ModelCredentialSchema] = None
|
||||
provider_credential_schema: ProviderCredentialSchema | None = None
|
||||
model_credential_schema: ModelCredentialSchema | None = None
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
# position from plugin _position.yaml
|
||||
position: Optional[dict[str, list[str]]] = {}
|
||||
position: dict[str, list[str]] | None = {}
|
||||
|
||||
@field_validator("models", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InvokeError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
description: str | None = None
|
||||
|
||||
def __init__(self, description: Optional[str] = None):
|
||||
def __init__(self, description: str | None = None):
|
||||
self.description = description
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import decimal
|
||||
import hashlib
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@@ -99,7 +98,7 @@ class AIModel(BaseModel):
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
# get price info from predefined model schema
|
||||
price_config: Optional[PriceConfig] = None
|
||||
price_config: PriceConfig | None = None
|
||||
if model_schema and model_schema.pricing:
|
||||
price_config = model_schema.pricing
|
||||
|
||||
@@ -132,7 +131,7 @@ class AIModel(BaseModel):
|
||||
currency=price_config.currency,
|
||||
)
|
||||
|
||||
def get_model_schema(self, model: str, credentials: Optional[dict] = None) -> Optional[AIModelEntity]:
|
||||
def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema by model name and credentials
|
||||
|
||||
@@ -171,7 +170,7 @@ class AIModel(BaseModel):
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema from credentials
|
||||
|
||||
@@ -229,7 +228,7 @@ class AIModel(BaseModel):
|
||||
|
||||
return schema
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
Get customizable model schema
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@@ -94,12 +94,12 @@ class LargeLanguageModel(AIModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
model_parameters: dict | None = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@@ -243,11 +243,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Invoke result generator
|
||||
@@ -328,7 +328,7 @@ class LargeLanguageModel(AIModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
@@ -403,11 +403,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger before invoke callbacks
|
||||
@@ -451,11 +451,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger new chunk callbacks
|
||||
@@ -498,11 +498,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger after invoke callbacks
|
||||
@@ -548,11 +548,11 @@ class LargeLanguageModel(AIModel):
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
tools: list[PromptMessageTool] | None = None,
|
||||
stop: Sequence[str] | None = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
callbacks: Optional[list[Callback]] = None,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
):
|
||||
"""
|
||||
Trigger invoke error callbacks
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@@ -18,7 +17,7 @@ class ModerationModel(AIModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
|
||||
def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
@@ -19,9 +17,9 @@ class RerankModel(AIModel):
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
user: Optional[str] = None,
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import IO, Optional
|
||||
from typing import IO
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@@ -17,7 +17,7 @@ class Speech2TextModel(AIModel):
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str:
|
||||
def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str:
|
||||
"""
|
||||
Invoke speech to text model
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from core.entities.embedding_type import EmbeddingInputType
|
||||
@@ -24,7 +22,7 @@ class TextEmbeddingModel(AIModel):
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import logging
|
||||
from threading import Lock
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_tokenizer: Optional[Any] = None
|
||||
_tokenizer: Any | None = None
|
||||
_lock = Lock()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@@ -28,7 +27,7 @@ class TTSModel(AIModel):
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
user: Optional[str] = None,
|
||||
user: str | None = None,
|
||||
) -> Iterable[bytes]:
|
||||
"""
|
||||
Invoke large language model
|
||||
@@ -56,7 +55,7 @@ class TTSModel(AIModel):
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None):
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None):
|
||||
"""
|
||||
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ import hashlib
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
import contexts
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
@@ -206,9 +205,9 @@ class ModelProviderFactory:
|
||||
def get_models(
|
||||
self,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
provider_configs: Optional[list[ProviderConfig]] = None,
|
||||
provider: str | None = None,
|
||||
model_type: ModelType | None = None,
|
||||
provider_configs: list[ProviderConfig] | None = None,
|
||||
) -> list[SimpleProviderEntity]:
|
||||
"""
|
||||
Get all models for given model type
|
||||
|
||||
@@ -8,7 +8,7 @@ from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6
|
||||
from pathlib import Path, PurePath
|
||||
from re import Pattern
|
||||
from types import GeneratorType
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Union
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -98,7 +98,7 @@ def jsonable_encoder(
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None,
|
||||
custom_encoder: dict[Any, Callable[[Any], Any]] | None = None,
|
||||
sqlalchemy_safe: bool = True,
|
||||
) -> Any:
|
||||
custom_encoder = custom_encoder or {}
|
||||
|
||||
Reference in New Issue
Block a user