diff --git a/plugins/module_utils/azure_rm_common.py b/plugins/module_utils/azure_rm_common.py index 17dd72a7d..ce5e5c6b3 100644 --- a/plugins/module_utils/azure_rm_common.py +++ b/plugins/module_utils/azure_rm_common.py @@ -1606,9 +1606,29 @@ def _get_profile(self, profile="default"): return None - def _get_msi_credentials(self, subscription_id=None, client_id=None, **kwargs): - credentials = MSIAuthentication(client_id=client_id) - credential = MSIAuthenticationWrapper(client_id=client_id) + def _get_msi_credentials(self, subscription_id=None, client_id=None, _cloud_environment=None, **kwargs): + # Get object `cloud_environment` from string `_cloud_environment` + cloud_environment = None + if not _cloud_environment: + cloud_environment = azure_cloud.AZURE_PUBLIC_CLOUD + else: + # try to look up "well-known" values via the name attribute on azure_cloud members + all_clouds = [x[1] for x in inspect.getmembers(azure_cloud) if isinstance(x[1], azure_cloud.Cloud)] + matched_clouds = [x for x in all_clouds if x.name == _cloud_environment] + if len(matched_clouds) == 1: + cloud_environment = matched_clouds[0] + elif len(matched_clouds) > 1: + self.fail("Azure SDK failure: more than one cloud matched for cloud_environment name '{0}'".format(_cloud_environment)) + else: + if not urlparse.urlparse(_cloud_environment).scheme: + self.fail("cloud_environment must be an endpoint discovery URL or one of {0}".format([x.name for x in all_clouds])) + try: + cloud_environment = azure_cloud.get_cloud_from_metadata_endpoint(_cloud_environment) + except Exception as exc: + self.fail("cloud_environment {0} could not be resolved: {1}".format(_cloud_environment, str(exc)), exception=traceback.format_exc()) + + credentials = MSIAuthentication(client_id=client_id, cloud_environment=cloud_environment) + credential = MSIAuthenticationWrapper(client_id=client_id, cloud_environment=cloud_environment) subscription_id = subscription_id or self._get_env('subscription_id') if not subscription_id: try: @@ -1623,6 +1643,7 @@ def _get_msi_credentials(self, subscription_id=None, client_id=None, **kwargs): 'credentials': credentials, 'credential': credential, 'subscription_id': subscription_id, + 'cloud_environment': cloud_environment, 'auth_source': 'msi' } @@ -1670,7 +1691,8 @@ def _get_credentials(self, auth_source=None, **params): if auth_source == 'msi': self.log('Retrieving credentials from MSI') - return self._get_msi_credentials(subscription_id=params.get('subscription_id'), client_id=params.get('client_id')) + return self._get_msi_credentials(subscription_id=params.get('subscription_id'), client_id=params.get('client_id'), + _cloud_environment=params.get('cloud_environment')) if auth_source == 'cli': if not HAS_AZURE_CLI_CORE: