diff --git a/.gitignore b/.gitignore index 69da0bf513c4e8ebdb04cd38baa98e3ca5d75a58..76aa668227b7fddd47fc0d70d92af05afc97820c 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,6 @@ db.sqlite3 .mypy_cache *.pyc .pytest_cache +.cache/ static/ steering_data_model.png diff --git a/poetry.lock b/poetry.lock index 5432cf2b013dbcc6a69dc4f234854858596ae538..09f26e2d25aa9611e4006fe7f2b8f7e92af46b78 100644 --- a/poetry.lock +++ b/poetry.lock @@ -402,18 +402,18 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.10.7" +version = "3.11.0" description = "A platform independent file lock." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "filelock-3.10.7-py3-none-any.whl", hash = "sha256:bde48477b15fde2c7e5a0713cbe72721cb5a5ad32ee0b8f419907960b9d75536"}, - {file = "filelock-3.10.7.tar.gz", hash = "sha256:892be14aa8efc01673b5ed6589dbccb95f9a8596f0507e232626155495c18105"}, + {file = "filelock-3.11.0-py3-none-any.whl", hash = "sha256:f08a52314748335c6460fc8fe40cd5638b85001225db78c2aa01c8c0db83b318"}, + {file = "filelock-3.11.0.tar.gz", hash = "sha256:3618c0da67adcc0506b015fd11ef7faf1b493f0b40d87728e19986b536890c37"}, ] [package.extras] -docs = ["furo (>=2022.12.7)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] +docs = ["furo (>=2023.3.27)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.2.2)", "diff-cover (>=7.5)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] [[package]] diff --git a/program/filters.py b/program/filters.py index e263f095a9e821439cf82da9524a7158a9b122da..298111c8b04302ad5ba4ca41981af0b80e6e563f 100644 --- a/program/filters.py +++ b/program/filters.py @@ -231,12 +231,20 @@ class TimeSlotFilterSet(filters.FilterSet): class NoteFilterSet(StaticFilterHelpTextMixin, filters.FilterSet): + show = IntegerInFilter( + field_name="timeslot__show", + help_text="Return only notes that belong to the specified show(s).", + ) + timeslot = IntegerInFilter( + field_name="timeslot", + help_text="Return only notes that belong to the specified timeslot(s).", + ) ids = IntegerInFilter( field_name="id", help_text="Return only notes matching the specified id(s).", ) show_owner = IntegerInFilter( - field_name="show__owners", + field_name="timeslot__show__owners", help_text="Return only notes by show the specified owner(s): all notes the user may edit.", ) @@ -245,7 +253,7 @@ class NoteFilterSet(StaticFilterHelpTextMixin, filters.FilterSet): help_texts = { "owner": "Return only notes created by the specified user.", } - fields = ["ids", "owner", "show_owner"] + fields = ["ids", "owner", "show", "timeslot", "show_owner"] class ActiveFilterSet(StaticFilterHelpTextMixin, filters.FilterSet): diff --git a/program/serializers.py b/program/serializers.py index f1165f1a67174bb299d26325c095257793993bce..f1c870b70b0865d48c33a3ec0334e04a2fc44207 100644 --- a/program/serializers.py +++ b/program/serializers.py @@ -671,11 +671,12 @@ class NoteSerializer(serializers.ModelSerializer): contributors = serializers.PrimaryKeyRelatedField(queryset=Host.objects.all(), many=True) image = serializers.PrimaryKeyRelatedField(queryset=Image.objects.all(), required=False) links = NoteLinkSerializer(many=True, required=False) - timeslot = serializers.PrimaryKeyRelatedField(queryset=TimeSlot.objects.all()) + timeslot = serializers.PrimaryKeyRelatedField(queryset=TimeSlot.objects.all(), required=False) class Meta: model = Note read_only_fields = ( + "id", "created_at", "created_by", "updated_at", @@ -704,9 +705,10 @@ class NoteSerializer(serializers.ModelSerializer): contributors = validated_data.pop("contributors", []) # the creator of the note is the owner - validated_data["owner"] = self.context["user_id"] - - note = Note.objects.create(**validated_data | self.context) # created_by + validated_data["owner"] = self.context["request"].user + note = Note.objects.create( + created_by=self.context["request"].user.username, **validated_data + ) note.contributors.set(contributors) @@ -736,7 +738,6 @@ class NoteSerializer(serializers.ModelSerializer): instance.cba_id = validated_data.get("cba_id", instance.cba_id) instance.content = validated_data.get("content", instance.content) instance.image = validated_data.get("image", instance.image) - instance.show = validated_data.get("show", instance.show) instance.slug = validated_data.get("slug", instance.slug) instance.summary = validated_data.get("summary", instance.summary) instance.timeslot = validated_data.get("timeslot", instance.timeslot) @@ -754,7 +755,7 @@ class NoteSerializer(serializers.ModelSerializer): for link_data in links_data: NoteLink.objects.create(note=instance, **link_data) - instance.updated_by = self.context.get("updated_by") + instance.updated_by = self.context.get("request").user.username instance.save() diff --git a/program/tests/__init__.py b/program/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeef18a2d317dfa39aea5752ac257bc20d0e04bd --- /dev/null +++ b/program/tests/__init__.py @@ -0,0 +1,100 @@ +import datetime + +from django.contrib.auth.models import User +from django.utils.text import slugify +from django.utils.timezone import now +from program.models import Note, RRule, Schedule, Show, TimeSlot + + +class SteeringTestCaseMixin: + base_url = "/api/v1" + + def _url(self, *paths, **kwargs): + url = "/".join(str(p) for p in paths) + "/" + return f"{self.base_url}/{url.format(**kwargs)}" + + def _get_client(self, user=None): + client = self.client_class() + if user: + client.force_authenticate(user=user) + return client + + +class UserMixin: + user_admin: User + user_common: User + + def setUp(self): + self.user_admin = User.objects.create_superuser( + "admin", "admin@aura.radio", password="admin" + ) + self.user_common = User.objects.create_user( + "herbert", "herbert@aura.radio", password="herbert" + ) + + +class ShowMixin: + def _create_show(self, name: str, **kwargs): + kwargs["name"] = name + kwargs.setdefault("slug", slugify(name)) + kwargs.setdefault("short_description", f"The {name} show") + owners = kwargs.pop("owners", []) + show = Show.objects.create(**kwargs) + if owners: + show.owners.set(owners) + return show + + +class ScheduleMixin: + def _get_rrule(self): + rrule = RRule.objects.first() + if rrule is None: + rrule = RRule.objects.create(name="once", freq=0) + return rrule + + def _create_schedule(self, show: Show, **kwargs): + _first_date = kwargs.get("first_date", now().date()) + kwargs["show"] = show + kwargs.setdefault("first_date", _first_date) + kwargs.setdefault("start_time", "08:00") + kwargs.setdefault("last_date", _first_date + datetime.timedelta(days=365)) + kwargs.setdefault("end_time", "09:00") + kwargs.setdefault("rrule", self._get_rrule()) + return Schedule.objects.create(**kwargs) + + +class TimeSlotMixin: + def _create_timeslot(self, schedule: Schedule, **kwargs): + _start = kwargs.get("start", now()) + kwargs.setdefault("schedule", schedule) + kwargs.setdefault("show", schedule.show) + kwargs.setdefault("start", _start) + kwargs.setdefault("end", _start + datetime.timedelta(hours=1)) + return TimeSlot.objects.create(**kwargs) + + +class NoteMixin: + def _create_note(self, timeslot: TimeSlot, **kwargs): + note_count = Note.objects.all().count() + _title = kwargs.get("title", f"a random note #{note_count}") + kwargs["timeslot"] = timeslot + kwargs["title"] = _title + kwargs.setdefault("slug", slugify(_title)) + return Note.objects.create(**kwargs) + + def _create_random_note_content(self, **kwargs): + note_count = Note.objects.all().count() + _title = kwargs.get("title", f"a random note #{note_count}") + kwargs["title"] = _title + kwargs.setdefault("slug", slugify(_title)) + kwargs.setdefault("content", "some random content") + kwargs.setdefault("contributors", []) + return kwargs + + +class ProgramModelMixin(ShowMixin, ScheduleMixin, TimeSlotMixin, NoteMixin): + pass + + +class BaseMixin(UserMixin, ProgramModelMixin, SteeringTestCaseMixin): + pass diff --git a/program/tests/test_notes.py b/program/tests/test_notes.py new file mode 100644 index 0000000000000000000000000000000000000000..d21fdadcaabcc7bca81efd6da49eda400df4e325 --- /dev/null +++ b/program/tests/test_notes.py @@ -0,0 +1,124 @@ +from rest_framework.test import APITransactionTestCase + +from program import tests +from program.models import Schedule, Show + + +class NoteViewTestCase(tests.BaseMixin, APITransactionTestCase): + reset_sequences = True + + show_beatbetrieb: Show + schedule_beatbetrieb: Schedule + show_musikrotation: Show + schedule_musikrotation: Schedule + + def setUp(self) -> None: + super().setUp() + self.show_beatbetrieb = self._create_show("Beatbetrieb") + self.schedule_beatbetrieb = self._create_schedule(self.show_beatbetrieb) + self.show_musikrotation = self._create_show("Musikrotation", owners=[self.user_common]) + self.schedule_musikrotation = self._create_schedule( + self.show_musikrotation, start_time="10:00", end_time="12:00" + ) + + def test_everyone_can_read_notes(self): + self._create_note(self._create_timeslot(schedule=self.schedule_beatbetrieb)) + self._create_note(self._create_timeslot(schedule=self.schedule_musikrotation)) + res = self._get_client().get(self._url("notes")) + self.assertEqual(len(res.data), 2) + + def test_common_users_can_create_notes_for_owned_shows(self): + ts = self._create_timeslot(schedule=self.schedule_musikrotation) + client = self._get_client(self.user_common) + endpoint = self._url("notes") + res = client.post( + endpoint, self._create_random_note_content(timeslot=ts.id), format="json" + ) + self.assertEqual(res.status_code, 201) + + def test_common_users_cannot_create_notes_for_foreign_shows(self): + ts = self._create_timeslot(schedule=self.schedule_beatbetrieb) + client = self._get_client(self.user_common) + endpoint = self._url("notes") + res = client.post( + endpoint, self._create_random_note_content(timeslot=ts.id), format="json" + ) + self.assertEqual(res.status_code, 404) + + def test_common_user_can_update_owned_shows(self): + ts = self._create_timeslot(schedule=self.schedule_musikrotation) + note = self._create_note(ts) + client = self._get_client(self.user_common) + new_note_content = self._create_random_note_content(title="meh") + res = client.put(self._url("notes", note.id), new_note_content, format="json") + self.assertEqual(res.status_code, 200) + + def test_common_user_cannot_update_notes_of_foreign_shows(self): + ts = self._create_timeslot(schedule=self.schedule_beatbetrieb) + note = self._create_note(ts) + client = self._get_client(self.user_common) + new_note_content = self._create_random_note_content(title="meh") + res = client.put(self._url("notes", note.id), new_note_content, format="json") + self.assertEqual(res.status_code, 404) + + def test_admin_can_create_notes_for_all_timeslots(self): + timeslot = self._create_timeslot(schedule=self.schedule_musikrotation) + client = self._get_client(self.user_admin) + res = client.post( + self._url("notes"), + self._create_random_note_content(timeslot=timeslot.id), + format="json", + ) + self.assertEqual(res.status_code, 201) + + def test_notes_can_be_created_through_nested_routes(self): + client = self._get_client(self.user_admin) + + # /shows/{pk}/notes/ + ts1 = self._create_timeslot(schedule=self.schedule_musikrotation) + url = self._url("shows", self.show_musikrotation.id, "notes") + note = self._create_random_note_content(title="meh", timeslot=ts1.id) + res = client.post(url, note, format="json") + self.assertEqual(res.status_code, 201) + + # /shows/{pk}/timeslots/{pk}/note/ + ts2 = self._create_timeslot(schedule=self.schedule_musikrotation) + url = self._url("shows", self.show_musikrotation, "timeslots", ts2.id, "note") + note = self._create_random_note_content(title="cool") + res = client.post(url, note, format="json") + self.assertEqual(res.status_code, 201) + + def test_notes_can_be_filtered_through_nested_routes_and_query_params(self): + client = self._get_client() + + ts1 = self._create_timeslot(schedule=self.schedule_musikrotation) + ts2 = self._create_timeslot(schedule=self.schedule_beatbetrieb) + ts3 = self._create_timeslot(schedule=self.schedule_beatbetrieb) + n1 = self._create_note(timeslot=ts1) + n2 = self._create_note(timeslot=ts2) + n3 = self._create_note(timeslot=ts3) + + def _get_ids(res): + return set(ts["id"] for ts in res.data) + + # /shows/{pk}/notes/ + query_res = client.get(self._url("notes") + f"?show={self.show_beatbetrieb.id}") + route_res = client.get(self._url("shows", self.show_beatbetrieb.id, "notes")) + ids = {n2.id, n3.id} + self.assertEqual(_get_ids(query_res), ids) + self.assertEqual(_get_ids(route_res), ids) + + query_res = client.get(self._url("notes") + f"?show={self.show_musikrotation.id}") + route_res = client.get(self._url("shows", self.show_musikrotation.id, "notes")) + ids = {n1.id} + self.assertEqual(_get_ids(query_res), ids) + self.assertEqual(_get_ids(route_res), ids) + + # /shows/{pk}/timeslots/{pk}/note/ + query_res = client.get(self._url("notes") + f"?timeslot={ts2.id}") + route_res = client.get( + self._url("shows", self.show_beatbetrieb.id, "timeslots", ts2.id, "note") + ) + ids = {n2.id} + self.assertEqual(_get_ids(query_res), ids) + self.assertEqual(_get_ids(route_res), ids) diff --git a/program/views.py b/program/views.py index 1bee42fb3165b7031775b38a8d4fcf343327b664..f7b051f2723df78dc13b1b0cfca8aef77bda07b3 100644 --- a/program/views.py +++ b/program/views.py @@ -26,12 +26,13 @@ from textwrap import dedent from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_view from rest_framework import mixins, permissions, status, viewsets +from rest_framework.exceptions import ValidationError from rest_framework.pagination import LimitOffsetPagination from rest_framework.response import Response from django.contrib.auth.models import User from django.core.exceptions import FieldError -from django.http import HttpResponse +from django.http import Http404, HttpResponse from django.shortcuts import get_list_or_404, get_object_or_404 from django.utils import timezone from django.utils.translation import gettext as _ @@ -772,87 +773,45 @@ class APINoteViewSet( viewsets.ModelViewSet, ): ROUTE_FILTER_LOOKUPS = { - "show_pk": "show", + "show_pk": "timeslot__show", "timeslot_pk": "timeslot", } - queryset = Note.objects.all() serializer_class = NoteSerializer - permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly] + permission_classes = [permissions.IsAuthenticatedOrReadOnly] pagination_class = LimitOffsetPagination - filter_class = filters.NoteFilterSet - - def create(self, request, *args, **kwargs): - """ - Only admins can create new notes. - """ - show_pk, timeslot_pk = get_values(self.kwargs, "show_pk", "timeslot_pk") - - if not request.user.is_superuser and show_pk not in request.user.shows.values_list( - "id", flat=True - ): - return Response(status=status.HTTP_401_UNAUTHORIZED) - - serializer = NoteSerializer( - data={"show": show_pk, "timeslot": timeslot_pk} | request.data, - context={"user_id": request.user.id, "created_by": request.user.username}, - ) - - if serializer.is_valid(): - hosts = Host.objects.filter(shows__in=request.user.shows.values_list("id", flat=True)) - if not request.user.is_superuser and request.data["host"] not in hosts: - serializer.validated_data["host"] = None - - serializer.save() - return Response(serializer.data, status=status.HTTP_201_CREATED) - - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - - def update(self, request, *args, **kwargs): - """ - Only admins can update existing notes. - """ - show_pk = get_values(self.kwargs, "show_pk") - - if not request.user.is_superuser and show_pk not in request.user.shows.values_list( - "id", flat=True - ): - return Response(status=status.HTTP_401_UNAUTHORIZED) - - note = self.get_object() - serializer = NoteSerializer( - note, data=request.data, context={"updated_by": request.user.username} - ) - - if serializer.is_valid(): - hosts = Host.objects.filter(shows__in=request.user.shows.values_list("id", flat=True)) - # Don't assign a host the user mustn't edit. Reassign the original value instead - if not request.user.is_superuser and int(request.data["host"]) not in hosts: - serializer.validated_data["host"] = Host.objects.filter(pk=note.host_id)[0] - - serializer.save() - return Response(serializer.data) - - return Response(status=status.HTTP_400_BAD_REQUEST) - - def partial_update(self, request, *args, **kwargs): - kwargs["partial"] = True - return self.update(request, *args, **kwargs) - - def destroy(self, request, *args, **kwargs): - """ - Only admins can delete existing notes. - """ - show_pk = get_values(self.kwargs, "show_pk") - - if not request.user.is_superuser and show_pk not in request.user.shows.values_list( - "id", flat=True - ): - return Response(status=status.HTTP_401_UNAUTHORIZED) - - self.get_object().delete() + filterset_class = filters.NoteFilterSet - return Response(status=status.HTTP_204_NO_CONTENT) + def get_queryset(self): + qs = super().get_queryset().order_by("slug") + # Users should always be able to see notes + if self.request.method not in permissions.SAFE_METHODS: + # If the request is not by an admin, + # check that the timeslot is owned by the current user. + if not self.request.user.is_superuser: + qs = qs.filter(timeslot__show__owners=self.request.user) + return qs + + def _get_timeslot(self): + # TODO: Once we remove nested routes, timeslot ownership + # should be checked in a permission class. + timeslot_pk = self.request.data.get("timeslot", None) + if timeslot_pk is None: + timeslot_pk = get_values(self.kwargs, "timeslot_pk") + if timeslot_pk is None: + raise ValidationError({"timeslot": [_("This field is required.")]}, code="required") + qs = TimeSlot.objects.all() + if not self.request.user.is_superuser: + qs = qs.filter(show__owners=self.request.user) + try: + return qs.get(pk=timeslot_pk) + except TimeSlot.DoesNotExist: + raise Http404() + + def perform_create(self, serializer): + # TODO: Once we remove nested routes, this should be removed + # and timeslot should be required in the serializer again. + serializer.save(timeslot=self._get_timeslot()) class ActiveFilterMixin: