feat: add AWS Managed IAM auth for OpenSearch vector DB (#18963)
This commit is contained in:
@@ -23,13 +23,70 @@ def setup_mock_redis():
|
||||
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
|
||||
|
||||
|
||||
class TestOpenSearchConfig:
|
||||
def test_to_opensearch_params(self):
|
||||
config = OpenSearchConfig(
|
||||
host="localhost",
|
||||
port=9200,
|
||||
secure=True,
|
||||
user="admin",
|
||||
password="password",
|
||||
)
|
||||
|
||||
params = config.to_opensearch_params()
|
||||
|
||||
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
|
||||
assert params["use_ssl"] is True
|
||||
assert params["verify_certs"] is True
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] == ("admin", "password")
|
||||
|
||||
@patch("boto3.Session")
|
||||
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
|
||||
def test_to_opensearch_params_with_aws_managed_iam(
|
||||
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
|
||||
):
|
||||
mock_credentials = MagicMock()
|
||||
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
|
||||
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_aws_signer_auth.return_value = mock_auth_instance
|
||||
|
||||
aws_region = "ap-southeast-2"
|
||||
aws_service = "aoss"
|
||||
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
|
||||
port = 9201
|
||||
|
||||
config = OpenSearchConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
secure=True,
|
||||
auth_method="aws_managed_iam",
|
||||
aws_region=aws_region,
|
||||
aws_service=aws_service,
|
||||
)
|
||||
|
||||
params = config.to_opensearch_params()
|
||||
|
||||
assert params["hosts"] == [{"host": host, "port": port}]
|
||||
assert params["use_ssl"] is True
|
||||
assert params["verify_certs"] is True
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] is mock_auth_instance
|
||||
|
||||
mock_aws_signer_auth.assert_called_once_with(
|
||||
credentials=mock_credentials, region=aws_region, service=aws_service
|
||||
)
|
||||
assert mock_boto_session.return_value.get_credentials.called
|
||||
|
||||
|
||||
class TestOpenSearchVector:
|
||||
def setup_method(self):
|
||||
self.collection_name = "test_collection"
|
||||
self.example_doc_id = "example_doc_id"
|
||||
self.vector = OpenSearchVector(
|
||||
collection_name=self.collection_name,
|
||||
config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False),
|
||||
config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
|
||||
)
|
||||
self.vector._client = MagicMock()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user