diff --git a/poc/flp_generic.py b/poc/flp_generic.py index 4d335f7a..49127cc7 100644 --- a/poc/flp_generic.py +++ b/poc/flp_generic.py @@ -658,6 +658,12 @@ class SumVec(Valid): JOINT_RAND_LEN = 1 OUTPUT_LEN = None # Set by constructor + @classmethod + def with_field(SumVec, field: field.FftField): + class SumVecWithField(SumVec): + Field = field + return SumVecWithField + def __init__(self, length, bits, chunk_length): """ Instantiate the `SumVec` circuit for measurements with `length` @@ -829,6 +835,20 @@ def decode(self, output, num_measurements): return total // num_measurements +def _test_sumvec_with_field(f: field.FftField): + cls = SumVec.with_field(f) + assert cls.Field == f + flp = FlpGeneric(cls(2, 4, 1)) + # Roundtrip test with no proof generated. + for meas in [[1, 2], [3, 4], [5, 6], [7, 8]]: + assert meas == flp.decode(flp.truncate(flp.encode(meas)), 1) + + +def test_sumvec_with_field(): + for f in [field.Field64, field.Field96, field.Field128]: + _test_sumvec_with_field(f) + + def test(): flp = FlpGeneric(Count()) test_flp_generic(flp, [ @@ -874,6 +894,8 @@ def test(): (flp.encode(0), True), ]) + test_sumvec_with_field() + if __name__ == '__main__': test() diff --git a/poc/vdaf_prio3.py b/poc/vdaf_prio3.py index fc6fef07..de230d30 100644 --- a/poc/vdaf_prio3.py +++ b/poc/vdaf_prio3.py @@ -7,6 +7,7 @@ import xof from common import (ERR_INPUT, ERR_VERIFY, TEST_VECTOR, Unsigned, byte, concat, front, vec_add, vec_sub, zeros) +from field import FftField, Field64, Field96, Field128 from vdaf import Vdaf, test_vdaf USAGE_MEAS_SHARE = 1 @@ -31,6 +32,7 @@ class Prio3(Vdaf): RAND_SIZE = None # Computed from `Xof.SEED_SIZE` and `SHARES` ROUNDS = 1 SHARES = None # A number between `[2, 256)` set later + PROOFS = 1 # Number of independent proofs # Types required by `Vdaf` Measurement = Flp.Measurement @@ -73,12 +75,11 @@ def is_valid(agg_param, previous_agg_params): def prep_init(Prio3, verify_key, agg_id, _agg_param, nonce, public_share, input_share): k_joint_rand_parts = public_share - (meas_share, proof_share, k_blind) = \ + (meas_share, proof_shares, k_blind) = \ Prio3.expand_input_share(agg_id, input_share) out_share = Prio3.Flp.truncate(meas_share) # Compute the joint randomness. - joint_rand = [] k_corrected_joint_rand, k_joint_rand_part = None, None if Prio3.Flp.JOINT_RAND_LEN > 0: k_joint_rand_part = Prio3.joint_rand_part( @@ -86,15 +87,26 @@ def prep_init(Prio3, verify_key, agg_id, _agg_param, k_joint_rand_parts[agg_id] = k_joint_rand_part k_corrected_joint_rand = Prio3.joint_rand_seed( k_joint_rand_parts) - joint_rand = Prio3.joint_rand(k_corrected_joint_rand) + joint_rands = Prio3.joint_rands(k_corrected_joint_rand) # Query the measurement and proof share. - query_rand = Prio3.query_rand(verify_key, nonce) - verifier_share = Prio3.Flp.query(meas_share, - proof_share, - query_rand, - joint_rand, - Prio3.SHARES) + query_rands = Prio3.query_rand(verify_key, nonce) + verifier_share = [] + for _ in range(Prio3.PROOFS): + proof_share, proof_shares = front( + Prio3.Flp.PROOF_LEN, proof_shares) + if Prio3.Flp.JOINT_RAND_LEN > 0: + joint_rand, joint_rands = front( + Prio3.Flp.JOINT_RAND_LEN, joint_rands) + else: + joint_rand = [] + query_rand, query_rands = front( + Prio3.Flp.QUERY_RAND_LEN, query_rands) + verifier_share += Prio3.Flp.query(meas_share, + proof_share, + query_rand, + joint_rand, + Prio3.SHARES) prep_state = (out_share, k_corrected_joint_rand) prep_share = (verifier_share, k_joint_rand_part) @@ -115,16 +127,19 @@ def prep_next(Prio3, prep, prep_msg): @classmethod def prep_shares_to_prep(Prio3, _agg_param, prep_shares): # Unshard the verifier shares into the verifier message. - verifier = Prio3.Flp.Field.zeros(Prio3.Flp.VERIFIER_LEN) + verifiers = Prio3.Flp.Field.zeros( + Prio3.Flp.VERIFIER_LEN * Prio3.PROOFS) k_joint_rand_parts = [] for (verifier_share, k_joint_rand_part) in prep_shares: - verifier = vec_add(verifier, verifier_share) + verifiers = vec_add(verifiers, verifier_share) if Prio3.Flp.JOINT_RAND_LEN > 0: k_joint_rand_parts.append(k_joint_rand_part) - # Verify that the proof is well-formed and the input is valid. - if not Prio3.Flp.decide(verifier): - raise ERR_VERIFY # proof verifier check failed + # Verify that all the proofs are well-formed and the input is valid. + for _ in range(Prio3.PROOFS): + verifier, verifiers = front(Prio3.Flp.VERIFIER_LEN, verifiers) + if not Prio3.Flp.decide(verifier): + raise ERR_VERIFY # proof verifier check failed # Combine the joint randomness parts computed by the # Aggregators into the true joint randomness seed. This is @@ -173,8 +188,12 @@ def shard_without_joint_rand(Prio3, meas, seeds): ) # Generate the proof and shard it into proof shares. - prove_rand = Prio3.prove_rand(k_prove) - leader_proof_share = Prio3.Flp.prove(meas, prove_rand, []) + prove_rands = Prio3.prove_rands(k_prove) + leader_proof_share = [] + for _ in range(Prio3.PROOFS): + prove_rand, prove_rands = front( + Prio3.Flp.PROVE_RAND_LEN, prove_rands) + leader_proof_share += Prio3.Flp.prove(meas, prove_rand, []) for j in range(Prio3.SHARES-1): leader_proof_share = vec_sub( leader_proof_share, @@ -230,10 +249,16 @@ def shard_with_joint_rand(Prio3, meas, nonce, seeds): 0, k_leader_blind, leader_meas_share, nonce)) # Generate the proof and shard it into proof shares. - prove_rand = Prio3.prove_rand(k_prove) - joint_rand = Prio3.joint_rand( + prove_rands = Prio3.prove_rands(k_prove) + joint_rands = Prio3.joint_rands( Prio3.joint_rand_seed(k_joint_rand_parts)) - leader_proof_share = Prio3.Flp.prove(meas, prove_rand, joint_rand) + leader_proof_share = [] + for _ in range(Prio3.PROOFS): + joint_rand, joint_rands = front( + Prio3.Flp.JOINT_RAND_LEN, joint_rands) + prove_rand, prove_rands = front( + Prio3.Flp.PROVE_RAND_LEN, prove_rands) + leader_proof_share += Prio3.Flp.prove(meas, prove_rand, joint_rand) for j in range(Prio3.SHARES-1): leader_proof_share = vec_sub( leader_proof_share, @@ -274,7 +299,7 @@ def helper_proof_share(Prio3, agg_id, k_share): k_share, Prio3.domain_separation_tag(USAGE_PROOF_SHARE), byte(agg_id), - Prio3.Flp.PROOF_LEN, + Prio3.Flp.PROOF_LEN * Prio3.PROOFS, ) @classmethod @@ -286,13 +311,13 @@ def expand_input_share(Prio3, agg_id, input_share): return (meas_share, proof_share, k_blind) @classmethod - def prove_rand(Prio3, k_prove): + def prove_rands(Prio3, k_prove): return Prio3.Xof.expand_into_vec( Prio3.Flp.Field, k_prove, Prio3.domain_separation_tag(USAGE_PROVE_RANDOMNESS), b'', - Prio3.Flp.PROVE_RAND_LEN, + Prio3.Flp.PROVE_RAND_LEN * Prio3.PROOFS, ) @classmethod @@ -302,7 +327,7 @@ def query_rand(Prio3, verify_key, nonce): verify_key, Prio3.domain_separation_tag(USAGE_QUERY_RANDOMNESS), nonce, - Prio3.Flp.QUERY_RAND_LEN, + Prio3.Flp.QUERY_RAND_LEN * Prio3.PROOFS, ) @classmethod @@ -323,14 +348,14 @@ def joint_rand_seed(Prio3, k_joint_rand_parts): ) @classmethod - def joint_rand(Prio3, k_joint_rand_seed): + def joint_rands(Prio3, k_joint_rand_seed): """Derive the joint randomness from its seed.""" return Prio3.Xof.expand_into_vec( Prio3.Flp.Field, k_joint_rand_seed, Prio3.domain_separation_tag(USAGE_JOINT_RANDOMNESS), b'', - Prio3.Flp.JOINT_RAND_LEN, + Prio3.Flp.JOINT_RAND_LEN * Prio3.PROOFS, ) @classmethod @@ -370,6 +395,7 @@ def test_vec_encode_input_share(Prio3, input_share): (meas_share, proof_share, k_blind) = input_share encoded = bytes() if type(meas_share) == list and type(proof_share) == list: # leader + assert len(proof_share) == Prio3.Flp.PROOF_LEN * Prio3.PROOFS encoded += Prio3.Flp.Field.encode_vec(meas_share) encoded += Prio3.Flp.Field.encode_vec(proof_share) elif type(meas_share) == bytes and type(proof_share) == bytes: # helper @@ -394,6 +420,7 @@ def test_vec_encode_agg_share(Prio3, agg_share): def test_vec_encode_prep_share(Prio3, prep_share): (verifier_share, k_joint_rand_part) = prep_share encoded = bytes() + assert len(verifier_share) == Prio3.Flp.VERIFIER_LEN * Prio3.PROOFS encoded += Prio3.Flp.Field.encode_vec(verifier_share) if k_joint_rand_part != None: # joint randomness used encoded += k_joint_rand_part @@ -483,6 +510,41 @@ class Prio3HistogramWithLength(Prio3Histogram): return Prio3HistogramWithLength +class Prio3SumVecWithMultiproof(Prio3SumVec): + ID = 0xFFFFFFFF # TBD + + # Operational parameters. + test_vec_name = 'Prio3SumVecWithMultiproof' + + @staticmethod + def is_valid(num_proofs: Unsigned, field: FftField) -> bool: + # To be confirmed + if field == Field64: + return num_proofs >= 2 + elif field == Field96: + return num_proofs >= 2 + elif field == Field128: + return num_proofs >= 1 + return False + + @classmethod + def with_params(cls, + length: Unsigned, + bits: Unsigned, + chunk_length: Unsigned, + num_proofs: Unsigned, + field: FftField): + if not cls.is_valid(num_proofs, field): + raise ERR_INPUT + + valid_cls = flp_generic.SumVec.with_field(field) + + class Prio3SumVecWithMultiproofAndParams(cls): + PROOFS = num_proofs + Flp = flp_generic.FlpGeneric(valid_cls(length, bits, chunk_length)) + return Prio3SumVecWithMultiproofAndParams + + ## # TESTS # @@ -506,6 +568,53 @@ class TestPrio3AverageWithBits(TestPrio3Average): return TestPrio3AverageWithBits +def _test_prio3sumvec(num_proofs: Unsigned, field: FftField): + cls = Prio3SumVecWithMultiproof.with_params( + 10, 8, 9, num_proofs=num_proofs, field=field).with_shares(2) + + assert cls.ID == 0xFFFFFFFF + assert cls.PROOFS == num_proofs + + test_vdaf( + cls, + None, + [[1, 61, 86, 61, 23, 0, 255, 3, 2, 1]], + [1, 61, 86, 61, 23, 0, 255, 3, 2, 1] + ) + test_vdaf( + cls, + None, + [ + list(range(10)), + [1] * 10, + [255] * 10 + ], + list(range(256, 266)), + print_test_vec=TEST_VECTOR, + ) + cls = Prio3SumVec.with_params(3, 16, 7).with_shares(3) + test_vdaf( + cls, + None, + [ + [10000, 32000, 9], + [19342, 19615, 3061], + [15986, 24671, 23910] + ], + [45328, 76286, 26980], + print_test_vec=TEST_VECTOR, + test_vec_instance=1, + ) + + +def test_prio3sumvec_with_multiproof(): + for n in range(1, 5): + for f in [Field64, Field96, Field128]: + if not Prio3SumVecWithMultiproof.is_valid(n, f): + continue + _test_prio3sumvec(num_proofs=n, field=f) + + if __name__ == '__main__': num_shares = 2 # Must be in range `[2, 256)` @@ -601,3 +710,5 @@ class TestPrio3AverageWithBits(TestPrio3Average): # otherwise. assert cls.is_valid(None, set([])) assert not cls.is_valid(None, set([None])) + + test_prio3sumvec_with_multiproof()