Files
aitsc/.venv/Lib/site-packages/sqlalchemy/testing/suite/test_cte.py

238 lines
7.1 KiB
Python
Raw Normal View History

2025-02-23 09:07:52 +08:00
# testing/suite/test_cte.py
# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
from .. import fixtures
from ..assertions import eq_
from ..schema import Column
from ..schema import Table
2025-08-29 00:34:40 +08:00
from ... import column
2025-02-23 09:07:52 +08:00
from ... import ForeignKey
from ... import Integer
from ... import select
from ... import String
from ... import testing
2025-08-29 00:34:40 +08:00
from ... import values
2025-02-23 09:07:52 +08:00
class CTETest(fixtures.TablesTest):
__backend__ = True
__requires__ = ("ctes",)
run_inserts = "each"
run_deletes = "each"
@classmethod
def define_tables(cls, metadata):
Table(
"some_table",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
Column("parent_id", ForeignKey("some_table.id")),
)
Table(
"some_other_table",
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
Column("parent_id", Integer),
)
@classmethod
def insert_data(cls, connection):
connection.execute(
cls.tables.some_table.insert(),
[
{"id": 1, "data": "d1", "parent_id": None},
{"id": 2, "data": "d2", "parent_id": 1},
{"id": 3, "data": "d3", "parent_id": 1},
{"id": 4, "data": "d4", "parent_id": 3},
{"id": 5, "data": "d5", "parent_id": 3},
],
)
def test_select_nonrecursive_round_trip(self, connection):
some_table = self.tables.some_table
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
result = connection.execute(
select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
)
eq_(result.fetchall(), [("d4",)])
def test_select_recursive_round_trip(self, connection):
some_table = self.tables.some_table
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte", recursive=True)
)
cte_alias = cte.alias("c1")
st1 = some_table.alias()
# note that SQL Server requires this to be UNION ALL,
# can't be UNION
cte = cte.union_all(
select(st1).where(st1.c.id == cte_alias.c.parent_id)
)
result = connection.execute(
select(cte.c.data)
.where(cte.c.data != "d2")
.order_by(cte.c.data.desc())
)
eq_(
result.fetchall(),
[("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
)
def test_insert_from_select_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(cte)
)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
)
@testing.requires.ctes_with_update_delete
@testing.requires.update_from
def test_update_from_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(some_table)
)
)
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.update()
.values(parent_id=5)
.where(some_other_table.c.data == cte.c.data)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[
(1, "d1", None),
(2, "d2", 5),
(3, "d3", 5),
(4, "d4", 5),
(5, "d5", 3),
],
)
@testing.requires.ctes_with_update_delete
@testing.requires.delete_from
def test_delete_from_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(some_table)
)
)
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.delete().where(
some_other_table.c.data == cte.c.data
)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[(1, "d1", None), (5, "d5", 3)],
)
@testing.requires.ctes_with_update_delete
def test_delete_scalar_subq_round_trip(self, connection):
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
connection.execute(
some_other_table.insert().from_select(
["id", "data", "parent_id"], select(some_table)
)
)
cte = (
select(some_table)
.where(some_table.c.data.in_(["d2", "d3", "d4"]))
.cte("some_cte")
)
connection.execute(
some_other_table.delete().where(
some_other_table.c.data
== select(cte.c.data)
.where(cte.c.id == some_other_table.c.id)
.scalar_subquery()
)
)
eq_(
connection.execute(
select(some_other_table).order_by(some_other_table.c.id)
).fetchall(),
[(1, "d1", None), (5, "d5", 3)],
)
2025-08-29 00:34:40 +08:00
@testing.variation("values_named", [True, False])
@testing.variation("cte_named", [True, False])
@testing.variation("literal_binds", [True, False])
@testing.requires.ctes_with_values
def test_values_named_via_cte(
self, connection, values_named, cte_named, literal_binds
):
cte1 = (
values(
column("col1", String),
column("col2", Integer),
literal_binds=bool(literal_binds),
name="some name" if values_named else None,
)
.data([("a", 2), ("b", 3)])
.cte("cte1" if cte_named else None)
)
stmt = select(cte1)
rows = connection.execute(stmt).all()
eq_(rows, [("a", 2), ("b", 3)])