From a0cfa43721c81f612ebb248900f1d5d7be7008c8 Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 1 Mar 2023 15:19:24 -0600 Subject: [PATCH 1/2] fix: closes #509 api did not handle AssertIsLessOrEqual with constant as first param --- frontend/cs/r1cs/api_assertions.go | 48 ++++++++++++++++++++++-------- frontend/cs/scs/api_assertions.go | 39 +++++++++++++++++------- frontend/cs/scs/builder.go | 6 ++++ 3 files changed, 70 insertions(+), 23 deletions(-) diff --git a/frontend/cs/r1cs/api_assertions.go b/frontend/cs/r1cs/api_assertions.go index 2d2db8f010..9218f34531 100644 --- a/frontend/cs/r1cs/api_assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -17,12 +17,12 @@ limitations under the License. package r1cs import ( + "fmt" "math/big" "github.com/consensys/gnark/debug" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/internal/expr" - "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/math/bits" ) @@ -80,18 +80,34 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { // // derived from: // https://github.com/zcash/zips/blob/main/protocol/protocol.pdf -func (builder *builder) AssertIsLessOrEqual(_v frontend.Variable, bound frontend.Variable) { - v := builder.toVariable(_v) +func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { + cv, vConst := builder.constantValue(v) + cb, bConst := builder.constantValue(bound) + + // both inputs are constants + if vConst && bConst { + bv, bb := builder.cs.ToBigInt(&cv), builder.cs.ToBigInt(&cb) + if bv.Cmp(bb) == 1 { + panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", bv.String(), bb.String())) + } + } - if b, ok := bound.(expr.LinearExpression); ok { - assertIsSet(b) - builder.mustBeLessOrEqVar(v, b) - } else { - builder.mustBeLessOrEqCst(v, utils.FromInterface(bound)) + // bound is constant + if bConst { + vv := builder.toVariable(v) + builder.mustBeLessOrEqCst(vv, builder.cs.ToBigInt(&cb)) + return } + + builder.mustBeLessOrEqVar(v, bound) } -func (builder *builder) mustBeLessOrEqVar(a, bound expr.LinearExpression) { +func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { + // here bound is NOT a constant, + // but a can be either constant or a wire. + + _, aConst := builder.constantValue(a) + debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound) nbBits := builder.cs.FieldBitLen() @@ -128,16 +144,22 @@ func (builder *builder) mustBeLessOrEqVar(a, bound expr.LinearExpression) { // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - builder.MarkBoolean(aBits[i].(expr.LinearExpression)) // this does not create a constraint + builder.MarkBoolean(aBits[i]) // this does not create a constraint - added = append(added, builder.cs.AddConstraint(builder.newR1C(l, aBits[i], zero))) + if aConst { + // aBits[i] is a constant; + l = builder.Mul(l, aBits[i]) + added = append(added, builder.cs.AddConstraint(builder.newR1C(l, zero, zero))) + } else { + added = append(added, builder.cs.AddConstraint(builder.newR1C(l, aBits[i], zero))) + } } builder.cs.AttachDebugInfo(debug, added) } -func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound big.Int) { +func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound *big.Int) { nbBits := builder.cs.FieldBitLen() @@ -187,7 +209,7 @@ func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound big.Int l = builder.Sub(l, aBits[i]) added = append(added, builder.cs.AddConstraint(builder.newR1C(l, aBits[i], builder.cstZero()))) - builder.MarkBoolean(aBits[i].(expr.LinearExpression)) + builder.MarkBoolean(aBits[i]) } else { builder.AssertIsBoolean(aBits[i]) } diff --git a/frontend/cs/scs/api_assertions.go b/frontend/cs/scs/api_assertions.go index 976a60f1ab..46af67ed92 100644 --- a/frontend/cs/scs/api_assertions.go +++ b/frontend/cs/scs/api_assertions.go @@ -110,13 +110,13 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) { switch b := bound.(type) { case expr.Term: - builder.mustBeLessOrEqVar(v.(expr.Term), b) + builder.mustBeLessOrEqVar(v, b) default: - builder.mustBeLessOrEqCst(v.(expr.Term), utils.FromInterface(b)) + builder.mustBeLessOrEqCst(v, utils.FromInterface(b)) } } -func (builder *builder) mustBeLessOrEqVar(a expr.Term, bound expr.Term) { +func (builder *builder) mustBeLessOrEqVar(a frontend.Variable, bound expr.Term) { debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound) @@ -147,18 +147,29 @@ func (builder *builder) mustBeLessOrEqVar(a expr.Term, bound expr.Term) { // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - builder.MarkBoolean(aBits[i].(expr.Term)) // this does not create a constraint + builder.MarkBoolean(aBits[i]) // this does not create a constraint + + if ai, ok := builder.constantValue(aBits[i]); ok { + // a is constant; ensure l == 0 + builder.cs.Mul(&l.Coeff, &ai) + builder.addPlonkConstraint(sparseR1C{ + xa: l.VID, + qL: l.Coeff, + }, debug) + } else { + // l * a[i] == 0 + builder.addPlonkConstraint(sparseR1C{ + xa: l.VID, + xb: aBits[i].(expr.Term).VID, + qM: l.Coeff, + }, debug) + } - builder.addPlonkConstraint(sparseR1C{ - xa: l.VID, - xb: aBits[i].(expr.Term).VID, - qM: l.Coeff, - }, debug) } } -func (builder *builder) mustBeLessOrEqCst(a expr.Term, bound big.Int) { +func (builder *builder) mustBeLessOrEqCst(a frontend.Variable, bound big.Int) { nbBits := builder.cs.FieldBitLen() @@ -170,6 +181,14 @@ func (builder *builder) mustBeLessOrEqCst(a expr.Term, bound big.Int) { panic("AssertIsLessOrEqual: bound is too large, constraint will never be satisfied") } + if ca, ok := builder.constantValue(a); ok { + // a is constant, compare the big int values + ba := builder.cs.ToBigInt(&ca) + if ba.Cmp(&bound) == 1 { + panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", ba.String(), bound.String())) + } + } + // debug info debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound) diff --git a/frontend/cs/scs/builder.go b/frontend/cs/scs/builder.go index 5a9159ce24..ce670203ae 100644 --- a/frontend/cs/scs/builder.go +++ b/frontend/cs/scs/builder.go @@ -136,6 +136,12 @@ func (builder *builder) addMulGate(a, b, c expr.Term, debug ...constraint.DebugI // addPlonkConstraint adds a sparseR1C to the underlying constraint system func (builder *builder) addPlonkConstraint(c sparseR1C, debug ...constraint.DebugInfo) { + if !c.qM.IsZero() && (c.xa == 0 || c.xb == 0) { + // TODO this is internal but not easy to detect; if qM is set, but one or both of xa / xb is not, + // since wireID == 0 is a valid wire, it may trigger unexpected behavior. + log := logger.Logger() + log.Warn().Msg("adding a plonk constraint with qM set but xa or xb == 0 (wire 0)") + } L := builder.cs.MakeTerm(&c.qL, c.xa) R := builder.cs.MakeTerm(&c.qR, c.xb) O := builder.cs.MakeTerm(&c.qO, c.xc) From 2432700da8f7d20589394ab793ff976e3096758a Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Wed, 1 Mar 2023 23:15:15 -0600 Subject: [PATCH 2/2] style: remove useless MarkBoolean on non-returned bits --- frontend/cs/r1cs/api_assertions.go | 2 -- frontend/cs/scs/api_assertions.go | 1 - 2 files changed, 3 deletions(-) diff --git a/frontend/cs/r1cs/api_assertions.go b/frontend/cs/r1cs/api_assertions.go index 9218f34531..ec7351b4af 100644 --- a/frontend/cs/r1cs/api_assertions.go +++ b/frontend/cs/r1cs/api_assertions.go @@ -144,7 +144,6 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - builder.MarkBoolean(aBits[i]) // this does not create a constraint if aConst { // aBits[i] is a constant; @@ -209,7 +208,6 @@ func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound *big.In l = builder.Sub(l, aBits[i]) added = append(added, builder.cs.AddConstraint(builder.newR1C(l, aBits[i], builder.cstZero()))) - builder.MarkBoolean(aBits[i]) } else { builder.AssertIsBoolean(aBits[i]) } diff --git a/frontend/cs/scs/api_assertions.go b/frontend/cs/scs/api_assertions.go index 46af67ed92..27b5962420 100644 --- a/frontend/cs/scs/api_assertions.go +++ b/frontend/cs/scs/api_assertions.go @@ -147,7 +147,6 @@ func (builder *builder) mustBeLessOrEqVar(a frontend.Variable, bound expr.Term) // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 // → this is a boolean constraint // if bound[i] == 0, t must be 0 or 1, thus ai must be 0 or 1 too - builder.MarkBoolean(aBits[i]) // this does not create a constraint if ai, ok := builder.constantValue(aBits[i]); ok { // a is constant; ensure l == 0