init commit

This commit is contained in:
2025-05-06 20:44:33 +09:00
commit 91f0d54563
5567 changed files with 948185 additions and 0 deletions

View File

@@ -0,0 +1,27 @@
r"""
______ _____ _____ _____ __
| ___ \ ___/ ___|_ _| / _| | |
| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| |__
| /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ /
| |\ \| |___/\__/ / | | | | | | | (_| | | | | | | __/\ V V / (_) | | | <
\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
"""
__title__ = 'Django REST framework'
__version__ = '3.16.0'
__author__ = 'Tom Christie'
__license__ = 'BSD 3-Clause'
__copyright__ = 'Copyright 2011-2023 Encode OSS Ltd'
# Version synonym
VERSION = __version__
# Header encoding (see RFC5987)
HTTP_HEADER_ENCODING = 'iso-8859-1'
# Default datetime input and output formats
ISO_8601 = 'iso-8601'
class RemovedInDRF317Warning(PendingDeprecationWarning):
pass

View File

@@ -0,0 +1,10 @@
from django.apps import AppConfig
class RestFrameworkConfig(AppConfig):
name = 'rest_framework'
verbose_name = "Django REST framework"
def ready(self):
# Add System checks
from .checks import pagination_system_check # NOQA

View File

@@ -0,0 +1,232 @@
"""
Provides various authentication policies.
"""
import base64
import binascii
from django.contrib.auth import authenticate, get_user_model
from django.middleware.csrf import CsrfViewMiddleware
from django.utils.translation import gettext_lazy as _
from rest_framework import HTTP_HEADER_ENCODING, exceptions
def get_authorization_header(request):
"""
Return request's 'Authorization:' header, as a bytestring.
Hide some test client ickyness where the header can be unicode.
"""
auth = request.META.get('HTTP_AUTHORIZATION', b'')
if isinstance(auth, str):
# Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING)
return auth
class CSRFCheck(CsrfViewMiddleware):
def _reject(self, request, reason):
# Return the failure reason instead of an HttpResponse
return reason
class BaseAuthentication:
"""
All authentication classes should extend BaseAuthentication.
"""
def authenticate(self, request):
"""
Authenticate the request and return a two-tuple of (user, token).
"""
raise NotImplementedError(".authenticate() must be overridden.")
def authenticate_header(self, request):
"""
Return a string to be used as the value of the `WWW-Authenticate`
header in a `401 Unauthenticated` response, or `None` if the
authentication scheme should return `403 Permission Denied` responses.
"""
pass
class BasicAuthentication(BaseAuthentication):
"""
HTTP Basic authentication against username/password.
"""
www_authenticate_realm = 'api'
def authenticate(self, request):
"""
Returns a `User` if a correct username and password have been supplied
using HTTP Basic authentication. Otherwise returns `None`.
"""
auth = get_authorization_header(request).split()
if not auth or auth[0].lower() != b'basic':
return None
if len(auth) == 1:
msg = _('Invalid basic header. No credentials provided.')
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
msg = _('Invalid basic header. Credentials string should not contain spaces.')
raise exceptions.AuthenticationFailed(msg)
try:
try:
auth_decoded = base64.b64decode(auth[1]).decode('utf-8')
except UnicodeDecodeError:
auth_decoded = base64.b64decode(auth[1]).decode('latin-1')
userid, password = auth_decoded.split(':', 1)
except (TypeError, ValueError, UnicodeDecodeError, binascii.Error):
msg = _('Invalid basic header. Credentials not correctly base64 encoded.')
raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(userid, password, request)
def authenticate_credentials(self, userid, password, request=None):
"""
Authenticate the userid and password against username and password
with optional request for context.
"""
credentials = {
get_user_model().USERNAME_FIELD: userid,
'password': password
}
user = authenticate(request=request, **credentials)
if user is None:
raise exceptions.AuthenticationFailed(_('Invalid username/password.'))
if not user.is_active:
raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
return (user, None)
def authenticate_header(self, request):
return 'Basic realm="%s"' % self.www_authenticate_realm
class SessionAuthentication(BaseAuthentication):
"""
Use Django's session framework for authentication.
"""
def authenticate(self, request):
"""
Returns a `User` if the request session currently has a logged in user.
Otherwise returns `None`.
"""
# Get the session-based user from the underlying HttpRequest object
user = getattr(request._request, 'user', None)
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
return None
self.enforce_csrf(request)
# CSRF passed with authenticated user
return (user, None)
def enforce_csrf(self, request):
"""
Enforce CSRF validation for session based authentication.
"""
def dummy_get_response(request): # pragma: no cover
return None
check = CSRFCheck(dummy_get_response)
# populates request.META['CSRF_COOKIE'], which is used in process_view()
check.process_request(request)
reason = check.process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
class TokenAuthentication(BaseAuthentication):
"""
Simple token based authentication.
Clients should authenticate by passing the token key in the "Authorization"
HTTP header, prepended with the string "Token ". For example:
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
"""
keyword = 'Token'
model = None
def get_model(self):
if self.model is not None:
return self.model
from rest_framework.authtoken.models import Token
return Token
"""
A custom token model may be used, but must have the following properties.
* key -- The string identifying the token
* user -- The user to which the token belongs
"""
def authenticate(self, request):
auth = get_authorization_header(request).split()
if not auth or auth[0].lower() != self.keyword.lower().encode():
return None
if len(auth) == 1:
msg = _('Invalid token header. No credentials provided.')
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
msg = _('Invalid token header. Token string should not contain spaces.')
raise exceptions.AuthenticationFailed(msg)
try:
token = auth[1].decode()
except UnicodeError:
msg = _('Invalid token header. Token string should not contain invalid characters.')
raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(token)
def authenticate_credentials(self, key):
model = self.get_model()
try:
token = model.objects.select_related('user').get(key=key)
except model.DoesNotExist:
raise exceptions.AuthenticationFailed(_('Invalid token.'))
if not token.user.is_active:
raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
return (token.user, token)
def authenticate_header(self, request):
return self.keyword
class RemoteUserAuthentication(BaseAuthentication):
"""
REMOTE_USER authentication.
To use this, set up your web server to perform authentication, which will
set the REMOTE_USER environment variable. You will need to have
'django.contrib.auth.backends.RemoteUserBackend in your
AUTHENTICATION_BACKENDS setting
"""
# Name of request header to grab username from. This will be the key as
# used in the request.META dictionary, i.e. the normalization of headers to
# all uppercase and the addition of "HTTP_" prefix apply.
header = "REMOTE_USER"
def authenticate(self, request):
user = authenticate(request=request, remote_user=request.META.get(self.header))
if user and user.is_active:
return (user, None)

View File

@@ -0,0 +1,54 @@
from django.contrib import admin
from django.contrib.admin.utils import quote
from django.contrib.admin.views.main import ChangeList
from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from rest_framework.authtoken.models import Token, TokenProxy
User = get_user_model()
class TokenChangeList(ChangeList):
"""Map to matching User id"""
def url_for_result(self, result):
pk = result.user.pk
return reverse('admin:%s_%s_change' % (self.opts.app_label,
self.opts.model_name),
args=(quote(pk),),
current_app=self.model_admin.admin_site.name)
class TokenAdmin(admin.ModelAdmin):
list_display = ('key', 'user', 'created')
fields = ('user',)
search_fields = ('user__username',)
search_help_text = _('Username')
ordering = ('-created',)
actions = None # Actions not compatible with mapped IDs.
def get_changelist(self, request, **kwargs):
return TokenChangeList
def get_object(self, request, object_id, from_field=None):
"""
Map from User ID to matching Token.
"""
queryset = self.get_queryset(request)
field = User._meta.pk
try:
object_id = field.to_python(object_id)
user = User.objects.get(**{field.name: object_id})
return queryset.get(user=user)
except (queryset.model.DoesNotExist, User.DoesNotExist, ValidationError, ValueError):
return None
def delete_model(self, request, obj):
# Map back to actual Token, since delete() uses pk.
token = Token.objects.get(key=obj.key)
return super().delete_model(request, token)
admin.site.register(TokenProxy, TokenAdmin)

View File

@@ -0,0 +1,7 @@
from django.apps import AppConfig
from django.utils.translation import gettext_lazy as _
class AuthTokenConfig(AppConfig):
name = 'rest_framework.authtoken'
verbose_name = _("Auth Token")

View File

@@ -0,0 +1,45 @@
from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand, CommandError
from rest_framework.authtoken.models import Token
UserModel = get_user_model()
class Command(BaseCommand):
help = 'Create DRF Token for a given user'
def create_user_token(self, username, reset_token):
user = UserModel._default_manager.get_by_natural_key(username)
if reset_token:
Token.objects.filter(user=user).delete()
token = Token.objects.get_or_create(user=user)
return token[0]
def add_arguments(self, parser):
parser.add_argument('username', type=str)
parser.add_argument(
'-r',
'--reset',
action='store_true',
dest='reset_token',
default=False,
help='Reset existing User token and create a new one',
)
def handle(self, *args, **options):
username = options['username']
reset_token = options['reset_token']
try:
token = self.create_user_token(username, reset_token)
except UserModel.DoesNotExist:
raise CommandError(
'Cannot create the Token: user {} does not exist'.format(
username)
)
self.stdout.write(
'Generated token {} for user {}'.format(token.key, username))

View File

@@ -0,0 +1,23 @@
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.CreateModel(
name='Token',
fields=[
('key', models.CharField(primary_key=True, serialize=False, max_length=40)),
('created', models.DateTimeField(auto_now_add=True)),
('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, related_name='auth_token', on_delete=models.CASCADE)),
],
options={
},
bases=(models.Model,),
),
]

View File

@@ -0,0 +1,31 @@
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authtoken', '0001_initial'),
]
operations = [
migrations.AlterModelOptions(
name='token',
options={'verbose_name_plural': 'Tokens', 'verbose_name': 'Token'},
),
migrations.AlterField(
model_name='token',
name='created',
field=models.DateTimeField(verbose_name='Created', auto_now_add=True),
),
migrations.AlterField(
model_name='token',
name='key',
field=models.CharField(verbose_name='Key', max_length=40, primary_key=True, serialize=False),
),
migrations.AlterField(
model_name='token',
name='user',
field=models.OneToOneField(to=settings.AUTH_USER_MODEL, verbose_name='User', related_name='auth_token', on_delete=models.CASCADE),
),
]

View File

@@ -0,0 +1,25 @@
# Generated by Django 3.1.1 on 2020-09-28 09:34
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('authtoken', '0002_auto_20160226_1747'),
]
operations = [
migrations.CreateModel(
name='TokenProxy',
fields=[
],
options={
'verbose_name': 'token',
'proxy': True,
'indexes': [],
'constraints': [],
},
bases=('authtoken.token',),
),
]

View File

@@ -0,0 +1,17 @@
# Generated by Django 4.1.3 on 2022-11-24 21:07
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('authtoken', '0003_tokenproxy'),
]
operations = [
migrations.AlterModelOptions(
name='tokenproxy',
options={'verbose_name': 'Token', 'verbose_name_plural': 'Tokens'},
),
]

View File

@@ -0,0 +1,55 @@
import binascii
import os
from django.conf import settings
from django.db import models
from django.utils.translation import gettext_lazy as _
class Token(models.Model):
"""
The default authorization token model.
"""
key = models.CharField(_("Key"), max_length=40, primary_key=True)
user = models.OneToOneField(
settings.AUTH_USER_MODEL, related_name='auth_token',
on_delete=models.CASCADE, verbose_name=_("User")
)
created = models.DateTimeField(_("Created"), auto_now_add=True)
class Meta:
# Work around for a bug in Django:
# https://code.djangoproject.com/ticket/19422
#
# Also see corresponding ticket:
# https://github.com/encode/django-rest-framework/issues/705
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
verbose_name = _("Token")
verbose_name_plural = _("Tokens")
def save(self, *args, **kwargs):
if not self.key:
self.key = self.generate_key()
return super().save(*args, **kwargs)
@classmethod
def generate_key(cls):
return binascii.hexlify(os.urandom(20)).decode()
def __str__(self):
return self.key
class TokenProxy(Token):
"""
Proxy mapping pk to user pk for use in admin.
"""
@property
def pk(self):
return self.user_id
class Meta:
proxy = 'rest_framework.authtoken' in settings.INSTALLED_APPS
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
verbose_name = _("Token")
verbose_name_plural = _("Tokens")

View File

@@ -0,0 +1,42 @@
from django.contrib.auth import authenticate
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
class AuthTokenSerializer(serializers.Serializer):
username = serializers.CharField(
label=_("Username"),
write_only=True
)
password = serializers.CharField(
label=_("Password"),
style={'input_type': 'password'},
trim_whitespace=False,
write_only=True
)
token = serializers.CharField(
label=_("Token"),
read_only=True
)
def validate(self, attrs):
username = attrs.get('username')
password = attrs.get('password')
if username and password:
user = authenticate(request=self.context.get('request'),
username=username, password=password)
# The authenticate call simply returns None for is_active=False
# users. (Assuming the default ModelBackend authentication
# backend.)
if not user:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg, code='authorization')
else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg, code='authorization')
attrs['user'] = user
return attrs

View File

@@ -0,0 +1,62 @@
from rest_framework import parsers, renderers
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.compat import coreapi, coreschema
from rest_framework.response import Response
from rest_framework.schemas import ManualSchema
from rest_framework.schemas import coreapi as coreapi_schema
from rest_framework.views import APIView
class ObtainAuthToken(APIView):
throttle_classes = ()
permission_classes = ()
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
renderer_classes = (renderers.JSONRenderer,)
serializer_class = AuthTokenSerializer
if coreapi_schema.is_enabled():
schema = ManualSchema(
fields=[
coreapi.Field(
name="username",
required=True,
location='form',
schema=coreschema.String(
title="Username",
description="Valid username for authentication",
),
),
coreapi.Field(
name="password",
required=True,
location='form',
schema=coreschema.String(
title="Password",
description="Valid password for authentication",
),
),
],
encoding="application/json",
)
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def get_serializer(self, *args, **kwargs):
kwargs['context'] = self.get_serializer_context()
return self.serializer_class(*args, **kwargs)
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key})
obtain_auth_token = ObtainAuthToken.as_view()

View File

@@ -0,0 +1,21 @@
from django.core.checks import Tags, Warning, register
@register(Tags.compatibility)
def pagination_system_check(app_configs, **kwargs):
errors = []
# Use of default page size setting requires a default Paginator class
from rest_framework.settings import api_settings
if api_settings.PAGE_SIZE and not api_settings.DEFAULT_PAGINATION_CLASS:
errors.append(
Warning(
"You have specified a default PAGE_SIZE pagination rest_framework setting, "
"without specifying also a DEFAULT_PAGINATION_CLASS.",
hint="The default for DEFAULT_PAGINATION_CLASS is None. "
"In previous versions this was PageNumberPagination. "
"If you wish to define PAGE_SIZE globally whilst defining "
"pagination_class on a per-view basis you may silence this check.",
id="rest_framework.W001"
)
)
return errors

View File

