Skip to content

Commit

Permalink
Merge pull request #10 from Trihydro/pr/addressing-usdot-comments
Browse files Browse the repository at this point in the history
Addressing USDOT comments regarding ISS Health Check & Firmware Manager
  • Loading branch information
payneBrandon authored Apr 26, 2024
2 parents c534751 + 2ca5153 commit dbee31a
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 61 deletions.
35 changes: 33 additions & 2 deletions services/addons/images/firmware_manager/download_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@
import os


# Download a blob from GCP Bucket Storage
def download_gcp_blob(blob_name, destination_file_name):
"""Download a file from a GCP Bucket Storage bucket to a local file.
Args:
blob_name (str): The name of the file in the bucket.
destination_file_name (str): The name of the local file to download the bucket file to.
"""

validate_file_type(blob_name)

gcp_project = os.environ.get("GCP_PROJECT")
bucket_name = os.environ.get("BLOB_STORAGE_BUCKET")
storage_client = storage.Client(gcp_project)
Expand All @@ -16,11 +24,34 @@ def download_gcp_blob(blob_name, destination_file_name):
)


# "Download" a blob from a directory mounted as a volume in a Docker container
def download_docker_blob(blob_name, destination_file_name):
"""Copy a file from a directory mounted as a volume in a Docker container to a local file.
Args:
blob_name (str): The name of the file in the directory.
destination_file_name (str): The name of the local file to copy the directory file to.
"""

validate_file_type(blob_name)

directory = "/mnt/blob_storage"
source_file_name = f"{directory}/{blob_name}"
os.system(f"cp {source_file_name} {destination_file_name}")
logging.info(
f"Copied storage object {blob_name} from directory {directory} to local file {destination_file_name}."
)

def validate_file_type(file_name):
"""Validate the file type of the file to be downloaded.
Args:
file_name (str): The name of the file to be downloaded.
"""
if not file_name.endswith(".tar"):
logging.error(f"Unsupported file type for storage object {file_name}. Only .tar files are supported.")
raise UnsupportedFileTypeException

class UnsupportedFileTypeException(Exception):
def __init__(self, message="Unsupported file type. Only .tar files are supported."):
self.message = message
super().__init__(self.message)
4 changes: 2 additions & 2 deletions services/addons/images/firmware_manager/sample.env
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ PG_DB_NAME=""
PG_DB_USER=""
PG_DB_PASS=""

# Blob storage variables
# Blob storage variables (only 'GCP' and 'DOCKER' are supported at this time)
BLOB_STORAGE_PROVIDER=DOCKER
## GCP Project and Bucket for BLOB storage (if using GCP)
GCP_PROJECT=
BLOB_STORAGE_BUCKET=
## Docker volume mount point for BLOB storage (if using Docker)
## Docker volume mount point for BLOB storage (if using DOCKER)
HOST_BLOB_STORAGE_DIRECTORY=./local_blob_storage

# For users using GCP cloud storage
Expand Down
5 changes: 5 additions & 0 deletions services/addons/images/firmware_manager/upgrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def download_blob(self):
download_blob.download_docker_blob(self.blob_name, self.local_file_name)
else:
logging.error("Unsupported blob storage provider")
raise StorageProviderNotSupportedException

# Notifies the firmware manager of the completion status for the upgrade
# success is a boolean
Expand All @@ -60,3 +61,7 @@ def notify_firmware_manager(self, success):
@abc.abstractclassmethod
def upgrade(self):
pass

class StorageProviderNotSupportedException(Exception):
def __init__(self):
super().__init__("Unsupported blob storage provider")
115 changes: 82 additions & 33 deletions services/addons/images/iss_health_check/iss_health_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,42 @@
import os
import iss_token
import common.pgquery as pgquery
from dataclasses import dataclass, field
from typing import Dict


def get_rsu_data():
# Set up logging
logger = logging.getLogger(__name__)

@dataclass
class RsuDataWrapper:
rsu_data: Dict[str, Dict[str, str]] = field(default_factory=dict)

def __init__(self, rsu_data):
self.rsu_data = rsu_data

def get_dict(self):
return self.rsu_data

def set_provisioner_company(self, scms_id, provisioner_company):
self.rsu_data[scms_id]["provisionerCompany"] = provisioner_company

def set_entity_type(self, scms_id, entity_type):
self.rsu_data[scms_id]["entityType"] = entity_type

def set_project_id(self, scms_id, project_id):
self.rsu_data[scms_id]["project_id"] = project_id

def set_device_health(self, scms_id, device_health):
self.rsu_data[scms_id]["deviceHealth"] = device_health

def set_expiration(self, scms_id, expiration):
self.rsu_data[scms_id]["expiration"] = expiration


def get_rsu_data() -> RsuDataWrapper:
"""Get RSU data from PostgreSQL and return it in a wrapper object"""

result = {}
query = (
"SELECT jsonb_build_object('rsu_id', rsu_id, 'iss_scms_id', iss_scms_id) "
Expand All @@ -16,15 +49,17 @@ def get_rsu_data():
)
data = pgquery.query_db(query)

logging.debug("Parsing results...")
logger.debug("Parsing results...")
for point in data:
point_dict = dict(point[0])
result[point_dict["iss_scms_id"]] = {"rsu_id": point_dict["rsu_id"]}

return result
return RsuDataWrapper(result)


def get_scms_status_data():
"""Get SCMS status data from ISS and return it as a dictionary"""

rsu_data = get_rsu_data()

# Create GET request headers
Expand All @@ -43,7 +78,7 @@ def get_scms_status_data():
iss_request = iss_base + "?pageSize={}&page={}&project_id={}".format(
page_size, page, project_id
)
logging.debug("GET: " + iss_request)
logger.debug("GET: " + iss_request)
response = requests.get(iss_request, headers=iss_headers)
enrollment_list = response.json()["data"]

Expand All @@ -52,50 +87,36 @@ def get_scms_status_data():

# Loop through each device on current page
for enrollment_status in enrollment_list:
if enrollment_status["_id"] in rsu_data:
rsu_data[enrollment_status["_id"]][
"provisionerCompany"
] = enrollment_status["provisionerCompany_id"]
rsu_data[enrollment_status["_id"]]["entityType"] = enrollment_status[
"entityType"
]
rsu_data[enrollment_status["_id"]]["project_id"] = enrollment_status[
"project_id"
]
rsu_data[enrollment_status["_id"]]["deviceHealth"] = enrollment_status[
"deviceHealth"
]
es_id = enrollment_status["_id"]
if es_id in rsu_data.get_dict():
rsu_data.set_provisioner_company(es_id, enrollment_status["provisionerCompany_id"])
rsu_data.set_entity_type(es_id, enrollment_status["entityType"])
rsu_data.set_project_id(es_id, enrollment_status["project_id"])
rsu_data.set_device_health(es_id, enrollment_status["deviceHealth"])

# If the device has yet to download its first set of certs, set the expiration time to when it was enrolled
if "authorizationCertInfo" in enrollment_status["enrollments"][0]:
rsu_data[enrollment_status["_id"]][
"expiration"
] = enrollment_status["enrollments"][0]["authorizationCertInfo"][
"expireTimeOfLatestDownloadedCert"
]
rsu_data.set_expiration(es_id, enrollment_status["enrollments"][0]["authorizationCertInfo"]["expireTimeOfLatestDownloadedCert"])
else:
rsu_data[enrollment_status["_id"]]["expiration"] = None
rsu_data.set_expiration(es_id, None)

messages_processed = messages_processed + 1

page = page + 1

logging.info("Processed {} messages".format(messages_processed))
return rsu_data
logger.info("Processed {} messages".format(messages_processed))
return rsu_data.get_dict()


def insert_scms_data(data):
logging.info("Inserting SCMS data into PostgreSQL...")
logger.info("Inserting SCMS data into PostgreSQL...")
now_ts = datetime.strftime(datetime.now(), "%Y-%m-%dT%H:%M:%S.000Z")

query = (
'INSERT INTO public.scms_health("timestamp", health, expiration, rsu_id) VALUES'
)
for value in data.values():
try:
value["deviceHealth"]
except KeyError:
logging.warning("deviceHealth not found in data for RSU with id {}, is it real data?".format(value["rsu_id"]))
if validate_scms_data(value) is False:
continue

health = "1" if value["deviceHealth"] == "Healthy" else "0"
Expand All @@ -107,11 +128,39 @@ def insert_scms_data(data):
else:
query = query + f" ('{now_ts}', '{health}', NULL, {value['rsu_id']}),"

