From f797dae1fbaa30cf1a7133f89f37d333adb590f5 Mon Sep 17 00:00:00 2001 From: Emir Soyturk Date: Wed, 4 Dec 2024 13:27:19 +0300 Subject: [PATCH] Host arithmetic for Rust and Golang Co-authored-by: Yuval Shekel --- icicle/src/curves/ffi_extern.cpp | 32 ++++ icicle/src/fields/ffi_extern.cpp | 40 +++++ wrappers/golang/curves/bls12377/curve.go | 24 +++ wrappers/golang/curves/bls12377/g2/curve.go | 24 +++ .../golang/curves/bls12377/g2/include/curve.h | 2 + .../curves/bls12377/g2/include/scalar_field.h | 4 + .../golang/curves/bls12377/include/curve.h | 2 + .../curves/bls12377/include/scalar_field.h | 4 + .../golang/curves/bls12377/scalar_field.go | 58 ++++++++ .../curves/bls12377/tests/curve_test.go | 13 ++ .../curves/bls12377/tests/g2_curve_test.go | 13 ++ .../bls12377/tests/scalar_field_test.go | 29 ++++ wrappers/golang/curves/bls12381/curve.go | 24 +++ wrappers/golang/curves/bls12381/g2/curve.go | 24 +++ .../golang/curves/bls12381/g2/include/curve.h | 2 + .../curves/bls12381/g2/include/scalar_field.h | 4 + .../golang/curves/bls12381/include/curve.h | 2 + .../curves/bls12381/include/scalar_field.h | 4 + .../golang/curves/bls12381/scalar_field.go | 58 ++++++++ .../curves/bls12381/tests/curve_test.go | 13 ++ .../curves/bls12381/tests/g2_curve_test.go | 13 ++ .../bls12381/tests/scalar_field_test.go | 29 ++++ wrappers/golang/curves/bn254/curve.go | 24 +++ wrappers/golang/curves/bn254/g2/curve.go | 24 +++ .../golang/curves/bn254/g2/include/curve.h | 2 + .../curves/bn254/g2/include/scalar_field.h | 4 + wrappers/golang/curves/bn254/include/curve.h | 2 + .../curves/bn254/include/scalar_field.h | 4 + wrappers/golang/curves/bn254/scalar_field.go | 58 ++++++++ .../golang/curves/bn254/tests/curve_test.go | 13 ++ .../curves/bn254/tests/g2_curve_test.go | 13 ++ .../curves/bn254/tests/scalar_field_test.go | 29 ++++ wrappers/golang/curves/bw6761/curve.go | 24 +++ wrappers/golang/curves/bw6761/g2/curve.go | 24 +++ .../golang/curves/bw6761/g2/include/curve.h | 2 + .../curves/bw6761/g2/include/scalar_field.h | 4 + wrappers/golang/curves/bw6761/include/curve.h | 2 + .../curves/bw6761/include/scalar_field.h | 4 + wrappers/golang/curves/bw6761/scalar_field.go | 58 ++++++++ .../golang/curves/bw6761/tests/curve_test.go | 13 ++ .../curves/bw6761/tests/g2_curve_test.go | 13 ++ .../curves/bw6761/tests/scalar_field_test.go | 29 ++++ wrappers/golang/curves/grumpkin/curve.go | 24 +++ .../golang/curves/grumpkin/include/curve.h | 2 + .../curves/grumpkin/include/scalar_field.h | 4 + .../golang/curves/grumpkin/scalar_field.go | 58 ++++++++ .../curves/grumpkin/tests/curve_test.go | 13 ++ .../grumpkin/tests/scalar_field_test.go | 29 ++++ .../babybear/extension/extension_field.go | 58 ++++++++ .../babybear/extension/include/scalar_field.h | 4 + .../fields/babybear/include/scalar_field.h | 4 + .../golang/fields/babybear/scalar_field.go | 58 ++++++++ .../babybear/tests/extension_field_test.go | 29 ++++ .../babybear/tests/scalar_field_test.go | 29 ++++ .../generator/curves/templates/curve.go.tmpl | 24 +++ .../generator/curves/templates/curve.h.tmpl | 2 + .../curves/templates/curve_test.go.tmpl | 13 ++ .../generator/fields/templates/field.go.tmpl | 58 ++++++++ .../fields/templates/field_test.go.tmpl | 31 ++++ .../fields/templates/scalar_field.h.tmpl | 4 + wrappers/rust/icicle-core/src/curve.rs | 98 +++++++++++- wrappers/rust/icicle-core/src/field.rs | 140 +++++++++++++++++- wrappers/rust/icicle-core/src/tests.rs | 40 ++++- wrappers/rust/icicle-core/src/traits.rs | 6 + 64 files changed, 1488 insertions(+), 4 deletions(-) diff --git a/icicle/src/curves/ffi_extern.cpp b/icicle/src/curves/ffi_extern.cpp index db27a9126..99bc4ed10 100644 --- a/icicle/src/curves/ffi_extern.cpp +++ b/icicle/src/curves/ffi_extern.cpp @@ -15,6 +15,21 @@ extern "C" bool CONCAT_EXPAND(CURVE, eq)(projective_t* point1, projective_t* poi (point2->z == point_field_t::zero())); } +extern "C" void CONCAT_EXPAND(CURVE, ecsub)(projective_t* point1, projective_t* point2, projective_t* result) +{ + *result = *point1 - *point2; +} + +extern "C" void CONCAT_EXPAND(CURVE, ecadd)(projective_t* point1, projective_t* point2, projective_t* result) +{ + *result = *point1 + *point2; +} + +extern "C" void CONCAT_EXPAND(CURVE, mul_scalar)(projective_t* point, scalar_t* scalar, projective_t* result) +{ + *result = *point * *scalar; +} + extern "C" void CONCAT_EXPAND(CURVE, to_affine)(projective_t* point, affine_t* point_out) { *point_out = projective_t::to_affine(*point); @@ -46,6 +61,23 @@ extern "C" bool CONCAT_EXPAND(CURVE, g2_eq)(g2_projective_t* point1, g2_projecti (point2->z == g2_point_field_t::zero())); } +extern "C" void +CONCAT_EXPAND(CURVE, g2_ecsub)(g2_projective_t* point1, g2_projective_t* point2, g2_projective_t* result) +{ + *result = *point1 - *point2; +} + +extern "C" void +CONCAT_EXPAND(CURVE, g2_ecadd)(g2_projective_t* point1, g2_projective_t* point2, g2_projective_t* result) +{ + *result = *point1 + *point2; +} + +extern "C" void CONCAT_EXPAND(CURVE, g2_mul_scalar)(g2_projective_t* point, scalar_t* scalar, g2_projective_t* result) +{ + *result = *point * *scalar; +} + extern "C" void CONCAT_EXPAND(CURVE, g2_to_affine)(g2_projective_t* point, g2_affine_t* point_out) { *point_out = g2_projective_t::to_affine(*point); diff --git a/icicle/src/fields/ffi_extern.cpp b/icicle/src/fields/ffi_extern.cpp index 811c317a3..12a0ce146 100644 --- a/icicle/src/fields/ffi_extern.cpp +++ b/icicle/src/fields/ffi_extern.cpp @@ -8,9 +8,49 @@ extern "C" void CONCAT_EXPAND(FIELD, generate_scalars)(scalar_t* scalars, int si scalar_t::rand_host_many(scalars, size); } +extern "C" void CONCAT_EXPAND(FIELD, sub)(scalar_t* scalar1, scalar_t* scalar2, scalar_t* result) +{ + *result = *scalar1 - *scalar2; +} + +extern "C" void CONCAT_EXPAND(FIELD, add)(scalar_t* scalar1, scalar_t* scalar2, scalar_t* result) +{ + *result = *scalar1 + *scalar2; +} + +extern "C" void CONCAT_EXPAND(FIELD, mul)(scalar_t* scalar1, scalar_t* scalar2, scalar_t* result) +{ + *result = *scalar1 * *scalar2; +} + +extern "C" void CONCAT_EXPAND(FIELD, inv)(scalar_t* scalar1, scalar_t* result) +{ + *result = scalar_t::inverse(*scalar1); +} + #ifdef EXT_FIELD extern "C" void CONCAT_EXPAND(FIELD, extension_generate_scalars)(extension_t* scalars, int size) { extension_t::rand_host_many(scalars, size); } + +extern "C" void CONCAT_EXPAND(FIELD, extension_sub)(extension_t* scalar1, extension_t* scalar2, extension_t* result) +{ + *result = *scalar1 - *scalar2; +} + +extern "C" void CONCAT_EXPAND(FIELD, extension_add)(extension_t* scalar1, extension_t* scalar2, extension_t* result) +{ + *result = *scalar1 + *scalar2; +} + +extern "C" void CONCAT_EXPAND(FIELD, extension_mul)(extension_t* scalar1, extension_t* scalar2, extension_t* result) +{ + *result = *scalar1 * *scalar2; +} + +extern "C" void CONCAT_EXPAND(FIELD, extension_inv)(extension_t* scalar1, extension_t* result) +{ + *result = extension_t::inverse(*scalar1); +} #endif // EXT_FIELD diff --git a/wrappers/golang/curves/bls12377/curve.go b/wrappers/golang/curves/bls12377/curve.go index 8d013078d..9f417c7a4 100644 --- a/wrappers/golang/curves/bls12377/curve.go +++ b/wrappers/golang/curves/bls12377/curve.go @@ -54,6 +54,30 @@ func (p Projective) ProjectiveEq(p2 *Projective) bool { return __ret == (C._Bool)(true) } +func (p Projective) Add(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.bls12_377_ecadd(cP, cP2, cRes) + + return res +} + +func (p Projective) Sub(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.bls12_377_ecsub(cP, cP2, cRes) + + return res +} + func (p *Projective) ToAffine() Affine { var a Affine diff --git a/wrappers/golang/curves/bls12377/g2/curve.go b/wrappers/golang/curves/bls12377/g2/curve.go index a6ae15635..a989a576f 100644 --- a/wrappers/golang/curves/bls12377/g2/curve.go +++ b/wrappers/golang/curves/bls12377/g2/curve.go @@ -54,6 +54,30 @@ func (p G2Projective) ProjectiveEq(p2 *G2Projective) bool { return __ret == (C._Bool)(true) } +func (p G2Projective) Add(p2 *G2Projective) G2Projective { + var res G2Projective + + cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2)) + cRes := (*C.g2_projective_t)(unsafe.Pointer(&res)) + + C.bls12_377_g2_ecadd(cP, cP2, cRes) + + return res +} + +func (p G2Projective) Sub(p2 *G2Projective) G2Projective { + var res G2Projective + + cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2)) + cRes := (*C.g2_projective_t)(unsafe.Pointer(&res)) + + C.bls12_377_g2_ecsub(cP, cP2, cRes) + + return res +} + func (p *G2Projective) ToAffine() G2Affine { var a G2Affine diff --git a/wrappers/golang/curves/bls12377/g2/include/curve.h b/wrappers/golang/curves/bls12377/g2/include/curve.h index 579955d92..4ef4effcb 100644 --- a/wrappers/golang/curves/bls12377/g2/include/curve.h +++ b/wrappers/golang/curves/bls12377/g2/include/curve.h @@ -12,6 +12,8 @@ typedef struct g2_affine_t g2_affine_t; typedef struct VecOpsConfig VecOpsConfig; bool bls12_377_g2_eq(g2_projective_t* point1, g2_projective_t* point2); +void bls12_377_g2_ecadd(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res); +void bls12_377_g2_ecsub(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res); void bls12_377_g2_to_affine(g2_projective_t* point, g2_affine_t* point_out); void bls12_377_g2_from_affine(g2_affine_t* point, g2_projective_t* point_out); void bls12_377_g2_generate_projective_points(g2_projective_t* points, int size); diff --git a/wrappers/golang/curves/bls12377/g2/include/scalar_field.h b/wrappers/golang/curves/bls12377/g2/include/scalar_field.h index 97de9e414..0fe8729a0 100644 --- a/wrappers/golang/curves/bls12377/g2/include/scalar_field.h +++ b/wrappers/golang/curves/bls12377/g2/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void bls12_377_generate_scalars(scalar_t* scalars, int size); int bls12_377_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void bls12_377_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_377_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_377_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_377_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12377/include/curve.h b/wrappers/golang/curves/bls12377/include/curve.h index 9741831b3..d898d807d 100644 --- a/wrappers/golang/curves/bls12377/include/curve.h +++ b/wrappers/golang/curves/bls12377/include/curve.h @@ -12,6 +12,8 @@ typedef struct affine_t affine_t; typedef struct VecOpsConfig VecOpsConfig; bool bls12_377_eq(projective_t* point1, projective_t* point2); +void bls12_377_ecadd(projective_t* point, projective_t* point2, projective_t* res); +void bls12_377_ecsub(projective_t* point, projective_t* point2, projective_t* res); void bls12_377_to_affine(projective_t* point, affine_t* point_out); void bls12_377_from_affine(affine_t* point, projective_t* point_out); void bls12_377_generate_projective_points(projective_t* points, int size); diff --git a/wrappers/golang/curves/bls12377/include/scalar_field.h b/wrappers/golang/curves/bls12377/include/scalar_field.h index 97de9e414..0fe8729a0 100644 --- a/wrappers/golang/curves/bls12377/include/scalar_field.h +++ b/wrappers/golang/curves/bls12377/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void bls12_377_generate_scalars(scalar_t* scalars, int size); int bls12_377_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void bls12_377_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_377_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_377_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_377_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12377/scalar_field.go b/wrappers/golang/curves/bls12377/scalar_field.go index a21f480bf..c771f6ae3 100644 --- a/wrappers/golang/curves/bls12377/scalar_field.go +++ b/wrappers/golang/curves/bls12377/scalar_field.go @@ -110,6 +110,64 @@ func GenerateScalars(size int) core.HostSlice[ScalarField] { return scalarSlice } +func (f ScalarField) Add(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_377_add(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Sub(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_377_sub(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Mul(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_377_mul(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Inv() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_377_inv(cF, cRes) + + return res +} + +func (f ScalarField) Sqr() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_377_mul(cF, cF, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/bls12377/tests/curve_test.go b/wrappers/golang/curves/bls12377/tests/curve_test.go index 8b921d2fa..09ed66b36 100644 --- a/wrappers/golang/curves/bls12377/tests/curve_test.go +++ b/wrappers/golang/curves/bls12377/tests/curve_test.go @@ -86,6 +86,18 @@ func testProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func testProjectiveArithmetic(suite *suite.Suite) { + points := bls12_377.GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func testProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *CurveTestSuite) TestCurve() { s.Run("TestProjectiveZero", testWrapper(&s.Suite, testProjectiveZero)) s.Run("TestProjectiveFromLimbs", testWrapper(&s.Suite, testProjectiveFromLimbs)) s.Run("TestProjectiveFromAffine", testWrapper(&s.Suite, testProjectiveFromAffine)) + s.Run("TestProjectiveArithmetic", testWrapper(&s.Suite, testProjectiveArithmetic)) } func TestSuiteCurve(t *testing.T) { diff --git a/wrappers/golang/curves/bls12377/tests/g2_curve_test.go b/wrappers/golang/curves/bls12377/tests/g2_curve_test.go index 461e72292..93d21f84c 100644 --- a/wrappers/golang/curves/bls12377/tests/g2_curve_test.go +++ b/wrappers/golang/curves/bls12377/tests/g2_curve_test.go @@ -86,6 +86,18 @@ func testG2ProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func testG2ProjectiveArithmetic(suite *suite.Suite) { + points := g2.G2GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func testG2ProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *G2CurveTestSuite) TestG2Curve() { s.Run("TestG2ProjectiveZero", testWrapper(&s.Suite, testG2ProjectiveZero)) s.Run("TestG2ProjectiveFromLimbs", testWrapper(&s.Suite, testG2ProjectiveFromLimbs)) s.Run("TestG2ProjectiveFromAffine", testWrapper(&s.Suite, testG2ProjectiveFromAffine)) + s.Run("TestG2ProjectiveArithmetic", testWrapper(&s.Suite, testG2ProjectiveArithmetic)) } func TestSuiteG2Curve(t *testing.T) { diff --git a/wrappers/golang/curves/bls12377/tests/scalar_field_test.go b/wrappers/golang/curves/bls12377/tests/scalar_field_test.go index 0912c833d..e7928eb19 100644 --- a/wrappers/golang/curves/bls12377/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bls12377/tests/scalar_field_test.go @@ -100,6 +100,34 @@ func testBls12_377GenerateScalars(suite *suite.Suite) { suite.NotContains(scalars, zeroScalar) } +func testScalarFieldArithmetic(suite *suite.Suite) { + const size = 1 << 10 + + scalarsA := bls12_377.GenerateScalars(size) + scalarsB := bls12_377.GenerateScalars(size) + + for i := 0; i < size; i++ { + result1 := scalarsA[i].Add(&scalarsB[i]) + result2 := result1.Sub(&scalarsB[i]) + + suite.Equal(scalarsA[i], result2, "Addition and subtraction do not yield the original value") + } + + scalarA := scalarsA[0] + square := scalarA.Sqr() + mul := scalarA.Mul(&scalarA) + + suite.Equal(square, mul, "Square and multiplication do not yield the same value") + + inv := scalarA.Inv() + + one := scalarA.Mul(&inv) + expectedOne := bls12_377.GenerateScalars(1)[0] + expectedOne.One() + + suite.Equal(expectedOne, one) +} + func testBls12_377MongtomeryConversion(suite *suite.Suite) { size := 1 << 20 scalars := bls12_377.GenerateScalars(size) @@ -133,6 +161,7 @@ func (s *ScalarFieldTestSuite) TestScalarField() { s.Run("TestScalarFieldAsPointer", testWrapper(&s.Suite, testScalarFieldAsPointer)) s.Run("TestScalarFieldFromBytes", testWrapper(&s.Suite, testScalarFieldFromBytes)) s.Run("TestScalarFieldToBytes", testWrapper(&s.Suite, testScalarFieldToBytes)) + s.Run("TestScalarFieldArithmetic", testWrapper(&s.Suite, testScalarFieldArithmetic)) s.Run("TestBls12_377GenerateScalars", testWrapper(&s.Suite, testBls12_377GenerateScalars)) s.Run("TestBls12_377MongtomeryConversion", testWrapper(&s.Suite, testBls12_377MongtomeryConversion)) } diff --git a/wrappers/golang/curves/bls12381/curve.go b/wrappers/golang/curves/bls12381/curve.go index a92a9dd72..23e549490 100644 --- a/wrappers/golang/curves/bls12381/curve.go +++ b/wrappers/golang/curves/bls12381/curve.go @@ -54,6 +54,30 @@ func (p Projective) ProjectiveEq(p2 *Projective) bool { return __ret == (C._Bool)(true) } +func (p Projective) Add(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.bls12_381_ecadd(cP, cP2, cRes) + + return res +} + +func (p Projective) Sub(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.bls12_381_ecsub(cP, cP2, cRes) + + return res +} + func (p *Projective) ToAffine() Affine { var a Affine diff --git a/wrappers/golang/curves/bls12381/g2/curve.go b/wrappers/golang/curves/bls12381/g2/curve.go index 7edad758f..35ce63a6f 100644 --- a/wrappers/golang/curves/bls12381/g2/curve.go +++ b/wrappers/golang/curves/bls12381/g2/curve.go @@ -54,6 +54,30 @@ func (p G2Projective) ProjectiveEq(p2 *G2Projective) bool { return __ret == (C._Bool)(true) } +func (p G2Projective) Add(p2 *G2Projective) G2Projective { + var res G2Projective + + cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2)) + cRes := (*C.g2_projective_t)(unsafe.Pointer(&res)) + + C.bls12_381_g2_ecadd(cP, cP2, cRes) + + return res +} + +func (p G2Projective) Sub(p2 *G2Projective) G2Projective { + var res G2Projective + + cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2)) + cRes := (*C.g2_projective_t)(unsafe.Pointer(&res)) + + C.bls12_381_g2_ecsub(cP, cP2, cRes) + + return res +} + func (p *G2Projective) ToAffine() G2Affine { var a G2Affine diff --git a/wrappers/golang/curves/bls12381/g2/include/curve.h b/wrappers/golang/curves/bls12381/g2/include/curve.h index fb570e6a1..0ac871612 100644 --- a/wrappers/golang/curves/bls12381/g2/include/curve.h +++ b/wrappers/golang/curves/bls12381/g2/include/curve.h @@ -12,6 +12,8 @@ typedef struct g2_affine_t g2_affine_t; typedef struct VecOpsConfig VecOpsConfig; bool bls12_381_g2_eq(g2_projective_t* point1, g2_projective_t* point2); +void bls12_381_g2_ecadd(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res); +void bls12_381_g2_ecsub(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res); void bls12_381_g2_to_affine(g2_projective_t* point, g2_affine_t* point_out); void bls12_381_g2_from_affine(g2_affine_t* point, g2_projective_t* point_out); void bls12_381_g2_generate_projective_points(g2_projective_t* points, int size); diff --git a/wrappers/golang/curves/bls12381/g2/include/scalar_field.h b/wrappers/golang/curves/bls12381/g2/include/scalar_field.h index f0807d9e9..54c08d6f4 100644 --- a/wrappers/golang/curves/bls12381/g2/include/scalar_field.h +++ b/wrappers/golang/curves/bls12381/g2/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void bls12_381_generate_scalars(scalar_t* scalars, int size); int bls12_381_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void bls12_381_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_381_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_381_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_381_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12381/include/curve.h b/wrappers/golang/curves/bls12381/include/curve.h index 9e4b6f49d..fb3fbe2fd 100644 --- a/wrappers/golang/curves/bls12381/include/curve.h +++ b/wrappers/golang/curves/bls12381/include/curve.h @@ -12,6 +12,8 @@ typedef struct affine_t affine_t; typedef struct VecOpsConfig VecOpsConfig; bool bls12_381_eq(projective_t* point1, projective_t* point2); +void bls12_381_ecadd(projective_t* point, projective_t* point2, projective_t* res); +void bls12_381_ecsub(projective_t* point, projective_t* point2, projective_t* res); void bls12_381_to_affine(projective_t* point, affine_t* point_out); void bls12_381_from_affine(affine_t* point, projective_t* point_out); void bls12_381_generate_projective_points(projective_t* points, int size); diff --git a/wrappers/golang/curves/bls12381/include/scalar_field.h b/wrappers/golang/curves/bls12381/include/scalar_field.h index f0807d9e9..54c08d6f4 100644 --- a/wrappers/golang/curves/bls12381/include/scalar_field.h +++ b/wrappers/golang/curves/bls12381/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void bls12_381_generate_scalars(scalar_t* scalars, int size); int bls12_381_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void bls12_381_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_381_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_381_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bls12_381_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bls12381/scalar_field.go b/wrappers/golang/curves/bls12381/scalar_field.go index 89092939a..ed6e7002e 100644 --- a/wrappers/golang/curves/bls12381/scalar_field.go +++ b/wrappers/golang/curves/bls12381/scalar_field.go @@ -110,6 +110,64 @@ func GenerateScalars(size int) core.HostSlice[ScalarField] { return scalarSlice } +func (f ScalarField) Add(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_381_add(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Sub(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_381_sub(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Mul(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_381_mul(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Inv() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_381_inv(cF, cRes) + + return res +} + +func (f ScalarField) Sqr() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bls12_381_mul(cF, cF, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/bls12381/tests/curve_test.go b/wrappers/golang/curves/bls12381/tests/curve_test.go index f1e741be1..89286343e 100644 --- a/wrappers/golang/curves/bls12381/tests/curve_test.go +++ b/wrappers/golang/curves/bls12381/tests/curve_test.go @@ -86,6 +86,18 @@ func testProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func testProjectiveArithmetic(suite *suite.Suite) { + points := bls12_381.GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func testProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *CurveTestSuite) TestCurve() { s.Run("TestProjectiveZero", testWrapper(&s.Suite, testProjectiveZero)) s.Run("TestProjectiveFromLimbs", testWrapper(&s.Suite, testProjectiveFromLimbs)) s.Run("TestProjectiveFromAffine", testWrapper(&s.Suite, testProjectiveFromAffine)) + s.Run("TestProjectiveArithmetic", testWrapper(&s.Suite, testProjectiveArithmetic)) } func TestSuiteCurve(t *testing.T) { diff --git a/wrappers/golang/curves/bls12381/tests/g2_curve_test.go b/wrappers/golang/curves/bls12381/tests/g2_curve_test.go index be51d74ea..cec4f0bf8 100644 --- a/wrappers/golang/curves/bls12381/tests/g2_curve_test.go +++ b/wrappers/golang/curves/bls12381/tests/g2_curve_test.go @@ -86,6 +86,18 @@ func testG2ProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func testG2ProjectiveArithmetic(suite *suite.Suite) { + points := g2.G2GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func testG2ProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *G2CurveTestSuite) TestG2Curve() { s.Run("TestG2ProjectiveZero", testWrapper(&s.Suite, testG2ProjectiveZero)) s.Run("TestG2ProjectiveFromLimbs", testWrapper(&s.Suite, testG2ProjectiveFromLimbs)) s.Run("TestG2ProjectiveFromAffine", testWrapper(&s.Suite, testG2ProjectiveFromAffine)) + s.Run("TestG2ProjectiveArithmetic", testWrapper(&s.Suite, testG2ProjectiveArithmetic)) } func TestSuiteG2Curve(t *testing.T) { diff --git a/wrappers/golang/curves/bls12381/tests/scalar_field_test.go b/wrappers/golang/curves/bls12381/tests/scalar_field_test.go index fbc353803..c4b68d9cc 100644 --- a/wrappers/golang/curves/bls12381/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bls12381/tests/scalar_field_test.go @@ -100,6 +100,34 @@ func testBls12_381GenerateScalars(suite *suite.Suite) { suite.NotContains(scalars, zeroScalar) } +func testScalarFieldArithmetic(suite *suite.Suite) { + const size = 1 << 10 + + scalarsA := bls12_381.GenerateScalars(size) + scalarsB := bls12_381.GenerateScalars(size) + + for i := 0; i < size; i++ { + result1 := scalarsA[i].Add(&scalarsB[i]) + result2 := result1.Sub(&scalarsB[i]) + + suite.Equal(scalarsA[i], result2, "Addition and subtraction do not yield the original value") + } + + scalarA := scalarsA[0] + square := scalarA.Sqr() + mul := scalarA.Mul(&scalarA) + + suite.Equal(square, mul, "Square and multiplication do not yield the same value") + + inv := scalarA.Inv() + + one := scalarA.Mul(&inv) + expectedOne := bls12_381.GenerateScalars(1)[0] + expectedOne.One() + + suite.Equal(expectedOne, one) +} + func testBls12_381MongtomeryConversion(suite *suite.Suite) { size := 1 << 20 scalars := bls12_381.GenerateScalars(size) @@ -133,6 +161,7 @@ func (s *ScalarFieldTestSuite) TestScalarField() { s.Run("TestScalarFieldAsPointer", testWrapper(&s.Suite, testScalarFieldAsPointer)) s.Run("TestScalarFieldFromBytes", testWrapper(&s.Suite, testScalarFieldFromBytes)) s.Run("TestScalarFieldToBytes", testWrapper(&s.Suite, testScalarFieldToBytes)) + s.Run("TestScalarFieldArithmetic", testWrapper(&s.Suite, testScalarFieldArithmetic)) s.Run("TestBls12_381GenerateScalars", testWrapper(&s.Suite, testBls12_381GenerateScalars)) s.Run("TestBls12_381MongtomeryConversion", testWrapper(&s.Suite, testBls12_381MongtomeryConversion)) } diff --git a/wrappers/golang/curves/bn254/curve.go b/wrappers/golang/curves/bn254/curve.go index bc83f9061..b08b0f33b 100644 --- a/wrappers/golang/curves/bn254/curve.go +++ b/wrappers/golang/curves/bn254/curve.go @@ -54,6 +54,30 @@ func (p Projective) ProjectiveEq(p2 *Projective) bool { return __ret == (C._Bool)(true) } +func (p Projective) Add(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.bn254_ecadd(cP, cP2, cRes) + + return res +} + +func (p Projective) Sub(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.bn254_ecsub(cP, cP2, cRes) + + return res +} + func (p *Projective) ToAffine() Affine { var a Affine diff --git a/wrappers/golang/curves/bn254/g2/curve.go b/wrappers/golang/curves/bn254/g2/curve.go index cd3e885b2..29193dd2c 100644 --- a/wrappers/golang/curves/bn254/g2/curve.go +++ b/wrappers/golang/curves/bn254/g2/curve.go @@ -54,6 +54,30 @@ func (p G2Projective) ProjectiveEq(p2 *G2Projective) bool { return __ret == (C._Bool)(true) } +func (p G2Projective) Add(p2 *G2Projective) G2Projective { + var res G2Projective + + cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2)) + cRes := (*C.g2_projective_t)(unsafe.Pointer(&res)) + + C.bn254_g2_ecadd(cP, cP2, cRes) + + return res +} + +func (p G2Projective) Sub(p2 *G2Projective) G2Projective { + var res G2Projective + + cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2)) + cRes := (*C.g2_projective_t)(unsafe.Pointer(&res)) + + C.bn254_g2_ecsub(cP, cP2, cRes) + + return res +} + func (p *G2Projective) ToAffine() G2Affine { var a G2Affine diff --git a/wrappers/golang/curves/bn254/g2/include/curve.h b/wrappers/golang/curves/bn254/g2/include/curve.h index 2ab4fedf0..a35d2e5ba 100644 --- a/wrappers/golang/curves/bn254/g2/include/curve.h +++ b/wrappers/golang/curves/bn254/g2/include/curve.h @@ -12,6 +12,8 @@ typedef struct g2_affine_t g2_affine_t; typedef struct VecOpsConfig VecOpsConfig; bool bn254_g2_eq(g2_projective_t* point1, g2_projective_t* point2); +void bn254_g2_ecadd(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res); +void bn254_g2_ecsub(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res); void bn254_g2_to_affine(g2_projective_t* point, g2_affine_t* point_out); void bn254_g2_from_affine(g2_affine_t* point, g2_projective_t* point_out); void bn254_g2_generate_projective_points(g2_projective_t* points, int size); diff --git a/wrappers/golang/curves/bn254/g2/include/scalar_field.h b/wrappers/golang/curves/bn254/g2/include/scalar_field.h index 9101faa80..96cda48f9 100644 --- a/wrappers/golang/curves/bn254/g2/include/scalar_field.h +++ b/wrappers/golang/curves/bn254/g2/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void bn254_generate_scalars(scalar_t* scalars, int size); int bn254_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void bn254_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bn254_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bn254_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bn254_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bn254/include/curve.h b/wrappers/golang/curves/bn254/include/curve.h index 43c4bc902..f9ac81700 100644 --- a/wrappers/golang/curves/bn254/include/curve.h +++ b/wrappers/golang/curves/bn254/include/curve.h @@ -12,6 +12,8 @@ typedef struct affine_t affine_t; typedef struct VecOpsConfig VecOpsConfig; bool bn254_eq(projective_t* point1, projective_t* point2); +void bn254_ecadd(projective_t* point, projective_t* point2, projective_t* res); +void bn254_ecsub(projective_t* point, projective_t* point2, projective_t* res); void bn254_to_affine(projective_t* point, affine_t* point_out); void bn254_from_affine(affine_t* point, projective_t* point_out); void bn254_generate_projective_points(projective_t* points, int size); diff --git a/wrappers/golang/curves/bn254/include/scalar_field.h b/wrappers/golang/curves/bn254/include/scalar_field.h index 9101faa80..96cda48f9 100644 --- a/wrappers/golang/curves/bn254/include/scalar_field.h +++ b/wrappers/golang/curves/bn254/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void bn254_generate_scalars(scalar_t* scalars, int size); int bn254_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void bn254_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bn254_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bn254_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bn254_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bn254/scalar_field.go b/wrappers/golang/curves/bn254/scalar_field.go index 8372381fe..0c07ad3ed 100644 --- a/wrappers/golang/curves/bn254/scalar_field.go +++ b/wrappers/golang/curves/bn254/scalar_field.go @@ -110,6 +110,64 @@ func GenerateScalars(size int) core.HostSlice[ScalarField] { return scalarSlice } +func (f ScalarField) Add(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bn254_add(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Sub(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bn254_sub(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Mul(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bn254_mul(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Inv() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bn254_inv(cF, cRes) + + return res +} + +func (f ScalarField) Sqr() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bn254_mul(cF, cF, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/bn254/tests/curve_test.go b/wrappers/golang/curves/bn254/tests/curve_test.go index 380b48e70..560946ef1 100644 --- a/wrappers/golang/curves/bn254/tests/curve_test.go +++ b/wrappers/golang/curves/bn254/tests/curve_test.go @@ -86,6 +86,18 @@ func testProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func testProjectiveArithmetic(suite *suite.Suite) { + points := bn254.GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func testProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *CurveTestSuite) TestCurve() { s.Run("TestProjectiveZero", testWrapper(&s.Suite, testProjectiveZero)) s.Run("TestProjectiveFromLimbs", testWrapper(&s.Suite, testProjectiveFromLimbs)) s.Run("TestProjectiveFromAffine", testWrapper(&s.Suite, testProjectiveFromAffine)) + s.Run("TestProjectiveArithmetic", testWrapper(&s.Suite, testProjectiveArithmetic)) } func TestSuiteCurve(t *testing.T) { diff --git a/wrappers/golang/curves/bn254/tests/g2_curve_test.go b/wrappers/golang/curves/bn254/tests/g2_curve_test.go index 1cc2d50e1..ebd829e47 100644 --- a/wrappers/golang/curves/bn254/tests/g2_curve_test.go +++ b/wrappers/golang/curves/bn254/tests/g2_curve_test.go @@ -86,6 +86,18 @@ func testG2ProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func testG2ProjectiveArithmetic(suite *suite.Suite) { + points := g2.G2GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func testG2ProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *G2CurveTestSuite) TestG2Curve() { s.Run("TestG2ProjectiveZero", testWrapper(&s.Suite, testG2ProjectiveZero)) s.Run("TestG2ProjectiveFromLimbs", testWrapper(&s.Suite, testG2ProjectiveFromLimbs)) s.Run("TestG2ProjectiveFromAffine", testWrapper(&s.Suite, testG2ProjectiveFromAffine)) + s.Run("TestG2ProjectiveArithmetic", testWrapper(&s.Suite, testG2ProjectiveArithmetic)) } func TestSuiteG2Curve(t *testing.T) { diff --git a/wrappers/golang/curves/bn254/tests/scalar_field_test.go b/wrappers/golang/curves/bn254/tests/scalar_field_test.go index 9156c9ad7..01f5c6609 100644 --- a/wrappers/golang/curves/bn254/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bn254/tests/scalar_field_test.go @@ -100,6 +100,34 @@ func testBn254GenerateScalars(suite *suite.Suite) { suite.NotContains(scalars, zeroScalar) } +func testScalarFieldArithmetic(suite *suite.Suite) { + const size = 1 << 10 + + scalarsA := bn254.GenerateScalars(size) + scalarsB := bn254.GenerateScalars(size) + + for i := 0; i < size; i++ { + result1 := scalarsA[i].Add(&scalarsB[i]) + result2 := result1.Sub(&scalarsB[i]) + + suite.Equal(scalarsA[i], result2, "Addition and subtraction do not yield the original value") + } + + scalarA := scalarsA[0] + square := scalarA.Sqr() + mul := scalarA.Mul(&scalarA) + + suite.Equal(square, mul, "Square and multiplication do not yield the same value") + + inv := scalarA.Inv() + + one := scalarA.Mul(&inv) + expectedOne := bn254.GenerateScalars(1)[0] + expectedOne.One() + + suite.Equal(expectedOne, one) +} + func testBn254MongtomeryConversion(suite *suite.Suite) { size := 1 << 20 scalars := bn254.GenerateScalars(size) @@ -133,6 +161,7 @@ func (s *ScalarFieldTestSuite) TestScalarField() { s.Run("TestScalarFieldAsPointer", testWrapper(&s.Suite, testScalarFieldAsPointer)) s.Run("TestScalarFieldFromBytes", testWrapper(&s.Suite, testScalarFieldFromBytes)) s.Run("TestScalarFieldToBytes", testWrapper(&s.Suite, testScalarFieldToBytes)) + s.Run("TestScalarFieldArithmetic", testWrapper(&s.Suite, testScalarFieldArithmetic)) s.Run("TestBn254GenerateScalars", testWrapper(&s.Suite, testBn254GenerateScalars)) s.Run("TestBn254MongtomeryConversion", testWrapper(&s.Suite, testBn254MongtomeryConversion)) } diff --git a/wrappers/golang/curves/bw6761/curve.go b/wrappers/golang/curves/bw6761/curve.go index a6778163a..165a0ee32 100644 --- a/wrappers/golang/curves/bw6761/curve.go +++ b/wrappers/golang/curves/bw6761/curve.go @@ -54,6 +54,30 @@ func (p Projective) ProjectiveEq(p2 *Projective) bool { return __ret == (C._Bool)(true) } +func (p Projective) Add(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.bw6_761_ecadd(cP, cP2, cRes) + + return res +} + +func (p Projective) Sub(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.bw6_761_ecsub(cP, cP2, cRes) + + return res +} + func (p *Projective) ToAffine() Affine { var a Affine diff --git a/wrappers/golang/curves/bw6761/g2/curve.go b/wrappers/golang/curves/bw6761/g2/curve.go index 9f880229e..66b21e447 100644 --- a/wrappers/golang/curves/bw6761/g2/curve.go +++ b/wrappers/golang/curves/bw6761/g2/curve.go @@ -54,6 +54,30 @@ func (p G2Projective) ProjectiveEq(p2 *G2Projective) bool { return __ret == (C._Bool)(true) } +func (p G2Projective) Add(p2 *G2Projective) G2Projective { + var res G2Projective + + cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2)) + cRes := (*C.g2_projective_t)(unsafe.Pointer(&res)) + + C.bw6_761_g2_ecadd(cP, cP2, cRes) + + return res +} + +func (p G2Projective) Sub(p2 *G2Projective) G2Projective { + var res G2Projective + + cP := (*C.g2_projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2)) + cRes := (*C.g2_projective_t)(unsafe.Pointer(&res)) + + C.bw6_761_g2_ecsub(cP, cP2, cRes) + + return res +} + func (p *G2Projective) ToAffine() G2Affine { var a G2Affine diff --git a/wrappers/golang/curves/bw6761/g2/include/curve.h b/wrappers/golang/curves/bw6761/g2/include/curve.h index b178fd678..fd579a6a0 100644 --- a/wrappers/golang/curves/bw6761/g2/include/curve.h +++ b/wrappers/golang/curves/bw6761/g2/include/curve.h @@ -12,6 +12,8 @@ typedef struct g2_affine_t g2_affine_t; typedef struct VecOpsConfig VecOpsConfig; bool bw6_761_g2_eq(g2_projective_t* point1, g2_projective_t* point2); +void bw6_761_g2_ecadd(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res); +void bw6_761_g2_ecsub(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res); void bw6_761_g2_to_affine(g2_projective_t* point, g2_affine_t* point_out); void bw6_761_g2_from_affine(g2_affine_t* point, g2_projective_t* point_out); void bw6_761_g2_generate_projective_points(g2_projective_t* points, int size); diff --git a/wrappers/golang/curves/bw6761/g2/include/scalar_field.h b/wrappers/golang/curves/bw6761/g2/include/scalar_field.h index 29ff19f4e..a40217357 100644 --- a/wrappers/golang/curves/bw6761/g2/include/scalar_field.h +++ b/wrappers/golang/curves/bw6761/g2/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void bw6_761_generate_scalars(scalar_t* scalars, int size); int bw6_761_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void bw6_761_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bw6_761_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bw6_761_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bw6_761_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bw6761/include/curve.h b/wrappers/golang/curves/bw6761/include/curve.h index 5364e783b..58487819e 100644 --- a/wrappers/golang/curves/bw6761/include/curve.h +++ b/wrappers/golang/curves/bw6761/include/curve.h @@ -12,6 +12,8 @@ typedef struct affine_t affine_t; typedef struct VecOpsConfig VecOpsConfig; bool bw6_761_eq(projective_t* point1, projective_t* point2); +void bw6_761_ecadd(projective_t* point, projective_t* point2, projective_t* res); +void bw6_761_ecsub(projective_t* point, projective_t* point2, projective_t* res); void bw6_761_to_affine(projective_t* point, affine_t* point_out); void bw6_761_from_affine(affine_t* point, projective_t* point_out); void bw6_761_generate_projective_points(projective_t* points, int size); diff --git a/wrappers/golang/curves/bw6761/include/scalar_field.h b/wrappers/golang/curves/bw6761/include/scalar_field.h index 29ff19f4e..a40217357 100644 --- a/wrappers/golang/curves/bw6761/include/scalar_field.h +++ b/wrappers/golang/curves/bw6761/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void bw6_761_generate_scalars(scalar_t* scalars, int size); int bw6_761_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void bw6_761_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bw6_761_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bw6_761_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void bw6_761_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/bw6761/scalar_field.go b/wrappers/golang/curves/bw6761/scalar_field.go index 5cb53afb4..8ab39796e 100644 --- a/wrappers/golang/curves/bw6761/scalar_field.go +++ b/wrappers/golang/curves/bw6761/scalar_field.go @@ -110,6 +110,64 @@ func GenerateScalars(size int) core.HostSlice[ScalarField] { return scalarSlice } +func (f ScalarField) Add(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bw6_761_add(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Sub(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bw6_761_sub(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Mul(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bw6_761_mul(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Inv() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bw6_761_inv(cF, cRes) + + return res +} + +func (f ScalarField) Sqr() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.bw6_761_mul(cF, cF, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/bw6761/tests/curve_test.go b/wrappers/golang/curves/bw6761/tests/curve_test.go index c7d9babd3..2a759ab26 100644 --- a/wrappers/golang/curves/bw6761/tests/curve_test.go +++ b/wrappers/golang/curves/bw6761/tests/curve_test.go @@ -86,6 +86,18 @@ func testProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func testProjectiveArithmetic(suite *suite.Suite) { + points := bw6_761.GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func testProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *CurveTestSuite) TestCurve() { s.Run("TestProjectiveZero", testWrapper(&s.Suite, testProjectiveZero)) s.Run("TestProjectiveFromLimbs", testWrapper(&s.Suite, testProjectiveFromLimbs)) s.Run("TestProjectiveFromAffine", testWrapper(&s.Suite, testProjectiveFromAffine)) + s.Run("TestProjectiveArithmetic", testWrapper(&s.Suite, testProjectiveArithmetic)) } func TestSuiteCurve(t *testing.T) { diff --git a/wrappers/golang/curves/bw6761/tests/g2_curve_test.go b/wrappers/golang/curves/bw6761/tests/g2_curve_test.go index 267aaeac9..8125f9b8b 100644 --- a/wrappers/golang/curves/bw6761/tests/g2_curve_test.go +++ b/wrappers/golang/curves/bw6761/tests/g2_curve_test.go @@ -86,6 +86,18 @@ func testG2ProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func testG2ProjectiveArithmetic(suite *suite.Suite) { + points := g2.G2GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func testG2ProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *G2CurveTestSuite) TestG2Curve() { s.Run("TestG2ProjectiveZero", testWrapper(&s.Suite, testG2ProjectiveZero)) s.Run("TestG2ProjectiveFromLimbs", testWrapper(&s.Suite, testG2ProjectiveFromLimbs)) s.Run("TestG2ProjectiveFromAffine", testWrapper(&s.Suite, testG2ProjectiveFromAffine)) + s.Run("TestG2ProjectiveArithmetic", testWrapper(&s.Suite, testG2ProjectiveArithmetic)) } func TestSuiteG2Curve(t *testing.T) { diff --git a/wrappers/golang/curves/bw6761/tests/scalar_field_test.go b/wrappers/golang/curves/bw6761/tests/scalar_field_test.go index aeaec4e9b..93bd0aa3d 100644 --- a/wrappers/golang/curves/bw6761/tests/scalar_field_test.go +++ b/wrappers/golang/curves/bw6761/tests/scalar_field_test.go @@ -100,6 +100,34 @@ func testBw6_761GenerateScalars(suite *suite.Suite) { suite.NotContains(scalars, zeroScalar) } +func testScalarFieldArithmetic(suite *suite.Suite) { + const size = 1 << 10 + + scalarsA := bw6_761.GenerateScalars(size) + scalarsB := bw6_761.GenerateScalars(size) + + for i := 0; i < size; i++ { + result1 := scalarsA[i].Add(&scalarsB[i]) + result2 := result1.Sub(&scalarsB[i]) + + suite.Equal(scalarsA[i], result2, "Addition and subtraction do not yield the original value") + } + + scalarA := scalarsA[0] + square := scalarA.Sqr() + mul := scalarA.Mul(&scalarA) + + suite.Equal(square, mul, "Square and multiplication do not yield the same value") + + inv := scalarA.Inv() + + one := scalarA.Mul(&inv) + expectedOne := bw6_761.GenerateScalars(1)[0] + expectedOne.One() + + suite.Equal(expectedOne, one) +} + func testBw6_761MongtomeryConversion(suite *suite.Suite) { size := 1 << 20 scalars := bw6_761.GenerateScalars(size) @@ -133,6 +161,7 @@ func (s *ScalarFieldTestSuite) TestScalarField() { s.Run("TestScalarFieldAsPointer", testWrapper(&s.Suite, testScalarFieldAsPointer)) s.Run("TestScalarFieldFromBytes", testWrapper(&s.Suite, testScalarFieldFromBytes)) s.Run("TestScalarFieldToBytes", testWrapper(&s.Suite, testScalarFieldToBytes)) + s.Run("TestScalarFieldArithmetic", testWrapper(&s.Suite, testScalarFieldArithmetic)) s.Run("TestBw6_761GenerateScalars", testWrapper(&s.Suite, testBw6_761GenerateScalars)) s.Run("TestBw6_761MongtomeryConversion", testWrapper(&s.Suite, testBw6_761MongtomeryConversion)) } diff --git a/wrappers/golang/curves/grumpkin/curve.go b/wrappers/golang/curves/grumpkin/curve.go index 915e83227..fb8af0abc 100644 --- a/wrappers/golang/curves/grumpkin/curve.go +++ b/wrappers/golang/curves/grumpkin/curve.go @@ -54,6 +54,30 @@ func (p Projective) ProjectiveEq(p2 *Projective) bool { return __ret == (C._Bool)(true) } +func (p Projective) Add(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.grumpkin_ecadd(cP, cP2, cRes) + + return res +} + +func (p Projective) Sub(p2 *Projective) Projective { + var res Projective + + cP := (*C.projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.projective_t)(unsafe.Pointer(p2)) + cRes := (*C.projective_t)(unsafe.Pointer(&res)) + + C.grumpkin_ecsub(cP, cP2, cRes) + + return res +} + func (p *Projective) ToAffine() Affine { var a Affine diff --git a/wrappers/golang/curves/grumpkin/include/curve.h b/wrappers/golang/curves/grumpkin/include/curve.h index e164919d1..073408cc8 100644 --- a/wrappers/golang/curves/grumpkin/include/curve.h +++ b/wrappers/golang/curves/grumpkin/include/curve.h @@ -12,6 +12,8 @@ typedef struct affine_t affine_t; typedef struct VecOpsConfig VecOpsConfig; bool grumpkin_eq(projective_t* point1, projective_t* point2); +void grumpkin_ecadd(projective_t* point, projective_t* point2, projective_t* res); +void grumpkin_ecsub(projective_t* point, projective_t* point2, projective_t* res); void grumpkin_to_affine(projective_t* point, affine_t* point_out); void grumpkin_from_affine(affine_t* point, projective_t* point_out); void grumpkin_generate_projective_points(projective_t* points, int size); diff --git a/wrappers/golang/curves/grumpkin/include/scalar_field.h b/wrappers/golang/curves/grumpkin/include/scalar_field.h index f9708c3db..c53e61bbd 100644 --- a/wrappers/golang/curves/grumpkin/include/scalar_field.h +++ b/wrappers/golang/curves/grumpkin/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void grumpkin_generate_scalars(scalar_t* scalars, int size); int grumpkin_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void grumpkin_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void grumpkin_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void grumpkin_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void grumpkin_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/curves/grumpkin/scalar_field.go b/wrappers/golang/curves/grumpkin/scalar_field.go index 8ef45b290..379975aa5 100644 --- a/wrappers/golang/curves/grumpkin/scalar_field.go +++ b/wrappers/golang/curves/grumpkin/scalar_field.go @@ -110,6 +110,64 @@ func GenerateScalars(size int) core.HostSlice[ScalarField] { return scalarSlice } +func (f ScalarField) Add(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.grumpkin_add(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Sub(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.grumpkin_sub(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Mul(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.grumpkin_mul(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Inv() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.grumpkin_inv(cF, cRes) + + return res +} + +func (f ScalarField) Sqr() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.grumpkin_mul(cF, cF, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/curves/grumpkin/tests/curve_test.go b/wrappers/golang/curves/grumpkin/tests/curve_test.go index 01743124f..1a40283aa 100644 --- a/wrappers/golang/curves/grumpkin/tests/curve_test.go +++ b/wrappers/golang/curves/grumpkin/tests/curve_test.go @@ -86,6 +86,18 @@ func testProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func testProjectiveArithmetic(suite *suite.Suite) { + points := grumpkin.GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func testProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *CurveTestSuite) TestCurve() { s.Run("TestProjectiveZero", testWrapper(&s.Suite, testProjectiveZero)) s.Run("TestProjectiveFromLimbs", testWrapper(&s.Suite, testProjectiveFromLimbs)) s.Run("TestProjectiveFromAffine", testWrapper(&s.Suite, testProjectiveFromAffine)) + s.Run("TestProjectiveArithmetic", testWrapper(&s.Suite, testProjectiveArithmetic)) } func TestSuiteCurve(t *testing.T) { diff --git a/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go b/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go index 9d6bad331..61e898cd6 100644 --- a/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go +++ b/wrappers/golang/curves/grumpkin/tests/scalar_field_test.go @@ -100,6 +100,34 @@ func testGrumpkinGenerateScalars(suite *suite.Suite) { suite.NotContains(scalars, zeroScalar) } +func testScalarFieldArithmetic(suite *suite.Suite) { + const size = 1 << 10 + + scalarsA := grumpkin.GenerateScalars(size) + scalarsB := grumpkin.GenerateScalars(size) + + for i := 0; i < size; i++ { + result1 := scalarsA[i].Add(&scalarsB[i]) + result2 := result1.Sub(&scalarsB[i]) + + suite.Equal(scalarsA[i], result2, "Addition and subtraction do not yield the original value") + } + + scalarA := scalarsA[0] + square := scalarA.Sqr() + mul := scalarA.Mul(&scalarA) + + suite.Equal(square, mul, "Square and multiplication do not yield the same value") + + inv := scalarA.Inv() + + one := scalarA.Mul(&inv) + expectedOne := grumpkin.GenerateScalars(1)[0] + expectedOne.One() + + suite.Equal(expectedOne, one) +} + func testGrumpkinMongtomeryConversion(suite *suite.Suite) { size := 1 << 20 scalars := grumpkin.GenerateScalars(size) @@ -133,6 +161,7 @@ func (s *ScalarFieldTestSuite) TestScalarField() { s.Run("TestScalarFieldAsPointer", testWrapper(&s.Suite, testScalarFieldAsPointer)) s.Run("TestScalarFieldFromBytes", testWrapper(&s.Suite, testScalarFieldFromBytes)) s.Run("TestScalarFieldToBytes", testWrapper(&s.Suite, testScalarFieldToBytes)) + s.Run("TestScalarFieldArithmetic", testWrapper(&s.Suite, testScalarFieldArithmetic)) s.Run("TestGrumpkinGenerateScalars", testWrapper(&s.Suite, testGrumpkinGenerateScalars)) s.Run("TestGrumpkinMongtomeryConversion", testWrapper(&s.Suite, testGrumpkinMongtomeryConversion)) } diff --git a/wrappers/golang/fields/babybear/extension/extension_field.go b/wrappers/golang/fields/babybear/extension/extension_field.go index 3d0a2ae33..e0d064e7f 100644 --- a/wrappers/golang/fields/babybear/extension/extension_field.go +++ b/wrappers/golang/fields/babybear/extension/extension_field.go @@ -110,6 +110,64 @@ func GenerateScalars(size int) core.HostSlice[ExtensionField] { return scalarSlice } +func (f ExtensionField) Add(f2 *ExtensionField) ExtensionField { + var res ExtensionField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_extension_add(cF, cF2, cRes) + + return res +} + +func (f ExtensionField) Sub(f2 *ExtensionField) ExtensionField { + var res ExtensionField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_extension_sub(cF, cF2, cRes) + + return res +} + +func (f ExtensionField) Mul(f2 *ExtensionField) ExtensionField { + var res ExtensionField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_extension_mul(cF, cF2, cRes) + + return res +} + +func (f ExtensionField) Inv() ExtensionField { + var res ExtensionField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_extension_inv(cF, cRes) + + return res +} + +func (f ExtensionField) Sqr() ExtensionField { + var res ExtensionField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_extension_mul(cF, cF, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/fields/babybear/extension/include/scalar_field.h b/wrappers/golang/fields/babybear/extension/include/scalar_field.h index 447ac2f1f..79e4a4603 100644 --- a/wrappers/golang/fields/babybear/extension/include/scalar_field.h +++ b/wrappers/golang/fields/babybear/extension/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void babybear_extension_generate_scalars(scalar_t* scalars, int size); int babybear_extension_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void babybear_extension_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void babybear_extension_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void babybear_extension_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void babybear_extension_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/fields/babybear/include/scalar_field.h b/wrappers/golang/fields/babybear/include/scalar_field.h index 36d1aff10..9f8f0e5a7 100644 --- a/wrappers/golang/fields/babybear/include/scalar_field.h +++ b/wrappers/golang/fields/babybear/include/scalar_field.h @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void babybear_generate_scalars(scalar_t* scalars, int size); int babybear_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void babybear_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void babybear_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void babybear_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void babybear_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/golang/fields/babybear/scalar_field.go b/wrappers/golang/fields/babybear/scalar_field.go index 78d798aae..ef6185b94 100644 --- a/wrappers/golang/fields/babybear/scalar_field.go +++ b/wrappers/golang/fields/babybear/scalar_field.go @@ -110,6 +110,64 @@ func GenerateScalars(size int) core.HostSlice[ScalarField] { return scalarSlice } +func (f ScalarField) Add(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_add(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Sub(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_sub(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Mul(f2 *ScalarField) ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_mul(cF, cF2, cRes) + + return res +} + +func (f ScalarField) Inv() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_inv(cF, cRes) + + return res +} + +func (f ScalarField) Sqr() ScalarField { + var res ScalarField + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.babybear_mul(cF, cF, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/fields/babybear/tests/extension_field_test.go b/wrappers/golang/fields/babybear/tests/extension_field_test.go index 4ba29f61f..e468fc309 100644 --- a/wrappers/golang/fields/babybear/tests/extension_field_test.go +++ b/wrappers/golang/fields/babybear/tests/extension_field_test.go @@ -100,6 +100,34 @@ func testBabybear_extensionGenerateScalars(suite *suite.Suite) { suite.NotContains(scalars, zeroScalar) } +func testExtensionFieldArithmetic(suite *suite.Suite) { + const size = 1 << 10 + + scalarsA := babybear_extension.GenerateScalars(size) + scalarsB := babybear_extension.GenerateScalars(size) + + for i := 0; i < size; i++ { + result1 := scalarsA[i].Add(&scalarsB[i]) + result2 := result1.Sub(&scalarsB[i]) + + suite.Equal(scalarsA[i], result2, "Addition and subtraction do not yield the original value") + } + + scalarA := scalarsA[0] + square := scalarA.Sqr() + mul := scalarA.Mul(&scalarA) + + suite.Equal(square, mul, "Square and multiplication do not yield the same value") + + inv := scalarA.Inv() + + one := scalarA.Mul(&inv) + expectedOne := babybear_extension.GenerateScalars(1)[0] + expectedOne.One() + + suite.Equal(expectedOne, one) +} + func testBabybear_extensionMongtomeryConversion(suite *suite.Suite) { size := 1 << 20 scalars := babybear_extension.GenerateScalars(size) @@ -133,6 +161,7 @@ func (s *ExtensionFieldTestSuite) TestExtensionField() { s.Run("TestExtensionFieldAsPointer", testWrapper(&s.Suite, testExtensionFieldAsPointer)) s.Run("TestExtensionFieldFromBytes", testWrapper(&s.Suite, testExtensionFieldFromBytes)) s.Run("TestExtensionFieldToBytes", testWrapper(&s.Suite, testExtensionFieldToBytes)) + s.Run("TestExtensionFieldArithmetic", testWrapper(&s.Suite, testExtensionFieldArithmetic)) s.Run("TestBabybear_extensionGenerateScalars", testWrapper(&s.Suite, testBabybear_extensionGenerateScalars)) s.Run("TestBabybear_extensionMongtomeryConversion", testWrapper(&s.Suite, testBabybear_extensionMongtomeryConversion)) } diff --git a/wrappers/golang/fields/babybear/tests/scalar_field_test.go b/wrappers/golang/fields/babybear/tests/scalar_field_test.go index fddfa25d9..abf8f0c3a 100644 --- a/wrappers/golang/fields/babybear/tests/scalar_field_test.go +++ b/wrappers/golang/fields/babybear/tests/scalar_field_test.go @@ -100,6 +100,34 @@ func testBabybearGenerateScalars(suite *suite.Suite) { suite.NotContains(scalars, zeroScalar) } +func testScalarFieldArithmetic(suite *suite.Suite) { + const size = 1 << 10 + + scalarsA := babybear.GenerateScalars(size) + scalarsB := babybear.GenerateScalars(size) + + for i := 0; i < size; i++ { + result1 := scalarsA[i].Add(&scalarsB[i]) + result2 := result1.Sub(&scalarsB[i]) + + suite.Equal(scalarsA[i], result2, "Addition and subtraction do not yield the original value") + } + + scalarA := scalarsA[0] + square := scalarA.Sqr() + mul := scalarA.Mul(&scalarA) + + suite.Equal(square, mul, "Square and multiplication do not yield the same value") + + inv := scalarA.Inv() + + one := scalarA.Mul(&inv) + expectedOne := babybear.GenerateScalars(1)[0] + expectedOne.One() + + suite.Equal(expectedOne, one) +} + func testBabybearMongtomeryConversion(suite *suite.Suite) { size := 1 << 20 scalars := babybear.GenerateScalars(size) @@ -133,6 +161,7 @@ func (s *ScalarFieldTestSuite) TestScalarField() { s.Run("TestScalarFieldAsPointer", testWrapper(&s.Suite, testScalarFieldAsPointer)) s.Run("TestScalarFieldFromBytes", testWrapper(&s.Suite, testScalarFieldFromBytes)) s.Run("TestScalarFieldToBytes", testWrapper(&s.Suite, testScalarFieldToBytes)) + s.Run("TestScalarFieldArithmetic", testWrapper(&s.Suite, testScalarFieldArithmetic)) s.Run("TestBabybearGenerateScalars", testWrapper(&s.Suite, testBabybearGenerateScalars)) s.Run("TestBabybearMongtomeryConversion", testWrapper(&s.Suite, testBabybearMongtomeryConversion)) } diff --git a/wrappers/golang/internal/generator/curves/templates/curve.go.tmpl b/wrappers/golang/internal/generator/curves/templates/curve.go.tmpl index 51e7dff68..ce6845239 100644 --- a/wrappers/golang/internal/generator/curves/templates/curve.go.tmpl +++ b/wrappers/golang/internal/generator/curves/templates/curve.go.tmpl @@ -55,6 +55,30 @@ func (p {{.CurvePrefix}}Projective) ProjectiveEq(p2 *{{.CurvePrefix}}Projective) return __ret == (C._Bool)(true) } +func (p {{.CurvePrefix}}Projective) Add(p2 *{{.CurvePrefix}}Projective) {{.CurvePrefix}}Projective { + var res {{.CurvePrefix}}Projective + + cP := (*C.{{toCName .CurvePrefix}}projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.{{toCName .CurvePrefix}}projective_t)(unsafe.Pointer(p2)) + cRes := (*C.{{toCName .CurvePrefix}}projective_t)(unsafe.Pointer(&res)) + + C.{{.Curve}}{{toCNameBackwards .CurvePrefix}}_ecadd(cP, cP2, cRes) + + return res +} + +func (p {{.CurvePrefix}}Projective) Sub(p2 *{{.CurvePrefix}}Projective) {{.CurvePrefix}}Projective { + var res {{.CurvePrefix}}Projective + + cP := (*C.{{toCName .CurvePrefix}}projective_t)(unsafe.Pointer(&p)) + cP2 := (*C.{{toCName .CurvePrefix}}projective_t)(unsafe.Pointer(p2)) + cRes := (*C.{{toCName .CurvePrefix}}projective_t)(unsafe.Pointer(&res)) + + C.{{.Curve}}{{toCNameBackwards .CurvePrefix}}_ecsub(cP, cP2, cRes) + + return res +} + func (p *{{.CurvePrefix}}Projective) ToAffine() {{.CurvePrefix}}Affine { var a {{.CurvePrefix}}Affine diff --git a/wrappers/golang/internal/generator/curves/templates/curve.h.tmpl b/wrappers/golang/internal/generator/curves/templates/curve.h.tmpl index bf690e57b..3f9865b73 100644 --- a/wrappers/golang/internal/generator/curves/templates/curve.h.tmpl +++ b/wrappers/golang/internal/generator/curves/templates/curve.h.tmpl @@ -12,6 +12,8 @@ typedef struct {{toCName .CurvePrefix}}affine_t {{toCName .CurvePrefix}}affine_t typedef struct VecOpsConfig VecOpsConfig; bool {{.Curve}}{{toCNameBackwards .CurvePrefix}}_eq({{toCName .CurvePrefix}}projective_t* point1, {{toCName .CurvePrefix}}projective_t* point2); +void {{.Curve}}{{toCNameBackwards .CurvePrefix}}_ecadd({{toCName .CurvePrefix}}projective_t* point, {{toCName .CurvePrefix}}projective_t* point2, {{toCName .CurvePrefix}}projective_t* res); +void {{.Curve}}{{toCNameBackwards .CurvePrefix}}_ecsub({{toCName .CurvePrefix}}projective_t* point, {{toCName .CurvePrefix}}projective_t* point2, {{toCName .CurvePrefix}}projective_t* res); void {{.Curve}}{{toCNameBackwards .CurvePrefix}}_to_affine({{toCName .CurvePrefix}}projective_t* point, {{toCName .CurvePrefix}}affine_t* point_out); void {{.Curve}}{{toCNameBackwards .CurvePrefix}}_from_affine({{toCName .CurvePrefix}}affine_t* point, {{toCName .CurvePrefix}}projective_t* point_out); void {{.Curve}}{{toCNameBackwards .CurvePrefix}}_generate_projective_points({{toCName .CurvePrefix}}projective_t* points, int size); diff --git a/wrappers/golang/internal/generator/curves/templates/curve_test.go.tmpl b/wrappers/golang/internal/generator/curves/templates/curve_test.go.tmpl index 3ddacf4cd..6bf65b884 100644 --- a/wrappers/golang/internal/generator/curves/templates/curve_test.go.tmpl +++ b/wrappers/golang/internal/generator/curves/templates/curve_test.go.tmpl @@ -86,6 +86,18 @@ func test{{.CurvePrefix}}ProjectiveFromLimbs(suite *suite.Suite) { suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs()) } +func test{{.CurvePrefix}}ProjectiveArithmetic(suite *suite.Suite) { + points := {{if eq .CurvePrefix "G2"}}g2{{else}}{{.Curve}}{{end}}.{{.CurvePrefix}}GenerateProjectivePoints(2) + + point1 := points[0] + point2 := points[1] + + add := point1.Add(&point2) + sub := add.Sub(&point2) + + suite.True(point1.ProjectiveEq(&sub)) +} + func test{{.CurvePrefix}}ProjectiveFromAffine(suite *suite.Suite) { randLimbs := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) randLimbs2 := test_helpers.GenerateRandomLimb(int({{.CurvePrefix}}BASE_LIMBS)) @@ -114,6 +126,7 @@ func (s *{{.CurvePrefix}}CurveTestSuite) Test{{.CurvePrefix}}Curve() { s.Run("Test{{.CurvePrefix}}ProjectiveZero", testWrapper(&s.Suite, test{{.CurvePrefix}}ProjectiveZero)) s.Run("Test{{.CurvePrefix}}ProjectiveFromLimbs", testWrapper(&s.Suite, test{{.CurvePrefix}}ProjectiveFromLimbs)) s.Run("Test{{.CurvePrefix}}ProjectiveFromAffine", testWrapper(&s.Suite, test{{.CurvePrefix}}ProjectiveFromAffine)) + s.Run("Test{{.CurvePrefix}}ProjectiveArithmetic", testWrapper(&s.Suite, test{{.CurvePrefix}}ProjectiveArithmetic)) } func TestSuite{{.CurvePrefix}}Curve(t *testing.T) { diff --git a/wrappers/golang/internal/generator/fields/templates/field.go.tmpl b/wrappers/golang/internal/generator/fields/templates/field.go.tmpl index 3e7946af1..c03943389 100644 --- a/wrappers/golang/internal/generator/fields/templates/field.go.tmpl +++ b/wrappers/golang/internal/generator/fields/templates/field.go.tmpl @@ -113,6 +113,64 @@ func GenerateScalars(size int) core.HostSlice[{{.FieldPrefix}}Field] { return scalarSlice } +func (f {{.FieldPrefix}}Field) Add(f2 *{{.FieldPrefix}}Field) {{.FieldPrefix}}Field { + var res {{.FieldPrefix}}Field + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.{{.Field}}_add(cF, cF2, cRes) + + return res +} + +func (f {{.FieldPrefix}}Field) Sub(f2 *{{.FieldPrefix}}Field) {{.FieldPrefix}}Field { + var res {{.FieldPrefix}}Field + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.{{.Field}}_sub(cF, cF2, cRes) + + return res +} + +func (f {{.FieldPrefix}}Field) Mul(f2 *{{.FieldPrefix}}Field) {{.FieldPrefix}}Field { + var res {{.FieldPrefix}}Field + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cF2 := (*C.scalar_t)(unsafe.Pointer(f2)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.{{.Field}}_mul(cF, cF2, cRes) + + return res +} + +func (f {{.FieldPrefix}}Field) Inv() {{.FieldPrefix}}Field { + var res {{.FieldPrefix}}Field + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.{{.Field}}_inv(cF, cRes) + + return res +} + +func (f {{.FieldPrefix}}Field) Sqr() {{.FieldPrefix}}Field { + var res {{.FieldPrefix}}Field + + cF := (*C.scalar_t)(unsafe.Pointer(&f)) + cRes := (*C.scalar_t)(unsafe.Pointer(&res)) + + C.{{.Field}}_mul(cF, cF, cRes) + + return res +} + func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError { defaultCfg := core.DefaultVecOpsConfig() cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg) diff --git a/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl b/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl index 034d5ba86..b60362a61 100644 --- a/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl +++ b/wrappers/golang/internal/generator/fields/templates/field_test.go.tmpl @@ -100,6 +100,36 @@ func test{{capitalize .Field}}GenerateScalars(suite *suite.Suite) { suite.NotContains(scalars, zeroScalar) } +func test{{.FieldPrefix}}FieldArithmetic(suite *suite.Suite) { + const size = 1 << 10 + + scalarsA := {{.Field}}.GenerateScalars(size) + scalarsB := {{.Field}}.GenerateScalars(size) + + + for i := 0; i < size; i++ { + result1 := scalarsA[i].Add(&scalarsB[i]) + result2 := result1.Sub(&scalarsB[i]) + + suite.Equal(scalarsA[i], result2, "Addition and subtraction do not yield the original value") + } + + scalarA := scalarsA[0] + square := scalarA.Sqr() + mul := scalarA.Mul(&scalarA) + + suite.Equal(square, mul, "Square and multiplication do not yield the same value") + + inv := scalarA.Inv() + + one := scalarA.Mul(&inv) + expectedOne := {{.Field}}.GenerateScalars(1)[0] + expectedOne.One() + + suite.Equal(expectedOne, one) +} + + func test{{capitalize .Field}}MongtomeryConversion(suite *suite.Suite) { size := 1 << 20 scalars := {{.Field}}.GenerateScalars(size) @@ -135,6 +165,7 @@ func (s *{{.FieldPrefix}}FieldTestSuite) Test{{.FieldPrefix}}Field() { s.Run("Test{{.FieldPrefix}}FieldFromBytes", testWrapper(&s.Suite, test{{.FieldPrefix}}FieldFromBytes)) s.Run("Test{{.FieldPrefix}}FieldToBytes", testWrapper(&s.Suite, test{{.FieldPrefix}}FieldToBytes)) {{if .IsScalar -}} + s.Run("Test{{.FieldPrefix}}FieldArithmetic", testWrapper(&s.Suite, test{{.FieldPrefix}}FieldArithmetic)) s.Run("Test{{capitalize .Field}}GenerateScalars", testWrapper(&s.Suite, test{{capitalize .Field}}GenerateScalars)) s.Run("Test{{capitalize .Field}}MongtomeryConversion", testWrapper(&s.Suite, test{{capitalize .Field}}MongtomeryConversion)) {{- end}} diff --git a/wrappers/golang/internal/generator/fields/templates/scalar_field.h.tmpl b/wrappers/golang/internal/generator/fields/templates/scalar_field.h.tmpl index 74327f47e..59b7573b7 100644 --- a/wrappers/golang/internal/generator/fields/templates/scalar_field.h.tmpl +++ b/wrappers/golang/internal/generator/fields/templates/scalar_field.h.tmpl @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig; void {{.Field}}_generate_scalars(scalar_t* scalars, int size); int {{.Field}}_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out); +void {{.Field}}_add(const scalar_t* a, const scalar_t* b, scalar_t* result); +void {{.Field}}_sub(const scalar_t* a, const scalar_t* b, scalar_t* result); +void {{.Field}}_mul(const scalar_t* a, const scalar_t* b, scalar_t* result); +void {{.Field}}_inv(const scalar_t* a, scalar_t* result); #ifdef __cplusplus } diff --git a/wrappers/rust/icicle-core/src/curve.rs b/wrappers/rust/icicle-core/src/curve.rs index 1beaa0826..82c8dd402 100644 --- a/wrappers/rust/icicle-core/src/curve.rs +++ b/wrappers/rust/icicle-core/src/curve.rs @@ -1,6 +1,7 @@ use crate::traits::{FieldImpl, MontgomeryConvertible}; use icicle_runtime::{errors::eIcicleError, memory::HostOrDeviceSlice, stream::IcicleStream}; use std::fmt::Debug; +use std::ops::{Add, Mul, Sub}; pub trait Curve: Debug + PartialEq + Copy + Clone { type BaseField: FieldImpl; @@ -28,6 +29,12 @@ pub trait Curve: Debug + PartialEq + Copy + Clone { is_into: bool, stream: &IcicleStream, ) -> eIcicleError; + #[doc(hidden)] + fn add(point1: Projective, point2: Projective) -> Projective; + #[doc(hidden)] + fn sub(point1: Projective, point2: Projective) -> Projective; + #[doc(hidden)] + fn mul_scalar(point1: Projective, point2: Self::ScalarField) -> Projective; } /// A [projective](https://hyperelliptic.org/EFD/g1p/auto-shortw-projective.html) elliptic curve point. @@ -158,6 +165,30 @@ impl MontgomeryConvertible for Projective { } } +impl Add for Projective { + type Output = Self; + + fn add(self, other: Self) -> Self { + C::add(self, other) + } +} + +impl Sub for Projective { + type Output = Self; + + fn sub(self, other: Self) -> Self { + C::sub(self, other) + } +} + +impl Mul<::ScalarField> for Projective { + type Output = Self; + + fn mul(self, other: ::ScalarField) -> Self { + C::mul_scalar(self, other) + } +} + #[macro_export] macro_rules! impl_curve { ( @@ -176,7 +207,7 @@ macro_rules! impl_curve { pub type $projective_type = Projective<$curve>; mod $curve_prefix_ident { - use super::{eIcicleError, $affine_type, $projective_type, IcicleStream, VecOpsConfig}; + use super::{eIcicleError, $affine_type, $projective_type, $scalar_field, IcicleStream, VecOpsConfig}; extern "C" { #[link_name = concat!($curve_prefix, "_eq")] @@ -187,6 +218,24 @@ macro_rules! impl_curve { pub(crate) fn generate_projective_points(points: *mut $projective_type, size: usize); #[link_name = concat!($curve_prefix, "_generate_affine_points")] pub(crate) fn generate_affine_points(points: *mut $affine_type, size: usize); + #[link_name = concat!($curve_prefix, "_ecadd")] + pub(crate) fn add( + point1: *const $projective_type, + point2: *const $projective_type, + result: *mut $projective_type, + ); + #[link_name = concat!($curve_prefix, "_ecsub")] + pub(crate) fn sub( + point1: *const $projective_type, + point2: *const $projective_type, + result: *mut $projective_type, + ); + #[link_name = concat!($curve_prefix, "_mul_scalar")] + pub(crate) fn mul_scalar( + point1: *const $projective_type, + point2: *const $scalar_field, + result: *mut $projective_type, + ); #[link_name = concat!($curve_prefix, "_affine_convert_montgomery")] pub(crate) fn _convert_affine_montgomery( input: *const $affine_type, @@ -218,6 +267,48 @@ macro_rules! impl_curve { unsafe { $curve_prefix_ident::proj_to_affine(point, point_out) }; } + fn add(point1: $projective_type, point2: $projective_type) -> $projective_type { + let mut result = $projective_type::zero(); + + unsafe { + $curve_prefix_ident::add( + &point1 as *const $projective_type, + &point2 as *const $projective_type, + &mut result as *mut _ as *mut $projective_type, + ); + }; + + result + } + + fn sub(point1: $projective_type, point2: $projective_type) -> $projective_type { + let mut result = $projective_type::zero(); + + unsafe { + $curve_prefix_ident::sub( + &point1 as *const $projective_type, + &point2 as *const $projective_type, + &mut result as *mut _ as *mut $projective_type, + ); + }; + + result + } + + fn mul_scalar(point1: $projective_type, point2: $scalar_field) -> $projective_type { + let mut result = $projective_type::zero(); + + unsafe { + $curve_prefix_ident::mul_scalar( + &point1 as *const $projective_type, + &point2 as *const $scalar_field, + &mut result as *mut _ as *mut $projective_type, + ); + }; + + result + } + fn generate_random_projective_points(size: usize) -> Vec<$projective_type> { let mut res = vec![$projective_type::zero(); size]; unsafe { @@ -298,6 +389,11 @@ macro_rules! impl_curve_tests { initialize(); check_points_convert_montgomery::<$curve>() } + + #[test] + fn test_point_arithmetic() { + check_point_arithmetic::<$curve>(); + } } }; } diff --git a/wrappers/rust/icicle-core/src/field.rs b/wrappers/rust/icicle-core/src/field.rs index f0f1e1ad5..9e3a16c2e 100644 --- a/wrappers/rust/icicle-core/src/field.rs +++ b/wrappers/rust/icicle-core/src/field.rs @@ -1,10 +1,11 @@ -use crate::traits::{FieldConfig, FieldImpl, MontgomeryConvertible}; +use crate::traits::{Arithmetic, FieldConfig, FieldImpl, MontgomeryConvertible}; use hex::FromHex; use icicle_runtime::errors::eIcicleError; use icicle_runtime::memory::HostOrDeviceSlice; use icicle_runtime::stream::IcicleStream; use std::fmt::{Debug, Display}; use std::marker::PhantomData; +use std::ops::{Add, Mul, Sub}; #[derive(PartialEq, Copy, Clone)] #[repr(C)] @@ -106,6 +107,61 @@ pub trait MontgomeryConvertibleField { fn from_mont(values: &mut (impl HostOrDeviceSlice + ?Sized), stream: &IcicleStream) -> eIcicleError; } +#[doc(hidden)] +pub trait FieldArithmetic { + fn add(first: F, second: F) -> F; + fn sub(first: F, second: F) -> F; + fn mul(first: F, second: F) -> F; + fn sqr(first: F) -> F; + fn inv(first: F) -> F; +} + +impl Arithmetic for Field +where + F: FieldArithmetic, +{ + fn sqr(self) -> Self { + F::sqr(self) + } + + fn inv(self) -> Self { + F::inv(self) + } +} + +impl Add for Field +where + F: FieldArithmetic, +{ + type Output = Self; + + fn add(self, second: Self) -> Self { + F::add(self, second) + } +} + +impl Sub for Field +where + F: FieldArithmetic, +{ + type Output = Self; + + fn sub(self, second: Self) -> Self { + F::sub(self, second) + } +} + +impl Mul for Field +where + F: FieldArithmetic, +{ + type Output = Self; + + fn mul(self, second: Self) -> Self { + F::mul(self, second) + } +} + impl MontgomeryConvertible for Field where F: MontgomeryConvertibleField, @@ -148,7 +204,7 @@ macro_rules! impl_scalar_field { mod $field_prefix_ident { use super::{$field_name, HostOrDeviceSlice}; - use icicle_core::vec_ops::VecOpsConfig; + use icicle_core::{traits::FieldImpl, vec_ops::VecOpsConfig}; use icicle_runtime::errors::eIcicleError; use icicle_runtime::stream::{IcicleStream, IcicleStreamHandle}; @@ -164,6 +220,18 @@ macro_rules! impl_scalar_field { config: &VecOpsConfig, output: *mut $field_name, ) -> eIcicleError; + + #[link_name = concat!($field_prefix, "_add")] + pub(crate) fn add(a: *const $field_name, b: *const $field_name, result: *mut $field_name); + + #[link_name = concat!($field_prefix, "_sub")] + pub(crate) fn sub(a: *const $field_name, b: *const $field_name, result: *mut $field_name); + + #[link_name = concat!($field_prefix, "_mul")] + pub(crate) fn mul(a: *const $field_name, b: *const $field_name, result: *mut $field_name); + + #[link_name = concat!($field_prefix, "_inv")] + pub(crate) fn inv(a: *const $field_name, result: *mut $field_name); } pub(crate) fn convert_scalars_montgomery( @@ -176,6 +244,69 @@ macro_rules! impl_scalar_field { } } + impl icicle_core::field::FieldArithmetic<$field_name> for $field_cfg { + fn add(first: $field_name, second: $field_name) -> $field_name { + let mut result = $field_name::zero(); + unsafe { + $field_prefix_ident::add( + &first as *const $field_name, + &second as *const $field_name, + &mut result as *mut $field_name, + ); + } + + result + } + + fn sub(first: $field_name, second: $field_name) -> $field_name { + let mut result = $field_name::zero(); + unsafe { + $field_prefix_ident::sub( + &first as *const $field_name, + &second as *const $field_name, + &mut result as *mut $field_name, + ); + } + + result + } + + fn mul(first: $field_name, second: $field_name) -> $field_name { + let mut result = $field_name::zero(); + unsafe { + $field_prefix_ident::mul( + &first as *const $field_name, + &second as *const $field_name, + &mut result as *mut $field_name, + ); + } + + result + } + + fn sqr(first: $field_name) -> $field_name { + let mut result = $field_name::zero(); + unsafe { + $field_prefix_ident::mul( + &first as *const $field_name, + &first as *const $field_name, + &mut result as *mut $field_name, + ); + } + + result + } + + fn inv(first: $field_name) -> $field_name { + let mut result = $field_name::zero(); + unsafe { + $field_prefix_ident::inv(&first as *const $field_name, &mut result as *mut $field_name); + } + + result + } + } + impl GenerateRandom<$field_name> for $field_cfg { fn generate_random(size: usize) -> Vec<$field_name> { let mut res = vec![$field_name::zero(); size]; @@ -255,6 +386,11 @@ macro_rules! impl_field_tests { initialize(); check_field_equality::<$field_name>() } + + #[test] + fn test_field_arithmetic() { + check_field_arithmetic::<$field_name>() + } } }; } diff --git a/wrappers/rust/icicle-core/src/tests.rs b/wrappers/rust/icicle-core/src/tests.rs index 20d8b6d03..370fc01b7 100644 --- a/wrappers/rust/icicle-core/src/tests.rs +++ b/wrappers/rust/icicle-core/src/tests.rs @@ -1,7 +1,7 @@ use crate::{ curve::{Affine, Curve, Projective}, field::Field, - traits::{FieldConfig, FieldImpl, GenerateRandom, MontgomeryConvertible}, + traits::{Arithmetic, FieldConfig, FieldImpl, GenerateRandom, MontgomeryConvertible}, }; use icicle_runtime::{ memory::{DeviceVec, HostSlice}, @@ -16,6 +16,32 @@ pub fn check_field_equality() { assert_eq!(left, right); } +pub fn check_field_arithmetic() +where + F: FieldImpl + Arithmetic, + F::Config: GenerateRandom, +{ + let size = 1 << 10; + let scalars_a = F::Config::generate_random(size); + let scalars_b = F::Config::generate_random(size); + + for i in 0..size { + let result1 = scalars_a[i] + scalars_b[i]; + let result2 = result1 - scalars_b[i]; + assert_eq!(result2, scalars_a[i]); + } + + let scalar_a = scalars_a[0]; + let square = scalar_a.sqr(); + let mul = scalar_a.mul(scalar_a); + + assert_eq!(square, mul); + + let inv = scalar_a.inv(); + let one = scalar_a.mul(inv); + assert_eq!(one, F::one()); +} + pub fn check_affine_projective_convert() { let size = 1 << 10; let affine_points = C::generate_random_affine_points(size); @@ -30,6 +56,18 @@ pub fn check_affine_projective_convert() { } } +pub fn check_point_arithmetic() { + let size = 1 << 10; + let projective_points_a = C::generate_random_projective_points(size); + let projective_points_b = C::generate_random_projective_points(size); + + for i in 0..size { + let result1 = projective_points_a[i] + projective_points_b[i]; + let result2 = result1 - projective_points_b[i]; + assert_eq!(result2, projective_points_a[i]); + } +} + pub fn check_point_equality() where C: Curve>, diff --git a/wrappers/rust/icicle-core/src/traits.rs b/wrappers/rust/icicle-core/src/traits.rs index 85d8b775e..835a0dca0 100644 --- a/wrappers/rust/icicle-core/src/traits.rs +++ b/wrappers/rust/icicle-core/src/traits.rs @@ -2,6 +2,7 @@ use icicle_runtime::errors::eIcicleError; use icicle_runtime::memory::HostOrDeviceSlice; use icicle_runtime::stream::IcicleStream; use std::fmt::{Debug, Display}; +use std::ops::{Add, Mul, Sub}; #[doc(hidden)] pub trait GenerateRandom { @@ -30,3 +31,8 @@ pub trait MontgomeryConvertible: Sized { fn to_mont(values: &mut (impl HostOrDeviceSlice + ?Sized), stream: &IcicleStream) -> eIcicleError; fn from_mont(values: &mut (impl HostOrDeviceSlice + ?Sized), stream: &IcicleStream) -> eIcicleError; } + +pub trait Arithmetic: Sized + Add + Sub + Mul { + fn sqr(self) -> Self; + fn inv(self) -> Self; +}