Skip to content

Commit

Permalink
Merge pull request #315 from pyupio/nicholas/packages-licenses
Browse files Browse the repository at this point in the history
Package license information on Safety.
  • Loading branch information
rafaelpivato authored Dec 13, 2020
2 parents 47f22f9 + 41714d8 commit 2e5b46b
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 16 deletions.
49 changes: 39 additions & 10 deletions safety/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import click
from safety import __version__
from safety import safety
from safety.formatter import report
from safety.formatter import report, license_report
import itertools
from safety.util import read_requirements, read_vulnerabilities
from safety.util import read_requirements, read_vulnerabilities, get_proxy_dict, get_packages_licenses
from safety.errors import DatabaseFetchError, DatabaseFileNotFoundError, InvalidKeyError

try:
Expand Down Expand Up @@ -66,20 +66,14 @@ def check(key, db, json, full_report, bare, stdin, files, cache, ignore, output,
d for d in pkg_resources.working_set
if d.key not in {"python", "wsgiref", "argparse"}
]
proxy_dictionary = {}
if proxyhost is not None:
if proxyprotocol in ["http", "https"]:
proxy_dictionary = {proxyprotocol: "{0}://{1}:{2}".format(proxyprotocol, proxyhost, str(proxyport))}
else:
click.secho("Proxy Protocol should be http or https only.", fg="red")
sys.exit(-1)
proxy_dictionary = get_proxy_dict(proxyprotocol, proxyhost, proxyport)
try:
vulns = safety.check(packages=packages, key=key, db_mirror=db, cached=cache, ignore_ids=ignore, proxy=proxy_dictionary)
output_report = report(vulns=vulns,
full=full_report,
json_report=json,
bare_report=bare,
checked_packages=len(packages),
checked_packages=len(packages),
db=db,
key=key)

Expand Down Expand Up @@ -128,5 +122,40 @@ def review(full_report, bare, file):
click.secho(output_report, nl=False if bare and not vulns else True)


@cli.command()
@click.option("--key", default="", envvar="SAFETY_API_KEY",
help="API Key for pyup.io's vulnerability database. Can be set as SAFETY_API_KEY "
"environment variable. Default: empty")
@click.option("--db", default="",
help="Path to a local license database. Default: empty")
@click.option("--cache/--no-cache", default=True,
help='Whether license database file should be cached.'
'Default: --cache')
@click.option("files", "--file", "-r", multiple=True, type=click.File(),
help="Read input from one (or multiple) requirement files. Default: empty")
@click.option("proxyhost", "--proxy-host", "-ph", multiple=False, type=str, default=None,
help="Proxy host IP or DNS --proxy-host")
@click.option("proxyport", "--proxy-port", "-pp", multiple=False, type=int, default=80,
help="Proxy port number --proxy-port")
@click.option("proxyprotocol", "--proxy-protocol", "-pr", multiple=False, type=str, default='http',
help="Proxy protocol (https or http) --proxy-protocol")
def license(key, db, cache, files, proxyprotocol, proxyhost, proxyport):

if files:
packages = list(itertools.chain.from_iterable(read_requirements(f, resolve=True) for f in files))
else:
import pkg_resources
packages = [
d for d in pkg_resources.working_set
if d.key not in {"python", "wsgiref", "argparse"}
]

proxy_dictionary = get_proxy_dict(proxyprotocol, proxyhost, proxyport)
licenses_db = safety.get_licenses(key, db, cache, proxy_dictionary)
filtered_packages_licenses = get_packages_licenses(packages, licenses_db)
output_report = license_report(packages=packages, licenses=filtered_packages_licenses)
click.secho(output_report, nl=True)


if __name__ == "__main__":
cli()
80 changes: 80 additions & 0 deletions safety/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import os
import textwrap

from .util import get_packages_licenses

# python 2.7 compat
try:
FileNotFoundError
Expand Down Expand Up @@ -69,6 +71,12 @@ class SheetReport(object):
+============================+===========+==========================+==========+
""".strip()

TABLE_HEADING_LICENSES = r"""
+=============================================+===========+====================+
| package | version | license |
+=============================================+===========+====================+
""".strip()

REPORT_HEADING = r"""
| REPORT |
""".strip()
Expand Down Expand Up @@ -125,6 +133,52 @@ def render(vulns, full, checked_packages, used_db):
content, SheetReport.REPORT_FOOTER]
)

@staticmethod
def render_licenses(packages, packages_licenses):
heading = SheetReport.REPORT_HEADING.replace(" ", "", 12).replace(
"REPORT", " Packages licenses"
)
if not packages_licenses:
content = "| {:76} |".format("No packages licenses found.")
return "\n".join(
[SheetReport.REPORT_BANNER, heading, SheetReport.REPORT_SECTION,
content, SheetReport.REPORT_FOOTER]
)

table = []
iteration = 1
for pkg_license in packages_licenses:
max_char = last_char = 43 # defines a limit for package name.
current_line = 1
package = pkg_license['package']
license = pkg_license['license']
version = pkg_license['version']
license_line = int(int(len(package) / max_char) / 2) + 1 # Calc to get which line to add the license info.

table.append("| {:43} | {:9} | {:18} |".format(
package[:max_char],
version[:9] if current_line == license_line else "",
license[:18] if current_line == license_line else "",
))

long_name = True if len(package[max_char:]) > 0 else False
while long_name: # If the package has a long name, break it into multiple lines.
current_line += 1
table.append("| {:43} | {:9} | {:18} |".format(
package[last_char:last_char+max_char],
version[:9] if current_line == license_line else "",
license[:18] if current_line == license_line else "",
))
last_char = last_char+max_char
long_name = True if len(package[last_char:]) > 0 else False

if iteration != len(packages_licenses): # Do not add dashes "----" for last package.
table.append("|" + ("-" * 78) + "|")
iteration += 1
return "\n".join(
[SheetReport.REPORT_BANNER, heading, SheetReport.TABLE_HEADING_LICENSES,
"\n".join(table), SheetReport.REPORT_FOOTER]
)

class BasicReport(object):
"""Basic report, intented to be used for terminals with < 80 columns"""
Expand Down Expand Up @@ -157,6 +211,24 @@ def render(vulns, full, checked_packages, used_db):
table
)

@staticmethod
def render_licenses(packages, packages_licenses):
table = [
"safety",
"packages licenses",
"---"
]
if not packages_licenses:
table.append("No packages licenses found.")
return "\n".join(table)

for pkg_license in packages_licenses:
text = pkg_license['package'] + \
", version " + pkg_license['version'] + \
", license " + pkg_license['license'] + "\n"
table.append(text)

return "\n".join(table)

class JsonReport(object):
"""Json report, for when the output is input for something else"""
Expand Down Expand Up @@ -192,3 +264,11 @@ def report(vulns, full=False, json_report=False, bare_report=False, checked_pack
if size.columns >= 80:
return SheetReport.render(vulns, full=full, checked_packages=checked_packages, used_db=used_db)
return BasicReport.render(vulns, full=full, checked_packages=checked_packages, used_db=used_db)


def license_report(packages, licenses):
size = get_terminal_size()

if size.columns >= 80:
return SheetReport.render_licenses(packages, licenses)
return BasicReport.render_licenses(packages, licenses)
23 changes: 23 additions & 0 deletions safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,26 @@ def review(vulnerabilities):
Vulnerability(**current_vuln)
)
return vulnerable


def get_licenses(key, db_mirror, cached, proxy):
key = key if key else os.environ.get("SAFETY_API_KEY", False)

if not key:
raise DatabaseFetchError("API-KEY not provided.")
if db_mirror:
mirrors = [db_mirror]
else:
mirrors = API_MIRRORS

db_name = "licenses.json"

for mirror in mirrors:
# mirror can either be a local path or a URL
if mirror.startswith("http://") or mirror.startswith("https://"):
licenses = fetch_database_url(mirror, db_name=db_name, key=key, cached=cached, proxy=proxy)
else:
licenses = fetch_database_file(mirror, db_name=db_name)
if licenses:
return licenses
raise DatabaseFetchError()
64 changes: 64 additions & 0 deletions safety/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dparse.parser import setuptools_parse_requirements_backport as _parse_requirements
from collections import namedtuple
from packaging.version import parse as parse_version
import click
import sys
import json
Expand Down Expand Up @@ -105,3 +106,66 @@ def read_requirements(fh, resolve=False):
)
except ValueError:
continue


def get_proxy_dict(proxyprotocol, proxyhost, proxyport):
proxy_dictionary = {}
if proxyhost is not None:
if proxyprotocol in ["http", "https"]:
proxy_dictionary = {proxyprotocol: "{0}://{1}:{2}".format(proxyprotocol, proxyhost, str(proxyport))}
else:
click.secho("Proxy Protocol should be http or https only.", fg="red")
sys.exit(-1)
return proxy_dictionary


def get_license_name_by_id(license_id, db):
licenses = db.get('licenses', [])
for name, id in licenses.items():
if id == license_id:
return name
return None

def get_packages_licenses(packages, licenses_db):
"""Get the licenses for the specified packages based on their version.
:param packages: packages list
:param licenses_db: the licenses db in the raw form.
:return: list of objects with the packages and their respectives licenses.
"""
packages_licenses_db = licenses_db.get('packages', {})
filtered_packages_licenses = []

for pkg in packages:
# Ignore recursive files not resolved
if isinstance(pkg, RequirementFile):
continue
# normalize the package name
pkg_name = pkg.key.replace("_", "-").lower()
# packages may have different licenses depending their version.
pkg_licenses = packages_licenses_db.get(pkg_name, [])
version_requested = parse_version(pkg.version)
license_id = None
license_name = None
for pkg_version in pkg_licenses:
license_start_version = parse_version(pkg_version['start_version'])
# Stops and return the previous stored license when a new
# license starts on a version above the requested one.
if version_requested >= license_start_version:
license_id = pkg_version['license_id']
else:
# We found the license for the version requested
break

if license_id:
license_name = get_license_name_by_id(license_id, licenses_db)
if not license_id or not license_name:
license_name = "N/A"

filtered_packages_licenses.append({
"package": pkg_name,
"version": pkg.version,
"license": license_name
})

return filtered_packages_licenses
13 changes: 13 additions & 0 deletions tests/test_db/licenses.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"licenses": {
"BSD-3-Clause": 1
},
"packages": {
"django": [
{
"start_version": "0.0",
"license_id": 1
}
]
}
}
46 changes: 40 additions & 6 deletions tests/test_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_check_from_file(self):
cached=False,
key=False,
ignore_ids=[],
proxy={}
proxy={},
)
self.assertEqual(len(vulns), 2)

Expand All @@ -160,7 +160,7 @@ def test_check_from_file_with_hash_pins(self):
cached=False,
key=False,
ignore_ids=[],
proxy={}
proxy={},
)
self.assertEqual(len(vulns), 2)

Expand All @@ -177,7 +177,7 @@ def test_multiple_versions(self):
cached=False,
key=False,
ignore_ids=[],
proxy={}
proxy={},
)
self.assertEqual(len(vulns), 4)

Expand All @@ -191,7 +191,7 @@ def test_check_live(self):
cached=False,
key=False,
ignore_ids=[],
proxy={}
proxy={},
)
self.assertEqual(len(vulns), 1)

Expand All @@ -205,7 +205,7 @@ def test_check_live_cached(self):
cached=True,
key=False,
ignore_ids=[],
proxy={}
proxy={},
)
self.assertEqual(len(vulns), 1)

Expand All @@ -218,10 +218,44 @@ def test_check_live_cached(self):
cached=True,
key=False,
ignore_ids=[],
proxy={}
proxy={},
)
self.assertEqual(len(vulns), 1)

def test_get_packages_licenses(self):
reqs = StringIO("Django==1.8.1\n\rinvalid==1.0.0")
packages = util.read_requirements(reqs)
licenses_db = safety.get_licenses(
db_mirror=os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"test_db"
),
cached=False,
key="foobarqux",
proxy={},
)
self.assertIn("licenses", licenses_db)
self.assertIn("packages", licenses_db)
self.assertIn("BSD-3-Clause", licenses_db['licenses'])
self.assertIn("django", licenses_db['packages'])

pkg_licenses = util.get_packages_licenses(packages, licenses_db)

self.assertIsInstance(pkg_licenses, list)
for pkg_license in pkg_licenses:
license = pkg_license['license']
version = pkg_license['version']
if pkg_license['package'] == 'django':
self.assertEqual(license, 'BSD-3-Clause')
self.assertEqual(version, '1.8.1')
elif pkg_license['package'] == 'invalid':
self.assertEqual(license, 'N/A')
self.assertEqual(version, '1.0.0')
else:
raise AssertionError(
"unexpected package '" + pkg_license['package'] + "' was found"
)


class ReadRequirementsTestCase(unittest.TestCase):

Expand Down

0 comments on commit 2e5b46b

Please sign in to comment.