Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: lookup blueprint compile time improvement #899

Merged
merged 6 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion constraint/bls12-377/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bls12-381/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bls24-315/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bls24-317/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions constraint/blueprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ type Blueprint interface {
// NbOutputs return the number of output wires this blueprint creates.
NbOutputs(inst Instruction) int

// WireWalker returns a function that walks the wires appearing in the blueprint.
// This is used by the level builder to build a dependency graph between instructions.
WireWalker(inst Instruction) func(cb func(wire uint32))
// UpdateInstructionTree updates the instruction tree;
// since the blue print knows which wires it references, it updates
// the instruction tree with the level of the (new) wires.
UpdateInstructionTree(inst Instruction, tree InstructionTree) Level
}

// Solver represents the state of a constraint system solver at runtime. Blueprint can interact
Expand Down
45 changes: 29 additions & 16 deletions constraint/blueprint_hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package constraint

import (
"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/debug"
)

type BlueprintGenericHint struct{}
Expand Down Expand Up @@ -71,24 +72,36 @@ func (b *BlueprintGenericHint) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintGenericHint) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
lenInputs := int(inst.Calldata[2])
j := 3
for i := 0; i < lenInputs; i++ {
n := int(inst.Calldata[j]) // len of linear expr
j++
func (b *BlueprintGenericHint) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
// BlueprintGenericHint knows the input and output to the instruction
maxLevel := LevelUnset

for k := 0; k < n; k++ {
t := Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]}
if !t.IsConstant() {
cb(t.VID)
}
j += 2
// iterate over the inputs and find the max level
lenInputs := int(inst.Calldata[2])
j := 3
for i := 0; i < lenInputs; i++ {
n := int(inst.Calldata[j]) // len of linear expr
j++

for k := 0; k < n; k++ {
wireID := inst.Calldata[j+1]
j += 2
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); level > maxLevel {
maxLevel = level
}
if debug.Debug && tree.GetWireLevel(wireID) == LevelUnset {
panic("wire we depend on is not in the instruction tree")
}
}
for k := inst.Calldata[j]; k < inst.Calldata[j+1]; k++ {
cb(k)
}
}

// iterate over the outputs and insert them at maxLevel + 1
outputLevel := maxLevel + 1
for k := inst.Calldata[j]; k < inst.Calldata[j+1]; k++ {
tree.InsertWire(k, outputLevel)
}
return outputLevel
}
77 changes: 47 additions & 30 deletions constraint/blueprint_logderivlookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import (
// It is essentially a hint to the solver, but enables storing the table entries only once.
type BlueprintLookupHint struct {
EntriesCalldata []uint32

// stores the maxLevel of the entries computed by WireWalker
maxLevel Level
maxLevelPosition int
maxLevelOffset int
}

// ensures BlueprintLookupHint implements the BlueprintSolvable interface
Expand Down Expand Up @@ -65,47 +70,59 @@ func (b *BlueprintLookupHint) NbOutputs(inst Instruction) int {
return int(inst.Calldata[2])
}

// Wires returns a function that walks the wires appearing in the blueprint.
// This is used by the level builder to build a dependency graph between instructions.
func (b *BlueprintLookupHint) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
// depend on the table UP to the number of entries at time of instruction creation.
nbEntries := int(inst.Calldata[1])
func (b *BlueprintLookupHint) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
// depend on the table UP to the number of entries at time of instruction creation.
nbEntries := int(inst.Calldata[1])

// check if we already cached the max level
if b.maxLevelPosition-1 < nbEntries { // adjust for default value of b.maxLevelPosition (0)

// invoke the callback on each wire appearing in the table
j := 0
for i := 0; i < nbEntries; i++ {
j := b.maxLevelOffset // skip the entries we already processed
for i := b.maxLevelPosition; i < nbEntries; i++ {
// first we have the length of the linear expression
n := int(b.EntriesCalldata[j])
j++
for k := 0; k < n; k++ {
t := Term{CID: b.EntriesCalldata[j], VID: b.EntriesCalldata[j+1]}
if !t.IsConstant() {
cb(t.VID)
}
wireID := b.EntriesCalldata[j+1]
j += 2
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); (level + 1) > b.maxLevel {
b.maxLevel = level + 1
}
}
}
b.maxLevelOffset = j
b.maxLevelPosition = nbEntries
}

// invoke the callback on each wire appearing in the inputs
nbInputs := int(inst.Calldata[2])
j = 3
for i := 0; i < nbInputs; i++ {
// first we have the length of the linear expression
n := int(inst.Calldata[j])
j++
for k := 0; k < n; k++ {
t := Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]}
if !t.IsConstant() {
cb(t.VID)
}
j += 2
maxLevel := b.maxLevel - 1 // offset for default value.

// update the max level with the lookup query inputs wires
nbInputs := int(inst.Calldata[2])
j := 3
for i := 0; i < nbInputs; i++ {
// first we have the length of the linear expression
n := int(inst.Calldata[j])
j++
for k := 0; k < n; k++ {
wireID := inst.Calldata[j+1]
j += 2
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); level > maxLevel {
maxLevel = level
}
}
}

