Skip to content

Commit

Permalink
Host arithmetic for Rust and Golang
Browse files Browse the repository at this point in the history
Co-authored-by: Yuval Shekel <yshekel@gmail.com>
  • Loading branch information
emirsoyturk and yshekel authored Dec 4, 2024
1 parent 379d3e5 commit f797dae
Show file tree
Hide file tree
Showing 64 changed files with 1,488 additions and 4 deletions.
32 changes: 32 additions & 0 deletions icicle/src/curves/ffi_extern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ extern "C" bool CONCAT_EXPAND(CURVE, eq)(projective_t* point1, projective_t* poi
(point2->z == point_field_t::zero()));
}

extern "C" void CONCAT_EXPAND(CURVE, ecsub)(projective_t* point1, projective_t* point2, projective_t* result)
{
*result = *point1 - *point2;
}

extern "C" void CONCAT_EXPAND(CURVE, ecadd)(projective_t* point1, projective_t* point2, projective_t* result)
{
*result = *point1 + *point2;
}

extern "C" void CONCAT_EXPAND(CURVE, mul_scalar)(projective_t* point, scalar_t* scalar, projective_t* result)
{
*result = *point * *scalar;
}

extern "C" void CONCAT_EXPAND(CURVE, to_affine)(projective_t* point, affine_t* point_out)
{
*point_out = projective_t::to_affine(*point);
Expand Down Expand Up @@ -46,6 +61,23 @@ extern "C" bool CONCAT_EXPAND(CURVE, g2_eq)(g2_projective_t* point1, g2_projecti
(point2->z == g2_point_field_t::zero()));
}

extern "C" void
CONCAT_EXPAND(CURVE, g2_ecsub)(g2_projective_t* point1, g2_projective_t* point2, g2_projective_t* result)
{
*result = *point1 - *point2;
}

extern "C" void
CONCAT_EXPAND(CURVE, g2_ecadd)(g2_projective_t* point1, g2_projective_t* point2, g2_projective_t* result)
{
*result = *point1 + *point2;
}

extern "C" void CONCAT_EXPAND(CURVE, g2_mul_scalar)(g2_projective_t* point, scalar_t* scalar, g2_projective_t* result)
{
*result = *point * *scalar;
}

