creation.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import multiprocessing
  2. import os
  3. import shutil
  4. import sqlite3
  5. import sys
  6. from pathlib import Path
  7. from django.db import NotSupportedError
  8. from django.db.backends.base.creation import BaseDatabaseCreation
  9. class DatabaseCreation(BaseDatabaseCreation):
  10. @staticmethod
  11. def is_in_memory_db(database_name):
  12. return not isinstance(database_name, Path) and (
  13. database_name == ":memory:" or "mode=memory" in database_name
  14. )
  15. def _get_test_db_name(self):
  16. test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:"
  17. if test_database_name == ":memory:":
  18. return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias
  19. return test_database_name
  20. def _create_test_db(self, verbosity, autoclobber, keepdb=False):
  21. test_database_name = self._get_test_db_name()
  22. if keepdb:
  23. return test_database_name
  24. if not self.is_in_memory_db(test_database_name):
  25. # Erase the old test database
  26. if verbosity >= 1:
  27. self.log(
  28. "Destroying old test database for alias %s..."
  29. % (self._get_database_display_str(verbosity, test_database_name),)
  30. )
  31. if os.access(test_database_name, os.F_OK):
  32. if not autoclobber:
  33. confirm = input(
  34. "Type 'yes' if you would like to try deleting the test "
  35. "database '%s', or 'no' to cancel: " % test_database_name
  36. )
  37. if autoclobber or confirm == "yes":
  38. try:
  39. os.remove(test_database_name)
  40. except Exception as e:
  41. self.log("Got an error deleting the old test database: %s" % e)
  42. sys.exit(2)
  43. else:
  44. self.log("Tests cancelled.")
  45. sys.exit(1)
  46. return test_database_name
  47. def get_test_db_clone_settings(self, suffix):
  48. orig_settings_dict = self.connection.settings_dict
  49. source_database_name = orig_settings_dict["NAME"] or ":memory:"
  50. if not self.is_in_memory_db(source_database_name):
  51. root, ext = os.path.splitext(source_database_name)
  52. return {**orig_settings_dict, "NAME": f"{root}_{suffix}{ext}"}
  53. start_method = multiprocessing.get_start_method()
  54. if start_method == "fork":
  55. return orig_settings_dict
  56. if start_method == "spawn":
  57. return {
  58. **orig_settings_dict,
  59. "NAME": f"{self.connection.alias}_{suffix}.sqlite3",
  60. }
  61. raise NotSupportedError(
  62. f"Cloning with start method {start_method!r} is not supported."
  63. )
  64. def _clone_test_db(self, suffix, verbosity, keepdb=False):
  65. source_database_name = self.connection.settings_dict["NAME"]
  66. target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
  67. if not self.is_in_memory_db(source_database_name):
  68. # Erase the old test database
  69. if os.access(target_database_name, os.F_OK):
  70. if keepdb:
  71. return
  72. if verbosity >= 1:
  73. self.log(
  74. "Destroying old test database for alias %s..."
  75. % (
  76. self._get_database_display_str(
  77. verbosity, target_database_name
  78. ),
  79. )
  80. )
  81. try:
  82. os.remove(target_database_name)
  83. except Exception as e:
  84. self.log("Got an error deleting the old test database: %s" % e)
  85. sys.exit(2)
  86. try:
  87. shutil.copy(source_database_name, target_database_name)
  88. except Exception as e:
  89. self.log("Got an error cloning the test database: %s" % e)
  90. sys.exit(2)
  91. # Forking automatically makes a copy of an in-memory database.
  92. # Spawn requires migrating to disk which will be re-opened in
  93. # setup_worker_connection.
  94. elif multiprocessing.get_start_method() == "spawn":
  95. ondisk_db = sqlite3.connect(target_database_name, uri=True)
  96. self.connection.connection.backup(ondisk_db)
  97. ondisk_db.close()
  98. def _destroy_test_db(self, test_database_name, verbosity):
  99. if test_database_name and not self.is_in_memory_db(test_database_name):
  100. # Remove the SQLite database file
  101. os.remove(test_database_name)
  102. def test_db_signature(self):
  103. """
  104. Return a tuple that uniquely identifies a test database.
  105. This takes into account the special cases of ":memory:" and "" for
  106. SQLite since the databases will be distinct despite having the same
  107. TEST NAME. See https://www.sqlite.org/inmemorydb.html
  108. """
  109. test_database_name = self._get_test_db_name()
  110. sig = [self.connection.settings_dict["NAME"]]
  111. if self.is_in_memory_db(test_database_name):
  112. sig.append(self.connection.alias)
  113. else:
  114. sig.append(test_database_name)
  115. return tuple(sig)
  116. def setup_worker_connection(self, _worker_id):
  117. settings_dict = self.get_test_db_clone_settings(_worker_id)
  118. # connection.settings_dict must be updated in place for changes to be
  119. # reflected in django.db.connections. Otherwise new threads would
  120. # connect to the default database instead of the appropriate clone.
  121. start_method = multiprocessing.get_start_method()
  122. if start_method == "fork":
  123. # Update settings_dict in place.
  124. self.connection.settings_dict.update(settings_dict)
  125. self.connection.close()
  126. elif start_method == "spawn":
  127. alias = self.connection.alias
  128. connection_str = (
  129. f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared"
  130. )
  131. source_db = self.connection.Database.connect(
  132. f"file:{alias}_{_worker_id}.sqlite3", uri=True
  133. )
  134. target_db = sqlite3.connect(connection_str, uri=True)
  135. source_db.backup(target_db)
  136. source_db.close()
  137. # Update settings_dict in place.
  138. self.connection.settings_dict.update(settings_dict)
  139. self.connection.settings_dict["NAME"] = connection_str
  140. # Re-open connection to in-memory database before closing copy
  141. # connection.
  142. self.connection.connect()
  143. target_db.close()
  144. if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
  145. self.mark_expected_failures_and_skips()