Skip to content

Commit

Permalink
sources/azure: refactor ssh key handling
Browse files Browse the repository at this point in the history
Split _get_public_ssh_keys_and_source() into
_get_public_keys_from_imds() and _get_public_keys_from_ovf().

Set _get_public_keys_from_imds() to take a parameter of the
IMDS metadata rather than assuming it is already set in
self.metadata.  This will allow us to move negotation into
local phase where self.metadata may not be set yet.  Update this
method to raise KeyError if IMDS metadata is missing/malformed,
and ValueError if SSH key format is not supported.  Update
get_public_ssh_keys() to catch these errors and fall back to the
OVF/Wireserver keys as needed.

To improve clarity, update register_with_azure_and_fetch_data()
to return the list of SSH keys, rather than bundling them into
a dictionary for updating against the metadata dictionary.

There should be no change in behavior with this refactor.

Signed-off-by: Chris Patterson <cpatterson@microsoft.com>
  • Loading branch information
cjp256 committed Feb 7, 2022
1 parent 826783d commit b4cb8fd
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 76 deletions.
115 changes: 51 additions & 64 deletions cloudinit/sources/DataSourceAzure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
import os.path
import re
import xml.etree.ElementTree as ET
from collections import namedtuple
from enum import Enum
from functools import partial
from time import sleep, time
from typing import Optional
from typing import List, Optional
from xml.dom import minidom

import requests
Expand Down Expand Up @@ -71,10 +69,6 @@
IMDS_VER_WANT = "2021-08-01"
IMDS_EXTENDED_VER_MIN = "2021-03-01"

# This holds SSH key data including if the source was
# from IMDS, as well as the SSH key data itself.
SSHKeys = namedtuple("SSHKeys", ("keys_from_imds", "ssh_keys"))


class MetadataType(Enum):
ALL = "{}/instance".format(IMDS_URL)
Expand Down Expand Up @@ -740,63 +734,59 @@ def device_name_to_device(self, name):
return self.ds_cfg["disk_aliases"].get(name)

@azure_ds_telemetry_reporter
def get_public_ssh_keys(self):
def get_public_ssh_keys(self) -> List[str]:
"""
Retrieve public SSH keys.
"""
try:
return self._get_public_keys_from_imds(self.metadata["imds"])
except (KeyError, ValueError):
pass

return self._get_public_ssh_keys_and_source().ssh_keys
return self._get_public_keys_from_ovf()

def _get_public_ssh_keys_and_source(self):
"""
Try to get the ssh keys from IMDS first, and if that fails
(i.e. IMDS is unavailable) then fallback to getting the ssh
keys from OVF.
def _get_public_keys_from_imds(self, imds_md: dict) -> List[str]:
"""Get SSH keys from IMDS metadata.
The benefit to getting keys from IMDS is a large performance
advantage, so this is a strong preference. But we must keep
OVF as a second option for environments that don't have IMDS.
"""
:raises KeyError: if IMDS metadata is malformed/missing.
:raises ValueError: if key format is not supported.
LOG.debug("Retrieving public SSH keys")
ssh_keys = []
keys_from_imds = True
LOG.debug("Attempting to get SSH keys from IMDS")
:returns: List of keys.
"""
try:
ssh_keys = [
public_key["keyData"]
for public_key in self.metadata["imds"]["compute"][
"publicKeys"
]
for public_key in imds_md["compute"]["publicKeys"]
]
for key in ssh_keys:
if not _key_is_openssh_formatted(key=key):
keys_from_imds = False
break

if not keys_from_imds:
log_msg = "Keys not in OpenSSH format, using OVF"
else:
log_msg = "Retrieved {} keys from IMDS".format(
len(ssh_keys) if ssh_keys is not None else 0
)
except KeyError:
log_msg = "Unable to get keys from IMDS, falling back to OVF"
keys_from_imds = False
finally:
log_msg = "No SSH keys found in IMDS metadata"
report_diagnostic_event(log_msg, logger_func=LOG.debug)
raise ValueError(log_msg)

if not keys_from_imds:
LOG.debug("Attempting to get SSH keys from OVF")
try:
ssh_keys = self.metadata["public-keys"]
log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys))
except KeyError:
log_msg = "No keys available from OVF"
finally:
report_diagnostic_event(log_msg, logger_func=LOG.debug)
if any(not _key_is_openssh_formatted(key=key) for key in ssh_keys):
log_msg = "Key(s) not in OpenSSH format"
report_diagnostic_event(log_msg, logger_func=LOG.debug)
raise ValueError(log_msg)

return SSHKeys(keys_from_imds=keys_from_imds, ssh_keys=ssh_keys)
log_msg = "Retrieved {} keys from IMDS".format(len(ssh_keys))
report_diagnostic_event(log_msg, logger_func=LOG.debug)
return ssh_keys

def _get_public_keys_from_ovf(self) -> List[str]:
"""Get SSH keys that were fetched from wireserver.
:returns: List of keys.
"""
ssh_keys = []
try:
ssh_keys = self.metadata["public-keys"]
log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys))
report_diagnostic_event(log_msg, logger_func=LOG.debug)
except KeyError:
log_msg = "No keys available from OVF"
report_diagnostic_event(log_msg, logger_func=LOG.debug)

return ssh_keys

def get_config_obj(self):
return self.cfg
Expand Down Expand Up @@ -832,10 +822,10 @@ def setup(self, is_new_instance):
self.get_instance_id(),
is_new_instance,
)
fabric_data = self._negotiate()
LOG.debug("negotiating returned %s", fabric_data)
if fabric_data:
self.metadata.update(fabric_data)
ssh_keys = self._negotiate()
LOG.debug("negotiating returned %s", ssh_keys)
if ssh_keys:
self.metadata["public-keys"] = ssh_keys
self._negotiated = True
else:
LOG.debug(
Expand Down Expand Up @@ -1462,24 +1452,21 @@ def _negotiate(self):
On failure, returns False.
"""
pubkey_info = None
ssh_keys_and_source = self._get_public_ssh_keys_and_source()

if not ssh_keys_and_source.keys_from_imds:
try:
self._get_public_keys_from_imds(self.metadata["imds"])
except (KeyError, ValueError):
pubkey_info = self.cfg.get("_pubkeys", None)
log_msg = "Retrieved {} fingerprints from OVF".format(
len(pubkey_info) if pubkey_info is not None else 0
)
report_diagnostic_event(log_msg, logger_func=LOG.debug)

metadata_func = partial(
get_metadata_from_fabric,
fallback_lease_file=self.dhclient_lease_file,
pubkey_info=pubkey_info,
)

LOG.debug("negotiating with fabric")
try:
fabric_data = metadata_func()
ssh_keys = get_metadata_from_fabric(
fallback_lease_file=self.dhclient_lease_file,
pubkey_info=pubkey_info,
)
except Exception as e:
report_diagnostic_event(
"Error communicating with Azure fabric; You may experience "
Expand All @@ -1491,7 +1478,7 @@ def _negotiate(self):
util.del_file(REPORTED_READY_MARKER_FILE)
util.del_file(REPROVISION_MARKER_FILE)
util.del_file(REPROVISION_NIC_DETACHED_MARKER_FILE)
return fabric_data
return ssh_keys

@azure_ds_telemetry_reporter
def activate(self, cfg, is_new_instance):
Expand Down
5 changes: 3 additions & 2 deletions cloudinit/sources/helpers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from contextlib import contextmanager
from datetime import datetime
from errno import ENOENT
from typing import List, Optional
from xml.etree import ElementTree
from xml.sax.saxutils import escape

Expand Down Expand Up @@ -1005,7 +1006,7 @@ def eject_iso(self, iso_dev) -> None:
@azure_ds_telemetry_reporter
def register_with_azure_and_fetch_data(
self, pubkey_info=None, iso_dev=None
) -> dict:
) -> Optional[List[str]]:
"""Gets the VM's GoalState from Azure, uses the GoalState information
to report ready/send the ready signal/provisioning complete signal to
Azure, and then uses pubkey_info to filter and obtain the user's
Expand Down Expand Up @@ -1038,7 +1039,7 @@ def register_with_azure_and_fetch_data(
self.eject_iso(iso_dev)

health_reporter.send_ready_signal()
return {"public-keys": ssh_keys}
return ssh_keys

@azure_ds_telemetry_reporter
def register_with_azure_and_report_failure(self, description: str) -> None:
Expand Down
10 changes: 4 additions & 6 deletions tests/unittests/sources/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,9 +762,7 @@ def _load_possible_azure_ds(seed_dir, cache_dir):
dsaz.BUILTIN_DS_CONFIG["data_dir"] = self.waagent_d

self.m_is_platform_viable = mock.MagicMock(autospec=True)
self.m_get_metadata_from_fabric = mock.MagicMock(
return_value={"public-keys": []}
)
self.m_get_metadata_from_fabric = mock.MagicMock(return_value=[])
self.m_report_failure_to_fabric = mock.MagicMock(autospec=True)
self.m_list_possible_azure_ds = mock.MagicMock(
side_effect=_load_possible_azure_ds
Expand Down Expand Up @@ -1725,10 +1723,10 @@ def test_exception_fetching_fabric_data_doesnt_propagate(self):

def test_fabric_data_included_in_metadata(self):
dsrc = self._get_ds({"ovfcontent": construct_valid_ovf_env()})
self.m_get_metadata_from_fabric.return_value = {"test": "value"}
self.m_get_metadata_from_fabric.return_value = ["ssh-key-value"]
ret = self._get_and_setup(dsrc)
self.assertTrue(ret)
self.assertEqual("value", dsrc.metadata["test"])
self.assertEqual(["ssh-key-value"], dsrc.metadata["public-keys"])

def test_instance_id_case_insensitive(self):
"""Return the previous iid when current is a case-insensitive match."""
Expand Down Expand Up @@ -2008,7 +2006,7 @@ def test_get_public_ssh_keys_without_imds(self, m_get_metadata_from_imds):
"sys_cfg": sys_cfg,
}
dsrc = self._get_ds(data)
dsaz.get_metadata_from_fabric.return_value = {"public-keys": ["key2"]}
dsaz.get_metadata_from_fabric.return_value = ["key2"]
dsrc.get_data()
dsrc.setup(True)
ssh_keys = dsrc.get_public_ssh_keys()
Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/sources/test_azure_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,16 +1204,16 @@ def test_certificates_used_to_determine_public_keys(self):
[mock.call(self.GoalState.return_value.certificates_xml)],
sslmgr.parse_certificates.call_args_list,
)
self.assertIn("expected-key", data["public-keys"])
self.assertIn("expected-no-value-key", data["public-keys"])
self.assertNotIn("should-not-be-found", data["public-keys"])
self.assertIn("expected-key", data)
self.assertIn("expected-no-value-key", data)
self.assertNotIn("should-not-be-found", data)

def test_absent_certificates_produces_empty_public_keys(self):
mypk = [{"fingerprint": "fp1", "path": "path1"}]
self.GoalState.return_value.certificates_xml = None
shim = wa_shim()
data = shim.register_with_azure_and_fetch_data(pubkey_info=mypk)
self.assertEqual([], data["public-keys"])
self.assertEqual([], data)

def test_correct_url_used_for_report_ready(self):
self.find_endpoint.return_value = "test_endpoint"
Expand Down

0 comments on commit b4cb8fd

Please sign in to comment.