Skip to content

Commit

Permalink
SQUASH David round 1
Browse files Browse the repository at this point in the history
  • Loading branch information
cjpatton committed Oct 7, 2024
1 parent c718974 commit fe5027f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
29 changes: 15 additions & 14 deletions draft-irtf-cfrg-vdaf.md
Original file line number Diff line number Diff line change
Expand Up @@ -2780,7 +2780,7 @@ def shard_with_joint_rand(
helper_seeds[i]
for i in range(1, (self.SHARES - 1) * 2, 2)
]
(leader_blind, prove), seeds = front(2, seeds)
(leader_blind, prove_seed), seeds = front(2, seeds)

# Shard the encoded measurement into shares and compute the
# joint randomness parts.
Expand All @@ -2798,7 +2798,7 @@ def shard_with_joint_rand(
ctx, 0, leader_blind, leader_meas_share, nonce))

# Generate the proof and shard it into proof shares.
prove_rands = self.prove_rands(ctx, prove)
prove_rands = self.prove_rands(ctx, prove_seed)
joint_rands = self.joint_rands(
ctx, self.joint_rand_seed(ctx, joint_rand_parts))
leader_proofs_share = []
Expand Down Expand Up @@ -2913,16 +2913,17 @@ def prep_init(

# Compute the joint randomness.
joint_rand: list[F] = []
corrected_joint_rand, joint_rand_part = None, None
corrected_joint_rand_seed, joint_rand_part = None, None
if self.flp.JOINT_RAND_LEN > 0:
assert blind is not None
assert joint_rand_parts is not None
joint_rand_part = self.joint_rand_part(
ctx, agg_id, blind, meas_share, nonce)
joint_rand_parts[agg_id] = joint_rand_part
corrected_joint_rand = self.joint_rand_seed(
corrected_joint_rand_seed = self.joint_rand_seed(
ctx, joint_rand_parts)
joint_rands = self.joint_rands(ctx, corrected_joint_rand)
joint_rands = self.joint_rands(
ctx, corrected_joint_rand_seed)

# Query the measurement and proof share.
query_rands = self.query_rands(verify_key, ctx, nonce)
Expand All @@ -2943,7 +2944,7 @@ def prep_init(
self.SHARES,
)

prep_state = (out_share, corrected_joint_rand)
prep_state = (out_share, corrected_joint_rand_seed)
prep_share = (verifiers_share, joint_rand_part)
return (prep_state, prep_share)

Expand All @@ -2953,12 +2954,12 @@ def prep_next(
prep_state: Prio3PrepState[F],
prep_msg: Optional[bytes]
) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]:
joint_rand = prep_msg
(out_share, corrected_joint_rand) = prep_state
joint_rand_seed = prep_msg
(out_share, corrected_joint_rand_seed) = prep_state

# If joint randomness was used, check that the value computed by
# the Aggregators matches the value indicated by the Client.
if joint_rand != corrected_joint_rand:
if joint_rand_seed != corrected_joint_rand_seed:
raise ValueError('joint randomness check failed')

return out_share
Expand Down Expand Up @@ -2987,10 +2988,10 @@ def prep_shares_to_prep(
# Combine the joint randomness parts computed by the
# Aggregators into the true joint randomness seed. This is
# used in the last step.
joint_rand = None
joint_rand_seed = None
if self.flp.JOINT_RAND_LEN > 0:
joint_rand = self.joint_rand_seed(ctx, joint_rand_parts)
return joint_rand
joint_rand_seed = self.joint_rand_seed(ctx, joint_rand_parts)
return joint_rand_seed
~~~
{: #prio3-prep-state title="Preparation state for Prio3."}

Expand Down Expand Up @@ -3099,10 +3100,10 @@ def expand_input_share(
(meas_share, proofs_share, blind) = input_share
return (meas_share, proofs_share, blind)

def prove_rands(self, ctx: bytes, prove: bytes) -> list[F]:
def prove_rands(self, ctx: bytes, prove_seed: bytes) -> list[F]:
return self.xof.expand_into_vec(
self.flp.field,
prove,
prove_seed,
self.domain_separation_tag(USAGE_PROVE_RANDOMNESS, ctx),
byte(self.PROOFS),
self.flp.PROVE_RAND_LEN * self.PROOFS,
Expand Down
35 changes: 18 additions & 17 deletions poc/vdaf_poc/vdaf_prio3.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,17 @@ def prep_init(

# Compute the joint randomness.
joint_rand: list[F] = []
corrected_joint_rand, joint_rand_part = None, None
corrected_joint_rand_seed, joint_rand_part = None, None
if self.flp.JOINT_RAND_LEN > 0:
assert blind is not None
assert joint_rand_parts is not None
joint_rand_part = self.joint_rand_part(
ctx, agg_id, blind, meas_share, nonce)
joint_rand_parts[agg_id] = joint_rand_part
corrected_joint_rand = self.joint_rand_seed(
corrected_joint_rand_seed = self.joint_rand_seed(
ctx, joint_rand_parts)
joint_rands = self.joint_rands(ctx, corrected_joint_rand)
joint_rands = self.joint_rands(
ctx, corrected_joint_rand_seed)

# Query the measurement and proof share.
query_rands = self.query_rands(verify_key, ctx, nonce)
Expand All @@ -182,7 +183,7 @@ def prep_init(
self.SHARES,
)

prep_state = (out_share, corrected_joint_rand)
prep_state = (out_share, corrected_joint_rand_seed)
prep_share = (verifiers_share, joint_rand_part)
return (prep_state, prep_share)

Expand All @@ -192,12 +193,12 @@ def prep_next(
prep_state: Prio3PrepState[F],
prep_msg: Optional[bytes]
) -> tuple[Prio3PrepState[F], Prio3PrepShare[F]] | list[F]:
joint_rand = prep_msg
(out_share, corrected_joint_rand) = prep_state
joint_rand_seed = prep_msg
(out_share, corrected_joint_rand_seed) = prep_state

# If joint randomness was used, check that the value computed by
# the Aggregators matches the value indicated by the Client.
if joint_rand != corrected_joint_rand:
if joint_rand_seed != corrected_joint_rand_seed:
raise ValueError('joint randomness check failed')

return out_share
Expand Down Expand Up @@ -226,10 +227,10 @@ def prep_shares_to_prep(
# Combine the joint randomness parts computed by the
# Aggregators into the true joint randomness seed. This is
# used in the last step.
joint_rand = None
joint_rand_seed = None
if self.flp.JOINT_RAND_LEN > 0:
joint_rand = self.joint_rand_seed(ctx, joint_rand_parts)
return joint_rand
joint_rand_seed = self.joint_rand_seed(ctx, joint_rand_parts)
return joint_rand_seed

# NOTE: This method is excerpted in the document, de-indented, as
# figure {{prio3-out2agg}}. Its width should be limited to 69 columns
Expand Down Expand Up @@ -339,7 +340,7 @@ def shard_with_joint_rand(
helper_seeds[i]
for i in range(1, (self.SHARES - 1) * 2, 2)
]
(leader_blind, prove), seeds = front(2, seeds)
(leader_blind, prove_seed), seeds = front(2, seeds)

# Shard the encoded measurement into shares and compute the
# joint randomness parts.
Expand All @@ -357,7 +358,7 @@ def shard_with_joint_rand(
ctx, 0, leader_blind, leader_meas_share, nonce))

# Generate the proof and shard it into proof shares.
prove_rands = self.prove_rands(ctx, prove)
prove_rands = self.prove_rands(ctx, prove_seed)
joint_rands = self.joint_rands(
ctx, self.joint_rand_seed(ctx, joint_rand_parts))
leader_proofs_share = []
Expand Down Expand Up @@ -448,10 +449,10 @@ def expand_input_share(
(meas_share, proofs_share, blind) = input_share
return (meas_share, proofs_share, blind)

def prove_rands(self, ctx: bytes, prove: bytes) -> list[F]:
def prove_rands(self, ctx: bytes, prove_seed: bytes) -> list[F]:
return self.xof.expand_into_vec(
self.flp.field,
prove,
prove_seed,
self.domain_separation_tag(USAGE_PROVE_RANDOMNESS, ctx),
byte(self.PROOFS),
self.flp.PROVE_RAND_LEN * self.PROOFS,
Expand Down Expand Up @@ -542,10 +543,10 @@ def test_vec_encode_prep_share(self, prep_share: Prio3PrepShare[F]) -> bytes:
return encoded

def test_vec_encode_prep_msg(self, prep_message: Optional[bytes]) -> bytes:
joint_rand = prep_message
joint_rand_seed = prep_message
encoded = bytes()
if joint_rand is not None: # joint randomness used
encoded += joint_rand
if joint_rand_seed is not None: # joint randomness used
encoded += joint_rand_seed
return encoded


Expand Down

0 comments on commit fe5027f

Please sign in to comment.