Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: clean up witness package, introduces clean witness.Witness interface #450

Merged
merged 11 commits into from
Feb 1, 2023
10 changes: 7 additions & 3 deletions backend/groth16/bellman_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,17 @@ func TestVerifyBellmanProof(t *testing.T) {

// verify groth16 proof
// we need to prepend the number of elements in the witness.
// witness package expects [nbPublic nbSecret] followed by [n | elements];
// note that n is redundant with nbPublic + nbSecret
var buf bytes.Buffer
_ = binary.Write(&buf, binary.BigEndian, uint32(len(inputsBytes)/(fr.Limbs*8)))
_ = binary.Write(&buf, binary.BigEndian, uint32(0))
_ = binary.Write(&buf, binary.BigEndian, uint32(len(inputsBytes)/(fr.Limbs*8)))
buf.Write(inputsBytes)

witness := &witness.Witness{
CurveID: ecc.BLS12_381,
}
witness, err := witness.New(ecc.BLS12_381.ScalarField())
require.NoError(t, err)

err = witness.UnmarshalBinary(buf.Bytes())
require.NoError(t, err)

Expand Down
74 changes: 37 additions & 37 deletions backend/groth16/groth16.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ import (
cs_bw6633 "github.com/consensys/gnark/constraint/bw6-633"
cs_bw6761 "github.com/consensys/gnark/constraint/bw6-761"

witness_bls12377 "github.com/consensys/gnark/internal/backend/bls12-377/witness"
witness_bls12381 "github.com/consensys/gnark/internal/backend/bls12-381/witness"
witness_bls24315 "github.com/consensys/gnark/internal/backend/bls24-315/witness"
witness_bls24317 "github.com/consensys/gnark/internal/backend/bls24-317/witness"
witness_bn254 "github.com/consensys/gnark/internal/backend/bn254/witness"
witness_bw6633 "github.com/consensys/gnark/internal/backend/bw6-633/witness"
witness_bw6761 "github.com/consensys/gnark/internal/backend/bw6-761/witness"
fr_bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr_bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
fr_bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr"
fr_bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr"
fr_bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr"
fr_bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr"
fr_bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr"

gnarkio "github.com/consensys/gnark/io"

Expand Down Expand Up @@ -109,51 +109,51 @@ type VerifyingKey interface {
}

// Verify runs the groth16.Verify algorithm on provided proof with given witness
func Verify(proof Proof, vk VerifyingKey, publicWitness *witness.Witness) error {
func Verify(proof Proof, vk VerifyingKey, publicWitness witness.Witness) error {

switch _proof := proof.(type) {
case *groth16_bls12377.Proof:
w, ok := publicWitness.Vector.(*witness_bls12377.Witness)
w, ok := publicWitness.Vector().(fr_bls12377.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bls12377.Verify(_proof, vk.(*groth16_bls12377.VerifyingKey), *w)
return groth16_bls12377.Verify(_proof, vk.(*groth16_bls12377.VerifyingKey), w)
case *groth16_bls12381.Proof:
w, ok := publicWitness.Vector.(*witness_bls12381.Witness)
w, ok := publicWitness.Vector().(fr_bls12381.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bls12381.Verify(_proof, vk.(*groth16_bls12381.VerifyingKey), *w)
return groth16_bls12381.Verify(_proof, vk.(*groth16_bls12381.VerifyingKey), w)
case *groth16_bn254.Proof:
w, ok := publicWitness.Vector.(*witness_bn254.Witness)
w, ok := publicWitness.Vector().(fr_bn254.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bn254.Verify(_proof, vk.(*groth16_bn254.VerifyingKey), *w)
return groth16_bn254.Verify(_proof, vk.(*groth16_bn254.VerifyingKey), w)
case *groth16_bw6761.Proof:
w, ok := publicWitness.Vector.(*witness_bw6761.Witness)
w, ok := publicWitness.Vector().(fr_bw6761.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bw6761.Verify(_proof, vk.(*groth16_bw6761.VerifyingKey), *w)
return groth16_bw6761.Verify(_proof, vk.(*groth16_bw6761.VerifyingKey), w)
case *groth16_bls24317.Proof:
w, ok := publicWitness.Vector.(*witness_bls24317.Witness)
w, ok := publicWitness.Vector().(fr_bls24317.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bls24317.Verify(_proof, vk.(*groth16_bls24317.VerifyingKey), *w)
return groth16_bls24317.Verify(_proof, vk.(*groth16_bls24317.VerifyingKey), w)
case *groth16_bls24315.Proof:
w, ok := publicWitness.Vector.(*witness_bls24315.Witness)
w, ok := publicWitness.Vector().(fr_bls24315.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bls24315.Verify(_proof, vk.(*groth16_bls24315.VerifyingKey), *w)
return groth16_bls24315.Verify(_proof, vk.(*groth16_bls24315.VerifyingKey), w)
case *groth16_bw6633.Proof:
w, ok := publicWitness.Vector.(*witness_bw6633.Witness)
w, ok := publicWitness.Vector().(fr_bw6633.Vector)
if !ok {
return witness.ErrInvalidWitness
}
return groth16_bw6633.Verify(_proof, vk.(*groth16_bw6633.VerifyingKey), *w)
return groth16_bw6633.Verify(_proof, vk.(*groth16_bw6633.VerifyingKey), w)
default:
panic("unrecognized R1CS curve type")
}
Expand All @@ -166,7 +166,7 @@ func Verify(proof Proof, vk VerifyingKey, publicWitness *witness.Witness) error
// will execute all the prover computations, even if the witness is invalid
// will produce an invalid proof
// internally, the solution vector to the R1CS will be filled with random values which may impact benchmarking
func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness.Witness, opts ...backend.ProverOption) (Proof, error) {
func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness witness.Witness, opts ...backend.ProverOption) (Proof, error) {

// apply options
opt, err := backend.NewProverConfig(opts...)
Expand All @@ -176,47 +176,47 @@ func Prove(r1cs constraint.ConstraintSystem, pk ProvingKey, fullWitness *witness

switch _r1cs := r1cs.(type) {
case *cs_bls12377.R1CS:
w, ok := fullWitness.Vector.(*witness_bls12377.Witness)
w, ok := fullWitness.Vector().(fr_bls12377.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), *w, opt)
return groth16_bls12377.Prove(_r1cs, pk.(*groth16_bls12377.ProvingKey), w, opt)
case *cs_bls12381.R1CS:
w, ok := fullWitness.Vector.(*witness_bls12381.Witness)
w, ok := fullWitness.Vector().(fr_bls12381.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), *w, opt)
return groth16_bls12381.Prove(_r1cs, pk.(*groth16_bls12381.ProvingKey), w, opt)
case *cs_bn254.R1CS:
w, ok := fullWitness.Vector.(*witness_bn254.Witness)
w, ok := fullWitness.Vector().(fr_bn254.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), *w, opt)
return groth16_bn254.Prove(_r1cs, pk.(*groth16_bn254.ProvingKey), w, opt)
case *cs_bw6761.R1CS:
w, ok := fullWitness.Vector.(*witness_bw6761.Witness)
w, ok := fullWitness.Vector().(fr_bw6761.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), *w, opt)
return groth16_bw6761.Prove(_r1cs, pk.(*groth16_bw6761.ProvingKey), w, opt)
case *cs_bls24317.R1CS:
w, ok := fullWitness.Vector.(*witness_bls24317.Witness)
w, ok := fullWitness.Vector().(fr_bls24317.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), *w, opt)
return groth16_bls24317.Prove(_r1cs, pk.(*groth16_bls24317.ProvingKey), w, opt)
case *cs_bls24315.R1CS:
w, ok := fullWitness.Vector.(*witness_bls24315.Witness)
w, ok := fullWitness.Vector().(fr_bls24315.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), *w, opt)
return groth16_bls24315.Prove(_r1cs, pk.(*groth16_bls24315.ProvingKey), w, opt)
case *cs_bw6633.R1CS:
w, ok := fullWitness.Vector.(*witness_bw6633.Witness)
w, ok := fullWitness.Vector().(fr_bw6633.Vector)
if !ok {
return nil, witness.ErrInvalidWitness
}
return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), *w, opt)
return groth16_bw6633.Prove(_r1cs, pk.(*groth16_bw6633.ProvingKey), w, opt)
default:
panic("unrecognized R1CS curve type")
}
Expand Down
124 changes: 124 additions & 0 deletions backend/groth16/groth16_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package groth16_test

import (
"math/big"
"testing"

"github.com/consensys/gnark"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend/groth16"
"github.com/consensys/gnark/constraint"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/r1cs"
)

//--------------------//
// benches //
//--------------------//

func BenchmarkSetup(b *testing.B) {
for _, curve := range getCurves() {
b.Run(curve.String(), func(b *testing.B) {
r1cs, _ := referenceCircuit(curve)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _, _ = groth16.Setup(r1cs)
}
})
}
}

func BenchmarkProver(b *testing.B) {
for _, curve := range getCurves() {
b.Run(curve.String(), func(b *testing.B) {
r1cs, _solution := referenceCircuit(curve)
fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField())
if err != nil {
b.Fatal(err)
}
pk, err := groth16.DummySetup(r1cs)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = groth16.Prove(r1cs, pk, fullWitness)
}
})
}
}

