Skip to content

Commit

Permalink
Merge pull request #506 from ConsenSys/perf/kzg-in-circuit
Browse files Browse the repository at this point in the history
Perf: KZG in circuit
  • Loading branch information
yelhousni committed Mar 2, 2023
2 parents c866c30 + 9b4739d commit 3149bf5
Show file tree
Hide file tree
Showing 19 changed files with 582 additions and 16 deletions.
13 changes: 13 additions & 0 deletions std/algebra/fields_bls12377/e2.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,16 @@ func (e *E2) Select(api frontend.API, b frontend.Variable, r1, r2 E2) *E2 {

return e
}

// Lookup2 implements two-bit lookup. It returns:
// - r1 if b1=0 and b2=0,
// - r2 if b1=0 and b2=1,
// - r3 if b1=1 and b2=0,
// - r3 if b1=1 and b2=1.
func (e *E2) Lookup2(api frontend.API, b1, b2 frontend.Variable, r1, r2, r3, r4 E2) *E2 {

e.A0 = api.Lookup2(b1, b2, r1.A0, r2.A0, r3.A0, r4.A0)
e.A1 = api.Lookup2(b1, b2, r1.A1, r2.A1, r3.A1, r4.A1)

return e
}
13 changes: 13 additions & 0 deletions std/algebra/fields_bls24315/e2.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,16 @@ func (e *E2) Select(api frontend.API, b frontend.Variable, r1, r2 E2) *E2 {

return e
}

// Lookup2 implements two-bit lookup. It returns:
// - r1 if b1=0 and b2=0,
// - r2 if b1=0 and b2=1,
// - r3 if b1=1 and b2=0,
// - r3 if b1=1 and b2=1.
func (e *E2) Lookup2(api frontend.API, b1, b2 frontend.Variable, r1, r2, r3, r4 E2) *E2 {

e.A0 = api.Lookup2(b1, b2, r1.A0, r2.A0, r3.A0, r4.A0)
e.A1 = api.Lookup2(b1, b2, r1.A1, r2.A1, r3.A1, r4.A1)

return e
}
13 changes: 13 additions & 0 deletions std/algebra/fields_bls24315/e4.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,16 @@ func (e *E4) Select(api frontend.API, b frontend.Variable, r1, r2 E4) *E4 {

return e
}

// Lookup2 implements two-bit lookup. It returns:
// - r1 if b1=0 and b2=0,
// - r2 if b1=0 and b2=1,
// - r3 if b1=1 and b2=0,
// - r3 if b1=1 and b2=1.
func (e *E4) Lookup2(api frontend.API, b1, b2 frontend.Variable, r1, r2, r3, r4 E4) *E4 {

e.B0.Lookup2(api, b1, b2, r1.B0, r2.B0, r3.B0, r4.B0)
e.B1.Lookup2(api, b1, b2, r1.B1, r2.B1, r3.B1, r4.B1)

return e
}
33 changes: 33 additions & 0 deletions std/algebra/sw_bls12377/g1.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,36 @@ func (p *G1Affine) DoubleAndAdd(api frontend.API, p1, p2 *G1Affine) *G1Affine {

return p
}

// ScalarMulBase computes s * g1 and returns it, where g1 is the fixed generator. It doesn't modify s.
func (P *G1Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G1Affine {

points := getCurvePoints()

sBits := api.ToBinary(s, 253)

var res, tmp G1Affine

// i = 1, 2
// gm[0] = 3g, gm[1] = 5g, gm[2] = 7g
res.X = api.Lookup2(sBits[1], sBits[2], points.G1x, points.G1m[0][0], points.G1m[1][0], points.G1m[2][0])
res.Y = api.Lookup2(sBits[1], sBits[2], points.G1y, points.G1m[0][1], points.G1m[1][1], points.G1m[2][1])

for i := 3; i < 253; i++ {
// gm[i] = [2^i]g
tmp.X = res.X
tmp.Y = res.Y
tmp.AddAssign(api, G1Affine{points.G1m[i][0], points.G1m[i][1]})
res.Select(api, sBits[i], tmp, res)
}

// i = 0
tmp.Neg(api, G1Affine{points.G1x, points.G1y})
tmp.AddAssign(api, res)
res.Select(api, sBits[0], res, tmp)

P.X = res.X
P.Y = res.Y

return P
}
31 changes: 31 additions & 0 deletions std/algebra/sw_bls12377/g1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,37 @@ func TestScalarMulG1(t *testing.T) {
assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761))
}

type g1varScalarMulBase struct {
C G1Affine `gnark:",public"`
R frontend.Variable
}

func (circuit *g1varScalarMulBase) Define(api frontend.API) error {
expected := G1Affine{}
expected.ScalarMulBase(api, circuit.R)
expected.AssertIsEqual(api, circuit.C)
return nil
}

