feat: add AWS Managed IAM auth for OpenSearch vector DB (#18963)

This commit is contained in:
Ahmad Zidan
2025-04-29 14:10:08 +07:00
committed by GitHub
parent 8b4ea01810
commit 8266815cda
5 changed files with 138 additions and 22 deletions

View File

@@ -1,10 +1,9 @@
import json
import logging
import ssl
from typing import Any, Optional
from typing import Any, Literal, Optional
from uuid import uuid4
from opensearchpy import OpenSearch, helpers
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
@@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel):
host: str
port: int
secure: bool = False
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None
password: Optional[str] = None
secure: bool = False
aws_region: Optional[str] = None
aws_service: Optional[str] = None
@model_validator(mode="before")
@classmethod
@@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
if values.get("auth_method") == "aws_managed_iam":
if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
return values
def create_ssl_context(self) -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
return ssl_context
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3 # type: ignore
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),
region=self.aws_region,
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
if self.user and self.password:
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
if self.secure:
params["ssl_context"] = self.create_ssl_context()
elif self.auth_method == "aws_managed_iam":
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
return params
@@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuid4().hex,
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]:
action["_id"] = uuid4().hex
actions.append(action)
helpers.bulk(self._client, actions)
helpers.bulk(
client=self._client,
actions=actions,
timeout=30,
max_retries=3,
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
@@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
},
}
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
secure=dify_config.OPENSEARCH_SECURE,
aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)