extern "C" void CONCAT_EXPAND(CURVE, g2_to_affine)(g2_projective_t* point, g2_affine_t* point_out)
{
*point_out = g2_projective_t::to_affine(*point);
Expand Down
40 changes: 40 additions & 0 deletions icicle/src/fields/ffi_extern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,49 @@ extern "C" void CONCAT_EXPAND(FIELD, generate_scalars)(scalar_t* scalars, int si
scalar_t::rand_host_many(scalars, size);
}

extern "C" void CONCAT_EXPAND(FIELD, sub)(scalar_t* scalar1, scalar_t* scalar2, scalar_t* result)
{
*result = *scalar1 - *scalar2;
}

extern "C" void CONCAT_EXPAND(FIELD, add)(scalar_t* scalar1, scalar_t* scalar2, scalar_t* result)
{
*result = *scalar1 + *scalar2;
}

extern "C" void CONCAT_EXPAND(FIELD, mul)(scalar_t* scalar1, scalar_t* scalar2, scalar_t* result)
{
*result = *scalar1 * *scalar2;
}

extern "C" void CONCAT_EXPAND(FIELD, inv)(scalar_t* scalar1, scalar_t* result)
{
*result = scalar_t::inverse(*scalar1);
}

#ifdef EXT_FIELD
extern "C" void CONCAT_EXPAND(FIELD, extension_generate_scalars)(extension_t* scalars, int size)
{
extension_t::rand_host_many(scalars, size);
}

extern "C" void CONCAT_EXPAND(FIELD, extension_sub)(extension_t* scalar1, extension_t* scalar2, extension_t* result)
{
*result = *scalar1 - *scalar2;
}

extern "C" void CONCAT_EXPAND(FIELD, extension_add)(extension_t* scalar1, extension_t* scalar2, extension_t* result)
{
*result = *scalar1 + *scalar2;
}

extern "C" void CONCAT_EXPAND(FIELD, extension_mul)(extension_t* scalar1, extension_t* scalar2, extension_t* result)
{
*result = *scalar1 * *scalar2;
}

extern "C" void CONCAT_EXPAND(FIELD, extension_inv)(extension_t* scalar1, extension_t* result)
{
*result = extension_t::inverse(*scalar1);
}
#endif // EXT_FIELD
24 changes: 24 additions & 0 deletions wrappers/golang/curves/bls12377/curve.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ func (p Projective) ProjectiveEq(p2 *Projective) bool {
return __ret == (C._Bool)(true)
}

func (p Projective) Add(p2 *Projective) Projective {
var res Projective

cP := (*C.projective_t)(unsafe.Pointer(&p))
cP2 := (*C.projective_t)(unsafe.Pointer(p2))
cRes := (*C.projective_t)(unsafe.Pointer(&res))

C.bls12_377_ecadd(cP, cP2, cRes)

return res
}

func (p Projective) Sub(p2 *Projective) Projective {
var res Projective

cP := (*C.projective_t)(unsafe.Pointer(&p))
cP2 := (*C.projective_t)(unsafe.Pointer(p2))
cRes := (*C.projective_t)(unsafe.Pointer(&res))

C.bls12_377_ecsub(cP, cP2, cRes)

return res
}

func (p *Projective) ToAffine() Affine {
var a Affine

Expand Down
24 changes: 24 additions & 0 deletions wrappers/golang/curves/bls12377/g2/curve.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ func (p G2Projective) ProjectiveEq(p2 *G2Projective) bool {
return __ret == (C._Bool)(true)
}

func (p G2Projective) Add(p2 *G2Projective) G2Projective {
var res G2Projective

cP := (*C.g2_projective_t)(unsafe.Pointer(&p))
cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2))
cRes := (*C.g2_projective_t)(unsafe.Pointer(&res))

C.bls12_377_g2_ecadd(cP, cP2, cRes)

return res
}

func (p G2Projective) Sub(p2 *G2Projective) G2Projective {
var res G2Projective

cP := (*C.g2_projective_t)(unsafe.Pointer(&p))
cP2 := (*C.g2_projective_t)(unsafe.Pointer(p2))
cRes := (*C.g2_projective_t)(unsafe.Pointer(&res))

C.bls12_377_g2_ecsub(cP, cP2, cRes)

return res
}

func (p *G2Projective) ToAffine() G2Affine {
var a G2Affine

Expand Down
2 changes: 2 additions & 0 deletions wrappers/golang/curves/bls12377/g2/include/curve.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ typedef struct g2_affine_t g2_affine_t;
typedef struct VecOpsConfig VecOpsConfig;

bool bls12_377_g2_eq(g2_projective_t* point1, g2_projective_t* point2);
void bls12_377_g2_ecadd(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res);
void bls12_377_g2_ecsub(g2_projective_t* point, g2_projective_t* point2, g2_projective_t* res);
void bls12_377_g2_to_affine(g2_projective_t* point, g2_affine_t* point_out);
void bls12_377_g2_from_affine(g2_affine_t* point, g2_projective_t* point_out);
void bls12_377_g2_generate_projective_points(g2_projective_t* points, int size);
Expand Down
4 changes: 4 additions & 0 deletions wrappers/golang/curves/bls12377/g2/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig;

void bls12_377_generate_scalars(scalar_t* scalars, int size);
int bls12_377_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out);
void bls12_377_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_inv(const scalar_t* a, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
2 changes: 2 additions & 0 deletions wrappers/golang/curves/bls12377/include/curve.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ typedef struct affine_t affine_t;
typedef struct VecOpsConfig VecOpsConfig;

bool bls12_377_eq(projective_t* point1, projective_t* point2);
void bls12_377_ecadd(projective_t* point, projective_t* point2, projective_t* res);
void bls12_377_ecsub(projective_t* point, projective_t* point2, projective_t* res);
void bls12_377_to_affine(projective_t* point, affine_t* point_out);
void bls12_377_from_affine(affine_t* point, projective_t* point_out);
void bls12_377_generate_projective_points(projective_t* points, int size);
Expand Down
4 changes: 4 additions & 0 deletions wrappers/golang/curves/bls12377/include/scalar_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ typedef struct VecOpsConfig VecOpsConfig;

void bls12_377_generate_scalars(scalar_t* scalars, int size);
int bls12_377_scalar_convert_montgomery(const scalar_t* d_in, size_t n, bool is_into, const VecOpsConfig* ctx, scalar_t* d_out);
void bls12_377_add(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_sub(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_mul(const scalar_t* a, const scalar_t* b, scalar_t* result);
void bls12_377_inv(const scalar_t* a, scalar_t* result);

#ifdef __cplusplus
}
Expand Down
58 changes: 58 additions & 0 deletions wrappers/golang/curves/bls12377/scalar_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,64 @@ func GenerateScalars(size int) core.HostSlice[ScalarField] {
return scalarSlice
}

func (f ScalarField) Add(f2 *ScalarField) ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cF2 := (*C.scalar_t)(unsafe.Pointer(f2))
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.bls12_377_add(cF, cF2, cRes)

return res
}

func (f ScalarField) Sub(f2 *ScalarField) ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cF2 := (*C.scalar_t)(unsafe.Pointer(f2))
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.bls12_377_sub(cF, cF2, cRes)

return res
}