func TestVarScalarMulBaseG1(t *testing.T) {
var c bls12377.G1Affine
gJac, _, _, _ := bls12377.Generators()

// create the cs
var circuit, witness g1varScalarMulBase
var r fr.Element
_, _ = r.SetRandom()
witness.R = r.String()
// compute the result
var br big.Int
gJac.ScalarMultiplication(&gJac, r.BigInt(&br))
c.FromJacobian(&gJac)
witness.C.Assign(&c)

assert := test.NewAssert(t)
assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761))
}

func randomPointG1() bls12377.G1Jac {

p1, _, _, _ := bls12377.Generators()
Expand Down
65 changes: 65 additions & 0 deletions std/algebra/sw_bls12377/g2.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,68 @@ func (p *G2Affine) DoubleAndAdd(api frontend.API, p1, p2 *G2Affine) *G2Affine {

return p
}

// ScalarMulBase computes s * g2 and returns it, where g2 is the fixed generator. It doesn't modify s.
func (P *G2Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G2Affine {

points := getTwistPoints()

sBits := api.ToBinary(s, 253)

var res, tmp G2Affine

// i = 1, 2
// gm[0] = 3g, gm[1] = 5g, gm[2] = 7g
res.X.Lookup2(api, sBits[1], sBits[2],
fields_bls12377.E2{
A0: points.G2x[0],
A1: points.G2x[1]},
fields_bls12377.E2{
A0: points.G2m[0][0],
A1: points.G2m[0][1]},
fields_bls12377.E2{
A0: points.G2m[1][0],
A1: points.G2m[1][1]},
fields_bls12377.E2{
A0: points.G2m[2][0],
A1: points.G2m[2][1]})
res.Y.Lookup2(api, sBits[1], sBits[2],
fields_bls12377.E2{
A0: points.G2y[0],
A1: points.G2y[1]},
fields_bls12377.E2{
A0: points.G2m[0][2],
A1: points.G2m[0][3]},
fields_bls12377.E2{
A0: points.G2m[1][2],
A1: points.G2m[1][3]},
fields_bls12377.E2{
A0: points.G2m[2][2],
A1: points.G2m[2][3]})

for i := 3; i < 253; i++ {
// gm[i] = [2^i]g
tmp.X = res.X
tmp.Y = res.Y
tmp.AddAssign(api, G2Affine{
fields_bls12377.E2{
A0: points.G2m[i][0],
A1: points.G2m[i][1]},
fields_bls12377.E2{
A0: points.G2m[i][2],
A1: points.G2m[i][3]}})
res.Select(api, sBits[i], tmp, res)
}

// i = 0
tmp.Neg(api, G2Affine{
fields_bls12377.E2{A0: points.G2x[0], A1: points.G2x[1]},
fields_bls12377.E2{A0: points.G2y[0], A1: points.G2y[1]}})
tmp.AddAssign(api, res)
res.Select(api, sBits[0], res, tmp)

P.X = res.X
P.Y = res.Y

return P
}
32 changes: 32 additions & 0 deletions std/algebra/sw_bls12377/g2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,38 @@ func TestScalarMulG2(t *testing.T) {
assert := test.NewAssert(t)
assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761))
}

type g2varScalarMulBase struct {
C G2Affine `gnark:",public"`
R frontend.Variable
}

func (circuit *g2varScalarMulBase) Define(api frontend.API) error {
expected := G2Affine{}
expected.ScalarMulBase(api, circuit.R)
expected.AssertIsEqual(api, circuit.C)
return nil
}

func TestVarScalarMulBaseG2(t *testing.T) {
var c bls12377.G2Affine
_, gJac, _, _ := bls12377.Generators()

// create the cs
var circuit, witness g2varScalarMulBase
var r fr.Element
_, _ = r.SetRandom()
witness.R = r.String()
// compute the result
var br big.Int
gJac.ScalarMultiplication(&gJac, r.BigInt(&br))
c.FromJacobian(&gJac)
witness.C.Assign(&c)

assert := test.NewAssert(t)
assert.SolvingSucceeded(&circuit, &witness, test.WithCurves(ecc.BW6_761))
}

func randomPointG2() bls12377.G2Jac {
_, p2, _, _ := bls12377.Generators()

Expand Down
42 changes: 42 additions & 0 deletions std/algebra/sw_bls12377/inner.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"

"github.com/consensys/gnark-crypto/ecc"
bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377"
"github.com/consensys/gnark/frontend"
)

Expand Down Expand Up @@ -62,3 +63,44 @@ func getInnerCurveConfig(outerCurveScalarField *big.Int) *innerConfig {

return &innerConfigBW6_761
}

var (
computedCurveTable [][2]*big.Int
computedTwistTable [][4]*big.Int
)

func init() {
computedCurveTable = computeCurveTable()
computedTwistTable = computeTwistTable()
}

type curvePoints struct {
G1x *big.Int // base point x
G1y *big.Int // base point y
G1m [][2]*big.Int // m*base points (x,y)
}

