chore(api/tests): apply ruff reformat #7590 (#7591)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang
2024-08-23 23:52:25 +08:00
committed by GitHub
parent 2da63654e5
commit b035c02f78
155 changed files with 4279 additions and 5925 deletions

View File

@@ -6,18 +6,21 @@ from flask import Flask
from configs.app_config import DifyConfig
EXAMPLE_ENV_FILENAME = '.env'
EXAMPLE_ENV_FILENAME = ".env"
@pytest.fixture
def example_env_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(EXAMPLE_ENV_FILENAME)
file_path.write_text(dedent(
"""
file_path.write_text(
dedent(
"""
CONSOLE_API_URL=https://example.com
CONSOLE_WEB_URL=https://example.com
"""))
"""
)
)
return str(file_path)
@@ -29,7 +32,7 @@ def test_dify_config_undefined_entry(example_env_file):
# entries not defined in app settings
with pytest.raises(TypeError):
# TypeError: 'AppSettings' object is not subscriptable
assert config['LOG_LEVEL'] == 'INFO'
assert config["LOG_LEVEL"] == "INFO"
def test_dify_config(example_env_file):
@@ -37,10 +40,10 @@ def test_dify_config(example_env_file):
config = DifyConfig(_env_file=example_env_file)
# constant values
assert config.COMMIT_SHA == ''
assert config.COMMIT_SHA == ""
# default values
assert config.EDITION == 'SELF_HOSTED'
assert config.EDITION == "SELF_HOSTED"
assert config.API_COMPRESSION_ENABLED is False
assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0
@@ -48,36 +51,36 @@ def test_dify_config(example_env_file):
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
def test_flask_configs(example_env_file):
flask_app = Flask('app')
flask_app = Flask("app")
# clear system environment variables
os.environ.clear()
flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore
config = flask_app.config
# configs read from pydantic-settings
assert config['LOG_LEVEL'] == 'INFO'
assert config['COMMIT_SHA'] == ''
assert config['EDITION'] == 'SELF_HOSTED'
assert config['API_COMPRESSION_ENABLED'] is False
assert config['SENTRY_TRACES_SAMPLE_RATE'] == 1.0
assert config['TESTING'] == False
assert config["LOG_LEVEL"] == "INFO"
assert config["COMMIT_SHA"] == ""
assert config["EDITION"] == "SELF_HOSTED"
assert config["API_COMPRESSION_ENABLED"] is False
assert config["SENTRY_TRACES_SAMPLE_RATE"] == 1.0
assert config["TESTING"] == False
# value from env file
assert config['CONSOLE_API_URL'] == 'https://example.com'
assert config["CONSOLE_API_URL"] == "https://example.com"
# fallback to alias choices value as CONSOLE_API_URL
assert config['FILES_URL'] == 'https://example.com'
assert config["FILES_URL"] == "https://example.com"
assert config['SQLALCHEMY_DATABASE_URI'] == 'postgresql://postgres:@localhost:5432/dify'
assert config['SQLALCHEMY_ENGINE_OPTIONS'] == {
'connect_args': {
'options': '-c timezone=UTC',
assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:@localhost:5432/dify"
assert config["SQLALCHEMY_ENGINE_OPTIONS"] == {
"connect_args": {
"options": "-c timezone=UTC",
},
'max_overflow': 10,
'pool_pre_ping': False,
'pool_recycle': 3600,
'pool_size': 30,
"max_overflow": 10,
"pool_pre_ping": False,
"pool_recycle": 3600,
"pool_size": 30,
}
assert config['CONSOLE_WEB_URL']=='https://example.com'
assert config['CONSOLE_CORS_ALLOW_ORIGINS']==['https://example.com']
assert config['WEB_API_CORS_ALLOW_ORIGINS'] == ['*']
assert config["CONSOLE_WEB_URL"] == "https://example.com"
assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"]
assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["*"]

View File

@@ -17,31 +17,31 @@ from core.app.segments.exc import VariableError
def test_string_variable():
test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'}
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
result = factory.build_variable_from_mapping(test_data)
assert isinstance(result, StringVariable)
def test_integer_variable():
test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42}
test_data = {"value_type": "number", "name": "test_int", "value": 42}
result = factory.build_variable_from_mapping(test_data)
assert isinstance(result, IntegerVariable)
def test_float_variable():
test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14}
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
result = factory.build_variable_from_mapping(test_data)
assert isinstance(result, FloatVariable)
def test_secret_variable():
test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'}
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
result = factory.build_variable_from_mapping(test_data)
assert isinstance(result, SecretVariable)
def test_invalid_value_type():
test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'}
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
with pytest.raises(VariableError):
factory.build_variable_from_mapping(test_data)
@@ -49,51 +49,51 @@ def test_invalid_value_type():
def test_build_a_blank_string():
result = factory.build_variable_from_mapping(
{
'value_type': 'string',
'name': 'blank',
'value': '',
"value_type": "string",
"name": "blank",
"value": "",
}
)
assert isinstance(result, StringVariable)
assert result.value == ''
assert result.value == ""
def test_build_a_object_variable_with_none_value():
var = factory.build_segment(
{
'key1': None,
"key1": None,
}
)
assert isinstance(var, ObjectSegment)
assert var.value['key1'] is None
assert var.value["key1"] is None
def test_object_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'object',
'name': 'test_object',
'description': 'Description of the variable.',
'value': {
'key1': 'text',
'key2': 2,
"id": str(uuid4()),
"value_type": "object",
"name": "test_object",
"description": "Description of the variable.",
"value": {
"key1": "text",
"key2": 2,
},
}
variable = factory.build_variable_from_mapping(mapping)
assert isinstance(variable, ObjectSegment)
assert isinstance(variable.value['key1'], str)
assert isinstance(variable.value['key2'], int)
assert isinstance(variable.value["key1"], str)
assert isinstance(variable.value["key2"], int)
def test_array_string_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'array[string]',
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
'text',
'text',
"id": str(uuid4()),
"value_type": "array[string]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
"text",
"text",
],
}
variable = factory.build_variable_from_mapping(mapping)
@@ -104,11 +104,11 @@ def test_array_string_variable():
def test_array_number_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'array[number]',
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
"id": str(uuid4()),
"value_type": "array[number]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
1,
2.0,
],
@@ -121,18 +121,18 @@ def test_array_number_variable():
def test_array_object_variable():
mapping = {
'id': str(uuid4()),
'value_type': 'array[object]',
'name': 'test_array',
'description': 'Description of the variable.',
'value': [
"id": str(uuid4()),
"value_type": "array[object]",
"name": "test_array",
"description": "Description of the variable.",
"value": [
{
'key1': 'text',
'key2': 1,
"key1": "text",
"key2": 1,
},
{
'key1': 'text',
'key2': 1,
"key1": "text",
"key2": 1,
},
],
}
@@ -140,19 +140,19 @@ def test_array_object_variable():
assert isinstance(variable, ArrayObjectVariable)
assert isinstance(variable.value[0], dict)
assert isinstance(variable.value[1], dict)
assert isinstance(variable.value[0]['key1'], str)
assert isinstance(variable.value[0]['key2'], int)
assert isinstance(variable.value[1]['key1'], str)
assert isinstance(variable.value[1]['key2'], int)
assert isinstance(variable.value[0]["key1"], str)
assert isinstance(variable.value[0]["key2"], int)
assert isinstance(variable.value[1]["key1"], str)
assert isinstance(variable.value[1]["key2"], int)
def test_variable_cannot_large_than_5_kb():
with pytest.raises(VariableError):
factory.build_variable_from_mapping(
{
'id': str(uuid4()),
'value_type': 'string',
'name': 'test_text',
'value': 'a' * 1024 * 6,
"id": str(uuid4()),
"value_type": "string",
"name": "test_text",
"value": "a" * 1024 * 6,
}
)

View File

@@ -7,20 +7,20 @@ from core.workflow.enums import SystemVariableKey
def test_segment_group_to_text():
variable_pool = VariablePool(
system_variables={
SystemVariableKey('user_id'): 'fake-user-id',
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[
SecretVariable(name='secret_key', value='fake-secret-key'),
SecretVariable(name="secret_key", value="fake-secret-key"),
],
)
variable_pool.add(('node_id', 'custom_query'), 'fake-user-query')
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
template = (
'Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}.'
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
)
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
assert segments_group.text == 'Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key.'
assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key."
assert (
segments_group.log
== f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}."
@@ -33,22 +33,22 @@ def test_convert_constant_to_segment_group():
user_inputs={},
environment_variables=[],
)
template = 'Hello, world!'
template = "Hello, world!"
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
assert segments_group.text == 'Hello, world!'
assert segments_group.log == 'Hello, world!'
assert segments_group.text == "Hello, world!"
assert segments_group.log == "Hello, world!"
def test_convert_variable_to_segment_group():
variable_pool = VariablePool(
system_variables={
SystemVariableKey('user_id'): 'fake-user-id',
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={},
environment_variables=[],
)
template = '{{#sys.user_id#}}'
template = "{{#sys.user_id#}}"
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
assert segments_group.text == 'fake-user-id'
assert segments_group.log == 'fake-user-id'
assert segments_group.value == [StringSegment(value='fake-user-id')]
assert segments_group.text == "fake-user-id"
assert segments_group.log == "fake-user-id"
assert segments_group.value == [StringSegment(value="fake-user-id")]

View File

@@ -13,60 +13,60 @@ from core.app.segments import (
def test_frozen_variables():
var = StringVariable(name='text', value='text')
var = StringVariable(name="text", value="text")
with pytest.raises(ValidationError):
var.value = 'new value'
var.value = "new value"
int_var = IntegerVariable(name='integer', value=42)
int_var = IntegerVariable(name="integer", value=42)
with pytest.raises(ValidationError):
int_var.value = 100
float_var = FloatVariable(name='float', value=3.14)
float_var = FloatVariable(name="float", value=3.14)
with pytest.raises(ValidationError):
float_var.value = 2.718
secret_var = SecretVariable(name='secret', value='secret_value')
secret_var = SecretVariable(name="secret", value="secret_value")
with pytest.raises(ValidationError):
secret_var.value = 'new_secret_value'
secret_var.value = "new_secret_value"
def test_variable_value_type_immutable():
with pytest.raises(ValidationError):
StringVariable(value_type=SegmentType.ARRAY_ANY, name='text', value='text')
StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text")
with pytest.raises(ValidationError):
StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'})
StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"})
var = IntegerVariable(name='integer', value=42)
var = IntegerVariable(name="integer", value=42)
with pytest.raises(ValidationError):
IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
var = FloatVariable(name='float', value=3.14)
var = FloatVariable(name="float", value=3.14)
with pytest.raises(ValidationError):
FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
var = SecretVariable(name='secret', value='secret_value')
var = SecretVariable(name="secret", value="secret_value")
with pytest.raises(ValidationError):
SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
def test_object_variable_to_object():
var = ObjectVariable(
name='object',
name="object",
value={
'key1': {
'key2': 'value2',
"key1": {
"key2": "value2",
},
'key2': ['value5_1', 42, {}],
"key2": ["value5_1", 42, {}],
},
)
assert var.to_object() == {
'key1': {
'key2': 'value2',
"key1": {
"key2": "value2",
},
'key2': [
'value5_1',
"key2": [
"value5_1",
42,
{},
],
@@ -74,11 +74,11 @@ def test_object_variable_to_object():
def test_variable_to_object():
var = StringVariable(name='text', value='text')
assert var.to_object() == 'text'
var = IntegerVariable(name='integer', value=42)
var = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42
var = FloatVariable(name='float', value=3.14)
var = FloatVariable(name="float", value=3.14)
assert var.to_object() == 3.14
var = SecretVariable(name='secret', value='secret_value')
assert var.to_object() == 'secret_value'
var = SecretVariable(name="secret", value="secret_value")
assert var.to_object() == "secret_value"

View File

@@ -4,17 +4,17 @@ from unittest.mock import MagicMock, patch
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
@patch('httpx.request')
@patch("httpx.request")
def test_successful_request(mock_request):
mock_response = MagicMock()
mock_response.status_code = 200
mock_request.return_value = mock_response
response = make_request('GET', 'http://example.com')
response = make_request("GET", "http://example.com")
assert response.status_code == 200
@patch('httpx.request')
@patch("httpx.request")
def test_retry_exceed_max_retries(mock_request):
mock_response = MagicMock()
mock_response.status_code = 500
@@ -23,13 +23,13 @@ def test_retry_exceed_max_retries(mock_request):
mock_request.side_effect = side_effects
try:
make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
raise AssertionError("Expected Exception not raised")
except Exception as e:
assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
@patch('httpx.request')
@patch("httpx.request")
def test_retry_logic_success(mock_request):
side_effects = []
@@ -45,8 +45,8 @@ def test_retry_logic_success(mock_request):
mock_request.side_effect = side_effects
response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES)
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
assert response.status_code == 200
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
assert mock_request.call_args_list[0][1].get('method') == 'GET'
assert mock_request.call_args_list[0][1].get("method") == "GET"

View File

@@ -21,18 +21,18 @@ def test_max_chunks():
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
return _MockTextEmbedding()
model = 'embedding-v1'
model = "embedding-v1"
credentials = {
'api_key': 'xxxx',
'secret_key': 'yyyy',
"api_key": "xxxx",
"secret_key": "yyyy",
}
embedding_model = WenxinTextEmbeddingModel()
context_size = embedding_model._get_context_size(model, credentials)
max_chunks = embedding_model._get_max_chunks(model, credentials)
embedding_model._create_text_embedding = _create_text_embedding
texts = ['0123456789' for i in range(0, max_chunks * 2)]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
texts = ["0123456789" for i in range(0, max_chunks * 2)]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test")
assert len(result.embeddings) == max_chunks * 2
@@ -41,16 +41,16 @@ def test_context_size():
return GPT2Tokenizer.get_num_tokens(text)
def mock_text(token_size: int) -> str:
_text = "".join(['0' for i in range(token_size)])
_text = "".join(["0" for i in range(token_size)])
num_tokens = get_num_tokens_by_gpt2(_text)
ratio = int(np.floor(len(_text) / num_tokens))
m_text = "".join([_text for i in range(ratio)])
return m_text
model = 'embedding-v1'
model = "embedding-v1"
credentials = {
'api_key': 'xxxx',
'secret_key': 'yyyy',
"api_key": "xxxx",
"secret_key": "yyyy",
}
embedding_model = WenxinTextEmbeddingModel()
context_size = embedding_model._get_context_size(model, credentials)
@@ -71,5 +71,5 @@ def test_context_size():
assert get_num_tokens_by_gpt2(text) == context_size * 2
texts = [text]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test")
assert result.usage.tokens == context_size

View File

@@ -14,39 +14,24 @@ from models.model import Conversation
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-3.5-turbo-instruct'
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-3.5-turbo-instruct"
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
prompt_template_config = CompletionModelPromptTemplate(
text=prompt_template
)
prompt_template_config = CompletionModelPromptTemplate(text=prompt_template)
memory_config = MemoryConfig(
role_prefix=MemoryConfig.RolePrefix(
user="Human",
assistant="Assistant"
),
window=MemoryConfig.WindowConfig(
enabled=False
)
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
window=MemoryConfig.WindowConfig(enabled=False),
)
inputs = {
"name": "John"
}
inputs = {"name": "John"}
files = []
context = "I am superman."
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Hello")
]
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
@@ -59,16 +44,22 @@ def test__get_completion_model_prompt_messages():
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock
model_config=model_config_mock,
)
assert len(prompt_messages) == 1
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({
"#context#": context,
"#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: "
f"{prompt.content}" for prompt in history_prompt_messages]),
**inputs,
})
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format(
{
"#context#": context,
"#histories#": "\n".join(
[
f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " f"{prompt.content}"
for prompt in history_prompt_messages
]
),
**inputs,
}
)
def test__get_chat_model_prompt_messages(get_chat_model_args):
@@ -77,15 +68,9 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
files = []
query = "Hi2."
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [
UserPromptMessage(content="Hi1."),
AssistantPromptMessage(content="Hello1!")
]
history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = AdvancedPromptTransform()
@@ -98,14 +83,14 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
context=context,
memory_config=memory_config,
memory=memory,
model_config=model_config_mock
model_config=model_config_mock,
)
assert len(prompt_messages) == 6
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=messages[0].text
).format({**inputs, "#context#": context})
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
assert prompt_messages[5].content == query
@@ -124,14 +109,14 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock
model_config=model_config_mock,
)
assert len(prompt_messages) == 3
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=messages[0].text
).format({**inputs, "#context#": context})
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
@@ -148,7 +133,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
image_config={
"detail": "high",
}
)
),
)
]
@@ -162,14 +147,14 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
context=context,
memory_config=None,
memory=None,
model_config=model_config_mock
model_config=model_config_mock,
)
assert len(prompt_messages) == 4
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
assert prompt_messages[0].content == PromptTemplateParser(
template=messages[0].text
).format({**inputs, "#context#": context})
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
{**inputs, "#context#": context}
)
assert isinstance(prompt_messages[3].content, list)
assert len(prompt_messages[3].content) == 2
assert prompt_messages[3].content[1].data == files[0].url
@@ -178,33 +163,20 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
@pytest.fixture
def get_chat_model_args():
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4'
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-4"
memory_config = MemoryConfig(
window=MemoryConfig.WindowConfig(
enabled=False
)
)
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
prompt_messages = [
ChatModelMessage(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM
),
ChatModelMessage(
text="Hi.",
role=PromptMessageRole.USER
),
ChatModelMessage(
text="Hello!",
role=PromptMessageRole.ASSISTANT
)
ChatModelMessage(text="Hi.", role=PromptMessageRole.USER),
ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
inputs = {
"name": "John"
}
inputs = {"name": "John"}
context = "I am superman."

View File

@@ -18,27 +18,28 @@ from models.model import Conversation
def test_get_prompt():
prompt_messages = [
SystemPromptMessage(content='System Template'),
UserPromptMessage(content='User Query'),
SystemPromptMessage(content="System Template"),
UserPromptMessage(content="User Query"),
]
history_messages = [
SystemPromptMessage(content='System Prompt 1'),
UserPromptMessage(content='User Prompt 1'),
AssistantPromptMessage(content='Assistant Thought 1'),
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
SystemPromptMessage(content='System Prompt 2'),
UserPromptMessage(content='User Prompt 2'),
AssistantPromptMessage(content='Assistant Thought 2'),
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
UserPromptMessage(content='User Prompt 3'),
AssistantPromptMessage(content='Assistant Thought 3'),
SystemPromptMessage(content="System Prompt 1"),
UserPromptMessage(content="User Prompt 1"),
AssistantPromptMessage(content="Assistant Thought 1"),
ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"),
ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"),
SystemPromptMessage(content="System Prompt 2"),
UserPromptMessage(content="User Prompt 2"),
AssistantPromptMessage(content="Assistant Thought 2"),
ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"),
ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"),
UserPromptMessage(content="User Prompt 3"),
AssistantPromptMessage(content="Assistant Thought 3"),
]
# use message number instead of token for testing
def side_effect_get_num_tokens(*args):
return len(args[2])
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
@@ -46,20 +47,17 @@ def test_get_prompt():
provider_model_bundle_mock.model_type_instance = large_language_model_mock
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.model = 'openai'
model_config_mock.model = "openai"
model_config_mock.credentials = {}
model_config_mock.provider_model_bundle = provider_model_bundle_mock
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
transform = AgentHistoryPromptTransform(
model_config=model_config_mock,
prompt_messages=prompt_messages,
history_messages=history_messages,
memory=memory
memory=memory,
)
max_token_limit = 5

View File

@@ -12,19 +12,15 @@ from core.prompt.prompt_transform import PromptTransform
def test__calculate_rest_token():
model_schema_mock = MagicMock(spec=AIModelEntity)
parameter_rule_mock = MagicMock(spec=ParameterRule)
parameter_rule_mock.name = 'max_tokens'
model_schema_mock.parameter_rules = [
parameter_rule_mock
]
model_schema_mock.model_properties = {
ModelPropertyKey.CONTEXT_SIZE: 62
}
parameter_rule_mock.name = "max_tokens"
model_schema_mock.parameter_rules = [parameter_rule_mock]
model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
large_language_model_mock.get_num_tokens.return_value = 6
provider_mock = MagicMock(spec=ProviderEntity)
provider_mock.provider = 'openai'
provider_mock.provider = "openai"
provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
provider_configuration_mock.provider = provider_mock
@@ -35,11 +31,9 @@ def test__calculate_rest_token():
provider_model_bundle_mock.configuration = provider_configuration_mock
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.model = 'gpt-4'
model_config_mock.model = "gpt-4"
model_config_mock.credentials = {}
model_config_mock.parameters = {
'max_tokens': 50
}
model_config_mock.parameters = {"max_tokens": 50}
model_config_mock.model_schema = model_schema_mock
model_config_mock.provider_model_bundle = provider_model_bundle_mock
@@ -49,8 +43,10 @@ def test__calculate_rest_token():
rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
# Validate based on the mock configuration and expected logic
expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
- model_config_mock.parameters['max_tokens']
- large_language_model_mock.get_num_tokens.return_value)
expected_rest_tokens = (
model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
- model_config_mock.parameters["max_tokens"]
- large_language_model_mock.get_num_tokens.return_value
)
assert rest_tokens == expected_rest_tokens
assert rest_tokens == 6

View File

@@ -19,12 +19,15 @@ def test_get_common_chat_app_prompt_template_with_pcqm():
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['histories_prompt']
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"]
+ pre_prompt
+ "\n"
+ prompt_rules["histories_prompt"]
+ prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
def test_get_baichuan_chat_app_prompt_template_with_pcqm():
@@ -39,12 +42,15 @@ def test_get_baichuan_chat_app_prompt_template_with_pcqm():
query_in_prompt=True,
with_memory_prompt=True,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['histories_prompt']
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"]
+ pre_prompt
+ "\n"
+ prompt_rules["histories_prompt"]
+ prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
def test_get_common_completion_app_prompt_template_with_pcq():
@@ -59,11 +65,11 @@ def test_get_common_completion_app_prompt_template_with_pcq():
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_baichuan_completion_app_prompt_template_with_pcq():
@@ -78,12 +84,12 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq():
query_in_prompt=True,
with_memory_prompt=False,
)
print(prompt_template['prompt_template'].template)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ pre_prompt + '\n'
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
print(prompt_template["prompt_template"].template)
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_common_chat_app_prompt_template_with_q():
@@ -98,9 +104,9 @@ def test_get_common_chat_app_prompt_template_with_q():
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == prompt_rules['query_prompt']
assert prompt_template['special_variable_keys'] == ['#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"]
assert prompt_template["special_variable_keys"] == ["#query#"]
def test_get_common_chat_app_prompt_template_with_cq():
@@ -115,10 +121,11 @@ def test_get_common_chat_app_prompt_template_with_cq():
query_in_prompt=True,
with_memory_prompt=False,
)
prompt_rules = prompt_template['prompt_rules']
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
+ prompt_rules['query_prompt'])
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
prompt_rules = prompt_template["prompt_rules"]
assert prompt_template["prompt_template"].template == (
prompt_rules["context_prompt"] + prompt_rules["query_prompt"]
)
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
def test_get_common_chat_app_prompt_template_with_p():
@@ -133,30 +140,25 @@ def test_get_common_chat_app_prompt_template_with_p():
query_in_prompt=False,
with_memory_prompt=False,
)
assert prompt_template['prompt_template'].template == pre_prompt + '\n'
assert prompt_template['custom_variable_keys'] == ['name']
assert prompt_template['special_variable_keys'] == []
assert prompt_template["prompt_template"].template == pre_prompt + "\n"
assert prompt_template["custom_variable_keys"] == ["name"]
assert prompt_template["special_variable_keys"] == []
def test__get_chat_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-4'
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-4"
memory_mock = MagicMock(spec=TokenBufferMemory)
history_prompt_messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Hello")
]
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory_mock.get_history_prompt_messages.return_value = history_prompt_messages
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {
"name": "John"
}
inputs = {"name": "John"}
context = "yes or no."
query = "How are you?"
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
@@ -167,7 +169,7 @@ def test__get_chat_model_prompt_messages():
files=[],
context=context,
memory=memory_mock,
model_config=model_config_mock
model_config=model_config_mock,
)
prompt_template = prompt_transform.get_prompt_template(
@@ -180,8 +182,8 @@ def test__get_chat_model_prompt_messages():
with_memory_prompt=False,
)
full_inputs = {**inputs, '#context#': context}
real_system_prompt = prompt_template['prompt_template'].format(full_inputs)
full_inputs = {**inputs, "#context#": context}
real_system_prompt = prompt_template["prompt_template"].format(full_inputs)
assert len(prompt_messages) == 4
assert prompt_messages[0].content == real_system_prompt
@@ -192,26 +194,18 @@ def test__get_chat_model_prompt_messages():
def test__get_completion_model_prompt_messages():
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
model_config_mock.provider = 'openai'
model_config_mock.model = 'gpt-3.5-turbo-instruct'
model_config_mock.provider = "openai"
model_config_mock.model = "gpt-3.5-turbo-instruct"
memory = TokenBufferMemory(
conversation=Conversation(),
model_instance=model_config_mock
)
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
history_prompt_messages = [
UserPromptMessage(content="Hi"),
AssistantPromptMessage(content="Hello")
]
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
prompt_transform = SimplePromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
pre_prompt = "You are a helpful assistant {{name}}."
inputs = {
"name": "John"
}
inputs = {"name": "John"}
context = "yes or no."
query = "How are you?"
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
@@ -222,7 +216,7 @@ def test__get_completion_model_prompt_messages():
files=[],
context=context,
memory=memory,
model_config=model_config_mock
model_config=model_config_mock,
)
prompt_template = prompt_transform.get_prompt_template(
@@ -235,14 +229,19 @@ def test__get_completion_model_prompt_messages():
with_memory_prompt=True,
)
prompt_rules = prompt_template['prompt_rules']
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
max_token_limit=2000,
human_prefix=prompt_rules.get("human_prefix", "Human"),
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant")
)}
real_prompt = prompt_template['prompt_template'].format(full_inputs)
prompt_rules = prompt_template["prompt_rules"]
full_inputs = {
**inputs,
"#context#": context,
"#query#": query,
"#histories#": memory.get_history_prompt_text(
max_token_limit=2000,
human_prefix=prompt_rules.get("human_prefix", "Human"),
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
),
}
real_prompt = prompt_template["prompt_template"].format(full_inputs)
assert len(prompt_messages) == 1
assert stops == prompt_rules.get('stops')
assert stops == prompt_rules.get("stops")
assert prompt_messages[0].content == real_prompt

