init commit
This commit is contained in:
63
.venv/lib/python3.10/site-packages/django/db/__init__.py
Normal file
63
.venv/lib/python3.10/site-packages/django/db/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from django.core import signals
|
||||
from django.db.utils import (
|
||||
DEFAULT_DB_ALIAS,
|
||||
DJANGO_VERSION_PICKLE_KEY,
|
||||
ConnectionHandler,
|
||||
ConnectionRouter,
|
||||
DatabaseError,
|
||||
DataError,
|
||||
Error,
|
||||
IntegrityError,
|
||||
InterfaceError,
|
||||
InternalError,
|
||||
NotSupportedError,
|
||||
OperationalError,
|
||||
ProgrammingError,
|
||||
)
|
||||
from django.utils.connection import ConnectionProxy
|
||||
|
||||
__all__ = [
|
||||
"close_old_connections",
|
||||
"connection",
|
||||
"connections",
|
||||
"reset_queries",
|
||||
"router",
|
||||
"DatabaseError",
|
||||
"IntegrityError",
|
||||
"InternalError",
|
||||
"ProgrammingError",
|
||||
"DataError",
|
||||
"NotSupportedError",
|
||||
"Error",
|
||||
"InterfaceError",
|
||||
"OperationalError",
|
||||
"DEFAULT_DB_ALIAS",
|
||||
"DJANGO_VERSION_PICKLE_KEY",
|
||||
]
|
||||
|
||||
connections = ConnectionHandler()
|
||||
|
||||
router = ConnectionRouter()
|
||||
|
||||
# For backwards compatibility. Prefer connections['default'] instead.
|
||||
connection = ConnectionProxy(connections, DEFAULT_DB_ALIAS)
|
||||
|
||||
|
||||
# Register an event to reset saved queries when a Django request is started.
|
||||
def reset_queries(**kwargs):
|
||||
for conn in connections.all(initialized_only=True):
|
||||
conn.queries_log.clear()
|
||||
|
||||
|
||||
signals.request_started.connect(reset_queries)
|
||||
|
||||
|
||||
# Register an event to reset transaction state and close connections past
|
||||
# their lifetime.
|
||||
def close_old_connections(**kwargs):
|
||||
for conn in connections.all(initialized_only=True):
|
||||
conn.close_if_unusable_or_obsolete()
|
||||
|
||||
|
||||
signals.request_started.connect(close_old_connections)
|
||||
signals.request_finished.connect(close_old_connections)
|
||||
@@ -0,0 +1,792 @@
|
||||
import _thread
|
||||
import copy
|
||||
import datetime
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import warnings
|
||||
import zoneinfo
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import DEFAULT_DB_ALIAS, DatabaseError, NotSupportedError
|
||||
from django.db.backends import utils
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
from django.db.backends.signals import connection_created
|
||||
from django.db.backends.utils import debug_transaction
|
||||
from django.db.transaction import TransactionManagementError
|
||||
from django.db.utils import DatabaseErrorWrapper, ProgrammingError
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
NO_DB_ALIAS = "__no_db__"
|
||||
RAN_DB_VERSION_CHECK = set()
|
||||
|
||||
logger = logging.getLogger("django.db.backends.base")
|
||||
|
||||
|
||||
class BaseDatabaseWrapper:
|
||||
"""Represent a database connection."""
|
||||
|
||||
# Mapping of Field objects to their column types.
|
||||
data_types = {}
|
||||
# Mapping of Field objects to their SQL suffix such as AUTOINCREMENT.
|
||||
data_types_suffix = {}
|
||||
# Mapping of Field objects to their SQL for CHECK constraints.
|
||||
data_type_check_constraints = {}
|
||||
ops = None
|
||||
vendor = "unknown"
|
||||
display_name = "unknown"
|
||||
SchemaEditorClass = None
|
||||
# Classes instantiated in __init__().
|
||||
client_class = None
|
||||
creation_class = None
|
||||
features_class = None
|
||||
introspection_class = None
|
||||
ops_class = None
|
||||
validation_class = BaseDatabaseValidation
|
||||
|
||||
queries_limit = 9000
|
||||
|
||||
def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS):
|
||||
# Connection related attributes.
|
||||
# The underlying database connection.
|
||||
self.connection = None
|
||||
# `settings_dict` should be a dictionary containing keys such as
|
||||
# NAME, USER, etc. It's called `settings_dict` instead of `settings`
|
||||
# to disambiguate it from Django settings modules.
|
||||
self.settings_dict = settings_dict
|
||||
self.alias = alias
|
||||
# Query logging in debug mode or when explicitly enabled.
|
||||
self.queries_log = deque(maxlen=self.queries_limit)
|
||||
self.force_debug_cursor = False
|
||||
|
||||
# Transaction related attributes.
|
||||
# Tracks if the connection is in autocommit mode. Per PEP 249, by
|
||||
# default, it isn't.
|
||||
self.autocommit = False
|
||||
# Tracks if the connection is in a transaction managed by 'atomic'.
|
||||
self.in_atomic_block = False
|
||||
# Increment to generate unique savepoint ids.
|
||||
self.savepoint_state = 0
|
||||
# List of savepoints created by 'atomic'.
|
||||
self.savepoint_ids = []
|
||||
# Stack of active 'atomic' blocks.
|
||||
self.atomic_blocks = []
|
||||
# Tracks if the outermost 'atomic' block should commit on exit,
|
||||
# ie. if autocommit was active on entry.
|
||||
self.commit_on_exit = True
|
||||
# Tracks if the transaction should be rolled back to the next
|
||||
# available savepoint because of an exception in an inner block.
|
||||
self.needs_rollback = False
|
||||
self.rollback_exc = None
|
||||
|
||||
# Connection termination related attributes.
|
||||
self.close_at = None
|
||||
self.closed_in_transaction = False
|
||||
self.errors_occurred = False
|
||||
self.health_check_enabled = False
|
||||
self.health_check_done = False
|
||||
|
||||
# Thread-safety related attributes.
|
||||
self._thread_sharing_lock = threading.Lock()
|
||||
self._thread_sharing_count = 0
|
||||
self._thread_ident = _thread.get_ident()
|
||||
|
||||
# A list of no-argument functions to run when the transaction commits.
|
||||
# Each entry is an (sids, func, robust) tuple, where sids is a set of
|
||||
# the active savepoint IDs when this function was registered and robust
|
||||
# specifies whether it's allowed for the function to fail.
|
||||
self.run_on_commit = []
|
||||
|
||||
# Should we run the on-commit hooks the next time set_autocommit(True)
|
||||
# is called?
|
||||
self.run_commit_hooks_on_set_autocommit_on = False
|
||||
|
||||
# A stack of wrappers to be invoked around execute()/executemany()
|
||||
# calls. Each entry is a function taking five arguments: execute, sql,
|
||||
# params, many, and context. It's the function's responsibility to
|
||||
# call execute(sql, params, many, context).
|
||||
self.execute_wrappers = []
|
||||
|
||||
self.client = self.client_class(self)
|
||||
self.creation = self.creation_class(self)
|
||||
self.features = self.features_class(self)
|
||||
self.introspection = self.introspection_class(self)
|
||||
self.ops = self.ops_class(self)
|
||||
self.validation = self.validation_class(self)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<{self.__class__.__qualname__} "
|
||||
f"vendor={self.vendor!r} alias={self.alias!r}>"
|
||||
)
|
||||
|
||||
def ensure_timezone(self):
|
||||
"""
|
||||
Ensure the connection's timezone is set to `self.timezone_name` and
|
||||
return whether it changed or not.
|
||||
"""
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def timezone(self):
|
||||
"""
|
||||
Return a tzinfo of the database connection time zone.
|
||||
|
||||
This is only used when time zone support is enabled. When a datetime is
|
||||
read from the database, it is always returned in this time zone.
|
||||
|
||||
When the database backend supports time zones, it doesn't matter which
|
||||
time zone Django uses, as long as aware datetimes are used everywhere.
|
||||
Other users connecting to the database can choose their own time zone.
|
||||
|
||||
When the database backend doesn't support time zones, the time zone
|
||||
Django uses may be constrained by the requirements of other users of
|
||||
the database.
|
||||
"""
|
||||
if not settings.USE_TZ:
|
||||
return None
|
||||
elif self.settings_dict["TIME_ZONE"] is None:
|
||||
return datetime.timezone.utc
|
||||
else:
|
||||
return zoneinfo.ZoneInfo(self.settings_dict["TIME_ZONE"])
|
||||
|
||||
@cached_property
|
||||
def timezone_name(self):
|
||||
"""
|
||||
Name of the time zone of the database connection.
|
||||
"""
|
||||
if not settings.USE_TZ:
|
||||
return settings.TIME_ZONE
|
||||
elif self.settings_dict["TIME_ZONE"] is None:
|
||||
return "UTC"
|
||||
else:
|
||||
return self.settings_dict["TIME_ZONE"]
|
||||
|
||||
@property
|
||||
def queries_logged(self):
|
||||
return self.force_debug_cursor or settings.DEBUG
|
||||
|
||||
@property
|
||||
def queries(self):
|
||||
if len(self.queries_log) == self.queries_log.maxlen:
|
||||
warnings.warn(
|
||||
"Limit for query logging exceeded, only the last {} queries "
|
||||
"will be returned.".format(self.queries_log.maxlen),
|
||||
stacklevel=2,
|
||||
)
|
||||
return list(self.queries_log)
|
||||
|
||||
def get_database_version(self):
|
||||
"""Return a tuple of the database's version."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a get_database_version() "
|
||||
"method."
|
||||
)
|
||||
|
||||
def check_database_version_supported(self):
|
||||
"""
|
||||
Raise an error if the database version isn't supported by this
|
||||
version of Django.
|
||||
"""
|
||||
if (
|
||||
self.features.minimum_database_version is not None
|
||||
and self.get_database_version() < self.features.minimum_database_version
|
||||
):
|
||||
db_version = ".".join(map(str, self.get_database_version()))
|
||||
min_db_version = ".".join(map(str, self.features.minimum_database_version))
|
||||
raise NotSupportedError(
|
||||
f"{self.display_name} {min_db_version} or later is required "
|
||||
f"(found {db_version})."
|
||||
)
|
||||
|
||||
# ##### Backend-specific methods for creating connections and cursors #####
|
||||
|
||||
def get_connection_params(self):
|
||||
"""Return a dict of parameters suitable for get_new_connection."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a get_connection_params() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def get_new_connection(self, conn_params):
|
||||
"""Open a connection to the database."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a get_new_connection() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def init_connection_state(self):
|
||||
"""Initialize the database connection settings."""
|
||||
if self.alias not in RAN_DB_VERSION_CHECK:
|
||||
self.check_database_version_supported()
|
||||
RAN_DB_VERSION_CHECK.add(self.alias)
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
"""Create a cursor. Assume that a connection is established."""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a create_cursor() method"
|
||||
)
|
||||
|
||||
# ##### Backend-specific methods for creating connections #####
|
||||
|
||||
@async_unsafe
|
||||
def connect(self):
|
||||
"""Connect to the database. Assume that the connection is closed."""
|
||||
# Check for invalid configurations.
|
||||
self.check_settings()
|
||||
# In case the previous connection was closed while in an atomic block
|
||||
self.in_atomic_block = False
|
||||
self.savepoint_ids = []
|
||||
self.atomic_blocks = []
|
||||
self.needs_rollback = False
|
||||
# Reset parameters defining when to close/health-check the connection.
|
||||
self.health_check_enabled = self.settings_dict["CONN_HEALTH_CHECKS"]
|
||||
max_age = self.settings_dict["CONN_MAX_AGE"]
|
||||
self.close_at = None if max_age is None else time.monotonic() + max_age
|
||||
self.closed_in_transaction = False
|
||||
self.errors_occurred = False
|
||||
# New connections are healthy.
|
||||
self.health_check_done = True
|
||||
# Establish the connection
|
||||
conn_params = self.get_connection_params()
|
||||
self.connection = self.get_new_connection(conn_params)
|
||||
self.set_autocommit(self.settings_dict["AUTOCOMMIT"])
|
||||
self.init_connection_state()
|
||||
connection_created.send(sender=self.__class__, connection=self)
|
||||
|
||||
self.run_on_commit = []
|
||||
|
||||
def check_settings(self):
|
||||
if self.settings_dict["TIME_ZONE"] is not None and not settings.USE_TZ:
|
||||
raise ImproperlyConfigured(
|
||||
"Connection '%s' cannot set TIME_ZONE because USE_TZ is False."
|
||||
% self.alias
|
||||
)
|
||||
|
||||
@async_unsafe
|
||||
def ensure_connection(self):
|
||||
"""Guarantee that a connection to the database is established."""
|
||||
if self.connection is None:
|
||||
if self.in_atomic_block and self.closed_in_transaction:
|
||||
raise ProgrammingError(
|
||||
"Cannot open a new connection in an atomic block."
|
||||
)
|
||||
with self.wrap_database_errors:
|
||||
self.connect()
|
||||
|
||||
# ##### Backend-specific wrappers for PEP-249 connection methods #####
|
||||
|
||||
def _prepare_cursor(self, cursor):
|
||||
"""
|
||||
Validate the connection is usable and perform database cursor wrapping.
|
||||
"""
|
||||
self.validate_thread_sharing()
|
||||
if self.queries_logged:
|
||||
wrapped_cursor = self.make_debug_cursor(cursor)
|
||||
else:
|
||||
wrapped_cursor = self.make_cursor(cursor)
|
||||
return wrapped_cursor
|
||||
|
||||
def _cursor(self, name=None):
|
||||
self.close_if_health_check_failed()
|
||||
self.ensure_connection()
|
||||
with self.wrap_database_errors:
|
||||
return self._prepare_cursor(self.create_cursor(name))
|
||||
|
||||
def _commit(self):
|
||||
if self.connection is not None:
|
||||
with debug_transaction(self, "COMMIT"), self.wrap_database_errors:
|
||||
return self.connection.commit()
|
||||
|
||||
def _rollback(self):
|
||||
if self.connection is not None:
|
||||
with debug_transaction(self, "ROLLBACK"), self.wrap_database_errors:
|
||||
return self.connection.rollback()
|
||||
|
||||
def _close(self):
|
||||
if self.connection is not None:
|
||||
with self.wrap_database_errors:
|
||||
return self.connection.close()
|
||||
|
||||
# ##### Generic wrappers for PEP-249 connection methods #####
|
||||
|
||||
@async_unsafe
|
||||
def cursor(self):
|
||||
"""Create a cursor, opening a connection if necessary."""
|
||||
return self._cursor()
|
||||
|
||||
@async_unsafe
|
||||
def commit(self):
|
||||
"""Commit a transaction and reset the dirty flag."""
|
||||
self.validate_thread_sharing()
|
||||
self.validate_no_atomic_block()
|
||||
self._commit()
|
||||
# A successful commit means that the database connection works.
|
||||
self.errors_occurred = False
|
||||
self.run_commit_hooks_on_set_autocommit_on = True
|
||||
|
||||
@async_unsafe
|
||||
def rollback(self):
|
||||
"""Roll back a transaction and reset the dirty flag."""
|
||||
self.validate_thread_sharing()
|
||||
self.validate_no_atomic_block()
|
||||
self._rollback()
|
||||
# A successful rollback means that the database connection works.
|
||||
self.errors_occurred = False
|
||||
self.needs_rollback = False
|
||||
self.run_on_commit = []
|
||||
|
||||
@async_unsafe
|
||||
def close(self):
|
||||
"""Close the connection to the database."""
|
||||
self.validate_thread_sharing()
|
||||
self.run_on_commit = []
|
||||
|
||||
# Don't call validate_no_atomic_block() to avoid making it difficult
|
||||
# to get rid of a connection in an invalid state. The next connect()
|
||||
# will reset the transaction state anyway.
|
||||
if self.closed_in_transaction or self.connection is None:
|
||||
return
|
||||
try:
|
||||
self._close()
|
||||
finally:
|
||||
if self.in_atomic_block:
|
||||
self.closed_in_transaction = True
|
||||
self.needs_rollback = True
|
||||
else:
|
||||
self.connection = None
|
||||
|
||||
# ##### Backend-specific savepoint management methods #####
|
||||
|
||||
def _savepoint(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_create_sql(sid))
|
||||
|
||||
def _savepoint_rollback(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_rollback_sql(sid))
|
||||
|
||||
def _savepoint_commit(self, sid):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute(self.ops.savepoint_commit_sql(sid))
|
||||
|
||||
def _savepoint_allowed(self):
|
||||
# Savepoints cannot be created outside a transaction
|
||||
return self.features.uses_savepoints and not self.get_autocommit()
|
||||
|
||||
# ##### Generic savepoint management methods #####
|
||||
|
||||
@async_unsafe
|
||||
def savepoint(self):
|
||||
"""
|
||||
Create a savepoint inside the current transaction. Return an
|
||||
identifier for the savepoint that will be used for the subsequent
|
||||
rollback or commit. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not self._savepoint_allowed():
|
||||
return
|
||||
|
||||
thread_ident = _thread.get_ident()
|
||||
tid = str(thread_ident).replace("-", "")
|
||||
|
||||
self.savepoint_state += 1
|
||||
sid = "s%s_x%d" % (tid, self.savepoint_state)
|
||||
|
||||
self.validate_thread_sharing()
|
||||
self._savepoint(sid)
|
||||
|
||||
return sid
|
||||
|
||||
@async_unsafe
|
||||
def savepoint_rollback(self, sid):
|
||||
"""
|
||||
Roll back to a savepoint. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not self._savepoint_allowed():
|
||||
return
|
||||
|
||||
self.validate_thread_sharing()
|
||||
self._savepoint_rollback(sid)
|
||||
|
||||
# Remove any callbacks registered while this savepoint was active.
|
||||
self.run_on_commit = [
|
||||
(sids, func, robust)
|
||||
for (sids, func, robust) in self.run_on_commit
|
||||
if sid not in sids
|
||||
]
|
||||
|
||||
@async_unsafe
|
||||
def savepoint_commit(self, sid):
|
||||
"""
|
||||
Release a savepoint. Do nothing if savepoints are not supported.
|
||||
"""
|
||||
if not self._savepoint_allowed():
|
||||
return
|
||||
|
||||
self.validate_thread_sharing()
|
||||
self._savepoint_commit(sid)
|
||||
|
||||
@async_unsafe
|
||||
def clean_savepoints(self):
|
||||
"""
|
||||
Reset the counter used to generate unique savepoint ids in this thread.
|
||||
"""
|
||||
self.savepoint_state = 0
|
||||
|
||||
# ##### Backend-specific transaction management methods #####
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
"""
|
||||
Backend-specific implementation to enable or disable autocommit.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require a _set_autocommit() method"
|
||||
)
|
||||
|
||||
# ##### Generic transaction management methods #####
|
||||
|
||||
def get_autocommit(self):
|
||||
"""Get the autocommit state."""
|
||||
self.ensure_connection()
|
||||
return self.autocommit
|
||||
|
||||
def set_autocommit(
|
||||
self, autocommit, force_begin_transaction_with_broken_autocommit=False
|
||||
):
|
||||
"""
|
||||
Enable or disable autocommit.
|
||||
|
||||
The usual way to start a transaction is to turn autocommit off.
|
||||
SQLite does not properly start a transaction when disabling
|
||||
autocommit. To avoid this buggy behavior and to actually enter a new
|
||||
transaction, an explicit BEGIN is required. Using
|
||||
force_begin_transaction_with_broken_autocommit=True will issue an
|
||||
explicit BEGIN with SQLite. This option will be ignored for other
|
||||
backends.
|
||||
"""
|
||||
self.validate_no_atomic_block()
|
||||
self.close_if_health_check_failed()
|
||||
self.ensure_connection()
|
||||
|
||||
start_transaction_under_autocommit = (
|
||||
force_begin_transaction_with_broken_autocommit
|
||||
and not autocommit
|
||||
and hasattr(self, "_start_transaction_under_autocommit")
|
||||
)
|
||||
|
||||
if start_transaction_under_autocommit:
|
||||
self._start_transaction_under_autocommit()
|
||||
elif autocommit:
|
||||
self._set_autocommit(autocommit)
|
||||
else:
|
||||
with debug_transaction(self, "BEGIN"):
|
||||
self._set_autocommit(autocommit)
|
||||
self.autocommit = autocommit
|
||||
|
||||
if autocommit and self.run_commit_hooks_on_set_autocommit_on:
|
||||
self.run_and_clear_commit_hooks()
|
||||
self.run_commit_hooks_on_set_autocommit_on = False
|
||||
|
||||
def get_rollback(self):
|
||||
"""Get the "needs rollback" flag -- for *advanced use* only."""
|
||||
if not self.in_atomic_block:
|
||||
raise TransactionManagementError(
|
||||
"The rollback flag doesn't work outside of an 'atomic' block."
|
||||
)
|
||||
return self.needs_rollback
|
||||
|
||||
def set_rollback(self, rollback):
|
||||
"""
|
||||
Set or unset the "needs rollback" flag -- for *advanced use* only.
|
||||
"""
|
||||
if not self.in_atomic_block:
|
||||
raise TransactionManagementError(
|
||||
"The rollback flag doesn't work outside of an 'atomic' block."
|
||||
)
|
||||
self.needs_rollback = rollback
|
||||
|
||||
def validate_no_atomic_block(self):
|
||||
"""Raise an error if an atomic block is active."""
|
||||
if self.in_atomic_block:
|
||||
raise TransactionManagementError(
|
||||
"This is forbidden when an 'atomic' block is active."
|
||||
)
|
||||
|
||||
def validate_no_broken_transaction(self):
|
||||
if self.needs_rollback:
|
||||
raise TransactionManagementError(
|
||||
"An error occurred in the current transaction. You can't "
|
||||
"execute queries until the end of the 'atomic' block."
|
||||
) from self.rollback_exc
|
||||
|
||||
# ##### Foreign key constraints checks handling #####
|
||||
|
||||
@contextmanager
|
||||
def constraint_checks_disabled(self):
|
||||
"""
|
||||
Disable foreign key constraint checking.
|
||||
"""
|
||||
disabled = self.disable_constraint_checking()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if disabled:
|
||||
self.enable_constraint_checking()
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
"""
|
||||
Backends can implement as needed to temporarily disable foreign key
|
||||
constraint checking. Should return True if the constraints were
|
||||
disabled and will need to be reenabled.
|
||||
"""
|
||||
return False
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
"""
|
||||
Backends can implement as needed to re-enable foreign key constraint
|
||||
checking.
|
||||
"""
|
||||
pass
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Backends can override this method if they can apply constraint
|
||||
checking (e.g. via "SET CONSTRAINTS ALL IMMEDIATE"). Should raise an
|
||||
IntegrityError if any invalid foreign key references are encountered.
|
||||
"""
|
||||
pass
|
||||
|
||||
# ##### Connection termination handling #####
|
||||
|
||||
def is_usable(self):
|
||||
"""
|
||||
Test if the database connection is usable.
|
||||
|
||||
This method may assume that self.connection is not None.
|
||||
|
||||
Actual implementations should take care not to raise exceptions
|
||||
as that may prevent Django from recycling unusable connections.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseWrapper may require an is_usable() method"
|
||||
)
|
||||
|
||||
def close_if_health_check_failed(self):
|
||||
"""Close existing connection if it fails a health check."""
|
||||
if (
|
||||
self.connection is None
|
||||
or not self.health_check_enabled
|
||||
or self.health_check_done
|
||||
):
|
||||
return
|
||||
|
||||
if not self.is_usable():
|
||||
self.close()
|
||||
self.health_check_done = True
|
||||
|
||||
def close_if_unusable_or_obsolete(self):
|
||||
"""
|
||||
Close the current connection if unrecoverable errors have occurred
|
||||
or if it outlived its maximum age.
|
||||
"""
|
||||
if self.connection is not None:
|
||||
self.health_check_done = False
|
||||
# If the application didn't restore the original autocommit setting,
|
||||
# don't take chances, drop the connection.
|
||||
if self.get_autocommit() != self.settings_dict["AUTOCOMMIT"]:
|
||||
self.close()
|
||||
return
|
||||
|
||||
# If an exception other than DataError or IntegrityError occurred
|
||||
# since the last commit / rollback, check if the connection works.
|
||||
if self.errors_occurred:
|
||||
if self.is_usable():
|
||||
self.errors_occurred = False
|
||||
self.health_check_done = True
|
||||
else:
|
||||
self.close()
|
||||
return
|
||||
|
||||
if self.close_at is not None and time.monotonic() >= self.close_at:
|
||||
self.close()
|
||||
return
|
||||
|
||||
# ##### Thread safety handling #####
|
||||
|
||||
@property
|
||||
def allow_thread_sharing(self):
|
||||
with self._thread_sharing_lock:
|
||||
return self._thread_sharing_count > 0
|
||||
|
||||
def inc_thread_sharing(self):
|
||||
with self._thread_sharing_lock:
|
||||
self._thread_sharing_count += 1
|
||||
|
||||
def dec_thread_sharing(self):
|
||||
with self._thread_sharing_lock:
|
||||
if self._thread_sharing_count <= 0:
|
||||
raise RuntimeError(
|
||||
"Cannot decrement the thread sharing count below zero."
|
||||
)
|
||||
self._thread_sharing_count -= 1
|
||||
|
||||
def validate_thread_sharing(self):
|
||||
"""
|
||||
Validate that the connection isn't accessed by another thread than the
|
||||
one which originally created it, unless the connection was explicitly
|
||||
authorized to be shared between threads (via the `inc_thread_sharing()`
|
||||
method). Raise an exception if the validation fails.
|
||||
"""
|
||||
if not (self.allow_thread_sharing or self._thread_ident == _thread.get_ident()):
|
||||
raise DatabaseError(
|
||||
"DatabaseWrapper objects created in a "
|
||||
"thread can only be used in that same thread. The object "
|
||||
"with alias '%s' was created in thread id %s and this is "
|
||||
"thread id %s." % (self.alias, self._thread_ident, _thread.get_ident())
|
||||
)
|
||||
|
||||
# ##### Miscellaneous #####
|
||||
|
||||
def prepare_database(self):
|
||||
"""
|
||||
Hook to do any database check or preparation, generally called before
|
||||
migrating a project or an app.
|
||||
"""
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def wrap_database_errors(self):
|
||||
"""
|
||||
Context manager and decorator that re-throws backend-specific database
|
||||
exceptions using Django's common wrappers.
|
||||
"""
|
||||
return DatabaseErrorWrapper(self)
|
||||
|
||||
def chunked_cursor(self):
|
||||
"""
|
||||
Return a cursor that tries to avoid caching in the database (if
|
||||
supported by the database), otherwise return a regular cursor.
|
||||
"""
|
||||
return self.cursor()
|
||||
|
||||
def make_debug_cursor(self, cursor):
|
||||
"""Create a cursor that logs all queries in self.queries_log."""
|
||||
return utils.CursorDebugWrapper(cursor, self)
|
||||
|
||||
def make_cursor(self, cursor):
|
||||
"""Create a cursor without debug logging."""
|
||||
return utils.CursorWrapper(cursor, self)
|
||||
|
||||
@contextmanager
|
||||
def temporary_connection(self):
|
||||
"""
|
||||
Context manager that ensures that a connection is established, and
|
||||
if it opened one, closes it to avoid leaving a dangling connection.
|
||||
This is useful for operations outside of the request-response cycle.
|
||||
|
||||
Provide a cursor: with self.temporary_connection() as cursor: ...
|
||||
"""
|
||||
must_close = self.connection is None
|
||||
try:
|
||||
with self.cursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
if must_close:
|
||||
self.close()
|
||||
|
||||
@contextmanager
|
||||
def _nodb_cursor(self):
|
||||
"""
|
||||
Return a cursor from an alternative connection to be used when there is
|
||||
no need to access the main database, specifically for test db
|
||||
creation/deletion. This also prevents the production database from
|
||||
being exposed to potential child threads while (or after) the test
|
||||
database is destroyed. Refs #10868, #17786, #16969.
|
||||
"""
|
||||
conn = self.__class__({**self.settings_dict, "NAME": None}, alias=NO_DB_ALIAS)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def schema_editor(self, *args, **kwargs):
|
||||
"""
|
||||
Return a new instance of this backend's SchemaEditor.
|
||||
"""
|
||||
if self.SchemaEditorClass is None:
|
||||
raise NotImplementedError(
|
||||
"The SchemaEditorClass attribute of this database wrapper is still None"
|
||||
)
|
||||
return self.SchemaEditorClass(self, *args, **kwargs)
|
||||
|
||||
def on_commit(self, func, robust=False):
|
||||
if not callable(func):
|
||||
raise TypeError("on_commit()'s callback must be a callable.")
|
||||
if self.in_atomic_block:
|
||||
# Transaction in progress; save for execution on commit.
|
||||
self.run_on_commit.append((set(self.savepoint_ids), func, robust))
|
||||
elif not self.get_autocommit():
|
||||
raise TransactionManagementError(
|
||||
"on_commit() cannot be used in manual transaction management"
|
||||
)
|
||||
else:
|
||||
# No transaction in progress and in autocommit mode; execute
|
||||
# immediately.
|
||||
if robust:
|
||||
try:
|
||||
func()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error calling {func.__qualname__} in on_commit() (%s).",
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
func()
|
||||
|
||||
def run_and_clear_commit_hooks(self):
|
||||
self.validate_no_atomic_block()
|
||||
current_run_on_commit = self.run_on_commit
|
||||
self.run_on_commit = []
|
||||
while current_run_on_commit:
|
||||
_, func, robust = current_run_on_commit.pop(0)
|
||||
if robust:
|
||||
try:
|
||||
func()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error calling {func.__qualname__} in on_commit() during "
|
||||
f"transaction (%s).",
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
func()
|
||||
|
||||
@contextmanager
|
||||
def execute_wrapper(self, wrapper):
|
||||
"""
|
||||
Return a context manager under which the wrapper is applied to suitable
|
||||
database query executions.
|
||||
"""
|
||||
self.execute_wrappers.append(wrapper)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.execute_wrappers.pop()
|
||||
|
||||
def copy(self, alias=None):
|
||||
"""
|
||||
Return a copy of this connection.
|
||||
|
||||
For tests that require two connections to the same database.
|
||||
"""
|
||||
settings_dict = copy.deepcopy(self.settings_dict)
|
||||
if alias is None:
|
||||
alias = self.alias
|
||||
return type(self)(settings_dict, alias)
|
||||
@@ -0,0 +1,28 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
|
||||
class BaseDatabaseClient:
|
||||
"""Encapsulate backend-specific methods for opening a client shell."""
|
||||
|
||||
# This should be a string representing the name of the executable
|
||||
# (e.g., "psql"). Subclasses must override this.
|
||||
executable_name = None
|
||||
|
||||
def __init__(self, connection):
|
||||
# connection is an instance of BaseDatabaseWrapper.
|
||||
self.connection = connection
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseClient must provide a "
|
||||
"settings_to_cmd_args_env() method or override a runshell()."
|
||||
)
|
||||
|
||||
def runshell(self, parameters):
|
||||
args, env = self.settings_to_cmd_args_env(
|
||||
self.connection.settings_dict, parameters
|
||||
)
|
||||
env = {**os.environ, **env} if env else None
|
||||
subprocess.run(args, env=env, check=True)
|
||||
@@ -0,0 +1,384 @@
|
||||
import os
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.core import serializers
|
||||
from django.db import router
|
||||
from django.db.transaction import atomic
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
# The prefix to put on the default database name when creating
|
||||
# the test database.
|
||||
TEST_DATABASE_PREFIX = "test_"
|
||||
|
||||
|
||||
class BaseDatabaseCreation:
|
||||
"""
|
||||
Encapsulate backend-specific differences pertaining to creation and
|
||||
destruction of the test database.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def _nodb_cursor(self):
|
||||
return self.connection._nodb_cursor()
|
||||
|
||||
def log(self, msg):
|
||||
sys.stderr.write(msg + os.linesep)
|
||||
|
||||
def create_test_db(
|
||||
self, verbosity=1, autoclobber=False, serialize=True, keepdb=False
|
||||
):
|
||||
"""
|
||||
Create a test database, prompting the user for confirmation if the
|
||||
database already exists. Return the name of the test database created.
|
||||
"""
|
||||
# Don't import django.core.management if it isn't needed.
|
||||
from django.core.management import call_command
|
||||
|
||||
test_database_name = self._get_test_db_name()
|
||||
|
||||
if verbosity >= 1:
|
||||
action = "Creating"
|
||||
if keepdb:
|
||||
action = "Using existing"
|
||||
|
||||
self.log(
|
||||
"%s test database for alias %s..."
|
||||
% (
|
||||
action,
|
||||
self._get_database_display_str(verbosity, test_database_name),
|
||||
)
|
||||
)
|
||||
|
||||
# We could skip this call if keepdb is True, but we instead
|
||||
# give it the keepdb param. This is to handle the case
|
||||
# where the test DB doesn't exist, in which case we need to
|
||||
# create it, then just not destroy it. If we instead skip
|
||||
# this, we will get an exception.
|
||||
self._create_test_db(verbosity, autoclobber, keepdb)
|
||||
|
||||
self.connection.close()
|
||||
settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
|
||||
self.connection.settings_dict["NAME"] = test_database_name
|
||||
|
||||
try:
|
||||
if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
|
||||
# Disable migrations for all apps.
|
||||
old_migration_modules = settings.MIGRATION_MODULES
|
||||
settings.MIGRATION_MODULES = {
|
||||
app.label: None for app in apps.get_app_configs()
|
||||
}
|
||||
# We report migrate messages at one level lower than that
|
||||
# requested. This ensures we don't get flooded with messages during
|
||||
# testing (unless you really ask to be flooded).
|
||||
call_command(
|
||||
"migrate",
|
||||
verbosity=max(verbosity - 1, 0),
|
||||
interactive=False,
|
||||
database=self.connection.alias,
|
||||
run_syncdb=True,
|
||||
)
|
||||
finally:
|
||||
if self.connection.settings_dict["TEST"]["MIGRATE"] is False:
|
||||
settings.MIGRATION_MODULES = old_migration_modules
|
||||
|
||||
# We then serialize the current state of the database into a string
|
||||
# and store it on the connection. This slightly horrific process is so people
|
||||
# who are testing on databases without transactions or who are using
|
||||
# a TransactionTestCase still get a clean database on every test run.
|
||||
if serialize:
|
||||
self.connection._test_serialized_contents = self.serialize_db_to_string()
|
||||
|
||||
call_command("createcachetable", database=self.connection.alias)
|
||||
|
||||
# Ensure a connection for the side effect of initializing the test database.
|
||||
self.connection.ensure_connection()
|
||||
|
||||
if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
|
||||
self.mark_expected_failures_and_skips()
|
||||
|
||||
return test_database_name
|
||||
|
||||
def set_as_test_mirror(self, primary_settings_dict):
|
||||
"""
|
||||
Set this database up to be used in testing as a mirror of a primary
|
||||
database whose settings are given.
|
||||
"""
|
||||
self.connection.settings_dict["NAME"] = primary_settings_dict["NAME"]
|
||||
|
||||
def serialize_db_to_string(self):
|
||||
"""
|
||||
Serialize all data in the database into a JSON string.
|
||||
Designed only for test runner usage; will not handle large
|
||||
amounts of data.
|
||||
"""
|
||||
|
||||
# Iteratively return every object for all models to serialize.
|
||||
def get_objects():
|
||||
from django.db.migrations.loader import MigrationLoader
|
||||
|
||||
loader = MigrationLoader(self.connection)
|
||||
for app_config in apps.get_app_configs():
|
||||
if (
|
||||
app_config.models_module is not None
|
||||
and app_config.label in loader.migrated_apps
|
||||
and app_config.name not in settings.TEST_NON_SERIALIZED_APPS
|
||||
):
|
||||
for model in app_config.get_models():
|
||||
if model._meta.can_migrate(
|
||||
self.connection
|
||||
) and router.allow_migrate_model(self.connection.alias, model):
|
||||
queryset = model._base_manager.using(
|
||||
self.connection.alias,
|
||||
).order_by(model._meta.pk.name)
|
||||
chunk_size = (
|
||||
2000 if queryset._prefetch_related_lookups else None
|
||||
)
|
||||
yield from queryset.iterator(chunk_size=chunk_size)
|
||||
|
||||
# Serialize to a string
|
||||
out = StringIO()
|
||||
serializers.serialize("json", get_objects(), indent=None, stream=out)
|
||||
return out.getvalue()
|
||||
|
||||
def deserialize_db_from_string(self, data):
|
||||
"""
|
||||
Reload the database with data from a string generated by
|
||||
the serialize_db_to_string() method.
|
||||
"""
|
||||
data = StringIO(data)
|
||||
table_names = set()
|
||||
# Load data in a transaction to handle forward references and cycles.
|
||||
with atomic(using=self.connection.alias):
|
||||
# Disable constraint checks, because some databases (MySQL) doesn't
|
||||
# support deferred checks.
|
||||
with self.connection.constraint_checks_disabled():
|
||||
for obj in serializers.deserialize(
|
||||
"json", data, using=self.connection.alias
|
||||
):
|
||||
obj.save()
|
||||
table_names.add(obj.object.__class__._meta.db_table)
|
||||
# Manually check for any invalid keys that might have been added,
|
||||
# because constraint checks were disabled.
|
||||
self.connection.check_constraints(table_names=table_names)
|
||||
|
||||
def _get_database_display_str(self, verbosity, database_name):
|
||||
"""
|
||||
Return display string for a database for use in various actions.
|
||||
"""
|
||||
return "'%s'%s" % (
|
||||
self.connection.alias,
|
||||
(" ('%s')" % database_name) if verbosity >= 2 else "",
|
||||
)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
"""
|
||||
Internal implementation - return the name of the test DB that will be
|
||||
created. Only useful when called from create_test_db() and
|
||||
_create_test_db() and when no external munging is done with the 'NAME'
|
||||
settings.
|
||||
"""
|
||||
if self.connection.settings_dict["TEST"]["NAME"]:
|
||||
return self.connection.settings_dict["TEST"]["NAME"]
|
||||
return TEST_DATABASE_PREFIX + self.connection.settings_dict["NAME"]
|
||||
|
||||
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||||
cursor.execute("CREATE DATABASE %(dbname)s %(suffix)s" % parameters)
|
||||
|
||||
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
|
||||
"""
|
||||
Internal implementation - create the test db tables.
|
||||
"""
|
||||
test_database_name = self._get_test_db_name()
|
||||
test_db_params = {
|
||||
"dbname": self.connection.ops.quote_name(test_database_name),
|
||||
"suffix": self.sql_table_creation_suffix(),
|
||||
}
|
||||
# Create the test database and connect to it.
|
||||
with self._nodb_cursor() as cursor:
|
||||
try:
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
# if we want to keep the db, then no need to do any of the below,
|
||||
# just return and skip it all.
|
||||
if keepdb:
|
||||
return test_database_name
|
||||
|
||||
self.log("Got an error creating the test database: %s" % e)
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"Type 'yes' if you would like to try deleting the test "
|
||||
"database '%s', or 'no' to cancel: " % test_database_name
|
||||
)
|
||||
if autoclobber or confirm == "yes":
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (
|
||||
self._get_database_display_str(
|
||||
verbosity, test_database_name
|
||||
),
|
||||
)
|
||||
)
|
||||
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
self.log("Got an error recreating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled.")
|
||||
sys.exit(1)
|
||||
|
||||
return test_database_name
|
||||
|
||||
def clone_test_db(self, suffix, verbosity=1, autoclobber=False, keepdb=False):
|
||||
"""
|
||||
Clone a test database.
|
||||
"""
|
||||
source_database_name = self.connection.settings_dict["NAME"]
|
||||
|
||||
if verbosity >= 1:
|
||||
action = "Cloning test database"
|
||||
if keepdb:
|
||||
action = "Using existing clone"
|
||||
self.log(
|
||||
"%s for alias %s..."
|
||||
% (
|
||||
action,
|
||||
self._get_database_display_str(verbosity, source_database_name),
|
||||
)
|
||||
)
|
||||
|
||||
# We could skip this call if keepdb is True, but we instead
|
||||
# give it the keepdb param. See create_test_db for details.
|
||||
self._clone_test_db(suffix, verbosity, keepdb)
|
||||
|
||||
def get_test_db_clone_settings(self, suffix):
|
||||
"""
|
||||
Return a modified connection settings dict for the n-th clone of a DB.
|
||||
"""
|
||||
# When this function is called, the test database has been created
|
||||
# already and its name has been copied to settings_dict['NAME'] so
|
||||
# we don't need to call _get_test_db_name.
|
||||
orig_settings_dict = self.connection.settings_dict
|
||||
return {
|
||||
**orig_settings_dict,
|
||||
"NAME": "{}_{}".format(orig_settings_dict["NAME"], suffix),
|
||||
}
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
"""
|
||||
Internal implementation - duplicate the test db tables.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"The database backend doesn't support cloning databases. "
|
||||
"Disable the option to run tests in parallel processes."
|
||||
)
|
||||
|
||||
def destroy_test_db(
|
||||
self, old_database_name=None, verbosity=1, keepdb=False, suffix=None
|
||||
):
|
||||
"""
|
||||
Destroy a test database, prompting the user for confirmation if the
|
||||
database already exists.
|
||||
"""
|
||||
self.connection.close()
|
||||
if suffix is None:
|
||||
test_database_name = self.connection.settings_dict["NAME"]
|
||||
else:
|
||||
test_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||||
|
||||
if verbosity >= 1:
|
||||
action = "Destroying"
|
||||
if keepdb:
|
||||
action = "Preserving"
|
||||
self.log(
|
||||
"%s test database for alias %s..."
|
||||
% (
|
||||
action,
|
||||
self._get_database_display_str(verbosity, test_database_name),
|
||||
)
|
||||
)
|
||||
|
||||
# if we want to preserve the database
|
||||
# skip the actual destroying piece.
|
||||
if not keepdb:
|
||||
self._destroy_test_db(test_database_name, verbosity)
|
||||
|
||||
# Restore the original database name
|
||||
if old_database_name is not None:
|
||||
settings.DATABASES[self.connection.alias]["NAME"] = old_database_name
|
||||
self.connection.settings_dict["NAME"] = old_database_name
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
"""
|
||||
Internal implementation - remove the test db tables.
|
||||
"""
|
||||
# Remove the test database to clean up after
|
||||
# ourselves. Connect to the previous database (not the test database)
|
||||
# to do so, because it's not allowed to delete a database while being
|
||||
# connected to it.
|
||||
with self._nodb_cursor() as cursor:
|
||||
cursor.execute(
|
||||
"DROP DATABASE %s" % self.connection.ops.quote_name(test_database_name)
|
||||
)
|
||||
|
||||
def mark_expected_failures_and_skips(self):
|
||||
"""
|
||||
Mark tests in Django's test suite which are expected failures on this
|
||||
database and test which should be skipped on this database.
|
||||
"""
|
||||
# Only load unittest if we're actually testing.
|
||||
from unittest import expectedFailure, skip
|
||||
|
||||
for test_name in self.connection.features.django_test_expected_failures:
|
||||
test_case_name, _, test_method_name = test_name.rpartition(".")
|
||||
test_app = test_name.split(".")[0]
|
||||
# Importing a test app that isn't installed raises RuntimeError.
|
||||
if test_app in settings.INSTALLED_APPS:
|
||||
test_case = import_string(test_case_name)
|
||||
test_method = getattr(test_case, test_method_name)
|
||||
setattr(test_case, test_method_name, expectedFailure(test_method))
|
||||
for reason, tests in self.connection.features.django_test_skips.items():
|
||||
for test_name in tests:
|
||||
test_case_name, _, test_method_name = test_name.rpartition(".")
|
||||
test_app = test_name.split(".")[0]
|
||||
# Importing a test app that isn't installed raises RuntimeError.
|
||||
if test_app in settings.INSTALLED_APPS:
|
||||
test_case = import_string(test_case_name)
|
||||
test_method = getattr(test_case, test_method_name)
|
||||
setattr(test_case, test_method_name, skip(reason)(test_method))
|
||||
|
||||
def sql_table_creation_suffix(self):
|
||||
"""
|
||||
SQL to append to the end of the test table creation statements.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def test_db_signature(self):
|
||||
"""
|
||||
Return a tuple with elements of self.connection.settings_dict (a
|
||||
DATABASES setting value) that uniquely identify a database
|
||||
accordingly to the RDBMS particularities.
|
||||
"""
|
||||
settings_dict = self.connection.settings_dict
|
||||
return (
|
||||
settings_dict["HOST"],
|
||||
settings_dict["PORT"],
|
||||
settings_dict["ENGINE"],
|
||||
self._get_test_db_name(),
|
||||
)
|
||||
|
||||
def setup_worker_connection(self, _worker_id):
|
||||
settings_dict = self.get_test_db_clone_settings(str(_worker_id))
|
||||
# connection.settings_dict must be updated in place for changes to be
|
||||
# reflected in django.db.connections. If the following line assigned
|
||||
# connection.settings_dict = settings_dict, new threads would connect
|
||||
# to the default database instead of the appropriate clone.
|
||||
self.connection.settings_dict.update(settings_dict)
|
||||
self.connection.close()
|
||||
@@ -0,0 +1,423 @@
|
||||
from django.db import ProgrammingError
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class BaseDatabaseFeatures:
|
||||
# An optional tuple indicating the minimum supported database version.
|
||||
minimum_database_version = None
|
||||
gis_enabled = False
|
||||
# Oracle can't group by LOB (large object) data types.
|
||||
allows_group_by_lob = True
|
||||
allows_group_by_selected_pks = False
|
||||
allows_group_by_select_index = True
|
||||
empty_fetchmany_value = []
|
||||
update_can_self_select = True
|
||||
# Does the backend support self-reference subqueries in the DELETE
|
||||
# statement?
|
||||
delete_can_self_reference_subquery = True
|
||||
|
||||
# Does the backend distinguish between '' and None?
|
||||
interprets_empty_strings_as_nulls = False
|
||||
|
||||
# Does the backend allow inserting duplicate NULL rows in a nullable
|
||||
# unique field? All core backends implement this correctly, but other
|
||||
# databases such as SQL Server do not.
|
||||
supports_nullable_unique_constraints = True
|
||||
|
||||
# Does the backend allow inserting duplicate rows when a unique_together
|
||||
# constraint exists and some fields are nullable but not all of them?
|
||||
supports_partially_nullable_unique_constraints = True
|
||||
|
||||
# Does the backend supports specifying whether NULL values should be
|
||||
# considered distinct in unique constraints?
|
||||
supports_nulls_distinct_unique_constraints = False
|
||||
|
||||
# Does the backend support initially deferrable unique constraints?
|
||||
supports_deferrable_unique_constraints = False
|
||||
|
||||
can_use_chunked_reads = True
|
||||
can_return_columns_from_insert = False
|
||||
can_return_rows_from_bulk_insert = False
|
||||
has_bulk_insert = True
|
||||
uses_savepoints = True
|
||||
can_release_savepoints = False
|
||||
|
||||
# If True, don't use integer foreign keys referring to, e.g., positive
|
||||
# integer primary keys.
|
||||
related_fields_match_type = False
|
||||
allow_sliced_subqueries_with_in = True
|
||||
has_select_for_update = False
|
||||
has_select_for_update_nowait = False
|
||||
has_select_for_update_skip_locked = False
|
||||
has_select_for_update_of = False
|
||||
has_select_for_no_key_update = False
|
||||
# Does the database's SELECT FOR UPDATE OF syntax require a column rather
|
||||
# than a table?
|
||||
select_for_update_of_column = False
|
||||
|
||||
# Does the default test database allow multiple connections?
|
||||
# Usually an indication that the test database is in-memory
|
||||
test_db_allows_multiple_connections = True
|
||||
|
||||
# Can an object be saved without an explicit primary key?
|
||||
supports_unspecified_pk = False
|
||||
|
||||
# Can a fixture contain forward references? i.e., are
|
||||
# FK constraints checked at the end of transaction, or
|
||||
# at the end of each save operation?
|
||||
supports_forward_references = True
|
||||
|
||||
# Does the backend truncate names properly when they are too long?
|
||||
truncates_names = False
|
||||
|
||||
# Is there a REAL datatype in addition to floats/doubles?
|
||||
has_real_datatype = False
|
||||
supports_subqueries_in_group_by = True
|
||||
|
||||
# Does the backend ignore unnecessary ORDER BY clauses in subqueries?
|
||||
ignores_unnecessary_order_by_in_subqueries = True
|
||||
|
||||
# Is there a true datatype for uuid?
|
||||
has_native_uuid_field = False
|
||||
|
||||
# Is there a true datatype for timedeltas?
|
||||
has_native_duration_field = False
|
||||
|
||||
# Does the database driver supports same type temporal data subtraction
|
||||
# by returning the type used to store duration field?
|
||||
supports_temporal_subtraction = False
|
||||
|
||||
# Does the __regex lookup support backreferencing and grouping?
|
||||
supports_regex_backreferencing = True
|
||||
|
||||
# Can date/datetime lookups be performed using a string?
|
||||
supports_date_lookup_using_string = True
|
||||
|
||||
# Can datetimes with timezones be used?
|
||||
supports_timezones = True
|
||||
|
||||
# Does the database have a copy of the zoneinfo database?
|
||||
has_zoneinfo_database = True
|
||||
|
||||
# When performing a GROUP BY, is an ORDER BY NULL required
|
||||
# to remove any ordering?
|
||||
requires_explicit_null_ordering_when_grouping = False
|
||||
|
||||
# Does the backend order NULL values as largest or smallest?
|
||||
nulls_order_largest = False
|
||||
|
||||
# Does the backend support NULLS FIRST and NULLS LAST in ORDER BY?
|
||||
supports_order_by_nulls_modifier = True
|
||||
|
||||
# Does the backend orders NULLS FIRST by default?
|
||||
order_by_nulls_first = False
|
||||
|
||||
# The database's limit on the number of query parameters.
|
||||
max_query_params = None
|
||||
|
||||
# Can an object have an autoincrement primary key of 0?
|
||||
allows_auto_pk_0 = True
|
||||
|
||||
# Do we need to NULL a ForeignKey out, or can the constraint check be
|
||||
# deferred
|
||||
can_defer_constraint_checks = False
|
||||
|
||||
# Does the backend support tablespaces? Default to False because it isn't
|
||||
# in the SQL standard.
|
||||
supports_tablespaces = False
|
||||
|
||||
# Does the backend reset sequences between tests?
|
||||
supports_sequence_reset = True
|
||||
|
||||
# Can the backend introspect the default value of a column?
|
||||
can_introspect_default = True
|
||||
|
||||
# Confirm support for introspected foreign keys
|
||||
# Every database can do this reliably, except MySQL,
|
||||
# which can't do it for MyISAM tables
|
||||
can_introspect_foreign_keys = True
|
||||
|
||||
# Map fields which some backends may not be able to differentiate to the
|
||||
# field it's introspected as.
|
||||
introspected_field_types = {
|
||||
"AutoField": "AutoField",
|
||||
"BigAutoField": "BigAutoField",
|
||||
"BigIntegerField": "BigIntegerField",
|
||||
"BinaryField": "BinaryField",
|
||||
"BooleanField": "BooleanField",
|
||||
"CharField": "CharField",
|
||||
"DurationField": "DurationField",
|
||||
"GenericIPAddressField": "GenericIPAddressField",
|
||||
"IntegerField": "IntegerField",
|
||||
"PositiveBigIntegerField": "PositiveBigIntegerField",
|
||||
"PositiveIntegerField": "PositiveIntegerField",
|
||||
"PositiveSmallIntegerField": "PositiveSmallIntegerField",
|
||||
"SmallAutoField": "SmallAutoField",
|
||||
"SmallIntegerField": "SmallIntegerField",
|
||||
"TimeField": "TimeField",
|
||||
}
|
||||
|
||||
# Can the backend introspect the column order (ASC/DESC) for indexes?
|
||||
supports_index_column_ordering = True
|
||||
|
||||
# Does the backend support introspection of materialized views?
|
||||
can_introspect_materialized_views = False
|
||||
|
||||
# Support for the DISTINCT ON clause
|
||||
can_distinct_on_fields = False
|
||||
|
||||
# Does the backend prevent running SQL queries in broken transactions?
|
||||
atomic_transactions = True
|
||||
|
||||
# Can we roll back DDL in a transaction?
|
||||
can_rollback_ddl = False
|
||||
|
||||
schema_editor_uses_clientside_param_binding = False
|
||||
|
||||
# Can we issue more than one ALTER COLUMN clause in an ALTER TABLE?
|
||||
supports_combined_alters = False
|
||||
|
||||
# Does it support foreign keys?
|
||||
supports_foreign_keys = True
|
||||
|
||||
# Can it create foreign key constraints inline when adding columns?
|
||||
can_create_inline_fk = True
|
||||
|
||||
# Can an index be renamed?
|
||||
can_rename_index = False
|
||||
|
||||
# Does it automatically index foreign keys?
|
||||
indexes_foreign_keys = True
|
||||
|
||||
# Does it support CHECK constraints?
|
||||
supports_column_check_constraints = True
|
||||
supports_table_check_constraints = True
|
||||
# Does the backend support introspection of CHECK constraints?
|
||||
can_introspect_check_constraints = True
|
||||
|
||||
# Does the backend support 'pyformat' style ("... %(name)s ...", {'name': value})
|
||||
# parameter passing? Note this can be provided by the backend even if not
|
||||
# supported by the Python driver
|
||||
supports_paramstyle_pyformat = True
|
||||
|
||||
# Does the backend require literal defaults, rather than parameterized ones?
|
||||
requires_literal_defaults = False
|
||||
|
||||
# Does the backend support functions in defaults?
|
||||
supports_expression_defaults = True
|
||||
|
||||
# Does the backend support the DEFAULT keyword in insert queries?
|
||||
supports_default_keyword_in_insert = True
|
||||
|
||||
# Does the backend support the DEFAULT keyword in bulk insert queries?
|
||||
supports_default_keyword_in_bulk_insert = True
|
||||
|
||||
# Does the backend require a connection reset after each material schema change?
|
||||
connection_persists_old_columns = False
|
||||
|
||||
# What kind of error does the backend throw when accessing closed cursor?
|
||||
closed_cursor_error_class = ProgrammingError
|
||||
|
||||
# Does 'a' LIKE 'A' match?
|
||||
has_case_insensitive_like = False
|
||||
|
||||
# Suffix for backends that don't support "SELECT xxx;" queries.
|
||||
bare_select_suffix = ""
|
||||
|
||||
# If NULL is implied on columns without needing to be explicitly specified
|
||||
implied_column_null = False
|
||||
|
||||
# Does the backend support "select for update" queries with limit (and offset)?
|
||||
supports_select_for_update_with_limit = True
|
||||
|
||||
# Does the backend ignore null expressions in GREATEST and LEAST queries unless
|
||||
# every expression is null?
|
||||
greatest_least_ignores_nulls = False
|
||||
|
||||
# Can the backend clone databases for parallel test execution?
|
||||
# Defaults to False to allow third-party backends to opt-in.
|
||||
can_clone_databases = False
|
||||
|
||||
# Does the backend consider table names with different casing to
|
||||
# be equal?
|
||||
ignores_table_name_case = False
|
||||
|
||||
# Place FOR UPDATE right after FROM clause. Used on MSSQL.
|
||||
for_update_after_from = False
|
||||
|
||||
# Combinatorial flags
|
||||
supports_select_union = True
|
||||
supports_select_intersection = True
|
||||
supports_select_difference = True
|
||||
supports_slicing_ordering_in_compound = False
|
||||
supports_parentheses_in_compound = True
|
||||
requires_compound_order_by_subquery = False
|
||||
|
||||
# Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
|
||||
# expressions?
|
||||
supports_aggregate_filter_clause = False
|
||||
|
||||
# Does the backend support indexing a TextField?
|
||||
supports_index_on_text_field = True
|
||||
|
||||
# Does the backend support window expressions (expression OVER (...))?
|
||||
supports_over_clause = False
|
||||
supports_frame_range_fixed_distance = False
|
||||
supports_frame_exclusion = False
|
||||
only_supports_unbounded_with_preceding_and_following = False
|
||||
|
||||
# Does the backend support CAST with precision?
|
||||
supports_cast_with_precision = True
|
||||
|
||||
# How many second decimals does the database return when casting a value to
|
||||
# a type with time?
|
||||
time_cast_precision = 6
|
||||
|
||||
# SQL to create a procedure for use by the Django test suite. The
|
||||
# functionality of the procedure isn't important.
|
||||
create_test_procedure_without_params_sql = None
|
||||
create_test_procedure_with_int_param_sql = None
|
||||
|
||||
# SQL to create a table with a composite primary key for use by the Django
|
||||
# test suite.
|
||||
create_test_table_with_composite_primary_key = None
|
||||
|
||||
# Does the backend support keyword parameters for cursor.callproc()?
|
||||
supports_callproc_kwargs = False
|
||||
|
||||
# What formats does the backend EXPLAIN syntax support?
|
||||
supported_explain_formats = set()
|
||||
|
||||
# Does the backend support the default parameter in lead() and lag()?
|
||||
supports_default_in_lead_lag = True
|
||||
|
||||
# Does the backend support ignoring constraint or uniqueness errors during
|
||||
# INSERT?
|
||||
supports_ignore_conflicts = True
|
||||
# Does the backend support updating rows on constraint or uniqueness errors
|
||||
# during INSERT?
|
||||
supports_update_conflicts = False
|
||||
supports_update_conflicts_with_target = False
|
||||
|
||||
# Does this backend require casting the results of CASE expressions used
|
||||
# in UPDATE statements to ensure the expression has the correct type?
|
||||
requires_casted_case_in_updates = False
|
||||
|
||||
# Does the backend support partial indexes (CREATE INDEX ... WHERE ...)?
|
||||
supports_partial_indexes = True
|
||||
supports_functions_in_partial_indexes = True
|
||||
# Does the backend support covering indexes (CREATE INDEX ... INCLUDE ...)?
|
||||
supports_covering_indexes = False
|
||||
# Does the backend support indexes on expressions?
|
||||
supports_expression_indexes = True
|
||||
# Does the backend treat COLLATE as an indexed expression?
|
||||
collate_as_index_expression = False
|
||||
|
||||
# Does the database allow more than one constraint or index on the same
|
||||
# field(s)?
|
||||
allows_multiple_constraints_on_same_fields = True
|
||||
|
||||
# Does the backend support boolean expressions in SELECT and GROUP BY
|
||||
# clauses?
|
||||
supports_boolean_expr_in_select_clause = True
|
||||
# Does the backend support comparing boolean expressions in WHERE clauses?
|
||||
# Eg: WHERE (price > 0) IS NOT NULL
|
||||
supports_comparing_boolean_expr = True
|
||||
|
||||
# Does the backend support JSONField?
|
||||
supports_json_field = True
|
||||
# Can the backend introspect a JSONField?
|
||||
can_introspect_json_field = True
|
||||
# Does the backend support primitives in JSONField?
|
||||
supports_primitives_in_json_field = True
|
||||
# Is there a true datatype for JSON?
|
||||
has_native_json_field = False
|
||||
# Does the backend use PostgreSQL-style JSON operators like '->'?
|
||||
has_json_operators = False
|
||||
# Does the backend support __contains and __contained_by lookups for
|
||||
# a JSONField?
|
||||
supports_json_field_contains = True
|
||||
# Does value__d__contains={'f': 'g'} (without a list around the dict) match
|
||||
# {'d': [{'f': 'g'}]}?
|
||||
json_key_contains_list_matching_requires_list = False
|
||||
# Does the backend support JSONObject() database function?
|
||||
has_json_object_function = True
|
||||
|
||||
# Does the backend support column collations?
|
||||
supports_collation_on_charfield = True
|
||||
supports_collation_on_textfield = True
|
||||
# Does the backend support non-deterministic collations?
|
||||
supports_non_deterministic_collations = True
|
||||
|
||||
# Does the backend support column and table comments?
|
||||
supports_comments = False
|
||||
# Does the backend support column comments in ADD COLUMN statements?
|
||||
supports_comments_inline = False
|
||||
|
||||
# Does the backend support stored generated columns?
|
||||
supports_stored_generated_columns = False
|
||||
# Does the backend support virtual generated columns?
|
||||
supports_virtual_generated_columns = False
|
||||
|
||||
# Does the backend support the logical XOR operator?
|
||||
supports_logical_xor = False
|
||||
|
||||
# Set to (exception, message) if null characters in text are disallowed.
|
||||
prohibits_null_characters_in_text_exception = None
|
||||
|
||||
# Does the backend support unlimited character columns?
|
||||
supports_unlimited_charfield = False
|
||||
|
||||
# Does the backend support native tuple lookups (=, >, <, IN)?
|
||||
supports_tuple_lookups = True
|
||||
|
||||
# Collation names for use by the Django test suite.
|
||||
test_collations = {
|
||||
"ci": None, # Case-insensitive.
|
||||
"cs": None, # Case-sensitive.
|
||||
"non_default": None, # Non-default.
|
||||
"swedish_ci": None, # Swedish case-insensitive.
|
||||
"virtual": None, # A collation that can be used for virtual columns.
|
||||
}
|
||||
# SQL template override for tests.aggregation.tests.NowUTC
|
||||
test_now_utc_template = None
|
||||
|
||||
# SQL to create a model instance using the database defaults.
|
||||
insert_test_table_with_defaults = None
|
||||
|
||||
# Does the Round() database function round to even?
|
||||
rounds_to_even = False
|
||||
|
||||
# A set of dotted paths to tests in Django's test suite that are expected
|
||||
# to fail on this database.
|
||||
django_test_expected_failures = set()
|
||||
# A map of reasons to sets of dotted paths to tests in Django's test suite
|
||||
# that should be skipped for this database.
|
||||
django_test_skips = {}
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
@cached_property
|
||||
def supports_explaining_query_execution(self):
|
||||
"""Does this backend support explaining query execution?"""
|
||||
return self.connection.ops.explain_prefix is not None
|
||||
|
||||
@cached_property
|
||||
def supports_transactions(self):
|
||||
"""Confirm support for transactions."""
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("CREATE TABLE ROLLBACK_TEST (X INT)")
|
||||
self.connection.set_autocommit(False)
|
||||
cursor.execute("INSERT INTO ROLLBACK_TEST (X) VALUES (8)")
|
||||
self.connection.rollback()
|
||||
self.connection.set_autocommit(True)
|
||||
cursor.execute("SELECT COUNT(X) FROM ROLLBACK_TEST")
|
||||
(count,) = cursor.fetchone()
|
||||
cursor.execute("DROP TABLE ROLLBACK_TEST")
|
||||
return count == 0
|
||||
|
||||
def allows_group_by_selected_pks_on_model(self, model):
|
||||
if not self.allows_group_by_selected_pks:
|
||||
return False
|
||||
return model._meta.managed
|
||||
@@ -0,0 +1,212 @@
|
||||
from collections import namedtuple
|
||||
|
||||
# Structure returned by DatabaseIntrospection.get_table_list()
|
||||
TableInfo = namedtuple("TableInfo", ["name", "type"])
|
||||
|
||||
# Structure returned by the DB-API cursor.description interface (PEP 249)
|
||||
FieldInfo = namedtuple(
|
||||
"FieldInfo",
|
||||
"name type_code display_size internal_size precision scale null_ok "
|
||||
"default collation",
|
||||
)
|
||||
|
||||
|
||||
class BaseDatabaseIntrospection:
|
||||
"""Encapsulate backend-specific introspection utilities."""
|
||||
|
||||
data_types_reverse = {}
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
"""
|
||||
Hook for a database backend to use the cursor description to
|
||||
match a Django field type to a database column.
|
||||
|
||||
For Oracle, the column data_type on its own is insufficient to
|
||||
distinguish between a FloatField and IntegerField, for example.
|
||||
"""
|
||||
return self.data_types_reverse[data_type]
|
||||
|
||||
def identifier_converter(self, name):
|
||||
"""
|
||||
Apply a conversion to the identifier for the purposes of comparison.
|
||||
|
||||
The default identifier converter is for case sensitive comparison.
|
||||
"""
|
||||
return name
|
||||
|
||||
def table_names(self, cursor=None, include_views=False):
|
||||
"""
|
||||
Return a list of names of all tables that exist in the database.
|
||||
Sort the returned table list by Python's default sorting. Do NOT use
|
||||
the database's ORDER BY here to avoid subtle differences in sorting
|
||||
order between databases.
|
||||
"""
|
||||
|
||||
def get_names(cursor):
|
||||
return sorted(
|
||||
ti.name
|
||||
for ti in self.get_table_list(cursor)
|
||||
if include_views or ti.type == "t"
|
||||
)
|
||||
|
||||
if cursor is None:
|
||||
with self.connection.cursor() as cursor:
|
||||
return get_names(cursor)
|
||||
return get_names(cursor)
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""
|
||||
Return an unsorted list of TableInfo named tuples of all tables and
|
||||
views that exist in the database.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a get_table_list() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a "
|
||||
"get_table_description() method."
|
||||
)
|
||||
|
||||
def get_migratable_models(self):
|
||||
from django.apps import apps
|
||||
from django.db import router
|
||||
|
||||
return (
|
||||
model
|
||||
for app_config in apps.get_app_configs()
|
||||
for model in router.get_migratable_models(app_config, self.connection.alias)
|
||||
if model._meta.can_migrate(self.connection)
|
||||
)
|
||||
|
||||
def django_table_names(self, only_existing=False, include_views=True):
|
||||
"""
|
||||
Return a list of all table names that have associated Django models and
|
||||
are in INSTALLED_APPS.
|
||||
|
||||
If only_existing is True, include only the tables in the database.
|
||||
"""
|
||||
tables = set()
|
||||
for model in self.get_migratable_models():
|
||||
if not model._meta.managed:
|
||||
continue
|
||||
tables.add(model._meta.db_table)
|
||||
tables.update(
|
||||
f.m2m_db_table()
|
||||
for f in model._meta.local_many_to_many
|
||||
if f.remote_field.through._meta.managed
|
||||
)
|
||||
tables = list(tables)
|
||||
if only_existing:
|
||||
existing_tables = set(self.table_names(include_views=include_views))
|
||||
tables = [
|
||||
t for t in tables if self.identifier_converter(t) in existing_tables
|
||||
]
|
||||
return tables
|
||||
|
||||
def installed_models(self, tables):
|
||||
"""
|
||||
Return a set of all models represented by the provided list of table
|
||||
names.
|
||||
"""
|
||||
tables = set(map(self.identifier_converter, tables))
|
||||
return {
|
||||
m
|
||||
for m in self.get_migratable_models()
|
||||
if self.identifier_converter(m._meta.db_table) in tables
|
||||
}
|
||||
|
||||
def sequence_list(self):
|
||||
"""
|
||||
Return a list of information about all DB sequences for all models in
|
||||
all apps.
|
||||
"""
|
||||
sequence_list = []
|
||||
with self.connection.cursor() as cursor:
|
||||
for model in self.get_migratable_models():
|
||||
if not model._meta.managed:
|
||||
continue
|
||||
if model._meta.swapped:
|
||||
continue
|
||||
sequence_list.extend(
|
||||
self.get_sequences(
|
||||
cursor, model._meta.db_table, model._meta.local_fields
|
||||
)
|
||||
)
|
||||
for f in model._meta.local_many_to_many:
|
||||
# If this is an m2m using an intermediate table,
|
||||
# we don't need to reset the sequence.
|
||||
if f.remote_field.through._meta.auto_created:
|
||||
sequence = self.get_sequences(cursor, f.m2m_db_table())
|
||||
sequence_list.extend(
|
||||
sequence or [{"table": f.m2m_db_table(), "column": None}]
|
||||
)
|
||||
return sequence_list
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
"""
|
||||
Return a list of introspected sequences for table_name. Each sequence
|
||||
is a dict: {'table': <table_name>, 'column': <column_name>}. An optional
|
||||
'name' key can be added if the backend supports named sequences.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a get_sequences() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a "
|
||||
"get_relations() method."
|
||||
)
|
||||
|
||||
def get_primary_key_column(self, cursor, table_name):
|
||||
"""
|
||||
Return the name of the primary key column for the given table.
|
||||
"""
|
||||
columns = self.get_primary_key_columns(cursor, table_name)
|
||||
return columns[0] if columns else None
|
||||
|
||||
def get_primary_key_columns(self, cursor, table_name):
|
||||
"""Return a list of primary key columns for the given table."""
|
||||
for constraint in self.get_constraints(cursor, table_name).values():
|
||||
if constraint["primary_key"]:
|
||||
return constraint["columns"]
|
||||
return None
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index)
|
||||
across one or more columns.
|
||||
|
||||
Return a dict mapping constraint names to their attributes,
|
||||
where attributes is a dict with keys:
|
||||
* columns: List of columns this covers
|
||||
* primary_key: True if primary key, False otherwise
|
||||
* unique: True if this is a unique constraint, False otherwise
|
||||
* foreign_key: (table, column) of target, or None
|
||||
* check: True if check constraint, False otherwise
|
||||
* index: True if index, False otherwise.
|
||||
* orders: The order (ASC/DESC) defined for the columns of indexes
|
||||
* type: The type of the index (btree, hash, etc.)
|
||||
|
||||
Some backends may return special constraint names that don't exist
|
||||
if they don't name constraints of a certain type (e.g. SQLite)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseIntrospection may require a get_constraints() "
|
||||
"method"
|
||||
)
|
||||
@@ -0,0 +1,806 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import json
|
||||
import warnings
|
||||
from importlib import import_module
|
||||
|
||||
import sqlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import NotSupportedError, transaction
|
||||
from django.db.models.expressions import Col
|
||||
from django.utils import timezone
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
|
||||
class BaseDatabaseOperations:
|
||||
"""
|
||||
Encapsulate backend-specific differences, such as the way a backend
|
||||
performs ordering or calculates the ID of a recently-inserted row.
|
||||
"""
|
||||
|
||||
compiler_module = "django.db.models.sql.compiler"
|
||||
|
||||
# Integer field safe ranges by `internal_type` as documented
|
||||
# in docs/ref/models/fields.txt.
|
||||
integer_field_ranges = {
|
||||
"SmallIntegerField": (-32768, 32767),
|
||||
"IntegerField": (-2147483648, 2147483647),
|
||||
"BigIntegerField": (-9223372036854775808, 9223372036854775807),
|
||||
"PositiveBigIntegerField": (0, 9223372036854775807),
|
||||
"PositiveSmallIntegerField": (0, 32767),
|
||||
"PositiveIntegerField": (0, 2147483647),
|
||||
"SmallAutoField": (-32768, 32767),
|
||||
"AutoField": (-2147483648, 2147483647),
|
||||
"BigAutoField": (-9223372036854775808, 9223372036854775807),
|
||||
}
|
||||
set_operators = {
|
||||
"union": "UNION",
|
||||
"intersection": "INTERSECT",
|
||||
"difference": "EXCEPT",
|
||||
}
|
||||
# Mapping of Field.get_internal_type() (typically the model field's class
|
||||
# name) to the data type to use for the Cast() function, if different from
|
||||
# DatabaseWrapper.data_types.
|
||||
cast_data_types = {}
|
||||
# CharField data type if the max_length argument isn't provided.
|
||||
cast_char_field_without_max_length = None
|
||||
|
||||
# Start and end points for window expressions.
|
||||
PRECEDING = "PRECEDING"
|
||||
FOLLOWING = "FOLLOWING"
|
||||
UNBOUNDED_PRECEDING = "UNBOUNDED " + PRECEDING
|
||||
UNBOUNDED_FOLLOWING = "UNBOUNDED " + FOLLOWING
|
||||
CURRENT_ROW = "CURRENT ROW"
|
||||
|
||||
# Prefix for EXPLAIN queries, or None EXPLAIN isn't supported.
|
||||
explain_prefix = None
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self._cache = None
|
||||
|
||||
def autoinc_sql(self, table, column):
|
||||
"""
|
||||
Return any SQL needed to support auto-incrementing primary keys, or
|
||||
None if no SQL is necessary.
|
||||
|
||||
This SQL is executed when a table is created.
|
||||
"""
|
||||
return None
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""
|
||||
Return the maximum allowed batch size for the backend. The fields
|
||||
are the fields going to be inserted in the batch, the objs contains
|
||||
all the objects to be inserted.
|
||||
"""
|
||||
return len(objs)
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a "
|
||||
"format_for_duration_arithmetic() method."
|
||||
)
|
||||
|
||||
def cache_key_culling_sql(self):
|
||||
"""
|
||||
Return an SQL query that retrieves the first cache key greater than the
|
||||
n smallest.
|
||||
|
||||
This is used by the 'db' cache backend to determine where to start
|
||||
culling.
|
||||
"""
|
||||
cache_key = self.quote_name("cache_key")
|
||||
return f"SELECT {cache_key} FROM %s ORDER BY {cache_key} LIMIT 1 OFFSET %%s"
|
||||
|
||||
def unification_cast_sql(self, output_field):
|
||||
"""
|
||||
Given a field instance, return the SQL that casts the result of a union
|
||||
to that type. The resulting string should contain a '%s' placeholder
|
||||
for the expression being cast.
|
||||
"""
|
||||
return "%s"
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
|
||||
extracts a value from the given date field field_name.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a date_extract_sql() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', or 'day', return the SQL that
|
||||
truncates the given date or datetime field field_name to a date object
|
||||
with only the given specificity.
|
||||
|
||||
If `tzname` is provided, the given value is truncated in a specific
|
||||
timezone.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a date_trunc_sql() "
|
||||
"method."
|
||||
)
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
"""
|
||||
Return the SQL to cast a datetime value to date value.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a "
|
||||
"datetime_cast_date_sql() method."
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
"""
|
||||
Return the SQL to cast a datetime value to time value.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a "
|
||||
"datetime_cast_time_sql() method"
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
|
||||
'second', return the SQL that extracts a value from the given
|
||||
datetime field field_name.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a datetime_extract_sql() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
"""
|
||||
Given a lookup_type of 'year', 'month', 'day', 'hour', 'minute', or
|
||||
'second', return the SQL that truncates the given datetime field
|
||||
field_name to a datetime object with only the given specificity.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a datetime_trunc_sql() "
|
||||
"method"
|
||||
)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
"""
|
||||
Given a lookup_type of 'hour', 'minute' or 'second', return the SQL
|
||||
that truncates the given time or datetime field field_name to a time
|
||||
object with only the given specificity.
|
||||
|
||||
If `tzname` is provided, the given value is truncated in a specific
|
||||
timezone.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a time_trunc_sql() method"
|
||||
)
|
||||
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
"""
|
||||
Given a lookup_type of 'hour', 'minute', or 'second', return the SQL
|
||||
that extracts a value from the given time field field_name.
|
||||
"""
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def deferrable_sql(self):
|
||||
"""
|
||||
Return the SQL to make a constraint "initially deferred" during a
|
||||
CREATE TABLE statement.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def distinct_sql(self, fields, params):
|
||||
"""
|
||||
Return an SQL DISTINCT clause which removes duplicate rows from the
|
||||
result set. If any fields are given, only check the given fields for
|
||||
duplicates.
|
||||
"""
|
||||
if fields:
|
||||
raise NotSupportedError(
|
||||
"DISTINCT ON fields is not supported by this database backend"
|
||||
)
|
||||
else:
|
||||
return ["DISTINCT"], []
|
||||
|
||||
def fetch_returned_insert_columns(self, cursor, returning_params):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the newly created data.
|
||||
"""
|
||||
return cursor.fetchone()
|
||||
|
||||
def field_cast_sql(self, db_type, internal_type):
|
||||
"""
|
||||
Given a column type (e.g. 'BLOB', 'VARCHAR') and an internal type
|
||||
(e.g. 'GenericIPAddressField'), return the SQL to cast it before using
|
||||
it in a WHERE statement. The resulting string should contain a '%s'
|
||||
placeholder for the column being searched against.
|
||||
"""
|
||||
warnings.warn(
|
||||
(
|
||||
"DatabaseOperations.field_cast_sql() is deprecated use "
|
||||
"DatabaseOperations.lookup_cast() instead."
|
||||
),
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return "%s"
|
||||
|
||||
def force_group_by(self):
|
||||
"""
|
||||
Return a GROUP BY clause to use with a HAVING clause when no grouping
|
||||
is specified.
|
||||
"""
|
||||
return []
|
||||
|
||||
def force_no_ordering(self):
|
||||
"""
|
||||
Return a list used in the "ORDER BY" clause to force no ordering at
|
||||
all. Return an empty list to include nothing in the ordering.
|
||||
"""
|
||||
return []
|
||||
|
||||
def for_update_sql(self, nowait=False, skip_locked=False, of=(), no_key=False):
|
||||
"""
|
||||
Return the FOR UPDATE SQL clause to lock rows for an update operation.
|
||||
"""
|
||||
return "FOR%s UPDATE%s%s%s" % (
|
||||
" NO KEY" if no_key else "",
|
||||
" OF %s" % ", ".join(of) if of else "",
|
||||
" NOWAIT" if nowait else "",
|
||||
" SKIP LOCKED" if skip_locked else "",
|
||||
)
|
||||
|
||||
def _get_limit_offset_params(self, low_mark, high_mark):
|
||||
offset = low_mark or 0
|
||||
if high_mark is not None:
|
||||
return (high_mark - offset), offset
|
||||
elif offset:
|
||||
return self.connection.ops.no_limit_value(), offset
|
||||
return None, offset
|
||||
|
||||
def limit_offset_sql(self, low_mark, high_mark):
|
||||
"""Return LIMIT/OFFSET SQL clause."""
|
||||
limit, offset = self._get_limit_offset_params(low_mark, high_mark)
|
||||
return " ".join(
|
||||
sql
|
||||
for sql in (
|
||||
("LIMIT %d" % limit) if limit else None,
|
||||
("OFFSET %d" % offset) if offset else None,
|
||||
)
|
||||
if sql
|
||||
)
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
|
||||
values_sql = ", ".join([f"({sql})" for sql in placeholder_rows_sql])
|
||||
return f"VALUES {values_sql}"
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
"""
|
||||
Return a string of the query last executed by the given cursor, with
|
||||
placeholders replaced with actual values.
|
||||
|
||||
`sql` is the raw query containing placeholders and `params` is the
|
||||
sequence of parameters. These are used by default, but this method
|
||||
exists for database backends to provide a better implementation
|
||||
according to their own quoting schemes.
|
||||
"""
|
||||
|
||||
# Convert params to contain string values.
|
||||
def to_string(s):
|
||||
return force_str(s, strings_only=True, errors="replace")
|
||||
|
||||
if isinstance(params, (list, tuple)):
|
||||
u_params = tuple(to_string(val) for val in params)
|
||||
elif params is None:
|
||||
u_params = ()
|
||||
else:
|
||||
u_params = {to_string(k): to_string(v) for k, v in params.items()}
|
||||
|
||||
return "QUERY = %r - PARAMS = %r" % (sql, u_params)
|
||||
|
||||
def last_insert_id(self, cursor, table_name, pk_name):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT statement into
|
||||
a table that has an auto-incrementing ID, return the newly created ID.
|
||||
|
||||
`pk_name` is the name of the primary-key column.
|
||||
"""
|
||||
return cursor.lastrowid
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
"""
|
||||
Return the string to use in a query when performing lookups
|
||||
("contains", "like", etc.). It should contain a '%s' placeholder for
|
||||
the column being searched against.
|
||||
"""
|
||||
return "%s"
|
||||
|
||||
def max_in_list_size(self):
|
||||
"""
|
||||
Return the maximum number of items that can be passed in a single 'IN'
|
||||
list condition, or None if the backend does not impose a limit.
|
||||
"""
|
||||
return None
|
||||
|
||||
def max_name_length(self):
|
||||
"""
|
||||
Return the maximum length of table and column names, or None if there
|
||||
is no limit.
|
||||
"""
|
||||
return None
|
||||
|
||||
def no_limit_value(self):
|
||||
"""
|
||||
Return the value to use for the LIMIT when we are wanting "LIMIT
|
||||
infinity". Return None if the limit clause can be omitted in this case.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a no_limit_value() method"
|
||||
)
|
||||
|
||||
def pk_default_value(self):
|
||||
"""
|
||||
Return the value to use during an INSERT statement to specify that
|
||||
the field should use its default value.
|
||||
"""
|
||||
return "DEFAULT"
|
||||
|
||||
def prepare_sql_script(self, sql):
|
||||
"""
|
||||
Take an SQL script that may contain multiple lines and return a list
|
||||
of statements to feed to successive cursor.execute() calls.
|
||||
|
||||
Since few databases are able to process raw SQL scripts in a single
|
||||
cursor.execute() call and PEP 249 doesn't talk about this use case,
|
||||
the default implementation is conservative.
|
||||
"""
|
||||
return [
|
||||
sqlparse.format(statement, strip_comments=True)
|
||||
for statement in sqlparse.split(sql)
|
||||
if statement
|
||||
]
|
||||
|
||||
def process_clob(self, value):
|
||||
"""
|
||||
Return the value of a CLOB column, for backends that return a locator
|
||||
object that requires additional processing.
|
||||
"""
|
||||
return value
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
"""
|
||||
For backends that support returning columns as part of an insert query,
|
||||
return the SQL and params to append to the INSERT query. The returned
|
||||
fragment should contain a format string to hold the appropriate column.
|
||||
"""
|
||||
pass
|
||||
|
||||
def compiler(self, compiler_name):
|
||||
"""
|
||||
Return the SQLCompiler class corresponding to the given name,
|
||||
in the namespace corresponding to the `compiler_module` attribute
|
||||
on this backend.
|
||||
"""
|
||||
if self._cache is None:
|
||||
self._cache = import_module(self.compiler_module)
|
||||
return getattr(self._cache, compiler_name)
|
||||
|
||||
def quote_name(self, name):
|
||||
"""
|
||||
Return a quoted version of the given table, index, or column name. Do
|
||||
not quote the given name if it's already been quoted.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a quote_name() method"
|
||||
)
|
||||
|
||||
def regex_lookup(self, lookup_type):
|
||||
"""
|
||||
Return the string to use in a query when performing regular expression
|
||||
lookups (using "regex" or "iregex"). It should contain a '%s'
|
||||
placeholder for the column being searched against.
|
||||
|
||||
If the feature is not supported (or part of it is not supported), raise
|
||||
NotImplementedError.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations may require a regex_lookup() method"
|
||||
)
|
||||
|
||||
def savepoint_create_sql(self, sid):
|
||||
"""
|
||||
Return the SQL for starting a new savepoint. Only required if the
|
||||
"uses_savepoints" feature is True. The "sid" parameter is a string
|
||||
for the savepoint id.
|
||||
"""
|
||||
return "SAVEPOINT %s" % self.quote_name(sid)
|
||||
|
||||
def savepoint_commit_sql(self, sid):
|
||||
"""
|
||||
Return the SQL for committing the given savepoint.
|
||||
"""
|
||||
return "RELEASE SAVEPOINT %s" % self.quote_name(sid)
|
||||
|
||||
def savepoint_rollback_sql(self, sid):
|
||||
"""
|
||||
Return the SQL for rolling back the given savepoint.
|
||||
"""
|
||||
return "ROLLBACK TO SAVEPOINT %s" % self.quote_name(sid)
|
||||
|
||||
def set_time_zone_sql(self):
|
||||
"""
|
||||
Return the SQL that will set the connection's time zone.
|
||||
|
||||
Return '' if the backend doesn't support time zones.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
"""
|
||||
Return a list of SQL statements required to remove all data from
|
||||
the given database tables (without actually removing the tables
|
||||
themselves).
|
||||
|
||||
The `style` argument is a Style object as returned by either
|
||||
color_style() or no_style() in django.core.management.color.
|
||||
|
||||
If `reset_sequences` is True, the list includes SQL statements required
|
||||
to reset the sequences.
|
||||
|
||||
The `allow_cascade` argument determines whether truncation may cascade
|
||||
to tables with foreign keys pointing the tables being truncated.
|
||||
PostgreSQL requires a cascade even if these tables are empty.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of BaseDatabaseOperations must provide an sql_flush() method"
|
||||
)
|
||||
|
||||
def execute_sql_flush(self, sql_list):
|
||||
"""Execute a list of SQL statements to flush the database."""
|
||||
with transaction.atomic(
|
||||
using=self.connection.alias,
|
||||
savepoint=self.connection.features.can_rollback_ddl,
|
||||
):
|
||||
with self.connection.cursor() as cursor:
|
||||
for sql in sql_list:
|
||||
cursor.execute(sql)
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
"""
|
||||
Return a list of the SQL statements required to reset sequences
|
||||
passed in `sequences`.
|
||||
|
||||
The `style` argument is a Style object as returned by either
|
||||
color_style() or no_style() in django.core.management.color.
|
||||
"""
|
||||
return []
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
"""
|
||||
Return a list of the SQL statements required to reset sequences for
|
||||
the given models.
|
||||
|
||||
The `style` argument is a Style object as returned by either
|
||||
color_style() or no_style() in django.core.management.color.
|
||||
"""
|
||||
return [] # No sequence reset required by default.
|
||||
|
||||
def start_transaction_sql(self):
|
||||
"""Return the SQL statement required to start a transaction."""
|
||||
return "BEGIN;"
|
||||
|
||||
def end_transaction_sql(self, success=True):
|
||||
"""Return the SQL statement required to end a transaction."""
|
||||
if not success:
|
||||
return "ROLLBACK;"
|
||||
return "COMMIT;"
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
"""
|
||||
Return the SQL that will be used in a query to define the tablespace.
|
||||
|
||||
Return '' if the backend doesn't support tablespaces.
|
||||
|
||||
If `inline` is True, append the SQL to a row; otherwise append it to
|
||||
the entire CREATE TABLE or CREATE INDEX statement.
|
||||
"""
|
||||
return ""
|
||||
|
||||
def prep_for_like_query(self, x):
|
||||
"""Prepare a value for use in a LIKE query."""
|
||||
return str(x).replace("\\", "\\\\").replace("%", r"\%").replace("_", r"\_")
|
||||
|
||||
# Same as prep_for_like_query(), but called for "iexact" matches, which
|
||||
# need not necessarily be implemented using "LIKE" in the backend.
|
||||
prep_for_iexact_query = prep_for_like_query
|
||||
|
||||
def validate_autopk_value(self, value):
|
||||
"""
|
||||
Certain backends do not accept some values for "serial" fields
|
||||
(for example zero in MySQL). Raise a ValueError if the value is
|
||||
invalid, otherwise return the validated value.
|
||||
"""
|
||||
return value
|
||||
|
||||
def adapt_unknown_value(self, value):
|
||||
"""
|
||||
Transform a value to something compatible with the backend driver.
|
||||
|
||||
This method only depends on the type of the value. It's designed for
|
||||
cases where the target type isn't known, such as .raw() SQL queries.
|
||||
As a consequence it may not work perfectly in all circumstances.
|
||||
"""
|
||||
if isinstance(value, datetime.datetime): # must be before date
|
||||
return self.adapt_datetimefield_value(value)
|
||||
elif isinstance(value, datetime.date):
|
||||
return self.adapt_datefield_value(value)
|
||||
elif isinstance(value, datetime.time):
|
||||
return self.adapt_timefield_value(value)
|
||||
elif isinstance(value, decimal.Decimal):
|
||||
return self.adapt_decimalfield_value(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
def adapt_integerfield_value(self, value, internal_type):
|
||||
return value
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
"""
|
||||
Transform a date value to an object compatible with what is expected
|
||||
by the backend driver for date columns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
"""
|
||||
Transform a datetime value to an object compatible with what is expected
|
||||
by the backend driver for datetime columns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
"""
|
||||
Transform a time value to an object compatible with what is expected
|
||||
by the backend driver for time columns.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("Django does not support timezone-aware times.")
|
||||
return str(value)
|
||||
|
||||
def adapt_decimalfield_value(self, value, max_digits=None, decimal_places=None):
|
||||
"""
|
||||
Transform a decimal.Decimal value to an object compatible with what is
|
||||
expected by the backend driver for decimal (numeric) columns.
|
||||
"""
|
||||
return value
|
||||
|
||||
def adapt_ipaddressfield_value(self, value):
|
||||
"""
|
||||
Transform a string representation of an IP address into the expected
|
||||
type for the backend driver.
|
||||
"""
|
||||
return value or None
|
||||
|
||||
def adapt_json_value(self, value, encoder):
|
||||
return json.dumps(value, cls=encoder)
|
||||
|
||||
def year_lookup_bounds_for_date_field(self, value, iso_year=False):
|
||||
"""
|
||||
Return a two-elements list with the lower and upper bound to be used
|
||||
with a BETWEEN operator to query a DateField value using a year
|
||||
lookup.
|
||||
|
||||
`value` is an int, containing the looked-up year.
|
||||
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
|
||||
"""
|
||||
if iso_year:
|
||||
first = datetime.date.fromisocalendar(value, 1, 1)
|
||||
second = datetime.date.fromisocalendar(
|
||||
value + 1, 1, 1
|
||||
) - datetime.timedelta(days=1)
|
||||
else:
|
||||
first = datetime.date(value, 1, 1)
|
||||
second = datetime.date(value, 12, 31)
|
||||
first = self.adapt_datefield_value(first)
|
||||
second = self.adapt_datefield_value(second)
|
||||
return [first, second]
|
||||
|
||||
def year_lookup_bounds_for_datetime_field(self, value, iso_year=False):
|
||||
"""
|
||||
Return a two-elements list with the lower and upper bound to be used
|
||||
with a BETWEEN operator to query a DateTimeField value using a year
|
||||
lookup.
|
||||
|
||||
`value` is an int, containing the looked-up year.
|
||||
If `iso_year` is True, return bounds for ISO-8601 week-numbering years.
|
||||
"""
|
||||
if iso_year:
|
||||
first = datetime.datetime.fromisocalendar(value, 1, 1)
|
||||
second = datetime.datetime.fromisocalendar(
|
||||
value + 1, 1, 1
|
||||
) - datetime.timedelta(microseconds=1)
|
||||
else:
|
||||
first = datetime.datetime(value, 1, 1)
|
||||
second = datetime.datetime(value, 12, 31, 23, 59, 59, 999999)
|
||||
if settings.USE_TZ:
|
||||
tz = timezone.get_current_timezone()
|
||||
first = timezone.make_aware(first, tz)
|
||||
second = timezone.make_aware(second, tz)
|
||||
first = self.adapt_datetimefield_value(first)
|
||||
second = self.adapt_datetimefield_value(second)
|
||||
return [first, second]
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
"""
|
||||
Return a list of functions needed to convert field data.
|
||||
|
||||
Some field types on some backends do not provide data in the correct
|
||||
format, this is the hook for converter functions.
|
||||
"""
|
||||
return []
|
||||
|
||||
def convert_durationfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
return datetime.timedelta(0, 0, value)
|
||||
|
||||
def check_expression_support(self, expression):
|
||||
"""
|
||||
Check that the backend supports the provided expression.
|
||||
|
||||
This is used on specific backends to rule out known expressions
|
||||
that have problematic or nonexistent implementations. If the
|
||||
expression has a known problem, the backend should raise
|
||||
NotSupportedError.
|
||||
"""
|
||||
pass
|
||||
|
||||
def conditional_expression_supported_in_where_clause(self, expression):
|
||||
"""
|
||||
Return True, if the conditional expression is supported in the WHERE
|
||||
clause.
|
||||
"""
|
||||
return True
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
"""
|
||||
Combine a list of subexpressions into a single expression, using
|
||||
the provided connecting operator. This is required because operators
|
||||
can vary between backends (e.g., Oracle with %% and &) and between
|
||||
subexpression types (e.g., date expressions).
|
||||
"""
|
||||
conn = " %s " % connector
|
||||
return conn.join(sub_expressions)
|
||||
|
||||
def combine_duration_expression(self, connector, sub_expressions):
|
||||
return self.combine_expression(connector, sub_expressions)
|
||||
|
||||
def binary_placeholder_sql(self, value):
|
||||
"""
|
||||
Some backends require special syntax to insert binary content (MySQL
|
||||
for example uses '_binary %s').
|
||||
"""
|
||||
return "%s"
|
||||
|
||||
def modify_insert_params(self, placeholder, params):
|
||||
"""
|
||||
Allow modification of insert parameters. Needed for Oracle Spatial
|
||||
backend due to #10888.
|
||||
"""
|
||||
return params
|
||||
|
||||
def integer_field_range(self, internal_type):
|
||||
"""
|
||||
Given an integer field internal type (e.g. 'PositiveIntegerField'),
|
||||
return a tuple of the (min_value, max_value) form representing the
|
||||
range of the column type bound to the field.
|
||||
"""
|
||||
return self.integer_field_ranges[internal_type]
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
if self.connection.features.supports_temporal_subtraction:
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
return "(%s - %s)" % (lhs_sql, rhs_sql), (*lhs_params, *rhs_params)
|
||||
raise NotSupportedError(
|
||||
"This backend does not support %s subtraction." % internal_type
|
||||
)
|
||||
|
||||
def window_frame_value(self, value):
|
||||
if isinstance(value, int):
|
||||
if value == 0:
|
||||
return self.CURRENT_ROW
|
||||
elif value < 0:
|
||||
return "%d %s" % (abs(value), self.PRECEDING)
|
||||
else:
|
||||
return "%d %s" % (value, self.FOLLOWING)
|
||||
|
||||
def window_frame_rows_start_end(self, start=None, end=None):
|
||||
"""
|
||||
Return SQL for start and end points in an OVER clause window frame.
|
||||
"""
|
||||
if isinstance(start, int) and isinstance(end, int) and start > end:
|
||||
raise ValueError("start cannot be greater than end.")
|
||||
if start is not None and not isinstance(start, int):
|
||||
raise ValueError(
|
||||
f"start argument must be an integer, zero, or None, but got '{start}'."
|
||||
)
|
||||
if end is not None and not isinstance(end, int):
|
||||
raise ValueError(
|
||||
f"end argument must be an integer, zero, or None, but got '{end}'."
|
||||
)
|
||||
start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING
|
||||
end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING
|
||||
return start_, end_
|
||||
|
||||
def window_frame_range_start_end(self, start=None, end=None):
|
||||
if (start is not None and not isinstance(start, int)) or (
|
||||
isinstance(start, int) and start > 0
|
||||
):
|
||||
raise ValueError(
|
||||
"start argument must be a negative integer, zero, or None, "
|
||||
"but got '%s'." % start
|
||||
)
|
||||
if (end is not None and not isinstance(end, int)) or (
|
||||
isinstance(end, int) and end < 0
|
||||
):
|
||||
raise ValueError(
|
||||
"end argument must be a positive integer, zero, or None, but got '%s'."
|
||||
% end
|
||||
)
|
||||
start_ = self.window_frame_value(start) or self.UNBOUNDED_PRECEDING
|
||||
end_ = self.window_frame_value(end) or self.UNBOUNDED_FOLLOWING
|
||||
features = self.connection.features
|
||||
if features.only_supports_unbounded_with_preceding_and_following and (
|
||||
(start and start < 0) or (end and end > 0)
|
||||
):
|
||||
raise NotSupportedError(
|
||||
"%s only supports UNBOUNDED together with PRECEDING and "
|
||||
"FOLLOWING." % self.connection.display_name
|
||||
)
|
||||
return start_, end_
|
||||
|
||||
def explain_query_prefix(self, format=None, **options):
|
||||
if not self.connection.features.supports_explaining_query_execution:
|
||||
raise NotSupportedError(
|
||||
"This backend does not support explaining query execution."
|
||||
)
|
||||
if format:
|
||||
supported_formats = self.connection.features.supported_explain_formats
|
||||
normalized_format = format.upper()
|
||||
if normalized_format not in supported_formats:
|
||||
msg = "%s is not a recognized format." % normalized_format
|
||||
if supported_formats:
|
||||
msg += " Allowed formats: %s" % ", ".join(sorted(supported_formats))
|
||||
else:
|
||||
msg += (
|
||||
f" {self.connection.display_name} does not support any formats."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if options:
|
||||
raise ValueError("Unknown options: %s" % ", ".join(sorted(options.keys())))
|
||||
return self.explain_prefix
|
||||
|
||||
def insert_statement(self, on_conflict=None):
|
||||
return "INSERT INTO"
|
||||
|
||||
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
||||
return ""
|
||||
|
||||
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
|
||||
lhs_expr = Col(lhs_table, lhs_field)
|
||||
rhs_expr = Col(rhs_table, rhs_field)
|
||||
|
||||
return lhs_expr, rhs_expr
|
||||
2046
.venv/lib/python3.10/site-packages/django/db/backends/base/schema.py
Normal file
2046
.venv/lib/python3.10/site-packages/django/db/backends/base/schema.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,29 @@
|
||||
class BaseDatabaseValidation:
|
||||
"""Encapsulate backend-specific validation."""
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def check(self, **kwargs):
|
||||
return []
|
||||
|
||||
def check_field(self, field, **kwargs):
|
||||
errors = []
|
||||
# Backends may implement a check_field_type() method.
|
||||
if (
|
||||
hasattr(self, "check_field_type")
|
||||
and
|
||||
# Ignore any related fields.
|
||||
not getattr(field, "remote_field", None)
|
||||
):
|
||||
# Ignore fields with unsupported features.
|
||||
db_supports_all_required_features = all(
|
||||
getattr(self.connection.features, feature, False)
|
||||
for feature in field.model._meta.required_db_features
|
||||
)
|
||||
if db_supports_all_required_features:
|
||||
field_type = field.db_type(self.connection)
|
||||
# Ignore non-concrete fields.
|
||||
if field_type is not None:
|
||||
errors.extend(self.check_field_type(field, field_type))
|
||||
return errors
|
||||
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Helpers to manipulate deferred DDL statements that might need to be adjusted or
|
||||
discarded within when executing a migration.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class Reference:
|
||||
"""Base class that defines the reference interface."""
|
||||
|
||||
def references_table(self, table):
|
||||
"""
|
||||
Return whether or not this instance references the specified table.
|
||||
"""
|
||||
return False
|
||||
|
||||
def references_column(self, table, column):
|
||||
"""
|
||||
Return whether or not this instance references the specified column.
|
||||
"""
|
||||
return False
|
||||
|
||||
def references_index(self, table, index):
|
||||
"""
|
||||
Return whether or not this instance references the specified index.
|
||||
"""
|
||||
return False
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
"""
|
||||
Rename all references to the old_name to the new_table.
|
||||
"""
|
||||
pass
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
"""
|
||||
Rename all references to the old_column to the new_column.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %r>" % (self.__class__.__name__, str(self))
|
||||
|
||||
def __str__(self):
|
||||
raise NotImplementedError(
|
||||
"Subclasses must define how they should be converted to string."
|
||||
)
|
||||
|
||||
|
||||
class Table(Reference):
|
||||
"""Hold a reference to a table."""
|
||||
|
||||
def __init__(self, table, quote_name):
|
||||
self.table = table
|
||||
self.quote_name = quote_name
|
||||
|
||||
def references_table(self, table):
|
||||
return self.table == table
|
||||
|
||||
def references_index(self, table, index):
|
||||
return self.references_table(table) and str(self) == index
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
if self.table == old_table:
|
||||
self.table = new_table
|
||||
|
||||
def __str__(self):
|
||||
return self.quote_name(self.table)
|
||||
|
||||
|
||||
class TableColumns(Table):
|
||||
"""Base class for references to multiple columns of a table."""
|
||||
|
||||
def __init__(self, table, columns):
|
||||
self.table = table
|
||||
self.columns = columns
|
||||
|
||||
def references_column(self, table, column):
|
||||
return self.table == table and column in self.columns
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
if self.table == table:
|
||||
for index, column in enumerate(self.columns):
|
||||
if column == old_column:
|
||||
self.columns[index] = new_column
|
||||
|
||||
|
||||
class Columns(TableColumns):
|
||||
"""Hold a reference to one or many columns."""
|
||||
|
||||
def __init__(self, table, columns, quote_name, col_suffixes=()):
|
||||
self.quote_name = quote_name
|
||||
self.col_suffixes = col_suffixes
|
||||
super().__init__(table, columns)
|
||||
|
||||
def __str__(self):
|
||||
def col_str(column, idx):
|
||||
col = self.quote_name(column)
|
||||
try:
|
||||
suffix = self.col_suffixes[idx]
|
||||
if suffix:
|
||||
col = "{} {}".format(col, suffix)
|
||||
except IndexError:
|
||||
pass
|
||||
return col
|
||||
|
||||
return ", ".join(
|
||||
col_str(column, idx) for idx, column in enumerate(self.columns)
|
||||
)
|
||||
|
||||
|
||||
class IndexName(TableColumns):
|
||||
"""Hold a reference to an index name."""
|
||||
|
||||
def __init__(self, table, columns, suffix, create_index_name):
|
||||
self.suffix = suffix
|
||||
self.create_index_name = create_index_name
|
||||
super().__init__(table, columns)
|
||||
|
||||
def __str__(self):
|
||||
return self.create_index_name(self.table, self.columns, self.suffix)
|
||||
|
||||
|
||||
class IndexColumns(Columns):
|
||||
def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
|
||||
self.opclasses = opclasses
|
||||
super().__init__(table, columns, quote_name, col_suffixes)
|
||||
|
||||
def __str__(self):
|
||||
def col_str(column, idx):
|
||||
# Index.__init__() guarantees that self.opclasses is the same
|
||||
# length as self.columns.
|
||||
col = "{} {}".format(self.quote_name(column), self.opclasses[idx])
|
||||
try:
|
||||
suffix = self.col_suffixes[idx]
|
||||
if suffix:
|
||||
col = "{} {}".format(col, suffix)
|
||||
except IndexError:
|
||||
pass
|
||||
return col
|
||||
|
||||
return ", ".join(
|
||||
col_str(column, idx) for idx, column in enumerate(self.columns)
|
||||
)
|
||||
|
||||
|
||||
class ForeignKeyName(TableColumns):
|
||||
"""Hold a reference to a foreign key name."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
from_table,
|
||||
from_columns,
|
||||
to_table,
|
||||
to_columns,
|
||||
suffix_template,
|
||||
create_fk_name,
|
||||
):
|
||||
self.to_reference = TableColumns(to_table, to_columns)
|
||||
self.suffix_template = suffix_template
|
||||
self.create_fk_name = create_fk_name
|
||||
super().__init__(
|
||||
from_table,
|
||||
from_columns,
|
||||
)
|
||||
|
||||
def references_table(self, table):
|
||||
return super().references_table(table) or self.to_reference.references_table(
|
||||
table
|
||||
)
|
||||
|
||||
def references_column(self, table, column):
|
||||
return super().references_column(
|
||||
table, column
|
||||
) or self.to_reference.references_column(table, column)
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
super().rename_table_references(old_table, new_table)
|
||||
self.to_reference.rename_table_references(old_table, new_table)
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
super().rename_column_references(table, old_column, new_column)
|
||||
self.to_reference.rename_column_references(table, old_column, new_column)
|
||||
|
||||
def __str__(self):
|
||||
suffix = self.suffix_template % {
|
||||
"to_table": self.to_reference.table,
|
||||
"to_column": self.to_reference.columns[0],
|
||||
}
|
||||
return self.create_fk_name(self.table, self.columns, suffix)
|
||||
|
||||
|
||||
class Statement(Reference):
|
||||
"""
|
||||
Statement template and formatting parameters container.
|
||||
|
||||
Allows keeping a reference to a statement without interpolating identifiers
|
||||
that might have to be adjusted if they're referencing a table or column
|
||||
that is removed
|
||||
"""
|
||||
|
||||
def __init__(self, template, **parts):
|
||||
self.template = template
|
||||
self.parts = parts
|
||||
|
||||
def references_table(self, table):
|
||||
return any(
|
||||
hasattr(part, "references_table") and part.references_table(table)
|
||||
for part in self.parts.values()
|
||||
)
|
||||
|
||||
def references_column(self, table, column):
|
||||
return any(
|
||||
hasattr(part, "references_column") and part.references_column(table, column)
|
||||
for part in self.parts.values()
|
||||
)
|
||||
|
||||
def references_index(self, table, index):
|
||||
return any(
|
||||
hasattr(part, "references_index") and part.references_index(table, index)
|
||||
for part in self.parts.values()
|
||||
)
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
for part in self.parts.values():
|
||||
if hasattr(part, "rename_table_references"):
|
||||
part.rename_table_references(old_table, new_table)
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
for part in self.parts.values():
|
||||
if hasattr(part, "rename_column_references"):
|
||||
part.rename_column_references(table, old_column, new_column)
|
||||
|
||||
def __str__(self):
|
||||
return self.template % self.parts
|
||||
|
||||
|
||||
class Expressions(TableColumns):
|
||||
def __init__(self, table, expressions, compiler, quote_value):
|
||||
self.compiler = compiler
|
||||
self.expressions = expressions
|
||||
self.quote_value = quote_value
|
||||
columns = [
|
||||
col.target.column
|
||||
for col in self.compiler.query._gen_cols([self.expressions])
|
||||
]
|
||||
super().__init__(table, columns)
|
||||
|
||||
def rename_table_references(self, old_table, new_table):
|
||||
if self.table != old_table:
|
||||
return
|
||||
self.expressions = self.expressions.relabeled_clone({old_table: new_table})
|
||||
super().rename_table_references(old_table, new_table)
|
||||
|
||||
def rename_column_references(self, table, old_column, new_column):
|
||||
if self.table != table:
|
||||
return
|
||||
expressions = deepcopy(self.expressions)
|
||||
self.columns = []
|
||||
for col in self.compiler.query._gen_cols([expressions]):
|
||||
if col.target.column == old_column:
|
||||
col.target.column = new_column
|
||||
self.columns.append(col.target.column)
|
||||
self.expressions = expressions
|
||||
|
||||
def __str__(self):
|
||||
sql, params = self.compiler.compile(self.expressions)
|
||||
params = map(self.quote_value, params)
|
||||
return sql % tuple(params)
|
||||
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Dummy database backend for Django.
|
||||
|
||||
Django uses this if the database ENGINE setting is empty (None or empty string).
|
||||
|
||||
Each of these API functions, except connection.close(), raise
|
||||
ImproperlyConfigured.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.dummy.features import DummyDatabaseFeatures
|
||||
|
||||
|
||||
def complain(*args, **kwargs):
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the ENGINE value. Check "
|
||||
"settings documentation for more details."
|
||||
)
|
||||
|
||||
|
||||
def ignore(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
quote_name = complain
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
runshell = complain
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
create_test_db = ignore
|
||||
destroy_test_db = ignore
|
||||
serialize_db_to_string = ignore
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
get_table_list = complain
|
||||
get_table_description = complain
|
||||
get_relations = complain
|
||||
get_indexes = complain
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
operators = {}
|
||||
# Override the base class implementations with null
|
||||
# implementations. Anything that tries to actually
|
||||
# do something raises complain; anything that tries
|
||||
# to rollback or undo something raises ignore.
|
||||
_cursor = complain
|
||||
ensure_connection = complain
|
||||
_commit = complain
|
||||
_rollback = ignore
|
||||
_close = ignore
|
||||
_savepoint = ignore
|
||||
_savepoint_commit = complain
|
||||
_savepoint_rollback = ignore
|
||||
_set_autocommit = complain
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DummyDatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
|
||||
def is_usable(self):
|
||||
return True
|
||||
@@ -0,0 +1,6 @@
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
|
||||
|
||||
class DummyDatabaseFeatures(BaseDatabaseFeatures):
|
||||
supports_transactions = False
|
||||
uses_savepoints = False
|
||||
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
MySQL database backend for Django.
|
||||
|
||||
Requires mysqlclient: https://pypi.org/project/mysqlclient/
|
||||
"""
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends import utils as backend_utils
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
try:
|
||||
import MySQLdb as Database
|
||||
except ImportError as err:
|
||||
raise ImproperlyConfigured(
|
||||
"Error loading MySQLdb module.\nDid you install mysqlclient?"
|
||||
) from err
|
||||
|
||||
from MySQLdb.constants import CLIENT, FIELD_TYPE
|
||||
from MySQLdb.converters import conversions
|
||||
|
||||
# Some of these import MySQLdb, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient
|
||||
from .creation import DatabaseCreation
|
||||
from .features import DatabaseFeatures
|
||||
from .introspection import DatabaseIntrospection
|
||||
from .operations import DatabaseOperations
|
||||
from .schema import DatabaseSchemaEditor
|
||||
from .validation import DatabaseValidation
|
||||
|
||||
version = Database.version_info
|
||||
if version < (1, 4, 3):
|
||||
raise ImproperlyConfigured(
|
||||
"mysqlclient 1.4.3 or newer is required; you have %s." % Database.__version__
|
||||
)
|
||||
|
||||
|
||||
# MySQLdb returns TIME columns as timedelta -- they are more like timedelta in
|
||||
# terms of actual behavior as they are signed and include days -- and Django
|
||||
# expects time.
|
||||
django_conversions = {
|
||||
**conversions,
|
||||
**{FIELD_TYPE.TIME: backend_utils.typecast_time},
|
||||
}
|
||||
|
||||
# This should match the numerical portion of the version numbers (we can treat
|
||||
# versions like 5.0.24 and 5.0.24a as the same).
|
||||
server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})")
|
||||
|
||||
|
||||
class CursorWrapper:
|
||||
"""
|
||||
A thin wrapper around MySQLdb's normal cursor class that catches particular
|
||||
exception instances and reraises them with the correct types.
|
||||
|
||||
Implemented as a wrapper, rather than a subclass, so that it isn't stuck
|
||||
to the particular underlying representation returned by Connection.cursor().
|
||||
"""
|
||||
|
||||
codes_for_integrityerror = (
|
||||
1048, # Column cannot be null
|
||||
1690, # BIGINT UNSIGNED value is out of range
|
||||
3819, # CHECK constraint is violated
|
||||
4025, # CHECK constraint failed
|
||||
)
|
||||
|
||||
def __init__(self, cursor):
|
||||
self.cursor = cursor
|
||||
|
||||
def execute(self, query, args=None):
|
||||
try:
|
||||
# args is None means no string interpolation
|
||||
return self.cursor.execute(query, args)
|
||||
except Database.OperationalError as e:
|
||||
# Map some error codes to IntegrityError, since they seem to be
|
||||
# misclassified and Django would prefer the more logical place.
|
||||
if e.args[0] in self.codes_for_integrityerror:
|
||||
raise IntegrityError(*tuple(e.args))
|
||||
raise
|
||||
|
||||
def executemany(self, query, args):
|
||||
try:
|
||||
return self.cursor.executemany(query, args)
|
||||
except Database.OperationalError as e:
|
||||
# Map some error codes to IntegrityError, since they seem to be
|
||||
# misclassified and Django would prefer the more logical place.
|
||||
if e.args[0] in self.codes_for_integrityerror:
|
||||
raise IntegrityError(*tuple(e.args))
|
||||
raise
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.cursor, attr)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "mysql"
|
||||
# This dictionary maps Field objects to their associated MySQL column
|
||||
# types, as strings. Column-type strings can contain format strings; they'll
|
||||
# be interpolated against the values of Field.__dict__ before being output.
|
||||
# If a column type is set to None, it won't be included in the output.
|
||||
|
||||
_data_types = {
|
||||
"AutoField": "integer AUTO_INCREMENT",
|
||||
"BigAutoField": "bigint AUTO_INCREMENT",
|
||||
"BinaryField": "longblob",
|
||||
"BooleanField": "bool",
|
||||
"CharField": "varchar(%(max_length)s)",
|
||||
"DateField": "date",
|
||||
"DateTimeField": "datetime(6)",
|
||||
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
|
||||
"DurationField": "bigint",
|
||||
"FileField": "varchar(%(max_length)s)",
|
||||
"FilePathField": "varchar(%(max_length)s)",
|
||||
"FloatField": "double precision",
|
||||
"IntegerField": "integer",
|
||||
"BigIntegerField": "bigint",
|
||||
"IPAddressField": "char(15)",
|
||||
"GenericIPAddressField": "char(39)",
|
||||
"JSONField": "json",
|
||||
"OneToOneField": "integer",
|
||||
"PositiveBigIntegerField": "bigint UNSIGNED",
|
||||
"PositiveIntegerField": "integer UNSIGNED",
|
||||
"PositiveSmallIntegerField": "smallint UNSIGNED",
|
||||
"SlugField": "varchar(%(max_length)s)",
|
||||
"SmallAutoField": "smallint AUTO_INCREMENT",
|
||||
"SmallIntegerField": "smallint",
|
||||
"TextField": "longtext",
|
||||
"TimeField": "time(6)",
|
||||
"UUIDField": "char(32)",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def data_types(self):
|
||||
_data_types = self._data_types.copy()
|
||||
if self.features.has_native_uuid_field:
|
||||
_data_types["UUIDField"] = "uuid"
|
||||
return _data_types
|
||||
|
||||
# For these data types:
|
||||
# - MySQL < 8.0.13 doesn't accept default values and implicitly treats them
|
||||
# as nullable
|
||||
# - all versions of MySQL and MariaDB don't support full width database
|
||||
# indexes
|
||||
_limited_data_types = (
|
||||
"tinyblob",
|
||||
"blob",
|
||||
"mediumblob",
|
||||
"longblob",
|
||||
"tinytext",
|
||||
"text",
|
||||
"mediumtext",
|
||||
"longtext",
|
||||
"json",
|
||||
)
|
||||
|
||||
operators = {
|
||||
"exact": "= %s",
|
||||
"iexact": "LIKE %s",
|
||||
"contains": "LIKE BINARY %s",
|
||||
"icontains": "LIKE %s",
|
||||
"gt": "> %s",
|
||||
"gte": ">= %s",
|
||||
"lt": "< %s",
|
||||
"lte": "<= %s",
|
||||
"startswith": "LIKE BINARY %s",
|
||||
"endswith": "LIKE BINARY %s",
|
||||
"istartswith": "LIKE %s",
|
||||
"iendswith": "LIKE %s",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\\', '\\\\'), '%%', '\%%'), '_', '\_')"
|
||||
pattern_ops = {
|
||||
"contains": "LIKE BINARY CONCAT('%%', {}, '%%')",
|
||||
"icontains": "LIKE CONCAT('%%', {}, '%%')",
|
||||
"startswith": "LIKE BINARY CONCAT({}, '%%')",
|
||||
"istartswith": "LIKE CONCAT({}, '%%')",
|
||||
"endswith": "LIKE BINARY CONCAT('%%', {})",
|
||||
"iendswith": "LIKE CONCAT('%%', {})",
|
||||
}
|
||||
|
||||
isolation_levels = {
|
||||
"read uncommitted",
|
||||
"read committed",
|
||||
"repeatable read",
|
||||
"serializable",
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
validation_class = DatabaseValidation
|
||||
|
||||
def get_database_version(self):
|
||||
return self.mysql_version
|
||||
|
||||
def get_connection_params(self):
|
||||
kwargs = {
|
||||
"conv": django_conversions,
|
||||
"charset": "utf8mb4",
|
||||
}
|
||||
settings_dict = self.settings_dict
|
||||
if settings_dict["USER"]:
|
||||
kwargs["user"] = settings_dict["USER"]
|
||||
if settings_dict["NAME"]:
|
||||
kwargs["database"] = settings_dict["NAME"]
|
||||
if settings_dict["PASSWORD"]:
|
||||
kwargs["password"] = settings_dict["PASSWORD"]
|
||||
if settings_dict["HOST"].startswith("/"):
|
||||
kwargs["unix_socket"] = settings_dict["HOST"]
|
||||
elif settings_dict["HOST"]:
|
||||
kwargs["host"] = settings_dict["HOST"]
|
||||
if settings_dict["PORT"]:
|
||||
kwargs["port"] = int(settings_dict["PORT"])
|
||||
# We need the number of potentially affected rows after an
|
||||
# "UPDATE", not the number of changed rows.
|
||||
kwargs["client_flag"] = CLIENT.FOUND_ROWS
|
||||
# Validate the transaction isolation level, if specified.
|
||||
options = settings_dict["OPTIONS"].copy()
|
||||
isolation_level = options.pop("isolation_level", "read committed")
|
||||
if isolation_level:
|
||||
isolation_level = isolation_level.lower()
|
||||
if isolation_level not in self.isolation_levels:
|
||||
raise ImproperlyConfigured(
|
||||
"Invalid transaction isolation level '%s' specified.\n"
|
||||
"Use one of %s, or None."
|
||||
% (
|
||||
isolation_level,
|
||||
", ".join("'%s'" % s for s in sorted(self.isolation_levels)),
|
||||
)
|
||||
)
|
||||
self.isolation_level = isolation_level
|
||||
kwargs.update(options)
|
||||
return kwargs
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
connection = Database.connect(**conn_params)
|
||||
# bytes encoder in mysqlclient doesn't work and was added only to
|
||||
# prevent KeyErrors in Django < 2.0. We can remove this workaround when
|
||||
# mysqlclient 2.1 becomes the minimal mysqlclient supported by Django.
|
||||
# See https://github.com/PyMySQL/mysqlclient/issues/489
|
||||
if connection.encoders.get(bytes) is bytes:
|
||||
connection.encoders.pop(bytes)
|
||||
return connection
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
assignments = []
|
||||
if self.features.is_sql_auto_is_null_enabled:
|
||||
# SQL_AUTO_IS_NULL controls whether an AUTO_INCREMENT column on
|
||||
# a recently inserted row will return when the field is tested
|
||||
# for NULL. Disabling this brings this aspect of MySQL in line
|
||||
# with SQL standards.
|
||||
assignments.append("SET SQL_AUTO_IS_NULL = 0")
|
||||
|
||||
if self.isolation_level:
|
||||
assignments.append(
|
||||
"SET SESSION TRANSACTION ISOLATION LEVEL %s"
|
||||
% self.isolation_level.upper()
|
||||
)
|
||||
|
||||
if assignments:
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("; ".join(assignments))
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
cursor = self.connection.cursor()
|
||||
return CursorWrapper(cursor)
|
||||
|
||||
def _rollback(self):
|
||||
try:
|
||||
BaseDatabaseWrapper._rollback(self)
|
||||
except Database.NotSupportedError:
|
||||
pass
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit(autocommit)
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
"""
|
||||
Disable foreign key checks, primarily for use in adding rows with
|
||||
forward references. Always return True to indicate constraint checks
|
||||
need to be re-enabled.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("SET foreign_key_checks=0")
|
||||
return True
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
"""
|
||||
Re-enable foreign key checks after they have been disabled.
|
||||
"""
|
||||
# Override needs_rollback in case constraint_checks_disabled is
|
||||
# nested inside transaction.atomic.
|
||||
self.needs_rollback, needs_rollback = False, self.needs_rollback
|
||||
try:
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("SET foreign_key_checks=1")
|
||||
finally:
|
||||
self.needs_rollback = needs_rollback
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check each table name in `table_names` for rows with invalid foreign
|
||||
key references. This method is intended to be used in conjunction with
|
||||
`disable_constraint_checking()` and `enable_constraint_checking()`, to
|
||||
determine if rows with invalid references were entered while constraint
|
||||
checks were off.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
table_names = self.introspection.table_names(cursor)
|
||||
for table_name in table_names:
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(
|
||||
cursor, table_name
|
||||
)
|
||||
if not primary_key_column_name:
|
||||
continue
|
||||
relations = self.introspection.get_relations(cursor, table_name)
|
||||
for column_name, (
|
||||
referenced_column_name,
|
||||
referenced_table_name,
|
||||
) in relations.items():
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT REFERRING.`%s`, REFERRING.`%s` FROM `%s` as REFERRING
|
||||
LEFT JOIN `%s` as REFERRED
|
||||
ON (REFERRING.`%s` = REFERRED.`%s`)
|
||||
WHERE REFERRING.`%s` IS NOT NULL AND REFERRED.`%s` IS NULL
|
||||
"""
|
||||
% (
|
||||
primary_key_column_name,
|
||||
column_name,
|
||||
table_name,
|
||||
referenced_table_name,
|
||||
column_name,
|
||||
referenced_column_name,
|
||||
column_name,
|
||||
referenced_column_name,
|
||||
)
|
||||
)
|
||||
for bad_row in cursor.fetchall():
|
||||
raise IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s."
|
||||
% (
|
||||
table_name,
|
||||
bad_row[0],
|
||||
table_name,
|
||||
column_name,
|
||||
bad_row[1],
|
||||
referenced_table_name,
|
||||
referenced_column_name,
|
||||
)
|
||||
)
|
||||
|
||||
def is_usable(self):
|
||||
try:
|
||||
self.connection.ping()
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def display_name(self):
|
||||
return "MariaDB" if self.mysql_is_mariadb else "MySQL"
|
||||
|
||||
@cached_property
|
||||
def data_type_check_constraints(self):
|
||||
if self.features.supports_column_check_constraints:
|
||||
check_constraints = {
|
||||
"PositiveBigIntegerField": "`%(column)s` >= 0",
|
||||
"PositiveIntegerField": "`%(column)s` >= 0",
|
||||
"PositiveSmallIntegerField": "`%(column)s` >= 0",
|
||||
}
|
||||
return check_constraints
|
||||
return {}
|
||||
|
||||
@cached_property
|
||||
def mysql_server_data(self):
|
||||
with self.temporary_connection() as cursor:
|
||||
# Select some server variables and test if the time zone
|
||||
# definitions are installed. CONVERT_TZ returns NULL if 'UTC'
|
||||
# timezone isn't loaded into the mysql.time_zone table.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT VERSION(),
|
||||
@@sql_mode,
|
||||
@@default_storage_engine,
|
||||
@@sql_auto_is_null,
|
||||
@@lower_case_table_names,
|
||||
CONVERT_TZ('2001-01-01 01:00:00', 'UTC', 'UTC') IS NOT NULL
|
||||
"""
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return {
|
||||
"version": row[0],
|
||||
"sql_mode": row[1],
|
||||
"default_storage_engine": row[2],
|
||||
"sql_auto_is_null": bool(row[3]),
|
||||
"lower_case_table_names": bool(row[4]),
|
||||
"has_zoneinfo_database": bool(row[5]),
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def mysql_server_info(self):
|
||||
return self.mysql_server_data["version"]
|
||||
|
||||
@cached_property
|
||||
def mysql_version(self):
|
||||
match = server_version_re.match(self.mysql_server_info)
|
||||
if not match:
|
||||
raise Exception(
|
||||
"Unable to determine MySQL version from version string %r"
|
||||
% self.mysql_server_info
|
||||
)
|
||||
return tuple(int(x) for x in match.groups())
|
||||
|
||||
@cached_property
|
||||
def mysql_is_mariadb(self):
|
||||
return "mariadb" in self.mysql_server_info.lower()
|
||||
|
||||
@cached_property
|
||||
def sql_mode(self):
|
||||
sql_mode = self.mysql_server_data["sql_mode"]
|
||||
return set(sql_mode.split(",") if sql_mode else ())
|
||||
@@ -0,0 +1,72 @@
|
||||
import signal
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = "mysql"
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name]
|
||||
env = None
|
||||
database = settings_dict["OPTIONS"].get(
|
||||
"database",
|
||||
settings_dict["OPTIONS"].get("db", settings_dict["NAME"]),
|
||||
)
|
||||
user = settings_dict["OPTIONS"].get("user", settings_dict["USER"])
|
||||
password = settings_dict["OPTIONS"].get(
|
||||
"password",
|
||||
settings_dict["OPTIONS"].get("passwd", settings_dict["PASSWORD"]),
|
||||
)
|
||||
host = settings_dict["OPTIONS"].get("host", settings_dict["HOST"])
|
||||
port = settings_dict["OPTIONS"].get("port", settings_dict["PORT"])
|
||||
server_ca = settings_dict["OPTIONS"].get("ssl", {}).get("ca")
|
||||
client_cert = settings_dict["OPTIONS"].get("ssl", {}).get("cert")
|
||||
client_key = settings_dict["OPTIONS"].get("ssl", {}).get("key")
|
||||
defaults_file = settings_dict["OPTIONS"].get("read_default_file")
|
||||
charset = settings_dict["OPTIONS"].get("charset")
|
||||
# Seems to be no good way to set sql_mode with CLI.
|
||||
|
||||
if defaults_file:
|
||||
args += ["--defaults-file=%s" % defaults_file]
|
||||
if user:
|
||||
args += ["--user=%s" % user]
|
||||
if password:
|
||||
# The MYSQL_PWD environment variable usage is discouraged per
|
||||
# MySQL's documentation due to the possibility of exposure through
|
||||
# `ps` on old Unix flavors but --password suffers from the same
|
||||
# flaw on even more systems. Usage of an environment variable also
|
||||
# prevents password exposure if the subprocess.run(check=True) call
|
||||
# raises a CalledProcessError since the string representation of
|
||||
# the latter includes all of the provided `args`.
|
||||
env = {"MYSQL_PWD": password}
|
||||
if host:
|
||||
if "/" in host:
|
||||
args += ["--socket=%s" % host]
|
||||
else:
|
||||
args += ["--host=%s" % host]
|
||||
if port:
|
||||
args += ["--port=%s" % port]
|
||||
if server_ca:
|
||||
args += ["--ssl-ca=%s" % server_ca]
|
||||
if client_cert:
|
||||
args += ["--ssl-cert=%s" % client_cert]
|
||||
if client_key:
|
||||
args += ["--ssl-key=%s" % client_key]
|
||||
if charset:
|
||||
args += ["--default-character-set=%s" % charset]
|
||||
if database:
|
||||
args += [database]
|
||||
args.extend(parameters)
|
||||
return args, env
|
||||
|
||||
def runshell(self, parameters):
|
||||
sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
try:
|
||||
# Allow SIGINT to pass to mysql to abort queries.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
super().runshell(parameters)
|
||||
finally:
|
||||
# Restore the original SIGINT handler.
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
@@ -0,0 +1,72 @@
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db.models.expressions import Col
|
||||
from django.db.models.sql.compiler import SQLAggregateCompiler, SQLCompiler
|
||||
from django.db.models.sql.compiler import SQLDeleteCompiler as BaseSQLDeleteCompiler
|
||||
from django.db.models.sql.compiler import SQLInsertCompiler
|
||||
from django.db.models.sql.compiler import SQLUpdateCompiler as BaseSQLUpdateCompiler
|
||||
|
||||
__all__ = [
|
||||
"SQLAggregateCompiler",
|
||||
"SQLCompiler",
|
||||
"SQLDeleteCompiler",
|
||||
"SQLInsertCompiler",
|
||||
"SQLUpdateCompiler",
|
||||
]
|
||||
|
||||
|
||||
class SQLDeleteCompiler(BaseSQLDeleteCompiler):
|
||||
def as_sql(self):
|
||||
# Prefer the non-standard DELETE FROM syntax over the SQL generated by
|
||||
# the SQLDeleteCompiler's default implementation when multiple tables
|
||||
# are involved since MySQL/MariaDB will generate a more efficient query
|
||||
# plan than when using a subquery.
|
||||
where, having, qualify = self.query.where.split_having_qualify(
|
||||
must_group_by=self.query.group_by is not None
|
||||
)
|
||||
if self.single_alias or having or qualify:
|
||||
# DELETE FROM cannot be used when filtering against aggregates or
|
||||
# window functions as it doesn't allow for GROUP BY/HAVING clauses
|
||||
# and the subquery wrapping (necessary to emulate QUALIFY).
|
||||
return super().as_sql()
|
||||
result = [
|
||||
"DELETE %s FROM"
|
||||
% self.quote_name_unless_alias(self.query.get_initial_alias())
|
||||
]
|
||||
from_sql, params = self.get_from_clause()
|
||||
result.extend(from_sql)
|
||||
try:
|
||||
where_sql, where_params = self.compile(where)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
result.append("WHERE %s" % where_sql)
|
||||
params.extend(where_params)
|
||||
return " ".join(result), tuple(params)
|
||||
|
||||
|
||||
class SQLUpdateCompiler(BaseSQLUpdateCompiler):
|
||||
def as_sql(self):
|
||||
update_query, update_params = super().as_sql()
|
||||
# MySQL and MariaDB support UPDATE ... ORDER BY syntax.
|
||||
if self.query.order_by:
|
||||
order_by_sql = []
|
||||
order_by_params = []
|
||||
db_table = self.query.get_meta().db_table
|
||||
try:
|
||||
for resolved, (sql, params, _) in self.get_order_by():
|
||||
if (
|
||||
isinstance(resolved.expression, Col)
|
||||
and resolved.expression.alias != db_table
|
||||
):
|
||||
# Ignore ordering if it contains joined fields, because
|
||||
# they cannot be used in the ORDER BY clause.
|
||||
raise FieldError
|
||||
order_by_sql.append(sql)
|
||||
order_by_params.extend(params)
|
||||
update_query += " ORDER BY " + ", ".join(order_by_sql)
|
||||
update_params += tuple(order_by_params)
|
||||
except FieldError:
|
||||
# Ignore ordering if it contains annotations, because they're
|
||||
# removed in .update() and cannot be resolved.
|
||||
pass
|
||||
return update_query, update_params
|
||||
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
|
||||
from .client import DatabaseClient
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
def sql_table_creation_suffix(self):
|
||||
suffix = []
|
||||
test_settings = self.connection.settings_dict["TEST"]
|
||||
if test_settings["CHARSET"]:
|
||||
suffix.append("CHARACTER SET %s" % test_settings["CHARSET"])
|
||||
if test_settings["COLLATION"]:
|
||||
suffix.append("COLLATE %s" % test_settings["COLLATION"])
|
||||
return " ".join(suffix)
|
||||
|
||||
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||||
try:
|
||||
super()._execute_create_test_db(cursor, parameters, keepdb)
|
||||
except Exception as e:
|
||||
if len(e.args) < 1 or e.args[0] != 1007:
|
||||
# All errors except "database exists" (1007) cancel tests.
|
||||
self.log("Got an error creating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
raise
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
source_database_name = self.connection.settings_dict["NAME"]
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||||
test_db_params = {
|
||||
"dbname": self.connection.ops.quote_name(target_database_name),
|
||||
"suffix": self.sql_table_creation_suffix(),
|
||||
}
|
||||
with self._nodb_cursor() as cursor:
|
||||
try:
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception:
|
||||
if keepdb:
|
||||
# If the database should be kept, skip everything else.
|
||||
return
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (
|
||||
self._get_database_display_str(
|
||||
verbosity, target_database_name
|
||||
),
|
||||
)
|
||||
)
|
||||
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
self.log("Got an error recreating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
self._clone_db(source_database_name, target_database_name)
|
||||
|
||||
def _clone_db(self, source_database_name, target_database_name):
|
||||
cmd_args, cmd_env = DatabaseClient.settings_to_cmd_args_env(
|
||||
self.connection.settings_dict, []
|
||||
)
|
||||
dump_cmd = [
|
||||
"mysqldump",
|
||||
*cmd_args[1:-1],
|
||||
"--routines",
|
||||
"--events",
|
||||
source_database_name,
|
||||
]
|
||||
dump_env = load_env = {**os.environ, **cmd_env} if cmd_env else None
|
||||
load_cmd = cmd_args
|
||||
load_cmd[-1] = target_database_name
|
||||
|
||||
with subprocess.Popen(
|
||||
dump_cmd, stdout=subprocess.PIPE, env=dump_env
|
||||
) as dump_proc:
|
||||
with subprocess.Popen(
|
||||
load_cmd,
|
||||
stdin=dump_proc.stdout,
|
||||
stdout=subprocess.DEVNULL,
|
||||
env=load_env,
|
||||
):
|
||||
# Allow dump_proc to receive a SIGPIPE if the load process exits.
|
||||
dump_proc.stdout.close()
|
||||
@@ -0,0 +1,324 @@
|
||||
import operator
|
||||
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
empty_fetchmany_value = ()
|
||||
related_fields_match_type = True
|
||||
# MySQL doesn't support sliced subqueries with IN/ALL/ANY/SOME.
|
||||
allow_sliced_subqueries_with_in = False
|
||||
has_select_for_update = True
|
||||
has_select_for_update_nowait = True
|
||||
supports_forward_references = False
|
||||
supports_regex_backreferencing = False
|
||||
supports_date_lookup_using_string = False
|
||||
supports_timezones = False
|
||||
requires_explicit_null_ordering_when_grouping = True
|
||||
atomic_transactions = False
|
||||
can_clone_databases = True
|
||||
supports_comments = True
|
||||
supports_comments_inline = True
|
||||
supports_temporal_subtraction = True
|
||||
supports_slicing_ordering_in_compound = True
|
||||
supports_index_on_text_field = False
|
||||
supports_over_clause = True
|
||||
supports_frame_range_fixed_distance = True
|
||||
supports_update_conflicts = True
|
||||
delete_can_self_reference_subquery = False
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE PROCEDURE test_procedure ()
|
||||
BEGIN
|
||||
DECLARE V_I INTEGER;
|
||||
SET V_I = 1;
|
||||
END;
|
||||
"""
|
||||
create_test_procedure_with_int_param_sql = """
|
||||
CREATE PROCEDURE test_procedure (P_I INTEGER)
|
||||
BEGIN
|
||||
DECLARE V_I INTEGER;
|
||||
SET V_I = P_I;
|
||||
END;
|
||||
"""
|
||||
create_test_table_with_composite_primary_key = """
|
||||
CREATE TABLE test_table_composite_pk (
|
||||
column_1 INTEGER NOT NULL,
|
||||
column_2 INTEGER NOT NULL,
|
||||
PRIMARY KEY(column_1, column_2)
|
||||
)
|
||||
"""
|
||||
# Neither MySQL nor MariaDB support partial indexes.
|
||||
supports_partial_indexes = False
|
||||
# COLLATE must be wrapped in parentheses because MySQL treats COLLATE as an
|
||||
# indexed expression.
|
||||
collate_as_index_expression = True
|
||||
insert_test_table_with_defaults = "INSERT INTO {} () VALUES ()"
|
||||
|
||||
supports_order_by_nulls_modifier = False
|
||||
order_by_nulls_first = True
|
||||
supports_logical_xor = True
|
||||
|
||||
supports_stored_generated_columns = True
|
||||
supports_virtual_generated_columns = True
|
||||
|
||||
@cached_property
|
||||
def minimum_database_version(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return (10, 5)
|
||||
else:
|
||||
return (8, 0, 11)
|
||||
|
||||
@cached_property
|
||||
def test_collations(self):
|
||||
return {
|
||||
"ci": "utf8mb4_general_ci",
|
||||
"non_default": "utf8mb4_esperanto_ci",
|
||||
"swedish_ci": "utf8mb4_swedish_ci",
|
||||
"virtual": "utf8mb4_esperanto_ci",
|
||||
}
|
||||
|
||||
test_now_utc_template = "UTC_TIMESTAMP(6)"
|
||||
|
||||
@cached_property
|
||||
def django_test_skips(self):
|
||||
skips = {
|
||||
"This doesn't work on MySQL.": {
|
||||
"db_functions.comparison.test_greatest.GreatestTests."
|
||||
"test_coalesce_workaround",
|
||||
"db_functions.comparison.test_least.LeastTests."
|
||||
"test_coalesce_workaround",
|
||||
},
|
||||
"MySQL doesn't support functional indexes on a function that "
|
||||
"returns JSON": {
|
||||
"schema.tests.SchemaTests.test_func_index_json_key_transform",
|
||||
},
|
||||
"MySQL supports multiplying and dividing DurationFields by a "
|
||||
"scalar value but it's not implemented (#25287).": {
|
||||
"expressions.tests.FTimeDeltaTests.test_durationfield_multiply_divide",
|
||||
},
|
||||
"UPDATE ... ORDER BY syntax on MySQL/MariaDB does not support ordering by"
|
||||
"related fields.": {
|
||||
"update.tests.AdvancedTests."
|
||||
"test_update_ordered_by_inline_m2m_annotation",
|
||||
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation",
|
||||
"update.tests.AdvancedTests.test_update_ordered_by_m2m_annotation_desc",
|
||||
},
|
||||
}
|
||||
if self.connection.mysql_is_mariadb and (
|
||||
self.connection.mysql_version < (10, 5, 2)
|
||||
):
|
||||
skips.update(
|
||||
{
|
||||
"https://jira.mariadb.org/browse/MDEV-19598": {
|
||||
"schema.tests.SchemaTests."
|
||||
"test_alter_not_unique_field_to_primary_key",
|
||||
},
|
||||
}
|
||||
)
|
||||
if not self.supports_explain_analyze:
|
||||
skips.update(
|
||||
{
|
||||
"MariaDB and MySQL >= 8.0.18 specific.": {
|
||||
"queries.test_explain.ExplainTests.test_mysql_analyze",
|
||||
},
|
||||
}
|
||||
)
|
||||
if "ONLY_FULL_GROUP_BY" in self.connection.sql_mode:
|
||||
skips.update(
|
||||
{
|
||||
"GROUP BY cannot contain nonaggregated column when "
|
||||
"ONLY_FULL_GROUP_BY mode is enabled on MySQL, see #34262.": {
|
||||
"aggregation.tests.AggregateTestCase."
|
||||
"test_group_by_nested_expression_with_params",
|
||||
},
|
||||
}
|
||||
)
|
||||
if self.connection.mysql_version < (8, 0, 31):
|
||||
skips.update(
|
||||
{
|
||||
"Nesting of UNIONs at the right-hand side is not supported on "
|
||||
"MySQL < 8.0.31": {
|
||||
"queries.test_qs_combinators.QuerySetSetOperationTests."
|
||||
"test_union_nested"
|
||||
},
|
||||
}
|
||||
)
|
||||
if not self.connection.mysql_is_mariadb:
|
||||
skips.update(
|
||||
{
|
||||
"MySQL doesn't allow renaming columns referenced by generated "
|
||||
"columns": {
|
||||
"migrations.test_operations.OperationTests."
|
||||
"test_invalid_generated_field_changes_on_rename_stored",
|
||||
"migrations.test_operations.OperationTests."
|
||||
"test_invalid_generated_field_changes_on_rename_virtual",
|
||||
},
|
||||
}
|
||||
)
|
||||
return skips
|
||||
|
||||
@cached_property
|
||||
def _mysql_storage_engine(self):
|
||||
"Internal method used in Django tests. Don't rely on this from your code"
|
||||
return self.connection.mysql_server_data["default_storage_engine"]
|
||||
|
||||
@cached_property
|
||||
def allows_auto_pk_0(self):
|
||||
"""
|
||||
Autoincrement primary key can be set to 0 if it doesn't generate new
|
||||
autoincrement values.
|
||||
"""
|
||||
return "NO_AUTO_VALUE_ON_ZERO" in self.connection.sql_mode
|
||||
|
||||
@cached_property
|
||||
def update_can_self_select(self):
|
||||
return self.connection.mysql_is_mariadb
|
||||
|
||||
@cached_property
|
||||
def can_introspect_foreign_keys(self):
|
||||
"Confirm support for introspected foreign keys"
|
||||
return self._mysql_storage_engine != "MyISAM"
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
"BinaryField": "TextField",
|
||||
"BooleanField": "IntegerField",
|
||||
"DurationField": "BigIntegerField",
|
||||
"GenericIPAddressField": "CharField",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def can_return_columns_from_insert(self):
|
||||
return self.connection.mysql_is_mariadb
|
||||
|
||||
can_return_rows_from_bulk_insert = property(
|
||||
operator.attrgetter("can_return_columns_from_insert")
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def has_zoneinfo_database(self):
|
||||
return self.connection.mysql_server_data["has_zoneinfo_database"]
|
||||
|
||||
@cached_property
|
||||
def is_sql_auto_is_null_enabled(self):
|
||||
return self.connection.mysql_server_data["sql_auto_is_null"]
|
||||
|
||||
@cached_property
|
||||
def supports_column_check_constraints(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 16)
|
||||
|
||||
supports_table_check_constraints = property(
|
||||
operator.attrgetter("supports_column_check_constraints")
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def can_introspect_check_constraints(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 16)
|
||||
|
||||
@cached_property
|
||||
def has_select_for_update_skip_locked(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 6)
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def has_select_for_update_of(self):
|
||||
return not self.connection.mysql_is_mariadb
|
||||
|
||||
@cached_property
|
||||
def supports_explain_analyze(self):
|
||||
return self.connection.mysql_is_mariadb or self.connection.mysql_version >= (
|
||||
8,
|
||||
0,
|
||||
18,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def supported_explain_formats(self):
|
||||
# Alias MySQL's TRADITIONAL to TEXT for consistency with other
|
||||
# backends.
|
||||
formats = {"JSON", "TEXT", "TRADITIONAL"}
|
||||
if not self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
|
||||
8,
|
||||
0,
|
||||
16,
|
||||
):
|
||||
formats.add("TREE")
|
||||
return formats
|
||||
|
||||
@cached_property
|
||||
def supports_transactions(self):
|
||||
"""
|
||||
All storage engines except MyISAM support transactions.
|
||||
"""
|
||||
return self._mysql_storage_engine != "MyISAM"
|
||||
|
||||
@cached_property
|
||||
def ignores_table_name_case(self):
|
||||
return self.connection.mysql_server_data["lower_case_table_names"]
|
||||
|
||||
@cached_property
|
||||
def supports_default_in_lead_lag(self):
|
||||
# To be added in https://jira.mariadb.org/browse/MDEV-12981.
|
||||
return not self.connection.mysql_is_mariadb
|
||||
|
||||
@cached_property
|
||||
def can_introspect_json_field(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.can_introspect_check_constraints
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def supports_index_column_ordering(self):
|
||||
if self._mysql_storage_engine != "InnoDB":
|
||||
return False
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 8)
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def supports_expression_indexes(self):
|
||||
return (
|
||||
not self.connection.mysql_is_mariadb
|
||||
and self._mysql_storage_engine != "MyISAM"
|
||||
and self.connection.mysql_version >= (8, 0, 13)
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def supports_select_intersection(self):
|
||||
is_mariadb = self.connection.mysql_is_mariadb
|
||||
return is_mariadb or self.connection.mysql_version >= (8, 0, 31)
|
||||
|
||||
supports_select_difference = property(
|
||||
operator.attrgetter("supports_select_intersection")
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def can_rename_index(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return self.connection.mysql_version >= (10, 5, 2)
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def supports_expression_defaults(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 13)
|
||||
|
||||
@cached_property
|
||||
def has_native_uuid_field(self):
|
||||
is_mariadb = self.connection.mysql_is_mariadb
|
||||
return is_mariadb and self.connection.mysql_version >= (10, 7)
|
||||
|
||||
@cached_property
|
||||
def allows_group_by_selected_pks(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return "ONLY_FULL_GROUP_BY" not in self.connection.sql_mode
|
||||
return True
|
||||
@@ -0,0 +1,358 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlparse
|
||||
from MySQLdb.constants import FIELD_TYPE
|
||||
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
|
||||
from django.db.backends.base.introspection import TableInfo as BaseTableInfo
|
||||
from django.db.models import Index
|
||||
from django.utils.datastructures import OrderedSet
|
||||
|
||||
FieldInfo = namedtuple(
|
||||
"FieldInfo",
|
||||
BaseFieldInfo._fields
|
||||
+ ("extra", "is_unsigned", "has_json_constraint", "comment", "data_type"),
|
||||
)
|
||||
InfoLine = namedtuple(
|
||||
"InfoLine",
|
||||
"col_name data_type max_len num_prec num_scale extra column_default "
|
||||
"collation is_unsigned comment",
|
||||
)
|
||||
TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
data_types_reverse = {
|
||||
FIELD_TYPE.BLOB: "TextField",
|
||||
FIELD_TYPE.CHAR: "CharField",
|
||||
FIELD_TYPE.DECIMAL: "DecimalField",
|
||||
FIELD_TYPE.NEWDECIMAL: "DecimalField",
|
||||
FIELD_TYPE.DATE: "DateField",
|
||||
FIELD_TYPE.DATETIME: "DateTimeField",
|
||||
FIELD_TYPE.DOUBLE: "FloatField",
|
||||
FIELD_TYPE.FLOAT: "FloatField",
|
||||
FIELD_TYPE.INT24: "IntegerField",
|
||||
FIELD_TYPE.JSON: "JSONField",
|
||||
FIELD_TYPE.LONG: "IntegerField",
|
||||
FIELD_TYPE.LONGLONG: "BigIntegerField",
|
||||
FIELD_TYPE.SHORT: "SmallIntegerField",
|
||||
FIELD_TYPE.STRING: "CharField",
|
||||
FIELD_TYPE.TIME: "TimeField",
|
||||
FIELD_TYPE.TIMESTAMP: "DateTimeField",
|
||||
FIELD_TYPE.TINY: "IntegerField",
|
||||
FIELD_TYPE.TINY_BLOB: "TextField",
|
||||
FIELD_TYPE.MEDIUM_BLOB: "TextField",
|
||||
FIELD_TYPE.LONG_BLOB: "TextField",
|
||||
FIELD_TYPE.VAR_STRING: "CharField",
|
||||
}
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if "auto_increment" in description.extra:
|
||||
if field_type == "IntegerField":
|
||||
return "AutoField"
|
||||
elif field_type == "BigIntegerField":
|
||||
return "BigAutoField"
|
||||
elif field_type == "SmallIntegerField":
|
||||
return "SmallAutoField"
|
||||
if description.is_unsigned:
|
||||
if field_type == "BigIntegerField":
|
||||
return "PositiveBigIntegerField"
|
||||
elif field_type == "IntegerField":
|
||||
return "PositiveIntegerField"
|
||||
elif field_type == "SmallIntegerField":
|
||||
return "PositiveSmallIntegerField"
|
||||
if description.data_type.upper() == "UUID":
|
||||
return "UUIDField"
|
||||
# JSON data type is an alias for LONGTEXT in MariaDB, use check
|
||||
# constraints clauses to introspect JSONField.
|
||||
if description.has_json_constraint:
|
||||
return "JSONField"
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
table_name,
|
||||
table_type,
|
||||
table_comment
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
"""
|
||||
)
|
||||
return [
|
||||
TableInfo(row[0], {"BASE TABLE": "t", "VIEW": "v"}.get(row[1]), row[2])
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface."
|
||||
"""
|
||||
json_constraints = {}
|
||||
if (
|
||||
self.connection.mysql_is_mariadb
|
||||
and self.connection.features.can_introspect_json_field
|
||||
):
|
||||
# JSON data type is an alias for LONGTEXT in MariaDB, select
|
||||
# JSON_VALID() constraints to introspect JSONField.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT c.constraint_name AS column_name
|
||||
FROM information_schema.check_constraints AS c
|
||||
WHERE
|
||||
c.table_name = %s AND
|
||||
LOWER(c.check_clause) =
|
||||
'json_valid(`' + LOWER(c.constraint_name) + '`)' AND
|
||||
c.constraint_schema = DATABASE()
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
json_constraints = {row[0] for row in cursor.fetchall()}
|
||||
# A default collation for the given table.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT table_collation
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
default_column_collation = row[0] if row else ""
|
||||
# information_schema database gives more accurate results for some figures:
|
||||
# - varchar length returned by cursor.description is an internal length,
|
||||
# not visible length (#5725)
|
||||
# - precision and scale (for decimal fields) (#5014)
|
||||
# - auto_increment is not available in cursor.description
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
column_name, data_type, character_maximum_length,
|
||||
numeric_precision, numeric_scale, extra, column_default,
|
||||
CASE
|
||||
WHEN collation_name = %s THEN NULL
|
||||
ELSE collation_name
|
||||
END AS collation_name,
|
||||
CASE
|
||||
WHEN column_type LIKE '%% unsigned' THEN 1
|
||||
ELSE 0
|
||||
END AS is_unsigned,
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = %s AND table_schema = DATABASE()
|
||||
""",
|
||||
[default_column_collation, table_name],
|
||||
)
|
||||
field_info = {line[0]: InfoLine(*line) for line in cursor.fetchall()}
|
||||
|
||||
cursor.execute(
|
||||
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
|
||||
def to_int(i):
|
||||
return int(i) if i is not None else i
|
||||
|
||||
fields = []
|
||||
for line in cursor.description:
|
||||
info = field_info[line[0]]
|
||||
fields.append(
|
||||
FieldInfo(
|
||||
*line[:2],
|
||||
to_int(info.max_len) or line[2],
|
||||
to_int(info.max_len) or line[3],
|
||||
to_int(info.num_prec) or line[4],
|
||||
to_int(info.num_scale) or line[5],
|
||||
line[6],
|
||||
info.column_default,
|
||||
info.collation,
|
||||
info.extra,
|
||||
info.is_unsigned,
|
||||
line[0] in json_constraints,
|
||||
info.comment,
|
||||
info.data_type,
|
||||
)
|
||||
)
|
||||
return fields
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
for field_info in self.get_table_description(cursor, table_name):
|
||||
if "auto_increment" in field_info.extra:
|
||||
# MySQL allows only one auto-increment column per table.
|
||||
return [{"table": table_name, "column": field_info.name}]
|
||||
return []
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT column_name, referenced_column_name, referenced_table_name
|
||||
FROM information_schema.key_column_usage
|
||||
WHERE table_name = %s
|
||||
AND table_schema = DATABASE()
|
||||
AND referenced_table_schema = DATABASE()
|
||||
AND referenced_table_name IS NOT NULL
|
||||
AND referenced_column_name IS NOT NULL
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
return {
|
||||
field_name: (other_field, other_table)
|
||||
for field_name, other_field, other_table in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_storage_engine(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve the storage engine for a given table. Return the default
|
||||
storage engine if the table doesn't exist.
|
||||
"""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT engine
|
||||
FROM information_schema.tables
|
||||
WHERE
|
||||
table_name = %s AND
|
||||
table_schema = DATABASE()
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
return self.connection.features._mysql_storage_engine
|
||||
return result[0]
|
||||
|
||||
def _parse_constraint_columns(self, check_clause, columns):
|
||||
check_columns = OrderedSet()
|
||||
statement = sqlparse.parse(check_clause)[0]
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
for token in tokens:
|
||||
if (
|
||||
token.ttype == sqlparse.tokens.Name
|
||||
and self.connection.ops.quote_name(token.value) == token.value
|
||||
and token.value[1:-1] in columns
|
||||
):
|
||||
check_columns.add(token.value[1:-1])
|
||||
return check_columns
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Get the actual constraint names and columns
|
||||
name_query = """
|
||||
SELECT kc.`constraint_name`, kc.`column_name`,
|
||||
kc.`referenced_table_name`, kc.`referenced_column_name`,
|
||||
c.`constraint_type`
|
||||
FROM
|
||||
information_schema.key_column_usage AS kc,
|
||||
information_schema.table_constraints AS c
|
||||
WHERE
|
||||
kc.table_schema = DATABASE() AND
|
||||
(
|
||||
kc.referenced_table_schema = DATABASE() OR
|
||||
kc.referenced_table_schema IS NULL
|
||||
) AND
|
||||
c.table_schema = kc.table_schema AND
|
||||
c.constraint_name = kc.constraint_name AND
|
||||
c.constraint_type != 'CHECK' AND
|
||||
kc.table_name = %s
|
||||
ORDER BY kc.`ordinal_position`
|
||||
"""
|
||||
cursor.execute(name_query, [table_name])
|
||||
for constraint, column, ref_table, ref_column, kind in cursor.fetchall():
|
||||
if constraint not in constraints:
|
||||
constraints[constraint] = {
|
||||
"columns": OrderedSet(),
|
||||
"primary_key": kind == "PRIMARY KEY",
|
||||
"unique": kind in {"PRIMARY KEY", "UNIQUE"},
|
||||
"index": False,
|
||||
"check": False,
|
||||
"foreign_key": (ref_table, ref_column) if ref_column else None,
|
||||
}
|
||||
if self.connection.features.supports_index_column_ordering:
|
||||
constraints[constraint]["orders"] = []
|
||||
constraints[constraint]["columns"].add(column)
|
||||
# Add check constraints.
|
||||
if self.connection.features.can_introspect_check_constraints:
|
||||
unnamed_constraints_index = 0
|
||||
columns = {
|
||||
info.name for info in self.get_table_description(cursor, table_name)
|
||||
}
|
||||
if self.connection.mysql_is_mariadb:
|
||||
type_query = """
|
||||
SELECT c.constraint_name, c.check_clause
|
||||
FROM information_schema.check_constraints AS c
|
||||
WHERE
|
||||
c.constraint_schema = DATABASE() AND
|
||||
c.table_name = %s
|
||||
"""
|
||||
else:
|
||||
type_query = """
|
||||
SELECT cc.constraint_name, cc.check_clause
|
||||
FROM
|
||||
information_schema.check_constraints AS cc,
|
||||
information_schema.table_constraints AS tc
|
||||
WHERE
|
||||
cc.constraint_schema = DATABASE() AND
|
||||
tc.table_schema = cc.constraint_schema AND
|
||||
cc.constraint_name = tc.constraint_name AND
|
||||
tc.constraint_type = 'CHECK' AND
|
||||
tc.table_name = %s
|
||||
"""
|
||||
cursor.execute(type_query, [table_name])
|
||||
for constraint, check_clause in cursor.fetchall():
|
||||
constraint_columns = self._parse_constraint_columns(
|
||||
check_clause, columns
|
||||
)
|
||||
# Ensure uniqueness of unnamed constraints. Unnamed unique
|
||||
# and check columns constraints have the same name as
|
||||
# a column.
|
||||
if set(constraint_columns) == {constraint}:
|
||||
unnamed_constraints_index += 1
|
||||
constraint = "__unnamed_constraint_%s__" % unnamed_constraints_index
|
||||
constraints[constraint] = {
|
||||
"columns": constraint_columns,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"check": True,
|
||||
"foreign_key": None,
|
||||
}
|
||||
# Now add in the indexes
|
||||
cursor.execute(
|
||||
"SHOW INDEX FROM %s" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
for table, non_unique, index, colseq, column, order, type_ in [
|
||||
x[:6] + (x[10],) for x in cursor.fetchall()
|
||||
]:
|
||||
if index not in constraints:
|
||||
constraints[index] = {
|
||||
"columns": OrderedSet(),
|
||||
"primary_key": False,
|
||||
"unique": not non_unique,
|
||||
"check": False,
|
||||
"foreign_key": None,
|
||||
}
|
||||
if self.connection.features.supports_index_column_ordering:
|
||||
constraints[index]["orders"] = []
|
||||
constraints[index]["index"] = True
|
||||
constraints[index]["type"] = (
|
||||
Index.suffix if type_ == "BTREE" else type_.lower()
|
||||
)
|
||||
constraints[index]["columns"].add(column)
|
||||
if self.connection.features.supports_index_column_ordering:
|
||||
constraints[index]["orders"].append("DESC" if order == "D" else "ASC")
|
||||
# Convert the sorted sets to lists
|
||||
for constraint in constraints.values():
|
||||
constraint["columns"] = list(constraint["columns"])
|
||||
return constraints
|
||||
@@ -0,0 +1,455 @@
|
||||
import uuid
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.utils import split_tzname_delta
|
||||
from django.db.models import Exists, ExpressionWrapper, Lookup
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
compiler_module = "django.db.backends.mysql.compiler"
|
||||
|
||||
# MySQL stores positive fields as UNSIGNED ints.
|
||||
integer_field_ranges = {
|
||||
**BaseDatabaseOperations.integer_field_ranges,
|
||||
"PositiveSmallIntegerField": (0, 65535),
|
||||
"PositiveIntegerField": (0, 4294967295),
|
||||
"PositiveBigIntegerField": (0, 18446744073709551615),
|
||||
}
|
||||
cast_data_types = {
|
||||
"AutoField": "signed integer",
|
||||
"BigAutoField": "signed integer",
|
||||
"SmallAutoField": "signed integer",
|
||||
"CharField": "char(%(max_length)s)",
|
||||
"DecimalField": "decimal(%(max_digits)s, %(decimal_places)s)",
|
||||
"TextField": "char",
|
||||
"IntegerField": "signed integer",
|
||||
"BigIntegerField": "signed integer",
|
||||
"SmallIntegerField": "signed integer",
|
||||
"PositiveBigIntegerField": "unsigned integer",
|
||||
"PositiveIntegerField": "unsigned integer",
|
||||
"PositiveSmallIntegerField": "unsigned integer",
|
||||
"DurationField": "signed integer",
|
||||
}
|
||||
cast_char_field_without_max_length = "char"
|
||||
explain_prefix = "EXPLAIN"
|
||||
|
||||
# EXTRACT format cannot be passed in parameters.
|
||||
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
# https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
|
||||
if lookup_type == "week_day":
|
||||
# DAYOFWEEK() returns an integer, 1-7, Sunday=1.
|
||||
return f"DAYOFWEEK({sql})", params
|
||||
elif lookup_type == "iso_week_day":
|
||||
# WEEKDAY() returns an integer, 0-6, Monday=0.
|
||||
return f"WEEKDAY({sql}) + 1", params
|
||||
elif lookup_type == "week":
|
||||
# Override the value of default_week_format for consistency with
|
||||
# other database backends.
|
||||
# Mode 3: Monday, 1-53, with 4 or more days this year.
|
||||
return f"WEEK({sql}, 3)", params
|
||||
elif lookup_type == "iso_year":
|
||||
# Get the year part from the YEARWEEK function, which returns a
|
||||
# number as year * 100 + week.
|
||||
return f"TRUNCATE(YEARWEEK({sql}, 3), -2) / 100", params
|
||||
else:
|
||||
# EXTRACT returns 1-53 based on ISO-8601 for the week number.
|
||||
lookup_type = lookup_type.upper()
|
||||
if not self._extract_format_re.fullmatch(lookup_type):
|
||||
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
|
||||
return f"EXTRACT({lookup_type} FROM {sql})", params
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
fields = {
|
||||
"year": "%Y-01-01",
|
||||
"month": "%Y-%m-01",
|
||||
}
|
||||
if lookup_type in fields:
|
||||
format_str = fields[lookup_type]
|
||||
return f"CAST(DATE_FORMAT({sql}, %s) AS DATE)", (*params, format_str)
|
||||
elif lookup_type == "quarter":
|
||||
return (
|
||||
f"MAKEDATE(YEAR({sql}), 1) + "
|
||||
f"INTERVAL QUARTER({sql}) QUARTER - INTERVAL 1 QUARTER",
|
||||
(*params, *params),
|
||||
)
|
||||
elif lookup_type == "week":
|
||||
return f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY)", (*params, *params)
|
||||
else:
|
||||
return f"DATE({sql})", params
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
return f"{sign}{offset}" if offset else tzname
|
||||
|
||||
def _convert_sql_to_tz(self, sql, params, tzname):
|
||||
if tzname and settings.USE_TZ and self.connection.timezone_name != tzname:
|
||||
return f"CONVERT_TZ({sql}, %s, %s)", (
|
||||
*params,
|
||||
self.connection.timezone_name,
|
||||
self._prepare_tzname_delta(tzname),
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"DATE({sql})", params
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"TIME({sql})", params
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
fields = ["year", "month", "day", "hour", "minute", "second"]
|
||||
format = ("%Y-", "%m", "-%d", " %H:", "%i", ":%s")
|
||||
format_def = ("0000-", "01", "-01", " 00:", "00", ":00")
|
||||
if lookup_type == "quarter":
|
||||
return (
|
||||
f"CAST(DATE_FORMAT(MAKEDATE(YEAR({sql}), 1) + "
|
||||
f"INTERVAL QUARTER({sql}) QUARTER - "
|
||||
f"INTERVAL 1 QUARTER, %s) AS DATETIME)"
|
||||
), (*params, *params, "%Y-%m-01 00:00:00")
|
||||
if lookup_type == "week":
|
||||
return (
|
||||
f"CAST(DATE_FORMAT("
|
||||
f"DATE_SUB({sql}, INTERVAL WEEKDAY({sql}) DAY), %s) AS DATETIME)"
|
||||
), (*params, *params, "%Y-%m-%d 00:00:00")
|
||||
try:
|
||||
i = fields.index(lookup_type) + 1
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
format_str = "".join(format[:i] + format_def[i:])
|
||||
return f"CAST(DATE_FORMAT({sql}, %s) AS DATETIME)", (*params, format_str)
|
||||
return sql, params
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
fields = {
|
||||
"hour": "%H:00:00",
|
||||
"minute": "%H:%i:00",
|
||||
"second": "%H:%i:%s",
|
||||
}
|
||||
if lookup_type in fields:
|
||||
format_str = fields[lookup_type]
|
||||
return f"CAST(DATE_FORMAT({sql}, %s) AS TIME)", (*params, format_str)
|
||||
else:
|
||||
return f"TIME({sql})", params
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the tuple of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
return "INTERVAL %s MICROSECOND" % sql
|
||||
|
||||
def force_no_ordering(self):
|
||||
"""
|
||||
"ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
|
||||
columns. If no ordering would otherwise be applied, we don't want any
|
||||
implicit sorting going on.
|
||||
"""
|
||||
return [(None, ("NULL", [], False))]
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# With MySQLdb, cursor objects have an (undocumented) "_executed"
|
||||
# attribute where the exact query sent to the database is saved.
|
||||
# See MySQLdb/cursors.py in the source distribution.
|
||||
# MySQLdb returns string, PyMySQL bytes.
|
||||
return force_str(getattr(cursor, "_executed", None), errors="replace")
|
||||
|
||||
def no_limit_value(self):
|
||||
# 2**64 - 1, as recommended by the MySQL documentation
|
||||
return 18446744073709551615
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith("`") and name.endswith("`"):
|
||||
return name # Quoting once is enough.
|
||||
return "`%s`" % name
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
# MySQL doesn't support an INSERT...RETURNING statement.
|
||||
if not fields:
|
||||
return "", ()
|
||||
columns = [
|
||||
"%s.%s"
|
||||
% (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
)
|
||||
for field in fields
|
||||
]
|
||||
return "RETURNING %s" % ", ".join(columns), ()
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
sql = ["SET FOREIGN_KEY_CHECKS = 0;"]
|
||||
if reset_sequences:
|
||||
# It's faster to TRUNCATE tables that require a sequence reset
|
||||
# since ALTER TABLE AUTO_INCREMENT is slower than TRUNCATE.
|
||||
sql.extend(
|
||||
"%s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("TRUNCATE"),
|
||||
style.SQL_FIELD(self.quote_name(table_name)),
|
||||
)
|
||||
for table_name in tables
|
||||
)
|
||||
else:
|
||||
# Otherwise issue a simple DELETE since it's faster than TRUNCATE
|
||||
# and preserves sequences.
|
||||
sql.extend(
|
||||
"%s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("DELETE"),
|
||||
style.SQL_KEYWORD("FROM"),
|
||||
style.SQL_FIELD(self.quote_name(table_name)),
|
||||
)
|
||||
for table_name in tables
|
||||
)
|
||||
sql.append("SET FOREIGN_KEY_CHECKS = 1;")
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
return [
|
||||
"%s %s %s %s = 1;"
|
||||
% (
|
||||
style.SQL_KEYWORD("ALTER"),
|
||||
style.SQL_KEYWORD("TABLE"),
|
||||
style.SQL_FIELD(self.quote_name(sequence_info["table"])),
|
||||
style.SQL_FIELD("AUTO_INCREMENT"),
|
||||
)
|
||||
for sequence_info in sequences
|
||||
]
|
||||
|
||||
def validate_autopk_value(self, value):
|
||||
# Zero in AUTO_INCREMENT field does not work without the
|
||||
# NO_AUTO_VALUE_ON_ZERO SQL mode.
|
||||
if value == 0 and not self.connection.features.allows_auto_pk_0:
|
||||
raise ValueError(
|
||||
"The database backend does not accept 0 as a value for AutoField."
|
||||
)
|
||||
return value
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
# MySQL doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError(
|
||||
"MySQL backend does not support timezone-aware datetimes when "
|
||||
"USE_TZ is False."
|
||||
)
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Expression values are adapted by the database.
|
||||
if hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
|
||||
# MySQL doesn't support tz-aware times
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("MySQL backend does not support timezone-aware times.")
|
||||
|
||||
return value.isoformat(timespec="microseconds")
|
||||
|
||||
def max_name_length(self):
|
||||
return 64
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
if connector == "^":
|
||||
return "POW(%s)" % ",".join(sub_expressions)
|
||||
# Convert the result to a signed integer since MySQL's binary operators
|
||||
# return an unsigned integer.
|
||||
elif connector in ("&", "|", "<<", "#"):
|
||||
connector = "^" if connector == "#" else connector
|
||||
return "CONVERT(%s, SIGNED)" % connector.join(sub_expressions)
|
||||
elif connector == ">>":
|
||||
lhs, rhs = sub_expressions
|
||||
return "FLOOR(%(lhs)s / POW(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == "BooleanField":
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
elif internal_type == "DateTimeField":
|
||||
if settings.USE_TZ:
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == "UUIDField":
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
return converters
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
if value in (0, 1):
|
||||
value = bool(value)
|
||||
return value
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
def binary_placeholder_sql(self, value):
|
||||
return (
|
||||
"_binary %s" if value is not None and not hasattr(value, "as_sql") else "%s"
|
||||
)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
if internal_type == "TimeField":
|
||||
if self.connection.mysql_is_mariadb:
|
||||
# MariaDB includes the microsecond component in TIME_TO_SEC as
|
||||
# a decimal. MySQL returns an integer without microseconds.
|
||||
return (
|
||||
"CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) "
|
||||
"* 1000000 AS SIGNED)"
|
||||
) % {
|
||||
"lhs": lhs_sql,
|
||||
"rhs": rhs_sql,
|
||||
}, (
|
||||
*lhs_params,
|
||||
*rhs_params,
|
||||
)
|
||||
return (
|
||||
"((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
|
||||
" (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
|
||||
) % {"lhs": lhs_sql, "rhs": rhs_sql}, tuple(lhs_params) * 2 + tuple(
|
||||
rhs_params
|
||||
) * 2
|
||||
params = (*rhs_params, *lhs_params)
|
||||
return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), params
|
||||
|
||||
def explain_query_prefix(self, format=None, **options):
|
||||
# Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
|
||||
if format and format.upper() == "TEXT":
|
||||
format = "TRADITIONAL"
|
||||
elif (
|
||||
not format and "TREE" in self.connection.features.supported_explain_formats
|
||||
):
|
||||
# Use TREE by default (if supported) as it's more informative.
|
||||
format = "TREE"
|
||||
analyze = options.pop("analyze", False)
|
||||
prefix = super().explain_query_prefix(format, **options)
|
||||
if analyze and self.connection.features.supports_explain_analyze:
|
||||
# MariaDB uses ANALYZE instead of EXPLAIN ANALYZE.
|
||||
prefix = (
|
||||
"ANALYZE" if self.connection.mysql_is_mariadb else prefix + " ANALYZE"
|
||||
)
|
||||
if format and not (analyze and not self.connection.mysql_is_mariadb):
|
||||
# Only MariaDB supports the analyze option with formats.
|
||||
prefix += " FORMAT=%s" % format
|
||||
return prefix
|
||||
|
||||
def regex_lookup(self, lookup_type):
|
||||
# REGEXP_LIKE doesn't exist in MariaDB.
|
||||
if self.connection.mysql_is_mariadb:
|
||||
if lookup_type == "regex":
|
||||
return "%s REGEXP BINARY %s"
|
||||
return "%s REGEXP %s"
|
||||
|
||||
match_option = "c" if lookup_type == "regex" else "i"
|
||||
return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option
|
||||
|
||||
def insert_statement(self, on_conflict=None):
|
||||
if on_conflict == OnConflict.IGNORE:
|
||||
return "INSERT IGNORE INTO"
|
||||
return super().insert_statement(on_conflict=on_conflict)
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
lookup = "%s"
|
||||
if internal_type == "JSONField":
|
||||
if self.connection.mysql_is_mariadb or lookup_type in (
|
||||
"iexact",
|
||||
"contains",
|
||||
"icontains",
|
||||
"startswith",
|
||||
"istartswith",
|
||||
"endswith",
|
||||
"iendswith",
|
||||
"regex",
|
||||
"iregex",
|
||||
):
|
||||
lookup = "JSON_UNQUOTE(%s)"
|
||||
return lookup
|
||||
|
||||
def conditional_expression_supported_in_where_clause(self, expression):
|
||||
# MySQL ignores indexes with boolean fields unless they're compared
|
||||
# directly to a boolean value.
|
||||
if isinstance(expression, (Exists, Lookup)):
|
||||
return True
|
||||
if isinstance(expression, ExpressionWrapper) and expression.conditional:
|
||||
return self.conditional_expression_supported_in_where_clause(
|
||||
expression.expression
|
||||
)
|
||||
if getattr(expression, "conditional", False):
|
||||
return False
|
||||
return super().conditional_expression_supported_in_where_clause(expression)
|
||||
|
||||
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
||||
if on_conflict == OnConflict.UPDATE:
|
||||
conflict_suffix_sql = "ON DUPLICATE KEY UPDATE %(fields)s"
|
||||
# The use of VALUES() is deprecated in MySQL 8.0.20+. Instead, use
|
||||
# aliases for the new row and its columns available in MySQL
|
||||
# 8.0.19+.
|
||||
if not self.connection.mysql_is_mariadb:
|
||||
if self.connection.mysql_version >= (8, 0, 19):
|
||||
conflict_suffix_sql = f"AS new {conflict_suffix_sql}"
|
||||
field_sql = "%(field)s = new.%(field)s"
|
||||
else:
|
||||
field_sql = "%(field)s = VALUES(%(field)s)"
|
||||
# Use VALUE() on MariaDB.
|
||||
else:
|
||||
field_sql = "%(field)s = VALUE(%(field)s)"
|
||||
|
||||
fields = ", ".join(
|
||||
[
|
||||
field_sql % {"field": field}
|
||||
for field in map(self.quote_name, update_fields)
|
||||
]
|
||||
)
|
||||
return conflict_suffix_sql % {"fields": fields}
|
||||
return super().on_conflict_suffix_sql(
|
||||
fields,
|
||||
on_conflict,
|
||||
update_fields,
|
||||
unique_fields,
|
||||
)
|
||||
@@ -0,0 +1,281 @@
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.models import NOT_PROVIDED, F, UniqueConstraint
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
sql_rename_table = "RENAME TABLE %(old_table)s TO %(new_table)s"
|
||||
|
||||
sql_alter_column_null = "MODIFY %(column)s %(type)s NULL"
|
||||
sql_alter_column_not_null = "MODIFY %(column)s %(type)s NOT NULL"
|
||||
sql_alter_column_type = "MODIFY %(column)s %(type)s%(collation)s%(comment)s"
|
||||
sql_alter_column_no_default_null = "ALTER COLUMN %(column)s SET DEFAULT NULL"
|
||||
|
||||
# No 'CASCADE' which works as a no-op in MySQL but is undocumented
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
|
||||
|
||||
sql_delete_unique = "ALTER TABLE %(table)s DROP INDEX %(name)s"
|
||||
sql_create_column_inline_fk = (
|
||||
", ADD CONSTRAINT %(name)s FOREIGN KEY (%(column)s) "
|
||||
"REFERENCES %(to_table)s(%(to_column)s)"
|
||||
)
|
||||
sql_delete_fk = "ALTER TABLE %(table)s DROP FOREIGN KEY %(name)s"
|
||||
|
||||
sql_delete_index = "DROP INDEX %(name)s ON %(table)s"
|
||||
sql_rename_index = "ALTER TABLE %(table)s RENAME INDEX %(old_name)s TO %(new_name)s"
|
||||
|
||||
sql_create_pk = (
|
||||
"ALTER TABLE %(table)s ADD CONSTRAINT %(name)s PRIMARY KEY (%(columns)s)"
|
||||
)
|
||||
sql_delete_pk = "ALTER TABLE %(table)s DROP PRIMARY KEY"
|
||||
|
||||
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
|
||||
|
||||
sql_alter_table_comment = "ALTER TABLE %(table)s COMMENT = %(comment)s"
|
||||
sql_alter_column_comment = None
|
||||
|
||||
@property
|
||||
def sql_delete_check(self):
|
||||
if self.connection.mysql_is_mariadb:
|
||||
# The name of the column check constraint is the same as the field
|
||||
# name on MariaDB. Adding IF EXISTS clause prevents migrations
|
||||
# crash. Constraint is removed during a "MODIFY" column statement.
|
||||
return "ALTER TABLE %(table)s DROP CONSTRAINT IF EXISTS %(name)s"
|
||||
return "ALTER TABLE %(table)s DROP CHECK %(name)s"
|
||||
|
||||
@property
|
||||
def sql_rename_column(self):
|
||||
is_mariadb = self.connection.mysql_is_mariadb
|
||||
if is_mariadb and self.connection.mysql_version < (10, 5, 2):
|
||||
# MariaDB < 10.5.2 doesn't support an
|
||||
# "ALTER TABLE ... RENAME COLUMN" statement.
|
||||
return "ALTER TABLE %(table)s CHANGE %(old_column)s %(new_column)s %(type)s"
|
||||
return super().sql_rename_column
|
||||
|
||||
def quote_value(self, value):
|
||||
self.connection.ensure_connection()
|
||||
# MySQLdb escapes to string, PyMySQL to bytes.
|
||||
quoted = self.connection.connection.escape(
|
||||
value, self.connection.connection.encoders
|
||||
)
|
||||
if isinstance(value, str) and isinstance(quoted, bytes):
|
||||
quoted = quoted.decode()
|
||||
return quoted
|
||||
|
||||
def _is_limited_data_type(self, field):
|
||||
db_type = field.db_type(self.connection)
|
||||
return (
|
||||
db_type is not None
|
||||
and db_type.lower() in self.connection._limited_data_types
|
||||
)
|
||||
|
||||
def _is_text_or_blob(self, field):
|
||||
db_type = field.db_type(self.connection)
|
||||
return db_type and db_type.lower().endswith(("blob", "text"))
|
||||
|
||||
def skip_default(self, field):
|
||||
default_is_empty = self.effective_default(field) in ("", b"")
|
||||
if default_is_empty and self._is_text_or_blob(field):
|
||||
return True
|
||||
if not self._supports_limited_data_type_defaults:
|
||||
return self._is_limited_data_type(field)
|
||||
return False
|
||||
|
||||
def skip_default_on_alter(self, field):
|
||||
default_is_empty = self.effective_default(field) in ("", b"")
|
||||
if default_is_empty and self._is_text_or_blob(field):
|
||||
return True
|
||||
if self._is_limited_data_type(field) and not self.connection.mysql_is_mariadb:
|
||||
# MySQL doesn't support defaults for BLOB and TEXT in the
|
||||
# ALTER COLUMN statement.
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def _supports_limited_data_type_defaults(self):
|
||||
# MariaDB and MySQL >= 8.0.13 support defaults for BLOB and TEXT.
|
||||
if self.connection.mysql_is_mariadb:
|
||||
return True
|
||||
return self.connection.mysql_version >= (8, 0, 13)
|
||||
|
||||
def _column_default_sql(self, field):
|
||||
if (
|
||||
not self.connection.mysql_is_mariadb
|
||||
and self._supports_limited_data_type_defaults
|
||||
and self._is_limited_data_type(field)
|
||||
):
|
||||
# MySQL supports defaults for BLOB and TEXT columns only if the
|
||||
# default value is written as an expression i.e. in parentheses.
|
||||
return "(%s)"
|
||||
return super()._column_default_sql(field)
|
||||
|
||||
def add_field(self, model, field):
|
||||
super().add_field(model, field)
|
||||
|
||||
# Simulate the effect of a one-off default.
|
||||
# field.default may be unhashable, so a set isn't used for "in" check.
|
||||
if self.skip_default(field) and field.default not in (None, NOT_PROVIDED):
|
||||
effective_default = self.effective_default(field)
|
||||
self.execute(
|
||||
"UPDATE %(table)s SET %(column)s = %%s"
|
||||
% {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
"column": self.quote_name(field.column),
|
||||
},
|
||||
[effective_default],
|
||||
)
|
||||
|
||||
def remove_constraint(self, model, constraint):
|
||||
if (
|
||||
isinstance(constraint, UniqueConstraint)
|
||||
and constraint.create_sql(model, self) is not None
|
||||
):
|
||||
self._create_missing_fk_index(
|
||||
model,
|
||||
fields=constraint.fields,
|
||||
expressions=constraint.expressions,
|
||||
)
|
||||
super().remove_constraint(model, constraint)
|
||||
|
||||
def remove_index(self, model, index):
|
||||
self._create_missing_fk_index(
|
||||
model,
|
||||
fields=[field_name for field_name, _ in index.fields_orders],
|
||||
expressions=index.expressions,
|
||||
)
|
||||
super().remove_index(model, index)
|
||||
|
||||
def _field_should_be_indexed(self, model, field):
|
||||
if not super()._field_should_be_indexed(model, field):
|
||||
return False
|
||||
|
||||
storage = self.connection.introspection.get_storage_engine(
|
||||
self.connection.cursor(), model._meta.db_table
|
||||
)
|
||||
# No need to create an index for ForeignKey fields except if
|
||||
# db_constraint=False because the index from that constraint won't be
|
||||
# created.
|
||||
if (
|
||||
storage == "InnoDB"
|
||||
and field.get_internal_type() == "ForeignKey"
|
||||
and field.db_constraint
|
||||
):
|
||||
return False
|
||||
return not self._is_limited_data_type(field)
|
||||
|
||||
def _create_missing_fk_index(
|
||||
self,
|
||||
model,
|
||||
*,
|
||||
fields,
|
||||
expressions=None,
|
||||
):
|
||||
"""
|
||||
MySQL can remove an implicit FK index on a field when that field is
|
||||
covered by another index like a unique_together. "covered" here means
|
||||
that the more complex index has the FK field as its first field (see
|
||||
https://bugs.mysql.com/bug.php?id=37910).
|
||||
|
||||
Manually create an implicit FK index to make it possible to remove the
|
||||
composed index.
|
||||
"""
|
||||
first_field_name = None
|
||||
if fields:
|
||||
first_field_name = fields[0]
|
||||
elif (
|
||||
expressions
|
||||
and self.connection.features.supports_expression_indexes
|
||||
and isinstance(expressions[0], F)
|
||||
and LOOKUP_SEP not in expressions[0].name
|
||||
):
|
||||
first_field_name = expressions[0].name
|
||||
|
||||
if not first_field_name:
|
||||
return
|
||||
|
||||
first_field = model._meta.get_field(first_field_name)
|
||||
if first_field.get_internal_type() == "ForeignKey":
|
||||
column = self.connection.introspection.identifier_converter(
|
||||
first_field.column
|
||||
)
|
||||
with self.connection.cursor() as cursor:
|
||||
constraint_names = [
|
||||
name
|
||||
for name, infodict in self.connection.introspection.get_constraints(
|
||||
cursor, model._meta.db_table
|
||||
).items()
|
||||
if infodict["index"] and infodict["columns"][0] == column
|
||||
]
|
||||
# There are no other indexes that starts with the FK field, only
|
||||
# the index that is expected to be deleted.
|
||||
if len(constraint_names) == 1:
|
||||
self.execute(
|
||||
self._create_index_sql(model, fields=[first_field], suffix="")
|
||||
)
|
||||
|
||||
def _delete_composed_index(self, model, fields, *args):
|
||||
self._create_missing_fk_index(model, fields=fields)
|
||||
return super()._delete_composed_index(model, fields, *args)
|
||||
|
||||
def _set_field_new_type(self, field, new_type):
|
||||
"""
|
||||
Keep the NULL and DEFAULT properties of the old field. If it has
|
||||
changed, it will be handled separately.
|
||||
"""
|
||||
if field.has_db_default():
|
||||
default_sql, params = self.db_default_sql(field)
|
||||
default_sql %= tuple(self.quote_value(p) for p in params)
|
||||
new_type += f" DEFAULT {default_sql}"
|
||||
if field.null:
|
||||
new_type += " NULL"
|
||||
else:
|
||||
new_type += " NOT NULL"
|
||||
return new_type
|
||||
|
||||
def _alter_column_type_sql(
|
||||
self, model, old_field, new_field, new_type, old_collation, new_collation
|
||||
):
|
||||
new_type = self._set_field_new_type(old_field, new_type)
|
||||
return super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
|
||||
def _field_db_check(self, field, field_db_params):
|
||||
if self.connection.mysql_is_mariadb and self.connection.mysql_version >= (
|
||||
10,
|
||||
5,
|
||||
2,
|
||||
):
|
||||
return super()._field_db_check(field, field_db_params)
|
||||
# On MySQL and MariaDB < 10.5.2 (no support for
|
||||
# "ALTER TABLE ... RENAME COLUMN" statements), check constraints with
|
||||
# the column name as it requires explicit recreation when the column is
|
||||
# renamed.
|
||||
return field_db_params["check"]
|
||||
|
||||
def _rename_field_sql(self, table, old_field, new_field, new_type):
|
||||
new_type = self._set_field_new_type(old_field, new_type)
|
||||
return super()._rename_field_sql(table, old_field, new_field, new_type)
|
||||
|
||||
def _alter_column_comment_sql(self, model, new_field, new_type, new_db_comment):
|
||||
# Comment is alter when altering the column type.
|
||||
return "", []
|
||||
|
||||
def _comment_sql(self, comment):
|
||||
comment_sql = super()._comment_sql(comment)
|
||||
return f" COMMENT {comment_sql}"
|
||||
|
||||
def _alter_column_null_sql(self, model, old_field, new_field):
|
||||
if not new_field.has_db_default():
|
||||
return super()._alter_column_null_sql(model, old_field, new_field)
|
||||
|
||||
new_db_params = new_field.db_parameters(connection=self.connection)
|
||||
type_sql = self._set_field_new_type(new_field, new_db_params["type"])
|
||||
return (
|
||||
"MODIFY %(column)s %(type)s"
|
||||
% {
|
||||
"column": self.quote_name(new_field.column),
|
||||
"type": type_sql,
|
||||
},
|
||||
[],
|
||||
)
|
||||
@@ -0,0 +1,77 @@
|
||||
from django.core import checks
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
from django.utils.version import get_docs_version
|
||||
|
||||
|
||||
class DatabaseValidation(BaseDatabaseValidation):
|
||||
def check(self, **kwargs):
|
||||
issues = super().check(**kwargs)
|
||||
issues.extend(self._check_sql_mode(**kwargs))
|
||||
return issues
|
||||
|
||||
def _check_sql_mode(self, **kwargs):
|
||||
if not (
|
||||
self.connection.sql_mode & {"STRICT_TRANS_TABLES", "STRICT_ALL_TABLES"}
|
||||
):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s Strict Mode is not set for database connection '%s'"
|
||||
% (self.connection.display_name, self.connection.alias),
|
||||
hint=(
|
||||
"%s's Strict Mode fixes many data integrity problems in "
|
||||
"%s, such as data truncation upon insertion, by "
|
||||
"escalating warnings into errors. It is strongly "
|
||||
"recommended you activate it. See: "
|
||||
"https://docs.djangoproject.com/en/%s/ref/databases/"
|
||||
"#mysql-sql-mode"
|
||||
% (
|
||||
self.connection.display_name,
|
||||
self.connection.display_name,
|
||||
get_docs_version(),
|
||||
),
|
||||
),
|
||||
id="mysql.W002",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def check_field_type(self, field, field_type):
|
||||
"""
|
||||
MySQL has the following field length restriction:
|
||||
No character (varchar) fields can have a length exceeding 255
|
||||
characters if they have a unique index on them.
|
||||
MySQL doesn't support a database index on some data types.
|
||||
"""
|
||||
errors = []
|
||||
if (
|
||||
field_type.startswith("varchar")
|
||||
and field.unique
|
||||
and (field.max_length is None or int(field.max_length) > 255)
|
||||
):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"%s may not allow unique CharFields to have a max_length "
|
||||
"> 255." % self.connection.display_name,
|
||||
obj=field,
|
||||
hint=(
|
||||
"See: https://docs.djangoproject.com/en/%s/ref/"
|
||||
"databases/#mysql-character-fields" % get_docs_version()
|
||||
),
|
||||
id="mysql.W003",
|
||||
)
|
||||
)
|
||||
|
||||
if field.db_index and field_type.lower() in self.connection._limited_data_types:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"%s does not support a database index on %s columns."
|
||||
% (self.connection.display_name, field_type),
|
||||
hint=(
|
||||
"An index won't be created. Silence this warning if "
|
||||
"you don't care about it."
|
||||
),
|
||||
obj=field,
|
||||
id="fields.W162",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
@@ -0,0 +1,668 @@
|
||||
"""
|
||||
Oracle database backend for Django.
|
||||
|
||||
Requires oracledb: https://oracle.github.io/python-oracledb/
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
import os
|
||||
import platform
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.backends.oracle.oracledb_any import is_oracledb
|
||||
from django.db.backends.utils import debug_transaction
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.encoding import force_bytes, force_str
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.version import get_version_tuple
|
||||
|
||||
try:
|
||||
from django.db.backends.oracle.oracledb_any import oracledb as Database
|
||||
except ImportError as e:
|
||||
raise ImproperlyConfigured(f"Error loading oracledb module: {e}")
|
||||
|
||||
|
||||
def _setup_environment(environ):
|
||||
# Cygwin requires some special voodoo to set the environment variables
|
||||
# properly so that Oracle will see them.
|
||||
if platform.system().upper().startswith("CYGWIN"):
|
||||
try:
|
||||
import ctypes
|
||||
except ImportError as e:
|
||||
raise ImproperlyConfigured(
|
||||
"Error loading ctypes: %s; "
|
||||
"the Oracle backend requires ctypes to "
|
||||
"operate correctly under Cygwin." % e
|
||||
)
|
||||
kernel32 = ctypes.CDLL("kernel32")
|
||||
for name, value in environ:
|
||||
kernel32.SetEnvironmentVariableA(name, value)
|
||||
else:
|
||||
os.environ.update(environ)
|
||||
|
||||
|
||||
_setup_environment(
|
||||
[
|
||||
# Oracle takes client-side character set encoding from the environment.
|
||||
("NLS_LANG", ".AL32UTF8"),
|
||||
# This prevents Unicode from getting mangled by getting encoded into the
|
||||
# potentially non-Unicode database character set.
|
||||
("ORA_NCHAR_LITERAL_REPLACE", "TRUE"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Some of these import oracledb, so import them after checking if it's
|
||||
# installed.
|
||||
from .client import DatabaseClient # NOQA
|
||||
from .creation import DatabaseCreation # NOQA
|
||||
from .features import DatabaseFeatures # NOQA
|
||||
from .introspection import DatabaseIntrospection # NOQA
|
||||
from .operations import DatabaseOperations # NOQA
|
||||
from .schema import DatabaseSchemaEditor # NOQA
|
||||
from .utils import Oracle_datetime, dsn # NOQA
|
||||
from .validation import DatabaseValidation # NOQA
|
||||
|
||||
|
||||
@contextmanager
|
||||
def wrap_oracle_errors():
|
||||
try:
|
||||
yield
|
||||
except Database.DatabaseError as e:
|
||||
# oracledb raises a oracledb.DatabaseError exception with the
|
||||
# following attributes and values:
|
||||
# code = 2091
|
||||
# message = 'ORA-02091: transaction rolled back
|
||||
# 'ORA-02291: integrity constraint (TEST_DJANGOTEST.SYS
|
||||
# _C00102056) violated - parent key not found'
|
||||
# or:
|
||||
# 'ORA-00001: unique constraint (DJANGOTEST.DEFERRABLE_
|
||||
# PINK_CONSTRAINT) violated
|
||||
# Convert that case to Django's IntegrityError exception.
|
||||
x = e.args[0]
|
||||
if (
|
||||
hasattr(x, "code")
|
||||
and hasattr(x, "message")
|
||||
and x.code == 2091
|
||||
and ("ORA-02291" in x.message or "ORA-00001" in x.message)
|
||||
):
|
||||
raise IntegrityError(*tuple(e.args))
|
||||
raise
|
||||
|
||||
|
||||
class _UninitializedOperatorsDescriptor:
|
||||
def __get__(self, instance, cls=None):
|
||||
# If connection.operators is looked up before a connection has been
|
||||
# created, transparently initialize connection.operators to avert an
|
||||
# AttributeError.
|
||||
if instance is None:
|
||||
raise AttributeError("operators not available as class attribute")
|
||||
# Creating a cursor will initialize the operators.
|
||||
instance.cursor().close()
|
||||
return instance.__dict__["operators"]
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "oracle"
|
||||
display_name = "Oracle"
|
||||
# This dictionary maps Field objects to their associated Oracle column
|
||||
# types, as strings. Column-type strings can contain format strings; they'll
|
||||
# be interpolated against the values of Field.__dict__ before being output.
|
||||
# If a column type is set to None, it won't be included in the output.
|
||||
#
|
||||
# Any format strings starting with "qn_" are quoted before being used in the
|
||||
# output (the "qn_" prefix is stripped before the lookup is performed.
|
||||
data_types = {
|
||||
"AutoField": "NUMBER(11) GENERATED BY DEFAULT ON NULL AS IDENTITY",
|
||||
"BigAutoField": "NUMBER(19) GENERATED BY DEFAULT ON NULL AS IDENTITY",
|
||||
"BinaryField": "BLOB",
|
||||
"BooleanField": "NUMBER(1)",
|
||||
"CharField": "NVARCHAR2(%(max_length)s)",
|
||||
"DateField": "DATE",
|
||||
"DateTimeField": "TIMESTAMP",
|
||||
"DecimalField": "NUMBER(%(max_digits)s, %(decimal_places)s)",
|
||||
"DurationField": "INTERVAL DAY(9) TO SECOND(6)",
|
||||
"FileField": "NVARCHAR2(%(max_length)s)",
|
||||
"FilePathField": "NVARCHAR2(%(max_length)s)",
|
||||
"FloatField": "DOUBLE PRECISION",
|
||||
"IntegerField": "NUMBER(11)",
|
||||
"JSONField": "NCLOB",
|
||||
"BigIntegerField": "NUMBER(19)",
|
||||
"IPAddressField": "VARCHAR2(15)",
|
||||
"GenericIPAddressField": "VARCHAR2(39)",
|
||||
"OneToOneField": "NUMBER(11)",
|
||||
"PositiveBigIntegerField": "NUMBER(19)",
|
||||
"PositiveIntegerField": "NUMBER(11)",
|
||||
"PositiveSmallIntegerField": "NUMBER(11)",
|
||||
"SlugField": "NVARCHAR2(%(max_length)s)",
|
||||
"SmallAutoField": "NUMBER(5) GENERATED BY DEFAULT ON NULL AS IDENTITY",
|
||||
"SmallIntegerField": "NUMBER(11)",
|
||||
"TextField": "NCLOB",
|
||||
"TimeField": "TIMESTAMP",
|
||||
"URLField": "VARCHAR2(%(max_length)s)",
|
||||
"UUIDField": "VARCHAR2(32)",
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
"BooleanField": "%(qn_column)s IN (0,1)",
|
||||
"JSONField": "%(qn_column)s IS JSON",
|
||||
"PositiveBigIntegerField": "%(qn_column)s >= 0",
|
||||
"PositiveIntegerField": "%(qn_column)s >= 0",
|
||||
"PositiveSmallIntegerField": "%(qn_column)s >= 0",
|
||||
}
|
||||
|
||||
# Oracle doesn't support a database index on these columns.
|
||||
_limited_data_types = ("clob", "nclob", "blob")
|
||||
|
||||
operators = _UninitializedOperatorsDescriptor()
|
||||
|
||||
_standard_operators = {
|
||||
"exact": "= %s",
|
||||
"iexact": "= UPPER(%s)",
|
||||
"contains": (
|
||||
"LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"icontains": (
|
||||
"LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) "
|
||||
"ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"gt": "> %s",
|
||||
"gte": ">= %s",
|
||||
"lt": "< %s",
|
||||
"lte": "<= %s",
|
||||
"startswith": (
|
||||
"LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"endswith": (
|
||||
"LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"istartswith": (
|
||||
"LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) "
|
||||
"ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
"iendswith": (
|
||||
"LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) "
|
||||
"ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
),
|
||||
}
|
||||
|
||||
_likec_operators = {
|
||||
**_standard_operators,
|
||||
"contains": "LIKEC %s ESCAPE '\\'",
|
||||
"icontains": "LIKEC UPPER(%s) ESCAPE '\\'",
|
||||
"startswith": "LIKEC %s ESCAPE '\\'",
|
||||
"endswith": "LIKEC %s ESCAPE '\\'",
|
||||
"istartswith": "LIKEC UPPER(%s) ESCAPE '\\'",
|
||||
"iendswith": "LIKEC UPPER(%s) ESCAPE '\\'",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, %, _)
|
||||
# should be escaped on the database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
|
||||
_pattern_ops = {
|
||||
"contains": "'%%' || {} || '%%'",
|
||||
"icontains": "'%%' || UPPER({}) || '%%'",
|
||||
"startswith": "{} || '%%'",
|
||||
"istartswith": "UPPER({}) || '%%'",
|
||||
"endswith": "'%%' || {}",
|
||||
"iendswith": "'%%' || UPPER({})",
|
||||
}
|
||||
|
||||
_standard_pattern_ops = {
|
||||
k: "LIKE TRANSLATE( " + v + " USING NCHAR_CS)"
|
||||
" ESCAPE TRANSLATE('\\' USING NCHAR_CS)"
|
||||
for k, v in _pattern_ops.items()
|
||||
}
|
||||
_likec_pattern_ops = {
|
||||
k: "LIKEC " + v + " ESCAPE '\\'" for k, v in _pattern_ops.items()
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
validation_class = DatabaseValidation
|
||||
_connection_pools = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
use_returning_into = self.settings_dict["OPTIONS"].get(
|
||||
"use_returning_into", True
|
||||
)
|
||||
self.features.can_return_columns_from_insert = use_returning_into
|
||||
|
||||
@property
|
||||
def is_pool(self):
|
||||
return self.settings_dict["OPTIONS"].get("pool", False)
|
||||
|
||||
@property
|
||||
def pool(self):
|
||||
if not self.is_pool:
|
||||
return None
|
||||
|
||||
if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
|
||||
raise ImproperlyConfigured(
|
||||
"Pooling doesn't support persistent connections."
|
||||
)
|
||||
|
||||
pool_key = (self.alias, self.settings_dict["USER"])
|
||||
if pool_key not in self._connection_pools:
|
||||
connect_kwargs = self.get_connection_params()
|
||||
pool_options = connect_kwargs.pop("pool")
|
||||
if pool_options is not True:
|
||||
connect_kwargs.update(pool_options)
|
||||
|
||||
pool = Database.create_pool(
|
||||
user=self.settings_dict["USER"],
|
||||
password=self.settings_dict["PASSWORD"],
|
||||
dsn=dsn(self.settings_dict),
|
||||
**connect_kwargs,
|
||||
)
|
||||
self._connection_pools.setdefault(pool_key, pool)
|
||||
|
||||
return self._connection_pools[pool_key]
|
||||
|
||||
def close_pool(self):
|
||||
if self.pool:
|
||||
self.pool.close(force=True)
|
||||
pool_key = (self.alias, self.settings_dict["USER"])
|
||||
del self._connection_pools[pool_key]
|
||||
|
||||
def get_database_version(self):
|
||||
return self.oracle_version
|
||||
|
||||
def get_connection_params(self):
|
||||
# Pooling feature is only supported for oracledb.
|
||||
if self.is_pool and not is_oracledb:
|
||||
raise ImproperlyConfigured(
|
||||
"Pooling isn't supported by cx_Oracle. Use python-oracledb instead."
|
||||
)
|
||||
conn_params = self.settings_dict["OPTIONS"].copy()
|
||||
if "use_returning_into" in conn_params:
|
||||
del conn_params["use_returning_into"]
|
||||
return conn_params
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
if self.pool:
|
||||
return self.pool.acquire()
|
||||
return Database.connect(
|
||||
user=self.settings_dict["USER"],
|
||||
password=self.settings_dict["PASSWORD"],
|
||||
dsn=dsn(self.settings_dict),
|
||||
**conn_params,
|
||||
)
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
cursor = self.create_cursor()
|
||||
# Set the territory first. The territory overrides NLS_DATE_FORMAT
|
||||
# and NLS_TIMESTAMP_FORMAT to the territory default. When all of
|
||||
# these are set in single statement it isn't clear what is supposed
|
||||
# to happen.
|
||||
cursor.execute("ALTER SESSION SET NLS_TERRITORY = 'AMERICA'")
|
||||
# Set Oracle date to ANSI date format. This only needs to execute
|
||||
# once when we create a new connection. We also set the Territory
|
||||
# to 'AMERICA' which forces Sunday to evaluate to a '1' in
|
||||
# TO_CHAR().
|
||||
cursor.execute(
|
||||
"ALTER SESSION SET NLS_DATE_FORMAT = 'YYYY-MM-DD HH24:MI:SS'"
|
||||
" NLS_TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF'"
|
||||
+ (" TIME_ZONE = 'UTC'" if settings.USE_TZ else "")
|
||||
)
|
||||
cursor.close()
|
||||
if "operators" not in self.__dict__:
|
||||
# Ticket #14149: Check whether our LIKE implementation will
|
||||
# work for this connection or we need to fall back on LIKEC.
|
||||
# This check is performed only once per DatabaseWrapper
|
||||
# instance per thread, since subsequent connections will use
|
||||
# the same settings.
|
||||
cursor = self.create_cursor()
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT 1 FROM DUAL WHERE DUMMY %s"
|
||||
% self._standard_operators["contains"],
|
||||
["X"],
|
||||
)
|
||||
except Database.DatabaseError:
|
||||
self.operators = self._likec_operators
|
||||
self.pattern_ops = self._likec_pattern_ops
|
||||
else:
|
||||
self.operators = self._standard_operators
|
||||
self.pattern_ops = self._standard_pattern_ops
|
||||
cursor.close()
|
||||
self.connection.stmtcachesize = 20
|
||||
# Ensure all changes are preserved even when AUTOCOMMIT is False.
|
||||
if not self.get_autocommit():
|
||||
self.commit()
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
return FormatStylePlaceholderCursor(self.connection, self)
|
||||
|
||||
def _commit(self):
|
||||
if self.connection is not None:
|
||||
with debug_transaction(self, "COMMIT"), wrap_oracle_errors():
|
||||
return self.connection.commit()
|
||||
|
||||
# Oracle doesn't support releasing savepoints. But we fake them when query
|
||||
# logging is enabled to keep query counts consistent with other backends.
|
||||
def _savepoint_commit(self, sid):
|
||||
if self.queries_logged:
|
||||
self.queries_log.append(
|
||||
{
|
||||
"sql": "-- RELEASE SAVEPOINT %s (faked)" % self.ops.quote_name(sid),
|
||||
"time": "0.000",
|
||||
}
|
||||
)
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit = autocommit
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check constraints by setting them to immediate. Return them to deferred
|
||||
afterward.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
|
||||
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
|
||||
|
||||
def is_usable(self):
|
||||
try:
|
||||
self.connection.ping()
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def close_if_health_check_failed(self):
|
||||
if self.pool:
|
||||
# The pool only returns healthy connections.
|
||||
return
|
||||
return super().close_if_health_check_failed()
|
||||
|
||||
@cached_property
|
||||
def oracle_version(self):
|
||||
with self.temporary_connection():
|
||||
return tuple(int(x) for x in self.connection.version.split("."))
|
||||
|
||||
@cached_property
|
||||
def oracledb_version(self):
|
||||
return get_version_tuple(Database.__version__)
|
||||
|
||||
|
||||
class OracleParam:
|
||||
"""
|
||||
Wrapper object for formatting parameters for Oracle. If the string
|
||||
representation of the value is large enough (greater than 4000 characters)
|
||||
the input size needs to be set as CLOB. Alternatively, if the parameter
|
||||
has an `input_size` attribute, then the value of the `input_size` attribute
|
||||
will be used instead. Otherwise, no input size will be set for the
|
||||
parameter when executing the query.
|
||||
"""
|
||||
|
||||
def __init__(self, param, cursor, strings_only=False):
|
||||
# With raw SQL queries, datetimes can reach this function
|
||||
# without being converted by DateTimeField.get_db_prep_value.
|
||||
if settings.USE_TZ and (
|
||||
isinstance(param, datetime.datetime)
|
||||
and not isinstance(param, Oracle_datetime)
|
||||
):
|
||||
param = Oracle_datetime.from_datetime(param)
|
||||
|
||||
string_size = 0
|
||||
has_boolean_data_type = (
|
||||
cursor.database.features.supports_boolean_expr_in_select_clause
|
||||
)
|
||||
if not has_boolean_data_type:
|
||||
# Oracle < 23c doesn't recognize True and False correctly.
|
||||
if param is True:
|
||||
param = 1
|
||||
elif param is False:
|
||||
param = 0
|
||||
if hasattr(param, "bind_parameter"):
|
||||
self.force_bytes = param.bind_parameter(cursor)
|
||||
elif isinstance(param, (Database.Binary, datetime.timedelta)):
|
||||
self.force_bytes = param
|
||||
else:
|
||||
# To transmit to the database, we need Unicode if supported
|
||||
# To get size right, we must consider bytes.
|
||||
self.force_bytes = force_str(param, cursor.charset, strings_only)
|
||||
if isinstance(self.force_bytes, str):
|
||||
# We could optimize by only converting up to 4000 bytes here
|
||||
string_size = len(force_bytes(param, cursor.charset, strings_only))
|
||||
if hasattr(param, "input_size"):
|
||||
# If parameter has `input_size` attribute, use that.
|
||||
self.input_size = param.input_size
|
||||
elif string_size > 4000:
|
||||
# Mark any string param greater than 4000 characters as a CLOB.
|
||||
self.input_size = Database.DB_TYPE_CLOB
|
||||
elif isinstance(param, datetime.datetime):
|
||||
self.input_size = Database.DB_TYPE_TIMESTAMP
|
||||
elif has_boolean_data_type and isinstance(param, bool):
|
||||
self.input_size = Database.DB_TYPE_BOOLEAN
|
||||
else:
|
||||
self.input_size = None
|
||||
|
||||
|
||||
class VariableWrapper:
|
||||
"""
|
||||
An adapter class for cursor variables that prevents the wrapped object
|
||||
from being converted into a string when used to instantiate an OracleParam.
|
||||
This can be used generally for any other object that should be passed into
|
||||
Cursor.execute as-is.
|
||||
"""
|
||||
|
||||
def __init__(self, var):
|
||||
self.var = var
|
||||
|
||||
def bind_parameter(self, cursor):
|
||||
return self.var
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.var, key)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == "var":
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
setattr(self.var, key, value)
|
||||
|
||||
|
||||
class FormatStylePlaceholderCursor:
|
||||
"""
|
||||
Django uses "format" (e.g. '%s') style placeholders, but Oracle uses ":var"
|
||||
style. This fixes it -- but note that if you want to use a literal "%s" in
|
||||
a query, you'll need to use "%%s".
|
||||
"""
|
||||
|
||||
charset = "utf-8"
|
||||
|
||||
def __init__(self, connection, database):
|
||||
self.cursor = connection.cursor()
|
||||
self.cursor.outputtypehandler = self._output_type_handler
|
||||
self.database = database
|
||||
|
||||
@staticmethod
|
||||
def _output_number_converter(value):
|
||||
return decimal.Decimal(value) if "." in value else int(value)
|
||||
|
||||
@staticmethod
|
||||
def _get_decimal_converter(precision, scale):
|
||||
if scale == 0:
|
||||
return int
|
||||
context = decimal.Context(prec=precision)
|
||||
quantize_value = decimal.Decimal(1).scaleb(-scale)
|
||||
return lambda v: decimal.Decimal(v).quantize(quantize_value, context=context)
|
||||
|
||||
@staticmethod
|
||||
def _output_type_handler(cursor, name, defaultType, length, precision, scale):
|
||||
"""
|
||||
Called for each db column fetched from cursors. Return numbers as the
|
||||
appropriate Python type, and NCLOB with JSON as strings.
|
||||
"""
|
||||
if defaultType == Database.NUMBER:
|
||||
if scale == -127:
|
||||
if precision == 0:
|
||||
# NUMBER column: decimal-precision floating point.
|
||||
# This will normally be an integer from a sequence,
|
||||
# but it could be a decimal value.
|
||||
outconverter = FormatStylePlaceholderCursor._output_number_converter
|
||||
else:
|
||||
# FLOAT column: binary-precision floating point.
|
||||
# This comes from FloatField columns.
|
||||
outconverter = float
|
||||
elif precision > 0:
|
||||
# NUMBER(p,s) column: decimal-precision fixed point.
|
||||
# This comes from IntegerField and DecimalField columns.
|
||||
outconverter = FormatStylePlaceholderCursor._get_decimal_converter(
|
||||
precision, scale
|
||||
)
|
||||
else:
|
||||
# No type information. This normally comes from a
|
||||
# mathematical expression in the SELECT list. Guess int
|
||||
# or Decimal based on whether it has a decimal point.
|
||||
outconverter = FormatStylePlaceholderCursor._output_number_converter
|
||||
return cursor.var(
|
||||
Database.STRING,
|
||||
size=255,
|
||||
arraysize=cursor.arraysize,
|
||||
outconverter=outconverter,
|
||||
)
|
||||
# oracledb 2.0.0+ returns NLOB columns with IS JSON constraints as
|
||||
# dicts. Use a no-op converter to avoid this.
|
||||
elif defaultType == Database.DB_TYPE_NCLOB:
|
||||
return cursor.var(Database.DB_TYPE_NCLOB, arraysize=cursor.arraysize)
|
||||
|
||||
def _format_params(self, params):
|
||||
try:
|
||||
return {k: OracleParam(v, self, True) for k, v in params.items()}
|
||||
except AttributeError:
|
||||
return tuple(OracleParam(p, self, True) for p in params)
|
||||
|
||||
def _guess_input_sizes(self, params_list):
|
||||
# Try dict handling; if that fails, treat as sequence
|
||||
if hasattr(params_list[0], "keys"):
|
||||
sizes = {}
|
||||
for params in params_list:
|
||||
for k, value in params.items():
|
||||
if value.input_size:
|
||||
sizes[k] = value.input_size
|
||||
if sizes:
|
||||
self.setinputsizes(**sizes)
|
||||
else:
|
||||
# It's not a list of dicts; it's a list of sequences
|
||||
sizes = [None] * len(params_list[0])
|
||||
for params in params_list:
|
||||
for i, value in enumerate(params):
|
||||
if value.input_size:
|
||||
sizes[i] = value.input_size
|
||||
if sizes:
|
||||
self.setinputsizes(*sizes)
|
||||
|
||||
def _param_generator(self, params):
|
||||
# Try dict handling; if that fails, treat as sequence
|
||||
if hasattr(params, "items"):
|
||||
return {k: v.force_bytes for k, v in params.items()}
|
||||
else:
|
||||
return [p.force_bytes for p in params]
|
||||
|
||||
def _fix_for_params(self, query, params, unify_by_values=False):
|
||||
# oracledb wants no trailing ';' for SQL statements. For PL/SQL, it
|
||||
# it does want a trailing ';' but not a trailing '/'. However, these
|
||||
# characters must be included in the original query in case the query
|
||||
# is being passed to SQL*Plus.
|
||||
if query.endswith(";") or query.endswith("/"):
|
||||
query = query[:-1]
|
||||
if params is None:
|
||||
params = []
|
||||
elif hasattr(params, "keys"):
|
||||
# Handle params as dict
|
||||
args = {k: ":%s" % k for k in params}
|
||||
query %= args
|
||||
elif unify_by_values and params:
|
||||
# Handle params as a dict with unified query parameters by their
|
||||
# values. It can be used only in single query execute() because
|
||||
# executemany() shares the formatted query with each of the params
|
||||
# list. e.g. for input params = [0.75, 2, 0.75, 'sth', 0.75]
|
||||
# params_dict = {
|
||||
# (float, 0.75): ':arg0',
|
||||
# (int, 2): ':arg1',
|
||||
# (str, 'sth'): ':arg2',
|
||||
# }
|
||||
# args = [':arg0', ':arg1', ':arg0', ':arg2', ':arg0']
|
||||
# params = {':arg0': 0.75, ':arg1': 2, ':arg2': 'sth'}
|
||||
# The type of parameters in param_types keys is necessary to avoid
|
||||
# unifying 0/1 with False/True.
|
||||
param_types = [(type(param), param) for param in params]
|
||||
params_dict = {
|
||||
param_type: ":arg%d" % i
|
||||
for i, param_type in enumerate(dict.fromkeys(param_types))
|
||||
}
|
||||
args = [params_dict[param_type] for param_type in param_types]
|
||||
params = {
|
||||
placeholder: param for (_, param), placeholder in params_dict.items()
|
||||
}
|
||||
query %= tuple(args)
|
||||
else:
|
||||
# Handle params as sequence
|
||||
args = [(":arg%d" % i) for i in range(len(params))]
|
||||
query %= tuple(args)
|
||||
return query, self._format_params(params)
|
||||
|
||||
def execute(self, query, params=None):
|
||||
query, params = self._fix_for_params(query, params, unify_by_values=True)
|
||||
self._guess_input_sizes([params])
|
||||
with wrap_oracle_errors():
|
||||
return self.cursor.execute(query, self._param_generator(params))
|
||||
|
||||
def executemany(self, query, params=None):
|
||||
if not params:
|
||||
# No params given, nothing to do
|
||||
return None
|
||||
# uniform treatment for sequences and iterables
|
||||
params_iter = iter(params)
|
||||
query, firstparams = self._fix_for_params(query, next(params_iter))
|
||||
# we build a list of formatted params; as we're going to traverse it
|
||||
# more than once, we can't make it lazy by using a generator
|
||||
formatted = [firstparams] + [self._format_params(p) for p in params_iter]
|
||||
self._guess_input_sizes(formatted)
|
||||
with wrap_oracle_errors():
|
||||
return self.cursor.executemany(
|
||||
query, [self._param_generator(p) for p in formatted]
|
||||
)
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.cursor.close()
|
||||
except Database.InterfaceError:
|
||||
# already closed
|
||||
pass
|
||||
|
||||
def var(self, *args):
|
||||
return VariableWrapper(self.cursor.var(*args))
|
||||
|
||||
def arrayvar(self, *args):
|
||||
return VariableWrapper(self.cursor.arrayvar(*args))
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.cursor, attr)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.cursor)
|
||||
@@ -0,0 +1,27 @@
|
||||
import shutil
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = "sqlplus"
|
||||
wrapper_name = "rlwrap"
|
||||
|
||||
@staticmethod
|
||||
def connect_string(settings_dict):
|
||||
from django.db.backends.oracle.utils import dsn
|
||||
|
||||
return '%s/"%s"@%s' % (
|
||||
settings_dict["USER"],
|
||||
settings_dict["PASSWORD"],
|
||||
dsn(settings_dict),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name, "-L", cls.connect_string(settings_dict)]
|
||||
wrapper_path = shutil.which(cls.wrapper_name)
|
||||
if wrapper_path:
|
||||
args = [wrapper_path, *args]
|
||||
args.extend(parameters)
|
||||
return args, None
|
||||
@@ -0,0 +1,467 @@
|
||||
import sys
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DatabaseError
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
from django.utils.crypto import get_random_string
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
TEST_DATABASE_PREFIX = "test_"
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
@cached_property
|
||||
def _maindb_connection(self):
|
||||
"""
|
||||
This is analogous to other backends' `_nodb_connection` property,
|
||||
which allows access to an "administrative" connection which can
|
||||
be used to manage the test databases.
|
||||
For Oracle, the only connection that can be used for that purpose
|
||||
is the main (non-test) connection.
|
||||
"""
|
||||
settings_dict = settings.DATABASES[self.connection.alias]
|
||||
user = settings_dict.get("SAVED_USER") or settings_dict["USER"]
|
||||
password = settings_dict.get("SAVED_PASSWORD") or settings_dict["PASSWORD"]
|
||||
settings_dict = {**settings_dict, "USER": user, "PASSWORD": password}
|
||||
DatabaseWrapper = type(self.connection)
|
||||
return DatabaseWrapper(settings_dict, alias=self.connection.alias)
|
||||
|
||||
def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
|
||||
parameters = self._get_test_db_params()
|
||||
with self._maindb_connection.cursor() as cursor:
|
||||
if self._test_database_create():
|
||||
try:
|
||||
self._execute_test_db_creation(
|
||||
cursor, parameters, verbosity, keepdb
|
||||
)
|
||||
except Exception as e:
|
||||
if "ORA-01543" not in str(e):
|
||||
# All errors except "tablespace already exists" cancel tests
|
||||
self.log("Got an error creating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"It appears the test database, %s, already exists. "
|
||||
"Type 'yes' to delete it, or 'no' to cancel: "
|
||||
% parameters["user"]
|
||||
)
|
||||
if autoclobber or confirm == "yes":
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias '%s'..."
|
||||
% self.connection.alias
|
||||
)
|
||||
try:
|
||||
self._execute_test_db_destruction(
|
||||
cursor, parameters, verbosity
|
||||
)
|
||||
except DatabaseError as e:
|
||||
if "ORA-29857" in str(e):
|
||||
self._handle_objects_preventing_db_destruction(
|
||||
cursor, parameters, verbosity, autoclobber
|
||||
)
|
||||
else:
|
||||
# Ran into a database error that isn't about
|
||||
# leftover objects in the tablespace.
|
||||
self.log(
|
||||
"Got an error destroying the old test database: %s"
|
||||
% e
|
||||
)
|
||||
sys.exit(2)
|
||||
except Exception as e:
|
||||
self.log(
|
||||
"Got an error destroying the old test database: %s" % e
|
||||
)
|
||||
sys.exit(2)
|
||||
try:
|
||||
self._execute_test_db_creation(
|
||||
cursor, parameters, verbosity, keepdb
|
||||
)
|
||||
except Exception as e:
|
||||
self.log(
|
||||
"Got an error recreating the test database: %s" % e
|
||||
)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled.")
|
||||
sys.exit(1)
|
||||
|
||||
if self._test_user_create():
|
||||
if verbosity >= 1:
|
||||
self.log("Creating test user...")
|
||||
try:
|
||||
self._create_test_user(cursor, parameters, verbosity, keepdb)
|
||||
except Exception as e:
|
||||
if "ORA-01920" not in str(e):
|
||||
# All errors except "user already exists" cancel tests
|
||||
self.log("Got an error creating the test user: %s" % e)
|
||||
sys.exit(2)
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"It appears the test user, %s, already exists. Type "
|
||||
"'yes' to delete it, or 'no' to cancel: "
|
||||
% parameters["user"]
|
||||
)
|
||||
if autoclobber or confirm == "yes":
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying old test user...")
|
||||
self._destroy_test_user(cursor, parameters, verbosity)
|
||||
if verbosity >= 1:
|
||||
self.log("Creating test user...")
|
||||
self._create_test_user(
|
||||
cursor, parameters, verbosity, keepdb
|
||||
)
|
||||
except Exception as e:
|
||||
self.log("Got an error recreating the test user: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled.")
|
||||
sys.exit(1)
|
||||
# Done with main user -- test user and tablespaces created.
|
||||
self._maindb_connection.close()
|
||||
self._switch_to_test_user(parameters)
|
||||
return self.connection.settings_dict["NAME"]
|
||||
|
||||
def _switch_to_test_user(self, parameters):
|
||||
"""
|
||||
Switch to the user that's used for creating the test database.
|
||||
|
||||
Oracle doesn't have the concept of separate databases under the same
|
||||
user, so a separate user is used; see _create_test_db(). The main user
|
||||
is also needed for cleanup when testing is completed, so save its
|
||||
credentials in the SAVED_USER/SAVED_PASSWORD key in the settings dict.
|
||||
"""
|
||||
real_settings = settings.DATABASES[self.connection.alias]
|
||||
real_settings["SAVED_USER"] = self.connection.settings_dict["SAVED_USER"] = (
|
||||
self.connection.settings_dict["USER"]
|
||||
)
|
||||
real_settings["SAVED_PASSWORD"] = self.connection.settings_dict[
|
||||
"SAVED_PASSWORD"
|
||||
] = self.connection.settings_dict["PASSWORD"]
|
||||
real_test_settings = real_settings["TEST"]
|
||||
test_settings = self.connection.settings_dict["TEST"]
|
||||
real_test_settings["USER"] = real_settings["USER"] = test_settings["USER"] = (
|
||||
self.connection.settings_dict["USER"]
|
||||
) = parameters["user"]
|
||||
real_settings["PASSWORD"] = self.connection.settings_dict["PASSWORD"] = (
|
||||
parameters["password"]
|
||||
)
|
||||
|
||||
def set_as_test_mirror(self, primary_settings_dict):
|
||||
"""
|
||||
Set this database up to be used in testing as a mirror of a primary
|
||||
database whose settings are given.
|
||||
"""
|
||||
self.connection.settings_dict["USER"] = primary_settings_dict["USER"]
|
||||
self.connection.settings_dict["PASSWORD"] = primary_settings_dict["PASSWORD"]
|
||||
|
||||
def _handle_objects_preventing_db_destruction(
|
||||
self, cursor, parameters, verbosity, autoclobber
|
||||
):
|
||||
# There are objects in the test tablespace which prevent dropping it
|
||||
# The easy fix is to drop the test user -- but are we allowed to do so?
|
||||
self.log(
|
||||
"There are objects in the old test database which prevent its destruction."
|
||||
"\nIf they belong to the test user, deleting the user will allow the test "
|
||||
"database to be recreated.\n"
|
||||
"Otherwise, you will need to find and remove each of these objects, "
|
||||
"or use a different tablespace.\n"
|
||||
)
|
||||
if self._test_user_create():
|
||||
if not autoclobber:
|
||||
confirm = input("Type 'yes' to delete user %s: " % parameters["user"])
|
||||
if autoclobber or confirm == "yes":
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying old test user...")
|
||||
self._destroy_test_user(cursor, parameters, verbosity)
|
||||
except Exception as e:
|
||||
self.log("Got an error destroying the test user: %s" % e)
|
||||
sys.exit(2)
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias '%s'..."
|
||||
% self.connection.alias
|
||||
)
|
||||
self._execute_test_db_destruction(cursor, parameters, verbosity)
|
||||
except Exception as e:
|
||||
self.log("Got an error destroying the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled -- test database cannot be recreated.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
self.log(
|
||||
"Django is configured to use pre-existing test user '%s',"
|
||||
" and will not attempt to delete it." % parameters["user"]
|
||||
)
|
||||
self.log("Tests cancelled -- test database cannot be recreated.")
|
||||
sys.exit(1)
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity=1):
|
||||
"""
|
||||
Destroy a test database, prompting the user for confirmation if the
|
||||
database already exists. Return the name of the test database created.
|
||||
"""
|
||||
if not self.connection.is_pool:
|
||||
self.connection.settings_dict["USER"] = self.connection.settings_dict[
|
||||
"SAVED_USER"
|
||||
]
|
||||
self.connection.settings_dict["PASSWORD"] = self.connection.settings_dict[
|
||||
"SAVED_PASSWORD"
|
||||
]
|
||||
self.connection.close()
|
||||
self.connection.close_pool()
|
||||
parameters = self._get_test_db_params()
|
||||
with self._maindb_connection.cursor() as cursor:
|
||||
if self._test_user_create():
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying test user...")
|
||||
self._destroy_test_user(cursor, parameters, verbosity)
|
||||
if self._test_database_create():
|
||||
if verbosity >= 1:
|
||||
self.log("Destroying test database tables...")
|
||||
self._execute_test_db_destruction(cursor, parameters, verbosity)
|
||||
self._maindb_connection.close()
|
||||
self._maindb_connection.close_pool()
|
||||
|
||||
def _execute_test_db_creation(self, cursor, parameters, verbosity, keepdb=False):
|
||||
if verbosity >= 2:
|
||||
self.log("_create_test_db(): dbname = %s" % parameters["user"])
|
||||
if self._test_database_oracle_managed_files():
|
||||
statements = [
|
||||
"""
|
||||
CREATE TABLESPACE %(tblspace)s
|
||||
DATAFILE SIZE %(size)s
|
||||
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
|
||||
""",
|
||||
"""
|
||||
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
|
||||
TEMPFILE SIZE %(size_tmp)s
|
||||
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
|
||||
""",
|
||||
]
|
||||
else:
|
||||
statements = [
|
||||
"""
|
||||
CREATE TABLESPACE %(tblspace)s
|
||||
DATAFILE '%(datafile)s' SIZE %(size)s REUSE
|
||||
AUTOEXTEND ON NEXT %(extsize)s MAXSIZE %(maxsize)s
|
||||
""",
|
||||
"""
|
||||
CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
|
||||
TEMPFILE '%(datafile_tmp)s' SIZE %(size_tmp)s REUSE
|
||||
AUTOEXTEND ON NEXT %(extsize_tmp)s MAXSIZE %(maxsize_tmp)s
|
||||
""",
|
||||
]
|
||||
# Ignore "tablespace already exists" error when keepdb is on.
|
||||
acceptable_ora_err = "ORA-01543" if keepdb else None
|
||||
self._execute_allow_fail_statements(
|
||||
cursor, statements, parameters, verbosity, acceptable_ora_err
|
||||
)
|
||||
|
||||
def _create_test_user(self, cursor, parameters, verbosity, keepdb=False):
|
||||
if verbosity >= 2:
|
||||
self.log("_create_test_user(): username = %s" % parameters["user"])
|
||||
statements = [
|
||||
"""CREATE USER %(user)s
|
||||
IDENTIFIED BY "%(password)s"
|
||||
DEFAULT TABLESPACE %(tblspace)s
|
||||
TEMPORARY TABLESPACE %(tblspace_temp)s
|
||||
QUOTA UNLIMITED ON %(tblspace)s
|
||||
""",
|
||||
"""GRANT CREATE SESSION,
|
||||
CREATE TABLE,
|
||||
CREATE SEQUENCE,
|
||||
CREATE PROCEDURE,
|
||||
CREATE TRIGGER
|
||||
TO %(user)s""",
|
||||
]
|
||||
# Ignore "user already exists" error when keepdb is on
|
||||
acceptable_ora_err = "ORA-01920" if keepdb else None
|
||||
success = self._execute_allow_fail_statements(
|
||||
cursor, statements, parameters, verbosity, acceptable_ora_err
|
||||
)
|
||||
# If the password was randomly generated, change the user accordingly.
|
||||
if not success and self._test_settings_get("PASSWORD") is None:
|
||||
set_password = 'ALTER USER %(user)s IDENTIFIED BY "%(password)s"'
|
||||
self._execute_statements(cursor, [set_password], parameters, verbosity)
|
||||
# Most test suites can be run without "create view" and
|
||||
# "create materialized view" privileges. But some need it.
|
||||
for object_type in ("VIEW", "MATERIALIZED VIEW"):
|
||||
extra = "GRANT CREATE %(object_type)s TO %(user)s"
|
||||
parameters["object_type"] = object_type
|
||||
success = self._execute_allow_fail_statements(
|
||||
cursor, [extra], parameters, verbosity, "ORA-01031"
|
||||
)
|
||||
if not success and verbosity >= 2:
|
||||
self.log(
|
||||
"Failed to grant CREATE %s permission to test user. This may be ok."
|
||||
% object_type
|
||||
)
|
||||
|
||||
def _execute_test_db_destruction(self, cursor, parameters, verbosity):
|
||||
if verbosity >= 2:
|
||||
self.log("_execute_test_db_destruction(): dbname=%s" % parameters["user"])
|
||||
statements = [
|
||||
"DROP TABLESPACE %(tblspace)s "
|
||||
"INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS",
|
||||
"DROP TABLESPACE %(tblspace_temp)s "
|
||||
"INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS",
|
||||
]
|
||||
self._execute_statements(cursor, statements, parameters, verbosity)
|
||||
|
||||
def _destroy_test_user(self, cursor, parameters, verbosity):
|
||||
if verbosity >= 2:
|
||||
self.log("_destroy_test_user(): user=%s" % parameters["user"])
|
||||
self.log("Be patient. This can take some time...")
|
||||
statements = [
|
||||
"DROP USER %(user)s CASCADE",
|
||||
]
|
||||
self._execute_statements(cursor, statements, parameters, verbosity)
|
||||
|
||||
def _execute_statements(
|
||||
self, cursor, statements, parameters, verbosity, allow_quiet_fail=False
|
||||
):
|
||||
for template in statements:
|
||||
stmt = template % parameters
|
||||
if verbosity >= 2:
|
||||
print(stmt)
|
||||
try:
|
||||
cursor.execute(stmt)
|
||||
except Exception as err:
|
||||
if (not allow_quiet_fail) or verbosity >= 2:
|
||||
self.log("Failed (%s)" % (err))
|
||||
raise
|
||||
|
||||
def _execute_allow_fail_statements(
|
||||
self, cursor, statements, parameters, verbosity, acceptable_ora_err
|
||||
):
|
||||
"""
|
||||
Execute statements which are allowed to fail silently if the Oracle
|
||||
error code given by `acceptable_ora_err` is raised. Return True if the
|
||||
statements execute without an exception, or False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Statement can fail when acceptable_ora_err is not None
|
||||
allow_quiet_fail = (
|
||||
acceptable_ora_err is not None and len(acceptable_ora_err) > 0
|
||||
)
|
||||
self._execute_statements(
|
||||
cursor,
|
||||
statements,
|
||||
parameters,
|
||||
verbosity,
|
||||
allow_quiet_fail=allow_quiet_fail,
|
||||
)
|
||||
return True
|
||||
except DatabaseError as err:
|
||||
description = str(err)
|
||||
if acceptable_ora_err is None or acceptable_ora_err not in description:
|
||||
raise
|
||||
return False
|
||||
|
||||
def _get_test_db_params(self):
|
||||
return {
|
||||
"dbname": self._test_database_name(),
|
||||
"user": self._test_database_user(),
|
||||
"password": self._test_database_passwd(),
|
||||
"tblspace": self._test_database_tblspace(),
|
||||
"tblspace_temp": self._test_database_tblspace_tmp(),
|
||||
"datafile": self._test_database_tblspace_datafile(),
|
||||
"datafile_tmp": self._test_database_tblspace_tmp_datafile(),
|
||||
"maxsize": self._test_database_tblspace_maxsize(),
|
||||
"maxsize_tmp": self._test_database_tblspace_tmp_maxsize(),
|
||||
"size": self._test_database_tblspace_size(),
|
||||
"size_tmp": self._test_database_tblspace_tmp_size(),
|
||||
"extsize": self._test_database_tblspace_extsize(),
|
||||
"extsize_tmp": self._test_database_tblspace_tmp_extsize(),
|
||||
}
|
||||
|
||||
def _test_settings_get(self, key, default=None, prefixed=None):
|
||||
"""
|
||||
Return a value from the test settings dict, or a given default, or a
|
||||
prefixed entry from the main settings dict.
|
||||
"""
|
||||
settings_dict = self.connection.settings_dict
|
||||
val = settings_dict["TEST"].get(key, default)
|
||||
if val is None and prefixed:
|
||||
val = TEST_DATABASE_PREFIX + settings_dict[prefixed]
|
||||
return val
|
||||
|
||||
def _test_database_name(self):
|
||||
return self._test_settings_get("NAME", prefixed="NAME")
|
||||
|
||||
def _test_database_create(self):
|
||||
return self._test_settings_get("CREATE_DB", default=True)
|
||||
|
||||
def _test_user_create(self):
|
||||
return self._test_settings_get("CREATE_USER", default=True)
|
||||
|
||||
def _test_database_user(self):
|
||||
return self._test_settings_get("USER", prefixed="USER")
|
||||
|
||||
def _test_database_passwd(self):
|
||||
password = self._test_settings_get("PASSWORD")
|
||||
if password is None and self._test_user_create():
|
||||
# Oracle passwords are limited to 30 chars and can't contain symbols.
|
||||
password = get_random_string(30)
|
||||
return password
|
||||
|
||||
def _test_database_tblspace(self):
|
||||
return self._test_settings_get("TBLSPACE", prefixed="USER")
|
||||
|
||||
def _test_database_tblspace_tmp(self):
|
||||
settings_dict = self.connection.settings_dict
|
||||
return settings_dict["TEST"].get(
|
||||
"TBLSPACE_TMP", TEST_DATABASE_PREFIX + settings_dict["USER"] + "_temp"
|
||||
)
|
||||
|
||||
def _test_database_tblspace_datafile(self):
|
||||
tblspace = "%s.dbf" % self._test_database_tblspace()
|
||||
return self._test_settings_get("DATAFILE", default=tblspace)
|
||||
|
||||
def _test_database_tblspace_tmp_datafile(self):
|
||||
tblspace = "%s.dbf" % self._test_database_tblspace_tmp()
|
||||
return self._test_settings_get("DATAFILE_TMP", default=tblspace)
|
||||
|
||||
def _test_database_tblspace_maxsize(self):
|
||||
return self._test_settings_get("DATAFILE_MAXSIZE", default="500M")
|
||||
|
||||
def _test_database_tblspace_tmp_maxsize(self):
|
||||
return self._test_settings_get("DATAFILE_TMP_MAXSIZE", default="500M")
|
||||
|
||||
def _test_database_tblspace_size(self):
|
||||
return self._test_settings_get("DATAFILE_SIZE", default="50M")
|
||||
|
||||
def _test_database_tblspace_tmp_size(self):
|
||||
return self._test_settings_get("DATAFILE_TMP_SIZE", default="50M")
|
||||
|
||||
def _test_database_tblspace_extsize(self):
|
||||
return self._test_settings_get("DATAFILE_EXTSIZE", default="25M")
|
||||
|
||||
def _test_database_tblspace_tmp_extsize(self):
|
||||
return self._test_settings_get("DATAFILE_TMP_EXTSIZE", default="25M")
|
||||
|
||||
def _test_database_oracle_managed_files(self):
|
||||
return self._test_settings_get("ORACLE_MANAGED_FILES", default=False)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
"""
|
||||
Return the 'production' DB name to get the test DB creation machinery
|
||||
to work. This isn't a great deal in this case because DB names as
|
||||
handled by Django don't have real counterparts in Oracle.
|
||||
"""
|
||||
return self.connection.settings_dict["NAME"]
|
||||
|
||||
def test_db_signature(self):
|
||||
settings_dict = self.connection.settings_dict
|
||||
return (
|
||||
settings_dict["HOST"],
|
||||
settings_dict["PORT"],
|
||||
settings_dict["ENGINE"],
|
||||
settings_dict["NAME"],
|
||||
self._test_database_user(),
|
||||
)
|
||||
@@ -0,0 +1,230 @@
|
||||
from django.db import DatabaseError, InterfaceError
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.db.backends.oracle.oracledb_any import is_oracledb
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
minimum_database_version = (19,)
|
||||
# Oracle crashes with "ORA-00932: inconsistent datatypes: expected - got
|
||||
# BLOB" when grouping by LOBs (#24096).
|
||||
allows_group_by_lob = False
|
||||
# Although GROUP BY select index is supported by Oracle 23c+, it requires
|
||||
# GROUP_BY_POSITION_ENABLED to be enabled to avoid backward compatibility
|
||||
# issues. Introspection of this settings is not straightforward.
|
||||
allows_group_by_select_index = False
|
||||
interprets_empty_strings_as_nulls = True
|
||||
has_select_for_update = True
|
||||
has_select_for_update_nowait = True
|
||||
has_select_for_update_skip_locked = True
|
||||
has_select_for_update_of = True
|
||||
select_for_update_of_column = True
|
||||
can_return_columns_from_insert = True
|
||||
supports_subqueries_in_group_by = False
|
||||
ignores_unnecessary_order_by_in_subqueries = False
|
||||
supports_transactions = True
|
||||
supports_timezones = False
|
||||
has_native_duration_field = True
|
||||
can_defer_constraint_checks = True
|
||||
supports_partially_nullable_unique_constraints = False
|
||||
supports_deferrable_unique_constraints = True
|
||||
truncates_names = True
|
||||
supports_comments = True
|
||||
supports_tablespaces = True
|
||||
supports_sequence_reset = False
|
||||
can_introspect_materialized_views = True
|
||||
atomic_transactions = False
|
||||
nulls_order_largest = True
|
||||
requires_literal_defaults = True
|
||||
supports_default_keyword_in_bulk_insert = False
|
||||
closed_cursor_error_class = InterfaceError
|
||||
# Select for update with limit can be achieved on Oracle, but not with the
|
||||
# current backend.
|
||||
supports_select_for_update_with_limit = False
|
||||
supports_temporal_subtraction = True
|
||||
# Oracle doesn't ignore quoted identifiers case but the current backend
|
||||
# does by uppercasing all identifiers.
|
||||
ignores_table_name_case = True
|
||||
supports_index_on_text_field = False
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE PROCEDURE "TEST_PROCEDURE" AS
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := 1;
|
||||
END;
|
||||
"""
|
||||
create_test_procedure_with_int_param_sql = """
|
||||
CREATE PROCEDURE "TEST_PROCEDURE" (P_I INTEGER) AS
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := P_I;
|
||||
END;
|
||||
"""
|
||||
create_test_table_with_composite_primary_key = """
|
||||
CREATE TABLE test_table_composite_pk (
|
||||
column_1 NUMBER(11) NOT NULL,
|
||||
column_2 NUMBER(11) NOT NULL,
|
||||
PRIMARY KEY (column_1, column_2)
|
||||
)
|
||||
"""
|
||||
supports_callproc_kwargs = True
|
||||
supports_over_clause = True
|
||||
supports_frame_range_fixed_distance = True
|
||||
supports_ignore_conflicts = False
|
||||
max_query_params = 2**16 - 1
|
||||
supports_partial_indexes = False
|
||||
supports_stored_generated_columns = False
|
||||
supports_virtual_generated_columns = True
|
||||
can_rename_index = True
|
||||
supports_slicing_ordering_in_compound = True
|
||||
requires_compound_order_by_subquery = True
|
||||
allows_multiple_constraints_on_same_fields = False
|
||||
supports_json_field_contains = False
|
||||
supports_collation_on_textfield = False
|
||||
supports_tuple_lookups = False
|
||||
test_now_utc_template = "CURRENT_TIMESTAMP AT TIME ZONE 'UTC'"
|
||||
django_test_expected_failures = {
|
||||
# A bug in Django/oracledb with respect to string handling (#23843).
|
||||
"annotations.tests.NonAggregateAnnotationTestCase.test_custom_functions",
|
||||
"annotations.tests.NonAggregateAnnotationTestCase."
|
||||
"test_custom_functions_can_ref_other_functions",
|
||||
}
|
||||
insert_test_table_with_defaults = (
|
||||
"INSERT INTO {} VALUES (DEFAULT, DEFAULT, DEFAULT)"
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def django_test_skips(self):
|
||||
skips = {
|
||||
"Oracle doesn't support SHA224.": {
|
||||
"db_functions.text.test_sha224.SHA224Tests.test_basic",
|
||||
"db_functions.text.test_sha224.SHA224Tests.test_transform",
|
||||
},
|
||||
"Oracle doesn't correctly calculate ISO 8601 week numbering before "
|
||||
"1583 (the Gregorian calendar was introduced in 1582).": {
|
||||
"db_functions.datetime.test_extract_trunc.DateFunctionTests."
|
||||
"test_trunc_week_before_1000",
|
||||
"db_functions.datetime.test_extract_trunc."
|
||||
"DateFunctionWithTimeZoneTests.test_trunc_week_before_1000",
|
||||
},
|
||||
"Oracle doesn't support bitwise XOR.": {
|
||||
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor",
|
||||
"expressions.tests.ExpressionOperatorTests."
|
||||
"test_lefthand_bitwise_xor_null",
|
||||
"expressions.tests.ExpressionOperatorTests."
|
||||
"test_lefthand_bitwise_xor_right_null",
|
||||
},
|
||||
"Oracle requires ORDER BY in row_number, ANSI:SQL doesn't.": {
|
||||
"expressions_window.tests.WindowFunctionTests."
|
||||
"test_row_number_no_ordering",
|
||||
"prefetch_related.tests.PrefetchLimitTests.test_empty_order",
|
||||
},
|
||||
"Oracle doesn't support changing collations on indexed columns (#33671).": {
|
||||
"migrations.test_operations.OperationTests."
|
||||
"test_alter_field_pk_fk_db_collation",
|
||||
},
|
||||
"Oracle doesn't support comparing NCLOB to NUMBER.": {
|
||||
"generic_relations_regress.tests.GenericRelationTests."
|
||||
"test_textlink_filter",
|
||||
},
|
||||
"Oracle doesn't support casting filters to NUMBER.": {
|
||||
"lookup.tests.LookupQueryingTests.test_aggregate_combined_lookup",
|
||||
},
|
||||
}
|
||||
if self.connection.oracle_version < (23,):
|
||||
skips.update(
|
||||
{
|
||||
"Raises ORA-00600 on Oracle < 23c: internal error code.": {
|
||||
"model_fields.test_jsonfield.TestQuerying."
|
||||
"test_usage_in_subquery",
|
||||
},
|
||||
}
|
||||
)
|
||||
if self.connection.is_pool:
|
||||
skips.update(
|
||||
{
|
||||
"Pooling does not support persistent connections": {
|
||||
"backends.base.test_base.ConnectionHealthChecksTests."
|
||||
"test_health_checks_enabled",
|
||||
"backends.base.test_base.ConnectionHealthChecksTests."
|
||||
"test_health_checks_enabled_errors_occurred",
|
||||
"backends.base.test_base.ConnectionHealthChecksTests."
|
||||
"test_health_checks_disabled",
|
||||
"backends.base.test_base.ConnectionHealthChecksTests."
|
||||
"test_set_autocommit_health_checks_enabled",
|
||||
"servers.tests.LiveServerTestCloseConnectionTest."
|
||||
"test_closes_connections",
|
||||
"backends.oracle.tests.TransactionalTests."
|
||||
"test_password_with_at_sign",
|
||||
},
|
||||
}
|
||||
)
|
||||
if is_oracledb and self.connection.oracledb_version >= (2, 1, 2):
|
||||
skips.update(
|
||||
{
|
||||
"python-oracledb 2.1.2+ no longer hides 'ORA-1403: no data found' "
|
||||
"exceptions raised in database triggers.": {
|
||||
"backends.oracle.tests.TransactionalTests."
|
||||
"test_hidden_no_data_found_exception"
|
||||
},
|
||||
},
|
||||
)
|
||||
return skips
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
"GenericIPAddressField": "CharField",
|
||||
"PositiveBigIntegerField": "BigIntegerField",
|
||||
"PositiveIntegerField": "IntegerField",
|
||||
"PositiveSmallIntegerField": "IntegerField",
|
||||
"SmallIntegerField": "IntegerField",
|
||||
"TimeField": "DateTimeField",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def test_collations(self):
|
||||
return {
|
||||
"ci": "BINARY_CI",
|
||||
"cs": "BINARY",
|
||||
"non_default": "SWEDISH_CI",
|
||||
"swedish_ci": "SWEDISH_CI",
|
||||
"virtual": "SWEDISH_CI" if self.supports_collation_on_charfield else None,
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def supports_collation_on_charfield(self):
|
||||
sql = "SELECT CAST('a' AS VARCHAR2(4001))" + self.bare_select_suffix
|
||||
with self.connection.cursor() as cursor:
|
||||
try:
|
||||
cursor.execute(sql)
|
||||
except DatabaseError as e:
|
||||
if e.args[0].code == 910:
|
||||
return False
|
||||
raise
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def supports_primitives_in_json_field(self):
|
||||
return self.connection.oracle_version >= (21,)
|
||||
|
||||
@cached_property
|
||||
def supports_frame_exclusion(self):
|
||||
return self.connection.oracle_version >= (21,)
|
||||
|
||||
@cached_property
|
||||
def supports_boolean_expr_in_select_clause(self):
|
||||
return self.connection.oracle_version >= (23,)
|
||||
|
||||
@cached_property
|
||||
def supports_comparing_boolean_expr(self):
|
||||
return self.connection.oracle_version >= (23,)
|
||||
|
||||
@cached_property
|
||||
def supports_aggregation_over_interval_types(self):
|
||||
return self.connection.oracle_version >= (23,)
|
||||
|
||||
@cached_property
|
||||
def bare_select_suffix(self):
|
||||
return "" if self.connection.oracle_version >= (23,) else " FROM DUAL"
|
||||
@@ -0,0 +1,26 @@
|
||||
from django.db.models import DecimalField, DurationField, Func
|
||||
|
||||
|
||||
class IntervalToSeconds(Func):
|
||||
function = ""
|
||||
template = """
|
||||
EXTRACT(day from %(expressions)s) * 86400 +
|
||||
EXTRACT(hour from %(expressions)s) * 3600 +
|
||||
EXTRACT(minute from %(expressions)s) * 60 +
|
||||
EXTRACT(second from %(expressions)s)
|
||||
"""
|
||||
|
||||
def __init__(self, expression, *, output_field=None, **extra):
|
||||
super().__init__(
|
||||
expression, output_field=output_field or DecimalField(), **extra
|
||||
)
|
||||
|
||||
|
||||
class SecondsToInterval(Func):
|
||||
function = "NUMTODSINTERVAL"
|
||||
template = "%(function)s(%(expressions)s, 'SECOND')"
|
||||
|
||||
def __init__(self, expression, *, output_field=None, **extra):
|
||||
super().__init__(
|
||||
expression, output_field=output_field or DurationField(), **extra
|
||||
)
|
||||
@@ -0,0 +1,414 @@
|
||||
from collections import namedtuple
|
||||
|
||||
from django.db import models
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
|
||||
from django.db.backends.base.introspection import TableInfo as BaseTableInfo
|
||||
from django.db.backends.oracle.oracledb_any import oracledb
|
||||
|
||||
FieldInfo = namedtuple(
|
||||
"FieldInfo", BaseFieldInfo._fields + ("is_autofield", "is_json", "comment")
|
||||
)
|
||||
TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
cache_bust_counter = 1
|
||||
|
||||
# Maps type objects to Django Field types.
|
||||
data_types_reverse = {
|
||||
oracledb.DB_TYPE_DATE: "DateField",
|
||||
oracledb.DB_TYPE_BINARY_DOUBLE: "FloatField",
|
||||
oracledb.DB_TYPE_BLOB: "BinaryField",
|
||||
oracledb.DB_TYPE_CHAR: "CharField",
|
||||
oracledb.DB_TYPE_CLOB: "TextField",
|
||||
oracledb.DB_TYPE_INTERVAL_DS: "DurationField",
|
||||
oracledb.DB_TYPE_NCHAR: "CharField",
|
||||
oracledb.DB_TYPE_NCLOB: "TextField",
|
||||
oracledb.DB_TYPE_NVARCHAR: "CharField",
|
||||
oracledb.DB_TYPE_NUMBER: "DecimalField",
|
||||
oracledb.DB_TYPE_TIMESTAMP: "DateTimeField",
|
||||
oracledb.DB_TYPE_VARCHAR: "CharField",
|
||||
}
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
if data_type == oracledb.NUMBER:
|
||||
precision, scale = description[4:6]
|
||||
if scale == 0:
|
||||
if precision > 11:
|
||||
return (
|
||||
"BigAutoField"
|
||||
if description.is_autofield
|
||||
else "BigIntegerField"
|
||||
)
|
||||
elif 1 < precision < 6 and description.is_autofield:
|
||||
return "SmallAutoField"
|
||||
elif precision == 1:
|
||||
return "BooleanField"
|
||||
elif description.is_autofield:
|
||||
return "AutoField"
|
||||
else:
|
||||
return "IntegerField"
|
||||
elif scale == -127:
|
||||
return "FloatField"
|
||||
elif data_type == oracledb.NCLOB and description.is_json:
|
||||
return "JSONField"
|
||||
|
||||
return super().get_field_type(data_type, description)
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_tables.table_name,
|
||||
't',
|
||||
user_tab_comments.comments
|
||||
FROM user_tables
|
||||
LEFT OUTER JOIN
|
||||
user_tab_comments
|
||||
ON user_tab_comments.table_name = user_tables.table_name
|
||||
WHERE
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_mviews
|
||||
WHERE user_mviews.mview_name = user_tables.table_name
|
||||
)
|
||||
UNION ALL
|
||||
SELECT view_name, 'v', NULL FROM user_views
|
||||
UNION ALL
|
||||
SELECT mview_name, 'v', NULL FROM user_mviews
|
||||
"""
|
||||
)
|
||||
return [
|
||||
TableInfo(self.identifier_converter(row[0]), row[1], row[2])
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
# A default collation for the given table/view/materialized view.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT user_tables.default_collation
|
||||
FROM user_tables
|
||||
WHERE
|
||||
user_tables.table_name = UPPER(%s) AND
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_mviews
|
||||
WHERE user_mviews.mview_name = user_tables.table_name
|
||||
)
|
||||
UNION ALL
|
||||
SELECT user_views.default_collation
|
||||
FROM user_views
|
||||
WHERE user_views.view_name = UPPER(%s)
|
||||
UNION ALL
|
||||
SELECT user_mviews.default_collation
|
||||
FROM user_mviews
|
||||
WHERE user_mviews.mview_name = UPPER(%s)
|
||||
""",
|
||||
[table_name, table_name, table_name],
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
default_table_collation = row[0] if row else ""
|
||||
# user_tab_columns gives data default for columns
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_tab_cols.column_name,
|
||||
user_tab_cols.data_default,
|
||||
CASE
|
||||
WHEN user_tab_cols.collation = %s
|
||||
THEN NULL
|
||||
ELSE user_tab_cols.collation
|
||||
END collation,
|
||||
CASE
|
||||
WHEN user_tab_cols.char_used IS NULL
|
||||
THEN user_tab_cols.data_length
|
||||
ELSE user_tab_cols.char_length
|
||||
END as display_size,
|
||||
CASE
|
||||
WHEN user_tab_cols.identity_column = 'YES' THEN 1
|
||||
ELSE 0
|
||||
END as is_autofield,
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1
|
||||
FROM user_json_columns
|
||||
WHERE
|
||||
user_json_columns.table_name = user_tab_cols.table_name AND
|
||||
user_json_columns.column_name = user_tab_cols.column_name
|
||||
)
|
||||
THEN 1
|
||||
ELSE 0
|
||||
END as is_json,
|
||||
user_col_comments.comments as col_comment
|
||||
FROM user_tab_cols
|
||||
LEFT OUTER JOIN
|
||||
user_col_comments ON
|
||||
user_col_comments.column_name = user_tab_cols.column_name AND
|
||||
user_col_comments.table_name = user_tab_cols.table_name
|
||||
WHERE user_tab_cols.table_name = UPPER(%s)
|
||||
""",
|
||||
[default_table_collation, table_name],
|
||||
)
|
||||
field_map = {
|
||||
column: (
|
||||
display_size,
|
||||
default.rstrip() if default and default != "NULL" else None,
|
||||
collation,
|
||||
is_autofield,
|
||||
is_json,
|
||||
comment,
|
||||
)
|
||||
for (
|
||||
column,
|
||||
default,
|
||||
collation,
|
||||
display_size,
|
||||
is_autofield,
|
||||
is_json,
|
||||
comment,
|
||||
) in cursor.fetchall()
|
||||
}
|
||||
self.cache_bust_counter += 1
|
||||
cursor.execute(
|
||||
"SELECT * FROM {} WHERE ROWNUM < 2 AND {} > 0".format(
|
||||
self.connection.ops.quote_name(table_name), self.cache_bust_counter
|
||||
)
|
||||
)
|
||||
description = []
|
||||
for desc in cursor.description:
|
||||
name = desc[0]
|
||||
(
|
||||
display_size,
|
||||
default,
|
||||
collation,
|
||||
is_autofield,
|
||||
is_json,
|
||||
comment,
|
||||
) = field_map[name]
|
||||
name %= {} # oracledb, for some reason, doubles percent signs.
|
||||
description.append(
|
||||
FieldInfo(
|
||||
self.identifier_converter(name),
|
||||
desc[1],
|
||||
display_size,
|
||||
desc[3],
|
||||
desc[4] or 0,
|
||||
desc[5] or 0,
|
||||
*desc[6:],
|
||||
default,
|
||||
collation,
|
||||
is_autofield,
|
||||
is_json,
|
||||
comment,
|
||||
)
|
||||
)
|
||||
return description
|
||||
|
||||
def identifier_converter(self, name):
|
||||
"""Identifier comparison is case insensitive under Oracle."""
|
||||
return name.lower()
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_tab_identity_cols.sequence_name,
|
||||
user_tab_identity_cols.column_name
|
||||
FROM
|
||||
user_tab_identity_cols,
|
||||
user_constraints,
|
||||
user_cons_columns cols
|
||||
WHERE
|
||||
user_constraints.constraint_name = cols.constraint_name
|
||||
AND user_constraints.table_name = user_tab_identity_cols.table_name
|
||||
AND cols.column_name = user_tab_identity_cols.column_name
|
||||
AND user_constraints.constraint_type = 'P'
|
||||
AND user_tab_identity_cols.table_name = UPPER(%s)
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
# Oracle allows only one identity column per table.
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return [
|
||||
{
|
||||
"name": self.identifier_converter(row[0]),
|
||||
"table": self.identifier_converter(table_name),
|
||||
"column": self.identifier_converter(row[1]),
|
||||
}
|
||||
]
|
||||
# To keep backward compatibility for AutoFields that aren't Oracle
|
||||
# identity columns.
|
||||
for f in table_fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
return [{"table": table_name, "column": f.column}]
|
||||
return []
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
table_name = table_name.upper()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT ca.column_name, cb.table_name, cb.column_name
|
||||
FROM user_constraints, USER_CONS_COLUMNS ca, USER_CONS_COLUMNS cb
|
||||
WHERE user_constraints.table_name = %s AND
|
||||
user_constraints.constraint_name = ca.constraint_name AND
|
||||
user_constraints.r_constraint_name = cb.constraint_name AND
|
||||
ca.position = cb.position""",
|
||||
[table_name],
|
||||
)
|
||||
|
||||
return {
|
||||
self.identifier_converter(field_name): (
|
||||
self.identifier_converter(rel_field_name),
|
||||
self.identifier_converter(rel_table_name),
|
||||
)
|
||||
for field_name, rel_table_name, rel_field_name in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_primary_key_columns(self, cursor, table_name):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
cols.column_name
|
||||
FROM
|
||||
user_constraints,
|
||||
user_cons_columns cols
|
||||
WHERE
|
||||
user_constraints.constraint_name = cols.constraint_name AND
|
||||
user_constraints.constraint_type = 'P' AND
|
||||
user_constraints.table_name = UPPER(%s)
|
||||
ORDER BY
|
||||
cols.position
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
return [self.identifier_converter(row[0]) for row in cursor.fetchall()]
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Loop over the constraints, getting PKs, uniques, and checks
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_constraints.constraint_name,
|
||||
LISTAGG(LOWER(cols.column_name), ',')
|
||||
WITHIN GROUP (ORDER BY cols.position),
|
||||
CASE user_constraints.constraint_type
|
||||
WHEN 'P' THEN 1
|
||||
ELSE 0
|
||||
END AS is_primary_key,
|
||||
CASE
|
||||
WHEN user_constraints.constraint_type IN ('P', 'U') THEN 1
|
||||
ELSE 0
|
||||
END AS is_unique,
|
||||
CASE user_constraints.constraint_type
|
||||
WHEN 'C' THEN 1
|
||||
ELSE 0
|
||||
END AS is_check_constraint
|
||||
FROM
|
||||
user_constraints
|
||||
LEFT OUTER JOIN
|
||||
user_cons_columns cols
|
||||
ON user_constraints.constraint_name = cols.constraint_name
|
||||
WHERE
|
||||
user_constraints.constraint_type = ANY('P', 'U', 'C')
|
||||
AND user_constraints.table_name = UPPER(%s)
|
||||
GROUP BY user_constraints.constraint_name, user_constraints.constraint_type
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
for constraint, columns, pk, unique, check in cursor.fetchall():
|
||||
constraint = self.identifier_converter(constraint)
|
||||
constraints[constraint] = {
|
||||
"columns": columns.split(","),
|
||||
"primary_key": pk,
|
||||
"unique": unique,
|
||||
"foreign_key": None,
|
||||
"check": check,
|
||||
"index": unique, # All uniques come with an index
|
||||
}
|
||||
# Foreign key constraints
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
cons.constraint_name,
|
||||
LISTAGG(LOWER(cols.column_name), ',')
|
||||
WITHIN GROUP (ORDER BY cols.position),
|
||||
LOWER(rcols.table_name),
|
||||
LOWER(rcols.column_name)
|
||||
FROM
|
||||
user_constraints cons
|
||||
INNER JOIN
|
||||
user_cons_columns rcols
|
||||
ON rcols.constraint_name = cons.r_constraint_name AND rcols.position = 1
|
||||
LEFT OUTER JOIN
|
||||
user_cons_columns cols
|
||||
ON cons.constraint_name = cols.constraint_name
|
||||
WHERE
|
||||
cons.constraint_type = 'R' AND
|
||||
cons.table_name = UPPER(%s)
|
||||
GROUP BY cons.constraint_name, rcols.table_name, rcols.column_name
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
for constraint, columns, other_table, other_column in cursor.fetchall():
|
||||
constraint = self.identifier_converter(constraint)
|
||||
constraints[constraint] = {
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"foreign_key": (other_table, other_column),
|
||||
"check": False,
|
||||
"index": False,
|
||||
"columns": columns.split(","),
|
||||
}
|
||||
# Now get indexes
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
ind.index_name,
|
||||
LOWER(ind.index_type),
|
||||
LOWER(ind.uniqueness),
|
||||
LISTAGG(LOWER(cols.column_name), ',')
|
||||
WITHIN GROUP (ORDER BY cols.column_position),
|
||||
LISTAGG(cols.descend, ',') WITHIN GROUP (ORDER BY cols.column_position)
|
||||
FROM
|
||||
user_ind_columns cols, user_indexes ind
|
||||
WHERE
|
||||
cols.table_name = UPPER(%s) AND
|
||||
NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM user_constraints cons
|
||||
WHERE ind.index_name = cons.index_name
|
||||
) AND cols.index_name = ind.index_name
|
||||
GROUP BY ind.index_name, ind.index_type, ind.uniqueness
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
for constraint, type_, unique, columns, orders in cursor.fetchall():
|
||||
constraint = self.identifier_converter(constraint)
|
||||
constraints[constraint] = {
|
||||
"primary_key": False,
|
||||
"unique": unique == "unique",
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
"type": "idx" if type_ == "normal" else type_,
|
||||
"columns": columns.split(","),
|
||||
"orders": orders.split(","),
|
||||
}
|
||||
return constraints
|
||||
@@ -0,0 +1,741 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import DatabaseError, NotSupportedError
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.utils import split_tzname_delta, strip_quotes, truncate_name
|
||||
from django.db.models import (
|
||||
AutoField,
|
||||
CompositePrimaryKey,
|
||||
Exists,
|
||||
ExpressionWrapper,
|
||||
Lookup,
|
||||
)
|
||||
from django.db.models.expressions import RawSQL
|
||||
from django.db.models.sql.where import WhereNode
|
||||
from django.utils import timezone
|
||||
from django.utils.encoding import force_bytes, force_str
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
from .base import Database
|
||||
from .utils import BulkInsertMapper, InsertVar, Oracle_datetime
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
# Oracle uses NUMBER(5), NUMBER(11), and NUMBER(19) for integer fields.
|
||||
# SmallIntegerField uses NUMBER(11) instead of NUMBER(5), which is used by
|
||||
# SmallAutoField, to preserve backward compatibility.
|
||||
integer_field_ranges = {
|
||||
"SmallIntegerField": (-99999999999, 99999999999),
|
||||
"IntegerField": (-99999999999, 99999999999),
|
||||
"BigIntegerField": (-9999999999999999999, 9999999999999999999),
|
||||
"PositiveBigIntegerField": (0, 9999999999999999999),
|
||||
"PositiveSmallIntegerField": (0, 99999999999),
|
||||
"PositiveIntegerField": (0, 99999999999),
|
||||
"SmallAutoField": (-99999, 99999),
|
||||
"AutoField": (-99999999999, 99999999999),
|
||||
"BigAutoField": (-9999999999999999999, 9999999999999999999),
|
||||
}
|
||||
set_operators = {**BaseDatabaseOperations.set_operators, "difference": "MINUS"}
|
||||
|
||||
# TODO: colorize this SQL code with style.SQL_KEYWORD(), etc.
|
||||
_sequence_reset_sql = """
|
||||
DECLARE
|
||||
table_value integer;
|
||||
seq_value integer;
|
||||
seq_name user_tab_identity_cols.sequence_name%%TYPE;
|
||||
BEGIN
|
||||
BEGIN
|
||||
SELECT sequence_name INTO seq_name FROM user_tab_identity_cols
|
||||
WHERE table_name = '%(table_name)s' AND
|
||||
column_name = '%(column_name)s';
|
||||
EXCEPTION WHEN NO_DATA_FOUND THEN
|
||||
seq_name := '%(no_autofield_sequence_name)s';
|
||||
END;
|
||||
|
||||
SELECT NVL(MAX(%(column)s), 0) INTO table_value FROM %(table)s;
|
||||
SELECT NVL(last_number - cache_size, 0) INTO seq_value FROM user_sequences
|
||||
WHERE sequence_name = seq_name;
|
||||
WHILE table_value > seq_value LOOP
|
||||
EXECUTE IMMEDIATE 'SELECT "'||seq_name||'".nextval%(suffix)s'
|
||||
INTO seq_value;
|
||||
END LOOP;
|
||||
END;
|
||||
/"""
|
||||
|
||||
# Oracle doesn't support string without precision; use the max string size.
|
||||
cast_char_field_without_max_length = "NVARCHAR2(2000)"
|
||||
cast_data_types = {
|
||||
"AutoField": "NUMBER(11)",
|
||||
"BigAutoField": "NUMBER(19)",
|
||||
"SmallAutoField": "NUMBER(5)",
|
||||
"TextField": cast_char_field_without_max_length,
|
||||
}
|
||||
|
||||
def cache_key_culling_sql(self):
|
||||
cache_key = self.quote_name("cache_key")
|
||||
return (
|
||||
f"SELECT {cache_key} "
|
||||
f"FROM %s "
|
||||
f"ORDER BY {cache_key} OFFSET %%s ROWS FETCH FIRST 1 ROWS ONLY"
|
||||
)
|
||||
|
||||
# EXTRACT format cannot be passed in parameters.
|
||||
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
extract_sql = f"TO_CHAR({sql}, %s)"
|
||||
extract_param = None
|
||||
if lookup_type == "week_day":
|
||||
# TO_CHAR(field, 'D') returns an integer from 1-7, where 1=Sunday.
|
||||
extract_param = "D"
|
||||
elif lookup_type == "iso_week_day":
|
||||
extract_sql = f"TO_CHAR({sql} - 1, %s)"
|
||||
extract_param = "D"
|
||||
elif lookup_type == "week":
|
||||
# IW = ISO week number
|
||||
extract_param = "IW"
|
||||
elif lookup_type == "quarter":
|
||||
extract_param = "Q"
|
||||
elif lookup_type == "iso_year":
|
||||
extract_param = "IYYY"
|
||||
else:
|
||||
lookup_type = lookup_type.upper()
|
||||
if not self._extract_format_re.fullmatch(lookup_type):
|
||||
raise ValueError(f"Invalid loookup type: {lookup_type!r}")
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/EXTRACT-datetime.html
|
||||
return f"EXTRACT({lookup_type} FROM {sql})", params
|
||||
return extract_sql, (*params, extract_param)
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
|
||||
trunc_param = None
|
||||
if lookup_type in ("year", "month"):
|
||||
trunc_param = lookup_type.upper()
|
||||
elif lookup_type == "quarter":
|
||||
trunc_param = "Q"
|
||||
elif lookup_type == "week":
|
||||
trunc_param = "IW"
|
||||
else:
|
||||
return f"TRUNC({sql})", params
|
||||
return f"TRUNC({sql}, %s)", (*params, trunc_param)
|
||||
|
||||
# Oracle crashes with "ORA-03113: end-of-file on communication channel"
|
||||
# if the time zone name is passed in parameter. Use interpolation instead.
|
||||
# https://groups.google.com/forum/#!msg/django-developers/zwQju7hbG78/9l934yelwfsJ
|
||||
# This regexp matches all time zone names from the zoneinfo database.
|
||||
_tzname_re = _lazy_re_compile(r"^[\w/:+-]+$")
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
return f"{sign}{offset}" if offset else tzname
|
||||
|
||||
def _convert_sql_to_tz(self, sql, params, tzname):
|
||||
if not (settings.USE_TZ and tzname):
|
||||
return sql, params
|
||||
if not self._tzname_re.match(tzname):
|
||||
raise ValueError("Invalid time zone name: %s" % tzname)
|
||||
# Convert from connection timezone to the local time, returning
|
||||
# TIMESTAMP WITH TIME ZONE and cast it back to TIMESTAMP to strip the
|
||||
# TIME ZONE details.
|
||||
if self.connection.timezone_name != tzname:
|
||||
from_timezone_name = self.connection.timezone_name
|
||||
to_timezone_name = self._prepare_tzname_delta(tzname)
|
||||
return (
|
||||
f"CAST((FROM_TZ({sql}, '{from_timezone_name}') AT TIME ZONE "
|
||||
f"'{to_timezone_name}') AS TIMESTAMP)",
|
||||
params,
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"TRUNC({sql})", params
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
# Since `TimeField` values are stored as TIMESTAMP change to the
|
||||
# default date and convert the field to the specified timezone.
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
convert_datetime_sql = (
|
||||
f"TO_TIMESTAMP(CONCAT('1900-01-01 ', TO_CHAR({sql}, 'HH24:MI:SS.FF')), "
|
||||
f"'YYYY-MM-DD HH24:MI:SS.FF')"
|
||||
)
|
||||
return (
|
||||
f"CASE WHEN {sql} IS NOT NULL THEN {convert_datetime_sql} ELSE NULL END",
|
||||
(*params, *params),
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return f"FLOOR(EXTRACT(SECOND FROM {sql}))", params
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
|
||||
trunc_param = None
|
||||
if lookup_type in ("year", "month"):
|
||||
trunc_param = lookup_type.upper()
|
||||
elif lookup_type == "quarter":
|
||||
trunc_param = "Q"
|
||||
elif lookup_type == "week":
|
||||
trunc_param = "IW"
|
||||
elif lookup_type == "hour":
|
||||
trunc_param = "HH24"
|
||||
elif lookup_type == "minute":
|
||||
trunc_param = "MI"
|
||||
elif lookup_type == "day":
|
||||
return f"TRUNC({sql})", params
|
||||
else:
|
||||
# Cast to DATE removes sub-second precision.
|
||||
return f"CAST({sql} AS DATE)", params
|
||||
return f"TRUNC({sql}, %s)", (*params, trunc_param)
|
||||
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return f"FLOOR(EXTRACT(SECOND FROM {sql}))", params
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
# The implementation is similar to `datetime_trunc_sql` as both
|
||||
# `DateTimeField` and `TimeField` are stored as TIMESTAMP where
|
||||
# the date part of the later is ignored.
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
trunc_param = None
|
||||
if lookup_type == "hour":
|
||||
trunc_param = "HH24"
|
||||
elif lookup_type == "minute":
|
||||
trunc_param = "MI"
|
||||
elif lookup_type == "second":
|
||||
# Cast to DATE removes sub-second precision.
|
||||
return f"CAST({sql} AS DATE)", params
|
||||
return f"TRUNC({sql}, %s)", (*params, trunc_param)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type in ["JSONField", "TextField"]:
|
||||
converters.append(self.convert_textfield_value)
|
||||
elif internal_type == "BinaryField":
|
||||
converters.append(self.convert_binaryfield_value)
|
||||
elif internal_type == "BooleanField":
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
elif internal_type == "DateTimeField":
|
||||
if settings.USE_TZ:
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == "DateField":
|
||||
converters.append(self.convert_datefield_value)
|
||||
elif internal_type == "TimeField":
|
||||
converters.append(self.convert_timefield_value)
|
||||
elif internal_type == "UUIDField":
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
# Oracle stores empty strings as null. If the field accepts the empty
|
||||
# string, undo this to adhere to the Django convention of using
|
||||
# the empty string instead of null.
|
||||
if expression.output_field.empty_strings_allowed:
|
||||
converters.append(
|
||||
self.convert_empty_bytes
|
||||
if internal_type == "BinaryField"
|
||||
else self.convert_empty_string
|
||||
)
|
||||
return converters
|
||||
|
||||
def convert_textfield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.LOB):
|
||||
value = value.read()
|
||||
return value
|
||||
|
||||
def convert_binaryfield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.LOB):
|
||||
value = force_bytes(value.read())
|
||||
return value
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
if value in (0, 1):
|
||||
value = bool(value)
|
||||
return value
|
||||
|
||||
# oracledb always returns datetime.datetime objects for
|
||||
# DATE and TIMESTAMP columns, but Django wants to see a
|
||||
# python datetime.date, .time, or .datetime.
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_datefield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.Timestamp):
|
||||
value = value.date()
|
||||
return value
|
||||
|
||||
def convert_timefield_value(self, value, expression, connection):
|
||||
if isinstance(value, Database.Timestamp):
|
||||
value = value.time()
|
||||
return value
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def convert_empty_string(value, expression, connection):
|
||||
return "" if value is None else value
|
||||
|
||||
@staticmethod
|
||||
def convert_empty_bytes(value, expression, connection):
|
||||
return b"" if value is None else value
|
||||
|
||||
def deferrable_sql(self):
|
||||
return " DEFERRABLE INITIALLY DEFERRED"
|
||||
|
||||
def fetch_returned_insert_columns(self, cursor, returning_params):
|
||||
columns = []
|
||||
for param in returning_params:
|
||||
value = param.get_value()
|
||||
# Can be removed when cx_Oracle is no longer supported and
|
||||
# python-oracle 2.1.2 becomes the minimum supported version.
|
||||
if value == []:
|
||||
raise DatabaseError(
|
||||
"The database did not return a new row id. Probably "
|
||||
'"ORA-1403: no data found" was raised internally but was '
|
||||
"hidden by the Oracle OCI library (see "
|
||||
"https://code.djangoproject.com/ticket/28859)."
|
||||
)
|
||||
columns.append(value[0])
|
||||
return tuple(columns)
|
||||
|
||||
def no_limit_value(self):
|
||||
return None
|
||||
|
||||
def limit_offset_sql(self, low_mark, high_mark):
|
||||
fetch, offset = self._get_limit_offset_params(low_mark, high_mark)
|
||||
return " ".join(
|
||||
sql
|
||||
for sql in (
|
||||
("OFFSET %d ROWS" % offset) if offset else None,
|
||||
("FETCH FIRST %d ROWS ONLY" % fetch) if fetch else None,
|
||||
)
|
||||
if sql
|
||||
)
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# https://python-oracledb.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.statement
|
||||
# The DB API definition does not define this attribute.
|
||||
statement = cursor.statement
|
||||
# Unlike Psycopg's `query` and MySQLdb`'s `_executed`, oracledb's
|
||||
# `statement` doesn't contain the query parameters. Substitute
|
||||
# parameters manually.
|
||||
if statement and params:
|
||||
if isinstance(params, (tuple, list)):
|
||||
params = {
|
||||
f":arg{i}": param for i, param in enumerate(dict.fromkeys(params))
|
||||
}
|
||||
elif isinstance(params, dict):
|
||||
params = {f":{key}": val for (key, val) in params.items()}
|
||||
for key in sorted(params, key=len, reverse=True):
|
||||
statement = statement.replace(
|
||||
key, force_str(params[key], errors="replace")
|
||||
)
|
||||
return statement
|
||||
|
||||
def last_insert_id(self, cursor, table_name, pk_name):
|
||||
sq_name = self._get_sequence_name(cursor, strip_quotes(table_name), pk_name)
|
||||
cursor.execute('"%s".currval' % sq_name)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
|
||||
return "UPPER(%s)"
|
||||
if lookup_type != "isnull" and internal_type in (
|
||||
"BinaryField",
|
||||
"TextField",
|
||||
):
|
||||
return "DBMS_LOB.SUBSTR(%s)"
|
||||
return "%s"
|
||||
|
||||
def max_in_list_size(self):
|
||||
return 1000
|
||||
|
||||
def max_name_length(self):
|
||||
return 30
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def prep_for_iexact_query(self, x):
|
||||
return x
|
||||
|
||||
def process_clob(self, value):
|
||||
if value is None:
|
||||
return ""
|
||||
return value.read()
|
||||
|
||||
def quote_name(self, name):
|
||||
# SQL92 requires delimited (quoted) names to be case-sensitive. When
|
||||
# not quoted, Oracle has case-insensitive behavior for identifiers, but
|
||||
# always defaults to uppercase.
|
||||
# We simplify things by making Oracle identifiers always uppercase.
|
||||
if not name.startswith('"') and not name.endswith('"'):
|
||||
name = '"%s"' % truncate_name(name, self.max_name_length())
|
||||
# Oracle puts the query text into a (query % args) construct, so % signs
|
||||
# in names need to be escaped. The '%%' will be collapsed back to '%' at
|
||||
# that stage so we aren't really making the name longer here.
|
||||
name = name.replace("%", "%%")
|
||||
return name.upper()
|
||||
|
||||
def regex_lookup(self, lookup_type):
|
||||
if lookup_type == "regex":
|
||||
match_option = "'c'"
|
||||
else:
|
||||
match_option = "'i'"
|
||||
return "REGEXP_LIKE(%%s, %%s, %s)" % match_option
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
if not fields:
|
||||
return "", ()
|
||||
field_names = []
|
||||
params = []
|
||||
for field in fields:
|
||||
field_names.append(
|
||||
"%s.%s"
|
||||
% (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
)
|
||||
)
|
||||
params.append(InsertVar(field))
|
||||
return "RETURNING %s INTO %s" % (
|
||||
", ".join(field_names),
|
||||
", ".join(["%s"] * len(params)),
|
||||
), tuple(params)
|
||||
|
||||
def __foreign_key_constraints(self, table_name, recursive):
|
||||
with self.connection.cursor() as cursor:
|
||||
if recursive:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
user_tables.table_name, rcons.constraint_name
|
||||
FROM
|
||||
user_tables
|
||||
JOIN
|
||||
user_constraints cons
|
||||
ON (user_tables.table_name = cons.table_name
|
||||
AND cons.constraint_type = ANY('P', 'U'))
|
||||
LEFT JOIN
|
||||
user_constraints rcons
|
||||
ON (user_tables.table_name = rcons.table_name
|
||||
AND rcons.constraint_type = 'R')
|
||||
START WITH user_tables.table_name = UPPER(%s)
|
||||
CONNECT BY
|
||||
NOCYCLE PRIOR cons.constraint_name = rcons.r_constraint_name
|
||||
GROUP BY
|
||||
user_tables.table_name, rcons.constraint_name
|
||||
HAVING user_tables.table_name != UPPER(%s)
|
||||
ORDER BY MAX(level) DESC
|
||||
""",
|
||||
(table_name, table_name),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
cons.table_name, cons.constraint_name
|
||||
FROM
|
||||
user_constraints cons
|
||||
WHERE
|
||||
cons.constraint_type = 'R'
|
||||
AND cons.table_name = UPPER(%s)
|
||||
""",
|
||||
(table_name,),
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
@cached_property
|
||||
def _foreign_key_constraints(self):
|
||||
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||
# Django's test suite.
|
||||
return lru_cache(maxsize=512)(self.__foreign_key_constraints)
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
truncated_tables = {table.upper() for table in tables}
|
||||
constraints = set()
|
||||
# Oracle's TRUNCATE CASCADE only works with ON DELETE CASCADE foreign
|
||||
# keys which Django doesn't define. Emulate the PostgreSQL behavior
|
||||
# which truncates all dependent tables by manually retrieving all
|
||||
# foreign key constraints and resolving dependencies.
|
||||
for table in tables:
|
||||
for foreign_table, constraint in self._foreign_key_constraints(
|
||||
table, recursive=allow_cascade
|
||||
):
|
||||
if allow_cascade:
|
||||
truncated_tables.add(foreign_table)
|
||||
constraints.add((foreign_table, constraint))
|
||||
sql = (
|
||||
[
|
||||
"%s %s %s %s %s %s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("ALTER"),
|
||||
style.SQL_KEYWORD("TABLE"),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
style.SQL_KEYWORD("DISABLE"),
|
||||
style.SQL_KEYWORD("CONSTRAINT"),
|
||||
style.SQL_FIELD(self.quote_name(constraint)),
|
||||
style.SQL_KEYWORD("KEEP"),
|
||||
style.SQL_KEYWORD("INDEX"),
|
||||
)
|
||||
for table, constraint in constraints
|
||||
]
|
||||
+ [
|
||||
"%s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("TRUNCATE"),
|
||||
style.SQL_KEYWORD("TABLE"),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
)
|
||||
for table in truncated_tables
|
||||
]
|
||||
+ [
|
||||
"%s %s %s %s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("ALTER"),
|
||||
style.SQL_KEYWORD("TABLE"),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
style.SQL_KEYWORD("ENABLE"),
|
||||
style.SQL_KEYWORD("CONSTRAINT"),
|
||||
style.SQL_FIELD(self.quote_name(constraint)),
|
||||
)
|
||||
for table, constraint in constraints
|
||||
]
|
||||
)
|
||||
if reset_sequences:
|
||||
sequences = [
|
||||
sequence
|
||||
for sequence in self.connection.introspection.sequence_list()
|
||||
if sequence["table"].upper() in truncated_tables
|
||||
]
|
||||
# Since we've just deleted all the rows, running our sequence ALTER
|
||||
# code will reset the sequence to 0.
|
||||
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
sql = []
|
||||
for sequence_info in sequences:
|
||||
no_autofield_sequence_name = self._get_no_autofield_sequence_name(
|
||||
sequence_info["table"]
|
||||
)
|
||||
table = self.quote_name(sequence_info["table"])
|
||||
column = self.quote_name(sequence_info["column"] or "id")
|
||||
query = self._sequence_reset_sql % {
|
||||
"no_autofield_sequence_name": no_autofield_sequence_name,
|
||||
"table": table,
|
||||
"column": column,
|
||||
"table_name": strip_quotes(table),
|
||||
"column_name": strip_quotes(column),
|
||||
"suffix": self.connection.features.bare_select_suffix,
|
||||
}
|
||||
sql.append(query)
|
||||
return sql
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
output = []
|
||||
query = self._sequence_reset_sql
|
||||
for model in model_list:
|
||||
for f in model._meta.local_fields:
|
||||
if isinstance(f, AutoField):
|
||||
no_autofield_sequence_name = self._get_no_autofield_sequence_name(
|
||||
model._meta.db_table
|
||||
)
|
||||
table = self.quote_name(model._meta.db_table)
|
||||
column = self.quote_name(f.column)
|
||||
output.append(
|
||||
query
|
||||
% {
|
||||
"no_autofield_sequence_name": no_autofield_sequence_name,
|
||||
"table": table,
|
||||
"column": column,
|
||||
"table_name": strip_quotes(table),
|
||||
"column_name": strip_quotes(column),
|
||||
"suffix": self.connection.features.bare_select_suffix,
|
||||
}
|
||||
)
|
||||
# Only one AutoField is allowed per model, so don't
|
||||
# continue to loop
|
||||
break
|
||||
return output
|
||||
|
||||
def start_transaction_sql(self):
|
||||
return ""
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
if inline:
|
||||
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
|
||||
else:
|
||||
return "TABLESPACE %s" % self.quote_name(tablespace)
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
"""
|
||||
Transform a date value to an object compatible with what is expected
|
||||
by the backend driver for date columns.
|
||||
The default implementation transforms the date to text, but that is not
|
||||
necessary for Oracle.
|
||||
"""
|
||||
return value
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
"""
|
||||
Transform a datetime value to an object compatible with what is expected
|
||||
by the backend driver for datetime columns.
|
||||
|
||||
If naive datetime is passed assumes that is in UTC. Normally Django
|
||||
models.DateTimeField makes sure that if USE_TZ is True passed datetime
|
||||
is timezone aware.
|
||||
"""
|
||||
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# oracledb doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Oracle backend does not support timezone-aware datetimes when "
|
||||
"USE_TZ is False."
|
||||
)
|
||||
|
||||
return Oracle_datetime.from_datetime(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if isinstance(value, str):
|
||||
return datetime.datetime.strptime(value, "%H:%M:%S")
|
||||
|
||||
# Oracle doesn't support tz-aware times
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("Oracle backend does not support timezone-aware times.")
|
||||
|
||||
return Oracle_datetime(
|
||||
1900, 1, 1, value.hour, value.minute, value.second, value.microsecond
|
||||
)
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
lhs, rhs = sub_expressions
|
||||
if connector == "%%":
|
||||
return "MOD(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "&":
|
||||
return "BITAND(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "|":
|
||||
return "BITAND(-%(lhs)s-1,%(rhs)s)+%(lhs)s" % {"lhs": lhs, "rhs": rhs}
|
||||
elif connector == "<<":
|
||||
return "(%(lhs)s * POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
|
||||
elif connector == ">>":
|
||||
return "FLOOR(%(lhs)s / POWER(2, %(rhs)s))" % {"lhs": lhs, "rhs": rhs}
|
||||
elif connector == "^":
|
||||
return "POWER(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "#":
|
||||
raise NotSupportedError("Bitwise XOR is not supported in Oracle.")
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def _get_no_autofield_sequence_name(self, table):
|
||||
"""
|
||||
Manually created sequence name to keep backward compatibility for
|
||||
AutoFields that aren't Oracle identity columns.
|
||||
"""
|
||||
name_length = self.max_name_length() - 3
|
||||
return "%s_SQ" % truncate_name(strip_quotes(table), name_length).upper()
|
||||
|
||||
def _get_sequence_name(self, cursor, table, pk_name):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT sequence_name
|
||||
FROM user_tab_identity_cols
|
||||
WHERE table_name = UPPER(%s)
|
||||
AND column_name = UPPER(%s)""",
|
||||
[table, pk_name],
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return self._get_no_autofield_sequence_name(table) if row is None else row[0]
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
field_placeholders = [
|
||||
BulkInsertMapper.types.get(
|
||||
getattr(field, "target_field", field).get_internal_type(), "%s"
|
||||
)
|
||||
for field in fields
|
||||
if field
|
||||
]
|
||||
query = []
|
||||
for row in placeholder_rows:
|
||||
select = []
|
||||
for i, placeholder in enumerate(row):
|
||||
# A model without any fields has fields=[None].
|
||||
if fields[i]:
|
||||
placeholder = field_placeholders[i] % placeholder
|
||||
# Add columns aliases to the first select to avoid "ORA-00918:
|
||||
# column ambiguously defined" when two or more columns in the
|
||||
# first select have the same value.
|
||||
if not query:
|
||||
placeholder = "%s col_%s" % (placeholder, i)
|
||||
select.append(placeholder)
|
||||
suffix = self.connection.features.bare_select_suffix
|
||||
query.append(f"SELECT %s{suffix}" % ", ".join(select))
|
||||
# Bulk insert to tables with Oracle identity columns causes Oracle to
|
||||
# add sequence.nextval to it. Sequence.nextval cannot be used with the
|
||||
# UNION operator. To prevent incorrect SQL, move UNION to a subquery.
|
||||
return "SELECT * FROM (%s)" % " UNION ALL ".join(query)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
if internal_type == "DateField":
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
return (
|
||||
"NUMTODSINTERVAL(TO_NUMBER(%s - %s), 'DAY')" % (lhs_sql, rhs_sql),
|
||||
params,
|
||||
)
|
||||
return super().subtract_temporals(internal_type, lhs, rhs)
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""Oracle restricts the number of parameters in a query."""
|
||||
fields = list(
|
||||
chain.from_iterable(
|
||||
field.fields if isinstance(field, CompositePrimaryKey) else [field]
|
||||
for field in fields
|
||||
)
|
||||
)
|
||||
if fields:
|
||||
return self.connection.features.max_query_params // len(fields)
|
||||
return len(objs)
|
||||
|
||||
def conditional_expression_supported_in_where_clause(self, expression):
|
||||
"""
|
||||
Oracle supports only EXISTS(...) or filters in the WHERE clause, others
|
||||
must be compared with True.
|
||||
"""
|
||||
if isinstance(expression, (Exists, Lookup, WhereNode)):
|
||||
return True
|
||||
if isinstance(expression, ExpressionWrapper) and expression.conditional:
|
||||
return self.conditional_expression_supported_in_where_clause(
|
||||
expression.expression
|
||||
)
|
||||
if isinstance(expression, RawSQL) and expression.conditional:
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,20 @@
|
||||
import warnings
|
||||
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
|
||||
try:
|
||||
import oracledb
|
||||
|
||||
is_oracledb = True
|
||||
except ImportError as e:
|
||||
try:
|
||||
import cx_Oracle as oracledb # NOQA
|
||||
|
||||
warnings.warn(
|
||||
"cx_Oracle is deprecated. Use oracledb instead.",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
is_oracledb = False
|
||||
except ImportError:
|
||||
raise e from None
|
||||
@@ -0,0 +1,252 @@
|
||||
import copy
|
||||
import datetime
|
||||
import re
|
||||
|
||||
from django.db import DatabaseError
|
||||
from django.db.backends.base.schema import (
|
||||
BaseDatabaseSchemaEditor,
|
||||
_related_non_m2m_objects,
|
||||
)
|
||||
from django.utils.duration import duration_iso_string
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
sql_create_column = "ALTER TABLE %(table)s ADD %(column)s %(definition)s"
|
||||
sql_alter_column_type = "MODIFY %(column)s %(type)s%(collation)s"
|
||||
sql_alter_column_null = "MODIFY %(column)s NULL"
|
||||
sql_alter_column_not_null = "MODIFY %(column)s NOT NULL"
|
||||
sql_alter_column_default = "MODIFY %(column)s DEFAULT %(default)s"
|
||||
sql_alter_column_no_default = "MODIFY %(column)s DEFAULT NULL"
|
||||
sql_alter_column_no_default_null = sql_alter_column_no_default
|
||||
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
|
||||
sql_create_column_inline_fk = (
|
||||
"CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
|
||||
)
|
||||
sql_delete_table = "DROP TABLE %(table)s CASCADE CONSTRAINTS"
|
||||
sql_create_index = "CREATE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
|
||||
|
||||
def quote_value(self, value):
|
||||
if isinstance(value, (datetime.date, datetime.time, datetime.datetime)):
|
||||
return "'%s'" % value
|
||||
elif isinstance(value, datetime.timedelta):
|
||||
return "'%s'" % duration_iso_string(value)
|
||||
elif isinstance(value, str):
|
||||
return "'%s'" % value.replace("'", "''")
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
return "'%s'" % value.hex()
|
||||
elif isinstance(value, bool):
|
||||
return "1" if value else "0"
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
def remove_field(self, model, field):
|
||||
# If the column is an identity column, drop the identity before
|
||||
# removing the field.
|
||||
if self._is_identity_column(model._meta.db_table, field.column):
|
||||
self._drop_identity(model._meta.db_table, field.column)
|
||||
super().remove_field(model, field)
|
||||
|
||||
def delete_model(self, model):
|
||||
# Run superclass action
|
||||
super().delete_model(model)
|
||||
# Clean up manually created sequence.
|
||||
self.execute(
|
||||
"""
|
||||
DECLARE
|
||||
i INTEGER;
|
||||
BEGIN
|
||||
SELECT COUNT(1) INTO i FROM USER_SEQUENCES
|
||||
WHERE SEQUENCE_NAME = '%(sq_name)s';
|
||||
IF i = 1 THEN
|
||||
EXECUTE IMMEDIATE 'DROP SEQUENCE "%(sq_name)s"';
|
||||
END IF;
|
||||
END;
|
||||
/"""
|
||||
% {
|
||||
"sq_name": self.connection.ops._get_no_autofield_sequence_name(
|
||||
model._meta.db_table
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def alter_field(self, model, old_field, new_field, strict=False):
|
||||
try:
|
||||
super().alter_field(model, old_field, new_field, strict)
|
||||
except DatabaseError as e:
|
||||
description = str(e)
|
||||
# If we're changing type to an unsupported type we need a
|
||||
# SQLite-ish workaround
|
||||
if "ORA-22858" in description or "ORA-22859" in description:
|
||||
self._alter_field_type_workaround(model, old_field, new_field)
|
||||
# If an identity column is changing to a non-numeric type, drop the
|
||||
# identity first.
|
||||
elif "ORA-30675" in description:
|
||||
self._drop_identity(model._meta.db_table, old_field.column)
|
||||
self.alter_field(model, old_field, new_field, strict)
|
||||
# If a primary key column is changing to an identity column, drop
|
||||
# the primary key first.
|
||||
elif "ORA-30673" in description and old_field.primary_key:
|
||||
self._delete_primary_key(model, strict=True)
|
||||
self._alter_field_type_workaround(model, old_field, new_field)
|
||||
# If a collation is changing on a primary key, drop the primary key
|
||||
# first.
|
||||
elif "ORA-43923" in description and old_field.primary_key:
|
||||
self._delete_primary_key(model, strict=True)
|
||||
self.alter_field(model, old_field, new_field, strict)
|
||||
# Restore a primary key, if needed.
|
||||
if new_field.primary_key:
|
||||
self.execute(self._create_primary_key_sql(model, new_field))
|
||||
else:
|
||||
raise
|
||||
|
||||
def _alter_field_type_workaround(self, model, old_field, new_field):
|
||||
"""
|
||||
Oracle refuses to change from some type to other type.
|
||||
What we need to do instead is:
|
||||
- Add a nullable version of the desired field with a temporary name. If
|
||||
the new column is an auto field, then the temporary column can't be
|
||||
nullable.
|
||||
- Update the table to transfer values from old to new
|
||||
- Drop old column
|
||||
- Rename the new column and possibly drop the nullable property
|
||||
"""
|
||||
# Make a new field that's like the new one but with a temporary
|
||||
# column name.
|
||||
new_temp_field = copy.deepcopy(new_field)
|
||||
new_temp_field.null = new_field.get_internal_type() not in (
|
||||
"AutoField",
|
||||
"BigAutoField",
|
||||
"SmallAutoField",
|
||||
)
|
||||
new_temp_field.column = self._generate_temp_name(new_field.column)
|
||||
# Add it
|
||||
self.add_field(model, new_temp_field)
|
||||
# Explicit data type conversion
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf
|
||||
# /Data-Type-Comparison-Rules.html#GUID-D0C5A47E-6F93-4C2D-9E49-4F2B86B359DD
|
||||
new_value = self.quote_name(old_field.column)
|
||||
old_type = old_field.db_type(self.connection)
|
||||
if re.match("^N?CLOB", old_type):
|
||||
new_value = "TO_CHAR(%s)" % new_value
|
||||
old_type = "VARCHAR2"
|
||||
if re.match("^N?VARCHAR2", old_type):
|
||||
new_internal_type = new_field.get_internal_type()
|
||||
if new_internal_type == "DateField":
|
||||
new_value = "TO_DATE(%s, 'YYYY-MM-DD')" % new_value
|
||||
elif new_internal_type == "DateTimeField":
|
||||
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
|
||||
elif new_internal_type == "TimeField":
|
||||
# TimeField are stored as TIMESTAMP with a 1900-01-01 date part.
|
||||
new_value = "CONCAT('1900-01-01 ', %s)" % new_value
|
||||
new_value = "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" % new_value
|
||||
# Transfer values across
|
||||
self.execute(
|
||||
"UPDATE %s set %s=%s"
|
||||
% (
|
||||
self.quote_name(model._meta.db_table),
|
||||
self.quote_name(new_temp_field.column),
|
||||
new_value,
|
||||
)
|
||||
)
|
||||
# Drop the old field
|
||||
self.remove_field(model, old_field)
|
||||
# Rename and possibly make the new field NOT NULL
|
||||
super().alter_field(model, new_temp_field, new_field)
|
||||
# Recreate foreign key (if necessary) because the old field is not
|
||||
# passed to the alter_field() and data types of new_temp_field and
|
||||
# new_field always match.
|
||||
new_type = new_field.db_type(self.connection)
|
||||
if (
|
||||
(old_field.primary_key and new_field.primary_key)
|
||||
or (old_field.unique and new_field.unique)
|
||||
) and old_type != new_type:
|
||||
for _, rel in _related_non_m2m_objects(new_temp_field, new_field):
|
||||
if rel.field.db_constraint:
|
||||
self.execute(
|
||||
self._create_fk_sql(rel.related_model, rel.field, "_fk")
|
||||
)
|
||||
|
||||
def _alter_column_type_sql(
|
||||
self, model, old_field, new_field, new_type, old_collation, new_collation
|
||||
):
|
||||
auto_field_types = {"AutoField", "BigAutoField", "SmallAutoField"}
|
||||
# Drop the identity if migrating away from AutoField.
|
||||
if (
|
||||
old_field.get_internal_type() in auto_field_types
|
||||
and new_field.get_internal_type() not in auto_field_types
|
||||
and self._is_identity_column(model._meta.db_table, new_field.column)
|
||||
):
|
||||
self._drop_identity(model._meta.db_table, new_field.column)
|
||||
return super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
|
||||
def normalize_name(self, name):
|
||||
"""
|
||||
Get the properly shortened and uppercased identifier as returned by
|
||||
quote_name() but without the quotes.
|
||||
"""
|
||||
nn = self.quote_name(name)
|
||||
if nn[0] == '"' and nn[-1] == '"':
|
||||
nn = nn[1:-1]
|
||||
return nn
|
||||
|
||||
def _generate_temp_name(self, for_name):
|
||||
"""Generate temporary names for workarounds that need temp columns."""
|
||||
suffix = hex(hash(for_name)).upper()[1:]
|
||||
return self.normalize_name(for_name + "_" + suffix)
|
||||
|
||||
def prepare_default(self, value):
|
||||
return self.quote_value(value)
|
||||
|
||||
def _field_should_be_indexed(self, model, field):
|
||||
create_index = super()._field_should_be_indexed(model, field)
|
||||
db_type = field.db_type(self.connection)
|
||||
if (
|
||||
db_type is not None
|
||||
and db_type.lower() in self.connection._limited_data_types
|
||||
):
|
||||
return False
|
||||
return create_index
|
||||
|
||||
def _is_identity_column(self, table_name, column_name):
|
||||
if not column_name:
|
||||
return False
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
CASE WHEN identity_column = 'YES' THEN 1 ELSE 0 END
|
||||
FROM user_tab_cols
|
||||
WHERE table_name = %s AND
|
||||
column_name = %s
|
||||
""",
|
||||
[self.normalize_name(table_name), self.normalize_name(column_name)],
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row[0] if row else False
|
||||
|
||||
def _drop_identity(self, table_name, column_name):
|
||||
self.execute(
|
||||
"ALTER TABLE %(table)s MODIFY %(column)s DROP IDENTITY"
|
||||
% {
|
||||
"table": self.quote_name(table_name),
|
||||
"column": self.quote_name(column_name),
|
||||
}
|
||||
)
|
||||
|
||||
def _get_default_collation(self, table_name):
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT default_collation FROM user_tables WHERE table_name = %s
|
||||
""",
|
||||
[self.normalize_name(table_name)],
|
||||
)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def _collate_sql(self, collation, old_collation=None, table_name=None):
|
||||
if collation is None and old_collation is not None:
|
||||
collation = self._get_default_collation(table_name)
|
||||
return super()._collate_sql(collation, old_collation, table_name)
|
||||
@@ -0,0 +1,99 @@
|
||||
import datetime
|
||||
import decimal
|
||||
|
||||
from .base import Database
|
||||
|
||||
|
||||
class InsertVar:
|
||||
"""
|
||||
A late-binding cursor variable that can be passed to Cursor.execute
|
||||
as a parameter, in order to receive the id of the row created by an
|
||||
insert statement.
|
||||
"""
|
||||
|
||||
types = {
|
||||
"AutoField": int,
|
||||
"BigAutoField": int,
|
||||
"SmallAutoField": int,
|
||||
"IntegerField": int,
|
||||
"BigIntegerField": int,
|
||||
"SmallIntegerField": int,
|
||||
"PositiveBigIntegerField": int,
|
||||
"PositiveSmallIntegerField": int,
|
||||
"PositiveIntegerField": int,
|
||||
"BooleanField": int,
|
||||
"FloatField": Database.DB_TYPE_BINARY_DOUBLE,
|
||||
"DateTimeField": Database.DB_TYPE_TIMESTAMP,
|
||||
"DateField": Database.Date,
|
||||
"DecimalField": decimal.Decimal,
|
||||
}
|
||||
|
||||
def __init__(self, field):
|
||||
internal_type = getattr(field, "target_field", field).get_internal_type()
|
||||
self.db_type = self.types.get(internal_type, str)
|
||||
self.bound_param = None
|
||||
|
||||
def bind_parameter(self, cursor):
|
||||
self.bound_param = cursor.cursor.var(self.db_type)
|
||||
return self.bound_param
|
||||
|
||||
def get_value(self):
|
||||
return self.bound_param.getvalue()
|
||||
|
||||
|
||||
class Oracle_datetime(datetime.datetime):
|
||||
"""
|
||||
A datetime object, with an additional class attribute
|
||||
to tell oracledb to save the microseconds too.
|
||||
"""
|
||||
|
||||
input_size = Database.DB_TYPE_TIMESTAMP
|
||||
|
||||
@classmethod
|
||||
def from_datetime(cls, dt):
|
||||
return Oracle_datetime(
|
||||
dt.year,
|
||||
dt.month,
|
||||
dt.day,
|
||||
dt.hour,
|
||||
dt.minute,
|
||||
dt.second,
|
||||
dt.microsecond,
|
||||
)
|
||||
|
||||
|
||||
class BulkInsertMapper:
|
||||
BLOB = "TO_BLOB(%s)"
|
||||
DATE = "TO_DATE(%s)"
|
||||
INTERVAL = "CAST(%s as INTERVAL DAY(9) TO SECOND(6))"
|
||||
NCLOB = "TO_NCLOB(%s)"
|
||||
NUMBER = "TO_NUMBER(%s)"
|
||||
TIMESTAMP = "TO_TIMESTAMP(%s)"
|
||||
|
||||
types = {
|
||||
"AutoField": NUMBER,
|
||||
"BigAutoField": NUMBER,
|
||||
"BigIntegerField": NUMBER,
|
||||
"BinaryField": BLOB,
|
||||
"BooleanField": NUMBER,
|
||||
"DateField": DATE,
|
||||
"DateTimeField": TIMESTAMP,
|
||||
"DecimalField": NUMBER,
|
||||
"DurationField": INTERVAL,
|
||||
"FloatField": NUMBER,
|
||||
"IntegerField": NUMBER,
|
||||
"PositiveBigIntegerField": NUMBER,
|
||||
"PositiveIntegerField": NUMBER,
|
||||
"PositiveSmallIntegerField": NUMBER,
|
||||
"SmallAutoField": NUMBER,
|
||||
"SmallIntegerField": NUMBER,
|
||||
"TextField": NCLOB,
|
||||
"TimeField": TIMESTAMP,
|
||||
}
|
||||
|
||||
|
||||
def dsn(settings_dict):
|
||||
if settings_dict["PORT"]:
|
||||
host = settings_dict["HOST"].strip() or "localhost"
|
||||
return Database.makedsn(host, int(settings_dict["PORT"]), settings_dict["NAME"])
|
||||
return settings_dict["NAME"]
|
||||
@@ -0,0 +1,22 @@
|
||||
from django.core import checks
|
||||
from django.db.backends.base.validation import BaseDatabaseValidation
|
||||
|
||||
|
||||
class DatabaseValidation(BaseDatabaseValidation):
|
||||
def check_field_type(self, field, field_type):
|
||||
"""Oracle doesn't support a database index on some data types."""
|
||||
errors = []
|
||||
if field.db_index and field_type.lower() in self.connection._limited_data_types:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"Oracle does not support a database index on %s columns."
|
||||
% field_type,
|
||||
hint=(
|
||||
"An index won't be created. Silence this warning if "
|
||||
"you don't care about it."
|
||||
),
|
||||
obj=field,
|
||||
id="fields.W162",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
PostgreSQL database backend for Django.
|
||||
|
||||
Requires psycopg2 >= 2.8.4 or psycopg >= 3.1.8
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import DatabaseError as WrappedDatabaseError
|
||||
from django.db import connections
|
||||
from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
|
||||
from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.safestring import SafeString
|
||||
from django.utils.version import get_version_tuple
|
||||
|
||||
try:
|
||||
try:
|
||||
import psycopg as Database
|
||||
except ImportError:
|
||||
import psycopg2 as Database
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
|
||||
|
||||
|
||||
def psycopg_version():
|
||||
version = Database.__version__.split(" ", 1)[0]
|
||||
return get_version_tuple(version)
|
||||
|
||||
|
||||
if psycopg_version() < (2, 8, 4):
|
||||
raise ImproperlyConfigured(
|
||||
f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
|
||||
)
|
||||
if (3,) <= psycopg_version() < (3, 1, 8):
|
||||
raise ImproperlyConfigured(
|
||||
f"psycopg version 3.1.8 or newer is required; you have {Database.__version__}"
|
||||
)
|
||||
|
||||
|
||||
from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
|
||||
|
||||
if is_psycopg3:
|
||||
from psycopg import adapters, sql
|
||||
from psycopg.pq import Format
|
||||
|
||||
from .psycopg_any import get_adapters_template, register_tzloader
|
||||
|
||||
TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
|
||||
|
||||
else:
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
|
||||
psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
|
||||
psycopg2.extras.register_uuid()
|
||||
|
||||
# Register support for inet[] manually so we don't have to handle the Inet()
|
||||
# object on load all the time.
|
||||
INETARRAY_OID = 1041
|
||||
INETARRAY = psycopg2.extensions.new_array_type(
|
||||
(INETARRAY_OID,),
|
||||
"INETARRAY",
|
||||
psycopg2.extensions.UNICODE,
|
||||
)
|
||||
psycopg2.extensions.register_type(INETARRAY)
|
||||
|
||||
# Some of these import psycopg, so import them after checking if it's installed.
|
||||
from .client import DatabaseClient # NOQA isort:skip
|
||||
from .creation import DatabaseCreation # NOQA isort:skip
|
||||
from .features import DatabaseFeatures # NOQA isort:skip
|
||||
from .introspection import DatabaseIntrospection # NOQA isort:skip
|
||||
from .operations import DatabaseOperations # NOQA isort:skip
|
||||
from .schema import DatabaseSchemaEditor # NOQA isort:skip
|
||||
|
||||
|
||||
def _get_varchar_column(data):
|
||||
if data["max_length"] is None:
|
||||
return "varchar"
|
||||
return "varchar(%(max_length)s)" % data
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "postgresql"
|
||||
display_name = "PostgreSQL"
|
||||
# This dictionary maps Field objects to their associated PostgreSQL column
|
||||
# types, as strings. Column-type strings can contain format strings; they'll
|
||||
# be interpolated against the values of Field.__dict__ before being output.
|
||||
# If a column type is set to None, it won't be included in the output.
|
||||
data_types = {
|
||||
"AutoField": "integer",
|
||||
"BigAutoField": "bigint",
|
||||
"BinaryField": "bytea",
|
||||
"BooleanField": "boolean",
|
||||
"CharField": _get_varchar_column,
|
||||
"DateField": "date",
|
||||
"DateTimeField": "timestamp with time zone",
|
||||
"DecimalField": "numeric(%(max_digits)s, %(decimal_places)s)",
|
||||
"DurationField": "interval",
|
||||
"FileField": "varchar(%(max_length)s)",
|
||||
"FilePathField": "varchar(%(max_length)s)",
|
||||
"FloatField": "double precision",
|
||||
"IntegerField": "integer",
|
||||
"BigIntegerField": "bigint",
|
||||
"IPAddressField": "inet",
|
||||
"GenericIPAddressField": "inet",
|
||||
"JSONField": "jsonb",
|
||||
"OneToOneField": "integer",
|
||||
"PositiveBigIntegerField": "bigint",
|
||||
"PositiveIntegerField": "integer",
|
||||
"PositiveSmallIntegerField": "smallint",
|
||||
"SlugField": "varchar(%(max_length)s)",
|
||||
"SmallAutoField": "smallint",
|
||||
"SmallIntegerField": "smallint",
|
||||
"TextField": "text",
|
||||
"TimeField": "time",
|
||||
"UUIDField": "uuid",
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
"PositiveBigIntegerField": '"%(column)s" >= 0',
|
||||
"PositiveIntegerField": '"%(column)s" >= 0',
|
||||
"PositiveSmallIntegerField": '"%(column)s" >= 0',
|
||||
}
|
||||
data_types_suffix = {
|
||||
"AutoField": "GENERATED BY DEFAULT AS IDENTITY",
|
||||
"BigAutoField": "GENERATED BY DEFAULT AS IDENTITY",
|
||||
"SmallAutoField": "GENERATED BY DEFAULT AS IDENTITY",
|
||||
}
|
||||
operators = {
|
||||
"exact": "= %s",
|
||||
"iexact": "= UPPER(%s)",
|
||||
"contains": "LIKE %s",
|
||||
"icontains": "LIKE UPPER(%s)",
|
||||
"regex": "~ %s",
|
||||
"iregex": "~* %s",
|
||||
"gt": "> %s",
|
||||
"gte": ">= %s",
|
||||
"lt": "< %s",
|
||||
"lte": "<= %s",
|
||||
"startswith": "LIKE %s",
|
||||
"endswith": "LIKE %s",
|
||||
"istartswith": "LIKE UPPER(%s)",
|
||||
"iendswith": "LIKE UPPER(%s)",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = (
|
||||
r"REPLACE(REPLACE(REPLACE({}, E'\\', E'\\\\'), E'%%', E'\\%%'), E'_', E'\\_')"
|
||||
)
|
||||
pattern_ops = {
|
||||
"contains": "LIKE '%%' || {} || '%%'",
|
||||
"icontains": "LIKE '%%' || UPPER({}) || '%%'",
|
||||
"startswith": "LIKE {} || '%%'",
|
||||
"istartswith": "LIKE UPPER({}) || '%%'",
|
||||
"endswith": "LIKE '%%' || {}",
|
||||
"iendswith": "LIKE '%%' || UPPER({})",
|
||||
}
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
# PostgreSQL backend-specific attributes.
|
||||
_named_cursor_idx = 0
|
||||
_connection_pools = {}
|
||||
|
||||
@property
|
||||
def pool(self):
|
||||
pool_options = self.settings_dict["OPTIONS"].get("pool")
|
||||
if self.alias == NO_DB_ALIAS or not pool_options:
|
||||
return None
|
||||
|
||||
if self.alias not in self._connection_pools:
|
||||
if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
|
||||
raise ImproperlyConfigured(
|
||||
"Pooling doesn't support persistent connections."
|
||||
)
|
||||
# Set the default options.
|
||||
if pool_options is True:
|
||||
pool_options = {}
|
||||
|
||||
try:
|
||||
from psycopg_pool import ConnectionPool
|
||||
except ImportError as err:
|
||||
raise ImproperlyConfigured(
|
||||
"Error loading psycopg_pool module.\nDid you install psycopg[pool]?"
|
||||
) from err
|
||||
|
||||
connect_kwargs = self.get_connection_params()
|
||||
# Ensure we run in autocommit, Django properly sets it later on.
|
||||
connect_kwargs["autocommit"] = True
|
||||
enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"]
|
||||
pool = ConnectionPool(
|
||||
kwargs=connect_kwargs,
|
||||
open=False, # Do not open the pool during startup.
|
||||
configure=self._configure_connection,
|
||||
check=ConnectionPool.check_connection if enable_checks else None,
|
||||
**pool_options,
|
||||
)
|
||||
# setdefault() ensures that multiple threads don't set this in
|
||||
# parallel. Since we do not open the pool during it's init above,
|
||||
# this means that at worst during startup multiple threads generate
|
||||
# pool objects and the first to set it wins.
|
||||
self._connection_pools.setdefault(self.alias, pool)
|
||||
|
||||
return self._connection_pools[self.alias]
|
||||
|
||||
def close_pool(self):
|
||||
if self.pool:
|
||||
self.pool.close()
|
||||
del self._connection_pools[self.alias]
|
||||
|
||||
def get_database_version(self):
|
||||
"""
|
||||
Return a tuple of the database's version.
|
||||
E.g. for pg_version 120004, return (12, 4).
|
||||
"""
|
||||
return divmod(self.pg_version, 10000)
|
||||
|
||||
def get_connection_params(self):
|
||||
settings_dict = self.settings_dict
|
||||
# None may be used to connect to the default 'postgres' db
|
||||
if settings_dict["NAME"] == "" and not settings_dict["OPTIONS"].get("service"):
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the NAME or OPTIONS['service'] value."
|
||||
)
|
||||
if len(settings_dict["NAME"] or "") > self.ops.max_name_length():
|
||||
raise ImproperlyConfigured(
|
||||
"The database name '%s' (%d characters) is longer than "
|
||||
"PostgreSQL's limit of %d characters. Supply a shorter NAME "
|
||||
"in settings.DATABASES."
|
||||
% (
|
||||
settings_dict["NAME"],
|
||||
len(settings_dict["NAME"]),
|
||||
self.ops.max_name_length(),
|
||||
)
|
||||
)
|
||||
if settings_dict["NAME"]:
|
||||
conn_params = {
|
||||
"dbname": settings_dict["NAME"],
|
||||
**settings_dict["OPTIONS"],
|
||||
}
|
||||
elif settings_dict["NAME"] is None:
|
||||
# Connect to the default 'postgres' db.
|
||||
settings_dict["OPTIONS"].pop("service", None)
|
||||
conn_params = {"dbname": "postgres", **settings_dict["OPTIONS"]}
|
||||
else:
|
||||
conn_params = {**settings_dict["OPTIONS"]}
|
||||
conn_params["client_encoding"] = "UTF8"
|
||||
|
||||
conn_params.pop("assume_role", None)
|
||||
conn_params.pop("isolation_level", None)
|
||||
|
||||
pool_options = conn_params.pop("pool", None)
|
||||
if pool_options and not is_psycopg3:
|
||||
raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
|
||||
|
||||
server_side_binding = conn_params.pop("server_side_binding", None)
|
||||
conn_params.setdefault(
|
||||
"cursor_factory",
|
||||
(
|
||||
ServerBindingCursor
|
||||
if is_psycopg3 and server_side_binding is True
|
||||
else Cursor
|
||||
),
|
||||
)
|
||||
if settings_dict["USER"]:
|
||||
conn_params["user"] = settings_dict["USER"]
|
||||
if settings_dict["PASSWORD"]:
|
||||
conn_params["password"] = settings_dict["PASSWORD"]
|
||||
if settings_dict["HOST"]:
|
||||
conn_params["host"] = settings_dict["HOST"]
|
||||
if settings_dict["PORT"]:
|
||||
conn_params["port"] = settings_dict["PORT"]
|
||||
if is_psycopg3:
|
||||
conn_params["context"] = get_adapters_template(
|
||||
settings.USE_TZ, self.timezone
|
||||
)
|
||||
# Disable prepared statements by default to keep connection poolers
|
||||
# working. Can be reenabled via OPTIONS in the settings dict.
|
||||
conn_params["prepare_threshold"] = conn_params.pop(
|
||||
"prepare_threshold", None
|
||||
)
|
||||
return conn_params
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
# self.isolation_level must be set:
|
||||
# - after connecting to the database in order to obtain the database's
|
||||
# default when no value is explicitly specified in options.
|
||||
# - before calling _set_autocommit() because if autocommit is on, that
|
||||
# will set connection.isolation_level to ISOLATION_LEVEL_AUTOCOMMIT.
|
||||
options = self.settings_dict["OPTIONS"]
|
||||
set_isolation_level = False
|
||||
try:
|
||||
isolation_level_value = options["isolation_level"]
|
||||
except KeyError:
|
||||
self.isolation_level = IsolationLevel.READ_COMMITTED
|
||||
else:
|
||||
# Set the isolation level to the value from OPTIONS.
|
||||
try:
|
||||
self.isolation_level = IsolationLevel(isolation_level_value)
|
||||
set_isolation_level = True
|
||||
except ValueError:
|
||||
raise ImproperlyConfigured(
|
||||
f"Invalid transaction isolation level {isolation_level_value} "
|
||||
f"specified. Use one of the psycopg.IsolationLevel values."
|
||||
)
|
||||
if self.pool:
|
||||
# If nothing else has opened the pool, open it now.
|
||||
self.pool.open()
|
||||
connection = self.pool.getconn()
|
||||
else:
|
||||
connection = self.Database.connect(**conn_params)
|
||||
if set_isolation_level:
|
||||
connection.isolation_level = self.isolation_level
|
||||
if not is_psycopg3:
|
||||
# Register dummy loads() to avoid a round trip from psycopg2's
|
||||
# decode to json.dumps() to json.loads(), when using a custom
|
||||
# decoder in JSONField.
|
||||
psycopg2.extras.register_default_jsonb(
|
||||
conn_or_curs=connection, loads=lambda x: x
|
||||
)
|
||||
return connection
|
||||
|
||||
def ensure_timezone(self):
|
||||
# Close the pool so new connections pick up the correct timezone.
|
||||
self.close_pool()
|
||||
if self.connection is None:
|
||||
return False
|
||||
return self._configure_timezone(self.connection)
|
||||
|
||||
def _configure_timezone(self, connection):
|
||||
conn_timezone_name = connection.info.parameter_status("TimeZone")
|
||||
timezone_name = self.timezone_name
|
||||
if timezone_name and conn_timezone_name != timezone_name:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
|
||||
return True
|
||||
return False
|
||||
|
||||
def _configure_role(self, connection):
|
||||
if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
|
||||
with connection.cursor() as cursor:
|
||||
sql = self.ops.compose_sql("SET ROLE %s", [new_role])
|
||||
cursor.execute(sql)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _configure_connection(self, connection):
|
||||
# This function is called from init_connection_state and from the
|
||||
# psycopg pool itself after a connection is opened.
|
||||
|
||||
# Commit after setting the time zone.
|
||||
commit_tz = self._configure_timezone(connection)
|
||||
# Set the role on the connection. This is useful if the credential used
|
||||
# to login is not the same as the role that owns database resources. As
|
||||
# can be the case when using temporary or ephemeral credentials.
|
||||
commit_role = self._configure_role(connection)
|
||||
|
||||
return commit_role or commit_tz
|
||||
|
||||
def _close(self):
|
||||
if self.connection is not None:
|
||||
# `wrap_database_errors` only works for `putconn` as long as there
|
||||
# is no `reset` function set in the pool because it is deferred
|
||||
# into a thread and not directly executed.
|
||||
with self.wrap_database_errors:
|
||||
if self.pool:
|
||||
# Ensure the correct pool is returned. This is a workaround
|
||||
# for tests so a pool can be changed on setting changes
|
||||
# (e.g. USE_TZ, TIME_ZONE).
|
||||
self.connection._pool.putconn(self.connection)
|
||||
# Connection can no longer be used.
|
||||
self.connection = None
|
||||
else:
|
||||
return self.connection.close()
|
||||
|
||||
def init_connection_state(self):
|
||||
super().init_connection_state()
|
||||
|
||||
if self.connection is not None and not self.pool:
|
||||
commit = self._configure_connection(self.connection)
|
||||
|
||||
if commit and not self.get_autocommit():
|
||||
self.connection.commit()
|
||||
|
||||
@async_unsafe
|
||||
def create_cursor(self, name=None):
|
||||
if name:
|
||||
if is_psycopg3 and (
|
||||
self.settings_dict["OPTIONS"].get("server_side_binding") is not True
|
||||
):
|
||||
# psycopg >= 3 forces the usage of server-side bindings for
|
||||
# named cursors so a specialized class that implements
|
||||
# server-side cursors while performing client-side bindings
|
||||
# must be used if `server_side_binding` is disabled (default).
|
||||
cursor = ServerSideCursor(
|
||||
self.connection,
|
||||
name=name,
|
||||
scrollable=False,
|
||||
withhold=self.connection.autocommit,
|
||||
)
|
||||
else:
|
||||
# In autocommit mode, the cursor will be used outside of a
|
||||
# transaction, hence use a holdable cursor.
|
||||
cursor = self.connection.cursor(
|
||||
name, scrollable=False, withhold=self.connection.autocommit
|
||||
)
|
||||
else:
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
if is_psycopg3:
|
||||
# Register the cursor timezone only if the connection disagrees, to
|
||||
# avoid copying the adapter map.
|
||||
tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
|
||||
if self.timezone != tzloader.timezone:
|
||||
register_tzloader(self.timezone, cursor)
|
||||
else:
|
||||
cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
||||
return cursor
|
||||
|
||||
def tzinfo_factory(self, offset):
|
||||
return self.timezone
|
||||
|
||||
@async_unsafe
|
||||
def chunked_cursor(self):
|
||||
self._named_cursor_idx += 1
|
||||
# Get the current async task
|
||||
# Note that right now this is behind @async_unsafe, so this is
|
||||
# unreachable, but in future we'll start loosening this restriction.
|
||||
# For now, it's here so that every use of "threading" is
|
||||
# also async-compatible.
|
||||
try:
|
||||
current_task = asyncio.current_task()
|
||||
except RuntimeError:
|
||||
current_task = None
|
||||
# Current task can be none even if the current_task call didn't error
|
||||
if current_task:
|
||||
task_ident = str(id(current_task))
|
||||
else:
|
||||
task_ident = "sync"
|
||||
# Use that and the thread ident to get a unique name
|
||||
return self._cursor(
|
||||
name="_django_curs_%d_%s_%d"
|
||||
% (
|
||||
# Avoid reusing name in other threads / tasks
|
||||
threading.current_thread().ident,
|
||||
task_ident,
|
||||
self._named_cursor_idx,
|
||||
)
|
||||
)
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
with self.wrap_database_errors:
|
||||
self.connection.autocommit = autocommit
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check constraints by setting them to immediate. Return them to deferred
|
||||
afterward.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")
|
||||
cursor.execute("SET CONSTRAINTS ALL DEFERRED")
|
||||
|
||||
def is_usable(self):
|
||||
if self.connection is None:
|
||||
return False
|
||||
try:
|
||||
# Use a psycopg cursor directly, bypassing Django's utilities.
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1")
|
||||
except Database.Error:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def close_if_health_check_failed(self):
|
||||
if self.pool:
|
||||
# The pool only returns healthy connections.
|
||||
return
|
||||
return super().close_if_health_check_failed()
|
||||
|
||||
@contextmanager
|
||||
def _nodb_cursor(self):
|
||||
cursor = None
|
||||
try:
|
||||
with super()._nodb_cursor() as cursor:
|
||||
yield cursor
|
||||
except (Database.DatabaseError, WrappedDatabaseError):
|
||||
if cursor is not None:
|
||||
raise
|
||||
warnings.warn(
|
||||
"Normally Django will use a connection to the 'postgres' database "
|
||||
"to avoid running initialization queries against the production "
|
||||
"database when it's not needed (for example, when running tests). "
|
||||
"Django was unable to create a connection to the 'postgres' database "
|
||||
"and will use the first PostgreSQL database instead.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
for connection in connections.all():
|
||||
if (
|
||||
connection.vendor == "postgresql"
|
||||
and connection.settings_dict["NAME"] != "postgres"
|
||||
):
|
||||
conn = self.__class__(
|
||||
{
|
||||
**self.settings_dict,
|
||||
"NAME": connection.settings_dict["NAME"],
|
||||
},
|
||||
alias=self.alias,
|
||||
)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
yield cursor
|
||||
finally:
|
||||
conn.close()
|
||||
break
|
||||
else:
|
||||
raise
|
||||
|
||||
@cached_property
|
||||
def pg_version(self):
|
||||
with self.temporary_connection():
|
||||
return self.connection.info.server_version
|
||||
|
||||
def make_debug_cursor(self, cursor):
|
||||
return CursorDebugWrapper(cursor, self)
|
||||
|
||||
|
||||
if is_psycopg3:
|
||||
|
||||
class CursorMixin:
|
||||
"""
|
||||
A subclass of psycopg cursor implementing callproc.
|
||||
"""
|
||||
|
||||
def callproc(self, name, args=None):
|
||||
if not isinstance(name, sql.Identifier):
|
||||
name = sql.Identifier(name)
|
||||
|
||||
qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
|
||||
if args:
|
||||
for item in args:
|
||||
qparts.append(sql.Literal(item))
|
||||
qparts.append(sql.SQL(","))
|
||||
del qparts[-1]
|
||||
|
||||
qparts.append(sql.SQL(")"))
|
||||
stmt = sql.Composed(qparts)
|
||||
self.execute(stmt)
|
||||
return args
|
||||
|
||||
class ServerBindingCursor(CursorMixin, Database.Cursor):
|
||||
pass
|
||||
|
||||
class Cursor(CursorMixin, Database.ClientCursor):
|
||||
pass
|
||||
|
||||
class ServerSideCursor(
|
||||
CursorMixin, Database.client_cursor.ClientCursorMixin, Database.ServerCursor
|
||||
):
|
||||
"""
|
||||
psycopg >= 3 forces the usage of server-side bindings when using named
|
||||
cursors but the ORM doesn't yet support the systematic generation of
|
||||
prepareable SQL (#20516).
|
||||
|
||||
ClientCursorMixin forces the usage of client-side bindings while
|
||||
ServerCursor implements the logic required to declare and scroll
|
||||
through named cursors.
|
||||
|
||||
Mixing ClientCursorMixin in wouldn't be necessary if Cursor allowed to
|
||||
specify how parameters should be bound instead, which ServerCursor
|
||||
would inherit, but that's not the case.
|
||||
"""
|
||||
|
||||
class CursorDebugWrapper(BaseCursorDebugWrapper):
|
||||
def copy(self, statement):
|
||||
with self.debug_sql(statement):
|
||||
return self.cursor.copy(statement)
|
||||
|
||||
else:
|
||||
Cursor = psycopg2.extensions.cursor
|
||||
|
||||
class CursorDebugWrapper(BaseCursorDebugWrapper):
|
||||
def copy_expert(self, sql, file, *args):
|
||||
with self.debug_sql(sql):
|
||||
return self.cursor.copy_expert(sql, file, *args)
|
||||
|
||||
def copy_to(self, file, table, *args, **kwargs):
|
||||
with self.debug_sql(sql="COPY %s TO STDOUT" % table):
|
||||
return self.cursor.copy_to(file, table, *args, **kwargs)
|
||||
@@ -0,0 +1,64 @@
|
||||
import signal
|
||||
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = "psql"
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name]
|
||||
options = settings_dict["OPTIONS"]
|
||||
|
||||
host = settings_dict.get("HOST")
|
||||
port = settings_dict.get("PORT")
|
||||
dbname = settings_dict.get("NAME")
|
||||
user = settings_dict.get("USER")
|
||||
passwd = settings_dict.get("PASSWORD")
|
||||
passfile = options.get("passfile")
|
||||
service = options.get("service")
|
||||
sslmode = options.get("sslmode")
|
||||
sslrootcert = options.get("sslrootcert")
|
||||
sslcert = options.get("sslcert")
|
||||
sslkey = options.get("sslkey")
|
||||
|
||||
if not dbname and not service:
|
||||
# Connect to the default 'postgres' db.
|
||||
dbname = "postgres"
|
||||
if user:
|
||||
args += ["-U", user]
|
||||
if host:
|
||||
args += ["-h", host]
|
||||
if port:
|
||||
args += ["-p", str(port)]
|
||||
args.extend(parameters)
|
||||
if dbname:
|
||||
args += [dbname]
|
||||
|
||||
env = {}
|
||||
if passwd:
|
||||
env["PGPASSWORD"] = str(passwd)
|
||||
if service:
|
||||
env["PGSERVICE"] = str(service)
|
||||
if sslmode:
|
||||
env["PGSSLMODE"] = str(sslmode)
|
||||
if sslrootcert:
|
||||
env["PGSSLROOTCERT"] = str(sslrootcert)
|
||||
if sslcert:
|
||||
env["PGSSLCERT"] = str(sslcert)
|
||||
if sslkey:
|
||||
env["PGSSLKEY"] = str(sslkey)
|
||||
if passfile:
|
||||
env["PGPASSFILE"] = str(passfile)
|
||||
return args, (env or None)
|
||||
|
||||
def runshell(self, parameters):
|
||||
sigint_handler = signal.getsignal(signal.SIGINT)
|
||||
try:
|
||||
# Allow SIGINT to pass to psql to abort queries.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
super().runshell(parameters)
|
||||
finally:
|
||||
# Restore the original SIGINT handler.
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
@@ -0,0 +1,50 @@
|
||||
from django.db.models.sql.compiler import (
|
||||
SQLAggregateCompiler,
|
||||
SQLCompiler,
|
||||
SQLDeleteCompiler,
|
||||
)
|
||||
from django.db.models.sql.compiler import SQLInsertCompiler as BaseSQLInsertCompiler
|
||||
from django.db.models.sql.compiler import SQLUpdateCompiler
|
||||
|
||||
__all__ = [
|
||||
"SQLAggregateCompiler",
|
||||
"SQLCompiler",
|
||||
"SQLDeleteCompiler",
|
||||
"SQLInsertCompiler",
|
||||
"SQLUpdateCompiler",
|
||||
]
|
||||
|
||||
|
||||
class InsertUnnest(list):
|
||||
"""
|
||||
Sentinel value to signal DatabaseOperations.bulk_insert_sql() that the
|
||||
UNNEST strategy should be used for the bulk insert.
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
return "UNNEST(%s)" % ", ".join(self)
|
||||
|
||||
|
||||
class SQLInsertCompiler(BaseSQLInsertCompiler):
|
||||
def assemble_as_sql(self, fields, value_rows):
|
||||
# Specialize bulk-insertion of literal non-array values through
|
||||
# UNNEST to reduce the time spent planning the query.
|
||||
if (
|
||||
# The optimization is not worth doing if there is a single
|
||||
# row as it will result in the same number of placeholders.
|
||||
len(value_rows) <= 1
|
||||
# Lack of fields denote the usage of the DEFAULT keyword
|
||||
# for the insertion of empty rows.
|
||||
or any(field is None for field in fields)
|
||||
# Compilable cannot be combined in an array of literal values.
|
||||
or any(any(hasattr(value, "as_sql") for value in row) for row in value_rows)
|
||||
):
|
||||
return super().assemble_as_sql(fields, value_rows)
|
||||
db_types = [field.db_type(self.connection) for field in fields]
|
||||
# Abort if any of the fields are arrays as UNNEST indiscriminately
|
||||
# flatten them instead of reducing their nesting by one.
|
||||
if any(db_type.endswith("]") for db_type in db_types):
|
||||
return super().assemble_as_sql(fields, value_rows)
|
||||
return InsertUnnest(["(%%s)::%s[]" % db_type for db_type in db_types]), [
|
||||
list(map(list, zip(*value_rows)))
|
||||
]
|
||||
@@ -0,0 +1,91 @@
|
||||
import sys
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
from django.db.backends.postgresql.psycopg_any import errors
|
||||
from django.db.backends.utils import strip_quotes
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
def _quote_name(self, name):
|
||||
return self.connection.ops.quote_name(name)
|
||||
|
||||
def _get_database_create_suffix(self, encoding=None, template=None):
|
||||
suffix = ""
|
||||
if encoding:
|
||||
suffix += " ENCODING '{}'".format(encoding)
|
||||
if template:
|
||||
suffix += " TEMPLATE {}".format(self._quote_name(template))
|
||||
return suffix and "WITH" + suffix
|
||||
|
||||
def sql_table_creation_suffix(self):
|
||||
test_settings = self.connection.settings_dict["TEST"]
|
||||
if test_settings.get("COLLATION") is not None:
|
||||
raise ImproperlyConfigured(
|
||||
"PostgreSQL does not support collation setting at database "
|
||||
"creation time."
|
||||
)
|
||||
return self._get_database_create_suffix(
|
||||
encoding=test_settings["CHARSET"],
|
||||
template=test_settings.get("TEMPLATE"),
|
||||
)
|
||||
|
||||
def _database_exists(self, cursor, database_name):
|
||||
cursor.execute(
|
||||
"SELECT 1 FROM pg_catalog.pg_database WHERE datname = %s",
|
||||
[strip_quotes(database_name)],
|
||||
)
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def _execute_create_test_db(self, cursor, parameters, keepdb=False):
|
||||
try:
|
||||
if keepdb and self._database_exists(cursor, parameters["dbname"]):
|
||||
# If the database should be kept and it already exists, don't
|
||||
# try to create a new one.
|
||||
return
|
||||
super()._execute_create_test_db(cursor, parameters, keepdb)
|
||||
except Exception as e:
|
||||
if not isinstance(e.__cause__, errors.DuplicateDatabase):
|
||||
# All errors except "database already exists" cancel tests.
|
||||
self.log("Got an error creating the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
elif not keepdb:
|
||||
# If the database should be kept, ignore "database already
|
||||
# exists".
|
||||
raise
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
# CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
|
||||
# to the template database.
|
||||
self.connection.close()
|
||||
self.connection.close_pool()
|
||||
|
||||
source_database_name = self.connection.settings_dict["NAME"]
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||||
test_db_params = {
|
||||
"dbname": self._quote_name(target_database_name),
|
||||
"suffix": self._get_database_create_suffix(template=source_database_name),
|
||||
}
|
||||
with self._nodb_cursor() as cursor:
|
||||
try:
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception:
|
||||
try:
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (
|
||||
self._get_database_display_str(
|
||||
verbosity, target_database_name
|
||||
),
|
||||
)
|
||||
)
|
||||
cursor.execute("DROP DATABASE %(dbname)s" % test_db_params)
|
||||
self._execute_create_test_db(cursor, test_db_params, keepdb)
|
||||
except Exception as e:
|
||||
self.log("Got an error cloning the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
self.connection.close_pool()
|
||||
return super()._destroy_test_db(test_database_name, verbosity)
|
||||
@@ -0,0 +1,170 @@
|
||||
import operator
|
||||
|
||||
from django.db import DataError, InterfaceError
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.db.backends.postgresql.psycopg_any import is_psycopg3
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
minimum_database_version = (14,)
|
||||
allows_group_by_selected_pks = True
|
||||
can_return_columns_from_insert = True
|
||||
can_return_rows_from_bulk_insert = True
|
||||
has_real_datatype = True
|
||||
has_native_uuid_field = True
|
||||
has_native_duration_field = True
|
||||
has_native_json_field = True
|
||||
can_defer_constraint_checks = True
|
||||
has_select_for_update = True
|
||||
has_select_for_update_nowait = True
|
||||
has_select_for_update_of = True
|
||||
has_select_for_update_skip_locked = True
|
||||
has_select_for_no_key_update = True
|
||||
can_release_savepoints = True
|
||||
supports_comments = True
|
||||
supports_tablespaces = True
|
||||
supports_transactions = True
|
||||
can_introspect_materialized_views = True
|
||||
can_distinct_on_fields = True
|
||||
can_rollback_ddl = True
|
||||
schema_editor_uses_clientside_param_binding = True
|
||||
supports_combined_alters = True
|
||||
nulls_order_largest = True
|
||||
closed_cursor_error_class = InterfaceError
|
||||
greatest_least_ignores_nulls = True
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
supports_slicing_ordering_in_compound = True
|
||||
create_test_procedure_without_params_sql = """
|
||||
CREATE FUNCTION test_procedure () RETURNS void AS $$
|
||||
DECLARE
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := 1;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;"""
|
||||
create_test_procedure_with_int_param_sql = """
|
||||
CREATE FUNCTION test_procedure (P_I INTEGER) RETURNS void AS $$
|
||||
DECLARE
|
||||
V_I INTEGER;
|
||||
BEGIN
|
||||
V_I := P_I;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;"""
|
||||
create_test_table_with_composite_primary_key = """
|
||||
CREATE TABLE test_table_composite_pk (
|
||||
column_1 INTEGER NOT NULL,
|
||||
column_2 INTEGER NOT NULL,
|
||||
PRIMARY KEY(column_1, column_2)
|
||||
)
|
||||
"""
|
||||
requires_casted_case_in_updates = True
|
||||
supports_over_clause = True
|
||||
supports_frame_exclusion = True
|
||||
only_supports_unbounded_with_preceding_and_following = True
|
||||
supports_aggregate_filter_clause = True
|
||||
supported_explain_formats = {"JSON", "TEXT", "XML", "YAML"}
|
||||
supports_deferrable_unique_constraints = True
|
||||
has_json_operators = True
|
||||
json_key_contains_list_matching_requires_list = True
|
||||
supports_update_conflicts = True
|
||||
supports_update_conflicts_with_target = True
|
||||
supports_covering_indexes = True
|
||||
supports_stored_generated_columns = True
|
||||
supports_virtual_generated_columns = False
|
||||
can_rename_index = True
|
||||
test_collations = {
|
||||
"deterministic": "C",
|
||||
"non_default": "sv-x-icu",
|
||||
"swedish_ci": "sv-x-icu",
|
||||
"virtual": "sv-x-icu",
|
||||
}
|
||||
test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
|
||||
insert_test_table_with_defaults = "INSERT INTO {} DEFAULT VALUES"
|
||||
|
||||
@cached_property
|
||||
def django_test_skips(self):
|
||||
skips = {
|
||||
"opclasses are PostgreSQL only.": {
|
||||
"indexes.tests.SchemaIndexesNotPostgreSQLTests."
|
||||
"test_create_index_ignores_opclasses",
|
||||
},
|
||||
"PostgreSQL requires casting to text.": {
|
||||
"lookup.tests.LookupTests.test_textfield_exact_null",
|
||||
},
|
||||
}
|
||||
if self.connection.settings_dict["OPTIONS"].get("pool"):
|
||||
skips.update(
|
||||
{
|
||||
"Pool does implicit health checks": {
|
||||
"backends.base.test_base.ConnectionHealthChecksTests."
|
||||
"test_health_checks_enabled",
|
||||
"backends.base.test_base.ConnectionHealthChecksTests."
|
||||
"test_set_autocommit_health_checks_enabled",
|
||||
},
|
||||
}
|
||||
)
|
||||
if self.uses_server_side_binding:
|
||||
skips.update(
|
||||
{
|
||||
"The actual query cannot be determined for server side bindings": {
|
||||
"backends.base.test_base.ExecuteWrapperTests."
|
||||
"test_wrapper_debug",
|
||||
}
|
||||
},
|
||||
)
|
||||
return skips
|
||||
|
||||
@cached_property
|
||||
def django_test_expected_failures(self):
|
||||
expected_failures = set()
|
||||
if self.uses_server_side_binding:
|
||||
expected_failures.update(
|
||||
{
|
||||
# Parameters passed to expressions in SELECT and GROUP BY
|
||||
# clauses are not recognized as the same values when using
|
||||
# server-side binding cursors (#34255).
|
||||
"aggregation.tests.AggregateTestCase."
|
||||
"test_group_by_nested_expression_with_params",
|
||||
}
|
||||
)
|
||||
return expected_failures
|
||||
|
||||
@cached_property
|
||||
def uses_server_side_binding(self):
|
||||
options = self.connection.settings_dict["OPTIONS"]
|
||||
return is_psycopg3 and options.get("server_side_binding") is True
|
||||
|
||||
@cached_property
|
||||
def prohibits_null_characters_in_text_exception(self):
|
||||
if is_psycopg3:
|
||||
return DataError, "PostgreSQL text fields cannot contain NUL (0x00) bytes"
|
||||
else:
|
||||
return ValueError, "A string literal cannot contain NUL (0x00) characters."
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
"PositiveBigIntegerField": "BigIntegerField",
|
||||
"PositiveIntegerField": "IntegerField",
|
||||
"PositiveSmallIntegerField": "SmallIntegerField",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def is_postgresql_15(self):
|
||||
return self.connection.pg_version >= 150000
|
||||
|
||||
@cached_property
|
||||
def is_postgresql_16(self):
|
||||
return self.connection.pg_version >= 160000
|
||||
|
||||
@cached_property
|
||||
def is_postgresql_17(self):
|
||||
return self.connection.pg_version >= 170000
|
||||
|
||||
supports_unlimited_charfield = True
|
||||
supports_nulls_distinct_unique_constraints = property(
|
||||
operator.attrgetter("is_postgresql_15")
|
||||
)
|
||||
@@ -0,0 +1,299 @@
|
||||
from collections import namedtuple
|
||||
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
|
||||
from django.db.backends.base.introspection import TableInfo as BaseTableInfo
|
||||
from django.db.models import Index
|
||||
|
||||
FieldInfo = namedtuple("FieldInfo", BaseFieldInfo._fields + ("is_autofield", "comment"))
|
||||
TableInfo = namedtuple("TableInfo", BaseTableInfo._fields + ("comment",))
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
# Maps type codes to Django Field types.
|
||||
data_types_reverse = {
|
||||
16: "BooleanField",
|
||||
17: "BinaryField",
|
||||
20: "BigIntegerField",
|
||||
21: "SmallIntegerField",
|
||||
23: "IntegerField",
|
||||
25: "TextField",
|
||||
700: "FloatField",
|
||||
701: "FloatField",
|
||||
869: "GenericIPAddressField",
|
||||
1042: "CharField", # blank-padded
|
||||
1043: "CharField",
|
||||
1082: "DateField",
|
||||
1083: "TimeField",
|
||||
1114: "DateTimeField",
|
||||
1184: "DateTimeField",
|
||||
1186: "DurationField",
|
||||
1266: "TimeField",
|
||||
1700: "DecimalField",
|
||||
2950: "UUIDField",
|
||||
3802: "JSONField",
|
||||
}
|
||||
# A hook for subclasses.
|
||||
index_default_access_method = "btree"
|
||||
|
||||
ignored_tables = []
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if description.is_autofield or (
|
||||
# Required for pre-Django 4.1 serial columns.
|
||||
description.default
|
||||
and "nextval" in description.default
|
||||
):
|
||||
if field_type == "IntegerField":
|
||||
return "AutoField"
|
||||
elif field_type == "BigIntegerField":
|
||||
return "BigAutoField"
|
||||
elif field_type == "SmallIntegerField":
|
||||
return "SmallAutoField"
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
c.relname,
|
||||
CASE
|
||||
WHEN c.relispartition THEN 'p'
|
||||
WHEN c.relkind IN ('m', 'v') THEN 'v'
|
||||
ELSE 't'
|
||||
END,
|
||||
obj_description(c.oid, 'pg_class')
|
||||
FROM pg_catalog.pg_class c
|
||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
|
||||
AND pg_catalog.pg_table_is_visible(c.oid)
|
||||
"""
|
||||
)
|
||||
return [
|
||||
TableInfo(*row)
|
||||
for row in cursor.fetchall()
|
||||
if row[0] not in self.ignored_tables
|
||||
]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
# Query the pg_catalog tables as cursor.description does not reliably
|
||||
# return the nullable property and information_schema.columns does not
|
||||
# contain details of materialized views.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
a.attname AS column_name,
|
||||
NOT (a.attnotnull OR (t.typtype = 'd' AND t.typnotnull)) AS is_nullable,
|
||||
pg_get_expr(ad.adbin, ad.adrelid) AS column_default,
|
||||
CASE WHEN collname = 'default' THEN NULL ELSE collname END AS collation,
|
||||
a.attidentity != '' AS is_autofield,
|
||||
col_description(a.attrelid, a.attnum) AS column_comment
|
||||
FROM pg_attribute a
|
||||
LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum
|
||||
LEFT JOIN pg_collation co ON a.attcollation = co.oid
|
||||
JOIN pg_type t ON a.atttypid = t.oid
|
||||
JOIN pg_class c ON a.attrelid = c.oid
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE c.relkind IN ('f', 'm', 'p', 'r', 'v')
|
||||
AND c.relname = %s
|
||||
AND n.nspname NOT IN ('pg_catalog', 'pg_toast')
|
||||
AND pg_catalog.pg_table_is_visible(c.oid)
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
field_map = {line[0]: line[1:] for line in cursor.fetchall()}
|
||||
cursor.execute(
|
||||
"SELECT * FROM %s LIMIT 1" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
return [
|
||||
FieldInfo(
|
||||
line.name,
|
||||
line.type_code,
|
||||
# display_size is always None on psycopg2.
|
||||
line.internal_size if line.display_size is None else line.display_size,
|
||||
line.internal_size,
|
||||
line.precision,
|
||||
line.scale,
|
||||
*field_map[line.name],
|
||||
)
|
||||
for line in cursor.description
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
s.relname AS sequence_name,
|
||||
a.attname AS colname
|
||||
FROM
|
||||
pg_class s
|
||||
JOIN pg_depend d ON d.objid = s.oid
|
||||
AND d.classid = 'pg_class'::regclass
|
||||
AND d.refclassid = 'pg_class'::regclass
|
||||
JOIN pg_attribute a ON d.refobjid = a.attrelid
|
||||
AND d.refobjsubid = a.attnum
|
||||
JOIN pg_class tbl ON tbl.oid = d.refobjid
|
||||
AND tbl.relname = %s
|
||||
AND pg_catalog.pg_table_is_visible(tbl.oid)
|
||||
WHERE
|
||||
s.relkind = 'S';
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
return [
|
||||
{"name": row[0], "table": table_name, "column": row[1]}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {field_name: (field_name_other_table, other_table)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT a1.attname, c2.relname, a2.attname
|
||||
FROM pg_constraint con
|
||||
LEFT JOIN pg_class c1 ON con.conrelid = c1.oid
|
||||
LEFT JOIN pg_class c2 ON con.confrelid = c2.oid
|
||||
LEFT JOIN
|
||||
pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1]
|
||||
LEFT JOIN
|
||||
pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1]
|
||||
WHERE
|
||||
c1.relname = %s AND
|
||||
con.contype = 'f' AND
|
||||
c1.relnamespace = c2.relnamespace AND
|
||||
pg_catalog.pg_table_is_visible(c1.oid)
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
return {row[0]: (row[2], row[1]) for row in cursor.fetchall()}
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns. Also retrieve the definition of expression-based
|
||||
indexes.
|
||||
"""
|
||||
constraints = {}
|
||||
# Loop over the key table, collecting things as constraints. The column
|
||||
# array must return column names in the same order in which they were
|
||||
# created.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
c.conname,
|
||||
array(
|
||||
SELECT attname
|
||||
FROM unnest(c.conkey) WITH ORDINALITY cols(colid, arridx)
|
||||
JOIN pg_attribute AS ca ON cols.colid = ca.attnum
|
||||
WHERE ca.attrelid = c.conrelid
|
||||
ORDER BY cols.arridx
|
||||
),
|
||||
c.contype,
|
||||
(SELECT fkc.relname || '.' || fka.attname
|
||||
FROM pg_attribute AS fka
|
||||
JOIN pg_class AS fkc ON fka.attrelid = fkc.oid
|
||||
WHERE fka.attrelid = c.confrelid AND fka.attnum = c.confkey[1]),
|
||||
cl.reloptions
|
||||
FROM pg_constraint AS c
|
||||
JOIN pg_class AS cl ON c.conrelid = cl.oid
|
||||
WHERE cl.relname = %s AND pg_catalog.pg_table_is_visible(cl.oid)
|
||||
""",
|
||||
[table_name],
|
||||
)
|
||||
for constraint, columns, kind, used_cols, options in cursor.fetchall():
|
||||
constraints[constraint] = {
|
||||
"columns": columns,
|
||||
"primary_key": kind == "p",
|
||||
"unique": kind in ["p", "u"],
|
||||
"foreign_key": tuple(used_cols.split(".", 1)) if kind == "f" else None,
|
||||
"check": kind == "c",
|
||||
"index": False,
|
||||
"definition": None,
|
||||
"options": options,
|
||||
}
|
||||
# Now get indexes
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
indexname,
|
||||
array_agg(attname ORDER BY arridx),
|
||||
indisunique,
|
||||
indisprimary,
|
||||
array_agg(ordering ORDER BY arridx),
|
||||
amname,
|
||||
exprdef,
|
||||
s2.attoptions
|
||||
FROM (
|
||||
SELECT
|
||||
c2.relname as indexname, idx.*, attr.attname, am.amname,
|
||||
CASE
|
||||
WHEN idx.indexprs IS NOT NULL THEN
|
||||
pg_get_indexdef(idx.indexrelid)
|
||||
END AS exprdef,
|
||||
CASE am.amname
|
||||
WHEN %s THEN
|
||||
CASE (option & 1)
|
||||
WHEN 1 THEN 'DESC' ELSE 'ASC'
|
||||
END
|
||||
END as ordering,
|
||||
c2.reloptions as attoptions
|
||||
FROM (
|
||||
SELECT *
|
||||
FROM
|
||||
pg_index i,
|
||||
unnest(i.indkey, i.indoption)
|
||||
WITH ORDINALITY koi(key, option, arridx)
|
||||
) idx
|
||||
LEFT JOIN pg_class c ON idx.indrelid = c.oid
|
||||
LEFT JOIN pg_class c2 ON idx.indexrelid = c2.oid
|
||||
LEFT JOIN pg_am am ON c2.relam = am.oid
|
||||
LEFT JOIN
|
||||
pg_attribute attr ON attr.attrelid = c.oid AND attr.attnum = idx.key
|
||||
WHERE c.relname = %s AND pg_catalog.pg_table_is_visible(c.oid)
|
||||
) s2
|
||||
GROUP BY indexname, indisunique, indisprimary, amname, exprdef, attoptions;
|
||||
""",
|
||||
[self.index_default_access_method, table_name],
|
||||
)
|
||||
for (
|
||||
index,
|
||||
columns,
|
||||
unique,
|
||||
primary,
|
||||
orders,
|
||||
type_,
|
||||
definition,
|
||||
options,
|
||||
) in cursor.fetchall():
|
||||
if index not in constraints:
|
||||
basic_index = (
|
||||
type_ == self.index_default_access_method
|
||||
and
|
||||
# '_btree' references
|
||||
# django.contrib.postgres.indexes.BTreeIndex.suffix.
|
||||
not index.endswith("_btree")
|
||||
and options is None
|
||||
)
|
||||
constraints[index] = {
|
||||
"columns": columns if columns != [None] else [],
|
||||
"orders": orders if orders != [None] else [],
|
||||
"primary_key": primary,
|
||||
"unique": unique,
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
"type": Index.suffix if basic_index else type_,
|
||||
"definition": definition,
|
||||
"options": options,
|
||||
}
|
||||
return constraints
|
||||
@@ -0,0 +1,422 @@
|
||||
import json
|
||||
from functools import lru_cache, partial
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.backends.postgresql.compiler import InsertUnnest
|
||||
from django.db.backends.postgresql.psycopg_any import (
|
||||
Inet,
|
||||
Jsonb,
|
||||
errors,
|
||||
is_psycopg3,
|
||||
mogrify,
|
||||
)
|
||||
from django.db.backends.utils import split_tzname_delta
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.db.models.functions import Cast
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_json_dumps(encoder):
|
||||
if encoder is None:
|
||||
return json.dumps
|
||||
return partial(json.dumps, cls=encoder)
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
compiler_module = "django.db.backends.postgresql.compiler"
|
||||
cast_char_field_without_max_length = "varchar"
|
||||
explain_prefix = "EXPLAIN"
|
||||
explain_options = frozenset(
|
||||
[
|
||||
"ANALYZE",
|
||||
"BUFFERS",
|
||||
"COSTS",
|
||||
"GENERIC_PLAN",
|
||||
"MEMORY",
|
||||
"SETTINGS",
|
||||
"SERIALIZE",
|
||||
"SUMMARY",
|
||||
"TIMING",
|
||||
"VERBOSE",
|
||||
"WAL",
|
||||
]
|
||||
)
|
||||
cast_data_types = {
|
||||
"AutoField": "integer",
|
||||
"BigAutoField": "bigint",
|
||||
"SmallAutoField": "smallint",
|
||||
}
|
||||
|
||||
if is_psycopg3:
|
||||
from psycopg.types import numeric
|
||||
|
||||
integerfield_type_map = {
|
||||
"SmallIntegerField": numeric.Int2,
|
||||
"IntegerField": numeric.Int4,
|
||||
"BigIntegerField": numeric.Int8,
|
||||
"PositiveSmallIntegerField": numeric.Int2,
|
||||
"PositiveIntegerField": numeric.Int4,
|
||||
"PositiveBigIntegerField": numeric.Int8,
|
||||
}
|
||||
|
||||
def unification_cast_sql(self, output_field):
|
||||
internal_type = output_field.get_internal_type()
|
||||
if internal_type in (
|
||||
"GenericIPAddressField",
|
||||
"IPAddressField",
|
||||
"TimeField",
|
||||
"UUIDField",
|
||||
):
|
||||
# PostgreSQL will resolve a union as type 'text' if input types are
|
||||
# 'unknown'.
|
||||
# https://www.postgresql.org/docs/current/typeconv-union-case.html
|
||||
# These fields cannot be implicitly cast back in the default
|
||||
# PostgreSQL configuration so we need to explicitly cast them.
|
||||
# We must also remove components of the type within brackets:
|
||||
# varchar(255) -> varchar.
|
||||
return (
|
||||
"CAST(%%s AS %s)" % output_field.db_type(self.connection).split("(")[0]
|
||||
)
|
||||
return "%s"
|
||||
|
||||
# EXTRACT format cannot be passed in parameters.
|
||||
_extract_format_re = _lazy_re_compile(r"[A-Z_]+")
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
|
||||
if lookup_type == "week_day":
|
||||
# For consistency across backends, we return Sunday=1, Saturday=7.
|
||||
return f"EXTRACT(DOW FROM {sql}) + 1", params
|
||||
elif lookup_type == "iso_week_day":
|
||||
return f"EXTRACT(ISODOW FROM {sql})", params
|
||||
elif lookup_type == "iso_year":
|
||||
return f"EXTRACT(ISOYEAR FROM {sql})", params
|
||||
|
||||
lookup_type = lookup_type.upper()
|
||||
if not self._extract_format_re.fullmatch(lookup_type):
|
||||
raise ValueError(f"Invalid lookup type: {lookup_type!r}")
|
||||
return f"EXTRACT({lookup_type} FROM {sql})", params
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
|
||||
|
||||
def _prepare_tzname_delta(self, tzname):
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
if offset:
|
||||
sign = "-" if sign == "+" else "+"
|
||||
return f"{tzname}{sign}{offset}"
|
||||
return tzname
|
||||
|
||||
def _convert_sql_to_tz(self, sql, params, tzname):
|
||||
if tzname and settings.USE_TZ:
|
||||
tzname_param = self._prepare_tzname_delta(tzname)
|
||||
return f"{sql} AT TIME ZONE %s", (*params, tzname_param)
|
||||
return sql, params
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"({sql})::date", params
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"({sql})::time", params
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TRUNC
|
||||
return f"DATE_TRUNC(%s, {sql})", (lookup_type, *params)
|
||||
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
if lookup_type == "second":
|
||||
# Truncate fractional seconds.
|
||||
return f"EXTRACT(SECOND FROM DATE_TRUNC(%s, {sql}))", ("second", *params)
|
||||
return self.date_extract_sql(lookup_type, sql, params)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
sql, params = self._convert_sql_to_tz(sql, params, tzname)
|
||||
return f"DATE_TRUNC(%s, {sql})::time", (lookup_type, *params)
|
||||
|
||||
def deferrable_sql(self):
|
||||
return " DEFERRABLE INITIALLY DEFERRED"
|
||||
|
||||
def bulk_insert_sql(self, fields, placeholder_rows):
|
||||
if isinstance(placeholder_rows, InsertUnnest):
|
||||
return f"SELECT * FROM {placeholder_rows}"
|
||||
return super().bulk_insert_sql(fields, placeholder_rows)
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the tuple of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def lookup_cast(self, lookup_type, internal_type=None):
|
||||
lookup = "%s"
|
||||
# Cast text lookups to text to allow things like filter(x__contains=4)
|
||||
if lookup_type in (
|
||||
"iexact",
|
||||
"contains",
|
||||
"icontains",
|
||||
"startswith",
|
||||
"istartswith",
|
||||
"endswith",
|
||||
"iendswith",
|
||||
"regex",
|
||||
"iregex",
|
||||
):
|
||||
if internal_type in ("IPAddressField", "GenericIPAddressField"):
|
||||
lookup = "HOST(%s)"
|
||||
else:
|
||||
lookup = "%s::text"
|
||||
|
||||
# Use UPPER(x) for case-insensitive lookups; it's faster.
|
||||
if lookup_type in ("iexact", "icontains", "istartswith", "iendswith"):
|
||||
lookup = "UPPER(%s)" % lookup
|
||||
|
||||
return lookup
|
||||
|
||||
def no_limit_value(self):
|
||||
return None
|
||||
|
||||
def prepare_sql_script(self, sql):
|
||||
return [sql]
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith('"') and name.endswith('"'):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def compose_sql(self, sql, params):
|
||||
return mogrify(sql, params, self.connection)
|
||||
|
||||
def set_time_zone_sql(self):
|
||||
return "SELECT set_config('TimeZone', %s, false)"
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if not tables:
|
||||
return []
|
||||
|
||||
# Perform a single SQL 'TRUNCATE x, y, z...;' statement. It allows us
|
||||
# to truncate tables referenced by a foreign key in any other table.
|
||||
sql_parts = [
|
||||
style.SQL_KEYWORD("TRUNCATE"),
|
||||
", ".join(style.SQL_FIELD(self.quote_name(table)) for table in tables),
|
||||
]
|
||||
if reset_sequences:
|
||||
sql_parts.append(style.SQL_KEYWORD("RESTART IDENTITY"))
|
||||
if allow_cascade:
|
||||
sql_parts.append(style.SQL_KEYWORD("CASCADE"))
|
||||
return ["%s;" % " ".join(sql_parts)]
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
# 'ALTER SEQUENCE sequence_name RESTART WITH 1;'... style SQL statements
|
||||
# to reset sequence indices
|
||||
sql = []
|
||||
for sequence_info in sequences:
|
||||
table_name = sequence_info["table"]
|
||||
# 'id' will be the case if it's an m2m using an autogenerated
|
||||
# intermediate table (see BaseDatabaseIntrospection.sequence_list).
|
||||
column_name = sequence_info["column"] or "id"
|
||||
sql.append(
|
||||
"%s setval(pg_get_serial_sequence('%s','%s'), 1, false);"
|
||||
% (
|
||||
style.SQL_KEYWORD("SELECT"),
|
||||
style.SQL_TABLE(self.quote_name(table_name)),
|
||||
style.SQL_FIELD(column_name),
|
||||
)
|
||||
)
|
||||
return sql
|
||||
|
||||
def tablespace_sql(self, tablespace, inline=False):
|
||||
if inline:
|
||||
return "USING INDEX TABLESPACE %s" % self.quote_name(tablespace)
|
||||
else:
|
||||
return "TABLESPACE %s" % self.quote_name(tablespace)
|
||||
|
||||
def sequence_reset_sql(self, style, model_list):
|
||||
from django.db import models
|
||||
|
||||
output = []
|
||||
qn = self.quote_name
|
||||
for model in model_list:
|
||||
# Use `coalesce` to set the sequence for each model to the max pk
|
||||
# value if there are records, or 1 if there are none. Set the
|
||||
# `is_called` property (the third argument to `setval`) to true if
|
||||
# there are records (as the max pk value is already in use),
|
||||
# otherwise set it to false. Use pg_get_serial_sequence to get the
|
||||
# underlying sequence name from the table name and column name.
|
||||
|
||||
for f in model._meta.local_fields:
|
||||
if isinstance(f, models.AutoField):
|
||||
output.append(
|
||||
"%s setval(pg_get_serial_sequence('%s','%s'), "
|
||||
"coalesce(max(%s), 1), max(%s) %s null) %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("SELECT"),
|
||||
style.SQL_TABLE(qn(model._meta.db_table)),
|
||||
style.SQL_FIELD(f.column),
|
||||
style.SQL_FIELD(qn(f.column)),
|
||||
style.SQL_FIELD(qn(f.column)),
|
||||
style.SQL_KEYWORD("IS NOT"),
|
||||
style.SQL_KEYWORD("FROM"),
|
||||
style.SQL_TABLE(qn(model._meta.db_table)),
|
||||
)
|
||||
)
|
||||
# Only one AutoField is allowed per model, so don't bother
|
||||
# continuing.
|
||||
break
|
||||
return output
|
||||
|
||||
def prep_for_iexact_query(self, x):
|
||||
return x
|
||||
|
||||
def max_name_length(self):
|
||||
"""
|
||||
Return the maximum length of an identifier.
|
||||
|
||||
The maximum length of an identifier is 63 by default, but can be
|
||||
changed by recompiling PostgreSQL after editing the NAMEDATALEN
|
||||
macro in src/include/pg_config_manual.h.
|
||||
|
||||
This implementation returns 63, but can be overridden by a custom
|
||||
database backend that inherits most of its behavior from this one.
|
||||
"""
|
||||
return 63
|
||||
|
||||
def distinct_sql(self, fields, params):
|
||||
if fields:
|
||||
params = [param for param_list in params for param in param_list]
|
||||
return (["DISTINCT ON (%s)" % ", ".join(fields)], params)
|
||||
else:
|
||||
return ["DISTINCT"], []
|
||||
|
||||
if is_psycopg3:
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
if self.connection.features.uses_server_side_binding:
|
||||
try:
|
||||
return self.compose_sql(sql, params)
|
||||
except errors.DataError:
|
||||
return None
|
||||
else:
|
||||
if cursor._query and cursor._query.query is not None:
|
||||
return cursor._query.query.decode()
|
||||
return None
|
||||
|
||||
else:
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# https://www.psycopg.org/docs/cursor.html#cursor.query
|
||||
# The query attribute is a Psycopg extension to the DB API 2.0.
|
||||
if cursor.query is not None:
|
||||
return cursor.query.decode()
|
||||
return None
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
if not fields:
|
||||
return "", ()
|
||||
columns = [
|
||||
"%s.%s"
|
||||
% (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
)
|
||||
for field in fields
|
||||
]
|
||||
return "RETURNING %s" % ", ".join(columns), ()
|
||||
|
||||
if is_psycopg3:
|
||||
|
||||
def adapt_integerfield_value(self, value, internal_type):
|
||||
if value is None or hasattr(value, "resolve_expression"):
|
||||
return value
|
||||
return self.integerfield_type_map[internal_type](value)
|
||||
|
||||
def adapt_datefield_value(self, value):
|
||||
return value
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
return value
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
return value
|
||||
|
||||
def adapt_ipaddressfield_value(self, value):
|
||||
if value:
|
||||
return Inet(value)
|
||||
return None
|
||||
|
||||
def adapt_json_value(self, value, encoder):
|
||||
return Jsonb(value, dumps=get_json_dumps(encoder))
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
if internal_type == "DateField":
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
return "(interval '1 day' * (%s - %s))" % (lhs_sql, rhs_sql), params
|
||||
return super().subtract_temporals(internal_type, lhs, rhs)
|
||||
|
||||
def explain_query_prefix(self, format=None, **options):
|
||||
extra = {}
|
||||
if serialize := options.pop("serialize", None):
|
||||
if serialize.upper() in {"TEXT", "BINARY"}:
|
||||
extra["SERIALIZE"] = serialize.upper()
|
||||
# Normalize options.
|
||||
if options:
|
||||
options = {
|
||||
name.upper(): "true" if value else "false"
|
||||
for name, value in options.items()
|
||||
}
|
||||
for valid_option in self.explain_options:
|
||||
value = options.pop(valid_option, None)
|
||||
if value is not None:
|
||||
extra[valid_option] = value
|
||||
prefix = super().explain_query_prefix(format, **options)
|
||||
if format:
|
||||
extra["FORMAT"] = format
|
||||
if extra:
|
||||
prefix += " (%s)" % ", ".join("%s %s" % i for i in extra.items())
|
||||
return prefix
|
||||
|
||||
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
||||
if on_conflict == OnConflict.IGNORE:
|
||||
return "ON CONFLICT DO NOTHING"
|
||||
if on_conflict == OnConflict.UPDATE:
|
||||
return "ON CONFLICT(%s) DO UPDATE SET %s" % (
|
||||
", ".join(map(self.quote_name, unique_fields)),
|
||||
", ".join(
|
||||
[
|
||||
f"{field} = EXCLUDED.{field}"
|
||||
for field in map(self.quote_name, update_fields)
|
||||
]
|
||||
),
|
||||
)
|
||||
return super().on_conflict_suffix_sql(
|
||||
fields,
|
||||
on_conflict,
|
||||
update_fields,
|
||||
unique_fields,
|
||||
)
|
||||
|
||||
def prepare_join_on_clause(self, lhs_table, lhs_field, rhs_table, rhs_field):
|
||||
lhs_expr, rhs_expr = super().prepare_join_on_clause(
|
||||
lhs_table, lhs_field, rhs_table, rhs_field
|
||||
)
|
||||
|
||||
if lhs_field.db_type(self.connection) != rhs_field.db_type(self.connection):
|
||||
rhs_expr = Cast(rhs_expr, lhs_field)
|
||||
|
||||
return lhs_expr, rhs_expr
|
||||
@@ -0,0 +1,114 @@
|
||||
import ipaddress
|
||||
from functools import lru_cache
|
||||
|
||||
try:
|
||||
from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors, sql
|
||||
from psycopg.postgres import types
|
||||
from psycopg.types.datetime import TimestamptzLoader
|
||||
from psycopg.types.json import Jsonb
|
||||
from psycopg.types.range import Range, RangeDumper
|
||||
from psycopg.types.string import TextLoader
|
||||
|
||||
Inet = ipaddress.ip_address
|
||||
|
||||
DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range
|
||||
RANGE_TYPES = (Range,)
|
||||
|
||||
TSRANGE_OID = types["tsrange"].oid
|
||||
TSTZRANGE_OID = types["tstzrange"].oid
|
||||
|
||||
def mogrify(sql, params, connection):
|
||||
with connection.cursor() as cursor:
|
||||
return ClientCursor(cursor.connection).mogrify(sql, params)
|
||||
|
||||
# Adapters.
|
||||
class BaseTzLoader(TimestamptzLoader):
|
||||
"""
|
||||
Load a PostgreSQL timestamptz using the a specific timezone.
|
||||
The timezone can be None too, in which case it will be chopped.
|
||||
"""
|
||||
|
||||
timezone = None
|
||||
|
||||
def load(self, data):
|
||||
res = super().load(data)
|
||||
return res.replace(tzinfo=self.timezone)
|
||||
|
||||
def register_tzloader(tz, context):
|
||||
class SpecificTzLoader(BaseTzLoader):
|
||||
timezone = tz
|
||||
|
||||
context.adapters.register_loader("timestamptz", SpecificTzLoader)
|
||||
|
||||
class DjangoRangeDumper(RangeDumper):
|
||||
"""A Range dumper customized for Django."""
|
||||
|
||||
def upgrade(self, obj, format):
|
||||
# Dump ranges containing naive datetimes as tstzrange, because
|
||||
# Django doesn't use tz-aware ones.
|
||||
dumper = super().upgrade(obj, format)
|
||||
if dumper is not self and dumper.oid == TSRANGE_OID:
|
||||
dumper.oid = TSTZRANGE_OID
|
||||
return dumper
|
||||
|
||||
@lru_cache
|
||||
def get_adapters_template(use_tz, timezone):
|
||||
# Create at adapters map extending the base one.
|
||||
ctx = adapt.AdaptersMap(adapters)
|
||||
# Register a no-op dumper to avoid a round trip from psycopg version 3
|
||||
# decode to json.dumps() to json.loads(), when using a custom decoder
|
||||
# in JSONField.
|
||||
ctx.register_loader("jsonb", TextLoader)
|
||||
# Don't convert automatically from PostgreSQL network types to Python
|
||||
# ipaddress.
|
||||
ctx.register_loader("inet", TextLoader)
|
||||
ctx.register_loader("cidr", TextLoader)
|
||||
ctx.register_dumper(Range, DjangoRangeDumper)
|
||||
# Register a timestamptz loader configured on self.timezone.
|
||||
# This, however, can be overridden by create_cursor.
|
||||
register_tzloader(timezone, ctx)
|
||||
return ctx
|
||||
|
||||
is_psycopg3 = True
|
||||
|
||||
except ImportError:
|
||||
from enum import IntEnum
|
||||
|
||||
from psycopg2 import errors, extensions, sql # NOQA
|
||||
from psycopg2.extras import ( # NOQA
|
||||
DateRange,
|
||||
DateTimeRange,
|
||||
DateTimeTZRange,
|
||||
Inet,
|
||||
Json,
|
||||
NumericRange,
|
||||
Range,
|
||||
)
|
||||
|
||||
RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange)
|
||||
|
||||
class IsolationLevel(IntEnum):
|
||||
READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
|
||||
READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED
|
||||
REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE
|
||||
|
||||
def _quote(value, connection=None):
|
||||
adapted = extensions.adapt(value)
|
||||
if hasattr(adapted, "encoding"):
|
||||
adapted.encoding = "utf8"
|
||||
# getquoted() returns a quoted bytestring of the adapted value.
|
||||
return adapted.getquoted().decode()
|
||||
|
||||
sql.quote = _quote
|
||||
|
||||
def mogrify(sql, params, connection):
|
||||
with connection.cursor() as cursor:
|
||||
return cursor.mogrify(sql, params).decode()
|
||||
|
||||
is_psycopg3 = False
|
||||
|
||||
class Jsonb(Json):
|
||||
def getquoted(self):
|
||||
quoted = super().getquoted()
|
||||
return quoted + b"::jsonb"
|
||||
@@ -0,0 +1,380 @@
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.backends.ddl_references import IndexColumns
|
||||
from django.db.backends.postgresql.psycopg_any import sql
|
||||
from django.db.backends.utils import strip_quotes
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
# Setting all constraints to IMMEDIATE to allow changing data in the same
|
||||
# transaction.
|
||||
sql_update_with_default = (
|
||||
"UPDATE %(table)s SET %(column)s = %(default)s WHERE %(column)s IS NULL"
|
||||
"; SET CONSTRAINTS ALL IMMEDIATE"
|
||||
)
|
||||
sql_alter_sequence_type = "ALTER SEQUENCE IF EXISTS %(sequence)s AS %(type)s"
|
||||
sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
|
||||
|
||||
sql_create_index = (
|
||||
"CREATE INDEX %(name)s ON %(table)s%(using)s "
|
||||
"(%(columns)s)%(include)s%(extra)s%(condition)s"
|
||||
)
|
||||
sql_create_index_concurrently = (
|
||||
"CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s "
|
||||
"(%(columns)s)%(include)s%(extra)s%(condition)s"
|
||||
)
|
||||
sql_delete_index = "DROP INDEX IF EXISTS %(name)s"
|
||||
sql_delete_index_concurrently = "DROP INDEX CONCURRENTLY IF EXISTS %(name)s"
|
||||
|
||||
# Setting the constraint to IMMEDIATE to allow changing data in the same
|
||||
# transaction.
|
||||
sql_create_column_inline_fk = (
|
||||
"CONSTRAINT %(name)s REFERENCES %(to_table)s(%(to_column)s)%(deferrable)s"
|
||||
"; SET CONSTRAINTS %(namespace)s%(name)s IMMEDIATE"
|
||||
)
|
||||
# Setting the constraint to IMMEDIATE runs any deferred checks to allow
|
||||
# dropping it in the same transaction.
|
||||
sql_delete_fk = (
|
||||
"SET CONSTRAINTS %(name)s IMMEDIATE; "
|
||||
"ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
|
||||
)
|
||||
sql_delete_procedure = "DROP FUNCTION %(procedure)s(%(param_types)s)"
|
||||
|
||||
def execute(self, sql, params=()):
|
||||
# Merge the query client-side, as PostgreSQL won't do it server-side.
|
||||
if params is None:
|
||||
return super().execute(sql, params)
|
||||
sql = self.connection.ops.compose_sql(str(sql), params)
|
||||
# Don't let the superclass touch anything.
|
||||
return super().execute(sql, None)
|
||||
|
||||
sql_add_identity = (
|
||||
"ALTER TABLE %(table)s ALTER COLUMN %(column)s ADD "
|
||||
"GENERATED BY DEFAULT AS IDENTITY"
|
||||
)
|
||||
sql_drop_indentity = (
|
||||
"ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP IDENTITY IF EXISTS"
|
||||
)
|
||||
|
||||
def quote_value(self, value):
|
||||
return sql.quote(value, self.connection.connection)
|
||||
|
||||
def _field_indexes_sql(self, model, field):
|
||||
output = super()._field_indexes_sql(model, field)
|
||||
like_index_statement = self._create_like_index_sql(model, field)
|
||||
if like_index_statement is not None:
|
||||
output.append(like_index_statement)
|
||||
return output
|
||||
|
||||
def _field_data_type(self, field):
|
||||
if field.is_relation:
|
||||
return field.rel_db_type(self.connection)
|
||||
return self.connection.data_types.get(
|
||||
field.get_internal_type(),
|
||||
field.db_type(self.connection),
|
||||
)
|
||||
|
||||
def _field_base_data_types(self, field):
|
||||
# Yield base data types for array fields.
|
||||
if field.base_field.get_internal_type() == "ArrayField":
|
||||
yield from self._field_base_data_types(field.base_field)
|
||||
else:
|
||||
yield self._field_data_type(field.base_field)
|
||||
|
||||
def _create_like_index_sql(self, model, field):
|
||||
"""
|
||||
Return the statement to create an index with varchar operator pattern
|
||||
when the column type is 'varchar' or 'text', otherwise return None.
|
||||
"""
|
||||
db_type = field.db_type(connection=self.connection)
|
||||
if db_type is not None and (field.db_index or field.unique):
|
||||
# Fields with database column types of `varchar` and `text` need
|
||||
# a second index that specifies their operator class, which is
|
||||
# needed when performing correct LIKE queries outside the
|
||||
# C locale. See #12234.
|
||||
#
|
||||
# The same doesn't apply to array fields such as varchar[size]
|
||||
# and text[size], so skip them.
|
||||
if "[" in db_type:
|
||||
return None
|
||||
# Non-deterministic collations on Postgresql don't support indexes
|
||||
# for operator classes varchar_pattern_ops/text_pattern_ops.
|
||||
collation_name = getattr(field, "db_collation", None)
|
||||
if not collation_name and field.is_relation:
|
||||
collation_name = getattr(field.target_field, "db_collation", None)
|
||||
if collation_name and not self._is_collation_deterministic(collation_name):
|
||||
return None
|
||||
if db_type.startswith("varchar"):
|
||||
return self._create_index_sql(
|
||||
model,
|
||||
fields=[field],
|
||||
suffix="_like",
|
||||
opclasses=["varchar_pattern_ops"],
|
||||
)
|
||||
elif db_type.startswith("text"):
|
||||
return self._create_index_sql(
|
||||
model,
|
||||
fields=[field],
|
||||
suffix="_like",
|
||||
opclasses=["text_pattern_ops"],
|
||||
)
|
||||
return None
|
||||
|
||||
def _using_sql(self, new_field, old_field):
|
||||
if new_field.generated:
|
||||
return ""
|
||||
using_sql = " USING %(column)s::%(type)s"
|
||||
new_internal_type = new_field.get_internal_type()
|
||||
old_internal_type = old_field.get_internal_type()
|
||||
if new_internal_type == "ArrayField" and new_internal_type == old_internal_type:
|
||||
# Compare base data types for array fields.
|
||||
if list(self._field_base_data_types(old_field)) != list(
|
||||
self._field_base_data_types(new_field)
|
||||
):
|
||||
return using_sql
|
||||
elif self._field_data_type(old_field) != self._field_data_type(new_field):
|
||||
return using_sql
|
||||
return ""
|
||||
|
||||
def _get_sequence_name(self, table, column):
|
||||
with self.connection.cursor() as cursor:
|
||||
for sequence in self.connection.introspection.get_sequences(cursor, table):
|
||||
if sequence["column"] == column:
|
||||
return sequence["name"]
|
||||
return None
|
||||
|
||||
def _is_changing_type_of_indexed_text_column(self, old_field, old_type, new_type):
|
||||
return (old_field.db_index or old_field.unique) and (
|
||||
(old_type.startswith("varchar") and not new_type.startswith("varchar"))
|
||||
or (old_type.startswith("text") and not new_type.startswith("text"))
|
||||
or (old_type.startswith("citext") and not new_type.startswith("citext"))
|
||||
)
|
||||
|
||||
def _alter_column_type_sql(
|
||||
self, model, old_field, new_field, new_type, old_collation, new_collation
|
||||
):
|
||||
# Drop indexes on varchar/text/citext columns that are changing to a
|
||||
# different type.
|
||||
old_db_params = old_field.db_parameters(connection=self.connection)
|
||||
old_type = old_db_params["type"]
|
||||
if self._is_changing_type_of_indexed_text_column(old_field, old_type, new_type):
|
||||
index_name = self._create_index_name(
|
||||
model._meta.db_table, [old_field.column], suffix="_like"
|
||||
)
|
||||
self.execute(self._delete_index_sql(model, index_name))
|
||||
|
||||
self.sql_alter_column_type = (
|
||||
"ALTER COLUMN %(column)s TYPE %(type)s%(collation)s"
|
||||
)
|
||||
# Cast when data type changed.
|
||||
if using_sql := self._using_sql(new_field, old_field):
|
||||
self.sql_alter_column_type += using_sql
|
||||
new_internal_type = new_field.get_internal_type()
|
||||
old_internal_type = old_field.get_internal_type()
|
||||
# Make ALTER TYPE with IDENTITY make sense.
|
||||
table = strip_quotes(model._meta.db_table)
|
||||
auto_field_types = {
|
||||
"AutoField",
|
||||
"BigAutoField",
|
||||
"SmallAutoField",
|
||||
}
|
||||
old_is_auto = old_internal_type in auto_field_types
|
||||
new_is_auto = new_internal_type in auto_field_types
|
||||
if new_is_auto and not old_is_auto:
|
||||
column = strip_quotes(new_field.column)
|
||||
return (
|
||||
(
|
||||
self.sql_alter_column_type
|
||||
% {
|
||||
"column": self.quote_name(column),
|
||||
"type": new_type,
|
||||
"collation": "",
|
||||
},
|
||||
[],
|
||||
),
|
||||
[
|
||||
(
|
||||
self.sql_add_identity
|
||||
% {
|
||||
"table": self.quote_name(table),
|
||||
"column": self.quote_name(column),
|
||||
},
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
elif old_is_auto and not new_is_auto:
|
||||
# Drop IDENTITY if exists (pre-Django 4.1 serial columns don't have
|
||||
# it).
|
||||
self.execute(
|
||||
self.sql_drop_indentity
|
||||
% {
|
||||
"table": self.quote_name(table),
|
||||
"column": self.quote_name(strip_quotes(new_field.column)),
|
||||
}
|
||||
)
|
||||
column = strip_quotes(new_field.column)
|
||||
fragment, _ = super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
# Drop the sequence if exists (Django 4.1+ identity columns don't
|
||||
# have it).
|
||||
other_actions = []
|
||||
if sequence_name := self._get_sequence_name(table, column):
|
||||
other_actions = [
|
||||
(
|
||||
self.sql_delete_sequence
|
||||
% {
|
||||
"sequence": self.quote_name(sequence_name),
|
||||
},
|
||||
[],
|
||||
)
|
||||
]
|
||||
return fragment, other_actions
|
||||
elif new_is_auto and old_is_auto and old_internal_type != new_internal_type:
|
||||
fragment, _ = super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
column = strip_quotes(new_field.column)
|
||||
db_types = {
|
||||
"AutoField": "integer",
|
||||
"BigAutoField": "bigint",
|
||||
"SmallAutoField": "smallint",
|
||||
}
|
||||
# Alter the sequence type if exists (Django 4.1+ identity columns
|
||||
# don't have it).
|
||||
other_actions = []
|
||||
if sequence_name := self._get_sequence_name(table, column):
|
||||
other_actions = [
|
||||
(
|
||||
self.sql_alter_sequence_type
|
||||
% {
|
||||
"sequence": self.quote_name(sequence_name),
|
||||
"type": db_types[new_internal_type],
|
||||
},
|
||||
[],
|
||||
),
|
||||
]
|
||||
return fragment, other_actions
|
||||
else:
|
||||
return super()._alter_column_type_sql(
|
||||
model, old_field, new_field, new_type, old_collation, new_collation
|
||||
)
|
||||
|
||||
def _alter_field(
|
||||
self,
|
||||
model,
|
||||
old_field,
|
||||
new_field,
|
||||
old_type,
|
||||
new_type,
|
||||
old_db_params,
|
||||
new_db_params,
|
||||
strict=False,
|
||||
):
|
||||
super()._alter_field(
|
||||
model,
|
||||
old_field,
|
||||
new_field,
|
||||
old_type,
|
||||
new_type,
|
||||
old_db_params,
|
||||
new_db_params,
|
||||
strict,
|
||||
)
|
||||
# Added an index? Create any PostgreSQL-specific indexes.
|
||||
if (
|
||||
(not (old_field.db_index or old_field.unique) and new_field.db_index)
|
||||
or (not old_field.unique and new_field.unique)
|
||||
or (
|
||||
self._is_changing_type_of_indexed_text_column(
|
||||
old_field, old_type, new_type
|
||||
)
|
||||
)
|
||||
):
|
||||
like_index_statement = self._create_like_index_sql(model, new_field)
|
||||
if like_index_statement is not None:
|
||||
self.execute(like_index_statement)
|
||||
|
||||
# Removed an index? Drop any PostgreSQL-specific indexes.
|
||||
if old_field.unique and not (new_field.db_index or new_field.unique):
|
||||
index_to_remove = self._create_index_name(
|
||||
model._meta.db_table, [old_field.column], suffix="_like"
|
||||
)
|
||||
self.execute(self._delete_index_sql(model, index_to_remove))
|
||||
|
||||
def _index_columns(self, table, columns, col_suffixes, opclasses):
|
||||
if opclasses:
|
||||
return IndexColumns(
|
||||
table,
|
||||
columns,
|
||||
self.quote_name,
|
||||
col_suffixes=col_suffixes,
|
||||
opclasses=opclasses,
|
||||
)
|
||||
return super()._index_columns(table, columns, col_suffixes, opclasses)
|
||||
|
||||
def add_index(self, model, index, concurrently=False):
|
||||
self.execute(
|
||||
index.create_sql(model, self, concurrently=concurrently), params=None
|
||||
)
|
||||
|
||||
def remove_index(self, model, index, concurrently=False):
|
||||
self.execute(index.remove_sql(model, self, concurrently=concurrently))
|
||||
|
||||
def _delete_index_sql(self, model, name, sql=None, concurrently=False):
|
||||
sql = (
|
||||
self.sql_delete_index_concurrently
|
||||
if concurrently
|
||||
else self.sql_delete_index
|
||||
)
|
||||
return super()._delete_index_sql(model, name, sql)
|
||||
|
||||
def _create_index_sql(
|
||||
self,
|
||||
model,
|
||||
*,
|
||||
fields=None,
|
||||
name=None,
|
||||
suffix="",
|
||||
using="",
|
||||
db_tablespace=None,
|
||||
col_suffixes=(),
|
||||
sql=None,
|
||||
opclasses=(),
|
||||
condition=None,
|
||||
concurrently=False,
|
||||
include=None,
|
||||
expressions=None,
|
||||
):
|
||||
sql = sql or (
|
||||
self.sql_create_index
|
||||
if not concurrently
|
||||
else self.sql_create_index_concurrently
|
||||
)
|
||||
return super()._create_index_sql(
|
||||
model,
|
||||
fields=fields,
|
||||
name=name,
|
||||
suffix=suffix,
|
||||
using=using,
|
||||
db_tablespace=db_tablespace,
|
||||
col_suffixes=col_suffixes,
|
||||
sql=sql,
|
||||
opclasses=opclasses,
|
||||
condition=condition,
|
||||
include=include,
|
||||
expressions=expressions,
|
||||
)
|
||||
|
||||
def _is_collation_deterministic(self, collation_name):
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT collisdeterministic
|
||||
FROM pg_collation
|
||||
WHERE collname = %s
|
||||
""",
|
||||
[collation_name],
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row[0] if row else None
|
||||
@@ -0,0 +1,3 @@
|
||||
from django.dispatch import Signal
|
||||
|
||||
connection_created = Signal()
|
||||
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
Implementations of SQL functions for SQLite.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import random
|
||||
import statistics
|
||||
import zoneinfo
|
||||
from datetime import timedelta
|
||||
from hashlib import md5, sha1, sha224, sha256, sha384, sha512
|
||||
from math import (
|
||||
acos,
|
||||
asin,
|
||||
atan,
|
||||
atan2,
|
||||
ceil,
|
||||
cos,
|
||||
degrees,
|
||||
exp,
|
||||
floor,
|
||||
fmod,
|
||||
log,
|
||||
pi,
|
||||
radians,
|
||||
sin,
|
||||
sqrt,
|
||||
tan,
|
||||
)
|
||||
from re import search as re_search
|
||||
|
||||
from django.db.backends.utils import (
|
||||
split_tzname_delta,
|
||||
typecast_time,
|
||||
typecast_timestamp,
|
||||
)
|
||||
from django.utils import timezone
|
||||
from django.utils.duration import duration_microseconds
|
||||
|
||||
|
||||
def register(connection):
|
||||
create_deterministic_function = functools.partial(
|
||||
connection.create_function,
|
||||
deterministic=True,
|
||||
)
|
||||
create_deterministic_function("django_date_extract", 2, _sqlite_datetime_extract)
|
||||
create_deterministic_function("django_date_trunc", 4, _sqlite_date_trunc)
|
||||
create_deterministic_function(
|
||||
"django_datetime_cast_date", 3, _sqlite_datetime_cast_date
|
||||
)
|
||||
create_deterministic_function(
|
||||
"django_datetime_cast_time", 3, _sqlite_datetime_cast_time
|
||||
)
|
||||
create_deterministic_function(
|
||||
"django_datetime_extract", 4, _sqlite_datetime_extract
|
||||
)
|
||||
create_deterministic_function("django_datetime_trunc", 4, _sqlite_datetime_trunc)
|
||||
create_deterministic_function("django_time_extract", 2, _sqlite_time_extract)
|
||||
create_deterministic_function("django_time_trunc", 4, _sqlite_time_trunc)
|
||||
create_deterministic_function("django_time_diff", 2, _sqlite_time_diff)
|
||||
create_deterministic_function("django_timestamp_diff", 2, _sqlite_timestamp_diff)
|
||||
create_deterministic_function("django_format_dtdelta", 3, _sqlite_format_dtdelta)
|
||||
create_deterministic_function("regexp", 2, _sqlite_regexp)
|
||||
create_deterministic_function("BITXOR", 2, _sqlite_bitxor)
|
||||
create_deterministic_function("COT", 1, _sqlite_cot)
|
||||
create_deterministic_function("LPAD", 3, _sqlite_lpad)
|
||||
create_deterministic_function("MD5", 1, _sqlite_md5)
|
||||
create_deterministic_function("REPEAT", 2, _sqlite_repeat)
|
||||
create_deterministic_function("REVERSE", 1, _sqlite_reverse)
|
||||
create_deterministic_function("RPAD", 3, _sqlite_rpad)
|
||||
create_deterministic_function("SHA1", 1, _sqlite_sha1)
|
||||
create_deterministic_function("SHA224", 1, _sqlite_sha224)
|
||||
create_deterministic_function("SHA256", 1, _sqlite_sha256)
|
||||
create_deterministic_function("SHA384", 1, _sqlite_sha384)
|
||||
create_deterministic_function("SHA512", 1, _sqlite_sha512)
|
||||
create_deterministic_function("SIGN", 1, _sqlite_sign)
|
||||
# Don't use the built-in RANDOM() function because it returns a value
|
||||
# in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
|
||||
connection.create_function("RAND", 0, random.random)
|
||||
connection.create_aggregate("STDDEV_POP", 1, StdDevPop)
|
||||
connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp)
|
||||
connection.create_aggregate("VAR_POP", 1, VarPop)
|
||||
connection.create_aggregate("VAR_SAMP", 1, VarSamp)
|
||||
# Some math functions are enabled by default in SQLite 3.35+.
|
||||
sql = "select sqlite_compileoption_used('ENABLE_MATH_FUNCTIONS')"
|
||||
if not connection.execute(sql).fetchone()[0]:
|
||||
create_deterministic_function("ACOS", 1, _sqlite_acos)
|
||||
create_deterministic_function("ASIN", 1, _sqlite_asin)
|
||||
create_deterministic_function("ATAN", 1, _sqlite_atan)
|
||||
create_deterministic_function("ATAN2", 2, _sqlite_atan2)
|
||||
create_deterministic_function("CEILING", 1, _sqlite_ceiling)
|
||||
create_deterministic_function("COS", 1, _sqlite_cos)
|
||||
create_deterministic_function("DEGREES", 1, _sqlite_degrees)
|
||||
create_deterministic_function("EXP", 1, _sqlite_exp)
|
||||
create_deterministic_function("FLOOR", 1, _sqlite_floor)
|
||||
create_deterministic_function("LN", 1, _sqlite_ln)
|
||||
create_deterministic_function("LOG", 2, _sqlite_log)
|
||||
create_deterministic_function("MOD", 2, _sqlite_mod)
|
||||
create_deterministic_function("PI", 0, _sqlite_pi)
|
||||
create_deterministic_function("POWER", 2, _sqlite_power)
|
||||
create_deterministic_function("RADIANS", 1, _sqlite_radians)
|
||||
create_deterministic_function("SIN", 1, _sqlite_sin)
|
||||
create_deterministic_function("SQRT", 1, _sqlite_sqrt)
|
||||
create_deterministic_function("TAN", 1, _sqlite_tan)
|
||||
|
||||
|
||||
def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = typecast_timestamp(dt)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
if conn_tzname:
|
||||
dt = dt.replace(tzinfo=zoneinfo.ZoneInfo(conn_tzname))
|
||||
if tzname is not None and tzname != conn_tzname:
|
||||
tzname, sign, offset = split_tzname_delta(tzname)
|
||||
if offset:
|
||||
hours, minutes = offset.split(":")
|
||||
offset_delta = timedelta(hours=int(hours), minutes=int(minutes))
|
||||
dt += offset_delta if sign == "+" else -offset_delta
|
||||
# The tzname may originally be just the offset e.g. "+3:00",
|
||||
# which becomes an empty string after splitting the sign and offset.
|
||||
# In this case, use the conn_tzname as fallback.
|
||||
dt = timezone.localtime(dt, zoneinfo.ZoneInfo(tzname or conn_tzname))
|
||||
return dt
|
||||
|
||||
|
||||
def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == "year":
|
||||
return f"{dt.year:04d}-01-01"
|
||||
elif lookup_type == "quarter":
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return f"{dt.year:04d}-{month_in_quarter:02d}-01"
|
||||
elif lookup_type == "month":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-01"
|
||||
elif lookup_type == "week":
|
||||
dt -= timedelta(days=dt.weekday())
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
|
||||
elif lookup_type == "day":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
|
||||
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
|
||||
|
||||
|
||||
def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
if dt is None:
|
||||
return None
|
||||
dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt_parsed is None:
|
||||
try:
|
||||
dt = typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
else:
|
||||
dt = dt_parsed
|
||||
if lookup_type == "hour":
|
||||
return f"{dt.hour:02d}:00:00"
|
||||
elif lookup_type == "minute":
|
||||
return f"{dt.hour:02d}:{dt.minute:02d}:00"
|
||||
elif lookup_type == "second":
|
||||
return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
|
||||
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.date().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
return dt.time().isoformat()
|
||||
|
||||
|
||||
def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == "week_day":
|
||||
return (dt.isoweekday() % 7) + 1
|
||||
elif lookup_type == "iso_week_day":
|
||||
return dt.isoweekday()
|
||||
elif lookup_type == "week":
|
||||
return dt.isocalendar().week
|
||||
elif lookup_type == "quarter":
|
||||
return ceil(dt.month / 3)
|
||||
elif lookup_type == "iso_year":
|
||||
return dt.isocalendar().year
|
||||
else:
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
|
||||
dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
|
||||
if dt is None:
|
||||
return None
|
||||
if lookup_type == "year":
|
||||
return f"{dt.year:04d}-01-01 00:00:00"
|
||||
elif lookup_type == "quarter":
|
||||
month_in_quarter = dt.month - (dt.month - 1) % 3
|
||||
return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00"
|
||||
elif lookup_type == "month":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00"
|
||||
elif lookup_type == "week":
|
||||
dt -= timedelta(days=dt.weekday())
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
|
||||
elif lookup_type == "day":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
|
||||
elif lookup_type == "hour":
|
||||
return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00"
|
||||
elif lookup_type == "minute":
|
||||
return (
|
||||
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
|
||||
f"{dt.hour:02d}:{dt.minute:02d}:00"
|
||||
)
|
||||
elif lookup_type == "second":
|
||||
return (
|
||||
f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
|
||||
f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
|
||||
)
|
||||
raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
|
||||
|
||||
|
||||
def _sqlite_time_extract(lookup_type, dt):
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
dt = typecast_time(dt)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
return getattr(dt, lookup_type)
|
||||
|
||||
|
||||
def _sqlite_prepare_dtdelta_param(conn, param):
|
||||
if conn in ["+", "-"]:
|
||||
if isinstance(param, int):
|
||||
return timedelta(0, 0, param)
|
||||
else:
|
||||
return typecast_timestamp(param)
|
||||
return param
|
||||
|
||||
|
||||
def _sqlite_format_dtdelta(connector, lhs, rhs):
|
||||
"""
|
||||
LHS and RHS can be either:
|
||||
- An integer number of microseconds
|
||||
- A string representing a datetime
|
||||
- A scalar value, e.g. float
|
||||
"""
|
||||
if connector is None or lhs is None or rhs is None:
|
||||
return None
|
||||
connector = connector.strip()
|
||||
try:
|
||||
real_lhs = _sqlite_prepare_dtdelta_param(connector, lhs)
|
||||
real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
if connector == "+":
|
||||
# typecast_timestamp() returns a date or a datetime without timezone.
|
||||
# It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
|
||||
out = str(real_lhs + real_rhs)
|
||||
elif connector == "-":
|
||||
out = str(real_lhs - real_rhs)
|
||||
elif connector == "*":
|
||||
out = real_lhs * real_rhs
|
||||
else:
|
||||
out = real_lhs / real_rhs
|
||||
return out
|
||||
|
||||
|
||||
def _sqlite_time_diff(lhs, rhs):
|
||||
if lhs is None or rhs is None:
|
||||
return None
|
||||
left = typecast_time(lhs)
|
||||
right = typecast_time(rhs)
|
||||
return (
|
||||
(left.hour * 60 * 60 * 1000000)
|
||||
+ (left.minute * 60 * 1000000)
|
||||
+ (left.second * 1000000)
|
||||
+ (left.microsecond)
|
||||
- (right.hour * 60 * 60 * 1000000)
|
||||
- (right.minute * 60 * 1000000)
|
||||
- (right.second * 1000000)
|
||||
- (right.microsecond)
|
||||
)
|
||||
|
||||
|
||||
def _sqlite_timestamp_diff(lhs, rhs):
|
||||
if lhs is None or rhs is None:
|
||||
return None
|
||||
left = typecast_timestamp(lhs)
|
||||
right = typecast_timestamp(rhs)
|
||||
return duration_microseconds(left - right)
|
||||
|
||||
|
||||
def _sqlite_regexp(pattern, string):
|
||||
if pattern is None or string is None:
|
||||
return None
|
||||
if not isinstance(string, str):
|
||||
string = str(string)
|
||||
return bool(re_search(pattern, string))
|
||||
|
||||
|
||||
def _sqlite_acos(x):
|
||||
if x is None:
|
||||
return None
|
||||
return acos(x)
|
||||
|
||||
|
||||
def _sqlite_asin(x):
|
||||
if x is None:
|
||||
return None
|
||||
return asin(x)
|
||||
|
||||
|
||||
def _sqlite_atan(x):
|
||||
if x is None:
|
||||
return None
|
||||
return atan(x)
|
||||
|
||||
|
||||
def _sqlite_atan2(y, x):
|
||||
if y is None or x is None:
|
||||
return None
|
||||
return atan2(y, x)
|
||||
|
||||
|
||||
def _sqlite_bitxor(x, y):
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return x ^ y
|
||||
|
||||
|
||||
def _sqlite_ceiling(x):
|
||||
if x is None:
|
||||
return None
|
||||
return ceil(x)
|
||||
|
||||
|
||||
def _sqlite_cos(x):
|
||||
if x is None:
|
||||
return None
|
||||
return cos(x)
|
||||
|
||||
|
||||
def _sqlite_cot(x):
|
||||
if x is None:
|
||||
return None
|
||||
return 1 / tan(x)
|
||||
|
||||
|
||||
def _sqlite_degrees(x):
|
||||
if x is None:
|
||||
return None
|
||||
return degrees(x)
|
||||
|
||||
|
||||
def _sqlite_exp(x):
|
||||
if x is None:
|
||||
return None
|
||||
return exp(x)
|
||||
|
||||
|
||||
def _sqlite_floor(x):
|
||||
if x is None:
|
||||
return None
|
||||
return floor(x)
|
||||
|
||||
|
||||
def _sqlite_ln(x):
|
||||
if x is None:
|
||||
return None
|
||||
return log(x)
|
||||
|
||||
|
||||
def _sqlite_log(base, x):
|
||||
if base is None or x is None:
|
||||
return None
|
||||
# Arguments reversed to match SQL standard.
|
||||
return log(x, base)
|
||||
|
||||
|
||||
def _sqlite_lpad(text, length, fill_text):
|
||||
if text is None or length is None or fill_text is None:
|
||||
return None
|
||||
delta = length - len(text)
|
||||
if delta <= 0:
|
||||
return text[:length]
|
||||
return (fill_text * length)[:delta] + text
|
||||
|
||||
|
||||
def _sqlite_md5(text):
|
||||
if text is None:
|
||||
return None
|
||||
return md5(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_mod(x, y):
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return fmod(x, y)
|
||||
|
||||
|
||||
def _sqlite_pi():
|
||||
return pi
|
||||
|
||||
|
||||
def _sqlite_power(x, y):
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return x**y
|
||||
|
||||
|
||||
def _sqlite_radians(x):
|
||||
if x is None:
|
||||
return None
|
||||
return radians(x)
|
||||
|
||||
|
||||
def _sqlite_repeat(text, count):
|
||||
if text is None or count is None:
|
||||
return None
|
||||
return text * count
|
||||
|
||||
|
||||
def _sqlite_reverse(text):
|
||||
if text is None:
|
||||
return None
|
||||
return text[::-1]
|
||||
|
||||
|
||||
def _sqlite_rpad(text, length, fill_text):
|
||||
if text is None or length is None or fill_text is None:
|
||||
return None
|
||||
return (text + fill_text * length)[:length]
|
||||
|
||||
|
||||
def _sqlite_sha1(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha1(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sha224(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha224(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sha256(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha256(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sha384(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha384(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sha512(text):
|
||||
if text is None:
|
||||
return None
|
||||
return sha512(text.encode()).hexdigest()
|
||||
|
||||
|
||||
def _sqlite_sign(x):
|
||||
if x is None:
|
||||
return None
|
||||
return (x > 0) - (x < 0)
|
||||
|
||||
|
||||
def _sqlite_sin(x):
|
||||
if x is None:
|
||||
return None
|
||||
return sin(x)
|
||||
|
||||
|
||||
def _sqlite_sqrt(x):
|
||||
if x is None:
|
||||
return None
|
||||
return sqrt(x)
|
||||
|
||||
|
||||
def _sqlite_tan(x):
|
||||
if x is None:
|
||||
return None
|
||||
return tan(x)
|
||||
|
||||
|
||||
class ListAggregate(list):
|
||||
step = list.append
|
||||
|
||||
|
||||
class StdDevPop(ListAggregate):
|
||||
finalize = statistics.pstdev
|
||||
|
||||
|
||||
class StdDevSamp(ListAggregate):
|
||||
finalize = statistics.stdev
|
||||
|
||||
|
||||
class VarPop(ListAggregate):
|
||||
finalize = statistics.pvariance
|
||||
|
||||
|
||||
class VarSamp(ListAggregate):
|
||||
finalize = statistics.variance
|
||||
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
SQLite backend for the sqlite3 module in the standard library.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from itertools import chain, tee
|
||||
from sqlite3 import dbapi2 as Database
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.db import IntegrityError
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.utils.asyncio import async_unsafe
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
from ._functions import register as register_functions
|
||||
from .client import DatabaseClient
|
||||
from .creation import DatabaseCreation
|
||||
from .features import DatabaseFeatures
|
||||
from .introspection import DatabaseIntrospection
|
||||
from .operations import DatabaseOperations
|
||||
from .schema import DatabaseSchemaEditor
|
||||
|
||||
|
||||
def decoder(conv_func):
|
||||
"""
|
||||
Convert bytestrings from Python's sqlite3 interface to a regular string.
|
||||
"""
|
||||
return lambda s: conv_func(s.decode())
|
||||
|
||||
|
||||
def adapt_date(val):
|
||||
return val.isoformat()
|
||||
|
||||
|
||||
def adapt_datetime(val):
|
||||
return val.isoformat(" ")
|
||||
|
||||
|
||||
def _get_varchar_column(data):
|
||||
if data["max_length"] is None:
|
||||
return "varchar"
|
||||
return "varchar(%(max_length)s)" % data
|
||||
|
||||
|
||||
Database.register_converter("bool", b"1".__eq__)
|
||||
Database.register_converter("date", decoder(parse_date))
|
||||
Database.register_converter("time", decoder(parse_time))
|
||||
Database.register_converter("datetime", decoder(parse_datetime))
|
||||
Database.register_converter("timestamp", decoder(parse_datetime))
|
||||
|
||||
Database.register_adapter(decimal.Decimal, str)
|
||||
Database.register_adapter(datetime.date, adapt_date)
|
||||
Database.register_adapter(datetime.datetime, adapt_datetime)
|
||||
|
||||
|
||||
class DatabaseWrapper(BaseDatabaseWrapper):
|
||||
vendor = "sqlite"
|
||||
display_name = "SQLite"
|
||||
# SQLite doesn't actually support most of these types, but it "does the right
|
||||
# thing" given more verbose field definitions, so leave them as is so that
|
||||
# schema inspection is more useful.
|
||||
data_types = {
|
||||
"AutoField": "integer",
|
||||
"BigAutoField": "integer",
|
||||
"BinaryField": "BLOB",
|
||||
"BooleanField": "bool",
|
||||
"CharField": _get_varchar_column,
|
||||
"DateField": "date",
|
||||
"DateTimeField": "datetime",
|
||||
"DecimalField": "decimal",
|
||||
"DurationField": "bigint",
|
||||
"FileField": "varchar(%(max_length)s)",
|
||||
"FilePathField": "varchar(%(max_length)s)",
|
||||
"FloatField": "real",
|
||||
"IntegerField": "integer",
|
||||
"BigIntegerField": "bigint",
|
||||
"IPAddressField": "char(15)",
|
||||
"GenericIPAddressField": "char(39)",
|
||||
"JSONField": "text",
|
||||
"OneToOneField": "integer",
|
||||
"PositiveBigIntegerField": "bigint unsigned",
|
||||
"PositiveIntegerField": "integer unsigned",
|
||||
"PositiveSmallIntegerField": "smallint unsigned",
|
||||
"SlugField": "varchar(%(max_length)s)",
|
||||
"SmallAutoField": "integer",
|
||||
"SmallIntegerField": "smallint",
|
||||
"TextField": "text",
|
||||
"TimeField": "time",
|
||||
"UUIDField": "char(32)",
|
||||
}
|
||||
data_type_check_constraints = {
|
||||
"PositiveBigIntegerField": '"%(column)s" >= 0',
|
||||
"JSONField": '(JSON_VALID("%(column)s") OR "%(column)s" IS NULL)',
|
||||
"PositiveIntegerField": '"%(column)s" >= 0',
|
||||
"PositiveSmallIntegerField": '"%(column)s" >= 0',
|
||||
}
|
||||
data_types_suffix = {
|
||||
"AutoField": "AUTOINCREMENT",
|
||||
"BigAutoField": "AUTOINCREMENT",
|
||||
"SmallAutoField": "AUTOINCREMENT",
|
||||
}
|
||||
# SQLite requires LIKE statements to include an ESCAPE clause if the value
|
||||
# being escaped has a percent or underscore in it.
|
||||
# See https://www.sqlite.org/lang_expr.html for an explanation.
|
||||
operators = {
|
||||
"exact": "= %s",
|
||||
"iexact": "LIKE %s ESCAPE '\\'",
|
||||
"contains": "LIKE %s ESCAPE '\\'",
|
||||
"icontains": "LIKE %s ESCAPE '\\'",
|
||||
"regex": "REGEXP %s",
|
||||
"iregex": "REGEXP '(?i)' || %s",
|
||||
"gt": "> %s",
|
||||
"gte": ">= %s",
|
||||
"lt": "< %s",
|
||||
"lte": "<= %s",
|
||||
"startswith": "LIKE %s ESCAPE '\\'",
|
||||
"endswith": "LIKE %s ESCAPE '\\'",
|
||||
"istartswith": "LIKE %s ESCAPE '\\'",
|
||||
"iendswith": "LIKE %s ESCAPE '\\'",
|
||||
}
|
||||
|
||||
# The patterns below are used to generate SQL pattern lookup clauses when
|
||||
# the right-hand side of the lookup isn't a raw string (it might be an expression
|
||||
# or the result of a bilateral transformation).
|
||||
# In those cases, special characters for LIKE operators (e.g. \, *, _) should be
|
||||
# escaped on database side.
|
||||
#
|
||||
# Note: we use str.format() here for readability as '%' is used as a wildcard for
|
||||
# the LIKE operator.
|
||||
pattern_esc = r"REPLACE(REPLACE(REPLACE({}, '\', '\\'), '%%', '\%%'), '_', '\_')"
|
||||
pattern_ops = {
|
||||
"contains": r"LIKE '%%' || {} || '%%' ESCAPE '\'",
|
||||
"icontains": r"LIKE '%%' || UPPER({}) || '%%' ESCAPE '\'",
|
||||
"startswith": r"LIKE {} || '%%' ESCAPE '\'",
|
||||
"istartswith": r"LIKE UPPER({}) || '%%' ESCAPE '\'",
|
||||
"endswith": r"LIKE '%%' || {} ESCAPE '\'",
|
||||
"iendswith": r"LIKE '%%' || UPPER({}) ESCAPE '\'",
|
||||
}
|
||||
|
||||
transaction_modes = frozenset(["DEFERRED", "EXCLUSIVE", "IMMEDIATE"])
|
||||
|
||||
Database = Database
|
||||
SchemaEditorClass = DatabaseSchemaEditor
|
||||
# Classes instantiated in __init__().
|
||||
client_class = DatabaseClient
|
||||
creation_class = DatabaseCreation
|
||||
features_class = DatabaseFeatures
|
||||
introspection_class = DatabaseIntrospection
|
||||
ops_class = DatabaseOperations
|
||||
|
||||
def get_connection_params(self):
|
||||
settings_dict = self.settings_dict
|
||||
if not settings_dict["NAME"]:
|
||||
raise ImproperlyConfigured(
|
||||
"settings.DATABASES is improperly configured. "
|
||||
"Please supply the NAME value."
|
||||
)
|
||||
kwargs = {
|
||||
"database": settings_dict["NAME"],
|
||||
"detect_types": Database.PARSE_DECLTYPES | Database.PARSE_COLNAMES,
|
||||
**settings_dict["OPTIONS"],
|
||||
}
|
||||
# Always allow the underlying SQLite connection to be shareable
|
||||
# between multiple threads. The safe-guarding will be handled at a
|
||||
# higher level by the `BaseDatabaseWrapper.allow_thread_sharing`
|
||||
# property. This is necessary as the shareability is disabled by
|
||||
# default in sqlite3 and it cannot be changed once a connection is
|
||||
# opened.
|
||||
if "check_same_thread" in kwargs and kwargs["check_same_thread"]:
|
||||
warnings.warn(
|
||||
"The `check_same_thread` option was provided and set to "
|
||||
"True. It will be overridden with False. Use the "
|
||||
"`DatabaseWrapper.allow_thread_sharing` property instead "
|
||||
"for controlling thread shareability.",
|
||||
RuntimeWarning,
|
||||
)
|
||||
kwargs.update({"check_same_thread": False, "uri": True})
|
||||
transaction_mode = kwargs.pop("transaction_mode", None)
|
||||
if (
|
||||
transaction_mode is not None
|
||||
and transaction_mode.upper() not in self.transaction_modes
|
||||
):
|
||||
allowed_transaction_modes = ", ".join(
|
||||
[f"{mode!r}" for mode in sorted(self.transaction_modes)]
|
||||
)
|
||||
raise ImproperlyConfigured(
|
||||
f"settings.DATABASES[{self.alias!r}]['OPTIONS']['transaction_mode'] "
|
||||
f"is improperly configured to '{transaction_mode}'. Use one of "
|
||||
f"{allowed_transaction_modes}, or None."
|
||||
)
|
||||
self.transaction_mode = transaction_mode.upper() if transaction_mode else None
|
||||
|
||||
init_command = kwargs.pop("init_command", "")
|
||||
self.init_commands = init_command.split(";")
|
||||
return kwargs
|
||||
|
||||
def get_database_version(self):
|
||||
return self.Database.sqlite_version_info
|
||||
|
||||
@async_unsafe
|
||||
def get_new_connection(self, conn_params):
|
||||
conn = Database.connect(**conn_params)
|
||||
register_functions(conn)
|
||||
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
# The macOS bundled SQLite defaults legacy_alter_table ON, which
|
||||
# prevents atomic table renames.
|
||||
conn.execute("PRAGMA legacy_alter_table = OFF")
|
||||
for init_command in self.init_commands:
|
||||
if init_command := init_command.strip():
|
||||
conn.execute(init_command)
|
||||
return conn
|
||||
|
||||
def create_cursor(self, name=None):
|
||||
return self.connection.cursor(factory=SQLiteCursorWrapper)
|
||||
|
||||
@async_unsafe
|
||||
def close(self):
|
||||
self.validate_thread_sharing()
|
||||
# If database is in memory, closing the connection destroys the
|
||||
# database. To prevent accidental data loss, ignore close requests on
|
||||
# an in-memory db.
|
||||
if not self.is_in_memory_db():
|
||||
BaseDatabaseWrapper.close(self)
|
||||
|
||||
def _savepoint_allowed(self):
|
||||
# When 'isolation_level' is not None, sqlite3 commits before each
|
||||
# savepoint; it's a bug. When it is None, savepoints don't make sense
|
||||
# because autocommit is enabled. The only exception is inside 'atomic'
|
||||
# blocks. To work around that bug, on SQLite, 'atomic' starts a
|
||||
# transaction explicitly rather than simply disable autocommit.
|
||||
return self.in_atomic_block
|
||||
|
||||
def _set_autocommit(self, autocommit):
|
||||
if autocommit:
|
||||
level = None
|
||||
else:
|
||||
# sqlite3's internal default is ''. It's different from None.
|
||||
# See Modules/_sqlite/connection.c.
|
||||
level = ""
|
||||
# 'isolation_level' is a misleading API.
|
||||
# SQLite always runs at the SERIALIZABLE isolation level.
|
||||
with self.wrap_database_errors:
|
||||
self.connection.isolation_level = level
|
||||
|
||||
def disable_constraint_checking(self):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("PRAGMA foreign_keys = OFF")
|
||||
# Foreign key constraints cannot be turned off while in a multi-
|
||||
# statement transaction. Fetch the current state of the pragma
|
||||
# to determine if constraints are effectively disabled.
|
||||
enabled = cursor.execute("PRAGMA foreign_keys").fetchone()[0]
|
||||
return not bool(enabled)
|
||||
|
||||
def enable_constraint_checking(self):
|
||||
with self.cursor() as cursor:
|
||||
cursor.execute("PRAGMA foreign_keys = ON")
|
||||
|
||||
def check_constraints(self, table_names=None):
|
||||
"""
|
||||
Check each table name in `table_names` for rows with invalid foreign
|
||||
key references. This method is intended to be used in conjunction with
|
||||
`disable_constraint_checking()` and `enable_constraint_checking()`, to
|
||||
determine if rows with invalid references were entered while constraint
|
||||
checks were off.
|
||||
"""
|
||||
with self.cursor() as cursor:
|
||||
if table_names is None:
|
||||
violations = cursor.execute("PRAGMA foreign_key_check").fetchall()
|
||||
else:
|
||||
violations = chain.from_iterable(
|
||||
cursor.execute(
|
||||
"PRAGMA foreign_key_check(%s)" % self.ops.quote_name(table_name)
|
||||
).fetchall()
|
||||
for table_name in table_names
|
||||
)
|
||||
# See https://www.sqlite.org/pragma.html#pragma_foreign_key_check
|
||||
for (
|
||||
table_name,
|
||||
rowid,
|
||||
referenced_table_name,
|
||||
foreign_key_index,
|
||||
) in violations:
|
||||
foreign_key = cursor.execute(
|
||||
"PRAGMA foreign_key_list(%s)" % self.ops.quote_name(table_name)
|
||||
).fetchall()[foreign_key_index]
|
||||
column_name, referenced_column_name = foreign_key[3:5]
|
||||
primary_key_column_name = self.introspection.get_primary_key_column(
|
||||
cursor, table_name
|
||||
)
|
||||
primary_key_value, bad_value = cursor.execute(
|
||||
"SELECT %s, %s FROM %s WHERE rowid = %%s"
|
||||
% (
|
||||
self.ops.quote_name(primary_key_column_name),
|
||||
self.ops.quote_name(column_name),
|
||||
self.ops.quote_name(table_name),
|
||||
),
|
||||
(rowid,),
|
||||
).fetchone()
|
||||
raise IntegrityError(
|
||||
"The row in table '%s' with primary key '%s' has an "
|
||||
"invalid foreign key: %s.%s contains a value '%s' that "
|
||||
"does not have a corresponding value in %s.%s."
|
||||
% (
|
||||
table_name,
|
||||
primary_key_value,
|
||||
table_name,
|
||||
column_name,
|
||||
bad_value,
|
||||
referenced_table_name,
|
||||
referenced_column_name,
|
||||
)
|
||||
)
|
||||
|
||||
def is_usable(self):
|
||||
return True
|
||||
|
||||
def _start_transaction_under_autocommit(self):
|
||||
"""
|
||||
Start a transaction explicitly in autocommit mode.
|
||||
|
||||
Staying in autocommit mode works around a bug of sqlite3 that breaks
|
||||
savepoints when autocommit is disabled.
|
||||
"""
|
||||
if self.transaction_mode is None:
|
||||
self.cursor().execute("BEGIN")
|
||||
else:
|
||||
self.cursor().execute(f"BEGIN {self.transaction_mode}")
|
||||
|
||||
def is_in_memory_db(self):
|
||||
return self.creation.is_in_memory_db(self.settings_dict["NAME"])
|
||||
|
||||
|
||||
FORMAT_QMARK_REGEX = _lazy_re_compile(r"(?<!%)%s")
|
||||
|
||||
|
||||
class SQLiteCursorWrapper(Database.Cursor):
|
||||
"""
|
||||
Django uses the "format" and "pyformat" styles, but Python's sqlite3 module
|
||||
supports neither of these styles.
|
||||
|
||||
This wrapper performs the following conversions:
|
||||
|
||||
- "format" style to "qmark" style
|
||||
- "pyformat" style to "named" style
|
||||
|
||||
In both cases, if you want to use a literal "%s", you'll need to use "%%s".
|
||||
"""
|
||||
|
||||
def execute(self, query, params=None):
|
||||
if params is None:
|
||||
return super().execute(query)
|
||||
# Extract names if params is a mapping, i.e. "pyformat" style is used.
|
||||
param_names = list(params) if isinstance(params, Mapping) else None
|
||||
query = self.convert_query(query, param_names=param_names)
|
||||
return super().execute(query, params)
|
||||
|
||||
def executemany(self, query, param_list):
|
||||
# Extract names if params is a mapping, i.e. "pyformat" style is used.
|
||||
# Peek carefully as a generator can be passed instead of a list/tuple.
|
||||
peekable, param_list = tee(iter(param_list))
|
||||
if (params := next(peekable, None)) and isinstance(params, Mapping):
|
||||
param_names = list(params)
|
||||
else:
|
||||
param_names = None
|
||||
query = self.convert_query(query, param_names=param_names)
|
||||
return super().executemany(query, param_list)
|
||||
|
||||
def convert_query(self, query, *, param_names=None):
|
||||
if param_names is None:
|
||||
# Convert from "format" style to "qmark" style.
|
||||
return FORMAT_QMARK_REGEX.sub("?", query).replace("%%", "%")
|
||||
else:
|
||||
# Convert from "pyformat" style to "named" style.
|
||||
return query % {name: f":{name}" for name in param_names}
|
||||
@@ -0,0 +1,10 @@
|
||||
from django.db.backends.base.client import BaseDatabaseClient
|
||||
|
||||
|
||||
class DatabaseClient(BaseDatabaseClient):
|
||||
executable_name = "sqlite3"
|
||||
|
||||
@classmethod
|
||||
def settings_to_cmd_args_env(cls, settings_dict, parameters):
|
||||
args = [cls.executable_name, settings_dict["NAME"], *parameters]
|
||||
return args, None
|
||||
@@ -0,0 +1,159 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from django.db import NotSupportedError
|
||||
from django.db.backends.base.creation import BaseDatabaseCreation
|
||||
|
||||
|
||||
class DatabaseCreation(BaseDatabaseCreation):
|
||||
@staticmethod
|
||||
def is_in_memory_db(database_name):
|
||||
return not isinstance(database_name, Path) and (
|
||||
database_name == ":memory:" or "mode=memory" in database_name
|
||||
)
|
||||
|
||||
def _get_test_db_name(self):
|
||||
test_database_name = self.connection.settings_dict["TEST"]["NAME"] or ":memory:"
|
||||
if test_database_name == ":memory:":
|
||||
return "file:memorydb_%s?mode=memory&cache=shared" % self.connection.alias
|
||||
return test_database_name
|
||||
|
||||
def _create_test_db(self, verbosity, autoclobber, keepdb=False):
|
||||
test_database_name = self._get_test_db_name()
|
||||
|
||||
if keepdb:
|
||||
return test_database_name
|
||||
if not self.is_in_memory_db(test_database_name):
|
||||
# Erase the old test database
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (self._get_database_display_str(verbosity, test_database_name),)
|
||||
)
|
||||
if os.access(test_database_name, os.F_OK):
|
||||
if not autoclobber:
|
||||
confirm = input(
|
||||
"Type 'yes' if you would like to try deleting the test "
|
||||
"database '%s', or 'no' to cancel: " % test_database_name
|
||||
)
|
||||
if autoclobber or confirm == "yes":
|
||||
try:
|
||||
os.remove(test_database_name)
|
||||
except Exception as e:
|
||||
self.log("Got an error deleting the old test database: %s" % e)
|
||||
sys.exit(2)
|
||||
else:
|
||||
self.log("Tests cancelled.")
|
||||
sys.exit(1)
|
||||
return test_database_name
|
||||
|
||||
def get_test_db_clone_settings(self, suffix):
|
||||
orig_settings_dict = self.connection.settings_dict
|
||||
source_database_name = orig_settings_dict["NAME"] or ":memory:"
|
||||
|
||||
if not self.is_in_memory_db(source_database_name):
|
||||
root, ext = os.path.splitext(source_database_name)
|
||||
return {**orig_settings_dict, "NAME": f"{root}_{suffix}{ext}"}
|
||||
|
||||
start_method = multiprocessing.get_start_method()
|
||||
if start_method == "fork":
|
||||
return orig_settings_dict
|
||||
if start_method == "spawn":
|
||||
return {
|
||||
**orig_settings_dict,
|
||||
"NAME": f"{self.connection.alias}_{suffix}.sqlite3",
|
||||
}
|
||||
raise NotSupportedError(
|
||||
f"Cloning with start method {start_method!r} is not supported."
|
||||
)
|
||||
|
||||
def _clone_test_db(self, suffix, verbosity, keepdb=False):
|
||||
source_database_name = self.connection.settings_dict["NAME"]
|
||||
target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
|
||||
if not self.is_in_memory_db(source_database_name):
|
||||
# Erase the old test database
|
||||
if os.access(target_database_name, os.F_OK):
|
||||
if keepdb:
|
||||
return
|
||||
if verbosity >= 1:
|
||||
self.log(
|
||||
"Destroying old test database for alias %s..."
|
||||
% (
|
||||
self._get_database_display_str(
|
||||
verbosity, target_database_name
|
||||
),
|
||||
)
|
||||
)
|
||||
try:
|
||||
os.remove(target_database_name)
|
||||
except Exception as e:
|
||||
self.log("Got an error deleting the old test database: %s" % e)
|
||||
sys.exit(2)
|
||||
try:
|
||||
shutil.copy(source_database_name, target_database_name)
|
||||
except Exception as e:
|
||||
self.log("Got an error cloning the test database: %s" % e)
|
||||
sys.exit(2)
|
||||
# Forking automatically makes a copy of an in-memory database.
|
||||
# Spawn requires migrating to disk which will be re-opened in
|
||||
# setup_worker_connection.
|
||||
elif multiprocessing.get_start_method() == "spawn":
|
||||
ondisk_db = sqlite3.connect(target_database_name, uri=True)
|
||||
self.connection.connection.backup(ondisk_db)
|
||||
ondisk_db.close()
|
||||
|
||||
def _destroy_test_db(self, test_database_name, verbosity):
|
||||
if test_database_name and not self.is_in_memory_db(test_database_name):
|
||||
# Remove the SQLite database file
|
||||
os.remove(test_database_name)
|
||||
|
||||
def test_db_signature(self):
|
||||
"""
|
||||
Return a tuple that uniquely identifies a test database.
|
||||
|
||||
This takes into account the special cases of ":memory:" and "" for
|
||||
SQLite since the databases will be distinct despite having the same
|
||||
TEST NAME. See https://www.sqlite.org/inmemorydb.html
|
||||
"""
|
||||
test_database_name = self._get_test_db_name()
|
||||
sig = [self.connection.settings_dict["NAME"]]
|
||||
if self.is_in_memory_db(test_database_name):
|
||||
sig.append(self.connection.alias)
|
||||
else:
|
||||
sig.append(test_database_name)
|
||||
return tuple(sig)
|
||||
|
||||
def setup_worker_connection(self, _worker_id):
|
||||
settings_dict = self.get_test_db_clone_settings(_worker_id)
|
||||
# connection.settings_dict must be updated in place for changes to be
|
||||
# reflected in django.db.connections. Otherwise new threads would
|
||||
# connect to the default database instead of the appropriate clone.
|
||||
start_method = multiprocessing.get_start_method()
|
||||
if start_method == "fork":
|
||||
# Update settings_dict in place.
|
||||
self.connection.settings_dict.update(settings_dict)
|
||||
self.connection.close()
|
||||
elif start_method == "spawn":
|
||||
alias = self.connection.alias
|
||||
connection_str = (
|
||||
f"file:memorydb_{alias}_{_worker_id}?mode=memory&cache=shared"
|
||||
)
|
||||
source_db = self.connection.Database.connect(
|
||||
f"file:{alias}_{_worker_id}.sqlite3?mode=ro", uri=True
|
||||
)
|
||||
target_db = sqlite3.connect(connection_str, uri=True)
|
||||
source_db.backup(target_db)
|
||||
source_db.close()
|
||||
# Update settings_dict in place.
|
||||
self.connection.settings_dict.update(settings_dict)
|
||||
self.connection.settings_dict["NAME"] = connection_str
|
||||
# Re-open connection to in-memory database before closing copy
|
||||
# connection.
|
||||
self.connection.connect()
|
||||
target_db.close()
|
||||
if os.environ.get("RUNNING_DJANGOS_TEST_SUITE") == "true":
|
||||
self.mark_expected_failures_and_skips()
|
||||
@@ -0,0 +1,168 @@
|
||||
import operator
|
||||
|
||||
from django.db import transaction
|
||||
from django.db.backends.base.features import BaseDatabaseFeatures
|
||||
from django.db.utils import OperationalError
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Database
|
||||
|
||||
|
||||
class DatabaseFeatures(BaseDatabaseFeatures):
|
||||
minimum_database_version = (3, 31)
|
||||
test_db_allows_multiple_connections = False
|
||||
supports_unspecified_pk = True
|
||||
supports_timezones = False
|
||||
max_query_params = 999
|
||||
supports_transactions = True
|
||||
atomic_transactions = False
|
||||
can_rollback_ddl = True
|
||||
can_create_inline_fk = False
|
||||
requires_literal_defaults = True
|
||||
can_clone_databases = True
|
||||
supports_temporal_subtraction = True
|
||||
ignores_table_name_case = True
|
||||
supports_cast_with_precision = False
|
||||
time_cast_precision = 3
|
||||
can_release_savepoints = True
|
||||
has_case_insensitive_like = True
|
||||
# Is "ALTER TABLE ... DROP COLUMN" supported?
|
||||
can_alter_table_drop_column = Database.sqlite_version_info >= (3, 35, 5)
|
||||
supports_parentheses_in_compound = False
|
||||
can_defer_constraint_checks = True
|
||||
supports_over_clause = True
|
||||
supports_frame_range_fixed_distance = True
|
||||
supports_frame_exclusion = True
|
||||
supports_aggregate_filter_clause = True
|
||||
order_by_nulls_first = True
|
||||
supports_json_field_contains = False
|
||||
supports_update_conflicts = True
|
||||
supports_update_conflicts_with_target = True
|
||||
supports_stored_generated_columns = True
|
||||
supports_virtual_generated_columns = True
|
||||
test_collations = {
|
||||
"ci": "nocase",
|
||||
"cs": "binary",
|
||||
"non_default": "nocase",
|
||||
"virtual": "nocase",
|
||||
}
|
||||
django_test_expected_failures = {
|
||||
# The django_format_dtdelta() function doesn't properly handle mixed
|
||||
# Date/DateTime fields and timedeltas.
|
||||
"expressions.tests.FTimeDeltaTests.test_mixed_comparisons1",
|
||||
}
|
||||
create_test_table_with_composite_primary_key = """
|
||||
CREATE TABLE test_table_composite_pk (
|
||||
column_1 INTEGER NOT NULL,
|
||||
column_2 INTEGER NOT NULL,
|
||||
PRIMARY KEY(column_1, column_2)
|
||||
)
|
||||
"""
|
||||
insert_test_table_with_defaults = 'INSERT INTO {} ("null") VALUES (1)'
|
||||
supports_default_keyword_in_insert = False
|
||||
supports_unlimited_charfield = True
|
||||
supports_tuple_lookups = False
|
||||
|
||||
@cached_property
|
||||
def django_test_skips(self):
|
||||
skips = {
|
||||
"SQLite stores values rounded to 15 significant digits.": {
|
||||
"model_fields.test_decimalfield.DecimalFieldTests."
|
||||
"test_fetch_from_db_without_float_rounding",
|
||||
},
|
||||
"SQLite naively remakes the table on field alteration.": {
|
||||
"schema.tests.SchemaTests.test_unique_no_unnecessary_fk_drops",
|
||||
"schema.tests.SchemaTests.test_unique_and_reverse_m2m",
|
||||
"schema.tests.SchemaTests."
|
||||
"test_alter_field_default_doesnt_perform_queries",
|
||||
"schema.tests.SchemaTests."
|
||||
"test_rename_column_renames_deferred_sql_references",
|
||||
},
|
||||
"SQLite doesn't support negative precision for ROUND().": {
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_null_with_negative_precision",
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_decimal_with_negative_precision",
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_float_with_negative_precision",
|
||||
"db_functions.math.test_round.RoundTests."
|
||||
"test_integer_with_negative_precision",
|
||||
},
|
||||
"The actual query cannot be determined on SQLite": {
|
||||
"backends.base.test_base.ExecuteWrapperTests.test_wrapper_debug",
|
||||
},
|
||||
}
|
||||
if self.connection.is_in_memory_db():
|
||||
skips.update(
|
||||
{
|
||||
"the sqlite backend's close() method is a no-op when using an "
|
||||
"in-memory database": {
|
||||
"servers.test_liveserverthread.LiveServerThreadTest."
|
||||
"test_closes_connections",
|
||||
"servers.tests.LiveServerTestCloseConnectionTest."
|
||||
"test_closes_connections",
|
||||
},
|
||||
"For SQLite in-memory tests, closing the connection destroys "
|
||||
"the database.": {
|
||||
"test_utils.tests.AssertNumQueriesUponConnectionTests."
|
||||
"test_ignores_connection_configuration_queries",
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
skips.update(
|
||||
{
|
||||
"Only connections to in-memory SQLite databases are passed to the "
|
||||
"server thread.": {
|
||||
"servers.tests.LiveServerInMemoryDatabaseLockTest."
|
||||
"test_in_memory_database_lock",
|
||||
},
|
||||
"multiprocessing's start method is checked only for in-memory "
|
||||
"SQLite databases": {
|
||||
"backends.sqlite.test_creation.TestDbSignatureTests."
|
||||
"test_get_test_db_clone_settings_not_supported",
|
||||
},
|
||||
}
|
||||
)
|
||||
if Database.sqlite_version_info < (3, 47):
|
||||
skips.update(
|
||||
{
|
||||
"SQLite does not parse escaped double quotes in the JSON path "
|
||||
"notation": {
|
||||
"model_fields.test_jsonfield.TestQuerying."
|
||||
"test_lookups_special_chars_double_quotes",
|
||||
},
|
||||
}
|
||||
)
|
||||
return skips
|
||||
|
||||
@cached_property
|
||||
def introspected_field_types(self):
|
||||
return {
|
||||
**super().introspected_field_types,
|
||||
"BigAutoField": "AutoField",
|
||||
"DurationField": "BigIntegerField",
|
||||
"GenericIPAddressField": "CharField",
|
||||
"SmallAutoField": "AutoField",
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def supports_json_field(self):
|
||||
with self.connection.cursor() as cursor:
|
||||
try:
|
||||
with transaction.atomic(self.connection.alias):
|
||||
cursor.execute('SELECT JSON(\'{"a": "b"}\')')
|
||||
except OperationalError:
|
||||
return False
|
||||
return True
|
||||
|
||||
can_introspect_json_field = property(operator.attrgetter("supports_json_field"))
|
||||
has_json_object_function = property(operator.attrgetter("supports_json_field"))
|
||||
|
||||
@cached_property
|
||||
def can_return_columns_from_insert(self):
|
||||
return Database.sqlite_version_info >= (3, 35)
|
||||
|
||||
can_return_rows_from_bulk_insert = property(
|
||||
operator.attrgetter("can_return_columns_from_insert")
|
||||
)
|
||||
@@ -0,0 +1,440 @@
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlparse
|
||||
|
||||
from django.db import DatabaseError
|
||||
from django.db.backends.base.introspection import BaseDatabaseIntrospection
|
||||
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
|
||||
from django.db.backends.base.introspection import TableInfo
|
||||
from django.db.models import Index
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
FieldInfo = namedtuple(
|
||||
"FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
|
||||
)
|
||||
|
||||
field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")
|
||||
|
||||
|
||||
def get_field_size(name):
|
||||
"""Extract the size number from a "varchar(11)" type name"""
|
||||
m = field_size_re.search(name)
|
||||
return int(m[1]) if m else None
|
||||
|
||||
|
||||
# This light wrapper "fakes" a dictionary interface, because some SQLite data
|
||||
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
|
||||
# as a simple dictionary lookup.
|
||||
class FlexibleFieldLookupDict:
|
||||
# Maps SQL types to Django Field types. Some of the SQL types have multiple
|
||||
# entries here because SQLite allows for anything and doesn't normalize the
|
||||
# field type; it uses whatever was given.
|
||||
base_data_types_reverse = {
|
||||
"bool": "BooleanField",
|
||||
"boolean": "BooleanField",
|
||||
"smallint": "SmallIntegerField",
|
||||
"smallint unsigned": "PositiveSmallIntegerField",
|
||||
"smallinteger": "SmallIntegerField",
|
||||
"int": "IntegerField",
|
||||
"integer": "IntegerField",
|
||||
"bigint": "BigIntegerField",
|
||||
"integer unsigned": "PositiveIntegerField",
|
||||
"bigint unsigned": "PositiveBigIntegerField",
|
||||
"decimal": "DecimalField",
|
||||
"real": "FloatField",
|
||||
"text": "TextField",
|
||||
"char": "CharField",
|
||||
"varchar": "CharField",
|
||||
"blob": "BinaryField",
|
||||
"date": "DateField",
|
||||
"datetime": "DateTimeField",
|
||||
"time": "TimeField",
|
||||
}
|
||||
|
||||
def __getitem__(self, key):
|
||||
key = key.lower().split("(", 1)[0].strip()
|
||||
return self.base_data_types_reverse[key]
|
||||
|
||||
|
||||
class DatabaseIntrospection(BaseDatabaseIntrospection):
|
||||
data_types_reverse = FlexibleFieldLookupDict()
|
||||
|
||||
def get_field_type(self, data_type, description):
|
||||
field_type = super().get_field_type(data_type, description)
|
||||
if description.pk and field_type in {
|
||||
"BigIntegerField",
|
||||
"IntegerField",
|
||||
"SmallIntegerField",
|
||||
}:
|
||||
# No support for BigAutoField or SmallAutoField as SQLite treats
|
||||
# all integer primary keys as signed 64-bit integers.
|
||||
return "AutoField"
|
||||
if description.has_json_constraint:
|
||||
return "JSONField"
|
||||
return field_type
|
||||
|
||||
def get_table_list(self, cursor):
|
||||
"""Return a list of table and view names in the current database."""
|
||||
# Skip the sqlite_sequence system table used for autoincrement key
|
||||
# generation.
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT name, type FROM sqlite_master
|
||||
WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
|
||||
ORDER BY name"""
|
||||
)
|
||||
return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]
|
||||
|
||||
def get_table_description(self, cursor, table_name):
|
||||
"""
|
||||
Return a description of the table with the DB-API cursor.description
|
||||
interface.
|
||||
"""
|
||||
cursor.execute(
|
||||
"PRAGMA table_xinfo(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
table_info = cursor.fetchall()
|
||||
if not table_info:
|
||||
raise DatabaseError(f"Table {table_name} does not exist (empty pragma).")
|
||||
collations = self._get_column_collations(cursor, table_name)
|
||||
json_columns = set()
|
||||
if self.connection.features.can_introspect_json_field:
|
||||
for line in table_info:
|
||||
column = line[1]
|
||||
json_constraint_sql = '%%json_valid("%s")%%' % column
|
||||
has_json_constraint = cursor.execute(
|
||||
"""
|
||||
SELECT sql
|
||||
FROM sqlite_master
|
||||
WHERE
|
||||
type = 'table' AND
|
||||
name = %s AND
|
||||
sql LIKE %s
|
||||
""",
|
||||
[table_name, json_constraint_sql],
|
||||
).fetchone()
|
||||
if has_json_constraint:
|
||||
json_columns.add(column)
|
||||
return [
|
||||
FieldInfo(
|
||||
name,
|
||||
data_type,
|
||||
get_field_size(data_type),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
not notnull,
|
||||
default,
|
||||
collations.get(name),
|
||||
pk == 1,
|
||||
name in json_columns,
|
||||
)
|
||||
for cid, name, data_type, notnull, default, pk, hidden in table_info
|
||||
if hidden
|
||||
in [
|
||||
0, # Normal column.
|
||||
2, # Virtual generated column.
|
||||
3, # Stored generated column.
|
||||
]
|
||||
]
|
||||
|
||||
def get_sequences(self, cursor, table_name, table_fields=()):
|
||||
pk_col = self.get_primary_key_column(cursor, table_name)
|
||||
return [{"table": table_name, "column": pk_col}]
|
||||
|
||||
def get_relations(self, cursor, table_name):
|
||||
"""
|
||||
Return a dictionary of {column_name: (ref_column_name, ref_table_name)}
|
||||
representing all foreign keys in the given table.
|
||||
"""
|
||||
cursor.execute(
|
||||
"PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
return {
|
||||
column_name: (ref_column_name, ref_table_name)
|
||||
for (
|
||||
_,
|
||||
_,
|
||||
ref_table_name,
|
||||
column_name,
|
||||
ref_column_name,
|
||||
*_,
|
||||
) in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_primary_key_columns(self, cursor, table_name):
|
||||
cursor.execute(
|
||||
"PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
return [name for _, name, *_, pk in cursor.fetchall() if pk]
|
||||
|
||||
def _parse_column_or_constraint_definition(self, tokens, columns):
|
||||
token = None
|
||||
is_constraint_definition = None
|
||||
field_name = None
|
||||
constraint_name = None
|
||||
unique = False
|
||||
unique_columns = []
|
||||
check = False
|
||||
check_columns = []
|
||||
braces_deep = 0
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, "("):
|
||||
braces_deep += 1
|
||||
elif token.match(sqlparse.tokens.Punctuation, ")"):
|
||||
braces_deep -= 1
|
||||
if braces_deep < 0:
|
||||
# End of columns and constraints for table definition.
|
||||
break
|
||||
elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
|
||||
# End of current column or constraint definition.
|
||||
break
|
||||
# Detect column or constraint definition by first token.
|
||||
if is_constraint_definition is None:
|
||||
is_constraint_definition = token.match(
|
||||
sqlparse.tokens.Keyword, "CONSTRAINT"
|
||||
)
|
||||
if is_constraint_definition:
|
||||
continue
|
||||
if is_constraint_definition:
|
||||
# Detect constraint name by second token.
|
||||
if constraint_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
constraint_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
constraint_name = token.value[1:-1]
|
||||
# Start constraint columns parsing after UNIQUE keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
|
||||
unique = True
|
||||
unique_braces_deep = braces_deep
|
||||
elif unique:
|
||||
if unique_braces_deep == braces_deep:
|
||||
if unique_columns:
|
||||
# Stop constraint parsing.
|
||||
unique = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
unique_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
unique_columns.append(token.value[1:-1])
|
||||
else:
|
||||
# Detect field name by first token.
|
||||
if field_name is None:
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
field_name = token.value
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
field_name = token.value[1:-1]
|
||||
if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
|
||||
unique_columns = [field_name]
|
||||
# Start constraint columns parsing after CHECK keyword.
|
||||
if token.match(sqlparse.tokens.Keyword, "CHECK"):
|
||||
check = True
|
||||
check_braces_deep = braces_deep
|
||||
elif check:
|
||||
if check_braces_deep == braces_deep:
|
||||
if check_columns:
|
||||
# Stop constraint parsing.
|
||||
check = False
|
||||
continue
|
||||
if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
|
||||
if token.value in columns:
|
||||
check_columns.append(token.value)
|
||||
elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
|
||||
if token.value[1:-1] in columns:
|
||||
check_columns.append(token.value[1:-1])
|
||||
unique_constraint = (
|
||||
{
|
||||
"unique": True,
|
||||
"columns": unique_columns,
|
||||
"primary_key": False,
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
if unique_columns
|
||||
else None
|
||||
)
|
||||
check_constraint = (
|
||||
{
|
||||
"check": True,
|
||||
"columns": check_columns,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"foreign_key": None,
|
||||
"index": False,
|
||||
}
|
||||
if check_columns
|
||||
else None
|
||||
)
|
||||
return constraint_name, unique_constraint, check_constraint, token
|
||||
|
||||
def _parse_table_constraints(self, sql, columns):
|
||||
# Check constraint parsing is based of SQLite syntax diagram.
|
||||
# https://www.sqlite.org/syntaxdiagrams.html#table-constraint
|
||||
statement = sqlparse.parse(sql)[0]
|
||||
constraints = {}
|
||||
unnamed_constrains_index = 0
|
||||
tokens = (token for token in statement.flatten() if not token.is_whitespace)
|
||||
# Go to columns and constraint definition
|
||||
for token in tokens:
|
||||
if token.match(sqlparse.tokens.Punctuation, "("):
|
||||
break
|
||||
# Parse columns and constraint definition
|
||||
while True:
|
||||
(
|
||||
constraint_name,
|
||||
unique,
|
||||
check,
|
||||
end_token,
|
||||
) = self._parse_column_or_constraint_definition(tokens, columns)
|
||||
if unique:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = unique
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints[
|
||||
"__unnamed_constraint_%s__" % unnamed_constrains_index
|
||||
] = unique
|
||||
if check:
|
||||
if constraint_name:
|
||||
constraints[constraint_name] = check
|
||||
else:
|
||||
unnamed_constrains_index += 1
|
||||
constraints[
|
||||
"__unnamed_constraint_%s__" % unnamed_constrains_index
|
||||
] = check
|
||||
if end_token.match(sqlparse.tokens.Punctuation, ")"):
|
||||
break
|
||||
return constraints
|
||||
|
||||
def get_constraints(self, cursor, table_name):
|
||||
"""
|
||||
Retrieve any constraints or keys (unique, pk, fk, check, index) across
|
||||
one or more columns.
|
||||
"""
|
||||
constraints = {}
|
||||
# Find inline check constraints.
|
||||
try:
|
||||
table_schema = cursor.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE type='table' and name=%s",
|
||||
[table_name],
|
||||
).fetchone()[0]
|
||||
except TypeError:
|
||||
# table_name is a view.
|
||||
pass
|
||||
else:
|
||||
columns = {
|
||||
info.name for info in self.get_table_description(cursor, table_name)
|
||||
}
|
||||
constraints.update(self._parse_table_constraints(table_schema, columns))
|
||||
|
||||
# Get the index info
|
||||
cursor.execute(
|
||||
"PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
# SQLite 3.8.9+ has 5 columns, however older versions only give 3
|
||||
# columns. Discard last 2 columns if there.
|
||||
number, index, unique = row[:3]
|
||||
cursor.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE type='index' AND name=%s",
|
||||
[index],
|
||||
)
|
||||
# There's at most one row.
|
||||
(sql,) = cursor.fetchone() or (None,)
|
||||
# Inline constraints are already detected in
|
||||
# _parse_table_constraints(). The reasons to avoid fetching inline
|
||||
# constraints from `PRAGMA index_list` are:
|
||||
# - Inline constraints can have a different name and information
|
||||
# than what `PRAGMA index_list` gives.
|
||||
# - Not all inline constraints may appear in `PRAGMA index_list`.
|
||||
if not sql:
|
||||
# An inline constraint
|
||||
continue
|
||||
# Get the index info for that index
|
||||
cursor.execute(
|
||||
"PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
|
||||
)
|
||||
for index_rank, column_rank, column in cursor.fetchall():
|
||||
if index not in constraints:
|
||||
constraints[index] = {
|
||||
"columns": [],
|
||||
"primary_key": False,
|
||||
"unique": bool(unique),
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": True,
|
||||
}
|
||||
constraints[index]["columns"].append(column)
|
||||
# Add type and column orders for indexes
|
||||
if constraints[index]["index"]:
|
||||
# SQLite doesn't support any index type other than b-tree
|
||||
constraints[index]["type"] = Index.suffix
|
||||
orders = self._get_index_columns_orders(sql)
|
||||
if orders is not None:
|
||||
constraints[index]["orders"] = orders
|
||||
# Get the PK
|
||||
pk_columns = self.get_primary_key_columns(cursor, table_name)
|
||||
if pk_columns:
|
||||
# SQLite doesn't actually give a name to the PK constraint,
|
||||
# so we invent one. This is fine, as the SQLite backend never
|
||||
# deletes PK constraints by name, as you can't delete constraints
|
||||
# in SQLite; we remake the table with a new PK instead.
|
||||
constraints["__primary__"] = {
|
||||
"columns": pk_columns,
|
||||
"primary_key": True,
|
||||
"unique": False, # It's not actually a unique constraint.
|
||||
"foreign_key": None,
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
relations = enumerate(self.get_relations(cursor, table_name).items())
|
||||
constraints.update(
|
||||
{
|
||||
f"fk_{index}": {
|
||||
"columns": [column_name],
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"foreign_key": (ref_table_name, ref_column_name),
|
||||
"check": False,
|
||||
"index": False,
|
||||
}
|
||||
for index, (column_name, (ref_column_name, ref_table_name)) in relations
|
||||
}
|
||||
)
|
||||
return constraints
|
||||
|
||||
def _get_index_columns_orders(self, sql):
|
||||
tokens = sqlparse.parse(sql)[0]
|
||||
for token in tokens:
|
||||
if isinstance(token, sqlparse.sql.Parenthesis):
|
||||
columns = str(token).strip("()").split(", ")
|
||||
return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
|
||||
return None
|
||||
|
||||
def _get_column_collations(self, cursor, table_name):
|
||||
row = cursor.execute(
|
||||
"""
|
||||
SELECT sql
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table' AND name = %s
|
||||
""",
|
||||
[table_name],
|
||||
).fetchone()
|
||||
if not row:
|
||||
return {}
|
||||
|
||||
sql = row[0]
|
||||
columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
|
||||
collations = {}
|
||||
for column in columns:
|
||||
tokens = column[1:].split()
|
||||
column_name = tokens[0].strip('"')
|
||||
for index, token in enumerate(tokens):
|
||||
if token == "COLLATE":
|
||||
collation = tokens[index + 1]
|
||||
break
|
||||
else:
|
||||
collation = None
|
||||
collations[column_name] = collation
|
||||
return collations
|
||||
@@ -0,0 +1,443 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import uuid
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldError
|
||||
from django.db import DatabaseError, NotSupportedError, models
|
||||
from django.db.backends.base.operations import BaseDatabaseOperations
|
||||
from django.db.models.constants import OnConflict
|
||||
from django.db.models.expressions import Col
|
||||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Database
|
||||
|
||||
|
||||
class DatabaseOperations(BaseDatabaseOperations):
|
||||
cast_char_field_without_max_length = "text"
|
||||
cast_data_types = {
|
||||
"DateField": "TEXT",
|
||||
"DateTimeField": "TEXT",
|
||||
}
|
||||
explain_prefix = "EXPLAIN QUERY PLAN"
|
||||
# List of datatypes to that cannot be extracted with JSON_EXTRACT() on
|
||||
# SQLite. Use JSON_TYPE() instead.
|
||||
jsonfield_datatype_values = frozenset(["null", "false", "true"])
|
||||
|
||||
def bulk_batch_size(self, fields, objs):
|
||||
"""
|
||||
SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
|
||||
999 variables per query.
|
||||
|
||||
If there's only a single field to insert, the limit is 500
|
||||
(SQLITE_MAX_COMPOUND_SELECT).
|
||||
"""
|
||||
fields = list(
|
||||
chain.from_iterable(
|
||||
(
|
||||
field.fields
|
||||
if isinstance(field, models.CompositePrimaryKey)
|
||||
else [field]
|
||||
)
|
||||
for field in fields
|
||||
)
|
||||
)
|
||||
if len(fields) == 1:
|
||||
return 500
|
||||
elif len(fields) > 1:
|
||||
return self.connection.features.max_query_params // len(fields)
|
||||
else:
|
||||
return len(objs)
|
||||
|
||||
def check_expression_support(self, expression):
|
||||
bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
|
||||
bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
|
||||
if isinstance(expression, bad_aggregates):
|
||||
for expr in expression.get_source_expressions():
|
||||
try:
|
||||
output_field = expr.output_field
|
||||
except (AttributeError, FieldError):
|
||||
# Not every subexpression has an output_field which is fine
|
||||
# to ignore.
|
||||
pass
|
||||
else:
|
||||
if isinstance(output_field, bad_fields):
|
||||
raise NotSupportedError(
|
||||
"You cannot use Sum, Avg, StdDev, and Variance "
|
||||
"aggregations on date/time fields in sqlite3 "
|
||||
"since date/time is saved as text."
|
||||
)
|
||||
if (
|
||||
isinstance(expression, models.Aggregate)
|
||||
and expression.distinct
|
||||
and len(expression.source_expressions) > 1
|
||||
):
|
||||
raise NotSupportedError(
|
||||
"SQLite doesn't support DISTINCT on aggregate functions "
|
||||
"accepting multiple arguments."
|
||||
)
|
||||
|
||||
def date_extract_sql(self, lookup_type, sql, params):
|
||||
"""
|
||||
Support EXTRACT with a user-defined function django_date_extract()
|
||||
that's registered in connect(). Use single quotes because this is a
|
||||
string and could otherwise cause a collision with a field name.
|
||||
"""
|
||||
return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params)
|
||||
|
||||
def fetch_returned_insert_rows(self, cursor):
|
||||
"""
|
||||
Given a cursor object that has just performed an INSERT...RETURNING
|
||||
statement into a table, return the list of returned data.
|
||||
"""
|
||||
return cursor.fetchall()
|
||||
|
||||
def format_for_duration_arithmetic(self, sql):
|
||||
"""Do nothing since formatting is handled in the custom function."""
|
||||
return sql
|
||||
|
||||
def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
return f"django_date_trunc(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
|
||||
return f"django_time_trunc(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def _convert_tznames_to_sql(self, tzname):
|
||||
if tzname and settings.USE_TZ:
|
||||
return tzname, self.connection.timezone_name
|
||||
return None, None
|
||||
|
||||
def datetime_cast_date_sql(self, sql, params, tzname):
|
||||
return f"django_datetime_cast_date({sql}, %s, %s)", (
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_cast_time_sql(self, sql, params, tzname):
|
||||
return f"django_datetime_cast_time({sql}, %s, %s)", (
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_extract_sql(self, lookup_type, sql, params, tzname):
|
||||
return f"django_datetime_extract(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
|
||||
return f"django_datetime_trunc(%s, {sql}, %s, %s)", (
|
||||
lookup_type.lower(),
|
||||
*params,
|
||||
*self._convert_tznames_to_sql(tzname),
|
||||
)
|
||||
|
||||
def time_extract_sql(self, lookup_type, sql, params):
|
||||
return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params)
|
||||
|
||||
def pk_default_value(self):
|
||||
return "NULL"
|
||||
|
||||
def _quote_params_for_last_executed_query(self, params):
|
||||
"""
|
||||
Only for last_executed_query! Don't use this to execute SQL queries!
|
||||
"""
|
||||
# This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
|
||||
# number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
|
||||
# number of return values, default = 2000). Since Python's sqlite3
|
||||
# module doesn't expose the get_limit() C API, assume the default
|
||||
# limits are in effect and split the work in batches if needed.
|
||||
BATCH_SIZE = 999
|
||||
if len(params) > BATCH_SIZE:
|
||||
results = ()
|
||||
for index in range(0, len(params), BATCH_SIZE):
|
||||
chunk = params[index : index + BATCH_SIZE]
|
||||
results += self._quote_params_for_last_executed_query(chunk)
|
||||
return results
|
||||
|
||||
sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
|
||||
# Bypass Django's wrappers and use the underlying sqlite3 connection
|
||||
# to avoid logging this query - it would trigger infinite recursion.
|
||||
cursor = self.connection.connection.cursor()
|
||||
# Native sqlite3 cursors cannot be used as context managers.
|
||||
try:
|
||||
return cursor.execute(sql, params).fetchone()
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def last_executed_query(self, cursor, sql, params):
|
||||
# Python substitutes parameters in Modules/_sqlite/cursor.c with:
|
||||
# bind_parameters(state, self->statement, parameters);
|
||||
# Unfortunately there is no way to reach self->statement from Python,
|
||||
# so we quote and substitute parameters manually.
|
||||
if params:
|
||||
if isinstance(params, (list, tuple)):
|
||||
params = self._quote_params_for_last_executed_query(params)
|
||||
else:
|
||||
values = tuple(params.values())
|
||||
values = self._quote_params_for_last_executed_query(values)
|
||||
params = dict(zip(params, values))
|
||||
return sql % params
|
||||
# For consistency with SQLiteCursorWrapper.execute(), just return sql
|
||||
# when there are no parameters. See #13648 and #17158.
|
||||
else:
|
||||
return sql
|
||||
|
||||
def quote_name(self, name):
|
||||
if name.startswith('"') and name.endswith('"'):
|
||||
return name # Quoting once is enough.
|
||||
return '"%s"' % name
|
||||
|
||||
def no_limit_value(self):
|
||||
return -1
|
||||
|
||||
def __references_graph(self, table_name):
|
||||
query = """
|
||||
WITH tables AS (
|
||||
SELECT %s name
|
||||
UNION
|
||||
SELECT sqlite_master.name
|
||||
FROM sqlite_master
|
||||
JOIN tables ON (sql REGEXP %s || tables.name || %s)
|
||||
) SELECT name FROM tables;
|
||||
"""
|
||||
params = (
|
||||
table_name,
|
||||
r'(?i)\s+references\s+("|\')?',
|
||||
r'("|\')?\s*\(',
|
||||
)
|
||||
with self.connection.cursor() as cursor:
|
||||
results = cursor.execute(query, params)
|
||||
return [row[0] for row in results.fetchall()]
|
||||
|
||||
@cached_property
|
||||
def _references_graph(self):
|
||||
# 512 is large enough to fit the ~330 tables (as of this writing) in
|
||||
# Django's test suite.
|
||||
return lru_cache(maxsize=512)(self.__references_graph)
|
||||
|
||||
def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
|
||||
if tables and allow_cascade:
|
||||
# Simulate TRUNCATE CASCADE by recursively collecting the tables
|
||||
# referencing the tables to be flushed.
|
||||
tables = set(
|
||||
chain.from_iterable(self._references_graph(table) for table in tables)
|
||||
)
|
||||
sql = [
|
||||
"%s %s %s;"
|
||||
% (
|
||||
style.SQL_KEYWORD("DELETE"),
|
||||
style.SQL_KEYWORD("FROM"),
|
||||
style.SQL_FIELD(self.quote_name(table)),
|
||||
)
|
||||
for table in tables
|
||||
]
|
||||
if reset_sequences:
|
||||
sequences = [{"table": table} for table in tables]
|
||||
sql.extend(self.sequence_reset_by_name_sql(style, sequences))
|
||||
return sql
|
||||
|
||||
def sequence_reset_by_name_sql(self, style, sequences):
|
||||
if not sequences:
|
||||
return []
|
||||
return [
|
||||
"%s %s %s %s = 0 %s %s %s (%s);"
|
||||
% (
|
||||
style.SQL_KEYWORD("UPDATE"),
|
||||
style.SQL_TABLE(self.quote_name("sqlite_sequence")),
|
||||
style.SQL_KEYWORD("SET"),
|
||||
style.SQL_FIELD(self.quote_name("seq")),
|
||||
style.SQL_KEYWORD("WHERE"),
|
||||
style.SQL_FIELD(self.quote_name("name")),
|
||||
style.SQL_KEYWORD("IN"),
|
||||
", ".join(
|
||||
["'%s'" % sequence_info["table"] for sequence_info in sequences]
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
def adapt_datetimefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
if settings.USE_TZ:
|
||||
value = timezone.make_naive(value, self.connection.timezone)
|
||||
else:
|
||||
raise ValueError(
|
||||
"SQLite backend does not support timezone-aware datetimes when "
|
||||
"USE_TZ is False."
|
||||
)
|
||||
|
||||
return str(value)
|
||||
|
||||
def adapt_timefield_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# SQLite doesn't support tz-aware datetimes
|
||||
if timezone.is_aware(value):
|
||||
raise ValueError("SQLite backend does not support timezone-aware times.")
|
||||
|
||||
return str(value)
|
||||
|
||||
def get_db_converters(self, expression):
|
||||
converters = super().get_db_converters(expression)
|
||||
internal_type = expression.output_field.get_internal_type()
|
||||
if internal_type == "DateTimeField":
|
||||
converters.append(self.convert_datetimefield_value)
|
||||
elif internal_type == "DateField":
|
||||
converters.append(self.convert_datefield_value)
|
||||
elif internal_type == "TimeField":
|
||||
converters.append(self.convert_timefield_value)
|
||||
elif internal_type == "DecimalField":
|
||||
converters.append(self.get_decimalfield_converter(expression))
|
||||
elif internal_type == "UUIDField":
|
||||
converters.append(self.convert_uuidfield_value)
|
||||
elif internal_type == "BooleanField":
|
||||
converters.append(self.convert_booleanfield_value)
|
||||
return converters
|
||||
|
||||
def convert_datetimefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.datetime):
|
||||
value = parse_datetime(value)
|
||||
if settings.USE_TZ and not timezone.is_aware(value):
|
||||
value = timezone.make_aware(value, self.connection.timezone)
|
||||
return value
|
||||
|
||||
def convert_datefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.date):
|
||||
value = parse_date(value)
|
||||
return value
|
||||
|
||||
def convert_timefield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
if not isinstance(value, datetime.time):
|
||||
value = parse_time(value)
|
||||
return value
|
||||
|
||||
def get_decimalfield_converter(self, expression):
|
||||
# SQLite stores only 15 significant digits. Digits coming from
|
||||
# float inaccuracy must be removed.
|
||||
create_decimal = decimal.Context(prec=15).create_decimal_from_float
|
||||
if isinstance(expression, Col):
|
||||
quantize_value = decimal.Decimal(1).scaleb(
|
||||
-expression.output_field.decimal_places
|
||||
)
|
||||
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value).quantize(
|
||||
quantize_value, context=expression.output_field.context
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def converter(value, expression, connection):
|
||||
if value is not None:
|
||||
return create_decimal(value)
|
||||
|
||||
return converter
|
||||
|
||||
def convert_uuidfield_value(self, value, expression, connection):
|
||||
if value is not None:
|
||||
value = uuid.UUID(value)
|
||||
return value
|
||||
|
||||
def convert_booleanfield_value(self, value, expression, connection):
|
||||
return bool(value) if value in (1, 0) else value
|
||||
|
||||
def combine_expression(self, connector, sub_expressions):
|
||||
# SQLite doesn't have a ^ operator, so use the user-defined POWER
|
||||
# function that's registered in connect().
|
||||
if connector == "^":
|
||||
return "POWER(%s)" % ",".join(sub_expressions)
|
||||
elif connector == "#":
|
||||
return "BITXOR(%s)" % ",".join(sub_expressions)
|
||||
return super().combine_expression(connector, sub_expressions)
|
||||
|
||||
def combine_duration_expression(self, connector, sub_expressions):
|
||||
if connector not in ["+", "-", "*", "/"]:
|
||||
raise DatabaseError("Invalid connector for timedelta: %s." % connector)
|
||||
fn_params = ["'%s'" % connector] + sub_expressions
|
||||
if len(fn_params) > 3:
|
||||
raise ValueError("Too many params for timedelta operations.")
|
||||
return "django_format_dtdelta(%s)" % ", ".join(fn_params)
|
||||
|
||||
def integer_field_range(self, internal_type):
|
||||
# SQLite doesn't enforce any integer constraints, but sqlite3 supports
|
||||
# integers up to 64 bits.
|
||||
if internal_type in [
|
||||
"PositiveBigIntegerField",
|
||||
"PositiveIntegerField",
|
||||
"PositiveSmallIntegerField",
|
||||
]:
|
||||
return (0, 9223372036854775807)
|
||||
return (-9223372036854775808, 9223372036854775807)
|
||||
|
||||
def subtract_temporals(self, internal_type, lhs, rhs):
|
||||
lhs_sql, lhs_params = lhs
|
||||
rhs_sql, rhs_params = rhs
|
||||
params = (*lhs_params, *rhs_params)
|
||||
if internal_type == "TimeField":
|
||||
return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
|
||||
return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
|
||||
|
||||
def insert_statement(self, on_conflict=None):
|
||||
if on_conflict == OnConflict.IGNORE:
|
||||
return "INSERT OR IGNORE INTO"
|
||||
return super().insert_statement(on_conflict=on_conflict)
|
||||
|
||||
def return_insert_columns(self, fields):
|
||||
# SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
|
||||
if not fields:
|
||||
return "", ()
|
||||
columns = [
|
||||
"%s.%s"
|
||||
% (
|
||||
self.quote_name(field.model._meta.db_table),
|
||||
self.quote_name(field.column),
|
||||
)
|
||||
for field in fields
|
||||
]
|
||||
return "RETURNING %s" % ", ".join(columns), ()
|
||||
|
||||
def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
|
||||
if (
|
||||
on_conflict == OnConflict.UPDATE
|
||||
and self.connection.features.supports_update_conflicts_with_target
|
||||
):
|
||||
return "ON CONFLICT(%s) DO UPDATE SET %s" % (
|
||||
", ".join(map(self.quote_name, unique_fields)),
|
||||
", ".join(
|
||||
[
|
||||
f"{field} = EXCLUDED.{field}"
|
||||
for field in map(self.quote_name, update_fields)
|
||||
]
|
||||
),
|
||||
)
|
||||
return super().on_conflict_suffix_sql(
|
||||
fields,
|
||||
on_conflict,
|
||||
update_fields,
|
||||
unique_fields,
|
||||
)
|
||||
|
||||
def force_group_by(self):
|
||||
return ["GROUP BY TRUE"] if Database.sqlite_version_info < (3, 39) else []
|
||||
@@ -0,0 +1,503 @@
|
||||
import copy
|
||||
from decimal import Decimal
|
||||
|
||||
from django.apps.registry import Apps
|
||||
from django.db import NotSupportedError
|
||||
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
|
||||
from django.db.backends.ddl_references import Statement
|
||||
from django.db.backends.utils import strip_quotes
|
||||
from django.db.models import CompositePrimaryKey, UniqueConstraint
|
||||
|
||||
|
||||
class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
|
||||
sql_delete_table = "DROP TABLE %(table)s"
|
||||
sql_create_fk = None
|
||||
sql_create_inline_fk = (
|
||||
"REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED"
|
||||
)
|
||||
sql_create_column_inline_fk = sql_create_inline_fk
|
||||
sql_delete_column = "ALTER TABLE %(table)s DROP COLUMN %(column)s"
|
||||
sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)"
|
||||
sql_delete_unique = "DROP INDEX %(name)s"
|
||||
sql_alter_table_comment = None
|
||||
sql_alter_column_comment = None
|
||||
|
||||
def __enter__(self):
|
||||
# Some SQLite schema alterations need foreign key constraints to be
|
||||
# disabled. Enforce it here for the duration of the schema edition.
|
||||
if not self.connection.disable_constraint_checking():
|
||||
raise NotSupportedError(
|
||||
"SQLite schema editor cannot be used while foreign key "
|
||||
"constraint checks are enabled. Make sure to disable them "
|
||||
"before entering a transaction.atomic() context because "
|
||||
"SQLite does not support disabling them in the middle of "
|
||||
"a multi-statement transaction."
|
||||
)
|
||||
return super().__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.connection.check_constraints()
|
||||
super().__exit__(exc_type, exc_value, traceback)
|
||||
self.connection.enable_constraint_checking()
|
||||
|
||||
def quote_value(self, value):
|
||||
# The backend "mostly works" without this function and there are use
|
||||
# cases for compiling Python without the sqlite3 libraries (e.g.
|
||||
# security hardening).
|
||||
try:
|
||||
import sqlite3
|
||||
|
||||
value = sqlite3.adapt(value)
|
||||
except ImportError:
|
||||
pass
|
||||
except sqlite3.ProgrammingError:
|
||||
pass
|
||||
# Manual emulation of SQLite parameter quoting
|
||||
if isinstance(value, bool):
|
||||
return str(int(value))
|
||||
elif isinstance(value, (Decimal, float, int)):
|
||||
return str(value)
|
||||
elif isinstance(value, str):
|
||||
return "'%s'" % value.replace("'", "''")
|
||||
elif value is None:
|
||||
return "NULL"
|
||||
elif isinstance(value, (bytes, bytearray, memoryview)):
|
||||
# Bytes are only allowed for BLOB fields, encoded as string
|
||||
# literals containing hexadecimal data and preceded by a single "X"
|
||||
# character.
|
||||
return "X'%s'" % value.hex()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot quote parameter value %r of type %s" % (value, type(value))
|
||||
)
|
||||
|
||||
def prepare_default(self, value):
|
||||
return self.quote_value(value)
|
||||
|
||||
def _remake_table(
|
||||
self, model, create_field=None, delete_field=None, alter_fields=None
|
||||
):
|
||||
"""
|
||||
Shortcut to transform a model from old_model into new_model
|
||||
|
||||
This follows the correct procedure to perform non-rename or column
|
||||
addition operations based on SQLite's documentation
|
||||
|
||||
https://www.sqlite.org/lang_altertable.html#caution
|
||||
|
||||
The essential steps are:
|
||||
1. Create a table with the updated definition called "new__app_model"
|
||||
2. Copy the data from the existing "app_model" table to the new table
|
||||
3. Drop the "app_model" table
|
||||
4. Rename the "new__app_model" table to "app_model"
|
||||
5. Restore any index of the previous "app_model" table.
|
||||
"""
|
||||
|
||||
# Self-referential fields must be recreated rather than copied from
|
||||
# the old model to ensure their remote_field.field_name doesn't refer
|
||||
# to an altered field.
|
||||
def is_self_referential(f):
|
||||
return f.is_relation and f.remote_field.model is model
|
||||
|
||||
# Work out the new fields dict / mapping
|
||||
body = {
|
||||
f.name: f.clone() if is_self_referential(f) else f
|
||||
for f in model._meta.local_concrete_fields
|
||||
}
|
||||
|
||||
# Since CompositePrimaryKey is not a concrete field (column is None),
|
||||
# it's not copied by default.
|
||||
pk = model._meta.pk
|
||||
if isinstance(pk, CompositePrimaryKey):
|
||||
body[pk.name] = pk.clone()
|
||||
|
||||
# Since mapping might mix column names and default values,
|
||||
# its values must be already quoted.
|
||||
mapping = {
|
||||
f.column: self.quote_name(f.column)
|
||||
for f in model._meta.local_concrete_fields
|
||||
if f.generated is False
|
||||
}
|
||||
# This maps field names (not columns) for things like unique_together
|
||||
rename_mapping = {}
|
||||
# If any of the new or altered fields is introducing a new PK,
|
||||
# remove the old one
|
||||
restore_pk_field = None
|
||||
alter_fields = alter_fields or []
|
||||
if getattr(create_field, "primary_key", False) or any(
|
||||
getattr(new_field, "primary_key", False) for _, new_field in alter_fields
|
||||
):
|
||||
for name, field in list(body.items()):
|
||||
if field.primary_key and not any(
|
||||
# Do not remove the old primary key when an altered field
|
||||
# that introduces a primary key is the same field.
|
||||
name == new_field.name
|
||||
for _, new_field in alter_fields
|
||||
):
|
||||
field.primary_key = False
|
||||
restore_pk_field = field
|
||||
if field.auto_created:
|
||||
del body[name]
|
||||
del mapping[field.column]
|
||||
# Add in any created fields
|
||||
if create_field:
|
||||
body[create_field.name] = create_field
|
||||
# Choose a default and insert it into the copy map
|
||||
if (
|
||||
not create_field.has_db_default()
|
||||
and not (create_field.many_to_many or create_field.generated)
|
||||
and create_field.concrete
|
||||
):
|
||||
mapping[create_field.column] = self.prepare_default(
|
||||
self.effective_default(create_field)
|
||||
)
|
||||
# Add in any altered fields
|
||||
for alter_field in alter_fields:
|
||||
old_field, new_field = alter_field
|
||||
body.pop(old_field.name, None)
|
||||
mapping.pop(old_field.column, None)
|
||||
body[new_field.name] = new_field
|
||||
rename_mapping[old_field.name] = new_field.name
|
||||
if new_field.generated:
|
||||
continue
|
||||
if old_field.null and not new_field.null:
|
||||
if not new_field.has_db_default():
|
||||
default = self.prepare_default(self.effective_default(new_field))
|
||||
else:
|
||||
default, _ = self.db_default_sql(new_field)
|
||||
case_sql = "coalesce(%(col)s, %(default)s)" % {
|
||||
"col": self.quote_name(old_field.column),
|
||||
"default": default,
|
||||
}
|
||||
mapping[new_field.column] = case_sql
|
||||
else:
|
||||
mapping[new_field.column] = self.quote_name(old_field.column)
|
||||
# Remove any deleted fields
|
||||
if delete_field:
|
||||
del body[delete_field.name]
|
||||
mapping.pop(delete_field.column, None)
|
||||
# Remove any implicit M2M tables
|
||||
if (
|
||||
delete_field.many_to_many
|
||||
and delete_field.remote_field.through._meta.auto_created
|
||||
):
|
||||
return self.delete_model(delete_field.remote_field.through)
|
||||
# Work inside a new app registry
|
||||
apps = Apps()
|
||||
|
||||
# Work out the new value of unique_together, taking renames into
|
||||
# account
|
||||
unique_together = [
|
||||
[rename_mapping.get(n, n) for n in unique]
|
||||
for unique in model._meta.unique_together
|
||||
]
|
||||
|
||||
indexes = model._meta.indexes
|
||||
if delete_field:
|
||||
indexes = [
|
||||
index for index in indexes if delete_field.name not in index.fields
|
||||
]
|
||||
|
||||
constraints = list(model._meta.constraints)
|
||||
|
||||
# Provide isolated instances of the fields to the new model body so
|
||||
# that the existing model's internals aren't interfered with when
|
||||
# the dummy model is constructed.
|
||||
body_copy = copy.deepcopy(body)
|
||||
|
||||
# Construct a new model with the new fields to allow self referential
|
||||
# primary key to resolve to. This model won't ever be materialized as a
|
||||
# table and solely exists for foreign key reference resolution purposes.
|
||||
# This wouldn't be required if the schema editor was operating on model
|
||||
# states instead of rendered models.
|
||||
meta_contents = {
|
||||
"app_label": model._meta.app_label,
|
||||
"db_table": model._meta.db_table,
|
||||
"unique_together": unique_together,
|
||||
"indexes": indexes,
|
||||
"constraints": constraints,
|
||||
"apps": apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy["Meta"] = meta
|
||||
body_copy["__module__"] = model.__module__
|
||||
type(model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Construct a model with a renamed table name.
|
||||
body_copy = copy.deepcopy(body)
|
||||
meta_contents = {
|
||||
"app_label": model._meta.app_label,
|
||||
"db_table": "new__%s" % strip_quotes(model._meta.db_table),
|
||||
"unique_together": unique_together,
|
||||
"indexes": indexes,
|
||||
"constraints": constraints,
|
||||
"apps": apps,
|
||||
}
|
||||
meta = type("Meta", (), meta_contents)
|
||||
body_copy["Meta"] = meta
|
||||
body_copy["__module__"] = model.__module__
|
||||
new_model = type("New%s" % model._meta.object_name, model.__bases__, body_copy)
|
||||
|
||||
# Remove the automatically recreated default primary key, if it has
|
||||
# been deleted.
|
||||
if delete_field and delete_field.attname == new_model._meta.pk.attname:
|
||||
auto_pk = new_model._meta.pk
|
||||
delattr(new_model, auto_pk.attname)
|
||||
new_model._meta.local_fields.remove(auto_pk)
|
||||
new_model.pk = None
|
||||
|
||||
# Create a new table with the updated schema.
|
||||
self.create_model(new_model)
|
||||
|
||||
# Copy data from the old table into the new table
|
||||
self.execute(
|
||||
"INSERT INTO %s (%s) SELECT %s FROM %s"
|
||||
% (
|
||||
self.quote_name(new_model._meta.db_table),
|
||||
", ".join(self.quote_name(x) for x in mapping),
|
||||
", ".join(mapping.values()),
|
||||
self.quote_name(model._meta.db_table),
|
||||
)
|
||||
)
|
||||
|
||||
# Delete the old table to make way for the new
|
||||
self.delete_model(model, handle_autom2m=False)
|
||||
|
||||
# Rename the new table to take way for the old
|
||||
self.alter_db_table(
|
||||
new_model,
|
||||
new_model._meta.db_table,
|
||||
model._meta.db_table,
|
||||
)
|
||||
|
||||
# Run deferred SQL on correct table
|
||||
for sql in self.deferred_sql:
|
||||
self.execute(sql)
|
||||
self.deferred_sql = []
|
||||
# Fix any PK-removed field
|
||||
if restore_pk_field:
|
||||
restore_pk_field.primary_key = True
|
||||
|
||||
def delete_model(self, model, handle_autom2m=True):
|
||||
if handle_autom2m:
|
||||
super().delete_model(model)
|
||||
else:
|
||||
# Delete the table (and only that)
|
||||
self.execute(
|
||||
self.sql_delete_table
|
||||
% {
|
||||
"table": self.quote_name(model._meta.db_table),
|
||||
}
|
||||
)
|
||||
# Remove all deferred statements referencing the deleted table.
|
||||
for sql in list(self.deferred_sql):
|
||||
if isinstance(sql, Statement) and sql.references_table(
|
||||
model._meta.db_table
|
||||
):
|
||||
self.deferred_sql.remove(sql)
|
||||
|
||||
def add_field(self, model, field):
|
||||
"""Create a field on a model."""
|
||||
from django.db.models.expressions import Value
|
||||
|
||||
# Special-case implicit M2M tables.
|
||||
if field.many_to_many and field.remote_field.through._meta.auto_created:
|
||||
self.create_model(field.remote_field.through)
|
||||
elif isinstance(field, CompositePrimaryKey):
|
||||
# If a CompositePrimaryKey field was added, the existing primary key field
|
||||
# had to be altered too, resulting in an AddField, AlterField migration.
|
||||
# The table cannot be re-created on AddField, it would result in a
|
||||
# duplicate primary key error.
|
||||
return
|
||||
elif (
|
||||
# Primary keys and unique fields are not supported in ALTER TABLE
|
||||
# ADD COLUMN.
|
||||
field.primary_key
|
||||
or field.unique
|
||||
or not field.null
|
||||
# Fields with default values cannot by handled by ALTER TABLE ADD
|
||||
# COLUMN statement because DROP DEFAULT is not supported in
|
||||
# ALTER TABLE.
|
||||
or self.effective_default(field) is not None
|
||||
# Fields with non-constant defaults cannot by handled by ALTER
|
||||
# TABLE ADD COLUMN statement.
|
||||
or (field.has_db_default() and not isinstance(field.db_default, Value))
|
||||
):
|
||||
self._remake_table(model, create_field=field)
|
||||
else:
|
||||
super().add_field(model, field)
|
||||
|
||||
def remove_field(self, model, field):
|
||||
"""
|
||||
Remove a field from a model. Usually involves deleting a column,
|
||||
but for M2Ms may involve deleting a table.
|
||||
"""
|
||||
# M2M fields are a special case
|
||||
if field.many_to_many:
|
||||
# For implicit M2M tables, delete the auto-created table
|
||||
if field.remote_field.through._meta.auto_created:
|
||||
self.delete_model(field.remote_field.through)
|
||||
# For explicit "through" M2M fields, do nothing
|
||||
elif (
|
||||
self.connection.features.can_alter_table_drop_column
|
||||
# Primary keys, unique fields, indexed fields, and foreign keys are
|
||||
# not supported in ALTER TABLE DROP COLUMN.
|
||||
and not field.primary_key
|
||||
and not field.unique
|
||||
and not field.db_index
|
||||
and not (field.remote_field and field.db_constraint)
|
||||
):
|
||||
super().remove_field(model, field)
|
||||
# For everything else, remake.
|
||||
else:
|
||||
# It might not actually have a column behind it
|
||||
if field.db_parameters(connection=self.connection)["type"] is None:
|
||||
return
|
||||
self._remake_table(model, delete_field=field)
|
||||
|
||||
def _alter_field(
|
||||
self,
|
||||
model,
|
||||
old_field,
|
||||
new_field,
|
||||
old_type,
|
||||
new_type,
|
||||
old_db_params,
|
||||
new_db_params,
|
||||
strict=False,
|
||||
):
|
||||
"""Perform a "physical" (non-ManyToMany) field update."""
|
||||
# Use "ALTER TABLE ... RENAME COLUMN" if only the column name
|
||||
# changed and there aren't any constraints.
|
||||
if (
|
||||
old_field.column != new_field.column
|
||||
and self.column_sql(model, old_field) == self.column_sql(model, new_field)
|
||||
and not (
|
||||
old_field.remote_field
|
||||
and old_field.db_constraint
|
||||
or new_field.remote_field
|
||||
and new_field.db_constraint
|
||||
)
|
||||
):
|
||||
return self.execute(
|
||||
self._rename_field_sql(
|
||||
model._meta.db_table, old_field, new_field, new_type
|
||||
)
|
||||
)
|
||||
# Alter by remaking table
|
||||
self._remake_table(model, alter_fields=[(old_field, new_field)])
|
||||
# Rebuild tables with FKs pointing to this field.
|
||||
old_collation = old_db_params.get("collation")
|
||||
new_collation = new_db_params.get("collation")
|
||||
if new_field.unique and (
|
||||
old_type != new_type or old_collation != new_collation
|
||||
):
|
||||
related_models = set()
|
||||
opts = new_field.model._meta
|
||||
for remote_field in opts.related_objects:
|
||||
# Ignore self-relationship since the table was already rebuilt.
|
||||
if remote_field.related_model == model:
|
||||
continue
|
||||
if not remote_field.many_to_many:
|
||||
if remote_field.field_name == new_field.name:
|
||||
related_models.add(remote_field.related_model)
|
||||
elif new_field.primary_key and remote_field.through._meta.auto_created:
|
||||
related_models.add(remote_field.through)
|
||||
if new_field.primary_key:
|
||||
for many_to_many in opts.many_to_many:
|
||||
# Ignore self-relationship since the table was already rebuilt.
|
||||
if many_to_many.related_model == model:
|
||||
continue
|
||||
if many_to_many.remote_field.through._meta.auto_created:
|
||||
related_models.add(many_to_many.remote_field.through)
|
||||
for related_model in related_models:
|
||||
self._remake_table(related_model)
|
||||
|
||||
def _alter_many_to_many(self, model, old_field, new_field, strict):
|
||||
"""Alter M2Ms to repoint their to= endpoints."""
|
||||
if (
|
||||
old_field.remote_field.through._meta.db_table
|
||||
== new_field.remote_field.through._meta.db_table
|
||||
):
|
||||
# The field name didn't change, but some options did, so we have to
|
||||
# propagate this altering.
|
||||
self._remake_table(
|
||||
old_field.remote_field.through,
|
||||
alter_fields=[
|
||||
(
|
||||
# The field that points to the target model is needed,
|
||||
# so that table can be remade with the new m2m field -
|
||||
# this is m2m_reverse_field_name().
|
||||
old_field.remote_field.through._meta.get_field(
|
||||
old_field.m2m_reverse_field_name()
|
||||
),
|
||||
new_field.remote_field.through._meta.get_field(
|
||||
new_field.m2m_reverse_field_name()
|
||||
),
|
||||
),
|
||||
(
|
||||
# The field that points to the model itself is needed,
|
||||
# so that table can be remade with the new self field -
|
||||
# this is m2m_field_name().
|
||||
old_field.remote_field.through._meta.get_field(
|
||||
old_field.m2m_field_name()
|
||||
),
|
||||
new_field.remote_field.through._meta.get_field(
|
||||
new_field.m2m_field_name()
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
return
|
||||
|
||||
# Make a new through table
|
||||
self.create_model(new_field.remote_field.through)
|
||||
# Copy the data across
|
||||
self.execute(
|
||||
"INSERT INTO %s (%s) SELECT %s FROM %s"
|
||||
% (
|
||||
self.quote_name(new_field.remote_field.through._meta.db_table),
|
||||
", ".join(
|
||||
[
|
||||
"id",
|
||||
new_field.m2m_column_name(),
|
||||
new_field.m2m_reverse_name(),
|
||||
]
|
||||
),
|
||||
", ".join(
|
||||
[
|
||||
"id",
|
||||
old_field.m2m_column_name(),
|
||||
old_field.m2m_reverse_name(),
|
||||
]
|
||||
),
|
||||
self.quote_name(old_field.remote_field.through._meta.db_table),
|
||||
)
|
||||
)
|
||||
# Delete the old through table
|
||||
self.delete_model(old_field.remote_field.through)
|
||||
|
||||
def add_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and (
|
||||
constraint.condition
|
||||
or constraint.contains_expressions
|
||||
or constraint.include
|
||||
or constraint.deferrable
|
||||
):
|
||||
super().add_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
|
||||
def remove_constraint(self, model, constraint):
|
||||
if isinstance(constraint, UniqueConstraint) and (
|
||||
constraint.condition
|
||||
or constraint.contains_expressions
|
||||
or constraint.include
|
||||
or constraint.deferrable
|
||||
):
|
||||
super().remove_constraint(model, constraint)
|
||||
else:
|
||||
self._remake_table(model)
|
||||
|
||||
def _collate_sql(self, collation):
|
||||
return "COLLATE " + collation
|
||||
342
.venv/lib/python3.10/site-packages/django/db/backends/utils.py
Normal file
342
.venv/lib/python3.10/site-packages/django/db/backends/utils.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from hashlib import md5
|
||||
|
||||
from django.apps import apps
|
||||
from django.db import NotSupportedError
|
||||
from django.utils.dateparse import parse_time
|
||||
|
||||
logger = logging.getLogger("django.db.backends")
|
||||
|
||||
|
||||
class CursorWrapper:
|
||||
def __init__(self, cursor, db):
|
||||
self.cursor = cursor
|
||||
self.db = db
|
||||
|
||||
WRAP_ERROR_ATTRS = frozenset(["fetchone", "fetchmany", "fetchall", "nextset"])
|
||||
|
||||
APPS_NOT_READY_WARNING_MSG = (
|
||||
"Accessing the database during app initialization is discouraged. To fix this "
|
||||
"warning, avoid executing queries in AppConfig.ready() or when your app "
|
||||
"modules are imported."
|
||||
)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
cursor_attr = getattr(self.cursor, attr)
|
||||
if attr in CursorWrapper.WRAP_ERROR_ATTRS:
|
||||
return self.db.wrap_database_errors(cursor_attr)
|
||||
else:
|
||||
return cursor_attr
|
||||
|
||||
def __iter__(self):
|
||||
with self.db.wrap_database_errors:
|
||||
yield from self.cursor
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
# Close instead of passing through to avoid backend-specific behavior
|
||||
# (#17671). Catch errors liberally because errors in cleanup code
|
||||
# aren't useful.
|
||||
try:
|
||||
self.close()
|
||||
except self.db.Database.Error:
|
||||
pass
|
||||
|
||||
# The following methods cannot be implemented in __getattr__, because the
|
||||
# code must run when the method is invoked, not just when it is accessed.
|
||||
|
||||
def callproc(self, procname, params=None, kparams=None):
|
||||
# Keyword parameters for callproc aren't supported in PEP 249, but the
|
||||
# database driver may support them (e.g. oracledb).
|
||||
if kparams is not None and not self.db.features.supports_callproc_kwargs:
|
||||
raise NotSupportedError(
|
||||
"Keyword parameters for callproc are not supported on this "
|
||||
"database backend."
|
||||
)
|
||||
# Raise a warning during app initialization (stored_app_configs is only
|
||||
# ever set during testing).
|
||||
if not apps.ready and not apps.stored_app_configs:
|
||||
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
if params is None and kparams is None:
|
||||
return self.cursor.callproc(procname)
|
||||
elif kparams is None:
|
||||
return self.cursor.callproc(procname, params)
|
||||
else:
|
||||
params = params or ()
|
||||
return self.cursor.callproc(procname, params, kparams)
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
return self._execute_with_wrappers(
|
||||
sql, params, many=False, executor=self._execute
|
||||
)
|
||||
|
||||
def executemany(self, sql, param_list):
|
||||
return self._execute_with_wrappers(
|
||||
sql, param_list, many=True, executor=self._executemany
|
||||
)
|
||||
|
||||
def _execute_with_wrappers(self, sql, params, many, executor):
|
||||
context = {"connection": self.db, "cursor": self}
|
||||
for wrapper in reversed(self.db.execute_wrappers):
|
||||
executor = functools.partial(wrapper, executor)
|
||||
return executor(sql, params, many, context)
|
||||
|
||||
def _execute(self, sql, params, *ignored_wrapper_args):
|
||||
# Raise a warning during app initialization (stored_app_configs is only
|
||||
# ever set during testing).
|
||||
if not apps.ready and not apps.stored_app_configs:
|
||||
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
if params is None:
|
||||
# params default might be backend specific.
|
||||
return self.cursor.execute(sql)
|
||||
else:
|
||||
return self.cursor.execute(sql, params)
|
||||
|
||||
def _executemany(self, sql, param_list, *ignored_wrapper_args):
|
||||
# Raise a warning during app initialization (stored_app_configs is only
|
||||
# ever set during testing).
|
||||
if not apps.ready and not apps.stored_app_configs:
|
||||
warnings.warn(self.APPS_NOT_READY_WARNING_MSG, category=RuntimeWarning)
|
||||
self.db.validate_no_broken_transaction()
|
||||
with self.db.wrap_database_errors:
|
||||
return self.cursor.executemany(sql, param_list)
|
||||
|
||||
|
||||
class CursorDebugWrapper(CursorWrapper):
|
||||
# XXX callproc isn't instrumented at this time.
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
with self.debug_sql(sql, params, use_last_executed_query=True):
|
||||
return super().execute(sql, params)
|
||||
|
||||
def executemany(self, sql, param_list):
|
||||
with self.debug_sql(sql, param_list, many=True):
|
||||
return super().executemany(sql, param_list)
|
||||
|
||||
@contextmanager
|
||||
def debug_sql(
|
||||
self, sql=None, params=None, use_last_executed_query=False, many=False
|
||||
):
|
||||
start = time.monotonic()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
stop = time.monotonic()
|
||||
duration = stop - start
|
||||
if use_last_executed_query:
|
||||
sql = self.db.ops.last_executed_query(self.cursor, sql, params)
|
||||
try:
|
||||
times = len(params) if many else ""
|
||||
except TypeError:
|
||||
# params could be an iterator.
|
||||
times = "?"
|
||||
self.db.queries_log.append(
|
||||
{
|
||||
"sql": "%s times: %s" % (times, sql) if many else sql,
|
||||
"time": "%.3f" % duration,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"(%.3f) %s; args=%s; alias=%s",
|
||||
duration,
|
||||
sql,
|
||||
params,
|
||||
self.db.alias,
|
||||
extra={
|
||||
"duration": duration,
|
||||
"sql": sql,
|
||||
"params": params,
|
||||
"alias": self.db.alias,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def debug_transaction(connection, sql):
|
||||
start = time.monotonic()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if connection.queries_logged:
|
||||
stop = time.monotonic()
|
||||
duration = stop - start
|
||||
connection.queries_log.append(
|
||||
{
|
||||
"sql": "%s" % sql,
|
||||
"time": "%.3f" % duration,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
"(%.3f) %s; args=%s; alias=%s",
|
||||
duration,
|
||||
sql,
|
||||
None,
|
||||
connection.alias,
|
||||
extra={
|
||||
"duration": duration,
|
||||
"sql": sql,
|
||||
"alias": connection.alias,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def split_tzname_delta(tzname):
|
||||
"""
|
||||
Split a time zone name into a 3-tuple of (name, sign, offset).
|
||||
"""
|
||||
for sign in ["+", "-"]:
|
||||
if sign in tzname:
|
||||
name, offset = tzname.rsplit(sign, 1)
|
||||
if offset and parse_time(offset):
|
||||
if ":" not in offset:
|
||||
offset = f"{offset}:00"
|
||||
return name, sign, offset
|
||||
return tzname, None, None
|
||||
|
||||
|
||||
###############################################
|
||||
# Converters from database (string) to Python #
|
||||
###############################################
|
||||
|
||||
|
||||
def typecast_date(s):
|
||||
return (
|
||||
datetime.date(*map(int, s.split("-"))) if s else None
|
||||
) # return None if s is null
|
||||
|
||||
|
||||
def typecast_time(s): # does NOT store time zone information
|
||||
if not s:
|
||||
return None
|
||||
hour, minutes, seconds = s.split(":")
|
||||
if "." in seconds: # check whether seconds have a fractional part
|
||||
seconds, microseconds = seconds.split(".")
|
||||
else:
|
||||
microseconds = "0"
|
||||
return datetime.time(
|
||||
int(hour), int(minutes), int(seconds), int((microseconds + "000000")[:6])
|
||||
)
|
||||
|
||||
|
||||
def typecast_timestamp(s): # does NOT store time zone information
|
||||
# "2005-07-29 15:48:00.590358-05"
|
||||
# "2005-07-29 09:56:00-05"
|
||||
if not s:
|
||||
return None
|
||||
if " " not in s:
|
||||
return typecast_date(s)
|
||||
d, t = s.split()
|
||||
# Remove timezone information.
|
||||
if "-" in t:
|
||||
t, _ = t.split("-", 1)
|
||||
elif "+" in t:
|
||||
t, _ = t.split("+", 1)
|
||||
dates = d.split("-")
|
||||
times = t.split(":")
|
||||
seconds = times[2]
|
||||
if "." in seconds: # check whether seconds have a fractional part
|
||||
seconds, microseconds = seconds.split(".")
|
||||
else:
|
||||
microseconds = "0"
|
||||
return datetime.datetime(
|
||||
int(dates[0]),
|
||||
int(dates[1]),
|
||||
int(dates[2]),
|
||||
int(times[0]),
|
||||
int(times[1]),
|
||||
int(seconds),
|
||||
int((microseconds + "000000")[:6]),
|
||||
)
|
||||
|
||||
|
||||
###############################################
|
||||
# Converters from Python to database (string) #
|
||||
###############################################
|
||||
|
||||
|
||||
def split_identifier(identifier):
|
||||
"""
|
||||
Split an SQL identifier into a two element tuple of (namespace, name).
|
||||
|
||||
The identifier could be a table, column, or sequence name might be prefixed
|
||||
by a namespace.
|
||||
"""
|
||||
try:
|
||||
namespace, name = identifier.split('"."')
|
||||
except ValueError:
|
||||
namespace, name = "", identifier
|
||||
return namespace.strip('"'), name.strip('"')
|
||||
|
||||
|
||||
def truncate_name(identifier, length=None, hash_len=4):
|
||||
"""
|
||||
Shorten an SQL identifier to a repeatable mangled version with the given
|
||||
length.
|
||||
|
||||
If a quote stripped name contains a namespace, e.g. USERNAME"."TABLE,
|
||||
truncate the table portion only.
|
||||
"""
|
||||
namespace, name = split_identifier(identifier)
|
||||
|
||||
if length is None or len(name) <= length:
|
||||
return identifier
|
||||
|
||||
digest = names_digest(name, length=hash_len)
|
||||
return "%s%s%s" % (
|
||||
'%s"."' % namespace if namespace else "",
|
||||
name[: length - hash_len],
|
||||
digest,
|
||||
)
|
||||
|
||||
|
||||
def names_digest(*args, length):
|
||||
"""
|
||||
Generate a 32-bit digest of a set of arguments that can be used to shorten
|
||||
identifying names.
|
||||
"""
|
||||
h = md5(usedforsecurity=False)
|
||||
for arg in args:
|
||||
h.update(arg.encode())
|
||||
return h.hexdigest()[:length]
|
||||
|
||||
|
||||
def format_number(value, max_digits, decimal_places):
|
||||
"""
|
||||
Format a number into a string with the requisite number of digits and
|
||||
decimal places.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
context = decimal.getcontext().copy()
|
||||
if max_digits is not None:
|
||||
context.prec = max_digits
|
||||
if decimal_places is not None:
|
||||
value = value.quantize(
|
||||
decimal.Decimal(1).scaleb(-decimal_places), context=context
|
||||
)
|
||||
else:
|
||||
context.traps[decimal.Rounded] = 1
|
||||
value = context.create_decimal(value)
|
||||
return "{:f}".format(value)
|
||||
|
||||
|
||||
def strip_quotes(table_name):
|
||||
"""
|
||||
Strip quotes off of quoted table names to make them safe for use in index
|
||||
names, sequence names, etc. For example '"USER"."TABLE"' (an Oracle naming
|
||||
scheme) becomes 'USER"."TABLE'.
|
||||
"""
|
||||
has_quotes = table_name.startswith('"') and table_name.endswith('"')
|
||||
return table_name[1:-1] if has_quotes else table_name
|
||||
@@ -0,0 +1,2 @@
|
||||
from .migration import Migration, swappable_dependency # NOQA
|
||||
from .operations import * # NOQA
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,60 @@
|
||||
from django.db import DatabaseError
|
||||
|
||||
|
||||
class AmbiguityError(Exception):
|
||||
"""More than one migration matches a name prefix."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BadMigrationError(Exception):
|
||||
"""There's a bad migration (unreadable/bad format/etc.)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CircularDependencyError(Exception):
|
||||
"""There's an impossible-to-resolve circular dependency."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InconsistentMigrationHistory(Exception):
|
||||
"""An applied migration has some of its dependencies not applied."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvalidBasesError(ValueError):
|
||||
"""A model's base classes can't be resolved."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IrreversibleError(RuntimeError):
|
||||
"""An irreversible migration is about to be reversed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(LookupError):
|
||||
"""An attempt on a node is made that is not available in the graph."""
|
||||
|
||||
def __init__(self, message, node, origin=None):
|
||||
self.message = message
|
||||
self.origin = origin
|
||||
self.node = node
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
def __repr__(self):
|
||||
return "NodeNotFoundError(%r)" % (self.node,)
|
||||
|
||||
|
||||
class MigrationSchemaMissing(DatabaseError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidMigrationPlan(ValueError):
|
||||
pass
|
||||
@@ -0,0 +1,413 @@
|
||||
from django.apps.registry import apps as global_apps
|
||||
from django.db import migrations, router
|
||||
|
||||
from .exceptions import InvalidMigrationPlan
|
||||
from .loader import MigrationLoader
|
||||
from .recorder import MigrationRecorder
|
||||
from .state import ProjectState
|
||||
|
||||
|
||||
class MigrationExecutor:
|
||||
"""
|
||||
End-to-end migration execution - load migrations and run them up or down
|
||||
to a specified set of targets.
|
||||
"""
|
||||
|
||||
def __init__(self, connection, progress_callback=None):
|
||||
self.connection = connection
|
||||
self.loader = MigrationLoader(self.connection)
|
||||
self.recorder = MigrationRecorder(self.connection)
|
||||
self.progress_callback = progress_callback
|
||||
|
||||
def migration_plan(self, targets, clean_start=False):
|
||||
"""
|
||||
Given a set of targets, return a list of (Migration instance, backwards?).
|
||||
"""
|
||||
plan = []
|
||||
if clean_start:
|
||||
applied = {}
|
||||
else:
|
||||
applied = dict(self.loader.applied_migrations)
|
||||
for target in targets:
|
||||
# If the target is (app_label, None), that means unmigrate everything
|
||||
if target[1] is None:
|
||||
for root in self.loader.graph.root_nodes():
|
||||
if root[0] == target[0]:
|
||||
for migration in self.loader.graph.backwards_plan(root):
|
||||
if migration in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], True))
|
||||
applied.pop(migration)
|
||||
# If the migration is already applied, do backwards mode,
|
||||
# otherwise do forwards mode.
|
||||
elif target in applied:
|
||||
# If the target is missing, it's likely a replaced migration.
|
||||
# Reload the graph without replacements.
|
||||
if (
|
||||
self.loader.replace_migrations
|
||||
and target not in self.loader.graph.node_map
|
||||
):
|
||||
self.loader.replace_migrations = False
|
||||
self.loader.build_graph()
|
||||
return self.migration_plan(targets, clean_start=clean_start)
|
||||
# Don't migrate backwards all the way to the target node (that
|
||||
# may roll back dependencies in other apps that don't need to
|
||||
# be rolled back); instead roll back through target's immediate
|
||||
# child(ren) in the same app, and no further.
|
||||
next_in_app = sorted(
|
||||
n
|
||||
for n in self.loader.graph.node_map[target].children
|
||||
if n[0] == target[0]
|
||||
)
|
||||
for node in next_in_app:
|
||||
for migration in self.loader.graph.backwards_plan(node):
|
||||
if migration in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], True))
|
||||
applied.pop(migration)
|
||||
else:
|
||||
for migration in self.loader.graph.forwards_plan(target):
|
||||
if migration not in applied:
|
||||
plan.append((self.loader.graph.nodes[migration], False))
|
||||
applied[migration] = self.loader.graph.nodes[migration]
|
||||
return plan
|
||||
|
||||
def _create_project_state(self, with_applied_migrations=False):
|
||||
"""
|
||||
Create a project state including all the applications without
|
||||
migrations and applied migrations if with_applied_migrations=True.
|
||||
"""
|
||||
state = ProjectState(real_apps=self.loader.unmigrated_apps)
|
||||
if with_applied_migrations:
|
||||
# Create the forwards plan Django would follow on an empty database
|
||||
full_plan = self.migration_plan(
|
||||
self.loader.graph.leaf_nodes(), clean_start=True
|
||||
)
|
||||
applied_migrations = {
|
||||
self.loader.graph.nodes[key]
|
||||
for key in self.loader.applied_migrations
|
||||
if key in self.loader.graph.nodes
|
||||
}
|
||||
for migration, _ in full_plan:
|
||||
if migration in applied_migrations:
|
||||
migration.mutate_state(state, preserve=False)
|
||||
return state
|
||||
|
||||
def migrate(self, targets, plan=None, state=None, fake=False, fake_initial=False):
|
||||
"""
|
||||
Migrate the database up to the given targets.
|
||||
|
||||
Django first needs to create all project states before a migration is
|
||||
(un)applied and in a second step run all the database operations.
|
||||
"""
|
||||
# The django_migrations table must be present to record applied
|
||||
# migrations, but don't create it if there are no migrations to apply.
|
||||
if plan == []:
|
||||
if not self.recorder.has_table():
|
||||
return self._create_project_state(with_applied_migrations=False)
|
||||
else:
|
||||
self.recorder.ensure_schema()
|
||||
|
||||
if plan is None:
|
||||
plan = self.migration_plan(targets)
|
||||
# Create the forwards plan Django would follow on an empty database
|
||||
full_plan = self.migration_plan(
|
||||
self.loader.graph.leaf_nodes(), clean_start=True
|
||||
)
|
||||
|
||||
all_forwards = all(not backwards for mig, backwards in plan)
|
||||
all_backwards = all(backwards for mig, backwards in plan)
|
||||
|
||||
if not plan:
|
||||
if state is None:
|
||||
# The resulting state should include applied migrations.
|
||||
state = self._create_project_state(with_applied_migrations=True)
|
||||
elif all_forwards == all_backwards:
|
||||
# This should only happen if there's a mixed plan
|
||||
raise InvalidMigrationPlan(
|
||||
"Migration plans with both forwards and backwards migrations "
|
||||
"are not supported. Please split your migration process into "
|
||||
"separate plans of only forwards OR backwards migrations.",
|
||||
plan,
|
||||
)
|
||||
elif all_forwards:
|
||||
if state is None:
|
||||
# The resulting state should still include applied migrations.
|
||||
state = self._create_project_state(with_applied_migrations=True)
|
||||
state = self._migrate_all_forwards(
|
||||
state, plan, full_plan, fake=fake, fake_initial=fake_initial
|
||||
)
|
||||
else:
|
||||
# No need to check for `elif all_backwards` here, as that condition
|
||||
# would always evaluate to true.
|
||||
state = self._migrate_all_backwards(plan, full_plan, fake=fake)
|
||||
|
||||
self.check_replacements()
|
||||
|
||||
return state
|
||||
|
||||
def _migrate_all_forwards(self, state, plan, full_plan, fake, fake_initial):
|
||||
"""
|
||||
Take a list of 2-tuples of the form (migration instance, False) and
|
||||
apply them in the order they occur in the full_plan.
|
||||
"""
|
||||
migrations_to_run = {m[0] for m in plan}
|
||||
for migration, _ in full_plan:
|
||||
if not migrations_to_run:
|
||||
# We remove every migration that we applied from these sets so
|
||||
# that we can bail out once the last migration has been applied
|
||||
# and don't always run until the very end of the migration
|
||||
# process.
|
||||
break
|
||||
if migration in migrations_to_run:
|
||||
if "apps" not in state.__dict__:
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_start")
|
||||
state.apps # Render all -- performance critical
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_success")
|
||||
state = self.apply_migration(
|
||||
state, migration, fake=fake, fake_initial=fake_initial
|
||||
)
|
||||
migrations_to_run.remove(migration)
|
||||
|
||||
return state
|
||||
|
||||
def _migrate_all_backwards(self, plan, full_plan, fake):
|
||||
"""
|
||||
Take a list of 2-tuples of the form (migration instance, True) and
|
||||
unapply them in reverse order they occur in the full_plan.
|
||||
|
||||
Since unapplying a migration requires the project state prior to that
|
||||
migration, Django will compute the migration states before each of them
|
||||
in a first run over the plan and then unapply them in a second run over
|
||||
the plan.
|
||||
"""
|
||||
migrations_to_run = {m[0] for m in plan}
|
||||
# Holds all migration states prior to the migrations being unapplied
|
||||
states = {}
|
||||
state = self._create_project_state()
|
||||
applied_migrations = {
|
||||
self.loader.graph.nodes[key]
|
||||
for key in self.loader.applied_migrations
|
||||
if key in self.loader.graph.nodes
|
||||
}
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_start")
|
||||
for migration, _ in full_plan:
|
||||
if not migrations_to_run:
|
||||
# We remove every migration that we applied from this set so
|
||||
# that we can bail out once the last migration has been applied
|
||||
# and don't always run until the very end of the migration
|
||||
# process.
|
||||
break
|
||||
if migration in migrations_to_run:
|
||||
if "apps" not in state.__dict__:
|
||||
state.apps # Render all -- performance critical
|
||||
# The state before this migration
|
||||
states[migration] = state
|
||||
# The old state keeps as-is, we continue with the new state
|
||||
state = migration.mutate_state(state, preserve=True)
|
||||
migrations_to_run.remove(migration)
|
||||
elif migration in applied_migrations:
|
||||
# Only mutate the state if the migration is actually applied
|
||||
# to make sure the resulting state doesn't include changes
|
||||
# from unrelated migrations.
|
||||
migration.mutate_state(state, preserve=False)
|
||||
if self.progress_callback:
|
||||
self.progress_callback("render_success")
|
||||
|
||||
for migration, _ in plan:
|
||||
self.unapply_migration(states[migration], migration, fake=fake)
|
||||
applied_migrations.remove(migration)
|
||||
|
||||
# Generate the post migration state by starting from the state before
|
||||
# the last migration is unapplied and mutating it to include all the
|
||||
# remaining applied migrations.
|
||||
last_unapplied_migration = plan[-1][0]
|
||||
state = states[last_unapplied_migration]
|
||||
# Avoid mutating state with apps rendered as it's an expensive
|
||||
# operation.
|
||||
del state.apps
|
||||
for index, (migration, _) in enumerate(full_plan):
|
||||
if migration == last_unapplied_migration:
|
||||
for migration, _ in full_plan[index:]:
|
||||
if migration in applied_migrations:
|
||||
migration.mutate_state(state, preserve=False)
|
||||
break
|
||||
|
||||
return state
|
||||
|
||||
def apply_migration(self, state, migration, fake=False, fake_initial=False):
|
||||
"""Run a migration forwards."""
|
||||
migration_recorded = False
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_start", migration, fake)
|
||||
if not fake:
|
||||
if fake_initial:
|
||||
# Test to see if this is an already-applied initial migration
|
||||
applied, state = self.detect_soft_applied(state, migration)
|
||||
if applied:
|
||||
fake = True
|
||||
if not fake:
|
||||
# Alright, do it normally
|
||||
with self.connection.schema_editor(
|
||||
atomic=migration.atomic
|
||||
) as schema_editor:
|
||||
state = migration.apply(state, schema_editor)
|
||||
if not schema_editor.deferred_sql:
|
||||
self.record_migration(migration)
|
||||
migration_recorded = True
|
||||
if not migration_recorded:
|
||||
self.record_migration(migration)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("apply_success", migration, fake)
|
||||
return state
|
||||
|
||||
def record_migration(self, migration):
|
||||
# For replacement migrations, record individual statuses
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_applied(app_label, name)
|
||||
else:
|
||||
self.recorder.record_applied(migration.app_label, migration.name)
|
||||
|
||||
def unapply_migration(self, state, migration, fake=False):
|
||||
"""Run a migration backwards."""
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_start", migration, fake)
|
||||
if not fake:
|
||||
with self.connection.schema_editor(
|
||||
atomic=migration.atomic
|
||||
) as schema_editor:
|
||||
state = migration.unapply(state, schema_editor)
|
||||
# For replacement migrations, also record individual statuses.
|
||||
if migration.replaces:
|
||||
for app_label, name in migration.replaces:
|
||||
self.recorder.record_unapplied(app_label, name)
|
||||
self.recorder.record_unapplied(migration.app_label, migration.name)
|
||||
# Report progress
|
||||
if self.progress_callback:
|
||||
self.progress_callback("unapply_success", migration, fake)
|
||||
return state
|
||||
|
||||
def check_replacements(self):
|
||||
"""
|
||||
Mark replacement migrations applied if their replaced set all are.
|
||||
|
||||
Do this unconditionally on every migrate, rather than just when
|
||||
migrations are applied or unapplied, to correctly handle the case
|
||||
when a new squash migration is pushed to a deployment that already had
|
||||
all its replaced migrations applied. In this case no new migration will
|
||||
be applied, but the applied state of the squashed migration must be
|
||||
maintained.
|
||||
"""
|
||||
applied = self.recorder.applied_migrations()
|
||||
for key, migration in self.loader.replacements.items():
|
||||
all_applied = all(m in applied for m in migration.replaces)
|
||||
if all_applied and key not in applied:
|
||||
self.recorder.record_applied(*key)
|
||||
|
||||
def detect_soft_applied(self, project_state, migration):
|
||||
"""
|
||||
Test whether a migration has been implicitly applied - that the
|
||||
tables or columns it would create exist. This is intended only for use
|
||||
on initial migrations (as it only looks for CreateModel and AddField).
|
||||
"""
|
||||
|
||||
def should_skip_detecting_model(migration, model):
|
||||
"""
|
||||
No need to detect tables for proxy models, unmanaged models, or
|
||||
models that can't be migrated on the current database.
|
||||
"""
|
||||
return (
|
||||
model._meta.proxy
|
||||
or not model._meta.managed
|
||||
or not router.allow_migrate(
|
||||
self.connection.alias,
|
||||
migration.app_label,
|
||||
model_name=model._meta.model_name,
|
||||
)
|
||||
)
|
||||
|
||||
if migration.initial is None:
|
||||
# Bail if the migration isn't the first one in its app
|
||||
if any(app == migration.app_label for app, name in migration.dependencies):
|
||||
return False, project_state
|
||||
elif migration.initial is False:
|
||||
# Bail if it's NOT an initial migration
|
||||
return False, project_state
|
||||
|
||||
if project_state is None:
|
||||
after_state = self.loader.project_state(
|
||||
(migration.app_label, migration.name), at_end=True
|
||||
)
|
||||
else:
|
||||
after_state = migration.mutate_state(project_state)
|
||||
apps = after_state.apps
|
||||
found_create_model_migration = False
|
||||
found_add_field_migration = False
|
||||
fold_identifier_case = self.connection.features.ignores_table_name_case
|
||||
with self.connection.cursor() as cursor:
|
||||
existing_table_names = set(
|
||||
self.connection.introspection.table_names(cursor)
|
||||
)
|
||||
if fold_identifier_case:
|
||||
existing_table_names = {
|
||||
name.casefold() for name in existing_table_names
|
||||
}
|
||||
# Make sure all create model and add field operations are done
|
||||
for operation in migration.operations:
|
||||
if isinstance(operation, migrations.CreateModel):
|
||||
model = apps.get_model(migration.app_label, operation.name)
|
||||
if model._meta.swapped:
|
||||
# We have to fetch the model to test with from the
|
||||
# main app cache, as it's not a direct dependency.
|
||||
model = global_apps.get_model(model._meta.swapped)
|
||||
if should_skip_detecting_model(migration, model):
|
||||
continue
|
||||
db_table = model._meta.db_table
|
||||
if fold_identifier_case:
|
||||
db_table = db_table.casefold()
|
||||
if db_table not in existing_table_names:
|
||||
return False, project_state
|
||||
found_create_model_migration = True
|
||||
elif isinstance(operation, migrations.AddField):
|
||||
model = apps.get_model(migration.app_label, operation.model_name)
|
||||
if model._meta.swapped:
|
||||
# We have to fetch the model to test with from the
|
||||
# main app cache, as it's not a direct dependency.
|
||||
model = global_apps.get_model(model._meta.swapped)
|
||||
if should_skip_detecting_model(migration, model):
|
||||
continue
|
||||
|
||||
table = model._meta.db_table
|
||||
field = model._meta.get_field(operation.name)
|
||||
|
||||
# Handle implicit many-to-many tables created by AddField.
|
||||
if field.many_to_many:
|
||||
through_db_table = field.remote_field.through._meta.db_table
|
||||
if fold_identifier_case:
|
||||
through_db_table = through_db_table.casefold()
|
||||
if through_db_table not in existing_table_names:
|
||||
return False, project_state
|
||||
else:
|
||||
found_add_field_migration = True
|
||||
continue
|
||||
with self.connection.cursor() as cursor:
|
||||
columns = self.connection.introspection.get_table_description(
|
||||
cursor, table
|
||||
)
|
||||
for column in columns:
|
||||
field_column = field.column
|
||||
column_name = column.name
|
||||
if fold_identifier_case:
|
||||
column_name = column_name.casefold()
|
||||
field_column = field_column.casefold()
|
||||
if column_name == field_column:
|
||||
found_add_field_migration = True
|
||||
break
|
||||
else:
|
||||
return False, project_state
|
||||
# If we get this far and we found at least one CreateModel or AddField
|
||||
# migration, the migration is considered implicitly applied.
|
||||
return (found_create_model_migration or found_add_field_migration), after_state
|
||||
333
.venv/lib/python3.10/site-packages/django/db/migrations/graph.py
Normal file
333
.venv/lib/python3.10/site-packages/django/db/migrations/graph.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from functools import total_ordering
|
||||
|
||||
from django.db.migrations.state import ProjectState
|
||||
|
||||
from .exceptions import CircularDependencyError, NodeNotFoundError
|
||||
|
||||
|
||||
@total_ordering
|
||||
class Node:
|
||||
"""
|
||||
A single node in the migration graph. Contains direct links to adjacent
|
||||
nodes in either direction.
|
||||
"""
|
||||
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
self.children = set()
|
||||
self.parents = set()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.key == other
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.key < other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.key)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.key[item]
|
||||
|
||||
def __str__(self):
|
||||
return str(self.key)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: (%r, %r)>" % (self.__class__.__name__, self.key[0], self.key[1])
|
||||
|
||||
def add_child(self, child):
|
||||
self.children.add(child)
|
||||
|
||||
def add_parent(self, parent):
|
||||
self.parents.add(parent)
|
||||
|
||||
|
||||
class DummyNode(Node):
|
||||
"""
|
||||
A node that doesn't correspond to a migration file on disk.
|
||||
(A squashed migration that was removed, for example.)
|
||||
|
||||
After the migration graph is processed, all dummy nodes should be removed.
|
||||
If there are any left, a nonexistent dependency error is raised.
|
||||
"""
|
||||
|
||||
def __init__(self, key, origin, error_message):
|
||||
super().__init__(key)
|
||||
self.origin = origin
|
||||
self.error_message = error_message
|
||||
|
||||
def raise_error(self):
|
||||
raise NodeNotFoundError(self.error_message, self.key, origin=self.origin)
|
||||
|
||||
|
||||
class MigrationGraph:
|
||||
"""
|
||||
Represent the digraph of all migrations in a project.
|
||||
|
||||
Each migration is a node, and each dependency is an edge. There are
|
||||
no implicit dependencies between numbered migrations - the numbering is
|
||||
merely a convention to aid file listing. Every new numbered migration
|
||||
has a declared dependency to the previous number, meaning that VCS
|
||||
branch merges can be detected and resolved.
|
||||
|
||||
Migrations files can be marked as replacing another set of migrations -
|
||||
this is to support the "squash" feature. The graph handler isn't responsible
|
||||
for these; instead, the code to load them in here should examine the
|
||||
migration files and if the replaced migrations are all either unapplied
|
||||
or not present, it should ignore the replaced ones, load in just the
|
||||
replacing migration, and repoint any dependencies that pointed to the
|
||||
replaced migrations to point to the replacing one.
|
||||
|
||||
A node should be a tuple: (app_path, migration_name). The tree special-cases
|
||||
things within an app - namely, root nodes and leaf nodes ignore dependencies
|
||||
to other apps.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.node_map = {}
|
||||
self.nodes = {}
|
||||
|
||||
def add_node(self, key, migration):
|
||||
assert key not in self.node_map
|
||||
node = Node(key)
|
||||
self.node_map[key] = node
|
||||
self.nodes[key] = migration
|
||||
|
||||
def add_dummy_node(self, key, origin, error_message):
|
||||
node = DummyNode(key, origin, error_message)
|
||||
self.node_map[key] = node
|
||||
self.nodes[key] = None
|
||||
|
||||
def add_dependency(self, migration, child, parent, skip_validation=False):
|
||||
"""
|
||||
This may create dummy nodes if they don't yet exist. If
|
||||
`skip_validation=True`, validate_consistency() should be called
|
||||
afterward.
|
||||
"""
|
||||
if child not in self.nodes:
|
||||
error_message = (
|
||||
"Migration %s dependencies reference nonexistent"
|
||||
" child node %r" % (migration, child)
|
||||
)
|
||||
self.add_dummy_node(child, migration, error_message)
|
||||
if parent not in self.nodes:
|
||||
error_message = (
|
||||
"Migration %s dependencies reference nonexistent"
|
||||
" parent node %r" % (migration, parent)
|
||||
)
|
||||
self.add_dummy_node(parent, migration, error_message)
|
||||
self.node_map[child].add_parent(self.node_map[parent])
|
||||
self.node_map[parent].add_child(self.node_map[child])
|
||||
if not skip_validation:
|
||||
self.validate_consistency()
|
||||
|
||||
def remove_replaced_nodes(self, replacement, replaced):
|
||||
"""
|
||||
Remove each of the `replaced` nodes (when they exist). Any
|
||||
dependencies that were referencing them are changed to reference the
|
||||
`replacement` node instead.
|
||||
"""
|
||||
# Cast list of replaced keys to set to speed up lookup later.
|
||||
replaced = set(replaced)
|
||||
try:
|
||||
replacement_node = self.node_map[replacement]
|
||||
except KeyError as err:
|
||||
raise NodeNotFoundError(
|
||||
"Unable to find replacement node %r. It was either never added"
|
||||
" to the migration graph, or has been removed." % (replacement,),
|
||||
replacement,
|
||||
) from err
|
||||
for replaced_key in replaced:
|
||||
self.nodes.pop(replaced_key, None)
|
||||
replaced_node = self.node_map.pop(replaced_key, None)
|
||||
if replaced_node:
|
||||
for child in replaced_node.children:
|
||||
child.parents.remove(replaced_node)
|
||||
# We don't want to create dependencies between the replaced
|
||||
# node and the replacement node as this would lead to
|
||||
# self-referencing on the replacement node at a later iteration.
|
||||
if child.key not in replaced:
|
||||
replacement_node.add_child(child)
|
||||
child.add_parent(replacement_node)
|
||||
for parent in replaced_node.parents:
|
||||
parent.children.remove(replaced_node)
|
||||
# Again, to avoid self-referencing.
|
||||
if parent.key not in replaced:
|
||||
replacement_node.add_parent(parent)
|
||||
parent.add_child(replacement_node)
|
||||
|
||||
def remove_replacement_node(self, replacement, replaced):
|
||||
"""
|
||||
The inverse operation to `remove_replaced_nodes`. Almost. Remove the
|
||||
replacement node `replacement` and remap its child nodes to `replaced`
|
||||
- the list of nodes it would have replaced. Don't remap its parent
|
||||
nodes as they are expected to be correct already.
|
||||
"""
|
||||
self.nodes.pop(replacement, None)
|
||||
try:
|
||||
replacement_node = self.node_map.pop(replacement)
|
||||
except KeyError as err:
|
||||
raise NodeNotFoundError(
|
||||
"Unable to remove replacement node %r. It was either never added"
|
||||
" to the migration graph, or has been removed already."
|
||||
% (replacement,),
|
||||
replacement,
|
||||
) from err
|
||||
replaced_nodes = set()
|
||||
replaced_nodes_parents = set()
|
||||
for key in replaced:
|
||||
replaced_node = self.node_map.get(key)
|
||||
if replaced_node:
|
||||
replaced_nodes.add(replaced_node)
|
||||
replaced_nodes_parents |= replaced_node.parents
|
||||
# We're only interested in the latest replaced node, so filter out
|
||||
# replaced nodes that are parents of other replaced nodes.
|
||||
replaced_nodes -= replaced_nodes_parents
|
||||
for child in replacement_node.children:
|
||||
child.parents.remove(replacement_node)
|
||||
for replaced_node in replaced_nodes:
|
||||
replaced_node.add_child(child)
|
||||
child.add_parent(replaced_node)
|
||||
for parent in replacement_node.parents:
|
||||
parent.children.remove(replacement_node)
|
||||
# NOTE: There is no need to remap parent dependencies as we can
|
||||
# assume the replaced nodes already have the correct ancestry.
|
||||
|
||||
def validate_consistency(self):
|
||||
"""Ensure there are no dummy nodes remaining in the graph."""
|
||||
[n.raise_error() for n in self.node_map.values() if isinstance(n, DummyNode)]
|
||||
|
||||
def forwards_plan(self, target):
|
||||
"""
|
||||
Given a node, return a list of which previous nodes (dependencies) must
|
||||
be applied, ending with the node itself. This is the list you would
|
||||
follow if applying the migrations to a database.
|
||||
"""
|
||||
if target not in self.nodes:
|
||||
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
|
||||
return self.iterative_dfs(self.node_map[target])
|
||||
|
||||
def backwards_plan(self, target):
|
||||
"""
|
||||
Given a node, return a list of which dependent nodes (dependencies)
|
||||
must be unapplied, ending with the node itself. This is the list you
|
||||
would follow if removing the migrations from a database.
|
||||
"""
|
||||
if target not in self.nodes:
|
||||
raise NodeNotFoundError("Node %r not a valid node" % (target,), target)
|
||||
return self.iterative_dfs(self.node_map[target], forwards=False)
|
||||
|
||||
def iterative_dfs(self, start, forwards=True):
|
||||
"""Iterative depth-first search for finding dependencies."""
|
||||
visited = []
|
||||
visited_set = set()
|
||||
stack = [(start, False)]
|
||||
while stack:
|
||||
node, processed = stack.pop()
|
||||
if node in visited_set:
|
||||
pass
|
||||
elif processed:
|
||||
visited_set.add(node)
|
||||
visited.append(node.key)
|
||||
else:
|
||||
stack.append((node, True))
|
||||
stack += [
|
||||
(n, False)
|
||||
for n in sorted(node.parents if forwards else node.children)
|
||||
]
|
||||
return visited
|
||||
|
||||
def root_nodes(self, app=None):
|
||||
"""
|
||||
Return all root nodes - that is, nodes with no dependencies inside
|
||||
their app. These are the starting point for an app.
|
||||
"""
|
||||
roots = set()
|
||||
for node in self.nodes:
|
||||
if all(key[0] != node[0] for key in self.node_map[node].parents) and (
|
||||
not app or app == node[0]
|
||||
):
|
||||
roots.add(node)
|
||||
return sorted(roots)
|
||||
|
||||
def leaf_nodes(self, app=None):
|
||||
"""
|
||||
Return all leaf nodes - that is, nodes with no dependents in their app.
|
||||
These are the "most current" version of an app's schema.
|
||||
Having more than one per app is technically an error, but one that
|
||||
gets handled further up, in the interactive command - it's usually the
|
||||
result of a VCS merge and needs some user input.
|
||||
"""
|
||||
leaves = set()
|
||||
for node in self.nodes:
|
||||
if all(key[0] != node[0] for key in self.node_map[node].children) and (
|
||||
not app or app == node[0]
|
||||
):
|
||||
leaves.add(node)
|
||||
return sorted(leaves)
|
||||
|
||||
def ensure_not_cyclic(self):
|
||||
# Algo from GvR:
|
||||
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
|
||||
todo = set(self.nodes)
|
||||
while todo:
|
||||
node = todo.pop()
|
||||
stack = [node]
|
||||
while stack:
|
||||
top = stack[-1]
|
||||
for child in self.node_map[top].children:
|
||||
# Use child.key instead of child to speed up the frequent
|
||||
# hashing.
|
||||
node = child.key
|
||||
if node in stack:
|
||||
cycle = stack[stack.index(node) :]
|
||||
raise CircularDependencyError(
|
||||
", ".join("%s.%s" % n for n in cycle)
|
||||
)
|
||||
if node in todo:
|
||||
stack.append(node)
|
||||
todo.remove(node)
|
||||
break
|
||||
else:
|
||||
node = stack.pop()
|
||||
|
||||
def __str__(self):
|
||||
return "Graph: %s nodes, %s edges" % self._nodes_and_edges()
|
||||
|
||||
def __repr__(self):
|
||||
nodes, edges = self._nodes_and_edges()
|
||||
return "<%s: nodes=%s, edges=%s>" % (self.__class__.__name__, nodes, edges)
|
||||
|
||||
def _nodes_and_edges(self):
|
||||
return len(self.nodes), sum(
|
||||
len(node.parents) for node in self.node_map.values()
|
||||
)
|
||||
|
||||
def _generate_plan(self, nodes, at_end):
|
||||
plan = []
|
||||
for node in nodes:
|
||||
for migration in self.forwards_plan(node):
|
||||
if migration not in plan and (at_end or migration not in nodes):
|
||||
plan.append(migration)
|
||||
return plan
|
||||
|
||||
def make_state(self, nodes=None, at_end=True, real_apps=None):
|
||||
"""
|
||||
Given a migration node or nodes, return a complete ProjectState for it.
|
||||
If at_end is False, return the state before the migration has run.
|
||||
If nodes is not provided, return the overall most current project state.
|
||||
"""
|
||||
if nodes is None:
|
||||
nodes = list(self.leaf_nodes())
|
||||
if not nodes:
|
||||
return ProjectState()
|
||||
if not isinstance(nodes[0], tuple):
|
||||
nodes = [nodes]
|
||||
plan = self._generate_plan(nodes, at_end)
|
||||
project_state = ProjectState(real_apps=real_apps)
|
||||
for node in plan:
|
||||
project_state = self.nodes[node].mutate_state(project_state, preserve=False)
|
||||
return project_state
|
||||
|
||||
def __contains__(self, node):
|
||||
return node in self.nodes
|
||||
@@ -0,0 +1,385 @@
|
||||
import pkgutil
|
||||
import sys
|
||||
from importlib import import_module, reload
|
||||
|
||||
from django.apps import apps
|
||||
from django.conf import settings
|
||||
from django.db.migrations.graph import MigrationGraph
|
||||
from django.db.migrations.recorder import MigrationRecorder
|
||||
|
||||
from .exceptions import (
|
||||
AmbiguityError,
|
||||
BadMigrationError,
|
||||
InconsistentMigrationHistory,
|
||||
NodeNotFoundError,
|
||||
)
|
||||
|
||||
MIGRATIONS_MODULE_NAME = "migrations"
|
||||
|
||||
|
||||
class MigrationLoader:
|
||||
"""
|
||||
Load migration files from disk and their status from the database.
|
||||
|
||||
Migration files are expected to live in the "migrations" directory of
|
||||
an app. Their names are entirely unimportant from a code perspective,
|
||||
but will probably follow the 1234_name.py convention.
|
||||
|
||||
On initialization, this class will scan those directories, and open and
|
||||
read the Python files, looking for a class called Migration, which should
|
||||
inherit from django.db.migrations.Migration. See
|
||||
django.db.migrations.migration for what that looks like.
|
||||
|
||||
Some migrations will be marked as "replacing" another set of migrations.
|
||||
These are loaded into a separate set of migrations away from the main ones.
|
||||
If all the migrations they replace are either unapplied or missing from
|
||||
disk, then they are injected into the main set, replacing the named migrations.
|
||||
Any dependency pointers to the replaced migrations are re-pointed to the
|
||||
new migration.
|
||||
|
||||
This does mean that this class MUST also talk to the database as well as
|
||||
to disk, but this is probably fine. We're already not just operating
|
||||
in memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection,
|
||||
load=True,
|
||||
ignore_no_migrations=False,
|
||||
replace_migrations=True,
|
||||
):
|
||||
self.connection = connection
|
||||
self.disk_migrations = None
|
||||
self.applied_migrations = None
|
||||
self.ignore_no_migrations = ignore_no_migrations
|
||||
self.replace_migrations = replace_migrations
|
||||
if load:
|
||||
self.build_graph()
|
||||
|
||||
@classmethod
|
||||
def migrations_module(cls, app_label):
|
||||
"""
|
||||
Return the path to the migrations module for the specified app_label
|
||||
and a boolean indicating if the module is specified in
|
||||
settings.MIGRATION_MODULE.
|
||||
"""
|
||||
if app_label in settings.MIGRATION_MODULES:
|
||||
return settings.MIGRATION_MODULES[app_label], True
|
||||
else:
|
||||
app_package_name = apps.get_app_config(app_label).name
|
||||
return "%s.%s" % (app_package_name, MIGRATIONS_MODULE_NAME), False
|
||||
|
||||
def load_disk(self):
|
||||
"""Load the migrations from all INSTALLED_APPS from disk."""
|
||||
self.disk_migrations = {}
|
||||
self.unmigrated_apps = set()
|
||||
self.migrated_apps = set()
|
||||
for app_config in apps.get_app_configs():
|
||||
# Get the migrations module directory
|
||||
module_name, explicit = self.migrations_module(app_config.label)
|
||||
if module_name is None:
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
was_loaded = module_name in sys.modules
|
||||
try:
|
||||
module = import_module(module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
if (explicit and self.ignore_no_migrations) or (
|
||||
not explicit and MIGRATIONS_MODULE_NAME in e.name.split(".")
|
||||
):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
raise
|
||||
else:
|
||||
# Module is not a package (e.g. migrations.py).
|
||||
if not hasattr(module, "__path__"):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
# Empty directories are namespaces. Namespace packages have no
|
||||
# __file__ and don't use a list for __path__. See
|
||||
# https://docs.python.org/3/reference/import.html#namespace-packages
|
||||
if getattr(module, "__file__", None) is None and not isinstance(
|
||||
module.__path__, list
|
||||
):
|
||||
self.unmigrated_apps.add(app_config.label)
|
||||
continue
|
||||
# Force a reload if it's already loaded (tests need this)
|
||||
if was_loaded:
|
||||
reload(module)
|
||||
self.migrated_apps.add(app_config.label)
|
||||
migration_names = {
|
||||
name
|
||||
for _, name, is_pkg in pkgutil.iter_modules(module.__path__)
|
||||
if not is_pkg and name[0] not in "_~"
|
||||
}
|
||||
# Load migrations
|
||||
for migration_name in migration_names:
|
||||
migration_path = "%s.%s" % (module_name, migration_name)
|
||||
try:
|
||||
migration_module = import_module(migration_path)
|
||||
except ImportError as e:
|
||||
if "bad magic number" in str(e):
|
||||
raise ImportError(
|
||||
"Couldn't import %r as it appears to be a stale "
|
||||
".pyc file." % migration_path
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
if not hasattr(migration_module, "Migration"):
|
||||
raise BadMigrationError(
|
||||
"Migration %s in app %s has no Migration class"
|
||||
% (migration_name, app_config.label)
|
||||
)
|
||||
self.disk_migrations[app_config.label, migration_name] = (
|
||||
migration_module.Migration(
|
||||
migration_name,
|
||||
app_config.label,
|
||||
)
|
||||
)
|
||||
|
||||
def get_migration(self, app_label, name_prefix):
|
||||
"""Return the named migration or raise NodeNotFoundError."""
|
||||
return self.graph.nodes[app_label, name_prefix]
|
||||
|
||||
def get_migration_by_prefix(self, app_label, name_prefix):
|
||||
"""
|
||||
Return the migration(s) which match the given app label and name_prefix.
|
||||
"""
|
||||
# Do the search
|
||||
results = []
|
||||
for migration_app_label, migration_name in self.disk_migrations:
|
||||
if migration_app_label == app_label and migration_name.startswith(
|
||||
name_prefix
|
||||
):
|
||||
results.append((migration_app_label, migration_name))
|
||||
if len(results) > 1:
|
||||
raise AmbiguityError(
|
||||
"There is more than one migration for '%s' with the prefix '%s'"
|
||||
% (app_label, name_prefix)
|
||||
)
|
||||
elif not results:
|
||||
raise KeyError(
|
||||
f"There is no migration for '{app_label}' with the prefix "
|
||||
f"'{name_prefix}'"
|
||||
)
|
||||
else:
|
||||
return self.disk_migrations[results[0]]
|
||||
|
||||
def check_key(self, key, current_app):
|
||||
if (key[1] != "__first__" and key[1] != "__latest__") or key in self.graph:
|
||||
return key
|
||||
# Special-case __first__, which means "the first migration" for
|
||||
# migrated apps, and is ignored for unmigrated apps. It allows
|
||||
# makemigrations to declare dependencies on apps before they even have
|
||||
# migrations.
|
||||
if key[0] == current_app:
|
||||
# Ignore __first__ references to the same app (#22325)
|
||||
return
|
||||
if key[0] in self.unmigrated_apps:
|
||||
# This app isn't migrated, but something depends on it.
|
||||
# The models will get auto-added into the state, though
|
||||
# so we're fine.
|
||||
return
|
||||
if key[0] in self.migrated_apps:
|
||||
try:
|
||||
if key[1] == "__first__":
|
||||
return self.graph.root_nodes(key[0])[0]
|
||||
else: # "__latest__"
|
||||
return self.graph.leaf_nodes(key[0])[0]
|
||||
except IndexError:
|
||||
if self.ignore_no_migrations:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
"Dependency on app with no migrations: %s" % key[0]
|
||||
)
|
||||
raise ValueError("Dependency on unknown app: %s" % key[0])
|
||||
|
||||
def add_internal_dependencies(self, key, migration):
|
||||
"""
|
||||
Internal dependencies need to be added first to ensure `__first__`
|
||||
dependencies find the correct root node.
|
||||
"""
|
||||
for parent in migration.dependencies:
|
||||
# Ignore __first__ references to the same app.
|
||||
if parent[0] == key[0] and parent[1] != "__first__":
|
||||
self.graph.add_dependency(migration, key, parent, skip_validation=True)
|
||||
|
||||
def add_external_dependencies(self, key, migration):
|
||||
for parent in migration.dependencies:
|
||||
# Skip internal dependencies
|
||||
if key[0] == parent[0]:
|
||||
continue
|
||||
parent = self.check_key(parent, key[0])
|
||||
if parent is not None:
|
||||
self.graph.add_dependency(migration, key, parent, skip_validation=True)
|
||||
for child in migration.run_before:
|
||||
child = self.check_key(child, key[0])
|
||||
if child is not None:
|
||||
self.graph.add_dependency(migration, child, key, skip_validation=True)
|
||||
|
||||
def build_graph(self):
|
||||
"""
|
||||
Build a migration dependency graph using both the disk and database.
|
||||
You'll need to rebuild the graph if you apply migrations. This isn't
|
||||
usually a problem as generally migration stuff runs in a one-shot process.
|
||||
"""
|
||||
# Load disk data
|
||||
self.load_disk()
|
||||
# Load database data
|
||||
if self.connection is None:
|
||||
self.applied_migrations = {}
|
||||
else:
|
||||
recorder = MigrationRecorder(self.connection)
|
||||
self.applied_migrations = recorder.applied_migrations()
|
||||
# To start, populate the migration graph with nodes for ALL migrations
|
||||
# and their dependencies. Also make note of replacing migrations at this step.
|
||||
self.graph = MigrationGraph()
|
||||
self.replacements = {}
|
||||
for key, migration in self.disk_migrations.items():
|
||||
self.graph.add_node(key, migration)
|
||||
# Replacing migrations.
|
||||
if migration.replaces:
|
||||
self.replacements[key] = migration
|
||||
for key, migration in self.disk_migrations.items():
|
||||
# Internal (same app) dependencies.
|
||||
self.add_internal_dependencies(key, migration)
|
||||
# Add external dependencies now that the internal ones have been resolved.
|
||||
for key, migration in self.disk_migrations.items():
|
||||
self.add_external_dependencies(key, migration)
|
||||
# Carry out replacements where possible and if enabled.
|
||||
if self.replace_migrations:
|
||||
for key, migration in self.replacements.items():
|
||||
# Get applied status of each of this migration's replacement
|
||||
# targets.
|
||||
applied_statuses = [
|
||||
(target in self.applied_migrations) for target in migration.replaces
|
||||
]
|
||||
# The replacing migration is only marked as applied if all of
|
||||
# its replacement targets are.
|
||||
if all(applied_statuses):
|
||||
self.applied_migrations[key] = migration
|
||||
else:
|
||||
self.applied_migrations.pop(key, None)
|
||||
# A replacing migration can be used if either all or none of
|
||||
# its replacement targets have been applied.
|
||||
if all(applied_statuses) or (not any(applied_statuses)):
|
||||
self.graph.remove_replaced_nodes(key, migration.replaces)
|
||||
else:
|
||||
# This replacing migration cannot be used because it is
|
||||
# partially applied. Remove it from the graph and remap
|
||||
# dependencies to it (#25945).
|
||||
self.graph.remove_replacement_node(key, migration.replaces)
|
||||
# Ensure the graph is consistent.
|
||||
try:
|
||||
self.graph.validate_consistency()
|
||||
except NodeNotFoundError as exc:
|
||||
# Check if the missing node could have been replaced by any squash
|
||||
# migration but wasn't because the squash migration was partially
|
||||
# applied before. In that case raise a more understandable exception
|
||||
# (#23556).
|
||||
# Get reverse replacements.
|
||||
reverse_replacements = {}
|
||||
for key, migration in self.replacements.items():
|
||||
for replaced in migration.replaces:
|
||||
reverse_replacements.setdefault(replaced, set()).add(key)
|
||||
# Try to reraise exception with more detail.
|
||||
if exc.node in reverse_replacements:
|
||||
candidates = reverse_replacements.get(exc.node, set())
|
||||
is_replaced = any(
|
||||
candidate in self.graph.nodes for candidate in candidates
|
||||
)
|
||||
if not is_replaced:
|
||||
tries = ", ".join("%s.%s" % c for c in candidates)
|
||||
raise NodeNotFoundError(
|
||||
"Migration {0} depends on nonexistent node ('{1}', '{2}'). "
|
||||
"Django tried to replace migration {1}.{2} with any of [{3}] "
|
||||
"but wasn't able to because some of the replaced migrations "
|
||||
"are already applied.".format(
|
||||
exc.origin, exc.node[0], exc.node[1], tries
|
||||
),
|
||||
exc.node,
|
||||
) from exc
|
||||
raise
|
||||
self.graph.ensure_not_cyclic()
|
||||
|
||||
def check_consistent_history(self, connection):
|
||||
"""
|
||||
Raise InconsistentMigrationHistory if any applied migrations have
|
||||
unapplied dependencies.
|
||||
"""
|
||||
recorder = MigrationRecorder(connection)
|
||||
applied = recorder.applied_migrations()
|
||||
for migration in applied:
|
||||
# If the migration is unknown, skip it.
|
||||
if migration not in self.graph.nodes:
|
||||
continue
|
||||
for parent in self.graph.node_map[migration].parents:
|
||||
if parent not in applied:
|
||||
# Skip unapplied squashed migrations that have all of their
|
||||
# `replaces` applied.
|
||||
if parent in self.replacements:
|
||||
if all(
|
||||
m in applied for m in self.replacements[parent].replaces
|
||||
):
|
||||
continue
|
||||
raise InconsistentMigrationHistory(
|
||||
"Migration {}.{} is applied before its dependency "
|
||||
"{}.{} on database '{}'.".format(
|
||||
migration[0],
|
||||
migration[1],
|
||||
parent[0],
|
||||
parent[1],
|
||||
connection.alias,
|
||||
)
|
||||
)
|
||||
|
||||
def detect_conflicts(self):
|
||||
"""
|
||||
Look through the loaded graph and detect any conflicts - apps
|
||||
with more than one leaf migration. Return a dict of the app labels
|
||||
that conflict with the migration names that conflict.
|
||||
"""
|
||||
seen_apps = {}
|
||||
conflicting_apps = set()
|
||||
for app_label, migration_name in self.graph.leaf_nodes():
|
||||
if app_label in seen_apps:
|
||||
conflicting_apps.add(app_label)
|
||||
seen_apps.setdefault(app_label, set()).add(migration_name)
|
||||
return {
|
||||
app_label: sorted(seen_apps[app_label]) for app_label in conflicting_apps
|
||||
}
|
||||
|
||||
def project_state(self, nodes=None, at_end=True):
|
||||
"""
|
||||
Return a ProjectState object representing the most recent state
|
||||
that the loaded migrations represent.
|
||||
|
||||
See graph.make_state() for the meaning of "nodes" and "at_end".
|
||||
"""
|
||||
return self.graph.make_state(
|
||||
nodes=nodes, at_end=at_end, real_apps=self.unmigrated_apps
|
||||
)
|
||||
|
||||
def collect_sql(self, plan):
|
||||
"""
|
||||
Take a migration plan and return a list of collected SQL statements
|
||||
that represent the best-efforts version of that plan.
|
||||
"""
|
||||
statements = []
|
||||
state = None
|
||||
for migration, backwards in plan:
|
||||
with self.connection.schema_editor(
|
||||
collect_sql=True, atomic=migration.atomic
|
||||
) as schema_editor:
|
||||
if state is None:
|
||||
state = self.project_state(
|
||||
(migration.app_label, migration.name), at_end=False
|
||||
)
|
||||
if not backwards:
|
||||
state = migration.apply(state, schema_editor, collect_sql=True)
|
||||
else:
|
||||
state = migration.unapply(state, schema_editor, collect_sql=True)
|
||||
statements.extend(schema_editor.collected_sql)
|
||||
return statements
|
||||
@@ -0,0 +1,239 @@
|
||||
import re
|
||||
|
||||
from django.db.migrations.utils import get_migration_name_timestamp
|
||||
from django.db.transaction import atomic
|
||||
|
||||
from .exceptions import IrreversibleError
|
||||
|
||||
|
||||
class Migration:
|
||||
"""
|
||||
The base class for all migrations.
|
||||
|
||||
Migration files will import this from django.db.migrations.Migration
|
||||
and subclass it as a class called Migration. It will have one or more
|
||||
of the following attributes:
|
||||
|
||||
- operations: A list of Operation instances, probably from
|
||||
django.db.migrations.operations
|
||||
- dependencies: A list of tuples of (app_path, migration_name)
|
||||
- run_before: A list of tuples of (app_path, migration_name)
|
||||
- replaces: A list of migration_names
|
||||
|
||||
Note that all migrations come out of migrations and into the Loader or
|
||||
Graph as instances, having been initialized with their app label and name.
|
||||
"""
|
||||
|
||||
# Operations to apply during this migration, in order.
|
||||
operations = []
|
||||
|
||||
# Other migrations that should be run before this migration.
|
||||
# Should be a list of (app, migration_name).
|
||||
dependencies = []
|
||||
|
||||
# Other migrations that should be run after this one (i.e. have
|
||||
# this migration added to their dependencies). Useful to make third-party
|
||||
# apps' migrations run after your AUTH_USER replacement, for example.
|
||||
run_before = []
|
||||
|
||||
# Migration names in this app that this migration replaces. If this is
|
||||
# non-empty, this migration will only be applied if all these migrations
|
||||
# are not applied.
|
||||
replaces = []
|
||||
|
||||
# Is this an initial migration? Initial migrations are skipped on
|
||||
# --fake-initial if the table or fields already exist. If None, check if
|
||||
# the migration has any dependencies to determine if there are dependencies
|
||||
# to tell if db introspection needs to be done. If True, always perform
|
||||
# introspection. If False, never perform introspection.
|
||||
initial = None
|
||||
|
||||
# Whether to wrap the whole migration in a transaction. Only has an effect
|
||||
# on database backends which support transactional DDL.
|
||||
atomic = True
|
||||
|
||||
def __init__(self, name, app_label):
|
||||
self.name = name
|
||||
self.app_label = app_label
|
||||
# Copy dependencies & other attrs as we might mutate them at runtime
|
||||
self.operations = list(self.__class__.operations)
|
||||
self.dependencies = list(self.__class__.dependencies)
|
||||
self.run_before = list(self.__class__.run_before)
|
||||
self.replaces = list(self.__class__.replaces)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, Migration)
|
||||
and self.name == other.name
|
||||
and self.app_label == other.app_label
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<Migration %s.%s>" % (self.app_label, self.name)
|
||||
|
||||
def __str__(self):
|
||||
return "%s.%s" % (self.app_label, self.name)
|
||||
|
||||
def __hash__(self):
|
||||
return hash("%s.%s" % (self.app_label, self.name))
|
||||
|
||||
def mutate_state(self, project_state, preserve=True):
|
||||
"""
|
||||
Take a ProjectState and return a new one with the migration's
|
||||
operations applied to it. Preserve the original object state by
|
||||
default and return a mutated state from a copy.
|
||||
"""
|
||||
new_state = project_state
|
||||
if preserve:
|
||||
new_state = project_state.clone()
|
||||
|
||||
for operation in self.operations:
|
||||
operation.state_forwards(self.app_label, new_state)
|
||||
return new_state
|
||||
|
||||
def apply(self, project_state, schema_editor, collect_sql=False):
|
||||
"""
|
||||
Take a project_state representing all migrations prior to this one
|
||||
and a schema_editor for a live database and apply the migration
|
||||
in a forwards order.
|
||||
|
||||
Return the resulting project state for efficient reuse by following
|
||||
Migrations.
|
||||
"""
|
||||
for operation in self.operations:
|
||||
# If this operation cannot be represented as SQL, place a comment
|
||||
# there instead
|
||||
if collect_sql:
|
||||
schema_editor.collected_sql.append("--")
|
||||
schema_editor.collected_sql.append("-- %s" % operation.describe())
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
schema_editor.collected_sql.append(
|
||||
"-- THIS OPERATION CANNOT BE WRITTEN AS SQL"
|
||||
)
|
||||
continue
|
||||
collected_sql_before = len(schema_editor.collected_sql)
|
||||
# Save the state before the operation has run
|
||||
old_state = project_state.clone()
|
||||
operation.state_forwards(self.app_label, project_state)
|
||||
# Run the operation
|
||||
atomic_operation = operation.atomic or (
|
||||
self.atomic and operation.atomic is not False
|
||||
)
|
||||
if not schema_editor.atomic_migration and atomic_operation:
|
||||
# Force a transaction on a non-transactional-DDL backend or an
|
||||
# atomic operation inside a non-atomic migration.
|
||||
with atomic(schema_editor.connection.alias):
|
||||
operation.database_forwards(
|
||||
self.app_label, schema_editor, old_state, project_state
|
||||
)
|
||||
else:
|
||||
# Normal behaviour
|
||||
operation.database_forwards(
|
||||
self.app_label, schema_editor, old_state, project_state
|
||||
)
|
||||
if collect_sql and collected_sql_before == len(schema_editor.collected_sql):
|
||||
schema_editor.collected_sql.append("-- (no-op)")
|
||||
return project_state
|
||||
|
||||
def unapply(self, project_state, schema_editor, collect_sql=False):
|
||||
"""
|
||||
Take a project_state representing all migrations prior to this one
|
||||
and a schema_editor for a live database and apply the migration
|
||||
in a reverse order.
|
||||
|
||||
The backwards migration process consists of two phases:
|
||||
|
||||
1. The intermediate states from right before the first until right
|
||||
after the last operation inside this migration are preserved.
|
||||
2. The operations are applied in reverse order using the states
|
||||
recorded in step 1.
|
||||
"""
|
||||
# Construct all the intermediate states we need for a reverse migration
|
||||
to_run = []
|
||||
new_state = project_state
|
||||
# Phase 1
|
||||
for operation in self.operations:
|
||||
# If it's irreversible, error out
|
||||
if not operation.reversible:
|
||||
raise IrreversibleError(
|
||||
"Operation %s in %s is not reversible" % (operation, self)
|
||||
)
|
||||
# Preserve new state from previous run to not tamper the same state
|
||||
# over all operations
|
||||
new_state = new_state.clone()
|
||||
old_state = new_state.clone()
|
||||
operation.state_forwards(self.app_label, new_state)
|
||||
to_run.insert(0, (operation, old_state, new_state))
|
||||
|
||||
# Phase 2
|
||||
for operation, to_state, from_state in to_run:
|
||||
if collect_sql:
|
||||
schema_editor.collected_sql.append("--")
|
||||
schema_editor.collected_sql.append("-- %s" % operation.describe())
|
||||
schema_editor.collected_sql.append("--")
|
||||
if not operation.reduces_to_sql:
|
||||
schema_editor.collected_sql.append(
|
||||
"-- THIS OPERATION CANNOT BE WRITTEN AS SQL"
|
||||
)
|
||||
continue
|
||||
collected_sql_before = len(schema_editor.collected_sql)
|
||||
atomic_operation = operation.atomic or (
|
||||
self.atomic and operation.atomic is not False
|
||||
)
|
||||
if not schema_editor.atomic_migration and atomic_operation:
|
||||
# Force a transaction on a non-transactional-DDL backend or an
|
||||
# atomic operation inside a non-atomic migration.
|
||||
with atomic(schema_editor.connection.alias):
|
||||
operation.database_backwards(
|
||||
self.app_label, schema_editor, from_state, to_state
|
||||
)
|
||||
else:
|
||||
# Normal behaviour
|
||||
operation.database_backwards(
|
||||
self.app_label, schema_editor, from_state, to_state
|
||||
)
|
||||
if collect_sql and collected_sql_before == len(schema_editor.collected_sql):
|
||||
schema_editor.collected_sql.append("-- (no-op)")
|
||||
return project_state
|
||||
|
||||
def suggest_name(self):
|
||||
"""
|
||||
Suggest a name for the operations this migration might represent. Names
|
||||
are not guaranteed to be unique, but put some effort into the fallback
|
||||
name to avoid VCS conflicts if possible.
|
||||
"""
|
||||
if self.initial:
|
||||
return "initial"
|
||||
|
||||
raw_fragments = [op.migration_name_fragment for op in self.operations]
|
||||
fragments = [re.sub(r"\W+", "_", name) for name in raw_fragments if name]
|
||||
|
||||
if not fragments or len(fragments) != len(self.operations):
|
||||
return "auto_%s" % get_migration_name_timestamp()
|
||||
|
||||
name = fragments[0]
|
||||
for fragment in fragments[1:]:
|
||||
new_name = f"{name}_{fragment}"
|
||||
if len(new_name) > 52:
|
||||
name = f"{name}_and_more"
|
||||
break
|
||||
name = new_name
|
||||
return name
|
||||
|
||||
|
||||
class SwappableTuple(tuple):
|
||||
"""
|
||||
Subclass of tuple so Django can tell this was originally a swappable
|
||||
dependency when it reads the migration file.
|
||||
"""
|
||||
|
||||
def __new__(cls, value, setting):
|
||||
self = tuple.__new__(cls, value)
|
||||
self.setting = setting
|
||||
return self
|
||||
|
||||
|
||||
def swappable_dependency(value):
|
||||
"""Turn a setting value into a dependency."""
|
||||
return SwappableTuple((value.split(".", 1)[0], "__first__"), value)
|
||||
@@ -0,0 +1,46 @@
|
||||
from .fields import AddField, AlterField, RemoveField, RenameField
|
||||
from .models import (
|
||||
AddConstraint,
|
||||
AddIndex,
|
||||
AlterConstraint,
|
||||
AlterIndexTogether,
|
||||
AlterModelManagers,
|
||||
AlterModelOptions,
|
||||
AlterModelTable,
|
||||
AlterModelTableComment,
|
||||
AlterOrderWithRespectTo,
|
||||
AlterUniqueTogether,
|
||||
CreateModel,
|
||||
DeleteModel,
|
||||
RemoveConstraint,
|
||||
RemoveIndex,
|
||||
RenameIndex,
|
||||
RenameModel,
|
||||
)
|
||||
from .special import RunPython, RunSQL, SeparateDatabaseAndState
|
||||
|
||||
__all__ = [
|
||||
"CreateModel",
|
||||
"DeleteModel",
|
||||
"AlterModelTable",
|
||||
"AlterModelTableComment",
|
||||
"AlterUniqueTogether",
|
||||
"RenameModel",
|
||||
"AlterIndexTogether",
|
||||
"AlterModelOptions",
|
||||
"AddIndex",
|
||||
"RemoveIndex",
|
||||
"RenameIndex",
|
||||
"AddField",
|
||||
"RemoveField",
|
||||
"AlterField",
|
||||
"RenameField",
|
||||
"AddConstraint",
|
||||
"RemoveConstraint",
|
||||
"AlterConstraint",
|
||||
"SeparateDatabaseAndState",
|
||||
"RunSQL",
|
||||
"RunPython",
|
||||
"AlterOrderWithRespectTo",
|
||||
"AlterModelManagers",
|
||||
]
|
||||
@@ -0,0 +1,166 @@
|
||||
import enum
|
||||
|
||||
from django.db import router
|
||||
|
||||
|
||||
class OperationCategory(str, enum.Enum):
|
||||
ADDITION = "+"
|
||||
REMOVAL = "-"
|
||||
ALTERATION = "~"
|
||||
PYTHON = "p"
|
||||
SQL = "s"
|
||||
MIXED = "?"
|
||||
|
||||
|
||||
class Operation:
|
||||
"""
|
||||
Base class for migration operations.
|
||||
|
||||
It's responsible for both mutating the in-memory model state
|
||||
(see db/migrations/state.py) to represent what it performs, as well
|
||||
as actually performing it against a live database.
|
||||
|
||||
Note that some operations won't modify memory state at all (e.g. data
|
||||
copying operations), and some will need their modifications to be
|
||||
optionally specified by the user (e.g. custom Python code snippets)
|
||||
|
||||
Due to the way this class deals with deconstruction, it should be
|
||||
considered immutable.
|
||||
"""
|
||||
|
||||
# If this migration can be run in reverse.
|
||||
# Some operations are impossible to reverse, like deleting data.
|
||||
reversible = True
|
||||
|
||||
# Can this migration be represented as SQL? (things like RunPython cannot)
|
||||
reduces_to_sql = True
|
||||
|
||||
# Should this operation be forced as atomic even on backends with no
|
||||
# DDL transaction support (i.e., does it have no DDL, like RunPython)
|
||||
atomic = False
|
||||
|
||||
# Should this operation be considered safe to elide and optimize across?
|
||||
elidable = False
|
||||
|
||||
serialization_expand_args = []
|
||||
|
||||
category = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# We capture the arguments to make returning them trivial
|
||||
self = object.__new__(cls)
|
||||
self._constructor_args = (args, kwargs)
|
||||
return self
|
||||
|
||||
def deconstruct(self):
|
||||
"""
|
||||
Return a 3-tuple of class import path (or just name if it lives
|
||||
under django.db.migrations), positional arguments, and keyword
|
||||
arguments.
|
||||
"""
|
||||
return (
|
||||
self.__class__.__name__,
|
||||
self._constructor_args[0],
|
||||
self._constructor_args[1],
|
||||
)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
"""
|
||||
Take the state from the previous migration, and mutate it
|
||||
so that it matches what this migration would perform.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of Operation must provide a state_forwards() method"
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
"""
|
||||
Perform the mutation on the database schema in the normal
|
||||
(forwards) direction.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of Operation must provide a database_forwards() method"
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
"""
|
||||
Perform the mutation on the database schema in the reverse
|
||||
direction - e.g. if this were CreateModel, it would in fact
|
||||
drop the model's table.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"subclasses of Operation must provide a database_backwards() method"
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
"""
|
||||
Output a brief summary of what the action does.
|
||||
"""
|
||||
return "%s: %s" % (self.__class__.__name__, self._constructor_args)
|
||||
|
||||
def formatted_description(self):
|
||||
"""Output a description prefixed by a category symbol."""
|
||||
description = self.describe()
|
||||
if self.category is None:
|
||||
return f"{OperationCategory.MIXED.value} {description}"
|
||||
return f"{self.category.value} {description}"
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
"""
|
||||
A filename part suitable for automatically naming a migration
|
||||
containing this operation, or None if not applicable.
|
||||
"""
|
||||
return None
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
"""
|
||||
Return True if there is a chance this operation references the given
|
||||
model name (as a string), with an app label for accuracy.
|
||||
|
||||
Used for optimization. If in doubt, return True;
|
||||
returning a false positive will merely make the optimizer a little
|
||||
less efficient, while returning a false negative may result in an
|
||||
unusable optimized migration.
|
||||
"""
|
||||
return True
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
"""
|
||||
Return True if there is a chance this operation references the given
|
||||
field name, with an app label for accuracy.
|
||||
|
||||
Used for optimization. If in doubt, return True.
|
||||
"""
|
||||
return self.references_model(model_name, app_label)
|
||||
|
||||
def allow_migrate_model(self, connection_alias, model):
|
||||
"""
|
||||
Return whether or not a model may be migrated.
|
||||
|
||||
This is a thin wrapper around router.allow_migrate_model() that
|
||||
preemptively rejects any proxy, swapped out, or unmanaged model.
|
||||
"""
|
||||
if not model._meta.can_migrate(connection_alias):
|
||||
return False
|
||||
|
||||
return router.allow_migrate_model(connection_alias, model)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
"""
|
||||
Return either a list of operations the actual operation should be
|
||||
replaced with or a boolean that indicates whether or not the specified
|
||||
operation can be optimized across.
|
||||
"""
|
||||
if self.elidable:
|
||||
return [operation]
|
||||
elif operation.elidable:
|
||||
return [self]
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s %s%s>" % (
|
||||
self.__class__.__name__,
|
||||
", ".join(map(repr, self._constructor_args[0])),
|
||||
",".join(" %s=%r" % x for x in self._constructor_args[1].items()),
|
||||
)
|
||||
@@ -0,0 +1,365 @@
|
||||
from django.db.migrations.utils import field_references
|
||||
from django.db.models import NOT_PROVIDED
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from .base import Operation, OperationCategory
|
||||
|
||||
|
||||
class FieldOperation(Operation):
|
||||
def __init__(self, model_name, name, field=None):
|
||||
self.model_name = model_name
|
||||
self.name = name
|
||||
self.field = field
|
||||
|
||||
@cached_property
|
||||
def model_name_lower(self):
|
||||
return self.model_name.lower()
|
||||
|
||||
@cached_property
|
||||
def name_lower(self):
|
||||
return self.name.lower()
|
||||
|
||||
def is_same_model_operation(self, operation):
|
||||
return self.model_name_lower == operation.model_name_lower
|
||||
|
||||
def is_same_field_operation(self, operation):
|
||||
return (
|
||||
self.is_same_model_operation(operation)
|
||||
and self.name_lower == operation.name_lower
|
||||
)
|
||||
|
||||
def references_model(self, name, app_label):
|
||||
name_lower = name.lower()
|
||||
if name_lower == self.model_name_lower:
|
||||
return True
|
||||
if self.field:
|
||||
return bool(
|
||||
field_references(
|
||||
(app_label, self.model_name_lower),
|
||||
self.field,
|
||||
(app_label, name_lower),
|
||||
)
|
||||
)
|
||||
return False
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
model_name_lower = model_name.lower()
|
||||
# Check if this operation locally references the field.
|
||||
if model_name_lower == self.model_name_lower:
|
||||
if name == self.name:
|
||||
return True
|
||||
elif (
|
||||
self.field
|
||||
and hasattr(self.field, "from_fields")
|
||||
and name in self.field.from_fields
|
||||
):
|
||||
return True
|
||||
# Check if this operation remotely references the field.
|
||||
if self.field is None:
|
||||
return False
|
||||
return bool(
|
||||
field_references(
|
||||
(app_label, self.model_name_lower),
|
||||
self.field,
|
||||
(app_label, model_name_lower),
|
||||
name,
|
||||
)
|
||||
)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
return super().reduce(operation, app_label) or not operation.references_field(
|
||||
self.model_name, self.name, app_label
|
||||
)
|
||||
|
||||
|
||||
class AddField(FieldOperation):
|
||||
"""Add a field to a model."""
|
||||
|
||||
category = OperationCategory.ADDITION
|
||||
|
||||
def __init__(self, model_name, name, field, preserve_default=True):
|
||||
self.preserve_default = preserve_default
|
||||
super().__init__(model_name, name, field)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"model_name": self.model_name,
|
||||
"name": self.name,
|
||||
"field": self.field,
|
||||
}
|
||||
if self.preserve_default is not True:
|
||||
kwargs["preserve_default"] = self.preserve_default
|
||||
return (self.__class__.__name__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.add_field(
|
||||
app_label,
|
||||
self.model_name_lower,
|
||||
self.name,
|
||||
self.field,
|
||||
self.preserve_default,
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
field = to_model._meta.get_field(self.name)
|
||||
if not self.preserve_default:
|
||||
field.default = self.field.default
|
||||
schema_editor.add_field(
|
||||
from_model,
|
||||
field,
|
||||
)
|
||||
if not self.preserve_default:
|
||||
field.default = NOT_PROVIDED
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
|
||||
schema_editor.remove_field(
|
||||
from_model, from_model._meta.get_field(self.name)
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Add field %s to %s" % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return "%s_%s" % (self.model_name_lower, self.name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if isinstance(operation, FieldOperation) and self.is_same_field_operation(
|
||||
operation
|
||||
):
|
||||
if isinstance(operation, AlterField):
|
||||
return [
|
||||
AddField(
|
||||
model_name=self.model_name,
|
||||
name=operation.name,
|
||||
field=operation.field,
|
||||
),
|
||||
]
|
||||
elif isinstance(operation, RemoveField):
|
||||
return []
|
||||
elif isinstance(operation, RenameField):
|
||||
return [
|
||||
AddField(
|
||||
model_name=self.model_name,
|
||||
name=operation.new_name,
|
||||
field=self.field,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class RemoveField(FieldOperation):
|
||||
"""Remove a field from a model."""
|
||||
|
||||
category = OperationCategory.REMOVAL
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"model_name": self.model_name,
|
||||
"name": self.name,
|
||||
}
|
||||
return (self.__class__.__name__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.remove_field(app_label, self.model_name_lower, self.name)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, from_model):
|
||||
schema_editor.remove_field(
|
||||
from_model, from_model._meta.get_field(self.name)
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.add_field(from_model, to_model._meta.get_field(self.name))
|
||||
|
||||
def describe(self):
|
||||
return "Remove field %s from %s" % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return "remove_%s_%s" % (self.model_name_lower, self.name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
from .models import DeleteModel
|
||||
|
||||
if (
|
||||
isinstance(operation, DeleteModel)
|
||||
and operation.name_lower == self.model_name_lower
|
||||
):
|
||||
return [operation]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class AlterField(FieldOperation):
|
||||
"""
|
||||
Alter a field's database column (e.g. null, max_length) to the provided
|
||||
new field.
|
||||
"""
|
||||
|
||||
category = OperationCategory.ALTERATION
|
||||
|
||||
def __init__(self, model_name, name, field, preserve_default=True):
|
||||
self.preserve_default = preserve_default
|
||||
super().__init__(model_name, name, field)
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"model_name": self.model_name,
|
||||
"name": self.name,
|
||||
"field": self.field,
|
||||
}
|
||||
if self.preserve_default is not True:
|
||||
kwargs["preserve_default"] = self.preserve_default
|
||||
return (self.__class__.__name__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.alter_field(
|
||||
app_label,
|
||||
self.model_name_lower,
|
||||
self.name,
|
||||
self.field,
|
||||
self.preserve_default,
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
from_field = from_model._meta.get_field(self.name)
|
||||
to_field = to_model._meta.get_field(self.name)
|
||||
if not self.preserve_default:
|
||||
to_field.default = self.field.default
|
||||
schema_editor.alter_field(from_model, from_field, to_field)
|
||||
if not self.preserve_default:
|
||||
to_field.default = NOT_PROVIDED
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
self.database_forwards(app_label, schema_editor, from_state, to_state)
|
||||
|
||||
def describe(self):
|
||||
return "Alter field %s on %s" % (self.name, self.model_name)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return "alter_%s_%s" % (self.model_name_lower, self.name_lower)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if isinstance(
|
||||
operation, (AlterField, RemoveField)
|
||||
) and self.is_same_field_operation(operation):
|
||||
return [operation]
|
||||
elif (
|
||||
isinstance(operation, RenameField)
|
||||
and self.is_same_field_operation(operation)
|
||||
and self.field.db_column is None
|
||||
):
|
||||
return [
|
||||
operation,
|
||||
AlterField(
|
||||
model_name=self.model_name,
|
||||
name=operation.new_name,
|
||||
field=self.field,
|
||||
),
|
||||
]
|
||||
return super().reduce(operation, app_label)
|
||||
|
||||
|
||||
class RenameField(FieldOperation):
|
||||
"""Rename a field on the model. Might affect db_column too."""
|
||||
|
||||
category = OperationCategory.ALTERATION
|
||||
|
||||
def __init__(self, model_name, old_name, new_name):
|
||||
self.old_name = old_name
|
||||
self.new_name = new_name
|
||||
super().__init__(model_name, old_name)
|
||||
|
||||
@cached_property
|
||||
def old_name_lower(self):
|
||||
return self.old_name.lower()
|
||||
|
||||
@cached_property
|
||||
def new_name_lower(self):
|
||||
return self.new_name.lower()
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"model_name": self.model_name,
|
||||
"old_name": self.old_name,
|
||||
"new_name": self.new_name,
|
||||
}
|
||||
return (self.__class__.__name__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
state.rename_field(
|
||||
app_label, self.model_name_lower, self.old_name, self.new_name
|
||||
)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.alter_field(
|
||||
from_model,
|
||||
from_model._meta.get_field(self.old_name),
|
||||
to_model._meta.get_field(self.new_name),
|
||||
)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
to_model = to_state.apps.get_model(app_label, self.model_name)
|
||||
if self.allow_migrate_model(schema_editor.connection.alias, to_model):
|
||||
from_model = from_state.apps.get_model(app_label, self.model_name)
|
||||
schema_editor.alter_field(
|
||||
from_model,
|
||||
from_model._meta.get_field(self.new_name),
|
||||
to_model._meta.get_field(self.old_name),
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Rename field %s on %s to %s" % (
|
||||
self.old_name,
|
||||
self.model_name,
|
||||
self.new_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def migration_name_fragment(self):
|
||||
return "rename_%s_%s_%s" % (
|
||||
self.old_name_lower,
|
||||
self.model_name_lower,
|
||||
self.new_name_lower,
|
||||
)
|
||||
|
||||
def references_field(self, model_name, name, app_label):
|
||||
return self.references_model(model_name, app_label) and (
|
||||
name.lower() == self.old_name_lower or name.lower() == self.new_name_lower
|
||||
)
|
||||
|
||||
def reduce(self, operation, app_label):
|
||||
if (
|
||||
isinstance(operation, RenameField)
|
||||
and self.is_same_model_operation(operation)
|
||||
and self.new_name_lower == operation.old_name_lower
|
||||
):
|
||||
return [
|
||||
RenameField(
|
||||
self.model_name,
|
||||
self.old_name,
|
||||
operation.new_name,
|
||||
),
|
||||
]
|
||||
# Skip `FieldOperation.reduce` as we want to run `references_field`
|
||||
# against self.old_name and self.new_name.
|
||||
return super(FieldOperation, self).reduce(operation, app_label) or not (
|
||||
operation.references_field(self.model_name, self.old_name, app_label)
|
||||
or operation.references_field(self.model_name, self.new_name, app_label)
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,211 @@
|
||||
from django.db import router
|
||||
|
||||
from .base import Operation, OperationCategory
|
||||
|
||||
|
||||
class SeparateDatabaseAndState(Operation):
|
||||
"""
|
||||
Take two lists of operations - ones that will be used for the database,
|
||||
and ones that will be used for the state change. This allows operations
|
||||
that don't support state change to have it applied, or have operations
|
||||
that affect the state or not the database, or so on.
|
||||
"""
|
||||
|
||||
category = OperationCategory.MIXED
|
||||
serialization_expand_args = ["database_operations", "state_operations"]
|
||||
|
||||
def __init__(self, database_operations=None, state_operations=None):
|
||||
self.database_operations = database_operations or []
|
||||
self.state_operations = state_operations or []
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {}
|
||||
if self.database_operations:
|
||||
kwargs["database_operations"] = self.database_operations
|
||||
if self.state_operations:
|
||||
kwargs["state_operations"] = self.state_operations
|
||||
return (self.__class__.__qualname__, [], kwargs)
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
for state_operation in self.state_operations:
|
||||
state_operation.state_forwards(app_label, state)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# We calculate state separately in here since our state functions aren't useful
|
||||
for database_operation in self.database_operations:
|
||||
to_state = from_state.clone()
|
||||
database_operation.state_forwards(app_label, to_state)
|
||||
database_operation.database_forwards(
|
||||
app_label, schema_editor, from_state, to_state
|
||||
)
|
||||
from_state = to_state
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# We calculate state separately in here since our state functions aren't useful
|
||||
to_states = {}
|
||||
for dbop in self.database_operations:
|
||||
to_states[dbop] = to_state
|
||||
to_state = to_state.clone()
|
||||
dbop.state_forwards(app_label, to_state)
|
||||
# to_state now has the states of all the database_operations applied
|
||||
# which is the from_state for the backwards migration of the last
|
||||
# operation.
|
||||
for database_operation in reversed(self.database_operations):
|
||||
from_state = to_state
|
||||
to_state = to_states[database_operation]
|
||||
database_operation.database_backwards(
|
||||
app_label, schema_editor, from_state, to_state
|
||||
)
|
||||
|
||||
def describe(self):
|
||||
return "Custom state/database change combination"
|
||||
|
||||
|
||||
class RunSQL(Operation):
|
||||
"""
|
||||
Run some raw SQL. A reverse SQL statement may be provided.
|
||||
|
||||
Also accept a list of operations that represent the state change effected
|
||||
by this SQL change, in case it's custom column/table creation/deletion.
|
||||
"""
|
||||
|
||||
category = OperationCategory.SQL
|
||||
noop = ""
|
||||
|
||||
def __init__(
|
||||
self, sql, reverse_sql=None, state_operations=None, hints=None, elidable=False
|
||||
):
|
||||
self.sql = sql
|
||||
self.reverse_sql = reverse_sql
|
||||
self.state_operations = state_operations or []
|
||||
self.hints = hints or {}
|
||||
self.elidable = elidable
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"sql": self.sql,
|
||||
}
|
||||
if self.reverse_sql is not None:
|
||||
kwargs["reverse_sql"] = self.reverse_sql
|
||||
if self.state_operations:
|
||||
kwargs["state_operations"] = self.state_operations
|
||||
if self.hints:
|
||||
kwargs["hints"] = self.hints
|
||||
return (self.__class__.__qualname__, [], kwargs)
|
||||
|
||||
@property
|
||||
def reversible(self):
|
||||
return self.reverse_sql is not None
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
for state_operation in self.state_operations:
|
||||
state_operation.state_forwards(app_label, state)
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if router.allow_migrate(
|
||||
schema_editor.connection.alias, app_label, **self.hints
|
||||
):
|
||||
self._run_sql(schema_editor, self.sql)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if self.reverse_sql is None:
|
||||
raise NotImplementedError("You cannot reverse this operation")
|
||||
if router.allow_migrate(
|
||||
schema_editor.connection.alias, app_label, **self.hints
|
||||
):
|
||||
self._run_sql(schema_editor, self.reverse_sql)
|
||||
|
||||
def describe(self):
|
||||
return "Raw SQL operation"
|
||||
|
||||
def _run_sql(self, schema_editor, sqls):
|
||||
if isinstance(sqls, (list, tuple)):
|
||||
for sql in sqls:
|
||||
params = None
|
||||
if isinstance(sql, (list, tuple)):
|
||||
elements = len(sql)
|
||||
if elements == 2:
|
||||
sql, params = sql
|
||||
else:
|
||||
raise ValueError("Expected a 2-tuple but got %d" % elements)
|
||||
schema_editor.execute(sql, params=params)
|
||||
elif sqls != RunSQL.noop:
|
||||
statements = schema_editor.connection.ops.prepare_sql_script(sqls)
|
||||
for statement in statements:
|
||||
schema_editor.execute(statement, params=None)
|
||||
|
||||
|
||||
class RunPython(Operation):
|
||||
"""
|
||||
Run Python code in a context suitable for doing versioned ORM operations.
|
||||
"""
|
||||
|
||||
category = OperationCategory.PYTHON
|
||||
reduces_to_sql = False
|
||||
|
||||
def __init__(
|
||||
self, code, reverse_code=None, atomic=None, hints=None, elidable=False
|
||||
):
|
||||
self.atomic = atomic
|
||||
# Forwards code
|
||||
if not callable(code):
|
||||
raise ValueError("RunPython must be supplied with a callable")
|
||||
self.code = code
|
||||
# Reverse code
|
||||
if reverse_code is None:
|
||||
self.reverse_code = None
|
||||
else:
|
||||
if not callable(reverse_code):
|
||||
raise ValueError("RunPython must be supplied with callable arguments")
|
||||
self.reverse_code = reverse_code
|
||||
self.hints = hints or {}
|
||||
self.elidable = elidable
|
||||
|
||||
def deconstruct(self):
|
||||
kwargs = {
|
||||
"code": self.code,
|
||||
}
|
||||
if self.reverse_code is not None:
|
||||
kwargs["reverse_code"] = self.reverse_code
|
||||
if self.atomic is not None:
|
||||
kwargs["atomic"] = self.atomic
|
||||
if self.hints:
|
||||
kwargs["hints"] = self.hints
|
||||
return (self.__class__.__qualname__, [], kwargs)
|
||||
|
||||
@property
|
||||
def reversible(self):
|
||||
return self.reverse_code is not None
|
||||
|
||||
def state_forwards(self, app_label, state):
|
||||
# RunPython objects have no state effect. To add some, combine this
|
||||
# with SeparateDatabaseAndState.
|
||||
pass
|
||||
|
||||
def database_forwards(self, app_label, schema_editor, from_state, to_state):
|
||||
# RunPython has access to all models. Ensure that all models are
|
||||
# reloaded in case any are delayed.
|
||||
from_state.clear_delayed_apps_cache()
|
||||
if router.allow_migrate(
|
||||
schema_editor.connection.alias, app_label, **self.hints
|
||||
):
|
||||
# We now execute the Python code in a context that contains a 'models'
|
||||
# object, representing the versioned models as an app registry.
|
||||
# We could try to override the global cache, but then people will still
|
||||
# use direct imports, so we go with a documentation approach instead.
|
||||
self.code(from_state.apps, schema_editor)
|
||||
|
||||
def database_backwards(self, app_label, schema_editor, from_state, to_state):
|
||||
if self.reverse_code is None:
|
||||
raise NotImplementedError("You cannot reverse this operation")
|
||||
if router.allow_migrate(
|
||||
schema_editor.connection.alias, app_label, **self.hints
|
||||
):
|
||||
self.reverse_code(from_state.apps, schema_editor)
|
||||
|
||||
def describe(self):
|
||||
return "Raw Python operation"
|
||||
|
||||
@staticmethod
|
||||
def noop(apps, schema_editor):
|
||||
return None
|
||||
@@ -0,0 +1,69 @@
|
||||
class MigrationOptimizer:
|
||||
"""
|
||||
Power the optimization process, where you provide a list of Operations
|
||||
and you are returned a list of equal or shorter length - operations
|
||||
are merged into one if possible.
|
||||
|
||||
For example, a CreateModel and an AddField can be optimized into a
|
||||
new CreateModel, and CreateModel and DeleteModel can be optimized into
|
||||
nothing.
|
||||
"""
|
||||
|
||||
def optimize(self, operations, app_label):
|
||||
"""
|
||||
Main optimization entry point. Pass in a list of Operation instances,
|
||||
get out a new list of Operation instances.
|
||||
|
||||
Unfortunately, due to the scope of the optimization (two combinable
|
||||
operations might be separated by several hundred others), this can't be
|
||||
done as a peephole optimization with checks/output implemented on
|
||||
the Operations themselves; instead, the optimizer looks at each
|
||||
individual operation and scans forwards in the list to see if there
|
||||
are any matches, stopping at boundaries - operations which can't
|
||||
be optimized over (RunSQL, operations on the same field/model, etc.)
|
||||
|
||||
The inner loop is run until the starting list is the same as the result
|
||||
list, and then the result is returned. This means that operation
|
||||
optimization must be stable and always return an equal or shorter list.
|
||||
"""
|
||||
# Internal tracking variable for test assertions about # of loops
|
||||
if app_label is None:
|
||||
raise TypeError("app_label must be a str.")
|
||||
self._iterations = 0
|
||||
while True:
|
||||
result = self.optimize_inner(operations, app_label)
|
||||
self._iterations += 1
|
||||
if result == operations:
|
||||
return result
|
||||
operations = result
|
||||
|
||||
def optimize_inner(self, operations, app_label):
|
||||
"""Inner optimization loop."""
|
||||
new_operations = []
|
||||
for i, operation in enumerate(operations):
|
||||
right = True # Should we reduce on the right or on the left.
|
||||
# Compare it to each operation after it
|
||||
for j, other in enumerate(operations[i + 1 :]):
|
||||
result = operation.reduce(other, app_label)
|
||||
if isinstance(result, list):
|
||||
in_between = operations[i + 1 : i + j + 1]
|
||||
if right:
|
||||
new_operations.extend(in_between)
|
||||
new_operations.extend(result)
|
||||
elif all(op.reduce(other, app_label) is True for op in in_between):
|
||||
# Perform a left reduction if all of the in-between
|
||||
# operations can optimize through other.
|
||||
new_operations.extend(result)
|
||||
new_operations.extend(in_between)
|
||||
else:
|
||||
# Otherwise keep trying.
|
||||
new_operations.append(operation)
|
||||
break
|
||||
new_operations.extend(operations[i + j + 2 :])
|
||||
return new_operations
|
||||
elif not result:
|
||||
# Can't perform a right reduction.
|
||||
right = False
|
||||
else:
|
||||
new_operations.append(operation)
|
||||
return new_operations
|
||||
@@ -0,0 +1,347 @@
|
||||
import datetime
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
from django.apps import apps
|
||||
from django.core.management.base import OutputWrapper
|
||||
from django.db.models import NOT_PROVIDED
|
||||
from django.utils import timezone
|
||||
from django.utils.version import get_docs_version
|
||||
|
||||
from .loader import MigrationLoader
|
||||
|
||||
|
||||
class MigrationQuestioner:
|
||||
"""
|
||||
Give the autodetector responses to questions it might have.
|
||||
This base class has a built-in noninteractive mode, but the
|
||||
interactive subclass is what the command-line arguments will use.
|
||||
"""
|
||||
|
||||
def __init__(self, defaults=None, specified_apps=None, dry_run=None):
|
||||
self.defaults = defaults or {}
|
||||
self.specified_apps = specified_apps or set()
|
||||
self.dry_run = dry_run
|
||||
|
||||
def ask_initial(self, app_label):
|
||||
"""Should we create an initial migration for the app?"""
|
||||
# If it was specified on the command line, definitely true
|
||||
if app_label in self.specified_apps:
|
||||
return True
|
||||
# Otherwise, we look to see if it has a migrations module
|
||||
# without any Python files in it, apart from __init__.py.
|
||||
# Apps from the new app template will have these; the Python
|
||||
# file check will ensure we skip South ones.
|
||||
try:
|
||||
app_config = apps.get_app_config(app_label)
|
||||
except LookupError: # It's a fake app.
|
||||
return self.defaults.get("ask_initial", False)
|
||||
migrations_import_path, _ = MigrationLoader.migrations_module(app_config.label)
|
||||
if migrations_import_path is None:
|
||||
# It's an application with migrations disabled.
|
||||
return self.defaults.get("ask_initial", False)
|
||||
try:
|
||||
migrations_module = importlib.import_module(migrations_import_path)
|
||||
except ImportError:
|
||||
return self.defaults.get("ask_initial", False)
|
||||
else:
|
||||
if getattr(migrations_module, "__file__", None):
|
||||
filenames = os.listdir(os.path.dirname(migrations_module.__file__))
|
||||
elif hasattr(migrations_module, "__path__"):
|
||||
if len(migrations_module.__path__) > 1:
|
||||
return False
|
||||
filenames = os.listdir(list(migrations_module.__path__)[0])
|
||||
return not any(x.endswith(".py") for x in filenames if x != "__init__.py")
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
"""Adding a NOT NULL field to a model."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
"""Changing a NULL field to NOT NULL."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_rename(self, model_name, old_name, new_name, field_instance):
|
||||
"""Was this field really renamed?"""
|
||||
return self.defaults.get("ask_rename", False)
|
||||
|
||||
def ask_rename_model(self, old_model_state, new_model_state):
|
||||
"""Was this model really renamed?"""
|
||||
return self.defaults.get("ask_rename_model", False)
|
||||
|
||||
def ask_merge(self, app_label):
|
||||
"""Should these migrations really be merged?"""
|
||||
return self.defaults.get("ask_merge", False)
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
"""Adding an auto_now_add field to a model."""
|
||||
# None means quit
|
||||
return None
|
||||
|
||||
def ask_unique_callable_default_addition(self, field_name, model_name):
|
||||
"""Adding a unique field with a callable default."""
|
||||
# None means continue.
|
||||
return None
|
||||
|
||||
|
||||
class InteractiveMigrationQuestioner(MigrationQuestioner):
|
||||
def __init__(
|
||||
self, defaults=None, specified_apps=None, dry_run=None, prompt_output=None
|
||||
):
|
||||
super().__init__(
|
||||
defaults=defaults, specified_apps=specified_apps, dry_run=dry_run
|
||||
)
|
||||
self.prompt_output = prompt_output or OutputWrapper(sys.stdout)
|
||||
|
||||
def _boolean_input(self, question, default=None):
|
||||
self.prompt_output.write(f"{question} ", ending="")
|
||||
result = input()
|
||||
if not result and default is not None:
|
||||
return default
|
||||
while not result or result[0].lower() not in "yn":
|
||||
self.prompt_output.write("Please answer yes or no: ", ending="")
|
||||
result = input()
|
||||
return result[0].lower() == "y"
|
||||
|
||||
def _choice_input(self, question, choices):
|
||||
self.prompt_output.write(f"{question}")
|
||||
for i, choice in enumerate(choices):
|
||||
self.prompt_output.write(" %s) %s" % (i + 1, choice))
|
||||
self.prompt_output.write("Select an option: ", ending="")
|
||||
while True:
|
||||
try:
|
||||
result = input()
|
||||
value = int(result)
|
||||
except ValueError:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
self.prompt_output.write("\nCancelled.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
if 0 < value <= len(choices):
|
||||
return value
|
||||
self.prompt_output.write("Please select a valid option: ", ending="")
|
||||
|
||||
def _ask_default(self, default=""):
|
||||
"""
|
||||
Prompt for a default value.
|
||||
|
||||
The ``default`` argument allows providing a custom default value (as a
|
||||
string) which will be shown to the user and used as the return value
|
||||
if the user doesn't provide any other input.
|
||||
"""
|
||||
self.prompt_output.write("Please enter the default value as valid Python.")
|
||||
if default:
|
||||
self.prompt_output.write(
|
||||
f"Accept the default '{default}' by pressing 'Enter' or "
|
||||
f"provide another value."
|
||||
)
|
||||
self.prompt_output.write(
|
||||
"The datetime and django.utils.timezone modules are available, so "
|
||||
"it is possible to provide e.g. timezone.now as a value."
|
||||
)
|
||||
self.prompt_output.write("Type 'exit' to exit this prompt")
|
||||
while True:
|
||||
if default:
|
||||
prompt = "[default: {}] >>> ".format(default)
|
||||
else:
|
||||
prompt = ">>> "
|
||||
self.prompt_output.write(prompt, ending="")
|
||||
try:
|
||||
code = input()
|
||||
except KeyboardInterrupt:
|
||||
self.prompt_output.write("\nCancelled.")
|
||||
sys.exit(1)
|
||||
if not code and default:
|
||||
code = default
|
||||
if not code:
|
||||
self.prompt_output.write(
|
||||
"Please enter some code, or 'exit' (without quotes) to exit."
|
||||
)
|
||||
elif code == "exit":
|
||||
sys.exit(1)
|
||||
else:
|
||||
try:
|
||||
return eval(code, {}, {"datetime": datetime, "timezone": timezone})
|
||||
except Exception as e:
|
||||
self.prompt_output.write(f"{e.__class__.__name__}: {e}")
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
"""Adding a NOT NULL field to a model."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
f"It is impossible to add a non-nullable field '{field_name}' "
|
||||
f"to {model_name} without specifying a default. This is "
|
||||
f"because the database needs something to populate existing "
|
||||
f"rows.\n"
|
||||
f"Please select a fix:",
|
||||
[
|
||||
(
|
||||
"Provide a one-off default now (will be set on all existing "
|
||||
"rows with a null value for this column)"
|
||||
),
|
||||
"Quit and manually define a default value in models.py.",
|
||||
],
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default()
|
||||
return None
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
"""Changing a NULL field to NOT NULL."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
f"It is impossible to change a nullable field '{field_name}' "
|
||||
f"on {model_name} to non-nullable without providing a "
|
||||
f"default. This is because the database needs something to "
|
||||
f"populate existing rows.\n"
|
||||
f"Please select a fix:",
|
||||
[
|
||||
(
|
||||
"Provide a one-off default now (will be set on all existing "
|
||||
"rows with a null value for this column)"
|
||||
),
|
||||
"Ignore for now. Existing rows that contain NULL values "
|
||||
"will have to be handled manually, for example with a "
|
||||
"RunPython or RunSQL operation.",
|
||||
"Quit and manually define a default value in models.py.",
|
||||
],
|
||||
)
|
||||
if choice == 2:
|
||||
return NOT_PROVIDED
|
||||
elif choice == 3:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default()
|
||||
return None
|
||||
|
||||
def ask_rename(self, model_name, old_name, new_name, field_instance):
|
||||
"""Was this field really renamed?"""
|
||||
msg = "Was %s.%s renamed to %s.%s (a %s)? [y/N]"
|
||||
return self._boolean_input(
|
||||
msg
|
||||
% (
|
||||
model_name,
|
||||
old_name,
|
||||
model_name,
|
||||
new_name,
|
||||
field_instance.__class__.__name__,
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
def ask_rename_model(self, old_model_state, new_model_state):
|
||||
"""Was this model really renamed?"""
|
||||
msg = "Was the model %s.%s renamed to %s? [y/N]"
|
||||
return self._boolean_input(
|
||||
msg
|
||||
% (old_model_state.app_label, old_model_state.name, new_model_state.name),
|
||||
False,
|
||||
)
|
||||
|
||||
def ask_merge(self, app_label):
|
||||
return self._boolean_input(
|
||||
"\nMerging will only work if the operations printed above do not conflict\n"
|
||||
+ "with each other (working on different fields or models)\n"
|
||||
+ "Should these migration branches be merged? [y/N]",
|
||||
False,
|
||||
)
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
"""Adding an auto_now_add field to a model."""
|
||||
if not self.dry_run:
|
||||
choice = self._choice_input(
|
||||
f"It is impossible to add the field '{field_name}' with "
|
||||
f"'auto_now_add=True' to {model_name} without providing a "
|
||||
f"default. This is because the database needs something to "
|
||||
f"populate existing rows.\n",
|
||||
[
|
||||
"Provide a one-off default now which will be set on all "
|
||||
"existing rows",
|
||||
"Quit and manually define a default value in models.py.",
|
||||
],
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
else:
|
||||
return self._ask_default(default="timezone.now")
|
||||
return None
|
||||
|
||||
def ask_unique_callable_default_addition(self, field_name, model_name):
|
||||
"""Adding a unique field with a callable default."""
|
||||
if not self.dry_run:
|
||||
version = get_docs_version()
|
||||
choice = self._choice_input(
|
||||
f"Callable default on unique field {model_name}.{field_name} "
|
||||
f"will not generate unique values upon migrating.\n"
|
||||
f"Please choose how to proceed:\n",
|
||||
[
|
||||
f"Continue making this migration as the first step in "
|
||||
f"writing a manual migration to generate unique values "
|
||||
f"described here: "
|
||||
f"https://docs.djangoproject.com/en/{version}/howto/"
|
||||
f"writing-migrations/#migrations-that-add-unique-fields.",
|
||||
"Quit and edit field options in models.py.",
|
||||
],
|
||||
)
|
||||
if choice == 2:
|
||||
sys.exit(3)
|
||||
return None
|
||||
|
||||
|
||||
class NonInteractiveMigrationQuestioner(MigrationQuestioner):
|
||||
def __init__(
|
||||
self,
|
||||
defaults=None,
|
||||
specified_apps=None,
|
||||
dry_run=None,
|
||||
verbosity=1,
|
||||
log=None,
|
||||
):
|
||||
self.verbosity = verbosity
|
||||
self.log = log
|
||||
super().__init__(
|
||||
defaults=defaults,
|
||||
specified_apps=specified_apps,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
def log_lack_of_migration(self, field_name, model_name, reason):
|
||||
if self.verbosity > 0:
|
||||
self.log(
|
||||
f"Field '{field_name}' on model '{model_name}' not migrated: "
|
||||
f"{reason}."
|
||||
)
|
||||
|
||||
def ask_not_null_addition(self, field_name, model_name):
|
||||
# We can't ask the user, so act like the user aborted.
|
||||
self.log_lack_of_migration(
|
||||
field_name,
|
||||
model_name,
|
||||
"it is impossible to add a non-nullable field without specifying "
|
||||
"a default",
|
||||
)
|
||||
sys.exit(3)
|
||||
|
||||
def ask_not_null_alteration(self, field_name, model_name):
|
||||
# We can't ask the user, so set as not provided.
|
||||
self.log(
|
||||
f"Field '{field_name}' on model '{model_name}' given a default of "
|
||||
f"NOT PROVIDED and must be corrected."
|
||||
)
|
||||
return NOT_PROVIDED
|
||||
|
||||
def ask_auto_now_add_addition(self, field_name, model_name):
|
||||
# We can't ask the user, so act like the user aborted.
|
||||
self.log_lack_of_migration(
|
||||
field_name,
|
||||
model_name,
|
||||
"it is impossible to add a field with 'auto_now_add=True' without "
|
||||
"specifying a default",
|
||||
)
|
||||
sys.exit(3)
|
||||
@@ -0,0 +1,111 @@
|
||||
from django.apps.registry import Apps
|
||||
from django.db import DatabaseError, models
|
||||
from django.utils.functional import classproperty
|
||||
from django.utils.timezone import now
|
||||
|
||||
from .exceptions import MigrationSchemaMissing
|
||||
|
||||
|
||||
class MigrationRecorder:
|
||||
"""
|
||||
Deal with storing migration records in the database.
|
||||
|
||||
Because this table is actually itself used for dealing with model
|
||||
creation, it's the one thing we can't do normally via migrations.
|
||||
We manually handle table creation/schema updating (using schema backend)
|
||||
and then have a floating model to do queries with.
|
||||
|
||||
If a migration is unapplied its row is removed from the table. Having
|
||||
a row in the table always means a migration is applied.
|
||||
"""
|
||||
|
||||
_migration_class = None
|
||||
|
||||
@classproperty
|
||||
def Migration(cls):
|
||||
"""
|
||||
Lazy load to avoid AppRegistryNotReady if installed apps import
|
||||
MigrationRecorder.
|
||||
"""
|
||||
if cls._migration_class is None:
|
||||
|
||||
class Migration(models.Model):
|
||||
app = models.CharField(max_length=255)
|
||||
name = models.CharField(max_length=255)
|
||||
applied = models.DateTimeField(default=now)
|
||||
|
||||
class Meta:
|
||||
apps = Apps()
|
||||
app_label = "migrations"
|
||||
db_table = "django_migrations"
|
||||
|
||||
def __str__(self):
|
||||
return "Migration %s for %s" % (self.name, self.app)
|
||||
|
||||
cls._migration_class = Migration
|
||||
return cls._migration_class
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self._has_table = False
|
||||
|
||||
@property
|
||||
def migration_qs(self):
|
||||
return self.Migration.objects.using(self.connection.alias)
|
||||
|
||||
def has_table(self):
|
||||
"""Return True if the django_migrations table exists."""
|
||||
# If the migrations table has already been confirmed to exist, don't
|
||||
# recheck it's existence.
|
||||
if self._has_table:
|
||||
return True
|
||||
# It hasn't been confirmed to exist, recheck.
|
||||
with self.connection.cursor() as cursor:
|
||||
tables = self.connection.introspection.table_names(cursor)
|
||||
|
||||
self._has_table = self.Migration._meta.db_table in tables
|
||||
return self._has_table
|
||||
|
||||
def ensure_schema(self):
|
||||
"""Ensure the table exists and has the correct schema."""
|
||||
# If the table's there, that's fine - we've never changed its schema
|
||||
# in the codebase.
|
||||
if self.has_table():
|
||||
return
|
||||
# Make the table
|
||||
try:
|
||||
with self.connection.schema_editor() as editor:
|
||||
editor.create_model(self.Migration)
|
||||
except DatabaseError as exc:
|
||||
raise MigrationSchemaMissing(
|
||||
"Unable to create the django_migrations table (%s)" % exc
|
||||
)
|
||||
|
||||
def applied_migrations(self):
|
||||
"""
|
||||
Return a dict mapping (app_name, migration_name) to Migration instances
|
||||
for all applied migrations.
|
||||
"""
|
||||
if self.has_table():
|
||||
return {
|
||||
(migration.app, migration.name): migration
|
||||
for migration in self.migration_qs
|
||||
}
|
||||
else:
|
||||
# If the django_migrations table doesn't exist, then no migrations
|
||||
# are applied.
|
||||
return {}
|
||||
|
||||
def record_applied(self, app, name):
|
||||
"""Record that a migration was applied."""
|
||||
self.ensure_schema()
|
||||
self.migration_qs.create(app=app, name=name)
|
||||
|
||||
def record_unapplied(self, app, name):
|
||||
"""Record that a migration was unapplied."""
|
||||
self.ensure_schema()
|
||||
self.migration_qs.filter(app=app, name=name).delete()
|
||||
|
||||
def flush(self):
|
||||
"""Delete all migration records. Useful for testing migrations."""
|
||||
self.migration_qs.all().delete()
|
||||
@@ -0,0 +1,405 @@
|
||||
import builtins
|
||||
import collections.abc
|
||||
import datetime
|
||||
import decimal
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import types
|
||||
import uuid
|
||||
|
||||
from django.conf import SettingsReference
|
||||
from django.db import models
|
||||
from django.db.migrations.operations.base import Operation
|
||||
from django.db.migrations.utils import COMPILED_REGEX_TYPE, RegexObject
|
||||
from django.utils.functional import LazyObject, Promise
|
||||
from django.utils.version import PY311, get_docs_version
|
||||
|
||||
FUNCTION_TYPES = (types.FunctionType, types.BuiltinFunctionType, types.MethodType)
|
||||
|
||||
if isinstance(functools._lru_cache_wrapper, type):
|
||||
# When using CPython's _functools C module, LRU cache function decorators
|
||||
# present as a class and not a function, so add that class to the list of
|
||||
# function types. In the pure Python implementation and PyPy they present
|
||||
# as normal functions which are already handled.
|
||||
FUNCTION_TYPES += (functools._lru_cache_wrapper,)
|
||||
|
||||
|
||||
class BaseSerializer:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def serialize(self):
|
||||
raise NotImplementedError(
|
||||
"Subclasses of BaseSerializer must implement the serialize() method."
|
||||
)
|
||||
|
||||
|
||||
class BaseSequenceSerializer(BaseSerializer):
|
||||
def _format(self):
|
||||
raise NotImplementedError(
|
||||
"Subclasses of BaseSequenceSerializer must implement the _format() method."
|
||||
)
|
||||
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for item in self.value:
|
||||
item_string, item_imports = serializer_factory(item).serialize()
|
||||
imports.update(item_imports)
|
||||
strings.append(item_string)
|
||||
value = self._format()
|
||||
return value % (", ".join(strings)), imports
|
||||
|
||||
|
||||
class BaseUnorderedSequenceSerializer(BaseSequenceSerializer):
|
||||
def __init__(self, value):
|
||||
super().__init__(sorted(value, key=repr))
|
||||
|
||||
|
||||
class BaseSimpleSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(self.value), set()
|
||||
|
||||
|
||||
class ChoicesSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return serializer_factory(self.value.value).serialize()
|
||||
|
||||
|
||||
class DateTimeSerializer(BaseSerializer):
|
||||
"""For datetime.*, except datetime.datetime."""
|
||||
|
||||
def serialize(self):
|
||||
return repr(self.value), {"import datetime"}
|
||||
|
||||
|
||||
class DatetimeDatetimeSerializer(BaseSerializer):
|
||||
"""For datetime.datetime."""
|
||||
|
||||
def serialize(self):
|
||||
if self.value.tzinfo is not None and self.value.tzinfo != datetime.timezone.utc:
|
||||
self.value = self.value.astimezone(datetime.timezone.utc)
|
||||
imports = ["import datetime"]
|
||||
return repr(self.value), set(imports)
|
||||
|
||||
|
||||
class DecimalSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(self.value), {"from decimal import Decimal"}
|
||||
|
||||
|
||||
class DeconstructableSerializer(BaseSerializer):
|
||||
@staticmethod
|
||||
def serialize_deconstructed(path, args, kwargs):
|
||||
name, imports = DeconstructableSerializer._serialize_path(path)
|
||||
strings = []
|
||||
for arg in args:
|
||||
arg_string, arg_imports = serializer_factory(arg).serialize()
|
||||
strings.append(arg_string)
|
||||
imports.update(arg_imports)
|
||||
for kw, arg in sorted(kwargs.items()):
|
||||
arg_string, arg_imports = serializer_factory(arg).serialize()
|
||||
imports.update(arg_imports)
|
||||
strings.append("%s=%s" % (kw, arg_string))
|
||||
return "%s(%s)" % (name, ", ".join(strings)), imports
|
||||
|
||||
@staticmethod
|
||||
def _serialize_path(path):
|
||||
module, name = path.rsplit(".", 1)
|
||||
if module == "django.db.models":
|
||||
imports = {"from django.db import models"}
|
||||
name = "models.%s" % name
|
||||
else:
|
||||
imports = {"import %s" % module}
|
||||
name = path
|
||||
return name, imports
|
||||
|
||||
def serialize(self):
|
||||
return self.serialize_deconstructed(*self.value.deconstruct())
|
||||
|
||||
|
||||
class DictionarySerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for k, v in sorted(self.value.items()):
|
||||
k_string, k_imports = serializer_factory(k).serialize()
|
||||
v_string, v_imports = serializer_factory(v).serialize()
|
||||
imports.update(k_imports)
|
||||
imports.update(v_imports)
|
||||
strings.append((k_string, v_string))
|
||||
return "{%s}" % (", ".join("%s: %s" % (k, v) for k, v in strings)), imports
|
||||
|
||||
|
||||
class EnumSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
enum_class = self.value.__class__
|
||||
module = enum_class.__module__
|
||||
if issubclass(enum_class, enum.Flag):
|
||||
if PY311:
|
||||
members = list(self.value)
|
||||
else:
|
||||
members, _ = enum._decompose(enum_class, self.value)
|
||||
members = reversed(members)
|
||||
else:
|
||||
members = (self.value,)
|
||||
return (
|
||||
" | ".join(
|
||||
[
|
||||
f"{module}.{enum_class.__qualname__}[{item.name!r}]"
|
||||
for item in members
|
||||
]
|
||||
),
|
||||
{"import %s" % module},
|
||||
)
|
||||
|
||||
|
||||
class FloatSerializer(BaseSimpleSerializer):
|
||||
def serialize(self):
|
||||
if math.isnan(self.value) or math.isinf(self.value):
|
||||
return 'float("{}")'.format(self.value), set()
|
||||
return super().serialize()
|
||||
|
||||
|
||||
class FrozensetSerializer(BaseUnorderedSequenceSerializer):
|
||||
def _format(self):
|
||||
return "frozenset([%s])"
|
||||
|
||||
|
||||
class FunctionTypeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
if getattr(self.value, "__self__", None) and isinstance(
|
||||
self.value.__self__, type
|
||||
):
|
||||
klass = self.value.__self__
|
||||
module = klass.__module__
|
||||
return "%s.%s.%s" % (module, klass.__qualname__, self.value.__name__), {
|
||||
"import %s" % module
|
||||
}
|
||||
# Further error checking
|
||||
if self.value.__name__ == "<lambda>":
|
||||
raise ValueError("Cannot serialize function: lambda")
|
||||
if self.value.__module__ is None:
|
||||
raise ValueError("Cannot serialize function %r: No module" % self.value)
|
||||
|
||||
module_name = self.value.__module__
|
||||
|
||||
if "<" not in self.value.__qualname__: # Qualname can include <locals>
|
||||
return "%s.%s" % (module_name, self.value.__qualname__), {
|
||||
"import %s" % self.value.__module__
|
||||
}
|
||||
|
||||
raise ValueError(
|
||||
"Could not find function %s in %s.\n" % (self.value.__name__, module_name)
|
||||
)
|
||||
|
||||
|
||||
class FunctoolsPartialSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
# Serialize functools.partial() arguments
|
||||
func_string, func_imports = serializer_factory(self.value.func).serialize()
|
||||
args_string, args_imports = serializer_factory(self.value.args).serialize()
|
||||
keywords_string, keywords_imports = serializer_factory(
|
||||
self.value.keywords
|
||||
).serialize()
|
||||
# Add any imports needed by arguments
|
||||
imports = {"import functools", *func_imports, *args_imports, *keywords_imports}
|
||||
return (
|
||||
"functools.%s(%s, *%s, **%s)"
|
||||
% (
|
||||
self.value.__class__.__name__,
|
||||
func_string,
|
||||
args_string,
|
||||
keywords_string,
|
||||
),
|
||||
imports,
|
||||
)
|
||||
|
||||
|
||||
class IterableSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
imports = set()
|
||||
strings = []
|
||||
for item in self.value:
|
||||
item_string, item_imports = serializer_factory(item).serialize()
|
||||
imports.update(item_imports)
|
||||
strings.append(item_string)
|
||||
# When len(strings)==0, the empty iterable should be serialized as
|
||||
# "()", not "(,)" because (,) is invalid Python syntax.
|
||||
value = "(%s)" if len(strings) != 1 else "(%s,)"
|
||||
return value % (", ".join(strings)), imports
|
||||
|
||||
|
||||
class ModelFieldSerializer(DeconstructableSerializer):
|
||||
def serialize(self):
|
||||
attr_name, path, args, kwargs = self.value.deconstruct()
|
||||
return self.serialize_deconstructed(path, args, kwargs)
|
||||
|
||||
|
||||
class ModelManagerSerializer(DeconstructableSerializer):
|
||||
def serialize(self):
|
||||
as_manager, manager_path, qs_path, args, kwargs = self.value.deconstruct()
|
||||
if as_manager:
|
||||
name, imports = self._serialize_path(qs_path)
|
||||
return "%s.as_manager()" % name, imports
|
||||
else:
|
||||
return self.serialize_deconstructed(manager_path, args, kwargs)
|
||||
|
||||
|
||||
class OperationSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
from django.db.migrations.writer import OperationWriter
|
||||
|
||||
string, imports = OperationWriter(self.value, indentation=0).serialize()
|
||||
# Nested operation, trailing comma is handled in upper OperationWriter._write()
|
||||
return string.rstrip(","), imports
|
||||
|
||||
|
||||
class PathLikeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return repr(os.fspath(self.value)), {}
|
||||
|
||||
|
||||
class PathSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
# Convert concrete paths to pure paths to avoid issues with migrations
|
||||
# generated on one platform being used on a different platform.
|
||||
prefix = "Pure" if isinstance(self.value, pathlib.Path) else ""
|
||||
return "pathlib.%s%r" % (prefix, self.value), {"import pathlib"}
|
||||
|
||||
|
||||
class RegexSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
regex_pattern, pattern_imports = serializer_factory(
|
||||
self.value.pattern
|
||||
).serialize()
|
||||
# Turn off default implicit flags (e.g. re.U) because regexes with the
|
||||
# same implicit and explicit flags aren't equal.
|
||||
flags = self.value.flags ^ re.compile("").flags
|
||||
regex_flags, flag_imports = serializer_factory(flags).serialize()
|
||||
imports = {"import re", *pattern_imports, *flag_imports}
|
||||
args = [regex_pattern]
|
||||
if flags:
|
||||
args.append(regex_flags)
|
||||
return "re.compile(%s)" % ", ".join(args), imports
|
||||
|
||||
|
||||
class SequenceSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
return "[%s]"
|
||||
|
||||
|
||||
class SetSerializer(BaseUnorderedSequenceSerializer):
|
||||
def _format(self):
|
||||
# Serialize as a set literal except when value is empty because {}
|
||||
# is an empty dict.
|
||||
return "{%s}" if self.value else "set(%s)"
|
||||
|
||||
|
||||
class SettingsReferenceSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return "settings.%s" % self.value.setting_name, {
|
||||
"from django.conf import settings"
|
||||
}
|
||||
|
||||
|
||||
class TupleSerializer(BaseSequenceSerializer):
|
||||
def _format(self):
|
||||
# When len(value)==0, the empty tuple should be serialized as "()",
|
||||
# not "(,)" because (,) is invalid Python syntax.
|
||||
return "(%s)" if len(self.value) != 1 else "(%s,)"
|
||||
|
||||
|
||||
class TypeSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
special_cases = [
|
||||
(models.Model, "models.Model", ["from django.db import models"]),
|
||||
(types.NoneType, "types.NoneType", ["import types"]),
|
||||
]
|
||||
for case, string, imports in special_cases:
|
||||
if case is self.value:
|
||||
return string, set(imports)
|
||||
if hasattr(self.value, "__module__"):
|
||||
module = self.value.__module__
|
||||
if module == builtins.__name__:
|
||||
return self.value.__name__, set()
|
||||
else:
|
||||
return "%s.%s" % (module, self.value.__qualname__), {
|
||||
"import %s" % module
|
||||
}
|
||||
|
||||
|
||||
class UUIDSerializer(BaseSerializer):
|
||||
def serialize(self):
|
||||
return "uuid.%s" % repr(self.value), {"import uuid"}
|
||||
|
||||
|
||||
class Serializer:
|
||||
_registry = {
|
||||
# Some of these are order-dependent.
|
||||
frozenset: FrozensetSerializer,
|
||||
list: SequenceSerializer,
|
||||
set: SetSerializer,
|
||||
tuple: TupleSerializer,
|
||||
dict: DictionarySerializer,
|
||||
models.Choices: ChoicesSerializer,
|
||||
enum.Enum: EnumSerializer,
|
||||
datetime.datetime: DatetimeDatetimeSerializer,
|
||||
(datetime.date, datetime.timedelta, datetime.time): DateTimeSerializer,
|
||||
SettingsReference: SettingsReferenceSerializer,
|
||||
float: FloatSerializer,
|
||||
(bool, int, types.NoneType, bytes, str, range): BaseSimpleSerializer,
|
||||
decimal.Decimal: DecimalSerializer,
|
||||
(functools.partial, functools.partialmethod): FunctoolsPartialSerializer,
|
||||
FUNCTION_TYPES: FunctionTypeSerializer,
|
||||
collections.abc.Iterable: IterableSerializer,
|
||||
(COMPILED_REGEX_TYPE, RegexObject): RegexSerializer,
|
||||
uuid.UUID: UUIDSerializer,
|
||||
pathlib.PurePath: PathSerializer,
|
||||
os.PathLike: PathLikeSerializer,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register(cls, type_, serializer):
|
||||
if not issubclass(serializer, BaseSerializer):
|
||||
raise ValueError(
|
||||
"'%s' must inherit from 'BaseSerializer'." % serializer.__name__
|
||||
)
|
||||
cls._registry[type_] = serializer
|
||||
|
||||
@classmethod
|
||||
def unregister(cls, type_):
|
||||
cls._registry.pop(type_)
|
||||
|
||||
|
||||
def serializer_factory(value):
|
||||
if isinstance(value, Promise):
|
||||
value = str(value)
|
||||
elif isinstance(value, LazyObject):
|
||||
# The unwrapped value is returned as the first item of the arguments
|
||||
# tuple.
|
||||
value = value.__reduce__()[1][0]
|
||||
|
||||
if isinstance(value, models.Field):
|
||||
return ModelFieldSerializer(value)
|
||||
if isinstance(value, models.manager.BaseManager):
|
||||
return ModelManagerSerializer(value)
|
||||
if isinstance(value, Operation):
|
||||
return OperationSerializer(value)
|
||||
if isinstance(value, type):
|
||||
return TypeSerializer(value)
|
||||
# Anything that knows how to deconstruct itself.
|
||||
if hasattr(value, "deconstruct"):
|
||||
return DeconstructableSerializer(value)
|
||||
for type_, serializer_cls in Serializer._registry.items():
|
||||
if isinstance(value, type_):
|
||||
return serializer_cls(value)
|
||||
raise ValueError(
|
||||
"Cannot serialize: %r\nThere are some values Django cannot serialize into "
|
||||
"migration files.\nFor more, see https://docs.djangoproject.com/en/%s/"
|
||||
"topics/migrations/#migration-serializing" % (value, get_docs_version())
|
||||
)
|
||||
1017
.venv/lib/python3.10/site-packages/django/db/migrations/state.py
Normal file
1017
.venv/lib/python3.10/site-packages/django/db/migrations/state.py
Normal file
File diff suppressed because it is too large
Load Diff
129
.venv/lib/python3.10/site-packages/django/db/migrations/utils.py
Normal file
129
.venv/lib/python3.10/site-packages/django/db/migrations/utils.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import datetime
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT
|
||||
|
||||
FieldReference = namedtuple("FieldReference", "to through")
|
||||
|
||||
COMPILED_REGEX_TYPE = type(re.compile(""))
|
||||
|
||||
|
||||
class RegexObject:
|
||||
def __init__(self, obj):
|
||||
self.pattern = obj.pattern
|
||||
self.flags = obj.flags
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, RegexObject):
|
||||
return NotImplemented
|
||||
return self.pattern == other.pattern and self.flags == other.flags
|
||||
|
||||
|
||||
def get_migration_name_timestamp():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M")
|
||||
|
||||
|
||||
def resolve_relation(model, app_label=None, model_name=None):
|
||||
"""
|
||||
Turn a model class or model reference string and return a model tuple.
|
||||
|
||||
app_label and model_name are used to resolve the scope of recursive and
|
||||
unscoped model relationship.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
if model == RECURSIVE_RELATIONSHIP_CONSTANT:
|
||||
if app_label is None or model_name is None:
|
||||
raise TypeError(
|
||||
"app_label and model_name must be provided to resolve "
|
||||
"recursive relationships."
|
||||
)
|
||||
return app_label, model_name
|
||||
if "." in model:
|
||||
app_label, model_name = model.split(".", 1)
|
||||
return app_label, model_name.lower()
|
||||
if app_label is None:
|
||||
raise TypeError(
|
||||
"app_label must be provided to resolve unscoped model relationships."
|
||||
)
|
||||
return app_label, model.lower()
|
||||
return model._meta.app_label, model._meta.model_name
|
||||
|
||||
|
||||
def field_references(
|
||||
model_tuple,
|
||||
field,
|
||||
reference_model_tuple,
|
||||
reference_field_name=None,
|
||||
reference_field=None,
|
||||
):
|
||||
"""
|
||||
Return either False or a FieldReference if `field` references provided
|
||||
context.
|
||||
|
||||
False positives can be returned if `reference_field_name` is provided
|
||||
without `reference_field` because of the introspection limitation it
|
||||
incurs. This should not be an issue when this function is used to determine
|
||||
whether or not an optimization can take place.
|
||||
"""
|
||||
remote_field = field.remote_field
|
||||
if not remote_field:
|
||||
return False
|
||||
references_to = None
|
||||
references_through = None
|
||||
if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
|
||||
to_fields = getattr(field, "to_fields", None)
|
||||
if (
|
||||
reference_field_name is None
|
||||
or
|
||||
# Unspecified to_field(s).
|
||||
to_fields is None
|
||||
or
|
||||
# Reference to primary key.
|
||||
(
|
||||
None in to_fields
|
||||
and (reference_field is None or reference_field.primary_key)
|
||||
)
|
||||
or
|
||||
# Reference to field.
|
||||
reference_field_name in to_fields
|
||||
):
|
||||
references_to = (remote_field, to_fields)
|
||||
through = getattr(remote_field, "through", None)
|
||||
if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
|
||||
through_fields = remote_field.through_fields
|
||||
if (
|
||||
reference_field_name is None
|
||||
or
|
||||
# Unspecified through_fields.
|
||||
through_fields is None
|
||||
or
|
||||
# Reference to field.
|
||||
reference_field_name in through_fields
|
||||
):
|
||||
references_through = (remote_field, through_fields)
|
||||
if not (references_to or references_through):
|
||||
return False
|
||||
return FieldReference(references_to, references_through)
|
||||
|
||||
|
||||
def get_references(state, model_tuple, field_tuple=()):
|
||||
"""
|
||||
Generator of (model_state, name, field, reference) referencing
|
||||
provided context.
|
||||
|
||||
If field_tuple is provided only references to this particular field of
|
||||
model_tuple will be generated.
|
||||
"""
|
||||
for state_model_tuple, model_state in state.models.items():
|
||||
for name, field in model_state.fields.items():
|
||||
reference = field_references(
|
||||
state_model_tuple, field, model_tuple, *field_tuple
|
||||
)
|
||||
if reference:
|
||||
yield model_state, name, field, reference
|
||||
|
||||
|
||||
def field_is_referenced(state, model_tuple, field_tuple):
|
||||
"""Return whether `field_tuple` is referenced by any state models."""
|
||||
return next(get_references(state, model_tuple, field_tuple), None) is not None
|
||||
@@ -0,0 +1,316 @@
|
||||
import os
|
||||
import re
|
||||
from importlib import import_module
|
||||
|
||||
from django import get_version
|
||||
from django.apps import apps
|
||||
|
||||
# SettingsReference imported for backwards compatibility in Django 2.2.
|
||||
from django.conf import SettingsReference # NOQA
|
||||
from django.db import migrations
|
||||
from django.db.migrations.loader import MigrationLoader
|
||||
from django.db.migrations.serializer import Serializer, serializer_factory
|
||||
from django.utils.inspect import get_func_args
|
||||
from django.utils.module_loading import module_dir
|
||||
from django.utils.timezone import now
|
||||
|
||||
|
||||
class OperationWriter:
|
||||
def __init__(self, operation, indentation=2):
|
||||
self.operation = operation
|
||||
self.buff = []
|
||||
self.indentation = indentation
|
||||
|
||||
def serialize(self):
|
||||
def _write(_arg_name, _arg_value):
|
||||
if _arg_name in self.operation.serialization_expand_args and isinstance(
|
||||
_arg_value, (list, tuple, dict)
|
||||
):
|
||||
if isinstance(_arg_value, dict):
|
||||
self.feed("%s={" % _arg_name)
|
||||
self.indent()
|
||||
for key, value in _arg_value.items():
|
||||
key_string, key_imports = MigrationWriter.serialize(key)
|
||||
arg_string, arg_imports = MigrationWriter.serialize(value)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
self.feed("%s: %s" % (key_string, args[0]))
|
||||
for arg in args[1:-1]:
|
||||
self.feed(arg)
|
||||
self.feed("%s," % args[-1])
|
||||
else:
|
||||
self.feed("%s: %s," % (key_string, arg_string))
|
||||
imports.update(key_imports)
|
||||
imports.update(arg_imports)
|
||||
self.unindent()
|
||||
self.feed("},")
|
||||
else:
|
||||
self.feed("%s=[" % _arg_name)
|
||||
self.indent()
|
||||
for item in _arg_value:
|
||||
arg_string, arg_imports = MigrationWriter.serialize(item)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
for arg in args[:-1]:
|
||||
self.feed(arg)
|
||||
self.feed("%s," % args[-1])
|
||||
else:
|
||||
self.feed("%s," % arg_string)
|
||||
imports.update(arg_imports)
|
||||
self.unindent()
|
||||
self.feed("],")
|
||||
else:
|
||||
arg_string, arg_imports = MigrationWriter.serialize(_arg_value)
|
||||
args = arg_string.splitlines()
|
||||
if len(args) > 1:
|
||||
self.feed("%s=%s" % (_arg_name, args[0]))
|
||||
for arg in args[1:-1]:
|
||||
self.feed(arg)
|
||||
self.feed("%s," % args[-1])
|
||||
else:
|
||||
self.feed("%s=%s," % (_arg_name, arg_string))
|
||||
imports.update(arg_imports)
|
||||
|
||||
imports = set()
|
||||
name, args, kwargs = self.operation.deconstruct()
|
||||
operation_args = get_func_args(self.operation.__init__)
|
||||
|
||||
# See if this operation is in django.db.migrations. If it is,
|
||||
# We can just use the fact we already have that imported,
|
||||
# otherwise, we need to add an import for the operation class.
|
||||
if getattr(migrations, name, None) == self.operation.__class__:
|
||||
self.feed("migrations.%s(" % name)
|
||||
else:
|
||||
imports.add("import %s" % (self.operation.__class__.__module__))
|
||||
self.feed("%s.%s(" % (self.operation.__class__.__module__, name))
|
||||
|
||||
self.indent()
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
arg_value = arg
|
||||
arg_name = operation_args[i]
|
||||
_write(arg_name, arg_value)
|
||||
|
||||
i = len(args)
|
||||
# Only iterate over remaining arguments
|
||||
for arg_name in operation_args[i:]:
|
||||
if arg_name in kwargs: # Don't sort to maintain signature order
|
||||
arg_value = kwargs[arg_name]
|
||||
_write(arg_name, arg_value)
|
||||
|
||||
self.unindent()
|
||||
self.feed("),")
|
||||
return self.render(), imports
|
||||
|
||||
def indent(self):
|
||||
self.indentation += 1
|
||||
|
||||
def unindent(self):
|
||||
self.indentation -= 1
|
||||
|
||||
def feed(self, line):
|
||||
self.buff.append(" " * (self.indentation * 4) + line)
|
||||
|
||||
def render(self):
|
||||
return "\n".join(self.buff)
|
||||
|
||||
|
||||
class MigrationWriter:
|
||||
"""
|
||||
Take a Migration instance and is able to produce the contents
|
||||
of the migration file from it.
|
||||
"""
|
||||
|
||||
def __init__(self, migration, include_header=True):
|
||||
self.migration = migration
|
||||
self.include_header = include_header
|
||||
self.needs_manual_porting = False
|
||||
|
||||
def as_string(self):
|
||||
"""Return a string of the file contents."""
|
||||
items = {
|
||||
"replaces_str": "",
|
||||
"initial_str": "",
|
||||
}
|
||||
|
||||
imports = set()
|
||||
|
||||
# Deconstruct operations
|
||||
operations = []
|
||||
for operation in self.migration.operations:
|
||||
operation_string, operation_imports = OperationWriter(operation).serialize()
|
||||
imports.update(operation_imports)
|
||||
operations.append(operation_string)
|
||||
items["operations"] = "\n".join(operations) + "\n" if operations else ""
|
||||
|
||||
# Format dependencies and write out swappable dependencies right
|
||||
dependencies = []
|
||||
for dependency in self.migration.dependencies:
|
||||
if dependency[0] == "__setting__":
|
||||
dependencies.append(
|
||||
" migrations.swappable_dependency(settings.%s),"
|
||||
% dependency[1]
|
||||
)
|
||||
imports.add("from django.conf import settings")
|
||||
else:
|
||||
dependencies.append(" %s," % self.serialize(dependency)[0])
|
||||
items["dependencies"] = (
|
||||
"\n".join(sorted(dependencies)) + "\n" if dependencies else ""
|
||||
)
|
||||
|
||||
# Format imports nicely, swapping imports of functions from migration files
|
||||
# for comments
|
||||
migration_imports = set()
|
||||
for line in list(imports):
|
||||
if re.match(r"^import (.*)\.\d+[^\s]*$", line):
|
||||
migration_imports.add(line.split("import")[1].strip())
|
||||
imports.remove(line)
|
||||
self.needs_manual_porting = True
|
||||
|
||||
# django.db.migrations is always used, but models import may not be.
|
||||
# If models import exists, merge it with migrations import.
|
||||
if "from django.db import models" in imports:
|
||||
imports.discard("from django.db import models")
|
||||
imports.add("from django.db import migrations, models")
|
||||
else:
|
||||
imports.add("from django.db import migrations")
|
||||
|
||||
# Sort imports by the package / module to be imported (the part after
|
||||
# "from" in "from ... import ..." or after "import" in "import ...").
|
||||
# First group the "import" statements, then "from ... import ...".
|
||||
sorted_imports = sorted(
|
||||
imports, key=lambda i: (i.split()[0] == "from", i.split()[1])
|
||||
)
|
||||
items["imports"] = "\n".join(sorted_imports) + "\n" if imports else ""
|
||||
if migration_imports:
|
||||
items["imports"] += (
|
||||
"\n\n# Functions from the following migrations need manual "
|
||||
"copying.\n# Move them and any dependencies into this file, "
|
||||
"then update the\n# RunPython operations to refer to the local "
|
||||
"versions:\n# %s"
|
||||
) % "\n# ".join(sorted(migration_imports))
|
||||
# If there's a replaces, make a string for it
|
||||
if self.migration.replaces:
|
||||
items["replaces_str"] = (
|
||||
"\n replaces = %s\n" % self.serialize(self.migration.replaces)[0]
|
||||
)
|
||||
# Hinting that goes into comment
|
||||
if self.include_header:
|
||||
items["migration_header"] = MIGRATION_HEADER_TEMPLATE % {
|
||||
"version": get_version(),
|
||||
"timestamp": now().strftime("%Y-%m-%d %H:%M"),
|
||||
}
|
||||
else:
|
||||
items["migration_header"] = ""
|
||||
|
||||
if self.migration.initial:
|
||||
items["initial_str"] = "\n initial = True\n"
|
||||
|
||||
return MIGRATION_TEMPLATE % items
|
||||
|
||||
@property
|
||||
def basedir(self):
|
||||
migrations_package_name, _ = MigrationLoader.migrations_module(
|
||||
self.migration.app_label
|
||||
)
|
||||
|
||||
if migrations_package_name is None:
|
||||
raise ValueError(
|
||||
"Django can't create migrations for app '%s' because "
|
||||
"migrations have been disabled via the MIGRATION_MODULES "
|
||||
"setting." % self.migration.app_label
|
||||
)
|
||||
|
||||
# See if we can import the migrations module directly
|
||||
try:
|
||||
migrations_module = import_module(migrations_package_name)
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
return module_dir(migrations_module)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Alright, see if it's a direct submodule of the app
|
||||
app_config = apps.get_app_config(self.migration.app_label)
|
||||
(
|
||||
maybe_app_name,
|
||||
_,
|
||||
migrations_package_basename,
|
||||
) = migrations_package_name.rpartition(".")
|
||||
if app_config.name == maybe_app_name:
|
||||
return os.path.join(app_config.path, migrations_package_basename)
|
||||
|
||||
# In case of using MIGRATION_MODULES setting and the custom package
|
||||
# doesn't exist, create one, starting from an existing package
|
||||
existing_dirs, missing_dirs = migrations_package_name.split("."), []
|
||||
while existing_dirs:
|
||||
missing_dirs.insert(0, existing_dirs.pop(-1))
|
||||
try:
|
||||
base_module = import_module(".".join(existing_dirs))
|
||||
except (ImportError, ValueError):
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
base_dir = module_dir(base_module)
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not locate an appropriate location to create "
|
||||
"migrations package %s. Make sure the toplevel "
|
||||
"package exists and can be imported." % migrations_package_name
|
||||
)
|
||||
|
||||
final_dir = os.path.join(base_dir, *missing_dirs)
|
||||
os.makedirs(final_dir, exist_ok=True)
|
||||
for missing_dir in missing_dirs:
|
||||
base_dir = os.path.join(base_dir, missing_dir)
|
||||
with open(os.path.join(base_dir, "__init__.py"), "w"):
|
||||
pass
|
||||
|
||||
return final_dir
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return "%s.py" % self.migration.name
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return os.path.join(self.basedir, self.filename)
|
||||
|
||||
@classmethod
|
||||
def serialize(cls, value):
|
||||
return serializer_factory(value).serialize()
|
||||
|
||||
@classmethod
|
||||
def register_serializer(cls, type_, serializer):
|
||||
Serializer.register(type_, serializer)
|
||||
|
||||
@classmethod
|
||||
def unregister_serializer(cls, type_):
|
||||
Serializer.unregister(type_)
|
||||
|
||||
|
||||
MIGRATION_HEADER_TEMPLATE = """\
|
||||
# Generated by Django %(version)s on %(timestamp)s
|
||||
|
||||
"""
|
||||
|
||||
|
||||
MIGRATION_TEMPLATE = """\
|
||||
%(migration_header)s%(imports)s
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
%(replaces_str)s%(initial_str)s
|
||||
dependencies = [
|
||||
%(dependencies)s\
|
||||
]
|
||||
|
||||
operations = [
|
||||
%(operations)s\
|
||||
]
|
||||
"""
|
||||
127
.venv/lib/python3.10/site-packages/django/db/models/__init__.py
Normal file
127
.venv/lib/python3.10/site-packages/django/db/models/__init__.py
Normal file
@@ -0,0 +1,127 @@
|
||||
from django.core.exceptions import ObjectDoesNotExist
|
||||
from django.db.models import signals
|
||||
from django.db.models.aggregates import * # NOQA
|
||||
from django.db.models.aggregates import __all__ as aggregates_all
|
||||
from django.db.models.constraints import * # NOQA
|
||||
from django.db.models.constraints import __all__ as constraints_all
|
||||
from django.db.models.deletion import (
|
||||
CASCADE,
|
||||
DO_NOTHING,
|
||||
PROTECT,
|
||||
RESTRICT,
|
||||
SET,
|
||||
SET_DEFAULT,
|
||||
SET_NULL,
|
||||
ProtectedError,
|
||||
RestrictedError,
|
||||
)
|
||||
from django.db.models.enums import * # NOQA
|
||||
from django.db.models.enums import __all__ as enums_all
|
||||
from django.db.models.expressions import (
|
||||
Case,
|
||||
Exists,
|
||||
Expression,
|
||||
ExpressionList,
|
||||
ExpressionWrapper,
|
||||
F,
|
||||
Func,
|
||||
OrderBy,
|
||||
OuterRef,
|
||||
RowRange,
|
||||
Subquery,
|
||||
Value,
|
||||
ValueRange,
|
||||
When,
|
||||
Window,
|
||||
WindowFrame,
|
||||
WindowFrameExclusion,
|
||||
)
|
||||
from django.db.models.fields import * # NOQA
|
||||
from django.db.models.fields import __all__ as fields_all
|
||||
from django.db.models.fields.composite import CompositePrimaryKey
|
||||
from django.db.models.fields.files import FileField, ImageField
|
||||
from django.db.models.fields.generated import GeneratedField
|
||||
from django.db.models.fields.json import JSONField
|
||||
from django.db.models.fields.proxy import OrderWrt
|
||||
from django.db.models.indexes import * # NOQA
|
||||
from django.db.models.indexes import __all__ as indexes_all
|
||||
from django.db.models.lookups import Lookup, Transform
|
||||
from django.db.models.manager import Manager
|
||||
from django.db.models.query import (
|
||||
Prefetch,
|
||||
QuerySet,
|
||||
aprefetch_related_objects,
|
||||
prefetch_related_objects,
|
||||
)
|
||||
from django.db.models.query_utils import FilteredRelation, Q
|
||||
|
||||
# Imports that would create circular imports if sorted
|
||||
from django.db.models.base import DEFERRED, Model # isort:skip
|
||||
from django.db.models.fields.related import ( # isort:skip
|
||||
ForeignKey,
|
||||
ForeignObject,
|
||||
OneToOneField,
|
||||
ManyToManyField,
|
||||
ForeignObjectRel,
|
||||
ManyToOneRel,
|
||||
ManyToManyRel,
|
||||
OneToOneRel,
|
||||
)
|
||||
|
||||
|
||||
__all__ = aggregates_all + constraints_all + enums_all + fields_all + indexes_all
|
||||
__all__ += [
|
||||
"ObjectDoesNotExist",
|
||||
"signals",
|
||||
"CASCADE",
|
||||
"DO_NOTHING",
|
||||
"PROTECT",
|
||||
"RESTRICT",
|
||||
"SET",
|
||||
"SET_DEFAULT",
|
||||
"SET_NULL",
|
||||
"ProtectedError",
|
||||
"RestrictedError",
|
||||
"Case",
|
||||
"CompositePrimaryKey",
|
||||
"Exists",
|
||||
"Expression",
|
||||
"ExpressionList",
|
||||
"ExpressionWrapper",
|
||||
"F",
|
||||
"Func",
|
||||
"OrderBy",
|
||||
"OuterRef",
|
||||
"RowRange",
|
||||
"Subquery",
|
||||
"Value",
|
||||
"ValueRange",
|
||||
"When",
|
||||
"Window",
|
||||
"WindowFrame",
|
||||
"WindowFrameExclusion",
|
||||
"FileField",
|
||||
"ImageField",
|
||||
"GeneratedField",
|
||||
"JSONField",
|
||||
"OrderWrt",
|
||||
"Lookup",
|
||||
"Transform",
|
||||
"Manager",
|
||||
"Prefetch",
|
||||
"Q",
|
||||
"QuerySet",
|
||||
"aprefetch_related_objects",
|
||||
"prefetch_related_objects",
|
||||
"DEFERRED",
|
||||
"Model",
|
||||
"FilteredRelation",
|
||||
"ForeignKey",
|
||||
"ForeignObject",
|
||||
"OneToOneField",
|
||||
"ManyToManyField",
|
||||
"ForeignObjectRel",
|
||||
"ManyToOneRel",
|
||||
"ManyToManyRel",
|
||||
"OneToOneRel",
|
||||
]
|
||||
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
Classes to represent the definitions of aggregate functions.
|
||||
"""
|
||||
|
||||
from django.core.exceptions import FieldError, FullResultSet
|
||||
from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When
|
||||
from django.db.models.fields import IntegerField
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db.models.functions.mixins import (
|
||||
FixDurationInputMixin,
|
||||
NumericOutputFieldMixin,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Aggregate",
|
||||
"Avg",
|
||||
"Count",
|
||||
"Max",
|
||||
"Min",
|
||||
"StdDev",
|
||||
"Sum",
|
||||
"Variance",
|
||||
]
|
||||
|
||||
|
||||
class Aggregate(Func):
|
||||
template = "%(function)s(%(distinct)s%(expressions)s)"
|
||||
contains_aggregate = True
|
||||
name = None
|
||||
filter_template = "%s FILTER (WHERE %%(filter)s)"
|
||||
window_compatible = True
|
||||
allow_distinct = False
|
||||
empty_result_set_value = None
|
||||
|
||||
def __init__(
|
||||
self, *expressions, distinct=False, filter=None, default=None, **extra
|
||||
):
|
||||
if distinct and not self.allow_distinct:
|
||||
raise TypeError("%s does not allow distinct." % self.__class__.__name__)
|
||||
if default is not None and self.empty_result_set_value is not None:
|
||||
raise TypeError(f"{self.__class__.__name__} does not allow default.")
|
||||
self.distinct = distinct
|
||||
self.filter = filter
|
||||
self.default = default
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def get_source_fields(self):
|
||||
# Don't return the filter expression since it's not a source field.
|
||||
return [e._output_field_or_none for e in super().get_source_expressions()]
|
||||
|
||||
def get_source_expressions(self):
|
||||
source_expressions = super().get_source_expressions()
|
||||
return source_expressions + [self.filter]
|
||||
|
||||
def set_source_expressions(self, exprs):
|
||||
*exprs, self.filter = exprs
|
||||
return super().set_source_expressions(exprs)
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
# Aggregates are not allowed in UPDATE queries, so ignore for_save
|
||||
c = super().resolve_expression(query, allow_joins, reuse, summarize)
|
||||
c.filter = (
|
||||
c.filter.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if c.filter
|
||||
else None
|
||||
)
|
||||
if summarize:
|
||||
# Summarized aggregates cannot refer to summarized aggregates.
|
||||
for ref in c.get_refs():
|
||||
if query.annotations[ref].is_summary:
|
||||
raise FieldError(
|
||||
f"Cannot compute {c.name}('{ref}'): '{ref}' is an aggregate"
|
||||
)
|
||||
elif not self.is_summary:
|
||||
# Call Aggregate.get_source_expressions() to avoid
|
||||
# returning self.filter and including that in this loop.
|
||||
expressions = super(Aggregate, c).get_source_expressions()
|
||||
for index, expr in enumerate(expressions):
|
||||
if expr.contains_aggregate:
|
||||
before_resolved = self.get_source_expressions()[index]
|
||||
name = (
|
||||
before_resolved.name
|
||||
if hasattr(before_resolved, "name")
|
||||
else repr(before_resolved)
|
||||
)
|
||||
raise FieldError(
|
||||
"Cannot compute %s('%s'): '%s' is an aggregate"
|
||||
% (c.name, name, name)
|
||||
)
|
||||
if (default := c.default) is None:
|
||||
return c
|
||||
if hasattr(default, "resolve_expression"):
|
||||
default = default.resolve_expression(query, allow_joins, reuse, summarize)
|
||||
if default._output_field_or_none is None:
|
||||
default.output_field = c._output_field_or_none
|
||||
else:
|
||||
default = Value(default, c._output_field_or_none)
|
||||
c.default = None # Reset the default argument before wrapping.
|
||||
coalesce = Coalesce(c, default, output_field=c._output_field_or_none)
|
||||
coalesce.is_summary = c.is_summary
|
||||
return coalesce
|
||||
|
||||
@property
|
||||
def default_alias(self):
|
||||
expressions = [
|
||||
expr for expr in self.get_source_expressions() if expr is not None
|
||||
]
|
||||
if len(expressions) == 1 and hasattr(expressions[0], "name"):
|
||||
return "%s__%s" % (expressions[0].name, self.name.lower())
|
||||
raise TypeError("Complex expressions require an alias")
|
||||
|
||||
def get_group_by_cols(self):
|
||||
return []
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
|
||||
if self.filter:
|
||||
if connection.features.supports_aggregate_filter_clause:
|
||||
try:
|
||||
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
|
||||
except FullResultSet:
|
||||
pass
|
||||
else:
|
||||
template = self.filter_template % extra_context.get(
|
||||
"template", self.template
|
||||
)
|
||||
sql, params = super().as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
filter=filter_sql,
|
||||
**extra_context,
|
||||
)
|
||||
return sql, (*params, *filter_params)
|
||||
else:
|
||||
copy = self.copy()
|
||||
copy.filter = None
|
||||
source_expressions = copy.get_source_expressions()
|
||||
condition = When(self.filter, then=source_expressions[0])
|
||||
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
|
||||
return super(Aggregate, copy).as_sql(
|
||||
compiler, connection, **extra_context
|
||||
)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def _get_repr_options(self):
|
||||
options = super()._get_repr_options()
|
||||
if self.distinct:
|
||||
options["distinct"] = self.distinct
|
||||
if self.filter:
|
||||
options["filter"] = self.filter
|
||||
return options
|
||||
|
||||
|
||||
class Avg(FixDurationInputMixin, NumericOutputFieldMixin, Aggregate):
|
||||
function = "AVG"
|
||||
name = "Avg"
|
||||
allow_distinct = True
|
||||
arity = 1
|
||||
|
||||
|
||||
class Count(Aggregate):
|
||||
function = "COUNT"
|
||||
name = "Count"
|
||||
output_field = IntegerField()
|
||||
allow_distinct = True
|
||||
empty_result_set_value = 0
|
||||
arity = 1
|
||||
allows_composite_expressions = True
|
||||
|
||||
def __init__(self, expression, filter=None, **extra):
|
||||
if expression == "*":
|
||||
expression = Star()
|
||||
if isinstance(expression, Star) and filter is not None:
|
||||
raise ValueError("Star cannot be used with filter. Please specify a field.")
|
||||
super().__init__(expression, filter=filter, **extra)
|
||||
|
||||
def resolve_expression(self, *args, **kwargs):
|
||||
result = super().resolve_expression(*args, **kwargs)
|
||||
expr = result.source_expressions[0]
|
||||
|
||||
# In case of composite primary keys, count the first column.
|
||||
if isinstance(expr, ColPairs):
|
||||
if self.distinct:
|
||||
raise ValueError(
|
||||
"COUNT(DISTINCT) doesn't support composite primary keys"
|
||||
)
|
||||
|
||||
cols = expr.get_cols()
|
||||
return Count(cols[0], filter=result.filter)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Max(Aggregate):
|
||||
function = "MAX"
|
||||
name = "Max"
|
||||
arity = 1
|
||||
|
||||
|
||||
class Min(Aggregate):
|
||||
function = "MIN"
|
||||
name = "Min"
|
||||
arity = 1
|
||||
|
||||
|
||||
class StdDev(NumericOutputFieldMixin, Aggregate):
|
||||
name = "StdDev"
|
||||
arity = 1
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "STDDEV_SAMP" if sample else "STDDEV_POP"
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "STDDEV_SAMP"}
|
||||
|
||||
|
||||
class Sum(FixDurationInputMixin, Aggregate):
|
||||
function = "SUM"
|
||||
name = "Sum"
|
||||
allow_distinct = True
|
||||
arity = 1
|
||||
|
||||
|
||||
class Variance(NumericOutputFieldMixin, Aggregate):
|
||||
name = "Variance"
|
||||
arity = 1
|
||||
|
||||
def __init__(self, expression, sample=False, **extra):
|
||||
self.function = "VAR_SAMP" if sample else "VAR_POP"
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def _get_repr_options(self):
|
||||
return {**super()._get_repr_options(), "sample": self.function == "VAR_SAMP"}
|
||||
2582
.venv/lib/python3.10/site-packages/django/db/models/base.py
Normal file
2582
.venv/lib/python3.10/site-packages/django/db/models/base.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Constants used across the ORM in general.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
# Separator used to split filter strings apart.
|
||||
LOOKUP_SEP = "__"
|
||||
|
||||
|
||||
class OnConflict(Enum):
|
||||
IGNORE = "ignore"
|
||||
UPDATE = "update"
|
||||
@@ -0,0 +1,728 @@
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from types import NoneType
|
||||
|
||||
from django.core import checks
|
||||
from django.core.exceptions import FieldDoesNotExist, FieldError, ValidationError
|
||||
from django.db import connections
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import Exists, ExpressionList, F, RawSQL
|
||||
from django.db.models.indexes import IndexExpression
|
||||
from django.db.models.lookups import Exact, IsNull
|
||||
from django.db.models.query_utils import Q
|
||||
from django.db.models.sql.query import Query
|
||||
from django.db.utils import DEFAULT_DB_ALIAS
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
__all__ = ["BaseConstraint", "CheckConstraint", "Deferrable", "UniqueConstraint"]
|
||||
|
||||
|
||||
class BaseConstraint:
|
||||
default_violation_error_message = _("Constraint “%(name)s” is violated.")
|
||||
violation_error_code = None
|
||||
violation_error_message = None
|
||||
|
||||
non_db_attrs = ("violation_error_code", "violation_error_message")
|
||||
|
||||
# RemovedInDjango60Warning: When the deprecation ends, replace with:
|
||||
# def __init__(
|
||||
# self, *, name, violation_error_code=None, violation_error_message=None
|
||||
# ):
|
||||
def __init__(
|
||||
self, *args, name=None, violation_error_code=None, violation_error_message=None
|
||||
):
|
||||
# RemovedInDjango60Warning.
|
||||
if name is None and not args:
|
||||
raise TypeError(
|
||||
f"{self.__class__.__name__}.__init__() missing 1 required keyword-only "
|
||||
f"argument: 'name'"
|
||||
)
|
||||
self.name = name
|
||||
if violation_error_code is not None:
|
||||
self.violation_error_code = violation_error_code
|
||||
if violation_error_message is not None:
|
||||
self.violation_error_message = violation_error_message
|
||||
else:
|
||||
self.violation_error_message = self.default_violation_error_message
|
||||
# RemovedInDjango60Warning.
|
||||
if args:
|
||||
warnings.warn(
|
||||
f"Passing positional arguments to {self.__class__.__name__} is "
|
||||
f"deprecated.",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
for arg, attr in zip(args, ["name", "violation_error_message"]):
|
||||
if arg:
|
||||
setattr(self, attr, arg)
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return False
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
@classmethod
|
||||
def _expression_refs_exclude(cls, model, expression, exclude):
|
||||
get_field = model._meta.get_field
|
||||
for field_name, *__ in model._get_expr_references(expression):
|
||||
if field_name in exclude:
|
||||
return True
|
||||
field = get_field(field_name)
|
||||
if field.generated and cls._expression_refs_exclude(
|
||||
model, field.expression, exclude
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
raise NotImplementedError("This method must be implemented by a subclass.")
|
||||
|
||||
def get_violation_error_message(self):
|
||||
return self.violation_error_message % {"name": self.name}
|
||||
|
||||
def _check(self, model, connection):
|
||||
return []
|
||||
|
||||
def _check_references(self, model, references):
|
||||
from django.db.models.fields.composite import CompositePrimaryKey
|
||||
|
||||
errors = []
|
||||
fields = set()
|
||||
for field_name, *lookups in references:
|
||||
# pk is an alias that won't be found by opts.get_field().
|
||||
if field_name != "pk" or isinstance(model._meta.pk, CompositePrimaryKey):
|
||||
fields.add(field_name)
|
||||
if not lookups:
|
||||
# If it has no lookups it cannot result in a JOIN.
|
||||
continue
|
||||
try:
|
||||
if field_name == "pk":
|
||||
field = model._meta.pk
|
||||
else:
|
||||
field = model._meta.get_field(field_name)
|
||||
if not field.is_relation or field.many_to_many or field.one_to_many:
|
||||
continue
|
||||
except FieldDoesNotExist:
|
||||
continue
|
||||
# JOIN must happen at the first lookup.
|
||||
first_lookup = lookups[0]
|
||||
if (
|
||||
hasattr(field, "get_transform")
|
||||
and hasattr(field, "get_lookup")
|
||||
and field.get_transform(first_lookup) is None
|
||||
and field.get_lookup(first_lookup) is None
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"'constraints' refers to the joined field '%s'."
|
||||
% LOOKUP_SEP.join([field_name] + lookups),
|
||||
obj=model,
|
||||
id="models.E041",
|
||||
)
|
||||
)
|
||||
errors.extend(model._check_local_fields(fields, "constraints"))
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
|
||||
path = path.replace("django.db.models.constraints", "django.db.models")
|
||||
kwargs = {"name": self.name}
|
||||
if (
|
||||
self.violation_error_message is not None
|
||||
and self.violation_error_message != self.default_violation_error_message
|
||||
):
|
||||
kwargs["violation_error_message"] = self.violation_error_message
|
||||
if self.violation_error_code is not None:
|
||||
kwargs["violation_error_code"] = self.violation_error_code
|
||||
return (path, (), kwargs)
|
||||
|
||||
def clone(self):
|
||||
_, args, kwargs = self.deconstruct()
|
||||
return self.__class__(*args, **kwargs)
|
||||
|
||||
|
||||
class CheckConstraint(BaseConstraint):
|
||||
# RemovedInDjango60Warning: when the deprecation ends, replace with
|
||||
# def __init__(
|
||||
# self, *, condition, name, violation_error_code=None, violation_error_message=None
|
||||
# )
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name,
|
||||
condition=None,
|
||||
check=None,
|
||||
violation_error_code=None,
|
||||
violation_error_message=None,
|
||||
):
|
||||
if check is not None:
|
||||
warnings.warn(
|
||||
"CheckConstraint.check is deprecated in favor of `.condition`.",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
condition = check
|
||||
self.condition = condition
|
||||
if not getattr(condition, "conditional", False):
|
||||
raise TypeError(
|
||||
"CheckConstraint.condition must be a Q instance or boolean expression."
|
||||
)
|
||||
super().__init__(
|
||||
name=name,
|
||||
violation_error_code=violation_error_code,
|
||||
violation_error_message=violation_error_message,
|
||||
)
|
||||
|
||||
def _get_check(self):
|
||||
warnings.warn(
|
||||
"CheckConstraint.check is deprecated in favor of `.condition`.",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.condition
|
||||
|
||||
def _set_check(self, value):
|
||||
warnings.warn(
|
||||
"CheckConstraint.check is deprecated in favor of `.condition`.",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.condition = value
|
||||
|
||||
check = property(_get_check, _set_check)
|
||||
|
||||
def _check(self, model, connection):
|
||||
errors = []
|
||||
if not (
|
||||
connection.features.supports_table_check_constraints
|
||||
or "supports_table_check_constraints" in model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
f"{connection.display_name} does not support check constraints.",
|
||||
hint=(
|
||||
"A constraint won't be created. Silence this warning if you "
|
||||
"don't care about it."
|
||||
),
|
||||
obj=model,
|
||||
id="models.W027",
|
||||
)
|
||||
)
|
||||
elif (
|
||||
connection.features.supports_table_check_constraints
|
||||
or "supports_table_check_constraints"
|
||||
not in model._meta.required_db_features
|
||||
):
|
||||
references = set()
|
||||
condition = self.condition
|
||||
if isinstance(condition, Q):
|
||||
references.update(model._get_expr_references(condition))
|
||||
if any(isinstance(expr, RawSQL) for expr in condition.flatten()):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
f"Check constraint {self.name!r} contains RawSQL() expression "
|
||||
"and won't be validated during the model full_clean().",
|
||||
hint="Silence this warning if you don't care about it.",
|
||||
obj=model,
|
||||
id="models.W045",
|
||||
),
|
||||
)
|
||||
errors.extend(self._check_references(model, references))
|
||||
return errors
|
||||
|
||||
def _get_check_sql(self, model, schema_editor):
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.condition)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
check = self._get_check_sql(model, schema_editor)
|
||||
return schema_editor._check_sql(self.name, check)
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
check = self._get_check_sql(model, schema_editor)
|
||||
return schema_editor._create_check_sql(model, self.name, check)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
return schema_editor._delete_check_sql(model, self.name)
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
against = instance._get_field_expression_map(meta=model._meta, exclude=exclude)
|
||||
try:
|
||||
if not Q(self.condition).check(against, using=using):
|
||||
raise ValidationError(
|
||||
self.get_violation_error_message(), code=self.violation_error_code
|
||||
)
|
||||
except FieldError:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: condition=%s name=%s%s%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
self.condition,
|
||||
repr(self.name),
|
||||
(
|
||||
""
|
||||
if self.violation_error_code is None
|
||||
else " violation_error_code=%r" % self.violation_error_code
|
||||
),
|
||||
(
|
||||
""
|
||||
if self.violation_error_message is None
|
||||
or self.violation_error_message == self.default_violation_error_message
|
||||
else " violation_error_message=%r" % self.violation_error_message
|
||||
),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, CheckConstraint):
|
||||
return (
|
||||
self.name == other.name
|
||||
and self.condition == other.condition
|
||||
and self.violation_error_code == other.violation_error_code
|
||||
and self.violation_error_message == other.violation_error_message
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
kwargs["condition"] = self.condition
|
||||
return path, args, kwargs
|
||||
|
||||
|
||||
class Deferrable(Enum):
|
||||
DEFERRED = "deferred"
|
||||
IMMEDIATE = "immediate"
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__qualname__}.{self._name_}"
|
||||
|
||||
|
||||
class UniqueConstraint(BaseConstraint):
|
||||
def __init__(
|
||||
self,
|
||||
*expressions,
|
||||
fields=(),
|
||||
name=None,
|
||||
condition=None,
|
||||
deferrable=None,
|
||||
include=None,
|
||||
opclasses=(),
|
||||
nulls_distinct=None,
|
||||
violation_error_code=None,
|
||||
violation_error_message=None,
|
||||
):
|
||||
if not name:
|
||||
raise ValueError("A unique constraint must be named.")
|
||||
if not expressions and not fields:
|
||||
raise ValueError(
|
||||
"At least one field or expression is required to define a "
|
||||
"unique constraint."
|
||||
)
|
||||
if expressions and fields:
|
||||
raise ValueError(
|
||||
"UniqueConstraint.fields and expressions are mutually exclusive."
|
||||
)
|
||||
if not isinstance(condition, (NoneType, Q)):
|
||||
raise ValueError("UniqueConstraint.condition must be a Q instance.")
|
||||
if condition and deferrable:
|
||||
raise ValueError("UniqueConstraint with conditions cannot be deferred.")
|
||||
if include and deferrable:
|
||||
raise ValueError("UniqueConstraint with include fields cannot be deferred.")
|
||||
if opclasses and deferrable:
|
||||
raise ValueError("UniqueConstraint with opclasses cannot be deferred.")
|
||||
if expressions and deferrable:
|
||||
raise ValueError("UniqueConstraint with expressions cannot be deferred.")
|
||||
if expressions and opclasses:
|
||||
raise ValueError(
|
||||
"UniqueConstraint.opclasses cannot be used with expressions. "
|
||||
"Use django.contrib.postgres.indexes.OpClass() instead."
|
||||
)
|
||||
if not isinstance(deferrable, (NoneType, Deferrable)):
|
||||
raise TypeError(
|
||||
"UniqueConstraint.deferrable must be a Deferrable instance."
|
||||
)
|
||||
if not isinstance(include, (NoneType, list, tuple)):
|
||||
raise TypeError("UniqueConstraint.include must be a list or tuple.")
|
||||
if not isinstance(opclasses, (list, tuple)):
|
||||
raise TypeError("UniqueConstraint.opclasses must be a list or tuple.")
|
||||
if not isinstance(nulls_distinct, (NoneType, bool)):
|
||||
raise TypeError("UniqueConstraint.nulls_distinct must be a bool.")
|
||||
if opclasses and len(fields) != len(opclasses):
|
||||
raise ValueError(
|
||||
"UniqueConstraint.fields and UniqueConstraint.opclasses must "
|
||||
"have the same number of elements."
|
||||
)
|
||||
self.fields = tuple(fields)
|
||||
self.condition = condition
|
||||
self.deferrable = deferrable
|
||||
self.include = tuple(include) if include else ()
|
||||
self.opclasses = opclasses
|
||||
self.nulls_distinct = nulls_distinct
|
||||
self.expressions = tuple(
|
||||
F(expression) if isinstance(expression, str) else expression
|
||||
for expression in expressions
|
||||
)
|
||||
super().__init__(
|
||||
name=name,
|
||||
violation_error_code=violation_error_code,
|
||||
violation_error_message=violation_error_message,
|
||||
)
|
||||
|
||||
@property
|
||||
def contains_expressions(self):
|
||||
return bool(self.expressions)
|
||||
|
||||
def _check(self, model, connection):
|
||||
errors = model._check_local_fields({*self.fields, *self.include}, "constraints")
|
||||
required_db_features = model._meta.required_db_features
|
||||
if self.condition is not None and not (
|
||||
connection.features.supports_partial_indexes
|
||||
or "supports_partial_indexes" in required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
f"{connection.display_name} does not support unique constraints "
|
||||
"with conditions.",
|
||||
hint=(
|
||||
"A constraint won't be created. Silence this warning if you "
|
||||
"don't care about it."
|
||||
),
|
||||
obj=model,
|
||||
id="models.W036",
|
||||
)
|
||||
)
|
||||
if self.deferrable is not None and not (
|
||||
connection.features.supports_deferrable_unique_constraints
|
||||
or "supports_deferrable_unique_constraints" in required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
f"{connection.display_name} does not support deferrable unique "
|
||||
"constraints.",
|
||||
hint=(
|
||||
"A constraint won't be created. Silence this warning if you "
|
||||
"don't care about it."
|
||||
),
|
||||
obj=model,
|
||||
id="models.W038",
|
||||
)
|
||||
)
|
||||
if self.include and not (
|
||||
connection.features.supports_covering_indexes
|
||||
or "supports_covering_indexes" in required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
f"{connection.display_name} does not support unique constraints "
|
||||
"with non-key columns.",
|
||||
hint=(
|
||||
"A constraint won't be created. Silence this warning if you "
|
||||
"don't care about it."
|
||||
),
|
||||
obj=model,
|
||||
id="models.W039",
|
||||
)
|
||||
)
|
||||
if self.contains_expressions and not (
|
||||
connection.features.supports_expression_indexes
|
||||
or "supports_expression_indexes" in required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
f"{connection.display_name} does not support unique constraints on "
|
||||
"expressions.",
|
||||
hint=(
|
||||
"A constraint won't be created. Silence this warning if you "
|
||||
"don't care about it."
|
||||
),
|
||||
obj=model,
|
||||
id="models.W044",
|
||||
)
|
||||
)
|
||||
if self.nulls_distinct is not None and not (
|
||||
connection.features.supports_nulls_distinct_unique_constraints
|
||||
or "supports_nulls_distinct_unique_constraints" in required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
f"{connection.display_name} does not support unique constraints "
|
||||
"with nulls distinct.",
|
||||
hint=(
|
||||
"A constraint won't be created. Silence this warning if you "
|
||||
"don't care about it."
|
||||
),
|
||||
obj=model,
|
||||
id="models.W047",
|
||||
)
|
||||
)
|
||||
references = set()
|
||||
if (
|
||||
connection.features.supports_partial_indexes
|
||||
or "supports_partial_indexes" not in required_db_features
|
||||
) and isinstance(self.condition, Q):
|
||||
references.update(model._get_expr_references(self.condition))
|
||||
if self.contains_expressions and (
|
||||
connection.features.supports_expression_indexes
|
||||
or "supports_expression_indexes" not in required_db_features
|
||||
):
|
||||
for expression in self.expressions:
|
||||
references.update(model._get_expr_references(expression))
|
||||
errors.extend(self._check_references(model, references))
|
||||
return errors
|
||||
|
||||
def _get_condition_sql(self, model, schema_editor):
|
||||
if self.condition is None:
|
||||
return None
|
||||
query = Query(model=model, alias_cols=False)
|
||||
where = query.build_where(self.condition)
|
||||
compiler = query.get_compiler(connection=schema_editor.connection)
|
||||
sql, params = where.as_sql(compiler, schema_editor.connection)
|
||||
return sql % tuple(schema_editor.quote_value(p) for p in params)
|
||||
|
||||
def _get_index_expressions(self, model, schema_editor):
|
||||
if not self.expressions:
|
||||
return None
|
||||
index_expressions = []
|
||||
for expression in self.expressions:
|
||||
index_expression = IndexExpression(expression)
|
||||
index_expression.set_wrapper_classes(schema_editor.connection)
|
||||
index_expressions.append(index_expression)
|
||||
return ExpressionList(*index_expressions).resolve_expression(
|
||||
Query(model, alias_cols=False),
|
||||
)
|
||||
|
||||
def constraint_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._unique_sql(
|
||||
model,
|
||||
fields,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
nulls_distinct=self.nulls_distinct,
|
||||
)
|
||||
|
||||
def create_sql(self, model, schema_editor):
|
||||
fields = [model._meta.get_field(field_name) for field_name in self.fields]
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._create_unique_sql(
|
||||
model,
|
||||
fields,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
nulls_distinct=self.nulls_distinct,
|
||||
)
|
||||
|
||||
def remove_sql(self, model, schema_editor):
|
||||
condition = self._get_condition_sql(model, schema_editor)
|
||||
include = [
|
||||
model._meta.get_field(field_name).column for field_name in self.include
|
||||
]
|
||||
expressions = self._get_index_expressions(model, schema_editor)
|
||||
return schema_editor._delete_unique_sql(
|
||||
model,
|
||||
self.name,
|
||||
condition=condition,
|
||||
deferrable=self.deferrable,
|
||||
include=include,
|
||||
opclasses=self.opclasses,
|
||||
expressions=expressions,
|
||||
nulls_distinct=self.nulls_distinct,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s:%s%s%s%s%s%s%s%s%s%s>" % (
|
||||
self.__class__.__qualname__,
|
||||
"" if not self.fields else " fields=%s" % repr(self.fields),
|
||||
"" if not self.expressions else " expressions=%s" % repr(self.expressions),
|
||||
" name=%s" % repr(self.name),
|
||||
"" if self.condition is None else " condition=%s" % self.condition,
|
||||
"" if self.deferrable is None else " deferrable=%r" % self.deferrable,
|
||||
"" if not self.include else " include=%s" % repr(self.include),
|
||||
"" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
|
||||
(
|
||||
""
|
||||
if self.nulls_distinct is None
|
||||
else " nulls_distinct=%r" % self.nulls_distinct
|
||||
),
|
||||
(
|
||||
""
|
||||
if self.violation_error_code is None
|
||||
else " violation_error_code=%r" % self.violation_error_code
|
||||
),
|
||||
(
|
||||
""
|
||||
if self.violation_error_message is None
|
||||
or self.violation_error_message == self.default_violation_error_message
|
||||
else " violation_error_message=%r" % self.violation_error_message
|
||||
),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, UniqueConstraint):
|
||||
return (
|
||||
self.name == other.name
|
||||
and self.fields == other.fields
|
||||
and self.condition == other.condition
|
||||
and self.deferrable == other.deferrable
|
||||
and self.include == other.include
|
||||
and self.opclasses == other.opclasses
|
||||
and self.expressions == other.expressions
|
||||
and self.nulls_distinct is other.nulls_distinct
|
||||
and self.violation_error_code == other.violation_error_code
|
||||
and self.violation_error_message == other.violation_error_message
|
||||
)
|
||||
return super().__eq__(other)
|
||||
|
||||
def deconstruct(self):
|
||||
path, args, kwargs = super().deconstruct()
|
||||
if self.fields:
|
||||
kwargs["fields"] = self.fields
|
||||
if self.condition:
|
||||
kwargs["condition"] = self.condition
|
||||
if self.deferrable:
|
||||
kwargs["deferrable"] = self.deferrable
|
||||
if self.include:
|
||||
kwargs["include"] = self.include
|
||||
if self.opclasses:
|
||||
kwargs["opclasses"] = self.opclasses
|
||||
if self.nulls_distinct is not None:
|
||||
kwargs["nulls_distinct"] = self.nulls_distinct
|
||||
return path, self.expressions, kwargs
|
||||
|
||||
def validate(self, model, instance, exclude=None, using=DEFAULT_DB_ALIAS):
|
||||
queryset = model._default_manager.using(using)
|
||||
if self.fields:
|
||||
lookup_kwargs = {}
|
||||
generated_field_names = []
|
||||
for field_name in self.fields:
|
||||
if exclude and field_name in exclude:
|
||||
return
|
||||
field = model._meta.get_field(field_name)
|
||||
if field.generated:
|
||||
if exclude and self._expression_refs_exclude(
|
||||
model, field.expression, exclude
|
||||
):
|
||||
return
|
||||
generated_field_names.append(field.name)
|
||||
else:
|
||||
lookup_value = getattr(instance, field.attname)
|
||||
if (
|
||||
self.nulls_distinct is not False
|
||||
and lookup_value is None
|
||||
or (
|
||||
lookup_value == ""
|
||||
and connections[
|
||||
using
|
||||
].features.interprets_empty_strings_as_nulls
|
||||
)
|
||||
):
|
||||
# A composite constraint containing NULL value cannot cause
|
||||
# a violation since NULL != NULL in SQL.
|
||||
return
|
||||
lookup_kwargs[field.name] = lookup_value
|
||||
lookup_args = []
|
||||
if generated_field_names:
|
||||
field_expression_map = instance._get_field_expression_map(
|
||||
meta=model._meta, exclude=exclude
|
||||
)
|
||||
for field_name in generated_field_names:
|
||||
expression = field_expression_map[field_name]
|
||||
if self.nulls_distinct is False:
|
||||
lhs = F(field_name)
|
||||
condition = Q(Exact(lhs, expression)) | Q(
|
||||
IsNull(lhs, True), IsNull(expression, True)
|
||||
)
|
||||
lookup_args.append(condition)
|
||||
else:
|
||||
lookup_kwargs[field_name] = expression
|
||||
queryset = queryset.filter(*lookup_args, **lookup_kwargs)
|
||||
else:
|
||||
# Ignore constraints with excluded fields.
|
||||
if exclude and any(
|
||||
self._expression_refs_exclude(model, expression, exclude)
|
||||
for expression in self.expressions
|
||||
):
|
||||
return
|
||||
replacements = {
|
||||
F(field): value
|
||||
for field, value in instance._get_field_expression_map(
|
||||
meta=model._meta, exclude=exclude
|
||||
).items()
|
||||
}
|
||||
filters = []
|
||||
for expr in self.expressions:
|
||||
if hasattr(expr, "get_expression_for_validation"):
|
||||
expr = expr.get_expression_for_validation()
|
||||
rhs = expr.replace_expressions(replacements)
|
||||
condition = Exact(expr, rhs)
|
||||
if self.nulls_distinct is False:
|
||||
condition = Q(condition) | Q(IsNull(expr, True), IsNull(rhs, True))
|
||||
filters.append(condition)
|
||||
queryset = queryset.filter(*filters)
|
||||
model_class_pk = instance._get_pk_val(model._meta)
|
||||
if not instance._state.adding and instance._is_pk_set(model._meta):
|
||||
queryset = queryset.exclude(pk=model_class_pk)
|
||||
if not self.condition:
|
||||
if queryset.exists():
|
||||
if (
|
||||
self.fields
|
||||
and self.violation_error_message
|
||||
== self.default_violation_error_message
|
||||
):
|
||||
# When fields are defined, use the unique_error_message() as
|
||||
# a default for backward compatibility.
|
||||
validation_error_message = instance.unique_error_message(
|
||||
model, self.fields
|
||||
)
|
||||
raise ValidationError(
|
||||
validation_error_message,
|
||||
code=validation_error_message.code,
|
||||
)
|
||||
raise ValidationError(
|
||||
self.get_violation_error_message(),
|
||||
code=self.violation_error_code,
|
||||
)
|
||||
else:
|
||||
against = instance._get_field_expression_map(
|
||||
meta=model._meta, exclude=exclude
|
||||
)
|
||||
try:
|
||||
if (self.condition & Exists(queryset.filter(self.condition))).check(
|
||||
against, using=using
|
||||
):
|
||||
raise ValidationError(
|
||||
self.get_violation_error_message(),
|
||||
code=self.violation_error_code,
|
||||
)
|
||||
except FieldError:
|
||||
pass
|
||||
517
.venv/lib/python3.10/site-packages/django/db/models/deletion.py
Normal file
517
.venv/lib/python3.10/site-packages/django/db/models/deletion.py
Normal file
@@ -0,0 +1,517 @@
|
||||
from collections import Counter, defaultdict
|
||||
from functools import partial, reduce
|
||||
from itertools import chain
|
||||
from operator import attrgetter, or_
|
||||
|
||||
from django.db import IntegrityError, connections, models, transaction
|
||||
from django.db.models import query_utils, signals, sql
|
||||
|
||||
|
||||
class ProtectedError(IntegrityError):
|
||||
def __init__(self, msg, protected_objects):
|
||||
self.protected_objects = protected_objects
|
||||
super().__init__(msg, protected_objects)
|
||||
|
||||
|
||||
class RestrictedError(IntegrityError):
|
||||
def __init__(self, msg, restricted_objects):
|
||||
self.restricted_objects = restricted_objects
|
||||
super().__init__(msg, restricted_objects)
|
||||
|
||||
|
||||
def CASCADE(collector, field, sub_objs, using):
|
||||
collector.collect(
|
||||
sub_objs,
|
||||
source=field.remote_field.model,
|
||||
source_attr=field.name,
|
||||
nullable=field.null,
|
||||
fail_on_restricted=False,
|
||||
)
|
||||
if field.null and not connections[using].features.can_defer_constraint_checks:
|
||||
collector.add_field_update(field, None, sub_objs)
|
||||
|
||||
|
||||
def PROTECT(collector, field, sub_objs, using):
|
||||
raise ProtectedError(
|
||||
"Cannot delete some instances of model '%s' because they are "
|
||||
"referenced through a protected foreign key: '%s.%s'"
|
||||
% (
|
||||
field.remote_field.model.__name__,
|
||||
sub_objs[0].__class__.__name__,
|
||||
field.name,
|
||||
),
|
||||
sub_objs,
|
||||
)
|
||||
|
||||
|
||||
def RESTRICT(collector, field, sub_objs, using):
|
||||
collector.add_restricted_objects(field, sub_objs)
|
||||
collector.add_dependency(field.remote_field.model, field.model)
|
||||
|
||||
|
||||
def SET(value):
|
||||
if callable(value):
|
||||
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value(), sub_objs)
|
||||
|
||||
else:
|
||||
|
||||
def set_on_delete(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, value, sub_objs)
|
||||
|
||||
set_on_delete.lazy_sub_objs = True
|
||||
|
||||
set_on_delete.deconstruct = lambda: ("django.db.models.SET", (value,), {})
|
||||
return set_on_delete
|
||||
|
||||
|
||||
def SET_NULL(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, None, sub_objs)
|
||||
|
||||
|
||||
SET_NULL.lazy_sub_objs = True
|
||||
|
||||
|
||||
def SET_DEFAULT(collector, field, sub_objs, using):
|
||||
collector.add_field_update(field, field.get_default(), sub_objs)
|
||||
|
||||
|
||||
def DO_NOTHING(collector, field, sub_objs, using):
|
||||
pass
|
||||
|
||||
|
||||
def get_candidate_relations_to_delete(opts):
|
||||
# The candidate relations are the ones that come from N-1 and 1-1 relations.
|
||||
# N-N (i.e., many-to-many) relations aren't candidates for deletion.
|
||||
return (
|
||||
f
|
||||
for f in opts.get_fields(include_hidden=True)
|
||||
if f.auto_created and not f.concrete and (f.one_to_one or f.one_to_many)
|
||||
)
|
||||
|
||||
|
||||
class Collector:
|
||||
def __init__(self, using, origin=None):
|
||||
self.using = using
|
||||
# A Model or QuerySet object.
|
||||
self.origin = origin
|
||||
# Initially, {model: {instances}}, later values become lists.
|
||||
self.data = defaultdict(set)
|
||||
# {(field, value): [instances, …]}
|
||||
self.field_updates = defaultdict(list)
|
||||
# {model: {field: {instances}}}
|
||||
self.restricted_objects = defaultdict(partial(defaultdict, set))
|
||||
# fast_deletes is a list of queryset-likes that can be deleted without
|
||||
# fetching the objects into memory.
|
||||
self.fast_deletes = []
|
||||
|
||||
# Tracks deletion-order dependency for databases without transactions
|
||||
# or ability to defer constraint checks. Only concrete model classes
|
||||
# should be included, as the dependencies exist only between actual
|
||||
# database tables; proxy models are represented here by their concrete
|
||||
# parent.
|
||||
self.dependencies = defaultdict(set) # {model: {models}}
|
||||
|
||||
def add(self, objs, source=None, nullable=False, reverse_dependency=False):
|
||||
"""
|
||||
Add 'objs' to the collection of objects to be deleted. If the call is
|
||||
the result of a cascade, 'source' should be the model that caused it,
|
||||
and 'nullable' should be set to True if the relation can be null.
|
||||
|
||||
Return a list of all objects that were not already collected.
|
||||
"""
|
||||
if not objs:
|
||||
return []
|
||||
new_objs = []
|
||||
model = objs[0].__class__
|
||||
instances = self.data[model]
|
||||
for obj in objs:
|
||||
if obj not in instances:
|
||||
new_objs.append(obj)
|
||||
instances.update(new_objs)
|
||||
# Nullable relationships can be ignored -- they are nulled out before
|
||||
# deleting, and therefore do not affect the order in which objects have
|
||||
# to be deleted.
|
||||
if source is not None and not nullable:
|
||||
self.add_dependency(source, model, reverse_dependency=reverse_dependency)
|
||||
return new_objs
|
||||
|
||||
def add_dependency(self, model, dependency, reverse_dependency=False):
|
||||
if reverse_dependency:
|
||||
model, dependency = dependency, model
|
||||
self.dependencies[model._meta.concrete_model].add(
|
||||
dependency._meta.concrete_model
|
||||
)
|
||||
self.data.setdefault(dependency, self.data.default_factory())
|
||||
|
||||
def add_field_update(self, field, value, objs):
|
||||
"""
|
||||
Schedule a field update. 'objs' must be a homogeneous iterable
|
||||
collection of model instances (e.g. a QuerySet).
|
||||
"""
|
||||
self.field_updates[field, value].append(objs)
|
||||
|
||||
def add_restricted_objects(self, field, objs):
|
||||
if objs:
|
||||
model = objs[0].__class__
|
||||
self.restricted_objects[model][field].update(objs)
|
||||
|
||||
def clear_restricted_objects_from_set(self, model, objs):
|
||||
if model in self.restricted_objects:
|
||||
self.restricted_objects[model] = {
|
||||
field: items - objs
|
||||
for field, items in self.restricted_objects[model].items()
|
||||
}
|
||||
|
||||
def clear_restricted_objects_from_queryset(self, model, qs):
|
||||
if model in self.restricted_objects:
|
||||
objs = set(
|
||||
qs.filter(
|
||||
pk__in=[
|
||||
obj.pk
|
||||
for objs in self.restricted_objects[model].values()
|
||||
for obj in objs
|
||||
]
|
||||
)
|
||||
)
|
||||
self.clear_restricted_objects_from_set(model, objs)
|
||||
|
||||
def _has_signal_listeners(self, model):
|
||||
return signals.pre_delete.has_listeners(
|
||||
model
|
||||
) or signals.post_delete.has_listeners(model)
|
||||
|
||||
def can_fast_delete(self, objs, from_field=None):
|
||||
"""
|
||||
Determine if the objects in the given queryset-like or single object
|
||||
can be fast-deleted. This can be done if there are no cascades, no
|
||||
parents and no signal listeners for the object class.
|
||||
|
||||
The 'from_field' tells where we are coming from - we need this to
|
||||
determine if the objects are in fact to be deleted. Allow also
|
||||
skipping parent -> child -> parent chain preventing fast delete of
|
||||
the child.
|
||||
"""
|
||||
if from_field and from_field.remote_field.on_delete is not CASCADE:
|
||||
return False
|
||||
if hasattr(objs, "_meta"):
|
||||
model = objs._meta.model
|
||||
elif hasattr(objs, "model") and hasattr(objs, "_raw_delete"):
|
||||
model = objs.model
|
||||
else:
|
||||
return False
|
||||
if self._has_signal_listeners(model):
|
||||
return False
|
||||
# The use of from_field comes from the need to avoid cascade back to
|
||||
# parent when parent delete is cascading to child.
|
||||
opts = model._meta
|
||||
return (
|
||||
all(
|
||||
link == from_field
|
||||
for link in opts.concrete_model._meta.parents.values()
|
||||
)
|
||||
and
|
||||
# Foreign keys pointing to this model.
|
||||
all(
|
||||
related.field.remote_field.on_delete is DO_NOTHING
|
||||
for related in get_candidate_relations_to_delete(opts)
|
||||
)
|
||||
and (
|
||||
# Something like generic foreign key.
|
||||
not any(
|
||||
hasattr(field, "bulk_related_objects")
|
||||
for field in opts.private_fields
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def get_del_batches(self, objs, fields):
|
||||
"""
|
||||
Return the objs in suitably sized batches for the used connection.
|
||||
"""
|
||||
conn_batch_size = max(
|
||||
connections[self.using].ops.bulk_batch_size(fields, objs), 1
|
||||
)
|
||||
if len(objs) > conn_batch_size:
|
||||
return [
|
||||
objs[i : i + conn_batch_size]
|
||||
for i in range(0, len(objs), conn_batch_size)
|
||||
]
|
||||
else:
|
||||
return [objs]
|
||||
|
||||
def collect(
|
||||
self,
|
||||
objs,
|
||||
source=None,
|
||||
nullable=False,
|
||||
collect_related=True,
|
||||
source_attr=None,
|
||||
reverse_dependency=False,
|
||||
keep_parents=False,
|
||||
fail_on_restricted=True,
|
||||
):
|
||||
"""
|
||||
Add 'objs' to the collection of objects to be deleted as well as all
|
||||
parent instances. 'objs' must be a homogeneous iterable collection of
|
||||
model instances (e.g. a QuerySet). If 'collect_related' is True,
|
||||
related objects will be handled by their respective on_delete handler.
|
||||
|
||||
If the call is the result of a cascade, 'source' should be the model
|
||||
that caused it and 'nullable' should be set to True, if the relation
|
||||
can be null.
|
||||
|
||||
If 'reverse_dependency' is True, 'source' will be deleted before the
|
||||
current model, rather than after. (Needed for cascading to parent
|
||||
models, the one case in which the cascade follows the forwards
|
||||
direction of an FK rather than the reverse direction.)
|
||||
|
||||
If 'keep_parents' is True, data of parent model's will be not deleted.
|
||||
|
||||
If 'fail_on_restricted' is False, error won't be raised even if it's
|
||||
prohibited to delete such objects due to RESTRICT, that defers
|
||||
restricted object checking in recursive calls where the top-level call
|
||||
may need to collect more objects to determine whether restricted ones
|
||||
can be deleted.
|
||||
"""
|
||||
if self.can_fast_delete(objs):
|
||||
self.fast_deletes.append(objs)
|
||||
return
|
||||
new_objs = self.add(
|
||||
objs, source, nullable, reverse_dependency=reverse_dependency
|
||||
)
|
||||
if not new_objs:
|
||||
return
|
||||
|
||||
model = new_objs[0].__class__
|
||||
|
||||
if not keep_parents:
|
||||
# Recursively collect concrete model's parent models, but not their
|
||||
# related objects. These will be found by meta.get_fields()
|
||||
concrete_model = model._meta.concrete_model
|
||||
for ptr in concrete_model._meta.parents.values():
|
||||
if ptr:
|
||||
parent_objs = [getattr(obj, ptr.name) for obj in new_objs]
|
||||
self.collect(
|
||||
parent_objs,
|
||||
source=model,
|
||||
source_attr=ptr.remote_field.related_name,
|
||||
collect_related=False,
|
||||
reverse_dependency=True,
|
||||
fail_on_restricted=False,
|
||||
)
|
||||
if not collect_related:
|
||||
return
|
||||
|
||||
model_fast_deletes = defaultdict(list)
|
||||
protected_objects = defaultdict(list)
|
||||
for related in get_candidate_relations_to_delete(model._meta):
|
||||
# Preserve parent reverse relationships if keep_parents=True.
|
||||
if keep_parents and related.model in model._meta.all_parents:
|
||||
continue
|
||||
field = related.field
|
||||
on_delete = field.remote_field.on_delete
|
||||
if on_delete == DO_NOTHING:
|
||||
continue
|
||||
related_model = related.related_model
|
||||
if self.can_fast_delete(related_model, from_field=field):
|
||||
model_fast_deletes[related_model].append(field)
|
||||
continue
|
||||
batches = self.get_del_batches(new_objs, [field])
|
||||
for batch in batches:
|
||||
sub_objs = self.related_objects(related_model, [field], batch)
|
||||
# Non-referenced fields can be deferred if no signal receivers
|
||||
# are connected for the related model as they'll never be
|
||||
# exposed to the user. Skip field deferring when some
|
||||
# relationships are select_related as interactions between both
|
||||
# features are hard to get right. This should only happen in
|
||||
# the rare cases where .related_objects is overridden anyway.
|
||||
if not (
|
||||
sub_objs.query.select_related
|
||||
or self._has_signal_listeners(related_model)
|
||||
):
|
||||
referenced_fields = set(
|
||||
chain.from_iterable(
|
||||
(rf.attname for rf in rel.field.foreign_related_fields)
|
||||
for rel in get_candidate_relations_to_delete(
|
||||
related_model._meta
|
||||
)
|
||||
)
|
||||
)
|
||||
sub_objs = sub_objs.only(*tuple(referenced_fields))
|
||||
if getattr(on_delete, "lazy_sub_objs", False) or sub_objs:
|
||||
try:
|
||||
on_delete(self, field, sub_objs, self.using)
|
||||
except ProtectedError as error:
|
||||
key = "'%s.%s'" % (field.model.__name__, field.name)
|
||||
protected_objects[key] += error.protected_objects
|
||||
if protected_objects:
|
||||
raise ProtectedError(
|
||||
"Cannot delete some instances of model %r because they are "
|
||||
"referenced through protected foreign keys: %s."
|
||||
% (
|
||||
model.__name__,
|
||||
", ".join(protected_objects),
|
||||
),
|
||||
set(chain.from_iterable(protected_objects.values())),
|
||||
)
|
||||
for related_model, related_fields in model_fast_deletes.items():
|
||||
batches = self.get_del_batches(new_objs, related_fields)
|
||||
for batch in batches:
|
||||
sub_objs = self.related_objects(related_model, related_fields, batch)
|
||||
self.fast_deletes.append(sub_objs)
|
||||
for field in model._meta.private_fields:
|
||||
if hasattr(field, "bulk_related_objects"):
|
||||
# It's something like generic foreign key.
|
||||
sub_objs = field.bulk_related_objects(new_objs, self.using)
|
||||
self.collect(
|
||||
sub_objs, source=model, nullable=True, fail_on_restricted=False
|
||||
)
|
||||
|
||||
if fail_on_restricted:
|
||||
# Raise an error if collected restricted objects (RESTRICT) aren't
|
||||
# candidates for deletion also collected via CASCADE.
|
||||
for related_model, instances in self.data.items():
|
||||
self.clear_restricted_objects_from_set(related_model, instances)
|
||||
for qs in self.fast_deletes:
|
||||
self.clear_restricted_objects_from_queryset(qs.model, qs)
|
||||
if self.restricted_objects.values():
|
||||
restricted_objects = defaultdict(list)
|
||||
for related_model, fields in self.restricted_objects.items():
|
||||
for field, objs in fields.items():
|
||||
if objs:
|
||||
key = "'%s.%s'" % (related_model.__name__, field.name)
|
||||
restricted_objects[key] += objs
|
||||
if restricted_objects:
|
||||
raise RestrictedError(
|
||||
"Cannot delete some instances of model %r because "
|
||||
"they are referenced through restricted foreign keys: "
|
||||
"%s."
|
||||
% (
|
||||
model.__name__,
|
||||
", ".join(restricted_objects),
|
||||
),
|
||||
set(chain.from_iterable(restricted_objects.values())),
|
||||
)
|
||||
|
||||
def related_objects(self, related_model, related_fields, objs):
|
||||
"""
|
||||
Get a QuerySet of the related model to objs via related fields.
|
||||
"""
|
||||
predicate = query_utils.Q.create(
|
||||
[(f"{related_field.name}__in", objs) for related_field in related_fields],
|
||||
connector=query_utils.Q.OR,
|
||||
)
|
||||
return related_model._base_manager.using(self.using).filter(predicate)
|
||||
|
||||
def instances_with_model(self):
|
||||
for model, instances in self.data.items():
|
||||
for obj in instances:
|
||||
yield model, obj
|
||||
|
||||
def sort(self):
|
||||
sorted_models = []
|
||||
concrete_models = set()
|
||||
models = list(self.data)
|
||||
while len(sorted_models) < len(models):
|
||||
found = False
|
||||
for model in models:
|
||||
if model in sorted_models:
|
||||
continue
|
||||
dependencies = self.dependencies.get(model._meta.concrete_model)
|
||||
if not (dependencies and dependencies.difference(concrete_models)):
|
||||
sorted_models.append(model)
|
||||
concrete_models.add(model._meta.concrete_model)
|
||||
found = True
|
||||
if not found:
|
||||
return
|
||||
self.data = {model: self.data[model] for model in sorted_models}
|
||||
|
||||
def delete(self):
|
||||
# sort instance collections
|
||||
for model, instances in self.data.items():
|
||||
self.data[model] = sorted(instances, key=attrgetter("pk"))
|
||||
|
||||
# if possible, bring the models in an order suitable for databases that
|
||||
# don't support transactions or cannot defer constraint checks until the
|
||||
# end of a transaction.
|
||||
self.sort()
|
||||
# number of objects deleted for each model label
|
||||
deleted_counter = Counter()
|
||||
|
||||
# Optimize for the case with a single obj and no dependencies
|
||||
if len(self.data) == 1 and len(instances) == 1:
|
||||
instance = list(instances)[0]
|
||||
if self.can_fast_delete(instance):
|
||||
with transaction.mark_for_rollback_on_error(self.using):
|
||||
count = sql.DeleteQuery(model).delete_batch(
|
||||
[instance.pk], self.using
|
||||
)
|
||||
setattr(instance, model._meta.pk.attname, None)
|
||||
return count, {model._meta.label: count}
|
||||
|
||||
with transaction.atomic(using=self.using, savepoint=False):
|
||||
# send pre_delete signals
|
||||
for model, obj in self.instances_with_model():
|
||||
if not model._meta.auto_created:
|
||||
signals.pre_delete.send(
|
||||
sender=model,
|
||||
instance=obj,
|
||||
using=self.using,
|
||||
origin=self.origin,
|
||||
)
|
||||
|
||||
# fast deletes
|
||||
for qs in self.fast_deletes:
|
||||
count = qs._raw_delete(using=self.using)
|
||||
if count:
|
||||
deleted_counter[qs.model._meta.label] += count
|
||||
|
||||
# update fields
|
||||
for (field, value), instances_list in self.field_updates.items():
|
||||
updates = []
|
||||
objs = []
|
||||
for instances in instances_list:
|
||||
if (
|
||||
isinstance(instances, models.QuerySet)
|
||||
and instances._result_cache is None
|
||||
):
|
||||
updates.append(instances)
|
||||
else:
|
||||
objs.extend(instances)
|
||||
if updates:
|
||||
combined_updates = reduce(or_, updates)
|
||||
combined_updates.update(**{field.name: value})
|
||||
if objs:
|
||||
model = objs[0].__class__
|
||||
query = sql.UpdateQuery(model)
|
||||
query.update_batch(
|
||||
list({obj.pk for obj in objs}), {field.name: value}, self.using
|
||||
)
|
||||
|
||||
# reverse instance collections
|
||||
for instances in self.data.values():
|
||||
instances.reverse()
|
||||
|
||||
# delete instances
|
||||
for model, instances in self.data.items():
|
||||
query = sql.DeleteQuery(model)
|
||||
pk_list = [obj.pk for obj in instances]
|
||||
count = query.delete_batch(pk_list, self.using)
|
||||
if count:
|
||||
deleted_counter[model._meta.label] += count
|
||||
|
||||
if not model._meta.auto_created:
|
||||
for obj in instances:
|
||||
signals.post_delete.send(
|
||||
sender=model,
|
||||
instance=obj,
|
||||
using=self.using,
|
||||
origin=self.origin,
|
||||
)
|
||||
|
||||
for model, instances in self.data.items():
|
||||
for instance in instances:
|
||||
setattr(instance, model._meta.pk.attname, None)
|
||||
return sum(deleted_counter.values()), dict(deleted_counter)
|
||||
123
.venv/lib/python3.10/site-packages/django/db/models/enums.py
Normal file
123
.venv/lib/python3.10/site-packages/django/db/models/enums.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import enum
|
||||
import warnings
|
||||
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.functional import Promise
|
||||
from django.utils.version import PY311, PY312
|
||||
|
||||
if PY311:
|
||||
from enum import EnumType, IntEnum, StrEnum
|
||||
from enum import property as enum_property
|
||||
else:
|
||||
from enum import EnumMeta as EnumType
|
||||
from types import DynamicClassAttribute as enum_property
|
||||
|
||||
class ReprEnum(enum.Enum):
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
class IntEnum(int, ReprEnum):
|
||||
pass
|
||||
|
||||
class StrEnum(str, ReprEnum):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["Choices", "IntegerChoices", "TextChoices"]
|
||||
|
||||
|
||||
class ChoicesType(EnumType):
|
||||
"""A metaclass for creating a enum choices."""
|
||||
|
||||
def __new__(metacls, classname, bases, classdict, **kwds):
|
||||
labels = []
|
||||
for key in classdict._member_names:
|
||||
value = classdict[key]
|
||||
if (
|
||||
isinstance(value, (list, tuple))
|
||||
and len(value) > 1
|
||||
and isinstance(value[-1], (Promise, str))
|
||||
):
|
||||
*value, label = value
|
||||
value = tuple(value)
|
||||
else:
|
||||
label = key.replace("_", " ").title()
|
||||
labels.append(label)
|
||||
# Use dict.__setitem__() to suppress defenses against double
|
||||
# assignment in enum's classdict.
|
||||
dict.__setitem__(classdict, key, value)
|
||||
cls = super().__new__(metacls, classname, bases, classdict, **kwds)
|
||||
for member, label in zip(cls.__members__.values(), labels):
|
||||
member._label_ = label
|
||||
return enum.unique(cls)
|
||||
|
||||
if not PY312:
|
||||
|
||||
def __contains__(cls, member):
|
||||
if not isinstance(member, enum.Enum):
|
||||
# Allow non-enums to match against member values.
|
||||
return any(x.value == member for x in cls)
|
||||
return super().__contains__(member)
|
||||
|
||||
@property
|
||||
def names(cls):
|
||||
empty = ["__empty__"] if hasattr(cls, "__empty__") else []
|
||||
return empty + [member.name for member in cls]
|
||||
|
||||
@property
|
||||
def choices(cls):
|
||||
empty = [(None, cls.__empty__)] if hasattr(cls, "__empty__") else []
|
||||
return empty + [(member.value, member.label) for member in cls]
|
||||
|
||||
@property
|
||||
def labels(cls):
|
||||
return [label for _, label in cls.choices]
|
||||
|
||||
@property
|
||||
def values(cls):
|
||||
return [value for value, _ in cls.choices]
|
||||
|
||||
|
||||
class Choices(enum.Enum, metaclass=ChoicesType):
|
||||
"""Class for creating enumerated choices."""
|
||||
|
||||
if PY311:
|
||||
do_not_call_in_templates = enum.nonmember(True)
|
||||
else:
|
||||
|
||||
@property
|
||||
def do_not_call_in_templates(self):
|
||||
return True
|
||||
|
||||
@enum_property
|
||||
def label(self):
|
||||
return self._label_
|
||||
|
||||
# A similar format was proposed for Python 3.10.
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__qualname__}.{self._name_}"
|
||||
|
||||
|
||||
class IntegerChoices(Choices, IntEnum):
|
||||
"""Class for creating enumerated integer choices."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TextChoices(Choices, StrEnum):
|
||||
"""Class for creating enumerated string choices."""
|
||||
|
||||
@staticmethod
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "ChoicesMeta":
|
||||
warnings.warn(
|
||||
"ChoicesMeta is deprecated in favor of ChoicesType.",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return ChoicesType
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
2122
.venv/lib/python3.10/site-packages/django/db/models/expressions.py
Normal file
2122
.venv/lib/python3.10/site-packages/django/db/models/expressions.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,177 @@
|
||||
import json
|
||||
|
||||
from django.core import checks
|
||||
from django.db.models import NOT_PROVIDED, Field
|
||||
from django.db.models.expressions import ColPairs
|
||||
from django.db.models.fields.tuple_lookups import (
|
||||
TupleExact,
|
||||
TupleGreaterThan,
|
||||
TupleGreaterThanOrEqual,
|
||||
TupleIn,
|
||||
TupleIsNull,
|
||||
TupleLessThan,
|
||||
TupleLessThanOrEqual,
|
||||
)
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class AttributeSetter:
|
||||
def __init__(self, name, value):
|
||||
setattr(self, name, value)
|
||||
|
||||
|
||||
class CompositeAttribute:
|
||||
def __init__(self, field):
|
||||
self.field = field
|
||||
|
||||
@property
|
||||
def attnames(self):
|
||||
return [field.attname for field in self.field.fields]
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
return tuple(getattr(instance, attname) for attname in self.attnames)
|
||||
|
||||
def __set__(self, instance, values):
|
||||
attnames = self.attnames
|
||||
length = len(attnames)
|
||||
|
||||
if values is None:
|
||||
values = (None,) * length
|
||||
|
||||
if not isinstance(values, (list, tuple)):
|
||||
raise ValueError(f"{self.field.name!r} must be a list or a tuple.")
|
||||
if length != len(values):
|
||||
raise ValueError(f"{self.field.name!r} must have {length} elements.")
|
||||
|
||||
for attname, value in zip(attnames, values):
|
||||
setattr(instance, attname, value)
|
||||
|
||||
|
||||
class CompositePrimaryKey(Field):
|
||||
descriptor_class = CompositeAttribute
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if (
|
||||
not args
|
||||
or not all(isinstance(field, str) for field in args)
|
||||
or len(set(args)) != len(args)
|
||||
):
|
||||
raise ValueError("CompositePrimaryKey args must be unique strings.")
|
||||
if len(args) == 1:
|
||||
raise ValueError("CompositePrimaryKey must include at least two fields.")
|
||||
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("CompositePrimaryKey cannot have a default.")
|
||||
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("CompositePrimaryKey cannot have a database default.")
|
||||
if kwargs.get("db_column", None) is not None:
|
||||
raise ValueError("CompositePrimaryKey cannot have a db_column.")
|
||||
if kwargs.setdefault("editable", False):
|
||||
raise ValueError("CompositePrimaryKey cannot be editable.")
|
||||
if not kwargs.setdefault("primary_key", True):
|
||||
raise ValueError("CompositePrimaryKey must be a primary key.")
|
||||
if not kwargs.setdefault("blank", True):
|
||||
raise ValueError("CompositePrimaryKey must be blank.")
|
||||
|
||||
self.field_names = args
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
# args is always [] so it can be ignored.
|
||||
name, path, _, kwargs = super().deconstruct()
|
||||
return name, path, self.field_names, kwargs
|
||||
|
||||
@cached_property
|
||||
def fields(self):
|
||||
meta = self.model._meta
|
||||
return tuple(meta.get_field(field_name) for field_name in self.field_names)
|
||||
|
||||
@cached_property
|
||||
def columns(self):
|
||||
return tuple(field.column for field in self.fields)
|
||||
|
||||
def contribute_to_class(self, cls, name, private_only=False):
|
||||
super().contribute_to_class(cls, name, private_only=private_only)
|
||||
cls._meta.pk = self
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def get_attname_column(self):
|
||||
return self.get_attname(), None
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.fields)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.field_names)
|
||||
|
||||
@cached_property
|
||||
def cached_col(self):
|
||||
return ColPairs(self.model._meta.db_table, self.fields, self.fields, self)
|
||||
|
||||
def get_col(self, alias, output_field=None):
|
||||
if alias == self.model._meta.db_table and (
|
||||
output_field is None or output_field == self
|
||||
):
|
||||
return self.cached_col
|
||||
|
||||
return ColPairs(alias, self.fields, self.fields, output_field)
|
||||
|
||||
def get_pk_value_on_save(self, instance):
|
||||
values = []
|
||||
|
||||
for field in self.fields:
|
||||
value = field.value_from_object(instance)
|
||||
if value is None:
|
||||
value = field.get_pk_value_on_save(instance)
|
||||
values.append(value)
|
||||
|
||||
return tuple(values)
|
||||
|
||||
def _check_field_name(self):
|
||||
if self.name == "pk":
|
||||
return []
|
||||
return [
|
||||
checks.Error(
|
||||
"'CompositePrimaryKey' must be named 'pk'.",
|
||||
obj=self,
|
||||
id="fields.E013",
|
||||
)
|
||||
]
|
||||
|
||||
def value_to_string(self, obj):
|
||||
values = []
|
||||
vals = self.value_from_object(obj)
|
||||
for field, value in zip(self.fields, vals):
|
||||
obj = AttributeSetter(field.attname, value)
|
||||
values.append(field.value_to_string(obj))
|
||||
return json.dumps(values, ensure_ascii=False)
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, str):
|
||||
# Assume we're deserializing.
|
||||
vals = json.loads(value)
|
||||
value = [
|
||||
field.to_python(val)
|
||||
for field, val in zip(self.fields, vals, strict=True)
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
CompositePrimaryKey.register_lookup(TupleExact)
|
||||
CompositePrimaryKey.register_lookup(TupleGreaterThan)
|
||||
CompositePrimaryKey.register_lookup(TupleGreaterThanOrEqual)
|
||||
CompositePrimaryKey.register_lookup(TupleLessThan)
|
||||
CompositePrimaryKey.register_lookup(TupleLessThanOrEqual)
|
||||
CompositePrimaryKey.register_lookup(TupleIn)
|
||||
CompositePrimaryKey.register_lookup(TupleIsNull)
|
||||
|
||||
|
||||
def unnest(fields):
|
||||
result = []
|
||||
|
||||
for field in fields:
|
||||
if isinstance(field, CompositePrimaryKey):
|
||||
result.extend(field.fields)
|
||||
else:
|
||||
result.append(field)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,538 @@
|
||||
import datetime
|
||||
import posixpath
|
||||
|
||||
from django import forms
|
||||
from django.core import checks
|
||||
from django.core.exceptions import FieldError
|
||||
from django.core.files.base import ContentFile, File
|
||||
from django.core.files.images import ImageFile
|
||||
from django.core.files.storage import Storage, default_storage
|
||||
from django.core.files.utils import validate_file_name
|
||||
from django.db.models import signals
|
||||
from django.db.models.expressions import DatabaseDefault
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.query_utils import DeferredAttribute
|
||||
from django.db.models.utils import AltersData
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.version import PY311
|
||||
|
||||
|
||||
class FieldFile(File, AltersData):
|
||||
def __init__(self, instance, field, name):
|
||||
super().__init__(None, name)
|
||||
self.instance = instance
|
||||
self.field = field
|
||||
self.storage = field.storage
|
||||
self._committed = True
|
||||
|
||||
def __eq__(self, other):
|
||||
# Older code may be expecting FileField values to be simple strings.
|
||||
# By overriding the == operator, it can remain backwards compatibility.
|
||||
if hasattr(other, "name"):
|
||||
return self.name == other.name
|
||||
return self.name == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
# The standard File contains most of the necessary properties, but
|
||||
# FieldFiles can be instantiated without a name, so that needs to
|
||||
# be checked for here.
|
||||
|
||||
def _require_file(self):
|
||||
if not self:
|
||||
raise ValueError(
|
||||
"The '%s' attribute has no file associated with it." % self.field.name
|
||||
)
|
||||
|
||||
def _get_file(self):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self._file = self.storage.open(self.name, "rb")
|
||||
return self._file
|
||||
|
||||
def _set_file(self, file):
|
||||
self._file = file
|
||||
|
||||
def _del_file(self):
|
||||
del self._file
|
||||
|
||||
file = property(_get_file, _set_file, _del_file)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
self._require_file()
|
||||
return self.storage.path(self.name)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
self._require_file()
|
||||
return self.storage.url(self.name)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
self._require_file()
|
||||
if not self._committed:
|
||||
return self.file.size
|
||||
return self.storage.size(self.name)
|
||||
|
||||
def open(self, mode="rb"):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self.file = self.storage.open(self.name, mode)
|
||||
else:
|
||||
self.file.open(mode)
|
||||
return self
|
||||
|
||||
# open() doesn't alter the file's contents, but it does reset the pointer
|
||||
open.alters_data = True
|
||||
|
||||
# In addition to the standard File API, FieldFiles have extra methods
|
||||
# to further manipulate the underlying file, as well as update the
|
||||
# associated model instance.
|
||||
|
||||
def _set_instance_attribute(self, name, content):
|
||||
setattr(self.instance, self.field.attname, name)
|
||||
|
||||
def save(self, name, content, save=True):
|
||||
name = self.field.generate_filename(self.instance, name)
|
||||
self.name = self.storage.save(name, content, max_length=self.field.max_length)
|
||||
self._set_instance_attribute(self.name, content)
|
||||
self._committed = True
|
||||
|
||||
# Save the object because it has changed, unless save is False
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
save.alters_data = True
|
||||
|
||||
def delete(self, save=True):
|
||||
if not self:
|
||||
return
|
||||
# Only close the file if it's already open, which we know by the
|
||||
# presence of self._file
|
||||
if hasattr(self, "_file"):
|
||||
self.close()
|
||||
del self.file
|
||||
|
||||
self.storage.delete(self.name)
|
||||
|
||||
self.name = None
|
||||
setattr(self.instance, self.field.attname, self.name)
|
||||
self._committed = False
|
||||
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
delete.alters_data = True
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
file = getattr(self, "_file", None)
|
||||
return file is None or file.closed
|
||||
|
||||
def close(self):
|
||||
file = getattr(self, "_file", None)
|
||||
if file is not None:
|
||||
file.close()
|
||||
|
||||
def __getstate__(self):
|
||||
# FieldFile needs access to its associated model field, an instance and
|
||||
# the file's name. Everything else will be restored later, by
|
||||
# FileDescriptor below.
|
||||
return {
|
||||
"name": self.name,
|
||||
"closed": False,
|
||||
"_committed": True,
|
||||
"_file": None,
|
||||
"instance": self.instance,
|
||||
"field": self.field,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.storage = self.field.storage
|
||||
|
||||
|
||||
class FileDescriptor(DeferredAttribute):
|
||||
"""
|
||||
The descriptor for the file attribute on the model instance. Return a
|
||||
FieldFile when accessed so you can write code like::
|
||||
|
||||
>>> from myapp.models import MyModel
|
||||
>>> instance = MyModel.objects.get(pk=1)
|
||||
>>> instance.file.size
|
||||
|
||||
Assign a file object on assignment so you can do::
|
||||
|
||||
>>> with open('/path/to/hello.world') as f:
|
||||
... instance.file = File(f)
|
||||
"""
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
# This is slightly complicated, so worth an explanation.
|
||||
# instance.file needs to ultimately return some instance of `File`,
|
||||
# probably a subclass. Additionally, this returned object needs to have
|
||||
# the FieldFile API so that users can easily do things like
|
||||
# instance.file.path and have that delegated to the file storage engine.
|
||||
# Easy enough if we're strict about assignment in __set__, but if you
|
||||
# peek below you can see that we're not. So depending on the current
|
||||
# value of the field we have to dynamically construct some sort of
|
||||
# "thing" to return.
|
||||
|
||||
# The instance dict contains whatever was originally assigned
|
||||
# in __set__.
|
||||
file = super().__get__(instance, cls)
|
||||
|
||||
# If this value is a string (instance.file = "path/to/file") or None
|
||||
# then we simply wrap it with the appropriate attribute class according
|
||||
# to the file field. [This is FieldFile for FileFields and
|
||||
# ImageFieldFile for ImageFields; it's also conceivable that user
|
||||
# subclasses might also want to subclass the attribute class]. This
|
||||
# object understands how to convert a path to a file, and also how to
|
||||
# handle None.
|
||||
if isinstance(file, str) or file is None:
|
||||
attr = self.field.attr_class(instance, self.field, file)
|
||||
instance.__dict__[self.field.attname] = attr
|
||||
|
||||
# If this value is a DatabaseDefault, initialize the attribute class
|
||||
# for this field with its db_default value.
|
||||
elif isinstance(file, DatabaseDefault):
|
||||
attr = self.field.attr_class(instance, self.field, self.field.db_default)
|
||||
instance.__dict__[self.field.attname] = attr
|
||||
|
||||
# Other types of files may be assigned as well, but they need to have
|
||||
# the FieldFile interface added to them. Thus, we wrap any other type of
|
||||
# File inside a FieldFile (well, the field's attr_class, which is
|
||||
# usually FieldFile).
|
||||
elif isinstance(file, File) and not isinstance(file, FieldFile):
|
||||
file_copy = self.field.attr_class(instance, self.field, file.name)
|
||||
file_copy.file = file
|
||||
file_copy._committed = False
|
||||
instance.__dict__[self.field.attname] = file_copy
|
||||
|
||||
# Finally, because of the (some would say boneheaded) way pickle works,
|
||||
# the underlying FieldFile might not actually itself have an associated
|
||||
# file. So we need to reset the details of the FieldFile in those cases.
|
||||
elif isinstance(file, FieldFile) and not hasattr(file, "field"):
|
||||
file.instance = instance
|
||||
file.field = self.field
|
||||
file.storage = self.field.storage
|
||||
|
||||
# Make sure that the instance is correct.
|
||||
elif isinstance(file, FieldFile) and instance is not file.instance:
|
||||
file.instance = instance
|
||||
|
||||
# That was fun, wasn't it?
|
||||
return instance.__dict__[self.field.attname]
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance.__dict__[self.field.attname] = value
|
||||
|
||||
|
||||
class FileField(Field):
|
||||
# The class to wrap instance attributes in. Accessing the file object off
|
||||
# the instance will always return an instance of attr_class.
|
||||
attr_class = FieldFile
|
||||
|
||||
# The descriptor to use for accessing the attribute off of the class.
|
||||
descriptor_class = FileDescriptor
|
||||
|
||||
description = _("File")
|
||||
|
||||
def __init__(
|
||||
self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
|
||||
):
|
||||
self._primary_key_set_explicitly = "primary_key" in kwargs
|
||||
|
||||
self.storage = storage or default_storage
|
||||
if callable(self.storage):
|
||||
# Hold a reference to the callable for deconstruct().
|
||||
self._storage_callable = self.storage
|
||||
self.storage = self.storage()
|
||||
if not isinstance(self.storage, Storage):
|
||||
raise TypeError(
|
||||
"%s.storage must be a subclass/instance of %s.%s"
|
||||
% (
|
||||
self.__class__.__qualname__,
|
||||
Storage.__module__,
|
||||
Storage.__qualname__,
|
||||
)
|
||||
)
|
||||
self.upload_to = upload_to
|
||||
|
||||
kwargs.setdefault("max_length", 100)
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_primary_key(),
|
||||
*self._check_upload_to(),
|
||||
]
|
||||
|
||||
def _check_primary_key(self):
|
||||
if self._primary_key_set_explicitly:
|
||||
return [
|
||||
checks.Error(
|
||||
"'primary_key' is not a valid argument for a %s."
|
||||
% self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E201",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _check_upload_to(self):
|
||||
if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
|
||||
return [
|
||||
checks.Error(
|
||||
"%s's 'upload_to' argument must be a relative path, not an "
|
||||
"absolute path." % self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E202",
|
||||
hint="Remove the leading slash.",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if kwargs.get("max_length") == 100:
|
||||
del kwargs["max_length"]
|
||||
kwargs["upload_to"] = self.upload_to
|
||||
storage = getattr(self, "_storage_callable", self.storage)
|
||||
if storage is not default_storage:
|
||||
kwargs["storage"] = storage
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
return "FileField"
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
# Need to convert File objects provided via a form to string for
|
||||
# database insertion.
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
file = super().pre_save(model_instance, add)
|
||||
if file.name is None and file._file is not None:
|
||||
exc = FieldError(
|
||||
f"File for {self.name} must have "
|
||||
"the name attribute specified to be saved."
|
||||
)
|
||||
if PY311 and isinstance(file._file, ContentFile):
|
||||
exc.add_note("Pass a 'name' argument to ContentFile.")
|
||||
raise exc
|
||||
|
||||
if file and not file._committed:
|
||||
# Commit the file to storage prior to saving the model
|
||||
file.save(file.name, file.file, save=False)
|
||||
return file
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def generate_filename(self, instance, filename):
|
||||
"""
|
||||
Apply (if callable) or prepend (if a string) upload_to to the filename,
|
||||
then delegate further processing of the name to the storage backend.
|
||||
Until the storage layer, all file paths are expected to be Unix style
|
||||
(with forward slashes).
|
||||
"""
|
||||
if callable(self.upload_to):
|
||||
filename = self.upload_to(instance, filename)
|
||||
else:
|
||||
dirname = datetime.datetime.now().strftime(str(self.upload_to))
|
||||
filename = posixpath.join(dirname, filename)
|
||||
filename = validate_file_name(filename, allow_relative_path=True)
|
||||
return self.storage.generate_filename(filename)
|
||||
|
||||
def save_form_data(self, instance, data):
|
||||
# Important: None means "no change", other false value means "clear"
|
||||
# This subtle distinction (rather than a more explicit marker) is
|
||||
# needed because we need to consume values that are also sane for a
|
||||
# regular (non Model-) Form to find in its cleaned_data dictionary.
|
||||
if data is not None:
|
||||
# This value will be converted to str and stored in the
|
||||
# database, so leaving False as-is is not acceptable.
|
||||
setattr(instance, self.name, data or "")
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.FileField,
|
||||
"max_length": self.max_length,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ImageFileDescriptor(FileDescriptor):
|
||||
"""
|
||||
Just like the FileDescriptor, but for ImageFields. The only difference is
|
||||
assigning the width/height to the width_field/height_field, if appropriate.
|
||||
"""
|
||||
|
||||
def __set__(self, instance, value):
|
||||
previous_file = instance.__dict__.get(self.field.attname)
|
||||
super().__set__(instance, value)
|
||||
|
||||
# To prevent recalculating image dimensions when we are instantiating
|
||||
# an object from the database (bug #11084), only update dimensions if
|
||||
# the field had a value before this assignment. Since the default
|
||||
# value for FileField subclasses is an instance of field.attr_class,
|
||||
# previous_file will only be None when we are called from
|
||||
# Model.__init__(). The ImageField.update_dimension_fields method
|
||||
# hooked up to the post_init signal handles the Model.__init__() cases.
|
||||
# Assignment happening outside of Model.__init__() will trigger the
|
||||
# update right here.
|
||||
if previous_file is not None:
|
||||
self.field.update_dimension_fields(instance, force=True)
|
||||
|
||||
|
||||
class ImageFieldFile(ImageFile, FieldFile):
|
||||
def _set_instance_attribute(self, name, content):
|
||||
setattr(self.instance, self.field.attname, content)
|
||||
# Update the name in case generate_filename() or storage.save() changed
|
||||
# it, but bypass the descriptor to avoid re-reading the file.
|
||||
self.instance.__dict__[self.field.attname] = self.name
|
||||
|
||||
def delete(self, save=True):
|
||||
# Clear the image dimensions cache
|
||||
if hasattr(self, "_dimensions_cache"):
|
||||
del self._dimensions_cache
|
||||
super().delete(save)
|
||||
|
||||
|
||||
class ImageField(FileField):
|
||||
attr_class = ImageFieldFile
|
||||
descriptor_class = ImageFileDescriptor
|
||||
description = _("Image")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
width_field=None,
|
||||
height_field=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.width_field, self.height_field = width_field, height_field
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_image_library_installed(),
|
||||
]
|
||||
|
||||
def _check_image_library_installed(self):
|
||||
try:
|
||||
from PIL import Image # NOQA
|
||||
except ImportError:
|
||||
return [
|
||||
checks.Error(
|
||||
"Cannot use ImageField because Pillow is not installed.",
|
||||
hint=(
|
||||
"Get Pillow at https://pypi.org/project/Pillow/ "
|
||||
'or run command "python -m pip install Pillow".'
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E210",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.width_field:
|
||||
kwargs["width_field"] = self.width_field
|
||||
if self.height_field:
|
||||
kwargs["height_field"] = self.height_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
# Attach update_dimension_fields so that dimension fields declared
|
||||
# after their corresponding image field don't stay cleared by
|
||||
# Model.__init__, see bug #11196.
|
||||
# Only run post-initialization dimension update on non-abstract models
|
||||
# with width_field/height_field.
|
||||
if not cls._meta.abstract and (self.width_field or self.height_field):
|
||||
signals.post_init.connect(self.update_dimension_fields, sender=cls)
|
||||
|
||||
def update_dimension_fields(self, instance, force=False, *args, **kwargs):
|
||||
"""
|
||||
Update field's width and height fields, if defined.
|
||||
|
||||
This method is hooked up to model's post_init signal to update
|
||||
dimensions after instantiating a model instance. However, dimensions
|
||||
won't be updated if the dimensions fields are already populated. This
|
||||
avoids unnecessary recalculation when loading an object from the
|
||||
database.
|
||||
|
||||
Dimensions can be forced to update with force=True, which is how
|
||||
ImageFileDescriptor.__set__ calls this method.
|
||||
"""
|
||||
# Nothing to update if the field doesn't have dimension fields or if
|
||||
# the field is deferred.
|
||||
has_dimension_fields = self.width_field or self.height_field
|
||||
if not has_dimension_fields or self.attname not in instance.__dict__:
|
||||
return
|
||||
|
||||
# getattr will call the ImageFileDescriptor's __get__ method, which
|
||||
# coerces the assigned value into an instance of self.attr_class
|
||||
# (ImageFieldFile in this case).
|
||||
file = getattr(instance, self.attname)
|
||||
|
||||
# Nothing to update if we have no file and not being forced to update.
|
||||
if not file and not force:
|
||||
return
|
||||
|
||||
dimension_fields_filled = not (
|
||||
(self.width_field and not getattr(instance, self.width_field))
|
||||
or (self.height_field and not getattr(instance, self.height_field))
|
||||
)
|
||||
# When both dimension fields have values, we are most likely loading
|
||||
# data from the database or updating an image field that already had
|
||||
# an image stored. In the first case, we don't want to update the
|
||||
# dimension fields because we are already getting their values from the
|
||||
# database. In the second case, we do want to update the dimensions
|
||||
# fields and will skip this return because force will be True since we
|
||||
# were called from ImageFileDescriptor.__set__.
|
||||
if dimension_fields_filled and not force:
|
||||
return
|
||||
|
||||
# file should be an instance of ImageFieldFile or should be None.
|
||||
if file:
|
||||
width = file.width
|
||||
height = file.height
|
||||
else:
|
||||
# No file, so clear dimensions fields.
|
||||
width = None
|
||||
height = None
|
||||
|
||||
# Update the width and height fields.
|
||||
if self.width_field:
|
||||
setattr(instance, self.width_field, width)
|
||||
if self.height_field:
|
||||
setattr(instance, self.height_field, height)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.ImageField,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,197 @@
|
||||
from django.core import checks
|
||||
from django.db import connections, router
|
||||
from django.db.models.sql import Query
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from . import NOT_PROVIDED, Field
|
||||
|
||||
__all__ = ["GeneratedField"]
|
||||
|
||||
|
||||
class GeneratedField(Field):
|
||||
generated = True
|
||||
db_returning = True
|
||||
|
||||
_query = None
|
||||
output_field = None
|
||||
|
||||
def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
|
||||
if kwargs.setdefault("editable", False):
|
||||
raise ValueError("GeneratedField cannot be editable.")
|
||||
if not kwargs.setdefault("blank", True):
|
||||
raise ValueError("GeneratedField must be blank.")
|
||||
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("GeneratedField cannot have a default.")
|
||||
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("GeneratedField cannot have a database default.")
|
||||
if db_persist not in (True, False):
|
||||
raise ValueError("GeneratedField.db_persist must be True or False.")
|
||||
|
||||
self.expression = expression
|
||||
self.output_field = output_field
|
||||
self.db_persist = db_persist
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@cached_property
|
||||
def cached_col(self):
|
||||
from django.db.models.expressions import Col
|
||||
|
||||
return Col(self.model._meta.db_table, self, self.output_field)
|
||||
|
||||
def get_col(self, alias, output_field=None):
|
||||
if alias != self.model._meta.db_table and output_field in (None, self):
|
||||
output_field = self.output_field
|
||||
return super().get_col(alias, output_field)
|
||||
|
||||
def contribute_to_class(self, *args, **kwargs):
|
||||
super().contribute_to_class(*args, **kwargs)
|
||||
|
||||
self._query = Query(model=self.model, alias_cols=False)
|
||||
# Register lookups from the output_field class.
|
||||
for lookup_name, lookup in self.output_field.get_class_lookups().items():
|
||||
self.register_lookup(lookup, lookup_name=lookup_name)
|
||||
|
||||
def generated_sql(self, connection):
|
||||
compiler = connection.ops.compiler("SQLCompiler")(
|
||||
self._query, connection=connection, using=None
|
||||
)
|
||||
resolved_expression = self.expression.resolve_expression(
|
||||
self._query, allow_joins=False
|
||||
)
|
||||
sql, params = compiler.compile(resolved_expression)
|
||||
if (
|
||||
getattr(self.expression, "conditional", False)
|
||||
and not connection.features.supports_boolean_expr_in_select_clause
|
||||
):
|
||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||
return sql, params
|
||||
|
||||
def check(self, **kwargs):
|
||||
databases = kwargs.get("databases") or []
|
||||
errors = [
|
||||
*super().check(**kwargs),
|
||||
*self._check_supported(databases),
|
||||
*self._check_persistence(databases),
|
||||
]
|
||||
output_field_clone = self.output_field.clone()
|
||||
output_field_clone.model = self.model
|
||||
output_field_checks = output_field_clone.check(databases=databases)
|
||||
if output_field_checks:
|
||||
separator = "\n "
|
||||
error_messages = separator.join(
|
||||
f"{output_check.msg} ({output_check.id})"
|
||||
for output_check in output_field_checks
|
||||
if isinstance(output_check, checks.Error)
|
||||
)
|
||||
if error_messages:
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"GeneratedField.output_field has errors:"
|
||||
f"{separator}{error_messages}",
|
||||
obj=self,
|
||||
id="fields.E223",
|
||||
)
|
||||
)
|
||||
warning_messages = separator.join(
|
||||
f"{output_check.msg} ({output_check.id})"
|
||||
for output_check in output_field_checks
|
||||
if isinstance(output_check, checks.Warning)
|
||||
)
|
||||
if warning_messages:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"GeneratedField.output_field has warnings:"
|
||||
f"{separator}{warning_messages}",
|
||||
obj=self,
|
||||
id="fields.W224",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _check_supported(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
connection.features.supports_virtual_generated_columns
|
||||
or "supports_stored_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
) and not (
|
||||
connection.features.supports_stored_generated_columns
|
||||
or "supports_virtual_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{connection.display_name} does not support GeneratedFields.",
|
||||
obj=self,
|
||||
id="fields.E220",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _check_persistence(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not self.db_persist and not (
|
||||
connection.features.supports_virtual_generated_columns
|
||||
or "supports_virtual_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{connection.display_name} does not support non-persisted "
|
||||
"GeneratedFields.",
|
||||
obj=self,
|
||||
id="fields.E221",
|
||||
hint="Set db_persist=True on the field.",
|
||||
)
|
||||
)
|
||||
if self.db_persist and not (
|
||||
connection.features.supports_stored_generated_columns
|
||||
or "supports_stored_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{connection.display_name} does not support persisted "
|
||||
"GeneratedFields.",
|
||||
obj=self,
|
||||
id="fields.E222",
|
||||
hint="Set db_persist=False on the field.",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
del kwargs["blank"]
|
||||
del kwargs["editable"]
|
||||
kwargs["db_persist"] = self.db_persist
|
||||
kwargs["expression"] = self.expression
|
||||
kwargs["output_field"] = self.output_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
return self.output_field.get_internal_type()
|
||||
|
||||
def db_parameters(self, connection):
|
||||
return self.output_field.db_parameters(connection)
|
||||
|
||||
def db_type_parameters(self, connection):
|
||||
return self.output_field.db_type_parameters(connection)
|
||||
@@ -0,0 +1,664 @@
|
||||
import json
|
||||
|
||||
from django import forms
|
||||
from django.core import checks, exceptions
|
||||
from django.db import NotSupportedError, connections, router
|
||||
from django.db.models import expressions, lookups
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields import TextField
|
||||
from django.db.models.lookups import (
|
||||
FieldGetDbPrepValueMixin,
|
||||
PostgresOperatorLookup,
|
||||
Transform,
|
||||
)
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from . import Field
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
__all__ = ["JSONField"]
|
||||
|
||||
|
||||
class JSONField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _("A JSON object")
|
||||
default_error_messages = {
|
||||
"invalid": _("Value must be valid JSON."),
|
||||
}
|
||||
_default_hint = ("dict", "{}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
**kwargs,
|
||||
):
|
||||
if encoder and not callable(encoder):
|
||||
raise ValueError("The encoder parameter must be a callable object.")
|
||||
if decoder and not callable(decoder):
|
||||
raise ValueError("The decoder parameter must be a callable object.")
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
databases = kwargs.get("databases") or []
|
||||
errors.extend(self._check_supported(databases))
|
||||
return errors
|
||||
|
||||
def _check_supported(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
"supports_json_field" in self.model._meta.required_db_features
|
||||
or connection.features.supports_json_field
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"%s does not support JSONFields." % connection.display_name,
|
||||
obj=self.model,
|
||||
id="fields.E180",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.encoder is not None:
|
||||
kwargs["encoder"] = self.encoder
|
||||
if self.decoder is not None:
|
||||
kwargs["decoder"] = self.decoder
|
||||
return name, path, args, kwargs
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return value
|
||||
# Some backends (SQLite at least) extract non-string values in their
|
||||
# SQL datatypes.
|
||||
if isinstance(expression, KeyTransform) and not isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.loads(value, cls=self.decoder)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def get_internal_type(self):
|
||||
return "JSONField"
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if not prepared:
|
||||
value = self.get_prep_value(value)
|
||||
return connection.ops.adapt_json_value(value, self.encoder)
|
||||
|
||||
def get_db_prep_save(self, value, connection):
|
||||
# This slightly involved logic is to allow for `None` to be used to
|
||||
# store SQL `NULL` while `Value(None, JSONField())` can be used to
|
||||
# store JSON `null` while preventing compilable `as_sql` values from
|
||||
# making their way to `get_db_prep_value`, which is what the `super()`
|
||||
# implementation does.
|
||||
if value is None:
|
||||
return value
|
||||
if (
|
||||
isinstance(value, expressions.Value)
|
||||
and value.value is None
|
||||
and isinstance(value.output_field, JSONField)
|
||||
):
|
||||
value = None
|
||||
return super().get_db_prep_save(value, connection)
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
return KeyTransformFactory(name)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
try:
|
||||
json.dumps(value, cls=self.encoder)
|
||||
except TypeError:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["invalid"],
|
||||
code="invalid",
|
||||
params={"value": value},
|
||||
)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return self.value_from_object(obj)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.JSONField,
|
||||
"encoder": self.encoder,
|
||||
"decoder": self.decoder,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def compile_json_path(key_transforms, include_root=True):
|
||||
path = ["$"] if include_root else []
|
||||
for key_transform in key_transforms:
|
||||
try:
|
||||
num = int(key_transform)
|
||||
except ValueError: # non-integer
|
||||
path.append(".")
|
||||
path.append(json.dumps(key_transform))
|
||||
else:
|
||||
path.append("[%s]" % num)
|
||||
return "".join(path)
|
||||
|
||||
|
||||
class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contains"
|
||||
postgres_operator = "@>"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contains lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
|
||||
|
||||
|
||||
class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contained_by"
|
||||
postgres_operator = "<@"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contained_by lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(rhs_params) + tuple(lhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
|
||||
|
||||
|
||||
class HasKeyLookup(PostgresOperatorLookup):
|
||||
logical_operator = None
|
||||
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
# Compile the final key without interpreting ints as array elements.
|
||||
return ".%s" % json.dumps(key_transform)
|
||||
|
||||
def _as_sql_parts(self, compiler, connection):
|
||||
# Process JSON path from the left-hand side.
|
||||
if isinstance(self.lhs, KeyTransform):
|
||||
lhs_sql, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
|
||||
compiler, connection
|
||||
)
|
||||
lhs_json_path = compile_json_path(lhs_key_transforms)
|
||||
else:
|
||||
lhs_sql, lhs_params = self.process_lhs(compiler, connection)
|
||||
lhs_json_path = "$"
|
||||
# Process JSON path from the right-hand side.
|
||||
rhs = self.rhs
|
||||
if not isinstance(rhs, (list, tuple)):
|
||||
rhs = [rhs]
|
||||
for key in rhs:
|
||||
if isinstance(key, KeyTransform):
|
||||
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
|
||||
else:
|
||||
rhs_key_transforms = [key]
|
||||
*rhs_key_transforms, final_key = rhs_key_transforms
|
||||
rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
|
||||
rhs_json_path += self.compile_json_path_final_key(final_key)
|
||||
yield lhs_sql, lhs_params, lhs_json_path + rhs_json_path
|
||||
|
||||
def _combine_sql_parts(self, parts):
|
||||
# Add condition for each key.
|
||||
if self.logical_operator:
|
||||
return "(%s)" % self.logical_operator.join(parts)
|
||||
return "".join(parts)
|
||||
|
||||
def as_sql(self, compiler, connection, template=None):
|
||||
sql_parts = []
|
||||
params = []
|
||||
for lhs_sql, lhs_params, rhs_json_path in self._as_sql_parts(
|
||||
compiler, connection
|
||||
):
|
||||
sql_parts.append(template % (lhs_sql, "%s"))
|
||||
params.extend(lhs_params + [rhs_json_path])
|
||||
return self._combine_sql_parts(sql_parts), tuple(params)
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %s)"
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
# Use a custom delimiter to prevent the JSON path from escaping the SQL
|
||||
# literal. See comment in KeyTransform.
|
||||
template = "JSON_EXISTS(%s, q'\uffff%s\uffff')"
|
||||
sql_parts = []
|
||||
params = []
|
||||
for lhs_sql, lhs_params, rhs_json_path in self._as_sql_parts(
|
||||
compiler, connection
|
||||
):
|
||||
# Add right-hand-side directly into SQL because it cannot be passed
|
||||
# as bind variables to JSON_EXISTS. It might result in invalid
|
||||
# queries but it is assumed that it cannot be evaded because the
|
||||
# path is JSON serialized.
|
||||
sql_parts.append(template % (lhs_sql, rhs_json_path))
|
||||
params.extend(lhs_params)
|
||||
return self._combine_sql_parts(sql_parts), tuple(params)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
|
||||
for key in rhs_key_transforms[:-1]:
|
||||
self.lhs = KeyTransform(key, self.lhs)
|
||||
self.rhs = rhs_key_transforms[-1]
|
||||
return super().as_postgresql(compiler, connection)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_TYPE(%s, %s) IS NOT NULL"
|
||||
)
|
||||
|
||||
|
||||
class HasKey(HasKeyLookup):
|
||||
lookup_name = "has_key"
|
||||
postgres_operator = "?"
|
||||
prepare_rhs = False
|
||||
|
||||
|
||||
class HasKeys(HasKeyLookup):
|
||||
lookup_name = "has_keys"
|
||||
postgres_operator = "?&"
|
||||
logical_operator = " AND "
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return [str(item) for item in self.rhs]
|
||||
|
||||
|
||||
class HasAnyKeys(HasKeys):
|
||||
lookup_name = "has_any_keys"
|
||||
postgres_operator = "?|"
|
||||
logical_operator = " OR "
|
||||
|
||||
|
||||
class HasKeyOrArrayIndex(HasKey):
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
return compile_json_path([key_transform], include_root=False)
|
||||
|
||||
|
||||
class CaseInsensitiveMixin:
|
||||
"""
|
||||
Mixin to allow case-insensitive comparison of JSON values on MySQL.
|
||||
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
|
||||
Because utf8mb4_bin is a binary collation, comparison of JSON values is
|
||||
case-sensitive.
|
||||
"""
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % lhs, lhs_params
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % rhs, rhs_params
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONExact(lookups.Exact):
|
||||
can_use_none_as_rhs = True
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
# Treat None lookup values as null.
|
||||
if rhs == "%s" and rhs_params == [None]:
|
||||
rhs_params = ["null"]
|
||||
if connection.vendor == "mysql":
|
||||
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
if connection.features.supports_primitives_in_json_field:
|
||||
lhs = f"JSON({lhs})"
|
||||
rhs = f"JSON({rhs})"
|
||||
return f"JSON_EQUAL({lhs}, {rhs} ERROR ON ERROR)", (*lhs_params, *rhs_params)
|
||||
|
||||
|
||||
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
JSONField.register_lookup(DataContains)
|
||||
JSONField.register_lookup(ContainedBy)
|
||||
JSONField.register_lookup(HasKey)
|
||||
JSONField.register_lookup(HasKeys)
|
||||
JSONField.register_lookup(HasAnyKeys)
|
||||
JSONField.register_lookup(JSONExact)
|
||||
JSONField.register_lookup(JSONIContains)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
postgres_operator = "->"
|
||||
postgres_nested_operator = "#>"
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key_name = str(key_name)
|
||||
|
||||
def preprocess_lhs(self, compiler, connection):
|
||||
key_transforms = [self.key_name]
|
||||
previous = self.lhs
|
||||
while isinstance(previous, KeyTransform):
|
||||
key_transforms.insert(0, previous.key_name)
|
||||
previous = previous.lhs
|
||||
lhs, params = compiler.compile(previous)
|
||||
if connection.vendor == "oracle":
|
||||
# Escape string-formatting.
|
||||
key_transforms = [key.replace("%", "%%") for key in key_transforms]
|
||||
return lhs, params, key_transforms
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
if connection.features.supports_primitives_in_json_field:
|
||||
sql = (
|
||||
"COALESCE("
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
|
||||
")"
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"COALESCE("
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff')"
|
||||
")"
|
||||
)
|
||||
# Add paths directly into SQL because path expressions cannot be passed
|
||||
# as bind variables on Oracle. Use a custom delimiter to prevent the
|
||||
# JSON path from escaping the SQL literal. Each key in the JSON path is
|
||||
# passed through json.dumps() with ensure_ascii=True (the default),
|
||||
# which converts the delimiter into the escaped \uffff format. This
|
||||
# ensures that the delimiter is not present in the JSON path.
|
||||
return sql % ((lhs, json_path) * 2), tuple(params) * 2
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
if len(key_transforms) > 1:
|
||||
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
|
||||
return sql, tuple(params) + (key_transforms,)
|
||||
try:
|
||||
lookup = int(self.key_name)
|
||||
except ValueError:
|
||||
lookup = self.key_name
|
||||
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
datatype_values = ",".join(
|
||||
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
|
||||
)
|
||||
return (
|
||||
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||||
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||||
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
|
||||
|
||||
|
||||
class KeyTextTransform(KeyTransform):
|
||||
postgres_operator = "->>"
|
||||
postgres_nested_operator = "#>>"
|
||||
output_field = TextField()
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
if connection.mysql_is_mariadb:
|
||||
# MariaDB doesn't support -> and ->> operators (see MDEV-13594).
|
||||
sql, params = super().as_mysql(compiler, connection)
|
||||
return "JSON_UNQUOTE(%s)" % sql, params
|
||||
else:
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
@classmethod
|
||||
def from_lookup(cls, lookup):
|
||||
transform, *keys = lookup.split(LOOKUP_SEP)
|
||||
if not keys:
|
||||
raise ValueError("Lookup must contain key or index transforms.")
|
||||
for key in keys:
|
||||
transform = cls(key, transform)
|
||||
return transform
|
||||
|
||||
|
||||
KT = KeyTextTransform.from_lookup
|
||||
|
||||
|
||||
class KeyTransformTextLookupMixin:
|
||||
"""
|
||||
Mixin for combining with a lookup expecting a text lhs from a JSONField
|
||||
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
|
||||
key values to text and performing the lookup on the resulting
|
||||
representation.
|
||||
"""
|
||||
|
||||
def __init__(self, key_transform, *args, **kwargs):
|
||||
if not isinstance(key_transform, KeyTransform):
|
||||
raise TypeError(
|
||||
"Transform should be an instance of KeyTransform in order to "
|
||||
"use this lookup."
|
||||
)
|
||||
key_text_transform = KeyTextTransform(
|
||||
key_transform.key_name,
|
||||
*key_transform.source_expressions,
|
||||
**key_transform.extra,
|
||||
)
|
||||
super().__init__(key_text_transform, *args, **kwargs)
|
||||
|
||||
|
||||
class KeyTransformIsNull(lookups.IsNull):
|
||||
# key__isnull=False is the same as has_key='key'
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = HasKeyOrArrayIndex(
|
||||
self.lhs.lhs,
|
||||
self.lhs.key_name,
|
||||
).as_oracle(compiler, connection)
|
||||
if not self.rhs:
|
||||
return sql, params
|
||||
# Column doesn't have a key or IS NULL.
|
||||
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
|
||||
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
template = "JSON_TYPE(%s, %s) IS NULL"
|
||||
if not self.rhs:
|
||||
template = "JSON_TYPE(%s, %s) IS NOT NULL"
|
||||
return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
)
|
||||
|
||||
|
||||
class KeyTransformIn(lookups.In):
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
sql, params = super().resolve_expression_parameter(
|
||||
compiler,
|
||||
connection,
|
||||
sql,
|
||||
param,
|
||||
)
|
||||
if (
|
||||
not hasattr(param, "as_sql")
|
||||
and not connection.features.has_native_json_field
|
||||
):
|
||||
if connection.vendor == "oracle":
|
||||
value = json.loads(param)
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
if isinstance(value, (list, dict)):
|
||||
sql %= "JSON_QUERY"
|
||||
else:
|
||||
sql %= "JSON_VALUE"
|
||||
elif connection.vendor == "mysql" or (
|
||||
connection.vendor == "sqlite"
|
||||
and params[0] not in connection.ops.jsonfield_datatype_values
|
||||
):
|
||||
sql = "JSON_EXTRACT(%s, '$')"
|
||||
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
|
||||
sql = "JSON_UNQUOTE(%s)" % sql
|
||||
return sql, params
|
||||
|
||||
|
||||
class KeyTransformExact(JSONExact):
|
||||
def process_rhs(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
return super(lookups.Exact, self).process_rhs(compiler, connection)
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "oracle":
|
||||
func = []
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
for value in rhs_params:
|
||||
value = json.loads(value)
|
||||
if isinstance(value, (list, dict)):
|
||||
func.append(sql % "JSON_QUERY")
|
||||
else:
|
||||
func.append(sql % "JSON_VALUE")
|
||||
rhs %= tuple(func)
|
||||
elif connection.vendor == "sqlite":
|
||||
func = []
|
||||
for value in rhs_params:
|
||||
if value in connection.ops.jsonfield_datatype_values:
|
||||
func.append("%s")
|
||||
else:
|
||||
func.append("JSON_EXTRACT(%s, '$')")
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if rhs_params == ["null"]:
|
||||
# Field has key and it's NULL.
|
||||
has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
|
||||
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
|
||||
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
|
||||
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
|
||||
return (
|
||||
"%s AND %s" % (has_key_sql, is_null_sql),
|
||||
tuple(has_key_params) + tuple(is_null_params),
|
||||
)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class KeyTransformIExact(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIContains(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIStartsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIEndsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIRegex(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformNumericLookupMixin:
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if not connection.features.has_native_json_field:
|
||||
rhs_params = [json.loads(value) for value in rhs_params]
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformIn)
|
||||
KeyTransform.register_lookup(KeyTransformExact)
|
||||
KeyTransform.register_lookup(KeyTransformIExact)
|
||||
KeyTransform.register_lookup(KeyTransformIsNull)
|
||||
KeyTransform.register_lookup(KeyTransformIContains)
|
||||
KeyTransform.register_lookup(KeyTransformStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformRegex)
|
||||
KeyTransform.register_lookup(KeyTransformIRegex)
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformLt)
|
||||
KeyTransform.register_lookup(KeyTransformLte)
|
||||
KeyTransform.register_lookup(KeyTransformGt)
|
||||
KeyTransform.register_lookup(KeyTransformGte)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return KeyTransform(self.key_name, *args, **kwargs)
|
||||
@@ -0,0 +1,81 @@
|
||||
import warnings
|
||||
|
||||
from django.core import checks
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
NOT_PROVIDED = object()
|
||||
|
||||
|
||||
class FieldCacheMixin:
|
||||
"""
|
||||
An API for working with the model's fields value cache.
|
||||
|
||||
Subclasses must set self.cache_name to a unique entry for the cache -
|
||||
typically the field’s name.
|
||||
"""
|
||||
|
||||
# RemovedInDjango60Warning.
|
||||
def get_cache_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@cached_property
|
||||
def cache_name(self):
|
||||
# RemovedInDjango60Warning: when the deprecation ends, replace with:
|
||||
# raise NotImplementedError
|
||||
cache_name = self.get_cache_name()
|
||||
warnings.warn(
|
||||
f"Override {self.__class__.__qualname__}.cache_name instead of "
|
||||
"get_cache_name().",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return cache_name
|
||||
|
||||
def get_cached_value(self, instance, default=NOT_PROVIDED):
|
||||
try:
|
||||
return instance._state.fields_cache[self.cache_name]
|
||||
except KeyError:
|
||||
if default is NOT_PROVIDED:
|
||||
raise
|
||||
return default
|
||||
|
||||
def is_cached(self, instance):
|
||||
return self.cache_name in instance._state.fields_cache
|
||||
|
||||
def set_cached_value(self, instance, value):
|
||||
instance._state.fields_cache[self.cache_name] = value
|
||||
|
||||
def delete_cached_value(self, instance):
|
||||
del instance._state.fields_cache[self.cache_name]
|
||||
|
||||
|
||||
class CheckFieldDefaultMixin:
|
||||
_default_hint = ("<valid default>", "<invalid default>")
|
||||
|
||||
def _check_default(self):
|
||||
if (
|
||||
self.has_default()
|
||||
and self.default is not None
|
||||
and not callable(self.default)
|
||||
):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s default should be a callable instead of an instance "
|
||||
"so that it's not shared between all field instances."
|
||||
% (self.__class__.__name__,),
|
||||
hint=(
|
||||
"Use a callable instead, e.g., use `%s` instead of "
|
||||
"`%s`." % self._default_hint
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E010",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
errors.extend(self._check_default())
|
||||
return errors
|
||||
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Field-like classes that aren't really fields. It's easier to use objects that
|
||||
have the same attributes as fields sometimes (avoids a lot of special casing).
|
||||
"""
|
||||
|
||||
from django.db.models import fields
|
||||
|
||||
|
||||
class OrderWrt(fields.IntegerField):
|
||||
"""
|
||||
A proxy for the _order database field that is used when
|
||||
Meta.order_with_respect_to is specified.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["name"] = "_order"
|
||||
kwargs["editable"] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,151 @@
|
||||
from django.db.models.expressions import ColPairs
|
||||
from django.db.models.fields import composite
|
||||
from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
|
||||
|
||||
def get_normalized_value(value, lhs):
|
||||
from django.db.models import Model
|
||||
|
||||
if isinstance(value, Model):
|
||||
if not value._is_pk_set():
|
||||
raise ValueError("Model instances passed to related filters must be saved.")
|
||||
value_list = []
|
||||
sources = composite.unnest(lhs.output_field.path_infos[-1].target_fields)
|
||||
for source in sources:
|
||||
while not isinstance(value, source.model) and source.remote_field:
|
||||
source = source.remote_field.model._meta.get_field(
|
||||
source.remote_field.field_name
|
||||
)
|
||||
try:
|
||||
value_list.append(getattr(value, source.attname))
|
||||
except AttributeError:
|
||||
# A case like Restaurant.objects.filter(place=restaurant_instance),
|
||||
# where place is a OneToOneField and the primary key of Restaurant.
|
||||
pk = value.pk
|
||||
return pk if isinstance(pk, tuple) else (pk,)
|
||||
return tuple(value_list)
|
||||
if not isinstance(value, tuple):
|
||||
return (value,)
|
||||
return value
|
||||
|
||||
|
||||
class RelatedIn(In):
|
||||
def get_prep_lookup(self):
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if (
|
||||
isinstance(self.rhs, Query)
|
||||
and not self.rhs.has_select_fields
|
||||
and self.lhs.output_field.related_model is self.rhs.model
|
||||
):
|
||||
self.rhs.set_values([f.name for f in self.lhs.sources])
|
||||
else:
|
||||
if self.rhs_is_direct_value():
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
|
||||
# We need to run the related field's get_prep_value(). Consider
|
||||
# case ForeignKey to IntegerField given value 'abc'. The
|
||||
# ForeignKey itself doesn't have validation for non-integers,
|
||||
# so we must run validation using the target field.
|
||||
if hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Run the target field's get_prep_value. We can safely
|
||||
# assume there is only one as we don't get to the direct
|
||||
# value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[
|
||||
-1
|
||||
]
|
||||
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
|
||||
elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
|
||||
self.lhs.field.target_field, "primary_key", False
|
||||
):
|
||||
if (
|
||||
getattr(self.lhs.output_field, "primary_key", False)
|
||||
and self.lhs.output_field.model == self.rhs.model
|
||||
):
|
||||
# A case like
|
||||
# Restaurant.objects.filter(place__in=restaurant_qs), where
|
||||
# place is a OneToOneField and the primary key of
|
||||
# Restaurant.
|
||||
target_field = self.lhs.field.name
|
||||
else:
|
||||
target_field = self.lhs.field.target_field.name
|
||||
self.rhs.set_values([target_field])
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if self.rhs_is_direct_value():
|
||||
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
|
||||
lookup = TupleIn(self.lhs, values)
|
||||
else:
|
||||
lookup = TupleIn(self.lhs, self.rhs)
|
||||
return compiler.compile(lookup)
|
||||
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedLookupMixin:
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, ColPairs) and not hasattr(
|
||||
self.rhs, "resolve_expression"
|
||||
):
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
|
||||
# We need to run the related field's get_prep_value(). Consider case
|
||||
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
|
||||
# doesn't have validation for non-integers, so we must run validation
|
||||
# using the target field.
|
||||
if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Get the target field. We can safely assume there is only one
|
||||
# as we don't get to the direct value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
|
||||
self.rhs = target_field.get_prep_value(self.rhs)
|
||||
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if not self.rhs_is_direct_value():
|
||||
raise ValueError(
|
||||
f"'{self.lookup_name}' doesn't support multi-column subqueries."
|
||||
)
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)
|
||||
lookup_class = tuple_lookups[self.lookup_name]
|
||||
lookup = lookup_class(self.lhs, self.rhs)
|
||||
return compiler.compile(lookup)
|
||||
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedExact(RelatedLookupMixin, Exact):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThan(RelatedLookupMixin, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedIsNull(RelatedLookupMixin, IsNull):
|
||||
pass
|
||||
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
"Rel objects" for related fields.
|
||||
|
||||
"Rel objects" (for lack of a better name) carry information about the relation
|
||||
modeled by a related field and provide some utility functions. They're stored
|
||||
in the ``remote_field`` attribute of the field.
|
||||
|
||||
They also act as reverse fields for the purposes of the Meta API because
|
||||
they're the closest concept currently available.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from django.core import exceptions
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
from . import BLANK_CHOICE_DASH
|
||||
from .mixins import FieldCacheMixin
|
||||
|
||||
|
||||
class ForeignObjectRel(FieldCacheMixin):
|
||||
"""
|
||||
Used by ForeignObject to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
# Field flags
|
||||
auto_created = True
|
||||
concrete = False
|
||||
editable = False
|
||||
is_relation = True
|
||||
|
||||
# Reverse relations are always nullable (Django can't enforce that a
|
||||
# foreign key on the related model points to this model).
|
||||
null = True
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
self.field = field
|
||||
self.model = to
|
||||
self.related_name = related_name
|
||||
self.related_query_name = related_query_name
|
||||
self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
|
||||
self.parent_link = parent_link
|
||||
self.on_delete = on_delete
|
||||
|
||||
self.symmetrical = False
|
||||
self.multiple = True
|
||||
|
||||
# Some of the following cached_properties can't be initialized in
|
||||
# __init__ as the field doesn't have its model yet. Calling these methods
|
||||
# before field.contribute_to_class() has been called will result in
|
||||
# AttributeError
|
||||
@cached_property
|
||||
def hidden(self):
|
||||
"""Should the related object be hidden?"""
|
||||
return bool(self.related_name) and self.related_name[-1] == "+"
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self.field.related_query_name()
|
||||
|
||||
@property
|
||||
def remote_field(self):
|
||||
return self.field
|
||||
|
||||
@property
|
||||
def target_field(self):
|
||||
"""
|
||||
When filtering against this relation, return the field on the remote
|
||||
model against which the filtering should happen.
|
||||
"""
|
||||
target_fields = self.path_infos[-1].target_fields
|
||||
if len(target_fields) > 1:
|
||||
raise exceptions.FieldError(
|
||||
"Can't use target_field for multicolumn relations."
|
||||
)
|
||||
return target_fields[0]
|
||||
|
||||
@cached_property
|
||||
def related_model(self):
|
||||
if not self.field.model:
|
||||
raise AttributeError(
|
||||
"This property can't be accessed before self.field.contribute_to_class "
|
||||
"has been called."
|
||||
)
|
||||
return self.field.model
|
||||
|
||||
@cached_property
|
||||
def many_to_many(self):
|
||||
return self.field.many_to_many
|
||||
|
||||
@cached_property
|
||||
def many_to_one(self):
|
||||
return self.field.one_to_many
|
||||
|
||||
@cached_property
|
||||
def one_to_many(self):
|
||||
return self.field.many_to_one
|
||||
|
||||
@cached_property
|
||||
def one_to_one(self):
|
||||
return self.field.one_to_one
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
return self.field.get_lookup(lookup_name)
|
||||
|
||||
def get_lookups(self):
|
||||
return self.field.get_lookups()
|
||||
|
||||
def get_transform(self, name):
|
||||
return self.field.get_transform(name)
|
||||
|
||||
def get_internal_type(self):
|
||||
return self.field.get_internal_type()
|
||||
|
||||
@property
|
||||
def db_type(self):
|
||||
return self.field.db_type
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s.%s>" % (
|
||||
type(self).__name__,
|
||||
self.related_model._meta.app_label,
|
||||
self.related_model._meta.model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return (
|
||||
self.field,
|
||||
self.model,
|
||||
self.related_name,
|
||||
self.related_query_name,
|
||||
make_hashable(self.limit_choices_to),
|
||||
self.parent_link,
|
||||
self.on_delete,
|
||||
self.symmetrical,
|
||||
self.multiple,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
# Delete the path_infos cached property because it can be recalculated
|
||||
# at first invocation after deserialization. The attribute must be
|
||||
# removed because subclasses like ManyToOneRel may have a PathInfo
|
||||
# which contains an intermediate M2M table that's been dynamically
|
||||
# created and doesn't exist in the .models module.
|
||||
# This is a reverse relation, so there is no reverse_path_infos to
|
||||
# delete.
|
||||
state.pop("path_infos", None)
|
||||
return state
|
||||
|
||||
def get_choices(
|
||||
self,
|
||||
include_blank=True,
|
||||
blank_choice=BLANK_CHOICE_DASH,
|
||||
limit_choices_to=None,
|
||||
ordering=(),
|
||||
):
|
||||
"""
|
||||
Return choices with a default blank choices included, for use
|
||||
as <select> choices for this field.
|
||||
|
||||
Analog of django.db.models.fields.Field.get_choices(), provided
|
||||
initially for utilization by RelatedFieldListFilter.
|
||||
"""
|
||||
limit_choices_to = limit_choices_to or self.limit_choices_to
|
||||
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
|
||||
if ordering:
|
||||
qs = qs.order_by(*ordering)
|
||||
return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
|
||||
|
||||
def get_joining_columns(self):
|
||||
warnings.warn(
|
||||
"ForeignObjectRel.get_joining_columns() is deprecated. Use "
|
||||
"get_joining_fields() instead.",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.field.get_reverse_joining_columns()
|
||||
|
||||
def get_joining_fields(self):
|
||||
return self.field.get_reverse_joining_fields()
|
||||
|
||||
def get_extra_restriction(self, alias, related_alias):
|
||||
return self.field.get_extra_restriction(related_alias, alias)
|
||||
|
||||
def set_field_name(self):
|
||||
"""
|
||||
Set the related field's name, this is not available until later stages
|
||||
of app loading, so set_field_name is called from
|
||||
set_attributes_from_rel()
|
||||
"""
|
||||
# By default foreign object doesn't relate to any remote field (for
|
||||
# example custom multicolumn joins currently have no remote field).
|
||||
self.field_name = None
|
||||
|
||||
@cached_property
|
||||
def accessor_name(self):
|
||||
return self.get_accessor_name()
|
||||
|
||||
def get_accessor_name(self, model=None):
|
||||
# This method encapsulates the logic that decides what name to give an
|
||||
# accessor descriptor that retrieves related many-to-one or
|
||||
# many-to-many objects. It uses the lowercased object_name + "_set",
|
||||
# but this can be overridden with the "related_name" option. Due to
|
||||
# backwards compatibility ModelForms need to be able to provide an
|
||||
# alternate model. See BaseInlineFormSet.get_default_prefix().
|
||||
opts = model._meta if model else self.related_model._meta
|
||||
model = model or self.related_model
|
||||
if self.multiple:
|
||||
# If this is a symmetrical m2m relation on self, there is no
|
||||
# reverse accessor.
|
||||
if self.symmetrical and model == self.model:
|
||||
return None
|
||||
if self.related_name:
|
||||
return self.related_name
|
||||
return opts.model_name + ("_set" if self.multiple else "")
|
||||
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
if filtered_relation:
|
||||
return self.field.get_reverse_path_info(filtered_relation)
|
||||
else:
|
||||
return self.field.reverse_path_infos
|
||||
|
||||
@cached_property
|
||||
def path_infos(self):
|
||||
return self.get_path_info()
|
||||
|
||||
@cached_property
|
||||
def cache_name(self):
|
||||
"""
|
||||
Return the name of the cache key to use for storing an instance of the
|
||||
forward model on the reverse model.
|
||||
"""
|
||||
return self.accessor_name
|
||||
|
||||
|
||||
class ManyToOneRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by the ForeignKey field to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
|
||||
Note: Because we somewhat abuse the Rel objects by using them as reverse
|
||||
fields we get the funny situation where
|
||||
``ManyToOneRel.many_to_one == False`` and
|
||||
``ManyToOneRel.one_to_many == True``. This is unfortunate but the actual
|
||||
ManyToOneRel class is a private API and there is work underway to turn
|
||||
reverse relations into actual fields.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.field_name = field_name
|
||||
|
||||
def __getstate__(self):
|
||||
state = super().__getstate__()
|
||||
state.pop("related_model", None)
|
||||
return state
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (self.field_name,)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the Field in the 'to' object to which this relationship is tied.
|
||||
"""
|
||||
field = self.model._meta.get_field(self.field_name)
|
||||
if not field.concrete:
|
||||
raise exceptions.FieldDoesNotExist(
|
||||
"No related field named '%s'" % self.field_name
|
||||
)
|
||||
return field
|
||||
|
||||
def set_field_name(self):
|
||||
self.field_name = self.field_name or self.model._meta.pk.name
|
||||
|
||||
|
||||
class OneToOneRel(ManyToOneRel):
|
||||
"""
|
||||
Used by OneToOneField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.multiple = False
|
||||
|
||||
|
||||
class ManyToManyRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by ManyToManyField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
symmetrical=True,
|
||||
through=None,
|
||||
through_fields=None,
|
||||
db_constraint=True,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
)
|
||||
|
||||
if through and not db_constraint:
|
||||
raise ValueError("Can't supply a through model and db_constraint=False")
|
||||
self.through = through
|
||||
|
||||
if through_fields and not through:
|
||||
raise ValueError("Cannot specify through_fields without a through model")
|
||||
self.through_fields = through_fields
|
||||
|
||||
self.symmetrical = symmetrical
|
||||
self.db_constraint = db_constraint
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (
|
||||
self.through,
|
||||
make_hashable(self.through_fields),
|
||||
self.db_constraint,
|
||||
)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the field in the 'to' object to which this relationship is tied.
|
||||
Provided for symmetry with ManyToOneRel.
|
||||
"""
|
||||
opts = self.through._meta
|
||||
if self.through_fields:
|
||||
field = opts.get_field(self.through_fields[0])
|
||||
else:
|
||||
for field in opts.fields:
|
||||
rel = getattr(field, "remote_field", None)
|
||||
if rel and rel.model == self.model:
|
||||
break
|
||||
return field.foreign_related_fields[0]
|
||||
@@ -0,0 +1,359 @@
|
||||
import itertools
|
||||
|
||||
from django.core.exceptions import EmptyResultSet
|
||||
from django.db.models import Field
|
||||
from django.db.models.expressions import (
|
||||
ColPairs,
|
||||
Func,
|
||||
ResolvedOuterRef,
|
||||
Subquery,
|
||||
Value,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
from django.db.models.sql import Query
|
||||
from django.db.models.sql.where import AND, OR, WhereNode
|
||||
|
||||
|
||||
class Tuple(Func):
|
||||
allows_composite_expressions = True
|
||||
function = ""
|
||||
output_field = Field()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.source_expressions)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.source_expressions)
|
||||
|
||||
|
||||
class TupleLookupMixin:
|
||||
allows_composite_expressions = True
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if self.rhs_is_direct_value():
|
||||
self.check_rhs_is_tuple_or_list()
|
||||
self.check_rhs_length_equals_lhs_length()
|
||||
else:
|
||||
self.check_rhs_is_supported_expression()
|
||||
super().get_prep_lookup()
|
||||
return self.rhs
|
||||
|
||||
def check_rhs_is_tuple_or_list(self):
|
||||
if not isinstance(self.rhs, (tuple, list)):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list"
|
||||
)
|
||||
|
||||
def check_rhs_length_equals_lhs_length(self):
|
||||
len_lhs = len(self.lhs)
|
||||
if len_lhs != len(self.rhs):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
|
||||
)
|
||||
|
||||
def check_rhs_is_supported_expression(self):
|
||||
if not isinstance(self.rhs, (ResolvedOuterRef, Query)):
|
||||
lhs_str = self.get_lhs_str()
|
||||
rhs_cls = self.rhs.__class__.__name__
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
||||
f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})"
|
||||
)
|
||||
|
||||
def get_lhs_str(self):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
return repr(self.lhs.field.name)
|
||||
else:
|
||||
names = ", ".join(repr(f.name) for f in self.lhs)
|
||||
return f"({names})"
|
||||
|
||||
def get_prep_lhs(self):
|
||||
if isinstance(self.lhs, (tuple, list)):
|
||||
return Tuple(*self.lhs)
|
||||
return super().get_prep_lhs()
|
||||
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
sql, params = super().process_lhs(compiler, connection, lhs)
|
||||
if not isinstance(self.lhs, Tuple):
|
||||
sql = f"({sql})"
|
||||
return sql, params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
args = [
|
||||
Value(val, output_field=col.output_field)
|
||||
for col, val in zip(self.lhs, self.rhs)
|
||||
]
|
||||
return compiler.compile(Tuple(*args))
|
||||
else:
|
||||
sql, params = compiler.compile(self.rhs)
|
||||
if isinstance(self.rhs, ColPairs):
|
||||
return "(%s)" % sql, params
|
||||
elif isinstance(self.rhs, Query):
|
||||
return super().process_rhs(compiler, connection)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Composite field lookups only work with composite expressions."
|
||||
)
|
||||
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__}.get_fallback_sql() must be implemented "
|
||||
f"for backends that don't have the supports_tuple_lookups feature enabled."
|
||||
)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_tuple_lookups:
|
||||
return self.get_fallback_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleExact(TupleLookupMixin, Exact):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
if isinstance(self.rhs, Query):
|
||||
return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) == (x, y, z) as SQL:
|
||||
# WHERE a = x AND b = y AND c = z
|
||||
lookups = [Exact(col, val) for col, val in zip(self.lhs, self.rhs)]
|
||||
root = WhereNode(lookups, connector=AND)
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleIsNull(TupleLookupMixin, IsNull):
|
||||
def get_prep_lookup(self):
|
||||
rhs = self.rhs
|
||||
if isinstance(rhs, (tuple, list)) and len(rhs) == 1:
|
||||
rhs = rhs[0]
|
||||
if isinstance(rhs, bool):
|
||||
return rhs
|
||||
raise ValueError(
|
||||
"The QuerySet value for an isnull lookup must be True or False."
|
||||
)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# e.g.: (a, b, c) is None as SQL:
|
||||
# WHERE a IS NULL OR b IS NULL OR c IS NULL
|
||||
# e.g.: (a, b, c) is not None as SQL:
|
||||
# WHERE a IS NOT NULL AND b IS NOT NULL AND c IS NOT NULL
|
||||
rhs = self.rhs
|
||||
lookups = [IsNull(col, rhs) for col in self.lhs]
|
||||
root = WhereNode(lookups, connector=OR if rhs else AND)
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleGreaterThan(TupleLookupMixin, GreaterThan):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) > (x, y, z) as SQL:
|
||||
# WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
|
||||
lookups = itertools.cycle([GreaterThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list[:-1])
|
||||
vals_iter = iter(vals_list[:-1])
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) >= (x, y, z) as SQL:
|
||||
# WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
|
||||
lookups = itertools.cycle([GreaterThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list)
|
||||
vals_iter = iter(vals_list)
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleLessThan(TupleLookupMixin, LessThan):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) < (x, y, z) as SQL:
|
||||
# WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
|
||||
lookups = itertools.cycle([LessThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list[:-1])
|
||||
vals_iter = iter(vals_list[:-1])
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) <= (x, y, z) as SQL:
|
||||
# WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
|
||||
lookups = itertools.cycle([LessThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list)
|
||||
vals_iter = iter(vals_list)
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleIn(TupleLookupMixin, In):
|
||||
def get_prep_lookup(self):
|
||||
if self.rhs_is_direct_value():
|
||||
self.check_rhs_is_tuple_or_list()
|
||||
self.check_rhs_is_collection_of_tuples_or_lists()
|
||||
self.check_rhs_elements_length_equals_lhs_length()
|
||||
else:
|
||||
self.check_rhs_is_query()
|
||||
super(TupleLookupMixin, self).get_prep_lookup()
|
||||
|
||||
return self.rhs # skip checks from mixin
|
||||
|
||||
def check_rhs_is_collection_of_tuples_or_lists(self):
|
||||
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} "
|
||||
"must be a collection of tuples or lists"
|
||||
)
|
||||
|
||||
def check_rhs_elements_length_equals_lhs_length(self):
|
||||
len_lhs = len(self.lhs)
|
||||
if not all(len_lhs == len(vals) for vals in self.rhs):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} "
|
||||
f"must have {len_lhs} elements each"
|
||||
)
|
||||
|
||||
def check_rhs_is_query(self):
|
||||
if not isinstance(self.rhs, (Query, Subquery)):
|
||||
lhs_str = self.get_lhs_str()
|
||||
rhs_cls = self.rhs.__class__.__name__
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
||||
f"must be a Query object (received {rhs_cls!r})"
|
||||
)
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if not self.rhs_is_direct_value():
|
||||
return super(TupleLookupMixin, self).process_rhs(compiler, connection)
|
||||
|
||||
rhs = self.rhs
|
||||
if not rhs:
|
||||
raise EmptyResultSet
|
||||
|
||||
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
||||
# WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
|
||||
result = []
|
||||
lhs = self.lhs
|
||||
|
||||
for vals in rhs:
|
||||
result.append(
|
||||
Tuple(
|
||||
*[
|
||||
Value(val, output_field=col.output_field)
|
||||
for col, val in zip(lhs, vals)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return compiler.compile(Tuple(*result))
|
||||
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
rhs = self.rhs
|
||||
if not rhs:
|
||||
raise EmptyResultSet
|
||||
if not self.rhs_is_direct_value():
|
||||
return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
||||
|
||||
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
||||
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
|
||||
root = WhereNode([], connector=OR)
|
||||
lhs = self.lhs
|
||||
|
||||
for vals in rhs:
|
||||
lookups = [Exact(col, val) for col, val in zip(lhs, vals)]
|
||||
root.children.append(WhereNode(lookups, connector=AND))
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
tuple_lookups = {
|
||||
"exact": TupleExact,
|
||||
"gt": TupleGreaterThan,
|
||||
"gte": TupleGreaterThanOrEqual,
|
||||
"lt": TupleLessThan,
|
||||
"lte": TupleLessThanOrEqual,
|
||||
"in": TupleIn,
|
||||
"isnull": TupleIsNull,
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
from .comparison import Cast, Coalesce, Collate, Greatest, Least, NullIf
|
||||
from .datetime import (
|
||||
Extract,
|
||||
ExtractDay,
|
||||
ExtractHour,
|
||||
ExtractIsoWeekDay,
|
||||
ExtractIsoYear,
|
||||
ExtractMinute,
|
||||
ExtractMonth,
|
||||
ExtractQuarter,
|
||||
ExtractSecond,
|
||||
ExtractWeek,
|
||||
ExtractWeekDay,
|
||||
ExtractYear,
|
||||
Now,
|
||||
Trunc,
|
||||
TruncDate,
|
||||
TruncDay,
|
||||
TruncHour,
|
||||
TruncMinute,
|
||||
TruncMonth,
|
||||
TruncQuarter,
|
||||
TruncSecond,
|
||||
TruncTime,
|
||||
TruncWeek,
|
||||
TruncYear,
|
||||
)
|
||||
from .json import JSONArray, JSONObject
|
||||
from .math import (
|
||||
Abs,
|
||||
ACos,
|
||||
ASin,
|
||||
ATan,
|
||||
ATan2,
|
||||
Ceil,
|
||||
Cos,
|
||||
Cot,
|
||||
Degrees,
|
||||
Exp,
|
||||
Floor,
|
||||
Ln,
|
||||
Log,
|
||||
Mod,
|
||||
Pi,
|
||||
Power,
|
||||
Radians,
|
||||
Random,
|
||||
Round,
|
||||
Sign,
|
||||
Sin,
|
||||
Sqrt,
|
||||
Tan,
|
||||
)
|
||||
from .text import (
|
||||
MD5,
|
||||
SHA1,
|
||||
SHA224,
|
||||
SHA256,
|
||||
SHA384,
|
||||
SHA512,
|
||||
Chr,
|
||||
Concat,
|
||||
ConcatPair,
|
||||
Left,
|
||||
Length,
|
||||
Lower,
|
||||
LPad,
|
||||
LTrim,
|
||||
Ord,
|
||||
Repeat,
|
||||
Replace,
|
||||
Reverse,
|
||||
Right,
|
||||
RPad,
|
||||
RTrim,
|
||||
StrIndex,
|
||||
Substr,
|
||||
Trim,
|
||||
Upper,
|
||||
)
|
||||
from .window import (
|
||||
CumeDist,
|
||||
DenseRank,
|
||||
FirstValue,
|
||||
Lag,
|
||||
LastValue,
|
||||
Lead,
|
||||
NthValue,
|
||||
Ntile,
|
||||
PercentRank,
|
||||
Rank,
|
||||
RowNumber,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# comparison and conversion
|
||||
"Cast",
|
||||
"Coalesce",
|
||||
"Collate",
|
||||
"Greatest",
|
||||
"Least",
|
||||
"NullIf",
|
||||
# datetime
|
||||
"Extract",
|
||||
"ExtractDay",
|
||||
"ExtractHour",
|
||||
"ExtractMinute",
|
||||
"ExtractMonth",
|
||||
"ExtractQuarter",
|
||||
"ExtractSecond",
|
||||
"ExtractWeek",
|
||||
"ExtractIsoWeekDay",
|
||||
"ExtractWeekDay",
|
||||
"ExtractIsoYear",
|
||||
"ExtractYear",
|
||||
"Now",
|
||||
"Trunc",
|
||||
"TruncDate",
|
||||
"TruncDay",
|
||||
"TruncHour",
|
||||
"TruncMinute",
|
||||
"TruncMonth",
|
||||
"TruncQuarter",
|
||||
"TruncSecond",
|
||||
"TruncTime",
|
||||
"TruncWeek",
|
||||
"TruncYear",
|
||||
# json
|
||||
"JSONArray",
|
||||
"JSONObject",
|
||||
# math
|
||||
"Abs",
|
||||
"ACos",
|
||||
"ASin",
|
||||
"ATan",
|
||||
"ATan2",
|
||||
"Ceil",
|
||||
"Cos",
|
||||
"Cot",
|
||||
"Degrees",
|
||||
"Exp",
|
||||
"Floor",
|
||||
"Ln",
|
||||
"Log",
|
||||
"Mod",
|
||||
"Pi",
|
||||
"Power",
|
||||
"Radians",
|
||||
"Random",
|
||||
"Round",
|
||||
"Sign",
|
||||
"Sin",
|
||||
"Sqrt",
|
||||
"Tan",
|
||||
# text
|
||||
"MD5",
|
||||
"SHA1",
|
||||
"SHA224",
|
||||
"SHA256",
|
||||
"SHA384",
|
||||
"SHA512",
|
||||
"Chr",
|
||||
"Concat",
|
||||
"ConcatPair",
|
||||
"Left",
|
||||
"Length",
|
||||
"Lower",
|
||||
"LPad",
|
||||
"LTrim",
|
||||
"Ord",
|
||||
"Repeat",
|
||||
"Replace",
|
||||
"Reverse",
|
||||
"Right",
|
||||
"RPad",
|
||||
"RTrim",
|
||||
"StrIndex",
|
||||
"Substr",
|
||||
"Trim",
|
||||
"Upper",
|
||||
# window
|
||||
"CumeDist",
|
||||
"DenseRank",
|
||||
"FirstValue",
|
||||
"Lag",
|
||||
"LastValue",
|
||||
"Lead",
|
||||
"NthValue",
|
||||
"Ntile",
|
||||
"PercentRank",
|
||||
"Rank",
|
||||
"RowNumber",
|
||||
]
|
||||
@@ -0,0 +1,172 @@
|
||||
"""Database functions that do comparisons or type conversions."""
|
||||
|
||||
from django.db.models.expressions import Func, Value
|
||||
from django.utils.regex_helper import _lazy_re_compile
|
||||
|
||||
|
||||
class Cast(Func):
|
||||
"""Coerce an expression to a new field type."""
|
||||
|
||||
function = "CAST"
|
||||
template = "%(function)s(%(expressions)s AS %(db_type)s)"
|
||||
|
||||
def __init__(self, expression, output_field):
|
||||
super().__init__(expression, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context["db_type"] = self.output_field.cast_db_type(connection)
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
db_type = self.output_field.db_type(connection)
|
||||
if db_type in {"datetime", "time"}:
|
||||
# Use strftime as datetime/time don't keep fractional seconds.
|
||||
template = "strftime(%%s, %(expressions)s)"
|
||||
sql, params = super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
format_string = "%H:%M:%f" if db_type == "time" else "%Y-%m-%d %H:%M:%f"
|
||||
params.insert(0, format_string)
|
||||
return sql, params
|
||||
elif db_type == "date":
|
||||
template = "date(%(expressions)s)"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
template = None
|
||||
output_type = self.output_field.get_internal_type()
|
||||
# MySQL doesn't support explicit cast to float.
|
||||
if output_type == "FloatField":
|
||||
template = "(%(expressions)s + 0.0)"
|
||||
# MariaDB doesn't support explicit cast to JSON.
|
||||
elif output_type == "JSONField" and connection.mysql_is_mariadb:
|
||||
template = "JSON_EXTRACT(%(expressions)s, '$')"
|
||||
return self.as_sql(compiler, connection, template=template, **extra_context)
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# CAST would be valid too, but the :: shortcut syntax is more readable.
|
||||
# 'expressions' is wrapped in parentheses in case it's a complex
|
||||
# expression.
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="(%(expressions)s)::%(db_type)s",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
if self.output_field.get_internal_type() == "JSONField":
|
||||
# Oracle doesn't support explicit cast to JSON.
|
||||
template = "JSON_QUERY(%(expressions)s, '$')"
|
||||
return super().as_sql(
|
||||
compiler, connection, template=template, **extra_context
|
||||
)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Coalesce(Func):
|
||||
"""Return, from left to right, the first non-null expression."""
|
||||
|
||||
function = "COALESCE"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Coalesce must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
@property
|
||||
def empty_result_set_value(self):
|
||||
for expression in self.get_source_expressions():
|
||||
result = expression.empty_result_set_value
|
||||
if result is NotImplemented or result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
# Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
|
||||
# so convert all fields to NCLOB when that type is expected.
|
||||
if self.output_field.get_internal_type() == "TextField":
|
||||
clone = self.copy()
|
||||
clone.set_source_expressions(
|
||||
[
|
||||
Func(expression, function="TO_NCLOB")
|
||||
for expression in self.get_source_expressions()
|
||||
]
|
||||
)
|
||||
return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
|
||||
return self.as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Collate(Func):
|
||||
function = "COLLATE"
|
||||
template = "%(expressions)s %(function)s %(collation)s"
|
||||
allowed_default = False
|
||||
# Inspired from
|
||||
# https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
|
||||
collation_re = _lazy_re_compile(r"^[\w-]+$")
|
||||
|
||||
def __init__(self, expression, collation):
|
||||
if not (collation and self.collation_re.match(collation)):
|
||||
raise ValueError("Invalid collation name: %r." % collation)
|
||||
self.collation = collation
|
||||
super().__init__(expression)
|
||||
|
||||
def as_sql(self, compiler, connection, **extra_context):
|
||||
extra_context.setdefault("collation", connection.ops.quote_name(self.collation))
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
|
||||
|
||||
class Greatest(Func):
|
||||
"""
|
||||
Return the maximum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, the maximum not-null expression is returned.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, null is returned.
|
||||
"""
|
||||
|
||||
function = "GREATEST"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Greatest must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MAX function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MAX", **extra_context)
|
||||
|
||||
|
||||
class Least(Func):
|
||||
"""
|
||||
Return the minimum expression.
|
||||
|
||||
If any expression is null the return value is database-specific:
|
||||
On PostgreSQL, return the minimum not-null expression.
|
||||
On MySQL, Oracle, and SQLite, if any expression is null, return null.
|
||||
"""
|
||||
|
||||
function = "LEAST"
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
if len(expressions) < 2:
|
||||
raise ValueError("Least must take at least two expressions")
|
||||
super().__init__(*expressions, **extra)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
"""Use the MIN function on SQLite."""
|
||||
return super().as_sqlite(compiler, connection, function="MIN", **extra_context)
|
||||
|
||||
|
||||
class NullIf(Func):
|
||||
function = "NULLIF"
|
||||
arity = 2
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
expression1 = self.get_source_expressions()[0]
|
||||
if isinstance(expression1, Value) and expression1.value is None:
|
||||
raise ValueError("Oracle does not allow Value(None) for expression1.")
|
||||
return super().as_sql(compiler, connection, **extra_context)
|
||||
@@ -0,0 +1,439 @@
|
||||
from datetime import datetime
|
||||
|
||||
from django.conf import settings
|
||||
from django.db.models.expressions import Func
|
||||
from django.db.models.fields import (
|
||||
DateField,
|
||||
DateTimeField,
|
||||
DurationField,
|
||||
Field,
|
||||
IntegerField,
|
||||
TimeField,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
Transform,
|
||||
YearExact,
|
||||
YearGt,
|
||||
YearGte,
|
||||
YearLt,
|
||||
YearLte,
|
||||
)
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
class TimezoneMixin:
|
||||
tzinfo = None
|
||||
|
||||
def get_tzname(self):
|
||||
# Timezone conversions must happen to the input datetime *before*
|
||||
# applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
|
||||
# database as 2016-01-01 01:00:00 +00:00. Any results should be
|
||||
# based on the input datetime not the stored datetime.
|
||||
tzname = None
|
||||
if settings.USE_TZ:
|
||||
if self.tzinfo is None:
|
||||
tzname = timezone.get_current_timezone_name()
|
||||
else:
|
||||
tzname = timezone._get_timezone_name(self.tzinfo)
|
||||
return tzname
|
||||
|
||||
|
||||
class Extract(TimezoneMixin, Transform):
|
||||
lookup_name = None
|
||||
output_field = IntegerField()
|
||||
|
||||
def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
|
||||
if self.lookup_name is None:
|
||||
self.lookup_name = lookup_name
|
||||
if self.lookup_name is None:
|
||||
raise ValueError("lookup_name must be provided")
|
||||
self.tzinfo = tzinfo
|
||||
super().__init__(expression, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
lhs_output_field = self.lhs.output_field
|
||||
if isinstance(lhs_output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
sql, params = connection.ops.datetime_extract_sql(
|
||||
self.lookup_name, sql, tuple(params), tzname
|
||||
)
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
elif isinstance(lhs_output_field, DateField):
|
||||
sql, params = connection.ops.date_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, TimeField):
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
elif isinstance(lhs_output_field, DurationField):
|
||||
if not connection.features.has_native_duration_field:
|
||||
raise ValueError(
|
||||
"Extract requires native DurationField database support."
|
||||
)
|
||||
sql, params = connection.ops.time_extract_sql(
|
||||
self.lookup_name, sql, tuple(params)
|
||||
)
|
||||
else:
|
||||
# resolve_expression has already validated the output_field so this
|
||||
# assert should never be hit.
|
||||
assert False, "Tried to Extract from an invalid type."
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
field = getattr(copy.lhs, "output_field", None)
|
||||
if field is None:
|
||||
return copy
|
||||
if not isinstance(field, (DateField, DateTimeField, TimeField, DurationField)):
|
||||
raise ValueError(
|
||||
"Extract input expression must be DateField, DateTimeField, "
|
||||
"TimeField, or DurationField."
|
||||
)
|
||||
# Passing dates to functions expecting datetimes is most likely a mistake.
|
||||
if type(field) is DateField and copy.lookup_name in (
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract time component '%s' from DateField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
)
|
||||
if isinstance(field, DurationField) and copy.lookup_name in (
|
||||
"year",
|
||||
"iso_year",
|
||||
"month",
|
||||
"week",
|
||||
"week_day",
|
||||
"iso_week_day",
|
||||
"quarter",
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot extract component '%s' from DurationField '%s'."
|
||||
% (copy.lookup_name, field.name)
|
||||
)
|
||||
return copy
|
||||
|
||||
|
||||
class ExtractYear(Extract):
|
||||
lookup_name = "year"
|
||||
|
||||
|
||||
class ExtractIsoYear(Extract):
|
||||
"""Return the ISO-8601 week-numbering year."""
|
||||
|
||||
lookup_name = "iso_year"
|
||||
|
||||
|
||||
class ExtractMonth(Extract):
|
||||
lookup_name = "month"
|
||||
|
||||
|
||||
class ExtractDay(Extract):
|
||||
lookup_name = "day"
|
||||
|
||||
|
||||
class ExtractWeek(Extract):
|
||||
"""
|
||||
Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
|
||||
week.
|
||||
"""
|
||||
|
||||
lookup_name = "week"
|
||||
|
||||
|
||||
class ExtractWeekDay(Extract):
|
||||
"""
|
||||
Return Sunday=1 through Saturday=7.
|
||||
|
||||
To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
|
||||
"""
|
||||
|
||||
lookup_name = "week_day"
|
||||
|
||||
|
||||
class ExtractIsoWeekDay(Extract):
|
||||
"""Return Monday=1 through Sunday=7, based on ISO-8601."""
|
||||
|
||||
lookup_name = "iso_week_day"
|
||||
|
||||
|
||||
class ExtractQuarter(Extract):
|
||||
lookup_name = "quarter"
|
||||
|
||||
|
||||
class ExtractHour(Extract):
|
||||
lookup_name = "hour"
|
||||
|
||||
|
||||
class ExtractMinute(Extract):
|
||||
lookup_name = "minute"
|
||||
|
||||
|
||||
class ExtractSecond(Extract):
|
||||
lookup_name = "second"
|
||||
|
||||
|
||||
DateField.register_lookup(ExtractYear)
|
||||
DateField.register_lookup(ExtractMonth)
|
||||
DateField.register_lookup(ExtractDay)
|
||||
DateField.register_lookup(ExtractWeekDay)
|
||||
DateField.register_lookup(ExtractIsoWeekDay)
|
||||
DateField.register_lookup(ExtractWeek)
|
||||
DateField.register_lookup(ExtractIsoYear)
|
||||
DateField.register_lookup(ExtractQuarter)
|
||||
|
||||
TimeField.register_lookup(ExtractHour)
|
||||
TimeField.register_lookup(ExtractMinute)
|
||||
TimeField.register_lookup(ExtractSecond)
|
||||
|
||||
DateTimeField.register_lookup(ExtractHour)
|
||||
DateTimeField.register_lookup(ExtractMinute)
|
||||
DateTimeField.register_lookup(ExtractSecond)
|
||||
|
||||
ExtractYear.register_lookup(YearExact)
|
||||
ExtractYear.register_lookup(YearGt)
|
||||
ExtractYear.register_lookup(YearGte)
|
||||
ExtractYear.register_lookup(YearLt)
|
||||
ExtractYear.register_lookup(YearLte)
|
||||
|
||||
ExtractIsoYear.register_lookup(YearExact)
|
||||
ExtractIsoYear.register_lookup(YearGt)
|
||||
ExtractIsoYear.register_lookup(YearGte)
|
||||
ExtractIsoYear.register_lookup(YearLt)
|
||||
ExtractIsoYear.register_lookup(YearLte)
|
||||
|
||||
|
||||
class Now(Func):
|
||||
template = "CURRENT_TIMESTAMP"
|
||||
output_field = DateTimeField()
|
||||
|
||||
def as_postgresql(self, compiler, connection, **extra_context):
|
||||
# PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
|
||||
# transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
|
||||
# other databases.
|
||||
return self.as_sql(
|
||||
compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
|
||||
)
|
||||
|
||||
def as_mysql(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
|
||||
)
|
||||
|
||||
def as_sqlite(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
|
||||
**extra_context,
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection, **extra_context):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="LOCALTIMESTAMP", **extra_context
|
||||
)
|
||||
|
||||
|
||||
class TruncBase(TimezoneMixin, Transform):
|
||||
kind = None
|
||||
tzinfo = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
**extra,
|
||||
):
|
||||
self.tzinfo = tzinfo
|
||||
super().__init__(expression, output_field=output_field, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = None
|
||||
if isinstance(self.lhs.output_field, DateTimeField):
|
||||
tzname = self.get_tzname()
|
||||
elif self.tzinfo is not None:
|
||||
raise ValueError("tzinfo can only be used with DateTimeField.")
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
sql, params = connection.ops.datetime_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, DateField):
|
||||
sql, params = connection.ops.date_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
sql, params = connection.ops.time_trunc_sql(
|
||||
self.kind, sql, tuple(params), tzname
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Trunc only valid on DateField, TimeField, or DateTimeField."
|
||||
)
|
||||
return sql, params
|
||||
|
||||
def resolve_expression(
|
||||
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
|
||||
):
|
||||
copy = super().resolve_expression(
|
||||
query, allow_joins, reuse, summarize, for_save
|
||||
)
|
||||
field = copy.lhs.output_field
|
||||
# DateTimeField is a subclass of DateField so this works for both.
|
||||
if not isinstance(field, (DateField, TimeField)):
|
||||
raise TypeError(
|
||||
"%r isn't a DateField, TimeField, or DateTimeField." % field.name
|
||||
)
|
||||
# If self.output_field was None, then accessing the field will trigger
|
||||
# the resolver to assign it to self.lhs.output_field.
|
||||
if not isinstance(copy.output_field, (DateField, DateTimeField, TimeField)):
|
||||
raise ValueError(
|
||||
"output_field must be either DateField, TimeField, or DateTimeField"
|
||||
)
|
||||
# Passing dates or times to functions expecting datetimes is most
|
||||
# likely a mistake.
|
||||
class_output_field = (
|
||||
self.__class__.output_field
|
||||
if isinstance(self.__class__.output_field, Field)
|
||||
else None
|
||||
)
|
||||
output_field = class_output_field or copy.output_field
|
||||
has_explicit_output_field = (
|
||||
class_output_field or field.__class__ is not copy.output_field.__class__
|
||||
)
|
||||
if type(field) is DateField and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("hour", "minute", "second", "time")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate DateField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
(
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField"
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(field, TimeField) and (
|
||||
isinstance(output_field, DateTimeField)
|
||||
or copy.kind in ("year", "quarter", "month", "week", "day", "date")
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot truncate TimeField '%s' to %s."
|
||||
% (
|
||||
field.name,
|
||||
(
|
||||
output_field.__class__.__name__
|
||||
if has_explicit_output_field
|
||||
else "DateTimeField"
|
||||
),
|
||||
)
|
||||
)
|
||||
return copy
|
||||
|
||||
def convert_value(self, value, expression, connection):
|
||||
if isinstance(self.output_field, DateTimeField):
|
||||
if not settings.USE_TZ:
|
||||
pass
|
||||
elif value is not None:
|
||||
value = value.replace(tzinfo=None)
|
||||
value = timezone.make_aware(value, self.tzinfo)
|
||||
elif not connection.features.has_zoneinfo_database:
|
||||
raise ValueError(
|
||||
"Database returned an invalid datetime value. Are time "
|
||||
"zone definitions for your database installed?"
|
||||
)
|
||||
elif isinstance(value, datetime):
|
||||
if value is None:
|
||||
pass
|
||||
elif isinstance(self.output_field, DateField):
|
||||
value = value.date()
|
||||
elif isinstance(self.output_field, TimeField):
|
||||
value = value.time()
|
||||
return value
|
||||
|
||||
|
||||
class Trunc(TruncBase):
|
||||
def __init__(
|
||||
self,
|
||||
expression,
|
||||
kind,
|
||||
output_field=None,
|
||||
tzinfo=None,
|
||||
**extra,
|
||||
):
|
||||
self.kind = kind
|
||||
super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
|
||||
|
||||
|
||||
class TruncYear(TruncBase):
|
||||
kind = "year"
|
||||
|
||||
|
||||
class TruncQuarter(TruncBase):
|
||||
kind = "quarter"
|
||||
|
||||
|
||||
class TruncMonth(TruncBase):
|
||||
kind = "month"
|
||||
|
||||
|
||||
class TruncWeek(TruncBase):
|
||||
"""Truncate to midnight on the Monday of the week."""
|
||||
|
||||
kind = "week"
|
||||
|
||||
|
||||
class TruncDay(TruncBase):
|
||||
kind = "day"
|
||||
|
||||
|
||||
class TruncDate(TruncBase):
|
||||
kind = "date"
|
||||
lookup_name = "date"
|
||||
output_field = DateField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to date rather than truncate to date.
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncTime(TruncBase):
|
||||
kind = "time"
|
||||
lookup_name = "time"
|
||||
output_field = TimeField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# Cast to time rather than truncate to time.
|
||||
sql, params = compiler.compile(self.lhs)
|
||||
tzname = self.get_tzname()
|
||||
return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
|
||||
|
||||
|
||||
class TruncHour(TruncBase):
|
||||
kind = "hour"
|
||||
|
||||
|
||||
class TruncMinute(TruncBase):
|
||||
kind = "minute"
|
||||
|
||||
|
||||
class TruncSecond(TruncBase):
|
||||
kind = "second"
|
||||
|
||||
|
||||
DateTimeField.register_lookup(TruncDate)
|
||||
DateTimeField.register_lookup(TruncTime)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user