diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ab0aae79..1f6f61a0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -70,7 +70,6 @@ jobs: - id: dependencies run: | - sudo apt-get install -y libsnappy-dev pip install -r requirements.txt pip install -r requirements.dev.txt diff --git a/.pylintrc b/.pylintrc index f95b17ec..ecae4979 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,5 +1,5 @@ [MAIN] -extension-pkg-allow-list=pydantic +extension-pkg-allow-list=cramjam,pydantic [MESSAGES CONTROL] disable= diff --git a/mypy.ini b/mypy.ini index db530bf0..42b5aab0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -14,10 +14,10 @@ ignore_missing_imports = True [mypy-azure.common.*] ignore_missing_imports = True -[mypy-oauth2client.*] +[mypy-cramjam.*] ignore_missing_imports = True -[mypy-snappy.*] +[mypy-oauth2client.*] ignore_missing_imports = True [mypy-swiftclient.*] diff --git a/rohmu.spec b/rohmu.spec index a3e43496..a3e5fc53 100644 --- a/rohmu.spec +++ b/rohmu.spec @@ -13,7 +13,8 @@ Requires: python3-cryptography >= 1.6 Requires: python3-dateutil Requires: python3-pydantic Requires: python3-requests -Requires: python3-snappy +# Requires: python3-snappy +# TODO: Create python3-cramjam Requires: python3-zstandard BuildRequires: python3-devel BuildRequires: python3-flake8 diff --git a/rohmu/compressor.py b/rohmu/compressor.py index f8c88c53..cb18d541 100644 --- a/rohmu/compressor.py +++ b/rohmu/compressor.py @@ -7,16 +7,43 @@ from .errors import InvalidConfigurationError from .filewrap import Sink, Stream from .snappyfile import SnappyFile -from .typing import BinaryData, Compressor, Decompressor, FileLike, HasRead, HasWrite +from .typing import BinaryData, CompressionAlgorithm, Compressor, Decompressor, FileLike, HasRead, HasWrite from .zstdfile import open as zstd_open from typing import cast, IO import lzma + try: - import snappy + import cramjam + + # Cramjam streaming classes are lazy and diverge from Compressor and Decompressor interfaces. + # Adapt the parent classes to flush and return the inner buffer after compress and decompress calls. + class CramjamStreamingCompressor(Compressor): + def __init__(self) -> None: + self._compressor = cramjam.snappy.Compressor() + + def compress(self, data: bytes) -> bytes: + self._compressor.compress(data) + return self.flush() + + def flush(self) -> bytes: + buf = self._compressor.flush() + return buf.read() + + class CramjamStreamingDecompressor(Decompressor): + def __init__(self) -> None: + self._decompressor = cramjam.snappy.Decompressor() + + def decompress(self, data: bytes) -> bytes: + self._decompressor.decompress(data) + buf = self._decompressor.flush() + return buf.read() + except ImportError: - snappy = None # type: ignore + cramjam = None # type: ignore + CramjamStreamingCompressor: Compressor | None = None # type: ignore[no-redef] + CramjamStreamingDecompressor: Decompressor | None = None # type: ignore[no-redef] try: import zstandard as zstd @@ -24,37 +51,43 @@ zstd = None # type: ignore -def CompressionFile(dst_fp: FileLike, algorithm: str, level: int = 0, threads: int = 0) -> FileLike: +def CompressionFile(dst_fp: FileLike, algorithm: CompressionAlgorithm, level: int = 0, threads: int = 0) -> FileLike: """This looks like a class to users, but is actually a function that instantiates a class based on algorithm.""" - if algorithm == "lzma": - return lzma.open(cast(IO[bytes], dst_fp), "w", preset=level) - - if algorithm == "snappy": - return SnappyFile(dst_fp, "wb") - - if algorithm == "zstd": - return zstd_open(dst_fp, "wb", level=level, threads=threads) - - if algorithm: - raise InvalidConfigurationError(f"invalid compression algorithm: {repr(algorithm)}") - - return dst_fp + compression_fileobj: FileLike + match algorithm: + case "lzma": + compression_fileobj = lzma.open(cast(IO[bytes], dst_fp), "w", preset=level) + case "snappy": + compression_fileobj = SnappyFile(dst_fp, "wb") + case "zstd": + compression_fileobj = zstd_open(dst_fp, "wb", level=level, threads=threads) + case _: + raise InvalidConfigurationError(f"invalid compression algorithm: {repr(algorithm)}") + return compression_fileobj + + +def create_streaming_compressor(algorithm: CompressionAlgorithm, level: int = 0) -> Compressor: + compressor: Compressor + match algorithm: + case "lzma": + compressor = lzma.LZMACompressor(lzma.FORMAT_XZ, -1, level, None) + case "snappy": + if CramjamStreamingCompressor is None: + raise ImportError("Unable to import cramjam") + compressor = CramjamStreamingCompressor() + case "zstd": + compressor = zstd.ZstdCompressor(level=level).compressobj() + case _: + raise InvalidConfigurationError(f"invalid compression algorithm: {repr(algorithm)}") + return compressor class CompressionStream(Stream): """Non-seekable stream of data that adds compression on top of given source stream""" - def __init__(self, src_fp: HasRead, algorithm: str, level: int = 0) -> None: + def __init__(self, src_fp: HasRead, algorithm: CompressionAlgorithm, level: int = 0) -> None: super().__init__(src_fp, minimum_read_size=32 * 1024) - self._compressor: Compressor - if algorithm == "lzma": - self._compressor = lzma.LZMACompressor(lzma.FORMAT_XZ, -1, level, None) - elif algorithm == "snappy": - self._compressor = snappy.StreamCompressor() - elif algorithm == "zstd": - self._compressor = zstd.ZstdCompressor(level=level).compressobj() - else: - raise InvalidConfigurationError(f"invalid compression algorithm: {repr(algorithm)}") + self._compressor = create_streaming_compressor(algorithm, level) def _process_chunk(self, data: bytes) -> bytes: return self._compressor.compress(data) @@ -63,42 +96,45 @@ def _finalize(self) -> bytes: return self._compressor.flush() -def DecompressionFile(src_fp: FileLike, algorithm: str) -> FileLike: +def DecompressionFile(src_fp: FileLike, algorithm: CompressionAlgorithm) -> FileLike: """This looks like a class to users, but is actually a function that instantiates a class based on algorithm.""" - if algorithm == "lzma": - return lzma.open(cast(IO[bytes], src_fp), "r") - - if algorithm == "snappy": - return SnappyFile(src_fp, "rb") - - if algorithm == "zstd": - return zstd_open(src_fp, "rb") + match algorithm: + case "lzma": + return lzma.open(cast(IO[bytes], src_fp), "r") + case "snappy": + return SnappyFile(src_fp, "rb") + case "zstd": + return zstd_open(src_fp, "rb") + case _: + raise InvalidConfigurationError(f"invalid compression algorithm: {repr(algorithm)}") - if algorithm: - raise InvalidConfigurationError(f"invalid compression algorithm: {repr(algorithm)}") - return src_fp +def create_streaming_decompressor(algorithm: CompressionAlgorithm) -> Decompressor: + decompressor: Decompressor + match algorithm: + case "lzma": + decompressor = lzma.LZMADecompressor() + case "snappy": + if CramjamStreamingDecompressor is None: + raise ImportError("Unable to import cramjam") + decompressor = CramjamStreamingDecompressor() + case "zstd": + decompressor = zstd.ZstdDecompressor().decompressobj() + case _: + raise InvalidConfigurationError(f"invalid compression algorithm: {repr(algorithm)}") + return decompressor class DecompressSink(Sink): - def __init__(self, next_sink: HasWrite, compression_algorithm: str): + def __init__(self, next_sink: HasWrite, compression_algorithm: CompressionAlgorithm): super().__init__(next_sink) - self.decompressor = self._create_decompressor(compression_algorithm) - - def _create_decompressor(self, alg: str) -> Decompressor: - if alg == "snappy": - return snappy.StreamDecompressor() - elif alg == "lzma": - return lzma.LZMADecompressor() - elif alg == "zstd": - return zstd.ZstdDecompressor().decompressobj() - raise InvalidConfigurationError(f"invalid compression algorithm: {repr(alg)}") + self.decompressor = create_streaming_decompressor(compression_algorithm) def write(self, data: BinaryData) -> int: data = bytes(data) if not isinstance(data, bytes) else data written = len(data) if not data: return written - data = self.decompressor.decompress(data) - self._write_to_next_sink(data) + decompressed_data = self.decompressor.decompress(data) + self._write_to_next_sink(decompressed_data) return written diff --git a/rohmu/filewrap.py b/rohmu/filewrap.py index 0e141b7c..4006b32a 100644 --- a/rohmu/filewrap.py +++ b/rohmu/filewrap.py @@ -178,7 +178,6 @@ def read(self, size: int = -1) -> bytes: bytes_available += len(dst_data) if not src_data: self._eof = True - if size < 0 or bytes_available < size: data = b"".join(chunks) self._remainder = b"" diff --git a/rohmu/rohmufile.py b/rohmu/rohmufile.py index 21c9aa70..a16bd0cf 100644 --- a/rohmu/rohmufile.py +++ b/rohmu/rohmufile.py @@ -12,7 +12,7 @@ from .encryptor import DecryptorFile, DecryptSink, EncryptorFile from .errors import InvalidConfigurationError from .filewrap import ThrottleSink -from .typing import FileLike, HasWrite, Metadata +from .typing import CompressionAlgorithm, FileLike, HasWrite, Metadata from contextlib import suppress from inspect import signature from rohmu.object_storage.base import IncrementalProgressCallbackType @@ -143,7 +143,7 @@ def read_file( def file_writer( *, fileobj: FileLike, - compression_algorithm: Optional[str] = None, + compression_algorithm: Optional[CompressionAlgorithm] = None, compression_level: int = 0, compression_threads: int = 0, rsa_public_key: Union[None, str, bytes] = None, @@ -162,7 +162,7 @@ def write_file( input_obj: FileLike, output_obj: FileLike, progress_callback: IncrementalProgressCallbackType = None, - compression_algorithm: Optional[str] = None, + compression_algorithm: Optional[CompressionAlgorithm] = None, compression_level: int = 0, compression_threads: int = 0, rsa_public_key: Union[None, str, bytes] = None, diff --git a/rohmu/snappyfile.py b/rohmu/snappyfile.py index deee951d..ea58722a 100644 --- a/rohmu/snappyfile.py +++ b/rohmu/snappyfile.py @@ -12,22 +12,22 @@ import io try: - import snappy + import cramjam except ImportError: - snappy = None # type: ignore + cramjam = None # type: ignore class SnappyFile(FileWrap): def __init__(self, next_fp: FileLike, mode: str) -> None: - if snappy is None: + if cramjam is None: raise io.UnsupportedOperation("Snappy is not available") if mode == "rb": - self.decr = snappy.StreamDecompressor() + self.decr = cramjam.snappy.Decompressor() self.encr = None elif mode == "wb": self.decr = None - self.encr = snappy.StreamCompressor() + self.encr = cramjam.snappy.Compressor() else: raise io.UnsupportedOperation("unsupported mode for SnappyFile") @@ -49,10 +49,11 @@ def write(self, data: BinaryData) -> int: # type: ignore [override] if self.encr is None: raise io.UnsupportedOperation("file not open for writing") data_as_bytes = bytes(data) - compressed_data = self.encr.compress(data_as_bytes) - self.next_fp.write(compressed_data) - self.offset += len(data_as_bytes) - return len(data_as_bytes) + block_size = self.encr.compress(data_as_bytes) + compressed_buffer = self.encr.flush() + self.next_fp.write(compressed_buffer) + self.offset += block_size + return block_size def writable(self) -> bool: return self.encr is not None @@ -62,19 +63,13 @@ def read(self, size: Optional[int] = -1) -> bytes: # pylint: disable=unused-arg self._check_not_closed() if self.decr is None: raise io.UnsupportedOperation("file not open for reading") - while not self.decr_done: - compressed = self.next_fp.read(IO_BLOCK_SIZE) - if not compressed: - self.decr_done = True - output = self.decr.flush() - else: - output = self.decr.decompress(compressed) - - if output: - self.offset += len(output) - return output - - return b"" + num_decompressed_bytes = 0 + while compressed := self.next_fp.read(IO_BLOCK_SIZE): + chunk_size = self.decr.decompress(compressed) + num_decompressed_bytes += chunk_size + self.offset += num_decompressed_bytes + output = self.decr.flush().read() + return output def readable(self) -> bool: return self.decr is not None diff --git a/rohmu/typing.py b/rohmu/typing.py index f7f32002..34ab3eff 100644 --- a/rohmu/typing.py +++ b/rohmu/typing.py @@ -1,7 +1,7 @@ from __future__ import annotations from types import TracebackType -from typing import Any, Dict, Optional, Protocol, Type, TYPE_CHECKING, Union +from typing import Any, Dict, Literal, Optional, Protocol, Type, TYPE_CHECKING, Union try: # Remove when dropping support for Python 3.7 @@ -32,6 +32,8 @@ StrOrPathLike = Union[str, "PathLike[str]"] +CompressionAlgorithm = Literal["lzma", "snappy", "zstd"] + class HasFileno(Protocol): def fileno(self) -> int: diff --git a/setup.cfg b/setup.cfg index 0efb458e..b5668504 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ packages = find: install_requires = azure-storage-blob >= 2.1.0 botocore + cramjam >= 2.7.0 cryptography google-api-python-client httplib2 @@ -32,7 +33,6 @@ install_requires = paramiko pydantic < 2 python-dateutil - python-snappy requests zstandard typing_extensions >= 3.10, < 5 diff --git a/test/test_compressor.py b/test/test_compressor.py index b13cdf04..65683275 100644 --- a/test/test_compressor.py +++ b/test/test_compressor.py @@ -1,19 +1,19 @@ -from rohmu.compressor import CompressionFile, DecompressionFile +from rohmu.compressor import CompressionFile, CompressionStream, DecompressionFile, DecompressSink +from rohmu.typing import CompressionAlgorithm +from typing import Final import io +import math import pytest +import random +import string +SAMPLE_BYTES: Final[bytes] = b"Some contents" -@pytest.mark.parametrize( - "algorithm,contents", - [ - ("lzma", b""), - ("snappy", b""), - ("lzma", b"Some contents"), - ("snappy", b"Some contents"), - ], -) -def test_compress_decompress_simple_file(algorithm: str, contents: bytes) -> None: + +@pytest.mark.parametrize("algorithm", ["lzma", "snappy", "zstd"]) +@pytest.mark.parametrize("contents", [b"", 100 * SAMPLE_BYTES], ids=["empty", "sample-bytes"]) +def test_compress_decompress_simple_file(algorithm: CompressionAlgorithm, contents: bytes) -> None: bio = io.BytesIO() ef = CompressionFile(bio, algorithm=algorithm) ef.write(contents) @@ -23,3 +23,39 @@ def test_compress_decompress_simple_file(algorithm: str, contents: bytes) -> Non df = DecompressionFile(bio, algorithm=algorithm) data = df.read() assert data == contents + + +@pytest.mark.skip(reason="neither snappy nor zstd seem to handle multiple chunks") +@pytest.mark.parametrize("algorithm", ["lzma", "snappy", "zstd"]) +def test_compress_decompress_multiple_chunks(algorithm: CompressionAlgorithm) -> None: + contents = "".join(random.choices(string.ascii_letters + string.digits, k=1_000_000)).encode() + num_bytes = len(contents) + print(f"Data size exponent (block size = 20): {math.log2(num_bytes)}") + bytes_buf = io.BytesIO() + ef = CompressionFile(bytes_buf, algorithm=algorithm) + ef.write(contents) + ef.close() + + bytes_buf.seek(0) + df = DecompressionFile(bytes_buf, algorithm=algorithm) + data = df.read() + assert data == contents + + +@pytest.mark.parametrize("algorithm", ["lzma", "snappy", "zstd"]) +@pytest.mark.parametrize("contents", [b"", 100 * SAMPLE_BYTES], ids=["empty", "sample-bytes"]) +def test_compress_decompress_streaming(algorithm: CompressionAlgorithm, contents: bytes) -> None: + input_buffer = io.BytesIO() + input_buffer.write(contents) + input_buffer.seek(0) + compression_stream = CompressionStream(input_buffer, algorithm) + compressed_bytes = compression_stream.read() + assert compression_stream.tell() == len(compressed_bytes) + + output_buffer = io.BytesIO() + decompression_sink = DecompressSink(output_buffer, algorithm) + num_bytes_written = decompression_sink.write(compressed_bytes) + assert num_bytes_written == len(compressed_bytes) + output_buffer.seek(0) + output_data = output_buffer.read() + assert output_data == contents diff --git a/test/test_dates.py b/test/test_dates.py index ff8ae69d..d251d52b 100644 --- a/test/test_dates.py +++ b/test/test_dates.py @@ -4,11 +4,13 @@ Copyright (c) 2017 Ohmu Ltd See LICENSE for details """ +from dateutil.parser import UnknownTimezoneWarning # type: ignore[attr-defined] from rohmu.dates import parse_timestamp import datetime import dateutil.tz import re +import warnings def test_parse_timestamp() -> None: @@ -24,6 +26,8 @@ def test_parse_timestamp() -> None: assert local_naive == local_aware.replace(tzinfo=None) str_unknown_aware = "2017-02-02 12:00:00 XYZ" + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UnknownTimezoneWarning) unknown_aware_utc = parse_timestamp(str_unknown_aware) assert unknown_aware_utc.tzinfo == datetime.timezone.utc assert unknown_aware_utc.isoformat() == "2017-02-02T12:00:00+00:00"