diff --git a/program/utils.py b/program/utils.py index 305f657cc6212fa77eae3bf34a215b2a6cf88139..d21eaaf81ea9f1e66d2b4551d59a190fdd1cc262 100644 --- a/program/utils.py +++ b/program/utils.py @@ -23,6 +23,7 @@ from datetime import date, datetime, time from typing import Dict, Optional, Tuple, Union import requests +from rest_framework import exceptions from django.utils import timezone from steering.settings import CBA_AJAX_URL, CBA_API_KEY, DEBUG @@ -126,3 +127,22 @@ class DisabledObjectPermissionCheckMixin: def check_object_permissions(self, request, obj): pass + + +class NestedObjectFinderMixin: + ROUTE_FILTER_LOOKUPS = {} + + def _get_route_filters(self) -> Dict[str, int]: + filter_kwargs = {} + for key, value in self.kwargs.items(): + if key in self.ROUTE_FILTER_LOOKUPS: + try: + filter_kwargs[self.ROUTE_FILTER_LOOKUPS[key]] = int(value) + except ValueError: + raise exceptions.ValidationError( + detail=f"{key} must map to an integer value.", code="invalid-pk" + ) + return filter_kwargs + + def get_queryset(self): + return super().get_queryset().filter(**self._get_route_filters()) diff --git a/program/views.py b/program/views.py index 0f4503be1a89452076b6225ebaa17f517a009c86..b3b8fcd7c96d1de182ef7c62622c917070247866 100644 --- a/program/views.py +++ b/program/views.py @@ -59,7 +59,12 @@ from program.serializers import ( TypeSerializer, UserSerializer, ) -from program.utils import DisabledObjectPermissionCheckMixin, get_values, parse_date +from program.utils import ( + DisabledObjectPermissionCheckMixin, + NestedObjectFinderMixin, + get_values, + parse_date, +) logger = logging.getLogger(__name__) @@ -357,40 +362,25 @@ class APIShowViewSet(DisabledObjectPermissionCheckMixin, viewsets.ModelViewSet): return Response(status=status.HTTP_204_NO_CONTENT) -class APIScheduleViewSet(viewsets.ModelViewSet): +class APIScheduleViewSet( + DisabledObjectPermissionCheckMixin, + NestedObjectFinderMixin, + viewsets.ModelViewSet, +): """ Returns a list of schedules. Only superusers may create and update schedules. """ + ROUTE_FILTER_LOOKUPS = { + "show_pk": "show", + } + queryset = Schedule.objects.all() serializer_class = ScheduleSerializer permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly] - def get_queryset(self): - queryset = super().get_queryset() - - # subroute filters - show_pk = get_values(self.kwargs, "show_pk") - if show_pk: - queryset = queryset.filter(show=show_pk) - - return queryset - - def retrieve(self, request, *args, **kwargs): - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") - - schedule = ( - get_object_or_404(Schedule, pk=pk, show=show_pk) - if show_pk - else get_object_or_404(Schedule, pk=pk) - ) - - serializer = ScheduleSerializer(schedule) - - return Response(serializer.data) - def create(self, request, *args, **kwargs): """ Create a schedule, generate timeslots, test for collisions and resolve them including notes @@ -438,18 +428,16 @@ class APIScheduleViewSet(viewsets.ModelViewSet): if not request.user.is_superuser: return Response(status=status.HTTP_401_UNAUTHORIZED) - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") - - # Only allow updating when calling /shows/{show_pk}/schedules/{pk}/ and with the `schedule` - # JSON object - if show_pk is None or "schedule" not in request.data: + # Only allow updating when with the `schedule` JSON object + if "schedule" not in request.data: return Response(status=status.HTTP_400_BAD_REQUEST) + schedule = self.get_object() + # If default playlist id or repetition are given, just update if default_playlist_id := request.data.get("schedule").get( "default_playlist_id" ): - schedule = get_object_or_404(Schedule, pk=pk, show=show_pk) schedule.default_playlist_id = int(default_playlist_id) schedule.save() @@ -457,7 +445,6 @@ class APIScheduleViewSet(viewsets.ModelViewSet): return Response(serializer.data) if is_repetition := request.data.get("schedule").get("is_repetition"): - schedule = get_object_or_404(Schedule, pk=pk, show=show_pk) schedule.is_repetition = bool(is_repetition) schedule.save() @@ -468,11 +455,15 @@ class APIScheduleViewSet(viewsets.ModelViewSet): if "solutions" not in request.data: # TODO: respond with status.HTTP_409_CONFLICT when the dashboard can handle it return Response( - Schedule.make_conflicts(request.data["schedule"], pk, show_pk) + Schedule.make_conflicts( + request.data["schedule"], schedule.pk, schedule.show.pk + ) ) # Otherwise try to resolve - resolution = Schedule.resolve_conflicts(request.data, pk, show_pk) + resolution = Schedule.resolve_conflicts( + request.data, schedule.pk, schedule.show.pk + ) # If resolution went well if "projected" not in resolution: @@ -492,13 +483,7 @@ class APIScheduleViewSet(viewsets.ModelViewSet): if not request.user.is_superuser: return Response(status=status.HTTP_401_UNAUTHORIZED) - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") - - # Only allow deleting when calling /shows/{show_pk}/schedules/{pk} - if show_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - Schedule.objects.get(pk=pk).delete() + self.get_object().delete() return Response(status=status.HTTP_204_NO_CONTENT) @@ -507,6 +492,8 @@ class APIScheduleViewSet(viewsets.ModelViewSet): # by creating or updating a schedule. # There might be a use case for adding a single timeslot without any conflicts though. class APITimeSlotViewSet( + DisabledObjectPermissionCheckMixin, + NestedObjectFinderMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, @@ -522,41 +509,21 @@ class APITimeSlotViewSet( Timeslots may only be added by creating/updating a schedule. """ + ROUTE_FILTER_LOOKUPS = { + "show_pk": "show", + "schedule_pk": "schedule", + } + permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly] serializer_class = TimeSlotSerializer pagination_class = LimitOffsetPagination queryset = TimeSlot.objects.all().order_by("-start") filterset_class = filters.TimeSlotFilterSet - def get_queryset(self): - queryset = super().get_queryset() - - # subroute filters - show_pk, schedule_pk = get_values(self.kwargs, "show_pk", "schedule_pk") - if show_pk: - queryset = queryset.filter(show=show_pk) - if schedule_pk: - queryset = queryset.filter(schedule=schedule_pk) - - return queryset - - def retrieve(self, request, *args, **kwargs): - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") - - if show_pk: - timeslot = get_object_or_404(TimeSlot, pk=pk, show=show_pk) - else: - timeslot = get_object_or_404(TimeSlot, pk=pk) - - serializer = TimeSlotSerializer(timeslot) - return Response(serializer.data) - def update(self, request, *args, **kwargs): """Link a playlist_id to a timeslot""" - pk, show_pk, schedule_pk = get_values( - self.kwargs, "pk", "show_pk", "schedule_pk" - ) + show_pk = get_values(self.kwargs, "show_pk") if ( not request.user.is_superuser @@ -564,15 +531,7 @@ class APITimeSlotViewSet( ): return Response(status=status.HTTP_401_UNAUTHORIZED) - # Update is only allowed when calling /shows/1/schedules/1/timeslots/1 and if user owns the - # show - if schedule_pk is None or show_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - timeslot = get_object_or_404( - TimeSlot, pk=pk, schedule=schedule_pk, show=show_pk - ) - + timeslot = self.get_object() serializer = TimeSlotSerializer(timeslot, data=request.data) if serializer.is_valid(): serializer.save() @@ -600,48 +559,38 @@ class APITimeSlotViewSet( if not request.user.is_superuser: return Response(status=status.HTTP_401_UNAUTHORIZED) - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") - - # Only allow when calling endpoint starting with /shows/1/... - if show_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - TimeSlot.objects.get(pk=pk).delete() + self.get_object().delete() return Response(status=status.HTTP_204_NO_CONTENT) -class APINoteViewSet(viewsets.ModelViewSet): +class APINoteViewSet( + DisabledObjectPermissionCheckMixin, + NestedObjectFinderMixin, + viewsets.ModelViewSet, +): """ Returns a list of notes. Superusers may access and update all notes. """ + ROUTE_FILTER_LOOKUPS = { + "show_pk": "show", + "timeslot_pk": "timeslot", + } + queryset = Note.objects.all() serializer_class = NoteSerializer permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly] pagination_class = LimitOffsetPagination filter_class = filters.NoteFilterSet - def get_queryset(self): - queryset = super().get_queryset() - - # subroute filters - show_pk, timeslot_pk = get_values(self.kwargs, "show_pk", "timeslot_pk") - if show_pk: - queryset = queryset.filter(show=show_pk) - if timeslot_pk: - queryset = queryset.filter(timeslot=timeslot_pk) - - return queryset - def create(self, request, *args, **kwargs): - """Create a note""" - - show_pk, schedule_pk, timeslot_pk = get_values( - self.kwargs, "show_pk", "schedule_pk", "timeslot_pk" - ) + """ + Only superusers can create new notes. + """ + show_pk, timeslot_pk = get_values(self.kwargs, "show_pk", "timeslot_pk") if ( not request.user.is_superuser @@ -649,10 +598,6 @@ class APINoteViewSet(viewsets.ModelViewSet): ): return Response(status=status.HTTP_401_UNAUTHORIZED) - # Only create a note if show_id, timeslot_id and schedule_id is given - if show_pk is None or schedule_pk is None or timeslot_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - serializer = NoteSerializer( data={"show": show_pk, "timeslot": timeslot_pk} | request.data, context={"user_id": request.user.id}, @@ -670,52 +615,11 @@ class APINoteViewSet(viewsets.ModelViewSet): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def retrieve(self, request, *args, **kwargs): + def update(self, request, *args, **kwargs): """ - Returns a single note - - Called by: - /notes/1 - /shows/1/notes/1 - /shows/1/timeslots/1/note/1 - /shows/1/schedules/1/timeslots/1/note/1 + Only superusers can update existing notes. """ - pk, show_pk, schedule_pk, timeslot_pk = get_values( - self.kwargs, "pk", "show_pk", "schedule_pk", "timeslot_pk" - ) - - # - # /shows/1/notes/1 - # - # Returns a note to a show - # - if show_pk and timeslot_pk is None and schedule_pk is None: - note = get_object_or_404(Note, pk=pk, show=show_pk) - - # - # /shows/1/timeslots/1/note/1 - # /shows/1/schedules/1/timeslots/1/note/1 - # - # Return a note to a timeslot - # - elif show_pk and timeslot_pk: - note = get_object_or_404(Note, pk=pk, show=show_pk, timeslot=timeslot_pk) - - # - # /notes/1 - # - # Returns the given note - # - else: - note = get_object_or_404(Note, pk=pk) - - serializer = NoteSerializer(note) - return Response(serializer.data) - - def update(self, request, *args, **kwargs): - pk, show_pk, schedule_pk, timeslot_pk = get_values( - self.kwargs, "pk", "show_pk", "schedule_pk", "timeslot_pk" - ) + show_pk = get_values(self.kwargs, "show_pk") if ( not request.user.is_superuser @@ -723,13 +627,7 @@ class APINoteViewSet(viewsets.ModelViewSet): ): return Response(status=status.HTTP_401_UNAUTHORIZED) - # Allow PUT only when calling - # /shows/{show_pk}/schedules/{schedule_pk}/timeslots/{timeslot_pk}/note/{pk} - if show_pk is None or schedule_pk is None or timeslot_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - note = get_object_or_404(Note, pk=pk, timeslot=timeslot_pk, show=show_pk) - + note = self.get_object() serializer = NoteSerializer(note, data=request.data) if serializer.is_valid(): @@ -748,9 +646,10 @@ class APINoteViewSet(viewsets.ModelViewSet): return Response(status=status.HTTP_400_BAD_REQUEST) def destroy(self, request, *args, **kwargs): - pk, show_pk, schedule_pk, timeslot_pk = get_values( - self.kwargs, "pk", "show_pk", "schedule_pk", "timeslot_pk" - ) + """ + Only superusers can delete existing notes. + """ + show_pk = get_values(self.kwargs, "show_pk") if ( not request.user.is_superuser @@ -758,10 +657,7 @@ class APINoteViewSet(viewsets.ModelViewSet): ): return Response(status=status.HTTP_401_UNAUTHORIZED) - if pk is None or show_pk is None or schedule_pk is None or timeslot_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - Note.objects.get(pk=pk).delete() + self.get_object().delete() return Response(status=status.HTTP_204_NO_CONTENT)