Skip to content

Commit

Permalink
add conditional progress printer on smt_batch so we don't spam on sta… (
Browse files Browse the repository at this point in the history
#1122)

* add conditional progress printer on smt_batch so we don't spam on stage_execute

* contructor for the insertBatchConfig

* fix tests
  • Loading branch information
V-Staykov authored and Stefan-Ethernal committed Sep 20, 2024
1 parent 47f58f9 commit 8c2c8a5
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 25 deletions.
3 changes: 2 additions & 1 deletion smt/pkg/blockinfo/block_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func BuildBlockInfoTree(
keys = append(keys, key)
vals = append(vals, val)

root, err := infoTree.smt.InsertBatch(context.Background(), "", keys, vals, nil, nil)
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "block_info_tree", false)
root, err := infoTree.smt.InsertBatch(insertBatchCfg, keys, vals, nil, nil)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions smt/pkg/blockinfo/block_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func TestBlockInfoHeader(t *testing.T) {
if err != nil {
t.Fatal(err)
}

root, err := infoTree.smt.InsertBatch(context.Background(), "", keys, vals, nil, nil)
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
root, err := infoTree.smt.InsertBatch(insertBatchCfg, keys, vals, nil, nil)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -213,8 +213,8 @@ func TestSetBlockTx(t *testing.T) {
if err != nil {
t.Fatal(err)
}

root, err2 := infoTree.smt.InsertBatch(context.Background(), "", keys, vals, nil, nil)
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
root, err2 := infoTree.smt.InsertBatch(insertBatchCfg, keys, vals, nil, nil)
if err2 != nil {
t.Fatal(err2)
}
Expand Down
3 changes: 2 additions & 1 deletion smt/pkg/smt/entity_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
}
}

_, err := s.InsertBatch(ctx, logPrefix, keysBatchStorage, valuesBatchStorage, nil, nil)
insertBatchCfg := NewInsertBatchConfig(ctx, logPrefix, true)
_, err := s.InsertBatch(insertBatchCfg, keysBatchStorage, valuesBatchStorage, nil, nil)
return keysBatchStorage, valuesBatchStorage, err
}

Expand Down
84 changes: 74 additions & 10 deletions smt/pkg/smt/smt_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,21 @@ import (
"github.com/ledgerwatch/erigon/zk"
)

func (s *SMT) InsertBatch(ctx context.Context, logPrefix string, nodeKeys []*utils.NodeKey, nodeValues []*utils.NodeValue8, nodeValuesHashes []*[4]uint64, rootNodeHash *utils.NodeKey) (*SMTResponse, error) {
type InsertBatchConfig struct {
ctx context.Context
logPrefix string
shouldPrintProgress bool
}

func NewInsertBatchConfig(ctx context.Context, logPrefix string, shouldPrintProgress bool) InsertBatchConfig {
return InsertBatchConfig{
ctx: ctx,
logPrefix: logPrefix,
shouldPrintProgress: shouldPrintProgress,
}
}

func (s *SMT) InsertBatch(cfg InsertBatchConfig, nodeKeys []*utils.NodeKey, nodeValues []*utils.NodeValue8, nodeValuesHashes []*[4]uint64, rootNodeHash *utils.NodeKey) (*SMTResponse, error) {
s.clearUpMutex.Lock()
defer s.clearUpMutex.Unlock()

Expand All @@ -20,7 +34,18 @@ func (s *SMT) InsertBatch(ctx context.Context, logPrefix string, nodeKeys []*uti
var smtBatchNodeRoot *smtBatchNode
nodeHashesForDelete := make(map[uint64]map[uint64]map[uint64]map[uint64]*utils.NodeKey)

progressChanPre, stopProgressPrinterPre := zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (pre-process)", logPrefix), uint64(4), false)
var progressChanPre chan uint64
var stopProgressPrinterPre func()
if cfg.shouldPrintProgress {
progressChanPre, stopProgressPrinterPre = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (pre-process)", cfg.logPrefix), uint64(4), false)
} else {
progressChanPre = make(chan uint64, 100)
var once sync.Once

stopProgressPrinterPre = func() {
once.Do(func() { close(progressChanPre) })
}
}
defer stopProgressPrinterPre()

if err = validateDataLengths(nodeKeys, nodeValues, &nodeValuesHashes); err != nil {
Expand All @@ -43,17 +68,27 @@ func (s *SMT) InsertBatch(ctx context.Context, logPrefix string, nodeKeys []*uti
}
progressChanPre <- uint64(1)
stopProgressPrinterPre()
var progressChan chan uint64
var stopProgressPrinter func()
if cfg.shouldPrintProgress {
progressChan, stopProgressPrinter = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (process)", cfg.logPrefix), uint64(size), false)
} else {
progressChan = make(chan uint64)
var once sync.Once

progressChan, stopProgressPrinter := zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (process)", logPrefix), uint64(size), false)
stopProgressPrinter = func() {
once.Do(func() { close(progressChan) })
}
}
defer stopProgressPrinter()

for i := 0; i < size; i++ {
select {
case <-ctx.Done():
return nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", logPrefix))
case <-cfg.ctx.Done():
return nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", cfg.logPrefix))
case progressChan <- uint64(1):
default:
}
progressChan <- uint64(1)

insertingNodeKey := nodeKeys[i]
insertingNodeValue := nodeValues[i]
Expand Down Expand Up @@ -146,13 +181,28 @@ func (s *SMT) InsertBatch(ctx context.Context, logPrefix string, nodeKeys []*uti
maxInsertingNodePathLevel = insertingNodePathLevel
}
}
progressChan <- uint64(1)
select {
case progressChan <- uint64(1):
default:
}
stopProgressPrinter()

s.updateDepth(maxInsertingNodePathLevel)

totalDeleteOps := len(nodeHashesForDelete)
progressChanDel, stopProgressPrinterDel := zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (deletes)", logPrefix), uint64(totalDeleteOps), false)

var progressChanDel chan uint64
var stopProgressPrinterDel func()
if cfg.shouldPrintProgress {
progressChanDel, stopProgressPrinterDel = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (deletes)", cfg.logPrefix), uint64(totalDeleteOps), false)
} else {
progressChanDel = make(chan uint64, 100)
var once sync.Once

stopProgressPrinterDel = func() {
once.Do(func() { close(progressChanDel) })
}
}
defer stopProgressPrinterDel()
for _, mapLevel0 := range nodeHashesForDelete {
progressChanDel <- uint64(1)
Expand All @@ -168,10 +218,24 @@ func (s *SMT) InsertBatch(ctx context.Context, logPrefix string, nodeKeys []*uti
stopProgressPrinterDel()

totalFinalizeOps := len(nodeValues)
progressChanFin, stopProgressPrinterFin := zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (finalize)", logPrefix), uint64(totalFinalizeOps), false)
var progressChanFin chan uint64
var stopProgressPrinterFin func()
if cfg.shouldPrintProgress {
progressChanFin, stopProgressPrinterFin = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (finalize)", cfg.logPrefix), uint64(totalFinalizeOps), false)
} else {
progressChanFin = make(chan uint64, 100)
var once sync.Once

stopProgressPrinterFin = func() {
once.Do(func() { close(progressChanFin) })
}
}
defer stopProgressPrinterFin()
for i, nodeValue := range nodeValues {
progressChanFin <- uint64(1)
select {
case progressChanFin <- uint64(1):
default:
}
if !nodeValue.IsZero() {
err = s.hashSave(nodeValue.ToUintArray(), utils.BranchCapacity, *nodeValuesHashes[i])
if err != nil {
Expand Down
23 changes: 14 additions & 9 deletions smt/pkg/smt/smt_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ func TestBatchSimpleInsert(t *testing.T) {
smtIncremental.InsertKA(k, valuesRaw[i])
}

_, err := smtBatch.InsertBatch(context.Background(), "", keyPointers, valuePointers, nil, nil)
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
_, err := smtBatch.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil)
assert.NilError(t, err)

_, err = smtBatchNoSave.InsertBatch(context.Background(), "", keyPointers, valuePointers, nil, nil)
_, err = smtBatchNoSave.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil)
assert.NilError(t, err)

smtIncremental.DumpTree()
Expand Down Expand Up @@ -111,7 +112,8 @@ func batchInsert(tree *smt.SMT, key, val []*big.Int) {
keyPointers = append(keyPointers, &k)
valuePointers = append(valuePointers, v)
}
tree.InsertBatch(context.Background(), "", keyPointers, valuePointers, nil, nil)
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
tree.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil)
}

func BenchmarkIncrementalInsert(b *testing.B) {
Expand Down Expand Up @@ -359,10 +361,10 @@ func TestBatchWitness(t *testing.T) {

smtIncremental := smt.NewSMT(nil, false)
smtBatch := smt.NewSMT(nil, false)

insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
for i, k := range keys {
smtIncremental.Insert(k, values[i])
_, err := smtBatch.InsertBatch(context.Background(), "", []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil)
_, err := smtBatch.InsertBatch(insertBatchCfg, []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil)
assert.NilError(t, err)

smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot()
Expand Down Expand Up @@ -423,10 +425,10 @@ func TestBatchDelete(t *testing.T) {

smtIncremental := smt.NewSMT(nil, false)
smtBatch := smt.NewSMT(nil, false)

insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
for i, k := range keys {
smtIncremental.Insert(k, values[i])
_, err := smtBatch.InsertBatch(context.Background(), "", []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil)
_, err := smtBatch.InsertBatch(insertBatchCfg, []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil)
assert.NilError(t, err)

smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot()
Expand Down Expand Up @@ -477,7 +479,9 @@ func TestBatchRawInsert(t *testing.T) {
t.Logf("Incremental insert %d values in %v\n", len(keysForIncremental), time.Since(startTime))

startTime = time.Now()
_, err := smtBatch.InsertBatch(context.Background(), "", keysForBatch, valuesForBatch, nil, nil)

insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", true)
_, err := smtBatch.InsertBatch(insertBatchCfg, keysForBatch, valuesForBatch, nil, nil)
assert.NilError(t, err)
t.Logf("Batch insert %d values in %v\n", len(keysForBatch), time.Since(startTime))

Expand Down Expand Up @@ -519,7 +523,8 @@ func TestBatchRawInsert(t *testing.T) {
t.Logf("Incremental delete %d values in %v\n", len(keysForIncrementalDelete), time.Since(startTime))

startTime = time.Now()
_, err = smtBatch.InsertBatch(context.Background(), "", keysForBatchDelete, valuesForBatchDelete, nil, nil)

_, err = smtBatch.InsertBatch(insertBatchCfg, keysForBatchDelete, valuesForBatchDelete, nil, nil)
assert.NilError(t, err)
t.Logf("Batch delete %d values in %v\n", len(keysForBatchDelete), time.Since(startTime))

Expand Down

0 comments on commit 8c2c8a5

Please sign in to comment.