Skip to content

Commit

Permalink
Migrating away from postgres CI fields.
Browse files Browse the repository at this point in the history
*Caveats:*

  - Assumes fields in LCI fields are all already in lowercase. No efforts are currently made to manipulate these fields. This is fixable if required.
  - Email fields are migrated to models.EmailField. Case insensitivity is not preserved.
  - Needs (much) more testing.

Implementation details:

  - Creates two LowerCase-fields to replace the LCI-fields.
  - Uses a manager to hook into the calls to objects.get, objects.filter, and objects.exclude (the latter is currently unused).
  - Uses a mixin for views to overload get_object for relevant detail views.
  - Overloaded get_queryset() usage to check for exists are handled manually, should be cleaned up.
  • Loading branch information
terjekv authored and Terje Kvernes committed Dec 5, 2023
1 parent e26fa52 commit 7eaa56e
Show file tree
Hide file tree
Showing 17 changed files with 315 additions and 56 deletions.
7 changes: 5 additions & 2 deletions hostpolicy/api/v1/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from django.contrib.auth.models import Group

from hostpolicy.models import HostPolicyAtom, HostPolicyRole
from mreg.api.v1.tests.tests import MregAPITestCase
from mreg.models.base import Label
from mreg.models.host import Host, Ipaddress
from mreg.models.network import NetGroupRegexPermission

from mreg.api.v1.tests.tests import MregAPITestCase


class HostPolicyUniqueNameSpace(MregAPITestCase):

Expand Down Expand Up @@ -43,6 +42,10 @@ def test_list_200_ok(self):
self.assertEqual(data['count'], 2)
self.assertEqual(len(data['results']), 2)

def test_case_insensitive(self):
"""Case insensitive lookups should work"""
self.assert_get(self.basejoin(self.object_one.name.upper()))

def test_get_404_not_found(self):
""""Getting a non-existing entry should return 404"""
self.assert_get_and_404(self.basejoin('nonexisting'))
Expand Down
29 changes: 15 additions & 14 deletions hostpolicy/api/v1/views.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from django.db.models import Prefetch
from rest_framework import status
from rest_framework.response import Response

from url_filter.filtersets import ModelFilterSet

from hostpolicy.models import HostPolicyAtom, HostPolicyRole
from hostpolicy.api.permissions import IsSuperOrHostPolicyAdminOrReadOnly
from mreg.api.v1.serializers import HostNameSerializer
from mreg.api.v1.views import (MregListCreateAPIView,
MregPermissionsListCreateAPIView,
MregPermissionsUpdateDestroy,
MregRetrieveUpdateDestroyAPIView,
)

from hostpolicy.models import HostPolicyAtom, HostPolicyRole
from mreg.api.v1.history import HistoryLog
from mreg.api.v1.serializers import HostNameSerializer
from mreg.api.v1.views import (
MregListCreateAPIView,
MregPermissionsListCreateAPIView,
MregPermissionsUpdateDestroy,
MregRetrieveUpdateDestroyAPIView,
)
from mreg.api.v1.views_m2m import M2MDetail, M2MList, M2MPermissions
from mreg.mixins import LowerCaseLookupMixin
from mreg.models.host import Host

from . import serializers
Expand Down Expand Up @@ -68,14 +68,15 @@ def get_queryset(self):

