123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360 |
- """
- Code to manage the creation and SQL rendering of 'where' constraints.
- """
- import operator
- from functools import reduce
- from django.core.exceptions import EmptyResultSet, FullResultSet
- from django.db.models.expressions import Case, When
- from django.db.models.functions import Mod
- from django.db.models.lookups import Exact
- from django.utils import tree
- from django.utils.functional import cached_property
- # Connection types
- AND = "AND"
- OR = "OR"
- XOR = "XOR"
- class WhereNode(tree.Node):
- """
- An SQL WHERE clause.
- The class is tied to the Query class that created it (in order to create
- the correct SQL).
- A child is usually an expression producing boolean values. Most likely the
- expression is a Lookup instance.
- However, a child could also be any class with as_sql() and either
- relabeled_clone() method or relabel_aliases() and clone() methods and
- contains_aggregate attribute.
- """
- default = AND
- resolved = False
- conditional = True
- def split_having_qualify(self, negated=False, must_group_by=False):
- """
- Return three possibly None nodes: one for those parts of self that
- should be included in the WHERE clause, one for those parts of self
- that must be included in the HAVING clause, and one for those parts
- that refer to window functions.
- """
- if not self.contains_aggregate and not self.contains_over_clause:
- return self, None, None
- in_negated = negated ^ self.negated
- # Whether or not children must be connected in the same filtering
- # clause (WHERE > HAVING > QUALIFY) to maintain logical semantic.
- must_remain_connected = (
- (in_negated and self.connector == AND)
- or (not in_negated and self.connector == OR)
- or self.connector == XOR
- )
- if (
- must_remain_connected
- and self.contains_aggregate
- and not self.contains_over_clause
- ):
- # It's must cheaper to short-circuit and stash everything in the
- # HAVING clause than split children if possible.
- return None, self, None
- where_parts = []
- having_parts = []
- qualify_parts = []
- for c in self.children:
- if hasattr(c, "split_having_qualify"):
- where_part, having_part, qualify_part = c.split_having_qualify(
- in_negated, must_group_by
- )
- if where_part is not None:
- where_parts.append(where_part)
- if having_part is not None:
- having_parts.append(having_part)
- if qualify_part is not None:
- qualify_parts.append(qualify_part)
- elif c.contains_over_clause:
- qualify_parts.append(c)
- elif c.contains_aggregate:
- having_parts.append(c)
- else:
- where_parts.append(c)
- if must_remain_connected and qualify_parts:
- # Disjunctive heterogeneous predicates can be pushed down to
- # qualify as long as no conditional aggregation is involved.
- if not where_parts or (where_parts and not must_group_by):
- return None, None, self
- elif where_parts:
- # In theory this should only be enforced when dealing with
- # where_parts containing predicates against multi-valued
- # relationships that could affect aggregation results but this
- # is complex to infer properly.
- raise NotImplementedError(
- "Heterogeneous disjunctive predicates against window functions are "
- "not implemented when performing conditional aggregation."
- )
- where_node = (
- self.create(where_parts, self.connector, self.negated)
- if where_parts
- else None
- )
- having_node = (
- self.create(having_parts, self.connector, self.negated)
- if having_parts
- else None
- )
- qualify_node = (
- self.create(qualify_parts, self.connector, self.negated)
- if qualify_parts
- else None
- )
- return where_node, having_node, qualify_node
- def as_sql(self, compiler, connection):
- """
- Return the SQL version of the where clause and the value to be
- substituted in. Return '', [] if this node matches everything,
- None, [] if this node is empty, and raise EmptyResultSet if this
- node can't match anything.
- """
- result = []
- result_params = []
- if self.connector == AND:
- full_needed, empty_needed = len(self.children), 1
- else:
- full_needed, empty_needed = 1, len(self.children)
- if self.connector == XOR and not connection.features.supports_logical_xor:
- # Convert if the database doesn't support XOR:
- # a XOR b XOR c XOR ...
- # to:
- # (a OR b OR c OR ...) AND MOD(a + b + c + ..., 2) == 1
- # The result of an n-ary XOR is true when an odd number of operands
- # are true.
- lhs = self.__class__(self.children, OR)
- rhs_sum = reduce(
- operator.add,
- (Case(When(c, then=1), default=0) for c in self.children),
- )
- if len(self.children) > 2:
- rhs_sum = Mod(rhs_sum, 2)
- rhs = Exact(1, rhs_sum)
- return self.__class__([lhs, rhs], AND, self.negated).as_sql(
- compiler, connection
- )
- for child in self.children:
- try:
- sql, params = compiler.compile(child)
- except EmptyResultSet:
- empty_needed -= 1
- except FullResultSet:
- full_needed -= 1
- else:
- if sql:
- result.append(sql)
- result_params.extend(params)
- else:
- full_needed -= 1
- # Check if this node matches nothing or everything.
- # First check the amount of full nodes and empty nodes
- # to make this node empty/full.
- # Now, check if this node is full/empty using the
- # counts.
- if empty_needed == 0:
- if self.negated:
- raise FullResultSet
- else:
- raise EmptyResultSet
- if full_needed == 0:
- if self.negated:
- raise EmptyResultSet
- else:
- raise FullResultSet
- conn = " %s " % self.connector
- sql_string = conn.join(result)
- if not sql_string:
- raise FullResultSet
- if self.negated:
- # Some backends (Oracle at least) need parentheses around the inner
- # SQL in the negated case, even if the inner SQL contains just a
- # single expression.
- sql_string = "NOT (%s)" % sql_string
- elif len(result) > 1 or self.resolved:
- sql_string = "(%s)" % sql_string
- return sql_string, result_params
- def get_group_by_cols(self):
- cols = []
- for child in self.children:
- cols.extend(child.get_group_by_cols())
- return cols
- def get_source_expressions(self):
- return self.children[:]
- def set_source_expressions(self, children):
- assert len(children) == len(self.children)
- self.children = children
- def relabel_aliases(self, change_map):
- """
- Relabel the alias values of any children. 'change_map' is a dictionary
- mapping old (current) alias values to the new values.
- """
- for pos, child in enumerate(self.children):
- if hasattr(child, "relabel_aliases"):
- # For example another WhereNode
- child.relabel_aliases(change_map)
- elif hasattr(child, "relabeled_clone"):
- self.children[pos] = child.relabeled_clone(change_map)
- def clone(self):
- clone = self.create(connector=self.connector, negated=self.negated)
- for child in self.children:
- if hasattr(child, "clone"):
- child = child.clone()
- clone.children.append(child)
- return clone
- def relabeled_clone(self, change_map):
- clone = self.clone()
- clone.relabel_aliases(change_map)
- return clone
- def replace_expressions(self, replacements):
- if replacement := replacements.get(self):
- return replacement
- clone = self.create(connector=self.connector, negated=self.negated)
- for child in self.children:
- clone.children.append(child.replace_expressions(replacements))
- return clone
- def get_refs(self):
- refs = set()
- for child in self.children:
- refs |= child.get_refs()
- return refs
- @classmethod
- def _contains_aggregate(cls, obj):
- if isinstance(obj, tree.Node):
- return any(cls._contains_aggregate(c) for c in obj.children)
- return obj.contains_aggregate
- @cached_property
- def contains_aggregate(self):
- return self._contains_aggregate(self)
- @classmethod
- def _contains_over_clause(cls, obj):
- if isinstance(obj, tree.Node):
- return any(cls._contains_over_clause(c) for c in obj.children)
- return obj.contains_over_clause
- @cached_property
- def contains_over_clause(self):
- return self._contains_over_clause(self)
- @property
- def is_summary(self):
- return any(child.is_summary for child in self.children)
- @staticmethod
- def _resolve_leaf(expr, query, *args, **kwargs):
- if hasattr(expr, "resolve_expression"):
- expr = expr.resolve_expression(query, *args, **kwargs)
- return expr
- @classmethod
- def _resolve_node(cls, node, query, *args, **kwargs):
- if hasattr(node, "children"):
- for child in node.children:
- cls._resolve_node(child, query, *args, **kwargs)
- if hasattr(node, "lhs"):
- node.lhs = cls._resolve_leaf(node.lhs, query, *args, **kwargs)
- if hasattr(node, "rhs"):
- node.rhs = cls._resolve_leaf(node.rhs, query, *args, **kwargs)
- def resolve_expression(self, *args, **kwargs):
- clone = self.clone()
- clone._resolve_node(clone, *args, **kwargs)
- clone.resolved = True
- return clone
- @cached_property
- def output_field(self):
- from django.db.models import BooleanField
- return BooleanField()
- @property
- def _output_field_or_none(self):
- return self.output_field
- def select_format(self, compiler, sql, params):
- # Wrap filters with a CASE WHEN expression if a database backend
- # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
- # BY list.
- if not compiler.connection.features.supports_boolean_expr_in_select_clause:
- sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
- return sql, params
- def get_db_converters(self, connection):
- return self.output_field.get_db_converters(connection)
- def get_lookup(self, lookup):
- return self.output_field.get_lookup(lookup)
- def leaves(self):
- for child in self.children:
- if isinstance(child, WhereNode):
- yield from child.leaves()
- else:
- yield child
- class NothingNode:
- """A node that matches nothing."""
- contains_aggregate = False
- contains_over_clause = False
- def as_sql(self, compiler=None, connection=None):
- raise EmptyResultSet
- class ExtraWhere:
- # The contents are a black box - assume no aggregates or windows are used.
- contains_aggregate = False
- contains_over_clause = False
- def __init__(self, sqls, params):
- self.sqls = sqls
- self.params = params
- def as_sql(self, compiler=None, connection=None):
- sqls = ["(%s)" % sql for sql in self.sqls]
- return " AND ".join(sqls), list(self.params or ())
- class SubqueryConstraint:
- # Even if aggregates or windows would be used in a subquery,
- # the outer query isn't interested about those.
- contains_aggregate = False
- contains_over_clause = False
- def __init__(self, alias, columns, targets, query_object):
- self.alias = alias
- self.columns = columns
- self.targets = targets
- query_object.clear_ordering(clear_default=True)
- self.query_object = query_object
- def as_sql(self, compiler, connection):
- query = self.query_object
- query.set_values(self.targets)
- query_compiler = query.get_compiler(connection=connection)
- return query_compiler.as_subquery_condition(self.alias, self.columns, compiler)
|