Skip to content

Commit

Permalink
A few cleanup fixes.
Browse files Browse the repository at this point in the history
  - Uses filterset_class correctly.
  - Applies the filter via MregMixin.
  - Test basic filtering in hostgroups.
  • Loading branch information
terjekv authored and Terje Kvernes committed Dec 5, 2023
1 parent aa9c83d commit 3d2a97f
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 26 deletions.
4 changes: 2 additions & 2 deletions hostpolicy/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class HostPolicyAtomList(HostPolicyAtomLogMixin, MregListCreateAPIView):
serializer_class = serializers.HostPolicyAtomSerializer
permission_classes = (IsSuperOrHostPolicyAdminOrReadOnly, )
lookup_field = 'name'
filter_class = HostPolicyRoleFilterSet
filterset_class = HostPolicyAtomFilterSet

def post(self, request, *args, **kwargs):
if "name" in request.data:
Expand Down Expand Up @@ -95,7 +95,7 @@ class HostPolicyRoleList(HostPolicyRoleLogMixin, MregListCreateAPIView):
serializer_class = serializers.HostPolicyRoleSerializer
permission_classes = (IsSuperOrHostPolicyAdminOrReadOnly, )
lookup_field = 'name'
filter_class = HostPolicyRoleFilterSet
filterset_class = HostPolicyRoleFilterSet

def post(self, request, *args, **kwargs):
if "name" in request.data:
Expand Down
4 changes: 2 additions & 2 deletions mreg/api/v1/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
)


class JSONFieldExactFilter(filters.BaseCSVFilter, filters.CharFilter):
class JSONFieldExactFilter(filters.CharFilter):
pass


class CIDRFieldExactFilter(filters.BaseCSVFilter, filters.CharFilter):
class CIDRFieldExactFilter(filters.CharFilter):
pass


Expand Down
13 changes: 9 additions & 4 deletions mreg/api/v1/tests/test_hostgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
class APIHostGroupsTestCase(MregAPITestCase):
"""This class defines the test suite for api/hostgroups"""

def _url_returns_count(self, url, count):
"""Check that the url returns exactly count results."""
response = self.assert_get(url)
data = response.json()
self.assertEqual(data['count'], count)
self.assertEqual(len(data['results']), count)

def setUp(self):
"""Define the test client and other test variables."""
super().setUp()
Expand All @@ -21,10 +28,8 @@ def test_hostgroups_get_200_ok(self):

def test_hostgroups_list_200_ok(self):
"""List all hosts should return 200"""
response = self.assert_get('/hostgroups/')
data = response.json()
self.assertEqual(data['count'], 3)
self.assertEqual(len(data['results']), 3)
self._url_returns_count('/hostgroups/', 3)
self._url_returns_count('/hostgroups/?name=testgroup1', 1)

def test_hostgroups_get_404_not_found(self):
""""Getting a non-existing entry should return 404"""
Expand Down
33 changes: 18 additions & 15 deletions mreg/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from django.db.models import Prefetch
from django.shortcuts import get_object_or_404

from django_filters import rest_framework as rest_filters

from rest_framework import filters, generics, status
from rest_framework.decorators import api_view
from rest_framework.exceptions import MethodNotAllowed, ParseError
Expand Down Expand Up @@ -68,6 +70,7 @@
class MregMixin:
filter_backends = (
filters.SearchFilter,
rest_filters.DjangoFilterBackend,
filters.OrderingFilter,
)
ordering_fields = "__all__"
Expand Down Expand Up @@ -206,7 +209,7 @@ class CnameList(HostPermissionsListCreateAPIView):
queryset = Cname.objects.all()
serializer_class = CnameSerializer
lookup_field = "name"
filter_class = CnameFilterSet
filterset_class = CnameFilterSet


class CnameDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -237,7 +240,7 @@ class HinfoList(HostPermissionsListCreateAPIView):

queryset = Hinfo.objects.all().order_by("host")
serializer_class = HinfoSerializer
filter_class = HinfoFilterSet
filterset_class = HinfoFilterSet


class HinfoDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -279,7 +282,7 @@ class HostList(HostPermissionsListCreateAPIView):
# so HostFilterSet would need to implement these changes.
# However, we also reuse _host_prefetcher in the HostDetail view below
# so this would all require a bit of careful refactoring...
# filter_class = HostFilterSet
# filterset_class = HostFilterSet

def get_queryset(self):
qs = _host_prefetcher(super().get_queryset())
Expand Down Expand Up @@ -380,7 +383,7 @@ def patch(self, request, *args, **kwargs):
class HistoryList(MregMixin, generics.ListAPIView):
queryset = History.objects.all().order_by('id')
serializer_class = HistorySerializer
filter_class = HistoryFilterSet
filterset_class = HistoryFilterSet


class HistoryDetail(MregMixin, generics.RetrieveAPIView):
Expand All @@ -400,7 +403,7 @@ class IpaddressList(HostPermissionsListCreateAPIView):

