Files
order/.venv/Lib/site-packages/sqlacodegen/codegen.py
2025-08-27 21:11:48 +08:00

736 lines
32 KiB
Python

"""Contains the code generation logic and helper functions."""
from __future__ import unicode_literals, division, print_function, absolute_import
import inspect
import re
import sys
from collections import defaultdict
from importlib import import_module
from inspect import ArgSpec
from keyword import iskeyword
import sqlalchemy
import sqlalchemy.exc
from sqlalchemy import (
Enum, ForeignKeyConstraint, PrimaryKeyConstraint, CheckConstraint, UniqueConstraint, Table,
Column, Float)
from sqlalchemy.schema import ForeignKey
from sqlalchemy.sql.sqltypes import NullType
from sqlalchemy.types import Boolean, String
from sqlalchemy.util import OrderedDict
# The generic ARRAY type was introduced in SQLAlchemy 1.1
try:
from sqlalchemy import ARRAY
except ImportError:
from sqlalchemy.dialects.postgresql import ARRAY
# SQLAlchemy 1.3.11+
try:
from sqlalchemy import Computed
except ImportError:
Computed = None
# Conditionally import Geoalchemy2 to enable reflection support
try:
import geoalchemy2 # noqa: F401
except ImportError:
pass
_re_boolean_check_constraint = re.compile(r"(?:(?:.*?)\.)?(.*?) IN \(0, 1\)")
_re_column_name = re.compile(r'(?:(["`]?)(?:.*)\1\.)?(["`]?)(.*)\2')
_re_enum_check_constraint = re.compile(r"(?:(?:.*?)\.)?(.*?) IN \((.+)\)")
_re_enum_item = re.compile(r"'(.*?)(?<!\\)'")
_re_invalid_identifier = re.compile(r'[^a-zA-Z0-9_]' if sys.version_info[0] < 3 else r'(?u)\W')
class _DummyInflectEngine(object):
@staticmethod
def singular_noun(noun):
return noun
# In SQLAlchemy 0.x, constraint.columns is sometimes a list, on 1.x onwards, always a
# ColumnCollection
def _get_column_names(constraint):
if isinstance(constraint.columns, list):
return constraint.columns
return list(constraint.columns.keys())
def _get_constraint_sort_key(constraint):
if isinstance(constraint, CheckConstraint):
return 'C{0}'.format(constraint.sqltext)
return constraint.__class__.__name__[0] + repr(_get_column_names(constraint))
class ImportCollector(OrderedDict):
def add_import(self, obj):
type_ = type(obj) if not isinstance(obj, type) else obj
pkgname = type_.__module__
# The column types have already been adapted towards generic types if possible, so if this
# is still a vendor specific type (e.g., MySQL INTEGER) be sure to use that rather than the
# generic sqlalchemy type as it might have different constructor parameters.
if pkgname.startswith('sqlalchemy.dialects.'):
dialect_pkgname = '.'.join(pkgname.split('.')[0:3])
dialect_pkg = import_module(dialect_pkgname)
if type_.__name__ in dialect_pkg.__all__:
pkgname = dialect_pkgname
else:
pkgname = 'sqlalchemy' if type_.__name__ in sqlalchemy.__all__ else type_.__module__
self.add_literal_import(pkgname, type_.__name__)
def add_literal_import(self, pkgname, name):
names = self.setdefault(pkgname, set())
names.add(name)
class Model(object):
def __init__(self, table):
super(Model, self).__init__()
self.table = table
self.schema = table.schema
# Adapt column types to the most reasonable generic types (ie. VARCHAR -> String)
for column in table.columns:
if not isinstance(column.type, NullType):
column.type = self._get_adapted_type(column.type, column.table.bind)
def _get_adapted_type(self, coltype, bind):
compiled_type = coltype.compile(bind.dialect)
for supercls in coltype.__class__.__mro__:
if not supercls.__name__.startswith('_') and hasattr(supercls, '__visit_name__'):
# Hack to fix adaptation of the Enum class which is broken since SQLAlchemy 1.2
kw = {}
if supercls is Enum:
kw['name'] = coltype.name
try:
new_coltype = coltype.adapt(supercls)
except TypeError:
# If the adaptation fails, don't try again
break
for key, value in kw.items():
setattr(new_coltype, key, value)
if isinstance(coltype, ARRAY):
new_coltype.item_type = self._get_adapted_type(new_coltype.item_type, bind)
try:
# If the adapted column type does not render the same as the original, don't
# substitute it
if new_coltype.compile(bind.dialect) != compiled_type:
# Make an exception to the rule for Float and arrays of Float, since at
# least on PostgreSQL, Float can accurately represent both REAL and
# DOUBLE_PRECISION
if not isinstance(new_coltype, Float) and \
not (isinstance(new_coltype, ARRAY) and
isinstance(new_coltype.item_type, Float)):
break
except sqlalchemy.exc.CompileError:
# If the adapted column type can't be compiled, don't substitute it
break
# Stop on the first valid non-uppercase column type class
coltype = new_coltype
if supercls.__name__ != supercls.__name__.upper():
break
return coltype
def add_imports(self, collector):
if self.table.columns:
collector.add_import(Column)
for column in self.table.columns:
collector.add_import(column.type)
if column.server_default:
if Computed and isinstance(column.server_default, Computed):
collector.add_literal_import('sqlalchemy', 'Computed')
else:
collector.add_literal_import('sqlalchemy', 'text')
if isinstance(column.type, ARRAY):
collector.add_import(column.type.item_type.__class__)
for constraint in sorted(self.table.constraints, key=_get_constraint_sort_key):
if isinstance(constraint, ForeignKeyConstraint):
if len(constraint.columns) > 1:
collector.add_literal_import('sqlalchemy', 'ForeignKeyConstraint')
else:
collector.add_literal_import('sqlalchemy', 'ForeignKey')
elif isinstance(constraint, UniqueConstraint):
if len(constraint.columns) > 1:
collector.add_literal_import('sqlalchemy', 'UniqueConstraint')
elif not isinstance(constraint, PrimaryKeyConstraint):
collector.add_import(constraint)
for index in self.table.indexes:
if len(index.columns) > 1:
collector.add_import(index)
@staticmethod
def _convert_to_valid_identifier(name):
assert name, 'Identifier cannot be empty'
if name[0].isdigit() or iskeyword(name):
name = '_' + name
elif name == 'metadata':
name = 'metadata_'
return _re_invalid_identifier.sub('_', name)
class ModelTable(Model):
def __init__(self, table):
super(ModelTable, self).__init__(table)
self.name = self._convert_to_valid_identifier(table.name)
def add_imports(self, collector):
super(ModelTable, self).add_imports(collector)
collector.add_import(Table)
class ModelClass(Model):
parent_name = 'Base'
def __init__(self, table, association_tables, inflect_engine, detect_joined):
super(ModelClass, self).__init__(table)
self.name = self._tablename_to_classname(table.name, inflect_engine)
self.children = []
self.attributes = OrderedDict()
# Assign attribute names for columns
for column in table.columns:
self._add_attribute(column.name, column)
# Add many-to-one relationships
pk_column_names = set(col.name for col in table.primary_key.columns)
for constraint in sorted(table.constraints, key=_get_constraint_sort_key):
if isinstance(constraint, ForeignKeyConstraint):
target_cls = self._tablename_to_classname(constraint.elements[0].column.table.name,
inflect_engine)
if (detect_joined and self.parent_name == 'Base' and
set(_get_column_names(constraint)) == pk_column_names):
self.parent_name = target_cls
else:
relationship_ = ManyToOneRelationship(self.name, target_cls, constraint,
inflect_engine)
self._add_attribute(relationship_.preferred_name, relationship_)
# Add many-to-many relationships
for association_table in association_tables:
fk_constraints = [c for c in association_table.constraints
if isinstance(c, ForeignKeyConstraint)]
fk_constraints.sort(key=_get_constraint_sort_key)
target_cls = self._tablename_to_classname(
fk_constraints[1].elements[0].column.table.name, inflect_engine)
relationship_ = ManyToManyRelationship(self.name, target_cls, association_table)
self._add_attribute(relationship_.preferred_name, relationship_)
@classmethod
def _tablename_to_classname(cls, tablename, inflect_engine):
tablename = cls._convert_to_valid_identifier(tablename)
camel_case_name = ''.join(part[:1].upper() + part[1:] for part in tablename.split('_'))
return inflect_engine.singular_noun(camel_case_name) or camel_case_name
def _add_attribute(self, attrname, value):
attrname = tempname = self._convert_to_valid_identifier(attrname)
counter = 1
while tempname in self.attributes:
tempname = attrname + str(counter)
counter += 1
self.attributes[tempname] = value
return tempname
def add_imports(self, collector):
super(ModelClass, self).add_imports(collector)
if any(isinstance(value, Relationship) for value in self.attributes.values()):
collector.add_literal_import('sqlalchemy.orm', 'relationship')
for child in self.children:
child.add_imports(collector)
class Relationship(object):
def __init__(self, source_cls, target_cls):
super(Relationship, self).__init__()
self.source_cls = source_cls
self.target_cls = target_cls
self.kwargs = OrderedDict()
class ManyToOneRelationship(Relationship):
def __init__(self, source_cls, target_cls, constraint, inflect_engine):
super(ManyToOneRelationship, self).__init__(source_cls, target_cls)
column_names = _get_column_names(constraint)
colname = column_names[0]
tablename = constraint.elements[0].column.table.name
if not colname.endswith('_id'):
self.preferred_name = inflect_engine.singular_noun(tablename) or tablename
else:
self.preferred_name = colname[:-3]
# Add uselist=False to One-to-One relationships
if any(isinstance(c, (PrimaryKeyConstraint, UniqueConstraint)) and
set(col.name for col in c.columns) == set(column_names)
for c in constraint.table.constraints):
self.kwargs['uselist'] = 'False'
# Handle self referential relationships
if source_cls == target_cls:
self.preferred_name = 'parent' if not colname.endswith('_id') else colname[:-3]
pk_col_names = [col.name for col in constraint.table.primary_key]
self.kwargs['remote_side'] = '[{0}]'.format(', '.join(pk_col_names))
# If the two tables share more than one foreign key constraint,
# SQLAlchemy needs an explicit primaryjoin to figure out which column(s) to join with
common_fk_constraints = self.get_common_fk_constraints(
constraint.table, constraint.elements[0].column.table)
if len(common_fk_constraints) > 1:
self.kwargs['primaryjoin'] = "'{0}.{1} == {2}.{3}'".format(
source_cls, column_names[0], target_cls, constraint.elements[0].column.name)
@staticmethod
def get_common_fk_constraints(table1, table2):
"""Returns a set of foreign key constraints the two tables have against each other."""
c1 = set(c for c in table1.constraints if isinstance(c, ForeignKeyConstraint) and
c.elements[0].column.table == table2)
c2 = set(c for c in table2.constraints if isinstance(c, ForeignKeyConstraint) and
c.elements[0].column.table == table1)
return c1.union(c2)
class ManyToManyRelationship(Relationship):
def __init__(self, source_cls, target_cls, assocation_table):
super(ManyToManyRelationship, self).__init__(source_cls, target_cls)
prefix = (assocation_table.schema + '.') if assocation_table.schema else ''
self.kwargs['secondary'] = repr(prefix + assocation_table.name)
constraints = [c for c in assocation_table.constraints
if isinstance(c, ForeignKeyConstraint)]
constraints.sort(key=_get_constraint_sort_key)
colname = _get_column_names(constraints[1])[0]
tablename = constraints[1].elements[0].column.table.name
self.preferred_name = tablename if not colname.endswith('_id') else colname[:-3] + 's'
# Handle self referential relationships
if source_cls == target_cls:
self.preferred_name = 'parents' if not colname.endswith('_id') else colname[:-3] + 's'
pri_pairs = zip(_get_column_names(constraints[0]), constraints[0].elements)
sec_pairs = zip(_get_column_names(constraints[1]), constraints[1].elements)
pri_joins = ['{0}.{1} == {2}.c.{3}'.format(source_cls, elem.column.name,
assocation_table.name, col)
for col, elem in pri_pairs]
sec_joins = ['{0}.{1} == {2}.c.{3}'.format(target_cls, elem.column.name,
assocation_table.name, col)
for col, elem in sec_pairs]
self.kwargs['primaryjoin'] = (
repr('and_({0})'.format(', '.join(pri_joins)))
if len(pri_joins) > 1 else repr(pri_joins[0]))
self.kwargs['secondaryjoin'] = (
repr('and_({0})'.format(', '.join(sec_joins)))
if len(sec_joins) > 1 else repr(sec_joins[0]))
class CodeGenerator(object):
template = """\
# coding: utf-8
{imports}
{metadata_declarations}
{models}"""
def __init__(self, metadata, noindexes=False, noconstraints=False, nojoined=False,
noinflect=False, noclasses=False, indentation=' ', model_separator='\n\n',
ignored_tables=('alembic_version', 'migrate_version'), table_model=ModelTable,
class_model=ModelClass, template=None, nocomments=False):
super(CodeGenerator, self).__init__()
self.metadata = metadata
self.noindexes = noindexes
self.noconstraints = noconstraints
self.nojoined = nojoined
self.noinflect = noinflect
self.noclasses = noclasses
self.indentation = indentation
self.model_separator = model_separator
self.ignored_tables = ignored_tables
self.table_model = table_model
self.class_model = class_model
self.nocomments = nocomments
self.inflect_engine = self.create_inflect_engine()
if template:
self.template = template
# Pick association tables from the metadata into their own set, don't process them normally
links = defaultdict(lambda: [])
association_tables = set()
for table in metadata.tables.values():
# Link tables have exactly two foreign key constraints and all columns are involved in
# them
fk_constraints = [constr for constr in table.constraints
if isinstance(constr, ForeignKeyConstraint)]
if len(fk_constraints) == 2 and all(col.foreign_keys for col in table.columns):
association_tables.add(table.name)
tablename = sorted(
fk_constraints, key=_get_constraint_sort_key)[0].elements[0].column.table.name
links[tablename].append(table)
# Iterate through the tables and create model classes when possible
self.models = []
self.collector = ImportCollector()
classes = {}
for table in metadata.sorted_tables:
# Support for Alembic and sqlalchemy-migrate -- never expose the schema version tables
if table.name in self.ignored_tables:
continue
if noindexes:
table.indexes.clear()
if noconstraints:
table.constraints = {table.primary_key}
table.foreign_keys.clear()
for col in table.columns:
col.foreign_keys.clear()
else:
# Detect check constraints for boolean and enum columns
for constraint in table.constraints.copy():
if isinstance(constraint, CheckConstraint):
sqltext = self._get_compiled_expression(constraint.sqltext)
# Turn any integer-like column with a CheckConstraint like
# "column IN (0, 1)" into a Boolean
match = _re_boolean_check_constraint.match(sqltext)
if match:
colname = _re_column_name.match(match.group(1)).group(3)
table.constraints.remove(constraint)
table.c[colname].type = Boolean()
continue
# Turn any string-type column with a CheckConstraint like
# "column IN (...)" into an Enum
match = _re_enum_check_constraint.match(sqltext)
if match:
colname = _re_column_name.match(match.group(1)).group(3)
items = match.group(2)
if isinstance(table.c[colname].type, String):
table.constraints.remove(constraint)
if not isinstance(table.c[colname].type, Enum):
options = _re_enum_item.findall(items)
table.c[colname].type = Enum(*options, native_enum=False)
continue
# Only form model classes for tables that have a primary key and are not association
# tables
if noclasses or not table.primary_key or table.name in association_tables:
model = self.table_model(table)
else:
model = self.class_model(table, links[table.name], self.inflect_engine,
not nojoined)
classes[model.name] = model
self.models.append(model)
model.add_imports(self.collector)
# Nest inherited classes in their superclasses to ensure proper ordering
for model in classes.values():
if model.parent_name != 'Base':
classes[model.parent_name].children.append(model)
self.models.remove(model)
# Add either the MetaData or declarative_base import depending on whether there are mapped
# classes or not
if not any(isinstance(model, self.class_model) for model in self.models):
self.collector.add_literal_import('sqlalchemy', 'MetaData')
else:
self.collector.add_literal_import('sqlalchemy.ext.declarative', 'declarative_base')
def create_inflect_engine(self):
if self.noinflect:
return _DummyInflectEngine()
else:
import inflect
return inflect.engine()
def render_imports(self):
return '\n'.join('from {0} import {1}'.format(package, ', '.join(sorted(names)))
for package, names in self.collector.items())
def render_metadata_declarations(self):
if 'sqlalchemy.ext.declarative' in self.collector:
return 'Base = declarative_base()\nmetadata = Base.metadata'
return 'metadata = MetaData()'
def _get_compiled_expression(self, statement):
"""Return the statement in a form where any placeholders have been filled in."""
return str(statement.compile(
self.metadata.bind, compile_kwargs={"literal_binds": True}))
@staticmethod
def _getargspec_init(method):
try:
if hasattr(inspect, 'getfullargspec'):
return inspect.getfullargspec(method)
else:
return inspect.getargspec(method)
except TypeError:
if method is object.__init__:
return ArgSpec(['self'], None, None, None)
else:
return ArgSpec(['self'], 'args', 'kwargs', None)
@classmethod
def render_column_type(cls, coltype):
args = []
kwargs = OrderedDict()
argspec = cls._getargspec_init(coltype.__class__.__init__)
defaults = dict(zip(argspec.args[-len(argspec.defaults or ()):],
argspec.defaults or ()))
missing = object()
use_kwargs = False
for attr in argspec.args[1:]:
# Remove annoyances like _warn_on_bytestring
if attr.startswith('_'):
continue
value = getattr(coltype, attr, missing)
default = defaults.get(attr, missing)
if value is missing or value == default:
use_kwargs = True
elif use_kwargs:
kwargs[attr] = repr(value)
else:
args.append(repr(value))
if argspec.varargs and hasattr(coltype, argspec.varargs):
varargs_repr = [repr(arg) for arg in getattr(coltype, argspec.varargs)]
args.extend(varargs_repr)
if isinstance(coltype, Enum) and coltype.name is not None:
kwargs['name'] = repr(coltype.name)
for key, value in kwargs.items():
args.append('{}={}'.format(key, value))
rendered = coltype.__class__.__name__
if args:
rendered += '({0})'.format(', '.join(args))
return rendered
def render_constraint(self, constraint):
def render_fk_options(*opts):
opts = [repr(opt) for opt in opts]
for attr in 'ondelete', 'onupdate', 'deferrable', 'initially', 'match':
value = getattr(constraint, attr, None)
if value:
opts.append('{0}={1!r}'.format(attr, value))
return ', '.join(opts)
if isinstance(constraint, ForeignKey):
remote_column = '{0}.{1}'.format(constraint.column.table.fullname,
constraint.column.name)
return 'ForeignKey({0})'.format(render_fk_options(remote_column))
elif isinstance(constraint, ForeignKeyConstraint):
local_columns = _get_column_names(constraint)
remote_columns = ['{0}.{1}'.format(fk.column.table.fullname, fk.column.name)
for fk in constraint.elements]
return 'ForeignKeyConstraint({0})'.format(
render_fk_options(local_columns, remote_columns))
elif isinstance(constraint, CheckConstraint):
return 'CheckConstraint({0!r})'.format(
self._get_compiled_expression(constraint.sqltext))
elif isinstance(constraint, UniqueConstraint):
columns = [repr(col.name) for col in constraint.columns]
return 'UniqueConstraint({0})'.format(', '.join(columns))
@staticmethod
def render_index(index):
extra_args = [repr(col.name) for col in index.columns]
if index.unique:
extra_args.append('unique=True')
return 'Index({0!r}, {1})'.format(index.name, ', '.join(extra_args))
def render_column(self, column, show_name):
kwarg = []
is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
dedicated_fks = [c for c in column.foreign_keys if len(c.constraint.columns) == 1]
is_unique = any(isinstance(c, UniqueConstraint) and set(c.columns) == {column}
for c in column.table.constraints)
is_unique = is_unique or any(i.unique and set(i.columns) == {column}
for i in column.table.indexes)
has_index = any(set(i.columns) == {column} for i in column.table.indexes)
server_default = None
# Render the column type if there are no foreign keys on it or any of them points back to
# itself
render_coltype = not dedicated_fks or any(fk.column is column for fk in dedicated_fks)
if column.key != column.name:
kwarg.append('key')
if column.primary_key:
kwarg.append('primary_key')
if not column.nullable and not is_sole_pk:
kwarg.append('nullable')
if is_unique:
column.unique = True
kwarg.append('unique')
elif has_index:
column.index = True
kwarg.append('index')
if Computed and isinstance(column.server_default, Computed):
expression = self._get_compiled_expression(column.server_default.sqltext)
persist_arg = ''
if column.server_default.persisted is not None:
persist_arg = ', persisted={}'.format(column.server_default.persisted)
server_default = 'Computed({!r}{})'.format(expression, persist_arg)
elif column.server_default:
# The quote escaping does not cover pathological cases but should mostly work
default_expr = self._get_compiled_expression(column.server_default.arg)
if '\n' in default_expr:
server_default = 'server_default=text("""\\\n{0}""")'.format(default_expr)
else:
default_expr = default_expr.replace('"', '\\"')
server_default = 'server_default=text("{0}")'.format(default_expr)
comment = getattr(column, 'comment', None)
return 'Column({0})'.format(', '.join(
([repr(column.name)] if show_name else []) +
([self.render_column_type(column.type)] if render_coltype else []) +
[self.render_constraint(x) for x in dedicated_fks] +
[repr(x) for x in column.constraints] +
['{0}={1}'.format(k, repr(getattr(column, k))) for k in kwarg] +
([server_default] if server_default else []) +
(['comment={!r}'.format(comment)] if comment and not self.nocomments else [])
))
def render_relationship(self, relationship):
rendered = 'relationship('
args = [repr(relationship.target_cls)]
if 'secondaryjoin' in relationship.kwargs:
rendered += '\n{0}{0}'.format(self.indentation)
delimiter, end = (',\n{0}{0}'.format(self.indentation),
'\n{0})'.format(self.indentation))
else:
delimiter, end = ', ', ')'
args.extend([key + '=' + value for key, value in relationship.kwargs.items()])
return rendered + delimiter.join(args) + end
def render_table(self, model):
rendered = 't_{0} = Table(\n{2}{1!r}, metadata,\n'.format(
model.name, model.table.name, self.indentation)
for column in model.table.columns:
rendered += '{0}{1},\n'.format(self.indentation, self.render_column(column, True))
for constraint in sorted(model.table.constraints, key=_get_constraint_sort_key):
if isinstance(constraint, PrimaryKeyConstraint):
continue
if (isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)) and
len(constraint.columns) == 1):
continue
rendered += '{0}{1},\n'.format(self.indentation, self.render_constraint(constraint))
for index in model.table.indexes:
if len(index.columns) > 1:
rendered += '{0}{1},\n'.format(self.indentation, self.render_index(index))
if model.schema:
rendered += "{0}schema='{1}',\n".format(self.indentation, model.schema)
table_comment = getattr(model.table, 'comment', None)
if table_comment:
quoted_comment = table_comment.replace("'", "\\'").replace('"', '\\"')
rendered += "{0}comment='{1}',\n".format(self.indentation, quoted_comment)
return rendered.rstrip('\n,') + '\n)\n'
def render_class(self, model):
rendered = 'class {0}({1}):\n'.format(model.name, model.parent_name)
rendered += '{0}__tablename__ = {1!r}\n'.format(self.indentation, model.table.name)
# Render constraints and indexes as __table_args__
table_args = []
for constraint in sorted(model.table.constraints, key=_get_constraint_sort_key):
if isinstance(constraint, PrimaryKeyConstraint):
continue
if (isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)) and
len(constraint.columns) == 1):
continue
table_args.append(self.render_constraint(constraint))
for index in model.table.indexes:
if len(index.columns) > 1:
table_args.append(self.render_index(index))
table_kwargs = {}
if model.schema:
table_kwargs['schema'] = model.schema
table_comment = getattr(model.table, 'comment', None)
if table_comment:
table_kwargs['comment'] = table_comment
kwargs_items = ', '.join('{0!r}: {1!r}'.format(key, table_kwargs[key])
for key in table_kwargs)
kwargs_items = '{{{0}}}'.format(kwargs_items) if kwargs_items else None
if table_kwargs and not table_args:
rendered += '{0}__table_args__ = {1}\n'.format(self.indentation, kwargs_items)
elif table_args:
if kwargs_items:
table_args.append(kwargs_items)
if len(table_args) == 1:
table_args[0] += ','
table_args_joined = ',\n{0}{0}'.format(self.indentation).join(table_args)
rendered += '{0}__table_args__ = (\n{0}{0}{1}\n{0})\n'.format(
self.indentation, table_args_joined)
# Render columns
rendered += '\n'
for attr, column in model.attributes.items():
if isinstance(column, Column):
show_name = attr != column.name
rendered += '{0}{1} = {2}\n'.format(
self.indentation, attr, self.render_column(column, show_name))
# Render relationships
if any(isinstance(value, Relationship) for value in model.attributes.values()):
rendered += '\n'
for attr, relationship in model.attributes.items():
if isinstance(relationship, Relationship):
rendered += '{0}{1} = {2}\n'.format(
self.indentation, attr, self.render_relationship(relationship))
# Render subclasses
for child_class in model.children:
rendered += self.model_separator + self.render_class(child_class)
return rendered
def render(self, outfile=sys.stdout):
rendered_models = []
for model in self.models:
if isinstance(model, self.class_model):
rendered_models.append(self.render_class(model))
elif isinstance(model, self.table_model):
rendered_models.append(self.render_table(model))
output = self.template.format(
imports=self.render_imports(),
metadata_declarations=self.render_metadata_declarations(),
models=self.model_separator.join(rendered_models).rstrip('\n'))
print(output, file=outfile)