Skip to content

Commit

Permalink
Add tests for improve_runner
Browse files Browse the repository at this point in the history
Signed-off-by: Hritik Vijay <hritikxx8@gmail.com>
  • Loading branch information
Hritik14 authored and TG1999 committed Aug 28, 2023
1 parent 2ea3455 commit cccc233
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 8 deletions.
22 changes: 14 additions & 8 deletions vulnerabilities/improve_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def run(self) -> None:
@transaction.atomic
def process_inferences(inferences: List[Inference], advisory: Advisory, improver_name: str):
"""
Return number of inferences processed.
An atomic transaction that updates both the Advisory (e.g. date_improved)
and processes the given inferences to create or update corresponding
database fields.
Expand All @@ -65,10 +66,11 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver
erroneous. Also, the atomic transaction for every advisory and its
inferences makes sure that date_improved of advisory is consistent.
"""
inferences_processed_count = 0

if not inferences:
logger.warn(f"Nothing to improve. Source: {improver_name} Advisory id: {advisory.id}")
return
logger.warning(f"Nothing to improve. Source: {improver_name} Advisory id: {advisory.id}")
return inferences_processed_count

logger.info(f"Improving advisory id: {advisory.id}")

Expand All @@ -80,7 +82,7 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver
)

if not vulnerability:
logger.warn(f"Unable to get vulnerability for inference: {inference!r}")
logger.warning(f"Unable to get vulnerability for inference: {inference!r}")
continue

for ref in inference.references:
Expand Down Expand Up @@ -143,8 +145,12 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver
cwe_obj, created = Weakness.objects.get_or_create(cwe_id=cwe_id)
cwe_obj.vulnerabilities.add(vulnerability)
cwe_obj.save()

inferences_processed_count += 1

advisory.date_improved = datetime.now(timezone.utc)
advisory.save()
return inferences_processed_count


def create_valid_vulnerability_reference(url, reference_id=None):
Expand All @@ -168,7 +174,7 @@ def create_valid_vulnerability_reference(url, reference_id=None):
return reference


def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summary):
def get_or_create_vulnerability_and_aliases(alias_names, vulnerability_id=None, summary=None):
"""
Get or create vulnerabilitiy and aliases such that all existing and new
aliases point to the same vulnerability
Expand All @@ -188,7 +194,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
# TODO: It is possible that all those vulnerabilities are actually
# the same at data level, figure out a way to merge them
if len(existing_vulns) > 1:
logger.warn(
logger.warning(
f"Given aliases {alias_names} already exist and do not point "
f"to a single vulnerability. Cannot improve. Skipped."
)
Expand All @@ -201,7 +207,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
and vulnerability_id
and existing_alias_vuln.vulnerability_id != vulnerability_id
):
logger.warn(
logger.warning(
f"Given aliases {alias_names!r} already exist and point to existing"
f"vulnerability {existing_alias_vuln}. Unable to create Vulnerability "
f"with vulnerability_id {vulnerability_id}. Skipped"
Expand All @@ -214,7 +220,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
try:
vulnerability = Vulnerability.objects.get(vulnerability_id=vulnerability_id)
except Vulnerability.DoesNotExist:
logger.warn(
logger.warning(
f"Given vulnerability_id: {vulnerability_id} does not exist in the database"
)
return
Expand All @@ -223,7 +229,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
vulnerability.save()

if summary and summary != vulnerability.summary:
logger.warn(
logger.warning(
f"Inconsistent summary for {vulnerability!r}. "
f"Existing: {vulnerability.summary}, provided: {summary}"
)
Expand Down
167 changes: 167 additions & 0 deletions vulnerabilities/tests/test_improve_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,27 @@
# See https://aboutcode.org for more information about nexB OSS projects.
#

from collections import Counter

import pytest
from django.utils import timezone
from packageurl import PackageURL
from pytest_django.asserts import assertQuerysetEqual

from vulnerabilities.importer import Reference
from vulnerabilities.improve_runner import create_valid_vulnerability_reference
from vulnerabilities.improve_runner import get_or_create_vulnerability_and_aliases
from vulnerabilities.improve_runner import process_inferences
from vulnerabilities.improver import Improver
from vulnerabilities.improver import Inference
from vulnerabilities.models import Advisory
from vulnerabilities.models import Alias
from vulnerabilities.models import Package
from vulnerabilities.models import PackageRelatedVulnerability
from vulnerabilities.models import Vulnerability
from vulnerabilities.models import VulnerabilityReference
from vulnerabilities.models import VulnerabilityRelatedReference
from vulnerabilities.models import VulnerabilitySeverity


@pytest.mark.django_db
Expand Down Expand Up @@ -37,3 +55,152 @@ def test_create_valid_vulnerability_reference_accepts_long_references():
url="https://foo.bar",
)
assert result


@pytest.mark.django_db
def test_get_or_create_vulnerability_and_aliases_with_new_vulnerability_and_new_aliases():
alias_names = ["TAYLOR-1337", "SWIFT-1337"]
summary = "Melodious vulnerability"
vulnerability = get_or_create_vulnerability_and_aliases(
alias_names=alias_names, summary=summary
)
assert vulnerability
alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
assert Counter(alias_names_in_db) == Counter(alias_names)


@pytest.mark.django_db
def test_get_or_create_vulnerability_and_aliases_with_different_vulnerability_and_existing_aliases():
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
existing_vulnerability.save()
existing_aliases = []
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
for alias in existing_alias_names:
existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability))
Alias.objects.bulk_create(existing_aliases)

different_vulnerability = Vulnerability(vulnerability_id="VCID-New")
different_vulnerability.save()
assert not get_or_create_vulnerability_and_aliases(
alias_names=existing_alias_names, vulnerability_id=different_vulnerability.vulnerability_id
)


@pytest.mark.django_db
def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_new_aliases():
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
existing_vulnerability.save()

existing_alias_names = ["ALIAS-1", "ALIAS-2"]
vulnerability = get_or_create_vulnerability_and_aliases(
vulnerability_id="VCID-Existing", alias_names=existing_alias_names
)
assert existing_vulnerability == vulnerability

alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
assert Counter(alias_names_in_db) == Counter(existing_alias_names)


@pytest.mark.django_db
def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_existing_aliases():
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
existing_vulnerability.save()

existing_aliases = []
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
for alias in existing_alias_names:
existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability))
Alias.objects.bulk_create(existing_aliases)

vulnerability = get_or_create_vulnerability_and_aliases(
vulnerability_id="VCID-Existing", alias_names=existing_alias_names
)
assert existing_vulnerability == vulnerability

alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
assert Counter(alias_names_in_db) == Counter(existing_alias_names)


@pytest.mark.django_db
def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_existing_and_new_aliases():
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
existing_vulnerability.save()

existing_aliases = []
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
for alias in existing_alias_names:
existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability))
Alias.objects.bulk_create(existing_aliases)

new_alias_names = ["ALIAS-3", "ALIAS-4"]
alias_names = existing_alias_names + new_alias_names
vulnerability = get_or_create_vulnerability_and_aliases(
vulnerability_id="VCID-Existing", alias_names=alias_names
)
assert existing_vulnerability == vulnerability

alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
assert Counter(alias_names_in_db) == Counter(alias_names)


DUMMY_ADVISORY = Advisory(summary="dummy", created_by="tests", date_collected=timezone.now())


@pytest.mark.django_db
def test_process_inferences_with_no_inference():
assert not process_inferences(
inferences=[], advisory=DUMMY_ADVISORY, improver_name="test_improver"
)


@pytest.mark.django_db
def test_process_inferences_with_unknown_but_specified_vulnerability():
inference = Inference(vulnerability_id="VCID-Does-Not-Exist-In-DB", aliases=["MATRIX-Neo"])
assert not process_inferences(
inferences=[inference], advisory=DUMMY_ADVISORY, improver_name="test_improver"
)


INFERENCES = [
Inference(
aliases=["CVE-1", "CVE-2"],
summary="One upon a time, in a package far far away",
affected_purls=[
PackageURL(type="character", namespace="star-wars", name="anakin", version="1")
],
fixed_purl=PackageURL(
type="character", namespace="star-wars", name="darth-vader", version="1"
),
references=[Reference(reference_id="imperial-vessel-1", url="https://m47r1x.github.io")],
)
]


def get_objects_in_all_tables_used_by_process_inferences():
return {
"vulnerabilities": list(Vulnerability.objects.all()),
"aliases": list(Alias.objects.all()),
"references": list(VulnerabilityReference.objects.all()),
"advisories": list(Advisory.objects.all()),
"packages": list(Package.objects.all()),
"references": list(VulnerabilityReference.objects.all()),
"severity": list(VulnerabilitySeverity.objects.all()),
}


@pytest.mark.django_db
def test_process_inferences_idempotency():
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver")
all_objects = get_objects_in_all_tables_used_by_process_inferences()
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver")
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver")
assert all_objects == get_objects_in_all_tables_used_by_process_inferences()


@pytest.mark.django_db
def test_process_inference_idempotency_with_different_improver_names():
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_one")
all_objects = get_objects_in_all_tables_used_by_process_inferences()
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_two")
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_three")
assert all_objects == get_objects_in_all_tables_used_by_process_inferences()
2 changes: 2 additions & 0 deletions vulnerabilities/tests/test_improver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_inference_to_dict_method_with_vulnerability_id():
"affected_purls": [],
"fixed_purl": None,
"references": [],
"weaknesses": [],
}
assert expected == inference.to_dict()

Expand All @@ -46,6 +47,7 @@ def test_inference_to_dict_method_with_purls():
"affected_purls": [purl.to_dict()],
"fixed_purl": purl.to_dict(),
"references": [],
"weaknesses": [],
}
assert expected == inference.to_dict()

Expand Down

0 comments on commit cccc233

Please sign in to comment.