generated.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from django.core import checks
  2. from django.db import connections, router
  3. from django.db.models.sql import Query
  4. from django.utils.functional import cached_property
  5. from . import NOT_PROVIDED, Field
  6. __all__ = ["GeneratedField"]
  7. class GeneratedField(Field):
  8. generated = True
  9. db_returning = True
  10. _query = None
  11. output_field = None
  12. def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
  13. if kwargs.setdefault("editable", False):
  14. raise ValueError("GeneratedField cannot be editable.")
  15. if not kwargs.setdefault("blank", True):
  16. raise ValueError("GeneratedField must be blank.")
  17. if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
  18. raise ValueError("GeneratedField cannot have a default.")
  19. if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
  20. raise ValueError("GeneratedField cannot have a database default.")
  21. if db_persist not in (True, False):
  22. raise ValueError("GeneratedField.db_persist must be True or False.")
  23. self.expression = expression
  24. self.output_field = output_field
  25. self.db_persist = db_persist
  26. super().__init__(**kwargs)
  27. @cached_property
  28. def cached_col(self):
  29. from django.db.models.expressions import Col
  30. return Col(self.model._meta.db_table, self, self.output_field)
  31. def get_col(self, alias, output_field=None):
  32. if alias != self.model._meta.db_table and output_field is None:
  33. output_field = self.output_field
  34. return super().get_col(alias, output_field)
  35. def contribute_to_class(self, *args, **kwargs):
  36. super().contribute_to_class(*args, **kwargs)
  37. self._query = Query(model=self.model, alias_cols=False)
  38. # Register lookups from the output_field class.
  39. for lookup_name, lookup in self.output_field.get_class_lookups().items():
  40. self.register_lookup(lookup, lookup_name=lookup_name)
  41. def generated_sql(self, connection):
  42. compiler = connection.ops.compiler("SQLCompiler")(
  43. self._query, connection=connection, using=None
  44. )
  45. resolved_expression = self.expression.resolve_expression(
  46. self._query, allow_joins=False
  47. )
  48. return compiler.compile(resolved_expression)
  49. def check(self, **kwargs):
  50. databases = kwargs.get("databases") or []
  51. errors = [
  52. *super().check(**kwargs),
  53. *self._check_supported(databases),
  54. *self._check_persistence(databases),
  55. ]
  56. output_field_clone = self.output_field.clone()
  57. output_field_clone.model = self.model
  58. output_field_checks = output_field_clone.check(databases=databases)
  59. if output_field_checks:
  60. separator = "\n "
  61. error_messages = separator.join(
  62. f"{output_check.msg} ({output_check.id})"
  63. for output_check in output_field_checks
  64. if isinstance(output_check, checks.Error)
  65. )
  66. if error_messages:
  67. errors.append(
  68. checks.Error(
  69. "GeneratedField.output_field has errors:"
  70. f"{separator}{error_messages}",
  71. obj=self,
  72. id="fields.E223",
  73. )
  74. )
  75. warning_messages = separator.join(
  76. f"{output_check.msg} ({output_check.id})"
  77. for output_check in output_field_checks
  78. if isinstance(output_check, checks.Warning)
  79. )
  80. if warning_messages:
  81. errors.append(
  82. checks.Warning(
  83. "GeneratedField.output_field has warnings:"
  84. f"{separator}{warning_messages}",
  85. obj=self,
  86. id="fields.W224",
  87. )
  88. )
  89. return errors
  90. def _check_supported(self, databases):
  91. errors = []
  92. for db in databases:
  93. if not router.allow_migrate_model(db, self.model):
  94. continue
  95. connection = connections[db]
  96. if (
  97. self.model._meta.required_db_vendor
  98. and self.model._meta.required_db_vendor != connection.vendor
  99. ):
  100. continue
  101. if not (
  102. connection.features.supports_virtual_generated_columns
  103. or "supports_stored_generated_columns"
  104. in self.model._meta.required_db_features
  105. ) and not (
  106. connection.features.supports_stored_generated_columns
  107. or "supports_virtual_generated_columns"
  108. in self.model._meta.required_db_features
  109. ):
  110. errors.append(
  111. checks.Error(
  112. f"{connection.display_name} does not support GeneratedFields.",
  113. obj=self,
  114. id="fields.E220",
  115. )
  116. )
  117. return errors
  118. def _check_persistence(self, databases):
  119. errors = []
  120. for db in databases:
  121. if not router.allow_migrate_model(db, self.model):
  122. continue
  123. connection = connections[db]
  124. if (
  125. self.model._meta.required_db_vendor
  126. and self.model._meta.required_db_vendor != connection.vendor
  127. ):
  128. continue
  129. if not self.db_persist and not (
  130. connection.features.supports_virtual_generated_columns
  131. or "supports_virtual_generated_columns"
  132. in self.model._meta.required_db_features
  133. ):
  134. errors.append(
  135. checks.Error(
  136. f"{connection.display_name} does not support non-persisted "
  137. "GeneratedFields.",
  138. obj=self,
  139. id="fields.E221",
  140. hint="Set db_persist=True on the field.",
  141. )
  142. )
  143. if self.db_persist and not (
  144. connection.features.supports_stored_generated_columns
  145. or "supports_stored_generated_columns"
  146. in self.model._meta.required_db_features
  147. ):
  148. errors.append(
  149. checks.Error(
  150. f"{connection.display_name} does not support persisted "
  151. "GeneratedFields.",
  152. obj=self,
  153. id="fields.E222",
  154. hint="Set db_persist=False on the field.",
  155. )
  156. )
  157. return errors
  158. def deconstruct(self):
  159. name, path, args, kwargs = super().deconstruct()
  160. del kwargs["blank"]
  161. del kwargs["editable"]
  162. kwargs["db_persist"] = self.db_persist
  163. kwargs["expression"] = self.expression
  164. kwargs["output_field"] = self.output_field
  165. return name, path, args, kwargs
  166. def get_internal_type(self):
  167. return self.output_field.get_internal_type()
  168. def db_parameters(self, connection):
  169. return self.output_field.db_parameters(connection)
  170. def db_type_parameters(self, connection):
  171. return self.output_field.db_type_parameters(connection)