func BenchmarkVerifier(b *testing.B) {
for _, curve := range getCurves() {
b.Run(curve.String(), func(b *testing.B) {
r1cs, _solution := referenceCircuit(curve)
fullWitness, err := frontend.NewWitness(_solution, curve.ScalarField())
if err != nil {
b.Fatal(err)
}
publicWitness, err := fullWitness.Public()
if err != nil {
b.Fatal(err)
}

pk, vk, err := groth16.Setup(r1cs)
if err != nil {
b.Fatal(err)
}
proof, err := groth16.Prove(r1cs, pk, fullWitness)
if err != nil {
panic(err)
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = groth16.Verify(proof, vk, publicWitness)
}
})
}
}

type refCircuit struct {
nbConstraints int
X frontend.Variable
Y frontend.Variable `gnark:",public"`
}

func (circuit *refCircuit) Define(api frontend.API) error {
for i := 0; i < circuit.nbConstraints; i++ {
circuit.X = api.Mul(circuit.X, circuit.X)
}
api.AssertIsEqual(circuit.X, circuit.Y)
return nil
}

func referenceCircuit(curve ecc.ID) (constraint.ConstraintSystem, frontend.Circuit) {
const nbConstraints = 40000
circuit := refCircuit{
nbConstraints: nbConstraints,
}
r1cs, err := frontend.Compile(curve.ScalarField(), r1cs.NewBuilder, &circuit)
if err != nil {
panic(err)
}

var good refCircuit
good.X = 2

// compute expected Y
expectedY := new(big.Int).SetUint64(2)
exp := big.NewInt(1)
exp.Lsh(exp, nbConstraints)
expectedY.Exp(expectedY, exp, curve.ScalarField())

good.Y = expectedY

return r1cs, &good
}

func getCurves() []ecc.ID {
if testing.Short() {
return []ecc.ID{ecc.BN254}
}
return gnark.Curves()
}
Loading