View File

@@ -5,20 +5,15 @@ from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
def test_default_value():
valid_config = {
'host': 'localhost',
'port': 19530,
'user': 'root',
'password': 'Milvus'
}
valid_config = {"host": "localhost", "port": 19530, "user": "root", "password": "Milvus"}
for key in valid_config:
config = valid_config.copy()
del config[key]
with pytest.raises(ValidationError) as e:
MilvusConfig(**config)
assert e.value.errors()[0]['msg'] == f'Value error, config MILVUS_{key.upper()} is required'
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
config = MilvusConfig(**valid_config)
assert config.secure is False
assert config.database == 'default'
assert config.database == "default"

View File

@@ -9,19 +9,17 @@ from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_resp
def test_firecrawl_web_extractor_crawl_mode(mocker):
url = "https://firecrawl.dev"
api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-'
base_url = 'https://api.firecrawl.dev'
firecrawl_app = FirecrawlApp(api_key=api_key,
base_url=base_url)
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
base_url = "https://api.firecrawl.dev"
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url)
params = {
'crawlerOptions': {
"crawlerOptions": {
"includes": [],
"excludes": [],
"generateImgAltText": True,
"maxDepth": 1,
"limit": 1,
'returnOnlyUrls': False,
"returnOnlyUrls": False,
}
}
mocked_firecrawl = {

View File

@@ -8,11 +8,8 @@ page_id = "page1"
extractor = notion_extractor.NotionExtractor(
notion_workspace_id='x',
notion_obj_id='x',
notion_page_type='page',
tenant_id='x',
notion_access_token='x')
notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x"
)
def _generate_page(page_title: str):
@@ -21,16 +18,10 @@ def _generate_page(page_title: str):
"id": page_id,
"properties": {
"Page": {
"type": "title",
"title": [
{
"type": "text",
"text": {"content": page_title},
"plain_text": page_title
}
]
"type": "title",
"title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}],
}
}
},
}
@@ -38,10 +29,7 @@ def _generate_block(block_id: str, block_type: str, block_text: str):
return {
"object": "block",
"id": block_id,
"parent": {
"type": "page_id",
"page_id": page_id
},
"parent": {"type": "page_id", "page_id": page_id},
"type": block_type,
"has_children": False,
block_type: {
@@ -49,10 +37,11 @@ def _generate_block(block_id: str, block_type: str, block_text: str):
{
"type": "text",
"text": {"content": block_text},
"plain_text": block_text,
}]
}
}
"plain_text": block_text,
}
]
},
}
def _mock_response(data):
@@ -63,7 +52,7 @@ def _mock_response(data):
def _remove_multiple_new_lines(text):
while '\n\n' in text:
while "\n\n" in text:
text = text.replace("\n\n", "\n")
return text.strip()
@@ -71,21 +60,21 @@ def _remove_multiple_new_lines(text):
def test_notion_page(mocker):
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
mocked_notion_page = {
"object": "list",
"results": [
_generate_block("b1", "heading_1", texts[0]),
_generate_block("b2", "heading_2", texts[1]),
_generate_block("b3", "paragraph", texts[2]),
_generate_block("b4", "heading_3", texts[3])
],
"next_cursor": None
"object": "list",
"results": [
_generate_block("b1", "heading_1", texts[0]),
_generate_block("b2", "heading_2", texts[1]),
_generate_block("b3", "paragraph", texts[2]),
_generate_block("b4", "heading_3", texts[3]),
],
"next_cursor": None,
}
mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
page_docs = extractor._load_data_as_documents(page_id, "page")
assert len(page_docs) == 1
content = _remove_multiple_new_lines(page_docs[0].page_content)
assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1'
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
def test_notion_database(mocker):
@@ -93,10 +82,10 @@ def test_notion_database(mocker):
mocked_notion_database = {
"object": "list",
"results": [_generate_page(i) for i in page_title_list],
"next_cursor": None
"next_cursor": None,
}
mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
database_docs = extractor._load_data_as_documents(database_id, "database")
assert len(database_docs) == 1
content = _remove_multiple_new_lines(database_docs[0].page_content)
assert content == '\n'.join([f'Page:{i}' for i in page_title_list])
assert content == "\n".join([f"Page:{i}" for i in page_title_list])

