From 0ec3e8a021ed482f8f0818d0dcf2e77082aebd1e Mon Sep 17 00:00:00 2001 From: Rishi Garg Date: Sun, 13 Oct 2024 15:40:36 +0530 Subject: [PATCH] Update version ordering for packages Signed-off-by: Rishi Garg --- .../management/commands/update_packages.py | 130 +++++++++++++++ vulnerabilities/models.py | 156 ++++++++++++++---- vulnerabilities/templates/packages.html | 2 +- vulnerabilities/views.py | 10 +- 4 files changed, 264 insertions(+), 34 deletions(-) create mode 100644 vulnerabilities/management/commands/update_packages.py diff --git a/vulnerabilities/management/commands/update_packages.py b/vulnerabilities/management/commands/update_packages.py new file mode 100644 index 000000000..aa56de2e0 --- /dev/null +++ b/vulnerabilities/management/commands/update_packages.py @@ -0,0 +1,130 @@ +from django.core.management.base import BaseCommand +from vulnerabilities.models import Package +from packaging.version import parse, InvalidVersion +from packageurl import PackageURL +import re + +class VersionHandler: + def __init__(self, version_string): + self.original_version = str(version_string) if version_string is not None else '' + self.parsed_version = self.parse_version(self.original_version) + + def parse_version(self, version_string): + if not version_string: + return (float('inf'),) + + # Handle date-based versions like YYYY-MM-DD + date_match = re.match(r'(\d{4})-(\d{2})-(\d{2})', version_string) + if date_match: + return (int(date_match.group(1)), int(date_match.group(2)), int(date_match.group(3)), '', '') + + # Handle versions with underscores, e.g., 1.2.3_4 + underscore_match = re.match(r'(\d+)\.(\d+)\.(\d+)_(\d+)', version_string) + if underscore_match: + return (int(underscore_match.group(1)), int(underscore_match.group(2)), + int(underscore_match.group(3)), int(underscore_match.group(4)), '') + + # Handle versions with build metadata, e.g., 1.2.3-alpha + build_match = re.match(r'(\d+)\.(\d+)\.(\d+)([.-].+)', version_string) + if build_match: + return (int(build_match.group(1)), int(build_match.group(2)), + int(build_match.group(3)), build_match.group(4), '') + + # Handle git commit hashes (40-character hex strings) + if re.match(r'^[a-f0-9]{40}$', version_string): + return (0, 0, 0, '', version_string) + + # Attempt to parse using standard version parsing + try: + parsed = parse(version_string) + return (parsed.major, parsed.minor, parsed.micro, parsed.pre, parsed.post) + except InvalidVersion: + return (float('inf'),) + + def __lt__(self, other): + if not isinstance(other, VersionHandler): + return NotImplemented + return self.parsed_version < other.parsed_version + + def __eq__(self, other): + if not isinstance(other, VersionHandler): + return NotImplemented + return self.parsed_version == other.parsed_version + +def extract_ecosystem_from_purl(purl): + try: + return PackageURL.from_string(purl).type if purl else '' + except ValueError: + return '' + +def check_if_prerelease(version): + return bool(version.parsed_version and len(version.parsed_version) > 3 and version.parsed_version[3]) + +class Command(BaseCommand): + help = 'Update version ordering and pre-release fields for existing packages' + + def add_arguments(self, parser): + parser.add_argument('--batch-size', type=int, default=1000, help='Number of packages to process in each batch') + + def handle(self, *args, **options): + batch_size = options['batch_size'] + packages = Package.objects.all() + total_packages = packages.count() + updated_packages = 0 + invalid_version_packages = [] + packages_missing_version = [] + + self.stdout.write(f"Processing {total_packages} packages in batches of {batch_size}") + + for start_index in range(0, total_packages, batch_size): + package_batch = packages[start_index:start_index + batch_size] + + try: + sorted_batch = sorted( + package_batch, + key=lambda pkg: ( + extract_ecosystem_from_purl(pkg.package_url), + VersionHandler(pkg.version) if pkg.version else VersionHandler('') + ) + ) + except TypeError as error: + self.stdout.write(self.style.ERROR(f"Type error during sorting: {str(error)}")) + continue + except Exception as error: + self.stdout.write(self.style.ERROR(f"Unexpected error during sorting: {str(error)}")) + continue + + for index, package in enumerate(sorted_batch): + version_handler = VersionHandler(package.version) + + new_version_order = float(index + 1) + is_prerelease = check_if_prerelease(version_handler) + + if package.version_order != new_version_order or package.is_pre_release != is_prerelease: + package.version_order = new_version_order + package.is_pre_release = is_prerelease + package.save() + updated_packages += 1 + + if version_handler.parsed_version == (float('inf'),): + invalid_version_packages.append(package) + if not package.version: + packages_missing_version.append(package) + + self.stdout.write(f"Processed {min(start_index + batch_size, total_packages)} / {total_packages} packages") + + self.stdout.write(self.style.SUCCESS(f'Updated {updated_packages} packages successfully.')) + + if invalid_version_packages: + self.stdout.write(self.style.WARNING(f'Found {len(invalid_version_packages)} packages with invalid versions:')) + for pkg in invalid_version_packages[:10]: + self.stdout.write(f' - Package ID: {pkg.id}, Version: {pkg.version}, PURL: {pkg.package_url}') + if len(invalid_version_packages) > 10: + self.stdout.write(f' ... and {len(invalid_version_packages) - 10} more.') + + if packages_missing_version: + self.stdout.write(self.style.WARNING(f'Found {len(packages_missing_version)} packages without versions:')) + for pkg in packages_missing_version[:10]: + self.stdout.write(f' - Package ID: {pkg.id}, PURL: {pkg.package_url}') + if len(packages_missing_version) > 10: + self.stdout.write(f' ... and {len(packages_missing_version) - 10} more.') \ No newline at end of file diff --git a/vulnerabilities/models.py b/vulnerabilities/models.py index fc5eb5e3c..eb512d084 100644 --- a/vulnerabilities/models.py +++ b/vulnerabilities/models.py @@ -591,16 +591,83 @@ def get_purl_query_lookups(purl): return purl_to_dict(plain_purl, with_empty=False) +from django.db import models +from django.urls import reverse +from django.db.models import Prefetch +from packageurl import PackageURL +from vulnerabilities import utils +from vulnerabilities.utils import normalize_purl +from vulnerabilities.utils import purl_to_dict +from packageurl.contrib.django.models import PackageURLMixin +from packageurl.contrib.django.models import PackageURLQuerySet +from vulnerabilities.utils import classproperty +from packaging.version import parse, InvalidVersion +import re + +class CustomVersion: + def __init__(self, version_string): + self.original = str(version_string) if version_string is not None else '' + self.parsed = self.parse_custom(self.original) + + def parse_custom(self, version_string): + if not version_string: + return (float('inf'),) # Represents a missing or invalid version, always sorts to the end + + # Handle date-like versions + date_match = re.match(r'(\d{4})-(\d{2})-(\d{2})', version_string) + if date_match: + return (int(date_match.group(1)), int(date_match.group(2)), int(date_match.group(3)), '', '') + + # Handle versions with underscores (e.g., 1.8.0_361) + underscore_match = re.match(r'(\d+)\.(\d+)\.(\d+)_(\d+)', version_string) + if underscore_match: + return (int(underscore_match.group(1)), int(underscore_match.group(2)), + int(underscore_match.group(3)), int(underscore_match.group(4)), '') + + # Handle versions with build numbers or additional qualifiers + build_match = re.match(r'(\d+)\.(\d+)\.(\d+)([.-].+)', version_string) + if build_match: + return (int(build_match.group(1)), int(build_match.group(2)), + int(build_match.group(3)), build_match.group(4), '') + + # Handle commit hashes + if re.match(r'^[a-f0-9]{40}$', version_string): + return (0, 0, 0, '', version_string) + + # Default to packaging.version.parse for standard versions + try: + parsed = parse(version_string) + return (parsed.major, parsed.minor, parsed.micro, parsed.pre, parsed.post) + except InvalidVersion: + return (float('inf'),) # Represents an invalid version + + def __lt__(self, other): + if not isinstance(other, CustomVersion): + return NotImplemented + if self.parsed is None and other.parsed is None: + return self.original < other.original + if self.parsed is None: + return True + if other.parsed is None: + return False + return self.parsed < other.parsed + + def __eq__(self, other): + if not isinstance(other, CustomVersion): + return NotImplemented + if self.parsed is None and other.parsed is None: + return self.original == other.original + if self.parsed is None or other.parsed is None: + return False + return self.parsed == other.parsed + +def is_pre_release(version): + return bool(version.parsed and len(version.parsed) > 3 and version.parsed[3]) + class Package(PackageURLMixin): """ A software package with related vulnerabilities. """ - - # Remove the `qualifers` and `set_package_url` overrides after - # https://github.com/package-url/packageurl-python/pull/35 - # https://github.com/package-url/packageurl-python/pull/67 - # gets merged - vulnerabilities = models.ManyToManyField( to="Vulnerability", through="PackageRelatedVulnerability" ) @@ -619,6 +686,10 @@ class Package(PackageURLMixin): db_index=True, ) + version_order = models.FloatField(null=True, blank=True, editable=False, help_text='Numeric representation of version for sorting purposes') + is_prerelease = models.BooleanField(default=False, help_text="Indicates if this version is a pre-release.") + is_vulnerable = models.BooleanField(default=False, help_text="Indicates if this version is vulnerable to known issues.") + is_ghost = models.BooleanField( default=False, help_text="True if the package does not exist in the upstream package manager or its repository.", @@ -627,9 +698,8 @@ class Package(PackageURLMixin): objects = PackageQuerySet.as_manager() def save(self, *args, **kwargs): - """ - Save, normalizing PURL fields. - """ + skip_version_ordering = kwargs.pop('skip_version_ordering', False) + purl = PackageURL( type=self.type, namespace=self.namespace, @@ -639,8 +709,6 @@ def save(self, *args, **kwargs): subpath=self.subpath, ) - # We re-parse the purl to ensure name and namespace - # are set correctly normalized = normalize_purl(purl=purl) for name, value in purl_to_dict(normalized).items(): @@ -649,11 +717,47 @@ def save(self, *args, **kwargs): self.package_url = str(normalized) plain_purl = utils.plain_purl(normalized) self.plain_package_url = str(plain_purl) + + if self.version and not skip_version_ordering: + self.set_version_order() + + self.is_prerelease = any(tag in self.version.lower() for tag in ['alpha', 'beta', 'rc', 'dev']) + + if self.is_vulnerable is None: + self.is_vulnerable = False + super().save(*args, **kwargs) - @property - def purl(self): - return self.package_url + def set_version_order(self): + custom_version = CustomVersion(self.version) + + similar_packages = Package.objects.filter( + type=self.type, + namespace=self.namespace, + name=self.name + ).exclude(pk=self.pk).order_by('version_order') + + if not similar_packages.exists(): + self.version_order = 500000000.0 + else: + insert_position = 0 + for pkg in similar_packages: + if CustomVersion(pkg.version) > custom_version: + break + insert_position += 1 + + if insert_position == 0: + next_order = similar_packages.first().version_order + self.version_order = next_order / 2 + elif insert_position == similar_packages.count(): + prev_order = similar_packages.last().version_order + self.version_order = prev_order + 1000000 + else: + prev_order = similar_packages[insert_position - 1].version_order + next_order = similar_packages[insert_position].version_order + self.version_order = (prev_order + next_order) / 2 + + self.version_order = max(0, min(self.version_order, 1000000000)) class Meta: unique_together = ["type", "namespace", "name", "version", "qualifiers", "subpath"] @@ -662,6 +766,10 @@ class Meta: def __str__(self): return self.package_url + @property + def purl(self): + return self.package_url + @property def affected_by(self): """ @@ -673,7 +781,6 @@ def affected_by(self): vulnerable_to = affected_by @property - # TODO: consider renaming to "fixes" or "fixing" ? (TBD) and updating the docstring def fixing(self): """ Return a queryset of vulnerabilities fixed by this package. @@ -698,14 +805,13 @@ def get_absolute_url(self): """ Return this Package details URL. """ - return reverse("package_details", args=[self.purl]) + return reverse('package_details', args=[self.purl]) def get_details_url(self, request): """ Return this Package details URL. """ from rest_framework.reverse import reverse - return reverse("package_details", kwargs={"purl": self.purl}, request=request) def sort_by_version(self, packages): @@ -714,16 +820,11 @@ def sort_by_version(self, packages): """ if not packages: return [] - return sorted(packages, key=lambda x: self.version_class(x.version)) - - @cached_property - def version_class(self): - range_class = RANGE_CLASS_BY_SCHEMES.get(self.type) - return range_class.version_class if range_class else Version + return sorted(packages, key=lambda x: CustomVersion(x.version)) - @cached_property + @property def current_version(self): - return self.version_class(self.version) + return CustomVersion(self.version) @property def next_non_vulnerable_version(self): @@ -754,7 +855,7 @@ def get_non_vulnerable_versions(self): later_non_vulnerable_versions = [ non_vuln_ver for non_vuln_ver in sorted_versions - if self.version_class(non_vuln_ver.version) > self.current_version + if CustomVersion(non_vuln_ver.version) > self.current_version ] if later_non_vulnerable_versions: @@ -806,7 +907,7 @@ def get_affecting_vulnerabilities(self): for fixed_pkg in vuln.fixed_packages: if fixed_pkg not in fixed_by_packages: continue - fixed_version = self.version_class(fixed_pkg.version) + fixed_version = CustomVersion(fixed_pkg.version) if fixed_version > self.current_version: later_fixed_packages.append(fixed_pkg) @@ -870,7 +971,6 @@ def affecting_vulns(self): ) ) - class PackageRelatedVulnerability(models.Model): """ Track the relationship between a Package and Vulnerability. diff --git a/vulnerabilities/templates/packages.html b/vulnerabilities/templates/packages.html index 1f7687429..0a54bae48 100644 --- a/vulnerabilities/templates/packages.html +++ b/vulnerabilities/templates/packages.html @@ -59,7 +59,7 @@ {{ package.purl }} + target="_self">{{ package.package_url }} {{ package.vulnerability_count }} {{ package.patched_vulnerability_count }} diff --git a/vulnerabilities/views.py b/vulnerabilities/views.py index 394dc1c36..8527d0c4d 100644 --- a/vulnerabilities/views.py +++ b/vulnerabilities/views.py @@ -83,11 +83,11 @@ def get_queryset(self, query=None): """ query = query or self.request.GET.get("search") or "" return ( - self.model.objects.search(query) - .with_vulnerability_counts() - .prefetch_related() - .order_by("package_url") - ) + self.model.objects.search(query) + .with_vulnerability_counts() + .prefetch_related() + .order_by("type", "namespace", "name", "-version_order") # Sorting by version order + ) class VulnerabilitySearch(ListView):