添加注册登录功能

This commit is contained in:
2025-08-29 00:34:40 +08:00
parent 09065f2ce7
commit 2fe3474d9e
3060 changed files with 29217 additions and 87137 deletions

View File

@@ -170,8 +170,10 @@ class _PGDialect_common_psycopg(PGDialect):
def _do_autocommit(self, connection, value):
connection.autocommit = value
def detect_autocommit_setting(self, dbapi_connection):
return bool(dbapi_connection.autocommit)
def do_ping(self, dbapi_connection):
cursor = None
before_autocommit = dbapi_connection.autocommit
if not before_autocommit:

View File

@@ -4,15 +4,18 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
import re
from typing import Any
from typing import Any as typing_Any
from typing import Iterable
from typing import Optional
from typing import Sequence
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .operators import CONTAINED_BY
from .operators import CONTAINS
@@ -21,28 +24,52 @@ from ... import types as sqltypes
from ... import util
from ...sql import expression
from ...sql import operators
from ...sql._typing import _TypeEngineArgument
from ...sql.visitors import InternalTraversal
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql._typing import _ColumnExpressionArgument
from ...sql._typing import _TypeEngineArgument
from ...sql.elements import ColumnElement
from ...sql.elements import Grouping
from ...sql.expression import BindParameter
from ...sql.operators import OperatorType
from ...sql.selectable import _SelectIterable
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _LiteralProcessorType
from ...sql.type_api import _ResultProcessorType
from ...sql.type_api import TypeEngine
from ...sql.visitors import _TraverseInternalsType
from ...util.typing import Self
_T = TypeVar("_T", bound=Any)
_T = TypeVar("_T", bound=typing_Any)
def Any(other, arrexpr, operator=operators.eq):
def Any(
other: typing_Any,
arrexpr: _ColumnExpressionArgument[_T],
operator: OperatorType = operators.eq,
) -> ColumnElement[bool]:
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
See that method for details.
"""
return arrexpr.any(other, operator)
return arrexpr.any(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
def All(other, arrexpr, operator=operators.eq):
def All(
other: typing_Any,
arrexpr: _ColumnExpressionArgument[_T],
operator: OperatorType = operators.eq,
) -> ColumnElement[bool]:
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
See that method for details.
"""
return arrexpr.all(other, operator)
return arrexpr.all(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501
class array(expression.ExpressionClauseList[_T]):
@@ -66,11 +93,32 @@ class array(expression.ExpressionClauseList[_T]):
ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
An instance of :class:`.array` will always have the datatype
:class:`_types.ARRAY`. The "inner" type of the array is inferred from
the values present, unless the ``type_`` keyword argument is passed::
:class:`_types.ARRAY`. The "inner" type of the array is inferred from the
values present, unless the :paramref:`_postgresql.array.type_` keyword
argument is passed::
array(["foo", "bar"], type_=CHAR)
When constructing an empty array, the :paramref:`_postgresql.array.type_`
argument is particularly important as PostgreSQL server typically requires
a cast to be rendered for the inner type in order to render an empty array.
SQLAlchemy's compilation for the empty array will produce this cast so
that::
stmt = array([], type_=Integer)
print(stmt.compile(dialect=postgresql.dialect()))
Produces:
.. sourcecode:: sql
ARRAY[]::INTEGER[]
As required by PostgreSQL for empty arrays.
.. versionadded:: 2.0.40 added support to render empty PostgreSQL array
literals with a required cast.
Multidimensional arrays are produced by nesting :class:`.array` constructs.
The dimensionality of the final :class:`_types.ARRAY`
type is calculated by
@@ -105,18 +153,33 @@ class array(expression.ExpressionClauseList[_T]):
__visit_name__ = "array"
stringify_dialect = "postgresql"
inherit_cache = True
def __init__(self, clauses, **kw):
type_arg = kw.pop("type_", None)
_traverse_internals: _TraverseInternalsType = [
("clauses", InternalTraversal.dp_clauseelement_tuple),
("type", InternalTraversal.dp_type),
]
def __init__(
self,
clauses: Iterable[_T],
*,
type_: Optional[_TypeEngineArgument[_T]] = None,
**kw: typing_Any,
):
r"""Construct an ARRAY literal.
:param clauses: iterable, such as a list, containing elements to be
rendered in the array
:param type\_: optional type. If omitted, the type is inferred
from the contents of the array.
"""
super().__init__(operators.comma_op, *clauses, **kw)
self._type_tuple = [arg.type for arg in self.clauses]
main_type = (
type_arg
if type_arg is not None
else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE
type_
if type_ is not None
else self.clauses[0].type if self.clauses else sqltypes.NULLTYPE
)
if isinstance(main_type, ARRAY):
@@ -127,15 +190,21 @@ class array(expression.ExpressionClauseList[_T]):
if main_type.dimensions is not None
else 2
),
)
) # type: ignore[assignment]
else:
self.type = ARRAY(main_type)
self.type = ARRAY(main_type) # type: ignore[assignment]
@property
def _select_iterable(self):
def _select_iterable(self) -> _SelectIterable:
return (self,)
def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
def _bind_param(
self,
operator: OperatorType,
obj: typing_Any,
type_: Optional[TypeEngine[_T]] = None,
_assume_scalar: bool = False,
) -> BindParameter[_T]:
if _assume_scalar or operator is operators.getitem:
return expression.BindParameter(
None,
@@ -154,16 +223,18 @@ class array(expression.ExpressionClauseList[_T]):
)
for o in obj
]
)
) # type: ignore[return-value]
def self_group(self, against=None):
def self_group(
self, against: Optional[OperatorType] = None
) -> Union[Self, Grouping[_T]]:
if against in (operators.any_op, operators.all_op, operators.getitem):
return expression.Grouping(self)
else:
return self
class ARRAY(sqltypes.ARRAY):
class ARRAY(sqltypes.ARRAY[_T]):
"""PostgreSQL ARRAY type.
The :class:`_postgresql.ARRAY` type is constructed in the same way
@@ -237,7 +308,7 @@ class ARRAY(sqltypes.ARRAY):
def __init__(
self,
item_type: _TypeEngineArgument[Any],
item_type: _TypeEngineArgument[_T],
as_tuple: bool = False,
dimensions: Optional[int] = None,
zero_indexes: bool = False,
@@ -286,7 +357,7 @@ class ARRAY(sqltypes.ARRAY):
self.dimensions = dimensions
self.zero_indexes = zero_indexes
class Comparator(sqltypes.ARRAY.Comparator):
class Comparator(sqltypes.ARRAY.Comparator[_T]):
"""Define comparison operations for :class:`_types.ARRAY`.
Note that these operations are in addition to those provided
@@ -296,7 +367,9 @@ class ARRAY(sqltypes.ARRAY):
"""
def contains(self, other, **kwargs):
def contains(
self, other: typing_Any, **kwargs: typing_Any
) -> ColumnElement[bool]:
"""Boolean expression. Test if elements are a superset of the
elements of the argument array expression.
@@ -305,7 +378,7 @@ class ARRAY(sqltypes.ARRAY):
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
def contained_by(self, other: typing_Any) -> ColumnElement[bool]:
"""Boolean expression. Test if elements are a proper subset of the
elements of the argument array expression.
"""
@@ -313,7 +386,7 @@ class ARRAY(sqltypes.ARRAY):
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def overlap(self, other):
def overlap(self, other: typing_Any) -> ColumnElement[bool]:
"""Boolean expression. Test if array has elements in common with
an argument array expression.
"""
@@ -321,35 +394,26 @@ class ARRAY(sqltypes.ARRAY):
comparator_factory = Comparator
@property
def hashable(self):
return self.as_tuple
@property
def python_type(self):
return list
def compare_values(self, x, y):
return x == y
@util.memoized_property
def _against_native_enum(self):
def _against_native_enum(self) -> bool:
return (
isinstance(self.item_type, sqltypes.Enum)
and self.item_type.native_enum
)
def literal_processor(self, dialect):
def literal_processor(
self, dialect: Dialect
) -> Optional[_LiteralProcessorType[_T]]:
item_proc = self.item_type.dialect_impl(dialect).literal_processor(
dialect
)
if item_proc is None:
return None
def to_str(elements):
def to_str(elements: Iterable[typing_Any]) -> str:
return f"ARRAY[{', '.join(elements)}]"
def process(value):
def process(value: Sequence[typing_Any]) -> str:
inner = self._apply_item_processor(
value, item_proc, self.dimensions, to_str
)
@@ -357,12 +421,16 @@ class ARRAY(sqltypes.ARRAY):
return process
def bind_processor(self, dialect):
def bind_processor(
self, dialect: Dialect
) -> Optional[_BindProcessorType[Sequence[typing_Any]]]:
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
dialect
)
def process(value):
def process(
value: Optional[Sequence[typing_Any]],
) -> Optional[list[typing_Any]]:
if value is None:
return value
else:
@@ -372,12 +440,16 @@ class ARRAY(sqltypes.ARRAY):
return process
def result_processor(self, dialect, coltype):
def result_processor(
self, dialect: Dialect, coltype: object
) -> _ResultProcessorType[Sequence[typing_Any]]:
item_proc = self.item_type.dialect_impl(dialect).result_processor(
dialect, coltype
)
def process(value):
def process(
value: Sequence[typing_Any],
) -> Optional[Sequence[typing_Any]]:
if value is None:
return value
else:
@@ -392,11 +464,13 @@ class ARRAY(sqltypes.ARRAY):
super_rp = process
pattern = re.compile(r"^{(.*)}$")
def handle_raw_string(value):
inner = pattern.match(value).group(1)
def handle_raw_string(value: str) -> list[str]:
inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501
return _split_enum_values(inner)
def process(value):
def process(
value: Sequence[typing_Any],
) -> Optional[Sequence[typing_Any]]:
if value is None:
return value
# isinstance(value, str) is required to handle
@@ -411,7 +485,7 @@ class ARRAY(sqltypes.ARRAY):
return process
def _split_enum_values(array_string):
def _split_enum_values(array_string: str) -> list[str]:
if '"' not in array_string:
# no escape char is present so it can just split on the comma
return array_string.split(",") if array_string else []

View File

@@ -910,13 +910,16 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection):
asyncio.CancelledError,
OSError,
self.dbapi.asyncpg.PostgresError,
):
) as e:
# in the case where we are recycling an old connection
# that may have already been disconnected, close() will
# fail with the above timeout. in this case, terminate
# the connection without any further waiting.
# see issue #8419
self._connection.terminate()
if isinstance(e, asyncio.CancelledError):
# re-raise CancelledError if we were cancelled
raise
else:
# not in a greenlet; this is the gc cleanup case
self._connection.terminate()
@@ -1114,6 +1117,9 @@ class PGDialect_asyncpg(PGDialect):
def set_isolation_level(self, dbapi_connection, level):
dbapi_connection.set_isolation_level(self._isolation_lookup[level])
def detect_autocommit_setting(self, dbapi_conn) -> bool:
return bool(dbapi_conn.autocommit)
def set_readonly(self, connection, value):
connection.readonly = value

View File

@@ -266,7 +266,7 @@ will remain consistent with the state of the transaction::
from sqlalchemy import event
postgresql_engine = create_engine(
"postgresql+pyscopg2://scott:tiger@hostname/dbname",
"postgresql+psycopg2://scott:tiger@hostname/dbname",
# disable default reset-on-return scheme
pool_reset_on_return=None,
)
@@ -978,6 +978,8 @@ PostgreSQL-Specific Index Options
Several extensions to the :class:`.Index` construct are available, specific
to the PostgreSQL dialect.
.. _postgresql_covering_indexes:
Covering Indexes
^^^^^^^^^^^^^^^^
@@ -990,6 +992,10 @@ would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)``
Note that this feature requires PostgreSQL 11 or later.
.. seealso::
:ref:`postgresql_constraint_options`
.. versionadded:: 1.4
.. _postgresql_partial_indexes:
@@ -1264,6 +1270,65 @@ with selected constraint constructs:
<https://www.postgresql.org/docs/current/static/sql-altertable.html>`_ -
in the PostgreSQL documentation.
* ``INCLUDE``: This option adds one or more columns as a "payload" to the
unique index created automatically by PostgreSQL for the constraint.
For example, the following table definition::
Table(
"mytable",
metadata,
Column("id", Integer, nullable=False),
Column("value", Integer, nullable=False),
UniqueConstraint("id", postgresql_include=["value"]),
)
would produce the DDL statement
.. sourcecode:: sql
CREATE TABLE mytable (
id INTEGER NOT NULL,
value INTEGER NOT NULL,
UNIQUE (id) INCLUDE (value)
)
Note that this feature requires PostgreSQL 11 or later.
.. versionadded:: 2.0.41
.. seealso::
:ref:`postgresql_covering_indexes`
.. seealso::
`PostgreSQL CREATE TABLE options
<https://www.postgresql.org/docs/current/static/sql-createtable.html>`_ -
in the PostgreSQL documentation.
* Column list with foreign key ``ON DELETE SET`` actions: This applies to
:class:`.ForeignKey` and :class:`.ForeignKeyConstraint`, the :paramref:`.ForeignKey.ondelete`
parameter will accept on the PostgreSQL backend only a string list of column
names inside parenthesis, following the ``SET NULL`` or ``SET DEFAULT``
phrases, which will limit the set of columns that are subject to the
action::
fktable = Table(
"fktable",
metadata,
Column("tid", Integer),
Column("id", Integer),
Column("fk_id_del_set_null", Integer),
ForeignKeyConstraint(
columns=["tid", "fk_id_del_set_null"],
refcolumns=[pktable.c.tid, pktable.c.id],
ondelete="SET NULL (fk_id_del_set_null)",
),
)
.. versionadded:: 2.0.40
.. _postgresql_table_valued_overview:
Table values, Table and Column valued functions, Row and Tuple objects
@@ -1482,6 +1547,7 @@ from functools import lru_cache
import re
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
@@ -1672,6 +1738,7 @@ RESERVED_WORDS = {
"verbose",
}
colspecs = {
sqltypes.ARRAY: _array.ARRAY,
sqltypes.Interval: INTERVAL,
@@ -1788,6 +1855,8 @@ class PGCompiler(compiler.SQLCompiler):
}"""
def visit_array(self, element, **kw):
if not element.clauses and not element.type.item_type._isnull:
return "ARRAY[]::%s" % element.type.compile(self.dialect)
return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
def visit_slice(self, element, **kw):
@@ -1811,9 +1880,23 @@ class PGCompiler(compiler.SQLCompiler):
kw["eager_grouping"] = True
return self._generate_generic_binary(
binary, " -> " if not _cast_applied else " ->> ", **kw
)
if (
not _cast_applied
and isinstance(binary.left.type, _json.JSONB)
and self.dialect._supports_jsonb_subscripting
):
# for pg14+JSONB use subscript notation: col['key'] instead
# of col -> 'key'
return "%s[%s]" % (
self.process(binary.left, **kw),
self.process(binary.right, **kw),
)
else:
# Fall back to arrow notation for older versions or when cast
# is applied
return self._generate_generic_binary(
binary, " -> " if not _cast_applied else " ->> ", **kw
)
def visit_json_path_getitem_op_binary(
self, binary, operator, _cast_applied=False, **kw
@@ -2001,9 +2084,10 @@ class PGCompiler(compiler.SQLCompiler):
for c in select._for_update_arg.of:
tables.update(sql_util.surface_selectables_only(c))
of_kw = dict(kw)
of_kw.update(ashint=True, use_schema=False)
tmp += " OF " + ", ".join(
self.process(table, ashint=True, use_schema=False, **kw)
for table in tables
self.process(table, **of_kw) for table in tables
)
if select._for_update_arg.nowait:
@@ -2232,6 +2316,18 @@ class PGDDLCompiler(compiler.DDLCompiler):
not_valid = constraint.dialect_options["postgresql"]["not_valid"]
return " NOT VALID" if not_valid else ""
def _define_include(self, obj):
includeclause = obj.dialect_options["postgresql"]["include"]
if not includeclause:
return ""
inclusions = [
obj.table.c[col] if isinstance(col, str) else col
for col in includeclause
]
return " INCLUDE (%s)" % ", ".join(
[self.preparer.quote(c.name) for c in inclusions]
)
def visit_check_constraint(self, constraint, **kw):
if constraint._type_bound:
typ = list(constraint.columns)[0].type
@@ -2255,6 +2351,29 @@ class PGDDLCompiler(compiler.DDLCompiler):
text += self._define_constraint_validity(constraint)
return text
def visit_primary_key_constraint(self, constraint, **kw):
text = super().visit_primary_key_constraint(constraint)
text += self._define_include(constraint)
return text
def visit_unique_constraint(self, constraint, **kw):
text = super().visit_unique_constraint(constraint)
text += self._define_include(constraint)
return text
@util.memoized_property
def _fk_ondelete_pattern(self):
return re.compile(
r"^(?:RESTRICT|CASCADE|SET (?:NULL|DEFAULT)(?:\s*\(.+\))?"
r"|NO ACTION)$",
re.I,
)
def define_constraint_ondelete_cascade(self, constraint):
return " ON DELETE %s" % self.preparer.validate_sql_phrase(
constraint.ondelete, self._fk_ondelete_pattern
)
def visit_create_enum_type(self, create, **kw):
type_ = create.element
@@ -2356,15 +2475,7 @@ class PGDDLCompiler(compiler.DDLCompiler):
)
)
includeclause = index.dialect_options["postgresql"]["include"]
if includeclause:
inclusions = [
index.table.c[col] if isinstance(col, str) else col
for col in includeclause
]
text += " INCLUDE (%s)" % ", ".join(
[preparer.quote(c.name) for c in inclusions]
)
text += self._define_include(index)
nulls_not_distinct = index.dialect_options["postgresql"][
"nulls_not_distinct"
@@ -3112,9 +3223,16 @@ class PGDialect(default.DefaultDialect):
"not_valid": False,
},
),
(
schema.PrimaryKeyConstraint,
{"include": None},
),
(
schema.UniqueConstraint,
{"nulls_not_distinct": None},
{
"include": None,
"nulls_not_distinct": None,
},
),
]
@@ -3123,6 +3241,7 @@ class PGDialect(default.DefaultDialect):
_backslash_escapes = True
_supports_create_index_concurrently = True
_supports_drop_index_concurrently = True
_supports_jsonb_subscripting = True
def __init__(
self,
@@ -3151,6 +3270,8 @@ class PGDialect(default.DefaultDialect):
)
self.supports_identity_columns = self.server_version_info >= (10,)
self._supports_jsonb_subscripting = self.server_version_info >= (14,)
def get_isolation_level_values(self, dbapi_conn):
# note the generic dialect doesn't have AUTOCOMMIT, however
# all postgresql dialects should include AUTOCOMMIT.
@@ -3601,6 +3722,7 @@ class PGDialect(default.DefaultDialect):
pg_catalog.pg_sequence.c.seqcache,
"cycle",
pg_catalog.pg_sequence.c.seqcycle,
type_=sqltypes.JSON(),
)
)
.select_from(pg_catalog.pg_sequence)
@@ -3742,8 +3864,8 @@ class PGDialect(default.DefaultDialect):
def _reflect_type(
self,
format_type: Optional[str],
domains: dict[str, ReflectedDomain],
enums: dict[str, ReflectedEnum],
domains: Dict[str, ReflectedDomain],
enums: Dict[str, ReflectedEnum],
type_description: str,
) -> sqltypes.TypeEngine[Any]:
"""
@@ -3813,7 +3935,8 @@ class PGDialect(default.DefaultDialect):
charlen = int(attype_args[0])
args = (charlen,)
elif attype.startswith("interval"):
# a domain or enum can start with interval, so be mindful of that.
elif attype == "interval" or attype.startswith("interval "):
schema_type = INTERVAL
field_match = re.match(r"interval (.+)", attype)
@@ -3830,7 +3953,6 @@ class PGDialect(default.DefaultDialect):
schema_type = ENUM
enum = enums[enum_or_domain_key]
args = tuple(enum["labels"])
kwargs["name"] = enum["name"]
if not enum["visible"]:
@@ -3995,21 +4117,35 @@ class PGDialect(default.DefaultDialect):
result = connection.execute(oid_q, params)
return result.all()
@lru_cache()
def _constraint_query(self, is_unique):
@util.memoized_property
def _constraint_query(self):
if self.server_version_info >= (11, 0):
indnkeyatts = pg_catalog.pg_index.c.indnkeyatts
else:
indnkeyatts = pg_catalog.pg_index.c.indnatts.label("indnkeyatts")
if self.server_version_info >= (15,):
indnullsnotdistinct = pg_catalog.pg_index.c.indnullsnotdistinct
else:
indnullsnotdistinct = sql.false().label("indnullsnotdistinct")
con_sq = (
select(
pg_catalog.pg_constraint.c.conrelid,
pg_catalog.pg_constraint.c.conname,
pg_catalog.pg_constraint.c.conindid,
sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label(
"attnum"
),
sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"),
sql.func.generate_subscripts(
pg_catalog.pg_constraint.c.conkey, 1
pg_catalog.pg_index.c.indkey, 1
).label("ord"),
indnkeyatts,
indnullsnotdistinct,
pg_catalog.pg_description.c.description,
)
.join(
pg_catalog.pg_index,
pg_catalog.pg_constraint.c.conindid
== pg_catalog.pg_index.c.indexrelid,
)
.outerjoin(
pg_catalog.pg_description,
pg_catalog.pg_description.c.objoid
@@ -4018,6 +4154,9 @@ class PGDialect(default.DefaultDialect):
.where(
pg_catalog.pg_constraint.c.contype == bindparam("contype"),
pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")),
# NOTE: filtering also on pg_index.indrelid for oids does
# not seem to have a performance effect, but it may be an
# option if perf problems are reported
)
.subquery("con")
)
@@ -4026,9 +4165,10 @@ class PGDialect(default.DefaultDialect):
select(
con_sq.c.conrelid,
con_sq.c.conname,
con_sq.c.conindid,
con_sq.c.description,
con_sq.c.ord,
con_sq.c.indnkeyatts,
con_sq.c.indnullsnotdistinct,
pg_catalog.pg_attribute.c.attname,
)
.select_from(pg_catalog.pg_attribute)
@@ -4051,7 +4191,7 @@ class PGDialect(default.DefaultDialect):
.subquery("attr")
)
constraint_query = (
return (
select(
attr_sq.c.conrelid,
sql.func.array_agg(
@@ -4063,31 +4203,15 @@ class PGDialect(default.DefaultDialect):
).label("cols"),
attr_sq.c.conname,
sql.func.min(attr_sq.c.description).label("description"),
sql.func.min(attr_sq.c.indnkeyatts).label("indnkeyatts"),
sql.func.bool_and(attr_sq.c.indnullsnotdistinct).label(
"indnullsnotdistinct"
),
)
.group_by(attr_sq.c.conrelid, attr_sq.c.conname)
.order_by(attr_sq.c.conrelid, attr_sq.c.conname)
)
if is_unique:
if self.server_version_info >= (15,):
constraint_query = constraint_query.join(
pg_catalog.pg_index,
attr_sq.c.conindid == pg_catalog.pg_index.c.indexrelid,
).add_columns(
sql.func.bool_and(
pg_catalog.pg_index.c.indnullsnotdistinct
).label("indnullsnotdistinct")
)
else:
constraint_query = constraint_query.add_columns(
sql.false().label("indnullsnotdistinct")
)
else:
constraint_query = constraint_query.add_columns(
sql.null().label("extra")
)
return constraint_query
def _reflect_constraint(
self, connection, contype, schema, filter_names, scope, kind, **kw
):
@@ -4103,26 +4227,42 @@ class PGDialect(default.DefaultDialect):
batches[0:3000] = []
result = connection.execute(
self._constraint_query(is_unique),
self._constraint_query,
{"oids": [r[0] for r in batch], "contype": contype},
)
).mappings()
result_by_oid = defaultdict(list)
for oid, cols, constraint_name, comment, extra in result:
result_by_oid[oid].append(
(cols, constraint_name, comment, extra)
)
for row_dict in result:
result_by_oid[row_dict["conrelid"]].append(row_dict)
for oid, tablename in batch:
for_oid = result_by_oid.get(oid, ())
if for_oid:
for cols, constraint, comment, extra in for_oid:
if is_unique:
yield tablename, cols, constraint, comment, {
"nullsnotdistinct": extra
}
for row in for_oid:
# See note in get_multi_indexes
all_cols = row["cols"]
indnkeyatts = row["indnkeyatts"]
if len(all_cols) > indnkeyatts:
inc_cols = all_cols[indnkeyatts:]
cst_cols = all_cols[:indnkeyatts]
else:
yield tablename, cols, constraint, comment, None
inc_cols = []
cst_cols = all_cols
opts = {}
if self.server_version_info >= (11,):
opts["postgresql_include"] = inc_cols
if is_unique:
opts["postgresql_nulls_not_distinct"] = row[
"indnullsnotdistinct"
]
yield (
tablename,
cst_cols,
row["conname"],
row["description"],
opts,
)
else:
yield tablename, None, None, None, None
@@ -4148,20 +4288,27 @@ class PGDialect(default.DefaultDialect):
# only a single pk can be present for each table. Return an entry
# even if a table has no primary key
default = ReflectionDefaults.pk_constraint
def pk_constraint(pk_name, cols, comment, opts):
info = {
"constrained_columns": cols,
"name": pk_name,
"comment": comment,
}
if opts:
info["dialect_options"] = opts
return info
return (
(
(schema, table_name),
(
{
"constrained_columns": [] if cols is None else cols,
"name": pk_name,
"comment": comment,
}
pk_constraint(pk_name, cols, comment, opts)
if pk_name is not None
else default()
),
)
for table_name, cols, pk_name, comment, _ in result
for table_name, cols, pk_name, comment, opts in result
)
@reflection.cache
@@ -4255,7 +4402,8 @@ class PGDialect(default.DefaultDialect):
r"[\s]?(ON UPDATE "
r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
r"[\s]?(ON DELETE "
r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?"
r"(CASCADE|RESTRICT|NO ACTION|"
r"SET (?:NULL|DEFAULT)(?:\s\(.+\))?)+)?"
r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?"
r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
)
@@ -4371,7 +4519,10 @@ class PGDialect(default.DefaultDialect):
@util.memoized_property
def _index_query(self):
pg_class_index = pg_catalog.pg_class.alias("cls_idx")
# NOTE: pg_index is used as from two times to improve performance,
# since extraing all the index information from `idx_sq` to avoid
# the second pg_index use leads to a worse performing query in
# particular when querying for a single table (as of pg 17)
# NOTE: repeating oids clause improve query performance
# subquery to get the columns
@@ -4380,6 +4531,9 @@ class PGDialect(default.DefaultDialect):
pg_catalog.pg_index.c.indexrelid,
pg_catalog.pg_index.c.indrelid,
sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"),
sql.func.unnest(pg_catalog.pg_index.c.indclass).label(
"att_opclass"
),
sql.func.generate_subscripts(
pg_catalog.pg_index.c.indkey, 1
).label("ord"),
@@ -4411,6 +4565,8 @@ class PGDialect(default.DefaultDialect):
else_=pg_catalog.pg_attribute.c.attname.cast(TEXT),
).label("element"),
(idx_sq.c.attnum == 0).label("is_expr"),
pg_catalog.pg_opclass.c.opcname,
pg_catalog.pg_opclass.c.opcdefault,
)
.select_from(idx_sq)
.outerjoin(
@@ -4421,6 +4577,10 @@ class PGDialect(default.DefaultDialect):
pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid,
),
)
.outerjoin(
pg_catalog.pg_opclass,
pg_catalog.pg_opclass.c.oid == idx_sq.c.att_opclass,
)
.where(idx_sq.c.indrelid.in_(bindparam("oids")))
.subquery("idx_attr")
)
@@ -4435,6 +4595,12 @@ class PGDialect(default.DefaultDialect):
sql.func.array_agg(
aggregate_order_by(attr_sq.c.is_expr, attr_sq.c.ord)
).label("elements_is_expr"),
sql.func.array_agg(
aggregate_order_by(attr_sq.c.opcname, attr_sq.c.ord)
).label("elements_opclass"),
sql.func.array_agg(
aggregate_order_by(attr_sq.c.opcdefault, attr_sq.c.ord)
).label("elements_opdefault"),
)
.group_by(attr_sq.c.indexrelid)
.subquery("idx_cols")
@@ -4443,7 +4609,7 @@ class PGDialect(default.DefaultDialect):
if self.server_version_info >= (11, 0):
indnkeyatts = pg_catalog.pg_index.c.indnkeyatts
else:
indnkeyatts = sql.null().label("indnkeyatts")
indnkeyatts = pg_catalog.pg_index.c.indnatts.label("indnkeyatts")
if self.server_version_info >= (15,):
nulls_not_distinct = pg_catalog.pg_index.c.indnullsnotdistinct
@@ -4453,13 +4619,13 @@ class PGDialect(default.DefaultDialect):
return (
select(
pg_catalog.pg_index.c.indrelid,
pg_class_index.c.relname.label("relname_index"),
pg_catalog.pg_class.c.relname,
pg_catalog.pg_index.c.indisunique,
pg_catalog.pg_constraint.c.conrelid.is_not(None).label(
"has_constraint"
),
pg_catalog.pg_index.c.indoption,
pg_class_index.c.reloptions,
pg_catalog.pg_class.c.reloptions,
pg_catalog.pg_am.c.amname,
# NOTE: pg_get_expr is very fast so this case has almost no
# performance impact
@@ -4477,6 +4643,8 @@ class PGDialect(default.DefaultDialect):
nulls_not_distinct,
cols_sq.c.elements,
cols_sq.c.elements_is_expr,
cols_sq.c.elements_opclass,
cols_sq.c.elements_opdefault,
)
.select_from(pg_catalog.pg_index)
.where(
@@ -4484,12 +4652,12 @@ class PGDialect(default.DefaultDialect):
~pg_catalog.pg_index.c.indisprimary,
)
.join(
pg_class_index,
pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid,
pg_catalog.pg_class,
pg_catalog.pg_index.c.indexrelid == pg_catalog.pg_class.c.oid,
)
.join(
pg_catalog.pg_am,
pg_class_index.c.relam == pg_catalog.pg_am.c.oid,
pg_catalog.pg_class.c.relam == pg_catalog.pg_am.c.oid,
)
.outerjoin(
cols_sq,
@@ -4506,7 +4674,9 @@ class PGDialect(default.DefaultDialect):
== sql.any_(_array.array(("p", "u", "x"))),
),
)
.order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname)
.order_by(
pg_catalog.pg_index.c.indrelid, pg_catalog.pg_class.c.relname
)
)
def get_multi_indexes(
@@ -4541,17 +4711,19 @@ class PGDialect(default.DefaultDialect):
continue
for row in result_by_oid[oid]:
index_name = row["relname_index"]
index_name = row["relname"]
table_indexes = indexes[(schema, table_name)]
all_elements = row["elements"]
all_elements_is_expr = row["elements_is_expr"]
all_elements_opclass = row["elements_opclass"]
all_elements_opdefault = row["elements_opdefault"]
indnkeyatts = row["indnkeyatts"]
# "The number of key columns in the index, not counting any
# included columns, which are merely stored and do not
# participate in the index semantics"
if indnkeyatts and len(all_elements) > indnkeyatts:
if len(all_elements) > indnkeyatts:
# this is a "covering index" which has INCLUDE columns
# as well as regular index columns
inc_cols = all_elements[indnkeyatts:]
@@ -4566,10 +4738,18 @@ class PGDialect(default.DefaultDialect):
not is_expr
for is_expr in all_elements_is_expr[indnkeyatts:]
)
idx_elements_opclass = all_elements_opclass[
:indnkeyatts
]
idx_elements_opdefault = all_elements_opdefault[
:indnkeyatts
]
else:
idx_elements = all_elements
idx_elements_is_expr = all_elements_is_expr
inc_cols = []
idx_elements_opclass = all_elements_opclass
idx_elements_opdefault = all_elements_opdefault
index = {"name": index_name, "unique": row["indisunique"]}
if any(idx_elements_is_expr):
@@ -4583,6 +4763,19 @@ class PGDialect(default.DefaultDialect):
else:
index["column_names"] = idx_elements
dialect_options = {}
if not all(idx_elements_opdefault):
dialect_options["postgresql_ops"] = {
name: opclass
for name, opclass, is_default in zip(
idx_elements,
idx_elements_opclass,
idx_elements_opdefault,
)
if not is_default
}
sorting = {}
for col_index, col_flags in enumerate(row["indoption"]):
col_sorting = ()
@@ -4602,7 +4795,6 @@ class PGDialect(default.DefaultDialect):
if row["has_constraint"]:
index["duplicates_constraint"] = index_name
dialect_options = {}
if row["reloptions"]:
dialect_options["postgresql_with"] = dict(
[
@@ -4681,12 +4873,7 @@ class PGDialect(default.DefaultDialect):
"comment": comment,
}
if options:
if options["nullsnotdistinct"]:
uc_dict["dialect_options"] = {
"postgresql_nulls_not_distinct": options[
"nullsnotdistinct"
]
}
uc_dict["dialect_options"] = options
uniques[(schema, table_name)].append(uc_dict)
return uniques.items()
@@ -5010,11 +5197,12 @@ class PGDialect(default.DefaultDialect):
key=lambda t: t[0],
)
for name, def_ in sorted_constraints:
# constraint is in the form "CHECK (expression)".
# constraint is in the form "CHECK (expression)"
# or "NOT NULL". Ignore the "NOT NULL" and
# remove "CHECK (" and the tailing ")".
check = def_[7:-1]
constraints.append({"name": name, "check": check})
if def_.casefold().startswith("check"):
check = def_[7:-1]
constraints.append({"name": name, "check": check})
domain_rec: ReflectedDomain = {
"name": domain["name"],
"schema": domain["schema"],

View File

@@ -8,6 +8,10 @@
from __future__ import annotations
from typing import Any
from typing import Iterable
from typing import List
from typing import Optional
from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
@@ -23,13 +27,19 @@ from ...sql.schema import ColumnCollectionConstraint
from ...sql.sqltypes import TEXT
from ...sql.visitors import InternalTraversal
_T = TypeVar("_T", bound=Any)
if TYPE_CHECKING:
from ...sql._typing import _ColumnExpressionArgument
from ...sql.elements import ClauseElement
from ...sql.elements import ColumnElement
from ...sql.operators import OperatorType
from ...sql.selectable import FromClause
from ...sql.visitors import _CloneCallableType
from ...sql.visitors import _TraverseInternalsType
_T = TypeVar("_T", bound=Any)
class aggregate_order_by(expression.ColumnElement):
class aggregate_order_by(expression.ColumnElement[_T]):
"""Represent a PostgreSQL aggregate order by expression.
E.g.::
@@ -75,11 +85,32 @@ class aggregate_order_by(expression.ColumnElement):
("order_by", InternalTraversal.dp_clauseelement),
]
def __init__(self, target, *order_by):
self.target = coercions.expect(roles.ExpressionElementRole, target)
@overload
def __init__(
self,
target: ColumnElement[_T],
*order_by: _ColumnExpressionArgument[Any],
): ...
@overload
def __init__(
self,
target: _ColumnExpressionArgument[_T],
*order_by: _ColumnExpressionArgument[Any],
): ...
def __init__(
self,
target: _ColumnExpressionArgument[_T],
*order_by: _ColumnExpressionArgument[Any],
):
self.target: ClauseElement = coercions.expect(
roles.ExpressionElementRole, target
)
self.type = self.target.type
_lob = len(order_by)
self.order_by: ClauseElement
if _lob == 0:
raise TypeError("at least one ORDER BY element is required")
elif _lob == 1:
@@ -91,18 +122,22 @@ class aggregate_order_by(expression.ColumnElement):
*order_by, _literal_as_text_role=roles.ExpressionElementRole
)
def self_group(self, against=None):
def self_group(
self, against: Optional[OperatorType] = None
) -> ClauseElement:
return self
def get_children(self, **kwargs):
def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]:
return self.target, self.order_by
def _copy_internals(self, clone=elements._clone, **kw):
def _copy_internals(
self, clone: _CloneCallableType = elements._clone, **kw: Any
) -> None:
self.target = clone(self.target, **kw)
self.order_by = clone(self.order_by, **kw)
@property
def _from_objects(self):
def _from_objects(self) -> List[FromClause]:
return self.target._from_objects + self.order_by._from_objects

View File

@@ -4,8 +4,15 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from .array import ARRAY
from .array import array as _pg_array
@@ -21,13 +28,23 @@ from .operators import PATH_EXISTS
from .operators import PATH_MATCH
from ... import types as sqltypes
from ...sql import cast
from ...sql._typing import _T
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.elements import ColumnElement
from ...sql.type_api import _BindProcessorType
from ...sql.type_api import _LiteralProcessorType
from ...sql.type_api import TypeEngine
__all__ = ("JSON", "JSONB")
class JSONPathType(sqltypes.JSON.JSONPathType):
def _processor(self, dialect, super_proc):
def process(value):
def _processor(
self, dialect: Dialect, super_proc: Optional[Callable[[Any], Any]]
) -> Callable[[Any], Any]:
def process(value: Any) -> Any:
if isinstance(value, str):
# If it's already a string assume that it's in json path
# format. This allows using cast with json paths literals
@@ -44,11 +61,13 @@ class JSONPathType(sqltypes.JSON.JSONPathType):
return process
def bind_processor(self, dialect):
return self._processor(dialect, self.string_bind_processor(dialect))
def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]:
return self._processor(dialect, self.string_bind_processor(dialect)) # type: ignore[return-value] # noqa: E501
def literal_processor(self, dialect):
return self._processor(dialect, self.string_literal_processor(dialect))
def literal_processor(
self, dialect: Dialect
) -> _LiteralProcessorType[Any]:
return self._processor(dialect, self.string_literal_processor(dialect)) # type: ignore[return-value] # noqa: E501
class JSONPATH(JSONPathType):
@@ -148,9 +167,13 @@ class JSON(sqltypes.JSON):
""" # noqa
render_bind_cast = True
astext_type = sqltypes.Text()
astext_type: TypeEngine[str] = sqltypes.Text()
def __init__(self, none_as_null=False, astext_type=None):
def __init__(
self,
none_as_null: bool = False,
astext_type: Optional[TypeEngine[str]] = None,
):
"""Construct a :class:`_types.JSON` type.
:param none_as_null: if True, persist the value ``None`` as a
@@ -175,11 +198,13 @@ class JSON(sqltypes.JSON):
if astext_type is not None:
self.astext_type = astext_type
class Comparator(sqltypes.JSON.Comparator):
class Comparator(sqltypes.JSON.Comparator[_T]):
"""Define comparison operations for :class:`_types.JSON`."""
type: JSON
@property
def astext(self):
def astext(self) -> ColumnElement[str]:
"""On an indexed expression, use the "astext" (e.g. "->>")
conversion when rendered in SQL.
@@ -193,13 +218,13 @@ class JSON(sqltypes.JSON):
"""
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
return self.expr.left.operate(
return self.expr.left.operate( # type: ignore[no-any-return]
JSONPATH_ASTEXT,
self.expr.right,
result_type=self.type.astext_type,
)
else:
return self.expr.left.operate(
return self.expr.left.operate( # type: ignore[no-any-return]
ASTEXT, self.expr.right, result_type=self.type.astext_type
)
@@ -258,28 +283,30 @@ class JSONB(JSON):
__visit_name__ = "JSONB"
class Comparator(JSON.Comparator):
class Comparator(JSON.Comparator[_T]):
"""Define comparison operations for :class:`_types.JSON`."""
def has_key(self, other):
type: JSONB
def has_key(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test for presence of a key (equivalent of
the ``?`` operator). Note that the key may be a SQLA expression.
"""
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
def has_all(self, other):
def has_all(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test for presence of all keys in jsonb
(equivalent of the ``?&`` operator)
"""
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
def has_any(self, other):
def has_any(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test for presence of any key in jsonb
(equivalent of the ``?|`` operator)
"""
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
def contains(self, other, **kwargs):
def contains(self, other: Any, **kwargs: Any) -> ColumnElement[bool]:
"""Boolean expression. Test if keys (or array) are a superset
of/contained the keys of the argument jsonb expression
(equivalent of the ``@>`` operator).
@@ -289,7 +316,7 @@ class JSONB(JSON):
"""
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
def contained_by(self, other):
def contained_by(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test if keys are a proper subset of the
keys of the argument jsonb expression
(equivalent of the ``<@`` operator).
@@ -298,7 +325,9 @@ class JSONB(JSON):
CONTAINED_BY, other, result_type=sqltypes.Boolean
)
def delete_path(self, array):
def delete_path(
self, array: Union[List[str], _pg_array[str]]
) -> ColumnElement[JSONB]:
"""JSONB expression. Deletes field or array element specified in
the argument array (equivalent of the ``#-`` operator).
@@ -312,7 +341,7 @@ class JSONB(JSON):
right_side = cast(array, ARRAY(sqltypes.TEXT))
return self.operate(DELETE_PATH, right_side, result_type=JSONB)
def path_exists(self, other):
def path_exists(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test for presence of item given by the
argument JSONPath expression (equivalent of the ``@?`` operator).
@@ -322,7 +351,7 @@ class JSONB(JSON):
PATH_EXISTS, other, result_type=sqltypes.Boolean
)
def path_match(self, other):
def path_match(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Test if JSONPath predicate given by the
argument JSONPath expression matches
(equivalent of the ``@@`` operator).

View File

@@ -7,7 +7,9 @@
# mypy: ignore-errors
from __future__ import annotations
from types import ModuleType
from typing import Any
from typing import Dict
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
@@ -25,10 +27,11 @@ from ...sql.ddl import InvokeCreateDDLBase
from ...sql.ddl import InvokeDropDDLBase
if TYPE_CHECKING:
from ...sql._typing import _CreateDropBind
from ...sql._typing import _TypeEngineArgument
class NamedType(sqltypes.TypeEngine):
class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine):
"""Base for named types."""
__abstract__ = True
@@ -36,7 +39,9 @@ class NamedType(sqltypes.TypeEngine):
DDLDropper: Type[NamedTypeDropper]
create_type: bool
def create(self, bind, checkfirst=True, **kw):
def create(
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
) -> None:
"""Emit ``CREATE`` DDL for this type.
:param bind: a connectable :class:`_engine.Engine`,
@@ -50,7 +55,9 @@ class NamedType(sqltypes.TypeEngine):
"""
bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
def drop(self, bind, checkfirst=True, **kw):
def drop(
self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any
) -> None:
"""Emit ``DROP`` DDL for this type.
:param bind: a connectable :class:`_engine.Engine`,
@@ -63,7 +70,9 @@ class NamedType(sqltypes.TypeEngine):
"""
bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
def _check_for_name_in_memos(self, checkfirst, kw):
def _check_for_name_in_memos(
self, checkfirst: bool, kw: Dict[str, Any]
) -> bool:
"""Look in the 'ddl runner' for 'memos', then
note our name in that collection.
@@ -87,7 +96,13 @@ class NamedType(sqltypes.TypeEngine):
else:
return False
def _on_table_create(self, target, bind, checkfirst=False, **kw):
def _on_table_create(
self,
target: Any,
bind: _CreateDropBind,
checkfirst: bool = False,
**kw: Any,
) -> None:
if (
checkfirst
or (
@@ -97,7 +112,13 @@ class NamedType(sqltypes.TypeEngine):
) and not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
def _on_table_drop(
self,
target: Any,
bind: _CreateDropBind,
checkfirst: bool = False,
**kw: Any,
) -> None:
if (
not self.metadata
and not kw.get("_is_metadata_operation", False)
@@ -105,11 +126,23 @@ class NamedType(sqltypes.TypeEngine):
):
self.drop(bind=bind, checkfirst=checkfirst)
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
def _on_metadata_create(
self,
target: Any,
bind: _CreateDropBind,
checkfirst: bool = False,
**kw: Any,
) -> None:
if not self._check_for_name_in_memos(checkfirst, kw):
self.create(bind=bind, checkfirst=checkfirst)
def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
def _on_metadata_drop(
self,
target: Any,
bind: _CreateDropBind,
checkfirst: bool = False,
**kw: Any,
) -> None:
if not self._check_for_name_in_memos(checkfirst, kw):
self.drop(bind=bind, checkfirst=checkfirst)
@@ -314,7 +347,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
return cls(**kw)
def create(self, bind=None, checkfirst=True):
def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
"""Emit ``CREATE TYPE`` for this
:class:`_postgresql.ENUM`.
@@ -335,7 +368,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
super().create(bind, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=True):
def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None:
"""Emit ``DROP TYPE`` for this
:class:`_postgresql.ENUM`.
@@ -355,7 +388,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
super().drop(bind, checkfirst=checkfirst)
def get_dbapi_type(self, dbapi):
def get_dbapi_type(self, dbapi: ModuleType) -> None:
"""dont return dbapi.STRING for ENUM in PostgreSQL, since that's
a different type"""
@@ -470,20 +503,6 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
def __test_init__(cls):
return cls("name", sqltypes.Integer)
def adapt(self, impl, **kw):
if self.default:
kw["default"] = self.default
if self.constraint_name is not None:
kw["constraint_name"] = self.constraint_name
if self.not_null:
kw["not_null"] = self.not_null
if self.check is not None:
kw["check"] = str(self.check)
if self.create_type:
kw["create_type"] = self.create_type
return super().adapt(impl, **kw)
class CreateEnumType(schema._CreateDropBase):
__visit_name__ = "create_enum_type"

View File

@@ -540,6 +540,9 @@ class PGDialect_pg8000(PGDialect):
cursor.execute("COMMIT")
cursor.close()
def detect_autocommit_setting(self, dbapi_conn) -> bool:
return bool(dbapi_conn.autocommit)
def set_readonly(self, connection, value):
cursor = connection.cursor()
try:

View File

@@ -4,7 +4,13 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from __future__ import annotations
from typing import Any
from typing import Optional
from typing import Sequence
from typing import TYPE_CHECKING
from .array import ARRAY
from .types import OID
@@ -23,31 +29,37 @@ from ...types import String
from ...types import Text
from ...types import TypeDecorator
if TYPE_CHECKING:
from ...engine.interfaces import Dialect
from ...sql.type_api import _ResultProcessorType
# types
class NAME(TypeDecorator):
class NAME(TypeDecorator[str]):
impl = String(64, collation="C")
cache_ok = True
class PG_NODE_TREE(TypeDecorator):
class PG_NODE_TREE(TypeDecorator[str]):
impl = Text(collation="C")
cache_ok = True
class INT2VECTOR(TypeDecorator):
class INT2VECTOR(TypeDecorator[Sequence[int]]):
impl = ARRAY(SmallInteger)
cache_ok = True
class OIDVECTOR(TypeDecorator):
class OIDVECTOR(TypeDecorator[Sequence[int]]):
impl = ARRAY(OID)
cache_ok = True
class _SpaceVector:
def result_processor(self, dialect, coltype):
def process(value):
def result_processor(
self, dialect: Dialect, coltype: object
) -> _ResultProcessorType[list[int]]:
def process(value: Any) -> Optional[list[int]]:
if value is None:
return value
return [int(p) for p in value.split(" ")]
@@ -298,3 +310,17 @@ pg_collation = Table(
Column("collicurules", Text, info={"server_version": (16,)}),
Column("collversion", Text, info={"server_version": (10,)}),
)
pg_opclass = Table(
"pg_opclass",
pg_catalog_meta,
Column("oid", OID, info={"server_version": (9, 3)}),
Column("opcmethod", NAME),
Column("opcname", NAME),
Column("opsnamespace", OID),
Column("opsowner", OID),
Column("opcfamily", OID),
Column("opcintype", OID),
Column("opcdefault", Boolean),
Column("opckeytype", OID),
)

View File

@@ -271,9 +271,9 @@ class Range(Generic[_T]):
value2 += step
value2_inc = False
if value1 < value2: # type: ignore
if value1 < value2:
return -1
elif value1 > value2: # type: ignore
elif value1 > value2:
return 1
elif only_values:
return 0

View File

@@ -52,28 +52,38 @@ class BYTEA(sqltypes.LargeBinary):
__visit_name__ = "BYTEA"
class INET(sqltypes.TypeEngine[str]):
class _NetworkAddressTypeMixin:
def coerce_compared_value(
self, op: Optional[OperatorType], value: Any
) -> TypeEngine[Any]:
if TYPE_CHECKING:
assert isinstance(self, TypeEngine)
return self
class INET(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
__visit_name__ = "INET"
PGInet = INET
class CIDR(sqltypes.TypeEngine[str]):
class CIDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
__visit_name__ = "CIDR"
PGCidr = CIDR
class MACADDR(sqltypes.TypeEngine[str]):
class MACADDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
__visit_name__ = "MACADDR"
PGMacAddr = MACADDR
class MACADDR8(sqltypes.TypeEngine[str]):
class MACADDR8(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]):
__visit_name__ = "MACADDR8"