feat: introduce trigger functionality (#27644)
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: Stream <Stream_2@qq.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do> Co-authored-by: Harry <xh001x@hotmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: yessenia <yessenia.contact@gmail.com> Co-authored-by: hjlarry <hjlarry@163.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WTW0313 <twwu@dify.ai> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
|
||||
|
||||
|
||||
def test_should_prepare_user_inputs_defaults_to_true():
|
||||
args = {"inputs": {}}
|
||||
|
||||
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
|
||||
|
||||
def test_should_prepare_user_inputs_skips_when_flag_truthy():
|
||||
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True}
|
||||
|
||||
assert not WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
|
||||
|
||||
def test_should_prepare_user_inputs_keeps_validation_when_flag_false():
|
||||
args = {"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: False}
|
||||
|
||||
assert WorkflowAppGenerator()._should_prepare_user_inputs(args)
|
||||
655
api/tests/unit_tests/core/plugin/utils/test_http_parser.py
Normal file
655
api/tests/unit_tests/core/plugin/utils/test_http_parser.py
Normal file
@@ -0,0 +1,655 @@
|
||||
import pytest
|
||||
from flask import Request, Response
|
||||
|
||||
from core.plugin.utils.http_parser import (
|
||||
deserialize_request,
|
||||
deserialize_response,
|
||||
serialize_request,
|
||||
serialize_response,
|
||||
)
|
||||
|
||||
|
||||
class TestSerializeRequest:
|
||||
def test_serialize_simple_get_request(self):
|
||||
# Create a simple GET request
|
||||
environ = {
|
||||
"REQUEST_METHOD": "GET",
|
||||
"PATH_INFO": "/api/test",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": None,
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert raw_data.startswith(b"GET /api/test HTTP/1.1\r\n")
|
||||
assert b"\r\n\r\n" in raw_data # Empty line between headers and body
|
||||
|
||||
def test_serialize_request_with_query_params(self):
|
||||
# Create a GET request with query parameters
|
||||
environ = {
|
||||
"REQUEST_METHOD": "GET",
|
||||
"PATH_INFO": "/api/search",
|
||||
"QUERY_STRING": "q=test&limit=10",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": None,
|
||||
"wsgi.url_scheme": "http",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert raw_data.startswith(b"GET /api/search?q=test&limit=10 HTTP/1.1\r\n")
|
||||
|
||||
def test_serialize_post_request_with_body(self):
|
||||
# Create a POST request with body
|
||||
from io import BytesIO
|
||||
|
||||
body = b'{"name": "test", "value": 123}'
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/data",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": "application/json",
|
||||
"HTTP_CONTENT_TYPE": "application/json",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/data HTTP/1.1\r\n" in raw_data
|
||||
assert b"Content-Type: application/json" in raw_data
|
||||
assert raw_data.endswith(body)
|
||||
|
||||
def test_serialize_request_with_custom_headers(self):
|
||||
# Create a request with custom headers
|
||||
environ = {
|
||||
"REQUEST_METHOD": "GET",
|
||||
"PATH_INFO": "/api/test",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": None,
|
||||
"wsgi.url_scheme": "http",
|
||||
"HTTP_AUTHORIZATION": "Bearer token123",
|
||||
"HTTP_X_CUSTOM_HEADER": "custom-value",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"Authorization: Bearer token123" in raw_data
|
||||
assert b"X-Custom-Header: custom-value" in raw_data
|
||||
|
||||
|
||||
class TestDeserializeRequest:
|
||||
def test_deserialize_simple_get_request(self):
|
||||
raw_data = b"GET /api/test HTTP/1.1\r\nHost: localhost:8000\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/api/test"
|
||||
assert request.headers.get("Host") == "localhost:8000"
|
||||
|
||||
def test_deserialize_request_with_query_params(self):
|
||||
raw_data = b"GET /api/search?q=test&limit=10 HTTP/1.1\r\nHost: example.com\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/api/search"
|
||||
assert request.query_string == b"q=test&limit=10"
|
||||
assert request.args.get("q") == "test"
|
||||
assert request.args.get("limit") == "10"
|
||||
|
||||
def test_deserialize_post_request_with_body(self):
|
||||
body = b'{"name": "test", "value": 123}'
|
||||
raw_data = (
|
||||
b"POST /api/data HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Type: application/json\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/api/data"
|
||||
assert request.content_type == "application/json"
|
||||
assert request.get_data() == body
|
||||
|
||||
def test_deserialize_request_with_custom_headers(self):
|
||||
raw_data = (
|
||||
b"GET /api/protected HTTP/1.1\r\n"
|
||||
b"Host: api.example.com\r\n"
|
||||
b"Authorization: Bearer token123\r\n"
|
||||
b"X-Custom-Header: custom-value\r\n"
|
||||
b"User-Agent: TestClient/1.0\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.headers.get("Authorization") == "Bearer token123"
|
||||
assert request.headers.get("X-Custom-Header") == "custom-value"
|
||||
assert request.headers.get("User-Agent") == "TestClient/1.0"
|
||||
|
||||
def test_deserialize_request_with_multiline_body(self):
|
||||
body = b"line1\r\nline2\r\nline3"
|
||||
raw_data = b"PUT /api/text HTTP/1.1\r\nHost: localhost\r\nContent-Type: text/plain\r\n\r\n" + body
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "PUT"
|
||||
assert request.get_data() == body
|
||||
|
||||
def test_deserialize_invalid_request_line(self):
|
||||
raw_data = b"INVALID\r\n\r\n" # Only one part, should fail
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid request line"):
|
||||
deserialize_request(raw_data)
|
||||
|
||||
def test_roundtrip_request(self):
|
||||
# Test that serialize -> deserialize produces equivalent request
|
||||
from io import BytesIO
|
||||
|
||||
body = b"test body content"
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/echo",
|
||||
"QUERY_STRING": "format=json",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8080",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": "text/plain",
|
||||
"HTTP_CONTENT_TYPE": "text/plain",
|
||||
"HTTP_X_REQUEST_ID": "req-123",
|
||||
}
|
||||
original_request = Request(environ)
|
||||
|
||||
# Serialize and deserialize
|
||||
raw_data = serialize_request(original_request)
|
||||
restored_request = deserialize_request(raw_data)
|
||||
|
||||
# Verify key properties are preserved
|
||||
assert restored_request.method == original_request.method
|
||||
assert restored_request.path == original_request.path
|
||||
assert restored_request.query_string == original_request.query_string
|
||||
assert restored_request.get_data() == body
|
||||
assert restored_request.headers.get("X-Request-Id") == "req-123"
|
||||
|
||||
|
||||
class TestSerializeResponse:
|
||||
def test_serialize_simple_response(self):
|
||||
response = Response("Hello, World!", status=200)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert raw_data.startswith(b"HTTP/1.1 200 OK\r\n")
|
||||
assert b"\r\n\r\n" in raw_data
|
||||
assert raw_data.endswith(b"Hello, World!")
|
||||
|
||||
def test_serialize_response_with_headers(self):
|
||||
response = Response(
|
||||
'{"status": "success"}',
|
||||
status=201,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-Id": "req-456",
|
||||
},
|
||||
)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 201 CREATED\r\n" in raw_data
|
||||
assert b"Content-Type: application/json" in raw_data
|
||||
assert b"X-Request-Id: req-456" in raw_data
|
||||
assert raw_data.endswith(b'{"status": "success"}')
|
||||
|
||||
def test_serialize_error_response(self):
|
||||
response = Response(
|
||||
"Not Found",
|
||||
status=404,
|
||||
headers={"Content-Type": "text/plain"},
|
||||
)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 404 NOT FOUND\r\n" in raw_data
|
||||
assert b"Content-Type: text/plain" in raw_data
|
||||
assert raw_data.endswith(b"Not Found")
|
||||
|
||||
def test_serialize_response_without_body(self):
|
||||
response = Response(status=204) # No Content
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 204 NO CONTENT\r\n" in raw_data
|
||||
assert raw_data.endswith(b"\r\n\r\n") # Should end with empty line
|
||||
|
||||
def test_serialize_response_with_binary_body(self):
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05"
|
||||
response = Response(
|
||||
binary_data,
|
||||
status=200,
|
||||
headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
|
||||
raw_data = serialize_response(response)
|
||||
|
||||
assert b"HTTP/1.1 200 OK\r\n" in raw_data
|
||||
assert b"Content-Type: application/octet-stream" in raw_data
|
||||
assert raw_data.endswith(binary_data)
|
||||
|
||||
|
||||
class TestDeserializeResponse:
|
||||
def test_deserialize_simple_response(self):
|
||||
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\nHello, World!"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == b"Hello, World!"
|
||||
assert response.headers.get("Content-Type") == "text/plain"
|
||||
|
||||
def test_deserialize_response_with_json(self):
|
||||
body = b'{"result": "success", "data": [1, 2, 3]}'
|
||||
raw_data = (
|
||||
b"HTTP/1.1 201 Created\r\n"
|
||||
b"Content-Type: application/json\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"X-Custom-Header: test-value\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.get_data() == body
|
||||
assert response.headers.get("Content-Type") == "application/json"
|
||||
assert response.headers.get("X-Custom-Header") == "test-value"
|
||||
|
||||
def test_deserialize_error_response(self):
|
||||
raw_data = b"HTTP/1.1 404 Not Found\r\nContent-Type: text/html\r\n\r\n<html><body>Page not found</body></html>"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert response.get_data() == b"<html><body>Page not found</body></html>"
|
||||
|
||||
def test_deserialize_response_without_body(self):
|
||||
raw_data = b"HTTP/1.1 204 No Content\r\n\r\n"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 204
|
||||
assert response.get_data() == b""
|
||||
|
||||
def test_deserialize_response_with_multiline_body(self):
|
||||
body = b"Line 1\r\nLine 2\r\nLine 3"
|
||||
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n" + body
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == body
|
||||
|
||||
def test_deserialize_response_minimal_status_line(self):
|
||||
# Test with minimal status line (no status text)
|
||||
raw_data = b"HTTP/1.1 200\r\n\r\nOK"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == b"OK"
|
||||
|
||||
def test_deserialize_invalid_status_line(self):
|
||||
raw_data = b"INVALID\r\n\r\n"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid status line"):
|
||||
deserialize_response(raw_data)
|
||||
|
||||
def test_roundtrip_response(self):
|
||||
# Test that serialize -> deserialize produces equivalent response
|
||||
original_response = Response(
|
||||
'{"message": "test"}',
|
||||
status=200,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"X-Request-Id": "abc-123",
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
# Serialize and deserialize
|
||||
raw_data = serialize_response(original_response)
|
||||
restored_response = deserialize_response(raw_data)
|
||||
|
||||
# Verify key properties are preserved
|
||||
assert restored_response.status_code == original_response.status_code
|
||||
assert restored_response.get_data() == original_response.get_data()
|
||||
assert restored_response.headers.get("Content-Type") == "application/json"
|
||||
assert restored_response.headers.get("X-Request-Id") == "abc-123"
|
||||
assert restored_response.headers.get("Cache-Control") == "no-cache"
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_request_with_empty_headers(self):
|
||||
raw_data = b"GET / HTTP/1.1\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/"
|
||||
|
||||
def test_response_with_empty_headers(self):
|
||||
raw_data = b"HTTP/1.1 200 OK\r\n\r\nSuccess"
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == b"Success"
|
||||
|
||||
def test_request_with_special_characters_in_path(self):
|
||||
raw_data = b"GET /api/test%20path?key=%26value HTTP/1.1\r\n\r\n"
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert "/api/test%20path" in request.full_path
|
||||
|
||||
def test_response_with_binary_content(self):
|
||||
binary_body = bytes(range(256)) # All possible byte values
|
||||
raw_data = b"HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\n\r\n" + binary_body
|
||||
|
||||
response = deserialize_response(raw_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_data() == binary_body
|
||||
|
||||
|
||||
class TestFileUploads:
|
||||
def test_serialize_request_with_text_file_upload(self):
|
||||
# Test multipart/form-data request with text file
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
|
||||
text_content = "Hello, this is a test file content!\nWith multiple lines."
|
||||
body = (
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="file"; filename="test.txt"\r\n'
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"\r\n"
|
||||
f"{text_content}\r\n"
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="description"\r\n'
|
||||
f"\r\n"
|
||||
f"Test file upload\r\n"
|
||||
f"------{boundary}--\r\n"
|
||||
).encode()
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/upload",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/upload HTTP/1.1\r\n" in raw_data
|
||||
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
|
||||
assert b'Content-Disposition: form-data; name="file"; filename="test.txt"' in raw_data
|
||||
assert text_content.encode() in raw_data
|
||||
|
||||
def test_deserialize_request_with_text_file_upload(self):
|
||||
# Test deserializing multipart/form-data request with text file
|
||||
boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
|
||||
text_content = "Sample text file content\nLine 2\nLine 3"
|
||||
body = (
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="document"; filename="document.txt"\r\n'
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"\r\n"
|
||||
f"{text_content}\r\n"
|
||||
f"------{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="title"\r\n'
|
||||
f"\r\n"
|
||||
f"My Document\r\n"
|
||||
f"------{boundary}--\r\n"
|
||||
).encode()
|
||||
|
||||
raw_data = (
|
||||
b"POST /api/documents HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/api/documents"
|
||||
assert "multipart/form-data" in request.content_type
|
||||
# The body should contain the multipart data
|
||||
request_body = request.get_data()
|
||||
assert b"document.txt" in request_body
|
||||
assert text_content.encode() in request_body
|
||||
|
||||
def test_serialize_request_with_binary_file_upload(self):
|
||||
# Test multipart/form-data request with binary file (e.g., image)
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----BoundaryString123"
|
||||
# Simulate a small PNG file header
|
||||
binary_content = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x10\x00\x00\x00\x10"
|
||||
|
||||
# Build multipart body
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="image"; filename="test.png"')
|
||||
body_parts.append(b"Content-Type: image/png")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(binary_content)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="caption"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b"Test image")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/images",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "http",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/images HTTP/1.1\r\n" in raw_data
|
||||
assert f"Content-Type: multipart/form-data; boundary={boundary}".encode() in raw_data
|
||||
assert b'filename="test.png"' in raw_data
|
||||
assert b"Content-Type: image/png" in raw_data
|
||||
assert binary_content in raw_data
|
||||
|
||||
def test_deserialize_request_with_binary_file_upload(self):
|
||||
# Test deserializing multipart/form-data request with binary file
|
||||
boundary = "----BoundaryABC123"
|
||||
# Simulate a small JPEG file header
|
||||
binary_content = b"\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00"
|
||||
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="photo"; filename="photo.jpg"')
|
||||
body_parts.append(b"Content-Type: image/jpeg")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(binary_content)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="album"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b"Vacation 2024")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
raw_data = (
|
||||
b"POST /api/photos HTTP/1.1\r\n"
|
||||
b"Host: api.example.com\r\n"
|
||||
b"Content-Type: multipart/form-data; boundary=" + boundary.encode() + b"\r\n"
|
||||
b"Content-Length: " + str(len(body)).encode() + b"\r\n"
|
||||
b"Accept: application/json\r\n"
|
||||
b"\r\n" + body
|
||||
)
|
||||
|
||||
request = deserialize_request(raw_data)
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/api/photos"
|
||||
assert "multipart/form-data" in request.content_type
|
||||
assert request.headers.get("Accept") == "application/json"
|
||||
|
||||
# Verify the binary content is preserved
|
||||
request_body = request.get_data()
|
||||
assert b"photo.jpg" in request_body
|
||||
assert b"image/jpeg" in request_body
|
||||
assert binary_content in request_body
|
||||
assert b"Vacation 2024" in request_body
|
||||
|
||||
def test_serialize_request_with_multiple_files(self):
|
||||
# Test request with multiple file uploads
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----MultiFilesBoundary"
|
||||
text_file = b"Text file contents"
|
||||
binary_file = b"\x00\x01\x02\x03\x04\x05"
|
||||
|
||||
body_parts = []
|
||||
# First file (text)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="doc.txt"')
|
||||
body_parts.append(b"Content-Type: text/plain")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(text_file)
|
||||
# Second file (binary)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="files"; filename="data.bin"')
|
||||
body_parts.append(b"Content-Type: application/octet-stream")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(binary_file)
|
||||
# Additional form field
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="folder"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b"uploads/2024")
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "POST",
|
||||
"PATH_INFO": "/api/batch-upload",
|
||||
"QUERY_STRING": "",
|
||||
"SERVER_NAME": "localhost",
|
||||
"SERVER_PORT": "8000",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "https",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_X_FORWARDED_PROTO": "https",
|
||||
}
|
||||
request = Request(environ)
|
||||
|
||||
raw_data = serialize_request(request)
|
||||
|
||||
assert b"POST /api/batch-upload HTTP/1.1\r\n" in raw_data
|
||||
assert b"doc.txt" in raw_data
|
||||
assert b"data.bin" in raw_data
|
||||
assert text_file in raw_data
|
||||
assert binary_file in raw_data
|
||||
assert b"uploads/2024" in raw_data
|
||||
|
||||
def test_roundtrip_file_upload_request(self):
|
||||
# Test that file upload request survives serialize -> deserialize
|
||||
from io import BytesIO
|
||||
|
||||
boundary = "----RoundTripBoundary"
|
||||
file_content = b"This is my file content with special chars: \xf0\x9f\x98\x80"
|
||||
|
||||
body_parts = []
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="upload"; filename="emoji.txt"')
|
||||
body_parts.append(b"Content-Type: text/plain; charset=utf-8")
|
||||
body_parts.append(b"")
|
||||
body_parts.append(file_content)
|
||||
body_parts.append(f"------{boundary}".encode())
|
||||
body_parts.append(b'Content-Disposition: form-data; name="metadata"')
|
||||
body_parts.append(b"")
|
||||
body_parts.append(b'{"encoding": "utf-8", "size": 42}')
|
||||
body_parts.append(f"------{boundary}--".encode())
|
||||
|
||||
body = b"\r\n".join(body_parts)
|
||||
|
||||
environ = {
|
||||
"REQUEST_METHOD": "PUT",
|
||||
"PATH_INFO": "/api/files/123",
|
||||
"QUERY_STRING": "version=2",
|
||||
"SERVER_NAME": "storage.example.com",
|
||||
"SERVER_PORT": "443",
|
||||
"wsgi.input": BytesIO(body),
|
||||
"wsgi.url_scheme": "https",
|
||||
"CONTENT_LENGTH": str(len(body)),
|
||||
"CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_CONTENT_TYPE": f"multipart/form-data; boundary={boundary}",
|
||||
"HTTP_AUTHORIZATION": "Bearer token123",
|
||||
"HTTP_X_FORWARDED_PROTO": "https",
|
||||
}
|
||||
original_request = Request(environ)
|
||||
|
||||
# Serialize and deserialize
|
||||
raw_data = serialize_request(original_request)
|
||||
restored_request = deserialize_request(raw_data)
|
||||
|
||||
# Verify the request is preserved
|
||||
assert restored_request.method == "PUT"
|
||||
assert restored_request.path == "/api/files/123"
|
||||
assert restored_request.query_string == b"version=2"
|
||||
assert "multipart/form-data" in restored_request.content_type
|
||||
assert boundary in restored_request.content_type
|
||||
|
||||
# Verify file content is preserved
|
||||
restored_body = restored_request.get_data()
|
||||
assert b"emoji.txt" in restored_body
|
||||
assert file_content in restored_body
|
||||
assert b'{"encoding": "utf-8", "size": 42}' in restored_body
|
||||
102
api/tests/unit_tests/core/test_trigger_debug_event_selectors.py
Normal file
102
api/tests/unit_tests/core/test_trigger_debug_event_selectors.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
|
||||
from core.trigger.debug import event_selectors
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig
|
||||
|
||||
|
||||
class _DummyRedis:
|
||||
def __init__(self):
|
||||
self.store: dict[str, str] = {}
|
||||
|
||||
def get(self, key: str):
|
||||
return self.store.get(key)
|
||||
|
||||
def setex(self, name: str, time: int, value: str):
|
||||
self.store[name] = value
|
||||
|
||||
def expire(self, name: str, ttl: int):
|
||||
# Expiration not required for these tests.
|
||||
pass
|
||||
|
||||
def delete(self, name: str):
|
||||
self.store.pop(name, None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_schedule_config() -> ScheduleConfig:
|
||||
return ScheduleConfig(
|
||||
node_id="node-1",
|
||||
cron_expression="* * * * *",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_schedule_service(monkeypatch: pytest.MonkeyPatch, dummy_schedule_config: ScheduleConfig):
|
||||
# Ensure poller always receives the deterministic config.
|
||||
monkeypatch.setattr(
|
||||
"services.trigger.schedule_service.ScheduleService.to_schedule_config",
|
||||
staticmethod(lambda *_args, **_kwargs: dummy_schedule_config),
|
||||
)
|
||||
|
||||
|
||||
def _make_poller(
|
||||
monkeypatch: pytest.MonkeyPatch, redis_client: _DummyRedis
|
||||
) -> event_selectors.ScheduleTriggerDebugEventPoller:
|
||||
monkeypatch.setattr(event_selectors, "redis_client", redis_client)
|
||||
return event_selectors.ScheduleTriggerDebugEventPoller(
|
||||
tenant_id="tenant-1",
|
||||
user_id="user-1",
|
||||
app_id="app-1",
|
||||
node_config={"id": "node-1", "data": {"mode": "cron"}},
|
||||
node_id="node-1",
|
||||
)
|
||||
|
||||
|
||||
def test_schedule_poller_handles_aware_next_run(monkeypatch: pytest.MonkeyPatch):
|
||||
redis_client = _DummyRedis()
|
||||
poller = _make_poller(monkeypatch, redis_client)
|
||||
|
||||
base_now = datetime(2025, 1, 1, 12, 0, 10)
|
||||
aware_next_run = datetime(2025, 1, 1, 12, 0, 5, tzinfo=UTC)
|
||||
|
||||
monkeypatch.setattr(event_selectors, "naive_utc_now", lambda: base_now)
|
||||
monkeypatch.setattr(event_selectors, "calculate_next_run_at", lambda *_: aware_next_run)
|
||||
|
||||
event = poller.poll()
|
||||
|
||||
assert event is not None
|
||||
assert event.node_id == "node-1"
|
||||
assert event.workflow_args["inputs"] == {}
|
||||
|
||||
|
||||
def test_schedule_runtime_cache_normalizes_timezone(
|
||||
monkeypatch: pytest.MonkeyPatch, dummy_schedule_config: ScheduleConfig
|
||||
):
|
||||
redis_client = _DummyRedis()
|
||||
poller = _make_poller(monkeypatch, redis_client)
|
||||
|
||||
localized_time = pytz.timezone("Asia/Shanghai").localize(datetime(2025, 1, 1, 20, 0, 0))
|
||||
|
||||
cron_hash = hashlib.sha256(dummy_schedule_config.cron_expression.encode()).hexdigest()
|
||||
cache_key = poller.schedule_debug_runtime_key(cron_hash)
|
||||
|
||||
redis_client.store[cache_key] = json.dumps(
|
||||
{
|
||||
"cache_key": cache_key,
|
||||
"timezone": dummy_schedule_config.timezone,
|
||||
"cron_expression": dummy_schedule_config.cron_expression,
|
||||
"next_run_at": localized_time.isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
runtime = poller.get_or_create_schedule_debug_runtime()
|
||||
|
||||
expected = localized_time.astimezone(UTC).replace(tzinfo=None)
|
||||
assert runtime.next_run_at == expected
|
||||
assert runtime.next_run_at.tzinfo is None
|
||||
@@ -4,7 +4,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.tools.utils.encryption import ProviderConfigEncrypter
|
||||
from core.helper.provider_encryption import ProviderConfigEncrypter
|
||||
|
||||
|
||||
# ---------------------------
|
||||
@@ -70,7 +70,7 @@ def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj
|
||||
data_in = {"username": "alice", "password": "plain_pwd"}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
with patch("core.tools.utils.encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt:
|
||||
with patch("core.helper.provider_encryption.encrypter.encrypt_token", return_value="CIPHERTEXT") as mock_encrypt:
|
||||
out = encrypter_obj.encrypt(data_in)
|
||||
|
||||
assert out["username"] == "alice"
|
||||
@@ -81,14 +81,14 @@ def test_encrypt_only_secret_is_encrypted_and_non_secret_unchanged(encrypter_obj
|
||||
|
||||
def test_encrypt_missing_secret_key_is_ok(encrypter_obj):
|
||||
"""If secret field missing in input, no error and no encryption called."""
|
||||
with patch("core.tools.utils.encryption.encrypter.encrypt_token") as mock_encrypt:
|
||||
with patch("core.helper.provider_encryption.encrypter.encrypt_token") as mock_encrypt:
|
||||
out = encrypter_obj.encrypt({"username": "alice"})
|
||||
assert out["username"] == "alice"
|
||||
mock_encrypt.assert_not_called()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# ProviderConfigEncrypter.mask_tool_credentials()
|
||||
# ProviderConfigEncrypter.mask_plugin_credentials()
|
||||
# ============================================================
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ def test_mask_tool_credentials_long_secret(encrypter_obj, raw, prefix, suffix):
|
||||
data_in = {"username": "alice", "password": raw}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
out = encrypter_obj.mask_tool_credentials(data_in)
|
||||
out = encrypter_obj.mask_plugin_credentials(data_in)
|
||||
masked = out["password"]
|
||||
|
||||
assert masked.startswith(prefix)
|
||||
@@ -122,7 +122,7 @@ def test_mask_tool_credentials_short_secret(encrypter_obj, raw):
|
||||
"""
|
||||
For length <= 6: fully mask with '*' of same length.
|
||||
"""
|
||||
out = encrypter_obj.mask_tool_credentials({"password": raw})
|
||||
out = encrypter_obj.mask_plugin_credentials({"password": raw})
|
||||
assert out["password"] == ("*" * len(raw))
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ def test_mask_tool_credentials_missing_key_noop(encrypter_obj):
|
||||
data_in = {"username": "alice"}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
out = encrypter_obj.mask_tool_credentials(data_in)
|
||||
out = encrypter_obj.mask_plugin_credentials(data_in)
|
||||
assert out["username"] == "alice"
|
||||
assert data_in == data_copy
|
||||
|
||||
@@ -151,7 +151,7 @@ def test_decrypt_normal_flow(encrypter_obj):
|
||||
data_in = {"username": "alice", "password": "ENC"}
|
||||
data_copy = copy.deepcopy(data_in)
|
||||
|
||||
with patch("core.tools.utils.encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt:
|
||||
with patch("core.helper.provider_encryption.encrypter.decrypt_token", return_value="PLAIN") as mock_decrypt:
|
||||
out = encrypter_obj.decrypt(data_in)
|
||||
|
||||
assert out["username"] == "alice"
|
||||
@@ -163,7 +163,7 @@ def test_decrypt_normal_flow(encrypter_obj):
|
||||
@pytest.mark.parametrize("empty_val", ["", None])
|
||||
def test_decrypt_skip_empty_values(encrypter_obj, empty_val):
|
||||
"""Skip decrypt if value is empty or None, keep original."""
|
||||
with patch("core.tools.utils.encryption.encrypter.decrypt_token") as mock_decrypt:
|
||||
with patch("core.helper.provider_encryption.encrypter.decrypt_token") as mock_decrypt:
|
||||
out = encrypter_obj.decrypt({"password": empty_val})
|
||||
|
||||
mock_decrypt.assert_not_called()
|
||||
@@ -175,7 +175,7 @@ def test_decrypt_swallow_exception_and_keep_original(encrypter_obj):
|
||||
If decrypt_token raises, exception should be swallowed,
|
||||
and original value preserved.
|
||||
"""
|
||||
with patch("core.tools.utils.encryption.encrypter.decrypt_token", side_effect=Exception("boom")):
|
||||
with patch("core.helper.provider_encryption.encrypter.decrypt_token", side_effect=Exception("boom")):
|
||||
out = encrypter_obj.decrypt({"password": "ENC_ERR"})
|
||||
|
||||
assert out["password"] == "ENC_ERR"
|
||||
|
||||
@@ -64,6 +64,15 @@ class _TestNode(Node):
|
||||
)
|
||||
self.data = dict(data)
|
||||
|
||||
node_type_value = data.get("type")
|
||||
if isinstance(node_type_value, NodeType):
|
||||
self.node_type = node_type_value
|
||||
elif isinstance(node_type_value, str):
|
||||
try:
|
||||
self.node_type = NodeType(node_type_value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -179,3 +188,22 @@ def test_graph_promotes_fail_branch_nodes_to_branch_execution_type(
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH
|
||||
|
||||
|
||||
def test_graph_validation_blocks_start_and_trigger_coexistence(
|
||||
graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]],
|
||||
) -> None:
|
||||
node_factory, graph_config = graph_init_dependencies
|
||||
graph_config["nodes"] = [
|
||||
{"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}},
|
||||
{
|
||||
"id": "trigger",
|
||||
"data": {"type": NodeType.TRIGGER_WEBHOOK, "title": "Webhook", "execution_type": NodeExecutionType.ROOT},
|
||||
},
|
||||
]
|
||||
graph_config["edges"] = []
|
||||
|
||||
with pytest.raises(GraphValidationError) as exc_info:
|
||||
Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues)
|
||||
|
||||
@@ -0,0 +1,308 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
Method,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
|
||||
|
||||
def test_method_enum():
|
||||
"""Test Method enum values."""
|
||||
assert Method.GET == "get"
|
||||
assert Method.POST == "post"
|
||||
assert Method.HEAD == "head"
|
||||
assert Method.PATCH == "patch"
|
||||
assert Method.PUT == "put"
|
||||
assert Method.DELETE == "delete"
|
||||
|
||||
# Test all enum values are strings
|
||||
for method in Method:
|
||||
assert isinstance(method.value, str)
|
||||
|
||||
|
||||
def test_content_type_enum():
|
||||
"""Test ContentType enum values."""
|
||||
assert ContentType.JSON == "application/json"
|
||||
assert ContentType.FORM_DATA == "multipart/form-data"
|
||||
assert ContentType.FORM_URLENCODED == "application/x-www-form-urlencoded"
|
||||
assert ContentType.TEXT == "text/plain"
|
||||
assert ContentType.BINARY == "application/octet-stream"
|
||||
|
||||
# Test all enum values are strings
|
||||
for content_type in ContentType:
|
||||
assert isinstance(content_type.value, str)
|
||||
|
||||
|
||||
def test_webhook_parameter_creation():
|
||||
"""Test WebhookParameter model creation and validation."""
|
||||
# Test with all fields
|
||||
param = WebhookParameter(name="api_key", required=True)
|
||||
assert param.name == "api_key"
|
||||
assert param.required is True
|
||||
|
||||
# Test with defaults
|
||||
param_default = WebhookParameter(name="optional_param")
|
||||
assert param_default.name == "optional_param"
|
||||
assert param_default.required is False
|
||||
|
||||
# Test validation - name is required
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookParameter()
|
||||
|
||||
|
||||
def test_webhook_body_parameter_creation():
|
||||
"""Test WebhookBodyParameter model creation and validation."""
|
||||
# Test with all fields
|
||||
body_param = WebhookBodyParameter(
|
||||
name="user_data",
|
||||
type="object",
|
||||
required=True,
|
||||
)
|
||||
assert body_param.name == "user_data"
|
||||
assert body_param.type == "object"
|
||||
assert body_param.required is True
|
||||
|
||||
# Test with defaults
|
||||
body_param_default = WebhookBodyParameter(name="message")
|
||||
assert body_param_default.name == "message"
|
||||
assert body_param_default.type == "string" # Default type
|
||||
assert body_param_default.required is False
|
||||
|
||||
# Test validation - name is required
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookBodyParameter()
|
||||
|
||||
|
||||
def test_webhook_body_parameter_types():
|
||||
"""Test WebhookBodyParameter type validation."""
|
||||
valid_types = [
|
||||
"string",
|
||||
"number",
|
||||
"boolean",
|
||||
"object",
|
||||
"array[string]",
|
||||
"array[number]",
|
||||
"array[boolean]",
|
||||
"array[object]",
|
||||
"file",
|
||||
]
|
||||
|
||||
for param_type in valid_types:
|
||||
param = WebhookBodyParameter(name="test", type=param_type)
|
||||
assert param.type == param_type
|
||||
|
||||
# Test invalid type
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookBodyParameter(name="test", type="invalid_type")
|
||||
|
||||
|
||||
def test_webhook_data_creation_minimal():
|
||||
"""Test WebhookData creation with minimal required fields."""
|
||||
data = WebhookData(title="Test Webhook")
|
||||
|
||||
assert data.title == "Test Webhook"
|
||||
assert data.method == Method.GET # Default
|
||||
assert data.content_type == ContentType.JSON # Default
|
||||
assert data.headers == [] # Default
|
||||
assert data.params == [] # Default
|
||||
assert data.body == [] # Default
|
||||
assert data.status_code == 200 # Default
|
||||
assert data.response_body == "" # Default
|
||||
assert data.webhook_id is None # Default
|
||||
assert data.timeout == 30 # Default
|
||||
|
||||
|
||||
def test_webhook_data_creation_full():
|
||||
"""Test WebhookData creation with all fields."""
|
||||
headers = [
|
||||
WebhookParameter(name="Authorization", required=True),
|
||||
WebhookParameter(name="Content-Type", required=False),
|
||||
]
|
||||
params = [
|
||||
WebhookParameter(name="version", required=True),
|
||||
WebhookParameter(name="format", required=False),
|
||||
]
|
||||
body = [
|
||||
WebhookBodyParameter(name="message", type="string", required=True),
|
||||
WebhookBodyParameter(name="count", type="number", required=False),
|
||||
WebhookBodyParameter(name="upload", type="file", required=True),
|
||||
]
|
||||
|
||||
# Use the alias for content_type to test it properly
|
||||
data = WebhookData(
|
||||
title="Full Webhook Test",
|
||||
desc="A comprehensive webhook test",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.FORM_DATA,
|
||||
headers=headers,
|
||||
params=params,
|
||||
body=body,
|
||||
status_code=201,
|
||||
response_body='{"success": true}',
|
||||
webhook_id="webhook_123",
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
assert data.title == "Full Webhook Test"
|
||||
assert data.desc == "A comprehensive webhook test"
|
||||
assert data.method == Method.POST
|
||||
assert data.content_type == ContentType.FORM_DATA
|
||||
assert len(data.headers) == 2
|
||||
assert len(data.params) == 2
|
||||
assert len(data.body) == 3
|
||||
assert data.status_code == 201
|
||||
assert data.response_body == '{"success": true}'
|
||||
assert data.webhook_id == "webhook_123"
|
||||
assert data.timeout == 60
|
||||
|
||||
|
||||
def test_webhook_data_content_type_alias():
|
||||
"""Test WebhookData content_type accepts both strings and enum values."""
|
||||
data1 = WebhookData(title="Test", content_type="application/json")
|
||||
assert data1.content_type == ContentType.JSON
|
||||
|
||||
data2 = WebhookData(title="Test", content_type=ContentType.FORM_DATA)
|
||||
assert data2.content_type == ContentType.FORM_DATA
|
||||
|
||||
|
||||
def test_webhook_data_model_dump():
|
||||
"""Test WebhookData model serialization."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.JSON,
|
||||
headers=[WebhookParameter(name="Authorization", required=True)],
|
||||
params=[WebhookParameter(name="version", required=False)],
|
||||
body=[WebhookBodyParameter(name="message", type="string", required=True)],
|
||||
status_code=200,
|
||||
response_body="OK",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
dumped = data.model_dump()
|
||||
|
||||
assert dumped["title"] == "Test Webhook"
|
||||
assert dumped["method"] == "post"
|
||||
assert dumped["content_type"] == "application/json"
|
||||
assert len(dumped["headers"]) == 1
|
||||
assert dumped["headers"][0]["name"] == "Authorization"
|
||||
assert dumped["headers"][0]["required"] is True
|
||||
assert len(dumped["params"]) == 1
|
||||
assert len(dumped["body"]) == 1
|
||||
assert dumped["body"][0]["type"] == "string"
|
||||
|
||||
|
||||
def test_webhook_data_model_dump_with_alias():
|
||||
"""Test WebhookData model serialization includes alias."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
content_type=ContentType.FORM_DATA,
|
||||
)
|
||||
|
||||
dumped = data.model_dump(by_alias=True)
|
||||
assert "content_type" in dumped
|
||||
assert dumped["content_type"] == "multipart/form-data"
|
||||
|
||||
|
||||
def test_webhook_data_validation_errors():
|
||||
"""Test WebhookData validation errors."""
|
||||
# Title is required (inherited from BaseNodeData)
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookData()
|
||||
|
||||
# Invalid method
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookData(title="Test", method="invalid_method")
|
||||
|
||||
# Invalid content_type
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookData(title="Test", content_type="invalid/type")
|
||||
|
||||
# Invalid status_code (should be int) - use non-numeric string
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookData(title="Test", status_code="invalid")
|
||||
|
||||
# Invalid timeout (should be int) - use non-numeric string
|
||||
with pytest.raises(ValidationError):
|
||||
WebhookData(title="Test", timeout="invalid")
|
||||
|
||||
# Valid cases that should NOT raise errors
|
||||
# These should work fine (pydantic converts string numbers to int)
|
||||
valid_data = WebhookData(title="Test", status_code="200", timeout="30")
|
||||
assert valid_data.status_code == 200
|
||||
assert valid_data.timeout == 30
|
||||
|
||||
|
||||
def test_webhook_data_sequence_fields():
|
||||
"""Test WebhookData sequence field behavior."""
|
||||
# Test empty sequences
|
||||
data = WebhookData(title="Test")
|
||||
assert data.headers == []
|
||||
assert data.params == []
|
||||
assert data.body == []
|
||||
|
||||
# Test immutable sequences
|
||||
headers = [WebhookParameter(name="test")]
|
||||
data = WebhookData(title="Test", headers=headers)
|
||||
|
||||
# Original list shouldn't affect the model
|
||||
headers.append(WebhookParameter(name="test2"))
|
||||
assert len(data.headers) == 1 # Should still be 1
|
||||
|
||||
|
||||
def test_webhook_data_sync_mode():
|
||||
"""Test WebhookData SyncMode nested enum."""
|
||||
# Test that SyncMode enum exists and has expected value
|
||||
assert hasattr(WebhookData, "SyncMode")
|
||||
assert WebhookData.SyncMode.SYNC == "async" # Note: confusingly named but correct
|
||||
|
||||
|
||||
def test_webhook_parameter_edge_cases():
|
||||
"""Test WebhookParameter edge cases."""
|
||||
# Test with special characters in name
|
||||
param = WebhookParameter(name="X-Custom-Header-123", required=True)
|
||||
assert param.name == "X-Custom-Header-123"
|
||||
|
||||
# Test with empty string name (should be valid if pydantic allows it)
|
||||
param_empty = WebhookParameter(name="", required=False)
|
||||
assert param_empty.name == ""
|
||||
|
||||
|
||||
def test_webhook_body_parameter_edge_cases():
|
||||
"""Test WebhookBodyParameter edge cases."""
|
||||
# Test file type parameter
|
||||
file_param = WebhookBodyParameter(name="upload", type="file", required=True)
|
||||
assert file_param.type == "file"
|
||||
assert file_param.required is True
|
||||
|
||||
# Test all valid types
|
||||
for param_type in [
|
||||
"string",
|
||||
"number",
|
||||
"boolean",
|
||||
"object",
|
||||
"array[string]",
|
||||
"array[number]",
|
||||
"array[boolean]",
|
||||
"array[object]",
|
||||
"file",
|
||||
]:
|
||||
param = WebhookBodyParameter(name=f"test_{param_type}", type=param_type)
|
||||
assert param.type == param_type
|
||||
|
||||
|
||||
def test_webhook_data_inheritance():
|
||||
"""Test WebhookData inherits from BaseNodeData correctly."""
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
|
||||
# Test that WebhookData is a subclass of BaseNodeData
|
||||
assert issubclass(WebhookData, BaseNodeData)
|
||||
|
||||
# Test that instances have BaseNodeData properties
|
||||
data = WebhookData(title="Test")
|
||||
assert hasattr(data, "title")
|
||||
assert hasattr(data, "desc") # Inherited from BaseNodeData
|
||||
@@ -0,0 +1,195 @@
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.base.exc import BaseNodeError
|
||||
from core.workflow.nodes.trigger_webhook.exc import (
|
||||
WebhookConfigError,
|
||||
WebhookNodeError,
|
||||
WebhookNotFoundError,
|
||||
WebhookTimeoutError,
|
||||
)
|
||||
|
||||
|
||||
def test_webhook_node_error_inheritance():
|
||||
"""Test WebhookNodeError inherits from BaseNodeError."""
|
||||
assert issubclass(WebhookNodeError, BaseNodeError)
|
||||
|
||||
# Test instantiation
|
||||
error = WebhookNodeError("Test error message")
|
||||
assert str(error) == "Test error message"
|
||||
assert isinstance(error, BaseNodeError)
|
||||
|
||||
|
||||
def test_webhook_timeout_error():
|
||||
"""Test WebhookTimeoutError functionality."""
|
||||
# Test inheritance
|
||||
assert issubclass(WebhookTimeoutError, WebhookNodeError)
|
||||
assert issubclass(WebhookTimeoutError, BaseNodeError)
|
||||
|
||||
# Test instantiation with message
|
||||
error = WebhookTimeoutError("Webhook request timed out")
|
||||
assert str(error) == "Webhook request timed out"
|
||||
|
||||
# Test instantiation without message
|
||||
error_no_msg = WebhookTimeoutError()
|
||||
assert isinstance(error_no_msg, WebhookTimeoutError)
|
||||
|
||||
|
||||
def test_webhook_not_found_error():
|
||||
"""Test WebhookNotFoundError functionality."""
|
||||
# Test inheritance
|
||||
assert issubclass(WebhookNotFoundError, WebhookNodeError)
|
||||
assert issubclass(WebhookNotFoundError, BaseNodeError)
|
||||
|
||||
# Test instantiation with message
|
||||
error = WebhookNotFoundError("Webhook trigger not found")
|
||||
assert str(error) == "Webhook trigger not found"
|
||||
|
||||
# Test instantiation without message
|
||||
error_no_msg = WebhookNotFoundError()
|
||||
assert isinstance(error_no_msg, WebhookNotFoundError)
|
||||
|
||||
|
||||
def test_webhook_config_error():
|
||||
"""Test WebhookConfigError functionality."""
|
||||
# Test inheritance
|
||||
assert issubclass(WebhookConfigError, WebhookNodeError)
|
||||
assert issubclass(WebhookConfigError, BaseNodeError)
|
||||
|
||||
# Test instantiation with message
|
||||
error = WebhookConfigError("Invalid webhook configuration")
|
||||
assert str(error) == "Invalid webhook configuration"
|
||||
|
||||
# Test instantiation without message
|
||||
error_no_msg = WebhookConfigError()
|
||||
assert isinstance(error_no_msg, WebhookConfigError)
|
||||
|
||||
|
||||
def test_webhook_error_hierarchy():
|
||||
"""Test the complete webhook error hierarchy."""
|
||||
# All webhook errors should inherit from WebhookNodeError
|
||||
webhook_errors = [
|
||||
WebhookTimeoutError,
|
||||
WebhookNotFoundError,
|
||||
WebhookConfigError,
|
||||
]
|
||||
|
||||
for error_class in webhook_errors:
|
||||
assert issubclass(error_class, WebhookNodeError)
|
||||
assert issubclass(error_class, BaseNodeError)
|
||||
|
||||
|
||||
def test_webhook_error_instantiation_with_args():
|
||||
"""Test webhook error instantiation with various arguments."""
|
||||
# Test with single string argument
|
||||
error1 = WebhookNodeError("Simple error message")
|
||||
assert str(error1) == "Simple error message"
|
||||
|
||||
# Test with multiple arguments
|
||||
error2 = WebhookTimeoutError("Timeout after", 30, "seconds")
|
||||
# Note: The exact string representation depends on Exception.__str__ implementation
|
||||
assert "Timeout after" in str(error2)
|
||||
|
||||
# Test with keyword arguments (if supported by base Exception)
|
||||
error3 = WebhookConfigError("Config error in field: timeout")
|
||||
assert "Config error in field: timeout" in str(error3)
|
||||
|
||||
|
||||
def test_webhook_error_as_exceptions():
|
||||
"""Test that webhook errors can be raised and caught properly."""
|
||||
# Test raising and catching WebhookNodeError
|
||||
with pytest.raises(WebhookNodeError) as exc_info:
|
||||
raise WebhookNodeError("Base webhook error")
|
||||
assert str(exc_info.value) == "Base webhook error"
|
||||
|
||||
# Test raising and catching specific errors
|
||||
with pytest.raises(WebhookTimeoutError) as exc_info:
|
||||
raise WebhookTimeoutError("Request timeout")
|
||||
assert str(exc_info.value) == "Request timeout"
|
||||
|
||||
with pytest.raises(WebhookNotFoundError) as exc_info:
|
||||
raise WebhookNotFoundError("Webhook not found")
|
||||
assert str(exc_info.value) == "Webhook not found"
|
||||
|
||||
with pytest.raises(WebhookConfigError) as exc_info:
|
||||
raise WebhookConfigError("Invalid config")
|
||||
assert str(exc_info.value) == "Invalid config"
|
||||
|
||||
|
||||
def test_webhook_error_catching_hierarchy():
|
||||
"""Test that webhook errors can be caught by their parent classes."""
|
||||
# WebhookTimeoutError should be catchable as WebhookNodeError
|
||||
with pytest.raises(WebhookNodeError):
|
||||
raise WebhookTimeoutError("Timeout error")
|
||||
|
||||
# WebhookNotFoundError should be catchable as WebhookNodeError
|
||||
with pytest.raises(WebhookNodeError):
|
||||
raise WebhookNotFoundError("Not found error")
|
||||
|
||||
# WebhookConfigError should be catchable as WebhookNodeError
|
||||
with pytest.raises(WebhookNodeError):
|
||||
raise WebhookConfigError("Config error")
|
||||
|
||||
# All webhook errors should be catchable as BaseNodeError
|
||||
with pytest.raises(BaseNodeError):
|
||||
raise WebhookTimeoutError("Timeout as base error")
|
||||
|
||||
with pytest.raises(BaseNodeError):
|
||||
raise WebhookNotFoundError("Not found as base error")
|
||||
|
||||
with pytest.raises(BaseNodeError):
|
||||
raise WebhookConfigError("Config as base error")
|
||||
|
||||
|
||||
def test_webhook_error_attributes():
|
||||
"""Test webhook error class attributes."""
|
||||
# Test that all error classes have proper __name__
|
||||
assert WebhookNodeError.__name__ == "WebhookNodeError"
|
||||
assert WebhookTimeoutError.__name__ == "WebhookTimeoutError"
|
||||
assert WebhookNotFoundError.__name__ == "WebhookNotFoundError"
|
||||
assert WebhookConfigError.__name__ == "WebhookConfigError"
|
||||
|
||||
# Test that all error classes have proper __module__
|
||||
expected_module = "core.workflow.nodes.trigger_webhook.exc"
|
||||
assert WebhookNodeError.__module__ == expected_module
|
||||
assert WebhookTimeoutError.__module__ == expected_module
|
||||
assert WebhookNotFoundError.__module__ == expected_module
|
||||
assert WebhookConfigError.__module__ == expected_module
|
||||
|
||||
|
||||
def test_webhook_error_docstrings():
|
||||
"""Test webhook error class docstrings."""
|
||||
assert WebhookNodeError.__doc__ == "Base webhook node error."
|
||||
assert WebhookTimeoutError.__doc__ == "Webhook timeout error."
|
||||
assert WebhookNotFoundError.__doc__ == "Webhook not found error."
|
||||
assert WebhookConfigError.__doc__ == "Webhook configuration error."
|
||||
|
||||
|
||||
def test_webhook_error_repr_and_str():
|
||||
"""Test webhook error string representations."""
|
||||
error = WebhookNodeError("Test message")
|
||||
|
||||
# Test __str__ method
|
||||
assert str(error) == "Test message"
|
||||
|
||||
# Test __repr__ method (should include class name)
|
||||
repr_str = repr(error)
|
||||
assert "WebhookNodeError" in repr_str
|
||||
assert "Test message" in repr_str
|
||||
|
||||
|
||||
def test_webhook_error_with_no_message():
|
||||
"""Test webhook errors with no message."""
|
||||
# Test that errors can be instantiated without messages
|
||||
errors = [
|
||||
WebhookNodeError(),
|
||||
WebhookTimeoutError(),
|
||||
WebhookNotFoundError(),
|
||||
WebhookConfigError(),
|
||||
]
|
||||
|
||||
for error in errors:
|
||||
# Should be instances of their respective classes
|
||||
assert isinstance(error, type(error))
|
||||
# Should be able to be raised
|
||||
with pytest.raises(type(error)):
|
||||
raise error
|
||||
@@ -0,0 +1,468 @@
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import StringVariable
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.trigger_webhook.entities import (
|
||||
ContentType,
|
||||
Method,
|
||||
WebhookBodyParameter,
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode:
|
||||
"""Helper function to create a webhook node with proper initialization."""
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": webhook_data.model_dump(),
|
||||
}
|
||||
|
||||
node = TriggerWebhookNode(
|
||||
id="1",
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
)
|
||||
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
def test_webhook_node_basic_initialization():
|
||||
"""Test basic webhook node initialization and configuration."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
method=Method.POST,
|
||||
content_type=ContentType.JSON,
|
||||
headers=[WebhookParameter(name="X-API-Key", required=True)],
|
||||
params=[WebhookParameter(name="version", required=False)],
|
||||
body=[WebhookBodyParameter(name="message", type="string", required=True)],
|
||||
status_code=200,
|
||||
response_body="OK",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
|
||||
assert node.node_type.value == "trigger-webhook"
|
||||
assert node.version() == "1"
|
||||
assert node._get_title() == "Test Webhook"
|
||||
assert node._node_data.method == Method.POST
|
||||
assert node._node_data.content_type == ContentType.JSON
|
||||
assert len(node._node_data.headers) == 1
|
||||
assert len(node._node_data.params) == 1
|
||||
assert len(node._node_data.body) == 1
|
||||
|
||||
|
||||
def test_webhook_node_default_config():
|
||||
"""Test webhook node default configuration."""
|
||||
config = TriggerWebhookNode.get_default_config()
|
||||
|
||||
assert config["type"] == "webhook"
|
||||
assert config["config"]["method"] == "get"
|
||||
assert config["config"]["content_type"] == "application/json"
|
||||
assert config["config"]["headers"] == []
|
||||
assert config["config"]["params"] == []
|
||||
assert config["config"]["body"] == []
|
||||
assert config["config"]["async_mode"] is True
|
||||
assert config["config"]["status_code"] == 200
|
||||
assert config["config"]["response_body"] == ""
|
||||
assert config["config"]["timeout"] == 30
|
||||
|
||||
|
||||
def test_webhook_node_run_with_headers():
|
||||
"""Test webhook node execution with header extraction."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
headers=[
|
||||
WebhookParameter(name="Authorization", required=True),
|
||||
WebhookParameter(name="Content-Type", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {
|
||||
"Authorization": "Bearer token123",
|
||||
"content-type": "application/json", # Different case
|
||||
"X-Custom": "custom-value",
|
||||
},
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["Authorization"] == "Bearer token123"
|
||||
assert result.outputs["Content_Type"] == "application/json" # Case-insensitive match
|
||||
assert "_webhook_raw" in result.outputs
|
||||
|
||||
|
||||
def test_webhook_node_run_with_query_params():
|
||||
"""Test webhook node execution with query parameter extraction."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
params=[
|
||||
WebhookParameter(name="page", required=True),
|
||||
WebhookParameter(name="limit", required=False),
|
||||
WebhookParameter(name="missing", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {
|
||||
"page": "1",
|
||||
"limit": "10",
|
||||
},
|
||||
"body": {},
|
||||
"files": {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["page"] == "1"
|
||||
assert result.outputs["limit"] == "10"
|
||||
assert result.outputs["missing"] is None # Missing parameter should be None
|
||||
|
||||
|
||||
def test_webhook_node_run_with_body_params():
|
||||
"""Test webhook node execution with body parameter extraction."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
body=[
|
||||
WebhookBodyParameter(name="message", type="string", required=True),
|
||||
WebhookBodyParameter(name="count", type="number", required=False),
|
||||
WebhookBodyParameter(name="active", type="boolean", required=False),
|
||||
WebhookBodyParameter(name="metadata", type="object", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {
|
||||
"message": "Hello World",
|
||||
"count": 42,
|
||||
"active": True,
|
||||
"metadata": {"key": "value"},
|
||||
},
|
||||
"files": {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["message"] == "Hello World"
|
||||
assert result.outputs["count"] == 42
|
||||
assert result.outputs["active"] is True
|
||||
assert result.outputs["metadata"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_webhook_node_run_with_file_params():
|
||||
"""Test webhook node execution with file parameter extraction."""
|
||||
# Create mock file objects
|
||||
file1 = File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="file1",
|
||||
filename="image.jpg",
|
||||
mime_type="image/jpeg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
file2 = File(
|
||||
tenant_id="1",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="file2",
|
||||
filename="document.pdf",
|
||||
mime_type="application/pdf",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
body=[
|
||||
WebhookBodyParameter(name="upload", type="file", required=True),
|
||||
WebhookBodyParameter(name="document", type="file", required=False),
|
||||
WebhookBodyParameter(name="missing_file", type="file", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {
|
||||
"upload": file1,
|
||||
"document": file2,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["upload"] == file1
|
||||
assert result.outputs["document"] == file2
|
||||
assert result.outputs["missing_file"] is None
|
||||
|
||||
|
||||
def test_webhook_node_run_mixed_parameters():
|
||||
"""Test webhook node execution with mixed parameter types."""
|
||||
file_obj = File(
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="file1",
|
||||
filename="test.jpg",
|
||||
mime_type="image/jpeg",
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
headers=[WebhookParameter(name="Authorization", required=True)],
|
||||
params=[WebhookParameter(name="version", required=False)],
|
||||
body=[
|
||||
WebhookBodyParameter(name="message", type="string", required=True),
|
||||
WebhookBodyParameter(name="upload", type="file", required=False),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {"Authorization": "Bearer token"},
|
||||
"query_params": {"version": "v1"},
|
||||
"body": {"message": "Test message"},
|
||||
"files": {"upload": file_obj},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["Authorization"] == "Bearer token"
|
||||
assert result.outputs["version"] == "v1"
|
||||
assert result.outputs["message"] == "Test message"
|
||||
assert result.outputs["upload"] == file_obj
|
||||
assert "_webhook_raw" in result.outputs
|
||||
|
||||
|
||||
def test_webhook_node_run_empty_webhook_data():
|
||||
"""Test webhook node execution with empty webhook data."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
headers=[WebhookParameter(name="Authorization", required=False)],
|
||||
params=[WebhookParameter(name="page", required=False)],
|
||||
body=[WebhookBodyParameter(name="message", type="string", required=False)],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={}, # No webhook_data
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["Authorization"] is None
|
||||
assert result.outputs["page"] is None
|
||||
assert result.outputs["message"] is None
|
||||
assert result.outputs["_webhook_raw"] == {}
|
||||
|
||||
|
||||
def test_webhook_node_run_case_insensitive_headers():
|
||||
"""Test webhook node header extraction is case-insensitive."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
headers=[
|
||||
WebhookParameter(name="Content-Type", required=True),
|
||||
WebhookParameter(name="X-API-KEY", required=True),
|
||||
WebhookParameter(name="authorization", required=True),
|
||||
],
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {
|
||||
"content-type": "application/json", # lowercase
|
||||
"x-api-key": "key123", # lowercase
|
||||
"Authorization": "Bearer token", # different case
|
||||
},
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs["Content_Type"] == "application/json"
|
||||
assert result.outputs["X_API_KEY"] == "key123"
|
||||
assert result.outputs["authorization"] == "Bearer token"
|
||||
|
||||
|
||||
def test_webhook_node_variable_pool_user_inputs():
|
||||
"""Test that webhook node uses user_inputs from variable pool correctly."""
|
||||
data = WebhookData(title="Test Webhook")
|
||||
|
||||
# Add some additional variables to the pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}},
|
||||
"other_var": "should_be_included",
|
||||
},
|
||||
)
|
||||
variable_pool.add(["node1", "extra"], StringVariable(name="extra", value="extra_value"))
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
# Check that all user_inputs are included in the inputs (they get converted to dict)
|
||||
inputs_dict = dict(result.inputs)
|
||||
assert "webhook_data" in inputs_dict
|
||||
assert "other_var" in inputs_dict
|
||||
assert inputs_dict["other_var"] == "should_be_included"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"method",
|
||||
[Method.GET, Method.POST, Method.PUT, Method.DELETE, Method.PATCH, Method.HEAD],
|
||||
)
|
||||
def test_webhook_node_different_methods(method):
|
||||
"""Test webhook node with different HTTP methods."""
|
||||
data = WebhookData(
|
||||
title="Test Webhook",
|
||||
method=method,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={
|
||||
"webhook_data": {
|
||||
"headers": {},
|
||||
"query_params": {},
|
||||
"body": {},
|
||||
"files": {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
node = create_webhook_node(data, variable_pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert node._node_data.method == method
|
||||
|
||||
|
||||
def test_webhook_data_content_type_field():
|
||||
"""Test that content_type accepts both raw strings and enum values."""
|
||||
data1 = WebhookData(title="Test", content_type="application/json")
|
||||
assert data1.content_type == ContentType.JSON
|
||||
|
||||
data2 = WebhookData(title="Test", content_type=ContentType.FORM_DATA)
|
||||
assert data2.content_type == ContentType.FORM_DATA
|
||||
|
||||
|
||||
def test_webhook_parameter_models():
|
||||
"""Test webhook parameter model validation."""
|
||||
# Test WebhookParameter
|
||||
param = WebhookParameter(name="test_param", required=True)
|
||||
assert param.name == "test_param"
|
||||
assert param.required is True
|
||||
|
||||
param_default = WebhookParameter(name="test_param")
|
||||
assert param_default.required is False
|
||||
|
||||
# Test WebhookBodyParameter
|
||||
body_param = WebhookBodyParameter(name="test_body", type="string", required=True)
|
||||
assert body_param.name == "test_body"
|
||||
assert body_param.type == "string"
|
||||
assert body_param.required is True
|
||||
|
||||
body_param_default = WebhookBodyParameter(name="test_body")
|
||||
assert body_param_default.type == "string" # Default type
|
||||
assert body_param_default.required is False
|
||||
|
||||
|
||||
def test_webhook_data_field_defaults():
|
||||
"""Test webhook data model field defaults."""
|
||||
data = WebhookData(title="Minimal Webhook")
|
||||
|
||||
assert data.method == Method.GET
|
||||
assert data.content_type == ContentType.JSON
|
||||
assert data.headers == []
|
||||
assert data.params == []
|
||||
assert data.body == []
|
||||
assert data.status_code == 200
|
||||
assert data.response_body == ""
|
||||
assert data.webhook_id is None
|
||||
assert data.timeout == 30
|
||||
@@ -131,6 +131,12 @@ class TestCelerySSLConfiguration:
|
||||
mock_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK = False
|
||||
mock_config.ENABLE_DATASETS_QUEUE_MONITOR = False
|
||||
mock_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK = False
|
||||
mock_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK = False
|
||||
mock_config.WORKFLOW_SCHEDULE_POLLER_INTERVAL = 1
|
||||
mock_config.WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE = 100
|
||||
mock_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK = 0
|
||||
mock_config.ENABLE_TRIGGER_PROVIDER_REFRESH_TASK = False
|
||||
mock_config.TRIGGER_PROVIDER_REFRESH_INTERVAL = 15
|
||||
|
||||
with patch("extensions.ext_celery.dify_config", mock_config):
|
||||
from dify_app import DifyApp
|
||||
|
||||
381
api/tests/unit_tests/libs/test_cron_compatibility.py
Normal file
381
api/tests/unit_tests/libs/test_cron_compatibility.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Enhanced cron syntax compatibility tests for croniter backend.
|
||||
|
||||
This test suite mirrors the frontend cron-parser tests to ensure
|
||||
complete compatibility between frontend and backend cron processing.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
from croniter import CroniterBadCronError
|
||||
|
||||
from libs.schedule_utils import calculate_next_run_at
|
||||
|
||||
|
||||
class TestCronCompatibility(unittest.TestCase):
|
||||
"""Test enhanced cron syntax compatibility with frontend."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_enhanced_dayofweek_syntax(self):
|
||||
"""Test enhanced day-of-week syntax compatibility."""
|
||||
test_cases = [
|
||||
("0 9 * * 7", 0), # Sunday as 7
|
||||
("0 9 * * 0", 0), # Sunday as 0
|
||||
("0 9 * * MON", 1), # Monday abbreviation
|
||||
("0 9 * * TUE", 2), # Tuesday abbreviation
|
||||
("0 9 * * WED", 3), # Wednesday abbreviation
|
||||
("0 9 * * THU", 4), # Thursday abbreviation
|
||||
("0 9 * * FRI", 5), # Friday abbreviation
|
||||
("0 9 * * SAT", 6), # Saturday abbreviation
|
||||
("0 9 * * SUN", 0), # Sunday abbreviation
|
||||
]
|
||||
|
||||
for expr, expected_weekday in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert (next_time.weekday() + 1 if next_time.weekday() < 6 else 0) == expected_weekday
|
||||
assert next_time.hour == 9
|
||||
assert next_time.minute == 0
|
||||
|
||||
def test_enhanced_month_syntax(self):
|
||||
"""Test enhanced month syntax compatibility."""
|
||||
test_cases = [
|
||||
("0 9 1 JAN *", 1), # January abbreviation
|
||||
("0 9 1 FEB *", 2), # February abbreviation
|
||||
("0 9 1 MAR *", 3), # March abbreviation
|
||||
("0 9 1 APR *", 4), # April abbreviation
|
||||
("0 9 1 MAY *", 5), # May abbreviation
|
||||
("0 9 1 JUN *", 6), # June abbreviation
|
||||
("0 9 1 JUL *", 7), # July abbreviation
|
||||
("0 9 1 AUG *", 8), # August abbreviation
|
||||
("0 9 1 SEP *", 9), # September abbreviation
|
||||
("0 9 1 OCT *", 10), # October abbreviation
|
||||
("0 9 1 NOV *", 11), # November abbreviation
|
||||
("0 9 1 DEC *", 12), # December abbreviation
|
||||
]
|
||||
|
||||
for expr, expected_month in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert next_time.month == expected_month
|
||||
assert next_time.day == 1
|
||||
assert next_time.hour == 9
|
||||
|
||||
def test_predefined_expressions(self):
|
||||
"""Test predefined cron expressions compatibility."""
|
||||
test_cases = [
|
||||
("@yearly", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0),
|
||||
("@annually", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0),
|
||||
("@monthly", lambda dt: dt.day == 1 and dt.hour == 0),
|
||||
("@weekly", lambda dt: dt.weekday() == 6 and dt.hour == 0), # Sunday = 6 in weekday()
|
||||
("@daily", lambda dt: dt.hour == 0 and dt.minute == 0),
|
||||
("@midnight", lambda dt: dt.hour == 0 and dt.minute == 0),
|
||||
("@hourly", lambda dt: dt.minute == 0),
|
||||
]
|
||||
|
||||
for expr, validator in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert validator(next_time), f"Validator failed for {expr}: {next_time}"
|
||||
|
||||
def test_special_characters(self):
|
||||
"""Test special characters in cron expressions."""
|
||||
test_cases = [
|
||||
"0 9 ? * 1", # ? wildcard
|
||||
"0 12 * * 7", # Sunday as 7
|
||||
"0 15 L * *", # Last day of month
|
||||
]
|
||||
|
||||
for expr in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert next_time > self.base_time
|
||||
except Exception as e:
|
||||
self.fail(f"Expression '{expr}' should be valid but raised: {e}")
|
||||
|
||||
def test_range_and_list_syntax(self):
|
||||
"""Test range and list syntax with abbreviations."""
|
||||
test_cases = [
|
||||
"0 9 * * MON-FRI", # Weekday range with abbreviations
|
||||
"0 9 * JAN-MAR *", # Month range with abbreviations
|
||||
"0 9 * * SUN,WED,FRI", # Weekday list with abbreviations
|
||||
"0 9 1 JAN,JUN,DEC *", # Month list with abbreviations
|
||||
]
|
||||
|
||||
for expr in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
assert next_time > self.base_time
|
||||
except Exception as e:
|
||||
self.fail(f"Expression '{expr}' should be valid but raised: {e}")
|
||||
|
||||
def test_invalid_enhanced_syntax(self):
|
||||
"""Test that invalid enhanced syntax is properly rejected."""
|
||||
invalid_expressions = [
|
||||
"0 12 * JANUARY *", # Full month name (not supported)
|
||||
"0 12 * * MONDAY", # Full day name (not supported)
|
||||
"0 12 32 JAN *", # Invalid day with valid month
|
||||
"15 10 1 * 8", # Invalid day of week
|
||||
"15 10 1 INVALID *", # Invalid month abbreviation
|
||||
"15 10 1 * INVALID", # Invalid day abbreviation
|
||||
"@invalid", # Invalid predefined expression
|
||||
]
|
||||
|
||||
for expr in invalid_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
with pytest.raises((CroniterBadCronError, ValueError)):
|
||||
calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
|
||||
def test_edge_cases_with_enhanced_syntax(self):
|
||||
"""Test edge cases with enhanced syntax."""
|
||||
test_cases = [
|
||||
("0 0 29 FEB *", lambda dt: dt.month == 2 and dt.day == 29), # Feb 29 with month abbreviation
|
||||
]
|
||||
|
||||
for expr, validator in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
if next_time: # Some combinations might not occur soon
|
||||
assert validator(next_time), f"Validator failed for {expr}: {next_time}"
|
||||
except (CroniterBadCronError, ValueError):
|
||||
# Some edge cases might be valid but not have upcoming occurrences
|
||||
pass
|
||||
|
||||
# Test complex expressions that have specific constraints
|
||||
complex_expr = "59 23 31 DEC SAT" # December 31st at 23:59 on Saturday
|
||||
try:
|
||||
next_time = calculate_next_run_at(complex_expr, "UTC", self.base_time)
|
||||
if next_time:
|
||||
# The next occurrence might not be exactly Dec 31 if it's not a Saturday
|
||||
# Just verify it's a valid result
|
||||
assert next_time is not None
|
||||
assert next_time.hour == 23
|
||||
assert next_time.minute == 59
|
||||
except Exception:
|
||||
# Complex date constraints might not have near-future occurrences
|
||||
pass
|
||||
|
||||
|
||||
class TestTimezoneCompatibility(unittest.TestCase):
|
||||
"""Test timezone compatibility between frontend and backend."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_timezone_consistency(self):
|
||||
"""Test that calculations are consistent across different timezones."""
|
||||
timezones = [
|
||||
"UTC",
|
||||
"America/New_York",
|
||||
"Europe/London",
|
||||
"Asia/Tokyo",
|
||||
"Asia/Kolkata",
|
||||
"Australia/Sydney",
|
||||
]
|
||||
|
||||
expression = "0 12 * * *" # Daily at noon
|
||||
|
||||
for timezone in timezones:
|
||||
with self.subTest(timezone=timezone):
|
||||
next_time = calculate_next_run_at(expression, timezone, self.base_time)
|
||||
assert next_time is not None
|
||||
|
||||
# Convert back to the target timezone to verify it's noon
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = next_time.astimezone(tz)
|
||||
assert local_time.hour == 12
|
||||
assert local_time.minute == 0
|
||||
|
||||
def test_dst_handling(self):
|
||||
"""Test DST boundary handling."""
|
||||
# Test around DST spring forward (March 2024)
|
||||
dst_base = datetime(2024, 3, 8, 10, 0, 0, tzinfo=UTC)
|
||||
expression = "0 2 * * *" # 2 AM daily (problematic during DST)
|
||||
timezone = "America/New_York"
|
||||
|
||||
try:
|
||||
next_time = calculate_next_run_at(expression, timezone, dst_base)
|
||||
assert next_time is not None
|
||||
|
||||
# During DST spring forward, 2 AM becomes 3 AM - both are acceptable
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = next_time.astimezone(tz)
|
||||
assert local_time.hour in [2, 3] # Either 2 AM or 3 AM is acceptable
|
||||
except Exception as e:
|
||||
self.fail(f"DST handling failed: {e}")
|
||||
|
||||
def test_half_hour_timezones(self):
|
||||
"""Test timezones with half-hour offsets."""
|
||||
timezones_with_offsets = [
|
||||
("Asia/Kolkata", 17, 30), # UTC+5:30 -> 12:00 UTC = 17:30 IST
|
||||
("Australia/Adelaide", 22, 30), # UTC+10:30 -> 12:00 UTC = 22:30 ACDT (summer time)
|
||||
]
|
||||
|
||||
expression = "0 12 * * *" # Noon UTC
|
||||
|
||||
for timezone, expected_hour, expected_minute in timezones_with_offsets:
|
||||
with self.subTest(timezone=timezone):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expression, timezone, self.base_time)
|
||||
assert next_time is not None
|
||||
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = next_time.astimezone(tz)
|
||||
assert local_time.hour == expected_hour
|
||||
assert local_time.minute == expected_minute
|
||||
except Exception:
|
||||
# Some complex timezone calculations might vary
|
||||
pass
|
||||
|
||||
def test_invalid_timezone_handling(self):
|
||||
"""Test handling of invalid timezones."""
|
||||
expression = "0 12 * * *"
|
||||
invalid_timezone = "Invalid/Timezone"
|
||||
|
||||
with pytest.raises((ValueError, Exception)): # Should raise an exception
|
||||
calculate_next_run_at(expression, invalid_timezone, self.base_time)
|
||||
|
||||
|
||||
class TestFrontendBackendIntegration(unittest.TestCase):
|
||||
"""Test integration patterns that mirror frontend usage."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test environment."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_execution_time_calculator_pattern(self):
|
||||
"""Test the pattern used by execution-time-calculator.ts."""
|
||||
# This mirrors the exact usage from execution-time-calculator.ts:47
|
||||
test_data = {
|
||||
"cron_expression": "30 14 * * 1-5", # 2:30 PM weekdays
|
||||
"timezone": "America/New_York",
|
||||
}
|
||||
|
||||
# Get next 5 execution times (like the frontend does)
|
||||
execution_times = []
|
||||
current_base = self.base_time
|
||||
|
||||
for _ in range(5):
|
||||
next_time = calculate_next_run_at(test_data["cron_expression"], test_data["timezone"], current_base)
|
||||
assert next_time is not None
|
||||
execution_times.append(next_time)
|
||||
current_base = next_time + timedelta(seconds=1) # Move slightly forward
|
||||
|
||||
assert len(execution_times) == 5
|
||||
|
||||
# Validate each execution time
|
||||
for exec_time in execution_times:
|
||||
# Convert to local timezone
|
||||
tz = pytz.timezone(test_data["timezone"])
|
||||
local_time = exec_time.astimezone(tz)
|
||||
|
||||
# Should be weekdays (1-5)
|
||||
assert local_time.weekday() in [0, 1, 2, 3, 4] # Mon-Fri in Python weekday
|
||||
|
||||
# Should be 2:30 PM in local time
|
||||
assert local_time.hour == 14
|
||||
assert local_time.minute == 30
|
||||
assert local_time.second == 0
|
||||
|
||||
def test_schedule_service_integration(self):
|
||||
"""Test integration with ScheduleService patterns."""
|
||||
from core.workflow.nodes.trigger_schedule.entities import VisualConfig
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
|
||||
# Test enhanced syntax through visual config conversion
|
||||
visual_configs = [
|
||||
# Test with month abbreviations
|
||||
{
|
||||
"frequency": "monthly",
|
||||
"config": VisualConfig(time="9:00 AM", monthly_days=[1]),
|
||||
"expected_cron": "0 9 1 * *",
|
||||
},
|
||||
# Test with weekday abbreviations
|
||||
{
|
||||
"frequency": "weekly",
|
||||
"config": VisualConfig(time="2:30 PM", weekdays=["mon", "wed", "fri"]),
|
||||
"expected_cron": "30 14 * * 1,3,5",
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in visual_configs:
|
||||
with self.subTest(frequency=test_case["frequency"]):
|
||||
cron_expr = ScheduleService.visual_to_cron(test_case["frequency"], test_case["config"])
|
||||
assert cron_expr == test_case["expected_cron"]
|
||||
|
||||
# Verify the generated cron expression is valid
|
||||
next_time = calculate_next_run_at(cron_expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
|
||||
def test_error_handling_consistency(self):
|
||||
"""Test that error handling matches frontend expectations."""
|
||||
invalid_expressions = [
|
||||
"60 10 1 * *", # Invalid minute
|
||||
"15 25 1 * *", # Invalid hour
|
||||
"15 10 32 * *", # Invalid day
|
||||
"15 10 1 13 *", # Invalid month
|
||||
"15 10 1", # Too few fields
|
||||
"15 10 1 * * *", # 6 fields (not supported in frontend)
|
||||
"0 15 10 1 * * *", # 7 fields (not supported in frontend)
|
||||
"invalid expression", # Completely invalid
|
||||
]
|
||||
|
||||
for expr in invalid_expressions:
|
||||
with self.subTest(expr=repr(expr)):
|
||||
with pytest.raises((CroniterBadCronError, ValueError, Exception)):
|
||||
calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
|
||||
# Note: Empty/whitespace expressions are not tested here as they are
|
||||
# not expected in normal usage due to database constraints (nullable=False)
|
||||
|
||||
def test_performance_requirements(self):
|
||||
"""Test that complex expressions parse within reasonable time."""
|
||||
import time
|
||||
|
||||
complex_expressions = [
|
||||
"*/5 9-17 * * 1-5", # Every 5 minutes, weekdays, business hours
|
||||
"0 */2 1,15 * *", # Every 2 hours on 1st and 15th
|
||||
"30 14 * * 1,3,5", # Mon, Wed, Fri at 14:30
|
||||
"15,45 8-18 * * 1-5", # 15 and 45 minutes past hour, weekdays
|
||||
"0 9 * JAN-MAR MON-FRI", # Enhanced syntax: Q1 weekdays at 9 AM
|
||||
"0 12 ? * SUN", # Enhanced syntax: Sundays at noon with ?
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for expr in complex_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
next_time = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert next_time is not None
|
||||
except CroniterBadCronError:
|
||||
# Some enhanced syntax might not be supported, that's OK
|
||||
pass
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = (end_time - start_time) * 1000 # Convert to milliseconds
|
||||
|
||||
# Should complete within reasonable time (less than 150ms like frontend)
|
||||
assert execution_time < 150, "Complex expressions should parse quickly"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Import timedelta for the test
|
||||
from datetime import timedelta
|
||||
|
||||
unittest.main()
|
||||
411
api/tests/unit_tests/libs/test_schedule_utils_enhanced.py
Normal file
411
api/tests/unit_tests/libs/test_schedule_utils_enhanced.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Enhanced schedule_utils tests for new cron syntax support.
|
||||
|
||||
These tests verify that the backend schedule_utils functions properly support
|
||||
the enhanced cron syntax introduced in the frontend, ensuring full compatibility.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
from croniter import CroniterBadCronError
|
||||
|
||||
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
|
||||
|
||||
class TestEnhancedCronSyntax(unittest.TestCase):
|
||||
"""Test enhanced cron syntax in calculate_next_run_at."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
# Monday, January 15, 2024, 10:00 AM UTC
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_month_abbreviations(self):
|
||||
"""Test month abbreviations (JAN, FEB, etc.)."""
|
||||
test_cases = [
|
||||
("0 12 1 JAN *", 1), # January
|
||||
("0 12 1 FEB *", 2), # February
|
||||
("0 12 1 MAR *", 3), # March
|
||||
("0 12 1 APR *", 4), # April
|
||||
("0 12 1 MAY *", 5), # May
|
||||
("0 12 1 JUN *", 6), # June
|
||||
("0 12 1 JUL *", 7), # July
|
||||
("0 12 1 AUG *", 8), # August
|
||||
("0 12 1 SEP *", 9), # September
|
||||
("0 12 1 OCT *", 10), # October
|
||||
("0 12 1 NOV *", 11), # November
|
||||
("0 12 1 DEC *", 12), # December
|
||||
]
|
||||
|
||||
for expr, expected_month in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse: {expr}"
|
||||
assert result.month == expected_month
|
||||
assert result.day == 1
|
||||
assert result.hour == 12
|
||||
assert result.minute == 0
|
||||
|
||||
def test_weekday_abbreviations(self):
|
||||
"""Test weekday abbreviations (SUN, MON, etc.)."""
|
||||
test_cases = [
|
||||
("0 9 * * SUN", 6), # Sunday (weekday() = 6)
|
||||
("0 9 * * MON", 0), # Monday (weekday() = 0)
|
||||
("0 9 * * TUE", 1), # Tuesday
|
||||
("0 9 * * WED", 2), # Wednesday
|
||||
("0 9 * * THU", 3), # Thursday
|
||||
("0 9 * * FRI", 4), # Friday
|
||||
("0 9 * * SAT", 5), # Saturday
|
||||
]
|
||||
|
||||
for expr, expected_weekday in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse: {expr}"
|
||||
assert result.weekday() == expected_weekday
|
||||
assert result.hour == 9
|
||||
assert result.minute == 0
|
||||
|
||||
def test_sunday_dual_representation(self):
|
||||
"""Test Sunday as both 0 and 7."""
|
||||
base_time = datetime(2024, 1, 14, 10, 0, 0, tzinfo=UTC) # Sunday
|
||||
|
||||
# Both should give the same next Sunday
|
||||
result_0 = calculate_next_run_at("0 10 * * 0", "UTC", base_time)
|
||||
result_7 = calculate_next_run_at("0 10 * * 7", "UTC", base_time)
|
||||
result_SUN = calculate_next_run_at("0 10 * * SUN", "UTC", base_time)
|
||||
|
||||
assert result_0 is not None
|
||||
assert result_7 is not None
|
||||
assert result_SUN is not None
|
||||
|
||||
# All should be Sundays
|
||||
assert result_0.weekday() == 6 # Sunday = 6 in weekday()
|
||||
assert result_7.weekday() == 6
|
||||
assert result_SUN.weekday() == 6
|
||||
|
||||
# Times should be identical
|
||||
assert result_0 == result_7
|
||||
assert result_0 == result_SUN
|
||||
|
||||
def test_predefined_expressions(self):
|
||||
"""Test predefined expressions (@daily, @weekly, etc.)."""
|
||||
test_cases = [
|
||||
("@yearly", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0 and dt.minute == 0),
|
||||
("@annually", lambda dt: dt.month == 1 and dt.day == 1 and dt.hour == 0 and dt.minute == 0),
|
||||
("@monthly", lambda dt: dt.day == 1 and dt.hour == 0 and dt.minute == 0),
|
||||
("@weekly", lambda dt: dt.weekday() == 6 and dt.hour == 0 and dt.minute == 0), # Sunday
|
||||
("@daily", lambda dt: dt.hour == 0 and dt.minute == 0),
|
||||
("@midnight", lambda dt: dt.hour == 0 and dt.minute == 0),
|
||||
("@hourly", lambda dt: dt.minute == 0),
|
||||
]
|
||||
|
||||
for expr, validator in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse: {expr}"
|
||||
assert validator(result), f"Validator failed for {expr}: {result}"
|
||||
|
||||
def test_question_mark_wildcard(self):
|
||||
"""Test ? wildcard character."""
|
||||
# ? in day position with specific weekday
|
||||
result_question = calculate_next_run_at("0 9 ? * 1", "UTC", self.base_time) # Monday
|
||||
result_star = calculate_next_run_at("0 9 * * 1", "UTC", self.base_time) # Monday
|
||||
|
||||
assert result_question is not None
|
||||
assert result_star is not None
|
||||
|
||||
# Both should return Mondays at 9:00
|
||||
assert result_question.weekday() == 0 # Monday
|
||||
assert result_star.weekday() == 0
|
||||
assert result_question.hour == 9
|
||||
assert result_star.hour == 9
|
||||
|
||||
# Results should be identical
|
||||
assert result_question == result_star
|
||||
|
||||
def test_last_day_of_month(self):
|
||||
"""Test 'L' for last day of month."""
|
||||
expr = "0 12 L * *" # Last day of month at noon
|
||||
|
||||
# Test for February (28 days in 2024 - not a leap year check)
|
||||
feb_base = datetime(2024, 2, 15, 10, 0, 0, tzinfo=UTC)
|
||||
result = calculate_next_run_at(expr, "UTC", feb_base)
|
||||
assert result is not None
|
||||
assert result.month == 2
|
||||
assert result.day == 29 # 2024 is a leap year
|
||||
assert result.hour == 12
|
||||
|
||||
def test_range_with_abbreviations(self):
|
||||
"""Test ranges using abbreviations."""
|
||||
test_cases = [
|
||||
"0 9 * * MON-FRI", # Weekday range
|
||||
"0 12 * JAN-MAR *", # Q1 months
|
||||
"0 15 * APR-JUN *", # Q2 months
|
||||
]
|
||||
|
||||
for expr in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse range expression: {expr}"
|
||||
assert result > self.base_time
|
||||
|
||||
def test_list_with_abbreviations(self):
|
||||
"""Test lists using abbreviations."""
|
||||
test_cases = [
|
||||
("0 9 * * SUN,WED,FRI", [6, 2, 4]), # Specific weekdays
|
||||
("0 12 1 JAN,JUN,DEC *", [1, 6, 12]), # Specific months
|
||||
]
|
||||
|
||||
for expr, expected_values in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse list expression: {expr}"
|
||||
|
||||
if "* *" in expr: # Weekday test
|
||||
assert result.weekday() in expected_values
|
||||
else: # Month test
|
||||
assert result.month in expected_values
|
||||
|
||||
def test_mixed_syntax(self):
|
||||
"""Test mixed traditional and enhanced syntax."""
|
||||
test_cases = [
|
||||
"30 14 15 JAN,JUN,DEC *", # Numbers + month abbreviations
|
||||
"0 9 * JAN-MAR MON-FRI", # Month range + weekday range
|
||||
"45 8 1,15 * MON", # Numbers + weekday abbreviation
|
||||
]
|
||||
|
||||
for expr in test_cases:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Failed to parse mixed syntax: {expr}"
|
||||
assert result > self.base_time
|
||||
|
||||
def test_complex_enhanced_expressions(self):
|
||||
"""Test complex expressions with multiple enhanced features."""
|
||||
# Note: Some of these might not be supported by croniter, that's OK
|
||||
complex_expressions = [
|
||||
"0 9 L JAN *", # Last day of January
|
||||
"30 14 * * FRI#1", # First Friday of month (if supported)
|
||||
"0 12 15 JAN-DEC/3 *", # 15th of every 3rd month (quarterly)
|
||||
]
|
||||
|
||||
for expr in complex_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
if result: # If supported, should return valid result
|
||||
assert result > self.base_time
|
||||
except Exception:
|
||||
# Some complex expressions might not be supported - that's acceptable
|
||||
pass
|
||||
|
||||
|
||||
class TestTimezoneHandlingEnhanced(unittest.TestCase):
|
||||
"""Test timezone handling with enhanced syntax."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_enhanced_syntax_with_timezones(self):
|
||||
"""Test enhanced syntax works correctly across timezones."""
|
||||
timezones = ["UTC", "America/New_York", "Asia/Tokyo", "Europe/London"]
|
||||
expression = "0 12 * * MON" # Monday at noon
|
||||
|
||||
for timezone in timezones:
|
||||
with self.subTest(timezone=timezone):
|
||||
result = calculate_next_run_at(expression, timezone, self.base_time)
|
||||
assert result is not None
|
||||
|
||||
# Convert to local timezone to verify it's Monday at noon
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = result.astimezone(tz)
|
||||
assert local_time.weekday() == 0 # Monday
|
||||
assert local_time.hour == 12
|
||||
assert local_time.minute == 0
|
||||
|
||||
def test_predefined_expressions_with_timezones(self):
|
||||
"""Test predefined expressions work with different timezones."""
|
||||
expression = "@daily"
|
||||
timezones = ["UTC", "America/New_York", "Asia/Tokyo"]
|
||||
|
||||
for timezone in timezones:
|
||||
with self.subTest(timezone=timezone):
|
||||
result = calculate_next_run_at(expression, timezone, self.base_time)
|
||||
assert result is not None
|
||||
|
||||
# Should be midnight in the specified timezone
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = result.astimezone(tz)
|
||||
assert local_time.hour == 0
|
||||
assert local_time.minute == 0
|
||||
|
||||
def test_dst_with_enhanced_syntax(self):
|
||||
"""Test DST handling with enhanced syntax."""
|
||||
# DST spring forward date in 2024
|
||||
dst_base = datetime(2024, 3, 8, 10, 0, 0, tzinfo=UTC)
|
||||
expression = "0 2 * * SUN" # Sunday at 2 AM (problematic during DST)
|
||||
timezone = "America/New_York"
|
||||
|
||||
result = calculate_next_run_at(expression, timezone, dst_base)
|
||||
assert result is not None
|
||||
|
||||
# Should handle DST transition gracefully
|
||||
tz = pytz.timezone(timezone)
|
||||
local_time = result.astimezone(tz)
|
||||
assert local_time.weekday() == 6 # Sunday
|
||||
|
||||
# During DST spring forward, 2 AM might become 3 AM
|
||||
assert local_time.hour in [2, 3]
|
||||
|
||||
|
||||
class TestErrorHandlingEnhanced(unittest.TestCase):
|
||||
"""Test error handling for enhanced syntax."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_invalid_enhanced_syntax(self):
|
||||
"""Test that invalid enhanced syntax raises appropriate errors."""
|
||||
invalid_expressions = [
|
||||
"0 12 * JANUARY *", # Full month name
|
||||
"0 12 * * MONDAY", # Full day name
|
||||
"0 12 32 JAN *", # Invalid day with valid month
|
||||
"0 12 * * MON-SUN-FRI", # Invalid range syntax
|
||||
"0 12 * JAN- *", # Incomplete range
|
||||
"0 12 * * ,MON", # Invalid list syntax
|
||||
"@INVALID", # Invalid predefined
|
||||
]
|
||||
|
||||
for expr in invalid_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
with pytest.raises((CroniterBadCronError, ValueError)):
|
||||
calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
|
||||
def test_boundary_values_with_enhanced_syntax(self):
|
||||
"""Test boundary values work with enhanced syntax."""
|
||||
# Valid boundary expressions
|
||||
valid_expressions = [
|
||||
"0 0 1 JAN *", # Minimum: January 1st midnight
|
||||
"59 23 31 DEC *", # Maximum: December 31st 23:59
|
||||
"0 12 29 FEB *", # Leap year boundary
|
||||
]
|
||||
|
||||
for expr in valid_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
if result: # Some dates might not occur soon
|
||||
assert result > self.base_time
|
||||
except Exception as e:
|
||||
# Some boundary cases might be complex to calculate
|
||||
self.fail(f"Valid boundary expression failed: {expr} - {e}")
|
||||
|
||||
|
||||
class TestPerformanceEnhanced(unittest.TestCase):
|
||||
"""Test performance with enhanced syntax."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_complex_expression_performance(self):
|
||||
"""Test that complex enhanced expressions parse within reasonable time."""
|
||||
import time
|
||||
|
||||
complex_expressions = [
|
||||
"*/5 9-17 * * MON-FRI", # Every 5 min, weekdays, business hours
|
||||
"0 9 * JAN-MAR MON-FRI", # Q1 weekdays at 9 AM
|
||||
"30 14 1,15 * * ", # 1st and 15th at 14:30
|
||||
"0 12 ? * SUN", # Sundays at noon with ?
|
||||
"@daily", # Predefined expression
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for expr in complex_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
try:
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None
|
||||
except Exception:
|
||||
# Some expressions might not be supported - acceptable
|
||||
pass
|
||||
|
||||
end_time = time.time()
|
||||
execution_time = (end_time - start_time) * 1000 # milliseconds
|
||||
|
||||
# Should be fast (less than 100ms for all expressions)
|
||||
assert execution_time < 100, "Enhanced expressions should parse quickly"
|
||||
|
||||
def test_multiple_calculations_performance(self):
|
||||
"""Test performance when calculating multiple next times."""
|
||||
import time
|
||||
|
||||
expression = "0 9 * * MON-FRI" # Weekdays at 9 AM
|
||||
iterations = 20
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
current_time = self.base_time
|
||||
for _ in range(iterations):
|
||||
result = calculate_next_run_at(expression, "UTC", current_time)
|
||||
assert result is not None
|
||||
current_time = result + timedelta(seconds=1) # Move forward slightly
|
||||
|
||||
end_time = time.time()
|
||||
total_time = (end_time - start_time) * 1000 # milliseconds
|
||||
avg_time = total_time / iterations
|
||||
|
||||
# Average should be very fast (less than 5ms per calculation)
|
||||
assert avg_time < 5, f"Average calculation time too slow: {avg_time}ms"
|
||||
|
||||
|
||||
class TestRegressionEnhanced(unittest.TestCase):
|
||||
"""Regression tests to ensure enhanced syntax doesn't break existing functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test with fixed time."""
|
||||
self.base_time = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
||||
|
||||
def test_traditional_syntax_still_works(self):
|
||||
"""Ensure traditional cron syntax continues to work."""
|
||||
traditional_expressions = [
|
||||
"15 10 1 * *", # Monthly 1st at 10:15
|
||||
"0 0 * * 0", # Weekly Sunday midnight
|
||||
"*/5 * * * *", # Every 5 minutes
|
||||
"0 9-17 * * 1-5", # Business hours weekdays
|
||||
"30 14 * * 1", # Monday 14:30
|
||||
"0 0 1,15 * *", # 1st and 15th midnight
|
||||
]
|
||||
|
||||
for expr in traditional_expressions:
|
||||
with self.subTest(expr=expr):
|
||||
result = calculate_next_run_at(expr, "UTC", self.base_time)
|
||||
assert result is not None, f"Traditional expression failed: {expr}"
|
||||
assert result > self.base_time
|
||||
|
||||
def test_convert_12h_to_24h_unchanged(self):
|
||||
"""Ensure convert_12h_to_24h function is unchanged."""
|
||||
test_cases = [
|
||||
("12:00 AM", (0, 0)), # Midnight
|
||||
("12:00 PM", (12, 0)), # Noon
|
||||
("1:30 AM", (1, 30)), # Early morning
|
||||
("11:45 PM", (23, 45)), # Late evening
|
||||
("6:15 AM", (6, 15)), # Morning
|
||||
("3:30 PM", (15, 30)), # Afternoon
|
||||
]
|
||||
|
||||
for time_str, expected in test_cases:
|
||||
with self.subTest(time_str=time_str):
|
||||
result = convert_12h_to_24h(time_str)
|
||||
assert result == expected, f"12h conversion failed: {time_str}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
22
api/tests/unit_tests/models/test_plugin_entities.py
Normal file
22
api/tests/unit_tests/models/test_plugin_entities.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import binascii
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.plugin.entities.request import TriggerDispatchResponse
|
||||
|
||||
|
||||
def test_trigger_dispatch_response():
|
||||
raw_http_response = b'HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n\r\n{"message": "Hello, world!"}'
|
||||
|
||||
data: Mapping[str, Any] = {
|
||||
"user_id": "123",
|
||||
"events": ["event1", "event2"],
|
||||
"response": binascii.hexlify(raw_http_response).decode(),
|
||||
"payload": {"key": "value"},
|
||||
}
|
||||
|
||||
response = TriggerDispatchResponse(**data)
|
||||
|
||||
assert response.response.status_code == 200
|
||||
assert response.response.headers["Content-Type"] == "application/json"
|
||||
assert response.response.get_data(as_text=True) == '{"message": "Hello, world!"}'
|
||||
779
api/tests/unit_tests/services/test_schedule_service.py
Normal file
779
api/tests/unit_tests/services/test_schedule_service.py
Normal file
@@ -0,0 +1,779 @@
|
||||
import unittest
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError
|
||||
from events.event_handlers.sync_workflow_schedule_when_app_published import (
|
||||
sync_schedule_from_workflow,
|
||||
)
|
||||
from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.trigger import WorkflowSchedulePlan
|
||||
from models.workflow import Workflow
|
||||
from services.trigger.schedule_service import ScheduleService
|
||||
|
||||
|
||||
class TestScheduleService(unittest.TestCase):
|
||||
"""Test cases for ScheduleService class."""
|
||||
|
||||
def test_calculate_next_run_at_valid_cron(self):
|
||||
"""Test calculating next run time with valid cron expression."""
|
||||
# Test daily cron at 10:30 AM
|
||||
cron_expr = "30 10 * * *"
|
||||
timezone = "UTC"
|
||||
base_time = datetime(2025, 8, 29, 9, 0, 0, tzinfo=UTC)
|
||||
|
||||
next_run = calculate_next_run_at(cron_expr, timezone, base_time)
|
||||
|
||||
assert next_run is not None
|
||||
assert next_run.hour == 10
|
||||
assert next_run.minute == 30
|
||||
assert next_run.day == 29
|
||||
|
||||
def test_calculate_next_run_at_with_timezone(self):
|
||||
"""Test calculating next run time with different timezone."""
|
||||
cron_expr = "0 9 * * *" # 9:00 AM
|
||||
timezone = "America/New_York"
|
||||
base_time = datetime(2025, 8, 29, 12, 0, 0, tzinfo=UTC) # 8:00 AM EDT
|
||||
|
||||
next_run = calculate_next_run_at(cron_expr, timezone, base_time)
|
||||
|
||||
assert next_run is not None
|
||||
# 9:00 AM EDT = 13:00 UTC (during EDT)
|
||||
expected_utc_hour = 13
|
||||
assert next_run.hour == expected_utc_hour
|
||||
|
||||
def test_calculate_next_run_at_with_last_day_of_month(self):
|
||||
"""Test calculating next run time with 'L' (last day) syntax."""
|
||||
cron_expr = "0 10 L * *" # 10:00 AM on last day of month
|
||||
timezone = "UTC"
|
||||
base_time = datetime(2025, 2, 15, 9, 0, 0, tzinfo=UTC)
|
||||
|
||||
next_run = calculate_next_run_at(cron_expr, timezone, base_time)
|
||||
|
||||
assert next_run is not None
|
||||
# February 2025 has 28 days
|
||||
assert next_run.day == 28
|
||||
assert next_run.month == 2
|
||||
|
||||
def test_calculate_next_run_at_invalid_cron(self):
|
||||
"""Test calculating next run time with invalid cron expression."""
|
||||
cron_expr = "invalid cron"
|
||||
timezone = "UTC"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
calculate_next_run_at(cron_expr, timezone)
|
||||
|
||||
def test_calculate_next_run_at_invalid_timezone(self):
|
||||
"""Test calculating next run time with invalid timezone."""
|
||||
from pytz import UnknownTimeZoneError
|
||||
|
||||
cron_expr = "30 10 * * *"
|
||||
timezone = "Invalid/Timezone"
|
||||
|
||||
with pytest.raises(UnknownTimeZoneError):
|
||||
calculate_next_run_at(cron_expr, timezone)
|
||||
|
||||
@patch("libs.schedule_utils.calculate_next_run_at")
|
||||
def test_create_schedule(self, mock_calculate_next_run):
|
||||
"""Test creating a new schedule."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_calculate_next_run.return_value = datetime(2025, 8, 30, 10, 30, 0, tzinfo=UTC)
|
||||
|
||||
config = ScheduleConfig(
|
||||
node_id="start",
|
||||
cron_expression="30 10 * * *",
|
||||
timezone="UTC",
|
||||
)
|
||||
|
||||
schedule = ScheduleService.create_schedule(
|
||||
session=mock_session,
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert schedule is not None
|
||||
assert schedule.tenant_id == "test-tenant"
|
||||
assert schedule.app_id == "test-app"
|
||||
assert schedule.node_id == "start"
|
||||
assert schedule.cron_expression == "30 10 * * *"
|
||||
assert schedule.timezone == "UTC"
|
||||
assert schedule.next_run_at is not None
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
@patch("services.trigger.schedule_service.calculate_next_run_at")
|
||||
def test_update_schedule(self, mock_calculate_next_run):
|
||||
"""Test updating an existing schedule."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_schedule = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_schedule.cron_expression = "0 12 * * *"
|
||||
mock_schedule.timezone = "America/New_York"
|
||||
mock_session.get.return_value = mock_schedule
|
||||
mock_calculate_next_run.return_value = datetime(2025, 8, 30, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
updates = SchedulePlanUpdate(
|
||||
cron_expression="0 12 * * *",
|
||||
timezone="America/New_York",
|
||||
)
|
||||
|
||||
result = ScheduleService.update_schedule(
|
||||
session=mock_session,
|
||||
schedule_id="test-schedule-id",
|
||||
updates=updates,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.cron_expression == "0 12 * * *"
|
||||
assert result.timezone == "America/New_York"
|
||||
mock_calculate_next_run.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
def test_update_schedule_not_found(self):
|
||||
"""Test updating a non-existent schedule raises exception."""
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
updates = SchedulePlanUpdate(
|
||||
cron_expression="0 12 * * *",
|
||||
)
|
||||
|
||||
with pytest.raises(ScheduleNotFoundError) as context:
|
||||
ScheduleService.update_schedule(
|
||||
session=mock_session,
|
||||
schedule_id="non-existent-id",
|
||||
updates=updates,
|
||||
)
|
||||
|
||||
assert "Schedule not found: non-existent-id" in str(context.value)
|
||||
mock_session.flush.assert_not_called()
|
||||
|
||||
def test_delete_schedule(self):
|
||||
"""Test deleting a schedule."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_schedule = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_session.get.return_value = mock_schedule
|
||||
|
||||
# Should not raise exception and complete successfully
|
||||
ScheduleService.delete_schedule(
|
||||
session=mock_session,
|
||||
schedule_id="test-schedule-id",
|
||||
)
|
||||
|
||||
mock_session.delete.assert_called_once_with(mock_schedule)
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
def test_delete_schedule_not_found(self):
|
||||
"""Test deleting a non-existent schedule raises exception."""
|
||||
from core.workflow.nodes.trigger_schedule.exc import ScheduleNotFoundError
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.get.return_value = None
|
||||
|
||||
# Should raise ScheduleNotFoundError
|
||||
with pytest.raises(ScheduleNotFoundError) as context:
|
||||
ScheduleService.delete_schedule(
|
||||
session=mock_session,
|
||||
schedule_id="non-existent-id",
|
||||
)
|
||||
|
||||
assert "Schedule not found: non-existent-id" in str(context.value)
|
||||
mock_session.delete.assert_not_called()
|
||||
|
||||
@patch("services.trigger.schedule_service.select")
|
||||
def test_get_tenant_owner(self, mock_select):
|
||||
"""Test getting tenant owner account."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_account = Mock(spec=Account)
|
||||
mock_account.id = "owner-account-id"
|
||||
|
||||
# Mock owner query
|
||||
mock_owner_result = Mock(spec=TenantAccountJoin)
|
||||
mock_owner_result.account_id = "owner-account-id"
|
||||
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = mock_owner_result
|
||||
mock_session.get.return_value = mock_account
|
||||
|
||||
result = ScheduleService.get_tenant_owner(
|
||||
session=mock_session,
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "owner-account-id"
|
||||
|
||||
@patch("services.trigger.schedule_service.select")
|
||||
def test_get_tenant_owner_fallback_to_admin(self, mock_select):
|
||||
"""Test getting tenant owner falls back to admin if no owner."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_account = Mock(spec=Account)
|
||||
mock_account.id = "admin-account-id"
|
||||
|
||||
# Mock admin query (owner returns None)
|
||||
mock_admin_result = Mock(spec=TenantAccountJoin)
|
||||
mock_admin_result.account_id = "admin-account-id"
|
||||
|
||||
mock_session.execute.return_value.scalar_one_or_none.side_effect = [None, mock_admin_result]
|
||||
mock_session.get.return_value = mock_account
|
||||
|
||||
result = ScheduleService.get_tenant_owner(
|
||||
session=mock_session,
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "admin-account-id"
|
||||
|
||||
@patch("services.trigger.schedule_service.calculate_next_run_at")
|
||||
def test_update_next_run_at(self, mock_calculate_next_run):
|
||||
"""Test updating next run time after schedule triggered."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_schedule = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_schedule.cron_expression = "30 10 * * *"
|
||||
mock_schedule.timezone = "UTC"
|
||||
mock_session.get.return_value = mock_schedule
|
||||
|
||||
next_time = datetime(2025, 8, 31, 10, 30, 0, tzinfo=UTC)
|
||||
mock_calculate_next_run.return_value = next_time
|
||||
|
||||
result = ScheduleService.update_next_run_at(
|
||||
session=mock_session,
|
||||
schedule_id="test-schedule-id",
|
||||
)
|
||||
|
||||
assert result == next_time
|
||||
assert mock_schedule.next_run_at == next_time
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
|
||||
class TestVisualToCron(unittest.TestCase):
|
||||
"""Test cases for visual configuration to cron conversion."""
|
||||
|
||||
def test_visual_to_cron_hourly(self):
|
||||
"""Test converting hourly visual config to cron."""
|
||||
visual_config = VisualConfig(on_minute=15)
|
||||
result = ScheduleService.visual_to_cron("hourly", visual_config)
|
||||
assert result == "15 * * * *"
|
||||
|
||||
def test_visual_to_cron_daily(self):
|
||||
"""Test converting daily visual config to cron."""
|
||||
visual_config = VisualConfig(time="2:30 PM")
|
||||
result = ScheduleService.visual_to_cron("daily", visual_config)
|
||||
assert result == "30 14 * * *"
|
||||
|
||||
def test_visual_to_cron_weekly(self):
|
||||
"""Test converting weekly visual config to cron."""
|
||||
visual_config = VisualConfig(
|
||||
time="10:00 AM",
|
||||
weekdays=["mon", "wed", "fri"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
assert result == "0 10 * * 1,3,5"
|
||||
|
||||
def test_visual_to_cron_monthly_with_specific_days(self):
|
||||
"""Test converting monthly visual config with specific days."""
|
||||
visual_config = VisualConfig(
|
||||
time="11:30 AM",
|
||||
monthly_days=[1, 15],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "30 11 1,15 * *"
|
||||
|
||||
def test_visual_to_cron_monthly_with_last_day(self):
|
||||
"""Test converting monthly visual config with last day using 'L' syntax."""
|
||||
visual_config = VisualConfig(
|
||||
time="11:30 AM",
|
||||
monthly_days=[1, "last"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "30 11 1,L * *"
|
||||
|
||||
def test_visual_to_cron_monthly_only_last_day(self):
|
||||
"""Test converting monthly visual config with only last day."""
|
||||
visual_config = VisualConfig(
|
||||
time="9:00 PM",
|
||||
monthly_days=["last"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "0 21 L * *"
|
||||
|
||||
def test_visual_to_cron_monthly_with_end_days_and_last(self):
|
||||
"""Test converting monthly visual config with days 29, 30, 31 and 'last'."""
|
||||
visual_config = VisualConfig(
|
||||
time="3:45 PM",
|
||||
monthly_days=[29, 30, 31, "last"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
# Should have 29,30,31,L - the L handles all possible last days
|
||||
assert result == "45 15 29,30,31,L * *"
|
||||
|
||||
def test_visual_to_cron_invalid_frequency(self):
|
||||
"""Test converting with invalid frequency."""
|
||||
with pytest.raises(ScheduleConfigError, match="Unsupported frequency: invalid"):
|
||||
ScheduleService.visual_to_cron("invalid", VisualConfig())
|
||||
|
||||
def test_visual_to_cron_weekly_no_weekdays(self):
|
||||
"""Test converting weekly with no weekdays specified."""
|
||||
visual_config = VisualConfig(time="10:00 AM")
|
||||
with pytest.raises(ScheduleConfigError, match="Weekdays are required for weekly schedules"):
|
||||
ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
|
||||
def test_visual_to_cron_hourly_no_minute(self):
|
||||
"""Test converting hourly with no on_minute specified."""
|
||||
visual_config = VisualConfig() # on_minute defaults to 0
|
||||
result = ScheduleService.visual_to_cron("hourly", visual_config)
|
||||
assert result == "0 * * * *" # Should use default value 0
|
||||
|
||||
def test_visual_to_cron_daily_no_time(self):
|
||||
"""Test converting daily with no time specified."""
|
||||
visual_config = VisualConfig(time=None)
|
||||
with pytest.raises(ScheduleConfigError, match="time is required for daily schedules"):
|
||||
ScheduleService.visual_to_cron("daily", visual_config)
|
||||
|
||||
def test_visual_to_cron_weekly_no_time(self):
|
||||
"""Test converting weekly with no time specified."""
|
||||
visual_config = VisualConfig(weekdays=["mon"])
|
||||
visual_config.time = None # Override default
|
||||
with pytest.raises(ScheduleConfigError, match="time is required for weekly schedules"):
|
||||
ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
|
||||
def test_visual_to_cron_monthly_no_time(self):
|
||||
"""Test converting monthly with no time specified."""
|
||||
visual_config = VisualConfig(monthly_days=[1])
|
||||
visual_config.time = None # Override default
|
||||
with pytest.raises(ScheduleConfigError, match="time is required for monthly schedules"):
|
||||
ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
|
||||
def test_visual_to_cron_monthly_duplicate_days(self):
|
||||
"""Test monthly with duplicate days should be deduplicated."""
|
||||
visual_config = VisualConfig(
|
||||
time="10:00 AM",
|
||||
monthly_days=[1, 15, 1, 15, 31], # Duplicates
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "0 10 1,15,31 * *" # Should be deduplicated
|
||||
|
||||
def test_visual_to_cron_monthly_unsorted_days(self):
|
||||
"""Test monthly with unsorted days should be sorted."""
|
||||
visual_config = VisualConfig(
|
||||
time="2:30 PM",
|
||||
monthly_days=[20, 5, 15, 1, 10], # Unsorted
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "30 14 1,5,10,15,20 * *" # Should be sorted
|
||||
|
||||
def test_visual_to_cron_weekly_all_weekdays(self):
|
||||
"""Test weekly with all weekdays."""
|
||||
visual_config = VisualConfig(
|
||||
time="8:00 AM",
|
||||
weekdays=["sun", "mon", "tue", "wed", "thu", "fri", "sat"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
assert result == "0 8 * * 0,1,2,3,4,5,6"
|
||||
|
||||
def test_visual_to_cron_hourly_boundary_values(self):
|
||||
"""Test hourly with boundary minute values."""
|
||||
# Minimum value
|
||||
visual_config = VisualConfig(on_minute=0)
|
||||
result = ScheduleService.visual_to_cron("hourly", visual_config)
|
||||
assert result == "0 * * * *"
|
||||
|
||||
# Maximum value
|
||||
visual_config = VisualConfig(on_minute=59)
|
||||
result = ScheduleService.visual_to_cron("hourly", visual_config)
|
||||
assert result == "59 * * * *"
|
||||
|
||||
def test_visual_to_cron_daily_midnight_noon(self):
|
||||
"""Test daily at special times (midnight and noon)."""
|
||||
# Midnight
|
||||
visual_config = VisualConfig(time="12:00 AM")
|
||||
result = ScheduleService.visual_to_cron("daily", visual_config)
|
||||
assert result == "0 0 * * *"
|
||||
|
||||
# Noon
|
||||
visual_config = VisualConfig(time="12:00 PM")
|
||||
result = ScheduleService.visual_to_cron("daily", visual_config)
|
||||
assert result == "0 12 * * *"
|
||||
|
||||
def test_visual_to_cron_monthly_mixed_with_last_and_duplicates(self):
|
||||
"""Test monthly with mixed days, 'last', and duplicates."""
|
||||
visual_config = VisualConfig(
|
||||
time="11:45 PM",
|
||||
monthly_days=[15, 1, "last", 15, 30, 1, "last"], # Mixed with duplicates
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert result == "45 23 1,15,30,L * *" # Deduplicated and sorted with L at end
|
||||
|
||||
def test_visual_to_cron_weekly_single_day(self):
|
||||
"""Test weekly with single weekday."""
|
||||
visual_config = VisualConfig(
|
||||
time="6:30 PM",
|
||||
weekdays=["sun"],
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
assert result == "30 18 * * 0"
|
||||
|
||||
def test_visual_to_cron_monthly_all_possible_days(self):
|
||||
"""Test monthly with all 31 days plus 'last'."""
|
||||
all_days = list(range(1, 32)) + ["last"]
|
||||
visual_config = VisualConfig(
|
||||
time="12:01 AM",
|
||||
monthly_days=all_days,
|
||||
)
|
||||
result = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
expected_days = ",".join([str(i) for i in range(1, 32)]) + ",L"
|
||||
assert result == f"1 0 {expected_days} * *"
|
||||
|
||||
def test_visual_to_cron_monthly_no_days(self):
|
||||
"""Test monthly without any days specified should raise error."""
|
||||
visual_config = VisualConfig(time="10:00 AM", monthly_days=[])
|
||||
with pytest.raises(ScheduleConfigError, match="Monthly days are required for monthly schedules"):
|
||||
ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
|
||||
def test_visual_to_cron_weekly_empty_weekdays_list(self):
|
||||
"""Test weekly with empty weekdays list should raise error."""
|
||||
visual_config = VisualConfig(time="10:00 AM", weekdays=[])
|
||||
with pytest.raises(ScheduleConfigError, match="Weekdays are required for weekly schedules"):
|
||||
ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
|
||||
|
||||
class TestParseTime(unittest.TestCase):
|
||||
"""Test cases for time parsing function."""
|
||||
|
||||
def test_parse_time_am(self):
|
||||
"""Test parsing AM time."""
|
||||
hour, minute = convert_12h_to_24h("9:30 AM")
|
||||
assert hour == 9
|
||||
assert minute == 30
|
||||
|
||||
def test_parse_time_pm(self):
|
||||
"""Test parsing PM time."""
|
||||
hour, minute = convert_12h_to_24h("2:45 PM")
|
||||
assert hour == 14
|
||||
assert minute == 45
|
||||
|
||||
def test_parse_time_noon(self):
|
||||
"""Test parsing 12:00 PM (noon)."""
|
||||
hour, minute = convert_12h_to_24h("12:00 PM")
|
||||
assert hour == 12
|
||||
assert minute == 0
|
||||
|
||||
def test_parse_time_midnight(self):
|
||||
"""Test parsing 12:00 AM (midnight)."""
|
||||
hour, minute = convert_12h_to_24h("12:00 AM")
|
||||
assert hour == 0
|
||||
assert minute == 0
|
||||
|
||||
def test_parse_time_invalid_format(self):
|
||||
"""Test parsing invalid time format."""
|
||||
with pytest.raises(ValueError, match="Invalid time format"):
|
||||
convert_12h_to_24h("25:00")
|
||||
|
||||
def test_parse_time_invalid_hour(self):
|
||||
"""Test parsing invalid hour."""
|
||||
with pytest.raises(ValueError, match="Invalid hour: 13"):
|
||||
convert_12h_to_24h("13:00 PM")
|
||||
|
||||
def test_parse_time_invalid_minute(self):
|
||||
"""Test parsing invalid minute."""
|
||||
with pytest.raises(ValueError, match="Invalid minute: 60"):
|
||||
convert_12h_to_24h("10:60 AM")
|
||||
|
||||
def test_parse_time_empty_string(self):
|
||||
"""Test parsing empty string."""
|
||||
with pytest.raises(ValueError, match="Time string cannot be empty"):
|
||||
convert_12h_to_24h("")
|
||||
|
||||
def test_parse_time_invalid_period(self):
|
||||
"""Test parsing invalid period."""
|
||||
with pytest.raises(ValueError, match="Invalid period"):
|
||||
convert_12h_to_24h("10:30 XM")
|
||||
|
||||
|
||||
class TestExtractScheduleConfig(unittest.TestCase):
|
||||
"""Test cases for extracting schedule configuration from workflow."""
|
||||
|
||||
def test_extract_schedule_config_with_cron_mode(self):
|
||||
"""Test extracting schedule config in cron mode."""
|
||||
workflow = Mock(spec=Workflow)
|
||||
workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "schedule-node",
|
||||
"data": {
|
||||
"type": "trigger-schedule",
|
||||
"mode": "cron",
|
||||
"cron_expression": "0 10 * * *",
|
||||
"timezone": "America/New_York",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
config = ScheduleService.extract_schedule_config(workflow)
|
||||
|
||||
assert config is not None
|
||||
assert config.node_id == "schedule-node"
|
||||
assert config.cron_expression == "0 10 * * *"
|
||||
assert config.timezone == "America/New_York"
|
||||
|
||||
def test_extract_schedule_config_with_visual_mode(self):
|
||||
"""Test extracting schedule config in visual mode."""
|
||||
workflow = Mock(spec=Workflow)
|
||||
workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "schedule-node",
|
||||
"data": {
|
||||
"type": "trigger-schedule",
|
||||
"mode": "visual",
|
||||
"frequency": "daily",
|
||||
"visual_config": {"time": "10:30 AM"},
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
config = ScheduleService.extract_schedule_config(workflow)
|
||||
|
||||
assert config is not None
|
||||
assert config.node_id == "schedule-node"
|
||||
assert config.cron_expression == "30 10 * * *"
|
||||
assert config.timezone == "UTC"
|
||||
|
||||
def test_extract_schedule_config_no_schedule_node(self):
|
||||
"""Test extracting config when no schedule node exists."""
|
||||
workflow = Mock(spec=Workflow)
|
||||
workflow.graph_dict = {
|
||||
"nodes": [
|
||||
{
|
||||
"id": "other-node",
|
||||
"data": {"type": "llm"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
config = ScheduleService.extract_schedule_config(workflow)
|
||||
assert config is None
|
||||
|
||||
def test_extract_schedule_config_invalid_graph(self):
|
||||
"""Test extracting config with invalid graph data."""
|
||||
workflow = Mock(spec=Workflow)
|
||||
workflow.graph_dict = None
|
||||
|
||||
with pytest.raises(ScheduleConfigError, match="Workflow graph is empty"):
|
||||
ScheduleService.extract_schedule_config(workflow)
|
||||
|
||||
|
||||
class TestScheduleWithTimezone(unittest.TestCase):
|
||||
"""Test cases for schedule with timezone handling."""
|
||||
|
||||
def test_visual_schedule_with_timezone_integration(self):
|
||||
"""Test complete flow: visual config → cron → execution in different timezones.
|
||||
|
||||
This test verifies that when a user in Shanghai sets a schedule for 10:30 AM,
|
||||
it runs at 10:30 AM Shanghai time, not 10:30 AM UTC.
|
||||
"""
|
||||
# User in Shanghai wants to run a task at 10:30 AM local time
|
||||
visual_config = VisualConfig(
|
||||
time="10:30 AM", # This is Shanghai time
|
||||
monthly_days=[1],
|
||||
)
|
||||
|
||||
# Convert to cron expression
|
||||
cron_expr = ScheduleService.visual_to_cron("monthly", visual_config)
|
||||
assert cron_expr is not None
|
||||
|
||||
assert cron_expr == "30 10 1 * *" # Direct conversion
|
||||
|
||||
# Now test execution with Shanghai timezone
|
||||
shanghai_tz = "Asia/Shanghai"
|
||||
# Base time: 2025-01-01 00:00:00 UTC (08:00:00 Shanghai)
|
||||
base_time = datetime(2025, 1, 1, 0, 0, 0, tzinfo=UTC)
|
||||
|
||||
next_run = calculate_next_run_at(cron_expr, shanghai_tz, base_time)
|
||||
|
||||
assert next_run is not None
|
||||
|
||||
# Should run at 10:30 AM Shanghai time on Jan 1
|
||||
# 10:30 AM Shanghai = 02:30 AM UTC (Shanghai is UTC+8)
|
||||
assert next_run.year == 2025
|
||||
assert next_run.month == 1
|
||||
assert next_run.day == 1
|
||||
assert next_run.hour == 2 # 02:30 UTC
|
||||
assert next_run.minute == 30
|
||||
|
||||
def test_visual_schedule_different_timezones_same_local_time(self):
|
||||
"""Test that same visual config in different timezones runs at different UTC times.
|
||||
|
||||
This verifies that a schedule set for "9:00 AM" runs at 9 AM local time
|
||||
regardless of the timezone.
|
||||
"""
|
||||
visual_config = VisualConfig(
|
||||
time="9:00 AM",
|
||||
weekdays=["mon"],
|
||||
)
|
||||
|
||||
cron_expr = ScheduleService.visual_to_cron("weekly", visual_config)
|
||||
assert cron_expr is not None
|
||||
assert cron_expr == "0 9 * * 1"
|
||||
|
||||
# Base time: Sunday 2025-01-05 12:00:00 UTC
|
||||
base_time = datetime(2025, 1, 5, 12, 0, 0, tzinfo=UTC)
|
||||
|
||||
# Test New York (UTC-5 in January)
|
||||
ny_next = calculate_next_run_at(cron_expr, "America/New_York", base_time)
|
||||
assert ny_next is not None
|
||||
# Monday 9 AM EST = Monday 14:00 UTC
|
||||
assert ny_next.day == 6
|
||||
assert ny_next.hour == 14 # 9 AM EST = 2 PM UTC
|
||||
|
||||
# Test Tokyo (UTC+9)
|
||||
tokyo_next = calculate_next_run_at(cron_expr, "Asia/Tokyo", base_time)
|
||||
assert tokyo_next is not None
|
||||
# Monday 9 AM JST = Monday 00:00 UTC
|
||||
assert tokyo_next.day == 6
|
||||
assert tokyo_next.hour == 0 # 9 AM JST = 0 AM UTC
|
||||
|
||||
def test_visual_schedule_daily_across_dst_change(self):
|
||||
"""Test that daily schedules adjust correctly during DST changes.
|
||||
|
||||
A schedule set for "10:00 AM" should always run at 10 AM local time,
|
||||
even when DST changes.
|
||||
"""
|
||||
visual_config = VisualConfig(
|
||||
time="10:00 AM",
|
||||
)
|
||||
|
||||
cron_expr = ScheduleService.visual_to_cron("daily", visual_config)
|
||||
assert cron_expr is not None
|
||||
|
||||
assert cron_expr == "0 10 * * *"
|
||||
|
||||
# Test before DST (EST - UTC-5)
|
||||
winter_base = datetime(2025, 2, 1, 0, 0, 0, tzinfo=UTC)
|
||||
winter_next = calculate_next_run_at(cron_expr, "America/New_York", winter_base)
|
||||
assert winter_next is not None
|
||||
# 10 AM EST = 15:00 UTC
|
||||
assert winter_next.hour == 15
|
||||
|
||||
# Test during DST (EDT - UTC-4)
|
||||
summer_base = datetime(2025, 6, 1, 0, 0, 0, tzinfo=UTC)
|
||||
summer_next = calculate_next_run_at(cron_expr, "America/New_York", summer_base)
|
||||
assert summer_next is not None
|
||||
# 10 AM EDT = 14:00 UTC
|
||||
assert summer_next.hour == 14
|
||||
|
||||
|
||||
class TestSyncScheduleFromWorkflow(unittest.TestCase):
|
||||
"""Test cases for syncing schedule from workflow."""
|
||||
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.select")
|
||||
def test_sync_schedule_create_new(self, mock_select, mock_service, mock_db):
|
||||
"""Test creating new schedule when none exists."""
|
||||
mock_session = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
Session = MagicMock(return_value=mock_session)
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
|
||||
mock_session.scalar.return_value = None # No existing plan
|
||||
|
||||
# Mock extract_schedule_config to return a ScheduleConfig object
|
||||
mock_config = Mock(spec=ScheduleConfig)
|
||||
mock_config.node_id = "start"
|
||||
mock_config.cron_expression = "30 10 * * *"
|
||||
mock_config.timezone = "UTC"
|
||||
mock_service.extract_schedule_config.return_value = mock_config
|
||||
|
||||
mock_new_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_service.create_schedule.return_value = mock_new_plan
|
||||
|
||||
workflow = Mock(spec=Workflow)
|
||||
result = sync_schedule_from_workflow("tenant-id", "app-id", workflow)
|
||||
|
||||
assert result == mock_new_plan
|
||||
mock_service.create_schedule.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.select")
|
||||
def test_sync_schedule_update_existing(self, mock_select, mock_service, mock_db):
|
||||
"""Test updating existing schedule."""
|
||||
mock_session = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
Session = MagicMock(return_value=mock_session)
|
||||
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
|
||||
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_existing_plan.id = "existing-plan-id"
|
||||
mock_session.scalar.return_value = mock_existing_plan
|
||||
|
||||
# Mock extract_schedule_config to return a ScheduleConfig object
|
||||
mock_config = Mock(spec=ScheduleConfig)
|
||||
mock_config.node_id = "start"
|
||||
mock_config.cron_expression = "0 12 * * *"
|
||||
mock_config.timezone = "America/New_York"
|
||||
mock_service.extract_schedule_config.return_value = mock_config
|
||||
|
||||
mock_updated_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_service.update_schedule.return_value = mock_updated_plan
|
||||
|
||||
workflow = Mock(spec=Workflow)
|
||||
result = sync_schedule_from_workflow("tenant-id", "app-id", workflow)
|
||||
|
||||
assert result == mock_updated_plan
|
||||
mock_service.update_schedule.assert_called_once()
|
||||
# Verify the arguments passed to update_schedule
|
||||
call_args = mock_service.update_schedule.call_args
|
||||
assert call_args.kwargs["session"] == mock_session
|
||||
assert call_args.kwargs["schedule_id"] == "existing-plan-id"
|
||||
updates_obj = call_args.kwargs["updates"]
|
||||
assert isinstance(updates_obj, SchedulePlanUpdate)
|
||||
assert updates_obj.node_id == "start"
|
||||
assert updates_obj.cron_expression == "0 12 * * *"
|
||||
assert updates_obj.timezone == "America/New_York"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.db")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.ScheduleService")
|
||||
@patch("events.event_handlers.sync_workflow_schedule_when_app_published.select")
|
||||
def test_sync_schedule_remove_when_no_config(self, mock_select, mock_service, mock_db):
|
||||
"""Test removing schedule when no schedule config in workflow."""
|
||||
mock_session = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=None)
|
||||
Session = MagicMock(return_value=mock_session)
|
||||
|
||||
with patch("events.event_handlers.sync_workflow_schedule_when_app_published.Session", Session):
|
||||
mock_existing_plan = Mock(spec=WorkflowSchedulePlan)
|
||||
mock_existing_plan.id = "existing-plan-id"
|
||||
mock_session.scalar.return_value = mock_existing_plan
|
||||
|
||||
mock_service.extract_schedule_config.return_value = None # No schedule config
|
||||
|
||||
workflow = Mock(spec=Workflow)
|
||||
result = sync_schedule_from_workflow("tenant-id", "app-id", workflow)
|
||||
|
||||
assert result is None
|
||||
# Now using ScheduleService.delete_schedule instead of session.delete
|
||||
mock_service.delete_schedule.assert_called_once_with(session=mock_session, schedule_id="existing-plan-id")
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
482
api/tests/unit_tests/services/test_webhook_service.py
Normal file
482
api/tests/unit_tests/services/test_webhook_service.py
Normal file
@@ -0,0 +1,482 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
|
||||
class TestWebhookServiceUnit:
|
||||
"""Unit tests for WebhookService focusing on business logic without database dependencies."""
|
||||
|
||||
def test_extract_webhook_data_json(self):
|
||||
"""Test webhook data extraction from JSON request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json", "Authorization": "Bearer token"},
|
||||
query_string="version=1&format=json",
|
||||
json={"message": "hello", "count": 42},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["headers"]["Authorization"] == "Bearer token"
|
||||
# Query params are now extracted as raw strings
|
||||
assert webhook_data["query_params"]["version"] == "1"
|
||||
assert webhook_data["query_params"]["format"] == "json"
|
||||
assert webhook_data["body"]["message"] == "hello"
|
||||
assert webhook_data["body"]["count"] == 42
|
||||
assert webhook_data["files"] == {}
|
||||
|
||||
def test_extract_webhook_data_query_params_remain_strings(self):
|
||||
"""Query parameters should be extracted as raw strings without automatic conversion."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="GET",
|
||||
headers={"Content-Type": "application/json"},
|
||||
query_string="count=42&threshold=3.14&enabled=true¬e=text",
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
# After refactoring, raw extraction keeps query params as strings
|
||||
assert webhook_data["query_params"]["count"] == "42"
|
||||
assert webhook_data["query_params"]["threshold"] == "3.14"
|
||||
assert webhook_data["query_params"]["enabled"] == "true"
|
||||
assert webhook_data["query_params"]["note"] == "text"
|
||||
|
||||
def test_extract_webhook_data_form_urlencoded(self):
|
||||
"""Test webhook data extraction from form URL encoded request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={"username": "test", "password": "secret"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["username"] == "test"
|
||||
assert webhook_data["body"]["password"] == "secret"
|
||||
|
||||
def test_extract_webhook_data_multipart_with_files(self):
|
||||
"""Test webhook data extraction from multipart form with files."""
|
||||
app = Flask(__name__)
|
||||
|
||||
# Create a mock file
|
||||
file_content = b"test file content"
|
||||
file_storage = FileStorage(stream=BytesIO(file_content), filename="test.txt", content_type="text/plain")
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "multipart/form-data"},
|
||||
data={"message": "test", "upload": file_storage},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
|
||||
mock_process_files.return_value = {"upload": "mocked_file_obj"}
|
||||
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["message"] == "test"
|
||||
assert webhook_data["files"]["upload"] == "mocked_file_obj"
|
||||
mock_process_files.assert_called_once()
|
||||
|
||||
def test_extract_webhook_data_raw_text(self):
|
||||
"""Test webhook data extraction from raw text request."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook", method="POST", headers={"Content-Type": "text/plain"}, data="raw text content"
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"]["raw"] == "raw text content"
|
||||
|
||||
def test_extract_webhook_data_invalid_json(self):
|
||||
"""Test webhook data extraction with invalid JSON."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook", method="POST", headers={"Content-Type": "application/json"}, data="invalid json"
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
|
||||
assert webhook_data["method"] == "POST"
|
||||
assert webhook_data["body"] == {} # Should default to empty dict
|
||||
|
||||
def test_generate_webhook_response_default(self):
|
||||
"""Test webhook response generation with default values."""
|
||||
node_config = {"data": {}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 200
|
||||
assert response_data["status"] == "success"
|
||||
assert "Webhook processed successfully" in response_data["message"]
|
||||
|
||||
def test_generate_webhook_response_custom_json(self):
|
||||
"""Test webhook response generation with custom JSON response."""
|
||||
node_config = {"data": {"status_code": 201, "response_body": '{"result": "created", "id": 123}'}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 201
|
||||
assert response_data["result"] == "created"
|
||||
assert response_data["id"] == 123
|
||||
|
||||
def test_generate_webhook_response_custom_text(self):
|
||||
"""Test webhook response generation with custom text response."""
|
||||
node_config = {"data": {"status_code": 202, "response_body": "Request accepted for processing"}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 202
|
||||
assert response_data["message"] == "Request accepted for processing"
|
||||
|
||||
def test_generate_webhook_response_invalid_json(self):
|
||||
"""Test webhook response generation with invalid JSON response."""
|
||||
node_config = {"data": {"status_code": 400, "response_body": '{"invalid": json}'}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 400
|
||||
assert response_data["message"] == '{"invalid": json}'
|
||||
|
||||
def test_generate_webhook_response_empty_response_body(self):
|
||||
"""Test webhook response generation with empty response body."""
|
||||
node_config = {"data": {"status_code": 204, "response_body": ""}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 204
|
||||
assert response_data["status"] == "success"
|
||||
assert "Webhook processed successfully" in response_data["message"]
|
||||
|
||||
def test_generate_webhook_response_array_json(self):
|
||||
"""Test webhook response generation with JSON array response."""
|
||||
node_config = {"data": {"status_code": 200, "response_body": '[{"id": 1}, {"id": 2}]'}}
|
||||
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
assert status_code == 200
|
||||
assert isinstance(response_data, list)
|
||||
assert len(response_data) == 2
|
||||
assert response_data[0]["id"] == 1
|
||||
assert response_data[1]["id"] == 2
|
||||
|
||||
@patch("services.trigger.webhook_service.ToolFileManager")
|
||||
@patch("services.trigger.webhook_service.file_factory")
|
||||
def test_process_file_uploads_success(self, mock_file_factory, mock_tool_file_manager):
|
||||
"""Test successful file upload processing."""
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
|
||||
# Mock file factory
|
||||
mock_file_obj = MagicMock()
|
||||
mock_file_factory.build_from_mapping.return_value = mock_file_obj
|
||||
|
||||
# Create mock files
|
||||
files = {
|
||||
"file1": MagicMock(filename="test1.txt", content_type="text/plain"),
|
||||
"file2": MagicMock(filename="test2.jpg", content_type="image/jpeg"),
|
||||
}
|
||||
|
||||
# Mock file reads
|
||||
files["file1"].read.return_value = b"content1"
|
||||
files["file2"].read.return_value = b"content2"
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
result = WebhookService._process_file_uploads(files, webhook_trigger)
|
||||
|
||||
assert len(result) == 2
|
||||
assert "file1" in result
|
||||
assert "file2" in result
|
||||
|
||||
# Verify file processing was called for each file
|
||||
assert mock_tool_file_manager.call_count == 2
|
||||
assert mock_file_factory.build_from_mapping.call_count == 2
|
||||
|
||||
@patch("services.trigger.webhook_service.ToolFileManager")
|
||||
@patch("services.trigger.webhook_service.file_factory")
|
||||
def test_process_file_uploads_with_errors(self, mock_file_factory, mock_tool_file_manager):
|
||||
"""Test file upload processing with errors."""
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
|
||||
# Mock file factory
|
||||
mock_file_obj = MagicMock()
|
||||
mock_file_factory.build_from_mapping.return_value = mock_file_obj
|
||||
|
||||
# Create mock files, one will fail
|
||||
files = {
|
||||
"good_file": MagicMock(filename="test.txt", content_type="text/plain"),
|
||||
"bad_file": MagicMock(filename="test.bad", content_type="text/plain"),
|
||||
}
|
||||
|
||||
files["good_file"].read.return_value = b"content"
|
||||
files["bad_file"].read.side_effect = Exception("Read error")
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
result = WebhookService._process_file_uploads(files, webhook_trigger)
|
||||
|
||||
# Should process the good file and skip the bad one
|
||||
assert len(result) == 1
|
||||
assert "good_file" in result
|
||||
assert "bad_file" not in result
|
||||
|
||||
def test_process_file_uploads_empty_filename(self):
|
||||
"""Test file upload processing with empty filename."""
|
||||
files = {
|
||||
"no_filename": MagicMock(filename="", content_type="text/plain"),
|
||||
"none_filename": MagicMock(filename=None, content_type="text/plain"),
|
||||
}
|
||||
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
result = WebhookService._process_file_uploads(files, webhook_trigger)
|
||||
|
||||
# Should skip files without filenames
|
||||
assert len(result) == 0
|
||||
|
||||
def test_validate_json_value_string(self):
|
||||
"""Test JSON value validation for string type."""
|
||||
# Valid string
|
||||
result = WebhookService._validate_json_value("name", "hello", "string")
|
||||
assert result == "hello"
|
||||
|
||||
# Invalid string (number) - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Expected string, got int"):
|
||||
WebhookService._validate_json_value("name", 123, "string")
|
||||
|
||||
def test_validate_json_value_number(self):
|
||||
"""Test JSON value validation for number type."""
|
||||
# Valid integer
|
||||
result = WebhookService._validate_json_value("count", 42, "number")
|
||||
assert result == 42
|
||||
|
||||
# Valid float
|
||||
result = WebhookService._validate_json_value("price", 19.99, "number")
|
||||
assert result == 19.99
|
||||
|
||||
# Invalid number (string) - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Expected number, got str"):
|
||||
WebhookService._validate_json_value("count", "42", "number")
|
||||
|
||||
def test_validate_json_value_bool(self):
|
||||
"""Test JSON value validation for boolean type."""
|
||||
# Valid boolean
|
||||
result = WebhookService._validate_json_value("enabled", True, "boolean")
|
||||
assert result is True
|
||||
|
||||
result = WebhookService._validate_json_value("enabled", False, "boolean")
|
||||
assert result is False
|
||||
|
||||
# Invalid boolean (string) - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Expected boolean, got str"):
|
||||
WebhookService._validate_json_value("enabled", "true", "boolean")
|
||||
|
||||
def test_validate_json_value_object(self):
|
||||
"""Test JSON value validation for object type."""
|
||||
# Valid object
|
||||
result = WebhookService._validate_json_value("user", {"name": "John", "age": 30}, "object")
|
||||
assert result == {"name": "John", "age": 30}
|
||||
|
||||
# Invalid object (string) - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Expected object, got str"):
|
||||
WebhookService._validate_json_value("user", "not_an_object", "object")
|
||||
|
||||
def test_validate_json_value_array_string(self):
|
||||
"""Test JSON value validation for array[string] type."""
|
||||
# Valid array of strings
|
||||
result = WebhookService._validate_json_value("tags", ["tag1", "tag2", "tag3"], "array[string]")
|
||||
assert result == ["tag1", "tag2", "tag3"]
|
||||
|
||||
# Invalid - not an array
|
||||
with pytest.raises(ValueError, match="Expected array of strings, got str"):
|
||||
WebhookService._validate_json_value("tags", "not_an_array", "array[string]")
|
||||
|
||||
# Invalid - array with non-strings
|
||||
with pytest.raises(ValueError, match="Expected array of strings, got list"):
|
||||
WebhookService._validate_json_value("tags", ["tag1", 123, "tag3"], "array[string]")
|
||||
|
||||
def test_validate_json_value_array_number(self):
|
||||
"""Test JSON value validation for array[number] type."""
|
||||
# Valid array of numbers
|
||||
result = WebhookService._validate_json_value("scores", [1, 2.5, 3, 4.7], "array[number]")
|
||||
assert result == [1, 2.5, 3, 4.7]
|
||||
|
||||
# Invalid - array with non-numbers
|
||||
with pytest.raises(ValueError, match="Expected array of numbers, got list"):
|
||||
WebhookService._validate_json_value("scores", [1, "2", 3], "array[number]")
|
||||
|
||||
def test_validate_json_value_array_bool(self):
|
||||
"""Test JSON value validation for array[boolean] type."""
|
||||
# Valid array of booleans
|
||||
result = WebhookService._validate_json_value("flags", [True, False, True], "array[boolean]")
|
||||
assert result == [True, False, True]
|
||||
|
||||
# Invalid - array with non-booleans
|
||||
with pytest.raises(ValueError, match="Expected array of booleans, got list"):
|
||||
WebhookService._validate_json_value("flags", [True, "false", True], "array[boolean]")
|
||||
|
||||
def test_validate_json_value_array_object(self):
|
||||
"""Test JSON value validation for array[object] type."""
|
||||
# Valid array of objects
|
||||
result = WebhookService._validate_json_value("users", [{"name": "John"}, {"name": "Jane"}], "array[object]")
|
||||
assert result == [{"name": "John"}, {"name": "Jane"}]
|
||||
|
||||
# Invalid - array with non-objects
|
||||
with pytest.raises(ValueError, match="Expected array of objects, got list"):
|
||||
WebhookService._validate_json_value("users", [{"name": "John"}, "not_object"], "array[object]")
|
||||
|
||||
def test_convert_form_value_string(self):
|
||||
"""Test form value conversion for string type."""
|
||||
result = WebhookService._convert_form_value("test", "hello", "string")
|
||||
assert result == "hello"
|
||||
|
||||
def test_convert_form_value_number(self):
|
||||
"""Test form value conversion for number type."""
|
||||
# Integer
|
||||
result = WebhookService._convert_form_value("count", "42", "number")
|
||||
assert result == 42
|
||||
|
||||
# Float
|
||||
result = WebhookService._convert_form_value("price", "19.99", "number")
|
||||
assert result == 19.99
|
||||
|
||||
# Invalid number
|
||||
with pytest.raises(ValueError, match="Cannot convert 'not_a_number' to number"):
|
||||
WebhookService._convert_form_value("count", "not_a_number", "number")
|
||||
|
||||
def test_convert_form_value_boolean(self):
|
||||
"""Test form value conversion for boolean type."""
|
||||
# True values
|
||||
assert WebhookService._convert_form_value("flag", "true", "boolean") is True
|
||||
assert WebhookService._convert_form_value("flag", "1", "boolean") is True
|
||||
assert WebhookService._convert_form_value("flag", "yes", "boolean") is True
|
||||
|
||||
# False values
|
||||
assert WebhookService._convert_form_value("flag", "false", "boolean") is False
|
||||
assert WebhookService._convert_form_value("flag", "0", "boolean") is False
|
||||
assert WebhookService._convert_form_value("flag", "no", "boolean") is False
|
||||
|
||||
# Invalid boolean
|
||||
with pytest.raises(ValueError, match="Cannot convert 'maybe' to boolean"):
|
||||
WebhookService._convert_form_value("flag", "maybe", "boolean")
|
||||
|
||||
def test_extract_and_validate_webhook_data_success(self):
|
||||
"""Test successful unified data extraction and validation."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
query_string="count=42&enabled=true",
|
||||
json={"message": "hello", "age": 25},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post",
|
||||
"content_type": "application/json",
|
||||
"params": [
|
||||
{"name": "count", "type": "number", "required": True},
|
||||
{"name": "enabled", "type": "boolean", "required": True},
|
||||
],
|
||||
"body": [
|
||||
{"name": "message", "type": "string", "required": True},
|
||||
{"name": "age", "type": "number", "required": True},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
result = WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
# Check that types are correctly converted
|
||||
assert result["query_params"]["count"] == 42 # Converted to int
|
||||
assert result["query_params"]["enabled"] is True # Converted to bool
|
||||
assert result["body"]["message"] == "hello" # Already string
|
||||
assert result["body"]["age"] == 25 # Already number
|
||||
|
||||
def test_extract_and_validate_webhook_data_validation_error(self):
|
||||
"""Test unified data extraction with validation error."""
|
||||
app = Flask(__name__)
|
||||
|
||||
with app.test_request_context(
|
||||
"/webhook",
|
||||
method="GET", # Wrong method
|
||||
headers={"Content-Type": "application/json"},
|
||||
):
|
||||
webhook_trigger = MagicMock()
|
||||
node_config = {
|
||||
"data": {
|
||||
"method": "post", # Expects POST
|
||||
"content_type": "application/json",
|
||||
}
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="HTTP method mismatch"):
|
||||
WebhookService.extract_and_validate_webhook_data(webhook_trigger, node_config)
|
||||
|
||||
def test_debug_mode_parameter_handling(self):
|
||||
"""Test that the debug mode parameter is properly handled in _prepare_webhook_execution."""
|
||||
from controllers.trigger.webhook import _prepare_webhook_execution
|
||||
|
||||
# Mock the WebhookService methods
|
||||
with (
|
||||
patch.object(WebhookService, "get_webhook_trigger_and_workflow") as mock_get_trigger,
|
||||
patch.object(WebhookService, "extract_and_validate_webhook_data") as mock_extract,
|
||||
):
|
||||
mock_trigger = MagicMock()
|
||||
mock_workflow = MagicMock()
|
||||
mock_config = {"data": {"test": "config"}}
|
||||
mock_data = {"test": "data"}
|
||||
|
||||
mock_get_trigger.return_value = (mock_trigger, mock_workflow, mock_config)
|
||||
mock_extract.return_value = mock_data
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=False)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
|
||||
# Reset mock
|
||||
mock_get_trigger.reset_mock()
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=True)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
Reference in New Issue
Block a user