Skip to content

Commit

Permalink
Perf: Improve MultiLin.Eval number of constraints (#788)
Browse files Browse the repository at this point in the history
* bench: multilin eval constraints number

* perf: fewer multilin folding constraints

* fix: correct nb constraints

* fix: panic if error

* perf: sometimes defer scaling of folding results
  • Loading branch information
Tabaie committed Jul 26, 2023
1 parent 860db7c commit b39b13f
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 12 deletions.
65 changes: 53 additions & 12 deletions std/polynomial/polynomial.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,71 @@ import (
type Polynomial []frontend.Variable
type MultiLin []frontend.Variable

var minFoldScaledLogSize = 16

// Evaluate assumes len(m) = 1 << len(at)
// it doesn't modify m
func (m MultiLin) Evaluate(api frontend.API, at []frontend.Variable) frontend.Variable {

eqs := make([]frontend.Variable, len(m))
eqs[0] = 1
for i, rI := range at {
prevSize := 1 << i
for j := prevSize - 1; j >= 0; j-- {
eqs[2*j+1] = api.Mul(rI, eqs[j])
eqs[2*j] = api.Sub(eqs[j], eqs[2*j+1]) // eq[2j] == (1 - rI) * eq[j]
_m := m.Clone()

/*minFoldScaledLogSize := 16
if api is r1cs {
minFoldScaledLogSize = math.MaxInt64 // no scaling for r1cs
}*/

scaleCorrectionFactor := frontend.Variable(1)
// at each iteration fold by at[i]
for len(_m) > 1 {
if len(_m) >= minFoldScaledLogSize {
scaleCorrectionFactor = api.Mul(scaleCorrectionFactor, _m.foldScaled(api, at[0]))
} else {
_m.fold(api, at[0])
}
_m = _m[:len(_m)/2]
at = at[1:]
}

if len(at) != 0 {
panic("incompatible evaluation vector size")
}

return api.Mul(_m[0], scaleCorrectionFactor)
}

// fold fixes the value of m's first variable to at, thus halving m's required bookkeeping table size
// WARNING: The user should halve m themselves after the call
func (m MultiLin) fold(api frontend.API, at frontend.Variable) {
zero := m[:len(m)/2]
one := m[len(m)/2:]
for j := range zero {
diff := api.Sub(one[j], zero[j])
zero[j] = api.MulAcc(zero[j], diff, at)
}
}

evaluation := frontend.Variable(0)
for j := range m {
evaluation = api.MulAcc(evaluation, eqs[j], m[j])
// foldScaled(m, at) = fold(m, at) / (1 - at)
// it returns 1 - at, for convenience
func (m MultiLin) foldScaled(api frontend.API, at frontend.Variable) (denom frontend.Variable) {
denom = api.Sub(1, at)
coeff := api.Div(at, denom)
zero := m[:len(m)/2]
one := m[len(m)/2:]
for j := range zero {
zero[j] = api.MulAcc(zero[j], one[j], coeff)
}
return evaluation
return
}

func (m MultiLin) NumVars() int {
return bits.TrailingZeros(uint(len(m)))
}

func (m MultiLin) Clone() MultiLin {
clone := make(MultiLin, len(m))
copy(clone, m)
return clone
}

func (p Polynomial) Eval(api frontend.API, at frontend.Variable) (pAt frontend.Variable) {
pAt = 0

Expand Down
50 changes: 50 additions & 0 deletions std/polynomial/polynomial_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package polynomial

import (
"errors"
"fmt"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/test"
"testing"
)
Expand Down Expand Up @@ -70,6 +73,31 @@ func TestEvalDeltasQuadratic(t *testing.T) {
testEvalDeltas(t, 3, []int64{1, -3, 3})
}

type foldMultiLinCircuit struct {
M []frontend.Variable
At frontend.Variable
Result []frontend.Variable
}

func (c *foldMultiLinCircuit) Define(api frontend.API) error {
if len(c.M) != 2*len(c.Result) {
return errors.New("folding size mismatch")
}
m := MultiLin(c.M)
m.fold(api, c.At)
for i := range c.Result {
api.AssertIsEqual(m[i], c.Result[i])
}
return nil
}

func TestFoldSmall(t *testing.T) {
test.NewAssert(t).SolvingSucceeded(
&foldMultiLinCircuit{M: make([]frontend.Variable, 4), Result: make([]frontend.Variable, 2)},
&foldMultiLinCircuit{M: []frontend.Variable{0, 1, 2, 3}, At: 2, Result: []frontend.Variable{4, 5}},
)
}

type evalMultiLinCircuit struct {
M []frontend.Variable `gnark:",public"`
At []frontend.Variable `gnark:",secret"`
Expand Down Expand Up @@ -204,3 +232,25 @@ func int64SliceToVariableSlice(slice []int64) []frontend.Variable {
}
return res
}

func ExampleMultiLin_Evaluate() {
const logSize = 20
const size = 1 << logSize
m := MultiLin(make([]frontend.Variable, size))
e := MultiLin(make([]frontend.Variable, logSize))

cs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &evalMultiLinCircuit{M: m, At: e, Evaluation: 0})
if err != nil {
panic(err)
}
fmt.Println("r1cs size:", cs.GetNbConstraints())

cs, err = frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &evalMultiLinCircuit{M: m, At: e, Evaluation: 0})
if err != nil {
panic(err)
}
fmt.Println("scs size:", cs.GetNbConstraints())

// Output: r1cs size: 1048627
//scs size: 2097226
}

0 comments on commit b39b13f

Please sign in to comment.