Skip to content

Commit

Permalink
Fix & refactor Huffman repeat tables for dictionaries
Browse files Browse the repository at this point in the history
The Huffman repeat mode checker assumed that the CTable was zeroed in the region `[maxSymbolValue + 1, 256)`.
This assumption didn't hold for tables built in the dictionaries, because it didn't go through the same codepath.

Since this code was originally written, we added a header to the CTable that specifies the `tableLog`.
Add `maxSymbolValue` to that header, and check that the table's `maxSymbolValue` is at least the block's `maxSymbolValue`.

This solution is cleaner because we write this header for every CTable we build, so it can't be missed in any code path.

Credit to OSS-Fuzz
  • Loading branch information
Nick Terrell committed Aug 25, 2023
1 parent 0fcb28c commit 671c157
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 18 deletions.
15 changes: 14 additions & 1 deletion lib/common/huf.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,22 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void

/** HUF_getNbBitsFromCTable() :
* Read nbBits from CTable symbolTable, for symbol `symbolValue` presumed <= HUF_SYMBOLVALUE_MAX
* Note 1 : is not inlined, as HUF_CElt definition is private */
* Note 1 : If symbolValue > HUF_readCTableHeader(symbolTable).maxSymbolValue, returns 0
* Note 2 : is not inlined, as HUF_CElt definition is private
*/
U32 HUF_getNbBitsFromCTable(const HUF_CElt* symbolTable, U32 symbolValue);

typedef struct {
BYTE tableLog;
BYTE maxSymbolValue;
BYTE unused[sizeof(size_t) - 2];
} HUF_CTableHeader;

/** HUF_readCTableHeader() :
* @returns The header from the CTable specifying the tableLog and the maxSymbolValue.
*/
HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable);

