init commit
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,177 @@
|
||||
import json
|
||||
|
||||
from django.core import checks
|
||||
from django.db.models import NOT_PROVIDED, Field
|
||||
from django.db.models.expressions import ColPairs
|
||||
from django.db.models.fields.tuple_lookups import (
|
||||
TupleExact,
|
||||
TupleGreaterThan,
|
||||
TupleGreaterThanOrEqual,
|
||||
TupleIn,
|
||||
TupleIsNull,
|
||||
TupleLessThan,
|
||||
TupleLessThanOrEqual,
|
||||
)
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
|
||||
class AttributeSetter:
|
||||
def __init__(self, name, value):
|
||||
setattr(self, name, value)
|
||||
|
||||
|
||||
class CompositeAttribute:
|
||||
def __init__(self, field):
|
||||
self.field = field
|
||||
|
||||
@property
|
||||
def attnames(self):
|
||||
return [field.attname for field in self.field.fields]
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
return tuple(getattr(instance, attname) for attname in self.attnames)
|
||||
|
||||
def __set__(self, instance, values):
|
||||
attnames = self.attnames
|
||||
length = len(attnames)
|
||||
|
||||
if values is None:
|
||||
values = (None,) * length
|
||||
|
||||
if not isinstance(values, (list, tuple)):
|
||||
raise ValueError(f"{self.field.name!r} must be a list or a tuple.")
|
||||
if length != len(values):
|
||||
raise ValueError(f"{self.field.name!r} must have {length} elements.")
|
||||
|
||||
for attname, value in zip(attnames, values):
|
||||
setattr(instance, attname, value)
|
||||
|
||||
|
||||
class CompositePrimaryKey(Field):
|
||||
descriptor_class = CompositeAttribute
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if (
|
||||
not args
|
||||
or not all(isinstance(field, str) for field in args)
|
||||
or len(set(args)) != len(args)
|
||||
):
|
||||
raise ValueError("CompositePrimaryKey args must be unique strings.")
|
||||
if len(args) == 1:
|
||||
raise ValueError("CompositePrimaryKey must include at least two fields.")
|
||||
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("CompositePrimaryKey cannot have a default.")
|
||||
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("CompositePrimaryKey cannot have a database default.")
|
||||
if kwargs.get("db_column", None) is not None:
|
||||
raise ValueError("CompositePrimaryKey cannot have a db_column.")
|
||||
if kwargs.setdefault("editable", False):
|
||||
raise ValueError("CompositePrimaryKey cannot be editable.")
|
||||
if not kwargs.setdefault("primary_key", True):
|
||||
raise ValueError("CompositePrimaryKey must be a primary key.")
|
||||
if not kwargs.setdefault("blank", True):
|
||||
raise ValueError("CompositePrimaryKey must be blank.")
|
||||
|
||||
self.field_names = args
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
# args is always [] so it can be ignored.
|
||||
name, path, _, kwargs = super().deconstruct()
|
||||
return name, path, self.field_names, kwargs
|
||||
|
||||
@cached_property
|
||||
def fields(self):
|
||||
meta = self.model._meta
|
||||
return tuple(meta.get_field(field_name) for field_name in self.field_names)
|
||||
|
||||
@cached_property
|
||||
def columns(self):
|
||||
return tuple(field.column for field in self.fields)
|
||||
|
||||
def contribute_to_class(self, cls, name, private_only=False):
|
||||
super().contribute_to_class(cls, name, private_only=private_only)
|
||||
cls._meta.pk = self
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def get_attname_column(self):
|
||||
return self.get_attname(), None
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.fields)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.field_names)
|
||||
|
||||
@cached_property
|
||||
def cached_col(self):
|
||||
return ColPairs(self.model._meta.db_table, self.fields, self.fields, self)
|
||||
|
||||
def get_col(self, alias, output_field=None):
|
||||
if alias == self.model._meta.db_table and (
|
||||
output_field is None or output_field == self
|
||||
):
|
||||
return self.cached_col
|
||||
|
||||
return ColPairs(alias, self.fields, self.fields, output_field)
|
||||
|
||||
def get_pk_value_on_save(self, instance):
|
||||
values = []
|
||||
|
||||
for field in self.fields:
|
||||
value = field.value_from_object(instance)
|
||||
if value is None:
|
||||
value = field.get_pk_value_on_save(instance)
|
||||
values.append(value)
|
||||
|
||||
return tuple(values)
|
||||
|
||||
def _check_field_name(self):
|
||||
if self.name == "pk":
|
||||
return []
|
||||
return [
|
||||
checks.Error(
|
||||
"'CompositePrimaryKey' must be named 'pk'.",
|
||||
obj=self,
|
||||
id="fields.E013",
|
||||
)
|
||||
]
|
||||
|
||||
def value_to_string(self, obj):
|
||||
values = []
|
||||
vals = self.value_from_object(obj)
|
||||
for field, value in zip(self.fields, vals):
|
||||
obj = AttributeSetter(field.attname, value)
|
||||
values.append(field.value_to_string(obj))
|
||||
return json.dumps(values, ensure_ascii=False)
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, str):
|
||||
# Assume we're deserializing.
|
||||
vals = json.loads(value)
|
||||
value = [
|
||||
field.to_python(val)
|
||||
for field, val in zip(self.fields, vals, strict=True)
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
CompositePrimaryKey.register_lookup(TupleExact)
|
||||
CompositePrimaryKey.register_lookup(TupleGreaterThan)
|
||||
CompositePrimaryKey.register_lookup(TupleGreaterThanOrEqual)
|
||||
CompositePrimaryKey.register_lookup(TupleLessThan)
|
||||
CompositePrimaryKey.register_lookup(TupleLessThanOrEqual)
|
||||
CompositePrimaryKey.register_lookup(TupleIn)
|
||||
CompositePrimaryKey.register_lookup(TupleIsNull)
|
||||
|
||||
|
||||
def unnest(fields):
|
||||
result = []
|
||||
|
||||
for field in fields:
|
||||
if isinstance(field, CompositePrimaryKey):
|
||||
result.extend(field.fields)
|
||||
else:
|
||||
result.append(field)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,538 @@
|
||||
import datetime
|
||||
import posixpath
|
||||
|
||||
from django import forms
|
||||
from django.core import checks
|
||||
from django.core.exceptions import FieldError
|
||||
from django.core.files.base import ContentFile, File
|
||||
from django.core.files.images import ImageFile
|
||||
from django.core.files.storage import Storage, default_storage
|
||||
from django.core.files.utils import validate_file_name
|
||||
from django.db.models import signals
|
||||
from django.db.models.expressions import DatabaseDefault
|
||||
from django.db.models.fields import Field
|
||||
from django.db.models.query_utils import DeferredAttribute
|
||||
from django.db.models.utils import AltersData
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.version import PY311
|
||||
|
||||
|
||||
class FieldFile(File, AltersData):
|
||||
def __init__(self, instance, field, name):
|
||||
super().__init__(None, name)
|
||||
self.instance = instance
|
||||
self.field = field
|
||||
self.storage = field.storage
|
||||
self._committed = True
|
||||
|
||||
def __eq__(self, other):
|
||||
# Older code may be expecting FileField values to be simple strings.
|
||||
# By overriding the == operator, it can remain backwards compatibility.
|
||||
if hasattr(other, "name"):
|
||||
return self.name == other.name
|
||||
return self.name == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
# The standard File contains most of the necessary properties, but
|
||||
# FieldFiles can be instantiated without a name, so that needs to
|
||||
# be checked for here.
|
||||
|
||||
def _require_file(self):
|
||||
if not self:
|
||||
raise ValueError(
|
||||
"The '%s' attribute has no file associated with it." % self.field.name
|
||||
)
|
||||
|
||||
def _get_file(self):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self._file = self.storage.open(self.name, "rb")
|
||||
return self._file
|
||||
|
||||
def _set_file(self, file):
|
||||
self._file = file
|
||||
|
||||
def _del_file(self):
|
||||
del self._file
|
||||
|
||||
file = property(_get_file, _set_file, _del_file)
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
self._require_file()
|
||||
return self.storage.path(self.name)
|
||||
|
||||
@property
|
||||
def url(self):
|
||||
self._require_file()
|
||||
return self.storage.url(self.name)
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
self._require_file()
|
||||
if not self._committed:
|
||||
return self.file.size
|
||||
return self.storage.size(self.name)
|
||||
|
||||
def open(self, mode="rb"):
|
||||
self._require_file()
|
||||
if getattr(self, "_file", None) is None:
|
||||
self.file = self.storage.open(self.name, mode)
|
||||
else:
|
||||
self.file.open(mode)
|
||||
return self
|
||||
|
||||
# open() doesn't alter the file's contents, but it does reset the pointer
|
||||
open.alters_data = True
|
||||
|
||||
# In addition to the standard File API, FieldFiles have extra methods
|
||||
# to further manipulate the underlying file, as well as update the
|
||||
# associated model instance.
|
||||
|
||||
def _set_instance_attribute(self, name, content):
|
||||
setattr(self.instance, self.field.attname, name)
|
||||
|
||||
def save(self, name, content, save=True):
|
||||
name = self.field.generate_filename(self.instance, name)
|
||||
self.name = self.storage.save(name, content, max_length=self.field.max_length)
|
||||
self._set_instance_attribute(self.name, content)
|
||||
self._committed = True
|
||||
|
||||
# Save the object because it has changed, unless save is False
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
save.alters_data = True
|
||||
|
||||
def delete(self, save=True):
|
||||
if not self:
|
||||
return
|
||||
# Only close the file if it's already open, which we know by the
|
||||
# presence of self._file
|
||||
if hasattr(self, "_file"):
|
||||
self.close()
|
||||
del self.file
|
||||
|
||||
self.storage.delete(self.name)
|
||||
|
||||
self.name = None
|
||||
setattr(self.instance, self.field.attname, self.name)
|
||||
self._committed = False
|
||||
|
||||
if save:
|
||||
self.instance.save()
|
||||
|
||||
delete.alters_data = True
|
||||
|
||||
@property
|
||||
def closed(self):
|
||||
file = getattr(self, "_file", None)
|
||||
return file is None or file.closed
|
||||
|
||||
def close(self):
|
||||
file = getattr(self, "_file", None)
|
||||
if file is not None:
|
||||
file.close()
|
||||
|
||||
def __getstate__(self):
|
||||
# FieldFile needs access to its associated model field, an instance and
|
||||
# the file's name. Everything else will be restored later, by
|
||||
# FileDescriptor below.
|
||||
return {
|
||||
"name": self.name,
|
||||
"closed": False,
|
||||
"_committed": True,
|
||||
"_file": None,
|
||||
"instance": self.instance,
|
||||
"field": self.field,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self.storage = self.field.storage
|
||||
|
||||
|
||||
class FileDescriptor(DeferredAttribute):
|
||||
"""
|
||||
The descriptor for the file attribute on the model instance. Return a
|
||||
FieldFile when accessed so you can write code like::
|
||||
|
||||
>>> from myapp.models import MyModel
|
||||
>>> instance = MyModel.objects.get(pk=1)
|
||||
>>> instance.file.size
|
||||
|
||||
Assign a file object on assignment so you can do::
|
||||
|
||||
>>> with open('/path/to/hello.world') as f:
|
||||
... instance.file = File(f)
|
||||
"""
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
|
||||
# This is slightly complicated, so worth an explanation.
|
||||
# instance.file needs to ultimately return some instance of `File`,
|
||||
# probably a subclass. Additionally, this returned object needs to have
|
||||
# the FieldFile API so that users can easily do things like
|
||||
# instance.file.path and have that delegated to the file storage engine.
|
||||
# Easy enough if we're strict about assignment in __set__, but if you
|
||||
# peek below you can see that we're not. So depending on the current
|
||||
# value of the field we have to dynamically construct some sort of
|
||||
# "thing" to return.
|
||||
|
||||
# The instance dict contains whatever was originally assigned
|
||||
# in __set__.
|
||||
file = super().__get__(instance, cls)
|
||||
|
||||
# If this value is a string (instance.file = "path/to/file") or None
|
||||
# then we simply wrap it with the appropriate attribute class according
|
||||
# to the file field. [This is FieldFile for FileFields and
|
||||
# ImageFieldFile for ImageFields; it's also conceivable that user
|
||||
# subclasses might also want to subclass the attribute class]. This
|
||||
# object understands how to convert a path to a file, and also how to
|
||||
# handle None.
|
||||
if isinstance(file, str) or file is None:
|
||||
attr = self.field.attr_class(instance, self.field, file)
|
||||
instance.__dict__[self.field.attname] = attr
|
||||
|
||||
# If this value is a DatabaseDefault, initialize the attribute class
|
||||
# for this field with its db_default value.
|
||||
elif isinstance(file, DatabaseDefault):
|
||||
attr = self.field.attr_class(instance, self.field, self.field.db_default)
|
||||
instance.__dict__[self.field.attname] = attr
|
||||
|
||||
# Other types of files may be assigned as well, but they need to have
|
||||
# the FieldFile interface added to them. Thus, we wrap any other type of
|
||||
# File inside a FieldFile (well, the field's attr_class, which is
|
||||
# usually FieldFile).
|
||||
elif isinstance(file, File) and not isinstance(file, FieldFile):
|
||||
file_copy = self.field.attr_class(instance, self.field, file.name)
|
||||
file_copy.file = file
|
||||
file_copy._committed = False
|
||||
instance.__dict__[self.field.attname] = file_copy
|
||||
|
||||
# Finally, because of the (some would say boneheaded) way pickle works,
|
||||
# the underlying FieldFile might not actually itself have an associated
|
||||
# file. So we need to reset the details of the FieldFile in those cases.
|
||||
elif isinstance(file, FieldFile) and not hasattr(file, "field"):
|
||||
file.instance = instance
|
||||
file.field = self.field
|
||||
file.storage = self.field.storage
|
||||
|
||||
# Make sure that the instance is correct.
|
||||
elif isinstance(file, FieldFile) and instance is not file.instance:
|
||||
file.instance = instance
|
||||
|
||||
# That was fun, wasn't it?
|
||||
return instance.__dict__[self.field.attname]
|
||||
|
||||
def __set__(self, instance, value):
|
||||
instance.__dict__[self.field.attname] = value
|
||||
|
||||
|
||||
class FileField(Field):
|
||||
# The class to wrap instance attributes in. Accessing the file object off
|
||||
# the instance will always return an instance of attr_class.
|
||||
attr_class = FieldFile
|
||||
|
||||
# The descriptor to use for accessing the attribute off of the class.
|
||||
descriptor_class = FileDescriptor
|
||||
|
||||
description = _("File")
|
||||
|
||||
def __init__(
|
||||
self, verbose_name=None, name=None, upload_to="", storage=None, **kwargs
|
||||
):
|
||||
self._primary_key_set_explicitly = "primary_key" in kwargs
|
||||
|
||||
self.storage = storage or default_storage
|
||||
if callable(self.storage):
|
||||
# Hold a reference to the callable for deconstruct().
|
||||
self._storage_callable = self.storage
|
||||
self.storage = self.storage()
|
||||
if not isinstance(self.storage, Storage):
|
||||
raise TypeError(
|
||||
"%s.storage must be a subclass/instance of %s.%s"
|
||||
% (
|
||||
self.__class__.__qualname__,
|
||||
Storage.__module__,
|
||||
Storage.__qualname__,
|
||||
)
|
||||
)
|
||||
self.upload_to = upload_to
|
||||
|
||||
kwargs.setdefault("max_length", 100)
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_primary_key(),
|
||||
*self._check_upload_to(),
|
||||
]
|
||||
|
||||
def _check_primary_key(self):
|
||||
if self._primary_key_set_explicitly:
|
||||
return [
|
||||
checks.Error(
|
||||
"'primary_key' is not a valid argument for a %s."
|
||||
% self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E201",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _check_upload_to(self):
|
||||
if isinstance(self.upload_to, str) and self.upload_to.startswith("/"):
|
||||
return [
|
||||
checks.Error(
|
||||
"%s's 'upload_to' argument must be a relative path, not an "
|
||||
"absolute path." % self.__class__.__name__,
|
||||
obj=self,
|
||||
id="fields.E202",
|
||||
hint="Remove the leading slash.",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if kwargs.get("max_length") == 100:
|
||||
del kwargs["max_length"]
|
||||
kwargs["upload_to"] = self.upload_to
|
||||
storage = getattr(self, "_storage_callable", self.storage)
|
||||
if storage is not default_storage:
|
||||
kwargs["storage"] = storage
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
return "FileField"
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
# Need to convert File objects provided via a form to string for
|
||||
# database insertion.
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
def pre_save(self, model_instance, add):
|
||||
file = super().pre_save(model_instance, add)
|
||||
if file.name is None and file._file is not None:
|
||||
exc = FieldError(
|
||||
f"File for {self.name} must have "
|
||||
"the name attribute specified to be saved."
|
||||
)
|
||||
if PY311 and isinstance(file._file, ContentFile):
|
||||
exc.add_note("Pass a 'name' argument to ContentFile.")
|
||||
raise exc
|
||||
|
||||
if file and not file._committed:
|
||||
# Commit the file to storage prior to saving the model
|
||||
file.save(file.name, file.file, save=False)
|
||||
return file
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
setattr(cls, self.attname, self.descriptor_class(self))
|
||||
|
||||
def generate_filename(self, instance, filename):
|
||||
"""
|
||||
Apply (if callable) or prepend (if a string) upload_to to the filename,
|
||||
then delegate further processing of the name to the storage backend.
|
||||
Until the storage layer, all file paths are expected to be Unix style
|
||||
(with forward slashes).
|
||||
"""
|
||||
if callable(self.upload_to):
|
||||
filename = self.upload_to(instance, filename)
|
||||
else:
|
||||
dirname = datetime.datetime.now().strftime(str(self.upload_to))
|
||||
filename = posixpath.join(dirname, filename)
|
||||
filename = validate_file_name(filename, allow_relative_path=True)
|
||||
return self.storage.generate_filename(filename)
|
||||
|
||||
def save_form_data(self, instance, data):
|
||||
# Important: None means "no change", other false value means "clear"
|
||||
# This subtle distinction (rather than a more explicit marker) is
|
||||
# needed because we need to consume values that are also sane for a
|
||||
# regular (non Model-) Form to find in its cleaned_data dictionary.
|
||||
if data is not None:
|
||||
# This value will be converted to str and stored in the
|
||||
# database, so leaving False as-is is not acceptable.
|
||||
setattr(instance, self.name, data or "")
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.FileField,
|
||||
"max_length": self.max_length,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ImageFileDescriptor(FileDescriptor):
|
||||
"""
|
||||
Just like the FileDescriptor, but for ImageFields. The only difference is
|
||||
assigning the width/height to the width_field/height_field, if appropriate.
|
||||
"""
|
||||
|
||||
def __set__(self, instance, value):
|
||||
previous_file = instance.__dict__.get(self.field.attname)
|
||||
super().__set__(instance, value)
|
||||
|
||||
# To prevent recalculating image dimensions when we are instantiating
|
||||
# an object from the database (bug #11084), only update dimensions if
|
||||
# the field had a value before this assignment. Since the default
|
||||
# value for FileField subclasses is an instance of field.attr_class,
|
||||
# previous_file will only be None when we are called from
|
||||
# Model.__init__(). The ImageField.update_dimension_fields method
|
||||
# hooked up to the post_init signal handles the Model.__init__() cases.
|
||||
# Assignment happening outside of Model.__init__() will trigger the
|
||||
# update right here.
|
||||
if previous_file is not None:
|
||||
self.field.update_dimension_fields(instance, force=True)
|
||||
|
||||
|
||||
class ImageFieldFile(ImageFile, FieldFile):
|
||||
def _set_instance_attribute(self, name, content):
|
||||
setattr(self.instance, self.field.attname, content)
|
||||
# Update the name in case generate_filename() or storage.save() changed
|
||||
# it, but bypass the descriptor to avoid re-reading the file.
|
||||
self.instance.__dict__[self.field.attname] = self.name
|
||||
|
||||
def delete(self, save=True):
|
||||
# Clear the image dimensions cache
|
||||
if hasattr(self, "_dimensions_cache"):
|
||||
del self._dimensions_cache
|
||||
super().delete(save)
|
||||
|
||||
|
||||
class ImageField(FileField):
|
||||
attr_class = ImageFieldFile
|
||||
descriptor_class = ImageFileDescriptor
|
||||
description = _("Image")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
width_field=None,
|
||||
height_field=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.width_field, self.height_field = width_field, height_field
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
return [
|
||||
*super().check(**kwargs),
|
||||
*self._check_image_library_installed(),
|
||||
]
|
||||
|
||||
def _check_image_library_installed(self):
|
||||
try:
|
||||
from PIL import Image # NOQA
|
||||
except ImportError:
|
||||
return [
|
||||
checks.Error(
|
||||
"Cannot use ImageField because Pillow is not installed.",
|
||||
hint=(
|
||||
"Get Pillow at https://pypi.org/project/Pillow/ "
|
||||
'or run command "python -m pip install Pillow".'
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E210",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.width_field:
|
||||
kwargs["width_field"] = self.width_field
|
||||
if self.height_field:
|
||||
kwargs["height_field"] = self.height_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def contribute_to_class(self, cls, name, **kwargs):
|
||||
super().contribute_to_class(cls, name, **kwargs)
|
||||
# Attach update_dimension_fields so that dimension fields declared
|
||||
# after their corresponding image field don't stay cleared by
|
||||
# Model.__init__, see bug #11196.
|
||||
# Only run post-initialization dimension update on non-abstract models
|
||||
# with width_field/height_field.
|
||||
if not cls._meta.abstract and (self.width_field or self.height_field):
|
||||
signals.post_init.connect(self.update_dimension_fields, sender=cls)
|
||||
|
||||
def update_dimension_fields(self, instance, force=False, *args, **kwargs):
|
||||
"""
|
||||
Update field's width and height fields, if defined.
|
||||
|
||||
This method is hooked up to model's post_init signal to update
|
||||
dimensions after instantiating a model instance. However, dimensions
|
||||
won't be updated if the dimensions fields are already populated. This
|
||||
avoids unnecessary recalculation when loading an object from the
|
||||
database.
|
||||
|
||||
Dimensions can be forced to update with force=True, which is how
|
||||
ImageFileDescriptor.__set__ calls this method.
|
||||
"""
|
||||
# Nothing to update if the field doesn't have dimension fields or if
|
||||
# the field is deferred.
|
||||
has_dimension_fields = self.width_field or self.height_field
|
||||
if not has_dimension_fields or self.attname not in instance.__dict__:
|
||||
return
|
||||
|
||||
# getattr will call the ImageFileDescriptor's __get__ method, which
|
||||
# coerces the assigned value into an instance of self.attr_class
|
||||
# (ImageFieldFile in this case).
|
||||
file = getattr(instance, self.attname)
|
||||
|
||||
# Nothing to update if we have no file and not being forced to update.
|
||||
if not file and not force:
|
||||
return
|
||||
|
||||
dimension_fields_filled = not (
|
||||
(self.width_field and not getattr(instance, self.width_field))
|
||||
or (self.height_field and not getattr(instance, self.height_field))
|
||||
)
|
||||
# When both dimension fields have values, we are most likely loading
|
||||
# data from the database or updating an image field that already had
|
||||
# an image stored. In the first case, we don't want to update the
|
||||
# dimension fields because we are already getting their values from the
|
||||
# database. In the second case, we do want to update the dimensions
|
||||
# fields and will skip this return because force will be True since we
|
||||
# were called from ImageFileDescriptor.__set__.
|
||||
if dimension_fields_filled and not force:
|
||||
return
|
||||
|
||||
# file should be an instance of ImageFieldFile or should be None.
|
||||
if file:
|
||||
width = file.width
|
||||
height = file.height
|
||||
else:
|
||||
# No file, so clear dimensions fields.
|
||||
width = None
|
||||
height = None
|
||||
|
||||
# Update the width and height fields.
|
||||
if self.width_field:
|
||||
setattr(instance, self.width_field, width)
|
||||
if self.height_field:
|
||||
setattr(instance, self.height_field, height)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.ImageField,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,197 @@
|
||||
from django.core import checks
|
||||
from django.db import connections, router
|
||||
from django.db.models.sql import Query
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from . import NOT_PROVIDED, Field
|
||||
|
||||
__all__ = ["GeneratedField"]
|
||||
|
||||
|
||||
class GeneratedField(Field):
|
||||
generated = True
|
||||
db_returning = True
|
||||
|
||||
_query = None
|
||||
output_field = None
|
||||
|
||||
def __init__(self, *, expression, output_field, db_persist=None, **kwargs):
|
||||
if kwargs.setdefault("editable", False):
|
||||
raise ValueError("GeneratedField cannot be editable.")
|
||||
if not kwargs.setdefault("blank", True):
|
||||
raise ValueError("GeneratedField must be blank.")
|
||||
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("GeneratedField cannot have a default.")
|
||||
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
|
||||
raise ValueError("GeneratedField cannot have a database default.")
|
||||
if db_persist not in (True, False):
|
||||
raise ValueError("GeneratedField.db_persist must be True or False.")
|
||||
|
||||
self.expression = expression
|
||||
self.output_field = output_field
|
||||
self.db_persist = db_persist
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@cached_property
|
||||
def cached_col(self):
|
||||
from django.db.models.expressions import Col
|
||||
|
||||
return Col(self.model._meta.db_table, self, self.output_field)
|
||||
|
||||
def get_col(self, alias, output_field=None):
|
||||
if alias != self.model._meta.db_table and output_field in (None, self):
|
||||
output_field = self.output_field
|
||||
return super().get_col(alias, output_field)
|
||||
|
||||
def contribute_to_class(self, *args, **kwargs):
|
||||
super().contribute_to_class(*args, **kwargs)
|
||||
|
||||
self._query = Query(model=self.model, alias_cols=False)
|
||||
# Register lookups from the output_field class.
|
||||
for lookup_name, lookup in self.output_field.get_class_lookups().items():
|
||||
self.register_lookup(lookup, lookup_name=lookup_name)
|
||||
|
||||
def generated_sql(self, connection):
|
||||
compiler = connection.ops.compiler("SQLCompiler")(
|
||||
self._query, connection=connection, using=None
|
||||
)
|
||||
resolved_expression = self.expression.resolve_expression(
|
||||
self._query, allow_joins=False
|
||||
)
|
||||
sql, params = compiler.compile(resolved_expression)
|
||||
if (
|
||||
getattr(self.expression, "conditional", False)
|
||||
and not connection.features.supports_boolean_expr_in_select_clause
|
||||
):
|
||||
sql = f"CASE WHEN {sql} THEN 1 ELSE 0 END"
|
||||
return sql, params
|
||||
|
||||
def check(self, **kwargs):
|
||||
databases = kwargs.get("databases") or []
|
||||
errors = [
|
||||
*super().check(**kwargs),
|
||||
*self._check_supported(databases),
|
||||
*self._check_persistence(databases),
|
||||
]
|
||||
output_field_clone = self.output_field.clone()
|
||||
output_field_clone.model = self.model
|
||||
output_field_checks = output_field_clone.check(databases=databases)
|
||||
if output_field_checks:
|
||||
separator = "\n "
|
||||
error_messages = separator.join(
|
||||
f"{output_check.msg} ({output_check.id})"
|
||||
for output_check in output_field_checks
|
||||
if isinstance(output_check, checks.Error)
|
||||
)
|
||||
if error_messages:
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"GeneratedField.output_field has errors:"
|
||||
f"{separator}{error_messages}",
|
||||
obj=self,
|
||||
id="fields.E223",
|
||||
)
|
||||
)
|
||||
warning_messages = separator.join(
|
||||
f"{output_check.msg} ({output_check.id})"
|
||||
for output_check in output_field_checks
|
||||
if isinstance(output_check, checks.Warning)
|
||||
)
|
||||
if warning_messages:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"GeneratedField.output_field has warnings:"
|
||||
f"{separator}{warning_messages}",
|
||||
obj=self,
|
||||
id="fields.W224",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _check_supported(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
connection.features.supports_virtual_generated_columns
|
||||
or "supports_stored_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
) and not (
|
||||
connection.features.supports_stored_generated_columns
|
||||
or "supports_virtual_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{connection.display_name} does not support GeneratedFields.",
|
||||
obj=self,
|
||||
id="fields.E220",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def _check_persistence(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not self.db_persist and not (
|
||||
connection.features.supports_virtual_generated_columns
|
||||
or "supports_virtual_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{connection.display_name} does not support non-persisted "
|
||||
"GeneratedFields.",
|
||||
obj=self,
|
||||
id="fields.E221",
|
||||
hint="Set db_persist=True on the field.",
|
||||
)
|
||||
)
|
||||
if self.db_persist and not (
|
||||
connection.features.supports_stored_generated_columns
|
||||
or "supports_stored_generated_columns"
|
||||
in self.model._meta.required_db_features
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
f"{connection.display_name} does not support persisted "
|
||||
"GeneratedFields.",
|
||||
obj=self,
|
||||
id="fields.E222",
|
||||
hint="Set db_persist=False on the field.",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
del kwargs["blank"]
|
||||
del kwargs["editable"]
|
||||
kwargs["db_persist"] = self.db_persist
|
||||
kwargs["expression"] = self.expression
|
||||
kwargs["output_field"] = self.output_field
|
||||
return name, path, args, kwargs
|
||||
|
||||
def get_internal_type(self):
|
||||
return self.output_field.get_internal_type()
|
||||
|
||||
def db_parameters(self, connection):
|
||||
return self.output_field.db_parameters(connection)
|
||||
|
||||
def db_type_parameters(self, connection):
|
||||
return self.output_field.db_type_parameters(connection)
|
||||
@@ -0,0 +1,664 @@
|
||||
import json
|
||||
|
||||
from django import forms
|
||||
from django.core import checks, exceptions
|
||||
from django.db import NotSupportedError, connections, router
|
||||
from django.db.models import expressions, lookups
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.fields import TextField
|
||||
from django.db.models.lookups import (
|
||||
FieldGetDbPrepValueMixin,
|
||||
PostgresOperatorLookup,
|
||||
Transform,
|
||||
)
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from . import Field
|
||||
from .mixins import CheckFieldDefaultMixin
|
||||
|
||||
__all__ = ["JSONField"]
|
||||
|
||||
|
||||
class JSONField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _("A JSON object")
|
||||
default_error_messages = {
|
||||
"invalid": _("Value must be valid JSON."),
|
||||
}
|
||||
_default_hint = ("dict", "{}")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose_name=None,
|
||||
name=None,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
**kwargs,
|
||||
):
|
||||
if encoder and not callable(encoder):
|
||||
raise ValueError("The encoder parameter must be a callable object.")
|
||||
if decoder and not callable(decoder):
|
||||
raise ValueError("The decoder parameter must be a callable object.")
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
super().__init__(verbose_name, name, **kwargs)
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
databases = kwargs.get("databases") or []
|
||||
errors.extend(self._check_supported(databases))
|
||||
return errors
|
||||
|
||||
def _check_supported(self, databases):
|
||||
errors = []
|
||||
for db in databases:
|
||||
if not router.allow_migrate_model(db, self.model):
|
||||
continue
|
||||
connection = connections[db]
|
||||
if (
|
||||
self.model._meta.required_db_vendor
|
||||
and self.model._meta.required_db_vendor != connection.vendor
|
||||
):
|
||||
continue
|
||||
if not (
|
||||
"supports_json_field" in self.model._meta.required_db_features
|
||||
or connection.features.supports_json_field
|
||||
):
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"%s does not support JSONFields." % connection.display_name,
|
||||
obj=self.model,
|
||||
id="fields.E180",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.encoder is not None:
|
||||
kwargs["encoder"] = self.encoder
|
||||
if self.decoder is not None:
|
||||
kwargs["decoder"] = self.decoder
|
||||
return name, path, args, kwargs
|
||||
|
||||
def from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return value
|
||||
# Some backends (SQLite at least) extract non-string values in their
|
||||
# SQL datatypes.
|
||||
if isinstance(expression, KeyTransform) and not isinstance(value, str):
|
||||
return value
|
||||
try:
|
||||
return json.loads(value, cls=self.decoder)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
def get_internal_type(self):
|
||||
return "JSONField"
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if not prepared:
|
||||
value = self.get_prep_value(value)
|
||||
return connection.ops.adapt_json_value(value, self.encoder)
|
||||
|
||||
def get_db_prep_save(self, value, connection):
|
||||
# This slightly involved logic is to allow for `None` to be used to
|
||||
# store SQL `NULL` while `Value(None, JSONField())` can be used to
|
||||
# store JSON `null` while preventing compilable `as_sql` values from
|
||||
# making their way to `get_db_prep_value`, which is what the `super()`
|
||||
# implementation does.
|
||||
if value is None:
|
||||
return value
|
||||
if (
|
||||
isinstance(value, expressions.Value)
|
||||
and value.value is None
|
||||
and isinstance(value.output_field, JSONField)
|
||||
):
|
||||
value = None
|
||||
return super().get_db_prep_save(value, connection)
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
return KeyTransformFactory(name)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
try:
|
||||
json.dumps(value, cls=self.encoder)
|
||||
except TypeError:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["invalid"],
|
||||
code="invalid",
|
||||
params={"value": value},
|
||||
)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return self.value_from_object(obj)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.JSONField,
|
||||
"encoder": self.encoder,
|
||||
"decoder": self.decoder,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def compile_json_path(key_transforms, include_root=True):
|
||||
path = ["$"] if include_root else []
|
||||
for key_transform in key_transforms:
|
||||
try:
|
||||
num = int(key_transform)
|
||||
except ValueError: # non-integer
|
||||
path.append(".")
|
||||
path.append(json.dumps(key_transform))
|
||||
else:
|
||||
path.append("[%s]" % num)
|
||||
return "".join(path)
|
||||
|
||||
|
||||
class DataContains(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contains"
|
||||
postgres_operator = "@>"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contains lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(lhs_params) + tuple(rhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (lhs, rhs), params
|
||||
|
||||
|
||||
class ContainedBy(FieldGetDbPrepValueMixin, PostgresOperatorLookup):
|
||||
lookup_name = "contained_by"
|
||||
postgres_operator = "<@"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_json_field_contains:
|
||||
raise NotSupportedError(
|
||||
"contained_by lookup is not supported on this database backend."
|
||||
)
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
params = tuple(rhs_params) + tuple(lhs_params)
|
||||
return "JSON_CONTAINS(%s, %s)" % (rhs, lhs), params
|
||||
|
||||
|
||||
class HasKeyLookup(PostgresOperatorLookup):
|
||||
logical_operator = None
|
||||
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
# Compile the final key without interpreting ints as array elements.
|
||||
return ".%s" % json.dumps(key_transform)
|
||||
|
||||
def _as_sql_parts(self, compiler, connection):
|
||||
# Process JSON path from the left-hand side.
|
||||
if isinstance(self.lhs, KeyTransform):
|
||||
lhs_sql, lhs_params, lhs_key_transforms = self.lhs.preprocess_lhs(
|
||||
compiler, connection
|
||||
)
|
||||
lhs_json_path = compile_json_path(lhs_key_transforms)
|
||||
else:
|
||||
lhs_sql, lhs_params = self.process_lhs(compiler, connection)
|
||||
lhs_json_path = "$"
|
||||
# Process JSON path from the right-hand side.
|
||||
rhs = self.rhs
|
||||
if not isinstance(rhs, (list, tuple)):
|
||||
rhs = [rhs]
|
||||
for key in rhs:
|
||||
if isinstance(key, KeyTransform):
|
||||
*_, rhs_key_transforms = key.preprocess_lhs(compiler, connection)
|
||||
else:
|
||||
rhs_key_transforms = [key]
|
||||
*rhs_key_transforms, final_key = rhs_key_transforms
|
||||
rhs_json_path = compile_json_path(rhs_key_transforms, include_root=False)
|
||||
rhs_json_path += self.compile_json_path_final_key(final_key)
|
||||
yield lhs_sql, lhs_params, lhs_json_path + rhs_json_path
|
||||
|
||||
def _combine_sql_parts(self, parts):
|
||||
# Add condition for each key.
|
||||
if self.logical_operator:
|
||||
return "(%s)" % self.logical_operator.join(parts)
|
||||
return "".join(parts)
|
||||
|
||||
def as_sql(self, compiler, connection, template=None):
|
||||
sql_parts = []
|
||||
params = []
|
||||
for lhs_sql, lhs_params, rhs_json_path in self._as_sql_parts(
|
||||
compiler, connection
|
||||
):
|
||||
sql_parts.append(template % (lhs_sql, "%s"))
|
||||
params.extend(lhs_params + [rhs_json_path])
|
||||
return self._combine_sql_parts(sql_parts), tuple(params)
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_CONTAINS_PATH(%s, 'one', %s)"
|
||||
)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
# Use a custom delimiter to prevent the JSON path from escaping the SQL
|
||||
# literal. See comment in KeyTransform.
|
||||
template = "JSON_EXISTS(%s, q'\uffff%s\uffff')"
|
||||
sql_parts = []
|
||||
params = []
|
||||
for lhs_sql, lhs_params, rhs_json_path in self._as_sql_parts(
|
||||
compiler, connection
|
||||
):
|
||||
# Add right-hand-side directly into SQL because it cannot be passed
|
||||
# as bind variables to JSON_EXISTS. It might result in invalid
|
||||
# queries but it is assumed that it cannot be evaded because the
|
||||
# path is JSON serialized.
|
||||
sql_parts.append(template % (lhs_sql, rhs_json_path))
|
||||
params.extend(lhs_params)
|
||||
return self._combine_sql_parts(sql_parts), tuple(params)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
*_, rhs_key_transforms = self.rhs.preprocess_lhs(compiler, connection)
|
||||
for key in rhs_key_transforms[:-1]:
|
||||
self.lhs = KeyTransform(key, self.lhs)
|
||||
self.rhs = rhs_key_transforms[-1]
|
||||
return super().as_postgresql(compiler, connection)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
return self.as_sql(
|
||||
compiler, connection, template="JSON_TYPE(%s, %s) IS NOT NULL"
|
||||
)
|
||||
|
||||
|
||||
class HasKey(HasKeyLookup):
|
||||
lookup_name = "has_key"
|
||||
postgres_operator = "?"
|
||||
prepare_rhs = False
|
||||
|
||||
|
||||
class HasKeys(HasKeyLookup):
|
||||
lookup_name = "has_keys"
|
||||
postgres_operator = "?&"
|
||||
logical_operator = " AND "
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return [str(item) for item in self.rhs]
|
||||
|
||||
|
||||
class HasAnyKeys(HasKeys):
|
||||
lookup_name = "has_any_keys"
|
||||
postgres_operator = "?|"
|
||||
logical_operator = " OR "
|
||||
|
||||
|
||||
class HasKeyOrArrayIndex(HasKey):
|
||||
def compile_json_path_final_key(self, key_transform):
|
||||
return compile_json_path([key_transform], include_root=False)
|
||||
|
||||
|
||||
class CaseInsensitiveMixin:
|
||||
"""
|
||||
Mixin to allow case-insensitive comparison of JSON values on MySQL.
|
||||
MySQL handles strings used in JSON context using the utf8mb4_bin collation.
|
||||
Because utf8mb4_bin is a binary collation, comparison of JSON values is
|
||||
case-sensitive.
|
||||
"""
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % lhs, lhs_params
|
||||
return lhs, lhs_params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "mysql":
|
||||
return "LOWER(%s)" % rhs, rhs_params
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class JSONExact(lookups.Exact):
|
||||
can_use_none_as_rhs = True
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
# Treat None lookup values as null.
|
||||
if rhs == "%s" and rhs_params == [None]:
|
||||
rhs_params = ["null"]
|
||||
if connection.vendor == "mysql":
|
||||
func = ["JSON_EXTRACT(%s, '$')"] * len(rhs_params)
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, lhs_params = self.process_lhs(compiler, connection)
|
||||
rhs, rhs_params = self.process_rhs(compiler, connection)
|
||||
if connection.features.supports_primitives_in_json_field:
|
||||
lhs = f"JSON({lhs})"
|
||||
rhs = f"JSON({rhs})"
|
||||
return f"JSON_EQUAL({lhs}, {rhs} ERROR ON ERROR)", (*lhs_params, *rhs_params)
|
||||
|
||||
|
||||
class JSONIContains(CaseInsensitiveMixin, lookups.IContains):
|
||||
pass
|
||||
|
||||
|
||||
JSONField.register_lookup(DataContains)
|
||||
JSONField.register_lookup(ContainedBy)
|
||||
JSONField.register_lookup(HasKey)
|
||||
JSONField.register_lookup(HasKeys)
|
||||
JSONField.register_lookup(HasAnyKeys)
|
||||
JSONField.register_lookup(JSONExact)
|
||||
JSONField.register_lookup(JSONIContains)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
postgres_operator = "->"
|
||||
postgres_nested_operator = "#>"
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key_name = str(key_name)
|
||||
|
||||
def preprocess_lhs(self, compiler, connection):
|
||||
key_transforms = [self.key_name]
|
||||
previous = self.lhs
|
||||
while isinstance(previous, KeyTransform):
|
||||
key_transforms.insert(0, previous.key_name)
|
||||
previous = previous.lhs
|
||||
lhs, params = compiler.compile(previous)
|
||||
if connection.vendor == "oracle":
|
||||
# Escape string-formatting.
|
||||
key_transforms = [key.replace("%", "%%") for key in key_transforms]
|
||||
return lhs, params, key_transforms
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "JSON_EXTRACT(%s, %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
if connection.features.supports_primitives_in_json_field:
|
||||
sql = (
|
||||
"COALESCE("
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff' DISALLOW SCALARS)"
|
||||
")"
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"COALESCE("
|
||||
"JSON_QUERY(%s, q'\uffff%s\uffff'),"
|
||||
"JSON_VALUE(%s, q'\uffff%s\uffff')"
|
||||
")"
|
||||
)
|
||||
# Add paths directly into SQL because path expressions cannot be passed
|
||||
# as bind variables on Oracle. Use a custom delimiter to prevent the
|
||||
# JSON path from escaping the SQL literal. Each key in the JSON path is
|
||||
# passed through json.dumps() with ensure_ascii=True (the default),
|
||||
# which converts the delimiter into the escaped \uffff format. This
|
||||
# ensures that the delimiter is not present in the JSON path.
|
||||
return sql % ((lhs, json_path) * 2), tuple(params) * 2
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
if len(key_transforms) > 1:
|
||||
sql = "(%s %s %%s)" % (lhs, self.postgres_nested_operator)
|
||||
return sql, tuple(params) + (key_transforms,)
|
||||
try:
|
||||
lookup = int(self.key_name)
|
||||
except ValueError:
|
||||
lookup = self.key_name
|
||||
return "(%s %s %%s)" % (lhs, self.postgres_operator), tuple(params) + (lookup,)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
datatype_values = ",".join(
|
||||
[repr(datatype) for datatype in connection.ops.jsonfield_datatype_values]
|
||||
)
|
||||
return (
|
||||
"(CASE WHEN JSON_TYPE(%s, %%s) IN (%s) "
|
||||
"THEN JSON_TYPE(%s, %%s) ELSE JSON_EXTRACT(%s, %%s) END)"
|
||||
) % (lhs, datatype_values, lhs, lhs), (tuple(params) + (json_path,)) * 3
|
||||
|
||||
|
||||
class KeyTextTransform(KeyTransform):
|
||||
postgres_operator = "->>"
|
||||
postgres_nested_operator = "#>>"
|
||||
output_field = TextField()
|
||||
|
||||
def as_mysql(self, compiler, connection):
|
||||
if connection.mysql_is_mariadb:
|
||||
# MariaDB doesn't support -> and ->> operators (see MDEV-13594).
|
||||
sql, params = super().as_mysql(compiler, connection)
|
||||
return "JSON_UNQUOTE(%s)" % sql, params
|
||||
else:
|
||||
lhs, params, key_transforms = self.preprocess_lhs(compiler, connection)
|
||||
json_path = compile_json_path(key_transforms)
|
||||
return "(%s ->> %%s)" % lhs, tuple(params) + (json_path,)
|
||||
|
||||
@classmethod
|
||||
def from_lookup(cls, lookup):
|
||||
transform, *keys = lookup.split(LOOKUP_SEP)
|
||||
if not keys:
|
||||
raise ValueError("Lookup must contain key or index transforms.")
|
||||
for key in keys:
|
||||
transform = cls(key, transform)
|
||||
return transform
|
||||
|
||||
|
||||
KT = KeyTextTransform.from_lookup
|
||||
|
||||
|
||||
class KeyTransformTextLookupMixin:
|
||||
"""
|
||||
Mixin for combining with a lookup expecting a text lhs from a JSONField
|
||||
key lookup. On PostgreSQL, make use of the ->> operator instead of casting
|
||||
key values to text and performing the lookup on the resulting
|
||||
representation.
|
||||
"""
|
||||
|
||||
def __init__(self, key_transform, *args, **kwargs):
|
||||
if not isinstance(key_transform, KeyTransform):
|
||||
raise TypeError(
|
||||
"Transform should be an instance of KeyTransform in order to "
|
||||
"use this lookup."
|
||||
)
|
||||
key_text_transform = KeyTextTransform(
|
||||
key_transform.key_name,
|
||||
*key_transform.source_expressions,
|
||||
**key_transform.extra,
|
||||
)
|
||||
super().__init__(key_text_transform, *args, **kwargs)
|
||||
|
||||
|
||||
class KeyTransformIsNull(lookups.IsNull):
|
||||
# key__isnull=False is the same as has_key='key'
|
||||
def as_oracle(self, compiler, connection):
|
||||
sql, params = HasKeyOrArrayIndex(
|
||||
self.lhs.lhs,
|
||||
self.lhs.key_name,
|
||||
).as_oracle(compiler, connection)
|
||||
if not self.rhs:
|
||||
return sql, params
|
||||
# Column doesn't have a key or IS NULL.
|
||||
lhs, lhs_params, _ = self.lhs.preprocess_lhs(compiler, connection)
|
||||
return "(NOT %s OR %s IS NULL)" % (sql, lhs), tuple(params) + tuple(lhs_params)
|
||||
|
||||
def as_sqlite(self, compiler, connection):
|
||||
template = "JSON_TYPE(%s, %s) IS NULL"
|
||||
if not self.rhs:
|
||||
template = "JSON_TYPE(%s, %s) IS NOT NULL"
|
||||
return HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name).as_sql(
|
||||
compiler,
|
||||
connection,
|
||||
template=template,
|
||||
)
|
||||
|
||||
|
||||
class KeyTransformIn(lookups.In):
|
||||
def resolve_expression_parameter(self, compiler, connection, sql, param):
|
||||
sql, params = super().resolve_expression_parameter(
|
||||
compiler,
|
||||
connection,
|
||||
sql,
|
||||
param,
|
||||
)
|
||||
if (
|
||||
not hasattr(param, "as_sql")
|
||||
and not connection.features.has_native_json_field
|
||||
):
|
||||
if connection.vendor == "oracle":
|
||||
value = json.loads(param)
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
if isinstance(value, (list, dict)):
|
||||
sql %= "JSON_QUERY"
|
||||
else:
|
||||
sql %= "JSON_VALUE"
|
||||
elif connection.vendor == "mysql" or (
|
||||
connection.vendor == "sqlite"
|
||||
and params[0] not in connection.ops.jsonfield_datatype_values
|
||||
):
|
||||
sql = "JSON_EXTRACT(%s, '$')"
|
||||
if connection.vendor == "mysql" and connection.mysql_is_mariadb:
|
||||
sql = "JSON_UNQUOTE(%s)" % sql
|
||||
return sql, params
|
||||
|
||||
|
||||
class KeyTransformExact(JSONExact):
|
||||
def process_rhs(self, compiler, connection):
|
||||
if isinstance(self.rhs, KeyTransform):
|
||||
return super(lookups.Exact, self).process_rhs(compiler, connection)
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if connection.vendor == "oracle":
|
||||
func = []
|
||||
sql = "%s(JSON_OBJECT('value' VALUE %%s FORMAT JSON), '$.value')"
|
||||
for value in rhs_params:
|
||||
value = json.loads(value)
|
||||
if isinstance(value, (list, dict)):
|
||||
func.append(sql % "JSON_QUERY")
|
||||
else:
|
||||
func.append(sql % "JSON_VALUE")
|
||||
rhs %= tuple(func)
|
||||
elif connection.vendor == "sqlite":
|
||||
func = []
|
||||
for value in rhs_params:
|
||||
if value in connection.ops.jsonfield_datatype_values:
|
||||
func.append("%s")
|
||||
else:
|
||||
func.append("JSON_EXTRACT(%s, '$')")
|
||||
rhs %= tuple(func)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_oracle(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if rhs_params == ["null"]:
|
||||
# Field has key and it's NULL.
|
||||
has_key_expr = HasKeyOrArrayIndex(self.lhs.lhs, self.lhs.key_name)
|
||||
has_key_sql, has_key_params = has_key_expr.as_oracle(compiler, connection)
|
||||
is_null_expr = self.lhs.get_lookup("isnull")(self.lhs, True)
|
||||
is_null_sql, is_null_params = is_null_expr.as_sql(compiler, connection)
|
||||
return (
|
||||
"%s AND %s" % (has_key_sql, is_null_sql),
|
||||
tuple(has_key_params) + tuple(is_null_params),
|
||||
)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class KeyTransformIExact(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IExact
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIContains(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IContains
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformStartsWith(KeyTransformTextLookupMixin, lookups.StartsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIStartsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IStartsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformEndsWith(KeyTransformTextLookupMixin, lookups.EndsWith):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIEndsWith(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IEndsWith
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformRegex(KeyTransformTextLookupMixin, lookups.Regex):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformIRegex(
|
||||
CaseInsensitiveMixin, KeyTransformTextLookupMixin, lookups.IRegex
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformNumericLookupMixin:
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
if not connection.features.has_native_json_field:
|
||||
rhs_params = [json.loads(value) for value in rhs_params]
|
||||
return rhs, rhs_params
|
||||
|
||||
|
||||
class KeyTransformLt(KeyTransformNumericLookupMixin, lookups.LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformLte(KeyTransformNumericLookupMixin, lookups.LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGt(KeyTransformNumericLookupMixin, lookups.GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class KeyTransformGte(KeyTransformNumericLookupMixin, lookups.GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformIn)
|
||||
KeyTransform.register_lookup(KeyTransformExact)
|
||||
KeyTransform.register_lookup(KeyTransformIExact)
|
||||
KeyTransform.register_lookup(KeyTransformIsNull)
|
||||
KeyTransform.register_lookup(KeyTransformIContains)
|
||||
KeyTransform.register_lookup(KeyTransformStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIStartsWith)
|
||||
KeyTransform.register_lookup(KeyTransformEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformIEndsWith)
|
||||
KeyTransform.register_lookup(KeyTransformRegex)
|
||||
KeyTransform.register_lookup(KeyTransformIRegex)
|
||||
|
||||
KeyTransform.register_lookup(KeyTransformLt)
|
||||
KeyTransform.register_lookup(KeyTransformLte)
|
||||
KeyTransform.register_lookup(KeyTransformGt)
|
||||
KeyTransform.register_lookup(KeyTransformGte)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return KeyTransform(self.key_name, *args, **kwargs)
|
||||
@@ -0,0 +1,81 @@
|
||||
import warnings
|
||||
|
||||
from django.core import checks
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
NOT_PROVIDED = object()
|
||||
|
||||
|
||||
class FieldCacheMixin:
|
||||
"""
|
||||
An API for working with the model's fields value cache.
|
||||
|
||||
Subclasses must set self.cache_name to a unique entry for the cache -
|
||||
typically the field’s name.
|
||||
"""
|
||||
|
||||
# RemovedInDjango60Warning.
|
||||
def get_cache_name(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@cached_property
|
||||
def cache_name(self):
|
||||
# RemovedInDjango60Warning: when the deprecation ends, replace with:
|
||||
# raise NotImplementedError
|
||||
cache_name = self.get_cache_name()
|
||||
warnings.warn(
|
||||
f"Override {self.__class__.__qualname__}.cache_name instead of "
|
||||
"get_cache_name().",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=3,
|
||||
)
|
||||
return cache_name
|
||||
|
||||
def get_cached_value(self, instance, default=NOT_PROVIDED):
|
||||
try:
|
||||
return instance._state.fields_cache[self.cache_name]
|
||||
except KeyError:
|
||||
if default is NOT_PROVIDED:
|
||||
raise
|
||||
return default
|
||||
|
||||
def is_cached(self, instance):
|
||||
return self.cache_name in instance._state.fields_cache
|
||||
|
||||
def set_cached_value(self, instance, value):
|
||||
instance._state.fields_cache[self.cache_name] = value
|
||||
|
||||
def delete_cached_value(self, instance):
|
||||
del instance._state.fields_cache[self.cache_name]
|
||||
|
||||
|
||||
class CheckFieldDefaultMixin:
|
||||
_default_hint = ("<valid default>", "<invalid default>")
|
||||
|
||||
def _check_default(self):
|
||||
if (
|
||||
self.has_default()
|
||||
and self.default is not None
|
||||
and not callable(self.default)
|
||||
):
|
||||
return [
|
||||
checks.Warning(
|
||||
"%s default should be a callable instead of an instance "
|
||||
"so that it's not shared between all field instances."
|
||||
% (self.__class__.__name__,),
|
||||
hint=(
|
||||
"Use a callable instead, e.g., use `%s` instead of "
|
||||
"`%s`." % self._default_hint
|
||||
),
|
||||
obj=self,
|
||||
id="fields.E010",
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
errors.extend(self._check_default())
|
||||
return errors
|
||||
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Field-like classes that aren't really fields. It's easier to use objects that
|
||||
have the same attributes as fields sometimes (avoids a lot of special casing).
|
||||
"""
|
||||
|
||||
from django.db.models import fields
|
||||
|
||||
|
||||
class OrderWrt(fields.IntegerField):
|
||||
"""
|
||||
A proxy for the _order database field that is used when
|
||||
Meta.order_with_respect_to is specified.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["name"] = "_order"
|
||||
kwargs["editable"] = False
|
||||
super().__init__(*args, **kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,151 @@
|
||||
from django.db.models.expressions import ColPairs
|
||||
from django.db.models.fields import composite
|
||||
from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
|
||||
|
||||
def get_normalized_value(value, lhs):
|
||||
from django.db.models import Model
|
||||
|
||||
if isinstance(value, Model):
|
||||
if not value._is_pk_set():
|
||||
raise ValueError("Model instances passed to related filters must be saved.")
|
||||
value_list = []
|
||||
sources = composite.unnest(lhs.output_field.path_infos[-1].target_fields)
|
||||
for source in sources:
|
||||
while not isinstance(value, source.model) and source.remote_field:
|
||||
source = source.remote_field.model._meta.get_field(
|
||||
source.remote_field.field_name
|
||||
)
|
||||
try:
|
||||
value_list.append(getattr(value, source.attname))
|
||||
except AttributeError:
|
||||
# A case like Restaurant.objects.filter(place=restaurant_instance),
|
||||
# where place is a OneToOneField and the primary key of Restaurant.
|
||||
pk = value.pk
|
||||
return pk if isinstance(pk, tuple) else (pk,)
|
||||
return tuple(value_list)
|
||||
if not isinstance(value, tuple):
|
||||
return (value,)
|
||||
return value
|
||||
|
||||
|
||||
class RelatedIn(In):
|
||||
def get_prep_lookup(self):
|
||||
from django.db.models.sql.query import Query # avoid circular import
|
||||
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if (
|
||||
isinstance(self.rhs, Query)
|
||||
and not self.rhs.has_select_fields
|
||||
and self.lhs.output_field.related_model is self.rhs.model
|
||||
):
|
||||
self.rhs.set_values([f.name for f in self.lhs.sources])
|
||||
else:
|
||||
if self.rhs_is_direct_value():
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
|
||||
# We need to run the related field's get_prep_value(). Consider
|
||||
# case ForeignKey to IntegerField given value 'abc'. The
|
||||
# ForeignKey itself doesn't have validation for non-integers,
|
||||
# so we must run validation using the target field.
|
||||
if hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Run the target field's get_prep_value. We can safely
|
||||
# assume there is only one as we don't get to the direct
|
||||
# value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[
|
||||
-1
|
||||
]
|
||||
self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
|
||||
elif not getattr(self.rhs, "has_select_fields", True) and not getattr(
|
||||
self.lhs.field.target_field, "primary_key", False
|
||||
):
|
||||
if (
|
||||
getattr(self.lhs.output_field, "primary_key", False)
|
||||
and self.lhs.output_field.model == self.rhs.model
|
||||
):
|
||||
# A case like
|
||||
# Restaurant.objects.filter(place__in=restaurant_qs), where
|
||||
# place is a OneToOneField and the primary key of
|
||||
# Restaurant.
|
||||
target_field = self.lhs.field.name
|
||||
else:
|
||||
target_field = self.lhs.field.target_field.name
|
||||
self.rhs.set_values([target_field])
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if self.rhs_is_direct_value():
|
||||
values = [get_normalized_value(value, self.lhs) for value in self.rhs]
|
||||
lookup = TupleIn(self.lhs, values)
|
||||
else:
|
||||
lookup = TupleIn(self.lhs, self.rhs)
|
||||
return compiler.compile(lookup)
|
||||
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedLookupMixin:
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.lhs, ColPairs) and not hasattr(
|
||||
self.rhs, "resolve_expression"
|
||||
):
|
||||
# If we get here, we are dealing with single-column relations.
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
|
||||
# We need to run the related field's get_prep_value(). Consider case
|
||||
# ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
|
||||
# doesn't have validation for non-integers, so we must run validation
|
||||
# using the target field.
|
||||
if self.prepare_rhs and hasattr(self.lhs.output_field, "path_infos"):
|
||||
# Get the target field. We can safely assume there is only one
|
||||
# as we don't get to the direct value branch otherwise.
|
||||
target_field = self.lhs.output_field.path_infos[-1].target_fields[-1]
|
||||
self.rhs = target_field.get_prep_value(self.rhs)
|
||||
|
||||
return super().get_prep_lookup()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
if not self.rhs_is_direct_value():
|
||||
raise ValueError(
|
||||
f"'{self.lookup_name}' doesn't support multi-column subqueries."
|
||||
)
|
||||
self.rhs = get_normalized_value(self.rhs, self.lhs)
|
||||
lookup_class = tuple_lookups[self.lookup_name]
|
||||
lookup = lookup_class(self.lhs, self.rhs)
|
||||
return compiler.compile(lookup)
|
||||
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class RelatedExact(RelatedLookupMixin, Exact):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThan(RelatedLookupMixin, LessThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
|
||||
pass
|
||||
|
||||
|
||||
class RelatedIsNull(RelatedLookupMixin, IsNull):
|
||||
pass
|
||||
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
"Rel objects" for related fields.
|
||||
|
||||
"Rel objects" (for lack of a better name) carry information about the relation
|
||||
modeled by a related field and provide some utility functions. They're stored
|
||||
in the ``remote_field`` attribute of the field.
|
||||
|
||||
They also act as reverse fields for the purposes of the Meta API because
|
||||
they're the closest concept currently available.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from django.core import exceptions
|
||||
from django.utils.deprecation import RemovedInDjango60Warning
|
||||
from django.utils.functional import cached_property
|
||||
from django.utils.hashable import make_hashable
|
||||
|
||||
from . import BLANK_CHOICE_DASH
|
||||
from .mixins import FieldCacheMixin
|
||||
|
||||
|
||||
class ForeignObjectRel(FieldCacheMixin):
|
||||
"""
|
||||
Used by ForeignObject to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
# Field flags
|
||||
auto_created = True
|
||||
concrete = False
|
||||
editable = False
|
||||
is_relation = True
|
||||
|
||||
# Reverse relations are always nullable (Django can't enforce that a
|
||||
# foreign key on the related model points to this model).
|
||||
null = True
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
self.field = field
|
||||
self.model = to
|
||||
self.related_name = related_name
|
||||
self.related_query_name = related_query_name
|
||||
self.limit_choices_to = {} if limit_choices_to is None else limit_choices_to
|
||||
self.parent_link = parent_link
|
||||
self.on_delete = on_delete
|
||||
|
||||
self.symmetrical = False
|
||||
self.multiple = True
|
||||
|
||||
# Some of the following cached_properties can't be initialized in
|
||||
# __init__ as the field doesn't have its model yet. Calling these methods
|
||||
# before field.contribute_to_class() has been called will result in
|
||||
# AttributeError
|
||||
@cached_property
|
||||
def hidden(self):
|
||||
"""Should the related object be hidden?"""
|
||||
return bool(self.related_name) and self.related_name[-1] == "+"
|
||||
|
||||
@cached_property
|
||||
def name(self):
|
||||
return self.field.related_query_name()
|
||||
|
||||
@property
|
||||
def remote_field(self):
|
||||
return self.field
|
||||
|
||||
@property
|
||||
def target_field(self):
|
||||
"""
|
||||
When filtering against this relation, return the field on the remote
|
||||
model against which the filtering should happen.
|
||||
"""
|
||||
target_fields = self.path_infos[-1].target_fields
|
||||
if len(target_fields) > 1:
|
||||
raise exceptions.FieldError(
|
||||
"Can't use target_field for multicolumn relations."
|
||||
)
|
||||
return target_fields[0]
|
||||
|
||||
@cached_property
|
||||
def related_model(self):
|
||||
if not self.field.model:
|
||||
raise AttributeError(
|
||||
"This property can't be accessed before self.field.contribute_to_class "
|
||||
"has been called."
|
||||
)
|
||||
return self.field.model
|
||||
|
||||
@cached_property
|
||||
def many_to_many(self):
|
||||
return self.field.many_to_many
|
||||
|
||||
@cached_property
|
||||
def many_to_one(self):
|
||||
return self.field.one_to_many
|
||||
|
||||
@cached_property
|
||||
def one_to_many(self):
|
||||
return self.field.many_to_one
|
||||
|
||||
@cached_property
|
||||
def one_to_one(self):
|
||||
return self.field.one_to_one
|
||||
|
||||
def get_lookup(self, lookup_name):
|
||||
return self.field.get_lookup(lookup_name)
|
||||
|
||||
def get_lookups(self):
|
||||
return self.field.get_lookups()
|
||||
|
||||
def get_transform(self, name):
|
||||
return self.field.get_transform(name)
|
||||
|
||||
def get_internal_type(self):
|
||||
return self.field.get_internal_type()
|
||||
|
||||
@property
|
||||
def db_type(self):
|
||||
return self.field.db_type
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s: %s.%s>" % (
|
||||
type(self).__name__,
|
||||
self.related_model._meta.app_label,
|
||||
self.related_model._meta.model_name,
|
||||
)
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return (
|
||||
self.field,
|
||||
self.model,
|
||||
self.related_name,
|
||||
self.related_query_name,
|
||||
make_hashable(self.limit_choices_to),
|
||||
self.parent_link,
|
||||
self.on_delete,
|
||||
self.symmetrical,
|
||||
self.multiple,
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return NotImplemented
|
||||
return self.identity == other.identity
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.identity)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
# Delete the path_infos cached property because it can be recalculated
|
||||
# at first invocation after deserialization. The attribute must be
|
||||
# removed because subclasses like ManyToOneRel may have a PathInfo
|
||||
# which contains an intermediate M2M table that's been dynamically
|
||||
# created and doesn't exist in the .models module.
|
||||
# This is a reverse relation, so there is no reverse_path_infos to
|
||||
# delete.
|
||||
state.pop("path_infos", None)
|
||||
return state
|
||||
|
||||
def get_choices(
|
||||
self,
|
||||
include_blank=True,
|
||||
blank_choice=BLANK_CHOICE_DASH,
|
||||
limit_choices_to=None,
|
||||
ordering=(),
|
||||
):
|
||||
"""
|
||||
Return choices with a default blank choices included, for use
|
||||
as <select> choices for this field.
|
||||
|
||||
Analog of django.db.models.fields.Field.get_choices(), provided
|
||||
initially for utilization by RelatedFieldListFilter.
|
||||
"""
|
||||
limit_choices_to = limit_choices_to or self.limit_choices_to
|
||||
qs = self.related_model._default_manager.complex_filter(limit_choices_to)
|
||||
if ordering:
|
||||
qs = qs.order_by(*ordering)
|
||||
return (blank_choice if include_blank else []) + [(x.pk, str(x)) for x in qs]
|
||||
|
||||
def get_joining_columns(self):
|
||||
warnings.warn(
|
||||
"ForeignObjectRel.get_joining_columns() is deprecated. Use "
|
||||
"get_joining_fields() instead.",
|
||||
RemovedInDjango60Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return self.field.get_reverse_joining_columns()
|
||||
|
||||
def get_joining_fields(self):
|
||||
return self.field.get_reverse_joining_fields()
|
||||
|
||||
def get_extra_restriction(self, alias, related_alias):
|
||||
return self.field.get_extra_restriction(related_alias, alias)
|
||||
|
||||
def set_field_name(self):
|
||||
"""
|
||||
Set the related field's name, this is not available until later stages
|
||||
of app loading, so set_field_name is called from
|
||||
set_attributes_from_rel()
|
||||
"""
|
||||
# By default foreign object doesn't relate to any remote field (for
|
||||
# example custom multicolumn joins currently have no remote field).
|
||||
self.field_name = None
|
||||
|
||||
@cached_property
|
||||
def accessor_name(self):
|
||||
return self.get_accessor_name()
|
||||
|
||||
def get_accessor_name(self, model=None):
|
||||
# This method encapsulates the logic that decides what name to give an
|
||||
# accessor descriptor that retrieves related many-to-one or
|
||||
# many-to-many objects. It uses the lowercased object_name + "_set",
|
||||
# but this can be overridden with the "related_name" option. Due to
|
||||
# backwards compatibility ModelForms need to be able to provide an
|
||||
# alternate model. See BaseInlineFormSet.get_default_prefix().
|
||||
opts = model._meta if model else self.related_model._meta
|
||||
model = model or self.related_model
|
||||
if self.multiple:
|
||||
# If this is a symmetrical m2m relation on self, there is no
|
||||
# reverse accessor.
|
||||
if self.symmetrical and model == self.model:
|
||||
return None
|
||||
if self.related_name:
|
||||
return self.related_name
|
||||
return opts.model_name + ("_set" if self.multiple else "")
|
||||
|
||||
def get_path_info(self, filtered_relation=None):
|
||||
if filtered_relation:
|
||||
return self.field.get_reverse_path_info(filtered_relation)
|
||||
else:
|
||||
return self.field.reverse_path_infos
|
||||
|
||||
@cached_property
|
||||
def path_infos(self):
|
||||
return self.get_path_info()
|
||||
|
||||
@cached_property
|
||||
def cache_name(self):
|
||||
"""
|
||||
Return the name of the cache key to use for storing an instance of the
|
||||
forward model on the reverse model.
|
||||
"""
|
||||
return self.accessor_name
|
||||
|
||||
|
||||
class ManyToOneRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by the ForeignKey field to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
|
||||
Note: Because we somewhat abuse the Rel objects by using them as reverse
|
||||
fields we get the funny situation where
|
||||
``ManyToOneRel.many_to_one == False`` and
|
||||
``ManyToOneRel.one_to_many == True``. This is unfortunate but the actual
|
||||
ManyToOneRel class is a private API and there is work underway to turn
|
||||
reverse relations into actual fields.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.field_name = field_name
|
||||
|
||||
def __getstate__(self):
|
||||
state = super().__getstate__()
|
||||
state.pop("related_model", None)
|
||||
return state
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (self.field_name,)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the Field in the 'to' object to which this relationship is tied.
|
||||
"""
|
||||
field = self.model._meta.get_field(self.field_name)
|
||||
if not field.concrete:
|
||||
raise exceptions.FieldDoesNotExist(
|
||||
"No related field named '%s'" % self.field_name
|
||||
)
|
||||
return field
|
||||
|
||||
def set_field_name(self):
|
||||
self.field_name = self.field_name or self.model._meta.pk.name
|
||||
|
||||
|
||||
class OneToOneRel(ManyToOneRel):
|
||||
"""
|
||||
Used by OneToOneField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
parent_link=False,
|
||||
on_delete=None,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
field_name,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
parent_link=parent_link,
|
||||
on_delete=on_delete,
|
||||
)
|
||||
|
||||
self.multiple = False
|
||||
|
||||
|
||||
class ManyToManyRel(ForeignObjectRel):
|
||||
"""
|
||||
Used by ManyToManyField to store information about the relation.
|
||||
|
||||
``_meta.get_fields()`` returns this class to provide access to the field
|
||||
flags for the reverse relation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field,
|
||||
to,
|
||||
related_name=None,
|
||||
related_query_name=None,
|
||||
limit_choices_to=None,
|
||||
symmetrical=True,
|
||||
through=None,
|
||||
through_fields=None,
|
||||
db_constraint=True,
|
||||
):
|
||||
super().__init__(
|
||||
field,
|
||||
to,
|
||||
related_name=related_name,
|
||||
related_query_name=related_query_name,
|
||||
limit_choices_to=limit_choices_to,
|
||||
)
|
||||
|
||||
if through and not db_constraint:
|
||||
raise ValueError("Can't supply a through model and db_constraint=False")
|
||||
self.through = through
|
||||
|
||||
if through_fields and not through:
|
||||
raise ValueError("Cannot specify through_fields without a through model")
|
||||
self.through_fields = through_fields
|
||||
|
||||
self.symmetrical = symmetrical
|
||||
self.db_constraint = db_constraint
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return super().identity + (
|
||||
self.through,
|
||||
make_hashable(self.through_fields),
|
||||
self.db_constraint,
|
||||
)
|
||||
|
||||
def get_related_field(self):
|
||||
"""
|
||||
Return the field in the 'to' object to which this relationship is tied.
|
||||
Provided for symmetry with ManyToOneRel.
|
||||
"""
|
||||
opts = self.through._meta
|
||||
if self.through_fields:
|
||||
field = opts.get_field(self.through_fields[0])
|
||||
else:
|
||||
for field in opts.fields:
|
||||
rel = getattr(field, "remote_field", None)
|
||||
if rel and rel.model == self.model:
|
||||
break
|
||||
return field.foreign_related_fields[0]
|
||||
@@ -0,0 +1,359 @@
|
||||
import itertools
|
||||
|
||||
from django.core.exceptions import EmptyResultSet
|
||||
from django.db.models import Field
|
||||
from django.db.models.expressions import (
|
||||
ColPairs,
|
||||
Func,
|
||||
ResolvedOuterRef,
|
||||
Subquery,
|
||||
Value,
|
||||
)
|
||||
from django.db.models.lookups import (
|
||||
Exact,
|
||||
GreaterThan,
|
||||
GreaterThanOrEqual,
|
||||
In,
|
||||
IsNull,
|
||||
LessThan,
|
||||
LessThanOrEqual,
|
||||
)
|
||||
from django.db.models.sql import Query
|
||||
from django.db.models.sql.where import AND, OR, WhereNode
|
||||
|
||||
|
||||
class Tuple(Func):
|
||||
allows_composite_expressions = True
|
||||
function = ""
|
||||
output_field = Field()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.source_expressions)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.source_expressions)
|
||||
|
||||
|
||||
class TupleLookupMixin:
|
||||
allows_composite_expressions = True
|
||||
|
||||
def get_prep_lookup(self):
|
||||
if self.rhs_is_direct_value():
|
||||
self.check_rhs_is_tuple_or_list()
|
||||
self.check_rhs_length_equals_lhs_length()
|
||||
else:
|
||||
self.check_rhs_is_supported_expression()
|
||||
super().get_prep_lookup()
|
||||
return self.rhs
|
||||
|
||||
def check_rhs_is_tuple_or_list(self):
|
||||
if not isinstance(self.rhs, (tuple, list)):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} must be a tuple or a list"
|
||||
)
|
||||
|
||||
def check_rhs_length_equals_lhs_length(self):
|
||||
len_lhs = len(self.lhs)
|
||||
if len_lhs != len(self.rhs):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
|
||||
)
|
||||
|
||||
def check_rhs_is_supported_expression(self):
|
||||
if not isinstance(self.rhs, (ResolvedOuterRef, Query)):
|
||||
lhs_str = self.get_lhs_str()
|
||||
rhs_cls = self.rhs.__class__.__name__
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
||||
f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})"
|
||||
)
|
||||
|
||||
def get_lhs_str(self):
|
||||
if isinstance(self.lhs, ColPairs):
|
||||
return repr(self.lhs.field.name)
|
||||
else:
|
||||
names = ", ".join(repr(f.name) for f in self.lhs)
|
||||
return f"({names})"
|
||||
|
||||
def get_prep_lhs(self):
|
||||
if isinstance(self.lhs, (tuple, list)):
|
||||
return Tuple(*self.lhs)
|
||||
return super().get_prep_lhs()
|
||||
|
||||
def process_lhs(self, compiler, connection, lhs=None):
|
||||
sql, params = super().process_lhs(compiler, connection, lhs)
|
||||
if not isinstance(self.lhs, Tuple):
|
||||
sql = f"({sql})"
|
||||
return sql, params
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if self.rhs_is_direct_value():
|
||||
args = [
|
||||
Value(val, output_field=col.output_field)
|
||||
for col, val in zip(self.lhs, self.rhs)
|
||||
]
|
||||
return compiler.compile(Tuple(*args))
|
||||
else:
|
||||
sql, params = compiler.compile(self.rhs)
|
||||
if isinstance(self.rhs, ColPairs):
|
||||
return "(%s)" % sql, params
|
||||
elif isinstance(self.rhs, Query):
|
||||
return super().process_rhs(compiler, connection)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Composite field lookups only work with composite expressions."
|
||||
)
|
||||
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__}.get_fallback_sql() must be implemented "
|
||||
f"for backends that don't have the supports_tuple_lookups feature enabled."
|
||||
)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
if not connection.features.supports_tuple_lookups:
|
||||
return self.get_fallback_sql(compiler, connection)
|
||||
return super().as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleExact(TupleLookupMixin, Exact):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
if isinstance(self.rhs, Query):
|
||||
return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) == (x, y, z) as SQL:
|
||||
# WHERE a = x AND b = y AND c = z
|
||||
lookups = [Exact(col, val) for col, val in zip(self.lhs, self.rhs)]
|
||||
root = WhereNode(lookups, connector=AND)
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleIsNull(TupleLookupMixin, IsNull):
|
||||
def get_prep_lookup(self):
|
||||
rhs = self.rhs
|
||||
if isinstance(rhs, (tuple, list)) and len(rhs) == 1:
|
||||
rhs = rhs[0]
|
||||
if isinstance(rhs, bool):
|
||||
return rhs
|
||||
raise ValueError(
|
||||
"The QuerySet value for an isnull lookup must be True or False."
|
||||
)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
# e.g.: (a, b, c) is None as SQL:
|
||||
# WHERE a IS NULL OR b IS NULL OR c IS NULL
|
||||
# e.g.: (a, b, c) is not None as SQL:
|
||||
# WHERE a IS NOT NULL AND b IS NOT NULL AND c IS NOT NULL
|
||||
rhs = self.rhs
|
||||
lookups = [IsNull(col, rhs) for col in self.lhs]
|
||||
root = WhereNode(lookups, connector=OR if rhs else AND)
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleGreaterThan(TupleLookupMixin, GreaterThan):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) > (x, y, z) as SQL:
|
||||
# WHERE a > x OR (a = x AND (b > y OR (b = y AND c > z)))
|
||||
lookups = itertools.cycle([GreaterThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list[:-1])
|
||||
vals_iter = iter(vals_list[:-1])
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleGreaterThanOrEqual(TupleLookupMixin, GreaterThanOrEqual):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) >= (x, y, z) as SQL:
|
||||
# WHERE a > x OR (a = x AND (b > y OR (b = y AND (c > z OR c = z))))
|
||||
lookups = itertools.cycle([GreaterThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list)
|
||||
vals_iter = iter(vals_list)
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleLessThan(TupleLookupMixin, LessThan):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) < (x, y, z) as SQL:
|
||||
# WHERE a < x OR (a = x AND (b < y OR (b = y AND c < z)))
|
||||
lookups = itertools.cycle([LessThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list[:-1])
|
||||
vals_iter = iter(vals_list[:-1])
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleLessThanOrEqual(TupleLookupMixin, LessThanOrEqual):
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
# Process right-hand-side to trigger sanitization.
|
||||
self.process_rhs(compiler, connection)
|
||||
# e.g.: (a, b, c) <= (x, y, z) as SQL:
|
||||
# WHERE a < x OR (a = x AND (b < y OR (b = y AND (c < z OR c = z))))
|
||||
lookups = itertools.cycle([LessThan, Exact])
|
||||
connectors = itertools.cycle([OR, AND])
|
||||
cols_list = [col for col in self.lhs for _ in range(2)]
|
||||
vals_list = [val for val in self.rhs for _ in range(2)]
|
||||
cols_iter = iter(cols_list)
|
||||
vals_iter = iter(vals_list)
|
||||
col = next(cols_iter)
|
||||
val = next(vals_iter)
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
root = node = WhereNode([lookup(col, val)], connector=connector)
|
||||
|
||||
for col, val in zip(cols_iter, vals_iter):
|
||||
lookup = next(lookups)
|
||||
connector = next(connectors)
|
||||
child = WhereNode([lookup(col, val)], connector=connector)
|
||||
node.children.append(child)
|
||||
node = child
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
class TupleIn(TupleLookupMixin, In):
|
||||
def get_prep_lookup(self):
|
||||
if self.rhs_is_direct_value():
|
||||
self.check_rhs_is_tuple_or_list()
|
||||
self.check_rhs_is_collection_of_tuples_or_lists()
|
||||
self.check_rhs_elements_length_equals_lhs_length()
|
||||
else:
|
||||
self.check_rhs_is_query()
|
||||
super(TupleLookupMixin, self).get_prep_lookup()
|
||||
|
||||
return self.rhs # skip checks from mixin
|
||||
|
||||
def check_rhs_is_collection_of_tuples_or_lists(self):
|
||||
if not all(isinstance(vals, (tuple, list)) for vals in self.rhs):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} "
|
||||
"must be a collection of tuples or lists"
|
||||
)
|
||||
|
||||
def check_rhs_elements_length_equals_lhs_length(self):
|
||||
len_lhs = len(self.lhs)
|
||||
if not all(len_lhs == len(vals) for vals in self.rhs):
|
||||
lhs_str = self.get_lhs_str()
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} lookup of {lhs_str} "
|
||||
f"must have {len_lhs} elements each"
|
||||
)
|
||||
|
||||
def check_rhs_is_query(self):
|
||||
if not isinstance(self.rhs, (Query, Subquery)):
|
||||
lhs_str = self.get_lhs_str()
|
||||
rhs_cls = self.rhs.__class__.__name__
|
||||
raise ValueError(
|
||||
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
|
||||
f"must be a Query object (received {rhs_cls!r})"
|
||||
)
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
if not self.rhs_is_direct_value():
|
||||
return super(TupleLookupMixin, self).process_rhs(compiler, connection)
|
||||
|
||||
rhs = self.rhs
|
||||
if not rhs:
|
||||
raise EmptyResultSet
|
||||
|
||||
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
||||
# WHERE (a, b, c) IN ((x1, y1, z1), (x2, y2, z2))
|
||||
result = []
|
||||
lhs = self.lhs
|
||||
|
||||
for vals in rhs:
|
||||
result.append(
|
||||
Tuple(
|
||||
*[
|
||||
Value(val, output_field=col.output_field)
|
||||
for col, val in zip(lhs, vals)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return compiler.compile(Tuple(*result))
|
||||
|
||||
def get_fallback_sql(self, compiler, connection):
|
||||
rhs = self.rhs
|
||||
if not rhs:
|
||||
raise EmptyResultSet
|
||||
if not self.rhs_is_direct_value():
|
||||
return super(TupleLookupMixin, self).as_sql(compiler, connection)
|
||||
|
||||
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
|
||||
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
|
||||
root = WhereNode([], connector=OR)
|
||||
lhs = self.lhs
|
||||
|
||||
for vals in rhs:
|
||||
lookups = [Exact(col, val) for col, val in zip(lhs, vals)]
|
||||
root.children.append(WhereNode(lookups, connector=AND))
|
||||
|
||||
return root.as_sql(compiler, connection)
|
||||
|
||||
|
||||
tuple_lookups = {
|
||||
"exact": TupleExact,
|
||||
"gt": TupleGreaterThan,
|
||||
"gte": TupleGreaterThanOrEqual,
|
||||
"lt": TupleLessThan,
|
||||
"lte": TupleLessThanOrEqual,
|
||||
"in": TupleIn,
|
||||
"isnull": TupleIsNull,
|
||||
}
|
||||
Reference in New Issue
Block a user