writer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import os
  2. import re
  3. from importlib import import_module
  4. from django import get_version
  5. from django.apps import apps
  6. # SettingsReference imported for backwards compatibility in Django 2.2.
  7. from django.conf import SettingsReference # NOQA
  8. from django.db import migrations
  9. from django.db.migrations.loader import MigrationLoader
  10. from django.db.migrations.serializer import Serializer, serializer_factory
  11. from django.utils.inspect import get_func_args
  12. from django.utils.module_loading import module_dir
  13. from django.utils.timezone import now
  14. class OperationWriter:
  15. def __init__(self, operation, indentation=2):
  16. self.operation = operation
  17. self.buff = []
  18. self.indentation = indentation
  19. def serialize(self):
  20. def _write(_arg_name, _arg_value):
  21. if _arg_name in self.operation.serialization_expand_args and isinstance(
  22. _arg_value, (list, tuple, dict)
  23. ):
  24. if isinstance(_arg_value, dict):
  25. self.feed("%s={" % _arg_name)
  26. self.indent()
  27. for key, value in _arg_value.items():
  28. key_string, key_imports = MigrationWriter.serialize(key)
  29. arg_string, arg_imports = MigrationWriter.serialize(value)
  30. args = arg_string.splitlines()
  31. if len(args) > 1:
  32. self.feed("%s: %s" % (key_string, args[0]))
  33. for arg in args[1:-1]:
  34. self.feed(arg)
  35. self.feed("%s," % args[-1])
  36. else:
  37. self.feed("%s: %s," % (key_string, arg_string))
  38. imports.update(key_imports)
  39. imports.update(arg_imports)
  40. self.unindent()
  41. self.feed("},")
  42. else:
  43. self.feed("%s=[" % _arg_name)
  44. self.indent()
  45. for item in _arg_value:
  46. arg_string, arg_imports = MigrationWriter.serialize(item)
  47. args = arg_string.splitlines()
  48. if len(args) > 1:
  49. for arg in args[:-1]:
  50. self.feed(arg)
  51. self.feed("%s," % args[-1])
  52. else:
  53. self.feed("%s," % arg_string)
  54. imports.update(arg_imports)
  55. self.unindent()
  56. self.feed("],")
  57. else:
  58. arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
  59. args = arg_string.splitlines()
  60. if len(args) > 1:
  61. self.feed("%s=%s" % (_arg_name, args[0]))
  62. for arg in args[1:-1]:
  63. self.feed(arg)
  64. self.feed("%s," % args[-1])
  65. else:
  66. self.feed("%s=%s," % (_arg_name, arg_string))
  67. imports.update(arg_imports)
  68. imports = set()
  69. name, args, kwargs = self.operation.deconstruct()
  70. operation_args = get_func_args(self.operation.__init__)
  71. # See if this operation is in django.db.migrations. If it is,
  72. # We can just use the fact we already have that imported,
  73. # otherwise, we need to add an import for the operation class.
  74. if getattr(migrations, name, None) == self.operation.__class__:
  75. self.feed("migrations.%s(" % name)
  76. else:
  77. imports.add("import %s" % (self.operation.__class__.__module__))
  78. self.feed("%s.%s(" % (self.operation.__class__.__module__, name))
  79. self.indent()
  80. for i, arg in enumerate(args):
  81. arg_value = arg
  82. arg_name = operation_args[i]
  83. _write(arg_name, arg_value)
  84. i = len(args)
  85. # Only iterate over remaining arguments
  86. for arg_name in operation_args[i:]:
  87. if arg_name in kwargs: # Don't sort to maintain signature order
  88. arg_value = kwargs[arg_name]
  89. _write(arg_name, arg_value)
  90. self.unindent()
  91. self.feed("),")
  92. return self.render(), imports
  93. def indent(self):
  94. self.indentation += 1
  95. def unindent(self):
  96. self.indentation -= 1
  97. def feed(self, line):
  98. self.buff.append(" " * (self.indentation * 4) + line)
  99. def render(self):
  100. return "\n".join(self.buff)
  101. class MigrationWriter:
  102. """
  103. Take a Migration instance and is able to produce the contents
  104. of the migration file from it.
  105. """
  106. def __init__(self, migration, include_header=True):
  107. self.migration = migration
  108. self.include_header = include_header
  109. self.needs_manual_porting = False
  110. def as_string(self):
  111. """Return a string of the file contents."""
  112. items = {
  113. "replaces_str": "",
  114. "initial_str": "",
  115. }
  116. imports = set()
  117. # Deconstruct operations
  118. operations = []
  119. for operation in self.migration.operations:
  120. operation_string, operation_imports = OperationWriter(operation).serialize()
  121. imports.update(operation_imports)
  122. operations.append(operation_string)
  123. items["operations"] = "\n".join(operations) + "\n" if operations else ""
  124. # Format dependencies and write out swappable dependencies right
  125. dependencies = []
  126. for dependency in self.migration.dependencies:
  127. if dependency[0] == "__setting__":
  128. dependencies.append(
  129. " migrations.swappable_dependency(settings.%s),"
  130. % dependency[1]
  131. )
  132. imports.add("from django.conf import settings")
  133. else:
  134. dependencies.append(" %s," % self.serialize(dependency)[0])
  135. items["dependencies"] = (
  136. "\n".join(sorted(dependencies)) + "\n" if dependencies else ""
  137. )
  138. # Format imports nicely, swapping imports of functions from migration files
  139. # for comments
  140. migration_imports = set()
  141. for line in list(imports):
  142. if re.match(r"^import (.*)\.\d+[^\s]*$", line):
  143. migration_imports.add(line.split("import")[1].strip())
  144. imports.remove(line)
  145. self.needs_manual_porting = True
  146. # django.db.migrations is always used, but models import may not be.
  147. # If models import exists, merge it with migrations import.
  148. if "from django.db import models" in imports:
  149. imports.discard("from django.db import models")
  150. imports.add("from django.db import migrations, models")
  151. else:
  152. imports.add("from django.db import migrations")
  153. # Sort imports by the package / module to be imported (the part after
  154. # "from" in "from ... import ..." or after "import" in "import ...").
  155. # First group the "import" statements, then "from ... import ...".
  156. sorted_imports = sorted(
  157. imports, key=lambda i: (i.split()[0] == "from", i.split()[1])
  158. )
  159. items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
  160. if migration_imports:
  161. items["imports"] += (
  162. "\n\n# Functions from the following migrations need manual "
  163. "copying.\n# Move them and any dependencies into this file, "
  164. "then update the\n# RunPython operations to refer to the local "
  165. "versions:\n# %s"
  166. ) % "\n# ".join(sorted(migration_imports))
  167. # If there's a replaces, make a string for it
  168. if self.migration.replaces:
  169. items["replaces_str"] = (
  170. "\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
  171. )
  172. # Hinting that goes into comment
  173. if self.include_header:
  174. items["migration_header"] = MIGRATION_HEADER_TEMPLATE % {
  175. "version": get_version(),
  176. "timestamp": now().strftime("%Y-%m-%d %H:%M"),
  177. }
  178. else:
  179. items["migration_header"] = ""
  180. if self.migration.initial:
  181. items["initial_str"] = "\n initial = True\n"
  182. return MIGRATION_TEMPLATE % items
  183. @property
  184. def basedir(self):
  185. migrations_package_name, _ = MigrationLoader.migrations_module(
  186. self.migration.app_label
  187. )
  188. if migrations_package_name is None:
  189. raise ValueError(
  190. "Django can't create migrations for app '%s' because "
  191. "migrations have been disabled via the MIGRATION_MODULES "
  192. "setting." % self.migration.app_label
  193. )
  194. # See if we can import the migrations module directly
  195. try:
  196. migrations_module = import_module(migrations_package_name)
  197. except ImportError:
  198. pass
  199. else:
  200. try:
  201. return module_dir(migrations_module)
  202. except ValueError:
  203. pass
  204. # Alright, see if it's a direct submodule of the app
  205. app_config = apps.get_app_config(self.migration.app_label)
  206. (
  207. maybe_app_name,
  208. _,
  209. migrations_package_basename,
  210. ) = migrations_package_name.rpartition(".")
  211. if app_config.name == maybe_app_name:
  212. return os.path.join(app_config.path, migrations_package_basename)
  213. # In case of using MIGRATION_MODULES setting and the custom package
  214. # doesn't exist, create one, starting from an existing package
  215. existing_dirs, missing_dirs = migrations_package_name.split("."), []
  216. while existing_dirs:
  217. missing_dirs.insert(0, existing_dirs.pop(-1))
  218. try:
  219. base_module = import_module(".".join(existing_dirs))
  220. except (ImportError, ValueError):
  221. continue
  222. else:
  223. try:
  224. base_dir = module_dir(base_module)
  225. except ValueError:
  226. continue
  227. else:
  228. break
  229. else:
  230. raise ValueError(
  231. "Could not locate an appropriate location to create "
  232. "migrations package %s. Make sure the toplevel "
  233. "package exists and can be imported." % migrations_package_name
  234. )
  235. final_dir = os.path.join(base_dir, *missing_dirs)
  236. os.makedirs(final_dir, exist_ok=True)
  237. for missing_dir in missing_dirs:
  238. base_dir = os.path.join(base_dir, missing_dir)
  239. with open(os.path.join(base_dir, "__init__.py"), "w"):
  240. pass
  241. return final_dir
  242. @property
  243. def filename(self):
  244. return "%s.py" % self.migration.name
  245. @property
  246. def path(self):
  247. return os.path.join(self.basedir, self.filename)
  248. @classmethod
  249. def serialize(cls, value):
  250. return serializer_factory(value).serialize()
  251. @classmethod
  252. def register_serializer(cls, type_, serializer):
  253. Serializer.register(type_, serializer)
  254. @classmethod
  255. def unregister_serializer(cls, type_):
  256. Serializer.unregister(type_)
  257. MIGRATION_HEADER_TEMPLATE = """\
  258. # Generated by Django %(version)s on %(timestamp)s
  259. """
  260. MIGRATION_TEMPLATE = """\
  261. %(migration_header)s%(imports)s
  262. class Migration(migrations.Migration):
  263. %(replaces_str)s%(initial_str)s
  264. dependencies = [
  265. %(dependencies)s\
  266. ]
  267. operations = [
  268. %(operations)s\
  269. ]
  270. """