diff --git a/CMakeLists.txt b/CMakeLists.txt index ab600a0..72b697c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -86,7 +86,11 @@ if(STREAMVBYTE_SANITIZE) -fno-omit-frame-pointer -fno-sanitize-recover=all ) - add_compile_definitions(ASAN_OPTIONS=detect_leaks=1) + add_link_options( + -fsanitize=address + -fno-omit-frame-pointer + -fno-sanitize-recover=all + ) endif() if(MSVC) diff --git a/Makefile b/Makefile index 70b6383..2f312c9 100644 --- a/Makefile +++ b/Makefile @@ -83,10 +83,10 @@ writeseq: ./tests/writeseq.c $(HEADERS) $(OBJECTS) $(CC) $(CFLAGS) -o writeseq ./tests/writeseq.c -Iinclude $(OBJECTS) unit: ./tests/unit.c $(HEADERS) $(OBJECTS) - $(CC) $(CFLAGS) -o unit ./tests/unit.c -Iinclude $(OBJECTS) + $(CC) $(CFLAGS) -o unit ./tests/unit.c -Iinclude -Isrc $(OBJECTS) dynunit: ./tests/unit.c $(HEADERS) $(LIBNAME) $(LNLIBNAME) - $(CC) $(CFLAGS) -o dynunit ./tests/unit.c -Iinclude -L. -lstreamvbyte + $(CC) $(CFLAGS) -o dynunit ./tests/unit.c -Iinclude -Isrc -L. -lstreamvbyte clean: rm -f unit *.o $(LIBNAME) $(LNLIBNAME) example shuffle_tables perf writeseq dynunit diff --git a/README.md b/README.md index 8ae0c2e..b902096 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,19 @@ information along with the compressed stream. During decoding, the library may read up to `STREAMVBYTE_PADDING` extra bytes from the input buffer (these bytes are read but never used). +To verify that the expected size of a stream is correct you may validate it before +decoding: +```C +// compressedbuffer, compsize, recovdata, N are as above +if (streamvbyte_validate_stream(compressedbuffer, compsize, N)) { + // the stream is safe to decode + streamvbyte_decode(compressedbuffer, recovdata, N); +} else { + // there's a mismatch between the expected size of the data (N) and the contents of + // the stream, so performing a decode is unsafe since the behaviour is undefined +} +``` + diff --git a/include/streamvbyte.h b/include/streamvbyte.h index e88ab08..35d17aa 100644 --- a/include/streamvbyte.h +++ b/include/streamvbyte.h @@ -1,6 +1,7 @@ #ifndef INCLUDE_STREAMVBYTE_H_ #define INCLUDE_STREAMVBYTE_H_ +#include #include #include @@ -10,7 +11,7 @@ extern "C" { #define STREAMVBYTE_PADDING 16 -// Encode an array of a given length read from in to bout in varint format. +// Encode an array of a given length read from in to out in varint format. // Returns the number of bytes written. // The number of values being stored (length) is not encoded in the compressed stream, // the caller is responsible for keeping a record of this length. @@ -66,6 +67,17 @@ size_t streamvbyte_decode(const uint8_t* in, uint32_t* out, uint32_t length); // streamvbyte_encode_0124. size_t streamvbyte_decode_0124(const uint8_t* in, uint32_t* out, uint32_t length); +// Validate an encoded stream. +// This can be used to validate that data received from an untrusted source (disk, network, +// etc...) has a valid length stored alongside it. +// "inLength" is the size of the encoded data "in", and "outLength" is the expected number +// of integers that were compressed. +bool streamvbyte_validate_stream(const uint8_t* in, size_t inLength, uint32_t outLength); + +// Same as streamvbyte_validate_stream but is meant to be used for streams encoded with +// streamvbyte_encode_0124. +bool streamvbyte_validate_stream_0124(const uint8_t* in, size_t inLength, uint32_t outLength); + #ifdef __cplusplus } #endif diff --git a/include/streamvbytedelta.h b/include/streamvbytedelta.h index 9ee8040..55ef010 100644 --- a/include/streamvbytedelta.h +++ b/include/streamvbytedelta.h @@ -8,7 +8,7 @@ extern "C" { #endif -// Encode an array of a given length read from in to bout in StreamVByte format. +// Encode an array of a given length read from in to out in StreamVByte format. // Returns the number of bytes written. // The number of values being stored (length) is not encoded in the compressed stream, // the caller is responsible for keeping a record of this length. The pointer "in" should diff --git a/src/streamvbyte_0124_decode.c b/src/streamvbyte_0124_decode.c index 524cfe7..a6bb6f3 100644 --- a/src/streamvbyte_0124_decode.c +++ b/src/streamvbyte_0124_decode.c @@ -181,5 +181,49 @@ size_t streamvbyte_decode_0124(const uint8_t *in, uint32_t *out, uint32_t count) #endif return (size_t)(svb_decode_scalar(out, keyPtr, dataPtr, count) - in); +} + +bool streamvbyte_validate_stream_0124(const uint8_t *in, size_t inCount, + uint32_t outCount) { + if (inCount == 0 || outCount == 0) + return inCount == outCount; + + // 2-bits per key (rounded up) + // Note that we don't add to outCount in case it overflows + uint32_t keyLen = outCount / 4; + if (outCount & 3) + keyLen++; + + // Check that there's enough space for the keys + if (keyLen > inCount) + return false; + + // Accumulate the key sizes in a wider type to avoid overflow + const uint8_t *keyPtr = in; + uint64_t encodedSize = 0; + + // Give the compiler a hint that it can avoid branches in the inner loop + for (uint32_t c = 0; c < outCount / 4; c++) { + uint32_t key = *keyPtr++; + for (uint8_t shift = 0; shift < 8; shift += 2) { + const uint8_t code = (key >> shift) & 0x3; + encodedSize += (1 << code) >> 1; + } + } + outCount &= 3; + + // Process the remainder one at a time + uint8_t shift = 0; + uint32_t key = *keyPtr++; + for (uint32_t c = 0; c < outCount; c++) { + if (shift == 8) { + shift = 0; + key = *keyPtr++; + } + const uint8_t code = (key >> shift) & 0x3; + encodedSize += (1 << code) >> 1; + shift += 2; + } + return encodedSize == inCount - keyLen; } diff --git a/src/streamvbyte_arm_decode.c b/src/streamvbyte_arm_decode.c index 02caaba..7480255 100644 --- a/src/streamvbyte_arm_decode.c +++ b/src/streamvbyte_arm_decode.c @@ -51,4 +51,32 @@ static const uint8_t *svb_decode_vector(uint32_t *out, const uint8_t *keyPtr, co return dataPtr; } + +static uint64_t svb_validate_vector(const uint8_t **keyPtrPtr, + uint32_t *countPtr) { + // Reduce the count by how many we'll process + const uint32_t count = *countPtr & ~7U; + const uint8_t *keyPtr = *keyPtrPtr; + *countPtr &= 7; + *keyPtrPtr += count / 4; + + // Deal with each of the 4 keys in a separate lane + const int32x4_t shifts = {0, -2, -4, -6}; + const uint32x4_t mask = vdupq_n_u32(3); + uint32x4_t acc0 = vdupq_n_u32(0); + uint32x4_t acc1 = vdupq_n_u32(0); + + // Unrolling more than twice doesn't seem to improve performance + for (uint32_t c = 0; c < count; c += 8) { + uint32x4_t shifted0 = vshlq_u32(vdupq_n_u32(*keyPtr++), shifts); + acc0 = vaddq_u32(acc0, vandq_u32(shifted0, mask)); + uint32x4_t shifted1 = vshlq_u32(vdupq_n_u32(*keyPtr++), shifts); + acc1 = vaddq_u32(acc1, vandq_u32(shifted1, mask)); + } + + // Accumulate the sums and add the +1 for each element (count) + uint64x2_t sum0 = vpaddlq_u32(acc0); + uint64x2_t sum1 = vpaddlq_u32(acc1); + return sum0[0] + sum0[1] + sum1[0] + sum1[1] + count; +} #endif diff --git a/src/streamvbyte_decode.c b/src/streamvbyte_decode.c index 4049a2e..b99c7e3 100644 --- a/src/streamvbyte_decode.c +++ b/src/streamvbyte_decode.c @@ -84,5 +84,53 @@ size_t streamvbyte_decode(const uint8_t *in, uint32_t *out, uint32_t count) { #endif return (size_t)(svb_decode_scalar(out, keyPtr, dataPtr, count) - in); +} + +bool streamvbyte_validate_stream(const uint8_t *in, size_t inCount, + uint32_t outCount) { + if (inCount == 0 || outCount == 0) + return inCount == outCount; + + // 2-bits per key (rounded up) + // Note that we don't add to outCount in case it overflows + uint32_t keyLen = outCount / 4; + if (outCount & 3) + keyLen++; + + // Check that there's enough space for the keys + if (keyLen > inCount) + return false; + + // Accumulate the key sizes in a wider type to avoid overflow + const uint8_t *keyPtr = in; + uint64_t encodedSize = 0; + +#if defined(__ARM_NEON__) + encodedSize = svb_validate_vector(&keyPtr, &outCount); +#endif + + // Give the compiler a hint that it can avoid branches in the inner loop + for (uint32_t c = 0; c < outCount / 4; c++) { + uint32_t key = *keyPtr++; + for (uint8_t shift = 0; shift < 8; shift += 2) { + const uint8_t code = (key >> shift) & 0x3; + encodedSize += code + 1; + } + } + outCount &= 3; + + // Process the remainder one at a time + uint8_t shift = 0; + uint32_t key = *keyPtr++; + for (uint32_t c = 0; c < outCount; c++) { + if (shift == 8) { + shift = 0; + key = *keyPtr++; + } + const uint8_t code = (key >> shift) & 0x3; + encodedSize += code + 1; + shift += 2; + } + return encodedSize == inCount - keyLen; } diff --git a/src/streamvbyte_encode.c b/src/streamvbyte_encode.c index ef9667c..f1602bd 100644 --- a/src/streamvbyte_encode.c +++ b/src/streamvbyte_encode.c @@ -108,7 +108,7 @@ size_t streamvbyte_compressedbytes_0124(const uint32_t* in, uint32_t length) { } -// Encode an array of a given length read from in to bout in streamvbyte format. +// Encode an array of a given length read from in to out in streamvbyte format. // Returns the number of bytes written. size_t streamvbyte_encode(const uint32_t *in, uint32_t count, uint8_t *out) { #ifdef STREAMVBYTE_X64 diff --git a/tests/unit.c b/tests/unit.c index e3d23a4..30e5f2e 100644 --- a/tests/unit.c +++ b/tests/unit.c @@ -47,6 +47,8 @@ static int zigzagtests(void) { } } + free(deltadataout); + free(deltadataback); free(databack); free(dataout); free(datain); @@ -96,11 +98,18 @@ static int basictests(void) { for (uint32_t length = 0; length <= N;) { for (uint32_t gap = 1; gap <= 387420489; gap *= 3) { - for (uint32_t k = 0; k < length; ++k) - datain[k] = gap - 1 + ((uint32_t)rand() % 8); // sometimes start with zero + datain[0] = (uint32_t)rand() % 8; // sometimes start with zero + for (uint32_t k = 1; k < length; ++k) + datain[k] = datain[k - 1] + gap - 1 + (uint32_t)rand() % 8; // Default encoding: 1,2,3,4 bytes per value size_t compsize = streamvbyte_encode(datain, length, compressedbuffer); + if (!streamvbyte_validate_stream(compressedbuffer, compsize, length)) { + printf("[streamvbyte_validate_stream] code is buggy length=%d gap=%d: compsize=%d\n", + (int)length, (int)gap, (int)compsize); + return -1; + } + size_t usedbytes = streamvbyte_decode(compressedbuffer, recovdata, length); if (compsize != usedbytes) { printf("[streamvbyte_decode] code is buggy length=%d gap=%d: compsize=%d != " @@ -118,6 +127,12 @@ static int basictests(void) { // Alternative encoding: 0,1,2,4 bytes per value compsize = streamvbyte_encode_0124(datain, length, compressedbuffer); + if (!streamvbyte_validate_stream_0124(compressedbuffer, compsize, length)) { + printf("[streamvbyte_validate_stream_0124] code is buggy length=%d gap=%d: compsize=%d\n", + (int)length, (int)gap, (int)compsize); + return -1; + } + usedbytes = streamvbyte_decode_0124(compressedbuffer, recovdata, length); if (compsize != usedbytes) { printf("[streamvbyte_decode_0124] code is buggy length=%d gap=%d: compsize=%d != " @@ -197,29 +212,37 @@ static int aqrittests(void) { const int length = 4; size_t compsize = streamvbyte_encode((uint32_t *)in, length, compressedbuffer); - size_t usedbytes = streamvbyte_decode(compressedbuffer, (uint32_t *)recovdata, length); + if (!streamvbyte_validate_stream(compressedbuffer, compsize, length)) { + printf("[streamvbyte_validate_stream] code is buggy i=%i\n", i); + return -1; + } + size_t usedbytes = streamvbyte_decode(compressedbuffer, (uint32_t *)recovdata, length); if (compsize != usedbytes) { - printf("[streamvbyte_decode] code is buggy"); + printf("[streamvbyte_decode] code is buggy i=%i\n", i); return -1; } for (size_t k = 0; k < length * sizeof(uint32_t); ++k) { if (recovdata[k] != in[k]) { - printf("[streamvbyte_decode] code is buggy"); + printf("[streamvbyte_decode] code is buggy i=%i\n", i); return -1; } } compsize = streamvbyte_encode_0124((uint32_t *)in, length, compressedbuffer); - usedbytes = streamvbyte_decode_0124(compressedbuffer, (uint32_t *)recovdata, length); + if (!streamvbyte_validate_stream_0124(compressedbuffer, compsize, length)) { + printf("[streamvbyte_validate_stream_0124] code is buggy i=%i\n", i); + return -1; + } + usedbytes = streamvbyte_decode_0124(compressedbuffer, (uint32_t *)recovdata, length); if (compsize != usedbytes) { - printf("[streamvbyte_decode_0124] code is buggy"); + printf("[streamvbyte_decode_0124] code is buggy i=%i\n", i); return -1; } for (size_t k = 0; k < length * sizeof(uint32_t); ++k) { if (recovdata[k] != in[k]) { - printf("[streamvbyte_decode_0124] code is buggy"); + printf("[streamvbyte_decode_0124] code is buggy i=%i\n", i); return -1; } }