diff --git a/common/djangoapps/third_party_auth/models.py b/common/djangoapps/third_party_auth/models.py index af5159764c92..1ea265e868f2 100644 --- a/common/djangoapps/third_party_auth/models.py +++ b/common/djangoapps/third_party_auth/models.py @@ -366,13 +366,12 @@ class OAuth2ProviderConfig(ProviderConfig): .. no_pii: """ - # We are keying the provider config by backend_name here as suggested in the python social - # auth documentation. In order to reuse a backend for a second provider, a subclass can be - # created with seperate name. + # We are keying the provider config by backend_name and site_id to support configuration per site. + # In order to reuse a backend for a second provider, a subclass can be created with seperate name. # example: # class SecondOpenIDProvider(OpenIDAuth): # name = "second-openId-provider" - KEY_FIELDS = ('backend_name',) + KEY_FIELDS = ('site_id', 'backend_name') prefix = 'oa2' backend_name = models.CharField( max_length=50, blank=False, db_index=True, @@ -401,6 +400,29 @@ class Meta: verbose_name = "Provider Configuration (OAuth)" verbose_name_plural = verbose_name + @classmethod + def current(cls, *args): + """ + Get the current config model for the provider according to the given backend and the current + site. + """ + site_id = Site.objects.get_current(get_current_request()).id + return super(OAuth2ProviderConfig, cls).current(site_id, *args) + + @property + def provider_id(self): + """ + Unique string key identifying this provider. Must be URL and css class friendly. + Ignoring site_id as the config is filtered using current method which fetches the configuration for the current + site_id. + """ + assert self.prefix is not None + return "-".join((self.prefix, ) + tuple( + str(getattr(self, field)) + for field in self.KEY_FIELDS + if field != 'site_id' + )) + def clean(self): """ Standardize and validate fields """ super().clean() diff --git a/common/djangoapps/third_party_auth/pipeline.py b/common/djangoapps/third_party_auth/pipeline.py index 9135ea556bea..fde2bb9cbc0c 100644 --- a/common/djangoapps/third_party_auth/pipeline.py +++ b/common/djangoapps/third_party_auth/pipeline.py @@ -854,7 +854,7 @@ def user_details_force_sync(auth_entry, strategy, details, user=None, *args, **k This step is controlled by the `sync_learner_profile_data` flag on the provider's configuration. """ current_provider = provider.Registry.get_from_pipeline({'backend': strategy.request.backend.name, 'kwargs': kwargs}) - if user and current_provider.sync_learner_profile_data: + if user and current_provider and current_provider.sync_learner_profile_data: # Keep track of which incoming values get applied. changed = {} @@ -931,7 +931,7 @@ def set_id_verification_status(auth_entry, strategy, details, user=None, *args, Use the user's authentication with the provider, if configured, as evidence of their identity being verified. """ current_provider = provider.Registry.get_from_pipeline({'backend': strategy.request.backend.name, 'kwargs': kwargs}) - if user and current_provider.enable_sso_id_verification: + if user and current_provider and current_provider.enable_sso_id_verification: # Get previous valid, non expired verification attempts for this SSO Provider and user verifications = SSOVerification.objects.filter( user=user, diff --git a/common/djangoapps/third_party_auth/tests/test_provider.py b/common/djangoapps/third_party_auth/tests/test_provider.py index 3fa8f80f4d1a..28e95c16b8f9 100644 --- a/common/djangoapps/third_party_auth/tests/test_provider.py +++ b/common/djangoapps/third_party_auth/tests/test_provider.py @@ -11,7 +11,9 @@ from common.djangoapps.third_party_auth import provider from common.djangoapps.third_party_auth.tests import testutil from common.djangoapps.third_party_auth.tests.utils import skip_unless_thirdpartyauth -from openedx.core.djangoapps.site_configuration.tests.test_util import with_site_configuration +from openedx.core.djangoapps.site_configuration.tests.test_util import ( + with_site_configuration, with_site_configuration_context +) SITE_DOMAIN_A = 'professionalx.example.com' SITE_DOMAIN_B = 'somethingelse.example.com' @@ -114,13 +116,13 @@ def test_providers_displayed_for_login(self): assert no_log_in_provider.provider_id not in provider_ids assert normal_provider.provider_id in provider_ids - def test_tpa_hint_provider_displayed_for_login(self): + def test_tpa_hint_exp_hidden_provider_displayed_for_login(self): """ - Tests to ensure that an enabled-but-not-visible provider is presented + Test to ensure that an explicitly enabled-but-not-visible provider is presented for use in the UI when the "tpa_hint" parameter is specified + A hidden provider should be accessible with tpa_hint (this is the main case) """ - # A hidden provider should be accessible with tpa_hint (this is the main case) hidden_provider = self.configure_google_provider(visible=False, enabled=True) provider_ids = [ idp.provider_id @@ -128,8 +130,14 @@ def test_tpa_hint_provider_displayed_for_login(self): ] assert hidden_provider.provider_id in provider_ids - # New providers are hidden (ie, not flagged as 'visible') by default - # The tpa_hint parameter should work for these providers as well + def test_tpa_hint_hidden_provider_displayed_for_login(self): + """ + Tests to ensure that an implicitly enabled-but-not-visible provider is presented + for use in the UI when the "tpa_hint" parameter is specified. + New providers are hidden (ie, not flagged as 'visible') by default + The tpa_hint parameter should work for these providers as well. + """ + implicitly_hidden_provider = self.configure_linkedin_provider(enabled=True) provider_ids = [ idp.provider_id @@ -137,7 +145,10 @@ def test_tpa_hint_provider_displayed_for_login(self): ] assert implicitly_hidden_provider.provider_id in provider_ids - # Disabled providers should not be matched in tpa_hint scenarios + def test_tpa_hint_disabled_hidden_provider_displayed_for_login(self): + """ + Disabled providers should not be matched in tpa_hint scenarios + """ disabled_provider = self.configure_twitter_provider(visible=True, enabled=False) provider_ids = [ idp.provider_id @@ -145,7 +156,10 @@ def test_tpa_hint_provider_displayed_for_login(self): ] assert disabled_provider.provider_id not in provider_ids - # Providers not utilized for learner authentication should not match tpa_hint + def test_tpa_hint_no_log_hidden_provider_displayed_for_login(self): + """ + Providers not utilized for learner authentication should not match tpa_hint + """ no_log_in_provider = self.configure_lti_provider() provider_ids = [ idp.provider_id @@ -153,6 +167,30 @@ def test_tpa_hint_provider_displayed_for_login(self): ] assert no_log_in_provider.provider_id not in provider_ids + def test_get_current_site_oauth_provider(self): + """ + Verify that correct provider for current site is returned even if same backend is used for multiple sites. + """ + site_a = Site.objects.get_or_create(domain=SITE_DOMAIN_A, name=SITE_DOMAIN_A)[0] + site_b = Site.objects.get_or_create(domain=SITE_DOMAIN_B, name=SITE_DOMAIN_B)[0] + site_a_provider = self.configure_google_provider(visible=True, enabled=True, site=site_a) + site_b_provider = self.configure_google_provider(visible=True, enabled=True, site=site_b) + with with_site_configuration_context(domain=SITE_DOMAIN_A): + assert site_a_provider.enabled_for_current_site is True + + # Registry.displayed_for_login gets providers enabled for current site + provider_ids = provider.Registry.displayed_for_login() + # Google oauth provider for current site should be displayed + assert site_a_provider in provider_ids + assert site_b_provider not in provider_ids + + # Similarly, the other site should only see its own providers + with with_site_configuration_context(domain=SITE_DOMAIN_B): + assert site_b_provider.enabled_for_current_site is True + provider_ids = provider.Registry.displayed_for_login() + assert site_b_provider in provider_ids + assert site_a_provider not in provider_ids + def test_provider_enabled_for_current_site(self): """ Verify that enabled_for_current_site returns True when the provider matches the current site. @@ -201,7 +239,7 @@ def test_oauth2_enabled_only_for_supplied_backend(self): def test_get_returns_none_if_provider_id_is_none(self): assert provider.Registry.get(None) is None - def test_get_returns_none_if_provider_not_enabled(self): + def test_get_returns_none_if_provider_not_enabled_change(self): linkedin_provider_id = "oa2-linkedin-oauth2" # At this point there should be no configuration entries at all so no providers should be enabled assert provider.Registry.enabled() == [] @@ -209,6 +247,12 @@ def test_get_returns_none_if_provider_not_enabled(self): # Now explicitly disabled this provider: self.configure_linkedin_provider(enabled=False) assert provider.Registry.get(linkedin_provider_id) is None + + def test_get_returns_provider_if_provider_enabled(self): + """ + Test to ensure that Registry gets enabled providers. + """ + linkedin_provider_id = "oa2-linkedin-oauth2" self.configure_linkedin_provider(enabled=True) assert provider.Registry.get(linkedin_provider_id).provider_id == linkedin_provider_id diff --git a/common/djangoapps/third_party_auth/views.py b/common/djangoapps/third_party_auth/views.py index c6a406c5301c..d24ba8cfd7db 100644 --- a/common/djangoapps/third_party_auth/views.py +++ b/common/djangoapps/third_party_auth/views.py @@ -47,7 +47,7 @@ def inactive_user_view(request): if third_party_auth.is_enabled() and pipeline.running(request): running_pipeline = pipeline.get(request) third_party_provider = provider.Registry.get_from_pipeline(running_pipeline) - if third_party_provider.skip_email_verification and not activated: + if third_party_provider and third_party_provider.skip_email_verification and not activated: user.is_active = True user.save() activated = True