feat: use xinference client instead of xinference (#1339)

This commit is contained in:
takatost
2023-10-13 15:46:09 +08:00
committed by GitHub
parent 9822f687f7
commit 3efaa713da
5 changed files with 83 additions and 14 deletions

View File

@@ -1,16 +1,53 @@
from typing import Optional, List, Any, Union, Generator
from typing import Optional, List, Any, Union, Generator, Mapping
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import Xinference
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from xinference.client import (
from xinference_client.client.restful.restful_client import (
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
RESTfulGenerateModelHandle, Client,
)
class XinferenceLLM(Xinference):
class XinferenceLLM(LLM):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""
def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None
):
super().__init__(
**{
"server_url": server_url,
"model_uid": model_uid,
}
)
if self.server_url is None:
raise ValueError("Please provide server URL")
if self.model_uid is None:
raise ValueError("Please provide the model UID")
self.client = Client(server_url)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "xinference"
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"server_url": self.server_url},
**{"model_uid": self.model_uid},
}
def _call(
self,
prompt: str,