From f7926bd9bebda856aa31ed5993e81a28a274b84a Mon Sep 17 00:00:00 2001 From: Konrad Mohrfeldt <konrad.mohrfeldt@farbdev.org> Date: Fri, 18 Mar 2022 14:55:25 +0100 Subject: [PATCH] refactor: implement consistent pk/slug retrieval for shows The show retrieve method allowed shows to be identified through the slug or the id. This is handy, but was restricted to the retrieve and did not apply to the update nor delete methods. The proper way to implement this kind of behaviour is through overriding get_object so that route identifiers are handled consistently. --- program/utils.py | 22 ++++++++++++++-------- program/views.py | 40 +++++++++++++++++----------------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/program/utils.py b/program/utils.py index 53b6420d..305f657c 100644 --- a/program/utils.py +++ b/program/utils.py @@ -109,14 +109,20 @@ def get_values( return int_if_digit(values[0]) -def get_pk_and_slug(kwargs: Dict[str, str]) -> Tuple[Optional[int], Optional[str]]: - """Get the pk and the slug from the kwargs.""" +class DisabledObjectPermissionCheckMixin: + """ + At the time of writing permission checks were entirely circumvented by manual + queries in viewsets. To make code refactoring easier and allow + the paced introduction of .get_object() in viewsets, object permission checks + need to be disabled until permission checks have been refactored as well. - pk, slug = None, None + Object permissions checks should become mandatory once proper permission_classes + are assigned to viewsets. This mixin should be removed afterwards. + """ - try: - pk = int(kwargs["pk"]) - except ValueError: - slug = kwargs["pk"] + # The text above becomes the viewset’s doc string otherwise and is displayed in + # the generated OpenAPI schema. + __doc__ = None - return pk, slug + def check_object_permissions(self, request, obj): + pass diff --git a/program/views.py b/program/views.py index f28c0c3e..0f4503be 100644 --- a/program/views.py +++ b/program/views.py @@ -59,7 +59,7 @@ from program.serializers import ( TypeSerializer, UserSerializer, ) -from program.utils import get_pk_and_slug, get_values, parse_date +from program.utils import DisabledObjectPermissionCheckMixin, get_values, parse_date logger = logging.getLogger(__name__) @@ -270,7 +270,7 @@ class APIUserViewSet( return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) -class APIShowViewSet(viewsets.ModelViewSet): +class APIShowViewSet(DisabledObjectPermissionCheckMixin, viewsets.ModelViewSet): """ Returns a list of available shows. @@ -283,6 +283,19 @@ class APIShowViewSet(viewsets.ModelViewSet): pagination_class = LimitOffsetPagination filterset_class = filters.ShowFilterSet + def get_object(self): + queryset = self.filter_queryset(self.get_queryset()) + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + lookup_arg = self.kwargs[lookup_url_kwarg] + # allow object retrieval through id or slug + try: + filter_kwargs = {self.lookup_field: int(lookup_arg)} + except ValueError: + filter_kwargs = {"slug": lookup_arg} + obj = get_object_or_404(queryset, **filter_kwargs) + self.check_object_permissions(self.request, obj) + return obj + def create(self, request, *args, **kwargs): """ Create a show. @@ -301,23 +314,6 @@ class APIShowViewSet(viewsets.ModelViewSet): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def retrieve(self, request, *args, **kwargs): - """Returns a single show""" - - pk, slug = get_pk_and_slug(self.kwargs) - - show = ( - get_object_or_404(Show, pk=pk) - if pk - else get_object_or_404(Show, slug=slug) - if slug - else None - ) - - serializer = ShowSerializer(show) - - return Response(serializer.data) - def update(self, request, *args, **kwargs): """ Update a show. @@ -332,7 +328,7 @@ class APIShowViewSet(viewsets.ModelViewSet): ): return Response(status=status.HTTP_401_UNAUTHORIZED) - show = get_object_or_404(Show, pk=pk) + show = self.get_object() serializer = ShowSerializer( show, data=request.data, context={"user": request.user} ) @@ -356,9 +352,7 @@ class APIShowViewSet(viewsets.ModelViewSet): if not request.user.is_superuser: return Response(status=status.HTTP_401_UNAUTHORIZED) - pk = get_values(self.kwargs, "pk") - - Show.objects.get(pk=pk).delete() + self.get_object().delete() return Response(status=status.HTTP_204_NO_CONTENT) -- GitLab