diff --git a/poc/draft-irtf-cfrg-vdaf b/poc/draft-irtf-cfrg-vdaf index 510c00c..3d2dd77 160000 --- a/poc/draft-irtf-cfrg-vdaf +++ b/poc/draft-irtf-cfrg-vdaf @@ -1 +1 @@ -Subproject commit 510c00c59402f4643ec564f873e639ae63be618f +Subproject commit 3d2dd77befcf149c818b4bae1744e60c00d0a4f3 diff --git a/poc/flp_pine.py b/poc/flp_pine.py index 62d5d95..a2b24e5 100644 --- a/poc/flp_pine.py +++ b/poc/flp_pine.py @@ -8,9 +8,9 @@ dir_name = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(dir_name, "draft-irtf-cfrg-vdaf", "poc")) from common import Unsigned, front, gen_rand, next_power_of_2 -from field import Field, Field128 +from field import Field, Field128, Field64 from flp_generic import FlpGeneric, Mul, ParallelSum, Valid, test_flp_generic -from xof import Xof, XofShake128 +from xof import Xof, XofFixedKeyAes128 class PineValid(Valid): @@ -37,10 +37,25 @@ class PineValid(Valid): # Number of bits in the encoded measurement. bit_checked_len = None + # XOF for `PineValid`. + # Note we currently use `XofFixedKeyAes128` by default, because + # `XofTurboShake128` in VDAF poc has a limit of how many bytes can be + # sampled: + # https://github.com/cfrg/draft-irtf-cfrg-vdaf/blob/fd7a2dc4993babbf3acffc7d498cd7925890064b/poc/xof.py#L18-L21. + # TODO(junyechen1996): Switch to TurboShake128 after + # https://github.com/cfrg/draft-irtf-cfrg-vdaf/pull/322 is merged. + Xof = XofFixedKeyAes128 + # Associated types for `Valid`. Measurement = list[float] AggResult = list[float] - Field = Field128 + Field = None # Set by `with_field()`. + + @classmethod + def with_field(PineValid, TheField): + class PineValidWithField(PineValid): + Field = TheField + return PineValidWithField def __init__(self, l2_norm_bound: float, @@ -513,7 +528,7 @@ def test_bit_chunks(): assert bits_str == bit_chunks_str def test_pine_valid_roundtrip(): - valid = PineValid(1.0, 15, 2, 1) + valid = PineValid.with_field(Field128)(1.0, 15, 2, 1) f64_vals = [0.5, 0.5] assert f64_vals == valid.decode(valid.truncate(valid.encode(f64_vals)), 1) @@ -521,22 +536,24 @@ def test(): test_bit_chunks() test_pine_valid_roundtrip() + # `PineValid` with `l2_norm_bound = 1.0`, `num_frac_bits = 4`, + # `dimension = 4`, `chunk_length = 150`. l2_norm_bound = 1.0 - # 4 fractional bits should be enough to keep 1 decimal digit. - num_frac_bits = 4 dimension = 4 + args = [l2_norm_bound, 4, dimension, 150] # A gradient with a L2-norm of exactly 1.0. measurement = [l2_norm_bound / 2] * dimension - pine_valid = PineValid(l2_norm_bound, num_frac_bits, dimension, 150) - flp = FlpGeneric(pine_valid) - - # Test PINE FLP with verification. - xof = XofShake128(gen_rand(16), b"", b"") - encoded_gradient = flp.encode(measurement) - (wr_check_bits, wr_dot_prods) = \ - pine_valid.run_wr_checks(encoded_gradient, xof) - flp_meas = encoded_gradient + wr_check_bits + wr_dot_prods - test_flp_generic(flp, [(flp_meas, True)]) + for field in [Field64, Field128]: + pine_valid = PineValid.with_field(field)(*args) + flp = FlpGeneric(pine_valid) + + # Test PINE FLP with verification. + xof = XofFixedKeyAes128(gen_rand(16), b"", b"") + encoded_gradient = flp.encode(measurement) + (wr_check_bits, wr_dot_prods) = \ + pine_valid.run_wr_checks(encoded_gradient, xof) + meas = encoded_gradient + wr_check_bits + wr_dot_prods + test_flp_generic(flp, [(meas, True)]) if __name__ == '__main__': diff --git a/poc/vdaf_pine.py b/poc/vdaf_pine.py index be8e09d..9b9dcce 100644 --- a/poc/vdaf_pine.py +++ b/poc/vdaf_pine.py @@ -10,6 +10,7 @@ import xof from common import (Unsigned, byte, concat, front, gen_rand, vec_add, vec_sub, zeros) +from field import Field, Field128, Field64 from flp_generic import FlpGeneric from flp_pine import PineValid from vdaf import Vdaf, test_vdaf @@ -32,16 +33,19 @@ class Pine(Vdaf): """The Pine VDAF.""" - # Internal parameters of `Pine`. - Xof = xof.XofShake128 - Flp = None # Set by constructor based on the user parameters. - MEAS_LEN = None # Set by constructor, based on `Flp.MEAS_LEN`, minus + # Operational parameters set by user. + Flp = FlpGeneric # Set by `with_params()`. It is a `FlpGeneric` with a + # concrete `PineValid`. + PROOFS = 1 # Set by `with_params()`, number of proofs to run with `Flp`. + + # Associated parameters for `Pine`. + MEAS_LEN = None # Set by `with_params()`, based on `Flp.MEAS_LEN`, minus # the number of wraparound checks, because Clients # don't send the dot products in wraparound checks. # Associated parameters required by `Vdaf`. ID = 0xFFFFFFFF # Private codepoint that will be updated later. - VERIFY_KEY_SIZE = Xof.SEED_SIZE # Set based on `Xof`. + VERIFY_KEY_SIZE = PineValid.Xof.SEED_SIZE # Set based on `Xof`. NONCE_SIZE = 16 RAND_SIZE = None # Computed from `Xof.SEED_SIZE` and `SHARES` ROUNDS = 1 @@ -56,25 +60,25 @@ class Pine(Vdaf): InputShare = tuple[ Union[ # Leader: expanded measurement share and proof share. - tuple[list[PineValid.Field], list[PineValid.Field]], + tuple[list[Flp.Field], list[Flp.Field]], # Helper: seeds both measurement share and proof share. tuple[bytes, bytes] ], bytes, # wraparound joint randomness blind bytes, # verification joint randomness blind ] - OutShare = list[PineValid.Field] - AggShare = list[PineValid.Field] + OutShare = list[Flp.Field] + AggShare = list[Flp.Field] AggResult = PineValid.AggResult PrepShare = tuple[ - list[PineValid.Field], # verifier share - bytes, # wraparound joint randomness part - bytes, # verification joint randomness part + list[Flp.Field], # verifier share + bytes, # wraparound joint randomness part + bytes, # verification joint randomness part ] PrepState = tuple[ - list[PineValid.Field], # output share - bytes, # corrected wraparound joint randomness seed - bytes, # corrected verification joint randomness seed + list[Flp.Field], # output share + bytes, # corrected wraparound joint randomness seed + bytes, # corrected verification joint randomness seed ] # Joint randomness seed check for both wraparound joint randomness # and verification joint randomness. @@ -86,11 +90,16 @@ def with_params(Pine, num_frac_bits: Unsigned, dimension: Unsigned, chunk_length: Unsigned, - num_shares: Unsigned): + num_shares: Unsigned, + field: Field, + num_proofs: Unsigned): class PineWithParams(Pine): - Flp = FlpGeneric( - PineValid(l2_norm_bound, num_frac_bits, dimension, chunk_length) - ) + # TODO(issue#39) Decide how many proofs to use and enforce + # robustness. + Flp = FlpGeneric(PineValid.with_field(field)( + l2_norm_bound, num_frac_bits, dimension, chunk_length + )) + PROOFS = num_proofs MEAS_LEN = Flp.MEAS_LEN - Flp.Valid.NUM_WR_CHECKS # The size of randomness is the seed size times the sum of # the following: @@ -100,16 +109,13 @@ class PineWithParams(Pine): # - Two joint randomness seed blind for each Aggregator, one for # wraparound check, one for verification. RAND_SIZE = (1 + 2 * (num_shares - 1) + 2 * num_shares) * \ - Pine.Xof.SEED_SIZE + Flp.Valid.Xof.SEED_SIZE SHARES = num_shares return PineWithParams @classmethod - def shard(Pine, - measurement: Measurement, - nonce: bytes, - rand: bytes) -> tuple[PublicShare, list[InputShare]]: - l = Pine.Xof.SEED_SIZE + def shard(Pine, measurement, nonce, rand): + l = Pine.Flp.Valid.Xof.SEED_SIZE seeds = [rand[i:i+l] for i in range(0, Pine.RAND_SIZE, l)] encoded_gradient = Pine.Flp.encode(measurement) @@ -131,7 +137,7 @@ def shard(Pine, for i in range(0, (Pine.SHARES - 1) * num_helper_seeds, num_helper_seeds) ] - k_helper_proof_shares = [ + k_helper_proofs_shares = [ k_helper_seeds[i] for i in range(1, (Pine.SHARES - 1) * num_helper_seeds, num_helper_seeds) @@ -193,21 +199,31 @@ def shard(Pine, USAGE_JOINT_RAND_PART ) # Compute verification joint randomness field elements. - vf_joint_rand = Pine.vf_joint_rand(Pine.joint_rand_seed( + vf_joint_rands = Pine.vf_joint_rands(Pine.joint_rand_seed( k_vf_joint_rand_parts, USAGE_JOINT_RAND_SEED, )) # Generate the proof and shard it into proof shares. - prove_rand = Pine.prove_rand(k_prove) - # PINE's `eval()` function expects the dot products in wraparound checks - # to be appended after `encoded_measurement`. - leader_proof_share = Pine.Flp.prove( - encoded_measurement + wr_dot_prods, prove_rand, vf_joint_rand - ) + prove_rands = Pine.prove_rands(k_prove) + # `PineValid.eval()` function expects the dot products in wraparound + # checks to be appended after `encoded_measurement`. + flp_meas = encoded_measurement + wr_dot_prods + leader_proofs_share = [] + for _ in range(Pine.PROOFS): + (prove_rand, prove_rands) = \ + front(Pine.Flp.PROVE_RAND_LEN, prove_rands) + (vf_joint_rand, vf_joint_rands) = \ + front(Pine.Flp.JOINT_RAND_LEN, vf_joint_rands) + leader_proofs_share += Pine.Flp.prove(flp_meas, + prove_rand, + vf_joint_rand) + # Sanity check: + assert len(prove_rands) == 0 + assert len(vf_joint_rands) == 0 for j in range(Pine.SHARES-1): - leader_proof_share = vec_sub( - leader_proof_share, - Pine.helper_proof_share(j+1, k_helper_proof_shares[j]), + leader_proofs_share = vec_sub( + leader_proofs_share, + Pine.helper_proofs_share(j+1, k_helper_proofs_shares[j]), ) # Each Aggregator's input share contains: @@ -220,14 +236,14 @@ def shard(Pine, input_shares = [] input_shares.append(( leader_meas_share, - leader_proof_share, + leader_proofs_share, k_leader_wr_joint_rand_blind, k_leader_vf_joint_rand_blind, )) for j in range(Pine.SHARES-1): input_shares.append(( k_helper_meas_shares[j], - k_helper_proof_shares[j], + k_helper_proofs_shares[j], k_helper_wr_joint_rand_blinds[j], k_helper_vf_joint_rand_blinds[j], )) @@ -237,16 +253,16 @@ def shard(Pine, @classmethod def prep_init(Pine, - verify_key: bytes, - agg_id: Unsigned, + verify_key, + agg_id, _agg_param, - nonce: bytes, - public_share: PublicShare, - input_share: InputShare): + nonce, + public_share, + input_share): (k_wr_joint_rand_parts, k_vf_joint_rand_parts) = public_share ( meas_share, - proof_share, + proofs_share, k_wr_joint_rand_blind, k_vf_joint_rand_blind ) = Pine.expand_input_share(agg_id, input_share) @@ -284,15 +300,29 @@ def prep_init(Pine, USAGE_JOINT_RAND_PART, USAGE_JOINT_RAND_SEED, ) - vf_joint_rand = Pine.vf_joint_rand(k_corrected_vf_joint_rand_seed) + vf_joint_rands = Pine.vf_joint_rands(k_corrected_vf_joint_rand_seed) # Query the measurement and proof share. - query_rand = Pine.query_rand(verify_key, nonce) - verifier_share = Pine.Flp.query(meas_share + wr_dot_prods, - proof_share, - query_rand, - vf_joint_rand, - Pine.SHARES) + # `PineValid.eval()` expects the dot products for wraparound checks to be + # appended at the end of Client's encoded measurement. + flp_meas_share = meas_share + wr_dot_prods + query_rands = Pine.query_rands(verify_key, nonce) + verifiers_share = [] + for _ in range(Pine.PROOFS): + (proof_share, proofs_share) = front(Pine.Flp.PROOF_LEN, proofs_share) + (vf_joint_rand, vf_joint_rands) = \ + front(Pine.Flp.JOINT_RAND_LEN, vf_joint_rands) + (query_rand, query_rands) = \ + front(Pine.Flp.QUERY_RAND_LEN, query_rands) + verifiers_share += Pine.Flp.query(flp_meas_share, + proof_share, + query_rand, + vf_joint_rand, + Pine.SHARES) + # Sanity check: + assert len(proofs_share) == 0 + assert len(vf_joint_rands) == 0 + assert len(query_rands) == 0 return ( # Prepare state: @@ -303,33 +333,34 @@ def prep_init(Pine, ), # Prepare share that is exchanged with other Aggregators: ( - verifier_share, + verifiers_share, k_wr_joint_rand_part, k_vf_joint_rand_part ) ) @classmethod - def prep_shares_to_prep( - Pine, - _agg_param, - prep_shares: list[PrepShare], - ) -> PrepMessage: + def prep_shares_to_prep(Pine, _agg_param, prep_shares): # Unshard the verifier shares into the verifier message. - verifier = Pine.Flp.Field.zeros(Pine.Flp.VERIFIER_LEN) + verifiers = Pine.Flp.Field.zeros(Pine.Flp.VERIFIER_LEN * Pine.PROOFS) k_wr_joint_rand_parts = [] k_vf_joint_rand_parts = [] - for (verifier_share, + for (verifiers_share, k_wr_joint_rand_part, k_vf_joint_rand_part) in prep_shares: - verifier = vec_add(verifier, verifier_share) + verifiers = vec_add(verifiers, verifiers_share) k_wr_joint_rand_parts.append(k_wr_joint_rand_part) k_vf_joint_rand_parts.append(k_vf_joint_rand_part) - # Verify that the proof is well-formed and the input is valid. - if not Pine.Flp.decide(verifier): - raise ValueError("Decision function failed after combining all " - "verifier shares.") + # Verify that each proof is well-formed and input is valid. + for _ in range(Pine.PROOFS): + (verifier, verifiers) = front(Pine.Flp.VERIFIER_LEN, verifiers) + if not Pine.Flp.decide(verifier): + # Proof verifier check failed. + raise ValueError("Decision function failed after combining all " + "verifier shares.") + # Sanity check: + assert len(verifiers) == 0 # Combine the joint randomness parts computed by the Aggregators # into the true joint randomness seeds, which are checked by all @@ -343,11 +374,7 @@ def prep_shares_to_prep( return (k_wr_joint_rand_seed, k_vf_joint_rand_seed) @classmethod - def prep_next( - Pine, - prep_state: PrepState, - prep_msg: PrepMessage, - ) -> Union[tuple[PrepState, PrepShare], OutShare]: + def prep_next(Pine, prep_state, prep_msg): # Joint randomness seeds sent by the Leader. (k_wr_joint_rand_seed, k_vf_joint_rand_seed) = prep_msg ( @@ -365,16 +392,14 @@ def prep_next( return out_share @classmethod - def aggregate(Pine, _agg_param, out_shares: list[OutShare]) -> AggShare: + def aggregate(Pine, _agg_param, out_shares): agg_share = Pine.Flp.Field.zeros(Pine.Flp.OUTPUT_LEN) for out_share in out_shares: agg_share = vec_add(agg_share, out_share) return agg_share @classmethod - def unshard(Pine, _agg_param, - agg_shares: list[list[PineValid.Field]], - num_measurements: Unsigned) -> AggResult: + def unshard(Pine, _agg_param, agg_shares, num_measurements): agg = Pine.Flp.Field.zeros(Pine.Flp.OUTPUT_LEN) for agg_share in agg_shares: agg = vec_add(agg, agg_share) @@ -386,12 +411,12 @@ def unshard(Pine, _agg_param, def helper_meas_share(Pine, agg_id: Unsigned, k_share: bytes, - meas_len: Unsigned) -> list[PineValid.Field]: + meas_len: Unsigned) -> list[Flp.Field]: """ Generate the helper measurement share up to length `meas_len`, for aggregator ID `agg_id`, with measurement share seed `k_share`. """ - return Pine.Xof.expand_into_vec( + return Pine.Flp.Valid.Xof.expand_into_vec( Pine.Flp.Field, k_share, Pine.domain_separation_tag(USAGE_MEAS_SHARE), @@ -400,19 +425,19 @@ def helper_meas_share(Pine, ) @classmethod - def helper_proof_share(Pine, - agg_id: Unsigned, - k_share: bytes) -> list[PineValid.Field]: + def helper_proofs_share(Pine, + agg_id: Unsigned, + k_share: bytes) -> list[Flp.Field]: """ - Generate the helper proof share for aggregator ID `agg_id`, with + Generate the helper proofs share for aggregator ID `agg_id`, with proof share seed `k_share`. """ - return Pine.Xof.expand_into_vec( + return Pine.Flp.Valid.Xof.expand_into_vec( Pine.Flp.Field, k_share, Pine.domain_separation_tag(USAGE_PROOF_SHARE), - byte(agg_id), - Pine.Flp.PROOF_LEN, + byte(Pine.PROOFS) + byte(agg_id), + Pine.Flp.PROOF_LEN * Pine.PROOFS ) @classmethod @@ -420,58 +445,58 @@ def expand_input_share( Pine, agg_id: Unsigned, input_share: InputShare, - ) -> tuple[list[PineValid.Field], list[PineValid.Field], bytes, bytes]: + ) -> tuple[list[Flp.Field], list[Flp.Field], bytes, bytes]: """Expand Helper's seed into a vector of field elements. """ ( meas_share, - proof_share, + proofs_share, k_wr_joint_rand_blind, k_vf_joint_rand_blind ) = input_share if agg_id > 0: meas_share = \ Pine.helper_meas_share(agg_id, meas_share, Pine.MEAS_LEN) - proof_share = Pine.helper_proof_share(agg_id, proof_share) + proofs_share = Pine.helper_proofs_share(agg_id, proofs_share) return (meas_share, - proof_share, + proofs_share, k_wr_joint_rand_blind, k_vf_joint_rand_blind) @classmethod - def prove_rand(Pine, k_prove: bytes) -> list[PineValid.Field]: + def prove_rands(Pine, k_prove: bytes) -> list[Flp.Field]: """Generate the prover randomness based on the seed blind `k_prove`.""" - return Pine.Xof.expand_into_vec( + return Pine.Flp.Valid.Xof.expand_into_vec( Pine.Flp.Field, k_prove, Pine.domain_separation_tag(USAGE_PROVE_RANDOMNESS), - b'', - Pine.Flp.PROVE_RAND_LEN, + byte(Pine.PROOFS), + Pine.Flp.PROVE_RAND_LEN * Pine.PROOFS ) @classmethod - def query_rand(Pine, - verify_key: bytes, - nonce: bytes) -> list[PineValid.Field]: + def query_rands(Pine, + verify_key: bytes, + nonce: bytes) -> list[Flp.Field]: """ Generate the query randomness based on the verification key and nonce. """ - return Pine.Xof.expand_into_vec( + return Pine.Flp.Valid.Xof.expand_into_vec( Pine.Flp.Field, verify_key, Pine.domain_separation_tag(USAGE_QUERY_RANDOMNESS), - nonce, - Pine.Flp.QUERY_RAND_LEN, + byte(Pine.PROOFS) + nonce, + Pine.Flp.QUERY_RAND_LEN * Pine.PROOFS ) @classmethod def joint_rand_part(Pine, agg_id: Unsigned, k_blind: bytes, - meas_share: list[PineValid.Field], + meas_share: list[Flp.Field], nonce: bytes, usage: Unsigned) -> bytes: """Derive joint randomness part for an Aggregator. """ - return Pine.Xof.derive_seed( + return Pine.Flp.Valid.Xof.derive_seed( k_blind, Pine.domain_separation_tag(usage), byte(agg_id) + nonce + Pine.Flp.Field.encode_vec(meas_share), @@ -480,13 +505,13 @@ def joint_rand_part(Pine, @classmethod def leader_meas_share_and_joint_rand_parts( Pine, - encoded_measurement: list[PineValid.Field], + encoded_measurement: list[Flp.Field], k_helper_joint_rand_blinds: list[bytes], k_helper_meas_shares: list[bytes], k_leader_joint_rand_blind: bytes, nonce: bytes, part_usage: Unsigned - ) -> tuple[list[PineValid.Field], list[bytes]]: + ) -> tuple[list[Flp.Field], list[bytes]]: """ Return the leader measurement share and joint randomness parts with domain separation tag `part_usage`. @@ -517,8 +542,8 @@ def joint_rand_seed(Pine, """ Derive the joint randomness seed from its parts and based on the usage. """ - return Pine.Xof.derive_seed( - zeros(Pine.Xof.SEED_SIZE), + return Pine.Flp.Valid.Xof.derive_seed( + zeros(Pine.Flp.Valid.Xof.SEED_SIZE), Pine.domain_separation_tag(usage), concat(k_joint_rand_parts), ) @@ -528,7 +553,7 @@ def joint_rand_part_and_seed_for_aggregator( Pine, agg_id: Unsigned, k_joint_rand_blind: bytes, - meas_share: list[PineValid.Field], + meas_share: list[Flp.Field], nonce: bytes, k_joint_rand_parts: list[bytes], joint_rand_part_usage: Unsigned, @@ -554,64 +579,63 @@ def joint_rand_part_and_seed_for_aggregator( return (k_joint_rand_part, k_corrected_joint_rand_seed) @classmethod - def wr_joint_rand_xof(Pine, k_wr_joint_rand_seed: bytes) -> Xof: + def wr_joint_rand_xof(Pine, k_wr_joint_rand_seed: bytes) -> PineValid.Xof: """Initialize the XOF to sample wraparound joint randomness. """ - return Pine.Xof( + return Pine.Flp.Valid.Xof( k_wr_joint_rand_seed, Pine.domain_separation_tag(USAGE_WR_JOINT_RANDOMNESS), b'', ) @classmethod - def vf_joint_rand(Pine, - k_joint_rand_seed: bytes) -> list[PineValid.Field]: + def vf_joint_rands(Pine, + k_joint_rand_seed: bytes) -> list[Flp.Field]: """ Derive the verification joint randomness based on the initial seed. """ - return Pine.Xof.expand_into_vec( + return Pine.Flp.Valid.Xof.expand_into_vec( Pine.Flp.Field, k_joint_rand_seed, Pine.domain_separation_tag(USAGE_JOINT_RANDOMNESS), - b'', - Pine.Flp.JOINT_RAND_LEN, + byte(Pine.PROOFS), + Pine.Flp.JOINT_RAND_LEN * Pine.PROOFS ) @classmethod - def test_vec_encode_input_share(Pine, input_share: InputShare) -> bytes: + def test_vec_encode_input_share(Pine, input_share): ( meas_share, - proof_share, + proofs_share, k_wr_joint_rand_blind, k_vf_joint_rand_blind ) = input_share encoded = bytes() - if type(meas_share) == list and type(proof_share) == list: # leader + if type(meas_share) == list and type(proofs_share) == list: # leader encoded += Pine.Flp.Field.encode_vec(meas_share) - encoded += Pine.Flp.Field.encode_vec(proof_share) - elif type(meas_share) == bytes and type(proof_share) == bytes: # helper + encoded += Pine.Flp.Field.encode_vec(proofs_share) + elif type(meas_share) == bytes and type(proofs_share) == bytes: # helper encoded += meas_share - encoded += proof_share + encoded += proofs_share return encoded + k_wr_joint_rand_blind + k_vf_joint_rand_blind @classmethod - def test_vec_encode_public_share(Pine, public_share: PublicShare) -> bytes: + def test_vec_encode_public_share(Pine, public_share): (k_wr_joint_rand_parts, k_vf_joint_rand_parts) = public_share return concat(k_wr_joint_rand_parts) + concat(k_vf_joint_rand_parts) @classmethod - def test_vec_encode_agg_share(Pine, - agg_share: list[PineValid.Field]) -> bytes: + def test_vec_encode_agg_share(Pine, agg_share): return Pine.Flp.Field.encode_vec(agg_share) @classmethod - def test_vec_encode_prep_share(Pine, prep_share: PrepShare) -> bytes: + def test_vec_encode_prep_share(Pine, prep_share): (verifier_share, k_wr_joint_rand_part, k_vf_joint_rand_part) = \ prep_share return Pine.Flp.Field.encode_vec(verifier_share) + \ k_wr_joint_rand_part + k_vf_joint_rand_part @classmethod - def test_vec_encode_prep_msg(Pine, prep_message: PrepMessage): + def test_vec_encode_prep_msg(Pine, prep_message): (k_wr_joint_rand_seed, k_vf_joint_rand_seed) = prep_message return k_wr_joint_rand_seed + k_vf_joint_rand_seed @@ -629,16 +653,16 @@ def test_shard_result_share_length(Vdaf: Pine): [wr_joint_rand_parts, vf_joint_rand_parts] = public_share assert len(wr_joint_rand_parts) == Vdaf.SHARES assert len(vf_joint_rand_parts) == Vdaf.SHARES - assert(all(len(part) == Pine.Xof.SEED_SIZE + assert(all(len(part) == Vdaf.Flp.Valid.Xof.SEED_SIZE for part in wr_joint_rand_parts)) - assert(all(len(part) == Pine.Xof.SEED_SIZE + assert(all(len(part) == Vdaf.Flp.Valid.Xof.SEED_SIZE for part in vf_joint_rand_parts)) # Check leader share length. - (meas_share, proof_share, wr_joint_rand_blind, vf_joint_rand_blind) = \ + (meas_share, proofs_share, wr_joint_rand_blind, vf_joint_rand_blind) = \ input_shares[0] assert len(meas_share) == Vdaf.MEAS_LEN - assert len(proof_share) == Vdaf.Flp.PROOF_LEN + assert len(proofs_share) == Vdaf.Flp.PROOF_LEN * Vdaf.PROOFS if __name__ == '__main__': usages = [USAGE_MEAS_SHARE, USAGE_PROOF_SHARE, USAGE_JOINT_RANDOMNESS, @@ -648,17 +672,24 @@ def test_shard_result_share_length(Vdaf: Pine): raise ValueError("Expect Prio3's domain separation tags to be unique " "from 1 to " + str(len(usages)) + ".") - # `Pine` with `l2_norm_bound = 1.0`, `num_frac_bits = 4`, `dimension = 4`, + # Instantiate `Pine` with different field sizes and number of proofs, but + # with the same user parameters: + # `l2_norm_bound = 1.0`, `num_frac_bits = 4`, `dimension = 4`, # `chunk_length = 150`, `num_shares = 2`. - ConcretePine = Pine.with_params(1.0, 4, 4, 150, 2) - test_shard_result_share_length(ConcretePine) - - test_vdaf( - ConcretePine, - None, - [ - [1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0], - ], - [1.0, 1.0, 0.0, 0.0], - ) + args = [1.0, 4, 4, 150, 2] + + # Test happy cases. + for (field, num_proofs) in [(Field64, 2), (Field128, 1)]: + concrete_pine = Pine.with_params(*args, field, num_proofs) + assert concrete_pine.Flp.Field == field + assert concrete_pine.PROOFS == num_proofs + test_shard_result_share_length(concrete_pine) + test_vdaf( + concrete_pine, + None, + [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + ], + [1.0, 1.0, 0.0, 0.0], + )