chore(api/tests): apply ruff reformat #7590 (#7591)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang
2024-08-23 23:52:25 +08:00
committed by GitHub
parent 2da63654e5
commit b035c02f78
155 changed files with 4279 additions and 5925 deletions

View File

@@ -22,23 +22,20 @@ from anthropic.types import (
)
from anthropic.types.message_delta_event import Delta
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
class MockAnthropicClass:
@staticmethod
def mocked_anthropic_chat_create_sync(model: str) -> Message:
return Message(
id='msg-123',
type='message',
role='assistant',
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
id="msg-123",
type="message",
role="assistant",
content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
model=model,
stop_reason='stop_sequence',
usage=Usage(
input_tokens=1,
output_tokens=1
)
stop_reason="stop_sequence",
usage=Usage(input_tokens=1, output_tokens=1),
)
@staticmethod
@@ -46,52 +43,43 @@ class MockAnthropicClass:
full_response_text = "hello, I'm a chatbot from anthropic"
yield MessageStartEvent(
type='message_start',
type="message_start",
message=Message(
id='msg-123',
id="msg-123",
content=[],
role='assistant',
role="assistant",
model=model,
stop_reason=None,
type='message',
usage=Usage(
input_tokens=1,
output_tokens=1
)
)
type="message",
usage=Usage(input_tokens=1, output_tokens=1),
),
)
index = 0
for i in range(0, len(full_response_text)):
yield ContentBlockDeltaEvent(
type='content_block_delta',
delta=TextDelta(text=full_response_text[i], type='text_delta'),
index=index
type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
)
index += 1
yield MessageDeltaEvent(
type='message_delta',
delta=Delta(
stop_reason='stop_sequence'
),
usage=MessageDeltaUsage(
output_tokens=1
)
type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
)
yield MessageStopEvent(type='message_stop')
yield MessageStopEvent(type="message_stop")
def mocked_anthropic(self: Messages, *,
max_tokens: int,
messages: Iterable[MessageParam],
model: str,
stream: Literal[True],
**kwargs: Any
) -> Union[Message, Stream[MessageStreamEvent]]:
def mocked_anthropic(
self: Messages,
*,
max_tokens: int,
messages: Iterable[MessageParam],
model: str,
stream: Literal[True],
**kwargs: Any,
) -> Union[Message, Stream[MessageStreamEvent]]:
if len(self._client.api_key) < 18:
raise anthropic.AuthenticationError('Invalid API key')
raise anthropic.AuthenticationError("Invalid API key")
if stream:
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
@@ -102,7 +90,7 @@ class MockAnthropicClass:
@pytest.fixture
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)
yield

View File

