123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487 |
- """
- PostgreSQL database backend for Django.
- Requires psycopg2 >= 2.8.4 or psycopg >= 3.1.8
- """
- import asyncio
- import threading
- import warnings
- from contextlib import contextmanager
- from django.conf import settings
- from django.core.exceptions import ImproperlyConfigured
- from django.db import DatabaseError as WrappedDatabaseError
- from django.db import connections
- from django.db.backends.base.base import BaseDatabaseWrapper
- from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
- from django.utils.asyncio import async_unsafe
- from django.utils.functional import cached_property
- from django.utils.safestring import SafeString
- from django.utils.version import get_version_tuple
- try:
- try:
- import psycopg as Database
- except ImportError:
- import psycopg2 as Database
- except ImportError:
- raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
- def psycopg_version():
- version = Database.__version__.split(" ", 1)[0]
- return get_version_tuple(version)
- if psycopg_version() < (2, 8, 4):
- raise ImproperlyConfigured(
- f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
- )
- if (3,) <= psycopg_version() < (3, 1, 8):
- raise ImproperlyConfigured(
- f"psycopg version 3.1.8 or newer is required; you have {Database.__version__}"
- )
- from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
- if is_psycopg3:
- from psycopg import adapters, sql
- from psycopg.pq import Format
- from .psycopg_any import get_adapters_template, register_tzloader
- TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
- else:
- import psycopg2.extensions
- import psycopg2.extras
- psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
- psycopg2.extras.register_uuid()
- # Register support for inet[] manually so we don't have to handle the Inet()
- # object on load all the time.
- INETARRAY_OID = 1041
- INETARRAY = psycopg2.extensions.new_array_type(
- (INETARRAY_OID,),
- "INETARRAY",
- psycopg2.extensions.UNICODE,
- )
- psycopg2.extensions.register_type(INETARRAY)
- # Some of these import psycopg, so import them after checking if it's installed.
- from .client import DatabaseClient # NOQA isort:skip
- from .creation import DatabaseCreation # NOQA isort:skip
- from .features import DatabaseFeatures # NOQA isort:skip
- from .introspection import DatabaseIntrospection # NOQA isort:skip
- from .operations import DatabaseOperations # NOQA isort:skip
- from .schema import DatabaseSchemaEditor # NOQA isort:skip
- def _get_varchar_column(data):
- if data["max_length"] is None:
- return "varchar"
- return "varchar(%(max_length)s)" % data
- class DatabaseWrapper(BaseDatabaseWrapper):
- vendor = "postgresql"
- display_name = "PostgreSQL"
- # This dictionary maps Field objects to their associated PostgreSQL column
- # types, as strings. Column-type strings can contain format strings; they'll
- # be interpolated against the values of Field.__dict__ before being output.
- # If a column type is set to None, it won't be included in the output.
- data_types = {
- "AutoField": "integer",
- "BigAutoField": "bigint",
- "BinaryField": "bytea",
- "BooleanField": "boolean",
- "CharField": _get_varchar_column,
- "DateField": "date",
- "DateTimeField": "timestamp with time zone",
- "DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
- "DurationField": "interval",
- "FileField": "varchar(%(max_length)s)",
- "FilePathField": "varchar(%(max_length)s)",
- "FloatField": "double precision",
- "IntegerField": "integer",
- "BigIntegerField": "bigint",
- "IPAddressField": "inet",
- "GenericIPAddressField": "inet",
- "JSONField": "jsonb",
- "OneToOneField": "integer",
- "PositiveBigIntegerField": "bigint",
- "PositiveIntegerField": "integer",
- "PositiveSmallIntegerField": "smallint",
- "SlugField": "varchar(%(max_length)s)",
- "SmallAutoField": "smallint",
- "SmallIntegerField": "smallint",
- "TextField": "text",
- "TimeField": "time",
- "UUIDField": "uuid",
- }
- data_type_check_constraints = {
- "PositiveBigIntegerField": '"%(column)s" >= 0',
- "PositiveIntegerField": '"%(column)s" >= 0',
- "PositiveSmallIntegerField": '"%(column)s" >= 0',
- }
- data_types_suffix = {
- "AutoField": "GENERATED BY DEFAULT AS IDENTITY",
- "BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
- "SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
- }
- operators = {
- "exact": "= %s",
- "iexact": "= UPPER(%s)",
- "contains": "LIKE %s",
- "icontains": "LIKE UPPER(%s)",
- "regex": "~ %s",
- "iregex": "~* %s",
- "gt": "> %s",
- "gte": ">= %s",
- "lt": "< %s",
- "lte": "<= %s",
- "startswith": "LIKE %s",
- "endswith": "LIKE %s",
- "istartswith": "LIKE UPPER(%s)",
- "iendswith": "LIKE UPPER(%s)",
- }
- # The patterns below are used to generate SQL pattern lookup clauses when
- # the right-hand side of the lookup isn't a raw string (it might be an expression
- # or the result of a bilateral transformation).
- # In those cases, special characters for LIKE operators (e.g. \, *, _) should be
- # escaped on database side.
- #
- # Note: we use str.format() here for readability as '%' is used as a wildcard for
- # the LIKE operator.
- pattern_esc = (
- r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
- )
- pattern_ops = {
- "contains": "LIKE '%%' || {} || '%%'",
- "icontains": "LIKE '%%' || UPPER({}) || '%%'",
- "startswith": "LIKE {} || '%%'",
- "istartswith": "LIKE UPPER({}) || '%%'",
- "endswith": "LIKE '%%' || {}",
- "iendswith": "LIKE '%%' || UPPER({})",
- }
- Database = Database
- SchemaEditorClass = DatabaseSchemaEditor
- # Classes instantiated in __init__().
- client_class = DatabaseClient
- creation_class = DatabaseCreation
- features_class = DatabaseFeatures
- introspection_class = DatabaseIntrospection
- ops_class = DatabaseOperations
- # PostgreSQL backend-specific attributes.
- _named_cursor_idx = 0
- def get_database_version(self):
- """
- Return a tuple of the database's version.
- E.g. for pg_version 120004, return (12, 4).
- """
- return divmod(self.pg_version, 10000)
- def get_connection_params(self):
- settings_dict = self.settings_dict
- # None may be used to connect to the default 'postgres' db
- if settings_dict["NAME"] == "" and not settings_dict.get("OPTIONS", {}).get(
- "service"
- ):
- raise ImproperlyConfigured(
- "settings.DATABASES is improperly configured. "
- "Please supply the NAME or OPTIONS['service'] value."
- )
- if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
- raise ImproperlyConfigured(
- "The database name '%s' (%d characters) is longer than "
- "PostgreSQL's limit of %d characters. Supply a shorter NAME "
- "in settings.DATABASES."
- % (
- settings_dict["NAME"],
- len(settings_dict["NAME"]),
- self.ops.max_name_length(),
- )
- )
- if settings_dict["NAME"]:
- conn_params = {
- "dbname": settings_dict["NAME"],
- **settings_dict["OPTIONS"],
- }
- elif settings_dict["NAME"] is None:
- # Connect to the default 'postgres' db.
- settings_dict.get("OPTIONS", {}).pop("service", None)
- conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]}
- else:
- conn_params = {**settings_dict["OPTIONS"]}
- conn_params["client_encoding"] = "UTF8"
- conn_params.pop("assume_role", None)
- conn_params.pop("isolation_level", None)
- server_side_binding = conn_params.pop("server_side_binding", None)
- conn_params.setdefault(
- "cursor_factory",
- ServerBindingCursor
- if is_psycopg3 and server_side_binding is True
- else Cursor,
- )
- if settings_dict["USER"]:
- conn_params["user"] = settings_dict["USER"]
- if settings_dict["PASSWORD"]:
- conn_params["password"] = settings_dict["PASSWORD"]
- if settings_dict["HOST"]:
- conn_params["host"] = settings_dict["HOST"]
- if settings_dict["PORT"]:
- conn_params["port"] = settings_dict["PORT"]
- if is_psycopg3:
- conn_params["context"] = get_adapters_template(
- settings.USE_TZ, self.timezone
- )
- # Disable prepared statements by default to keep connection poolers
- # working. Can be reenabled via OPTIONS in the settings dict.
- conn_params["prepare_threshold"] = conn_params.pop(
- "prepare_threshold", None
- )
- return conn_params
- @async_unsafe
- def get_new_connection(self, conn_params):
- # self.isolation_level must be set:
- # - after connecting to the database in order to obtain the database's
- # default when no value is explicitly specified in options.
- # - before calling _set_autocommit() because if autocommit is on, that
- # will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
- options = self.settings_dict["OPTIONS"]
- set_isolation_level = False
- try:
- isolation_level_value = options["isolation_level"]
- except KeyError:
- self.isolation_level = IsolationLevel.READ_COMMITTED
- else:
- # Set the isolation level to the value from OPTIONS.
- try:
- self.isolation_level = IsolationLevel(isolation_level_value)
- set_isolation_level = True
- except ValueError:
- raise ImproperlyConfigured(
- f"Invalid transaction isolation level {isolation_level_value} "
- f"specified. Use one of the psycopg.IsolationLevel values."
- )
- connection = self.Database.connect(**conn_params)
- if set_isolation_level:
- connection.isolation_level = self.isolation_level
- if not is_psycopg3:
- # Register dummy loads() to avoid a round trip from psycopg2's
- # decode to json.dumps() to json.loads(), when using a custom
- # decoder in JSONField.
- psycopg2.extras.register_default_jsonb(
- conn_or_curs=connection, loads=lambda x: x
- )
- return connection
- def ensure_timezone(self):
- if self.connection is None:
- return False
- conn_timezone_name = self.connection.info.parameter_status("TimeZone")
- timezone_name = self.timezone_name
- if timezone_name and conn_timezone_name != timezone_name:
- with self.connection.cursor() as cursor:
- cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
- return True
- return False
- def ensure_role(self):
- if self.connection is None:
- return False
- if new_role := self.settings_dict.get("OPTIONS", {}).get("assume_role"):
- with self.connection.cursor() as cursor:
- sql = self.ops.compose_sql("SET ROLE %s", [new_role])
- cursor.execute(sql)
- return True
- return False
- def init_connection_state(self):
- super().init_connection_state()
- # Commit after setting the time zone.
- commit_tz = self.ensure_timezone()
- # Set the role on the connection. This is useful if the credential used
- # to login is not the same as the role that owns database resources. As
- # can be the case when using temporary or ephemeral credentials.
- commit_role = self.ensure_role()
- if (commit_role or commit_tz) and not self.get_autocommit():
- self.connection.commit()
- @async_unsafe
- def create_cursor(self, name=None):
- if name:
- # In autocommit mode, the cursor will be used outside of a
- # transaction, hence use a holdable cursor.
- cursor = self.connection.cursor(
- name, scrollable=False, withhold=self.connection.autocommit
- )
- else:
- cursor = self.connection.cursor()
- if is_psycopg3:
- # Register the cursor timezone only if the connection disagrees, to
- # avoid copying the adapter map.
- tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
- if self.timezone != tzloader.timezone:
- register_tzloader(self.timezone, cursor)
- else:
- cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
- return cursor
- def tzinfo_factory(self, offset):
- return self.timezone
- @async_unsafe
- def chunked_cursor(self):
- self._named_cursor_idx += 1
- # Get the current async task
- # Note that right now this is behind @async_unsafe, so this is
- # unreachable, but in future we'll start loosening this restriction.
- # For now, it's here so that every use of "threading" is
- # also async-compatible.
- try:
- current_task = asyncio.current_task()
- except RuntimeError:
- current_task = None
- # Current task can be none even if the current_task call didn't error
- if current_task:
- task_ident = str(id(current_task))
- else:
- task_ident = "sync"
- # Use that and the thread ident to get a unique name
- return self._cursor(
- name="_django_curs_%d_%s_%d"
- % (
- # Avoid reusing name in other threads / tasks
- threading.current_thread().ident,
- task_ident,
- self._named_cursor_idx,
- )
- )
- def _set_autocommit(self, autocommit):
- with self.wrap_database_errors:
- self.connection.autocommit = autocommit
- def check_constraints(self, table_names=None):
- """
- Check constraints by setting them to immediate. Return them to deferred
- afterward.
- """
- with self.cursor() as cursor:
- cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
- cursor.execute("SET CONSTRAINTS ALL DEFERRED")
- def is_usable(self):
- try:
- # Use a psycopg cursor directly, bypassing Django's utilities.
- with self.connection.cursor() as cursor:
- cursor.execute("SELECT 1")
- except Database.Error:
- return False
- else:
- return True
- @contextmanager
- def _nodb_cursor(self):
- cursor = None
- try:
- with super()._nodb_cursor() as cursor:
- yield cursor
- except (Database.DatabaseError, WrappedDatabaseError):
- if cursor is not None:
- raise
- warnings.warn(
- "Normally Django will use a connection to the 'postgres' database "
- "to avoid running initialization queries against the production "
- "database when it's not needed (for example, when running tests). "
- "Django was unable to create a connection to the 'postgres' database "
- "and will use the first PostgreSQL database instead.",
- RuntimeWarning,
- )
- for connection in connections.all():
- if (
- connection.vendor == "postgresql"
- and connection.settings_dict["NAME"] != "postgres"
- ):
- conn = self.__class__(
- {
- **self.settings_dict,
- "NAME": connection.settings_dict["NAME"],
- },
- alias=self.alias,
- )
- try:
- with conn.cursor() as cursor:
- yield cursor
- finally:
- conn.close()
- break
- else:
- raise
- @cached_property
- def pg_version(self):
- with self.temporary_connection():
- return self.connection.info.server_version
- def make_debug_cursor(self, cursor):
- return CursorDebugWrapper(cursor, self)
- if is_psycopg3:
- class CursorMixin:
- """
- A subclass of psycopg cursor implementing callproc.
- """
- def callproc(self, name, args=None):
- if not isinstance(name, sql.Identifier):
- name = sql.Identifier(name)
- qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
- if args:
- for item in args:
- qparts.append(sql.Literal(item))
- qparts.append(sql.SQL(","))
- del qparts[-1]
- qparts.append(sql.SQL(")"))
- stmt = sql.Composed(qparts)
- self.execute(stmt)
- return args
- class ServerBindingCursor(CursorMixin, Database.Cursor):
- pass
- class Cursor(CursorMixin, Database.ClientCursor):
- pass
- class CursorDebugWrapper(BaseCursorDebugWrapper):
- def copy(self, statement):
- with self.debug_sql(statement):
- return self.cursor.copy(statement)
- else:
- Cursor = psycopg2.extensions.cursor
- class CursorDebugWrapper(BaseCursorDebugWrapper):
- def copy_expert(self, sql, file, *args):
- with self.debug_sql(sql):
- return self.cursor.copy_expert(sql, file, *args)
- def copy_to(self, file, table, *args, **kwargs):
- with self.debug_sql(sql="COPY %s TO STDOUT" % table):
- return self.cursor.copy_to(file, table, *args, **kwargs)
|