fix: hf hosted inference check (#1128)
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain import HuggingFaceHub
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
@@ -9,6 +8,7 @@ from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.huggingface_endpoint_llm import HuggingFaceEndpointLLM
|
||||
from core.third_party.langchain.llms.huggingface_hub_llm import HuggingFaceHubLLM
|
||||
|
||||
|
||||
class HuggingfaceHubModel(BaseLLM):
|
||||
@@ -31,7 +31,7 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
streaming=streaming
|
||||
)
|
||||
else:
|
||||
client = HuggingFaceHub(
|
||||
client = HuggingFaceHubLLM(
|
||||
repo_id=self.name,
|
||||
task=self.credentials['task_type'],
|
||||
model_kwargs=provider_model_kwargs,
|
||||
@@ -88,4 +88,6 @@ class HuggingfaceHubModel(BaseLLM):
|
||||
if 'baichuan' in self.name.lower():
|
||||
return False
|
||||
|
||||
return True
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -89,7 +89,8 @@ class HuggingfaceHubProvider(BaseModelProvider):
|
||||
raise CredentialsValidateFailedError('Task Type must be provided.')
|
||||
|
||||
if credentials['task_type'] not in ("text2text-generation", "text-generation", "summarization"):
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, text-generation, summarization.')
|
||||
raise CredentialsValidateFailedError('Task Type must be one of text2text-generation, '
|
||||
'text-generation, summarization.')
|
||||
|
||||
try:
|
||||
llm = HuggingFaceEndpointLLM(
|
||||
|
||||
Reference in New Issue
Block a user