Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate RECORD file using streams instead of reading in-memory #186

Merged
merged 3 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/installer/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import csv
import hashlib
import os
from typing import Iterable, Iterator, Optional, Tuple, cast
from typing import BinaryIO, Iterable, Iterator, Optional, Tuple, cast

from installer.utils import copyfileobj_with_hashing, get_stream_length

__all__ = [
"Hash",
Expand Down Expand Up @@ -144,6 +146,10 @@ def __eq__(self, other: object) -> bool:
def validate(self, data: bytes) -> bool:
"""Validate that ``data`` matches this instance.

.. attention::
.. deprecated:: 0.8.0
Use :py:meth:`validate_stream` instead, with ``BytesIO(data)``.

:param data: Contents of the file corresponding to this instance.
:return: whether ``data`` matches hash and size.
"""
Expand All @@ -155,6 +161,33 @@ def validate(self, data: bytes) -> bool:

return True

def validate_stream(self, stream: BinaryIO) -> bool:
"""Validate that data read from stream matches this instance.

:param stream: Representing the contents of the file.
:return: Whether data read from stream matches hash and size.
"""
if self.hash_ is not None:
with open(os.devnull, "wb") as new_target:
hash_, size = copyfileobj_with_hashing(
stream, cast("BinaryIO", new_target), self.hash_.name
)

if self.size is not None and size != self.size:
return False
if self.hash_.value != hash_:
return False
return True

elif self.size is not None:
assert self.hash_ is None
size = get_stream_length(stream)
if size != self.size:
return False
return True

return True

@classmethod
def from_elements(cls, path: str, hash_: str, size: str) -> "RecordEntry":
r"""Build a RecordEntry object, from values of the elements.
Expand Down
10 changes: 5 additions & 5 deletions src/installer/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,11 @@ def validate_record(self, *, validate_contents: bool = True) -> None:
f"In {self._zipfile.filename}, hash / size of {item.filename} is not included in RECORD"
)
if validate_contents:
data = self._zipfile.read(item)
if not record.validate(data):
issues.append(
f"In {self._zipfile.filename}, hash / size of {item.filename} didn't match RECORD"
)
with self._zipfile.open(item, "r") as stream:
if not record.validate_stream(cast(BinaryIO, stream)):
issues.append(
f"In {self._zipfile.filename}, hash / size of {item.filename} didn't match RECORD"
)

if issues:
raise _WheelFileValidationError(issues)
Expand Down
16 changes: 16 additions & 0 deletions src/installer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,22 @@ def copyfileobj_with_hashing(
return base64.urlsafe_b64encode(hasher.digest()).decode("ascii").rstrip("="), size


def get_stream_length(source: BinaryIO) -> int:
"""Read a buffer while computing the content's size.

:param source: buffer holding the source data
:return: size of the contents
"""
size = 0
while True:
buf = source.read(_COPY_BUFSIZE)
if not buf:
break
size += len(buf)

return size


def get_launcher_kind() -> "LauncherKind": # pragma: no cover
"""Get the launcher kind for the current machine."""
if os.name != "nt":
Expand Down
34 changes: 26 additions & 8 deletions tests/test_records.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from io import BytesIO

import pytest

from installer.records import Hash, InvalidRecordEntry, RecordEntry, parse_record_file
Expand Down Expand Up @@ -51,23 +53,31 @@ def record_input(request):
b"test3\n",
True,
),
(
"purelib",
("test4.py", "sha256=Y0sCextp4SQtQNU-MSs7SsdxD1W-gfKJtUlEbvZ3i-4", 7),
b"test1\n",
False,
),
(
"purelib",
(
"test5.py",
"test4.py",
"sha256=Y0sCextp4SQtQNU-MSs7SsdxD1W-gfKJtUlEbvZ3i-4",
None,
),
b"test1\n",
True,
),
("purelib", ("test6.py", None, None), b"test1\n", True),
("purelib", ("test5.py", None, None), b"test1\n", True),
("purelib", ("test6.py", None, 6), b"test1\n", True),
(
"purelib",
("test7.py", "sha256=Y0sCextp4SQtQNU-MSs7SsdxD1W-gfKJtUlEbvZ3i-4", 7),
b"test1\n",
False,
),
(
"purelib",
("test7.py", "sha256=Y0sCextp4SQtQNU-MSs7SsdxD1W-gfKJtUlEbvZ3i-4", None),
b"not-test1\n",
False,
),
("purelib", ("test8.py", None, 10), b"test1\n", False),
]


Expand Down Expand Up @@ -130,6 +140,14 @@ def test_validation(self, scheme, elements, data, passes_validation):
record = RecordEntry.from_elements(*elements)
assert record.validate(data) == passes_validation

@pytest.mark.parametrize(
("scheme", "elements", "data", "passes_validation"), SAMPLE_RECORDS
)
def test_validate_stream(self, scheme, elements, data, passes_validation):
record = RecordEntry.from_elements(*elements)

assert record.validate_stream(BytesIO(data)) == passes_validation

@pytest.mark.parametrize(
("scheme", "elements", "data", "passes_validation"), SAMPLE_RECORDS
)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
construct_record_file,
copyfileobj_with_hashing,
fix_shebang,
get_stream_length,
parse_entrypoints,
parse_metadata_file,
parse_wheel_filename,
Expand Down Expand Up @@ -134,6 +135,17 @@ def test_basic_functionality(self):
assert written_data == data


class TestGetStreamLength:
def test_basic_functionality(self):
data = b"input data is this"
size = len(data)

with BytesIO(data) as source:
result = get_stream_length(source)

assert result == size


class TestScript:
@pytest.mark.parametrize(
("data", "expected"),
Expand Down