operations.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. import uuid
  2. from django.conf import settings
  3. from django.db.backends.base.operations import BaseDatabaseOperations
  4. from django.db.backends.utils import split_tzname_delta
  5. from django.db.models import Exists, ExpressionWrapper, Lookup
  6. from django.db.models.constants import OnConflict
  7. from django.utils import timezone
  8. from django.utils.encoding import force_str
  9. from django.utils.regex_helper import _lazy_re_compile
  10. class DatabaseOperations(BaseDatabaseOperations):
  11. compiler_module = "django.db.backends.mysql.compiler"
  12. # MySQL stores positive fields as UNSIGNED ints.
  13. integer_field_ranges = {
  14. **BaseDatabaseOperations.integer_field_ranges,
  15. "PositiveSmallIntegerField": (0, 65535),
  16. "PositiveIntegerField": (0, 4294967295),
  17. "PositiveBigIntegerField": (0, 18446744073709551615),
  18. }
  19. cast_data_types = {
  20. "AutoField": "signed integer",
  21. "BigAutoField": "signed integer",
  22. "SmallAutoField": "signed integer",
  23. "CharField": "char(%(max_length)s)",
  24. "DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
  25. "TextField": "char",
  26. "IntegerField": "signed integer",
  27. "BigIntegerField": "signed integer",
  28. "SmallIntegerField": "signed integer",
  29. "PositiveBigIntegerField": "unsigned integer",
  30. "PositiveIntegerField": "unsigned integer",
  31. "PositiveSmallIntegerField": "unsigned integer",
  32. "DurationField": "signed integer",
  33. }
  34. cast_char_field_without_max_length = "char"
  35. explain_prefix = "EXPLAIN"
  36. # EXTRACT format cannot be passed in parameters.
  37. _extract_format_re = _lazy_re_compile(r"[A-Z_]+")
  38. def date_extract_sql(self, lookup_type, sql, params):
  39. # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
  40. if lookup_type == "week_day":
  41. # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
  42. return f"DAYOFWEEK({sql})", params
  43. elif lookup_type == "iso_week_day":
  44. # WEEKDAY() returns an integer, 0-6, Monday=0.
  45. return f"WEEKDAY({sql}) + 1", params
  46. elif lookup_type == "week":
  47. # Override the value of default_week_format for consistency with
  48. # other database backends.
  49. # Mode 3: Monday, 1-53, with 4 or more days this year.
  50. return f"WEEK({sql}, 3)", params
  51. elif lookup_type == "iso_year":
  52. # Get the year part from the YEARWEEK function, which returns a
  53. # number as year * 100 + week.
  54. return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
  55. else:
  56. # EXTRACT returns 1-53 based on ISO-8601 for the week number.
  57. lookup_type = lookup_type.upper()
  58. if not self._extract_format_re.fullmatch(lookup_type):
  59. raise ValueError(f"Invalid loookup type: {lookup_type!r}")
  60. return f"EXTRACT({lookup_type} FROM {sql})", params
  61. def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
  62. sql, params = self._convert_sql_to_tz(sql, params, tzname)
  63. fields = {
  64. "year": "%Y-01-01",
  65. "month": "%Y-%m-01",
  66. }
  67. if lookup_type in fields:
  68. format_str = fields[lookup_type]
  69. return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
  70. elif lookup_type == "quarter":
  71. return (
  72. f"MAKEDATE(YEAR({sql}), 1) + "
  73. f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
  74. (*params, *params),
  75. )
  76. elif lookup_type == "week":
  77. return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
  78. else:
  79. return f"DATE({sql})", params
  80. def _prepare_tzname_delta(self, tzname):
  81. tzname, sign, offset = split_tzname_delta(tzname)
  82. return f"{sign}{offset}" if offset else tzname
  83. def _convert_sql_to_tz(self, sql, params, tzname):
  84. if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
  85. return f"CONVERT_TZ({sql}, %s, %s)", (
  86. *params,
  87. self.connection.timezone_name,
  88. self._prepare_tzname_delta(tzname),
  89. )
  90. return sql, params
  91. def datetime_cast_date_sql(self, sql, params, tzname):
  92. sql, params = self._convert_sql_to_tz(sql, params, tzname)
  93. return f"DATE({sql})", params
  94. def datetime_cast_time_sql(self, sql, params, tzname):
  95. sql, params = self._convert_sql_to_tz(sql, params, tzname)
  96. return f"TIME({sql})", params
  97. def datetime_extract_sql(self, lookup_type, sql, params, tzname):
  98. sql, params = self._convert_sql_to_tz(sql, params, tzname)
  99. return self.date_extract_sql(lookup_type, sql, params)
  100. def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
  101. sql, params = self._convert_sql_to_tz(sql, params, tzname)
  102. fields = ["year", "month", "day", "hour", "minute", "second"]
  103. format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
  104. format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
  105. if lookup_type == "quarter":
  106. return (
  107. f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
  108. f"INTERVAL QUARTER({sql}) QUARTER - "
  109. f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
  110. ), (*params, *params, "%Y-%m-01 00:00:00")
  111. if lookup_type == "week":
  112. return (
  113. f"CAST(DATE_FORMAT("
  114. f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
  115. ), (*params, *params, "%Y-%m-%d 00:00:00")
  116. try:
  117. i = fields.index(lookup_type) + 1
  118. except ValueError:
  119. pass
  120. else:
  121. format_str = "".join(format[:i] + format_def[i:])
  122. return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
  123. return sql, params
  124. def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
  125. sql, params = self._convert_sql_to_tz(sql, params, tzname)
  126. fields = {
  127. "hour": "%H:00:00",
  128. "minute": "%H:%i:00",
  129. "second": "%H:%i:%s",
  130. }
  131. if lookup_type in fields:
  132. format_str = fields[lookup_type]
  133. return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
  134. else:
  135. return f"TIME({sql})", params
  136. def fetch_returned_insert_rows(self, cursor):
  137. """
  138. Given a cursor object that has just performed an INSERT...RETURNING
  139. statement into a table, return the tuple of returned data.
  140. """
  141. return cursor.fetchall()
  142. def format_for_duration_arithmetic(self, sql):
  143. return "INTERVAL %s MICROSECOND" % sql
  144. def force_no_ordering(self):
  145. """
  146. "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
  147. columns. If no ordering would otherwise be applied, we don't want any
  148. implicit sorting going on.
  149. """
  150. return [(None, ("NULL", [], False))]
  151. def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
  152. return value
  153. def last_executed_query(self, cursor, sql, params):
  154. # With MySQLdb, cursor objects have an (undocumented) "_executed"
  155. # attribute where the exact query sent to the database is saved.
  156. # See MySQLdb/cursors.py in the source distribution.
  157. # MySQLdb returns string, PyMySQL bytes.
  158. return force_str(getattr(cursor, "_executed", None), errors="replace")
  159. def no_limit_value(self):
  160. # 2**64 - 1, as recommended by the MySQL documentation
  161. return 18446744073709551615
  162. def quote_name(self, name):
  163. if name.startswith("`") and name.endswith("`"):
  164. return name # Quoting once is enough.
  165. return "`%s`" % name
  166. def return_insert_columns(self, fields):
  167. # MySQL and MariaDB < 10.5.0 don't support an INSERT...RETURNING
  168. # statement.
  169. if not fields:
  170. return "", ()
  171. columns = [
  172. "%s.%s"
  173. % (
  174. self.quote_name(field.model._meta.db_table),
  175. self.quote_name(field.column),
  176. )
  177. for field in fields
  178. ]
  179. return "RETURNING %s" % ", ".join(columns), ()
  180. def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
  181. if not tables:
  182. return []
  183. sql = ["SET FOREIGN_KEY_CHECKS = 0;"]
  184. if reset_sequences:
  185. # It's faster to TRUNCATE tables that require a sequence reset
  186. # since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE.
  187. sql.extend(
  188. "%s %s;"
  189. % (
  190. style.SQL_KEYWORD("TRUNCATE"),
  191. style.SQL_FIELD(self.quote_name(table_name)),
  192. )
  193. for table_name in tables
  194. )
  195. else:
  196. # Otherwise issue a simple DELETE since it's faster than TRUNCATE
  197. # and preserves sequences.
  198. sql.extend(
  199. "%s %s %s;"
  200. % (
  201. style.SQL_KEYWORD("DELETE"),
  202. style.SQL_KEYWORD("FROM"),
  203. style.SQL_FIELD(self.quote_name(table_name)),
  204. )
  205. for table_name in tables
  206. )
  207. sql.append("SET FOREIGN_KEY_CHECKS = 1;")
  208. return sql
  209. def sequence_reset_by_name_sql(self, style, sequences):
  210. return [
  211. "%s %s %s %s = 1;"
  212. % (
  213. style.SQL_KEYWORD("ALTER"),
  214. style.SQL_KEYWORD("TABLE"),
  215. style.SQL_FIELD(self.quote_name(sequence_info["table"])),
  216. style.SQL_FIELD("AUTO_INCREMENT"),
  217. )
  218. for sequence_info in sequences
  219. ]
  220. def validate_autopk_value(self, value):
  221. # Zero in AUTO_INCREMENT field does not work without the
  222. # NO_AUTO_VALUE_ON_ZERO SQL mode.
  223. if value == 0 and not self.connection.features.allows_auto_pk_0:
  224. raise ValueError(
  225. "The database backend does not accept 0 as a value for AutoField."
  226. )
  227. return value
  228. def adapt_datetimefield_value(self, value):
  229. if value is None:
  230. return None
  231. # Expression values are adapted by the database.
  232. if hasattr(value, "resolve_expression"):
  233. return value
  234. # MySQL doesn't support tz-aware datetimes
  235. if timezone.is_aware(value):
  236. if settings.USE_TZ:
  237. value = timezone.make_naive(value, self.connection.timezone)
  238. else:
  239. raise ValueError(
  240. "MySQL backend does not support timezone-aware datetimes when "
  241. "USE_TZ is False."
  242. )
  243. return str(value)
  244. def adapt_timefield_value(self, value):
  245. if value is None:
  246. return None
  247. # Expression values are adapted by the database.
  248. if hasattr(value, "resolve_expression"):
  249. return value
  250. # MySQL doesn't support tz-aware times
  251. if timezone.is_aware(value):
  252. raise ValueError("MySQL backend does not support timezone-aware times.")
  253. return value.isoformat(timespec="microseconds")
  254. def max_name_length(self):
  255. return 64
  256. def pk_default_value(self):
  257. return "NULL"
  258. def bulk_insert_sql(self, fields, placeholder_rows):
  259. placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
  260. values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
  261. return "VALUES " + values_sql
  262. def combine_expression(self, connector, sub_expressions):
  263. if connector == "^":
  264. return "POW(%s)" % ",".join(sub_expressions)
  265. # Convert the result to a signed integer since MySQL's binary operators
  266. # return an unsigned integer.
  267. elif connector in ("&", "|", "<<", "#"):
  268. connector = "^" if connector == "#" else connector
  269. return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions)
  270. elif connector == ">>":
  271. lhs, rhs = sub_expressions
  272. return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
  273. return super().combine_expression(connector, sub_expressions)
  274. def get_db_converters(self, expression):
  275. converters = super().get_db_converters(expression)
  276. internal_type = expression.output_field.get_internal_type()
  277. if internal_type == "BooleanField":
  278. converters.append(self.convert_booleanfield_value)
  279. elif internal_type == "DateTimeField":
  280. if settings.USE_TZ:
  281. converters.append(self.convert_datetimefield_value)
  282. elif internal_type == "UUIDField":
  283. converters.append(self.convert_uuidfield_value)
  284. return converters
  285. def convert_booleanfield_value(self, value, expression, connection):
  286. if value in (0, 1):
  287. value = bool(value)
  288. return value
  289. def convert_datetimefield_value(self, value, expression, connection):
  290. if value is not None:
  291. value = timezone.make_aware(value, self.connection.timezone)
  292. return value
  293. def convert_uuidfield_value(self, value, expression, connection):
  294. if value is not None:
  295. value = uuid.UUID(value)
  296. return value
  297. def binary_placeholder_sql(self, value):
  298. return (
  299. "_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
  300. )
  301. def subtract_temporals(self, internal_type, lhs, rhs):
  302. lhs_sql, lhs_params = lhs
  303. rhs_sql, rhs_params = rhs
  304. if internal_type == "TimeField":
  305. if self.connection.mysql_is_mariadb:
  306. # MariaDB includes the microsecond component in TIME_TO_SEC as
  307. # a decimal. MySQL returns an integer without microseconds.
  308. return (
  309. "CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) "
  310. "* 1000000 AS SIGNED)"
  311. ) % {
  312. "lhs": lhs_sql,
  313. "rhs": rhs_sql,
  314. }, (
  315. *lhs_params,
  316. *rhs_params,
  317. )
  318. return (
  319. "((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
  320. " (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
  321. ) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple(
  322. rhs_params
  323. ) * 2
  324. params = (*rhs_params, *lhs_params)
  325. return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
  326. def explain_query_prefix(self, format=None, **options):
  327. # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
  328. if format and format.upper() == "TEXT":
  329. format = "TRADITIONAL"
  330. elif (
  331. not format and "TREE" in self.connection.features.supported_explain_formats
  332. ):
  333. # Use TREE by default (if supported) as it's more informative.
  334. format = "TREE"
  335. analyze = options.pop("analyze", False)
  336. prefix = super().explain_query_prefix(format, **options)
  337. if analyze and self.connection.features.supports_explain_analyze:
  338. # MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
  339. prefix = (
  340. "ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
  341. )
  342. if format and not (analyze and not self.connection.mysql_is_mariadb):
  343. # Only MariaDB supports the analyze option with formats.
  344. prefix += " FORMAT=%s" % format
  345. return prefix
  346. def regex_lookup(self, lookup_type):
  347. # REGEXP_LIKE doesn't exist in MariaDB.
  348. if self.connection.mysql_is_mariadb:
  349. if lookup_type == "regex":
  350. return "%s REGEXP BINARY %s"
  351. return "%s REGEXP %s"
  352. match_option = "c" if lookup_type == "regex" else "i"
  353. return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
  354. def insert_statement(self, on_conflict=None):
  355. if on_conflict == OnConflict.IGNORE:
  356. return "INSERT IGNORE INTO"
  357. return super().insert_statement(on_conflict=on_conflict)
  358. def lookup_cast(self, lookup_type, internal_type=None):
  359. lookup = "%s"
  360. if internal_type == "JSONField":
  361. if self.connection.mysql_is_mariadb or lookup_type in (
  362. "iexact",
  363. "contains",
  364. "icontains",
  365. "startswith",
  366. "istartswith",
  367. "endswith",
  368. "iendswith",
  369. "regex",
  370. "iregex",
  371. ):
  372. lookup = "JSON_UNQUOTE(%s)"
  373. return lookup
  374. def conditional_expression_supported_in_where_clause(self, expression):
  375. # MySQL ignores indexes with boolean fields unless they're compared
  376. # directly to a boolean value.
  377. if isinstance(expression, (Exists, Lookup)):
  378. return True
  379. if isinstance(expression, ExpressionWrapper) and expression.conditional:
  380. return self.conditional_expression_supported_in_where_clause(
  381. expression.expression
  382. )
  383. if getattr(expression, "conditional", False):
  384. return False
  385. return super().conditional_expression_supported_in_where_clause(expression)
  386. def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
  387. if on_conflict == OnConflict.UPDATE:
  388. conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
  389. # The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
  390. # aliases for the new row and its columns available in MySQL
  391. # 8.0.19+.
  392. if not self.connection.mysql_is_mariadb:
  393. if self.connection.mysql_version >= (8, 0, 19):
  394. conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
  395. field_sql = "%(field)s = new.%(field)s"
  396. else:
  397. field_sql = "%(field)s = VALUES(%(field)s)"
  398. # Use VALUE() on MariaDB.
  399. else:
  400. field_sql = "%(field)s = VALUE(%(field)s)"
  401. fields = ", ".join(
  402. [
  403. field_sql % {"field": field}
  404. for field in map(self.quote_name, update_fields)
  405. ]
  406. )
  407. return conflict_suffix_sql % {"fields": fields}
  408. return super().on_conflict_suffix_sql(
  409. fields,
  410. on_conflict,
  411. update_fields,
  412. unique_fields,
  413. )