import re
from collections import namedtuple
from collections.abc import Callable
from enum import Enum, auto
from typing import TYPE_CHECKING, Any
import peewee as pw
from playhouse.migrate import Operation
from miggy.deconstructor import ForeignKeyFieldDeconstructor, deconstructor_factory
from miggy.state import State
from miggy.types import ModelCls
from miggy.utils import (
ModelIndex,
fk_postfix,
get_default_constraint_value,
get_single_index,
get_single_index_name,
has_single_index,
indexes_state,
resolve_field,
)
if TYPE_CHECKING:
from miggy.schema import SchemaMigrator
RunPythonF = Callable[["SchemaMigrator", "State"], None]
class Dependency(namedtuple("Dependency", "model_name field_name type")):
class Type(Enum):
REMOVE_PK = auto()
CREATE = auto()
[docs]
class MigrateOperation:
"""
Base class for a migrate operation
"""
def __new__(cls, *args, **kwargs):
self = object.__new__(cls)
self._constructor_args = (args, kwargs)
return self
@property
def deps(self) -> list[Dependency]:
# Workaround to avoid initializing deps in the constructor
# and calling super() in every MigrateOperation subclass.
if not hasattr(self, "_deps"):
self._deps: list[Dependency] = []
return self._deps
def get_operation_call(self) -> str:
name = self.__class__.__name__
shortcut = re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
return f"migrator.{shortcut}"
def deconstruct(self):
"""
Return a 3-tuple of class import path (or just name if it lives
under django.db.migrations), positional arguments, and keyword
arguments.
"""
return (
self.get_operation_call(),
self._constructor_args[0],
self._constructor_args[1],
)
[docs]
def state_forwards(self, state: State) -> None:
"""
Take the state from the previous migration, and mutate it
so that it matches what this migration would perform.
"""
raise NotImplementedError
[docs]
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation] | list[Callable]:
"""
Perform the mutation on the database schema in the normal
(forwards) direction.
The method MUST NOT mutate provided states.
"""
raise NotImplementedError
[docs]
class RunPython(MigrateOperation):
"""
Allows to run custom Python code. **func** should be callable object that accept two arguments;
the first is an instance of :class:`SchemaMigrator` and the second is an instance of :class:`State`
Example::
def save_user(schema_migrator: SchemaMigrator, current_state: State):
User = current_state["user"]
User(
first_name="First",
last_name="Last",
).save()
migrator.add_operaion(RunPython(save_user))
"""
def __init__(self, func: RunPythonF) -> None:
self.func = func
def state_forwards(self, state: State) -> None:
pass
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation] | list[Callable]:
return [lambda: self.func(schema_migrator, from_state)]
[docs]
class RunSql(MigrateOperation):
"""
Allows running of arbitrary SQL on the database -
useful for more advanced features of database backends that Miggy doesn’t support directly.
Example::
migrator.add_operation(
RunSql(
'INSERT INTO "user" ("first_name", "last_name") VALUES (%s, %s)',
(
"First",
"Last",
),
)
)
"""
def __init__(self, sql: str, params: tuple[Any, ...] | None = None) -> None:
self.sql = sql
self.params = params
def state_forwards(self, state: State) -> None:
pass
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation] | list[Callable]:
return [schema_migrator.sql(self.sql, self.params)]
[docs]
class CreateModel(MigrateOperation):
"""
Creates a new model in the :class:`State` and a corresponding table in the database to match it.
"""
def __init__(self, name: str, fields: dict[str, pw.Field], meta: dict[str, Any]) -> None:
self.name = name
self.fields = fields
self.meta = meta
def state_forwards(self, state: State) -> None:
state.add_model(self.name, self.fields, self.meta)
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Callable]:
model = to_state[self.name]
return [schema_migrator.create_table(model)]
[docs]
class RemoveModel(MigrateOperation):
"""
Deletes the model from the :class:`State` and its table from the database.
"""
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def state_forwards(self, state: State) -> None:
state.remove_model(self.model_name)
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Callable]:
model = from_state[self.model_name]
return [schema_migrator.drop_table(model)]
[docs]
class AddIndex(MigrateOperation):
"""
Creates an index in the database table for the model with model_name.
The index will be saved in **Model._meta.indexes_state** dict
"""
def __init__(
self,
model_name: str,
*fields: str,
name: str,
unique: bool = False,
where: pw.SQL | None = None,
safe: bool = False,
concurrently: bool = False,
) -> None:
self.model_name = model_name
self.fields = fields
self.unique = unique
self.where = where
self.name = name
self.safe = safe
self.concurrently = concurrently
self._index: ModelIndex | None = None
def build_index(self, model: ModelCls) -> ModelIndex:
if not self._index:
self._index = ModelIndex(
model=model,
fields=[resolve_field(model, f) for f in self.fields],
unique=self.unique,
where=self.where,
name=self.name,
safe=self.safe,
concurrently=self.concurrently,
)
return self._index
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
model_index = self.build_index(model)
indexes_state(model)[model_index._name] = model_index
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
model = to_state[self.model_name]
model_index = self.build_index(model)
return [schema_migrator.add_model_index(model_index)]
[docs]
class DropIndex(MigrateOperation):
"""
Removes the index named name from the model with model_name.
"""
def __init__(self, model_name: str, name: str) -> None:
self.model_name = model_name
self.name = name
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
del indexes_state(model)[self.name]
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
model = from_state[self.model_name]
return [schema_migrator.drop_index(model._meta.table_name, self.name)]
[docs]
class AddField(MigrateOperation):
"""
Add a field to a model.
"""
def __init__(self, model_name: str, name: str, field: pw.Field) -> None:
self.model_name = model_name
self.name = name
self.field = field
def state_forwards(self, state: State) -> None:
state.add_field(self.model_name, self.name, self.field)
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
model = to_state[self.model_name]
field = model._meta.fields[self.name]
return [schema_migrator.add_field(field)]
[docs]
class AlterField(MigrateOperation):
"""
Alter a field for a model.
"""
def __init__(self, model_name: str, name: str, field: pw.Field) -> None:
self.model_name = model_name
self.name = name
self.field = field
def state_forwards(self, state: State) -> None:
state.add_field(self.model_name, self.name, self.field)
def handle_indexes(
self, old_field: pw.Field, new_field: pw.Field, schema_migrator: "SchemaMigrator"
) -> list[Operation]:
_ops = []
_field = new_field
if _field.unique and old_field.unique:
return []
if not _field.unique and not old_field.unique and _field.index == old_field.index:
return []
table_name = old_field.model._meta.table_name
if has_single_index(old_field):
# We have already renamed the column so create name from the new field
_ops.append(schema_migrator.drop_index(table_name, get_single_index_name(_field)))
if model_index := get_single_index(_field):
_ops.append(schema_migrator.add_model_index(model_index))
return _ops
def handle_fk_constraint(
self, old_field: pw.Field, new_field: pw.Field, schema_migrator: "SchemaMigrator"
) -> list[Operation]:
_ops: list[Operation] = []
is_old_field_fk = isinstance(old_field, pw.ForeignKeyField)
is_new_field_fk = isinstance(new_field, pw.ForeignKeyField)
if (
is_old_field_fk
and is_new_field_fk
and (
ForeignKeyFieldDeconstructor(old_field).deconstruct_fk_params()
== ForeignKeyFieldDeconstructor(new_field).deconstruct_fk_params()
)
):
# Nothing's changed for fk
return _ops
table_name = old_field.model._meta.table_name
if is_old_field_fk:
# we use new_field.column_name because we may have rename column before
_ops.append(schema_migrator.drop_foreign_key_constraint(table_name, new_field.column_name))
if is_new_field_fk:
_ops.append(
schema_migrator.add_foreign_key_constraint(
table_name,
new_field.column_name,
new_field.rel_model._meta.table_name,
new_field.rel_field.column_name,
new_field.on_delete,
new_field.on_update,
constraint_name=new_field.constraint_name,
)
)
return _ops
def handle_default_constraint(
self, old_field: pw.Field, new_field: pw.Field, schema_migrator: "SchemaMigrator"
) -> list[Operation]:
old_value = get_default_constraint_value(old_field) or ""
new_value = get_default_constraint_value(new_field) or ""
table_name = old_field.model._meta.table_name
if old_value != new_value:
if new_value:
return [schema_migrator.add_column_default(table_name, new_field.column_name, new_value)]
else:
return [
schema_migrator.drop_column_default(
table_name,
new_field.column_name,
)
]
return []
def handle_type(
self, old_field: pw.Field, new_field: pw.Field, schema_migrator: "SchemaMigrator"
) -> list[Operation]:
old_field_deconstructor = deconstructor_factory(old_field)
new_field_deconstructor = deconstructor_factory(new_field)
if (
old_field_deconstructor.field_type is not new_field_deconstructor.field_type
or old_field_deconstructor.get_type_modifiers() != new_field_deconstructor.get_type_modifiers()
):
table_name = old_field.model._meta.table_name
return [schema_migrator.alter_column_type(table_name, new_field.column_name, new_field)]
return []
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
_ops = []
name = self.name
old_model = from_state[self.model_name]
old_field = getattr(old_model, name)
old_column_name = old_field.column_name
table_name = old_model._meta.table_name
model = to_state[self.model_name]
field = model._meta.fields[self.name]
if old_column_name != field.column_name:
_ops.append(schema_migrator.rename_field(table_name, old_field, field))
_ops.append(schema_migrator._change_primary_key(old_field, field))
_ops.extend(self.handle_type(old_field, field, schema_migrator))
_ops.extend(self.handle_fk_constraint(old_field, field, schema_migrator))
_ops.extend(self.handle_default_constraint(old_field, field, schema_migrator))
if old_field.null != field.null:
_operation = schema_migrator.drop_not_null if field.null else schema_migrator.add_not_null
_ops.append(_operation(table_name, field.column_name))
_ops.extend(self.handle_indexes(old_field, field, schema_migrator))
return _ops
[docs]
class RemoveField(MigrateOperation):
"""
Remove a field from a model
"""
def __init__(self, model_name: str, name: str) -> None:
self.model_name = model_name
self.name = name
def state_forwards(self, state: State) -> None:
state.remove_field(self.model_name, self.name)
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
model = from_state[self.model_name]
field = model._meta.fields[self.name]
return [schema_migrator.drop_column(model._meta.table_name, field.column_name, cascade=False)]
[docs]
class RenameField(MigrateOperation):
"""
Changes a field’s name (and, unless **column_name** is set, its column name).
It also renames a single-column indexe, if it exists.
**Warning:**
This operation does not rename indexes created via the **Meta** class or the **add_index()** method.
You should explicitly specify index names if you plan to use this operation.
Otherwise, you will be prompted to recreate the index with a new name in the next migration.
"""
def __init__(self, model_name: str, old_name: str, new_name: str) -> None:
self.model_name = model_name
self.old_field_name = old_name
self.new_field_name = new_name
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
old_field = model._meta.fields[self.old_field_name]
new_field = old_field.clone()
new_field.column_name = self.resolve_new_name(old_field, self.new_field_name)
state.remove_field(self.model_name, self.old_field_name)
state.add_field(self.model_name, self.new_field_name, new_field)
def resolve_new_name(self, old_field: pw.Field, new_name: str) -> str:
if isinstance(old_field, pw.ForeignKeyField):
if old_field.column_name == fk_postfix(old_field.name):
return fk_postfix(new_name)
if old_field.column_name == old_field.name:
return new_name
return old_field.column_name
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
old_model = from_state[self.model_name]
new_model = to_state[self.model_name]
old_field = old_model._meta.fields[self.old_field_name]
new_field = new_model._meta.fields[self.new_field_name]
if old_field.column_name != new_field.column_name:
return [schema_migrator.rename_field(new_model._meta.table_name, old_field, new_field)]
return []
class ChangeNullable(MigrateOperation):
def __init__(self, model_name: str, *names: str, is_null: bool) -> None:
self.model_name = model_name
self.names = names
self.is_null = is_null
def state_forwards(self, state: State) -> None:
model = state[self.model_name]
for name in self.names:
field = model._meta.fields[name]
field.null = self.is_null
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
ops = []
model = to_state[self.model_name]
for name in self.names:
field = model._meta.fields[name]
_operation = schema_migrator.drop_not_null if self.is_null else schema_migrator.add_not_null
ops.append(_operation(model._meta.table_name, field.column_name))
return ops
[docs]
class AddPrimaryKeyConstraint(MigrateOperation):
def __init__(self, model_name: str, *fields: str) -> None:
self.model_name = model_name
self.fields = fields
def state_forwards(self, state: State) -> None:
state.add_composite_key(self.model_name, pw.CompositeKey(*self.fields))
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
model = to_state[self.model_name]
columns = [model._meta.fields[f].column_name for f in self.fields]
return [schema_migrator.add_primary_key_constraint(model._meta.table_name, *columns)]
[docs]
class RemovePrimaryKeyConstraint(MigrateOperation):
def __init__(self, model_name: str) -> None:
self.model_name = model_name
def state_forwards(self, state: State) -> None:
state.remove_composite_key(self.model_name)
def database_forwards(
self, schema_migrator: "SchemaMigrator", from_state: State, to_state: State
) -> list[Operation]:
model = from_state[self.model_name]
return [schema_migrator.drop_primary_key_constraint(model._meta.table_name)]