feat: use xinference client instead of xinference (#1339)
This commit is contained in:
@@ -1,16 +1,53 @@
|
||||
from typing import Optional, List, Any, Union, Generator
|
||||
from typing import Optional, List, Any, Union, Generator, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms import Xinference
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from xinference.client import (
|
||||
from xinference_client.client.restful.restful_client import (
|
||||
RESTfulChatglmCppChatModelHandle,
|
||||
RESTfulChatModelHandle,
|
||||
RESTfulGenerateModelHandle,
|
||||
RESTfulGenerateModelHandle, Client,
|
||||
)
|
||||
|
||||
|
||||
class XinferenceLLM(Xinference):
|
||||
class XinferenceLLM(LLM):
|
||||
client: Any
|
||||
server_url: Optional[str]
|
||||
"""URL of the xinference server"""
|
||||
model_uid: Optional[str]
|
||||
"""UID of the launched model"""
|
||||
|
||||
def __init__(
|
||||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
|
||||
):
|
||||
super().__init__(
|
||||
**{
|
||||
"server_url": server_url,
|
||||
"model_uid": model_uid,
|
||||
}
|
||||
)
|
||||
|
||||
if self.server_url is None:
|
||||
raise ValueError("Please provide server URL")
|
||||
|
||||
if self.model_uid is None:
|
||||
raise ValueError("Please provide the model UID")
|
||||
|
||||
self.client = Client(server_url)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "xinference"
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
**{"server_url": self.server_url},
|
||||
**{"model_uid": self.model_uid},
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
Reference in New Issue
Block a user