new models, frontend functions, public pages
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
__all__ = [
|
||||
'django_oauth_toolkit',
|
||||
'djangorestframework_camel_case',
|
||||
'rest_auth',
|
||||
'rest_framework',
|
||||
'rest_polymorphic',
|
||||
'rest_framework_dataclasses',
|
||||
'rest_framework_jwt',
|
||||
'rest_framework_simplejwt',
|
||||
'django_filters',
|
||||
'rest_framework_recursive',
|
||||
'rest_framework_gis',
|
||||
'pydantic',
|
||||
'knox_auth_token',
|
||||
]
|
||||
@@ -0,0 +1,305 @@
|
||||
from django.db import models
|
||||
|
||||
from drf_spectacular.drainage import add_trace_message, get_override, has_override, warn
|
||||
from drf_spectacular.extensions import OpenApiFilterExtension
|
||||
from drf_spectacular.plumbing import (
|
||||
build_array_type, build_basic_type, build_choice_description_list, build_parameter_type,
|
||||
follow_field_source, force_instance, get_manager, get_type_hints, get_view_model, is_basic_type,
|
||||
is_field,
|
||||
)
|
||||
from drf_spectacular.settings import spectacular_settings
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter
|
||||
|
||||
_NoHint = object()
|
||||
|
||||
|
||||
class DjangoFilterExtension(OpenApiFilterExtension):
|
||||
"""
|
||||
Extensions that specifically deals with ``django-filter`` fields. The introspection
|
||||
attempts to estimate the underlying model field types to generate filter types.
|
||||
|
||||
However, there are under-specified filter fields for which heuristics need to be performed.
|
||||
This serves as an explicit list of all partially-handled filter fields:
|
||||
|
||||
- ``AllValuesFilter``: skip choices to prevent DB query
|
||||
- ``AllValuesMultipleFilter``: skip choices to prevent DB query, multi handled though
|
||||
- ``ChoiceFilter``: enum handled, type under-specified
|
||||
- ``DateRangeFilter``: N/A
|
||||
- ``LookupChoiceFilter``: N/A
|
||||
- ``ModelChoiceFilter``: enum handled
|
||||
- ``ModelMultipleChoiceFilter``: enum, multi handled
|
||||
- ``MultipleChoiceFilter``: enum, multi handled
|
||||
- ``RangeFilter``: min/max handled, type under-specified
|
||||
- ``TypedChoiceFilter``: enum handled
|
||||
- ``TypedMultipleChoiceFilter``: enum, multi handled
|
||||
|
||||
In case of warnings or incorrect filter types, you can manually override the underlying
|
||||
field type with a manual ``extend_schema_field`` decoration. Alternatively, if you have a
|
||||
filter method for your filter field, you can attach ``extend_schema_field`` to that filter
|
||||
method.
|
||||
|
||||
.. code-block::
|
||||
|
||||
class SomeFilter(FilterSet):
|
||||
some_field = extend_schema_field(OpenApiTypes.NUMBER)(
|
||||
RangeFilter(field_name='some_manually_annotated_field_in_qs')
|
||||
)
|
||||
|
||||
"""
|
||||
target_class = 'django_filters.rest_framework.DjangoFilterBackend'
|
||||
match_subclasses = True
|
||||
|
||||
def get_schema_operation_parameters(self, auto_schema, *args, **kwargs):
|
||||
model = get_view_model(auto_schema.view)
|
||||
if not model:
|
||||
return []
|
||||
|
||||
filterset_class = self.target.get_filterset_class(auto_schema.view, get_manager(model).none())
|
||||
if not filterset_class:
|
||||
return []
|
||||
|
||||
result = []
|
||||
with add_trace_message(filterset_class):
|
||||
for field_name, filter_field in filterset_class.base_filters.items():
|
||||
result += self.resolve_filter_field(
|
||||
auto_schema, model, filterset_class, field_name, filter_field
|
||||
)
|
||||
return result
|
||||
|
||||
def resolve_filter_field(self, auto_schema, model, filterset_class, field_name, filter_field):
|
||||
from django_filters import filters
|
||||
|
||||
unambiguous_mapping = {
|
||||
filters.CharFilter: OpenApiTypes.STR,
|
||||
filters.BooleanFilter: OpenApiTypes.BOOL,
|
||||
filters.DateFilter: OpenApiTypes.DATE,
|
||||
filters.DateTimeFilter: OpenApiTypes.DATETIME,
|
||||
filters.IsoDateTimeFilter: OpenApiTypes.DATETIME,
|
||||
filters.TimeFilter: OpenApiTypes.TIME,
|
||||
filters.UUIDFilter: OpenApiTypes.UUID,
|
||||
filters.DurationFilter: OpenApiTypes.DURATION,
|
||||
filters.OrderingFilter: OpenApiTypes.STR,
|
||||
filters.TimeRangeFilter: OpenApiTypes.TIME,
|
||||
filters.DateFromToRangeFilter: OpenApiTypes.DATE,
|
||||
filters.IsoDateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
|
||||
filters.DateTimeFromToRangeFilter: OpenApiTypes.DATETIME,
|
||||
}
|
||||
filter_method = self._get_filter_method(filterset_class, filter_field)
|
||||
filter_method_hint = self._get_filter_method_hint(filter_method)
|
||||
filter_choices = self._get_explicit_filter_choices(filter_field)
|
||||
schema_from_override = False
|
||||
|
||||
if has_override(filter_field, 'field') or has_override(filter_method, 'field'):
|
||||
schema_from_override = True
|
||||
annotation = (
|
||||
get_override(filter_field, 'field') or get_override(filter_method, 'field')
|
||||
)
|
||||
if is_basic_type(annotation):
|
||||
schema = build_basic_type(annotation)
|
||||
elif isinstance(annotation, dict):
|
||||
# allow injecting raw schema via @extend_schema_field decorator
|
||||
schema = annotation.copy()
|
||||
elif is_field(annotation):
|
||||
schema = auto_schema._map_serializer_field(force_instance(annotation), "request")
|
||||
else:
|
||||
warn(
|
||||
f"Unsupported annotation {annotation} on filter field {field_name}. defaulting to string."
|
||||
)
|
||||
schema = build_basic_type(OpenApiTypes.STR)
|
||||
elif filter_method_hint is not _NoHint:
|
||||
if is_basic_type(filter_method_hint):
|
||||
schema = build_basic_type(filter_method_hint)
|
||||
else:
|
||||
schema = build_basic_type(OpenApiTypes.STR)
|
||||
elif isinstance(filter_field, tuple(unambiguous_mapping)):
|
||||
for cls in filter_field.__class__.__mro__:
|
||||
if cls in unambiguous_mapping:
|
||||
schema = build_basic_type(unambiguous_mapping[cls])
|
||||
break
|
||||
elif isinstance(filter_field, (filters.NumberFilter, filters.NumericRangeFilter)):
|
||||
# NumberField is underspecified by itself. try to find the
|
||||
# type that makes the most sense or default to generic NUMBER
|
||||
model_field = self._get_model_field(filter_field, model)
|
||||
if isinstance(model_field, (models.IntegerField, models.AutoField)):
|
||||
schema = build_basic_type(OpenApiTypes.INT)
|
||||
elif isinstance(model_field, models.FloatField):
|
||||
schema = build_basic_type(OpenApiTypes.FLOAT)
|
||||
elif isinstance(model_field, models.DecimalField):
|
||||
schema = build_basic_type(OpenApiTypes.NUMBER) # TODO may be improved
|
||||
else:
|
||||
schema = build_basic_type(OpenApiTypes.NUMBER)
|
||||
elif isinstance(filter_field, (filters.ChoiceFilter, filters.MultipleChoiceFilter)):
|
||||
try:
|
||||
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
|
||||
except Exception:
|
||||
if filter_choices and is_basic_type(type(filter_choices[0])):
|
||||
# fallback to type guessing from first choice element
|
||||
schema = build_basic_type(type(filter_choices[0]))
|
||||
else:
|
||||
warn(
|
||||
f'Unable to guess choice types from values, filter method\'s type hint '
|
||||
f'or find "{field_name}" in model. Defaulting to string.'
|
||||
)
|
||||
schema = build_basic_type(OpenApiTypes.STR)
|
||||
else:
|
||||
# the last resort is to look up the type via the model or queryset field
|
||||
# and emit a warning if we were unsuccessful.
|
||||
try:
|
||||
schema = self._get_schema_from_model_field(auto_schema, filter_field, model)
|
||||
except Exception as exc: # pragma: no cover
|
||||
warn(
|
||||
f'Exception raised while trying resolve model field for django-filter '
|
||||
f'field "{field_name}". Defaulting to string (Exception: {exc})'
|
||||
)
|
||||
schema = build_basic_type(OpenApiTypes.STR)
|
||||
|
||||
# primary keys are usually non-editable (readOnly=True) and map_model_field correctly
|
||||
# signals that attribute. however this does not apply in this context.
|
||||
schema.pop('readOnly', None)
|
||||
# enrich schema with additional info from filter_field
|
||||
enum = schema.pop('enum', None)
|
||||
# explicit filter choices may disable enum retrieved from model
|
||||
if not schema_from_override and filter_choices is not None:
|
||||
enum = filter_choices
|
||||
|
||||
description = schema.pop('description', None)
|
||||
if not schema_from_override:
|
||||
description = self._get_field_description(filter_field, description)
|
||||
|
||||
# parameter style variations based on filter base class
|
||||
if isinstance(filter_field, filters.BaseCSVFilter):
|
||||
schema = build_array_type(schema)
|
||||
field_names = [field_name]
|
||||
explode = False
|
||||
style = 'form'
|
||||
elif isinstance(filter_field, filters.MultipleChoiceFilter):
|
||||
schema = build_array_type(schema)
|
||||
field_names = [field_name]
|
||||
explode = True
|
||||
style = 'form'
|
||||
elif isinstance(filter_field, (filters.RangeFilter, filters.NumericRangeFilter)):
|
||||
try:
|
||||
suffixes = filter_field.field_class.widget.suffixes
|
||||
except AttributeError:
|
||||
suffixes = ['min', 'max']
|
||||
field_names = [
|
||||
f'{field_name}_{suffix}' if suffix else field_name for suffix in suffixes
|
||||
]
|
||||
explode = None
|
||||
style = None
|
||||
else:
|
||||
field_names = [field_name]
|
||||
explode = None
|
||||
style = None
|
||||
|
||||
return [
|
||||
build_parameter_type(
|
||||
name=field_name,
|
||||
required=filter_field.extra['required'],
|
||||
location=OpenApiParameter.QUERY,
|
||||
description=description,
|
||||
schema=schema,
|
||||
enum=enum,
|
||||
explode=explode,
|
||||
style=style
|
||||
)
|
||||
for field_name in field_names
|
||||
]
|
||||
|
||||
def _get_filter_method(self, filterset_class, filter_field):
|
||||
if callable(filter_field.method):
|
||||
return filter_field.method
|
||||
elif isinstance(filter_field.method, str):
|
||||
return getattr(filterset_class, filter_field.method)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_filter_method_hint(self, filter_method):
|
||||
try:
|
||||
return get_type_hints(filter_method)['value']
|
||||
except: # noqa: E722
|
||||
return _NoHint
|
||||
|
||||
def _get_explicit_filter_choices(self, filter_field):
|
||||
if 'choices' not in filter_field.extra:
|
||||
return None
|
||||
elif callable(filter_field.extra['choices']):
|
||||
# choices function may utilize the DB, so refrain from actually calling it.
|
||||
return []
|
||||
else:
|
||||
return [c for c, _ in filter_field.extra['choices']]
|
||||
|
||||
def _get_model_field(self, filter_field, model):
|
||||
if not filter_field.field_name:
|
||||
return None
|
||||
path = filter_field.field_name.split('__')
|
||||
return follow_field_source(model, path, emit_warnings=False)
|
||||
|
||||
def _get_schema_from_model_field(self, auto_schema, filter_field, model):
|
||||
# Has potential to throw exceptions. Needs to be wrapped in try/except!
|
||||
#
|
||||
# first search for the field in the model as this has the least amount of
|
||||
# potential side effects. Only after that fails, attempt to call
|
||||
# get_queryset() to check for potential query annotations.
|
||||
model_field = self._get_model_field(filter_field, model)
|
||||
|
||||
# this is a cross feature between rest-framework-gis and django-filter. Regular
|
||||
# behavior needs to be sidestepped as the model information is lost down the line.
|
||||
# TODO for now this will be just a string to cover WKT, WKB, and urlencoded GeoJSON
|
||||
# build_geo_schema(model_field) would yield the correct result
|
||||
if self._is_gis(model_field):
|
||||
return build_basic_type(OpenApiTypes.STR)
|
||||
|
||||
if not isinstance(model_field, models.Field):
|
||||
qs = auto_schema.view.get_queryset()
|
||||
model_field = qs.query.annotations[filter_field.field_name].field
|
||||
return auto_schema._map_model_field(model_field, direction=None)
|
||||
|
||||
def _get_field_description(self, filter_field, description):
|
||||
# Try to improve description beyond auto-generated model description
|
||||
if filter_field.extra.get('help_text', None):
|
||||
description = filter_field.extra['help_text']
|
||||
elif filter_field.label is not None:
|
||||
description = filter_field.label
|
||||
|
||||
choices = filter_field.extra.get('choices')
|
||||
if choices and callable(choices):
|
||||
# remove auto-generated enum list, since choices come from a callable
|
||||
if '\n\n*' in (description or ''):
|
||||
description, _, _ = description.partition('\n\n*')
|
||||
elif (description or '').startswith('* `'):
|
||||
description = ''
|
||||
return description
|
||||
|
||||
choice_description = ''
|
||||
if spectacular_settings.ENUM_GENERATE_CHOICE_DESCRIPTION and choices and not callable(choices):
|
||||
choice_description = build_choice_description_list(choices)
|
||||
|
||||
if not choices:
|
||||
return description
|
||||
|
||||
if not description:
|
||||
return choice_description
|
||||
|
||||
if '\n\n*' in description:
|
||||
description, _, _ = description.partition('\n\n*')
|
||||
return description + '\n\n' + choice_description
|
||||
|
||||
if description.startswith('* `'):
|
||||
return choice_description
|
||||
|
||||
return description + '\n\n' + choice_description
|
||||
|
||||
@classmethod
|
||||
def _is_gis(cls, field):
|
||||
if not getattr(cls, '_has_gis', True):
|
||||
return False
|
||||
try:
|
||||
from django.contrib.gis.db.models import GeometryField
|
||||
from rest_framework_gis.filters import GeometryFilter
|
||||
|
||||
return isinstance(field, (GeometryField, GeometryFilter))
|
||||
except: # noqa
|
||||
cls._has_gis = False
|
||||
return False
|
||||
@@ -0,0 +1,49 @@
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
|
||||
|
||||
class DjangoOAuthToolkitScheme(OpenApiAuthenticationExtension):
|
||||
target_class = 'oauth2_provider.contrib.rest_framework.OAuth2Authentication'
|
||||
name = 'oauth2'
|
||||
|
||||
def get_security_requirement(self, auto_schema):
|
||||
from oauth2_provider.contrib.rest_framework import (
|
||||
IsAuthenticatedOrTokenHasScope, TokenHasScope, TokenMatchesOASRequirements,
|
||||
)
|
||||
view = auto_schema.view
|
||||
request = view.request
|
||||
|
||||
for permission in auto_schema.view.get_permissions():
|
||||
if isinstance(permission, TokenMatchesOASRequirements):
|
||||
alt_scopes = permission.get_required_alternate_scopes(request, view)
|
||||
alt_scopes = alt_scopes.get(auto_schema.method, [])
|
||||
return [{self.name: group} for group in alt_scopes]
|
||||
if isinstance(permission, IsAuthenticatedOrTokenHasScope):
|
||||
return {self.name: TokenHasScope().get_scopes(request, view)}
|
||||
if isinstance(permission, TokenHasScope):
|
||||
# catch-all for subclasses of TokenHasScope like TokenHasReadWriteScope
|
||||
return {self.name: permission.get_scopes(request, view)}
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
from oauth2_provider.scopes import get_scopes_backend
|
||||
|
||||
from drf_spectacular.settings import spectacular_settings
|
||||
|
||||
flows = {}
|
||||
for flow_type in spectacular_settings.OAUTH2_FLOWS:
|
||||
flows[flow_type] = {}
|
||||
if flow_type in ('implicit', 'authorizationCode'):
|
||||
flows[flow_type]['authorizationUrl'] = spectacular_settings.OAUTH2_AUTHORIZATION_URL
|
||||
if flow_type in ('password', 'clientCredentials', 'authorizationCode'):
|
||||
flows[flow_type]['tokenUrl'] = spectacular_settings.OAUTH2_TOKEN_URL
|
||||
if spectacular_settings.OAUTH2_REFRESH_URL:
|
||||
flows[flow_type]['refreshUrl'] = spectacular_settings.OAUTH2_REFRESH_URL
|
||||
if spectacular_settings.OAUTH2_SCOPES:
|
||||
flows[flow_type]['scopes'] = spectacular_settings.OAUTH2_SCOPES
|
||||
else:
|
||||
scope_backend = get_scopes_backend()
|
||||
flows[flow_type]['scopes'] = scope_backend.get_all_scopes()
|
||||
|
||||
return {
|
||||
'type': 'oauth2',
|
||||
'flows': flows
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
|
||||
def camelize_serializer_fields(result, generator, request, public):
|
||||
from django.conf import settings
|
||||
from djangorestframework_camel_case.settings import api_settings
|
||||
from djangorestframework_camel_case.util import camelize_re, underscore_to_camel
|
||||
|
||||
# prunes subtrees from camelization based on owning field name
|
||||
ignore_fields = api_settings.JSON_UNDERSCOREIZE.get("ignore_fields") or ()
|
||||
# ignore certain field names while camelizing
|
||||
ignore_keys = api_settings.JSON_UNDERSCOREIZE.get("ignore_keys") or ()
|
||||
|
||||
def has_middleware_installed():
|
||||
try:
|
||||
from djangorestframework_camel_case.middleware import CamelCaseMiddleWare
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
for middleware in [import_string(m) for m in settings.MIDDLEWARE]:
|
||||
try:
|
||||
if issubclass(middleware, CamelCaseMiddleWare):
|
||||
return True
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
def camelize_str(key: str) -> str:
|
||||
new_key = re.sub(camelize_re, underscore_to_camel, key) if "_" in key else key
|
||||
if key in ignore_keys or new_key in ignore_keys:
|
||||
return key
|
||||
return new_key
|
||||
|
||||
def camelize_component(schema: dict, name: Optional[str] = None) -> dict:
|
||||
if name is not None and (name in ignore_fields or camelize_str(name) in ignore_fields):
|
||||
return schema
|
||||
elif schema.get('type') == 'object':
|
||||
if 'properties' in schema:
|
||||
schema['properties'] = {
|
||||
camelize_str(field_name): camelize_component(field_schema, field_name)
|
||||
for field_name, field_schema in schema['properties'].items()
|
||||
}
|
||||
if 'required' in schema:
|
||||
schema['required'] = [camelize_str(field) for field in schema['required']]
|
||||
elif schema.get('type') == 'array':
|
||||
camelize_component(schema['items'])
|
||||
return schema
|
||||
|
||||
for (_, component_type), component in generator.registry._components.items():
|
||||
if component_type == 'schemas':
|
||||
camelize_component(component.schema)
|
||||
|
||||
if has_middleware_installed():
|
||||
for url_schema in result["paths"].values():
|
||||
for method_schema in url_schema.values():
|
||||
for parameter in method_schema.get("parameters", []):
|
||||
parameter["name"] = camelize_str(parameter["name"])
|
||||
|
||||
# inplace modification of components also affect result dict, so regeneration is not necessary
|
||||
return result
|
||||
@@ -0,0 +1,13 @@
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
from drf_spectacular.plumbing import build_bearer_security_scheme_object
|
||||
|
||||
|
||||
class KnoxTokenScheme(OpenApiAuthenticationExtension):
|
||||
target_class = 'knox.auth.TokenAuthentication'
|
||||
name = 'knoxApiToken'
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
return build_bearer_security_scheme_object(
|
||||
header_name='Authorization',
|
||||
token_prefix=self.target.authenticate_header(""),
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
from drf_spectacular.drainage import set_override, warn
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension
|
||||
from drf_spectacular.plumbing import ResolvedComponent, build_basic_type
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
|
||||
|
||||
class PydanticExtension(OpenApiSerializerExtension):
|
||||
"""
|
||||
Allows using pydantic models on @extend_schema(request=..., response=...) to
|
||||
describe your API.
|
||||
|
||||
We only have partial support for pydantic's version of dataclass, due to the way they
|
||||
are designed. The outermost class (the @extend_schema argument) has to be a subclass
|
||||
of pydantic.BaseModel. Inside this outermost BaseModel, any combination of dataclass
|
||||
and BaseModel can be used.
|
||||
"""
|
||||
|
||||
target_class = "pydantic.BaseModel"
|
||||
match_subclasses = True
|
||||
|
||||
def get_name(self, auto_schema, direction):
|
||||
# due to the fact that it is complicated to pull out every field member BaseModel class
|
||||
# of the entry model, we simply use the class name as string for object. This hack may
|
||||
# create false positive warnings, so turn it off. However, this may suppress correct
|
||||
# warnings involving the entry class.
|
||||
# TODO suppression may be migrated to new ComponentIdentity system
|
||||
set_override(self.target, 'suppress_collision_warning', True)
|
||||
return self.target.__name__
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
# let pydantic generate a JSON schema
|
||||
try:
|
||||
from pydantic.json_schema import model_json_schema
|
||||
except ImportError:
|
||||
warn("Only pydantic >= 2 is supported. defaulting to generic object.")
|
||||
return build_basic_type(OpenApiTypes.OBJECT)
|
||||
|
||||
schema = model_json_schema(self.target, ref_template="#/components/schemas/{model}", mode="serialization")
|
||||
|
||||
# pull out potential sub-schemas and put them into component section
|
||||
for sub_name, sub_schema in schema.pop("$defs", {}).items():
|
||||
component = ResolvedComponent(
|
||||
name=sub_name,
|
||||
type=ResolvedComponent.SCHEMA,
|
||||
object=sub_name,
|
||||
schema=sub_schema,
|
||||
)
|
||||
auto_schema.registry.register_on_missing(component)
|
||||
|
||||
return schema
|
||||
@@ -0,0 +1,173 @@
|
||||
from django.conf import settings
|
||||
from django.utils.version import get_version_tuple
|
||||
from rest_framework import serializers
|
||||
|
||||
from drf_spectacular.contrib.rest_framework_simplejwt import (
|
||||
SimpleJWTScheme, TokenRefreshSerializerExtension,
|
||||
)
|
||||
from drf_spectacular.drainage import warn
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiViewExtension
|
||||
from drf_spectacular.utils import extend_schema
|
||||
|
||||
|
||||
def get_dj_rest_auth_setting(class_name, setting_name):
|
||||
from dj_rest_auth.__version__ import __version__
|
||||
|
||||
if get_version_tuple(__version__) < (3, 0, 0):
|
||||
from dj_rest_auth import app_settings
|
||||
|
||||
return getattr(app_settings, class_name)
|
||||
else:
|
||||
from dj_rest_auth.app_settings import api_settings
|
||||
|
||||
return getattr(api_settings, setting_name)
|
||||
|
||||
|
||||
def get_token_serializer_class():
|
||||
from dj_rest_auth.__version__ import __version__
|
||||
|
||||
if get_version_tuple(__version__) < (3, 0, 0):
|
||||
use_jwt = getattr(settings, 'REST_USE_JWT', False)
|
||||
else:
|
||||
from dj_rest_auth.app_settings import api_settings
|
||||
|
||||
use_jwt = api_settings.USE_JWT
|
||||
|
||||
if use_jwt:
|
||||
return get_dj_rest_auth_setting('JWTSerializer', 'JWT_SERIALIZER')
|
||||
else:
|
||||
return get_dj_rest_auth_setting('TokenSerializer', 'TOKEN_SERIALIZER')
|
||||
|
||||
|
||||
class RestAuthDetailSerializer(serializers.Serializer):
|
||||
detail = serializers.CharField(read_only=True, required=False)
|
||||
|
||||
|
||||
class RestAuthDefaultResponseView(OpenApiViewExtension):
|
||||
def view_replacement(self):
|
||||
class Fixed(self.target_class):
|
||||
@extend_schema(responses=RestAuthDetailSerializer)
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
return Fixed
|
||||
|
||||
|
||||
class RestAuthLoginView(OpenApiViewExtension):
|
||||
target_class = 'dj_rest_auth.views.LoginView'
|
||||
|
||||
def view_replacement(self):
|
||||
class Fixed(self.target_class):
|
||||
@extend_schema(responses=get_token_serializer_class())
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
return Fixed
|
||||
|
||||
|
||||
class RestAuthLogoutView(OpenApiViewExtension):
|
||||
target_class = 'dj_rest_auth.views.LogoutView'
|
||||
|
||||
def view_replacement(self):
|
||||
if getattr(settings, 'ACCOUNT_LOGOUT_ON_GET', None):
|
||||
get_schema_params = {'responses': RestAuthDetailSerializer}
|
||||
else:
|
||||
get_schema_params = {'exclude': True}
|
||||
|
||||
class Fixed(self.target_class):
|
||||
@extend_schema(**get_schema_params)
|
||||
def get(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
@extend_schema(request=None, responses=RestAuthDetailSerializer)
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
return Fixed
|
||||
|
||||
|
||||
class RestAuthPasswordChangeView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.views.PasswordChangeView'
|
||||
|
||||
|
||||
class RestAuthPasswordResetView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.views.PasswordResetView'
|
||||
|
||||
|
||||
class RestAuthPasswordResetConfirmView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.views.PasswordResetConfirmView'
|
||||
|
||||
|
||||
class RestAuthVerifyEmailView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.registration.views.VerifyEmailView'
|
||||
optional = True
|
||||
|
||||
|
||||
class RestAuthResendEmailVerificationView(RestAuthDefaultResponseView):
|
||||
target_class = 'dj_rest_auth.registration.views.ResendEmailVerificationView'
|
||||
optional = True
|
||||
|
||||
|
||||
class RestAuthJWTSerializer(OpenApiSerializerExtension):
|
||||
target_class = 'dj_rest_auth.serializers.JWTSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
class Fixed(self.target_class):
|
||||
user = get_dj_rest_auth_setting('UserDetailsSerializer', 'USER_DETAILS_SERIALIZER')()
|
||||
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class CookieTokenRefreshSerializerExtension(TokenRefreshSerializerExtension):
|
||||
target_class = 'dj_rest_auth.jwt_auth.CookieTokenRefreshSerializer'
|
||||
optional = True
|
||||
|
||||
def get_name(self):
|
||||
return 'TokenRefresh'
|
||||
|
||||
|
||||
class RestAuthRegisterView(OpenApiViewExtension):
|
||||
target_class = 'dj_rest_auth.registration.views.RegisterView'
|
||||
optional = True
|
||||
|
||||
def view_replacement(self):
|
||||
from allauth.account.app_settings import EMAIL_VERIFICATION, EmailVerificationMethod
|
||||
|
||||
if EMAIL_VERIFICATION == EmailVerificationMethod.MANDATORY:
|
||||
response_serializer = RestAuthDetailSerializer
|
||||
else:
|
||||
response_serializer = get_token_serializer_class()
|
||||
|
||||
class Fixed(self.target_class):
|
||||
@extend_schema(responses=response_serializer)
|
||||
def post(self, request, *args, **kwargs):
|
||||
pass # pragma: no cover
|
||||
|
||||
return Fixed
|
||||
|
||||
|
||||
class SimpleJWTCookieScheme(SimpleJWTScheme):
|
||||
target_class = 'dj_rest_auth.jwt_auth.JWTCookieAuthentication'
|
||||
optional = True
|
||||
name = ['jwtHeaderAuth', 'jwtCookieAuth'] # type: ignore
|
||||
|
||||
def get_security_requirement(self, auto_schema):
|
||||
return [{name: []} for name in self.name]
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
cookie_name = get_dj_rest_auth_setting('JWT_AUTH_COOKIE', 'JWT_AUTH_COOKIE')
|
||||
if not cookie_name:
|
||||
cookie_name = 'jwt-auth'
|
||||
warn(
|
||||
f'"JWT_AUTH_COOKIE" setting required for JWTCookieAuthentication. '
|
||||
f'defaulting to {cookie_name}'
|
||||
)
|
||||
|
||||
return [
|
||||
super().get_security_definition(auto_schema), # JWT from header
|
||||
{
|
||||
'type': 'apiKey',
|
||||
'in': 'cookie',
|
||||
'name': cookie_name,
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
from drf_spectacular.extensions import OpenApiViewExtension
|
||||
|
||||
|
||||
class ObtainAuthTokenView(OpenApiViewExtension):
|
||||
target_class = 'rest_framework.authtoken.views.ObtainAuthToken'
|
||||
match_subclasses = True
|
||||
|
||||
def view_replacement(self):
|
||||
"""
|
||||
Prior to DRF 3.12.0, usage of ObtainAuthToken resulted in AssertionError
|
||||
|
||||
Incompatible AutoSchema used on View "ObtainAuthToken". Is DRF's DEFAULT_SCHEMA_CLASS ...
|
||||
|
||||
This is because DRF had a bug which made it NOT honor DEFAULT_SCHEMA_CLASS and instead
|
||||
injected an unsolicited coreschema class for this view and this view only. This extension
|
||||
fixes the view before the wrong schema class is used.
|
||||
|
||||
Bug in DRF that was fixed in later versions:
|
||||
https://github.com/encode/django-rest-framework/blob/4121b01b912668c049b26194a9a107c27a332429/rest_framework/authtoken/views.py#L16
|
||||
"""
|
||||
from rest_framework import VERSION
|
||||
|
||||
from drf_spectacular.openapi import AutoSchema
|
||||
|
||||
# no intervention needed
|
||||
if VERSION >= '3.12':
|
||||
return self.target
|
||||
|
||||
class FixedObtainAuthToken(self.target):
|
||||
schema = AutoSchema()
|
||||
|
||||
return FixedObtainAuthToken
|
||||
@@ -0,0 +1,36 @@
|
||||
from typing import Any
|
||||
|
||||
from drf_spectacular.drainage import get_override, has_override
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension
|
||||
from drf_spectacular.plumbing import ComponentIdentity, get_doc
|
||||
from drf_spectacular.utils import Direction
|
||||
|
||||
|
||||
class OpenApiDataclassSerializerExtensions(OpenApiSerializerExtension):
|
||||
target_class = "rest_framework_dataclasses.serializers.DataclassSerializer"
|
||||
match_subclasses = True
|
||||
|
||||
def get_name(self):
|
||||
"""Use the dataclass name in the schema, instead of the serializer prefix (which can be just Dataclass)."""
|
||||
if has_override(self.target, 'component_name'):
|
||||
return get_override(self.target, 'component_name')
|
||||
if getattr(getattr(self.target, 'Meta', None), 'ref_name', None) is not None:
|
||||
return self.target.Meta.ref_name
|
||||
if has_override(self.target.dataclass_definition.dataclass_type, 'component_name'):
|
||||
return get_override(self.target.dataclass_definition.dataclass_type, 'component_name')
|
||||
return self.target.dataclass_definition.dataclass_type.__name__
|
||||
|
||||
def get_identity(self, auto_schema, direction: Direction) -> Any:
|
||||
return ComponentIdentity(self.target.dataclass_definition.dataclass_type)
|
||||
|
||||
def strip_library_doc(self, schema):
|
||||
"""Strip the DataclassSerializer library documentation from the schema."""
|
||||
from rest_framework_dataclasses.serializers import DataclassSerializer
|
||||
if 'description' in schema and schema['description'] == get_doc(DataclassSerializer):
|
||||
del schema['description']
|
||||
return schema
|
||||
|
||||
def map_serializer(self, auto_schema, direction: Direction):
|
||||
""""Generate the schema for a DataclassSerializer."""
|
||||
schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
|
||||
return self.strip_library_doc(schema)
|
||||
@@ -0,0 +1,219 @@
|
||||
from rest_framework.utils.model_meta import get_field_info
|
||||
|
||||
from drf_spectacular.drainage import warn
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension, OpenApiSerializerFieldExtension
|
||||
from drf_spectacular.plumbing import (
|
||||
ResolvedComponent, build_array_type, build_object_type, follow_field_source, get_doc,
|
||||
)
|
||||
|
||||
|
||||
def build_point_schema():
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "number", "format": "float"},
|
||||
"example": [12.9721, 77.5933],
|
||||
"minItems": 2,
|
||||
"maxItems": 3,
|
||||
}
|
||||
|
||||
|
||||
def build_linestring_schema():
|
||||
return {
|
||||
"type": "array",
|
||||
"items": build_point_schema(),
|
||||
"example": [[22.4707, 70.0577], [12.9721, 77.5933]],
|
||||
"minItems": 2,
|
||||
}
|
||||
|
||||
|
||||
def build_polygon_schema():
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {**build_linestring_schema(), "minItems": 4},
|
||||
"example": [
|
||||
[
|
||||
[0.0, 0.0],
|
||||
[0.0, 50.0],
|
||||
[50.0, 50.0],
|
||||
[50.0, 0.0],
|
||||
[0.0, 0.0],
|
||||
],
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def build_geo_container_schema(name, coords):
|
||||
return build_object_type(
|
||||
properties={
|
||||
"type": {"type": "string", "enum": [name]},
|
||||
"coordinates": coords,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def build_point_geo_schema():
|
||||
return build_geo_container_schema("Point", build_point_schema())
|
||||
|
||||
|
||||
def build_linestring_geo_schema():
|
||||
return build_geo_container_schema("LineString", build_linestring_schema())
|
||||
|
||||
|
||||
def build_polygon_geo_schema():
|
||||
return build_geo_container_schema("Polygon", build_polygon_schema())
|
||||
|
||||
|
||||
def build_geometry_geo_schema():
|
||||
return {
|
||||
'oneOf': [
|
||||
build_point_geo_schema(),
|
||||
build_linestring_geo_schema(),
|
||||
build_polygon_geo_schema(),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def build_bbox_schema():
|
||||
return {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"minItems": 4,
|
||||
"maxItems": 4,
|
||||
"example": [12.9721, 77.5933, 12.9721, 77.5933],
|
||||
}
|
||||
|
||||
|
||||
def build_geo_schema(model_field):
|
||||
from django.contrib.gis.db import models
|
||||
|
||||
if isinstance(model_field, models.PointField):
|
||||
return build_point_geo_schema()
|
||||
elif isinstance(model_field, models.LineStringField):
|
||||
return build_linestring_geo_schema()
|
||||
elif isinstance(model_field, models.PolygonField):
|
||||
return build_polygon_geo_schema()
|
||||
elif isinstance(model_field, models.MultiPointField):
|
||||
return build_geo_container_schema(
|
||||
"MultiPoint", build_array_type(build_point_schema())
|
||||
)
|
||||
elif isinstance(model_field, models.MultiLineStringField):
|
||||
return build_geo_container_schema(
|
||||
"MultiLineString", build_array_type(build_linestring_schema())
|
||||
)
|
||||
elif isinstance(model_field, models.MultiPolygonField):
|
||||
return build_geo_container_schema(
|
||||
"MultiPolygon", build_array_type(build_polygon_schema())
|
||||
)
|
||||
elif isinstance(model_field, models.GeometryCollectionField):
|
||||
return build_geo_container_schema(
|
||||
"GeometryCollection", build_array_type(build_geometry_geo_schema())
|
||||
)
|
||||
elif isinstance(model_field, models.GeometryField):
|
||||
return build_geometry_geo_schema()
|
||||
else:
|
||||
warn("Encountered unknown GIS geometry field")
|
||||
return {}
|
||||
|
||||
|
||||
def map_geo_field(serializer, geo_field_name):
|
||||
from rest_framework_gis.fields import GeometrySerializerMethodField
|
||||
|
||||
field = serializer.fields[geo_field_name]
|
||||
if isinstance(field, GeometrySerializerMethodField):
|
||||
warn("Geometry generation for GeometrySerializerMethodField is not supported.")
|
||||
return {}
|
||||
model_field = get_field_info(serializer.Meta.model).fields[geo_field_name]
|
||||
return build_geo_schema(model_field)
|
||||
|
||||
|
||||
def _inject_enum_collision_fix(collection):
|
||||
from drf_spectacular.settings import spectacular_settings
|
||||
if not collection and 'GisFeatureEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
|
||||
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureEnum'] = ('Feature',)
|
||||
if collection and 'GisFeatureCollectionEnum' not in spectacular_settings.ENUM_NAME_OVERRIDES:
|
||||
spectacular_settings.ENUM_NAME_OVERRIDES['GisFeatureCollectionEnum'] = ('FeatureCollection',)
|
||||
|
||||
|
||||
class GeoFeatureModelSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_gis.serializers.GeoFeatureModelSerializer'
|
||||
match_subclasses = True
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
_inject_enum_collision_fix(collection=False)
|
||||
|
||||
base_schema = auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
|
||||
return self.map_geo_feature_model_serializer(self.target, base_schema)
|
||||
|
||||
def map_geo_feature_model_serializer(self, serializer, base_schema):
|
||||
from rest_framework_gis.serializers import GeoFeatureModelSerializer
|
||||
|
||||
geo_properties = {
|
||||
"type": {"type": "string", "enum": ["Feature"]}
|
||||
}
|
||||
if serializer.Meta.id_field:
|
||||
geo_properties["id"] = base_schema["properties"].pop(serializer.Meta.id_field)
|
||||
|
||||
geo_properties["geometry"] = map_geo_field(serializer, serializer.Meta.geo_field)
|
||||
base_schema["properties"].pop(serializer.Meta.geo_field)
|
||||
|
||||
if serializer.Meta.auto_bbox or serializer.Meta.bbox_geo_field:
|
||||
geo_properties["bbox"] = build_bbox_schema()
|
||||
base_schema["properties"].pop(serializer.Meta.bbox_geo_field, None)
|
||||
|
||||
# only expose if description comes from the user
|
||||
description = base_schema.pop('description', None)
|
||||
if description == get_doc(GeoFeatureModelSerializer):
|
||||
description = None
|
||||
|
||||
# ignore this aspect for now
|
||||
base_schema.pop('required', None)
|
||||
|
||||
# nest remaining fields under property "properties"
|
||||
geo_properties["properties"] = base_schema
|
||||
|
||||
return build_object_type(
|
||||
properties=geo_properties,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
class GeoFeatureModelListSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_gis.serializers.GeoFeatureModelListSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
_inject_enum_collision_fix(collection=True)
|
||||
|
||||
# build/retrieve feature component generated by GeoFeatureModelSerializerExtension.
|
||||
# wrap the ref in the special list structure and build another component based on that.
|
||||
feature_component = auto_schema.resolve_serializer(self.target.child, direction)
|
||||
collection_schema = build_object_type(
|
||||
properties={
|
||||
"type": {"type": "string", "enum": ["FeatureCollection"]},
|
||||
"features": build_array_type(feature_component.ref)
|
||||
}
|
||||
)
|
||||
list_component = ResolvedComponent(
|
||||
name=f'{feature_component.name}List',
|
||||
type=ResolvedComponent.SCHEMA,
|
||||
object=self.target.child,
|
||||
schema=collection_schema
|
||||
)
|
||||
auto_schema.registry.register_on_missing(list_component)
|
||||
return list_component.ref
|
||||
|
||||
|
||||
class GeometryFieldExtension(OpenApiSerializerFieldExtension):
|
||||
target_class = 'rest_framework_gis.fields.GeometryField'
|
||||
match_subclasses = True
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
# running this extension for GeoFeatureModelSerializer's geo_field is superfluous
|
||||
# as above extension already handles that individually. We run it anyway because
|
||||
# robustly checking the proper condition is harder.
|
||||
try:
|
||||
model = self.target.parent.Meta.model
|
||||
model_field = follow_field_source(model, self.target.source.split('.'))
|
||||
return build_geo_schema(model_field)
|
||||
except: # noqa: E722
|
||||
warn(f'Encountered an issue resolving field {self.target}. defaulting to generic object.')
|
||||
return {}
|
||||
@@ -0,0 +1,16 @@
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension
|
||||
from drf_spectacular.plumbing import build_bearer_security_scheme_object
|
||||
|
||||
|
||||
class JWTScheme(OpenApiAuthenticationExtension):
|
||||
target_class = 'rest_framework_jwt.authentication.JSONWebTokenAuthentication'
|
||||
name = 'jwtAuth'
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
from rest_framework_jwt.settings import api_settings
|
||||
|
||||
return build_bearer_security_scheme_object(
|
||||
header_name='AUTHORIZATION',
|
||||
token_prefix=api_settings.JWT_AUTH_HEADER_PREFIX,
|
||||
bearer_format='JWT'
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
from drf_spectacular.extensions import OpenApiSerializerFieldExtension
|
||||
from drf_spectacular.plumbing import build_array_type, is_list_serializer
|
||||
|
||||
|
||||
class RecursiveFieldExtension(OpenApiSerializerFieldExtension):
|
||||
target_class = "rest_framework_recursive.fields.RecursiveField"
|
||||
|
||||
def map_serializer_field(self, auto_schema, direction):
|
||||
proxied = self.target.proxied
|
||||
|
||||
if is_list_serializer(proxied):
|
||||
component = auto_schema.resolve_serializer(proxied.child, direction)
|
||||
return build_array_type(component.ref)
|
||||
|
||||
component = auto_schema.resolve_serializer(proxied, direction)
|
||||
return component.ref
|
||||
@@ -0,0 +1,87 @@
|
||||
from rest_framework import serializers
|
||||
|
||||
from drf_spectacular.drainage import warn
|
||||
from drf_spectacular.extensions import OpenApiAuthenticationExtension, OpenApiSerializerExtension
|
||||
from drf_spectacular.plumbing import build_bearer_security_scheme_object
|
||||
from drf_spectacular.utils import inline_serializer
|
||||
|
||||
|
||||
class TokenObtainPairSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_simplejwt.serializers.TokenObtainPairSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
Fixed = inline_serializer('Fixed', fields={
|
||||
self.target_class.username_field: serializers.CharField(write_only=True),
|
||||
'password': serializers.CharField(write_only=True),
|
||||
'access': serializers.CharField(read_only=True),
|
||||
'refresh': serializers.CharField(read_only=True),
|
||||
})
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class TokenObtainSlidingSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
Fixed = inline_serializer('Fixed', fields={
|
||||
self.target_class.username_field: serializers.CharField(write_only=True),
|
||||
'password': serializers.CharField(write_only=True),
|
||||
'token': serializers.CharField(read_only=True),
|
||||
})
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class TokenRefreshSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_simplejwt.serializers.TokenRefreshSerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
from rest_framework_simplejwt.settings import api_settings
|
||||
|
||||
if api_settings.ROTATE_REFRESH_TOKENS:
|
||||
class Fixed(serializers.Serializer):
|
||||
access = serializers.CharField(read_only=True)
|
||||
refresh = serializers.CharField()
|
||||
else:
|
||||
class Fixed(serializers.Serializer):
|
||||
access = serializers.CharField(read_only=True)
|
||||
refresh = serializers.CharField(write_only=True)
|
||||
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class TokenVerifySerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_framework_simplejwt.serializers.TokenVerifySerializer'
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
Fixed = inline_serializer('Fixed', fields={
|
||||
'token': serializers.CharField(write_only=True),
|
||||
})
|
||||
return auto_schema._map_serializer(Fixed, direction)
|
||||
|
||||
|
||||
class SimpleJWTScheme(OpenApiAuthenticationExtension):
|
||||
target_class = 'rest_framework_simplejwt.authentication.JWTAuthentication'
|
||||
name = 'jwtAuth'
|
||||
|
||||
def get_security_definition(self, auto_schema):
|
||||
from rest_framework_simplejwt.settings import api_settings
|
||||
|
||||
if len(api_settings.AUTH_HEADER_TYPES) > 1:
|
||||
warn(
|
||||
f'OpenAPI3 can only have one "bearerFormat". JWT Settings specify '
|
||||
f'{api_settings.AUTH_HEADER_TYPES}. Using the first one.'
|
||||
)
|
||||
|
||||
return build_bearer_security_scheme_object(
|
||||
header_name=getattr(api_settings, 'AUTH_HEADER_NAME', 'HTTP_AUTHORIZATION'),
|
||||
token_prefix=api_settings.AUTH_HEADER_TYPES[0],
|
||||
bearer_format='JWT'
|
||||
)
|
||||
|
||||
|
||||
class SimpleJWTTokenUserScheme(SimpleJWTScheme):
|
||||
target_class = 'rest_framework_simplejwt.authentication.JWTTokenUserAuthentication'
|
||||
|
||||
|
||||
class SimpleJWTStatelessUserScheme(SimpleJWTScheme):
|
||||
target_class = "rest_framework_simplejwt.authentication.JWTStatelessUserAuthentication"
|
||||
@@ -0,0 +1,81 @@
|
||||
from drf_spectacular.drainage import warn
|
||||
from drf_spectacular.extensions import OpenApiSerializerExtension
|
||||
from drf_spectacular.plumbing import (
|
||||
ComponentIdentity, ResolvedComponent, build_basic_type, build_object_type,
|
||||
is_patched_serializer,
|
||||
)
|
||||
from drf_spectacular.settings import spectacular_settings
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
|
||||
|
||||
class PolymorphicSerializerExtension(OpenApiSerializerExtension):
|
||||
target_class = 'rest_polymorphic.serializers.PolymorphicSerializer'
|
||||
match_subclasses = True
|
||||
|
||||
def map_serializer(self, auto_schema, direction):
|
||||
sub_components = []
|
||||
serializer = self.target
|
||||
|
||||
for sub_model in serializer.model_serializer_mapping:
|
||||
sub_serializer = serializer._get_serializer_from_model_or_instance(sub_model)
|
||||
sub_serializer.partial = serializer.partial
|
||||
resource_type = serializer.to_resource_type(sub_model)
|
||||
component = auto_schema.resolve_serializer(sub_serializer, direction)
|
||||
if not component:
|
||||
# rebuild a virtual schema-less component to model empty serializers
|
||||
component = ResolvedComponent(
|
||||
name=auto_schema._get_serializer_name(sub_serializer, direction),
|
||||
type=ResolvedComponent.SCHEMA,
|
||||
object=ComponentIdentity('virtual')
|
||||
)
|
||||
typed_component = self.build_typed_component(
|
||||
auto_schema=auto_schema,
|
||||
component=component,
|
||||
resource_type_field_name=serializer.resource_type_field_name,
|
||||
patched=is_patched_serializer(sub_serializer, direction)
|
||||
)
|
||||
sub_components.append((resource_type, typed_component.ref))
|
||||
|
||||
if not resource_type:
|
||||
warn(
|
||||
f'discriminator mapping key is empty for {sub_serializer.__class__}. '
|
||||
f'this might lead to code generation issues.'
|
||||
)
|
||||
|
||||
one_of_list = []
|
||||
for _, ref in sub_components:
|
||||
if ref not in one_of_list:
|
||||
one_of_list.append(ref)
|
||||
|
||||
return {
|
||||
'oneOf': one_of_list,
|
||||
'discriminator': {
|
||||
'propertyName': serializer.resource_type_field_name,
|
||||
'mapping': {resource_type: ref['$ref'] for resource_type, ref in sub_components},
|
||||
}
|
||||
}
|
||||
|
||||
def build_typed_component(self, auto_schema, component, resource_type_field_name, patched):
|
||||
if spectacular_settings.COMPONENT_SPLIT_REQUEST and component.name.endswith('Request'):
|
||||
typed_component_name = component.name[:-len('Request')] + 'TypedRequest'
|
||||
else:
|
||||
typed_component_name = f'{component.name}Typed'
|
||||
|
||||
resource_type_schema = build_object_type(
|
||||
properties={resource_type_field_name: build_basic_type(OpenApiTypes.STR)},
|
||||
required=None if patched else [resource_type_field_name]
|
||||
)
|
||||
# if sub-serializer has an empty schema, only expose the resource_type field part
|
||||
if component.schema:
|
||||
schema = {'allOf': [resource_type_schema, component.ref]}
|
||||
else:
|
||||
schema = resource_type_schema
|
||||
|
||||
component_typed = ResolvedComponent(
|
||||
name=typed_component_name,
|
||||
type=ResolvedComponent.SCHEMA,
|
||||
object=component.object,
|
||||
schema=schema,
|
||||
)
|
||||
auto_schema.registry.register_on_missing(component_typed)
|
||||
return component_typed
|
||||
Reference in New Issue
Block a user