123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- """
- Helpers to manipulate deferred DDL statements that might need to be adjusted or
- discarded within when executing a migration.
- """
- from copy import deepcopy
- class Reference:
- """Base class that defines the reference interface."""
- def references_table(self, table):
- """
- Return whether or not this instance references the specified table.
- """
- return False
- def references_column(self, table, column):
- """
- Return whether or not this instance references the specified column.
- """
- return False
- def rename_table_references(self, old_table, new_table):
- """
- Rename all references to the old_name to the new_table.
- """
- pass
- def rename_column_references(self, table, old_column, new_column):
- """
- Rename all references to the old_column to the new_column.
- """
- pass
- def __repr__(self):
- return "<%s %r>" % (self.__class__.__name__, str(self))
- def __str__(self):
- raise NotImplementedError(
- "Subclasses must define how they should be converted to string."
- )
- class Table(Reference):
- """Hold a reference to a table."""
- def __init__(self, table, quote_name):
- self.table = table
- self.quote_name = quote_name
- def references_table(self, table):
- return self.table == table
- def rename_table_references(self, old_table, new_table):
- if self.table == old_table:
- self.table = new_table
- def __str__(self):
- return self.quote_name(self.table)
- class TableColumns(Table):
- """Base class for references to multiple columns of a table."""
- def __init__(self, table, columns):
- self.table = table
- self.columns = columns
- def references_column(self, table, column):
- return self.table == table and column in self.columns
- def rename_column_references(self, table, old_column, new_column):
- if self.table == table:
- for index, column in enumerate(self.columns):
- if column == old_column:
- self.columns[index] = new_column
- class Columns(TableColumns):
- """Hold a reference to one or many columns."""
- def __init__(self, table, columns, quote_name, col_suffixes=()):
- self.quote_name = quote_name
- self.col_suffixes = col_suffixes
- super().__init__(table, columns)
- def __str__(self):
- def col_str(column, idx):
- col = self.quote_name(column)
- try:
- suffix = self.col_suffixes[idx]
- if suffix:
- col = "{} {}".format(col, suffix)
- except IndexError:
- pass
- return col
- return ", ".join(
- col_str(column, idx) for idx, column in enumerate(self.columns)
- )
- class IndexName(TableColumns):
- """Hold a reference to an index name."""
- def __init__(self, table, columns, suffix, create_index_name):
- self.suffix = suffix
- self.create_index_name = create_index_name
- super().__init__(table, columns)
- def __str__(self):
- return self.create_index_name(self.table, self.columns, self.suffix)
- class IndexColumns(Columns):
- def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
- self.opclasses = opclasses
- super().__init__(table, columns, quote_name, col_suffixes)
- def __str__(self):
- def col_str(column, idx):
- # Index.__init__() guarantees that self.opclasses is the same
- # length as self.columns.
- col = "{} {}".format(self.quote_name(column), self.opclasses[idx])
- try:
- suffix = self.col_suffixes[idx]
- if suffix:
- col = "{} {}".format(col, suffix)
- except IndexError:
- pass
- return col
- return ", ".join(
- col_str(column, idx) for idx, column in enumerate(self.columns)
- )
- class ForeignKeyName(TableColumns):
- """Hold a reference to a foreign key name."""
- def __init__(
- self,
- from_table,
- from_columns,
- to_table,
- to_columns,
- suffix_template,
- create_fk_name,
- ):
- self.to_reference = TableColumns(to_table, to_columns)
- self.suffix_template = suffix_template
- self.create_fk_name = create_fk_name
- super().__init__(
- from_table,
- from_columns,
- )
- def references_table(self, table):
- return super().references_table(table) or self.to_reference.references_table(
- table
- )
- def references_column(self, table, column):
- return super().references_column(
- table, column
- ) or self.to_reference.references_column(table, column)
- def rename_table_references(self, old_table, new_table):
- super().rename_table_references(old_table, new_table)
- self.to_reference.rename_table_references(old_table, new_table)
- def rename_column_references(self, table, old_column, new_column):
- super().rename_column_references(table, old_column, new_column)
- self.to_reference.rename_column_references(table, old_column, new_column)
- def __str__(self):
- suffix = self.suffix_template % {
- "to_table": self.to_reference.table,
- "to_column": self.to_reference.columns[0],
- }
- return self.create_fk_name(self.table, self.columns, suffix)
- class Statement(Reference):
- """
- Statement template and formatting parameters container.
- Allows keeping a reference to a statement without interpolating identifiers
- that might have to be adjusted if they're referencing a table or column
- that is removed
- """
- def __init__(self, template, **parts):
- self.template = template
- self.parts = parts
- def references_table(self, table):
- return any(
- hasattr(part, "references_table") and part.references_table(table)
- for part in self.parts.values()
- )
- def references_column(self, table, column):
- return any(
- hasattr(part, "references_column") and part.references_column(table, column)
- for part in self.parts.values()
- )
- def rename_table_references(self, old_table, new_table):
- for part in self.parts.values():
- if hasattr(part, "rename_table_references"):
- part.rename_table_references(old_table, new_table)
- def rename_column_references(self, table, old_column, new_column):
- for part in self.parts.values():
- if hasattr(part, "rename_column_references"):
- part.rename_column_references(table, old_column, new_column)
- def __str__(self):
- return self.template % self.parts
- class Expressions(TableColumns):
- def __init__(self, table, expressions, compiler, quote_value):
- self.compiler = compiler
- self.expressions = expressions
- self.quote_value = quote_value
- columns = [
- col.target.column
- for col in self.compiler.query._gen_cols([self.expressions])
- ]
- super().__init__(table, columns)
- def rename_table_references(self, old_table, new_table):
- if self.table != old_table:
- return
- self.expressions = self.expressions.relabeled_clone({old_table: new_table})
- super().rename_table_references(old_table, new_table)
- def rename_column_references(self, table, old_column, new_column):
- if self.table != table:
- return
- expressions = deepcopy(self.expressions)
- self.columns = []
- for col in self.compiler.query._gen_cols([expressions]):
- if col.target.column == old_column:
- col.target.column = new_column
- self.columns.append(col.target.column)
- self.expressions = expressions
- def __str__(self):
- sql, params = self.compiler.compile(self.expressions)
- params = map(self.quote_value, params)
- return sql % tuple(params)
|