From 5fea1a7a6292e6f78f533ee7d706bad23420f132 Mon Sep 17 00:00:00 2001 From: David JM Emmett Date: Wed, 7 Aug 2024 16:14:56 +0100 Subject: [PATCH] Adding --osv-url argument to allow use of private OSV vulnerability services https://github.com/pypa/pip-audit/issues/805 --- .gitignore | 1 + pip_audit/_cli.py | 20 ++++++++++++++++---- pip_audit/_service/osv.py | 13 ++++++++++--- pip_audit/_service/pypi.py | 2 +- test/test_cli.py | 2 +- 5 files changed, 29 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 4d279d60..90c1decc 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ __pycache__/ html/ dist/ .python-version +/.pytest_cache/ diff --git a/pip_audit/_cli.py b/pip_audit/_cli.py index 85cc02c9..c399635f 100644 --- a/pip_audit/_cli.py +++ b/pip_audit/_cli.py @@ -98,11 +98,11 @@ class VulnerabilityServiceChoice(str, enum.Enum): Osv = "osv" Pypi = "pypi" - def to_service(self, timeout: int, cache_dir: Path | None) -> VulnerabilityService: + def to_service(self, **kwargs: dict) -> VulnerabilityService: if self is VulnerabilityServiceChoice.Osv: - return OsvService(cache_dir, timeout) + return OsvService(**kwargs) elif self is VulnerabilityServiceChoice.Pypi: - return PyPIService(cache_dir, timeout) + return PyPIService(**kwargs) else: assert_never(self) # pragma: no cover @@ -241,6 +241,14 @@ def _parser() -> argparse.ArgumentParser: # pragma: no cover VulnerabilityServiceChoice, ), ) + parser.add_argument( + "--osv-url", + type=str, + metavar="OSV_URL", + dest="osv_url", + default=os.environ.get("PIP_AUDIT_OSV_URL", OsvService.DEFAULT_OSV_URL), + help="URL to use for the OSV API instead of the default", + ) parser.add_argument( "-d", "--dry-run", @@ -418,7 +426,11 @@ def audit() -> None: # pragma: no cover parser = _parser() args = _parse_args(parser) - service = args.vulnerability_service.to_service(args.timeout, args.cache_dir) + service = args.vulnerability_service.to_service( + timeout=args.timeout, + cache_dir=args.cache_dir, + osv_url=args.osv_url, + ) output_desc = args.desc.to_bool(args.format) output_aliases = args.aliases.to_bool(args.format) formatter = args.format.to_format(output_desc, output_aliases) diff --git a/pip_audit/_service/osv.py b/pip_audit/_service/osv.py index 8b0fca28..e0e80d6b 100644 --- a/pip_audit/_service/osv.py +++ b/pip_audit/_service/osv.py @@ -31,7 +31,14 @@ class OsvService(VulnerabilityService): package vulnerability information. """ - def __init__(self, cache_dir: Path | None = None, timeout: int | None = None): + DEFAULT_OSV_URL = "https://api.osv.dev/v1/query" + + def __init__( + self, + cache_dir: Path | None = None, + timeout: int | None = None, + osv_url: str = DEFAULT_OSV_URL, + ): """ Create a new `OsvService`. @@ -43,6 +50,7 @@ def __init__(self, cache_dir: Path | None = None, timeout: int | None = None): """ self.session = caching_session(cache_dir, use_pip=False) self.timeout = timeout + self.osv_url = osv_url def query(self, spec: Dependency) -> tuple[Dependency, list[VulnerabilityResult]]: """ @@ -54,14 +62,13 @@ def query(self, spec: Dependency) -> tuple[Dependency, list[VulnerabilityResult] return spec, [] spec = cast(ResolvedDependency, spec) - url = "https://api.osv.dev/v1/query" query = { "package": {"name": spec.canonical_name, "ecosystem": "PyPI"}, "version": str(spec.version), } try: response: requests.Response = self.session.post( - url=url, + url=self.osv_url, data=json.dumps(query), timeout=self.timeout, ) diff --git a/pip_audit/_service/pypi.py b/pip_audit/_service/pypi.py index 448a934a..f0848719 100644 --- a/pip_audit/_service/pypi.py +++ b/pip_audit/_service/pypi.py @@ -32,7 +32,7 @@ class PyPIService(VulnerabilityService): package vulnerability information. """ - def __init__(self, cache_dir: Path | None = None, timeout: int | None = None) -> None: + def __init__(self, cache_dir: Path | None = None, timeout: int | None = None, **kwargs: dict) -> None: """ Create a new `PyPIService`. diff --git a/test/test_cli.py b/test/test_cli.py index 6943412f..cb6eea88 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -29,7 +29,7 @@ def test_str(self): class TestVulnerabilityServiceChoice: def test_to_service_is_exhaustive(self, cache_dir): for choice in VulnerabilityServiceChoice: - assert choice.to_service(0, cache_dir) is not None + assert choice.to_service(timeout=0, cache_dir=cache_dir) is not None def test_str(self): for choice in VulnerabilityServiceChoice: