From ab3e6406a2519720efa67518a1eafb5090b62a2f Mon Sep 17 00:00:00 2001
From: Konrad Mohrfeldt <konrad.mohrfeldt@farbdev.org>
Date: Fri, 7 Apr 2023 02:10:18 +0200
Subject: [PATCH] refactor: re-work note viewset to work with updated model

---
 program/filters.py     |   2 +-
 program/serializers.py |  13 ++---
 program/views.py       | 111 +++++++++++++----------------------------
 3 files changed, 43 insertions(+), 83 deletions(-)

diff --git a/program/filters.py b/program/filters.py
index e263f095..73038ac7 100644
--- a/program/filters.py
+++ b/program/filters.py
@@ -236,7 +236,7 @@ class NoteFilterSet(StaticFilterHelpTextMixin, filters.FilterSet):
         help_text="Return only notes matching the specified id(s).",
     )
     show_owner = IntegerInFilter(
-        field_name="show__owners",
+        field_name="timeslot__show__owners",
         help_text="Return only notes by show the specified owner(s): all notes the user may edit.",
     )
 
diff --git a/program/serializers.py b/program/serializers.py
index f1165f1a..f1c870b7 100644
--- a/program/serializers.py
+++ b/program/serializers.py
@@ -671,11 +671,12 @@ class NoteSerializer(serializers.ModelSerializer):
     contributors = serializers.PrimaryKeyRelatedField(queryset=Host.objects.all(), many=True)
     image = serializers.PrimaryKeyRelatedField(queryset=Image.objects.all(), required=False)
     links = NoteLinkSerializer(many=True, required=False)
-    timeslot = serializers.PrimaryKeyRelatedField(queryset=TimeSlot.objects.all())
+    timeslot = serializers.PrimaryKeyRelatedField(queryset=TimeSlot.objects.all(), required=False)
 
     class Meta:
         model = Note
         read_only_fields = (
+            "id",
             "created_at",
             "created_by",
             "updated_at",
@@ -704,9 +705,10 @@ class NoteSerializer(serializers.ModelSerializer):
         contributors = validated_data.pop("contributors", [])
 
         # the creator of the note is the owner
-        validated_data["owner"] = self.context["user_id"]
-
-        note = Note.objects.create(**validated_data | self.context)  # created_by
+        validated_data["owner"] = self.context["request"].user
+        note = Note.objects.create(
+            created_by=self.context["request"].user.username, **validated_data
+        )
 
         note.contributors.set(contributors)
 
@@ -736,7 +738,6 @@ class NoteSerializer(serializers.ModelSerializer):
         instance.cba_id = validated_data.get("cba_id", instance.cba_id)
         instance.content = validated_data.get("content", instance.content)
         instance.image = validated_data.get("image", instance.image)
-        instance.show = validated_data.get("show", instance.show)
         instance.slug = validated_data.get("slug", instance.slug)
         instance.summary = validated_data.get("summary", instance.summary)
         instance.timeslot = validated_data.get("timeslot", instance.timeslot)
@@ -754,7 +755,7 @@ class NoteSerializer(serializers.ModelSerializer):
             for link_data in links_data:
                 NoteLink.objects.create(note=instance, **link_data)
 
-        instance.updated_by = self.context.get("updated_by")
+        instance.updated_by = self.context.get("request").user.username
 
         instance.save()
 
diff --git a/program/views.py b/program/views.py
index 1bee42fb..f7b051f2 100644
--- a/program/views.py
+++ b/program/views.py
@@ -26,12 +26,13 @@ from textwrap import dedent
 
 from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_view
 from rest_framework import mixins, permissions, status, viewsets
+from rest_framework.exceptions import ValidationError
 from rest_framework.pagination import LimitOffsetPagination
 from rest_framework.response import Response
 
 from django.contrib.auth.models import User
 from django.core.exceptions import FieldError
-from django.http import HttpResponse
+from django.http import Http404, HttpResponse
 from django.shortcuts import get_list_or_404, get_object_or_404
 from django.utils import timezone
 from django.utils.translation import gettext as _
@@ -772,87 +773,45 @@ class APINoteViewSet(
     viewsets.ModelViewSet,
 ):
     ROUTE_FILTER_LOOKUPS = {
-        "show_pk": "show",
+        "show_pk": "timeslot__show",
         "timeslot_pk": "timeslot",
     }
-
     queryset = Note.objects.all()
     serializer_class = NoteSerializer
-    permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly]
+    permission_classes = [permissions.IsAuthenticatedOrReadOnly]
     pagination_class = LimitOffsetPagination