/*
* HUF_decompress() does the following:
* 1. select the decompression algorithm (X1, X2) based on pre-computed heuristics
Expand Down
61 changes: 44 additions & 17 deletions lib/compress/huf_compress.c
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,25 @@ static void HUF_setValue(HUF_CElt* elt, size_t value)
}
}

HUF_CTableHeader HUF_readCTableHeader(HUF_CElt const* ctable)
{
HUF_CTableHeader header;
ZSTD_memcpy(&header, ctable, sizeof(header));
return header;
}

static void HUF_writeCTableHeader(HUF_CElt* ctable, U32 tableLog, U32 maxSymbolValue)
{
HUF_CTableHeader header;
HUF_STATIC_ASSERT(sizeof(ctable[0]) == sizeof(header));
ZSTD_memset(&header, 0, sizeof(header));
assert(tableLog < 256);
header.tableLog = (BYTE)tableLog;
assert(maxSymbolValue < 256);
header.maxSymbolValue = (BYTE)maxSymbolValue;
ZSTD_memcpy(ctable, &header, sizeof(header));
}

typedef struct {
HUF_CompressWeightsWksp wksp;
BYTE bitsToWeight[HUF_TABLELOG_MAX + 1]; /* precomputed conversion table */
Expand All @@ -237,6 +256,9 @@ size_t HUF_writeCTable_wksp(void* dst, size_t maxDstSize,

HUF_STATIC_ASSERT(HUF_CTABLE_WORKSPACE_SIZE >= sizeof(HUF_WriteCTableWksp));

assert(HUF_readCTableHeader(CTable).maxSymbolValue == maxSymbolValue);
assert(HUF_readCTableHeader(CTable).tableLog == huffLog);

/* check conditions */
if (workspaceSize < sizeof(HUF_WriteCTableWksp)) return ERROR(GENERIC);
if (maxSymbolValue > HUF_SYMBOLVALUE_MAX) return ERROR(maxSymbolValue_tooLarge);
Expand Down Expand Up @@ -283,7 +305,9 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void
if (tableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge);
if (nbSymbols > *maxSymbolValuePtr+1) return ERROR(maxSymbolValue_tooSmall);

CTable[0] = tableLog;
*maxSymbolValuePtr = nbSymbols - 1;

HUF_writeCTableHeader(CTable, tableLog, *maxSymbolValuePtr);

/* Prepare base value per rank */
{ U32 n, nextRankStart = 0;
Expand Down Expand Up @@ -315,14 +339,15 @@ size_t HUF_readCTable (HUF_CElt* CTable, unsigned* maxSymbolValuePtr, const void
{ U32 n; for (n=0; n<nbSymbols; n++) HUF_setValue(ct + n, valPerRank[HUF_getNbBits(ct[n])]++); }
}

*maxSymbolValuePtr = nbSymbols - 1;
return readSize;
}

U32 HUF_getNbBitsFromCTable(HUF_CElt const* CTable, U32 symbolValue)
{
const HUF_CElt* const ct = CTable + 1;
assert(symbolValue <= HUF_SYMBOLVALUE_MAX);
if (symbolValue > HUF_readCTableHeader(CTable).maxSymbolValue)
return 0;
return (U32)HUF_getNbBits(ct[symbolValue]);
}

Expand Down Expand Up @@ -723,7 +748,8 @@ static void HUF_buildCTableFromTree(HUF_CElt* CTable, nodeElt const* huffNode, i
HUF_setNbBits(ct + huffNode[n].byte, huffNode[n].nbBits); /* push nbBits per symbol, symbol order */
for (n=0; n<alphabetSize; n++)
HUF_setValue(ct + n, valPerRank[HUF_getNbBits(ct[n])]++); /* assign value within rank, symbol order */
CTable[0] = maxNbBits;

HUF_writeCTableHeader(CTable, maxNbBits, maxSymbolValue);
}

size_t
Expand Down Expand Up @@ -776,13 +802,20 @@ size_t HUF_estimateCompressedSize(const HUF_CElt* CTable, const unsigned* count,
}

int HUF_validateCTable(const HUF_CElt* CTable, const unsigned* count, unsigned maxSymbolValue) {
HUF_CElt const* ct = CTable + 1;
int bad = 0;
int s;
for (s = 0; s <= (int)maxSymbolValue; ++s) {
bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0);
}
return !bad;
HUF_CTableHeader header = HUF_readCTableHeader(CTable);
HUF_CElt const* ct = CTable + 1;
int bad = 0;
int s;

assert(header.tableLog <= HUF_TABLELOG_ABSOLUTEMAX);

if (header.maxSymbolValue < maxSymbolValue)
return 0;

for (s = 0; s <= (int)maxSymbolValue; ++s) {
bad |= (count[s] != 0) & (HUF_getNbBits(ct[s]) == 0);
}
return !bad;
}

size_t HUF_compressBound(size_t size) { return HUF_COMPRESSBOUND(size); }
Expand Down Expand Up @@ -1024,7 +1057,7 @@ HUF_compress1X_usingCTable_internal_body(void* dst, size_t dstSize,
const void* src, size_t srcSize,
const HUF_CElt* CTable)
{
U32 const tableLog = (U32)CTable[0];
U32 const tableLog = HUF_readCTableHeader(CTable).tableLog;
HUF_CElt const* ct = CTable + 1;
const BYTE* ip = (const BYTE*) src;
BYTE* const ostart = (BYTE*)dst;
Expand Down Expand Up @@ -1372,12 +1405,6 @@ HUF_compress_internal (void* dst, size_t dstSize,
huffLog = (U32)maxBits;
DEBUGLOG(6, "bit distribution completed (%zu symbols)", showCTableBits(table->CTable + 1, maxSymbolValue+1));
}
/* Zero unused symbols in CTable, so we can check it for validity */
{
size_t const ctableSize = HUF_CTABLE_SIZE_ST(maxSymbolValue);
size_t const unusedSize = sizeof(table->CTable) - ctableSize * sizeof(HUF_CElt);
ZSTD_memset(table->CTable + ctableSize, 0, unusedSize);
}

/* Write table description header */
{ CHECK_V_F(hSize, HUF_writeCTable_wksp(op, dstSize, table->CTable, maxSymbolValue, huffLog,
Expand Down

0 comments on commit 671c157

Please sign in to comment.