@@ -0,0 +1,209 @@
"""
The `compat` module provides support for backwards compatibility with older
versions of Django/Python, and compatibility wrappers around optional packages.
"""
import django
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.db.models.sql.query import Node
from django.views.generic import View
def unicode_http_header(value):
# Coerce HTTP header value to unicode.
if isinstance(value, bytes):
return value.decode('iso-8859-1')
return value
# django.contrib.postgres requires psycopg2
try:
from django.contrib.postgres import fields as postgres_fields
except ImportError:
postgres_fields = None
# coreapi is required for CoreAPI schema generation
try:
import coreapi
except ImportError:
coreapi = None
# uritemplate is required for OpenAPI and CoreAPI schema generation
try:
import uritemplate
except ImportError:
uritemplate = None
# coreschema is optional
try:
import coreschema
except ImportError:
coreschema = None
# pyyaml is optional
try:
import yaml
except ImportError:
yaml = None
# inflection is optional
try:
import inflection
except ImportError:
inflection = None
# requests is optional
try:
import requests
except ImportError:
requests = None
# PATCH method is not implemented by Django
if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch']
# Markdown is optional (version 3.0+ required)
try:
import markdown
HEADERID_EXT_PATH = 'markdown.extensions.toc'
LEVEL_PARAM = 'baselevel'
def apply_markdown(text):
"""
Simple wrapper around :func:`markdown.markdown` to set the base level
of '#' style headers to <h2>.
"""
extensions = [HEADERID_EXT_PATH]
extension_configs = {
HEADERID_EXT_PATH: {
LEVEL_PARAM: '2'
}
}
md = markdown.Markdown(
extensions=extensions, extension_configs=extension_configs
)
md_filter_add_syntax_highlight(md)
return md.convert(text)
except ImportError:
apply_markdown = None
markdown = None
try:
import pygments
from pygments.formatters import HtmlFormatter
from pygments.lexers import TextLexer, get_lexer_by_name
def pygments_highlight(text, lang, style):
lexer = get_lexer_by_name(lang, stripall=False)
formatter = HtmlFormatter(nowrap=True, style=style)
return pygments.highlight(text, lexer, formatter)
def pygments_css(style):
formatter = HtmlFormatter(style=style)
return formatter.get_style_defs('.highlight')
except ImportError:
pygments = None
def pygments_highlight(text, lang, style):
return text
def pygments_css(style):
return None
if markdown is not None and pygments is not None:
# starting from this blogpost and modified to support current markdown extensions API
# https://zerokspot.com/weblog/2008/06/18/syntax-highlighting-in-markdown-with-pygments/
import re
from markdown.preprocessors import Preprocessor
class CodeBlockPreprocessor(Preprocessor):
pattern = re.compile(
r'^\s*``` *([^\n]+)\n(.+?)^\s*```', re.M | re.S)
formatter = HtmlFormatter()
def run(self, lines):
def repl(m):
try:
lexer = get_lexer_by_name(m.group(1))
except (ValueError, NameError):
lexer = TextLexer()
code = m.group(2).replace('\t', ' ')
code = pygments.highlight(code, lexer, self.formatter)
code = code.replace('\n\n', '\n&nbsp;\n').replace('\n', '<br />').replace('\\@', '@')
return '\n\n%s\n\n' % code
ret = self.pattern.sub(repl, "\n".join(lines))
return ret.split("\n")
def md_filter_add_syntax_highlight(md):
md.preprocessors.register(CodeBlockPreprocessor(), 'highlight', 40)
return True
else:
def md_filter_add_syntax_highlight(md):
return False
if django.VERSION >= (5, 1):
# Django 5.1+: use the stock ip_address_validators function
# Note: Before Django 5.1, ip_address_validators returns a tuple containing
# 1) the list of validators and 2) the error message. Starting from
# Django 5.1 ip_address_validators only returns the list of validators
from django.core.validators import ip_address_validators
def get_referenced_base_fields_from_q(q):
return q.referenced_base_fields
else:
# Django <= 5.1: create a compatibility shim for ip_address_validators
from django.core.validators import \
ip_address_validators as _ip_address_validators
def ip_address_validators(protocol, unpack_ipv4):
return _ip_address_validators(protocol, unpack_ipv4)[0]
# Django < 5.1: create a compatibility shim for Q.referenced_base_fields
# https://github.com/django/django/blob/5.1a1/django/db/models/query_utils.py#L179
def _get_paths_from_expression(expr):
if isinstance(expr, models.F):
yield expr.name
elif hasattr(expr, 'flatten'):
for child in expr.flatten():
if isinstance(child, models.F):
yield child.name
elif isinstance(child, models.Q):
yield from _get_children_from_q(child)
def _get_children_from_q(q):
for child in q.children:
if isinstance(child, Node):
yield from _get_children_from_q(child)
elif isinstance(child, tuple):
lhs, rhs = child
yield lhs
if hasattr(rhs, 'resolve_expression'):
yield from _get_paths_from_expression(rhs)
elif hasattr(child, 'resolve_expression'):
yield from _get_paths_from_expression(child)
def get_referenced_base_fields_from_q(q):
return {
child.split(LOOKUP_SEP, 1)[0] for child in _get_children_from_q(q)
}
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: https://bugs.python.org/issue22767
SHORT_SEPARATORS = (',', ':')
LONG_SEPARATORS = (', ', ': ')
INDENT_SEPARATORS = (',', ': ')

View File

@@ -0,0 +1,233 @@
"""
The most important decorator in this module is `@api_view`, which is used
for writing function-based views with REST framework.
There are also various decorators for setting the API policies on function
based views, as well as the `@action` decorator, which is used to annotate
methods on viewsets that should be included by routers.
"""
import types
from django.forms.utils import pretty_name
from rest_framework.views import APIView
def api_view(http_method_names=None):
"""
Decorator that converts a function-based view into an APIView subclass.
Takes a list of allowed methods for the view as an argument.
"""
http_method_names = ['GET'] if (http_method_names is None) else http_method_names
def decorator(func):
WrappedAPIView = type(
'WrappedAPIView',
(APIView,),
{'__doc__': func.__doc__}
)
# Note, the above allows us to set the docstring.
# It is the equivalent of:
#
# class WrappedAPIView(APIView):
# pass
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
# api_view applied without (method_names)
assert not isinstance(http_method_names, types.FunctionType), \
'@api_view missing list of allowed HTTP methods'
# api_view applied with eg. string instead of list of strings
assert isinstance(http_method_names, (list, tuple)), \
'@api_view expected a list of strings, received %s' % type(http_method_names).__name__
allowed_methods = set(http_method_names) | {'options'}
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
def handler(self, *args, **kwargs):
return func(*args, **kwargs)
for method in http_method_names:
setattr(WrappedAPIView, method.lower(), handler)
WrappedAPIView.__name__ = func.__name__
WrappedAPIView.__module__ = func.__module__
WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
APIView.renderer_classes)
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
APIView.parser_classes)
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
APIView.authentication_classes)
WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
APIView.throttle_classes)
WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
APIView.permission_classes)
WrappedAPIView.schema = getattr(func, 'schema',
APIView.schema)
return WrappedAPIView.as_view()
return decorator
def renderer_classes(renderer_classes):
def decorator(func):
func.renderer_classes = renderer_classes
return func
return decorator
def parser_classes(parser_classes):
def decorator(func):
func.parser_classes = parser_classes
return func
return decorator
def authentication_classes(authentication_classes):
def decorator(func):
func.authentication_classes = authentication_classes
return func
return decorator
def throttle_classes(throttle_classes):
def decorator(func):
func.throttle_classes = throttle_classes
return func
return decorator
def permission_classes(permission_classes):
def decorator(func):
func.permission_classes = permission_classes
return func
return decorator
def schema(view_inspector):
def decorator(func):
func.schema = view_inspector
return func
return decorator
def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs):
"""
Mark a ViewSet method as a routable action.
`@action`-decorated functions will be endowed with a `mapping` property,
a `MethodMapper` that can be used to add additional method-based behaviors
on the routed action.
:param methods: A list of HTTP method names this action responds to.
Defaults to GET only.
:param detail: Required. Determines whether this action applies to
instance/detail requests or collection/list requests.
:param url_path: Define the URL segment for this action. Defaults to the
name of the method decorated.
:param url_name: Define the internal (`reverse`) URL name for this action.
Defaults to the name of the method decorated with underscores
replaced with dashes.
:param kwargs: Additional properties to set on the view. This can be used
to override viewset-level *_classes settings, equivalent to
how the `@renderer_classes` etc. decorators work for function-
based API views.
"""
methods = ['get'] if methods is None else methods
methods = [method.lower() for method in methods]
assert detail is not None, (
"@action() missing required argument: 'detail'"
)
# name and suffix are mutually exclusive
if 'name' in kwargs and 'suffix' in kwargs:
raise TypeError("`name` and `suffix` are mutually exclusive arguments.")
def decorator(func):
func.mapping = MethodMapper(func, methods)
func.detail = detail
func.url_path = url_path if url_path else func.__name__
func.url_name = url_name if url_name else func.__name__.replace('_', '-')
# These kwargs will end up being passed to `ViewSet.as_view()` within
# the router, which eventually delegates to Django's CBV `View`,
# which assigns them as instance attributes for each request.
func.kwargs = kwargs
# Set descriptive arguments for viewsets
if 'name' not in kwargs and 'suffix' not in kwargs:
func.kwargs['name'] = pretty_name(func.__name__)
func.kwargs['description'] = func.__doc__ or None
return func
return decorator
class MethodMapper(dict):
"""
Enables mapping HTTP methods to different ViewSet methods for a single,
logical action.
Example usage:
class MyViewSet(ViewSet):
@action(detail=False)
def example(self, request, **kwargs):
...
@example.mapping.post
def create_example(self, request, **kwargs):
...
"""
def __init__(self, action, methods):
self.action = action
for method in methods:
self[method] = self.action.__name__
def _map(self, method, func):
assert method not in self, (
"Method '%s' has already been mapped to '.%s'." % (method, self[method]))
assert func.__name__ != self.action.__name__, (
"Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration.")
self[method] = func.__name__
return func
def get(self, func):
return self._map('get', func)
def post(self, func):
return self._map('post', func)
def put(self, func):
return self._map('put', func)
def patch(self, func):
return self._map('patch', func)
def delete(self, func):
return self._map('delete', func)
def head(self, func):
return self._map('head', func)
def options(self, func):
return self._map('options', func)
def trace(self, func):
return self._map('trace', func)

View File

