feat:mysql adaptation for metadb (#28188)
This commit is contained in:
@@ -2,11 +2,15 @@ import enum
|
||||
import uuid
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import CHAR, VARCHAR, TypeDecorator
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import CHAR, TEXT, VARCHAR, LargeBinary, TypeDecorator
|
||||
from sqlalchemy.dialects.mysql import LONGBLOB, LONGTEXT
|
||||
from sqlalchemy.dialects.postgresql import BYTEA, JSONB, UUID
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class StringUUID(TypeDecorator[uuid.UUID | str | None]):
|
||||
impl = CHAR
|
||||
@@ -34,6 +38,78 @@ class StringUUID(TypeDecorator[uuid.UUID | str | None]):
|
||||
return str(value)
|
||||
|
||||
|
||||
class LongText(TypeDecorator[str | None]):
|
||||
impl = TEXT
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return value
|
||||
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(TEXT())
|
||||
elif dialect.name == "mysql":
|
||||
return dialect.type_descriptor(LONGTEXT())
|
||||
else:
|
||||
return dialect.type_descriptor(TEXT())
|
||||
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
if value is None:
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
class BinaryData(TypeDecorator[bytes | None]):
|
||||
impl = LargeBinary
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: bytes | None, dialect: Dialect) -> bytes | None:
|
||||
if value is None:
|
||||
return value
|
||||
return value
|
||||
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(BYTEA())
|
||||
elif dialect.name == "mysql":
|
||||
return dialect.type_descriptor(LONGBLOB())
|
||||
else:
|
||||
return dialect.type_descriptor(LargeBinary())
|
||||
|
||||
def process_result_value(self, value: bytes | None, dialect: Dialect) -> bytes | None:
|
||||
if value is None:
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
class AdjustedJSON(TypeDecorator[dict | list | None]):
|
||||
impl = sa.JSON
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, astext_type=None):
|
||||
self.astext_type = astext_type
|
||||
super().__init__()
|
||||
|
||||
def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]:
|
||||
if dialect.name == "postgresql":
|
||||
if self.astext_type:
|
||||
return dialect.type_descriptor(JSONB(astext_type=self.astext_type))
|
||||
else:
|
||||
return dialect.type_descriptor(JSONB())
|
||||
elif dialect.name == "mysql":
|
||||
return dialect.type_descriptor(sa.JSON())
|
||||
else:
|
||||
return dialect.type_descriptor(sa.JSON())
|
||||
|
||||
def process_bind_param(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: dict | list | None, dialect: Dialect) -> dict | list | None:
|
||||
return value
|
||||
|
||||
|
||||
_E = TypeVar("_E", bound=enum.StrEnum)
|
||||
|
||||
|
||||
@@ -77,3 +153,11 @@ class EnumText(TypeDecorator[_E | None], Generic[_E]):
|
||||
if x is None or y is None:
|
||||
return x is y
|
||||
return x == y
|
||||
|
||||
|
||||
def adjusted_json_index(index_name, column_name):
|
||||
index_name = index_name or f"{column_name}_idx"
|
||||
if dify_config.DB_TYPE == "postgresql":
|
||||
return sa.Index(index_name, column_name, postgresql_using="gin")
|
||||
else:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user