chore(api/tests): apply ruff reformat #7590 (#7591)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang
2024-08-23 23:52:25 +08:00
committed by GitHub
parent 2da63654e5
commit b035c02f78
155 changed files with 4279 additions and 5925 deletions

View File

@@ -17,31 +17,31 @@ from core.app.segments.exc import VariableError
def test_string_variable():
test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'}
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = factory.build_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)
def test_integer_variable():
test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42}
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = factory.build_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)
def test_float_variable():
test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14}
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = factory.build_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)
def test_secret_variable():
test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'}
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = factory.build_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)
def test_invalid_value_type():
test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'}
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
factory.build_variable_from_mapping(test_data)
@@ -49,51 +49,51 @@ def test_invalid_value_type():
def test_build_a_blank_string():
result = factory.build_variable_from_mapping(
{
'value_type': 'string',
'name': 'blank',
'value': '',
"value_type": "string",
"name": "blank",
"value": "",
}
)
assert isinstance(result, StringVariable)
assert result.value == ''
assert result.value == ""
def test_build_a_object_variable_with_none_value():
var = factory.build_segment(
{
'key1': None,
"key1": None,
}
)
assert isinstance(var, ObjectSegment)
assert var.value['key1'] is None
assert var.value["key1"] is None
def test_object_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'object',
'name': 'test_object',
'description': 'Description of the variable.',
'value': {
'key1': 'text',
'key2': 2,
"id": str(uuid4()),
"value_type": "object",
"name": "test_object",
"description": "Description of the variable.",
"value": {
"key1": "text",
"key2": 2,
},
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value['key1'], str)
assert isinstance(variable.value['key2'], int)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)
def test_array_string_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'array[string]',
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
'text',
'text',
"id": str(uuid4()),
"value_type": "array[string]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
"text",
"text",
],
}
variable = factory.build_variable_from_mapping(mapping)
@@ -104,11 +104,11 @@ def test_array_string_variable():
def test_array_number_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'array[number]',
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
"id": str(uuid4()),
"value_type": "array[number]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
1,
2.0,
],
@@ -121,18 +121,18 @@ def test_array_number_variable():
def test_array_object_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'array[object]',
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
"id": str(uuid4()),
"value_type": "array[object]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
{
'key1': 'text',
'key2': 1,
"key1": "text",
"key2": 1,
},
{
'key1': 'text',
'key2': 1,
"key1": "text",
"key2": 1,
},
],
}
@@ -140,19 +140,19 @@ def test_array_object_variable():
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]['key1'], str)
assert isinstance(variable.value[0]['key2'], int)
assert isinstance(variable.value[1]['key1'], str)
assert isinstance(variable.value[1]['key2'], int)
assert isinstance(variable.value[0]["key1"], str)
assert isinstance(variable.value[0]["key2"], int)
assert isinstance(variable.value[1]["key1"], str)
assert isinstance(variable.value[1]["key2"], int)
def test_variable_cannot_large_than_5_kb():
with pytest.raises(VariableError):
factory.build_variable_from_mapping(
{
'id': str(uuid4()),
'value_type': 'string',
'name': 'test_text',
'value': 'a' * 1024 * 6,
"id": str(uuid4()),
"value_type": "string",
"name": "test_text",
"value": "a" * 1024 * 6,
}
)

View File

@@ -7,20 +7,20 @@ from core.workflow.enums import SystemVariableKey
def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables={
SystemVariableKey('user_id'): 'fake-user-id',
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name='secret_key', value='fake-secret-key'),
SecretVariable(name="secret_key", value="fake-secret-key"),
],
)
variable_pool.add(('node_id', 'custom_query'), 'fake-user-query')
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
'Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}.'
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
assert segments_group.text == 'Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key.'
assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key."
assert (
segments_group.log
== f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}."
@@ -33,22 +33,22 @@ def test_convert_constant_to_segment_group():
user_inputs={},
environment_variables=[],
)
template = 'Hello, world!'
template = "Hello, world!"
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
assert segments_group.text == 'Hello, world!'
assert segments_group.log == 'Hello, world!'
assert segments_group.text == "Hello, world!"
assert segments_group.log == "Hello, world!"
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables={
SystemVariableKey('user_id'): 'fake-user-id',
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[],
)
template = '{{#sys.user_id#}}'
template = "{{#sys.user_id#}}"
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
assert segments_group.text == 'fake-user-id'
assert segments_group.log == 'fake-user-id'
assert segments_group.value == [StringSegment(value='fake-user-id')]
assert segments_group.text == "fake-user-id"
assert segments_group.log == "fake-user-id"
assert segments_group.value == [StringSegment(value="fake-user-id")]

View File

@@ -13,60 +13,60 @@ from core.app.segments import (
def test_frozen_variables():
var = StringVariable(name='text', value='text')
var = StringVariable(name="text", value="text")
with pytest.raises(ValidationError):
var.value = 'new value'
var.value = "new value"
int_var = IntegerVariable(name='integer', value=42)
int_var = IntegerVariable(name="integer", value=42)
with pytest.raises(ValidationError):
int_var.value = 100
float_var = FloatVariable(name='float', value=3.14)
float_var = FloatVariable(name="float", value=3.14)
with pytest.raises(ValidationError):
float_var.value = 2.718
secret_var = SecretVariable(name='secret', value='secret_value')
secret_var = SecretVariable(name="secret", value="secret_value")
with pytest.raises(ValidationError):
secret_var.value = 'new_secret_value'
secret_var.value = "new_secret_value"
def test_variable_value_type_immutable():
with pytest.raises(ValidationError):
StringVariable(value_type=SegmentType.ARRAY_ANY, name='text', value='text')
StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text")
with pytest.raises(ValidationError):
StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'})
StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"})
var = IntegerVariable(name='integer', value=42)
var = IntegerVariable(name="integer", value=42)
with pytest.raises(ValidationError):
IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
var = FloatVariable(name='float', value=3.14)
var = FloatVariable(name="float", value=3.14)
with pytest.raises(ValidationError):
FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
var = SecretVariable(name='secret', value='secret_value')
var = SecretVariable(name="secret", value="secret_value")
with pytest.raises(ValidationError):
SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
def test_object_variable_to_object():
var = ObjectVariable(
name='object',
name="object",
value={
'key1': {
'key2': 'value2',
"key1": {
"key2": "value2",
},
'key2': ['value5_1', 42, {}],
"key2": ["value5_1", 42, {}],
},
)
assert var.to_object() == {
'key1': {
'key2': 'value2',
"key1": {
"key2": "value2",
},
'key2': [
'value5_1',
"key2": [
"value5_1",
42,
{},
],
@@ -74,11 +74,11 @@ def test_object_variable_to_object():
def test_variable_to_object():
var = StringVariable(name='text', value='text')
assert var.to_object() == 'text'
var = IntegerVariable(name='integer', value=42)
var = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42
var = FloatVariable(name='float', value=3.14)
var = FloatVariable(name="float", value=3.14)
assert var.to_object() == 3.14
var = SecretVariable(name='secret', value='secret_value')
assert var.to_object() == 'secret_value'
var = SecretVariable(name="secret", value="secret_value")
assert var.to_object() == "secret_value"

View File

@@ -4,17 +4,17 @@ from unittest.mock import MagicMock, patch
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
@patch('httpx.request')
@patch("httpx.request")
def test_successful_request(mock_request):
mock_response = MagicMock()
mock_response.status_code = 200
mock_request.return_value = mock_response
response = make_request('GET', 'http://example.com')
response = make_request("GET", "http://example.com")
assert response.status_code == 200
@patch('httpx.request')
@patch("httpx.request")
def test_retry_exceed_max_retries(mock_request):
mock_response = MagicMock()
mock_response.status_code = 500
@@ -23,13 +23,13 @@ def test_retry_exceed_max_retries(mock_request):
mock_request.side_effect = side_effects
try:
make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
raise AssertionError("Expected Exception not raised")
except Exception as e:
assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch('httpx.request')
@patch("httpx.request")
def test_retry_logic_success(mock_request):
side_effects = []
@@ -45,8 +45,8 @@ def test_retry_logic_success(mock_request):
mock_request.side_effect = side_effects
response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES)
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
assert response.status_code == 200
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_request.call_args_list[0][1].get('method') == 'GET'
assert mock_request.call_args_list[0][1].get("method") == "GET"

View File

@@ -21,18 +21,18 @@ def test_max_chunks():
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
return _MockTextEmbedding()
model = 'embedding-v1'
model = "embedding-v1"
credentials = {
'api_key': 'xxxx',
'secret_key': 'yyyy',
"api_key": "xxxx",
"secret_key": "yyyy",
}
embedding_model = WenxinTextEmbeddingModel()
context_size = embedding_model._get_context_size(model, credentials)
max_chunks = embedding_model._get_max_chunks(model, credentials)
embedding_model._create_text_embedding = _create_text_embedding
texts = ['0123456789' for i in range(0, max_chunks * 2)]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
texts = ["0123456789" for i in range(0, max_chunks * 2)]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test")
assert len(result.embeddings) == max_chunks * 2
@@ -41,16 +41,16 @@ def test_context_size():
return GPT2Tokenizer.get_num_tokens(text)
def mock_text(token_size: int) -> str:
_text = "".join(['0' for i in range(token_size)])
_text = "".join(["0" for i in range(token_size)])
num_tokens = get_num_tokens_by_gpt2(_text)
ratio = int(np.floor(len(_text) / num_tokens))
m_text = "".join([_text for i in range(ratio)])
return m_text
model = 'embedding-v1'
model = "embedding-v1"
credentials = {
'api_key': 'xxxx',
'secret_key': 'yyyy',
"api_key": "xxxx",
"secret_key": "yyyy",
}
embedding_model = WenxinTextEmbeddingModel()
context_size = embedding_model._get_context_size(model, credentials)
@@ -71,5 +71,5 @@ def test_context_size():
assert get_num_tokens_by_gpt2(text) == context_size * 2
texts = [text]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test")
assert result.usage.tokens == context_size

View File

@@ -14,39 +14,24 @@ from models.model import Conversation
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-3.5-turbo-instruct'
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-3.5-turbo-instruct"
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
prompt_template_config = CompletionModelPromptTemplate(
text=prompt_template
)
prompt_template_config = CompletionModelPromptTemplate(text=prompt_template)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(
user="Human",
assistant="Assistant"
),
window=MemoryConfig.WindowConfig(
enabled=False
)
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=False),
)
inputs = {
"name": "John"
}
inputs = {"name": "John"}
files = []
context = "I am superman."
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Hello")
]
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
@@ -59,16 +44,22 @@ def test__get_completion_model_prompt_messages():
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock
model_config=model_config_mock,
)
assert len(prompt_messages) == 1
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({
"#context#": context,
"#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: "
f"{prompt.content}" for prompt in history_prompt_messages]),
**inputs,
})
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format(
{
"#context#": context,
"#histories#": "\n".join(
[
f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " f"{prompt.content}"
for prompt in history_prompt_messages
]
),
**inputs,
}
)
def test__get_chat_model_prompt_messages(get_chat_model_args):
@@ -77,15 +68,9 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
files = []
query = "Hi2."
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [
UserPromptMessage(content="Hi1."),
AssistantPromptMessage(content="Hello1!")
]
history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
@@ -98,14 +83,14 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock
model_config=model_config_mock,
)
assert len(prompt_messages) == 6
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=messages[0].text
).format({**inputs, "#context#": context})
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
assert prompt_messages[5].content == query
@@ -124,14 +109,14 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock
model_config=model_config_mock,
)
assert len(prompt_messages) == 3
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=messages[0].text
).format({**inputs, "#context#": context})
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
@@ -148,7 +133,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
image_config={
"detail": "high",
}
)
),
)
]
@@ -162,14 +147,14 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock
model_config=model_config_mock,
)
assert len(prompt_messages) == 4
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=messages[0].text
).format({**inputs, "#context#": context})
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
assert isinstance(prompt_messages[3].content, list)
assert len(prompt_messages[3].content) == 2
assert prompt_messages[3].content[1].data == files[0].url
@@ -178,33 +163,20 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
@pytest.fixture
def get_chat_model_args():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4'
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-4"
memory_config = MemoryConfig(
window=MemoryConfig.WindowConfig(
enabled=False
)
)
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
prompt_messages = [
ChatModelMessage(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM
),
ChatModelMessage(
text="Hi.",
role=PromptMessageRole.USER
),
ChatModelMessage(
text="Hello!",
role=PromptMessageRole.ASSISTANT
)
ChatModelMessage(text="Hi.", role=PromptMessageRole.USER),
ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
inputs = {
"name": "John"
}
inputs = {"name": "John"}
context = "I am superman."

View File

@@ -18,27 +18,28 @@ from models.model import Conversation
def test_get_prompt():
prompt_messages = [
SystemPromptMessage(content='System Template'),
UserPromptMessage(content='User Query'),
SystemPromptMessage(content="System Template"),
UserPromptMessage(content="User Query"),
]
history_messages = [
SystemPromptMessage(content='System Prompt 1'),
UserPromptMessage(content='User Prompt 1'),
AssistantPromptMessage(content='Assistant Thought 1'),
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
SystemPromptMessage(content='System Prompt 2'),
UserPromptMessage(content='User Prompt 2'),
AssistantPromptMessage(content='Assistant Thought 2'),
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
UserPromptMessage(content='User Prompt 3'),
AssistantPromptMessage(content='Assistant Thought 3'),
SystemPromptMessage(content="System Prompt 1"),
UserPromptMessage(content="User Prompt 1"),
AssistantPromptMessage(content="Assistant Thought 1"),
ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"),
ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"),
SystemPromptMessage(content="System Prompt 2"),
UserPromptMessage(content="User Prompt 2"),
AssistantPromptMessage(content="Assistant Thought 2"),
ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"),
ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"),
UserPromptMessage(content="User Prompt 3"),
AssistantPromptMessage(content="Assistant Thought 3"),
]
# use message number instead of token for testing
def side_effect_get_num_tokens(*args):
return len(args[2])
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
@@ -46,20 +47,17 @@ def test_get_prompt():
provider_model_bundle_mock.model_type_instance = large_language_model_mock
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.model = 'openai'
model_config_mock.model = "openai"
model_config_mock.credentials = {}
model_config_mock.provider_model_bundle = provider_model_bundle_mock
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
transform = AgentHistoryPromptTransform(
model_config=model_config_mock,
prompt_messages=prompt_messages,
history_messages=history_messages,
memory=memory
memory=memory,
)
max_token_limit = 5

View File

@@ -12,19 +12,15 @@ from core.prompt.prompt_transform import PromptTransform
def test__calculate_rest_token():
model_schema_mock = MagicMock(spec=AIModelEntity)
parameter_rule_mock = MagicMock(spec=ParameterRule)
parameter_rule_mock.name = 'max_tokens'
model_schema_mock.parameter_rules = [
parameter_rule_mock
]
model_schema_mock.model_properties = {
ModelPropertyKey.CONTEXT_SIZE: 62
}
parameter_rule_mock.name = "max_tokens"
model_schema_mock.parameter_rules = [parameter_rule_mock]
model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens.return_value = 6
provider_mock = MagicMock(spec=ProviderEntity)
provider_mock.provider = 'openai'
provider_mock.provider = "openai"
provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
provider_configuration_mock.provider = provider_mock
@@ -35,11 +31,9 @@ def test__calculate_rest_token():
provider_model_bundle_mock.configuration = provider_configuration_mock
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.model = 'gpt-4'
model_config_mock.model = "gpt-4"
model_config_mock.credentials = {}
model_config_mock.parameters = {
'max_tokens': 50
}
model_config_mock.parameters = {"max_tokens": 50}
model_config_mock.model_schema = model_schema_mock
model_config_mock.provider_model_bundle = provider_model_bundle_mock
@@ -49,8 +43,10 @@ def test__calculate_rest_token():
rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
# Validate based on the mock configuration and expected logic
expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
- model_config_mock.parameters['max_tokens']
- large_language_model_mock.get_num_tokens.return_value)
expected_rest_tokens = (
model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
- model_config_mock.parameters["max_tokens"]
- large_language_model_mock.get_num_tokens.return_value
)
assert rest_tokens == expected_rest_tokens
assert rest_tokens == 6

View File

@@ -19,12 +19,15 @@ def test_get_common_chat_app_prompt_template_with_pcqm():
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['histories_prompt']
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"]
+ pre_prompt
+ "\n"
+ prompt_rules["histories_prompt"]
+ prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
def test_get_baichuan_chat_app_prompt_template_with_pcqm():
@@ -39,12 +42,15 @@ def test_get_baichuan_chat_app_prompt_template_with_pcqm():
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['histories_prompt']
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"]
+ pre_prompt
+ "\n"
+ prompt_rules["histories_prompt"]
+ prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
def test_get_common_completion_app_prompt_template_with_pcq():
@@ -59,11 +65,11 @@ def test_get_common_completion_app_prompt_template_with_pcq():
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_baichuan_completion_app_prompt_template_with_pcq():
@@ -78,12 +84,12 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq():
query_in_prompt=True,
with_memory_prompt=False,
)
print(prompt_template['prompt_template'].template)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
print(prompt_template["prompt_template"].template)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_common_chat_app_prompt_template_with_q():
@@ -98,9 +104,9 @@ def test_get_common_chat_app_prompt_template_with_q():
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == prompt_rules['query_prompt']
assert prompt_template['special_variable_keys'] == ['#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"]
assert prompt_template["special_variable_keys"] == ["#query#"]
def test_get_common_chat_app_prompt_template_with_cq():
@@ -115,10 +121,11 @@ def test_get_common_chat_app_prompt_template_with_cq():
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_common_chat_app_prompt_template_with_p():
@@ -133,30 +140,25 @@ def test_get_common_chat_app_prompt_template_with_p():
query_in_prompt=False,
with_memory_prompt=False,
)
assert prompt_template['prompt_template'].template == pre_prompt + '\n'
assert prompt_template['custom_variable_keys'] == ['name']
assert prompt_template['special_variable_keys'] == []
assert prompt_template["prompt_template"].template == pre_prompt + "\n"
assert prompt_template["custom_variable_keys"] == ["name"]
assert prompt_template["special_variable_keys"] == []
def test__get_chat_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4'
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-4"
memory_mock = MagicMock(spec=TokenBufferMemory)
history_prompt_messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Hello")
]
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory_mock.get_history_prompt_messages.return_value = history_prompt_messages
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {
"name": "John"
}
inputs = {"name": "John"}
context = "yes or no."
query = "How are you?"
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
@@ -167,7 +169,7 @@ def test__get_chat_model_prompt_messages():
files=[],
context=context,
memory=memory_mock,
model_config=model_config_mock
model_config=model_config_mock,
)
prompt_template = prompt_transform.get_prompt_template(
@@ -180,8 +182,8 @@ def test__get_chat_model_prompt_messages():
with_memory_prompt=False,
)
full_inputs = {**inputs, '#context#': context}
real_system_prompt = prompt_template['prompt_template'].format(full_inputs)
full_inputs = {**inputs, "#context#": context}
real_system_prompt = prompt_template["prompt_template"].format(full_inputs)
assert len(prompt_messages) == 4
assert prompt_messages[0].content == real_system_prompt
@@ -192,26 +194,18 @@ def test__get_chat_model_prompt_messages():
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-3.5-turbo-instruct'
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-3.5-turbo-instruct"
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Hello")
]
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {
"name": "John"
}
inputs = {"name": "John"}
context = "yes or no."
query = "How are you?"
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
@@ -222,7 +216,7 @@ def test__get_completion_model_prompt_messages():
files=[],
context=context,
memory=memory,
model_config=model_config_mock
model_config=model_config_mock,
)
prompt_template = prompt_transform.get_prompt_template(
@@ -235,14 +229,19 @@ def test__get_completion_model_prompt_messages():
with_memory_prompt=True,
)
prompt_rules = prompt_template['prompt_rules']
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
max_token_limit=2000,
human_prefix=prompt_rules.get("human_prefix", "Human"),
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant")
)}
real_prompt = prompt_template['prompt_template'].format(full_inputs)
prompt_rules = prompt_template["prompt_rules"]
full_inputs = {
**inputs,
"#context#": context,
"#query#": query,
"#histories#": memory.get_history_prompt_text(
max_token_limit=2000,
human_prefix=prompt_rules.get("human_prefix", "Human"),
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
),
}
real_prompt = prompt_template["prompt_template"].format(full_inputs)
assert len(prompt_messages) == 1
assert stops == prompt_rules.get('stops')
assert stops == prompt_rules.get("stops")
assert prompt_messages[0].content == real_prompt

View File

@@ -5,20 +5,15 @@ from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
def test_default_value():
valid_config = {
'host': 'localhost',
'port': 19530,
'user': 'root',
'password': 'Milvus'
}
valid_config = {"host": "localhost", "port": 19530, "user": "root", "password": "Milvus"}
for key in valid_config:
config = valid_config.copy()
del config[key]
with pytest.raises(ValidationError) as e:
MilvusConfig(**config)
assert e.value.errors()[0]['msg'] == f'Value error, config MILVUS_{key.upper()} is required'
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
config = MilvusConfig(**valid_config)
assert config.secure is False
assert config.database == 'default'
assert config.database == "default"

View File

@@ -9,19 +9,17 @@ from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_resp
def test_firecrawl_web_extractor_crawl_mode(mocker):
url = "https://firecrawl.dev"
api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-'
base_url = 'https://api.firecrawl.dev'
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=base_url)
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url)
params = {
'crawlerOptions': {
"crawlerOptions": {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"maxDepth": 1,
"limit": 1,
'returnOnlyUrls': False,
"returnOnlyUrls": False,
}
}
mocked_firecrawl = {

View File

@@ -8,11 +8,8 @@ page_id = "page1"
extractor = notion_extractor.NotionExtractor(
notion_workspace_id='x',
notion_obj_id='x',
notion_page_type='page',
tenant_id='x',
notion_access_token='x')
notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x"
)
def _generate_page(page_title: str):
@@ -21,16 +18,10 @@ def _generate_page(page_title: str):
"id": page_id,
"properties": {
"Page": {
"type": "title",
"title": [
{
"type": "text",
"text": {"content": page_title},
"plain_text": page_title
}
]
"type": "title",
"title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}],
}
}
},
}
@@ -38,10 +29,7 @@ def _generate_block(block_id: str, block_type: str, block_text: str):
return {
"object": "block",
"id": block_id,
"parent": {
"type": "page_id",
"page_id": page_id
},
"parent": {"type": "page_id", "page_id": page_id},
"type": block_type,
"has_children": False,
block_type: {
@@ -49,10 +37,11 @@ def _generate_block(block_id: str, block_type: str, block_text: str):
{
"type": "text",
"text": {"content": block_text},
"plain_text": block_text,
}]
}
}
"plain_text": block_text,
}
]
},
}
def _mock_response(data):
@@ -63,7 +52,7 @@ def _mock_response(data):
def _remove_multiple_new_lines(text):
while '\n\n' in text:
while "\n\n" in text:
text = text.replace("\n\n", "\n")
return text.strip()
@@ -71,21 +60,21 @@ def _remove_multiple_new_lines(text):
def test_notion_page(mocker):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
"results": [
_generate_block("b1", "heading_1", texts[0]),
_generate_block("b2", "heading_2", texts[1]),
_generate_block("b3", "paragraph", texts[2]),
_generate_block("b4", "heading_3", texts[3])
],
"next_cursor": None
"object": "list",
"results": [
_generate_block("b1", "heading_1", texts[0]),
_generate_block("b2", "heading_2", texts[1]),
_generate_block("b3", "paragraph", texts[2]),
_generate_block("b4", "heading_3", texts[3]),
],
"next_cursor": None,
}
mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
page_docs = extractor._load_data_as_documents(page_id, "page")
assert len(page_docs) == 1
content = _remove_multiple_new_lines(page_docs[0].page_content)
assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1'
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker):
@@ -93,10 +82,10 @@ def test_notion_database(mocker):
mocked_notion_database = {
"object": "list",
"results": [_generate_page(i) for i in page_title_list],
"next_cursor": None
"next_cursor": None,
}
mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
database_docs = extractor._load_data_as_documents(database_id, "database")
assert len(database_docs) == 1
content = _remove_multiple_new_lines(database_docs[0].page_content)
assert content == '\n'.join([f'Page:{i}' for i in page_title_list])
assert content == "\n".join([f"Page:{i}" for i in page_title_list])

View File

@@ -10,36 +10,24 @@ from core.model_runtime.entities.model_entities import ModelType
@pytest.fixture
def lb_model_manager():
load_balancing_configs = [
ModelLoadBalancingConfiguration(
id='id1',
name='__inherit__',
credentials={}
),
ModelLoadBalancingConfiguration(
id='id2',
name='first',
credentials={"openai_api_key": "fake_key"}
),
ModelLoadBalancingConfiguration(
id='id3',
name='second',
credentials={"openai_api_key": "fake_key"}
)
ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}),
ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}),
ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}),
]
lb_model_manager = LBModelManager(
tenant_id='tenant_id',
provider='openai',
tenant_id="tenant_id",
provider="openai",
model_type=ModelType.LLM,
model='gpt-4',
model="gpt-4",
load_balancing_configs=load_balancing_configs,
managed_credentials={"openai_api_key": "fake_key"}
managed_credentials={"openai_api_key": "fake_key"},
)
lb_model_manager.cooldown = MagicMock(return_value=None)
def is_cooldown(config: ModelLoadBalancingConfiguration):
if config.id == 'id1':
if config.id == "id1":
return True
return False
@@ -61,14 +49,15 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
assert lb_model_manager.in_cooldown(config3) is False
start_index = 0
def incr(key):
nonlocal start_index
start_index += 1
return start_index
mocker.patch('redis.Redis.incr', side_effect=incr)
mocker.patch('redis.Redis.set', return_value=None)
mocker.patch('redis.Redis.expire', return_value=None)
mocker.patch("redis.Redis.incr", side_effect=incr)
mocker.patch("redis.Redis.set", return_value=None)
mocker.patch("redis.Redis.expire", return_value=None)
config = lb_model_manager.fetch_next()
assert config == config2

View File

@@ -11,62 +11,62 @@ def test__to_model_settings(mocker):
provider_entity = None
for provider in provider_entities:
if provider.provider == 'openai':
if provider.provider == "openai":
provider_entity = provider
# Mocking the inputs
provider_model_settings = [ProviderModelSetting(
id='id',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
enabled=True,
load_balancing_enabled=True
)]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id='id1',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='__inherit__',
encrypted_config=None,
enabled=True
),
LoadBalancingModelConfig(
id='id2',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='first',
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True
provider_model_settings = [
ProviderModelSetting(
id="id",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=True,
)
]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id="id1",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
),
LoadBalancingModelConfig(
id="id2",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
),
]
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity,
provider_model_settings,
load_balancing_model_configs
)
result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == 'gpt-4'
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 2
assert result[0].load_balancing_configs[0].name == '__inherit__'
assert result[0].load_balancing_configs[1].name == 'first'
assert result[0].load_balancing_configs[0].name == "__inherit__"
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker):
@@ -75,47 +75,47 @@ def test__to_model_settings_only_one_lb(mocker):
provider_entity = None
for provider in provider_entities:
if provider.provider == 'openai':
if provider.provider == "openai":
provider_entity = provider
# Mocking the inputs
provider_model_settings = [ProviderModelSetting(
id='id',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
enabled=True,
load_balancing_enabled=True
)]
provider_model_settings = [
ProviderModelSetting(
id="id",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=True,
)
]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id='id1',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='__inherit__',
id="id1",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True
enabled=True,
)
]
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity,
provider_model_settings,
load_balancing_model_configs
)
result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == 'gpt-4'
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0
@@ -127,57 +127,57 @@ def test__to_model_settings_lb_disabled(mocker):
provider_entity = None
for provider in provider_entities:
if provider.provider == 'openai':
if provider.provider == "openai":
provider_entity = provider
# Mocking the inputs
provider_model_settings = [ProviderModelSetting(
id='id',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
enabled=True,
load_balancing_enabled=False
)]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id='id1',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='__inherit__',
encrypted_config=None,
enabled=True
),
LoadBalancingModelConfig(
id='id2',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='first',
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True
provider_model_settings = [
ProviderModelSetting(
id="id",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=False,
)
]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id="id1",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
),
LoadBalancingModelConfig(
id="id2",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
),
]
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity,
provider_model_settings,
load_balancing_model_configs
)
result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == 'gpt-4'
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0

