Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressing USDOT comments regarding ISS Health Check & Firmware Manager #10

Merged
merged 9 commits into from
Apr 26, 2024
Merged
35 changes: 33 additions & 2 deletions services/addons/images/firmware_manager/download_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
import os


# Download a blob from GCP Bucket Storage
def download_gcp_blob(blob_name, destination_file_name):
"""Download a file from a GCP Bucket Storage bucket to a local file.

Args:
blob_name (str): The name of the file in the bucket.
destination_file_name (str): The name of the local file to download the bucket file to.
"""

validate_file_type(blob_name)

gcp_project = os.environ.get("GCP_PROJECT")
bucket_name = os.environ.get("BLOB_STORAGE_BUCKET")
storage_client = storage.Client(gcp_project)
Expand All @@ -16,11 +24,34 @@ def download_gcp_blob(blob_name, destination_file_name):
)


# "Download" a blob from a directory mounted as a volume in a Docker container
def download_docker_blob(blob_name, destination_file_name):
"""Copy a file from a directory mounted as a volume in a Docker container to a local file.

Args:
blob_name (str): The name of the file in the directory.
destination_file_name (str): The name of the local file to copy the directory file to.
"""

validate_file_type(blob_name)

directory = "/mnt/blob_storage"
source_file_name = f"{directory}/{blob_name}"
os.system(f"cp {source_file_name} {destination_file_name}")
logging.info(
f"Copied storage object {blob_name} from directory {directory} to local file {destination_file_name}."
)

def validate_file_type(file_name):
"""Validate the file type of the file to be downloaded.

Args:
file_name (str): The name of the file to be downloaded.
"""
if not file_name.endswith(".tar"):
logging.error(f"Unsupported file type for storage object {file_name}. Only .tar files are supported.")
raise UnsupportedFileTypeException

class UnsupportedFileTypeException(Exception):
def __init__(self, message="Unsupported file type. Only .tar files are supported."):
self.message = message
super().__init__(self.message)
4 changes: 2 additions & 2 deletions services/addons/images/firmware_manager/sample.env
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ PG_DB_NAME=""
PG_DB_USER=""
PG_DB_PASS=""

# Blob storage variables
# Blob storage variables (only 'GCP' and 'DOCKER' are supported at this time)
BLOB_STORAGE_PROVIDER=DOCKER
## GCP Project and Bucket for BLOB storage (if using GCP)
GCP_PROJECT=
BLOB_STORAGE_BUCKET=
## Docker volume mount point for BLOB storage (if using Docker)
## Docker volume mount point for BLOB storage (if using DOCKER)
HOST_BLOB_STORAGE_DIRECTORY=./local_blob_storage

# For users using GCP cloud storage
Expand Down
5 changes: 5 additions & 0 deletions services/addons/images/firmware_manager/upgrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def download_blob(self):
download_blob.download_docker_blob(self.blob_name, self.local_file_name)
else:
logging.error("Unsupported blob storage provider")
raise StorageProviderNotSupportedException

# Notifies the firmware manager of the completion status for the upgrade
# success is a boolean
Expand All @@ -60,3 +61,7 @@ def notify_firmware_manager(self, success):
@abc.abstractclassmethod
def upgrade(self):
pass

class StorageProviderNotSupportedException(Exception):
def __init__(self):
super().__init__("Unsupported blob storage provider")
115 changes: 82 additions & 33 deletions services/addons/images/iss_health_check/iss_health_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,42 @@
import os
import iss_token
import common.pgquery as pgquery
from dataclasses import dataclass, field
from typing import Dict


def get_rsu_data():
# Set up logging
logger = logging.getLogger(__name__)

@dataclass
class RsuDataWrapper:
rsu_data: Dict[str, Dict[str, str]] = field(default_factory=dict)

def __init__(self, rsu_data):
self.rsu_data = rsu_data

def get_dict(self):
return self.rsu_data

def set_provisioner_company(self, scms_id, provisioner_company):
self.rsu_data[scms_id]["provisionerCompany"] = provisioner_company

def set_entity_type(self, scms_id, entity_type):
self.rsu_data[scms_id]["entityType"] = entity_type

def set_project_id(self, scms_id, project_id):
self.rsu_data[scms_id]["project_id"] = project_id

def set_device_health(self, scms_id, device_health):
self.rsu_data[scms_id]["deviceHealth"] = device_health

def set_expiration(self, scms_id, expiration):
self.rsu_data[scms_id]["expiration"] = expiration


def get_rsu_data() -> RsuDataWrapper:
"""Get RSU data from PostgreSQL and return it in a wrapper object"""

result = {}
query = (
"SELECT jsonb_build_object('rsu_id', rsu_id, 'iss_scms_id', iss_scms_id) "
Expand All @@ -16,15 +49,17 @@ def get_rsu_data():
)
data = pgquery.query_db(query)

logging.debug("Parsing results...")
logger.debug("Parsing results...")
for point in data:
point_dict = dict(point[0])
result[point_dict["iss_scms_id"]] = {"rsu_id": point_dict["rsu_id"]}

return result
return RsuDataWrapper(result)


def get_scms_status_data():
"""Get SCMS status data from ISS and return it as a dictionary"""

rsu_data = get_rsu_data()

# Create GET request headers
Expand All @@ -43,7 +78,7 @@ def get_scms_status_data():
iss_request = iss_base + "?pageSize={}&page={}&project_id={}".format(
page_size, page, project_id
)
logging.debug("GET: " + iss_request)
logger.debug("GET: " + iss_request)
response = requests.get(iss_request, headers=iss_headers)
enrollment_list = response.json()["data"]

Expand All @@ -52,50 +87,36 @@ def get_scms_status_data():

# Loop through each device on current page
for enrollment_status in enrollment_list:
if enrollment_status["_id"] in rsu_data:
rsu_data[enrollment_status["_id"]][
"provisionerCompany"
] = enrollment_status["provisionerCompany_id"]
rsu_data[enrollment_status["_id"]]["entityType"] = enrollment_status[
"entityType"
]
rsu_data[enrollment_status["_id"]]["project_id"] = enrollment_status[
"project_id"
]
rsu_data[enrollment_status["_id"]]["deviceHealth"] = enrollment_status[
"deviceHealth"
]
es_id = enrollment_status["_id"]
if es_id in rsu_data.get_dict():
rsu_data.set_provisioner_company(es_id, enrollment_status["provisionerCompany_id"])
rsu_data.set_entity_type(es_id, enrollment_status["entityType"])
rsu_data.set_project_id(es_id, enrollment_status["project_id"])
rsu_data.set_device_health(es_id, enrollment_status["deviceHealth"])

# If the device has yet to download its first set of certs, set the expiration time to when it was enrolled
if "authorizationCertInfo" in enrollment_status["enrollments"][0]:
rsu_data[enrollment_status["_id"]][
"expiration"
] = enrollment_status["enrollments"][0]["authorizationCertInfo"][
"expireTimeOfLatestDownloadedCert"
]
rsu_data.set_expiration(es_id, enrollment_status["enrollments"][0]["authorizationCertInfo"]["expireTimeOfLatestDownloadedCert"])
else:
rsu_data[enrollment_status["_id"]]["expiration"] = None
rsu_data.set_expiration(es_id, None)

messages_processed = messages_processed + 1

page = page + 1

logging.info("Processed {} messages".format(messages_processed))
return rsu_data
logger.info("Processed {} messages".format(messages_processed))
return rsu_data.get_dict()


def insert_scms_data(data):
logging.info("Inserting SCMS data into PostgreSQL...")
logger.info("Inserting SCMS data into PostgreSQL...")
now_ts = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S.000Z")

query = (
'INSERT INTO public.scms_health("timestamp", health, expiration, rsu_id) VALUES'
)
for value in data.values():
try:
value["deviceHealth"]
except KeyError:
logging.warning("deviceHealth not found in data for RSU with id {}, is it real data?".format(value["rsu_id"]))
if validate_scms_data(value) is False:
continue

health = "1" if value["deviceHealth"] == "Healthy" else "0"
Expand All @@ -107,11 +128,39 @@ def insert_scms_data(data):
else:
query = query + f" ('{now_ts}', '{health}', NULL, {value['rsu_id']}),"

pgquery.write_db(query[:-1])
logging.info(
query = query[:-1] # remove comma
pgquery.write_db(query)
logger.info(
"SCMS data inserted {} messages into PostgreSQL...".format(len(data.values()))
)

def validate_scms_data(value):
"""Validate the SCMS data

Args:
value (dict): SCMS data
"""

try:
value["rsu_id"]
except KeyError as e:
logger.warning("rsu_id not found in data, is it real data? exception: {}".format(e))
return False

try:
value["deviceHealth"]
except KeyError as e:
logger.warning("deviceHealth not found in data for RSU with id {}, is it real data? exception: {}".format(value["rsu_id"], e))
return False

try:
value["expiration"]
except KeyError as e:
logger.warning("expiration not found in data for RSU with id {}, is it real data? exception: {}".format(value["rsu_id"], e))
return False

return True


if __name__ == "__main__":
# Configure logging based on ENV var or use default if not set
Expand All @@ -121,4 +170,4 @@ def insert_scms_data(data):
logging.basicConfig(format="%(levelname)s:%(message)s", level=log_level)

scms_statuses = get_scms_status_data()
insert_scms_data(scms_statuses)
insert_scms_data(scms_statuses)
30 changes: 17 additions & 13 deletions services/addons/images/iss_health_check/iss_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
import uuid
import logging


# Set up logging
logger = logging.getLogger(__name__)

# Get storage type from environment variable
def get_storage_type():
"""Get the storage type for the ISS SCMS API token
"""
try :
os.environ["STORAGE_TYPE"]
except KeyError:
logging.error("STORAGE_TYPE environment variable not set, exiting")
logger.error("STORAGE_TYPE environment variable not set, exiting")
exit(1)

storageTypeCaseInsensitive = os.environ["STORAGE_TYPE"].casefold()
Expand All @@ -22,7 +26,7 @@ def get_storage_type():
elif storageTypeCaseInsensitive == "postgres":
return "postgres"
else:
logging.error("STORAGE_TYPE environment variable not set to a valid value, exiting")
logger.error("STORAGE_TYPE environment variable not set to a valid value, exiting")
exit(1)


Expand All @@ -40,7 +44,7 @@ def create_secret(client, secret_id, parent):
"secret": {"replication": {"automatic": {}}},
}
)
logging.debug("New secret created")
logger.debug("New secret created")


def check_if_secret_exists(client, secret_id, parent):
Expand All @@ -54,7 +58,7 @@ def check_if_secret_exists(client, secret_id, parent):
):
# secret names are in the form of "projects/project_id/secrets/secret_id"
if secret.name.split("/")[-1] == secret_id:
logging.debug(f"Secret {secret_id} exists")
logger.debug(f"Secret {secret_id} exists")
return True
return False

Expand Down Expand Up @@ -84,7 +88,7 @@ def add_secret_version(client, secret_id, parent, data):
"payload": {"data": str.encode(json.dumps(data))},
}
)
logging.debug("New version added")
logger.debug("New version added")


# Postgres functions
Expand Down Expand Up @@ -116,7 +120,7 @@ def get_latest_data(table_name):
toReturn["id"] = data[0][0] # id
toReturn["name"] = data[0][1] # common_name
toReturn["token"] = data[0][2] # token
logging.debug(f"Received token: {toReturn['name']} with id {toReturn['id']}")
logger.debug(f"Received token: {toReturn['name']} with id {toReturn['id']}")
return toReturn


Expand Down Expand Up @@ -147,10 +151,10 @@ def get_token():
value = get_latest_secret_version(client, secret_id, parent)
friendly_name = value["name"]
token = value["token"]
logging.debug(f"Received token: {friendly_name}")
logger.debug(f"Received token: {friendly_name}")
else:
# If there is no available ISS token secret, create secret
logging.debug("Secret does not exist, creating secret")
logger.debug("Secret does not exist, creating secret")
create_secret(client, secret_id, parent)
# Use environment variable for first run with new secret
token = os.environ["ISS_API_KEY"]
Expand All @@ -166,7 +170,7 @@ def get_token():
id = value["id"]
friendly_name = value["name"]
token = value["token"]
logging.debug(f"Received token: {friendly_name} with id {id}")
logger.debug(f"Received token: {friendly_name} with id {id}")
else:
# if there is no data, use environment variable for first run
token = os.environ["ISS_API_KEY"]
Expand All @@ -182,20 +186,20 @@ def get_token():
iss_post_body = {"friendlyName": new_friendly_name, "expireDays": 1}

# Create new ISS SCMS API Token to ensure its freshness
logging.debug("POST: " + iss_base)
logger.debug("POST: " + iss_base)
response = requests.post(iss_base, json=iss_post_body, headers=iss_headers)
try:
new_token = response.json()["Item"]
except requests.JSONDecodeError:
logging.error("Failed to decode JSON response from ISS SCMS API. Response: " + response.text)
logger.error("Failed to decode JSON response from ISS SCMS API. Response: " + response.text)
exit(1)
logging.debug(f"Received new token: {new_friendly_name}")
logger.debug(f"Received new token: {new_friendly_name}")

if data_exists:
# If exists, delete previous API key to prevent key clutter
iss_delete_body = {"friendlyName": friendly_name}
requests.delete(iss_base, json=iss_delete_body, headers=iss_headers)
logging.debug(f"Old token has been deleted from ISS SCMS: {friendly_name}")
logger.debug(f"Old token has been deleted from ISS SCMS: {friendly_name}")

version_data = {"name": new_friendly_name, "token": new_token}

Expand Down
Loading
Loading