feat: add AWS Managed IAM auth for OpenSearch vector DB (#18963)
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import ssl
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from opensearchpy import OpenSearch, helpers
|
||||
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
|
||||
from opensearchpy.helpers import BulkIndexError
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
@@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
|
||||
class OpenSearchConfig(BaseModel):
|
||||
host: str
|
||||
port: int
|
||||
secure: bool = False
|
||||
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
|
||||
user: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
secure: bool = False
|
||||
aws_region: Optional[str] = None
|
||||
aws_service: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
|
||||
raise ValueError("config OPENSEARCH_HOST is required")
|
||||
if not values.get("port"):
|
||||
raise ValueError("config OPENSEARCH_PORT is required")
|
||||
if values.get("auth_method") == "aws_managed_iam":
|
||||
if not values.get("aws_region"):
|
||||
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
|
||||
if not values.get("aws_service"):
|
||||
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
|
||||
return values
|
||||
|
||||
def create_ssl_context(self) -> ssl.SSLContext:
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
|
||||
return ssl_context
|
||||
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
|
||||
import boto3 # type: ignore
|
||||
|
||||
return Urllib3AWSV4SignerAuth(
|
||||
credentials=boto3.Session().get_credentials(),
|
||||
region=self.aws_region,
|
||||
service=self.aws_service, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def to_opensearch_params(self) -> dict[str, Any]:
|
||||
params = {
|
||||
"hosts": [{"host": self.host, "port": self.port}],
|
||||
"use_ssl": self.secure,
|
||||
"verify_certs": self.secure,
|
||||
"connection_class": Urllib3HttpConnection,
|
||||
"pool_maxsize": 20,
|
||||
}
|
||||
if self.user and self.password:
|
||||
|
||||
if self.auth_method == "basic":
|
||||
logger.info("Using basic authentication for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = (self.user, self.password)
|
||||
if self.secure:
|
||||
params["ssl_context"] = self.create_ssl_context()
|
||||
elif self.auth_method == "aws_managed_iam":
|
||||
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
|
||||
|
||||
params["http_auth"] = self.create_aws_managed_iam_auth()
|
||||
|
||||
return params
|
||||
|
||||
|
||||
@@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
|
||||
action = {
|
||||
"_op_type": "index",
|
||||
"_index": self._collection_name.lower(),
|
||||
"_id": uuid4().hex,
|
||||
"_source": {
|
||||
Field.CONTENT_KEY.value: documents[i].page_content,
|
||||
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
|
||||
Field.METADATA_KEY.value: documents[i].metadata,
|
||||
},
|
||||
}
|
||||
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
|
||||
if self._client_config.aws_service not in ["aoss"]:
|
||||
action["_id"] = uuid4().hex
|
||||
actions.append(action)
|
||||
|
||||
helpers.bulk(self._client, actions)
|
||||
helpers.bulk(
|
||||
client=self._client,
|
||||
actions=actions,
|
||||
timeout=30,
|
||||
max_retries=3,
|
||||
)
|
||||
|
||||
def get_ids_by_metadata_field(self, key: str, value: str):
|
||||
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
|
||||
@@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
|
||||
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
|
||||
|
||||
redis_client.set(collection_exist_cache_key, 1, ex=3600)
|
||||
@@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
|
||||
open_search_config = OpenSearchConfig(
|
||||
host=dify_config.OPENSEARCH_HOST or "localhost",
|
||||
port=dify_config.OPENSEARCH_PORT,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
|
||||
user=dify_config.OPENSEARCH_USER,
|
||||
password=dify_config.OPENSEARCH_PASSWORD,
|
||||
secure=dify_config.OPENSEARCH_SECURE,
|
||||
aws_region=dify_config.OPENSEARCH_AWS_REGION,
|
||||
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
|
||||
)
|
||||
|
||||
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
|
||||
|
||||
Reference in New Issue
Block a user