Skip to content
Snippets Groups Projects
Commit f7926bd9 authored by Konrad Mohrfeldt's avatar Konrad Mohrfeldt :koala:
Browse files

refactor: implement consistent pk/slug retrieval for shows

The show retrieve method allowed shows to be identified through the slug
or the id. This is handy, but was restricted to the retrieve and did not
apply to the update nor delete methods. The proper way to implement this
kind of behaviour is through overriding get_object so that route
identifiers are handled consistently.
parent a03f60a2
No related branches found
No related tags found
1 merge request!21Add API documentation
...@@ -109,14 +109,20 @@ def get_values( ...@@ -109,14 +109,20 @@ def get_values(
return int_if_digit(values[0]) return int_if_digit(values[0])
def get_pk_and_slug(kwargs: Dict[str, str]) -> Tuple[Optional[int], Optional[str]]: class DisabledObjectPermissionCheckMixin:
"""Get the pk and the slug from the kwargs.""" """
At the time of writing permission checks were entirely circumvented by manual
queries in viewsets. To make code refactoring easier and allow
the paced introduction of .get_object() in viewsets, object permission checks
need to be disabled until permission checks have been refactored as well.
pk, slug = None, None Object permissions checks should become mandatory once proper permission_classes
are assigned to viewsets. This mixin should be removed afterwards.
"""
try: # The text above becomes the viewset’s doc string otherwise and is displayed in
pk = int(kwargs["pk"]) # the generated OpenAPI schema.
except ValueError: __doc__ = None
slug = kwargs["pk"]
return pk, slug def check_object_permissions(self, request, obj):
pass
...@@ -59,7 +59,7 @@ from program.serializers import ( ...@@ -59,7 +59,7 @@ from program.serializers import (
TypeSerializer, TypeSerializer,
UserSerializer, UserSerializer,
) )
from program.utils import get_pk_and_slug, get_values, parse_date from program.utils import DisabledObjectPermissionCheckMixin, get_values, parse_date
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -270,7 +270,7 @@ class APIUserViewSet( ...@@ -270,7 +270,7 @@ class APIUserViewSet(
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
class APIShowViewSet(viewsets.ModelViewSet): class APIShowViewSet(DisabledObjectPermissionCheckMixin, viewsets.ModelViewSet):
""" """
Returns a list of available shows. Returns a list of available shows.
...@@ -283,6 +283,19 @@ class APIShowViewSet(viewsets.ModelViewSet): ...@@ -283,6 +283,19 @@ class APIShowViewSet(viewsets.ModelViewSet):
pagination_class = LimitOffsetPagination pagination_class = LimitOffsetPagination
filterset_class = filters.ShowFilterSet filterset_class = filters.ShowFilterSet
def get_object(self):
queryset = self.filter_queryset(self.get_queryset())
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_arg = self.kwargs[lookup_url_kwarg]
# allow object retrieval through id or slug
try:
filter_kwargs = {self.lookup_field: int(lookup_arg)}
except ValueError:
filter_kwargs = {"slug": lookup_arg}
obj = get_object_or_404(queryset, **filter_kwargs)
self.check_object_permissions(self.request, obj)
return obj
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
""" """
Create a show. Create a show.
...@@ -301,23 +314,6 @@ class APIShowViewSet(viewsets.ModelViewSet): ...@@ -301,23 +314,6 @@ class APIShowViewSet(viewsets.ModelViewSet):
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def retrieve(self, request, *args, **kwargs):
"""Returns a single show"""
pk, slug = get_pk_and_slug(self.kwargs)
show = (
get_object_or_404(Show, pk=pk)
if pk
else get_object_or_404(Show, slug=slug)
if slug
else None
)
serializer = ShowSerializer(show)
return Response(serializer.data)
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
""" """
Update a show. Update a show.
...@@ -332,7 +328,7 @@ class APIShowViewSet(viewsets.ModelViewSet): ...@@ -332,7 +328,7 @@ class APIShowViewSet(viewsets.ModelViewSet):
): ):
return Response(status=status.HTTP_401_UNAUTHORIZED) return Response(status=status.HTTP_401_UNAUTHORIZED)
show = get_object_or_404(Show, pk=pk) show = self.get_object()
serializer = ShowSerializer( serializer = ShowSerializer(
show, data=request.data, context={"user": request.user} show, data=request.data, context={"user": request.user}
) )
...@@ -356,9 +352,7 @@ class APIShowViewSet(viewsets.ModelViewSet): ...@@ -356,9 +352,7 @@ class APIShowViewSet(viewsets.ModelViewSet):
if not request.user.is_superuser: if not request.user.is_superuser:
return Response(status=status.HTTP_401_UNAUTHORIZED) return Response(status=status.HTTP_401_UNAUTHORIZED)
pk = get_values(self.kwargs, "pk") self.get_object().delete()
Show.objects.get(pk=pk).delete()
return Response(status=status.HTTP_204_NO_CONTENT) return Response(status=status.HTTP_204_NO_CONTENT)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment