From dda6990b53c2b7270071b2757a2bfe921cb87b57 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Wed, 20 Dec 2023 10:40:24 -0800 Subject: [PATCH] poc: Reduce test runtime by making TurboSHAKE128 stateful The reference implementation of TurboSHAKE128 provides a one-shot API. To use this efficiently, we restrict the number of bytes required for tests so that we can pre-compute the entire output buffer. This is controlled by MAX_XOF_OUT_STREAM_BYTES. This restriction is confusing and has side-effects for specs that use XofTurboShake128. However using the one-shot API without this optimization makes the unit tests prohibitively slow, about a minute on my machine. Implement a stateful API for TurboSHAKE128 and use it in XofTurboShake128. This reduces the runtime to 20 seconds on my machine. Accordingly, restore the unit tests for Poplar1 to what they were before we made this optimization. --- poc/Makefile | 1 + poc/turboshake.py | 189 ++++++++++++++++++++++++++++++++++++++++++++ poc/vdaf_poplar1.py | 20 ++--- poc/xof.py | 54 +++++++------ 4 files changed, 228 insertions(+), 36 deletions(-) create mode 100644 poc/turboshake.py diff --git a/poc/Makefile b/poc/Makefile index c5cec7df..3b53df47 100644 --- a/poc/Makefile +++ b/poc/Makefile @@ -2,6 +2,7 @@ test: sage -python common.py sage -python field.py sage -python xof.py + sage -python turboshake.py sage -python flp.py sage -python flp_generic.py sage -python idpf.py diff --git a/poc/turboshake.py b/poc/turboshake.py new file mode 100644 index 00000000..26cafd6b --- /dev/null +++ b/poc/turboshake.py @@ -0,0 +1,189 @@ +# A stateful implementation of TurboSHAKE adapted from the reference implementation +# +# We use TurboSHAKE in two steps: +# +# 1. Message fragments are absorbed into the hash state +# 2. Output fragments are squeezed out of the hash state +# +# The reference implementation of TurboSHAKE only provides a "one-shot" API, +# where the message and the length of the output are determined in advance. +# +# The stateful API is not needed if you know the desired output length in +# advance. Even if you don't know the desired output length, you can always do +# something like this: +# +# 1. Concatenate the message fragments into message `M` +# 2. Keep track of the output length `totalOutputBytesLen` squeezed so far and +# output `TurboSHAKE(c, M, D, totalOutputBytesLen+nextOutputBytesLen)`. +# +# However if the output length is large, then this is prohibitively slow, even +# for reference code. In particular, this makes the unit tests for Prio3 and +# Poplar1 take well over 30 seconds to run. Thus the purpose of implementing a +# stateful API is to make our unit tests run in a reasonable amount of time. + +import os +import sys + +kangarootwelve_path = \ + "%s/draft-irtf-cfrg-kangarootwelve/py" % os.path.dirname(__file__) # nopep8 +assert os.path.isdir(kangarootwelve_path) # nopep8 +sys.path.append(kangarootwelve_path) # nopep8 + +from TurboSHAKE import KeccakP1600, TurboSHAKE128 + + +class TurboSHAKEAbosrb: + '''TurboSHAKE in the absorb state.''' + + def __init__(self, c, D): + ''' + Initialize the absorb state with capacity `c` (number of bits) and + domain separation byte `D`. + ''' + self.D = D + self.rate_in_bytes = (1600-c)//8 + self.state = bytearray([0 for i in range(200)]) + self.state_offset = 0 + + def update(self, M): + ''' + Update the absorb state with message fragment `M`. + ''' + input_offset = 0 + while input_offset < len(M): + length = len(M)-input_offset + block_size = min(length, self.rate_in_bytes-self.state_offset) + for i in range(block_size): + self.state[i+self.state_offset] ^= M[i+input_offset] + input_offset += block_size + self.state_offset += block_size + if self.state_offset == self.rate_in_bytes: + self.state = KeccakP1600(self.state, 12) + self.state_offset = 0 + + def squeeze(self): + ''' + Consume the absorb state and return the TurboSHAKE squeeze state. + ''' + state = self.state[:] # deep copy + state[self.state_offset] ^= self.D + if (((self.D & 0x80) != 0) and + (self.state_offset == (self.rate_in_bytes-1))): + state = KeccakP1600(state, 12) + state[self.rate_in_bytes-1] = state[self.rate_in_bytes-1] ^ 0x80 + state = KeccakP1600(state, 12) + + squeeze = TurboSHAKESqueeze() + squeeze.rate_in_bytes = self.rate_in_bytes + squeeze.state = state + squeeze.state_offset = 0 + return squeeze + + +class TurboSHAKESqueeze: + '''TurboSHAKE in the squeeze state.''' + + def next(self, length): + ''' + Return the next `length` bytes of output and update the squeeze state. + ''' + output = bytearray() + while length > 0: + block_size = min(length, self.rate_in_bytes-self.state_offset) + length -= block_size + output += \ + self.state[self.state_offset:self.state_offset+block_size] + self.state_offset += block_size + if self.state_offset == self.rate_in_bytes: + self.state = KeccakP1600(self.state, 12) + self.state_offset = 0 + return output + + +def NewTurboSHAKE128(D): + ''' + Return the absorb state for TurboSHAKE128 with domain separation byte `D`. + ''' + return TurboSHAKEAbosrb(256, D) + + +def testAPI(stateful, oneshot): + '''Test that the outputs of the stateful and oneshot APIs match.''' + + test_cases = [ + { + 'fragments': [], + 'lengths': [], + }, + { + 'fragments': [], + 'lengths': [ + 1000, + ], + }, + { + 'fragments': [ + b'\xff' * 500, + ], + 'lengths': [ + 12, + ], + }, + { + 'fragments': [ + b'hello', + b', ', + b'', + b'world', + ], + 'lengths': [ + 1, + 17, + 256, + 128, + 0, + 7, + 14, + ], + }, + { + 'fragments': [ + b'\xff' * 1024, + b'\x17' * 23, + b'', + b'\xf1' * 512, + ], + 'lengths': [ + 1000, + 0, + 0, + 14, + ], + + } + ] + + D = 99 + for (i, test_case) in enumerate(test_cases): + absorb = stateful(D) + message = bytearray() + for fragment in test_case['fragments']: + absorb.update(fragment) + message += fragment + squeeze = absorb.squeeze() + output = b'' + output_len = 0 + for length in test_case['lengths']: + output += squeeze.next(length) + output_len += length + expected_output = oneshot(message, D, output_len) + if output != expected_output: + raise Exception('test case {} failed: got {}; want {}'.format( + i, + output.hex(), + expected_output.hex(), + )) + + +if __name__ == '__main__': + testAPI(NewTurboSHAKE128, TurboSHAKE128) diff --git a/poc/vdaf_poplar1.py b/poc/vdaf_poplar1.py index d8b3535c..39eccfd4 100644 --- a/poc/vdaf_poplar1.py +++ b/poc/vdaf_poplar1.py @@ -414,28 +414,28 @@ def encode_idpf_field_vec(vec): [2], ) test_vdaf( - Poplar1.with_bits(64), + Poplar1.with_bits(128), ( - 63, - (from_be_bytes(b'01234567'),), + 127, + (from_be_bytes(b'0123456789abcdef'),), ), [ - from_be_bytes(b'01234567'), + from_be_bytes(b'0123456789abcdef'), ], [1], ) test_vdaf( - Poplar1.with_bits(64), + Poplar1.with_bits(256), ( - 31, + 63, ( - from_be_bytes(b'0000'), - from_be_bytes(b'0123'), + from_be_bytes(b'00000000'), + from_be_bytes(b'01234567'), ), ), [ - from_be_bytes(b'01234567'), - from_be_bytes(b'01234000'), + from_be_bytes(b'0123456789abcdef0123456789abcdef'), + from_be_bytes(b'01234567890000000000000000000000'), ], [0, 2], ) diff --git a/poc/xof.py b/poc/xof.py index 6a271312..c5159090 100644 --- a/poc/xof.py +++ b/poc/xof.py @@ -2,25 +2,12 @@ from __future__ import annotations -import os -import sys - from Cryptodome.Cipher import AES -kangarootwelve_path = \ - "%s/draft-irtf-cfrg-kangarootwelve/py" % os.path.dirname(__file__) # nopep8 -assert os.path.isdir(kangarootwelve_path) # nopep8 -sys.path.append(kangarootwelve_path) # nopep8 -from TurboSHAKE import TurboSHAKE128 - from common import (TEST_VECTOR, VERSION, Bytes, Unsigned, concat, format_dst, from_le_bytes, gen_rand, next_power_of_2, print_wrapped_line, to_le_bytes, xor) - -# Maximum XOF output length that will be requested by any test in this package. -# Each time `XofTurboShake128` is constructed we call `TurboSHAKE128()` once -# and fill a buffer with the output stream. -MAX_XOF_OUT_STREAM_BYTES = 2000 +from turboshake import NewTurboSHAKE128, TurboSHAKE128 class Xof: @@ -84,18 +71,33 @@ class XofTurboShake128(Xof): test_vec_name = 'XofTurboShake128' def __init__(self, seed, dst, binder): + ''' + self.l = 0 + self.m = to_le_bytes(len(dst), 1) + dst + seed + binder + ''' self.length_consumed = 0 - self.stream = TurboSHAKE128( - to_le_bytes(len(dst), 1) + dst + seed + binder, - 1, - MAX_XOF_OUT_STREAM_BYTES, - ) + state = NewTurboSHAKE128(1) + state.update(to_le_bytes(len(dst), 1)) + state.update(dst) + state.update(seed) + state.update(binder) + self.state = state.squeeze() def next(self, length): - assert self.length_consumed + length < MAX_XOF_OUT_STREAM_BYTES - out = self.stream[self.length_consumed:self.length_consumed+length] - self.length_consumed += length - return out + ''' + self.l += length + + # Function `TurboSHAKE128(M, D, L)` is as defined in + # Section 2.2 of [TurboSHAKE]. + # + # Implementation note: Rather than re-generate the output + # stream each time `next()` is invoked, most implementations + # of TurboSHAKE128 will expose an "absorb-then-squeeze" API that + # allows stateful handling of the stream. + stream = TurboSHAKE128(self.m, 1, self.l) + return stream[-length:] + ''' + return self.state.next(length) class XofFixedKeyAes128(Xof): @@ -113,9 +115,9 @@ class XofFixedKeyAes128(Xof): def __init__(self, seed, dst, binder): self.length_consumed = 0 - # Use SHA-3 to derive a key from the binder string and domain - # separation tag. Note that the AES key does not need to be - # kept secret from any party. However, when used with + # Use TurboSHAKE128 to derive a key from the binder string and + # domain separation tag. Note that the AES key does not need + # to be kept secret from any party. However, when used with # IdpfPoplar, we require the binder to be a random nonce. # # Implementation note: This step can be cached across XOF