base.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. """
  2. PostgreSQL database backend for Django.
  3. Requires psycopg2 >= 2.8.4 or psycopg >= 3.1.8
  4. """
  5. import asyncio
  6. import threading
  7. import warnings
  8. from contextlib import contextmanager
  9. from django.conf import settings
  10. from django.core.exceptions import ImproperlyConfigured
  11. from django.db import DatabaseError as WrappedDatabaseError
  12. from django.db import connections
  13. from django.db.backends.base.base import BaseDatabaseWrapper
  14. from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
  15. from django.utils.asyncio import async_unsafe
  16. from django.utils.functional import cached_property
  17. from django.utils.safestring import SafeString
  18. from django.utils.version import get_version_tuple
  19. try:
  20. try:
  21. import psycopg as Database
  22. except ImportError:
  23. import psycopg2 as Database
  24. except ImportError:
  25. raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
  26. def psycopg_version():
  27. version = Database.__version__.split(" ", 1)[0]
  28. return get_version_tuple(version)
  29. if psycopg_version() < (2, 8, 4):
  30. raise ImproperlyConfigured(
  31. f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
  32. )
  33. if (3,) <= psycopg_version() < (3, 1, 8):
  34. raise ImproperlyConfigured(
  35. f"psycopg version 3.1.8 or newer is required; you have {Database.__version__}"
  36. )
  37. from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
  38. if is_psycopg3:
  39. from psycopg import adapters, sql
  40. from psycopg.pq import Format
  41. from .psycopg_any import get_adapters_template, register_tzloader
  42. TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
  43. else:
  44. import psycopg2.extensions
  45. import psycopg2.extras
  46. psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
  47. psycopg2.extras.register_uuid()
  48. # Register support for inet[] manually so we don't have to handle the Inet()
  49. # object on load all the time.
  50. INETARRAY_OID = 1041
  51. INETARRAY = psycopg2.extensions.new_array_type(
  52. (INETARRAY_OID,),
  53. "INETARRAY",
  54. psycopg2.extensions.UNICODE,
  55. )
  56. psycopg2.extensions.register_type(INETARRAY)
  57. # Some of these import psycopg, so import them after checking if it's installed.
  58. from .client import DatabaseClient # NOQA isort:skip
  59. from .creation import DatabaseCreation # NOQA isort:skip
  60. from .features import DatabaseFeatures # NOQA isort:skip
  61. from .introspection import DatabaseIntrospection # NOQA isort:skip
  62. from .operations import DatabaseOperations # NOQA isort:skip
  63. from .schema import DatabaseSchemaEditor # NOQA isort:skip
  64. def _get_varchar_column(data):
  65. if data["max_length"] is None:
  66. return "varchar"
  67. return "varchar(%(max_length)s)" % data
  68. class DatabaseWrapper(BaseDatabaseWrapper):
  69. vendor = "postgresql"
  70. display_name = "PostgreSQL"
  71. # This dictionary maps Field objects to their associated PostgreSQL column
  72. # types, as strings. Column-type strings can contain format strings; they'll
  73. # be interpolated against the values of Field.__dict__ before being output.
  74. # If a column type is set to None, it won't be included in the output.
  75. data_types = {
  76. "AutoField": "integer",
  77. "BigAutoField": "bigint",
  78. "BinaryField": "bytea",
  79. "BooleanField": "boolean",
  80. "CharField": _get_varchar_column,
  81. "DateField": "date",
  82. "DateTimeField": "timestamp with time zone",
  83. "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
  84. "DurationField": "interval",
  85. "FileField": "varchar(%(max_length)s)",
  86. "FilePathField": "varchar(%(max_length)s)",
  87. "FloatField": "double precision",
  88. "IntegerField": "integer",
  89. "BigIntegerField": "bigint",
  90. "IPAddressField": "inet",
  91. "GenericIPAddressField": "inet",
  92. "JSONField": "jsonb",
  93. "OneToOneField": "integer",
  94. "PositiveBigIntegerField": "bigint",
  95. "PositiveIntegerField": "integer",
  96. "PositiveSmallIntegerField": "smallint",
  97. "SlugField": "varchar(%(max_length)s)",
  98. "SmallAutoField": "smallint",
  99. "SmallIntegerField": "smallint",
  100. "TextField": "text",
  101. "TimeField": "time",
  102. "UUIDField": "uuid",
  103. }
  104. data_type_check_constraints = {
  105. "PositiveBigIntegerField": '"%(column)s" >= 0',
  106. "PositiveIntegerField": '"%(column)s" >= 0',
  107. "PositiveSmallIntegerField": '"%(column)s" >= 0',
  108. }
  109. data_types_suffix = {
  110. "AutoField": "GENERATED BY DEFAULT AS IDENTITY",
  111. "BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
  112. "SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
  113. }
  114. operators = {
  115. "exact": "= %s",
  116. "iexact": "= UPPER(%s)",
  117. "contains": "LIKE %s",
  118. "icontains": "LIKE UPPER(%s)",
  119. "regex": "~ %s",
  120. "iregex": "~* %s",
  121. "gt": "> %s",
  122. "gte": ">= %s",
  123. "lt": "< %s",
  124. "lte": "<= %s",
  125. "startswith": "LIKE %s",
  126. "endswith": "LIKE %s",
  127. "istartswith": "LIKE UPPER(%s)",
  128. "iendswith": "LIKE UPPER(%s)",
  129. }
  130. # The patterns below are used to generate SQL pattern lookup clauses when
  131. # the right-hand side of the lookup isn't a raw string (it might be an expression
  132. # or the result of a bilateral transformation).
  133. # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
  134. # escaped on database side.
  135. #
  136. # Note: we use str.format() here for readability as '%' is used as a wildcard for
  137. # the LIKE operator.
  138. pattern_esc = (
  139. r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
  140. )
  141. pattern_ops = {
  142. "contains": "LIKE '%%' || {} || '%%'",
  143. "icontains": "LIKE '%%' || UPPER({}) || '%%'",
  144. "startswith": "LIKE {} || '%%'",
  145. "istartswith": "LIKE UPPER({}) || '%%'",
  146. "endswith": "LIKE '%%' || {}",
  147. "iendswith": "LIKE '%%' || UPPER({})",
  148. }
  149. Database = Database
  150. SchemaEditorClass = DatabaseSchemaEditor
  151. # Classes instantiated in __init__().
  152. client_class = DatabaseClient
  153. creation_class = DatabaseCreation
  154. features_class = DatabaseFeatures
  155. introspection_class = DatabaseIntrospection
  156. ops_class = DatabaseOperations
  157. # PostgreSQL backend-specific attributes.
  158. _named_cursor_idx = 0
  159. def get_database_version(self):
  160. """
  161. Return a tuple of the database's version.
  162. E.g. for pg_version 120004, return (12, 4).
  163. """
  164. return divmod(self.pg_version, 10000)
  165. def get_connection_params(self):
  166. settings_dict = self.settings_dict
  167. # None may be used to connect to the default 'postgres' db
  168. if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
  169. "service"
  170. ):
  171. raise ImproperlyConfigured(
  172. "settings.DATABASES is improperly configured. "
  173. "Please supply the NAME or OPTIONS['service'] value."
  174. )
  175. if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
  176. raise ImproperlyConfigured(
  177. "The database name '%s' (%d characters) is longer than "
  178. "PostgreSQL's limit of %d characters. Supply a shorter NAME "
  179. "in settings.DATABASES."
  180. % (
  181. settings_dict["NAME"],
  182. len(settings_dict["NAME"]),
  183. self.ops.max_name_length(),
  184. )
  185. )
  186. if settings_dict["NAME"]:
  187. conn_params = {
  188. "dbname": settings_dict["NAME"],
  189. **settings_dict["OPTIONS"],
  190. }
  191. elif settings_dict["NAME"] is None:
  192. # Connect to the default 'postgres' db.
  193. settings_dict.get("OPTIONS", {}).pop("service", None)
  194. conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]}
  195. else:
  196. conn_params = {**settings_dict["OPTIONS"]}
  197. conn_params["client_encoding"] = "UTF8"
  198. conn_params.pop("assume_role", None)
  199. conn_params.pop("isolation_level", None)
  200. server_side_binding = conn_params.pop("server_side_binding", None)
  201. conn_params.setdefault(
  202. "cursor_factory",
  203. ServerBindingCursor
  204. if is_psycopg3 and server_side_binding is True
  205. else Cursor,
  206. )
  207. if settings_dict["USER"]:
  208. conn_params["user"] = settings_dict["USER"]
  209. if settings_dict["PASSWORD"]:
  210. conn_params["password"] = settings_dict["PASSWORD"]
  211. if settings_dict["HOST"]:
  212. conn_params["host"] = settings_dict["HOST"]
  213. if settings_dict["PORT"]:
  214. conn_params["port"] = settings_dict["PORT"]
  215. if is_psycopg3:
  216. conn_params["context"] = get_adapters_template(
  217. settings.USE_TZ, self.timezone
  218. )
  219. # Disable prepared statements by default to keep connection poolers
  220. # working. Can be reenabled via OPTIONS in the settings dict.
  221. conn_params["prepare_threshold"] = conn_params.pop(
  222. "prepare_threshold", None
  223. )
  224. return conn_params
  225. @async_unsafe
  226. def get_new_connection(self, conn_params):
  227. # self.isolation_level must be set:
  228. # - after connecting to the database in order to obtain the database's
  229. # default when no value is explicitly specified in options.
  230. # - before calling _set_autocommit() because if autocommit is on, that
  231. # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
  232. options = self.settings_dict["OPTIONS"]
  233. set_isolation_level = False
  234. try:
  235. isolation_level_value = options["isolation_level"]
  236. except KeyError:
  237. self.isolation_level = IsolationLevel.READ_COMMITTED
  238. else:
  239. # Set the isolation level to the value from OPTIONS.
  240. try:
  241. self.isolation_level = IsolationLevel(isolation_level_value)
  242. set_isolation_level = True
  243. except ValueError:
  244. raise ImproperlyConfigured(
  245. f"Invalid transaction isolation level {isolation_level_value} "
  246. f"specified. Use one of the psycopg.IsolationLevel values."
  247. )
  248. connection = self.Database.connect(**conn_params)
  249. if set_isolation_level:
  250. connection.isolation_level = self.isolation_level
  251. if not is_psycopg3:
  252. # Register dummy loads() to avoid a round trip from psycopg2's
  253. # decode to json.dumps() to json.loads(), when using a custom
  254. # decoder in JSONField.
  255. psycopg2.extras.register_default_jsonb(
  256. conn_or_curs=connection, loads=lambda x: x
  257. )
  258. return connection
  259. def ensure_timezone(self):
  260. if self.connection is None:
  261. return False
  262. conn_timezone_name = self.connection.info.parameter_status("TimeZone")
  263. timezone_name = self.timezone_name
  264. if timezone_name and conn_timezone_name != timezone_name:
  265. with self.connection.cursor() as cursor:
  266. cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
  267. return True
  268. return False
  269. def ensure_role(self):
  270. if self.connection is None:
  271. return False
  272. if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
  273. with self.connection.cursor() as cursor:
  274. sql = self.ops.compose_sql("SET ROLE %s", [new_role])
  275. cursor.execute(sql)
  276. return True
  277. return False
  278. def init_connection_state(self):
  279. super().init_connection_state()
  280. # Commit after setting the time zone.
  281. commit_tz = self.ensure_timezone()
  282. # Set the role on the connection. This is useful if the credential used
  283. # to login is not the same as the role that owns database resources. As
  284. # can be the case when using temporary or ephemeral credentials.
  285. commit_role = self.ensure_role()
  286. if (commit_role or commit_tz) and not self.get_autocommit():
  287. self.connection.commit()
  288. @async_unsafe
  289. def create_cursor(self, name=None):
  290. if name:
  291. # In autocommit mode, the cursor will be used outside of a
  292. # transaction, hence use a holdable cursor.
  293. cursor = self.connection.cursor(
  294. name, scrollable=False, withhold=self.connection.autocommit
  295. )
  296. else:
  297. cursor = self.connection.cursor()
  298. if is_psycopg3:
  299. # Register the cursor timezone only if the connection disagrees, to
  300. # avoid copying the adapter map.
  301. tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
  302. if self.timezone != tzloader.timezone:
  303. register_tzloader(self.timezone, cursor)
  304. else:
  305. cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
  306. return cursor
  307. def tzinfo_factory(self, offset):
  308. return self.timezone
  309. @async_unsafe
  310. def chunked_cursor(self):
  311. self._named_cursor_idx += 1
  312. # Get the current async task
  313. # Note that right now this is behind @async_unsafe, so this is
  314. # unreachable, but in future we'll start loosening this restriction.
  315. # For now, it's here so that every use of "threading" is
  316. # also async-compatible.
  317. try:
  318. current_task = asyncio.current_task()
  319. except RuntimeError:
  320. current_task = None
  321. # Current task can be none even if the current_task call didn't error
  322. if current_task:
  323. task_ident = str(id(current_task))
  324. else:
  325. task_ident = "sync"
  326. # Use that and the thread ident to get a unique name
  327. return self._cursor(
  328. name="_django_curs_%d_%s_%d"
  329. % (
  330. # Avoid reusing name in other threads / tasks
  331. threading.current_thread().ident,
  332. task_ident,
  333. self._named_cursor_idx,
  334. )
  335. )
  336. def _set_autocommit(self, autocommit):
  337. with self.wrap_database_errors:
  338. self.connection.autocommit = autocommit
  339. def check_constraints(self, table_names=None):
  340. """
  341. Check constraints by setting them to immediate. Return them to deferred
  342. afterward.
  343. """
  344. with self.cursor() as cursor:
  345. cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
  346. cursor.execute("SET CONSTRAINTS ALL DEFERRED")
  347. def is_usable(self):
  348. try:
  349. # Use a psycopg cursor directly, bypassing Django's utilities.
  350. with self.connection.cursor() as cursor:
  351. cursor.execute("SELECT 1")
  352. except Database.Error:
  353. return False
  354. else:
  355. return True
  356. @contextmanager
  357. def _nodb_cursor(self):
  358. cursor = None
  359. try:
  360. with super()._nodb_cursor() as cursor:
  361. yield cursor
  362. except (Database.DatabaseError, WrappedDatabaseError):
  363. if cursor is not None:
  364. raise
  365. warnings.warn(
  366. "Normally Django will use a connection to the 'postgres' database "
  367. "to avoid running initialization queries against the production "
  368. "database when it's not needed (for example, when running tests). "
  369. "Django was unable to create a connection to the 'postgres' database "
  370. "and will use the first PostgreSQL database instead.",
  371. RuntimeWarning,
  372. )
  373. for connection in connections.all():
  374. if (
  375. connection.vendor == "postgresql"
  376. and connection.settings_dict["NAME"] != "postgres"
  377. ):
  378. conn = self.__class__(
  379. {
  380. **self.settings_dict,
  381. "NAME": connection.settings_dict["NAME"],
  382. },
  383. alias=self.alias,
  384. )
  385. try:
  386. with conn.cursor() as cursor:
  387. yield cursor
  388. finally:
  389. conn.close()
  390. break
  391. else:
  392. raise
  393. @cached_property
  394. def pg_version(self):
  395. with self.temporary_connection():
  396. return self.connection.info.server_version
  397. def make_debug_cursor(self, cursor):
  398. return CursorDebugWrapper(cursor, self)
  399. if is_psycopg3:
  400. class CursorMixin:
  401. """
  402. A subclass of psycopg cursor implementing callproc.
  403. """
  404. def callproc(self, name, args=None):
  405. if not isinstance(name, sql.Identifier):
  406. name = sql.Identifier(name)
  407. qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
  408. if args:
  409. for item in args:
  410. qparts.append(sql.Literal(item))
  411. qparts.append(sql.SQL(","))
  412. del qparts[-1]
  413. qparts.append(sql.SQL(")"))
  414. stmt = sql.Composed(qparts)
  415. self.execute(stmt)
  416. return args
  417. class ServerBindingCursor(CursorMixin, Database.Cursor):
  418. pass
  419. class Cursor(CursorMixin, Database.ClientCursor):
  420. pass
  421. class CursorDebugWrapper(BaseCursorDebugWrapper):
  422. def copy(self, statement):
  423. with self.debug_sql(statement):
  424. return self.cursor.copy(statement)
  425. else:
  426. Cursor = psycopg2.extensions.cursor
  427. class CursorDebugWrapper(BaseCursorDebugWrapper):
  428. def copy_expert(self, sql, file, *args):
  429. with self.debug_sql(sql):
  430. return self.cursor.copy_expert(sql, file, *args)
  431. def copy_to(self, file, table, *args, **kwargs):
  432. with self.debug_sql(sql="COPY %s TO STDOUT" % table):
  433. return self.cursor.copy_to(file, table, *args, **kwargs)