diff --git a/safety/cli.py b/safety/cli.py index 30a07056..3156d4d6 100644 --- a/safety/cli.py +++ b/safety/cli.py @@ -4,9 +4,9 @@ import click from safety import __version__ from safety import safety -from safety.formatter import report +from safety.formatter import report, license_report import itertools -from safety.util import read_requirements, read_vulnerabilities +from safety.util import read_requirements, read_vulnerabilities, get_proxy_dict, get_packages_licenses from safety.errors import DatabaseFetchError, DatabaseFileNotFoundError, InvalidKeyError try: @@ -66,20 +66,14 @@ def check(key, db, json, full_report, bare, stdin, files, cache, ignore, output, d for d in pkg_resources.working_set if d.key not in {"python", "wsgiref", "argparse"} ] - proxy_dictionary = {} - if proxyhost is not None: - if proxyprotocol in ["http", "https"]: - proxy_dictionary = {proxyprotocol: "{0}://{1}:{2}".format(proxyprotocol, proxyhost, str(proxyport))} - else: - click.secho("Proxy Protocol should be http or https only.", fg="red") - sys.exit(-1) + proxy_dictionary = get_proxy_dict(proxyprotocol, proxyhost, proxyport) try: vulns = safety.check(packages=packages, key=key, db_mirror=db, cached=cache, ignore_ids=ignore, proxy=proxy_dictionary) output_report = report(vulns=vulns, full=full_report, json_report=json, bare_report=bare, - checked_packages=len(packages), + checked_packages=len(packages), db=db, key=key) @@ -128,5 +122,40 @@ def review(full_report, bare, file): click.secho(output_report, nl=False if bare and not vulns else True) +@cli.command() +@click.option("--key", default="", envvar="SAFETY_API_KEY", + help="API Key for pyup.io's vulnerability database. Can be set as SAFETY_API_KEY " + "environment variable. Default: empty") +@click.option("--db", default="", + help="Path to a local license database. Default: empty") +@click.option("--cache/--no-cache", default=True, + help='Whether license database file should be cached.' + 'Default: --cache') +@click.option("files", "--file", "-r", multiple=True, type=click.File(), + help="Read input from one (or multiple) requirement files. Default: empty") +@click.option("proxyhost", "--proxy-host", "-ph", multiple=False, type=str, default=None, + help="Proxy host IP or DNS --proxy-host") +@click.option("proxyport", "--proxy-port", "-pp", multiple=False, type=int, default=80, + help="Proxy port number --proxy-port") +@click.option("proxyprotocol", "--proxy-protocol", "-pr", multiple=False, type=str, default='http', + help="Proxy protocol (https or http) --proxy-protocol") +def license(key, db, cache, files, proxyprotocol, proxyhost, proxyport): + + if files: + packages = list(itertools.chain.from_iterable(read_requirements(f, resolve=True) for f in files)) + else: + import pkg_resources + packages = [ + d for d in pkg_resources.working_set + if d.key not in {"python", "wsgiref", "argparse"} + ] + + proxy_dictionary = get_proxy_dict(proxyprotocol, proxyhost, proxyport) + licenses_db = safety.get_licenses(key, db, cache, proxy_dictionary) + filtered_packages_licenses = get_packages_licenses(packages, licenses_db) + output_report = license_report(packages=packages, licenses=filtered_packages_licenses) + click.secho(output_report, nl=True) + + if __name__ == "__main__": cli() diff --git a/safety/formatter.py b/safety/formatter.py index 6cfa369a..af1e0903 100644 --- a/safety/formatter.py +++ b/safety/formatter.py @@ -5,6 +5,8 @@ import os import textwrap +from .util import get_packages_licenses + # python 2.7 compat try: FileNotFoundError @@ -69,6 +71,12 @@ class SheetReport(object): +============================+===========+==========================+==========+ """.strip() + TABLE_HEADING_LICENSES = r""" ++=============================================+===========+====================+ +| package | version | license | ++=============================================+===========+====================+ + """.strip() + REPORT_HEADING = r""" | REPORT | """.strip() @@ -125,6 +133,52 @@ def render(vulns, full, checked_packages, used_db): content, SheetReport.REPORT_FOOTER] ) + @staticmethod + def render_licenses(packages, packages_licenses): + heading = SheetReport.REPORT_HEADING.replace(" ", "", 12).replace( + "REPORT", " Packages licenses" + ) + if not packages_licenses: + content = "| {:76} |".format("No packages licenses found.") + return "\n".join( + [SheetReport.REPORT_BANNER, heading, SheetReport.REPORT_SECTION, + content, SheetReport.REPORT_FOOTER] + ) + + table = [] + iteration = 1 + for pkg_license in packages_licenses: + max_char = last_char = 43 # defines a limit for package name. + current_line = 1 + package = pkg_license['package'] + license = pkg_license['license'] + version = pkg_license['version'] + license_line = int(int(len(package) / max_char) / 2) + 1 # Calc to get which line to add the license info. + + table.append("| {:43} | {:9} | {:18} |".format( + package[:max_char], + version[:9] if current_line == license_line else "", + license[:18] if current_line == license_line else "", + )) + + long_name = True if len(package[max_char:]) > 0 else False + while long_name: # If the package has a long name, break it into multiple lines. + current_line += 1 + table.append("| {:43} | {:9} | {:18} |".format( + package[last_char:last_char+max_char], + version[:9] if current_line == license_line else "", + license[:18] if current_line == license_line else "", + )) + last_char = last_char+max_char + long_name = True if len(package[last_char:]) > 0 else False + + if iteration != len(packages_licenses): # Do not add dashes "----" for last package. + table.append("|" + ("-" * 78) + "|") + iteration += 1 + return "\n".join( + [SheetReport.REPORT_BANNER, heading, SheetReport.TABLE_HEADING_LICENSES, + "\n".join(table), SheetReport.REPORT_FOOTER] + ) class BasicReport(object): """Basic report, intented to be used for terminals with < 80 columns""" @@ -157,6 +211,24 @@ def render(vulns, full, checked_packages, used_db): table ) + @staticmethod + def render_licenses(packages, packages_licenses): + table = [ + "safety", + "packages licenses", + "---" + ] + if not packages_licenses: + table.append("No packages licenses found.") + return "\n".join(table) + + for pkg_license in packages_licenses: + text = pkg_license['package'] + \ + ", version " + pkg_license['version'] + \ + ", license " + pkg_license['license'] + "\n" + table.append(text) + + return "\n".join(table) class JsonReport(object): """Json report, for when the output is input for something else""" @@ -192,3 +264,11 @@ def report(vulns, full=False, json_report=False, bare_report=False, checked_pack if size.columns >= 80: return SheetReport.render(vulns, full=full, checked_packages=checked_packages, used_db=used_db) return BasicReport.render(vulns, full=full, checked_packages=checked_packages, used_db=used_db) + + +def license_report(packages, licenses): + size = get_terminal_size() + + if size.columns >= 80: + return SheetReport.render_licenses(packages, licenses) + return BasicReport.render_licenses(packages, licenses) diff --git a/safety/safety.py b/safety/safety.py index e11a5310..42b1c02c 100644 --- a/safety/safety.py +++ b/safety/safety.py @@ -176,3 +176,26 @@ def review(vulnerabilities): Vulnerability(**current_vuln) ) return vulnerable + + +def get_licenses(key, db_mirror, cached, proxy): + key = key if key else os.environ.get("SAFETY_API_KEY", False) + + if not key: + raise DatabaseFetchError("API-KEY not provided.") + if db_mirror: + mirrors = [db_mirror] + else: + mirrors = API_MIRRORS + + db_name = "licenses.json" + + for mirror in mirrors: + # mirror can either be a local path or a URL + if mirror.startswith("http://") or mirror.startswith("https://"): + licenses = fetch_database_url(mirror, db_name=db_name, key=key, cached=cached, proxy=proxy) + else: + licenses = fetch_database_file(mirror, db_name=db_name) + if licenses: + return licenses + raise DatabaseFetchError() diff --git a/safety/util.py b/safety/util.py index 82044574..760b21a5 100644 --- a/safety/util.py +++ b/safety/util.py @@ -1,5 +1,6 @@ from dparse.parser import setuptools_parse_requirements_backport as _parse_requirements from collections import namedtuple +from packaging.version import parse as parse_version import click import sys import json @@ -105,3 +106,66 @@ def read_requirements(fh, resolve=False): ) except ValueError: continue + + +def get_proxy_dict(proxyprotocol, proxyhost, proxyport): + proxy_dictionary = {} + if proxyhost is not None: + if proxyprotocol in ["http", "https"]: + proxy_dictionary = {proxyprotocol: "{0}://{1}:{2}".format(proxyprotocol, proxyhost, str(proxyport))} + else: + click.secho("Proxy Protocol should be http or https only.", fg="red") + sys.exit(-1) + return proxy_dictionary + + +def get_license_name_by_id(license_id, db): + licenses = db.get('licenses', []) + for name, id in licenses.items(): + if id == license_id: + return name + return None + +def get_packages_licenses(packages, licenses_db): + """Get the licenses for the specified packages based on their version. + + :param packages: packages list + :param licenses_db: the licenses db in the raw form. + :return: list of objects with the packages and their respectives licenses. + """ + packages_licenses_db = licenses_db.get('packages', {}) + filtered_packages_licenses = [] + + for pkg in packages: + # Ignore recursive files not resolved + if isinstance(pkg, RequirementFile): + continue + # normalize the package name + pkg_name = pkg.key.replace("_", "-").lower() + # packages may have different licenses depending their version. + pkg_licenses = packages_licenses_db.get(pkg_name, []) + version_requested = parse_version(pkg.version) + license_id = None + license_name = None + for pkg_version in pkg_licenses: + license_start_version = parse_version(pkg_version['start_version']) + # Stops and return the previous stored license when a new + # license starts on a version above the requested one. + if version_requested >= license_start_version: + license_id = pkg_version['license_id'] + else: + # We found the license for the version requested + break + + if license_id: + license_name = get_license_name_by_id(license_id, licenses_db) + if not license_id or not license_name: + license_name = "N/A" + + filtered_packages_licenses.append({ + "package": pkg_name, + "version": pkg.version, + "license": license_name + }) + + return filtered_packages_licenses diff --git a/tests/test_db/licenses.json b/tests/test_db/licenses.json new file mode 100644 index 00000000..3fe23a43 --- /dev/null +++ b/tests/test_db/licenses.json @@ -0,0 +1,13 @@ +{ + "licenses": { + "BSD-3-Clause": 1 + }, + "packages": { + "django": [ + { + "start_version": "0.0", + "license_id": 1 + } + ] + } +} \ No newline at end of file diff --git a/tests/test_safety.py b/tests/test_safety.py index be086751..9fe4833e 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -142,7 +142,7 @@ def test_check_from_file(self): cached=False, key=False, ignore_ids=[], - proxy={} + proxy={}, ) self.assertEqual(len(vulns), 2) @@ -160,7 +160,7 @@ def test_check_from_file_with_hash_pins(self): cached=False, key=False, ignore_ids=[], - proxy={} + proxy={}, ) self.assertEqual(len(vulns), 2) @@ -177,7 +177,7 @@ def test_multiple_versions(self): cached=False, key=False, ignore_ids=[], - proxy={} + proxy={}, ) self.assertEqual(len(vulns), 4) @@ -191,7 +191,7 @@ def test_check_live(self): cached=False, key=False, ignore_ids=[], - proxy={} + proxy={}, ) self.assertEqual(len(vulns), 1) @@ -205,7 +205,7 @@ def test_check_live_cached(self): cached=True, key=False, ignore_ids=[], - proxy={} + proxy={}, ) self.assertEqual(len(vulns), 1) @@ -218,10 +218,44 @@ def test_check_live_cached(self): cached=True, key=False, ignore_ids=[], - proxy={} + proxy={}, ) self.assertEqual(len(vulns), 1) + def test_get_packages_licenses(self): + reqs = StringIO("Django==1.8.1\n\rinvalid==1.0.0") + packages = util.read_requirements(reqs) + licenses_db = safety.get_licenses( + db_mirror=os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "test_db" + ), + cached=False, + key="foobarqux", + proxy={}, + ) + self.assertIn("licenses", licenses_db) + self.assertIn("packages", licenses_db) + self.assertIn("BSD-3-Clause", licenses_db['licenses']) + self.assertIn("django", licenses_db['packages']) + + pkg_licenses = util.get_packages_licenses(packages, licenses_db) + + self.assertIsInstance(pkg_licenses, list) + for pkg_license in pkg_licenses: + license = pkg_license['license'] + version = pkg_license['version'] + if pkg_license['package'] == 'django': + self.assertEqual(license, 'BSD-3-Clause') + self.assertEqual(version, '1.8.1') + elif pkg_license['package'] == 'invalid': + self.assertEqual(license, 'N/A') + self.assertEqual(version, '1.0.0') + else: + raise AssertionError( + "unexpected package '" + pkg_license['package'] + "' was found" + ) + class ReadRequirementsTestCase(unittest.TestCase):