diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 479d30835fc8d3a306e435f098928db0e4e8f835..5ce2cb1ee90e9219a2d70e70b9d43b2a59b2981e 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -7,8 +7,29 @@ - Dockerfile stages: + - build - release +build-openapi-scheme: + stage: build + image: python:3.9-alpine + variables: + SECRET_KEY: not-a-very-secret-key + DJANGO_SETTINGS_MODULE: steering.settings + # steering.settings switches to sqlite if the VIRTUAL_ENV environment + # variable is present. It’s probably a good idea to refactor this to + # something more explicit. + VIRTUAL_ENV: 1 + before_script: + - apk add gcc musl-dev zlib-dev jpeg-dev libmagic + - pip install -r requirements.txt + script: + - python3 -m django spectacular --validate --lang en --file openapi.yaml + artifacts: + paths: + - openapi.yaml + + docker-push: # Use the official docker image. image: docker:latest @@ -32,16 +53,16 @@ docker-push: fi # TODO: maybe isolate docker build and docker push - docker push "$AURA_IMAGE_NAME" --all-tags - rules: + rules: - *release-rules # every commit on master/main branch should trigger a push to docker-hub as unstable without a release - - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH + - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH exists: - Dockerfile release_job: stage: release - needs: + needs: - docker-push image: registry.gitlab.com/gitlab-org/release-cli:latest rules: *release-rules diff --git a/program/auth.py b/program/auth.py index 19d2b99756e835716636361c96bcf9b7457556a1..723b949523e764d96245aa15061de55812b4df67 100644 --- a/program/auth.py +++ b/program/auth.py @@ -18,6 +18,8 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. # +from drf_spectacular.extensions import OpenApiAuthenticationExtension +from drf_spectacular.plumbing import build_bearer_security_scheme_object from oidc_provider.lib.utils.oauth2 import extract_access_token from oidc_provider.models import Token from rest_framework import authentication, exceptions @@ -40,3 +42,27 @@ class OidcOauth2Auth(authentication.BaseAuthentication): raise exceptions.AuthenticationFailed("The oauth2 token has expired") return oauth2_token.user, None + + +class OidcOauth2AuthenticationScheme(OpenApiAuthenticationExtension): + target_class = "program.auth.OidcOauth2Auth" + name = "tokenAuth" + + def get_security_definition(self, auto_schema): + # One might be inclined to return a list here, because the bearer token + # can also be passed as an `access_token` query parameter. + # Officially this seems to be supported, but once we return a list + # the authorization helper in the generated documentation is empty, + # so we only return the primary mode instead. + return build_bearer_security_scheme_object("Authorization", "Bearer") + + # This should work, but doesn’t for the reasons above: + # + # return [ + # build_bearer_security_scheme_object("Authorization", "Bearer"), + # { + # "type": "apiKey", + # "in": "query", + # "name": "access_token", + # } + # ] diff --git a/program/filters.py b/program/filters.py index e05e58d70c58d55ae4171ceb22055487f03bfe8d..8478b6f9baca2d7c92b9d4127f47c50b86df1d99 100644 --- a/program/filters.py +++ b/program/filters.py @@ -4,7 +4,6 @@ from django_filters import rest_framework as filters from django_filters import widgets from django import forms -from django.contrib.auth.models import User from django.db.models import Q, QuerySet from django.utils import timezone from program import models @@ -20,55 +19,59 @@ class StaticFilterHelpTextMixin: return _filter -class ModelMultipleChoiceFilter(filters.ModelMultipleChoiceFilter): +class IntegerInFilter(filters.BaseInFilter): + class QueryArrayWidget(widgets.QueryArrayWidget): + # see: https://github.com/carltongibson/django-filter/issues/1047 + def value_from_datadict(self, data, files, name): + new_data = {} + for key in data.keys(): + if len(data.getlist(key)) == 1 and "," in data[key]: + new_data[key] = data[key] + else: + new_data[key] = data.getlist(key) + return super().value_from_datadict(new_data, files, name) + + field_class = forms.IntegerField + def __init__(self, *args, **kwargs): - kwargs.setdefault("widget", widgets.CSVWidget()) - kwargs["lookup_expr"] = "in" + kwargs.setdefault("widget", self.QueryArrayWidget()) super().__init__(*args, **kwargs) - def get_filter_predicate(self, v): - # There is something wrong with using ModelMultipleChoiceFilter - # along the CSVWidget that causes lookups to fail. - # May be related to: https://github.com/carltongibson/django-filter/issues/1103 - return super().get_filter_predicate([v.pk]) - class ShowFilterSet(StaticFilterHelpTextMixin, filters.FilterSet): active = filters.BooleanFilter( field_name="is_active", method="filter_active", help_text=( - "Return only currently running shows if true or past or upcoming shows if false.", + "Return only currently running shows (with timeslots in the future) if true " + "or past or upcoming shows if false." ), ) - host = ModelMultipleChoiceFilter( - queryset=models.Host.objects.all(), + host = IntegerInFilter( field_name="hosts", help_text="Return only shows assigned to the given host(s).", ) # TODO: replace `musicfocus` with `music_focus` when dashboard is updated - musicfocus = ModelMultipleChoiceFilter( - queryset=models.MusicFocus.objects.all(), + musicfocus = IntegerInFilter( field_name="music_focus", help_text="Return only shows with given music focus(es).", ) - owner = ModelMultipleChoiceFilter( - queryset=User.objects.all(), + owner = IntegerInFilter( field_name="owners", help_text="Return only shows that belong to the given owner(s).", ) - category = ModelMultipleChoiceFilter( - queryset=models.Category.objects.all(), + category = IntegerInFilter( help_text="Return only shows of the given category or categories.", ) - language = ModelMultipleChoiceFilter( - queryset=models.Language.objects.all(), + language = IntegerInFilter( help_text="Return only shows of the given language(s).", ) - topic = ModelMultipleChoiceFilter( - queryset=models.Topic.objects.all(), + topic = IntegerInFilter( help_text="Return only shows of the given topic(s).", ) + type = IntegerInFilter( + help_text="Return only shows of a given type.", + ) public = filters.BooleanFilter( field_name="is_public", help_text="Return only shows that are public/non-public.", @@ -102,9 +105,6 @@ class ShowFilterSet(StaticFilterHelpTextMixin, filters.FilterSet): class Meta: model = models.Show - help_texts = { - "type": "Return only shows of a given type.", - } fields = [ "active", "category", @@ -182,16 +182,10 @@ class TimeSlotFilterSet(filters.FilterSet): queryset = self.filter_surrounding(queryset, "surrounding", timezone.now()) return queryset - class Meta: - model = models.TimeSlot - fields = [ - "order", - "start", - "end", - "surrounding", - ] + def get_form_class(self): + form_cls = super().get_form_class() - class form(forms.Form): + class TimeSlotFilterSetFormWithDefaults(form_cls): def clean_start(self): start = self.cleaned_data.get("start", None) return start or timezone.now().date() @@ -200,16 +194,31 @@ class TimeSlotFilterSet(filters.FilterSet): end = self.cleaned_data.get("end", None) return end or self.cleaned_data["start"] + datetime.timedelta(days=60) + # We only want defaults to apply in the context of the list action. + # When accessing individual timeslots we don’t want the queryset to be restricted + # to the default range of 60 days as get_object would yield a 404 otherwise. + if self.request.parser_context["view"].action == "list": + return TimeSlotFilterSetFormWithDefaults + else: + return form_cls + + class Meta: + model = models.TimeSlot + fields = [ + "order", + "start", + "end", + "surrounding", + ] + class NoteFilterSet(StaticFilterHelpTextMixin, filters.FilterSet): - ids = ModelMultipleChoiceFilter( + ids = IntegerInFilter( field_name="id", - queryset=models.Note.objects.all(), help_text="Return only notes matching the specified id(s).", ) - owner = ModelMultipleChoiceFilter( + owner = IntegerInFilter( field_name="show__owners", - queryset=models.User.objects.all(), help_text="Return only notes by show the specified owner(s): all notes the user may edit.", ) diff --git a/program/migrations/0018_auto_20220322_2113.py b/program/migrations/0018_auto_20220322_2113.py new file mode 100644 index 0000000000000000000000000000000000000000..70dc947b08dc8984f8bda61504aea8813dd009e2 --- /dev/null +++ b/program/migrations/0018_auto_20220322_2113.py @@ -0,0 +1,103 @@ +# Generated by Django 3.2.12 on 2022-03-22 20:13 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("program", "0017_auto_20220302_1711"), + ] + + operations = [ + migrations.AlterField( + model_name="schedule", + name="add_business_days_only", + field=models.BooleanField( + default=False, + help_text="Whether to add add_days_no but skipping the weekends. E.g. if weekday is Friday, the date returned will be the next Monday.", # noqa: E501 + ), + ), + migrations.AlterField( + model_name="schedule", + name="add_days_no", + field=models.IntegerField( + blank=True, + help_text="Add a number of days to the generated dates. This can be useful for repetitions, like 'On the following day'.", # noqa: E501 + null=True, + ), + ), + migrations.AlterField( + model_name="schedule", + name="by_weekday", + field=models.IntegerField( + choices=[ + (0, "Monday"), + (1, "Tuesday"), + (2, "Wednesday"), + (3, "Thursday"), + (4, "Friday"), + (5, "Saturday"), + (6, "Sunday"), + ], + help_text="Number of the Weekday.", + ), + ), + migrations.AlterField( + model_name="schedule", + name="default_playlist_id", + field=models.IntegerField( + blank=True, + help_text="A tank ID in case the timeslot's playlist_id is empty.", + null=True, + ), + ), + migrations.AlterField( + model_name="schedule", + name="end_time", + field=models.TimeField(help_text="End time of schedule."), + ), + migrations.AlterField( + model_name="schedule", + name="first_date", + field=models.DateField(help_text="Start date of schedule."), + ), + migrations.AlterField( + model_name="schedule", + name="is_repetition", + field=models.BooleanField( + default=False, help_text="Whether the schedule is a repetition." + ), + ), + migrations.AlterField( + model_name="schedule", + name="last_date", + field=models.DateField(help_text="End date of schedule."), + ), + migrations.AlterField( + model_name="schedule", + name="rrule", + field=models.ForeignKey( + help_text="\nA recurrence rule.\n\n* 1 = once,\n* 2 = daily,\n* 3 = business days,\n* 4 = weekly,\n* 5 = biweekly,\n* 6 = every four weeks,\n* 7 = every even calendar week (ISO 8601),\n* 8 = every odd calendar week (ISO 8601),\n* 9 = every 1st week of month,\n* 10 = every 2nd week of month,\n* 11 = every 3rd week of month,\n* 12 = every 4th week of month,\n* 13 = every 5th week of month\n", # noqa: E501 + on_delete=django.db.models.deletion.CASCADE, + related_name="schedules", + to="program.rrule", + ), + ), + migrations.AlterField( + model_name="schedule", + name="show", + field=models.ForeignKey( + help_text="Show the schedule belongs to.", + on_delete=django.db.models.deletion.CASCADE, + related_name="schedules", + to="program.show", + ), + ), + migrations.AlterField( + model_name="schedule", + name="start_time", + field=models.TimeField(help_text="Start time of schedule."), + ), + ] diff --git a/program/models.py b/program/models.py index 6a4762df388c594a6d4534e0e6e966b12234588f..cb5991e06d0e241d180eb5b86ef33de5dd62092f 100644 --- a/program/models.py +++ b/program/models.py @@ -19,9 +19,11 @@ # from datetime import datetime, time, timedelta +from textwrap import dedent from dateutil.relativedelta import relativedelta from dateutil.rrule import rrule +from rest_framework.exceptions import ValidationError from versatileimagefield.fields import PPOIField, VersatileImageField from django.contrib.auth.models import User @@ -224,17 +226,76 @@ class RRule(models.Model): class Schedule(models.Model): - rrule = models.ForeignKey(RRule, on_delete=models.CASCADE, related_name="schedules") - show = models.ForeignKey(Show, on_delete=models.CASCADE, related_name="schedules") - by_weekday = models.IntegerField() - first_date = models.DateField() - start_time = models.TimeField() - end_time = models.TimeField() - last_date = models.DateField() - is_repetition = models.BooleanField(default=False) - add_days_no = models.IntegerField(blank=True, null=True) - add_business_days_only = models.BooleanField(default=False) - default_playlist_id = models.IntegerField(blank=True, null=True) + rrule = models.ForeignKey( + RRule, + on_delete=models.CASCADE, + related_name="schedules", + help_text=dedent( + """ + A recurrence rule. + + * 1 = once, + * 2 = daily, + * 3 = business days, + * 4 = weekly, + * 5 = biweekly, + * 6 = every four weeks, + * 7 = every even calendar week (ISO 8601), + * 8 = every odd calendar week (ISO 8601), + * 9 = every 1st week of month, + * 10 = every 2nd week of month, + * 11 = every 3rd week of month, + * 12 = every 4th week of month, + * 13 = every 5th week of month + """ + ), + ) + show = models.ForeignKey( + Show, + on_delete=models.CASCADE, + related_name="schedules", + help_text="Show the schedule belongs to.", + ) + by_weekday = models.IntegerField( + help_text="Number of the Weekday.", + choices=[ + (0, "Monday"), + (1, "Tuesday"), + (2, "Wednesday"), + (3, "Thursday"), + (4, "Friday"), + (5, "Saturday"), + (6, "Sunday"), + ], + ) + first_date = models.DateField(help_text="Start date of schedule.") + start_time = models.TimeField(help_text="Start time of schedule.") + end_time = models.TimeField(help_text="End time of schedule.") + last_date = models.DateField(help_text="End date of schedule.") + is_repetition = models.BooleanField( + default=False, + help_text="Whether the schedule is a repetition.", + ) + add_days_no = models.IntegerField( + blank=True, + null=True, + help_text=( + "Add a number of days to the generated dates. " + "This can be useful for repetitions, like 'On the following day'." + ), + ) + add_business_days_only = models.BooleanField( + default=False, + help_text=( + "Whether to add add_days_no but skipping the weekends. " + "E.g. if weekday is Friday, the date returned will be the next Monday." + ), + ) + default_playlist_id = models.IntegerField( + blank=True, + null=True, + help_text="A tank ID in case the timeslot's playlist_id is empty.", + ) class Meta: ordering = ("first_date", "start_time") @@ -525,10 +586,10 @@ class Schedule(models.Model): # Get note try: - note = Note.objects.get(timeslot=c.id).values_list("id", flat=True) - collision["note_id"] = note + note = Note.objects.get(timeslot=c.id) + collision["note_id"] = note.pk except ObjectDoesNotExist: - pass + collision["note_id"] = None collisions.append(collision) @@ -606,6 +667,7 @@ class Schedule(models.Model): projected_entry["collisions"] = collisions projected_entry["solution_choices"] = solution_choices + projected_entry["error"] = None projected.append(projected_entry) conflicts["projected"] = projected @@ -673,17 +735,26 @@ class Schedule(models.Model): conflicts = Schedule.make_conflicts(sdl, schedule_pk, show_pk) if schedule.rrule.freq > 0 and schedule.first_date == schedule.last_date: - return {"detail": _("Start and until dates mustn't be the same")} + raise ValidationError( + _("Start and until dates mustn't be the same"), + code="no-same-day-start-and-end", + ) if schedule.last_date < schedule.first_date: - return {"detail": _("Until date mustn't be before start")} + raise ValidationError( + _("Until date mustn't be before start"), + code="no-start-after-end", + ) num_conflicts = len( [pr for pr in conflicts["projected"] if len(pr["collisions"]) > 0] ) if len(solutions) != num_conflicts: - return {"detail": _("Numbers of conflicts and solutions don't match.")} + raise ValidationError( + _("Numbers of conflicts and solutions don't match."), + code="one-solution-per-conflict", + ) # Projected timeslots to create create = [] diff --git a/program/serializers.py b/program/serializers.py index 7e9e8af40a3adee5c4a6bb2a12811d2c7fbe00d7..76a903cb11f29e09e4197458805371a47ca2dc36 100644 --- a/program/serializers.py +++ b/program/serializers.py @@ -20,6 +20,7 @@ from profile.models import Profile from profile.serializers import ProfileSerializer +from typing import List from rest_framework import serializers @@ -44,6 +45,37 @@ from program.models import ( from program.utils import get_audio_url from steering.settings import THUMBNAIL_SIZES +SOLUTION_CHOICES = { + "theirs": "Discard projected timeslot. Keep existing timeslot(s).", + "ours": "Create projected timeslot. Delete existing timeslot(s).", + "theirs-start": ( + "Keep existing timeslot. Create projected timeslot with start time of existing end." + ), + "ours-start": ( + "Create projected timeslot. Change end of existing timeslot to projected start time." + ), + "theirs-end": ( + "Keep existing timeslot. Create projected timeslot with end of existing start time." + ), + "ours-end": ( + "Create projected timeslot. Change start of existing timeslot to projected end time." + ), + "theirs-both": ( + "Keep existing timeslot. " + "Create two projected timeslots with end of existing start and start of existing end." + ), + "ours-both": ( + "Create projected timeslot. Split existing timeslot into two: \n\n" + "* set existing end time to projected start,\n" + "* create another timeslot with start = projected end and end = existing end." + ), +} + + +class ErrorSerializer(serializers.Serializer): + message = serializers.CharField() + code = serializers.CharField(allow_null=True) + class UserSerializer(serializers.ModelSerializer): # Add profile fields to JSON @@ -139,10 +171,10 @@ class LinkSerializer(serializers.ModelSerializer): class HostSerializer(serializers.ModelSerializer): links = LinkSerializer(many=True, required=False) - thumbnails = serializers.SerializerMethodField() # Read-only + thumbnails = serializers.SerializerMethodField() @staticmethod - def get_thumbnails(host): + def get_thumbnails(host) -> List[str]: """Returns thumbnails""" thumbnails = [] @@ -261,10 +293,10 @@ class ShowSerializer(serializers.HyperlinkedModelSerializer): predecessor = serializers.PrimaryKeyRelatedField( queryset=Show.objects.all(), required=False, allow_null=True ) - thumbnails = serializers.SerializerMethodField() # Read-only + thumbnails = serializers.SerializerMethodField() @staticmethod - def get_thumbnails(show): + def get_thumbnails(show) -> List[str]: """Returns thumbnails""" thumbnails = [] @@ -372,14 +404,44 @@ class ShowSerializer(serializers.HyperlinkedModelSerializer): class ScheduleSerializer(serializers.ModelSerializer): - rrule = serializers.PrimaryKeyRelatedField(queryset=RRule.objects.all()) - show = serializers.PrimaryKeyRelatedField(queryset=Show.objects.all()) + rrule = serializers.PrimaryKeyRelatedField( + queryset=RRule.objects.all(), + help_text=Schedule.rrule.field.help_text, + ) + show = serializers.PrimaryKeyRelatedField( + queryset=Show.objects.all(), + help_text=Schedule.show.field.help_text, + ) # TODO: remove this when the dashboard is updated - byweekday = serializers.IntegerField(source="by_weekday") - dstart = serializers.DateField(source="first_date") - tstart = serializers.TimeField(source="start_time") - tend = serializers.TimeField(source="end_time") - until = serializers.DateField(source="last_date") + byweekday = serializers.IntegerField( + source="by_weekday", + help_text=Schedule.by_weekday.field.help_text, + ) + dstart = serializers.DateField( + source="first_date", + help_text=Schedule.first_date.field.help_text, + ) + tstart = serializers.TimeField( + source="start_time", + help_text=Schedule.start_time.field.help_text, + ) + tend = serializers.TimeField( + source="end_time", + help_text=Schedule.end_time.field.help_text, + ) + until = serializers.DateField( + source="last_date", + help_text=Schedule.last_date.field.help_text, + ) + dryrun = serializers.BooleanField( + write_only=True, + required=False, + help_text=( + "Whether to simulate the database changes. If true, no database changes will occur. " + "Instead a list of objects that would be created, updated and deleted if dryrun was " + "false will be returned." + ), + ) class Meta: model = Schedule @@ -423,6 +485,88 @@ class ScheduleSerializer(serializers.ModelSerializer): return instance +class CollisionSerializer(serializers.Serializer): + id = serializers.IntegerField() + start = serializers.DateTimeField() + end = serializers.DateTimeField() + playlist_id = serializers.IntegerField(allow_null=True) + show = serializers.IntegerField() + show_name = serializers.CharField() + is_repetition = serializers.BooleanField() + schedule = serializers.IntegerField() + memo = serializers.CharField() + note_id = serializers.IntegerField(allow_null=True) + + +class ProjectedTimeSlotSerializer(serializers.Serializer): + hash = serializers.CharField() + start = serializers.DateTimeField() + end = serializers.DateTimeField() + collisions = CollisionSerializer(many=True) + error = serializers.CharField(allow_null=True) + solution_choices = serializers.ListField( + child=serializers.ChoiceField(SOLUTION_CHOICES) + ) + + +class DryRunTimeSlotSerializer(serializers.Serializer): + id = serializers.PrimaryKeyRelatedField( + queryset=TimeSlot.objects.all(), allow_null=True + ) + schedule = serializers.PrimaryKeyRelatedField( + queryset=Schedule.objects.all(), allow_null=True + ) + playlist_id = serializers.IntegerField(allow_null=True) + start = serializers.DateField() + end = serializers.DateField() + is_repetition = serializers.BooleanField() + memo = serializers.CharField() + + +class ScheduleCreateUpdateRequestSerializer(serializers.Serializer): + schedule = ScheduleSerializer() + solutions = serializers.DictField(child=serializers.ChoiceField(SOLUTION_CHOICES)) + notes = serializers.DictField(child=serializers.IntegerField(), required=False) + playlists = serializers.DictField(child=serializers.IntegerField(), required=False) + + +# TODO: There shouldn’t be a separate ScheduleSerializer for use in responses. +# Instead the default serializer should be used. Unfortunately, the +# code that generates the data creates custom dicts with this particular format. +class ScheduleInResponseSerializer(serializers.Serializer): + # "Schedule schema type" is the rendered name of the ScheduleSerializer. + """ + For documentation on the individual fields see the + Schedule schema type. + """ + add_business_days_only = serializers.BooleanField() + add_days_no = serializers.IntegerField(allow_null=True) + by_weekday = serializers.IntegerField() + default_playlist_id = serializers.IntegerField(allow_null=True) + end_time = serializers.TimeField() + first_date = serializers.DateField() + id = serializers.PrimaryKeyRelatedField(queryset=Schedule.objects.all()) + is_repetition = serializers.BooleanField() + last_date = serializers.DateField() + rrule = serializers.PrimaryKeyRelatedField(queryset=RRule.objects.all()) + show = serializers.PrimaryKeyRelatedField(queryset=Note.objects.all()) + start_time = serializers.TimeField() + + +class ScheduleConflictResponseSerializer(serializers.Serializer): + projected = ProjectedTimeSlotSerializer(many=True) + solutions = serializers.DictField(child=serializers.ChoiceField(SOLUTION_CHOICES)) + notes = serializers.DictField(child=serializers.IntegerField()) + playlists = serializers.DictField(child=serializers.IntegerField()) + schedule = ScheduleInResponseSerializer() + + +class ScheduleDryRunResponseSerializer(serializers.Serializer): + created = DryRunTimeSlotSerializer(many=True) + updated = DryRunTimeSlotSerializer(many=True) + deleted = DryRunTimeSlotSerializer(many=True) + + class TimeSlotSerializer(serializers.ModelSerializer): show = serializers.PrimaryKeyRelatedField(queryset=Show.objects.all()) schedule = serializers.PrimaryKeyRelatedField(queryset=Schedule.objects.all()) @@ -452,11 +596,10 @@ class NoteSerializer(serializers.ModelSerializer): show = serializers.PrimaryKeyRelatedField(queryset=Show.objects.all()) timeslot = serializers.PrimaryKeyRelatedField(queryset=TimeSlot.objects.all()) host = serializers.PrimaryKeyRelatedField(queryset=Host.objects.all()) - thumbnails = serializers.SerializerMethodField() # Read-only - cba_id = serializers.IntegerField(required=False, write_only=True) + thumbnails = serializers.SerializerMethodField() @staticmethod - def get_thumbnails(note): + def get_thumbnails(note) -> List[str]: """Returns thumbnails""" thumbnails = [] diff --git a/program/utils.py b/program/utils.py index 53b6420d1b95bc99e98df51ac8e905e0d073eeff..d21eaaf81ea9f1e66d2b4551d59a190fdd1cc262 100644 --- a/program/utils.py +++ b/program/utils.py @@ -23,6 +23,7 @@ from datetime import date, datetime, time from typing import Dict, Optional, Tuple, Union import requests +from rest_framework import exceptions from django.utils import timezone from steering.settings import CBA_AJAX_URL, CBA_API_KEY, DEBUG @@ -109,14 +110,39 @@ def get_values( return int_if_digit(values[0]) -def get_pk_and_slug(kwargs: Dict[str, str]) -> Tuple[Optional[int], Optional[str]]: - """Get the pk and the slug from the kwargs.""" +class DisabledObjectPermissionCheckMixin: + """ + At the time of writing permission checks were entirely circumvented by manual + queries in viewsets. To make code refactoring easier and allow + the paced introduction of .get_object() in viewsets, object permission checks + need to be disabled until permission checks have been refactored as well. - pk, slug = None, None + Object permissions checks should become mandatory once proper permission_classes + are assigned to viewsets. This mixin should be removed afterwards. + """ - try: - pk = int(kwargs["pk"]) - except ValueError: - slug = kwargs["pk"] + # The text above becomes the viewset’s doc string otherwise and is displayed in + # the generated OpenAPI schema. + __doc__ = None + + def check_object_permissions(self, request, obj): + pass + + +class NestedObjectFinderMixin: + ROUTE_FILTER_LOOKUPS = {} + + def _get_route_filters(self) -> Dict[str, int]: + filter_kwargs = {} + for key, value in self.kwargs.items(): + if key in self.ROUTE_FILTER_LOOKUPS: + try: + filter_kwargs[self.ROUTE_FILTER_LOOKUPS[key]] = int(value) + except ValueError: + raise exceptions.ValidationError( + detail=f"{key} must map to an integer value.", code="invalid-pk" + ) + return filter_kwargs - return pk, slug + def get_queryset(self): + return super().get_queryset().filter(**self._get_route_filters()) diff --git a/program/views.py b/program/views.py index f28c0c3ef6303449a86993ff71247a0d19acfc01..e0f7b4d893f9c8d1ec21198f1b0b2f60441df75a 100644 --- a/program/views.py +++ b/program/views.py @@ -21,7 +21,9 @@ import json import logging from datetime import date, datetime, time +from textwrap import dedent +from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_view from rest_framework import mixins, permissions, status, viewsets from rest_framework.pagination import LimitOffsetPagination from rest_framework.response import Response @@ -47,11 +49,15 @@ from program.models import ( ) from program.serializers import ( CategorySerializer, + ErrorSerializer, FundingCategorySerializer, HostSerializer, LanguageSerializer, MusicFocusSerializer, NoteSerializer, + ScheduleConflictResponseSerializer, + ScheduleCreateUpdateRequestSerializer, + ScheduleDryRunResponseSerializer, ScheduleSerializer, ShowSerializer, TimeSlotSerializer, @@ -59,7 +65,12 @@ from program.serializers import ( TypeSerializer, UserSerializer, ) -from program.utils import get_pk_and_slug, get_values, parse_date +from program.utils import ( + DisabledObjectPermissionCheckMixin, + NestedObjectFinderMixin, + get_values, + parse_date, +) logger = logging.getLogger(__name__) @@ -187,19 +198,36 @@ def json_playout(request): ) +@extend_schema_view( + create=extend_schema(summary="Create a new user."), + retrieve=extend_schema( + summary="Retrieve a single user.", + description="Non-admin users may only retrieve their own user record.", + ), + update=extend_schema( + summary="Update an existing user.", + description="Non-admin users may only update their own user record.", + ), + partial_update=extend_schema( + summary="Partially update an existing user.", + description="Non-admin users may only update their own user record.", + ), + list=extend_schema( + summary="List all users.", + description=( + "The returned list of records will only contain a single record " + "for non-admin users which is their own user account." + ), + ), +) class APIUserViewSet( + DisabledObjectPermissionCheckMixin, mixins.CreateModelMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet, ): - """ - Returns a list of users. - - Only returns the user that is currently authenticated unless the user is a superuser. - """ - permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly] serializer_class = UserSerializer queryset = User.objects.all() @@ -213,23 +241,9 @@ class APIUserViewSet( return queryset - def retrieve(self, request, *args, **kwargs): - """Returns a single user.""" - pk = get_values(self.kwargs, "pk") - - # Common users only see themselves - if not request.user.is_superuser and pk != request.user.id: - return Response(status=status.HTTP_401_UNAUTHORIZED) - - user = get_object_or_404(User, pk=pk) - serializer = UserSerializer(user) - return Response(serializer.data) - def create(self, request, *args, **kwargs): """ - Create a User. - - Only superusers may create users. + Only admins may create users. """ if not request.user.is_superuser: @@ -243,51 +257,38 @@ class APIUserViewSet( return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def update(self, request, *args, **kwargs): - """ - Updates the user’s data. - - Non-superusers may not be able to edit all of the available data. - """ - pk = get_values(self.kwargs, "pk") - - serializer = UserSerializer(data=request.data) - # Common users may only edit themselves - if not request.user.is_superuser and pk != request.user.id: - return Response( - serializer.initial_data, status=status.HTTP_401_UNAUTHORIZED - ) - - user = get_object_or_404(User, pk=pk) - serializer = UserSerializer( - user, data=request.data, context={"user": request.user} - ) - - if serializer.is_valid(): - serializer.save() - return Response(serializer.data) - - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - - -class APIShowViewSet(viewsets.ModelViewSet): - """ - Returns a list of available shows. - - Only superusers may add and delete shows. - """ +@extend_schema_view( + create=extend_schema(summary="Create a new show."), + retrieve=extend_schema(summary="Retrieve a single show."), + update=extend_schema(summary="Update an existing show."), + partial_update=extend_schema(summary="Partially update an existing show."), + destroy=extend_schema(summary="Delete an existing show."), + list=extend_schema(summary="List all shows."), +) +class APIShowViewSet(DisabledObjectPermissionCheckMixin, viewsets.ModelViewSet): queryset = Show.objects.all() serializer_class = ShowSerializer permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly] pagination_class = LimitOffsetPagination filterset_class = filters.ShowFilterSet + def get_object(self): + queryset = self.filter_queryset(self.get_queryset()) + lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field + lookup_arg = self.kwargs[lookup_url_kwarg] + # allow object retrieval through id or slug + try: + filter_kwargs = {self.lookup_field: int(lookup_arg)} + except ValueError: + filter_kwargs = {"slug": lookup_arg} + obj = get_object_or_404(queryset, **filter_kwargs) + self.check_object_permissions(self.request, obj) + return obj + def create(self, request, *args, **kwargs): """ - Create a show. - - Only superusers may create a show. + Only admins may create a show. """ if not request.user.is_superuser: @@ -301,28 +302,9 @@ class APIShowViewSet(viewsets.ModelViewSet): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def retrieve(self, request, *args, **kwargs): - """Returns a single show""" - - pk, slug = get_pk_and_slug(self.kwargs) - - show = ( - get_object_or_404(Show, pk=pk) - if pk - else get_object_or_404(Show, slug=slug) - if slug - else None - ) - - serializer = ShowSerializer(show) - - return Response(serializer.data) - def update(self, request, *args, **kwargs): """ - Update a show. - - Common users may only update shows they own. + Non-admin users may only update shows they own. """ pk = get_values(self.kwargs, "pk") @@ -332,7 +314,7 @@ class APIShowViewSet(viewsets.ModelViewSet): ): return Response(status=status.HTTP_401_UNAUTHORIZED) - show = get_object_or_404(Show, pk=pk) + show = self.get_object() serializer = ShowSerializer( show, data=request.data, context={"user": request.user} ) @@ -348,60 +330,139 @@ class APIShowViewSet(viewsets.ModelViewSet): def destroy(self, request, *args, **kwargs): """ - Delete a show. - - Only superusers may delete shows. + Only admins may delete shows. """ if not request.user.is_superuser: return Response(status=status.HTTP_401_UNAUTHORIZED) - pk = get_values(self.kwargs, "pk") - - Show.objects.get(pk=pk).delete() + self.get_object().delete() return Response(status=status.HTTP_204_NO_CONTENT) -class APIScheduleViewSet(viewsets.ModelViewSet): - """ - Returns a list of schedules. - - Only superusers may create and update schedules. - """ +@extend_schema_view( + create=extend_schema( + summary="Create a new schedule.", + responses={ + status.HTTP_201_CREATED: OpenApiResponse( + response=ScheduleConflictResponseSerializer, + description=( + "Signals the successful creation of the schedule and of the projected " + "timeslots." + ), + ), + status.HTTP_202_ACCEPTED: OpenApiResponse( + response=ScheduleDryRunResponseSerializer, + description=( + "Returns the list of timeslots that would be created, updated and deleted if " + "the schedule request would not have been sent with the dryrun flag." + ), + ), + status.HTTP_400_BAD_REQUEST: OpenApiResponse( + response=ErrorSerializer(many=True), + description=dedent( + """ + Returned in case the request contained invalid data. + + This may happen if: + * the until date is before the start date (`no-start-after-end`), + in which case you should correct either the start or until date. + * The start and until date are the same (`no-same-day-start-and-end`). + This is only allowed for single timeslots with the recurrence rule + set to `once`. You should fix either the start or until date. + * The number of conflicts and solutions aren’t the same + (`one-solution-per-conflict`). Only one solution is allowed per conflict, + so you either offered too many or not enough solutions for any reported + conflicts. + """ + ), + ), + status.HTTP_403_FORBIDDEN: OpenApiResponse( + response=ErrorSerializer, + description=( + "Returned in case the request contained no or invalid authenticated data " + "or the authenticated user does not have authorization to perform the " + "requested operation." + ), + ), + status.HTTP_409_CONFLICT: OpenApiResponse( + response=ScheduleConflictResponseSerializer, + description=dedent( + """ + Returns the list of projected timeslots and any collisions that may have + been found for existing timeslots. + + Errors on projected timeslots may include: + * 'This change on the timeslot is not allowed.' + When adding: There was a change in the schedule's data during conflict + resolution. + When updating: Fields 'start', 'end', 'byweekday' or 'rrule' have changed, + which is not allowed. + * 'No solution given': No solution was provided for the conflict in + `solutions`. Provide a value of `solution_choices`. + * 'Given solution is not accepted for this conflict.': + The solution has a value which is not part of `solution_choices`. + Provide a value of `solution_choices` (at least `ours` or `theirs`). + """ + ), + ), + }, + ), + retrieve=extend_schema(summary="Retrieve a single schedule."), + update=extend_schema(summary="Update an existing schedule."), + partial_update=extend_schema(summary="Partially update an existing schedule."), + destroy=extend_schema(summary="Delete an existing schedule."), + list=extend_schema(summary="List all schedules."), +) +class APIScheduleViewSet( + DisabledObjectPermissionCheckMixin, + NestedObjectFinderMixin, + viewsets.ModelViewSet, +): + ROUTE_FILTER_LOOKUPS = { + "show_pk": "show", + } queryset = Schedule.objects.all() serializer_class = ScheduleSerializer permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly] - def get_queryset(self): - queryset = super().get_queryset() + def get_serializer_class(self): + if self.action in ("create", "update", "partial_update"): + return ScheduleCreateUpdateRequestSerializer + return super().get_serializer_class() - # subroute filters - show_pk = get_values(self.kwargs, "show_pk") - if show_pk: - queryset = queryset.filter(show=show_pk) + def create(self, request, *args, **kwargs): + """ + Create a schedule, generate timeslots, test for collisions and resolve them + (including notes). - return queryset + Note that creating or updating a schedule is the only way to create timeslots. - def retrieve(self, request, *args, **kwargs): - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") + Only admins may add schedules. - schedule = ( - get_object_or_404(Schedule, pk=pk, show=show_pk) - if show_pk - else get_object_or_404(Schedule, pk=pk) - ) + The projected timeslots defined by the schedule are matched against existing + timeslots. The API will return an object that contains - serializer = ScheduleSerializer(schedule) + * the schedule's data, + * projected timeslots, + * detected collisions, + * and possible solutions. - return Response(serializer.data) + As long as no `solutions` object has been set or unresolved collisions exist, + no data is written to the database. A schedule is only created if at least + one timeslot was generated by it. - def create(self, request, *args, **kwargs): - """ - Create a schedule, generate timeslots, test for collisions and resolve them including notes + In order to resolve any possible conflicts, the client must submit a new request with + a solution for each conflict. Possible solutions are listed as part of the projected + timeslot in the `solution_choices` array. In a best-case scenario with no detected + conflicts an empty solutions object will suffice. For more details on the individual + types of solutions see the SolutionChoicesEnum. - Only superusers may add schedules. + **Please note**: + If there's more than one collision for a projected timeslot, only `theirs` and `ours` + are currently supported as solutions. """ if not request.user.is_superuser: @@ -425,6 +486,10 @@ class APIScheduleViewSet(viewsets.ModelViewSet): # Otherwise try to resolve resolution = Schedule.resolve_conflicts(request.data, pk, show_pk) + if all(key in resolution for key in ["create", "update", "delete"]): + # this is a dry-run + return Response(resolution, status=status.HTTP_202_ACCEPTED) + # If resolution went well if "projected" not in resolution: return Response(resolution, status=status.HTTP_201_CREATED) @@ -438,24 +503,22 @@ class APIScheduleViewSet(viewsets.ModelViewSet): Update a schedule, generate timeslots, test for collisions and resolve them including notes. - Only superusers may update schedules. + Only admins may update schedules. """ if not request.user.is_superuser: return Response(status=status.HTTP_401_UNAUTHORIZED) - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") - - # Only allow updating when calling /shows/{show_pk}/schedules/{pk}/ and with the `schedule` - # JSON object - if show_pk is None or "schedule" not in request.data: + # Only allow updating when with the `schedule` JSON object + if "schedule" not in request.data: return Response(status=status.HTTP_400_BAD_REQUEST) + schedule = self.get_object() + # If default playlist id or repetition are given, just update if default_playlist_id := request.data.get("schedule").get( "default_playlist_id" ): - schedule = get_object_or_404(Schedule, pk=pk, show=show_pk) schedule.default_playlist_id = int(default_playlist_id) schedule.save() @@ -463,7 +526,6 @@ class APIScheduleViewSet(viewsets.ModelViewSet): return Response(serializer.data) if is_repetition := request.data.get("schedule").get("is_repetition"): - schedule = get_object_or_404(Schedule, pk=pk, show=show_pk) schedule.is_repetition = bool(is_repetition) schedule.save() @@ -474,11 +536,15 @@ class APIScheduleViewSet(viewsets.ModelViewSet): if "solutions" not in request.data: # TODO: respond with status.HTTP_409_CONFLICT when the dashboard can handle it return Response( - Schedule.make_conflicts(request.data["schedule"], pk, show_pk) + Schedule.make_conflicts( + request.data["schedule"], schedule.pk, schedule.show.pk + ) ) # Otherwise try to resolve - resolution = Schedule.resolve_conflicts(request.data, pk, show_pk) + resolution = Schedule.resolve_conflicts( + request.data, schedule.pk, schedule.show.pk + ) # If resolution went well if "projected" not in resolution: @@ -490,21 +556,13 @@ class APIScheduleViewSet(viewsets.ModelViewSet): def destroy(self, request, *args, **kwargs): """ - Delete a schedule. - - Only superusers may delete schedules. + Only admins may delete schedules. """ if not request.user.is_superuser: return Response(status=status.HTTP_401_UNAUTHORIZED) - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") - - # Only allow deleting when calling /shows/{show_pk}/schedules/{pk} - if show_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - Schedule.objects.get(pk=pk).delete() + self.get_object().delete() return Response(status=status.HTTP_204_NO_CONTENT) @@ -512,21 +570,34 @@ class APIScheduleViewSet(viewsets.ModelViewSet): # TODO: Create is currently not implemented because timeslots are supposed to be inserted # by creating or updating a schedule. # There might be a use case for adding a single timeslot without any conflicts though. +@extend_schema_view( + retrieve=extend_schema(summary="Retrieve a single timeslot."), + update=extend_schema(summary="Update an existing timeslot."), + partial_update=extend_schema(summary="Partially update an existing timeslot."), + destroy=extend_schema(summary="Delete an existing timeslot."), + list=extend_schema( + summary="List all timeslots.", + description=dedent( + """ + By default, only timeslots ranging from now + 60 days will be displayed. + You may override this default overriding start and/or end parameter. + """ + ), + ), +) class APITimeSlotViewSet( + DisabledObjectPermissionCheckMixin, + NestedObjectFinderMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, mixins.DestroyModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet, ): - """ - Returns a list of timeslots. - - By default, only timeslots ranging from now + 60 days will be displayed. - You may override this default overriding start and/or end parameter. - - Timeslots may only be added by creating/updating a schedule. - """ + ROUTE_FILTER_LOOKUPS = { + "show_pk": "show", + "schedule_pk": "schedule", + } permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly] serializer_class = TimeSlotSerializer @@ -534,35 +605,8 @@ class APITimeSlotViewSet( queryset = TimeSlot.objects.all().order_by("-start") filterset_class = filters.TimeSlotFilterSet - def get_queryset(self): - queryset = super().get_queryset() - - # subroute filters - show_pk, schedule_pk = get_values(self.kwargs, "show_pk", "schedule_pk") - if show_pk: - queryset = queryset.filter(show=show_pk) - if schedule_pk: - queryset = queryset.filter(schedule=schedule_pk) - - return queryset - - def retrieve(self, request, *args, **kwargs): - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") - - if show_pk: - timeslot = get_object_or_404(TimeSlot, pk=pk, show=show_pk) - else: - timeslot = get_object_or_404(TimeSlot, pk=pk) - - serializer = TimeSlotSerializer(timeslot) - return Response(serializer.data) - def update(self, request, *args, **kwargs): - """Link a playlist_id to a timeslot""" - - pk, show_pk, schedule_pk = get_values( - self.kwargs, "pk", "show_pk", "schedule_pk" - ) + show_pk = get_values(self.kwargs, "show_pk") if ( not request.user.is_superuser @@ -570,15 +614,7 @@ class APITimeSlotViewSet( ): return Response(status=status.HTTP_401_UNAUTHORIZED) - # Update is only allowed when calling /shows/1/schedules/1/timeslots/1 and if user owns the - # show - if schedule_pk is None or show_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - timeslot = get_object_or_404( - TimeSlot, pk=pk, schedule=schedule_pk, show=show_pk - ) - + timeslot = self.get_object() serializer = TimeSlotSerializer(timeslot, data=request.data) if serializer.is_valid(): serializer.save() @@ -598,31 +634,37 @@ class APITimeSlotViewSet( def destroy(self, request, *args, **kwargs): """ - Deletes a timeslot. - - Only superusers may delete timeslots. + Only admins may delete timeslots. """ if not request.user.is_superuser: return Response(status=status.HTTP_401_UNAUTHORIZED) - pk, show_pk = get_values(self.kwargs, "pk", "show_pk") - - # Only allow when calling endpoint starting with /shows/1/... - if show_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - TimeSlot.objects.get(pk=pk).delete() + self.get_object().delete() return Response(status=status.HTTP_204_NO_CONTENT) -class APINoteViewSet(viewsets.ModelViewSet): - """ - Returns a list of notes. - - Superusers may access and update all notes. - """ +@extend_schema_view( + create=extend_schema(summary="Create a new note."), + retrieve=extend_schema(summary="Retrieve a single note."), + update=extend_schema(summary="Update an existing note."), + partial_update=extend_schema( + summary="Partially update an existing note.", + description="Only admins can partially update existing notes.", + ), + destroy=extend_schema(summary="Delete an existing note."), + list=extend_schema(summary="List all notes."), +) +class APINoteViewSet( + DisabledObjectPermissionCheckMixin, + NestedObjectFinderMixin, + viewsets.ModelViewSet, +): + ROUTE_FILTER_LOOKUPS = { + "show_pk": "show", + "timeslot_pk": "timeslot", + } queryset = Note.objects.all() serializer_class = NoteSerializer @@ -630,24 +672,11 @@ class APINoteViewSet(viewsets.ModelViewSet): pagination_class = LimitOffsetPagination filter_class = filters.NoteFilterSet - def get_queryset(self): - queryset = super().get_queryset() - - # subroute filters - show_pk, timeslot_pk = get_values(self.kwargs, "show_pk", "timeslot_pk") - if show_pk: - queryset = queryset.filter(show=show_pk) - if timeslot_pk: - queryset = queryset.filter(timeslot=timeslot_pk) - - return queryset - def create(self, request, *args, **kwargs): - """Create a note""" - - show_pk, schedule_pk, timeslot_pk = get_values( - self.kwargs, "show_pk", "schedule_pk", "timeslot_pk" - ) + """ + Only admins can create new notes. + """ + show_pk, timeslot_pk = get_values(self.kwargs, "show_pk", "timeslot_pk") if ( not request.user.is_superuser @@ -655,10 +684,6 @@ class APINoteViewSet(viewsets.ModelViewSet): ): return Response(status=status.HTTP_401_UNAUTHORIZED) - # Only create a note if show_id, timeslot_id and schedule_id is given - if show_pk is None or schedule_pk is None or timeslot_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - serializer = NoteSerializer( data={"show": show_pk, "timeslot": timeslot_pk} | request.data, context={"user_id": request.user.id}, @@ -676,52 +701,11 @@ class APINoteViewSet(viewsets.ModelViewSet): return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def retrieve(self, request, *args, **kwargs): + def update(self, request, *args, **kwargs): """ - Returns a single note - - Called by: - /notes/1 - /shows/1/notes/1 - /shows/1/timeslots/1/note/1 - /shows/1/schedules/1/timeslots/1/note/1 + Only admins can update existing notes. """ - pk, show_pk, schedule_pk, timeslot_pk = get_values( - self.kwargs, "pk", "show_pk", "schedule_pk", "timeslot_pk" - ) - - # - # /shows/1/notes/1 - # - # Returns a note to a show - # - if show_pk and timeslot_pk is None and schedule_pk is None: - note = get_object_or_404(Note, pk=pk, show=show_pk) - - # - # /shows/1/timeslots/1/note/1 - # /shows/1/schedules/1/timeslots/1/note/1 - # - # Return a note to a timeslot - # - elif show_pk and timeslot_pk: - note = get_object_or_404(Note, pk=pk, show=show_pk, timeslot=timeslot_pk) - - # - # /notes/1 - # - # Returns the given note - # - else: - note = get_object_or_404(Note, pk=pk) - - serializer = NoteSerializer(note) - return Response(serializer.data) - - def update(self, request, *args, **kwargs): - pk, show_pk, schedule_pk, timeslot_pk = get_values( - self.kwargs, "pk", "show_pk", "schedule_pk", "timeslot_pk" - ) + show_pk = get_values(self.kwargs, "show_pk") if ( not request.user.is_superuser @@ -729,13 +713,7 @@ class APINoteViewSet(viewsets.ModelViewSet): ): return Response(status=status.HTTP_401_UNAUTHORIZED) - # Allow PUT only when calling - # /shows/{show_pk}/schedules/{schedule_pk}/timeslots/{timeslot_pk}/note/{pk} - if show_pk is None or schedule_pk is None or timeslot_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - note = get_object_or_404(Note, pk=pk, timeslot=timeslot_pk, show=show_pk) - + note = self.get_object() serializer = NoteSerializer(note, data=request.data) if serializer.is_valid(): @@ -754,9 +732,10 @@ class APINoteViewSet(viewsets.ModelViewSet): return Response(status=status.HTTP_400_BAD_REQUEST) def destroy(self, request, *args, **kwargs): - pk, show_pk, schedule_pk, timeslot_pk = get_values( - self.kwargs, "pk", "show_pk", "schedule_pk", "timeslot_pk" - ) + """ + Only admins can delete existing notes. + """ + show_pk = get_values(self.kwargs, "show_pk") if ( not request.user.is_superuser @@ -764,10 +743,7 @@ class APINoteViewSet(viewsets.ModelViewSet): ): return Response(status=status.HTTP_401_UNAUTHORIZED) - if pk is None or show_pk is None or schedule_pk is None or timeslot_pk is None: - return Response(status=status.HTTP_400_BAD_REQUEST) - - Note.objects.get(pk=pk).delete() + self.get_object().delete() return Response(status=status.HTTP_204_NO_CONTENT) @@ -776,65 +752,95 @@ class ActiveFilterMixin: filter_class = filters.ActiveFilterSet +@extend_schema_view( + create=extend_schema(summary="Create a new category."), + retrieve=extend_schema(summary="Retrieve a single category."), + update=extend_schema(summary="Update an existing category."), + partial_update=extend_schema(summary="Partially update an existing category."), + destroy=extend_schema(summary="Delete an existing category."), + list=extend_schema(summary="List all categories."), +) class APICategoryViewSet(ActiveFilterMixin, viewsets.ModelViewSet): - """ - Returns a list of categories. - """ - queryset = Category.objects.all() serializer_class = CategorySerializer +@extend_schema_view( + create=extend_schema(summary="Create a new type."), + retrieve=extend_schema(summary="Retrieve a single type."), + update=extend_schema(summary="Update an existing type."), + partial_update=extend_schema(summary="Partially update an existing type."), + destroy=extend_schema(summary="Delete an existing type."), + list=extend_schema(summary="List all types."), +) class APITypeViewSet(ActiveFilterMixin, viewsets.ModelViewSet): - """ - Returns a list of types. - """ - queryset = Type.objects.all() serializer_class = TypeSerializer +@extend_schema_view( + create=extend_schema(summary="Create a new topic."), + retrieve=extend_schema(summary="Retrieve a single topic."), + update=extend_schema(summary="Update an existing topic."), + partial_update=extend_schema(summary="Partially update an existing topic."), + destroy=extend_schema(summary="Delete an existing topic."), + list=extend_schema(summary="List all topics."), +) class APITopicViewSet(ActiveFilterMixin, viewsets.ModelViewSet): - """ - Returns a list of topics. - """ - queryset = Topic.objects.all() serializer_class = TopicSerializer +@extend_schema_view( + create=extend_schema(summary="Create a new music focus."), + retrieve=extend_schema(summary="Retrieve a single music focus."), + update=extend_schema(summary="Update an existing music focus."), + partial_update=extend_schema(summary="Partially update an existing music focus."), + destroy=extend_schema(summary="Delete an existing music focus."), + list=extend_schema(summary="List all music focuses."), +) class APIMusicFocusViewSet(ActiveFilterMixin, viewsets.ModelViewSet): - """ - Returns a list of music focuses. - """ - queryset = MusicFocus.objects.all() serializer_class = MusicFocusSerializer +@extend_schema_view( + create=extend_schema(summary="Create a new funding category."), + retrieve=extend_schema(summary="Retrieve a single funding category."), + update=extend_schema(summary="Update an existing funding category."), + partial_update=extend_schema( + summary="Partially update an existing funding category." + ), + destroy=extend_schema(summary="Delete an existing funding category."), + list=extend_schema(summary="List all funding categories."), +) class APIFundingCategoryViewSet(ActiveFilterMixin, viewsets.ModelViewSet): - """ - Returns a list of funding categories. - """ - queryset = FundingCategory.objects.all() serializer_class = FundingCategorySerializer +@extend_schema_view( + create=extend_schema(summary="Create a new language."), + retrieve=extend_schema(summary="Retrieve a single language."), + update=extend_schema(summary="Update an existing language."), + partial_update=extend_schema(summary="Partially update an existing language."), + destroy=extend_schema(summary="Delete an existing language."), + list=extend_schema(summary="List all languages."), +) class APILanguageViewSet(ActiveFilterMixin, viewsets.ModelViewSet): - """ - Returns a list of languages. - """ - queryset = Language.objects.all() serializer_class = LanguageSerializer +@extend_schema_view( + create=extend_schema(summary="Create a new host."), + retrieve=extend_schema(summary="Retrieve a single host."), + update=extend_schema(summary="Update an existing host."), + partial_update=extend_schema(summary="Partially update an existing host."), + destroy=extend_schema(summary="Delete an existing host."), + list=extend_schema(summary="List all hosts."), +) class APIHostViewSet(ActiveFilterMixin, viewsets.ModelViewSet): - """ - Returns a list of hosts. - """ - queryset = Host.objects.all() serializer_class = HostSerializer pagination_class = LimitOffsetPagination diff --git a/requirements.txt b/requirements.txt index 8ef0a542428224ef5041843867732511557ee527..74eb2c55e17ba9fcec003a91be6225b78a34300a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ django-filter==21.1 django-oidc-provider==0.7.0 django-versatileimagefield==2.2 djangorestframework==3.13.1 +drf_spectacular==0.21.2 drf-nested-routers==0.93.4 future==0.18.2 gunicorn==20.1.0 diff --git a/steering/schema.py b/steering/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..8505a547fe503fb0c2161cb8187eb3adef2e462a --- /dev/null +++ b/steering/schema.py @@ -0,0 +1,59 @@ +from typing import Iterable, Tuple + + +def _generate_choices_description(choices: Iterable[Tuple[str, str]]): + def _gen(): + for key, value in choices: + yield f"**{key}**: {value}\n\n" + + return "\n".join(_gen()).strip() + + +def add_enum_documentation(result, generator, request, public): + """ + Choice descriptions are available through the assigned choices values + but are not incorporated into the schema by default. + + This post-processing hook adds them to the appropriate objects. + """ + # TODO: The logic behind this might be a worthwhile addition to drf-spectacular. + from program.models import Schedule + from program.serializers import SOLUTION_CHOICES + + weekday_choices_desc = _generate_choices_description( + Schedule.by_weekday.field.choices + ) + solutions_choices_desc = _generate_choices_description(SOLUTION_CHOICES.items()) + schema = result["components"]["schemas"] + schema["ByWeekdayEnum"]["description"] = weekday_choices_desc + schema["SolutionChoicesEnum"]["description"] = solutions_choices_desc + for item in ["ScheduleCreateUpdateRequest", "PatchedScheduleCreateUpdateRequest"]: + solutions_props = schema[item]["properties"]["solutions"][ + "additionalProperties" + ] + solutions_props["description"] = solutions_choices_desc + return result + + +def fix_schedule_pk_type(result, generator, request, public): + """ + schedule_pk’s type cannot be deduced in note routes because the Note class + has no schedule field drf-spectacular can map it to. + + Normally we would define this by using @extend_schema on the note viewset, but + as the schedule_pk field does not exist for __all__ note routes this would + inadvertently add the field to routes that don’t even have the parameter like + /api/v1/notes/. + + So we patch the type of schedule_pk fields in post-processing and ignore + the warning until we find a better solution. + """ + for path, methods in result["paths"].items(): + if not ("{schedule_pk}" in path and "/note" in path): + continue + for method_name, method_def in methods.items(): + for parameter in method_def["parameters"]: + if parameter["in"] == "path" and parameter["name"] == "schedule_pk": + parameter["schema"]["type"] = "integer" + break + return result diff --git a/steering/settings.py b/steering/settings.py index 16eb0590c4253180e16d823442f84b957627526e..bbedf9d9ec39df40d716670b0bdffa7bbc3026dc 100644 --- a/steering/settings.py +++ b/steering/settings.py @@ -109,6 +109,19 @@ REST_FRAMEWORK = { "program.auth.OidcOauth2Auth", ], "DEFAULT_FILTER_BACKENDS": ["django_filters.rest_framework.DjangoFilterBackend"], + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", + "EXCEPTION_HANDLER": "steering.views.full_details_exception_handler", +} + +SPECTACULAR_SETTINGS = { + "TITLE": "AURA Steering API", + "DESCRIPTION": "Programme/schedule management for Aura", + "POSTPROCESSING_HOOKS": [ + "drf_spectacular.hooks.postprocess_schema_enums", + "steering.schema.add_enum_documentation", + "steering.schema.fix_schedule_pk_type", + ], + "VERSION": "1.0.0", } INSTALLED_APPS = ( @@ -125,6 +138,7 @@ INSTALLED_APPS = ( "rest_framework", "rest_framework_nested", "django_filters", + "drf_spectacular", "oidc_provider", "corsheaders", ) diff --git a/steering/urls.py b/steering/urls.py index fdad4b383cf77da35c0620733567525c9a9c14cb..b0e7edf8b1f9a99cd8e0a6c39fcd7b6f457b71ca 100644 --- a/steering/urls.py +++ b/steering/urls.py @@ -18,6 +18,7 @@ # along with this program. If not, see <http://www.gnu.org/licenses/>. # +from drf_spectacular.views import SpectacularAPIView, SpectacularSwaggerView from rest_framework_nested import routers from django.contrib import admin @@ -100,5 +101,11 @@ urlpatterns = [ path("api/v1/playout", json_playout), path("api/v1/program/week", json_playout), path("api/v1/program/<int:year>/<int:month>/<int:day>)/", json_day_schedule), + path("api/v1/schema/", SpectacularAPIView.as_view(), name="schema"), + path( + "api/v1/schema/swagger-ui/", + SpectacularSwaggerView.as_view(url_name="schema"), + name="swagger-ui", + ), path("admin/", admin.site.urls), ] diff --git a/steering/views.py b/steering/views.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb278d58db73c88f476616c19ad60ad54d12076 --- /dev/null +++ b/steering/views.py @@ -0,0 +1,8 @@ +from rest_framework.exceptions import APIException +from rest_framework.views import exception_handler + + +def full_details_exception_handler(exc, context): + if isinstance(exc, APIException): + exc.detail = exc.get_full_details() + return exception_handler(exc, context)