From 9bb282b1efd3278c2f02b2800ad71adbd1164f90 Mon Sep 17 00:00:00 2001 From: David Cook Date: Thu, 29 Feb 2024 10:09:32 -0600 Subject: [PATCH] Use unittest for tests (#330) --- .github/workflows/lint-python.yml | 4 +- poc/Makefile | 12 +- poc/daf.py | 78 --------- poc/field.py | 74 -------- poc/flp.py | 72 +------- poc/flp_generic.py | 181 +------------------- poc/idpf.py | 144 ++-------------- poc/idpf_poplar.py | 29 +--- poc/plot_prio3_multiproof_robustness.py | 7 +- poc/tests/__init__.py | 0 poc/tests/idpf.py | 113 +++++++++++++ poc/tests/test_daf.py | 79 +++++++++ poc/tests/test_field.py | 83 +++++++++ poc/tests/test_flp.py | 70 ++++++++ poc/tests/test_flp_generic.py | 195 +++++++++++++++++++++ poc/tests/test_idpf_poplar.py | 49 ++++++ poc/tests/test_vdaf_poplar1.py | 106 ++++++++++++ poc/tests/test_vdaf_prio3.py | 216 ++++++++++++++++++++++++ poc/tests/test_xof.py | 94 +++++++++++ poc/tests/vdaf.py | 27 +++ poc/vdaf.py | 24 --- poc/vdaf_poplar1.py | 101 +---------- poc/vdaf_prio3.py | 200 +--------------------- poc/xof.py | 94 +---------- 24 files changed, 1071 insertions(+), 981 deletions(-) create mode 100644 poc/tests/__init__.py create mode 100644 poc/tests/idpf.py create mode 100644 poc/tests/test_daf.py create mode 100644 poc/tests/test_field.py create mode 100644 poc/tests/test_flp.py create mode 100644 poc/tests/test_flp_generic.py create mode 100644 poc/tests/test_idpf_poplar.py create mode 100644 poc/tests/test_vdaf_poplar1.py create mode 100644 poc/tests/test_vdaf_prio3.py create mode 100644 poc/tests/test_xof.py create mode 100644 poc/tests/vdaf.py diff --git a/.github/workflows/lint-python.yml b/.github/workflows/lint-python.yml index c1dbf610..773aa607 100644 --- a/.github/workflows/lint-python.yml +++ b/.github/workflows/lint-python.yml @@ -25,11 +25,11 @@ jobs: - name: Run pyflakes working-directory: poc - run: pyflakes *.py + run: pyflakes *.py tests/*.py - name: Run autopep8 working-directory: poc - run: autopep8 --diff --exit-code *.py + run: autopep8 --diff --exit-code *.py tests/*.py - name: Run isort working-directory: poc diff --git a/poc/Makefile b/poc/Makefile index c5cec7df..9cc333e5 100644 --- a/poc/Makefile +++ b/poc/Makefile @@ -1,12 +1,2 @@ test: - sage -python common.py - sage -python field.py - sage -python xof.py - sage -python flp.py - sage -python flp_generic.py - sage -python idpf.py - sage -python idpf_poplar.py - sage -python daf.py - sage -python vdaf.py - sage -python vdaf_prio3.py - sage -python vdaf_poplar1.py + sage -python -m unittest diff --git a/poc/daf.py b/poc/daf.py index 1de96325..e6cd4c70 100644 --- a/poc/daf.py +++ b/poc/daf.py @@ -2,11 +2,7 @@ from __future__ import annotations -from functools import reduce - -import field from common import Bool, Unsigned, gen_rand -from xof import XofTurboShake128 class Daf: @@ -138,77 +134,3 @@ def run_daf(Daf, agg_result = Daf.unshard(agg_param, agg_shares, num_measurements) return agg_result - - -## -# TESTS -# - -class TestDaf(Daf): - """A simple DAF used for testing.""" - - # Operational parameters - Field = field.Field128 - - # Associated parameters - ID = 0xFFFFFFFF - SHARES = 2 - NONCE_SIZE = 0 - RAND_SIZE = 16 - - # Associated types - Measurement = Unsigned - PublicShare = None - InputShare = Field - OutShare = Field - AggShare = Field - AggResult = Unsigned - - @classmethod - def shard(cls, measurement, _nonce, rand): - helper_shares = XofTurboShake128.expand_into_vec(cls.Field, - rand, - b'', - b'', - cls.SHARES-1) - leader_share = cls.Field(measurement) - for helper_share in helper_shares: - leader_share -= helper_share - input_shares = [leader_share] + helper_shares - return (None, input_shares) - - @classmethod - def prep(cls, _agg_id, _agg_param, _nonce, _public_share, input_share): - # For this simple test DAF, the output share is the same as the input - # share. - return input_share - - @classmethod - def aggregate(cls, _agg_param, out_shares): - return reduce(lambda x, y: x + y, out_shares) - - @classmethod - def unshard(cls, _agg_param, agg_shares, _num_measurements): - return reduce(lambda x, y: x + y, agg_shares).as_unsigned() - - -def test_daf(Daf, - agg_param, - measurements, - expected_agg_result): - # Test that the algorithm identifier is in the correct range. - assert 0 <= Daf.ID and Daf.ID < 2 ** 32 - - # Run the DAF on the set of measurements. - nonces = [gen_rand(Daf.NONCE_SIZE) for _ in range(len(measurements))] - agg_result = run_daf(Daf, - agg_param, - measurements, - nonces) - if agg_result != expected_agg_result: - print('daf test failed ({} on {}): unexpected result: got {}; want {}'.format( - Daf.__class__, measurements, agg_result, expected_agg_result)) - - -if __name__ == '__main__': - test_daf(TestDaf, None, [1, 2, 3, 4], 10) diff --git a/poc/field.py b/poc/field.py index d037bee2..b98864c4 100644 --- a/poc/field.py +++ b/poc/field.py @@ -251,77 +251,3 @@ def poly_interp(Field, xs, ys): R = PolynomialRing(Field.gf, 'x') p = R.lagrange_polynomial([(x.val, y.val) for (x, y) in zip(xs, ys)]) return poly_strip(Field, list(map(lambda x: Field(x), p.coefficients()))) - - -## -# TESTS -# - -def test_field(cls): - # Test constructing a field element from an integer. - assert cls(1337) == cls(cls.gf(1337)) - - # Test generating a zero-vector. - vec = cls.zeros(23) - assert len(vec) == 23 - for x in vec: - assert x == cls(cls.gf.zero()) - - # Test generating a random vector. - vec = cls.rand_vec(23) - assert len(vec) == 23 - - # Test arithmetic. - x = cls(cls.gf.random_element()) - y = cls(cls.gf.random_element()) - assert x + y == cls(x.val + y.val) - assert x - y == cls(x.val - y.val) - assert -x == cls(-x.val) - assert x * y == cls(x.val * y.val) - assert x.inv() == cls(x.val**-1) - - # Test serialization. - want = cls.rand_vec(10) - got = cls.decode_vec(cls.encode_vec(want)) - assert got == want - - # Test encoding integer as bit vector. - vals = [i for i in range(15)] - bits = 4 - for val in vals: - encoded = cls.encode_into_bit_vector(val, bits) - assert cls.decode_from_bit_vector(encoded).as_unsigned() == val - - -def test_fft_field(cls): - test_field(cls) - - # Test generator. - assert cls.gen()**cls.GEN_ORDER == cls(1) - - -if __name__ == '__main__': - test_fft_field(Field64) - test_fft_field(Field96) - test_fft_field(Field128) - test_field(Field255) - - # Test GF(2). - assert Field2(1).as_unsigned() == 1 - assert Field2(0).as_unsigned() == 0 - assert Field2(1) + Field2(1) == Field2(0) - assert Field2(1) * Field2(1) == Field2(1) - assert -Field2(1) == Field2(1) - assert Field2(1).conditional_select(b'hello') == b'hello' - assert Field2(0).conditional_select(b'hello') == bytes([0, 0, 0, 0, 0]) - - # Test polynomial interpolation. - cls = Field64 - p = cls.rand_vec(10) - xs = [cls(x) for x in range(10)] - ys = [poly_eval(cls, p, x) for x in xs] - q = poly_interp(cls, xs, ys) - for x in xs: - a = poly_eval(cls, p, x) - b = poly_eval(cls, q, x) - assert a == b diff --git a/poc/flp.py b/poc/flp.py index 0d929292..6a75b662 100644 --- a/poc/flp.py +++ b/poc/flp.py @@ -1,9 +1,7 @@ """Fully linear proof (FLP) systems.""" -from copy import deepcopy - import field -from common import ERR_ENCODE, Bool, Unsigned, Vec, vec_add, vec_sub +from common import Bool, Unsigned, Vec, vec_add, vec_sub from field import Field @@ -124,71 +122,3 @@ def run_flp(flp, meas: Vec[Flp.Field], num_shares: Unsigned): # Verifier decides if the measurement is valid. return flp.decide(verifier) - - -## -# TESTS -# - - -class FlpTest(Flp): - """An insecure FLP used only for testing.""" - # Associated parameters - JOINT_RAND_LEN = 1 - PROVE_RAND_LEN = 2 - QUERY_RAND_LEN = 3 - MEAS_LEN = 2 - OUTPUT_LEN = 1 - PROOF_LEN = 2 - VERIFIER_LEN = 2 - - # Associated types - Measurement = Unsigned - AggResult = Unsigned - - # Operational parameters - meas_range = range(5) - - def encode(self, measurement): - if measurement not in self.meas_range: - raise ERR_ENCODE - return [self.Field(measurement)] * 2 - - def prove(self, meas, prove_rand, joint_rand): - # The proof is the measurement itself for this trivially insecure FLP. - return deepcopy(meas) - - def query(self, meas, proof, query_rand, joint_rand, _num_shares): - return deepcopy(proof) - - def decide(self, verifier): - """Decide if a verifier message was generated from a valid - measurement.""" - if len(verifier) != 2 or \ - verifier[0] != verifier[1] or \ - verifier[0].as_unsigned() not in self.meas_range: - return False - return True - - def truncate(self, meas): - return [meas[0]] - - def decode(self, output, _num_measurements): - return output[0].as_unsigned() - - -class FlpTestField128(FlpTest): - Field = field.Field128 - - @staticmethod - def with_joint_rand_len(joint_rand_len): - flp = FlpTestField128() - flp.JOINT_RAND_LEN = joint_rand_len - return flp - - -if __name__ == '__main__': - flp = FlpTestField128() - assert run_flp(flp, flp.encode(0), 3) == True - assert run_flp(flp, flp.encode(4), 3) == True - assert run_flp(flp, [field.Field128(1337)], 3) == False diff --git a/poc/flp_generic.py b/poc/flp_generic.py index 8128675c..9f2627dc 100644 --- a/poc/flp_generic.py +++ b/poc/flp_generic.py @@ -5,7 +5,7 @@ import field from common import ERR_ABORT, ERR_INPUT, Unsigned, Vec, next_power_of_2 from field import poly_eval, poly_interp, poly_mul, poly_strip -from flp import Flp, run_flp +from flp import Flp class Gadget: @@ -647,7 +647,7 @@ def test_vec_set_type_param(self, test_vec): class MultiHotHistogram(Valid): - """ + r""" A validity circuit that checks each Client's measurement is a bit vector with at most `max_count` number of 1s. @@ -875,180 +875,3 @@ def test_vec_set_type_param(self, test_vec): test_vec['bits'] = self.bits test_vec['chunk_length'] = self.chunk_length return ['length', 'bits', 'chunk_length'] - - -## -# TESTS -# - -class TestMultiGadget(Valid): - # Associated types - Field = field.Field64 - Measurement = Unsigned - - # Associated parameters - GADGETS = [Mul(), Mul()] - GADGET_CALLS = [1, 2] - MEAS_LEN = 1 - JOINT_RAND_LEN = 0 - OUTPUT_LEN = 1 - - def eval(self, meas, joint_rand, _num_shares): - self.check_valid_eval(meas, joint_rand) - # Not a very useful circuit, obviously. We just want to do something. - x = self.GADGETS[0].eval(self.Field, [meas[0], meas[0]]) - y = self.GADGETS[1].eval(self.Field, [meas[0], x]) - z = self.GADGETS[1].eval(self.Field, [x, y]) - return z - - def encode(self, measurement): - if measurement not in [0, 1]: - raise ERR_INPUT - return [self.Field(measurement)] - - def truncate(self, meas): - if len(meas) != 1: - raise ERR_INPUT - return meas - - def decode(self, output, _num_measurements): - return output[0].as_unsigned() - - -def test_gadget(g, Field, test_length): - """ - Test for equivalence of `Gadget.eval()` and `Gadget.eval_poly()`. - """ - meas_poly = [] - meas = [] - eval_at = Field.rand_vec(1)[0] - for _ in range(g.ARITY): - meas_poly.append(Field.rand_vec(test_length)) - meas.append(poly_eval(Field, meas_poly[-1], eval_at)) - out_poly = g.eval_poly(Field, meas_poly) - - want = g.eval(Field, meas) - got = poly_eval(Field, out_poly, eval_at) - assert got == want - - -def test_flp_generic(flp, test_cases): - for (g, g_calls) in zip(flp.Valid.GADGETS, flp.Valid.GADGET_CALLS): - test_gadget(g, flp.Field, next_power_of_2(g_calls + 1)) - - for (i, (meas, expected_decision)) in enumerate(test_cases): - assert len(meas) == flp.MEAS_LEN - assert len(flp.truncate(meas)) == flp.OUTPUT_LEN - - # Evaluate validity circuit. - joint_rand = flp.Field.rand_vec(flp.JOINT_RAND_LEN) - v = flp.Valid.eval(meas, joint_rand, 1) - if (v == flp.Field(0)) != expected_decision: - print('{}: test {} failed: validity circuit returned {}'.format( - flp.Valid.__class__.__name__, i, v)) - - # Run the FLP. - decision = run_flp(flp, meas, 2) - if decision != expected_decision: - print('{}: test {} failed: proof evaluation resulted in {}; want {}'.format( - flp.Valid.__class__.__name__, i, decision, expected_decision)) - - -class TestAverage(Sum): - """ - Flp subclass that calculates the average of integers. The result is rounded - down. - """ - # Associated types - AggResult = Unsigned - - def decode(self, output, num_measurements): - total = super().decode(output, num_measurements) - return total // num_measurements - - -# Test encoding, truncation, then decoding. -def test_encode_truncate_decode(flp, measurements): - for measurement in measurements: - assert measurement == flp.decode( - flp.truncate(flp.encode(measurement)), 1) - - -def test_encode_truncate_decode_with_fft_fields(cls, measurements, *args): - for f in [field.Field64, field.Field96, field.Field128]: - cls_with_field = cls.with_field(f) - assert cls_with_field.Field == f - obj = cls_with_field(*args) - assert isinstance(obj, cls) - test_encode_truncate_decode(FlpGeneric(obj), measurements) - - -def test(): - flp = FlpGeneric(Count()) - test_flp_generic(flp, [ - (flp.encode(0), True), - (flp.encode(1), True), - ([flp.Field(1337)], False), - ]) - - test_gadget(Range2(), field.Field128, 10) - - test_gadget(PolyEval([0, -23, 1, 3]), field.Field128, 10) - - flp = FlpGeneric(Sum(10)) - test_flp_generic(flp, [ - (flp.encode(0), True), - (flp.encode(100), True), - (flp.encode(2 ** 10 - 1), True), - (flp.Field.rand_vec(10), False), - ]) - test_encode_truncate_decode(flp, [0, 100, 2 ** 10 - 1]) - - flp = FlpGeneric(Histogram(4, 2)) - test_flp_generic(flp, [ - (flp.encode(0), True), - (flp.encode(1), True), - (flp.encode(2), True), - (flp.encode(3), True), - ([flp.Field(0)] * 4, False), - ([flp.Field(1)] * 4, False), - (flp.Field.rand_vec(4), False), - ]) - - # MultiHotHistogram with length = 4, max_count = 2, chunk_length = 2. - flp = FlpGeneric(MultiHotHistogram(4, 2, 2)) - # Successful cases: - cases = [ - (flp.encode([0, 0, 0, 0]), True), - (flp.encode([0, 1, 0, 0]), True), - (flp.encode([0, 1, 1, 0]), True), - (flp.encode([1, 1, 0, 0]), True), - ] - # Failure cases: too many number of 1s, should fail count check. - cases += [ - ( - [flp.Field(1)] * i + - [flp.Field(0)] * (flp.Valid.length - i) + - # Try to lie about the encoded count. - [flp.Field(0)] * flp.Valid.bits_for_count, - False - ) - for i in range(flp.Valid.max_count + 1, flp.Valid.length + 1) - ] - # Failure case: pass count check but fail bit check. - cases += [(flp.encode([flp.Field.MODULUS - 1, 1, 0, 0]), False)] - test_flp_generic(flp, cases) - - # SumVec with length 2, bits 4, chunk len 1. - test_encode_truncate_decode_with_fft_fields(SumVec, - [[1, 2], [3, 4], [5, 6], [7, 8]], - 2, 4, 1) - - flp = FlpGeneric(TestMultiGadget()) - test_flp_generic(flp, [ - (flp.encode(0), True), - ]) - - -if __name__ == '__main__': - test() diff --git a/poc/idpf.py b/poc/idpf.py index 4e26e850..c1f26f2c 100644 --- a/poc/idpf.py +++ b/poc/idpf.py @@ -2,13 +2,10 @@ from __future__ import annotations -import json -import os -from functools import reduce from typing import Tuple, Union import field -from common import TEST_VECTOR_PATH, Bool, Bytes, Unsigned, gen_rand, vec_add +from common import Bool, Bytes, Unsigned class Idpf: @@ -20,8 +17,8 @@ class Idpf: # Bit length of valid input values (i.e., the length of `alpha` in bits). BITS: Unsigned = None - # The length of each output vector (i.e., the length of `beta_leaf` and each - # element of `beta_inner`). + # The length of each output vector (i.e., the length of `beta_leaf` and + # each element of `beta_inner`). VALUE_LEN: Unsigned = None # Size in bytes of each IDPF key share. @@ -53,8 +50,8 @@ def gen(Idpf, """ Generates an IDPF public share and sequence of IDPF-keys of length `SHARES`. Value `alpha` is the input to encode. Values `beta_inner` and - `beta_leaf` are assigned to the values of the nodes on the non-zero path - of the IDPF tree. String `binder` is a binder string. + `beta_leaf` are assigned to the values of the nodes on the non-zero + path of the IDPF tree. String `binder` is a binder string. An error is raised if integer `alpha` is larger than or equal to `2^BITS`, any elment of `beta_inner` has length other than `VALUE_LEN`, @@ -71,25 +68,25 @@ def eval(Idpf, prefixes: Tuple[Unsigned, ...], binder: Bytes) -> Output: """ - Evaluate an IDPF key at a given level of the tree and with the given set - of prefixes. The output is a vector where each element is a vector of - length `VALUE_LEN`. The output field is `FieldLeaf` if `level == BITS` - and `FieldInner` otherwise. `binder` must match the binder string passed - by the Client to `gen`. + Evaluate an IDPF key at a given level of the tree and with the given + set of prefixes. The output is a vector where each element is a vector + of length `VALUE_LEN`. The output field is `FieldLeaf` if `level == + BITS` and `FieldInner` otherwise. `binder` must match the binder string + passed by the Client to `gen`. Let `LSB(x, N)` denote the least significant `N` bits of positive integer `x`. By definition, a positive integer `x` is said to be the - length-`L` prefix of positive integer `y` if `LSB(x, L)` is equal to the - most significant `L` bits of `LSB(y, BITS)`, For example, 6 (110 in + length-`L` prefix of positive integer `y` if `LSB(x, L)` is equal to + the most significant `L` bits of `LSB(y, BITS)`, For example, 6 (110 in binary) is the length-3 prefix of 25 (11001), but 7 (111) is not. Each element of `prefixes` is an integer in `[0, 2^level)`. For each element of `prefixes` that is the length-`level` prefix of the input - encoded by the IDPF-key generation algorithm (i.e., `alpha`), the sum of - the corresponding output shares will be equal to one of the programmed - output vectors (i.e., an element of `beta_inner + [beta_leaf]`). For all - other elements of `prefixes`, the corresponding output shares will sum - up to the 0-vector. + encoded by the IDPF-key generation algorithm (i.e., `alpha`), the sum + of the corresponding output shares will be equal to one of the + programmed output vectors (i.e., an element of `beta_inner + + [beta_leaf]`). For all other elements of `prefixes`, the corresponding + output shares will sum up to the 0-vector. An error is raised if any element of `prefixes` is larger than or equal to `2^level` or if `level` is greater than `BITS`. @@ -106,110 +103,3 @@ def is_prefix(Idpf, x: Unsigned, y: Unsigned, L: Unsigned) -> Bool: """Returns `True` iff `x` is the prefix of `y` of length `L`.""" assert 0 < L and L <= Idpf.BITS return y >> (Idpf.BITS - L) == x - - -def test_idpf(Idpf, alpha, level, prefixes): - """ - Generate a set of IDPF keys and evaluate them on the given set of prefix. - """ - beta_inner = [[Idpf.FieldInner(1)] * Idpf.VALUE_LEN] * (Idpf.BITS-1) - beta_leaf = [Idpf.FieldLeaf(1)] * Idpf.VALUE_LEN - - # Generate the IDPF keys. - rand = gen_rand(Idpf.RAND_SIZE) - binder = b'some nonce' - (public_share, keys) = Idpf.gen(alpha, beta_inner, beta_leaf, binder, rand) - - out = [Idpf.current_field(level).zeros(Idpf.VALUE_LEN)] * len(prefixes) - for agg_id in range(Idpf.SHARES): - out_share = Idpf.eval( - agg_id, public_share, keys[agg_id], level, prefixes, binder) - for i in range(len(prefixes)): - out[i] = vec_add(out[i], out_share[i]) - - for (got, prefix) in zip(out, prefixes): - if Idpf.is_prefix(prefix, alpha, level+1): - if level < Idpf.BITS-1: - want = beta_inner[level] - else: - want = beta_leaf - else: - want = Idpf.current_field(level).zeros(Idpf.VALUE_LEN) - - if got != want: - print('error: {0:b} {1:b} {2}: got {3}; want {4}'.format( - alpha, prefix, level, got, want)) - - -def gen_test_vec(Idpf, alpha, test_vec_instance): - beta_inner = [] - for level in range(Idpf.BITS-1): - beta_inner.append([Idpf.FieldInner(level)] * Idpf.VALUE_LEN) - beta_leaf = [Idpf.FieldLeaf(Idpf.BITS-1)] * Idpf.VALUE_LEN - rand = gen_rand(Idpf.RAND_SIZE) - binder = b'some nonce' - (public_share, keys) = Idpf.gen(alpha, beta_inner, beta_leaf, binder, rand) - - printable_beta_inner = [ - [str(elem.as_unsigned()) for elem in value] for value in beta_inner - ] - printable_beta_leaf = [str(elem.as_unsigned()) for elem in beta_leaf] - printable_keys = [key.hex() for key in keys] - test_vec = { - 'bits': int(Idpf.BITS), - 'alpha': str(alpha), - 'beta_inner': printable_beta_inner, - 'beta_leaf': printable_beta_leaf, - 'binder': binder.hex(), - 'public_share': public_share.hex(), - 'keys': printable_keys, - } - - os.system('mkdir -p {}'.format(TEST_VECTOR_PATH)) - with open('{}/{}_{}.json'.format( - TEST_VECTOR_PATH, Idpf.test_vec_name, test_vec_instance), 'w') as f: - json.dump(test_vec, f, indent=4, sort_keys=True) - f.write('\n') - - -def test_idpf_exhaustive(Idpf, alpha): - """Generate a set of IDPF keys and test every possible output.""" - - # Generate random outputs with which to program the IDPF. - beta_inner = [] - for _ in range(Idpf.BITS - 1): - beta_inner.append(Idpf.FieldInner.rand_vec(Idpf.VALUE_LEN)) - beta_leaf = Idpf.FieldLeaf.rand_vec(Idpf.VALUE_LEN) - - # Generate the IDPF keys. - rand = gen_rand(Idpf.RAND_SIZE) - binder = b"some nonce" - (public_share, keys) = Idpf.gen(alpha, beta_inner, beta_leaf, binder, rand) - - # Evaluate the IDPF at every node of the tree. - for level in range(Idpf.BITS): - prefixes = tuple(range(2 ** level)) - - out_shares = [] - for agg_id in range(Idpf.SHARES): - out_shares.append( - Idpf.eval(agg_id, public_share, - keys[agg_id], level, prefixes, binder)) - - # Check that each set of output shares for each prefix sums up to the - # correct value. - for prefix in prefixes: - got = reduce(lambda x, y: vec_add(x, y), - map(lambda x: x[prefix], out_shares)) - - if Idpf.is_prefix(prefix, alpha, level+1): - if level < Idpf.BITS-1: - want = beta_inner[level] - else: - want = beta_leaf - else: - want = Idpf.current_field(level).zeros(Idpf.VALUE_LEN) - - if got != want: - print('error: {0:b} {1:b} {2}: got {3}; want {4}'.format( - alpha, prefix, level, got, want)) diff --git a/poc/idpf_poplar.py b/poc/idpf_poplar.py index b654b934..2d2a9e80 100644 --- a/poc/idpf_poplar.py +++ b/poc/idpf_poplar.py @@ -3,10 +3,10 @@ import itertools import field -from common import (ERR_DECODE, ERR_INPUT, TEST_VECTOR, Bytes, Unsigned, Vec, - format_dst, vec_add, vec_neg, vec_sub, xor) +from common import (ERR_DECODE, ERR_INPUT, Bytes, Unsigned, Vec, format_dst, + vec_add, vec_neg, vec_sub, xor) from field import Field2 -from idpf import Idpf, gen_test_vec, test_idpf, test_idpf_exhaustive +from idpf import Idpf from xof import XofFixedKeyAes128 @@ -265,26 +265,3 @@ def unpack_bits(packed_bits: Bytes, length: Unsigned) -> Vec[Field2]: if (length + 7) // 8 != len(packed_bits) or leftover_bits != 0: raise ERR_DECODE return bits - - -if __name__ == '__main__': - cls = IdpfPoplar \ - .with_value_len(2) - if TEST_VECTOR: - gen_test_vec(cls.with_bits(10), 0, 0) - test_idpf(cls.with_bits(16), 0b1111000011110000, 15, (0b1111000011110000,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 14, (0b111100001111000,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 13, (0b11110000111100,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 12, (0b1111000011110,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 11, (0b111100001111,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 10, (0b11110000111,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 5, (0b111100,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 4, (0b11110,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 3, (0b1111,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 2, (0b111,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 1, (0b11,)) - test_idpf(cls.with_bits(16), 0b1111000011110000, 0, (0b1,)) - test_idpf(cls.with_bits(1000), 0, 999, (0,)) - test_idpf_exhaustive(cls.with_bits(1), 0) - test_idpf_exhaustive(cls.with_bits(1), 1) - test_idpf_exhaustive(cls.with_bits(8), 91) diff --git a/poc/plot_prio3_multiproof_robustness.py b/poc/plot_prio3_multiproof_robustness.py index a7b3cf24..b15df191 100644 --- a/poc/plot_prio3_multiproof_robustness.py +++ b/poc/plot_prio3_multiproof_robustness.py @@ -1,4 +1,5 @@ -# plot_prio3_multiproof_robustness.py - Plot robustness bounds for various parameters. +# plot_prio3_multiproof_robustness.py - Plot robustness bounds for various +# parameters. # Use `sage -python plot_prio3_multiproof_robustness.py` import math @@ -30,8 +31,8 @@ def robustness(epsilon, ro_queries, prep_queries, num_proofs, seed_bits): epsilon - soundness of the base FLP - ro_queries - random oracle queries, a proxy for the amount of precomputation - done by the adversary + ro_queries - random oracle queries, a proxy for the amount of + precomputation done by the adversary prep_queries - number of online attempts, a proxy for the batch size diff --git a/poc/tests/__init__.py b/poc/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/poc/tests/idpf.py b/poc/tests/idpf.py new file mode 100644 index 00000000..e1aba0b4 --- /dev/null +++ b/poc/tests/idpf.py @@ -0,0 +1,113 @@ +import json +import os +from functools import reduce + +from common import TEST_VECTOR_PATH, gen_rand, vec_add + + +def test_idpf(Idpf, alpha, level, prefixes): + """ + Generate a set of IDPF keys and evaluate them on the given set of prefix. + """ + beta_inner = [[Idpf.FieldInner(1)] * Idpf.VALUE_LEN] * (Idpf.BITS-1) + beta_leaf = [Idpf.FieldLeaf(1)] * Idpf.VALUE_LEN + + # Generate the IDPF keys. + rand = gen_rand(Idpf.RAND_SIZE) + binder = b'some nonce' + (public_share, keys) = Idpf.gen(alpha, beta_inner, beta_leaf, binder, rand) + + out = [Idpf.current_field(level).zeros(Idpf.VALUE_LEN)] * len(prefixes) + for agg_id in range(Idpf.SHARES): + out_share = Idpf.eval( + agg_id, public_share, keys[agg_id], level, prefixes, binder) + for i in range(len(prefixes)): + out[i] = vec_add(out[i], out_share[i]) + + for (got, prefix) in zip(out, prefixes): + if Idpf.is_prefix(prefix, alpha, level+1): + if level < Idpf.BITS-1: + want = beta_inner[level] + else: + want = beta_leaf + else: + want = Idpf.current_field(level).zeros(Idpf.VALUE_LEN) + + if got != want: + print('error: {0:b} {1:b} {2}: got {3}; want {4}'.format( + alpha, prefix, level, got, want)) + + +def test_idpf_exhaustive(Idpf, alpha): + """Generate a set of IDPF keys and test every possible output.""" + + # Generate random outputs with which to program the IDPF. + beta_inner = [] + for _ in range(Idpf.BITS - 1): + beta_inner.append(Idpf.FieldInner.rand_vec(Idpf.VALUE_LEN)) + beta_leaf = Idpf.FieldLeaf.rand_vec(Idpf.VALUE_LEN) + + # Generate the IDPF keys. + rand = gen_rand(Idpf.RAND_SIZE) + binder = b"some nonce" + (public_share, keys) = Idpf.gen(alpha, beta_inner, beta_leaf, binder, rand) + + # Evaluate the IDPF at every node of the tree. + for level in range(Idpf.BITS): + prefixes = tuple(range(2 ** level)) + + out_shares = [] + for agg_id in range(Idpf.SHARES): + out_shares.append( + Idpf.eval(agg_id, public_share, + keys[agg_id], level, prefixes, binder)) + + # Check that each set of output shares for each prefix sums up to the + # correct value. + for prefix in prefixes: + got = reduce(lambda x, y: vec_add(x, y), + map(lambda x: x[prefix], out_shares)) + + if Idpf.is_prefix(prefix, alpha, level+1): + if level < Idpf.BITS-1: + want = beta_inner[level] + else: + want = beta_leaf + else: + want = Idpf.current_field(level).zeros(Idpf.VALUE_LEN) + + if got != want: + print('error: {0:b} {1:b} {2}: got {3}; want {4}'.format( + alpha, prefix, level, got, want)) + + +def gen_test_vec(Idpf, alpha, test_vec_instance): + beta_inner = [] + for level in range(Idpf.BITS-1): + beta_inner.append([Idpf.FieldInner(level)] * Idpf.VALUE_LEN) + beta_leaf = [Idpf.FieldLeaf(Idpf.BITS-1)] * Idpf.VALUE_LEN + rand = gen_rand(Idpf.RAND_SIZE) + binder = b'some nonce' + (public_share, keys) = Idpf.gen(alpha, beta_inner, beta_leaf, binder, rand) + + printable_beta_inner = [ + [str(elem.as_unsigned()) for elem in value] for value in beta_inner + ] + printable_beta_leaf = [str(elem.as_unsigned()) for elem in beta_leaf] + printable_keys = [key.hex() for key in keys] + test_vec = { + 'bits': int(Idpf.BITS), + 'alpha': str(alpha), + 'beta_inner': printable_beta_inner, + 'beta_leaf': printable_beta_leaf, + 'binder': binder.hex(), + 'public_share': public_share.hex(), + 'keys': printable_keys, + } + + os.system('mkdir -p {}'.format(TEST_VECTOR_PATH)) + filename = '{}/{}_{}.json'.format(TEST_VECTOR_PATH, Idpf.test_vec_name, + test_vec_instance) + with open(filename, 'w') as f: + json.dump(test_vec, f, indent=4, sort_keys=True) + f.write('\n') diff --git a/poc/tests/test_daf.py b/poc/tests/test_daf.py new file mode 100644 index 00000000..2488789b --- /dev/null +++ b/poc/tests/test_daf.py @@ -0,0 +1,79 @@ +import unittest +from functools import reduce + +from common import Unsigned, gen_rand +from daf import Daf, run_daf +from field import Field128 +from xof import XofTurboShake128 + + +class TestDaf(Daf): + """A simple DAF used for testing.""" + + # Operational parameters + Field = Field128 + + # Associated parameters + ID = 0xFFFFFFFF + SHARES = 2 + NONCE_SIZE = 0 + RAND_SIZE = 16 + + # Associated types + Measurement = Unsigned + PublicShare = None + InputShare = Field + OutShare = Field + AggShare = Field + AggResult = Unsigned + + @classmethod + def shard(cls, measurement, _nonce, rand): + helper_shares = XofTurboShake128.expand_into_vec(cls.Field, + rand, + b'', + b'', + cls.SHARES-1) + leader_share = cls.Field(measurement) + for helper_share in helper_shares: + leader_share -= helper_share + input_shares = [leader_share] + helper_shares + return (None, input_shares) + + @classmethod + def prep(cls, _agg_id, _agg_param, _nonce, _public_share, input_share): + # For this simple test DAF, the output share is the same as the input + # share. + return input_share + + @classmethod + def aggregate(cls, _agg_param, out_shares): + return reduce(lambda x, y: x + y, out_shares) + + @classmethod + def unshard(cls, _agg_param, agg_shares, _num_measurements): + return reduce(lambda x, y: x + y, agg_shares).as_unsigned() + + +def test_daf(Daf, + agg_param, + measurements, + expected_agg_result): + # Test that the algorithm identifier is in the correct range. + assert 0 <= Daf.ID and Daf.ID < 2 ** 32 + + # Run the DAF on the set of measurements. + nonces = [gen_rand(Daf.NONCE_SIZE) for _ in range(len(measurements))] + agg_result = run_daf(Daf, + agg_param, + measurements, + nonces) + if agg_result != expected_agg_result: + print('daf test failed ({} on {}): unexpected result: got {}; want {}' + .format(Daf.__class__, measurements, agg_result, + expected_agg_result)) + + +class TestDafCase(unittest.TestCase): + def test_test_daf(self): + test_daf(TestDaf, None, [1, 2, 3, 4], 10) diff --git a/poc/tests/test_field.py b/poc/tests/test_field.py new file mode 100644 index 00000000..38167a29 --- /dev/null +++ b/poc/tests/test_field.py @@ -0,0 +1,83 @@ +import unittest + +from field import (Field2, Field64, Field96, Field128, Field255, poly_eval, + poly_interp) + + +def test_field(cls): + # Test constructing a field element from an integer. + assert cls(1337) == cls(cls.gf(1337)) + + # Test generating a zero-vector. + vec = cls.zeros(23) + assert len(vec) == 23 + for x in vec: + assert x == cls(cls.gf.zero()) + + # Test generating a random vector. + vec = cls.rand_vec(23) + assert len(vec) == 23 + + # Test arithmetic. + x = cls(cls.gf.random_element()) + y = cls(cls.gf.random_element()) + assert x + y == cls(x.val + y.val) + assert x - y == cls(x.val - y.val) + assert -x == cls(-x.val) + assert x * y == cls(x.val * y.val) + assert x.inv() == cls(x.val**-1) + + # Test serialization. + want = cls.rand_vec(10) + got = cls.decode_vec(cls.encode_vec(want)) + assert got == want + + # Test encoding integer as bit vector. + vals = [i for i in range(15)] + bits = 4 + for val in vals: + encoded = cls.encode_into_bit_vector(val, bits) + assert cls.decode_from_bit_vector(encoded).as_unsigned() == val + + +def test_fft_field(cls): + test_field(cls) + + # Test generator. + assert cls.gen()**cls.GEN_ORDER == cls(1) + + +class TestFields(unittest.TestCase): + def test_field64(self): + test_fft_field(Field64) + + def test_field96(self): + test_fft_field(Field96) + + def test_field128(self): + test_fft_field(Field128) + + def test_field255(self): + test_field(Field255) + + def test_field2(self): + # Test GF(2). + assert Field2(1).as_unsigned() == 1 + assert Field2(0).as_unsigned() == 0 + assert Field2(1) + Field2(1) == Field2(0) + assert Field2(1) * Field2(1) == Field2(1) + assert -Field2(1) == Field2(1) + assert Field2(1).conditional_select(b'hello') == b'hello' + assert Field2(0).conditional_select(b'hello') == bytes([0, 0, 0, 0, 0]) + + def test_interp(self): + # Test polynomial interpolation. + cls = Field64 + p = cls.rand_vec(10) + xs = [cls(x) for x in range(10)] + ys = [poly_eval(cls, p, x) for x in xs] + q = poly_interp(cls, xs, ys) + for x in xs: + a = poly_eval(cls, p, x) + b = poly_eval(cls, q, x) + assert a == b diff --git a/poc/tests/test_flp.py b/poc/tests/test_flp.py new file mode 100644 index 00000000..4a032910 --- /dev/null +++ b/poc/tests/test_flp.py @@ -0,0 +1,70 @@ +import unittest +from copy import deepcopy + +from common import ERR_ENCODE, Unsigned +from field import Field128 +from flp import Flp, run_flp + + +class FlpTest(Flp): + """An insecure FLP used only for testing.""" + # Associated parameters + JOINT_RAND_LEN = 1 + PROVE_RAND_LEN = 2 + QUERY_RAND_LEN = 3 + MEAS_LEN = 2 + OUTPUT_LEN = 1 + PROOF_LEN = 2 + VERIFIER_LEN = 2 + + # Associated types + Measurement = Unsigned + AggResult = Unsigned + + # Operational parameters + meas_range = range(5) + + def encode(self, measurement): + if measurement not in self.meas_range: + raise ERR_ENCODE + return [self.Field(measurement)] * 2 + + def prove(self, meas, prove_rand, joint_rand): + # The proof is the measurement itself for this trivially insecure FLP. + return deepcopy(meas) + + def query(self, meas, proof, query_rand, joint_rand, _num_shares): + return deepcopy(proof) + + def decide(self, verifier): + """Decide if a verifier message was generated from a valid + measurement.""" + if len(verifier) != 2 or \ + verifier[0] != verifier[1] or \ + verifier[0].as_unsigned() not in self.meas_range: + return False + return True + + def truncate(self, meas): + return [meas[0]] + + def decode(self, output, _num_measurements): + return output[0].as_unsigned() + + +class FlpTestField128(FlpTest): + Field = Field128 + + @staticmethod + def with_joint_rand_len(joint_rand_len): + flp = FlpTestField128() + flp.JOINT_RAND_LEN = joint_rand_len + return flp + + +class TestFlp(unittest.TestCase): + def test_flp(self): + flp = FlpTestField128() + assert run_flp(flp, flp.encode(0), 3) is True + assert run_flp(flp, flp.encode(4), 3) is True + assert run_flp(flp, [Field128(1337)], 3) is False diff --git a/poc/tests/test_flp_generic.py b/poc/tests/test_flp_generic.py new file mode 100644 index 00000000..684ccd8c --- /dev/null +++ b/poc/tests/test_flp_generic.py @@ -0,0 +1,195 @@ +import unittest + +from common import ERR_INPUT, Unsigned, next_power_of_2 +from field import Field64, Field96, Field128, poly_eval +from flp import run_flp +from flp_generic import (Count, FlpGeneric, Histogram, Mul, MultiHotHistogram, + PolyEval, Range2, Sum, SumVec, Valid) + + +class TestMultiGadget(Valid): + # Associated types + Field = Field64 + Measurement = Unsigned + + # Associated parameters + GADGETS = [Mul(), Mul()] + GADGET_CALLS = [1, 2] + MEAS_LEN = 1 + JOINT_RAND_LEN = 0 + OUTPUT_LEN = 1 + + def eval(self, meas, joint_rand, _num_shares): + self.check_valid_eval(meas, joint_rand) + # Not a very useful circuit, obviously. We just want to do something. + x = self.GADGETS[0].eval(self.Field, [meas[0], meas[0]]) + y = self.GADGETS[1].eval(self.Field, [meas[0], x]) + z = self.GADGETS[1].eval(self.Field, [x, y]) + return z + + def encode(self, measurement): + if measurement not in [0, 1]: + raise ERR_INPUT + return [self.Field(measurement)] + + def truncate(self, meas): + if len(meas) != 1: + raise ERR_INPUT + return meas + + def decode(self, output, _num_measurements): + return output[0].as_unsigned() + + +def test_gadget(g, Field, test_length): + """ + Test for equivalence of `Gadget.eval()` and `Gadget.eval_poly()`. + """ + meas_poly = [] + meas = [] + eval_at = Field.rand_vec(1)[0] + for _ in range(g.ARITY): + meas_poly.append(Field.rand_vec(test_length)) + meas.append(poly_eval(Field, meas_poly[-1], eval_at)) + out_poly = g.eval_poly(Field, meas_poly) + + want = g.eval(Field, meas) + got = poly_eval(Field, out_poly, eval_at) + assert got == want + + +def test_flp_generic(flp, test_cases): + for (g, g_calls) in zip(flp.Valid.GADGETS, flp.Valid.GADGET_CALLS): + test_gadget(g, flp.Field, next_power_of_2(g_calls + 1)) + + for (i, (meas, expected_decision)) in enumerate(test_cases): + assert len(meas) == flp.MEAS_LEN + assert len(flp.truncate(meas)) == flp.OUTPUT_LEN + + # Evaluate validity circuit. + joint_rand = flp.Field.rand_vec(flp.JOINT_RAND_LEN) + v = flp.Valid.eval(meas, joint_rand, 1) + if (v == flp.Field(0)) != expected_decision: + print('{}: test {} failed: validity circuit returned {}'.format( + flp.Valid.__class__.__name__, i, v)) + + # Run the FLP. + decision = run_flp(flp, meas, 2) + if decision != expected_decision: + print( + '{}: test {} failed: proof evaluation resulted in {}; want {}' + .format( + flp.Valid.__class__.__name__, i, decision, + expected_decision, + ) + ) + + +class TestAverage(Sum): + """ + Flp subclass that calculates the average of integers. The result is rounded + down. + """ + # Associated types + AggResult = Unsigned + + def decode(self, output, num_measurements): + total = super().decode(output, num_measurements) + return total // num_measurements + + +# Test encoding, truncation, then decoding. +def test_encode_truncate_decode(flp, measurements): + for measurement in measurements: + assert measurement == flp.decode( + flp.truncate(flp.encode(measurement)), 1) + + +def test_encode_truncate_decode_with_fft_fields(cls, measurements, *args): + for f in [Field64, Field96, Field128]: + cls_with_field = cls.with_field(f) + assert cls_with_field.Field == f + obj = cls_with_field(*args) + assert isinstance(obj, cls) + test_encode_truncate_decode(FlpGeneric(obj), measurements) + + +class TestFlpGeneric(unittest.TestCase): + def test_count(self): + flp = FlpGeneric(Count()) + test_flp_generic(flp, [ + (flp.encode(0), True), + (flp.encode(1), True), + ([flp.Field(1337)], False), + ]) + + def test_sum(self): + flp = FlpGeneric(Sum(10)) + test_flp_generic(flp, [ + (flp.encode(0), True), + (flp.encode(100), True), + (flp.encode(2 ** 10 - 1), True), + (flp.Field.rand_vec(10), False), + ]) + test_encode_truncate_decode(flp, [0, 100, 2 ** 10 - 1]) + + def test_histogram(self): + flp = FlpGeneric(Histogram(4, 2)) + test_flp_generic(flp, [ + (flp.encode(0), True), + (flp.encode(1), True), + (flp.encode(2), True), + (flp.encode(3), True), + ([flp.Field(0)] * 4, False), + ([flp.Field(1)] * 4, False), + (flp.Field.rand_vec(4), False), + ]) + + def test_multi_hot_histogram(self): + # MultiHotHistogram with length = 4, max_count = 2, chunk_length = 2. + flp = FlpGeneric(MultiHotHistogram(4, 2, 2)) + # Successful cases: + cases = [ + (flp.encode([0, 0, 0, 0]), True), + (flp.encode([0, 1, 0, 0]), True), + (flp.encode([0, 1, 1, 0]), True), + (flp.encode([1, 1, 0, 0]), True), + ] + # Failure cases: too many number of 1s, should fail count check. + cases += [ + ( + [flp.Field(1)] * i + + [flp.Field(0)] * (flp.Valid.length - i) + + # Try to lie about the encoded count. + [flp.Field(0)] * flp.Valid.bits_for_count, + False + ) + for i in range(flp.Valid.max_count + 1, flp.Valid.length + 1) + ] + # Failure case: pass count check but fail bit check. + cases += [(flp.encode([flp.Field.MODULUS - 1, 1, 0, 0]), False)] + test_flp_generic(flp, cases) + + def test_sumvec(self): + # SumVec with length 2, bits 4, chunk len 1. + test_encode_truncate_decode_with_fft_fields( + SumVec, + [[1, 2], [3, 4], [5, 6], [7, 8]], + 2, + 4, + 1, + ) + + def test_multigadget(self): + flp = FlpGeneric(TestMultiGadget()) + test_flp_generic(flp, [ + (flp.encode(0), True), + ]) + + +class TestGadget(unittest.TestCase): + def test_range2(self): + test_gadget(Range2(), Field128, 10) + + def test_polyeval(self): + test_gadget(PolyEval([0, -23, 1, 3]), Field128, 10) diff --git a/poc/tests/test_idpf_poplar.py b/poc/tests/test_idpf_poplar.py new file mode 100644 index 00000000..64407dc2 --- /dev/null +++ b/poc/tests/test_idpf_poplar.py @@ -0,0 +1,49 @@ +import unittest + +from common import TEST_VECTOR +from idpf_poplar import IdpfPoplar +from tests.idpf import gen_test_vec, test_idpf, test_idpf_exhaustive + + +class TestIdpfPoplar(unittest.TestCase): + def test_idpfpoplar(self): + cls = IdpfPoplar \ + .with_value_len(2) + if TEST_VECTOR: + gen_test_vec(cls.with_bits(10), 0, 0) + test_idpf( + cls.with_bits(16), + 0b1111000011110000, + 15, + (0b1111000011110000,), + ) + test_idpf( + cls.with_bits(16), + 0b1111000011110000, + 14, + (0b111100001111000,), + ) + test_idpf( + cls.with_bits(16), + 0b1111000011110000, + 13, + (0b11110000111100,), + ) + test_idpf( + cls.with_bits(16), + 0b1111000011110000, + 12, + (0b1111000011110,), + ) + test_idpf(cls.with_bits(16), 0b1111000011110000, 11, (0b111100001111,)) + test_idpf(cls.with_bits(16), 0b1111000011110000, 10, (0b11110000111,)) + test_idpf(cls.with_bits(16), 0b1111000011110000, 5, (0b111100,)) + test_idpf(cls.with_bits(16), 0b1111000011110000, 4, (0b11110,)) + test_idpf(cls.with_bits(16), 0b1111000011110000, 3, (0b1111,)) + test_idpf(cls.with_bits(16), 0b1111000011110000, 2, (0b111,)) + test_idpf(cls.with_bits(16), 0b1111000011110000, 1, (0b11,)) + test_idpf(cls.with_bits(16), 0b1111000011110000, 0, (0b1,)) + test_idpf(cls.with_bits(1000), 0, 999, (0,)) + test_idpf_exhaustive(cls.with_bits(1), 0) + test_idpf_exhaustive(cls.with_bits(1), 1) + test_idpf_exhaustive(cls.with_bits(8), 91) diff --git a/poc/tests/test_vdaf_poplar1.py b/poc/tests/test_vdaf_poplar1.py new file mode 100644 index 00000000..5d74d7f1 --- /dev/null +++ b/poc/tests/test_vdaf_poplar1.py @@ -0,0 +1,106 @@ +import unittest + +from common import TEST_VECTOR, from_be_bytes +from tests.vdaf import test_vdaf +from vdaf_poplar1 import Poplar1 + + +class TestPoplar1(unittest.TestCase): + def test_poplar1(self): + test_vdaf(Poplar1.with_bits(15), (15, ()), [], []) + test_vdaf(Poplar1.with_bits(2), (1, (0b11,)), [], [0]) + test_vdaf( + Poplar1.with_bits(2), + (0, (0b0, 0b1)), + [0b10, 0b00, 0b11, 0b01, 0b11], + [2, 3], + ) + test_vdaf( + Poplar1.with_bits(2), + (1, (0b00, 0b01)), + [0b10, 0b00, 0b11, 0b01, 0b01], + [1, 2], + ) + test_vdaf( + Poplar1.with_bits(16), + (15, (0b1111000011110000,)), + [0b1111000011110000], + [1], + ) + test_vdaf( + Poplar1.with_bits(16), + (14, (0b111100001111000,)), + [ + 0b1111000011110000, + 0b1111000011110001, + 0b0111000011110000, + 0b1111000011110010, + 0b1111000000000000, + ], + [2], + ) + test_vdaf( + Poplar1.with_bits(128), + ( + 127, + (from_be_bytes(b'0123456789abcdef'),), + ), + [ + from_be_bytes(b'0123456789abcdef'), + ], + [1], + ) + test_vdaf( + Poplar1.with_bits(256), + ( + 63, + ( + from_be_bytes(b'00000000'), + from_be_bytes(b'01234567'), + ), + ), + [ + from_be_bytes(b'0123456789abcdef0123456789abcdef'), + from_be_bytes(b'01234567890000000000000000000000'), + ], + [0, 2], + ) + + def test_is_valid(self): + # Test `is_valid` returns False on repeated levels, and True otherwise. + cls = Poplar1.with_bits(256) + agg_params = [(0, ()), (1, (0,)), (1, (0, 1))] + assert cls.is_valid(agg_params[0], set([])) + assert cls.is_valid(agg_params[1], set(agg_params[:1])) + assert not cls.is_valid(agg_params[2], set(agg_params[:2])) + + def test_aggregation_parameter_encoding(self): + # Test aggregation parameter encoding. + cls = Poplar1.with_bits(256) + want = (0, ()) + assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) + want = (0, (0, 1)) + assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) + want = (2, (0, 1, 2, 3)) + assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) + want = (17, (0, 1, 1233, 2 ** 18 - 1)) + assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) + want = (255, (0, 1, 1233, 2 ** 256 - 1)) + assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) + + def test_generate_test_vectors(self): + # Generate test vectors. + cls = Poplar1.with_bits(4) + assert cls.ID == 0x00001000 + measurements = [0b1101] + tests = [ + # (level, prefixes, expected result) + (0, [0, 1], [0, 1]), + (1, [0, 1, 2, 3], [0, 0, 0, 1]), + (2, [0, 2, 4, 6], [0, 0, 0, 1]), + (3, [1, 3, 5, 7, 9, 13, 15], [0, 0, 0, 0, 0, 1, 0]), + ] + for (level, prefixes, expected_result) in tests: + agg_param = (int(level), tuple(map(int, prefixes))) + test_vdaf(cls, agg_param, measurements, expected_result, + print_test_vec=TEST_VECTOR, test_vec_instance=level) diff --git a/poc/tests/test_vdaf_prio3.py b/poc/tests/test_vdaf_prio3.py new file mode 100644 index 00000000..743b3799 --- /dev/null +++ b/poc/tests/test_vdaf_prio3.py @@ -0,0 +1,216 @@ +import unittest + +from common import TEST_VECTOR, Unsigned +from field import FftField, Field64 +from flp_generic import FlpGeneric, SumVec +from tests.test_flp import FlpTestField128 +from tests.test_flp_generic import TestAverage +from tests.vdaf import test_vdaf +from vdaf_prio3 import (Prio3, Prio3Count, Prio3Histogram, + Prio3MultiHotHistogram, Prio3Sum, Prio3SumVec, + Prio3SumVecWithMultiproof) +from xof import XofTurboShake128 + + +class TestPrio3Average(Prio3): + """ + A Prio3 instantiation to test use of num_measurements in the Valid + class's decode() method. + """ + + Xof = XofTurboShake128 + # NOTE 0xFFFFFFFF is reserved for testing. If we decide to standardize this + # Prio3 variant, then we'll need to pick a real codepoint for it. + ID = 0xFFFFFFFF + VERIFY_KEY_SIZE = XofTurboShake128.SEED_SIZE + + @classmethod + def with_bits(cls, bits: Unsigned): + class TestPrio3AverageWithBits(TestPrio3Average): + Flp = FlpGeneric(TestAverage(bits)) + return TestPrio3AverageWithBits + + +def test_prio3sumvec(num_proofs: Unsigned, field: FftField): + valid_cls = SumVec.with_field(field) + assert Prio3SumVecWithMultiproof.is_recommended( + valid_cls, num_proofs, field) + + cls = Prio3SumVecWithMultiproof \ + .with_params(10, 8, 9, num_proofs, field) \ + .with_shares(2) + + assert cls.ID == 0xFFFFFFFF + assert cls.PROOFS == num_proofs + + test_vdaf( + cls, + None, + [[1, 61, 86, 61, 23, 0, 255, 3, 2, 1]], + [1, 61, 86, 61, 23, 0, 255, 3, 2, 1] + ) + test_vdaf( + cls, + None, + [ + list(range(10)), + [1] * 10, + [255] * 10 + ], + list(range(256, 266)), + print_test_vec=False, + ) + cls = Prio3SumVec.with_params(3, 16, 7).with_shares(3) + test_vdaf( + cls, + None, + [ + [10000, 32000, 9], + [19342, 19615, 3061], + [15986, 24671, 23910] + ], + [45328, 76286, 26980], + print_test_vec=False, + test_vec_instance=1, + ) + + +class TestPrio3(unittest.TestCase): + def test_flp_test(self): + cls = Prio3 \ + .with_xof(XofTurboShake128) \ + .with_flp(FlpTestField128()) \ + .with_shares(2) + cls.ID = 0xFFFFFFFF + test_vdaf(cls, None, [1, 2, 3, 4, 4], 14) + + # If JOINT_RAND_LEN == 0, then Fiat-Shamir isn't needed and we can skip + # generating the joint randomness. + cls = Prio3 \ + .with_xof(XofTurboShake128) \ + .with_flp(FlpTestField128.with_joint_rand_len(0)) \ + .with_shares(2) + cls.ID = 0xFFFFFFFF + test_vdaf(cls, None, [1, 2, 3, 4, 4], 14) + + def test_count(self): + cls = Prio3Count.with_shares(2) + assert cls.ID == 0x00000000 + test_vdaf(cls, None, [0, 1, 1, 0, 1], 3) + test_vdaf(cls, None, [1], 1, print_test_vec=TEST_VECTOR) + + def test_count_3_shares(self): + cls = Prio3Count.with_shares(3) + test_vdaf(cls, None, [1], 1, print_test_vec=TEST_VECTOR, + test_vec_instance=1) + + def test_sum(self): + cls = Prio3Sum.with_bits(8).with_shares(2) + assert cls.ID == 0x00000001 + test_vdaf(cls, None, [0, 147, 1, 0, 11, 0], 159) + test_vdaf(cls, None, [100], 100, print_test_vec=TEST_VECTOR) + + def test_sum_3_shares(self): + cls = Prio3Sum.with_bits(8).with_shares(3) + test_vdaf(cls, None, [100], 100, print_test_vec=TEST_VECTOR, + test_vec_instance=1) + + def test_sum_vec(self): + cls = Prio3SumVec.with_params(10, 8, 9).with_shares(2) + assert cls.ID == 0x00000002 + test_vdaf( + cls, + None, + [[1, 61, 86, 61, 23, 0, 255, 3, 2, 1]], + [1, 61, 86, 61, 23, 0, 255, 3, 2, 1] + ) + test_vdaf( + cls, + None, + [ + list(range(10)), + [1] * 10, + [255] * 10 + ], + list(range(256, 266)), + print_test_vec=TEST_VECTOR, + ) + + def test_sum_vec_3_shares(self): + cls = Prio3SumVec.with_params(3, 16, 7).with_shares(3) + test_vdaf( + cls, + None, + [ + [10000, 32000, 9], + [19342, 19615, 3061], + [15986, 24671, 23910] + ], + [45328, 76286, 26980], + print_test_vec=TEST_VECTOR, + test_vec_instance=1, + ) + + def test_histogram(self): + cls = Prio3Histogram \ + .with_params(4, 2) \ + .with_shares(2) + assert cls.ID == 0x00000003 + test_vdaf(cls, None, [0], [1, 0, 0, 0]) + test_vdaf(cls, None, [1], [0, 1, 0, 0]) + test_vdaf(cls, None, [2], [0, 0, 1, 0]) + test_vdaf(cls, None, [3], [0, 0, 0, 1]) + test_vdaf(cls, None, [0, 0, 1, 1, 2, 2, 3, 3], [2, 2, 2, 2]) + test_vdaf(cls, None, [2], [0, 0, 1, 0], print_test_vec=TEST_VECTOR) + cls = Prio3Histogram.with_params(11, 3).with_shares(3) + test_vdaf( + cls, + None, + [2], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + print_test_vec=TEST_VECTOR, + test_vec_instance=1, + ) + + def test_multi_hot_histogram(self): + # Prio3MultiHotHistogram with length = 4, max_count = 2, + # chunk_length = 2. + cls = Prio3MultiHotHistogram \ + .with_params(4, 2, 2) \ + .with_shares(2) + assert cls.ID == 0xFFFFFFFF + test_vdaf(cls, None, [[0, 0, 0, 0]], [0, 0, 0, 0]) + test_vdaf(cls, None, [[0, 1, 0, 0]], [0, 1, 0, 0]) + test_vdaf(cls, None, [[0, 1, 1, 0]], [0, 1, 1, 0]) + test_vdaf(cls, None, [[0, 1, 1, 0], [0, 1, 0, 1]], [0, 2, 1, 1]) + test_vdaf( + cls, None, [[0, 1, 1, 0]], [0, 1, 1, 0], print_test_vec=False + ) + + def test_multi_hot_histogram_3_shares(self): + # Prio3MultiHotHistogram with length = 11, max_count = 5, + # chunk_length = 3. + cls = Prio3MultiHotHistogram.with_params(11, 5, 3).with_shares(3) + test_vdaf( + cls, + None, + [[1] * 5 + [0] * 6], + [1] * 5 + [0] * 6, + print_test_vec=False, + test_vec_instance=1, + ) + + def test_average(self): + cls = TestPrio3Average.with_bits(3).with_shares(2) + test_vdaf(cls, None, [1, 5, 1, 1, 4, 1, 3, 2], 2) + + def test_is_valid(self): + cls = TestPrio3Average.with_bits(3).with_shares(2) + # Test `is_valid` returns True on empty previous_agg_params, and False + # otherwise. + assert cls.is_valid(None, set([])) + assert not cls.is_valid(None, set([None])) + + def test_multiproof(self): + for n in range(2, 5): + test_prio3sumvec(num_proofs=n, field=Field64) diff --git a/poc/tests/test_xof.py b/poc/tests/test_xof.py new file mode 100644 index 00000000..ed6b4136 --- /dev/null +++ b/poc/tests/test_xof.py @@ -0,0 +1,94 @@ +import json +import os +import unittest + +from common import (TEST_VECTOR, TEST_VECTOR_PATH, format_dst, gen_rand, + print_wrapped_line) +from field import Field64, Field128 +from xof import XofFixedKeyAes128, XofTurboShake128 + + +def test_xof(Xof, F, expanded_len): + dst = format_dst(7, 1337, 2) + binder = b'a string that binds some protocol artifact to the output' + seed = gen_rand(Xof.SEED_SIZE) + + # Test next + expanded_data = Xof(seed, dst, binder).next(expanded_len) + assert len(expanded_data) == expanded_len + + want = Xof(seed, dst, binder).next(700) + got = b'' + xof = Xof(seed, dst, binder) + for i in range(0, 700, 7): + got += xof.next(7) + assert got == want + + # Test derive + derived_seed = Xof.derive_seed(seed, dst, binder) + assert len(derived_seed) == Xof.SEED_SIZE + + # Test expand_into_vec + expanded_vec = Xof.expand_into_vec(F, seed, dst, binder, expanded_len) + assert len(expanded_vec) == expanded_len + + +def generate_test_vector(cls): + seed = gen_rand(cls.SEED_SIZE) + dst = b'domain separation tag' + binder = b'binder string' + length = 40 + + test_vector = { + 'seed': seed.hex(), + 'dst': dst.hex(), + 'binder': binder.hex(), + 'length': length, + 'derived_seed': None, # set below + 'expanded_vec_field128': None, # set below + } + + test_vector['derived_seed'] = cls.derive_seed( + seed, dst, binder).hex() + test_vector['expanded_vec_field128'] = Field128.encode_vec( + cls.expand_into_vec(Field128, seed, dst, binder, length)).hex() + + print('{}:'.format(cls.test_vec_name)) + print(' seed: "{}"'.format(test_vector['seed'])) + print(' dst: "{}"'.format(test_vector['dst'])) + print(' binder: "{}"'.format(test_vector['binder'])) + print(' length: {}'.format(test_vector['length'])) + print(' derived_seed: "{}"'.format(test_vector['derived_seed'])) + print(' expanded_vec_field128: >-') + print_wrapped_line(test_vector['expanded_vec_field128'], tab=4) + + os.system('mkdir -p {}'.format(TEST_VECTOR_PATH)) + with open('{}/{}.json'.format( + TEST_VECTOR_PATH, cls.__name__), 'w') as f: + json.dump(test_vector, f, indent=4, sort_keys=True) + f.write('\n') + + +class TestXof(unittest.TestCase): + def test_rejection_sampling(self): + # This test case was found through brute-force search using this tool: + # https://github.com/divergentdave/vdaf-rejection-sampling-search + expanded_vec = XofTurboShake128.expand_into_vec( + Field64, + bytes([0xd1, 0x95, 0xec, 0x90, 0xc1, 0xbc, 0xf1, 0xf2, 0xcb, 0x2c, + 0x7e, 0x74, 0xc5, 0xc5, 0xf6, 0xda]), + b'', # domain separation tag + b'', # binder + 140, + ) + assert expanded_vec[-1] == Field64(9734340616212735019) + + def test_turboshake128(self): + test_xof(XofTurboShake128, Field128, 23) + if TEST_VECTOR: + generate_test_vector(XofTurboShake128) + + def test_fixedkeyaes128(self): + test_xof(XofFixedKeyAes128, Field128, 23) + if TEST_VECTOR: + generate_test_vector(XofFixedKeyAes128) diff --git a/poc/tests/vdaf.py b/poc/tests/vdaf.py new file mode 100644 index 00000000..4c97b78d --- /dev/null +++ b/poc/tests/vdaf.py @@ -0,0 +1,27 @@ +from common import gen_rand +from vdaf import run_vdaf + + +def test_vdaf(Vdaf, + agg_param, + measurements, + expected_agg_result, + print_test_vec=False, + test_vec_instance=0): + # Test that the algorithm identifier is in the correct range. + assert 0 <= Vdaf.ID and Vdaf.ID < 2 ** 32 + + # Run the VDAF on the set of measurmenets. + nonces = [gen_rand(Vdaf.NONCE_SIZE) for _ in range(len(measurements))] + verify_key = gen_rand(Vdaf.VERIFY_KEY_SIZE) + agg_result = run_vdaf(Vdaf, + verify_key, + agg_param, + nonces, + measurements, + print_test_vec, + test_vec_instance) + if agg_result != expected_agg_result: + print('vdaf test failed ({} on {}): unexpected result: got {}; want {}' + .format(Vdaf.test_vec_name, measurements, agg_result, + expected_agg_result)) diff --git a/poc/vdaf.py b/poc/vdaf.py index 5a969d4e..61425d24 100644 --- a/poc/vdaf.py +++ b/poc/vdaf.py @@ -362,27 +362,3 @@ def pretty_print_vdaf_test_vec(Vdaf, test_vec, type_params): # Unshard print('agg_result: {}'.format(test_vec['agg_result'])) print() - - -def test_vdaf(Vdaf, - agg_param, - measurements, - expected_agg_result, - print_test_vec=False, - test_vec_instance=0): - # Test that the algorithm identifier is in the correct range. - assert 0 <= Vdaf.ID and Vdaf.ID < 2 ** 32 - - # Run the VDAF on the set of measurmenets. - nonces = [gen_rand(Vdaf.NONCE_SIZE) for _ in range(len(measurements))] - verify_key = gen_rand(Vdaf.VERIFY_KEY_SIZE) - agg_result = run_vdaf(Vdaf, - verify_key, - agg_param, - nonces, - measurements, - print_test_vec, - test_vec_instance) - if agg_result != expected_agg_result: - print('vdaf test failed ({} on {}): unexpected result: got {}; want {}'.format( - Vdaf.test_vec_name, measurements, agg_result, expected_agg_result)) diff --git a/poc/vdaf_poplar1.py b/poc/vdaf_poplar1.py index 39eccfd4..f4129cd5 100644 --- a/poc/vdaf_poplar1.py +++ b/poc/vdaf_poplar1.py @@ -7,9 +7,9 @@ import idpf import idpf_poplar import xof -from common import (ERR_INPUT, ERR_VERIFY, TEST_VECTOR, Bytes, Unsigned, byte, +from common import (ERR_INPUT, ERR_VERIFY, Bytes, Unsigned, byte, from_be_bytes, front, to_be_bytes, vec_add, vec_sub) -from vdaf import Vdaf, test_vdaf +from vdaf import Vdaf USAGE_SHARD_RAND = 1 USAGE_CORR_INNER = 2 @@ -378,100 +378,3 @@ def encode_idpf_field_vec(vec): Field = vec[0].__class__ encoded += Field.encode_vec(vec) return encoded - - -if __name__ == '__main__': - test_vdaf(Poplar1.with_bits(15), (15, ()), [], []) - test_vdaf(Poplar1.with_bits(2), (1, (0b11,)), [], [0]) - test_vdaf( - Poplar1.with_bits(2), - (0, (0b0, 0b1)), - [0b10, 0b00, 0b11, 0b01, 0b11], - [2, 3], - ) - test_vdaf( - Poplar1.with_bits(2), - (1, (0b00, 0b01)), - [0b10, 0b00, 0b11, 0b01, 0b01], - [1, 2], - ) - test_vdaf( - Poplar1.with_bits(16), - (15, (0b1111000011110000,)), - [0b1111000011110000], - [1], - ) - test_vdaf( - Poplar1.with_bits(16), - (14, (0b111100001111000,)), - [ - 0b1111000011110000, - 0b1111000011110001, - 0b0111000011110000, - 0b1111000011110010, - 0b1111000000000000, - ], - [2], - ) - test_vdaf( - Poplar1.with_bits(128), - ( - 127, - (from_be_bytes(b'0123456789abcdef'),), - ), - [ - from_be_bytes(b'0123456789abcdef'), - ], - [1], - ) - test_vdaf( - Poplar1.with_bits(256), - ( - 63, - ( - from_be_bytes(b'00000000'), - from_be_bytes(b'01234567'), - ), - ), - [ - from_be_bytes(b'0123456789abcdef0123456789abcdef'), - from_be_bytes(b'01234567890000000000000000000000'), - ], - [0, 2], - ) - - # Test `is_valid` returns False on repeated levels, and True otherwise. - cls = Poplar1.with_bits(256) - agg_params = [(0, ()), (1, (0,)), (1, (0, 1))] - assert cls.is_valid(agg_params[0], set([])) - assert cls.is_valid(agg_params[1], set(agg_params[:1])) - assert not cls.is_valid(agg_params[2], set(agg_params[:2])) - - # Test aggregation parameter encoding. - cls = Poplar1.with_bits(256) - want = (0, ()) - assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) - want = (0, (0, 1)) - assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) - want = (2, (0, 1, 2, 3)) - assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) - want = (17, (0, 1, 1233, 2 ** 18 - 1)) - assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) - want = (255, (0, 1, 1233, 2 ** 256 - 1)) - assert want == cls.decode_agg_param(cls.encode_agg_param(*want)) - - # Generate test vectors. - cls = Poplar1.with_bits(4) - assert cls.ID == 0x00001000 - measurements = [0b1101] - tests = [ - # (level, prefixes, expected result) - (0, [0, 1], [0, 1]), - (1, [0, 1, 2, 3], [0, 0, 0, 1]), - (2, [0, 2, 4, 6], [0, 0, 0, 1]), - (3, [1, 3, 5, 7, 9, 13, 15], [0, 0, 0, 0, 0, 1, 0]), - ] - for (level, prefixes, expected_result) in tests: - agg_param = (int(level), tuple(map(int, prefixes))) - test_vdaf(cls, agg_param, measurements, expected_result, - print_test_vec=TEST_VECTOR, test_vec_instance=level) diff --git a/poc/vdaf_prio3.py b/poc/vdaf_prio3.py index f01aa317..82785bfd 100644 --- a/poc/vdaf_prio3.py +++ b/poc/vdaf_prio3.py @@ -5,10 +5,10 @@ import flp import flp_generic import xof -from common import (ERR_INPUT, ERR_VERIFY, TEST_VECTOR, Unsigned, byte, concat, - front, vec_add, vec_sub, zeros) +from common import (ERR_INPUT, ERR_VERIFY, Unsigned, byte, concat, front, + vec_add, vec_sub, zeros) from field import FftField, Field64, Field128 -from vdaf import Vdaf, test_vdaf +from vdaf import Vdaf USAGE_MEAS_SHARE = 1 USAGE_PROOF_SHARE = 2 @@ -571,197 +571,3 @@ class Prio3MultiHotHistogramWithParams(Prio3MultiHotHistogram): length, max_count, chunk_length )) return Prio3MultiHotHistogramWithParams - - -## -# TESTS -# - -class TestPrio3Average(Prio3): - """ - A Prio3 instantiation to test use of num_measurements in the Valid - class's decode() method. - """ - - Xof = xof.XofTurboShake128 - # NOTE 0xFFFFFFFF is reserved for testing. If we decide to standardize this - # Prio3 variant, then we'll need to pick a real codepoint for it. - ID = 0xFFFFFFFF - VERIFY_KEY_SIZE = xof.XofTurboShake128.SEED_SIZE - - @classmethod - def with_bits(cls, bits: Unsigned): - class TestPrio3AverageWithBits(TestPrio3Average): - Flp = flp_generic.FlpGeneric(flp_generic.TestAverage(bits)) - return TestPrio3AverageWithBits - - -def _test_prio3sumvec(num_proofs: Unsigned, field: FftField): - valid_cls = flp_generic.SumVec.with_field(field) - assert Prio3SumVecWithMultiproof.is_recommended( - valid_cls, num_proofs, field) - - cls = Prio3SumVecWithMultiproof \ - .with_params(10, 8, 9, num_proofs, field) \ - .with_shares(2) - - assert cls.ID == 0xFFFFFFFF - assert cls.PROOFS == num_proofs - - test_vdaf( - cls, - None, - [[1, 61, 86, 61, 23, 0, 255, 3, 2, 1]], - [1, 61, 86, 61, 23, 0, 255, 3, 2, 1] - ) - test_vdaf( - cls, - None, - [ - list(range(10)), - [1] * 10, - [255] * 10 - ], - list(range(256, 266)), - print_test_vec=False, - ) - cls = Prio3SumVec.with_params(3, 16, 7).with_shares(3) - test_vdaf( - cls, - None, - [ - [10000, 32000, 9], - [19342, 19615, 3061], - [15986, 24671, 23910] - ], - [45328, 76286, 26980], - print_test_vec=False, - test_vec_instance=1, - ) - - -def test_prio3sumvec_with_multiproof(): - for n in range(2, 5): - _test_prio3sumvec(num_proofs=n, field=Field64) - - -if __name__ == '__main__': - num_shares = 2 # Must be in range `[2, 256)` - - cls = Prio3 \ - .with_xof(xof.XofTurboShake128) \ - .with_flp(flp.FlpTestField128()) \ - .with_shares(num_shares) - cls.ID = 0xFFFFFFFF - test_vdaf(cls, None, [1, 2, 3, 4, 4], 14) - - # If JOINT_RAND_LEN == 0, then Fiat-Shamir isn't needed and we can skip - # generating the joint randomness. - cls = Prio3 \ - .with_xof(xof.XofTurboShake128) \ - .with_flp(flp.FlpTestField128.with_joint_rand_len(0)) \ - .with_shares(num_shares) - cls.ID = 0xFFFFFFFF - test_vdaf(cls, None, [1, 2, 3, 4, 4], 14) - - cls = Prio3Count.with_shares(num_shares) - assert cls.ID == 0x00000000 - test_vdaf(cls, None, [0, 1, 1, 0, 1], 3) - test_vdaf(cls, None, [1], 1, print_test_vec=TEST_VECTOR) - cls = Prio3Count.with_shares(3) - test_vdaf(cls, None, [1], 1, print_test_vec=TEST_VECTOR, - test_vec_instance=1) - - cls = Prio3Sum.with_bits(8).with_shares(num_shares) - assert cls.ID == 0x00000001 - test_vdaf(cls, None, [0, 147, 1, 0, 11, 0], 159) - test_vdaf(cls, None, [100], 100, print_test_vec=TEST_VECTOR) - cls = Prio3Sum.with_bits(8).with_shares(3) - test_vdaf(cls, None, [100], 100, print_test_vec=TEST_VECTOR, - test_vec_instance=1) - - cls = Prio3SumVec.with_params(10, 8, 9).with_shares(2) - assert cls.ID == 0x00000002 - test_vdaf( - cls, - None, - [[1, 61, 86, 61, 23, 0, 255, 3, 2, 1]], - [1, 61, 86, 61, 23, 0, 255, 3, 2, 1] - ) - test_vdaf( - cls, - None, - [ - list(range(10)), - [1] * 10, - [255] * 10 - ], - list(range(256, 266)), - print_test_vec=TEST_VECTOR, - ) - cls = Prio3SumVec.with_params(3, 16, 7).with_shares(3) - test_vdaf( - cls, - None, - [ - [10000, 32000, 9], - [19342, 19615, 3061], - [15986, 24671, 23910] - ], - [45328, 76286, 26980], - print_test_vec=TEST_VECTOR, - test_vec_instance=1, - ) - - cls = Prio3Histogram \ - .with_params(4, 2) \ - .with_shares(num_shares) - assert cls.ID == 0x00000003 - test_vdaf(cls, None, [0], [1, 0, 0, 0]) - test_vdaf(cls, None, [1], [0, 1, 0, 0]) - test_vdaf(cls, None, [2], [0, 0, 1, 0]) - test_vdaf(cls, None, [3], [0, 0, 0, 1]) - test_vdaf(cls, None, [0, 0, 1, 1, 2, 2, 3, 3], [2, 2, 2, 2]) - test_vdaf(cls, None, [2], [0, 0, 1, 0], print_test_vec=TEST_VECTOR) - cls = Prio3Histogram.with_params(11, 3).with_shares(3) - test_vdaf( - cls, - None, - [2], - [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], - print_test_vec=TEST_VECTOR, - test_vec_instance=1, - ) - - # Prio3MultiHotHistogram with length = 4, max_count = 2, chunk_length = 2. - cls = Prio3MultiHotHistogram \ - .with_params(4, 2, 2) \ - .with_shares(num_shares) - assert cls.ID == 0xFFFFFFFF - test_vdaf(cls, None, [[0, 0, 0, 0]], [0, 0, 0, 0]) - test_vdaf(cls, None, [[0, 1, 0, 0]], [0, 1, 0, 0]) - test_vdaf(cls, None, [[0, 1, 1, 0]], [0, 1, 1, 0]) - test_vdaf(cls, None, [[0, 1, 1, 0], [0, 1, 0, 1]], [0, 2, 1, 1]) - test_vdaf( - cls, None, [[0, 1, 1, 0]], [0, 1, 1, 0], print_test_vec=False - ) - # Prio3MultiHotHistogram with length = 11, max_count = 5, chunk_length = 3. - cls = Prio3MultiHotHistogram.with_params(11, 5, 3).with_shares(3) - test_vdaf( - cls, - None, - [[1] * 5 + [0] * 6], - [1] * 5 + [0] * 6, - print_test_vec=False, - test_vec_instance=1, - ) - - cls = TestPrio3Average.with_bits(3).with_shares(num_shares) - test_vdaf(cls, None, [1, 5, 1, 1, 4, 1, 3, 2], 2) - - # Test `is_valid` returns True on empty previous_agg_params, and False - # otherwise. - assert cls.is_valid(None, set([])) - assert not cls.is_valid(None, set([None])) - - test_prio3sumvec_with_multiproof() diff --git a/poc/xof.py b/poc/xof.py index 521f1509..22cf5fc2 100644 --- a/poc/xof.py +++ b/poc/xof.py @@ -5,9 +5,8 @@ from Cryptodome.Cipher import AES from Cryptodome.Hash import TurboSHAKE128 -from common import (TEST_VECTOR, TEST_VECTOR_PATH, Bytes, Unsigned, concat, - format_dst, from_le_bytes, gen_rand, next_power_of_2, - print_wrapped_line, to_le_bytes, xor) +from common import (Bytes, Unsigned, concat, from_le_bytes, next_power_of_2, + to_le_bytes, xor) class Xof: @@ -16,7 +15,8 @@ class Xof: # Size of the seed. SEED_SIZE: Unsigned - def __init__(self, seed: Bytes["Xof.SEED_SIZE"], dst: Bytes, binder: Bytes): + def __init__(self, seed: Bytes["Xof.SEED_SIZE"], dst: Bytes, + binder: Bytes): """ Construct a new instance of this XOF from the given seed, domain separation tag, and binder string. @@ -157,89 +157,3 @@ def hash_block(self, block): lo, hi = block[:8], block[8:] sigma_block = concat([hi, xor(hi, lo)]) return xor(self.cipher.encrypt(sigma_block), sigma_block) - - -## -# TESTS -# - -def test_xof(Xof, F, expanded_len): - dst = format_dst(7, 1337, 2) - binder = b'a string that binds some protocol artifact to the output' - seed = gen_rand(Xof.SEED_SIZE) - - # Test next - expanded_data = Xof(seed, dst, binder).next(expanded_len) - assert len(expanded_data) == expanded_len - - want = Xof(seed, dst, binder).next(700) - got = b'' - xof = Xof(seed, dst, binder) - for i in range(0, 700, 7): - got += xof.next(7) - assert got == want - - # Test derive - derived_seed = Xof.derive_seed(seed, dst, binder) - assert len(derived_seed) == Xof.SEED_SIZE - - # Test expand_into_vec - expanded_vec = Xof.expand_into_vec(F, seed, dst, binder, expanded_len) - assert len(expanded_vec) == expanded_len - - -if __name__ == '__main__': - import json - import os - - from field import Field64, Field128 - - # This test case was found through brute-force search using this tool: - # https://github.com/divergentdave/vdaf-rejection-sampling-search - expanded_vec = XofTurboShake128.expand_into_vec( - Field64, - bytes([0xd1, 0x95, 0xec, 0x90, 0xc1, 0xbc, 0xf1, 0xf2, 0xcb, 0x2c, - 0x7e, 0x74, 0xc5, 0xc5, 0xf6, 0xda]), - b'', # domain separation tag - b'', # binder - 140, - ) - assert expanded_vec[-1] == Field64(9734340616212735019) - - for cls in (XofTurboShake128, XofFixedKeyAes128): - test_xof(cls, Field128, 23) - - if TEST_VECTOR: - seed = gen_rand(cls.SEED_SIZE) - dst = b'domain separation tag' - binder = b'binder string' - length = 40 - - test_vector = { - 'seed': seed.hex(), - 'dst': dst.hex(), - 'binder': binder.hex(), - 'length': length, - 'derived_seed': None, # set below - 'expanded_vec_field128': None, # set below - } - - test_vector['derived_seed'] = cls.derive_seed( - seed, dst, binder).hex() - test_vector['expanded_vec_field128'] = Field128.encode_vec( - cls.expand_into_vec(Field128, seed, dst, binder, length)).hex() - - print('{}:'.format(cls.test_vec_name)) - print(' seed: "{}"'.format(test_vector['seed'])) - print(' dst: "{}"'.format(test_vector['dst'])) - print(' binder: "{}"'.format(test_vector['binder'])) - print(' length: {}'.format(test_vector['length'])) - print(' derived_seed: "{}"'.format(test_vector['derived_seed'])) - print(' expanded_vec_field128: >-') - print_wrapped_line(test_vector['expanded_vec_field128'], tab=4) - - os.system('mkdir -p {}'.format(TEST_VECTOR_PATH)) - with open('{}/{}.json'.format( - TEST_VECTOR_PATH, cls.__name__), 'w') as f: - json.dump(test_vector, f, indent=4, sort_keys=True) - f.write('\n')