Skip to content

Commit

Permalink
chore: add type annotations and use named tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
jdobes authored and yungbender committed Nov 7, 2023
1 parent 81384d3 commit 6283173
Showing 1 changed file with 58 additions and 32 deletions.
90 changes: 58 additions & 32 deletions manager/status_handler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""
Module for /status API endpoint
"""
from collections import namedtuple
from datetime import datetime
from datetime import timezone
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from uuid import UUID

from peewee import DataError
from peewee import EXCLUDED
from peewee import fn
from peewee import IntegrityError
from peewee import Query
from peewee import ValuesList
from psycopg2 import IntegrityError as psycopg2IntegrityError

Expand All @@ -34,14 +41,17 @@

LOGGER = get_logger(__name__)

StatusPair = namedtuple("StatusPair", ["status_id", "status_text"])
SystemCvePair = namedtuple("SystemCvePair", ["inventory_id", "cve"])


class GetStatus(GetRequest):
"""GET to /v1/status"""

_endpoint_name = r"/v1/status"

@classmethod
def handle_get(cls, **kwargs): # pylint: disable=unused-argument
def handle_get(cls, **kwargs) -> Tuple[Dict[str, any], int]: # pylint: disable=unused-argument
"""Return the data from the Status table as JSON"""
query = (Status.select().order_by(Status.id.asc()).dicts())
status_list = []
Expand All @@ -56,7 +66,7 @@ class PatchStatus(PatchRequest):
_endpoint_name = r"/v1/status"

@staticmethod
def _prepare_data(data):
def _prepare_data(data: Dict[str, any]) -> Tuple[Optional[List[str]], List[str], Optional[int], Optional[str]]:
if "inventory_id" in data:
in_inventory_id_list = parse_str_or_list(data["inventory_id"])
else:
Expand All @@ -73,7 +83,9 @@ def _prepare_data(data):
return in_inventory_id_list, in_cve_list, in_status_id, in_status_text

@staticmethod
def _apply_system_list_filter(query, rh_account_id, in_inventory_id_list):
def _apply_system_list_filter(query: Query,
rh_account_id: int,
in_inventory_id_list: Optional[List[str]]) -> Query:
query = cyndi_join(query)
query = query.where((SystemPlatform.rh_account_id == rh_account_id) &
(SystemPlatform.when_deleted.is_null(True)))
Expand All @@ -82,7 +94,10 @@ def _apply_system_list_filter(query, rh_account_id, in_inventory_id_list):
return query

@classmethod
def _get_current_status(cls, rh_account_id, in_inventory_id_list, in_cve_list):
def _get_current_status(cls,
rh_account_id: int,
in_inventory_id_list: Optional[List[str]],
in_cve_list: List[str]) -> Dict[str, Dict[str, StatusPair]]:
# pair status
system_cve_details = (SystemCveData.select(SystemPlatform.inventory_id, CveMetadata.cve,
SystemCveData.status_id, SystemCveData.status_text)
Expand All @@ -94,7 +109,7 @@ def _get_current_status(cls, rh_account_id, in_inventory_id_list, in_cve_list):
current_status = {}
for system_cve_detail in system_cve_details:
current_status.setdefault(system_cve_detail["cve"], {})[system_cve_detail["inventory_id"]] = \
(system_cve_detail["status_id"], system_cve_detail["status_text"])
StatusPair(system_cve_detail["status_id"], system_cve_detail["status_text"])

# global status
cve_details = (CveAccountData.select(CveMetadata.cve, CveAccountData.status_id, CveAccountData.status_text)
Expand All @@ -103,11 +118,14 @@ def _get_current_status(cls, rh_account_id, in_inventory_id_list, in_cve_list):
(CveMetadata.cve << in_cve_list))
.dicts())
for cve_detail in cve_details:
current_status.setdefault(cve_detail["cve"], {})["global"] = (cve_detail["status_id"], cve_detail["status_text"])
current_status.setdefault(cve_detail["cve"], {})["global"] = StatusPair(cve_detail["status_id"], cve_detail["status_text"])
return current_status

@classmethod
def _get_affected_pairs(cls, rh_account_id, in_inventory_id_list, in_cve_list):
def _get_affected_pairs(cls,
rh_account_id: int,
in_inventory_id_list: Optional[List[str]],
in_cve_list: List[str]) -> Set[SystemCvePair]:
affected_pairs = set()
fixable_pairs = (SystemVulnerabilities.select(SystemPlatform.inventory_id, CveMetadata.cve)
.join(CveMetadata, on=(SystemVulnerabilities.cve_id == CveMetadata.id))
Expand All @@ -123,7 +141,7 @@ def _get_affected_pairs(cls, rh_account_id, in_inventory_id_list, in_cve_list):
.dicts())
fixable_pairs = cls._apply_system_list_filter(fixable_pairs, rh_account_id, in_inventory_id_list)
for pair in fixable_pairs:
affected_pairs.add((pair["inventory_id"], pair["cve"]))
affected_pairs.add(SystemCvePair(pair["inventory_id"], pair["cve"]))

unfixable_pairs = (SystemVulnerablePackage.select(SystemPlatform.inventory_id, CveMetadata.cve)
.join(VulnerablePackageCVE, on=(SystemVulnerablePackage.vulnerable_package_id == VulnerablePackageCVE.vulnerable_package_id))
Expand All @@ -136,31 +154,33 @@ def _get_affected_pairs(cls, rh_account_id, in_inventory_id_list, in_cve_list):
.dicts())
unfixable_pairs = cls._apply_system_list_filter(unfixable_pairs, rh_account_id, in_inventory_id_list)
for pair in unfixable_pairs:
affected_pairs.add((pair["inventory_id"], pair["cve"]))
affected_pairs.add(SystemCvePair(pair["inventory_id"], pair["cve"]))
return affected_pairs

@classmethod
def _get_target_status(cls, inventory_id, cve, current_status, in_status_id, in_status_text):
def _get_target_status(cls,
inventory_id: str,
cve: str,
current_status: Dict[str, Dict[str, StatusPair]],
in_status_id: Optional[int],
in_status_text: Optional[str]) -> StatusPair:
# set global CVE status_id if there is no status_id in request
global_status_id, global_status_text = current_status.get(cve, {}).get("global", (0, None))
current_status_id, current_status_text = current_status.get(cve, {}).get(inventory_id, (0, None))
global_status_pair = current_status.get(cve, {}).get("global", StatusPair(0, None))
current_status_pair = current_status.get(cve, {}).get(inventory_id, StatusPair(0, None))

if in_status_id is None and in_status_text is None:
target_status_id = global_status_id
target_status_text = global_status_text
target_status_pair = global_status_pair
else:
target_status_id = current_status_id
target_status_text = current_status_text
target_status_pair = current_status_pair

if in_status_id is not None:
target_status_id = in_status_id
target_status_text = in_status_text
target_status_pair = StatusPair(in_status_id, in_status_text)

return target_status_id, target_status_text
return target_status_pair

@classmethod
@RBAC.need_permissions(RbacRoutePermissions.SYSTEM_CVE_STATUS_EDIT)
def handle_patch(cls, **kwargs):
def handle_patch(cls, **kwargs) -> Dict[str, any]:
"""Update the "status" field for a system/cve combination"""
# pylint: disable=singleton-comparison
data = kwargs["data"]
Expand All @@ -181,19 +201,25 @@ def handle_patch(cls, **kwargs):
to_upsert = []
to_delete = []
updated = []
for inventory_id, cve in affected_pairs:
target_status_id, target_status_text = cls._get_target_status(inventory_id, cve, current_status, in_status_id, in_status_text)
current_status_row = current_status.get(cve, {}).get(inventory_id)
if not current_status_row: # insert new statuses
if target_status_id != 0 or target_status_text is not None:
to_upsert.append((UUID(inventory_id), cve, target_status_id, target_status_text))
updated.append({"inventory_id": inventory_id, "cve": cve})
for system_cve_pair in affected_pairs:
target_status_pair = cls._get_target_status(system_cve_pair.inventory_id, system_cve_pair.cve, current_status, in_status_id, in_status_text)
current_status_pair = current_status.get(system_cve_pair.cve, {}).get(system_cve_pair.inventory_id)
if not current_status_pair: # insert new statuses
if target_status_pair.status_id != 0 or target_status_pair.status_text is not None:
to_upsert.append((UUID(system_cve_pair.inventory_id),
system_cve_pair.cve,
target_status_pair.status_id,
target_status_pair.status_text))
updated.append({"inventory_id": system_cve_pair.inventory_id, "cve": system_cve_pair.cve})
else: # update existing statuses
if target_status_id != 0 or target_status_text is not None:
if target_status_id != current_status_row[0] or target_status_text != current_status_row[1]:
to_upsert.append((UUID(inventory_id), cve, target_status_id, target_status_text))
updated.append({"inventory_id": inventory_id, "cve": cve})
current_status.get(cve, {}).pop(inventory_id, None)
if target_status_pair.status_id != 0 or target_status_pair.status_text is not None:
if target_status_pair.status_id != current_status_pair.status_id or target_status_pair.status_text != current_status_pair.status_text:
to_upsert.append((UUID(system_cve_pair.inventory_id),
system_cve_pair.cve,
target_status_pair.status_id,
target_status_pair.status_text))
updated.append({"inventory_id": system_cve_pair.inventory_id, "cve": system_cve_pair.cve})
current_status.get(system_cve_pair.cve, {}).pop(system_cve_pair.inventory_id, None)

for cve, systems in current_status.items(): # delete statuses that are set to 0, or no longer relevant
for sys in systems:
Expand Down

0 comments on commit 6283173

Please sign in to comment.