Skip to content

Commit

Permalink
Merge pull request #144 from martindurant/guesses
Browse files Browse the repository at this point in the history
Remake format guess functions
  • Loading branch information
martindurant authored Aug 29, 2024
2 parents 3a355aa + 0906e06 commit ac4ddd3
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 74 deletions.
98 changes: 62 additions & 36 deletions src/snappy/snappy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,15 @@ def __init__(self):
self.remains = None

@staticmethod
def check_format(data):
"""Checks that the given data starts with snappy framing format
stream identifier.
Raises UncompressError if it doesn't start with the identifier.
:return: None
def check_format(fin):
"""Does this stream start with a stream header block?
True indicates that the stream can likely be decoded using this class.
"""
if len(data) < 6:
raise UncompressError("Too short data length")
chunk_type = struct.unpack("<L", data[:4])[0]
size = (chunk_type >> 8)
chunk_type &= 0xff
if (chunk_type != _IDENTIFIER_CHUNK or
size != len(_STREAM_IDENTIFIER)):
raise UncompressError("stream missing snappy identifier")
chunk = data[4:4 + size]
if chunk != _STREAM_IDENTIFIER:
raise UncompressError("stream has invalid snappy identifier")
try:
return fin.read(len(_STREAM_HEADER_BLOCK)) == _STREAM_HEADER_BLOCK
except:
return False

def decompress(self, data: bytes):
"""Decompress 'data', returning a string containing the uncompressed
Expand Down Expand Up @@ -233,14 +225,21 @@ def __init__(self):
self.remains = b""

@staticmethod
def check_format(data):
"""Checks that there are enough bytes for a hadoop header
We cannot actually determine if the data is really hadoop-snappy
def check_format(fin):
"""Does this look like a hadoop snappy stream?
"""
if len(data) < 8:
raise UncompressError("Too short data length")
chunk_length = int.from_bytes(data[4:8], "big")
try:
from snappy.snappy_formats import check_unframed_format
size = fin.seek(0, 2)
fin.seek(0)
assert size >= 8

chunk_length = int.from_bytes(fin.read(4), "big")
assert chunk_length < size
fin.read(4)
return check_unframed_format(fin)
except:
return False

def decompress(self, data: bytes):
"""Decompress 'data', returning a string containing the uncompressed
Expand Down Expand Up @@ -319,16 +318,43 @@ def stream_decompress(src,
decompressor.flush() # makes sure the stream ended well


def check_format(fin=None, chunk=None,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
decompressor_cls=StreamDecompressor):
ok = True
if chunk is None:
chunk = fin.read(blocksize)
if not chunk:
raise UncompressError("Empty input stream")
try:
decompressor_cls.check_format(chunk)
except UncompressError as err:
ok = False
return ok, chunk
def hadoop_stream_decompress(
src,
dst,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
):
c = HadoopStreamDecompressor()
while True:
data = src.read(blocksize)
if not data:
break
buf = c.decompress(data)
if buf:
dst.write(buf)
dst.flush()


def hadoop_stream_compress(
src,
dst,
blocksize=_STREAM_TO_STREAM_BLOCK_SIZE,
):
c = HadoopStreamCompressor()
while True:
data = src.read(blocksize)
if not data:
break
buf = c.compress(data)
if buf:
dst.write(buf)
dst.flush()


def raw_stream_decompress(src, dst):
data = src.read()
dst.write(decompress(data))


def raw_stream_compress(src, dst):
data = src.read()
dst.write(compress(data))
96 changes: 69 additions & 27 deletions src/snappy/snappy_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,65 +8,107 @@
from __future__ import absolute_import

from .snappy import (
stream_compress, stream_decompress, check_format, UncompressError)

HadoopStreamDecompressor, StreamDecompressor,
hadoop_stream_compress, hadoop_stream_decompress, raw_stream_compress,
raw_stream_decompress, stream_compress, stream_decompress,
UncompressError
)

FRAMING_FORMAT = 'framing'

# Means format auto detection.
# For compression will be used framing format.
# In case of decompression will try to detect a format from the input stream
# header.
FORMAT_AUTO = 'auto'
DEFAULT_FORMAT = "auto"

DEFAULT_FORMAT = FORMAT_AUTO

ALL_SUPPORTED_FORMATS = [FRAMING_FORMAT, FORMAT_AUTO]
ALL_SUPPORTED_FORMATS = ["framing", "auto"]

_COMPRESS_METHODS = {
FRAMING_FORMAT: stream_compress,
"framing": stream_compress,
"hadoop": hadoop_stream_compress,
"raw": raw_stream_compress
}

_DECOMPRESS_METHODS = {
FRAMING_FORMAT: stream_decompress,
"framing": stream_decompress,
"hadoop": hadoop_stream_decompress,
"raw": raw_stream_decompress
}

# We will use framing format as the default to compression.
# And for decompression, if it's not defined explicitly, we will try to
# guess the format from the file header.
_DEFAULT_COMPRESS_FORMAT = FRAMING_FORMAT
_DEFAULT_COMPRESS_FORMAT = "framing"


def uvarint(fin):
"""Read uint64 nbumber from varint encoding in a stream"""
result = 0
shift = 0
while True:
byte = fin.read(1)[0]
result |= (byte & 0x7F) << shift
if (byte & 0x80) == 0:
break
shift += 7
return result


def check_unframed_format(fin, reset=False):
"""Can this be read using the raw codec
This function wil return True for all snappy raw streams, but
True does not mean that we can necessarily decode the stream.
"""
if reset:
fin.seek(0)
try:
size = uvarint(fin)
assert size < 2**32 - 1
next_byte = fin.read(1)[0]
end = fin.seek(0, 2)
assert size < end
assert next_byte & 0b11 == 0 # must start with literal block
return True
except:
return False


# The tuple contains an ordered sequence of a format checking function and
# a format-specific decompression function.
# Framing format has it's header, that may be recognized.
_DECOMPRESS_FORMAT_FUNCS = (
(check_format, stream_decompress),
)
_DECOMPRESS_FORMAT_FUNCS = {
"framed": stream_decompress,
"hadoop": hadoop_stream_decompress,
"raw": raw_stream_decompress
}


def guess_format_by_header(fin):
"""Tries to guess a compression format for the given input file by it's
header.
:return: tuple of decompression method and a chunk that was taken from the
input for format detection.
:return: format name (str), stream decompress function (callable)
"""
chunk = None
for check_method, decompress_func in _DECOMPRESS_FORMAT_FUNCS:
ok, chunk = check_method(fin=fin, chunk=chunk)
if not ok:
continue
return decompress_func, chunk
raise UncompressError("Can't detect archive format")
if StreamDecompressor.check_format(fin):
form = "framed"
elif HadoopStreamDecompressor.check_format(fin):
form = "hadoop"
elif check_unframed_format(fin, reset=True):
form = "raw"
else:
raise UncompressError("Can't detect format")
return form, _DECOMPRESS_FORMAT_FUNCS[form]


def get_decompress_function(specified_format, fin):
if specified_format == FORMAT_AUTO:
decompress_func, read_chunk = guess_format_by_header(fin)
return decompress_func, read_chunk
return _DECOMPRESS_METHODS[specified_format], None
if specified_format == "auto":
format, decompress_func = guess_format_by_header(fin)
return decompress_func
return _DECOMPRESS_METHODS[specified_format]


def get_compress_function(specified_format):
if specified_format == FORMAT_AUTO:
if specified_format == "auto":
return _COMPRESS_METHODS[_DEFAULT_COMPRESS_FORMAT]
return _COMPRESS_METHODS[specified_format]
45 changes: 34 additions & 11 deletions test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from unittest import TestCase

from snappy import snappy_formats as formats
from snappy.snappy import _CHUNK_MAX, UncompressError


class TestFormatBase(TestCase):
compress_format = formats.FORMAT_AUTO
decompress_format = formats.FORMAT_AUTO
compress_format = "auto"
decompress_format = "auto"
success = True

def runTest(self):
Expand All @@ -18,34 +17,58 @@ def runTest(self):
compressed_stream = io.BytesIO()
compress_func(instream, compressed_stream)
compressed_stream.seek(0)
decompress_func, read_chunk = formats.get_decompress_function(
decompress_func = formats.get_decompress_function(
self.decompress_format, compressed_stream
)
compressed_stream.seek(0)
decompressed_stream = io.BytesIO()
decompress_func(
compressed_stream,
decompressed_stream,
start_chunk=read_chunk
)
decompressed_stream.seek(0)
self.assertEqual(data, decompressed_stream.read())


class TestFormatFramingFraming(TestFormatBase):
compress_format = formats.FRAMING_FORMAT
decompress_format = formats.FRAMING_FORMAT
compress_format = "framing"
decompress_format = "framing"
success = True


class TestFormatFramingAuto(TestFormatBase):
compress_format = formats.FRAMING_FORMAT
decompress_format = formats.FORMAT_AUTO
compress_format = "framing"
decompress_format = "auto"
success = True


class TestFormatAutoFraming(TestFormatBase):
compress_format = formats.FORMAT_AUTO
decompress_format = formats.FRAMING_FORMAT
compress_format = "auto"
decompress_format = "framing"
success = True


class TestFormatHadoop(TestFormatBase):
compress_format = "hadoop"
decompress_format = "hadoop"
success = True


class TestFormatRaw(TestFormatBase):
compress_format = "raw"
decompress_format = "raw"
success = True


class TestFormatHadoopAuto(TestFormatBase):
compress_format = "hadoop"
decompress_format = "auto"
success = True


class TestFormatRawAuto(TestFormatBase):
compress_format = "raw"
decompress_format = "auto"
success = True


Expand Down

0 comments on commit ac4ddd3

Please sign in to comment.