diff --git a/cloudinit/sources/DataSourceAzure.py b/cloudinit/sources/DataSourceAzure.py index 2566bd8e9e92..41cd3ab27a65 100755 --- a/cloudinit/sources/DataSourceAzure.py +++ b/cloudinit/sources/DataSourceAzure.py @@ -589,9 +589,23 @@ def crawl_metadata(self): crawled_data["metadata"]["random_seed"] = seed crawled_data["metadata"]["instance-id"] = self._iid() - if pps_type != PPSType.NONE: - LOG.info("Reporting ready to Azure after getting ReprovisionData") - self._report_ready() + if self._negotiated is False and self._is_ephemeral_networking_up(): + # Report ready and fetch public-keys from Wireserver, if required. + pubkey_info = self._determine_wireserver_pubkey_info( + cfg=cfg, imds_md=imds_md + ) + try: + ssh_keys = self._report_ready(pubkey_info=pubkey_info) + except Exception: + # Failed to report ready, but continue with best effort. + pass + else: + LOG.debug("negotiating returned %s", ssh_keys) + if ssh_keys: + crawled_data["metadata"]["public-keys"] = ssh_keys + + self._cleanup_markers() + self._negotiated = True return crawled_data @@ -820,24 +834,6 @@ def _iid(self, previous=None): return previous return iid - @azure_ds_telemetry_reporter - def setup(self, is_new_instance): - if self._negotiated is False: - LOG.debug( - "negotiating for %s (new_instance=%s)", - self.get_instance_id(), - is_new_instance, - ) - ssh_keys = self._negotiate() - LOG.debug("negotiating returned %s", ssh_keys) - if ssh_keys: - self.metadata["public-keys"] = ssh_keys - self._negotiated = True - else: - LOG.debug( - "negotiating already done for %s", self.get_instance_id() - ) - @azure_ds_telemetry_reporter def _wait_for_nic_detach(self, nl_sock): """Use the netlink socket provided to wait for nic detach event. @@ -960,8 +956,9 @@ def _report_ready_for_pps(self) -> None: :raises sources.InvalidMetaDataException: On error reporting ready. """ - report_ready_succeeded = self._report_ready() - if not report_ready_succeeded: + try: + self._report_ready() + except Exception: msg = "Failed reporting ready while in the preprovisioning pool." report_diagnostic_event(msg, logger_func=LOG.error) raise sources.InvalidMetaDataException(msg) @@ -1377,25 +1374,32 @@ def _report_failure(self, description: Optional[str] = None) -> bool: return False - def _report_ready(self) -> bool: + @azure_ds_telemetry_reporter + def _report_ready( + self, *, pubkey_info: Optional[List[str]] = None + ) -> Optional[List[str]]: """Tells the fabric provisioning has completed. - @return: The success status of sending the ready signal. + :param pubkey_info: Fingerprints of keys to request from Wireserver. + + :raises Exception: if failed to report. + + :returns: List of SSH keys, if requested. """ try: - get_metadata_from_fabric( + return get_metadata_from_fabric( fallback_lease_file=None, dhcp_opts=self._wireserver_endpoint, iso_dev=self.iso_dev, + pubkey_info=pubkey_info, ) - return True except Exception as e: report_diagnostic_event( "Error communicating with Azure fabric; You may experience " "connectivity issues: %s" % e, logger_func=LOG.warning, ) - return False + raise def _ppstype_from_imds(self, imds_md: dict) -> Optional[str]: try: @@ -1441,6 +1445,7 @@ def _write_reprovision_marker(self): "{pid}: {time}\n".format(pid=os.getpid(), time=time()), ) + @azure_ds_telemetry_reporter def _reprovision(self): """Initiate the reprovisioning workflow. @@ -1456,40 +1461,29 @@ def _reprovision(self): return (md, ud, cfg, {"ovf-env.xml": contents}) @azure_ds_telemetry_reporter - def _negotiate(self): - """Negotiate with fabric and return data from it. + def _determine_wireserver_pubkey_info( + self, *, cfg: dict, imds_md: dict + ) -> Optional[List[str]]: + """Determine the fingerprints we need to retrieve from Wireserver. - On success, returns a dictionary including 'public_keys'. - On failure, returns False. + :return: List of keys to request from Wireserver, if any, else None. """ - pubkey_info = None + pubkey_info: Optional[List[str]] = None try: - self._get_public_keys_from_imds(self.metadata["imds"]) + self._get_public_keys_from_imds(imds_md) except (KeyError, ValueError): - pubkey_info = self.cfg.get("_pubkeys", None) + pubkey_info = cfg.get("_pubkeys", None) log_msg = "Retrieved {} fingerprints from OVF".format( len(pubkey_info) if pubkey_info is not None else 0 ) report_diagnostic_event(log_msg, logger_func=LOG.debug) + return pubkey_info - LOG.debug("negotiating with fabric") - try: - ssh_keys = get_metadata_from_fabric( - fallback_lease_file=self.dhclient_lease_file, - pubkey_info=pubkey_info, - ) - except Exception as e: - report_diagnostic_event( - "Error communicating with Azure fabric; You may experience " - "connectivity issues: %s" % e, - logger_func=LOG.warning, - ) - return False - + def _cleanup_markers(self): + """Cleanup any marker files.""" util.del_file(REPORTED_READY_MARKER_FILE) util.del_file(REPROVISION_MARKER_FILE) util.del_file(REPROVISION_NIC_DETACHED_MARKER_FILE) - return ssh_keys @azure_ds_telemetry_reporter def activate(self, cfg, is_new_instance): diff --git a/tests/unittests/sources/test_azure.py b/tests/unittests/sources/test_azure.py index bd525b13b673..e850ec9e8797 100644 --- a/tests/unittests/sources/test_azure.py +++ b/tests/unittests/sources/test_azure.py @@ -1225,7 +1225,10 @@ def test_crawl_metadata_on_reprovision_reports_ready_using_lease( dsrc.crawl_metadata() - assert m_report_ready.mock_calls == [mock.call(), mock.call()] + assert m_report_ready.mock_calls == [ + mock.call(), + mock.call(pubkey_info=None), + ] def test_waagent_d_has_0700_perms(self): # we expect /var/lib/waagent to be created 0700 @@ -1603,12 +1606,23 @@ def test_ovf_can_include_unicode(self): def test_dsaz_report_ready_returns_true_when_report_succeeds(self): dsrc = self._get_ds({"ovfcontent": construct_valid_ovf_env()}) - self.assertTrue(dsrc._report_ready()) + assert dsrc._report_ready() == [] - def test_dsaz_report_ready_returns_false_and_does_not_propagate_exc(self): + @mock.patch(MOCKPATH + "report_diagnostic_event") + def test_dsaz_report_ready_failure_reports_telemetry(self, m_report_diag): dsrc = self._get_ds({"ovfcontent": construct_valid_ovf_env()}) - self.m_get_metadata_from_fabric.side_effect = Exception - self.assertFalse(dsrc._report_ready()) + self.m_get_metadata_from_fabric.side_effect = Exception("foo") + + with pytest.raises(Exception): + dsrc._report_ready() + + assert m_report_diag.mock_calls == [ + mock.call( + "Error communicating with Azure fabric; " + "You may experience connectivity issues: foo", + logger_func=dsaz.LOG.warning, + ) + ] def test_dsaz_report_failure_returns_true_when_report_succeeds(self): dsrc = self._get_ds({"ovfcontent": construct_valid_ovf_env()}) @@ -3282,7 +3296,7 @@ def test_poll_imds_report_ready_failure_raises_exc_and_doesnt_write_marker( } ] m_media_switch.return_value = None - m_report_ready.return_value = False + m_report_ready.side_effect = [Exception("fail")] dsa = dsaz.DataSourceAzure({}, distro=None, paths=self.paths) self.assertFalse(os.path.exists(report_file)) with mock.patch(MOCKPATH + "REPORTED_READY_MARKER_FILE", report_file):