diff --git a/edc_sites/admin/site_model_admin_mixin.py b/edc_sites/admin/site_model_admin_mixin.py index 00a34d4..0ffc90c 100644 --- a/edc_sites/admin/site_model_admin_mixin.py +++ b/edc_sites/admin/site_model_admin_mixin.py @@ -1,6 +1,7 @@ from __future__ import annotations import collections +from typing import TYPE_CHECKING, Type from django.contrib import admin from django.core.exceptions import FieldError, ObjectDoesNotExist @@ -10,6 +11,10 @@ from ..site import sites from .list_filters import SiteListFilter +if TYPE_CHECKING: + from django.contrib.admin import SimpleListFilter + + __all__ = ["SiteModelAdminMixin"] @@ -35,7 +40,7 @@ def site_name(self, obj=None): return obj.site.name return f"{site_profile.site.id} {site_profile.description}" - def get_list_filter(self, request): + def get_list_filter(self, request) -> tuple[str | Type[SimpleListFilter], ...]: """Insert `SiteListFilter` before field name `created`. Remove site from the list if user does not have access diff --git a/edc_sites/modelform_mixins.py b/edc_sites/modelform_mixins.py index d05e36b..6e296c3 100644 --- a/edc_sites/modelform_mixins.py +++ b/edc_sites/modelform_mixins.py @@ -1,11 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING - from django import forms - -if TYPE_CHECKING: - from django.contrib.sites.models import Site +from django.contrib.sites.models import Site __all__ = ["SiteModelFormMixin"] @@ -32,7 +28,11 @@ def clean(self) -> dict: @property def site(self) -> Site: - return self.cleaned_data.get("site") or self.instance.site or self.related_visit.site + if related_visit := getattr(self, "related_visit", None): + return related_visit.site + return ( + self.cleaned_data.get("site") or self.instance.site or Site.objects.get_current() + ) def validate_with_current_site(self) -> None: current_site = getattr(self, "current_site", None)