@@ -0,0 +1,88 @@
from django.urls import include, path
from rest_framework.renderers import (
CoreJSONRenderer, DocumentationRenderer, SchemaJSRenderer
)
from rest_framework.schemas import SchemaGenerator, get_schema_view
from rest_framework.settings import api_settings
def get_docs_view(
title=None, description=None, schema_url=None, urlconf=None,
public=True, patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None):
if renderer_classes is None:
renderer_classes = [DocumentationRenderer, CoreJSONRenderer]
return get_schema_view(
title=title,
url=schema_url,
urlconf=urlconf,
description=description,
renderer_classes=renderer_classes,
public=public,
patterns=patterns,
generator_class=generator_class,
authentication_classes=authentication_classes,
permission_classes=permission_classes,
)
def get_schemajs_view(
title=None, description=None, schema_url=None, urlconf=None,
public=True, patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
renderer_classes = [SchemaJSRenderer]
return get_schema_view(
title=title,
url=schema_url,
urlconf=urlconf,
description=description,
renderer_classes=renderer_classes,
public=public,
patterns=patterns,
generator_class=generator_class,
authentication_classes=authentication_classes,
permission_classes=permission_classes,
)
def include_docs_urls(
title=None, description=None, schema_url=None, urlconf=None,
public=True, patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None):
docs_view = get_docs_view(
title=title,
description=description,
schema_url=schema_url,
urlconf=urlconf,
public=public,
patterns=patterns,
generator_class=generator_class,
authentication_classes=authentication_classes,
renderer_classes=renderer_classes,
permission_classes=permission_classes,
)
schema_js_view = get_schemajs_view(
title=title,
description=description,
schema_url=schema_url,
urlconf=urlconf,
public=public,
patterns=patterns,
generator_class=generator_class,
authentication_classes=authentication_classes,
permission_classes=permission_classes,
)
urls = [
path('', docs_view, name='docs-index'),
path('schema.js', schema_js_view, name='schema-js')
]
return include((urls, 'api-docs'), namespace='api-docs')

View File

@@ -0,0 +1,264 @@
"""
Handled exceptions raised by REST framework.
In addition, Django's built in 403 and 404 exceptions are handled.
(`django.http.Http404` and `django.core.exceptions.PermissionDenied`)
"""
import math
from django.http import JsonResponse
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
from django.utils.translation import ngettext
from rest_framework import status
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList
def _get_error_details(data, default_code=None):
"""
Descend into a nested data structure, forcing any
lazy translation strings or strings into `ErrorDetail`.
"""
if isinstance(data, (list, tuple)):
ret = [
_get_error_details(item, default_code) for item in data
]
if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer)
return ret
elif isinstance(data, dict):
ret = {
key: _get_error_details(value, default_code)
for key, value in data.items()
}
if isinstance(data, ReturnDict):
return ReturnDict(ret, serializer=data.serializer)
return ret
text = force_str(data)
code = getattr(data, 'code', default_code)
return ErrorDetail(text, code)
def _get_codes(detail):
if isinstance(detail, list):
return [_get_codes(item) for item in detail]
elif isinstance(detail, dict):
return {key: _get_codes(value) for key, value in detail.items()}
return detail.code
def _get_full_details(detail):
if isinstance(detail, list):
return [_get_full_details(item) for item in detail]
elif isinstance(detail, dict):
return {key: _get_full_details(value) for key, value in detail.items()}
return {
'message': detail,
'code': detail.code
}
class ErrorDetail(str):
"""
A string-like object that can additionally have a code.
"""
code = None
def __new__(cls, string, code=None):
self = super().__new__(cls, string)
self.code = code
return self
def __eq__(self, other):
result = super().__eq__(other)
if result is NotImplemented:
return NotImplemented
try:
return result and self.code == other.code
except AttributeError:
return result
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return NotImplemented
return not result
def __repr__(self):
return 'ErrorDetail(string=%r, code=%r)' % (
str(self),
self.code,
)
def __hash__(self):
return hash(str(self))
class APIException(Exception):
"""
Base class for REST framework exceptions.
Subclasses should provide `.status_code` and `.default_detail` properties.
"""
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
default_detail = _('A server error occurred.')
default_code = 'error'
def __init__(self, detail=None, code=None):
if detail is None:
detail = self.default_detail
if code is None:
code = self.default_code
self.detail = _get_error_details(detail, code)
def __str__(self):
return str(self.detail)
def get_codes(self):
"""
Return only the code part of the error details.
Eg. {"name": ["required"]}
"""
return _get_codes(self.detail)
def get_full_details(self):
"""
Return both the message & code parts of the error details.
Eg. {"name": [{"message": "This field is required.", "code": "required"}]}
"""
return _get_full_details(self.detail)
# The recommended style for using `ValidationError` is to keep it namespaced
# under `serializers`, in order to minimize potential confusion with Django's
# built in `ValidationError`. For example:
#
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Invalid input.')
default_code = 'invalid'
def __init__(self, detail=None, code=None):
if detail is None:
detail = self.default_detail
if code is None:
code = self.default_code
# For validation failures, we may collect many errors together,
# so the details should always be coerced to a list if not already.
if isinstance(detail, tuple):
detail = list(detail)
elif not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
self.detail = _get_error_details(detail, code)
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Malformed request.')
default_code = 'parse_error'
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Incorrect authentication credentials.')
default_code = 'authentication_failed'
class NotAuthenticated(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Authentication credentials were not provided.')
default_code = 'not_authenticated'
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = _('You do not have permission to perform this action.')
default_code = 'permission_denied'
class NotFound(APIException):
status_code = status.HTTP_404_NOT_FOUND
default_detail = _('Not found.')
default_code = 'not_found'
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = _('Method "{method}" not allowed.')
default_code = 'method_not_allowed'
def __init__(self, method, detail=None, code=None):
if detail is None:
detail = force_str(self.default_detail).format(method=method)
super().__init__(detail, code)
class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE
default_detail = _('Could not satisfy the request Accept header.')
default_code = 'not_acceptable'
def __init__(self, detail=None, code=None, available_renderers=None):
self.available_renderers = available_renderers
super().__init__(detail, code)
class UnsupportedMediaType(APIException):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
default_detail = _('Unsupported media type "{media_type}" in request.')
default_code = 'unsupported_media_type'
def __init__(self, media_type, detail=None, code=None):
if detail is None:
detail = force_str(self.default_detail).format(media_type=media_type)
super().__init__(detail, code)
class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
default_detail = _('Request was throttled.')
extra_detail_singular = _('Expected available in {wait} second.')
extra_detail_plural = _('Expected available in {wait} seconds.')
default_code = 'throttled'
def __init__(self, wait=None, detail=None, code=None):
if detail is None:
detail = force_str(self.default_detail)
if wait is not None:
wait = math.ceil(wait)
detail = ' '.join((
detail,
force_str(ngettext(self.extra_detail_singular.format(wait=wait),
self.extra_detail_plural.format(wait=wait),
wait))))
self.wait = wait
super().__init__(detail, code)
def server_error(request, *args, **kwargs):
"""
Generic 500 error handler.
"""
data = {
'error': 'Server Error (500)'
}
return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def bad_request(request, exception, *args, **kwargs):
"""
Generic 400 error handler.
"""
data = {
'error': 'Bad Request (400)'
}
return JsonResponse(data, status=status.HTTP_400_BAD_REQUEST)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,379 @@
"""
Provides generic filtering backends that can be used to filter the results
returned by list views.
"""
import operator
import warnings
from functools import reduce
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.template import loader
from django.utils.encoding import force_str
from django.utils.text import smart_split, unescape_string_literal
from django.utils.translation import gettext_lazy as _
from rest_framework import RemovedInDRF317Warning
from rest_framework.compat import coreapi, coreschema
from rest_framework.fields import CharField
from rest_framework.settings import api_settings
def search_smart_split(search_terms):
"""Returns sanitized search terms as a list."""
split_terms = []
for term in smart_split(search_terms):
# trim commas to avoid bad matching for quoted phrases
term = term.strip(',')
if term.startswith(('"', "'")) and term[0] == term[-1]:
# quoted phrases are kept together without any other split
split_terms.append(unescape_string_literal(term))
else:
# non-quoted tokens are split by comma, keeping only non-empty ones
for sub_term in term.split(','):
if sub_term:
split_terms.append(sub_term.strip())
return split_terms
class BaseFilterBackend:
"""
A base class from which all filter backend classes should inherit.
"""
def filter_queryset(self, request, queryset, view):
"""
Return a filtered queryset.
"""
raise NotImplementedError(".filter_queryset() must be overridden.")
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return []
def get_schema_operation_parameters(self, view):
return []
class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM
template = 'rest_framework/filters/search.html'
lookup_prefixes = {
'^': 'istartswith',
'=': 'iexact',
'@': 'search',
'$': 'iregex',
}
search_title = _('Search')
search_description = _('A search term.')
def get_search_fields(self, view, request):
"""
Search fields are obtained from the view, but the request is always
passed to this method. Sub-classes can override this method to
dynamically change the search fields based on request content.
"""
return getattr(view, 'search_fields', None)
def get_search_terms(self, request):
"""
Search terms are set by a ?search=... query parameter,
and may be whitespace delimited.
"""
value = request.query_params.get(self.search_param, '')
field = CharField(trim_whitespace=False, allow_blank=True)
cleaned_value = field.run_validation(value)
return search_smart_split(cleaned_value)
def construct_search(self, field_name, queryset):
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
else:
# Use field_name if it includes a lookup.
opts = queryset.model._meta
lookup_fields = field_name.split(LOOKUP_SEP)
# Go through the fields, following all relations.
prev_field = None
for path_part in lookup_fields:
if path_part == "pk":
path_part = opts.pk.name
try:
field = opts.get_field(path_part)
except FieldDoesNotExist:
# Use valid query lookups.
if prev_field and prev_field.get_lookup(path_part):
return field_name
else:
prev_field = field
if hasattr(field, "path_infos"):
# Update opts to follow the relation.
opts = field.path_infos[-1].to_opts
# Otherwise, use the field with icontains.
lookup = 'icontains'
return LOOKUP_SEP.join([field_name, lookup])
def must_call_distinct(self, queryset, search_fields):
"""
Return True if 'distinct()' should be used to query the given lookups.
"""
for search_field in search_fields:
opts = queryset.model._meta
if search_field[0] in self.lookup_prefixes:
search_field = search_field[1:]
# Annotated fields do not need to be distinct
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
continue
parts = search_field.split(LOOKUP_SEP)
for part in parts:
field = opts.get_field(part)
if hasattr(field, 'get_path_info'):
# This field is a relation, update opts to follow the relation
path_info = field.get_path_info()
opts = path_info[-1].to_opts
if any(path.m2m for path in path_info):
# This field is a m2m relation so we know we need to call distinct
return True
else:
# This field has a custom __ query transform but is not a relational field.
break
return False
def filter_queryset(self, request, queryset, view):
search_fields = self.get_search_fields(view, request)
search_terms = self.get_search_terms(request)
if not search_fields or not search_terms:
return queryset
orm_lookups = [
self.construct_search(str(search_field), queryset)
for search_field in search_fields
]
base = queryset
# generator which for each term builds the corresponding search
conditions = (
reduce(
operator.or_,
(models.Q(**{orm_lookup: term}) for orm_lookup in orm_lookups)
) for term in search_terms
)
queryset = queryset.filter(reduce(operator.and_, conditions))
# Remove duplicates from results, if necessary
if self.must_call_distinct(queryset, search_fields):
# inspired by django.contrib.admin
# this is more accurate than .distinct form M2M relationship
# also is cross-database
queryset = queryset.filter(pk=models.OuterRef('pk'))
queryset = base.filter(models.Exists(queryset))
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None):
return ''
context = {
'param': self.search_param,
'term': request.query_params.get(self.search_param, ''),
}
template = loader.get_template(self.template)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.search_param,
required=False,
location='query',
schema=coreschema.String(
title=force_str(self.search_title),
description=force_str(self.search_description)
)
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.search_param,
'required': False,
'in': 'query',
'description': force_str(self.search_description),
'schema': {
'type': 'string',
},
},
]
class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
ordering_title = _('Ordering')
ordering_description = _('Which field to use when ordering the results.')
template = 'rest_framework/filters/ordering.html'
def get_ordering(self, request, queryset, view):
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
The `ordering` query parameter can be overridden by setting
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.
"""
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
if isinstance(ordering, str):
return (ordering,)
return ordering
def get_default_valid_fields(self, queryset, view, context={}):
# If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view.
if hasattr(view, 'get_serializer_class'):
try:
serializer_class = view.get_serializer_class()
except AssertionError:
# Raised by the default implementation if
# no serializer_class was found
serializer_class = None
else:
serializer_class = getattr(view, 'serializer_class', None)
if serializer_class is None:
msg = (
"Cannot use %s on a view which does not have either a "
"'serializer_class', an overriding 'get_serializer_class' "
"or 'ordering_fields' attribute."
)
raise ImproperlyConfigured(msg % self.__class__.__name__)
model_class = queryset.model
model_property_names = [
# 'pk' is a property added in Django's Model class, however it is valid for ordering.
attr for attr in dir(model_class) if isinstance(getattr(model_class, attr), property) and attr != 'pk'
]
return [
(field.source.replace('.', '__') or field_name, field.label)
for field_name, field in serializer_class(context=context).fields.items()
if (
not getattr(field, 'write_only', False) and
not field.source == '*' and
field.source not in model_property_names
)
]
def get_valid_fields(self, queryset, view, context={}):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None:
# Default to allowing filtering on serializer fields
return self.get_default_valid_fields(queryset, view, context)
elif valid_fields == '__all__':
# View explicitly allows filtering on any model field
valid_fields = [
(field.name, field.verbose_name) for field in queryset.model._meta.fields
]
valid_fields += [
(key, key.title().split('__'))
for key in queryset.query.annotations
]
else:
valid_fields = [
(item, item) if isinstance(item, str) else item
for item in valid_fields
]
return valid_fields
def remove_invalid_fields(self, queryset, fields, view, request):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
def term_valid(term):
if term.startswith("-"):
term = term[1:]
return term in valid_fields
return [term for term in fields if term_valid(term)]
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)
if ordering:
return queryset.order_by(*ordering)
return queryset
def get_template_context(self, request, queryset, view):
current = self.get_ordering(request, queryset, view)
current = None if not current else current[0]
options = []
context = {
'request': request,
'current': current,
'param': self.ordering_param,
}
for key, label in self.get_valid_fields(queryset, view, context):
options.append((key, '%s - %s' % (label, _('ascending'))))
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
context['options'] = options
return context
def to_html(self, request, queryset, view):
template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.ordering_param,
required=False,
location='query',
schema=coreschema.String(
title=force_str(self.ordering_title),
description=force_str(self.ordering_description)
)
)
]
def get_schema_operation_parameters(self, view):
return [
{
'name': self.ordering_param,
'required': False,
'in': 'query',
'description': force_str(self.ordering_description),
'schema': {
'type': 'string',
},
},
]

View File

@@ -0,0 +1,295 @@
"""
Generic views that provide commonly needed behaviour.
"""
from django.core.exceptions import ValidationError
from django.db.models.query import QuerySet
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
from rest_framework import mixins, views
from rest_framework.settings import api_settings
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
"""
Same as Django's standard shortcut, but make sure to also raise 404
if the filter_kwargs don't match the required types.
"""
try:
return _get_object_or_404(queryset, *filter_args, **filter_kwargs)
except (TypeError, ValueError, ValidationError):
raise Http404
class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
# You'll need to either set these attributes,
# or override `get_queryset()`/`get_serializer_class()`.
# If you are overriding a view method, it is important that you call
# `get_queryset()` instead of accessing the `queryset` property directly,
# as `queryset` will get evaluated only once, and those results are cached
# for all subsequent requests.
queryset = None
serializer_class = None
# If you want to use object lookups other than pk, set 'lookup_field'.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
lookup_url_kwarg = None
# The filter backend classes to use for queryset filtering
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
# The style to use for queryset pagination.
pagination_class = api_settings.DEFAULT_PAGINATION_CLASS
# Allow generic typing checking for generic views.
def __class_getitem__(cls, *args, **kwargs):
return cls
def get_queryset(self):
"""
Get the list of items for this view.
This must be an iterable, and may be a queryset.
Defaults to using `self.queryset`.
This method should always be used rather than accessing `self.queryset`
directly, as `self.queryset` gets evaluated only once, and those results
are cached for all subsequent requests.
You may want to override this if you need to provide different
querysets depending on the incoming request.
(Eg. return a list of items that is specific to the user)
"""
assert self.queryset is not None, (
"'%s' should either include a `queryset` attribute, "
"or override the `get_queryset()` method."
% self.__class__.__name__
)
queryset = self.queryset
if isinstance(queryset, QuerySet):
# Ensure queryset is re-evaluated on each request.
queryset = queryset.all()
return queryset
def get_object(self):
"""
Returns the object the view is displaying.
You may want to override this if you need to provide non-standard
queryset lookups. Eg if objects are referenced using multiple
keyword arguments in the url conf.
"""
queryset = self.filter_queryset(self.get_queryset())
# Perform the lookup filtering.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
assert lookup_url_kwarg in self.kwargs, (
'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' %
(self.__class__.__name__, lookup_url_kwarg)
)
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
obj = get_object_or_404(queryset, **filter_kwargs)
# May raise a permission denied
self.check_object_permissions(self.request, obj)
return obj
def get_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
kwargs.setdefault('context', self.get_serializer_context())
return serializer_class(*args, **kwargs)
def get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
assert self.serializer_class is not None, (
"'%s' should either include a `serializer_class` attribute, "
"or override the `get_serializer_class()` method."
% self.__class__.__name__
)
return self.serializer_class
def get_serializer_context(self):
"""
Extra context provided to the serializer class.
"""
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def filter_queryset(self, queryset):
"""
Given a queryset, filter it with whichever filter backend is in use.
You are unlikely to want to override this method, although you may need
to call it either from a list view, or from a custom `get_object`
method if you want to apply the configured filtering backend to the
default queryset.
"""
for backend in list(self.filter_backends):
queryset = backend().filter_queryset(self.request, queryset, self)
return queryset
@property
def paginator(self):
"""
The paginator instance associated with the view, or `None`.
"""
if not hasattr(self, '_paginator'):
if self.pagination_class is None:
self._paginator = None
else:
self._paginator = self.pagination_class()
return self._paginator
def paginate_queryset(self, queryset):
"""
Return a single page of results, or `None` if pagination is disabled.
"""
if self.paginator is None:
return None
return self.paginator.paginate_queryset(queryset, self.request, view=self)
def get_paginated_response(self, data):
"""
Return a paginated style `Response` object for the given output data.
"""
assert self.paginator is not None
return self.paginator.get_paginated_response(data)
# Concrete view classes that provide method handlers
# by composing the mixin classes with the base view.
class CreateAPIView(mixins.CreateModelMixin,
GenericAPIView):
"""
Concrete view for creating a model instance.
"""
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
class ListAPIView(mixins.ListModelMixin,
GenericAPIView):
"""
Concrete view for listing a queryset.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
class RetrieveAPIView(mixins.RetrieveModelMixin,
GenericAPIView):
"""
Concrete view for retrieving a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
class DestroyAPIView(mixins.DestroyModelMixin,
GenericAPIView):
"""
Concrete view for deleting a model instance.
"""
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
class UpdateAPIView(mixins.UpdateModelMixin,
GenericAPIView):
"""
Concrete view for updating a model instance.
"""
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
GenericAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
GenericAPIView):
"""
Concrete view for retrieving, updating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
GenericAPIView):
"""
Concrete view for retrieving or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
GenericAPIView):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)

View File

@@ -0,0 +1,71 @@
from django.core.management.base import BaseCommand
from django.utils.module_loading import import_string
from rest_framework import renderers
from rest_framework.schemas import coreapi
from rest_framework.schemas.openapi import SchemaGenerator
OPENAPI_MODE = 'openapi'
COREAPI_MODE = 'coreapi'
class Command(BaseCommand):
help = "Generates configured API schema for project."
def get_mode(self):
return COREAPI_MODE if coreapi.is_enabled() else OPENAPI_MODE
def add_arguments(self, parser):
parser.add_argument('--title', dest="title", default='', type=str)
parser.add_argument('--url', dest="url", default=None, type=str)
parser.add_argument('--description', dest="description", default=None, type=str)
if self.get_mode() == COREAPI_MODE:
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
else:
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json'], default='openapi', type=str)
parser.add_argument('--urlconf', dest="urlconf", default=None, type=str)
parser.add_argument('--generator_class', dest="generator_class", default=None, type=str)
parser.add_argument('--file', dest="file", default=None, type=str)
parser.add_argument('--api_version', dest="api_version", default='', type=str)
def handle(self, *args, **options):
if options['generator_class']:
generator_class = import_string(options['generator_class'])
else:
generator_class = self.get_generator_class()
generator = generator_class(
url=options['url'],
title=options['title'],
description=options['description'],
urlconf=options['urlconf'],
version=options['api_version'],
)
schema = generator.get_schema(request=None, public=True)
renderer = self.get_renderer(options['format'])
output = renderer.render(schema, renderer_context={})
if options['file']:
with open(options['file'], 'wb') as f:
f.write(output)
else:
self.stdout.write(output.decode())
def get_renderer(self, format):
if self.get_mode() == COREAPI_MODE:
renderer_cls = {
'corejson': renderers.CoreJSONRenderer,
'openapi': renderers.CoreAPIOpenAPIRenderer,
'openapi-json': renderers.CoreAPIJSONOpenAPIRenderer,
}[format]
return renderer_cls()
renderer_cls = {
'openapi': renderers.OpenAPIRenderer,
'openapi-json': renderers.JSONOpenAPIRenderer,
}[format]
return renderer_cls()
def get_generator_class(self):
if self.get_mode() == COREAPI_MODE:
return coreapi.SchemaGenerator
return SchemaGenerator

View File

@@ -0,0 +1,152 @@
"""
The metadata API is used to allow customization of how `OPTIONS` requests
are handled. We currently provide a single default implementation that returns
some fairly ad-hoc information about the view.
Future implementations might use JSON schema or other definitions in order
to return this information in a more standardized way.
"""
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.utils.encoding import force_str
from rest_framework import exceptions, serializers
from rest_framework.request import clone_request
from rest_framework.utils.field_mapping import ClassLookupDict
class BaseMetadata:
def determine_metadata(self, request, view):
"""
Return a dictionary of metadata about the view.
Used to return responses for OPTIONS requests.
"""
raise NotImplementedError(".determine_metadata() must be overridden.")
class SimpleMetadata(BaseMetadata):
"""
This is the default metadata implementation.
It returns an ad-hoc set of information about the view.
There are not any formalized standards for `OPTIONS` responses
for us to base this on.
"""
label_lookup = ClassLookupDict({
serializers.Field: 'field',
serializers.BooleanField: 'boolean',
serializers.CharField: 'string',
serializers.UUIDField: 'string',
serializers.URLField: 'url',
serializers.EmailField: 'email',
serializers.RegexField: 'regex',
serializers.SlugField: 'slug',
serializers.IntegerField: 'integer',
serializers.FloatField: 'float',
serializers.DecimalField: 'decimal',
serializers.DateField: 'date',
serializers.DateTimeField: 'datetime',
serializers.TimeField: 'time',
serializers.DurationField: 'duration',
serializers.ChoiceField: 'choice',
serializers.MultipleChoiceField: 'multiple choice',
serializers.FileField: 'file upload',
serializers.ImageField: 'image upload',
serializers.ListField: 'list',
serializers.DictField: 'nested object',
serializers.Serializer: 'nested object',
})
def determine_metadata(self, request, view):
metadata = {
"name": view.get_view_name(),
"description": view.get_view_description(),
"renders": [renderer.media_type for renderer in view.renderer_classes],
"parses": [parser.media_type for parser in view.parser_classes],
}
if hasattr(view, 'get_serializer'):
actions = self.determine_actions(request, view)
if actions:
metadata['actions'] = actions
return metadata
def determine_actions(self, request, view):
"""
For generic class based views we return information about
the fields that are accepted for 'PUT' and 'POST' methods.
"""
actions = {}
for method in {'PUT', 'POST'} & set(view.allowed_methods):
view.request = clone_request(request, method)
try:
# Test global permissions
if hasattr(view, 'check_permissions'):
view.check_permissions(view.request)
# Test object permissions
if method == 'PUT' and hasattr(view, 'get_object'):
view.get_object()
except (exceptions.APIException, PermissionDenied, Http404):
pass
else:
# If user has appropriate permissions for the view, include
# appropriate metadata about the fields that should be supplied.
serializer = view.get_serializer()
actions[method] = self.get_serializer_info(serializer)
finally:
view.request = request
return actions
def get_serializer_info(self, serializer):
"""
Given an instance of a serializer, return a dictionary of metadata
about its fields.
"""
if hasattr(serializer, 'child'):
# If this is a `ListSerializer` then we want to examine the
# underlying child serializer instance instead.
serializer = serializer.child
return {
field_name: self.get_field_info(field)
for field_name, field in serializer.fields.items()
if not isinstance(field, serializers.HiddenField)
}
def get_field_info(self, field):
"""
Given an instance of a serializer field, return a dictionary
of metadata about it.
"""
field_info = {
"type": self.label_lookup[field],
"required": getattr(field, "required", False),
}
attrs = [
'read_only', 'label', 'help_text',
'min_length', 'max_length',
'min_value', 'max_value',
'max_digits', 'decimal_places'
]
for attr in attrs:
value = getattr(field, attr, None)
if value is not None and value != '':
field_info[attr] = force_str(value, strings_only=True)
if getattr(field, 'child', None):
field_info['child'] = self.get_field_info(field.child)
elif getattr(field, 'fields', None):
field_info['children'] = self.get_serializer_info(field)
if (not field_info.get('read_only') and
not isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) and
hasattr(field, 'choices')):
field_info['choices'] = [
{
'value': choice_value,
'display_name': force_str(choice_name, strings_only=True)
}
for choice_value, choice_name in field.choices.items()
]
return field_info

View File

@@ -0,0 +1,95 @@
"""
Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
"""
from rest_framework import status
from rest_framework.response import Response
from rest_framework.settings import api_settings
class CreateModelMixin:
"""
Create a model instance.
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def perform_create(self, serializer):
serializer.save()
def get_success_headers(self, data):
try:
return {'Location': str(data[api_settings.URL_FIELD_NAME])}
except (TypeError, KeyError):
return {}
class ListModelMixin:
"""
List a queryset.
"""
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
class RetrieveModelMixin:
"""
Retrieve a model instance.
"""
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)
class UpdateModelMixin:
"""
Update a model instance.
"""
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
return Response(serializer.data)
def perform_update(self, serializer):
serializer.save()
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
return self.update(request, *args, **kwargs)
class DestroyModelMixin:
"""
Destroy a model instance.
"""
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return Response(status=status.HTTP_204_NO_CONTENT)
def perform_destroy(self, instance):
instance.delete()

