添加注册登录功能

This commit is contained in:
2025-08-29 00:34:40 +08:00
parent 09065f2ce7
commit 2fe3474d9e
3060 changed files with 29217 additions and 87137 deletions

View File

@@ -74,38 +74,49 @@ from .base import _de_clone
from .base import _from_objects
from .base import _NONE_NAME
from .base import _SentinelDefaultCharacterization
from .base import Executable
from .base import NO_ARG
from .elements import ClauseElement
from .elements import quoted_name
from .schema import Column
from .sqltypes import TupleType
from .type_api import TypeEngine
from .visitors import prefix_anon_map
from .visitors import Visitable
from .. import exc
from .. import util
from ..util import FastIntFlag
from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import Self
from ..util.typing import TypedDict
if typing.TYPE_CHECKING:
from .annotation import _AnnotationDict
from .base import _AmbiguousTableNameMap
from .base import CompileState
from .base import Executable
from .cache_key import CacheKey
from .ddl import ExecutableDDLElement
from .dml import Insert
from .dml import Update
from .dml import UpdateBase
from .dml import UpdateDMLState
from .dml import ValuesBase
from .elements import _truncated_label
from .elements import BinaryExpression
from .elements import BindParameter
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import False_
from .elements import Label
from .elements import Null
from .elements import True_
from .functions import Function
from .schema import Column
from .schema import Constraint
from .schema import ForeignKeyConstraint
from .schema import Index
from .schema import PrimaryKeyConstraint
from .schema import Table
from .schema import UniqueConstraint
from .selectable import _ColumnsClauseElement
from .selectable import AliasedReturnsRows
from .selectable import CompoundSelectState
from .selectable import CTE
@@ -115,6 +126,10 @@ if typing.TYPE_CHECKING:
from .selectable import Select
from .selectable import SelectState
from .type_api import _BindProcessorType
from .type_api import TypeDecorator
from .type_api import TypeEngine
from .type_api import UserDefinedType
from .visitors import Visitable
from ..engine.cursor import CursorResultMetaData
from ..engine.interfaces import _CoreSingleExecuteParams
from ..engine.interfaces import _DBAPIAnyExecuteParams
@@ -126,6 +141,7 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import Dialect
from ..engine.interfaces import SchemaTranslateMapType
_FromHintsType = Dict["FromClause", str]
RESERVED_WORDS = {
@@ -870,6 +886,7 @@ class Compiled:
self.string = self.process(self.statement, **compile_kwargs)
if render_schema_translate:
assert schema_translate_map is not None
self.string = self.preparer._render_schema_translates(
self.string, schema_translate_map
)
@@ -902,7 +919,7 @@ class Compiled:
raise exc.UnsupportedCompilationError(self, type(element)) from err
@property
def sql_compiler(self):
def sql_compiler(self) -> SQLCompiler:
"""Return a Compiled that is capable of processing SQL expressions.
If this compiler is one, it would likely just return 'self'.
@@ -1791,7 +1808,7 @@ class SQLCompiler(Compiled):
return len(self.stack) > 1
@property
def sql_compiler(self):
def sql_compiler(self) -> Self:
return self
def construct_expanded_state(
@@ -2297,10 +2314,7 @@ class SQLCompiler(Compiled):
@util.memoized_property
@util.preload_module("sqlalchemy.engine.result")
def _inserted_primary_key_from_returning_getter(self):
if typing.TYPE_CHECKING:
from ..engine import result
else:
result = util.preloaded.engine_result
result = util.preloaded.engine_result
assert self.compile_state is not None
statement = self.compile_state.statement
@@ -2342,7 +2356,7 @@ class SQLCompiler(Compiled):
return get
def default_from(self):
def default_from(self) -> str:
"""Called when a SELECT statement has no froms, and no FROM clause is
to be appended.
@@ -2734,16 +2748,16 @@ class SQLCompiler(Compiled):
return text
def visit_null(self, expr, **kw):
def visit_null(self, expr: Null, **kw: Any) -> str:
return "NULL"
def visit_true(self, expr, **kw):
def visit_true(self, expr: True_, **kw: Any) -> str:
if self.dialect.supports_native_boolean:
return "true"
else:
return "1"
def visit_false(self, expr, **kw):
def visit_false(self, expr: False_, **kw: Any) -> str:
if self.dialect.supports_native_boolean:
return "false"
else:
@@ -2879,14 +2893,18 @@ class SQLCompiler(Compiled):
def visit_over(self, over, **kwargs):
text = over.element._compiler_dispatch(self, **kwargs)
if over.range_:
if over.range_ is not None:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
over.range_, **kwargs
)
elif over.rows:
elif over.rows is not None:
range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
over.rows, **kwargs
)
elif over.groups is not None:
range_ = "GROUPS BETWEEN %s" % self._format_frame_clause(
over.groups, **kwargs
)
else:
range_ = None
@@ -2985,7 +3003,7 @@ class SQLCompiler(Compiled):
% self.dialect.name
)
def function_argspec(self, func, **kwargs):
def function_argspec(self, func: Function[Any], **kwargs: Any) -> str:
return func.clause_expr._compiler_dispatch(self, **kwargs)
def visit_compound_select(
@@ -3449,8 +3467,12 @@ class SQLCompiler(Compiled):
)
def _generate_generic_binary(
self, binary, opstring, eager_grouping=False, **kw
):
self,
binary: BinaryExpression[Any],
opstring: str,
eager_grouping: bool = False,
**kw: Any,
) -> str:
_in_operator_expression = kw.get("_in_operator_expression", False)
kw["_in_operator_expression"] = True
@@ -3619,19 +3641,25 @@ class SQLCompiler(Compiled):
**kw,
)
def visit_regexp_match_op_binary(self, binary, operator, **kw):
def visit_regexp_match_op_binary(
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
) -> str:
raise exc.CompileError(
"%s dialect does not support regular expressions"
% self.dialect.name
)
def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
def visit_not_regexp_match_op_binary(
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
) -> str:
raise exc.CompileError(
"%s dialect does not support regular expressions"
% self.dialect.name
)
def visit_regexp_replace_op_binary(self, binary, operator, **kw):
def visit_regexp_replace_op_binary(
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
) -> str:
raise exc.CompileError(
"%s dialect does not support regular expression replacements"
% self.dialect.name
@@ -3838,7 +3866,9 @@ class SQLCompiler(Compiled):
else:
return self.render_literal_value(value, bindparam.type)
def render_literal_value(self, value, type_):
def render_literal_value(
self, value: Any, type_: sqltypes.TypeEngine[Any]
) -> str:
"""Render the value of a bind parameter as a quoted literal.
This is used for statement sections that do not accept bind parameters
@@ -4074,15 +4104,28 @@ class SQLCompiler(Compiled):
del self.level_name_by_cte[existing_cte_reference_cte]
else:
# if the two CTEs are deep-copy identical, consider them
# the same, **if** they are clones, that is, they came from
# the ORM or other visit method
if (
cte._is_clone_of is not None
or existing_cte._is_clone_of is not None
) and cte.compare(existing_cte):
# if the two CTEs have the same hash, which we expect
# here means that one/both is an annotated of the other
(hash(cte) == hash(existing_cte))
# or...
or (
(
# if they are clones, i.e. they came from the ORM
# or some other visit method
cte._is_clone_of is not None
or existing_cte._is_clone_of is not None
)
# and are deep-copy identical
and cte.compare(existing_cte)
)
):
# then consider these two CTEs the same
is_new_cte = False
else:
# otherwise these are two CTEs that either will render
# differently, or were indicated separately by the user,
# with the same name
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
"the same name: %r" % cte_name
@@ -4115,7 +4158,7 @@ class SQLCompiler(Compiled):
if cte.recursive:
self.ctes_recursive = True
text = self.preparer.format_alias(cte, cte_name)
if cte.recursive:
if cte.recursive or cte.element.name_cte_columns:
col_source = cte.element
# TODO: can we get at the .columns_plus_names collection
@@ -4184,7 +4227,7 @@ class SQLCompiler(Compiled):
if self.preparer._requires_quotes(cte_name):
cte_name = self.preparer.quote(cte_name)
text += self.get_render_as_alias_suffix(cte_name)
return text
return text # type: ignore[no-any-return]
else:
return self.preparer.format_alias(cte, cte_name)
@@ -4246,7 +4289,7 @@ class SQLCompiler(Compiled):
inner = "(%s)" % (inner,)
return inner
else:
enclosing_alias = kwargs["enclosing_alias"] = alias
kwargs["enclosing_alias"] = alias
if asfrom or ashint:
if isinstance(alias.name, elements._truncated_label):
@@ -4336,7 +4379,13 @@ class SQLCompiler(Compiled):
)
return f"VALUES {tuples}"
def visit_values(self, element, asfrom=False, from_linter=None, **kw):
def visit_values(
self, element, asfrom=False, from_linter=None, visiting_cte=None, **kw
):
if element._independent_ctes:
self._dispatch_independent_ctes(element, kw)
v = self._render_values(element, **kw)
if element._unnamed:
@@ -4357,7 +4406,12 @@ class SQLCompiler(Compiled):
name if name is not None else "(unnamed VALUES element)"
)
if name:
if visiting_cte is not None and visiting_cte.element is element:
if element._is_lateral:
raise exc.CompileError(
"Can't use a LATERAL VALUES expression inside of a CTE"
)
elif name:
kw["include_table"] = False
v = "%s(%s)%s (%s)" % (
lateral,
@@ -4541,7 +4595,52 @@ class SQLCompiler(Compiled):
elif isinstance(column, elements.TextClause):
render_with_label = False
elif isinstance(column, elements.UnaryExpression):
render_with_label = column.wraps_column_expression or asfrom
# unary expression. notes added as of #12681
#
# By convention, the visit_unary() method
# itself does not add an entry to the result map, and relies
# upon either the inner expression creating a result map
# entry, or if not, by creating a label here that produces
# the result map entry. Where that happens is based on whether
# or not the element immediately inside the unary is a
# NamedColumn subclass or not.
#
# Now, this also impacts how the SELECT is written; if
# we decide to generate a label here, we get the usual
# "~(x+y) AS anon_1" thing in the columns clause. If we
# don't, we don't get an AS at all, we get like
# "~table.column".
#
# But here is the important thing as of modernish (like 1.4)
# versions of SQLAlchemy - **whether or not the AS <label>
# is present in the statement is not actually important**.
# We target result columns **positionally** for a fully
# compiled ``Select()`` object; before 1.4 we needed those
# labels to match in cursor.description etc etc but now it
# really doesn't matter.
# So really, we could set render_with_label True in all cases.
# Or we could just have visit_unary() populate the result map
# in all cases.
#
# What we're doing here is strictly trying to not rock the
# boat too much with when we do/don't render "AS label";
# labels being present helps in the edge cases that we
# "fall back" to named cursor.description matching, labels
# not being present for columns keeps us from having awkward
# phrases like "SELECT DISTINCT table.x AS x".
render_with_label = (
(
# exception case to detect if we render "not boolean"
# as "not <col>" for native boolean or "<col> = 1"
# for non-native boolean. this is controlled by
# visit_is_<true|false>_unary_operator
column.operator
in (operators.is_false, operators.is_true)
and not self.dialect.supports_native_boolean
)
or column._wraps_unnamed_column()
or asfrom
)
elif (
# general class of expressions that don't have a SQL-column
# addressible name. includes scalar selects, bind parameters,
@@ -4599,7 +4698,9 @@ class SQLCompiler(Compiled):
def get_select_hint_text(self, byfroms):
return None
def get_from_hint_text(self, table, text):
def get_from_hint_text(
self, table: FromClause, text: Optional[str]
) -> Optional[str]:
return None
def get_crud_hint_text(self, table, text):
@@ -5084,7 +5185,7 @@ class SQLCompiler(Compiled):
else:
return "WITH"
def get_select_precolumns(self, select, **kw):
def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
"""Called when building a ``SELECT`` statement, position is just
before column list.
@@ -5129,7 +5230,7 @@ class SQLCompiler(Compiled):
def returning_clause(
self,
stmt: UpdateBase,
returning_cols: Sequence[ColumnElement[Any]],
returning_cols: Sequence[_ColumnsClauseElement],
*,
populate_result_map: bool,
**kw: Any,
@@ -5219,6 +5320,7 @@ class SQLCompiler(Compiled):
use_schema=True,
from_linter=None,
ambiguous_table_name_map=None,
enclosing_alias=None,
**kwargs,
):
if from_linter:
@@ -5237,7 +5339,11 @@ class SQLCompiler(Compiled):
ret = self.preparer.quote(table.name)
if (
not effective_schema
(
enclosing_alias is None
or enclosing_alias.element is not table
)
and not effective_schema
and ambiguous_table_name_map
and table.name in ambiguous_table_name_map
):
@@ -6142,11 +6248,18 @@ class SQLCompiler(Compiled):
"criteria within UPDATE"
)
def visit_update(self, update_stmt, visiting_cte=None, **kw):
def visit_update(
self,
update_stmt: Update,
visiting_cte: Optional[CTE] = None,
**kw: Any,
) -> str:
compile_state = update_stmt._compile_state_factory(
update_stmt, self, **kw
)
update_stmt = compile_state.statement
if TYPE_CHECKING:
assert isinstance(compile_state, UpdateDMLState)
update_stmt = compile_state.statement # type: ignore[assignment]
if visiting_cte is not None:
kw["visiting_cte"] = visiting_cte
@@ -6281,10 +6394,10 @@ class SQLCompiler(Compiled):
self.stack.pop(-1)
return text
return text # type: ignore[no-any-return]
def delete_extra_from_clause(
self, update_stmt, from_table, extra_froms, from_hints, **kw
self, delete_stmt, from_table, extra_froms, from_hints, **kw
):
"""Provide a hook to override the generation of an
DELETE..FROM clause.
@@ -6506,7 +6619,7 @@ class StrSQLCompiler(SQLCompiler):
def returning_clause(
self,
stmt: UpdateBase,
returning_cols: Sequence[ColumnElement[Any]],
returning_cols: Sequence[_ColumnsClauseElement],
*,
populate_result_map: bool,
**kw: Any,
@@ -6527,7 +6640,7 @@ class StrSQLCompiler(SQLCompiler):
)
def delete_extra_from_clause(
self, update_stmt, from_table, extra_froms, from_hints, **kw
self, delete_stmt, from_table, extra_froms, from_hints, **kw
):
kw["asfrom"] = True
return ", " + ", ".join(
@@ -6574,8 +6687,8 @@ class DDLCompiler(Compiled):
compile_kwargs: Mapping[str, Any] = ...,
): ...
@util.memoized_property
def sql_compiler(self):
@util.ro_memoized_property
def sql_compiler(self) -> SQLCompiler:
return self.dialect.statement_compiler(
self.dialect, None, schema_translate_map=self.schema_translate_map
)
@@ -6739,7 +6852,7 @@ class DDLCompiler(Compiled):
def visit_drop_view(self, drop, **kw):
return "\nDROP VIEW " + self.preparer.format_table(drop.element)
def _verify_index_table(self, index):
def _verify_index_table(self, index: Index) -> None:
if index.table is None:
raise exc.CompileError(
"Index '%s' is not associated with any table." % index.name
@@ -6790,7 +6903,9 @@ class DDLCompiler(Compiled):
return text + self._prepared_index_name(index, include_schema=True)
def _prepared_index_name(self, index, include_schema=False):
def _prepared_index_name(
self, index: Index, include_schema: bool = False
) -> str:
if index.table is not None:
effective_schema = self.preparer.schema_for_object(index.table)
else:
@@ -6800,7 +6915,7 @@ class DDLCompiler(Compiled):
else:
schema_name = None
index_name = self.preparer.format_index(index)
index_name: str = self.preparer.format_index(index)
if schema_name:
index_name = schema_name + "." + index_name
@@ -6937,13 +7052,13 @@ class DDLCompiler(Compiled):
def post_create_table(self, table):
return ""
def get_column_default_string(self, column):
def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
if isinstance(column.server_default, schema.DefaultClause):
return self.render_default_string(column.server_default.arg)
else:
return None
def render_default_string(self, default):
def render_default_string(self, default: Union[Visitable, str]) -> str:
if isinstance(default, str):
return self.sql_compiler.render_literal_value(
default, sqltypes.STRINGTYPE
@@ -6981,7 +7096,9 @@ class DDLCompiler(Compiled):
text += self.define_constraint_deferrability(constraint)
return text
def visit_primary_key_constraint(self, constraint, **kw):
def visit_primary_key_constraint(
self, constraint: PrimaryKeyConstraint, **kw: Any
) -> str:
if len(constraint) == 0:
return ""
text = ""
@@ -7030,7 +7147,9 @@ class DDLCompiler(Compiled):
return preparer.format_table(table)
def visit_unique_constraint(self, constraint, **kw):
def visit_unique_constraint(
self, constraint: UniqueConstraint, **kw: Any
) -> str:
if len(constraint) == 0:
return ""
text = ""
@@ -7045,22 +7164,37 @@ class DDLCompiler(Compiled):
text += self.define_constraint_deferrability(constraint)
return text
def define_unique_constraint_distinct(self, constraint, **kw):
def define_unique_constraint_distinct(
self, constraint: UniqueConstraint, **kw: Any
) -> str:
return ""
def define_constraint_cascades(self, constraint):
def define_constraint_cascades(
self, constraint: ForeignKeyConstraint
) -> str:
text = ""
if constraint.ondelete is not None:
text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
constraint.ondelete, FK_ON_DELETE
)
text += self.define_constraint_ondelete_cascade(constraint)
if constraint.onupdate is not None:
text += " ON UPDATE %s" % self.preparer.validate_sql_phrase(
constraint.onupdate, FK_ON_UPDATE
)
text += self.define_constraint_onupdate_cascade(constraint)
return text
def define_constraint_deferrability(self, constraint):
def define_constraint_ondelete_cascade(
self, constraint: ForeignKeyConstraint
) -> str:
return " ON DELETE %s" % self.preparer.validate_sql_phrase(
constraint.ondelete, FK_ON_DELETE
)
def define_constraint_onupdate_cascade(
self, constraint: ForeignKeyConstraint
) -> str:
return " ON UPDATE %s" % self.preparer.validate_sql_phrase(
constraint.onupdate, FK_ON_UPDATE
)
def define_constraint_deferrability(self, constraint: Constraint) -> str:
text = ""
if constraint.deferrable is not None:
if constraint.deferrable:
@@ -7100,19 +7234,21 @@ class DDLCompiler(Compiled):
class GenericTypeCompiler(TypeCompiler):
def visit_FLOAT(self, type_, **kw):
def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
return "FLOAT"
def visit_DOUBLE(self, type_, **kw):
def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
return "DOUBLE"
def visit_DOUBLE_PRECISION(self, type_, **kw):
def visit_DOUBLE_PRECISION(
self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any
) -> str:
return "DOUBLE PRECISION"
def visit_REAL(self, type_, **kw):
def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
return "REAL"
def visit_NUMERIC(self, type_, **kw):
def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
@@ -7123,7 +7259,7 @@ class GenericTypeCompiler(TypeCompiler):
"scale": type_.scale,
}
def visit_DECIMAL(self, type_, **kw):
def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str:
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
@@ -7134,128 +7270,138 @@ class GenericTypeCompiler(TypeCompiler):
"scale": type_.scale,
}
def visit_INTEGER(self, type_, **kw):
def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str:
return "INTEGER"
def visit_SMALLINT(self, type_, **kw):
def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str:
return "SMALLINT"
def visit_BIGINT(self, type_, **kw):
def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
return "BIGINT"
def visit_TIMESTAMP(self, type_, **kw):
def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str:
return "TIMESTAMP"
def visit_DATETIME(self, type_, **kw):
def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str:
return "DATETIME"
def visit_DATE(self, type_, **kw):
def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str:
return "DATE"
def visit_TIME(self, type_, **kw):
def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str:
return "TIME"
def visit_CLOB(self, type_, **kw):
def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str:
return "CLOB"
def visit_NCLOB(self, type_, **kw):
def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str:
return "NCLOB"
def _render_string_type(self, type_, name, length_override=None):
def _render_string_type(
self, name: str, length: Optional[int], collation: Optional[str]
) -> str:
text = name
if length_override:
text += "(%d)" % length_override
elif type_.length:
text += "(%d)" % type_.length
if type_.collation:
text += ' COLLATE "%s"' % type_.collation
if length:
text += f"({length})"
if collation:
text += f' COLLATE "{collation}"'
return text
def visit_CHAR(self, type_, **kw):
return self._render_string_type(type_, "CHAR")
def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str:
return self._render_string_type("CHAR", type_.length, type_.collation)
def visit_NCHAR(self, type_, **kw):
return self._render_string_type(type_, "NCHAR")
def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str:
return self._render_string_type("NCHAR", type_.length, type_.collation)
def visit_VARCHAR(self, type_, **kw):
return self._render_string_type(type_, "VARCHAR")
def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
return self._render_string_type(
"VARCHAR", type_.length, type_.collation
)
def visit_NVARCHAR(self, type_, **kw):
return self._render_string_type(type_, "NVARCHAR")
def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str:
return self._render_string_type(
"NVARCHAR", type_.length, type_.collation
)
def visit_TEXT(self, type_, **kw):
return self._render_string_type(type_, "TEXT")
def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str:
return self._render_string_type("TEXT", type_.length, type_.collation)
def visit_UUID(self, type_, **kw):
def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
return "UUID"
def visit_BLOB(self, type_, **kw):
def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str:
return "BLOB"
def visit_BINARY(self, type_, **kw):
def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str:
return "BINARY" + (type_.length and "(%d)" % type_.length or "")
def visit_VARBINARY(self, type_, **kw):
def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str:
return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
def visit_BOOLEAN(self, type_, **kw):
def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
return "BOOLEAN"
def visit_uuid(self, type_, **kw):
def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
if not type_.native_uuid or not self.dialect.supports_native_uuid:
return self._render_string_type(type_, "CHAR", length_override=32)
return self._render_string_type("CHAR", length=32, collation=None)
else:
return self.visit_UUID(type_, **kw)
def visit_large_binary(self, type_, **kw):
def visit_large_binary(
self, type_: sqltypes.LargeBinary, **kw: Any
) -> str:
return self.visit_BLOB(type_, **kw)
def visit_boolean(self, type_, **kw):
def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str:
return self.visit_BOOLEAN(type_, **kw)
def visit_time(self, type_, **kw):
def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str:
return self.visit_TIME(type_, **kw)
def visit_datetime(self, type_, **kw):
def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str:
return self.visit_DATETIME(type_, **kw)
def visit_date(self, type_, **kw):
def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str:
return self.visit_DATE(type_, **kw)
def visit_big_integer(self, type_, **kw):
def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
return self.visit_BIGINT(type_, **kw)
def visit_small_integer(self, type_, **kw):
def visit_small_integer(
self, type_: sqltypes.SmallInteger, **kw: Any
) -> str:
return self.visit_SMALLINT(type_, **kw)
def visit_integer(self, type_, **kw):
def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str:
return self.visit_INTEGER(type_, **kw)
def visit_real(self, type_, **kw):
def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
return self.visit_REAL(type_, **kw)
def visit_float(self, type_, **kw):
def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
return self.visit_FLOAT(type_, **kw)
def visit_double(self, type_, **kw):
def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
return self.visit_DOUBLE(type_, **kw)
def visit_numeric(self, type_, **kw):
def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
return self.visit_NUMERIC(type_, **kw)
def visit_string(self, type_, **kw):
def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
def visit_unicode(self, type_, **kw):
def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
def visit_text(self, type_, **kw):
def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str:
return self.visit_TEXT(type_, **kw)
def visit_unicode_text(self, type_, **kw):
def visit_unicode_text(
self, type_: sqltypes.UnicodeText, **kw: Any
) -> str:
return self.visit_TEXT(type_, **kw)
def visit_enum(self, type_, **kw):
def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str:
return self.visit_VARCHAR(type_, **kw)
def visit_null(self, type_, **kw):
@@ -7265,10 +7411,14 @@ class GenericTypeCompiler(TypeCompiler):
"type on this Column?" % type_
)
def visit_type_decorator(self, type_, **kw):
def visit_type_decorator(
self, type_: TypeDecorator[Any], **kw: Any
) -> str:
return self.process(type_.type_engine(self.dialect), **kw)
def visit_user_defined(self, type_, **kw):
def visit_user_defined(
self, type_: UserDefinedType[Any], **kw: Any
) -> str:
return type_.get_col_spec(**kw)
@@ -7343,12 +7493,12 @@ class IdentifierPreparer:
def __init__(
self,
dialect,
initial_quote='"',
final_quote=None,
escape_quote='"',
quote_case_sensitive_collations=True,
omit_schema=False,
dialect: Dialect,
initial_quote: str = '"',
final_quote: Optional[str] = None,
escape_quote: str = '"',
quote_case_sensitive_collations: bool = True,
omit_schema: bool = False,
):
"""Construct a new ``IdentifierPreparer`` object.
@@ -7401,7 +7551,9 @@ class IdentifierPreparer:
prep._includes_none_schema_translate = includes_none
return prep
def _render_schema_translates(self, statement, schema_translate_map):
def _render_schema_translates(
self, statement: str, schema_translate_map: SchemaTranslateMapType
) -> str:
d = schema_translate_map
if None in d:
if not self._includes_none_schema_translate:
@@ -7413,7 +7565,7 @@ class IdentifierPreparer:
"schema_translate_map dictionaries."
)
d["_none"] = d[None]
d["_none"] = d[None] # type: ignore[index]
def replace(m):
name = m.group(2)
@@ -7606,7 +7758,9 @@ class IdentifierPreparer:
else:
return collation_name
def format_sequence(self, sequence, use_schema=True):
def format_sequence(
self, sequence: schema.Sequence, use_schema: bool = True
) -> str:
name = self.quote(sequence.name)
effective_schema = self.schema_for_object(sequence)
@@ -7643,7 +7797,9 @@ class IdentifierPreparer:
return ident
@util.preload_module("sqlalchemy.sql.naming")
def format_constraint(self, constraint, _alembic_quote=True):
def format_constraint(
self, constraint: Union[Constraint, Index], _alembic_quote: bool = True
) -> Optional[str]:
naming = util.preloaded.sql_naming
if constraint.name is _NONE_NAME:
@@ -7656,6 +7812,7 @@ class IdentifierPreparer:
else:
name = constraint.name
assert name is not None
if constraint.__visit_name__ == "index":
return self.truncate_and_render_index_name(
name, _alembic_quote=_alembic_quote
@@ -7665,7 +7822,9 @@ class IdentifierPreparer:
name, _alembic_quote=_alembic_quote
)
def truncate_and_render_index_name(self, name, _alembic_quote=True):
def truncate_and_render_index_name(
self, name: str, _alembic_quote: bool = True
) -> str:
# calculate these at format time so that ad-hoc changes
# to dialect.max_identifier_length etc. can be reflected
# as IdentifierPreparer is long lived
@@ -7677,7 +7836,9 @@ class IdentifierPreparer:
name, max_, _alembic_quote
)
def truncate_and_render_constraint_name(self, name, _alembic_quote=True):
def truncate_and_render_constraint_name(
self, name: str, _alembic_quote: bool = True
) -> str:
# calculate these at format time so that ad-hoc changes
# to dialect.max_identifier_length etc. can be reflected
# as IdentifierPreparer is long lived
@@ -7689,7 +7850,9 @@ class IdentifierPreparer:
name, max_, _alembic_quote
)
def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote):
def _truncate_and_render_maxlen_name(
self, name: str, max_: int, _alembic_quote: bool
) -> str:
if isinstance(name, elements._truncated_label):
if len(name) > max_:
name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
@@ -7701,13 +7864,21 @@ class IdentifierPreparer:
else:
return self.quote(name)
def format_index(self, index):
return self.format_constraint(index)
def format_index(self, index: Index) -> str:
name = self.format_constraint(index)
assert name is not None
return name
def format_table(self, table, use_schema=True, name=None):
def format_table(
self,
table: FromClause,
use_schema: bool = True,
name: Optional[str] = None,
) -> str:
"""Prepare a quoted table and schema name."""
if name is None:
if TYPE_CHECKING:
assert isinstance(table, NamedFromClause)
name = table.name
result = self.quote(name)
@@ -7739,17 +7910,18 @@ class IdentifierPreparer:
def format_column(
self,
column,
use_table=False,
name=None,
table_name=None,
use_schema=False,
anon_map=None,
):
column: ColumnElement[Any],
use_table: bool = False,
name: Optional[str] = None,
table_name: Optional[str] = None,
use_schema: bool = False,
anon_map: Optional[Mapping[str, Any]] = None,
) -> str:
"""Prepare a quoted column name."""
if name is None:
name = column.name
assert name is not None
if anon_map is not None and isinstance(
name, elements._truncated_label
@@ -7817,7 +7989,7 @@ class IdentifierPreparer:
)
return r
def unformat_identifiers(self, identifiers):
def unformat_identifiers(self, identifiers: str) -> Sequence[str]:
"""Unpack 'schema.table.column'-like strings into components."""
r = self._r_identifiers