View File

@@ -10,36 +10,24 @@ from core.model_runtime.entities.model_entities import ModelType
@pytest.fixture
def lb_model_manager():
load_balancing_configs = [
ModelLoadBalancingConfiguration(
id='id1',
name='__inherit__',
credentials={}
),
ModelLoadBalancingConfiguration(
id='id2',
name='first',
credentials={"openai_api_key": "fake_key"}
),
ModelLoadBalancingConfiguration(
id='id3',
name='second',
credentials={"openai_api_key": "fake_key"}
)
ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}),
ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}),
ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}),
]
lb_model_manager = LBModelManager(
tenant_id='tenant_id',
provider='openai',
tenant_id="tenant_id",
provider="openai",
model_type=ModelType.LLM,
model='gpt-4',
model="gpt-4",
load_balancing_configs=load_balancing_configs,
managed_credentials={"openai_api_key": "fake_key"}
managed_credentials={"openai_api_key": "fake_key"},
)
lb_model_manager.cooldown = MagicMock(return_value=None)
def is_cooldown(config: ModelLoadBalancingConfiguration):
if config.id == 'id1':
if config.id == "id1":
return True
return False
@@ -61,14 +49,15 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
assert lb_model_manager.in_cooldown(config3) is False
start_index = 0
def incr(key):
nonlocal start_index
start_index += 1
return start_index
mocker.patch('redis.Redis.incr', side_effect=incr)
mocker.patch('redis.Redis.set', return_value=None)
mocker.patch('redis.Redis.expire', return_value=None)
mocker.patch("redis.Redis.incr", side_effect=incr)
mocker.patch("redis.Redis.set", return_value=None)
mocker.patch("redis.Redis.expire", return_value=None)
config = lb_model_manager.fetch_next()
assert config == config2

View File