View File

@@ -0,0 +1,97 @@
"""
Content negotiation deals with selecting an appropriate renderer given the
incoming request. Typically this will be based on the request's Accept header.
"""
from django.http import Http404
from rest_framework import exceptions
from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import (
_MediaType, media_type_matches, order_by_precedence
)
class BaseContentNegotiation:
def select_parser(self, request, parsers):
raise NotImplementedError('.select_parser() must be implemented')
def select_renderer(self, request, renderers, format_suffix=None):
raise NotImplementedError('.select_renderer() must be implemented')
class DefaultContentNegotiation(BaseContentNegotiation):
settings = api_settings
def select_parser(self, request, parsers):
"""
Given a list of parsers and a media type, return the appropriate
parser to handle the incoming request.
"""
for parser in parsers:
if media_type_matches(parser.media_type, request.content_type):
return parser
return None
def select_renderer(self, request, renderers, format_suffix=None):
"""
Given a request and a list of renderers, return a two-tuple of:
(renderer, media type).
"""
# Allow URL style format override. eg. "?format=json
format_query_param = self.settings.URL_FORMAT_OVERRIDE
format = format_suffix or request.query_params.get(format_query_param)
if format:
renderers = self.filter_renderers(renderers, format)
accepts = self.get_accept_list(request)
# Check the acceptable media types against each renderer,
# attempting more specific media types first
# NB. The inner loop here isn't as bad as it first looks :)
# Worst case is we're looping over len(accept_list) * len(self.renderers)
for media_type_set in order_by_precedence(accepts):
for renderer in renderers:
for media_type in media_type_set:
if media_type_matches(renderer.media_type, media_type):
# Return the most specific media type as accepted.
media_type_wrapper = _MediaType(media_type)
if (
_MediaType(renderer.media_type).precedence >
media_type_wrapper.precedence
):
# Eg client requests '*/*'
# Accepted media type is 'application/json'
full_media_type = ';'.join(
(renderer.media_type,) +
tuple(
'{}={}'.format(key, value)
for key, value in media_type_wrapper.params.items()
)
)
return renderer, full_media_type
else:
# Eg client requests 'application/json; indent=8'
# Accepted media type is 'application/json; indent=8'
return renderer, media_type
raise exceptions.NotAcceptable(available_renderers=renderers)
def filter_renderers(self, renderers, format):
"""
If there is a '.json' style format suffix, filter the renderers
so that we only negotiation against those that accept that format.
"""
renderers = [renderer for renderer in renderers
if renderer.format == format]
if not renderers:
raise Http404
return renderers
def get_accept_list(self, request):
"""
Given the incoming request, return a tokenized list of media
type strings.
"""
header = request.META.get('HTTP_ACCEPT', '*/*')
return [token.strip() for token in header.split(',')]

View File

