from collections.abc import Callable, ValuesView
from typing import Any
import peewee as pw
from playhouse.migrate import (
SQL,
MySQLDatabase,
Operation,
PostgresqlDatabase,
SqliteDatabase,
operation,
)
from playhouse.migrate import MySQLMigrator as MqM
from playhouse.migrate import PostgresqlMigrator as PgM
from playhouse.migrate import SchemaMigrator as ScM
from playhouse.migrate import SqliteMigrator as SqM
from miggy import LOGGER
from miggy.auto import FieldComparer, resolve_field
from miggy.types import ModelCls
from miggy.utils import (
ModelIndex,
copy_model,
delete_field,
get_default_constraint,
get_default_constraint_value,
get_single_index,
get_single_index_name,
has_single_index,
indexes_state,
make_single_index,
)
ModelDict = dict[str, ModelCls]
RunPythonF = Callable[["SchemaMigrator", "State"], None]
[docs]
class State:
"""
Current state containing historical models that match the operation’s place in the project history.
This is a dict-like class that stores data in the format model_name: model_class.
The model_name is case-insensitive.
Example::
User = state["user"]
User.get(id=1)
"""
def __init__(self, data: ModelDict | None = None) -> None:
self.data: ModelDict = data or {}
self._snapshot: ModelDict | None = None
def normalize_key(self, key: str) -> str:
_key = key.lower()
if self._snapshot is not None:
if _key in self._snapshot:
self._snapshot[_key] = copy_model(self._snapshot[_key])
return _key
def __setitem__(self, key: str, val: ModelCls) -> None:
self.data[self.normalize_key(key)] = val
def __getitem__(self, key: str) -> ModelCls:
return self.data[self.normalize_key(key)]
def __delitem__(self, key: str) -> None:
del self.data[self.normalize_key(key)]
def __contains__(self, key: str) -> bool:
return self.normalize_key(key) in self.data
def values(self) -> ValuesView[ModelCls]:
return self.data.values()
def create_snapshot(self) -> None:
self._snapshot = self.data.copy()
def pop_snapshot(self) -> "State":
_snapshot = self._snapshot
self._snapshot = None
return State(_snapshot)
[docs]
class MigrateOperation:
"""
Base class for a migrate operation
"""
[docs]
def state_forwards(self, state: State) -> None:
"""
Take the state from the previous migration, and mutate it
so that it matches what this migration would perform.
"""
raise NotImplementedError
[docs]
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation] | list[Callable]:
"""
Perform the mutation on the database schema in the normal
(forwards) direction.
The method MUST NOT mutate provided states.
"""
raise NotImplementedError
[docs]
class RunPython(MigrateOperation):
"""
Allows to run custom Python code. **func** should be callable object that accept two arguments;
the first is an instance of :class:`SchemaMigrator` and the second is an instance of :class:`State`
Example::
def save_user(schema_migrator: SchemaMigrator, current_state: State):
User = current_state["user"]
User(
first_name="First",
last_name="Last",
).save()
migrator.add_operaion(RunPython(save_user))
"""
def __init__(self, func: RunPythonF) -> None:
self.func = func
def state_forwards(self, state: State) -> None:
pass
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation] | list[Callable]:
return [lambda: self.func(schema_migrator, from_state)]
[docs]
class RunSql(MigrateOperation):
"""
Allows running of arbitrary SQL on the database -
useful for more advanced features of database backends that Miggy doesn’t support directly.
Example::
migrator.add_operation(
RunSql(
'INSERT INTO "user" ("first_name", "last_name") VALUES (%s, %s)',
(
"First",
"Last",
),
)
)
"""
def __init__(self, sql: str, params: tuple[Any, ...] | None = None) -> None:
self.sql = sql
self.params = params
def state_forwards(self, state: State) -> None:
pass
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation] | list[Callable]:
return [schema_migrator.sql(self.sql, self.params)]
[docs]
class CreateModel(MigrateOperation):
"""
Creates a new model in the :class:`State` and a corresponding table in the database to match it.
"""
def __init__(self, model: ModelCls) -> None:
self.model = model
def state_forwards(self, state: State) -> None:
state[self.model._meta.name] = self.model
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Callable]:
model = to_state[self.model._meta.name]
return [schema_migrator.create_table(model)]
[docs]
class RemoveModel(MigrateOperation):
"""
Deletes the model from the :class:`State` and its table from the database.
"""
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def state_forwards(self, state: State) -> None:
del state[self.model_name]
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Callable]:
model = from_state[self.model_name]
return [schema_migrator.drop_table(model)]
[docs]
class AddIndex(MigrateOperation):
"""
Creates an index in the database table for the model with model_name.
The index will be saved in **Model._meta.indexes_state** dict
"""
def __init__(
self,
model_name: str,
*fields: str,
name: str,
unique: bool = False,
where: pw.SQL | None = None,
safe: bool = False,
concurrently: bool = False,
) -> None:
self.model_name = model_name
self.fields = fields
self.unique = unique
self.where = where
self.name = name
self.safe = safe
self.concurrently = concurrently
self._index: ModelIndex | None = None
def build_index(self, model: ModelCls) -> ModelIndex:
if not self._index:
self._index = ModelIndex(
model=model,
fields=[resolve_field(model, f) for f in self.fields],
unique=self.unique,
where=self.where,
name=self.name,
safe=self.safe,
concurrently=self.concurrently,
)
return self._index
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
model_index = self.build_index(model)
indexes_state(model)[model_index._name] = model_index
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
model = to_state[self.model_name]
model_index = self.build_index(model)
return [schema_migrator.add_model_index(model_index)]
[docs]
class DropIndex(MigrateOperation):
"""
Removes the index named name from the model with model_name.
"""
def __init__(self, model_name: str, name: str) -> None:
self.model_name = model_name
self.name = name
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
del indexes_state(model)[self.name]
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
model = from_state[self.model_name]
return [schema_migrator.drop_index(model._meta.table_name, self.name)]
[docs]
class AddFields(MigrateOperation):
"""
Adds fields to a model.
"""
def __init__(self, model_name: str, **fields: pw.Field) -> None:
self.model_name = model_name
self.fields = fields
def state_forwards(self, state: State) -> None:
for name, field in self.fields.items():
state[self.model_name]._meta.add_field(name, field)
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
ops = []
model = to_state[self.model_name]
for field in self.fields.values():
ops.append(schema_migrator.add_column(model._meta.table_name, field.column_name, field))
return ops
[docs]
class ChangeFields(MigrateOperation):
"""
Change fields to a model.
"""
def __init__(self, model_name: str, **fields: pw.Field) -> None:
self.model_name = model_name
self.fields = fields
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
for name, field in self.fields.items():
model._meta.add_field(name, field)
def handle_indexes(
self, old_field: pw.Field, new_field: pw.Field, schema_migrator: "SchemaMigrator"
) -> list[Operation]:
_ops = []
_field = new_field
if _field.unique and old_field.unique:
return []
if not _field.unique and not old_field.unique and _field.index == old_field.index:
return []
table_name = old_field.model._meta.table_name
if has_single_index(old_field):
# We have already renamed the column so create name from the new field
_ops.append(schema_migrator.drop_index(table_name, get_single_index_name(_field)))
if model_index := get_single_index(_field):
_ops.append(schema_migrator.add_model_index(model_index))
return _ops
def handle_fk_constraint(
self, old_field: pw.Field, new_field: pw.Field, schema_migrator: "SchemaMigrator"
) -> list[Operation]:
_ops: list[Operation] = []
is_old_field_fk = isinstance(old_field, pw.ForeignKeyField)
is_new_field_fk = isinstance(new_field, pw.ForeignKeyField)
if (
is_old_field_fk
and is_new_field_fk
and FieldComparer.fk_to_params(old_field) == FieldComparer.fk_to_params(new_field)
):
# Nothing's changed for fk
return _ops
table_name = old_field.model._meta.table_name
if is_old_field_fk:
_ops.append(schema_migrator.drop_foreign_key_constraint(table_name, new_field.column_name))
if is_new_field_fk:
_ops.append(
schema_migrator.add_foreign_key_constraint(
table_name,
new_field.column_name,
new_field.rel_model._meta.table_name,
new_field.rel_field.name,
new_field.on_delete,
new_field.on_update,
constraint_name=new_field.constraint_name,
)
)
return _ops
def handle_default_constraint(
self, old_field: pw.Field, new_field: pw.Field, schema_migrator: "SchemaMigrator"
) -> list[Operation]:
old_value = get_default_constraint_value(old_field) or ""
new_value = get_default_constraint_value(new_field) or ""
table_name = old_field.model._meta.table_name
if old_value != new_value:
if new_value:
return [schema_migrator.add_column_default(table_name, new_field.column_name, new_value)]
else:
return [
schema_migrator.drop_column_default(
table_name,
new_field.column_name,
)
]
return []
def handle_type(
self, old_field: pw.Field, new_field: pw.Field, schema_migrator: "SchemaMigrator"
) -> list[Operation]:
old_field_comparer = FieldComparer(old_field)
new_field_comparer = FieldComparer(new_field)
if (
old_field_comparer.field_type is not new_field_comparer.field_type
or old_field_comparer.get_type_params() != new_field_comparer.get_type_params()
):
table_name = old_field.model._meta.table_name
return [schema_migrator.alter_column_type(table_name, new_field.column_name, new_field)]
return []
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
_ops = []
model = from_state[self.model_name]
table_name = model._meta.table_name
for name, field in self.fields.items():
old_field = getattr(model, name)
old_column_name = old_field.column_name
if old_column_name != field.column_name:
_ops.append(schema_migrator.rename_field(table_name, old_field, field))
_ops.extend(self.handle_type(old_field, field, schema_migrator))
_ops.extend(self.handle_fk_constraint(old_field, field, schema_migrator))
_ops.extend(self.handle_default_constraint(old_field, field, schema_migrator))
if old_field.null != field.null:
_operation = schema_migrator.drop_not_null if field.null else schema_migrator.add_not_null
_ops.append(_operation(table_name, field.column_name))
_ops.extend(self.handle_indexes(old_field, field, schema_migrator))
return _ops
[docs]
class RemoveFields(MigrateOperation):
"""
Removes fields from a model
"""
def __init__(self, model_name: str, *names: str, cascade: bool = False) -> None:
self.model_name = model_name
self.cascade = cascade
self.names = names
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
for name in self.names:
field = state[self.model_name]._meta.fields[name]
delete_field(model, field)
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
"""Remove fields from model."""
ops = []
model = from_state[self.model_name]
for name in self.names:
field = model._meta.fields[name]
ops.append(schema_migrator.drop_column(model._meta.table_name, field.column_name, cascade=self.cascade))
return ops
def fk_postfix(name: str) -> str:
return name if name.endswith("_id") else name + "_id"
[docs]
class RenameField(MigrateOperation):
"""
Changes a field’s name (and, unless **column_name** is set, its column name).
It also renames a single-column indexe, if it exists.
**Warning:**
This operation does not rename indexes created via the **Meta** class or the **add_index()** method.
You should explicitly specify index names if you plan to use this operation.
Otherwise, you will be prompted to recreate the index with a new name in the next migration.
"""
def __init__(self, model_name: str, old_name: str, new_name: str) -> None:
self.model_name = model_name
self.old_field_name = old_name
self.new_field_name = new_name
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
old_field = model._meta.fields[self.old_field_name]
new_field = old_field.clone()
delete_field(model, old_field)
new_field.column_name = self.resolve_new_name(old_field, self.new_field_name)
model._meta.add_field(self.new_field_name, new_field)
def resolve_new_name(self, old_field: pw.Field, new_name: str) -> str:
if isinstance(old_field, pw.ForeignKeyField):
if old_field.column_name == fk_postfix(old_field.name):
return fk_postfix(new_name)
if old_field.column_name == old_field.name:
return new_name
return old_field.column_name
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
old_model = from_state[self.model_name]
new_model = to_state[self.model_name]
old_field = old_model._meta.fields[self.old_field_name]
new_field = new_model._meta.fields[self.new_field_name]
if old_field.column_name != new_field.column_name:
return [schema_migrator.rename_field(new_model._meta.table_name, old_field, new_field)]
return []
class ChangeNullable(MigrateOperation):
def __init__(self, model_name: str, *names: str, is_null: bool) -> None:
self.model_name = model_name
self.names = names
self.is_null = is_null
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
for name in self.names:
field = model._meta.fields[name]
field.null = self.is_null
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
ops = []
model = to_state[self.model_name]
for name in self.names:
field = model._meta.fields[name]
_operation = schema_migrator.drop_not_null if self.is_null else schema_migrator.add_not_null
ops.append(_operation(model._meta.table_name, field.column_name))
return ops
class Migration:
def __init__(self, state: State, schema_migrator: "SchemaMigrator", schema: str | None = None) -> None:
self.state = state
self.schema_migrator = schema_migrator
self.schema = schema
self.operations: list[Operation | Callable] = []
def append(self, op: MigrateOperation) -> None:
self.state.create_snapshot()
op.state_forwards(self.state)
from_state = self.state.pop_snapshot()
self.operations.extend(op.database_forwards(self.schema_migrator, from_state, self.state))
def apply(self, change_schema: bool) -> None:
if not change_schema:
return
if self.schema:
_ops = [self.schema_migrator.select_schema(self.schema), *self.operations]
else:
_ops = [*self.operations]
for op in _ops:
if isinstance(op, Operation):
LOGGER.info("%s %s", op.method, op.args)
op.run()
else:
op()
def clean(self) -> None:
self.operations = []
[docs]
class SchemaMigrator(ScM):
"""Extended **playhouse.migrate.SchemaMigrator** from **peewee**"""
@classmethod
def from_database(cls, database):
"""Initialize migrator by db."""
if isinstance(database, PostgresqlDatabase):
return PostgresqlMigrator(database)
if isinstance(database, SqliteDatabase):
return SqliteMigrator(database)
if isinstance(database, MySQLDatabase):
return MySQLMigrator(database)
return super(SchemaMigrator, cls).from_database(database)
@operation
def select_schema(self, schema):
"""Select database schema"""
raise NotImplementedError()
[docs]
@operation
def sql(self, sql, params: tuple[Any, ...] | None = None):
"""Execute raw SQL."""
return SQL(sql, params)
@operation
def add_column(self, table, column_name, field):
# Adding a column is complicated by the fact that if there are rows
# present and the field is non-null, then we need to first add the
# column as a nullable field, then set the value, then add a not null
# constraint.
default_constraint = get_default_constraint(field)
if not field.null and field.default is None and not default_constraint:
raise ValueError("%s is not null but has no default" % column_name)
is_foreign_key = isinstance(field, pw.ForeignKeyField)
if is_foreign_key and not field.rel_field:
raise ValueError("Foreign keys must specify a `field`.")
operations = [self.alter_add_column(table, column_name, field)]
# In the event the field is *not* nullable and has no default constraint, update with the default
# value and set not null.
if not field.null:
if not default_constraint:
operations.append(
self.apply_default(table, column_name, field),
)
operations.append(self.add_not_null(table, column_name))
if is_foreign_key and self.explicit_create_foreign_key:
operations.append(
self.add_foreign_key_constraint(
table,
column_name,
field.rel_model._meta.table_name,
field.rel_field.column_name,
field.on_delete,
field.on_update,
)
)
if model_index := get_single_index(field):
operations.append(self.add_model_index(model_index))
return operations
@operation
def add_model_index(self, model_index: ModelIndex):
ctx = self.make_context()
return ctx.sql(model_index)
[docs]
@operation
def rename_index(self, old_name: str, new_name: str):
"""Change index name"""
ctx = self.make_context()
return ctx.literal("ALTER INDEX ").sql(pw.Entity(old_name)).literal(" RENAME TO ").sql(pw.Entity(new_name))
@operation
def resolve_single_index_name(self, old_field: pw.Field, new_field: pw.Field):
operations = []
if old_model_index := get_single_index(old_field):
new_single_index = make_single_index(new_field)
operations.append(self.rename_index(old_model_index._name, new_single_index._name))
return operations
@operation
def rename_field(self, table: str, old_field: pw.Field, new_field: pw.Field):
operations = [self.rename_column(table, old_field.column_name, new_field.column_name)]
operations.append(self.resolve_single_index_name(old_field, new_field))
return operations
[docs]
def create_table(self, model: ModelCls, safe: bool = False) -> Callable:
"""
Create table from model class
"""
model._meta.database = self.database
model._meta.legacy_table_names = False
return lambda: model.create_table(safe=safe)
[docs]
def drop_table(self, model: ModelCls, safe: bool = False) -> Callable:
"""
Drop model table
"""
model._meta.database = self.database
return lambda: model.drop_table(safe=safe)
class MySQLMigrator(SchemaMigrator, MqM):
def alter_change_column(self, table, column, field):
"""Support change columns."""
ctx = self.make_context()
field_null, field.null = field.null, True
ctx = self._alter_table(ctx, table).literal(" MODIFY COLUMN ").sql(field.ddl(ctx))
field.null = field_null
return ctx
class PostgresqlMigrator(SchemaMigrator, PgM):
"""Support the migrations in postgresql."""
@operation
def select_schema(self, schema):
"""Select database schema"""
return self.set_search_path(schema)
def get_foreign_key_constraint(self, table: str, column_name: str) -> str:
sql = """
SELECT DISTINCT
kcu.constraint_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON (tc.constraint_name = kcu.constraint_name AND
tc.constraint_schema = kcu.constraint_schema AND
tc.table_name = kcu.table_name AND
tc.table_schema = kcu.table_schema)
JOIN information_schema.constraint_column_usage AS ccu
ON (ccu.constraint_name = tc.constraint_name AND
ccu.constraint_schema = tc.constraint_schema)
WHERE
tc.constraint_type = 'FOREIGN KEY' AND
tc.table_name = %s AND
tc.table_schema = current_schema() AND
kcu.column_name = %s"""
cursor = self.database.execute_sql(sql, (table, column_name))
return cursor.fetchall()[0][0]
@operation
def drop_foreign_key_constraint(self, table: str, column_name: str):
fk_constraint = self.get_foreign_key_constraint(table, column_name)
return self.drop_constraint(table, fk_constraint)
class SqliteMigrator(SchemaMigrator, SqM):
"""Support the migrations in sqlite."""
def drop_table(self, model, cascade=True):
"""SQLite doesnt support cascade syntax by default."""
return lambda: model.drop_table(cascade=False)
def alter_column_type(self, table, column, field):
"""Support change columns."""
return self._update_column(table, column, lambda a, b: b)
def drop_column(self, table, column_name, cascade=True, legacy=True, **kwargs):
"""drop_column will not work for FK so we should use the legacy version"""
return super(SqliteMigrator, self).drop_column(table, column_name, cascade, legacy, **kwargs)
[docs]
class Migrator(object):
"""
A class that provides shortcuts for adding migration operations.
"""
def __init__(self, database, schema=None):
"""Initialize the migrator."""
if isinstance(database, pw.Proxy):
database = database.obj
self.database = database
self.state = State()
self.schema_migrator = SchemaMigrator.from_database(self.database)
self.schema = schema
self.migration = Migration(self.state, self.schema_migrator, schema=schema)
[docs]
def add_operation(self, op: MigrateOperation) -> None:
"""
Adds a migrate operation
"""
self.migration.append(op)
def run(self, change_schema: bool = True):
self.migration.apply(change_schema)
self.clean()
[docs]
def python(self, func: RunPythonF):
"""A shortcut for adding a :class:`RunPython` operation."""
self.add_operation(RunPython(func))
[docs]
def sql(self, sql: str, params: tuple[Any, ...] | None = None) -> None:
"""A shortcut for adding a :class:`RunSql` operation."""
self.add_operation(RunSql(sql, params))
def clean(self):
"""Clean the operations."""
self.migration.clean()
[docs]
def create_model(self, model: ModelCls) -> ModelCls:
"""A shortcut for adding a :class:`CreateModel` operation."""
self.add_operation(CreateModel(model))
return model
create_table = create_model
[docs]
def remove_model(self, model_name: str) -> None:
"""A shortcut for adding a :class:`RemoveModel` operation."""
self.add_operation(RemoveModel(model_name))
drop_table = remove_model
[docs]
def add_fields(self, model_name: str, **fields: Any) -> None:
"""A shortcut for adding a :class:`AddFields` operation."""
self.add_operation(AddFields(model_name, **fields))
add_columns = add_fields
[docs]
def change_fields(self, model_name: str, **fields: pw.Field) -> None:
"""A shortcut for adding a :class:`ChangeFields` operation."""
return self.add_operation(ChangeFields(model_name, **fields))
change_columns = change_fields
[docs]
def remove_fields(self, model_name: str, *names: str, cascade: bool = False) -> None:
"""A shortcut for adding a :class:`RemoveFields` operation."""
self.add_operation(RemoveFields(model_name, *names, cascade=cascade))
drop_columns = remove_fields
[docs]
def rename_field(self, model_name: str, old_name: str, new_name: str) -> None:
"""A shortcut for adding a :class:`RenameField` operation."""
self.add_operation(RenameField(model_name, old_name, new_name))
rename_column = rename_field
[docs]
def rename_table(self, model_name: str, new_table_name: str) -> None:
"""A shortcut for adding a :class:`RenameTable` operation."""
self.add_operation(RenameTable(model_name, new_table_name))
rename_model = rename_table
[docs]
def add_index(
self,
model_name: str,
*fields: str,
name: str,
unique: bool = False,
where: pw.SQL | None = None,
safe: bool = False,
concurrently: bool = False,
) -> None:
"""A shortcut for adding a :class:`AddIndex` operation."""
self.add_operation(
AddIndex(model_name, *fields, name=name, unique=unique, where=where, safe=safe, concurrently=concurrently)
)
[docs]
def drop_index(self, model_name: str, name: str) -> None:
"""A shortcut for adding a :class:`DropIndex` operation."""
self.add_operation(DropIndex(model_name, name))
def add_not_null(self, model_name: str, *names: str) -> None:
self.add_operation(ChangeNullable(model_name, *names, is_null=False))
def drop_not_null(self, model_name: str, *names: str) -> None:
"""Drop not null."""
self.add_operation(ChangeNullable(model_name, *names, is_null=True))