@@ -11,62 +11,62 @@ def test__to_model_settings(mocker):
provider_entity = None
for provider in provider_entities:
if provider.provider == 'openai':
if provider.provider == "openai":
provider_entity = provider
# Mocking the inputs
provider_model_settings = [ProviderModelSetting(
id='id',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
enabled=True,
load_balancing_enabled=True
)]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id='id1',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='__inherit__',
encrypted_config=None,
enabled=True
),
LoadBalancingModelConfig(
id='id2',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='first',
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True
provider_model_settings = [
ProviderModelSetting(
id="id",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=True,
)
]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id="id1",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
),
LoadBalancingModelConfig(
id="id2",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
),
]
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity,
provider_model_settings,
load_balancing_model_configs
)
result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == 'gpt-4'
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 2
assert result[0].load_balancing_configs[0].name == '__inherit__'
assert result[0].load_balancing_configs[1].name == 'first'
assert result[0].load_balancing_configs[0].name == "__inherit__"
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker):
@@ -75,47 +75,47 @@ def test__to_model_settings_only_one_lb(mocker):
provider_entity = None
for provider in provider_entities:
if provider.provider == 'openai':
if provider.provider == "openai":
provider_entity = provider
# Mocking the inputs
provider_model_settings = [ProviderModelSetting(
id='id',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
enabled=True,
load_balancing_enabled=True
)]
provider_model_settings = [
ProviderModelSetting(
id="id",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=True,
)
]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id='id1',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='__inherit__',
id="id1",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True
enabled=True,
)
]
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity,
provider_model_settings,
load_balancing_model_configs
)
result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == 'gpt-4'
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0
@@ -127,57 +127,57 @@ def test__to_model_settings_lb_disabled(mocker):
provider_entity = None
for provider in provider_entities:
if provider.provider == 'openai':
if provider.provider == "openai":
provider_entity = provider
# Mocking the inputs
provider_model_settings = [ProviderModelSetting(
id='id',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
enabled=True,
load_balancing_enabled=False
)]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id='id1',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='__inherit__',
encrypted_config=None,
enabled=True
),
LoadBalancingModelConfig(
id='id2',
tenant_id='tenant_id',
provider_name='openai',
model_name='gpt-4',
model_type='text-generation',
name='first',
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True
provider_model_settings = [
ProviderModelSetting(
id="id",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
enabled=True,
load_balancing_enabled=False,
)
]
load_balancing_model_configs = [
LoadBalancingModelConfig(
id="id1",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="__inherit__",
encrypted_config=None,
enabled=True,
),
LoadBalancingModelConfig(
id="id2",
tenant_id="tenant_id",
provider_name="openai",
model_name="gpt-4",
model_type="text-generation",
name="first",
encrypted_config='{"openai_api_key": "fake_key"}',
enabled=True,
),
]
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity,
provider_model_settings,
load_balancing_model_configs
)
result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs)
# Asserting that the result is as expected
assert len(result) == 1
assert isinstance(result[0], ModelSettings)
assert result[0].model == 'gpt-4'
assert result[0].model == "gpt-4"
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0

View File

@@ -5,52 +5,52 @@ from core.tools.utils.tool_parameter_converter import ToolParameterConverter
def test_get_parameter_type():
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string'
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string'
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean'
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number'
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean"
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number"
with pytest.raises(ValueError):
ToolParameterConverter.get_parameter_type('unsupported_type')
ToolParameterConverter.get_parameter_type("unsupported_type")
def test_cast_parameter_by_type():
# string
assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test'
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1'
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0'
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ''
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ""
# secret input
assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test'
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1'
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0'
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ''
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ""
# select
assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test'
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1'
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0'
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ''
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test"
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1"
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ""
# boolean
true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something']
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
for value in true_values:
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, '']
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
for value in false_values:
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
# number
assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
# unknown
assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1'
assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1'
assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1"
assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1"
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None

View File

@@ -11,29 +11,30 @@ from models.workflow import WorkflowNodeExecutionStatus
def test_execute_answer():
node = AnswerNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'answer',
'data': {
'title': '123',
'type': 'answer',
'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.'
}
}
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'weather'], 'sunny')
pool.add(['llm', 'text'], 'You are a helpful AI.')
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
)
pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.")
# Mock db.session.close()
db.session.close = MagicMock()
@@ -42,4 +43,4 @@ def test_execute_answer():
result = node._run(pool)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."

View File

@@ -11,134 +11,81 @@ from models.workflow import WorkflowNodeExecutionStatus
def test_execute_if_else_result_true():
node = IfElseNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'if-else',
'data': {
'title': '123',
'type': 'if-else',
'logical_operator': 'and',
'conditions': [
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
'comparison_operator': 'contains',
'variable_selector': ['start', 'array_contains'],
'value': 'ab'
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
'comparison_operator': 'not contains',
'variable_selector': ['start', 'array_not_contains'],
'value': 'ab'
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
'comparison_operator': 'contains',
'variable_selector': ['start', 'contains'],
'value': 'ab'
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
'comparison_operator': 'not contains',
'variable_selector': ['start', 'not_contains'],
'value': 'ab'
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{
'comparison_operator': 'start with',
'variable_selector': ['start', 'start_with'],
'value': 'ab'
},
{
'comparison_operator': 'end with',
'variable_selector': ['start', 'end_with'],
'value': 'ab'
},
{
'comparison_operator': 'is',
'variable_selector': ['start', 'is'],
'value': 'ab'
},
{
'comparison_operator': 'is not',
'variable_selector': ['start', 'is_not'],
'value': 'ab'
},
{
'comparison_operator': 'empty',
'variable_selector': ['start', 'empty'],
'value': 'ab'
},
{
'comparison_operator': 'not empty',
'variable_selector': ['start', 'not_empty'],
'value': 'ab'
},
{
'comparison_operator': '=',
'variable_selector': ['start', 'equals'],
'value': '22'
},
{
'comparison_operator': '',
'variable_selector': ['start', 'not_equals'],
'value': '22'
},
{
'comparison_operator': '>',
'variable_selector': ['start', 'greater_than'],
'value': '22'
},
{
'comparison_operator': '<',
'variable_selector': ['start', 'less_than'],
'value': '22'
},
{
'comparison_operator': '',
'variable_selector': ['start', 'greater_than_or_equal'],
'value': '22'
},
{
'comparison_operator': '',
'variable_selector': ['start', 'less_than_or_equal'],
'value': '22'
},
{
'comparison_operator': 'null',
'variable_selector': ['start', 'null']
},
{
'comparison_operator': 'not null',
'variable_selector': ['start', 'not_null']
},
]
}
}
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
},
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
pool.add(['start', 'contains'], 'cabcde')
pool.add(['start', 'not_contains'], 'zacde')
pool.add(['start', 'start_with'], 'abc')
pool.add(['start', 'end_with'], 'zzab')
pool.add(['start', 'is'], 'ab')
pool.add(['start', 'is_not'], 'aab')
pool.add(['start', 'empty'], '')
pool.add(['start', 'not_empty'], 'aaa')
pool.add(['start', 'equals'], 22)
pool.add(['start', 'not_equals'], 23)
pool.add(['start', 'greater_than'], 23)
pool.add(['start', 'less_than'], 21)
pool.add(['start', 'greater_than_or_equal'], 22)
pool.add(['start', 'less_than_or_equal'], 21)
pool.add(['start', 'not_null'], '1212')
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
)
pool.add(["start", "array_contains"], ["ab", "def"])
pool.add(["start", "array_not_contains"], ["ac", "def"])
pool.add(["start", "contains"], "cabcde")
pool.add(["start", "not_contains"], "zacde")
pool.add(["start", "start_with"], "abc")
pool.add(["start", "end_with"], "zzab")
pool.add(["start", "is"], "ab")
pool.add(["start", "is_not"], "aab")
pool.add(["start", "empty"], "")
pool.add(["start", "not_empty"], "aaa")
pool.add(["start", "equals"], 22)
pool.add(["start", "not_equals"], 23)
pool.add(["start", "greater_than"], 23)
pool.add(["start", "less_than"], 21)
pool.add(["start", "greater_than_or_equal"], 22)
pool.add(["start", "less_than_or_equal"], 21)
pool.add(["start", "not_null"], "1212")
# Mock db.session.close()
db.session.close = MagicMock()
@@ -147,46 +94,47 @@ def test_execute_if_else_result_true():
result = node._run(pool)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['result'] is True
assert result.outputs["result"] is True
def test_execute_if_else_result_false():
node = IfElseNode(
tenant_id='1',
app_id='1',
workflow_id='1',
user_id='1',
tenant_id="1",
app_id="1",
workflow_id="1",
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'if-else',
'data': {
'title': '123',
'type': 'if-else',
'logical_operator': 'or',
'conditions': [
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
'comparison_operator': 'contains',
'variable_selector': ['start', 'array_contains'],
'value': 'ab'
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
'comparison_operator': 'not contains',
'variable_selector': ['start', 'array_not_contains'],
'value': 'ab'
}
]
}
}
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
},
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariableKey.FILES: [],
SystemVariableKey.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
pool = VariablePool(
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
user_inputs={},
environment_variables=[],
)
pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"])
# Mock db.session.close()
db.session.close = MagicMock()
@@ -195,4 +143,4 @@ def test_execute_if_else_result_false():
result = node._run(pool)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['result'] is False
assert result.outputs["result"] is False

View File

@@ -8,41 +8,41 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
DEFAULT_NODE_ID = 'node_id'
DEFAULT_NODE_ID = "node_id"
def test_overwrite_string_variable():
conversation_variable = StringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value='the first value',
name="test_conversation_variable",
value="the first value",
)
input_variable = StringVariable(
id=str(uuid4()),
name='test_string_variable',
value='the second value',
name="test_string_variable",
value="the second value",
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
tenant_id="tenant_id",
app_id="app_id",
workflow_id="workflow_id",
user_id="user_id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.OVER_WRITE.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
"id": "node_id",
"data": {
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -52,48 +52,48 @@ def test_overwrite_string_variable():
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.value == 'the second value'
assert got.to_object() == 'the second value'
assert got.value == "the second value"
assert got.to_object() == "the second value"
def test_append_variable_to_array():
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value=['the first value'],
name="test_conversation_variable",
value=["the first value"],
)
input_variable = StringVariable(
id=str(uuid4()),
name='test_string_variable',
value='the second value',
name="test_string_variable",
value="the second value",
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
tenant_id="tenant_id",
app_id="app_id",
workflow_id="workflow_id",
user_id="user_id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.APPEND.value,
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
"id": "node_id",
"data": {
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -103,41 +103,41 @@ def test_append_variable_to_array():
input_variable,
)
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run:
node.run(variable_pool)
mock_run.assert_called_once()
got = variable_pool.get(['conversation', conversation_variable.name])
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == ['the first value', 'the second value']
assert got.to_object() == ["the first value", "the second value"]
def test_clear_array():
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name='test_conversation_variable',
value=['the first value'],
name="test_conversation_variable",
value=["the first value"],
)
node = VariableAssignerNode(
tenant_id='tenant_id',
app_id='app_id',
workflow_id='workflow_id',
user_id='user_id',
tenant_id="tenant_id",
app_id="app_id",
workflow_id="workflow_id",
user_id="user_id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config={
'id': 'node_id',
'data': {
'assigned_variable_selector': ['conversation', conversation_variable.name],
'write_mode': WriteMode.CLEAR.value,
'input_variable_selector': [],
"id": "node_id",
"data": {
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
},
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
@@ -145,6 +145,6 @@ def test_clear_array():
node.run(variable_pool)
got = variable_pool.get(['conversation', conversation_variable.name])
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == []

View File

@@ -3,50 +3,46 @@ import pandas as pd
def test_pandas_csv(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
data = {'col1': [1, 2.2, -3.3, 4.0, 5],
'col2': ['A', 'B', 'C', 'D', 'E']}
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data)
# write to csv file
csv_file_path = tmp_path.joinpath('example.csv')
csv_file_path = tmp_path.joinpath("example.csv")
df1.to_csv(csv_file_path, index=False)
# read from csv file
df2 = pd.read_csv(csv_file_path, on_bad_lines='skip')
assert df2[df2.columns[0]].to_list() == data['col1']
assert df2[df2.columns[1]].to_list() == data['col2']
df2 = pd.read_csv(csv_file_path, on_bad_lines="skip")
assert df2[df2.columns[0]].to_list() == data["col1"]
assert df2[df2.columns[1]].to_list() == data["col2"]
def test_pandas_xlsx(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
data = {'col1': [1, 2.2, -3.3, 4.0, 5],
'col2': ['A', 'B', 'C', 'D', 'E']}
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data)
# write to xlsx file
xlsx_file_path = tmp_path.joinpath('example.xlsx')
xlsx_file_path = tmp_path.joinpath("example.xlsx")
df1.to_excel(xlsx_file_path, index=False)
# read from xlsx file
df2 = pd.read_excel(xlsx_file_path)
assert df2[df2.columns[0]].to_list() == data['col1']
assert df2[df2.columns[1]].to_list() == data['col2']
assert df2[df2.columns[0]].to_list() == data["col1"]
assert df2[df2.columns[1]].to_list() == data["col2"]
def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
data1 = {'col1': [1, 2, 3, 4, 5],
'col2': ['A', 'B', 'C', 'D', 'E']}
data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]}
df1 = pd.DataFrame(data1)
data2 = {'col1': [6, 7, 8, 9, 10],
'col2': ['F', 'G', 'H', 'I', 'J']}
data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]}
df2 = pd.DataFrame(data2)
# write to xlsx file with sheets
xlsx_file_path = tmp_path.joinpath('example_with_sheets.xlsx')
sheet1 = 'Sheet1'
sheet2 = 'Sheet2'
xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx")
sheet1 = "Sheet1"
sheet2 = "Sheet2"
with pd.ExcelWriter(xlsx_file_path) as excel_writer:
df1.to_excel(excel_writer, sheet_name=sheet1, index=False)
df2.to_excel(excel_writer, sheet_name=sheet2, index=False)
@@ -54,9 +50,9 @@ def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch):
# read from xlsx file with sheets
with pd.ExcelFile(xlsx_file_path) as excel_file:
df1 = pd.read_excel(excel_file, sheet_name=sheet1)
assert df1[df1.columns[0]].to_list() == data1['col1']
assert df1[df1.columns[1]].to_list() == data1['col2']
assert df1[df1.columns[0]].to_list() == data1["col1"]
assert df1[df1.columns[1]].to_list() == data1["col2"]
df2 = pd.read_excel(excel_file, sheet_name=sheet2)
assert df2[df2.columns[0]].to_list() == data2['col1']
assert df2[df2.columns[1]].to_list() == data2['col2']
assert df2[df2.columns[0]].to_list() == data2["col1"]
assert df2[df2.columns[1]].to_list() == data2["col2"]

View File

@@ -15,7 +15,7 @@ def test_gmpy2_pkcs10aep_cipher() -> None:
private_rsa_key = RSA.import_key(private_key)
private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key)
raw_text = 'raw_text'
raw_text = "raw_text"
raw_text_bytes = raw_text.encode()
# RSA encryption by public key and decryption by private key

View File

@@ -3,21 +3,21 @@ from yarl import URL
def test_yarl_urls():
expected_1 = 'https://dify.ai/api'
assert str(URL('https://dify.ai') / 'api') == expected_1
assert str(URL('https://dify.ai/') / 'api') == expected_1
expected_1 = "https://dify.ai/api"
assert str(URL("https://dify.ai") / "api") == expected_1
assert str(URL("https://dify.ai/") / "api") == expected_1
expected_2 = 'http://dify.ai:12345/api'
assert str(URL('http://dify.ai:12345') / 'api') == expected_2
assert str(URL('http://dify.ai:12345/') / 'api') == expected_2
expected_2 = "http://dify.ai:12345/api"
assert str(URL("http://dify.ai:12345") / "api") == expected_2
assert str(URL("http://dify.ai:12345/") / "api") == expected_2
expected_3 = 'https://dify.ai/api/v1'
assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3
assert str(URL('https://dify.ai') / 'api/v1') == expected_3
assert str(URL('https://dify.ai/') / 'api/v1') == expected_3
assert str(URL('https://dify.ai/api') / 'v1') == expected_3
assert str(URL('https://dify.ai/api/') / 'v1') == expected_3
expected_3 = "https://dify.ai/api/v1"
assert str(URL("https://dify.ai") / "api" / "v1") == expected_3
assert str(URL("https://dify.ai") / "api/v1") == expected_3
assert str(URL("https://dify.ai/") / "api/v1") == expected_3
assert str(URL("https://dify.ai/api") / "v1") == expected_3
assert str(URL("https://dify.ai/api/") / "v1") == expected_3
with pytest.raises(ValueError) as e1:
str(URL('https://dify.ai') / '/api')
str(URL("https://dify.ai") / "/api")
assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"

