python学习

This commit is contained in:
2019-07-17 16:36:59 +08:00
commit 596a0bf6bf
786 changed files with 300160 additions and 0 deletions

View File

@@ -0,0 +1,106 @@
# sql/__init__.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from .expression import Alias # noqa
from .expression import alias # noqa
from .expression import all_ # noqa
from .expression import and_ # noqa
from .expression import any_ # noqa
from .expression import asc # noqa
from .expression import between # noqa
from .expression import bindparam # noqa
from .expression import case # noqa
from .expression import cast # noqa
from .expression import ClauseElement # noqa
from .expression import collate # noqa
from .expression import column # noqa
from .expression import ColumnCollection # noqa
from .expression import ColumnElement # noqa
from .expression import CompoundSelect # noqa
from .expression import cte # noqa
from .expression import Delete # noqa
from .expression import delete # noqa
from .expression import desc # noqa
from .expression import distinct # noqa
from .expression import except_ # noqa
from .expression import except_all # noqa
from .expression import exists # noqa
from .expression import extract # noqa
from .expression import false # noqa
from .expression import False_ # noqa
from .expression import FromClause # noqa
from .expression import func # noqa
from .expression import funcfilter # noqa
from .expression import Insert # noqa
from .expression import insert # noqa
from .expression import intersect # noqa
from .expression import intersect_all # noqa
from .expression import Join # noqa
from .expression import join # noqa
from .expression import label # noqa
from .expression import lateral # noqa
from .expression import literal # noqa
from .expression import literal_column # noqa
from .expression import modifier # noqa
from .expression import not_ # noqa
from .expression import null # noqa
from .expression import nullsfirst # noqa
from .expression import nullslast # noqa
from .expression import or_ # noqa
from .expression import outerjoin # noqa
from .expression import outparam # noqa
from .expression import over # noqa
from .expression import quoted_name # noqa
from .expression import Select # noqa
from .expression import select # noqa
from .expression import Selectable # noqa
from .expression import subquery # noqa
from .expression import table # noqa
from .expression import TableClause # noqa
from .expression import TableSample # noqa
from .expression import tablesample # noqa
from .expression import text # noqa
from .expression import true # noqa
from .expression import True_ # noqa
from .expression import tuple_ # noqa
from .expression import type_coerce # noqa
from .expression import union # noqa
from .expression import union_all # noqa
from .expression import Update # noqa
from .expression import update # noqa
from .expression import within_group # noqa
from .visitors import ClauseVisitor # noqa
def __go(lcls):
global __all__
from .. import util as _sa_util
import inspect as _inspect
__all__ = sorted(
name
for name, obj in lcls.items()
if not (name.startswith("_") or _inspect.ismodule(obj))
)
from .annotation import _prepare_annotations
from .annotation import Annotated # noqa
from .elements import AnnotatedColumnElement
from .elements import ClauseList # noqa
from .selectable import AnnotatedFromClause # noqa
_prepare_annotations(ColumnElement, AnnotatedColumnElement)
_prepare_annotations(FromClause, AnnotatedFromClause)
_prepare_annotations(ClauseList, Annotated)
_sa_util.dependencies.resolve_all("sqlalchemy.sql")
from . import naming # noqa
__go(locals())

View File

@@ -0,0 +1,208 @@
# sql/annotation.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""The :class:`.Annotated` class and related routines; creates hash-equivalent
copies of SQL constructs which contain context-specific markers and
associations.
"""
from . import operators
from .. import util
class Annotated(object):
"""clones a ClauseElement and applies an 'annotations' dictionary.
Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
in hashed collections.
A reference to the original element is maintained, for the important
reason of keeping its hash value current. When GC'ed, the
hash value may be reused, causing conflicts.
.. note:: The rationale for Annotated producing a brand new class,
rather than placing the functionality directly within ClauseElement,
is **performance**. The __hash__() method is absent on plain
ClauseElement which leads to significantly reduced function call
overhead, as the use of sets and dictionaries against ClauseElement
objects is prevalent, but most are not "annotated".
"""
def __new__(cls, *args):
if not args:
# clone constructor
return object.__new__(cls)
else:
element, values = args
# pull appropriate subclass from registry of annotated
# classes
try:
cls = annotated_classes[element.__class__]
except KeyError:
cls = _new_annotation_type(element.__class__, cls)
return object.__new__(cls)
def __init__(self, element, values):
self.__dict__ = element.__dict__.copy()
self.__element = element
self._annotations = values
self._hash = hash(element)
def _annotate(self, values):
_values = self._annotations.copy()
_values.update(values)
return self._with_annotations(_values)
def _with_annotations(self, values):
clone = self.__class__.__new__(self.__class__)
clone.__dict__ = self.__dict__.copy()
clone._annotations = values
return clone
def _deannotate(self, values=None, clone=True):
if values is None:
return self.__element
else:
_values = self._annotations.copy()
for v in values:
_values.pop(v, None)
return self._with_annotations(_values)
def _compiler_dispatch(self, visitor, **kw):
return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
@property
def _constructor(self):
return self.__element._constructor
def _clone(self):
clone = self.__element._clone()
if clone is self.__element:
# detect immutable, don't change anything
return self
else:
# update the clone with any changes that have occurred
# to this object's __dict__.
clone.__dict__.update(self.__dict__)
return self.__class__(clone, self._annotations)
def __reduce__(self):
return self.__class__, (self.__element, self._annotations)
def __hash__(self):
return self._hash
def __eq__(self, other):
if isinstance(self.__element, operators.ColumnOperators):
return self.__element.__class__.__eq__(self, other)
else:
return hash(other) == hash(self)
# hard-generate Annotated subclasses. this technique
# is used instead of on-the-fly types (i.e. type.__new__())
# so that the resulting objects are pickleable.
annotated_classes = {}
def _deep_annotate(element, annotations, exclude=None):
"""Deep copy the given ClauseElement, annotating each element
with the given annotations dictionary.
Elements within the exclude collection will be cloned but not annotated.
"""
def clone(elem):
if (
exclude
and hasattr(elem, "proxy_set")
and elem.proxy_set.intersection(exclude)
):
newelem = elem._clone()
elif annotations != elem._annotations:
newelem = elem._annotate(annotations)
else:
newelem = elem
newelem._copy_internals(clone=clone)
return newelem
if element is not None:
element = clone(element)
return element
def _deep_deannotate(element, values=None):
"""Deep copy the given element, removing annotations."""
cloned = util.column_dict()
def clone(elem):
# if a values dict is given,
# the elem must be cloned each time it appears,
# as there may be different annotations in source
# elements that are remaining. if totally
# removing all annotations, can assume the same
# slate...
if values or elem not in cloned:
newelem = elem._deannotate(values=values, clone=True)
newelem._copy_internals(clone=clone)
if not values:
cloned[elem] = newelem
return newelem
else:
return cloned[elem]
if element is not None:
element = clone(element)
return element
def _shallow_annotate(element, annotations):
"""Annotate the given ClauseElement and copy its internals so that
internal objects refer to the new annotated object.
Basically used to apply a "dont traverse" annotation to a
selectable, without digging throughout the whole
structure wasting time.
"""
element = element._annotate(annotations)
element._copy_internals()
return element
def _new_annotation_type(cls, base_cls):
if issubclass(cls, Annotated):
return cls
elif cls in annotated_classes:
return annotated_classes[cls]
for super_ in cls.__mro__:
# check if an Annotated subclass more specific than
# the given base_cls is already registered, such
# as AnnotatedColumnElement.
if super_ in annotated_classes:
base_cls = annotated_classes[super_]
break
annotated_classes[cls] = anno_cls = type(
"Annotated%s" % cls.__name__, (base_cls, cls), {}
)
globals()["Annotated%s" % cls.__name__] = anno_cls
return anno_cls
def _prepare_annotations(target_hierarchy, base_cls):
stack = [target_hierarchy]
while stack:
cls = stack.pop()
stack.extend(cls.__subclasses__())
_new_annotation_type(cls, base_cls)

View File

@@ -0,0 +1,670 @@
# sql/base.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Foundational utilities common to many sql modules.
"""
import itertools
import re
from .visitors import ClauseVisitor
from .. import exc
from .. import util
PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT")
NO_ARG = util.symbol("NO_ARG")
class Immutable(object):
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
def unique_params(self, *optionaldict, **kwargs):
raise NotImplementedError("Immutable objects do not support copying")
def params(self, *optionaldict, **kwargs):
raise NotImplementedError("Immutable objects do not support copying")
def _clone(self):
return self
def _from_objects(*elements):
return itertools.chain(*[element._from_objects for element in elements])
@util.decorator
def _generative(fn, *args, **kw):
"""Mark a method as generative."""
self = args[0]._generate()
fn(self, *args[1:], **kw)
return self
class _DialectArgView(util.collections_abc.MutableMapping):
"""A dictionary view of dialect-level arguments in the form
<dialectname>_<argument_name>.
"""
def __init__(self, obj):
self.obj = obj
def _key(self, key):
try:
dialect, value_key = key.split("_", 1)
except ValueError:
raise KeyError(key)
else:
return dialect, value_key
def __getitem__(self, key):
dialect, value_key = self._key(key)
try:
opt = self.obj.dialect_options[dialect]
except exc.NoSuchModuleError:
raise KeyError(key)
else:
return opt[value_key]
def __setitem__(self, key, value):
try:
dialect, value_key = self._key(key)
except KeyError:
raise exc.ArgumentError(
"Keys must be of the form <dialectname>_<argname>"
)
else:
self.obj.dialect_options[dialect][value_key] = value
def __delitem__(self, key):
dialect, value_key = self._key(key)
del self.obj.dialect_options[dialect][value_key]
def __len__(self):
return sum(
len(args._non_defaults)
for args in self.obj.dialect_options.values()
)
def __iter__(self):
return (
util.safe_kwarg("%s_%s" % (dialect_name, value_name))
for dialect_name in self.obj.dialect_options
for value_name in self.obj.dialect_options[
dialect_name
]._non_defaults
)
class _DialectArgDict(util.collections_abc.MutableMapping):
"""A dictionary view of dialect-level arguments for a specific
dialect.
Maintains a separate collection of user-specified arguments
and dialect-specified default arguments.
"""
def __init__(self):
self._non_defaults = {}
self._defaults = {}
def __len__(self):
return len(set(self._non_defaults).union(self._defaults))
def __iter__(self):
return iter(set(self._non_defaults).union(self._defaults))
def __getitem__(self, key):
if key in self._non_defaults:
return self._non_defaults[key]
else:
return self._defaults[key]
def __setitem__(self, key, value):
self._non_defaults[key] = value
def __delitem__(self, key):
del self._non_defaults[key]
class DialectKWArgs(object):
"""Establish the ability for a class to have dialect-specific arguments
with defaults and constructor validation.
The :class:`.DialectKWArgs` interacts with the
:attr:`.DefaultDialect.construct_arguments` present on a dialect.
.. seealso::
:attr:`.DefaultDialect.construct_arguments`
"""
@classmethod
def argument_for(cls, dialect_name, argument_name, default):
"""Add a new kind of dialect-specific keyword argument for this class.
E.g.::
Index.argument_for("mydialect", "length", None)
some_index = Index('a', 'b', mydialect_length=5)
The :meth:`.DialectKWArgs.argument_for` method is a per-argument
way adding extra arguments to the
:attr:`.DefaultDialect.construct_arguments` dictionary. This
dictionary provides a list of argument names accepted by various
schema-level constructs on behalf of a dialect.
New dialects should typically specify this dictionary all at once as a
data member of the dialect class. The use case for ad-hoc addition of
argument names is typically for end-user code that is also using
a custom compilation scheme which consumes the additional arguments.
:param dialect_name: name of a dialect. The dialect must be
locatable, else a :class:`.NoSuchModuleError` is raised. The
dialect must also include an existing
:attr:`.DefaultDialect.construct_arguments` collection, indicating
that it participates in the keyword-argument validation and default
system, else :class:`.ArgumentError` is raised. If the dialect does
not include this collection, then any keyword argument can be
specified on behalf of this dialect already. All dialects packaged
within SQLAlchemy include this collection, however for third party
dialects, support may vary.
:param argument_name: name of the parameter.
:param default: default value of the parameter.
.. versionadded:: 0.9.4
"""
construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
if construct_arg_dictionary is None:
raise exc.ArgumentError(
"Dialect '%s' does have keyword-argument "
"validation and defaults enabled configured" % dialect_name
)
if cls not in construct_arg_dictionary:
construct_arg_dictionary[cls] = {}
construct_arg_dictionary[cls][argument_name] = default
@util.memoized_property
def dialect_kwargs(self):
"""A collection of keyword arguments specified as dialect-specific
options to this construct.
The arguments are present here in their original ``<dialect>_<kwarg>``
format. Only arguments that were actually passed are included;
unlike the :attr:`.DialectKWArgs.dialect_options` collection, which
contains all options known by this dialect including defaults.
The collection is also writable; keys are accepted of the
form ``<dialect>_<kwarg>`` where the value will be assembled
into the list of options.
.. versionadded:: 0.9.2
.. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs`
collection is now writable.
.. seealso::
:attr:`.DialectKWArgs.dialect_options` - nested dictionary form
"""
return _DialectArgView(self)
@property
def kwargs(self):
"""A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
return self.dialect_kwargs
@util.dependencies("sqlalchemy.dialects")
def _kw_reg_for_dialect(dialects, dialect_name):
dialect_cls = dialects.registry.load(dialect_name)
if dialect_cls.construct_arguments is None:
return None
return dict(dialect_cls.construct_arguments)
_kw_registry = util.PopulateDict(_kw_reg_for_dialect)
def _kw_reg_for_dialect_cls(self, dialect_name):
construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
d = _DialectArgDict()
if construct_arg_dictionary is None:
d._defaults.update({"*": None})
else:
for cls in reversed(self.__class__.__mro__):
if cls in construct_arg_dictionary:
d._defaults.update(construct_arg_dictionary[cls])
return d
@util.memoized_property
def dialect_options(self):
"""A collection of keyword arguments specified as dialect-specific
options to this construct.
This is a two-level nested registry, keyed to ``<dialect_name>``
and ``<argument_name>``. For example, the ``postgresql_where``
argument would be locatable as::
arg = my_object.dialect_options['postgresql']['where']
.. versionadded:: 0.9.2
.. seealso::
:attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form
"""
return util.PopulateDict(
util.portable_instancemethod(self._kw_reg_for_dialect_cls)
)
def _validate_dialect_kwargs(self, kwargs):
# validate remaining kwargs that they all specify DB prefixes
if not kwargs:
return
for k in kwargs:
m = re.match("^(.+?)_(.+)$", k)
if not m:
raise TypeError(
"Additional arguments should be "
"named <dialectname>_<argument>, got '%s'" % k
)
dialect_name, arg_name = m.group(1, 2)
try:
construct_arg_dictionary = self.dialect_options[dialect_name]
except exc.NoSuchModuleError:
util.warn(
"Can't validate argument %r; can't "
"locate any SQLAlchemy dialect named %r"
% (k, dialect_name)
)
self.dialect_options[dialect_name] = d = _DialectArgDict()
d._defaults.update({"*": None})
d._non_defaults[arg_name] = kwargs[k]
else:
if (
"*" not in construct_arg_dictionary
and arg_name not in construct_arg_dictionary
):
raise exc.ArgumentError(
"Argument %r is not accepted by "
"dialect %r on behalf of %r"
% (k, dialect_name, self.__class__)
)
else:
construct_arg_dictionary[arg_name] = kwargs[k]
class Generative(object):
"""Allow a ClauseElement to generate itself via the
@_generative decorator.
"""
def _generate(self):
s = self.__class__.__new__(self.__class__)
s.__dict__ = self.__dict__.copy()
return s
class Executable(Generative):
"""Mark a ClauseElement as supporting execution.
:class:`.Executable` is a superclass for all "statement" types
of objects, including :func:`select`, :func:`delete`, :func:`update`,
:func:`insert`, :func:`text`.
"""
supports_execution = True
_execution_options = util.immutabledict()
_bind = None
@_generative
def execution_options(self, **kw):
""" Set non-SQL options for the statement which take effect during
execution.
Execution options can be set on a per-statement or
per :class:`.Connection` basis. Additionally, the
:class:`.Engine` and ORM :class:`~.orm.query.Query` objects provide
access to execution options which they in turn configure upon
connections.
The :meth:`execution_options` method is generative. A new
instance of this statement is returned that contains the options::
statement = select([table.c.x, table.c.y])
statement = statement.execution_options(autocommit=True)
Note that only a subset of possible execution options can be applied
to a statement - these include "autocommit" and "stream_results",
but not "isolation_level" or "compiled_cache".
See :meth:`.Connection.execution_options` for a full list of
possible options.
.. seealso::
:meth:`.Connection.execution_options`
:meth:`.Query.execution_options`
:meth:`.Executable.get_execution_options`
"""
if "isolation_level" in kw:
raise exc.ArgumentError(
"'isolation_level' execution option may only be specified "
"on Connection.execution_options(), or "
"per-engine using the isolation_level "
"argument to create_engine()."
)
if "compiled_cache" in kw:
raise exc.ArgumentError(
"'compiled_cache' execution option may only be specified "
"on Connection.execution_options(), not per statement."
)
self._execution_options = self._execution_options.union(kw)
def get_execution_options(self):
""" Get the non-SQL options which will take effect during execution.
.. versionadded:: 1.3
.. seealso::
:meth:`.Executable.execution_options`
"""
return self._execution_options
def execute(self, *multiparams, **params):
"""Compile and execute this :class:`.Executable`."""
e = self.bind
if e is None:
label = getattr(self, "description", self.__class__.__name__)
msg = (
"This %s is not directly bound to a Connection or Engine. "
"Use the .execute() method of a Connection or Engine "
"to execute this construct." % label
)
raise exc.UnboundExecutionError(msg)
return e._execute_clauseelement(self, multiparams, params)
def scalar(self, *multiparams, **params):
"""Compile and execute this :class:`.Executable`, returning the
result's scalar representation.
"""
return self.execute(*multiparams, **params).scalar()
@property
def bind(self):
"""Returns the :class:`.Engine` or :class:`.Connection` to
which this :class:`.Executable` is bound, or None if none found.
This is a traversal which checks locally, then
checks among the "from" clauses of associated objects
until a bound engine or connection is found.
"""
if self._bind is not None:
return self._bind
for f in _from_objects(self):
if f is self:
continue
engine = f.bind
if engine is not None:
return engine
else:
return None
class SchemaEventTarget(object):
"""Base class for elements that are the targets of :class:`.DDLEvents`
events.
This includes :class:`.SchemaItem` as well as :class:`.SchemaType`.
"""
def _set_parent(self, parent):
"""Associate with this SchemaEvent's parent object."""
def _set_parent_with_dispatch(self, parent):
self.dispatch.before_parent_attach(self, parent)
self._set_parent(parent)
self.dispatch.after_parent_attach(self, parent)
class SchemaVisitor(ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
__traverse_options__ = {"schema_visitor": True}
class ColumnCollection(util.OrderedProperties):
"""An ordered dictionary that stores a list of ColumnElement
instances.
Overrides the ``__eq__()`` method to produce SQL clauses between
sets of correlated columns.
"""
__slots__ = "_all_columns"
def __init__(self, *columns):
super(ColumnCollection, self).__init__()
object.__setattr__(self, "_all_columns", [])
for c in columns:
self.add(c)
def __str__(self):
return repr([str(c) for c in self])
def replace(self, column):
"""add the given column to this collection, removing unaliased
versions of this column as well as existing columns with the
same key.
e.g.::
t = Table('sometable', metadata, Column('col1', Integer))
t.columns.replace(Column('col1', Integer, key='columnone'))
will remove the original 'col1' from the collection, and add
the new column under the name 'columnname'.
Used by schema.Column to override columns during table reflection.
"""
remove_col = None
if column.name in self and column.key != column.name:
other = self[column.name]
if other.name == other.key:
remove_col = other
del self._data[other.key]
if column.key in self._data:
remove_col = self._data[column.key]
self._data[column.key] = column
if remove_col is not None:
self._all_columns[:] = [
column if c is remove_col else c for c in self._all_columns
]
else:
self._all_columns.append(column)
def add(self, column):
"""Add a column to this collection.
The key attribute of the column will be used as the hash key
for this dictionary.
"""
if not column.key:
raise exc.ArgumentError(
"Can't add unnamed column to column collection"
)
self[column.key] = column
def __delitem__(self, key):
raise NotImplementedError()
def __setattr__(self, key, obj):
raise NotImplementedError()
def __setitem__(self, key, value):
if key in self:
# this warning is primarily to catch select() statements
# which have conflicting column names in their exported
# columns collection
existing = self[key]
if existing is value:
return
if not existing.shares_lineage(value):
util.warn(
"Column %r on table %r being replaced by "
"%r, which has the same key. Consider "
"use_labels for select() statements."
% (key, getattr(existing, "table", None), value)
)
# pop out memoized proxy_set as this
# operation may very well be occurring
# in a _make_proxy operation
util.memoized_property.reset(value, "proxy_set")
self._all_columns.append(value)
self._data[key] = value
def clear(self):
raise NotImplementedError()
def remove(self, column):
del self._data[column.key]
self._all_columns[:] = [
c for c in self._all_columns if c is not column
]
def update(self, iter_):
cols = list(iter_)
all_col_set = set(self._all_columns)
self._all_columns.extend(
c for label, c in cols if c not in all_col_set
)
self._data.update((label, c) for label, c in cols)
def extend(self, iter_):
cols = list(iter_)
all_col_set = set(self._all_columns)
self._all_columns.extend(c for c in cols if c not in all_col_set)
self._data.update((c.key, c) for c in cols)
__hash__ = None
@util.dependencies("sqlalchemy.sql.elements")
def __eq__(self, elements, other):
l = []
for c in getattr(other, "_all_columns", other):
for local in self._all_columns:
if c.shares_lineage(local):
l.append(c == local)
return elements.and_(*l)
def __contains__(self, other):
if not isinstance(other, util.string_types):
raise exc.ArgumentError("__contains__ requires a string argument")
return util.OrderedProperties.__contains__(self, other)
def __getstate__(self):
return {"_data": self._data, "_all_columns": self._all_columns}
def __setstate__(self, state):
object.__setattr__(self, "_data", state["_data"])
object.__setattr__(self, "_all_columns", state["_all_columns"])
def contains_column(self, col):
return col in set(self._all_columns)
def as_immutable(self):
return ImmutableColumnCollection(self._data, self._all_columns)
class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
def __init__(self, data, all_columns):
util.ImmutableProperties.__init__(self, data)
object.__setattr__(self, "_all_columns", all_columns)
extend = remove = util.ImmutableProperties._immutable
class ColumnSet(util.ordered_column_set):
def contains_column(self, col):
return col in self
def extend(self, cols):
for col in cols:
self.add(col)
def __add__(self, other):
return list(self) + list(other)
@util.dependencies("sqlalchemy.sql.elements")
def __eq__(self, elements, other):
l = []
for c in other:
for local in self:
if c.shares_lineage(local):
l.append(c == local)
return elements.and_(*l)
def __hash__(self):
return hash(tuple(x for x in self))
def _bind_or_error(schemaitem, msg=None):
bind = schemaitem.bind
if not bind:
name = schemaitem.__class__.__name__
label = getattr(
schemaitem, "fullname", getattr(schemaitem, "name", None)
)
if label:
item = "%s object %r" % (name, label)
else:
item = "%s object" % name
if msg is None:
msg = (
"%s is not bound to an Engine or Connection. "
"Execution can not proceed without a database to execute "
"against." % item
)
raise exc.UnboundExecutionError(msg)
return bind

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,799 @@
# sql/crud.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Functions used by compiler.py to determine the parameters rendered
within INSERT and UPDATE statements.
"""
import operator
from . import dml
from . import elements
from .. import exc
from .. import util
REQUIRED = util.symbol(
"REQUIRED",
"""
Placeholder for the value within a :class:`.BindParameter`
which is required to be present when the statement is passed
to :meth:`.Connection.execute`.
This symbol is typically used when a :func:`.expression.insert`
or :func:`.expression.update` statement is compiled without parameter
values present.
""",
)
ISINSERT = util.symbol("ISINSERT")
ISUPDATE = util.symbol("ISUPDATE")
ISDELETE = util.symbol("ISDELETE")
def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
restore_isinsert = compiler.isinsert
restore_isupdate = compiler.isupdate
restore_isdelete = compiler.isdelete
should_restore = (
restore_isinsert or restore_isupdate or restore_isdelete
) or len(compiler.stack) > 1
if local_stmt_type is ISINSERT:
compiler.isupdate = False
compiler.isinsert = True
elif local_stmt_type is ISUPDATE:
compiler.isupdate = True
compiler.isinsert = False
elif local_stmt_type is ISDELETE:
if not should_restore:
compiler.isdelete = True
else:
assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
try:
if local_stmt_type in (ISINSERT, ISUPDATE):
return _get_crud_params(compiler, stmt, **kw)
finally:
if should_restore:
compiler.isinsert = restore_isinsert
compiler.isupdate = restore_isupdate
compiler.isdelete = restore_isdelete
def _get_crud_params(compiler, stmt, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
Also generates the Compiled object's postfetch, prefetch, and
returning column collections, used for default handling and ultimately
populating the ResultProxy's prefetch_cols() and postfetch_cols()
collections.
"""
compiler.postfetch = []
compiler.insert_prefetch = []
compiler.update_prefetch = []
compiler.returning = []
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and stmt.parameters is None:
return [
(c, _create_bind_param(compiler, c, None, required=True))
for c in stmt.table.columns
]
if stmt._has_multi_parameters:
stmt_parameters = stmt.parameters[0]
else:
stmt_parameters = stmt.parameters
# getters - these are normally just column.key,
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
(
_column_as_key,
_getattr_col_key,
_col_bind_name,
) = _key_getters_for_crud_column(compiler, stmt)
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
parameters = {}
else:
parameters = dict(
(_column_as_key(key), REQUIRED)
for key in compiler.column_keys
if not stmt_parameters or key not in stmt_parameters
)
# create a list of column assignment clauses as tuples
values = []
if stmt_parameters is not None:
_get_stmt_parameters_params(
compiler, parameters, stmt_parameters, _column_as_key, values, kw
)
check_columns = {}
# special logic that only occurs for multi-table UPDATE
# statements
if compiler.isupdate and stmt._extra_froms and stmt_parameters:
_get_multitable_params(
compiler,
stmt,
stmt_parameters,
check_columns,
_col_bind_name,
_getattr_col_key,
values,
kw,
)
if compiler.isinsert and stmt.select_names:
_scan_insert_from_select_cols(
compiler,
stmt,
parameters,
_getattr_col_key,
_column_as_key,
_col_bind_name,
check_columns,
values,
kw,
)
else:
_scan_cols(
compiler,
stmt,
parameters,
_getattr_col_key,
_column_as_key,
_col_bind_name,
check_columns,
values,
kw,
)
if parameters and stmt_parameters:
check = (
set(parameters)
.intersection(_column_as_key(k) for k in stmt_parameters)
.difference(check_columns)
)
if check:
raise exc.CompileError(
"Unconsumed column names: %s"
% (", ".join("%s" % c for c in check))
)
if stmt._has_multi_parameters:
values = _extend_values_for_multiparams(compiler, stmt, values, kw)
return values
def _create_bind_param(
compiler, col, value, process=True, required=False, name=None, **kw
):
if name is None:
name = col.key
bindparam = elements.BindParameter(
name, value, type_=col.type, required=required
)
bindparam._is_crud = True
if process:
bindparam = bindparam._compiler_dispatch(compiler, **kw)
return bindparam
def _key_getters_for_crud_column(compiler, stmt):
if compiler.isupdate and stmt._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
# the "main" table of the statement remains unqualified,
# allowing the most compatibility with a non-multi-table
# statement.
_et = set(stmt._extra_froms)
def _column_as_key(key):
str_key = elements._column_as_key(key)
if hasattr(key, "table") and key.table in _et:
return (key.table.name, str_key)
else:
return str_key
def _getattr_col_key(col):
if col.table in _et:
return (col.table.name, col.key)
else:
return col.key
def _col_bind_name(col):
if col.table in _et:
return "%s_%s" % (col.table.name, col.key)
else:
return col.key
else:
_column_as_key = elements._column_as_key
_getattr_col_key = _col_bind_name = operator.attrgetter("key")
return _column_as_key, _getattr_col_key, _col_bind_name
def _scan_insert_from_select_cols(
compiler,
stmt,
parameters,
_getattr_col_key,
_column_as_key,
_col_bind_name,
check_columns,
values,
kw,
):
(
need_pks,
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
) = _get_returning_modifiers(compiler, stmt)
cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names]
compiler._insert_from_select = stmt.select
add_select_cols = []
if stmt.include_insert_from_select_defaults:
col_set = set(cols)
for col in stmt.table.columns:
if col not in col_set and col.default:
cols.append(col)
for c in cols:
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
parameters.pop(col_key)
values.append((c, None))
else:
_append_param_insert_select_hasdefault(
compiler, stmt, c, add_select_cols, kw
)
if add_select_cols:
values.extend(add_select_cols)
compiler._insert_from_select = compiler._insert_from_select._generate()
compiler._insert_from_select._raw_columns = tuple(
compiler._insert_from_select._raw_columns
) + tuple(expr for col, expr in add_select_cols)
def _scan_cols(
compiler,
stmt,
parameters,
_getattr_col_key,
_column_as_key,
_col_bind_name,
check_columns,
values,
kw,
):
(
need_pks,
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
) = _get_returning_modifiers(compiler, stmt)
if stmt._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in stmt._parameter_ordering
]
ordered_keys = set(parameter_ordering)
cols = [stmt.table.c[key] for key in parameter_ordering] + [
c for c in stmt.table.c if c.key not in ordered_keys
]
else:
cols = stmt.table.columns
for c in cols:
col_key = _getattr_col_key(c)
if col_key in parameters and col_key not in check_columns:
_append_param_parameter(
compiler,
stmt,
c,
col_key,
parameters,
_col_bind_name,
implicit_returning,
implicit_return_defaults,
values,
kw,
)
elif compiler.isinsert:
if (
c.primary_key
and need_pks
and (
implicit_returning
or not postfetch_lastrowid
or c is not stmt.table._autoincrement_column
)
):
if implicit_returning:
_append_param_insert_pk_returning(
compiler, stmt, c, values, kw
)
else:
_append_param_insert_pk(compiler, stmt, c, values, kw)
elif c.default is not None:
_append_param_insert_hasdefault(
compiler, stmt, c, implicit_return_defaults, values, kw
)
elif c.server_default is not None:
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif (
c.primary_key
and c is not stmt.table._autoincrement_column
and not c.nullable
):
_warn_pk_with_no_anticipated_value(c)
elif compiler.isupdate:
_append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw
)
def _append_param_parameter(
compiler,
stmt,
c,
col_key,
parameters,
_col_bind_name,
implicit_returning,
implicit_return_defaults,
values,
kw,
):
value = parameters.pop(col_key)
if elements._is_literal(value):
value = _create_bind_param(
compiler,
c,
value,
required=value is REQUIRED,
name=_col_bind_name(c)
if not stmt._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw
)
else:
if isinstance(value, elements.BindParameter) and value.type._isnull:
value = value._clone()
value.type = c.type
if c.primary_key and implicit_returning:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
else:
# postfetch specifically means, "we can SELECT the row we just
# inserted by primary key to get back the server generated
# defaults". so by definition this can't be used to get the primary
# key value back, because we need to have it ahead of time.
if not c.primary_key:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
values.append((c, value))
def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
"""Create a primary key expression in the INSERT statement and
possibly a RETURNING clause for it.
If the column has a Python-side default, we will create a bound
parameter for it and "pre-execute" the Python function. If
the column has a SQL expression default, or is a sequence,
we will add it directly into the INSERT statement and add a
RETURNING element to get the new value. If the column has a
server side default or is marked as the "autoincrement" column,
we will add a RETRUNING element to get at the value.
If all the above tests fail, that indicates a primary key column with no
noted default generation capabilities that has no parameter passed;
raise an exception.
"""
if c.default is not None:
if c.default.is_sequence:
if compiler.dialect.supports_sequences and (
not c.default.optional
or not compiler.dialect.sequences_optional
):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
compiler.returning.append(c)
elif c.default.is_clause_element:
values.append(
(c, compiler.process(c.default.arg.self_group(), **kw))
)
compiler.returning.append(c)
else:
values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
elif not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
_warn_pk_with_no_anticipated_value(c)
def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
param = _create_bind_param(compiler, c, None, process=process, name=name)
compiler.insert_prefetch.append(c)
return param
def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
param = _create_bind_param(compiler, c, None, process=process, name=name)
compiler.update_prefetch.append(c)
return param
class _multiparam_column(elements.ColumnElement):
_is_multiparam_column = True
def __init__(self, original, index):
self.index = index
self.key = "%s_m%d" % (original.key, index + 1)
self.original = original
self.default = original.default
self.type = original.type
def __eq__(self, other):
return (
isinstance(other, _multiparam_column)
and other.key == self.key
and other.original == self.original
)
def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
if not c.default:
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
"a Python-side value or SQL expression is required" % c
)
elif c.default.is_clause_element:
return compiler.process(c.default.arg.self_group(), **kw)
else:
col = _multiparam_column(c, index)
if isinstance(stmt, dml.Insert):
return _create_insert_prefetch_bind_param(compiler, col)
else:
return _create_update_prefetch_bind_param(compiler, col)
def _append_param_insert_pk(compiler, stmt, c, values, kw):
"""Create a bound parameter in the INSERT statement to receive a
'prefetched' default value.
The 'prefetched' value indicates that we are to invoke a Python-side
default function or expliclt SQL expression before the INSERT statement
proceeds, so that we have a primary key value available.
if the column has no noted default generation capabilities, it has
no value passed in either; raise an exception.
"""
if (
# column has a Python-side default
c.default is not None
and (
# and it won't be a Sequence
not c.default.is_sequence
or compiler.dialect.supports_sequences
)
) or (
# column is the "autoincrement column"
c is stmt.table._autoincrement_column
and (
# and it's either a "sequence" or a
# pre-executable "autoincrement" sequence
compiler.dialect.supports_sequences
or compiler.dialect.preexecute_autoincrement_sequences
)
):
values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
_warn_pk_with_no_anticipated_value(c)
def _append_param_insert_hasdefault(
compiler, stmt, c, implicit_return_defaults, values, kw
):
if c.default.is_sequence:
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
elif c.default.is_clause_element:
proc = compiler.process(c.default.arg.self_group(), **kw)
values.append((c, proc))
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
# don't add primary key column to postfetch
compiler.postfetch.append(c)
else:
values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
if c.default.is_sequence:
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
proc = c.default
values.append((c, proc.next_value()))
elif c.default.is_clause_element:
proc = c.default.arg.self_group()
values.append((c, proc))
else:
values.append(
(c, _create_insert_prefetch_bind_param(compiler, c, process=False))
)
def _append_param_update(
compiler, stmt, c, implicit_return_defaults, values, kw
):
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
(c, compiler.process(c.onupdate.arg.self_group(), **kw))
)
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
else:
values.append((c, _create_update_prefetch_bind_param(compiler, c)))
elif c.server_onupdate is not None:
if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
elif (
implicit_return_defaults
and stmt._return_defaults is not True
and c in implicit_return_defaults
):
compiler.returning.append(c)
def _get_multitable_params(
compiler,
stmt,
stmt_parameters,
check_columns,
_col_bind_name,
_getattr_col_key,
values,
kw,
):
normalized_params = dict(
(elements._clause_element_as_expr(c), param)
for c, param in stmt_parameters.items()
)
affected_tables = set()
for t in stmt._extra_froms:
for c in t.c:
if c in normalized_params:
affected_tables.add(t)
check_columns[_getattr_col_key(c)] = c
value = normalized_params[c]
if elements._is_literal(value):
value = _create_bind_param(
compiler,
c,
value,
required=value is REQUIRED,
name=_col_bind_name(c),
)
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
values.append((c, value))
# determine tables which are actually to be updated - process onupdate
# and server_onupdate for these
for t in affected_tables:
for c in t.c:
if c in normalized_params:
continue
elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
(
c,
compiler.process(
c.onupdate.arg.self_group(), **kw
),
)
)
compiler.postfetch.append(c)
else:
values.append(
(
c,
_create_update_prefetch_bind_param(
compiler, c, name=_col_bind_name(c)
),
)
)
elif c.server_onupdate is not None:
compiler.postfetch.append(c)
def _extend_values_for_multiparams(compiler, stmt, values, kw):
values_0 = values
values = [values]
for i, row in enumerate(stmt.parameters[1:]):
extension = []
for (col, param) in values_0:
if col in row or col.key in row:
key = col if col in row else col.key
if elements._is_literal(row[key]):
new_param = _create_bind_param(
compiler,
col,
row[key],
name="%s_m%d" % (col.key, i + 1),
**kw
)
else:
new_param = compiler.process(row[key].self_group(), **kw)
else:
new_param = _process_multiparam_default_bind(
compiler, stmt, col, i, kw
)
extension.append((col, new_param))
values.append(extension)
return values
def _get_stmt_parameters_params(
compiler, parameters, stmt_parameters, _column_as_key, values, kw
):
for k, v in stmt_parameters.items():
colkey = _column_as_key(k)
if colkey is not None:
parameters.setdefault(colkey, v)
else:
# a non-Column expression on the left side;
# add it to values() in an "as-is" state,
# coercing right side to bound param
if elements._is_literal(v):
v = compiler.process(
elements.BindParameter(None, v, type_=k.type), **kw
)
else:
v = compiler.process(v.self_group(), **kw)
values.append((k, v))
def _get_returning_modifiers(compiler, stmt):
need_pks = (
compiler.isinsert
and not compiler.inline
and not stmt._returning
and not stmt._has_multi_parameters
)
implicit_returning = (
need_pks
and compiler.dialect.implicit_returning
and stmt.table.implicit_returning
)
if compiler.isinsert:
implicit_return_defaults = implicit_returning and stmt._return_defaults
elif compiler.isupdate:
implicit_return_defaults = (
compiler.dialect.implicit_returning
and stmt.table.implicit_returning
and stmt._return_defaults
)
else:
# this line is unused, currently we are always
# isinsert or isupdate
implicit_return_defaults = False # pragma: no cover
if implicit_return_defaults:
if stmt._return_defaults is True:
implicit_return_defaults = set(stmt.table.c)
else:
implicit_return_defaults = set(stmt._return_defaults)
postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
return (
need_pks,
implicit_returning,
implicit_return_defaults,
postfetch_lastrowid,
)
def _warn_pk_with_no_anticipated_value(c):
msg = (
"Column '%s.%s' is marked as a member of the "
"primary key for table '%s', "
"but has no Python-side or server-side default generator indicated, "
"nor does it indicate 'autoincrement=True' or 'nullable=True', "
"and no explicit value is passed. "
"Primary key columns typically may not store NULL."
% (c.table.fullname, c.name, c.table.fullname)
)
if len(c.table.primary_key) > 1:
msg += (
" Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
"indicated explicitly for composite (e.g. multicolumn) primary "
"keys if AUTO_INCREMENT/SERIAL/IDENTITY "
"behavior is expected for one of the columns in the primary key. "
"CREATE TABLE statements are impacted by this change as well on "
"most backends."
)
util.warn(msg)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,370 @@
# sql/default_comparator.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Default implementation of SQL comparison operations.
"""
from . import operators
from . import type_api
from .elements import _clause_element_as_expr
from .elements import _const_expr
from .elements import _is_literal
from .elements import _literal_as_text
from .elements import and_
from .elements import BinaryExpression
from .elements import BindParameter
from .elements import ClauseElement
from .elements import ClauseList
from .elements import collate
from .elements import CollectionAggregate
from .elements import ColumnElement
from .elements import False_
from .elements import Null
from .elements import or_
from .elements import TextClause
from .elements import True_
from .elements import Tuple
from .elements import UnaryExpression
from .elements import Visitable
from .selectable import Alias
from .selectable import ScalarSelect
from .selectable import Selectable
from .selectable import SelectBase
from .. import exc
from .. import util
def _boolean_compare(
expr,
op,
obj,
negate=None,
reverse=False,
_python_is_types=(util.NoneType, bool),
result_type=None,
**kwargs
):
if result_type is None:
result_type = type_api.BOOLEANTYPE
if isinstance(obj, _python_is_types + (Null, True_, False_)):
# allow x ==/!= True/False to be treated as a literal.
# this comes out to "== / != true/false" or "1/0" if those
# constants aren't supported and works on all platforms
if op in (operators.eq, operators.ne) and isinstance(
obj, (bool, True_, False_)
):
return BinaryExpression(
expr,
_literal_as_text(obj),
op,
type_=result_type,
negate=negate,
modifiers=kwargs,
)
elif op in (operators.is_distinct_from, operators.isnot_distinct_from):
return BinaryExpression(
expr,
_literal_as_text(obj),
op,
type_=result_type,
negate=negate,
modifiers=kwargs,
)
else:
# all other None/True/False uses IS, IS NOT
if op in (operators.eq, operators.is_):
return BinaryExpression(
expr,
_const_expr(obj),
operators.is_,
negate=operators.isnot,
type_=result_type,
)
elif op in (operators.ne, operators.isnot):
return BinaryExpression(
expr,
_const_expr(obj),
operators.isnot,
negate=operators.is_,
type_=result_type,
)
else:
raise exc.ArgumentError(
"Only '=', '!=', 'is_()', 'isnot()', "
"'is_distinct_from()', 'isnot_distinct_from()' "
"operators can be used with None/True/False"
)
else:
obj = _check_literal(expr, op, obj)
if reverse:
return BinaryExpression(
obj, expr, op, type_=result_type, negate=negate, modifiers=kwargs
)
else:
return BinaryExpression(
expr, obj, op, type_=result_type, negate=negate, modifiers=kwargs
)
def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw):
if result_type is None:
if op.return_type:
result_type = op.return_type
elif op.is_comparison:
result_type = type_api.BOOLEANTYPE
return _binary_operate(
expr, op, obj, reverse=reverse, result_type=result_type, **kw
)
def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw):
obj = _check_literal(expr, op, obj)
if reverse:
left, right = obj, expr
else:
left, right = expr, obj
if result_type is None:
op, result_type = left.comparator._adapt_expression(
op, right.comparator
)
return BinaryExpression(left, right, op, type_=result_type, modifiers=kw)
def _conjunction_operate(expr, op, other, **kw):
if op is operators.and_:
return and_(expr, other)
elif op is operators.or_:
return or_(expr, other)
else:
raise NotImplementedError()
def _scalar(expr, op, fn, **kw):
return fn(expr)
def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
if isinstance(seq_or_selectable, ScalarSelect):
return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op)
elif isinstance(seq_or_selectable, SelectBase):
# TODO: if we ever want to support (x, y, z) IN (select x,
# y, z from table), we would need a multi-column version of
# as_scalar() to produce a multi- column selectable that
# does not export itself as a FROM clause
return _boolean_compare(
expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw
)
elif isinstance(seq_or_selectable, (Selectable, TextClause)):
return _boolean_compare(
expr, op, seq_or_selectable, negate=negate_op, **kw
)
elif isinstance(seq_or_selectable, ClauseElement):
if (
isinstance(seq_or_selectable, BindParameter)
and seq_or_selectable.expanding
):
if isinstance(expr, Tuple):
seq_or_selectable = seq_or_selectable._with_expanding_in_types(
[elem.type for elem in expr]
)
return _boolean_compare(
expr, op, seq_or_selectable, negate=negate_op
)
else:
raise exc.InvalidRequestError(
"in_() accepts"
" either a list of expressions, "
'a selectable, or an "expanding" bound parameter: %r'
% seq_or_selectable
)
# Handle non selectable arguments as sequences
args = []
for o in seq_or_selectable:
if not _is_literal(o):
if not isinstance(o, operators.ColumnOperators):
raise exc.InvalidRequestError(
"in_() accepts"
" either a list of expressions, "
'a selectable, or an "expanding" bound parameter: %r' % o
)
elif o is None:
o = Null()
else:
o = expr._bind_param(op, o)
args.append(o)
if len(args) == 0:
op, negate_op = (
(operators.empty_in_op, operators.empty_notin_op)
if op is operators.in_op
else (operators.empty_notin_op, operators.empty_in_op)
)
return _boolean_compare(
expr, op, ClauseList(*args).self_group(against=op), negate=negate_op
)
def _getitem_impl(expr, op, other, **kw):
if isinstance(expr.type, type_api.INDEXABLE):
other = _check_literal(expr, op, other)
return _binary_operate(expr, op, other, **kw)
else:
_unsupported_impl(expr, op, other, **kw)
def _unsupported_impl(expr, op, *arg, **kw):
raise NotImplementedError(
"Operator '%s' is not supported on " "this expression" % op.__name__
)
def _inv_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__inv__`."""
if hasattr(expr, "negation_clause"):
return expr.negation_clause
else:
return expr._negate()
def _neg_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__neg__`."""
return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
def _match_impl(expr, op, other, **kw):
"""See :meth:`.ColumnOperators.match`."""
return _boolean_compare(
expr,
operators.match_op,
_check_literal(expr, operators.match_op, other),
result_type=type_api.MATCHTYPE,
negate=operators.notmatch_op
if op is operators.match_op
else operators.match_op,
**kw
)
def _distinct_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.distinct`."""
return UnaryExpression(
expr, operator=operators.distinct_op, type_=expr.type
)
def _between_impl(expr, op, cleft, cright, **kw):
"""See :meth:`.ColumnOperators.between`."""
return BinaryExpression(
expr,
ClauseList(
_check_literal(expr, operators.and_, cleft),
_check_literal(expr, operators.and_, cright),
operator=operators.and_,
group=False,
group_contents=False,
),
op,
negate=operators.notbetween_op
if op is operators.between_op
else operators.between_op,
modifiers=kw,
)
def _collate_impl(expr, op, other, **kw):
return collate(expr, other)
# a mapping of operators with the method they use, along with
# their negated operator for comparison operators
operator_lookup = {
"and_": (_conjunction_operate,),
"or_": (_conjunction_operate,),
"inv": (_inv_impl,),
"add": (_binary_operate,),
"mul": (_binary_operate,),
"sub": (_binary_operate,),
"div": (_binary_operate,),
"mod": (_binary_operate,),
"truediv": (_binary_operate,),
"custom_op": (_custom_op_operate,),
"json_path_getitem_op": (_binary_operate,),
"json_getitem_op": (_binary_operate,),
"concat_op": (_binary_operate,),
"any_op": (_scalar, CollectionAggregate._create_any),
"all_op": (_scalar, CollectionAggregate._create_all),
"lt": (_boolean_compare, operators.ge),
"le": (_boolean_compare, operators.gt),
"ne": (_boolean_compare, operators.eq),
"gt": (_boolean_compare, operators.le),
"ge": (_boolean_compare, operators.lt),
"eq": (_boolean_compare, operators.ne),
"is_distinct_from": (_boolean_compare, operators.isnot_distinct_from),
"isnot_distinct_from": (_boolean_compare, operators.is_distinct_from),
"like_op": (_boolean_compare, operators.notlike_op),
"ilike_op": (_boolean_compare, operators.notilike_op),
"notlike_op": (_boolean_compare, operators.like_op),
"notilike_op": (_boolean_compare, operators.ilike_op),
"contains_op": (_boolean_compare, operators.notcontains_op),
"startswith_op": (_boolean_compare, operators.notstartswith_op),
"endswith_op": (_boolean_compare, operators.notendswith_op),
"desc_op": (_scalar, UnaryExpression._create_desc),
"asc_op": (_scalar, UnaryExpression._create_asc),
"nullsfirst_op": (_scalar, UnaryExpression._create_nullsfirst),
"nullslast_op": (_scalar, UnaryExpression._create_nullslast),
"in_op": (_in_impl, operators.notin_op),
"notin_op": (_in_impl, operators.in_op),
"is_": (_boolean_compare, operators.is_),
"isnot": (_boolean_compare, operators.isnot),
"collate": (_collate_impl,),
"match_op": (_match_impl,),
"notmatch_op": (_match_impl,),
"distinct_op": (_distinct_impl,),
"between_op": (_between_impl,),
"notbetween_op": (_between_impl,),
"neg": (_neg_impl,),
"getitem": (_getitem_impl,),
"lshift": (_unsupported_impl,),
"rshift": (_unsupported_impl,),
"contains": (_unsupported_impl,),
}
def _check_literal(expr, operator, other, bindparam_type=None):
if isinstance(other, (ColumnElement, TextClause)):
if isinstance(other, BindParameter) and other.type._isnull:
other = other._clone()
other.type = expr.type
return other
elif hasattr(other, "__clause_element__"):
other = other.__clause_element__()
elif isinstance(other, type_api.TypeEngine.Comparator):
other = other.expr
if isinstance(other, (SelectBase, Alias)):
return other.as_scalar()
elif not isinstance(other, Visitable):
return expr._bind_param(operator, other, type_=bindparam_type)
else:
return other

View File

@@ -0,0 +1,906 @@
# sql/dml.py
# Copyright (C) 2009-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""
Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
"""
from .base import _from_objects
from .base import _generative
from .base import DialectKWArgs
from .base import Executable
from .elements import _clone
from .elements import _column_as_key
from .elements import _literal_as_text
from .elements import and_
from .elements import ClauseElement
from .elements import Null
from .selectable import _interpret_as_from
from .selectable import _interpret_as_select
from .selectable import HasCTE
from .selectable import HasPrefixes
from .. import exc
from .. import util
class UpdateBase(
HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement
):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.
"""
__visit_name__ = "update_base"
_execution_options = Executable._execution_options.union(
{"autocommit": True}
)
_hints = util.immutabledict()
_parameter_ordering = None
_prefixes = ()
named_with_column = False
def _process_colparams(self, parameters):
def process_single(p):
if isinstance(p, (list, tuple)):
return dict((c.key, pval) for c, pval in zip(self.table.c, p))
else:
return p
if self._preserve_parameter_order and parameters is not None:
if not isinstance(parameters, list) or (
parameters and not isinstance(parameters[0], tuple)
):
raise ValueError(
"When preserve_parameter_order is True, "
"values() only accepts a list of 2-tuples"
)
self._parameter_ordering = [key for key, value in parameters]
return dict(parameters), False
if (
isinstance(parameters, (list, tuple))
and parameters
and isinstance(parameters[0], (list, tuple, dict))
):
if not self._supports_multi_parameters:
raise exc.InvalidRequestError(
"This construct does not support "
"multiple parameter sets."
)
return [process_single(p) for p in parameters], True
else:
return process_single(parameters), False
def params(self, *arg, **kw):
"""Set the parameters for the statement.
This method raises ``NotImplementedError`` on the base class,
and is overridden by :class:`.ValuesBase` to provide the
SET/VALUES clause of UPDATE and INSERT.
"""
raise NotImplementedError(
"params() is not supported for INSERT/UPDATE/DELETE statements."
" To set the values for an INSERT or UPDATE statement, use"
" stmt.values(**parameters)."
)
def bind(self):
"""Return a 'bind' linked to this :class:`.UpdateBase`
or a :class:`.Table` associated with it.
"""
return self._bind or self.table.bind
def _set_bind(self, bind):
self._bind = bind
bind = property(bind, _set_bind)
@_generative
def returning(self, *cols):
r"""Add a :term:`RETURNING` or equivalent clause to this statement.
e.g.::
stmt = table.update().\
where(table.c.data == 'value').\
values(status='X').\
returning(table.c.server_flag,
table.c.updated_timestamp)
for server_flag, updated_timestamp in connection.execute(stmt):
print(server_flag, updated_timestamp)
The given collection of column expressions should be derived from
the table that is
the target of the INSERT, UPDATE, or DELETE. While :class:`.Column`
objects are typical, the elements can also be expressions::
stmt = table.insert().returning(
(table.c.first_name + " " + table.c.last_name).
label('fullname'))
Upon compilation, a RETURNING clause, or database equivalent,
will be rendered within the statement. For INSERT and UPDATE,
the values are the newly inserted/updated values. For DELETE,
the values are those of the rows which were deleted.
Upon execution, the values of the columns to be returned are made
available via the result set and can be iterated using
:meth:`.ResultProxy.fetchone` and similar. For DBAPIs which do not
natively support returning values (i.e. cx_oracle), SQLAlchemy will
approximate this behavior at the result level so that a reasonable
amount of behavioral neutrality is provided.
Note that not all databases/DBAPIs
support RETURNING. For those backends with no support,
an exception is raised upon compilation and/or execution.
For those who do support it, the functionality across backends
varies greatly, including restrictions on executemany()
and other statements which return multiple rows. Please
read the documentation notes for the database in use in
order to determine the availability of RETURNING.
.. seealso::
:meth:`.ValuesBase.return_defaults` - an alternative method tailored
towards efficient fetching of server-side defaults and triggers
for single-row INSERTs or UPDATEs.
"""
self._returning = cols
@_generative
def with_hint(self, text, selectable=None, dialect_name="*"):
"""Add a table hint for a single table to this
INSERT/UPDATE/DELETE statement.
.. note::
:meth:`.UpdateBase.with_hint` currently applies only to
Microsoft SQL Server. For MySQL INSERT/UPDATE/DELETE hints, use
:meth:`.UpdateBase.prefix_with`.
The text of the hint is rendered in the appropriate
location for the database backend in use, relative
to the :class:`.Table` that is the subject of this
statement, or optionally to that of the given
:class:`.Table` passed as the ``selectable`` argument.
The ``dialect_name`` option will limit the rendering of a particular
hint to a particular backend. Such as, to add a hint
that only takes effect for SQL Server::
mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql")
:param text: Text of the hint.
:param selectable: optional :class:`.Table` that specifies
an element of the FROM clause within an UPDATE or DELETE
to be the subject of the hint - applies only to certain backends.
:param dialect_name: defaults to ``*``, if specified as the name
of a particular dialect, will apply these hints only when
that dialect is in use.
"""
if selectable is None:
selectable = self.table
self._hints = self._hints.union({(selectable, dialect_name): text})
class ValuesBase(UpdateBase):
"""Supplies support for :meth:`.ValuesBase.values` to
INSERT and UPDATE constructs."""
__visit_name__ = "values_base"
_supports_multi_parameters = False
_has_multi_parameters = False
_preserve_parameter_order = False
select = None
_post_values_clause = None
def __init__(self, table, values, prefixes):
self.table = _interpret_as_from(table)
self.parameters, self._has_multi_parameters = self._process_colparams(
values
)
if prefixes:
self._setup_prefixes(prefixes)
@_generative
def values(self, *args, **kwargs):
r"""specify a fixed VALUES clause for an INSERT statement, or the SET
clause for an UPDATE.
Note that the :class:`.Insert` and :class:`.Update` constructs support
per-execution time formatting of the VALUES and/or SET clauses,
based on the arguments passed to :meth:`.Connection.execute`.
However, the :meth:`.ValuesBase.values` method can be used to "fix" a
particular set of parameters into the statement.
Multiple calls to :meth:`.ValuesBase.values` will produce a new
construct, each one with the parameter list modified to include
the new parameters sent. In the typical case of a single
dictionary of parameters, the newly passed keys will replace
the same keys in the previous construct. In the case of a list-based
"multiple values" construct, each new list of values is extended
onto the existing list of values.
:param \**kwargs: key value pairs representing the string key
of a :class:`.Column` mapped to the value to be rendered into the
VALUES or SET clause::
users.insert().values(name="some name")
users.update().where(users.c.id==5).values(name="some name")
:param \*args: As an alternative to passing key/value parameters,
a dictionary, tuple, or list of dictionaries or tuples can be passed
as a single positional argument in order to form the VALUES or
SET clause of the statement. The forms that are accepted vary
based on whether this is an :class:`.Insert` or an :class:`.Update`
construct.
For either an :class:`.Insert` or :class:`.Update` construct, a
single dictionary can be passed, which works the same as that of
the kwargs form::
users.insert().values({"name": "some name"})
users.update().values({"name": "some new name"})
Also for either form but more typically for the :class:`.Insert`
construct, a tuple that contains an entry for every column in the
table is also accepted::
users.insert().values((5, "some name"))
The :class:`.Insert` construct also supports being passed a list
of dictionaries or full-table-tuples, which on the server will
render the less common SQL syntax of "multiple values" - this
syntax is supported on backends such as SQLite, PostgreSQL, MySQL,
but not necessarily others::
users.insert().values([
{"name": "some name"},
{"name": "some other name"},
{"name": "yet another name"},
])
The above form would render a multiple VALUES statement similar to::
INSERT INTO users (name) VALUES
(:name_1),
(:name_2),
(:name_3)
It is essential to note that **passing multiple values is
NOT the same as using traditional executemany() form**. The above
syntax is a **special** syntax not typically used. To emit an
INSERT statement against multiple rows, the normal method is
to pass a multiple values list to the :meth:`.Connection.execute`
method, which is supported by all database backends and is generally
more efficient for a very large number of parameters.
.. seealso::
:ref:`execute_multiple` - an introduction to
the traditional Core method of multiple parameter set
invocation for INSERTs and other statements.
.. versionchanged:: 1.0.0 an INSERT that uses a multiple-VALUES
clause, even a list of length one,
implies that the :paramref:`.Insert.inline` flag is set to
True, indicating that the statement will not attempt to fetch
the "last inserted primary key" or other defaults. The
statement deals with an arbitrary number of rows, so the
:attr:`.ResultProxy.inserted_primary_key` accessor does not
apply.
.. versionchanged:: 1.0.0 A multiple-VALUES INSERT now supports
columns with Python side default values and callables in the
same way as that of an "executemany" style of invocation; the
callable is invoked for each row. See :ref:`bug_3288`
for other details.
The :class:`.Update` construct supports a special form which is a
list of 2-tuples, which when provided must be passed in conjunction
with the
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
parameter.
This form causes the UPDATE statement to render the SET clauses
using the order of parameters given to :meth:`.Update.values`, rather
than the ordering of columns given in the :class:`.Table`.
.. versionadded:: 1.0.10 - added support for parameter-ordered
UPDATE statements via the
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
flag.
.. seealso::
:ref:`updates_order_parameters` - full example of the
:paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`
flag
.. seealso::
:ref:`inserts_and_updates` - SQL Expression
Language Tutorial
:func:`~.expression.insert` - produce an ``INSERT`` statement
:func:`~.expression.update` - produce an ``UPDATE`` statement
"""
if self.select is not None:
raise exc.InvalidRequestError(
"This construct already inserts from a SELECT"
)
if self._has_multi_parameters and kwargs:
raise exc.InvalidRequestError(
"This construct already has multiple parameter sets."
)
if args:
if len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary/tuple or list of "
"dictionaries/tuples is accepted positionally."
)
v = args[0]
else:
v = {}
if self.parameters is None:
(
self.parameters,
self._has_multi_parameters,
) = self._process_colparams(v)
else:
if self._has_multi_parameters:
self.parameters = list(self.parameters)
p, self._has_multi_parameters = self._process_colparams(v)
if not self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
"formats in one statement"
)
self.parameters.extend(p)
else:
self.parameters = self.parameters.copy()
p, self._has_multi_parameters = self._process_colparams(v)
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
"formats in one statement"
)
self.parameters.update(p)
if kwargs:
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't pass kwargs and multiple parameter sets "
"simultaneously"
)
else:
self.parameters.update(kwargs)
@_generative
def return_defaults(self, *cols):
"""Make use of a :term:`RETURNING` clause for the purpose
of fetching server-side expressions and defaults.
E.g.::
stmt = table.insert().values(data='newdata').return_defaults()
result = connection.execute(stmt)
server_created_at = result.returned_defaults['created_at']
When used against a backend that supports RETURNING, all column
values generated by SQL expression or server-side-default will be
added to any existing RETURNING clause, provided that
:meth:`.UpdateBase.returning` is not used simultaneously. The column
values will then be available on the result using the
:attr:`.ResultProxy.returned_defaults` accessor as a dictionary,
referring to values keyed to the :class:`.Column` object as well as
its ``.key``.
This method differs from :meth:`.UpdateBase.returning` in these ways:
1. :meth:`.ValuesBase.return_defaults` is only intended for use with
an INSERT or an UPDATE statement that matches exactly one row.
While the RETURNING construct in the general sense supports
multiple rows for a multi-row UPDATE or DELETE statement, or for
special cases of INSERT that return multiple rows (e.g. INSERT from
SELECT, multi-valued VALUES clause),
:meth:`.ValuesBase.return_defaults` is intended only for an
"ORM-style" single-row INSERT/UPDATE statement. The row returned
by the statement is also consumed implicitly when
:meth:`.ValuesBase.return_defaults` is used. By contrast,
:meth:`.UpdateBase.returning` leaves the RETURNING result-set
intact with a collection of any number of rows.
2. It is compatible with the existing logic to fetch auto-generated
primary key values, also known as "implicit returning". Backends
that support RETURNING will automatically make use of RETURNING in
order to fetch the value of newly generated primary keys; while the
:meth:`.UpdateBase.returning` method circumvents this behavior,
:meth:`.ValuesBase.return_defaults` leaves it intact.
3. It can be called against any backend. Backends that don't support
RETURNING will skip the usage of the feature, rather than raising
an exception. The return value of
:attr:`.ResultProxy.returned_defaults` will be ``None``
:meth:`.ValuesBase.return_defaults` is used by the ORM to provide
an efficient implementation for the ``eager_defaults`` feature of
:func:`.mapper`.
:param cols: optional list of column key names or :class:`.Column`
objects. If omitted, all column expressions evaluated on the server
are added to the returning list.
.. versionadded:: 0.9.0
.. seealso::
:meth:`.UpdateBase.returning`
:attr:`.ResultProxy.returned_defaults`
"""
self._return_defaults = cols or True
class Insert(ValuesBase):
"""Represent an INSERT construct.
The :class:`.Insert` object is created using the
:func:`~.expression.insert()` function.
.. seealso::
:ref:`coretutorial_insert_expressions`
"""
__visit_name__ = "insert"
_supports_multi_parameters = True
def __init__(
self,
table,
values=None,
inline=False,
bind=None,
prefixes=None,
returning=None,
return_defaults=False,
**dialect_kw
):
"""Construct an :class:`.Insert` object.
Similar functionality is available via the
:meth:`~.TableClause.insert` method on
:class:`~.schema.Table`.
:param table: :class:`.TableClause` which is the subject of the
insert.
:param values: collection of values to be inserted; see
:meth:`.Insert.values` for a description of allowed formats here.
Can be omitted entirely; a :class:`.Insert` construct will also
dynamically render the VALUES clause at execution time based on
the parameters passed to :meth:`.Connection.execute`.
:param inline: if True, no attempt will be made to retrieve the
SQL-generated default values to be provided within the statement;
in particular,
this allows SQL expressions to be rendered 'inline' within the
statement without the need to pre-execute them beforehand; for
backends that support "returning", this turns off the "implicit
returning" feature for the statement.
If both `values` and compile-time bind parameters are present, the
compile-time bind parameters override the information specified
within `values` on a per-key basis.
The keys within `values` can be either
:class:`~sqlalchemy.schema.Column` objects or their string
identifiers. Each key may reference one of:
* a literal data value (i.e. string, number, etc.);
* a Column object;
* a SELECT statement.
If a ``SELECT`` statement is specified which references this
``INSERT`` statement's table, the statement will be correlated
against the ``INSERT`` statement.
.. seealso::
:ref:`coretutorial_insert_expressions` - SQL Expression Tutorial
:ref:`inserts_and_updates` - SQL Expression Tutorial
"""
ValuesBase.__init__(self, table, values, prefixes)
self._bind = bind
self.select = self.select_names = None
self.include_insert_from_select_defaults = False
self.inline = inline
self._returning = returning
self._validate_dialect_kwargs(dialect_kw)
self._return_defaults = return_defaults
def get_children(self, **kwargs):
if self.select is not None:
return (self.select,)
else:
return ()
@_generative
def from_select(self, names, select, include_defaults=True):
"""Return a new :class:`.Insert` construct which represents
an ``INSERT...FROM SELECT`` statement.
e.g.::
sel = select([table1.c.a, table1.c.b]).where(table1.c.c > 5)
ins = table2.insert().from_select(['a', 'b'], sel)
:param names: a sequence of string column names or :class:`.Column`
objects representing the target columns.
:param select: a :func:`.select` construct, :class:`.FromClause`
or other construct which resolves into a :class:`.FromClause`,
such as an ORM :class:`.Query` object, etc. The order of
columns returned from this FROM clause should correspond to the
order of columns sent as the ``names`` parameter; while this
is not checked before passing along to the database, the database
would normally raise an exception if these column lists don't
correspond.
:param include_defaults: if True, non-server default values and
SQL expressions as specified on :class:`.Column` objects
(as documented in :ref:`metadata_defaults_toplevel`) not
otherwise specified in the list of names will be rendered
into the INSERT and SELECT statements, so that these values are also
included in the data to be inserted.
.. note:: A Python-side default that uses a Python callable function
will only be invoked **once** for the whole statement, and **not
per row**.
.. versionadded:: 1.0.0 - :meth:`.Insert.from_select` now renders
Python-side and SQL expression column defaults into the
SELECT statement for columns otherwise not included in the
list of column names.
.. versionchanged:: 1.0.0 an INSERT that uses FROM SELECT
implies that the :paramref:`.insert.inline` flag is set to
True, indicating that the statement will not attempt to fetch
the "last inserted primary key" or other defaults. The statement
deals with an arbitrary number of rows, so the
:attr:`.ResultProxy.inserted_primary_key` accessor does not apply.
"""
if self.parameters:
raise exc.InvalidRequestError(
"This construct already inserts value expressions"
)
self.parameters, self._has_multi_parameters = self._process_colparams(
{_column_as_key(n): Null() for n in names}
)
self.select_names = names
self.inline = True
self.include_insert_from_select_defaults = include_defaults
self.select = _interpret_as_select(select)
def _copy_internals(self, clone=_clone, **kw):
# TODO: coverage
self.parameters = self.parameters.copy()
if self.select is not None:
self.select = _clone(self.select)
class Update(ValuesBase):
"""Represent an Update construct.
The :class:`.Update` object is created using the :func:`update()`
function.
"""
__visit_name__ = "update"
def __init__(
self,
table,
whereclause=None,
values=None,
inline=False,
bind=None,
prefixes=None,
returning=None,
return_defaults=False,
preserve_parameter_order=False,
**dialect_kw
):
r"""Construct an :class:`.Update` object.
E.g.::
from sqlalchemy import update
stmt = update(users).where(users.c.id==5).\
values(name='user #5')
Similar functionality is available via the
:meth:`~.TableClause.update` method on
:class:`.Table`::
stmt = users.update().\
where(users.c.id==5).\
values(name='user #5')
:param table: A :class:`.Table` object representing the database
table to be updated.
:param whereclause: Optional SQL expression describing the ``WHERE``
condition of the ``UPDATE`` statement. Modern applications
may prefer to use the generative :meth:`~Update.where()`
method to specify the ``WHERE`` clause.
The WHERE clause can refer to multiple tables.
For databases which support this, an ``UPDATE FROM`` clause will
be generated, or on MySQL, a multi-table update. The statement
will fail on databases that don't have support for multi-table
update statements. A SQL-standard method of referring to
additional tables in the WHERE clause is to use a correlated
subquery::
users.update().values(name='ed').where(
users.c.name==select([addresses.c.email_address]).\
where(addresses.c.user_id==users.c.id).\
as_scalar()
)
:param values:
Optional dictionary which specifies the ``SET`` conditions of the
``UPDATE``. If left as ``None``, the ``SET``
conditions are determined from those parameters passed to the
statement during the execution and/or compilation of the
statement. When compiled standalone without any parameters,
the ``SET`` clause generates for all columns.
Modern applications may prefer to use the generative
:meth:`.Update.values` method to set the values of the
UPDATE statement.
:param inline:
if True, SQL defaults present on :class:`.Column` objects via
the ``default`` keyword will be compiled 'inline' into the statement
and not pre-executed. This means that their values will not
be available in the dictionary returned from
:meth:`.ResultProxy.last_updated_params`.
:param preserve_parameter_order: if True, the update statement is
expected to receive parameters **only** via the
:meth:`.Update.values` method, and they must be passed as a Python
``list`` of 2-tuples. The rendered UPDATE statement will emit the SET
clause for each referenced column maintaining this order.
.. versionadded:: 1.0.10
.. seealso::
:ref:`updates_order_parameters` - full example of the
:paramref:`~.update.preserve_parameter_order` flag
If both ``values`` and compile-time bind parameters are present, the
compile-time bind parameters override the information specified
within ``values`` on a per-key basis.
The keys within ``values`` can be either :class:`.Column`
objects or their string identifiers (specifically the "key" of the
:class:`.Column`, normally but not necessarily equivalent to
its "name"). Normally, the
:class:`.Column` objects used here are expected to be
part of the target :class:`.Table` that is the table
to be updated. However when using MySQL, a multiple-table
UPDATE statement can refer to columns from any of
the tables referred to in the WHERE clause.
The values referred to in ``values`` are typically:
* a literal data value (i.e. string, number, etc.)
* a SQL expression, such as a related :class:`.Column`,
a scalar-returning :func:`.select` construct,
etc.
When combining :func:`.select` constructs within the values
clause of an :func:`.update` construct,
the subquery represented by the :func:`.select` should be
*correlated* to the parent table, that is, providing criterion
which links the table inside the subquery to the outer table
being updated::
users.update().values(
name=select([addresses.c.email_address]).\
where(addresses.c.user_id==users.c.id).\
as_scalar()
)
.. seealso::
:ref:`inserts_and_updates` - SQL Expression
Language Tutorial
"""
self._preserve_parameter_order = preserve_parameter_order
ValuesBase.__init__(self, table, values, prefixes)
self._bind = bind
self._returning = returning
if whereclause is not None:
self._whereclause = _literal_as_text(whereclause)
else:
self._whereclause = None
self.inline = inline
self._validate_dialect_kwargs(dialect_kw)
self._return_defaults = return_defaults
def get_children(self, **kwargs):
if self._whereclause is not None:
return (self._whereclause,)
else:
return ()
def _copy_internals(self, clone=_clone, **kw):
# TODO: coverage
self._whereclause = clone(self._whereclause, **kw)
self.parameters = self.parameters.copy()
@_generative
def where(self, whereclause):
"""return a new update() construct with the given expression added to
its WHERE clause, joined to the existing clause via AND, if any.
"""
if self._whereclause is not None:
self._whereclause = and_(
self._whereclause, _literal_as_text(whereclause)
)
else:
self._whereclause = _literal_as_text(whereclause)
@property
def _extra_froms(self):
froms = []
seen = {self.table}
if self._whereclause is not None:
for item in _from_objects(self._whereclause):
if not seen.intersection(item._cloned_set):
froms.append(item)
seen.update(item._cloned_set)
return froms
class Delete(UpdateBase):
"""Represent a DELETE construct.
The :class:`.Delete` object is created using the :func:`delete()`
function.
"""
__visit_name__ = "delete"
def __init__(
self,
table,
whereclause=None,
bind=None,
returning=None,
prefixes=None,
**dialect_kw
):
"""Construct :class:`.Delete` object.
Similar functionality is available via the
:meth:`~.TableClause.delete` method on
:class:`~.schema.Table`.
:param table: The table to delete rows from.
:param whereclause: A :class:`.ClauseElement` describing the ``WHERE``
condition of the ``DELETE`` statement. Note that the
:meth:`~Delete.where()` generative method may be used instead.
The WHERE clause can refer to multiple tables.
For databases which support this, a ``DELETE..USING`` or similar
clause will be generated. The statement
will fail on databases that don't have support for multi-table
delete statements. A SQL-standard method of referring to
additional tables in the WHERE clause is to use a correlated
subquery::
users.delete().where(
users.c.name==select([addresses.c.email_address]).\
where(addresses.c.user_id==users.c.id).\
as_scalar()
)
.. versionchanged:: 1.2.0
The WHERE clause of DELETE can refer to multiple tables.
.. seealso::
:ref:`deletes` - SQL Expression Tutorial
"""
self._bind = bind
self.table = _interpret_as_from(table)
self._returning = returning
if prefixes:
self._setup_prefixes(prefixes)
if whereclause is not None:
self._whereclause = _literal_as_text(whereclause)
else:
self._whereclause = None
self._validate_dialect_kwargs(dialect_kw)
def get_children(self, **kwargs):
if self._whereclause is not None:
return (self._whereclause,)
else:
return ()
@_generative
def where(self, whereclause):
"""Add the given WHERE clause to a newly returned delete construct."""
if self._whereclause is not None:
self._whereclause = and_(
self._whereclause, _literal_as_text(whereclause)
)
else:
self._whereclause = _literal_as_text(whereclause)
@property
def _extra_froms(self):
froms = []
seen = {self.table}
if self._whereclause is not None:
for item in _from_objects(self._whereclause):
if not seen.intersection(item._cloned_set):
froms.append(item)
seen.update(item._cloned_set)
return froms
def _copy_internals(self, clone=_clone, **kw):
# TODO: coverage
self._whereclause = clone(self._whereclause, **kw)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,263 @@
# sql/expression.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Defines the public namespace for SQL expression constructs.
Prior to version 0.9, this module contained all of "elements", "dml",
"default_comparator" and "selectable". The module was broken up
and most "factory" functions were moved to be grouped with their associated
class.
"""
__all__ = [
"Alias",
"any_",
"all_",
"ClauseElement",
"ColumnCollection",
"ColumnElement",
"CompoundSelect",
"Delete",
"FromClause",
"Insert",
"Join",
"Lateral",
"Select",
"Selectable",
"TableClause",
"Update",
"alias",
"and_",
"asc",
"between",
"bindparam",
"case",
"cast",
"column",
"cte",
"delete",
"desc",
"distinct",
"except_",
"except_all",
"exists",
"extract",
"func",
"modifier",
"collate",
"insert",
"intersect",
"intersect_all",
"join",
"label",
"lateral",
"literal",
"literal_column",
"not_",
"null",
"nullsfirst",
"nullslast",
"or_",
"outparam",
"outerjoin",
"over",
"select",
"subquery",
"table",
"text",
"tuple_",
"type_coerce",
"quoted_name",
"union",
"union_all",
"update",
"within_group",
"TableSample",
"tablesample",
]
from .base import _from_objects # noqa
from .base import ColumnCollection # noqa
from .base import Executable # noqa
from .base import Generative # noqa
from .base import PARSE_AUTOCOMMIT # noqa
from .dml import Delete # noqa
from .dml import Insert # noqa
from .dml import Update # noqa
from .dml import UpdateBase # noqa
from .dml import ValuesBase # noqa
from .elements import _clause_element_as_expr # noqa
from .elements import _clone # noqa
from .elements import _cloned_difference # noqa
from .elements import _cloned_intersection # noqa
from .elements import _column_as_key # noqa
from .elements import _corresponding_column_or_error # noqa
from .elements import _expression_literal_as_text # noqa
from .elements import _is_column # noqa
from .elements import _labeled # noqa
from .elements import _literal_as_binds # noqa
from .elements import _literal_as_column # noqa
from .elements import _literal_as_label_reference # noqa
from .elements import _literal_as_text # noqa
from .elements import _only_column_elements # noqa
from .elements import _select_iterables # noqa
from .elements import _string_or_unprintable # noqa
from .elements import _truncated_label # noqa
from .elements import between # noqa
from .elements import BinaryExpression # noqa
from .elements import BindParameter # noqa
from .elements import BooleanClauseList # noqa
from .elements import Case # noqa
from .elements import Cast # noqa
from .elements import ClauseElement # noqa
from .elements import ClauseList # noqa
from .elements import collate # noqa
from .elements import CollectionAggregate # noqa
from .elements import ColumnClause # noqa
from .elements import ColumnElement # noqa
from .elements import Extract # noqa
from .elements import False_ # noqa
from .elements import FunctionFilter # noqa
from .elements import Grouping # noqa
from .elements import Label # noqa
from .elements import literal # noqa
from .elements import literal_column # noqa
from .elements import not_ # noqa
from .elements import Null # noqa
from .elements import outparam # noqa
from .elements import Over # noqa
from .elements import quoted_name # noqa
from .elements import ReleaseSavepointClause # noqa
from .elements import RollbackToSavepointClause # noqa
from .elements import SavepointClause # noqa
from .elements import TextClause # noqa
from .elements import True_ # noqa
from .elements import Tuple # noqa
from .elements import TypeClause # noqa
from .elements import TypeCoerce # noqa
from .elements import UnaryExpression # noqa
from .elements import WithinGroup # noqa
from .functions import func # noqa
from .functions import Function # noqa
from .functions import FunctionElement # noqa
from .functions import modifier # noqa
from .selectable import _interpret_as_from # noqa
from .selectable import Alias # noqa
from .selectable import CompoundSelect # noqa
from .selectable import CTE # noqa
from .selectable import Exists # noqa
from .selectable import FromClause # noqa
from .selectable import FromGrouping # noqa
from .selectable import GenerativeSelect # noqa
from .selectable import HasCTE # noqa
from .selectable import HasPrefixes # noqa
from .selectable import HasSuffixes # noqa
from .selectable import Join # noqa
from .selectable import Lateral # noqa
from .selectable import ScalarSelect # noqa
from .selectable import Select # noqa
from .selectable import Selectable # noqa
from .selectable import SelectBase # noqa
from .selectable import subquery # noqa
from .selectable import TableClause # noqa
from .selectable import TableSample # noqa
from .selectable import TextAsFrom # noqa
from .visitors import Visitable # noqa
from ..util.langhelpers import public_factory # noqa
# factory functions - these pull class-bound constructors and classmethods
# from SQL elements and selectables into public functions. This allows
# the functions to be available in the sqlalchemy.sql.* namespace and
# to be auto-cross-documenting from the function to the class itself.
all_ = public_factory(CollectionAggregate._create_all, ".expression.all_")
any_ = public_factory(CollectionAggregate._create_any, ".expression.any_")
and_ = public_factory(BooleanClauseList.and_, ".expression.and_")
alias = public_factory(Alias._factory, ".expression.alias")
tablesample = public_factory(TableSample._factory, ".expression.tablesample")
lateral = public_factory(Lateral._factory, ".expression.lateral")
or_ = public_factory(BooleanClauseList.or_, ".expression.or_")
bindparam = public_factory(BindParameter, ".expression.bindparam")
select = public_factory(Select, ".expression.select")
text = public_factory(TextClause._create_text, ".expression.text")
table = public_factory(TableClause, ".expression.table")
column = public_factory(ColumnClause, ".expression.column")
over = public_factory(Over, ".expression.over")
within_group = public_factory(WithinGroup, ".expression.within_group")
label = public_factory(Label, ".expression.label")
case = public_factory(Case, ".expression.case")
cast = public_factory(Cast, ".expression.cast")
cte = public_factory(CTE._factory, ".expression.cte")
extract = public_factory(Extract, ".exp # noqaression.extract")
tuple_ = public_factory(Tuple, ".expression.tuple_")
except_ = public_factory(CompoundSelect._create_except, ".expression.except_")
except_all = public_factory(
CompoundSelect._create_except_all, ".expression.except_all"
)
intersect = public_factory(
CompoundSelect._create_intersect, ".expression.intersect"
)
intersect_all = public_factory(
CompoundSelect._create_intersect_all, ".expression.intersect_all"
)
union = public_factory(CompoundSelect._create_union, ".expression.union")
union_all = public_factory(
CompoundSelect._create_union_all, ".expression.union_all"
)
exists = public_factory(Exists, ".expression.exists")
nullsfirst = public_factory(
UnaryExpression._create_nullsfirst, ".expression.nullsfirst"
)
nullslast = public_factory(
UnaryExpression._create_nullslast, ".expression.nullslast"
)
asc = public_factory(UnaryExpression._create_asc, ".expression.asc")
desc = public_factory(UnaryExpression._create_desc, ".expression.desc")
distinct = public_factory(
UnaryExpression._create_distinct, ".expression.distinct"
)
type_coerce = public_factory(TypeCoerce, ".expression.type_coerce")
true = public_factory(True_._instance, ".expression.true")
false = public_factory(False_._instance, ".expression.false")
null = public_factory(Null._instance, ".expression.null")
join = public_factory(Join._create_join, ".expression.join")
outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin")
insert = public_factory(Insert, ".expression.insert")
update = public_factory(Update, ".expression.update")
delete = public_factory(Delete, ".expression.delete")
funcfilter = public_factory(FunctionFilter, ".expression.funcfilter")
# internal functions still being called from tests and the ORM,
# these might be better off in some other namespace
# old names for compatibility
_Executable = Executable
_BindParamClause = BindParameter
_Label = Label
_SelectBase = SelectBase
_BinaryExpression = BinaryExpression
_Cast = Cast
_Null = Null
_False = False_
_True = True_
_TextClause = TextClause
_UnaryExpression = UnaryExpression
_Case = Case
_Tuple = Tuple
_Over = Over
_Generative = Generative
_TypeClause = TypeClause
_Extract = Extract
_Exists = Exists
_Grouping = Grouping
_FromGrouping = FromGrouping
_ScalarSelect = ScalarSelect

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,182 @@
# sqlalchemy/naming.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Establish constraint and index naming conventions.
"""
import re
from .elements import _defer_name
from .elements import _defer_none_name
from .elements import conv
from .schema import CheckConstraint
from .schema import Column
from .schema import Constraint
from .schema import ForeignKeyConstraint
from .schema import Index
from .schema import PrimaryKeyConstraint
from .schema import Table
from .schema import UniqueConstraint
from .. import event
from .. import events # noqa
from .. import exc
class ConventionDict(object):
def __init__(self, const, table, convention):
self.const = const
self._is_fk = isinstance(const, ForeignKeyConstraint)
self.table = table
self.convention = convention
self._const_name = const.name
def _key_table_name(self):
return self.table.name
def _column_X(self, idx):
if self._is_fk:
fk = self.const.elements[idx]
return fk.parent
else:
return list(self.const.columns)[idx]
def _key_constraint_name(self):
if isinstance(self._const_name, (type(None), _defer_none_name)):
raise exc.InvalidRequestError(
"Naming convention including "
"%(constraint_name)s token requires that "
"constraint is explicitly named."
)
if not isinstance(self._const_name, conv):
self.const.name = None
return self._const_name
def _key_column_X_key(self, idx):
# note this method was missing before
# [ticket:3989], meaning tokens like ``%(column_0_key)s`` weren't
# working even though documented.
return self._column_X(idx).key
def _key_column_X_name(self, idx):
return self._column_X(idx).name
def _key_column_X_label(self, idx):
return self._column_X(idx)._label
def _key_referred_table_name(self):
fk = self.const.elements[0]
refs = fk.target_fullname.split(".")
if len(refs) == 3:
refschema, reftable, refcol = refs
else:
reftable, refcol = refs
return reftable
def _key_referred_column_X_name(self, idx):
fk = self.const.elements[idx]
# note that before [ticket:3989], this method was returning
# the specification for the :class:`.ForeignKey` itself, which normally
# would be using the ``.key`` of the column, not the name.
return fk.column.name
def __getitem__(self, key):
if key in self.convention:
return self.convention[key](self.const, self.table)
elif hasattr(self, "_key_%s" % key):
return getattr(self, "_key_%s" % key)()
else:
col_template = re.match(r".*_?column_(\d+)(_?N)?_.+", key)
if col_template:
idx = col_template.group(1)
multiples = col_template.group(2)
if multiples:
if self._is_fk:
elems = self.const.elements
else:
elems = list(self.const.columns)
tokens = []
for idx, elem in enumerate(elems):
attr = "_key_" + key.replace("0" + multiples, "X")
try:
tokens.append(getattr(self, attr)(idx))
except AttributeError:
raise KeyError(key)
sep = "_" if multiples.startswith("_") else ""
return sep.join(tokens)
else:
attr = "_key_" + key.replace(idx, "X")
idx = int(idx)
if hasattr(self, attr):
return getattr(self, attr)(idx)
raise KeyError(key)
_prefix_dict = {
Index: "ix",
PrimaryKeyConstraint: "pk",
CheckConstraint: "ck",
UniqueConstraint: "uq",
ForeignKeyConstraint: "fk",
}
def _get_convention(dict_, key):
for super_ in key.__mro__:
if super_ in _prefix_dict and _prefix_dict[super_] in dict_:
return dict_[_prefix_dict[super_]]
elif super_ in dict_:
return dict_[super_]
else:
return None
def _constraint_name_for_table(const, table):
metadata = table.metadata
convention = _get_convention(metadata.naming_convention, type(const))
if isinstance(const.name, conv):
return const.name
elif (
convention is not None
and not isinstance(const.name, conv)
and (
const.name is None
or "constraint_name" in convention
or isinstance(const.name, _defer_name)
)
):
return conv(
convention
% ConventionDict(const, table, metadata.naming_convention)
)
elif isinstance(convention, _defer_none_name):
return None
@event.listens_for(Constraint, "after_parent_attach")
@event.listens_for(Index, "after_parent_attach")
def _constraint_name(const, table):
if isinstance(table, Column):
# for column-attached constraint, set another event
# to link the column attached to the table as this constraint
# associated with the table.
event.listen(
table,
"after_parent_attach",
lambda col, table: _constraint_name(const, table),
)
elif isinstance(table, Table):
if isinstance(const.name, (conv, _defer_name)):
return
newname = _constraint_name_for_table(const, table)
if newname is not None:
const.name = newname

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,936 @@
# sql/util.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""High level utilities which build upon other modules here.
"""
from collections import deque
from itertools import chain
from . import operators
from . import visitors
from .annotation import _deep_annotate # noqa
from .annotation import _deep_deannotate # noqa
from .annotation import _shallow_annotate # noqa
from .base import _from_objects
from .base import ColumnSet
from .ddl import sort_tables # noqa
from .elements import _expand_cloned
from .elements import _find_columns # noqa
from .elements import _label_reference
from .elements import _textual_label_reference
from .elements import BindParameter
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import Null
from .elements import UnaryExpression
from .schema import Column
from .selectable import Alias
from .selectable import FromClause
from .selectable import FromGrouping
from .selectable import Join
from .selectable import ScalarSelect
from .selectable import SelectBase
from .selectable import TableClause
from .. import exc
from .. import util
join_condition = util.langhelpers.public_factory(
Join._join_condition, ".sql.util.join_condition"
)
def find_join_source(clauses, join_to):
"""Given a list of FROM clauses and a selectable,
return the first index and element from the list of
clauses which can be joined against the selectable. returns
None, None if no match is found.
e.g.::
clause1 = table1.join(table2)
clause2 = table4.join(table5)
join_to = table2.join(table3)
find_join_source([clause1, clause2], join_to) == clause1
"""
selectables = list(_from_objects(join_to))
idx = []
for i, f in enumerate(clauses):
for s in selectables:
if f.is_derived_from(s):
idx.append(i)
return idx
def find_left_clause_that_matches_given(clauses, join_from):
"""Given a list of FROM clauses and a selectable,
return the indexes from the list of
clauses which is derived from the selectable.
"""
selectables = list(_from_objects(join_from))
liberal_idx = []
for i, f in enumerate(clauses):
for s in selectables:
# basic check, if f is derived from s.
# this can be joins containing a table, or an aliased table
# or select statement matching to a table. This check
# will match a table to a selectable that is adapted from
# that table. With Query, this suits the case where a join
# is being made to an adapted entity
if f.is_derived_from(s):
liberal_idx.append(i)
break
# in an extremely small set of use cases, a join is being made where
# there are multiple FROM clauses where our target table is represented
# in more than one, such as embedded or similar. in this case, do
# another pass where we try to get a more exact match where we aren't
# looking at adaption relationships.
if len(liberal_idx) > 1:
conservative_idx = []
for idx in liberal_idx:
f = clauses[idx]
for s in selectables:
if set(surface_selectables(f)).intersection(
surface_selectables(s)
):
conservative_idx.append(idx)
break
if conservative_idx:
return conservative_idx
return liberal_idx
def find_left_clause_to_join_from(clauses, join_to, onclause):
"""Given a list of FROM clauses, a selectable,
and optional ON clause, return a list of integer indexes from the
clauses list indicating the clauses that can be joined from.
The presence of an "onclause" indicates that at least one clause can
definitely be joined from; if the list of clauses is of length one
and the onclause is given, returns that index. If the list of clauses
is more than length one, and the onclause is given, attempts to locate
which clauses contain the same columns.
"""
idx = []
selectables = set(_from_objects(join_to))
# if we are given more than one target clause to join
# from, use the onclause to provide a more specific answer.
# otherwise, don't try to limit, after all, "ON TRUE" is a valid
# on clause
if len(clauses) > 1 and onclause is not None:
resolve_ambiguity = True
cols_in_onclause = _find_columns(onclause)
else:
resolve_ambiguity = False
cols_in_onclause = None
for i, f in enumerate(clauses):
for s in selectables.difference([f]):
if resolve_ambiguity:
if set(f.c).union(s.c).issuperset(cols_in_onclause):
idx.append(i)
break
elif Join._can_join(f, s) or onclause is not None:
idx.append(i)
break
if len(idx) > 1:
# this is the same "hide froms" logic from
# Selectable._get_display_froms
toremove = set(
chain(*[_expand_cloned(f._hide_froms) for f in clauses])
)
idx = [i for i in idx if clauses[i] not in toremove]
# onclause was given and none of them resolved, so assume
# all indexes can match
if not idx and onclause is not None:
return range(len(clauses))
else:
return idx
def visit_binary_product(fn, expr):
"""Produce a traversal of the given expression, delivering
column comparisons to the given function.
The function is of the form::
def my_fn(binary, left, right)
For each binary expression located which has a
comparison operator, the product of "left" and
"right" will be delivered to that function,
in terms of that binary.
Hence an expression like::
and_(
(a + b) == q + func.sum(e + f),
j == r
)
would have the traversal::
a <eq> q
a <eq> e
a <eq> f
b <eq> q
b <eq> e
b <eq> f
j <eq> r
That is, every combination of "left" and
"right" that doesn't further contain
a binary comparison is passed as pairs.
"""
stack = []
def visit(element):
if isinstance(element, ScalarSelect):
# we don't want to dig into correlated subqueries,
# those are just column elements by themselves
yield element
elif element.__visit_name__ == "binary" and operators.is_comparison(
element.operator
):
stack.insert(0, element)
for l in visit(element.left):
for r in visit(element.right):
fn(stack[0], l, r)
stack.pop(0)
for elem in element.get_children():
visit(elem)
else:
if isinstance(element, ColumnClause):
yield element
for elem in element.get_children():
for e in visit(elem):
yield e
list(visit(expr))
def find_tables(
clause,
check_columns=False,
include_aliases=False,
include_joins=False,
include_selects=False,
include_crud=False,
):
"""locate Table objects within the given expression."""
tables = []
_visitors = {}
if include_selects:
_visitors["select"] = _visitors["compound_select"] = tables.append
if include_joins:
_visitors["join"] = tables.append
if include_aliases:
_visitors["alias"] = tables.append
if include_crud:
_visitors["insert"] = _visitors["update"] = _visitors[
"delete"
] = lambda ent: tables.append(ent.table)
if check_columns:
def visit_column(column):
tables.append(column.table)
_visitors["column"] = visit_column
_visitors["table"] = tables.append
visitors.traverse(clause, {"column_collections": False}, _visitors)
return tables
def unwrap_order_by(clause):
"""Break up an 'order by' expression into individual column-expressions,
without DESC/ASC/NULLS FIRST/NULLS LAST"""
cols = util.column_set()
result = []
stack = deque([clause])
while stack:
t = stack.popleft()
if isinstance(t, ColumnElement) and (
not isinstance(t, UnaryExpression)
or not operators.is_ordering_modifier(t.modifier)
):
if isinstance(t, _label_reference):
t = t.element
if isinstance(t, (_textual_label_reference)):
continue
if t not in cols:
cols.add(t)
result.append(t)
else:
for c in t.get_children():
stack.append(c)
return result
def unwrap_label_reference(element):
def replace(elem):
if isinstance(elem, (_label_reference, _textual_label_reference)):
return elem.element
return visitors.replacement_traverse(element, {}, replace)
def expand_column_list_from_order_by(collist, order_by):
"""Given the columns clause and ORDER BY of a selectable,
return a list of column expressions that can be added to the collist
corresponding to the ORDER BY, without repeating those already
in the collist.
"""
cols_already_present = set(
[
col.element if col._order_by_label_element is not None else col
for col in collist
]
)
return [
col
for col in chain(*[unwrap_order_by(o) for o in order_by])
if col not in cols_already_present
]
def clause_is_present(clause, search):
"""Given a target clause and a second to search within, return True
if the target is plainly present in the search without any
subqueries or aliases involved.
Basically descends through Joins.
"""
for elem in surface_selectables(search):
if clause == elem: # use == here so that Annotated's compare
return True
else:
return False
def surface_selectables(clause):
stack = [clause]
while stack:
elem = stack.pop()
yield elem
if isinstance(elem, Join):
stack.extend((elem.left, elem.right))
elif isinstance(elem, FromGrouping):
stack.append(elem.element)
def surface_selectables_only(clause):
stack = [clause]
while stack:
elem = stack.pop()
if isinstance(elem, (TableClause, Alias)):
yield elem
if isinstance(elem, Join):
stack.extend((elem.left, elem.right))
elif isinstance(elem, FromGrouping):
stack.append(elem.element)
elif isinstance(elem, ColumnClause):
stack.append(elem.table)
def surface_column_elements(clause, include_scalar_selects=True):
"""traverse and yield only outer-exposed column elements, such as would
be addressable in the WHERE clause of a SELECT if this element were
in the columns clause."""
filter_ = (FromGrouping,)
if not include_scalar_selects:
filter_ += (SelectBase,)
stack = deque([clause])
while stack:
elem = stack.popleft()
yield elem
for sub in elem.get_children():
if isinstance(sub, filter_):
continue
stack.append(sub)
def selectables_overlap(left, right):
"""Return True if left/right have some overlapping selectable"""
return bool(
set(surface_selectables(left)).intersection(surface_selectables(right))
)
def bind_values(clause):
"""Return an ordered list of "bound" values in the given clause.
E.g.::
>>> expr = and_(
... table.c.foo==5, table.c.foo==7
... )
>>> bind_values(expr)
[5, 7]
"""
v = []
def visit_bindparam(bind):
v.append(bind.effective_value)
visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
return v
def _quote_ddl_expr(element):
if isinstance(element, util.string_types):
element = element.replace("'", "''")
return "'%s'" % element
else:
return repr(element)
class _repr_base(object):
_LIST = 0
_TUPLE = 1
_DICT = 2
__slots__ = ("max_chars",)
def trunc(self, value):
rep = repr(value)
lenrep = len(rep)
if lenrep > self.max_chars:
segment_length = self.max_chars // 2
rep = (
rep[0:segment_length]
+ (
" ... (%d characters truncated) ... "
% (lenrep - self.max_chars)
)
+ rep[-segment_length:]
)
return rep
class _repr_row(_repr_base):
"""Provide a string view of a row."""
__slots__ = ("row",)
def __init__(self, row, max_chars=300):
self.row = row
self.max_chars = max_chars
def __repr__(self):
trunc = self.trunc
return "(%s%s)" % (
", ".join(trunc(value) for value in self.row),
"," if len(self.row) == 1 else "",
)
class _repr_params(_repr_base):
"""Provide a string view of bound parameters.
Truncates display to a given numnber of 'multi' parameter sets,
as well as long values to a given number of characters.
"""
__slots__ = "params", "batches"
def __init__(self, params, batches, max_chars=300):
self.params = params
self.batches = batches
self.max_chars = max_chars
def __repr__(self):
if isinstance(self.params, list):
typ = self._LIST
ismulti = self.params and isinstance(
self.params[0], (list, dict, tuple)
)
elif isinstance(self.params, tuple):
typ = self._TUPLE
ismulti = self.params and isinstance(
self.params[0], (list, dict, tuple)
)
elif isinstance(self.params, dict):
typ = self._DICT
ismulti = False
else:
return self.trunc(self.params)
if ismulti and len(self.params) > self.batches:
msg = " ... displaying %i of %i total bound parameter sets ... "
return " ".join(
(
self._repr_multi(self.params[: self.batches - 2], typ)[
0:-1
],
msg % (self.batches, len(self.params)),
self._repr_multi(self.params[-2:], typ)[1:],
)
)
elif ismulti:
return self._repr_multi(self.params, typ)
else:
return self._repr_params(self.params, typ)
def _repr_multi(self, multi_params, typ):
if multi_params:
if isinstance(multi_params[0], list):
elem_type = self._LIST
elif isinstance(multi_params[0], tuple):
elem_type = self._TUPLE
elif isinstance(multi_params[0], dict):
elem_type = self._DICT
else:
assert False, "Unknown parameter type %s" % (
type(multi_params[0])
)
elements = ", ".join(
self._repr_params(params, elem_type) for params in multi_params
)
else:
elements = ""
if typ == self._LIST:
return "[%s]" % elements
else:
return "(%s)" % elements
def _repr_params(self, params, typ):
trunc = self.trunc
if typ is self._DICT:
return "{%s}" % (
", ".join(
"%r: %s" % (key, trunc(value))
for key, value in params.items()
)
)
elif typ is self._TUPLE:
return "(%s%s)" % (
", ".join(trunc(value) for value in params),
"," if len(params) == 1 else "",
)
else:
return "[%s]" % (", ".join(trunc(value) for value in params))
def adapt_criterion_to_null(crit, nulls):
"""given criterion containing bind params, convert selected elements
to IS NULL.
"""
def visit_binary(binary):
if (
isinstance(binary.left, BindParameter)
and binary.left._identifying_key in nulls
):
# reverse order if the NULL is on the left side
binary.left = binary.right
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
elif (
isinstance(binary.right, BindParameter)
and binary.right._identifying_key in nulls
):
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
def splice_joins(left, right, stop_on=None):
if left is None:
return right
stack = [(right, None)]
adapter = ClauseAdapter(left)
ret = None
while stack:
(right, prevright) = stack.pop()
if isinstance(right, Join) and right is not stop_on:
right = right._clone()
right._reset_exported()
right.onclause = adapter.traverse(right.onclause)
stack.append((right.left, right))
else:
right = adapter.traverse(right)
if prevright is not None:
prevright.left = right
if ret is None:
ret = right
return ret
def reduce_columns(columns, *clauses, **kw):
r"""given a list of columns, return a 'reduced' set based on natural
equivalents.
the set is reduced to the smallest list of columns which have no natural
equivalent present in the list. A "natural equivalent" means that two
columns will ultimately represent the same value because they are related
by a foreign key.
\*clauses is an optional list of join clauses which will be traversed
to further identify columns that are "equivalent".
\**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
whose tables are not yet configured, or columns that aren't yet present.
This function is primarily used to determine the most minimal "primary
key" from a selectable, by reducing the set of primary key columns present
in the selectable to just those that are not repeated.
"""
ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
only_synonyms = kw.pop("only_synonyms", False)
columns = util.ordered_column_set(columns)
omit = util.column_set()
for col in columns:
for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
for c in columns:
if c is col:
continue
try:
fk_col = fk.column
except exc.NoReferencedColumnError:
# TODO: add specific coverage here
# to test/sql/test_selectable ReduceTest
if ignore_nonexistent_tables:
continue
else:
raise
except exc.NoReferencedTableError:
# TODO: add specific coverage here
# to test/sql/test_selectable ReduceTest
if ignore_nonexistent_tables:
continue
else:
raise
if fk_col.shares_lineage(c) and (
not only_synonyms or c.name == col.name
):
omit.add(col)
break
if clauses:
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(
chain(*[c.proxy_set for c in columns.difference(omit)])
)
if binary.left in cols and binary.right in cols:
for c in reversed(columns):
if c.shares_lineage(binary.right) and (
not only_synonyms or c.name == binary.left.name
):
omit.add(c)
break
for clause in clauses:
if clause is not None:
visitors.traverse(clause, {}, {"binary": visit_binary})
return ColumnSet(columns.difference(omit))
def criterion_as_pairs(
expression,
consider_as_foreign_keys=None,
consider_as_referenced_keys=None,
any_operator=False,
):
"""traverse an expression and locate binary criterion pairs."""
if consider_as_foreign_keys and consider_as_referenced_keys:
raise exc.ArgumentError(
"Can only specify one of "
"'consider_as_foreign_keys' or "
"'consider_as_referenced_keys'"
)
def col_is(a, b):
# return a is b
return a.compare(b)
def visit_binary(binary):
if not any_operator and binary.operator is not operators.eq:
return
if not isinstance(binary.left, ColumnElement) or not isinstance(
binary.right, ColumnElement
):
return
if consider_as_foreign_keys:
if binary.left in consider_as_foreign_keys and (
col_is(binary.right, binary.left)
or binary.right not in consider_as_foreign_keys
):
pairs.append((binary.right, binary.left))
elif binary.right in consider_as_foreign_keys and (
col_is(binary.left, binary.right)
or binary.left not in consider_as_foreign_keys
):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
if binary.left in consider_as_referenced_keys and (
col_is(binary.right, binary.left)
or binary.right not in consider_as_referenced_keys
):
pairs.append((binary.left, binary.right))
elif binary.right in consider_as_referenced_keys and (
col_is(binary.left, binary.right)
or binary.left not in consider_as_referenced_keys
):
pairs.append((binary.right, binary.left))
else:
if isinstance(binary.left, Column) and isinstance(
binary.right, Column
):
if binary.left.references(binary.right):
pairs.append((binary.right, binary.left))
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
pairs = []
visitors.traverse(expression, {}, {"binary": visit_binary})
return pairs
class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""Clones and modifies clauses based on column correspondence.
E.g.::
table1 = Table('sometable', metadata,
Column('col1', Integer),
Column('col2', Integer)
)
table2 = Table('someothertable', metadata,
Column('col1', Integer),
Column('col2', Integer)
)
condition = table1.c.col1 == table2.c.col1
make an alias of table1::
s = table1.alias('foo')
calling ``ClauseAdapter(s).traverse(condition)`` converts
condition to read::
s.c.col1 == table2.c.col1
"""
def __init__(
self,
selectable,
equivalents=None,
include_fn=None,
exclude_fn=None,
adapt_on_names=False,
anonymize_labels=False,
):
self.__traverse_options__ = {
"stop_on": [selectable],
"anonymize_labels": anonymize_labels,
}
self.selectable = selectable
self.include_fn = include_fn
self.exclude_fn = exclude_fn
self.equivalents = util.column_dict(equivalents or {})
self.adapt_on_names = adapt_on_names
def _corresponding_column(
self, col, require_embedded, _seen=util.EMPTY_SET
):
newcol = self.selectable.corresponding_column(
col, require_embedded=require_embedded
)
if newcol is None and col in self.equivalents and col not in _seen:
for equiv in self.equivalents[col]:
newcol = self._corresponding_column(
equiv,
require_embedded=require_embedded,
_seen=_seen.union([col]),
)
if newcol is not None:
return newcol
if self.adapt_on_names and newcol is None:
newcol = self.selectable.c.get(col.name)
return newcol
def replace(self, col):
if isinstance(col, FromClause) and self.selectable.is_derived_from(
col
):
return self.selectable
elif not isinstance(col, ColumnElement):
return None
elif self.include_fn and not self.include_fn(col):
return None
elif self.exclude_fn and self.exclude_fn(col):
return None
else:
return self._corresponding_column(col, True)
class ColumnAdapter(ClauseAdapter):
"""Extends ClauseAdapter with extra utility functions.
Key aspects of ColumnAdapter include:
* Expressions that are adapted are stored in a persistent
.columns collection; so that an expression E adapted into
an expression E1, will return the same object E1 when adapted
a second time. This is important in particular for things like
Label objects that are anonymized, so that the ColumnAdapter can
be used to present a consistent "adapted" view of things.
* Exclusion of items from the persistent collection based on
include/exclude rules, but also independent of hash identity.
This because "annotated" items all have the same hash identity as their
parent.
* "wrapping" capability is added, so that the replacement of an expression
E can proceed through a series of adapters. This differs from the
visitor's "chaining" feature in that the resulting object is passed
through all replacing functions unconditionally, rather than stopping
at the first one that returns non-None.
* An adapt_required option, used by eager loading to indicate that
We don't trust a result row column that is not translated.
This is to prevent a column from being interpreted as that
of the child row in a self-referential scenario, see
inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
"""
def __init__(
self,
selectable,
equivalents=None,
adapt_required=False,
include_fn=None,
exclude_fn=None,
adapt_on_names=False,
allow_label_resolve=True,
anonymize_labels=False,
):
ClauseAdapter.__init__(
self,
selectable,
equivalents,
include_fn=include_fn,
exclude_fn=exclude_fn,
adapt_on_names=adapt_on_names,
anonymize_labels=anonymize_labels,
)
self.columns = util.populate_column_dict(self._locate_col)
if self.include_fn or self.exclude_fn:
self.columns = self._IncludeExcludeMapping(self, self.columns)
self.adapt_required = adapt_required
self.allow_label_resolve = allow_label_resolve
self._wrap = None
class _IncludeExcludeMapping(object):
def __init__(self, parent, columns):
self.parent = parent
self.columns = columns
def __getitem__(self, key):
if (
self.parent.include_fn and not self.parent.include_fn(key)
) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
if self.parent._wrap:
return self.parent._wrap.columns[key]
else:
return key
return self.columns[key]
def wrap(self, adapter):
ac = self.__class__.__new__(self.__class__)
ac.__dict__.update(self.__dict__)
ac._wrap = adapter
ac.columns = util.populate_column_dict(ac._locate_col)
if ac.include_fn or ac.exclude_fn:
ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
return ac
def traverse(self, obj):
return self.columns[obj]
adapt_clause = traverse
adapt_list = ClauseAdapter.copy_and_process
def _locate_col(self, col):
c = ClauseAdapter.traverse(self, col)
if self._wrap:
c2 = self._wrap._locate_col(c)
if c2 is not None:
c = c2
if self.adapt_required and c is col:
return None
c._allow_label_resolve = self.allow_label_resolve
return c
def __getstate__(self):
d = self.__dict__.copy()
del d["columns"]
return d
def __setstate__(self, state):
self.__dict__.update(state)
self.columns = util.PopulateDict(self._locate_col)

View File

@@ -0,0 +1,342 @@
# sql/visitors.py
# Copyright (C) 2005-2019 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
"""Visitor/traversal interface and library functions.
SQLAlchemy schema and expression constructs rely on a Python-centric
version of the classic "visitor" pattern as the primary way in which
they apply functionality. The most common use of this pattern
is statement compilation, where individual expression classes match
up to rendering methods that produce a string result. Beyond this,
the visitor system is also used to inspect expressions for various
information and patterns, as well as for usage in
some kinds of expression transformation. Other kinds of transformation
use a non-visitor traversal system.
For many examples of how the visit system is used, see the
sqlalchemy.sql.util and the sqlalchemy.sql.compiler modules.
For an introduction to clause adaption, see
http://techspot.zzzeek.org/2008/01/23/expression-transformations/
"""
from collections import deque
import operator
from .. import exc
from .. import util
__all__ = [
"VisitableType",
"Visitable",
"ClauseVisitor",
"CloningVisitor",
"ReplacingCloningVisitor",
"iterate",
"iterate_depthfirst",
"traverse_using",
"traverse",
"traverse_depthfirst",
"cloned_traverse",
"replacement_traverse",
]
class VisitableType(type):
"""Metaclass which assigns a `_compiler_dispatch` method to classes
having a `__visit_name__` attribute.
The _compiler_dispatch attribute becomes an instance method which
looks approximately like the following::
def _compiler_dispatch (self, visitor, **kw):
'''Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.'''
visit_attr = 'visit_%s' % self.__visit_name__
return getattr(visitor, visit_attr)(self, **kw)
Classes having no __visit_name__ attribute will remain unaffected.
"""
def __init__(cls, clsname, bases, clsdict):
if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
_generate_dispatch(cls)
super(VisitableType, cls).__init__(clsname, bases, clsdict)
def _generate_dispatch(cls):
"""Return an optimized visit dispatch function for the cls
for use by the compiler.
"""
if "__visit_name__" in cls.__dict__:
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
# There is an optimization opportunity here because the
# the string name of the class's __visit_name__ is known at
# this early stage (import time) so it can be pre-constructed.
getter = operator.attrgetter("visit_%s" % visit_name)
def _compiler_dispatch(self, visitor, **kw):
try:
meth = getter(visitor)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
else:
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw):
visit_attr = "visit_%s" % self.__visit_name__
try:
meth = getattr(visitor, visit_attr)
except AttributeError:
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
_compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.
"""
cls._compiler_dispatch = _compiler_dispatch
class Visitable(util.with_metaclass(VisitableType, object)):
"""Base class for visitable objects, applies the
``VisitableType`` metaclass.
"""
class ClauseVisitor(object):
"""Base class for visitor objects which can traverse using
the traverse() function.
"""
__traverse_options__ = {}
def traverse_single(self, obj, **kw):
for v in self.visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
return meth(obj, **kw)
def iterate(self, obj):
"""traverse the given expression structure, returning an iterator
of all elements.
"""
return iterate(obj, self.__traverse_options__)
def traverse(self, obj):
"""traverse and visit the given expression structure."""
return traverse(obj, self.__traverse_options__, self._visitor_dict)
@util.memoized_property
def _visitor_dict(self):
visitors = {}
for name in dir(self):
if name.startswith("visit_"):
visitors[name[6:]] = getattr(self, name)
return visitors
@property
def visitor_iterator(self):
"""iterate through this visitor and each 'chained' visitor."""
v = self
while v:
yield v
v = getattr(v, "_next", None)
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
the chained visitor will receive all visit events after this one.
"""
tail = list(self.visitor_iterator)[-1]
tail._next = visitor
return self
class CloningVisitor(ClauseVisitor):
"""Base class for visitor objects which can traverse using
the cloned_traverse() function.
"""
def copy_and_process(self, list_):
"""Apply cloned traversal to the given list of elements, and return
the new list.
"""
return [self.traverse(x) for x in list_]
def traverse(self, obj):
"""traverse and visit the given expression structure."""
return cloned_traverse(
obj, self.__traverse_options__, self._visitor_dict
)
class ReplacingCloningVisitor(CloningVisitor):
"""Base class for visitor objects which can traverse using
the replacement_traverse() function.
"""
def replace(self, elem):
"""receive pre-copied elements during a cloning traversal.
If the method returns a new element, the element is used
instead of creating a simple copy of the element. Traversal
will halt on the newly returned element if it is re-encountered.
"""
return None
def traverse(self, obj):
"""traverse and visit the given expression structure."""
def replace(elem):
for v in self.visitor_iterator:
e = v.replace(elem)
if e is not None:
return e
return replacement_traverse(obj, self.__traverse_options__, replace)
def iterate(obj, opts):
"""traverse the given expression structure, returning an iterator.
traversal is configured to be breadth-first.
"""
# fasttrack for atomic elements like columns
children = obj.get_children(**opts)
if not children:
return [obj]
traversal = deque()
stack = deque([obj])
while stack:
t = stack.popleft()
traversal.append(t)
for c in t.get_children(**opts):
stack.append(c)
return iter(traversal)
def iterate_depthfirst(obj, opts):
"""traverse the given expression structure, returning an iterator.
traversal is configured to be depth-first.
"""
# fasttrack for atomic elements like columns
children = obj.get_children(**opts)
if not children:
return [obj]
stack = deque([obj])
traversal = deque()
while stack:
t = stack.pop()
traversal.appendleft(t)
for c in t.get_children(**opts):
stack.append(c)
return iter(traversal)
def traverse_using(iterator, obj, visitors):
"""visit the given expression structure using the given iterator of
objects.
"""
for target in iterator:
meth = visitors.get(target.__visit_name__, None)
if meth:
meth(target)
return obj
def traverse(obj, opts, visitors):
"""traverse and visit the given expression structure using the default
iterator.
"""
return traverse_using(iterate(obj, opts), obj, visitors)
def traverse_depthfirst(obj, opts, visitors):
"""traverse and visit the given expression structure using the
depth-first iterator.
"""
return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
def cloned_traverse(obj, opts, visitors):
"""clone the given expression structure, allowing
modifications by visitors."""
cloned = {}
stop_on = set(opts.get("stop_on", []))
def clone(elem):
if elem in stop_on:
return elem
else:
if id(elem) not in cloned:
cloned[id(elem)] = newelem = elem._clone()
newelem._copy_internals(clone=clone)
meth = visitors.get(newelem.__visit_name__, None)
if meth:
meth(newelem)
return cloned[id(elem)]
if obj is not None:
obj = clone(obj)
return obj
def replacement_traverse(obj, opts, replace):
"""clone the given expression structure, allowing element
replacement by a given replacement function."""
cloned = {}
stop_on = {id(x) for x in opts.get("stop_on", [])}
def clone(elem, **kw):
if (
id(elem) in stop_on
or "no_replacement_traverse" in elem._annotations
):
return elem
else:
newelem = replace(elem)
if newelem is not None:
stop_on.add(id(newelem))
return newelem
else:
if elem not in cloned:
cloned[elem] = newelem = elem._clone()
newelem._copy_internals(clone=clone, **kw)
return cloned[elem]
if obj is not None:
obj = clone(obj, **opts)
return obj