subqueries.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. """
  2. Query subclasses which provide extra functionality beyond simple data retrieval.
  3. """
  4. from django.core.exceptions import FieldError
  5. from django.db.models.sql.constants import CURSOR, GET_ITERATOR_CHUNK_SIZE, NO_RESULTS
  6. from django.db.models.sql.query import Query
  7. __all__ = ["DeleteQuery", "UpdateQuery", "InsertQuery", "AggregateQuery"]
  8. class DeleteQuery(Query):
  9. """A DELETE SQL query."""
  10. compiler = "SQLDeleteCompiler"
  11. def do_query(self, table, where, using):
  12. self.alias_map = {table: self.alias_map[table]}
  13. self.where = where
  14. cursor = self.get_compiler(using).execute_sql(CURSOR)
  15. if cursor:
  16. with cursor:
  17. return cursor.rowcount
  18. return 0
  19. def delete_batch(self, pk_list, using):
  20. """
  21. Set up and execute delete queries for all the objects in pk_list.
  22. More than one physical query may be executed if there are a
  23. lot of values in pk_list.
  24. """
  25. # number of objects deleted
  26. num_deleted = 0
  27. field = self.get_meta().pk
  28. for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
  29. self.clear_where()
  30. self.add_filter(
  31. f"{field.attname}__in",
  32. pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE],
  33. )
  34. num_deleted += self.do_query(
  35. self.get_meta().db_table, self.where, using=using
  36. )
  37. return num_deleted
  38. class UpdateQuery(Query):
  39. """An UPDATE SQL query."""
  40. compiler = "SQLUpdateCompiler"
  41. def __init__(self, *args, **kwargs):
  42. super().__init__(*args, **kwargs)
  43. self._setup_query()
  44. def _setup_query(self):
  45. """
  46. Run on initialization and at the end of chaining. Any attributes that
  47. would normally be set in __init__() should go here instead.
  48. """
  49. self.values = []
  50. self.related_ids = None
  51. self.related_updates = {}
  52. def clone(self):
  53. obj = super().clone()
  54. obj.related_updates = self.related_updates.copy()
  55. return obj
  56. def update_batch(self, pk_list, values, using):
  57. self.add_update_values(values)
  58. for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
  59. self.clear_where()
  60. self.add_filter(
  61. "pk__in", pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]
  62. )
  63. self.get_compiler(using).execute_sql(NO_RESULTS)
  64. def add_update_values(self, values):
  65. """
  66. Convert a dictionary of field name to value mappings into an update
  67. query. This is the entry point for the public update() method on
  68. querysets.
  69. """
  70. values_seq = []
  71. for name, val in values.items():
  72. field = self.get_meta().get_field(name)
  73. direct = (
  74. not (field.auto_created and not field.concrete) or not field.concrete
  75. )
  76. model = field.model._meta.concrete_model
  77. if not direct or (field.is_relation and field.many_to_many):
  78. raise FieldError(
  79. "Cannot update model field %r (only non-relations and "
  80. "foreign keys permitted)." % field
  81. )
  82. if model is not self.get_meta().concrete_model:
  83. self.add_related_update(model, field, val)
  84. continue
  85. values_seq.append((field, model, val))
  86. return self.add_update_fields(values_seq)
  87. def add_update_fields(self, values_seq):
  88. """
  89. Append a sequence of (field, model, value) triples to the internal list
  90. that will be used to generate the UPDATE query. Might be more usefully
  91. called add_update_targets() to hint at the extra information here.
  92. """
  93. for field, model, val in values_seq:
  94. # Omit generated fields.
  95. if field.generated:
  96. continue
  97. if hasattr(val, "resolve_expression"):
  98. # Resolve expressions here so that annotations are no longer needed
  99. val = val.resolve_expression(self, allow_joins=False, for_save=True)
  100. self.values.append((field, model, val))
  101. def add_related_update(self, model, field, value):
  102. """
  103. Add (name, value) to an update query for an ancestor model.
  104. Update are coalesced so that only one update query per ancestor is run.
  105. """
  106. self.related_updates.setdefault(model, []).append((field, None, value))
  107. def get_related_updates(self):
  108. """
  109. Return a list of query objects: one for each update required to an
  110. ancestor model. Each query will have the same filtering conditions as
  111. the current query but will only update a single table.
  112. """
  113. if not self.related_updates:
  114. return []
  115. result = []
  116. for model, values in self.related_updates.items():
  117. query = UpdateQuery(model)
  118. query.values = values
  119. if self.related_ids is not None:
  120. query.add_filter("pk__in", self.related_ids[model])
  121. result.append(query)
  122. return result
  123. class InsertQuery(Query):
  124. compiler = "SQLInsertCompiler"
  125. def __init__(
  126. self, *args, on_conflict=None, update_fields=None, unique_fields=None, **kwargs
  127. ):
  128. super().__init__(*args, **kwargs)
  129. self.fields = []
  130. self.objs = []
  131. self.on_conflict = on_conflict
  132. self.update_fields = update_fields or []
  133. self.unique_fields = unique_fields or []
  134. def insert_values(self, fields, objs, raw=False):
  135. self.fields = fields
  136. self.objs = objs
  137. self.raw = raw
  138. class AggregateQuery(Query):
  139. """
  140. Take another query as a parameter to the FROM clause and only select the
  141. elements in the provided list.
  142. """
  143. compiler = "SQLAggregateCompiler"
  144. def __init__(self, model, inner_query):
  145. self.inner_query = inner_query
  146. super().__init__(model)