Feat/support multimodal embedding (#29115)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -10,9 +10,9 @@ from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
@@ -200,7 +200,7 @@ class ModelInstance:
|
||||
|
||||
def invoke_text_embedding(
|
||||
self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT
|
||||
) -> TextEmbeddingResult:
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@@ -212,7 +212,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
TextEmbeddingResult,
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
@@ -223,6 +223,34 @@ class ModelInstance:
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_multimodal_embedding(
|
||||
self,
|
||||
multimodel_documents: list[dict],
|
||||
user: str | None = None,
|
||||
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param multimodel_documents: multimodel documents to embed
|
||||
:param user: unique user id
|
||||
:param input_type: input type
|
||||
:return: embeddings result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
return cast(
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
user=user,
|
||||
input_type=input_type,
|
||||
),
|
||||
)
|
||||
|
||||
def get_text_embedding_num_tokens(self, texts: list[str]) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for text embedding
|
||||
@@ -276,6 +304,40 @@ class ModelInstance:
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_multimodal_rerank(
|
||||
self,
|
||||
query: dict,
|
||||
docs: list[dict],
|
||||
score_threshold: float | None = None,
|
||||
top_n: int | None = None,
|
||||
user: str | None = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank model
|
||||
|
||||
:param query: search query
|
||||
:param docs: docs for reranking
|
||||
:param score_threshold: score threshold
|
||||
:param top_n: top n
|
||||
:param user: unique user id
|
||||
:return: rerank result
|
||||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
score_threshold=score_threshold,
|
||||
top_n=top_n,
|
||||
user=user,
|
||||
),
|
||||
)
|
||||
|
||||
def invoke_moderation(self, text: str, user: str | None = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
@@ -461,6 +523,32 @@ class ModelManager:
|
||||
model=default_model_entity.model,
|
||||
)
|
||||
|
||||
def check_model_support_vision(self, tenant_id: str, provider: str, model: str, model_type: ModelType) -> bool:
|
||||
"""
|
||||
Check if model supports vision
|
||||
:param tenant_id: tenant id
|
||||
:param provider: provider name
|
||||
:param model: model name
|
||||
:return: True if model supports vision, False otherwise
|
||||
"""
|
||||
model_instance = self.get_model_instance(tenant_id, provider, model_type, model)
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
match model_type:
|
||||
case ModelType.LLM:
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
case ModelType.TEXT_EMBEDDING:
|
||||
model_type_instance = cast(TextEmbeddingModel, model_type_instance)
|
||||
case ModelType.RERANK:
|
||||
model_type_instance = cast(RerankModel, model_type_instance)
|
||||
case _:
|
||||
raise ValueError(f"Model type {model_type} is not supported")
|
||||
model_schema = model_type_instance.get_model_schema(model, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return False
|
||||
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LBModelManager:
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user