fields.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. from django.db.migrations.utils import field_references
  2. from django.db.models import NOT_PROVIDED
  3. from django.utils.functional import cached_property
  4. from .base import Operation
  5. class FieldOperation(Operation):
  6. def __init__(self, model_name, name, field=None):
  7. self.model_name = model_name
  8. self.name = name
  9. self.field = field
  10. @cached_property
  11. def model_name_lower(self):
  12. return self.model_name.lower()
  13. @cached_property
  14. def name_lower(self):
  15. return self.name.lower()
  16. def is_same_model_operation(self, operation):
  17. return self.model_name_lower == operation.model_name_lower
  18. def is_same_field_operation(self, operation):
  19. return (
  20. self.is_same_model_operation(operation)
  21. and self.name_lower == operation.name_lower
  22. )
  23. def references_model(self, name, app_label):
  24. name_lower = name.lower()
  25. if name_lower == self.model_name_lower:
  26. return True
  27. if self.field:
  28. return bool(
  29. field_references(
  30. (app_label, self.model_name_lower),
  31. self.field,
  32. (app_label, name_lower),
  33. )
  34. )
  35. return False
  36. def references_field(self, model_name, name, app_label):
  37. model_name_lower = model_name.lower()
  38. # Check if this operation locally references the field.
  39. if model_name_lower == self.model_name_lower:
  40. if name == self.name:
  41. return True
  42. elif (
  43. self.field
  44. and hasattr(self.field, "from_fields")
  45. and name in self.field.from_fields
  46. ):
  47. return True
  48. # Check if this operation remotely references the field.
  49. if self.field is None:
  50. return False
  51. return bool(
  52. field_references(
  53. (app_label, self.model_name_lower),
  54. self.field,
  55. (app_label, model_name_lower),
  56. name,
  57. )
  58. )
  59. def reduce(self, operation, app_label):
  60. return super().reduce(operation, app_label) or not operation.references_field(
  61. self.model_name, self.name, app_label
  62. )
  63. class AddField(FieldOperation):
  64. """Add a field to a model."""
  65. def __init__(self, model_name, name, field, preserve_default=True):
  66. self.preserve_default = preserve_default
  67. super().__init__(model_name, name, field)
  68. def deconstruct(self):
  69. kwargs = {
  70. "model_name": self.model_name,
  71. "name": self.name,
  72. "field": self.field,
  73. }
  74. if self.preserve_default is not True:
  75. kwargs["preserve_default"] = self.preserve_default
  76. return (self.__class__.__name__, [], kwargs)
  77. def state_forwards(self, app_label, state):
  78. state.add_field(
  79. app_label,
  80. self.model_name_lower,
  81. self.name,
  82. self.field,
  83. self.preserve_default,
  84. )
  85. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  86. to_model = to_state.apps.get_model(app_label, self.model_name)
  87. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  88. from_model = from_state.apps.get_model(app_label, self.model_name)
  89. field = to_model._meta.get_field(self.name)
  90. if not self.preserve_default:
  91. field.default = self.field.default
  92. schema_editor.add_field(
  93. from_model,
  94. field,
  95. )
  96. if not self.preserve_default:
  97. field.default = NOT_PROVIDED
  98. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  99. from_model = from_state.apps.get_model(app_label, self.model_name)
  100. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  101. schema_editor.remove_field(
  102. from_model, from_model._meta.get_field(self.name)
  103. )
  104. def describe(self):
  105. return "Add field %s to %s" % (self.name, self.model_name)
  106. @property
  107. def migration_name_fragment(self):
  108. return "%s_%s" % (self.model_name_lower, self.name_lower)
  109. def reduce(self, operation, app_label):
  110. if isinstance(operation, FieldOperation) and self.is_same_field_operation(
  111. operation
  112. ):
  113. if isinstance(operation, AlterField):
  114. return [
  115. AddField(
  116. model_name=self.model_name,
  117. name=operation.name,
  118. field=operation.field,
  119. ),
  120. ]
  121. elif isinstance(operation, RemoveField):
  122. return []
  123. elif isinstance(operation, RenameField):
  124. return [
  125. AddField(
  126. model_name=self.model_name,
  127. name=operation.new_name,
  128. field=self.field,
  129. ),
  130. ]
  131. return super().reduce(operation, app_label)
  132. class RemoveField(FieldOperation):
  133. """Remove a field from a model."""
  134. def deconstruct(self):
  135. kwargs = {
  136. "model_name": self.model_name,
  137. "name": self.name,
  138. }
  139. return (self.__class__.__name__, [], kwargs)
  140. def state_forwards(self, app_label, state):
  141. state.remove_field(app_label, self.model_name_lower, self.name)
  142. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  143. from_model = from_state.apps.get_model(app_label, self.model_name)
  144. if self.allow_migrate_model(schema_editor.connection.alias, from_model):
  145. schema_editor.remove_field(
  146. from_model, from_model._meta.get_field(self.name)
  147. )
  148. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  149. to_model = to_state.apps.get_model(app_label, self.model_name)
  150. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  151. from_model = from_state.apps.get_model(app_label, self.model_name)
  152. schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
  153. def describe(self):
  154. return "Remove field %s from %s" % (self.name, self.model_name)
  155. @property
  156. def migration_name_fragment(self):
  157. return "remove_%s_%s" % (self.model_name_lower, self.name_lower)
  158. def reduce(self, operation, app_label):
  159. from .models import DeleteModel
  160. if (
  161. isinstance(operation, DeleteModel)
  162. and operation.name_lower == self.model_name_lower
  163. ):
  164. return [operation]
  165. return super().reduce(operation, app_label)
  166. class AlterField(FieldOperation):
  167. """
  168. Alter a field's database column (e.g. null, max_length) to the provided
  169. new field.
  170. """
  171. def __init__(self, model_name, name, field, preserve_default=True):
  172. self.preserve_default = preserve_default
  173. super().__init__(model_name, name, field)
  174. def deconstruct(self):
  175. kwargs = {
  176. "model_name": self.model_name,
  177. "name": self.name,
  178. "field": self.field,
  179. }
  180. if self.preserve_default is not True:
  181. kwargs["preserve_default"] = self.preserve_default
  182. return (self.__class__.__name__, [], kwargs)
  183. def state_forwards(self, app_label, state):
  184. state.alter_field(
  185. app_label,
  186. self.model_name_lower,
  187. self.name,
  188. self.field,
  189. self.preserve_default,
  190. )
  191. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  192. to_model = to_state.apps.get_model(app_label, self.model_name)
  193. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  194. from_model = from_state.apps.get_model(app_label, self.model_name)
  195. from_field = from_model._meta.get_field(self.name)
  196. to_field = to_model._meta.get_field(self.name)
  197. if not self.preserve_default:
  198. to_field.default = self.field.default
  199. schema_editor.alter_field(from_model, from_field, to_field)
  200. if not self.preserve_default:
  201. to_field.default = NOT_PROVIDED
  202. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  203. self.database_forwards(app_label, schema_editor, from_state, to_state)
  204. def describe(self):
  205. return "Alter field %s on %s" % (self.name, self.model_name)
  206. @property
  207. def migration_name_fragment(self):
  208. return "alter_%s_%s" % (self.model_name_lower, self.name_lower)
  209. def reduce(self, operation, app_label):
  210. if isinstance(
  211. operation, (AlterField, RemoveField)
  212. ) and self.is_same_field_operation(operation):
  213. return [operation]
  214. elif (
  215. isinstance(operation, RenameField)
  216. and self.is_same_field_operation(operation)
  217. and self.field.db_column is None
  218. ):
  219. return [
  220. operation,
  221. AlterField(
  222. model_name=self.model_name,
  223. name=operation.new_name,
  224. field=self.field,
  225. ),
  226. ]
  227. return super().reduce(operation, app_label)
  228. class RenameField(FieldOperation):
  229. """Rename a field on the model. Might affect db_column too."""
  230. def __init__(self, model_name, old_name, new_name):
  231. self.old_name = old_name
  232. self.new_name = new_name
  233. super().__init__(model_name, old_name)
  234. @cached_property
  235. def old_name_lower(self):
  236. return self.old_name.lower()
  237. @cached_property
  238. def new_name_lower(self):
  239. return self.new_name.lower()
  240. def deconstruct(self):
  241. kwargs = {
  242. "model_name": self.model_name,
  243. "old_name": self.old_name,
  244. "new_name": self.new_name,
  245. }
  246. return (self.__class__.__name__, [], kwargs)
  247. def state_forwards(self, app_label, state):
  248. state.rename_field(
  249. app_label, self.model_name_lower, self.old_name, self.new_name
  250. )
  251. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  252. to_model = to_state.apps.get_model(app_label, self.model_name)
  253. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  254. from_model = from_state.apps.get_model(app_label, self.model_name)
  255. schema_editor.alter_field(
  256. from_model,
  257. from_model._meta.get_field(self.old_name),
  258. to_model._meta.get_field(self.new_name),
  259. )
  260. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  261. to_model = to_state.apps.get_model(app_label, self.model_name)
  262. if self.allow_migrate_model(schema_editor.connection.alias, to_model):
  263. from_model = from_state.apps.get_model(app_label, self.model_name)
  264. schema_editor.alter_field(
  265. from_model,
  266. from_model._meta.get_field(self.new_name),
  267. to_model._meta.get_field(self.old_name),
  268. )
  269. def describe(self):
  270. return "Rename field %s on %s to %s" % (
  271. self.old_name,
  272. self.model_name,
  273. self.new_name,
  274. )
  275. @property
  276. def migration_name_fragment(self):
  277. return "rename_%s_%s_%s" % (
  278. self.old_name_lower,
  279. self.model_name_lower,
  280. self.new_name_lower,
  281. )
  282. def references_field(self, model_name, name, app_label):
  283. return self.references_model(model_name, app_label) and (
  284. name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
  285. )
  286. def reduce(self, operation, app_label):
  287. if (
  288. isinstance(operation, RenameField)
  289. and self.is_same_model_operation(operation)
  290. and self.new_name_lower == operation.old_name_lower
  291. ):
  292. return [
  293. RenameField(
  294. self.model_name,
  295. self.old_name,
  296. operation.new_name,
  297. ),
  298. ]
  299. # Skip `FieldOperation.reduce` as we want to run `references_field`
  300. # against self.old_name and self.new_name.
  301. return super(FieldOperation, self).reduce(operation, app_label) or not (
  302. operation.references_field(self.model_name, self.old_name, app_label)
  303. or operation.references_field(self.model_name, self.new_name, app_label)
  304. )