@@ -0,0 +1,990 @@
"""
Pagination serializers determine the structure of the output that should
be used for paginated responses.
"""
import contextlib
import warnings
from base64 import b64decode, b64encode
from collections import namedtuple
from urllib import parse
from django.core.paginator import InvalidPage
from django.core.paginator import Paginator as DjangoPaginator
from django.template import loader
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
from rest_framework import RemovedInDRF317Warning
from rest_framework.compat import coreapi, coreschema
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.utils.urls import remove_query_param, replace_query_param
def _positive_int(integer_string, strict=False, cutoff=None):
"""
Cast a string to a strictly positive integer.
"""
ret = int(integer_string)
if ret < 0 or (ret == 0 and strict):
raise ValueError()
if cutoff:
return min(ret, cutoff)
return ret
def _divide_with_ceil(a, b):
"""
Returns 'a' divided by 'b', with any remainder rounded up.
"""
if a % b:
return (a // b) + 1
return a // b
def _get_displayed_page_numbers(current, final):
"""
This utility function determines a list of page numbers to display.
This gives us a nice contextually relevant set of page numbers.
For example:
current=14, final=16 -> [1, None, 13, 14, 15, 16]
This implementation gives one page to each side of the cursor,
or two pages to the side when the cursor is at the edge, then
ensures that any breaks between non-continuous page numbers never
remove only a single page.
For an alternative implementation which gives two pages to each side of
the cursor, eg. as in GitHub issue list pagination, see:
https://gist.github.com/tomchristie/321140cebb1c4a558b15
"""
assert current >= 1
assert final >= current
if final <= 5:
return list(range(1, final + 1))
# We always include the first two pages, last two pages, and
# two pages either side of the current page.
included = {1, current - 1, current, current + 1, final}
# If the break would only exclude a single page number then we
# may as well include the page number instead of the break.
if current <= 4:
included.add(2)
included.add(3)
if current >= final - 3:
included.add(final - 1)
included.add(final - 2)
# Now sort the page numbers and drop anything outside the limits.
included = [
idx for idx in sorted(included)
if 0 < idx <= final
]
# Finally insert any `...` breaks
if current > 4:
included.insert(1, None)
if current < final - 3:
included.insert(len(included) - 1, None)
return included
def _get_page_links(page_numbers, current, url_func):
"""
Given a list of page numbers and `None` page breaks,
return a list of `PageLink` objects.
"""
page_links = []
for page_number in page_numbers:
if page_number is None:
page_link = PAGE_BREAK
else:
page_link = PageLink(
url=url_func(page_number),
number=page_number,
is_active=(page_number == current),
is_break=False
)
page_links.append(page_link)
return page_links
def _reverse_ordering(ordering_tuple):
"""
Given an order_by tuple such as `('-created', 'uuid')` reverse the
ordering and return a new tuple, eg. `('created', '-uuid')`.
"""
def invert(x):
return x[1:] if x.startswith('-') else '-' + x
return tuple([invert(item) for item in ordering_tuple])
Cursor = namedtuple('Cursor', ['offset', 'reverse', 'position'])
PageLink = namedtuple('PageLink', ['url', 'number', 'is_active', 'is_break'])
PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True)
class BasePagination:
display_page_controls = False
def paginate_queryset(self, queryset, request, view=None): # pragma: no cover
raise NotImplementedError('paginate_queryset() must be implemented.')
def get_paginated_response(self, data): # pragma: no cover
raise NotImplementedError('get_paginated_response() must be implemented.')
def get_paginated_response_schema(self, schema):
return schema
def to_html(self): # pragma: no cover
raise NotImplementedError('to_html() must be implemented to display page controls.')
def get_results(self, data):
return data['results']
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
return []
def get_schema_operation_parameters(self, view):
return []
class PageNumberPagination(BasePagination):
"""
A simple page number based style that supports page numbers as
query parameters. For example:
http://api.example.org/accounts/?page=4
http://api.example.org/accounts/?page=4&page_size=100
"""
# The default page size.
# Defaults to `None`, meaning pagination is disabled.
page_size = api_settings.PAGE_SIZE
django_paginator_class = DjangoPaginator
# Client can control the page using this query parameter.
page_query_param = 'page'
page_query_description = _('A page number within the paginated result set.')
# Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None
page_size_query_description = _('Number of results to return per page.')
# Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set.
max_page_size = None
last_page_strings = ('last',)
template = 'rest_framework/pagination/numbers.html'
invalid_page_message = _('Invalid page.')
def paginate_queryset(self, queryset, request, view=None):
"""
Paginate a queryset if required, either returning a
page object, or `None` if pagination is not configured for this view.
"""
self.request = request
page_size = self.get_page_size(request)
if not page_size:
return None
paginator = self.django_paginator_class(queryset, page_size)
page_number = self.get_page_number(request, paginator)
try:
self.page = paginator.page(page_number)
except InvalidPage as exc:
msg = self.invalid_page_message.format(
page_number=page_number, message=str(exc)
)
raise NotFound(msg)
if paginator.num_pages > 1 and self.template is not None:
# The browsable API should display pagination controls.
self.display_page_controls = True
return list(self.page)
def get_page_number(self, request, paginator):
page_number = request.query_params.get(self.page_query_param) or 1
if page_number in self.last_page_strings:
page_number = paginator.num_pages
return page_number
def get_paginated_response(self, data):
return Response({
'count': self.page.paginator.count,
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'results': data,
})
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'required': ['count', 'results'],
'properties': {
'count': {
'type': 'integer',
'example': 123,
},
'next': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{page_query_param}=4'.format(
page_query_param=self.page_query_param)
},
'previous': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{page_query_param}=2'.format(
page_query_param=self.page_query_param)
},
'results': schema,
},
}
def get_page_size(self, request):
if self.page_size_query_param:
with contextlib.suppress(KeyError, ValueError):
return _positive_int(
request.query_params[self.page_size_query_param],
strict=True,
cutoff=self.max_page_size
)
return self.page_size
def get_next_link(self):
if not self.page.has_next():
return None
url = self.request.build_absolute_uri()
page_number = self.page.next_page_number()
return replace_query_param(url, self.page_query_param, page_number)
def get_previous_link(self):
if not self.page.has_previous():
return None
url = self.request.build_absolute_uri()
page_number = self.page.previous_page_number()
if page_number == 1:
return remove_query_param(url, self.page_query_param)
return replace_query_param(url, self.page_query_param, page_number)
def get_html_context(self):
base_url = self.request.build_absolute_uri()
def page_number_to_url(page_number):
if page_number == 1:
return remove_query_param(base_url, self.page_query_param)
else:
return replace_query_param(base_url, self.page_query_param, page_number)
current = self.page.number
final = self.page.paginator.num_pages
page_numbers = _get_displayed_page_numbers(current, final)
page_links = _get_page_links(page_numbers, current, page_number_to_url)
return {
'previous_url': self.get_previous_link(),
'next_url': self.get_next_link(),
'page_links': page_links
}
def to_html(self):
template = loader.get_template(self.template)
context = self.get_html_context()
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
fields = [
coreapi.Field(
name=self.page_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Page',
description=force_str(self.page_query_description)
)
)
]
if self.page_size_query_param is not None:
fields.append(
coreapi.Field(
name=self.page_size_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Page size',
description=force_str(self.page_size_query_description)
)
)
)
return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.page_query_param,
'required': False,
'in': 'query',
'description': force_str(self.page_query_description),
'schema': {
'type': 'integer',
},
},
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_str(self.page_size_query_description),
'schema': {
'type': 'integer',
},
},
)
return parameters
class LimitOffsetPagination(BasePagination):
"""
A limit/offset based style. For example:
http://api.example.org/accounts/?limit=100
http://api.example.org/accounts/?offset=400&limit=100
"""
default_limit = api_settings.PAGE_SIZE
limit_query_param = 'limit'
limit_query_description = _('Number of results to return per page.')
offset_query_param = 'offset'
offset_query_description = _('The initial index from which to return the results.')
max_limit = None
template = 'rest_framework/pagination/numbers.html'
def paginate_queryset(self, queryset, request, view=None):
self.request = request
self.limit = self.get_limit(request)
if self.limit is None:
return None
self.count = self.get_count(queryset)
self.offset = self.get_offset(request)
if self.count > self.limit and self.template is not None:
self.display_page_controls = True
if self.count == 0 or self.offset > self.count:
return []
return list(queryset[self.offset:self.offset + self.limit])
def get_paginated_response(self, data):
return Response({
'count': self.count,
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'results': data
})
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'required': ['count', 'results'],
'properties': {
'count': {
'type': 'integer',
'example': 123,
},
'next': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{offset_param}=400&{limit_param}=100'.format(
offset_param=self.offset_query_param, limit_param=self.limit_query_param),
},
'previous': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{offset_param}=200&{limit_param}=100'.format(
offset_param=self.offset_query_param, limit_param=self.limit_query_param),
},
'results': schema,
},
}
def get_limit(self, request):
if self.limit_query_param:
with contextlib.suppress(KeyError, ValueError):
return _positive_int(
request.query_params[self.limit_query_param],
strict=True,
cutoff=self.max_limit
)
return self.default_limit
def get_offset(self, request):
try:
return _positive_int(
request.query_params[self.offset_query_param],
)
except (KeyError, ValueError):
return 0
def get_next_link(self):
if self.offset + self.limit >= self.count:
return None
url = self.request.build_absolute_uri()
url = replace_query_param(url, self.limit_query_param, self.limit)
offset = self.offset + self.limit
return replace_query_param(url, self.offset_query_param, offset)
def get_previous_link(self):
if self.offset <= 0:
return None
url = self.request.build_absolute_uri()
url = replace_query_param(url, self.limit_query_param, self.limit)
if self.offset - self.limit <= 0:
return remove_query_param(url, self.offset_query_param)
offset = self.offset - self.limit
return replace_query_param(url, self.offset_query_param, offset)
def get_html_context(self):
base_url = self.request.build_absolute_uri()
if self.limit:
current = _divide_with_ceil(self.offset, self.limit) + 1
# The number of pages is a little bit fiddly.
# We need to sum both the number of pages from current offset to end
# plus the number of pages up to the current offset.
# When offset is not strictly divisible by the limit then we may
# end up introducing an extra page as an artifact.
final = (
_divide_with_ceil(self.count - self.offset, self.limit) +
_divide_with_ceil(self.offset, self.limit)
)
final = max(final, 1)
else:
current = 1
final = 1
if current > final:
current = final
def page_number_to_url(page_number):
if page_number == 1:
return remove_query_param(base_url, self.offset_query_param)
else:
offset = self.offset + ((page_number - current) * self.limit)
return replace_query_param(base_url, self.offset_query_param, offset)
page_numbers = _get_displayed_page_numbers(current, final)
page_links = _get_page_links(page_numbers, current, page_number_to_url)
return {
'previous_url': self.get_previous_link(),
'next_url': self.get_next_link(),
'page_links': page_links
}
def to_html(self):
template = loader.get_template(self.template)
context = self.get_html_context()
return template.render(context)
def get_count(self, queryset):
"""
Determine an object count, supporting either querysets or regular lists.
"""
try:
return queryset.count()
except (AttributeError, TypeError):
return len(queryset)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
name=self.limit_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Limit',
description=force_str(self.limit_query_description)
)
),
coreapi.Field(
name=self.offset_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Offset',
description=force_str(self.offset_query_description)
)
)
]
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.limit_query_param,
'required': False,
'in': 'query',
'description': force_str(self.limit_query_description),
'schema': {
'type': 'integer',
},
},
{
'name': self.offset_query_param,
'required': False,
'in': 'query',
'description': force_str(self.offset_query_description),
'schema': {
'type': 'integer',
},
},
]
return parameters
class CursorPagination(BasePagination):
"""
The cursor pagination implementation is necessarily complex.
For an overview of the position/offset style we use, see this post:
https://cra.mr/2011/03/08/building-cursors-for-the-disqus-api
"""
cursor_query_param = 'cursor'
cursor_query_description = _('The pagination cursor value.')
page_size = api_settings.PAGE_SIZE
invalid_cursor_message = _('Invalid cursor')
ordering = '-created'
template = 'rest_framework/pagination/previous_and_next.html'
# Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None
page_size_query_description = _('Number of results to return per page.')
# Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set.
max_page_size = None
# The offset in the cursor is used in situations where we have a
# nearly-unique index. (Eg millisecond precision creation timestamps)
# We guard against malicious users attempting to cause expensive database
# queries, by having a hard cap on the maximum possible size of the offset.
offset_cutoff = 1000
def paginate_queryset(self, queryset, request, view=None):
self.request = request
self.page_size = self.get_page_size(request)
if not self.page_size:
return None
self.base_url = request.build_absolute_uri()
self.ordering = self.get_ordering(request, queryset, view)
self.cursor = self.decode_cursor(request)
if self.cursor is None:
(offset, reverse, current_position) = (0, False, None)
else:
(offset, reverse, current_position) = self.cursor
# Cursor pagination always enforces an ordering.
if reverse:
queryset = queryset.order_by(*_reverse_ordering(self.ordering))
else:
queryset = queryset.order_by(*self.ordering)
# If we have a cursor with a fixed position then filter by that.
if current_position is not None:
order = self.ordering[0]
is_reversed = order.startswith('-')
order_attr = order.lstrip('-')
# Test for: (cursor reversed) XOR (queryset reversed)
if self.cursor.reverse != is_reversed:
kwargs = {order_attr + '__lt': current_position}
else:
kwargs = {order_attr + '__gt': current_position}
queryset = queryset.filter(**kwargs)
# If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a
# page following on from this one.
results = list(queryset[offset:offset + self.page_size + 1])
self.page = list(results[:self.page_size])
# Determine the position of the final item following the page.
if len(results) > len(self.page):
has_following_position = True
following_position = self._get_position_from_instance(results[-1], self.ordering)
else:
has_following_position = False
following_position = None
if reverse:
# If we have a reverse queryset, then the query ordering was in reverse
# so we need to reverse the items again before returning them to the user.
self.page = list(reversed(self.page))
# Determine next and previous positions for reverse cursors.
self.has_next = (current_position is not None) or (offset > 0)
self.has_previous = has_following_position
if self.has_next:
self.next_position = current_position
if self.has_previous:
self.previous_position = following_position
else:
# Determine next and previous positions for forward cursors.
self.has_next = has_following_position
self.has_previous = (current_position is not None) or (offset > 0)
if self.has_next:
self.next_position = following_position
if self.has_previous:
self.previous_position = current_position
# Display page controls in the browsable API if there is more
# than one page.
if (self.has_previous or self.has_next) and self.template is not None:
self.display_page_controls = True
return self.page
def get_page_size(self, request):
if self.page_size_query_param:
with contextlib.suppress(KeyError, ValueError):
return _positive_int(
request.query_params[self.page_size_query_param],
strict=True,
cutoff=self.max_page_size
)
return self.page_size
def get_next_link(self):
if not self.has_next:
return None
if self.page and self.cursor and self.cursor.reverse and self.cursor.offset != 0:
# If we're reversing direction and we have an offset cursor
# then we cannot use the first position we find as a marker.
compare = self._get_position_from_instance(self.page[-1], self.ordering)
else:
compare = self.next_position
offset = 0
has_item_with_unique_position = False
for item in reversed(self.page):
position = self._get_position_from_instance(item, self.ordering)
if position != compare:
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
break
# The item in this position has the same position as the item
# following it, we can't use it as a marker position, so increment
# the offset and keep seeking to the previous item.
compare = position
offset += 1
if self.page and not has_item_with_unique_position:
# There were no unique positions in the page.
if not self.has_previous:
# We are on the first page.
# Our cursor will have an offset equal to the page size,
# but no position to filter against yet.
offset = self.page_size
position = None
elif self.cursor.reverse:
# The change in direction will introduce a paging artifact,
# where we end up skipping forward a few extra items.
offset = 0
position = self.previous_position
else:
# Use the position from the existing cursor and increment
# it's offset by the page size.
offset = self.cursor.offset + self.page_size
position = self.previous_position
if not self.page:
position = self.next_position
cursor = Cursor(offset=offset, reverse=False, position=position)
return self.encode_cursor(cursor)
def get_previous_link(self):
if not self.has_previous:
return None
if self.page and self.cursor and not self.cursor.reverse and self.cursor.offset != 0:
# If we're reversing direction and we have an offset cursor
# then we cannot use the first position we find as a marker.
compare = self._get_position_from_instance(self.page[0], self.ordering)
else:
compare = self.previous_position
offset = 0
has_item_with_unique_position = False
for item in self.page:
position = self._get_position_from_instance(item, self.ordering)
if position != compare:
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
break
# The item in this position has the same position as the item
# following it, we can't use it as a marker position, so increment
# the offset and keep seeking to the previous item.
compare = position
offset += 1
if self.page and not has_item_with_unique_position:
# There were no unique positions in the page.
if not self.has_next:
# We are on the final page.
# Our cursor will have an offset equal to the page size,
# but no position to filter against yet.
offset = self.page_size
position = None
elif self.cursor.reverse:
# Use the position from the existing cursor and increment
# it's offset by the page size.
offset = self.cursor.offset + self.page_size
position = self.next_position
else:
# The change in direction will introduce a paging artifact,
# where we end up skipping back a few extra items.
offset = 0
position = self.next_position
if not self.page:
position = self.previous_position
cursor = Cursor(offset=offset, reverse=True, position=position)
return self.encode_cursor(cursor)
def get_ordering(self, request, queryset, view):
"""
Return a tuple of strings, that may be used in an `order_by` method.
"""
# The default case is to check for an `ordering` attribute
# on this pagination instance.
ordering = self.ordering
ordering_filters = [
filter_cls for filter_cls in getattr(view, 'filter_backends', [])
if hasattr(filter_cls, 'get_ordering')
]
if ordering_filters:
# If a filter exists on the view that implements `get_ordering`
# then we defer to that filter to determine the ordering.
filter_cls = ordering_filters[0]
filter_instance = filter_cls()
ordering_from_filter = filter_instance.get_ordering(request, queryset, view)
if ordering_from_filter:
ordering = ordering_from_filter
assert ordering is not None, (
'Using cursor pagination, but no ordering attribute was declared '
'on the pagination class.'
)
assert '__' not in ordering, (
'Cursor pagination does not support double underscore lookups '
'for orderings. Orderings should be an unchanging, unique or '
'nearly-unique field on the model, such as "-created" or "pk".'
)
assert isinstance(ordering, (str, list, tuple)), (
'Invalid ordering. Expected string or tuple, but got {type}'.format(
type=type(ordering).__name__
)
)
if isinstance(ordering, str):
return (ordering,)
return tuple(ordering)
def decode_cursor(self, request):
"""
Given a request with a cursor, return a `Cursor` instance.
"""
# Determine if we have a cursor, and if so then decode it.
encoded = request.query_params.get(self.cursor_query_param)
if encoded is None:
return None
try:
querystring = b64decode(encoded.encode('ascii')).decode('ascii')
tokens = parse.parse_qs(querystring, keep_blank_values=True)
offset = tokens.get('o', ['0'])[0]
offset = _positive_int(offset, cutoff=self.offset_cutoff)
reverse = tokens.get('r', ['0'])[0]
reverse = bool(int(reverse))
position = tokens.get('p', [None])[0]
except (TypeError, ValueError):
raise NotFound(self.invalid_cursor_message)
return Cursor(offset=offset, reverse=reverse, position=position)
def encode_cursor(self, cursor):
"""
Given a Cursor instance, return an url with encoded cursor.
"""
tokens = {}
if cursor.offset != 0:
tokens['o'] = str(cursor.offset)
if cursor.reverse:
tokens['r'] = '1'
if cursor.position is not None:
tokens['p'] = cursor.position
querystring = parse.urlencode(tokens, doseq=True)
encoded = b64encode(querystring.encode('ascii')).decode('ascii')
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
def _get_position_from_instance(self, instance, ordering):
field_name = ordering[0].lstrip('-')
if isinstance(instance, dict):
attr = instance[field_name]
else:
attr = getattr(instance, field_name)
return str(attr)
def get_paginated_response(self, data):
return Response({
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'results': data,
})
def get_paginated_response_schema(self, schema):
return {
'type': 'object',
'required': ['results'],
'properties': {
'next': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{cursor_query_param}=cD00ODY%3D"'.format(
cursor_query_param=self.cursor_query_param)
},
'previous': {
'type': 'string',
'nullable': True,
'format': 'uri',
'example': 'http://api.example.org/accounts/?{cursor_query_param}=cj0xJnA9NDg3'.format(
cursor_query_param=self.cursor_query_param)
},
'results': schema,
},
}
def get_html_context(self):
return {
'previous_url': self.get_previous_link(),
'next_url': self.get_next_link()
}
def to_html(self):
template = loader.get_template(self.template)
context = self.get_html_context()
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
fields = [
coreapi.Field(
name=self.cursor_query_param,
required=False,
location='query',
schema=coreschema.String(
title='Cursor',
description=force_str(self.cursor_query_description)
)
)
]
if self.page_size_query_param is not None:
fields.append(
coreapi.Field(
name=self.page_size_query_param,
required=False,
location='query',
schema=coreschema.Integer(
title='Page size',
description=force_str(self.page_size_query_description)
)
)
)
return fields
def get_schema_operation_parameters(self, view):
parameters = [
{
'name': self.cursor_query_param,
'required': False,
'in': 'query',
'description': force_str(self.cursor_query_description),
'schema': {
'type': 'string',
},
}
]
if self.page_size_query_param is not None:
parameters.append(
{
'name': self.page_size_query_param,
'required': False,
'in': 'query',
'description': force_str(self.page_size_query_description),
'schema': {
'type': 'integer',
},
}
)
return parameters

View File

@@ -0,0 +1,206 @@
"""
Parsers are used to parse the content of incoming HTTP requests.
They give us a generic way of being able to handle various media types
on the request, such as form content or json encoded data.
"""
import codecs
import contextlib
from django.conf import settings
from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import ChunkIter
from django.http.multipartparser import \
MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError
from django.utils.http import parse_header_parameters
from rest_framework import renderers
from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings
from rest_framework.utils import json
class DataAndFiles:
def __init__(self, data, files):
self.data = data
self.files = files
class BaseParser:
"""
All parsers should extend `BaseParser`, specifying a `media_type`
attribute, and overriding the `.parse()` method.
"""
media_type = None
def parse(self, stream, media_type=None, parser_context=None):
"""
Given a stream to read from, return the parsed representation.
Should return parsed data, or a `DataAndFiles` object consisting of the
parsed data and files.
"""
raise NotImplementedError(".parse() must be overridden.")
class JSONParser(BaseParser):
"""
Parses JSON-serialized data.
"""
media_type = 'application/json'
renderer_class = renderers.JSONRenderer
strict = api_settings.STRICT_JSON
def parse(self, stream, media_type=None, parser_context=None):
"""
Parses the incoming bytestream as JSON and returns the resulting data.
"""
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
try:
decoded_stream = codecs.getreader(encoding)(stream)
parse_constant = json.strict_constant if self.strict else None
return json.load(decoded_stream, parse_constant=parse_constant)
except ValueError as exc:
raise ParseError('JSON parse error - %s' % str(exc))
class FormParser(BaseParser):
"""
Parser for form data.
"""
media_type = 'application/x-www-form-urlencoded'
def parse(self, stream, media_type=None, parser_context=None):
"""
Parses the incoming bytestream as a URL encoded form,
and returns the resulting QueryDict.
"""
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
return QueryDict(stream.read(), encoding=encoding)
class MultiPartParser(BaseParser):
"""
Parser for multipart form data, which may include file data.
"""
media_type = 'multipart/form-data'
def parse(self, stream, media_type=None, parser_context=None):
"""
Parses the incoming bytestream as a multipart encoded form,
and returns a DataAndFiles object.
`.data` will be a `QueryDict` containing all the form parameters.
`.files` will be a `QueryDict` containing all the form files.
"""
parser_context = parser_context or {}
request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
meta = request.META.copy()
meta['CONTENT_TYPE'] = media_type
upload_handlers = request.upload_handlers
try:
parser = DjangoMultiPartParser(meta, stream, upload_handlers, encoding)
data, files = parser.parse()
return DataAndFiles(data, files)
except MultiPartParserError as exc:
raise ParseError('Multipart form parse error - %s' % str(exc))
class FileUploadParser(BaseParser):
"""
Parser for file upload data.
"""
media_type = '*/*'
errors = {
'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream',
'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.',
}
def parse(self, stream, media_type=None, parser_context=None):
"""
Treats the incoming bytestream as a raw file upload and returns
a `DataAndFiles` object.
`.data` will be None (we expect request body to be a file content).
`.files` will be a `QueryDict` containing one 'file' element.
"""
parser_context = parser_context or {}
request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
meta = request.META
upload_handlers = request.upload_handlers
filename = self.get_filename(stream, media_type, parser_context)
if not filename:
raise ParseError(self.errors['no_filename'])
# Note that this code is extracted from Django's handling of
# file uploads in MultiPartParser.
content_type = meta.get('HTTP_CONTENT_TYPE',
meta.get('CONTENT_TYPE', ''))
try:
content_length = int(meta.get('HTTP_CONTENT_LENGTH',
meta.get('CONTENT_LENGTH', 0)))
except (ValueError, TypeError):
content_length = None
# See if the handler will want to take care of the parsing.
for handler in upload_handlers:
result = handler.handle_raw_input(stream,
meta,
content_length,
None,
encoding)
if result is not None:
return DataAndFiles({}, {'file': result[1]})
# This is the standard case.
possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size]
chunk_size = min([2 ** 31 - 4] + possible_sizes)
chunks = ChunkIter(stream, chunk_size)
counters = [0] * len(upload_handlers)
for index, handler in enumerate(upload_handlers):
try:
handler.new_file(None, filename, content_type,
content_length, encoding)
except StopFutureHandlers:
upload_handlers = upload_handlers[:index + 1]
break
for chunk in chunks:
for index, handler in enumerate(upload_handlers):
chunk_length = len(chunk)
chunk = handler.receive_data_chunk(chunk, counters[index])
counters[index] += chunk_length
if chunk is None:
break
for index, handler in enumerate(upload_handlers):
file_obj = handler.file_complete(counters[index])
if file_obj is not None:
return DataAndFiles({}, {'file': file_obj})
raise ParseError(self.errors['unhandled'])
def get_filename(self, stream, media_type, parser_context):
"""
Detects the uploaded file name. First searches a 'filename' url kwarg.
Then tries to parse Content-Disposition header.
"""
with contextlib.suppress(KeyError):
return parser_context['kwargs']['filename']
with contextlib.suppress(AttributeError, KeyError, ValueError):
meta = parser_context['request'].META
disposition, params = parse_header_parameters(meta['HTTP_CONTENT_DISPOSITION'])
if 'filename*' in params:
return params['filename*']
return params['filename']

View File

@@ -0,0 +1,314 @@
"""
Provides a set of pluggable permission policies.
"""
from django.http import Http404
from rest_framework import exceptions
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
class OperationHolderMixin:
def __and__(self, other):
return OperandHolder(AND, self, other)
def __or__(self, other):
return OperandHolder(OR, self, other)
def __rand__(self, other):
return OperandHolder(AND, other, self)
def __ror__(self, other):
return OperandHolder(OR, other, self)
def __invert__(self):
return SingleOperandHolder(NOT, self)
class SingleOperandHolder(OperationHolderMixin):
def __init__(self, operator_class, op1_class):
self.operator_class = operator_class
self.op1_class = op1_class
def __call__(self, *args, **kwargs):
op1 = self.op1_class(*args, **kwargs)
return self.operator_class(op1)
class OperandHolder(OperationHolderMixin):
def __init__(self, operator_class, op1_class, op2_class):
self.operator_class = operator_class
self.op1_class = op1_class
self.op2_class = op2_class
def __call__(self, *args, **kwargs):
op1 = self.op1_class(*args, **kwargs)
op2 = self.op2_class(*args, **kwargs)
return self.operator_class(op1, op2)
def __eq__(self, other):
return (
isinstance(other, OperandHolder) and
self.operator_class == other.operator_class and
self.op1_class == other.op1_class and
self.op2_class == other.op2_class
)
def __hash__(self):
return hash((self.operator_class, self.op1_class, self.op2_class))
class AND:
def __init__(self, op1, op2):
self.op1 = op1
self.op2 = op2
def has_permission(self, request, view):
return (
self.op1.has_permission(request, view) and
self.op2.has_permission(request, view)
)
def has_object_permission(self, request, view, obj):
return (
self.op1.has_object_permission(request, view, obj) and
self.op2.has_object_permission(request, view, obj)
)
class OR:
def __init__(self, op1, op2):
self.op1 = op1
self.op2 = op2
def has_permission(self, request, view):
return (
self.op1.has_permission(request, view) or
self.op2.has_permission(request, view)
)
def has_object_permission(self, request, view, obj):
return (
self.op1.has_permission(request, view)
and self.op1.has_object_permission(request, view, obj)
) or (
self.op2.has_permission(request, view)
and self.op2.has_object_permission(request, view, obj)
)
class NOT:
def __init__(self, op1):
self.op1 = op1
def has_permission(self, request, view):
return not self.op1.has_permission(request, view)
def has_object_permission(self, request, view, obj):
return not self.op1.has_object_permission(request, view, obj)
class BasePermissionMetaclass(OperationHolderMixin, type):
pass
class BasePermission(metaclass=BasePermissionMetaclass):
"""
A base class from which all permission classes should inherit.
"""
def has_permission(self, request, view):
"""
Return `True` if permission is granted, `False` otherwise.
"""
return True
def has_object_permission(self, request, view, obj):
"""
Return `True` if permission is granted, `False` otherwise.
"""
return True
class AllowAny(BasePermission):
"""
Allow any access.
This isn't strictly required, since you could use an empty
permission_classes list, but it's useful because it makes the intention
more explicit.
"""
def has_permission(self, request, view):
return True
class IsAuthenticated(BasePermission):
"""
Allows access only to authenticated users.
"""
def has_permission(self, request, view):
return bool(request.user and request.user.is_authenticated)
class IsAdminUser(BasePermission):
"""
Allows access only to admin users.
"""
def has_permission(self, request, view):
return bool(request.user and request.user.is_staff)
class IsAuthenticatedOrReadOnly(BasePermission):
"""
The request is authenticated as a user, or is a read-only request.
"""
def has_permission(self, request, view):
return bool(
request.method in SAFE_METHODS or
request.user and
request.user.is_authenticated
)
class DjangoModelPermissions(BasePermission):
"""
The request is authenticated using `django.contrib.auth` permissions.
See: https://docs.djangoproject.com/en/dev/topics/auth/#permissions
It ensures that the user is authenticated, and has the appropriate
`add`/`change`/`delete` permissions on the model.
This permission can only be applied against view classes that
provide a `.queryset` attribute.
"""
# Map methods into required permission codes.
# Override this if you need to also provide 'view' permissions,
# or if you want to provide custom permission codes.
perms_map = {
'GET': [],
'OPTIONS': [],
'HEAD': [],
'POST': ['%(app_label)s.add_%(model_name)s'],
'PUT': ['%(app_label)s.change_%(model_name)s'],
'PATCH': ['%(app_label)s.change_%(model_name)s'],
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
}
authenticated_users_only = True
def get_required_permissions(self, method, model_cls):
"""
Given a model and an HTTP method, return the list of permission
codes that the user is required to have.
"""
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.model_name
}
if method not in self.perms_map:
raise exceptions.MethodNotAllowed(method)
return [perm % kwargs for perm in self.perms_map[method]]
def _queryset(self, view):
assert hasattr(view, 'get_queryset') \
or getattr(view, 'queryset', None) is not None, (
'Cannot apply {} on a view that does not set '
'`.queryset` or have a `.get_queryset()` method.'
).format(self.__class__.__name__)
if hasattr(view, 'get_queryset'):
queryset = view.get_queryset()
assert queryset is not None, (
'{}.get_queryset() returned None'.format(view.__class__.__name__)
)
return queryset
return view.queryset
def has_permission(self, request, view):
if not request.user or (
not request.user.is_authenticated and self.authenticated_users_only):
return False
# Workaround to ensure DjangoModelPermissions are not applied
# to the root view when using DefaultRouter.
if getattr(view, '_ignore_model_permissions', False):
return True
queryset = self._queryset(view)
perms = self.get_required_permissions(request.method, queryset.model)
return request.user.has_perms(perms)
class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
"""
Similar to DjangoModelPermissions, except that anonymous users are
allowed read-only access.
"""
authenticated_users_only = False
class DjangoObjectPermissions(DjangoModelPermissions):
"""
The request is authenticated using Django's object-level permissions.
It requires an object-permissions-enabled backend, such as Django Guardian.
It ensures that the user is authenticated, and has the appropriate
`add`/`change`/`delete` permissions on the object using .has_perms.
This permission can only be applied against view classes that
provide a `.queryset` attribute.
"""
perms_map = {
'GET': [],
'OPTIONS': [],
'HEAD': [],
'POST': ['%(app_label)s.add_%(model_name)s'],
'PUT': ['%(app_label)s.change_%(model_name)s'],
'PATCH': ['%(app_label)s.change_%(model_name)s'],
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
}
def get_required_object_permissions(self, method, model_cls):
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.model_name
}
if method not in self.perms_map:
raise exceptions.MethodNotAllowed(method)
return [perm % kwargs for perm in self.perms_map[method]]
def has_object_permission(self, request, view, obj):
# authentication checks have already executed via has_permission
queryset = self._queryset(view)
model_cls = queryset.model
user = request.user
perms = self.get_required_object_permissions(request.method, model_cls)
if not user.has_perms(perms, obj):
# If the user does not have permissions we need to determine if
# they have read permissions to see 403, or not, and simply see
# a 404 response.
if request.method in SAFE_METHODS:
# Read permissions already checked and failed, no need
# to make another lookup.
raise Http404
read_perms = self.get_required_object_permissions('GET', model_cls)
if not user.has_perms(read_perms, obj):
raise Http404
# Has read permissions.
return False
return True

View File

