constraints.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import warnings
  2. from enum import Enum
  3. from types import NoneType
  4. from django.core.exceptions import FieldError, ValidationError
  5. from django.db import connections
  6. from django.db.models.expressions import Exists, ExpressionList, F, OrderBy
  7. from django.db.models.indexes import IndexExpression
  8. from django.db.models.lookups import Exact
  9. from django.db.models.query_utils import Q
  10. from django.db.models.sql.query import Query
  11. from django.db.utils import DEFAULT_DB_ALIAS
  12. from django.utils.deprecation import RemovedInDjango60Warning
  13. from django.utils.translation import gettext_lazy as _
  14. __all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
  15. class BaseConstraint:
  16. default_violation_error_message = _("Constraint “%(name)s” is violated.")
  17. violation_error_code = None
  18. violation_error_message = None
  19. # RemovedInDjango60Warning: When the deprecation ends, replace with:
  20. # def __init__(
  21. # self, *, name, violation_error_code=None, violation_error_message=None
  22. # ):
  23. def __init__(
  24. self, *args, name=None, violation_error_code=None, violation_error_message=None
  25. ):
  26. # RemovedInDjango60Warning.
  27. if name is None and not args:
  28. raise TypeError(
  29. f"{self.__class__.__name__}.__init__() missing 1 required keyword-only "
  30. f"argument: 'name'"
  31. )
  32. self.name = name
  33. if violation_error_code is not None:
  34. self.violation_error_code = violation_error_code
  35. if violation_error_message is not None:
  36. self.violation_error_message = violation_error_message
  37. else:
  38. self.violation_error_message = self.default_violation_error_message
  39. # RemovedInDjango60Warning.
  40. if args:
  41. warnings.warn(
  42. f"Passing positional arguments to {self.__class__.__name__} is "
  43. f"deprecated.",
  44. RemovedInDjango60Warning,
  45. stacklevel=2,
  46. )
  47. for arg, attr in zip(args, ["name", "violation_error_message"]):
  48. if arg:
  49. setattr(self, attr, arg)
  50. @property
  51. def contains_expressions(self):
  52. return False
  53. def constraint_sql(self, model, schema_editor):
  54. raise NotImplementedError("This method must be implemented by a subclass.")
  55. def create_sql(self, model, schema_editor):
  56. raise NotImplementedError("This method must be implemented by a subclass.")
  57. def remove_sql(self, model, schema_editor):
  58. raise NotImplementedError("This method must be implemented by a subclass.")
  59. def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
  60. raise NotImplementedError("This method must be implemented by a subclass.")
  61. def get_violation_error_message(self):
  62. return self.violation_error_message % {"name": self.name}
  63. def deconstruct(self):
  64. path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
  65. path = path.replace("django.db.models.constraints", "django.db.models")
  66. kwargs = {"name": self.name}
  67. if (
  68. self.violation_error_message is not None
  69. and self.violation_error_message != self.default_violation_error_message
  70. ):
  71. kwargs["violation_error_message"] = self.violation_error_message
  72. if self.violation_error_code is not None:
  73. kwargs["violation_error_code"] = self.violation_error_code
  74. return (path, (), kwargs)
  75. def clone(self):
  76. _, args, kwargs = self.deconstruct()
  77. return self.__class__(*args, **kwargs)
  78. class CheckConstraint(BaseConstraint):
  79. def __init__(
  80. self, *, check, name, violation_error_code=None, violation_error_message=None
  81. ):
  82. self.check = check
  83. if not getattr(check, "conditional", False):
  84. raise TypeError(
  85. "CheckConstraint.check must be a Q instance or boolean expression."
  86. )
  87. super().__init__(
  88. name=name,
  89. violation_error_code=violation_error_code,
  90. violation_error_message=violation_error_message,
  91. )
  92. def _get_check_sql(self, model, schema_editor):
  93. query = Query(model=model, alias_cols=False)
  94. where = query.build_where(self.check)
  95. compiler = query.get_compiler(connection=schema_editor.connection)
  96. sql, params = where.as_sql(compiler, schema_editor.connection)
  97. return sql % tuple(schema_editor.quote_value(p) for p in params)
  98. def constraint_sql(self, model, schema_editor):
  99. check = self._get_check_sql(model, schema_editor)
  100. return schema_editor._check_sql(self.name, check)
  101. def create_sql(self, model, schema_editor):
  102. check = self._get_check_sql(model, schema_editor)
  103. return schema_editor._create_check_sql(model, self.name, check)
  104. def remove_sql(self, model, schema_editor):
  105. return schema_editor._delete_check_sql(model, self.name)
  106. def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
  107. against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
  108. try:
  109. if not Q(self.check).check(against, using=using):
  110. raise ValidationError(
  111. self.get_violation_error_message(), code=self.violation_error_code
  112. )
  113. except FieldError:
  114. pass
  115. def __repr__(self):
  116. return "<%s: check=%s name=%s%s%s>" % (
  117. self.__class__.__qualname__,
  118. self.check,
  119. repr(self.name),
  120. (
  121. ""
  122. if self.violation_error_code is None
  123. else " violation_error_code=%r" % self.violation_error_code
  124. ),
  125. (
  126. ""
  127. if self.violation_error_message is None
  128. or self.violation_error_message == self.default_violation_error_message
  129. else " violation_error_message=%r" % self.violation_error_message
  130. ),
  131. )
  132. def __eq__(self, other):
  133. if isinstance(other, CheckConstraint):
  134. return (
  135. self.name == other.name
  136. and self.check == other.check
  137. and self.violation_error_code == other.violation_error_code
  138. and self.violation_error_message == other.violation_error_message
  139. )
  140. return super().__eq__(other)
  141. def deconstruct(self):
  142. path, args, kwargs = super().deconstruct()
  143. kwargs["check"] = self.check
  144. return path, args, kwargs
  145. class Deferrable(Enum):
  146. DEFERRED = "deferred"
  147. IMMEDIATE = "immediate"
  148. # A similar format was proposed for Python 3.10.
  149. def __repr__(self):
  150. return f"{self.__class__.__qualname__}.{self._name_}"
  151. class UniqueConstraint(BaseConstraint):
  152. def __init__(
  153. self,
  154. *expressions,
  155. fields=(),
  156. name=None,
  157. condition=None,
  158. deferrable=None,
  159. include=None,
  160. opclasses=(),
  161. nulls_distinct=None,
  162. violation_error_code=None,
  163. violation_error_message=None,
  164. ):
  165. if not name:
  166. raise ValueError("A unique constraint must be named.")
  167. if not expressions and not fields:
  168. raise ValueError(
  169. "At least one field or expression is required to define a "
  170. "unique constraint."
  171. )
  172. if expressions and fields:
  173. raise ValueError(
  174. "UniqueConstraint.fields and expressions are mutually exclusive."
  175. )
  176. if not isinstance(condition, (NoneType, Q)):
  177. raise ValueError("UniqueConstraint.condition must be a Q instance.")
  178. if condition and deferrable:
  179. raise ValueError("UniqueConstraint with conditions cannot be deferred.")
  180. if include and deferrable:
  181. raise ValueError("UniqueConstraint with include fields cannot be deferred.")
  182. if opclasses and deferrable:
  183. raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
  184. if expressions and deferrable:
  185. raise ValueError("UniqueConstraint with expressions cannot be deferred.")
  186. if expressions and opclasses:
  187. raise ValueError(
  188. "UniqueConstraint.opclasses cannot be used with expressions. "
  189. "Use django.contrib.postgres.indexes.OpClass() instead."
  190. )
  191. if not isinstance(deferrable, (NoneType, Deferrable)):
  192. raise TypeError(
  193. "UniqueConstraint.deferrable must be a Deferrable instance."
  194. )
  195. if not isinstance(include, (NoneType, list, tuple)):
  196. raise TypeError("UniqueConstraint.include must be a list or tuple.")
  197. if not isinstance(opclasses, (list, tuple)):
  198. raise TypeError("UniqueConstraint.opclasses must be a list or tuple.")
  199. if not isinstance(nulls_distinct, (NoneType, bool)):
  200. raise TypeError("UniqueConstraint.nulls_distinct must be a bool.")
  201. if opclasses and len(fields) != len(opclasses):
  202. raise ValueError(
  203. "UniqueConstraint.fields and UniqueConstraint.opclasses must "
  204. "have the same number of elements."
  205. )
  206. self.fields = tuple(fields)
  207. self.condition = condition
  208. self.deferrable = deferrable
  209. self.include = tuple(include) if include else ()
  210. self.opclasses = opclasses
  211. self.nulls_distinct = nulls_distinct
  212. self.expressions = tuple(
  213. F(expression) if isinstance(expression, str) else expression
  214. for expression in expressions
  215. )
  216. super().__init__(
  217. name=name,
  218. violation_error_code=violation_error_code,
  219. violation_error_message=violation_error_message,
  220. )
  221. @property
  222. def contains_expressions(self):
  223. return bool(self.expressions)
  224. def _get_condition_sql(self, model, schema_editor):
  225. if self.condition is None:
  226. return None
  227. query = Query(model=model, alias_cols=False)
  228. where = query.build_where(self.condition)
  229. compiler = query.get_compiler(connection=schema_editor.connection)
  230. sql, params = where.as_sql(compiler, schema_editor.connection)
  231. return sql % tuple(schema_editor.quote_value(p) for p in params)
  232. def _get_index_expressions(self, model, schema_editor):
  233. if not self.expressions:
  234. return None
  235. index_expressions = []
  236. for expression in self.expressions:
  237. index_expression = IndexExpression(expression)
  238. index_expression.set_wrapper_classes(schema_editor.connection)
  239. index_expressions.append(index_expression)
  240. return ExpressionList(*index_expressions).resolve_expression(
  241. Query(model, alias_cols=False),
  242. )
  243. def constraint_sql(self, model, schema_editor):
  244. fields = [model._meta.get_field(field_name) for field_name in self.fields]
  245. include = [
  246. model._meta.get_field(field_name).column for field_name in self.include
  247. ]
  248. condition = self._get_condition_sql(model, schema_editor)
  249. expressions = self._get_index_expressions(model, schema_editor)
  250. return schema_editor._unique_sql(
  251. model,
  252. fields,
  253. self.name,
  254. condition=condition,
  255. deferrable=self.deferrable,
  256. include=include,
  257. opclasses=self.opclasses,
  258. expressions=expressions,
  259. nulls_distinct=self.nulls_distinct,
  260. )
  261. def create_sql(self, model, schema_editor):
  262. fields = [model._meta.get_field(field_name) for field_name in self.fields]
  263. include = [
  264. model._meta.get_field(field_name).column for field_name in self.include
  265. ]
  266. condition = self._get_condition_sql(model, schema_editor)
  267. expressions = self._get_index_expressions(model, schema_editor)
  268. return schema_editor._create_unique_sql(
  269. model,
  270. fields,
  271. self.name,
  272. condition=condition,
  273. deferrable=self.deferrable,
  274. include=include,
  275. opclasses=self.opclasses,
  276. expressions=expressions,
  277. nulls_distinct=self.nulls_distinct,
  278. )
  279. def remove_sql(self, model, schema_editor):
  280. condition = self._get_condition_sql(model, schema_editor)
  281. include = [
  282. model._meta.get_field(field_name).column for field_name in self.include
  283. ]
  284. expressions = self._get_index_expressions(model, schema_editor)
  285. return schema_editor._delete_unique_sql(
  286. model,
  287. self.name,
  288. condition=condition,
  289. deferrable=self.deferrable,
  290. include=include,
  291. opclasses=self.opclasses,
  292. expressions=expressions,
  293. nulls_distinct=self.nulls_distinct,
  294. )
  295. def __repr__(self):
  296. return "<%s:%s%s%s%s%s%s%s%s%s%s>" % (
  297. self.__class__.__qualname__,
  298. "" if not self.fields else " fields=%s" % repr(self.fields),
  299. "" if not self.expressions else " expressions=%s" % repr(self.expressions),
  300. " name=%s" % repr(self.name),
  301. "" if self.condition is None else " condition=%s" % self.condition,
  302. "" if self.deferrable is None else " deferrable=%r" % self.deferrable,
  303. "" if not self.include else " include=%s" % repr(self.include),
  304. "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
  305. (
  306. ""
  307. if self.nulls_distinct is None
  308. else " nulls_distinct=%r" % self.nulls_distinct
  309. ),
  310. (
  311. ""
  312. if self.violation_error_code is None
  313. else " violation_error_code=%r" % self.violation_error_code
  314. ),
  315. (
  316. ""
  317. if self.violation_error_message is None
  318. or self.violation_error_message == self.default_violation_error_message
  319. else " violation_error_message=%r" % self.violation_error_message
  320. ),
  321. )
  322. def __eq__(self, other):
  323. if isinstance(other, UniqueConstraint):
  324. return (
  325. self.name == other.name
  326. and self.fields == other.fields
  327. and self.condition == other.condition
  328. and self.deferrable == other.deferrable
  329. and self.include == other.include
  330. and self.opclasses == other.opclasses
  331. and self.expressions == other.expressions
  332. and self.nulls_distinct is other.nulls_distinct
  333. and self.violation_error_code == other.violation_error_code
  334. and self.violation_error_message == other.violation_error_message
  335. )
  336. return super().__eq__(other)
  337. def deconstruct(self):
  338. path, args, kwargs = super().deconstruct()
  339. if self.fields:
  340. kwargs["fields"] = self.fields
  341. if self.condition:
  342. kwargs["condition"] = self.condition
  343. if self.deferrable:
  344. kwargs["deferrable"] = self.deferrable
  345. if self.include:
  346. kwargs["include"] = self.include
  347. if self.opclasses:
  348. kwargs["opclasses"] = self.opclasses
  349. if self.nulls_distinct is not None:
  350. kwargs["nulls_distinct"] = self.nulls_distinct
  351. return path, self.expressions, kwargs
  352. def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
  353. queryset = model._default_manager.using(using)
  354. if self.fields:
  355. lookup_kwargs = {}
  356. for field_name in self.fields:
  357. if exclude and field_name in exclude:
  358. return
  359. field = model._meta.get_field(field_name)
  360. lookup_value = getattr(instance, field.attname)
  361. if (
  362. self.nulls_distinct is not False
  363. and lookup_value is None
  364. or (
  365. lookup_value == ""
  366. and connections[
  367. using
  368. ].features.interprets_empty_strings_as_nulls
  369. )
  370. ):
  371. # A composite constraint containing NULL value cannot cause
  372. # a violation since NULL != NULL in SQL.
  373. return
  374. lookup_kwargs[field.name] = lookup_value
  375. queryset = queryset.filter(**lookup_kwargs)
  376. else:
  377. # Ignore constraints with excluded fields.
  378. if exclude:
  379. for expression in self.expressions:
  380. if hasattr(expression, "flatten"):
  381. for expr in expression.flatten():
  382. if isinstance(expr, F) and expr.name in exclude:
  383. return
  384. elif isinstance(expression, F) and expression.name in exclude:
  385. return
  386. replacements = {
  387. F(field): value
  388. for field, value in instance._get_field_value_map(
  389. meta=model._meta, exclude=exclude
  390. ).items()
  391. }
  392. expressions = []
  393. for expr in self.expressions:
  394. # Ignore ordering.
  395. if isinstance(expr, OrderBy):
  396. expr = expr.expression
  397. expressions.append(Exact(expr, expr.replace_expressions(replacements)))
  398. queryset = queryset.filter(*expressions)
  399. model_class_pk = instance._get_pk_val(model._meta)
  400. if not instance._state.adding and model_class_pk is not None:
  401. queryset = queryset.exclude(pk=model_class_pk)
  402. if not self.condition:
  403. if queryset.exists():
  404. if self.expressions:
  405. raise ValidationError(
  406. self.get_violation_error_message(),
  407. code=self.violation_error_code,
  408. )
  409. # When fields are defined, use the unique_error_message() for
  410. # backward compatibility.
  411. for model, constraints in instance.get_constraints():
  412. for constraint in constraints:
  413. if constraint is self:
  414. raise ValidationError(
  415. instance.unique_error_message(model, self.fields),
  416. )
  417. else:
  418. against = instance._get_field_value_map(meta=model._meta, exclude=exclude)
  419. try:
  420. if (self.condition & Exists(queryset.filter(self.condition))).check(
  421. against, using=using
  422. ):
  423. raise ValidationError(
  424. self.get_violation_error_message(),
  425. code=self.violation_error_code,
  426. )
  427. except FieldError:
  428. pass