pgquery.write_db(query[:-1])
logging.info(
query = query[:-1] # remove comma
pgquery.write_db(query)
logger.info(
"SCMS data inserted {} messages into PostgreSQL...".format(len(data.values()))
)

def validate_scms_data(value):
"""Validate the SCMS data
Args:
value (dict): SCMS data
"""

try:
value["rsu_id"]
except KeyError as e:
logger.warning("rsu_id not found in data, is it real data? exception: {}".format(e))
return False

try:
value["deviceHealth"]
except KeyError as e:
logger.warning("deviceHealth not found in data for RSU with id {}, is it real data? exception: {}".format(value["rsu_id"], e))
return False

try:
value["expiration"]
except KeyError as e:
logger.warning("expiration not found in data for RSU with id {}, is it real data? exception: {}".format(value["rsu_id"], e))
return False

return True


if __name__ == "__main__":
# Configure logging based on ENV var or use default if not set
Expand All @@ -121,4 +170,4 @@ def insert_scms_data(data):
logging.basicConfig(format="%(levelname)s:%(message)s", level=log_level)

scms_statuses = get_scms_status_data()
insert_scms_data(scms_statuses)
insert_scms_data(scms_statuses)
30 changes: 17 additions & 13 deletions services/addons/images/iss_health_check/iss_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
import uuid
import logging


# Set up logging
logger = logging.getLogger(__name__)

# Get storage type from environment variable
def get_storage_type():
"""Get the storage type for the ISS SCMS API token
"""
try :
os.environ["STORAGE_TYPE"]
except KeyError:
logging.error("STORAGE_TYPE environment variable not set, exiting")
logger.error("STORAGE_TYPE environment variable not set, exiting")
exit(1)

storageTypeCaseInsensitive = os.environ["STORAGE_TYPE"].casefold()
Expand All @@ -22,7 +26,7 @@ def get_storage_type():
elif storageTypeCaseInsensitive == "postgres":
return "postgres"
else:
logging.error("STORAGE_TYPE environment variable not set to a valid value, exiting")
logger.error("STORAGE_TYPE environment variable not set to a valid value, exiting")
exit(1)


Expand All @@ -40,7 +44,7 @@ def create_secret(client, secret_id, parent):
"secret": {"replication": {"automatic": {}}},
}
)
logging.debug("New secret created")
logger.debug("New secret created")


def check_if_secret_exists(client, secret_id, parent):
Expand All @@ -54,7 +58,7 @@ def check_if_secret_exists(client, secret_id, parent):
):
# secret names are in the form of "projects/project_id/secrets/secret_id"
if secret.name.split("/")[-1] == secret_id:
logging.debug(f"Secret {secret_id} exists")
logger.debug(f"Secret {secret_id} exists")
return True
return False

Expand Down Expand Up @@ -84,7 +88,7 @@ def add_secret_version(client, secret_id, parent, data):
"payload": {"data": str.encode(json.dumps(data))},
}
)
logging.debug("New version added")
logger.debug("New version added")


# Postgres functions
Expand Down Expand Up @@ -116,7 +120,7 @@ def get_latest_data(table_name):
toReturn["id"] = data[0][0] # id
toReturn["name"] = data[0][1] # common_name
toReturn["token"] = data[0][2] # token
logging.debug(f"Received token: {toReturn['name']} with id {toReturn['id']}")
logger.debug(f"Received token: {toReturn['name']} with id {toReturn['id']}")
return toReturn


Expand Down Expand Up @@ -147,10 +151,10 @@ def get_token():
value = get_latest_secret_version(client, secret_id, parent)
friendly_name = value["name"]
token = value["token"]
logging.debug(f"Received token: {friendly_name}")
logger.debug(f"Received token: {friendly_name}")
else:
# If there is no available ISS token secret, create secret
logging.debug("Secret does not exist, creating secret")
logger.debug("Secret does not exist, creating secret")
create_secret(client, secret_id, parent)
# Use environment variable for first run with new secret
token = os.environ["ISS_API_KEY"]
Expand All @@ -166,7 +170,7 @@ def get_token():
id = value["id"]
friendly_name = value["name"]
token = value["token"]
logging.debug(f"Received token: {friendly_name} with id {id}")
logger.debug(f"Received token: {friendly_name} with id {id}")
else:
# if there is no data, use environment variable for first run
token = os.environ["ISS_API_KEY"]
Expand All @@ -182,20 +186,20 @@ def get_token():
iss_post_body = {"friendlyName": new_friendly_name, "expireDays": 1}

# Create new ISS SCMS API Token to ensure its freshness
logging.debug("POST: " + iss_base)
logger.debug("POST: " + iss_base)
response = requests.post(iss_base, json=iss_post_body, headers=iss_headers)
try:
new_token = response.json()["Item"]
except requests.JSONDecodeError:
logging.error("Failed to decode JSON response from ISS SCMS API. Response: " + response.text)
logger.error("Failed to decode JSON response from ISS SCMS API. Response: " + response.text)
exit(1)
logging.debug(f"Received new token: {new_friendly_name}")
logger.debug(f"Received new token: {new_friendly_name}")

if data_exists:
# If exists, delete previous API key to prevent key clutter
iss_delete_body = {"friendlyName": friendly_name}
requests.delete(iss_base, json=iss_delete_body, headers=iss_headers)
logging.debug(f"Old token has been deleted from ISS SCMS: {friendly_name}")
logger.debug(f"Old token has been deleted from ISS SCMS: {friendly_name}")

version_data = {"name": new_friendly_name, "token": new_token}

Expand Down
Loading

0 comments on commit dbee31a

Please sign in to comment.