Skip to content

Commit

Permalink
Adding in some functionality to use different clouds (#28528)
Browse files Browse the repository at this point in the history
* Adding in some functionality to use different clouds

This support is needed for the air gapped environments.  There are three
ways to add a new cloud environment.  These were all taken from examples
in the v1 sdk.

1) The SDK will look for a default configuration file and try to find
   cloud environments in there.
2) If you set an environment variable called ARM_METADATA_URL, it will
   look there for cloud configurations.  If you do not set this, it will
   use a default URL in the _azure_environments.py file to find them.
3) The SDK exposes two new functions, add_cloud which will add the new
   configuration to the configuration file mentioned in #1, and
   update_cloud which will update the added configuration.

* Removing some of the functionality, only ARM check remains

* Adding unit test for new environments functionality

* fixed tests with mock

* removed print statement

* removed commented code

* removed print statements (oops)

* fixed tests and removed comments

* Fixing a testing bug

* Fixing a misspelled word

* Changing how we reach out to ARM, als fixing some pylint

* Fixing more lint errors

* Fixing more lint errors

* updated code per suggestions in PR

* fixed typo in warning

* added registry_endpoint to metadata url, also added tests for making sure all endpointurls are registered

* updated how the registry discovery endpoint is created. Uses a default region but region can be updated with environment variable

* fixed linting errors

* moved discovery url logic around to make sure it's not overwriting public regions

* Fixing small pylint errors

* Moving over to using HttpPipeline instead of requests

* fixed up based on comments in the PR

* fixed broken unit tests and mocked correctly

* Fixing pylint issues

---------

Co-authored-by: Ronald Shaw <ronaldshaw@microsoft.com>
  • Loading branch information
brownma-ms and ronaldshaw-work authored Feb 7, 2023
1 parent 77ae304 commit c3b7b30
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 11 deletions.
104 changes: 97 additions & 7 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_azure_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

from azure.ai.ml._utils.utils import _get_mfe_url_override
from azure.ai.ml.constants._common import AZUREML_CLOUD_ENV_NAME
from azure.ai.ml.constants._common import ArmConstants
from azure.core.rest import HttpRequest
from azure.mgmt.core import ARMPipelineClient



module_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,6 +61,19 @@ class EndpointURLS: # pylint: disable=too-few-public-methods,no-init
},
}

_requests_pipeline = None

def _get_cloud(cloud: str):
if cloud in _environments:
return _environments[cloud]
arm_url = os.environ.get(ArmConstants.METADATA_URL_ENV_NAME,ArmConstants.DEFAULT_URL)
arm_clouds = _get_clouds_by_metadata_url(arm_url)
try:
new_cloud = arm_clouds[cloud]
_environments.update(new_cloud)
return new_cloud
except KeyError:
raise Exception('Unknown cloud environment "{0}".'.format(cloud))

def _get_default_cloud_name():
"""Return AzureCloud as the default cloud."""
Expand All @@ -74,17 +92,18 @@ def _get_cloud_details(cloud: str = AzureEnvironments.ENV_DEFAULT):
AzureEnvironments.ENV_DEFAULT,
)
cloud = _get_default_cloud_name()
try:
azure_environment = _environments[cloud]
module_logger.debug("Using the cloud configuration: '%s'.", azure_environment)
except KeyError:
raise Exception('Unknown cloud environment "{0}".'.format(cloud))
return azure_environment
return _get_cloud(cloud)


def _set_cloud(cloud: str = AzureEnvironments.ENV_DEFAULT):
"""Sets the current cloud
:param cloud: cloud name
"""
if cloud is not None:
if cloud not in _environments:
try:
_get_cloud(cloud)
except Exception:
raise Exception('Unknown cloud environment supplied: "{0}".'.format(cloud))
else:
cloud = _get_default_cloud_name()
Expand Down Expand Up @@ -189,3 +208,74 @@ def _resource_to_scopes(resource):
"""
scope = resource + "/.default"
return [scope]

def _get_registry_discovery_url(cloud, cloud_suffix=""):
"""Get or generate the registry discovery url
:param cloud: configuration of the cloud to get the registry_discovery_url from
:param cloud_suffix: the suffix to use for the cloud, in the case that the registry_discovery_url
must be generated
:return: string of discovery url
"""
cloud_name = cloud["name"]
if cloud_name in _environments:
return _environments[cloud_name].registry_url

registry_discovery_region = os.environ.get(
ArmConstants.REGISTRY_DISCOVERY_REGION_ENV_NAME,
ArmConstants.REGISTRY_DISCOVERY_DEFAULT_REGION
)
registry_discovery_region_default = "https://{}{}.api.azureml.{}/".format(
cloud_name.lower(),
registry_discovery_region,
cloud_suffix
)
return os.environ.get(ArmConstants.REGISTRY_ENV_URL, registry_discovery_region_default)

def _get_clouds_by_metadata_url(metadata_url):
"""Get all the clouds by the specified metadata url
:return: list of the clouds
"""
try:
module_logger.debug('Start : Loading cloud metadata from the url specified by %s', metadata_url)
client = ARMPipelineClient(base_url=metadata_url, policies=[])
HttpRequest("GET", metadata_url)
with client.send_request(HttpRequest("GET", metadata_url)) as meta_response:
arm_cloud_dict = meta_response.json()
cli_cloud_dict = _convert_arm_to_cli(arm_cloud_dict)
module_logger.debug('Finish : Loading cloud metadata from the url specified by %s', metadata_url)
return cli_cloud_dict
except Exception as ex: # pylint: disable=broad-except
module_logger.warning("Error: Azure ML was unable to load cloud metadata from the url specified by %s. %s. "
"This may be due to a misconfiguration of networking controls. Azure Machine Learning Python "
"SDK requires outbound access to Azure Resource Manager. Please contact your networking team "
"to configure outbound access to Azure Resource Manager on both Network Security Group and "
"Firewall. For more details on required configurations, see "
"https://docs.microsoft.com/azure/machine-learning/how-to-access-azureml-behind-firewall.",
metadata_url, ex)
return {}

def _convert_arm_to_cli(arm_cloud_metadata):
cli_cloud_metadata_dict = {}
if isinstance(arm_cloud_metadata, dict):
arm_cloud_metadata = [arm_cloud_metadata]

for cloud in arm_cloud_metadata:
try:
cloud_name = cloud["name"]
portal_endpoint = cloud["portal"]
cloud_suffix = ".".join(portal_endpoint.split('.')[2:]).replace("/", "")
registry_discovery_url = _get_registry_discovery_url(cloud, cloud_suffix)
cli_cloud_metadata_dict[cloud_name] = {
EndpointURLS.AZURE_PORTAL_ENDPOINT: cloud["portal"],
EndpointURLS.RESOURCE_MANAGER_ENDPOINT: cloud["resourceManager"],
EndpointURLS.ACTIVE_DIRECTORY_ENDPOINT: cloud["authentication"]["loginEndpoint"],
EndpointURLS.AML_RESOURCE_ID: "https://ml.azure.{}".format(cloud_suffix),
EndpointURLS.STORAGE_ENDPOINT: cloud["suffixes"]["storage"],
EndpointURLS.REGISTRY_DISCOVERY_ENDPOINT: registry_discovery_url
}
except KeyError as ex:
module_logger.warning("Property on cloud not found in arm cloud metadata: %s", ex)
continue
return cli_cloud_metadata_dict
6 changes: 6 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ class ArmConstants(object):
AZURE_MGMT_KEYVAULT_API_VERSION = "2019-09-01"
AZURE_MGMT_CONTAINER_REG_API_VERSION = "2019-05-01"

DEFAULT_URL = "https://management.azure.com/metadata/endpoints?api-version=2019-05-01"
METADATA_URL_ENV_NAME = "ARM_CLOUD_METADATA_URL"
REGISTRY_DISCOVERY_DEFAULT_REGION = "west"
REGISTRY_DISCOVERY_REGION_ENV_NAME = "REGISTRY_DISCOVERY_ENDPOINT_REGION"
REGISTRY_ENV_URL = "REGISTRY_DISCOVERY_ENDPOINT_URL"


class HttpResponseStatusCode(object):
NOT_FOUND = 404
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,66 @@
import os

import mock
import pytest
from mock import MagicMock, patch

from azure.ai.ml._azure_environments import (
AzureEnvironments,
EndpointURLS,
_get_azure_portal_id_from_metadata,
_get_base_url_from_metadata,
_get_cloud_details,
_get_cloud_information_from_metadata,
_get_default_cloud_name,
_get_registry_discovery_endpoint_from_metadata,
_get_storage_endpoint_from_metadata,
_set_cloud,
)
from azure.ai.ml.constants._common import AZUREML_CLOUD_ENV_NAME
from azure.ai.ml.constants._common import ArmConstants, AZUREML_CLOUD_ENV_NAME
from azure.mgmt.core import ARMPipelineClient

def mocked_send_request_get(*args, **kwargs):
class MockResponse:
def __init__(self):
self.status_code = 201
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
return
def json(self):
return [
{
"name": "TEST_ENV",
"portal": "testportal.azure.com",
"resourceManager": "testresourcemanager.azure.com",
"authentication": {
"loginEndpoint": "testdirectoryendpoint.azure.com"
},
"suffixes": {
"storage": "teststorageendpoint"
}
},
{
"name": "TEST_ENV2",
"portal": "testportal.azure.windows.net",
"resourceManager": "testresourcemanager.azure.com",
"authentication": {
"loginEndpoint": "testdirectoryendpoint.azure.com"
},
"suffixes": {
"storage": "teststorageendpoint"
}
},
{
"name": "MISCONFIGURED"
}
]
return MockResponse()


@pytest.mark.unittest
@pytest.mark.core_sdk_test
class TestCloudEnvironments:

@mock.patch.dict(os.environ, {AZUREML_CLOUD_ENV_NAME: AzureEnvironments.ENV_DEFAULT}, clear=True)
def test_set_valid_cloud_details_china(self):
cloud_environment = AzureEnvironments.ENV_CHINA
Expand Down Expand Up @@ -70,7 +112,6 @@ def test_get_default_cloud(self):
with mock.patch("os.environ", {AZUREML_CLOUD_ENV_NAME: "yadadada"}):
cloud_name = _get_default_cloud_name()
assert cloud_name == "yadadada"


def test_get_registry_endpoint_from_public(self):
cloud_environment = AzureEnvironments.ENV_DEFAULT
Expand All @@ -88,4 +129,36 @@ def test_get_registry_endpoint_from_us_gov(self):
cloud_environment = AzureEnvironments.ENV_US_GOVERNMENT
_set_cloud(cloud_environment)
base_url = _get_registry_discovery_endpoint_from_metadata(cloud_environment)
assert "https://usgovarizona.api.ml.azure.us/" in base_url
assert "https://usgovarizona.api.ml.azure.us/" in base_url

@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("azure.mgmt.core.ARMPipelineClient.send_request", side_effect=mocked_send_request_get)
def test_get_cloud_from_arm(self, mock_arm_pipeline_client_send_request):

_set_cloud('TEST_ENV')
cloud_details = _get_cloud_information_from_metadata("TEST_ENV")
assert cloud_details.get("cloud") == "TEST_ENV"

@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("azure.mgmt.core.ARMPipelineClient.send_request", side_effect=mocked_send_request_get)
def test_all_endpointurls_used(self, mock_get):
cloud_details = _get_cloud_details("TEST_ENV")
endpoint_urls = [a for a in dir(EndpointURLS) if not a.startswith('__')]
for url in endpoint_urls:
try:
cloud_details[EndpointURLS.__dict__[url]]
except:
assert False, "Url not found: {}".format(EndpointURLS.__dict__[url])
assert True

@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("azure.mgmt.core.ARMPipelineClient.send_request", side_effect=mocked_send_request_get)
def test_metadata_registry_endpoint(self, mock_get):
cloud_details = _get_cloud_details("TEST_ENV2")
assert cloud_details.get(EndpointURLS.REGISTRY_DISCOVERY_ENDPOINT) == "https://test_env2west.api.azureml.windows.net/"

@mock.patch.dict(os.environ, {}, clear=True)
@mock.patch("azure.mgmt.core.ARMPipelineClient.send_request", side_effect=mocked_send_request_get)
def test_arm_misconfigured(self, mock_get):
with pytest.raises(Exception) as e_info:
_set_cloud("MISCONFIGURED")

0 comments on commit c3b7b30

Please sign in to comment.