View File

@@ -2,13 +2,13 @@ from models.account import TenantAccountRole
def test_account_is_privileged_role() -> None:
assert TenantAccountRole.ADMIN == 'admin'
assert TenantAccountRole.OWNER == 'owner'
assert TenantAccountRole.EDITOR == 'editor'
assert TenantAccountRole.NORMAL == 'normal'
assert TenantAccountRole.ADMIN == "admin"
assert TenantAccountRole.OWNER == "owner"
assert TenantAccountRole.EDITOR == "editor"
assert TenantAccountRole.NORMAL == "normal"
assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN)
assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER)
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL)
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.EDITOR)
assert not TenantAccountRole.is_privileged_role('')
assert not TenantAccountRole.is_privileged_role("")

View File

@@ -7,19 +7,19 @@ from models import ConversationVariable
def test_from_variable_and_to_variable():
variable = factory.build_variable_from_mapping(
{
'id': str(uuid4()),
'name': 'name',
'value_type': SegmentType.OBJECT,
'value': {
'key': {
'key': 'value',
"id": str(uuid4()),
"name": "name",
"value_type": SegmentType.OBJECT,
"value": {
"key": {
"key": "value",
}
},
}
)
conversation_variable = ConversationVariable.from_variable(
app_id='app_id', conversation_id='conversation_id', variable=variable
app_id="app_id", conversation_id="conversation_id", variable=variable
)
assert conversation_variable.to_variable() == variable

View File

@@ -8,30 +8,30 @@ from models.workflow import Workflow
def test_environment_variables():
contexts.tenant_id.set('tenant_id')
contexts.tenant_id.set("tenant_id")
# Create a Workflow instance
workflow = Workflow(
tenant_id='tenant_id',
app_id='app_id',
type='workflow',
version='draft',
graph='{}',
features='{}',
created_by='account_id',
tenant_id="tenant_id",
app_id="app_id",
type="workflow",
version="draft",
graph="{}",
features="{}",
created_by="account_id",
environment_variables=[],
conversation_variables=[],
)
# Create some EnvironmentVariable instances
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())})
variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())})
variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())})
variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())})
variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())})
variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())})
variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())})
with (
mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'),
mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'),
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
):
# Set the environment_variables property of the Workflow instance
variables = [variable1, variable2, variable3, variable4]
@@ -42,30 +42,30 @@ def test_environment_variables():
def test_update_environment_variables():
contexts.tenant_id.set('tenant_id')
contexts.tenant_id.set("tenant_id")
# Create a Workflow instance
workflow = Workflow(
tenant_id='tenant_id',
app_id='app_id',
type='workflow',
version='draft',
graph='{}',
features='{}',
created_by='account_id',
tenant_id="tenant_id",
app_id="app_id",
type="workflow",
version="draft",
graph="{}",
features="{}",
created_by="account_id",
environment_variables=[],
conversation_variables=[],
)
# Create some EnvironmentVariable instances
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())})
variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())})
variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())})
variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())})
variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())})
variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())})
variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())})
with (
mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'),
mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'),
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
):
variables = [variable1, variable2, variable3, variable4]
@@ -76,28 +76,28 @@ def test_update_environment_variables():
# Update the name of variable3 and keep the value as it is
variables[2] = variable3.model_copy(
update={
'name': 'new name',
'value': HIDDEN_VALUE,
"name": "new name",
"value": HIDDEN_VALUE,
}
)
workflow.environment_variables = variables
assert workflow.environment_variables[2].name == 'new name'
assert workflow.environment_variables[2].name == "new name"
assert workflow.environment_variables[2].value == variable3.value
def test_to_dict():
contexts.tenant_id.set('tenant_id')
contexts.tenant_id.set("tenant_id")
# Create a Workflow instance
workflow = Workflow(
tenant_id='tenant_id',
app_id='app_id',
type='workflow',
version='draft',
graph='{}',
features='{}',
created_by='account_id',
tenant_id="tenant_id",
app_id="app_id",
type="workflow",
version="draft",
graph="{}",
features="{}",
created_by="account_id",
environment_variables=[],
conversation_variables=[],
)
@@ -105,19 +105,19 @@ def test_to_dict():
# Create some EnvironmentVariable instances
with (
mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'),
mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'),
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
):
# Set the environment_variables property of the Workflow instance
workflow.environment_variables = [
SecretVariable.model_validate({'name': 'secret', 'value': 'secret', 'id': str(uuid4())}),
StringVariable.model_validate({'name': 'text', 'value': 'text', 'id': str(uuid4())}),
SecretVariable.model_validate({"name": "secret", "value": "secret", "id": str(uuid4())}),
StringVariable.model_validate({"name": "text", "value": "text", "id": str(uuid4())}),
]
workflow_dict = workflow.to_dict()
assert workflow_dict['environment_variables'][0]['value'] == ''
assert workflow_dict['environment_variables'][1]['value'] == 'text'
assert workflow_dict["environment_variables"][0]["value"] == ""
assert workflow_dict["environment_variables"][1]["value"] == "text"
workflow_dict = workflow.to_dict(include_secret=True)
assert workflow_dict['environment_variables'][0]['value'] == 'secret'
assert workflow_dict['environment_variables'][1]['value'] == 'text'
assert workflow_dict["environment_variables"][0]["value"] == "secret"
assert workflow_dict["environment_variables"][1]["value"] == "text"

View File