func (f ScalarField) Mul(f2 *ScalarField) ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cF2 := (*C.scalar_t)(unsafe.Pointer(f2))
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.bls12_377_mul(cF, cF2, cRes)

return res
}

func (f ScalarField) Inv() ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.bls12_377_inv(cF, cRes)

return res
}

func (f ScalarField) Sqr() ScalarField {
var res ScalarField

cF := (*C.scalar_t)(unsafe.Pointer(&f))
cRes := (*C.scalar_t)(unsafe.Pointer(&res))

C.bls12_377_mul(cF, cF, cRes)

return res
}

func convertScalarsMontgomery(scalars core.HostOrDeviceSlice, isInto bool) runtime.EIcicleError {
defaultCfg := core.DefaultVecOpsConfig()
cValues, _, _, cCfg, cSize := core.VecOpCheck(scalars, scalars, scalars, &defaultCfg)
Expand Down
13 changes: 13 additions & 0 deletions wrappers/golang/curves/bls12377/tests/curve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ func testProjectiveFromLimbs(suite *suite.Suite) {
suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs())
}

func testProjectiveArithmetic(suite *suite.Suite) {
points := bls12_377.GenerateProjectivePoints(2)

point1 := points[0]
point2 := points[1]

add := point1.Add(&point2)
sub := add.Sub(&point2)

suite.True(point1.ProjectiveEq(&sub))
}

