Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sources/azure: refactor ssh key handling #1248

Merged
merged 1 commit into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 51 additions & 64 deletions cloudinit/sources/DataSourceAzure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
import os.path
import re
import xml.etree.ElementTree as ET
from collections import namedtuple
from enum import Enum
from functools import partial
from time import sleep, time
from typing import Optional
from typing import List, Optional
from xml.dom import minidom

import requests
Expand Down Expand Up @@ -71,10 +69,6 @@
IMDS_VER_WANT = "2021-08-01"
IMDS_EXTENDED_VER_MIN = "2021-03-01"

# This holds SSH key data including if the source was
# from IMDS, as well as the SSH key data itself.
SSHKeys = namedtuple("SSHKeys", ("keys_from_imds", "ssh_keys"))


class MetadataType(Enum):
ALL = "{}/instance".format(IMDS_URL)
Expand Down Expand Up @@ -740,63 +734,59 @@ def device_name_to_device(self, name):
return self.ds_cfg["disk_aliases"].get(name)

@azure_ds_telemetry_reporter
def get_public_ssh_keys(self):
def get_public_ssh_keys(self) -> List[str]:
"""
Retrieve public SSH keys.
"""
try:
return self._get_public_keys_from_imds(self.metadata["imds"])
except (KeyError, ValueError):
pass

return self._get_public_ssh_keys_and_source().ssh_keys
return self._get_public_keys_from_ovf()

def _get_public_ssh_keys_and_source(self):
"""
Try to get the ssh keys from IMDS first, and if that fails
(i.e. IMDS is unavailable) then fallback to getting the ssh
keys from OVF.
def _get_public_keys_from_imds(self, imds_md: dict) -> List[str]:
"""Get SSH keys from IMDS metadata.

The benefit to getting keys from IMDS is a large performance
advantage, so this is a strong preference. But we must keep
OVF as a second option for environments that don't have IMDS.
"""
:raises KeyError: if IMDS metadata is malformed/missing.
cjp256 marked this conversation as resolved.
Show resolved Hide resolved
:raises ValueError: if key format is not supported.

LOG.debug("Retrieving public SSH keys")
ssh_keys = []
keys_from_imds = True
LOG.debug("Attempting to get SSH keys from IMDS")
:returns: List of keys.
"""
try:
ssh_keys = [
public_key["keyData"]
for public_key in self.metadata["imds"]["compute"][
"publicKeys"
]
for public_key in imds_md["compute"]["publicKeys"]
]
for key in ssh_keys:
if not _key_is_openssh_formatted(key=key):
keys_from_imds = False
break

if not keys_from_imds:
log_msg = "Keys not in OpenSSH format, using OVF"
else:
log_msg = "Retrieved {} keys from IMDS".format(
len(ssh_keys) if ssh_keys is not None else 0
)
except KeyError:
log_msg = "Unable to get keys from IMDS, falling back to OVF"
keys_from_imds = False
finally:
log_msg = "No SSH keys found in IMDS metadata"
report_diagnostic_event(log_msg, logger_func=LOG.debug)
raise

if not keys_from_imds:
LOG.debug("Attempting to get SSH keys from OVF")
try:
ssh_keys = self.metadata["public-keys"]
log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys))
except KeyError:
log_msg = "No keys available from OVF"
finally:
report_diagnostic_event(log_msg, logger_func=LOG.debug)
if any(not _key_is_openssh_formatted(key=key) for key in ssh_keys):
log_msg = "Key(s) not in OpenSSH format"
report_diagnostic_event(log_msg, logger_func=LOG.debug)
raise ValueError(log_msg)

log_msg = "Retrieved {} keys from IMDS".format(len(ssh_keys))
report_diagnostic_event(log_msg, logger_func=LOG.debug)
return ssh_keys

return SSHKeys(keys_from_imds=keys_from_imds, ssh_keys=ssh_keys)
def _get_public_keys_from_ovf(self) -> List[str]:
"""Get SSH keys that were fetched from wireserver.

:returns: List of keys.
"""
ssh_keys = []
try:
ssh_keys = self.metadata["public-keys"]
log_msg = "Retrieved {} keys from OVF".format(len(ssh_keys))
report_diagnostic_event(log_msg, logger_func=LOG.debug)
except KeyError:
log_msg = "No keys available from OVF"
report_diagnostic_event(log_msg, logger_func=LOG.debug)

return ssh_keys

