diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index cf43109b60..3973b20f0a 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -30,16 +30,6 @@ jobs: run: go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest - name: gofmt run: if [[ -n $(gofmt -l .) ]]; then echo "please run gofmt"; exit 1; fi - - name: go vet - run: go vet ./... - - name: staticcheck - run: | - go install honnef.co/go/tools/cmd/staticcheck@23e1086441d24fed9f668ad1cd4374245118b590 - staticcheck ./... - - name: gosec - run: | - go install github.com/securego/gosec/v2/cmd/gosec@latest - gosec -exclude G204 ./... - name: generated files should not be modified run: | go generate ./... diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index d33dc30d24..fcdaecced0 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -34,16 +34,6 @@ jobs: run: go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest - name: gofmt run: if [[ -n $(gofmt -l .) ]]; then echo "please run gofmt"; exit 1; fi - - name: go vet - run: go vet ./... - - name: staticcheck - run: | - go install honnef.co/go/tools/cmd/staticcheck@23e1086441d24fed9f668ad1cd4374245118b590 - staticcheck ./... - - name: gosec - run: | - go install github.com/securego/gosec/v2/cmd/gosec@latest - gosec -exclude G204 ./... - name: generated files should not be modified run: | go generate ./... diff --git a/backend/plonkfri/plonkfri.go b/backend/plonkfri/plonkfri.go index b69f1a0acf..dd541e3e6d 100644 --- a/backend/plonkfri/plonkfri.go +++ b/backend/plonkfri/plonkfri.go @@ -41,6 +41,10 @@ import ( witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness" witness_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/witness" witness_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/witness" + + cs_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/cs" + plonk_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/plonkfri" + witness_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/witness" ) // Proof represents a Plonk proof generated by plonk.Prove @@ -86,6 +90,8 @@ func Setup(ccs frontend.CompiledConstraintSystem) (ProvingKey, VerifyingKey, err return plonk_bls24315.Setup(tccs) case *cs_bw6633.SparseR1CS: return plonk_bw6633.Setup(tccs) + case *cs_bls24317.SparseR1CS: + return plonk_bls24317.Setup(tccs) default: panic("unrecognized SparseR1CS curve type") } @@ -147,7 +153,12 @@ func Prove(ccs frontend.CompiledConstraintSystem, pk ProvingKey, fullWitness *wi return nil, witness.ErrInvalidWitness } return plonk_bls24315.Prove(tccs, pk.(*plonk_bls24315.ProvingKey), *w, opt) - + case *cs_bls24317.SparseR1CS: + w, ok := fullWitness.Vector.(*witness_bls24317.Witness) + if !ok { + return nil, witness.ErrInvalidWitness + } + return plonk_bls24317.Prove(tccs, pk.(*plonk_bls24317.ProvingKey), *w, opt) default: panic("unrecognized SparseR1CS curve type") } @@ -199,6 +210,12 @@ func Verify(proof Proof, vk VerifyingKey, publicWitness *witness.Witness) error return witness.ErrInvalidWitness } return plonk_bls24315.Verify(_proof, vk.(*plonk_bls24315.VerifyingKey), *w) + case *plonk_bls24317.Proof: + w, ok := publicWitness.Vector.(*witness_bls24317.Witness) + if !ok { + return witness.ErrInvalidWitness + } + return plonk_bls24317.Verify(_proof, vk.(*plonk_bls24317.VerifyingKey), *w) default: panic("unrecognized proof type") diff --git a/examples/emulated/emulated.go b/examples/emulated/emulated.go new file mode 100644 index 0000000000..d925b6df8a --- /dev/null +++ b/examples/emulated/emulated.go @@ -0,0 +1,23 @@ +package emulated + +import ( + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/math/emulated" +) + +type Circuit struct { + // Limbs of non-native elements X, Y and Res + X, Y, Res emulated.Element[emulated.Secp256k1] +} + +func (circuit *Circuit) Define(api frontend.API) error { + // wrap API to work in SECP256k1 scalar field + secp256k1, err := emulated.NewField[emulated.Secp256k1](api) + if err != nil { + return err + } + + tmp := secp256k1.Mul(circuit.X, circuit.Y) + secp256k1.AssertIsEqual(tmp, circuit.Res) + return nil +} diff --git a/examples/emulated/emulated_test.go b/examples/emulated/emulated_test.go new file mode 100644 index 0000000000..fa3b606dd4 --- /dev/null +++ b/examples/emulated/emulated_test.go @@ -0,0 +1,23 @@ +package emulated + +import ( + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/std" + "github.com/consensys/gnark/test" +) + +func TestEmulatedArithmetic(t *testing.T) { + assert := test.NewAssert(t) + std.RegisterHints() + + var circuit, witness Circuit + + witness.X.Assign("26959946673427741531515197488526605382048662297355296634326893985793") + witness.Y.Assign("53919893346855483063030394977053210764097324594710593268653787971586") + witness.Res.Assign("485279052387156144224396168012515269674445015885648619762653195154800") + + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerialization()) +} diff --git a/frontend/schema/schema.go b/frontend/schema/schema.go index a2adb61b10..cac0ab481e 100644 --- a/frontend/schema/schema.go +++ b/frontend/schema/schema.go @@ -34,6 +34,12 @@ type Schema struct { // LeafHandler is the handler function that will be called when Visit reaches leafs of the struct type LeafHandler func(field *Field, tValue reflect.Value) error +// An object implementing an init hook knows how to "init" itself +// when parsed at compile time +type InitHook interface { + GnarkInitHook() // TODO @gbotrel find a better home for this +} + // Parse filters recursively input data struct and keeps only the fields containing slices, arrays of elements of // type frontend.Variable and return the corresponding Slices are converted to arrays. // @@ -278,6 +284,9 @@ func parse(r []Field, input interface{}, target reflect.Type, parentFullName, pa if fValue.CanAddr() && fValue.Addr().CanInterface() { value := fValue.Addr().Interface() + if ih, hasInitHook := value.(InitHook); hasInitHook { + ih.GnarkInitHook() + } var err error subFields, err = parse(subFields, value, target, getFullName(parentFullName, name, nameTag), name, nameTag, visibility, handler, nbPublic, nbSecret) if err != nil { diff --git a/go.mod b/go.mod index 3d6f34f309..48352d00b6 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/leanovate/gopter v0.2.9 github.com/rs/zerolog v1.26.1 github.com/stretchr/testify v1.7.1 + golang.org/x/exp v0.0.0-20220713135740-79cabaa25d75 ) require ( diff --git a/go.sum b/go.sum index c7fb2e5952..2888afb28c 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20211215165025-cf75a172585e/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064 h1:S25/rfnfsMVgORT4/J61MJ7rdyseOZOyvLIrZEZ7s6s= golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/exp v0.0.0-20220713135740-79cabaa25d75 h1:x03zeu7B2B11ySp+daztnwM5oBJ/8wGUSqrwcw9L0RA= +golang.org/x/exp v0.0.0-20220713135740-79cabaa25d75/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/internal/stats/latest.stats b/internal/stats/latest.stats index ccff5b1c6c..1bd31e94ad 100644 Binary files a/internal/stats/latest.stats and b/internal/stats/latest.stats differ diff --git a/internal/stats/snippet.go b/internal/stats/snippet.go index 7bcb086ee4..321d496594 100644 --- a/internal/stats/snippet.go +++ b/internal/stats/snippet.go @@ -11,6 +11,7 @@ import ( "github.com/consensys/gnark/std/algebra/sw_bls24315" "github.com/consensys/gnark/std/hash/mimc" "github.com/consensys/gnark/std/math/bits" + "github.com/consensys/gnark/std/math/emulated" ) var ( @@ -78,6 +79,25 @@ func initSnippets() { mimc.Write(newVariable()) _ = mimc.Sum() }) + registerSnippet("math/emulated/secp256k1_32", func(api frontend.API, newVariable func() frontend.Variable) { + secp256k1, _ := emulated.NewField[emulated.Secp256k1](api) + + newElement := func() emulated.Element[emulated.Secp256k1] { + r := emulated.NewElement[emulated.Secp256k1](nil) + for i := 0; i < len(r.Limbs); i++ { + r.Limbs[i] = newVariable() + } + return r + } + + x13 := secp256k1.Mul(newElement(), newElement(), newElement()) + fx2 := secp256k1.Mul(5, newElement()) + nom := secp256k1.Sub(fx2, x13) + denom := secp256k1.Add(newElement(), newElement(), newElement(), newElement()) + free := secp256k1.Div(nom, denom) + res := secp256k1.Add(x13, fx2, free) + secp256k1.AssertIsEqual(res, newElement()) + }) registerSnippet("pairing_bls12377", func(api frontend.API, newVariable func() frontend.Variable) { diff --git a/internal/stats/stats.go b/internal/stats/stats.go index c9773cccd7..d747af0827 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -12,7 +12,6 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" - "github.com/consensys/gnark/logger" ) const nbCurves = 7 @@ -46,7 +45,7 @@ func init() { func NewGlobalStats() *globalStats { return &globalStats{ - Stats: make(map[string][backend.PLONK + 1][nbCurves + 1]snippetStats), + Stats: make(map[string][backend.PLONKFRI + 1][nbCurves + 1]snippetStats), } } @@ -80,7 +79,7 @@ func NewSnippetStats(curve ecc.ID, backendID backend.ID, circuit frontend.Circui switch backendID { case backend.GROTH16: newCompiler = r1cs.NewBuilder - case backend.PLONK: + case backend.PLONK, backend.PLONKFRI: newCompiler = scs.NewBuilder default: panic("not implemented") @@ -101,11 +100,6 @@ func NewSnippetStats(curve ecc.ID, backendID backend.ID, circuit frontend.Circui func (s *globalStats) Add(curve ecc.ID, backendID backend.ID, cs snippetStats, circuitName string) { s.Lock() defer s.Unlock() - if backendID == backend.PLONKFRI { - log := logger.Logger() - log.Warn().Msg("ignoring plonk_fri circuit") - return - } rs := s.Stats[circuitName] rs[backendID][CurveIdx(curve)] = cs s.Stats[circuitName] = rs @@ -118,7 +112,7 @@ type Circuit struct { type globalStats struct { sync.RWMutex - Stats map[string][backend.PLONK + 1][nbCurves + 1]snippetStats + Stats map[string][backend.PLONKFRI + 1][nbCurves + 1]snippetStats } type snippetStats struct { diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index b44fcf88cc..b536eda5be 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -29,10 +29,6 @@ func TestCircuitStatistics(t *testing.T) { for _, b := range backend.Implemented() { curve := curve backendID := b - if backendID == backend.PLONKFRI { - // TODO - continue - } name := name // copy the circuit now in case assert calls t.Parallel() circuit := c.Circuit diff --git a/std/hints.go b/std/hints.go index 877f22cb90..fda988d553 100644 --- a/std/hints.go +++ b/std/hints.go @@ -7,7 +7,7 @@ import ( "github.com/consensys/gnark/std/algebra/sw_bls12377" "github.com/consensys/gnark/std/algebra/sw_bls24315" "github.com/consensys/gnark/std/math/bits" - "github.com/consensys/gnark/std/math/nonnative" + "github.com/consensys/gnark/std/math/emulated" ) var registerOnce sync.Once @@ -30,5 +30,5 @@ func registerHints() { hint.Register(bits.NNAF) hint.Register(bits.IthBit) hint.Register(bits.NBits) - hint.Register(nonnative.GetHints()...) + hint.Register(emulated.GetHints()...) } diff --git a/std/math/nonnative/composition.go b/std/math/emulated/composition.go similarity index 51% rename from std/math/nonnative/composition.go rename to std/math/emulated/composition.go index 2218da9f67..0ca4986de8 100644 --- a/std/math/nonnative/composition.go +++ b/std/math/emulated/composition.go @@ -1,4 +1,4 @@ -package nonnative +package emulated import ( "fmt" @@ -25,6 +25,7 @@ func recompose(inputs []*big.Int, nbBits uint, res *big.Int) error { res.Lsh(res, nbBits) res.Add(res, inputs[len(inputs)-i-1]) } + // TODO @gbotrel mod reduce ? return nil } @@ -47,7 +48,7 @@ func decompose(input *big.Int, nbBits uint, res []*big.Int) error { base := new(big.Int).Lsh(big.NewInt(1), nbBits) tmp := new(big.Int).Set(input) for i := 0; i < len(res); i++ { - res[i] = new(big.Int).Mod(tmp, base) + res[i].Mod(tmp, base) tmp.Rsh(tmp, nbBits) } return nil @@ -67,75 +68,84 @@ func decompose(input *big.Int, nbBits uint, res []*big.Int) error { // // then no such underflow happens and s = a-b (mod p) as the padding is multiple // of p. -func subPadding(params *Params, current_overflow uint, nbLimbs uint) []*big.Int { - padLimbs := make([]*big.Int, nbLimbs) - for i := 0; i < len(padLimbs); i++ { - padLimbs[i] = new(big.Int).Lsh(big.NewInt(1), uint(current_overflow)+params.nbBits) +func subPadding[T FieldParams](overflow uint, nbLimbs uint) []*big.Int { + var fp T + p := fp.Modulus() + bitsPerLimbs := fp.BitsPerLimb() + + // first, we build a number nLimbs, such that nLimbs > b; + // here b is defined by its bounds, that is b is an element with nbLimbs of (bitsPerLimbs+overflow) + // so a number nLimbs > b, is simply taking the next power of 2 over this bound . + nLimbs := make([]*big.Int, nbLimbs) + for i := 0; i < len(nLimbs); i++ { + nLimbs[i] = new(big.Int).SetUint64(1) + nLimbs[i].Lsh(nLimbs[i], overflow+bitsPerLimbs) } - pad := new(big.Int) - if err := recompose(padLimbs, params.nbBits, pad); err != nil { + + // recompose n as the sum of the coefficients weighted by the limbs + n := new(big.Int) + if err := recompose(nLimbs, bitsPerLimbs, n); err != nil { panic(fmt.Sprintf("recompose: %v", err)) } - pad.Mod(pad, params.r) - pad.Sub(params.r, pad) - ret := make([]*big.Int, nbLimbs) - for i := range ret { - ret[i] = new(big.Int) + // mod reduce n, and negate it + n.Mod(n, p) + n.Sub(p, n) + + // construct pad such that: + // pad := n - neg(n mod p) == kp + pad := make([]*big.Int, nbLimbs) + for i := range pad { + pad[i] = new(big.Int) } - if err := decompose(pad, params.nbBits, ret); err != nil { + if err := decompose(n, bitsPerLimbs, pad); err != nil { panic(fmt.Sprintf("decompose: %v", err)) } - for i := range ret { - ret[i].Add(ret[i], padLimbs[i]) + for i := range pad { + pad[i].Add(pad[i], nLimbs[i]) } - return ret + return pad } -// regroupParams returns parameters which allow for most optimal regrouping of +// compact returns parameters which allow for most optimal regrouping of // limbs. In regrouping the limbs, we encode multiple existing limbs as a linear // combination in a single new limb. -func regroupParams(params *Params, nbNativeBits, nbMaxOverflow uint) *Params { +// compact returns a and b minimal (in number of limbs) representation that fits in the snark field +func (f *field[T]) compact(a, b Element[T]) (ac, bc []frontend.Variable, bitsPerLimb uint) { + maxOverflow := max(a.overflow, b.overflow) // subtract one bit as can not potentially use all bits of Fr and one bit as // grouping may overflow - maxFit := nbNativeBits - 2 - groupSize := (maxFit - nbMaxOverflow) / params.nbBits + maxNbBits := uint(f.api.Compiler().FieldBitLen()) - 2 - maxOverflow + groupSize := maxNbBits / a.fParams.BitsPerLimb() if groupSize == 0 { - // not sufficient space for regroup, return the same parameters. - return params - } - nbRegroupBits := params.nbBits * groupSize - nbRegroupLimbs := (params.nbLimbs + groupSize) / groupSize - return &Params{ - r: params.r, - hasInverses: params.hasInverses, - nbLimbs: nbRegroupLimbs, - nbBits: nbRegroupBits, + // no space for compact + return a.Limbs, b.Limbs, a.fParams.BitsPerLimb() } + + bitsPerLimb = a.fParams.BitsPerLimb() * groupSize + + ac = f.compactLimbs(a, groupSize, bitsPerLimb) + bc = f.compactLimbs(b, groupSize, bitsPerLimb) + return } -// regroupLimbs perform the regrouping of limbs between old and new parameters. -func regroupLimbs(api frontend.API, params, regroupParams *Params, limbs []frontend.Variable) []frontend.Variable { - if params.nbBits == regroupParams.nbBits { - // not regrouping - return limbs - } - if regroupParams.nbBits%params.nbBits != 0 { - panic("regroup bitwidth must be multiple of initial bitwidth") +// compactLimbs perform the regrouping of limbs between old and new parameters. +func (f *field[T]) compactLimbs(e Element[T], groupSize, bitsPerLimb uint) []frontend.Variable { + if f.fParams.BitsPerLimb() == bitsPerLimb { + return e.Limbs } - groupSize := regroupParams.nbBits / params.nbBits - nbLimbs := (uint(len(limbs)) + groupSize - 1) / groupSize - regrouped := make([]frontend.Variable, nbLimbs) + nbLimbs := (uint(len(e.Limbs)) + groupSize - 1) / groupSize + r := make([]frontend.Variable, nbLimbs) coeffs := make([]*big.Int, groupSize) one := big.NewInt(1) for i := range coeffs { coeffs[i] = new(big.Int) - coeffs[i].Lsh(one, params.nbBits*uint(i)) + coeffs[i].Lsh(one, e.fParams.BitsPerLimb()*uint(i)) } for i := uint(0); i < nbLimbs; i++ { - regrouped[i] = uint(0) - for j := uint(0); j < groupSize && i*groupSize+j < uint(len(limbs)); j++ { - regrouped[i] = api.Add(regrouped[i], api.Mul(coeffs[j], limbs[i*groupSize+j])) + r[i] = uint(0) + for j := uint(0); j < groupSize && i*groupSize+j < uint(len(e.Limbs)); j++ { + r[i] = f.api.Add(r[i], f.api.Mul(coeffs[j], e.Limbs[i*groupSize+j])) } } - return regrouped + return r } diff --git a/std/math/emulated/composition_test.go b/std/math/emulated/composition_test.go new file mode 100644 index 0000000000..f826cdd8f5 --- /dev/null +++ b/std/math/emulated/composition_test.go @@ -0,0 +1,65 @@ +package emulated + +import ( + "crypto/rand" + "fmt" + "math/big" + "testing" + + "github.com/consensys/gnark/test" +) + +func TestComposition(t *testing.T) { + testComposition[BN254Fp](t) + testComposition[Secp256k1](t) + testComposition[BLS12377Fp](t) + testComposition[Goldilocks](t) +} + +func testComposition[T FieldParams](t *testing.T) { + assert := test.NewAssert(t) + var fp T + assert.Run(func(assert *test.Assert) { + n, err := rand.Int(rand.Reader, fp.Modulus()) + if err != nil { + assert.FailNow("rand int", err) + } + res := make([]*big.Int, fp.NbLimbs()) + for i := range res { + res[i] = new(big.Int) + } + if err = decompose(n, fp.BitsPerLimb(), res); err != nil { + assert.FailNow("decompose", err) + } + n2 := new(big.Int) + if err = recompose(res, fp.BitsPerLimb(), n2); err != nil { + assert.FailNow("recompose", err) + } + if n2.Cmp(n) != 0 { + assert.FailNow("unequal") + } + }, testName[T]()) +} + +func TestSubPadding(t *testing.T) { + testSubPadding[BN254Fp](t) + testSubPadding[Secp256k1](t) + testSubPadding[BLS12377Fp](t) + testSubPadding[Goldilocks](t) +} + +func testSubPadding[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + for i := fp.NbLimbs(); i < 2*fp.NbLimbs(); i++ { + assert.Run(func(assert *test.Assert) { + limbs := subPadding[T](0, i) + padValue := new(big.Int) + if err := recompose(limbs, fp.BitsPerLimb(), padValue); err != nil { + assert.FailNow("recompose", err) + } + padValue.Mod(padValue, fp.Modulus()) + assert.Zero(padValue.Cmp(big.NewInt(0)), "padding not multiple of order") + }, fmt.Sprintf("%s/nbLimbs=%d", testName[T](), i)) + } +} diff --git a/std/math/nonnative/doc.go b/std/math/emulated/doc.go similarity index 99% rename from std/math/nonnative/doc.go rename to std/math/emulated/doc.go index 06d304cf8b..66c4b0f3b0 100644 --- a/std/math/nonnative/doc.go +++ b/std/math/emulated/doc.go @@ -1,5 +1,5 @@ /* -Package nonnative implements operations over any modulus. +Package emulated implements operations over any modulus. Non-native computation in circuit @@ -217,4 +217,4 @@ We do not assume particular value for placeholder constant and may its implementation to speed up compilation. */ -package nonnative +package emulated diff --git a/std/math/emulated/element.go b/std/math/emulated/element.go new file mode 100644 index 0000000000..78c48cedc7 --- /dev/null +++ b/std/math/emulated/element.go @@ -0,0 +1,553 @@ +package emulated + +// TODO: add checks which ensure that constants are not used as receivers +// TODO: add sanity checks before the operations (e.g. that overflow is +// sufficient and do not need to reduce) +// TODO: think about different "operation modes". Probably hand-optimized code +// is better than reducing eagerly, but the user should be at least aware during +// compile-time that values need to be reduced. But there should be an easy-mode +// where the user does not need to manually reduce and the library does it as +// necessary. +// TODO: check that the parameters coincide for elements. +// TODO: less equal than +// TODO: simple exponentiation before we implement Wesolowsky + +import ( + "errors" + "fmt" + "math" + "math/big" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/compiled" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/std/math/bits" + "golang.org/x/exp/constraints" +) + +type errOverflow struct { + op string + nextOverflow uint + maxOverflow uint + reduceRight bool +} + +func (e errOverflow) Error() string { + return fmt.Sprintf("op %s overflow %d exceeds max %d", e.op, e.nextOverflow, e.maxOverflow) +} + +// Element defines an element in the ring of integers modulo n. The integer +// value of the element is split into limbs of nbBits lengths and represented as +// a slice of limbs. +type Element[T FieldParams] struct { + Limbs []frontend.Variable `gnark:"limbs"` // in little-endian (least significant limb first) encoding + + // overflow indicates the number of additions on top of the normal form. To + // ensure that none of the limbs overflow the scalar field of the snark + // curve, we must check that nbBits+overflow < floor(log2(fr modulus)) + overflow uint `gnark:"-"` + + // f carries the ring parameters + fParams T +} + +// NewElement builds a new emulated element from input +// if input is a Element[T], this functions clones and return a new Element[T] +// else, it attemps to convert to big.Int , mod reduce if necessary and return a cannonical Element[T] +func NewElement[T FieldParams](v interface{}) Element[T] { + r := Element[T]{} + + if v == nil { + r.Limbs = make([]frontend.Variable, r.fParams.NbLimbs()) + for i := 0; i < len(r.Limbs); i++ { + r.Limbs[i] = 0 + } + + return r + } + switch tv := v.(type) { + case Element[T]: + r.Limbs = make([]frontend.Variable, len(tv.Limbs)) + copy(r.Limbs, tv.Limbs) + r.overflow = tv.overflow + return r + case *Element[T]: + r.Limbs = make([]frontend.Variable, len(tv.Limbs)) + copy(r.Limbs, tv.Limbs) + r.overflow = tv.overflow + return r + case compiled.LinearExpression: + // TODO @gbotrel don't like that + // return f.PackLimbs([]frontend.Variable{in}) + r.Limbs = []frontend.Variable{v} + return r + case compiled.Term: + // TODO @gbotrel don't like that + // return f.PackLimbs([]frontend.Variable{in}) + r.Limbs = []frontend.Variable{v} + return r + } + + // convert to big.Int + bValue := utils.FromInterface(v) + + // mod reduce + if r.fParams.Modulus().Cmp(&bValue) != 0 { + bValue.Mod(&bValue, r.fParams.Modulus()) + } + + // decompose into limbs + // TODO @gbotrel use big.Int pool here + limbs := make([]*big.Int, r.fParams.NbLimbs()) + for i := range limbs { + limbs[i] = new(big.Int) + } + if err := decompose(&bValue, r.fParams.BitsPerLimb(), limbs); err != nil { + panic(fmt.Errorf("decompose value: %w", err)) + } + + // assign limb values + r.Limbs = make([]frontend.Variable, r.fParams.NbLimbs()) + for i := range limbs { + r.Limbs[i] = frontend.Variable(limbs[i]) + } + + return r +} + +// toBits returns the bit representation of the Element in little-endian (LSB +// first) order. The returned bits are constrained to be 0-1. The number of +// returned bits is nbLimbs*nbBits+overflow. To obtain the bits of the canonical +// representation of Element, reduce Element first and take less significant +// bits corresponding to the bitwidth of the emulated modulus. +func (f *field[T]) toBits(a Element[T]) []frontend.Variable { + ba, aConst := f.ConstantValue(a) + if aConst { + return f.api.ToBinary(ba, int(f.fParams.BitsPerLimb()*f.fParams.NbLimbs())) + } + var carry frontend.Variable = 0 + var fullBits []frontend.Variable + var limbBits []frontend.Variable + for i := 0; i < len(a.Limbs); i++ { + limbBits = bits.ToBinary(f.api, f.api.Add(a.Limbs[i], carry), bits.WithNbDigits(int(a.fParams.BitsPerLimb()+a.overflow))) + fullBits = append(fullBits, limbBits[:a.fParams.BitsPerLimb()]...) + if a.overflow > 0 { + carry = bits.FromBinary(f.api, limbBits[a.fParams.BitsPerLimb():]) + } + } + fullBits = append(fullBits, limbBits[a.fParams.BitsPerLimb():a.fParams.BitsPerLimb()+a.overflow]...) + return fullBits +} + +// maxOverflow returns the maximal possible overflow for the element. If the +// overflow of the next operation exceeds the value returned by this method, +// then the limbs may overflow the native field. +func (f *field[T]) maxOverflow() uint { + f.maxOfOnce.Do(func() { + f.maxOf = uint(f.api.Compiler().FieldBitLen()-1) - f.fParams.BitsPerLimb() + }) + return f.maxOf +} + +// assertLimbsEqualitySlow is the main routine in the package. It asserts that the +// two slices of limbs represent the same integer value. This is also the most +// costly operation in the package as it does bit decomposition of the limbs. +func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { + + nbLimbs := max(len(l), len(r)) + maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits) + maxValueShift := new(big.Int).Lsh(big.NewInt(1), nbCarryBits) + + var carry frontend.Variable = 0 + for i := 0; i < nbLimbs; i++ { + diff := api.Add(maxValue, carry) + if i < len(l) { + diff = api.Add(diff, l[i]) + } + if i < len(r) { + diff = api.Sub(diff, r[i]) + } + if i > 0 { + diff = api.Sub(diff, maxValueShift) + } + // TODO: more efficient methods for splitting a variable? Because we are + // splitting the value into two, then maybe we do not need the whole + // binary decomposition \sum_{i=0}^n a_i 2^i, but can use a * 2^nbits + + // b. Then we can also omit the FromBinary call. + diffBits := bits.ToBinary(api, diff, bits.WithNbDigits(int(nbBits+nbCarryBits+1)), bits.WithUnconstrainedOutputs()) + for j := uint(0); j < nbBits; j++ { + api.AssertIsEqual(diffBits[j], 0) + } + carry = bits.FromBinary(api, diffBits[nbBits:nbBits+nbCarryBits+1]) + } + api.AssertIsEqual(carry, maxValueShift) +} + +// AssertLimbsEquality asserts that the limbs represent a same integer value (up +// to overflow). This method does not ensure that the values are equal modulo +// the field order. For strict equality, use AssertIsEqual. +func (f *field[T]) AssertLimbsEquality(a, b Element[T]) { + ba, aConst := f.ConstantValue(a) + bb, bConst := f.ConstantValue(b) + if aConst && bConst { + ba.Mod(ba, f.fParams.Modulus()) + bb.Mod(bb, f.fParams.Modulus()) + if ba.Cmp(bb) != 0 { + panic(fmt.Errorf("constant values are different: %s != %s", ba.String(), bb.String())) + } + return + } + + // first, we check if we can compact the e and other; they could be using 8 limbs of 32bits + // but with our snark field, we could express them in 2 limbs of 128bits, which would make bit decomposition + // and limbs equality in-circuit (way) cheaper + ca, cb, bitsPerLimb := f.compact(a, b) + + // f.log.Trace().Int("len(a.limbs)", len(a.Limbs)). + // Int("len(b.limbs)", len(b.Limbs)). + // Int("len(cb.limbs)", len(cb)). + // Int("len(ca.limbs)", len(ca)). + // Msg("AssertLimbsEquality") + // slow path -- the overflows are different. Need to compare with carries. + // TODO: we previously assumed that one side was "larger" than the other + // side, but I think this assumption is not valid anymore + if a.overflow > b.overflow { + assertLimbsEqualitySlow(f.api, ca, cb, bitsPerLimb, a.overflow) + } else { + assertLimbsEqualitySlow(f.api, cb, ca, bitsPerLimb, b.overflow) + } +} + +// EnforceWidth enforces that the bitlength of the value is exactly the +// bitlength of the modulus. Any newly initialized variable should be +// constrained to ensure correct operations. +func (f *field[T]) EnforceWidth(a Element[T]) { + _, aConst := f.ConstantValue(a) + if aConst { + if len(a.Limbs) != int(f.fParams.NbLimbs()) { + panic("constant limb width doesn't match parametrized field") + } + } + + for i := range a.Limbs { + // TODO @gbotrel why check all the limbs here? if len(e.Limbs) <= modulus + // && last limb <= bits[lastLimbs] modulus, we're good ? + limbNbBits := int(a.fParams.BitsPerLimb()) + if i == len(a.Limbs)-1 { + // take only required bits from the most significant limb + limbNbBits = ((a.fParams.Modulus().BitLen() - 1) % int(a.fParams.BitsPerLimb())) + 1 + } + // bits.ToBinary restricts the least significant NbDigits to be equal to + // the limb value. This is sufficient to restrict for the bitlength and + // we can discard the bits themselves. + bits.ToBinary(f.api, a.Limbs[i], bits.WithNbDigits(limbNbBits)) + } +} + +func (f *field[T]) addPreCond(a, b Element[T]) (nextOverflow uint, err error) { + reduceRight := a.overflow < b.overflow + + nextOverflow = max(a.overflow, b.overflow) + 1 + + if nextOverflow > f.maxOverflow() { + err = errOverflow{op: "add", nextOverflow: nextOverflow, maxOverflow: f.maxOverflow(), reduceRight: reduceRight} + } + return +} + +func (f *field[T]) add(a, b Element[T], nextOverflow uint) Element[T] { + ba, aConst := f.ConstantValue(a) + bb, bConst := f.ConstantValue(b) + if aConst && bConst { + ba.Add(ba, bb).Mod(ba, f.fParams.Modulus()) + return NewElement[T](ba) + } + + // TODO: figure out case when one element is a constant. If one addend is a + // constant, then we do not reduce it (but this is always case as the + // constant's overflow never increases?) + // TODO: check that the target is a variable (has an API) + // TODO: if both are constants, then add big ints + nbLimbs := max(len(a.Limbs), len(b.Limbs)) + limbs := make([]frontend.Variable, nbLimbs) + for i := range limbs { + limbs[i] = 0 + if i < len(a.Limbs) { + limbs[i] = f.api.Add(limbs[i], a.Limbs[i]) + } + if i < len(b.Limbs) { + limbs[i] = f.api.Add(limbs[i], b.Limbs[i]) + } + } + + e := Element[T]{ + Limbs: limbs, + overflow: nextOverflow, + } + return e +} + +func (f *field[T]) mulPreCond(a, b Element[T]) (nextOverflow uint, err error) { + reduceRight := a.overflow < b.overflow + nbResLimbs := nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)) + nextOverflow = f.fParams.BitsPerLimb() + uint(math.Log2(float64(2*nbResLimbs-1))) + 1 + a.overflow + b.overflow + if nextOverflow > f.maxOverflow() { + err = errOverflow{op: "mul", nextOverflow: nextOverflow, maxOverflow: f.maxOverflow(), reduceRight: reduceRight} + } + return +} + +func (f *field[T]) mul(a, b Element[T], nextOverflow uint) Element[T] { + // TODO: when one element is constant. + ba, aConst := f.ConstantValue(a) + bb, bConst := f.ConstantValue(b) + if aConst && bConst { + ba.Mul(ba, bb).Mod(ba, f.fParams.Modulus()) + return NewElement[T](ba) + } + + // mulResult contains the result (out of circuit) of a * b school book multiplication + // len(mulResult) == len(a) + len(b) - 1 + mulResult, err := computeMultiplicationHint(f.api, f, a.Limbs, b.Limbs) + if err != nil { + panic(fmt.Sprintf("multiplication hint: %s", err)) + } + + // we computed the result of the mul outside the circuit (mulResult) + // and we want to constrain inside the circuit that this injected value + // actually matches the in-circuit a * b values + // create constraints (\sum_{i=0}^{m-1} a_i c^i) * (\sum_{i=0}^{m-1} b_i + // c^i) = (\sum_{i=0}^{2m-2} z_i c^i) for c \in {1, 2m-1} + w := new(big.Int) + for c := 1; c <= len(mulResult); c++ { + w.SetInt64(1) // c^i + l := a.Limbs[0] + r := b.Limbs[0] + o := mulResult[0] + + for i := 1; i < len(mulResult); i++ { + w.Lsh(w, uint(c)) + if i < len(a.Limbs) { + l = f.api.Add(l, f.api.Mul(a.Limbs[i], w)) + } + if i < len(b.Limbs) { + r = f.api.Add(r, f.api.Mul(b.Limbs[i], w)) + } + o = f.api.Add(o, f.api.Mul(mulResult[i], w)) + } + f.api.AssertIsEqual(f.api.Mul(l, r), o) + } + + return Element[T]{ + Limbs: mulResult, + overflow: nextOverflow, + } +} + +// reduce reduces a modulo modulus and assigns e to the reduced value. +func (f *field[T]) reduce(a Element[T]) Element[T] { + if a.overflow == 0 { + // fast path - already reduced, omit reduction. + return a + } + // sanity check + _, aConst := f.ConstantValue(a) + if aConst { + panic("trying to reduce a constant, which happen to have an overflow flag set") + } + + // slow path - use hint to reduce value + e, err := f.computeRemHint(a, f.Modulus()) + if err != nil { + panic(fmt.Sprintf("reduction hint: %v", err)) + } + // TODO @gbotrel fixme: assertIsEqual(a, e) crashes Pairing test + f.assertIsEqual(e, a) + return e +} + +// Assign a value to self (witness assignment) +func (e *Element[T]) Assign(val interface{}) { + *e = NewElement[T](val) +} + +func (e *Element[T]) GnarkInitHook() { + if e.Limbs == nil { + *e = NewElement[T](nil) + } +} + +// Set sets e to a and returns e. If a is constant, then it also enforces the +// widths of the limbs. +func (e *Element[T]) Set(a Element[T]) { + e.Limbs = make([]frontend.Variable, len(a.Limbs)) + e.overflow = a.overflow + copy(e.Limbs, a.Limbs) + // TODO @gbotrel this shouldn't happen anymore + // if a.f.api == nil { + // // we are setting from constant -- ensure that the widths of the limbs + // // are restricted + // e.EnforceWidth() + // } +} + +// AssertIsEqual ensures that a is equal to b modulo the modulus. +func (f *field[T]) assertIsEqual(a, b Element[T]) Element[T] { + ba, aConst := f.ConstantValue(a) + bb, bConst := f.ConstantValue(b) + if aConst && bConst { + ba.Mod(ba, f.fParams.Modulus()) + bb.Mod(bb, f.fParams.Modulus()) + if ba.Cmp(bb) != 0 { + panic(fmt.Sprintf("%s != %s", ba, bb)) + } + return NewElement[T](nil) // TODO @gbotrel un-used result + } + + diff := (f.Sub(b, a)).(Element[T]) + + // we compute k such that diff / p == k + // so essentially, we say "I know an element k such that k*p == diff" + // hence, diff == 0 mod p + p := f.Modulus() + // we compute k such that diff / p == k + // so essentially, we say "I know an element k such that k*p == diff" + // hence, diff == 0 mod p + k, err := f.computeQuoHint(diff) + if err != nil { + panic(fmt.Sprintf("hint error: %v", err)) + } + + kp := (f.Mul(k, p)).(Element[T]) + + f.AssertLimbsEquality(diff, kp) + + // TODO @gbotrel improve useless alloc + // we have this so that the signature of assertIsEqual matches expected in reduceAndOp + return NewElement[T](nil) +} + +// AssertIsEqualLessThan ensures that e is less or equal than e. +func (f *field[T]) AssertIsLessEqualThan(e, a Element[T]) { + if e.overflow+a.overflow > 0 { + panic("inputs must have 0 overflow") + } + eBits := f.toBits(e) + aBits := f.toBits(a) + ff := func(xbits, ybits []frontend.Variable) []frontend.Variable { + diff := len(xbits) - len(ybits) + ybits = append(ybits, make([]frontend.Variable, diff)...) + for i := len(ybits) - diff - 1; i < len(ybits); i++ { + ybits[i] = 0 + } + return ybits + } + if len(eBits) > len(aBits) { + aBits = ff(eBits, aBits) + } else { + eBits = ff(aBits, eBits) + } + p := make([]frontend.Variable, len(eBits)+1) + p[len(eBits)] = 1 + for i := len(eBits) - 1; i >= 0; i-- { + v := f.api.Mul(p[i+1], eBits[i]) + p[i] = f.api.Select(aBits[i], v, p[i+1]) + t := f.api.Select(aBits[i], 0, p[i+1]) + l := f.api.Sub(1, t, eBits[i]) + ll := f.api.Mul(l, eBits[i]) + f.api.AssertIsEqual(ll, 0) + } +} + +func (f *field[T]) subPreCond(a, b Element[T]) (nextOverflow uint, err error) { + + reduceRight := a.overflow < b.overflow+2 + nextOverflow = max(b.overflow+2, a.overflow) + if nextOverflow > f.maxOverflow() { + err = errOverflow{op: "sub", nextOverflow: nextOverflow, maxOverflow: f.maxOverflow(), reduceRight: reduceRight} + } + return +} + +func (f *field[T]) sub(a, b Element[T], nextOverflow uint) Element[T] { + ba, aConst := f.ConstantValue(a) + bb, bConst := f.ConstantValue(b) + if aConst && bConst { + ba.Sub(ba, bb).Mod(ba, f.fParams.Modulus()) + return NewElement[T](ba) + } + + // first we have to compute padding to ensure that the subtraction does not + // underflow. + nbLimbs := max(len(a.Limbs), len(b.Limbs)) + limbs := make([]frontend.Variable, nbLimbs) + padLimbs := subPadding[T](b.overflow, uint(nbLimbs)) + for i := range limbs { + limbs[i] = padLimbs[i] + if i < len(a.Limbs) { + limbs[i] = f.api.Add(limbs[i], a.Limbs[i]) + } + if i < len(b.Limbs) { + limbs[i] = f.api.Sub(limbs[i], b.Limbs[i]) + } + } + e := Element[T]{ + Limbs: limbs, + overflow: nextOverflow, + } + return e +} + +// Select sets e to a if selector == 0 and to b otherwise. +// assumes a overflow == b overflow +func (f *field[T]) _select(selector frontend.Variable, a, b Element[T]) Element[T] { + e := NewElement[T](nil) + e.overflow = a.overflow + for i := range a.Limbs { + e.Limbs[i] = f.api.Select(selector, a.Limbs[i], b.Limbs[i]) + } + return e +} + +// Lookup2 performs two-bit lookup between a, b, c, d based on lookup bits b1 +// and b2. Sets e to a if b0=b1=0, b if b0=1 and b1=0, c if b0=0 and b1=1, d if b0=b1=1. +func (f *field[T]) lookup2(b0, b1 frontend.Variable, a, b, c, d Element[T]) Element[T] { + if len(a.Limbs) != len(b.Limbs) || len(a.Limbs) != len(c.Limbs) || len(a.Limbs) != len(d.Limbs) { + panic("unequal limb counts for lookup") + } + if a.overflow != b.overflow || a.overflow != c.overflow || a.overflow != d.overflow { + panic("unequal overflows for lookup") + } + e := NewElement[T](nil) + e.Limbs = make([]frontend.Variable, len(a.Limbs)) + e.overflow = a.overflow + for i := range a.Limbs { + e.Limbs[i] = f.api.Lookup2(b0, b1, a.Limbs[i], b.Limbs[i], c.Limbs[i], d.Limbs[i]) + } + return e +} + +// reduceAndOp applies op on the inputs. If the pre-condition check preCond +// errs, then first reduces the input arguments. The reduction is done +// one-by-one with the element with highest overflow reduced first. +func (f *field[T]) reduceAndOp(op func(Element[T], Element[T], uint) Element[T], preCond func(Element[T], Element[T]) (uint, error), a, b Element[T]) Element[T] { + var nextOverflow uint + var err error + var target errOverflow + + for nextOverflow, err = preCond(a, b); errors.As(err, &target); nextOverflow, err = preCond(a, b) { + if !target.reduceRight { + a = f.reduce(a) + } else { + b = f.reduce(b) + } + } + return op(a, b, nextOverflow) +} + +func max[T constraints.Ordered](a, b T) T { + if a > b { + return a + } + return b +} diff --git a/std/math/emulated/element_test.go b/std/math/emulated/element_test.go new file mode 100644 index 0000000000..ee03e9f076 --- /dev/null +++ b/std/math/emulated/element_test.go @@ -0,0 +1,764 @@ +package emulated + +import ( + "crypto/rand" + "fmt" + "math/big" + "reflect" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/test" +) + +const testCurve = ecc.BN254 + +type AssertLimbEqualityCircuit[T FieldParams] struct { + A, B Element[T] +} + +func (c *AssertLimbEqualityCircuit[T]) Define(api frontend.API) error { + _f, err := NewField[T](api) + if err != nil { + return err + } + f := _f.(*field[T]) + f.AssertLimbsEquality(c.A, c.B) + return nil +} + +func testName[T FieldParams]() string { + var fp T + return fmt.Sprintf("%s/limb=%d", reflect.TypeOf(fp).Name(), fp.BitsPerLimb()) +} + +func TestAssertLimbEqualityNoOverflow(t *testing.T) { + testAssertLimbEqualityNoOverflow[Goldilocks](t) + testAssertLimbEqualityNoOverflow[Secp256k1](t) + testAssertLimbEqualityNoOverflow[BN254Fp](t) +} + +func testAssertLimbEqualityNoOverflow[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness AssertLimbEqualityCircuit[T] + val, _ := rand.Int(rand.Reader, fp.Modulus()) + witness.A.Assign(val) + witness.B.Assign(val) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +// TODO: add also cases which should fail + +type AssertIsLessEqualThanCircuit[T FieldParams] struct { + L, R Element[T] +} + +func (c *AssertIsLessEqualThanCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + f.AssertIsLessOrEqual(c.L, c.R) + return nil +} + +func TestAssertIsLessEqualThan(t *testing.T) { + testAssertIsLessEqualThan[Goldilocks](t) + testAssertIsLessEqualThan[Secp256k1](t) + testAssertIsLessEqualThan[BN254Fp](t) +} + +func testAssertIsLessEqualThan[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness AssertIsLessEqualThanCircuit[T] + R, _ := rand.Int(rand.Reader, fp.Modulus()) + L, _ := rand.Int(rand.Reader, R) + witness.R.Assign(R) + witness.L.Assign(L) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type AddCircuit[T FieldParams] struct { + A, B, C Element[T] +} + +func (c *AddCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Add(c.A, c.B) + f.AssertIsEqual(res, c.C) + return nil +} + +func TestAddCircuitNoOverflow(t *testing.T) { + testAddCircuitNoOverflow[Goldilocks](t) + testAddCircuitNoOverflow[Secp256k1](t) + testAddCircuitNoOverflow[BN254Fp](t) +} + +func testAddCircuitNoOverflow[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness AddCircuit[T] + bound := new(big.Int).Rsh(fp.Modulus(), 1) + val1, _ := rand.Int(rand.Reader, bound) + val2, _ := rand.Int(rand.Reader, bound) + res := new(big.Int).Add(val1, val2) + witness.A.Assign(val1) + witness.B.Assign(val2) + witness.C.Assign(res) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type MulNoOverflowCircuit[T FieldParams] struct { + A Element[T] + B Element[T] + C Element[T] +} + +func (c *MulNoOverflowCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Mul(c.A, c.B) + f.AssertIsEqual(res, c.C) + return nil +} + +func TestMulCircuitNoOverflow(t *testing.T) { + // testMulCircuitNoOverflow[Goldilocks](t) + testMulCircuitNoOverflow[Secp256k1](t) + // testMulCircuitNoOverflow[BN254Fp](t) +} + +func testMulCircuitNoOverflow[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness MulNoOverflowCircuit[T] + val1, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), uint(fp.Modulus().BitLen())/2)) + val2, _ := rand.Int(rand.Reader, new(big.Int).Div(fp.Modulus(), val1)) + res := new(big.Int).Mul(val1, val2) + witness.A.Assign(val1) + witness.B.Assign(val2) + witness.C.Assign(res) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16)) + }, testName[T]()) +} + +type MulCircuitOverflow[T FieldParams] struct { + A Element[T] + B Element[T] + C Element[T] +} + +func (c *MulCircuitOverflow[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Mul(c.A, c.B) + f.AssertIsEqual(res, c.C) + return nil +} + +func TestMulCircuitOverflow(t *testing.T) { + testMulCircuitOverflow[Goldilocks](t) + testMulCircuitOverflow[Secp256k1](t) + testMulCircuitOverflow[BN254Fp](t) +} + +func testMulCircuitOverflow[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness MulCircuitOverflow[T] + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + val2, _ := rand.Int(rand.Reader, fp.Modulus()) + res := new(big.Int).Mul(val1, val2) + res.Mod(res, fp.Modulus()) + witness.A.Assign(val1) + witness.B.Assign(val2) + witness.C.Assign(res) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type ReduceAfterAddCircuit[T FieldParams] struct { + A Element[T] + B Element[T] + C Element[T] +} + +func (c *ReduceAfterAddCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Add(c.A, c.B) + res = f.(*field[T]).reduce(res.(Element[T])) + f.AssertIsEqual(res, c.C) + return nil +} + +func TestReduceAfterAdd(t *testing.T) { + testReduceAfterAdd[Goldilocks](t) + testReduceAfterAdd[Secp256k1](t) + testReduceAfterAdd[BN254Fp](t) +} + +func testReduceAfterAdd[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness ReduceAfterAddCircuit[T] + val2, _ := rand.Int(rand.Reader, fp.Modulus()) + val1, _ := rand.Int(rand.Reader, val2) + val3 := new(big.Int).Add(val1, fp.Modulus()) + val3.Sub(val3, val2) + witness.A.Assign(val3) + witness.B.Assign(val2) + witness.C.Assign(val1) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type SubtractCircuit[T FieldParams] struct { + A Element[T] + B Element[T] + C Element[T] +} + +func (c *SubtractCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Sub(c.A, c.B) + f.AssertIsEqual(res, c.C) + return nil +} + +func TestSubtractNoOverflow(t *testing.T) { + testSubtractNoOverflow[Goldilocks](t) + testSubtractNoOverflow[Secp256k1](t) + testSubtractNoOverflow[BN254Fp](t) + + testSubtractOverflow[Goldilocks](t) + testSubtractOverflow[Secp256k1](t) + testSubtractOverflow[BN254Fp](t) +} + +func testSubtractNoOverflow[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness SubtractCircuit[T] + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + val2, _ := rand.Int(rand.Reader, val1) + res := new(big.Int).Sub(val1, val2) + witness.A.Assign(val1) + witness.B.Assign(val2) + witness.C.Assign(res) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +func testSubtractOverflow[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness SubtractCircuit[T] + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + val2, _ := rand.Int(rand.Reader, new(big.Int).Sub(fp.Modulus(), val1)) + val2.Add(val2, val1) + res := new(big.Int).Sub(val1, val2) + res.Mod(res, fp.Modulus()) + witness.A.Assign(val1) + witness.B.Assign(val2) + witness.C.Assign(res) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type NegationCircuit[T FieldParams] struct { + A Element[T] + B Element[T] +} + +func (c *NegationCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Neg(c.A) + f.AssertIsEqual(res, c.B) + return nil +} + +func TestNegation(t *testing.T) { + testNegation[Goldilocks](t) + testNegation[Secp256k1](t) + testNegation[BN254Fp](t) +} + +func testNegation[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness NegationCircuit[T] + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + res := new(big.Int).Sub(fp.Modulus(), val1) + witness.A.Assign(val1) + witness.B.Assign(res) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type InverseCircuit[T FieldParams] struct { + A Element[T] + B Element[T] +} + +func (c *InverseCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Inverse(c.A) + f.AssertIsEqual(res, c.B) + return nil +} + +func TestInverse(t *testing.T) { + testInverse[Goldilocks](t) + testInverse[Secp256k1](t) + testInverse[BN254Fp](t) +} + +func testInverse[T FieldParams](t *testing.T) { + var fp T + if !fp.IsPrime() { + t.Skip() + } + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness InverseCircuit[T] + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + res := new(big.Int).ModInverse(val1, fp.Modulus()) + witness.A.Assign(val1) + witness.B.Assign(res) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type DivisionCircuit[T FieldParams] struct { + A Element[T] + B Element[T] + C Element[T] +} + +func (c *DivisionCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Div(c.A, c.B) + f.AssertIsEqual(res, c.C) + return nil +} + +func TestDivision(t *testing.T) { + testDivision[Goldilocks](t) + testDivision[Secp256k1](t) + testDivision[BN254Fp](t) +} + +func testDivision[T FieldParams](t *testing.T) { + var fp T + if !fp.IsPrime() { + t.Skip() + } + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness DivisionCircuit[T] + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + val2, _ := rand.Int(rand.Reader, fp.Modulus()) + res := new(big.Int) + res.ModInverse(val2, fp.Modulus()) + res.Mul(val1, res) + res.Mod(res, fp.Modulus()) + witness.A.Assign(val1) + witness.B.Assign(val2) + witness.C.Assign(res) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type ToBinaryCircuit[T FieldParams] struct { + Value Element[T] + Bits []frontend.Variable +} + +func (c *ToBinaryCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + bits := f.ToBinary(c.Value) + if len(bits) != len(c.Bits) { + return fmt.Errorf("got %d bits, expected %d", len(bits), len(c.Bits)) + } + for i := range bits { + api.AssertIsEqual(bits[i], c.Bits[i]) + } + return nil +} + +func TestToBinary(t *testing.T) { + testToBinary[Goldilocks](t) + testToBinary[Secp256k1](t) + testToBinary[BN254Fp](t) +} + +func testToBinary[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness ToBinaryCircuit[T] + bitLen := fp.BitsPerLimb() * fp.NbLimbs() + circuit.Bits = make([]frontend.Variable, bitLen) + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + bits := make([]frontend.Variable, bitLen) + for i := 0; i < len(bits); i++ { + bits[i] = val1.Bit(i) + } + witness.Value.Assign(val1) + witness.Bits = bits + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type FromBinaryCircuit[T FieldParams] struct { + Bits []frontend.Variable + Res Element[T] +} + +func (c *FromBinaryCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.FromBinary(c.Bits) + f.AssertIsEqual(res, c.Res) + return nil +} + +func TestFromBinary(t *testing.T) { + testFromBinary[Goldilocks](t) + testFromBinary[Secp256k1](t) + testFromBinary[BN254Fp](t) +} + +func testFromBinary[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness FromBinaryCircuit[T] + bitLen := fp.Modulus().BitLen() + circuit.Bits = make([]frontend.Variable, bitLen) + + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + bits := make([]frontend.Variable, bitLen) + for i := 0; i < len(bits); i++ { + bits[i] = val1.Bit(i) + } + + witness.Res.Assign(val1) + witness.Bits = bits + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type EqualityCheckCircuit[T FieldParams] struct { + A Element[T] + B Element[T] +} + +func (c *EqualityCheckCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := c.A //f.Set(c.A) TODO @gbotrel fixme + f.AssertIsEqual(res, c.B) + return nil +} + +func TestConstantEqual(t *testing.T) { + testConstantEqual[Goldilocks](t) + testConstantEqual[BN254Fp](t) + testConstantEqual[Secp256k1](t) +} + +func testConstantEqual[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness EqualityCheckCircuit[T] + val, _ := rand.Int(rand.Reader, fp.Modulus()) + witness.A.Assign(val) + witness.B.Assign(val) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type SelectCircuit[T FieldParams] struct { + Selector frontend.Variable + A Element[T] + B Element[T] + C Element[T] +} + +func (c *SelectCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Select(c.Selector, c.A, c.B) + f.AssertIsEqual(res, c.C) + return nil +} + +func TestSelect(t *testing.T) { + testSelect[Goldilocks](t) + testSelect[Secp256k1](t) + testSelect[BN254Fp](t) +} + +func testSelect[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness SelectCircuit[T] + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + val2, _ := rand.Int(rand.Reader, fp.Modulus()) + randbit, _ := rand.Int(rand.Reader, big.NewInt(2)) + b := randbit.Uint64() + + witness.A.Assign(val1) + witness.B.Assign(val2) + witness.C.Assign([]*big.Int{val1, val2}[1-b]) + witness.Selector = b + + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type Lookup2Circuit[T FieldParams] struct { + Bit0 frontend.Variable + Bit1 frontend.Variable + A Element[T] + B Element[T] + C Element[T] + D Element[T] + E Element[T] +} + +func (c *Lookup2Circuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Lookup2(c.Bit0, c.Bit1, c.A, c.B, c.C, c.D) + f.AssertIsEqual(res, c.E) + return nil +} + +func TestLookup2(t *testing.T) { + testLookup2[Goldilocks](t) + testLookup2[Secp256k1](t) + testLookup2[BN254Fp](t) +} + +func testLookup2[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness Lookup2Circuit[T] + + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + val2, _ := rand.Int(rand.Reader, fp.Modulus()) + val3, _ := rand.Int(rand.Reader, fp.Modulus()) + val4, _ := rand.Int(rand.Reader, fp.Modulus()) + randbit, _ := rand.Int(rand.Reader, big.NewInt(4)) + + witness.A.Assign(val1) + witness.B.Assign(val2) + witness.C.Assign(val3) + witness.D.Assign(val4) + witness.E.Assign([]*big.Int{val1, val2, val3, val4}[randbit.Uint64()]) + witness.Bit0 = randbit.Bit(0) + witness.Bit1 = randbit.Bit(1) + + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +type ComputationCircuit[T FieldParams] struct { + noReduce bool + + X1, X2, X3, X4, X5, X6 Element[T] + Res Element[T] +} + +func (c *ComputationCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + // compute x1^3 + 5*x2 + (x3-x4) / (x5+x6) + x13 := f.Mul(c.X1, c.X1) + if !c.noReduce { + x13 = f.(*field[T]).reduce(x13.(Element[T])) + } + x13 = f.Mul(x13, c.X1) + if !c.noReduce { + x13 = f.(*field[T]).reduce(x13.(Element[T])) + } + + fx2 := f.Mul(5, c.X2) + fx2 = f.(*field[T]).reduce(fx2.(Element[T])) + + nom := f.Sub(c.X3, c.X4) + + denom := f.Add(c.X5, c.X6) + + free := f.Div(nom, denom) + + res := f.Add(x13, fx2, free) + + f.AssertIsEqual(res, c.Res) + return nil +} + +func TestComputation(t *testing.T) { + testComputation[Goldilocks](t) + testComputation[Secp256k1](t) + testComputation[BN254Fp](t) +} + +func testComputation[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness ComputationCircuit[T] + + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + val2, _ := rand.Int(rand.Reader, fp.Modulus()) + val3, _ := rand.Int(rand.Reader, fp.Modulus()) + val4, _ := rand.Int(rand.Reader, fp.Modulus()) + val5, _ := rand.Int(rand.Reader, fp.Modulus()) + val6, _ := rand.Int(rand.Reader, fp.Modulus()) + + tmp := new(big.Int) + res := new(big.Int) + // res = x1^3 + tmp.Exp(val1, big.NewInt(3), fp.Modulus()) + res.Set(tmp) + // res = x1^3 + 5*x2 + tmp.Mul(val2, big.NewInt(5)) + res.Add(res, tmp) + // tmp = (x3-x4) + tmp.Sub(val3, val4) + tmp.Mod(tmp, fp.Modulus()) + // tmp2 = (x5+x6) + tmp2 := new(big.Int) + tmp2.Add(val5, val6) + // tmp = (x3-x4)/(x5+x6) + tmp2.ModInverse(tmp2, fp.Modulus()) + tmp.Mul(tmp, tmp2) + tmp.Mod(tmp, fp.Modulus()) + // res = x1^3 + 5*x2 + (x3-x4)/(x5+x6) + res.Add(res, tmp) + res.Mod(res, fp.Modulus()) + + witness.X1.Assign(val1) + witness.X2.Assign(val2) + witness.X3.Assign(val3) + witness.X4.Assign(val4) + witness.X5.Assign(val5) + witness.X6.Assign(val6) + witness.Res.Assign(res) + + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} + +func TestOptimisation(t *testing.T) { + assert := test.NewAssert(t) + circuit := ComputationCircuit[BN254Fp]{ + noReduce: true, + } + ccs, err := frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit) + assert.NoError(err) + assert.LessOrEqual(ccs.GetNbConstraints(), 3291) + ccs2, err := frontend.Compile(testCurve.ScalarField(), scs.NewBuilder, &circuit) + assert.NoError(err) + assert.LessOrEqual(ccs2.GetNbConstraints(), 10722) +} + +type FourMulsCircuit[T FieldParams] struct { + A Element[T] + Res Element[T] +} + +func (c *FourMulsCircuit[T]) Define(api frontend.API) error { + f, err := NewField[T](api) + if err != nil { + return err + } + res := f.Mul(c.A, c.A, c.A, c.A) + f.AssertIsEqual(res, c.Res) + return nil +} + +func TestFourMuls(t *testing.T) { + testFourMuls[Goldilocks](t) + testFourMuls[Secp256k1](t) + testFourMuls[BN254Fp](t) +} + +func testFourMuls[T FieldParams](t *testing.T) { + var fp T + assert := test.NewAssert(t) + assert.Run(func(assert *test.Assert) { + var circuit, witness FourMulsCircuit[T] + + val1, _ := rand.Int(rand.Reader, fp.Modulus()) + res := new(big.Int) + res.Mul(val1, val1) + res.Mul(res, val1) + res.Mul(res, val1) + res.Mod(res, fp.Modulus()) + + witness.A.Assign(val1) + witness.Res.Assign(res) + assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.NoSerialization(), test.WithBackends(backend.GROTH16, backend.PLONK)) + }, testName[T]()) +} diff --git a/std/math/emulated/field.go b/std/math/emulated/field.go new file mode 100644 index 0000000000..3cec8af812 --- /dev/null +++ b/std/math/emulated/field.go @@ -0,0 +1,676 @@ +package emulated + +import ( + "errors" + "fmt" + "math/big" + "reflect" + "strconv" + "sync" + + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/logger" + "github.com/consensys/gnark/std/math/bits" + "github.com/rs/zerolog" +) + +// field defines the parameters of the emulated ring of integers modulo n. If +// n is prime, then the ring is also a finite field where inverse and division +// are allowed. +type field[T FieldParams] struct { + // api is the native API + api frontend.API + builder frontend.Builder + + // f carries the ring parameters + fParams T + + // maxOf is the maximum overflow before the element must be reduced. + maxOf uint + maxOfOnce sync.Once + + // constants for often used elements n, 0 and 1. Allocated only once + nConstOnce sync.Once + nConst Element[T] + zeroConstOnce sync.Once + zeroConst Element[T] + oneConstOnce sync.Once + oneConst Element[T] + + log zerolog.Logger +} + +// NewField returns an object to be used in-circuit to perform emulated arithmetic. +// +// The returned object implements frontend.API and as such, is used transparently in a circuit. +// +// This is an experimental feature and performing emulated arithmetic in-circuit is extremly costly. +// See package doc for more info. +func NewField[T FieldParams](native frontend.API) (frontend.API, error) { + f := &field[T]{ + api: native, + log: logger.Logger(), + } + + // ensure prime is correctly set + if f.fParams.IsPrime() { + if !f.fParams.Modulus().ProbablyPrime(20) { + return nil, fmt.Errorf("invalid parametrization: modulus is not prime") + } + } + + if f.fParams.BitsPerLimb() < 3 { + // even three is way too small, but it should probably work. + return nil, fmt.Errorf("nbBits must be at least 3") + } + + if f.fParams.Modulus().Cmp(big.NewInt(1)) < 1 { + return nil, fmt.Errorf("n must be at least 2") + } + + nbLimbs := (uint(f.fParams.Modulus().BitLen()) + f.fParams.BitsPerLimb() - 1) / f.fParams.BitsPerLimb() + if nbLimbs != f.fParams.NbLimbs() { + return nil, fmt.Errorf("nbLimbs mismatch got %d expected %d", f.fParams.NbLimbs(), nbLimbs) + } + + if f.api == nil { + return f, fmt.Errorf("missing api") + } + + if uint(f.api.Compiler().FieldBitLen()) < 2*f.fParams.BitsPerLimb()+1 { + return nil, fmt.Errorf("elements with limb length %d does not fit into scalar field", f.fParams.BitsPerLimb()) + } + + return f, nil +} + +func (f *field[T]) varToElement(in frontend.Variable) Element[T] { + switch vv := in.(type) { + case Element[T]: + return vv + case *Element[T]: + return *vv + default: + return NewElement[T](in) + } +} + +func (f *field[T]) varsToElements(in ...frontend.Variable) []Element[T] { + var els []Element[T] + for i := range in { + switch v := in[i].(type) { + case []frontend.Variable: + subels := f.varsToElements(v...) + els = append(els, subels...) + case frontend.Variable: + els = append(els, f.varToElement(v)) + default: + // handle nil value + panic("can't convert to Element[T]") + } + } + return els +} + +func (f *field[T]) Add(i1 frontend.Variable, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + els := f.varsToElements(i1, i2, in) + res := f.reduceAndOp(f.add, f.addPreCond, els[0], els[1]) + for i := 2; i < len(els); i++ { + res = f.reduceAndOp(f.add, f.addPreCond, res, els[i]) // TODO @gbotrel re-use res memory, don't reallocate limbs ! + } + return res +} + +// Negate sets e to -a and returns e. The returned element may be larger than +// the modulus. +func (f *field[T]) Neg(i1 frontend.Variable) frontend.Variable { + el := f.varToElement(i1) + + return f.Sub(f.Zero(), el) +} + +func (f *field[T]) Sub(i1 frontend.Variable, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + els := f.varsToElements(i1, i2, in) + sub := NewElement[T](nil) + sub.Set(els[1]) + for i := 2; i < len(els); i++ { + sub = f.reduceAndOp(f.add, f.addPreCond, sub, els[i]) + } + res := f.reduceAndOp(f.sub, f.subPreCond, els[0], sub) + return res +} + +func (f *field[T]) Mul(i1 frontend.Variable, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + els := f.varsToElements(i1, i2, in) + res := f.reduceAndOp(f.mul, f.mulPreCond, els[0], els[1]) + for i := 2; i < len(els); i++ { + res = f.reduceAndOp(f.mul, f.mulPreCond, res, els[i]) + } + return res +} + +func (f *field[T]) DivUnchecked(i1 frontend.Variable, i2 frontend.Variable) frontend.Variable { + return f.Div(i1, i2) +} + +// Div sets e to a/b and returns e. If modulus is not a prime, it panics. The +// result is less than the modulus. This method is more efficient than inverting +// b and multiplying it by a. +func (f *field[T]) Div(i1 frontend.Variable, i2 frontend.Variable) frontend.Variable { + if !f.fParams.IsPrime() { + // TODO shouldn't we still try to do a classic int div in a hint, constraint the result, and let it fail? + // that would enable things like uint32 div ? + panic("modulus not a prime") + } + + els := f.varsToElements(i1, i2) + a := els[0] + b := els[1] + div, err := computeDivisionHint(f.api, f, a.Limbs, b.Limbs) + if err != nil { + panic(fmt.Sprintf("compute division: %v", err)) + } + e := NewElement[T](nil) + e.Limbs = div + e.overflow = 0 + f.EnforceWidth(e) + res := (f.Mul(e, b)).(Element[T]) + f.assertIsEqual(res, a) + return e +} + +// Inverse sets e to 1/a and returns e. If modulus is not a prime, it panics. +// The result is less than the modulus. +func (f *field[T]) Inverse(i1 frontend.Variable) frontend.Variable { + a := f.varToElement(i1) + if !f.fParams.IsPrime() { + panic("modulus not a prime") + } + k, err := computeInverseHint(f.api, f, a.Limbs) + if err != nil { + panic(fmt.Sprintf("compute inverse: %v", err)) + } + e := NewElement[T](nil) + e.Limbs = k + e.overflow = 0 + f.EnforceWidth(e) + res := (f.Mul(e, a)).(Element[T]) + one := f.One() + f.assertIsEqual(res, one) + return e +} + +func (f *field[T]) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { + el := f.varToElement(i1) + res := f.reduce(el) + out := f.toBits(res) + switch len(n) { + case 0: + case 1: + // TODO @gbotrel this can unecessarly constraint some bits + // and falsify test results where we only want to "mask" a part of the element + out = out[:n[0]] + default: + panic("only single vararg permitted to ToBinary") + } + return out +} + +func (f *field[T]) FromBinary(b ...frontend.Variable) frontend.Variable { + els := f.varsToElements(b) + in := make([]frontend.Variable, len(els)) + for i := range els { + f.AssertIsBoolean(els[i]) + in[i] = els[i].Limbs[0] + } + e := NewElement[T](nil) + nbLimbs := (uint(len(in)) + e.fParams.BitsPerLimb() - 1) / e.fParams.BitsPerLimb() + limbs := make([]frontend.Variable, nbLimbs) + for i := uint(0); i < nbLimbs-1; i++ { + limbs[i] = bits.FromBinary(f.api, in[i*e.fParams.BitsPerLimb():(i+1)*e.fParams.BitsPerLimb()]) + } + limbs[nbLimbs-1] = bits.FromBinary(f.api, in[(nbLimbs-1)*e.fParams.BitsPerLimb():]) + e.overflow = 0 + e.Limbs = limbs + return e +} + +func (f *field[T]) Xor(a frontend.Variable, b frontend.Variable) frontend.Variable { + els := f.varsToElements(a, b) + f.AssertIsBoolean(els[0]) + f.AssertIsBoolean(els[1]) + rv := f.api.Xor(els[0].Limbs[0], els[1].Limbs[0]) + r := f.PackLimbs([]frontend.Variable{rv}) + + f.EnforceWidth(r) + return r +} + +func (f *field[T]) Or(a frontend.Variable, b frontend.Variable) frontend.Variable { + els := f.varsToElements(a, b) + f.AssertIsBoolean(els[0]) + f.AssertIsBoolean(els[1]) + rv := f.api.Or(els[0].Limbs[0], els[1].Limbs[0]) + r := f.PackLimbs([]frontend.Variable{rv}) + + f.EnforceWidth(r) + return r +} + +func (f *field[T]) And(a frontend.Variable, b frontend.Variable) frontend.Variable { + els := f.varsToElements(a, b) + f.AssertIsBoolean(els[0]) + f.AssertIsBoolean(els[1]) + rv := f.api.And(els[0].Limbs[0], els[1].Limbs[0]) + r := f.PackLimbs([]frontend.Variable{rv}) + + f.EnforceWidth(r) + return r +} + +func (f *field[T]) Select(b frontend.Variable, i1 frontend.Variable, i2 frontend.Variable) frontend.Variable { + els := f.varsToElements(i1, i2) + switch vv := b.(type) { + case Element[T]: + f.AssertIsBoolean(vv) + b = vv.Limbs[0] + case *Element[T]: + f.AssertIsBoolean(vv) + b = vv.Limbs[0] + } + if els[0].overflow == els[1].overflow && len(els[0].Limbs) == len(els[1].Limbs) { + return f._select(b, els[0], els[1]) + } + s0 := els[0] + s1 := els[1] + if s0.overflow != 0 || len(s0.Limbs) != int(f.fParams.NbLimbs()) { + s0 = f.reduce(s0) + } + if s1.overflow != 0 || len(s1.Limbs) != int(f.fParams.NbLimbs()) { + s1 = f.reduce(s1) + } + return f._select(b, s0, s1) +} + +func (f *field[T]) Lookup2(b0 frontend.Variable, b1 frontend.Variable, i0 frontend.Variable, i1 frontend.Variable, i2 frontend.Variable, i3 frontend.Variable) frontend.Variable { + els := f.varsToElements(i0, i1, i2, i3) + switch vv := b0.(type) { + case Element[T]: + f.AssertIsBoolean(vv) + b0 = vv.Limbs[0] + case *Element[T]: + f.AssertIsBoolean(vv) + b0 = vv.Limbs[0] + } + switch vv := b1.(type) { + case Element[T]: + f.AssertIsBoolean(vv) + b1 = vv.Limbs[0] + case *Element[T]: + f.AssertIsBoolean(vv) + b1 = vv.Limbs[0] + } + if els[0].overflow == els[1].overflow && els[0].overflow == els[2].overflow && els[0].overflow == els[3].overflow && len(els[0].Limbs) == len(els[1].Limbs) && len(els[0].Limbs) == len(els[2].Limbs) && len(els[0].Limbs) == len(els[3].Limbs) { + return f.lookup2(b0, b1, els[0], els[1], els[2], els[3]) + } + s0 := els[0] + s1 := els[1] + s2 := els[2] + s3 := els[3] + if s0.overflow != 0 || len(s0.Limbs) != int(f.fParams.NbLimbs()) { + s0 = f.reduce(s0) + } + if s1.overflow != 0 || len(s1.Limbs) != int(f.fParams.NbLimbs()) { + s1 = f.reduce(s1) + } + if s2.overflow != 0 || len(s2.Limbs) != int(f.fParams.NbLimbs()) { + s2 = f.reduce(s2) + } + if s3.overflow != 0 || len(s3.Limbs) != int(f.fParams.NbLimbs()) { + s3 = f.reduce(s3) + } + return f.lookup2(b0, b1, s0, s1, s2, s3) +} + +func (f *field[T]) IsZero(i1 frontend.Variable) frontend.Variable { + el := f.varToElement(i1) + reduced := f.reduce(el) + res := f.api.IsZero(reduced.Limbs[0]) + for i := 1; i < len(reduced.Limbs); i++ { + f.api.Mul(res, f.api.IsZero(reduced.Limbs[i])) + } + r := f.PackLimbs([]frontend.Variable{res}) + + f.EnforceWidth(r) + return r +} + +func (f *field[T]) Cmp(i1 frontend.Variable, i2 frontend.Variable) frontend.Variable { + els := f.varsToElements(i1, i2) + rls := make([]Element[T], 2) + rls[0] = f.reduce(els[0]) + rls[1] = f.reduce(els[1]) + var res frontend.Variable = 0 + for i := int(f.fParams.NbLimbs() - 1); i >= 0; i-- { + lmbCmp := f.api.Cmp(rls[0].Limbs[i], rls[1].Limbs[i]) + res = f.api.Select(f.api.IsZero(res), lmbCmp, res) + } + return res +} + +func (f *field[T]) AssertIsEqual(i1 frontend.Variable, i2 frontend.Variable) { + els := f.varsToElements(i1, i2) + tmp := NewElement[T](els[0]) + // tmp.Set(els[0]) // TODO @gbotrel do we need to duplicate here? + f.reduceAndOp(func(a, b Element[T], nextOverflow uint) Element[T] { + f.assertIsEqual(a, b) + return NewElement[T](nil) + }, + func(e1, e2 Element[T]) (uint, error) { + nextOverflow, err := f.subPreCond(e2, e1) // TODO @gbotrel previously "tmp.sub..." + var target errOverflow + if err != nil && errors.As(err, &target) { + target.reduceRight = !target.reduceRight + return nextOverflow, target + } + return nextOverflow, err + }, tmp, els[1]) +} + +func (f *field[T]) AssertIsDifferent(i1 frontend.Variable, i2 frontend.Variable) { + els := f.varsToElements(i1, i2) + rls := []Element[T]{NewElement[T](nil), NewElement[T](nil)} + rls[0] = f.reduce(els[0]) + rls[1] = f.reduce(els[1]) + var res frontend.Variable = 0 + for i := 0; i < int(f.fParams.NbLimbs()); i++ { + cmp := f.api.Cmp(rls[0].Limbs[i], rls[1].Limbs[i]) + cmpsq := f.api.Mul(cmp, cmp) + res = f.api.Add(res, cmpsq) + } + f.api.AssertIsDifferent(res, 0) +} + +func (f *field[T]) AssertIsBoolean(i1 frontend.Variable) { + switch vv := i1.(type) { + case Element[T]: + v := f.reduce(vv) + f.api.AssertIsBoolean(v.Limbs[0]) + for i := 1; i < len(v.Limbs); i++ { + f.api.AssertIsEqual(v.Limbs[i], 0) + } + case *Element[T]: + v := f.reduce(*vv) + f.api.AssertIsBoolean(v.Limbs[0]) + for i := 1; i < len(v.Limbs); i++ { + f.api.AssertIsEqual(v.Limbs[i], 0) + } + default: + f.api.AssertIsBoolean(vv) + } +} + +func (f *field[T]) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { + els := f.varsToElements(v, bound) + l := f.reduce(els[0]) + r := f.reduce(els[1]) + f.AssertIsLessEqualThan(l, r) +} + +func (f *field[T]) Println(a ...frontend.Variable) { + els := f.varsToElements(a) + for i := range els { + f.api.Println(els[i].Limbs...) + } +} + +func (f *field[T]) Compiler() frontend.Compiler { + return f +} + +type typedInput struct { + pos int + nbLimbs int + isElement bool +} + +func (f *field[T]) NewHint(hf hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { + // this is a trick to allow calling hint functions using non-native + // elements. We use the fact that the hints take as inputs *big.Int values. + // Instead of supplying hf to the solver for calling, we wrap it with + // another function (implementing hint.Function), which takes as inputs the + // "expanded" version of inputs (where instead of Element[T] values we provide + // as inputs the limbs of every Element[T]) and returns nbLimbs*nbOutputs + // number of outputs (i.e. the limbs of non-native Element[T] values). The + // wrapper then recomposes and decomposes the *big.Int values at runtime and + // provides them as input to the initially provided hint function. + var expandedInputs []frontend.Variable + typedInputs := make([]typedInput, len(inputs)) + for i := range inputs { + switch vv := inputs[i].(type) { + case Element[T]: + expandedInputs = append(expandedInputs, vv.Limbs...) + typedInputs[i] = typedInput{ + pos: len(expandedInputs) - len(vv.Limbs), + nbLimbs: len(vv.Limbs), + isElement: true, + } + case *Element[T]: + expandedInputs = append(expandedInputs, vv.Limbs...) + typedInputs[i] = typedInput{ + pos: len(expandedInputs) - len(vv.Limbs), + nbLimbs: len(vv.Limbs), + isElement: true, + } + default: + expandedInputs = append(expandedInputs, inputs[i]) + typedInputs[i] = typedInput{ + pos: len(expandedInputs) - 1, + nbLimbs: 1, + isElement: false, + } + } + } + nbNativeOutputs := nbOutputs * int(f.fParams.NbLimbs()) + wrappedHint := func(_ *big.Int, expandedHintInputs []*big.Int, expandedHintOutputs []*big.Int) error { + hintInputs := make([]*big.Int, len(inputs)) + hintOutputs := make([]*big.Int, nbOutputs) + for i, ti := range typedInputs { + hintInputs[i] = new(big.Int) + if ti.isElement { + if err := recompose(expandedHintInputs[ti.pos:ti.pos+ti.nbLimbs], f.fParams.BitsPerLimb(), hintInputs[i]); err != nil { + return fmt.Errorf("recompose: %w", err) + } + } else { + hintInputs[i].Set(expandedHintInputs[ti.pos]) + } + } + for i := range hintOutputs { + hintOutputs[i] = new(big.Int) + } + if err := hf(f.fParams.Modulus(), hintInputs, hintOutputs); err != nil { + return fmt.Errorf("call hint: %w", err) + } + for i := range hintOutputs { + if err := decompose(hintOutputs[i], f.fParams.BitsPerLimb(), expandedHintOutputs[i*int(f.fParams.NbLimbs()):(i+1)*int(f.fParams.NbLimbs())]); err != nil { + return fmt.Errorf("decompose: %w", err) + } + } + return nil + } + hintRet, err := f.api.Compiler().NewHint(wrappedHint, nbNativeOutputs, expandedInputs...) + if err != nil { + return nil, fmt.Errorf("NewHint: %w", err) + } + ret := make([]frontend.Variable, nbOutputs) + for i := 0; i < nbOutputs; i++ { + el := NewElement[T](nil) + el.Limbs = hintRet[i*int(f.fParams.NbLimbs()) : (i+1)*int(f.fParams.NbLimbs())] + ret[i] = el + } + return ret, nil +} + +func (f *field[T]) Tag(name string) frontend.Tag { + return f.api.Compiler().Tag(name) +} + +func (f *field[T]) AddCounter(from frontend.Tag, to frontend.Tag) { + f.api.Compiler().AddCounter(from, to) +} + +func (f *field[T]) ConstantValue(v frontend.Variable) (*big.Int, bool) { + var limbs []frontend.Variable // emulated limbs + switch vv := v.(type) { + case Element[T]: + limbs = vv.Limbs + case *Element[T]: + limbs = vv.Limbs + case []frontend.Variable: + limbs = vv + default: + return f.api.Compiler().ConstantValue(vv) + } + var ok bool + + constLimbs := make([]*big.Int, len(limbs)) + for i, l := range limbs { + // for each limb we get it's constant value if we can, or fail. + if constLimbs[i], ok = f.ConstantValue(l); !ok { + return nil, false + } + } + + res := new(big.Int) + if err := recompose(constLimbs, f.fParams.BitsPerLimb(), res); err != nil { + f.log.Error().Err(err).Msg("recomposing constant") + return nil, false + } + return res, true +} + +func (f *field[T]) Field() *big.Int { + return f.fParams.Modulus() +} + +func (f *field[T]) FieldBitLen() int { + return f.fParams.Modulus().BitLen() +} + +func (f *field[T]) IsBoolean(v frontend.Variable) bool { + switch vv := v.(type) { + case Element[T]: + return f.api.Compiler().IsBoolean(vv.Limbs[0]) + case *Element[T]: + return f.api.Compiler().IsBoolean(vv.Limbs[0]) + default: + return f.api.Compiler().IsBoolean(vv) + } +} + +func (f *field[T]) MarkBoolean(v frontend.Variable) { + switch vv := v.(type) { + case Element[T]: + f.api.Compiler().MarkBoolean(vv.Limbs[0]) + case *Element[T]: + f.api.Compiler().MarkBoolean(vv.Limbs[0]) + default: + f.api.Compiler().MarkBoolean(vv) + } +} + +// Modulus returns the modulus of the emulated ring as a constant. The returned +// element is not safe to use as an operation receiver. +func (f *field[T]) Modulus() Element[T] { + f.nConstOnce.Do(func() { + f.nConst = NewElement[T](f.fParams.Modulus()) + }) + return f.nConst +} + +// Zero returns zero as a constant. The returned element is not safe to use as +// an operation receiver. +func (f *field[T]) Zero() Element[T] { + f.zeroConstOnce.Do(func() { + f.zeroConst = NewElement[T](nil) + }) + return f.zeroConst +} + +// One returns one as a constant. The returned element is not safe to use as an +// operation receiver. +func (f *field[T]) One() Element[T] { + f.oneConstOnce.Do(func() { + f.oneConst = NewElement[T](1) + }) + return f.oneConst +} + +// PackLimbs returns a constant element from the given limbs. The +// returned element is not safe to use as an operation receiver. +func (f *field[T]) PackLimbs(limbs []frontend.Variable) Element[T] { + // TODO: check that every limb does not overflow the expected width + + return Element[T]{ + Limbs: limbs, + overflow: 0, + } +} + +// builderWrapper returns a wrapper for the builder which is compatible to use +// as a frontend compile option. When using this wrapper, it is possible to +// extend existing circuits into any emulated field defined by +func builderWrapper[T FieldParams]() frontend.BuilderWrapper { + return func(b frontend.Builder) frontend.Builder { + fw, err := NewField[T](b) + if err != nil { + panic(err) + } + fw.(*field[T]).builder = b + return fw.(*field[T]) + } +} + +func (f *field[T]) Compile() (frontend.CompiledConstraintSystem, error) { + return f.builder.Compile() +} + +func (f *field[T]) SetSchema(s *schema.Schema) { + f.builder.SetSchema(s) +} + +func (f *field[T]) VariableCount(t reflect.Type) int { + return int(f.fParams.NbLimbs()) +} + +func (f *field[T]) addVariable(sf *schema.Field, recurseFn func(*schema.Field) frontend.Variable) frontend.Variable { + limbs := make([]frontend.Variable, f.fParams.NbLimbs()) + var subfs []schema.Field + for i := range limbs { + subf := schema.Field{ + Name: strconv.Itoa(i), + Visibility: sf.Visibility, + FullName: fmt.Sprintf("%s_%d", sf.FullName, i), + Type: schema.Leaf, + ArraySize: 1, + } + subfs = append(subfs, subf) + limbs[i] = recurseFn(&subf) + } + sf.ArraySize = len(subfs) + sf.Type = schema.Array + sf.SubFields = subfs + el := f.PackLimbs(limbs) + return el +} + +func (f *field[T]) AddPublicVariable(sf *schema.Field) frontend.Variable { + return f.addVariable(sf, f.builder.AddPublicVariable) +} + +func (f *field[T]) AddSecretVariable(sf *schema.Field) frontend.Variable { + return f.addVariable(sf, f.builder.AddSecretVariable) + +} diff --git a/std/math/emulated/field_test.go b/std/math/emulated/field_test.go new file mode 100644 index 0000000000..06bece4488 --- /dev/null +++ b/std/math/emulated/field_test.go @@ -0,0 +1,389 @@ +package emulated + +import ( + "crypto/rand" + "errors" + "fmt" + "math/big" + "sort" + "testing" + + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/internal/backend/circuits" + "github.com/consensys/gnark/std/algebra/fields_bls12377" + "github.com/consensys/gnark/std/algebra/sw_bls12377" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" +) + +func witnessData(q *big.Int) (X1, X2, X3, X4, X5, X6, Res *big.Int) { + x1, _ := rand.Int(rand.Reader, q) + x2, _ := rand.Int(rand.Reader, q) + x3, _ := rand.Int(rand.Reader, q) + x4, _ := rand.Int(rand.Reader, q) + x5, _ := rand.Int(rand.Reader, q) + x6, _ := rand.Int(rand.Reader, q) + + tmp := new(big.Int) + res := new(big.Int) + // res = x1^3 + tmp.Exp(x1, big.NewInt(3), q) + res.Set(tmp) + // res = x1^3 + 5*x2 + tmp.Mul(x2, big.NewInt(5)) + res.Add(res, tmp) + // tmp = (x3-x4) + tmp.Sub(x3, x4) + tmp.Mod(tmp, q) + // tmp2 = (x5+x6) + tmp2 := new(big.Int) + tmp2.Add(x5, x6) + // tmp = (x3-x4)/(x5+x6) + tmp2.ModInverse(tmp2, q) + tmp.Mul(tmp, tmp2) + tmp.Mod(tmp, q) + // res = x1^3 + 5*x2 + (x3-x4)/(x5+x6) + res.Add(res, tmp) + res.Mod(res, q) + return x1, x2, x3, x4, x5, x6, res +} + +type WrapperCircuit struct { + X1, X2, X3, X4, X5, X6 frontend.Variable + Res frontend.Variable +} + +func (c *WrapperCircuit) Define(api frontend.API) error { + // compute x1^3 + 5*x2 + (x3-x4) / (x5+x6) + x13 := api.Mul(c.X1, c.X1, c.X1) + fx2 := api.Mul(5, c.X2) + nom := api.Sub(c.X3, c.X4) + denom := api.Add(c.X5, c.X6) + free := api.Div(nom, denom) + res := api.Add(x13, fx2, free) + api.AssertIsEqual(res, c.Res) + return nil +} + +func TestTestEngineWrapper(t *testing.T) { + assert := test.NewAssert(t) + + circuit := WrapperCircuit{ + X1: NewElement[Secp256k1](nil), + X2: NewElement[Secp256k1](nil), + X3: NewElement[Secp256k1](nil), + X4: NewElement[Secp256k1](nil), + X5: NewElement[Secp256k1](nil), + X6: NewElement[Secp256k1](nil), + Res: NewElement[Secp256k1](nil), + } + + x1, x2, x3, x4, x5, x6, res := witnessData(Secp256k1{}.Modulus()) + witness := WrapperCircuit{ + X1: NewElement[Secp256k1](x1), + X2: NewElement[Secp256k1](x2), + X3: NewElement[Secp256k1](x3), + X4: NewElement[Secp256k1](x4), + X5: NewElement[Secp256k1](x5), + X6: NewElement[Secp256k1](x6), + Res: NewElement[Secp256k1](res), + } + wrapperOpt := test.WithApiWrapper(func(api frontend.API) frontend.API { + napi, err := NewField[Secp256k1](api) + assert.NoError(err) + return napi + }) + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField(), wrapperOpt) + assert.NoError(err) +} + +func TestCompilerWrapper(t *testing.T) { + assert := test.NewAssert(t) + circuit := WrapperCircuit{ + X1: NewElement[Secp256k1](nil), + X2: NewElement[Secp256k1](nil), + X3: NewElement[Secp256k1](nil), + X4: NewElement[Secp256k1](nil), + X5: NewElement[Secp256k1](nil), + X6: NewElement[Secp256k1](nil), + Res: NewElement[Secp256k1](nil), + } + + x1, x2, x3, x4, x5, x6, res := witnessData(Secp256k1{}.Modulus()) + witness := WrapperCircuit{ + X1: NewElement[Secp256k1](x1), + X2: NewElement[Secp256k1](x2), + X3: NewElement[Secp256k1](x3), + X4: NewElement[Secp256k1](x4), + X5: NewElement[Secp256k1](x5), + X6: NewElement[Secp256k1](x6), + Res: NewElement[Secp256k1](res), + } + ccs, err := frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.WithBuilderWrapper(builderWrapper[Secp256k1]())) + assert.NoError(err) + t.Log(ccs.GetNbConstraints()) + // TODO: create proof + _ = witness +} + +func TestIntegrationApi(t *testing.T) { + assert := test.NewAssert(t) + wrapperOpt := test.WithApiWrapper(func(api frontend.API) frontend.API { + napi, err := NewField[Secp256k1](api) + assert.NoError(err) + return napi + }) + keys := make([]string, 0, len(circuits.Circuits)) + for k := range circuits.Circuits { + keys = append(keys, k) + } + sort.Strings(keys) + + for i := range keys { + name := keys[i] + if name == "inv" || name == "div" || name == "cmp" { + // TODO @gbotrel thes don't pass when we use emulated field modulus != snark field + continue + } + tData := circuits.Circuits[name] + assert.Run(func(assert *test.Assert) { + _, err := frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, tData.Circuit, frontend.WithBuilderWrapper(builderWrapper[Secp256k1]())) + assert.NoError(err) + }, name, "compile") + for i := range tData.ValidAssignments { + assignment := tData.ValidAssignments[i] + assert.Run(func(assert *test.Assert) { + err := test.IsSolved(tData.Circuit, assignment, testCurve.ScalarField(), wrapperOpt) + assert.NoError(err) + }, name, fmt.Sprintf("valid=%d", i)) + } + for i := range tData.InvalidAssignments { + assignment := tData.InvalidAssignments[i] + assert.Run(func(assert *test.Assert) { + err := test.IsSolved(tData.Circuit, assignment, testCurve.ScalarField(), wrapperOpt) + assert.Error(err) + }, name, fmt.Sprintf("invalid=%d", i)) + } + } +} + +func TestVarToElements(t *testing.T) { + assert := require.New(t) + _f, _ := NewField[BN254Fp](nil) + + f := _f.(*field[BN254Fp]) + + { + in := []frontend.Variable{8000, 42} + out1 := f.varsToElements(in...) + out2 := f.varsToElements(in) + + assert.Equal(len(out1), len(out2)) + assert.Equal(len(out1), 2) + } + + defer func() { + if r := recover(); r == nil { + t.Fatal("nil input should panic") + } + }() + in := []frontend.Variable{8000, nil, 3000} + _ = f.varsToElements(in) +} + +type pairingBLS377 struct { + P sw_bls12377.G1Affine `gnark:",public"` + Q sw_bls12377.G2Affine + pairingRes bls12377.GT +} + +//lint:ignore U1000 skipped test +func (circuit *pairingBLS377) Define(api frontend.API) error { + pairingRes, _ := sw_bls12377.Pair(api, + []sw_bls12377.G1Affine{circuit.P}, + []sw_bls12377.G2Affine{circuit.Q}) + api.AssertIsEqual(pairingRes.C0.B0.A0, &circuit.pairingRes.C0.B0.A0) + api.AssertIsEqual(pairingRes.C0.B0.A1, &circuit.pairingRes.C0.B0.A1) + api.AssertIsEqual(pairingRes.C0.B1.A0, &circuit.pairingRes.C0.B1.A0) + api.AssertIsEqual(pairingRes.C0.B1.A1, &circuit.pairingRes.C0.B1.A1) + api.AssertIsEqual(pairingRes.C0.B2.A0, &circuit.pairingRes.C0.B2.A0) + api.AssertIsEqual(pairingRes.C0.B2.A1, &circuit.pairingRes.C0.B2.A1) + api.AssertIsEqual(pairingRes.C1.B0.A0, &circuit.pairingRes.C1.B0.A0) + api.AssertIsEqual(pairingRes.C1.B0.A1, &circuit.pairingRes.C1.B0.A1) + api.AssertIsEqual(pairingRes.C1.B1.A0, &circuit.pairingRes.C1.B1.A0) + api.AssertIsEqual(pairingRes.C1.B1.A1, &circuit.pairingRes.C1.B1.A1) + api.AssertIsEqual(pairingRes.C1.B2.A0, &circuit.pairingRes.C1.B2.A0) + api.AssertIsEqual(pairingRes.C1.B2.A1, &circuit.pairingRes.C1.B2.A1) + return nil +} + +func TestPairingBLS377(t *testing.T) { + t.Skip() + assert := test.NewAssert(t) + + _, _, P, Q := bls12377.Generators() + milRes, _ := bls12377.MillerLoop([]bls12377.G1Affine{P}, []bls12377.G2Affine{Q}) + pairingRes := bls12377.FinalExponentiation(&milRes) + + circuit := pairingBLS377{ + pairingRes: pairingRes, + P: sw_bls12377.G1Affine{ + X: NewElement[BLS12377Fp](nil), + Y: NewElement[BLS12377Fp](nil), + }, + Q: sw_bls12377.G2Affine{ + X: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + Y: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + }, + } + witness := pairingBLS377{ + pairingRes: pairingRes, + P: sw_bls12377.G1Affine{ + X: NewElement[BLS12377Fp](P.X), + Y: NewElement[BLS12377Fp](P.Y), + }, + Q: sw_bls12377.G2Affine{ + X: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](Q.X.A0), + A1: NewElement[BLS12377Fp](Q.X.A1), + }, + Y: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](Q.Y.A0), + A1: NewElement[BLS12377Fp](Q.Y.A1), + }, + }, + } + + wrapperOpt := test.WithApiWrapper(func(api frontend.API) frontend.API { + napi, err := NewField[BLS12377Fp](api) + assert.NoError(err) + return napi + }) + // TODO @gbotrel test engine when test.SetAllVariablesAsConstants() also consider witness as + // constant. It shouldn't. + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField(), wrapperOpt) //, test.SetAllVariablesAsConstants()) + assert.NoError(err) + // _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.WithBuilderWrapper(builderWrapper[BLS12377Fp]())) + // assert.NoError(err) + // TODO: create proof +} + +type ConstantCircuit struct { +} + +func (c *ConstantCircuit) Define(api frontend.API) error { + f, err := NewField[Secp256k1](api) + if err != nil { + return err + } + { + c1 := NewElement[Secp256k1](42) + b1, ok := f.ConstantValue(c1) + if !ok { + return errors.New("42 should be constant") + } + if !(b1.IsUint64() && b1.Uint64() == 42) { + return errors.New("42 != constant(42)") + } + } + { + m := f.(*field[Secp256k1]).Modulus() + b1, ok := f.ConstantValue(m) + if !ok { + return errors.New("modulus should be constant") + } + if b1.Cmp(Secp256k1{}.Modulus()) != 0 { + return errors.New("modulus != constant(modulus)") + } + } + + return nil +} + +func TestConstantCircuit(t *testing.T) { + assert := test.NewAssert(t) + + var circuit, witness ConstantCircuit + + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField(), test.SetAllVariablesAsConstants()) + assert.NoError(err) + + _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.IgnoreUnconstrainedInputs()) + assert.NoError(err) +} + +type MulConstantCircuit struct { +} + +func (c *MulConstantCircuit) Define(api frontend.API) error { + f, err := NewField[Secp256k1](api) + if err != nil { + return err + } + c0 := NewElement[Secp256k1](0) + c1 := NewElement[Secp256k1](0) + c2 := NewElement[Secp256k1](0) + r := f.Mul(c0, c1) + f.AssertIsEqual(r, c2) + + return nil +} + +func TestMulConstantCircuit(t *testing.T) { + assert := test.NewAssert(t) + + var circuit, witness MulConstantCircuit + + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField(), test.SetAllVariablesAsConstants()) + assert.NoError(err) + + _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.IgnoreUnconstrainedInputs()) + assert.NoError(err) +} + +type SubConstantCircuit struct { +} + +func (c *SubConstantCircuit) Define(api frontend.API) error { + f, err := NewField[Secp256k1](api) + if err != nil { + return err + } + c0 := NewElement[Secp256k1](0) + c1 := NewElement[Secp256k1](0) + c2 := NewElement[Secp256k1](0) + r := f.Sub(c0, c1) + if r.(Element[Secp256k1]).overflow != 0 { + return fmt.Errorf("overflow %d != 0", r.(Element[Secp256k1]).overflow) + } + // rc, ok := f.ConstantValue(r) + // if !ok { + // return errors.New("0 - 0 is not constant") + // } + // if !rc.IsUint64() && rc.Uint64() == 0 { + // return fmt.Errorf("0 - 0 = %s", rc.String()) + // } + f.AssertIsEqual(r, c2) + + return nil +} + +func TestSubConstantCircuit(t *testing.T) { + assert := test.NewAssert(t) + + var circuit, witness SubConstantCircuit + + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField(), test.SetAllVariablesAsConstants()) + assert.NoError(err) + + _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.IgnoreUnconstrainedInputs()) + assert.NoError(err) +} diff --git a/std/math/nonnative/hints.go b/std/math/emulated/hints.go similarity index 61% rename from std/math/nonnative/hints.go rename to std/math/emulated/hints.go index 8d5cf2f058..7620d3180c 100644 --- a/std/math/nonnative/hints.go +++ b/std/math/emulated/hints.go @@ -1,4 +1,4 @@ -package nonnative +package emulated import ( "fmt" @@ -8,21 +8,28 @@ import ( "github.com/consensys/gnark/frontend" ) +// TODO @gbotrel hint[T FieldParams] would simplify this . Issue is when registering hint, if QuoRem[T] was declared +// inside a func, then it becomes anonymous and hint identification is screwed. + +func init() { + hint.Register(GetHints()...) +} + // GetHints returns all hint functions used in the package. func GetHints() []hint.Function { return []hint.Function{ DivHint, - EqualityHint, + QuoHint, InverseHint, MultiplicationHint, - ReductionHint, + RemHint, } } // computeMultiplicationHint packs the inputs for the MultiplicationHint hint function. -func computeMultiplicationHint(api frontend.API, params *Params, leftLimbs, rightLimbs []frontend.Variable) (mulLimbs []frontend.Variable, err error) { +func computeMultiplicationHint[T FieldParams](api frontend.API, params *field[T], leftLimbs, rightLimbs []frontend.Variable) (mulLimbs []frontend.Variable, err error) { hintInputs := []frontend.Variable{ - params.nbBits, + params.fParams.BitsPerLimb(), len(leftLimbs), len(rightLimbs), } @@ -75,117 +82,93 @@ func MultiplicationHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) err return nil } -// computeReductionHint packs inputs for the ReductionHint hint function. -func computeReductionHint(api frontend.API, params *Params, inLimbs []frontend.Variable) (reducedLimbs []frontend.Variable, err error) { +// computeRemHint packs inputs for the RemHint hint function. +// sets z to the remainder x%y for y != 0 and returns z. +func (f *field[T]) computeRemHint(x, y Element[T]) (z Element[T], err error) { + var fp T hintInputs := []frontend.Variable{ - params.nbBits, - params.nbLimbs, + fp.BitsPerLimb(), + len(x.Limbs), } - p := params.Modulus() - for i := range p.Limbs { - hintInputs = append(hintInputs, frontend.Variable(p.Limbs[i])) + hintInputs = append(hintInputs, x.Limbs...) + hintInputs = append(hintInputs, y.Limbs...) + limbs, err := f.api.NewHint(RemHint, int(len(y.Limbs)), hintInputs...) + if err != nil { + return Element[T]{}, err } - hintInputs = append(hintInputs, inLimbs...) - return api.NewHint(ReductionHint, int(params.nbLimbs), hintInputs...) + return f.PackLimbs(limbs), nil } -// ReductionHint computes the remainder r for input x = k*p + r and stores it -// in outputs. -func ReductionHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { - if len(inputs) < 2 { - return fmt.Errorf("input must be at least two elements") - } - nbBits := uint(inputs[0].Uint64()) - nbLimbs := int(inputs[1].Int64()) - if len(inputs[2:]) < 2*nbLimbs { - return fmt.Errorf("reducible value missing") - } - if len(outputs) != nbLimbs { - return fmt.Errorf("result does not fit into output") +// RemHint sets z to the remainder x%y for y != 0 and returns z. +// If y == 0, returns an error. +// Rem implements truncated modulus (like Go); see QuoRem for more details. +func RemHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + nbBits, _, x, y, err := parseHintDivInputs(inputs) + if err != nil { + return err } - p := new(big.Int) - if err := recompose(inputs[2:2+nbLimbs], nbBits, p); err != nil { - return fmt.Errorf("recompose emulated order: %w", err) - } - x := new(big.Int) - if err := recompose(inputs[2+nbLimbs:], nbBits, x); err != nil { - return fmt.Errorf("recompose value: %w", err) - } - q := new(big.Int) r := new(big.Int) - q.QuoRem(x, p, r) + r.Rem(x, y) if err := decompose(r, nbBits, outputs); err != nil { return fmt.Errorf("decompose remainder: %w", err) } return nil } -// computeEqualityHint packs the inputs for EqualityHint function. -func computeEqualityHint(api frontend.API, params *Params, diff Element) (kpLimbs []frontend.Variable, err error) { - p := params.Modulus() - resLen := (uint(len(diff.Limbs))*params.nbBits + diff.overflow + 1 - // diff total bitlength - uint(params.r.BitLen()) + // subtract modulus bitlength - params.nbBits - 1) / // to round up - params.nbBits +// computeQuoHint packs the inputs for QuoHint function and returns z = x / y +// (discards remainder) +func (f *field[T]) computeQuoHint(x Element[T]) (z Element[T], err error) { + var fp T + resLen := (uint(len(x.Limbs))*fp.BitsPerLimb() + x.overflow + 1 - // diff total bitlength + uint(fp.Modulus().BitLen()) + // subtract modulus bitlength + fp.BitsPerLimb() - 1) / // to round up + fp.BitsPerLimb() + hintInputs := []frontend.Variable{ - params.nbBits, - params.nbLimbs, + fp.BitsPerLimb(), + len(x.Limbs), } + p := f.Modulus() + hintInputs = append(hintInputs, x.Limbs...) hintInputs = append(hintInputs, p.Limbs...) - hintInputs = append(hintInputs, diff.Limbs...) - return api.NewHint(EqualityHint, int(resLen), hintInputs...) -} -// EqualityHint computes k for input x = k*p and stores it in outputs. -func EqualityHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { - // first value is the number of bits per limb (nbBits) - // second value is the number of limbs for modulus - // then comes emulated modulus (p) - // and the rest is for the difference (diff) - // - // if the quotient z = diff / p is larger than the scalar modulus, then err. - // Otherwise we store the quotient in the output element (a single element). - // - // errors when does not divide evenly (we do not allow to generate invalid - // proofs) - if len(inputs) < 2 { - return fmt.Errorf("at least 2 inputs required") + limbs, err := f.api.NewHint(QuoHint, int(resLen), hintInputs...) + if err != nil { + return Element[T]{}, err } - nbBits := uint(inputs[0].Uint64()) - nbLimbs := int(inputs[1].Int64()) - if len(inputs[2:]) < nbLimbs { - return fmt.Errorf("modulus limbs missing") - } - p := new(big.Int) - diff := new(big.Int) - if err := recompose(inputs[2:2+nbLimbs], nbBits, p); err != nil { - return fmt.Errorf("recompose emulated order: %w", err) - } - if err := recompose(inputs[2+nbLimbs:], nbBits, diff); err != nil { - return fmt.Errorf("recompose diff") + + return f.PackLimbs(limbs), nil +} + +// QuoHint sets z to the quotient x/y for y != 0 and returns z. +// If y == 0, returns an error. +// Quo implements truncated division (like Go); see QuoRem for more details. +func QuoHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + nbBits, _, x, y, err := parseHintDivInputs(inputs) + if err != nil { + return err } - r := new(big.Int) z := new(big.Int) - z.QuoRem(diff, p, r) - if r.Cmp(big.NewInt(0)) != 0 { - return fmt.Errorf("modulus does not divide diff evenly") - } + z.Quo(x, y) //.Mod(z, y) + if err := decompose(z, nbBits, outputs); err != nil { return fmt.Errorf("decompose: %w", err) } + return nil } // computeInverseHint packs the inputs for the InverseHint hint function. -func computeInverseHint(api frontend.API, params *Params, inLimbs []frontend.Variable) (inverseLimbs []frontend.Variable, err error) { +func computeInverseHint[T FieldParams](api frontend.API, params *field[T], inLimbs []frontend.Variable) (inverseLimbs []frontend.Variable, err error) { + var fp T hintInputs := []frontend.Variable{ - params.nbBits, - params.nbLimbs, + fp.BitsPerLimb(), + fp.NbLimbs(), } p := params.Modulus() hintInputs = append(hintInputs, p.Limbs...) hintInputs = append(hintInputs, inLimbs...) - return api.NewHint(InverseHint, int(params.nbLimbs), hintInputs...) + return api.NewHint(InverseHint, int(fp.NbLimbs()), hintInputs...) } // InverseHint computes the inverse x^-1 for the input x and stores it in outputs. @@ -219,17 +202,18 @@ func InverseHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { } // computeDivisionHint packs the inputs for DivisionHint hint function. -func computeDivisionHint(api frontend.API, params *Params, nomLimbs, denomLimbs []frontend.Variable) (divLimbs []frontend.Variable, err error) { +func computeDivisionHint[T FieldParams](api frontend.API, params *field[T], nomLimbs, denomLimbs []frontend.Variable) (divLimbs []frontend.Variable, err error) { + var fp T hintInputs := []frontend.Variable{ - params.nbBits, - params.nbLimbs, + fp.BitsPerLimb(), + fp.NbLimbs(), len(nomLimbs), } p := params.Modulus() hintInputs = append(hintInputs, p.Limbs...) hintInputs = append(hintInputs, nomLimbs...) hintInputs = append(hintInputs, denomLimbs...) - return api.NewHint(DivHint, int(params.nbLimbs), hintInputs...) + return api.NewHint(DivHint, int(fp.NbLimbs()), hintInputs...) } // DivHint computes the value z = x/y for inputs x and y and stores z in @@ -272,3 +256,30 @@ func DivHint(mod *big.Int, inputs []*big.Int, outputs []*big.Int) error { } return nil } + +// input[0] = nbBits per limb +// input[1] = nbLimbs(x) +// input[2:2+nbLimbs(x)] = limbs(x) +// input[2+nbLimbs(x):] = limbs(y) +// errors if y == 0 +func parseHintDivInputs(inputs []*big.Int) (uint, int, *big.Int, *big.Int, error) { + if len(inputs) < 2 { + return 0, 0, nil, nil, fmt.Errorf("at least 2 inputs required") + } + nbBits := uint(inputs[0].Uint64()) + nbLimbs := int(inputs[1].Int64()) + if len(inputs[2:]) < nbLimbs { + return 0, 0, nil, nil, fmt.Errorf("x limbs missing") + } + x, y := new(big.Int), new(big.Int) + if err := recompose(inputs[2:2+nbLimbs], nbBits, x); err != nil { + return 0, 0, nil, nil, fmt.Errorf("recompose x: %w", err) + } + if err := recompose(inputs[2+nbLimbs:], nbBits, y); err != nil { + return 0, 0, nil, nil, fmt.Errorf("recompose y: %w", err) + } + if y.IsUint64() && y.Uint64() == 0 { + return 0, 0, nil, nil, fmt.Errorf("y == 0") + } + return nbBits, nbLimbs, x, y, nil +} diff --git a/std/math/emulated/pairing_test.go b/std/math/emulated/pairing_test.go new file mode 100644 index 0000000000..6391a038a5 --- /dev/null +++ b/std/math/emulated/pairing_test.go @@ -0,0 +1,116 @@ +package emulated + +import ( + "testing" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/consensys/gnark/std/algebra/fields_bls12377" + "github.com/consensys/gnark/std/algebra/sw_bls12377" + "github.com/consensys/gnark/test" +) + +type mlBLS377 struct { + R sw_bls12377.GT +} + +func (circuit *mlBLS377) Define(api frontend.API) error { + circuit.R, _ = e12Squares(api, circuit.R) + return nil +} + +func TestE12SquareBLS377(t *testing.T) { + if testing.Short() { + t.Skip() + } + assert := test.NewAssert(t) + + circuit := mlBLS377{ + R: sw_bls12377.GT{ + C0: fields_bls12377.E6{ + B0: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + B1: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + B2: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + }, + C1: fields_bls12377.E6{ + B0: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + B1: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + B2: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + }, + }, + } + + witness := mlBLS377{ + R: sw_bls12377.GT{ + C0: fields_bls12377.E6{ + B0: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + B1: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + B2: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + }, + C1: fields_bls12377.E6{ + B0: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + B1: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + B2: fields_bls12377.E2{ + A0: NewElement[BLS12377Fp](nil), + A1: NewElement[BLS12377Fp](nil), + }, + }, + }, + } + + wrapperOpt := test.WithApiWrapper(func(api frontend.API) frontend.API { + napi, err := NewField[BLS12377Fp](api) + assert.NoError(err) + return napi + }) + + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField(), wrapperOpt) //, test.SetAllVariablesAsConstants()) + assert.NoError(err) + + _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.WithBuilderWrapper(builderWrapper[BLS12377Fp]()), frontend.IgnoreUnconstrainedInputs()) + assert.NoError(err) + +} + +// e12Squares +func e12Squares(api frontend.API, R sw_bls12377.GT) (sw_bls12377.GT, error) { + const N = 4 + for i := 0; i < N; i++ { + R.Square(api, R) + } + + return R, nil +} diff --git a/std/math/emulated/params.go b/std/math/emulated/params.go new file mode 100644 index 0000000000..067b6c9f9f --- /dev/null +++ b/std/math/emulated/params.go @@ -0,0 +1,61 @@ +package emulated + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc" +) + +// FieldParams describe the emulated field characteristics +type FieldParams interface { + NbLimbs() uint + BitsPerLimb() uint // limbSize is number of bits per limb. Top limb may contain less than limbSize bits. + IsPrime() bool + Modulus() *big.Int // TODO @gbotrel built-in don't copy value, we probably should. +} + +var ( + qSecp256k1 *big.Int + qGoldilocks *big.Int +) + +func init() { + qSecp256k1, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", 16) + qGoldilocks, _ = new(big.Int).SetString("ffffffff00000001", 16) +} + +// Goldilocks provide type parametrization for emulated field on 1 limb of width 64bits +// for modulus 0xffffffff00000001 +type Goldilocks struct{} + +func (fp Goldilocks) NbLimbs() uint { return 1 } +func (fp Goldilocks) BitsPerLimb() uint { return 64 } +func (fp Goldilocks) IsPrime() bool { return true } +func (fp Goldilocks) Modulus() *big.Int { return qGoldilocks } + +// Secp256k1 provide type parametrization for emulated field on 8 limb of width 32bits +// for modulus 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f +type Secp256k1 struct{} + +func (fp Secp256k1) NbLimbs() uint { return 8 } +func (fp Secp256k1) BitsPerLimb() uint { return 32 } +func (fp Secp256k1) IsPrime() bool { return true } +func (fp Secp256k1) Modulus() *big.Int { return qSecp256k1 } + +// BN254Fp provide type parametrization for emulated field on 8 limb of width 32bits +// for modulus 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 +type BN254Fp struct{} + +func (fp BN254Fp) NbLimbs() uint { return 8 } +func (fp BN254Fp) BitsPerLimb() uint { return 32 } +func (fp BN254Fp) IsPrime() bool { return true } +func (fp BN254Fp) Modulus() *big.Int { return ecc.BN254.BaseField() } + +// BLS12377Fp provide type parametrization for emulated field on 6 limb of width 64bits +// for modulus 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 +type BLS12377Fp struct{} + +func (fp BLS12377Fp) NbLimbs() uint { return 6 } +func (fp BLS12377Fp) BitsPerLimb() uint { return 64 } +func (fp BLS12377Fp) IsPrime() bool { return true } +func (fp BLS12377Fp) Modulus() *big.Int { return ecc.BLS12_377.BaseField() } diff --git a/std/math/nonnative/api.go b/std/math/nonnative/api.go deleted file mode 100644 index ca6aac9e00..0000000000 --- a/std/math/nonnative/api.go +++ /dev/null @@ -1,587 +0,0 @@ -package nonnative - -import ( - "errors" - "fmt" - "math/big" - "reflect" - "strconv" - - "github.com/consensys/gnark/backend/hint" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/compiled" - "github.com/consensys/gnark/frontend/schema" -) - -// BuilderWrapper returns a wrapper for the builder which is compatible to use -// as a frontend compile option. When using this wrapper, it is possible to -// extend existing circuits into any emulated field defined by params. -func BuilderWrapper(params *Params) frontend.BuilderWrapper { - return func(b frontend.Builder) frontend.Builder { - return &fakeAPI{ - builder: b, - api: b, - params: params, - } - } -} - -func (f *fakeAPI) Compile() (frontend.CompiledConstraintSystem, error) { - return f.builder.Compile() -} - -func (f *fakeAPI) SetSchema(s *schema.Schema) { - f.builder.SetSchema(s) -} - -func (f *fakeAPI) VariableCount(t reflect.Type) int { - return int(f.params.nbLimbs) -} - -func (f *fakeAPI) addVariable(field *schema.Field, recurseFn func(*schema.Field) frontend.Variable) frontend.Variable { - limbs := make([]frontend.Variable, f.params.nbLimbs) - var subfs []schema.Field - for i := range limbs { - subf := schema.Field{ - Name: strconv.Itoa(i), - Visibility: field.Visibility, - FullName: fmt.Sprintf("%s_%d", field.FullName, i), - Type: schema.Leaf, - ArraySize: 1, - } - subfs = append(subfs, subf) - limbs[i] = recurseFn(&subf) - } - field.ArraySize = len(subfs) - field.Type = schema.Array - field.SubFields = subfs - el := f.params.ConstantFromLimbs(limbs) - return el -} - -func (f *fakeAPI) AddPublicVariable(field *schema.Field) frontend.Variable { - return f.addVariable(field, f.builder.AddPublicVariable) -} - -func (f *fakeAPI) AddSecretVariable(field *schema.Field) frontend.Variable { - return f.addVariable(field, f.builder.AddSecretVariable) - -} - -// NewAPI wraps the existing native API such that all methods are performed -// using field emulation. -func NewAPI(native frontend.API, params *Params) frontend.API { - return &fakeAPI{ - api: native, - params: params, - } -} - -type fakeAPI struct { - // api is the native API - api frontend.API - builder frontend.Builder - params *Params -} - -func (f *fakeAPI) varToElement(in frontend.Variable) *Element { - var e *Element - switch vv := in.(type) { - case Element: - e = &vv - case *Element: - e = vv - case *big.Int: - el := f.params.ConstantFromBigOrPanic(vv) - e = &el - case big.Int: - el := f.params.ConstantFromBigOrPanic(&vv) - e = &el - case int: - el := f.params.ConstantFromBigOrPanic(big.NewInt(int64(vv))) - e = &el - case string: - elb := new(big.Int) - elb.SetString(vv, 10) - el := f.params.ConstantFromBigOrPanic(elb) - e = &el - case interface{ ToBigIntRegular(*big.Int) *big.Int }: - b := new(big.Int) - vv.ToBigIntRegular(b) - el := f.params.ConstantFromBigOrPanic(b) - e = &el - case compiled.LinearExpression: - el := f.params.ConstantFromLimbs([]frontend.Variable{in}) - e = &el - case compiled.Term: - el := f.params.ConstantFromLimbs([]frontend.Variable{in}) - e = &el - default: - panic(fmt.Sprintf("can not cast %T to *Element", in)) - } - if !f.params.isEqual(e.params) { - panic("incompatible Element parameters") - } - return e -} - -func (f *fakeAPI) varsToElements(in ...frontend.Variable) []*Element { - var els []*Element - for i := range in { - switch v := in[i].(type) { - case []frontend.Variable: - subels := f.varsToElements(v...) - els = append(els, subels...) - case frontend.Variable: - els = append(els, f.varToElement(v)) - } - } - return els -} - -func (f *fakeAPI) Add(i1 frontend.Variable, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - els := f.varsToElements(i1, i2, in) - res := f.params.Element(f.api) - res.reduceAndOp(res.add, res.addPreCond, els[0], els[1]) - for i := 2; i < len(els); i++ { - res.reduceAndOp(res.add, res.addPreCond, &res, els[i]) - } - return &res -} - -func (f *fakeAPI) Neg(i1 frontend.Variable) frontend.Variable { - el := f.varToElement(i1) - res := f.params.Element(f.api) - res.Negate(*el) - return &res -} - -func (f *fakeAPI) Sub(i1 frontend.Variable, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - els := f.varsToElements(i1, i2, in) - sub := f.params.Element(f.api) - sub.Set(*els[1]) - for i := 2; i < len(els); i++ { - sub.reduceAndOp(sub.add, sub.addPreCond, &sub, els[i]) - } - res := f.params.Element(f.api) - res.reduceAndOp(res.sub, res.subPreCond, els[0], &sub) - return &res -} - -func (f *fakeAPI) Mul(i1 frontend.Variable, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - els := f.varsToElements(i1, i2, in) - res := f.params.Element(f.api) - res.reduceAndOp(res.mul, res.mulPreCond, els[0], els[1]) - for i := 2; i < len(els); i++ { - res.reduceAndOp(res.mul, res.mulPreCond, &res, els[i]) - } - return &res -} - -func (f *fakeAPI) DivUnchecked(i1 frontend.Variable, i2 frontend.Variable) frontend.Variable { - return f.Div(i1, i2) -} - -func (f *fakeAPI) Div(i1 frontend.Variable, i2 frontend.Variable) frontend.Variable { - els := f.varsToElements(i1, i2) - res := f.params.Element(f.api) - res.Div(*els[0], *els[1]) - return &res -} - -func (f *fakeAPI) Inverse(i1 frontend.Variable) frontend.Variable { - el := f.varToElement(i1) - res := f.params.Element(f.api) - res.Inverse(*el) - return &res -} - -func (f *fakeAPI) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { - el := f.varToElement(i1) - res := f.params.Element(f.api) - res.Reduce(*el) - out := res.ToBits() - switch len(n) { - case 0: - case 1: - out = out[:n[0]] - default: - panic("only single vararg permitted to ToBinary") - } - return out -} - -func (f *fakeAPI) FromBinary(b ...frontend.Variable) frontend.Variable { - els := f.varsToElements(b) - bits := make([]frontend.Variable, len(els)) - for i := range els { - f.AssertIsBoolean(els[i]) - bits[i] = els[i].Limbs[0] - } - res := f.params.Element(f.api) - res.FromBits(bits) - return &res -} - -func (f *fakeAPI) Xor(a frontend.Variable, b frontend.Variable) frontend.Variable { - els := f.varsToElements(a, b) - f.AssertIsBoolean(els[0]) - f.AssertIsBoolean(els[1]) - rv := f.api.Xor(els[0].Limbs[0], els[1].Limbs[0]) - r := f.params.ConstantFromLimbs([]frontend.Variable{rv}) - r.api = f.api - r.EnforceWidth() - return r -} - -func (f *fakeAPI) Or(a frontend.Variable, b frontend.Variable) frontend.Variable { - els := f.varsToElements(a, b) - f.AssertIsBoolean(els[0]) - f.AssertIsBoolean(els[1]) - rv := f.api.Or(els[0].Limbs[0], els[1].Limbs[0]) - r := f.params.ConstantFromLimbs([]frontend.Variable{rv}) - r.api = f.api - r.EnforceWidth() - return r -} - -func (f *fakeAPI) And(a frontend.Variable, b frontend.Variable) frontend.Variable { - els := f.varsToElements(a, b) - f.AssertIsBoolean(els[0]) - f.AssertIsBoolean(els[1]) - rv := f.api.And(els[0].Limbs[0], els[1].Limbs[0]) - r := f.params.ConstantFromLimbs([]frontend.Variable{rv}) - r.api = f.api - r.EnforceWidth() - return r -} - -func (f *fakeAPI) Select(b frontend.Variable, i1 frontend.Variable, i2 frontend.Variable) frontend.Variable { - els := f.varsToElements(i1, i2) - res := f.params.Element(f.api) - switch vv := b.(type) { - case Element: - f.AssertIsBoolean(vv) - b = vv.Limbs[0] - case *Element: - f.AssertIsBoolean(vv) - b = vv.Limbs[0] - } - if els[0].overflow == els[1].overflow && len(els[0].Limbs) == len(els[1].Limbs) { - res.Select(b, *els[0], *els[1]) - return &res - } - s0 := els[0] - s1 := els[1] - if s0.overflow != 0 || len(s0.Limbs) != int(f.params.nbLimbs) { - v := f.params.Element(f.api) - v.Reduce(*s0) - s0 = &v - } - if s1.overflow != 0 || len(s1.Limbs) != int(f.params.nbLimbs) { - v := f.params.Element(f.api) - v.Reduce(*s1) - s1 = &v - } - res.Select(b, *s0, *s1) - return &res -} - -func (f *fakeAPI) Lookup2(b0 frontend.Variable, b1 frontend.Variable, i0 frontend.Variable, i1 frontend.Variable, i2 frontend.Variable, i3 frontend.Variable) frontend.Variable { - els := f.varsToElements(i0, i1, i2, i3) - res := f.params.Element(f.api) - switch vv := b0.(type) { - case Element: - f.AssertIsBoolean(vv) - b0 = vv.Limbs[0] - case *Element: - f.AssertIsBoolean(vv) - b0 = vv.Limbs[0] - } - switch vv := b1.(type) { - case Element: - f.AssertIsBoolean(vv) - b1 = vv.Limbs[0] - case *Element: - f.AssertIsBoolean(vv) - b1 = vv.Limbs[0] - } - if els[0].overflow == els[1].overflow && els[0].overflow == els[2].overflow && els[0].overflow == els[3].overflow && len(els[0].Limbs) == len(els[1].Limbs) && len(els[0].Limbs) == len(els[2].Limbs) && len(els[0].Limbs) == len(els[3].Limbs) { - res.Lookup2(b0, b1, *els[0], *els[1], *els[2], *els[3]) - return &res - } - s0 := els[0] - s1 := els[1] - s2 := els[2] - s3 := els[3] - if s0.overflow != 0 || len(s0.Limbs) != int(f.params.nbLimbs) { - v := f.params.Element(f.api) - v.Reduce(*s0) - s0 = &v - } - if s1.overflow != 0 || len(s1.Limbs) != int(f.params.nbLimbs) { - v := f.params.Element(f.api) - v.Reduce(*s1) - s1 = &v - } - if s2.overflow != 0 || len(s2.Limbs) != int(f.params.nbLimbs) { - v := f.params.Element(f.api) - v.Reduce(*s2) - s2 = &v - } - if s3.overflow != 0 || len(s3.Limbs) != int(f.params.nbLimbs) { - v := f.params.Element(f.api) - v.Reduce(*s3) - s3 = &v - } - res.Lookup2(b0, b1, *s0, *s1, *s2, *s3) - return &res -} - -func (f *fakeAPI) IsZero(i1 frontend.Variable) frontend.Variable { - el := f.varToElement(i1) - reduced := f.params.Element(f.api) - reduced.Reduce(*el) - res := f.api.IsZero(reduced.Limbs[0]) - for i := 1; i < len(reduced.Limbs); i++ { - f.api.Mul(res, f.api.IsZero(reduced.Limbs[i])) - } - r := f.params.ConstantFromLimbs([]frontend.Variable{res}) - r.api = f.api - r.EnforceWidth() - return r -} - -func (f *fakeAPI) Cmp(i1 frontend.Variable, i2 frontend.Variable) frontend.Variable { - els := f.varsToElements(i1, i2) - rls := []Element{f.params.Element(f.api), f.params.Element(f.api)} - rls[0].Reduce(*els[0]) - rls[1].Reduce(*els[1]) - var res frontend.Variable = 0 - for i := int(f.params.nbLimbs - 1); i >= 0; i-- { - lmbCmp := f.api.Cmp(rls[0].Limbs[i], rls[1].Limbs[i]) - res = f.api.Select(f.api.IsZero(res), lmbCmp, res) - } - return res -} - -func (f *fakeAPI) AssertIsEqual(i1 frontend.Variable, i2 frontend.Variable) { - els := f.varsToElements(i1, i2) - tmp := f.params.Element(f.api) - tmp.Set(*els[0]) - tmp.reduceAndOp(func(a, b Element, nextOverflow uint) { a.AssertIsEqual(b) }, func(e1, e2 Element) (uint, error) { - nextOverflow, err := tmp.subPreCond(e2, e1) - var target errOverflow - if err != nil && errors.As(err, &target) { - target.reduceRight = !target.reduceRight - return nextOverflow, target - } - return nextOverflow, err - }, &tmp, els[1]) -} - -func (f *fakeAPI) AssertIsDifferent(i1 frontend.Variable, i2 frontend.Variable) { - els := f.varsToElements(i1, i2) - rls := []Element{f.params.Element(f.api), f.params.Element(f.api)} - rls[0].Reduce(*els[0]) - rls[1].Reduce(*els[1]) - var res frontend.Variable = 0 - for i := 0; i < int(f.params.nbLimbs); i++ { - cmp := f.api.Cmp(rls[0].Limbs[i], rls[1].Limbs[i]) - cmpsq := f.api.Mul(cmp, cmp) - res = f.api.Add(res, cmpsq) - } - f.api.AssertIsDifferent(res, 0) -} - -func (f *fakeAPI) AssertIsBoolean(i1 frontend.Variable) { - switch vv := i1.(type) { - case Element: - v := f.params.Element(f.api) - v.Reduce(vv) - f.api.AssertIsBoolean(v.Limbs[0]) - for i := 1; i < len(v.Limbs); i++ { - f.api.AssertIsEqual(v.Limbs[i], 0) - } - case *Element: - v := f.params.Element(f.api) - v.Reduce(*vv) - f.api.AssertIsBoolean(v.Limbs[0]) - for i := 1; i < len(v.Limbs); i++ { - f.api.AssertIsEqual(v.Limbs[i], 0) - } - default: - f.api.AssertIsBoolean(vv) - } -} - -func (f *fakeAPI) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { - els := f.varsToElements(v, bound) - l := f.params.Element(f.api) - l.Reduce(*els[0]) - r := f.params.Element(f.api) - r.Reduce(*els[1]) - l.AssertIsLessEqualThan(r) -} - -func (f *fakeAPI) Println(a ...frontend.Variable) { - els := f.varsToElements(a) - for i := range els { - f.api.Println(els[i].Limbs...) - } -} - -func (f *fakeAPI) Compiler() frontend.Compiler { - return f -} - -func (f *fakeAPI) NewHint(hf hint.Function, nbOutputs int, inputs ...frontend.Variable) ([]frontend.Variable, error) { - // this is a trick to allow calling hint functions using non-native - // elements. We use the fact that the hints take as inputs *big.Int values. - // Instead of supplying hf to the solver for calling, we wrap it with - // another function (implementing hint.Function), which takes as inputs the - // "expanded" version of inputs (where instead of Element values we provide - // as inputs the limbs of every Element) and returns nbLimbs*nbOutputs - // number of outputs (i.e. the limbs of non-native Element values). The - // wrapper then recomposes and decomposes the *big.Int values at runtime and - // provides them as input to the initially provided hint function. - var expandedInputs []frontend.Variable - type typedInput struct { - pos int - nbLimbs int - isElement bool - } - typedInputs := make([]typedInput, len(inputs)) - for i := range inputs { - switch vv := inputs[i].(type) { - case Element: - expandedInputs = append(expandedInputs, vv.Limbs...) - typedInputs[i] = typedInput{ - pos: len(expandedInputs) - len(vv.Limbs), - nbLimbs: len(vv.Limbs), - isElement: true, - } - case *Element: - expandedInputs = append(expandedInputs, vv.Limbs...) - typedInputs[i] = typedInput{ - pos: len(expandedInputs) - len(vv.Limbs), - nbLimbs: len(vv.Limbs), - isElement: true, - } - default: - expandedInputs = append(expandedInputs, inputs[i]) - typedInputs[i] = typedInput{ - pos: len(expandedInputs) - 1, - nbLimbs: 1, - isElement: false, - } - } - } - nbNativeOutputs := nbOutputs * int(f.params.nbLimbs) - wrappedHint := func(_ *big.Int, expandedHintInputs []*big.Int, expandedHintOutputs []*big.Int) error { - hintInputs := make([]*big.Int, len(inputs)) - hintOutputs := make([]*big.Int, nbOutputs) - for i, ti := range typedInputs { - hintInputs[i] = new(big.Int) - if ti.isElement { - if err := recompose(expandedHintInputs[ti.pos:ti.pos+ti.nbLimbs], f.params.nbBits, hintInputs[i]); err != nil { - return fmt.Errorf("recompose: %w", err) - } - } else { - hintInputs[i].Set(expandedHintInputs[ti.pos]) - } - } - for i := range hintOutputs { - hintOutputs[i] = new(big.Int) - } - if err := hf(f.params.r, hintInputs, hintOutputs); err != nil { - return fmt.Errorf("call hint: %w", err) - } - for i := range hintOutputs { - if err := decompose(hintOutputs[i], f.params.nbBits, expandedHintOutputs[i*int(f.params.nbLimbs):(i+1)*int(f.params.nbLimbs)]); err != nil { - return fmt.Errorf("decompose: %w", err) - } - } - return nil - } - hintRet, err := f.api.Compiler().NewHint(wrappedHint, nbNativeOutputs, expandedInputs...) - if err != nil { - return nil, fmt.Errorf("NewHint: %w", err) - } - ret := make([]frontend.Variable, nbOutputs) - for i := 0; i < nbOutputs; i++ { - el := f.params.Element(f.api) - el.Limbs = hintRet[i*int(f.params.nbLimbs) : (i+1)*int(f.params.nbLimbs)] - ret[i] = &el - } - return ret, nil -} - -func (f *fakeAPI) Tag(name string) frontend.Tag { - return f.api.Compiler().Tag(name) -} - -func (f *fakeAPI) AddCounter(from frontend.Tag, to frontend.Tag) { - f.api.Compiler().AddCounter(from, to) -} - -func (f *fakeAPI) ConstantValue(v frontend.Variable) (*big.Int, bool) { - var constLimbs []*big.Int - var nbBits uint - var succ bool - switch vv := v.(type) { - case Element: - nbBits = vv.params.nbBits - constLimbs = make([]*big.Int, len(vv.Limbs)) - for i := range vv.Limbs { - if constLimbs[i], succ = f.api.Compiler().ConstantValue(vv.Limbs[i]); !succ { - return nil, false - } - } - case *Element: - nbBits = vv.params.nbBits - constLimbs = make([]*big.Int, len(vv.Limbs)) - for i := range vv.Limbs { - if constLimbs[i], succ = f.api.Compiler().ConstantValue(vv.Limbs[i]); !succ { - return nil, false - } - } - default: - return f.api.Compiler().ConstantValue(vv) - } - res := new(big.Int) - if err := recompose(constLimbs, nbBits, res); err != nil { - return nil, false - } - return res, true -} - -func (f *fakeAPI) Field() *big.Int { - return f.params.r -} - -func (f *fakeAPI) FieldBitLen() int { - return f.params.r.BitLen() -} - -func (f *fakeAPI) IsBoolean(v frontend.Variable) bool { - switch vv := v.(type) { - case Element: - return f.api.Compiler().IsBoolean(vv.Limbs[0]) - case *Element: - return f.api.Compiler().IsBoolean(vv.Limbs[0]) - default: - return f.api.Compiler().IsBoolean(vv) - } -} - -func (f *fakeAPI) MarkBoolean(v frontend.Variable) { - switch vv := v.(type) { - case Element: - f.api.Compiler().MarkBoolean(vv.Limbs[0]) - case *Element: - f.api.Compiler().MarkBoolean(vv.Limbs[0]) - default: - f.api.Compiler().MarkBoolean(vv) - } -} diff --git a/std/math/nonnative/api_test.go b/std/math/nonnative/api_test.go deleted file mode 100644 index 425cad99bd..0000000000 --- a/std/math/nonnative/api_test.go +++ /dev/null @@ -1,289 +0,0 @@ -package nonnative - -import ( - "crypto/rand" - "fmt" - "math/big" - "sort" - "testing" - - "github.com/consensys/gnark-crypto/ecc" - bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/internal/backend/circuits" - "github.com/consensys/gnark/std/algebra/fields_bls12377" - "github.com/consensys/gnark/std/algebra/sw_bls12377" - "github.com/consensys/gnark/test" -) - -func witnessData(mod *big.Int) (X1, X2, X3, X4, X5, X6, Res *big.Int) { - val1, _ := rand.Int(rand.Reader, mod) - val2, _ := rand.Int(rand.Reader, mod) - val3, _ := rand.Int(rand.Reader, mod) - val4, _ := rand.Int(rand.Reader, mod) - val5, _ := rand.Int(rand.Reader, mod) - val6, _ := rand.Int(rand.Reader, mod) - - tmp := new(big.Int) - res := new(big.Int) - // res = x1^3 - tmp.Exp(val1, big.NewInt(3), mod) - res.Set(tmp) - // res = x1^3 + 5*x2 - tmp.Mul(val2, big.NewInt(5)) - res.Add(res, tmp) - // tmp = (x3-x4) - tmp.Sub(val3, val4) - tmp.Mod(tmp, mod) - // tmp2 = (x5+x6) - tmp2 := new(big.Int) - tmp2.Add(val5, val6) - // tmp = (x3-x4)/(x5+x6) - tmp2.ModInverse(tmp2, mod) - tmp.Mul(tmp, tmp2) - tmp.Mod(tmp, mod) - // res = x1^3 + 5*x2 + (x3-x4)/(x5+x6) - res.Add(res, tmp) - res.Mod(res, mod) - return val1, val2, val3, val4, val5, val6, res -} - -type EmulatedApiCircuit struct { - Params *Params - - X1, X2, X3, X4, X5, X6 Element - Res Element -} - -func (c *EmulatedApiCircuit) Define(api frontend.API) error { - if c.Params != nil { - api = NewAPI(api, c.Params) - } - // compute x1^3 + 5*x2 + (x3-x4) / (x5+x6) - x13 := api.Mul(c.X1, c.X1, c.X1) - fx2 := api.Mul(5, c.X2) - nom := api.Sub(c.X3, c.X4) - denom := api.Add(c.X5, c.X6) - free := api.Div(nom, denom) - res := api.Add(x13, fx2, free) - api.AssertIsEqual(res, c.Res) - return nil -} - -func TestEmulatedApi(t *testing.T) { - assert := test.NewAssert(t) - - r := ecc.BN254.ScalarField() - params, err := NewParams(32, r) - assert.NoError(err) - - circuit := EmulatedApiCircuit{ - Params: params, - X1: params.Placeholder(), - X2: params.Placeholder(), - X3: params.Placeholder(), - X4: params.Placeholder(), - X5: params.Placeholder(), - X6: params.Placeholder(), - Res: params.Placeholder(), - } - - x1, x2, x3, x4, x5, x6, res := witnessData(params.r) - witness := EmulatedApiCircuit{ - Params: params, - X1: params.ConstantFromBigOrPanic(x1), - X2: params.ConstantFromBigOrPanic(x2), - X3: params.ConstantFromBigOrPanic(x3), - X4: params.ConstantFromBigOrPanic(x4), - X5: params.ConstantFromBigOrPanic(x5), - X6: params.ConstantFromBigOrPanic(x6), - Res: params.ConstantFromBigOrPanic(res), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) -} - -type WrapperCircuit struct { - X1, X2, X3, X4, X5, X6 frontend.Variable - Res frontend.Variable -} - -func (c *WrapperCircuit) Define(api frontend.API) error { - // compute x1^3 + 5*x2 + (x3-x4) / (x5+x6) - x13 := api.Mul(c.X1, c.X1, c.X1) - fx2 := api.Mul(5, c.X2) - nom := api.Sub(c.X3, c.X4) - denom := api.Add(c.X5, c.X6) - free := api.Div(nom, denom) - res := api.Add(x13, fx2, free) - api.AssertIsEqual(res, c.Res) - return nil -} - -func TestTestEngineWrapper(t *testing.T) { - assert := test.NewAssert(t) - r := ecc.BN254.ScalarField() - params, err := NewParams(32, r) - assert.NoError(err) - - circuit := WrapperCircuit{ - X1: params.Placeholder(), - X2: params.Placeholder(), - X3: params.Placeholder(), - X4: params.Placeholder(), - X5: params.Placeholder(), - X6: params.Placeholder(), - Res: params.Placeholder(), - } - x1, x2, x3, x4, x5, x6, res := witnessData(params.r) - witness := WrapperCircuit{ - X1: params.ConstantFromBigOrPanic(x1), - X2: params.ConstantFromBigOrPanic(x2), - X3: params.ConstantFromBigOrPanic(x3), - X4: params.ConstantFromBigOrPanic(x4), - X5: params.ConstantFromBigOrPanic(x5), - X6: params.ConstantFromBigOrPanic(x6), - Res: params.ConstantFromBigOrPanic(res), - } - wrapperOpt := test.WithApiWrapper(func(api frontend.API) frontend.API { - return NewAPI(api, params) - }) - err = test.IsSolved(&circuit, &witness, testCurve.ScalarField(), wrapperOpt) - assert.NoError(err) -} - -func TestCompilerWrapper(t *testing.T) { - assert := test.NewAssert(t) - r := ecc.BN254.ScalarField() - params, err := NewParams(32, r) - assert.NoError(err) - - circuit := WrapperCircuit{} - x1, x2, x3, x4, x5, x6, res := witnessData(params.r) - witness := WrapperCircuit{ - X1: params.ConstantFromBigOrPanic(x1), - X2: params.ConstantFromBigOrPanic(x2), - X3: params.ConstantFromBigOrPanic(x3), - X4: params.ConstantFromBigOrPanic(x4), - X5: params.ConstantFromBigOrPanic(x5), - X6: params.ConstantFromBigOrPanic(x6), - Res: params.ConstantFromBigOrPanic(res), - } - ccs, err := frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.WithBuilderWrapper(BuilderWrapper(params))) - assert.NoError(err) - t.Log(ccs.GetNbConstraints()) - // TODO: create proof - _ = witness -} - -func TestIntegrationApi(t *testing.T) { - assert := test.NewAssert(t) - r := ecc.BN254.ScalarField() - params, err := NewParams(32, r) - assert.NoError(err) - wrapperOpt := test.WithApiWrapper(func(api frontend.API) frontend.API { - return NewAPI(api, params) - }) - keys := make([]string, 0, len(circuits.Circuits)) - for k := range circuits.Circuits { - keys = append(keys, k) - } - sort.Strings(keys) - - for i := range keys { - name := keys[i] - tData := circuits.Circuits[name] - assert.Run(func(assert *test.Assert) { - _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, tData.Circuit, frontend.WithBuilderWrapper(BuilderWrapper(params))) - assert.NoError(err) - }, name, "compile") - for i := range tData.ValidAssignments { - assignment := tData.ValidAssignments[i] - assert.Run(func(assert *test.Assert) { - err = test.IsSolved(tData.Circuit, assignment, testCurve.ScalarField(), wrapperOpt) - assert.NoError(err) - }, name, fmt.Sprintf("valid=%d", i)) - } - for i := range tData.InvalidAssignments { - assignment := tData.InvalidAssignments[i] - assert.Run(func(assert *test.Assert) { - err = test.IsSolved(tData.Circuit, assignment, testCurve.ScalarField(), wrapperOpt) - assert.Error(err) - }, name, fmt.Sprintf("invalid=%d", i)) - } - } -} - -type pairingBLS377 struct { - P sw_bls12377.G1Affine `gnark:",public"` - Q sw_bls12377.G2Affine - pairingRes bls12377.GT -} - -//lint:ignore U1000 skipped test -func (circuit *pairingBLS377) Define(api frontend.API) error { - pairingRes, _ := sw_bls12377.Pair(api, - []sw_bls12377.G1Affine{circuit.P}, - []sw_bls12377.G2Affine{circuit.Q}) - api.AssertIsEqual(pairingRes.C0.B0.A0, &circuit.pairingRes.C0.B0.A0) - api.AssertIsEqual(pairingRes.C0.B0.A1, &circuit.pairingRes.C0.B0.A1) - api.AssertIsEqual(pairingRes.C0.B1.A0, &circuit.pairingRes.C0.B1.A0) - api.AssertIsEqual(pairingRes.C0.B1.A1, &circuit.pairingRes.C0.B1.A1) - api.AssertIsEqual(pairingRes.C0.B2.A0, &circuit.pairingRes.C0.B2.A0) - api.AssertIsEqual(pairingRes.C0.B2.A1, &circuit.pairingRes.C0.B2.A1) - api.AssertIsEqual(pairingRes.C1.B0.A0, &circuit.pairingRes.C1.B0.A0) - api.AssertIsEqual(pairingRes.C1.B0.A1, &circuit.pairingRes.C1.B0.A1) - api.AssertIsEqual(pairingRes.C1.B1.A0, &circuit.pairingRes.C1.B1.A0) - api.AssertIsEqual(pairingRes.C1.B1.A1, &circuit.pairingRes.C1.B1.A1) - api.AssertIsEqual(pairingRes.C1.B2.A0, &circuit.pairingRes.C1.B2.A0) - api.AssertIsEqual(pairingRes.C1.B2.A1, &circuit.pairingRes.C1.B2.A1) - return nil -} - -func TestPairingBLS377(t *testing.T) { - t.Skip() - assert := test.NewAssert(t) - params, err := NewParams(32, ecc.BW6_761.ScalarField()) - assert.NoError(err) - - _, _, P, Q := bls12377.Generators() - milRes, _ := bls12377.MillerLoop([]bls12377.G1Affine{P}, []bls12377.G2Affine{Q}) - pairingRes := bls12377.FinalExponentiation(&milRes) - - circuit := pairingBLS377{} - - pxb := new(big.Int) - pyb := new(big.Int) - qxab := new(big.Int) - qxbb := new(big.Int) - qyab := new(big.Int) - qybb := new(big.Int) - witness := pairingBLS377{ - pairingRes: pairingRes, - P: sw_bls12377.G1Affine{ - X: params.ConstantFromBigOrPanic(P.X.ToBigIntRegular(pxb)), - Y: params.ConstantFromBigOrPanic(P.Y.ToBigIntRegular(pyb)), - }, - Q: sw_bls12377.G2Affine{ - X: fields_bls12377.E2{ - A0: params.ConstantFromBigOrPanic(Q.X.A0.ToBigIntRegular(qxab)), - A1: params.ConstantFromBigOrPanic(Q.X.A1.ToBigIntRegular(qxbb)), - }, - Y: fields_bls12377.E2{ - A0: params.ConstantFromBigOrPanic(Q.Y.A0.ToBigIntRegular(qyab)), - A1: params.ConstantFromBigOrPanic(Q.Y.A1.ToBigIntRegular(qybb)), - }, - }, - } - - wrapperOpt := test.WithApiWrapper(func(api frontend.API) frontend.API { - return NewAPI(api, params) - }) - err = test.IsSolved(&circuit, &witness, testCurve.ScalarField(), wrapperOpt) - assert.NoError(err) - _, err = frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit, frontend.WithBuilderWrapper(BuilderWrapper(params))) - assert.NoError(err) - // TODO: create proof -} diff --git a/std/math/nonnative/composition_test.go b/std/math/nonnative/composition_test.go deleted file mode 100644 index c8321d162f..0000000000 --- a/std/math/nonnative/composition_test.go +++ /dev/null @@ -1,55 +0,0 @@ -package nonnative - -import ( - "crypto/rand" - "fmt" - "math/big" - "testing" - - "github.com/consensys/gnark/test" -) - -func TestComposition(t *testing.T) { - assert := test.NewAssert(t) - for _, fp := range emulatedFields(t) { - params := fp.params - assert.Run(func(assert *test.Assert) { - n, err := rand.Int(rand.Reader, params.r) - if err != nil { - assert.FailNow("rand int", err) - } - res := make([]*big.Int, params.nbLimbs) - for i := range res { - res[i] = new(big.Int) - } - if err = decompose(n, params.nbBits, res); err != nil { - assert.FailNow("decompose", err) - } - n2 := new(big.Int) - if err = recompose(res, params.nbBits, n2); err != nil { - assert.FailNow("recompose", err) - } - if n2.Cmp(n) != 0 { - assert.FailNow("unequal") - } - }, testName(fp)) - } -} - -func TestSubPadding(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - for i := params.nbLimbs; i < 2*params.nbLimbs; i++ { - assert.Run(func(assert *test.Assert) { - limbs := subPadding(params, 0, i) - padValue := new(big.Int) - if err := recompose(limbs, params.nbBits, padValue); err != nil { - assert.FailNow("recompose", err) - } - padValue.Mod(padValue, params.r) - assert.Zero(padValue.Cmp(big.NewInt(0)), "padding not multiple of order") - }, fmt.Sprintf("%s/nbLimbs=%d", testName(fp), i)) - } - } -} diff --git a/std/math/nonnative/variable.go b/std/math/nonnative/variable.go deleted file mode 100644 index fcd403946a..0000000000 --- a/std/math/nonnative/variable.go +++ /dev/null @@ -1,678 +0,0 @@ -package nonnative - -// TODO: add checks which ensure that constants are not used as receivers -// TODO: add sanity checks before the operations (e.g. that overflow is -// sufficient and do not need to reduce) -// TODO: think about different "operation modes". Probably hand-optimized code -// is better than reducing eagerly, but the user should be at least aware during -// compile-time that values need to be reduced. But there should be an easy-mode -// where the user does not need to manually reduce and the library does it as -// necessary. -// TODO: check that the parameters coincide for elements. -// TODO: less equal than -// TODO: simple exponentiation before we implement Wesolowsky - -import ( - "errors" - "fmt" - "math" - "math/big" - "sync" - - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/std/math/bits" -) - -type errOverflow struct { - op string - nextOverflow uint - maxOverflow uint - reduceRight bool -} - -func (e errOverflow) Error() string { - return fmt.Sprintf("op %s overflow %d exceeds max %d", e.op, e.nextOverflow, e.maxOverflow) -} - -// Params defines the parameters of the emulated ring of integers modulo n. If -// n is prime, then the ring is also a finite field where inverse and division -// are allowed. -type Params struct { - // r is the modulus - r *big.Int - // hasInverses indicates if order is prime - hasInverses bool - // nbLimbs is the number of limbs which fit reduced element - nbLimbs uint - // nbBits is number of bits per limb. Top limb may contain less than - // nbBits bits. - nbBits uint - // maxOf is the maximum overflow before the element must be reduced. - maxOf uint - maxOfOnce sync.Once - - // constants for often used elements n, 0 and 1. Allocated only once - nConstOnce sync.Once - nConst *Element - zeroConstOnce sync.Once - zeroConst *Element - oneConstOnce sync.Once - oneConst *Element -} - -// NewParams initializes the parameters for emulating operations modulo n where -// every limb of the element contains up to nbBits bits. Returns error if sanity -// checks fail. -// -// This method checks the primality of n to detect if parameters define a finite -// field. As such, invocation of this method is expensive and should be done -// once. -func NewParams(nbBits int, r *big.Int) (*Params, error) { - if r.Cmp(big.NewInt(1)) < 1 { - return nil, fmt.Errorf("n must be at least 2") - } - if nbBits < 3 { - // even three is way too small, but it should probably work. - return nil, fmt.Errorf("nbBits must be at least 3") - } - nbLimbs := (r.BitLen() + nbBits - 1) / nbBits - fp := &Params{ - r: r, - nbLimbs: uint(nbLimbs), - nbBits: uint(nbBits), - hasInverses: r.ProbablyPrime(20), - } - return fp, nil -} - -// Element defines an element in the ring of integers modulo n. The integer -// value of the element is split into limbs of nbBits lengths and represented as -// a slice of limbs. -type Element struct { - Limbs []frontend.Variable `gnark:"limbs"` // in little-endian (least significant limb first) encoding - - // params carries the ring parameters - params *Params `gnark:"-"` - // overflow indicates the number of additions on top of the normal form. To - // ensure that none of the limbs overflow the scalar field of the snark - // curve, we must check that nbBits+overflow < floor(log2(fr modulus)) - overflow uint `gnark:"-"` - // api references the API for variable elements - api frontend.API `gnark:"-"` -} - -// Element returns initialized element in the field. The value of this element -// is not constrained and it only safe to use as a receiver in operations. For -// elements initialized to values use Zero(), One() or Modulus(). -func (fp *Params) Element(api frontend.API) Element { - if uint(api.Compiler().FieldBitLen()) < 2*fp.nbBits+1 { - panic(fmt.Sprintf("elements with limb length %d does not fit into scalar field", fp.nbBits)) - } - e := Element{ - Limbs: make([]frontend.Variable, fp.nbLimbs), - params: fp, - overflow: 0, - api: api, - } - return e -} - -// Modulus returns the modulus of the emulated ring as a constant. The returned -// element is not safe to use as an operation receiver. -func (fp *Params) Modulus() Element { - fp.nConstOnce.Do(func() { - element, err := fp.ConstantFromBig(fp.r) - if err != nil { - // should not err for fp.order - panic(fmt.Sprintf("witness from order: %v", err)) - } - fp.nConst = &element - }) - return *fp.nConst -} - -// Zero returns zero as a constant. The returned element is not safe to use as -// an operation receiver. -func (fp *Params) Zero() Element { - fp.zeroConstOnce.Do(func() { - element, err := fp.ConstantFromBig(big.NewInt(0)) - if err != nil { - panic(fmt.Sprintf("witness from zero: %v", err)) - } - fp.zeroConst = &element - }) - return *fp.zeroConst -} - -// One returns one as a constant. The returned element is not safe to use as an -// operation receiver. -func (fp *Params) One() Element { - fp.oneConstOnce.Do(func() { - element, err := fp.ConstantFromBig(big.NewInt(1)) - if err != nil { - panic(fmt.Sprintf("witness from one: %v", err)) - } - fp.oneConst = &element - }) - return *fp.oneConst -} - -// ConstantFromBig returns a constant element from the value. The returned -// element is not safe to use as an operation receiver. -func (fp *Params) ConstantFromBig(value *big.Int) (Element, error) { - constValue := new(big.Int).Set(value) - if fp.r.Cmp(value) != 0 { - constValue.Mod(constValue, fp.r) - } - limbs := make([]*big.Int, fp.nbLimbs) - for i := range limbs { - limbs[i] = new(big.Int) - } - if err := decompose(constValue, fp.nbBits, limbs); err != nil { - return Element{}, fmt.Errorf("decompose value: %w", err) - } - limbVars := make([]frontend.Variable, len(limbs)) - for i := range limbs { - limbVars[i] = frontend.Variable(limbs[i]) - } - e := Element{ - Limbs: limbVars, - params: fp, - overflow: 0, - api: nil, - } - return e, nil -} - -// ConstantFromBigOrPanic returns a constant from value or panics if value does -// not define a valid element in the ring. -func (fp *Params) ConstantFromBigOrPanic(value *big.Int) Element { - el, err := fp.ConstantFromBig(value) - if err != nil { - panic(err) - } - return el -} - -// ConstantFromLimbs returns a constant element from the given limbs. The -// returned element is not safe to use as an operation receiver. -func (fp *Params) ConstantFromLimbs(limbs []frontend.Variable) Element { - // TODO: check that every limb does not overflow the expected width - return Element{ - Limbs: limbs, - params: fp, - overflow: 0, - api: nil, - } -} - -// Placeholder returns a constant which is safe to use as a placeholder when -// compiling a circuit. -func (fp *Params) Placeholder() Element { - e, err := fp.ConstantFromBig(big.NewInt(0)) - if err != nil { - panic(err) - } - return e -} - -// From returns an element by regrouping the limbs to these parameters. -func (fp *Params) From(api frontend.API, a Element) Element { - return Element{ - api: api, - params: fp, - overflow: a.overflow, - Limbs: regroupLimbs(api, a.params, fp, a.Limbs), - } -} - -// isEqual returns if fp is equivalent to other. -func (fp *Params) isEqual(other *Params) bool { - return fp.r.Cmp(other.r) == 0 && fp.nbBits == other.nbBits -} - -// ToBits returns the bit representation of the Element in little-endian (LSB -// first) order. The returned bits are constrained to be 0-1. The number of -// returned bits is nbLimbs*nbBits+overflow. To obtain the bits of the canonical -// representation of Element, reduce Element first and take less significant -// bits corresponding to the bitwidth of the emulated modulus. -func (e *Element) ToBits() []frontend.Variable { - var carry frontend.Variable = 0 - var fullBits []frontend.Variable - var limbBits []frontend.Variable - for i := 0; i < len(e.Limbs); i++ { - limbBits = bits.ToBinary(e.api, e.api.Add(e.Limbs[i], carry), bits.WithNbDigits(int(e.params.nbBits+e.overflow))) - fullBits = append(fullBits, limbBits[:e.params.nbBits]...) - if e.overflow > 0 { - carry = bits.FromBinary(e.api, limbBits[e.params.nbBits:]) - } - } - fullBits = append(fullBits, limbBits[e.params.nbBits:e.params.nbBits+e.overflow]...) - return fullBits -} - -// FromBits sets the value of e from the given boolean variables in. The method -// assumes that the bits are given from the canonical representation of element -// (less than modulus). -func (e *Element) FromBits(in []frontend.Variable) { - nbLimbs := (uint(len(in)) + e.params.nbBits - 1) / e.params.nbBits - limbs := make([]frontend.Variable, nbLimbs) - for i := uint(0); i < nbLimbs-1; i++ { - limbs[i] = bits.FromBinary(e.api, in[i*e.params.nbBits:(i+1)*e.params.nbBits]) - } - limbs[nbLimbs-1] = bits.FromBinary(e.api, in[(nbLimbs-1)*e.params.nbBits:]) - e.overflow = 0 - e.Limbs = limbs -} - -// maxOverflow returns the maximal possible overflow for the element. If the -// overflow of the next operation exceeds the value returned by this method, -// then the limbs may overflow the native field. -func (e Element) maxOverflow() uint { - e.params.maxOfOnce.Do(func() { - e.params.maxOf = uint(e.api.Compiler().FieldBitLen()-1) - e.params.nbBits - }) - return e.params.maxOf -} - -// assertLimbsEqualitySlow is the main routine in the package. It asserts that the -// two slices of limbs represent the same integer value. This is also the most -// costly operation in the package as it does bit decomposition of the limbs. -func assertLimbsEqualitySlow(api frontend.API, l, r []frontend.Variable, nbBits, nbCarryBits uint) { - nbLimbs := len(l) - if len(r) > nbLimbs { - nbLimbs = len(r) - } - maxValue := new(big.Int).Lsh(big.NewInt(1), nbBits+nbCarryBits) - maxValueShift := new(big.Int).Lsh(big.NewInt(1), nbCarryBits) - - var carry frontend.Variable = 0 - for i := 0; i < nbLimbs; i++ { - diff := api.Add(maxValue, carry) - if i < len(l) { - diff = api.Add(diff, l[i]) - } - if i < len(r) { - diff = api.Sub(diff, r[i]) - } - if i > 0 { - diff = api.Sub(diff, maxValueShift) - } - // TODO: more efficient methods for splitting a variable? Because we are - // splitting the value into two, then maybe we do not need the whole - // binary decomposition \sum_{i=0}^n a_i 2^i, but can use a * 2^nbits + - // b. Then we can also omit the FromBinary call. - diffBits := bits.ToBinary(api, diff, bits.WithNbDigits(int(nbBits+nbCarryBits+1)), bits.WithUnconstrainedOutputs()) - for j := uint(0); j < nbBits; j++ { - api.AssertIsEqual(diffBits[j], 0) - } - carry = bits.FromBinary(api, diffBits[nbBits:nbBits+nbCarryBits+1]) - } - api.AssertIsEqual(carry, maxValueShift) -} - -// AssertLimbsEquality asserts that the limbs represent a same integer value (up -// to overflow). This method does not ensure that the values are equal modulo -// the field order. For strict equality, use AssertIsEqual. -func (e *Element) AssertLimbsEquality(a Element) { - maxOverflow := e.overflow - if a.overflow > e.overflow { - maxOverflow = a.overflow - } - rgpar := regroupParams(e.params, uint(e.api.Compiler().FieldBitLen()), maxOverflow) - rge := rgpar.From(e.api, *e) - rga := rgpar.From(e.api, a) - // slow path -- the overflows are different. Need to compare with carries. - // TODO: we previously assumed that one side was "larger" than the other - // side, but I think this assumption is not valid anymore - if e.overflow > a.overflow { - assertLimbsEqualitySlow(rge.api, rge.Limbs, rga.Limbs, rge.params.nbBits, rge.overflow) - } else { - assertLimbsEqualitySlow(rge.api, rga.Limbs, rge.Limbs, rga.params.nbBits, rga.overflow) - } -} - -// EnforceWidth enforces that the bitlength of the value is exactly the -// bitlength of the modulus. Any newly initialized variable should be -// constrained to ensure correct operations. -func (e *Element) EnforceWidth() { - for i := range e.Limbs { - limbNbBits := int(e.params.nbBits) - if i == len(e.Limbs)-1 { - // take only required bits from the most significant limb - limbNbBits = ((e.params.r.BitLen() - 1) % int(e.params.nbBits)) + 1 - } - // bits.ToBinary restricts the least significant NbDigits to be equal to - // the limb value. This is sufficient to restrict for the bitlength and - // we can discard the bits themselves. - bits.ToBinary(e.api, e.Limbs[i], bits.WithNbDigits(limbNbBits)) - } -} - -// Add sets e to a+b and returns e. The returned element may not be reduced to -// be less than the ring modulus. -func (e *Element) Add(a, b Element) *Element { - // variable case only - // TODO: figure out case when one element is a constant. If one addend is a - // constant, then we do not reduce it (but this is always case as the - // constant's overflow never increases?) - // TODO: check that the target is a variable (has an API) - // TODO: if both are constants, then add big ints - overflow, err := e.addPreCond(a, b) - if err != nil { - panic(err) - } - e.add(a, b, overflow) - return e -} - -func (e Element) addPreCond(a, b Element) (nextOverflow uint, err error) { - nextOverflow = 1 - reduceRight := a.overflow < b.overflow - if a.overflow > b.overflow { - nextOverflow += a.overflow - } else { - nextOverflow += b.overflow - } - if nextOverflow > e.maxOverflow() { - err = errOverflow{op: "add", nextOverflow: nextOverflow, maxOverflow: e.maxOverflow(), reduceRight: reduceRight} - } - return -} - -func (e *Element) add(a, b Element, nextOverflow uint) { - nbLimbs := len(a.Limbs) - if len(b.Limbs) > nbLimbs { - nbLimbs = len(b.Limbs) - } - limbs := make([]frontend.Variable, nbLimbs) - for i := range limbs { - limbs[i] = 0 - if i < len(a.Limbs) { - limbs[i] = e.api.Add(limbs[i], a.Limbs[i]) - } - if i < len(b.Limbs) { - limbs[i] = e.api.Add(limbs[i], b.Limbs[i]) - } - } - e.Limbs = limbs - e.overflow = nextOverflow -} - -// Mul sets e to a*b and returns e. The returned element may not be reduced to -// be less than the ring modulus. -func (e *Element) Mul(a, b Element) *Element { - // XXX: currently variable case only - // TODO: when one element is constant. - // TODO: check that target is initialized (has an API) - // TODO: if both are constants, then do big int mul - overflow, err := e.mulPreCond(a, b) - if err != nil { - panic(err) - } - e.mul(a, b, overflow) - return e -} - -func (e Element) mulPreCond(a, b Element) (nextOverflow uint, err error) { - reduceRight := a.overflow < b.overflow - nbResLimbs := nbMultiplicationResLimbs(len(a.Limbs), len(b.Limbs)) - nextOverflow = e.params.nbBits + uint(math.Log2(float64(2*nbResLimbs-1))) + 1 + a.overflow + b.overflow - if nextOverflow > e.maxOverflow() { - err = errOverflow{op: "mul", nextOverflow: nextOverflow, maxOverflow: e.maxOverflow(), reduceRight: reduceRight} - } - return -} - -func (e *Element) mul(a, b Element, nextOverflow uint) { - limbs, err := computeMultiplicationHint(e.api, e.params, a.Limbs, b.Limbs) - if err != nil { - panic(fmt.Sprintf("multiplication hint: %s", err)) - } - // create constraints (\sum_{i=0}^{m-1} a_i c^i) * (\sum_{i=0}^{m-1} b_i - // c^i) = (\sum_{i=0}^{2m-2} z_i c^i) for c \in {1, 2m-1} - for c := 1; c <= len(limbs); c++ { - cb := big.NewInt(int64(c)) // c - bit := big.NewInt(1) // c^i - l := e.api.Mul(a.Limbs[0], bit) - for i := 1; i < len(a.Limbs); i++ { - bit.Mul(bit, cb) - l = e.api.Add(l, e.api.Mul(a.Limbs[i], bit)) - } - bit.SetInt64(1) - r := e.api.Mul(b.Limbs[0], bit) - for i := 1; i < len(b.Limbs); i++ { - bit.Mul(bit, cb) - r = e.api.Add(r, e.api.Mul(b.Limbs[i], bit)) - } - bit.SetInt64(1) - o := e.api.Mul(limbs[0], bit) - for i := 1; i < len(limbs); i++ { - bit.Mul(bit, cb) - o = e.api.Add(o, e.api.Mul(limbs[i], bit)) - } - e.api.AssertIsEqual(e.api.Mul(l, r), o) - } - e.Limbs = limbs - e.overflow = nextOverflow -} - -// Reduce reduces a modulo modulus and assigns e to the reduced value. -func (e *Element) Reduce(a Element) *Element { - if a.overflow == 0 { - // fast path - already reduced, omit reduction. - e.Set(a) - return e - } - // slow path - use hint to reduce value - r, err := computeReductionHint(e.api, a.params, a.Limbs) - if err != nil { - panic(fmt.Sprintf("reduction hint: %v", err)) - } - e.Limbs = r - e.overflow = 0 - e.AssertIsEqual(a) - return e -} - -// Set sets e to a and returns e. If a is constant, then it also enforces the -// widths of the limbs. -func (e *Element) Set(a Element) { - e.Limbs = make([]frontend.Variable, len(a.Limbs)) - e.overflow = a.overflow - copy(e.Limbs, a.Limbs) - if a.api == nil { - // we are setting from constant -- ensure that the widths of the limbs - // are restricted - e.EnforceWidth() - } -} - -// AssertIsEqual ensures that a is equal to e modulo the modulus. -func (e *Element) AssertIsEqual(a Element) { - diff := e.params.Element(e.api) - diff.Sub(a, *e) - kLimbs, err := computeEqualityHint(e.api, e.params, diff) - if err != nil { - panic(fmt.Sprintf("hint error: %v", err)) - } - k := e.params.ConstantFromLimbs(kLimbs) - p := e.params.Modulus() - kp := e.params.Element(e.api) - kp.Mul(k, p) - diff.AssertLimbsEquality(kp) -} - -// AssertIsEqualLessThan ensures that e is less or equal than e. -func (e *Element) AssertIsLessEqualThan(a Element) { - if e.overflow+a.overflow > 0 { - panic("inputs must have 0 overflow") - } - eBits := e.ToBits() - aBits := a.ToBits() - f := func(xbits, ybits []frontend.Variable) []frontend.Variable { - diff := len(xbits) - len(ybits) - ybits = append(ybits, make([]frontend.Variable, diff)...) - for i := len(ybits) - diff - 1; i < len(ybits); i++ { - ybits[i] = 0 - } - return ybits - } - if len(eBits) > len(aBits) { - aBits = f(eBits, aBits) - } else { - eBits = f(aBits, eBits) - } - p := make([]frontend.Variable, len(eBits)+1) - p[len(eBits)] = 1 - for i := len(eBits) - 1; i >= 0; i-- { - v := e.api.Mul(p[i+1], eBits[i]) - p[i] = e.api.Select(aBits[i], v, p[i+1]) - t := e.api.Select(aBits[i], 0, p[i+1]) - l := e.api.Sub(1, t, eBits[i]) - ll := e.api.Mul(l, eBits[i]) - e.api.AssertIsEqual(ll, 0) - } -} - -// Sub sets e to a-b and returns e. The returned element may not be reduced to -// be less than the ring modulus. -func (e *Element) Sub(a, b Element) *Element { - overflow, err := e.subPreCond(a, b) - if err != nil { - panic(err) - } - e.sub(a, b, overflow) - return e -} - -func (e Element) subPreCond(a, b Element) (nextOverflow uint, err error) { - reduceRight := a.overflow < b.overflow+2 - nextOverflow = b.overflow + 2 - if a.overflow > nextOverflow { - nextOverflow = a.overflow - } - if nextOverflow > e.maxOverflow() { - err = errOverflow{op: "sub", nextOverflow: nextOverflow, maxOverflow: e.maxOverflow(), reduceRight: reduceRight} - } - return -} - -func (e *Element) sub(a, b Element, nextOverflow uint) { - // first we have to compute padding to ensure that the subtraction does not - // underflow. - nbLimbs := len(a.Limbs) - if len(b.Limbs) > nbLimbs { - nbLimbs = len(b.Limbs) - } - limbs := make([]frontend.Variable, nbLimbs) - padLimbs := subPadding(e.params, b.overflow, uint(nbLimbs)) - for i := range limbs { - limbs[i] = padLimbs[i] - if i < len(a.Limbs) { - limbs[i] = e.api.Add(limbs[i], a.Limbs[i]) - } - if i < len(b.Limbs) { - limbs[i] = e.api.Sub(limbs[i], b.Limbs[i]) - } - } - e.Limbs = limbs - e.overflow = nextOverflow -} - -// Div sets e to a/b and returns e. If modulus is not a prime, it panics. The -// result is less than the modulus. This method is more efficient than inverting -// b and multiplying it by a. -func (e *Element) Div(a, b Element) *Element { - if !e.params.hasInverses { - panic("modulus not a prime") - } - div, err := computeDivisionHint(e.api, e.params, a.Limbs, b.Limbs) - if err != nil { - panic(fmt.Sprintf("compute division: %v", err)) - } - e.Limbs = div - e.overflow = 0 - e.EnforceWidth() - res := e.params.Element(e.api) - res.Mul(*e, b) - res.AssertIsEqual(a) - return e -} - -// Inverse sets e to 1/a and returns e. If modulus is not a prime, it panics. -// The result is less than the modulus. -func (e *Element) Inverse(a Element) *Element { - if !e.params.hasInverses { - panic("modulus not a prime") - } - k, err := computeInverseHint(e.api, e.params, a.Limbs) - if err != nil { - panic(fmt.Sprintf("compute inverse: %v", err)) - } - e.Limbs = k - e.overflow = 0 - e.EnforceWidth() - res := e.params.Element(e.api) - res.Mul(*e, a) - one := e.params.One() - res.AssertIsEqual(one) - return e -} - -// Negate sets e to -a and returns e. The returned element may be larger than -// the modulus. -func (e *Element) Negate(a Element) *Element { - z := e.params.Zero() - return e.Sub(z, a) -} - -// Select sets e to a if selector == 0 and to b otherwise. -func (e *Element) Select(selector frontend.Variable, a, b Element) *Element { - if len(a.Limbs) != len(b.Limbs) { - panic("unequal limb counts for select") - } - if a.overflow != b.overflow { - panic("unequal overflows for select") - } - e.Limbs = make([]frontend.Variable, len(a.Limbs)) - e.overflow = a.overflow - for i := range a.Limbs { - e.Limbs[i] = e.api.Select(selector, a.Limbs[i], b.Limbs[i]) - } - return e -} - -// Lookup2 performs two-bit lookup between a, b, c, d based on lookup bits b1 -// and b2. Sets e to a if b0=b1=0, b if b0=1 and b1=0, c if b0=0 and b1=1, d if b0=b1=1. -func (e *Element) Lookup2(b0, b1 frontend.Variable, a, b, c, d Element) *Element { - if len(a.Limbs) != len(b.Limbs) || len(a.Limbs) != len(c.Limbs) || len(a.Limbs) != len(d.Limbs) { - panic("unequal limb counts for lookup") - } - if a.overflow != b.overflow || a.overflow != c.overflow || a.overflow != d.overflow { - panic("unequal overflows for lookup") - } - e.Limbs = make([]frontend.Variable, len(a.Limbs)) - e.overflow = a.overflow - for i := range a.Limbs { - e.Limbs[i] = e.api.Lookup2(b0, b1, a.Limbs[i], b.Limbs[i], c.Limbs[i], d.Limbs[i]) - } - return e -} - -// reduceAndOp applies op on the inputs. If the pre-condition check preCond -// errs, then first reduces the input arguments. The reduction is done -// one-by-one with the element with highest overflow reduced first. -func (e *Element) reduceAndOp(op func(Element, Element, uint), preCond func(Element, Element) (uint, error), a, b *Element) { - var nextOverflow uint - var err error - var target errOverflow - for nextOverflow, err = preCond(*a, *b); errors.As(err, &target); nextOverflow, err = preCond(*a, *b) { - if !target.reduceRight { - a.Reduce(*a) - } else { - b.Reduce(*b) - } - } - op(*a, *b, nextOverflow) -} diff --git a/std/math/nonnative/variable_test.go b/std/math/nonnative/variable_test.go deleted file mode 100644 index 0bd9ddacec..0000000000 --- a/std/math/nonnative/variable_test.go +++ /dev/null @@ -1,936 +0,0 @@ -package nonnative - -import ( - "crypto/rand" - "fmt" - "math/big" - "testing" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/r1cs" - "github.com/consensys/gnark/frontend/cs/scs" - "github.com/consensys/gnark/test" -) - -const testCurve = ecc.BN254 - -// TODO: add also cases which should fail - -type emulatedField struct { - params *Params - name string -} - -func emulatedFields(t *testing.T) []emulatedField { - t.Helper() - var ret []emulatedField - for _, limbLength := range []int{32, 48, 64, 120} { - bn254fp, err := NewParams(limbLength, ecc.BN254.BaseField()) - if err != nil { - t.Fatal(err) - } - ret = append(ret, emulatedField{bn254fp, "bn254fp"}) - secp256k1fp, err := NewParams(limbLength, new(big.Int).SetBytes([]byte{ - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, - 0xFF, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFC, 0x2F, - })) - if err != nil { - t.Fatal(err) - } - ret = append(ret, emulatedField{secp256k1fp, "secp256k1"}) - } - goldilocks, err := NewParams(64, new(big.Int).SetBytes([]byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01})) - if err != nil { - t.Fatal(err) - } - ret = append(ret, emulatedField{goldilocks, "goldilocks"}) - return ret -} - -func testName(ef emulatedField) string { - return fmt.Sprintf("%s/limb=%d", ef.name, ef.params.nbBits) -} - -type AssertLimbEqualityCircuit struct { - params *Params - - A Element - B Element -} - -func (c *AssertLimbEqualityCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Set(c.A) - res.AssertLimbsEquality(c.B) - return nil -} - -func TestAssertLimbEqualityNoOverflow(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := AssertLimbEqualityCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - } - - val, _ := rand.Int(rand.Reader, params.r) - witness := AssertLimbEqualityCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val), - B: params.ConstantFromBigOrPanic(val), - } - assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type AssertIsLessEqualThanCircuit struct { - params *Params - - L Element - R Element -} - -func (c *AssertIsLessEqualThanCircuit) Define(api frontend.API) error { - L := c.params.Element(api) - L.Set(c.L) - R := c.params.Element(api) - R.Set(c.R) - L.AssertIsLessEqualThan(R) - return nil -} - -func TestAssertIsLessEqualThan(t *testing.T) { - for _, fp := range emulatedFields(t)[:1] { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := AssertIsLessEqualThanCircuit{ - params: params, - L: params.Placeholder(), - R: params.Placeholder(), - } - R, _ := rand.Int(rand.Reader, params.r) - L, _ := rand.Int(rand.Reader, R) - witness := AssertIsLessEqualThanCircuit{ - params: params, - L: params.ConstantFromBigOrPanic(L), - R: params.ConstantFromBigOrPanic(R), - } - assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type AddCircuit struct { - params *Params - - A Element - B Element - C Element -} - -func (c *AddCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Add(c.A, c.B) - res.AssertLimbsEquality(c.C) - return nil -} - -func TestAddCircuitNoOverflow(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := AddCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, new(big.Int).Div(params.r, big.NewInt(2))) - val2, _ := rand.Int(rand.Reader, new(big.Int).Div(params.r, big.NewInt(2))) - res := new(big.Int).Add(val1, val2) - witness := AddCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic(res), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type MulNoOverflowCircuit struct { - params *Params - - A Element - B Element - C Element -} - -func (c *MulNoOverflowCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Mul(c.A, c.B) - res.AssertLimbsEquality(c.C) - return nil -} - -func TestMulCircuitNoOverflow(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := MulNoOverflowCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), uint(params.r.BitLen())/2)) - val2, _ := rand.Int(rand.Reader, new(big.Int).Div(params.r, val1)) - res := new(big.Int).Mul(val1, val2) - witness := MulNoOverflowCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic(res), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type MulCircuitOverflow struct { - params *Params - - A Element - B Element - C Element -} - -func (c *MulCircuitOverflow) Define(api frontend.API) error { - res := c.params.Element(api) - res.Mul(c.A, c.B) - res.AssertIsEqual(c.C) - return nil -} - -func TestMulCircuitOverflow(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := MulCircuitOverflow{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - val2, _ := rand.Int(rand.Reader, params.r) - res := new(big.Int).Mul(val1, val2) - res.Mod(res, params.r) - witness := MulCircuitOverflow{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic(res), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type ReduceAfterAddCircuit struct { - params *Params - - A Element - B Element - C Element -} - -func (c *ReduceAfterAddCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Add(c.A, c.B) - res.Reduce(res) - res.AssertIsEqual(c.C) - return nil -} - -func TestReduceAfterAdd(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := ReduceAfterAddCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - } - - val2, _ := rand.Int(rand.Reader, params.r) - val1, _ := rand.Int(rand.Reader, val2) - val3 := new(big.Int).Add(val1, params.r) - val3.Sub(val3, val2) - witness := ReduceAfterAddCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val3), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic(val1), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type SubtractCircuit struct { - params *Params - - A Element - B Element - C Element -} - -func (c *SubtractCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Sub(c.A, c.B) - res.AssertIsEqual(c.C) - return nil -} - -func TestSubtractNoOverflow(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := SubtractCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - val2, _ := rand.Int(rand.Reader, val1) - res := new(big.Int).Sub(val1, val2) - witness := SubtractCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic(res), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -func TestSubtractOverflow(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := SubtractCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - val2, _ := rand.Int(rand.Reader, new(big.Int).Sub(params.r, val1)) - val2.Add(val2, val1) - res := new(big.Int).Sub(val1, val2) - res.Mod(res, params.r) - witness := SubtractCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic(res), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type NegationCircuit struct { - params *Params - - A Element - B Element -} - -func (c *NegationCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Negate(c.A) - res.AssertIsEqual(c.B) - return nil -} - -func TestNegation(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := NegationCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - res := new(big.Int).Sub(params.r, val1) - witness := NegationCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(res), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type InverseCircuit struct { - params *Params - - A Element - B Element -} - -func (c *InverseCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Inverse(c.A) - res.AssertIsEqual(c.B) - return nil -} - -func TestInverse(t *testing.T) { - for _, fp := range emulatedFields(t) { - if !fp.params.hasInverses { - continue - } - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := InverseCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - res := new(big.Int).ModInverse(val1, params.r) - witness := InverseCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(res), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type DivisionCircuit struct { - params *Params - A Element - B Element - C Element -} - -func (c *DivisionCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Div(c.A, c.B) - res.AssertIsEqual(c.C) - return nil -} - -func TestDivision(t *testing.T) { - for _, fp := range emulatedFields(t) { - if !fp.params.hasInverses { - continue - } - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := DivisionCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - val2, _ := rand.Int(rand.Reader, params.r) - res := new(big.Int) - res.ModInverse(val2, params.r) - res.Mul(val1, res) - res.Mod(res, params.r) - witness := DivisionCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic(res), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type ToBitsCircuit struct { - params *Params - - Value Element - Bits []frontend.Variable -} - -func (c *ToBitsCircuit) Define(api frontend.API) error { - el := c.params.Element(api) - el.Set(c.Value) - bits := el.ToBits() - if len(bits) != len(c.Bits) { - return fmt.Errorf("got %d bits, expected %d", len(bits), len(c.Bits)) - } - for i := range bits { - api.AssertIsEqual(bits[i], c.Bits[i]) - } - return nil -} - -func TestToBits(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - bitLen := params.nbBits * params.nbLimbs - circuit := ToBitsCircuit{ - params: params, - Value: params.Placeholder(), - Bits: make([]frontend.Variable, bitLen), - } - - val1, _ := rand.Int(rand.Reader, params.r) - bits := make([]frontend.Variable, bitLen) - for i := 0; i < len(bits); i++ { - bits[i] = val1.Bit(i) - } - witness := ToBitsCircuit{ - params: params, - Value: params.ConstantFromBigOrPanic(val1), - Bits: bits, - } - - assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type FromBitsCircuit struct { - params *Params - - Bits []frontend.Variable - Res Element -} - -func (c *FromBitsCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.FromBits(c.Bits) - res.AssertIsEqual(c.Res) - return nil -} - -func TestFromBits(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - bitLen := params.r.BitLen() - circuit := FromBitsCircuit{ - params: params, - Bits: make([]frontend.Variable, bitLen), - Res: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - bits := make([]frontend.Variable, bitLen) - for i := 0; i < len(bits); i++ { - bits[i] = val1.Bit(i) - } - witness := FromBitsCircuit{ - params: params, - Bits: bits, - Res: params.ConstantFromBigOrPanic(val1), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.WithProverOpts(backend.WithHints(GetHints()...))) - }, testName(fp)) - } -} - -type ConstantCircuit struct { - params *Params - - A Element - B Element -} - -func (c *ConstantCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Set(c.A) - res.AssertIsEqual(c.B) - return nil -} - -func TestConstant(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := ConstantCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - } - val, _ := rand.Int(rand.Reader, params.r) - witness := ConstantCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val), - B: params.ConstantFromBigOrPanic(val), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithCurves(testCurve), test.WithProverOpts(backend.WithHints(GetHints()...))) - }, testName(fp)) - } -} - -type SelectCircuit struct { - params *Params - - Selector frontend.Variable - A Element - B Element - C Element -} - -func (c *SelectCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Select(c.Selector, c.A, c.B) - res.AssertIsEqual(c.C) - return nil -} - -func TestSelect(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := SelectCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - val2, _ := rand.Int(rand.Reader, params.r) - randbit, _ := rand.Int(rand.Reader, big.NewInt(2)) - b := randbit.Uint64() - witness := SelectCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic([]*big.Int{val1, val2}[1-b]), - Selector: b, - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type Lookup2Circuit struct { - params *Params - - Bit0 frontend.Variable - Bit1 frontend.Variable - A Element - B Element - C Element - D Element - E Element -} - -func (c *Lookup2Circuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Lookup2(c.Bit0, c.Bit1, c.A, c.B, c.C, c.D) - res.AssertIsEqual(c.E) - return nil -} - -func TestLookup2(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := Lookup2Circuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - D: params.Placeholder(), - E: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - val2, _ := rand.Int(rand.Reader, params.r) - val3, _ := rand.Int(rand.Reader, params.r) - val4, _ := rand.Int(rand.Reader, params.r) - randbit, _ := rand.Int(rand.Reader, big.NewInt(4)) - witness := Lookup2Circuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic(val3), - D: params.ConstantFromBigOrPanic(val4), - E: params.ConstantFromBigOrPanic([]*big.Int{val1, val2, val3, val4}[randbit.Uint64()]), - Bit0: randbit.Bit(0), - Bit1: randbit.Bit(1), - } - - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -type ComputationCircuit struct { - params *Params - noReduce bool - - X1, X2, X3, X4, X5, X6 Element - Res Element -} - -func (c *ComputationCircuit) Define(api frontend.API) error { - // compute x1^3 + 5*x2 + (x3-x4) / (x5+x6) - x13 := c.params.Element(api) - x13.Mul(c.X1, c.X1) - if !c.noReduce { - x13.Reduce(x13) - } - x13.Mul(x13, c.X1) - if !c.noReduce { - x13.Reduce(x13) - } - - fx2 := c.params.Element(api) - five, err := c.params.ConstantFromBig(big.NewInt(5)) - if err != nil { - return fmt.Errorf("five: %w", err) - } - fx2.Mul(five, c.X2) - fx2.Reduce(fx2) - - nom := c.params.Element(api) - nom.Sub(c.X3, c.X4) - - denom := c.params.Element(api) - denom.Add(c.X5, c.X6) - - free := c.params.Element(api) - free.Div(nom, denom) - - res := c.params.Element(api) - res.Add(x13, fx2) - res.Add(res, free) - - res.AssertIsEqual(c.Res) - return nil -} - -func TestComputation(t *testing.T) { - for _, fp := range emulatedFields(t) { - if !fp.params.hasInverses { - continue - } - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := ComputationCircuit{ - params: params, - X1: params.Placeholder(), - X2: params.Placeholder(), - X3: params.Placeholder(), - X4: params.Placeholder(), - X5: params.Placeholder(), - X6: params.Placeholder(), - Res: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, params.r) - val2, _ := rand.Int(rand.Reader, params.r) - val3, _ := rand.Int(rand.Reader, params.r) - val4, _ := rand.Int(rand.Reader, params.r) - val5, _ := rand.Int(rand.Reader, params.r) - val6, _ := rand.Int(rand.Reader, params.r) - - tmp := new(big.Int) - res := new(big.Int) - // res = x1^3 - tmp.Exp(val1, big.NewInt(3), params.r) - res.Set(tmp) - // res = x1^3 + 5*x2 - tmp.Mul(val2, big.NewInt(5)) - res.Add(res, tmp) - // tmp = (x3-x4) - tmp.Sub(val3, val4) - tmp.Mod(tmp, params.r) - // tmp2 = (x5+x6) - tmp2 := new(big.Int) - tmp2.Add(val5, val6) - // tmp = (x3-x4)/(x5+x6) - tmp2.ModInverse(tmp2, params.r) - tmp.Mul(tmp, tmp2) - tmp.Mod(tmp, params.r) - // res = x1^3 + 5*x2 + (x3-x4)/(x5+x6) - res.Add(res, tmp) - res.Mod(res, params.r) - - witness := ComputationCircuit{ - params: params, - X1: params.ConstantFromBigOrPanic(val1), - X2: params.ConstantFromBigOrPanic(val2), - X3: params.ConstantFromBigOrPanic(val3), - X4: params.ConstantFromBigOrPanic(val4), - X5: params.ConstantFromBigOrPanic(val5), - X6: params.ConstantFromBigOrPanic(val6), - Res: params.ConstantFromBigOrPanic(res), - } - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} - -func TestOptimisation(t *testing.T) { - assert := test.NewAssert(t) - params, err := NewParams(32, ecc.BN254.ScalarField()) - assert.NoError(err) - circuit := ComputationCircuit{ - params: params, - noReduce: true, - X1: params.Placeholder(), - X2: params.Placeholder(), - X3: params.Placeholder(), - X4: params.Placeholder(), - X5: params.Placeholder(), - X6: params.Placeholder(), - Res: params.Placeholder(), - } - ccs, err := frontend.Compile(testCurve.ScalarField(), r1cs.NewBuilder, &circuit) - assert.NoError(err) - assert.LessOrEqual(ccs.GetNbConstraints(), 3291) - ccs2, err := frontend.Compile(testCurve.ScalarField(), scs.NewBuilder, &circuit) - assert.NoError(err) - assert.LessOrEqual(ccs2.GetNbConstraints(), 10722) -} - -type FourMulsCircuit struct { - params *Params - A Element - Res Element -} - -func (c *FourMulsCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Mul(c.A, c.A) - res.Mul(res, c.A) - res.Mul(res, c.A) - res.AssertIsEqual(c.Res) - return nil -} - -func TestFourMuls(t *testing.T) { - assert := test.NewAssert(t) - params, err := NewParams(32, ecc.BN254.ScalarField()) - assert.NoError(err) - circuit := FourMulsCircuit{ - params: params, - A: params.Placeholder(), - Res: params.Placeholder(), - } - val1, _ := rand.Int(rand.Reader, params.r) - res := new(big.Int) - res.Mul(val1, val1) - res.Mul(res, val1) - res.Mul(res, val1) - res.Mod(res, params.r) - witness := FourMulsCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - Res: params.ConstantFromBigOrPanic(res), - } - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) -} - -type RegroupCircuit struct { - params *Params - - A Element - B Element - C Element -} - -func (c *RegroupCircuit) Define(api frontend.API) error { - res := c.params.Element(api) - res.Add(c.A, c.B) - res.AssertLimbsEquality(c.C) - params2 := regroupParams(c.params, uint(api.Compiler().FieldBitLen()), res.overflow) - res2 := params2.From(api, res) - C2 := params2.From(api, c.C) - res2.AssertLimbsEquality(C2) - return nil -} - -func TestRegroupCircuit(t *testing.T) { - for _, fp := range emulatedFields(t) { - params := fp.params - assert := test.NewAssert(t) - assert.Run(func(assert *test.Assert) { - circuit := RegroupCircuit{ - params: params, - A: params.Placeholder(), - B: params.Placeholder(), - C: params.Placeholder(), - } - - val1, _ := rand.Int(rand.Reader, new(big.Int).Div(params.r, big.NewInt(2))) - val2, _ := rand.Int(rand.Reader, new(big.Int).Div(params.r, big.NewInt(2))) - res := new(big.Int).Add(val1, val2) - witness := RegroupCircuit{ - params: params, - A: params.ConstantFromBigOrPanic(val1), - B: params.ConstantFromBigOrPanic(val2), - C: params.ConstantFromBigOrPanic(res), - } - assert.ProverSucceeded(&circuit, &witness, test.WithProverOpts(backend.WithHints(GetHints()...)), test.WithCurves(testCurve)) - }, testName(fp)) - } -} diff --git a/test/assert.go b/test/assert.go index 7e925a2853..e667aaa613 100644 --- a/test/assert.go +++ b/test/assert.go @@ -407,6 +407,12 @@ func (assert *Assert) fuzzer(fuzzer filler, circuit, w frontend.Circuit, b backe errVars := IsSolved(circuit, w, curve.ScalarField()) errConsts := IsSolved(circuit, w, curve.ScalarField(), SetAllVariablesAsConstants()) + if (errVars == nil) != (errConsts == nil) { + assert.Log("errVars", errVars) + assert.Log("errConsts", errConsts) + assert.FailNow("solving circuit with values as constants vs non-constants mismatched result") + } + if errVars == nil && errConsts == nil { // valid witness assert.solvingSucceeded(circuit, w, b, curve, opt) diff --git a/test/engine.go b/test/engine.go index 0aaa87b666..d823230608 100644 --- a/test/engine.go +++ b/test/engine.go @@ -27,6 +27,7 @@ import ( "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend/schema" + "github.com/consensys/gnark/logger" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark/backend" @@ -126,16 +127,29 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng } }() + log := logger.Logger() + log.Debug().Msg("running circuit in test engine") + cptAdd, cptMul, cptSub, cptToBinary, cptFromBinary, cptAssertIsEqual = 0, 0, 0, 0, 0, 0 api := e.apiWrapper(e) err = c.Define(api) + log.Debug().Uint64("add", cptAdd). + Uint64("sub", cptSub). + Uint64("mul", cptMul). + Uint64("equals", cptAssertIsEqual). + Uint64("toBinary", cptToBinary). + Uint64("fromBinary", cptFromBinary).Msg("counters") return } +var cptAdd, cptMul, cptSub, cptToBinary, cptFromBinary, cptAssertIsEqual uint64 + func (e *engine) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + cptAdd++ res := new(big.Int) res.Add(e.toBigInt(i1), e.toBigInt(i2)) for i := 0; i < len(in); i++ { + cptAdd++ res.Add(res, e.toBigInt(in[i])) } res.Mod(res, e.modulus()) @@ -143,9 +157,11 @@ func (e *engine) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend } func (e *engine) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + cptSub++ res := new(big.Int) res.Sub(e.toBigInt(i1), e.toBigInt(i2)) for i := 0; i < len(in); i++ { + cptSub++ res.Sub(res, e.toBigInt(in[i])) } res.Mod(res, e.modulus()) @@ -160,6 +176,7 @@ func (e *engine) Neg(i1 frontend.Variable) frontend.Variable { } func (e *engine) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { + cptMul++ b2 := e.toBigInt(i2) if len(in) == 0 && b2.IsUint64() && b2.Uint64() <= 1 { // special path to avoid useless allocations @@ -173,6 +190,7 @@ func (e *engine) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend res.Mul(b1, b2) res.Mod(res, e.modulus()) for i := 0; i < len(in); i++ { + cptMul++ res.Mul(res, e.toBigInt(in[i])) res.Mod(res, e.modulus()) } @@ -212,6 +230,7 @@ func (e *engine) Inverse(i1 frontend.Variable) frontend.Variable { } func (e *engine) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { + cptToBinary++ nbBits := e.FieldBitLen() if len(n) == 1 { nbBits = n[0] @@ -243,6 +262,7 @@ func (e *engine) ToBinary(i1 frontend.Variable, n ...int) []frontend.Variable { } func (e *engine) FromBinary(v ...frontend.Variable) frontend.Variable { + cptFromBinary++ bits := make([]bool, len(v)) for i := 0; i < len(v); i++ { be := e.toBigInt(v[i]) @@ -339,6 +359,7 @@ func (e *engine) Cmp(i1, i2 frontend.Variable) frontend.Variable { } func (e *engine) AssertIsEqual(i1, i2 frontend.Variable) { + cptAssertIsEqual++ b1, b2 := e.toBigInt(i1), e.toBigInt(i2) if b1.Cmp(b2) != 0 { panic(fmt.Sprintf("[assertIsEqual] %s == %s", b1.String(), b2.String()))