introspection.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. from collections import namedtuple
  2. import sqlparse
  3. from django.db import DatabaseError
  4. from django.db.backends.base.introspection import BaseDatabaseIntrospection
  5. from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
  6. from django.db.backends.base.introspection import TableInfo
  7. from django.db.models import Index
  8. from django.utils.regex_helper import _lazy_re_compile
  9. FieldInfo = namedtuple(
  10. "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
  11. )
  12. field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
  13. def get_field_size(name):
  14. """Extract the size number from a "varchar(11)" type name"""
  15. m = field_size_re.search(name)
  16. return int(m[1]) if m else None
  17. # This light wrapper "fakes" a dictionary interface, because some SQLite data
  18. # types include variables in them -- e.g. "varchar(30)" -- and can't be matched
  19. # as a simple dictionary lookup.
  20. class FlexibleFieldLookupDict:
  21. # Maps SQL types to Django Field types. Some of the SQL types have multiple
  22. # entries here because SQLite allows for anything and doesn't normalize the
  23. # field type; it uses whatever was given.
  24. base_data_types_reverse = {
  25. "bool": "BooleanField",
  26. "boolean": "BooleanField",
  27. "smallint": "SmallIntegerField",
  28. "smallint unsigned": "PositiveSmallIntegerField",
  29. "smallinteger": "SmallIntegerField",
  30. "int": "IntegerField",
  31. "integer": "IntegerField",
  32. "bigint": "BigIntegerField",
  33. "integer unsigned": "PositiveIntegerField",
  34. "bigint unsigned": "PositiveBigIntegerField",
  35. "decimal": "DecimalField",
  36. "real": "FloatField",
  37. "text": "TextField",
  38. "char": "CharField",
  39. "varchar": "CharField",
  40. "blob": "BinaryField",
  41. "date": "DateField",
  42. "datetime": "DateTimeField",
  43. "time": "TimeField",
  44. }
  45. def __getitem__(self, key):
  46. key = key.lower().split("(", 1)[0].strip()
  47. return self.base_data_types_reverse[key]
  48. class DatabaseIntrospection(BaseDatabaseIntrospection):
  49. data_types_reverse = FlexibleFieldLookupDict()
  50. def get_field_type(self, data_type, description):
  51. field_type = super().get_field_type(data_type, description)
  52. if description.pk and field_type in {
  53. "BigIntegerField",
  54. "IntegerField",
  55. "SmallIntegerField",
  56. }:
  57. # No support for BigAutoField or SmallAutoField as SQLite treats
  58. # all integer primary keys as signed 64-bit integers.
  59. return "AutoField"
  60. if description.has_json_constraint:
  61. return "JSONField"
  62. return field_type
  63. def get_table_list(self, cursor):
  64. """Return a list of table and view names in the current database."""
  65. # Skip the sqlite_sequence system table used for autoincrement key
  66. # generation.
  67. cursor.execute(
  68. """
  69. SELECT name, type FROM sqlite_master
  70. WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
  71. ORDER BY name"""
  72. )
  73. return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
  74. def get_table_description(self, cursor, table_name):
  75. """
  76. Return a description of the table with the DB-API cursor.description
  77. interface.
  78. """
  79. cursor.execute(
  80. "PRAGMA table_xinfo(%s)" % self.connection.ops.quote_name(table_name)
  81. )
  82. table_info = cursor.fetchall()
  83. if not table_info:
  84. raise DatabaseError(f"Table {table_name} does not exist (empty pragma).")
  85. collations = self._get_column_collations(cursor, table_name)
  86. json_columns = set()
  87. if self.connection.features.can_introspect_json_field:
  88. for line in table_info:
  89. column = line[1]
  90. json_constraint_sql = '%%json_valid("%s")%%' % column
  91. has_json_constraint = cursor.execute(
  92. """
  93. SELECT sql
  94. FROM sqlite_master
  95. WHERE
  96. type = 'table' AND
  97. name = %s AND
  98. sql LIKE %s
  99. """,
  100. [table_name, json_constraint_sql],
  101. ).fetchone()
  102. if has_json_constraint:
  103. json_columns.add(column)
  104. return [
  105. FieldInfo(
  106. name,
  107. data_type,
  108. get_field_size(data_type),
  109. None,
  110. None,
  111. None,
  112. not notnull,
  113. default,
  114. collations.get(name),
  115. pk == 1,
  116. name in json_columns,
  117. )
  118. for cid, name, data_type, notnull, default, pk, hidden in table_info
  119. if hidden
  120. in [
  121. 0, # Normal column.
  122. 2, # Virtual generated column.
  123. 3, # Stored generated column.
  124. ]
  125. ]
  126. def get_sequences(self, cursor, table_name, table_fields=()):
  127. pk_col = self.get_primary_key_column(cursor, table_name)
  128. return [{"table": table_name, "column": pk_col}]
  129. def get_relations(self, cursor, table_name):
  130. """
  131. Return a dictionary of {column_name: (ref_column_name, ref_table_name)}
  132. representing all foreign keys in the given table.
  133. """
  134. cursor.execute(
  135. "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
  136. )
  137. return {
  138. column_name: (ref_column_name, ref_table_name)
  139. for (
  140. _,
  141. _,
  142. ref_table_name,
  143. column_name,
  144. ref_column_name,
  145. *_,
  146. ) in cursor.fetchall()
  147. }
  148. def get_primary_key_columns(self, cursor, table_name):
  149. cursor.execute(
  150. "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
  151. )
  152. return [name for _, name, *_, pk in cursor.fetchall() if pk]
  153. def _parse_column_or_constraint_definition(self, tokens, columns):
  154. token = None
  155. is_constraint_definition = None
  156. field_name = None
  157. constraint_name = None
  158. unique = False
  159. unique_columns = []
  160. check = False
  161. check_columns = []
  162. braces_deep = 0
  163. for token in tokens:
  164. if token.match(sqlparse.tokens.Punctuation, "("):
  165. braces_deep += 1
  166. elif token.match(sqlparse.tokens.Punctuation, ")"):
  167. braces_deep -= 1
  168. if braces_deep < 0:
  169. # End of columns and constraints for table definition.
  170. break
  171. elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
  172. # End of current column or constraint definition.
  173. break
  174. # Detect column or constraint definition by first token.
  175. if is_constraint_definition is None:
  176. is_constraint_definition = token.match(
  177. sqlparse.tokens.Keyword, "CONSTRAINT"
  178. )
  179. if is_constraint_definition:
  180. continue
  181. if is_constraint_definition:
  182. # Detect constraint name by second token.
  183. if constraint_name is None:
  184. if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
  185. constraint_name = token.value
  186. elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
  187. constraint_name = token.value[1:-1]
  188. # Start constraint columns parsing after UNIQUE keyword.
  189. if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
  190. unique = True
  191. unique_braces_deep = braces_deep
  192. elif unique:
  193. if unique_braces_deep == braces_deep:
  194. if unique_columns:
  195. # Stop constraint parsing.
  196. unique = False
  197. continue
  198. if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
  199. unique_columns.append(token.value)
  200. elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
  201. unique_columns.append(token.value[1:-1])
  202. else:
  203. # Detect field name by first token.
  204. if field_name is None:
  205. if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
  206. field_name = token.value
  207. elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
  208. field_name = token.value[1:-1]
  209. if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
  210. unique_columns = [field_name]
  211. # Start constraint columns parsing after CHECK keyword.
  212. if token.match(sqlparse.tokens.Keyword, "CHECK"):
  213. check = True
  214. check_braces_deep = braces_deep
  215. elif check:
  216. if check_braces_deep == braces_deep:
  217. if check_columns:
  218. # Stop constraint parsing.
  219. check = False
  220. continue
  221. if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
  222. if token.value in columns:
  223. check_columns.append(token.value)
  224. elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
  225. if token.value[1:-1] in columns:
  226. check_columns.append(token.value[1:-1])
  227. unique_constraint = (
  228. {
  229. "unique": True,
  230. "columns": unique_columns,
  231. "primary_key": False,
  232. "foreign_key": None,
  233. "check": False,
  234. "index": False,
  235. }
  236. if unique_columns
  237. else None
  238. )
  239. check_constraint = (
  240. {
  241. "check": True,
  242. "columns": check_columns,
  243. "primary_key": False,
  244. "unique": False,
  245. "foreign_key": None,
  246. "index": False,
  247. }
  248. if check_columns
  249. else None
  250. )
  251. return constraint_name, unique_constraint, check_constraint, token
  252. def _parse_table_constraints(self, sql, columns):
  253. # Check constraint parsing is based of SQLite syntax diagram.
  254. # https://www.sqlite.org/syntaxdiagrams.html#table-constraint
  255. statement = sqlparse.parse(sql)[0]
  256. constraints = {}
  257. unnamed_constrains_index = 0
  258. tokens = (token for token in statement.flatten() if not token.is_whitespace)
  259. # Go to columns and constraint definition
  260. for token in tokens:
  261. if token.match(sqlparse.tokens.Punctuation, "("):
  262. break
  263. # Parse columns and constraint definition
  264. while True:
  265. (
  266. constraint_name,
  267. unique,
  268. check,
  269. end_token,
  270. ) = self._parse_column_or_constraint_definition(tokens, columns)
  271. if unique:
  272. if constraint_name:
  273. constraints[constraint_name] = unique
  274. else:
  275. unnamed_constrains_index += 1
  276. constraints[
  277. "__unnamed_constraint_%s__" % unnamed_constrains_index
  278. ] = unique
  279. if check:
  280. if constraint_name:
  281. constraints[constraint_name] = check
  282. else:
  283. unnamed_constrains_index += 1
  284. constraints[
  285. "__unnamed_constraint_%s__" % unnamed_constrains_index
  286. ] = check
  287. if end_token.match(sqlparse.tokens.Punctuation, ")"):
  288. break
  289. return constraints
  290. def get_constraints(self, cursor, table_name):
  291. """
  292. Retrieve any constraints or keys (unique, pk, fk, check, index) across
  293. one or more columns.
  294. """
  295. constraints = {}
  296. # Find inline check constraints.
  297. try:
  298. table_schema = cursor.execute(
  299. "SELECT sql FROM sqlite_master WHERE type='table' and name=%s"
  300. % (self.connection.ops.quote_name(table_name),)
  301. ).fetchone()[0]
  302. except TypeError:
  303. # table_name is a view.
  304. pass
  305. else:
  306. columns = {
  307. info.name for info in self.get_table_description(cursor, table_name)
  308. }
  309. constraints.update(self._parse_table_constraints(table_schema, columns))
  310. # Get the index info
  311. cursor.execute(
  312. "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
  313. )
  314. for row in cursor.fetchall():
  315. # SQLite 3.8.9+ has 5 columns, however older versions only give 3
  316. # columns. Discard last 2 columns if there.
  317. number, index, unique = row[:3]
  318. cursor.execute(
  319. "SELECT sql FROM sqlite_master "
  320. "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
  321. )
  322. # There's at most one row.
  323. (sql,) = cursor.fetchone() or (None,)
  324. # Inline constraints are already detected in
  325. # _parse_table_constraints(). The reasons to avoid fetching inline
  326. # constraints from `PRAGMA index_list` are:
  327. # - Inline constraints can have a different name and information
  328. # than what `PRAGMA index_list` gives.
  329. # - Not all inline constraints may appear in `PRAGMA index_list`.
  330. if not sql:
  331. # An inline constraint
  332. continue
  333. # Get the index info for that index
  334. cursor.execute(
  335. "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
  336. )
  337. for index_rank, column_rank, column in cursor.fetchall():
  338. if index not in constraints:
  339. constraints[index] = {
  340. "columns": [],
  341. "primary_key": False,
  342. "unique": bool(unique),
  343. "foreign_key": None,
  344. "check": False,
  345. "index": True,
  346. }
  347. constraints[index]["columns"].append(column)
  348. # Add type and column orders for indexes
  349. if constraints[index]["index"]:
  350. # SQLite doesn't support any index type other than b-tree
  351. constraints[index]["type"] = Index.suffix
  352. orders = self._get_index_columns_orders(sql)
  353. if orders is not None:
  354. constraints[index]["orders"] = orders
  355. # Get the PK
  356. pk_columns = self.get_primary_key_columns(cursor, table_name)
  357. if pk_columns:
  358. # SQLite doesn't actually give a name to the PK constraint,
  359. # so we invent one. This is fine, as the SQLite backend never
  360. # deletes PK constraints by name, as you can't delete constraints
  361. # in SQLite; we remake the table with a new PK instead.
  362. constraints["__primary__"] = {
  363. "columns": pk_columns,
  364. "primary_key": True,
  365. "unique": False, # It's not actually a unique constraint.
  366. "foreign_key": None,
  367. "check": False,
  368. "index": False,
  369. }
  370. relations = enumerate(self.get_relations(cursor, table_name).items())
  371. constraints.update(
  372. {
  373. f"fk_{index}": {
  374. "columns": [column_name],
  375. "primary_key": False,
  376. "unique": False,
  377. "foreign_key": (ref_table_name, ref_column_name),
  378. "check": False,
  379. "index": False,
  380. }
  381. for index, (column_name, (ref_column_name, ref_table_name)) in relations
  382. }
  383. )
  384. return constraints
  385. def _get_index_columns_orders(self, sql):
  386. tokens = sqlparse.parse(sql)[0]
  387. for token in tokens:
  388. if isinstance(token, sqlparse.sql.Parenthesis):
  389. columns = str(token).strip("()").split(", ")
  390. return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
  391. return None
  392. def _get_column_collations(self, cursor, table_name):
  393. row = cursor.execute(
  394. """
  395. SELECT sql
  396. FROM sqlite_master
  397. WHERE type = 'table' AND name = %s
  398. """,
  399. [table_name],
  400. ).fetchone()
  401. if not row:
  402. return {}
  403. sql = row[0]
  404. columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
  405. collations = {}
  406. for column in columns:
  407. tokens = column[1:].split()
  408. column_name = tokens[0].strip('"')
  409. for index, token in enumerate(tokens):
  410. if token == "COLLATE":
  411. collation = tokens[index + 1]
  412. break
  413. else:
  414. collation = None
  415. collations[column_name] = collation
  416. return collations