@@ -83,18 +83,12 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable",
type="api",
config={
"api_based_extension_id": api_based_extension_id
}
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
)
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
)
assert len(nodes) == 2
@@ -105,10 +99,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
assert http_request_node["data"]["method"] == "post"
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
assert http_request_node["data"]["authorization"]["type"] == "api-key"
assert http_request_node["data"]["authorization"]["config"] == {
"type": "bearer",
"api_key": "api_key"
}
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
assert http_request_node["data"]["body"]["type"] == "json"
body_data = http_request_node["data"]["body"]["data"]
@@ -153,18 +144,12 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
external_data_variables = [
ExternalDataVariableEntity(
variable="external_variable",
type="api",
config={
"api_based_extension_id": api_based_extension_id
}
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
)
]
nodes, _ = workflow_converter._convert_to_http_request_node(
app_model=app_model,
variables=default_variables,
external_data_variables=external_data_variables
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
)
assert len(nodes) == 2
@@ -175,10 +160,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
assert http_request_node["data"]["method"] == "post"
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
assert http_request_node["data"]["authorization"]["type"] == "api-key"
assert http_request_node["data"]["authorization"]["config"] == {
"type": "bearer",
"api_key": "api_key"
}
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
assert http_request_node["data"]["body"]["type"] == "json"
body_data = http_request_node["data"]["body"]["data"]
@@ -207,37 +189,25 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot():
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={
'reranking_provider_name': 'cohere',
'reranking_model_name': 'rerank-english-v2.0'
},
reranking_enabled=True
)
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(
provider='openai',
model='gpt-4',
mode='chat',
parameters={},
stop=[]
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode,
dataset_config=dataset_config,
model_config=model_config
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
)
assert node["data"]["type"] == "knowledge-retrieval"
assert node["data"]["query_variable_selector"] == ["sys", "query"]
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
assert (node["data"]["retrieval_mode"]
== dataset_config.retrieve_config.retrieve_strategy.value)
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
assert node["data"]["multiple_retrieval_config"] == {
"top_k": dataset_config.retrieve_config.top_k,
"score_threshold": dataset_config.retrieve_config.score_threshold,
"reranking_model": dataset_config.retrieve_config.reranking_model
"reranking_model": dataset_config.retrieve_config.reranking_model,
}
@@ -251,37 +221,25 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app():
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
top_k=5,
score_threshold=0.8,
reranking_model={
'reranking_provider_name': 'cohere',
'reranking_model_name': 'rerank-english-v2.0'
},
reranking_enabled=True
)
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
reranking_enabled=True,
),
)
model_config = ModelConfigEntity(
provider='openai',
model='gpt-4',
mode='chat',
parameters={},
stop=[]
)
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
new_app_mode=new_app_mode,
dataset_config=dataset_config,
model_config=model_config
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
)
assert node["data"]["type"] == "knowledge-retrieval"
assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable]
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
assert (node["data"]["retrieval_mode"]
== dataset_config.retrieve_config.retrieve_strategy.value)
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
assert node["data"]["multiple_retrieval_config"] == {
"top_k": dataset_config.retrieve_config.top_k,
"score_threshold": dataset_config.retrieve_config.score_threshold,
"reranking_model": dataset_config.retrieve_config.reranking_model
"reranking_model": dataset_config.retrieve_config.reranking_model,
}
@@ -293,14 +251,12 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [
start_node
],
"edges": [] # no need
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
@@ -308,7 +264,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}."
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
)
llm_node = workflow_converter._convert_to_llm_node(
@@ -316,17 +272,17 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]['name'] == model
assert llm_node["data"]['model']["mode"] == model_mode.value
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
for v in default_variables:
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
assert llm_node["data"]["prompt_template"][0]['text'] == template + '\n'
assert llm_node["data"]['context']['enabled'] is False
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n"
assert llm_node["data"]["context"]["enabled"] is False
def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables):
@@ -337,14 +293,12 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [
start_node
],
"edges": [] # no need
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
@@ -352,7 +306,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}."
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
)
llm_node = workflow_converter._convert_to_llm_node(
@@ -360,17 +314,17 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]['name'] == model
assert llm_node["data"]['model']["mode"] == model_mode.value
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
template = prompt_template.simple_prompt_template
for v in default_variables:
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
assert llm_node["data"]["prompt_template"]['text'] == template + '\n'
assert llm_node["data"]['context']['enabled'] is False
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"]["text"] == template + "\n"
assert llm_node["data"]["context"]["enabled"] is False
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables):
@@ -381,14 +335,12 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables)
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [
start_node
],
"edges": [] # no need
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
@@ -396,12 +348,16 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables)
prompt_template = PromptTemplateEntity(
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[
AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM),
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
])
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
messages=[
AdvancedChatMessageEntity(
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
role=PromptMessageRole.SYSTEM,
),
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
]
),
)
llm_node = workflow_converter._convert_to_llm_node(
@@ -409,18 +365,18 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables)
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]['name'] == model
assert llm_node["data"]['model']["mode"] == model_mode.value
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], list)
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
template = prompt_template.advanced_chat_prompt_template.messages[0].text
for v in default_variables:
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
assert llm_node["data"]["prompt_template"][0]['text'] == template
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"][0]["text"] == template
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables):
@@ -431,14 +387,12 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
workflow_converter = WorkflowConverter()
start_node = workflow_converter._convert_to_start_node(default_variables)
graph = {
"nodes": [
start_node
],
"edges": [] # no need
"nodes": [start_node],
"edges": [], # no need
}
model_config_mock = MagicMock(spec=ModelConfigEntity)
model_config_mock.provider = 'openai'
model_config_mock.provider = "openai"
model_config_mock.model = model
model_config_mock.mode = model_mode.value
model_config_mock.parameters = {}
@@ -448,12 +402,9 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n"
"Human: hi\nAssistant: ",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(
user="Human",
assistant="Assistant"
)
)
"Human: hi\nAssistant: ",
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
),
)
llm_node = workflow_converter._convert_to_llm_node(
@@ -461,14 +412,14 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
new_app_mode=new_app_mode,
model_config=model_config_mock,
graph=graph,
prompt_template=prompt_template
prompt_template=prompt_template,
)
assert llm_node["data"]["type"] == "llm"
assert llm_node["data"]["model"]['name'] == model
assert llm_node["data"]['model']["mode"] == model_mode.value
assert llm_node["data"]["model"]["name"] == model
assert llm_node["data"]["model"]["mode"] == model_mode.value
assert isinstance(llm_node["data"]["prompt_template"], dict)
template = prompt_template.advanced_completion_prompt_template.prompt
for v in default_variables:
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
assert llm_node["data"]["prompt_template"]['text'] == template
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
assert llm_node["data"]["prompt_template"]["text"] == template

View File

@@ -8,8 +8,9 @@ from core.helper.position_helper import get_position_map, is_filtered, pin_posit
@pytest.fixture
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions.yaml").write_text(dedent(
"""\
tmp_path.joinpath("example_positions.yaml").write_text(
dedent(
"""\
- first
- second
# - commented
@@ -17,57 +18,54 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
- 9999999999999
- forth
"""))
"""
)
)
return str(tmp_path)
@pytest.fixture
def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent(
"""\
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(
dedent(
"""\
# - commented1
# - commented2
-
-
"""))
"""
)
)
return str(tmp_path)
def test_position_helper(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
file_name='example_positions.yaml')
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
assert len(position_map) == 4
assert position_map == {
'first': 0,
'second': 1,
'third': 2,
'forth': 3,
"first": 0,
"second": 1,
"third": 2,
"forth": 3,
}
def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml):
position_map = get_position_map(
folder_path=prepare_empty_commented_positions_yaml,
file_name='example_positions_all_commented.yaml')
folder_path=prepare_empty_commented_positions_yaml, file_name="example_positions_all_commented.yaml"
)
assert position_map == {}
def test_excluded_position_data(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
file_name='example_positions.yaml'
)
pin_list = ['forth', 'first']
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
pin_list = ["forth", "first"]
include_set = set()
exclude_set = {'9999999999999'}
exclude_set = {"9999999999999"}
position_map = pin_position_map(
original_position_map=position_map,
pin_list=pin_list
)
position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list)
data = [
"forth",
@@ -90,22 +88,16 @@ def test_excluded_position_data(prepare_example_positions_yaml):
)
# assert the result in the correct order
assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2']
assert sorted_data == ["forth", "first", "second", "third", "extra1", "extra2"]
def test_included_position_data(prepare_example_positions_yaml):
position_map = get_position_map(
folder_path=prepare_example_positions_yaml,
file_name='example_positions.yaml'
)
pin_list = ['forth', 'first']
include_set = {'forth', 'first'}
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
pin_list = ["forth", "first"]
include_set = {"forth", "first"}
exclude_set = {}
position_map = pin_position_map(
original_position_map=position_map,
pin_list=pin_list
)
position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list)
data = [
"forth",
@@ -128,4 +120,4 @@ def test_included_position_data(prepare_example_positions_yaml):
)
# assert the result in the correct order
assert sorted_data == ['forth', 'first']
assert sorted_data == ["forth", "first"]

View File

@@ -5,17 +5,18 @@ from yaml import YAMLError
from core.tools.utils.yaml_utils import load_yaml_file
EXAMPLE_YAML_FILE = 'example_yaml.yaml'
INVALID_YAML_FILE = 'invalid_yaml.yaml'
NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'
EXAMPLE_YAML_FILE = "example_yaml.yaml"
INVALID_YAML_FILE = "invalid_yaml.yaml"
NON_EXISTING_YAML_FILE = "non_existing_file.yaml"
@pytest.fixture
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
file_path.write_text(dedent(
"""\
file_path.write_text(
dedent(
"""\
address:
city: Example City
country: Example Country
@@ -26,7 +27,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
- Java
- C++
empty_key:
"""))
"""
)
)
return str(file_path)
@@ -34,8 +37,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
file_path.write_text(dedent(
"""\
file_path.write_text(
dedent(
"""\
address:
city: Example City
country: Example Country
@@ -45,13 +49,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
- Python
- Java
- C++
"""))
"""
)
)
return str(file_path)
def test_load_yaml_non_existing_file():
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
assert load_yaml_file(file_path='') == {}
assert load_yaml_file(file_path="") == {}
with pytest.raises(FileNotFoundError):
load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False)
@@ -60,12 +66,12 @@ def test_load_yaml_non_existing_file():
def test_load_valid_yaml_file(prepare_example_yaml_file):
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
assert len(yaml_data) > 0
assert yaml_data['age'] == 30
assert yaml_data['gender'] == 'male'
assert yaml_data['address']['city'] == 'Example City'
assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}
assert yaml_data.get('empty_key') is None
assert yaml_data.get('non_existed_key') is None
assert yaml_data["age"] == 30
assert yaml_data["gender"] == "male"
assert yaml_data["address"]["city"] == "Example City"
assert set(yaml_data["languages"]) == {"Python", "Java", "C++"}
assert yaml_data.get("empty_key") is None
assert yaml_data.get("non_existed_key") is None
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):