feat: add jina embedding (#1647)

Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
zxhlyh
2023-11-29 14:58:11 +08:00
committed by GitHub
parent 454577c6b1
commit 451af66be0
22 changed files with 662 additions and 4 deletions

View File

@@ -75,6 +75,9 @@ class ModelProviderFactory:
elif provider_name == 'cohere':
from core.model_providers.providers.cohere_provider import CohereProvider
return CohereProvider
elif provider_name == 'jina':
from core.model_providers.providers.jina_provider import JinaProvider
return JinaProvider
else:
raise NotImplementedError

View File

@@ -0,0 +1,25 @@
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.embedding.base import BaseEmbedding
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
class JinaEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = JinaEmbeddings(
model=name,
**credentials
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError):
return LLMBadRequestError(f"Jina: {str(ex)}")
else:
return ex

View File

@@ -0,0 +1,141 @@
import json
from json import JSONDecodeError
from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.jina_embedding import JinaEmbedding
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.embeddings.jina_embedding import JinaEmbeddings
from models.provider import ProviderType
class JinaProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'jina'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.EMBEDDINGS:
return [
{
'id': 'jina-embeddings-v2-base-en',
'name': 'jina-embeddings-v2-base-en',
},
{
'id': 'jina-embeddings-v2-small-en',
'name': 'jina-embeddings-v2-small-en',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.EMBEDDINGS:
model_class = JinaEmbedding
else:
raise NotImplementedError
return model_class
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Jina API Key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
}
embedding = JinaEmbeddings(
model='jina-embeddings-v2-small-en',
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
return credentials
return {}
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)
def _get_text_generation_model_mode(self, model_name) -> str:
raise NotImplementedError
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
raise NotImplementedError

View File

@@ -14,5 +14,6 @@
"xinference",
"openllm",
"localai",
"cohere"
"cohere",
"jina"
]

View File

@@ -0,0 +1,10 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed",
"supported_model_types": [
"embeddings"
]
}

View File

@@ -0,0 +1,69 @@
"""Wrapper around Jina embedding models."""
from typing import Any, List
import requests
from pydantic import BaseModel, Extra
from langchain.embeddings.base import Embeddings
class JinaEmbeddings(BaseModel, Embeddings):
"""Wrapper around Jina embedding models.
"""
client: Any #: :meta private:
api_key: str
model: str
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to Jina's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
result = self.invoke_embedding(text=text)
embeddings.append(result)
return [list(map(float, e)) for e in embeddings]
def invoke_embedding(self, text):
params = {
"model": self.model,
"input": [
text
]
}
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
response = requests.post(
'https://api.jina.ai/v1/embeddings',
headers=headers,
json=params
)
if not response.ok:
raise ValueError(f"Jina HTTP {response.status_code} error: {response.text}")
json_response = response.json()
return json_response["data"][0]["embedding"]
def embed_query(self, text: str) -> List[float]:
"""Call out to Jina's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]