From 5c34d059312d8a2b9165b107cb753a5ce7d6aef5 Mon Sep 17 00:00:00 2001
From: Ernesto Rico Schmidt <ernesto@helsinki.at>
Date: Mon, 14 Mar 2022 18:02:51 -0400
Subject: [PATCH] Refactor and Clean-up Viewsets

- Reorder the code inside the methods to fail fast on autorization,
- Replace `int_or_none` with a more generic solution, and move to utils,
- Add `get_values` and move `pk_and_slug` as `get_pk_and_slug` into utils,
- Replace calls to static methos in models local queries,
- Return meaningful status code while creating and updating resources,
- Return `409` when creating or updating a schedule produces a conflict.
---
 program/utils.py |  28 +++++++
 program/views.py | 212 +++++++++++++++++++----------------------------
 2 files changed, 115 insertions(+), 125 deletions(-)

diff --git a/program/utils.py b/program/utils.py
index 78481e83..4b42ab84 100644
--- a/program/utils.py
+++ b/program/utils.py
@@ -19,6 +19,7 @@
 #
 
 from datetime import datetime, date, time
+from typing import Dict, Optional, Union, Tuple
 
 import requests
 from django.utils import timezone
@@ -73,3 +74,30 @@ def get_audio_url(cba_id):
         audio_url = requests.get(url).json()
 
     return audio_url
+
+
+def get_values(kwargs: Dict[str, str], *keys: str) -> Union[Tuple[Union[int, str, None], ...], int, str, None]:
+    """Get the values of the keys from the kwargs."""
+
+    def int_if_digit(value: Optional[str]) -> Optional[Union[int, str]]:
+        return int(value) if value and value.isdigit() else value
+
+    values = [kwargs.get(key) for key in keys]
+
+    if len(values) > 1:
+        return tuple(int_if_digit(value) for value in values)
+    else:
+        return int_if_digit(values[0])
+
+
+def get_pk_and_slug(kwargs: Dict[str, str]) -> Tuple[Optional[int], Optional[str]]:
+    """Get the pk and the slug from the kwargs."""
+
+    pk, slug = None, None
+
+    try:
+        pk = int(kwargs['pk'])
+    except ValueError:
+        slug = kwargs['pk']
+
+    return pk, slug
diff --git a/program/views.py b/program/views.py
index 7a8cdd9a..73073c42 100644
--- a/program/views.py
+++ b/program/views.py
@@ -36,7 +36,7 @@ from program.models import Type, MusicFocus, Language, Note, Show, Category, Fun
 from program.serializers import TypeSerializer, LanguageSerializer, MusicFocusSerializer, NoteSerializer, ShowSerializer, \
     ScheduleSerializer, CategorySerializer, FundingCategorySerializer, TopicSerializer, TimeSlotSerializer, HostSerializer, \
     UserSerializer
-from program.utils import parse_date
+from program.utils import parse_date, get_values, get_pk_and_slug
 
 logger = logging.getLogger(__name__)
 
@@ -137,22 +137,6 @@ def json_playout(request):
                         content_type="application/json; charset=utf-8")
 
 
-def int_or_none(key, kwargs):
-    return int(kwargs[key]) if key in kwargs else None
-
-
-def pk_and_slug(kwargs):
-    pk = None
-    slug = None
-
-    try:
-        pk = int(kwargs['pk'])
-    except ValueError:
-        slug = kwargs['pk']
-
-    return pk, slug
-
-
 class APIUserViewSet(viewsets.ModelViewSet):
     """
     /users returns oneself. Superusers see all users. Only superusers may create a user (GET, POST)
@@ -172,14 +156,9 @@ class APIUserViewSet(viewsets.ModelViewSet):
 
         return User.objects.filter(pk=self.request.user.id)
 
-    def list(self, request, *args, **kwargs):
-        users = self.get_queryset()
-        serializer = UserSerializer(users, many=True)
-        return Response(serializer.data)
-
     def retrieve(self, request, *args, **kwargs):
         """Returns a single user"""
-        pk = int_or_none('pk', self.kwargs)
+        pk = get_values(self.kwargs, 'pk')
 
         # Common users only see themselves
         if not request.user.is_superuser and pk != request.user.id:
@@ -202,12 +181,12 @@ class APIUserViewSet(viewsets.ModelViewSet):
 
         if serializer.is_valid():
             serializer.save()
-            return Response(serializer.data)
+            return Response(serializer.data, status=status.HTTP_201_CREATED)
 
         return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
 
     def update(self, request, *args, **kwargs):
-        pk = int_or_none('pk', self.kwargs)
+        pk = get_values(self.kwargs, 'pk')
 
         serializer = UserSerializer(data=request.data)
         # Common users may only edit themselves
@@ -225,7 +204,7 @@ class APIUserViewSet(viewsets.ModelViewSet):
 
     def destroy(self, request, *args, **kwargs):
         """Deleting users is prohibited: Set 'is_active' to False instead"""
-        return Response(status=status.HTTP_401_UNAUTHORIZED)
+        return Response(status=status.HTTP_400_BAD_REQUEST)
 
 
 class APIShowViewSet(viewsets.ModelViewSet):
@@ -329,14 +308,14 @@ class APIShowViewSet(viewsets.ModelViewSet):
 
         if serializer.is_valid():
             serializer.save()
-            return Response(serializer.data)
+            return Response(serializer.data, status=status.HTTP_201_CREATED)
 
         return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
 
     def retrieve(self, request, *args, **kwargs):
         """Returns a single show"""
 
-        pk, slug = pk_and_slug(self.kwargs)
+        pk, slug = get_pk_and_slug(self.kwargs)
 
         show = get_object_or_404(Show, pk=pk) if pk else get_object_or_404(Show, slug=slug) if slug else None
 
@@ -350,9 +329,9 @@ class APIShowViewSet(viewsets.ModelViewSet):
         Common users may only update shows they own
         """
 
-        pk = int_or_none('pk', self.kwargs)
+        pk = get_values(self.kwargs, 'pk')
 
-        if not Show.is_editable(self, pk):
+        if not request.user.is_superuser and pk not in request.user.shows.values_list('id', flat=True):
             return Response(status=status.HTTP_401_UNAUTHORIZED)
 
         show = get_object_or_404(Show, pk=pk)
@@ -376,7 +355,8 @@ class APIShowViewSet(viewsets.ModelViewSet):
         if not request.user.is_superuser:
             return Response(status=status.HTTP_401_UNAUTHORIZED)
 
-        pk = int_or_none('pk', self.kwargs)
+        pk = get_values(self.kwargs, 'pk')
+
         Show.objects.get(pk=pk).delete()
 
         return Response(status=status.HTTP_204_NO_CONTENT)
@@ -397,30 +377,20 @@ class APIScheduleViewSet(viewsets.ModelViewSet):
     permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly]
 
     def get_queryset(self):
-        show_pk = int_or_none('show_pk', self.kwargs)
+        show_pk = get_values(self.kwargs, 'show_pk')
 
         if show_pk:
             return Schedule.objects.filter(show=show_pk)
 
         return Schedule.objects.all()
 
-    def list(self, request, *args, **kwargs):
-        """List Schedules of a show"""
-
-        schedules = self.get_queryset()
-        serializer = ScheduleSerializer(schedules, many=True)
-        return Response(serializer.data)
-
     def retrieve(self, request, *args, **kwargs):
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
+        pk, show_pk = get_values(self.kwargs, 'pk', 'show_pk')
 
-        if show_pk:
-            schedule = get_object_or_404(Schedule, pk=pk, show=show_pk)
-        else:
-            schedule = get_object_or_404(Schedule, pk=pk)
+        schedule = get_object_or_404(Schedule, pk=pk, show=show_pk) if show_pk else get_object_or_404(Schedule, pk=pk)
 
         serializer = ScheduleSerializer(schedule)
+
         return Response(serializer.data)
 
     def create(self, request, *args, **kwargs):
@@ -430,20 +400,19 @@ class APIScheduleViewSet(viewsets.ModelViewSet):
         Only superusers may add schedules
         TODO: Perhaps directly insert into database if no conflicts found
         """
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
 
-        # Only allow creating when calling /shows/1/schedules/
-        if show_pk is None or not request.user.is_superuser:
+        if not request.user.is_superuser:
             return Response(status=status.HTTP_401_UNAUTHORIZED)
 
-        # The schedule dict is mandatory
-        if 'schedule' not in request.data:
+        pk, show_pk = get_values(self.kwargs, 'pk', 'show_pk')
+
+        # Only allow creating when calling /shows/{show_pk}/schedules/ and with ehe `schedule` JSON object
+        if show_pk is None or 'schedule' not in request.data:
             return Response(status=status.HTTP_400_BAD_REQUEST)
 
         # First create submit -> return projected timeslots and collisions
         if 'solutions' not in request.data:
-            return Response(Schedule.make_conflicts(request.data['schedule'], pk, show_pk))
+            return Response(Schedule.make_conflicts(request.data['schedule'], pk, show_pk), status=status.HTTP_409_CONFLICT)
 
         # Otherwise try to resolve
         resolution = Schedule.resolve_conflicts(request.data, pk, show_pk)
@@ -453,7 +422,7 @@ class APIScheduleViewSet(viewsets.ModelViewSet):
             return Response(resolution, status=status.HTTP_201_CREATED)
 
         # Otherwise return conflicts
-        return Response(resolution)
+        return Response(resolution, status=status.HTTP_409_CONFLICT)
 
     def update(self, request, *args, **kwargs):
         """
@@ -461,21 +430,18 @@ class APIScheduleViewSet(viewsets.ModelViewSet):
 
         Only superusers may update schedules
         """
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
 
-        # Only allow updating when calling /shows/1/schedules/1
-        if show_pk is None or not request.user.is_superuser:
+        if not request.user.is_superuser:
             return Response(status=status.HTTP_401_UNAUTHORIZED)
 
-        # The schedule dict is mandatory
-        if 'schedule' not in request.data:
+        pk, show_pk = get_values(self.kwargs, 'pk', 'show_pk')
+
+        # Only allow updating when calling /shows/{show_pk}/schedules/{pk}/ and with the `schedule` JSON object
+        if show_pk is None or 'schedule' not in request.data:
             return Response(status=status.HTTP_400_BAD_REQUEST)
 
-        # If we're updating the default playlist id
-        # TODO: If nothing else than default_playlist_id, or is_repetition changed -> just save and don't do anything
-        new_schedule = request.data.get('schedule')
-        if default_playlist_id := new_schedule.get('default_playlist_id'):
+        # If default playlist id or repetition are given, just update
+        if default_playlist_id := request.data.get('schedule').get('default_playlist_id'):
             schedule = get_object_or_404(Schedule, pk=pk, show=show_pk)
             schedule.default_playlist_id = int(default_playlist_id)
             schedule.save()
@@ -483,32 +449,43 @@ class APIScheduleViewSet(viewsets.ModelViewSet):
             serializer = ScheduleSerializer(schedule)
             return Response(serializer.data)
 
+        if is_repetition := request.data.get('schedule').get('is_repetition'):
+            schedule = get_object_or_404(Schedule, pk=pk, show=show_pk)
+            schedule.is_repetition = bool(is_repetition)
+            schedule.save()
+
+            serializer = ScheduleSerializer(schedule)
+            return Response(serializer.data)
+
         # First update submit -> return projected timeslots and collisions
         if 'solutions' not in request.data:
-            return Response(Schedule.make_conflicts(request.data['schedule'], pk, show_pk))
+            return Response(Schedule.make_conflicts(request.data['schedule'], pk, show_pk), status=status.HTTP_409_CONFLICT)
 
         # Otherwise try to resolve
         resolution = Schedule.resolve_conflicts(request.data, pk, show_pk)
 
         # If resolution went well
         if 'projected' not in resolution:
-            return Response(resolution, status=status.HTTP_200_OK)
+            return Response(resolution)
 
         # Otherwise return conflicts
-        return Response(resolution)
+        return Response(resolution, status=status.HTTP_409_CONFLICT)
 
     def destroy(self, request, *args, **kwargs):
         """
         Delete a schedule
         Only superusers may delete schedules
         """
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = self.kwargs.get('show_pk')
 
-        # Only allow deleting when calling /shows/1/schedules/1
-        if show_pk is None or not request.user.is_superuser:
+        if not request.user.is_superuser:
             return Response(status=status.HTTP_401_UNAUTHORIZED)
 
+        pk, show_pk = get_values(self.kwargs, 'pk', 'show_pk')
+
+        # Only allow deleting when calling /shows/{show_pk}/schedules/{pk}
+        if show_pk is None:
+            return Response(status=status.HTTP_400_BAD_REQUEST)
+
         Schedule.objects.get(pk=pk).delete()
 
         return Response(status=status.HTTP_204_NO_CONTENT)
@@ -534,8 +511,7 @@ class APITimeSlotViewSet(viewsets.ModelViewSet):
     queryset = TimeSlot.objects.none()
 
     def get_queryset(self):
-        show_pk = int_or_none('show_pk', self.kwargs)
-        schedule_pk = int_or_none('schedule_pk', self.kwargs)
+        show_pk, schedule_pk = get_values(self.kwargs, 'show_pk', 'schedule_pk')
         # Filters
 
         # Return next 60 days by default
@@ -591,8 +567,7 @@ class APITimeSlotViewSet(viewsets.ModelViewSet):
             return TimeSlot.objects.filter(start__gte=start, end__lte=end).order_by(order)
 
     def retrieve(self, request, *args, **kwargs):
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
+        pk, show_pk = get_values(self.kwargs, 'pk', 'show_pk')
 
         if show_pk:
             timeslot = get_object_or_404(TimeSlot, pk=pk, show=show_pk)
@@ -607,18 +582,20 @@ class APITimeSlotViewSet(viewsets.ModelViewSet):
         Timeslots may only be created by adding/updating schedules
         TODO: Adding single timeslot which fits to schedule?
         """
-        return Response(status=status.HTTP_401_UNAUTHORIZED)
+        return Response(status=status.HTTP_400_BAD_REQUEST)
 
     def update(self, request, *args, **kwargs):
         """Link a playlist_id to a timeslot"""
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
-        schedule_pk = int_or_none('schedule_pk', self.kwargs)
 
-        # Update is only allowed when calling /shows/1/schedules/1/timeslots/1 and if user owns the show
-        if schedule_pk is None or show_pk is None or not Show.is_editable(self, show_pk):
+        pk, show_pk, schedule_pk = get_values(self.kwargs, 'pk', 'show_pk', 'schedule_pk')
+
+        if not request.user.is_superuser and show_pk not in request.user.shows.values_lis('id', flat=True):
             return Response(status=status.HTTP_401_UNAUTHORIZED)
 
+        # Update is only allowed when calling /shows/1/schedules/1/timeslots/1 and if user owns the show
+        if schedule_pk is None or show_pk is None:
+            return Response(status=status.HTTP_400_BAD_REQUEST)
+
         timeslot = get_object_or_404(TimeSlot, pk=pk, schedule=schedule_pk, show=show_pk)
 
         serializer = TimeSlotSerializer(timeslot, data=request.data)
@@ -631,7 +608,7 @@ class APITimeSlotViewSet(viewsets.ModelViewSet):
             ts = TimeSlot.objects.filter(show=show_pk, start__gt=timeslot.start)[0]
             if ts.is_repetition:
                 serializer = TimeSlotSerializer(ts)
-                return Response(serializer.data, status=status.HTTP_200_OK)
+                return Response(serializer.data)
 
             # ...or nothing if there isn't one
             return Response(status=status.HTTP_200_OK)
@@ -643,16 +620,16 @@ class APITimeSlotViewSet(viewsets.ModelViewSet):
         Delete a timeslot
         Only superusers may delete timeslots
         """
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
+
+        if not request.user.is_superuser:
+            return Response(status=status.HTTP_401_UNAUTHORIZED)
+
+        pk, show_pk = get_values(self.kwargs, 'pk', 'show_pk')
 
         # Only allow when calling endpoint starting with /shows/1/...
         if show_pk is None:
             return Response(status=status.HTTP_400_BAD_REQUEST)
 
-        if not request.user.is_superuser:
-            return Response(status=status.HTTP_401_UNAUTHORIZED)
-
         TimeSlot.objects.get(pk=pk).delete()
 
         return Response(status=status.HTTP_204_NO_CONTENT)
@@ -682,8 +659,7 @@ class APINoteViewSet(viewsets.ModelViewSet):
     pagination_class = LimitOffsetPagination
 
     def get_queryset(self):
-        timeslot_pk = int_or_none('timeslot_pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
+        timeslot_pk, show_pk = get_values(self.kwargs, 'timeslot_pk', 'show_pk')
 
         # Endpoints
 
@@ -736,27 +712,25 @@ class APINoteViewSet(viewsets.ModelViewSet):
 
     def create(self, request, *args, **kwargs):
         """Create a note"""
-        show_pk = int_or_none('show_pk', self.kwargs)
-        schedule_pk = int_or_none('schedule_pk', self.kwargs)
-        timeslot_pk = int_or_none('timeslot_pk', self.kwargs)
+
+        show_pk, schedule_pk, timeslot_pk = get_values(self.kwargs, 'show_pk', 'schedule_pk', 'timelost_pk')
+
+        if not request.user.is_superuser and show_pk not in request.user.shows.values_list('id', flat=True):
+            return Response(status=status.HTTP_401_UNAUTHORIZED)
 
         # Only create a note if show_id, timeslot_id and schedule_id is given
         if show_pk is None or schedule_pk is None or timeslot_pk is None:
             return Response(status=status.HTTP_400_BAD_REQUEST)
 
-        if not Show.is_editable(self, show_pk):
-            return Response(status=status.HTTP_401_UNAUTHORIZED)
-
         serializer = NoteSerializer(data=request.data, context={'user_id': request.user.id})
 
         if serializer.is_valid():
-
-            # Don't assign a host the user mustn't edit
-            if not Host.is_editable(self, request.data['host']) or request.data['host'] is None:
+            hosts = Host.objects.filter(shows__in=request.user.shows.values_list('id', flat=True))
+            if not request.user.is_superuser and request.data['host'] not in hosts:
                 serializer.validated_data['host'] = None
 
             serializer.save()
-            return Response(serializer.data)
+            return Response(serializer.data, status=status.HTTP_201_CREATED)
 
         return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
 
@@ -770,10 +744,7 @@ class APINoteViewSet(viewsets.ModelViewSet):
         /shows/1/timeslots/1/note/1
         /shows/1/schedules/1/timeslots/1/note/1
         """
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
-        schedule_pk = int_or_none('schedule_pk', self.kwargs)
-        timeslot_pk = int_or_none('timeslot_pk', self.kwargs)
+        pk, show_pk, schedule_pk, timeslot_pk = get_values(self.kwargs, 'pk', 'show_pk', 'schedule_pk', 'timeslot_pk')
 
         #
         #      /shows/1/notes/1
@@ -804,27 +775,23 @@ class APINoteViewSet(viewsets.ModelViewSet):
         return Response(serializer.data)
 
     def update(self, request, *args, **kwargs):
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
-        schedule_pk = int_or_none('schedule_pk', self.kwargs)
-        timeslot_pk = int_or_none('timeslot_pk', self.kwargs)
+        pk, show_pk, schedule_pk, timeslot_pk = get_values(self.kwargs, 'pk', 'show_pk', 'schedule_pk', 'timeslot_pk')
+
+        if not request.user.is_superuser and show_pk not in request.user.shows.values_list('id', flat=True):
+            return Response(status=status.HTTP_401_UNAUTHORIZED)
 
-        # Allow PUT only when calling /shows/1/schedules/1/timeslots/1/note/1
+        # Allow PUT only when calling /shows/{show_pk}/schedules/{schedule_pk}/timeslots/{timeslot_pk}/note/{pk}
         if show_pk is None or schedule_pk is None or timeslot_pk is None:
             return Response(status=status.HTTP_400_BAD_REQUEST)
 
         note = get_object_or_404(Note, pk=pk, timeslot=timeslot_pk, show=show_pk)
 
-        # Commons users may only edit notes of shows they own
-        if not Note.is_editable(self, note.id):
-            return Response(status=status.HTTP_401_UNAUTHORIZED)
-
         serializer = NoteSerializer(note, data=request.data)
 
         if serializer.is_valid():
-
+            hosts = Host.objects.filter(shows__in=request.user.shows.values_list('id', flat=True))
             # Don't assign a host the user mustn't edit. Reassign the original value instead
-            if not Host.is_editable(self, request.data['host']) and request.data['host']:
+            if not request.user.is_superuser and int(request.data['host']) not in hosts:
                 serializer.validated_data['host'] = Host.objects.filter(pk=note.host_id)[0]
 
             serializer.save()
@@ -833,22 +800,17 @@ class APINoteViewSet(viewsets.ModelViewSet):
         return Response(status=status.HTTP_400_BAD_REQUEST)
 
     def destroy(self, request, *args, **kwargs):
-        # Allow DELETE only when calling /shows/1/schedules/1/timeslots/1/note/1
-        pk = int_or_none('pk', self.kwargs)
-        show_pk = int_or_none('show_pk', self.kwargs)
-        schedule_pk = int_or_none('schedule_pk', self.kwargs)
-        timeslot_pk = int_or_none('timeslot_pk', self.kwargs)
+        pk, show_pk, schedule_pk, timeslot_pk = get_values(self.kwargs, 'pk', 'show_pk', 'schedule_pk', 'timeslot_pk')
 
-        if show_pk is None or schedule_pk is None or timeslot_pk is None:
-            return Response(status=status.HTTP_400_BAD_REQUEST)
+        if not request.user.is_superuser and show_pk not in request.user.shows.values_list('id', flat=True):
+            return Response(status=status.HTTP_401_UNAUTHORIZED)
 
-        note = get_object_or_404(Note, pk=pk)
+        if pk is None or show_pk is None or schedule_pk is None or timeslot_pk is None:
+            return Response(status=status.HTTP_400_BAD_REQUEST)
 
-        if Note.is_editable(self, note.id):
-            Note.objects.get(pk=pk).delete()
-            return Response(status=status.HTTP_204_NO_CONTENT)
+        Note.objects.get(pk=pk).delete()
 
-        return Response(status=status.HTTP_401_UNAUTHORIZED)
+        return Response(status=status.HTTP_204_NO_CONTENT)
 
 
 class ActiveInactiveViewSet(viewsets.ModelViewSet):
-- 
GitLab