diff --git a/services/addons/images/firmware_manager/download_blob.py b/services/addons/images/firmware_manager/download_blob.py index 6dda3645..e0131068 100644 --- a/services/addons/images/firmware_manager/download_blob.py +++ b/services/addons/images/firmware_manager/download_blob.py @@ -3,8 +3,16 @@ import os -# Download a blob from GCP Bucket Storage def download_gcp_blob(blob_name, destination_file_name): + """Download a file from a GCP Bucket Storage bucket to a local file. + + Args: + blob_name (str): The name of the file in the bucket. + destination_file_name (str): The name of the local file to download the bucket file to. + """ + + validate_file_type(blob_name) + gcp_project = os.environ.get("GCP_PROJECT") bucket_name = os.environ.get("BLOB_STORAGE_BUCKET") storage_client = storage.Client(gcp_project) @@ -16,11 +24,34 @@ def download_gcp_blob(blob_name, destination_file_name): ) -# "Download" a blob from a directory mounted as a volume in a Docker container def download_docker_blob(blob_name, destination_file_name): + """Copy a file from a directory mounted as a volume in a Docker container to a local file. + + Args: + blob_name (str): The name of the file in the directory. + destination_file_name (str): The name of the local file to copy the directory file to. + """ + + validate_file_type(blob_name) + directory = "/mnt/blob_storage" source_file_name = f"{directory}/{blob_name}" os.system(f"cp {source_file_name} {destination_file_name}") logging.info( f"Copied storage object {blob_name} from directory {directory} to local file {destination_file_name}." ) + +def validate_file_type(file_name): + """Validate the file type of the file to be downloaded. + + Args: + file_name (str): The name of the file to be downloaded. + """ + if not file_name.endswith(".tar"): + logging.error(f"Unsupported file type for storage object {file_name}. Only .tar files are supported.") + raise UnsupportedFileTypeException + +class UnsupportedFileTypeException(Exception): + def __init__(self, message="Unsupported file type. Only .tar files are supported."): + self.message = message + super().__init__(self.message) \ No newline at end of file diff --git a/services/addons/images/firmware_manager/sample.env b/services/addons/images/firmware_manager/sample.env index 43f286bb..dd1e046d 100644 --- a/services/addons/images/firmware_manager/sample.env +++ b/services/addons/images/firmware_manager/sample.env @@ -6,12 +6,12 @@ PG_DB_NAME="" PG_DB_USER="" PG_DB_PASS="" -# Blob storage variables +# Blob storage variables (only 'GCP' and 'DOCKER' are supported at this time) BLOB_STORAGE_PROVIDER=DOCKER ## GCP Project and Bucket for BLOB storage (if using GCP) GCP_PROJECT= BLOB_STORAGE_BUCKET= -## Docker volume mount point for BLOB storage (if using Docker) +## Docker volume mount point for BLOB storage (if using DOCKER) HOST_BLOB_STORAGE_DIRECTORY=./local_blob_storage # For users using GCP cloud storage diff --git a/services/addons/images/firmware_manager/upgrader.py b/services/addons/images/firmware_manager/upgrader.py index b93e4078..8e0d36a3 100644 --- a/services/addons/images/firmware_manager/upgrader.py +++ b/services/addons/images/firmware_manager/upgrader.py @@ -40,6 +40,7 @@ def download_blob(self): download_blob.download_docker_blob(self.blob_name, self.local_file_name) else: logging.error("Unsupported blob storage provider") + raise StorageProviderNotSupportedException # Notifies the firmware manager of the completion status for the upgrade # success is a boolean @@ -60,3 +61,7 @@ def notify_firmware_manager(self, success): @abc.abstractclassmethod def upgrade(self): pass + +class StorageProviderNotSupportedException(Exception): + def __init__(self): + super().__init__("Unsupported blob storage provider") \ No newline at end of file diff --git a/services/addons/images/iss_health_check/iss_health_checker.py b/services/addons/images/iss_health_check/iss_health_checker.py index 9ffffb8d..ec2d5a3d 100644 --- a/services/addons/images/iss_health_check/iss_health_checker.py +++ b/services/addons/images/iss_health_check/iss_health_checker.py @@ -4,9 +4,42 @@ import os import iss_token import common.pgquery as pgquery +from dataclasses import dataclass, field +from typing import Dict -def get_rsu_data(): +# Set up logging +logger = logging.getLogger(__name__) + +@dataclass +class RsuDataWrapper: + rsu_data: Dict[str, Dict[str, str]] = field(default_factory=dict) + + def __init__(self, rsu_data): + self.rsu_data = rsu_data + + def get_dict(self): + return self.rsu_data + + def set_provisioner_company(self, scms_id, provisioner_company): + self.rsu_data[scms_id]["provisionerCompany"] = provisioner_company + + def set_entity_type(self, scms_id, entity_type): + self.rsu_data[scms_id]["entityType"] = entity_type + + def set_project_id(self, scms_id, project_id): + self.rsu_data[scms_id]["project_id"] = project_id + + def set_device_health(self, scms_id, device_health): + self.rsu_data[scms_id]["deviceHealth"] = device_health + + def set_expiration(self, scms_id, expiration): + self.rsu_data[scms_id]["expiration"] = expiration + + +def get_rsu_data() -> RsuDataWrapper: + """Get RSU data from PostgreSQL and return it in a wrapper object""" + result = {} query = ( "SELECT jsonb_build_object('rsu_id', rsu_id, 'iss_scms_id', iss_scms_id) " @@ -16,15 +49,17 @@ def get_rsu_data(): ) data = pgquery.query_db(query) - logging.debug("Parsing results...") + logger.debug("Parsing results...") for point in data: point_dict = dict(point[0]) result[point_dict["iss_scms_id"]] = {"rsu_id": point_dict["rsu_id"]} - return result + return RsuDataWrapper(result) def get_scms_status_data(): + """Get SCMS status data from ISS and return it as a dictionary""" + rsu_data = get_rsu_data() # Create GET request headers @@ -43,7 +78,7 @@ def get_scms_status_data(): iss_request = iss_base + "?pageSize={}&page={}&project_id={}".format( page_size, page, project_id ) - logging.debug("GET: " + iss_request) + logger.debug("GET: " + iss_request) response = requests.get(iss_request, headers=iss_headers) enrollment_list = response.json()["data"] @@ -52,50 +87,36 @@ def get_scms_status_data(): # Loop through each device on current page for enrollment_status in enrollment_list: - if enrollment_status["_id"] in rsu_data: - rsu_data[enrollment_status["_id"]][ - "provisionerCompany" - ] = enrollment_status["provisionerCompany_id"] - rsu_data[enrollment_status["_id"]]["entityType"] = enrollment_status[ - "entityType" - ] - rsu_data[enrollment_status["_id"]]["project_id"] = enrollment_status[ - "project_id" - ] - rsu_data[enrollment_status["_id"]]["deviceHealth"] = enrollment_status[ - "deviceHealth" - ] + es_id = enrollment_status["_id"] + if es_id in rsu_data.get_dict(): + rsu_data.set_provisioner_company(es_id, enrollment_status["provisionerCompany_id"]) + rsu_data.set_entity_type(es_id, enrollment_status["entityType"]) + rsu_data.set_project_id(es_id, enrollment_status["project_id"]) + rsu_data.set_device_health(es_id, enrollment_status["deviceHealth"]) # If the device has yet to download its first set of certs, set the expiration time to when it was enrolled if "authorizationCertInfo" in enrollment_status["enrollments"][0]: - rsu_data[enrollment_status["_id"]][ - "expiration" - ] = enrollment_status["enrollments"][0]["authorizationCertInfo"][ - "expireTimeOfLatestDownloadedCert" - ] + rsu_data.set_expiration(es_id, enrollment_status["enrollments"][0]["authorizationCertInfo"]["expireTimeOfLatestDownloadedCert"]) else: - rsu_data[enrollment_status["_id"]]["expiration"] = None + rsu_data.set_expiration(es_id, None) messages_processed = messages_processed + 1 page = page + 1 - logging.info("Processed {} messages".format(messages_processed)) - return rsu_data + logger.info("Processed {} messages".format(messages_processed)) + return rsu_data.get_dict() def insert_scms_data(data): - logging.info("Inserting SCMS data into PostgreSQL...") + logger.info("Inserting SCMS data into PostgreSQL...") now_ts = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S.000Z") query = ( 'INSERT INTO public.scms_health("timestamp", health, expiration, rsu_id) VALUES' ) for value in data.values(): - try: - value["deviceHealth"] - except KeyError: - logging.warning("deviceHealth not found in data for RSU with id {}, is it real data?".format(value["rsu_id"])) + if validate_scms_data(value) is False: continue health = "1" if value["deviceHealth"] == "Healthy" else "0" @@ -107,11 +128,39 @@ def insert_scms_data(data): else: query = query + f" ('{now_ts}', '{health}', NULL, {value['rsu_id']})," - pgquery.write_db(query[:-1]) - logging.info( + query = query[:-1] # remove comma + pgquery.write_db(query) + logger.info( "SCMS data inserted {} messages into PostgreSQL...".format(len(data.values())) ) +def validate_scms_data(value): + """Validate the SCMS data + + Args: + value (dict): SCMS data + """ + + try: + value["rsu_id"] + except KeyError as e: + logger.warning("rsu_id not found in data, is it real data? exception: {}".format(e)) + return False + + try: + value["deviceHealth"] + except KeyError as e: + logger.warning("deviceHealth not found in data for RSU with id {}, is it real data? exception: {}".format(value["rsu_id"], e)) + return False + + try: + value["expiration"] + except KeyError as e: + logger.warning("expiration not found in data for RSU with id {}, is it real data? exception: {}".format(value["rsu_id"], e)) + return False + + return True + if __name__ == "__main__": # Configure logging based on ENV var or use default if not set @@ -121,4 +170,4 @@ def insert_scms_data(data): logging.basicConfig(format="%(levelname)s:%(message)s", level=log_level) scms_statuses = get_scms_status_data() - insert_scms_data(scms_statuses) + insert_scms_data(scms_statuses) \ No newline at end of file diff --git a/services/addons/images/iss_health_check/iss_token.py b/services/addons/images/iss_health_check/iss_token.py index fb33829a..37cf9536 100644 --- a/services/addons/images/iss_health_check/iss_token.py +++ b/services/addons/images/iss_health_check/iss_token.py @@ -6,6 +6,10 @@ import uuid import logging + +# Set up logging +logger = logging.getLogger(__name__) + # Get storage type from environment variable def get_storage_type(): """Get the storage type for the ISS SCMS API token @@ -13,7 +17,7 @@ def get_storage_type(): try : os.environ["STORAGE_TYPE"] except KeyError: - logging.error("STORAGE_TYPE environment variable not set, exiting") + logger.error("STORAGE_TYPE environment variable not set, exiting") exit(1) storageTypeCaseInsensitive = os.environ["STORAGE_TYPE"].casefold() @@ -22,7 +26,7 @@ def get_storage_type(): elif storageTypeCaseInsensitive == "postgres": return "postgres" else: - logging.error("STORAGE_TYPE environment variable not set to a valid value, exiting") + logger.error("STORAGE_TYPE environment variable not set to a valid value, exiting") exit(1) @@ -40,7 +44,7 @@ def create_secret(client, secret_id, parent): "secret": {"replication": {"automatic": {}}}, } ) - logging.debug("New secret created") + logger.debug("New secret created") def check_if_secret_exists(client, secret_id, parent): @@ -54,7 +58,7 @@ def check_if_secret_exists(client, secret_id, parent): ): # secret names are in the form of "projects/project_id/secrets/secret_id" if secret.name.split("/")[-1] == secret_id: - logging.debug(f"Secret {secret_id} exists") + logger.debug(f"Secret {secret_id} exists") return True return False @@ -84,7 +88,7 @@ def add_secret_version(client, secret_id, parent, data): "payload": {"data": str.encode(json.dumps(data))}, } ) - logging.debug("New version added") + logger.debug("New version added") # Postgres functions @@ -116,7 +120,7 @@ def get_latest_data(table_name): toReturn["id"] = data[0][0] # id toReturn["name"] = data[0][1] # common_name toReturn["token"] = data[0][2] # token - logging.debug(f"Received token: {toReturn['name']} with id {toReturn['id']}") + logger.debug(f"Received token: {toReturn['name']} with id {toReturn['id']}") return toReturn @@ -147,10 +151,10 @@ def get_token(): value = get_latest_secret_version(client, secret_id, parent) friendly_name = value["name"] token = value["token"] - logging.debug(f"Received token: {friendly_name}") + logger.debug(f"Received token: {friendly_name}") else: # If there is no available ISS token secret, create secret - logging.debug("Secret does not exist, creating secret") + logger.debug("Secret does not exist, creating secret") create_secret(client, secret_id, parent) # Use environment variable for first run with new secret token = os.environ["ISS_API_KEY"] @@ -166,7 +170,7 @@ def get_token(): id = value["id"] friendly_name = value["name"] token = value["token"] - logging.debug(f"Received token: {friendly_name} with id {id}") + logger.debug(f"Received token: {friendly_name} with id {id}") else: # if there is no data, use environment variable for first run token = os.environ["ISS_API_KEY"] @@ -182,20 +186,20 @@ def get_token(): iss_post_body = {"friendlyName": new_friendly_name, "expireDays": 1} # Create new ISS SCMS API Token to ensure its freshness - logging.debug("POST: " + iss_base) + logger.debug("POST: " + iss_base) response = requests.post(iss_base, json=iss_post_body, headers=iss_headers) try: new_token = response.json()["Item"] except requests.JSONDecodeError: - logging.error("Failed to decode JSON response from ISS SCMS API. Response: " + response.text) + logger.error("Failed to decode JSON response from ISS SCMS API. Response: " + response.text) exit(1) - logging.debug(f"Received new token: {new_friendly_name}") + logger.debug(f"Received new token: {new_friendly_name}") if data_exists: # If exists, delete previous API key to prevent key clutter iss_delete_body = {"friendlyName": friendly_name} requests.delete(iss_base, json=iss_delete_body, headers=iss_headers) - logging.debug(f"Old token has been deleted from ISS SCMS: {friendly_name}") + logger.debug(f"Old token has been deleted from ISS SCMS: {friendly_name}") version_data = {"name": new_friendly_name, "token": new_token} diff --git a/services/addons/tests/firmware_manager/test_download_blob.py b/services/addons/tests/firmware_manager/test_download_blob.py index 224b4299..c0c9cc9b 100644 --- a/services/addons/tests/firmware_manager/test_download_blob.py +++ b/services/addons/tests/firmware_manager/test_download_blob.py @@ -1,7 +1,9 @@ from unittest.mock import MagicMock, patch import os +import pytest from addons.images.firmware_manager import download_blob +from addons.images.firmware_manager.download_blob import UnsupportedFileTypeException @patch.dict( @@ -17,24 +19,42 @@ def test_download_gcp_blob(mock_storage_client, mock_logging): # run download_blob.download_gcp_blob( - blob_name="test.blob", destination_file_name="/home/test/" + blob_name="test.tar", destination_file_name="/home/test/" ) # validate mock_storage_client.assert_called_with("test-project") mock_client.get_bucket.assert_called_with("test-bucket") - mock_bucket.blob.assert_called_with("test.blob") + mock_bucket.blob.assert_called_with("test.tar") mock_blob.download_to_filename.assert_called_with("/home/test/") mock_logging.info.assert_called_with( - "Downloaded storage object test.blob from bucket test-bucket to local file /home/test/." + "Downloaded storage object test.tar from bucket test-bucket to local file /home/test/." ) +@patch.dict( + os.environ, {"GCP_PROJECT": "test-project", "BLOB_STORAGE_BUCKET": "test-bucket"} +) +def test_download_gcp_blob_unsupported_file_type(): + # prepare + blob_name = "test.blob" + destination_file_name = "/home/test/" + + # run + with pytest.raises(UnsupportedFileTypeException): + download_blob.download_gcp_blob(blob_name, destination_file_name) + + # validate + os.system.assert_not_called() + mock_logging.error.assert_called_with( + f"Unsupported file type for storage object {blob_name}. Only .tar files are supported." + ) + @patch("addons.images.firmware_manager.download_blob.logging") def test_download_docker_blob(mock_logging): # prepare os.system = MagicMock() - blob_name = "test.blob" + blob_name = "test.tar" destination_file_name = "/home/test/" # run @@ -46,3 +66,20 @@ def test_download_docker_blob(mock_logging): f"Copied storage object {blob_name} from directory /mnt/blob_storage to local file {destination_file_name}." ) + +@patch("addons.images.firmware_manager.download_blob.logging") +def test_download_docker_blob_unsupported_file_type(mock_logging): + # prepare + os.system = MagicMock() + blob_name = "test.blob" + destination_file_name = "/home/test/" + + # run + with pytest.raises(UnsupportedFileTypeException): + download_blob.download_docker_blob(blob_name, destination_file_name) + + # validate + os.system.assert_not_called() + mock_logging.error.assert_called_with( + f"Unsupported file type for storage object {blob_name}. Only .tar files are supported." + ) \ No newline at end of file diff --git a/services/addons/tests/firmware_manager/test_upgrader.py b/services/addons/tests/firmware_manager/test_upgrader.py index 68f4df3a..75321355 100644 --- a/services/addons/tests/firmware_manager/test_upgrader.py +++ b/services/addons/tests/firmware_manager/test_upgrader.py @@ -1,7 +1,9 @@ from unittest.mock import patch import os +import pytest from addons.images.firmware_manager import upgrader +from addons.images.firmware_manager.upgrader import StorageProviderNotSupportedException # Test class for testing the abstract class @@ -109,11 +111,12 @@ def test_download_blob_not_supported(mock_Path, mock_download_gcp_blob, mock_log mock_path_obj = mock_Path.return_value test_upgrader = TestUpgrader(test_upgrade_info) - test_upgrader.download_blob() + with pytest.raises(StorageProviderNotSupportedException): + test_upgrader.download_blob() - mock_path_obj.mkdir.assert_called_with(exist_ok=True) - mock_download_gcp_blob.assert_not_called() - mock_logging.error.assert_called_with("Unsupported blob storage provider") + mock_path_obj.mkdir.assert_called_with(exist_ok=True) + mock_download_gcp_blob.assert_not_called() + mock_logging.error.assert_called_with("Unsupported blob storage provider") @patch("addons.images.firmware_manager.upgrader.logging") diff --git a/services/addons/tests/iss_health_check/test_iss_health_checker.py b/services/addons/tests/iss_health_check/test_iss_health_checker.py index ea4a4b38..0e2069d3 100644 --- a/services/addons/tests/iss_health_check/test_iss_health_checker.py +++ b/services/addons/tests/iss_health_check/test_iss_health_checker.py @@ -2,6 +2,7 @@ import os from addons.images.iss_health_check import iss_health_checker +from addons.images.iss_health_check.iss_health_checker import RsuDataWrapper @patch("addons.images.iss_health_check.iss_health_checker.pgquery.query_db") @@ -10,7 +11,8 @@ def test_get_rsu_data_no_data(mock_query_db): result = iss_health_checker.get_rsu_data() # check - assert result == {} + expected = RsuDataWrapper({}) + assert result == expected mock_query_db.assert_called_once() mock_query_db.assert_called_with( "SELECT jsonb_build_object('rsu_id', rsu_id, 'iss_scms_id', iss_scms_id) FROM public.rsus WHERE iss_scms_id IS NOT NULL ORDER BY rsu_id" @@ -27,7 +29,7 @@ def test_get_rsu_data_with_data(mock_query_db): ] result = iss_health_checker.get_rsu_data() - expected_result = {"ABC": {"rsu_id": 1}, "DEF": {"rsu_id": 2}, "GHI": {"rsu_id": 3}} + expected_result = RsuDataWrapper({"ABC": {"rsu_id": 1}, "DEF": {"rsu_id": 2}, "GHI": {"rsu_id": 3}}) # check assert result == expected_result @@ -52,7 +54,7 @@ def test_get_rsu_data_with_data(mock_query_db): def test_get_scms_status_data( mock_get_rsu_data, mock_get_token, mock_requests, mock_response ): - mock_get_rsu_data.return_value = {"ABC": {"rsu_id": 1}, "DEF": {"rsu_id": 2}} + mock_get_rsu_data.return_value = RsuDataWrapper({"ABC": {"rsu_id": 1}, "DEF": {"rsu_id": 2}}) mock_get_token.get_token.return_value = "test-token" mock_requests.get.return_value = mock_response mock_response.json.side_effect = [ @@ -141,3 +143,66 @@ def test_insert_scms_data(mock_write_db, mock_datetime): "('2022-11-03T00:00:00.000Z', '0', NULL, 2)" ) mock_write_db.assert_called_with(expectedQuery) + + +@patch("addons.images.iss_health_check.iss_health_checker.datetime") +@patch("addons.images.iss_health_check.iss_health_checker.pgquery.write_db") +def test_insert_scms_data_no_rsu_id(mock_write_db, mock_datetime): + mock_datetime.strftime.return_value = "2022-11-03T00:00:00.000Z" + test_data = { + "ABC": { + "deviceHealth": "Healthy", + "expiration": "2022-11-02T00:00:00.000Z", + }, + "DEF": {"rsu_id": 2, "deviceHealth": "Unhealthy", "expiration": None}, + } + # call + iss_health_checker.insert_scms_data(test_data) + + expectedQuery = ( + 'INSERT INTO public.scms_health("timestamp", health, expiration, rsu_id) VALUES ' + "('2022-11-03T00:00:00.000Z', '0', NULL, 2)" + ) + mock_write_db.assert_called_with(expectedQuery) + + +@patch("addons.images.iss_health_check.iss_health_checker.datetime") +@patch("addons.images.iss_health_check.iss_health_checker.pgquery.write_db") +def test_insert_scms_data_no_deviceHealth(mock_write_db, mock_datetime): + mock_datetime.strftime.return_value = "2022-11-03T00:00:00.000Z" + test_data = { + "ABC": { + "rsu_id": 1, + "expiration": "2022-11-02T00:00:00.000Z", + }, + "DEF": {"rsu_id": 2, "deviceHealth": "Unhealthy", "expiration": None}, + } + # call + iss_health_checker.insert_scms_data(test_data) + + expectedQuery = ( + 'INSERT INTO public.scms_health("timestamp", health, expiration, rsu_id) VALUES ' + "('2022-11-03T00:00:00.000Z', '0', NULL, 2)" + ) + mock_write_db.assert_called_with(expectedQuery) + + +@patch("addons.images.iss_health_check.iss_health_checker.datetime") +@patch("addons.images.iss_health_check.iss_health_checker.pgquery.write_db") +def test_insert_scms_data_no_expiration(mock_write_db, mock_datetime): + mock_datetime.strftime.return_value = "2022-11-03T00:00:00.000Z" + test_data = { + "ABC": { + "rsu_id": 1, + "deviceHealth": "Healthy", + }, + "DEF": {"rsu_id": 2, "deviceHealth": "Unhealthy", "expiration": "test"}, + } + # call + iss_health_checker.insert_scms_data(test_data) + + expectedQuery = ( + 'INSERT INTO public.scms_health("timestamp", health, expiration, rsu_id) VALUES ' + "('2022-11-03T00:00:00.000Z', '0', 'test', 2)" + ) + mock_write_db.assert_called_with(expectedQuery) \ No newline at end of file