View File

@@ -5,52 +5,52 @@ from core.tools.utils.tool_parameter_converter import ToolParameterConverter
def test_get_parameter_type():
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string'
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string'
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean'
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number'
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number"
with pytest.raises(ValueError):
ToolParameterConverter.get_parameter_type('unsupported_type')
ToolParameterConverter.get_parameter_type("unsupported_type")
def test_cast_parameter_by_type():
# string
assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test'
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1'
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0'
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ''
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ""
# secret input
assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test'
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1'
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0'
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ''
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ""
# select
assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test'
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1'
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0'
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ''
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ""
# boolean
true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something']
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
for value in true_values:
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, '']
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
for value in false_values:
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
# number
assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
# unknown
assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1'
assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1'
assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1"
assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None

View File

@@ -11,29 +11,30 @@ from models.workflow import WorkflowNodeExecutionStatus
def test_execute_answer():
node = AnswerNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'answer',
'data': {
'title': '123',
'type': 'answer',
'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.'
}
}
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'weather'], 'sunny')
pool.add(['llm', 'text'], 'You are a helpful AI.')
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
)
pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.")
# Mock db.session.close()
db.session.close = MagicMock()
@@ -42,4 +43,4 @@ def test_execute_answer():
result = node._run(pool)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."

View File

@@ -11,134 +11,81 @@ from models.workflow import WorkflowNodeExecutionStatus
def test_execute_if_else_result_true():
node = IfElseNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'if-else',
'data': {
'title': '123',
'type': 'if-else',
'logical_operator': 'and',
'conditions': [
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
'comparison_operator': 'contains',
'variable_selector': ['start', 'array_contains'],
'value': 'ab'
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
'comparison_operator': 'not contains',
'variable_selector': ['start', 'array_not_contains'],
'value': 'ab'
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
'comparison_operator': 'contains',
'variable_selector': ['start', 'contains'],
'value': 'ab'
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
'comparison_operator': 'not contains',
'variable_selector': ['start', 'not_contains'],
'value': 'ab'
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{
'comparison_operator': 'start with',
'variable_selector': ['start', 'start_with'],
'value': 'ab'
},
{
'comparison_operator': 'end with',
'variable_selector': ['start', 'end_with'],
'value': 'ab'
},
{
'comparison_operator': 'is',
'variable_selector': ['start', 'is'],
'value': 'ab'
},
{
'comparison_operator': 'is not',
'variable_selector': ['start', 'is_not'],
'value': 'ab'
},
{
'comparison_operator': 'empty',
'variable_selector': ['start', 'empty'],
'value': 'ab'
},
{
'comparison_operator': 'not empty',
'variable_selector': ['start', 'not_empty'],
'value': 'ab'
},
{
'comparison_operator': '=',
'variable_selector': ['start', 'equals'],
'value': '22'
},
{
'comparison_operator': '',
'variable_selector': ['start', 'not_equals'],
'value': '22'
},
{
'comparison_operator': '>',
'variable_selector': ['start', 'greater_than'],
'value': '22'
},
{
'comparison_operator': '<',
'variable_selector': ['start', 'less_than'],
'value': '22'
},
{
'comparison_operator': '',
'variable_selector': ['start', 'greater_than_or_equal'],
'value': '22'
},
{
'comparison_operator': '',
'variable_selector': ['start', 'less_than_or_equal'],
'value': '22'
},
{
'comparison_operator': 'null',
'variable_selector': ['start', 'null']
},
{
'comparison_operator': 'not null',
'variable_selector': ['start', 'not_null']
},
]
}
}
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
},
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
pool.add(['start', 'contains'], 'cabcde')
pool.add(['start', 'not_contains'], 'zacde')
pool.add(['start', 'start_with'], 'abc')
pool.add(['start', 'end_with'], 'zzab')
pool.add(['start', 'is'], 'ab')
pool.add(['start', 'is_not'], 'aab')
pool.add(['start', 'empty'], '')
pool.add(['start', 'not_empty'], 'aaa')
pool.add(['start', 'equals'], 22)
pool.add(['start', 'not_equals'], 23)
pool.add(['start', 'greater_than'], 23)
pool.add(['start', 'less_than'], 21)
pool.add(['start', 'greater_than_or_equal'], 22)
pool.add(['start', 'less_than_or_equal'], 21)
pool.add(['start', 'not_null'], '1212')
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
)
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
pool.add(["start", "not_contains"], "zacde")
pool.add(["start", "start_with"], "abc")
pool.add(["start", "end_with"], "zzab")
pool.add(["start", "is"], "ab")
pool.add(["start", "is_not"], "aab")
pool.add(["start", "empty"], "")
pool.add(["start", "not_empty"], "aaa")
pool.add(["start", "equals"], 22)
pool.add(["start", "not_equals"], 23)
pool.add(["start", "greater_than"], 23)
pool.add(["start", "less_than"], 21)
pool.add(["start", "greater_than_or_equal"], 22)
pool.add(["start", "less_than_or_equal"], 21)
pool.add(["start", "not_null"], "1212")
# Mock db.session.close()
db.session.close = MagicMock()
@@ -147,46 +94,47 @@ def test_execute_if_else_result_true():
result = node._run(pool)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['result'] is True
assert result.outputs["result"] is True
def test_execute_if_else_result_false():
node = IfElseNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'if-else',
'data': {
'title': '123',
'type': 'if-else',
'logical_operator': 'or',
'conditions': [
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
'comparison_operator': 'contains',
'variable_selector': ['start', 'array_contains'],
'value': 'ab'
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
'comparison_operator': 'not contains',
'variable_selector': ['start', 'array_not_contains'],
'value': 'ab'
}
]
}
}
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
},
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
)
pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"])
# Mock db.session.close()
db.session.close = MagicMock()
@@ -195,4 +143,4 @@ def test_execute_if_else_result_false():
result = node._run(pool)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['result'] is False
assert result.outputs["result"] is False

View File

@@ -8,41 +8,41 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
DEFAULT_NODE_ID = 'node_id'
DEFAULT_NODE_ID = "node_id"
def test_overwrite_string_variable():
conversation_variable = StringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value='the first value',
name="test_conversation_variable",
value="the first value",
)
input_variable = StringVariable(
id=str(uuid4()),
name='test_string_variable',
value='the second value',
name="test_string_variable",
value="the second value",
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
tenant_id="tenant_id",
app_id="app_id",
workflow_id="workflow_id",
user_id="user_id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.OVER_WRITE.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
"id": "node_id",
"data": {
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -52,48 +52,48 @@ def test_overwrite_string_variable():
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.value == 'the second value'
assert got.to_object() == 'the second value'
assert got.value == "the second value"
assert got.to_object() == "the second value"
def test_append_variable_to_array():
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value=['the first value'],
name="test_conversation_variable",
value=["the first value"],
)
input_variable = StringVariable(
id=str(uuid4()),
name='test_string_variable',
value='the second value',
name="test_string_variable",
value="the second value",
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
tenant_id="tenant_id",
app_id="app_id",
workflow_id="workflow_id",
user_id="user_id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.APPEND.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
"id": "node_id",
"data": {
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -103,41 +103,41 @@ def test_append_variable_to_array():
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == ['the first value', 'the second value']
assert got.to_object() == ["the first value", "the second value"]
def test_clear_array():
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value=['the first value'],
name="test_conversation_variable",
value=["the first value"],
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
tenant_id="tenant_id",
app_id="app_id",
workflow_id="workflow_id",
user_id="user_id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.CLEAR.value,
'input_variable_selector': [],
"id": "node_id",
"data": {
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -145,6 +145,6 @@ def test_clear_array():
node.run(variable_pool)
got = variable_pool.get(['conversation', conversation_variable.name])
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == []