Skip to content

Commit

Permalink
contructor for the insertBatchConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
V-Staykov committed Sep 5, 2024
1 parent 99d079b commit 8e66707
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 58 deletions.
6 changes: 1 addition & 5 deletions smt/pkg/blockinfo/block_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,7 @@ func BuildBlockInfoTree(
keys = append(keys, key)
vals = append(vals, val)

insertBatchCfg := smt.InsertBatchConfig{
Ctx: context.Background(),
LogPrefix: "block_info_tree",
ShouldPrintProgress: false,
}
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "block_info_tree", false)
root, err := infoTree.smt.InsertBatch(insertBatchCfg, keys, vals, nil, nil)
if err != nil {
return nil, err
Expand Down
12 changes: 2 additions & 10 deletions smt/pkg/blockinfo/block_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,7 @@ func TestBlockInfoHeader(t *testing.T) {
if err != nil {
t.Fatal(err)
}
insertBatchCfg := smt.InsertBatchConfig{
Ctx: context.Background(),
LogPrefix: "",
ShouldPrintProgress: false,
}
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
root, err := infoTree.smt.InsertBatch(insertBatchCfg, keys, vals, nil, nil)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -217,11 +213,7 @@ func TestSetBlockTx(t *testing.T) {
if err != nil {
t.Fatal(err)
}
insertBatchCfg := smt.InsertBatchConfig{
Ctx: context.Background(),
LogPrefix: "",
ShouldPrintProgress: false,
}
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
root, err2 := infoTree.smt.InsertBatch(insertBatchCfg, keys, vals, nil, nil)
if err2 != nil {
t.Fatal(err2)
Expand Down
6 changes: 1 addition & 5 deletions smt/pkg/smt/entity_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
}
}

insertBatchCfg := InsertBatchConfig{
Ctx: ctx,
LogPrefix: logPrefix,
ShouldPrintProgress: true,
}
insertBatchCfg := NewInsertBatchConfig(ctx, logPrefix, true)
_, err := s.InsertBatch(insertBatchCfg, keysBatchStorage, valuesBatchStorage, nil, nil)
return keysBatchStorage, valuesBatchStorage, err
}
Expand Down
34 changes: 21 additions & 13 deletions smt/pkg/smt/smt_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ import (
)

type InsertBatchConfig struct {
Ctx context.Context
LogPrefix string
ShouldPrintProgress bool
ctx context.Context
logPrefix string
shouldPrintProgress bool
}

func NewInsertBatchConfig(ctx context.Context, logPrefix string, shouldPrintProgress bool) InsertBatchConfig {
return InsertBatchConfig{
ctx: ctx,
logPrefix: logPrefix,
shouldPrintProgress: shouldPrintProgress,
}
}

func (s *SMT) InsertBatch(cfg InsertBatchConfig, nodeKeys []*utils.NodeKey, nodeValues []*utils.NodeValue8, nodeValuesHashes []*[4]uint64, rootNodeHash *utils.NodeKey) (*SMTResponse, error) {
Expand All @@ -28,8 +36,8 @@ func (s *SMT) InsertBatch(cfg InsertBatchConfig, nodeKeys []*utils.NodeKey, node

var progressChanPre chan uint64
var stopProgressPrinterPre func()
if cfg.ShouldPrintProgress {
progressChanPre, stopProgressPrinterPre = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (pre-process)", cfg.LogPrefix), uint64(4), false)
if cfg.shouldPrintProgress {
progressChanPre, stopProgressPrinterPre = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (pre-process)", cfg.logPrefix), uint64(4), false)
} else {
progressChanPre = make(chan uint64, 100)
var once sync.Once
Expand Down Expand Up @@ -62,8 +70,8 @@ func (s *SMT) InsertBatch(cfg InsertBatchConfig, nodeKeys []*utils.NodeKey, node
stopProgressPrinterPre()
var progressChan chan uint64
var stopProgressPrinter func()
if cfg.ShouldPrintProgress {
progressChan, stopProgressPrinter = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (process)", cfg.LogPrefix), uint64(size), false)
if cfg.shouldPrintProgress {
progressChan, stopProgressPrinter = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (process)", cfg.logPrefix), uint64(size), false)
} else {
progressChan = make(chan uint64)
var once sync.Once
Expand All @@ -76,8 +84,8 @@ func (s *SMT) InsertBatch(cfg InsertBatchConfig, nodeKeys []*utils.NodeKey, node

for i := 0; i < size; i++ {
select {
case <-cfg.Ctx.Done():
return nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", cfg.LogPrefix))
case <-cfg.ctx.Done():
return nil, fmt.Errorf(fmt.Sprintf("[%s] Context done", cfg.logPrefix))
case progressChan <- uint64(1):
default:
}
Expand Down Expand Up @@ -185,8 +193,8 @@ func (s *SMT) InsertBatch(cfg InsertBatchConfig, nodeKeys []*utils.NodeKey, node

var progressChanDel chan uint64
var stopProgressPrinterDel func()
if cfg.ShouldPrintProgress {
progressChanDel, stopProgressPrinterDel = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (deletes)", cfg.LogPrefix), uint64(totalDeleteOps), false)
if cfg.shouldPrintProgress {
progressChanDel, stopProgressPrinterDel = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (deletes)", cfg.logPrefix), uint64(totalDeleteOps), false)
} else {
progressChanDel = make(chan uint64, 100)
var once sync.Once
Expand All @@ -212,8 +220,8 @@ func (s *SMT) InsertBatch(cfg InsertBatchConfig, nodeKeys []*utils.NodeKey, node
totalFinalizeOps := len(nodeValues)
var progressChanFin chan uint64
var stopProgressPrinterFin func()
if cfg.ShouldPrintProgress {
progressChanFin, stopProgressPrinterFin = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (finalize)", cfg.LogPrefix), uint64(totalFinalizeOps), false)
if cfg.shouldPrintProgress {
progressChanFin, stopProgressPrinterFin = zk.ProgressPrinter(fmt.Sprintf("[%s] SMT incremental progress (finalize)", cfg.logPrefix), uint64(totalFinalizeOps), false)
} else {
progressChanFin = make(chan uint64, 100)
var once sync.Once
Expand Down
32 changes: 7 additions & 25 deletions smt/pkg/smt/smt_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,8 @@ func TestBatchSimpleInsert(t *testing.T) {

smtIncremental.InsertKA(k, valuesRaw[i])
}
insertBatchCfg := smt.InsertBatchConfig{
Ctx: context.Background(),
LogPrefix: "",
ShouldPrintProgress: true,
}

insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
_, err := smtBatch.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil)
assert.NilError(t, err)

Expand Down Expand Up @@ -115,11 +112,7 @@ func batchInsert(tree *smt.SMT, key, val []*big.Int) {
keyPointers = append(keyPointers, &k)
valuePointers = append(valuePointers, v)
}
insertBatchCfg := smt.InsertBatchConfig{
Ctx: context.Background(),
LogPrefix: "",
ShouldPrintProgress: true,
}
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
tree.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil)
}

Expand Down Expand Up @@ -368,11 +361,7 @@ func TestBatchWitness(t *testing.T) {

smtIncremental := smt.NewSMT(nil, false)
smtBatch := smt.NewSMT(nil, false)
insertBatchCfg := smt.InsertBatchConfig{
Ctx: context.Background(),
LogPrefix: "",
ShouldPrintProgress: true,
}
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
for i, k := range keys {
smtIncremental.Insert(k, values[i])
_, err := smtBatch.InsertBatch(insertBatchCfg, []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil)
Expand Down Expand Up @@ -436,11 +425,7 @@ func TestBatchDelete(t *testing.T) {

smtIncremental := smt.NewSMT(nil, false)
smtBatch := smt.NewSMT(nil, false)
insertBatchCfg := smt.InsertBatchConfig{
Ctx: context.Background(),
LogPrefix: "",
ShouldPrintProgress: true,
}
insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
for i, k := range keys {
smtIncremental.Insert(k, values[i])
_, err := smtBatch.InsertBatch(insertBatchCfg, []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil)
Expand Down Expand Up @@ -494,11 +479,8 @@ func TestBatchRawInsert(t *testing.T) {
t.Logf("Incremental insert %d values in %v\n", len(keysForIncremental), time.Since(startTime))

startTime = time.Now()
insertBatchCfg := smt.InsertBatchConfig{
Ctx: context.Background(),
LogPrefix: "",
ShouldPrintProgress: true,
}

insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false)
_, err := smtBatch.InsertBatch(insertBatchCfg, keysForBatch, valuesForBatch, nil, nil)
assert.NilError(t, err)
t.Logf("Batch insert %d values in %v\n", len(keysForBatch), time.Since(startTime))
Expand Down

0 comments on commit 8e66707

Please sign in to comment.