From 8f2f9eeeff6ba5ba6b7e88fc3b30edb17a23fecf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Freitag?= Date: Mon, 11 Apr 2022 15:50:22 +0200 Subject: [PATCH] Accept per-widget attrs for PhoneNumberPrefixWidget Modeled after https://github.com/django/django/blob/b8759093d8eaea32c8d177615df7de559b6571c7/django/forms/widgets.py#L962-L994. --- phonenumber_field/widgets.py | 11 +++++++---- tests/test_widgets.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/phonenumber_field/widgets.py b/phonenumber_field/widgets.py index c81cccb5..d63f086c 100644 --- a/phonenumber_field/widgets.py +++ b/phonenumber_field/widgets.py @@ -44,14 +44,14 @@ def localized_choices(language): class PhonePrefixSelect(Select): initial = None - def __init__(self, initial=None): + def __init__(self, initial=None, attrs=None): language = translation.get_language() or settings.LANGUAGE_CODE choices = localized_choices(language) if initial is None: initial = getattr(settings, "PHONENUMBER_DEFAULT_REGION", None) if initial in REGION_CODE_TO_COUNTRY_CODE: self.initial = initial - super().__init__(choices=sorted(choices, key=lambda item: item[1])) + super().__init__(attrs=attrs, choices=sorted(choices, key=lambda item: item[1])) def get_context(self, name, value, attrs): attrs = (attrs or {}).copy() @@ -66,8 +66,11 @@ class PhoneNumberPrefixWidget(MultiWidget): - an input for local phone number """ - def __init__(self, attrs=None, initial=None): - widgets = (PhonePrefixSelect(initial), TextInput()) + def __init__(self, attrs=None, initial=None, country_attrs=None, number_attrs=None): + widgets = ( + PhonePrefixSelect(initial, attrs=country_attrs), + TextInput(attrs=number_attrs), + ) super().__init__(widgets, attrs) def decompress(self, value): diff --git a/tests/test_widgets.py b/tests/test_widgets.py index 3615b246..d967a937 100644 --- a/tests/test_widgets.py +++ b/tests/test_widgets.py @@ -205,6 +205,17 @@ def test_maxlength_not_in_select(self): html = widget.render(name="widget", value=None, attrs={"maxlength": 32}) self.assertIn('', html) + self.assertInHTML( + '', html, count=1 + ) + class PhoneNumberInternationalFallbackWidgetTest(SimpleTestCase): def test_fallback_widget_switches_between_national_and_international(self):