diff --git a/.github/workflows/golang.yml b/.github/workflows/golang.yml index df07f9ee0..7fe7b396b 100644 --- a/.github/workflows/golang.yml +++ b/.github/workflows/golang.yml @@ -36,6 +36,9 @@ jobs: with: pr-number: ${{ github.event.pull_request.number }} + # TODO - add runtime tests to the workflow + # TODO - add core tests to the workflow + build-curves-linux: name: Build and test curves on Linux runs-on: [self-hosted, Linux, X64, icicle] diff --git a/wrappers/golang/core/internal/mock_curve.go b/wrappers/golang/core/internal/mock_curve.go index dddbddee7..ee5200202 100644 --- a/wrappers/golang/core/internal/mock_curve.go +++ b/wrappers/golang/core/internal/mock_curve.go @@ -47,6 +47,10 @@ func (a *MockAffine) Zero() MockAffine { return *a } +func (a *MockAffine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *MockAffine) FromLimbs(x, y []uint32) MockAffine { a.X.FromLimbs(x) a.Y.FromLimbs(y) diff --git a/wrappers/golang/core/internal/mock_field.go b/wrappers/golang/core/internal/mock_field.go index ae5ce68a7..245d143ef 100644 --- a/wrappers/golang/core/internal/mock_field.go +++ b/wrappers/golang/core/internal/mock_field.go @@ -53,6 +53,16 @@ func (f *MockBaseField) Zero() MockBaseField { return *f } +func (f *MockBaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *MockBaseField) One() MockBaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bls12377/base_field.go b/wrappers/golang/curves/bls12377/base_field.go index 247482119..84e789875 100644 --- a/wrappers/golang/curves/bls12377/base_field.go +++ b/wrappers/golang/curves/bls12377/base_field.go @@ -53,6 +53,16 @@ func (f *BaseField) Zero() BaseField { return *f } +func (f *BaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *BaseField) One() BaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bls12377/curve.go b/wrappers/golang/curves/bls12377/curve.go index 70a6c85d8..1c6de36ed 100644 --- a/wrappers/golang/curves/bls12377/curve.go +++ b/wrappers/golang/curves/bls12377/curve.go @@ -96,6 +96,10 @@ func (a *Affine) Zero() Affine { return *a } +func (a *Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *Affine) FromLimbs(x, y []uint32) Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -106,9 +110,19 @@ func (a *Affine) FromLimbs(x, y []uint32) Affine { func (a Affine) ToProjective() Projective { var p Projective - cA := (*C.affine_t)(unsafe.Pointer(&a)) - cP := (*C.projective_t)(unsafe.Pointer(&p)) - C.bls12_377_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.affine_t)(unsafe.Pointer(&a)) + // cP := (*C.projective_t)(unsafe.Pointer(&p)) + // C.bls12_377_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/curves/bls12377/g2/curve.go b/wrappers/golang/curves/bls12377/g2/curve.go index dd16d714a..fb418b412 100644 --- a/wrappers/golang/curves/bls12377/g2/curve.go +++ b/wrappers/golang/curves/bls12377/g2/curve.go @@ -96,6 +96,10 @@ func (a *G2Affine) Zero() G2Affine { return *a } +func (a *G2Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *G2Affine) FromLimbs(x, y []uint32) G2Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -106,9 +110,19 @@ func (a *G2Affine) FromLimbs(x, y []uint32) G2Affine { func (a G2Affine) ToProjective() G2Projective { var p G2Projective - cA := (*C.g2_affine_t)(unsafe.Pointer(&a)) - cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) - C.bls12_377_g2_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.g2_affine_t)(unsafe.Pointer(&a)) + // cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + // C.bls12_377_g2_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/curves/bls12377/g2/g2base_field.go b/wrappers/golang/curves/bls12377/g2/g2base_field.go index c4073af97..8087a2bbf 100644 --- a/wrappers/golang/curves/bls12377/g2/g2base_field.go +++ b/wrappers/golang/curves/bls12377/g2/g2base_field.go @@ -53,6 +53,16 @@ func (f *G2BaseField) Zero() G2BaseField { return *f } +func (f *G2BaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *G2BaseField) One() G2BaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bls12377/scalar_field.go b/wrappers/golang/curves/bls12377/scalar_field.go index 928e0adfb..a21f480bf 100644 --- a/wrappers/golang/curves/bls12377/scalar_field.go +++ b/wrappers/golang/curves/bls12377/scalar_field.go @@ -60,6 +60,16 @@ func (f *ScalarField) Zero() ScalarField { return *f } +func (f *ScalarField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *ScalarField) One() ScalarField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bls12377/tests/base_field_test.go b/wrappers/golang/curves/bls12377/tests/base_field_test.go index 68e06d715..da9f57531 100644 --- a/wrappers/golang/curves/bls12377/tests/base_field_test.go +++ b/wrappers/golang/curves/bls12377/tests/base_field_test.go @@ -5,85 +5,105 @@ import ( bls12_377 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( BASE_LIMBS = bls12_377.BASE_LIMBS ) -func TestBaseFieldFromLimbs(t *testing.T) { +func testBaseFieldFromLimbs(suite suite.Suite) { emptyField := bls12_377.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestBaseFieldGetLimbs(t *testing.T) { +func testBaseFieldGetLimbs(suite suite.Suite) { emptyField := bls12_377.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") } -func TestBaseFieldOne(t *testing.T) { +func testBaseFieldOne(suite suite.Suite) { var emptyField bls12_377.BaseField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(BASE_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") } -func TestBaseFieldZero(t *testing.T) { +func testBaseFieldZero(suite suite.Suite) { var emptyField bls12_377.BaseField emptyField.Zero() limbsZero := make([]uint32, BASE_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") } -func TestBaseFieldSize(t *testing.T) { +func testBaseFieldSize(suite suite.Suite) { var emptyField bls12_377.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestBaseFieldAsPointer(t *testing.T) { +func testBaseFieldAsPointer(suite suite.Suite) { var emptyField bls12_377.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestBaseFieldFromBytes(t *testing.T) { +func testBaseFieldFromBytes(suite suite.Suite) { var emptyField bls12_377.BaseField bytes, expected := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestBaseFieldToBytes(t *testing.T) { +func testBaseFieldToBytes(suite suite.Suite) { var emptyField bls12_377.BaseField expected, limbs := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") +} + +type BaseFieldTestSuite struct { + suite.Suite +} + +func (s *BaseFieldTestSuite) TestBaseField() { + s.Run("TestBaseFieldFromLimbs", testWrapper(s.Suite, testBaseFieldFromLimbs)) + s.Run("TestBaseFieldGetLimbs", testWrapper(s.Suite, testBaseFieldGetLimbs)) + s.Run("TestBaseFieldOne", testWrapper(s.Suite, testBaseFieldOne)) + s.Run("TestBaseFieldZero", testWrapper(s.Suite, testBaseFieldZero)) + s.Run("TestBaseFieldSize", testWrapper(s.Suite, testBaseFieldSize)) + s.Run("TestBaseFieldAsPointer", testWrapper(s.Suite, testBaseFieldAsPointer)) + s.Run("TestBaseFieldFromBytes", testWrapper(s.Suite, testBaseFieldFromBytes)) + s.Run("TestBaseFieldToBytes", testWrapper(s.Suite, testBaseFieldToBytes)) + +} + +func TestSuiteBaseField(t *testing.T) { + suite.Run(t, new(BaseFieldTestSuite)) } diff --git a/wrappers/golang/curves/bls12377/tests/curve_test.go b/wrappers/golang/curves/bls12377/tests/curve_test.go index 2c4cf0a94..ff7d64437 100644 --- a/wrappers/golang/curves/bls12377/tests/curve_test.go +++ b/wrappers/golang/curves/bls12377/tests/curve_test.go @@ -5,15 +5,15 @@ import ( bls12_377 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestAffineZero(t *testing.T) { +func testAffineZero(suite suite.Suite) { var fieldZero = bls12_377.BaseField{} var affineZero bls12_377.Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -21,22 +21,22 @@ func TestAffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func TestAffineFromLimbs(t *testing.T) { +func testAffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var affine bls12_377.Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func TestAffineToProjective(t *testing.T) { +func testAffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne bls12_377.BaseField @@ -49,31 +49,31 @@ func TestAffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func TestProjectiveZero(t *testing.T) { +func testProjectiveZero(suite suite.Suite) { var projectiveZero bls12_377.Projective projectiveZero.Zero() var fieldZero = bls12_377.BaseField{} var fieldOne bls12_377.BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var projective bls12_377.Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func TestProjectiveFromLimbs(t *testing.T) { +func testProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -81,12 +81,12 @@ func TestProjectiveFromLimbs(t *testing.T) { var projective bls12_377.Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func TestProjectiveFromAffine(t *testing.T) { +func testProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne bls12_377.BaseField @@ -100,5 +100,22 @@ func TestProjectiveFromAffine(t *testing.T) { var projectivePoint bls12_377.Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type CurveTestSuite struct { + suite.Suite +} + +func (s *CurveTestSuite) TestCurve() { + s.Run("TestAffineZero", testWrapper(s.Suite, testAffineZero)) + s.Run("TestAffineFromLimbs", testWrapper(s.Suite, testAffineFromLimbs)) + s.Run("TestAffineToProjective", testWrapper(s.Suite, testAffineToProjective)) + s.Run("TestProjectiveZero", testWrapper(s.Suite, testProjectiveZero)) + s.Run("TestProjectiveFromLimbs", testWrapper(s.Suite, testProjectiveFromLimbs)) + s.Run("TestProjectiveFromAffine", testWrapper(s.Suite, testProjectiveFromAffine)) +} + +func TestSuiteCurve(t *testing.T) { + suite.Run(t, new(CurveTestSuite)) } diff --git a/wrappers/golang/curves/bls12377/tests/ecntt_test.go b/wrappers/golang/curves/bls12377/tests/ecntt_test.go index a9b301ece..d1a171253 100644 --- a/wrappers/golang/curves/bls12377/tests/ecntt_test.go +++ b/wrappers/golang/curves/bls12377/tests/ecntt_test.go @@ -9,10 +9,10 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime/config_extension" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestECNtt(t *testing.T) { +func testECNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() ext := config_extension.Create() ext.SetInt(core.CUDA_NTT_ALGORITHM, int(core.Radix2)) @@ -31,7 +31,19 @@ func TestECNtt(t *testing.T) { output := make(core.HostSlice[bls12_377.Projective], testSize) e := ecntt.ECNtt(pointsCopy, core.KForward, &cfg, output) - assert.Equal(t, runtime.Success, e, "ECNtt failed") + suite.Equal(runtime.Success, e, "ECNtt failed") } } } + +type ECNttTestSuite struct { + suite.Suite +} + +func (s *ECNttTestSuite) TestECNtt() { + s.Run("TestECNtt", testWrapper(s.Suite, testECNtt)) +} + +func TestSuiteECNtt(t *testing.T) { + suite.Run(t, new(ECNttTestSuite)) +} diff --git a/wrappers/golang/curves/bls12377/tests/g2_curve_test.go b/wrappers/golang/curves/bls12377/tests/g2_curve_test.go index 2a02b324b..58324fb9f 100644 --- a/wrappers/golang/curves/bls12377/tests/g2_curve_test.go +++ b/wrappers/golang/curves/bls12377/tests/g2_curve_test.go @@ -5,15 +5,15 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377/g2" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestG2AffineZero(t *testing.T) { +func testG2AffineZero(suite suite.Suite) { var fieldZero = g2.G2BaseField{} var affineZero g2.G2Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -21,22 +21,22 @@ func TestG2AffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func TestG2AffineFromLimbs(t *testing.T) { +func testG2AffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var affine g2.G2Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func TestG2AffineToProjective(t *testing.T) { +func testG2AffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var fieldOne g2.G2BaseField @@ -49,31 +49,31 @@ func TestG2AffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func TestG2ProjectiveZero(t *testing.T) { +func testG2ProjectiveZero(suite suite.Suite) { var projectiveZero g2.G2Projective projectiveZero.Zero() var fieldZero = g2.G2BaseField{} var fieldOne g2.G2BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var projective g2.G2Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func TestG2ProjectiveFromLimbs(t *testing.T) { +func testG2ProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -81,12 +81,12 @@ func TestG2ProjectiveFromLimbs(t *testing.T) { var projective g2.G2Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func TestG2ProjectiveFromAffine(t *testing.T) { +func testG2ProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var fieldOne g2.G2BaseField @@ -100,5 +100,22 @@ func TestG2ProjectiveFromAffine(t *testing.T) { var projectivePoint g2.G2Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type G2CurveTestSuite struct { + suite.Suite +} + +func (s *G2CurveTestSuite) TestG2Curve() { + s.Run("TestG2AffineZero", testWrapper(s.Suite, testG2AffineZero)) + s.Run("TestG2AffineFromLimbs", testWrapper(s.Suite, testG2AffineFromLimbs)) + s.Run("TestG2AffineToProjective", testWrapper(s.Suite, testG2AffineToProjective)) + s.Run("TestG2ProjectiveZero", testWrapper(s.Suite, testG2ProjectiveZero)) + s.Run("TestG2ProjectiveFromLimbs", testWrapper(s.Suite, testG2ProjectiveFromLimbs)) + s.Run("TestG2ProjectiveFromAffine", testWrapper(s.Suite, testG2ProjectiveFromAffine)) +} + +func TestSuiteG2Curve(t *testing.T) { + suite.Run(t, new(G2CurveTestSuite)) } diff --git a/wrappers/golang/curves/bls12377/tests/g2_g2base_field_test.go b/wrappers/golang/curves/bls12377/tests/g2_g2base_field_test.go index f4293a742..ad8ba74fd 100644 --- a/wrappers/golang/curves/bls12377/tests/g2_g2base_field_test.go +++ b/wrappers/golang/curves/bls12377/tests/g2_g2base_field_test.go @@ -5,85 +5,105 @@ import ( bls12_377 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377/g2" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( G2BASE_LIMBS = bls12_377.G2BASE_LIMBS ) -func TestG2BaseFieldFromLimbs(t *testing.T) { +func testG2BaseFieldFromLimbs(suite suite.Suite) { emptyField := bls12_377.G2BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestG2BaseFieldGetLimbs(t *testing.T) { +func testG2BaseFieldGetLimbs(suite suite.Suite) { emptyField := bls12_377.G2BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") } -func TestG2BaseFieldOne(t *testing.T) { +func testG2BaseFieldOne(suite suite.Suite) { var emptyField bls12_377.G2BaseField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(G2BASE_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "G2BaseField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "G2BaseField with limbs to field one did not work") } -func TestG2BaseFieldZero(t *testing.T) { +func testG2BaseFieldZero(suite suite.Suite) { var emptyField bls12_377.G2BaseField emptyField.Zero() limbsZero := make([]uint32, G2BASE_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "G2BaseField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "G2BaseField with limbs to field zero failed") } -func TestG2BaseFieldSize(t *testing.T) { +func testG2BaseFieldSize(suite suite.Suite) { var emptyField bls12_377.G2BaseField randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestG2BaseFieldAsPointer(t *testing.T) { +func testG2BaseFieldAsPointer(suite suite.Suite) { var emptyField bls12_377.G2BaseField randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestG2BaseFieldFromBytes(t *testing.T) { +func testG2BaseFieldFromBytes(suite suite.Suite) { var emptyField bls12_377.G2BaseField bytes, expected := test_helpers.GenerateBytesArray(int(G2BASE_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestG2BaseFieldToBytes(t *testing.T) { +func testG2BaseFieldToBytes(suite suite.Suite) { var emptyField bls12_377.G2BaseField expected, limbs := test_helpers.GenerateBytesArray(int(G2BASE_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") +} + +type G2BaseFieldTestSuite struct { + suite.Suite +} + +func (s *G2BaseFieldTestSuite) TestG2BaseField() { + s.Run("TestG2BaseFieldFromLimbs", testWrapper(s.Suite, testG2BaseFieldFromLimbs)) + s.Run("TestG2BaseFieldGetLimbs", testWrapper(s.Suite, testG2BaseFieldGetLimbs)) + s.Run("TestG2BaseFieldOne", testWrapper(s.Suite, testG2BaseFieldOne)) + s.Run("TestG2BaseFieldZero", testWrapper(s.Suite, testG2BaseFieldZero)) + s.Run("TestG2BaseFieldSize", testWrapper(s.Suite, testG2BaseFieldSize)) + s.Run("TestG2BaseFieldAsPointer", testWrapper(s.Suite, testG2BaseFieldAsPointer)) + s.Run("TestG2BaseFieldFromBytes", testWrapper(s.Suite, testG2BaseFieldFromBytes)) + s.Run("TestG2BaseFieldToBytes", testWrapper(s.Suite, testG2BaseFieldToBytes)) + +} + +func TestSuiteG2BaseField(t *testing.T) { + suite.Run(t, new(G2BaseFieldTestSuite)) } diff --git a/wrappers/golang/curves/bls12377/tests/g2_msm_test.go b/wrappers/golang/curves/bls12377/tests/g2_msm_test.go index 76b93f924..dc1cacfc6 100644 --- a/wrappers/golang/curves/bls12377/tests/g2_msm_test.go +++ b/wrappers/golang/curves/bls12377/tests/g2_msm_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377" @@ -62,7 +62,7 @@ func projectiveToGnarkAffineG2(p g2.G2Projective) bls12377.G2Affine { return *g2Affine.FromJacobian(&g2Jac) } -func testAgainstGnarkCryptoMsmG2(t *testing.T, scalars core.HostSlice[icicleBls12_377.ScalarField], points core.HostSlice[g2.G2Affine], out g2.G2Projective) { +func testAgainstGnarkCryptoMsmG2(suite suite.Suite, scalars core.HostSlice[icicleBls12_377.ScalarField], points core.HostSlice[g2.G2Affine], out g2.G2Projective) { scalarsFr := make([]fr.Element, len(scalars)) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -74,10 +74,10 @@ func testAgainstGnarkCryptoMsmG2(t *testing.T, scalars core.HostSlice[icicleBls1 pointsFp[i] = projectiveToGnarkAffineG2(v.ToProjective()) } - testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t, scalarsFr, pointsFp, out) + testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite, scalarsFr, pointsFp, out) } -func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t *testing.T, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bls12377.G2Affine], out g2.G2Projective) { +func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite suite.Suite, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bls12377.G2Affine], out g2.G2Projective) { var msmRes bls12377.G2Jac msmRes.MultiExp(pointsFp, scalarsFr, ecc.MultiExpConfig{}) @@ -86,7 +86,7 @@ func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t *testing.T, scalarsFr core.Ho icicleResAffine := projectiveToGnarkAffineG2(out) - assert.Equal(t, msmResAffine, icicleResAffine) + suite.Equal(msmResAffine, icicleResAffine) } func convertIcicleG2AffineToG2Affine(iciclePoints []g2.G2Affine) []bls12377.G2Affine { @@ -119,7 +119,7 @@ func convertIcicleG2AffineToG2Affine(iciclePoints []g2.G2Affine) []bls12377.G2Af return points } -func TestMSMG2(t *testing.T) { +func testMSMG2(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -133,11 +133,11 @@ func TestMSMG2(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -145,11 +145,11 @@ func TestMSMG2(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } } -func TestMSMG2GnarkCryptoTypes(t *testing.T) { +func testMSMG2GnarkCryptoTypes(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{3} { runtime.SetDevice(&DEVICE) @@ -169,22 +169,22 @@ func TestMSMG2GnarkCryptoTypes(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.AreBasesMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true e = g2.G2Msm(scalarsHost, pointsHost, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t, scalarsHost, pointsHost, outHost[0]) + testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite, scalarsHost, pointsHost, outHost[0]) } } -func TestMSMG2Batch(t *testing.T) { +func testMSMG2Batch(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -197,10 +197,10 @@ func TestMSMG2Batch(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -209,15 +209,15 @@ func TestMSMG2Batch(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsG2(t *testing.T) { +func testPrecomputePointsG2(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := g2.G2GetDefaultMSMConfig() const precomputeFactor = 8 @@ -234,20 +234,20 @@ func TestPrecomputePointsG2(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = g2.G2PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p g2.G2Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -257,13 +257,13 @@ func TestPrecomputePointsG2(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsSharedBasesG2(t *testing.T) { +func testPrecomputePointsSharedBasesG2(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -279,18 +279,18 @@ func TestPrecomputePointsSharedBasesG2(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = g2.G2PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p g2.G2Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -300,13 +300,13 @@ func TestPrecomputePointsSharedBasesG2(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[0:size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestMSMG2SkewedDistribution(t *testing.T) { +func testMSMG2SkewedDistribution(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -325,19 +325,19 @@ func TestMSMG2SkewedDistribution(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } } -func TestMSMG2MultiDevice(t *testing.T) { +func testMSMG2MultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -361,11 +361,11 @@ func TestMSMG2MultiDevice(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -373,9 +373,27 @@ func TestMSMG2MultiDevice(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } }) } wg.Wait() } + +type MSMG2TestSuite struct { + suite.Suite +} + +func (s *MSMG2TestSuite) TestMSMG2() { + s.Run("TestMSMG2", testWrapper(s.Suite, testMSMG2)) + s.Run("TestMSMG2GnarkCryptoTypes", testWrapper(s.Suite, testMSMG2GnarkCryptoTypes)) + s.Run("TestMSMG2Batch", testWrapper(s.Suite, testMSMG2Batch)) + s.Run("TestPrecomputePointsG2", testWrapper(s.Suite, testPrecomputePointsG2)) + s.Run("TestPrecomputePointsSharedBasesG2", testWrapper(s.Suite, testPrecomputePointsSharedBasesG2)) + s.Run("TestMSMG2SkewedDistribution", testWrapper(s.Suite, testMSMG2SkewedDistribution)) + s.Run("TestMSMG2MultiDevice", testWrapper(s.Suite, testMSMG2MultiDevice)) +} + +func TestSuiteMSMG2(t *testing.T) { + suite.Run(t, new(MSMG2TestSuite)) +} diff --git a/wrappers/golang/curves/bls12377/tests/main_test.go b/wrappers/golang/curves/bls12377/tests/main_test.go index da9bfdfdd..c04a8fbd8 100644 --- a/wrappers/golang/curves/bls12377/tests/main_test.go +++ b/wrappers/golang/curves/bls12377/tests/main_test.go @@ -2,12 +2,14 @@ package tests import ( "fmt" - "testing" - "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bls12_377 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377" ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" + "github.com/stretchr/testify/suite" + "os" + "sync" + "testing" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft" ) @@ -16,7 +18,10 @@ const ( largestTestSize = 20 ) -var DEVICE runtime.Device +var ( + DEVICE runtime.Device + exitCode int +) func initDomain(largestTestSize int, cfg core.NTTInitDomainConfig) runtime.EIcicleError { rouMont, _ := fft.Generator(uint64(1 << largestTestSize)) @@ -29,6 +34,18 @@ func initDomain(largestTestSize int, cfg core.NTTInitDomainConfig) runtime.EIcic return e } +func testWrapper(suite suite.Suite, fn func(suite.Suite)) func() { + return func() { + wg := sync.WaitGroup{} + wg.Add(1) + runtime.RunOnDevice(&DEVICE, func(args ...any) { + defer wg.Done() + fn(suite) + }) + wg.Wait() + } +} + func TestMain(m *testing.M) { runtime.LoadBackendFromEnvOrDefault() devices, e := runtime.GetRegisteredDevices() @@ -36,6 +53,7 @@ func TestMain(m *testing.M) { panic("Failed to load registered devices") } for _, deviceType := range devices { + fmt.Println("Running tests for device type:", deviceType) DEVICE = runtime.CreateDevice(deviceType, 0) runtime.SetDevice(&DEVICE) @@ -50,8 +68,10 @@ func TestMain(m *testing.M) { } } + // TODO - run tests for each device type without calling `m.Run` multiple times + // see https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/testing/testing.go;l=1936-1940 for more info // execute tests - m.Run() + exitCode |= m.Run() // release domain e = ntt.ReleaseDomain() @@ -63,4 +83,6 @@ func TestMain(m *testing.M) { } } } + + os.Exit(exitCode) } diff --git a/wrappers/golang/curves/bls12377/tests/msm_test.go b/wrappers/golang/curves/bls12377/tests/msm_test.go index 92252d377..3563bd3b9 100644 --- a/wrappers/golang/curves/bls12377/tests/msm_test.go +++ b/wrappers/golang/curves/bls12377/tests/msm_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377" @@ -35,7 +35,7 @@ func projectiveToGnarkAffine(p icicleBls12_377.Projective) bls12377.G1Affine { return bls12377.G1Affine{X: *x, Y: *y} } -func testAgainstGnarkCryptoMsm(t *testing.T, scalars core.HostSlice[icicleBls12_377.ScalarField], points core.HostSlice[icicleBls12_377.Affine], out icicleBls12_377.Projective) { +func testAgainstGnarkCryptoMsm(suite suite.Suite, scalars core.HostSlice[icicleBls12_377.ScalarField], points core.HostSlice[icicleBls12_377.Affine], out icicleBls12_377.Projective) { scalarsFr := make([]fr.Element, len(scalars)) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -47,10 +47,10 @@ func testAgainstGnarkCryptoMsm(t *testing.T, scalars core.HostSlice[icicleBls12_ pointsFp[i] = projectiveToGnarkAffine(v.ToProjective()) } - testAgainstGnarkCryptoMsmGnarkCryptoTypes(t, scalarsFr, pointsFp, out) + testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite, scalarsFr, pointsFp, out) } -func testAgainstGnarkCryptoMsmGnarkCryptoTypes(t *testing.T, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bls12377.G1Affine], out icicleBls12_377.Projective) { +func testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite suite.Suite, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bls12377.G1Affine], out icicleBls12_377.Projective) { var msmRes bls12377.G1Jac msmRes.MultiExp(pointsFp, scalarsFr, ecc.MultiExpConfig{}) @@ -59,7 +59,7 @@ func testAgainstGnarkCryptoMsmGnarkCryptoTypes(t *testing.T, scalarsFr core.Host icicleResAffine := projectiveToGnarkAffine(out) - assert.Equal(t, msmResAffine, icicleResAffine) + suite.Equal(msmResAffine, icicleResAffine) } func convertIcicleAffineToG1Affine(iciclePoints []icicleBls12_377.Affine) []bls12377.G1Affine { @@ -79,7 +79,7 @@ func convertIcicleAffineToG1Affine(iciclePoints []icicleBls12_377.Affine) []bls1 return points } -func TestMSM(t *testing.T) { +func testMSM(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -93,11 +93,11 @@ func TestMSM(t *testing.T) { var p icicleBls12_377.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_377.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -105,11 +105,11 @@ func TestMSM(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } } -func TestMSMGnarkCryptoTypes(t *testing.T) { +func testMSMGnarkCryptoTypes(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{3} { runtime.SetDevice(&DEVICE) @@ -129,22 +129,22 @@ func TestMSMGnarkCryptoTypes(t *testing.T) { var p icicleBls12_377.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.AreBasesMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true e = msm.Msm(scalarsHost, pointsHost, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_377.Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmGnarkCryptoTypes(t, scalarsHost, pointsHost, outHost[0]) + testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite, scalarsHost, pointsHost, outHost[0]) } } -func TestMSMBatch(t *testing.T) { +func testMSMBatch(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -157,10 +157,10 @@ func TestMSMBatch(t *testing.T) { var p icicleBls12_377.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_377.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -169,15 +169,15 @@ func TestMSMBatch(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePoints(t *testing.T) { +func testPrecomputePoints(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 @@ -194,20 +194,20 @@ func TestPrecomputePoints(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleBls12_377.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleBls12_377.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -217,13 +217,13 @@ func TestPrecomputePoints(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsSharedBases(t *testing.T) { +func testPrecomputePointsSharedBases(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -239,18 +239,18 @@ func TestPrecomputePointsSharedBases(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleBls12_377.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleBls12_377.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -260,13 +260,13 @@ func TestPrecomputePointsSharedBases(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[0:size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestMSMSkewedDistribution(t *testing.T) { +func testMSMSkewedDistribution(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -285,19 +285,19 @@ func TestMSMSkewedDistribution(t *testing.T) { var p icicleBls12_377.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_377.Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } } -func TestMSMMultiDevice(t *testing.T) { +func testMSMMultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -321,11 +321,11 @@ func TestMSMMultiDevice(t *testing.T) { var p icicleBls12_377.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_377.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -333,9 +333,27 @@ func TestMSMMultiDevice(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } }) } wg.Wait() } + +type MSMTestSuite struct { + suite.Suite +} + +func (s *MSMTestSuite) TestMSM() { + s.Run("TestMSM", testWrapper(s.Suite, testMSM)) + s.Run("TestMSMGnarkCryptoTypes", testWrapper(s.Suite, testMSMGnarkCryptoTypes)) + s.Run("TestMSMBatch", testWrapper(s.Suite, testMSMBatch)) + s.Run("TestPrecomputePoints", testWrapper(s.Suite, testPrecomputePoints)) + s.Run("TestPrecomputePointsSharedBases", testWrapper(s.Suite, testPrecomputePointsSharedBases)) + s.Run("TestMSMSkewedDistribution", testWrapper(s.Suite, testMSMSkewedDistribution)) + s.Run("TestMSMMultiDevice", testWrapper(s.Suite, testMSMMultiDevice)) +} + +func TestSuiteMSM(t *testing.T) { + suite.Run(t, new(MSMTestSuite)) +} diff --git a/wrappers/golang/curves/bls12377/tests/ntt_test.go b/wrappers/golang/curves/bls12377/tests/ntt_test.go index 77c7121e1..74a23e669 100644 --- a/wrappers/golang/curves/bls12377/tests/ntt_test.go +++ b/wrappers/golang/curves/bls12377/tests/ntt_test.go @@ -11,10 +11,10 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[bls12_377.ScalarField], output core.HostSlice[bls12_377.ScalarField], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNtt(suite suite.Suite, size int, scalars core.HostSlice[bls12_377.ScalarField], output core.HostSlice[bls12_377.ScalarField], order core.Ordering, direction core.NTTDir) { scalarsFr := make([]fr.Element, size) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -26,10 +26,10 @@ func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[bl outputAsFr[i] = slice64 } - testAgainstGnarkCryptoNttGnarkTypes(t, size, scalarsFr, outputAsFr, order, direction) + testAgainstGnarkCryptoNttGnarkTypes(suite, size, scalarsFr, outputAsFr, order, direction) } -func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNttGnarkTypes(suite suite.Suite, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { domainWithPrecompute := fft.NewDomain(uint64(size)) // DIT + BitReverse == Ordering.kRR // DIT == Ordering.kRN @@ -51,25 +51,19 @@ func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core. if order == core.KNN || order == core.KRR { fft.BitReverse(scalarsFr) } - assert.Equal(t, scalarsFr, outputAsFr) + suite.Equal(scalarsFr, outputAsFr) } -func TestNTTGetDefaultConfig(t *testing.T) { +func testNTTGetDefaultConfig(suite suite.Suite) { actual := ntt.GetDefaultNttConfig() expected := test_helpers.GenerateLimbOne(int(bls12_377.SCALAR_LIMBS)) - assert.Equal(t, expected, actual.CosetGen[:]) + suite.Equal(expected, actual.CosetGen[:]) cosetGenField := bls12_377.ScalarField{} cosetGenField.One() - assert.ElementsMatch(t, cosetGenField.GetLimbs(), actual.CosetGen) + suite.ElementsMatch(cosetGenField.GetLimbs(), actual.CosetGen) } -func TestInitDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - cfg := core.GetDefaultNTTInitDomainConfig() - assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) }) -} - -func TestNtt(t *testing.T) { +func testNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := bls12_377.GenerateScalars(1 << largestTestSize) @@ -87,11 +81,11 @@ func TestNtt(t *testing.T) { ntt.Ntt(scalarsCopy, core.KForward, &cfg, output) // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, core.KForward) } } } -func TestNttFrElement(t *testing.T) { +func testNttFrElement(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := make([]fr.Element, 4) var x fr.Element @@ -114,12 +108,12 @@ func TestNttFrElement(t *testing.T) { ntt.Ntt(scalarsCopy, core.KForward, &cfg, output) // Compare with gnark-crypto - testAgainstGnarkCryptoNttGnarkTypes(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNttGnarkTypes(suite, testSize, scalarsCopy, output, v, core.KForward) } } } -func TestNttDeviceAsync(t *testing.T) { +func testNttDeviceAsync(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := bls12_377.GenerateScalars(1 << largestTestSize) @@ -150,13 +144,13 @@ func TestNttDeviceAsync(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, direction) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, direction) } } } } -func TestNttBatch(t *testing.T) { +func testNttBatch(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() largestTestSize := 10 largestBatchSize := 20 @@ -194,16 +188,26 @@ func TestNttBatch(t *testing.T) { domainWithPrecompute.FFT(scalarsFr, fft.DIF) fft.BitReverse(scalarsFr) - if !assert.True(t, reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { - t.FailNow() + if !suite.True(reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { + suite.T().FailNow() } } } } } -func TestReleaseDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - e := ntt.ReleaseDomain() - assert.Equal(t, runtime.Success, e, "ReleasDomain failed") +type NTTTestSuite struct { + suite.Suite +} + +func (s *NTTTestSuite) TestNTT() { + s.Run("TestNTTGetDefaultConfig", testWrapper(s.Suite, testNTTGetDefaultConfig)) + s.Run("TestNTT", testWrapper(s.Suite, testNtt)) + s.Run("TestNTTFrElement", testWrapper(s.Suite, testNttFrElement)) + s.Run("TestNttDeviceAsync", testWrapper(s.Suite, testNttDeviceAsync)) + s.Run("TestNttBatch", testWrapper(s.Suite, testNttBatch)) +} + +func TestSuiteNTT(t *testing.T) { + suite.Run(t, new(NTTTestSuite)) } diff --git a/wrappers/golang/curves/bls12377/tests/polynomial_test.go b/wrappers/golang/curves/bls12377/tests/polynomial_test.go index a7aca6d82..1dfd58cb4 100644 --- a/wrappers/golang/curves/bls12377/tests/polynomial_test.go +++ b/wrappers/golang/curves/bls12377/tests/polynomial_test.go @@ -6,10 +6,9 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bls12_377 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377" - // "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377/polynomial" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) var one, two, three, four, five bls12_377.ScalarField @@ -41,7 +40,7 @@ func vecOp(a, b bls12_377.ScalarField, op core.VecOps) bls12_377.ScalarField { return out[0] } -func TestPolyCreateFromCoefficients(t *testing.T) { +func testPolyCreateFromCoefficients(suite suite.Suite) { scalars := bls12_377.GenerateScalars(33) var uniPoly polynomial.DensePolynomial @@ -49,7 +48,7 @@ func TestPolyCreateFromCoefficients(t *testing.T) { poly.Print() } -func TestPolyEval(t *testing.T) { +func testPolyEval(suite suite.Suite) { // testing correct evaluation of f(8) for f(x)=4x^2+2x+5 coeffs := core.HostSliceFromElements([]bls12_377.ScalarField{five, two, four}) var f polynomial.DensePolynomial @@ -62,10 +61,10 @@ func TestPolyEval(t *testing.T) { evals := make(core.HostSlice[bls12_377.ScalarField], 1) fEvaled := f.EvalOnDomain(domains, evals) var expected bls12_377.ScalarField - assert.Equal(t, expected.FromUint32(277), fEvaled.(core.HostSlice[bls12_377.ScalarField])[0]) + suite.Equal(expected.FromUint32(277), fEvaled.(core.HostSlice[bls12_377.ScalarField])[0]) } -func TestPolyClone(t *testing.T) { +func testPolyClone(suite suite.Suite) { f := randomPoly(8) x := rand() fx := f.Eval(x) @@ -76,11 +75,11 @@ func TestPolyClone(t *testing.T) { gx := g.Eval(x) fgx := fg.Eval(x) - assert.Equal(t, fx, gx) - assert.Equal(t, vecOp(fx, gx, core.Add), fgx) + suite.Equal(fx, gx) + suite.Equal(vecOp(fx, gx, core.Add), fgx) } -func TestPolyAddSubMul(t *testing.T) { +func testPolyAddSubMul(suite suite.Suite) { testSize := 1 << 10 f := randomPoly(testSize) g := randomPoly(testSize) @@ -91,26 +90,26 @@ func TestPolyAddSubMul(t *testing.T) { polyAdd := f.Add(&g) fxAddgx := vecOp(fx, gx, core.Add) - assert.Equal(t, polyAdd.Eval(x), fxAddgx) + suite.Equal(polyAdd.Eval(x), fxAddgx) polySub := f.Subtract(&g) fxSubgx := vecOp(fx, gx, core.Sub) - assert.Equal(t, polySub.Eval(x), fxSubgx) + suite.Equal(polySub.Eval(x), fxSubgx) polyMul := f.Multiply(&g) fxMulgx := vecOp(fx, gx, core.Mul) - assert.Equal(t, polyMul.Eval(x), fxMulgx) + suite.Equal(polyMul.Eval(x), fxMulgx) s1 := rand() polMulS1 := f.MultiplyByScalar(s1) - assert.Equal(t, polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) + suite.Equal(polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) s2 := rand() polMulS2 := f.MultiplyByScalar(s2) - assert.Equal(t, polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) + suite.Equal(polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) } -func TestPolyMonomials(t *testing.T) { +func testPolyMonomials(suite suite.Suite) { var zero bls12_377.ScalarField var f polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements([]bls12_377.ScalarField{one, zero, two})) @@ -119,20 +118,20 @@ func TestPolyMonomials(t *testing.T) { fx := f.Eval(x) f.AddMonomial(three, 1) fxAdded := f.Eval(x) - assert.Equal(t, fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) + suite.Equal(fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) f.SubMonomial(one, 0) fxSub := f.Eval(x) - assert.Equal(t, fxSub, vecOp(fxAdded, one, core.Sub)) + suite.Equal(fxSub, vecOp(fxAdded, one, core.Sub)) } -func TestPolyReadCoeffs(t *testing.T) { +func testPolyReadCoeffs(suite suite.Suite) { var f polynomial.DensePolynomial coeffs := core.HostSliceFromElements([]bls12_377.ScalarField{one, two, three, four}) f.CreateFromCoeffecitients(coeffs) coeffsCopied := make(core.HostSlice[bls12_377.ScalarField], coeffs.Len()) _, _ = f.CopyCoeffsRange(0, coeffs.Len()-1, coeffsCopied) - assert.ElementsMatch(t, coeffs, coeffsCopied) + suite.ElementsMatch(coeffs, coeffsCopied) var coeffsDevice core.DeviceSlice coeffsDevice.Malloc(one.Size(), coeffs.Len()) @@ -140,16 +139,16 @@ func TestPolyReadCoeffs(t *testing.T) { coeffsHost := make(core.HostSlice[bls12_377.ScalarField], coeffs.Len()) coeffsHost.CopyFromDevice(&coeffsDevice) - assert.ElementsMatch(t, coeffs, coeffsHost) + suite.ElementsMatch(coeffs, coeffsHost) } -func TestPolyOddEvenSlicing(t *testing.T) { +func testPolyOddEvenSlicing(suite suite.Suite) { size := 1<<10 - 3 f := randomPoly(size) even := f.Even() odd := f.Odd() - assert.Equal(t, f.Degree(), even.Degree()+odd.Degree()+1) + suite.Equal(f.Degree(), even.Degree()+odd.Degree()+1) x := rand() var evenExpected, oddExpected bls12_377.ScalarField @@ -164,13 +163,13 @@ func TestPolyOddEvenSlicing(t *testing.T) { } evenEvaled := even.Eval(x) - assert.Equal(t, evenExpected, evenEvaled) + suite.Equal(evenExpected, evenEvaled) oddEvaled := odd.Eval(x) - assert.Equal(t, oddExpected, oddEvaled) + suite.Equal(oddExpected, oddEvaled) } -func TestPolynomialDivision(t *testing.T) { +func testPolynomialDivision(suite suite.Suite) { // divide f(x)/g(x), compute q(x), r(x) and check f(x)=q(x)*g(x)+r(x) var f, g polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements(bls12_377.GenerateScalars(1 << 4))) @@ -184,10 +183,10 @@ func TestPolynomialDivision(t *testing.T) { x := bls12_377.GenerateScalars(1)[0] fEval := f.Eval(x) fReconEval := fRecon.Eval(x) - assert.Equal(t, fEval, fReconEval) + suite.Equal(fEval, fReconEval) } -func TestDivideByVanishing(t *testing.T) { +func testDivideByVanishing(suite suite.Suite) { // poly of x^4-1 vanishes ad 4th rou var zero bls12_377.ScalarField minus_one := vecOp(zero, one, core.Sub) @@ -200,31 +199,51 @@ func TestDivideByVanishing(t *testing.T) { fv := f.Multiply(&v) fDegree := f.Degree() fvDegree := fv.Degree() - assert.Equal(t, fDegree+4, fvDegree) + suite.Equal(fDegree+4, fvDegree) fReconstructed := fv.DivideByVanishing(4) - assert.Equal(t, fDegree, fReconstructed.Degree()) + suite.Equal(fDegree, fReconstructed.Degree()) x := rand() - assert.Equal(t, f.Eval(x), fReconstructed.Eval(x)) + suite.Equal(f.Eval(x), fReconstructed.Eval(x)) } -// func TestPolySlice(t *testing.T) { +// func TestPolySlice(suite suite.Suite) { // size := 4 // coeffs := bls12_377.GenerateScalars(size) // var f DensePolynomial // f.CreateFromCoeffecitients(coeffs) // fSlice := f.AsSlice() -// assert.True(t, fSlice.IsOnDevice()) -// assert.Equal(t, size, fSlice.Len()) +// suite.True(fSlice.IsOnDevice()) +// suite.Equal(size, fSlice.Len()) // hostSlice := make(core.HostSlice[bls12_377.ScalarField], size) // hostSlice.CopyFromDevice(fSlice) -// assert.Equal(t, coeffs, hostSlice) +// suite.Equal(coeffs, hostSlice) // cfg := ntt.GetDefaultNttConfig() // res := make(core.HostSlice[bls12_377.ScalarField], size) // ntt.Ntt(fSlice, core.KForward, cfg, res) -// assert.Equal(t, f.Eval(one), res[0]) +// suite.Equal(f.Eval(one), res[0]) // } + +type PolynomialTestSuite struct { + suite.Suite +} + +func (s *PolynomialTestSuite) TestPolynomial() { + s.Run("TestPolyCreateFromCoefficients", testWrapper(s.Suite, testPolyCreateFromCoefficients)) + s.Run("TestPolyEval", testWrapper(s.Suite, testPolyEval)) + s.Run("TestPolyClone", testWrapper(s.Suite, testPolyClone)) + s.Run("TestPolyAddSubMul", testWrapper(s.Suite, testPolyAddSubMul)) + s.Run("TestPolyMonomials", testWrapper(s.Suite, testPolyMonomials)) + s.Run("TestPolyReadCoeffs", testWrapper(s.Suite, testPolyReadCoeffs)) + s.Run("TestPolyOddEvenSlicing", testWrapper(s.Suite, testPolyOddEvenSlicing)) + s.Run("TestPolynomialDivision", testWrapper(s.Suite, testPolynomialDivision)) + s.Run("TestDivideByVanishing", testWrapper(s.Suite, testDivideByVanishing)) +} + +func TestSuitePolynomial(t *testing.T) { + suite.Run(t, new(PolynomialTestSuite)) +} diff --git a/wrappers/golang/curves/bls12377/tests/scalar_field_test.go b/wrappers/golang/curves/bls12377/tests/scalar_field_test.go index 5592a07e4..a3b57924e 100644 --- a/wrappers/golang/curves/bls12377/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bls12377/tests/scalar_field_test.go @@ -6,101 +6,101 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bls12_377 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( SCALAR_LIMBS = bls12_377.SCALAR_LIMBS ) -func TestScalarFieldFromLimbs(t *testing.T) { +func testScalarFieldFromLimbs(suite suite.Suite) { emptyField := bls12_377.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestScalarFieldGetLimbs(t *testing.T) { +func testScalarFieldGetLimbs(suite suite.Suite) { emptyField := bls12_377.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") } -func TestScalarFieldOne(t *testing.T) { +func testScalarFieldOne(suite suite.Suite) { var emptyField bls12_377.ScalarField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(SCALAR_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") } -func TestScalarFieldZero(t *testing.T) { +func testScalarFieldZero(suite suite.Suite) { var emptyField bls12_377.ScalarField emptyField.Zero() limbsZero := make([]uint32, SCALAR_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") } -func TestScalarFieldSize(t *testing.T) { +func testScalarFieldSize(suite suite.Suite) { var emptyField bls12_377.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestScalarFieldAsPointer(t *testing.T) { +func testScalarFieldAsPointer(suite suite.Suite) { var emptyField bls12_377.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestScalarFieldFromBytes(t *testing.T) { +func testScalarFieldFromBytes(suite suite.Suite) { var emptyField bls12_377.ScalarField bytes, expected := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestScalarFieldToBytes(t *testing.T) { +func testScalarFieldToBytes(suite suite.Suite) { var emptyField bls12_377.ScalarField expected, limbs := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") } -func TestBls12_377GenerateScalars(t *testing.T) { +func testBls12_377GenerateScalars(suite suite.Suite) { const numScalars = 8 scalars := bls12_377.GenerateScalars(numScalars) - assert.Implements(t, (*core.HostOrDeviceSlice)(nil), &scalars) + suite.Implements((*core.HostOrDeviceSlice)(nil), &scalars) - assert.Equal(t, numScalars, scalars.Len()) + suite.Equal(numScalars, scalars.Len()) zeroScalar := bls12_377.ScalarField{} - assert.NotContains(t, scalars, zeroScalar) + suite.NotContains(scalars, zeroScalar) } -func TestBls12_377MongtomeryConversion(t *testing.T) { +func testBls12_377MongtomeryConversion(suite suite.Suite) { size := 1 << 20 scalars := bls12_377.GenerateScalars(size) @@ -112,10 +112,31 @@ func TestBls12_377MongtomeryConversion(t *testing.T) { scalarsMontHost := make(core.HostSlice[bls12_377.ScalarField], size) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.NotEqual(t, scalars, scalarsMontHost) + suite.NotEqual(scalars, scalarsMontHost) bls12_377.FromMontgomery(deviceScalars) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.Equal(t, scalars, scalarsMontHost) + suite.Equal(scalars, scalarsMontHost) +} + +type ScalarFieldTestSuite struct { + suite.Suite +} + +func (s *ScalarFieldTestSuite) TestScalarField() { + s.Run("TestScalarFieldFromLimbs", testWrapper(s.Suite, testScalarFieldFromLimbs)) + s.Run("TestScalarFieldGetLimbs", testWrapper(s.Suite, testScalarFieldGetLimbs)) + s.Run("TestScalarFieldOne", testWrapper(s.Suite, testScalarFieldOne)) + s.Run("TestScalarFieldZero", testWrapper(s.Suite, testScalarFieldZero)) + s.Run("TestScalarFieldSize", testWrapper(s.Suite, testScalarFieldSize)) + s.Run("TestScalarFieldAsPointer", testWrapper(s.Suite, testScalarFieldAsPointer)) + s.Run("TestScalarFieldFromBytes", testWrapper(s.Suite, testScalarFieldFromBytes)) + s.Run("TestScalarFieldToBytes", testWrapper(s.Suite, testScalarFieldToBytes)) + s.Run("TestBls12_377GenerateScalars", testWrapper(s.Suite, testBls12_377GenerateScalars)) + s.Run("TestBls12_377MongtomeryConversion", testWrapper(s.Suite, testBls12_377MongtomeryConversion)) +} + +func TestSuiteScalarField(t *testing.T) { + suite.Run(t, new(ScalarFieldTestSuite)) } diff --git a/wrappers/golang/curves/bls12377/tests/vec_ops_test.go b/wrappers/golang/curves/bls12377/tests/vec_ops_test.go index 976e11c34..11a27d330 100644 --- a/wrappers/golang/curves/bls12377/tests/vec_ops_test.go +++ b/wrappers/golang/curves/bls12377/tests/vec_ops_test.go @@ -6,10 +6,10 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bls12_377 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12377/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestBls12_377VecOps(t *testing.T) { +func testBls12_377VecOps(suite suite.Suite) { testSize := 1 << 14 a := bls12_377.GenerateScalars(testSize) @@ -27,14 +27,14 @@ func TestBls12_377VecOps(t *testing.T) { vecOps.VecOp(a, b, out, cfg, core.Add) vecOps.VecOp(out, b, out2, cfg, core.Sub) - assert.Equal(t, a, out2) + suite.Equal(a, out2) vecOps.VecOp(a, ones, out3, cfg, core.Mul) - assert.Equal(t, a, out3) + suite.Equal(a, out3) } -func TestBls12_377Transpose(t *testing.T) { +func testBls12_377Transpose(suite suite.Suite) { rowSize := 1 << 6 columnSize := 1 << 8 @@ -48,7 +48,7 @@ func TestBls12_377Transpose(t *testing.T) { vecOps.TransposeMatrix(matrix, out, columnSize, rowSize, cfg) vecOps.TransposeMatrix(out, out2, rowSize, columnSize, cfg) - assert.Equal(t, matrix, out2) + suite.Equal(matrix, out2) var dMatrix, dOut, dOut2 core.DeviceSlice @@ -61,5 +61,18 @@ func TestBls12_377Transpose(t *testing.T) { output := make(core.HostSlice[bls12_377.ScalarField], rowSize*columnSize) output.CopyFromDevice(&dOut2) - assert.Equal(t, matrix, output) + suite.Equal(matrix, output) +} + +type Bls12_377VecOpsTestSuite struct { + suite.Suite +} + +func (s *Bls12_377VecOpsTestSuite) TestBls12_377VecOps() { + s.Run("TestBls12_377VecOps", testWrapper(s.Suite, testBls12_377VecOps)) + s.Run("TestBls12_377Transpose", testWrapper(s.Suite, testBls12_377Transpose)) +} + +func TestSuiteBls12_377VecOps(t *testing.T) { + suite.Run(t, new(Bls12_377VecOpsTestSuite)) } diff --git a/wrappers/golang/curves/bls12381/base_field.go b/wrappers/golang/curves/bls12381/base_field.go index 724c8bca9..536af80e2 100644 --- a/wrappers/golang/curves/bls12381/base_field.go +++ b/wrappers/golang/curves/bls12381/base_field.go @@ -53,6 +53,16 @@ func (f *BaseField) Zero() BaseField { return *f } +func (f *BaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *BaseField) One() BaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bls12381/curve.go b/wrappers/golang/curves/bls12381/curve.go index 241761248..c0518e91a 100644 --- a/wrappers/golang/curves/bls12381/curve.go +++ b/wrappers/golang/curves/bls12381/curve.go @@ -96,6 +96,10 @@ func (a *Affine) Zero() Affine { return *a } +func (a *Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *Affine) FromLimbs(x, y []uint32) Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -106,9 +110,19 @@ func (a *Affine) FromLimbs(x, y []uint32) Affine { func (a Affine) ToProjective() Projective { var p Projective - cA := (*C.affine_t)(unsafe.Pointer(&a)) - cP := (*C.projective_t)(unsafe.Pointer(&p)) - C.bls12_381_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.affine_t)(unsafe.Pointer(&a)) + // cP := (*C.projective_t)(unsafe.Pointer(&p)) + // C.bls12_381_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/curves/bls12381/g2/curve.go b/wrappers/golang/curves/bls12381/g2/curve.go index fa6dad5d3..fe1a303e9 100644 --- a/wrappers/golang/curves/bls12381/g2/curve.go +++ b/wrappers/golang/curves/bls12381/g2/curve.go @@ -96,6 +96,10 @@ func (a *G2Affine) Zero() G2Affine { return *a } +func (a *G2Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *G2Affine) FromLimbs(x, y []uint32) G2Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -106,9 +110,19 @@ func (a *G2Affine) FromLimbs(x, y []uint32) G2Affine { func (a G2Affine) ToProjective() G2Projective { var p G2Projective - cA := (*C.g2_affine_t)(unsafe.Pointer(&a)) - cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) - C.bls12_381_g2_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.g2_affine_t)(unsafe.Pointer(&a)) + // cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + // C.bls12_381_g2_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/curves/bls12381/g2/g2base_field.go b/wrappers/golang/curves/bls12381/g2/g2base_field.go index c4073af97..8087a2bbf 100644 --- a/wrappers/golang/curves/bls12381/g2/g2base_field.go +++ b/wrappers/golang/curves/bls12381/g2/g2base_field.go @@ -53,6 +53,16 @@ func (f *G2BaseField) Zero() G2BaseField { return *f } +func (f *G2BaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *G2BaseField) One() G2BaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bls12381/scalar_field.go b/wrappers/golang/curves/bls12381/scalar_field.go index 5757d0d56..89092939a 100644 --- a/wrappers/golang/curves/bls12381/scalar_field.go +++ b/wrappers/golang/curves/bls12381/scalar_field.go @@ -60,6 +60,16 @@ func (f *ScalarField) Zero() ScalarField { return *f } +func (f *ScalarField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *ScalarField) One() ScalarField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bls12381/tests/base_field_test.go b/wrappers/golang/curves/bls12381/tests/base_field_test.go index 4c4ca22fe..2f37ef300 100644 --- a/wrappers/golang/curves/bls12381/tests/base_field_test.go +++ b/wrappers/golang/curves/bls12381/tests/base_field_test.go @@ -5,85 +5,105 @@ import ( bls12_381 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( BASE_LIMBS = bls12_381.BASE_LIMBS ) -func TestBaseFieldFromLimbs(t *testing.T) { +func testBaseFieldFromLimbs(suite suite.Suite) { emptyField := bls12_381.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestBaseFieldGetLimbs(t *testing.T) { +func testBaseFieldGetLimbs(suite suite.Suite) { emptyField := bls12_381.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") } -func TestBaseFieldOne(t *testing.T) { +func testBaseFieldOne(suite suite.Suite) { var emptyField bls12_381.BaseField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(BASE_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") } -func TestBaseFieldZero(t *testing.T) { +func testBaseFieldZero(suite suite.Suite) { var emptyField bls12_381.BaseField emptyField.Zero() limbsZero := make([]uint32, BASE_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") } -func TestBaseFieldSize(t *testing.T) { +func testBaseFieldSize(suite suite.Suite) { var emptyField bls12_381.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestBaseFieldAsPointer(t *testing.T) { +func testBaseFieldAsPointer(suite suite.Suite) { var emptyField bls12_381.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestBaseFieldFromBytes(t *testing.T) { +func testBaseFieldFromBytes(suite suite.Suite) { var emptyField bls12_381.BaseField bytes, expected := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestBaseFieldToBytes(t *testing.T) { +func testBaseFieldToBytes(suite suite.Suite) { var emptyField bls12_381.BaseField expected, limbs := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") +} + +type BaseFieldTestSuite struct { + suite.Suite +} + +func (s *BaseFieldTestSuite) TestBaseField() { + s.Run("TestBaseFieldFromLimbs", testWrapper(s.Suite, testBaseFieldFromLimbs)) + s.Run("TestBaseFieldGetLimbs", testWrapper(s.Suite, testBaseFieldGetLimbs)) + s.Run("TestBaseFieldOne", testWrapper(s.Suite, testBaseFieldOne)) + s.Run("TestBaseFieldZero", testWrapper(s.Suite, testBaseFieldZero)) + s.Run("TestBaseFieldSize", testWrapper(s.Suite, testBaseFieldSize)) + s.Run("TestBaseFieldAsPointer", testWrapper(s.Suite, testBaseFieldAsPointer)) + s.Run("TestBaseFieldFromBytes", testWrapper(s.Suite, testBaseFieldFromBytes)) + s.Run("TestBaseFieldToBytes", testWrapper(s.Suite, testBaseFieldToBytes)) + +} + +func TestSuiteBaseField(t *testing.T) { + suite.Run(t, new(BaseFieldTestSuite)) } diff --git a/wrappers/golang/curves/bls12381/tests/curve_test.go b/wrappers/golang/curves/bls12381/tests/curve_test.go index 59e349559..750a9f020 100644 --- a/wrappers/golang/curves/bls12381/tests/curve_test.go +++ b/wrappers/golang/curves/bls12381/tests/curve_test.go @@ -5,15 +5,15 @@ import ( bls12_381 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestAffineZero(t *testing.T) { +func testAffineZero(suite suite.Suite) { var fieldZero = bls12_381.BaseField{} var affineZero bls12_381.Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -21,22 +21,22 @@ func TestAffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func TestAffineFromLimbs(t *testing.T) { +func testAffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var affine bls12_381.Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func TestAffineToProjective(t *testing.T) { +func testAffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne bls12_381.BaseField @@ -49,31 +49,31 @@ func TestAffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func TestProjectiveZero(t *testing.T) { +func testProjectiveZero(suite suite.Suite) { var projectiveZero bls12_381.Projective projectiveZero.Zero() var fieldZero = bls12_381.BaseField{} var fieldOne bls12_381.BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var projective bls12_381.Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func TestProjectiveFromLimbs(t *testing.T) { +func testProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -81,12 +81,12 @@ func TestProjectiveFromLimbs(t *testing.T) { var projective bls12_381.Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func TestProjectiveFromAffine(t *testing.T) { +func testProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne bls12_381.BaseField @@ -100,5 +100,22 @@ func TestProjectiveFromAffine(t *testing.T) { var projectivePoint bls12_381.Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type CurveTestSuite struct { + suite.Suite +} + +func (s *CurveTestSuite) TestCurve() { + s.Run("TestAffineZero", testWrapper(s.Suite, testAffineZero)) + s.Run("TestAffineFromLimbs", testWrapper(s.Suite, testAffineFromLimbs)) + s.Run("TestAffineToProjective", testWrapper(s.Suite, testAffineToProjective)) + s.Run("TestProjectiveZero", testWrapper(s.Suite, testProjectiveZero)) + s.Run("TestProjectiveFromLimbs", testWrapper(s.Suite, testProjectiveFromLimbs)) + s.Run("TestProjectiveFromAffine", testWrapper(s.Suite, testProjectiveFromAffine)) +} + +func TestSuiteCurve(t *testing.T) { + suite.Run(t, new(CurveTestSuite)) } diff --git a/wrappers/golang/curves/bls12381/tests/ecntt_test.go b/wrappers/golang/curves/bls12381/tests/ecntt_test.go index 9851b1a90..1b5d95c60 100644 --- a/wrappers/golang/curves/bls12381/tests/ecntt_test.go +++ b/wrappers/golang/curves/bls12381/tests/ecntt_test.go @@ -9,10 +9,10 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime/config_extension" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestECNtt(t *testing.T) { +func testECNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() ext := config_extension.Create() ext.SetInt(core.CUDA_NTT_ALGORITHM, int(core.Radix2)) @@ -31,7 +31,19 @@ func TestECNtt(t *testing.T) { output := make(core.HostSlice[bls12_381.Projective], testSize) e := ecntt.ECNtt(pointsCopy, core.KForward, &cfg, output) - assert.Equal(t, runtime.Success, e, "ECNtt failed") + suite.Equal(runtime.Success, e, "ECNtt failed") } } } + +type ECNttTestSuite struct { + suite.Suite +} + +func (s *ECNttTestSuite) TestECNtt() { + s.Run("TestECNtt", testWrapper(s.Suite, testECNtt)) +} + +func TestSuiteECNtt(t *testing.T) { + suite.Run(t, new(ECNttTestSuite)) +} diff --git a/wrappers/golang/curves/bls12381/tests/g2_curve_test.go b/wrappers/golang/curves/bls12381/tests/g2_curve_test.go index b0580a273..62a29328a 100644 --- a/wrappers/golang/curves/bls12381/tests/g2_curve_test.go +++ b/wrappers/golang/curves/bls12381/tests/g2_curve_test.go @@ -5,15 +5,15 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381/g2" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestG2AffineZero(t *testing.T) { +func testG2AffineZero(suite suite.Suite) { var fieldZero = g2.G2BaseField{} var affineZero g2.G2Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -21,22 +21,22 @@ func TestG2AffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func TestG2AffineFromLimbs(t *testing.T) { +func testG2AffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var affine g2.G2Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func TestG2AffineToProjective(t *testing.T) { +func testG2AffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var fieldOne g2.G2BaseField @@ -49,31 +49,31 @@ func TestG2AffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func TestG2ProjectiveZero(t *testing.T) { +func testG2ProjectiveZero(suite suite.Suite) { var projectiveZero g2.G2Projective projectiveZero.Zero() var fieldZero = g2.G2BaseField{} var fieldOne g2.G2BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var projective g2.G2Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func TestG2ProjectiveFromLimbs(t *testing.T) { +func testG2ProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -81,12 +81,12 @@ func TestG2ProjectiveFromLimbs(t *testing.T) { var projective g2.G2Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func TestG2ProjectiveFromAffine(t *testing.T) { +func testG2ProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var fieldOne g2.G2BaseField @@ -100,5 +100,22 @@ func TestG2ProjectiveFromAffine(t *testing.T) { var projectivePoint g2.G2Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type G2CurveTestSuite struct { + suite.Suite +} + +func (s *G2CurveTestSuite) TestG2Curve() { + s.Run("TestG2AffineZero", testWrapper(s.Suite, testG2AffineZero)) + s.Run("TestG2AffineFromLimbs", testWrapper(s.Suite, testG2AffineFromLimbs)) + s.Run("TestG2AffineToProjective", testWrapper(s.Suite, testG2AffineToProjective)) + s.Run("TestG2ProjectiveZero", testWrapper(s.Suite, testG2ProjectiveZero)) + s.Run("TestG2ProjectiveFromLimbs", testWrapper(s.Suite, testG2ProjectiveFromLimbs)) + s.Run("TestG2ProjectiveFromAffine", testWrapper(s.Suite, testG2ProjectiveFromAffine)) +} + +func TestSuiteG2Curve(t *testing.T) { + suite.Run(t, new(G2CurveTestSuite)) } diff --git a/wrappers/golang/curves/bls12381/tests/g2_g2base_field_test.go b/wrappers/golang/curves/bls12381/tests/g2_g2base_field_test.go index 069c23557..31c592ea9 100644 --- a/wrappers/golang/curves/bls12381/tests/g2_g2base_field_test.go +++ b/wrappers/golang/curves/bls12381/tests/g2_g2base_field_test.go @@ -5,85 +5,105 @@ import ( bls12_381 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381/g2" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( G2BASE_LIMBS = bls12_381.G2BASE_LIMBS ) -func TestG2BaseFieldFromLimbs(t *testing.T) { +func testG2BaseFieldFromLimbs(suite suite.Suite) { emptyField := bls12_381.G2BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestG2BaseFieldGetLimbs(t *testing.T) { +func testG2BaseFieldGetLimbs(suite suite.Suite) { emptyField := bls12_381.G2BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") } -func TestG2BaseFieldOne(t *testing.T) { +func testG2BaseFieldOne(suite suite.Suite) { var emptyField bls12_381.G2BaseField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(G2BASE_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "G2BaseField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "G2BaseField with limbs to field one did not work") } -func TestG2BaseFieldZero(t *testing.T) { +func testG2BaseFieldZero(suite suite.Suite) { var emptyField bls12_381.G2BaseField emptyField.Zero() limbsZero := make([]uint32, G2BASE_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "G2BaseField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "G2BaseField with limbs to field zero failed") } -func TestG2BaseFieldSize(t *testing.T) { +func testG2BaseFieldSize(suite suite.Suite) { var emptyField bls12_381.G2BaseField randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestG2BaseFieldAsPointer(t *testing.T) { +func testG2BaseFieldAsPointer(suite suite.Suite) { var emptyField bls12_381.G2BaseField randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestG2BaseFieldFromBytes(t *testing.T) { +func testG2BaseFieldFromBytes(suite suite.Suite) { var emptyField bls12_381.G2BaseField bytes, expected := test_helpers.GenerateBytesArray(int(G2BASE_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestG2BaseFieldToBytes(t *testing.T) { +func testG2BaseFieldToBytes(suite suite.Suite) { var emptyField bls12_381.G2BaseField expected, limbs := test_helpers.GenerateBytesArray(int(G2BASE_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") +} + +type G2BaseFieldTestSuite struct { + suite.Suite +} + +func (s *G2BaseFieldTestSuite) TestG2BaseField() { + s.Run("TestG2BaseFieldFromLimbs", testWrapper(s.Suite, testG2BaseFieldFromLimbs)) + s.Run("TestG2BaseFieldGetLimbs", testWrapper(s.Suite, testG2BaseFieldGetLimbs)) + s.Run("TestG2BaseFieldOne", testWrapper(s.Suite, testG2BaseFieldOne)) + s.Run("TestG2BaseFieldZero", testWrapper(s.Suite, testG2BaseFieldZero)) + s.Run("TestG2BaseFieldSize", testWrapper(s.Suite, testG2BaseFieldSize)) + s.Run("TestG2BaseFieldAsPointer", testWrapper(s.Suite, testG2BaseFieldAsPointer)) + s.Run("TestG2BaseFieldFromBytes", testWrapper(s.Suite, testG2BaseFieldFromBytes)) + s.Run("TestG2BaseFieldToBytes", testWrapper(s.Suite, testG2BaseFieldToBytes)) + +} + +func TestSuiteG2BaseField(t *testing.T) { + suite.Run(t, new(G2BaseFieldTestSuite)) } diff --git a/wrappers/golang/curves/bls12381/tests/g2_msm_test.go b/wrappers/golang/curves/bls12381/tests/g2_msm_test.go index 4b9708188..82e8c6dce 100644 --- a/wrappers/golang/curves/bls12381/tests/g2_msm_test.go +++ b/wrappers/golang/curves/bls12381/tests/g2_msm_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381" @@ -62,7 +62,7 @@ func projectiveToGnarkAffineG2(p g2.G2Projective) bls12381.G2Affine { return *g2Affine.FromJacobian(&g2Jac) } -func testAgainstGnarkCryptoMsmG2(t *testing.T, scalars core.HostSlice[icicleBls12_381.ScalarField], points core.HostSlice[g2.G2Affine], out g2.G2Projective) { +func testAgainstGnarkCryptoMsmG2(suite suite.Suite, scalars core.HostSlice[icicleBls12_381.ScalarField], points core.HostSlice[g2.G2Affine], out g2.G2Projective) { scalarsFr := make([]fr.Element, len(scalars)) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -74,10 +74,10 @@ func testAgainstGnarkCryptoMsmG2(t *testing.T, scalars core.HostSlice[icicleBls1 pointsFp[i] = projectiveToGnarkAffineG2(v.ToProjective()) } - testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t, scalarsFr, pointsFp, out) + testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite, scalarsFr, pointsFp, out) } -func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t *testing.T, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bls12381.G2Affine], out g2.G2Projective) { +func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite suite.Suite, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bls12381.G2Affine], out g2.G2Projective) { var msmRes bls12381.G2Jac msmRes.MultiExp(pointsFp, scalarsFr, ecc.MultiExpConfig{}) @@ -86,7 +86,7 @@ func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t *testing.T, scalarsFr core.Ho icicleResAffine := projectiveToGnarkAffineG2(out) - assert.Equal(t, msmResAffine, icicleResAffine) + suite.Equal(msmResAffine, icicleResAffine) } func convertIcicleG2AffineToG2Affine(iciclePoints []g2.G2Affine) []bls12381.G2Affine { @@ -119,7 +119,7 @@ func convertIcicleG2AffineToG2Affine(iciclePoints []g2.G2Affine) []bls12381.G2Af return points } -func TestMSMG2(t *testing.T) { +func testMSMG2(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -133,11 +133,11 @@ func TestMSMG2(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -145,11 +145,11 @@ func TestMSMG2(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } } -func TestMSMG2GnarkCryptoTypes(t *testing.T) { +func testMSMG2GnarkCryptoTypes(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{3} { runtime.SetDevice(&DEVICE) @@ -169,22 +169,22 @@ func TestMSMG2GnarkCryptoTypes(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.AreBasesMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true e = g2.G2Msm(scalarsHost, pointsHost, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t, scalarsHost, pointsHost, outHost[0]) + testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite, scalarsHost, pointsHost, outHost[0]) } } -func TestMSMG2Batch(t *testing.T) { +func testMSMG2Batch(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -197,10 +197,10 @@ func TestMSMG2Batch(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -209,15 +209,15 @@ func TestMSMG2Batch(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsG2(t *testing.T) { +func testPrecomputePointsG2(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := g2.G2GetDefaultMSMConfig() const precomputeFactor = 8 @@ -234,20 +234,20 @@ func TestPrecomputePointsG2(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = g2.G2PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p g2.G2Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -257,13 +257,13 @@ func TestPrecomputePointsG2(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsSharedBasesG2(t *testing.T) { +func testPrecomputePointsSharedBasesG2(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -279,18 +279,18 @@ func TestPrecomputePointsSharedBasesG2(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = g2.G2PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p g2.G2Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -300,13 +300,13 @@ func TestPrecomputePointsSharedBasesG2(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[0:size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestMSMG2SkewedDistribution(t *testing.T) { +func testMSMG2SkewedDistribution(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -325,19 +325,19 @@ func TestMSMG2SkewedDistribution(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } } -func TestMSMG2MultiDevice(t *testing.T) { +func testMSMG2MultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -361,11 +361,11 @@ func TestMSMG2MultiDevice(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -373,9 +373,27 @@ func TestMSMG2MultiDevice(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } }) } wg.Wait() } + +type MSMG2TestSuite struct { + suite.Suite +} + +func (s *MSMG2TestSuite) TestMSMG2() { + s.Run("TestMSMG2", testWrapper(s.Suite, testMSMG2)) + s.Run("TestMSMG2GnarkCryptoTypes", testWrapper(s.Suite, testMSMG2GnarkCryptoTypes)) + s.Run("TestMSMG2Batch", testWrapper(s.Suite, testMSMG2Batch)) + s.Run("TestPrecomputePointsG2", testWrapper(s.Suite, testPrecomputePointsG2)) + s.Run("TestPrecomputePointsSharedBasesG2", testWrapper(s.Suite, testPrecomputePointsSharedBasesG2)) + s.Run("TestMSMG2SkewedDistribution", testWrapper(s.Suite, testMSMG2SkewedDistribution)) + s.Run("TestMSMG2MultiDevice", testWrapper(s.Suite, testMSMG2MultiDevice)) +} + +func TestSuiteMSMG2(t *testing.T) { + suite.Run(t, new(MSMG2TestSuite)) +} diff --git a/wrappers/golang/curves/bls12381/tests/main_test.go b/wrappers/golang/curves/bls12381/tests/main_test.go index 2bda7ba6c..3e9ac72c2 100644 --- a/wrappers/golang/curves/bls12381/tests/main_test.go +++ b/wrappers/golang/curves/bls12381/tests/main_test.go @@ -2,12 +2,14 @@ package tests import ( "fmt" - "testing" - "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bls12_381 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381" ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" + "github.com/stretchr/testify/suite" + "os" + "sync" + "testing" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/fft" ) @@ -16,7 +18,10 @@ const ( largestTestSize = 20 ) -var DEVICE runtime.Device +var ( + DEVICE runtime.Device + exitCode int +) func initDomain(largestTestSize int, cfg core.NTTInitDomainConfig) runtime.EIcicleError { rouMont, _ := fft.Generator(uint64(1 << largestTestSize)) @@ -29,6 +34,18 @@ func initDomain(largestTestSize int, cfg core.NTTInitDomainConfig) runtime.EIcic return e } +func testWrapper(suite suite.Suite, fn func(suite.Suite)) func() { + return func() { + wg := sync.WaitGroup{} + wg.Add(1) + runtime.RunOnDevice(&DEVICE, func(args ...any) { + defer wg.Done() + fn(suite) + }) + wg.Wait() + } +} + func TestMain(m *testing.M) { runtime.LoadBackendFromEnvOrDefault() devices, e := runtime.GetRegisteredDevices() @@ -36,6 +53,7 @@ func TestMain(m *testing.M) { panic("Failed to load registered devices") } for _, deviceType := range devices { + fmt.Println("Running tests for device type:", deviceType) DEVICE = runtime.CreateDevice(deviceType, 0) runtime.SetDevice(&DEVICE) @@ -50,8 +68,10 @@ func TestMain(m *testing.M) { } } + // TODO - run tests for each device type without calling `m.Run` multiple times + // see https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/testing/testing.go;l=1936-1940 for more info // execute tests - m.Run() + exitCode |= m.Run() // release domain e = ntt.ReleaseDomain() @@ -63,4 +83,6 @@ func TestMain(m *testing.M) { } } } + + os.Exit(exitCode) } diff --git a/wrappers/golang/curves/bls12381/tests/msm_test.go b/wrappers/golang/curves/bls12381/tests/msm_test.go index 50acdd6dc..4e9869927 100644 --- a/wrappers/golang/curves/bls12381/tests/msm_test.go +++ b/wrappers/golang/curves/bls12381/tests/msm_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381" @@ -35,7 +35,7 @@ func projectiveToGnarkAffine(p icicleBls12_381.Projective) bls12381.G1Affine { return bls12381.G1Affine{X: *x, Y: *y} } -func testAgainstGnarkCryptoMsm(t *testing.T, scalars core.HostSlice[icicleBls12_381.ScalarField], points core.HostSlice[icicleBls12_381.Affine], out icicleBls12_381.Projective) { +func testAgainstGnarkCryptoMsm(suite suite.Suite, scalars core.HostSlice[icicleBls12_381.ScalarField], points core.HostSlice[icicleBls12_381.Affine], out icicleBls12_381.Projective) { scalarsFr := make([]fr.Element, len(scalars)) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -47,10 +47,10 @@ func testAgainstGnarkCryptoMsm(t *testing.T, scalars core.HostSlice[icicleBls12_ pointsFp[i] = projectiveToGnarkAffine(v.ToProjective()) } - testAgainstGnarkCryptoMsmGnarkCryptoTypes(t, scalarsFr, pointsFp, out) + testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite, scalarsFr, pointsFp, out) } -func testAgainstGnarkCryptoMsmGnarkCryptoTypes(t *testing.T, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bls12381.G1Affine], out icicleBls12_381.Projective) { +func testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite suite.Suite, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bls12381.G1Affine], out icicleBls12_381.Projective) { var msmRes bls12381.G1Jac msmRes.MultiExp(pointsFp, scalarsFr, ecc.MultiExpConfig{}) @@ -59,7 +59,7 @@ func testAgainstGnarkCryptoMsmGnarkCryptoTypes(t *testing.T, scalarsFr core.Host icicleResAffine := projectiveToGnarkAffine(out) - assert.Equal(t, msmResAffine, icicleResAffine) + suite.Equal(msmResAffine, icicleResAffine) } func convertIcicleAffineToG1Affine(iciclePoints []icicleBls12_381.Affine) []bls12381.G1Affine { @@ -79,7 +79,7 @@ func convertIcicleAffineToG1Affine(iciclePoints []icicleBls12_381.Affine) []bls1 return points } -func TestMSM(t *testing.T) { +func testMSM(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -93,11 +93,11 @@ func TestMSM(t *testing.T) { var p icicleBls12_381.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_381.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -105,11 +105,11 @@ func TestMSM(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } } -func TestMSMGnarkCryptoTypes(t *testing.T) { +func testMSMGnarkCryptoTypes(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{3} { runtime.SetDevice(&DEVICE) @@ -129,22 +129,22 @@ func TestMSMGnarkCryptoTypes(t *testing.T) { var p icicleBls12_381.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.AreBasesMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true e = msm.Msm(scalarsHost, pointsHost, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_381.Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmGnarkCryptoTypes(t, scalarsHost, pointsHost, outHost[0]) + testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite, scalarsHost, pointsHost, outHost[0]) } } -func TestMSMBatch(t *testing.T) { +func testMSMBatch(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -157,10 +157,10 @@ func TestMSMBatch(t *testing.T) { var p icicleBls12_381.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_381.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -169,15 +169,15 @@ func TestMSMBatch(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePoints(t *testing.T) { +func testPrecomputePoints(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 @@ -194,20 +194,20 @@ func TestPrecomputePoints(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleBls12_381.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleBls12_381.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -217,13 +217,13 @@ func TestPrecomputePoints(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsSharedBases(t *testing.T) { +func testPrecomputePointsSharedBases(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -239,18 +239,18 @@ func TestPrecomputePointsSharedBases(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleBls12_381.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleBls12_381.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -260,13 +260,13 @@ func TestPrecomputePointsSharedBases(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[0:size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestMSMSkewedDistribution(t *testing.T) { +func testMSMSkewedDistribution(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -285,19 +285,19 @@ func TestMSMSkewedDistribution(t *testing.T) { var p icicleBls12_381.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_381.Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } } -func TestMSMMultiDevice(t *testing.T) { +func testMSMMultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -321,11 +321,11 @@ func TestMSMMultiDevice(t *testing.T) { var p icicleBls12_381.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBls12_381.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -333,9 +333,27 @@ func TestMSMMultiDevice(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } }) } wg.Wait() } + +type MSMTestSuite struct { + suite.Suite +} + +func (s *MSMTestSuite) TestMSM() { + s.Run("TestMSM", testWrapper(s.Suite, testMSM)) + s.Run("TestMSMGnarkCryptoTypes", testWrapper(s.Suite, testMSMGnarkCryptoTypes)) + s.Run("TestMSMBatch", testWrapper(s.Suite, testMSMBatch)) + s.Run("TestPrecomputePoints", testWrapper(s.Suite, testPrecomputePoints)) + s.Run("TestPrecomputePointsSharedBases", testWrapper(s.Suite, testPrecomputePointsSharedBases)) + s.Run("TestMSMSkewedDistribution", testWrapper(s.Suite, testMSMSkewedDistribution)) + s.Run("TestMSMMultiDevice", testWrapper(s.Suite, testMSMMultiDevice)) +} + +func TestSuiteMSM(t *testing.T) { + suite.Run(t, new(MSMTestSuite)) +} diff --git a/wrappers/golang/curves/bls12381/tests/ntt_test.go b/wrappers/golang/curves/bls12381/tests/ntt_test.go index e3143f352..c353a1a50 100644 --- a/wrappers/golang/curves/bls12381/tests/ntt_test.go +++ b/wrappers/golang/curves/bls12381/tests/ntt_test.go @@ -11,10 +11,10 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[bls12_381.ScalarField], output core.HostSlice[bls12_381.ScalarField], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNtt(suite suite.Suite, size int, scalars core.HostSlice[bls12_381.ScalarField], output core.HostSlice[bls12_381.ScalarField], order core.Ordering, direction core.NTTDir) { scalarsFr := make([]fr.Element, size) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -26,10 +26,10 @@ func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[bl outputAsFr[i] = slice64 } - testAgainstGnarkCryptoNttGnarkTypes(t, size, scalarsFr, outputAsFr, order, direction) + testAgainstGnarkCryptoNttGnarkTypes(suite, size, scalarsFr, outputAsFr, order, direction) } -func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNttGnarkTypes(suite suite.Suite, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { domainWithPrecompute := fft.NewDomain(uint64(size)) // DIT + BitReverse == Ordering.kRR // DIT == Ordering.kRN @@ -51,25 +51,19 @@ func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core. if order == core.KNN || order == core.KRR { fft.BitReverse(scalarsFr) } - assert.Equal(t, scalarsFr, outputAsFr) + suite.Equal(scalarsFr, outputAsFr) } -func TestNTTGetDefaultConfig(t *testing.T) { +func testNTTGetDefaultConfig(suite suite.Suite) { actual := ntt.GetDefaultNttConfig() expected := test_helpers.GenerateLimbOne(int(bls12_381.SCALAR_LIMBS)) - assert.Equal(t, expected, actual.CosetGen[:]) + suite.Equal(expected, actual.CosetGen[:]) cosetGenField := bls12_381.ScalarField{} cosetGenField.One() - assert.ElementsMatch(t, cosetGenField.GetLimbs(), actual.CosetGen) + suite.ElementsMatch(cosetGenField.GetLimbs(), actual.CosetGen) } -func TestInitDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - cfg := core.GetDefaultNTTInitDomainConfig() - assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) }) -} - -func TestNtt(t *testing.T) { +func testNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := bls12_381.GenerateScalars(1 << largestTestSize) @@ -87,11 +81,11 @@ func TestNtt(t *testing.T) { ntt.Ntt(scalarsCopy, core.KForward, &cfg, output) // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, core.KForward) } } } -func TestNttFrElement(t *testing.T) { +func testNttFrElement(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := make([]fr.Element, 4) var x fr.Element @@ -114,12 +108,12 @@ func TestNttFrElement(t *testing.T) { ntt.Ntt(scalarsCopy, core.KForward, &cfg, output) // Compare with gnark-crypto - testAgainstGnarkCryptoNttGnarkTypes(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNttGnarkTypes(suite, testSize, scalarsCopy, output, v, core.KForward) } } } -func TestNttDeviceAsync(t *testing.T) { +func testNttDeviceAsync(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := bls12_381.GenerateScalars(1 << largestTestSize) @@ -150,13 +144,13 @@ func TestNttDeviceAsync(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, direction) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, direction) } } } } -func TestNttBatch(t *testing.T) { +func testNttBatch(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() largestTestSize := 10 largestBatchSize := 20 @@ -194,16 +188,26 @@ func TestNttBatch(t *testing.T) { domainWithPrecompute.FFT(scalarsFr, fft.DIF) fft.BitReverse(scalarsFr) - if !assert.True(t, reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { - t.FailNow() + if !suite.True(reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { + suite.T().FailNow() } } } } } -func TestReleaseDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - e := ntt.ReleaseDomain() - assert.Equal(t, runtime.Success, e, "ReleasDomain failed") +type NTTTestSuite struct { + suite.Suite +} + +func (s *NTTTestSuite) TestNTT() { + s.Run("TestNTTGetDefaultConfig", testWrapper(s.Suite, testNTTGetDefaultConfig)) + s.Run("TestNTT", testWrapper(s.Suite, testNtt)) + s.Run("TestNTTFrElement", testWrapper(s.Suite, testNttFrElement)) + s.Run("TestNttDeviceAsync", testWrapper(s.Suite, testNttDeviceAsync)) + s.Run("TestNttBatch", testWrapper(s.Suite, testNttBatch)) +} + +func TestSuiteNTT(t *testing.T) { + suite.Run(t, new(NTTTestSuite)) } diff --git a/wrappers/golang/curves/bls12381/tests/polynomial_test.go b/wrappers/golang/curves/bls12381/tests/polynomial_test.go index cc8f3c0e2..aaa921286 100644 --- a/wrappers/golang/curves/bls12381/tests/polynomial_test.go +++ b/wrappers/golang/curves/bls12381/tests/polynomial_test.go @@ -6,10 +6,9 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bls12_381 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381" - // "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381/polynomial" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) var one, two, three, four, five bls12_381.ScalarField @@ -41,7 +40,7 @@ func vecOp(a, b bls12_381.ScalarField, op core.VecOps) bls12_381.ScalarField { return out[0] } -func TestPolyCreateFromCoefficients(t *testing.T) { +func testPolyCreateFromCoefficients(suite suite.Suite) { scalars := bls12_381.GenerateScalars(33) var uniPoly polynomial.DensePolynomial @@ -49,7 +48,7 @@ func TestPolyCreateFromCoefficients(t *testing.T) { poly.Print() } -func TestPolyEval(t *testing.T) { +func testPolyEval(suite suite.Suite) { // testing correct evaluation of f(8) for f(x)=4x^2+2x+5 coeffs := core.HostSliceFromElements([]bls12_381.ScalarField{five, two, four}) var f polynomial.DensePolynomial @@ -62,10 +61,10 @@ func TestPolyEval(t *testing.T) { evals := make(core.HostSlice[bls12_381.ScalarField], 1) fEvaled := f.EvalOnDomain(domains, evals) var expected bls12_381.ScalarField - assert.Equal(t, expected.FromUint32(277), fEvaled.(core.HostSlice[bls12_381.ScalarField])[0]) + suite.Equal(expected.FromUint32(277), fEvaled.(core.HostSlice[bls12_381.ScalarField])[0]) } -func TestPolyClone(t *testing.T) { +func testPolyClone(suite suite.Suite) { f := randomPoly(8) x := rand() fx := f.Eval(x) @@ -76,11 +75,11 @@ func TestPolyClone(t *testing.T) { gx := g.Eval(x) fgx := fg.Eval(x) - assert.Equal(t, fx, gx) - assert.Equal(t, vecOp(fx, gx, core.Add), fgx) + suite.Equal(fx, gx) + suite.Equal(vecOp(fx, gx, core.Add), fgx) } -func TestPolyAddSubMul(t *testing.T) { +func testPolyAddSubMul(suite suite.Suite) { testSize := 1 << 10 f := randomPoly(testSize) g := randomPoly(testSize) @@ -91,26 +90,26 @@ func TestPolyAddSubMul(t *testing.T) { polyAdd := f.Add(&g) fxAddgx := vecOp(fx, gx, core.Add) - assert.Equal(t, polyAdd.Eval(x), fxAddgx) + suite.Equal(polyAdd.Eval(x), fxAddgx) polySub := f.Subtract(&g) fxSubgx := vecOp(fx, gx, core.Sub) - assert.Equal(t, polySub.Eval(x), fxSubgx) + suite.Equal(polySub.Eval(x), fxSubgx) polyMul := f.Multiply(&g) fxMulgx := vecOp(fx, gx, core.Mul) - assert.Equal(t, polyMul.Eval(x), fxMulgx) + suite.Equal(polyMul.Eval(x), fxMulgx) s1 := rand() polMulS1 := f.MultiplyByScalar(s1) - assert.Equal(t, polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) + suite.Equal(polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) s2 := rand() polMulS2 := f.MultiplyByScalar(s2) - assert.Equal(t, polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) + suite.Equal(polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) } -func TestPolyMonomials(t *testing.T) { +func testPolyMonomials(suite suite.Suite) { var zero bls12_381.ScalarField var f polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements([]bls12_381.ScalarField{one, zero, two})) @@ -119,20 +118,20 @@ func TestPolyMonomials(t *testing.T) { fx := f.Eval(x) f.AddMonomial(three, 1) fxAdded := f.Eval(x) - assert.Equal(t, fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) + suite.Equal(fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) f.SubMonomial(one, 0) fxSub := f.Eval(x) - assert.Equal(t, fxSub, vecOp(fxAdded, one, core.Sub)) + suite.Equal(fxSub, vecOp(fxAdded, one, core.Sub)) } -func TestPolyReadCoeffs(t *testing.T) { +func testPolyReadCoeffs(suite suite.Suite) { var f polynomial.DensePolynomial coeffs := core.HostSliceFromElements([]bls12_381.ScalarField{one, two, three, four}) f.CreateFromCoeffecitients(coeffs) coeffsCopied := make(core.HostSlice[bls12_381.ScalarField], coeffs.Len()) _, _ = f.CopyCoeffsRange(0, coeffs.Len()-1, coeffsCopied) - assert.ElementsMatch(t, coeffs, coeffsCopied) + suite.ElementsMatch(coeffs, coeffsCopied) var coeffsDevice core.DeviceSlice coeffsDevice.Malloc(one.Size(), coeffs.Len()) @@ -140,16 +139,16 @@ func TestPolyReadCoeffs(t *testing.T) { coeffsHost := make(core.HostSlice[bls12_381.ScalarField], coeffs.Len()) coeffsHost.CopyFromDevice(&coeffsDevice) - assert.ElementsMatch(t, coeffs, coeffsHost) + suite.ElementsMatch(coeffs, coeffsHost) } -func TestPolyOddEvenSlicing(t *testing.T) { +func testPolyOddEvenSlicing(suite suite.Suite) { size := 1<<10 - 3 f := randomPoly(size) even := f.Even() odd := f.Odd() - assert.Equal(t, f.Degree(), even.Degree()+odd.Degree()+1) + suite.Equal(f.Degree(), even.Degree()+odd.Degree()+1) x := rand() var evenExpected, oddExpected bls12_381.ScalarField @@ -164,13 +163,13 @@ func TestPolyOddEvenSlicing(t *testing.T) { } evenEvaled := even.Eval(x) - assert.Equal(t, evenExpected, evenEvaled) + suite.Equal(evenExpected, evenEvaled) oddEvaled := odd.Eval(x) - assert.Equal(t, oddExpected, oddEvaled) + suite.Equal(oddExpected, oddEvaled) } -func TestPolynomialDivision(t *testing.T) { +func testPolynomialDivision(suite suite.Suite) { // divide f(x)/g(x), compute q(x), r(x) and check f(x)=q(x)*g(x)+r(x) var f, g polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements(bls12_381.GenerateScalars(1 << 4))) @@ -184,10 +183,10 @@ func TestPolynomialDivision(t *testing.T) { x := bls12_381.GenerateScalars(1)[0] fEval := f.Eval(x) fReconEval := fRecon.Eval(x) - assert.Equal(t, fEval, fReconEval) + suite.Equal(fEval, fReconEval) } -func TestDivideByVanishing(t *testing.T) { +func testDivideByVanishing(suite suite.Suite) { // poly of x^4-1 vanishes ad 4th rou var zero bls12_381.ScalarField minus_one := vecOp(zero, one, core.Sub) @@ -200,31 +199,51 @@ func TestDivideByVanishing(t *testing.T) { fv := f.Multiply(&v) fDegree := f.Degree() fvDegree := fv.Degree() - assert.Equal(t, fDegree+4, fvDegree) + suite.Equal(fDegree+4, fvDegree) fReconstructed := fv.DivideByVanishing(4) - assert.Equal(t, fDegree, fReconstructed.Degree()) + suite.Equal(fDegree, fReconstructed.Degree()) x := rand() - assert.Equal(t, f.Eval(x), fReconstructed.Eval(x)) + suite.Equal(f.Eval(x), fReconstructed.Eval(x)) } -// func TestPolySlice(t *testing.T) { +// func TestPolySlice(suite suite.Suite) { // size := 4 // coeffs := bls12_381.GenerateScalars(size) // var f DensePolynomial // f.CreateFromCoeffecitients(coeffs) // fSlice := f.AsSlice() -// assert.True(t, fSlice.IsOnDevice()) -// assert.Equal(t, size, fSlice.Len()) +// suite.True(fSlice.IsOnDevice()) +// suite.Equal(size, fSlice.Len()) // hostSlice := make(core.HostSlice[bls12_381.ScalarField], size) // hostSlice.CopyFromDevice(fSlice) -// assert.Equal(t, coeffs, hostSlice) +// suite.Equal(coeffs, hostSlice) // cfg := ntt.GetDefaultNttConfig() // res := make(core.HostSlice[bls12_381.ScalarField], size) // ntt.Ntt(fSlice, core.KForward, cfg, res) -// assert.Equal(t, f.Eval(one), res[0]) +// suite.Equal(f.Eval(one), res[0]) // } + +type PolynomialTestSuite struct { + suite.Suite +} + +func (s *PolynomialTestSuite) TestPolynomial() { + s.Run("TestPolyCreateFromCoefficients", testWrapper(s.Suite, testPolyCreateFromCoefficients)) + s.Run("TestPolyEval", testWrapper(s.Suite, testPolyEval)) + s.Run("TestPolyClone", testWrapper(s.Suite, testPolyClone)) + s.Run("TestPolyAddSubMul", testWrapper(s.Suite, testPolyAddSubMul)) + s.Run("TestPolyMonomials", testWrapper(s.Suite, testPolyMonomials)) + s.Run("TestPolyReadCoeffs", testWrapper(s.Suite, testPolyReadCoeffs)) + s.Run("TestPolyOddEvenSlicing", testWrapper(s.Suite, testPolyOddEvenSlicing)) + s.Run("TestPolynomialDivision", testWrapper(s.Suite, testPolynomialDivision)) + s.Run("TestDivideByVanishing", testWrapper(s.Suite, testDivideByVanishing)) +} + +func TestSuitePolynomial(t *testing.T) { + suite.Run(t, new(PolynomialTestSuite)) +} diff --git a/wrappers/golang/curves/bls12381/tests/scalar_field_test.go b/wrappers/golang/curves/bls12381/tests/scalar_field_test.go index 9a2206f77..159087b87 100644 --- a/wrappers/golang/curves/bls12381/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bls12381/tests/scalar_field_test.go @@ -6,101 +6,101 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bls12_381 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( SCALAR_LIMBS = bls12_381.SCALAR_LIMBS ) -func TestScalarFieldFromLimbs(t *testing.T) { +func testScalarFieldFromLimbs(suite suite.Suite) { emptyField := bls12_381.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestScalarFieldGetLimbs(t *testing.T) { +func testScalarFieldGetLimbs(suite suite.Suite) { emptyField := bls12_381.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") } -func TestScalarFieldOne(t *testing.T) { +func testScalarFieldOne(suite suite.Suite) { var emptyField bls12_381.ScalarField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(SCALAR_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") } -func TestScalarFieldZero(t *testing.T) { +func testScalarFieldZero(suite suite.Suite) { var emptyField bls12_381.ScalarField emptyField.Zero() limbsZero := make([]uint32, SCALAR_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") } -func TestScalarFieldSize(t *testing.T) { +func testScalarFieldSize(suite suite.Suite) { var emptyField bls12_381.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestScalarFieldAsPointer(t *testing.T) { +func testScalarFieldAsPointer(suite suite.Suite) { var emptyField bls12_381.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestScalarFieldFromBytes(t *testing.T) { +func testScalarFieldFromBytes(suite suite.Suite) { var emptyField bls12_381.ScalarField bytes, expected := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestScalarFieldToBytes(t *testing.T) { +func testScalarFieldToBytes(suite suite.Suite) { var emptyField bls12_381.ScalarField expected, limbs := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") } -func TestBls12_381GenerateScalars(t *testing.T) { +func testBls12_381GenerateScalars(suite suite.Suite) { const numScalars = 8 scalars := bls12_381.GenerateScalars(numScalars) - assert.Implements(t, (*core.HostOrDeviceSlice)(nil), &scalars) + suite.Implements((*core.HostOrDeviceSlice)(nil), &scalars) - assert.Equal(t, numScalars, scalars.Len()) + suite.Equal(numScalars, scalars.Len()) zeroScalar := bls12_381.ScalarField{} - assert.NotContains(t, scalars, zeroScalar) + suite.NotContains(scalars, zeroScalar) } -func TestBls12_381MongtomeryConversion(t *testing.T) { +func testBls12_381MongtomeryConversion(suite suite.Suite) { size := 1 << 20 scalars := bls12_381.GenerateScalars(size) @@ -112,10 +112,31 @@ func TestBls12_381MongtomeryConversion(t *testing.T) { scalarsMontHost := make(core.HostSlice[bls12_381.ScalarField], size) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.NotEqual(t, scalars, scalarsMontHost) + suite.NotEqual(scalars, scalarsMontHost) bls12_381.FromMontgomery(deviceScalars) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.Equal(t, scalars, scalarsMontHost) + suite.Equal(scalars, scalarsMontHost) +} + +type ScalarFieldTestSuite struct { + suite.Suite +} + +func (s *ScalarFieldTestSuite) TestScalarField() { + s.Run("TestScalarFieldFromLimbs", testWrapper(s.Suite, testScalarFieldFromLimbs)) + s.Run("TestScalarFieldGetLimbs", testWrapper(s.Suite, testScalarFieldGetLimbs)) + s.Run("TestScalarFieldOne", testWrapper(s.Suite, testScalarFieldOne)) + s.Run("TestScalarFieldZero", testWrapper(s.Suite, testScalarFieldZero)) + s.Run("TestScalarFieldSize", testWrapper(s.Suite, testScalarFieldSize)) + s.Run("TestScalarFieldAsPointer", testWrapper(s.Suite, testScalarFieldAsPointer)) + s.Run("TestScalarFieldFromBytes", testWrapper(s.Suite, testScalarFieldFromBytes)) + s.Run("TestScalarFieldToBytes", testWrapper(s.Suite, testScalarFieldToBytes)) + s.Run("TestBls12_381GenerateScalars", testWrapper(s.Suite, testBls12_381GenerateScalars)) + s.Run("TestBls12_381MongtomeryConversion", testWrapper(s.Suite, testBls12_381MongtomeryConversion)) +} + +func TestSuiteScalarField(t *testing.T) { + suite.Run(t, new(ScalarFieldTestSuite)) } diff --git a/wrappers/golang/curves/bls12381/tests/vec_ops_test.go b/wrappers/golang/curves/bls12381/tests/vec_ops_test.go index 96a1e5046..09375e95e 100644 --- a/wrappers/golang/curves/bls12381/tests/vec_ops_test.go +++ b/wrappers/golang/curves/bls12381/tests/vec_ops_test.go @@ -6,10 +6,10 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bls12_381 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bls12381/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestBls12_381VecOps(t *testing.T) { +func testBls12_381VecOps(suite suite.Suite) { testSize := 1 << 14 a := bls12_381.GenerateScalars(testSize) @@ -27,14 +27,14 @@ func TestBls12_381VecOps(t *testing.T) { vecOps.VecOp(a, b, out, cfg, core.Add) vecOps.VecOp(out, b, out2, cfg, core.Sub) - assert.Equal(t, a, out2) + suite.Equal(a, out2) vecOps.VecOp(a, ones, out3, cfg, core.Mul) - assert.Equal(t, a, out3) + suite.Equal(a, out3) } -func TestBls12_381Transpose(t *testing.T) { +func testBls12_381Transpose(suite suite.Suite) { rowSize := 1 << 6 columnSize := 1 << 8 @@ -48,7 +48,7 @@ func TestBls12_381Transpose(t *testing.T) { vecOps.TransposeMatrix(matrix, out, columnSize, rowSize, cfg) vecOps.TransposeMatrix(out, out2, rowSize, columnSize, cfg) - assert.Equal(t, matrix, out2) + suite.Equal(matrix, out2) var dMatrix, dOut, dOut2 core.DeviceSlice @@ -61,5 +61,18 @@ func TestBls12_381Transpose(t *testing.T) { output := make(core.HostSlice[bls12_381.ScalarField], rowSize*columnSize) output.CopyFromDevice(&dOut2) - assert.Equal(t, matrix, output) + suite.Equal(matrix, output) +} + +type Bls12_381VecOpsTestSuite struct { + suite.Suite +} + +func (s *Bls12_381VecOpsTestSuite) TestBls12_381VecOps() { + s.Run("TestBls12_381VecOps", testWrapper(s.Suite, testBls12_381VecOps)) + s.Run("TestBls12_381Transpose", testWrapper(s.Suite, testBls12_381Transpose)) +} + +func TestSuiteBls12_381VecOps(t *testing.T) { + suite.Run(t, new(Bls12_381VecOpsTestSuite)) } diff --git a/wrappers/golang/curves/bn254/base_field.go b/wrappers/golang/curves/bn254/base_field.go index 0fbf3aeee..5fe4d8c57 100644 --- a/wrappers/golang/curves/bn254/base_field.go +++ b/wrappers/golang/curves/bn254/base_field.go @@ -53,6 +53,16 @@ func (f *BaseField) Zero() BaseField { return *f } +func (f *BaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *BaseField) One() BaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bn254/curve.go b/wrappers/golang/curves/bn254/curve.go index a3aaa8c7b..504656b44 100644 --- a/wrappers/golang/curves/bn254/curve.go +++ b/wrappers/golang/curves/bn254/curve.go @@ -96,6 +96,10 @@ func (a *Affine) Zero() Affine { return *a } +func (a *Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *Affine) FromLimbs(x, y []uint32) Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -106,9 +110,19 @@ func (a *Affine) FromLimbs(x, y []uint32) Affine { func (a Affine) ToProjective() Projective { var p Projective - cA := (*C.affine_t)(unsafe.Pointer(&a)) - cP := (*C.projective_t)(unsafe.Pointer(&p)) - C.bn254_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.affine_t)(unsafe.Pointer(&a)) + // cP := (*C.projective_t)(unsafe.Pointer(&p)) + // C.bn254_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/curves/bn254/g2/curve.go b/wrappers/golang/curves/bn254/g2/curve.go index 357e8e88c..b1b33c3fb 100644 --- a/wrappers/golang/curves/bn254/g2/curve.go +++ b/wrappers/golang/curves/bn254/g2/curve.go @@ -96,6 +96,10 @@ func (a *G2Affine) Zero() G2Affine { return *a } +func (a *G2Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *G2Affine) FromLimbs(x, y []uint32) G2Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -106,9 +110,19 @@ func (a *G2Affine) FromLimbs(x, y []uint32) G2Affine { func (a G2Affine) ToProjective() G2Projective { var p G2Projective - cA := (*C.g2_affine_t)(unsafe.Pointer(&a)) - cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) - C.bn254_g2_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.g2_affine_t)(unsafe.Pointer(&a)) + // cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + // C.bn254_g2_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/curves/bn254/g2/g2base_field.go b/wrappers/golang/curves/bn254/g2/g2base_field.go index 409a7d5ad..78fa94afa 100644 --- a/wrappers/golang/curves/bn254/g2/g2base_field.go +++ b/wrappers/golang/curves/bn254/g2/g2base_field.go @@ -53,6 +53,16 @@ func (f *G2BaseField) Zero() G2BaseField { return *f } +func (f *G2BaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *G2BaseField) One() G2BaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bn254/scalar_field.go b/wrappers/golang/curves/bn254/scalar_field.go index bc37a327b..8372381fe 100644 --- a/wrappers/golang/curves/bn254/scalar_field.go +++ b/wrappers/golang/curves/bn254/scalar_field.go @@ -60,6 +60,16 @@ func (f *ScalarField) Zero() ScalarField { return *f } +func (f *ScalarField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *ScalarField) One() ScalarField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bn254/tests/base_field_test.go b/wrappers/golang/curves/bn254/tests/base_field_test.go index 2d4c9e77f..a18f217c1 100644 --- a/wrappers/golang/curves/bn254/tests/base_field_test.go +++ b/wrappers/golang/curves/bn254/tests/base_field_test.go @@ -5,85 +5,105 @@ import ( bn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( BASE_LIMBS = bn254.BASE_LIMBS ) -func TestBaseFieldFromLimbs(t *testing.T) { +func testBaseFieldFromLimbs(suite suite.Suite) { emptyField := bn254.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestBaseFieldGetLimbs(t *testing.T) { +func testBaseFieldGetLimbs(suite suite.Suite) { emptyField := bn254.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") } -func TestBaseFieldOne(t *testing.T) { +func testBaseFieldOne(suite suite.Suite) { var emptyField bn254.BaseField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(BASE_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") } -func TestBaseFieldZero(t *testing.T) { +func testBaseFieldZero(suite suite.Suite) { var emptyField bn254.BaseField emptyField.Zero() limbsZero := make([]uint32, BASE_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") } -func TestBaseFieldSize(t *testing.T) { +func testBaseFieldSize(suite suite.Suite) { var emptyField bn254.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestBaseFieldAsPointer(t *testing.T) { +func testBaseFieldAsPointer(suite suite.Suite) { var emptyField bn254.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestBaseFieldFromBytes(t *testing.T) { +func testBaseFieldFromBytes(suite suite.Suite) { var emptyField bn254.BaseField bytes, expected := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestBaseFieldToBytes(t *testing.T) { +func testBaseFieldToBytes(suite suite.Suite) { var emptyField bn254.BaseField expected, limbs := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") +} + +type BaseFieldTestSuite struct { + suite.Suite +} + +func (s *BaseFieldTestSuite) TestBaseField() { + s.Run("TestBaseFieldFromLimbs", testWrapper(s.Suite, testBaseFieldFromLimbs)) + s.Run("TestBaseFieldGetLimbs", testWrapper(s.Suite, testBaseFieldGetLimbs)) + s.Run("TestBaseFieldOne", testWrapper(s.Suite, testBaseFieldOne)) + s.Run("TestBaseFieldZero", testWrapper(s.Suite, testBaseFieldZero)) + s.Run("TestBaseFieldSize", testWrapper(s.Suite, testBaseFieldSize)) + s.Run("TestBaseFieldAsPointer", testWrapper(s.Suite, testBaseFieldAsPointer)) + s.Run("TestBaseFieldFromBytes", testWrapper(s.Suite, testBaseFieldFromBytes)) + s.Run("TestBaseFieldToBytes", testWrapper(s.Suite, testBaseFieldToBytes)) + +} + +func TestSuiteBaseField(t *testing.T) { + suite.Run(t, new(BaseFieldTestSuite)) } diff --git a/wrappers/golang/curves/bn254/tests/curve_test.go b/wrappers/golang/curves/bn254/tests/curve_test.go index 8d47e08e5..d69bc1350 100644 --- a/wrappers/golang/curves/bn254/tests/curve_test.go +++ b/wrappers/golang/curves/bn254/tests/curve_test.go @@ -5,15 +5,15 @@ import ( bn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestAffineZero(t *testing.T) { +func testAffineZero(suite suite.Suite) { var fieldZero = bn254.BaseField{} var affineZero bn254.Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -21,22 +21,22 @@ func TestAffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func TestAffineFromLimbs(t *testing.T) { +func testAffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var affine bn254.Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func TestAffineToProjective(t *testing.T) { +func testAffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne bn254.BaseField @@ -49,31 +49,31 @@ func TestAffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func TestProjectiveZero(t *testing.T) { +func testProjectiveZero(suite suite.Suite) { var projectiveZero bn254.Projective projectiveZero.Zero() var fieldZero = bn254.BaseField{} var fieldOne bn254.BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var projective bn254.Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func TestProjectiveFromLimbs(t *testing.T) { +func testProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -81,12 +81,12 @@ func TestProjectiveFromLimbs(t *testing.T) { var projective bn254.Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func TestProjectiveFromAffine(t *testing.T) { +func testProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne bn254.BaseField @@ -100,5 +100,22 @@ func TestProjectiveFromAffine(t *testing.T) { var projectivePoint bn254.Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type CurveTestSuite struct { + suite.Suite +} + +func (s *CurveTestSuite) TestCurve() { + s.Run("TestAffineZero", testWrapper(s.Suite, testAffineZero)) + s.Run("TestAffineFromLimbs", testWrapper(s.Suite, testAffineFromLimbs)) + s.Run("TestAffineToProjective", testWrapper(s.Suite, testAffineToProjective)) + s.Run("TestProjectiveZero", testWrapper(s.Suite, testProjectiveZero)) + s.Run("TestProjectiveFromLimbs", testWrapper(s.Suite, testProjectiveFromLimbs)) + s.Run("TestProjectiveFromAffine", testWrapper(s.Suite, testProjectiveFromAffine)) +} + +func TestSuiteCurve(t *testing.T) { + suite.Run(t, new(CurveTestSuite)) } diff --git a/wrappers/golang/curves/bn254/tests/ecntt_test.go b/wrappers/golang/curves/bn254/tests/ecntt_test.go index b8e8ee7d7..87526c6f4 100644 --- a/wrappers/golang/curves/bn254/tests/ecntt_test.go +++ b/wrappers/golang/curves/bn254/tests/ecntt_test.go @@ -9,10 +9,10 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime/config_extension" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestECNtt(t *testing.T) { +func testECNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() ext := config_extension.Create() ext.SetInt(core.CUDA_NTT_ALGORITHM, int(core.Radix2)) @@ -31,7 +31,19 @@ func TestECNtt(t *testing.T) { output := make(core.HostSlice[bn254.Projective], testSize) e := ecntt.ECNtt(pointsCopy, core.KForward, &cfg, output) - assert.Equal(t, runtime.Success, e, "ECNtt failed") + suite.Equal(runtime.Success, e, "ECNtt failed") } } } + +type ECNttTestSuite struct { + suite.Suite +} + +func (s *ECNttTestSuite) TestECNtt() { + s.Run("TestECNtt", testWrapper(s.Suite, testECNtt)) +} + +func TestSuiteECNtt(t *testing.T) { + suite.Run(t, new(ECNttTestSuite)) +} diff --git a/wrappers/golang/curves/bn254/tests/g2_curve_test.go b/wrappers/golang/curves/bn254/tests/g2_curve_test.go index f63e05399..a1e2509ba 100644 --- a/wrappers/golang/curves/bn254/tests/g2_curve_test.go +++ b/wrappers/golang/curves/bn254/tests/g2_curve_test.go @@ -5,15 +5,15 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/g2" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestG2AffineZero(t *testing.T) { +func testG2AffineZero(suite suite.Suite) { var fieldZero = g2.G2BaseField{} var affineZero g2.G2Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -21,22 +21,22 @@ func TestG2AffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func TestG2AffineFromLimbs(t *testing.T) { +func testG2AffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var affine g2.G2Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func TestG2AffineToProjective(t *testing.T) { +func testG2AffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var fieldOne g2.G2BaseField @@ -49,31 +49,31 @@ func TestG2AffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func TestG2ProjectiveZero(t *testing.T) { +func testG2ProjectiveZero(suite suite.Suite) { var projectiveZero g2.G2Projective projectiveZero.Zero() var fieldZero = g2.G2BaseField{} var fieldOne g2.G2BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var projective g2.G2Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func TestG2ProjectiveFromLimbs(t *testing.T) { +func testG2ProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -81,12 +81,12 @@ func TestG2ProjectiveFromLimbs(t *testing.T) { var projective g2.G2Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func TestG2ProjectiveFromAffine(t *testing.T) { +func testG2ProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var fieldOne g2.G2BaseField @@ -100,5 +100,22 @@ func TestG2ProjectiveFromAffine(t *testing.T) { var projectivePoint g2.G2Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type G2CurveTestSuite struct { + suite.Suite +} + +func (s *G2CurveTestSuite) TestG2Curve() { + s.Run("TestG2AffineZero", testWrapper(s.Suite, testG2AffineZero)) + s.Run("TestG2AffineFromLimbs", testWrapper(s.Suite, testG2AffineFromLimbs)) + s.Run("TestG2AffineToProjective", testWrapper(s.Suite, testG2AffineToProjective)) + s.Run("TestG2ProjectiveZero", testWrapper(s.Suite, testG2ProjectiveZero)) + s.Run("TestG2ProjectiveFromLimbs", testWrapper(s.Suite, testG2ProjectiveFromLimbs)) + s.Run("TestG2ProjectiveFromAffine", testWrapper(s.Suite, testG2ProjectiveFromAffine)) +} + +func TestSuiteG2Curve(t *testing.T) { + suite.Run(t, new(G2CurveTestSuite)) } diff --git a/wrappers/golang/curves/bn254/tests/g2_g2base_field_test.go b/wrappers/golang/curves/bn254/tests/g2_g2base_field_test.go index 2ebfa22f4..8ee894075 100644 --- a/wrappers/golang/curves/bn254/tests/g2_g2base_field_test.go +++ b/wrappers/golang/curves/bn254/tests/g2_g2base_field_test.go @@ -5,85 +5,105 @@ import ( bn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/g2" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( G2BASE_LIMBS = bn254.G2BASE_LIMBS ) -func TestG2BaseFieldFromLimbs(t *testing.T) { +func testG2BaseFieldFromLimbs(suite suite.Suite) { emptyField := bn254.G2BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestG2BaseFieldGetLimbs(t *testing.T) { +func testG2BaseFieldGetLimbs(suite suite.Suite) { emptyField := bn254.G2BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") } -func TestG2BaseFieldOne(t *testing.T) { +func testG2BaseFieldOne(suite suite.Suite) { var emptyField bn254.G2BaseField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(G2BASE_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "G2BaseField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "G2BaseField with limbs to field one did not work") } -func TestG2BaseFieldZero(t *testing.T) { +func testG2BaseFieldZero(suite suite.Suite) { var emptyField bn254.G2BaseField emptyField.Zero() limbsZero := make([]uint32, G2BASE_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "G2BaseField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "G2BaseField with limbs to field zero failed") } -func TestG2BaseFieldSize(t *testing.T) { +func testG2BaseFieldSize(suite suite.Suite) { var emptyField bn254.G2BaseField randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestG2BaseFieldAsPointer(t *testing.T) { +func testG2BaseFieldAsPointer(suite suite.Suite) { var emptyField bn254.G2BaseField randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestG2BaseFieldFromBytes(t *testing.T) { +func testG2BaseFieldFromBytes(suite suite.Suite) { var emptyField bn254.G2BaseField bytes, expected := test_helpers.GenerateBytesArray(int(G2BASE_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestG2BaseFieldToBytes(t *testing.T) { +func testG2BaseFieldToBytes(suite suite.Suite) { var emptyField bn254.G2BaseField expected, limbs := test_helpers.GenerateBytesArray(int(G2BASE_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") +} + +type G2BaseFieldTestSuite struct { + suite.Suite +} + +func (s *G2BaseFieldTestSuite) TestG2BaseField() { + s.Run("TestG2BaseFieldFromLimbs", testWrapper(s.Suite, testG2BaseFieldFromLimbs)) + s.Run("TestG2BaseFieldGetLimbs", testWrapper(s.Suite, testG2BaseFieldGetLimbs)) + s.Run("TestG2BaseFieldOne", testWrapper(s.Suite, testG2BaseFieldOne)) + s.Run("TestG2BaseFieldZero", testWrapper(s.Suite, testG2BaseFieldZero)) + s.Run("TestG2BaseFieldSize", testWrapper(s.Suite, testG2BaseFieldSize)) + s.Run("TestG2BaseFieldAsPointer", testWrapper(s.Suite, testG2BaseFieldAsPointer)) + s.Run("TestG2BaseFieldFromBytes", testWrapper(s.Suite, testG2BaseFieldFromBytes)) + s.Run("TestG2BaseFieldToBytes", testWrapper(s.Suite, testG2BaseFieldToBytes)) + +} + +func TestSuiteG2BaseField(t *testing.T) { + suite.Run(t, new(G2BaseFieldTestSuite)) } diff --git a/wrappers/golang/curves/bn254/tests/g2_msm_test.go b/wrappers/golang/curves/bn254/tests/g2_msm_test.go index 5e5eef5b3..d022ba1c8 100644 --- a/wrappers/golang/curves/bn254/tests/g2_msm_test.go +++ b/wrappers/golang/curves/bn254/tests/g2_msm_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254" @@ -62,7 +62,7 @@ func projectiveToGnarkAffineG2(p g2.G2Projective) bn254.G2Affine { return *g2Affine.FromJacobian(&g2Jac) } -func testAgainstGnarkCryptoMsmG2(t *testing.T, scalars core.HostSlice[icicleBn254.ScalarField], points core.HostSlice[g2.G2Affine], out g2.G2Projective) { +func testAgainstGnarkCryptoMsmG2(suite suite.Suite, scalars core.HostSlice[icicleBn254.ScalarField], points core.HostSlice[g2.G2Affine], out g2.G2Projective) { scalarsFr := make([]fr.Element, len(scalars)) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -74,10 +74,10 @@ func testAgainstGnarkCryptoMsmG2(t *testing.T, scalars core.HostSlice[icicleBn25 pointsFp[i] = projectiveToGnarkAffineG2(v.ToProjective()) } - testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t, scalarsFr, pointsFp, out) + testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite, scalarsFr, pointsFp, out) } -func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t *testing.T, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bn254.G2Affine], out g2.G2Projective) { +func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite suite.Suite, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bn254.G2Affine], out g2.G2Projective) { var msmRes bn254.G2Jac msmRes.MultiExp(pointsFp, scalarsFr, ecc.MultiExpConfig{}) @@ -86,7 +86,7 @@ func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t *testing.T, scalarsFr core.Ho icicleResAffine := projectiveToGnarkAffineG2(out) - assert.Equal(t, msmResAffine, icicleResAffine) + suite.Equal(msmResAffine, icicleResAffine) } func convertIcicleG2AffineToG2Affine(iciclePoints []g2.G2Affine) []bn254.G2Affine { @@ -119,7 +119,7 @@ func convertIcicleG2AffineToG2Affine(iciclePoints []g2.G2Affine) []bn254.G2Affin return points } -func TestMSMG2(t *testing.T) { +func testMSMG2(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -133,11 +133,11 @@ func TestMSMG2(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -145,11 +145,11 @@ func TestMSMG2(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } } -func TestMSMG2GnarkCryptoTypes(t *testing.T) { +func testMSMG2GnarkCryptoTypes(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{3} { runtime.SetDevice(&DEVICE) @@ -169,22 +169,22 @@ func TestMSMG2GnarkCryptoTypes(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.AreBasesMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true e = g2.G2Msm(scalarsHost, pointsHost, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t, scalarsHost, pointsHost, outHost[0]) + testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite, scalarsHost, pointsHost, outHost[0]) } } -func TestMSMG2Batch(t *testing.T) { +func testMSMG2Batch(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -197,10 +197,10 @@ func TestMSMG2Batch(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -209,15 +209,15 @@ func TestMSMG2Batch(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsG2(t *testing.T) { +func testPrecomputePointsG2(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := g2.G2GetDefaultMSMConfig() const precomputeFactor = 8 @@ -234,20 +234,20 @@ func TestPrecomputePointsG2(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = g2.G2PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p g2.G2Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -257,13 +257,13 @@ func TestPrecomputePointsG2(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsSharedBasesG2(t *testing.T) { +func testPrecomputePointsSharedBasesG2(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -279,18 +279,18 @@ func TestPrecomputePointsSharedBasesG2(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = g2.G2PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p g2.G2Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -300,13 +300,13 @@ func TestPrecomputePointsSharedBasesG2(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[0:size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestMSMG2SkewedDistribution(t *testing.T) { +func testMSMG2SkewedDistribution(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -325,19 +325,19 @@ func TestMSMG2SkewedDistribution(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } } -func TestMSMG2MultiDevice(t *testing.T) { +func testMSMG2MultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -361,11 +361,11 @@ func TestMSMG2MultiDevice(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -373,9 +373,27 @@ func TestMSMG2MultiDevice(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } }) } wg.Wait() } + +type MSMG2TestSuite struct { + suite.Suite +} + +func (s *MSMG2TestSuite) TestMSMG2() { + s.Run("TestMSMG2", testWrapper(s.Suite, testMSMG2)) + s.Run("TestMSMG2GnarkCryptoTypes", testWrapper(s.Suite, testMSMG2GnarkCryptoTypes)) + s.Run("TestMSMG2Batch", testWrapper(s.Suite, testMSMG2Batch)) + s.Run("TestPrecomputePointsG2", testWrapper(s.Suite, testPrecomputePointsG2)) + s.Run("TestPrecomputePointsSharedBasesG2", testWrapper(s.Suite, testPrecomputePointsSharedBasesG2)) + s.Run("TestMSMG2SkewedDistribution", testWrapper(s.Suite, testMSMG2SkewedDistribution)) + s.Run("TestMSMG2MultiDevice", testWrapper(s.Suite, testMSMG2MultiDevice)) +} + +func TestSuiteMSMG2(t *testing.T) { + suite.Run(t, new(MSMG2TestSuite)) +} diff --git a/wrappers/golang/curves/bn254/tests/main_test.go b/wrappers/golang/curves/bn254/tests/main_test.go index 711f5b678..299029534 100644 --- a/wrappers/golang/curves/bn254/tests/main_test.go +++ b/wrappers/golang/curves/bn254/tests/main_test.go @@ -2,12 +2,14 @@ package tests import ( "fmt" - "testing" - "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" + "github.com/stretchr/testify/suite" + "os" + "sync" + "testing" "github.com/consensys/gnark-crypto/ecc/bn254/fr/fft" ) @@ -16,7 +18,10 @@ const ( largestTestSize = 20 ) -var DEVICE runtime.Device +var ( + DEVICE runtime.Device + exitCode int +) func initDomain(largestTestSize int, cfg core.NTTInitDomainConfig) runtime.EIcicleError { rouMont, _ := fft.Generator(uint64(1 << largestTestSize)) @@ -29,6 +34,18 @@ func initDomain(largestTestSize int, cfg core.NTTInitDomainConfig) runtime.EIcic return e } +func testWrapper(suite suite.Suite, fn func(suite.Suite)) func() { + return func() { + wg := sync.WaitGroup{} + wg.Add(1) + runtime.RunOnDevice(&DEVICE, func(args ...any) { + defer wg.Done() + fn(suite) + }) + wg.Wait() + } +} + func TestMain(m *testing.M) { runtime.LoadBackendFromEnvOrDefault() devices, e := runtime.GetRegisteredDevices() @@ -36,6 +53,7 @@ func TestMain(m *testing.M) { panic("Failed to load registered devices") } for _, deviceType := range devices { + fmt.Println("Running tests for device type:", deviceType) DEVICE = runtime.CreateDevice(deviceType, 0) runtime.SetDevice(&DEVICE) @@ -50,8 +68,10 @@ func TestMain(m *testing.M) { } } + // TODO - run tests for each device type without calling `m.Run` multiple times + // see https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/testing/testing.go;l=1936-1940 for more info // execute tests - m.Run() + exitCode |= m.Run() // release domain e = ntt.ReleaseDomain() @@ -63,4 +83,6 @@ func TestMain(m *testing.M) { } } } + + os.Exit(exitCode) } diff --git a/wrappers/golang/curves/bn254/tests/msm_test.go b/wrappers/golang/curves/bn254/tests/msm_test.go index 84e092c4b..ee6d7defc 100644 --- a/wrappers/golang/curves/bn254/tests/msm_test.go +++ b/wrappers/golang/curves/bn254/tests/msm_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254" @@ -35,7 +35,7 @@ func projectiveToGnarkAffine(p icicleBn254.Projective) bn254.G1Affine { return bn254.G1Affine{X: *x, Y: *y} } -func testAgainstGnarkCryptoMsm(t *testing.T, scalars core.HostSlice[icicleBn254.ScalarField], points core.HostSlice[icicleBn254.Affine], out icicleBn254.Projective) { +func testAgainstGnarkCryptoMsm(suite suite.Suite, scalars core.HostSlice[icicleBn254.ScalarField], points core.HostSlice[icicleBn254.Affine], out icicleBn254.Projective) { scalarsFr := make([]fr.Element, len(scalars)) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -47,10 +47,10 @@ func testAgainstGnarkCryptoMsm(t *testing.T, scalars core.HostSlice[icicleBn254. pointsFp[i] = projectiveToGnarkAffine(v.ToProjective()) } - testAgainstGnarkCryptoMsmGnarkCryptoTypes(t, scalarsFr, pointsFp, out) + testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite, scalarsFr, pointsFp, out) } -func testAgainstGnarkCryptoMsmGnarkCryptoTypes(t *testing.T, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bn254.G1Affine], out icicleBn254.Projective) { +func testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite suite.Suite, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bn254.G1Affine], out icicleBn254.Projective) { var msmRes bn254.G1Jac msmRes.MultiExp(pointsFp, scalarsFr, ecc.MultiExpConfig{}) @@ -59,7 +59,7 @@ func testAgainstGnarkCryptoMsmGnarkCryptoTypes(t *testing.T, scalarsFr core.Host icicleResAffine := projectiveToGnarkAffine(out) - assert.Equal(t, msmResAffine, icicleResAffine) + suite.Equal(msmResAffine, icicleResAffine) } func convertIcicleAffineToG1Affine(iciclePoints []icicleBn254.Affine) []bn254.G1Affine { @@ -79,7 +79,7 @@ func convertIcicleAffineToG1Affine(iciclePoints []icicleBn254.Affine) []bn254.G1 return points } -func TestMSM(t *testing.T) { +func testMSM(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -93,11 +93,11 @@ func TestMSM(t *testing.T) { var p icicleBn254.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBn254.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -105,11 +105,11 @@ func TestMSM(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } } -func TestMSMGnarkCryptoTypes(t *testing.T) { +func testMSMGnarkCryptoTypes(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{3} { runtime.SetDevice(&DEVICE) @@ -129,22 +129,22 @@ func TestMSMGnarkCryptoTypes(t *testing.T) { var p icicleBn254.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.AreBasesMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true e = msm.Msm(scalarsHost, pointsHost, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBn254.Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmGnarkCryptoTypes(t, scalarsHost, pointsHost, outHost[0]) + testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite, scalarsHost, pointsHost, outHost[0]) } } -func TestMSMBatch(t *testing.T) { +func testMSMBatch(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -157,10 +157,10 @@ func TestMSMBatch(t *testing.T) { var p icicleBn254.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBn254.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -169,15 +169,15 @@ func TestMSMBatch(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePoints(t *testing.T) { +func testPrecomputePoints(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 @@ -194,20 +194,20 @@ func TestPrecomputePoints(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleBn254.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleBn254.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -217,13 +217,13 @@ func TestPrecomputePoints(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsSharedBases(t *testing.T) { +func testPrecomputePointsSharedBases(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -239,18 +239,18 @@ func TestPrecomputePointsSharedBases(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleBn254.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleBn254.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -260,13 +260,13 @@ func TestPrecomputePointsSharedBases(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[0:size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestMSMSkewedDistribution(t *testing.T) { +func testMSMSkewedDistribution(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -285,19 +285,19 @@ func TestMSMSkewedDistribution(t *testing.T) { var p icicleBn254.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBn254.Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } } -func TestMSMMultiDevice(t *testing.T) { +func testMSMMultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -321,11 +321,11 @@ func TestMSMMultiDevice(t *testing.T) { var p icicleBn254.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBn254.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -333,9 +333,27 @@ func TestMSMMultiDevice(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } }) } wg.Wait() } + +type MSMTestSuite struct { + suite.Suite +} + +func (s *MSMTestSuite) TestMSM() { + s.Run("TestMSM", testWrapper(s.Suite, testMSM)) + s.Run("TestMSMGnarkCryptoTypes", testWrapper(s.Suite, testMSMGnarkCryptoTypes)) + s.Run("TestMSMBatch", testWrapper(s.Suite, testMSMBatch)) + s.Run("TestPrecomputePoints", testWrapper(s.Suite, testPrecomputePoints)) + s.Run("TestPrecomputePointsSharedBases", testWrapper(s.Suite, testPrecomputePointsSharedBases)) + s.Run("TestMSMSkewedDistribution", testWrapper(s.Suite, testMSMSkewedDistribution)) + s.Run("TestMSMMultiDevice", testWrapper(s.Suite, testMSMMultiDevice)) +} + +func TestSuiteMSM(t *testing.T) { + suite.Run(t, new(MSMTestSuite)) +} diff --git a/wrappers/golang/curves/bn254/tests/ntt_test.go b/wrappers/golang/curves/bn254/tests/ntt_test.go index 79b8ecbe9..8ed5a248d 100644 --- a/wrappers/golang/curves/bn254/tests/ntt_test.go +++ b/wrappers/golang/curves/bn254/tests/ntt_test.go @@ -11,10 +11,10 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[bn254.ScalarField], output core.HostSlice[bn254.ScalarField], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNtt(suite suite.Suite, size int, scalars core.HostSlice[bn254.ScalarField], output core.HostSlice[bn254.ScalarField], order core.Ordering, direction core.NTTDir) { scalarsFr := make([]fr.Element, size) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -26,10 +26,10 @@ func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[bn outputAsFr[i] = slice64 } - testAgainstGnarkCryptoNttGnarkTypes(t, size, scalarsFr, outputAsFr, order, direction) + testAgainstGnarkCryptoNttGnarkTypes(suite, size, scalarsFr, outputAsFr, order, direction) } -func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNttGnarkTypes(suite suite.Suite, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { domainWithPrecompute := fft.NewDomain(uint64(size)) // DIT + BitReverse == Ordering.kRR // DIT == Ordering.kRN @@ -51,25 +51,19 @@ func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core. if order == core.KNN || order == core.KRR { fft.BitReverse(scalarsFr) } - assert.Equal(t, scalarsFr, outputAsFr) + suite.Equal(scalarsFr, outputAsFr) } -func TestNTTGetDefaultConfig(t *testing.T) { +func testNTTGetDefaultConfig(suite suite.Suite) { actual := ntt.GetDefaultNttConfig() expected := test_helpers.GenerateLimbOne(int(bn254.SCALAR_LIMBS)) - assert.Equal(t, expected, actual.CosetGen[:]) + suite.Equal(expected, actual.CosetGen[:]) cosetGenField := bn254.ScalarField{} cosetGenField.One() - assert.ElementsMatch(t, cosetGenField.GetLimbs(), actual.CosetGen) + suite.ElementsMatch(cosetGenField.GetLimbs(), actual.CosetGen) } -func TestInitDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - cfg := core.GetDefaultNTTInitDomainConfig() - assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) }) -} - -func TestNtt(t *testing.T) { +func testNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := bn254.GenerateScalars(1 << largestTestSize) @@ -87,11 +81,11 @@ func TestNtt(t *testing.T) { ntt.Ntt(scalarsCopy, core.KForward, &cfg, output) // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, core.KForward) } } } -func TestNttFrElement(t *testing.T) { +func testNttFrElement(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := make([]fr.Element, 4) var x fr.Element @@ -114,12 +108,12 @@ func TestNttFrElement(t *testing.T) { ntt.Ntt(scalarsCopy, core.KForward, &cfg, output) // Compare with gnark-crypto - testAgainstGnarkCryptoNttGnarkTypes(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNttGnarkTypes(suite, testSize, scalarsCopy, output, v, core.KForward) } } } -func TestNttDeviceAsync(t *testing.T) { +func testNttDeviceAsync(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := bn254.GenerateScalars(1 << largestTestSize) @@ -150,13 +144,13 @@ func TestNttDeviceAsync(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, direction) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, direction) } } } } -func TestNttBatch(t *testing.T) { +func testNttBatch(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() largestTestSize := 10 largestBatchSize := 20 @@ -194,16 +188,26 @@ func TestNttBatch(t *testing.T) { domainWithPrecompute.FFT(scalarsFr, fft.DIF) fft.BitReverse(scalarsFr) - if !assert.True(t, reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { - t.FailNow() + if !suite.True(reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { + suite.T().FailNow() } } } } } -func TestReleaseDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - e := ntt.ReleaseDomain() - assert.Equal(t, runtime.Success, e, "ReleasDomain failed") +type NTTTestSuite struct { + suite.Suite +} + +func (s *NTTTestSuite) TestNTT() { + s.Run("TestNTTGetDefaultConfig", testWrapper(s.Suite, testNTTGetDefaultConfig)) + s.Run("TestNTT", testWrapper(s.Suite, testNtt)) + s.Run("TestNTTFrElement", testWrapper(s.Suite, testNttFrElement)) + s.Run("TestNttDeviceAsync", testWrapper(s.Suite, testNttDeviceAsync)) + s.Run("TestNttBatch", testWrapper(s.Suite, testNttBatch)) +} + +func TestSuiteNTT(t *testing.T) { + suite.Run(t, new(NTTTestSuite)) } diff --git a/wrappers/golang/curves/bn254/tests/polynomial_test.go b/wrappers/golang/curves/bn254/tests/polynomial_test.go index abd9dc5cb..a19dcb768 100644 --- a/wrappers/golang/curves/bn254/tests/polynomial_test.go +++ b/wrappers/golang/curves/bn254/tests/polynomial_test.go @@ -6,10 +6,9 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" - // "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/polynomial" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) var one, two, three, four, five bn254.ScalarField @@ -41,7 +40,7 @@ func vecOp(a, b bn254.ScalarField, op core.VecOps) bn254.ScalarField { return out[0] } -func TestPolyCreateFromCoefficients(t *testing.T) { +func testPolyCreateFromCoefficients(suite suite.Suite) { scalars := bn254.GenerateScalars(33) var uniPoly polynomial.DensePolynomial @@ -49,7 +48,7 @@ func TestPolyCreateFromCoefficients(t *testing.T) { poly.Print() } -func TestPolyEval(t *testing.T) { +func testPolyEval(suite suite.Suite) { // testing correct evaluation of f(8) for f(x)=4x^2+2x+5 coeffs := core.HostSliceFromElements([]bn254.ScalarField{five, two, four}) var f polynomial.DensePolynomial @@ -62,10 +61,10 @@ func TestPolyEval(t *testing.T) { evals := make(core.HostSlice[bn254.ScalarField], 1) fEvaled := f.EvalOnDomain(domains, evals) var expected bn254.ScalarField - assert.Equal(t, expected.FromUint32(277), fEvaled.(core.HostSlice[bn254.ScalarField])[0]) + suite.Equal(expected.FromUint32(277), fEvaled.(core.HostSlice[bn254.ScalarField])[0]) } -func TestPolyClone(t *testing.T) { +func testPolyClone(suite suite.Suite) { f := randomPoly(8) x := rand() fx := f.Eval(x) @@ -76,11 +75,11 @@ func TestPolyClone(t *testing.T) { gx := g.Eval(x) fgx := fg.Eval(x) - assert.Equal(t, fx, gx) - assert.Equal(t, vecOp(fx, gx, core.Add), fgx) + suite.Equal(fx, gx) + suite.Equal(vecOp(fx, gx, core.Add), fgx) } -func TestPolyAddSubMul(t *testing.T) { +func testPolyAddSubMul(suite suite.Suite) { testSize := 1 << 10 f := randomPoly(testSize) g := randomPoly(testSize) @@ -91,26 +90,26 @@ func TestPolyAddSubMul(t *testing.T) { polyAdd := f.Add(&g) fxAddgx := vecOp(fx, gx, core.Add) - assert.Equal(t, polyAdd.Eval(x), fxAddgx) + suite.Equal(polyAdd.Eval(x), fxAddgx) polySub := f.Subtract(&g) fxSubgx := vecOp(fx, gx, core.Sub) - assert.Equal(t, polySub.Eval(x), fxSubgx) + suite.Equal(polySub.Eval(x), fxSubgx) polyMul := f.Multiply(&g) fxMulgx := vecOp(fx, gx, core.Mul) - assert.Equal(t, polyMul.Eval(x), fxMulgx) + suite.Equal(polyMul.Eval(x), fxMulgx) s1 := rand() polMulS1 := f.MultiplyByScalar(s1) - assert.Equal(t, polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) + suite.Equal(polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) s2 := rand() polMulS2 := f.MultiplyByScalar(s2) - assert.Equal(t, polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) + suite.Equal(polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) } -func TestPolyMonomials(t *testing.T) { +func testPolyMonomials(suite suite.Suite) { var zero bn254.ScalarField var f polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements([]bn254.ScalarField{one, zero, two})) @@ -119,20 +118,20 @@ func TestPolyMonomials(t *testing.T) { fx := f.Eval(x) f.AddMonomial(three, 1) fxAdded := f.Eval(x) - assert.Equal(t, fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) + suite.Equal(fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) f.SubMonomial(one, 0) fxSub := f.Eval(x) - assert.Equal(t, fxSub, vecOp(fxAdded, one, core.Sub)) + suite.Equal(fxSub, vecOp(fxAdded, one, core.Sub)) } -func TestPolyReadCoeffs(t *testing.T) { +func testPolyReadCoeffs(suite suite.Suite) { var f polynomial.DensePolynomial coeffs := core.HostSliceFromElements([]bn254.ScalarField{one, two, three, four}) f.CreateFromCoeffecitients(coeffs) coeffsCopied := make(core.HostSlice[bn254.ScalarField], coeffs.Len()) _, _ = f.CopyCoeffsRange(0, coeffs.Len()-1, coeffsCopied) - assert.ElementsMatch(t, coeffs, coeffsCopied) + suite.ElementsMatch(coeffs, coeffsCopied) var coeffsDevice core.DeviceSlice coeffsDevice.Malloc(one.Size(), coeffs.Len()) @@ -140,16 +139,16 @@ func TestPolyReadCoeffs(t *testing.T) { coeffsHost := make(core.HostSlice[bn254.ScalarField], coeffs.Len()) coeffsHost.CopyFromDevice(&coeffsDevice) - assert.ElementsMatch(t, coeffs, coeffsHost) + suite.ElementsMatch(coeffs, coeffsHost) } -func TestPolyOddEvenSlicing(t *testing.T) { +func testPolyOddEvenSlicing(suite suite.Suite) { size := 1<<10 - 3 f := randomPoly(size) even := f.Even() odd := f.Odd() - assert.Equal(t, f.Degree(), even.Degree()+odd.Degree()+1) + suite.Equal(f.Degree(), even.Degree()+odd.Degree()+1) x := rand() var evenExpected, oddExpected bn254.ScalarField @@ -164,13 +163,13 @@ func TestPolyOddEvenSlicing(t *testing.T) { } evenEvaled := even.Eval(x) - assert.Equal(t, evenExpected, evenEvaled) + suite.Equal(evenExpected, evenEvaled) oddEvaled := odd.Eval(x) - assert.Equal(t, oddExpected, oddEvaled) + suite.Equal(oddExpected, oddEvaled) } -func TestPolynomialDivision(t *testing.T) { +func testPolynomialDivision(suite suite.Suite) { // divide f(x)/g(x), compute q(x), r(x) and check f(x)=q(x)*g(x)+r(x) var f, g polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements(bn254.GenerateScalars(1 << 4))) @@ -184,10 +183,10 @@ func TestPolynomialDivision(t *testing.T) { x := bn254.GenerateScalars(1)[0] fEval := f.Eval(x) fReconEval := fRecon.Eval(x) - assert.Equal(t, fEval, fReconEval) + suite.Equal(fEval, fReconEval) } -func TestDivideByVanishing(t *testing.T) { +func testDivideByVanishing(suite suite.Suite) { // poly of x^4-1 vanishes ad 4th rou var zero bn254.ScalarField minus_one := vecOp(zero, one, core.Sub) @@ -200,31 +199,51 @@ func TestDivideByVanishing(t *testing.T) { fv := f.Multiply(&v) fDegree := f.Degree() fvDegree := fv.Degree() - assert.Equal(t, fDegree+4, fvDegree) + suite.Equal(fDegree+4, fvDegree) fReconstructed := fv.DivideByVanishing(4) - assert.Equal(t, fDegree, fReconstructed.Degree()) + suite.Equal(fDegree, fReconstructed.Degree()) x := rand() - assert.Equal(t, f.Eval(x), fReconstructed.Eval(x)) + suite.Equal(f.Eval(x), fReconstructed.Eval(x)) } -// func TestPolySlice(t *testing.T) { +// func TestPolySlice(suite suite.Suite) { // size := 4 // coeffs := bn254.GenerateScalars(size) // var f DensePolynomial // f.CreateFromCoeffecitients(coeffs) // fSlice := f.AsSlice() -// assert.True(t, fSlice.IsOnDevice()) -// assert.Equal(t, size, fSlice.Len()) +// suite.True(fSlice.IsOnDevice()) +// suite.Equal(size, fSlice.Len()) // hostSlice := make(core.HostSlice[bn254.ScalarField], size) // hostSlice.CopyFromDevice(fSlice) -// assert.Equal(t, coeffs, hostSlice) +// suite.Equal(coeffs, hostSlice) // cfg := ntt.GetDefaultNttConfig() // res := make(core.HostSlice[bn254.ScalarField], size) // ntt.Ntt(fSlice, core.KForward, cfg, res) -// assert.Equal(t, f.Eval(one), res[0]) +// suite.Equal(f.Eval(one), res[0]) // } + +type PolynomialTestSuite struct { + suite.Suite +} + +func (s *PolynomialTestSuite) TestPolynomial() { + s.Run("TestPolyCreateFromCoefficients", testWrapper(s.Suite, testPolyCreateFromCoefficients)) + s.Run("TestPolyEval", testWrapper(s.Suite, testPolyEval)) + s.Run("TestPolyClone", testWrapper(s.Suite, testPolyClone)) + s.Run("TestPolyAddSubMul", testWrapper(s.Suite, testPolyAddSubMul)) + s.Run("TestPolyMonomials", testWrapper(s.Suite, testPolyMonomials)) + s.Run("TestPolyReadCoeffs", testWrapper(s.Suite, testPolyReadCoeffs)) + s.Run("TestPolyOddEvenSlicing", testWrapper(s.Suite, testPolyOddEvenSlicing)) + s.Run("TestPolynomialDivision", testWrapper(s.Suite, testPolynomialDivision)) + s.Run("TestDivideByVanishing", testWrapper(s.Suite, testDivideByVanishing)) +} + +func TestSuitePolynomial(t *testing.T) { + suite.Run(t, new(PolynomialTestSuite)) +} diff --git a/wrappers/golang/curves/bn254/tests/scalar_field_test.go b/wrappers/golang/curves/bn254/tests/scalar_field_test.go index 73743e2eb..fbe3686d6 100644 --- a/wrappers/golang/curves/bn254/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bn254/tests/scalar_field_test.go @@ -6,101 +6,101 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( SCALAR_LIMBS = bn254.SCALAR_LIMBS ) -func TestScalarFieldFromLimbs(t *testing.T) { +func testScalarFieldFromLimbs(suite suite.Suite) { emptyField := bn254.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestScalarFieldGetLimbs(t *testing.T) { +func testScalarFieldGetLimbs(suite suite.Suite) { emptyField := bn254.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") } -func TestScalarFieldOne(t *testing.T) { +func testScalarFieldOne(suite suite.Suite) { var emptyField bn254.ScalarField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(SCALAR_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") } -func TestScalarFieldZero(t *testing.T) { +func testScalarFieldZero(suite suite.Suite) { var emptyField bn254.ScalarField emptyField.Zero() limbsZero := make([]uint32, SCALAR_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") } -func TestScalarFieldSize(t *testing.T) { +func testScalarFieldSize(suite suite.Suite) { var emptyField bn254.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestScalarFieldAsPointer(t *testing.T) { +func testScalarFieldAsPointer(suite suite.Suite) { var emptyField bn254.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestScalarFieldFromBytes(t *testing.T) { +func testScalarFieldFromBytes(suite suite.Suite) { var emptyField bn254.ScalarField bytes, expected := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestScalarFieldToBytes(t *testing.T) { +func testScalarFieldToBytes(suite suite.Suite) { var emptyField bn254.ScalarField expected, limbs := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") } -func TestBn254GenerateScalars(t *testing.T) { +func testBn254GenerateScalars(suite suite.Suite) { const numScalars = 8 scalars := bn254.GenerateScalars(numScalars) - assert.Implements(t, (*core.HostOrDeviceSlice)(nil), &scalars) + suite.Implements((*core.HostOrDeviceSlice)(nil), &scalars) - assert.Equal(t, numScalars, scalars.Len()) + suite.Equal(numScalars, scalars.Len()) zeroScalar := bn254.ScalarField{} - assert.NotContains(t, scalars, zeroScalar) + suite.NotContains(scalars, zeroScalar) } -func TestBn254MongtomeryConversion(t *testing.T) { +func testBn254MongtomeryConversion(suite suite.Suite) { size := 1 << 20 scalars := bn254.GenerateScalars(size) @@ -112,10 +112,31 @@ func TestBn254MongtomeryConversion(t *testing.T) { scalarsMontHost := make(core.HostSlice[bn254.ScalarField], size) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.NotEqual(t, scalars, scalarsMontHost) + suite.NotEqual(scalars, scalarsMontHost) bn254.FromMontgomery(deviceScalars) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.Equal(t, scalars, scalarsMontHost) + suite.Equal(scalars, scalarsMontHost) +} + +type ScalarFieldTestSuite struct { + suite.Suite +} + +func (s *ScalarFieldTestSuite) TestScalarField() { + s.Run("TestScalarFieldFromLimbs", testWrapper(s.Suite, testScalarFieldFromLimbs)) + s.Run("TestScalarFieldGetLimbs", testWrapper(s.Suite, testScalarFieldGetLimbs)) + s.Run("TestScalarFieldOne", testWrapper(s.Suite, testScalarFieldOne)) + s.Run("TestScalarFieldZero", testWrapper(s.Suite, testScalarFieldZero)) + s.Run("TestScalarFieldSize", testWrapper(s.Suite, testScalarFieldSize)) + s.Run("TestScalarFieldAsPointer", testWrapper(s.Suite, testScalarFieldAsPointer)) + s.Run("TestScalarFieldFromBytes", testWrapper(s.Suite, testScalarFieldFromBytes)) + s.Run("TestScalarFieldToBytes", testWrapper(s.Suite, testScalarFieldToBytes)) + s.Run("TestBn254GenerateScalars", testWrapper(s.Suite, testBn254GenerateScalars)) + s.Run("TestBn254MongtomeryConversion", testWrapper(s.Suite, testBn254MongtomeryConversion)) +} + +func TestSuiteScalarField(t *testing.T) { + suite.Run(t, new(ScalarFieldTestSuite)) } diff --git a/wrappers/golang/curves/bn254/tests/vec_ops_test.go b/wrappers/golang/curves/bn254/tests/vec_ops_test.go index 3e1215a5f..a179e2cef 100644 --- a/wrappers/golang/curves/bn254/tests/vec_ops_test.go +++ b/wrappers/golang/curves/bn254/tests/vec_ops_test.go @@ -6,10 +6,10 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestBn254VecOps(t *testing.T) { +func testBn254VecOps(suite suite.Suite) { testSize := 1 << 14 a := bn254.GenerateScalars(testSize) @@ -27,14 +27,14 @@ func TestBn254VecOps(t *testing.T) { vecOps.VecOp(a, b, out, cfg, core.Add) vecOps.VecOp(out, b, out2, cfg, core.Sub) - assert.Equal(t, a, out2) + suite.Equal(a, out2) vecOps.VecOp(a, ones, out3, cfg, core.Mul) - assert.Equal(t, a, out3) + suite.Equal(a, out3) } -func TestBn254Transpose(t *testing.T) { +func testBn254Transpose(suite suite.Suite) { rowSize := 1 << 6 columnSize := 1 << 8 @@ -48,7 +48,7 @@ func TestBn254Transpose(t *testing.T) { vecOps.TransposeMatrix(matrix, out, columnSize, rowSize, cfg) vecOps.TransposeMatrix(out, out2, rowSize, columnSize, cfg) - assert.Equal(t, matrix, out2) + suite.Equal(matrix, out2) var dMatrix, dOut, dOut2 core.DeviceSlice @@ -61,5 +61,18 @@ func TestBn254Transpose(t *testing.T) { output := make(core.HostSlice[bn254.ScalarField], rowSize*columnSize) output.CopyFromDevice(&dOut2) - assert.Equal(t, matrix, output) + suite.Equal(matrix, output) +} + +type Bn254VecOpsTestSuite struct { + suite.Suite +} + +func (s *Bn254VecOpsTestSuite) TestBn254VecOps() { + s.Run("TestBn254VecOps", testWrapper(s.Suite, testBn254VecOps)) + s.Run("TestBn254Transpose", testWrapper(s.Suite, testBn254Transpose)) +} + +func TestSuiteBn254VecOps(t *testing.T) { + suite.Run(t, new(Bn254VecOpsTestSuite)) } diff --git a/wrappers/golang/curves/bw6761/base_field.go b/wrappers/golang/curves/bw6761/base_field.go index b401e49d7..8a51fc944 100644 --- a/wrappers/golang/curves/bw6761/base_field.go +++ b/wrappers/golang/curves/bw6761/base_field.go @@ -53,6 +53,16 @@ func (f *BaseField) Zero() BaseField { return *f } +func (f *BaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *BaseField) One() BaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bw6761/curve.go b/wrappers/golang/curves/bw6761/curve.go index cbd4903ca..54da95360 100644 --- a/wrappers/golang/curves/bw6761/curve.go +++ b/wrappers/golang/curves/bw6761/curve.go @@ -96,6 +96,10 @@ func (a *Affine) Zero() Affine { return *a } +func (a *Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *Affine) FromLimbs(x, y []uint32) Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -106,9 +110,19 @@ func (a *Affine) FromLimbs(x, y []uint32) Affine { func (a Affine) ToProjective() Projective { var p Projective - cA := (*C.affine_t)(unsafe.Pointer(&a)) - cP := (*C.projective_t)(unsafe.Pointer(&p)) - C.bw6_761_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.affine_t)(unsafe.Pointer(&a)) + // cP := (*C.projective_t)(unsafe.Pointer(&p)) + // C.bw6_761_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/curves/bw6761/g2/curve.go b/wrappers/golang/curves/bw6761/g2/curve.go index c3e8e0f07..0661e5703 100644 --- a/wrappers/golang/curves/bw6761/g2/curve.go +++ b/wrappers/golang/curves/bw6761/g2/curve.go @@ -96,6 +96,10 @@ func (a *G2Affine) Zero() G2Affine { return *a } +func (a *G2Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *G2Affine) FromLimbs(x, y []uint32) G2Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -106,9 +110,19 @@ func (a *G2Affine) FromLimbs(x, y []uint32) G2Affine { func (a G2Affine) ToProjective() G2Projective { var p G2Projective - cA := (*C.g2_affine_t)(unsafe.Pointer(&a)) - cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) - C.bw6_761_g2_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.g2_affine_t)(unsafe.Pointer(&a)) + // cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + // C.bw6_761_g2_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/curves/bw6761/g2/g2base_field.go b/wrappers/golang/curves/bw6761/g2/g2base_field.go index c4073af97..8087a2bbf 100644 --- a/wrappers/golang/curves/bw6761/g2/g2base_field.go +++ b/wrappers/golang/curves/bw6761/g2/g2base_field.go @@ -53,6 +53,16 @@ func (f *G2BaseField) Zero() G2BaseField { return *f } +func (f *G2BaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *G2BaseField) One() G2BaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bw6761/scalar_field.go b/wrappers/golang/curves/bw6761/scalar_field.go index 5af4056d3..5cb53afb4 100644 --- a/wrappers/golang/curves/bw6761/scalar_field.go +++ b/wrappers/golang/curves/bw6761/scalar_field.go @@ -60,6 +60,16 @@ func (f *ScalarField) Zero() ScalarField { return *f } +func (f *ScalarField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *ScalarField) One() ScalarField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/bw6761/tests/base_field_test.go b/wrappers/golang/curves/bw6761/tests/base_field_test.go index b00f54bd3..48cd1040b 100644 --- a/wrappers/golang/curves/bw6761/tests/base_field_test.go +++ b/wrappers/golang/curves/bw6761/tests/base_field_test.go @@ -5,85 +5,105 @@ import ( bw6_761 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( BASE_LIMBS = bw6_761.BASE_LIMBS ) -func TestBaseFieldFromLimbs(t *testing.T) { +func testBaseFieldFromLimbs(suite suite.Suite) { emptyField := bw6_761.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestBaseFieldGetLimbs(t *testing.T) { +func testBaseFieldGetLimbs(suite suite.Suite) { emptyField := bw6_761.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") } -func TestBaseFieldOne(t *testing.T) { +func testBaseFieldOne(suite suite.Suite) { var emptyField bw6_761.BaseField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(BASE_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") } -func TestBaseFieldZero(t *testing.T) { +func testBaseFieldZero(suite suite.Suite) { var emptyField bw6_761.BaseField emptyField.Zero() limbsZero := make([]uint32, BASE_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") } -func TestBaseFieldSize(t *testing.T) { +func testBaseFieldSize(suite suite.Suite) { var emptyField bw6_761.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestBaseFieldAsPointer(t *testing.T) { +func testBaseFieldAsPointer(suite suite.Suite) { var emptyField bw6_761.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestBaseFieldFromBytes(t *testing.T) { +func testBaseFieldFromBytes(suite suite.Suite) { var emptyField bw6_761.BaseField bytes, expected := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestBaseFieldToBytes(t *testing.T) { +func testBaseFieldToBytes(suite suite.Suite) { var emptyField bw6_761.BaseField expected, limbs := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") +} + +type BaseFieldTestSuite struct { + suite.Suite +} + +func (s *BaseFieldTestSuite) TestBaseField() { + s.Run("TestBaseFieldFromLimbs", testWrapper(s.Suite, testBaseFieldFromLimbs)) + s.Run("TestBaseFieldGetLimbs", testWrapper(s.Suite, testBaseFieldGetLimbs)) + s.Run("TestBaseFieldOne", testWrapper(s.Suite, testBaseFieldOne)) + s.Run("TestBaseFieldZero", testWrapper(s.Suite, testBaseFieldZero)) + s.Run("TestBaseFieldSize", testWrapper(s.Suite, testBaseFieldSize)) + s.Run("TestBaseFieldAsPointer", testWrapper(s.Suite, testBaseFieldAsPointer)) + s.Run("TestBaseFieldFromBytes", testWrapper(s.Suite, testBaseFieldFromBytes)) + s.Run("TestBaseFieldToBytes", testWrapper(s.Suite, testBaseFieldToBytes)) + +} + +func TestSuiteBaseField(t *testing.T) { + suite.Run(t, new(BaseFieldTestSuite)) } diff --git a/wrappers/golang/curves/bw6761/tests/curve_test.go b/wrappers/golang/curves/bw6761/tests/curve_test.go index 30a40c578..607dce82d 100644 --- a/wrappers/golang/curves/bw6761/tests/curve_test.go +++ b/wrappers/golang/curves/bw6761/tests/curve_test.go @@ -5,15 +5,15 @@ import ( bw6_761 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestAffineZero(t *testing.T) { +func testAffineZero(suite suite.Suite) { var fieldZero = bw6_761.BaseField{} var affineZero bw6_761.Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -21,22 +21,22 @@ func TestAffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func TestAffineFromLimbs(t *testing.T) { +func testAffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var affine bw6_761.Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func TestAffineToProjective(t *testing.T) { +func testAffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne bw6_761.BaseField @@ -49,31 +49,31 @@ func TestAffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func TestProjectiveZero(t *testing.T) { +func testProjectiveZero(suite suite.Suite) { var projectiveZero bw6_761.Projective projectiveZero.Zero() var fieldZero = bw6_761.BaseField{} var fieldOne bw6_761.BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var projective bw6_761.Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func TestProjectiveFromLimbs(t *testing.T) { +func testProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -81,12 +81,12 @@ func TestProjectiveFromLimbs(t *testing.T) { var projective bw6_761.Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func TestProjectiveFromAffine(t *testing.T) { +func testProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne bw6_761.BaseField @@ -100,5 +100,22 @@ func TestProjectiveFromAffine(t *testing.T) { var projectivePoint bw6_761.Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type CurveTestSuite struct { + suite.Suite +} + +func (s *CurveTestSuite) TestCurve() { + s.Run("TestAffineZero", testWrapper(s.Suite, testAffineZero)) + s.Run("TestAffineFromLimbs", testWrapper(s.Suite, testAffineFromLimbs)) + s.Run("TestAffineToProjective", testWrapper(s.Suite, testAffineToProjective)) + s.Run("TestProjectiveZero", testWrapper(s.Suite, testProjectiveZero)) + s.Run("TestProjectiveFromLimbs", testWrapper(s.Suite, testProjectiveFromLimbs)) + s.Run("TestProjectiveFromAffine", testWrapper(s.Suite, testProjectiveFromAffine)) +} + +func TestSuiteCurve(t *testing.T) { + suite.Run(t, new(CurveTestSuite)) } diff --git a/wrappers/golang/curves/bw6761/tests/ecntt_test.go b/wrappers/golang/curves/bw6761/tests/ecntt_test.go index d334e1218..95ee6e7dc 100644 --- a/wrappers/golang/curves/bw6761/tests/ecntt_test.go +++ b/wrappers/golang/curves/bw6761/tests/ecntt_test.go @@ -9,10 +9,10 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime/config_extension" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestECNtt(t *testing.T) { +func testECNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() ext := config_extension.Create() ext.SetInt(core.CUDA_NTT_ALGORITHM, int(core.Radix2)) @@ -31,7 +31,19 @@ func TestECNtt(t *testing.T) { output := make(core.HostSlice[bw6_761.Projective], testSize) e := ecntt.ECNtt(pointsCopy, core.KForward, &cfg, output) - assert.Equal(t, runtime.Success, e, "ECNtt failed") + suite.Equal(runtime.Success, e, "ECNtt failed") } } } + +type ECNttTestSuite struct { + suite.Suite +} + +func (s *ECNttTestSuite) TestECNtt() { + s.Run("TestECNtt", testWrapper(s.Suite, testECNtt)) +} + +func TestSuiteECNtt(t *testing.T) { + suite.Run(t, new(ECNttTestSuite)) +} diff --git a/wrappers/golang/curves/bw6761/tests/g2_curve_test.go b/wrappers/golang/curves/bw6761/tests/g2_curve_test.go index d33271eab..32f9dc847 100644 --- a/wrappers/golang/curves/bw6761/tests/g2_curve_test.go +++ b/wrappers/golang/curves/bw6761/tests/g2_curve_test.go @@ -5,15 +5,15 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761/g2" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestG2AffineZero(t *testing.T) { +func testG2AffineZero(suite suite.Suite) { var fieldZero = g2.G2BaseField{} var affineZero g2.G2Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -21,22 +21,22 @@ func TestG2AffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func TestG2AffineFromLimbs(t *testing.T) { +func testG2AffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var affine g2.G2Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func TestG2AffineToProjective(t *testing.T) { +func testG2AffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var fieldOne g2.G2BaseField @@ -49,31 +49,31 @@ func TestG2AffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func TestG2ProjectiveZero(t *testing.T) { +func testG2ProjectiveZero(suite suite.Suite) { var projectiveZero g2.G2Projective projectiveZero.Zero() var fieldZero = g2.G2BaseField{} var fieldOne g2.G2BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var projective g2.G2Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func TestG2ProjectiveFromLimbs(t *testing.T) { +func testG2ProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -81,12 +81,12 @@ func TestG2ProjectiveFromLimbs(t *testing.T) { var projective g2.G2Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func TestG2ProjectiveFromAffine(t *testing.T) { +func testG2ProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) var fieldOne g2.G2BaseField @@ -100,5 +100,22 @@ func TestG2ProjectiveFromAffine(t *testing.T) { var projectivePoint g2.G2Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type G2CurveTestSuite struct { + suite.Suite +} + +func (s *G2CurveTestSuite) TestG2Curve() { + s.Run("TestG2AffineZero", testWrapper(s.Suite, testG2AffineZero)) + s.Run("TestG2AffineFromLimbs", testWrapper(s.Suite, testG2AffineFromLimbs)) + s.Run("TestG2AffineToProjective", testWrapper(s.Suite, testG2AffineToProjective)) + s.Run("TestG2ProjectiveZero", testWrapper(s.Suite, testG2ProjectiveZero)) + s.Run("TestG2ProjectiveFromLimbs", testWrapper(s.Suite, testG2ProjectiveFromLimbs)) + s.Run("TestG2ProjectiveFromAffine", testWrapper(s.Suite, testG2ProjectiveFromAffine)) +} + +func TestSuiteG2Curve(t *testing.T) { + suite.Run(t, new(G2CurveTestSuite)) } diff --git a/wrappers/golang/curves/bw6761/tests/g2_g2base_field_test.go b/wrappers/golang/curves/bw6761/tests/g2_g2base_field_test.go index ec6496065..40051add8 100644 --- a/wrappers/golang/curves/bw6761/tests/g2_g2base_field_test.go +++ b/wrappers/golang/curves/bw6761/tests/g2_g2base_field_test.go @@ -5,85 +5,105 @@ import ( bw6_761 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761/g2" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( G2BASE_LIMBS = bw6_761.G2BASE_LIMBS ) -func TestG2BaseFieldFromLimbs(t *testing.T) { +func testG2BaseFieldFromLimbs(suite suite.Suite) { emptyField := bw6_761.G2BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestG2BaseFieldGetLimbs(t *testing.T) { +func testG2BaseFieldGetLimbs(suite suite.Suite) { emptyField := bw6_761.G2BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the G2BaseField's limbs") } -func TestG2BaseFieldOne(t *testing.T) { +func testG2BaseFieldOne(suite suite.Suite) { var emptyField bw6_761.G2BaseField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(G2BASE_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "G2BaseField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "G2BaseField with limbs to field one did not work") } -func TestG2BaseFieldZero(t *testing.T) { +func testG2BaseFieldZero(suite suite.Suite) { var emptyField bw6_761.G2BaseField emptyField.Zero() limbsZero := make([]uint32, G2BASE_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "G2BaseField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "G2BaseField with limbs to field zero failed") } -func TestG2BaseFieldSize(t *testing.T) { +func testG2BaseFieldSize(suite suite.Suite) { var emptyField bw6_761.G2BaseField randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestG2BaseFieldAsPointer(t *testing.T) { +func testG2BaseFieldAsPointer(suite suite.Suite) { var emptyField bw6_761.G2BaseField randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestG2BaseFieldFromBytes(t *testing.T) { +func testG2BaseFieldFromBytes(suite suite.Suite) { var emptyField bw6_761.G2BaseField bytes, expected := test_helpers.GenerateBytesArray(int(G2BASE_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestG2BaseFieldToBytes(t *testing.T) { +func testG2BaseFieldToBytes(suite suite.Suite) { var emptyField bw6_761.G2BaseField expected, limbs := test_helpers.GenerateBytesArray(int(G2BASE_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") +} + +type G2BaseFieldTestSuite struct { + suite.Suite +} + +func (s *G2BaseFieldTestSuite) TestG2BaseField() { + s.Run("TestG2BaseFieldFromLimbs", testWrapper(s.Suite, testG2BaseFieldFromLimbs)) + s.Run("TestG2BaseFieldGetLimbs", testWrapper(s.Suite, testG2BaseFieldGetLimbs)) + s.Run("TestG2BaseFieldOne", testWrapper(s.Suite, testG2BaseFieldOne)) + s.Run("TestG2BaseFieldZero", testWrapper(s.Suite, testG2BaseFieldZero)) + s.Run("TestG2BaseFieldSize", testWrapper(s.Suite, testG2BaseFieldSize)) + s.Run("TestG2BaseFieldAsPointer", testWrapper(s.Suite, testG2BaseFieldAsPointer)) + s.Run("TestG2BaseFieldFromBytes", testWrapper(s.Suite, testG2BaseFieldFromBytes)) + s.Run("TestG2BaseFieldToBytes", testWrapper(s.Suite, testG2BaseFieldToBytes)) + +} + +func TestSuiteG2BaseField(t *testing.T) { + suite.Run(t, new(G2BaseFieldTestSuite)) } diff --git a/wrappers/golang/curves/bw6761/tests/g2_msm_test.go b/wrappers/golang/curves/bw6761/tests/g2_msm_test.go index ab4e60118..8b4b377c3 100644 --- a/wrappers/golang/curves/bw6761/tests/g2_msm_test.go +++ b/wrappers/golang/curves/bw6761/tests/g2_msm_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761" @@ -35,7 +35,7 @@ func projectiveToGnarkAffineG2(p g2.G2Projective) bw6761.G2Affine { return bw6761.G2Affine{X: *x, Y: *y} } -func testAgainstGnarkCryptoMsmG2(t *testing.T, scalars core.HostSlice[icicleBw6_761.ScalarField], points core.HostSlice[g2.G2Affine], out g2.G2Projective) { +func testAgainstGnarkCryptoMsmG2(suite suite.Suite, scalars core.HostSlice[icicleBw6_761.ScalarField], points core.HostSlice[g2.G2Affine], out g2.G2Projective) { scalarsFr := make([]fr.Element, len(scalars)) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -47,10 +47,10 @@ func testAgainstGnarkCryptoMsmG2(t *testing.T, scalars core.HostSlice[icicleBw6_ pointsFp[i] = projectiveToGnarkAffineG2(v.ToProjective()) } - testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t, scalarsFr, pointsFp, out) + testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite, scalarsFr, pointsFp, out) } -func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t *testing.T, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bw6761.G2Affine], out g2.G2Projective) { +func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite suite.Suite, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bw6761.G2Affine], out g2.G2Projective) { var msmRes bw6761.G2Jac msmRes.MultiExp(pointsFp, scalarsFr, ecc.MultiExpConfig{}) @@ -59,7 +59,7 @@ func testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t *testing.T, scalarsFr core.Ho icicleResAffine := projectiveToGnarkAffineG2(out) - assert.Equal(t, msmResAffine, icicleResAffine) + suite.Equal(msmResAffine, icicleResAffine) } func convertIcicleAffineToG2Affine(iciclePoints []g2.G2Affine) []bw6761.G2Affine { @@ -79,7 +79,7 @@ func convertIcicleAffineToG2Affine(iciclePoints []g2.G2Affine) []bw6761.G2Affine return points } -func TestMSMG2(t *testing.T) { +func testMSMG2(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -93,11 +93,11 @@ func TestMSMG2(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -105,11 +105,11 @@ func TestMSMG2(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } } -func TestMSMG2GnarkCryptoTypes(t *testing.T) { +func testMSMG2GnarkCryptoTypes(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{3} { runtime.SetDevice(&DEVICE) @@ -129,22 +129,22 @@ func TestMSMG2GnarkCryptoTypes(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.AreBasesMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true e = g2.G2Msm(scalarsHost, pointsHost, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(t, scalarsHost, pointsHost, outHost[0]) + testAgainstGnarkCryptoMsmG2GnarkCryptoTypes(suite, scalarsHost, pointsHost, outHost[0]) } } -func TestMSMG2Batch(t *testing.T) { +func testMSMG2Batch(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -157,10 +157,10 @@ func TestMSMG2Batch(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -169,15 +169,15 @@ func TestMSMG2Batch(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsG2(t *testing.T) { +func testPrecomputePointsG2(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := g2.G2GetDefaultMSMConfig() const precomputeFactor = 8 @@ -194,20 +194,20 @@ func TestPrecomputePointsG2(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = g2.G2PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p g2.G2Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -217,13 +217,13 @@ func TestPrecomputePointsG2(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsSharedBasesG2(t *testing.T) { +func testPrecomputePointsSharedBasesG2(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -239,18 +239,18 @@ func TestPrecomputePointsSharedBasesG2(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = g2.G2PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p g2.G2Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -260,13 +260,13 @@ func TestPrecomputePointsSharedBasesG2(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[0:size] out := outHost[i] - testAgainstGnarkCryptoMsmG2(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsmG2(suite, scalarsSlice, pointsSlice, out) } } } } -func TestMSMG2SkewedDistribution(t *testing.T) { +func testMSMG2SkewedDistribution(suite suite.Suite) { cfg := g2.G2GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -285,19 +285,19 @@ func TestMSMG2SkewedDistribution(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } } -func TestMSMG2MultiDevice(t *testing.T) { +func testMSMG2MultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -321,11 +321,11 @@ func TestMSMG2MultiDevice(t *testing.T) { var p g2.G2Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = g2.G2Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[g2.G2Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -333,9 +333,27 @@ func TestMSMG2MultiDevice(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsmG2(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsmG2(suite, scalars, points, outHost[0]) } }) } wg.Wait() } + +type MSMG2TestSuite struct { + suite.Suite +} + +func (s *MSMG2TestSuite) TestMSMG2() { + s.Run("TestMSMG2", testWrapper(s.Suite, testMSMG2)) + s.Run("TestMSMG2GnarkCryptoTypes", testWrapper(s.Suite, testMSMG2GnarkCryptoTypes)) + s.Run("TestMSMG2Batch", testWrapper(s.Suite, testMSMG2Batch)) + s.Run("TestPrecomputePointsG2", testWrapper(s.Suite, testPrecomputePointsG2)) + s.Run("TestPrecomputePointsSharedBasesG2", testWrapper(s.Suite, testPrecomputePointsSharedBasesG2)) + s.Run("TestMSMG2SkewedDistribution", testWrapper(s.Suite, testMSMG2SkewedDistribution)) + s.Run("TestMSMG2MultiDevice", testWrapper(s.Suite, testMSMG2MultiDevice)) +} + +func TestSuiteMSMG2(t *testing.T) { + suite.Run(t, new(MSMG2TestSuite)) +} diff --git a/wrappers/golang/curves/bw6761/tests/main_test.go b/wrappers/golang/curves/bw6761/tests/main_test.go index 39100da61..5fd5fc2d3 100644 --- a/wrappers/golang/curves/bw6761/tests/main_test.go +++ b/wrappers/golang/curves/bw6761/tests/main_test.go @@ -2,12 +2,14 @@ package tests import ( "fmt" - "testing" - "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bw6_761 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761" ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" + "github.com/stretchr/testify/suite" + "os" + "sync" + "testing" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/fft" ) @@ -16,7 +18,10 @@ const ( largestTestSize = 20 ) -var DEVICE runtime.Device +var ( + DEVICE runtime.Device + exitCode int +) func initDomain(largestTestSize int, cfg core.NTTInitDomainConfig) runtime.EIcicleError { rouMont, _ := fft.Generator(uint64(1 << largestTestSize)) @@ -29,6 +34,18 @@ func initDomain(largestTestSize int, cfg core.NTTInitDomainConfig) runtime.EIcic return e } +func testWrapper(suite suite.Suite, fn func(suite.Suite)) func() { + return func() { + wg := sync.WaitGroup{} + wg.Add(1) + runtime.RunOnDevice(&DEVICE, func(args ...any) { + defer wg.Done() + fn(suite) + }) + wg.Wait() + } +} + func TestMain(m *testing.M) { runtime.LoadBackendFromEnvOrDefault() devices, e := runtime.GetRegisteredDevices() @@ -36,6 +53,7 @@ func TestMain(m *testing.M) { panic("Failed to load registered devices") } for _, deviceType := range devices { + fmt.Println("Running tests for device type:", deviceType) DEVICE = runtime.CreateDevice(deviceType, 0) runtime.SetDevice(&DEVICE) @@ -50,8 +68,10 @@ func TestMain(m *testing.M) { } } + // TODO - run tests for each device type without calling `m.Run` multiple times + // see https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/testing/testing.go;l=1936-1940 for more info // execute tests - m.Run() + exitCode |= m.Run() // release domain e = ntt.ReleaseDomain() @@ -63,4 +83,6 @@ func TestMain(m *testing.M) { } } } + + os.Exit(exitCode) } diff --git a/wrappers/golang/curves/bw6761/tests/msm_test.go b/wrappers/golang/curves/bw6761/tests/msm_test.go index 966a56487..5be5af22e 100644 --- a/wrappers/golang/curves/bw6761/tests/msm_test.go +++ b/wrappers/golang/curves/bw6761/tests/msm_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761" @@ -35,7 +35,7 @@ func projectiveToGnarkAffine(p icicleBw6_761.Projective) bw6761.G1Affine { return bw6761.G1Affine{X: *x, Y: *y} } -func testAgainstGnarkCryptoMsm(t *testing.T, scalars core.HostSlice[icicleBw6_761.ScalarField], points core.HostSlice[icicleBw6_761.Affine], out icicleBw6_761.Projective) { +func testAgainstGnarkCryptoMsm(suite suite.Suite, scalars core.HostSlice[icicleBw6_761.ScalarField], points core.HostSlice[icicleBw6_761.Affine], out icicleBw6_761.Projective) { scalarsFr := make([]fr.Element, len(scalars)) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -47,10 +47,10 @@ func testAgainstGnarkCryptoMsm(t *testing.T, scalars core.HostSlice[icicleBw6_76 pointsFp[i] = projectiveToGnarkAffine(v.ToProjective()) } - testAgainstGnarkCryptoMsmGnarkCryptoTypes(t, scalarsFr, pointsFp, out) + testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite, scalarsFr, pointsFp, out) } -func testAgainstGnarkCryptoMsmGnarkCryptoTypes(t *testing.T, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bw6761.G1Affine], out icicleBw6_761.Projective) { +func testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite suite.Suite, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[bw6761.G1Affine], out icicleBw6_761.Projective) { var msmRes bw6761.G1Jac msmRes.MultiExp(pointsFp, scalarsFr, ecc.MultiExpConfig{}) @@ -59,7 +59,7 @@ func testAgainstGnarkCryptoMsmGnarkCryptoTypes(t *testing.T, scalarsFr core.Host icicleResAffine := projectiveToGnarkAffine(out) - assert.Equal(t, msmResAffine, icicleResAffine) + suite.Equal(msmResAffine, icicleResAffine) } func convertIcicleAffineToG1Affine(iciclePoints []icicleBw6_761.Affine) []bw6761.G1Affine { @@ -79,7 +79,7 @@ func convertIcicleAffineToG1Affine(iciclePoints []icicleBw6_761.Affine) []bw6761 return points } -func TestMSM(t *testing.T) { +func testMSM(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -93,11 +93,11 @@ func TestMSM(t *testing.T) { var p icicleBw6_761.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBw6_761.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -105,11 +105,11 @@ func TestMSM(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } } -func TestMSMGnarkCryptoTypes(t *testing.T) { +func testMSMGnarkCryptoTypes(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{3} { runtime.SetDevice(&DEVICE) @@ -129,22 +129,22 @@ func TestMSMGnarkCryptoTypes(t *testing.T) { var p icicleBw6_761.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.AreBasesMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true e = msm.Msm(scalarsHost, pointsHost, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBw6_761.Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsmGnarkCryptoTypes(t, scalarsHost, pointsHost, outHost[0]) + testAgainstGnarkCryptoMsmGnarkCryptoTypes(suite, scalarsHost, pointsHost, outHost[0]) } } -func TestMSMBatch(t *testing.T) { +func testMSMBatch(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -157,10 +157,10 @@ func TestMSMBatch(t *testing.T) { var p icicleBw6_761.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBw6_761.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -169,15 +169,15 @@ func TestMSMBatch(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePoints(t *testing.T) { +func testPrecomputePoints(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 @@ -194,20 +194,20 @@ func TestPrecomputePoints(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleBw6_761.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleBw6_761.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -217,13 +217,13 @@ func TestPrecomputePoints(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestPrecomputePointsSharedBases(t *testing.T) { +func testPrecomputePointsSharedBases(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -239,18 +239,18 @@ func TestPrecomputePointsSharedBases(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleBw6_761.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleBw6_761.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -260,13 +260,13 @@ func TestPrecomputePointsSharedBases(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[0:size] out := outHost[i] - testAgainstGnarkCryptoMsm(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm(suite, scalarsSlice, pointsSlice, out) } } } } -func TestMSMSkewedDistribution(t *testing.T) { +func testMSMSkewedDistribution(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -285,19 +285,19 @@ func TestMSMSkewedDistribution(t *testing.T) { var p icicleBw6_761.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBw6_761.Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } } -func TestMSMMultiDevice(t *testing.T) { +func testMSMMultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -321,11 +321,11 @@ func TestMSMMultiDevice(t *testing.T) { var p icicleBw6_761.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleBw6_761.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -333,9 +333,27 @@ func TestMSMMultiDevice(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Check with gnark-crypto - testAgainstGnarkCryptoMsm(t, scalars, points, outHost[0]) + testAgainstGnarkCryptoMsm(suite, scalars, points, outHost[0]) } }) } wg.Wait() } + +type MSMTestSuite struct { + suite.Suite +} + +func (s *MSMTestSuite) TestMSM() { + s.Run("TestMSM", testWrapper(s.Suite, testMSM)) + s.Run("TestMSMGnarkCryptoTypes", testWrapper(s.Suite, testMSMGnarkCryptoTypes)) + s.Run("TestMSMBatch", testWrapper(s.Suite, testMSMBatch)) + s.Run("TestPrecomputePoints", testWrapper(s.Suite, testPrecomputePoints)) + s.Run("TestPrecomputePointsSharedBases", testWrapper(s.Suite, testPrecomputePointsSharedBases)) + s.Run("TestMSMSkewedDistribution", testWrapper(s.Suite, testMSMSkewedDistribution)) + s.Run("TestMSMMultiDevice", testWrapper(s.Suite, testMSMMultiDevice)) +} + +func TestSuiteMSM(t *testing.T) { + suite.Run(t, new(MSMTestSuite)) +} diff --git a/wrappers/golang/curves/bw6761/tests/ntt_test.go b/wrappers/golang/curves/bw6761/tests/ntt_test.go index 86e7d783f..42a7bfe4a 100644 --- a/wrappers/golang/curves/bw6761/tests/ntt_test.go +++ b/wrappers/golang/curves/bw6761/tests/ntt_test.go @@ -11,10 +11,10 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[bw6_761.ScalarField], output core.HostSlice[bw6_761.ScalarField], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNtt(suite suite.Suite, size int, scalars core.HostSlice[bw6_761.ScalarField], output core.HostSlice[bw6_761.ScalarField], order core.Ordering, direction core.NTTDir) { scalarsFr := make([]fr.Element, size) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -26,10 +26,10 @@ func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[bw outputAsFr[i] = slice64 } - testAgainstGnarkCryptoNttGnarkTypes(t, size, scalarsFr, outputAsFr, order, direction) + testAgainstGnarkCryptoNttGnarkTypes(suite, size, scalarsFr, outputAsFr, order, direction) } -func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNttGnarkTypes(suite suite.Suite, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { domainWithPrecompute := fft.NewDomain(uint64(size)) // DIT + BitReverse == Ordering.kRR // DIT == Ordering.kRN @@ -51,25 +51,19 @@ func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core. if order == core.KNN || order == core.KRR { fft.BitReverse(scalarsFr) } - assert.Equal(t, scalarsFr, outputAsFr) + suite.Equal(scalarsFr, outputAsFr) } -func TestNTTGetDefaultConfig(t *testing.T) { +func testNTTGetDefaultConfig(suite suite.Suite) { actual := ntt.GetDefaultNttConfig() expected := test_helpers.GenerateLimbOne(int(bw6_761.SCALAR_LIMBS)) - assert.Equal(t, expected, actual.CosetGen[:]) + suite.Equal(expected, actual.CosetGen[:]) cosetGenField := bw6_761.ScalarField{} cosetGenField.One() - assert.ElementsMatch(t, cosetGenField.GetLimbs(), actual.CosetGen) + suite.ElementsMatch(cosetGenField.GetLimbs(), actual.CosetGen) } -func TestInitDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - cfg := core.GetDefaultNTTInitDomainConfig() - assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) }) -} - -func TestNtt(t *testing.T) { +func testNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := bw6_761.GenerateScalars(1 << largestTestSize) @@ -87,11 +81,11 @@ func TestNtt(t *testing.T) { ntt.Ntt(scalarsCopy, core.KForward, &cfg, output) // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, core.KForward) } } } -func TestNttFrElement(t *testing.T) { +func testNttFrElement(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := make([]fr.Element, 4) var x fr.Element @@ -114,12 +108,12 @@ func TestNttFrElement(t *testing.T) { ntt.Ntt(scalarsCopy, core.KForward, &cfg, output) // Compare with gnark-crypto - testAgainstGnarkCryptoNttGnarkTypes(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNttGnarkTypes(suite, testSize, scalarsCopy, output, v, core.KForward) } } } -func TestNttDeviceAsync(t *testing.T) { +func testNttDeviceAsync(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := bw6_761.GenerateScalars(1 << largestTestSize) @@ -150,13 +144,13 @@ func TestNttDeviceAsync(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, direction) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, direction) } } } } -func TestNttBatch(t *testing.T) { +func testNttBatch(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() largestTestSize := 10 largestBatchSize := 20 @@ -194,16 +188,26 @@ func TestNttBatch(t *testing.T) { domainWithPrecompute.FFT(scalarsFr, fft.DIF) fft.BitReverse(scalarsFr) - if !assert.True(t, reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { - t.FailNow() + if !suite.True(reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { + suite.T().FailNow() } } } } } -func TestReleaseDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - e := ntt.ReleaseDomain() - assert.Equal(t, runtime.Success, e, "ReleasDomain failed") +type NTTTestSuite struct { + suite.Suite +} + +func (s *NTTTestSuite) TestNTT() { + s.Run("TestNTTGetDefaultConfig", testWrapper(s.Suite, testNTTGetDefaultConfig)) + s.Run("TestNTT", testWrapper(s.Suite, testNtt)) + s.Run("TestNTTFrElement", testWrapper(s.Suite, testNttFrElement)) + s.Run("TestNttDeviceAsync", testWrapper(s.Suite, testNttDeviceAsync)) + s.Run("TestNttBatch", testWrapper(s.Suite, testNttBatch)) +} + +func TestSuiteNTT(t *testing.T) { + suite.Run(t, new(NTTTestSuite)) } diff --git a/wrappers/golang/curves/bw6761/tests/polynomial_test.go b/wrappers/golang/curves/bw6761/tests/polynomial_test.go index d8fdfc55c..6f4714681 100644 --- a/wrappers/golang/curves/bw6761/tests/polynomial_test.go +++ b/wrappers/golang/curves/bw6761/tests/polynomial_test.go @@ -6,10 +6,9 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bw6_761 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761" - // "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761/polynomial" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) var one, two, three, four, five bw6_761.ScalarField @@ -41,7 +40,7 @@ func vecOp(a, b bw6_761.ScalarField, op core.VecOps) bw6_761.ScalarField { return out[0] } -func TestPolyCreateFromCoefficients(t *testing.T) { +func testPolyCreateFromCoefficients(suite suite.Suite) { scalars := bw6_761.GenerateScalars(33) var uniPoly polynomial.DensePolynomial @@ -49,7 +48,7 @@ func TestPolyCreateFromCoefficients(t *testing.T) { poly.Print() } -func TestPolyEval(t *testing.T) { +func testPolyEval(suite suite.Suite) { // testing correct evaluation of f(8) for f(x)=4x^2+2x+5 coeffs := core.HostSliceFromElements([]bw6_761.ScalarField{five, two, four}) var f polynomial.DensePolynomial @@ -62,10 +61,10 @@ func TestPolyEval(t *testing.T) { evals := make(core.HostSlice[bw6_761.ScalarField], 1) fEvaled := f.EvalOnDomain(domains, evals) var expected bw6_761.ScalarField - assert.Equal(t, expected.FromUint32(277), fEvaled.(core.HostSlice[bw6_761.ScalarField])[0]) + suite.Equal(expected.FromUint32(277), fEvaled.(core.HostSlice[bw6_761.ScalarField])[0]) } -func TestPolyClone(t *testing.T) { +func testPolyClone(suite suite.Suite) { f := randomPoly(8) x := rand() fx := f.Eval(x) @@ -76,11 +75,11 @@ func TestPolyClone(t *testing.T) { gx := g.Eval(x) fgx := fg.Eval(x) - assert.Equal(t, fx, gx) - assert.Equal(t, vecOp(fx, gx, core.Add), fgx) + suite.Equal(fx, gx) + suite.Equal(vecOp(fx, gx, core.Add), fgx) } -func TestPolyAddSubMul(t *testing.T) { +func testPolyAddSubMul(suite suite.Suite) { testSize := 1 << 10 f := randomPoly(testSize) g := randomPoly(testSize) @@ -91,26 +90,26 @@ func TestPolyAddSubMul(t *testing.T) { polyAdd := f.Add(&g) fxAddgx := vecOp(fx, gx, core.Add) - assert.Equal(t, polyAdd.Eval(x), fxAddgx) + suite.Equal(polyAdd.Eval(x), fxAddgx) polySub := f.Subtract(&g) fxSubgx := vecOp(fx, gx, core.Sub) - assert.Equal(t, polySub.Eval(x), fxSubgx) + suite.Equal(polySub.Eval(x), fxSubgx) polyMul := f.Multiply(&g) fxMulgx := vecOp(fx, gx, core.Mul) - assert.Equal(t, polyMul.Eval(x), fxMulgx) + suite.Equal(polyMul.Eval(x), fxMulgx) s1 := rand() polMulS1 := f.MultiplyByScalar(s1) - assert.Equal(t, polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) + suite.Equal(polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) s2 := rand() polMulS2 := f.MultiplyByScalar(s2) - assert.Equal(t, polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) + suite.Equal(polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) } -func TestPolyMonomials(t *testing.T) { +func testPolyMonomials(suite suite.Suite) { var zero bw6_761.ScalarField var f polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements([]bw6_761.ScalarField{one, zero, two})) @@ -119,20 +118,20 @@ func TestPolyMonomials(t *testing.T) { fx := f.Eval(x) f.AddMonomial(three, 1) fxAdded := f.Eval(x) - assert.Equal(t, fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) + suite.Equal(fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) f.SubMonomial(one, 0) fxSub := f.Eval(x) - assert.Equal(t, fxSub, vecOp(fxAdded, one, core.Sub)) + suite.Equal(fxSub, vecOp(fxAdded, one, core.Sub)) } -func TestPolyReadCoeffs(t *testing.T) { +func testPolyReadCoeffs(suite suite.Suite) { var f polynomial.DensePolynomial coeffs := core.HostSliceFromElements([]bw6_761.ScalarField{one, two, three, four}) f.CreateFromCoeffecitients(coeffs) coeffsCopied := make(core.HostSlice[bw6_761.ScalarField], coeffs.Len()) _, _ = f.CopyCoeffsRange(0, coeffs.Len()-1, coeffsCopied) - assert.ElementsMatch(t, coeffs, coeffsCopied) + suite.ElementsMatch(coeffs, coeffsCopied) var coeffsDevice core.DeviceSlice coeffsDevice.Malloc(one.Size(), coeffs.Len()) @@ -140,16 +139,16 @@ func TestPolyReadCoeffs(t *testing.T) { coeffsHost := make(core.HostSlice[bw6_761.ScalarField], coeffs.Len()) coeffsHost.CopyFromDevice(&coeffsDevice) - assert.ElementsMatch(t, coeffs, coeffsHost) + suite.ElementsMatch(coeffs, coeffsHost) } -func TestPolyOddEvenSlicing(t *testing.T) { +func testPolyOddEvenSlicing(suite suite.Suite) { size := 1<<10 - 3 f := randomPoly(size) even := f.Even() odd := f.Odd() - assert.Equal(t, f.Degree(), even.Degree()+odd.Degree()+1) + suite.Equal(f.Degree(), even.Degree()+odd.Degree()+1) x := rand() var evenExpected, oddExpected bw6_761.ScalarField @@ -164,13 +163,13 @@ func TestPolyOddEvenSlicing(t *testing.T) { } evenEvaled := even.Eval(x) - assert.Equal(t, evenExpected, evenEvaled) + suite.Equal(evenExpected, evenEvaled) oddEvaled := odd.Eval(x) - assert.Equal(t, oddExpected, oddEvaled) + suite.Equal(oddExpected, oddEvaled) } -func TestPolynomialDivision(t *testing.T) { +func testPolynomialDivision(suite suite.Suite) { // divide f(x)/g(x), compute q(x), r(x) and check f(x)=q(x)*g(x)+r(x) var f, g polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements(bw6_761.GenerateScalars(1 << 4))) @@ -184,10 +183,10 @@ func TestPolynomialDivision(t *testing.T) { x := bw6_761.GenerateScalars(1)[0] fEval := f.Eval(x) fReconEval := fRecon.Eval(x) - assert.Equal(t, fEval, fReconEval) + suite.Equal(fEval, fReconEval) } -func TestDivideByVanishing(t *testing.T) { +func testDivideByVanishing(suite suite.Suite) { // poly of x^4-1 vanishes ad 4th rou var zero bw6_761.ScalarField minus_one := vecOp(zero, one, core.Sub) @@ -200,31 +199,51 @@ func TestDivideByVanishing(t *testing.T) { fv := f.Multiply(&v) fDegree := f.Degree() fvDegree := fv.Degree() - assert.Equal(t, fDegree+4, fvDegree) + suite.Equal(fDegree+4, fvDegree) fReconstructed := fv.DivideByVanishing(4) - assert.Equal(t, fDegree, fReconstructed.Degree()) + suite.Equal(fDegree, fReconstructed.Degree()) x := rand() - assert.Equal(t, f.Eval(x), fReconstructed.Eval(x)) + suite.Equal(f.Eval(x), fReconstructed.Eval(x)) } -// func TestPolySlice(t *testing.T) { +// func TestPolySlice(suite suite.Suite) { // size := 4 // coeffs := bw6_761.GenerateScalars(size) // var f DensePolynomial // f.CreateFromCoeffecitients(coeffs) // fSlice := f.AsSlice() -// assert.True(t, fSlice.IsOnDevice()) -// assert.Equal(t, size, fSlice.Len()) +// suite.True(fSlice.IsOnDevice()) +// suite.Equal(size, fSlice.Len()) // hostSlice := make(core.HostSlice[bw6_761.ScalarField], size) // hostSlice.CopyFromDevice(fSlice) -// assert.Equal(t, coeffs, hostSlice) +// suite.Equal(coeffs, hostSlice) // cfg := ntt.GetDefaultNttConfig() // res := make(core.HostSlice[bw6_761.ScalarField], size) // ntt.Ntt(fSlice, core.KForward, cfg, res) -// assert.Equal(t, f.Eval(one), res[0]) +// suite.Equal(f.Eval(one), res[0]) // } + +type PolynomialTestSuite struct { + suite.Suite +} + +func (s *PolynomialTestSuite) TestPolynomial() { + s.Run("TestPolyCreateFromCoefficients", testWrapper(s.Suite, testPolyCreateFromCoefficients)) + s.Run("TestPolyEval", testWrapper(s.Suite, testPolyEval)) + s.Run("TestPolyClone", testWrapper(s.Suite, testPolyClone)) + s.Run("TestPolyAddSubMul", testWrapper(s.Suite, testPolyAddSubMul)) + s.Run("TestPolyMonomials", testWrapper(s.Suite, testPolyMonomials)) + s.Run("TestPolyReadCoeffs", testWrapper(s.Suite, testPolyReadCoeffs)) + s.Run("TestPolyOddEvenSlicing", testWrapper(s.Suite, testPolyOddEvenSlicing)) + s.Run("TestPolynomialDivision", testWrapper(s.Suite, testPolynomialDivision)) + s.Run("TestDivideByVanishing", testWrapper(s.Suite, testDivideByVanishing)) +} + +func TestSuitePolynomial(t *testing.T) { + suite.Run(t, new(PolynomialTestSuite)) +} diff --git a/wrappers/golang/curves/bw6761/tests/scalar_field_test.go b/wrappers/golang/curves/bw6761/tests/scalar_field_test.go index 52bce5470..cafaf1a6a 100644 --- a/wrappers/golang/curves/bw6761/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bw6761/tests/scalar_field_test.go @@ -6,101 +6,101 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bw6_761 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( SCALAR_LIMBS = bw6_761.SCALAR_LIMBS ) -func TestScalarFieldFromLimbs(t *testing.T) { +func testScalarFieldFromLimbs(suite suite.Suite) { emptyField := bw6_761.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestScalarFieldGetLimbs(t *testing.T) { +func testScalarFieldGetLimbs(suite suite.Suite) { emptyField := bw6_761.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") } -func TestScalarFieldOne(t *testing.T) { +func testScalarFieldOne(suite suite.Suite) { var emptyField bw6_761.ScalarField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(SCALAR_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") } -func TestScalarFieldZero(t *testing.T) { +func testScalarFieldZero(suite suite.Suite) { var emptyField bw6_761.ScalarField emptyField.Zero() limbsZero := make([]uint32, SCALAR_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") } -func TestScalarFieldSize(t *testing.T) { +func testScalarFieldSize(suite suite.Suite) { var emptyField bw6_761.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestScalarFieldAsPointer(t *testing.T) { +func testScalarFieldAsPointer(suite suite.Suite) { var emptyField bw6_761.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestScalarFieldFromBytes(t *testing.T) { +func testScalarFieldFromBytes(suite suite.Suite) { var emptyField bw6_761.ScalarField bytes, expected := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestScalarFieldToBytes(t *testing.T) { +func testScalarFieldToBytes(suite suite.Suite) { var emptyField bw6_761.ScalarField expected, limbs := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") } -func TestBw6_761GenerateScalars(t *testing.T) { +func testBw6_761GenerateScalars(suite suite.Suite) { const numScalars = 8 scalars := bw6_761.GenerateScalars(numScalars) - assert.Implements(t, (*core.HostOrDeviceSlice)(nil), &scalars) + suite.Implements((*core.HostOrDeviceSlice)(nil), &scalars) - assert.Equal(t, numScalars, scalars.Len()) + suite.Equal(numScalars, scalars.Len()) zeroScalar := bw6_761.ScalarField{} - assert.NotContains(t, scalars, zeroScalar) + suite.NotContains(scalars, zeroScalar) } -func TestBw6_761MongtomeryConversion(t *testing.T) { +func testBw6_761MongtomeryConversion(suite suite.Suite) { size := 1 << 20 scalars := bw6_761.GenerateScalars(size) @@ -112,10 +112,31 @@ func TestBw6_761MongtomeryConversion(t *testing.T) { scalarsMontHost := make(core.HostSlice[bw6_761.ScalarField], size) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.NotEqual(t, scalars, scalarsMontHost) + suite.NotEqual(scalars, scalarsMontHost) bw6_761.FromMontgomery(deviceScalars) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.Equal(t, scalars, scalarsMontHost) + suite.Equal(scalars, scalarsMontHost) +} + +type ScalarFieldTestSuite struct { + suite.Suite +} + +func (s *ScalarFieldTestSuite) TestScalarField() { + s.Run("TestScalarFieldFromLimbs", testWrapper(s.Suite, testScalarFieldFromLimbs)) + s.Run("TestScalarFieldGetLimbs", testWrapper(s.Suite, testScalarFieldGetLimbs)) + s.Run("TestScalarFieldOne", testWrapper(s.Suite, testScalarFieldOne)) + s.Run("TestScalarFieldZero", testWrapper(s.Suite, testScalarFieldZero)) + s.Run("TestScalarFieldSize", testWrapper(s.Suite, testScalarFieldSize)) + s.Run("TestScalarFieldAsPointer", testWrapper(s.Suite, testScalarFieldAsPointer)) + s.Run("TestScalarFieldFromBytes", testWrapper(s.Suite, testScalarFieldFromBytes)) + s.Run("TestScalarFieldToBytes", testWrapper(s.Suite, testScalarFieldToBytes)) + s.Run("TestBw6_761GenerateScalars", testWrapper(s.Suite, testBw6_761GenerateScalars)) + s.Run("TestBw6_761MongtomeryConversion", testWrapper(s.Suite, testBw6_761MongtomeryConversion)) +} + +func TestSuiteScalarField(t *testing.T) { + suite.Run(t, new(ScalarFieldTestSuite)) } diff --git a/wrappers/golang/curves/bw6761/tests/vec_ops_test.go b/wrappers/golang/curves/bw6761/tests/vec_ops_test.go index 3ed9a0cb7..9c7dfd0dc 100644 --- a/wrappers/golang/curves/bw6761/tests/vec_ops_test.go +++ b/wrappers/golang/curves/bw6761/tests/vec_ops_test.go @@ -6,10 +6,10 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" bw6_761 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bw6761/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestBw6_761VecOps(t *testing.T) { +func testBw6_761VecOps(suite suite.Suite) { testSize := 1 << 14 a := bw6_761.GenerateScalars(testSize) @@ -27,14 +27,14 @@ func TestBw6_761VecOps(t *testing.T) { vecOps.VecOp(a, b, out, cfg, core.Add) vecOps.VecOp(out, b, out2, cfg, core.Sub) - assert.Equal(t, a, out2) + suite.Equal(a, out2) vecOps.VecOp(a, ones, out3, cfg, core.Mul) - assert.Equal(t, a, out3) + suite.Equal(a, out3) } -func TestBw6_761Transpose(t *testing.T) { +func testBw6_761Transpose(suite suite.Suite) { rowSize := 1 << 6 columnSize := 1 << 8 @@ -48,7 +48,7 @@ func TestBw6_761Transpose(t *testing.T) { vecOps.TransposeMatrix(matrix, out, columnSize, rowSize, cfg) vecOps.TransposeMatrix(out, out2, rowSize, columnSize, cfg) - assert.Equal(t, matrix, out2) + suite.Equal(matrix, out2) var dMatrix, dOut, dOut2 core.DeviceSlice @@ -61,5 +61,18 @@ func TestBw6_761Transpose(t *testing.T) { output := make(core.HostSlice[bw6_761.ScalarField], rowSize*columnSize) output.CopyFromDevice(&dOut2) - assert.Equal(t, matrix, output) + suite.Equal(matrix, output) +} + +type Bw6_761VecOpsTestSuite struct { + suite.Suite +} + +func (s *Bw6_761VecOpsTestSuite) TestBw6_761VecOps() { + s.Run("TestBw6_761VecOps", testWrapper(s.Suite, testBw6_761VecOps)) + s.Run("TestBw6_761Transpose", testWrapper(s.Suite, testBw6_761Transpose)) +} + +func TestSuiteBw6_761VecOps(t *testing.T) { + suite.Run(t, new(Bw6_761VecOpsTestSuite)) } diff --git a/wrappers/golang/curves/grumpkin/base_field.go b/wrappers/golang/curves/grumpkin/base_field.go index cacd27674..f9fd84646 100644 --- a/wrappers/golang/curves/grumpkin/base_field.go +++ b/wrappers/golang/curves/grumpkin/base_field.go @@ -53,6 +53,16 @@ func (f *BaseField) Zero() BaseField { return *f } +func (f *BaseField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *BaseField) One() BaseField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/grumpkin/curve.go b/wrappers/golang/curves/grumpkin/curve.go index 3e93617ad..8c9e4b600 100644 --- a/wrappers/golang/curves/grumpkin/curve.go +++ b/wrappers/golang/curves/grumpkin/curve.go @@ -96,6 +96,10 @@ func (a *Affine) Zero() Affine { return *a } +func (a *Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *Affine) FromLimbs(x, y []uint32) Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -106,9 +110,19 @@ func (a *Affine) FromLimbs(x, y []uint32) Affine { func (a Affine) ToProjective() Projective { var p Projective - cA := (*C.affine_t)(unsafe.Pointer(&a)) - cP := (*C.projective_t)(unsafe.Pointer(&p)) - C.grumpkin_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.affine_t)(unsafe.Pointer(&a)) + // cP := (*C.projective_t)(unsafe.Pointer(&p)) + // C.grumpkin_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/curves/grumpkin/scalar_field.go b/wrappers/golang/curves/grumpkin/scalar_field.go index 51bbabe72..8ef45b290 100644 --- a/wrappers/golang/curves/grumpkin/scalar_field.go +++ b/wrappers/golang/curves/grumpkin/scalar_field.go @@ -60,6 +60,16 @@ func (f *ScalarField) Zero() ScalarField { return *f } +func (f *ScalarField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *ScalarField) One() ScalarField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/curves/grumpkin/tests/base_field_test.go b/wrappers/golang/curves/grumpkin/tests/base_field_test.go index c3d414e71..6276e6e67 100644 --- a/wrappers/golang/curves/grumpkin/tests/base_field_test.go +++ b/wrappers/golang/curves/grumpkin/tests/base_field_test.go @@ -5,85 +5,105 @@ import ( grumpkin "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/grumpkin" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( BASE_LIMBS = grumpkin.BASE_LIMBS ) -func TestBaseFieldFromLimbs(t *testing.T) { +func testBaseFieldFromLimbs(suite suite.Suite) { emptyField := grumpkin.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestBaseFieldGetLimbs(t *testing.T) { +func testBaseFieldGetLimbs(suite suite.Suite) { emptyField := grumpkin.BaseField{} randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the BaseField's limbs") } -func TestBaseFieldOne(t *testing.T) { +func testBaseFieldOne(suite suite.Suite) { var emptyField grumpkin.BaseField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(BASE_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "BaseField with limbs to field one did not work") } -func TestBaseFieldZero(t *testing.T) { +func testBaseFieldZero(suite suite.Suite) { var emptyField grumpkin.BaseField emptyField.Zero() limbsZero := make([]uint32, BASE_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "BaseField with limbs to field zero failed") } -func TestBaseFieldSize(t *testing.T) { +func testBaseFieldSize(suite suite.Suite) { var emptyField grumpkin.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestBaseFieldAsPointer(t *testing.T) { +func testBaseFieldAsPointer(suite suite.Suite) { var emptyField grumpkin.BaseField randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestBaseFieldFromBytes(t *testing.T) { +func testBaseFieldFromBytes(suite suite.Suite) { var emptyField grumpkin.BaseField bytes, expected := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestBaseFieldToBytes(t *testing.T) { +func testBaseFieldToBytes(suite suite.Suite) { var emptyField grumpkin.BaseField expected, limbs := test_helpers.GenerateBytesArray(int(BASE_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") +} + +type BaseFieldTestSuite struct { + suite.Suite +} + +func (s *BaseFieldTestSuite) TestBaseField() { + s.Run("TestBaseFieldFromLimbs", testWrapper(s.Suite, testBaseFieldFromLimbs)) + s.Run("TestBaseFieldGetLimbs", testWrapper(s.Suite, testBaseFieldGetLimbs)) + s.Run("TestBaseFieldOne", testWrapper(s.Suite, testBaseFieldOne)) + s.Run("TestBaseFieldZero", testWrapper(s.Suite, testBaseFieldZero)) + s.Run("TestBaseFieldSize", testWrapper(s.Suite, testBaseFieldSize)) + s.Run("TestBaseFieldAsPointer", testWrapper(s.Suite, testBaseFieldAsPointer)) + s.Run("TestBaseFieldFromBytes", testWrapper(s.Suite, testBaseFieldFromBytes)) + s.Run("TestBaseFieldToBytes", testWrapper(s.Suite, testBaseFieldToBytes)) + +} + +func TestSuiteBaseField(t *testing.T) { + suite.Run(t, new(BaseFieldTestSuite)) } diff --git a/wrappers/golang/curves/grumpkin/tests/curve_test.go b/wrappers/golang/curves/grumpkin/tests/curve_test.go index da68e141c..af4dbd00f 100644 --- a/wrappers/golang/curves/grumpkin/tests/curve_test.go +++ b/wrappers/golang/curves/grumpkin/tests/curve_test.go @@ -5,15 +5,15 @@ import ( grumpkin "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/grumpkin" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestAffineZero(t *testing.T) { +func testAffineZero(suite suite.Suite) { var fieldZero = grumpkin.BaseField{} var affineZero grumpkin.Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -21,22 +21,22 @@ func TestAffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func TestAffineFromLimbs(t *testing.T) { +func testAffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var affine grumpkin.Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func TestAffineToProjective(t *testing.T) { +func testAffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne grumpkin.BaseField @@ -49,31 +49,31 @@ func TestAffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func TestProjectiveZero(t *testing.T) { +func testProjectiveZero(suite suite.Suite) { var projectiveZero grumpkin.Projective projectiveZero.Zero() var fieldZero = grumpkin.BaseField{} var fieldOne grumpkin.BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var projective grumpkin.Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func TestProjectiveFromLimbs(t *testing.T) { +func testProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -81,12 +81,12 @@ func TestProjectiveFromLimbs(t *testing.T) { var projective grumpkin.Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func TestProjectiveFromAffine(t *testing.T) { +func testProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) var fieldOne grumpkin.BaseField @@ -100,5 +100,22 @@ func TestProjectiveFromAffine(t *testing.T) { var projectivePoint grumpkin.Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type CurveTestSuite struct { + suite.Suite +} + +func (s *CurveTestSuite) TestCurve() { + s.Run("TestAffineZero", testWrapper(s.Suite, testAffineZero)) + s.Run("TestAffineFromLimbs", testWrapper(s.Suite, testAffineFromLimbs)) + s.Run("TestAffineToProjective", testWrapper(s.Suite, testAffineToProjective)) + s.Run("TestProjectiveZero", testWrapper(s.Suite, testProjectiveZero)) + s.Run("TestProjectiveFromLimbs", testWrapper(s.Suite, testProjectiveFromLimbs)) + s.Run("TestProjectiveFromAffine", testWrapper(s.Suite, testProjectiveFromAffine)) +} + +func TestSuiteCurve(t *testing.T) { + suite.Run(t, new(CurveTestSuite)) } diff --git a/wrappers/golang/curves/grumpkin/tests/main_test.go b/wrappers/golang/curves/grumpkin/tests/main_test.go index 9e03ffb9c..d41eaacbe 100644 --- a/wrappers/golang/curves/grumpkin/tests/main_test.go +++ b/wrappers/golang/curves/grumpkin/tests/main_test.go @@ -1,6 +1,10 @@ package tests import ( + "fmt" + "github.com/stretchr/testify/suite" + "os" + "sync" "testing" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" @@ -10,7 +14,22 @@ const ( largestTestSize = 20 ) -var DEVICE runtime.Device +var ( + DEVICE runtime.Device + exitCode int +) + +func testWrapper(suite suite.Suite, fn func(suite.Suite)) func() { + return func() { + wg := sync.WaitGroup{} + wg.Add(1) + runtime.RunOnDevice(&DEVICE, func(args ...any) { + defer wg.Done() + fn(suite) + }) + wg.Wait() + } +} func TestMain(m *testing.M) { runtime.LoadBackendFromEnvOrDefault() @@ -19,11 +38,16 @@ func TestMain(m *testing.M) { panic("Failed to load registered devices") } for _, deviceType := range devices { + fmt.Println("Running tests for device type:", deviceType) DEVICE = runtime.CreateDevice(deviceType, 0) runtime.SetDevice(&DEVICE) + // TODO - run tests for each device type without calling `m.Run` multiple times + // see https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/testing/testing.go;l=1936-1940 for more info // execute tests - m.Run() + exitCode |= m.Run() } + + os.Exit(exitCode) } diff --git a/wrappers/golang/curves/grumpkin/tests/msm_test.go b/wrappers/golang/curves/grumpkin/tests/msm_test.go index 9911f2fd8..796241762 100644 --- a/wrappers/golang/curves/grumpkin/tests/msm_test.go +++ b/wrappers/golang/curves/grumpkin/tests/msm_test.go @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" icicleGrumpkin "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/grumpkin" @@ -13,7 +13,7 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" ) -func TestMSM(t *testing.T) { +func testMSM(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -27,11 +27,11 @@ func TestMSM(t *testing.T) { var p icicleGrumpkin.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleGrumpkin.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -42,7 +42,7 @@ func TestMSM(t *testing.T) { } } -func TestMSMBatch(t *testing.T) { +func testMSMBatch(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -55,10 +55,10 @@ func TestMSMBatch(t *testing.T) { var p icicleGrumpkin.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleGrumpkin.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -67,9 +67,9 @@ func TestMSMBatch(t *testing.T) { } } -func TestPrecomputePoints(t *testing.T) { +func testPrecomputePoints(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 @@ -86,20 +86,20 @@ func TestPrecomputePoints(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleGrumpkin.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleGrumpkin.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -109,7 +109,7 @@ func TestPrecomputePoints(t *testing.T) { } } -func TestPrecomputePointsSharedBases(t *testing.T) { +func testPrecomputePointsSharedBases(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -125,18 +125,18 @@ func TestPrecomputePointsSharedBases(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = msm.PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p icicleGrumpkin.Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[icicleGrumpkin.Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -146,7 +146,7 @@ func TestPrecomputePointsSharedBases(t *testing.T) { } } -func TestMSMSkewedDistribution(t *testing.T) { +func testMSMSkewedDistribution(suite suite.Suite) { cfg := msm.GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -165,10 +165,10 @@ func TestMSMSkewedDistribution(t *testing.T) { var p icicleGrumpkin.Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleGrumpkin.Projective], 1) outHost.CopyFromDevice(&out) out.Free() @@ -176,7 +176,7 @@ func TestMSMSkewedDistribution(t *testing.T) { } } -func TestMSMMultiDevice(t *testing.T) { +func testMSMMultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -200,11 +200,11 @@ func TestMSMMultiDevice(t *testing.T) { var p icicleGrumpkin.Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = msm.Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[icicleGrumpkin.Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -217,3 +217,20 @@ func TestMSMMultiDevice(t *testing.T) { } wg.Wait() } + +type MSMTestSuite struct { + suite.Suite +} + +func (s *MSMTestSuite) TestMSM() { + s.Run("TestMSM", testWrapper(s.Suite, testMSM)) + s.Run("TestMSMBatch", testWrapper(s.Suite, testMSMBatch)) + s.Run("TestPrecomputePoints", testWrapper(s.Suite, testPrecomputePoints)) + s.Run("TestPrecomputePointsSharedBases", testWrapper(s.Suite, testPrecomputePointsSharedBases)) + s.Run("TestMSMSkewedDistribution", testWrapper(s.Suite, testMSMSkewedDistribution)) + s.Run("TestMSMMultiDevice", testWrapper(s.Suite, testMSMMultiDevice)) +} + +func TestSuiteMSM(t *testing.T) { + suite.Run(t, new(MSMTestSuite)) +} diff --git a/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go b/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go index 1fddcb67e..2d02387e2 100644 --- a/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go +++ b/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go @@ -6,101 +6,101 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" grumpkin "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/grumpkin" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( SCALAR_LIMBS = grumpkin.SCALAR_LIMBS ) -func TestScalarFieldFromLimbs(t *testing.T) { +func testScalarFieldFromLimbs(suite suite.Suite) { emptyField := grumpkin.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestScalarFieldGetLimbs(t *testing.T) { +func testScalarFieldGetLimbs(suite suite.Suite) { emptyField := grumpkin.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") } -func TestScalarFieldOne(t *testing.T) { +func testScalarFieldOne(suite suite.Suite) { var emptyField grumpkin.ScalarField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(SCALAR_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") } -func TestScalarFieldZero(t *testing.T) { +func testScalarFieldZero(suite suite.Suite) { var emptyField grumpkin.ScalarField emptyField.Zero() limbsZero := make([]uint32, SCALAR_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") } -func TestScalarFieldSize(t *testing.T) { +func testScalarFieldSize(suite suite.Suite) { var emptyField grumpkin.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestScalarFieldAsPointer(t *testing.T) { +func testScalarFieldAsPointer(suite suite.Suite) { var emptyField grumpkin.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestScalarFieldFromBytes(t *testing.T) { +func testScalarFieldFromBytes(suite suite.Suite) { var emptyField grumpkin.ScalarField bytes, expected := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestScalarFieldToBytes(t *testing.T) { +func testScalarFieldToBytes(suite suite.Suite) { var emptyField grumpkin.ScalarField expected, limbs := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") } -func TestGrumpkinGenerateScalars(t *testing.T) { +func testGrumpkinGenerateScalars(suite suite.Suite) { const numScalars = 8 scalars := grumpkin.GenerateScalars(numScalars) - assert.Implements(t, (*core.HostOrDeviceSlice)(nil), &scalars) + suite.Implements((*core.HostOrDeviceSlice)(nil), &scalars) - assert.Equal(t, numScalars, scalars.Len()) + suite.Equal(numScalars, scalars.Len()) zeroScalar := grumpkin.ScalarField{} - assert.NotContains(t, scalars, zeroScalar) + suite.NotContains(scalars, zeroScalar) } -func TestGrumpkinMongtomeryConversion(t *testing.T) { +func testGrumpkinMongtomeryConversion(suite suite.Suite) { size := 1 << 20 scalars := grumpkin.GenerateScalars(size) @@ -112,10 +112,31 @@ func TestGrumpkinMongtomeryConversion(t *testing.T) { scalarsMontHost := make(core.HostSlice[grumpkin.ScalarField], size) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.NotEqual(t, scalars, scalarsMontHost) + suite.NotEqual(scalars, scalarsMontHost) grumpkin.FromMontgomery(deviceScalars) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.Equal(t, scalars, scalarsMontHost) + suite.Equal(scalars, scalarsMontHost) +} + +type ScalarFieldTestSuite struct { + suite.Suite +} + +func (s *ScalarFieldTestSuite) TestScalarField() { + s.Run("TestScalarFieldFromLimbs", testWrapper(s.Suite, testScalarFieldFromLimbs)) + s.Run("TestScalarFieldGetLimbs", testWrapper(s.Suite, testScalarFieldGetLimbs)) + s.Run("TestScalarFieldOne", testWrapper(s.Suite, testScalarFieldOne)) + s.Run("TestScalarFieldZero", testWrapper(s.Suite, testScalarFieldZero)) + s.Run("TestScalarFieldSize", testWrapper(s.Suite, testScalarFieldSize)) + s.Run("TestScalarFieldAsPointer", testWrapper(s.Suite, testScalarFieldAsPointer)) + s.Run("TestScalarFieldFromBytes", testWrapper(s.Suite, testScalarFieldFromBytes)) + s.Run("TestScalarFieldToBytes", testWrapper(s.Suite, testScalarFieldToBytes)) + s.Run("TestGrumpkinGenerateScalars", testWrapper(s.Suite, testGrumpkinGenerateScalars)) + s.Run("TestGrumpkinMongtomeryConversion", testWrapper(s.Suite, testGrumpkinMongtomeryConversion)) +} + +func TestSuiteScalarField(t *testing.T) { + suite.Run(t, new(ScalarFieldTestSuite)) } diff --git a/wrappers/golang/curves/grumpkin/tests/vec_ops_test.go b/wrappers/golang/curves/grumpkin/tests/vec_ops_test.go index 35a99a9e1..c5da4be60 100644 --- a/wrappers/golang/curves/grumpkin/tests/vec_ops_test.go +++ b/wrappers/golang/curves/grumpkin/tests/vec_ops_test.go @@ -6,10 +6,10 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" grumpkin "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/grumpkin" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/grumpkin/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestGrumpkinVecOps(t *testing.T) { +func testGrumpkinVecOps(suite suite.Suite) { testSize := 1 << 14 a := grumpkin.GenerateScalars(testSize) @@ -27,14 +27,14 @@ func TestGrumpkinVecOps(t *testing.T) { vecOps.VecOp(a, b, out, cfg, core.Add) vecOps.VecOp(out, b, out2, cfg, core.Sub) - assert.Equal(t, a, out2) + suite.Equal(a, out2) vecOps.VecOp(a, ones, out3, cfg, core.Mul) - assert.Equal(t, a, out3) + suite.Equal(a, out3) } -func TestGrumpkinTranspose(t *testing.T) { +func testGrumpkinTranspose(suite suite.Suite) { rowSize := 1 << 6 columnSize := 1 << 8 @@ -48,7 +48,7 @@ func TestGrumpkinTranspose(t *testing.T) { vecOps.TransposeMatrix(matrix, out, columnSize, rowSize, cfg) vecOps.TransposeMatrix(out, out2, rowSize, columnSize, cfg) - assert.Equal(t, matrix, out2) + suite.Equal(matrix, out2) var dMatrix, dOut, dOut2 core.DeviceSlice @@ -61,5 +61,18 @@ func TestGrumpkinTranspose(t *testing.T) { output := make(core.HostSlice[grumpkin.ScalarField], rowSize*columnSize) output.CopyFromDevice(&dOut2) - assert.Equal(t, matrix, output) + suite.Equal(matrix, output) +} + +type GrumpkinVecOpsTestSuite struct { + suite.Suite +} + +func (s *GrumpkinVecOpsTestSuite) TestGrumpkinVecOps() { + s.Run("TestGrumpkinVecOps", testWrapper(s.Suite, testGrumpkinVecOps)) + s.Run("TestGrumpkinTranspose", testWrapper(s.Suite, testGrumpkinTranspose)) +} + +func TestSuiteGrumpkinVecOps(t *testing.T) { + suite.Run(t, new(GrumpkinVecOpsTestSuite)) } diff --git a/wrappers/golang/fields/babybear/extension/extension_field.go b/wrappers/golang/fields/babybear/extension/extension_field.go index 512db3910..3d0a2ae33 100644 --- a/wrappers/golang/fields/babybear/extension/extension_field.go +++ b/wrappers/golang/fields/babybear/extension/extension_field.go @@ -60,6 +60,16 @@ func (f *ExtensionField) Zero() ExtensionField { return *f } +func (f *ExtensionField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *ExtensionField) One() ExtensionField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/fields/babybear/scalar_field.go b/wrappers/golang/fields/babybear/scalar_field.go index a5553d999..78d798aae 100644 --- a/wrappers/golang/fields/babybear/scalar_field.go +++ b/wrappers/golang/fields/babybear/scalar_field.go @@ -60,6 +60,16 @@ func (f *ScalarField) Zero() ScalarField { return *f } +func (f *ScalarField) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *ScalarField) One() ScalarField { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/fields/babybear/tests/extension_field_test.go b/wrappers/golang/fields/babybear/tests/extension_field_test.go index 7178d257e..ff899bbed 100644 --- a/wrappers/golang/fields/babybear/tests/extension_field_test.go +++ b/wrappers/golang/fields/babybear/tests/extension_field_test.go @@ -6,101 +6,101 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" babybear_extension "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear/extension" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( EXTENSION_LIMBS = babybear_extension.EXTENSION_LIMBS ) -func TestExtensionFieldFromLimbs(t *testing.T) { +func testExtensionFieldFromLimbs(suite suite.Suite) { emptyField := babybear_extension.ExtensionField{} randLimbs := test_helpers.GenerateRandomLimb(int(EXTENSION_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ExtensionField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ExtensionField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestExtensionFieldGetLimbs(t *testing.T) { +func testExtensionFieldGetLimbs(suite suite.Suite) { emptyField := babybear_extension.ExtensionField{} randLimbs := test_helpers.GenerateRandomLimb(int(EXTENSION_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ExtensionField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ExtensionField's limbs") } -func TestExtensionFieldOne(t *testing.T) { +func testExtensionFieldOne(suite suite.Suite) { var emptyField babybear_extension.ExtensionField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(EXTENSION_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(EXTENSION_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "ExtensionField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "ExtensionField with limbs to field one did not work") } -func TestExtensionFieldZero(t *testing.T) { +func testExtensionFieldZero(suite suite.Suite) { var emptyField babybear_extension.ExtensionField emptyField.Zero() limbsZero := make([]uint32, EXTENSION_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(EXTENSION_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "ExtensionField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "ExtensionField with limbs to field zero failed") } -func TestExtensionFieldSize(t *testing.T) { +func testExtensionFieldSize(suite suite.Suite) { var emptyField babybear_extension.ExtensionField randLimbs := test_helpers.GenerateRandomLimb(int(EXTENSION_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestExtensionFieldAsPointer(t *testing.T) { +func testExtensionFieldAsPointer(suite suite.Suite) { var emptyField babybear_extension.ExtensionField randLimbs := test_helpers.GenerateRandomLimb(int(EXTENSION_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestExtensionFieldFromBytes(t *testing.T) { +func testExtensionFieldFromBytes(suite suite.Suite) { var emptyField babybear_extension.ExtensionField bytes, expected := test_helpers.GenerateBytesArray(int(EXTENSION_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestExtensionFieldToBytes(t *testing.T) { +func testExtensionFieldToBytes(suite suite.Suite) { var emptyField babybear_extension.ExtensionField expected, limbs := test_helpers.GenerateBytesArray(int(EXTENSION_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") } -func TestBabybear_extensionGenerateScalars(t *testing.T) { +func testBabybear_extensionGenerateScalars(suite suite.Suite) { const numScalars = 8 scalars := babybear_extension.GenerateScalars(numScalars) - assert.Implements(t, (*core.HostOrDeviceSlice)(nil), &scalars) + suite.Implements((*core.HostOrDeviceSlice)(nil), &scalars) - assert.Equal(t, numScalars, scalars.Len()) + suite.Equal(numScalars, scalars.Len()) zeroScalar := babybear_extension.ExtensionField{} - assert.NotContains(t, scalars, zeroScalar) + suite.NotContains(scalars, zeroScalar) } -func TestBabybear_extensionMongtomeryConversion(t *testing.T) { +func testBabybear_extensionMongtomeryConversion(suite suite.Suite) { size := 1 << 20 scalars := babybear_extension.GenerateScalars(size) @@ -112,10 +112,31 @@ func TestBabybear_extensionMongtomeryConversion(t *testing.T) { scalarsMontHost := make(core.HostSlice[babybear_extension.ExtensionField], size) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.NotEqual(t, scalars, scalarsMontHost) + suite.NotEqual(scalars, scalarsMontHost) babybear_extension.FromMontgomery(deviceScalars) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.Equal(t, scalars, scalarsMontHost) + suite.Equal(scalars, scalarsMontHost) +} + +type ExtensionFieldTestSuite struct { + suite.Suite +} + +func (s *ExtensionFieldTestSuite) TestExtensionField() { + s.Run("TestExtensionFieldFromLimbs", testWrapper(s.Suite, testExtensionFieldFromLimbs)) + s.Run("TestExtensionFieldGetLimbs", testWrapper(s.Suite, testExtensionFieldGetLimbs)) + s.Run("TestExtensionFieldOne", testWrapper(s.Suite, testExtensionFieldOne)) + s.Run("TestExtensionFieldZero", testWrapper(s.Suite, testExtensionFieldZero)) + s.Run("TestExtensionFieldSize", testWrapper(s.Suite, testExtensionFieldSize)) + s.Run("TestExtensionFieldAsPointer", testWrapper(s.Suite, testExtensionFieldAsPointer)) + s.Run("TestExtensionFieldFromBytes", testWrapper(s.Suite, testExtensionFieldFromBytes)) + s.Run("TestExtensionFieldToBytes", testWrapper(s.Suite, testExtensionFieldToBytes)) + s.Run("TestBabybear_extensionGenerateScalars", testWrapper(s.Suite, testBabybear_extensionGenerateScalars)) + s.Run("TestBabybear_extensionMongtomeryConversion", testWrapper(s.Suite, testBabybear_extensionMongtomeryConversion)) +} + +func TestSuiteExtensionField(t *testing.T) { + suite.Run(t, new(ExtensionFieldTestSuite)) } diff --git a/wrappers/golang/fields/babybear/tests/extension_vec_ops_test.go b/wrappers/golang/fields/babybear/tests/extension_vec_ops_test.go index 7d0e038a5..f7523fe51 100644 --- a/wrappers/golang/fields/babybear/tests/extension_vec_ops_test.go +++ b/wrappers/golang/fields/babybear/tests/extension_vec_ops_test.go @@ -6,10 +6,10 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" babybear_extension "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear/extension" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear/extension/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestBabybear_extensionVecOps(t *testing.T) { +func testBabybear_extensionVecOps(suite suite.Suite) { testSize := 1 << 14 a := babybear_extension.GenerateScalars(testSize) @@ -27,14 +27,14 @@ func TestBabybear_extensionVecOps(t *testing.T) { vecOps.VecOp(a, b, out, cfg, core.Add) vecOps.VecOp(out, b, out2, cfg, core.Sub) - assert.Equal(t, a, out2) + suite.Equal(a, out2) vecOps.VecOp(a, ones, out3, cfg, core.Mul) - assert.Equal(t, a, out3) + suite.Equal(a, out3) } -func TestBabybear_extensionTranspose(t *testing.T) { +func testBabybear_extensionTranspose(suite suite.Suite) { rowSize := 1 << 6 columnSize := 1 << 8 @@ -48,7 +48,7 @@ func TestBabybear_extensionTranspose(t *testing.T) { vecOps.TransposeMatrix(matrix, out, columnSize, rowSize, cfg) vecOps.TransposeMatrix(out, out2, rowSize, columnSize, cfg) - assert.Equal(t, matrix, out2) + suite.Equal(matrix, out2) var dMatrix, dOut, dOut2 core.DeviceSlice @@ -61,5 +61,18 @@ func TestBabybear_extensionTranspose(t *testing.T) { output := make(core.HostSlice[babybear_extension.ExtensionField], rowSize*columnSize) output.CopyFromDevice(&dOut2) - assert.Equal(t, matrix, output) + suite.Equal(matrix, output) +} + +type Babybear_extensionVecOpsTestSuite struct { + suite.Suite +} + +func (s *Babybear_extensionVecOpsTestSuite) TestBabybear_extensionVecOps() { + s.Run("TestBabybear_extensionVecOps", testWrapper(s.Suite, testBabybear_extensionVecOps)) + s.Run("TestBabybear_extensionTranspose", testWrapper(s.Suite, testBabybear_extensionTranspose)) +} + +func TestSuiteBabybear_extensionVecOps(t *testing.T) { + suite.Run(t, new(Babybear_extensionVecOpsTestSuite)) } diff --git a/wrappers/golang/fields/babybear/tests/main_test.go b/wrappers/golang/fields/babybear/tests/main_test.go index 556b7462f..91e934627 100644 --- a/wrappers/golang/fields/babybear/tests/main_test.go +++ b/wrappers/golang/fields/babybear/tests/main_test.go @@ -2,19 +2,24 @@ package tests import ( "fmt" - "testing" - "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" babybear "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear" ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" + "github.com/stretchr/testify/suite" + "os" + "sync" + "testing" ) const ( largestTestSize = 20 ) -var DEVICE runtime.Device +var ( + DEVICE runtime.Device + exitCode int +) func initDomain(cfg core.NTTInitDomainConfig) runtime.EIcicleError { rouIcicle := babybear.ScalarField{} @@ -23,6 +28,18 @@ func initDomain(cfg core.NTTInitDomainConfig) runtime.EIcicleError { return e } +func testWrapper(suite suite.Suite, fn func(suite.Suite)) func() { + return func() { + wg := sync.WaitGroup{} + wg.Add(1) + runtime.RunOnDevice(&DEVICE, func(args ...any) { + defer wg.Done() + fn(suite) + }) + wg.Wait() + } +} + func TestMain(m *testing.M) { runtime.LoadBackendFromEnvOrDefault() devices, e := runtime.GetRegisteredDevices() @@ -30,6 +47,7 @@ func TestMain(m *testing.M) { panic("Failed to load registered devices") } for _, deviceType := range devices { + fmt.Println("Running tests for device type:", deviceType) DEVICE = runtime.CreateDevice(deviceType, 0) runtime.SetDevice(&DEVICE) @@ -44,8 +62,10 @@ func TestMain(m *testing.M) { } } + // TODO - run tests for each device type without calling `m.Run` multiple times + // see https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/testing/testing.go;l=1936-1940 for more info // execute tests - m.Run() + exitCode |= m.Run() // release domain e = ntt.ReleaseDomain() @@ -57,4 +77,6 @@ func TestMain(m *testing.M) { } } } + + os.Exit(exitCode) } diff --git a/wrappers/golang/fields/babybear/tests/ntt_no_domain_test.go b/wrappers/golang/fields/babybear/tests/ntt_no_domain_test.go index 06c9903d4..acde109e1 100644 --- a/wrappers/golang/fields/babybear/tests/ntt_no_domain_test.go +++ b/wrappers/golang/fields/babybear/tests/ntt_no_domain_test.go @@ -1,6 +1,7 @@ package tests import ( + "github.com/stretchr/testify/suite" "testing" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" @@ -9,7 +10,7 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" ) -func TestNttNoDomain(t *testing.T) { +func testNttNoDomain(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := babybear_extension.GenerateScalars(1 << largestTestSize) @@ -29,7 +30,7 @@ func TestNttNoDomain(t *testing.T) { } } -func TestNttDeviceAsyncNoDomain(t *testing.T) { +func testNttDeviceAsyncNoDomain(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := babybear_extension.GenerateScalars(1 << largestTestSize) @@ -64,7 +65,7 @@ func TestNttDeviceAsyncNoDomain(t *testing.T) { } } -func TestNttBatchNoDomain(t *testing.T) { +func testNttBatchNoDomain(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() largestTestSize := 12 largestBatchSize := 100 @@ -87,3 +88,17 @@ func TestNttBatchNoDomain(t *testing.T) { } } } + +type NTTNoDomainTestSuite struct { + suite.Suite +} + +func (s *NTTNoDomainTestSuite) TestNTTNoDomain() { + s.Run("TestNTTNoDomain", testWrapper(s.Suite, testNttNoDomain)) + s.Run("TestNttDeviceAsyncNoDomain", testWrapper(s.Suite, testNttDeviceAsyncNoDomain)) + s.Run("TestNttBatchNoDomain", testWrapper(s.Suite, testNttBatchNoDomain)) +} + +func TestSuiteNTTNoDomain(t *testing.T) { + suite.Run(t, new(NTTNoDomainTestSuite)) +} diff --git a/wrappers/golang/fields/babybear/tests/ntt_test.go b/wrappers/golang/fields/babybear/tests/ntt_test.go index 641c2da65..db4b9b287 100644 --- a/wrappers/golang/fields/babybear/tests/ntt_test.go +++ b/wrappers/golang/fields/babybear/tests/ntt_test.go @@ -8,26 +8,20 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestNTTGetDefaultConfig(t *testing.T) { +func testNTTGetDefaultConfig(suite suite.Suite) { actual := ntt.GetDefaultNttConfig() expected := test_helpers.GenerateLimbOne(int(babybear.SCALAR_LIMBS)) - assert.Equal(t, expected, actual.CosetGen[:]) + suite.Equal(expected, actual.CosetGen[:]) cosetGenField := babybear.ScalarField{} cosetGenField.One() - assert.ElementsMatch(t, cosetGenField.GetLimbs(), actual.CosetGen) + suite.ElementsMatch(cosetGenField.GetLimbs(), actual.CosetGen) } -func TestInitDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - cfg := core.GetDefaultNTTInitDomainConfig() - assert.NotPanics(t, func() { initDomain(cfg) }) -} - -func TestNtt(t *testing.T) { +func testNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := babybear.GenerateScalars(1 << largestTestSize) @@ -48,7 +42,7 @@ func TestNtt(t *testing.T) { } } -func TestNttDeviceAsync(t *testing.T) { +func testNttDeviceAsync(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := babybear.GenerateScalars(1 << largestTestSize) @@ -83,7 +77,7 @@ func TestNttDeviceAsync(t *testing.T) { } } -func TestNttBatch(t *testing.T) { +func testNttBatch(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() largestTestSize := 10 largestBatchSize := 20 @@ -108,8 +102,17 @@ func TestNttBatch(t *testing.T) { } } -func TestReleaseDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - e := ntt.ReleaseDomain() - assert.Equal(t, runtime.Success, e, "ReleasDomain failed") +type NTTTestSuite struct { + suite.Suite +} + +func (s *NTTTestSuite) TestNTT() { + s.Run("TestNTTGetDefaultConfig", testWrapper(s.Suite, testNTTGetDefaultConfig)) + s.Run("TestNTT", testWrapper(s.Suite, testNtt)) + s.Run("TestNttDeviceAsync", testWrapper(s.Suite, testNttDeviceAsync)) + s.Run("TestNttBatch", testWrapper(s.Suite, testNttBatch)) +} + +func TestSuiteNTT(t *testing.T) { + suite.Run(t, new(NTTTestSuite)) } diff --git a/wrappers/golang/fields/babybear/tests/polynomial_test.go b/wrappers/golang/fields/babybear/tests/polynomial_test.go index 1e1ed1d09..d2b76d625 100644 --- a/wrappers/golang/fields/babybear/tests/polynomial_test.go +++ b/wrappers/golang/fields/babybear/tests/polynomial_test.go @@ -6,10 +6,9 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" babybear "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear" - // "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear/polynomial" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) var one, two, three, four, five babybear.ScalarField @@ -41,7 +40,7 @@ func vecOp(a, b babybear.ScalarField, op core.VecOps) babybear.ScalarField { return out[0] } -func TestPolyCreateFromCoefficients(t *testing.T) { +func testPolyCreateFromCoefficients(suite suite.Suite) { scalars := babybear.GenerateScalars(33) var uniPoly polynomial.DensePolynomial @@ -49,7 +48,7 @@ func TestPolyCreateFromCoefficients(t *testing.T) { poly.Print() } -func TestPolyEval(t *testing.T) { +func testPolyEval(suite suite.Suite) { // testing correct evaluation of f(8) for f(x)=4x^2+2x+5 coeffs := core.HostSliceFromElements([]babybear.ScalarField{five, two, four}) var f polynomial.DensePolynomial @@ -62,10 +61,10 @@ func TestPolyEval(t *testing.T) { evals := make(core.HostSlice[babybear.ScalarField], 1) fEvaled := f.EvalOnDomain(domains, evals) var expected babybear.ScalarField - assert.Equal(t, expected.FromUint32(277), fEvaled.(core.HostSlice[babybear.ScalarField])[0]) + suite.Equal(expected.FromUint32(277), fEvaled.(core.HostSlice[babybear.ScalarField])[0]) } -func TestPolyClone(t *testing.T) { +func testPolyClone(suite suite.Suite) { f := randomPoly(8) x := rand() fx := f.Eval(x) @@ -76,11 +75,11 @@ func TestPolyClone(t *testing.T) { gx := g.Eval(x) fgx := fg.Eval(x) - assert.Equal(t, fx, gx) - assert.Equal(t, vecOp(fx, gx, core.Add), fgx) + suite.Equal(fx, gx) + suite.Equal(vecOp(fx, gx, core.Add), fgx) } -func TestPolyAddSubMul(t *testing.T) { +func testPolyAddSubMul(suite suite.Suite) { testSize := 1 << 10 f := randomPoly(testSize) g := randomPoly(testSize) @@ -91,26 +90,26 @@ func TestPolyAddSubMul(t *testing.T) { polyAdd := f.Add(&g) fxAddgx := vecOp(fx, gx, core.Add) - assert.Equal(t, polyAdd.Eval(x), fxAddgx) + suite.Equal(polyAdd.Eval(x), fxAddgx) polySub := f.Subtract(&g) fxSubgx := vecOp(fx, gx, core.Sub) - assert.Equal(t, polySub.Eval(x), fxSubgx) + suite.Equal(polySub.Eval(x), fxSubgx) polyMul := f.Multiply(&g) fxMulgx := vecOp(fx, gx, core.Mul) - assert.Equal(t, polyMul.Eval(x), fxMulgx) + suite.Equal(polyMul.Eval(x), fxMulgx) s1 := rand() polMulS1 := f.MultiplyByScalar(s1) - assert.Equal(t, polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) + suite.Equal(polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) s2 := rand() polMulS2 := f.MultiplyByScalar(s2) - assert.Equal(t, polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) + suite.Equal(polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) } -func TestPolyMonomials(t *testing.T) { +func testPolyMonomials(suite suite.Suite) { var zero babybear.ScalarField var f polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements([]babybear.ScalarField{one, zero, two})) @@ -119,20 +118,20 @@ func TestPolyMonomials(t *testing.T) { fx := f.Eval(x) f.AddMonomial(three, 1) fxAdded := f.Eval(x) - assert.Equal(t, fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) + suite.Equal(fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) f.SubMonomial(one, 0) fxSub := f.Eval(x) - assert.Equal(t, fxSub, vecOp(fxAdded, one, core.Sub)) + suite.Equal(fxSub, vecOp(fxAdded, one, core.Sub)) } -func TestPolyReadCoeffs(t *testing.T) { +func testPolyReadCoeffs(suite suite.Suite) { var f polynomial.DensePolynomial coeffs := core.HostSliceFromElements([]babybear.ScalarField{one, two, three, four}) f.CreateFromCoeffecitients(coeffs) coeffsCopied := make(core.HostSlice[babybear.ScalarField], coeffs.Len()) _, _ = f.CopyCoeffsRange(0, coeffs.Len()-1, coeffsCopied) - assert.ElementsMatch(t, coeffs, coeffsCopied) + suite.ElementsMatch(coeffs, coeffsCopied) var coeffsDevice core.DeviceSlice coeffsDevice.Malloc(one.Size(), coeffs.Len()) @@ -140,16 +139,16 @@ func TestPolyReadCoeffs(t *testing.T) { coeffsHost := make(core.HostSlice[babybear.ScalarField], coeffs.Len()) coeffsHost.CopyFromDevice(&coeffsDevice) - assert.ElementsMatch(t, coeffs, coeffsHost) + suite.ElementsMatch(coeffs, coeffsHost) } -func TestPolyOddEvenSlicing(t *testing.T) { +func testPolyOddEvenSlicing(suite suite.Suite) { size := 1<<10 - 3 f := randomPoly(size) even := f.Even() odd := f.Odd() - assert.Equal(t, f.Degree(), even.Degree()+odd.Degree()+1) + suite.Equal(f.Degree(), even.Degree()+odd.Degree()+1) x := rand() var evenExpected, oddExpected babybear.ScalarField @@ -164,13 +163,13 @@ func TestPolyOddEvenSlicing(t *testing.T) { } evenEvaled := even.Eval(x) - assert.Equal(t, evenExpected, evenEvaled) + suite.Equal(evenExpected, evenEvaled) oddEvaled := odd.Eval(x) - assert.Equal(t, oddExpected, oddEvaled) + suite.Equal(oddExpected, oddEvaled) } -func TestPolynomialDivision(t *testing.T) { +func testPolynomialDivision(suite suite.Suite) { // divide f(x)/g(x), compute q(x), r(x) and check f(x)=q(x)*g(x)+r(x) var f, g polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements(babybear.GenerateScalars(1 << 4))) @@ -184,10 +183,10 @@ func TestPolynomialDivision(t *testing.T) { x := babybear.GenerateScalars(1)[0] fEval := f.Eval(x) fReconEval := fRecon.Eval(x) - assert.Equal(t, fEval, fReconEval) + suite.Equal(fEval, fReconEval) } -func TestDivideByVanishing(t *testing.T) { +func testDivideByVanishing(suite suite.Suite) { // poly of x^4-1 vanishes ad 4th rou var zero babybear.ScalarField minus_one := vecOp(zero, one, core.Sub) @@ -200,31 +199,51 @@ func TestDivideByVanishing(t *testing.T) { fv := f.Multiply(&v) fDegree := f.Degree() fvDegree := fv.Degree() - assert.Equal(t, fDegree+4, fvDegree) + suite.Equal(fDegree+4, fvDegree) fReconstructed := fv.DivideByVanishing(4) - assert.Equal(t, fDegree, fReconstructed.Degree()) + suite.Equal(fDegree, fReconstructed.Degree()) x := rand() - assert.Equal(t, f.Eval(x), fReconstructed.Eval(x)) + suite.Equal(f.Eval(x), fReconstructed.Eval(x)) } -// func TestPolySlice(t *testing.T) { +// func TestPolySlice(suite suite.Suite) { // size := 4 // coeffs := babybear.GenerateScalars(size) // var f DensePolynomial // f.CreateFromCoeffecitients(coeffs) // fSlice := f.AsSlice() -// assert.True(t, fSlice.IsOnDevice()) -// assert.Equal(t, size, fSlice.Len()) +// suite.True(fSlice.IsOnDevice()) +// suite.Equal(size, fSlice.Len()) // hostSlice := make(core.HostSlice[babybear.ScalarField], size) // hostSlice.CopyFromDevice(fSlice) -// assert.Equal(t, coeffs, hostSlice) +// suite.Equal(coeffs, hostSlice) // cfg := ntt.GetDefaultNttConfig() // res := make(core.HostSlice[babybear.ScalarField], size) // ntt.Ntt(fSlice, core.KForward, cfg, res) -// assert.Equal(t, f.Eval(one), res[0]) +// suite.Equal(f.Eval(one), res[0]) // } + +type PolynomialTestSuite struct { + suite.Suite +} + +func (s *PolynomialTestSuite) TestPolynomial() { + s.Run("TestPolyCreateFromCoefficients", testWrapper(s.Suite, testPolyCreateFromCoefficients)) + s.Run("TestPolyEval", testWrapper(s.Suite, testPolyEval)) + s.Run("TestPolyClone", testWrapper(s.Suite, testPolyClone)) + s.Run("TestPolyAddSubMul", testWrapper(s.Suite, testPolyAddSubMul)) + s.Run("TestPolyMonomials", testWrapper(s.Suite, testPolyMonomials)) + s.Run("TestPolyReadCoeffs", testWrapper(s.Suite, testPolyReadCoeffs)) + s.Run("TestPolyOddEvenSlicing", testWrapper(s.Suite, testPolyOddEvenSlicing)) + s.Run("TestPolynomialDivision", testWrapper(s.Suite, testPolynomialDivision)) + s.Run("TestDivideByVanishing", testWrapper(s.Suite, testDivideByVanishing)) +} + +func TestSuitePolynomial(t *testing.T) { + suite.Run(t, new(PolynomialTestSuite)) +} diff --git a/wrappers/golang/fields/babybear/tests/scalar_field_test.go b/wrappers/golang/fields/babybear/tests/scalar_field_test.go index a9e0ec1c1..88b53007e 100644 --- a/wrappers/golang/fields/babybear/tests/scalar_field_test.go +++ b/wrappers/golang/fields/babybear/tests/scalar_field_test.go @@ -6,101 +6,101 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" babybear "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) const ( SCALAR_LIMBS = babybear.SCALAR_LIMBS ) -func TestScalarFieldFromLimbs(t *testing.T) { +func testScalarFieldFromLimbs(suite suite.Suite) { emptyField := babybear.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func TestScalarFieldGetLimbs(t *testing.T) { +func testScalarFieldGetLimbs(suite suite.Suite) { emptyField := babybear.ScalarField{} randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the ScalarField's limbs") } -func TestScalarFieldOne(t *testing.T) { +func testScalarFieldOne(suite suite.Suite) { var emptyField babybear.ScalarField emptyField.One() limbOne := test_helpers.GenerateLimbOne(int(SCALAR_LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "ScalarField with limbs to field one did not work") } -func TestScalarFieldZero(t *testing.T) { +func testScalarFieldZero(suite suite.Suite) { var emptyField babybear.ScalarField emptyField.Zero() limbsZero := make([]uint32, SCALAR_LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "ScalarField with limbs to field zero failed") } -func TestScalarFieldSize(t *testing.T) { +func testScalarFieldSize(suite suite.Suite) { var emptyField babybear.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func TestScalarFieldAsPointer(t *testing.T) { +func testScalarFieldAsPointer(suite suite.Suite) { var emptyField babybear.ScalarField randLimbs := test_helpers.GenerateRandomLimb(int(SCALAR_LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func TestScalarFieldFromBytes(t *testing.T) { +func testScalarFieldFromBytes(suite suite.Suite) { var emptyField babybear.ScalarField bytes, expected := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func TestScalarFieldToBytes(t *testing.T) { +func testScalarFieldToBytes(suite suite.Suite) { var emptyField babybear.ScalarField expected, limbs := test_helpers.GenerateBytesArray(int(SCALAR_LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") } -func TestBabybearGenerateScalars(t *testing.T) { +func testBabybearGenerateScalars(suite suite.Suite) { const numScalars = 8 scalars := babybear.GenerateScalars(numScalars) - assert.Implements(t, (*core.HostOrDeviceSlice)(nil), &scalars) + suite.Implements((*core.HostOrDeviceSlice)(nil), &scalars) - assert.Equal(t, numScalars, scalars.Len()) + suite.Equal(numScalars, scalars.Len()) zeroScalar := babybear.ScalarField{} - assert.NotContains(t, scalars, zeroScalar) + suite.NotContains(scalars, zeroScalar) } -func TestBabybearMongtomeryConversion(t *testing.T) { +func testBabybearMongtomeryConversion(suite suite.Suite) { size := 1 << 20 scalars := babybear.GenerateScalars(size) @@ -112,10 +112,31 @@ func TestBabybearMongtomeryConversion(t *testing.T) { scalarsMontHost := make(core.HostSlice[babybear.ScalarField], size) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.NotEqual(t, scalars, scalarsMontHost) + suite.NotEqual(scalars, scalarsMontHost) babybear.FromMontgomery(deviceScalars) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.Equal(t, scalars, scalarsMontHost) + suite.Equal(scalars, scalarsMontHost) +} + +type ScalarFieldTestSuite struct { + suite.Suite +} + +func (s *ScalarFieldTestSuite) TestScalarField() { + s.Run("TestScalarFieldFromLimbs", testWrapper(s.Suite, testScalarFieldFromLimbs)) + s.Run("TestScalarFieldGetLimbs", testWrapper(s.Suite, testScalarFieldGetLimbs)) + s.Run("TestScalarFieldOne", testWrapper(s.Suite, testScalarFieldOne)) + s.Run("TestScalarFieldZero", testWrapper(s.Suite, testScalarFieldZero)) + s.Run("TestScalarFieldSize", testWrapper(s.Suite, testScalarFieldSize)) + s.Run("TestScalarFieldAsPointer", testWrapper(s.Suite, testScalarFieldAsPointer)) + s.Run("TestScalarFieldFromBytes", testWrapper(s.Suite, testScalarFieldFromBytes)) + s.Run("TestScalarFieldToBytes", testWrapper(s.Suite, testScalarFieldToBytes)) + s.Run("TestBabybearGenerateScalars", testWrapper(s.Suite, testBabybearGenerateScalars)) + s.Run("TestBabybearMongtomeryConversion", testWrapper(s.Suite, testBabybearMongtomeryConversion)) +} + +func TestSuiteScalarField(t *testing.T) { + suite.Run(t, new(ScalarFieldTestSuite)) } diff --git a/wrappers/golang/fields/babybear/tests/vec_ops_test.go b/wrappers/golang/fields/babybear/tests/vec_ops_test.go index 772b964bd..e2aea1e65 100644 --- a/wrappers/golang/fields/babybear/tests/vec_ops_test.go +++ b/wrappers/golang/fields/babybear/tests/vec_ops_test.go @@ -6,10 +6,10 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" babybear "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/fields/babybear/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestBabybearVecOps(t *testing.T) { +func testBabybearVecOps(suite suite.Suite) { testSize := 1 << 14 a := babybear.GenerateScalars(testSize) @@ -27,14 +27,14 @@ func TestBabybearVecOps(t *testing.T) { vecOps.VecOp(a, b, out, cfg, core.Add) vecOps.VecOp(out, b, out2, cfg, core.Sub) - assert.Equal(t, a, out2) + suite.Equal(a, out2) vecOps.VecOp(a, ones, out3, cfg, core.Mul) - assert.Equal(t, a, out3) + suite.Equal(a, out3) } -func TestBabybearTranspose(t *testing.T) { +func testBabybearTranspose(suite suite.Suite) { rowSize := 1 << 6 columnSize := 1 << 8 @@ -48,7 +48,7 @@ func TestBabybearTranspose(t *testing.T) { vecOps.TransposeMatrix(matrix, out, columnSize, rowSize, cfg) vecOps.TransposeMatrix(out, out2, rowSize, columnSize, cfg) - assert.Equal(t, matrix, out2) + suite.Equal(matrix, out2) var dMatrix, dOut, dOut2 core.DeviceSlice @@ -61,5 +61,18 @@ func TestBabybearTranspose(t *testing.T) { output := make(core.HostSlice[babybear.ScalarField], rowSize*columnSize) output.CopyFromDevice(&dOut2) - assert.Equal(t, matrix, output) + suite.Equal(matrix, output) +} + +type BabybearVecOpsTestSuite struct { + suite.Suite +} + +func (s *BabybearVecOpsTestSuite) TestBabybearVecOps() { + s.Run("TestBabybearVecOps", testWrapper(s.Suite, testBabybearVecOps)) + s.Run("TestBabybearTranspose", testWrapper(s.Suite, testBabybearTranspose)) +} + +func TestSuiteBabybearVecOps(t *testing.T) { + suite.Run(t, new(BabybearVecOpsTestSuite)) } diff --git a/wrappers/golang/internal/generator/curves/templates/curve.go.tmpl b/wrappers/golang/internal/generator/curves/templates/curve.go.tmpl index eac33fbc6..9db359aae 100644 --- a/wrappers/golang/internal/generator/curves/templates/curve.go.tmpl +++ b/wrappers/golang/internal/generator/curves/templates/curve.go.tmpl @@ -97,6 +97,10 @@ func (a *{{.CurvePrefix}}Affine) Zero() {{.CurvePrefix}}Affine { return *a } +func (a *{{.CurvePrefix}}Affine) IsZero() bool { + return a.X.IsZero() && a.Y.IsZero() +} + func (a *{{.CurvePrefix}}Affine) FromLimbs(x, y []uint32) {{.CurvePrefix}}Affine { a.X.FromLimbs(x) a.Y.FromLimbs(y) @@ -108,9 +112,19 @@ func (a *{{.CurvePrefix}}Affine) FromLimbs(x, y []uint32) {{.CurvePrefix}}Affine func (a {{.CurvePrefix}}Affine) ToProjective() {{.CurvePrefix}}Projective { var p {{.CurvePrefix}}Projective - cA := (*C.{{toCName .CurvePrefix}}affine_t)(unsafe.Pointer(&a)) - cP := (*C.{{toCName .CurvePrefix}}projective_t)(unsafe.Pointer(&p)) - C.{{.Curve}}{{toCNameBackwards .CurvePrefix}}_from_affine(cA, cP) + // TODO - Figure out why this sometimes returns an empty projective point, i.e. {x:0, y:0, z:0} + // cA := (*C.{{toCName .CurvePrefix}}affine_t)(unsafe.Pointer(&a)) + // cP := (*C.{{toCName .CurvePrefix}}projective_t)(unsafe.Pointer(&p)) + // C.{{.Curve}}{{toCNameBackwards .CurvePrefix}}_from_affine(cA, cP) + + if a.IsZero() { + p.Zero() + } else { + p.X = a.X + p.Y = a.Y + p.Z.One() + } + return p } diff --git a/wrappers/golang/internal/generator/curves/templates/curve_test.go.tmpl b/wrappers/golang/internal/generator/curves/templates/curve_test.go.tmpl index 1ad840eea..b2514dd60 100644 --- a/wrappers/golang/internal/generator/curves/templates/curve_test.go.tmpl +++ b/wrappers/golang/internal/generator/curves/templates/curve_test.go.tmpl @@ -5,15 +5,15 @@ import ( {{if ne .CurvePrefix "G2"}}{{.Curve}}{{end}} "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func Test{{.CurvePrefix}}AffineZero(t *testing.T) { +func test{{.CurvePrefix}}AffineZero(suite suite.Suite) { var fieldZero = {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}BaseField{} var affineZero {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}Affine - assert.Equal(t, affineZero.X, fieldZero) - assert.Equal(t, affineZero.Y, fieldZero) + suite.Equal(affineZero.X, fieldZero) + suite.Equal(affineZero.Y, fieldZero) x := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) y := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) @@ -21,22 +21,22 @@ func Test{{.CurvePrefix}}AffineZero(t *testing.T) { affine.FromLimbs(x, y) affine.Zero() - assert.Equal(t, affine.X, fieldZero) - assert.Equal(t, affine.Y, fieldZero) + suite.Equal(affine.X, fieldZero) + suite.Equal(affine.Y, fieldZero) } -func Test{{.CurvePrefix}}AffineFromLimbs(t *testing.T) { +func test{{.CurvePrefix}}AffineFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) var affine {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}Affine affine.FromLimbs(randLimbs, randLimbs2) - assert.ElementsMatch(t, randLimbs, affine.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, affine.Y.GetLimbs()) + suite.ElementsMatch(randLimbs, affine.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, affine.Y.GetLimbs()) } -func Test{{.CurvePrefix}}AffineToProjective(t *testing.T) { +func test{{.CurvePrefix}}AffineToProjective(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) var fieldOne {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}BaseField @@ -49,31 +49,31 @@ func Test{{.CurvePrefix}}AffineToProjective(t *testing.T) { affine.FromLimbs(randLimbs, randLimbs2) projectivePoint := affine.ToProjective() - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) } -func Test{{.CurvePrefix}}ProjectiveZero(t *testing.T) { +func test{{.CurvePrefix}}ProjectiveZero(suite suite.Suite) { var projectiveZero {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}Projective projectiveZero.Zero() var fieldZero = {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}BaseField{} var fieldOne {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}BaseField fieldOne.One() - assert.Equal(t, projectiveZero.X, fieldZero) - assert.Equal(t, projectiveZero.Y, fieldOne) - assert.Equal(t, projectiveZero.Z, fieldZero) + suite.Equal(projectiveZero.X, fieldZero) + suite.Equal(projectiveZero.Y, fieldOne) + suite.Equal(projectiveZero.Z, fieldZero) randLimbs := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) var projective {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}Projective projective.FromLimbs(randLimbs, randLimbs, randLimbs) projective.Zero() - assert.Equal(t, projective.X, fieldZero) - assert.Equal(t, projective.Y, fieldOne) - assert.Equal(t, projective.Z, fieldZero) + suite.Equal(projective.X, fieldZero) + suite.Equal(projective.Y, fieldOne) + suite.Equal(projective.Z, fieldZero) } -func Test{{.CurvePrefix}}ProjectiveFromLimbs(t *testing.T) { +func test{{.CurvePrefix}}ProjectiveFromLimbs(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) randLimbs3 := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) @@ -81,12 +81,12 @@ func Test{{.CurvePrefix}}ProjectiveFromLimbs(t *testing.T) { var projective {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}Projective projective.FromLimbs(randLimbs, randLimbs2, randLimbs3) - assert.ElementsMatch(t, randLimbs, projective.X.GetLimbs()) - assert.ElementsMatch(t, randLimbs2, projective.Y.GetLimbs()) - assert.ElementsMatch(t, randLimbs3, projective.Z.GetLimbs()) + suite.ElementsMatch(randLimbs, projective.X.GetLimbs()) + suite.ElementsMatch(randLimbs2, projective.Y.GetLimbs()) + suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } -func Test{{.CurvePrefix}}ProjectiveFromAffine(t *testing.T) { +func test{{.CurvePrefix}}ProjectiveFromAffine(suite suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) var fieldOne {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}BaseField @@ -100,5 +100,22 @@ func Test{{.CurvePrefix}}ProjectiveFromAffine(t *testing.T) { var projectivePoint {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}Projective projectivePoint.FromAffine(affine) - assert.Equal(t, expected, projectivePoint) + suite.Equal(expected, projectivePoint) +} + +type {{.CurvePrefix}}CurveTestSuite struct { + suite.Suite +} + +func (s *{{.CurvePrefix}}CurveTestSuite) Test{{.CurvePrefix}}Curve() { + s.Run("Test{{.CurvePrefix}}AffineZero", testWrapper(s.Suite, test{{.CurvePrefix}}AffineZero)) + s.Run("Test{{.CurvePrefix}}AffineFromLimbs", testWrapper(s.Suite, test{{.CurvePrefix}}AffineFromLimbs)) + s.Run("Test{{.CurvePrefix}}AffineToProjective", testWrapper(s.Suite, test{{.CurvePrefix}}AffineToProjective)) + s.Run("Test{{.CurvePrefix}}ProjectiveZero", testWrapper(s.Suite, test{{.CurvePrefix}}ProjectiveZero)) + s.Run("Test{{.CurvePrefix}}ProjectiveFromLimbs", testWrapper(s.Suite, test{{.CurvePrefix}}ProjectiveFromLimbs)) + s.Run("Test{{.CurvePrefix}}ProjectiveFromAffine", testWrapper(s.Suite, test{{.CurvePrefix}}ProjectiveFromAffine)) +} + +func TestSuite{{.CurvePrefix}}Curve(t *testing.T) { + suite.Run(t, new({{.CurvePrefix}}CurveTestSuite)) } diff --git a/wrappers/golang/internal/generator/ecntt/templates/ecntt_test.go.tmpl b/wrappers/golang/internal/generator/ecntt/templates/ecntt_test.go.tmpl index 8b12a6af8..ee1bb522f 100644 --- a/wrappers/golang/internal/generator/ecntt/templates/ecntt_test.go.tmpl +++ b/wrappers/golang/internal/generator/ecntt/templates/ecntt_test.go.tmpl @@ -9,10 +9,10 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime/config_extension" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func TestECNtt(t *testing.T) { +func testECNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() ext := config_extension.Create() ext.SetInt(core.CUDA_NTT_ALGORITHM, int(core.Radix2)) @@ -31,7 +31,19 @@ func TestECNtt(t *testing.T) { output := make(core.HostSlice[{{.Curve}}.Projective], testSize) e := ecntt.ECNtt(pointsCopy, core.KForward, &cfg, output) - assert.Equal(t, runtime.Success, e, "ECNtt failed") + suite.Equal(runtime.Success, e, "ECNtt failed") } } } + +type ECNttTestSuite struct { + suite.Suite +} + +func (s *ECNttTestSuite) TestECNtt() { + s.Run("TestECNtt", testWrapper(s.Suite, testECNtt)) +} + +func TestSuiteECNtt(t *testing.T) { + suite.Run(t, new(ECNttTestSuite)) +} diff --git a/wrappers/golang/internal/generator/fields/templates/field.go.tmpl b/wrappers/golang/internal/generator/fields/templates/field.go.tmpl index 927aef1fd..3e7946af1 100644 --- a/wrappers/golang/internal/generator/fields/templates/field.go.tmpl +++ b/wrappers/golang/internal/generator/fields/templates/field.go.tmpl @@ -63,6 +63,16 @@ func (f *{{.FieldPrefix}}Field) Zero() {{.FieldPrefix}}Field { return *f } +func (f *{{.FieldPrefix}}Field) IsZero() bool { + for _, limb := range f.limbs { + if limb != 0 { + return false + } + } + + return true +} + func (f *{{.FieldPrefix}}Field) One() {{.FieldPrefix}}Field { for i := range f.limbs { f.limbs[i] = 0 diff --git a/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl b/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl index 15062f92f..8718d9adf 100644 --- a/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl +++ b/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl @@ -6,101 +6,101 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core"{{end}} {{if ne .FieldPrefix "G2"}}{{.Field}}{{end}} "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}{{if eq .FieldPrefix "G2"}}/g2{{end}}" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/test_helpers" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) - const ( +const ( {{toConst .FieldPrefix}}LIMBS = {{if eq .FieldPrefix "G2"}}g2{{else}}{{.Field}}{{end}}.{{toConst .FieldPrefix}}LIMBS ) -func Test{{.FieldPrefix}}FieldFromLimbs(t *testing.T) { +func test{{.FieldPrefix}}FieldFromLimbs(suite suite.Suite) { emptyField := {{if eq .FieldPrefix "G2"}}g2{{else}}{{.Field}}{{end}}.{{.FieldPrefix}}Field{} randLimbs := test_helpers.GenerateRandomLimb(int({{toConst .FieldPrefix}}LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the {{.FieldPrefix}}Field's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the {{.FieldPrefix}}Field's limbs") randLimbs[0] = 100 - assert.NotEqual(t, randLimbs, emptyField.GetLimbs()) + suite.NotEqual(randLimbs, emptyField.GetLimbs()) } -func Test{{.FieldPrefix}}FieldGetLimbs(t *testing.T) { +func test{{.FieldPrefix}}FieldGetLimbs(suite suite.Suite) { emptyField := {{if eq .FieldPrefix "G2"}}g2{{else}}{{.Field}}{{end}}.{{.FieldPrefix}}Field{} randLimbs := test_helpers.GenerateRandomLimb(int({{toConst .FieldPrefix}}LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.ElementsMatch(t, randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the {{.FieldPrefix}}Field's limbs") + suite.ElementsMatch(randLimbs, emptyField.GetLimbs(), "Limbs do not match; there was an issue with setting the {{.FieldPrefix}}Field's limbs") } -func Test{{.FieldPrefix}}FieldOne(t *testing.T) { +func test{{.FieldPrefix}}FieldOne(suite suite.Suite) { var emptyField {{if eq .FieldPrefix "G2"}}g2{{else}}{{.Field}}{{end}}.{{.FieldPrefix}}Field emptyField.One() limbOne := test_helpers.GenerateLimbOne(int({{toConst .FieldPrefix}}LIMBS)) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "Empty field to field one did not work") randLimbs := test_helpers.GenerateRandomLimb(int({{toConst .FieldPrefix}}LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.One() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbOne, "{{.FieldPrefix}}Field with limbs to field one did not work") + suite.ElementsMatch(emptyField.GetLimbs(), limbOne, "{{.FieldPrefix}}Field with limbs to field one did not work") } -func Test{{.FieldPrefix}}FieldZero(t *testing.T) { +func test{{.FieldPrefix}}FieldZero(suite suite.Suite) { var emptyField {{if eq .FieldPrefix "G2"}}g2{{else}}{{.Field}}{{end}}.{{.FieldPrefix}}Field emptyField.Zero() limbsZero := make([]uint32, {{toConst .FieldPrefix}}LIMBS) - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "Empty field to field zero failed") randLimbs := test_helpers.GenerateRandomLimb(int({{toConst .FieldPrefix}}LIMBS)) emptyField.FromLimbs(randLimbs[:]) emptyField.Zero() - assert.ElementsMatch(t, emptyField.GetLimbs(), limbsZero, "{{.FieldPrefix}}Field with limbs to field zero failed") + suite.ElementsMatch(emptyField.GetLimbs(), limbsZero, "{{.FieldPrefix}}Field with limbs to field zero failed") } -func Test{{.FieldPrefix}}FieldSize(t *testing.T) { +func test{{.FieldPrefix}}FieldSize(suite suite.Suite) { var emptyField {{if eq .FieldPrefix "G2"}}g2{{else}}{{.Field}}{{end}}.{{.FieldPrefix}}Field randLimbs := test_helpers.GenerateRandomLimb(int({{toConst .FieldPrefix}}LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") + suite.Equal(len(randLimbs)*4, emptyField.Size(), "Size returned an incorrect value of bytes") } -func Test{{.FieldPrefix}}FieldAsPointer(t *testing.T) { +func test{{.FieldPrefix}}FieldAsPointer(suite suite.Suite) { var emptyField {{if eq .FieldPrefix "G2"}}g2{{else}}{{.Field}}{{end}}.{{.FieldPrefix}}Field randLimbs := test_helpers.GenerateRandomLimb(int({{toConst .FieldPrefix}}LIMBS)) emptyField.FromLimbs(randLimbs[:]) - assert.Equal(t, randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") + suite.Equal(randLimbs[0], *emptyField.AsPointer(), "AsPointer returned pointer to incorrect value") } -func Test{{.FieldPrefix}}FieldFromBytes(t *testing.T) { +func test{{.FieldPrefix}}FieldFromBytes(suite suite.Suite) { var emptyField {{if eq .FieldPrefix "G2"}}g2{{else}}{{.Field}}{{end}}.{{.FieldPrefix}}Field bytes, expected := test_helpers.GenerateBytesArray(int({{toConst .FieldPrefix}}LIMBS)) emptyField.FromBytesLittleEndian(bytes) - assert.ElementsMatch(t, emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") + suite.ElementsMatch(emptyField.GetLimbs(), expected, "FromBytes returned incorrect values") } -func Test{{.FieldPrefix}}FieldToBytes(t *testing.T) { +func test{{.FieldPrefix}}FieldToBytes(suite suite.Suite) { var emptyField {{if eq .FieldPrefix "G2"}}g2{{else}}{{.Field}}{{end}}.{{.FieldPrefix}}Field expected, limbs := test_helpers.GenerateBytesArray(int({{toConst .FieldPrefix}}LIMBS)) emptyField.FromLimbs(limbs) - assert.ElementsMatch(t, emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") + suite.ElementsMatch(emptyField.ToBytesLittleEndian(), expected, "ToBytes returned incorrect values") } {{if .IsScalar}} -func Test{{capitalize .Field}}GenerateScalars(t *testing.T) { +func test{{capitalize .Field}}GenerateScalars(suite suite.Suite) { const numScalars = 8 scalars := {{.Field}}.GenerateScalars(numScalars) - assert.Implements(t, (*core.HostOrDeviceSlice)(nil), &scalars) + suite.Implements((*core.HostOrDeviceSlice)(nil), &scalars) - assert.Equal(t, numScalars, scalars.Len()) + suite.Equal(numScalars, scalars.Len()) zeroScalar := {{.Field}}.{{.FieldPrefix}}Field{} - assert.NotContains(t, scalars, zeroScalar) + suite.NotContains(scalars, zeroScalar) } -func Test{{capitalize .Field}}MongtomeryConversion(t *testing.T) { +func test{{capitalize .Field}}MongtomeryConversion(suite suite.Suite) { size := 1 << 20 scalars := {{.Field}}.GenerateScalars(size) @@ -112,10 +112,34 @@ func Test{{capitalize .Field}}MongtomeryConversion(t *testing.T) { scalarsMontHost := make(core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], size) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.NotEqual(t, scalars, scalarsMontHost) + suite.NotEqual(scalars, scalarsMontHost) {{.Field}}.FromMontgomery(deviceScalars) scalarsMontHost.CopyFromDevice(&deviceScalars) - assert.Equal(t, scalars, scalarsMontHost) + suite.Equal(scalars, scalarsMontHost) }{{end}} + + +type {{.FieldPrefix}}FieldTestSuite struct { + suite.Suite +} + +func (s *{{.FieldPrefix}}FieldTestSuite) Test{{.FieldPrefix}}Field() { + s.Run("Test{{.FieldPrefix}}FieldFromLimbs", testWrapper(s.Suite, test{{.FieldPrefix}}FieldFromLimbs)) + s.Run("Test{{.FieldPrefix}}FieldGetLimbs", testWrapper(s.Suite, test{{.FieldPrefix}}FieldGetLimbs)) + s.Run("Test{{.FieldPrefix}}FieldOne", testWrapper(s.Suite, test{{.FieldPrefix}}FieldOne)) + s.Run("Test{{.FieldPrefix}}FieldZero", testWrapper(s.Suite, test{{.FieldPrefix}}FieldZero)) + s.Run("Test{{.FieldPrefix}}FieldSize", testWrapper(s.Suite, test{{.FieldPrefix}}FieldSize)) + s.Run("Test{{.FieldPrefix}}FieldAsPointer", testWrapper(s.Suite, test{{.FieldPrefix}}FieldAsPointer)) + s.Run("Test{{.FieldPrefix}}FieldFromBytes", testWrapper(s.Suite, test{{.FieldPrefix}}FieldFromBytes)) + s.Run("Test{{.FieldPrefix}}FieldToBytes", testWrapper(s.Suite, test{{.FieldPrefix}}FieldToBytes)) + {{if .IsScalar -}} + s.Run("Test{{capitalize .Field}}GenerateScalars", testWrapper(s.Suite, test{{capitalize .Field}}GenerateScalars)) + s.Run("Test{{capitalize .Field}}MongtomeryConversion", testWrapper(s.Suite, test{{capitalize .Field}}MongtomeryConversion)) + {{- end}} +} + +func TestSuite{{.FieldPrefix}}Field(t *testing.T) { + suite.Run(t, new({{.FieldPrefix}}FieldTestSuite)) +} diff --git a/wrappers/golang/internal/generator/msm/templates/msm_test.go.tmpl b/wrappers/golang/internal/generator/msm/templates/msm_test.go.tmpl index da4566f4f..05fd6d877 100644 --- a/wrappers/golang/internal/generator/msm/templates/msm_test.go.tmpl +++ b/wrappers/golang/internal/generator/msm/templates/msm_test.go.tmpl @@ -5,7 +5,7 @@ import ( "sync" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" {{if ne .GnarkImport "" -}} "github.com/consensys/gnark-crypto/ecc" @@ -82,7 +82,7 @@ func projectiveToGnarkAffineG2(p g2.G2Projective) {{toPackage .GnarkImport}}.G2A return *g2Affine.FromJacobian(&g2Jac) } {{end}} -func testAgainstGnarkCryptoMsm{{.CurvePrefix}}(t *testing.T, scalars core.HostSlice[icicle{{capitalize .Curve}}.ScalarField], points core.HostSlice[{{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Affine], out {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective) { +func testAgainstGnarkCryptoMsm{{.CurvePrefix}}(suite suite.Suite, scalars core.HostSlice[icicle{{capitalize .Curve}}.ScalarField], points core.HostSlice[{{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Affine], out {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective) { scalarsFr := make([]fr.Element, len(scalars)) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -94,10 +94,10 @@ func testAgainstGnarkCryptoMsm{{.CurvePrefix}}(t *testing.T, scalars core.HostSl pointsFp[i] = projectiveToGnarkAffine{{.CurvePrefix}}(v.ToProjective()) } - testAgainstGnarkCryptoMsm{{.CurvePrefix}}GnarkCryptoTypes(t, scalarsFr, pointsFp, out) + testAgainstGnarkCryptoMsm{{.CurvePrefix}}GnarkCryptoTypes(suite, scalarsFr, pointsFp, out) } -func testAgainstGnarkCryptoMsm{{.CurvePrefix}}GnarkCryptoTypes(t *testing.T, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[{{toPackage .GnarkImport}}.{{if eq .CurvePrefix "G2"}}G2{{else}}G1{{end}}Affine], out {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective) { +func testAgainstGnarkCryptoMsm{{.CurvePrefix}}GnarkCryptoTypes(suite suite.Suite, scalarsFr core.HostSlice[fr.Element], pointsFp core.HostSlice[{{toPackage .GnarkImport}}.{{if eq .CurvePrefix "G2"}}G2{{else}}G1{{end}}Affine], out {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective) { var msmRes {{toPackage .GnarkImport}}.{{if eq .CurvePrefix "G2"}}G2{{else}}G1{{end}}Jac msmRes.MultiExp(pointsFp, scalarsFr, ecc.MultiExpConfig{}) @@ -106,7 +106,7 @@ func testAgainstGnarkCryptoMsm{{.CurvePrefix}}GnarkCryptoTypes(t *testing.T, sca icicleResAffine := projectiveToGnarkAffine{{.CurvePrefix}}(out) - assert.Equal(t, msmResAffine, icicleResAffine) + suite.Equal(msmResAffine, icicleResAffine) } {{$isBW6 := eq .Curve "bw6_761"}}{{$isG2 := eq .CurvePrefix "G2"}}{{$isG1 := ne .CurvePrefix "G2"}}{{if or $isBW6 $isG1 -}} @@ -157,7 +157,7 @@ func convertIcicleG2AffineToG2Affine(iciclePoints []g2.G2Affine) []{{toPackage . return points }{{end}}{{end}} -func TestMSM{{.CurvePrefix}}(t *testing.T) { +func testMSM{{.CurvePrefix}}(suite suite.Suite) { cfg := {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}GetDefaultMSMConfig() cfg.IsAsync = true for _, power := range []int{2, 3, 4, 5, 6} { @@ -171,11 +171,11 @@ func TestMSM{{.CurvePrefix}}(t *testing.T) { var p {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[{{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -184,12 +184,12 @@ func TestMSM{{.CurvePrefix}}(t *testing.T) { runtime.DestroyStream(stream) {{if ne .GnarkImport "" -}} // Check with gnark-crypto - testAgainstGnarkCryptoMsm{{.CurvePrefix}}(t, scalars, points, outHost[0]){{end}} + testAgainstGnarkCryptoMsm{{.CurvePrefix}}(suite, scalars, points, outHost[0]){{end}} } } {{if ne .GnarkImport "" -}} -func TestMSM{{if eq .CurvePrefix "G2"}}G2{{end}}GnarkCryptoTypes(t *testing.T) { +func testMSM{{if eq .CurvePrefix "G2"}}G2{{end}}GnarkCryptoTypes(suite suite.Suite) { cfg := {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}GetDefaultMSMConfig() for _, power := range []int{3} { runtime.SetDevice(&DEVICE) @@ -209,22 +209,22 @@ func TestMSM{{if eq .CurvePrefix "G2"}}G2{{end}}GnarkCryptoTypes(t *testing.T) { var p {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.AreBasesMontgomeryForm = true cfg.AreScalarsMontgomeryForm = true e = {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}Msm(scalarsHost, pointsHost, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[{{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective], 1) outHost.CopyFromDevice(&out) out.Free() // Check with gnark-crypto - testAgainstGnarkCryptoMsm{{.CurvePrefix}}GnarkCryptoTypes(t, scalarsHost, pointsHost, outHost[0]) + testAgainstGnarkCryptoMsm{{.CurvePrefix}}GnarkCryptoTypes(suite, scalarsHost, pointsHost, outHost[0]) } } {{end}} -func TestMSM{{.CurvePrefix}}Batch(t *testing.T) { +func testMSM{{.CurvePrefix}}Batch(suite suite.Suite) { cfg := {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}GetDefaultMSMConfig() for _, power := range []int{5, 6} { for _, batchSize := range []int{1, 3, 5} { @@ -237,10 +237,10 @@ func TestMSM{{.CurvePrefix}}Batch(t *testing.T) { var p {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), batchSize) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[{{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -250,15 +250,15 @@ func TestMSM{{.CurvePrefix}}Batch(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm{{.CurvePrefix}}(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm{{.CurvePrefix}}(suite, scalarsSlice, pointsSlice, out) }{{end}} } } } -func TestPrecomputePoints{{.CurvePrefix}}(t *testing.T) { +func testPrecomputePoints{{.CurvePrefix}}(suite suite.Suite) { if DEVICE.GetDeviceType() == "CPU" { - t.Skip("Skipping cpu test") + suite.T().Skip("Skipping cpu test") } cfg := {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}GetDefaultMSMConfig() const precomputeFactor = 8 @@ -275,20 +275,20 @@ func TestPrecomputePoints{{.CurvePrefix}}(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") cfg.BatchSize = int32(batchSize) cfg.ArePointsSharedInBatch = false e = {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[{{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -299,13 +299,13 @@ func TestPrecomputePoints{{.CurvePrefix}}(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[i*size : (i+1)*size] out := outHost[i] - testAgainstGnarkCryptoMsm{{.CurvePrefix}}(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm{{.CurvePrefix}}(suite, scalarsSlice, pointsSlice, out) }{{end}} } } } -func TestPrecomputePointsSharedBases{{.CurvePrefix}}(t *testing.T) { +func testPrecomputePointsSharedBases{{.CurvePrefix}}(suite suite.Suite) { cfg := {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}GetDefaultMSMConfig() const precomputeFactor = 8 cfg.PrecomputeFactor = precomputeFactor @@ -321,18 +321,18 @@ func TestPrecomputePointsSharedBases{{.CurvePrefix}}(t *testing.T) { var precomputeOut core.DeviceSlice _, e := precomputeOut.Malloc(points[0].Size(), points.Len()*int(precomputeFactor)) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for PrecomputeBases results failed") e = {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}PrecomputeBases(points, &cfg, precomputeOut) - assert.Equal(t, runtime.Success, e, "PrecomputeBases failed") + suite.Equal(runtime.Success, e, "PrecomputeBases failed") var p {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective var out core.DeviceSlice _, e = out.Malloc(p.Size(), batchSize) - assert.Equal(t, runtime.Success, e, "Allocating bytes on device for Projective results failed") + suite.Equal(runtime.Success, e, "Allocating bytes on device for Projective results failed") e = {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}Msm(scalars, precomputeOut, &cfg, out) - assert.Equal(t, runtime.Success, e, "Msm failed") + suite.Equal(runtime.Success, e, "Msm failed") outHost := make(core.HostSlice[{{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective], batchSize) outHost.CopyFromDevice(&out) out.Free() @@ -343,13 +343,13 @@ func TestPrecomputePointsSharedBases{{.CurvePrefix}}(t *testing.T) { scalarsSlice := scalars[i*size : (i+1)*size] pointsSlice := points[0 : size] out := outHost[i] - testAgainstGnarkCryptoMsm{{.CurvePrefix}}(t, scalarsSlice, pointsSlice, out) + testAgainstGnarkCryptoMsm{{.CurvePrefix}}(suite, scalarsSlice, pointsSlice, out) }{{end}} } } } -func TestMSM{{.CurvePrefix}}SkewedDistribution(t *testing.T) { +func testMSM{{.CurvePrefix}}SkewedDistribution(suite suite.Suite) { cfg := {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5} { runtime.SetDevice(&DEVICE) @@ -368,20 +368,20 @@ func TestMSM{{.CurvePrefix}}SkewedDistribution(t *testing.T) { var p {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective var out core.DeviceSlice _, e := out.Malloc(p.Size(), 1) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") e = {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[{{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective], 1) outHost.CopyFromDevice(&out) out.Free() {{if ne .GnarkImport "" -}} // Check with gnark-crypto - testAgainstGnarkCryptoMsm{{.CurvePrefix}}(t, scalars, points, outHost[0]){{end}} + testAgainstGnarkCryptoMsm{{.CurvePrefix}}(suite, scalars, points, outHost[0]){{end}} } } -func TestMSM{{.CurvePrefix}}MultiDevice(t *testing.T) { +func testMSM{{.CurvePrefix}}MultiDevice(suite suite.Suite) { numDevices, _ := runtime.GetDeviceCount() fmt.Println("There are ", numDevices, " ", DEVICE.GetDeviceType(), " devices available") wg := sync.WaitGroup{} @@ -405,11 +405,11 @@ func TestMSM{{.CurvePrefix}}MultiDevice(t *testing.T) { var p {{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective var out core.DeviceSlice _, e := out.MallocAsync(p.Size(), 1, stream) - assert.Equal(t, e, runtime.Success, "Allocating bytes on device for Projective results failed") + suite.Equal(e, runtime.Success, "Allocating bytes on device for Projective results failed") cfg.StreamHandle = stream e = {{if eq .CurvePrefix "G2"}}g2{{else}}msm{{end}}.{{.CurvePrefix}}Msm(scalars, points, &cfg, out) - assert.Equal(t, e, runtime.Success, "Msm failed") + suite.Equal(e, runtime.Success, "Msm failed") outHost := make(core.HostSlice[{{if ne .CurvePrefix "G2"}}icicle{{capitalize .Curve}}{{else}}g2{{end}}.{{.CurvePrefix}}Projective], 1) outHost.CopyFromDeviceAsync(&out, stream) out.FreeAsync(stream) @@ -417,9 +417,29 @@ func TestMSM{{.CurvePrefix}}MultiDevice(t *testing.T) { runtime.SynchronizeStream(stream) runtime.DestroyStream(stream) {{if ne .GnarkImport "" -}}// Check with gnark-crypto - testAgainstGnarkCryptoMsm{{.CurvePrefix}}(t, scalars, points, outHost[0]){{end}} + testAgainstGnarkCryptoMsm{{.CurvePrefix}}(suite, scalars, points, outHost[0]){{end}} } }) } wg.Wait() } + +type MSM{{.CurvePrefix}}TestSuite struct { + suite.Suite +} + +func (s *MSM{{.CurvePrefix}}TestSuite) TestMSM{{.CurvePrefix}}() { + s.Run("TestMSM{{.CurvePrefix}}", testWrapper(s.Suite, testMSM{{.CurvePrefix}})) + {{if ne .GnarkImport "" -}} + s.Run("TestMSM{{.CurvePrefix}}GnarkCryptoTypes", testWrapper(s.Suite, testMSM{{.CurvePrefix}}GnarkCryptoTypes)) + {{end -}} + s.Run("TestMSM{{.CurvePrefix}}Batch", testWrapper(s.Suite, testMSM{{.CurvePrefix}}Batch)) + s.Run("TestPrecomputePoints{{.CurvePrefix}}", testWrapper(s.Suite, testPrecomputePoints{{.CurvePrefix}})) + s.Run("TestPrecomputePointsSharedBases{{.CurvePrefix}}", testWrapper(s.Suite, testPrecomputePointsSharedBases{{.CurvePrefix}})) + s.Run("TestMSM{{.CurvePrefix}}SkewedDistribution", testWrapper(s.Suite, testMSM{{.CurvePrefix}}SkewedDistribution)) + s.Run("TestMSM{{.CurvePrefix}}MultiDevice", testWrapper(s.Suite, testMSM{{.CurvePrefix}}MultiDevice)) +} + +func TestSuiteMSM{{.CurvePrefix}}(t *testing.T) { + suite.Run(t, new(MSM{{.CurvePrefix}}TestSuite)) +} diff --git a/wrappers/golang/internal/generator/ntt/templates/ntt_no_domain_test.go.tmpl b/wrappers/golang/internal/generator/ntt/templates/ntt_no_domain_test.go.tmpl index af6d9a4d6..e99c03a58 100644 --- a/wrappers/golang/internal/generator/ntt/templates/ntt_no_domain_test.go.tmpl +++ b/wrappers/golang/internal/generator/ntt/templates/ntt_no_domain_test.go.tmpl @@ -1,10 +1,8 @@ package tests import ( - {{if ne .GnarkImport "" -}} - "reflect" - {{end -}} "testing" + "github.com/stretchr/testify/suite" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" @@ -12,7 +10,7 @@ import ( ntt "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}/ntt" ) -func TestNttNoDomain(t *testing.T) { +func testNttNoDomain(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := {{.FieldNoDomain}}.GenerateScalars(1 << largestTestSize) @@ -32,7 +30,7 @@ func TestNttNoDomain(t *testing.T) { } } -func TestNttDeviceAsyncNoDomain(t *testing.T) { +func testNttDeviceAsyncNoDomain(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := {{.FieldNoDomain}}.GenerateScalars(1 << largestTestSize) @@ -67,7 +65,7 @@ func TestNttDeviceAsyncNoDomain(t *testing.T) { } } -func TestNttBatchNoDomain(t *testing.T) { +func testNttBatchNoDomain(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() largestTestSize := 12 largestBatchSize := 100 @@ -90,3 +88,17 @@ func TestNttBatchNoDomain(t *testing.T) { } } } + +type NTTNoDomainTestSuite struct { + suite.Suite +} + +func (s *NTTNoDomainTestSuite) TestNTTNoDomain() { + s.Run("TestNTTNoDomain", testWrapper(s.Suite, testNttNoDomain)) + s.Run("TestNttDeviceAsyncNoDomain", testWrapper(s.Suite, testNttDeviceAsyncNoDomain)) + s.Run("TestNttBatchNoDomain", testWrapper(s.Suite, testNttBatchNoDomain)) +} + +func TestSuiteNTTNoDomain(t *testing.T) { + suite.Run(t, new(NTTNoDomainTestSuite)) +} diff --git a/wrappers/golang/internal/generator/ntt/templates/ntt_test.go.tmpl b/wrappers/golang/internal/generator/ntt/templates/ntt_test.go.tmpl index f91053526..e6db0b998 100644 --- a/wrappers/golang/internal/generator/ntt/templates/ntt_test.go.tmpl +++ b/wrappers/golang/internal/generator/ntt/templates/ntt_test.go.tmpl @@ -15,11 +15,11 @@ import ( "github.com/consensys/gnark-crypto/ecc/{{.GnarkImport}}/fr" "github.com/consensys/gnark-crypto/ecc/{{.GnarkImport}}/fr/fft" {{end -}} - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) {{if ne .GnarkImport "" -}} -func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], output core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNtt(suite suite.Suite, size int, scalars core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], output core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], order core.Ordering, direction core.NTTDir) { scalarsFr := make([]fr.Element, size) for i, v := range scalars { slice64, _ := fr.LittleEndian.Element((*[fr.Bytes]byte)(v.ToBytesLittleEndian())) @@ -31,10 +31,10 @@ func testAgainstGnarkCryptoNtt(t *testing.T, size int, scalars core.HostSlice[{{ outputAsFr[i] = slice64 } - testAgainstGnarkCryptoNttGnarkTypes(t, size, scalarsFr, outputAsFr, order, direction) + testAgainstGnarkCryptoNttGnarkTypes(suite, size, scalarsFr, outputAsFr, order, direction) } -func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { +func testAgainstGnarkCryptoNttGnarkTypes(suite suite.Suite, size int, scalarsFr core.HostSlice[fr.Element], outputAsFr core.HostSlice[fr.Element], order core.Ordering, direction core.NTTDir) { domainWithPrecompute := fft.NewDomain(uint64(size)) // DIT + BitReverse == Ordering.kRR // DIT == Ordering.kRN @@ -56,27 +56,21 @@ func testAgainstGnarkCryptoNttGnarkTypes(t *testing.T, size int, scalarsFr core. if order == core.KNN || order == core.KRR { fft.BitReverse(scalarsFr) } - assert.Equal(t, scalarsFr, outputAsFr) + suite.Equal(scalarsFr, outputAsFr) } {{end -}} -func TestNTTGetDefaultConfig(t *testing.T) { +func testNTTGetDefaultConfig(suite suite.Suite) { actual := ntt.GetDefaultNttConfig() expected := test_helpers.GenerateLimbOne(int({{.Field}}.{{toConst .FieldPrefix}}LIMBS)) - assert.Equal(t, expected, actual.CosetGen[:]) + suite.Equal(expected, actual.CosetGen[:]) cosetGenField := {{.Field}}.{{.FieldPrefix}}Field{} cosetGenField.One() - assert.ElementsMatch(t, cosetGenField.GetLimbs(), actual.CosetGen) + suite.ElementsMatch(cosetGenField.GetLimbs(), actual.CosetGen) } -func TestInitDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - cfg := core.GetDefaultNTTInitDomainConfig() - assert.NotPanics(t, func() { initDomain({{if ne .GnarkImport "" -}}largestTestSize, {{end -}}cfg) }) -} - -func TestNtt(t *testing.T) { +func testNtt(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := {{.Field}}.GenerateScalars(1 << largestTestSize) @@ -95,13 +89,13 @@ func TestNtt(t *testing.T) { {{if ne .GnarkImport "" -}} // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, core.KForward) {{end -}} } } } {{if ne .GnarkImport "" -}} -func TestNttFrElement(t *testing.T) { +func testNttFrElement(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := make([]fr.Element, 4) var x fr.Element @@ -124,12 +118,12 @@ func TestNttFrElement(t *testing.T) { ntt.Ntt(scalarsCopy, core.KForward, &cfg, output) // Compare with gnark-crypto - testAgainstGnarkCryptoNttGnarkTypes(t, testSize, scalarsCopy, output, v, core.KForward) + testAgainstGnarkCryptoNttGnarkTypes(suite, testSize, scalarsCopy, output, v, core.KForward) } } } {{end}} -func TestNttDeviceAsync(t *testing.T) { +func testNttDeviceAsync(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() scalars := {{.Field}}.GenerateScalars(1 << largestTestSize) @@ -161,14 +155,14 @@ func TestNttDeviceAsync(t *testing.T) { runtime.DestroyStream(stream) {{if ne .GnarkImport "" -}} // Compare with gnark-crypto - testAgainstGnarkCryptoNtt(t, testSize, scalarsCopy, output, v, direction) + testAgainstGnarkCryptoNtt(suite, testSize, scalarsCopy, output, v, direction) {{end -}} } } } } -func TestNttBatch(t *testing.T) { +func testNttBatch(suite suite.Suite) { cfg := ntt.GetDefaultNttConfig() largestTestSize := 10 largestBatchSize := 20 @@ -207,8 +201,8 @@ func TestNttBatch(t *testing.T) { domainWithPrecompute.FFT(scalarsFr, fft.DIF) fft.BitReverse(scalarsFr) - if !assert.True(t, reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { - t.FailNow() + if !suite.True(reflect.DeepEqual(scalarsFr, outputAsFr[i*testSize:(i+1)*testSize])) { + suite.T().FailNow() } } {{end -}} @@ -216,8 +210,20 @@ func TestNttBatch(t *testing.T) { } } -func TestReleaseDomain(t *testing.T) { - t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function") - e := ntt.ReleaseDomain() - assert.Equal(t, runtime.Success, e, "ReleasDomain failed") +type NTTTestSuite struct { + suite.Suite +} + +func (s *NTTTestSuite) TestNTT() { + s.Run("TestNTTGetDefaultConfig", testWrapper(s.Suite, testNTTGetDefaultConfig)) + s.Run("TestNTT", testWrapper(s.Suite, testNtt)) + {{if ne .GnarkImport "" -}} + s.Run("TestNTTFrElement", testWrapper(s.Suite, testNttFrElement)) + {{end -}} + s.Run("TestNttDeviceAsync", testWrapper(s.Suite, testNttDeviceAsync)) + s.Run("TestNttBatch", testWrapper(s.Suite, testNttBatch)) +} + +func TestSuiteNTT(t *testing.T) { + suite.Run(t, new(NTTTestSuite)) } diff --git a/wrappers/golang/internal/generator/polynomial/templates/polynomial_test.go.tmpl b/wrappers/golang/internal/generator/polynomial/templates/polynomial_test.go.tmpl index de13cfabc..5d1982ba1 100644 --- a/wrappers/golang/internal/generator/polynomial/templates/polynomial_test.go.tmpl +++ b/wrappers/golang/internal/generator/polynomial/templates/polynomial_test.go.tmpl @@ -6,10 +6,9 @@ import ( "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" {{.Field}} "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}" - // "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}/ntt" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}/polynomial" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) var one, two, three, four, five {{.Field}}.{{.FieldPrefix}}Field @@ -41,7 +40,7 @@ func vecOp(a, b {{.Field}}.{{.FieldPrefix}}Field, op core.VecOps) {{.Field}}.{{. return out[0] } -func TestPolyCreateFromCoefficients(t *testing.T) { +func testPolyCreateFromCoefficients(suite suite.Suite) { scalars := {{.Field}}.GenerateScalars(33) var uniPoly polynomial.DensePolynomial @@ -49,7 +48,7 @@ func TestPolyCreateFromCoefficients(t *testing.T) { poly.Print() } -func TestPolyEval(t *testing.T) { +func testPolyEval(suite suite.Suite) { // testing correct evaluation of f(8) for f(x)=4x^2+2x+5 coeffs := core.HostSliceFromElements([]{{.Field}}.{{.FieldPrefix}}Field{five, two, four}) var f polynomial.DensePolynomial @@ -62,10 +61,10 @@ func TestPolyEval(t *testing.T) { evals := make(core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], 1) fEvaled := f.EvalOnDomain(domains, evals) var expected {{.Field}}.{{.FieldPrefix}}Field - assert.Equal(t, expected.FromUint32(277), fEvaled.(core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field])[0]) + suite.Equal(expected.FromUint32(277), fEvaled.(core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field])[0]) } -func TestPolyClone(t *testing.T) { +func testPolyClone(suite suite.Suite) { f := randomPoly(8) x := rand() fx := f.Eval(x) @@ -76,11 +75,11 @@ func TestPolyClone(t *testing.T) { gx := g.Eval(x) fgx := fg.Eval(x) - assert.Equal(t, fx, gx) - assert.Equal(t, vecOp(fx, gx, core.Add), fgx) + suite.Equal(fx, gx) + suite.Equal(vecOp(fx, gx, core.Add), fgx) } -func TestPolyAddSubMul(t *testing.T) { +func testPolyAddSubMul(suite suite.Suite) { testSize := 1 << 10 f := randomPoly(testSize) g := randomPoly(testSize) @@ -91,26 +90,26 @@ func TestPolyAddSubMul(t *testing.T) { polyAdd := f.Add(&g) fxAddgx := vecOp(fx, gx, core.Add) - assert.Equal(t, polyAdd.Eval(x), fxAddgx) + suite.Equal(polyAdd.Eval(x), fxAddgx) polySub := f.Subtract(&g) fxSubgx := vecOp(fx, gx, core.Sub) - assert.Equal(t, polySub.Eval(x), fxSubgx) + suite.Equal(polySub.Eval(x), fxSubgx) polyMul := f.Multiply(&g) fxMulgx := vecOp(fx, gx, core.Mul) - assert.Equal(t, polyMul.Eval(x), fxMulgx) + suite.Equal(polyMul.Eval(x), fxMulgx) s1 := rand() polMulS1 := f.MultiplyByScalar(s1) - assert.Equal(t, polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) + suite.Equal(polMulS1.Eval(x), vecOp(fx, s1, core.Mul)) s2 := rand() polMulS2 := f.MultiplyByScalar(s2) - assert.Equal(t, polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) + suite.Equal(polMulS2.Eval(x), vecOp(fx, s2, core.Mul)) } -func TestPolyMonomials(t *testing.T) { +func testPolyMonomials(suite suite.Suite) { var zero {{.Field}}.{{.FieldPrefix}}Field var f polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements([]{{.Field}}.{{.FieldPrefix}}Field{one, zero, two})) @@ -119,20 +118,20 @@ func TestPolyMonomials(t *testing.T) { fx := f.Eval(x) f.AddMonomial(three, 1) fxAdded := f.Eval(x) - assert.Equal(t, fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) + suite.Equal(fxAdded, vecOp(fx, vecOp(three, x, core.Mul), core.Add)) f.SubMonomial(one, 0) fxSub := f.Eval(x) - assert.Equal(t, fxSub, vecOp(fxAdded, one, core.Sub)) + suite.Equal(fxSub, vecOp(fxAdded, one, core.Sub)) } -func TestPolyReadCoeffs(t *testing.T) { +func testPolyReadCoeffs(suite suite.Suite) { var f polynomial.DensePolynomial coeffs := core.HostSliceFromElements([]{{.Field}}.{{.FieldPrefix}}Field{one, two, three, four}) f.CreateFromCoeffecitients(coeffs) coeffsCopied := make(core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], coeffs.Len()) _, _ = f.CopyCoeffsRange(0, coeffs.Len()-1, coeffsCopied) - assert.ElementsMatch(t, coeffs, coeffsCopied) + suite.ElementsMatch(coeffs, coeffsCopied) var coeffsDevice core.DeviceSlice coeffsDevice.Malloc(one.Size(), coeffs.Len()) @@ -140,16 +139,16 @@ func TestPolyReadCoeffs(t *testing.T) { coeffsHost := make(core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], coeffs.Len()) coeffsHost.CopyFromDevice(&coeffsDevice) - assert.ElementsMatch(t, coeffs, coeffsHost) + suite.ElementsMatch(coeffs, coeffsHost) } -func TestPolyOddEvenSlicing(t *testing.T) { +func testPolyOddEvenSlicing(suite suite.Suite) { size := 1<<10 - 3 f := randomPoly(size) even := f.Even() odd := f.Odd() - assert.Equal(t, f.Degree(), even.Degree()+odd.Degree()+1) + suite.Equal(f.Degree(), even.Degree()+odd.Degree()+1) x := rand() var evenExpected, oddExpected {{.Field}}.{{.FieldPrefix}}Field @@ -164,13 +163,13 @@ func TestPolyOddEvenSlicing(t *testing.T) { } evenEvaled := even.Eval(x) - assert.Equal(t, evenExpected, evenEvaled) + suite.Equal(evenExpected, evenEvaled) oddEvaled := odd.Eval(x) - assert.Equal(t, oddExpected, oddEvaled) + suite.Equal(oddExpected, oddEvaled) } -func TestPolynomialDivision(t *testing.T) { +func testPolynomialDivision(suite suite.Suite) { // divide f(x)/g(x), compute q(x), r(x) and check f(x)=q(x)*g(x)+r(x) var f, g polynomial.DensePolynomial f.CreateFromCoeffecitients(core.HostSliceFromElements({{.Field}}.GenerateScalars(1 << 4))) @@ -184,10 +183,10 @@ func TestPolynomialDivision(t *testing.T) { x := {{.Field}}.GenerateScalars(1)[0] fEval := f.Eval(x) fReconEval := fRecon.Eval(x) - assert.Equal(t, fEval, fReconEval) + suite.Equal(fEval, fReconEval) } -func TestDivideByVanishing(t *testing.T) { +func testDivideByVanishing(suite suite.Suite) { // poly of x^4-1 vanishes ad 4th rou var zero {{.Field}}.{{.FieldPrefix}}Field minus_one := vecOp(zero, one, core.Sub) @@ -200,31 +199,51 @@ func TestDivideByVanishing(t *testing.T) { fv := f.Multiply(&v) fDegree := f.Degree() fvDegree := fv.Degree() - assert.Equal(t, fDegree+4, fvDegree) + suite.Equal(fDegree+4, fvDegree) fReconstructed := fv.DivideByVanishing(4) - assert.Equal(t, fDegree, fReconstructed.Degree()) + suite.Equal(fDegree, fReconstructed.Degree()) x := rand() - assert.Equal(t, f.Eval(x), fReconstructed.Eval(x)) + suite.Equal(f.Eval(x), fReconstructed.Eval(x)) } -// func TestPolySlice(t *testing.T) { +// func TestPolySlice(suite suite.Suite) { // size := 4 // coeffs := {{.Field}}.GenerateScalars(size) // var f DensePolynomial // f.CreateFromCoeffecitients(coeffs) // fSlice := f.AsSlice() -// assert.True(t, fSlice.IsOnDevice()) -// assert.Equal(t, size, fSlice.Len()) +// suite.True(fSlice.IsOnDevice()) +// suite.Equal(size, fSlice.Len()) // hostSlice := make(core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], size) // hostSlice.CopyFromDevice(fSlice) -// assert.Equal(t, coeffs, hostSlice) +// suite.Equal(coeffs, hostSlice) // cfg := ntt.GetDefaultNttConfig() // res := make(core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], size) // ntt.Ntt(fSlice, core.KForward, cfg, res) -// assert.Equal(t, f.Eval(one), res[0]) +// suite.Equal(f.Eval(one), res[0]) // } + +type PolynomialTestSuite struct { + suite.Suite +} + +func (s *PolynomialTestSuite) TestPolynomial() { + s.Run("TestPolyCreateFromCoefficients", testWrapper(s.Suite, testPolyCreateFromCoefficients)) + s.Run("TestPolyEval", testWrapper(s.Suite, testPolyEval)) + s.Run("TestPolyClone", testWrapper(s.Suite, testPolyClone)) + s.Run("TestPolyAddSubMul", testWrapper(s.Suite, testPolyAddSubMul)) + s.Run("TestPolyMonomials", testWrapper(s.Suite, testPolyMonomials)) + s.Run("TestPolyReadCoeffs", testWrapper(s.Suite, testPolyReadCoeffs)) + s.Run("TestPolyOddEvenSlicing", testWrapper(s.Suite, testPolyOddEvenSlicing)) + s.Run("TestPolynomialDivision", testWrapper(s.Suite, testPolynomialDivision)) + s.Run("TestDivideByVanishing", testWrapper(s.Suite, testDivideByVanishing)) +} + +func TestSuitePolynomial(t *testing.T) { + suite.Run(t, new(PolynomialTestSuite)) +} diff --git a/wrappers/golang/internal/generator/tests/templates/main_test.go.tmpl b/wrappers/golang/internal/generator/tests/templates/main_test.go.tmpl index 8be26804e..fc864e899 100644 --- a/wrappers/golang/internal/generator/tests/templates/main_test.go.tmpl +++ b/wrappers/golang/internal/generator/tests/templates/main_test.go.tmpl @@ -2,8 +2,11 @@ package tests import ( "testing" - {{if .SupportsNTT -}} + "sync" "fmt" + "os" + "github.com/stretchr/testify/suite" + {{if .SupportsNTT -}} "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" {{.Field}} "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}"{{end}} @@ -19,7 +22,11 @@ import ( const ( largestTestSize = 20 ) -var DEVICE runtime.Device + +var ( + DEVICE runtime.Device + exitCode int +) {{if .SupportsNTT -}} func initDomain({{if ne .GnarkImport "" -}}largestTestSize int, {{end -}}cfg core.NTTInitDomainConfig) runtime.EIcicleError { @@ -38,6 +45,18 @@ func initDomain({{if ne .GnarkImport "" -}}largestTestSize int, {{end -}}cfg cor return e }{{end}} +func testWrapper(suite suite.Suite, fn func(suite.Suite)) func() { + return func() { + wg := sync.WaitGroup{} + wg.Add(1) + runtime.RunOnDevice(&DEVICE, func(args ...any) { + defer wg.Done() + fn(suite) + }) + wg.Wait() + } +} + func TestMain(m *testing.M) { runtime.LoadBackendFromEnvOrDefault() devices, e := runtime.GetRegisteredDevices() @@ -45,6 +64,7 @@ func TestMain(m *testing.M) { panic("Failed to load registered devices") } for _, deviceType := range devices { + fmt.Println("Running tests for device type:", deviceType) DEVICE = runtime.CreateDevice(deviceType, 0) runtime.SetDevice(&DEVICE) @@ -59,8 +79,10 @@ func TestMain(m *testing.M) { } }{{end}} + // TODO - run tests for each device type without calling `m.Run` multiple times + // see https://cs.opensource.google/go/go/+/refs/tags/go1.23.1:src/testing/testing.go;l=1936-1940 for more info // execute tests - m.Run() + exitCode |= m.Run() {{if .SupportsNTT -}}// release domain e = ntt.ReleaseDomain() @@ -72,4 +94,6 @@ func TestMain(m *testing.M) { } }{{end}} } + + os.Exit(exitCode) } diff --git a/wrappers/golang/internal/generator/vecOps/templates/vec_ops_test.go.tmpl b/wrappers/golang/internal/generator/vecOps/templates/vec_ops_test.go.tmpl index 66bb00a01..d704be991 100644 --- a/wrappers/golang/internal/generator/vecOps/templates/vec_ops_test.go.tmpl +++ b/wrappers/golang/internal/generator/vecOps/templates/vec_ops_test.go.tmpl @@ -6,10 +6,10 @@ import ( {{.Field}} "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/{{.BaseImportPath}}/vecOps" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -func Test{{capitalize .Field}}VecOps(t *testing.T) { +func test{{capitalize .Field}}VecOps(suite suite.Suite) { testSize := 1 << 14 a := {{.Field}}.GenerateScalars(testSize) @@ -27,14 +27,14 @@ func Test{{capitalize .Field}}VecOps(t *testing.T) { vecOps.VecOp(a, b, out, cfg, core.Add) vecOps.VecOp(out, b, out2, cfg, core.Sub) - assert.Equal(t, a, out2) + suite.Equal(a, out2) vecOps.VecOp(a, ones, out3, cfg, core.Mul) - assert.Equal(t, a, out3) + suite.Equal(a, out3) } -func Test{{capitalize .Field}}Transpose(t *testing.T) { +func test{{capitalize .Field}}Transpose(suite suite.Suite) { rowSize := 1 << 6 columnSize := 1 << 8 @@ -48,8 +48,7 @@ func Test{{capitalize .Field}}Transpose(t *testing.T) { vecOps.TransposeMatrix(matrix, out, columnSize, rowSize, cfg) vecOps.TransposeMatrix(out, out2, rowSize, columnSize, cfg) - - assert.Equal(t, matrix, out2) + suite.Equal(matrix, out2) var dMatrix, dOut, dOut2 core.DeviceSlice @@ -62,5 +61,18 @@ func Test{{capitalize .Field}}Transpose(t *testing.T) { output := make(core.HostSlice[{{.Field}}.{{.FieldPrefix}}Field], rowSize*columnSize) output.CopyFromDevice(&dOut2) - assert.Equal(t, matrix, output) + suite.Equal(matrix, output) } + +type {{capitalize .Field}}VecOpsTestSuite struct { + suite.Suite +} + +func (s *{{capitalize .Field}}VecOpsTestSuite) Test{{capitalize .Field}}VecOps() { + s.Run("Test{{capitalize .Field}}VecOps", testWrapper(s.Suite, test{{capitalize .Field}}VecOps)) + s.Run("Test{{capitalize .Field}}Transpose", testWrapper(s.Suite, test{{capitalize .Field}}Transpose)) +} + +func TestSuite{{capitalize .Field}}VecOps(t *testing.T) { + suite.Run(t, new({{capitalize .Field}}VecOpsTestSuite)) +} \ No newline at end of file diff --git a/wrappers/golang/runtime/runtime.go b/wrappers/golang/runtime/runtime.go index 68fdf4615..4d6d7dd39 100644 --- a/wrappers/golang/runtime/runtime.go +++ b/wrappers/golang/runtime/runtime.go @@ -46,11 +46,11 @@ func IsActiveDeviceMemory(ptr unsafe.Pointer) bool { return EIcicleError(cErr) == Success } -// RunOnDevice forces the provided function to run all GPU related calls within it -// on the same host thread and therefore the same GPU device. +// RunOnDevice forces the provided function to run all device related calls within it +// on the same host thread and therefore the same device. // // NOTE: Goroutines launched within funcToRun are not bound to the -// same host thread as funcToRun and therefore not to the same GPU device. +// same host thread as funcToRun and therefore not to the same device. // If that is a requirement, RunOnDevice should be called for each with the // same deviceId as the original call. // @@ -78,11 +78,11 @@ func IsActiveDeviceMemory(ptr unsafe.Pointer) bool { // // }, i) func RunOnDevice(device *Device, funcToRun func(args ...any), args ...any) { - go func(id *Device) { + go func(deviceToRunOn *Device) { defer runtime.UnlockOSThread() runtime.LockOSThread() originalDevice, _ := GetActiveDevice() - SetDevice(id) + SetDevice(deviceToRunOn) funcToRun(args...) SetDevice(originalDevice) }(device) diff --git a/wrappers/golang/runtime/tests/device_test.go b/wrappers/golang/runtime/tests/device_test.go index a938c2b75..a4c114389 100644 --- a/wrappers/golang/runtime/tests/device_test.go +++ b/wrappers/golang/runtime/tests/device_test.go @@ -1,6 +1,7 @@ package tests import ( + "os/exec" "testing" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" @@ -21,10 +22,16 @@ func TestGetDeviceType(t *testing.T) { func TestIsDeviceAvailable(t *testing.T) { runtime.LoadBackendFromEnvOrDefault() dev := runtime.CreateDevice("CUDA", 0) - err := runtime.SetDevice(&dev) + _ = runtime.SetDevice(&dev) res, err := runtime.GetDeviceCount() + + expectedNumDevices, error := exec.Command("nvidia-smi", "-L", "|", "wc", "-l").Output() + if error != nil { + t.Skip("Failed to get number of devices") + } + assert.Equal(t, runtime.Success, err) - assert.Equal(t, res, 2) + assert.Equal(t, expectedNumDevices, res) err = runtime.LoadBackendFromEnvOrDefault() assert.Equal(t, runtime.Success, err) @@ -39,7 +46,7 @@ func TestIsDeviceAvailable(t *testing.T) { func TestRegisteredDevices(t *testing.T) { err := runtime.LoadBackendFromEnvOrDefault() assert.Equal(t, runtime.Success, err) - devices, err := runtime.GetRegisteredDevices() + devices, _ := runtime.GetRegisteredDevices() assert.Equal(t, []string{"CUDA", "CPU"}, devices) } diff --git a/wrappers/rust/icicle-core/src/curve.rs b/wrappers/rust/icicle-core/src/curve.rs index 8b0f2cbe7..a668b6589 100644 --- a/wrappers/rust/icicle-core/src/curve.rs +++ b/wrappers/rust/icicle-core/src/curve.rs @@ -69,6 +69,10 @@ impl Affine { } pub fn to_projective(&self) -> Projective { + if *self == (Affine::::zero()) { + return Projective::::zero(); + } + Projective { x: self.x, y: self.y,