diff --git a/src/k8s-extension/azext_k8s_extension/custom.py b/src/k8s-extension/azext_k8s_extension/custom.py index c88461bb48e..fcb182b64b7 100644 --- a/src/k8s-extension/azext_k8s_extension/custom.py +++ b/src/k8s-extension/azext_k8s_extension/custom.py @@ -34,6 +34,7 @@ DefaultExtension, user_confirmation_factory, ) +from .partner_extensions.WorkloadIAM import WorkloadIAM from . import consts from ._client_factory import cf_resources @@ -51,6 +52,7 @@ def ExtensionFactory(extension_name): "microsoft.azureml.kubernetes": AzureMLKubernetes, "microsoft.dapr": Dapr, "microsoft.dataprotection.kubernetes": DataProtectionKubernetes, + "microsoft.workloadiam": WorkloadIAM, } # Return the extension if we find it in the map, else return the default diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/WorkloadIAM.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/WorkloadIAM.py new file mode 100644 index 00000000000..6e47a151528 --- /dev/null +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/WorkloadIAM.py @@ -0,0 +1,139 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import subprocess + +from knack.log import get_logger +from knack.util import CLIError + +from azure.cli.core.azclierror import InvalidArgumentValueError + +from ..vendored_sdks.models import (Extension, Scope, ScopeCluster) + +from .DefaultExtension import DefaultExtension + +logger = get_logger(__name__) + +CONFIG_SETTINGS_USER_TRUST_DOMAIN = 'trustDomain' +CONFIG_SETTINGS_USER_LOCAL_AUTHORITY = 'localAuthority' +CONFIG_SETTINGS_USER_TENANT_ID = 'tenantID' +CONFIG_SETTINGS_USER_JOIN_TOKEN = 'joinToken' + +CONFIG_SETTINGS_HELM_TRUST_DOMAIN = 'global.workload-iam.trustDomain' +CONFIG_SETTINGS_HELM_TENANT_ID = 'global.workload-iam.tenantID' +CONFIG_SETTINGS_HELM_JOIN_TOKEN = 'workload-iam-local-authority.localAuthorityArgs.joinToken' + +class WorkloadIAM(DefaultExtension): + + def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_type, cluster_rp, + extension_type, scope, auto_upgrade_minor_version, release_train, version, target_namespace, + release_namespace, configuration_settings, configuration_protected_settings, + configuration_settings_file, configuration_protected_settings_file, + plan_name, plan_publisher, plan_product): + """ + Create method for ExtensionType 'microsoft.workloadiam'. + """ + + # Ensure that the values provided by the user for generic values of Arc extensions are + # valid, set sensible default values if not. + if release_train is None: + # TODO - Set this to 'stable' when the extension is ready + release_train = 'preview' + + if scope is None: + scope = 'cluster' + elif scope != 'cluster': + raise InvalidArgumentValueError( + f"Invalid scope '{scope}'. This extension can only be installed at 'cluster' scope.") + + # Scope is always cluster + scope_cluster = ScopeCluster(release_namespace=release_namespace) + ext_scope = Scope(cluster=scope_cluster, namespace=None) + + # Get user configuration values and remove them from the dictionary + trust_domain = configuration_settings.pop(CONFIG_SETTINGS_USER_TRUST_DOMAIN, None) + tenant_id = configuration_settings.pop(CONFIG_SETTINGS_USER_TENANT_ID, None) + local_authority = configuration_settings.pop(CONFIG_SETTINGS_USER_LOCAL_AUTHORITY, None) + join_token = configuration_settings.pop(CONFIG_SETTINGS_USER_JOIN_TOKEN, None) + + # A trust domain name is always required + if trust_domain is None: + raise InvalidArgumentValueError( + f"Invalid configuration settings '{configuration_settings}'. Please provide a trust " + "domain name.") + + if tenant_id is None: + raise InvalidArgumentValueError( + f"Invalid configuration settings '{configuration_settings}'. Please provide a " + "tenant ID.") + + # If the user hasn't provided a join token, create one + if join_token is None: + if local_authority is None: + raise InvalidArgumentValueError( + f"Invalid configuration settings '{configuration_settings}'. Either a join " + "token or a local authority name must be provided.") + join_token = self.get_join_token(trust_domain, local_authority) + else: + logger.info("Join token is provided") + + # Save configuration setting values to overwrite values in the Helm chart + configuration_settings[CONFIG_SETTINGS_HELM_TRUST_DOMAIN] = trust_domain + configuration_settings[CONFIG_SETTINGS_HELM_TENANT_ID] = tenant_id + configuration_settings[CONFIG_SETTINGS_HELM_JOIN_TOKEN] = join_token + + logger.debug("Configuration settings value for Helm: %s", str(configuration_settings)) + + create_identity = True + extension = Extension( + extension_type=extension_type, + auto_upgrade_minor_version=auto_upgrade_minor_version, + release_train=release_train, + version=version, + scope=ext_scope, + configuration_settings=configuration_settings, + configuration_protected_settings=configuration_protected_settings + ) + return extension, name, create_identity + + + def get_join_token(self, trust_domain, local_authority): + """ + Invoke the az command to obtain a join token. + """ + + logger.info("Getting a join token from the control plane") + + # Invoke az workload-iam command to obtain the join token + cmd = [ + "az", "workload-iam", "local-authority", "attestation-method", "create", + "--td", trust_domain, + "--la", local_authority, + "--type", "joinTokenAttestationMethod", + "--query", "singleUseToken", + "--dn", "myJoinToken", + ] + cmd_str = " ".join(cmd) + + try: + # Note: We can't use get_default_cli() here because its invoke() method + # always prints the console output, which we want to avoid. + result = subprocess.run(cmd, capture_output=True, shell=True) + except Exception as e: + logger.error(f"Error while generating a join token: {cmd_str}") + raise e + + if result.returncode != 0: + raise CLIError(f"Failed to generate a join token (exit code {result.returncode}): {cmd_str}") + + try: + # Strip double quotes from the output + command_output = result.stdout.decode("utf-8") + token = command_output.strip("\r\n").strip("\"") + except Exception as e: + logger.error(f"Failed to parse output of join token command: {cmd_str}") + raise e + + return token diff --git a/src/k8s-extension/azext_k8s_extension/tests/latest/test_workload_iam.py b/src/k8s-extension/azext_k8s_extension/tests/latest/test_workload_iam.py new file mode 100644 index 00000000000..d20d6e5d22d --- /dev/null +++ b/src/k8s-extension/azext_k8s_extension/tests/latest/test_workload_iam.py @@ -0,0 +1,318 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=protected-access + +import unittest + +from azure.cli.core.azclierror import InvalidArgumentValueError +from azext_k8s_extension.partner_extensions.WorkloadIAM import ( + WorkloadIAM, + CONFIG_SETTINGS_USER_TRUST_DOMAIN, + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY, + CONFIG_SETTINGS_USER_TENANT_ID, + CONFIG_SETTINGS_USER_JOIN_TOKEN, + CONFIG_SETTINGS_HELM_TRUST_DOMAIN, + CONFIG_SETTINGS_HELM_TENANT_ID, + CONFIG_SETTINGS_HELM_JOIN_TOKEN +) + +from knack.util import CLIError + +from unittest.mock import patch + +class TestWorkloadIAM(unittest.TestCase): + + def test_workload_iam_create_without_join_token_success(self): + """ + Test that, when the user doesn't provide a join token, the Create() method calls + get_join_token() and creates a new one, and that the final configuration settings + are the expected ones. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_tenant_id = 'any_tenant_id' + mock_join_token = 'any_join_token' + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY: mock_local_authority_name, + CONFIG_SETTINGS_USER_TENANT_ID: mock_tenant_id, + } + + def mock_extension_init(_self, *, extension_type, auto_upgrade_minor_version, release_train, + version, scope, configuration_settings, configuration_protected_settings): + assert(release_train == "dev") + assert(configuration_settings[CONFIG_SETTINGS_HELM_JOIN_TOKEN] == mock_join_token); + assert(configuration_settings[CONFIG_SETTINGS_HELM_TRUST_DOMAIN] == mock_trust_domain_name) + assert(configuration_settings[CONFIG_SETTINGS_HELM_TENANT_ID] == mock_tenant_id) + + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.Extension.__init__', + new=mock_extension_init), \ + patch('azext_k8s_extension.partner_extensions.WorkloadIAM.WorkloadIAM.get_join_token', + return_value=mock_join_token): + + # Test & assert + workload_iam = WorkloadIAM() + _, name, _ = workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + self.assertEqual(name, 'workload-iam') + + + def test_workload_iam_create_with_join_token_and_local_authority_success(self): + """ + Test that, when the user provides a join token, the Create() method doesn't call + get_join_token(), and that the final configuration settings are the expected ones. The + provided local authority is only required to generate a new join token. As no token is + created, the local authority will just be ignored. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_tenant_id = 'any_tenant_id' + mock_join_token = 'any_join_token' + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY: mock_local_authority_name, + CONFIG_SETTINGS_USER_TENANT_ID: mock_tenant_id, + CONFIG_SETTINGS_USER_JOIN_TOKEN: mock_join_token, + } + + def mock_extension_init(_self, *, extension_type, auto_upgrade_minor_version, release_train, + version, scope, configuration_settings, configuration_protected_settings): + assert(release_train == "dev") + assert(configuration_settings[CONFIG_SETTINGS_HELM_JOIN_TOKEN] == mock_join_token); + assert(configuration_settings[CONFIG_SETTINGS_HELM_TRUST_DOMAIN] == mock_trust_domain_name) + assert(configuration_settings[CONFIG_SETTINGS_HELM_TENANT_ID] == mock_tenant_id) + + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.Extension.__init__', + new=mock_extension_init), \ + patch('azext_k8s_extension.partner_extensions.WorkloadIAM.WorkloadIAM.get_join_token', + return_value='BAD_JOIN_TOKEN'): + + # Test & assert + workload_iam = WorkloadIAM() + _, name, _ = workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + self.assertEqual(name, 'workload-iam') + + + def test_workload_iam_create_with_join_token_and_no_local_authority_success(self): + """ + Test that, when the user provides a join token, the Create() method doesn't call + get_join_token(), and that the final configuration settings are the expected ones. The + provided local authority is only required to generate a new join token, so the test should + pass even without it. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_tenant_id = 'any_tenant_id' + mock_join_token = 'any_join_token' + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_JOIN_TOKEN: mock_join_token, + CONFIG_SETTINGS_USER_TENANT_ID: mock_tenant_id, + } + + def mock_extension_init(_self, *, extension_type, auto_upgrade_minor_version, release_train, + version, scope, configuration_settings, configuration_protected_settings): + assert(release_train == "dev") + assert(configuration_settings[CONFIG_SETTINGS_HELM_JOIN_TOKEN] == mock_join_token); + assert(configuration_settings[CONFIG_SETTINGS_HELM_TRUST_DOMAIN] == mock_trust_domain_name) + assert(configuration_settings[CONFIG_SETTINGS_HELM_TENANT_ID] == mock_tenant_id) + + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.Extension.__init__', + new=mock_extension_init), \ + patch('azext_k8s_extension.partner_extensions.WorkloadIAM.WorkloadIAM.get_join_token', + return_value='BAD_JOIN_TOKEN'): + + # Test & assert + workload_iam = WorkloadIAM() + _, name, _ = workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + self.assertEqual(name, 'workload-iam') + + def test_workload_iam_create_with_trust_domain_local_authority_no_tenant_id(self): + """ + Test that, when the user doesn't provide a tenant ID, there is an error. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY: mock_local_authority_name, + } + + with self.assertRaises(InvalidArgumentValueError) as context: + workload_iam = WorkloadIAM() + workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + + self.assertEqual(str(context.exception), + f"Invalid configuration settings '{settings}'. Please provide a tenant ID.") + + def test_workload_iam_create_with_wrong_scope_fails(self): + """ + Test that when the user provides a scope that isn't "cluster" the method Create() fails. + """ + + bad_scope = 'namespace' + + with self.assertRaises(InvalidArgumentValueError) as context: + workload_iam = WorkloadIAM() + workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope=bad_scope, auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=None, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + + self.assertEqual(str(context.exception), + f"Invalid scope '{bad_scope}'. This extension can only be installed at 'cluster' scope.") + + + def test_workload_iam_create_with_not_enough_settings_fails(self): + """ + Test that when the user doesn't provide the trust domain or local authority the method + Create() fails. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_tenant_id = 'any_tenant_id' + + # Missing local authority + + settings = { + CONFIG_SETTINGS_USER_TRUST_DOMAIN: mock_trust_domain_name, + CONFIG_SETTINGS_USER_TENANT_ID: mock_tenant_id, + } + + with self.assertRaises(InvalidArgumentValueError) as context: + workload_iam = WorkloadIAM() + workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + + self.assertEqual(str(context.exception), + f"Invalid configuration settings '{settings}'. Either a join token or a local " + "authority name must be provided.") + + # Missing trust domain + + settings = { + CONFIG_SETTINGS_USER_LOCAL_AUTHORITY: mock_local_authority_name, + } + + with self.assertRaises(InvalidArgumentValueError) as context: + workload_iam = WorkloadIAM() + workload_iam.Create(cmd=None, client=None, resource_group_name=None, + cluster_name=None, name='workload-iam', cluster_type=None, cluster_rp=None, + extension_type=None, scope='cluster', auto_upgrade_minor_version=None, + release_train='dev', version='0.1.0', target_namespace=None, + release_namespace=None, configuration_settings=settings, + configuration_protected_settings=None, configuration_settings_file=None, + configuration_protected_settings_file=None, plan_name=None, plan_publisher=None, + plan_product=None) + + self.assertEqual(str(context.exception), + f"Invalid configuration settings '{settings}'. Please provide a trust " + "domain name.") + + def test_workload_iam_get_join_token_with_valid_argument_success(self): + """ + Test that when get_join_token() succeedes it returns a token in the right format (between + double quotes) and that the arguments passed to "az workload-iam" are the expected ones. + """ + + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_join_token = 'any_join_token' + + class MockResult(): + def __init__(self): + self.returncode = 0 + self.stdout = ('\"' + mock_join_token + '\"').encode('utf-8') + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.subprocess.run', + return_value=MockResult()): + # Test & assert + workload_iam = WorkloadIAM() + join_token = workload_iam.get_join_token(mock_trust_domain_name, mock_local_authority_name) + self.assertEqual(join_token, mock_join_token) + + + def test_workload_iam_get_join_token_with_bad_exit_code(self): + """ + Test that get_join_token() fails with the right error message if "az workload-iam" returns a + non-zero error code (and if no exception is raised). + """ + + # Set up mocks + mock_trust_domain_name = 'any_trust_domain_name.com' + mock_local_authority_name = 'any_local_authority_name' + mock_join_token = 'any_join_token' + mock_exit_code = 1 + + cmd = [ + "az", "workload-iam", "local-authority", "attestation-method", "create", + "--td", mock_trust_domain_name, + "--la", mock_local_authority_name, + "--type", "joinTokenAttestationMethod", + "--query", "singleUseToken", + "--dn", "myJoinToken", + ] + + class MockResult(): + def __init__(self): + self.returncode = mock_exit_code + + with patch('azext_k8s_extension.partner_extensions.WorkloadIAM.subprocess.run', + return_value=MockResult()): + # Test & assert + workload_iam = WorkloadIAM() + cmd_str = " ".join(cmd) + self.assertRaisesRegex(CLIError, + f"Failed to generate a join token \(exit code {mock_exit_code}\): {cmd_str}", + workload_iam.get_join_token, mock_trust_domain_name, mock_local_authority_name)