Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -0,0 +1,68 @@
|
||||
import anthropic
|
||||
from anthropic import Anthropic
|
||||
from anthropic.resources.completions import Completions
|
||||
from anthropic.types import completion_create_params, Completion
|
||||
from anthropic._types import NOT_GIVEN, NotGiven, Headers, Query, Body
|
||||
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from typing import List, Union, Literal, Any, Generator
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
import os
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
|
||||
|
||||
class MockAnthropicClass(object):
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_sync(model: str) -> Completion:
|
||||
return Completion(
|
||||
completion='hello, I\'m a chatbot from anthropic',
|
||||
model=model,
|
||||
stop_reason='stop_sequence'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_stream(model: str) -> Generator[Completion, None, None]:
|
||||
full_response_text = "hello, I'm a chatbot from anthropic"
|
||||
|
||||
for i in range(0, len(full_response_text) + 1):
|
||||
sleep(0.1)
|
||||
if i == len(full_response_text):
|
||||
yield Completion(
|
||||
completion='',
|
||||
model=model,
|
||||
stop_reason='stop_sequence'
|
||||
)
|
||||
else:
|
||||
yield Completion(
|
||||
completion=full_response_text[i],
|
||||
model=model,
|
||||
stop_reason=''
|
||||
)
|
||||
|
||||
def mocked_anthropic(self: Completions, *,
|
||||
max_tokens_to_sample: int,
|
||||
model: Union[str, Literal["claude-2.1", "claude-instant-1"]],
|
||||
prompt: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any
|
||||
) -> Union[Completion, Generator[Completion, None, None]]:
|
||||
if len(self._client.api_key) < 18:
|
||||
raise anthropic.AuthenticationError('Invalid API key')
|
||||
|
||||
if stream:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
|
||||
else:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_sync(model=model)
|
||||
|
||||
@pytest.fixture
|
||||
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Completions, 'create', MockAnthropicClass.mocked_anthropic)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
127
api/tests/integration_tests/model_runtime/__mock/google.py
Normal file
127
api/tests/integration_tests/model_runtime/__mock/google.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from google.generativeai import GenerativeModel
|
||||
from google.generativeai.types import GenerateContentResponse
|
||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||
import google.generativeai.types.generation_types as generation_config_types
|
||||
import google.generativeai.types.content_types as content_types
|
||||
import google.generativeai.types.safety_types as safety_types
|
||||
from google.generativeai.client import _ClientManager, configure
|
||||
|
||||
from google.ai import generativelanguage as glm
|
||||
|
||||
from typing import Generator, List
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import pytest
|
||||
|
||||
current_api_key = ''
|
||||
|
||||
class MockGoogleResponseClass(object):
|
||||
_done = False
|
||||
|
||||
def __iter__(self):
|
||||
full_response_text = 'it\'s google!'
|
||||
|
||||
for i in range(0, len(full_response_text) + 1, 1):
|
||||
if i == len(full_response_text):
|
||||
self._done = True
|
||||
yield GenerateContentResponse(
|
||||
done=True,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
)
|
||||
else:
|
||||
yield GenerateContentResponse(
|
||||
done=False,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
)
|
||||
|
||||
class MockGoogleResponseCandidateClass(object):
|
||||
finish_reason = 'stop'
|
||||
|
||||
class MockGoogleClass(object):
|
||||
@staticmethod
|
||||
def generate_content_sync() -> GenerateContentResponse:
|
||||
return GenerateContentResponse(
|
||||
done=True,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
|
||||
return MockGoogleResponseClass()
|
||||
|
||||
def generate_content(self: GenerativeModel,
|
||||
contents: content_types.ContentsType,
|
||||
*,
|
||||
generation_config: generation_config_types.GenerationConfigType | None = None,
|
||||
safety_settings: safety_types.SafetySettingOptions | None = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> GenerateContentResponse:
|
||||
global current_api_key
|
||||
|
||||
if len(current_api_key) < 16:
|
||||
raise Exception('Invalid API key')
|
||||
|
||||
if stream:
|
||||
return MockGoogleClass.generate_content_stream()
|
||||
|
||||
return MockGoogleClass.generate_content_sync()
|
||||
|
||||
@property
|
||||
def generative_response_text(self) -> str:
|
||||
return 'it\'s google!'
|
||||
|
||||
@property
|
||||
def generative_response_candidates(self) -> List[MockGoogleResponseCandidateClass]:
|
||||
return [MockGoogleResponseCandidateClass()]
|
||||
|
||||
def make_client(self: _ClientManager, name: str):
|
||||
global current_api_key
|
||||
|
||||
if name.endswith("_async"):
|
||||
name = name.split("_")[0]
|
||||
cls = getattr(glm, name.title() + "ServiceAsyncClient")
|
||||
else:
|
||||
cls = getattr(glm, name.title() + "ServiceClient")
|
||||
|
||||
# Attempt to configure using defaults.
|
||||
if not self.client_config:
|
||||
configure()
|
||||
|
||||
client_options = self.client_config.get("client_options", None)
|
||||
if client_options:
|
||||
current_api_key = client_options.api_key
|
||||
|
||||
def nop(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
original_init = cls.__init__
|
||||
cls.__init__ = nop
|
||||
client: glm.GenerativeServiceClient = cls(**self.client_config)
|
||||
cls.__init__ = original_init
|
||||
|
||||
if not self.default_metadata:
|
||||
return client
|
||||
|
||||
@pytest.fixture
|
||||
def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
|
||||
monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
|
||||
monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
|
||||
|
||||
yield
|
||||
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,21 @@
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
|
||||
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from typing import List, Dict, Any
|
||||
|
||||
import pytest
|
||||
import os
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
@pytest.fixture
|
||||
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
@@ -0,0 +1,54 @@
|
||||
from huggingface_hub import InferenceClient
|
||||
from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse, Details, StreamDetails, Token
|
||||
from huggingface_hub.utils import BadRequestError
|
||||
|
||||
from typing import Literal, Optional, List, Generator, Union, Any
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
import re
|
||||
|
||||
class MockHuggingfaceChatClass(object):
|
||||
@staticmethod
|
||||
def generate_create_sync(model: str) -> TextGenerationResponse:
|
||||
response = TextGenerationResponse(
|
||||
generated_text="You can call me Miku Miku o~e~o~",
|
||||
details=Details(
|
||||
finish_reason="length",
|
||||
generated_tokens=6,
|
||||
tokens=[
|
||||
Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]:
|
||||
full_text = "You can call me Miku Miku o~e~o~"
|
||||
|
||||
for i in range(0, len(full_text)):
|
||||
response = TextGenerationStreamResponse(
|
||||
token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
|
||||
)
|
||||
response.generated_text = full_text[i]
|
||||
response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
|
||||
|
||||
yield response
|
||||
|
||||
def text_generation(self: InferenceClient, prompt: str, *,
|
||||
stream: Literal[False] = ...,
|
||||
model: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
|
||||
# check if key is valid
|
||||
if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
|
||||
raise BadRequestError('Invalid API key')
|
||||
|
||||
if model is None:
|
||||
raise BadRequestError('Invalid model')
|
||||
|
||||
if stream:
|
||||
return MockHuggingfaceChatClass.generate_create_stream(model)
|
||||
return MockHuggingfaceChatClass.generate_create_sync(model)
|
||||
|
||||
63
api/tests/integration_tests/model_runtime/__mock/openai.py
Normal file
63
api/tests/integration_tests/model_runtime/__mock/openai.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
|
||||
from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass
|
||||
from openai.resources.completions import Completions
|
||||
from openai.resources.chat import Completions as ChatCompletions
|
||||
from openai.resources.models import Models
|
||||
from openai.resources.moderations import Moderations
|
||||
from openai.resources.audio.transcriptions import Transcriptions
|
||||
from openai.resources.embeddings import Embeddings
|
||||
|
||||
# import monkeypatch
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from typing import Literal, Callable, List
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
def mock_openai(monkeypatch: MonkeyPatch, methods: List[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]:
|
||||
"""
|
||||
mock openai module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
if "completion" in methods:
|
||||
monkeypatch.setattr(Completions, "create", MockCompletionsClass.completion_create)
|
||||
|
||||
if "chat" in methods:
|
||||
monkeypatch.setattr(ChatCompletions, "create", MockChatClass.chat_create)
|
||||
|
||||
if "remote" in methods:
|
||||
monkeypatch.setattr(Models, "list", MockModelClass.list)
|
||||
|
||||
if "moderation" in methods:
|
||||
monkeypatch.setattr(Moderations, "create", MockModerationClass.moderation_create)
|
||||
|
||||
if "speech2text" in methods:
|
||||
monkeypatch.setattr(Transcriptions, "create", MockSpeech2TextClass.speech2text_create)
|
||||
|
||||
if "text_embedding" in methods:
|
||||
monkeypatch.setattr(Embeddings, "create", MockEmbeddingsClass.create_embeddings)
|
||||
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
@pytest.fixture
|
||||
def setup_openai_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, 'param') else []
|
||||
if MOCK:
|
||||
unpatch = mock_openai(monkeypatch, methods=methods)
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
235
api/tests/integration_tests/model_runtime/__mock/openai_chat.py
Normal file
235
api/tests/integration_tests/model_runtime/__mock/openai_chat.py
Normal file
@@ -0,0 +1,235 @@
|
||||
from openai import OpenAI
|
||||
from openai.types import Completion as CompletionMessage
|
||||
from openai._types import NotGiven, NOT_GIVEN
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, \
|
||||
ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, ChatCompletionMessageToolCall
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall, ChoiceDeltaFunctionCall,\
|
||||
Choice, ChoiceDelta, ChoiceDeltaToolCallFunction
|
||||
from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice, ChatCompletion as _ChatCompletion
|
||||
from openai.types.chat.chat_completion_message import FunctionCall, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from openai.resources.chat.completions import Completions
|
||||
from openai import AzureOpenAI
|
||||
|
||||
import openai.types.chat.completion_create_params as completion_create_params
|
||||
|
||||
# import monkeypatch
|
||||
from typing import List, Any, Generator, Union, Optional, Literal
|
||||
from time import time, sleep
|
||||
from json import dumps, loads
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
import re
|
||||
|
||||
class MockChatClass(object):
|
||||
@staticmethod
|
||||
def generate_function_call(
|
||||
functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
) -> Optional[FunctionCall]:
|
||||
if not functions or len(functions) == 0:
|
||||
return None
|
||||
function: completion_create_params.Function = functions[0]
|
||||
function_name = function['name']
|
||||
function_description = function['description']
|
||||
function_parameters = function['parameters']
|
||||
function_parameters_type = function_parameters['type']
|
||||
if function_parameters_type != 'object':
|
||||
return None
|
||||
function_parameters_properties = function_parameters['properties']
|
||||
function_parameters_required = function_parameters['required']
|
||||
parameters = {}
|
||||
for parameter_name, parameter in function_parameters_properties.items():
|
||||
if parameter_name not in function_parameters_required:
|
||||
continue
|
||||
parameter_type = parameter['type']
|
||||
if parameter_type == 'string':
|
||||
if 'enum' in parameter:
|
||||
if len(parameter['enum']) == 0:
|
||||
continue
|
||||
parameters[parameter_name] = parameter['enum'][0]
|
||||
else:
|
||||
parameters[parameter_name] = 'kawaii'
|
||||
elif parameter_type == 'integer':
|
||||
parameters[parameter_name] = 114514
|
||||
elif parameter_type == 'number':
|
||||
parameters[parameter_name] = 1919810.0
|
||||
elif parameter_type == 'boolean':
|
||||
parameters[parameter_name] = True
|
||||
|
||||
return FunctionCall(name=function_name, arguments=dumps(parameters))
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_calls(
|
||||
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> Optional[List[ChatCompletionMessageToolCall]]:
|
||||
list_tool_calls = []
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
tool: ChatCompletionToolParam = tools[0]
|
||||
|
||||
if tools['type'] != 'function':
|
||||
return None
|
||||
|
||||
function = tool['function']
|
||||
|
||||
function_call = MockChatClass.generate_function_call(functions=[function])
|
||||
if function_call is None:
|
||||
return None
|
||||
|
||||
list_tool_calls.append(ChatCompletionMessageToolCall(
|
||||
id='sakurajima-mai',
|
||||
function=Function(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
),
|
||||
type='function'
|
||||
))
|
||||
|
||||
return list_tool_calls
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_sync(
|
||||
model: str,
|
||||
functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> CompletionMessage:
|
||||
tool_calls = []
|
||||
function_call = MockChatClass.generate_function_call(functions=functions)
|
||||
if not function_call:
|
||||
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
|
||||
|
||||
sleep(1)
|
||||
return _ChatCompletion(
|
||||
id='cmpl-3QJQa5jXJ5Z5X',
|
||||
choices=[
|
||||
_ChatCompletionChoice(
|
||||
finish_reason='content_filter',
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content='elaina',
|
||||
role='assistant',
|
||||
function_call=function_call,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object='chat.completion',
|
||||
system_fingerprint='',
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_stream(
|
||||
model: str,
|
||||
functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
) -> Generator[ChatCompletionChunk, None, None]:
|
||||
tool_calls = []
|
||||
function_call = MockChatClass.generate_function_call(functions=functions)
|
||||
if not function_call:
|
||||
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
|
||||
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
sleep(0.1)
|
||||
if i == len(full_text):
|
||||
yield ChatCompletionChunk(
|
||||
id='cmpl-3QJQa5jXJ5Z5X',
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content='',
|
||||
function_call=ChoiceDeltaFunctionCall(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
) if function_call else None,
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id='misaka-mikoto',
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=tool_calls[0].function.name,
|
||||
arguments=tool_calls[0].function.arguments,
|
||||
),
|
||||
type='function'
|
||||
)
|
||||
] if tool_calls and len(tool_calls) > 0 else None
|
||||
),
|
||||
finish_reason='function_call',
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object='chat.completion.chunk',
|
||||
system_fingerprint='',
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
total_tokens=19,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionChunk(
|
||||
id='cmpl-3QJQa5jXJ5Z5X',
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content=full_text[i],
|
||||
role='assistant',
|
||||
),
|
||||
finish_reason='content_filter',
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object='chat.completion.chunk',
|
||||
system_fingerprint='',
|
||||
)
|
||||
|
||||
def chat_create(self: Completions, *,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
model: Union[str,Literal[
|
||||
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"],
|
||||
],
|
||||
functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
|
||||
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any,
|
||||
):
|
||||
openai_models = [
|
||||
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
|
||||
]
|
||||
azure_openai_models = [
|
||||
"gpt35", "gpt-4v", "gpt-35-turbo"
|
||||
]
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
if stream:
|
||||
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
|
||||
|
||||
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
|
||||
@@ -0,0 +1,121 @@
|
||||
from openai import BadRequestError, OpenAI, AzureOpenAI
|
||||
from openai.types import Completion as CompletionMessage
|
||||
from openai._types import NotGiven, NOT_GIVEN
|
||||
from openai.types.completion import CompletionChoice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from openai.resources.completions import Completions
|
||||
|
||||
# import monkeypatch
|
||||
from typing import List, Any, Generator, Union, Optional, Literal
|
||||
from time import time, sleep
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
import re
|
||||
|
||||
class MockCompletionsClass(object):
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_sync(
|
||||
model: str
|
||||
) -> CompletionMessage:
|
||||
sleep(1)
|
||||
return CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
created=int(time()),
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(
|
||||
text="mock",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_stream(
|
||||
model: str
|
||||
) -> Generator[CompletionMessage, None, None]:
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
sleep(0.1)
|
||||
if i == len(full_text):
|
||||
yield CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
created=int(time()),
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(
|
||||
text="",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
total_tokens=19,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
created=int(time()),
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(
|
||||
text=full_text[i],
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason="content_filter"
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def completion_create(self: Completions, *, model: Union[
|
||||
str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003", "text-davinci-002", "text-davinci-001",
|
||||
"code-davinci-002", "text-curie-001", "text-babbage-001",
|
||||
"text-ada-001"],
|
||||
],
|
||||
prompt: Union[str, List[str], List[int], List[List[int]], None],
|
||||
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any
|
||||
):
|
||||
openai_models = [
|
||||
"babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001",
|
||||
"code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001",
|
||||
]
|
||||
azure_openai_models = [
|
||||
"gpt-35-turbo-instruct"
|
||||
]
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
|
||||
if not prompt:
|
||||
raise BadRequestError('Invalid prompt')
|
||||
if stream:
|
||||
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
|
||||
|
||||
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,67 @@
|
||||
from openai.resources.moderations import Moderations
|
||||
from openai.types import ModerationCreateResponse
|
||||
from openai.types.moderation import Moderation, Categories, CategoryScores
|
||||
from openai._types import NotGiven, NOT_GIVEN
|
||||
|
||||
from typing import Union, List, Literal, Any
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
import re
|
||||
|
||||
class MockModerationClass(object):
|
||||
def moderation_create(self: Moderations,*,
|
||||
input: Union[str, List[str]],
|
||||
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any
|
||||
) -> ModerationCreateResponse:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
|
||||
if len(self._client.api_key) < 18:
|
||||
raise InvokeAuthorizationError('Invalid API key')
|
||||
|
||||
for text in input:
|
||||
result = []
|
||||
if 'kill' in text:
|
||||
moderation_categories = {
|
||||
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
|
||||
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
|
||||
'sexual/minors': False, 'violence': False, 'violence/graphic': False
|
||||
}
|
||||
moderation_categories_scores = {
|
||||
'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0,
|
||||
'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0,
|
||||
'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0
|
||||
}
|
||||
|
||||
result.append(Moderation(
|
||||
flagged=True,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores)
|
||||
))
|
||||
else:
|
||||
moderation_categories = {
|
||||
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
|
||||
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
|
||||
'sexual/minors': False, 'violence': False, 'violence/graphic': False
|
||||
}
|
||||
moderation_categories_scores = {
|
||||
'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0,
|
||||
'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0,
|
||||
'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0
|
||||
}
|
||||
result.append(Moderation(
|
||||
flagged=False,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores)
|
||||
))
|
||||
|
||||
return ModerationCreateResponse(
|
||||
id='shiroii kuloko',
|
||||
model=model,
|
||||
results=result
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from openai.resources.models import Models
|
||||
from openai.types.model import Model
|
||||
|
||||
from typing import List
|
||||
from time import time
|
||||
|
||||
class MockModelClass(object):
|
||||
"""
|
||||
mock class for openai.models.Models
|
||||
"""
|
||||
def list(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> List[Model]:
|
||||
return [
|
||||
Model(
|
||||
id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ',
|
||||
created=int(time()),
|
||||
object='model',
|
||||
owned_by='organization:org-123',
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
from openai.resources.audio.transcriptions import Transcriptions
|
||||
from openai._types import NotGiven, NOT_GIVEN, FileTypes
|
||||
from openai.types.audio.transcription import Transcription
|
||||
|
||||
from typing import Union, List, Literal, Any
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
import re
|
||||
|
||||
class MockSpeech2TextClass(object):
|
||||
def speech2text_create(self: Transcriptions,
|
||||
*,
|
||||
file: FileTypes,
|
||||
model: Union[str, Literal["whisper-1"]],
|
||||
language: str | NotGiven = NOT_GIVEN,
|
||||
prompt: str | NotGiven = NOT_GIVEN,
|
||||
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
|
||||
temperature: float | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any
|
||||
) -> Transcription:
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
|
||||
if len(self._client.api_key) < 18:
|
||||
raise InvokeAuthorizationError('Invalid API key')
|
||||
|
||||
return Transcription(
|
||||
text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
|
||||
)
|
||||
142
api/tests/integration_tests/model_runtime/__mock/xinference.py
Normal file
142
api/tests/integration_tests/model_runtime/__mock/xinference.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from xinference_client.client.restful.restful_client import Client, \
|
||||
RESTfulChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatglmCppChatModelHandle, \
|
||||
RESTfulEmbeddingModelHandle, RESTfulRerankModelHandle
|
||||
from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
|
||||
|
||||
from requests.sessions import Session
|
||||
from requests import Response
|
||||
from requests.exceptions import ConnectionError
|
||||
from typing import Union, List
|
||||
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
import pytest
|
||||
import os
|
||||
import re
|
||||
|
||||
class MockXinferenceClass(object):
|
||||
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
||||
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if 'generate' == model_uid:
|
||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url)
|
||||
if 'chat' == model_uid:
|
||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url)
|
||||
if 'embedding' == model_uid:
|
||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url)
|
||||
if 'rerank' == model_uid:
|
||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url)
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
def get(self: Session, url: str, **kwargs):
|
||||
if '/v1/models/' in url:
|
||||
response = Response()
|
||||
|
||||
# get model uid
|
||||
model_uid = url.split('/')[-1]
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
|
||||
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
|
||||
response.status_code = 404
|
||||
raise ConnectionError('404 Not Found')
|
||||
|
||||
# check if url is valid
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
|
||||
response.status_code = 404
|
||||
raise ConnectionError('404 Not Found')
|
||||
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
"model_type": "LLM",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
"0",
|
||||
"1"
|
||||
],
|
||||
"model_name": "chatglm3-6b",
|
||||
"model_lang": [
|
||||
"en"
|
||||
],
|
||||
"model_ability": [
|
||||
"generate",
|
||||
"chat"
|
||||
],
|
||||
"model_description": "latest chatglm3",
|
||||
"model_format": "pytorch",
|
||||
"model_size_in_billions": 7,
|
||||
"quantization": "none",
|
||||
"model_hub": "huggingface",
|
||||
"revision": null,
|
||||
"context_length": 2048,
|
||||
"replica": 1
|
||||
}'''
|
||||
return response
|
||||
|
||||
def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
self._model_uid != 'rerank':
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if top_n is None:
|
||||
top_n = 1
|
||||
|
||||
return {
|
||||
'results': [
|
||||
{
|
||||
'index': i,
|
||||
'document': doc,
|
||||
'relevance_score': 0.9
|
||||
}
|
||||
for i, doc in enumerate(documents[:top_n])
|
||||
]
|
||||
}
|
||||
|
||||
def create_embedding(
|
||||
self: RESTfulGenerateModelHandle,
|
||||
input: Union[str, List[str]],
|
||||
**kwargs
|
||||
) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
self._model_uid != 'embedding':
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
ipt_len = len(input)
|
||||
|
||||
embedding = Embedding(
|
||||
object="list",
|
||||
model=self._model_uid,
|
||||
data=[
|
||||
EmbeddingData(
|
||||
index=i,
|
||||
object="embedding",
|
||||
embedding=[1919.810 for _ in range(768)]
|
||||
)
|
||||
for i in range(ipt_len)
|
||||
],
|
||||
usage=EmbeddingUsage(
|
||||
prompt_tokens=ipt_len,
|
||||
total_tokens=ipt_len
|
||||
)
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
@pytest.fixture
|
||||
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
116
api/tests/integration_tests/model_runtime/anthropic/test_llm.py
Normal file
116
api/tests/integration_tests/model_runtime/anthropic/test_llm.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
||||
LLMResultChunkDelta
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeLanguageModel
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
def test_validate_credentials(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1',
|
||||
credentials={
|
||||
'anthropic_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
def test_invoke_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
|
||||
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'top_p': 1.0,
|
||||
'max_tokens_to_sample': 10
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
def test_invoke_stream_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens_to_sample': 100
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='claude-instant-1',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 18
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProvider
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
def test_validate_provider_credentials(setup_anthropic_mock):
|
||||
provider = AnthropicProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
}
|
||||
)
|
||||
BIN
api/tests/integration_tests/model_runtime/assets/audio.mp3
Normal file
BIN
api/tests/integration_tests/model_runtime/assets/audio.mp3
Normal file
Binary file not shown.
File diff suppressed because one or more lines are too long
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.azure_openai.text_embedding.text_embedding import AzureOpenAITextEmbeddingModel
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embedding',
|
||||
credentials={
|
||||
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
|
||||
'openai_api_key': 'invalid_key',
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='embedding',
|
||||
credentials={
|
||||
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
|
||||
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='embedding',
|
||||
credentials={
|
||||
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
|
||||
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='embedding',
|
||||
credentials={
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
190
api/tests/integration_tests/model_runtime/baichuan/test_llm.py
Normal file
190
api/tests/integration_tests/model_runtime/baichuan/test_llm.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from typing import Generator
|
||||
from time import sleep
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
|
||||
LLMResultChunk
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.baichuan.llm.llm import BaichuanLarguageModel
|
||||
|
||||
def test_predefined_models():
|
||||
model = BaichuanLarguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='baichuan2-turbo',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='baichuan2-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_model():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_model_with_system_message():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='请记住你是Kasumi。'
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='现在告诉我你是谁?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
'with_search_enhance': True,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ''
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
total_message += chunk.delta.message.content
|
||||
|
||||
assert '不' not in total_message
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='baichuan2-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 9
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.baichuan.baichuan import BaichuanProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = BaichuanProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': 'hahahaha'
|
||||
}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY')
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.baichuan.text_embedding.text_embedding import BaichuanTextEmbeddingModel
|
||||
|
||||
def test_validate_credentials():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='baichuan-text-embedding',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='baichuan-text-embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='baichuan-text-embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='baichuan-text-embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
287
api/tests/integration_tests/model_runtime/chatglm/test_llm.py
Normal file
287
api/tests/integration_tests/model_runtime/chatglm/test_llm.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from typing import Generator
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \
|
||||
SystemPromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
|
||||
LLMResultChunk
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
def test_predefined_models():
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_invoke_stream_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_invoke_stream_model_with_functions(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm3-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。'
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='波士顿天气如何?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user='abc-123',
|
||||
stream=True,
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
call: LLMResultChunk = None
|
||||
chunks = []
|
||||
|
||||
for chunk in response:
|
||||
chunks.append(chunk)
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0:
|
||||
call = chunk
|
||||
break
|
||||
|
||||
assert call is not None
|
||||
assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_invoke_model_with_functions(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm3-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='What is the weather like in San Francisco?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user='abc-123',
|
||||
stream=False,
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
assert response.message.tool_calls[0].function.name == 'get_current_weather'
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_validate_provider_credentials(setup_openai_mock):
|
||||
provider = ChatGLMProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_base': 'hahahaha'
|
||||
}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,21 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.cohere.cohere import CohereProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = CohereProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = CohereRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='rerank-english-v2.0',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='rerank-english-v2.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = CohereRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='rerank-english-v2.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
"Census, Carson City had a population of 55,274.",
|
||||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
|
||||
"is the capital of the United States. It is a federal district. The President of the USA and many major "
|
||||
"national government offices are in the territory. This makes it the political center of the United "
|
||||
"States of America."
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].index == 1
|
||||
assert result.docs[0].score >= 0.8
|
||||
229
api/tests/integration_tests/model_runtime/google/test_llm.py
Normal file
229
api/tests/integration_tests/model_runtime/google/test_llm.py
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.google.google import GoogleProvider
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
|
||||
def test_validate_provider_credentials(setup_google_mock):
|
||||
provider = GoogleProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'google_api_key': os.environ.get('GOOGLE_API_KEY')
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,304 @@
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
||||
LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import UserPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.huggingface_hub.llm.llm import HuggingfaceHubLargeLanguageModel
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='fake-model',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='openchat/openchat_3.5',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='openchat/openchat_3.5',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='openchat/openchat_3.5',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='openchat/openchat_3.5',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='google/mt5-base',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='google/mt5-base',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='google/mt5-base',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='google/mt5-base',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='google/mt5-base',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 7
|
||||
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import \
|
||||
HuggingfaceHubTextEmbeddingModel
|
||||
|
||||
|
||||
def test_hosted_inference_api_validate_credentials():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='facebook/bart-base',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='facebook/bart-base',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_hosted_inference_api_invoke_model():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='facebook/bart-base',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_inference_endpoints_validate_credentials():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='all-MiniLM-L6-v2',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='all-MiniLM-L6-v2',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_inference_endpoints_invoke_model():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='all-MiniLM-L6-v2',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='all-MiniLM-L6-v2',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.jina.jina import JinaProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = JinaProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': 'hahahaha'
|
||||
}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY')
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.jina.text_embedding.text_embedding import JinaTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = JinaTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = JinaTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY'),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = JinaTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY'),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 6
|
||||
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
LocalAI Embedding Interface is temporarily unavaliable due to
|
||||
we could not find a way to test it for now.
|
||||
"""
|
||||
213
api/tests/integration_tests/model_runtime/localai/test_llm.py
Normal file
213
api/tests/integration_tests/model_runtime/localai/test_llm.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from typing import Generator
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \
|
||||
SystemPromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ParameterRule
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
|
||||
LLMResultChunk
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.localai.llm.llm import LocalAILarguageModel
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = LocalAILarguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='chinese-llama-2-7b',
|
||||
credentials={
|
||||
'server_url': 'hahahaha',
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='chinese-llama-2-7b',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = LocalAILarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='ping'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_chat_model():
|
||||
model = LocalAILarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='ping'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_stream_completion_model():
|
||||
model = LocalAILarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
def test_invoke_stream_chat_model():
|
||||
model = LocalAILarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = LocalAILarguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='????',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='????',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 10
|
||||
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.minimax.text_embedding.text_embedding import MinimaxTextEmbeddingModel
|
||||
|
||||
def test_validate_credentials():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embo-01',
|
||||
credentials={
|
||||
'minimax_api_key': 'invalid_key',
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='embo-01',
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_model():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='embo-01',
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 16
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='embo-01',
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
158
api/tests/integration_tests/model_runtime/minimax/test_llm.py
Normal file
158
api/tests/integration_tests/model_runtime/minimax/test_llm.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from typing import Generator
|
||||
from time import sleep
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
|
||||
LLMResultChunk
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.minimax.llm.llm import MinimaxLargeLanguageModel
|
||||
|
||||
def test_predefined_models():
|
||||
model = MinimaxLargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='abab5.5-chat',
|
||||
credentials={
|
||||
'minimax_api_key': 'invalid_key',
|
||||
'minimax_group_id': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='abab5.5-chat',
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='abab5-chat',
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='abab5.5-chat',
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='abab5.5-chat',
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
'plugin_web_search': True,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ''
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
total_message += chunk.delta.message.content
|
||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
|
||||
assert '参考资料' in total_message
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='abab5.5-chat',
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 30
|
||||
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.minimax.minimax import MinimaxProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = MinimaxProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'minimax_api_key': 'hahahaha',
|
||||
'minimax_group_id': '123',
|
||||
}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'),
|
||||
}
|
||||
)
|
||||
382
api/tests/integration_tests/model_runtime/openai/test_llm.py
Normal file
382
api/tests/integration_tests/model_runtime/openai/test_llm.py
Normal file
File diff suppressed because one or more lines are too long
@@ -0,0 +1,55 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = OpenAIModerationModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='text-moderation-stable',
|
||||
credentials={
|
||||
'openai_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='text-moderation-stable',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = OpenAIModerationModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='text-moderation-stable',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
},
|
||||
text="hello",
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, bool)
|
||||
assert result is False
|
||||
|
||||
result = model.invoke(
|
||||
model='text-moderation-stable',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
},
|
||||
text="i will kill you",
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, bool)
|
||||
assert result is True
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openai.openai import OpenAIProvider
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
def test_validate_provider_credentials(setup_openai_mock):
|
||||
provider = OpenAIProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openai.speech2text.speech2text import OpenAISpeech2TextModel
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = OpenAISpeech2TextModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'openai_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = OpenAISpeech2TextModel()
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, 'audio.mp3')
|
||||
|
||||
# Open the file and get the file object
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
file = audio_file
|
||||
|
||||
result = model.invoke(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
},
|
||||
file=file,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
|
||||
@@ -0,0 +1,67 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openai.text_embedding.text_embedding import OpenAITextEmbeddingModel
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = OpenAITextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'openai_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = OpenAITextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'openai_api_base': 'https://api.openai.com'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OpenAITextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'openai_api_base': 'https://api.openai.com'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
@@ -0,0 +1,181 @@
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, \
|
||||
SystemPromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
|
||||
LLMResultChunk
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
"""
|
||||
Using Together.ai's OpenAI-compatible API as testing endpoint
|
||||
"""
|
||||
|
||||
def test_validate_credentials():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
|
||||
'mode': 'chat'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
|
||||
'mode': 'chat'
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_model():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/completions',
|
||||
'mode': 'completion'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/chat/completions',
|
||||
'mode': 'chat'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
|
||||
# using OpenAI's ChatGPT-3.5 as testing endpoint
|
||||
def test_invoke_chat_model_with_tools():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='gpt-3.5-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/chat/completions',
|
||||
'mode': 'chat'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in London?",
|
||||
)
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_weather',
|
||||
description='Determine weather in my location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"celsius",
|
||||
"fahrenheit"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 1024
|
||||
},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert isinstance(result.message, AssistantPromptMessage)
|
||||
assert len(result.message.tool_calls) > 0
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/chat/completions'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
@@ -0,0 +1,79 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import OAICompatEmbeddingModel
|
||||
|
||||
"""
|
||||
Using OpenAI's API as testing endpoint
|
||||
"""
|
||||
|
||||
def test_validate_credentials():
|
||||
model = OAICompatEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'endpoint_url': 'https://api.openai.com/v1/embeddings',
|
||||
'context_size': 8184,
|
||||
'max_chunks': 32
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/embeddings',
|
||||
'context_size': 8184,
|
||||
'max_chunks': 32
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = OAICompatEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/embeddings',
|
||||
'context_size': 8184,
|
||||
'max_chunks': 32
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OAICompatEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/embeddings',
|
||||
'context_size': 8184,
|
||||
'max_chunks': 32
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openllm.text_embedding.text_embedding import OpenLLMTextEmbeddingModel
|
||||
|
||||
def test_validate_credentials():
|
||||
model = OpenLLMTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
credentials={
|
||||
'server_url': 'ww' + os.environ.get('OPENLLM_SERVER_URL'),
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = OpenLLMTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='NOT IMPORTANT',
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens > 0
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OpenLLMTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='NOT IMPORTANT',
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
104
api/tests/integration_tests/model_runtime/openllm/test_llm.py
Normal file
104
api/tests/integration_tests/model_runtime/openllm/test_llm.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from typing import Generator
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
|
||||
LLMResultChunk
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.openllm.llm.llm import OpenLLMLargeLanguageModel
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = OpenLLMLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
credentials={
|
||||
'server_url': 'invalid_key',
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_model():
|
||||
model = OpenLLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='NOT IMPORTANT',
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = OpenLLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='NOT IMPORTANT',
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OpenLLMLargeLanguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='NOT IMPORTANT',
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 3
|
||||
119
api/tests/integration_tests/model_runtime/replicate/test_llm.py
Normal file
119
api/tests/integration_tests/model_runtime/replicate/test_llm.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
||||
LLMResultChunkDelta
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.replicate.llm.llm import ReplicateLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = ReplicateLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='meta/llama-2-13b-chat',
|
||||
credentials={
|
||||
'replicate_api_token': 'invalid_key',
|
||||
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='meta/llama-2-13b-chat',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = ReplicateLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='meta/llama-2-13b-chat',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = ReplicateLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/mixtral-8x7b-instruct-v0.1',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = ReplicateLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
@@ -0,0 +1,151 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.replicate.text_embedding.text_embedding import ReplicateEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials_one():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='replicate/all-mpnet-base-v2',
|
||||
credentials={
|
||||
'replicate_api_token': 'invalid_key',
|
||||
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='replicate/all-mpnet-base-v2',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_two():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='nateraw/bge-large-en-v1.5',
|
||||
credentials={
|
||||
'replicate_api_token': 'invalid_key',
|
||||
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='nateraw/bge-large-en-v1.5',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model_one():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='nateraw/bge-large-en-v1.5',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_invoke_model_two():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='andreasjansson/clip-features',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_invoke_model_three():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='replicate/all-mpnet-base-v2',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_invoke_model_four():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='nateraw/jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='nateraw/jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
114
api/tests/integration_tests/model_runtime/spark/test_llm.py
Normal file
114
api/tests/integration_tests/model_runtime/spark/test_llm.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
||||
LLMResultChunkDelta
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.spark.llm.llm import SparkLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = SparkLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='spark-1.5',
|
||||
credentials={
|
||||
'app_id': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='spark-1.5',
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = SparkLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='spark-1.5',
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = SparkLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='spark-1.5',
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = SparkLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='spark-1.5',
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.spark.spark import SparkProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = SparkProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.provider_entities import SimpleProviderEntity, ProviderConfig, ProviderEntity
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory, ModelProviderExtension
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_get_providers():
|
||||
factory = ModelProviderFactory()
|
||||
providers = factory.get_providers()
|
||||
|
||||
for provider in providers:
|
||||
logger.debug(provider)
|
||||
|
||||
assert len(providers) >= 1
|
||||
assert isinstance(providers[0], ProviderEntity)
|
||||
|
||||
|
||||
def test_get_models():
|
||||
factory = ModelProviderFactory()
|
||||
providers = factory.get_models(
|
||||
model_type=ModelType.LLM,
|
||||
provider_configs=[
|
||||
ProviderConfig(
|
||||
provider='openai',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
logger.debug(providers)
|
||||
|
||||
assert len(providers) >= 1
|
||||
assert isinstance(providers[0], SimpleProviderEntity)
|
||||
|
||||
# all provider models type equals to ModelType.LLM
|
||||
for provider in providers:
|
||||
for provider_model in provider.models:
|
||||
assert provider_model.model_type == ModelType.LLM
|
||||
|
||||
providers = factory.get_models(
|
||||
provider='openai',
|
||||
provider_configs=[
|
||||
ProviderConfig(
|
||||
provider='openai',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert len(providers) == 1
|
||||
assert isinstance(providers[0], SimpleProviderEntity)
|
||||
assert providers[0].provider == 'openai'
|
||||
|
||||
|
||||
def test_provider_credentials_validate():
|
||||
factory = ModelProviderFactory()
|
||||
factory.provider_credentials_validate(
|
||||
provider='openai',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test__get_model_provider_map():
|
||||
factory = ModelProviderFactory()
|
||||
model_providers = factory._get_model_provider_map()
|
||||
|
||||
for name, model_provider in model_providers.items():
|
||||
logger.debug(name)
|
||||
logger.debug(model_provider.provider_instance)
|
||||
|
||||
assert len(model_providers) >= 1
|
||||
assert isinstance(model_providers['openai'], ModelProviderExtension)
|
||||
107
api/tests/integration_tests/model_runtime/tongyi/test_llm.py
Normal file
107
api/tests/integration_tests/model_runtime/tongyi/test_llm.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
||||
LLMResultChunkDelta
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.tongyi.llm.llm import TongyiLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = TongyiLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = TongyiLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = TongyiLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100,
|
||||
'seed': 1234
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = TongyiLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 12
|
||||
@@ -0,0 +1,21 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.tongyi.tongyi import TongyiProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = TongyiProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
}
|
||||
)
|
||||
271
api/tests/integration_tests/model_runtime/wenxin/test_llm.py
Normal file
271
api/tests/integration_tests/model_runtime/wenxin/test_llm.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from typing import Generator
|
||||
from time import sleep
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage, SystemPromptMessage
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
|
||||
LLMResultChunk
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLarguageModel
|
||||
|
||||
def test_predefined_models():
|
||||
model = ErnieBotLarguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = ErnieBotLarguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_model_ernie_bot():
|
||||
sleep(3)
|
||||
model = ErnieBotLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_model_ernie_bot_turbo():
|
||||
sleep(3)
|
||||
model = ErnieBotLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_model_ernie_8k():
|
||||
sleep(3)
|
||||
model = ErnieBotLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_model_ernie_bot_4():
|
||||
sleep(3)
|
||||
model = ErnieBotLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot-4',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = ErnieBotLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
def test_invoke_model_with_system():
|
||||
sleep(3)
|
||||
model = ErnieBotLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='你是Kasumi'
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='你是谁?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert 'kasumi' in response.message.content.lower()
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = ErnieBotLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'disable_search': True,
|
||||
},
|
||||
stop=[],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ''
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
total_message += chunk.delta.message.content
|
||||
print(chunk.delta.message.content)
|
||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
|
||||
# there should be 对不起、我不能、不支持……
|
||||
assert ('不' in total_message or '抱歉' in total_message or '无法' in total_message)
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = ErnieBotLarguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 10
|
||||
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.wenxin.wenxin import WenxinProvider
|
||||
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = WenxinProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': 'hahahaha',
|
||||
'secret_key': 'hahahaha'
|
||||
}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.xinference.text_embedding.text_embedding import XinferenceTextEmbeddingModel
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock, MOCK
|
||||
|
||||
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
|
||||
def test_validate_credentials(setup_xinference_mock):
|
||||
model = XinferenceTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-base-en',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-base-en',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
|
||||
def test_invoke_model(setup_xinference_mock):
|
||||
model = XinferenceTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='bge-base-en',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens > 0
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = XinferenceTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='bge-base-en',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
392
api/tests/integration_tests/model_runtime/xinference/test_llm.py
Normal file
392
api/tests/integration_tests/model_runtime/xinference/test_llm.py
Normal file
@@ -0,0 +1,392 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from typing import Generator
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, \
|
||||
SystemPromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunkDelta, \
|
||||
LLMResultChunk
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.xinference.llm.llm import XinferenceAILargeLanguageModel
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='ChatGLM3',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='aaaaa',
|
||||
credentials={
|
||||
'server_url': '',
|
||||
'model_uid': ''
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='ChatGLM3',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
|
||||
def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ChatGLM3',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
|
||||
def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ChatGLM3',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
"""
|
||||
Funtion calling of xinference does not support stream mode currently
|
||||
"""
|
||||
# def test_invoke_stream_chat_model_with_functions():
|
||||
# model = XinferenceAILargeLanguageModel()
|
||||
|
||||
# response = model.invoke(
|
||||
# model='ChatGLM3-6b',
|
||||
# credentials={
|
||||
# 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
# 'model_type': 'text-generation',
|
||||
# 'model_name': 'ChatGLM3',
|
||||
# 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
# },
|
||||
# prompt_messages=[
|
||||
# SystemPromptMessage(
|
||||
# content='你是一个天气机器人,可以通过调用函数来获取天气信息',
|
||||
# ),
|
||||
# UserPromptMessage(
|
||||
# content='波士顿天气如何?'
|
||||
# )
|
||||
# ],
|
||||
# model_parameters={
|
||||
# 'temperature': 0,
|
||||
# 'top_p': 1.0,
|
||||
# },
|
||||
# stop=['you'],
|
||||
# user='abc-123',
|
||||
# stream=True,
|
||||
# tools=[
|
||||
# PromptMessageTool(
|
||||
# name='get_current_weather',
|
||||
# description='Get the current weather in a given location',
|
||||
# parameters={
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state e.g. San Francisco, CA"
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": ["celsius", "fahrenheit"]
|
||||
# }
|
||||
# },
|
||||
# "required": [
|
||||
# "location"
|
||||
# ]
|
||||
# }
|
||||
# )
|
||||
# ]
|
||||
# )
|
||||
|
||||
# assert isinstance(response, Generator)
|
||||
|
||||
# call: LLMResultChunk = None
|
||||
# chunks = []
|
||||
|
||||
# for chunk in response:
|
||||
# chunks.append(chunk)
|
||||
# assert isinstance(chunk, LLMResultChunk)
|
||||
# assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
# assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
# assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
# if chunk.delta.message.tool_calls and len(chunk.delta.message.tool_calls) > 0:
|
||||
# call = chunk
|
||||
# break
|
||||
|
||||
# assert call is not None
|
||||
# assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
|
||||
|
||||
# def test_invoke_chat_model_with_functions():
|
||||
# model = XinferenceAILargeLanguageModel()
|
||||
|
||||
# response = model.invoke(
|
||||
# model='ChatGLM3-6b',
|
||||
# credentials={
|
||||
# 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
# 'model_type': 'text-generation',
|
||||
# 'model_name': 'ChatGLM3',
|
||||
# 'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
# },
|
||||
# prompt_messages=[
|
||||
# UserPromptMessage(
|
||||
# content='What is the weather like in San Francisco?'
|
||||
# )
|
||||
# ],
|
||||
# model_parameters={
|
||||
# 'temperature': 0.7,
|
||||
# 'top_p': 1.0,
|
||||
# },
|
||||
# stop=['you'],
|
||||
# user='abc-123',
|
||||
# stream=False,
|
||||
# tools=[
|
||||
# PromptMessageTool(
|
||||
# name='get_current_weather',
|
||||
# description='Get the current weather in a given location',
|
||||
# parameters={
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state e.g. San Francisco, CA"
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": [
|
||||
# "c",
|
||||
# "f"
|
||||
# ]
|
||||
# }
|
||||
# },
|
||||
# "required": [
|
||||
# "location"
|
||||
# ]
|
||||
# }
|
||||
# )
|
||||
# ]
|
||||
# )
|
||||
|
||||
# assert isinstance(response, LLMResult)
|
||||
# assert len(response.message.content) > 0
|
||||
# assert response.usage.total_tokens > 0
|
||||
# assert response.message.tool_calls[0].function.name == 'get_current_weather'
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
|
||||
def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='alapaca',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='alapaca',
|
||||
credentials={
|
||||
'server_url': '',
|
||||
'model_uid': ''
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='alapaca',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
|
||||
def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='alapaca',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='the United States is'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
|
||||
def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='alapaca',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='the United States is'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='ChatGLM3',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='ChatGLM3',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.xinference.rerank.rerank import XinferenceRerankModel
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock, MOCK
|
||||
|
||||
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
|
||||
def test_validate_credentials(setup_xinference_mock):
|
||||
model = XinferenceRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': 'awdawdaw',
|
||||
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
|
||||
}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
|
||||
def test_invoke_model(setup_xinference_mock):
|
||||
model = XinferenceRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
|
||||
"and she leads a team named PopiParty."
|
||||
],
|
||||
score_threshold=0.8
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 1
|
||||
assert result.docs[0].index == 0
|
||||
assert result.docs[0].score >= 0.8
|
||||
106
api/tests/integration_tests/model_runtime/zhipuai/test_llm.py
Normal file
106
api/tests/integration_tests/model_runtime/zhipuai/test_llm.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, \
|
||||
LLMResultChunkDelta
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.9,
|
||||
'top_p': 0.7
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.9,
|
||||
'top_p': 0.7
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
@@ -0,0 +1,20 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai import ZhipuaiProvider
|
||||
|
||||
def test_validate_provider_credentials():
|
||||
provider = ZhipuaiProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.zhipuai.text_embedding.text_embedding import ZhipuAITextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = ZhipuAITextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='text_embedding',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='text_embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = ZhipuAITextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='text_embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = ZhipuAITextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='text_embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
Reference in New Issue
Block a user