Skip to content

Commit

Permalink
Use unittest for tests (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave authored Feb 29, 2024
1 parent da84b8a commit 9bb282b
Show file tree
Hide file tree
Showing 24 changed files with 1,071 additions and 981 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/lint-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ jobs:

- name: Run pyflakes
working-directory: poc
run: pyflakes *.py
run: pyflakes *.py tests/*.py

- name: Run autopep8
working-directory: poc
run: autopep8 --diff --exit-code *.py
run: autopep8 --diff --exit-code *.py tests/*.py

- name: Run isort
working-directory: poc
Expand Down
12 changes: 1 addition & 11 deletions poc/Makefile
Original file line number Diff line number Diff line change
@@ -1,12 +1,2 @@
test:
sage -python common.py
sage -python field.py
sage -python xof.py
sage -python flp.py
sage -python flp_generic.py
sage -python idpf.py
sage -python idpf_poplar.py
sage -python daf.py
sage -python vdaf.py
sage -python vdaf_prio3.py
sage -python vdaf_poplar1.py
sage -python -m unittest
78 changes: 0 additions & 78 deletions poc/daf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@

from __future__ import annotations

from functools import reduce

import field
from common import Bool, Unsigned, gen_rand
from xof import XofTurboShake128


class Daf:
Expand Down Expand Up @@ -138,77 +134,3 @@ def run_daf(Daf,
agg_result = Daf.unshard(agg_param, agg_shares,
num_measurements)
return agg_result


##
# TESTS
#

class TestDaf(Daf):
"""A simple DAF used for testing."""

# Operational parameters
Field = field.Field128

# Associated parameters
ID = 0xFFFFFFFF
SHARES = 2
NONCE_SIZE = 0
RAND_SIZE = 16

# Associated types
Measurement = Unsigned
PublicShare = None
InputShare = Field
OutShare = Field
AggShare = Field
AggResult = Unsigned

@classmethod
def shard(cls, measurement, _nonce, rand):
helper_shares = XofTurboShake128.expand_into_vec(cls.Field,
rand,
b'',
b'',
cls.SHARES-1)
leader_share = cls.Field(measurement)
for helper_share in helper_shares:
leader_share -= helper_share
input_shares = [leader_share] + helper_shares
return (None, input_shares)

@classmethod
def prep(cls, _agg_id, _agg_param, _nonce, _public_share, input_share):
# For this simple test DAF, the output share is the same as the input
# share.
return input_share

@classmethod
def aggregate(cls, _agg_param, out_shares):
return reduce(lambda x, y: x + y, out_shares)

@classmethod
def unshard(cls, _agg_param, agg_shares, _num_measurements):
return reduce(lambda x, y: x + y, agg_shares).as_unsigned()


def test_daf(Daf,
agg_param,
measurements,
expected_agg_result):
# Test that the algorithm identifier is in the correct range.
assert 0 <= Daf.ID and Daf.ID < 2 ** 32

# Run the DAF on the set of measurements.
nonces = [gen_rand(Daf.NONCE_SIZE) for _ in range(len(measurements))]
agg_result = run_daf(Daf,
agg_param,
measurements,
nonces)
if agg_result != expected_agg_result:
print('daf test failed ({} on {}): unexpected result: got {}; want {}'.format(
Daf.__class__, measurements, agg_result, expected_agg_result))


if __name__ == '__main__':
test_daf(TestDaf, None, [1, 2, 3, 4], 10)
74 changes: 0 additions & 74 deletions poc/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,77 +251,3 @@ def poly_interp(Field, xs, ys):
R = PolynomialRing(Field.gf, 'x')
p = R.lagrange_polynomial([(x.val, y.val) for (x, y) in zip(xs, ys)])
return poly_strip(Field, list(map(lambda x: Field(x), p.coefficients())))


##
# TESTS
#

def test_field(cls):
# Test constructing a field element from an integer.
assert cls(1337) == cls(cls.gf(1337))

# Test generating a zero-vector.
vec = cls.zeros(23)
assert len(vec) == 23
for x in vec:
assert x == cls(cls.gf.zero())

# Test generating a random vector.
vec = cls.rand_vec(23)
assert len(vec) == 23

# Test arithmetic.
x = cls(cls.gf.random_element())
y = cls(cls.gf.random_element())
assert x + y == cls(x.val + y.val)
assert x - y == cls(x.val - y.val)
assert -x == cls(-x.val)
assert x * y == cls(x.val * y.val)
assert x.inv() == cls(x.val**-1)

# Test serialization.
want = cls.rand_vec(10)
got = cls.decode_vec(cls.encode_vec(want))
assert got == want

# Test encoding integer as bit vector.
vals = [i for i in range(15)]
bits = 4
for val in vals:
encoded = cls.encode_into_bit_vector(val, bits)
assert cls.decode_from_bit_vector(encoded).as_unsigned() == val


def test_fft_field(cls):
test_field(cls)

# Test generator.
assert cls.gen()**cls.GEN_ORDER == cls(1)


if __name__ == '__main__':
test_fft_field(Field64)
test_fft_field(Field96)
test_fft_field(Field128)
test_field(Field255)

# Test GF(2).
assert Field2(1).as_unsigned() == 1
assert Field2(0).as_unsigned() == 0
assert Field2(1) + Field2(1) == Field2(0)
assert Field2(1) * Field2(1) == Field2(1)
assert -Field2(1) == Field2(1)
assert Field2(1).conditional_select(b'hello') == b'hello'
assert Field2(0).conditional_select(b'hello') == bytes([0, 0, 0, 0, 0])

# Test polynomial interpolation.
cls = Field64
p = cls.rand_vec(10)
xs = [cls(x) for x in range(10)]
ys = [poly_eval(cls, p, x) for x in xs]
q = poly_interp(cls, xs, ys)
for x in xs:
a = poly_eval(cls, p, x)
b = poly_eval(cls, q, x)
assert a == b
72 changes: 1 addition & 71 deletions poc/flp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Fully linear proof (FLP) systems."""

from copy import deepcopy

import field
from common import ERR_ENCODE, Bool, Unsigned, Vec, vec_add, vec_sub
from common import Bool, Unsigned, Vec, vec_add, vec_sub
from field import Field


Expand Down Expand Up @@ -124,71 +122,3 @@ def run_flp(flp, meas: Vec[Flp.Field], num_shares: Unsigned):

# Verifier decides if the measurement is valid.
return flp.decide(verifier)


##
# TESTS
#


class FlpTest(Flp):
"""An insecure FLP used only for testing."""
# Associated parameters
JOINT_RAND_LEN = 1
PROVE_RAND_LEN = 2
QUERY_RAND_LEN = 3
MEAS_LEN = 2
OUTPUT_LEN = 1
PROOF_LEN = 2
VERIFIER_LEN = 2

# Associated types
Measurement = Unsigned
AggResult = Unsigned

# Operational parameters
meas_range = range(5)

def encode(self, measurement):
if measurement not in self.meas_range:
raise ERR_ENCODE
return [self.Field(measurement)] * 2

def prove(self, meas, prove_rand, joint_rand):
# The proof is the measurement itself for this trivially insecure FLP.
return deepcopy(meas)

def query(self, meas, proof, query_rand, joint_rand, _num_shares):
return deepcopy(proof)

def decide(self, verifier):
"""Decide if a verifier message was generated from a valid
measurement."""
if len(verifier) != 2 or \
verifier[0] != verifier[1] or \
verifier[0].as_unsigned() not in self.meas_range:
return False
return True

def truncate(self, meas):
return [meas[0]]

def decode(self, output, _num_measurements):
return output[0].as_unsigned()


class FlpTestField128(FlpTest):
Field = field.Field128

@staticmethod
def with_joint_rand_len(joint_rand_len):
flp = FlpTestField128()
flp.JOINT_RAND_LEN = joint_rand_len
return flp


if __name__ == '__main__':
flp = FlpTestField128()
assert run_flp(flp, flp.encode(0), 3) == True
assert run_flp(flp, flp.encode(4), 3) == True
assert run_flp(flp, [field.Field128(1337)], 3) == False
Loading

0 comments on commit 9bb282b

Please sign in to comment.