添加注册登录功能
This commit is contained in:
@@ -74,38 +74,49 @@ from .base import _de_clone
|
||||
from .base import _from_objects
|
||||
from .base import _NONE_NAME
|
||||
from .base import _SentinelDefaultCharacterization
|
||||
from .base import Executable
|
||||
from .base import NO_ARG
|
||||
from .elements import ClauseElement
|
||||
from .elements import quoted_name
|
||||
from .schema import Column
|
||||
from .sqltypes import TupleType
|
||||
from .type_api import TypeEngine
|
||||
from .visitors import prefix_anon_map
|
||||
from .visitors import Visitable
|
||||
from .. import exc
|
||||
from .. import util
|
||||
from ..util import FastIntFlag
|
||||
from ..util.typing import Literal
|
||||
from ..util.typing import Protocol
|
||||
from ..util.typing import Self
|
||||
from ..util.typing import TypedDict
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .annotation import _AnnotationDict
|
||||
from .base import _AmbiguousTableNameMap
|
||||
from .base import CompileState
|
||||
from .base import Executable
|
||||
from .cache_key import CacheKey
|
||||
from .ddl import ExecutableDDLElement
|
||||
from .dml import Insert
|
||||
from .dml import Update
|
||||
from .dml import UpdateBase
|
||||
from .dml import UpdateDMLState
|
||||
from .dml import ValuesBase
|
||||
from .elements import _truncated_label
|
||||
from .elements import BinaryExpression
|
||||
from .elements import BindParameter
|
||||
from .elements import ClauseElement
|
||||
from .elements import ColumnClause
|
||||
from .elements import ColumnElement
|
||||
from .elements import False_
|
||||
from .elements import Label
|
||||
from .elements import Null
|
||||
from .elements import True_
|
||||
from .functions import Function
|
||||
from .schema import Column
|
||||
from .schema import Constraint
|
||||
from .schema import ForeignKeyConstraint
|
||||
from .schema import Index
|
||||
from .schema import PrimaryKeyConstraint
|
||||
from .schema import Table
|
||||
from .schema import UniqueConstraint
|
||||
from .selectable import _ColumnsClauseElement
|
||||
from .selectable import AliasedReturnsRows
|
||||
from .selectable import CompoundSelectState
|
||||
from .selectable import CTE
|
||||
@@ -115,6 +126,10 @@ if typing.TYPE_CHECKING:
|
||||
from .selectable import Select
|
||||
from .selectable import SelectState
|
||||
from .type_api import _BindProcessorType
|
||||
from .type_api import TypeDecorator
|
||||
from .type_api import TypeEngine
|
||||
from .type_api import UserDefinedType
|
||||
from .visitors import Visitable
|
||||
from ..engine.cursor import CursorResultMetaData
|
||||
from ..engine.interfaces import _CoreSingleExecuteParams
|
||||
from ..engine.interfaces import _DBAPIAnyExecuteParams
|
||||
@@ -126,6 +141,7 @@ if typing.TYPE_CHECKING:
|
||||
from ..engine.interfaces import Dialect
|
||||
from ..engine.interfaces import SchemaTranslateMapType
|
||||
|
||||
|
||||
_FromHintsType = Dict["FromClause", str]
|
||||
|
||||
RESERVED_WORDS = {
|
||||
@@ -870,6 +886,7 @@ class Compiled:
|
||||
self.string = self.process(self.statement, **compile_kwargs)
|
||||
|
||||
if render_schema_translate:
|
||||
assert schema_translate_map is not None
|
||||
self.string = self.preparer._render_schema_translates(
|
||||
self.string, schema_translate_map
|
||||
)
|
||||
@@ -902,7 +919,7 @@ class Compiled:
|
||||
raise exc.UnsupportedCompilationError(self, type(element)) from err
|
||||
|
||||
@property
|
||||
def sql_compiler(self):
|
||||
def sql_compiler(self) -> SQLCompiler:
|
||||
"""Return a Compiled that is capable of processing SQL expressions.
|
||||
|
||||
If this compiler is one, it would likely just return 'self'.
|
||||
@@ -1791,7 +1808,7 @@ class SQLCompiler(Compiled):
|
||||
return len(self.stack) > 1
|
||||
|
||||
@property
|
||||
def sql_compiler(self):
|
||||
def sql_compiler(self) -> Self:
|
||||
return self
|
||||
|
||||
def construct_expanded_state(
|
||||
@@ -2297,10 +2314,7 @@ class SQLCompiler(Compiled):
|
||||
@util.memoized_property
|
||||
@util.preload_module("sqlalchemy.engine.result")
|
||||
def _inserted_primary_key_from_returning_getter(self):
|
||||
if typing.TYPE_CHECKING:
|
||||
from ..engine import result
|
||||
else:
|
||||
result = util.preloaded.engine_result
|
||||
result = util.preloaded.engine_result
|
||||
|
||||
assert self.compile_state is not None
|
||||
statement = self.compile_state.statement
|
||||
@@ -2342,7 +2356,7 @@ class SQLCompiler(Compiled):
|
||||
|
||||
return get
|
||||
|
||||
def default_from(self):
|
||||
def default_from(self) -> str:
|
||||
"""Called when a SELECT statement has no froms, and no FROM clause is
|
||||
to be appended.
|
||||
|
||||
@@ -2734,16 +2748,16 @@ class SQLCompiler(Compiled):
|
||||
|
||||
return text
|
||||
|
||||
def visit_null(self, expr, **kw):
|
||||
def visit_null(self, expr: Null, **kw: Any) -> str:
|
||||
return "NULL"
|
||||
|
||||
def visit_true(self, expr, **kw):
|
||||
def visit_true(self, expr: True_, **kw: Any) -> str:
|
||||
if self.dialect.supports_native_boolean:
|
||||
return "true"
|
||||
else:
|
||||
return "1"
|
||||
|
||||
def visit_false(self, expr, **kw):
|
||||
def visit_false(self, expr: False_, **kw: Any) -> str:
|
||||
if self.dialect.supports_native_boolean:
|
||||
return "false"
|
||||
else:
|
||||
@@ -2879,14 +2893,18 @@ class SQLCompiler(Compiled):
|
||||
|
||||
def visit_over(self, over, **kwargs):
|
||||
text = over.element._compiler_dispatch(self, **kwargs)
|
||||
if over.range_:
|
||||
if over.range_ is not None:
|
||||
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
|
||||
over.range_, **kwargs
|
||||
)
|
||||
elif over.rows:
|
||||
elif over.rows is not None:
|
||||
range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
|
||||
over.rows, **kwargs
|
||||
)
|
||||
elif over.groups is not None:
|
||||
range_ = "GROUPS BETWEEN %s" % self._format_frame_clause(
|
||||
over.groups, **kwargs
|
||||
)
|
||||
else:
|
||||
range_ = None
|
||||
|
||||
@@ -2985,7 +3003,7 @@ class SQLCompiler(Compiled):
|
||||
% self.dialect.name
|
||||
)
|
||||
|
||||
def function_argspec(self, func, **kwargs):
|
||||
def function_argspec(self, func: Function[Any], **kwargs: Any) -> str:
|
||||
return func.clause_expr._compiler_dispatch(self, **kwargs)
|
||||
|
||||
def visit_compound_select(
|
||||
@@ -3449,8 +3467,12 @@ class SQLCompiler(Compiled):
|
||||
)
|
||||
|
||||
def _generate_generic_binary(
|
||||
self, binary, opstring, eager_grouping=False, **kw
|
||||
):
|
||||
self,
|
||||
binary: BinaryExpression[Any],
|
||||
opstring: str,
|
||||
eager_grouping: bool = False,
|
||||
**kw: Any,
|
||||
) -> str:
|
||||
_in_operator_expression = kw.get("_in_operator_expression", False)
|
||||
|
||||
kw["_in_operator_expression"] = True
|
||||
@@ -3619,19 +3641,25 @@ class SQLCompiler(Compiled):
|
||||
**kw,
|
||||
)
|
||||
|
||||
def visit_regexp_match_op_binary(self, binary, operator, **kw):
|
||||
def visit_regexp_match_op_binary(
|
||||
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
|
||||
) -> str:
|
||||
raise exc.CompileError(
|
||||
"%s dialect does not support regular expressions"
|
||||
% self.dialect.name
|
||||
)
|
||||
|
||||
def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
|
||||
def visit_not_regexp_match_op_binary(
|
||||
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
|
||||
) -> str:
|
||||
raise exc.CompileError(
|
||||
"%s dialect does not support regular expressions"
|
||||
% self.dialect.name
|
||||
)
|
||||
|
||||
def visit_regexp_replace_op_binary(self, binary, operator, **kw):
|
||||
def visit_regexp_replace_op_binary(
|
||||
self, binary: BinaryExpression[Any], operator: Any, **kw: Any
|
||||
) -> str:
|
||||
raise exc.CompileError(
|
||||
"%s dialect does not support regular expression replacements"
|
||||
% self.dialect.name
|
||||
@@ -3838,7 +3866,9 @@ class SQLCompiler(Compiled):
|
||||
else:
|
||||
return self.render_literal_value(value, bindparam.type)
|
||||
|
||||
def render_literal_value(self, value, type_):
|
||||
def render_literal_value(
|
||||
self, value: Any, type_: sqltypes.TypeEngine[Any]
|
||||
) -> str:
|
||||
"""Render the value of a bind parameter as a quoted literal.
|
||||
|
||||
This is used for statement sections that do not accept bind parameters
|
||||
@@ -4074,15 +4104,28 @@ class SQLCompiler(Compiled):
|
||||
|
||||
del self.level_name_by_cte[existing_cte_reference_cte]
|
||||
else:
|
||||
# if the two CTEs are deep-copy identical, consider them
|
||||
# the same, **if** they are clones, that is, they came from
|
||||
# the ORM or other visit method
|
||||
if (
|
||||
cte._is_clone_of is not None
|
||||
or existing_cte._is_clone_of is not None
|
||||
) and cte.compare(existing_cte):
|
||||
# if the two CTEs have the same hash, which we expect
|
||||
# here means that one/both is an annotated of the other
|
||||
(hash(cte) == hash(existing_cte))
|
||||
# or...
|
||||
or (
|
||||
(
|
||||
# if they are clones, i.e. they came from the ORM
|
||||
# or some other visit method
|
||||
cte._is_clone_of is not None
|
||||
or existing_cte._is_clone_of is not None
|
||||
)
|
||||
# and are deep-copy identical
|
||||
and cte.compare(existing_cte)
|
||||
)
|
||||
):
|
||||
# then consider these two CTEs the same
|
||||
is_new_cte = False
|
||||
else:
|
||||
# otherwise these are two CTEs that either will render
|
||||
# differently, or were indicated separately by the user,
|
||||
# with the same name
|
||||
raise exc.CompileError(
|
||||
"Multiple, unrelated CTEs found with "
|
||||
"the same name: %r" % cte_name
|
||||
@@ -4115,7 +4158,7 @@ class SQLCompiler(Compiled):
|
||||
if cte.recursive:
|
||||
self.ctes_recursive = True
|
||||
text = self.preparer.format_alias(cte, cte_name)
|
||||
if cte.recursive:
|
||||
if cte.recursive or cte.element.name_cte_columns:
|
||||
col_source = cte.element
|
||||
|
||||
# TODO: can we get at the .columns_plus_names collection
|
||||
@@ -4184,7 +4227,7 @@ class SQLCompiler(Compiled):
|
||||
if self.preparer._requires_quotes(cte_name):
|
||||
cte_name = self.preparer.quote(cte_name)
|
||||
text += self.get_render_as_alias_suffix(cte_name)
|
||||
return text
|
||||
return text # type: ignore[no-any-return]
|
||||
else:
|
||||
return self.preparer.format_alias(cte, cte_name)
|
||||
|
||||
@@ -4246,7 +4289,7 @@ class SQLCompiler(Compiled):
|
||||
inner = "(%s)" % (inner,)
|
||||
return inner
|
||||
else:
|
||||
enclosing_alias = kwargs["enclosing_alias"] = alias
|
||||
kwargs["enclosing_alias"] = alias
|
||||
|
||||
if asfrom or ashint:
|
||||
if isinstance(alias.name, elements._truncated_label):
|
||||
@@ -4336,7 +4379,13 @@ class SQLCompiler(Compiled):
|
||||
)
|
||||
return f"VALUES {tuples}"
|
||||
|
||||
def visit_values(self, element, asfrom=False, from_linter=None, **kw):
|
||||
def visit_values(
|
||||
self, element, asfrom=False, from_linter=None, visiting_cte=None, **kw
|
||||
):
|
||||
|
||||
if element._independent_ctes:
|
||||
self._dispatch_independent_ctes(element, kw)
|
||||
|
||||
v = self._render_values(element, **kw)
|
||||
|
||||
if element._unnamed:
|
||||
@@ -4357,7 +4406,12 @@ class SQLCompiler(Compiled):
|
||||
name if name is not None else "(unnamed VALUES element)"
|
||||
)
|
||||
|
||||
if name:
|
||||
if visiting_cte is not None and visiting_cte.element is element:
|
||||
if element._is_lateral:
|
||||
raise exc.CompileError(
|
||||
"Can't use a LATERAL VALUES expression inside of a CTE"
|
||||
)
|
||||
elif name:
|
||||
kw["include_table"] = False
|
||||
v = "%s(%s)%s (%s)" % (
|
||||
lateral,
|
||||
@@ -4541,7 +4595,52 @@ class SQLCompiler(Compiled):
|
||||
elif isinstance(column, elements.TextClause):
|
||||
render_with_label = False
|
||||
elif isinstance(column, elements.UnaryExpression):
|
||||
render_with_label = column.wraps_column_expression or asfrom
|
||||
# unary expression. notes added as of #12681
|
||||
#
|
||||
# By convention, the visit_unary() method
|
||||
# itself does not add an entry to the result map, and relies
|
||||
# upon either the inner expression creating a result map
|
||||
# entry, or if not, by creating a label here that produces
|
||||
# the result map entry. Where that happens is based on whether
|
||||
# or not the element immediately inside the unary is a
|
||||
# NamedColumn subclass or not.
|
||||
#
|
||||
# Now, this also impacts how the SELECT is written; if
|
||||
# we decide to generate a label here, we get the usual
|
||||
# "~(x+y) AS anon_1" thing in the columns clause. If we
|
||||
# don't, we don't get an AS at all, we get like
|
||||
# "~table.column".
|
||||
#
|
||||
# But here is the important thing as of modernish (like 1.4)
|
||||
# versions of SQLAlchemy - **whether or not the AS <label>
|
||||
# is present in the statement is not actually important**.
|
||||
# We target result columns **positionally** for a fully
|
||||
# compiled ``Select()`` object; before 1.4 we needed those
|
||||
# labels to match in cursor.description etc etc but now it
|
||||
# really doesn't matter.
|
||||
# So really, we could set render_with_label True in all cases.
|
||||
# Or we could just have visit_unary() populate the result map
|
||||
# in all cases.
|
||||
#
|
||||
# What we're doing here is strictly trying to not rock the
|
||||
# boat too much with when we do/don't render "AS label";
|
||||
# labels being present helps in the edge cases that we
|
||||
# "fall back" to named cursor.description matching, labels
|
||||
# not being present for columns keeps us from having awkward
|
||||
# phrases like "SELECT DISTINCT table.x AS x".
|
||||
render_with_label = (
|
||||
(
|
||||
# exception case to detect if we render "not boolean"
|
||||
# as "not <col>" for native boolean or "<col> = 1"
|
||||
# for non-native boolean. this is controlled by
|
||||
# visit_is_<true|false>_unary_operator
|
||||
column.operator
|
||||
in (operators.is_false, operators.is_true)
|
||||
and not self.dialect.supports_native_boolean
|
||||
)
|
||||
or column._wraps_unnamed_column()
|
||||
or asfrom
|
||||
)
|
||||
elif (
|
||||
# general class of expressions that don't have a SQL-column
|
||||
# addressible name. includes scalar selects, bind parameters,
|
||||
@@ -4599,7 +4698,9 @@ class SQLCompiler(Compiled):
|
||||
def get_select_hint_text(self, byfroms):
|
||||
return None
|
||||
|
||||
def get_from_hint_text(self, table, text):
|
||||
def get_from_hint_text(
|
||||
self, table: FromClause, text: Optional[str]
|
||||
) -> Optional[str]:
|
||||
return None
|
||||
|
||||
def get_crud_hint_text(self, table, text):
|
||||
@@ -5084,7 +5185,7 @@ class SQLCompiler(Compiled):
|
||||
else:
|
||||
return "WITH"
|
||||
|
||||
def get_select_precolumns(self, select, **kw):
|
||||
def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
|
||||
"""Called when building a ``SELECT`` statement, position is just
|
||||
before column list.
|
||||
|
||||
@@ -5129,7 +5230,7 @@ class SQLCompiler(Compiled):
|
||||
def returning_clause(
|
||||
self,
|
||||
stmt: UpdateBase,
|
||||
returning_cols: Sequence[ColumnElement[Any]],
|
||||
returning_cols: Sequence[_ColumnsClauseElement],
|
||||
*,
|
||||
populate_result_map: bool,
|
||||
**kw: Any,
|
||||
@@ -5219,6 +5320,7 @@ class SQLCompiler(Compiled):
|
||||
use_schema=True,
|
||||
from_linter=None,
|
||||
ambiguous_table_name_map=None,
|
||||
enclosing_alias=None,
|
||||
**kwargs,
|
||||
):
|
||||
if from_linter:
|
||||
@@ -5237,7 +5339,11 @@ class SQLCompiler(Compiled):
|
||||
ret = self.preparer.quote(table.name)
|
||||
|
||||
if (
|
||||
not effective_schema
|
||||
(
|
||||
enclosing_alias is None
|
||||
or enclosing_alias.element is not table
|
||||
)
|
||||
and not effective_schema
|
||||
and ambiguous_table_name_map
|
||||
and table.name in ambiguous_table_name_map
|
||||
):
|
||||
@@ -6142,11 +6248,18 @@ class SQLCompiler(Compiled):
|
||||
"criteria within UPDATE"
|
||||
)
|
||||
|
||||
def visit_update(self, update_stmt, visiting_cte=None, **kw):
|
||||
def visit_update(
|
||||
self,
|
||||
update_stmt: Update,
|
||||
visiting_cte: Optional[CTE] = None,
|
||||
**kw: Any,
|
||||
) -> str:
|
||||
compile_state = update_stmt._compile_state_factory(
|
||||
update_stmt, self, **kw
|
||||
)
|
||||
update_stmt = compile_state.statement
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(compile_state, UpdateDMLState)
|
||||
update_stmt = compile_state.statement # type: ignore[assignment]
|
||||
|
||||
if visiting_cte is not None:
|
||||
kw["visiting_cte"] = visiting_cte
|
||||
@@ -6281,10 +6394,10 @@ class SQLCompiler(Compiled):
|
||||
|
||||
self.stack.pop(-1)
|
||||
|
||||
return text
|
||||
return text # type: ignore[no-any-return]
|
||||
|
||||
def delete_extra_from_clause(
|
||||
self, update_stmt, from_table, extra_froms, from_hints, **kw
|
||||
self, delete_stmt, from_table, extra_froms, from_hints, **kw
|
||||
):
|
||||
"""Provide a hook to override the generation of an
|
||||
DELETE..FROM clause.
|
||||
@@ -6506,7 +6619,7 @@ class StrSQLCompiler(SQLCompiler):
|
||||
def returning_clause(
|
||||
self,
|
||||
stmt: UpdateBase,
|
||||
returning_cols: Sequence[ColumnElement[Any]],
|
||||
returning_cols: Sequence[_ColumnsClauseElement],
|
||||
*,
|
||||
populate_result_map: bool,
|
||||
**kw: Any,
|
||||
@@ -6527,7 +6640,7 @@ class StrSQLCompiler(SQLCompiler):
|
||||
)
|
||||
|
||||
def delete_extra_from_clause(
|
||||
self, update_stmt, from_table, extra_froms, from_hints, **kw
|
||||
self, delete_stmt, from_table, extra_froms, from_hints, **kw
|
||||
):
|
||||
kw["asfrom"] = True
|
||||
return ", " + ", ".join(
|
||||
@@ -6574,8 +6687,8 @@ class DDLCompiler(Compiled):
|
||||
compile_kwargs: Mapping[str, Any] = ...,
|
||||
): ...
|
||||
|
||||
@util.memoized_property
|
||||
def sql_compiler(self):
|
||||
@util.ro_memoized_property
|
||||
def sql_compiler(self) -> SQLCompiler:
|
||||
return self.dialect.statement_compiler(
|
||||
self.dialect, None, schema_translate_map=self.schema_translate_map
|
||||
)
|
||||
@@ -6739,7 +6852,7 @@ class DDLCompiler(Compiled):
|
||||
def visit_drop_view(self, drop, **kw):
|
||||
return "\nDROP VIEW " + self.preparer.format_table(drop.element)
|
||||
|
||||
def _verify_index_table(self, index):
|
||||
def _verify_index_table(self, index: Index) -> None:
|
||||
if index.table is None:
|
||||
raise exc.CompileError(
|
||||
"Index '%s' is not associated with any table." % index.name
|
||||
@@ -6790,7 +6903,9 @@ class DDLCompiler(Compiled):
|
||||
|
||||
return text + self._prepared_index_name(index, include_schema=True)
|
||||
|
||||
def _prepared_index_name(self, index, include_schema=False):
|
||||
def _prepared_index_name(
|
||||
self, index: Index, include_schema: bool = False
|
||||
) -> str:
|
||||
if index.table is not None:
|
||||
effective_schema = self.preparer.schema_for_object(index.table)
|
||||
else:
|
||||
@@ -6800,7 +6915,7 @@ class DDLCompiler(Compiled):
|
||||
else:
|
||||
schema_name = None
|
||||
|
||||
index_name = self.preparer.format_index(index)
|
||||
index_name: str = self.preparer.format_index(index)
|
||||
|
||||
if schema_name:
|
||||
index_name = schema_name + "." + index_name
|
||||
@@ -6937,13 +7052,13 @@ class DDLCompiler(Compiled):
|
||||
def post_create_table(self, table):
|
||||
return ""
|
||||
|
||||
def get_column_default_string(self, column):
|
||||
def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
|
||||
if isinstance(column.server_default, schema.DefaultClause):
|
||||
return self.render_default_string(column.server_default.arg)
|
||||
else:
|
||||
return None
|
||||
|
||||
def render_default_string(self, default):
|
||||
def render_default_string(self, default: Union[Visitable, str]) -> str:
|
||||
if isinstance(default, str):
|
||||
return self.sql_compiler.render_literal_value(
|
||||
default, sqltypes.STRINGTYPE
|
||||
@@ -6981,7 +7096,9 @@ class DDLCompiler(Compiled):
|
||||
text += self.define_constraint_deferrability(constraint)
|
||||
return text
|
||||
|
||||
def visit_primary_key_constraint(self, constraint, **kw):
|
||||
def visit_primary_key_constraint(
|
||||
self, constraint: PrimaryKeyConstraint, **kw: Any
|
||||
) -> str:
|
||||
if len(constraint) == 0:
|
||||
return ""
|
||||
text = ""
|
||||
@@ -7030,7 +7147,9 @@ class DDLCompiler(Compiled):
|
||||
|
||||
return preparer.format_table(table)
|
||||
|
||||
def visit_unique_constraint(self, constraint, **kw):
|
||||
def visit_unique_constraint(
|
||||
self, constraint: UniqueConstraint, **kw: Any
|
||||
) -> str:
|
||||
if len(constraint) == 0:
|
||||
return ""
|
||||
text = ""
|
||||
@@ -7045,22 +7164,37 @@ class DDLCompiler(Compiled):
|
||||
text += self.define_constraint_deferrability(constraint)
|
||||
return text
|
||||
|
||||
def define_unique_constraint_distinct(self, constraint, **kw):
|
||||
def define_unique_constraint_distinct(
|
||||
self, constraint: UniqueConstraint, **kw: Any
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
def define_constraint_cascades(self, constraint):
|
||||
def define_constraint_cascades(
|
||||
self, constraint: ForeignKeyConstraint
|
||||
) -> str:
|
||||
text = ""
|
||||
if constraint.ondelete is not None:
|
||||
text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
|
||||
constraint.ondelete, FK_ON_DELETE
|
||||
)
|
||||
text += self.define_constraint_ondelete_cascade(constraint)
|
||||
|
||||
if constraint.onupdate is not None:
|
||||
text += " ON UPDATE %s" % self.preparer.validate_sql_phrase(
|
||||
constraint.onupdate, FK_ON_UPDATE
|
||||
)
|
||||
text += self.define_constraint_onupdate_cascade(constraint)
|
||||
return text
|
||||
|
||||
def define_constraint_deferrability(self, constraint):
|
||||
def define_constraint_ondelete_cascade(
|
||||
self, constraint: ForeignKeyConstraint
|
||||
) -> str:
|
||||
return " ON DELETE %s" % self.preparer.validate_sql_phrase(
|
||||
constraint.ondelete, FK_ON_DELETE
|
||||
)
|
||||
|
||||
def define_constraint_onupdate_cascade(
|
||||
self, constraint: ForeignKeyConstraint
|
||||
) -> str:
|
||||
return " ON UPDATE %s" % self.preparer.validate_sql_phrase(
|
||||
constraint.onupdate, FK_ON_UPDATE
|
||||
)
|
||||
|
||||
def define_constraint_deferrability(self, constraint: Constraint) -> str:
|
||||
text = ""
|
||||
if constraint.deferrable is not None:
|
||||
if constraint.deferrable:
|
||||
@@ -7100,19 +7234,21 @@ class DDLCompiler(Compiled):
|
||||
|
||||
|
||||
class GenericTypeCompiler(TypeCompiler):
|
||||
def visit_FLOAT(self, type_, **kw):
|
||||
def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
|
||||
return "FLOAT"
|
||||
|
||||
def visit_DOUBLE(self, type_, **kw):
|
||||
def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
|
||||
return "DOUBLE"
|
||||
|
||||
def visit_DOUBLE_PRECISION(self, type_, **kw):
|
||||
def visit_DOUBLE_PRECISION(
|
||||
self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any
|
||||
) -> str:
|
||||
return "DOUBLE PRECISION"
|
||||
|
||||
def visit_REAL(self, type_, **kw):
|
||||
def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
|
||||
return "REAL"
|
||||
|
||||
def visit_NUMERIC(self, type_, **kw):
|
||||
def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
|
||||
if type_.precision is None:
|
||||
return "NUMERIC"
|
||||
elif type_.scale is None:
|
||||
@@ -7123,7 +7259,7 @@ class GenericTypeCompiler(TypeCompiler):
|
||||
"scale": type_.scale,
|
||||
}
|
||||
|
||||
def visit_DECIMAL(self, type_, **kw):
|
||||
def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str:
|
||||
if type_.precision is None:
|
||||
return "DECIMAL"
|
||||
elif type_.scale is None:
|
||||
@@ -7134,128 +7270,138 @@ class GenericTypeCompiler(TypeCompiler):
|
||||
"scale": type_.scale,
|
||||
}
|
||||
|
||||
def visit_INTEGER(self, type_, **kw):
|
||||
def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str:
|
||||
return "INTEGER"
|
||||
|
||||
def visit_SMALLINT(self, type_, **kw):
|
||||
def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str:
|
||||
return "SMALLINT"
|
||||
|
||||
def visit_BIGINT(self, type_, **kw):
|
||||
def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
|
||||
return "BIGINT"
|
||||
|
||||
def visit_TIMESTAMP(self, type_, **kw):
|
||||
def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str:
|
||||
return "TIMESTAMP"
|
||||
|
||||
def visit_DATETIME(self, type_, **kw):
|
||||
def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str:
|
||||
return "DATETIME"
|
||||
|
||||
def visit_DATE(self, type_, **kw):
|
||||
def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str:
|
||||
return "DATE"
|
||||
|
||||
def visit_TIME(self, type_, **kw):
|
||||
def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str:
|
||||
return "TIME"
|
||||
|
||||
def visit_CLOB(self, type_, **kw):
|
||||
def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str:
|
||||
return "CLOB"
|
||||
|
||||
def visit_NCLOB(self, type_, **kw):
|
||||
def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str:
|
||||
return "NCLOB"
|
||||
|
||||
def _render_string_type(self, type_, name, length_override=None):
|
||||
def _render_string_type(
|
||||
self, name: str, length: Optional[int], collation: Optional[str]
|
||||
) -> str:
|
||||
text = name
|
||||
if length_override:
|
||||
text += "(%d)" % length_override
|
||||
elif type_.length:
|
||||
text += "(%d)" % type_.length
|
||||
if type_.collation:
|
||||
text += ' COLLATE "%s"' % type_.collation
|
||||
if length:
|
||||
text += f"({length})"
|
||||
if collation:
|
||||
text += f' COLLATE "{collation}"'
|
||||
return text
|
||||
|
||||
def visit_CHAR(self, type_, **kw):
|
||||
return self._render_string_type(type_, "CHAR")
|
||||
def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str:
|
||||
return self._render_string_type("CHAR", type_.length, type_.collation)
|
||||
|
||||
def visit_NCHAR(self, type_, **kw):
|
||||
return self._render_string_type(type_, "NCHAR")
|
||||
def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str:
|
||||
return self._render_string_type("NCHAR", type_.length, type_.collation)
|
||||
|
||||
def visit_VARCHAR(self, type_, **kw):
|
||||
return self._render_string_type(type_, "VARCHAR")
|
||||
def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
|
||||
return self._render_string_type(
|
||||
"VARCHAR", type_.length, type_.collation
|
||||
)
|
||||
|
||||
def visit_NVARCHAR(self, type_, **kw):
|
||||
return self._render_string_type(type_, "NVARCHAR")
|
||||
def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str:
|
||||
return self._render_string_type(
|
||||
"NVARCHAR", type_.length, type_.collation
|
||||
)
|
||||
|
||||
def visit_TEXT(self, type_, **kw):
|
||||
return self._render_string_type(type_, "TEXT")
|
||||
def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str:
|
||||
return self._render_string_type("TEXT", type_.length, type_.collation)
|
||||
|
||||
def visit_UUID(self, type_, **kw):
|
||||
def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
|
||||
return "UUID"
|
||||
|
||||
def visit_BLOB(self, type_, **kw):
|
||||
def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str:
|
||||
return "BLOB"
|
||||
|
||||
def visit_BINARY(self, type_, **kw):
|
||||
def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str:
|
||||
return "BINARY" + (type_.length and "(%d)" % type_.length or "")
|
||||
|
||||
def visit_VARBINARY(self, type_, **kw):
|
||||
def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str:
|
||||
return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
|
||||
|
||||
def visit_BOOLEAN(self, type_, **kw):
|
||||
def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
|
||||
return "BOOLEAN"
|
||||
|
||||
def visit_uuid(self, type_, **kw):
|
||||
def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
|
||||
if not type_.native_uuid or not self.dialect.supports_native_uuid:
|
||||
return self._render_string_type(type_, "CHAR", length_override=32)
|
||||
return self._render_string_type("CHAR", length=32, collation=None)
|
||||
else:
|
||||
return self.visit_UUID(type_, **kw)
|
||||
|
||||
def visit_large_binary(self, type_, **kw):
|
||||
def visit_large_binary(
|
||||
self, type_: sqltypes.LargeBinary, **kw: Any
|
||||
) -> str:
|
||||
return self.visit_BLOB(type_, **kw)
|
||||
|
||||
def visit_boolean(self, type_, **kw):
|
||||
def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str:
|
||||
return self.visit_BOOLEAN(type_, **kw)
|
||||
|
||||
def visit_time(self, type_, **kw):
|
||||
def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str:
|
||||
return self.visit_TIME(type_, **kw)
|
||||
|
||||
def visit_datetime(self, type_, **kw):
|
||||
def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str:
|
||||
return self.visit_DATETIME(type_, **kw)
|
||||
|
||||
def visit_date(self, type_, **kw):
|
||||
def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str:
|
||||
return self.visit_DATE(type_, **kw)
|
||||
|
||||
def visit_big_integer(self, type_, **kw):
|
||||
def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
|
||||
return self.visit_BIGINT(type_, **kw)
|
||||
|
||||
def visit_small_integer(self, type_, **kw):
|
||||
def visit_small_integer(
|
||||
self, type_: sqltypes.SmallInteger, **kw: Any
|
||||
) -> str:
|
||||
return self.visit_SMALLINT(type_, **kw)
|
||||
|
||||
def visit_integer(self, type_, **kw):
|
||||
def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str:
|
||||
return self.visit_INTEGER(type_, **kw)
|
||||
|
||||
def visit_real(self, type_, **kw):
|
||||
def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
|
||||
return self.visit_REAL(type_, **kw)
|
||||
|
||||
def visit_float(self, type_, **kw):
|
||||
def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
|
||||
return self.visit_FLOAT(type_, **kw)
|
||||
|
||||
def visit_double(self, type_, **kw):
|
||||
def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
|
||||
return self.visit_DOUBLE(type_, **kw)
|
||||
|
||||
def visit_numeric(self, type_, **kw):
|
||||
def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
|
||||
return self.visit_NUMERIC(type_, **kw)
|
||||
|
||||
def visit_string(self, type_, **kw):
|
||||
def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
|
||||
return self.visit_VARCHAR(type_, **kw)
|
||||
|
||||
def visit_unicode(self, type_, **kw):
|
||||
def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str:
|
||||
return self.visit_VARCHAR(type_, **kw)
|
||||
|
||||
def visit_text(self, type_, **kw):
|
||||
def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str:
|
||||
return self.visit_TEXT(type_, **kw)
|
||||
|
||||
def visit_unicode_text(self, type_, **kw):
|
||||
def visit_unicode_text(
|
||||
self, type_: sqltypes.UnicodeText, **kw: Any
|
||||
) -> str:
|
||||
return self.visit_TEXT(type_, **kw)
|
||||
|
||||
def visit_enum(self, type_, **kw):
|
||||
def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str:
|
||||
return self.visit_VARCHAR(type_, **kw)
|
||||
|
||||
def visit_null(self, type_, **kw):
|
||||
@@ -7265,10 +7411,14 @@ class GenericTypeCompiler(TypeCompiler):
|
||||
"type on this Column?" % type_
|
||||
)
|
||||
|
||||
def visit_type_decorator(self, type_, **kw):
|
||||
def visit_type_decorator(
|
||||
self, type_: TypeDecorator[Any], **kw: Any
|
||||
) -> str:
|
||||
return self.process(type_.type_engine(self.dialect), **kw)
|
||||
|
||||
def visit_user_defined(self, type_, **kw):
|
||||
def visit_user_defined(
|
||||
self, type_: UserDefinedType[Any], **kw: Any
|
||||
) -> str:
|
||||
return type_.get_col_spec(**kw)
|
||||
|
||||
|
||||
@@ -7343,12 +7493,12 @@ class IdentifierPreparer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dialect,
|
||||
initial_quote='"',
|
||||
final_quote=None,
|
||||
escape_quote='"',
|
||||
quote_case_sensitive_collations=True,
|
||||
omit_schema=False,
|
||||
dialect: Dialect,
|
||||
initial_quote: str = '"',
|
||||
final_quote: Optional[str] = None,
|
||||
escape_quote: str = '"',
|
||||
quote_case_sensitive_collations: bool = True,
|
||||
omit_schema: bool = False,
|
||||
):
|
||||
"""Construct a new ``IdentifierPreparer`` object.
|
||||
|
||||
@@ -7401,7 +7551,9 @@ class IdentifierPreparer:
|
||||
prep._includes_none_schema_translate = includes_none
|
||||
return prep
|
||||
|
||||
def _render_schema_translates(self, statement, schema_translate_map):
|
||||
def _render_schema_translates(
|
||||
self, statement: str, schema_translate_map: SchemaTranslateMapType
|
||||
) -> str:
|
||||
d = schema_translate_map
|
||||
if None in d:
|
||||
if not self._includes_none_schema_translate:
|
||||
@@ -7413,7 +7565,7 @@ class IdentifierPreparer:
|
||||
"schema_translate_map dictionaries."
|
||||
)
|
||||
|
||||
d["_none"] = d[None]
|
||||
d["_none"] = d[None] # type: ignore[index]
|
||||
|
||||
def replace(m):
|
||||
name = m.group(2)
|
||||
@@ -7606,7 +7758,9 @@ class IdentifierPreparer:
|
||||
else:
|
||||
return collation_name
|
||||
|
||||
def format_sequence(self, sequence, use_schema=True):
|
||||
def format_sequence(
|
||||
self, sequence: schema.Sequence, use_schema: bool = True
|
||||
) -> str:
|
||||
name = self.quote(sequence.name)
|
||||
|
||||
effective_schema = self.schema_for_object(sequence)
|
||||
@@ -7643,7 +7797,9 @@ class IdentifierPreparer:
|
||||
return ident
|
||||
|
||||
@util.preload_module("sqlalchemy.sql.naming")
|
||||
def format_constraint(self, constraint, _alembic_quote=True):
|
||||
def format_constraint(
|
||||
self, constraint: Union[Constraint, Index], _alembic_quote: bool = True
|
||||
) -> Optional[str]:
|
||||
naming = util.preloaded.sql_naming
|
||||
|
||||
if constraint.name is _NONE_NAME:
|
||||
@@ -7656,6 +7812,7 @@ class IdentifierPreparer:
|
||||
else:
|
||||
name = constraint.name
|
||||
|
||||
assert name is not None
|
||||
if constraint.__visit_name__ == "index":
|
||||
return self.truncate_and_render_index_name(
|
||||
name, _alembic_quote=_alembic_quote
|
||||
@@ -7665,7 +7822,9 @@ class IdentifierPreparer:
|
||||
name, _alembic_quote=_alembic_quote
|
||||
)
|
||||
|
||||
def truncate_and_render_index_name(self, name, _alembic_quote=True):
|
||||
def truncate_and_render_index_name(
|
||||
self, name: str, _alembic_quote: bool = True
|
||||
) -> str:
|
||||
# calculate these at format time so that ad-hoc changes
|
||||
# to dialect.max_identifier_length etc. can be reflected
|
||||
# as IdentifierPreparer is long lived
|
||||
@@ -7677,7 +7836,9 @@ class IdentifierPreparer:
|
||||
name, max_, _alembic_quote
|
||||
)
|
||||
|
||||
def truncate_and_render_constraint_name(self, name, _alembic_quote=True):
|
||||
def truncate_and_render_constraint_name(
|
||||
self, name: str, _alembic_quote: bool = True
|
||||
) -> str:
|
||||
# calculate these at format time so that ad-hoc changes
|
||||
# to dialect.max_identifier_length etc. can be reflected
|
||||
# as IdentifierPreparer is long lived
|
||||
@@ -7689,7 +7850,9 @@ class IdentifierPreparer:
|
||||
name, max_, _alembic_quote
|
||||
)
|
||||
|
||||
def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote):
|
||||
def _truncate_and_render_maxlen_name(
|
||||
self, name: str, max_: int, _alembic_quote: bool
|
||||
) -> str:
|
||||
if isinstance(name, elements._truncated_label):
|
||||
if len(name) > max_:
|
||||
name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
|
||||
@@ -7701,13 +7864,21 @@ class IdentifierPreparer:
|
||||
else:
|
||||
return self.quote(name)
|
||||
|
||||
def format_index(self, index):
|
||||
return self.format_constraint(index)
|
||||
def format_index(self, index: Index) -> str:
|
||||
name = self.format_constraint(index)
|
||||
assert name is not None
|
||||
return name
|
||||
|
||||
def format_table(self, table, use_schema=True, name=None):
|
||||
def format_table(
|
||||
self,
|
||||
table: FromClause,
|
||||
use_schema: bool = True,
|
||||
name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Prepare a quoted table and schema name."""
|
||||
|
||||
if name is None:
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(table, NamedFromClause)
|
||||
name = table.name
|
||||
|
||||
result = self.quote(name)
|
||||
@@ -7739,17 +7910,18 @@ class IdentifierPreparer:
|
||||
|
||||
def format_column(
|
||||
self,
|
||||
column,
|
||||
use_table=False,
|
||||
name=None,
|
||||
table_name=None,
|
||||
use_schema=False,
|
||||
anon_map=None,
|
||||
):
|
||||
column: ColumnElement[Any],
|
||||
use_table: bool = False,
|
||||
name: Optional[str] = None,
|
||||
table_name: Optional[str] = None,
|
||||
use_schema: bool = False,
|
||||
anon_map: Optional[Mapping[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Prepare a quoted column name."""
|
||||
|
||||
if name is None:
|
||||
name = column.name
|
||||
assert name is not None
|
||||
|
||||
if anon_map is not None and isinstance(
|
||||
name, elements._truncated_label
|
||||
@@ -7817,7 +7989,7 @@ class IdentifierPreparer:
|
||||
)
|
||||
return r
|
||||
|
||||
def unformat_identifiers(self, identifiers):
|
||||
def unformat_identifiers(self, identifiers: str) -> Sequence[str]:
|
||||
"""Unpack 'schema.table.column'-like strings into components."""
|
||||
|
||||
r = self._r_identifiers
|
||||
|
||||
Reference in New Issue
Block a user