func getCurvePoints() curvePoints {
_, _, g1aff, _ := bls12377.Generators()
return curvePoints{
G1x: g1aff.X.BigInt(new(big.Int)),
G1y: g1aff.Y.BigInt(new(big.Int)),
G1m: computedCurveTable,
}
}

type twistPoints struct {
G2x [2]*big.Int // base point x ∈ E2
G2y [2]*big.Int // base point y ∈ E2
G2m [][4]*big.Int // m*base points (x,y)
}

func getTwistPoints() twistPoints {
_, _, _, g2aff := bls12377.Generators()
return twistPoints{
G2x: [2]*big.Int{g2aff.X.A0.BigInt(new(big.Int)), g2aff.X.A1.BigInt(new(big.Int))},
G2y: [2]*big.Int{g2aff.Y.A0.BigInt(new(big.Int)), g2aff.Y.A1.BigInt(new(big.Int))},
G2m: computedTwistTable,
}

}
59 changes: 59 additions & 0 deletions std/algebra/sw_bls12377/inner_compute.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package sw_bls12377

import (
"math/big"

bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377"
)

func computeCurveTable() [][2]*big.Int {
G1jac, _, _, _ := bls12377.Generators()
table := make([][2]*big.Int, 253)
tmp := new(bls12377.G1Jac).Set(&G1jac)
aff := new(bls12377.G1Affine)
jac := new(bls12377.G1Jac)
for i := 1; i < 253; i++ {
tmp = tmp.Double(tmp)
switch i {
case 1, 2:
jac.Set(tmp).AddAssign(&G1jac)
aff.FromJacobian(jac)
table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))}
case 3:
jac.Set(tmp).SubAssign(&G1jac)
aff.FromJacobian(jac)
table[i-1] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))}
fallthrough
default:
aff.FromJacobian(tmp)
table[i] = [2]*big.Int{aff.X.BigInt(new(big.Int)), aff.Y.BigInt(new(big.Int))}
}
}
return table[:]
}

func computeTwistTable() [][4]*big.Int {
_, G2jac, _, _ := bls12377.Generators()
table := make([][4]*big.Int, 253)
tmp := new(bls12377.G2Jac).Set(&G2jac)
aff := new(bls12377.G2Affine)
jac := new(bls12377.G2Jac)
for i := 1; i < 253; i++ {
tmp = tmp.Double(tmp)
switch i {
case 1, 2:
jac.Set(tmp).AddAssign(&G2jac)
aff.FromJacobian(jac)
table[i-1] = [4]*big.Int{aff.X.A0.BigInt(new(big.Int)), aff.X.A1.BigInt(new(big.Int)), aff.Y.A0.BigInt(new(big.Int)), aff.Y.A1.BigInt(new(big.Int))}
case 3:
jac.Set(tmp).SubAssign(&G2jac)
aff.FromJacobian(jac)
table[i-1] = [4]*big.Int{aff.X.A0.BigInt(new(big.Int)), aff.X.A1.BigInt(new(big.Int)), aff.Y.A0.BigInt(new(big.Int)), aff.Y.A1.BigInt(new(big.Int))}
fallthrough
default:
aff.FromJacobian(tmp)
table[i] = [4]*big.Int{aff.X.A0.BigInt(new(big.Int)), aff.X.A1.BigInt(new(big.Int)), aff.Y.A0.BigInt(new(big.Int)), aff.Y.A1.BigInt(new(big.Int))}
}
}
return table[:]
}
33 changes: 33 additions & 0 deletions std/algebra/sw_bls24315/g1.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,36 @@ func (p *G1Affine) DoubleAndAdd(api frontend.API, p1, p2 *G1Affine) *G1Affine {

return p
}

// ScalarMulBase computes s * g1 and returns it, where g1 is the fixed generator. It doesn't modify s.
func (P *G1Affine) ScalarMulBase(api frontend.API, s frontend.Variable) *G1Affine {

points := getCurvePoints()

sBits := api.ToBinary(s, 253)

var res, tmp G1Affine

// i = 1, 2
// gm[0] = 3g, gm[1] = 5g, gm[2] = 7g
res.X = api.Lookup2(sBits[1], sBits[2], points.G1x, points.G1m[0][0], points.G1m[1][0], points.G1m[2][0])
res.Y = api.Lookup2(sBits[1], sBits[2], points.G1y, points.G1m[0][1], points.G1m[1][1], points.G1m[2][1])

for i := 3; i < 253; i++ {
// gm[i] = [2^i]g
tmp.X = res.X
tmp.Y = res.Y
tmp.AddAssign(api, G1Affine{points.G1m[i][0], points.G1m[i][1]})
res.Select(api, sBits[i], tmp, res)
}

// i = 0
tmp.Neg(api, G1Affine{points.G1x, points.G1y})
tmp.AddAssign(api, res)
res.Select(api, sBits[0], res, tmp)

P.X = res.X
P.Y = res.Y

return P
}
Loading

0 comments on commit 3149bf5

Please sign in to comment.