def get_config_obj(self):
return self.cfg
Expand Down Expand Up @@ -832,10 +822,10 @@ def setup(self, is_new_instance):
self.get_instance_id(),
is_new_instance,
)
fabric_data = self._negotiate()
LOG.debug("negotiating returned %s", fabric_data)
if fabric_data:
self.metadata.update(fabric_data)
ssh_keys = self._negotiate()
LOG.debug("negotiating returned %s", ssh_keys)
if ssh_keys:
self.metadata["public-keys"] = ssh_keys
self._negotiated = True
else:
LOG.debug(
Expand Down Expand Up @@ -1462,24 +1452,21 @@ def _negotiate(self):
On failure, returns False.
"""
pubkey_info = None
ssh_keys_and_source = self._get_public_ssh_keys_and_source()

if not ssh_keys_and_source.keys_from_imds:
try:
self._get_public_keys_from_imds(self.metadata["imds"])
except (KeyError, ValueError):
pubkey_info = self.cfg.get("_pubkeys", None)
log_msg = "Retrieved {} fingerprints from OVF".format(
len(pubkey_info) if pubkey_info is not None else 0
)
report_diagnostic_event(log_msg, logger_func=LOG.debug)

metadata_func = partial(
get_metadata_from_fabric,
fallback_lease_file=self.dhclient_lease_file,
pubkey_info=pubkey_info,
)

LOG.debug("negotiating with fabric")
try:
fabric_data = metadata_func()
ssh_keys = get_metadata_from_fabric(
fallback_lease_file=self.dhclient_lease_file,
pubkey_info=pubkey_info,
)
except Exception as e:
report_diagnostic_event(
"Error communicating with Azure fabric; You may experience "
Expand All @@ -1491,7 +1478,7 @@ def _negotiate(self):
util.del_file(REPORTED_READY_MARKER_FILE)
util.del_file(REPROVISION_MARKER_FILE)
util.del_file(REPROVISION_NIC_DETACHED_MARKER_FILE)
return fabric_data
return ssh_keys

@azure_ds_telemetry_reporter
def activate(self, cfg, is_new_instance):
Expand Down
5 changes: 3 additions & 2 deletions cloudinit/sources/helpers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from contextlib import contextmanager
from datetime import datetime
from errno import ENOENT
from typing import List, Optional
from xml.etree import ElementTree
from xml.sax.saxutils import escape

Expand Down Expand Up @@ -1005,7 +1006,7 @@ def eject_iso(self, iso_dev) -> None:
@azure_ds_telemetry_reporter
def register_with_azure_and_fetch_data(
self, pubkey_info=None, iso_dev=None
) -> dict:
) -> Optional[List[str]]:
"""Gets the VM's GoalState from Azure, uses the GoalState information
to report ready/send the ready signal/provisioning complete signal to
Azure, and then uses pubkey_info to filter and obtain the user's
Expand Down Expand Up @@ -1038,7 +1039,7 @@ def register_with_azure_and_fetch_data(
self.eject_iso(iso_dev)

health_reporter.send_ready_signal()
return {"public-keys": ssh_keys}
return ssh_keys

@azure_ds_telemetry_reporter
def register_with_azure_and_report_failure(self, description: str) -> None:
Expand Down
10 changes: 4 additions & 6 deletions tests/unittests/sources/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,9 +762,7 @@ def _load_possible_azure_ds(seed_dir, cache_dir):
dsaz.BUILTIN_DS_CONFIG["data_dir"] = self.waagent_d

self.m_is_platform_viable = mock.MagicMock(autospec=True)
self.m_get_metadata_from_fabric = mock.MagicMock(
return_value={"public-keys": []}
)
self.m_get_metadata_from_fabric = mock.MagicMock(return_value=[])
self.m_report_failure_to_fabric = mock.MagicMock(autospec=True)
self.m_list_possible_azure_ds = mock.MagicMock(
side_effect=_load_possible_azure_ds
Expand Down Expand Up @@ -1725,10 +1723,10 @@ def test_exception_fetching_fabric_data_doesnt_propagate(self):

def test_fabric_data_included_in_metadata(self):
dsrc = self._get_ds({"ovfcontent": construct_valid_ovf_env()})
self.m_get_metadata_from_fabric.return_value = {"test": "value"}
self.m_get_metadata_from_fabric.return_value = ["ssh-key-value"]
ret = self._get_and_setup(dsrc)
self.assertTrue(ret)
self.assertEqual("value", dsrc.metadata["test"])
self.assertEqual(["ssh-key-value"], dsrc.metadata["public-keys"])

def test_instance_id_case_insensitive(self):
"""Return the previous iid when current is a case-insensitive match."""
Expand Down Expand Up @@ -2008,7 +2006,7 @@ def test_get_public_ssh_keys_without_imds(self, m_get_metadata_from_imds):
"sys_cfg": sys_cfg,
}
dsrc = self._get_ds(data)
dsaz.get_metadata_from_fabric.return_value = {"public-keys": ["key2"]}
dsaz.get_metadata_from_fabric.return_value = ["key2"]
dsrc.get_data()
dsrc.setup(True)
ssh_keys = dsrc.get_public_ssh_keys()
Expand Down
8 changes: 4 additions & 4 deletions tests/unittests/sources/test_azure_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,16 +1204,16 @@ def test_certificates_used_to_determine_public_keys(self):
[mock.call(self.GoalState.return_value.certificates_xml)],
sslmgr.parse_certificates.call_args_list,
)
self.assertIn("expected-key", data["public-keys"])
self.assertIn("expected-no-value-key", data["public-keys"])
self.assertNotIn("should-not-be-found", data["public-keys"])
self.assertIn("expected-key", data)
self.assertIn("expected-no-value-key", data)
self.assertNotIn("should-not-be-found", data)

def test_absent_certificates_produces_empty_public_keys(self):
mypk = [{"fingerprint": "fp1", "path": "path1"}]
self.GoalState.return_value.certificates_xml = None
shim = wa_shim()
data = shim.register_with_azure_and_fetch_data(pubkey_info=mypk)
self.assertEqual([], data["public-keys"])
self.assertEqual([], data)

def test_correct_url_used_for_report_ready(self):
self.find_endpoint.return_value = "test_endpoint"
Expand Down