// finally we have the outputs
for i := 0; i < nbInputs; i++ {
cb(uint32(i + int(inst.WireOffset)))
}
// finally we have the outputs
maxLevel++
for i := 0; i < nbInputs; i++ {
tree.InsertWire(uint32(i+int(inst.WireOffset)), maxLevel)
}

return maxLevel
}
44 changes: 30 additions & 14 deletions constraint/blueprint_r1cs.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,39 @@ func (b *BlueprintGenericR1C) DecompressR1C(c *R1C, inst Instruction) {
copySlice(&c.O, lenO, offset+2*(lenL+lenR))
}

func (b *BlueprintGenericR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
lenL := int(inst.Calldata[1])
lenR := int(inst.Calldata[2])
lenO := int(inst.Calldata[3])
func (b *BlueprintGenericR1C) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
// a R1C doesn't know which wires are input and which are outputs
lenL := int(inst.Calldata[1])
lenR := int(inst.Calldata[2])
lenO := int(inst.Calldata[3])

appendWires := func(expectedLen, idx int) {
for k := 0; k < expectedLen; k++ {
idx++
cb(inst.Calldata[idx])
idx++
outputWires := make([]uint32, 0)
maxLevel := LevelUnset
walkWires := func(n, idx int) {
for k := 0; k < n; k++ {
wireID := inst.Calldata[idx+1]
idx += 2 // advance the offset (coeffID + wireID)
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); level == LevelUnset {
outputWires = append(outputWires, wireID)
} else if level > maxLevel {
maxLevel = level
}
}
}

const offset = 4
walkWires(lenL, offset)
walkWires(lenR, offset+2*lenL)
walkWires(lenO, offset+2*(lenL+lenR))

const offset = 4
appendWires(lenL, offset)
appendWires(lenR, offset+2*lenL)
appendWires(lenO, offset+2*(lenL+lenR))
// insert the new wires.
maxLevel++
for _, wireID := range outputWires {
tree.InsertWire(wireID, maxLevel)
}

return maxLevel
}
55 changes: 33 additions & 22 deletions constraint/blueprint_scs.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ func (b *BlueprintGenericSparseR1C) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintGenericSparseR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
cb(inst.Calldata[0]) // xa
cb(inst.Calldata[1]) // xb
cb(inst.Calldata[2]) // xc
}
func (b *BlueprintGenericSparseR1C) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
return updateInstructionTree(inst.Calldata[0:3], tree)
}

func (b *BlueprintGenericSparseR1C) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
Expand Down Expand Up @@ -172,12 +168,8 @@ func (b *BlueprintSparseR1CMul) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintSparseR1CMul) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
cb(inst.Calldata[0]) // xa
cb(inst.Calldata[1]) // xb
cb(inst.Calldata[2]) // xc
}
func (b *BlueprintSparseR1CMul) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
return updateInstructionTree(inst.Calldata[0:3], tree)
}

func (b *BlueprintSparseR1CMul) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
Expand Down Expand Up @@ -220,12 +212,8 @@ func (b *BlueprintSparseR1CAdd) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintSparseR1CAdd) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
cb(inst.Calldata[0]) // xa
cb(inst.Calldata[1]) // xb
cb(inst.Calldata[2]) // xc
}
func (b *BlueprintSparseR1CAdd) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
return updateInstructionTree(inst.Calldata[0:3], tree)
}

func (b *BlueprintSparseR1CAdd) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
Expand Down Expand Up @@ -273,10 +261,8 @@ func (b *BlueprintSparseR1CBool) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintSparseR1CBool) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
cb(inst.Calldata[0]) // xa
}
func (b *BlueprintSparseR1CBool) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
return updateInstructionTree(inst.Calldata[0:1], tree)
}

func (b *BlueprintSparseR1CBool) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
Expand All @@ -303,3 +289,28 @@ func (b *BlueprintSparseR1CBool) DecompressSparseR1C(c *SparseR1C, inst Instruct
c.QL = inst.Calldata[1]
c.QM = inst.Calldata[2]
}

func updateInstructionTree(wires []uint32, tree InstructionTree) Level {
// constraint has at most one unsolved wire.
var outputWire uint32
found := false
maxLevel := LevelUnset
for _, wireID := range wires {
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); level == LevelUnset {
outputWire = wireID
found = true
} else if level > maxLevel {
maxLevel = level
}
}

maxLevel++
if found {
tree.InsertWire(outputWire, maxLevel)
}

return maxLevel
}
1 change: 0 additions & 1 deletion constraint/bn254/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bw6-633/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bw6-761/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading