Skip to content

Commit

Permalink
Trusted publishing: surface pending publisher name collisions (pypi#1…
Browse files Browse the repository at this point in the history
…6260)

* IProjectService.check_project_name

Extract the distribution name validation logic into a separate function
for reuse elsewhere.

* Wire pending publisher forms to IProjectService

* Update tests

* Invalidate any conflicting publisher with a similar name

* Add pending_project_name_ultranormalized index

* Drop redundant normalization clause

* Update translations

* Remove unnecessary or_

Co-authored-by: Dustin Ingram <di@users.noreply.github.com>

* Rebase migration

---------

Co-authored-by: Dustin Ingram <di@users.noreply.github.com>
  • Loading branch information
twm and di authored Dec 2, 2024
1 parent 304671b commit d6d91d5
Show file tree
Hide file tree
Showing 23 changed files with 387 additions and 163 deletions.
87 changes: 54 additions & 33 deletions tests/unit/accounts/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
OrganizationRole,
OrganizationRoleType,
)
from warehouse.packaging.interfaces import IProjectService
from warehouse.packaging.models import Role, RoleInvitation
from warehouse.rate_limiting.interfaces import IRateLimiter

Expand Down Expand Up @@ -3306,8 +3307,17 @@ def test_reauth_no_user(self, monkeypatch, pyramid_request):

class TestManageAccountPublishingViews:
def test_initializes(self, metrics):
project_service = pretend.stub(check_project_name=lambda name: None)

def find_service(iface, name=None, context=None):
if iface is IMetricsService:
return metrics
if iface is IProjectService:
return project_service
return pretend.stub()

request = pretend.stub(
find_service=pretend.call_recorder(lambda *a, **kw: metrics),
find_service=pretend.call_recorder(find_service),
route_url=pretend.stub(),
POST=MultiDict(),
registry=pretend.stub(
Expand All @@ -3320,9 +3330,11 @@ def test_initializes(self, metrics):

assert view.request is request
assert view.metrics is metrics
assert view.project_service is project_service

assert view.request.find_service.calls == [
pretend.call(IMetricsService, context=None)
pretend.call(IMetricsService, context=None),
pretend.call(IProjectService, context=None),
]

@pytest.mark.parametrize(
Expand All @@ -3348,6 +3360,8 @@ def test_ratelimiting(self, metrics, ip_exceeded, user_exceeded):
def find_service(iface, name=None, context=None):
if iface is IMetricsService:
return metrics
if iface is IProjectService:
return pretend.stub(check_project_name=lambda name: None)

if name == "user_oidc.publisher.register":
return user_rate_limiter
Expand Down Expand Up @@ -3375,6 +3389,7 @@ def find_service(iface, name=None, context=None):
}
assert request.find_service.calls == [
pretend.call(IMetricsService, context=None),
pretend.call(IProjectService, context=None),
pretend.call(IRateLimiter, name="user_oidc.publisher.register"),
pretend.call(IRateLimiter, name="ip_oidc.publisher.register"),
]
Expand Down Expand Up @@ -3402,16 +3417,17 @@ def test_manage_publishing(self, metrics, monkeypatch):
"github.token": "fake-api-token",
}
),
find_service=pretend.call_recorder(lambda *a, **kw: metrics),
find_service=lambda svc, **kw: {
IMetricsService: metrics,
IProjectService: project_service,
}[svc],
flags=pretend.stub(
disallow_oidc=pretend.call_recorder(lambda f=None: False)
),
POST=pretend.stub(),
)

project_factory = pretend.stub()
project_factory_cls = pretend.call_recorder(lambda r: project_factory)
monkeypatch.setattr(views, "ProjectFactory", project_factory_cls)
project_service = pretend.stub(check_project_name=lambda name: None)

pending_github_publisher_form_obj = pretend.stub()
pending_github_publisher_form_cls = pretend.call_recorder(
Expand Down Expand Up @@ -3466,24 +3482,26 @@ def test_manage_publishing(self, metrics, monkeypatch):
pretend.call(AdminFlagValue.DISALLOW_GOOGLE_OIDC),
pretend.call(AdminFlagValue.DISALLOW_ACTIVESTATE_OIDC),
]
assert project_factory_cls.calls == [pretend.call(request)]
assert pending_github_publisher_form_cls.calls == [
pretend.call(
request.POST,
api_token="fake-api-token",
route_url=route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]
assert pending_gitlab_publisher_form_cls.calls == [
pretend.call(
request.POST,
route_url=route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]

def test_manage_publishing_admin_disabled(self, monkeypatch, pyramid_request):
project_service = pretend.stub(check_project_name=lambda name: None)
pyramid_request.find_service = lambda _, **kw: project_service

pyramid_request.user = pretend.stub()
pyramid_request.registry = pretend.stub(
settings={
Expand All @@ -3497,10 +3515,6 @@ def test_manage_publishing_admin_disabled(self, monkeypatch, pyramid_request):
flash=pretend.call_recorder(lambda *a, **kw: None)
)

project_factory = pretend.stub()
project_factory_cls = pretend.call_recorder(lambda r: project_factory)
monkeypatch.setattr(views, "ProjectFactory", project_factory_cls)

pending_github_publisher_form_obj = pretend.stub()
pending_github_publisher_form_cls = pretend.call_recorder(
lambda *a, **kw: pending_github_publisher_form_obj
Expand Down Expand Up @@ -3568,14 +3582,14 @@ def test_manage_publishing_admin_disabled(self, monkeypatch, pyramid_request):
pyramid_request.POST,
api_token="fake-api-token",
route_url=pyramid_request.route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]
assert pending_gitlab_publisher_form_cls.calls == [
pretend.call(
pyramid_request.POST,
route_url=pyramid_request.route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]

Expand Down Expand Up @@ -3607,6 +3621,12 @@ def test_manage_publishing_admin_disabled(self, monkeypatch, pyramid_request):
def test_add_pending_oidc_publisher_admin_disabled(
self, monkeypatch, pyramid_request, view_name, flag, publisher_name
):
project_service = pretend.stub(check_project_name=lambda name: None)
pyramid_request.find_service = lambda interface, **kwargs: {
IProjectService: project_service,
IMetricsService: pretend.stub(),
}[interface]

pyramid_request.user = pretend.stub()
pyramid_request.registry = pretend.stub(
settings={
Expand All @@ -3620,10 +3640,6 @@ def test_add_pending_oidc_publisher_admin_disabled(
flash=pretend.call_recorder(lambda *a, **kw: None)
)

project_factory = pretend.stub()
project_factory_cls = pretend.call_recorder(lambda r: project_factory)
monkeypatch.setattr(views, "ProjectFactory", project_factory_cls)

pending_github_publisher_form_obj = pretend.stub()
pending_github_publisher_form_cls = pretend.call_recorder(
lambda *a, **kw: pending_github_publisher_form_obj
Expand Down Expand Up @@ -3698,14 +3714,14 @@ def test_add_pending_oidc_publisher_admin_disabled(
pyramid_request.POST,
api_token="fake-api-token",
route_url=pyramid_request.route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]
assert pending_gitlab_publisher_form_cls.calls == [
pretend.call(
pyramid_request.POST,
route_url=pyramid_request.route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]

Expand Down Expand Up @@ -3741,7 +3757,14 @@ def test_add_pending_oidc_publisher_user_cannot_register(
view_name,
flag,
publisher_name,
metrics,
):
project_service = pretend.stub(check_project_name=lambda name: None)
pyramid_request.find_service = lambda interface, **kwargs: {
IProjectService: project_service,
IMetricsService: metrics,
}[interface]

pyramid_request.registry = pretend.stub(
settings={
"github.token": "fake-api-token",
Expand All @@ -3757,10 +3780,6 @@ def test_add_pending_oidc_publisher_user_cannot_register(
flash=pretend.call_recorder(lambda *a, **kw: None)
)

project_factory = pretend.stub()
project_factory_cls = pretend.call_recorder(lambda r: project_factory)
monkeypatch.setattr(views, "ProjectFactory", project_factory_cls)

pending_github_publisher_form_obj = pretend.stub()
pending_github_publisher_form_cls = pretend.call_recorder(
lambda *a, **kw: pending_github_publisher_form_obj
Expand Down Expand Up @@ -3839,14 +3858,14 @@ def test_add_pending_oidc_publisher_user_cannot_register(
pyramid_request.POST,
api_token="fake-api-token",
route_url=pyramid_request.route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]
assert pending_gitlab_publisher_form_cls.calls == [
pretend.call(
pyramid_request.POST,
route_url=pyramid_request.route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]

Expand Down Expand Up @@ -4474,6 +4493,12 @@ def test_add_pending_oidc_publisher(
def test_delete_pending_oidc_publisher_admin_disabled(
self, monkeypatch, pyramid_request
):
project_service = pretend.stub(check_project_name=lambda name: None)
pyramid_request.find_service = lambda interface, **kwargs: {
IProjectService: project_service,
IMetricsService: pretend.stub(),
}[interface]

pyramid_request.user = pretend.stub()
pyramid_request.registry = pretend.stub(
settings={
Expand All @@ -4487,10 +4512,6 @@ def test_delete_pending_oidc_publisher_admin_disabled(
flash=pretend.call_recorder(lambda *a, **kw: None)
)

project_factory = pretend.stub()
project_factory_cls = pretend.call_recorder(lambda r: project_factory)
monkeypatch.setattr(views, "ProjectFactory", project_factory_cls)

pending_github_publisher_form_obj = pretend.stub()
pending_github_publisher_form_cls = pretend.call_recorder(
lambda *a, **kw: pending_github_publisher_form_obj
Expand Down Expand Up @@ -4558,14 +4579,14 @@ def test_delete_pending_oidc_publisher_admin_disabled(
pyramid_request.POST,
api_token="fake-api-token",
route_url=pyramid_request.route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]
assert pending_gitlab_publisher_form_cls.calls == [
pretend.call(
pyramid_request.POST,
route_url=pyramid_request.route_url,
project_factory=project_factory,
check_project_name=project_service.check_project_name,
)
]

Expand Down
27 changes: 16 additions & 11 deletions tests/unit/oidc/forms/test_activestate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from webob.multidict import MultiDict

from warehouse.oidc.forms import activestate
from warehouse.packaging.interfaces import ProjectNameUnavailableReason

fake_username = "some-username"
fake_org_name = "some-org"
Expand All @@ -30,14 +31,13 @@
_requests = requests


def _raise(exception):
raise exception


class TestPendingActiveStatePublisherForm:
def test_validate(self, monkeypatch):
project_factory = []
route_url = pretend.stub()

def check_project_name(name):
return None

data = MultiDict(
{
"organization": "some-org",
Expand All @@ -47,24 +47,29 @@ def test_validate(self, monkeypatch):
}
)
form = activestate.PendingActiveStatePublisherForm(
MultiDict(data), route_url=route_url, project_factory=project_factory
MultiDict(data),
route_url=route_url,
check_project_name=check_project_name,
)

# Test built-in validations
monkeypatch.setattr(form, "_lookup_actor", lambda *o: {"user_id": "some-id"})

monkeypatch.setattr(form, "_lookup_organization", lambda *o: None)

assert form._project_factory == project_factory
assert form._check_project_name == check_project_name
assert form._route_url == route_url
assert form.validate()

def test_validate_project_name_already_in_use(self, pyramid_config):
project_factory = ["some-project"]
route_url = pretend.call_recorder(lambda *args, **kwargs: "")

def check_project_name(name):
return ProjectNameUnavailableReason.AlreadyExists

form = activestate.PendingActiveStatePublisherForm(
route_url=route_url, project_factory=project_factory
route_url=route_url,
check_project_name=check_project_name,
)

field = pretend.stub(data="some-project")
Expand Down Expand Up @@ -208,7 +213,7 @@ def test_lookup_actor_non_json(self, monkeypatch):
response = pretend.stub(
status_code=200,
raise_for_status=pretend.call_recorder(lambda: None),
json=lambda: _raise(_requests.exceptions.JSONDecodeError("", "", 0)),
json=pretend.raiser(_requests.exceptions.JSONDecodeError("", "", 0)),
content=b"",
)

Expand Down Expand Up @@ -437,7 +442,7 @@ def test_lookup_organization_non_json(self, monkeypatch):
response = pretend.stub(
status_code=200,
raise_for_status=pretend.call_recorder(lambda: None),
json=lambda: _raise(_requests.exceptions.JSONDecodeError("", "", 0)),
json=pretend.raiser(_requests.exceptions.JSONDecodeError("", "", 0)),
content=b"",
)

Expand Down
Loading

0 comments on commit d6d91d5

Please sign in to comment.