-    filter_class = filters.NoteFilterSet
-
-    def create(self, request, *args, **kwargs):
-        """
-        Only admins can create new notes.
-        """
-        show_pk, timeslot_pk = get_values(self.kwargs, "show_pk", "timeslot_pk")
-
-        if not request.user.is_superuser and show_pk not in request.user.shows.values_list(
-            "id", flat=True
-        ):
-            return Response(status=status.HTTP_401_UNAUTHORIZED)
-
-        serializer = NoteSerializer(
-            data={"show": show_pk, "timeslot": timeslot_pk} | request.data,
-            context={"user_id": request.user.id, "created_by": request.user.username},
-        )
-
-        if serializer.is_valid():
-            hosts = Host.objects.filter(shows__in=request.user.shows.values_list("id", flat=True))
-            if not request.user.is_superuser and request.data["host"] not in hosts:
-                serializer.validated_data["host"] = None
-
-            serializer.save()
-            return Response(serializer.data, status=status.HTTP_201_CREATED)
-
-        return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
-
-    def update(self, request, *args, **kwargs):
-        """
-        Only admins can update existing notes.
-        """
-        show_pk = get_values(self.kwargs, "show_pk")
-
-        if not request.user.is_superuser and show_pk not in request.user.shows.values_list(
-            "id", flat=True
-        ):
-            return Response(status=status.HTTP_401_UNAUTHORIZED)
-
-        note = self.get_object()
-        serializer = NoteSerializer(
-            note, data=request.data, context={"updated_by": request.user.username}
-        )
-
-        if serializer.is_valid():
-            hosts = Host.objects.filter(shows__in=request.user.shows.values_list("id", flat=True))
-            # Don't assign a host the user mustn't edit. Reassign the original value instead
-            if not request.user.is_superuser and int(request.data["host"]) not in hosts:
-                serializer.validated_data["host"] = Host.objects.filter(pk=note.host_id)[0]
-
-            serializer.save()
-            return Response(serializer.data)
-
-        return Response(status=status.HTTP_400_BAD_REQUEST)
-
-    def partial_update(self, request, *args, **kwargs):
-        kwargs["partial"] = True
-        return self.update(request, *args, **kwargs)
-
-    def destroy(self, request, *args, **kwargs):
-        """
-        Only admins can delete existing notes.
-        """
-        show_pk = get_values(self.kwargs, "show_pk")
-
-        if not request.user.is_superuser and show_pk not in request.user.shows.values_list(
-            "id", flat=True
-        ):
-            return Response(status=status.HTTP_401_UNAUTHORIZED)
-
-        self.get_object().delete()
+    filterset_class = filters.NoteFilterSet
 
-        return Response(status=status.HTTP_204_NO_CONTENT)
+    def get_queryset(self):
+        qs = super().get_queryset().order_by("slug")
+        # Users should always be able to see notes
+        if self.request.method not in permissions.SAFE_METHODS:
+            # If the request is not by an admin,
+            # check that the timeslot is owned by the current user.
+            if not self.request.user.is_superuser:
+                qs = qs.filter(timeslot__show__owners=self.request.user)
+        return qs
+
+    def _get_timeslot(self):
+        # TODO: Once we remove nested routes, timeslot ownership
+        #       should be checked in a permission class.
+        timeslot_pk = self.request.data.get("timeslot", None)
+        if timeslot_pk is None:
+            timeslot_pk = get_values(self.kwargs, "timeslot_pk")
+        if timeslot_pk is None:
+            raise ValidationError({"timeslot": [_("This field is required.")]}, code="required")
+        qs = TimeSlot.objects.all()
+        if not self.request.user.is_superuser:
+            qs = qs.filter(show__owners=self.request.user)
+        try:
+            return qs.get(pk=timeslot_pk)
+        except TimeSlot.DoesNotExist:
+            raise Http404()
+
+    def perform_create(self, serializer):
+        # TODO: Once we remove nested routes, this should be removed
+        #       and timeslot should be required in the serializer again.
+        serializer.save(timeslot=self._get_timeslot())
 
 
 class ActiveFilterMixin:
-- 
GitLab