expressions.py 65 KB


  1. import copy
  2. import datetime
  3. import functools
  4. import inspect
  5. from collections import defaultdict
  6. from decimal import Decimal
  7. from types import NoneType
  8. from uuid import UUID
  9. from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
  10. from django.db import DatabaseError, NotSupportedError, connection
  11. from django.db.models import fields
  12. from django.db.models.constants import LOOKUP_SEP
  13. from django.db.models.query_utils import Q
  14. from django.utils.deconstruct import deconstructible
  15. from django.utils.functional import cached_property
  16. from django.utils.hashable import make_hashable
  17. class SQLiteNumericMixin:
  18. """
  19. Some expressions with output_field=DecimalField() must be cast to
  20. numeric to be properly filtered.
  21. """
  22. def as_sqlite(self, compiler, connection, **extra_context):
  23. sql, params = self.as_sql(compiler, connection, **extra_context)
  24. try:
  25. if self.output_field.get_internal_type() == "DecimalField":
  26. sql = "(CAST(%s AS NUMERIC))" % sql
  27. except FieldError:
  28. pass
  29. return sql, params
  30. class Combinable:
  31. """
  32. Provide the ability to combine one or two objects with
  33. some connector. For example F('foo') + F('bar').
  34. """
  35. # Arithmetic connectors
  36. ADD = "+"
  37. SUB = "-"
  38. MUL = "*"
  39. DIV = "/"
  40. POW = "^"
  41. # The following is a quoted % operator - it is quoted because it can be
  42. # used in strings that also have parameter substitution.
  43. MOD = "%%"
  44. # Bitwise operators - note that these are generated by .bitand()
  45. # and .bitor(), the '&' and '|' are reserved for boolean operator
  46. # usage.
  47. BITAND = "&"
  48. BITOR = "|"
  49. BITLEFTSHIFT = "<<"
  50. BITRIGHTSHIFT = ">>"
  51. BITXOR = "#"
  52. def _combine(self, other, connector, reversed):
  53. if not hasattr(other, "resolve_expression"):
  54. # everything must be resolvable to an expression
  55. other = Value(other)
  56. if reversed:
  57. return CombinedExpression(other, connector, self)
  58. return CombinedExpression(self, connector, other)
  59. #############
  60. # OPERATORS #
  61. #############
  62. def __neg__(self):
  63. return self._combine(-1, self.MUL, False)
  64. def __add__(self, other):
  65. return self._combine(other, self.ADD, False)
  66. def __sub__(self, other):
  67. return self._combine(other, self.SUB, False)
  68. def __mul__(self, other):
  69. return self._combine(other, self.MUL, False)
  70. def __truediv__(self, other):
  71. return self._combine(other, self.DIV, False)
  72. def __mod__(self, other):
  73. return self._combine(other, self.MOD, False)
  74. def __pow__(self, other):
  75. return self._combine(other, self.POW, False)
  76. def __and__(self, other):
  77. if getattr(self, "conditional", False) and getattr(other, "conditional", False):
  78. return Q(self) & Q(other)
  79. raise NotImplementedError(
  80. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  81. )
  82. def bitand(self, other):
  83. return self._combine(other, self.BITAND, False)
  84. def bitleftshift(self, other):
  85. return self._combine(other, self.BITLEFTSHIFT, False)
  86. def bitrightshift(self, other):
  87. return self._combine(other, self.BITRIGHTSHIFT, False)
  88. def __xor__(self, other):
  89. if getattr(self, "conditional", False) and getattr(other, "conditional", False):
  90. return Q(self) ^ Q(other)
  91. raise NotImplementedError(
  92. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  93. )
  94. def bitxor(self, other):
  95. return self._combine(other, self.BITXOR, False)
  96. def __or__(self, other):
  97. if getattr(self, "conditional", False) and getattr(other, "conditional", False):
  98. return Q(self) | Q(other)
  99. raise NotImplementedError(
  100. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  101. )
  102. def bitor(self, other):
  103. return self._combine(other, self.BITOR, False)
  104. def __radd__(self, other):
  105. return self._combine(other, self.ADD, True)
  106. def __rsub__(self, other):
  107. return self._combine(other, self.SUB, True)
  108. def __rmul__(self, other):
  109. return self._combine(other, self.MUL, True)
  110. def __rtruediv__(self, other):
  111. return self._combine(other, self.DIV, True)
  112. def __rmod__(self, other):
  113. return self._combine(other, self.MOD, True)
  114. def __rpow__(self, other):
  115. return self._combine(other, self.POW, True)
  116. def __rand__(self, other):
  117. raise NotImplementedError(
  118. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  119. )
  120. def __ror__(self, other):
  121. raise NotImplementedError(
  122. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  123. )
  124. def __rxor__(self, other):
  125. raise NotImplementedError(
  126. "Use .bitand(), .bitor(), and .bitxor() for bitwise logical operations."
  127. )
  128. def __invert__(self):
  129. return NegatedExpression(self)
  130. class BaseExpression:
  131. """Base class for all query expressions."""
  132. empty_result_set_value = NotImplemented
  133. # aggregate specific fields
  134. is_summary = False
  135. _output_field_resolved_to_none = False
  136. # Can the expression be used in a WHERE clause?
  137. filterable = True
  138. # Can the expression can be used as a source expression in Window?
  139. window_compatible = False
  140. # Can the expression be used as a database default value?
  141. allowed_default = False
  142. def __init__(self, output_field=None):
  143. if output_field is not None:
  144. self.output_field = output_field
  145. def __getstate__(self):
  146. state = self.__dict__.copy()
  147. state.pop("convert_value", None)
  148. return state
  149. def get_db_converters(self, connection):
  150. return (
  151. []
  152. if self.convert_value is self._convert_value_noop
  153. else [self.convert_value]
  154. ) + self.output_field.get_db_converters(connection)
  155. def get_source_expressions(self):
  156. return []
  157. def set_source_expressions(self, exprs):
  158. assert not exprs
  159. def _parse_expressions(self, *expressions):
  160. return [
  161. arg
  162. if hasattr(arg, "resolve_expression")
  163. else (F(arg) if isinstance(arg, str) else Value(arg))
  164. for arg in expressions
  165. ]
  166. def as_sql(self, compiler, connection):
  167. """
  168. Responsible for returning a (sql, [params]) tuple to be included
  169. in the current query.
  170. Different backends can provide their own implementation, by
  171. providing an `as_{vendor}` method and patching the Expression:
  172. ```
  173. def override_as_sql(self, compiler, connection):
  174. # custom logic
  175. return super().as_sql(compiler, connection)
  176. setattr(Expression, 'as_' + connection.vendor, override_as_sql)
  177. ```
  178. Arguments:
  179. * compiler: the query compiler responsible for generating the query.
  180. Must have a compile method, returning a (sql, [params]) tuple.
  181. Calling compiler(value) will return a quoted `value`.
  182. * connection: the database connection used for the current query.
  183. Return: (sql, params)
  184. Where `sql` is a string containing ordered sql parameters to be
  185. replaced with the elements of the list `params`.
  186. """
  187. raise NotImplementedError("Subclasses must implement as_sql()")
  188. @cached_property
  189. def contains_aggregate(self):
  190. return any(
  191. expr and expr.contains_aggregate for expr in self.get_source_expressions()
  192. )
  193. @cached_property
  194. def contains_over_clause(self):
  195. return any(
  196. expr and expr.contains_over_clause for expr in self.get_source_expressions()
  197. )
  198. @cached_property
  199. def contains_column_references(self):
  200. return any(
  201. expr and expr.contains_column_references
  202. for expr in self.get_source_expressions()
  203. )
  204. @cached_property
  205. def contains_subquery(self):
  206. return any(
  207. expr and (getattr(expr, "subquery", False) or expr.contains_subquery)
  208. for expr in self.get_source_expressions()
  209. )
  210. def resolve_expression(
  211. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  212. ):
  213. """
  214. Provide the chance to do any preprocessing or validation before being
  215. added to the query.
  216. Arguments:
  217. * query: the backend query implementation
  218. * allow_joins: boolean allowing or denying use of joins
  219. in this query
  220. * reuse: a set of reusable joins for multijoins
  221. * summarize: a terminal aggregate clause
  222. * for_save: whether this expression about to be used in a save or update
  223. Return: an Expression to be added to the query.
  224. """
  225. c = self.copy()
  226. c.is_summary = summarize
  227. c.set_source_expressions(
  228. [
  229. expr.resolve_expression(query, allow_joins, reuse, summarize)
  230. if expr
  231. else None
  232. for expr in c.get_source_expressions()
  233. ]
  234. )
  235. return c
  236. @property
  237. def conditional(self):
  238. return isinstance(self.output_field, fields.BooleanField)
  239. @property
  240. def field(self):
  241. return self.output_field
  242. @cached_property
  243. def output_field(self):
  244. """Return the output type of this expressions."""
  245. output_field = self._resolve_output_field()
  246. if output_field is None:
  247. self._output_field_resolved_to_none = True
  248. raise FieldError("Cannot resolve expression type, unknown output_field")
  249. return output_field
  250. @cached_property
  251. def _output_field_or_none(self):
  252. """
  253. Return the output field of this expression, or None if
  254. _resolve_output_field() didn't return an output type.
  255. """
  256. try:
  257. return self.output_field
  258. except FieldError:
  259. if not self._output_field_resolved_to_none:
  260. raise
  261. def _resolve_output_field(self):
  262. """
  263. Attempt to infer the output type of the expression.
  264. As a guess, if the output fields of all source fields match then simply
  265. infer the same type here.
  266. If a source's output field resolves to None, exclude it from this check.
  267. If all sources are None, then an error is raised higher up the stack in
  268. the output_field property.
  269. """
  270. # This guess is mostly a bad idea, but there is quite a lot of code
  271. # (especially 3rd party Func subclasses) that depend on it, we'd need a
  272. # deprecation path to fix it.
  273. sources_iter = (
  274. source for source in self.get_source_fields() if source is not None
  275. )
  276. for output_field in sources_iter:
  277. for source in sources_iter:
  278. if not isinstance(output_field, source.__class__):
  279. raise FieldError(
  280. "Expression contains mixed types: %s, %s. You must "
  281. "set output_field."
  282. % (
  283. output_field.__class__.__name__,
  284. source.__class__.__name__,
  285. )
  286. )
  287. return output_field
  288. @staticmethod
  289. def _convert_value_noop(value, expression, connection):
  290. return value
  291. @cached_property
  292. def convert_value(self):
  293. """
  294. Expressions provide their own converters because users have the option
  295. of manually specifying the output_field which may be a different type
  296. from the one the database returns.
  297. """
  298. field = self.output_field
  299. internal_type = field.get_internal_type()
  300. if internal_type == "FloatField":
  301. return (
  302. lambda value, expression, connection: None
  303. if value is None
  304. else float(value)
  305. )
  306. elif internal_type.endswith("IntegerField"):
  307. return (
  308. lambda value, expression, connection: None
  309. if value is None
  310. else int(value)
  311. )
  312. elif internal_type == "DecimalField":
  313. return (
  314. lambda value, expression, connection: None
  315. if value is None
  316. else Decimal(value)
  317. )
  318. return self._convert_value_noop
  319. def get_lookup(self, lookup):
  320. return self.output_field.get_lookup(lookup)
  321. def get_transform(self, name):
  322. return self.output_field.get_transform(name)
  323. def relabeled_clone(self, change_map):
  324. clone = self.copy()
  325. clone.set_source_expressions(
  326. [
  327. e.relabeled_clone(change_map) if e is not None else None
  328. for e in self.get_source_expressions()
  329. ]
  330. )
  331. return clone
  332. def replace_expressions(self, replacements):
  333. if replacement := replacements.get(self):
  334. return replacement
  335. clone = self.copy()
  336. source_expressions = clone.get_source_expressions()
  337. clone.set_source_expressions(
  338. [
  339. expr.replace_expressions(replacements) if expr else None
  340. for expr in source_expressions
  341. ]
  342. )
  343. return clone
  344. def get_refs(self):
  345. refs = set()
  346. for expr in self.get_source_expressions():
  347. refs |= expr.get_refs()
  348. return refs
  349. def copy(self):
  350. return copy.copy(self)
  351. def prefix_references(self, prefix):
  352. clone = self.copy()
  353. clone.set_source_expressions(
  354. [
  355. F(f"{prefix}{expr.name}")
  356. if isinstance(expr, F)
  357. else expr.prefix_references(prefix)
  358. for expr in self.get_source_expressions()
  359. ]
  360. )
  361. return clone
  362. def get_group_by_cols(self):
  363. if not self.contains_aggregate:
  364. return [self]
  365. cols = []
  366. for source in self.get_source_expressions():
  367. cols.extend(source.get_group_by_cols())
  368. return cols
  369. def get_source_fields(self):
  370. """Return the underlying field types used by this aggregate."""
  371. return [e._output_field_or_none for e in self.get_source_expressions()]
  372. def asc(self, **kwargs):
  373. return OrderBy(self, **kwargs)
  374. def desc(self, **kwargs):
  375. return OrderBy(self, descending=True, **kwargs)
  376. def reverse_ordering(self):
  377. return self
  378. def flatten(self):
  379. """
  380. Recursively yield this expression and all subexpressions, in
  381. depth-first order.
  382. """
  383. yield self
  384. for expr in self.get_source_expressions():
  385. if expr:
  386. if hasattr(expr, "flatten"):
  387. yield from expr.flatten()
  388. else:
  389. yield expr
  390. def select_format(self, compiler, sql, params):
  391. """
  392. Custom format for select clauses. For example, EXISTS expressions need
  393. to be wrapped in CASE WHEN on Oracle.
  394. """
  395. if hasattr(self.output_field, "select_format"):
  396. return self.output_field.select_format(compiler, sql, params)
  397. return sql, params
  398. @deconstructible
  399. class Expression(BaseExpression, Combinable):
  400. """An expression that can be combined with other expressions."""
  401. @cached_property
  402. def identity(self):
  403. constructor_signature = inspect.signature(self.__init__)
  404. args, kwargs = self._constructor_args
  405. signature = constructor_signature.bind_partial(*args, **kwargs)
  406. signature.apply_defaults()
  407. arguments = signature.arguments.items()
  408. identity = [self.__class__]
  409. for arg, value in arguments:
  410. if isinstance(value, fields.Field):
  411. if value.name and value.model:
  412. value = (value.model._meta.label, value.name)
  413. else:
  414. value = type(value)
  415. else:
  416. value = make_hashable(value)
  417. identity.append((arg, value))
  418. return tuple(identity)
  419. def __eq__(self, other):
  420. if not isinstance(other, Expression):
  421. return NotImplemented
  422. return other.identity == self.identity
  423. def __hash__(self):
  424. return hash(self.identity)
  425. # Type inference for CombinedExpression.output_field.
  426. # Missing items will result in FieldError, by design.
  427. #
  428. # The current approach for NULL is based on lowest common denominator behavior
  429. # i.e. if one of the supported databases is raising an error (rather than
  430. # return NULL) for `val <op> NULL`, then Django raises FieldError.
  431. _connector_combinations = [
  432. # Numeric operations - operands of same type.
  433. # PositiveIntegerField should take precedence over IntegerField (except
  434. # subtraction).
  435. {
  436. connector: [
  437. (
  438. fields.PositiveIntegerField,
  439. fields.PositiveIntegerField,
  440. fields.PositiveIntegerField,
  441. ),
  442. ]
  443. for connector in (
  444. Combinable.ADD,
  445. Combinable.MUL,
  446. Combinable.DIV,
  447. Combinable.MOD,
  448. Combinable.POW,
  449. )
  450. },
  451. # Other numeric operands.
  452. {
  453. connector: [
  454. (fields.IntegerField, fields.IntegerField, fields.IntegerField),
  455. (fields.FloatField, fields.FloatField, fields.FloatField),
  456. (fields.DecimalField, fields.DecimalField, fields.DecimalField),
  457. ]
  458. for connector in (
  459. Combinable.ADD,
  460. Combinable.SUB,
  461. Combinable.MUL,
  462. # Behavior for DIV with integer arguments follows Postgres/SQLite,
  463. # not MySQL/Oracle.
  464. Combinable.DIV,
  465. Combinable.MOD,
  466. Combinable.POW,
  467. )
  468. },
  469. # Numeric operations - operands of different type.
  470. {
  471. connector: [
  472. (fields.IntegerField, fields.DecimalField, fields.DecimalField),
  473. (fields.DecimalField, fields.IntegerField, fields.DecimalField),
  474. (fields.IntegerField, fields.FloatField, fields.FloatField),
  475. (fields.FloatField, fields.IntegerField, fields.FloatField),
  476. ]
  477. for connector in (
  478. Combinable.ADD,
  479. Combinable.SUB,
  480. Combinable.MUL,
  481. Combinable.DIV,
  482. Combinable.MOD,
  483. )
  484. },
  485. # Bitwise operators.
  486. {
  487. connector: [
  488. (fields.IntegerField, fields.IntegerField, fields.IntegerField),
  489. ]
  490. for connector in (
  491. Combinable.BITAND,
  492. Combinable.BITOR,
  493. Combinable.BITLEFTSHIFT,
  494. Combinable.BITRIGHTSHIFT,
  495. Combinable.BITXOR,
  496. )
  497. },
  498. # Numeric with NULL.
  499. {
  500. connector: [
  501. (field_type, NoneType, field_type),
  502. (NoneType, field_type, field_type),
  503. ]
  504. for connector in (
  505. Combinable.ADD,
  506. Combinable.SUB,
  507. Combinable.MUL,
  508. Combinable.DIV,
  509. Combinable.MOD,
  510. Combinable.POW,
  511. )
  512. for field_type in (fields.IntegerField, fields.DecimalField, fields.FloatField)
  513. },
  514. # Date/DateTimeField/DurationField/TimeField.
  515. {
  516. Combinable.ADD: [
  517. # Date/DateTimeField.
  518. (fields.DateField, fields.DurationField, fields.DateTimeField),
  519. (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
  520. (fields.DurationField, fields.DateField, fields.DateTimeField),
  521. (fields.DurationField, fields.DateTimeField, fields.DateTimeField),
  522. # DurationField.
  523. (fields.DurationField, fields.DurationField, fields.DurationField),
  524. # TimeField.
  525. (fields.TimeField, fields.DurationField, fields.TimeField),
  526. (fields.DurationField, fields.TimeField, fields.TimeField),
  527. ],
  528. },
  529. {
  530. Combinable.SUB: [
  531. # Date/DateTimeField.
  532. (fields.DateField, fields.DurationField, fields.DateTimeField),
  533. (fields.DateTimeField, fields.DurationField, fields.DateTimeField),
  534. (fields.DateField, fields.DateField, fields.DurationField),
  535. (fields.DateField, fields.DateTimeField, fields.DurationField),
  536. (fields.DateTimeField, fields.DateField, fields.DurationField),
  537. (fields.DateTimeField, fields.DateTimeField, fields.DurationField),
  538. # DurationField.
  539. (fields.DurationField, fields.DurationField, fields.DurationField),
  540. # TimeField.
  541. (fields.TimeField, fields.DurationField, fields.TimeField),
  542. (fields.TimeField, fields.TimeField, fields.DurationField),
  543. ],
  544. },
  545. ]
  546. _connector_combinators = defaultdict(list)
  547. def register_combinable_fields(lhs, connector, rhs, result):
  548. """
  549. Register combinable types:
  550. lhs <connector> rhs -> result
  551. e.g.
  552. register_combinable_fields(
  553. IntegerField, Combinable.ADD, FloatField, FloatField
  554. )
  555. """
  556. _connector_combinators[connector].append((lhs, rhs, result))
  557. for d in _connector_combinations:
  558. for connector, field_types in d.items():
  559. for lhs, rhs, result in field_types:
  560. register_combinable_fields(lhs, connector, rhs, result)
  561. @functools.lru_cache(maxsize=128)
  562. def _resolve_combined_type(connector, lhs_type, rhs_type):
  563. combinators = _connector_combinators.get(connector, ())
  564. for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
  565. if issubclass(lhs_type, combinator_lhs_type) and issubclass(
  566. rhs_type, combinator_rhs_type
  567. ):
  568. return combined_type
  569. class CombinedExpression(SQLiteNumericMixin, Expression):
  570. def __init__(self, lhs, connector, rhs, output_field=None):
  571. super().__init__(output_field=output_field)
  572. self.connector = connector
  573. self.lhs = lhs
  574. self.rhs = rhs
  575. def __repr__(self):
  576. return "<{}: {}>".format(self.__class__.__name__, self)
  577. def __str__(self):
  578. return "{} {} {}".format(self.lhs, self.connector, self.rhs)
  579. def get_source_expressions(self):
  580. return [self.lhs, self.rhs]
  581. def set_source_expressions(self, exprs):
  582. self.lhs, self.rhs = exprs
  583. def _resolve_output_field(self):
  584. # We avoid using super() here for reasons given in
  585. # Expression._resolve_output_field()
  586. combined_type = _resolve_combined_type(
  587. self.connector,
  588. type(self.lhs._output_field_or_none),
  589. type(self.rhs._output_field_or_none),
  590. )
  591. if combined_type is None:
  592. raise FieldError(
  593. f"Cannot infer type of {self.connector!r} expression involving these "
  594. f"types: {self.lhs.output_field.__class__.__name__}, "
  595. f"{self.rhs.output_field.__class__.__name__}. You must set "
  596. f"output_field."
  597. )
  598. return combined_type()
  599. def as_sql(self, compiler, connection):
  600. expressions = []
  601. expression_params = []
  602. sql, params = compiler.compile(self.lhs)
  603. expressions.append(sql)
  604. expression_params.extend(params)
  605. sql, params = compiler.compile(self.rhs)
  606. expressions.append(sql)
  607. expression_params.extend(params)
  608. # order of precedence
  609. expression_wrapper = "(%s)"
  610. sql = connection.ops.combine_expression(self.connector, expressions)
  611. return expression_wrapper % sql, expression_params
  612. def resolve_expression(
  613. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  614. ):
  615. lhs = self.lhs.resolve_expression(
  616. query, allow_joins, reuse, summarize, for_save
  617. )
  618. rhs = self.rhs.resolve_expression(
  619. query, allow_joins, reuse, summarize, for_save
  620. )
  621. if not isinstance(self, (DurationExpression, TemporalSubtraction)):
  622. try:
  623. lhs_type = lhs.output_field.get_internal_type()
  624. except (AttributeError, FieldError):
  625. lhs_type = None
  626. try:
  627. rhs_type = rhs.output_field.get_internal_type()
  628. except (AttributeError, FieldError):
  629. rhs_type = None
  630. if "DurationField" in {lhs_type, rhs_type} and lhs_type != rhs_type:
  631. return DurationExpression(
  632. self.lhs, self.connector, self.rhs
  633. ).resolve_expression(
  634. query,
  635. allow_joins,
  636. reuse,
  637. summarize,
  638. for_save,
  639. )
  640. datetime_fields = {"DateField", "DateTimeField", "TimeField"}
  641. if (
  642. self.connector == self.SUB
  643. and lhs_type in datetime_fields
  644. and lhs_type == rhs_type
  645. ):
  646. return TemporalSubtraction(self.lhs, self.rhs).resolve_expression(
  647. query,
  648. allow_joins,
  649. reuse,
  650. summarize,
  651. for_save,
  652. )
  653. c = self.copy()
  654. c.is_summary = summarize
  655. c.lhs = lhs
  656. c.rhs = rhs
  657. return c
  658. @cached_property
  659. def allowed_default(self):
  660. return self.lhs.allowed_default and self.rhs.allowed_default
  661. class DurationExpression(CombinedExpression):
  662. def compile(self, side, compiler, connection):
  663. try:
  664. output = side.output_field
  665. except FieldError:
  666. pass
  667. else:
  668. if output.get_internal_type() == "DurationField":
  669. sql, params = compiler.compile(side)
  670. return connection.ops.format_for_duration_arithmetic(sql), params
  671. return compiler.compile(side)
  672. def as_sql(self, compiler, connection):
  673. if connection.features.has_native_duration_field:
  674. return super().as_sql(compiler, connection)
  675. connection.ops.check_expression_support(self)
  676. expressions = []
  677. expression_params = []
  678. sql, params = self.compile(self.lhs, compiler, connection)
  679. expressions.append(sql)
  680. expression_params.extend(params)
  681. sql, params = self.compile(self.rhs, compiler, connection)
  682. expressions.append(sql)
  683. expression_params.extend(params)
  684. # order of precedence
  685. expression_wrapper = "(%s)"
  686. sql = connection.ops.combine_duration_expression(self.connector, expressions)
  687. return expression_wrapper % sql, expression_params
  688. def as_sqlite(self, compiler, connection, **extra_context):
  689. sql, params = self.as_sql(compiler, connection, **extra_context)
  690. if self.connector in {Combinable.MUL, Combinable.DIV}:
  691. try:
  692. lhs_type = self.lhs.output_field.get_internal_type()
  693. rhs_type = self.rhs.output_field.get_internal_type()
  694. except (AttributeError, FieldError):
  695. pass
  696. else:
  697. allowed_fields = {
  698. "DecimalField",
  699. "DurationField",
  700. "FloatField",
  701. "IntegerField",
  702. }
  703. if lhs_type not in allowed_fields or rhs_type not in allowed_fields:
  704. raise DatabaseError(
  705. f"Invalid arguments for operator {self.connector}."
  706. )
  707. return sql, params
  708. class TemporalSubtraction(CombinedExpression):
  709. output_field = fields.DurationField()
  710. def __init__(self, lhs, rhs):
  711. super().__init__(lhs, self.SUB, rhs)
  712. def as_sql(self, compiler, connection):
  713. connection.ops.check_expression_support(self)
  714. lhs = compiler.compile(self.lhs)
  715. rhs = compiler.compile(self.rhs)
  716. return connection.ops.subtract_temporals(
  717. self.lhs.output_field.get_internal_type(), lhs, rhs
  718. )
  719. @deconstructible(path="django.db.models.F")
  720. class F(Combinable):
  721. """An object capable of resolving references to existing query objects."""
  722. allowed_default = False
  723. def __init__(self, name):
  724. """
  725. Arguments:
  726. * name: the name of the field this expression references
  727. """
  728. self.name = name
  729. def __repr__(self):
  730. return "{}({})".format(self.__class__.__name__, self.name)
  731. def resolve_expression(
  732. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  733. ):
  734. return query.resolve_ref(self.name, allow_joins, reuse, summarize)
  735. def replace_expressions(self, replacements):
  736. return replacements.get(self, self)
  737. def asc(self, **kwargs):
  738. return OrderBy(self, **kwargs)
  739. def desc(self, **kwargs):
  740. return OrderBy(self, descending=True, **kwargs)
  741. def __eq__(self, other):
  742. return self.__class__ == other.__class__ and self.name == other.name
  743. def __hash__(self):
  744. return hash(self.name)
  745. def copy(self):
  746. return copy.copy(self)
  747. class ResolvedOuterRef(F):
  748. """
  749. An object that contains a reference to an outer query.
  750. In this case, the reference to the outer query has been resolved because
  751. the inner query has been used as a subquery.
  752. """
  753. contains_aggregate = False
  754. contains_over_clause = False
  755. def as_sql(self, *args, **kwargs):
  756. raise ValueError(
  757. "This queryset contains a reference to an outer query and may "
  758. "only be used in a subquery."
  759. )
  760. def resolve_expression(self, *args, **kwargs):
  761. col = super().resolve_expression(*args, **kwargs)
  762. if col.contains_over_clause:
  763. raise NotSupportedError(
  764. f"Referencing outer query window expression is not supported: "
  765. f"{self.name}."
  766. )
  767. # FIXME: Rename possibly_multivalued to multivalued and fix detection
  768. # for non-multivalued JOINs (e.g. foreign key fields). This should take
  769. # into account only many-to-many and one-to-many relationships.
  770. col.possibly_multivalued = LOOKUP_SEP in self.name
  771. return col
  772. def relabeled_clone(self, relabels):
  773. return self
  774. def get_group_by_cols(self):
  775. return []
  776. class OuterRef(F):
  777. contains_aggregate = False
  778. contains_over_clause = False
  779. def resolve_expression(self, *args, **kwargs):
  780. if isinstance(self.name, self.__class__):
  781. return self.name
  782. return ResolvedOuterRef(self.name)
  783. def relabeled_clone(self, relabels):
  784. return self
  785. @deconstructible(path="django.db.models.Func")
  786. class Func(SQLiteNumericMixin, Expression):
  787. """An SQL function call."""
  788. function = None
  789. template = "%(function)s(%(expressions)s)"
  790. arg_joiner = ", "
  791. arity = None # The number of arguments the function accepts.
  792. def __init__(self, *expressions, output_field=None, **extra):
  793. if self.arity is not None and len(expressions) != self.arity:
  794. raise TypeError(
  795. "'%s' takes exactly %s %s (%s given)"
  796. % (
  797. self.__class__.__name__,
  798. self.arity,
  799. "argument" if self.arity == 1 else "arguments",
  800. len(expressions),
  801. )
  802. )
  803. super().__init__(output_field=output_field)
  804. self.source_expressions = self._parse_expressions(*expressions)
  805. self.extra = extra
  806. def __repr__(self):
  807. args = self.arg_joiner.join(str(arg) for arg in self.source_expressions)
  808. extra = {**self.extra, **self._get_repr_options()}
  809. if extra:
  810. extra = ", ".join(
  811. str(key) + "=" + str(val) for key, val in sorted(extra.items())
  812. )
  813. return "{}({}, {})".format(self.__class__.__name__, args, extra)
  814. return "{}({})".format(self.__class__.__name__, args)
  815. def _get_repr_options(self):
  816. """Return a dict of extra __init__() options to include in the repr."""
  817. return {}
  818. def get_source_expressions(self):
  819. return self.source_expressions
  820. def set_source_expressions(self, exprs):
  821. self.source_expressions = exprs
  822. def resolve_expression(
  823. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  824. ):
  825. c = self.copy()
  826. c.is_summary = summarize
  827. for pos, arg in enumerate(c.source_expressions):
  828. c.source_expressions[pos] = arg.resolve_expression(
  829. query, allow_joins, reuse, summarize, for_save
  830. )
  831. return c
  832. def as_sql(
  833. self,
  834. compiler,
  835. connection,
  836. function=None,
  837. template=None,
  838. arg_joiner=None,
  839. **extra_context,
  840. ):
  841. connection.ops.check_expression_support(self)
  842. sql_parts = []
  843. params = []
  844. for arg in self.source_expressions:
  845. try:
  846. arg_sql, arg_params = compiler.compile(arg)
  847. except EmptyResultSet:
  848. empty_result_set_value = getattr(
  849. arg, "empty_result_set_value", NotImplemented
  850. )
  851. if empty_result_set_value is NotImplemented:
  852. raise
  853. arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
  854. except FullResultSet:
  855. arg_sql, arg_params = compiler.compile(Value(True))
  856. sql_parts.append(arg_sql)
  857. params.extend(arg_params)
  858. data = {**self.extra, **extra_context}
  859. # Use the first supplied value in this order: the parameter to this
  860. # method, a value supplied in __init__()'s **extra (the value in
  861. # `data`), or the value defined on the class.
  862. if function is not None:
  863. data["function"] = function
  864. else:
  865. data.setdefault("function", self.function)
  866. template = template or data.get("template", self.template)
  867. arg_joiner = arg_joiner or data.get("arg_joiner", self.arg_joiner)
  868. data["expressions"] = data["field"] = arg_joiner.join(sql_parts)
  869. return template % data, params
  870. def copy(self):
  871. copy = super().copy()
  872. copy.source_expressions = self.source_expressions[:]
  873. copy.extra = self.extra.copy()
  874. return copy
  875. @cached_property
  876. def allowed_default(self):
  877. return all(expression.allowed_default for expression in self.source_expressions)
  878. @deconstructible(path="django.db.models.Value")
  879. class Value(SQLiteNumericMixin, Expression):
  880. """Represent a wrapped value as a node within an expression."""
  881. # Provide a default value for `for_save` in order to allow unresolved
  882. # instances to be compiled until a decision is taken in #25425.
  883. for_save = False
  884. allowed_default = True
  885. def __init__(self, value, output_field=None):
  886. """
  887. Arguments:
  888. * value: the value this expression represents. The value will be
  889. added into the sql parameter list and properly quoted.
  890. * output_field: an instance of the model field type that this
  891. expression will return, such as IntegerField() or CharField().
  892. """
  893. super().__init__(output_field=output_field)
  894. self.value = value
  895. def __repr__(self):
  896. return f"{self.__class__.__name__}({self.value!r})"
  897. def as_sql(self, compiler, connection):
  898. connection.ops.check_expression_support(self)
  899. val = self.value
  900. output_field = self._output_field_or_none
  901. if output_field is not None:
  902. if self.for_save:
  903. val = output_field.get_db_prep_save(val, connection=connection)
  904. else:
  905. val = output_field.get_db_prep_value(val, connection=connection)
  906. if hasattr(output_field, "get_placeholder"):
  907. return output_field.get_placeholder(val, compiler, connection), [val]
  908. if val is None:
  909. # oracledb does not always convert None to the appropriate
  910. # NULL type (like in case expressions using numbers), so we
  911. # use a literal SQL NULL
  912. return "NULL", []
  913. return "%s", [val]
  914. def resolve_expression(
  915. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  916. ):
  917. c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
  918. c.for_save = for_save
  919. return c
  920. def get_group_by_cols(self):
  921. return []
  922. def _resolve_output_field(self):
  923. if isinstance(self.value, str):
  924. return fields.CharField()
  925. if isinstance(self.value, bool):
  926. return fields.BooleanField()
  927. if isinstance(self.value, int):
  928. return fields.IntegerField()
  929. if isinstance(self.value, float):
  930. return fields.FloatField()
  931. if isinstance(self.value, datetime.datetime):
  932. return fields.DateTimeField()
  933. if isinstance(self.value, datetime.date):
  934. return fields.DateField()
  935. if isinstance(self.value, datetime.time):
  936. return fields.TimeField()
  937. if isinstance(self.value, datetime.timedelta):
  938. return fields.DurationField()
  939. if isinstance(self.value, Decimal):
  940. return fields.DecimalField()
  941. if isinstance(self.value, bytes):
  942. return fields.BinaryField()
  943. if isinstance(self.value, UUID):
  944. return fields.UUIDField()
  945. @property
  946. def empty_result_set_value(self):
  947. return self.value
  948. class RawSQL(Expression):
  949. allowed_default = True
  950. def __init__(self, sql, params, output_field=None):
  951. if output_field is None:
  952. output_field = fields.Field()
  953. self.sql, self.params = sql, params
  954. super().__init__(output_field=output_field)
  955. def __repr__(self):
  956. return "{}({}, {})".format(self.__class__.__name__, self.sql, self.params)
  957. def as_sql(self, compiler, connection):
  958. return "(%s)" % self.sql, self.params
  959. def get_group_by_cols(self):
  960. return [self]
  961. def resolve_expression(
  962. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  963. ):
  964. # Resolve parents fields used in raw SQL.
  965. if query.model:
  966. for parent in query.model._meta.get_parent_list():
  967. for parent_field in parent._meta.local_fields:
  968. _, column_name = parent_field.get_attname_column()
  969. if column_name.lower() in self.sql.lower():
  970. query.resolve_ref(
  971. parent_field.name, allow_joins, reuse, summarize
  972. )
  973. break
  974. return super().resolve_expression(
  975. query, allow_joins, reuse, summarize, for_save
  976. )
  977. class Star(Expression):
  978. def __repr__(self):
  979. return "'*'"
  980. def as_sql(self, compiler, connection):
  981. return "*", []
  982. class DatabaseDefault(Expression):
  983. """Placeholder expression for the database default in an insert query."""
  984. def as_sql(self, compiler, connection):
  985. return "DEFAULT", []
  986. class Col(Expression):
  987. contains_column_references = True
  988. possibly_multivalued = False
  989. def __init__(self, alias, target, output_field=None):
  990. if output_field is None:
  991. output_field = target
  992. super().__init__(output_field=output_field)
  993. self.alias, self.target = alias, target
  994. def __repr__(self):
  995. alias, target = self.alias, self.target
  996. identifiers = (alias, str(target)) if alias else (str(target),)
  997. return "{}({})".format(self.__class__.__name__, ", ".join(identifiers))
  998. def as_sql(self, compiler, connection):
  999. alias, column = self.alias, self.target.column
  1000. identifiers = (alias, column) if alias else (column,)
  1001. sql = ".".join(map(compiler.quote_name_unless_alias, identifiers))
  1002. return sql, []
  1003. def relabeled_clone(self, relabels):
  1004. if self.alias is None:
  1005. return self
  1006. return self.__class__(
  1007. relabels.get(self.alias, self.alias), self.target, self.output_field
  1008. )
  1009. def get_group_by_cols(self):
  1010. return [self]
  1011. def get_db_converters(self, connection):
  1012. if self.target == self.output_field:
  1013. return self.output_field.get_db_converters(connection)
  1014. return self.output_field.get_db_converters(
  1015. connection
  1016. ) + self.target.get_db_converters(connection)
  1017. class Ref(Expression):
  1018. """
  1019. Reference to column alias of the query. For example, Ref('sum_cost') in
  1020. qs.annotate(sum_cost=Sum('cost')) query.
  1021. """
  1022. def __init__(self, refs, source):
  1023. super().__init__()
  1024. self.refs, self.source = refs, source
  1025. def __repr__(self):
  1026. return "{}({}, {})".format(self.__class__.__name__, self.refs, self.source)
  1027. def get_source_expressions(self):
  1028. return [self.source]
  1029. def set_source_expressions(self, exprs):
  1030. (self.source,) = exprs
  1031. def resolve_expression(
  1032. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1033. ):
  1034. # The sub-expression `source` has already been resolved, as this is
  1035. # just a reference to the name of `source`.
  1036. return self
  1037. def get_refs(self):
  1038. return {self.refs}
  1039. def relabeled_clone(self, relabels):
  1040. clone = self.copy()
  1041. clone.source = self.source.relabeled_clone(relabels)
  1042. return clone
  1043. def as_sql(self, compiler, connection):
  1044. return connection.ops.quote_name(self.refs), []
  1045. def get_group_by_cols(self):
  1046. return [self]
  1047. class ExpressionList(Func):
  1048. """
  1049. An expression containing multiple expressions. Can be used to provide a
  1050. list of expressions as an argument to another expression, like a partition
  1051. clause.
  1052. """
  1053. template = "%(expressions)s"
  1054. def __init__(self, *expressions, **extra):
  1055. if not expressions:
  1056. raise ValueError(
  1057. "%s requires at least one expression." % self.__class__.__name__
  1058. )
  1059. super().__init__(*expressions, **extra)
  1060. def __str__(self):
  1061. return self.arg_joiner.join(str(arg) for arg in self.source_expressions)
  1062. def as_sqlite(self, compiler, connection, **extra_context):
  1063. # Casting to numeric is unnecessary.
  1064. return self.as_sql(compiler, connection, **extra_context)
  1065. def get_group_by_cols(self):
  1066. group_by_cols = []
  1067. for partition in self.get_source_expressions():
  1068. group_by_cols.extend(partition.get_group_by_cols())
  1069. return group_by_cols
  1070. class OrderByList(Func):
  1071. allowed_default = False
  1072. template = "ORDER BY %(expressions)s"
  1073. def __init__(self, *expressions, **extra):
  1074. expressions = (
  1075. (
  1076. OrderBy(F(expr[1:]), descending=True)
  1077. if isinstance(expr, str) and expr[0] == "-"
  1078. else expr
  1079. )
  1080. for expr in expressions
  1081. )
  1082. super().__init__(*expressions, **extra)
  1083. def as_sql(self, *args, **kwargs):
  1084. if not self.source_expressions:
  1085. return "", ()
  1086. return super().as_sql(*args, **kwargs)
  1087. def get_group_by_cols(self):
  1088. group_by_cols = []
  1089. for order_by in self.get_source_expressions():
  1090. group_by_cols.extend(order_by.get_group_by_cols())
  1091. return group_by_cols
  1092. @deconstructible(path="django.db.models.ExpressionWrapper")
  1093. class ExpressionWrapper(SQLiteNumericMixin, Expression):
  1094. """
  1095. An expression that can wrap another expression so that it can provide
  1096. extra context to the inner expression, such as the output_field.
  1097. """
  1098. def __init__(self, expression, output_field):
  1099. super().__init__(output_field=output_field)
  1100. self.expression = expression
  1101. def set_source_expressions(self, exprs):
  1102. self.expression = exprs[0]
  1103. def get_source_expressions(self):
  1104. return [self.expression]
  1105. def get_group_by_cols(self):
  1106. if isinstance(self.expression, Expression):
  1107. expression = self.expression.copy()
  1108. expression.output_field = self.output_field
  1109. return expression.get_group_by_cols()
  1110. # For non-expressions e.g. an SQL WHERE clause, the entire
  1111. # `expression` must be included in the GROUP BY clause.
  1112. return super().get_group_by_cols()
  1113. def as_sql(self, compiler, connection):
  1114. return compiler.compile(self.expression)
  1115. def __repr__(self):
  1116. return "{}({})".format(self.__class__.__name__, self.expression)
  1117. @property
  1118. def allowed_default(self):
  1119. return self.expression.allowed_default
  1120. class NegatedExpression(ExpressionWrapper):
  1121. """The logical negation of a conditional expression."""
  1122. def __init__(self, expression):
  1123. super().__init__(expression, output_field=fields.BooleanField())
  1124. def __invert__(self):
  1125. return self.expression.copy()
  1126. def as_sql(self, compiler, connection):
  1127. try:
  1128. sql, params = super().as_sql(compiler, connection)
  1129. except EmptyResultSet:
  1130. features = compiler.connection.features
  1131. if not features.supports_boolean_expr_in_select_clause:
  1132. return "1=1", ()
  1133. return compiler.compile(Value(True))
  1134. ops = compiler.connection.ops
  1135. # Some database backends (e.g. Oracle) don't allow EXISTS() and filters
  1136. # to be compared to another expression unless they're wrapped in a CASE
  1137. # WHEN.
  1138. if not ops.conditional_expression_supported_in_where_clause(self.expression):
  1139. return f"CASE WHEN {sql} = 0 THEN 1 ELSE 0 END", params
  1140. return f"NOT {sql}", params
  1141. def resolve_expression(
  1142. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1143. ):
  1144. resolved = super().resolve_expression(
  1145. query, allow_joins, reuse, summarize, for_save
  1146. )
  1147. if not getattr(resolved.expression, "conditional", False):
  1148. raise TypeError("Cannot negate non-conditional expressions.")
  1149. return resolved
  1150. def select_format(self, compiler, sql, params):
  1151. # Wrap boolean expressions with a CASE WHEN expression if a database
  1152. # backend (e.g. Oracle) doesn't support boolean expression in SELECT or
  1153. # GROUP BY list.
  1154. expression_supported_in_where_clause = (
  1155. compiler.connection.ops.conditional_expression_supported_in_where_clause
  1156. )
  1157. if (
  1158. not compiler.connection.features.supports_boolean_expr_in_select_clause
  1159. # Avoid double wrapping.
  1160. and expression_supported_in_where_clause(self.expression)
  1161. ):
  1162. sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
  1163. return sql, params
  1164. @deconstructible(path="django.db.models.When")
  1165. class When(Expression):
  1166. template = "WHEN %(condition)s THEN %(result)s"
  1167. # This isn't a complete conditional expression, must be used in Case().
  1168. conditional = False
  1169. def __init__(self, condition=None, then=None, **lookups):
  1170. if lookups:
  1171. if condition is None:
  1172. condition, lookups = Q(**lookups), None
  1173. elif getattr(condition, "conditional", False):
  1174. condition, lookups = Q(condition, **lookups), None
  1175. if condition is None or not getattr(condition, "conditional", False) or lookups:
  1176. raise TypeError(
  1177. "When() supports a Q object, a boolean expression, or lookups "
  1178. "as a condition."
  1179. )
  1180. if isinstance(condition, Q) and not condition:
  1181. raise ValueError("An empty Q() can't be used as a When() condition.")
  1182. super().__init__(output_field=None)
  1183. self.condition = condition
  1184. self.result = self._parse_expressions(then)[0]
  1185. def __str__(self):
  1186. return "WHEN %r THEN %r" % (self.condition, self.result)
  1187. def __repr__(self):
  1188. return "<%s: %s>" % (self.__class__.__name__, self)
  1189. def get_source_expressions(self):
  1190. return [self.condition, self.result]
  1191. def set_source_expressions(self, exprs):
  1192. self.condition, self.result = exprs
  1193. def get_source_fields(self):
  1194. # We're only interested in the fields of the result expressions.
  1195. return [self.result._output_field_or_none]
  1196. def resolve_expression(
  1197. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1198. ):
  1199. c = self.copy()
  1200. c.is_summary = summarize
  1201. if hasattr(c.condition, "resolve_expression"):
  1202. c.condition = c.condition.resolve_expression(
  1203. query, allow_joins, reuse, summarize, False
  1204. )
  1205. c.result = c.result.resolve_expression(
  1206. query, allow_joins, reuse, summarize, for_save
  1207. )
  1208. return c
  1209. def as_sql(self, compiler, connection, template=None, **extra_context):
  1210. connection.ops.check_expression_support(self)
  1211. template_params = extra_context
  1212. sql_params = []
  1213. condition_sql, condition_params = compiler.compile(self.condition)
  1214. template_params["condition"] = condition_sql
  1215. result_sql, result_params = compiler.compile(self.result)
  1216. template_params["result"] = result_sql
  1217. template = template or self.template
  1218. return template % template_params, (
  1219. *sql_params,
  1220. *condition_params,
  1221. *result_params,
  1222. )
  1223. def get_group_by_cols(self):
  1224. # This is not a complete expression and cannot be used in GROUP BY.
  1225. cols = []
  1226. for source in self.get_source_expressions():
  1227. cols.extend(source.get_group_by_cols())
  1228. return cols
  1229. @cached_property
  1230. def allowed_default(self):
  1231. return self.condition.allowed_default and self.result.allowed_default
  1232. @deconstructible(path="django.db.models.Case")
  1233. class Case(SQLiteNumericMixin, Expression):
  1234. """
  1235. An SQL searched CASE expression:
  1236. CASE
  1237. WHEN n > 0
  1238. THEN 'positive'
  1239. WHEN n < 0
  1240. THEN 'negative'
  1241. ELSE 'zero'
  1242. END
  1243. """
  1244. template = "CASE %(cases)s ELSE %(default)s END"
  1245. case_joiner = " "
  1246. def __init__(self, *cases, default=None, output_field=None, **extra):
  1247. if not all(isinstance(case, When) for case in cases):
  1248. raise TypeError("Positional arguments must all be When objects.")
  1249. super().__init__(output_field)
  1250. self.cases = list(cases)
  1251. self.default = self._parse_expressions(default)[0]
  1252. self.extra = extra
  1253. def __str__(self):
  1254. return "CASE %s, ELSE %r" % (
  1255. ", ".join(str(c) for c in self.cases),
  1256. self.default,
  1257. )
  1258. def __repr__(self):
  1259. return "<%s: %s>" % (self.__class__.__name__, self)
  1260. def get_source_expressions(self):
  1261. return self.cases + [self.default]
  1262. def set_source_expressions(self, exprs):
  1263. *self.cases, self.default = exprs
  1264. def resolve_expression(
  1265. self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
  1266. ):
  1267. c = self.copy()
  1268. c.is_summary = summarize
  1269. for pos, case in enumerate(c.cases):
  1270. c.cases[pos] = case.resolve_expression(
  1271. query, allow_joins, reuse, summarize, for_save
  1272. )
  1273. c.default = c.default.resolve_expression(
  1274. query, allow_joins, reuse, summarize, for_save
  1275. )
  1276. return c
  1277. def copy(self):
  1278. c = super().copy()
  1279. c.cases = c.cases[:]
  1280. return c
  1281. def as_sql(
  1282. self, compiler, connection, template=None, case_joiner=None, **extra_context
  1283. ):
  1284. connection.ops.check_expression_support(self)
  1285. if not self.cases:
  1286. return compiler.compile(self.default)
  1287. template_params = {**self.extra, **extra_context}
  1288. case_parts = []
  1289. sql_params = []
  1290. default_sql, default_params = compiler.compile(self.default)
  1291. for case in self.cases:
  1292. try:
  1293. case_sql, case_params = compiler.compile(case)
  1294. except EmptyResultSet:
  1295. continue
  1296. except FullResultSet:
  1297. default_sql, default_params = compiler.compile(case.result)
  1298. break
  1299. case_parts.append(case_sql)
  1300. sql_params.extend(case_params)
  1301. if not case_parts:
  1302. return default_sql, default_params
  1303. case_joiner = case_joiner or self.case_joiner
  1304. template_params["cases"] = case_joiner.join(case_parts)
  1305. template_params["default"] = default_sql
  1306. sql_params.extend(default_params)
  1307. template = template or template_params.get("template", self.template)
  1308. sql = template % template_params
  1309. if self._output_field_or_none is not None:
  1310. sql = connection.ops.unification_cast_sql(self.output_field) % sql
  1311. return sql, sql_params
  1312. def get_group_by_cols(self):
  1313. if not self.cases:
  1314. return self.default.get_group_by_cols()
  1315. return super().get_group_by_cols()
  1316. @cached_property
  1317. def allowed_default(self):
  1318. return self.default.allowed_default and all(
  1319. case_.allowed_default for case_ in self.cases
  1320. )
  1321. class Subquery(BaseExpression, Combinable):
  1322. """
  1323. An explicit subquery. It may contain OuterRef() references to the outer
  1324. query which will be resolved when it is applied to that query.
  1325. """
  1326. template = "(%(subquery)s)"
  1327. contains_aggregate = False
  1328. empty_result_set_value = None
  1329. subquery = True
  1330. def __init__(self, queryset, output_field=None, **extra):
  1331. # Allow the usage of both QuerySet and sql.Query objects.
  1332. self.query = getattr(queryset, "query", queryset).clone()
  1333. self.query.subquery = True
  1334. self.extra = extra
  1335. super().__init__(output_field)
  1336. def get_source_expressions(self):
  1337. return [self.query]
  1338. def set_source_expressions(self, exprs):
  1339. self.query = exprs[0]
  1340. def _resolve_output_field(self):
  1341. return self.query.output_field
  1342. def copy(self):
  1343. clone = super().copy()
  1344. clone.query = clone.query.clone()
  1345. return clone
  1346. @property
  1347. def external_aliases(self):
  1348. return self.query.external_aliases
  1349. def get_external_cols(self):
  1350. return self.query.get_external_cols()
  1351. def as_sql(self, compiler, connection, template=None, **extra_context):
  1352. connection.ops.check_expression_support(self)
  1353. template_params = {**self.extra, **extra_context}
  1354. subquery_sql, sql_params = self.query.as_sql(compiler, connection)
  1355. template_params["subquery"] = subquery_sql[1:-1]
  1356. template = template or template_params.get("template", self.template)
  1357. sql = template % template_params
  1358. return sql, sql_params
  1359. def get_group_by_cols(self):
  1360. return self.query.get_group_by_cols(wrapper=self)
  1361. class Exists(Subquery):
  1362. template = "EXISTS(%(subquery)s)"
  1363. output_field = fields.BooleanField()
  1364. empty_result_set_value = False
  1365. def __init__(self, queryset, **kwargs):
  1366. super().__init__(queryset, **kwargs)
  1367. self.query = self.query.exists()
  1368. def select_format(self, compiler, sql, params):
  1369. # Wrap EXISTS() with a CASE WHEN expression if a database backend
  1370. # (e.g. Oracle) doesn't support boolean expression in SELECT or GROUP
  1371. # BY list.
  1372. if not compiler.connection.features.supports_boolean_expr_in_select_clause:
  1373. sql = "CASE WHEN {} THEN 1 ELSE 0 END".format(sql)
  1374. return sql, params
  1375. def as_sql(self, compiler, *args, **kwargs):
  1376. try:
  1377. return super().as_sql(compiler, *args, **kwargs)
  1378. except EmptyResultSet:
  1379. features = compiler.connection.features
  1380. if not features.supports_boolean_expr_in_select_clause:
  1381. return "1=0", ()
  1382. return compiler.compile(Value(False))
  1383. @deconstructible(path="django.db.models.OrderBy")
  1384. class OrderBy(Expression):
  1385. template = "%(expression)s %(ordering)s"
  1386. conditional = False
  1387. def __init__(self, expression, descending=False, nulls_first=None, nulls_last=None):
  1388. if nulls_first and nulls_last:
  1389. raise ValueError("nulls_first and nulls_last are mutually exclusive")
  1390. if nulls_first is False or nulls_last is False:
  1391. raise ValueError("nulls_first and nulls_last values must be True or None.")
  1392. self.nulls_first = nulls_first
  1393. self.nulls_last = nulls_last
  1394. self.descending = descending
  1395. if not hasattr(expression, "resolve_expression"):
  1396. raise ValueError("expression must be an expression type")
  1397. self.expression = expression
  1398. def __repr__(self):
  1399. return "{}({}, descending={})".format(
  1400. self.__class__.__name__, self.expression, self.descending
  1401. )
  1402. def set_source_expressions(self, exprs):
  1403. self.expression = exprs[0]
  1404. def get_source_expressions(self):
  1405. return [self.expression]
  1406. def as_sql(self, compiler, connection, template=None, **extra_context):
  1407. template = template or self.template
  1408. if connection.features.supports_order_by_nulls_modifier:
  1409. if self.nulls_last:
  1410. template = "%s NULLS LAST" % template
  1411. elif self.nulls_first:
  1412. template = "%s NULLS FIRST" % template
  1413. else:
  1414. if self.nulls_last and not (
  1415. self.descending and connection.features.order_by_nulls_first
  1416. ):
  1417. template = "%%(expression)s IS NULL, %s" % template
  1418. elif self.nulls_first and not (
  1419. not self.descending and connection.features.order_by_nulls_first
  1420. ):
  1421. template = "%%(expression)s IS NOT NULL, %s" % template
  1422. connection.ops.check_expression_support(self)
  1423. expression_sql, params = compiler.compile(self.expression)
  1424. placeholders = {
  1425. "expression": expression_sql,
  1426. "ordering": "DESC" if self.descending else "ASC",
  1427. **extra_context,
  1428. }
  1429. params *= template.count("%(expression)s")
  1430. return (template % placeholders).rstrip(), params
  1431. def as_oracle(self, compiler, connection):
  1432. # Oracle doesn't allow ORDER BY EXISTS() or filters unless it's wrapped
  1433. # in a CASE WHEN.
  1434. if connection.ops.conditional_expression_supported_in_where_clause(
  1435. self.expression
  1436. ):
  1437. copy = self.copy()
  1438. copy.expression = Case(
  1439. When(self.expression, then=True),
  1440. default=False,
  1441. )
  1442. return copy.as_sql(compiler, connection)
  1443. return self.as_sql(compiler, connection)
  1444. def get_group_by_cols(self):
  1445. cols = []
  1446. for source in self.get_source_expressions():
  1447. cols.extend(source.get_group_by_cols())
  1448. return cols
  1449. def reverse_ordering(self):
  1450. self.descending = not self.descending
  1451. if self.nulls_first:
  1452. self.nulls_last = True
  1453. self.nulls_first = None
  1454. elif self.nulls_last:
  1455. self.nulls_first = True
  1456. self.nulls_last = None
  1457. return self
  1458. def asc(self):
  1459. self.descending = False
  1460. def desc(self):
  1461. self.descending = True
  1462. class Window(SQLiteNumericMixin, Expression):
  1463. template = "%(expression)s OVER (%(window)s)"
  1464. # Although the main expression may either be an aggregate or an
  1465. # expression with an aggregate function, the GROUP BY that will
  1466. # be introduced in the query as a result is not desired.
  1467. contains_aggregate = False
  1468. contains_over_clause = True
  1469. def __init__(
  1470. self,
  1471. expression,
  1472. partition_by=None,
  1473. order_by=None,
  1474. frame=None,
  1475. output_field=None,
  1476. ):
  1477. self.partition_by = partition_by
  1478. self.order_by = order_by
  1479. self.frame = frame
  1480. if not getattr(expression, "window_compatible", False):
  1481. raise ValueError(
  1482. "Expression '%s' isn't compatible with OVER clauses."
  1483. % expression.__class__.__name__
  1484. )
  1485. if self.partition_by is not None:
  1486. if not isinstance(self.partition_by, (tuple, list)):
  1487. self.partition_by = (self.partition_by,)
  1488. self.partition_by = ExpressionList(*self.partition_by)
  1489. if self.order_by is not None:
  1490. if isinstance(self.order_by, (list, tuple)):
  1491. self.order_by = OrderByList(*self.order_by)
  1492. elif isinstance(self.order_by, (BaseExpression, str)):
  1493. self.order_by = OrderByList(self.order_by)
  1494. else:
  1495. raise ValueError(
  1496. "Window.order_by must be either a string reference to a "
  1497. "field, an expression, or a list or tuple of them."
  1498. )
  1499. super().__init__(output_field=output_field)
  1500. self.source_expression = self._parse_expressions(expression)[0]
  1501. def _resolve_output_field(self):
  1502. return self.source_expression.output_field
  1503. def get_source_expressions(self):
  1504. return [self.source_expression, self.partition_by, self.order_by, self.frame]
  1505. def set_source_expressions(self, exprs):
  1506. self.source_expression, self.partition_by, self.order_by, self.frame = exprs
  1507. def as_sql(self, compiler, connection, template=None):
  1508. connection.ops.check_expression_support(self)
  1509. if not connection.features.supports_over_clause:
  1510. raise NotSupportedError("This backend does not support window expressions.")
  1511. expr_sql, params = compiler.compile(self.source_expression)
  1512. window_sql, window_params = [], ()
  1513. if self.partition_by is not None:
  1514. sql_expr, sql_params = self.partition_by.as_sql(
  1515. compiler=compiler,
  1516. connection=connection,
  1517. template="PARTITION BY %(expressions)s",
  1518. )
  1519. window_sql.append(sql_expr)
  1520. window_params += tuple(sql_params)
  1521. if self.order_by is not None:
  1522. order_sql, order_params = compiler.compile(self.order_by)
  1523. window_sql.append(order_sql)
  1524. window_params += tuple(order_params)
  1525. if self.frame:
  1526. frame_sql, frame_params = compiler.compile(self.frame)
  1527. window_sql.append(frame_sql)
  1528. window_params += tuple(frame_params)
  1529. template = template or self.template
  1530. return (
  1531. template % {"expression": expr_sql, "window": " ".join(window_sql).strip()},
  1532. (*params, *window_params),
  1533. )
  1534. def as_sqlite(self, compiler, connection):
  1535. if isinstance(self.output_field, fields.DecimalField):
  1536. # Casting to numeric must be outside of the window expression.
  1537. copy = self.copy()
  1538. source_expressions = copy.get_source_expressions()
  1539. source_expressions[0].output_field = fields.FloatField()
  1540. copy.set_source_expressions(source_expressions)
  1541. return super(Window, copy).as_sqlite(compiler, connection)
  1542. return self.as_sql(compiler, connection)
  1543. def __str__(self):
  1544. return "{} OVER ({}{}{})".format(
  1545. str(self.source_expression),
  1546. "PARTITION BY " + str(self.partition_by) if self.partition_by else "",
  1547. str(self.order_by or ""),
  1548. str(self.frame or ""),
  1549. )
  1550. def __repr__(self):
  1551. return "<%s: %s>" % (self.__class__.__name__, self)
  1552. def get_group_by_cols(self):
  1553. group_by_cols = []
  1554. if self.partition_by:
  1555. group_by_cols.extend(self.partition_by.get_group_by_cols())
  1556. if self.order_by is not None:
  1557. group_by_cols.extend(self.order_by.get_group_by_cols())
  1558. return group_by_cols
  1559. class WindowFrame(Expression):
  1560. """
  1561. Model the frame clause in window expressions. There are two types of frame
  1562. clauses which are subclasses, however, all processing and validation (by no
  1563. means intended to be complete) is done here. Thus, providing an end for a
  1564. frame is optional (the default is UNBOUNDED FOLLOWING, which is the last
  1565. row in the frame).
  1566. """
  1567. template = "%(frame_type)s BETWEEN %(start)s AND %(end)s"
  1568. def __init__(self, start=None, end=None):
  1569. self.start = Value(start)
  1570. self.end = Value(end)
  1571. def set_source_expressions(self, exprs):
  1572. self.start, self.end = exprs
  1573. def get_source_expressions(self):
  1574. return [self.start, self.end]
  1575. def as_sql(self, compiler, connection):
  1576. connection.ops.check_expression_support(self)
  1577. start, end = self.window_frame_start_end(
  1578. connection, self.start.value, self.end.value
  1579. )
  1580. return (
  1581. self.template
  1582. % {
  1583. "frame_type": self.frame_type,
  1584. "start": start,
  1585. "end": end,
  1586. },
  1587. [],
  1588. )
  1589. def __repr__(self):
  1590. return "<%s: %s>" % (self.__class__.__name__, self)
  1591. def get_group_by_cols(self):
  1592. return []
  1593. def __str__(self):
  1594. if self.start.value is not None and self.start.value < 0:
  1595. start = "%d %s" % (abs(self.start.value), connection.ops.PRECEDING)
  1596. elif self.start.value is not None and self.start.value == 0:
  1597. start = connection.ops.CURRENT_ROW
  1598. else:
  1599. start = connection.ops.UNBOUNDED_PRECEDING
  1600. if self.end.value is not None and self.end.value > 0:
  1601. end = "%d %s" % (self.end.value, connection.ops.FOLLOWING)
  1602. elif self.end.value is not None and self.end.value == 0:
  1603. end = connection.ops.CURRENT_ROW
  1604. else:
  1605. end = connection.ops.UNBOUNDED_FOLLOWING
  1606. return self.template % {
  1607. "frame_type": self.frame_type,
  1608. "start": start,
  1609. "end": end,
  1610. }
  1611. def window_frame_start_end(self, connection, start, end):
  1612. raise NotImplementedError("Subclasses must implement window_frame_start_end().")
  1613. class RowRange(WindowFrame):
  1614. frame_type = "ROWS"
  1615. def window_frame_start_end(self, connection, start, end):
  1616. return connection.ops.window_frame_rows_start_end(start, end)
  1617. class ValueRange(WindowFrame):
  1618. frame_type = "RANGE"
  1619. def window_frame_start_end(self, connection, start, end):
  1620. return connection.ops.window_frame_range_start_end(start, end)