func testProjectiveFromAffine(suite *suite.Suite) {
randLimbs := test_helpers.GenerateRandomLimb(int(BASE_LIMBS))
randLimbs2 := test_helpers.GenerateRandomLimb(int(BASE_LIMBS))
Expand Down Expand Up @@ -114,6 +126,7 @@ func (s *CurveTestSuite) TestCurve() {
s.Run("TestProjectiveZero", testWrapper(&s.Suite, testProjectiveZero))
s.Run("TestProjectiveFromLimbs", testWrapper(&s.Suite, testProjectiveFromLimbs))
s.Run("TestProjectiveFromAffine", testWrapper(&s.Suite, testProjectiveFromAffine))
s.Run("TestProjectiveArithmetic", testWrapper(&s.Suite, testProjectiveArithmetic))
}

func TestSuiteCurve(t *testing.T) {
Expand Down
13 changes: 13 additions & 0 deletions wrappers/golang/curves/bls12377/tests/g2_curve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ func testG2ProjectiveFromLimbs(suite *suite.Suite) {
suite.ElementsMatch(randLimbs3, projective.Z.GetLimbs())
}

func testG2ProjectiveArithmetic(suite *suite.Suite) {
points := g2.G2GenerateProjectivePoints(2)

point1 := points[0]
point2 := points[1]

add := point1.Add(&point2)
sub := add.Sub(&point2)

suite.True(point1.ProjectiveEq(&sub))
}

func testG2ProjectiveFromAffine(suite *suite.Suite) {
randLimbs := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS))
randLimbs2 := test_helpers.GenerateRandomLimb(int(G2BASE_LIMBS))
Expand Down Expand Up @@ -114,6 +126,7 @@ func (s *G2CurveTestSuite) TestG2Curve() {
s.Run("TestG2ProjectiveZero", testWrapper(&s.Suite, testG2ProjectiveZero))
s.Run("TestG2ProjectiveFromLimbs", testWrapper(&s.Suite, testG2ProjectiveFromLimbs))
s.Run("TestG2ProjectiveFromAffine", testWrapper(&s.Suite, testG2ProjectiveFromAffine))
s.Run("TestG2ProjectiveArithmetic", testWrapper(&s.Suite, testG2ProjectiveArithmetic))
}

func TestSuiteG2Curve(t *testing.T) {
Expand Down
29 changes: 29 additions & 0 deletions wrappers/golang/curves/bls12377/tests/scalar_field_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,34 @@ func testBls12_377GenerateScalars(suite *suite.Suite) {
suite.NotContains(scalars, zeroScalar)
}

func testScalarFieldArithmetic(suite *suite.Suite) {
const size = 1 << 10

scalarsA := bls12_377.GenerateScalars(size)
scalarsB := bls12_377.GenerateScalars(size)

for i := 0; i < size; i++ {
result1 := scalarsA[i].Add(&scalarsB[i])
result2 := result1.Sub(&scalarsB[i])

suite.Equal(scalarsA[i], result2, "Addition and subtraction do not yield the original value")
}

scalarA := scalarsA[0]
square := scalarA.Sqr()
mul := scalarA.Mul(&scalarA)

suite.Equal(square, mul, "Square and multiplication do not yield the same value")

inv := scalarA.Inv()

one := scalarA.Mul(&inv)
expectedOne := bls12_377.GenerateScalars(1)[0]
expectedOne.One()

suite.Equal(expectedOne, one)
}

func testBls12_377MongtomeryConversion(suite *suite.Suite) {
size := 1 << 20
scalars := bls12_377.GenerateScalars(size)
Expand Down Expand Up @@ -133,6 +161,7 @@ func (s *ScalarFieldTestSuite) TestScalarField() {
s.Run("TestScalarFieldAsPointer", testWrapper(&s.Suite, testScalarFieldAsPointer))
s.Run("TestScalarFieldFromBytes", testWrapper(&s.Suite, testScalarFieldFromBytes))
s.Run("TestScalarFieldToBytes", testWrapper(&s.Suite, testScalarFieldToBytes))
s.Run("TestScalarFieldArithmetic", testWrapper(&s.Suite, testScalarFieldArithmetic))
s.Run("TestBls12_377GenerateScalars", testWrapper(&s.Suite, testBls12_377GenerateScalars))
s.Run("TestBls12_377MongtomeryConversion", testWrapper(&s.Suite, testBls12_377MongtomeryConversion))
}
Expand Down
24 changes: 24 additions & 0 deletions wrappers/golang/curves/bls12381/curve.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ func (p Projective) ProjectiveEq(p2 *Projective) bool {
return __ret == (C._Bool)(true)
}

func (p Projective) Add(p2 *Projective) Projective {
var res Projective

cP := (*C.projective_t)(unsafe.Pointer(&p))
cP2 := (*C.projective_t)(unsafe.Pointer(p2))
cRes := (*C.projective_t)(unsafe.Pointer(&res))

C.bls12_381_ecadd(cP, cP2, cRes)

return res
}

func (p Projective) Sub(p2 *Projective) Projective {
var res Projective

cP := (*C.projective_t)(unsafe.Pointer(&p))
cP2 := (*C.projective_t)(unsafe.Pointer(p2))
cRes := (*C.projective_t)(unsafe.Pointer(&res))

C.bls12_381_ecsub(cP, cP2, cRes)

return res
}

func (p *Projective) ToAffine() Affine {
var a Affine

Expand Down
Loading

0 comments on commit f797dae

Please sign in to comment.