Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(sw_emulated): optimize jointScalarMulGeneric #1049

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,16 +664,64 @@ func (c *Curve[B, S]) jointScalarMul(p1, p2 *AffinePoint[B], s1, s2 *emulated.El
}
}

// jointScalarMulGeneric computes [s1]p1 + [s2]p2. It doesn't modify the inputs.
// jointScalarMulGeneric computes [s1]p1 + [s2]p2. It doesn't modify p1, p2 nor s1, s2.
//
// ⚠️ p1, p2 must not be (0,0) and s1, s2 must not be 0, unless [algopts.WithCompleteArithmetic] option is set.
// ⚠️ The scalars s1, s2 must be nonzero and the point p1, p2 different from (0,0), unless [algopts.WithCompleteArithmetic] option is set.
func (c *Curve[B, S]) jointScalarMulGeneric(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
res1 := c.scalarMulGeneric(p1, s1, opts...)
res2 := c.scalarMulGeneric(p2, s2, opts...)
return c.Add(res1, res2)
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
if cfg.CompleteArithmetic {
// TODO @yelhousni: optimize
res1 := c.scalarMulGeneric(p1, s1, opts...)
res2 := c.scalarMulGeneric(p2, s2, opts...)
return c.AddUnified(res1, res2)
} else {
return c.jointScalarMulGenericUnsafe(p1, p2, s1, s2)
}
}

// jointScalarMulGenericUnsafe computes [s1]p1 + [s2]p2 using Shamir's trick and returns it. It doesn't modify p1, p2 nor s1, s2.
// ⚠️ The scalars must be nonzero and the points different from (0,0).
func (c *Curve[B, S]) jointScalarMulGenericUnsafe(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S]) *AffinePoint[B] {
var Acc, B1, p1Neg, p2Neg *AffinePoint[B]
p1Neg = c.Neg(p1)
p2Neg = c.Neg(p2)

// Acc = P1 + P2
Acc = c.Add(p1, p2)

s1bits := c.scalarApi.ToBits(s1)
s2bits := c.scalarApi.ToBits(s2)

var st S
nbits := st.Modulus().BitLen()

for i := nbits - 1; i > 0; i-- {
B1 = &AffinePoint[B]{
X: p1Neg.X,
Y: *c.baseApi.Select(s1bits[i], &p1.Y, &p1Neg.Y),
}
Acc = c.doubleAndAdd(Acc, B1)
B1 = &AffinePoint[B]{
X: p2Neg.X,
Y: *c.baseApi.Select(s2bits[i], &p2.Y, &p2Neg.Y),
}
Acc = c.Add(Acc, B1)

}

// i = 0
p1Neg = c.Add(p1Neg, Acc)
Acc = c.Select(s1bits[0], Acc, p1Neg)
p2Neg = c.Add(p2Neg, Acc)
Acc = c.Select(s2bits[0], Acc, p2Neg)

return Acc
}

// jointScalarMulGLV computes [s1]p1 + [s2]p2 using an endomorphism. It doesn't modify P, Q nor s.
// jointScalarMulGLV computes [s1]p1 + [s2]p2 using an endomorphism. It doesn't modify p1, p2 nor s1, s2.
//
// ⚠️ The scalars s1, s2 must be nonzero and the point p1, p2 different from (0,0), unless [algopts.WithCompleteArithmetic] option is set.
func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
Expand All @@ -691,8 +739,8 @@ func (c *Curve[B, S]) jointScalarMulGLV(p1, p2 *AffinePoint[B], s1, s2 *emulated
}
}

// jointScalarMulGLVUnsafe computes [s]Q + [t]R using Shamir's trick with an efficient endomorphism and returns it. It doesn't modify P, Q nor s.
// ⚠️ The scalar s must be nonzero and the point Q different from (0,0), unless [algopts.WithCompleteArithmetic] option is set.
// jointScalarMulGLVUnsafe computes [s]Q + [t]R using Shamir's trick with an efficient endomorphism and returns it. It doesn't modify Q, R nor s, t.
// ⚠️ The scalars must be nonzero and the points different from (0,0).
func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulated.Element[S]) *AffinePoint[B] {
var st S
frModulus := c.scalarApi.Modulus()
Expand Down
171 changes: 169 additions & 2 deletions std/algebra/emulated/sw_emulated/point_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,40 @@ func TestJointScalarMul6(t *testing.T) {
assert.NoError(err)
}

func TestJointScalarMul4(t *testing.T) {
assert := test.NewAssert(t)
p256 := elliptic.P256()
s1, err := rand.Int(rand.Reader, p256.Params().N)
assert.NoError(err)
s2, err := rand.Int(rand.Reader, p256.Params().N)
assert.NoError(err)
p1x, p1y := p256.ScalarBaseMult(s1.Bytes())
p2x, p2y := p256.ScalarBaseMult(s2.Bytes())
resx, resy := p256.ScalarMult(p1x, p1y, s1.Bytes())
tmpx, tmpy := p256.ScalarMult(p2x, p2y, s2.Bytes())
resx, resy = p256.Add(resx, resy, tmpx, tmpy)

circuit := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{}
witness := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](resx),
Y: emulated.ValueOf[emulated.P256Fp](resy),
},
}
err = test.IsSolved(&circuit, &witness, testCurve.ScalarField())
assert.NoError(err)
}

type JointScalarMulEdgeCasesTest[T, S emulated.FieldParams] struct {
P1, P2, Q AffinePoint[T]
S1, S2 emulated.Element[S]
Expand All @@ -1415,12 +1449,11 @@ func TestJointScalarMulEdgeCases6(t *testing.T) {
s2 := new(big.Int)
r1.BigInt(s1)
r2.BigInt(s2)
var res, res1, res2, gen2, infinity bw6761.G1Affine
var res1, res2, gen2, infinity bw6761.G1Affine
_, _, gen1, _ := bw6761.Generators()
gen2.Double(&gen1)
res1.ScalarMultiplication(&gen1, s1)
res2.ScalarMultiplication(&gen2, s2)
res.Add(&res1, &res2)

circuit := JointScalarMulEdgeCasesTest[emulated.BW6761Fp, emulated.BW6761Fr]{}
// s1*(0,0) + s2*(0,0) == (0,0)
Expand Down Expand Up @@ -1544,6 +1577,140 @@ func TestJointScalarMulEdgeCases6(t *testing.T) {
assert.NoError(err)
}

func TestJointScalarMulEdgeCases4(t *testing.T) {
assert := test.NewAssert(t)
p256 := elliptic.P256()
s1, err := rand.Int(rand.Reader, p256.Params().N)
assert.NoError(err)
s2, err := rand.Int(rand.Reader, p256.Params().N)
assert.NoError(err)
p1x, p1y := p256.ScalarBaseMult(s1.Bytes())
p2x, p2y := p256.ScalarBaseMult(s2.Bytes())
res1x, res1y := p256.ScalarMult(p1x, p1y, s1.Bytes())
res2x, res2y := p256.ScalarMult(p2x, p2y, s2.Bytes())

circuit := JointScalarMulEdgeCasesTest[emulated.P256Fp, emulated.P256Fr]{}
// s1*(0,0) + s2*(0,0) == (0,0)
witness1 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
}
err = test.IsSolved(&circuit, &witness1, testCurve.ScalarField())
assert.NoError(err)

// s1*P + s2*(0,0) == s1*P
witness2 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](res1x),
Y: emulated.ValueOf[emulated.P256Fp](res1y),
},
}
err = test.IsSolved(&circuit, &witness2, testCurve.ScalarField())
assert.NoError(err)

// s1*(0,0) + s2*Q == s2*Q
witness3 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](res2x),
Y: emulated.ValueOf[emulated.P256Fp](res2y),
},
}
err = test.IsSolved(&circuit, &witness3, testCurve.ScalarField())
assert.NoError(err)

// 0*P + 0*Q == (0,0)
witness4 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](0),
S2: emulated.ValueOf[emulated.P256Fr](0),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](0),
Y: emulated.ValueOf[emulated.P256Fp](0),
},
}
err = test.IsSolved(&circuit, &witness4, testCurve.ScalarField())
assert.NoError(err)

// 0*P + s2*Q == s2*Q
witness5 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](0),
S2: emulated.ValueOf[emulated.P256Fr](s2),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](res2x),
Y: emulated.ValueOf[emulated.P256Fp](res2y),
},
}
err = test.IsSolved(&circuit, &witness5, testCurve.ScalarField())
assert.NoError(err)

// s1*P + 0*Q == s1*P
witness6 := JointScalarMulTest[emulated.P256Fp, emulated.P256Fr]{
S1: emulated.ValueOf[emulated.P256Fr](s1),
S2: emulated.ValueOf[emulated.P256Fr](0),
P1: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p1x),
Y: emulated.ValueOf[emulated.P256Fp](p1y),
},
P2: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](p2x),
Y: emulated.ValueOf[emulated.P256Fp](p2y),
},
Q: AffinePoint[emulated.P256Fp]{
X: emulated.ValueOf[emulated.P256Fp](res1x),
Y: emulated.ValueOf[emulated.P256Fp](res1y),
},
}
err = test.IsSolved(&circuit, &witness6, testCurve.ScalarField())
assert.NoError(err)
}

type MuxCircuitTest[T, S emulated.FieldParams] struct {
Selector frontend.Variable
Inputs [8]AffinePoint[T]
Expand Down
Loading