Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -6,18 +6,21 @@ from flask import Flask
|
||||
|
||||
from configs.app_config import DifyConfig
|
||||
|
||||
EXAMPLE_ENV_FILENAME = '.env'
|
||||
EXAMPLE_ENV_FILENAME = ".env"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_env_file(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path.joinpath(EXAMPLE_ENV_FILENAME)
|
||||
file_path.write_text(dedent(
|
||||
"""
|
||||
file_path.write_text(
|
||||
dedent(
|
||||
"""
|
||||
CONSOLE_API_URL=https://example.com
|
||||
CONSOLE_WEB_URL=https://example.com
|
||||
"""))
|
||||
"""
|
||||
)
|
||||
)
|
||||
return str(file_path)
|
||||
|
||||
|
||||
@@ -29,7 +32,7 @@ def test_dify_config_undefined_entry(example_env_file):
|
||||
# entries not defined in app settings
|
||||
with pytest.raises(TypeError):
|
||||
# TypeError: 'AppSettings' object is not subscriptable
|
||||
assert config['LOG_LEVEL'] == 'INFO'
|
||||
assert config["LOG_LEVEL"] == "INFO"
|
||||
|
||||
|
||||
def test_dify_config(example_env_file):
|
||||
@@ -37,10 +40,10 @@ def test_dify_config(example_env_file):
|
||||
config = DifyConfig(_env_file=example_env_file)
|
||||
|
||||
# constant values
|
||||
assert config.COMMIT_SHA == ''
|
||||
assert config.COMMIT_SHA == ""
|
||||
|
||||
# default values
|
||||
assert config.EDITION == 'SELF_HOSTED'
|
||||
assert config.EDITION == "SELF_HOSTED"
|
||||
assert config.API_COMPRESSION_ENABLED is False
|
||||
assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0
|
||||
|
||||
@@ -48,36 +51,36 @@ def test_dify_config(example_env_file):
|
||||
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
|
||||
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
|
||||
def test_flask_configs(example_env_file):
|
||||
flask_app = Flask('app')
|
||||
flask_app = Flask("app")
|
||||
# clear system environment variables
|
||||
os.environ.clear()
|
||||
flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore
|
||||
config = flask_app.config
|
||||
|
||||
# configs read from pydantic-settings
|
||||
assert config['LOG_LEVEL'] == 'INFO'
|
||||
assert config['COMMIT_SHA'] == ''
|
||||
assert config['EDITION'] == 'SELF_HOSTED'
|
||||
assert config['API_COMPRESSION_ENABLED'] is False
|
||||
assert config['SENTRY_TRACES_SAMPLE_RATE'] == 1.0
|
||||
assert config['TESTING'] == False
|
||||
assert config["LOG_LEVEL"] == "INFO"
|
||||
assert config["COMMIT_SHA"] == ""
|
||||
assert config["EDITION"] == "SELF_HOSTED"
|
||||
assert config["API_COMPRESSION_ENABLED"] is False
|
||||
assert config["SENTRY_TRACES_SAMPLE_RATE"] == 1.0
|
||||
assert config["TESTING"] == False
|
||||
|
||||
# value from env file
|
||||
assert config['CONSOLE_API_URL'] == 'https://example.com'
|
||||
assert config["CONSOLE_API_URL"] == "https://example.com"
|
||||
# fallback to alias choices value as CONSOLE_API_URL
|
||||
assert config['FILES_URL'] == 'https://example.com'
|
||||
assert config["FILES_URL"] == "https://example.com"
|
||||
|
||||
assert config['SQLALCHEMY_DATABASE_URI'] == 'postgresql://postgres:@localhost:5432/dify'
|
||||
assert config['SQLALCHEMY_ENGINE_OPTIONS'] == {
|
||||
'connect_args': {
|
||||
'options': '-c timezone=UTC',
|
||||
assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:@localhost:5432/dify"
|
||||
assert config["SQLALCHEMY_ENGINE_OPTIONS"] == {
|
||||
"connect_args": {
|
||||
"options": "-c timezone=UTC",
|
||||
},
|
||||
'max_overflow': 10,
|
||||
'pool_pre_ping': False,
|
||||
'pool_recycle': 3600,
|
||||
'pool_size': 30,
|
||||
"max_overflow": 10,
|
||||
"pool_pre_ping": False,
|
||||
"pool_recycle": 3600,
|
||||
"pool_size": 30,
|
||||
}
|
||||
|
||||
assert config['CONSOLE_WEB_URL']=='https://example.com'
|
||||
assert config['CONSOLE_CORS_ALLOW_ORIGINS']==['https://example.com']
|
||||
assert config['WEB_API_CORS_ALLOW_ORIGINS'] == ['*']
|
||||
assert config["CONSOLE_WEB_URL"] == "https://example.com"
|
||||
assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"]
|
||||
assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["*"]
|
||||
|
||||
@@ -17,31 +17,31 @@ from core.app.segments.exc import VariableError
|
||||
|
||||
|
||||
def test_string_variable():
|
||||
test_data = {'value_type': 'string', 'name': 'test_text', 'value': 'Hello, World!'}
|
||||
test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, StringVariable)
|
||||
|
||||
|
||||
def test_integer_variable():
|
||||
test_data = {'value_type': 'number', 'name': 'test_int', 'value': 42}
|
||||
test_data = {"value_type": "number", "name": "test_int", "value": 42}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, IntegerVariable)
|
||||
|
||||
|
||||
def test_float_variable():
|
||||
test_data = {'value_type': 'number', 'name': 'test_float', 'value': 3.14}
|
||||
test_data = {"value_type": "number", "name": "test_float", "value": 3.14}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, FloatVariable)
|
||||
|
||||
|
||||
def test_secret_variable():
|
||||
test_data = {'value_type': 'secret', 'name': 'test_secret', 'value': 'secret_value'}
|
||||
test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"}
|
||||
result = factory.build_variable_from_mapping(test_data)
|
||||
assert isinstance(result, SecretVariable)
|
||||
|
||||
|
||||
def test_invalid_value_type():
|
||||
test_data = {'value_type': 'unknown', 'name': 'test_invalid', 'value': 'value'}
|
||||
test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"}
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(test_data)
|
||||
|
||||
@@ -49,51 +49,51 @@ def test_invalid_value_type():
|
||||
def test_build_a_blank_string():
|
||||
result = factory.build_variable_from_mapping(
|
||||
{
|
||||
'value_type': 'string',
|
||||
'name': 'blank',
|
||||
'value': '',
|
||||
"value_type": "string",
|
||||
"name": "blank",
|
||||
"value": "",
|
||||
}
|
||||
)
|
||||
assert isinstance(result, StringVariable)
|
||||
assert result.value == ''
|
||||
assert result.value == ""
|
||||
|
||||
|
||||
def test_build_a_object_variable_with_none_value():
|
||||
var = factory.build_segment(
|
||||
{
|
||||
'key1': None,
|
||||
"key1": None,
|
||||
}
|
||||
)
|
||||
assert isinstance(var, ObjectSegment)
|
||||
assert var.value['key1'] is None
|
||||
assert var.value["key1"] is None
|
||||
|
||||
|
||||
def test_object_variable():
|
||||
mapping = {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'object',
|
||||
'name': 'test_object',
|
||||
'description': 'Description of the variable.',
|
||||
'value': {
|
||||
'key1': 'text',
|
||||
'key2': 2,
|
||||
"id": str(uuid4()),
|
||||
"value_type": "object",
|
||||
"name": "test_object",
|
||||
"description": "Description of the variable.",
|
||||
"value": {
|
||||
"key1": "text",
|
||||
"key2": 2,
|
||||
},
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
assert isinstance(variable, ObjectSegment)
|
||||
assert isinstance(variable.value['key1'], str)
|
||||
assert isinstance(variable.value['key2'], int)
|
||||
assert isinstance(variable.value["key1"], str)
|
||||
assert isinstance(variable.value["key2"], int)
|
||||
|
||||
|
||||
def test_array_string_variable():
|
||||
mapping = {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'array[string]',
|
||||
'name': 'test_array',
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
'text',
|
||||
'text',
|
||||
"id": str(uuid4()),
|
||||
"value_type": "array[string]",
|
||||
"name": "test_array",
|
||||
"description": "Description of the variable.",
|
||||
"value": [
|
||||
"text",
|
||||
"text",
|
||||
],
|
||||
}
|
||||
variable = factory.build_variable_from_mapping(mapping)
|
||||
@@ -104,11 +104,11 @@ def test_array_string_variable():
|
||||
|
||||
def test_array_number_variable():
|
||||
mapping = {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'array[number]',
|
||||
'name': 'test_array',
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
"id": str(uuid4()),
|
||||
"value_type": "array[number]",
|
||||
"name": "test_array",
|
||||
"description": "Description of the variable.",
|
||||
"value": [
|
||||
1,
|
||||
2.0,
|
||||
],
|
||||
@@ -121,18 +121,18 @@ def test_array_number_variable():
|
||||
|
||||
def test_array_object_variable():
|
||||
mapping = {
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'array[object]',
|
||||
'name': 'test_array',
|
||||
'description': 'Description of the variable.',
|
||||
'value': [
|
||||
"id": str(uuid4()),
|
||||
"value_type": "array[object]",
|
||||
"name": "test_array",
|
||||
"description": "Description of the variable.",
|
||||
"value": [
|
||||
{
|
||||
'key1': 'text',
|
||||
'key2': 1,
|
||||
"key1": "text",
|
||||
"key2": 1,
|
||||
},
|
||||
{
|
||||
'key1': 'text',
|
||||
'key2': 1,
|
||||
"key1": "text",
|
||||
"key2": 1,
|
||||
},
|
||||
],
|
||||
}
|
||||
@@ -140,19 +140,19 @@ def test_array_object_variable():
|
||||
assert isinstance(variable, ArrayObjectVariable)
|
||||
assert isinstance(variable.value[0], dict)
|
||||
assert isinstance(variable.value[1], dict)
|
||||
assert isinstance(variable.value[0]['key1'], str)
|
||||
assert isinstance(variable.value[0]['key2'], int)
|
||||
assert isinstance(variable.value[1]['key1'], str)
|
||||
assert isinstance(variable.value[1]['key2'], int)
|
||||
assert isinstance(variable.value[0]["key1"], str)
|
||||
assert isinstance(variable.value[0]["key2"], int)
|
||||
assert isinstance(variable.value[1]["key1"], str)
|
||||
assert isinstance(variable.value[1]["key2"], int)
|
||||
|
||||
|
||||
def test_variable_cannot_large_than_5_kb():
|
||||
with pytest.raises(VariableError):
|
||||
factory.build_variable_from_mapping(
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'value_type': 'string',
|
||||
'name': 'test_text',
|
||||
'value': 'a' * 1024 * 6,
|
||||
"id": str(uuid4()),
|
||||
"value_type": "string",
|
||||
"name": "test_text",
|
||||
"value": "a" * 1024 * 6,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -7,20 +7,20 @@ from core.workflow.enums import SystemVariableKey
|
||||
def test_segment_group_to_text():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey('user_id'): 'fake-user-id',
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[
|
||||
SecretVariable(name='secret_key', value='fake-secret-key'),
|
||||
SecretVariable(name="secret_key", value="fake-secret-key"),
|
||||
],
|
||||
)
|
||||
variable_pool.add(('node_id', 'custom_query'), 'fake-user-query')
|
||||
variable_pool.add(("node_id", "custom_query"), "fake-user-query")
|
||||
template = (
|
||||
'Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}.'
|
||||
"Hello, {{#sys.user_id#}}! Your query is {{#node_id.custom_query#}}. And your key is {{#env.secret_key#}}."
|
||||
)
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
|
||||
assert segments_group.text == 'Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key.'
|
||||
assert segments_group.text == "Hello, fake-user-id! Your query is fake-user-query. And your key is fake-secret-key."
|
||||
assert (
|
||||
segments_group.log
|
||||
== f"Hello, fake-user-id! Your query is fake-user-query. And your key is {encrypter.obfuscated_token('fake-secret-key')}."
|
||||
@@ -33,22 +33,22 @@ def test_convert_constant_to_segment_group():
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
template = 'Hello, world!'
|
||||
template = "Hello, world!"
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
assert segments_group.text == 'Hello, world!'
|
||||
assert segments_group.log == 'Hello, world!'
|
||||
assert segments_group.text == "Hello, world!"
|
||||
assert segments_group.log == "Hello, world!"
|
||||
|
||||
|
||||
def test_convert_variable_to_segment_group():
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey('user_id'): 'fake-user-id',
|
||||
SystemVariableKey("user_id"): "fake-user-id",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
template = '{{#sys.user_id#}}'
|
||||
template = "{{#sys.user_id#}}"
|
||||
segments_group = parser.convert_template(template=template, variable_pool=variable_pool)
|
||||
assert segments_group.text == 'fake-user-id'
|
||||
assert segments_group.log == 'fake-user-id'
|
||||
assert segments_group.value == [StringSegment(value='fake-user-id')]
|
||||
assert segments_group.text == "fake-user-id"
|
||||
assert segments_group.log == "fake-user-id"
|
||||
assert segments_group.value == [StringSegment(value="fake-user-id")]
|
||||
|
||||
@@ -13,60 +13,60 @@ from core.app.segments import (
|
||||
|
||||
|
||||
def test_frozen_variables():
|
||||
var = StringVariable(name='text', value='text')
|
||||
var = StringVariable(name="text", value="text")
|
||||
with pytest.raises(ValidationError):
|
||||
var.value = 'new value'
|
||||
var.value = "new value"
|
||||
|
||||
int_var = IntegerVariable(name='integer', value=42)
|
||||
int_var = IntegerVariable(name="integer", value=42)
|
||||
with pytest.raises(ValidationError):
|
||||
int_var.value = 100
|
||||
|
||||
float_var = FloatVariable(name='float', value=3.14)
|
||||
float_var = FloatVariable(name="float", value=3.14)
|
||||
with pytest.raises(ValidationError):
|
||||
float_var.value = 2.718
|
||||
|
||||
secret_var = SecretVariable(name='secret', value='secret_value')
|
||||
secret_var = SecretVariable(name="secret", value="secret_value")
|
||||
with pytest.raises(ValidationError):
|
||||
secret_var.value = 'new_secret_value'
|
||||
secret_var.value = "new_secret_value"
|
||||
|
||||
|
||||
def test_variable_value_type_immutable():
|
||||
with pytest.raises(ValidationError):
|
||||
StringVariable(value_type=SegmentType.ARRAY_ANY, name='text', value='text')
|
||||
StringVariable(value_type=SegmentType.ARRAY_ANY, name="text", value="text")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
StringVariable.model_validate({'value_type': 'not text', 'name': 'text', 'value': 'text'})
|
||||
StringVariable.model_validate({"value_type": "not text", "name": "text", "value": "text"})
|
||||
|
||||
var = IntegerVariable(name='integer', value=42)
|
||||
var = IntegerVariable(name="integer", value=42)
|
||||
with pytest.raises(ValidationError):
|
||||
IntegerVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
|
||||
|
||||
var = FloatVariable(name='float', value=3.14)
|
||||
var = FloatVariable(name="float", value=3.14)
|
||||
with pytest.raises(ValidationError):
|
||||
FloatVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
|
||||
|
||||
var = SecretVariable(name='secret', value='secret_value')
|
||||
var = SecretVariable(name="secret", value="secret_value")
|
||||
with pytest.raises(ValidationError):
|
||||
SecretVariable(value_type=SegmentType.ARRAY_ANY, name=var.name, value=var.value)
|
||||
|
||||
|
||||
def test_object_variable_to_object():
|
||||
var = ObjectVariable(
|
||||
name='object',
|
||||
name="object",
|
||||
value={
|
||||
'key1': {
|
||||
'key2': 'value2',
|
||||
"key1": {
|
||||
"key2": "value2",
|
||||
},
|
||||
'key2': ['value5_1', 42, {}],
|
||||
"key2": ["value5_1", 42, {}],
|
||||
},
|
||||
)
|
||||
|
||||
assert var.to_object() == {
|
||||
'key1': {
|
||||
'key2': 'value2',
|
||||
"key1": {
|
||||
"key2": "value2",
|
||||
},
|
||||
'key2': [
|
||||
'value5_1',
|
||||
"key2": [
|
||||
"value5_1",
|
||||
42,
|
||||
{},
|
||||
],
|
||||
@@ -74,11 +74,11 @@ def test_object_variable_to_object():
|
||||
|
||||
|
||||
def test_variable_to_object():
|
||||
var = StringVariable(name='text', value='text')
|
||||
assert var.to_object() == 'text'
|
||||
var = IntegerVariable(name='integer', value=42)
|
||||
var = StringVariable(name="text", value="text")
|
||||
assert var.to_object() == "text"
|
||||
var = IntegerVariable(name="integer", value=42)
|
||||
assert var.to_object() == 42
|
||||
var = FloatVariable(name='float', value=3.14)
|
||||
var = FloatVariable(name="float", value=3.14)
|
||||
assert var.to_object() == 3.14
|
||||
var = SecretVariable(name='secret', value='secret_value')
|
||||
assert var.to_object() == 'secret_value'
|
||||
var = SecretVariable(name="secret", value="secret_value")
|
||||
assert var.to_object() == "secret_value"
|
||||
|
||||
@@ -4,17 +4,17 @@ from unittest.mock import MagicMock, patch
|
||||
from core.helper.ssrf_proxy import SSRF_DEFAULT_MAX_RETRIES, STATUS_FORCELIST, make_request
|
||||
|
||||
|
||||
@patch('httpx.request')
|
||||
@patch("httpx.request")
|
||||
def test_successful_request(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
response = make_request('GET', 'http://example.com')
|
||||
response = make_request("GET", "http://example.com")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@patch('httpx.request')
|
||||
@patch("httpx.request")
|
||||
def test_retry_exceed_max_retries(mock_request):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
@@ -23,13 +23,13 @@ def test_retry_exceed_max_retries(mock_request):
|
||||
mock_request.side_effect = side_effects
|
||||
|
||||
try:
|
||||
make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
||||
make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES - 1)
|
||||
raise AssertionError("Expected Exception not raised")
|
||||
except Exception as e:
|
||||
assert str(e) == f"Reached maximum retries ({SSRF_DEFAULT_MAX_RETRIES - 1}) for URL http://example.com"
|
||||
|
||||
|
||||
@patch('httpx.request')
|
||||
@patch("httpx.request")
|
||||
def test_retry_logic_success(mock_request):
|
||||
side_effects = []
|
||||
|
||||
@@ -45,8 +45,8 @@ def test_retry_logic_success(mock_request):
|
||||
|
||||
mock_request.side_effect = side_effects
|
||||
|
||||
response = make_request('GET', 'http://example.com', max_retries=SSRF_DEFAULT_MAX_RETRIES)
|
||||
response = make_request("GET", "http://example.com", max_retries=SSRF_DEFAULT_MAX_RETRIES)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_request.call_count == SSRF_DEFAULT_MAX_RETRIES + 1
|
||||
assert mock_request.call_args_list[0][1].get('method') == 'GET'
|
||||
assert mock_request.call_args_list[0][1].get("method") == "GET"
|
||||
|
||||
@@ -21,18 +21,18 @@ def test_max_chunks():
|
||||
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
|
||||
return _MockTextEmbedding()
|
||||
|
||||
model = 'embedding-v1'
|
||||
model = "embedding-v1"
|
||||
credentials = {
|
||||
'api_key': 'xxxx',
|
||||
'secret_key': 'yyyy',
|
||||
"api_key": "xxxx",
|
||||
"secret_key": "yyyy",
|
||||
}
|
||||
embedding_model = WenxinTextEmbeddingModel()
|
||||
context_size = embedding_model._get_context_size(model, credentials)
|
||||
max_chunks = embedding_model._get_max_chunks(model, credentials)
|
||||
embedding_model._create_text_embedding = _create_text_embedding
|
||||
|
||||
texts = ['0123456789' for i in range(0, max_chunks * 2)]
|
||||
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
|
||||
texts = ["0123456789" for i in range(0, max_chunks * 2)]
|
||||
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test")
|
||||
assert len(result.embeddings) == max_chunks * 2
|
||||
|
||||
|
||||
@@ -41,16 +41,16 @@ def test_context_size():
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
||||
def mock_text(token_size: int) -> str:
|
||||
_text = "".join(['0' for i in range(token_size)])
|
||||
_text = "".join(["0" for i in range(token_size)])
|
||||
num_tokens = get_num_tokens_by_gpt2(_text)
|
||||
ratio = int(np.floor(len(_text) / num_tokens))
|
||||
m_text = "".join([_text for i in range(ratio)])
|
||||
return m_text
|
||||
|
||||
model = 'embedding-v1'
|
||||
model = "embedding-v1"
|
||||
credentials = {
|
||||
'api_key': 'xxxx',
|
||||
'secret_key': 'yyyy',
|
||||
"api_key": "xxxx",
|
||||
"secret_key": "yyyy",
|
||||
}
|
||||
embedding_model = WenxinTextEmbeddingModel()
|
||||
context_size = embedding_model._get_context_size(model, credentials)
|
||||
@@ -71,5 +71,5 @@ def test_context_size():
|
||||
assert get_num_tokens_by_gpt2(text) == context_size * 2
|
||||
|
||||
texts = [text]
|
||||
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
|
||||
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, "test")
|
||||
assert result.usage.tokens == context_size
|
||||
|
||||
@@ -14,39 +14,24 @@ from models.model import Conversation
|
||||
|
||||
def test__get_completion_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-3.5-turbo-instruct'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-3.5-turbo-instruct"
|
||||
|
||||
prompt_template = "Context:\n{{#context#}}\n\nHistories:\n{{#histories#}}\n\nyou are {{name}}."
|
||||
prompt_template_config = CompletionModelPromptTemplate(
|
||||
text=prompt_template
|
||||
)
|
||||
prompt_template_config = CompletionModelPromptTemplate(text=prompt_template)
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
role_prefix=MemoryConfig.RolePrefix(
|
||||
user="Human",
|
||||
assistant="Assistant"
|
||||
),
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"),
|
||||
window=MemoryConfig.WindowConfig(enabled=False),
|
||||
)
|
||||
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
inputs = {"name": "John"}
|
||||
files = []
|
||||
context = "I am superman."
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Hello")
|
||||
]
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
@@ -59,16 +44,22 @@ def test__get_completion_model_prompt_messages():
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format({
|
||||
"#context#": context,
|
||||
"#histories#": "\n".join([f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: "
|
||||
f"{prompt.content}" for prompt in history_prompt_messages]),
|
||||
**inputs,
|
||||
})
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=prompt_template).format(
|
||||
{
|
||||
"#context#": context,
|
||||
"#histories#": "\n".join(
|
||||
[
|
||||
f"{'Human' if prompt.role.value == 'user' else 'Assistant'}: " f"{prompt.content}"
|
||||
for prompt in history_prompt_messages
|
||||
]
|
||||
),
|
||||
**inputs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
@@ -77,15 +68,9 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
files = []
|
||||
query = "Hi2."
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi1."),
|
||||
AssistantPromptMessage(content="Hello1!")
|
||||
]
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi1."), AssistantPromptMessage(content="Hello1!")]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
@@ -98,14 +83,14 @@ def test__get_chat_model_prompt_messages(get_chat_model_args):
|
||||
context=context,
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 6
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
|
||||
{**inputs, "#context#": context}
|
||||
)
|
||||
assert prompt_messages[5].content == query
|
||||
|
||||
|
||||
@@ -124,14 +109,14 @@ def test__get_chat_model_prompt_messages_no_memory(get_chat_model_args):
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 3
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
|
||||
{**inputs, "#context#": context}
|
||||
)
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_args):
|
||||
@@ -148,7 +133,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
image_config={
|
||||
"detail": "high",
|
||||
}
|
||||
)
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -162,14 +147,14 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
context=context,
|
||||
memory_config=None,
|
||||
memory=None,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].role == PromptMessageRole.SYSTEM
|
||||
assert prompt_messages[0].content == PromptTemplateParser(
|
||||
template=messages[0].text
|
||||
).format({**inputs, "#context#": context})
|
||||
assert prompt_messages[0].content == PromptTemplateParser(template=messages[0].text).format(
|
||||
{**inputs, "#context#": context}
|
||||
)
|
||||
assert isinstance(prompt_messages[3].content, list)
|
||||
assert len(prompt_messages[3].content) == 2
|
||||
assert prompt_messages[3].content[1].data == files[0].url
|
||||
@@ -178,33 +163,20 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
@pytest.fixture
|
||||
def get_chat_model_args():
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-4'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-4"
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
window=MemoryConfig.WindowConfig(
|
||||
enabled=False
|
||||
)
|
||||
)
|
||||
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
|
||||
|
||||
prompt_messages = [
|
||||
ChatModelMessage(
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", role=PromptMessageRole.SYSTEM
|
||||
),
|
||||
ChatModelMessage(
|
||||
text="Hi.",
|
||||
role=PromptMessageRole.USER
|
||||
),
|
||||
ChatModelMessage(
|
||||
text="Hello!",
|
||||
role=PromptMessageRole.ASSISTANT
|
||||
)
|
||||
ChatModelMessage(text="Hi.", role=PromptMessageRole.USER),
|
||||
ChatModelMessage(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
inputs = {"name": "John"}
|
||||
|
||||
context = "I am superman."
|
||||
|
||||
|
||||
@@ -18,27 +18,28 @@ from models.model import Conversation
|
||||
|
||||
def test_get_prompt():
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content='System Template'),
|
||||
UserPromptMessage(content='User Query'),
|
||||
SystemPromptMessage(content="System Template"),
|
||||
UserPromptMessage(content="User Query"),
|
||||
]
|
||||
history_messages = [
|
||||
SystemPromptMessage(content='System Prompt 1'),
|
||||
UserPromptMessage(content='User Prompt 1'),
|
||||
AssistantPromptMessage(content='Assistant Thought 1'),
|
||||
ToolPromptMessage(content='Tool 1-1', name='Tool 1-1', tool_call_id='1'),
|
||||
ToolPromptMessage(content='Tool 1-2', name='Tool 1-2', tool_call_id='2'),
|
||||
SystemPromptMessage(content='System Prompt 2'),
|
||||
UserPromptMessage(content='User Prompt 2'),
|
||||
AssistantPromptMessage(content='Assistant Thought 2'),
|
||||
ToolPromptMessage(content='Tool 2-1', name='Tool 2-1', tool_call_id='3'),
|
||||
ToolPromptMessage(content='Tool 2-2', name='Tool 2-2', tool_call_id='4'),
|
||||
UserPromptMessage(content='User Prompt 3'),
|
||||
AssistantPromptMessage(content='Assistant Thought 3'),
|
||||
SystemPromptMessage(content="System Prompt 1"),
|
||||
UserPromptMessage(content="User Prompt 1"),
|
||||
AssistantPromptMessage(content="Assistant Thought 1"),
|
||||
ToolPromptMessage(content="Tool 1-1", name="Tool 1-1", tool_call_id="1"),
|
||||
ToolPromptMessage(content="Tool 1-2", name="Tool 1-2", tool_call_id="2"),
|
||||
SystemPromptMessage(content="System Prompt 2"),
|
||||
UserPromptMessage(content="User Prompt 2"),
|
||||
AssistantPromptMessage(content="Assistant Thought 2"),
|
||||
ToolPromptMessage(content="Tool 2-1", name="Tool 2-1", tool_call_id="3"),
|
||||
ToolPromptMessage(content="Tool 2-2", name="Tool 2-2", tool_call_id="4"),
|
||||
UserPromptMessage(content="User Prompt 3"),
|
||||
AssistantPromptMessage(content="Assistant Thought 3"),
|
||||
]
|
||||
|
||||
# use message number instead of token for testing
|
||||
def side_effect_get_num_tokens(*args):
|
||||
return len(args[2])
|
||||
|
||||
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
large_language_model_mock.get_num_tokens = MagicMock(side_effect=side_effect_get_num_tokens)
|
||||
|
||||
@@ -46,20 +47,17 @@ def test_get_prompt():
|
||||
provider_model_bundle_mock.model_type_instance = large_language_model_mock
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.model = 'openai'
|
||||
model_config_mock.model = "openai"
|
||||
model_config_mock.credentials = {}
|
||||
model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
transform = AgentHistoryPromptTransform(
|
||||
model_config=model_config_mock,
|
||||
prompt_messages=prompt_messages,
|
||||
history_messages=history_messages,
|
||||
memory=memory
|
||||
memory=memory,
|
||||
)
|
||||
|
||||
max_token_limit = 5
|
||||
|
||||
@@ -12,19 +12,15 @@ from core.prompt.prompt_transform import PromptTransform
|
||||
def test__calculate_rest_token():
|
||||
model_schema_mock = MagicMock(spec=AIModelEntity)
|
||||
parameter_rule_mock = MagicMock(spec=ParameterRule)
|
||||
parameter_rule_mock.name = 'max_tokens'
|
||||
model_schema_mock.parameter_rules = [
|
||||
parameter_rule_mock
|
||||
]
|
||||
model_schema_mock.model_properties = {
|
||||
ModelPropertyKey.CONTEXT_SIZE: 62
|
||||
}
|
||||
parameter_rule_mock.name = "max_tokens"
|
||||
model_schema_mock.parameter_rules = [parameter_rule_mock]
|
||||
model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62}
|
||||
|
||||
large_language_model_mock = MagicMock(spec=LargeLanguageModel)
|
||||
large_language_model_mock.get_num_tokens.return_value = 6
|
||||
|
||||
provider_mock = MagicMock(spec=ProviderEntity)
|
||||
provider_mock.provider = 'openai'
|
||||
provider_mock.provider = "openai"
|
||||
|
||||
provider_configuration_mock = MagicMock(spec=ProviderConfiguration)
|
||||
provider_configuration_mock.provider = provider_mock
|
||||
@@ -35,11 +31,9 @@ def test__calculate_rest_token():
|
||||
provider_model_bundle_mock.configuration = provider_configuration_mock
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.model = 'gpt-4'
|
||||
model_config_mock.model = "gpt-4"
|
||||
model_config_mock.credentials = {}
|
||||
model_config_mock.parameters = {
|
||||
'max_tokens': 50
|
||||
}
|
||||
model_config_mock.parameters = {"max_tokens": 50}
|
||||
model_config_mock.model_schema = model_schema_mock
|
||||
model_config_mock.provider_model_bundle = provider_model_bundle_mock
|
||||
|
||||
@@ -49,8 +43,10 @@ def test__calculate_rest_token():
|
||||
rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock)
|
||||
|
||||
# Validate based on the mock configuration and expected logic
|
||||
expected_rest_tokens = (model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
- model_config_mock.parameters['max_tokens']
|
||||
- large_language_model_mock.get_num_tokens.return_value)
|
||||
expected_rest_tokens = (
|
||||
model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE]
|
||||
- model_config_mock.parameters["max_tokens"]
|
||||
- large_language_model_mock.get_num_tokens.return_value
|
||||
)
|
||||
assert rest_tokens == expected_rest_tokens
|
||||
assert rest_tokens == 6
|
||||
|
||||
@@ -19,12 +19,15 @@ def test_get_common_chat_app_prompt_template_with_pcqm():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['histories_prompt']
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"]
|
||||
+ pre_prompt
|
||||
+ "\n"
|
||||
+ prompt_rules["histories_prompt"]
|
||||
+ prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
|
||||
|
||||
|
||||
def test_get_baichuan_chat_app_prompt_template_with_pcqm():
|
||||
@@ -39,12 +42,15 @@ def test_get_baichuan_chat_app_prompt_template_with_pcqm():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['histories_prompt']
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#histories#', '#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"]
|
||||
+ pre_prompt
|
||||
+ "\n"
|
||||
+ prompt_rules["histories_prompt"]
|
||||
+ prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#histories#", "#query#"]
|
||||
|
||||
|
||||
def test_get_common_completion_app_prompt_template_with_pcq():
|
||||
@@ -59,11 +65,11 @@ def test_get_common_completion_app_prompt_template_with_pcq():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
|
||||
|
||||
|
||||
def test_get_baichuan_completion_app_prompt_template_with_pcq():
|
||||
@@ -78,12 +84,12 @@ def test_get_baichuan_completion_app_prompt_template_with_pcq():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
print(prompt_template['prompt_template'].template)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ pre_prompt + '\n'
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
|
||||
print(prompt_template["prompt_template"].template)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"] + pre_prompt + "\n" + prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_q():
|
||||
@@ -98,9 +104,9 @@ def test_get_common_chat_app_prompt_template_with_q():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == prompt_rules['query_prompt']
|
||||
assert prompt_template['special_variable_keys'] == ['#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == prompt_rules["query_prompt"]
|
||||
assert prompt_template["special_variable_keys"] == ["#query#"]
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_cq():
|
||||
@@ -115,10 +121,11 @@ def test_get_common_chat_app_prompt_template_with_cq():
|
||||
query_in_prompt=True,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
assert prompt_template['prompt_template'].template == (prompt_rules['context_prompt']
|
||||
+ prompt_rules['query_prompt'])
|
||||
assert prompt_template['special_variable_keys'] == ['#context#', '#query#']
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
assert prompt_template["prompt_template"].template == (
|
||||
prompt_rules["context_prompt"] + prompt_rules["query_prompt"]
|
||||
)
|
||||
assert prompt_template["special_variable_keys"] == ["#context#", "#query#"]
|
||||
|
||||
|
||||
def test_get_common_chat_app_prompt_template_with_p():
|
||||
@@ -133,30 +140,25 @@ def test_get_common_chat_app_prompt_template_with_p():
|
||||
query_in_prompt=False,
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
assert prompt_template['prompt_template'].template == pre_prompt + '\n'
|
||||
assert prompt_template['custom_variable_keys'] == ['name']
|
||||
assert prompt_template['special_variable_keys'] == []
|
||||
assert prompt_template["prompt_template"].template == pre_prompt + "\n"
|
||||
assert prompt_template["custom_variable_keys"] == ["name"]
|
||||
assert prompt_template["special_variable_keys"] == []
|
||||
|
||||
|
||||
def test__get_chat_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-4'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-4"
|
||||
|
||||
memory_mock = MagicMock(spec=TokenBufferMemory)
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Hello")
|
||||
]
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
|
||||
memory_mock.get_history_prompt_messages.return_value = history_prompt_messages
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
inputs = {"name": "John"}
|
||||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, _ = prompt_transform._get_chat_model_prompt_messages(
|
||||
@@ -167,7 +169,7 @@ def test__get_chat_model_prompt_messages():
|
||||
files=[],
|
||||
context=context,
|
||||
memory=memory_mock,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
@@ -180,8 +182,8 @@ def test__get_chat_model_prompt_messages():
|
||||
with_memory_prompt=False,
|
||||
)
|
||||
|
||||
full_inputs = {**inputs, '#context#': context}
|
||||
real_system_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
full_inputs = {**inputs, "#context#": context}
|
||||
real_system_prompt = prompt_template["prompt_template"].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 4
|
||||
assert prompt_messages[0].content == real_system_prompt
|
||||
@@ -192,26 +194,18 @@ def test__get_chat_model_prompt_messages():
|
||||
|
||||
def test__get_completion_model_prompt_messages():
|
||||
model_config_mock = MagicMock(spec=ModelConfigWithCredentialsEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.model = 'gpt-3.5-turbo-instruct'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = "gpt-3.5-turbo-instruct"
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=Conversation(),
|
||||
model_instance=model_config_mock
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=Conversation(), model_instance=model_config_mock)
|
||||
|
||||
history_prompt_messages = [
|
||||
UserPromptMessage(content="Hi"),
|
||||
AssistantPromptMessage(content="Hello")
|
||||
]
|
||||
history_prompt_messages = [UserPromptMessage(content="Hi"), AssistantPromptMessage(content="Hello")]
|
||||
memory.get_history_prompt_messages = MagicMock(return_value=history_prompt_messages)
|
||||
|
||||
prompt_transform = SimplePromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
pre_prompt = "You are a helpful assistant {{name}}."
|
||||
inputs = {
|
||||
"name": "John"
|
||||
}
|
||||
inputs = {"name": "John"}
|
||||
context = "yes or no."
|
||||
query = "How are you?"
|
||||
prompt_messages, stops = prompt_transform._get_completion_model_prompt_messages(
|
||||
@@ -222,7 +216,7 @@ def test__get_completion_model_prompt_messages():
|
||||
files=[],
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config_mock
|
||||
model_config=model_config_mock,
|
||||
)
|
||||
|
||||
prompt_template = prompt_transform.get_prompt_template(
|
||||
@@ -235,14 +229,19 @@ def test__get_completion_model_prompt_messages():
|
||||
with_memory_prompt=True,
|
||||
)
|
||||
|
||||
prompt_rules = prompt_template['prompt_rules']
|
||||
full_inputs = {**inputs, '#context#': context, '#query#': query, '#histories#': memory.get_history_prompt_text(
|
||||
max_token_limit=2000,
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant")
|
||||
)}
|
||||
real_prompt = prompt_template['prompt_template'].format(full_inputs)
|
||||
prompt_rules = prompt_template["prompt_rules"]
|
||||
full_inputs = {
|
||||
**inputs,
|
||||
"#context#": context,
|
||||
"#query#": query,
|
||||
"#histories#": memory.get_history_prompt_text(
|
||||
max_token_limit=2000,
|
||||
human_prefix=prompt_rules.get("human_prefix", "Human"),
|
||||
ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),
|
||||
),
|
||||
}
|
||||
real_prompt = prompt_template["prompt_template"].format(full_inputs)
|
||||
|
||||
assert len(prompt_messages) == 1
|
||||
assert stops == prompt_rules.get('stops')
|
||||
assert stops == prompt_rules.get("stops")
|
||||
assert prompt_messages[0].content == real_prompt
|
||||
|
||||
@@ -5,20 +5,15 @@ from core.rag.datasource.vdb.milvus.milvus_vector import MilvusConfig
|
||||
|
||||
|
||||
def test_default_value():
|
||||
valid_config = {
|
||||
'host': 'localhost',
|
||||
'port': 19530,
|
||||
'user': 'root',
|
||||
'password': 'Milvus'
|
||||
}
|
||||
valid_config = {"host": "localhost", "port": 19530, "user": "root", "password": "Milvus"}
|
||||
|
||||
for key in valid_config:
|
||||
config = valid_config.copy()
|
||||
del config[key]
|
||||
with pytest.raises(ValidationError) as e:
|
||||
MilvusConfig(**config)
|
||||
assert e.value.errors()[0]['msg'] == f'Value error, config MILVUS_{key.upper()} is required'
|
||||
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
|
||||
|
||||
config = MilvusConfig(**valid_config)
|
||||
assert config.secure is False
|
||||
assert config.database == 'default'
|
||||
assert config.database == "default"
|
||||
|
||||
@@ -9,19 +9,17 @@ from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_resp
|
||||
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
url = "https://firecrawl.dev"
|
||||
api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-'
|
||||
base_url = 'https://api.firecrawl.dev'
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=base_url)
|
||||
api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-"
|
||||
base_url = "https://api.firecrawl.dev"
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url)
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"crawlerOptions": {
|
||||
"includes": [],
|
||||
"excludes": [],
|
||||
"generateImgAltText": True,
|
||||
"maxDepth": 1,
|
||||
"limit": 1,
|
||||
'returnOnlyUrls': False,
|
||||
|
||||
"returnOnlyUrls": False,
|
||||
}
|
||||
}
|
||||
mocked_firecrawl = {
|
||||
|
||||
@@ -8,11 +8,8 @@ page_id = "page1"
|
||||
|
||||
|
||||
extractor = notion_extractor.NotionExtractor(
|
||||
notion_workspace_id='x',
|
||||
notion_obj_id='x',
|
||||
notion_page_type='page',
|
||||
tenant_id='x',
|
||||
notion_access_token='x')
|
||||
notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x"
|
||||
)
|
||||
|
||||
|
||||
def _generate_page(page_title: str):
|
||||
@@ -21,16 +18,10 @@ def _generate_page(page_title: str):
|
||||
"id": page_id,
|
||||
"properties": {
|
||||
"Page": {
|
||||
"type": "title",
|
||||
"title": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": page_title},
|
||||
"plain_text": page_title
|
||||
}
|
||||
]
|
||||
"type": "title",
|
||||
"title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}],
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -38,10 +29,7 @@ def _generate_block(block_id: str, block_type: str, block_text: str):
|
||||
return {
|
||||
"object": "block",
|
||||
"id": block_id,
|
||||
"parent": {
|
||||
"type": "page_id",
|
||||
"page_id": page_id
|
||||
},
|
||||
"parent": {"type": "page_id", "page_id": page_id},
|
||||
"type": block_type,
|
||||
"has_children": False,
|
||||
block_type: {
|
||||
@@ -49,10 +37,11 @@ def _generate_block(block_id: str, block_type: str, block_text: str):
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": block_text},
|
||||
"plain_text": block_text,
|
||||
}]
|
||||
}
|
||||
}
|
||||
"plain_text": block_text,
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _mock_response(data):
|
||||
@@ -63,7 +52,7 @@ def _mock_response(data):
|
||||
|
||||
|
||||
def _remove_multiple_new_lines(text):
|
||||
while '\n\n' in text:
|
||||
while "\n\n" in text:
|
||||
text = text.replace("\n\n", "\n")
|
||||
return text.strip()
|
||||
|
||||
@@ -71,21 +60,21 @@ def _remove_multiple_new_lines(text):
|
||||
def test_notion_page(mocker):
|
||||
texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"]
|
||||
mocked_notion_page = {
|
||||
"object": "list",
|
||||
"results": [
|
||||
_generate_block("b1", "heading_1", texts[0]),
|
||||
_generate_block("b2", "heading_2", texts[1]),
|
||||
_generate_block("b3", "paragraph", texts[2]),
|
||||
_generate_block("b4", "heading_3", texts[3])
|
||||
],
|
||||
"next_cursor": None
|
||||
"object": "list",
|
||||
"results": [
|
||||
_generate_block("b1", "heading_1", texts[0]),
|
||||
_generate_block("b2", "heading_2", texts[1]),
|
||||
_generate_block("b3", "paragraph", texts[2]),
|
||||
_generate_block("b4", "heading_3", texts[3]),
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("requests.request", return_value=_mock_response(mocked_notion_page))
|
||||
|
||||
page_docs = extractor._load_data_as_documents(page_id, "page")
|
||||
assert len(page_docs) == 1
|
||||
content = _remove_multiple_new_lines(page_docs[0].page_content)
|
||||
assert content == '# Head 1\n## 1.1\nparagraph 1\n### 1.1.1'
|
||||
assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1"
|
||||
|
||||
|
||||
def test_notion_database(mocker):
|
||||
@@ -93,10 +82,10 @@ def test_notion_database(mocker):
|
||||
mocked_notion_database = {
|
||||
"object": "list",
|
||||
"results": [_generate_page(i) for i in page_title_list],
|
||||
"next_cursor": None
|
||||
"next_cursor": None,
|
||||
}
|
||||
mocker.patch("requests.post", return_value=_mock_response(mocked_notion_database))
|
||||
database_docs = extractor._load_data_as_documents(database_id, "database")
|
||||
assert len(database_docs) == 1
|
||||
content = _remove_multiple_new_lines(database_docs[0].page_content)
|
||||
assert content == '\n'.join([f'Page:{i}' for i in page_title_list])
|
||||
assert content == "\n".join([f"Page:{i}" for i in page_title_list])
|
||||
|
||||
@@ -10,36 +10,24 @@ from core.model_runtime.entities.model_entities import ModelType
|
||||
@pytest.fixture
|
||||
def lb_model_manager():
|
||||
load_balancing_configs = [
|
||||
ModelLoadBalancingConfiguration(
|
||||
id='id1',
|
||||
name='__inherit__',
|
||||
credentials={}
|
||||
),
|
||||
ModelLoadBalancingConfiguration(
|
||||
id='id2',
|
||||
name='first',
|
||||
credentials={"openai_api_key": "fake_key"}
|
||||
),
|
||||
ModelLoadBalancingConfiguration(
|
||||
id='id3',
|
||||
name='second',
|
||||
credentials={"openai_api_key": "fake_key"}
|
||||
)
|
||||
ModelLoadBalancingConfiguration(id="id1", name="__inherit__", credentials={}),
|
||||
ModelLoadBalancingConfiguration(id="id2", name="first", credentials={"openai_api_key": "fake_key"}),
|
||||
ModelLoadBalancingConfiguration(id="id3", name="second", credentials={"openai_api_key": "fake_key"}),
|
||||
]
|
||||
|
||||
lb_model_manager = LBModelManager(
|
||||
tenant_id='tenant_id',
|
||||
provider='openai',
|
||||
tenant_id="tenant_id",
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
model='gpt-4',
|
||||
model="gpt-4",
|
||||
load_balancing_configs=load_balancing_configs,
|
||||
managed_credentials={"openai_api_key": "fake_key"}
|
||||
managed_credentials={"openai_api_key": "fake_key"},
|
||||
)
|
||||
|
||||
lb_model_manager.cooldown = MagicMock(return_value=None)
|
||||
|
||||
def is_cooldown(config: ModelLoadBalancingConfiguration):
|
||||
if config.id == 'id1':
|
||||
if config.id == "id1":
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -61,14 +49,15 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager):
|
||||
assert lb_model_manager.in_cooldown(config3) is False
|
||||
|
||||
start_index = 0
|
||||
|
||||
def incr(key):
|
||||
nonlocal start_index
|
||||
start_index += 1
|
||||
return start_index
|
||||
|
||||
mocker.patch('redis.Redis.incr', side_effect=incr)
|
||||
mocker.patch('redis.Redis.set', return_value=None)
|
||||
mocker.patch('redis.Redis.expire', return_value=None)
|
||||
mocker.patch("redis.Redis.incr", side_effect=incr)
|
||||
mocker.patch("redis.Redis.set", return_value=None)
|
||||
mocker.patch("redis.Redis.expire", return_value=None)
|
||||
|
||||
config = lb_model_manager.fetch_next()
|
||||
assert config == config2
|
||||
|
||||
@@ -11,62 +11,62 @@ def test__to_model_settings(mocker):
|
||||
|
||||
provider_entity = None
|
||||
for provider in provider_entities:
|
||||
if provider.provider == 'openai':
|
||||
if provider.provider == "openai":
|
||||
provider_entity = provider
|
||||
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [ProviderModelSetting(
|
||||
id='id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
enabled=True,
|
||||
load_balancing_enabled=True
|
||||
)]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id='id1',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='__inherit__',
|
||||
encrypted_config=None,
|
||||
enabled=True
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id='id2',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='first',
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
|
||||
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity,
|
||||
provider_model_settings,
|
||||
load_balancing_model_configs
|
||||
)
|
||||
result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == 'gpt-4'
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 2
|
||||
assert result[0].load_balancing_configs[0].name == '__inherit__'
|
||||
assert result[0].load_balancing_configs[1].name == 'first'
|
||||
assert result[0].load_balancing_configs[0].name == "__inherit__"
|
||||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
def test__to_model_settings_only_one_lb(mocker):
|
||||
@@ -75,47 +75,47 @@ def test__to_model_settings_only_one_lb(mocker):
|
||||
|
||||
provider_entity = None
|
||||
for provider in provider_entities:
|
||||
if provider.provider == 'openai':
|
||||
if provider.provider == "openai":
|
||||
provider_entity = provider
|
||||
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [ProviderModelSetting(
|
||||
id='id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
enabled=True,
|
||||
load_balancing_enabled=True
|
||||
)]
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=True,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id='id1',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='__inherit__',
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True
|
||||
enabled=True,
|
||||
)
|
||||
]
|
||||
|
||||
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity,
|
||||
provider_model_settings,
|
||||
load_balancing_model_configs
|
||||
)
|
||||
result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == 'gpt-4'
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
@@ -127,57 +127,57 @@ def test__to_model_settings_lb_disabled(mocker):
|
||||
|
||||
provider_entity = None
|
||||
for provider in provider_entities:
|
||||
if provider.provider == 'openai':
|
||||
if provider.provider == "openai":
|
||||
provider_entity = provider
|
||||
|
||||
# Mocking the inputs
|
||||
provider_model_settings = [ProviderModelSetting(
|
||||
id='id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
enabled=True,
|
||||
load_balancing_enabled=False
|
||||
)]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id='id1',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='__inherit__',
|
||||
encrypted_config=None,
|
||||
enabled=True
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id='id2',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openai',
|
||||
model_name='gpt-4',
|
||||
model_type='text-generation',
|
||||
name='first',
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True
|
||||
provider_model_settings = [
|
||||
ProviderModelSetting(
|
||||
id="id",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
]
|
||||
load_balancing_model_configs = [
|
||||
LoadBalancingModelConfig(
|
||||
id="id1",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="__inherit__",
|
||||
encrypted_config=None,
|
||||
enabled=True,
|
||||
),
|
||||
LoadBalancingModelConfig(
|
||||
id="id2",
|
||||
tenant_id="tenant_id",
|
||||
provider_name="openai",
|
||||
model_name="gpt-4",
|
||||
model_type="text-generation",
|
||||
name="first",
|
||||
encrypted_config='{"openai_api_key": "fake_key"}',
|
||||
enabled=True,
|
||||
),
|
||||
]
|
||||
|
||||
mocker.patch('core.helper.model_provider_cache.ProviderCredentialsCache.get', return_value={"openai_api_key": "fake_key"})
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity,
|
||||
provider_model_settings,
|
||||
load_balancing_model_configs
|
||||
)
|
||||
result = provider_manager._to_model_settings(provider_entity, provider_model_settings, load_balancing_model_configs)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], ModelSettings)
|
||||
assert result[0].model == 'gpt-4'
|
||||
assert result[0].model == "gpt-4"
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
@@ -5,52 +5,52 @@ from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
|
||||
|
||||
def test_get_parameter_type():
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == 'string'
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == 'string'
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == 'boolean'
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == 'number'
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.STRING) == "string"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.SELECT) == "string"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.BOOLEAN) == "boolean"
|
||||
assert ToolParameterConverter.get_parameter_type(ToolParameter.ToolParameterType.NUMBER) == "number"
|
||||
with pytest.raises(ValueError):
|
||||
ToolParameterConverter.get_parameter_type('unsupported_type')
|
||||
ToolParameterConverter.get_parameter_type("unsupported_type")
|
||||
|
||||
|
||||
def test_cast_parameter_by_type():
|
||||
# string
|
||||
assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.STRING) == 'test'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == '1'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == '1.0'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ''
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.STRING) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.STRING) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.STRING) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.STRING) == ""
|
||||
|
||||
# secret input
|
||||
assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SECRET_INPUT) == 'test'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == '1'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == '1.0'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ''
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SECRET_INPUT) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SECRET_INPUT) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SECRET_INPUT) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SECRET_INPUT) == ""
|
||||
|
||||
# select
|
||||
assert ToolParameterConverter.cast_parameter_by_type('test', ToolParameter.ToolParameterType.SELECT) == 'test'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == '1'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == '1.0'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ''
|
||||
assert ToolParameterConverter.cast_parameter_by_type("test", ToolParameter.ToolParameterType.SELECT) == "test"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.SELECT) == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.SELECT) == "1.0"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.SELECT) == ""
|
||||
|
||||
# boolean
|
||||
true_values = [True, 'True', 'true', '1', 'YES', 'Yes', 'yes', 'y', 'something']
|
||||
true_values = [True, "True", "true", "1", "YES", "Yes", "yes", "y", "something"]
|
||||
for value in true_values:
|
||||
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is True
|
||||
|
||||
false_values = [False, 'False', 'false', '0', 'NO', 'No', 'no', 'n', None, '']
|
||||
false_values = [False, "False", "false", "0", "NO", "No", "no", "n", None, ""]
|
||||
for value in false_values:
|
||||
assert ToolParameterConverter.cast_parameter_by_type(value, ToolParameter.ToolParameterType.BOOLEAN) is False
|
||||
|
||||
# number
|
||||
assert ToolParameterConverter.cast_parameter_by_type('1', ToolParameter.ToolParameterType.NUMBER) == 1
|
||||
assert ToolParameterConverter.cast_parameter_by_type('1.0', ToolParameter.ToolParameterType.NUMBER) == 1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type('-1.0', ToolParameter.ToolParameterType.NUMBER) == -1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1", ToolParameter.ToolParameterType.NUMBER) == 1
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1.0", ToolParameter.ToolParameterType.NUMBER) == 1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type("-1.0", ToolParameter.ToolParameterType.NUMBER) == -1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, ToolParameter.ToolParameterType.NUMBER) == 1
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1.0, ToolParameter.ToolParameterType.NUMBER) == 1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(-1.0, ToolParameter.ToolParameterType.NUMBER) == -1.0
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
|
||||
|
||||
# unknown
|
||||
assert ToolParameterConverter.cast_parameter_by_type('1', 'unknown_type') == '1'
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, 'unknown_type') == '1'
|
||||
assert ToolParameterConverter.cast_parameter_by_type("1", "unknown_type") == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(1, "unknown_type") == "1"
|
||||
assert ToolParameterConverter.cast_parameter_by_type(None, ToolParameter.ToolParameterType.NUMBER) is None
|
||||
|
||||
@@ -11,29 +11,30 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
def test_execute_answer():
|
||||
node = AnswerNode(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_id='1',
|
||||
user_id='1',
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'answer',
|
||||
'data': {
|
||||
'title': '123',
|
||||
'type': 'answer',
|
||||
'answer': 'Today\'s weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.'
|
||||
}
|
||||
}
|
||||
"id": "answer",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "answer",
|
||||
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'weather'], 'sunny')
|
||||
pool.add(['llm', 'text'], 'You are a helpful AI.')
|
||||
pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["start", "weather"], "sunny")
|
||||
pool.add(["llm", "text"], "You are a helpful AI.")
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
@@ -42,4 +43,4 @@ def test_execute_answer():
|
||||
result = node._run(pool)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs['answer'] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
|
||||
assert result.outputs["answer"] == "Today's weather is sunny\nYou are a helpful AI.\n{{img}}\nFin."
|
||||
|
||||
@@ -11,134 +11,81 @@ from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
def test_execute_if_else_result_true():
|
||||
node = IfElseNode(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_id='1',
|
||||
user_id='1',
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'if-else',
|
||||
'data': {
|
||||
'title': '123',
|
||||
'type': 'if-else',
|
||||
'logical_operator': 'and',
|
||||
'conditions': [
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
'comparison_operator': 'contains',
|
||||
'variable_selector': ['start', 'array_contains'],
|
||||
'value': 'ab'
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'not contains',
|
||||
'variable_selector': ['start', 'array_not_contains'],
|
||||
'value': 'ab'
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
|
||||
{
|
||||
'comparison_operator': 'contains',
|
||||
'variable_selector': ['start', 'contains'],
|
||||
'value': 'ab'
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
|
||||
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
|
||||
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
|
||||
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
|
||||
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
|
||||
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
|
||||
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
|
||||
{"comparison_operator": "≠", "variable_selector": ["start", "not_equals"], "value": "22"},
|
||||
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
|
||||
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
|
||||
{
|
||||
'comparison_operator': 'not contains',
|
||||
'variable_selector': ['start', 'not_contains'],
|
||||
'value': 'ab'
|
||||
"comparison_operator": "≥",
|
||||
"variable_selector": ["start", "greater_than_or_equal"],
|
||||
"value": "22",
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'start with',
|
||||
'variable_selector': ['start', 'start_with'],
|
||||
'value': 'ab'
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'end with',
|
||||
'variable_selector': ['start', 'end_with'],
|
||||
'value': 'ab'
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'is',
|
||||
'variable_selector': ['start', 'is'],
|
||||
'value': 'ab'
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'is not',
|
||||
'variable_selector': ['start', 'is_not'],
|
||||
'value': 'ab'
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'empty',
|
||||
'variable_selector': ['start', 'empty'],
|
||||
'value': 'ab'
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'not empty',
|
||||
'variable_selector': ['start', 'not_empty'],
|
||||
'value': 'ab'
|
||||
},
|
||||
{
|
||||
'comparison_operator': '=',
|
||||
'variable_selector': ['start', 'equals'],
|
||||
'value': '22'
|
||||
},
|
||||
{
|
||||
'comparison_operator': '≠',
|
||||
'variable_selector': ['start', 'not_equals'],
|
||||
'value': '22'
|
||||
},
|
||||
{
|
||||
'comparison_operator': '>',
|
||||
'variable_selector': ['start', 'greater_than'],
|
||||
'value': '22'
|
||||
},
|
||||
{
|
||||
'comparison_operator': '<',
|
||||
'variable_selector': ['start', 'less_than'],
|
||||
'value': '22'
|
||||
},
|
||||
{
|
||||
'comparison_operator': '≥',
|
||||
'variable_selector': ['start', 'greater_than_or_equal'],
|
||||
'value': '22'
|
||||
},
|
||||
{
|
||||
'comparison_operator': '≤',
|
||||
'variable_selector': ['start', 'less_than_or_equal'],
|
||||
'value': '22'
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'null',
|
||||
'variable_selector': ['start', 'null']
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'not null',
|
||||
'variable_selector': ['start', 'not_null']
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
{"comparison_operator": "≤", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
|
||||
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
|
||||
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
||||
pool.add(['start', 'contains'], 'cabcde')
|
||||
pool.add(['start', 'not_contains'], 'zacde')
|
||||
pool.add(['start', 'start_with'], 'abc')
|
||||
pool.add(['start', 'end_with'], 'zzab')
|
||||
pool.add(['start', 'is'], 'ab')
|
||||
pool.add(['start', 'is_not'], 'aab')
|
||||
pool.add(['start', 'empty'], '')
|
||||
pool.add(['start', 'not_empty'], 'aaa')
|
||||
pool.add(['start', 'equals'], 22)
|
||||
pool.add(['start', 'not_equals'], 23)
|
||||
pool.add(['start', 'greater_than'], 23)
|
||||
pool.add(['start', 'less_than'], 21)
|
||||
pool.add(['start', 'greater_than_or_equal'], 22)
|
||||
pool.add(['start', 'less_than_or_equal'], 21)
|
||||
pool.add(['start', 'not_null'], '1212')
|
||||
pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["start", "array_contains"], ["ab", "def"])
|
||||
pool.add(["start", "array_not_contains"], ["ac", "def"])
|
||||
pool.add(["start", "contains"], "cabcde")
|
||||
pool.add(["start", "not_contains"], "zacde")
|
||||
pool.add(["start", "start_with"], "abc")
|
||||
pool.add(["start", "end_with"], "zzab")
|
||||
pool.add(["start", "is"], "ab")
|
||||
pool.add(["start", "is_not"], "aab")
|
||||
pool.add(["start", "empty"], "")
|
||||
pool.add(["start", "not_empty"], "aaa")
|
||||
pool.add(["start", "equals"], 22)
|
||||
pool.add(["start", "not_equals"], 23)
|
||||
pool.add(["start", "greater_than"], 23)
|
||||
pool.add(["start", "less_than"], 21)
|
||||
pool.add(["start", "greater_than_or_equal"], 22)
|
||||
pool.add(["start", "less_than_or_equal"], 21)
|
||||
pool.add(["start", "not_null"], "1212")
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
@@ -147,46 +94,47 @@ def test_execute_if_else_result_true():
|
||||
result = node._run(pool)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs['result'] is True
|
||||
assert result.outputs["result"] is True
|
||||
|
||||
|
||||
def test_execute_if_else_result_false():
|
||||
node = IfElseNode(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_id='1',
|
||||
user_id='1',
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'if-else',
|
||||
'data': {
|
||||
'title': '123',
|
||||
'type': 'if-else',
|
||||
'logical_operator': 'or',
|
||||
'conditions': [
|
||||
"id": "if-else",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"type": "if-else",
|
||||
"logical_operator": "or",
|
||||
"conditions": [
|
||||
{
|
||||
'comparison_operator': 'contains',
|
||||
'variable_selector': ['start', 'array_contains'],
|
||||
'value': 'ab'
|
||||
"comparison_operator": "contains",
|
||||
"variable_selector": ["start", "array_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
{
|
||||
'comparison_operator': 'not contains',
|
||||
'variable_selector': ['start', 'array_not_contains'],
|
||||
'value': 'ab'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
"comparison_operator": "not contains",
|
||||
"variable_selector": ["start", "array_not_contains"],
|
||||
"value": "ab",
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||
pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["start", "array_contains"], ["1ab", "def"])
|
||||
pool.add(["start", "array_not_contains"], ["ab", "def"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
@@ -195,4 +143,4 @@ def test_execute_if_else_result_false():
|
||||
result = node._run(pool)
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs['result'] is False
|
||||
assert result.outputs["result"] is False
|
||||
|
||||
@@ -8,41 +8,41 @@ from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode
|
||||
|
||||
DEFAULT_NODE_ID = 'node_id'
|
||||
DEFAULT_NODE_ID = "node_id"
|
||||
|
||||
|
||||
def test_overwrite_string_variable():
|
||||
conversation_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value='the first value',
|
||||
name="test_conversation_variable",
|
||||
value="the first value",
|
||||
)
|
||||
|
||||
input_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_string_variable',
|
||||
value='the second value',
|
||||
name="test_string_variable",
|
||||
value="the second value",
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
tenant_id="tenant_id",
|
||||
app_id="app_id",
|
||||
workflow_id="workflow_id",
|
||||
user_id="user_id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.OVER_WRITE.value,
|
||||
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.OVER_WRITE.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -52,48 +52,48 @@ def test_overwrite_string_variable():
|
||||
input_variable,
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
|
||||
with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run:
|
||||
node.run(variable_pool)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.value == 'the second value'
|
||||
assert got.to_object() == 'the second value'
|
||||
assert got.value == "the second value"
|
||||
assert got.to_object() == "the second value"
|
||||
|
||||
|
||||
def test_append_variable_to_array():
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value=['the first value'],
|
||||
name="test_conversation_variable",
|
||||
value=["the first value"],
|
||||
)
|
||||
|
||||
input_variable = StringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_string_variable',
|
||||
value='the second value',
|
||||
name="test_string_variable",
|
||||
value="the second value",
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
tenant_id="tenant_id",
|
||||
app_id="app_id",
|
||||
workflow_id="workflow_id",
|
||||
user_id="user_id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.APPEND.value,
|
||||
'input_variable_selector': [DEFAULT_NODE_ID, input_variable.name],
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.APPEND.value,
|
||||
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -103,41 +103,41 @@ def test_append_variable_to_array():
|
||||
input_variable,
|
||||
)
|
||||
|
||||
with mock.patch('core.workflow.nodes.variable_assigner.node.update_conversation_variable') as mock_run:
|
||||
with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run:
|
||||
node.run(variable_pool)
|
||||
mock_run.assert_called_once()
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == ['the first value', 'the second value']
|
||||
assert got.to_object() == ["the first value", "the second value"]
|
||||
|
||||
|
||||
def test_clear_array():
|
||||
conversation_variable = ArrayStringVariable(
|
||||
id=str(uuid4()),
|
||||
name='test_conversation_variable',
|
||||
value=['the first value'],
|
||||
name="test_conversation_variable",
|
||||
value=["the first value"],
|
||||
)
|
||||
|
||||
node = VariableAssignerNode(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
workflow_id='workflow_id',
|
||||
user_id='user_id',
|
||||
tenant_id="tenant_id",
|
||||
app_id="app_id",
|
||||
workflow_id="workflow_id",
|
||||
user_id="user_id",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config={
|
||||
'id': 'node_id',
|
||||
'data': {
|
||||
'assigned_variable_selector': ['conversation', conversation_variable.name],
|
||||
'write_mode': WriteMode.CLEAR.value,
|
||||
'input_variable_selector': [],
|
||||
"id": "node_id",
|
||||
"data": {
|
||||
"assigned_variable_selector": ["conversation", conversation_variable.name],
|
||||
"write_mode": WriteMode.CLEAR.value,
|
||||
"input_variable_selector": [],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: 'conversation_id'},
|
||||
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[conversation_variable],
|
||||
@@ -145,6 +145,6 @@ def test_clear_array():
|
||||
|
||||
node.run(variable_pool)
|
||||
|
||||
got = variable_pool.get(['conversation', conversation_variable.name])
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
assert got is not None
|
||||
assert got.to_object() == []
|
||||
|
||||
@@ -3,50 +3,46 @@ import pandas as pd
|
||||
|
||||
def test_pandas_csv(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
data = {'col1': [1, 2.2, -3.3, 4.0, 5],
|
||||
'col2': ['A', 'B', 'C', 'D', 'E']}
|
||||
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
|
||||
df1 = pd.DataFrame(data)
|
||||
|
||||
# write to csv file
|
||||
csv_file_path = tmp_path.joinpath('example.csv')
|
||||
csv_file_path = tmp_path.joinpath("example.csv")
|
||||
df1.to_csv(csv_file_path, index=False)
|
||||
|
||||
# read from csv file
|
||||
df2 = pd.read_csv(csv_file_path, on_bad_lines='skip')
|
||||
assert df2[df2.columns[0]].to_list() == data['col1']
|
||||
assert df2[df2.columns[1]].to_list() == data['col2']
|
||||
df2 = pd.read_csv(csv_file_path, on_bad_lines="skip")
|
||||
assert df2[df2.columns[0]].to_list() == data["col1"]
|
||||
assert df2[df2.columns[1]].to_list() == data["col2"]
|
||||
|
||||
|
||||
def test_pandas_xlsx(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
data = {'col1': [1, 2.2, -3.3, 4.0, 5],
|
||||
'col2': ['A', 'B', 'C', 'D', 'E']}
|
||||
data = {"col1": [1, 2.2, -3.3, 4.0, 5], "col2": ["A", "B", "C", "D", "E"]}
|
||||
df1 = pd.DataFrame(data)
|
||||
|
||||
# write to xlsx file
|
||||
xlsx_file_path = tmp_path.joinpath('example.xlsx')
|
||||
xlsx_file_path = tmp_path.joinpath("example.xlsx")
|
||||
df1.to_excel(xlsx_file_path, index=False)
|
||||
|
||||
# read from xlsx file
|
||||
df2 = pd.read_excel(xlsx_file_path)
|
||||
assert df2[df2.columns[0]].to_list() == data['col1']
|
||||
assert df2[df2.columns[1]].to_list() == data['col2']
|
||||
assert df2[df2.columns[0]].to_list() == data["col1"]
|
||||
assert df2[df2.columns[1]].to_list() == data["col2"]
|
||||
|
||||
|
||||
def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch):
|
||||
monkeypatch.chdir(tmp_path)
|
||||
data1 = {'col1': [1, 2, 3, 4, 5],
|
||||
'col2': ['A', 'B', 'C', 'D', 'E']}
|
||||
data1 = {"col1": [1, 2, 3, 4, 5], "col2": ["A", "B", "C", "D", "E"]}
|
||||
df1 = pd.DataFrame(data1)
|
||||
|
||||
data2 = {'col1': [6, 7, 8, 9, 10],
|
||||
'col2': ['F', 'G', 'H', 'I', 'J']}
|
||||
data2 = {"col1": [6, 7, 8, 9, 10], "col2": ["F", "G", "H", "I", "J"]}
|
||||
df2 = pd.DataFrame(data2)
|
||||
|
||||
# write to xlsx file with sheets
|
||||
xlsx_file_path = tmp_path.joinpath('example_with_sheets.xlsx')
|
||||
sheet1 = 'Sheet1'
|
||||
sheet2 = 'Sheet2'
|
||||
xlsx_file_path = tmp_path.joinpath("example_with_sheets.xlsx")
|
||||
sheet1 = "Sheet1"
|
||||
sheet2 = "Sheet2"
|
||||
with pd.ExcelWriter(xlsx_file_path) as excel_writer:
|
||||
df1.to_excel(excel_writer, sheet_name=sheet1, index=False)
|
||||
df2.to_excel(excel_writer, sheet_name=sheet2, index=False)
|
||||
@@ -54,9 +50,9 @@ def test_pandas_xlsx_with_sheets(tmp_path, monkeypatch):
|
||||
# read from xlsx file with sheets
|
||||
with pd.ExcelFile(xlsx_file_path) as excel_file:
|
||||
df1 = pd.read_excel(excel_file, sheet_name=sheet1)
|
||||
assert df1[df1.columns[0]].to_list() == data1['col1']
|
||||
assert df1[df1.columns[1]].to_list() == data1['col2']
|
||||
assert df1[df1.columns[0]].to_list() == data1["col1"]
|
||||
assert df1[df1.columns[1]].to_list() == data1["col2"]
|
||||
|
||||
df2 = pd.read_excel(excel_file, sheet_name=sheet2)
|
||||
assert df2[df2.columns[0]].to_list() == data2['col1']
|
||||
assert df2[df2.columns[1]].to_list() == data2['col2']
|
||||
assert df2[df2.columns[0]].to_list() == data2["col1"]
|
||||
assert df2[df2.columns[1]].to_list() == data2["col2"]
|
||||
|
||||
@@ -15,7 +15,7 @@ def test_gmpy2_pkcs10aep_cipher() -> None:
|
||||
private_rsa_key = RSA.import_key(private_key)
|
||||
private_cipher_rsa = gmpy2_pkcs10aep_cipher.new(private_rsa_key)
|
||||
|
||||
raw_text = 'raw_text'
|
||||
raw_text = "raw_text"
|
||||
raw_text_bytes = raw_text.encode()
|
||||
|
||||
# RSA encryption by public key and decryption by private key
|
||||
|
||||
@@ -3,21 +3,21 @@ from yarl import URL
|
||||
|
||||
|
||||
def test_yarl_urls():
|
||||
expected_1 = 'https://dify.ai/api'
|
||||
assert str(URL('https://dify.ai') / 'api') == expected_1
|
||||
assert str(URL('https://dify.ai/') / 'api') == expected_1
|
||||
expected_1 = "https://dify.ai/api"
|
||||
assert str(URL("https://dify.ai") / "api") == expected_1
|
||||
assert str(URL("https://dify.ai/") / "api") == expected_1
|
||||
|
||||
expected_2 = 'http://dify.ai:12345/api'
|
||||
assert str(URL('http://dify.ai:12345') / 'api') == expected_2
|
||||
assert str(URL('http://dify.ai:12345/') / 'api') == expected_2
|
||||
expected_2 = "http://dify.ai:12345/api"
|
||||
assert str(URL("http://dify.ai:12345") / "api") == expected_2
|
||||
assert str(URL("http://dify.ai:12345/") / "api") == expected_2
|
||||
|
||||
expected_3 = 'https://dify.ai/api/v1'
|
||||
assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3
|
||||
assert str(URL('https://dify.ai') / 'api/v1') == expected_3
|
||||
assert str(URL('https://dify.ai/') / 'api/v1') == expected_3
|
||||
assert str(URL('https://dify.ai/api') / 'v1') == expected_3
|
||||
assert str(URL('https://dify.ai/api/') / 'v1') == expected_3
|
||||
expected_3 = "https://dify.ai/api/v1"
|
||||
assert str(URL("https://dify.ai") / "api" / "v1") == expected_3
|
||||
assert str(URL("https://dify.ai") / "api/v1") == expected_3
|
||||
assert str(URL("https://dify.ai/") / "api/v1") == expected_3
|
||||
assert str(URL("https://dify.ai/api") / "v1") == expected_3
|
||||
assert str(URL("https://dify.ai/api/") / "v1") == expected_3
|
||||
|
||||
with pytest.raises(ValueError) as e1:
|
||||
str(URL('https://dify.ai') / '/api')
|
||||
str(URL("https://dify.ai") / "/api")
|
||||
assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"
|
||||
|
||||
@@ -2,13 +2,13 @@ from models.account import TenantAccountRole
|
||||
|
||||
|
||||
def test_account_is_privileged_role() -> None:
|
||||
assert TenantAccountRole.ADMIN == 'admin'
|
||||
assert TenantAccountRole.OWNER == 'owner'
|
||||
assert TenantAccountRole.EDITOR == 'editor'
|
||||
assert TenantAccountRole.NORMAL == 'normal'
|
||||
assert TenantAccountRole.ADMIN == "admin"
|
||||
assert TenantAccountRole.OWNER == "owner"
|
||||
assert TenantAccountRole.EDITOR == "editor"
|
||||
assert TenantAccountRole.NORMAL == "normal"
|
||||
|
||||
assert TenantAccountRole.is_privileged_role(TenantAccountRole.ADMIN)
|
||||
assert TenantAccountRole.is_privileged_role(TenantAccountRole.OWNER)
|
||||
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.NORMAL)
|
||||
assert not TenantAccountRole.is_privileged_role(TenantAccountRole.EDITOR)
|
||||
assert not TenantAccountRole.is_privileged_role('')
|
||||
assert not TenantAccountRole.is_privileged_role("")
|
||||
|
||||
@@ -7,19 +7,19 @@ from models import ConversationVariable
|
||||
def test_from_variable_and_to_variable():
|
||||
variable = factory.build_variable_from_mapping(
|
||||
{
|
||||
'id': str(uuid4()),
|
||||
'name': 'name',
|
||||
'value_type': SegmentType.OBJECT,
|
||||
'value': {
|
||||
'key': {
|
||||
'key': 'value',
|
||||
"id": str(uuid4()),
|
||||
"name": "name",
|
||||
"value_type": SegmentType.OBJECT,
|
||||
"value": {
|
||||
"key": {
|
||||
"key": "value",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
conversation_variable = ConversationVariable.from_variable(
|
||||
app_id='app_id', conversation_id='conversation_id', variable=variable
|
||||
app_id="app_id", conversation_id="conversation_id", variable=variable
|
||||
)
|
||||
|
||||
assert conversation_variable.to_variable() == variable
|
||||
|
||||
@@ -8,30 +8,30 @@ from models.workflow import Workflow
|
||||
|
||||
|
||||
def test_environment_variables():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
contexts.tenant_id.set("tenant_id")
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
type='workflow',
|
||||
version='draft',
|
||||
graph='{}',
|
||||
features='{}',
|
||||
created_by='account_id',
|
||||
tenant_id="tenant_id",
|
||||
app_id="app_id",
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
created_by="account_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
|
||||
variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())})
|
||||
variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())})
|
||||
variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())})
|
||||
variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())})
|
||||
variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())})
|
||||
variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())})
|
||||
variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())})
|
||||
|
||||
with (
|
||||
mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'),
|
||||
mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'),
|
||||
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
||||
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
||||
):
|
||||
# Set the environment_variables property of the Workflow instance
|
||||
variables = [variable1, variable2, variable3, variable4]
|
||||
@@ -42,30 +42,30 @@ def test_environment_variables():
|
||||
|
||||
|
||||
def test_update_environment_variables():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
contexts.tenant_id.set("tenant_id")
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
type='workflow',
|
||||
version='draft',
|
||||
graph='{}',
|
||||
features='{}',
|
||||
created_by='account_id',
|
||||
tenant_id="tenant_id",
|
||||
app_id="app_id",
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
created_by="account_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
# Create some EnvironmentVariable instances
|
||||
variable1 = StringVariable.model_validate({'name': 'var1', 'value': 'value1', 'id': str(uuid4())})
|
||||
variable2 = IntegerVariable.model_validate({'name': 'var2', 'value': 123, 'id': str(uuid4())})
|
||||
variable3 = SecretVariable.model_validate({'name': 'var3', 'value': 'secret', 'id': str(uuid4())})
|
||||
variable4 = FloatVariable.model_validate({'name': 'var4', 'value': 3.14, 'id': str(uuid4())})
|
||||
variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())})
|
||||
variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())})
|
||||
variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())})
|
||||
variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())})
|
||||
|
||||
with (
|
||||
mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'),
|
||||
mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'),
|
||||
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
||||
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
||||
):
|
||||
variables = [variable1, variable2, variable3, variable4]
|
||||
|
||||
@@ -76,28 +76,28 @@ def test_update_environment_variables():
|
||||
# Update the name of variable3 and keep the value as it is
|
||||
variables[2] = variable3.model_copy(
|
||||
update={
|
||||
'name': 'new name',
|
||||
'value': HIDDEN_VALUE,
|
||||
"name": "new name",
|
||||
"value": HIDDEN_VALUE,
|
||||
}
|
||||
)
|
||||
|
||||
workflow.environment_variables = variables
|
||||
assert workflow.environment_variables[2].name == 'new name'
|
||||
assert workflow.environment_variables[2].name == "new name"
|
||||
assert workflow.environment_variables[2].value == variable3.value
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
contexts.tenant_id.set('tenant_id')
|
||||
contexts.tenant_id.set("tenant_id")
|
||||
|
||||
# Create a Workflow instance
|
||||
workflow = Workflow(
|
||||
tenant_id='tenant_id',
|
||||
app_id='app_id',
|
||||
type='workflow',
|
||||
version='draft',
|
||||
graph='{}',
|
||||
features='{}',
|
||||
created_by='account_id',
|
||||
tenant_id="tenant_id",
|
||||
app_id="app_id",
|
||||
type="workflow",
|
||||
version="draft",
|
||||
graph="{}",
|
||||
features="{}",
|
||||
created_by="account_id",
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
@@ -105,19 +105,19 @@ def test_to_dict():
|
||||
# Create some EnvironmentVariable instances
|
||||
|
||||
with (
|
||||
mock.patch('core.helper.encrypter.encrypt_token', return_value='encrypted_token'),
|
||||
mock.patch('core.helper.encrypter.decrypt_token', return_value='secret'),
|
||||
mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"),
|
||||
mock.patch("core.helper.encrypter.decrypt_token", return_value="secret"),
|
||||
):
|
||||
# Set the environment_variables property of the Workflow instance
|
||||
workflow.environment_variables = [
|
||||
SecretVariable.model_validate({'name': 'secret', 'value': 'secret', 'id': str(uuid4())}),
|
||||
StringVariable.model_validate({'name': 'text', 'value': 'text', 'id': str(uuid4())}),
|
||||
SecretVariable.model_validate({"name": "secret", "value": "secret", "id": str(uuid4())}),
|
||||
StringVariable.model_validate({"name": "text", "value": "text", "id": str(uuid4())}),
|
||||
]
|
||||
|
||||
workflow_dict = workflow.to_dict()
|
||||
assert workflow_dict['environment_variables'][0]['value'] == ''
|
||||
assert workflow_dict['environment_variables'][1]['value'] == 'text'
|
||||
assert workflow_dict["environment_variables"][0]["value"] == ""
|
||||
assert workflow_dict["environment_variables"][1]["value"] == "text"
|
||||
|
||||
workflow_dict = workflow.to_dict(include_secret=True)
|
||||
assert workflow_dict['environment_variables'][0]['value'] == 'secret'
|
||||
assert workflow_dict['environment_variables'][1]['value'] == 'text'
|
||||
assert workflow_dict["environment_variables"][0]["value"] == "secret"
|
||||
assert workflow_dict["environment_variables"][1]["value"] == "text"
|
||||
|
||||
@@ -83,18 +83,12 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable",
|
||||
type="api",
|
||||
config={
|
||||
"api_based_extension_id": api_based_extension_id
|
||||
}
|
||||
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
|
||||
)
|
||||
]
|
||||
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=default_variables,
|
||||
external_data_variables=external_data_variables
|
||||
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
@@ -105,10 +99,7 @@ def test__convert_to_http_request_node_for_chatbot(default_variables):
|
||||
assert http_request_node["data"]["method"] == "post"
|
||||
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
|
||||
assert http_request_node["data"]["authorization"]["type"] == "api-key"
|
||||
assert http_request_node["data"]["authorization"]["config"] == {
|
||||
"type": "bearer",
|
||||
"api_key": "api_key"
|
||||
}
|
||||
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
|
||||
assert http_request_node["data"]["body"]["type"] == "json"
|
||||
|
||||
body_data = http_request_node["data"]["body"]["data"]
|
||||
@@ -153,18 +144,12 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
|
||||
external_data_variables = [
|
||||
ExternalDataVariableEntity(
|
||||
variable="external_variable",
|
||||
type="api",
|
||||
config={
|
||||
"api_based_extension_id": api_based_extension_id
|
||||
}
|
||||
variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id}
|
||||
)
|
||||
]
|
||||
|
||||
nodes, _ = workflow_converter._convert_to_http_request_node(
|
||||
app_model=app_model,
|
||||
variables=default_variables,
|
||||
external_data_variables=external_data_variables
|
||||
app_model=app_model, variables=default_variables, external_data_variables=external_data_variables
|
||||
)
|
||||
|
||||
assert len(nodes) == 2
|
||||
@@ -175,10 +160,7 @@ def test__convert_to_http_request_node_for_workflow_app(default_variables):
|
||||
assert http_request_node["data"]["method"] == "post"
|
||||
assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint
|
||||
assert http_request_node["data"]["authorization"]["type"] == "api-key"
|
||||
assert http_request_node["data"]["authorization"]["config"] == {
|
||||
"type": "bearer",
|
||||
"api_key": "api_key"
|
||||
}
|
||||
assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"}
|
||||
assert http_request_node["data"]["body"]["type"] == "json"
|
||||
|
||||
body_data = http_request_node["data"]["body"]["data"]
|
||||
@@ -207,37 +189,25 @@ def test__convert_to_knowledge_retrieval_node_for_chatbot():
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={
|
||||
'reranking_provider_name': 'cohere',
|
||||
'reranking_model_name': 'rerank-english-v2.0'
|
||||
},
|
||||
reranking_enabled=True
|
||||
)
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ModelConfigEntity(
|
||||
provider='openai',
|
||||
model='gpt-4',
|
||||
mode='chat',
|
||||
parameters={},
|
||||
stop=[]
|
||||
)
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config
|
||||
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
|
||||
)
|
||||
|
||||
assert node["data"]["type"] == "knowledge-retrieval"
|
||||
assert node["data"]["query_variable_selector"] == ["sys", "query"]
|
||||
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
|
||||
assert (node["data"]["retrieval_mode"]
|
||||
== dataset_config.retrieve_config.retrieve_strategy.value)
|
||||
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
|
||||
assert node["data"]["multiple_retrieval_config"] == {
|
||||
"top_k": dataset_config.retrieve_config.top_k,
|
||||
"score_threshold": dataset_config.retrieve_config.score_threshold,
|
||||
"reranking_model": dataset_config.retrieve_config.reranking_model
|
||||
"reranking_model": dataset_config.retrieve_config.reranking_model,
|
||||
}
|
||||
|
||||
|
||||
@@ -251,37 +221,25 @@ def test__convert_to_knowledge_retrieval_node_for_workflow_app():
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE,
|
||||
top_k=5,
|
||||
score_threshold=0.8,
|
||||
reranking_model={
|
||||
'reranking_provider_name': 'cohere',
|
||||
'reranking_model_name': 'rerank-english-v2.0'
|
||||
},
|
||||
reranking_enabled=True
|
||||
)
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"},
|
||||
reranking_enabled=True,
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ModelConfigEntity(
|
||||
provider='openai',
|
||||
model='gpt-4',
|
||||
mode='chat',
|
||||
parameters={},
|
||||
stop=[]
|
||||
)
|
||||
model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[])
|
||||
|
||||
node = WorkflowConverter()._convert_to_knowledge_retrieval_node(
|
||||
new_app_mode=new_app_mode,
|
||||
dataset_config=dataset_config,
|
||||
model_config=model_config
|
||||
new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config
|
||||
)
|
||||
|
||||
assert node["data"]["type"] == "knowledge-retrieval"
|
||||
assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable]
|
||||
assert node["data"]["dataset_ids"] == dataset_config.dataset_ids
|
||||
assert (node["data"]["retrieval_mode"]
|
||||
== dataset_config.retrieve_config.retrieve_strategy.value)
|
||||
assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value
|
||||
assert node["data"]["multiple_retrieval_config"] == {
|
||||
"top_k": dataset_config.retrieve_config.top_k,
|
||||
"score_threshold": dataset_config.retrieve_config.score_threshold,
|
||||
"reranking_model": dataset_config.retrieve_config.reranking_model
|
||||
"reranking_model": dataset_config.retrieve_config.reranking_model,
|
||||
}
|
||||
|
||||
|
||||
@@ -293,14 +251,12 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [
|
||||
start_node
|
||||
],
|
||||
"edges": [] # no need
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
@@ -308,7 +264,7 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}."
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
@@ -316,17 +272,17 @@ def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables):
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]['name'] == model
|
||||
assert llm_node["data"]['model']["mode"] == model_mode.value
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
template = prompt_template.simple_prompt_template
|
||||
for v in default_variables:
|
||||
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
|
||||
assert llm_node["data"]["prompt_template"][0]['text'] == template + '\n'
|
||||
assert llm_node["data"]['context']['enabled'] is False
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n"
|
||||
assert llm_node["data"]["context"]["enabled"] is False
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables):
|
||||
@@ -337,14 +293,12 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [
|
||||
start_node
|
||||
],
|
||||
"edges": [] # no need
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
@@ -352,7 +306,7 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.SIMPLE,
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}."
|
||||
simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.",
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
@@ -360,17 +314,17 @@ def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variab
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]['name'] == model
|
||||
assert llm_node["data"]['model']["mode"] == model_mode.value
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
template = prompt_template.simple_prompt_template
|
||||
for v in default_variables:
|
||||
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
|
||||
assert llm_node["data"]["prompt_template"]['text'] == template + '\n'
|
||||
assert llm_node["data"]['context']['enabled'] is False
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"]["text"] == template + "\n"
|
||||
assert llm_node["data"]["context"]["enabled"] is False
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables):
|
||||
@@ -381,14 +335,12 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables)
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [
|
||||
start_node
|
||||
],
|
||||
"edges": [] # no need
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
@@ -396,12 +348,16 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables)
|
||||
|
||||
prompt_template = PromptTemplateEntity(
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(messages=[
|
||||
AdvancedChatMessageEntity(text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM),
|
||||
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
])
|
||||
advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity(
|
||||
messages=[
|
||||
AdvancedChatMessageEntity(
|
||||
text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}",
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
),
|
||||
AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER),
|
||||
AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
@@ -409,18 +365,18 @@ def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables)
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]['name'] == model
|
||||
assert llm_node["data"]['model']["mode"] == model_mode.value
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
assert isinstance(llm_node["data"]["prompt_template"], list)
|
||||
assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages)
|
||||
template = prompt_template.advanced_chat_prompt_template.messages[0].text
|
||||
for v in default_variables:
|
||||
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
|
||||
assert llm_node["data"]["prompt_template"][0]['text'] == template
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"][0]["text"] == template
|
||||
|
||||
|
||||
def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables):
|
||||
@@ -431,14 +387,12 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
|
||||
workflow_converter = WorkflowConverter()
|
||||
start_node = workflow_converter._convert_to_start_node(default_variables)
|
||||
graph = {
|
||||
"nodes": [
|
||||
start_node
|
||||
],
|
||||
"edges": [] # no need
|
||||
"nodes": [start_node],
|
||||
"edges": [], # no need
|
||||
}
|
||||
|
||||
model_config_mock = MagicMock(spec=ModelConfigEntity)
|
||||
model_config_mock.provider = 'openai'
|
||||
model_config_mock.provider = "openai"
|
||||
model_config_mock.model = model
|
||||
model_config_mock.mode = model_mode.value
|
||||
model_config_mock.parameters = {}
|
||||
@@ -448,12 +402,9 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
|
||||
prompt_type=PromptTemplateEntity.PromptType.ADVANCED,
|
||||
advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity(
|
||||
prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\n"
|
||||
"Human: hi\nAssistant: ",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(
|
||||
user="Human",
|
||||
assistant="Assistant"
|
||||
)
|
||||
)
|
||||
"Human: hi\nAssistant: ",
|
||||
role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"),
|
||||
),
|
||||
)
|
||||
|
||||
llm_node = workflow_converter._convert_to_llm_node(
|
||||
@@ -461,14 +412,14 @@ def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_var
|
||||
new_app_mode=new_app_mode,
|
||||
model_config=model_config_mock,
|
||||
graph=graph,
|
||||
prompt_template=prompt_template
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
assert llm_node["data"]["type"] == "llm"
|
||||
assert llm_node["data"]["model"]['name'] == model
|
||||
assert llm_node["data"]['model']["mode"] == model_mode.value
|
||||
assert llm_node["data"]["model"]["name"] == model
|
||||
assert llm_node["data"]["model"]["mode"] == model_mode.value
|
||||
assert isinstance(llm_node["data"]["prompt_template"], dict)
|
||||
template = prompt_template.advanced_completion_prompt_template.prompt
|
||||
for v in default_variables:
|
||||
template = template.replace('{{' + v.variable + '}}', '{{#start.' + v.variable + '#}}')
|
||||
assert llm_node["data"]["prompt_template"]['text'] == template
|
||||
template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}")
|
||||
assert llm_node["data"]["prompt_template"]["text"] == template
|
||||
|
||||
@@ -8,8 +8,9 @@ from core.helper.position_helper import get_position_map, is_filtered, pin_posit
|
||||
@pytest.fixture
|
||||
def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tmp_path.joinpath("example_positions.yaml").write_text(dedent(
|
||||
"""\
|
||||
tmp_path.joinpath("example_positions.yaml").write_text(
|
||||
dedent(
|
||||
"""\
|
||||
- first
|
||||
- second
|
||||
# - commented
|
||||
@@ -17,57 +18,54 @@ def prepare_example_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
|
||||
- 9999999999999
|
||||
- forth
|
||||
"""))
|
||||
"""
|
||||
)
|
||||
)
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_empty_commented_positions_yaml(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(dedent(
|
||||
"""\
|
||||
tmp_path.joinpath("example_positions_all_commented.yaml").write_text(
|
||||
dedent(
|
||||
"""\
|
||||
# - commented1
|
||||
# - commented2
|
||||
-
|
||||
-
|
||||
|
||||
"""))
|
||||
"""
|
||||
)
|
||||
)
|
||||
return str(tmp_path)
|
||||
|
||||
|
||||
def test_position_helper(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
file_name='example_positions.yaml')
|
||||
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
|
||||
assert len(position_map) == 4
|
||||
assert position_map == {
|
||||
'first': 0,
|
||||
'second': 1,
|
||||
'third': 2,
|
||||
'forth': 3,
|
||||
"first": 0,
|
||||
"second": 1,
|
||||
"third": 2,
|
||||
"forth": 3,
|
||||
}
|
||||
|
||||
|
||||
def test_position_helper_with_all_commented(prepare_empty_commented_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_empty_commented_positions_yaml,
|
||||
file_name='example_positions_all_commented.yaml')
|
||||
folder_path=prepare_empty_commented_positions_yaml, file_name="example_positions_all_commented.yaml"
|
||||
)
|
||||
assert position_map == {}
|
||||
|
||||
|
||||
def test_excluded_position_data(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
file_name='example_positions.yaml'
|
||||
)
|
||||
pin_list = ['forth', 'first']
|
||||
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
|
||||
pin_list = ["forth", "first"]
|
||||
include_set = set()
|
||||
exclude_set = {'9999999999999'}
|
||||
exclude_set = {"9999999999999"}
|
||||
|
||||
position_map = pin_position_map(
|
||||
original_position_map=position_map,
|
||||
pin_list=pin_list
|
||||
)
|
||||
position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list)
|
||||
|
||||
data = [
|
||||
"forth",
|
||||
@@ -90,22 +88,16 @@ def test_excluded_position_data(prepare_example_positions_yaml):
|
||||
)
|
||||
|
||||
# assert the result in the correct order
|
||||
assert sorted_data == ['forth', 'first', 'second', 'third', 'extra1', 'extra2']
|
||||
assert sorted_data == ["forth", "first", "second", "third", "extra1", "extra2"]
|
||||
|
||||
|
||||
def test_included_position_data(prepare_example_positions_yaml):
|
||||
position_map = get_position_map(
|
||||
folder_path=prepare_example_positions_yaml,
|
||||
file_name='example_positions.yaml'
|
||||
)
|
||||
pin_list = ['forth', 'first']
|
||||
include_set = {'forth', 'first'}
|
||||
position_map = get_position_map(folder_path=prepare_example_positions_yaml, file_name="example_positions.yaml")
|
||||
pin_list = ["forth", "first"]
|
||||
include_set = {"forth", "first"}
|
||||
exclude_set = {}
|
||||
|
||||
position_map = pin_position_map(
|
||||
original_position_map=position_map,
|
||||
pin_list=pin_list
|
||||
)
|
||||
position_map = pin_position_map(original_position_map=position_map, pin_list=pin_list)
|
||||
|
||||
data = [
|
||||
"forth",
|
||||
@@ -128,4 +120,4 @@ def test_included_position_data(prepare_example_positions_yaml):
|
||||
)
|
||||
|
||||
# assert the result in the correct order
|
||||
assert sorted_data == ['forth', 'first']
|
||||
assert sorted_data == ["forth", "first"]
|
||||
|
||||
@@ -5,17 +5,18 @@ from yaml import YAMLError
|
||||
|
||||
from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
EXAMPLE_YAML_FILE = 'example_yaml.yaml'
|
||||
INVALID_YAML_FILE = 'invalid_yaml.yaml'
|
||||
NON_EXISTING_YAML_FILE = 'non_existing_file.yaml'
|
||||
EXAMPLE_YAML_FILE = "example_yaml.yaml"
|
||||
INVALID_YAML_FILE = "invalid_yaml.yaml"
|
||||
NON_EXISTING_YAML_FILE = "non_existing_file.yaml"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path.joinpath(EXAMPLE_YAML_FILE)
|
||||
file_path.write_text(dedent(
|
||||
"""\
|
||||
file_path.write_text(
|
||||
dedent(
|
||||
"""\
|
||||
address:
|
||||
city: Example City
|
||||
country: Example Country
|
||||
@@ -26,7 +27,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
- Java
|
||||
- C++
|
||||
empty_key:
|
||||
"""))
|
||||
"""
|
||||
)
|
||||
)
|
||||
return str(file_path)
|
||||
|
||||
|
||||
@@ -34,8 +37,9 @@ def prepare_example_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_path = tmp_path.joinpath(INVALID_YAML_FILE)
|
||||
file_path.write_text(dedent(
|
||||
"""\
|
||||
file_path.write_text(
|
||||
dedent(
|
||||
"""\
|
||||
address:
|
||||
city: Example City
|
||||
country: Example Country
|
||||
@@ -45,13 +49,15 @@ def prepare_invalid_yaml_file(tmp_path, monkeypatch) -> str:
|
||||
- Python
|
||||
- Java
|
||||
- C++
|
||||
"""))
|
||||
"""
|
||||
)
|
||||
)
|
||||
return str(file_path)
|
||||
|
||||
|
||||
def test_load_yaml_non_existing_file():
|
||||
assert load_yaml_file(file_path=NON_EXISTING_YAML_FILE) == {}
|
||||
assert load_yaml_file(file_path='') == {}
|
||||
assert load_yaml_file(file_path="") == {}
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_yaml_file(file_path=NON_EXISTING_YAML_FILE, ignore_error=False)
|
||||
@@ -60,12 +66,12 @@ def test_load_yaml_non_existing_file():
|
||||
def test_load_valid_yaml_file(prepare_example_yaml_file):
|
||||
yaml_data = load_yaml_file(file_path=prepare_example_yaml_file)
|
||||
assert len(yaml_data) > 0
|
||||
assert yaml_data['age'] == 30
|
||||
assert yaml_data['gender'] == 'male'
|
||||
assert yaml_data['address']['city'] == 'Example City'
|
||||
assert set(yaml_data['languages']) == {'Python', 'Java', 'C++'}
|
||||
assert yaml_data.get('empty_key') is None
|
||||
assert yaml_data.get('non_existed_key') is None
|
||||
assert yaml_data["age"] == 30
|
||||
assert yaml_data["gender"] == "male"
|
||||
assert yaml_data["address"]["city"] == "Example City"
|
||||
assert set(yaml_data["languages"]) == {"Python", "Java", "C++"}
|
||||
assert yaml_data.get("empty_key") is None
|
||||
assert yaml_data.get("non_existed_key") is None
|
||||
|
||||
|
||||
def test_load_invalid_yaml_file(prepare_invalid_yaml_file):
|
||||
|
||||
Reference in New Issue
Block a user