From 83538c7316e9b5d9a9ed50896539fa44139ec5d1 Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Wed, 18 Sep 2024 12:10:46 -0400 Subject: [PATCH 1/6] feat(admin): add prohibited email domains Signed-off-by: Mike Fiedler --- tests/unit/admin/test_routes.py | 15 +++++++++++++++ warehouse/admin/routes.py | 16 ++++++++++++++++ warehouse/admin/templates/admin/base.html | 4 ++-- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/tests/unit/admin/test_routes.py b/tests/unit/admin/test_routes.py index 9b1125594bda..ab481bf4cd75 100644 --- a/tests/unit/admin/test_routes.py +++ b/tests/unit/admin/test_routes.py @@ -284,6 +284,21 @@ def test_includeme(): "/admin/prohibited_user_names/bulk/", domain=warehouse, ), + pretend.call( + "admin.prohibited_email_domains.list", + "/admin/prohibited_email_domains/", + domain=warehouse, + ), + pretend.call( + "admin.prohibited_email_domains.add", + "/admin/prohibited_email_domains/add/", + domain=warehouse, + ), + pretend.call( + "admin.prohibited_email_domains.remove", + "/admin/prohibited_email_domains/remove/", + domain=warehouse, + ), pretend.call( "admin.observations.list", "/admin/observations/", diff --git a/warehouse/admin/routes.py b/warehouse/admin/routes.py index 6fb87d7a4520..0f8b941d7599 100644 --- a/warehouse/admin/routes.py +++ b/warehouse/admin/routes.py @@ -293,6 +293,22 @@ def includeme(config): "/admin/prohibited_user_names/bulk/", domain=warehouse, ) + # Prohibited Email related Admin pages + config.add_route( + "admin.prohibited_email_domains.list", + "/admin/prohibited_email_domains/", + domain=warehouse, + ) + config.add_route( + "admin.prohibited_email_domains.add", + "/admin/prohibited_email_domains/add/", + domain=warehouse, + ) + config.add_route( + "admin.prohibited_email_domains.remove", + "/admin/prohibited_email_domains/remove/", + domain=warehouse, + ) # Observation related Admin pages config.add_route( diff --git a/warehouse/admin/templates/admin/base.html b/warehouse/admin/templates/admin/base.html index e2bb9eefa842..7416eb3e08c8 100644 --- a/warehouse/admin/templates/admin/base.html +++ b/warehouse/admin/templates/admin/base.html @@ -175,9 +175,9 @@ From db024cec434e4937c96130212fcc93b9e42023c6 Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Wed, 18 Sep 2024 12:11:16 -0400 Subject: [PATCH 2/6] add permissions Signed-off-by: Mike Fiedler --- tests/unit/test_config.py | 4 ++++ warehouse/authnz/_permissions.py | 3 +++ warehouse/config.py | 4 ++++ 3 files changed, 11 insertions(+) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index d7f295448d7a..698ef2b2aa5f 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -572,6 +572,8 @@ def test_root_factory_access_control_list(): Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, Permissions.AdminOrganizationsWrite, + Permissions.AdminProhibitedEmailDomainsRead, + Permissions.AdminProhibitedEmailDomainsWrite, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedProjectsWrite, Permissions.AdminProhibitedUsernameRead, @@ -602,6 +604,7 @@ def test_root_factory_access_control_list(): Permissions.AdminObservationsRead, Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, + Permissions.AdminProhibitedEmailDomainsRead, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedUsernameRead, Permissions.AdminProjectsRead, @@ -627,6 +630,7 @@ def test_root_factory_access_control_list(): Permissions.AdminObservationsRead, Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, + Permissions.AdminProhibitedEmailDomainsRead, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedUsernameRead, Permissions.AdminProjectsRead, diff --git a/warehouse/authnz/_permissions.py b/warehouse/authnz/_permissions.py index 373788380ca2..62d60ac07088 100644 --- a/warehouse/authnz/_permissions.py +++ b/warehouse/authnz/_permissions.py @@ -60,6 +60,9 @@ class Permissions(StrEnum): AdminOrganizationsRead = "admin:organizations:read" AdminOrganizationsWrite = "admin:organizations:write" + AdminProhibitedEmailDomainsRead = "admin:prohibited-email-domains:read" + AdminProhibitedEmailDomainsWrite = "admin:prohibited-email-domains:write" + AdminProhibitedProjectsRead = "admin:prohibited-projects:read" AdminProhibitedProjectsWrite = "admin:prohibited-projects:write" diff --git a/warehouse/config.py b/warehouse/config.py index 753bb9aaee76..ea71cfb0edb5 100644 --- a/warehouse/config.py +++ b/warehouse/config.py @@ -87,6 +87,8 @@ class RootFactory: Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, Permissions.AdminOrganizationsWrite, + Permissions.AdminProhibitedEmailDomainsRead, + Permissions.AdminProhibitedEmailDomainsWrite, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedProjectsWrite, Permissions.AdminProhibitedUsernameRead, @@ -117,6 +119,7 @@ class RootFactory: Permissions.AdminObservationsRead, Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, + Permissions.AdminProhibitedEmailDomainsRead, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedUsernameRead, Permissions.AdminProjectsRead, @@ -142,6 +145,7 @@ class RootFactory: Permissions.AdminObservationsRead, Permissions.AdminObservationsWrite, Permissions.AdminOrganizationsRead, + Permissions.AdminProhibitedEmailDomainsRead, Permissions.AdminProhibitedProjectsRead, Permissions.AdminProhibitedUsernameRead, Permissions.AdminProjectsRead, From 7bd6f6f28b260176b297fedfce88e7a507a1b8ef Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Wed, 18 Sep 2024 12:33:01 -0400 Subject: [PATCH 3/6] add views and templates Signed-off-by: Mike Fiedler --- tests/common/db/accounts.py | 19 +- .../views/test_prohibited_email_domains.py | 211 ++++++++++++++++++ .../admin/prohibited_email_domains/list.html | 183 +++++++++++++++ .../admin/views/prohibited_email_domains.py | 131 +++++++++++ 4 files changed, 543 insertions(+), 1 deletion(-) create mode 100644 tests/unit/admin/views/test_prohibited_email_domains.py create mode 100644 warehouse/admin/templates/admin/prohibited_email_domains/list.html create mode 100644 warehouse/admin/views/prohibited_email_domains.py diff --git a/tests/common/db/accounts.py b/tests/common/db/accounts.py index 435bd1717938..00e2756012cf 100644 --- a/tests/common/db/accounts.py +++ b/tests/common/db/accounts.py @@ -13,13 +13,21 @@ import datetime import factory +import faker from argon2 import PasswordHasher -from warehouse.accounts.models import Email, ProhibitedUserName, User +from warehouse.accounts.models import ( + Email, + ProhibitedEmailDomain, + ProhibitedUserName, + User, +) from .base import WarehouseFactory +fake = faker.Faker() + class UserFactory(WarehouseFactory): class Meta: @@ -90,6 +98,15 @@ class Meta: transient_bounces = 0 +class ProhibitedEmailDomainFactory(WarehouseFactory): + class Meta: + model = ProhibitedEmailDomain + + # TODO: Replace when factory_boy supports `unique`. + # See https://github.com/FactoryBoy/factory_boy/pull/997 + domain = factory.Sequence(lambda _: fake.unique.domain_name()) + + class ProhibitedUsernameFactory(WarehouseFactory): class Meta: model = ProhibitedUserName diff --git a/tests/unit/admin/views/test_prohibited_email_domains.py b/tests/unit/admin/views/test_prohibited_email_domains.py new file mode 100644 index 000000000000..5e8831736e6e --- /dev/null +++ b/tests/unit/admin/views/test_prohibited_email_domains.py @@ -0,0 +1,211 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pretend +import pytest + +from pyramid.httpexceptions import HTTPBadRequest, HTTPSeeOther + +from warehouse.admin.views import prohibited_email_domains as views + +from ....common.db.accounts import ProhibitedEmailDomain, ProhibitedEmailDomainFactory + + +class TestProhibitedEmailDomainsList: + def test_no_query(self, db_request): + prohibited = sorted( + ProhibitedEmailDomainFactory.create_batch(30), + key=lambda b: b.created, + ) + + result = views.prohibited_email_domains(db_request) + + assert result == {"prohibited_email_domains": prohibited[:25], "query": None} + + def test_with_page(self, db_request): + prohibited = sorted( + ProhibitedEmailDomainFactory.create_batch(30), + key=lambda b: b.created, + ) + db_request.GET["page"] = "2" + + result = views.prohibited_email_domains(db_request) + + assert result == {"prohibited_email_domains": prohibited[25:], "query": None} + + def test_with_invalid_page(self): + request = pretend.stub(params={"page": "not an integer"}) + + with pytest.raises(HTTPBadRequest): + views.prohibited_email_domains(request) + + def test_basic_query(self, db_request): + prohibited = sorted( + ProhibitedEmailDomainFactory.create_batch(30), + key=lambda b: b.created, + ) + db_request.GET["q"] = prohibited[0].domain + + result = views.prohibited_email_domains(db_request) + + assert result == { + "prohibited_email_domains": [prohibited[0]], + "query": prohibited[0].domain, + } + + def test_wildcard_query(self, db_request): + prohibited = sorted( + ProhibitedEmailDomainFactory.create_batch(30), + key=lambda b: b.created, + ) + db_request.GET["q"] = f"{prohibited[0].domain[:-1]}%" + + result = views.prohibited_email_domains(db_request) + + assert result == { + "prohibited_email_domains": [prohibited[0]], + "query": f"{prohibited[0].domain[:-1]}%", + } + + +class TestProhibitedEmailDomainsAdd: + def test_no_email_domain(self, db_request): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/add/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {} + + with pytest.raises(HTTPSeeOther): + views.add_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Email domain is required.", queue="error") + ] + + def test_invalid_domain(self, db_request): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/add/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {"email_domain": "invalid"} + + with pytest.raises(HTTPSeeOther): + views.add_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Invalid domain name 'invalid'", queue="error") + ] + + def test_duplicate_domain(self, db_request): + existing_domain = ProhibitedEmailDomainFactory.create() + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/add/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {"email_domain": existing_domain.domain} + + with pytest.raises(HTTPSeeOther): + views.add_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call( + f"Email domain '{existing_domain.domain}' already exists.", + queue="error", + ) + ] + + def test_success(self, db_request): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/list/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = { + "email_domain": "example.com", + "is_mx_record": "on", + "comment": "testing", + } + + response = views.add_prohibited_email_domain(db_request) + + assert response.status_code == 303 + assert response.headers["Location"] == "/admin/prohibited_email_domains/list/" + assert db_request.session.flash.calls == [ + pretend.call("Prohibited email domain added.", queue="success") + ] + + query = db_request.db.query(ProhibitedEmailDomain).filter( + ProhibitedEmailDomain.domain == "example.com" + ) + assert query.count() == 1 + assert query.one().is_mx_record + assert query.one().comment == "testing" + + +class TestProhibitedEmailDomainsRemove: + def test_no_domain_name(self, db_request): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/remove/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {} + + with pytest.raises(HTTPSeeOther): + views.remove_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Domain name is required.", queue="error") + ] + + def test_domain_not_found(self, db_request): + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/remove/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {"domain_name": "example.com"} + + with pytest.raises(HTTPSeeOther): + views.remove_prohibited_email_domain(db_request) + + assert db_request.session.flash.calls == [ + pretend.call("Domain not found.", queue="error") + ] + + def test_success(self, db_request): + domain = ProhibitedEmailDomainFactory.create() + db_request.method = "POST" + db_request.route_path = lambda a: "/admin/prohibited_email_domains/list/" + db_request.session = pretend.stub( + flash=pretend.call_recorder(lambda *a, **kw: None) + ) + db_request.POST = {"domain_name": domain.domain} + + response = views.remove_prohibited_email_domain(db_request) + + assert response.status_code == 303 + assert response.headers["Location"] == "/admin/prohibited_email_domains/list/" + assert db_request.session.flash.calls == [ + pretend.call( + f"Prohibited email domain '{domain.domain}' removed.", queue="success" + ) + ] + + query = db_request.db.query(ProhibitedEmailDomain).filter( + ProhibitedEmailDomain.domain == domain.domain + ) + assert query.count() == 0 diff --git a/warehouse/admin/templates/admin/prohibited_email_domains/list.html b/warehouse/admin/templates/admin/prohibited_email_domains/list.html new file mode 100644 index 000000000000..3579ac9a4967 --- /dev/null +++ b/warehouse/admin/templates/admin/prohibited_email_domains/list.html @@ -0,0 +1,183 @@ +{# + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. +-#} + +{% extends "admin/base.html" %} + +{% import "admin/utils/pagination.html" as pagination %} + +{% set perms_admin_prohibited_email_domain_write = request.has_permission(Permissions.AdminProhibitedEmailDomainsWrite) %} + +{% block title %} + Prohibited Email Domains +{% endblock title %} + +{% block breadcrumb %} + + +{% endblock breadcrumb %} + +{% block content %} +
+
+
+
+ + +
+ +
+
+
+
+
+
+
+ + + + + + + + + + + + + {% for prohibited_email_domain in prohibited_email_domains %} + + + + + + + + + {% endfor %} + +
Domain NameMX record?Prohibited byProhibited onComment
{{ prohibited_email_domain.domain }} + {% if prohibited_email_domain.is_mx_record %}{% endif %} + + + {{ prohibited_email_domain.prohibited_by.username }} + + {{ prohibited_email_domain.created | format_datetime }}{{ prohibited_email_domain.comment }} + + + + + +
+
+ +
+
+
+
+

