From 382bd8e1ba4b9e6dfb70de0a7aaae4cca8f594e5 Mon Sep 17 00:00:00 2001 From: Youssef El Housni Date: Fri, 9 Feb 2024 15:03:19 +0100 Subject: [PATCH] perf(sw_emulated): optimize jointScalarMulGeneric --- std/algebra/emulated/sw_emulated/point.go | 64 ++++++- .../emulated/sw_emulated/point_test.go | 171 +++++++++++++++++- 2 files changed, 225 insertions(+), 10 deletions(-) diff --git a/std/algebra/emulated/sw_emulated/point.go b/std/algebra/emulated/sw_emulated/point.go index ac2c8f7066..4aefe8c78e 100644 --- a/std/algebra/emulated/sw_emulated/point.go +++ b/std/algebra/emulated/sw_emulated/point.go @@ -664,16 +664,64 @@ func (c *Curve[B, S]) jointScalarMul(p1, p2 *AffinePoint[B], s1, s2 *emulated.El } } -// jointScalarMulGeneric computes [s1]p1 + [s2]p2. It doesn't modify the inputs. +// jointScalarMulGeneric computes [s1]p1 + [s2]p2. It doesn't modify p1, p2 nor s1, s2. // -// ⚠️ p1, p2 must not be (0,0) and s1, s2 must not be 0, unless [algopts.WithCompleteArithmetic] option is set. +// ⚠️ The scalars s1, s2 must be nonzero and the point p1, p2 different from (0,0), unless [algopts.WithCompleteArithmetic] option is set. func (c *Curve[B, S]) jointScalarMulGeneric(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { - res1 := c.scalarMulGeneric(p1, s1, opts...) - res2 := c.scalarMulGeneric(p2, s2, opts...) - return c.Add(res1, res2) + cfg, err := algopts.NewConfig(opts...) + if err != nil { + panic(fmt.Sprintf("parse opts: %v", err)) + } + if cfg.CompleteArithmetic { + // TODO @yelhousni: optimize + res1 := c.scalarMulGeneric(p1, s1, opts...) + res2 := c.scalarMulGeneric(p2, s2, opts...) + return c.AddUnified(res1, res2) + } else { + return c.jointScalarMulGenericUnsafe(p1, p2, s1, s2) + } +} + +// jointScalarMulGenericUnsafe computes [s1]p1 + [s2]p2 using Shamir's trick and returns it. It doesn't modify p1, p2 nor s1, s2. +// ⚠️ The scalars must be nonzero and the points different from (0,0). +func (c *Curve[B, S]) jointScalarMulGenericUnsafe(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S]) *AffinePoint[B] { + var Acc, B1, p1Neg, p2Neg *AffinePoint[B] + p1Neg = c.Neg(p1) + p2Neg = c.Neg(p2) + + // Acc = P1 + P2 + Acc = c.Add(p1, p2) + + s1bits := c.scalarApi.ToBits(s1) + s2bits := c.scalarApi.ToBits(s2) + + var st S + nbits := st.Modulus().BitLen() + + for i := nbits - 1; i > 0; i-- { + B1 = &AffinePoint[B]{ + X: p1Neg.X, + Y: *c.baseApi.Select(s1bits[i], &p1.Y, &p1Neg.Y), + } + Acc = c.doubleAndAdd(Acc, B1) + B1 = &AffinePoint[B]{ + X: p2Neg.X, + Y: *c.baseApi.Select(s2bits[i], &p2.Y, &p2Neg.Y), + } + Acc = c.Add(Acc, B1) + + } + + // i = 0 + p1Neg = c.Add(p1Neg, Acc) + Acc = c.Select(s1bits[0], Acc, p1Neg) + p2Neg = c.Add(p2Neg, Acc) + Acc = c.Select(s2bits[0], Acc, p2Neg) + + return Acc } -// jointScalarMulGLV computes [s1]p1 + [s2]p2 using an endomorphism. It doesn't modify P, Q nor s. +// jointScalarMulGLV computes [s1]p1 + [s2]p2 using an endomorphism. It doesn't modify p1, p2 nor s1, s2. // // ⚠️ The scalars s1, s2 must be nonzero and the point p1, p2 different from (0,0), unless [algopts.WithCompleteArithmetic] option is set. func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] { @@ -691,8 +739,8 @@ func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated } } -// jointScalarMulGLVUnsafe computes [s]Q + [t]R using Shamir's trick with an efficient endomorphism and returns it. It doesn't modify P, Q nor s. -// ⚠️ The scalar s must be nonzero and the point Q different from (0,0), unless [algopts.WithCompleteArithmetic] option is set. +// jointScalarMulGLVUnsafe computes [s]Q + [t]R using Shamir's trick with an efficient endomorphism and returns it. It doesn't modify Q, R nor s, t. +// ⚠️ The scalars must be nonzero and the points different from (0,0). func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulated.Element[S]) *AffinePoint[B] { var st S frModulus := c.scalarApi.Modulus() diff --git a/std/algebra/emulated/sw_emulated/point_test.go b/std/algebra/emulated/sw_emulated/point_test.go index 2b912a972b..2c8f6c6b5c 100644 --- a/std/algebra/emulated/sw_emulated/point_test.go +++ b/std/algebra/emulated/sw_emulated/point_test.go @@ -1391,6 +1391,40 @@ func TestJointScalarMul6(t *testing.T) { assert.NoError(err) } +func TestJointScalarMul4(t *testing.T) { + assert := test.NewAssert(t) + p256 := elliptic.P256() + s1, err := rand.Int(rand.Reader, p256.Params().N) + assert.NoError(err) + s2, err := rand.Int(rand.Reader, p256.Params().N) + assert.NoError(err) + p1x, p1y := p256.ScalarBaseMult(s1.Bytes()) + p2x, p2y := p256.ScalarBaseMult(s2.Bytes()) + resx, resy := p256.ScalarMult(p1x, p1y, s1.Bytes()) + tmpx, tmpy := p256.ScalarMult(p2x, p2y, s2.Bytes()) + resx, resy = p256.Add(resx, resy, tmpx, tmpy) + + circuit := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{} + witness := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{ + S1: emulated.ValueOf[emulated.P256Fr](s1), + S2: emulated.ValueOf[emulated.P256Fr](s2), + P1: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p1x), + Y: emulated.ValueOf[emulated.P256Fp](p1y), + }, + P2: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p2x), + Y: emulated.ValueOf[emulated.P256Fp](p2y), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](resx), + Y: emulated.ValueOf[emulated.P256Fp](resy), + }, + } + err = test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + type JointScalarMulEdgeCasesTest[T, S emulated.FieldParams] struct { P1, P2, Q AffinePoint[T] S1, S2 emulated.Element[S] @@ -1415,12 +1449,11 @@ func TestJointScalarMulEdgeCases6(t *testing.T) { s2 := new(big.Int) r1.BigInt(s1) r2.BigInt(s2) - var res, res1, res2, gen2, infinity bw6761.G1Affine + var res1, res2, gen2, infinity bw6761.G1Affine _, _, gen1, _ := bw6761.Generators() gen2.Double(&gen1) res1.ScalarMultiplication(&gen1, s1) res2.ScalarMultiplication(&gen2, s2) - res.Add(&res1, &res2) circuit := JointScalarMulEdgeCasesTest[emulated.BW6761Fp, emulated.BW6761Fr]{} // s1*(0,0) + s2*(0,0) == (0,0) @@ -1544,6 +1577,140 @@ func TestJointScalarMulEdgeCases6(t *testing.T) { assert.NoError(err) } +func TestJointScalarMulEdgeCases4(t *testing.T) { + assert := test.NewAssert(t) + p256 := elliptic.P256() + s1, err := rand.Int(rand.Reader, p256.Params().N) + assert.NoError(err) + s2, err := rand.Int(rand.Reader, p256.Params().N) + assert.NoError(err) + p1x, p1y := p256.ScalarBaseMult(s1.Bytes()) + p2x, p2y := p256.ScalarBaseMult(s2.Bytes()) + res1x, res1y := p256.ScalarMult(p1x, p1y, s1.Bytes()) + res2x, res2y := p256.ScalarMult(p2x, p2y, s2.Bytes()) + + circuit := JointScalarMulEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{} + // s1*(0,0) + s2*(0,0) == (0,0) + witness1 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{ + S1: emulated.ValueOf[emulated.P256Fr](s1), + S2: emulated.ValueOf[emulated.P256Fr](s2), + P1: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + P2: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness1, testCurve.ScalarField()) + assert.NoError(err) + + // s1*P + s2*(0,0) == s1*P + witness2 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{ + S1: emulated.ValueOf[emulated.P256Fr](s1), + S2: emulated.ValueOf[emulated.P256Fr](s2), + P1: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p1x), + Y: emulated.ValueOf[emulated.P256Fp](p1y), + }, + P2: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](res1x), + Y: emulated.ValueOf[emulated.P256Fp](res1y), + }, + } + err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField()) + assert.NoError(err) + + // s1*(0,0) + s2*Q == s2*Q + witness3 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{ + S1: emulated.ValueOf[emulated.P256Fr](s1), + S2: emulated.ValueOf[emulated.P256Fr](s2), + P1: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + P2: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p2x), + Y: emulated.ValueOf[emulated.P256Fp](p2y), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](res2x), + Y: emulated.ValueOf[emulated.P256Fp](res2y), + }, + } + err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField()) + assert.NoError(err) + + // 0*P + 0*Q == (0,0) + witness4 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{ + S1: emulated.ValueOf[emulated.P256Fr](0), + S2: emulated.ValueOf[emulated.P256Fr](0), + P1: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p1x), + Y: emulated.ValueOf[emulated.P256Fp](p1y), + }, + P2: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p2x), + Y: emulated.ValueOf[emulated.P256Fp](p2y), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](0), + Y: emulated.ValueOf[emulated.P256Fp](0), + }, + } + err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField()) + assert.NoError(err) + + // 0*P + s2*Q == s2*Q + witness5 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{ + S1: emulated.ValueOf[emulated.P256Fr](0), + S2: emulated.ValueOf[emulated.P256Fr](s2), + P1: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p1x), + Y: emulated.ValueOf[emulated.P256Fp](p1y), + }, + P2: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p2x), + Y: emulated.ValueOf[emulated.P256Fp](p2y), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](res2x), + Y: emulated.ValueOf[emulated.P256Fp](res2y), + }, + } + err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField()) + assert.NoError(err) + + // s1*P + 0*Q == s1*P + witness6 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{ + S1: emulated.ValueOf[emulated.P256Fr](s1), + S2: emulated.ValueOf[emulated.P256Fr](0), + P1: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p1x), + Y: emulated.ValueOf[emulated.P256Fp](p1y), + }, + P2: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](p2x), + Y: emulated.ValueOf[emulated.P256Fp](p2y), + }, + Q: AffinePoint[emulated.P256Fp]{ + X: emulated.ValueOf[emulated.P256Fp](res1x), + Y: emulated.ValueOf[emulated.P256Fp](res1y), + }, + } + err = test.IsSolved(&circuit, &witness6, testCurve.ScalarField()) + assert.NoError(err) +} + type MuxCircuitTest[T, S emulated.FieldParams] struct { Selector frontend.Variable Inputs [8]AffinePoint[T]