queryset = Ipaddress.objects.get_queryset().order_by("id")
serializer_class = IpaddressSerializer
filter_class = IpaddressFilterSet
filterset_class = IpaddressFilterSet


class IpaddressDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -430,7 +433,7 @@ class LocList(HostPermissionsListCreateAPIView):

queryset = Loc.objects.all().order_by("host")
serializer_class = LocSerializer
filter_class = LocFilterSet
filterset_class = LocFilterSet


class LocDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -460,7 +463,7 @@ class MxList(HostPermissionsListCreateAPIView):

queryset = Mx.objects.get_queryset().order_by("id")
serializer_class = MxSerializer
filter_class = MxFilterSet
filterset_class = MxFilterSet


class MxDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -490,7 +493,7 @@ class NaptrList(HostPermissionsListCreateAPIView):

queryset = Naptr.objects.all()
serializer_class = NaptrSerializer
filter_class = NaptrFilterSet
filterset_class = NaptrFilterSet


class NaptrDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -521,7 +524,7 @@ class NameServerList(HostPermissionsListCreateAPIView):
queryset = NameServer.objects.all()
serializer_class = NameServerSerializer
lookup_field = "name"
filter_class = NameServerFilterSet
filterset_class = NameServerFilterSet


class NameServerDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -552,7 +555,7 @@ class PtrOverrideList(HostPermissionsListCreateAPIView):

queryset = PtrOverride.objects.get_queryset().order_by("id")
serializer_class = PtrOverrideSerializer
filter_class = PtrOverrideFilterSet
filterset_class = PtrOverrideFilterSet


class PtrOverrideDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -582,7 +585,7 @@ class SshfpList(HostPermissionsListCreateAPIView):

queryset = Sshfp.objects.get_queryset().order_by("id")
serializer_class = SshfpSerializer
filter_class = SshfpFilterSet
filterset_class = SshfpFilterSet


class SshfpDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -612,7 +615,7 @@ class SrvList(HostPermissionsListCreateAPIView):

queryset = Srv.objects.all()
serializer_class = SrvSerializer
filter_class = SrvFilterSet
filterset_class = SrvFilterSet


class SrvDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand Down Expand Up @@ -661,7 +664,7 @@ class NetworkList(MregListCreateAPIView):
serializer_class = NetworkSerializer
permission_classes = (IsSuperGroupMember | IsAuthenticatedAndReadOnly,)
lookup_field = "network"
filter_class = NetworkFilterSet
filterset_class = NetworkFilterSet

def post(self, request, *args, **kwargs):
error = _overlap_check(request.data["network"])
Expand Down Expand Up @@ -863,7 +866,7 @@ class TxtList(HostPermissionsListCreateAPIView):

queryset = Txt.objects.get_queryset().order_by("id")
serializer_class = TxtSerializer
filter_class = TxtFilterSet
filterset_class = TxtFilterSet


class TxtDetail(HostPermissionsUpdateDestroy, MregRetrieveUpdateDestroyAPIView):
Expand All @@ -888,7 +891,7 @@ class NetGroupRegexPermissionList(MregMixin, generics.ListCreateAPIView):
queryset = NetGroupRegexPermission.objects.all().order_by('id')
serializer_class = NetGroupRegexPermissionSerializer
permission_classes = (IsSuperOrAdminOrReadOnly,)
filter_class = NetGroupRegexPermissionFilterSet
filterset_class = NetGroupRegexPermissionFilterSet


class NetGroupRegexPermissionDetail(MregRetrieveUpdateDestroyAPIView):
Expand Down
2 changes: 1 addition & 1 deletion mreg/api/v1/views_bacnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class BACnetIDList(MregListCreateAPIView):
permission_classes = (IsGrantedNetGroupRegexPermission,)
lookup_field = "id"
filterset_fields = "id"
filter_class = BACnetIDFilterSet
filterset_class = BACnetIDFilterSet

def post(self, request, *args, **kwargs):
# request.data is immutable
Expand Down
2 changes: 1 addition & 1 deletion mreg/api/v1/views_hostgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class HostGroupList(HostGroupLogMixin, MregListCreateAPIView):
queryset = HostGroup.objects.all()
serializer_class = serializers.HostGroupSerializer
permission_classes = (IsSuperOrGroupAdminOrReadOnly, )
filter_class = HostGroupFilterSet
filterset_class = HostGroupFilterSet

def post(self, request, *args, **kwargs):
if "name" in request.data:
Expand Down
2 changes: 1 addition & 1 deletion mreg/api/v1/views_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class LabelList(MregListCreateAPIView):
queryset = Label.objects.all()
serializer_class = serializers.LabelSerializer
permission_classes = (IsSuperOrAdminOrReadOnly,)
filter_class = LabelFilterSet
filterset_class = LabelFilterSet

def post(self, request, *args, **kwargs):
if "name" in request.data:
Expand Down

0 comments on commit 3d2a97f

Please sign in to comment.