Prohibit email domain

+
+
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ +
+
+
+{% endblock content %} diff --git a/warehouse/admin/views/prohibited_email_domains.py b/warehouse/admin/views/prohibited_email_domains.py new file mode 100644 index 000000000000..3cfd3af90e6c --- /dev/null +++ b/warehouse/admin/views/prohibited_email_domains.py @@ -0,0 +1,131 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paginate_sqlalchemy import SqlalchemyOrmPage as SQLAlchemyORMPage +from pyramid.httpexceptions import HTTPBadRequest, HTTPSeeOther +from pyramid.view import view_config +from sqlalchemy import func +from tldextract import extract + +from warehouse.accounts.models import ProhibitedEmailDomain +from warehouse.authnz import Permissions +from warehouse.utils.paginate import paginate_url_factory + + +@view_config( + route_name="admin.prohibited_email_domains.list", + renderer="admin/prohibited_email_domains/list.html", + permission=Permissions.AdminProhibitedEmailDomainsRead, + request_method="GET", + uses_session=True, +) +def prohibited_email_domains(request): + q = request.params.get("q") + + try: + page_num = int(request.params.get("page", 1)) + except ValueError: + raise HTTPBadRequest("'page' must be an integer.") from None + + prohibited_email_domains_query = request.db.query(ProhibitedEmailDomain).order_by( + ProhibitedEmailDomain.created.desc() + ) + + if q: + prohibited_email_domains_query = prohibited_email_domains_query.filter( + ProhibitedEmailDomain.domain.ilike(q) + ) + + prohibited_email_domains = SQLAlchemyORMPage( + prohibited_email_domains_query, + page=page_num, + items_per_page=25, + url_maker=paginate_url_factory(request), + ) + + return {"prohibited_email_domains": prohibited_email_domains, "query": q} + + +@view_config( + route_name="admin.prohibited_email_domains.add", + permission=Permissions.AdminProhibitedEmailDomainsWrite, + request_method="POST", + uses_session=True, + require_methods=False, +) +def add_prohibited_email_domain(request): + email_domain = request.POST.get("email_domain") + if not email_domain: + request.session.flash("Email domain is required.", queue="error") + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + # validate that the domain is valid + if not extract(email_domain).registered_domain: + request.session.flash(f"Invalid domain name '{email_domain}'", queue="error") + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + # make sure we don't have a duplicate entry + if ( + request.db.query(func.count(ProhibitedEmailDomain.id)) + .filter(ProhibitedEmailDomain.domain == email_domain) + .scalar() + > 0 + ): + request.session.flash( + f"Email domain '{email_domain}' already exists.", queue="error" + ) + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + + # Add the domain to the database + is_mx_record = bool(request.POST.get("is_mx_record")) + comment = request.POST.get("comment") + + prohibited_email_domain = ProhibitedEmailDomain( + domain=email_domain, + is_mx_record=is_mx_record, + prohibited_by=request.user, + comment=comment, + ) + + request.db.add(prohibited_email_domain) + request.session.flash("Prohibited email domain added.", queue="success") + + return HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + + +@view_config( + route_name="admin.prohibited_email_domains.remove", + permission=Permissions.AdminProhibitedEmailDomainsWrite, + request_method="POST", + uses_session=True, + require_methods=False, +) +def remove_prohibited_email_domain(request): + domain_name = request.POST.get("domain_name") + if not domain_name: + request.session.flash("Domain name is required.", queue="error") + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + + domain_record = ( + request.db.query(ProhibitedEmailDomain) + .filter(ProhibitedEmailDomain.domain == domain_name) + .first() + ) + + if not domain_record: + request.session.flash("Domain not found.", queue="error") + raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) + + request.db.delete(domain_record) + request.session.flash( + f"Prohibited email domain '{domain_record.domain}' removed.", queue="success" + ) + + return HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) From 715027b1892b505ad26bae3802d4b33c6130f625 Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Thu, 19 Sep 2024 09:12:36 -0400 Subject: [PATCH 4/6] feat: handle non-exact domain inputs Signed-off-by: Mike Fiedler --- .../admin/views/test_prohibited_email_domains.py | 14 +++++++++++--- warehouse/admin/views/prohibited_email_domains.py | 9 +++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/unit/admin/views/test_prohibited_email_domains.py b/tests/unit/admin/views/test_prohibited_email_domains.py index 5e8831736e6e..9530aea62f43 100644 --- a/tests/unit/admin/views/test_prohibited_email_domains.py +++ b/tests/unit/admin/views/test_prohibited_email_domains.py @@ -127,14 +127,22 @@ def test_duplicate_domain(self, db_request): ) ] - def test_success(self, db_request): + @pytest.mark.parametrize( + ("input_domain", "expected_domain"), + [ + ("example.com", "example.com"), + ("mail.example.co.uk", "example.co.uk"), + ("https://example.com/", "example.com"), + ], + ) + def test_success(self, db_request, input_domain, expected_domain): db_request.method = "POST" db_request.route_path = lambda a: "/admin/prohibited_email_domains/list/" db_request.session = pretend.stub( flash=pretend.call_recorder(lambda *a, **kw: None) ) db_request.POST = { - "email_domain": "example.com", + "email_domain": input_domain, "is_mx_record": "on", "comment": "testing", } @@ -148,7 +156,7 @@ def test_success(self, db_request): ] query = db_request.db.query(ProhibitedEmailDomain).filter( - ProhibitedEmailDomain.domain == "example.com" + ProhibitedEmailDomain.domain == expected_domain ) assert query.count() == 1 assert query.one().is_mx_record diff --git a/warehouse/admin/views/prohibited_email_domains.py b/warehouse/admin/views/prohibited_email_domains.py index 3cfd3af90e6c..b0e479ffe0bd 100644 --- a/warehouse/admin/views/prohibited_email_domains.py +++ b/warehouse/admin/views/prohibited_email_domains.py @@ -68,18 +68,19 @@ def add_prohibited_email_domain(request): request.session.flash("Email domain is required.", queue="error") raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) # validate that the domain is valid - if not extract(email_domain).registered_domain: + registered_domain = extract(email_domain).registered_domain + if not registered_domain: request.session.flash(f"Invalid domain name '{email_domain}'", queue="error") raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) # make sure we don't have a duplicate entry if ( request.db.query(func.count(ProhibitedEmailDomain.id)) - .filter(ProhibitedEmailDomain.domain == email_domain) + .filter(ProhibitedEmailDomain.domain == registered_domain) .scalar() > 0 ): request.session.flash( - f"Email domain '{email_domain}' already exists.", queue="error" + f"Email domain '{registered_domain}' already exists.", queue="error" ) raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) @@ -88,7 +89,7 @@ def add_prohibited_email_domain(request): comment = request.POST.get("comment") prohibited_email_domain = ProhibitedEmailDomain( - domain=email_domain, + domain=registered_domain, is_mx_record=is_mx_record, prohibited_by=request.user, comment=comment, From fa450c1938eb279f0ed825dfb17d16ba7798d8a7 Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Thu, 19 Sep 2024 11:34:33 -0400 Subject: [PATCH 5/6] refactor query to use `exists()` subquery Signed-off-by: Mike Fiedler --- warehouse/admin/views/prohibited_email_domains.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/warehouse/admin/views/prohibited_email_domains.py b/warehouse/admin/views/prohibited_email_domains.py index b0e479ffe0bd..4a21103332ff 100644 --- a/warehouse/admin/views/prohibited_email_domains.py +++ b/warehouse/admin/views/prohibited_email_domains.py @@ -13,7 +13,7 @@ from paginate_sqlalchemy import SqlalchemyOrmPage as SQLAlchemyORMPage from pyramid.httpexceptions import HTTPBadRequest, HTTPSeeOther from pyramid.view import view_config -from sqlalchemy import func +from sqlalchemy import exists, select from tldextract import extract from warehouse.accounts.models import ProhibitedEmailDomain @@ -73,11 +73,8 @@ def add_prohibited_email_domain(request): request.session.flash(f"Invalid domain name '{email_domain}'", queue="error") raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) # make sure we don't have a duplicate entry - if ( - request.db.query(func.count(ProhibitedEmailDomain.id)) - .filter(ProhibitedEmailDomain.domain == registered_domain) - .scalar() - > 0 + if request.db.scalar( + select(exists().where(ProhibitedEmailDomain.domain == registered_domain)) ): request.session.flash( f"Email domain '{registered_domain}' already exists.", queue="error" From 59702712ad063fd7953f07f430d1db99e5c5132c Mon Sep 17 00:00:00 2001 From: Mike Fiedler Date: Thu, 19 Sep 2024 11:48:41 -0400 Subject: [PATCH 6/6] fix: disallow using the live service during extraction If we have to do this a third time, we probably want to wrap the extractor in a utility function. Signed-off-by: Mike Fiedler --- warehouse/admin/views/prohibited_email_domains.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/warehouse/admin/views/prohibited_email_domains.py b/warehouse/admin/views/prohibited_email_domains.py index 4a21103332ff..73c4d9852aee 100644 --- a/warehouse/admin/views/prohibited_email_domains.py +++ b/warehouse/admin/views/prohibited_email_domains.py @@ -14,7 +14,7 @@ from pyramid.httpexceptions import HTTPBadRequest, HTTPSeeOther from pyramid.view import view_config from sqlalchemy import exists, select -from tldextract import extract +from tldextract import TLDExtract from warehouse.accounts.models import ProhibitedEmailDomain from warehouse.authnz import Permissions @@ -68,7 +68,8 @@ def add_prohibited_email_domain(request): request.session.flash("Email domain is required.", queue="error") raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list")) # validate that the domain is valid - registered_domain = extract(email_domain).registered_domain + extractor = TLDExtract(suffix_list_urls=()) # Updated during image build + registered_domain = extractor(email_domain).registered_domain if not registered_domain: request.session.flash(f"Invalid domain name '{email_domain}'", queue="error") raise HTTPSeeOther(request.route_path("admin.prohibited_email_domains.list"))