diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 908f1f02..85e3c13c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -70,7 +70,6 @@ jobs: - id: dependencies run: | - sudo apt-get install -y libsnappy-dev pip install -r requirements.txt pip install -r requirements.dev.txt diff --git a/.pylintrc b/.pylintrc index f90a82a0..79ca45b4 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,3 +1,7 @@ +[MAIN] + +extension-pkg-allow-list=cramjam + [MESSAGES CONTROL] disable= bad-option-value, diff --git a/make_release.py b/make_release.py index 18c0d227..75746bca 100755 --- a/make_release.py +++ b/make_release.py @@ -13,12 +13,14 @@ def make_release(version: str) -> None: version_filename.write_text(f'VERSION = "{version}"\n') subprocess.run(["git", "-C", str(project_directory), "add", str(version_filename)], check=True) subprocess.run(["git", "-C", str(project_directory), "commit", "-m", f"Bump to version {version}"], check=True) - subprocess.run(["git", "-C", str(project_directory), "tag", "-a", f"releases/{version}", "-m", f"Version {version}"], check=True) + subprocess.run( + ["git", "-C", str(project_directory), "tag", "-a", f"releases/{version}", "-m", f"Version {version}"], check=True + ) subprocess.run(["git", "-C", str(project_directory), "log", "-n", "1", "-p"], check=True) print("Run 'git push --tags' to confirm the release") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser("Make a rohmu release") parser.add_argument("version") args = parser.parse_args() diff --git a/mypy.ini b/mypy.ini index 7bd72294..a5647dd0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -14,14 +14,13 @@ ignore_missing_imports = True [mypy-azure.common.*] ignore_missing_imports = True -[mypy-dataclasses.*] +[mypy-cramjam.*] ignore_missing_imports = True - -[mypy-oauth2client.*] +[mypy-dataclasses.*] ignore_missing_imports = True -[mypy-snappy.*] +[mypy-oauth2client.*] ignore_missing_imports = True [mypy-swiftclient.*] diff --git a/rohmu.spec b/rohmu.spec index a3e43496..a3e5fc53 100644 --- a/rohmu.spec +++ b/rohmu.spec @@ -13,7 +13,8 @@ Requires: python3-cryptography >= 1.6 Requires: python3-dateutil Requires: python3-pydantic Requires: python3-requests -Requires: python3-snappy +# Requires: python3-snappy +# TODO: Create python3-cramjam Requires: python3-zstandard BuildRequires: python3-devel BuildRequires: python3-flake8 diff --git a/rohmu/compressor.py b/rohmu/compressor.py index b856baa4..8b219b81 100644 --- a/rohmu/compressor.py +++ b/rohmu/compressor.py @@ -7,16 +7,16 @@ from .errors import InvalidConfigurationError from .filewrap import Sink, Stream from .snappyfile import SnappyFile -from .typing import BinaryData, Compressor, Decompressor, FileLike, HasRead, HasWrite +from .typing import Algorithm, BinaryData, Compressor, Decompressor, FileLike, HasRead, HasWrite from .zstdfile import open as zstd_open from typing import cast, IO import lzma try: - import snappy + import cramjam except ImportError: - snappy = None # type: ignore + cramjam = None # type: ignore try: import zstandard as zstd @@ -24,37 +24,39 @@ zstd = None # type: ignore -def CompressionFile(dst_fp: FileLike, algorithm: str, level: int = 0, threads: int = 0) -> FileLike: +def CompressionFile(dst_fp: FileLike, algorithm: Algorithm, level: int = 0, threads: int = 0) -> FileLike: """This looks like a class to users, but is actually a function that instantiates a class based on algorithm.""" - if algorithm == "lzma": - return lzma.open(cast(IO[bytes], dst_fp), "w", preset=level) - - if algorithm == "snappy": - return SnappyFile(dst_fp, "wb") - - if algorithm == "zstd": - return zstd_open(dst_fp, "wb", level=level, threads=threads) - - if algorithm: - raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) - - return dst_fp + match algorithm: + case "lzma": + compression_fileobj = lzma.open(cast(IO[bytes], dst_fp), "w", preset=level) + case "snappy": + compression_fileobj = SnappyFile(dst_fp, "wb") + case "zstd": + compression_fileobj = zstd_open(dst_fp, "wb", level=level, threads=threads) + case _: + raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) + return compression_fileobj + + +def create_streaming_compressor(algorithm: Algorithm, level: int = 0) -> Compressor: + match algorithm: + case "lzma": + compressor = lzma.LZMACompressor(lzma.FORMAT_XZ, -1, level, None) + case "snappy": + compressor = cramjam.snappy.Compressor() + case "zstd": + compressor = zstd.ZstdCompressor(level=level).compressobj() + case _: + raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) + return compressor class CompressionStream(Stream): """Non-seekable stream of data that adds compression on top of given source stream""" - def __init__(self, src_fp: HasRead, algorithm: str, level: int = 0) -> None: + def __init__(self, src_fp: HasRead, algorithm: Algorithm, level: int = 0) -> None: super().__init__(src_fp, minimum_read_size=32 * 1024) - self._compressor: Compressor - if algorithm == "lzma": - self._compressor = lzma.LZMACompressor(lzma.FORMAT_XZ, -1, level, None) - elif algorithm == "snappy": - self._compressor = snappy.StreamCompressor() - elif algorithm == "zstd": - self._compressor = zstd.ZstdCompressor(level=level).compressobj() - else: - raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) + self._compressor = create_streaming_compressor(algorithm, level) def _process_chunk(self, data: bytes) -> bytes: return self._compressor.compress(data) @@ -63,42 +65,42 @@ def _finalize(self) -> bytes: return self._compressor.flush() -def DecompressionFile(src_fp: FileLike, algorithm: str) -> FileLike: +def DecompressionFile(src_fp: FileLike, algorithm: Algorithm) -> FileLike: """This looks like a class to users, but is actually a function that instantiates a class based on algorithm.""" - if algorithm == "lzma": - return lzma.open(cast(IO[bytes], src_fp), "r") - - if algorithm == "snappy": - return SnappyFile(src_fp, "rb") - - if algorithm == "zstd": - return zstd_open(src_fp, "rb") + match algorithm: + case "lzma": + return lzma.open(cast(IO[bytes], src_fp), "r") + case "snappy": + return SnappyFile(src_fp, "rb") + case "zstd": + return zstd_open(src_fp, "rb") + case _: + raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) - if algorithm: - raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) - return src_fp +def create_streaming_decompressor(algorithm: Algorithm) -> Decompressor: + match algorithm: + case "lzma": + decompressor = lzma.LZMADecompressor() + case "snappy": + decompressor = cramjam.snappy.Decompressor() + case "zstd": + decompressor = zstd.ZstdDecompressor().decompressobj() + case _: + raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) + return decompressor class DecompressSink(Sink): - def __init__(self, next_sink: HasWrite, compression_algorithm: str): + def __init__(self, next_sink: HasWrite, compression_algorithm: Algorithm): super().__init__(next_sink) - self.decompressor = self._create_decompressor(compression_algorithm) - - def _create_decompressor(self, alg: str) -> Decompressor: - if alg == "snappy": - return snappy.StreamDecompressor() - elif alg == "lzma": - return lzma.LZMADecompressor() - elif alg == "zstd": - return zstd.ZstdDecompressor().decompressobj() - raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(alg)) + self.decompressor = create_streaming_decompressor(compression_algorithm) def write(self, data: BinaryData) -> int: data = bytes(data) if not isinstance(data, bytes) else data written = len(data) if not data: return written - data = self.decompressor.decompress(data) - self._write_to_next_sink(data) + decompressed_data = self.decompressor.decompress(data) + self._write_to_next_sink(decompressed_data) return written diff --git a/rohmu/filewrap.py b/rohmu/filewrap.py index 0e141b7c..4006b32a 100644 --- a/rohmu/filewrap.py +++ b/rohmu/filewrap.py @@ -178,7 +178,6 @@ def read(self, size: int = -1) -> bytes: bytes_available += len(dst_data) if not src_data: self._eof = True - if size < 0 or bytes_available < size: data = b"".join(chunks) self._remainder = b"" diff --git a/rohmu/rohmufile.py b/rohmu/rohmufile.py index c735d28d..3ba95bd4 100644 --- a/rohmu/rohmufile.py +++ b/rohmu/rohmufile.py @@ -12,7 +12,7 @@ from .encryptor import DecryptorFile, DecryptSink, EncryptorFile from .errors import InvalidConfigurationError from .filewrap import ThrottleSink -from .typing import FileLike, HasWrite, Metadata +from .typing import Algorithm, FileLike, HasWrite, Metadata from contextlib import suppress from inspect import signature from rohmu.object_storage.base import IncrementalProgressCallbackType @@ -142,7 +142,7 @@ def read_file( def file_writer( *, fileobj: FileLike, - compression_algorithm: Optional[str] = None, + compression_algorithm: Optional[Algorithm] = None, compression_level: int = 0, compression_threads: int = 0, rsa_public_key: Union[None, str, bytes] = None, diff --git a/rohmu/snappyfile.py b/rohmu/snappyfile.py index deee951d..ea58722a 100644 --- a/rohmu/snappyfile.py +++ b/rohmu/snappyfile.py @@ -12,22 +12,22 @@ import io try: - import snappy + import cramjam except ImportError: - snappy = None # type: ignore + cramjam = None # type: ignore class SnappyFile(FileWrap): def __init__(self, next_fp: FileLike, mode: str) -> None: - if snappy is None: + if cramjam is None: raise io.UnsupportedOperation("Snappy is not available") if mode == "rb": - self.decr = snappy.StreamDecompressor() + self.decr = cramjam.snappy.Decompressor() self.encr = None elif mode == "wb": self.decr = None - self.encr = snappy.StreamCompressor() + self.encr = cramjam.snappy.Compressor() else: raise io.UnsupportedOperation("unsupported mode for SnappyFile") @@ -49,10 +49,11 @@ def write(self, data: BinaryData) -> int: # type: ignore [override] if self.encr is None: raise io.UnsupportedOperation("file not open for writing") data_as_bytes = bytes(data) - compressed_data = self.encr.compress(data_as_bytes) - self.next_fp.write(compressed_data) - self.offset += len(data_as_bytes) - return len(data_as_bytes) + block_size = self.encr.compress(data_as_bytes) + compressed_buffer = self.encr.flush() + self.next_fp.write(compressed_buffer) + self.offset += block_size + return block_size def writable(self) -> bool: return self.encr is not None @@ -62,19 +63,13 @@ def read(self, size: Optional[int] = -1) -> bytes: # pylint: disable=unused-arg self._check_not_closed() if self.decr is None: raise io.UnsupportedOperation("file not open for reading") - while not self.decr_done: - compressed = self.next_fp.read(IO_BLOCK_SIZE) - if not compressed: - self.decr_done = True - output = self.decr.flush() - else: - output = self.decr.decompress(compressed) - - if output: - self.offset += len(output) - return output - - return b"" + num_decompressed_bytes = 0 + while compressed := self.next_fp.read(IO_BLOCK_SIZE): + chunk_size = self.decr.decompress(compressed) + num_decompressed_bytes += chunk_size + self.offset += num_decompressed_bytes + output = self.decr.flush().read() + return output def readable(self) -> bool: return self.decr is not None diff --git a/rohmu/transfer_pool.py b/rohmu/transfer_pool.py index ad909543..0292f2b9 100644 --- a/rohmu/transfer_pool.py +++ b/rohmu/transfer_pool.py @@ -109,6 +109,7 @@ def put(self, transfer: TransferCacheItem) -> None: _BASE_TRANSFER_INSTANCE_ATTRS = {"config_model", "log", "notifier", "prefix", "stats"} _BASE_TRANSFER_ATTRS = {attr for attr in vars(BaseTransfer) if not attr.startswith("__")} | _BASE_TRANSFER_INSTANCE_ATTRS + # pylint: disable=abstract-method,super-init-not-called class SafeTransfer(BaseTransfer[StorageModel]): """Helper class that helps the users in finding bugs in their code handling transfers. diff --git a/rohmu/typing.py b/rohmu/typing.py index f7f32002..44e4feec 100644 --- a/rohmu/typing.py +++ b/rohmu/typing.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import TracebackType -from typing import Any, Dict, Optional, Protocol, Type, TYPE_CHECKING, Union +from typing import Any, Dict, Literal, Optional, Protocol, Type, TYPE_CHECKING, Union try: # Remove when dropping support for Python 3.7 @@ -32,6 +32,8 @@ StrOrPathLike = Union[str, "PathLike[str]"] +Algorithm = Literal["lzma", "snappy", "zstd"] + class HasFileno(Protocol): def fileno(self) -> int: diff --git a/setup.cfg b/setup.cfg index 6214f113..cf28ae12 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ packages = find: install_requires = azure-storage-blob >= 2.1.0 botocore + cramjam >= 2.7.0 cryptography google-api-python-client httplib2 @@ -32,7 +33,6 @@ install_requires = paramiko pydantic python-dateutil - python-snappy requests zstandard typing_extensions >= 3.10, < 5 diff --git a/test/test_compressor.py b/test/test_compressor.py index b13cdf04..172d9966 100644 --- a/test/test_compressor.py +++ b/test/test_compressor.py @@ -1,19 +1,19 @@ -from rohmu.compressor import CompressionFile, DecompressionFile +from rohmu.compressor import CompressionFile, CompressionStream, DecompressionFile, DecompressSink +from rohmu.typing import Algorithm +from typing import Final import io +import math import pytest +import random +import string +SAMPLE_BYTES: Final[bytes] = b"Some contents" -@pytest.mark.parametrize( - "algorithm,contents", - [ - ("lzma", b""), - ("snappy", b""), - ("lzma", b"Some contents"), - ("snappy", b"Some contents"), - ], -) -def test_compress_decompress_simple_file(algorithm: str, contents: bytes) -> None: + +@pytest.mark.parametrize("algorithm", ["lzma", "snappy", "zstd"]) +@pytest.mark.parametrize("contents", [b"", 100 * SAMPLE_BYTES], ids=["empty", "sample-bytes"]) +def test_compress_decompress_simple_file(algorithm: Algorithm, contents: bytes) -> None: bio = io.BytesIO() ef = CompressionFile(bio, algorithm=algorithm) ef.write(contents) @@ -23,3 +23,39 @@ def test_compress_decompress_simple_file(algorithm: str, contents: bytes) -> Non df = DecompressionFile(bio, algorithm=algorithm) data = df.read() assert data == contents + + +@pytest.mark.skip(reason="neither snappy nor zstd seem to handle multiple chunks") +@pytest.mark.parametrize("algorithm", ["lzma", "snappy", "zstd"]) +def test_compress_decompress_multiple_chunks(algorithm: Algorithm) -> None: + contents = "".join(random.choices(string.ascii_letters + string.digits, k=1_000_000)).encode() + num_bytes = len(contents) + print(f"Data size exponent (block size = 20): {math.log2(num_bytes)}") + bytes_buf = io.BytesIO() + ef = CompressionFile(bytes_buf, algorithm=algorithm) + ef.write(contents) + ef.close() + + bytes_buf.seek(0) + df = DecompressionFile(bytes_buf, algorithm=algorithm) + data = df.read() + assert data == contents + + +@pytest.mark.parametrize("algorithm", ["lzma", "snappy", "zstd"]) +@pytest.mark.parametrize("contents", [b"", 100 * SAMPLE_BYTES], ids=["empty", "sample-bytes"]) +def test_compress_decompress_streaming(algorithm: Algorithm, contents: bytes) -> None: + input_buffer = io.BytesIO() + input_buffer.write(contents) + input_buffer.seek(0) + compression_stream = CompressionStream(input_buffer, algorithm) + compressed_bytes = compression_stream.read() + assert compression_stream.tell() == len(compressed_bytes) + + output_buffer = io.BytesIO() + decompression_sink = DecompressSink(output_buffer, algorithm) + num_bytes_written = decompression_sink.write(compressed_bytes) + assert num_bytes_written == len(compressed_bytes) + output_buffer.seek(0) + output_data = output_buffer.read() + assert output_data == contents diff --git a/test/test_dates.py b/test/test_dates.py index 696715e5..06f4579e 100644 --- a/test/test_dates.py +++ b/test/test_dates.py @@ -4,11 +4,13 @@ Copyright (c) 2017 Ohmu Ltd See LICENSE for details """ +from dateutil.parser import UnknownTimezoneWarning from rohmu.dates import parse_timestamp import datetime import dateutil.tz import re +import warnings def test_parse_timestamp() -> None: @@ -24,7 +26,9 @@ def test_parse_timestamp() -> None: assert local_naive == local_aware.replace(tzinfo=None) str_unknown_aware = "2017-02-02 12:00:00 XYZ" - unknown_aware_utc = parse_timestamp(str_unknown_aware) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UnknownTimezoneWarning) + unknown_aware_utc = parse_timestamp(str_unknown_aware) assert unknown_aware_utc.tzinfo == datetime.timezone.utc assert unknown_aware_utc.isoformat() == "2017-02-02T12:00:00+00:00" diff --git a/test/test_factory.py b/test/test_factory.py index e49f3592..32b181fc 100644 --- a/test/test_factory.py +++ b/test/test_factory.py @@ -41,7 +41,6 @@ def test_get_transfer_s3( mock_notifier: Mock, config: Config, ) -> None: - expected_config_arg = dict(config) expected_config_arg.pop("storage_type") expected_config_arg.pop("notifier")