json.py 22 KB


  1. import json
  2. import warnings
  3. from django import forms
  4. from django.core import checks, exceptions
  5. from django.db import NotSupportedError, connections, router
  6. from django.db.models import expressions, lookups
  7. from django.db.models.constants import LOOKUP_SEP
  8. from django.db.models.fields import TextField
  9. from django.db.models.lookups import (
  10. FieldGetDbPrepValueMixin,
  11. PostgresOperatorLookup,
  12. Transform,
  13. )
  14. from django.utils.deprecation import RemovedInDjango51Warning
  15. from django.utils.translation import gettext_lazy as _
  16. from . import Field
  17. from .mixins import CheckFieldDefaultMixin
  18. __all__ = ["JSONField"]
  19. class JSONField(CheckFieldDefaultMixin, Field):
  20. empty_strings_allowed = False
  21. description = _("A JSON object")
  22. default_error_messages = {
  23. "invalid": _("Value must be valid JSON."),
  24. }
  25. _default_hint = ("dict", "{}")
  26. def __init__(
  27. self,
  28. verbose_name=None,
  29. name=None,
  30. encoder=None,
  31. decoder=None,
  32. **kwargs,
  33. ):
  34. if encoder and not callable(encoder):
  35. raise ValueError("The encoder parameter must be a callable object.")
  36. if decoder and not callable(decoder):
  37. raise ValueError("The decoder parameter must be a callable object.")
  38. self.encoder = encoder
  39. self.decoder = decoder
  40. super().__init__(verbose_name, name, **kwargs)
  41. def check(self, **kwargs):
  42. errors = super().check(**kwargs)
  43. databases = kwargs.get("databases") or []
  44. errors.extend(self._check_supported(databases))
  45. return errors
  46. def _check_supported(self, databases):
  47. errors = []
  48. for db in databases:
  49. if not router.allow_migrate_model(db, self.model):
  50. continue
  51. connection = connections[db]
  52. if (
  53. self.model._meta.required_db_vendor
  54. and self.model._meta.required_db_vendor != connection.vendor
  55. ):
  56. continue
  57. if not (
  58. "supports_json_field" in self.model._meta.required_db_features
  59. or connection.features.supports_json_field
  60. ):
  61. errors.append(
  62. checks.Error(
  63. "%s does not support JSONFields." % connection.display_name,
  64. obj=self.model,
  65. id="fields.E180",
  66. )
  67. )
  68. return errors
  69. def deconstruct(self):
  70. name, path, args, kwargs = super().deconstruct()
  71. if self.encoder is not None:
  72. kwargs["encoder"] = self.encoder
  73. if self.decoder is not None:
  74. kwargs["decoder"] = self.decoder
  75. return name, path, args, kwargs
  76. def from_db_value(self, value, expression, connection):
  77. if value is None:
  78. return value
  79. # Some backends (SQLite at least) extract non-string values in their
  80. # SQL datatypes.
  81. if isinstance(expression, KeyTransform) and not isinstance(value, str):
  82. return value
  83. try:
  84. return json.loads(value, cls=self.decoder)
  85. except json.JSONDecodeError:
  86. return value
  87. def get_internal_type(self):
  88. return "JSONField"
  89. def get_db_prep_value(self, value, connection, prepared=False):
  90. if not prepared:
  91. value = self.get_prep_value(value)
  92. # RemovedInDjango51Warning: When the deprecation ends, replace with:
  93. # if (
  94. # isinstance(value, expressions.Value)
  95. # and isinstance(value.output_field, JSONField)
  96. # ):
  97. # value = value.value
  98. # elif hasattr(value, "as_sql"): ...
  99. if isinstance(value, expressions.Value):
  100. if isinstance(value.value, str) and not isinstance(
  101. value.output_field, JSONField
  102. ):
  103. try:
  104. value = json.loads(value.value, cls=self.decoder)
  105. except json.JSONDecodeError:
  106. value = value.value
  107. else:
  108. warnings.warn(
  109. "Providing an encoded JSON string via Value() is deprecated. "
  110. f"Use Value({value!r}, output_field=JSONField()) instead.",
  111. category=RemovedInDjango51Warning,
  112. )
  113. elif isinstance(value.output_field, JSONField):
  114. value = value.value
  115. else:
  116. return value
  117. elif hasattr(value, "as_sql"):
  118. return value
  119. return connection.ops.adapt_json_value(value, self.encoder)
  120. def get_db_prep_save(self, value, connection):
  121. if value is None:
  122. return value
  123. return self.get_db_prep_value(value, connection)
  124. def get_transform(self, name):
  125. transform = super().get_transform(name)
  126. if transform:
  127. return transform
  128. return KeyTransformFactory(name)
  129. def validate(self, value, model_instance):
  130. super().validate(value, model_instance)
  131. try:
  132. json.dumps(value, cls=self.encoder)
  133. except TypeError:
  134. raise exceptions.ValidationError(
  135. self.error_messages["invalid"],
  136. code="invalid",
  137. params={"value": value},
  138. )
  139. def value_to_string(self, obj):
  140. return self.value_from_object(obj)
  141. def formfield(self, **kwargs):
  142. return super().formfield(
  143. **{
  144. "form_class": forms.JSONField,
  145. "encoder": self.encoder,
  146. "decoder": self.decoder,
  147. **kwargs,
  148. }
  149. )
  150. def compile_json_path(key_transforms, include_root=True):
  151. path = ["$"] if include_root else []
  152. for key_transform in key_transforms:
  153. try:
  154. num = int(key_transform)
  155. except ValueError: # non-integer
  156. path.append(".")
  157. path.append(json.dumps(key_transform))
  158. else:
  159. path.append("[%s]" % num)
  160. return "".join(path)
  161. class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
  162. lookup_name = "contains"
  163. postgres_operator = "@>"
  164. def as_sql(self, compiler, connection):
  165. if not connection.features.supports_json_field_contains:
  166. raise NotSupportedError(
  167. "contains lookup is not supported on this database backend."
  168. )
  169. lhs, lhs_params = self.process_lhs(compiler, connection)
  170. rhs, rhs_params = self.process_rhs(compiler, connection)
  171. params = tuple(lhs_params) + tuple(rhs_params)
  172. return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
  173. class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
  174. lookup_name = "contained_by"
  175. postgres_operator = "<@"
  176. def as_sql(self, compiler, connection):
  177. if not connection.features.supports_json_field_contains:
  178. raise NotSupportedError(
  179. "contained_by lookup is not supported on this database backend."
  180. )
  181. lhs, lhs_params = self.process_lhs(compiler, connection)
  182. rhs, rhs_params = self.process_rhs(compiler, connection)
  183. params = tuple(rhs_params) + tuple(lhs_params)
  184. return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
  185. class HasKeyLookup(PostgresOperatorLookup):
  186. logical_operator = None
  187. def compile_json_path_final_key(self, key_transform):
  188. # Compile the final key without interpreting ints as array elements.
  189. return ".%s" % json.dumps(key_transform)
  190. def as_sql(self, compiler, connection, template=None):
  191. # Process JSON path from the left-hand side.
  192. if isinstance(self.lhs, KeyTransform):
  193. lhs, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
  194. compiler, connection
  195. )
  196. lhs_json_path = compile_json_path(lhs_key_transforms)
  197. else:
  198. lhs, lhs_params = self.process_lhs(compiler, connection)
  199. lhs_json_path = "$"
  200. sql = template % lhs
  201. # Process JSON path from the right-hand side.
  202. rhs = self.rhs
  203. rhs_params = []
  204. if not isinstance(rhs, (list, tuple)):
  205. rhs = [rhs]
  206. for key in rhs:
  207. if isinstance(key, KeyTransform):
  208. *_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
  209. else:
  210. rhs_key_transforms = [key]
  211. *rhs_key_transforms, final_key = rhs_key_transforms
  212. rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
  213. rhs_json_path += self.compile_json_path_final_key(final_key)
  214. rhs_params.append(lhs_json_path + rhs_json_path)
  215. # Add condition for each key.
  216. if self.logical_operator:
  217. sql = "(%s)" % self.logical_operator.join([sql] * len(rhs_params))
  218. return sql, tuple(lhs_params) + tuple(rhs_params)
  219. def as_mysql(self, compiler, connection):
  220. return self.as_sql(
  221. compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %%s)"
  222. )
  223. def as_oracle(self, compiler, connection):
  224. sql, params = self.as_sql(
  225. compiler, connection, template="JSON_EXISTS(%s, '%%s')"
  226. )
  227. # Add paths directly into SQL because path expressions cannot be passed
  228. # as bind variables on Oracle.
  229. return sql % tuple(params), []
  230. def as_postgresql(self, compiler, connection):
  231. if isinstance(self.rhs, KeyTransform):
  232. *_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
  233. for key in rhs_key_transforms[:-1]:
  234. self.lhs = KeyTransform(key, self.lhs)
  235. self.rhs = rhs_key_transforms[-1]
  236. return super().as_postgresql(compiler, connection)
  237. def as_sqlite(self, compiler, connection):
  238. return self.as_sql(
  239. compiler, connection, template="JSON_TYPE(%s, %%s) IS NOT NULL"
  240. )
  241. class HasKey(HasKeyLookup):
  242. lookup_name = "has_key"
  243. postgres_operator = "?"
  244. prepare_rhs = False
  245. class HasKeys(HasKeyLookup):
  246. lookup_name = "has_keys"
  247. postgres_operator = "?&"
  248. logical_operator = " AND "
  249. def get_prep_lookup(self):
  250. return [str(item) for item in self.rhs]
  251. class HasAnyKeys(HasKeys):
  252. lookup_name = "has_any_keys"
  253. postgres_operator = "?|"
  254. logical_operator = " OR "
  255. class HasKeyOrArrayIndex(HasKey):
  256. def compile_json_path_final_key(self, key_transform):
  257. return compile_json_path([key_transform], include_root=False)
  258. class CaseInsensitiveMixin:
  259. """
  260. Mixin to allow case-insensitive comparison of JSON values on MySQL.
  261. MySQL handles strings used in JSON context using the utf8mb4_bin collation.
  262. Because utf8mb4_bin is a binary collation, comparison of JSON values is
  263. case-sensitive.
  264. """
  265. def process_lhs(self, compiler, connection):
  266. lhs, lhs_params = super().process_lhs(compiler, connection)
  267. if connection.vendor == "mysql":
  268. return "LOWER(%s)" % lhs, lhs_params
  269. return lhs, lhs_params
  270. def process_rhs(self, compiler, connection):
  271. rhs, rhs_params = super().process_rhs(compiler, connection)
  272. if connection.vendor == "mysql":
  273. return "LOWER(%s)" % rhs, rhs_params
  274. return rhs, rhs_params
  275. class JSONExact(lookups.Exact):
  276. can_use_none_as_rhs = True
  277. def process_rhs(self, compiler, connection):
  278. rhs, rhs_params = super().process_rhs(compiler, connection)
  279. # Treat None lookup values as null.
  280. if rhs == "%s" and rhs_params == [None]:
  281. rhs_params = ["null"]
  282. if connection.vendor == "mysql":
  283. func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
  284. rhs %= tuple(func)
  285. return rhs, rhs_params
  286. class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
  287. pass
  288. JSONField.register_lookup(DataContains)
  289. JSONField.register_lookup(ContainedBy)
  290. JSONField.register_lookup(HasKey)
  291. JSONField.register_lookup(HasKeys)
  292. JSONField.register_lookup(HasAnyKeys)
  293. JSONField.register_lookup(JSONExact)
  294. JSONField.register_lookup(JSONIContains)
  295. class KeyTransform(Transform):
  296. postgres_operator = "->"
  297. postgres_nested_operator = "#>"
  298. def __init__(self, key_name, *args, **kwargs):
  299. super().__init__(*args, **kwargs)
  300. self.key_name = str(key_name)
  301. def preprocess_lhs(self, compiler, connection):
  302. key_transforms = [self.key_name]
  303. previous = self.lhs
  304. while isinstance(previous, KeyTransform):
  305. key_transforms.insert(0, previous.key_name)
  306. previous = previous.lhs
  307. lhs, params = compiler.compile(previous)
  308. if connection.vendor == "oracle":
  309. # Escape string-formatting.
  310. key_transforms = [key.replace("%", "%%") for key in key_transforms]
  311. return lhs, params, key_transforms
  312. def as_mysql(self, compiler, connection):
  313. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  314. json_path = compile_json_path(key_transforms)
  315. return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
  316. def as_oracle(self, compiler, connection):
  317. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  318. json_path = compile_json_path(key_transforms)
  319. return (
  320. "COALESCE(JSON_QUERY(%s, '%s'), JSON_VALUE(%s, '%s'))"
  321. % ((lhs, json_path) * 2)
  322. ), tuple(params) * 2
  323. def as_postgresql(self, compiler, connection):
  324. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  325. if len(key_transforms) > 1:
  326. sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
  327. return sql, tuple(params) + (key_transforms,)
  328. try:
  329. lookup = int(self.key_name)
  330. except ValueError:
  331. lookup = self.key_name
  332. return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
  333. def as_sqlite(self, compiler, connection):
  334. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  335. json_path = compile_json_path(key_transforms)
  336. datatype_values = ",".join(
  337. [repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
  338. )
  339. return (
  340. "(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
  341. "THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
  342. ) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
  343. class KeyTextTransform(KeyTransform):
  344. postgres_operator = "->>"
  345. postgres_nested_operator = "#>>"
  346. output_field = TextField()
  347. def as_mysql(self, compiler, connection):
  348. if connection.mysql_is_mariadb:
  349. # MariaDB doesn't support -> and ->> operators (see MDEV-13594).
  350. sql, params = super().as_mysql(compiler, connection)
  351. return "JSON_UNQUOTE(%s)" % sql, params
  352. else:
  353. lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
  354. json_path = compile_json_path(key_transforms)
  355. return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
  356. @classmethod
  357. def from_lookup(cls, lookup):
  358. transform, *keys = lookup.split(LOOKUP_SEP)
  359. if not keys:
  360. raise ValueError("Lookup must contain key or index transforms.")
  361. for key in keys:
  362. transform = cls(key, transform)
  363. return transform
  364. KT = KeyTextTransform.from_lookup
  365. class KeyTransformTextLookupMixin:
  366. """
  367. Mixin for combining with a lookup expecting a text lhs from a JSONField
  368. key lookup. On PostgreSQL, make use of the ->> operator instead of casting
  369. key values to text and performing the lookup on the resulting
  370. representation.
  371. """
  372. def __init__(self, key_transform, *args, **kwargs):
  373. if not isinstance(key_transform, KeyTransform):
  374. raise TypeError(
  375. "Transform should be an instance of KeyTransform in order to "
  376. "use this lookup."
  377. )
  378. key_text_transform = KeyTextTransform(
  379. key_transform.key_name,
  380. *key_transform.source_expressions,
  381. **key_transform.extra,
  382. )
  383. super().__init__(key_text_transform, *args, **kwargs)
  384. class KeyTransformIsNull(lookups.IsNull):
  385. # key__isnull=False is the same as has_key='key'
  386. def as_oracle(self, compiler, connection):
  387. sql, params = HasKeyOrArrayIndex(
  388. self.lhs.lhs,
  389. self.lhs.key_name,
  390. ).as_oracle(compiler, connection)
  391. if not self.rhs:
  392. return sql, params
  393. # Column doesn't have a key or IS NULL.
  394. lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
  395. return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
  396. def as_sqlite(self, compiler, connection):
  397. template = "JSON_TYPE(%s, %%s) IS NULL"
  398. if not self.rhs:
  399. template = "JSON_TYPE(%s, %%s) IS NOT NULL"
  400. return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
  401. compiler,
  402. connection,
  403. template=template,
  404. )
  405. class KeyTransformIn(lookups.In):
  406. def resolve_expression_parameter(self, compiler, connection, sql, param):
  407. sql, params = super().resolve_expression_parameter(
  408. compiler,
  409. connection,
  410. sql,
  411. param,
  412. )
  413. if (
  414. not hasattr(param, "as_sql")
  415. and not connection.features.has_native_json_field
  416. ):
  417. if connection.vendor == "oracle":
  418. value = json.loads(param)
  419. sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
  420. if isinstance(value, (list, dict)):
  421. sql %= "JSON_QUERY"
  422. else:
  423. sql %= "JSON_VALUE"
  424. elif connection.vendor == "mysql" or (
  425. connection.vendor == "sqlite"
  426. and params[0] not in connection.ops.jsonfield_datatype_values
  427. ):
  428. sql = "JSON_EXTRACT(%s, '$')"
  429. if connection.vendor == "mysql" and connection.mysql_is_mariadb:
  430. sql = "JSON_UNQUOTE(%s)" % sql
  431. return sql, params
  432. class KeyTransformExact(JSONExact):
  433. def process_rhs(self, compiler, connection):
  434. if isinstance(self.rhs, KeyTransform):
  435. return super(lookups.Exact, self).process_rhs(compiler, connection)
  436. rhs, rhs_params = super().process_rhs(compiler, connection)
  437. if connection.vendor == "oracle":
  438. func = []
  439. sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
  440. for value in rhs_params:
  441. value = json.loads(value)
  442. if isinstance(value, (list, dict)):
  443. func.append(sql % "JSON_QUERY")
  444. else:
  445. func.append(sql % "JSON_VALUE")
  446. rhs %= tuple(func)
  447. elif connection.vendor == "sqlite":
  448. func = []
  449. for value in rhs_params:
  450. if value in connection.ops.jsonfield_datatype_values:
  451. func.append("%s")
  452. else:
  453. func.append("JSON_EXTRACT(%s, '$')")
  454. rhs %= tuple(func)
  455. return rhs, rhs_params
  456. def as_oracle(self, compiler, connection):
  457. rhs, rhs_params = super().process_rhs(compiler, connection)
  458. if rhs_params == ["null"]:
  459. # Field has key and it's NULL.
  460. has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
  461. has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
  462. is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
  463. is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
  464. return (
  465. "%s AND %s" % (has_key_sql, is_null_sql),
  466. tuple(has_key_params) + tuple(is_null_params),
  467. )
  468. return super().as_sql(compiler, connection)
  469. class KeyTransformIExact(
  470. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
  471. ):
  472. pass
  473. class KeyTransformIContains(
  474. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
  475. ):
  476. pass
  477. class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
  478. pass
  479. class KeyTransformIStartsWith(
  480. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
  481. ):
  482. pass
  483. class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
  484. pass
  485. class KeyTransformIEndsWith(
  486. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
  487. ):
  488. pass
  489. class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
  490. pass
  491. class KeyTransformIRegex(
  492. CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
  493. ):
  494. pass
  495. class KeyTransformNumericLookupMixin:
  496. def process_rhs(self, compiler, connection):
  497. rhs, rhs_params = super().process_rhs(compiler, connection)
  498. if not connection.features.has_native_json_field:
  499. rhs_params = [json.loads(value) for value in rhs_params]
  500. return rhs, rhs_params
  501. class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
  502. pass
  503. class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
  504. pass
  505. class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
  506. pass
  507. class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
  508. pass
  509. KeyTransform.register_lookup(KeyTransformIn)
  510. KeyTransform.register_lookup(KeyTransformExact)
  511. KeyTransform.register_lookup(KeyTransformIExact)
  512. KeyTransform.register_lookup(KeyTransformIsNull)
  513. KeyTransform.register_lookup(KeyTransformIContains)
  514. KeyTransform.register_lookup(KeyTransformStartsWith)
  515. KeyTransform.register_lookup(KeyTransformIStartsWith)
  516. KeyTransform.register_lookup(KeyTransformEndsWith)
  517. KeyTransform.register_lookup(KeyTransformIEndsWith)
  518. KeyTransform.register_lookup(KeyTransformRegex)
  519. KeyTransform.register_lookup(KeyTransformIRegex)
  520. KeyTransform.register_lookup(KeyTransformLt)
  521. KeyTransform.register_lookup(KeyTransformLte)
  522. KeyTransform.register_lookup(KeyTransformGt)
  523. KeyTransform.register_lookup(KeyTransformGte)
  524. class KeyTransformFactory:
  525. def __init__(self, key_name):
  526. self.key_name = key_name
  527. def __call__(self, *args, **kwargs):
  528. return KeyTransform(self.key_name, *args, **kwargs)