Skip to content

Commit

Permalink
perf: big optim for JointScalarMul and MSM
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni committed Mar 6, 2024
1 parent 92b6a8d commit c7d831d
Showing 1 changed file with 24 additions and 228 deletions.
252 changes: 24 additions & 228 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -939,9 +939,18 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
}
tablePhiS[3] = c.Neg(tablePhiS[2])

// suppose first bit is 1 and set:
// suppose first bits are 1 and set:
// Acc = Q + R + Φ(Q) + Φ(R)
Acc := c.Add(tableS[1], tablePhiS[1])
B1 := Acc
// and then add the base point G to it to avoid incomplete additions in the
// loop, because when doing doubleAndAdd(Acc, Bi) as (Acc+Bi)+Acc it might
// happen that Acc==Bi or -Bi. But now we force Acc to be different than
// the stored Bi. However we need at the end to subtract [2^nbits]G
// (harcoded) from the result.
//
// Acc = Q + R + Φ(Q) + Φ(R) + G
Acc = c.Add(Acc, c.Generator())

s1bits := c.scalarApi.ToBits(s1)
s2bits := c.scalarApi.ToBits(s2)
Expand All @@ -951,7 +960,6 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat

// At each iteration look up the point P from:
// B1 = +Q + R + Φ(Q) + Φ(R)
B1 := Acc
// B2 = +Q + R + Φ(Q) - Φ(R)
B2 := c.Add(tableS[1], tablePhiS[2])
// B3 = +Q + R - Φ(Q) + Φ(R)
Expand Down Expand Up @@ -1012,12 +1020,11 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
&B15.Y, &B7.Y, &B13.Y, &B5.Y, &B11.Y, &B3.Y, &B9.Y, &B1.Y,
),
}
Acc = c.double(Acc)
Acc = c.add(Acc, P)
Acc = c.doubleAndAdd(Acc, P)
}

// i = 0
// subtract the initial point from the accumulator when first bit was 0
// subtract the initial points from the accumulator when first bits are 0
tableQ[0] = c.Add(tableQ[0], Acc)
Acc = c.Select(s1bits[0], Acc, tableQ[0])
tablePhiQ[0] = c.Add(tablePhiQ[0], Acc)
Expand All @@ -1027,6 +1034,10 @@ func (c *Curve[B, S]) jointScalarMulGLVUnsafe(Q, R *AffinePoint[B], s, t *emulat
tablePhiR[0] = c.Add(tablePhiR[0], Acc)
Acc = c.Select(t2bits[0], Acc, tablePhiR[0])

// subtract [2^nbits]G since we added G at the beginning
gm := c.GeneratorMultiples()[nbits-1]
Acc = c.Add(Acc, c.Neg(&gm))

return Acc

}
Expand Down Expand Up @@ -1148,46 +1159,14 @@ func (c *Curve[B, S]) scalarMulBaseGeneric(s *emulated.Element[S], opts ...algop
return res
}

// JointScalarMulBase computes [s]G + [t]Q and returns it, where G is the
// fixed generator. It doesn't modify Q, s and t.
// JointScalarMulBase computes s2 * p + s1 * g and returns it, where g is the
// fixed generator. It doesn't modify p, s1 and s2.
//
// This function doesn't check that the Q is on the curve. See AssertIsOnCurve.
// ⚠️ p must NOT be (0,0).
// ⚠️ s1 and s2 must NOT be 0.
//
// JointScalarMulBase calls jointScalarMulGeneric or jointScalarMulBaseGLV depending on whether an efficient endomorphism is available.
func (c *Curve[B, S]) JointScalarMulBase(Q *AffinePoint[B], s, t *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
if c.params.Eigenvalue != nil && c.params.ThirdRootOne != nil {
return c.jointScalarMulBaseGLV(Q, s, t, opts...)

} else {
return c.jointScalarMulGeneric(c.Generator(), Q, t, s, opts...)

}

}

func (c *Curve[B, S]) jointScalarMulBaseGLV(Q *AffinePoint[B], s, t *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
cfg, err := algopts.NewConfig(opts...)
if err != nil {
panic(fmt.Sprintf("parse opts: %v", err))
}
if cfg.CompleteArithmetic {
// TODO @yelhousni: optimize
res1 := c.ScalarMulBase(s, opts...)
res2 := c.scalarMulGLV(Q, t, opts...)
return c.AddUnified(res1, res2)
} else {
return c.jointScalarMulBaseGLVUnsafe(Q, s, t)
}
}

// JointScalarMulBaseGLVUnsafe computes [s]G + [t]Q using an endomorphism and returns it, where G is the
// fixed generator. It doesn't modify Q, s and t.
//
// ⚠️ Q must NOT be (0,0).
// ⚠️ s and t must NOT be 0.
//
// JointScalarMulBase is used to verify an ECDSA signature (r,s) on the
// secp256k1 curve. In this case, p is a public key, s2=r/s and s1=hash/s.
// JointScalarMulBase is used to verify an ECDSA signature (r,s) for example on
// the secp256k1 curve. In this case, p is a public key, s2=r/s and s1=hash/s.
// - hash cannot be 0, because of pre-image resistance.
// - r cannot be 0, because r is the x coordinate of a random point on
// secp256k1 (y²=x³+7 mod p) and 7 is not a square mod p. For any other
Expand All @@ -1196,191 +1175,8 @@ func (c *Curve[B, S]) jointScalarMulBaseGLV(Q *AffinePoint[B], s, t *emulated.El
//
// The [EVM] specifies these checks, wich are performed on the zkEVM
// arithmetization side before calling the circuit that uses this method.
func (c *Curve[B, S]) jointScalarMulBaseGLVUnsafe(Q *AffinePoint[B], s, t *emulated.Element[S]) *AffinePoint[B] {
var st S
frModulus := c.scalarApi.Modulus()
sd, err := c.scalarApi.NewHint(decomposeScalarG1, 7, s, c.eigenvalue, frModulus)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
s1, s2, s3, s4, s5, s6 := sd[0], sd[1], sd[3], sd[4], sd[5], sd[6]

td, err := c.scalarApi.NewHint(decomposeScalarG1, 7, t, c.eigenvalue, frModulus)
if err != nil {
// err is non-nil only for invalid number of inputs
panic(err)
}
t1, t2, t3, t4, t5, t6 := td[0], td[1], td[3], td[4], td[5], td[6]

c.scalarApi.AssertIsEqual(
c.scalarApi.Add(s5, c.scalarApi.Mul(s6, c.eigenvalue)),
c.scalarApi.Add(s, c.scalarApi.Mul(frModulus, sd[2])),
)
c.scalarApi.AssertIsEqual(
c.scalarApi.Add(t5, c.scalarApi.Mul(t6, c.eigenvalue)),
c.scalarApi.Add(t, c.scalarApi.Mul(frModulus, td[2])),
)

// s1, s2 can be negative (bigints) in the hint, but will be reduced
// in-circuit modulo the SNARK scalar field and not the emulated field. So
// we return in the hint both |s1|, |s2| and the flags s3=0/1, s4=0/1 to
// negate the point instead of the corresponding scalar. Since s3, s4 are
// either 0 or 1, we only need to check the first limb is zero or not.
// Same goes for t1, t2.
selector1 := c.api.IsZero(s3.Limbs[0])
selector2 := c.api.IsZero(s4.Limbs[0])
selector3 := c.api.IsZero(t3.Limbs[0])
selector4 := c.api.IsZero(t4.Limbs[0])

// precompute -Q, -Φ(Q), Φ(Q)
var tableQ, tablePhiQ [2]*AffinePoint[B]
negQY := c.baseApi.Neg(&Q.Y)
tableQ[1] = &AffinePoint[B]{
X: Q.X,
Y: *c.baseApi.Select(selector1, negQY, &Q.Y),
}
tableQ[0] = c.Neg(tableQ[1])
tablePhiQ[1] = &AffinePoint[B]{
X: *c.baseApi.Mul(&Q.X, c.thirdRootOne),
Y: *c.baseApi.Select(selector2, negQY, &Q.Y),
}
tablePhiQ[0] = c.Neg(tablePhiQ[1])
// precompute -R, -Φ(R), Φ(R)
R := c.Generator()
negRY := c.baseApi.Neg(&R.Y)
var tableR, tablePhiR [2]*AffinePoint[B]
tableR[1] = &AffinePoint[B]{
X: R.X,
Y: *c.baseApi.Select(selector3, negRY, &R.Y),
}
tableR[0] = c.Neg(tableR[1])
tablePhiR[1] = &AffinePoint[B]{
X: *c.baseApi.Mul(&R.X, c.thirdRootOne),
Y: *c.baseApi.Select(selector4, negRY, &R.Y),
}
tablePhiR[0] = c.Neg(tablePhiR[1])
// precompute Q+R, -Q-R, Q-R, -Q+R, Φ(Q)+Φ(R), -Φ(Q)-Φ(R), Φ(Q)-Φ(R), -Φ(Q)+Φ(R)
var tableS, tablePhiS [4]*AffinePoint[B]
tableS[0] = tableQ[0]
tableS[0] = c.Add(tableS[0], tableR[0])
tableS[1] = c.Neg(tableS[0])
tableS[2] = tableQ[1]
tableS[2] = c.Add(tableS[2], tableR[0])
tableS[3] = c.Neg(tableS[2])
f0 := c.baseApi.Mul(&tableS[0].X, c.thirdRootOne)
f2 := c.baseApi.Mul(&tableS[2].X, c.thirdRootOne)
tablePhiS[0] = &AffinePoint[B]{
X: *c.baseApi.Select(c.api.Xor(selector2, selector4), f2, f0),
Y: *c.baseApi.Lookup2(selector2, selector4, &tableS[0].Y, &tableS[2].Y, &tableS[3].Y, &tableS[1].Y),
}
tablePhiS[1] = c.Neg(tablePhiS[0])
tablePhiS[2] = &AffinePoint[B]{
X: *c.baseApi.Select(c.api.Xor(selector2, selector4), f0, f2),
Y: *c.baseApi.Lookup2(selector2, selector4, &tableS[2].Y, &tableS[0].Y, &tableS[1].Y, &tableS[3].Y),
}
tablePhiS[3] = c.Neg(tablePhiS[2])

// suppose first bit is 1 and set:
// Acc = Q + R + Φ(Q) + Φ(R)
Acc := c.Add(tableS[1], tablePhiS[1])
B1 := Acc
// and then add R to it to avoid incomplete addition in the loop because
// when doing doubleAndAdd(Acc, Bi) as (Acc+Bi)+Acc it might happen that
// Acc==Bi. But now we force Acc to be different that the stored Bi.
// However we need to subtract [2^{nbits-1}]R (harcoded) from the result.
//
// Acc = Q + [2]R + Φ(Q) + Φ(R)
Acc = c.Add(Acc, tableR[1])

s1bits := c.scalarApi.ToBits(s1)
s2bits := c.scalarApi.ToBits(s2)
t1bits := c.scalarApi.ToBits(t1)
t2bits := c.scalarApi.ToBits(t2)
nbits := st.Modulus().BitLen()>>1 + 2

// At each iteration look up the point P from:
// B1 = +Q + R + Φ(Q) + Φ(R)
// B2 = +Q + R + Φ(Q) - Φ(R)
B2 := c.Add(tableS[1], tablePhiS[2])
// B3 = +Q + R - Φ(Q) + Φ(R)
B3 := c.Add(tableS[1], tablePhiS[3])
// B4 = +Q + R - Φ(Q) - Φ(R)
B4 := c.Add(tableS[1], tablePhiS[0])
// B5 = +Q - R + Φ(Q) + Φ(R)
B5 := c.Add(tableS[2], tablePhiS[1])
// B6 = +Q - R + Φ(Q) - Φ(R)
B6 := c.Add(tableS[2], tablePhiS[2])
// B7 = +Q - R - Φ(Q) + Φ(R)
B7 := c.Add(tableS[2], tablePhiS[3])
// B8 = +Q - R - Φ(Q) - Φ(R)
B8 := c.Add(tableS[2], tablePhiS[0])
// B9 = -Q + R + Φ(Q) + Φ(R)
B9 := c.Neg(B8)
// B10 = -Q + R + Φ(Q) - Φ(R)
B10 := c.Neg(B7)
// B11 = -Q + R - Φ(Q) + Φ(R)
B11 := c.Neg(B6)
// B12 = -Q + R - Φ(Q) - Φ(R)
B12 := c.Neg(B5)
// B13 = -Q - R + Φ(Q) + Φ(R)
B13 := c.Neg(B4)
// B14 = -Q - R + Φ(Q) - Φ(R)
B14 := c.Neg(B3)
// B15 = -Q - R - Φ(Q) + Φ(R)
B15 := c.Neg(B2)
// B16 = -Q - R - Φ(Q) - Φ(R)
B16 := c.Neg(B1)
// and compute [2]Acc+P. We don't use doubleAndAdd as it involves edge cases.
// Note: half of the Bi points have the same X coordinates.

var P *AffinePoint[B]
for i := nbits - 1; i > 0; i-- {
// selectorY takes values in [0,15]
selectorY := c.api.Add(
s1bits[i],
c.api.Mul(s2bits[i], 2),
c.api.Mul(t1bits[i], 4),
c.api.Mul(t2bits[i], 8),
)
// selectorX takes values in [0,7] with:
// - when selectorY < 8: selectorX = selectorY
// - when selectorY >= 8: selectorX = 15 - selectorY
selectorX := c.api.Add(
c.api.Mul(selectorY, c.api.Sub(1, c.api.Mul(t2bits[i], 2))),
c.api.Mul(t2bits[i], 15),
)
// Bi.Y are distints so we need a 16-to-1 multiplexer,
// but only half of the Bi.X are distinct so we need a 8-to-1.
P = &AffinePoint[B]{
X: *c.baseApi.Mux(selectorX,
&B16.X, &B8.X, &B14.X, &B6.X, &B12.X, &B4.X, &B10.X, &B2.X,
),
Y: *c.baseApi.Mux(selectorY,
&B16.Y, &B8.Y, &B14.Y, &B6.Y, &B12.Y, &B4.Y, &B10.Y, &B2.Y,
&B15.Y, &B7.Y, &B13.Y, &B5.Y, &B11.Y, &B3.Y, &B9.Y, &B1.Y,
),
}
Acc = c.doubleAndAdd(Acc, P)
}

// i = 0
// subtract the initial point from the accumulator when first bit was 0
tableQ[0] = c.Add(tableQ[0], Acc)
Acc = c.Select(s1bits[0], Acc, tableQ[0])
tablePhiQ[0] = c.Add(tablePhiQ[0], Acc)
Acc = c.Select(s2bits[0], Acc, tablePhiQ[0])
tableR[0] = c.Add(tableR[0], Acc)
Acc = c.Select(t1bits[0], Acc, tableR[0])
tablePhiR[0] = c.Add(tablePhiR[0], Acc)
Acc = c.Select(t2bits[0], Acc, tablePhiR[0])

// subtract [2^{nbits-1}]R since we added R at the beginning
gm := c.GeneratorMultiples()[nbits-1]
Acc = c.Add(Acc, c.Neg(&gm))

return Acc

func (c *Curve[B, S]) JointScalarMulBase(p *AffinePoint[B], s2, s1 *emulated.Element[S], opts ...algopts.AlgebraOption) *AffinePoint[B] {
return c.jointScalarMul(c.Generator(), p, s1, s2, opts...)
}

// MultiScalarMul computes the multi scalar multiplication of the points P and
Expand Down

0 comments on commit c7d831d

Please sign in to comment.