diff --git a/CHANGES b/CHANGES index e69de29..f2f7429 100644 --- a/CHANGES +++ b/CHANGES @@ -0,0 +1,8 @@ +0.3.23 +------ +- add func to validate site for subject +- SiteModelAdminMixin: + - feature to limit to related site in FK and M2M + - filter queryset by site + - check for site attr + diff --git a/edc_sites/__init__.py b/edc_sites/__init__.py index 5c8f545..d54d12d 100644 --- a/edc_sites/__init__.py +++ b/edc_sites/__init__.py @@ -1,8 +1,12 @@ -from .add_or_update_django_sites import add_or_update_django_sites # noqa -from .get_all_sites import get_all_sites # noqa -from .get_country import get_current_country # noqa -from .get_site_by_attr import get_site_by_attr # noqa -from .get_site_id import InvalidSiteError, get_site_id # noqa -from .get_site_name import get_site_name # noqa -from .get_sites_by_country import get_sites_by_country # noqa -from .get_sites_module import get_sites_module # noqa +from .add_or_update_django_sites import add_or_update_django_sites +from .get_all_sites import get_all_sites +from .get_country import get_current_country +from .get_site_by_attr import get_site_by_attr +from .get_site_id import InvalidSiteError, get_site_id +from .get_site_name import get_site_name +from .get_sites_by_country import get_sites_by_country +from .get_sites_module import get_sites_module +from .valid_site_for_subject_or_raise import ( + InvalidSiteForSubjectError, + valid_site_for_subject_or_raise, +) diff --git a/edc_sites/admin/__init__.py b/edc_sites/admin/__init__.py new file mode 100644 index 0000000..a41b6a9 --- /dev/null +++ b/edc_sites/admin/__init__.py @@ -0,0 +1 @@ +from .modeladmin_mixins import SiteModelAdminMixin diff --git a/edc_sites/admin/modeladmin_mixins.py b/edc_sites/admin/modeladmin_mixins.py new file mode 100644 index 0000000..a934e02 --- /dev/null +++ b/edc_sites/admin/modeladmin_mixins.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import collections + +from django.contrib import admin +from django.core.exceptions import FieldError + +from ..get_country import get_current_country +from ..get_language_choices_for_site import get_language_choices_for_site + + +class SiteModeAdminMixinError(Exception): + pass + + +class SiteModelAdminMixin: + language_db_field_name = "language" + + limit_related_to_current_country: list[str] = None + limit_related_to_current_site: list[str] = None + + @admin.display(description="Site", ordering="site__id") + def site_code(self, obj=None): + return obj.site.id + + def get_queryset(self, request): + """Limit modeladmin queryset for the current site only""" + qs = super().get_queryset(request) + if getattr(request, "site", None): + try: + qs = qs.filter(site_id=request.site.id) + except FieldError: + raise SiteModeAdminMixinError( + f"Model missing field `site`. Model `{self.model}`. Did you mean to use " + f"the SiteModelAdminMixin? See `{self}`." + ) + return qs + + def get_form(self, request, obj=None, change=False, **kwargs): + """Add current_site attr to form instance""" + form = super().get_form(request, obj=obj, change=change, **kwargs) + form.current_site = getattr(request, "site", None) + return form + + def formfield_for_choice_field(self, db_field, request, **kwargs): + """Use site id to select languages to show in choices.""" + if db_field.name == self.language_db_field_name: + try: + language_choices = get_language_choices_for_site(request.site, other=True) + except AttributeError as e: + if "WSGIRequest" not in str(e): + raise + else: + if language_choices: + kwargs["choices"] = language_choices + return super().formfield_for_choice_field(db_field, request, **kwargs) + + def formfield_for_foreignkey(self, db_field, request, **kwargs): + """Filter a ForeignKey field`s queryset by the current site + or country. + + Note, a queryset set by the ModelForm class will overwrite + the field's queryset added here. + """ + self.raise_on_dups_in_field_lists( + self.limit_related_to_current_country, + self.limit_related_to_current_site, + ) + if db_field.name in (self.limit_related_to_current_country or []): + self.raise_on_queryset_exists(db_field, kwargs) + country = get_current_country(request) + model_cls = getattr(self.model, db_field.name).field.related_model + kwargs["queryset"] = model_cls.objects.filter(siteprofile__country=country) + elif db_field.name in (self.limit_related_to_current_site or []) and getattr( + request, "site", None + ): + self.raise_on_queryset_exists(db_field, kwargs) + model_cls = getattr(self.model, db_field.name).field.related_model + kwargs["queryset"] = model_cls.objects.filter(id=request.site.id) + elif db_field.name in (self.limit_related_to_current_site or []): + self.raise_on_queryset_exists(db_field, kwargs) + model_cls = getattr(self.model, db_field.name).field.related_model + kwargs["queryset"] = model_cls.on_site.all() + return super().formfield_for_foreignkey(db_field, request, **kwargs) + + def formfield_for_manytomany(self, db_field, request, **kwargs): + """Filter a ManyToMany field`s queryset by the current site. + + Note, a queryset set by the ModelForm class will overwrite + the field's queryset added here. + """ + self.raise_on_dups_in_field_lists( + self.limit_related_to_current_country, + self.limit_related_to_current_site, + ) + if db_field.name in (self.limit_related_to_current_site or []): + self.raise_on_queryset_exists(db_field, kwargs) + model_cls = getattr(self.model, db_field.name).remote_field.model + kwargs["queryset"] = model_cls.on_site.all() + elif db_field.name in (self.limit_related_to_current_country or []): + country = get_current_country(request) + model_cls = getattr(self.model, db_field.name).remote_field.model + kwargs["queryset"] = model_cls.objects.filter(siteprofile__country=country) + return super().formfield_for_manytomany(db_field, request, **kwargs) + + def raise_on_queryset_exists(self, db_field, kwargs): + """Raise an exception if the `queryset` key exists in the + kwargs dict. + + If `queryset` exists, remove the field name from the class attr: + limit_fk_field_to_... + limit_m2m_field_to_... + """ + if "queryset" in kwargs: + raise SiteModeAdminMixinError( + f"Key `queryset` unexpectedly exists. Got field `{db_field.name}` " + f"from {self}." + f". Did you manually set key `queryset` for field `{db_field.name}`?" + ) + + @staticmethod + def raise_on_dups_in_field_lists(*field_lists: list[str]): + orig = [] + for field_list in field_lists: + orig.extend(field_list or []) + if dups := [item for item, count in collections.Counter(orig).items() if count > 1]: + raise SiteModeAdminMixinError( + f"Related field appears in more than one list. Got {dups}." + ) diff --git a/edc_sites/get_site_model_cls.py b/edc_sites/get_site_model_cls.py new file mode 100644 index 0000000..4fe3f66 --- /dev/null +++ b/edc_sites/get_site_model_cls.py @@ -0,0 +1,12 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from django.apps import apps as django_apps + +if TYPE_CHECKING: + from django.contrib.sites.models import Site + + +def get_site_model_cls() -> Site: + return django_apps.get_model("sites.site") diff --git a/edc_sites/model_mixins.py b/edc_sites/model_mixins.py index 54eead9..8593b41 100644 --- a/edc_sites/model_mixins.py +++ b/edc_sites/model_mixins.py @@ -1,8 +1,14 @@ +from __future__ import annotations + from django.contrib.sites.managers import CurrentSiteManager as BaseCurrentSiteManager from django.contrib.sites.models import Site from django.db import models +class SiteModelMixinError(Exception): + pass + + class CurrentSiteManager(BaseCurrentSiteManager): use_in_migrations = True @@ -18,9 +24,32 @@ class SiteModelMixin(models.Model): on_site = CurrentSiteManager() def save(self, *args, **kwargs): - if not self.site: - self.site = Site.objects.get_current() + if not self.id: + self.site = self.get_site_on_create() + elif "update_fields" in kwargs and "site" not in kwargs.get("update_fields"): + pass + else: + self.validate_site_against_current() super().save(*args, **kwargs) + def get_site_on_create(self) -> Site: + """Returns a site model instance. + + See also django-multisite. + """ + current_site = Site.objects.get_current() + return current_site if not self.site else self.site + + def validate_site_against_current(self) -> None: + """Validate existing site instance matches current_site.""" + pass + # current_site = Site.objects.get_current() + # if self.site != current_site: + # site = current_site + # raise SiteModelMixinError( + # f"Invalid attempt to change site! Expected `{self.site}`. " + # f"Tried to change to `{current_site}`. Model=`{self}`. id=`{self.id}`." + # ) + class Meta: abstract = True diff --git a/edc_sites/modeladmin_mixins.py b/edc_sites/modeladmin_mixins.py index bbe417c..0dddd3f 100644 --- a/edc_sites/modeladmin_mixins.py +++ b/edc_sites/modeladmin_mixins.py @@ -1,94 +1,9 @@ -import collections +import warnings -from django.contrib import admin -from django.core.exceptions import FieldError +from .admin import SiteModelAdminMixin # noqa -from .get_country import get_current_country -from .get_language_choices_for_site import get_language_choices_for_site - - -class SiteModeAdminMixinError(Exception): - pass - - -class SiteModelAdminMixin: - """Adds the current site to the form from the request object. - - Use together with the `SiteModelFormMixin`. - - - - """ - - language_db_field_name = "language" - limit_fk_field_to_current_country: list[str] = None - limit_fk_field_to_current_site: list[str] = None - limit_m2m_field_to_current_site: list[str] = None - - @admin.display(description="Site", ordering="site__id") - def site_code(self, obj=None): - return obj.site.id - - def get_queryset(self, request): - """Limit modeladmin queryset for the current site only""" - qs = super().get_queryset(request) - if getattr(request, "site", None): - try: - qs = qs.filter(site_id=request.site.id) - except FieldError: - pass - return qs - - def get_form(self, request, obj=None, change=False, **kwargs): - """Add current_site attr to form instance""" - form = super().get_form(request, obj=obj, change=change, **kwargs) - form.current_site = getattr(request, "site", None) - return form - - def formfield_for_choice_field(self, db_field, request, **kwargs): - if db_field.name == self.language_db_field_name: - try: - language_choices = get_language_choices_for_site(request.site, other=True) - except AttributeError as e: - if "WSGIRequest" not in str(e): - raise - else: - if language_choices: - kwargs["choices"] = language_choices - return super().formfield_for_choice_field(db_field, request, **kwargs) - - def formfield_for_foreignkey(self, db_field, request, **kwargs): - """Add a queryset to kwargs if a condition is a matched. - - Note, a queryset set at the form level will replace any - queryset added to kwargs here. - """ - self.raise_on_duplicates_in_fk_fields_lists() - if db_field.name in (self.limit_fk_field_to_current_country or []): - country = get_current_country(request) - model_cls = getattr(self.model, db_field.name).field.related_model - kwargs["queryset"] = model_cls.objects.filter(siteprofile__country=country) - elif db_field.name in (self.limit_fk_field_to_current_site or []) and getattr( - request, "site", None - ): - model_cls = getattr(self.model, db_field.name).field.related_model - kwargs["queryset"] = model_cls.objects.filter(id=request.site.id) - elif db_field.name in (self.limit_fk_field_to_current_site or []): - model_cls = getattr(self.model, db_field.name).field.related_model - kwargs["queryset"] = model_cls.on_site.all() - return super().formfield_for_foreignkey(db_field, request, **kwargs) - - def formfield_for_manytomany(self, db_field, request, **kwargs): - if db_field.name in (self.limit_m2m_field_to_current_site or []): - model_cls = getattr(self.model, db_field.name).remote_field.model - kwargs["queryset"] = model_cls.on_site.all() - return super().formfield_for_manytomany(db_field, request, **kwargs) - - def raise_on_duplicates_in_fk_fields_lists(self): - orig = (self.limit_fk_field_to_current_country or []) + ( - self.limit_fk_field_to_current_site or [] - ) - if dups := [item for item, count in collections.Counter(orig).items() if count > 1]: - raise SiteModeAdminMixinError( - f"FK field name appears in more than one list. Got {dups}." - ) +warnings.warn( + "This import path is deprecated. Use `edc_sites.admin` instead.", + DeprecationWarning, + stacklevel=2, +) diff --git a/edc_sites/tests/models.py b/edc_sites/tests/models.py index 7876236..032de78 100644 --- a/edc_sites/tests/models.py +++ b/edc_sites/tests/models.py @@ -6,6 +6,6 @@ class TestModelWithSite(SiteModelMixin, models.Model): f1 = models.CharField(max_length=10, default="1") - on_site = CurrentSiteManager() - objects = models.Manager() + + on_site = CurrentSiteManager() diff --git a/edc_sites/valid_site_for_subject_or_raise.py b/edc_sites/valid_site_for_subject_or_raise.py new file mode 100644 index 0000000..9ea81ac --- /dev/null +++ b/edc_sites/valid_site_for_subject_or_raise.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from django.core.exceptions import ObjectDoesNotExist +from edc_registration import get_registered_subject_model_cls + +from edc_sites.get_site_model_cls import get_site_model_cls + +if TYPE_CHECKING: + from django.contrib.sites.models import Site + + +class InvalidSiteForSubjectError(Exception): + pass + + +def valid_site_for_subject_or_raise(subject_identifier: str) -> Site: + """Raises an InvalidSiteError exception if the subject_identifier is not + from the current site. + + * Confirms by querying RegisteredSubject. + * If subject_identifier is invalid will raise ObjectDoesNotExist + """ + current_site = get_site_model_cls().objects.get_current() + try: + get_registered_subject_model_cls().objects.get( + site=current_site, subject_identifier=subject_identifier + ) + except ObjectDoesNotExist: + try: + obj = get_registered_subject_model_cls().objects.get( + subject_identifier=subject_identifier + ) + except ObjectDoesNotExist as e: + raise InvalidSiteForSubjectError( + "Unable to validate site for subject. subject_identifier=" + f"`{subject_identifier}`. Got `{e}`" + ) + else: + raise InvalidSiteForSubjectError( + f"Invalid site for subject. {subject_identifier}. Expected `{obj.site.name}`. " + f"Got `{current_site.name}`" + ) + return current_site