@@ -12,63 +12,46 @@ from google.generativeai.client import _ClientManager, configure
from google.generativeai.types import GenerateContentResponse
from google.generativeai.types.generation_types import BaseGenerateContentResponse
current_api_key = ''
current_api_key = ""
class MockGoogleResponseClass:
_done = False
def __iter__(self):
full_response_text = 'it\'s google!'
full_response_text = "it's google!"
for i in range(0, len(full_response_text) + 1, 1):
if i == len(full_response_text):
self._done = True
yield GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)
else:
yield GenerateContentResponse(
done=False,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)
class MockGoogleResponseCandidateClass:
finish_reason = 'stop'
finish_reason = "stop"
@property
def content(self) -> gag_content.Content:
return gag_content.Content(
parts=[
gag_content.Part(text='it\'s google!')
]
)
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
class MockGoogleClass:
@staticmethod
def generate_content_sync() -> GenerateContentResponse:
return GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
)
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
@staticmethod
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
return MockGoogleResponseClass()
def generate_content(self: GenerativeModel,
def generate_content(
self: GenerativeModel,
contents: content_types.ContentsType,
*,
generation_config: generation_config_types.GenerationConfigType | None = None,
@@ -79,21 +62,21 @@ class MockGoogleClass:
global current_api_key
if len(current_api_key) < 16:
raise Exception('Invalid API key')
raise Exception("Invalid API key")
if stream:
return MockGoogleClass.generate_content_stream()
return MockGoogleClass.generate_content_sync()
@property
def generative_response_text(self) -> str:
return 'it\'s google!'
return "it's google!"
@property
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()]
def make_client(self: _ClientManager, name: str):
global current_api_key
@@ -121,7 +104,8 @@ class MockGoogleClass:
if not self.default_metadata:
return client
@pytest.fixture
def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
@@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
yield
monkeypatch.undo()
monkeypatch.undo()

View File

@@ -6,14 +6,15 @@ from huggingface_hub import InferenceClient
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
yield
if MOCK:
monkeypatch.undo()
monkeypatch.undo()

View File

@@ -22,10 +22,8 @@ class MockHuggingfaceChatClass:
details=Details(
finish_reason="length",
generated_tokens=6,
tokens=[
Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
]
)
tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
),
)
return response
@@ -36,26 +34,23 @@ class MockHuggingfaceChatClass:
for i in range(0, len(full_text)):
response = TextGenerationStreamResponse(
token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
)
response.generated_text = full_text[i]
response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
yield response
def text_generation(self: InferenceClient, prompt: str, *,
stream: Literal[False] = ...,
model: Optional[str] = None,
**kwargs: Any
def text_generation(
self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
# check if key is valid
if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
raise BadRequestError('Invalid API key')
if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
raise BadRequestError("Invalid API key")
if model is None:
raise BadRequestError('Invalid model')
raise BadRequestError("Invalid model")
if stream:
return MockHuggingfaceChatClass.generate_create_stream(model)
return MockHuggingfaceChatClass.generate_create_sync(model)

View File

@@ -5,10 +5,10 @@ class MockTEIClass:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
# During mock, we don't have a real server to query, so we just return a dummy value
if 'rerank' in model_name:
model_type = 'reranker'
if "rerank" in model_name:
model_type = "reranker"
else:
model_type = 'embedding'
model_type = "embedding"
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
@@ -17,16 +17,16 @@ class MockTEIClass:
# Use space as token separator, and split the text into tokens
tokenized_texts = []
for text in texts:
tokens = text.split(' ')
tokens = text.split(" ")
current_index = 0
tokenized_text = []
for idx, token in enumerate(tokens):
s_token = {
'id': idx,
'text': token,
'special': False,
'start': current_index,
'stop': current_index + len(token),
"id": idx,
"text": token,
"special": False,
"start": current_index,
"stop": current_index + len(token),
}
current_index += len(token) + 1
tokenized_text.append(s_token)
@@ -55,18 +55,18 @@ class MockTEIClass:
embedding = [0.1] * 768
embeddings.append(
{
'object': 'embedding',
'embedding': embedding,
'index': idx,
"object": "embedding",
"embedding": embedding,
"index": idx,
}
)
return {
'object': 'list',
'data': embeddings,
'model': 'MODEL_NAME',
'usage': {
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
'total_tokens': sum(len(text.split(' ')) for text in texts),
"object": "list",
"data": embeddings,
"model": "MODEL_NAME",
"usage": {
"prompt_tokens": sum(len(text.split(" ")) for text in texts),
"total_tokens": sum(len(text.split(" ")) for text in texts),
},
}
@@ -83,9 +83,9 @@ class MockTEIClass:
for idx, text in enumerate(texts):
reranked_docs.append(
{
'index': idx,
'text': text,
'score': 0.9,
"index": idx,
"text": text,
"score": 0.9,
}
)
# For mock, only return the first document

View File

@@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_remote import MockModel
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]:
def mock_openai(
monkeypatch: MonkeyPatch,
methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
) -> Callable[[], None]:
"""
mock openai module
mock openai module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
@@ -52,15 +56,16 @@ def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "c
return unpatch
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_openai_mock(request, monkeypatch):
methods = request.param if hasattr(request, 'param') else []
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_openai(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()
unpatch()

View File

@@ -43,62 +43,64 @@ class MockChatClass:
if not functions or len(functions) == 0:
return None
function: completion_create_params.Function = functions[0]
function_name = function['name']
function_description = function['description']
function_parameters = function['parameters']
function_parameters_type = function_parameters['type']
if function_parameters_type != 'object':
function_name = function["name"]
function_description = function["description"]
function_parameters = function["parameters"]
function_parameters_type = function_parameters["type"]
if function_parameters_type != "object":
return None
function_parameters_properties = function_parameters['properties']
function_parameters_required = function_parameters['required']
function_parameters_properties = function_parameters["properties"]
function_parameters_required = function_parameters["required"]
parameters = {}
for parameter_name, parameter in function_parameters_properties.items():
if parameter_name not in function_parameters_required:
continue
parameter_type = parameter['type']
if parameter_type == 'string':
if 'enum' in parameter:
if len(parameter['enum']) == 0:
parameter_type = parameter["type"]
if parameter_type == "string":
if "enum" in parameter:
if len(parameter["enum"]) == 0:
continue
parameters[parameter_name] = parameter['enum'][0]
parameters[parameter_name] = parameter["enum"][0]
else:
parameters[parameter_name] = 'kawaii'
elif parameter_type == 'integer':
parameters[parameter_name] = "kawaii"
elif parameter_type == "integer":
parameters[parameter_name] = 114514
elif parameter_type == 'number':
elif parameter_type == "number":
parameters[parameter_name] = 1919810.0
elif parameter_type == 'boolean':
elif parameter_type == "boolean":
parameters[parameter_name] = True
return FunctionCall(name=function_name, arguments=dumps(parameters))
@staticmethod
def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
list_tool_calls = []
if not tools or len(tools) == 0:
return None
tool = tools[0]
if 'type' in tools and tools['type'] != 'function':
if "type" in tools and tools["type"] != "function":
return None
function = tool['function']
function = tool["function"]
function_call = MockChatClass.generate_function_call(functions=[function])
if function_call is None:
return None
list_tool_calls.append(ChatCompletionMessageToolCall(
id='sakurajima-mai',
function=Function(
name=function_call.name,
arguments=function_call.arguments,
),
type='function'
))
list_tool_calls.append(
ChatCompletionMessageToolCall(
id="sakurajima-mai",
function=Function(
name=function_call.name,
arguments=function_call.arguments,
),
type="function",
)
)
return list_tool_calls
@staticmethod
def mocked_openai_chat_create_sync(
model: str,
@@ -111,30 +113,27 @@ class MockChatClass:
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
return _ChatCompletion(
id='cmpl-3QJQa5jXJ5Z5X',
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
_ChatCompletionChoice(
finish_reason='content_filter',
finish_reason="content_filter",
index=0,
message=ChatCompletionMessage(
content='elaina',
role='assistant',
function_call=function_call,
tool_calls=tool_calls
)
content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
),
)
],
created=int(time()),
model=model,
object='chat.completion',
system_fingerprint='',
object="chat.completion",
system_fingerprint="",
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
)
),
)
@staticmethod
def mocked_openai_chat_create_stream(
model: str,
@@ -150,36 +149,40 @@ class MockChatClass:
for i in range(0, len(full_text) + 1):
if i == len(full_text):
yield ChatCompletionChunk(
id='cmpl-3QJQa5jXJ5Z5X',
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
Choice(
delta=ChoiceDelta(
content='',
content="",
function_call=ChoiceDeltaFunctionCall(
name=function_call.name,
arguments=function_call.arguments,
) if function_call else None,
role='assistant',
)
if function_call
else None,
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id='misaka-mikoto',
id="misaka-mikoto",
function=ChoiceDeltaToolCallFunction(
name=tool_calls[0].function.name,
arguments=tool_calls[0].function.arguments,
),
type='function'
type="function",
)
] if tool_calls and len(tool_calls) > 0 else None
]
if tool_calls and len(tool_calls) > 0
else None,
),
finish_reason='function_call',
finish_reason="function_call",
index=0,
)
],
created=int(time()),
model=model,
object='chat.completion.chunk',
system_fingerprint='',
object="chat.completion.chunk",
system_fingerprint="",
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=17,
@@ -188,30 +191,45 @@ class MockChatClass:
)
else:
yield ChatCompletionChunk(
id='cmpl-3QJQa5jXJ5Z5X',
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
Choice(
delta=ChoiceDelta(
content=full_text[i],
role='assistant',
role="assistant",
),
finish_reason='content_filter',
finish_reason="content_filter",
index=0,
)
],
created=int(time()),
model=model,
object='chat.completion.chunk',
system_fingerprint='',
object="chat.completion.chunk",
system_fingerprint="",
)
def chat_create(self: Completions, *,
def chat_create(
self: Completions,
*,
messages: list[ChatCompletionMessageParam],
model: Union[str,Literal[
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"],
model: Union[
str,
Literal[
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
],
],
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
@@ -220,24 +238,32 @@ class MockChatClass:
**kwargs: Any,
):
openai_models = [
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
]
azure_openai_models = [
"gpt35", "gpt-4v", "gpt-35-turbo"
]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if model in openai_models + azure_openai_models:
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if stream:
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)

View File

@@ -17,9 +17,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockCompletionsClass:
@staticmethod
def mocked_openai_completion_create_sync(
model: str
) -> CompletionMessage:
def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
return CompletionMessage(
id="cmpl-3QJQa5jXJ5Z5X",
object="text_completion",
@@ -38,13 +36,11 @@ class MockCompletionsClass:
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
)
),
)
@staticmethod
def mocked_openai_completion_create_stream(
model: str
) -> Generator[CompletionMessage, None, None]:
def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
for i in range(0, len(full_text) + 1):
if i == len(full_text):
@@ -76,46 +72,59 @@ class MockCompletionsClass:
model=model,
system_fingerprint="",
choices=[
CompletionChoice(
text=full_text[i],
index=0,
logprobs=None,
finish_reason="content_filter"
)
CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
],
)
def completion_create(self: Completions, *, model: Union[
str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct",
"text-davinci-003", "text-davinci-002", "text-davinci-001",
"code-davinci-002", "text-curie-001", "text-babbage-001",
"text-ada-001"],
def completion_create(
self: Completions,
*,
model: Union[
str,
Literal[
"babbage-002",
"davinci-002",
"gpt-3.5-turbo-instruct",
"text-davinci-003",
"text-davinci-002",
"text-davinci-001",
"code-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
],
],
prompt: Union[str, list[str], list[int], list[list[int]], None],
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
**kwargs: Any
**kwargs: Any,
):
openai_models = [
"babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001",
"code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001",
]
azure_openai_models = [
"gpt-35-turbo-instruct"
"babbage-002",
"davinci-002",
"gpt-3.5-turbo-instruct",
"text-davinci-003",
"text-davinci-002",
"text-davinci-001",
"code-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
]
azure_openai_models = ["gpt-35-turbo-instruct"]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if model in openai_models + azure_openai_models:
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if not prompt:
raise BadRequestError('Invalid prompt')
raise BadRequestError("Invalid prompt")
if stream:
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)

File diff suppressed because one or more lines are too long

View File

@@ -10,58 +10,92 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockModerationClass:
def moderation_create(self: Moderations,*,
def moderation_create(
self: Moderations,
*,
input: Union[str, list[str]],
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
**kwargs: Any
**kwargs: Any,
) -> ModerationCreateResponse:
if isinstance(input, str):
input = [input]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError('Invalid API key')
raise InvokeAuthorizationError("Invalid API key")
for text in input:
result = []
if 'kill' in text:
if "kill" in text:
moderation_categories = {
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
'sexual/minors': False, 'violence': False, 'violence/graphic': False
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
}
moderation_categories_scores = {
'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0,
'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0,
'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0
"harassment": 1.0,
"harassment/threatening": 1.0,
"hate": 1.0,
"hate/threatening": 1.0,
"self-harm": 1.0,
"self-harm/instructions": 1.0,
"self-harm/intent": 1.0,
"sexual": 1.0,
"sexual/minors": 1.0,
"violence": 1.0,
"violence/graphic": 1.0,
}
result.append(Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores)
))
result.append(
Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
)
)
else:
moderation_categories = {
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
'sexual/minors': False, 'violence': False, 'violence/graphic': False
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
}
moderation_categories_scores = {
'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0,
'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0,
'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0
"harassment": 0.0,
"harassment/threatening": 0.0,
"hate": 0.0,
"hate/threatening": 0.0,
"self-harm": 0.0,
"self-harm/instructions": 0.0,
"self-harm/intent": 0.0,
"sexual": 0.0,
"sexual/minors": 0.0,
"violence": 0.0,
"violence/graphic": 0.0,
}
result.append(Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores)
))
result.append(
Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
)
)
return ModerationCreateResponse(
id='shiroii kuloko',
model=model,
results=result
)
return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result)

View File

@@ -6,17 +6,18 @@ from openai.types.model import Model
class MockModelClass:
"""
mock class for openai.models.Models
mock class for openai.models.Models
"""
def list(
self,
**kwargs,
) -> list[Model]:
return [
Model(
id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ',
id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ",
created=int(time()),
object='model',
owned_by='organization:org-123',
object="model",
owned_by="organization:org-123",
)
]
]

View File

@@ -9,7 +9,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockSpeech2TextClass:
def speech2text_create(self: Transcriptions,
def speech2text_create(
self: Transcriptions,
*,
file: FileTypes,
model: Union[str, Literal["whisper-1"]],
@@ -17,14 +18,12 @@ class MockSpeech2TextClass:
prompt: str | NotGiven = NOT_GIVEN,
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
**kwargs: Any
**kwargs: Any,
) -> Transcription:
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError('Invalid API key')
return Transcription(
text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
)
raise InvokeAuthorizationError("Invalid API key")
return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")

View File

@@ -19,40 +19,43 @@ from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
class MockXinferenceClass:
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
raise RuntimeError('404 Not Found')
if 'generate' == model_uid:
def get_chat_model(
self: Client, model_uid: str
) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
raise RuntimeError("404 Not Found")
if "generate" == model_uid:
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'chat' == model_uid:
if "chat" == model_uid:
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'embedding' == model_uid:
if "embedding" == model_uid:
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'rerank' == model_uid:
if "rerank" == model_uid:
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError('404 Not Found')
raise RuntimeError("404 Not Found")
def get(self: Session, url: str, **kwargs):
response = Response()
if 'v1/models/' in url:
if "v1/models/" in url:
# get model uid
model_uid = url.split('/')[-1] or ''
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
model_uid = url.split("/")[-1] or ""
if not re.match(
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
) and model_uid not in ["generate", "chat", "embedding", "rerank"]:
response.status_code = 404
response._content = b'{}'
response._content = b"{}"
return response
# check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
response.status_code = 404
response._content = b'{}'
response._content = b"{}"
return response
if model_uid in ['generate', 'chat']:
if model_uid in ["generate", "chat"]:
response.status_code = 200
response._content = b'''{
response._content = b"""{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
@@ -75,12 +78,12 @@ class MockXinferenceClass:
"revision": null,
"context_length": 2048,
"replica": 1
}'''
}"""
return response
elif model_uid == 'embedding':
elif model_uid == "embedding":
response.status_code = 200
response._content = b'''{
response._content = b"""{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
@@ -93,51 +96,48 @@ class MockXinferenceClass:
],
"revision": null,
"max_tokens": 512
}'''
}"""
return response
elif 'v1/cluster/auth' in url:
elif "v1/cluster/auth" in url:
response.status_code = 200
response._content = b'''{
response._content = b"""{
"auth": true
}'''
}"""
return response
def _check_cluster_authenticated(self):
self._cluster_authed = True
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
def rerank(
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'rerank':
raise RuntimeError('404 Not Found')
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
raise RuntimeError('404 Not Found')
if (
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
and self._model_uid != "rerank"
):
raise RuntimeError("404 Not Found")
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
raise RuntimeError("404 Not Found")
if top_n is None:
top_n = 1
return {
'results': [
{
'index': i,
'document': doc,
'relevance_score': 0.9
}
for i, doc in enumerate(documents[:top_n])
"results": [
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
]
}
def create_embedding(
self: RESTfulGenerateModelHandle,
input: Union[str, list[str]],
**kwargs
) -> dict:
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'embedding':
raise RuntimeError('404 Not Found')
if (
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
and self._model_uid != "embedding"
):
raise RuntimeError("404 Not Found")
if isinstance(input, str):
input = [input]
@@ -147,32 +147,27 @@ class MockXinferenceClass:
object="list",
model=self._model_uid,
data=[
EmbeddingData(
index=i,
object="embedding",
embedding=[1919.810 for _ in range(768)]
)
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
for i in range(ipt_len)
],
usage=EmbeddingUsage(
prompt_tokens=ipt_len,
total_tokens=ipt_len
)
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
)
return embedding
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
yield
if MOCK:
monkeypatch.undo()
monkeypatch.undo()