@@ -0,0 +1,585 @@
import contextlib
import sys
from operator import attrgetter
from urllib import parse
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
from django.db.models import Manager
from django.db.models.query import QuerySet
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
from django.utils.encoding import smart_str, uri_to_iri
from django.utils.translation import gettext_lazy as _
from rest_framework.fields import (
Field, SkipField, empty, get_attribute, is_simple_callable, iter_options
)
from rest_framework.reverse import reverse
from rest_framework.settings import api_settings
from rest_framework.utils import html
def method_overridden(method_name, klass, instance):
"""
Determine if a method has been overridden.
"""
method = getattr(klass, method_name)
default_method = getattr(method, '__func__', method) # Python 3 compat
return default_method is not getattr(instance, method_name).__func__
class ObjectValueError(ValueError):
"""
Raised when `queryset.get()` failed due to an underlying `ValueError`.
Wrapping prevents calling code conflating this with unrelated errors.
"""
class ObjectTypeError(TypeError):
"""
Raised when `queryset.get()` failed due to an underlying `TypeError`.
Wrapping prevents calling code conflating this with unrelated errors.
"""
class Hyperlink(str):
"""
A string like object that additionally has an associated name.
We use this for hyperlinked URLs that may render as a named link
in some contexts, or render as a plain URL in others.
"""
def __new__(cls, url, obj):
ret = super().__new__(cls, url)
ret.obj = obj
return ret
def __getnewargs__(self):
return (str(self), self.name)
@property
def name(self):
# This ensures that we only called `__str__` lazily,
# as in some cases calling __str__ on a model instances *might*
# involve a database lookup.
return str(self.obj)
is_hyperlink = True
class PKOnlyObject:
"""
This is a mock object, used for when we only need the pk of the object
instance, but still want to return an object with a .pk attribute,
in order to keep the same interface as a regular model instance.
"""
def __init__(self, pk):
self.pk = pk
def __str__(self):
return "%s" % self.pk
# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
MANY_RELATION_KWARGS = (
'read_only', 'write_only', 'required', 'default', 'initial', 'source',
'label', 'help_text', 'style', 'error_messages', 'allow_empty',
'html_cutoff', 'html_cutoff_text'
)
class RelatedField(Field):
queryset = None
html_cutoff = None
html_cutoff_text = None
def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', self.queryset)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings)
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
self.html_cutoff_text = kwargs.pop(
'html_cutoff_text',
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
)
if not method_overridden('get_queryset', RelatedField, self):
assert self.queryset is not None or kwargs.get('read_only'), (
'Relational field must provide a `queryset` argument, '
'override `get_queryset`, or set read_only=`True`.'
)
assert not (self.queryset is not None and kwargs.get('read_only')), (
'Relational fields should not provide a `queryset` argument, '
'when setting read_only=`True`.'
)
kwargs.pop('many', None)
kwargs.pop('allow_empty', None)
super().__init__(**kwargs)
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
# `ManyRelatedField` classes instead when `many=True` is set.
if kwargs.pop('many', False):
return cls.many_init(*args, **kwargs)
return super().__new__(cls, *args, **kwargs)
@classmethod
def many_init(cls, *args, **kwargs):
"""
This method handles creating a parent `ManyRelatedField` instance
when the `many=True` keyword argument is passed.
Typically you won't need to override this method.
Note that we're over-cautious in passing most arguments to both parent
and child classes in order to try to cover the general case. If you're
overriding this method you'll probably want something much simpler, eg:
@classmethod
def many_init(cls, *args, **kwargs):
kwargs['child'] = cls()
return CustomManyRelatedField(*args, **kwargs)
"""
list_kwargs = {'child_relation': cls(*args, **kwargs)}
for key in kwargs:
if key in MANY_RELATION_KWARGS:
list_kwargs[key] = kwargs[key]
return ManyRelatedField(**list_kwargs)
def run_validation(self, data=empty):
# We force empty strings to None values for relational fields.
if data == '':
data = None
return super().run_validation(data)
def get_queryset(self):
queryset = self.queryset
if isinstance(queryset, (QuerySet, Manager)):
# Ensure queryset is re-evaluated whenever used.
# Note that actually a `Manager` class may also be used as the
# queryset argument. This occurs on ModelSerializer fields,
# as it allows us to generate a more expressive 'repr' output
# for the field.
# Eg: 'MyRelationship(queryset=ExampleModel.objects.all())'
queryset = queryset.all()
return queryset
def use_pk_only_optimization(self):
return False
def get_attribute(self, instance):
if self.use_pk_only_optimization() and self.source_attrs:
# Optimized case, return a mock object only containing the pk attribute.
with contextlib.suppress(AttributeError):
attribute_instance = get_attribute(instance, self.source_attrs[:-1])
value = attribute_instance.serializable_value(self.source_attrs[-1])
if is_simple_callable(value):
# Handle edge case where the relationship `source` argument
# points to a `get_relationship()` method on the model.
value = value()
# Handle edge case where relationship `source` argument points
# to an instance instead of a pk (e.g., a `@property`).
value = getattr(value, 'pk', value)
return PKOnlyObject(pk=value)
# Standard case, return the object instance.
return super().get_attribute(instance)
def get_choices(self, cutoff=None):
queryset = self.get_queryset()
if queryset is None:
# Ensure that field.choices returns something sensible
# even when accessed with a read-only field.
return {}
if cutoff is not None:
queryset = queryset[:cutoff]
return {
self.to_representation(item): self.display_value(item) for item in queryset
}
@property
def choices(self):
return self.get_choices()
@property
def grouped_choices(self):
return self.choices
def iter_options(self):
return iter_options(
self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text
)
def display_value(self, instance):
return str(instance)
class StringRelatedField(RelatedField):
"""
A read only field that represents its targets using their
plain string representation.
"""
def __init__(self, **kwargs):
kwargs['read_only'] = True
super().__init__(**kwargs)
def to_representation(self, value):
return str(value)
class PrimaryKeyRelatedField(RelatedField):
default_error_messages = {
'required': _('This field is required.'),
'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
}
def __init__(self, **kwargs):
self.pk_field = kwargs.pop('pk_field', None)
super().__init__(**kwargs)
def use_pk_only_optimization(self):
return True
def to_internal_value(self, data):
if self.pk_field is not None:
data = self.pk_field.to_internal_value(data)
queryset = self.get_queryset()
try:
if isinstance(data, bool):
raise TypeError
return queryset.get(pk=data)
except ObjectDoesNotExist:
self.fail('does_not_exist', pk_value=data)
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__)
def to_representation(self, value):
if self.pk_field is not None:
return self.pk_field.to_representation(value.pk)
return value.pk
class HyperlinkedRelatedField(RelatedField):
lookup_field = 'pk'
view_name = None
default_error_messages = {
'required': _('This field is required.'),
'no_match': _('Invalid hyperlink - No URL match.'),
'incorrect_match': _('Invalid hyperlink - Incorrect URL match.'),
'does_not_exist': _('Invalid hyperlink - Object does not exist.'),
'incorrect_type': _('Incorrect type. Expected URL string, received {data_type}.'),
}
def __init__(self, view_name=None, **kwargs):
if view_name is not None:
self.view_name = view_name
assert self.view_name is not None, 'The `view_name` argument is required.'
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
self.format = kwargs.pop('format', None)
# We include this simply for dependency injection in tests.
# We can't add it as a class attributes or it would expect an
# implicit `self` argument to be passed.
self.reverse = reverse
super().__init__(**kwargs)
def use_pk_only_optimization(self):
return self.lookup_field == 'pk'
def get_object(self, view_name, view_args, view_kwargs):
"""
Return the object corresponding to a matched URL.
Takes the matched URL conf arguments, and should return an
object instance, or raise an `ObjectDoesNotExist` exception.
"""
lookup_value = view_kwargs[self.lookup_url_kwarg]
lookup_kwargs = {self.lookup_field: lookup_value}
queryset = self.get_queryset()
try:
return queryset.get(**lookup_kwargs)
except ValueError:
exc = ObjectValueError(str(sys.exc_info()[1]))
raise exc.with_traceback(sys.exc_info()[2])
except TypeError:
exc = ObjectTypeError(str(sys.exc_info()[1]))
raise exc.with_traceback(sys.exc_info()[2])
def get_url(self, obj, view_name, request, format):
"""
Given an object, return the URL that hyperlinks to the object.
May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
attributes are not configured to correctly match the URL conf.
"""
# Unsaved objects will not yet have a valid URL.
if hasattr(obj, 'pk') and obj.pk in (None, ''):
return None
lookup_value = getattr(obj, self.lookup_field)
kwargs = {self.lookup_url_kwarg: lookup_value}
return self.reverse(view_name, kwargs=kwargs, request=request, format=format)
def to_internal_value(self, data):
request = self.context.get('request')
try:
http_prefix = data.startswith(('http:', 'https:'))
except AttributeError:
self.fail('incorrect_type', data_type=type(data).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
data = parse.urlparse(data).path
prefix = get_script_prefix()
if data.startswith(prefix):
data = '/' + data[len(prefix):]
data = uri_to_iri(parse.unquote(data))
try:
match = resolve(data)
except Resolver404:
self.fail('no_match')
try:
expected_viewname = request.versioning_scheme.get_versioned_viewname(
self.view_name, request
)
except AttributeError:
expected_viewname = self.view_name
if match.view_name != expected_viewname:
self.fail('incorrect_match')
try:
return self.get_object(match.view_name, match.args, match.kwargs)
except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError):
self.fail('does_not_exist')
def to_representation(self, value):
assert 'request' in self.context, (
"`%s` requires the request in the serializer"
" context. Add `context={'request': request}` when instantiating "
"the serializer." % self.__class__.__name__
)
request = self.context['request']
format = self.context.get('format')
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
#
# Eg. Consider a HyperlinkedIdentityField pointing from a json
# representation to an html property of that representation...
#
# '/snippets/1/' should link to '/snippets/1/highlight/'
# ...but...
# '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
if format and self.format and self.format != format:
format = self.format
# Return the hyperlink, or error if incorrectly configured.
try:
url = self.get_url(value, self.view_name, request, format)
except NoReverseMatch:
msg = (
'Could not resolve URL for hyperlinked relationship using '
'view name "%s". You may have failed to include the related '
'model in your API, or incorrectly configured the '
'`lookup_field` attribute on this field.'
)
if value in ('', None):
value_string = {'': 'the empty string', None: 'None'}[value]
msg += (
" WARNING: The value of the field on the model instance "
"was %s, which may be why it didn't match any "
"entries in your URL conf." % value_string
)
raise ImproperlyConfigured(msg % self.view_name)
if url is None:
return None
return Hyperlink(url, value)
class HyperlinkedIdentityField(HyperlinkedRelatedField):
"""
A read-only field that represents the identity URL for an object, itself.
This is in contrast to `HyperlinkedRelatedField` which represents the
URL of relationships to other objects.
"""
def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True
kwargs['source'] = '*'
super().__init__(view_name, **kwargs)
def use_pk_only_optimization(self):
# We have the complete object instance already. We don't need
# to run the 'only get the pk for this relationship' code.
return False
class SlugRelatedField(RelatedField):
"""
A read-write field that represents the target of the relationship
by a unique 'slug' attribute.
"""
default_error_messages = {
'does_not_exist': _('Object with {slug_name}={value} does not exist.'),
'invalid': _('Invalid value.'),
}
def __init__(self, slug_field=None, **kwargs):
assert slug_field is not None, 'The `slug_field` argument is required.'
self.slug_field = slug_field
super().__init__(**kwargs)
def to_internal_value(self, data):
queryset = self.get_queryset()
try:
return queryset.get(**{self.slug_field: data})
except ObjectDoesNotExist:
self.fail('does_not_exist', slug_name=self.slug_field, value=smart_str(data))
except (TypeError, ValueError):
self.fail('invalid')
def to_representation(self, obj):
slug = self.slug_field
if "__" in slug:
# handling nested relationship if defined
slug = slug.replace('__', '.')
return attrgetter(slug)(obj)
class ManyRelatedField(Field):
"""
Relationships with `many=True` transparently get coerced into instead being
a ManyRelatedField with a child relationship.
The `ManyRelatedField` class is responsible for handling iterating through
the values and passing each one to the child relationship.
This class is treated as private API.
You shouldn't generally need to be using this class directly yourself,
and should instead simply set 'many=True' on the relationship.
"""
initial = []
default_empty_html = []
default_error_messages = {
'not_a_list': _('Expected a list of items but got type "{input_type}".'),
'empty': _('This list may not be empty.')
}
html_cutoff = None
html_cutoff_text = None
def __init__(self, child_relation=None, *args, **kwargs):
self.child_relation = child_relation
self.allow_empty = kwargs.pop('allow_empty', True)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings)
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
self.html_cutoff_text = kwargs.pop(
'html_cutoff_text',
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
)
assert child_relation is not None, '`child_relation` is a required argument.'
super().__init__(*args, **kwargs)
self.child_relation.bind(field_name='', parent=self)
def get_value(self, dictionary):
# We override the default field access in order to support
# lists in HTML forms.
if html.is_html_input(dictionary):
# Don't return [] if the update is partial
if self.field_name not in dictionary:
if getattr(self.root, 'partial', False):
return empty
return dictionary.getlist(self.field_name)
return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
if isinstance(data, str) or not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
self.fail('empty')
return [
self.child_relation.to_internal_value(item)
for item in data
]
def get_attribute(self, instance):
# Can't have any relationships if not created
if hasattr(instance, 'pk') and instance.pk is None:
return []
try:
relationship = get_attribute(instance, self.source_attrs)
except (KeyError, AttributeError) as exc:
if self.default is not empty:
return self.get_default()
if self.allow_null:
return None
if not self.required:
raise SkipField()
msg = (
'Got {exc_type} when attempting to get a value for field '
'`{field}` on serializer `{serializer}`.\nThe serializer '
'field might be named incorrectly and not match '
'any attribute or key on the `{instance}` instance.\n'
'Original exception text was: {exc}.'.format(
exc_type=type(exc).__name__,
field=self.field_name,
serializer=self.parent.__class__.__name__,
instance=instance.__class__.__name__,
exc=exc
)
)
raise type(exc)(msg)
return relationship.all() if hasattr(relationship, 'all') else relationship
def to_representation(self, iterable):
return [
self.child_relation.to_representation(value)
for value in iterable
]
def get_choices(self, cutoff=None):
return self.child_relation.get_choices(cutoff)
@property
def choices(self):
return self.get_choices()
@property
def grouped_choices(self):
return self.choices
def iter_options(self):
return iter_options(
self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,449 @@
"""
The Request class is used as a wrapper around the standard request object.
The wrapped request then offers a richer API, in particular :
- content automatically parsed according to `Content-Type` header,
and available as `request.data`
- full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content
"""
import io
import sys
from contextlib import contextmanager
from django.conf import settings
from django.http import HttpRequest, QueryDict
from django.http.request import RawPostDataException
from django.utils.datastructures import MultiValueDict
from django.utils.http import parse_header_parameters
from rest_framework import exceptions
from rest_framework.settings import api_settings
def is_form_media_type(media_type):
"""
Return True if the media type is a valid form media type.
"""
base_media_type, params = parse_header_parameters(media_type)
return (base_media_type == 'application/x-www-form-urlencoded' or
base_media_type == 'multipart/form-data')
class override_method:
"""
A context manager that temporarily overrides the method on a request,
additionally setting the `view.request` attribute.
Usage:
with override_method(view, request, 'POST') as request:
... # Do stuff with `view` and `request`
"""
def __init__(self, view, request, method):
self.view = view
self.request = request
self.method = method
self.action = getattr(view, 'action', None)
def __enter__(self):
self.view.request = clone_request(self.request, self.method)
# For viewsets we also set the `.action` attribute.
action_map = getattr(self.view, 'action_map', {})
self.view.action = action_map.get(self.method.lower())
return self.view.request
def __exit__(self, *args, **kwarg):
self.view.request = self.request
self.view.action = self.action
class WrappedAttributeError(Exception):
pass
@contextmanager
def wrap_attributeerrors():
"""
Used to re-raise AttributeErrors caught during authentication, preventing
these errors from otherwise being handled by the attribute access protocol.
"""
try:
yield
except AttributeError:
info = sys.exc_info()
exc = WrappedAttributeError(str(info[1]))
raise exc.with_traceback(info[2])
class Empty:
"""
Placeholder for unset attributes.
Cannot use `None`, as that may be a valid value.
"""
pass
def _hasattr(obj, name):
return not getattr(obj, name) is Empty
def clone_request(request, method):
"""
Internal helper method to clone a request, replacing with a different
HTTP method. Used for checking permissions against other methods.
"""
ret = Request(request=request._request,
parsers=request.parsers,
authenticators=request.authenticators,
negotiator=request.negotiator,
parser_context=request.parser_context)
ret._data = request._data
ret._files = request._files
ret._full_data = request._full_data
ret._content_type = request._content_type
ret._stream = request._stream
ret.method = method
if hasattr(request, '_user'):
ret._user = request._user
if hasattr(request, '_auth'):
ret._auth = request._auth
if hasattr(request, '_authenticator'):
ret._authenticator = request._authenticator
if hasattr(request, 'accepted_renderer'):
ret.accepted_renderer = request.accepted_renderer
if hasattr(request, 'accepted_media_type'):
ret.accepted_media_type = request.accepted_media_type
if hasattr(request, 'version'):
ret.version = request.version
if hasattr(request, 'versioning_scheme'):
ret.versioning_scheme = request.versioning_scheme
return ret
class ForcedAuthentication:
"""
This authentication class is used if the test client or request factory
forcibly authenticated the request.
"""
def __init__(self, force_user, force_token):
self.force_user = force_user
self.force_token = force_token
def authenticate(self, request):
return (self.force_user, self.force_token)
class Request:
"""
Wrapper allowing to enhance a standard `HttpRequest` instance.
Kwargs:
- request(HttpRequest). The original request instance.
- parsers(list/tuple). The parsers to use for parsing the
request content.
- authenticators(list/tuple). The authenticators used to try
authenticating the request's user.
"""
def __init__(self, request, parsers=None, authenticators=None,
negotiator=None, parser_context=None):
assert isinstance(request, HttpRequest), (
'The `request` argument must be an instance of '
'`django.http.HttpRequest`, not `{}.{}`.'
.format(request.__class__.__module__, request.__class__.__name__)
)
self._request = request
self.parsers = parsers or ()
self.authenticators = authenticators or ()
self.negotiator = negotiator or self._default_negotiator()
self.parser_context = parser_context
self._data = Empty
self._files = Empty
self._full_data = Empty
self._content_type = Empty
self._stream = Empty
if self.parser_context is None:
self.parser_context = {}
self.parser_context['request'] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
force_user = getattr(request, '_force_auth_user', None)
force_token = getattr(request, '_force_auth_token', None)
if force_user is not None or force_token is not None:
forced_auth = ForcedAuthentication(force_user, force_token)
self.authenticators = (forced_auth,)
def __repr__(self):
return '<%s.%s: %s %r>' % (
self.__class__.__module__,
self.__class__.__name__,
self.method,
self.get_full_path())
# Allow generic typing checking for requests.
def __class_getitem__(cls, *args, **kwargs):
return cls
def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
@property
def content_type(self):
meta = self._request.META
return meta.get('CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
@property
def stream(self):
"""
Returns an object that may be used to stream the request content.
"""
if not _hasattr(self, '_stream'):
self._load_stream()
return self._stream
@property
def query_params(self):
"""
More semantically correct name for request.GET.
"""
return self._request.GET
@property
def data(self):
if not _hasattr(self, '_full_data'):
with wrap_attributeerrors():
self._load_data_and_files()
return self._full_data
@property
def user(self):
"""
Returns the user associated with the current request, as authenticated
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
with wrap_attributeerrors():
self._authenticate()
return self._user
@user.setter
def user(self, value):
"""
Sets the user on the current request. This is necessary to maintain
compatibility with django.contrib.auth where the user property is
set in the login and logout functions.
Note that we also set the user on Django's underlying `HttpRequest`
instance, ensuring that it is available to any middleware in the stack.
"""
self._user = value
self._request.user = value
@property
def auth(self):
"""
Returns any non-user authentication information associated with the
request, such as an authentication token.
"""
if not hasattr(self, '_auth'):
with wrap_attributeerrors():
self._authenticate()
return self._auth
@auth.setter
def auth(self, value):
"""
Sets any non-user authentication information associated with the
request, such as an authentication token.
"""
self._auth = value
self._request.auth = value
@property
def successful_authenticator(self):
"""
Return the instance of the authentication instance class that was used
to authenticate the request, or `None`.
"""
if not hasattr(self, '_authenticator'):
with wrap_attributeerrors():
self._authenticate()
return self._authenticator
def _load_data_and_files(self):
"""
Parses the request content into `self.data`.
"""
if not _hasattr(self, '_data'):
self._data, self._files = self._parse()
if self._files:
self._full_data = self._data.copy()
self._full_data.update(self._files)
else:
self._full_data = self._data
# if a form media type, copy data & files refs to the underlying
# http request so that closable objects are handled appropriately.
if is_form_media_type(self.content_type):
self._request._post = self.POST
self._request._files = self.FILES
def _load_stream(self):
"""
Return the content body of the request, as a stream.
"""
meta = self._request.META
try:
content_length = int(
meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))
)
except (ValueError, TypeError):
content_length = 0
if content_length == 0:
self._stream = None
elif not self._request._read_started:
self._stream = self._request
else:
self._stream = io.BytesIO(self.body)
def _supports_form_parsing(self):
"""
Return True if this requests supports parsing form data.
"""
form_media = (
'application/x-www-form-urlencoded',
'multipart/form-data'
)
return any(parser.media_type in form_media for parser in self.parsers)
def _parse(self):
"""
Parse the request content, returning a two-tuple of (data, files)
May raise an `UnsupportedMediaType`, or `ParseError` exception.
"""
media_type = self.content_type
try:
stream = self.stream
except RawPostDataException:
if not hasattr(self._request, '_post'):
raise
# If request.POST has been accessed in middleware, and a method='POST'
# request was made with 'multipart/form-data', then the request stream
# will already have been exhausted.
if self._supports_form_parsing():
return (self._request.POST, self._request.FILES)
stream = None
if stream is None or media_type is None:
if media_type and is_form_media_type(media_type):
empty_data = QueryDict('', encoding=self._request._encoding)
else:
empty_data = {}
empty_files = MultiValueDict()
return (empty_data, empty_files)
parser = self.negotiator.select_parser(self, self.parsers)
if not parser:
raise exceptions.UnsupportedMediaType(media_type)
try:
parsed = parser.parse(stream, media_type, self.parser_context)
except Exception:
# If we get an exception during parsing, fill in empty data and
# re-raise. Ensures we don't simply repeat the error when
# attempting to render the browsable renderer response, or when
# logging the request or similar.
self._data = QueryDict('', encoding=self._request._encoding)
self._files = MultiValueDict()
self._full_data = self._data
raise
# Parser classes may return the raw data, or a
# DataAndFiles object. Unpack the result as required.
try:
return (parsed.data, parsed.files)
except AttributeError:
empty_files = MultiValueDict()
return (parsed, empty_files)
def _authenticate(self):
"""
Attempt to authenticate the request using each authentication instance
in turn.
"""
for authenticator in self.authenticators:
try:
user_auth_tuple = authenticator.authenticate(self)
except exceptions.APIException:
self._not_authenticated()
raise
if user_auth_tuple is not None:
self._authenticator = authenticator
self.user, self.auth = user_auth_tuple
return
self._not_authenticated()
def _not_authenticated(self):
"""
Set authenticator, user & authtoken representing an unauthenticated request.
Defaults are None, AnonymousUser & None.
"""
self._authenticator = None
if api_settings.UNAUTHENTICATED_USER:
self.user = api_settings.UNAUTHENTICATED_USER()
else:
self.user = None
if api_settings.UNAUTHENTICATED_TOKEN:
self.auth = api_settings.UNAUTHENTICATED_TOKEN()
else:
self.auth = None
def __getattr__(self, attr):
"""
If an attribute does not exist on this instance, then we also attempt
to proxy it to the underlying HttpRequest object.
"""
try:
_request = self.__getattribute__("_request")
return getattr(_request, attr)
except AttributeError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attr}'")
@property
def POST(self):
# Ensure that request.POST uses our request parsing.
if not _hasattr(self, '_data'):
with wrap_attributeerrors():
self._load_data_and_files()
if is_form_media_type(self.content_type):
return self._data
return QueryDict('', encoding=self._request._encoding)
@property
def FILES(self):
# Leave this one alone for backwards compat with Django's request.FILES
# Different from the other two cases, which are not valid property
# names on the WSGIRequest class.
if not _hasattr(self, '_files'):
with wrap_attributeerrors():
self._load_data_and_files()
return self._files
def force_plaintext_errors(self, value):
# Hack to allow our exception handler to force choice of
# plaintext or html error responses.
self._request.is_ajax = lambda: value

View File

@@ -0,0 +1,107 @@
"""
The Response class in REST framework is similar to HTTPResponse, except that
it is initialized with unrendered data, instead of a pre-rendered string.
The appropriate renderer is called during Django's template response rendering.
"""
from http.client import responses
from django.template.response import SimpleTemplateResponse
from rest_framework.serializers import Serializer
class Response(SimpleTemplateResponse):
"""
An HttpResponse that allows its data to be rendered into
arbitrary media types.
"""
def __init__(self, data=None, status=None,
template_name=None, headers=None,
exception=False, content_type=None):
"""
Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'.
Setting 'renderer' and 'media_type' will typically be deferred,
For example being set automatically by the `APIView`.
"""
super().__init__(None, status=status)
if isinstance(data, Serializer):
msg = (
'You passed a Serializer instance as data, but '
'probably meant to pass serialized `.data` or '
'`.error`. representation.'
)
raise AssertionError(msg)
self.data = data
self.template_name = template_name
self.exception = exception
self.content_type = content_type
if headers:
for name, value in headers.items():
self[name] = value
# Allow generic typing checking for responses.
def __class_getitem__(cls, *args, **kwargs):
return cls
@property
def rendered_content(self):
renderer = getattr(self, 'accepted_renderer', None)
accepted_media_type = getattr(self, 'accepted_media_type', None)
context = getattr(self, 'renderer_context', None)
assert renderer, ".accepted_renderer not set on Response"
assert accepted_media_type, ".accepted_media_type not set on Response"
assert context is not None, ".renderer_context not set on Response"
context['response'] = self
media_type = renderer.media_type
charset = renderer.charset
content_type = self.content_type
if content_type is None and charset is not None:
content_type = "{}; charset={}".format(media_type, charset)
elif content_type is None:
content_type = media_type
self['Content-Type'] = content_type
ret = renderer.render(self.data, accepted_media_type, context)
if isinstance(ret, str):
assert charset, (
'renderer returned unicode, and did not specify '
'a charset value.'
)
return ret.encode(charset)
if not ret:
del self['Content-Type']
return ret
@property
def status_text(self):
"""
Returns reason text corresponding to our HTTP response status code.
Provided for convenience.
"""
return responses.get(self.status_code, '')
def __getstate__(self):
"""
Remove attributes from the response that shouldn't be cached.
"""
state = super().__getstate__()
for key in (
'accepted_renderer', 'renderer_context', 'resolver_match',
'client', 'request', 'json', 'wsgi_request'
):
if key in state:
del state[key]
state['_closable_objects'] = []
return state

View File

@@ -0,0 +1,66 @@
"""
Provide urlresolver functions that return fully qualified URLs or view names
"""
from django.urls import NoReverseMatch
from django.urls import reverse as django_reverse
from django.utils.functional import lazy
from rest_framework.settings import api_settings
from rest_framework.utils.urls import replace_query_param
def preserve_builtin_query_params(url, request=None):
"""
Given an incoming request, and an outgoing URL representation,
append the value of any built-in query parameters.
"""
if request is None:
return url
overrides = [
api_settings.URL_FORMAT_OVERRIDE,
]
for param in overrides:
if param and (param in request.GET):
value = request.GET[param]
url = replace_query_param(url, param, value)
return url
def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
If versioning is being used then we pass any `reverse` calls through
to the versioning scheme instance, so that the resulting URL
can be modified if needed.
"""
scheme = getattr(request, 'versioning_scheme', None)
if scheme is not None:
try:
url = scheme.reverse(viewname, args, kwargs, request, format, **extra)
except NoReverseMatch:
# In case the versioning scheme reversal fails, fallback to the
# default implementation
url = _reverse(viewname, args, kwargs, request, format, **extra)
else:
url = _reverse(viewname, args, kwargs, request, format, **extra)
return preserve_builtin_query_params(url, request)
def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra):
"""
Same as `django.urls.reverse`, but optionally takes a request
and returns a fully qualified URL, using the request to get the base URL.
"""
if format is not None:
kwargs = kwargs or {}
kwargs['format'] = format
url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if request:
return request.build_absolute_uri(url)
return url
reverse_lazy = lazy(reverse, str)

View File

@@ -0,0 +1,390 @@
"""
Routers provide a convenient and consistent way of automatically
determining the URL conf for your API.
They are used by simply instantiating a Router class, and then registering
all the required ViewSets with that router.
For example, you might have a `urls.py` that looks something like this:
router = routers.DefaultRouter()
router.register('users', UserViewSet, 'user')
router.register('accounts', AccountViewSet, 'account')
urlpatterns = router.urls
"""
import itertools
from collections import namedtuple
from django.core.exceptions import ImproperlyConfigured
from django.urls import NoReverseMatch, path, re_path
from rest_framework import views
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
from rest_framework.schemas.views import SchemaView
from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns
Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
def escape_curly_brackets(url_path):
"""
Double brackets in regex of url_path for escape string formatting
"""
return url_path.replace('{', '{{').replace('}', '}}')
def flatten(list_of_lists):
"""
Takes an iterable of iterables, returns a single iterable containing all items
"""
return itertools.chain(*list_of_lists)
class BaseRouter:
def __init__(self):
self.registry = []
def register(self, prefix, viewset, basename=None):
if basename is None:
basename = self.get_default_basename(viewset)
if self.is_already_registered(basename):
msg = (f'Router with basename "{basename}" is already registered. '
f'Please provide a unique basename for viewset "{viewset}"')
raise ImproperlyConfigured(msg)
self.registry.append((prefix, viewset, basename))
# invalidate the urls cache
if hasattr(self, '_urls'):
del self._urls
def is_already_registered(self, new_basename):
"""
Check if `basename` is already registered
"""
return any(basename == new_basename for _prefix, _viewset, basename in self.registry)
def get_default_basename(self, viewset):
"""
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
raise NotImplementedError('get_default_basename must be overridden')
def get_urls(self):
"""
Return a list of URL patterns, given the registered viewsets.
"""
raise NotImplementedError('get_urls must be overridden')
@property
def urls(self):
if not hasattr(self, '_urls'):
self._urls = self.get_urls()
return self._urls
class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
url=r'^{prefix}{trailing_slash}$',
mapping={
'get': 'list',
'post': 'create'
},
name='{basename}-list',
detail=False,
initkwargs={'suffix': 'List'}
),
# Dynamically generated list routes. Generated using
# @action(detail=False) decorator on methods of the viewset.
DynamicRoute(
url=r'^{prefix}/{url_path}{trailing_slash}$',
name='{basename}-{url_name}',
detail=False,
initkwargs={}
),
# Detail route.
Route(
url=r'^{prefix}/{lookup}{trailing_slash}$',
mapping={
'get': 'retrieve',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy'
},
name='{basename}-detail',
detail=True,
initkwargs={'suffix': 'Instance'}
),
# Dynamically generated detail routes. Generated using
# @action(detail=True) decorator on methods of the viewset.
DynamicRoute(
url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
name='{basename}-{url_name}',
detail=True,
initkwargs={}
),
]
def __init__(self, trailing_slash=True, use_regex_path=True):
self.trailing_slash = '/' if trailing_slash else ''
self._use_regex = use_regex_path
if use_regex_path:
self._base_pattern = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
self._default_value_pattern = '[^/.]+'
self._url_conf = re_path
else:
self._base_pattern = '<{lookup_value}:{lookup_prefix}{lookup_url_kwarg}>'
self._default_value_pattern = 'str'
self._url_conf = path
# remove regex characters from routes
_routes = []
for route in self.routes:
url_param = route.url
if url_param[0] == '^':
url_param = url_param[1:]
if url_param[-1] == '$':
url_param = url_param[:-1]
_routes.append(route._replace(url=url_param))
self.routes = _routes
super().__init__()
def get_default_basename(self, viewset):
"""
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
queryset = getattr(viewset, 'queryset', None)
assert queryset is not None, '`basename` argument not specified, and could ' \
'not automatically determine the name from the viewset, as ' \
'it does not have a `.queryset` attribute.'
return queryset.model._meta.object_name.lower()
def get_routes(self, viewset):
"""
Augment `self.routes` with any dynamically generated routes.
Returns a list of the Route namedtuple.
"""
# converting to list as iterables are good for one pass, known host needs to be checked again and again for
# different functions.
known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
extra_actions = viewset.get_extra_actions()
# checking action names against the known actions list
not_allowed = [
action.__name__ for action in extra_actions
if action.__name__ in known_actions
]
if not_allowed:
msg = ('Cannot use the @action decorator on the following '
'methods, as they are existing routes: %s')
raise ImproperlyConfigured(msg % ', '.join(not_allowed))
# partition detail and list actions
detail_actions = [action for action in extra_actions if action.detail]
list_actions = [action for action in extra_actions if not action.detail]
routes = []
for route in self.routes:
if isinstance(route, DynamicRoute) and route.detail:
routes += [self._get_dynamic_route(route, action) for action in detail_actions]
elif isinstance(route, DynamicRoute) and not route.detail:
routes += [self._get_dynamic_route(route, action) for action in list_actions]
else:
routes.append(route)
return routes
def _get_dynamic_route(self, route, action):
initkwargs = route.initkwargs.copy()
initkwargs.update(action.kwargs)
url_path = escape_curly_brackets(action.url_path)
return Route(
url=route.url.replace('{url_path}', url_path),
mapping=action.mapping,
name=route.name.replace('{url_name}', action.url_name),
detail=route.detail,
initkwargs=initkwargs,
)
def get_method_map(self, viewset, method_map):
"""
Given a viewset, and a mapping of http methods to actions,
return a new mapping which only includes any mappings that
are actually implemented by the viewset.
"""
bound_methods = {}
for method, action in method_map.items():
if hasattr(viewset, action):
bound_methods[method] = action
return bound_methods
def get_lookup_regex(self, viewset, lookup_prefix=''):
"""
Given a viewset, return the portion of URL regex that is used
to match against a single instance.
Note that lookup_prefix is not used directly inside REST rest_framework
itself, but is required in order to nicely support nested router
implementations, such as drf-nested-routers.
https://github.com/alanjds/drf-nested-routers
"""
# Use `pk` as default field, unset set. Default regex should not
# consume `.json` style suffixes and should break at '/' boundaries.
lookup_field = getattr(viewset, 'lookup_field', 'pk')
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
lookup_value = None
if not self._use_regex:
# try to get a more appropriate attribute when not using regex
lookup_value = getattr(viewset, 'lookup_value_converter', None)
if lookup_value is None:
# fallback to legacy
lookup_value = getattr(viewset, 'lookup_value_regex', self._default_value_pattern)
return self._base_pattern.format(
lookup_prefix=lookup_prefix,
lookup_url_kwarg=lookup_url_kwarg,
lookup_value=lookup_value
)
def get_urls(self):
"""
Use the registered viewsets to generate a list of URL patterns.
"""
ret = []
for prefix, viewset, basename in self.registry:
lookup = self.get_lookup_regex(viewset)
routes = self.get_routes(viewset)
for route in routes:
# Only actions which actually exist on the viewset will be bound
mapping = self.get_method_map(viewset, route.mapping)
if not mapping:
continue
# Build the url pattern
regex = route.url.format(
prefix=prefix,
lookup=lookup,
trailing_slash=self.trailing_slash
)
# If there is no prefix, the first part of the url is probably
# controlled by project's urls.py and the router is in an app,
# so a slash in the beginning will (A) cause Django to give
# warnings and (B) generate URLS that will require using '//'.
if not prefix:
if self._url_conf is path:
if regex[0] == '/':
regex = regex[1:]
elif regex[:2] == '^/':
regex = '^' + regex[2:]
initkwargs = route.initkwargs.copy()
initkwargs.update({
'basename': basename,
'detail': route.detail,
})
view = viewset.as_view(mapping, **initkwargs)
name = route.name.format(basename=basename)
ret.append(self._url_conf(regex, view, name=name))
return ret
class APIRootView(views.APIView):
"""
The default basic root view for DefaultRouter
"""
_ignore_model_permissions = True
schema = None # exclude from schema
api_root_dict = None
def get(self, request, *args, **kwargs):
# Return a plain {"name": "hyperlink"} response.
ret = {}
namespace = request.resolver_match.namespace
for key, url_name in self.api_root_dict.items():
if namespace:
url_name = namespace + ':' + url_name
try:
ret[key] = reverse(
url_name,
args=args,
kwargs=kwargs,
request=request,
format=kwargs.get('format')
)
except NoReverseMatch:
# Don't bail out if eg. no list routes exist, only detail routes.
continue
return Response(ret)
class DefaultRouter(SimpleRouter):
"""
The default router extends the SimpleRouter, but also adds in a default
API root view, and adds format suffix patterns to the URLs.
"""
include_root_view = True
include_format_suffixes = True
root_view_name = 'api-root'
default_schema_renderers = None
APIRootView = APIRootView
APISchemaView = SchemaView
SchemaGenerator = SchemaGenerator
def __init__(self, *args, **kwargs):
if 'root_renderers' in kwargs:
self.root_renderers = kwargs.pop('root_renderers')
else:
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
super().__init__(*args, **kwargs)
def get_api_root_view(self, api_urls=None):
"""
Return a basic root view.
"""
api_root_dict = {}
list_name = self.routes[0].name
for prefix, viewset, basename in self.registry:
api_root_dict[prefix] = list_name.format(basename=basename)
return self.APIRootView.as_view(api_root_dict=api_root_dict)
def get_urls(self):
"""
Generate the list of URL patterns, including a default root view
for the API, and appending `.json` style format suffixes.
"""
urls = super().get_urls()
if self.include_root_view:
view = self.get_api_root_view(api_urls=urls)
root_url = path('', view, name=self.root_view_name)
urls.append(root_url)
if self.include_format_suffixes:
urls = format_suffix_patterns(urls)
return urls

View File

@@ -0,0 +1,58 @@
"""
rest_framework.schemas
schemas:
__init__.py
generators.py # Top-down schema generation
inspectors.py # Per-endpoint view introspection
utils.py # Shared helper functions
views.py # Houses `SchemaView`, `APIView` subclass.
We expose a minimal "public" API directly from `schemas`. This covers the
basic use-cases:
from rest_framework.schemas import (
AutoSchema,
ManualSchema,
get_schema_view,
SchemaGenerator,
)
Other access should target the submodules directly
"""
from rest_framework.settings import api_settings
from . import coreapi, openapi
from .coreapi import AutoSchema, ManualSchema, SchemaGenerator # noqa
from .inspectors import DefaultSchema # noqa
def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
version=None):
"""
Return a schema view.
"""
if generator_class is None:
if coreapi.is_enabled():
generator_class = coreapi.SchemaGenerator
else:
generator_class = openapi.SchemaGenerator
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns, version=version
)
# Avoid import cycle on APIView
from .views import SchemaView
return SchemaView.as_view(
renderer_classes=renderer_classes,
schema_generator=generator,
public=public,
authentication_classes=authentication_classes,
permission_classes=permission_classes,
)

Some files were not shown because too many files have changed in this diff Show More