def post(self, request, *args, **kwargs):
if "name" in request.data:
if self.get_queryset().filter(name=request.data['name']).exists():
# Due to the overriding of get_queryset, we need to manually use lower()
if self.get_queryset().filter(name=request.data['name'].lower()).exists():
content = {'ERROR': 'name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)

return super().post(request, *args, **kwargs)


class HostPolicyAtomDetail(HostPolicyAtomLogMixin, MregRetrieveUpdateDestroyAPIView):
class HostPolicyAtomDetail(HostPolicyAtomLogMixin, LowerCaseLookupMixin, MregRetrieveUpdateDestroyAPIView):

queryset = HostPolicyAtom.objects.all()
serializer_class = serializers.HostPolicyAtomSerializer
Expand Down Expand Up @@ -103,14 +104,14 @@ def get_queryset(self):

def post(self, request, *args, **kwargs):
if "name" in request.data:
if self.get_queryset().filter(name=request.data['name']).exists():
# Due to the overriding of get_queryset, we need to manually use lower()
if self.get_queryset().filter(name=request.data['name'].lower()).exists():
content = {'ERROR': 'name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)

return super().post(request, *args, **kwargs)


class HostPolicyRoleDetail(HostPolicyRoleLogMixin, MregRetrieveUpdateDestroyAPIView):
class HostPolicyRoleDetail(HostPolicyRoleLogMixin, LowerCaseLookupMixin, MregRetrieveUpdateDestroyAPIView):

queryset = HostPolicyRole.objects.all()
serializer_class = serializers.HostPolicyRoleSerializer
Expand Down
25 changes: 25 additions & 0 deletions hostpolicy/migrations/0004_migrate_away_from_ci_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated by Django 3.2.18 on 2023-06-29 09:26

from django.db import migrations
import hostpolicy.models
import mreg.fields


class Migration(migrations.Migration):

dependencies = [
('hostpolicy', '0003_hostpolicyrole_labels'),
]

operations = [
migrations.AlterField(
model_name='hostpolicyatom',
name='name',
field=mreg.fields.LowerCaseCharField(max_length=64, unique=True, validators=[hostpolicy.models._validate_atom_name]),
),
migrations.AlterField(
model_name='hostpolicyrole',
name='name',
field=mreg.fields.LowerCaseCharField(max_length=64, unique=True, validators=[hostpolicy.models._validate_role_name]),
),
]
15 changes: 10 additions & 5 deletions hostpolicy/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from django.core.exceptions import ValidationError
import datetime

from django.core.exceptions import ValidationError
from django.db import models

from mreg.fields import LCICharField
from mreg.models.host import Host
from mreg.fields import LowerCaseCharField
from mreg.managers import LowerCaseManager
from mreg.models.base import Label
from mreg.models.host import Host


class HostPolicyComponent(models.Model):
Expand All @@ -29,7 +30,9 @@ def _validate_atom_name(name):

class HostPolicyAtom(HostPolicyComponent):

name = LCICharField(max_length=64, unique=True, validators=[_validate_atom_name])
name = LowerCaseCharField(max_length=64, unique=True, validators=[_validate_atom_name])

objects = LowerCaseManager()

class Meta:
db_table = 'hostpolicy_atom'
Expand All @@ -44,11 +47,13 @@ def _validate_role_name(name):

class HostPolicyRole(HostPolicyComponent):

name = LCICharField(max_length=64, unique=True, validators=[_validate_role_name])
name = LowerCaseCharField(max_length=64, unique=True, validators=[_validate_role_name])
atoms = models.ManyToManyField(HostPolicyAtom, related_name='roles')
hosts = models.ManyToManyField(Host, related_name='hostpolicyroles')
labels = models.ManyToManyField(Label, blank=True, related_name='hostpolicyroles')

objects = LowerCaseManager()

class Meta:
db_table = 'hostpolicy_role'
ordering = ('name',)
4 changes: 4 additions & 0 deletions mreg/api/v1/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
PtrOverrideSerializer, SrvSerializer,
SshfpSerializer, TxtSerializer)

from mreg.mixins import LowerCaseLookupMixin

# These filtersets are used for applying generic filtering to all objects.
class CnameFilterSet(ModelFilterSet):
Expand Down Expand Up @@ -271,6 +272,7 @@ def get_queryset(self):


class CnameDetail(HostPermissionsUpdateDestroy,
LowerCaseLookupMixin,
MregRetrieveUpdateDestroyAPIView):
"""
get:
Expand Down Expand Up @@ -304,6 +306,7 @@ def get_queryset(self):


class HinfoDetail(HostPermissionsUpdateDestroy,
LowerCaseLookupMixin,
MregRetrieveUpdateDestroyAPIView):
"""
get:
Expand Down Expand Up @@ -409,6 +412,7 @@ def post(self, request, *args, **kwargs):


class HostDetail(HostPermissionsUpdateDestroy,
LowerCaseLookupMixin,
MregRetrieveUpdateDestroyAPIView):
"""
get:
Expand Down
7 changes: 5 additions & 2 deletions mreg/api/v1/views_hostgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
IsSuperOrGroupAdminOrReadOnly)
from mreg.models.host import Host, HostGroup

from mreg.mixins import LowerCaseLookupMixin

from . import serializers
from .history import HistoryLog
from .views import (MregListCreateAPIView,
Expand Down Expand Up @@ -85,14 +87,15 @@ def get_queryset(self):

def post(self, request, *args, **kwargs):
if "name" in request.data:
if self.get_queryset().filter(name=request.data['name']).exists():
# We need to manually use lower() here due to the overriden get_queryset()
if self.get_queryset().filter(name=request.data['name'].lower()).exists():
content = {'ERROR': 'hostgroup name already in use'}
return Response(content, status=status.HTTP_409_CONFLICT)
self.lookup_field = 'name'
return super().post(request, *args, **kwargs)


class HostGroupDetail(HostGroupPermissionsUpdateDestroy):
class HostGroupDetail(LowerCaseLookupMixin, HostGroupPermissionsUpdateDestroy):
"""
get:
Returns details for the specified hostgroup. Includes hostgroups that are members.
Expand Down
4 changes: 3 additions & 1 deletion mreg/api/v1/views_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .views import MregListCreateAPIView, MregRetrieveUpdateDestroyAPIView
from mreg.models.base import Label
from mreg.api.permissions import IsSuperOrAdminOrReadOnly

from mreg.mixins import LowerCaseLookupMixin
from . import serializers


Expand All @@ -31,7 +33,7 @@ def post(self, request, *args, **kwargs):
return super().post(request, *args, **kwargs)


class LabelDetail(MregRetrieveUpdateDestroyAPIView):
class LabelDetail(LowerCaseLookupMixin, MregRetrieveUpdateDestroyAPIView):
"""
get:
Returns details for a Label.
Expand Down
6 changes: 4 additions & 2 deletions mreg/api/v1/views_zones.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from mreg.models.host import Host
from mreg.models.zone import ForwardZone, ForwardZoneDelegation, ReverseZone, ReverseZoneDelegation

from mreg.mixins import LowerCaseLookupMixin

from mreg.api.permissions import (IsSuperGroupMember, IsAuthenticatedAndReadOnly)

from .serializers import (ForwardZoneDelegationSerializer, ForwardZoneSerializer,
Expand Down Expand Up @@ -172,7 +174,7 @@ class ReverseZoneDelegationList(ZoneDelegationList):
model = ReverseZone


class ZoneDetail(MregRetrieveUpdateDestroyAPIView):
class ZoneDetail(LowerCaseLookupMixin, MregRetrieveUpdateDestroyAPIView):
"""
get:
List details for a zone.
Expand Down Expand Up @@ -239,7 +241,7 @@ class ReverseZoneDetail(ZoneDetail):
queryset = ReverseZone.objects.all()


class ZoneDelegationDetail(MregRetrieveUpdateDestroyAPIView):
class ZoneDelegationDetail(LowerCaseLookupMixin, MregRetrieveUpdateDestroyAPIView):

lookup_field = 'delegation'
permission_classes = (IsSuperGroupMember | IsAuthenticatedAndReadOnly, )
Expand Down
15 changes: 15 additions & 0 deletions mreg/fields.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import django.contrib.postgres.fields as pgfields
import django.db.models as models

from .validators import validate_hostname

class LowerCaseCharField(models.CharField):
"""A CharField where the value is stored in lower case."""

def get_db_prep_save(self, value, connection):
if isinstance(value, str):
value = value.lower()
return super().get_db_prep_save(value, connection)

class LCICharField(pgfields.CICharField):
"""A pgfields.CICharField where the value is stored in lower case. """
Expand All @@ -11,6 +19,13 @@ def get_db_prep_save(self, value, connection):
value = value.lower()
return super().get_db_prep_save(value, connection)

class LowerCaseDNSNameField(LowerCaseCharField):
"""A field to hold DNS names."""
def __init__(self, *args, **kwargs):
kwargs['max_length'] = 253
if 'validators' not in kwargs:
kwargs['validators'] = [validate_hostname]
super().__init__(*args, **kwargs)

class DnsNameField(LCICharField):
"""
Expand Down
42 changes: 42 additions & 0 deletions mreg/managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from django.db import models

from .fields import LowerCaseCharField


class LowerCaseManager(models.Manager):
"""A manager that lowercases all values of LowerCaseCharFields in filter/exclude/get calls."""

@property
def lowercase_fields(self):
if not hasattr(self, "_lowercase_fields_cache"):
self._lowercase_fields_cache = [
field.name
for field in self.model._meta.get_fields()
if isinstance(field, LowerCaseCharField)
]
return self._lowercase_fields_cache

def _lowercase_fields(self, **kwargs):
lower_kwargs = {}
for key, value in kwargs.items():
field_name = key.split("__")[0]
if field_name in self.lowercase_fields and isinstance(value, str):
value = value.lower()
lower_kwargs[key] = value
return lower_kwargs

def filter(self, **kwargs):
return super().filter(**self._lowercase_fields(**kwargs))

def exclude(self, **kwargs):
return super().exclude(**self._lowercase_fields(**kwargs))

def get(self, **kwargs):
return super().get(**self._lowercase_fields(**kwargs))


def lower_case_manager_factory(base_manager):
class LowerCaseBaseManager(base_manager, LowerCaseManager):
pass

return LowerCaseBaseManager
Loading

0 comments on commit 7eaa56e

Please sign in to comment.