Skip to content
Snippets Groups Projects
Verified Commit 92651bc3 authored by Ernesto Rico Schmidt's avatar Ernesto Rico Schmidt
Browse files

feat: clean-up queryset usage, replace is_superuser checks with group membership checks

parent 770fae36
No related branches found
No related tags found
1 merge request!29Use docker main tag
Pipeline #7112 passed
...@@ -32,6 +32,7 @@ from rest_framework.exceptions import ValidationError ...@@ -32,6 +32,7 @@ from rest_framework.exceptions import ValidationError
from rest_framework.pagination import LimitOffsetPagination from rest_framework.pagination import LimitOffsetPagination
from rest_framework.response import Response from rest_framework.response import Response
from django.conf import settings
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.http import Http404, HttpResponse, JsonResponse from django.http import Http404, HttpResponse, JsonResponse
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
...@@ -260,18 +261,18 @@ class APIUserViewSet( ...@@ -260,18 +261,18 @@ class APIUserViewSet(
viewsets.GenericViewSet, viewsets.GenericViewSet,
): ):
serializer_class = UserSerializer serializer_class = UserSerializer
queryset = User.objects.all()
filter_backends = [drf_filters.SearchFilter] filter_backends = [drf_filters.SearchFilter]
search_fields = ["username", "first_name", "last_name", "email"] search_fields = ["username", "first_name", "last_name", "email"]
def get_queryset(self): def get_queryset(self):
queryset = super().get_queryset() """The queryset contains all the users only for privileged users."""
# Constrain access to oneself except for superusers. qs = User.objects.all()
if not self.request.user.is_superuser:
queryset = queryset.filter(pk=self.request.user.id)
return queryset if not self.request.user.groups.filter(name=settings.PRIVILEGED_GROUP).exists():
qs = qs.filter(pk=self.request.user.id)
return qs
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = UserSerializer( serializer = UserSerializer(
...@@ -301,7 +302,6 @@ class APIUserViewSet( ...@@ -301,7 +302,6 @@ class APIUserViewSet(
), ),
) )
class APIImageViewSet(viewsets.ModelViewSet): class APIImageViewSet(viewsets.ModelViewSet):
queryset = Image.objects.all()
serializer_class = ImageSerializer serializer_class = ImageSerializer
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
pagination_class = LimitOffsetPagination pagination_class = LimitOffsetPagination
...@@ -583,7 +583,7 @@ class APIScheduleViewSet( ...@@ -583,7 +583,7 @@ class APIScheduleViewSet(
them including notes. them including notes.
""" """
if not request.user.is_superuser: if not request.user.groups.filter(name=settings.PRIVILEGED_GROUP).exists():
return Response(status=status.HTTP_401_UNAUTHORIZED) return Response(status=status.HTTP_401_UNAUTHORIZED)
# Only allow updating when with the `schedule` JSON object # Only allow updating when with the `schedule` JSON object
...@@ -699,7 +699,6 @@ class APINoteViewSet( ...@@ -699,7 +699,6 @@ class APINoteViewSet(
filterset_class = filters.NoteFilterSet filterset_class = filters.NoteFilterSet
pagination_class = LimitOffsetPagination pagination_class = LimitOffsetPagination
permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly] permission_classes = [permissions.DjangoModelPermissionsOrAnonReadOnly]
queryset = Note.objects.all()
serializer_class = NoteSerializer serializer_class = NoteSerializer
def get_serializer_context(self): def get_serializer_context(self):
...@@ -710,13 +709,15 @@ class APINoteViewSet( ...@@ -710,13 +709,15 @@ class APINoteViewSet(
return context return context
def get_queryset(self): def get_queryset(self):
qs = super().get_queryset().order_by("slug") """The queryset contains all the notes if the method is safe, otherwise
# Users should always be able to see notes - if the user is not in the privileged group, the notes owned by the user are filtered."""
qs = Note.objects.all()
if self.request.method not in permissions.SAFE_METHODS: if self.request.method not in permissions.SAFE_METHODS:
# If the request is not by an admin, if not self.request.user.groups.filter(name=settings.PRIVILEGED_GROUP).exists():
# check that the timeslot is owned by the current user.
if not self.request.user.is_superuser:
qs = qs.filter(timeslot__schedule__show__owners=self.request.user) qs = qs.filter(timeslot__schedule__show__owners=self.request.user)
return qs return qs
def _get_timeslot(self): def _get_timeslot(self):
...@@ -728,7 +729,7 @@ class APINoteViewSet( ...@@ -728,7 +729,7 @@ class APINoteViewSet(
if timeslot_pk is None: if timeslot_pk is None:
raise ValidationError({"timeslot_id": [_("This field is required.")]}, code="required") raise ValidationError({"timeslot_id": [_("This field is required.")]}, code="required")
qs = TimeSlot.objects.all() qs = TimeSlot.objects.all()
if not self.request.user.is_superuser: if not self.request.user.groups.filter(name=settings.ENTITLED_GROUPS[0]):
qs = qs.filter(schedule__show__owners=self.request.user) qs = qs.filter(schedule__show__owners=self.request.user)
try: try:
return qs.get(pk=timeslot_pk) return qs.get(pk=timeslot_pk)
......
...@@ -49,12 +49,12 @@ from program.views import ( ...@@ -49,12 +49,12 @@ from program.views import (
admin.autodiscover() admin.autodiscover()
router = routers.DefaultRouter() router = routers.DefaultRouter()
router.register(r"users", APIUserViewSet) router.register(r"users", APIUserViewSet, basename="user")
router.register(r"hosts", APIHostViewSet) router.register(r"hosts", APIHostViewSet)
router.register(r"shows", APIShowViewSet) router.register(r"shows", APIShowViewSet)
router.register(r"schedules", APIScheduleViewSet) router.register(r"schedules", APIScheduleViewSet)
router.register(r"timeslots", APITimeSlotViewSet) router.register(r"timeslots", APITimeSlotViewSet)
router.register(r"notes", APINoteViewSet) router.register(r"notes", APINoteViewSet, basename="note")
router.register(r"categories", APICategoryViewSet) router.register(r"categories", APICategoryViewSet)
router.register(r"topics", APITopicViewSet) router.register(r"topics", APITopicViewSet)
router.register(r"types", APITypeViewSet) router.register(r"types", APITypeViewSet)
...@@ -64,7 +64,7 @@ router.register(r"languages", APILanguageViewSet) ...@@ -64,7 +64,7 @@ router.register(r"languages", APILanguageViewSet)
router.register(r"licenses", APILicenseViewSet) router.register(r"licenses", APILicenseViewSet)
router.register(r"link-types", APILinkTypeViewSet) router.register(r"link-types", APILinkTypeViewSet)
router.register(r"rrules", APIRRuleViewSet) router.register(r"rrules", APIRRuleViewSet)
router.register(r"images", APIImageViewSet) router.register(r"images", APIImageViewSet, basename="image")
# Nested Routers # Nested Routers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment