From 93fd49d29ded8fd2f04f2a634957846b74d8d212 Mon Sep 17 00:00:00 2001 From: Marko Date: Thu, 5 Jan 2023 14:42:49 +0100 Subject: [PATCH] chore: backport pr 13881 (#14349) Co-authored-by: yihuang Co-authored-by: Julien Robert Co-authored-by: Matt Kocubinski --- CHANGELOG.md | 5 +- go.mod | 1 + go.sum | 2 + store/cachekv/benchmark_test.go | 161 ++++++++++++++ store/cachekv/internal/btree.go | 80 +++++++ store/cachekv/internal/btree_test.go | 202 ++++++++++++++++++ store/cachekv/internal/memiterator.go | 137 ++++++++++++ store/cachekv/{ => internal}/mergeiterator.go | 53 +++-- store/cachekv/memiterator.go | 57 ----- store/cachekv/search_benchmark_test.go | 4 +- store/cachekv/store.go | 47 ++-- store/cachekv/store_test.go | 146 +++++++++++++ 12 files changed, 782 insertions(+), 113 deletions(-) create mode 100644 store/cachekv/benchmark_test.go create mode 100644 store/cachekv/internal/btree.go create mode 100644 store/cachekv/internal/btree_test.go create mode 100644 store/cachekv/internal/memiterator.go rename store/cachekv/{ => internal}/mergeiterator.go (86%) delete mode 100644 store/cachekv/memiterator.go diff --git a/CHANGELOG.md b/CHANGELOG.md index dedbccb2a812..21526ee3e3cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ Ref: https://keepachangelog.com/en/1.0.0/ ### Improvements +* [#13881](https://github.com/cosmos/cosmos-sdk/pull/13881) Optimize iteration on nested cached KV stores and other operations in general. * (x/gov) [#14347](https://github.com/cosmos/cosmos-sdk/pull/14347) Support `v1.Proposal` message in `v1beta1.Proposal.Content`. ### Bug Fixes @@ -59,13 +60,12 @@ Ref: https://keepachangelog.com/en/1.0.0/ * (deps) Bump Tendermint version to [v0.34.24](https://github.com/tendermint/tendermint/releases/tag/v0.34.24). * [#13651](https://github.com/cosmos/cosmos-sdk/pull/13651) Update `server/config/config.GetConfig` function. -* [#13781](https://github.com/cosmos/cosmos-sdk/pull/13781) Remove `client/keys.KeysCdc`. * [#14175](https://github.com/cosmos/cosmos-sdk/pull/14175) Add `server.DefaultBaseappOptions(appopts)` function to reduce boiler plate in root.go. ### State Machine Breaking * (x/gov) [#14214](https://github.com/cosmos/cosmos-sdk/pull/14214) Fix gov v0.46 migration to v1 votes. - * Also provide a helper function `govv046.Migrate_V0466_To_V0467` for migrating a chain already on v0.46 with versions <=v0.46.6 to the latest v0.46.7 correct state. + * Also provide a helper function `govv046.Migrate_V0466_To_V0467` for migrating a chain already on v0.46 with versions <=v0.46.6 to the latest v0.46.7 correct state. * (x/group) [#14071](https://github.com/cosmos/cosmos-sdk/pull/14071) Don't re-tally proposal after voting period end if they have been marked as ACCEPTED or REJECTED. ### API Breaking Changes @@ -239,6 +239,7 @@ replace github.com/confio/ics23/go => github.com/cosmos/cosmos-sdk/ics23/go v0.8 * (x/group) [#12888](https://github.com/cosmos/cosmos-sdk/pull/12888) Fix event propagation to the current context of `x/group` message execution `[]sdk.Result`. * (x/upgrade) [#12906](https://github.com/cosmos/cosmos-sdk/pull/12906) Fix upgrade failure by moving downgrade verification logic after store migration. +* (store) [#12945](https://github.com/cosmos/cosmos-sdk/pull/12945) Fix nil end semantics in store/cachekv/iterator when iterating a dirty cache. ## [v0.46.0](https://github.com/cosmos/cosmos-sdk/releases/tag/v0.46.0) - 2022-07-26 diff --git a/go.mod b/go.mod index af7a597cccc2..dc789b4e79c2 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,7 @@ require ( github.com/tendermint/go-amino v0.16.0 github.com/tendermint/tendermint v0.34.24 github.com/tendermint/tm-db v0.6.7 + github.com/tidwall/btree v1.5.0 golang.org/x/crypto v0.2.0 golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e google.golang.org/genproto v0.0.0-20221014213838-99cd37c6964a diff --git a/go.sum b/go.sum index 1a250e71fa80..354198c92ace 100644 --- a/go.sum +++ b/go.sum @@ -952,6 +952,8 @@ github.com/tendermint/tendermint v0.34.24 h1:879MKKJWYYPJEMMKME+DWUTY4V9f/FBpnZD github.com/tendermint/tendermint v0.34.24/go.mod h1:rXVrl4OYzmIa1I91av3iLv2HS0fGSiucyW9J4aMTpKI= github.com/tendermint/tm-db v0.6.7 h1:fE00Cbl0jayAoqlExN6oyQJ7fR/ZtoVOmvPJ//+shu8= github.com/tendermint/tm-db v0.6.7/go.mod h1:byQDzFkZV1syXr/ReXS808NxA2xvyuuVgXOJ/088L6I= +github.com/tidwall/btree v1.5.0 h1:iV0yVY/frd7r6qGBXfEYs7DH0gTDgrKTrDjS7xt/IyQ= +github.com/tidwall/btree v1.5.0/go.mod h1:LGm8L/DZjPLmeWGjv5kFrY8dL4uVhMmzmmLYmsObdKE= github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= diff --git a/store/cachekv/benchmark_test.go b/store/cachekv/benchmark_test.go new file mode 100644 index 000000000000..2db62ba5d6c6 --- /dev/null +++ b/store/cachekv/benchmark_test.go @@ -0,0 +1,161 @@ +package cachekv_test + +import ( + fmt "fmt" + "testing" + + "github.com/cosmos/cosmos-sdk/store" + storetypes "github.com/cosmos/cosmos-sdk/store/types" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/tendermint/tendermint/libs/log" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" + dbm "github.com/tendermint/tm-db" +) + +func DoBenchmarkDeepContextStack(b *testing.B, depth int) { + begin := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} + end := []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + key := storetypes.NewKVStoreKey("test") + + db := dbm.NewMemDB() + cms := store.NewCommitMultiStore(db) + cms.MountStoreWithDB(key, storetypes.StoreTypeIAVL, db) + cms.LoadLatestVersion() + ctx := sdk.NewContext(cms, tmproto.Header{}, false, log.NewNopLogger()) + + var stack ContextStack + stack.Reset(ctx) + + for i := 0; i < depth; i++ { + stack.Snapshot() + + store := stack.CurrentContext().KVStore(key) + store.Set(begin, []byte("value")) + } + + store := stack.CurrentContext().KVStore(key) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + it := store.Iterator(begin, end) + it.Valid() + it.Key() + it.Value() + it.Next() + it.Close() + } +} + +func BenchmarkDeepContextStack1(b *testing.B) { + DoBenchmarkDeepContextStack(b, 1) +} + +func BenchmarkDeepContextStack3(b *testing.B) { + DoBenchmarkDeepContextStack(b, 3) +} +func BenchmarkDeepContextStack10(b *testing.B) { + DoBenchmarkDeepContextStack(b, 10) +} + +func BenchmarkDeepContextStack13(b *testing.B) { + DoBenchmarkDeepContextStack(b, 13) +} + +// cachedContext is a pair of cache context and its corresponding commit method. +// They are obtained from the return value of `context.CacheContext()`. +type cachedContext struct { + ctx sdk.Context + commit func() +} + +// ContextStack manages the initial context and a stack of cached contexts, +// to support the `StateDB.Snapshot` and `StateDB.RevertToSnapshot` methods. +// +// Copied from an old version of ethermint +type ContextStack struct { + // Context of the initial state before transaction execution. + // It's the context used by `StateDB.CommitedState`. + initialCtx sdk.Context + cachedContexts []cachedContext +} + +// CurrentContext returns the top context of cached stack, +// if the stack is empty, returns the initial context. +func (cs *ContextStack) CurrentContext() sdk.Context { + l := len(cs.cachedContexts) + if l == 0 { + return cs.initialCtx + } + return cs.cachedContexts[l-1].ctx +} + +// Reset sets the initial context and clear the cache context stack. +func (cs *ContextStack) Reset(ctx sdk.Context) { + cs.initialCtx = ctx + if len(cs.cachedContexts) > 0 { + cs.cachedContexts = []cachedContext{} + } +} + +// IsEmpty returns true if the cache context stack is empty. +func (cs *ContextStack) IsEmpty() bool { + return len(cs.cachedContexts) == 0 +} + +// Commit commits all the cached contexts from top to bottom in order and clears the stack by setting an empty slice of cache contexts. +func (cs *ContextStack) Commit() { + // commit in order from top to bottom + for i := len(cs.cachedContexts) - 1; i >= 0; i-- { + if cs.cachedContexts[i].commit == nil { + panic(fmt.Sprintf("commit function at index %d should not be nil", i)) + } else { + cs.cachedContexts[i].commit() + } + } + cs.cachedContexts = []cachedContext{} +} + +// CommitToRevision commit the cache after the target revision, +// to improve efficiency of db operations. +func (cs *ContextStack) CommitToRevision(target int) error { + if target < 0 || target >= len(cs.cachedContexts) { + return fmt.Errorf("snapshot index %d out of bound [%d..%d)", target, 0, len(cs.cachedContexts)) + } + + // commit in order from top to bottom + for i := len(cs.cachedContexts) - 1; i > target; i-- { + if cs.cachedContexts[i].commit == nil { + return fmt.Errorf("commit function at index %d should not be nil", i) + } + cs.cachedContexts[i].commit() + } + cs.cachedContexts = cs.cachedContexts[0 : target+1] + + return nil +} + +// Snapshot pushes a new cached context to the stack, +// and returns the index of it. +func (cs *ContextStack) Snapshot() int { + i := len(cs.cachedContexts) + ctx, commit := cs.CurrentContext().CacheContext() + cs.cachedContexts = append(cs.cachedContexts, cachedContext{ctx: ctx, commit: commit}) + return i +} + +// RevertToSnapshot pops all the cached contexts after the target index (inclusive). +// the target should be snapshot index returned by `Snapshot`. +// This function panics if the index is out of bounds. +func (cs *ContextStack) RevertToSnapshot(target int) { + if target < 0 || target >= len(cs.cachedContexts) { + panic(fmt.Errorf("snapshot index %d out of bound [%d..%d)", target, 0, len(cs.cachedContexts))) + } + cs.cachedContexts = cs.cachedContexts[:target] +} + +// RevertAll discards all the cache contexts. +func (cs *ContextStack) RevertAll() { + if len(cs.cachedContexts) > 0 { + cs.RevertToSnapshot(0) + } +} diff --git a/store/cachekv/internal/btree.go b/store/cachekv/internal/btree.go new file mode 100644 index 000000000000..142f754bbd38 --- /dev/null +++ b/store/cachekv/internal/btree.go @@ -0,0 +1,80 @@ +package internal + +import ( + "bytes" + "errors" + + "github.com/tidwall/btree" +) + +const ( + // The approximate number of items and children per B-tree node. Tuned with benchmarks. + // copied from memdb. + bTreeDegree = 32 +) + +var errKeyEmpty = errors.New("key cannot be empty") + +// BTree implements the sorted cache for cachekv store, +// we don't use MemDB here because cachekv is used extensively in sdk core path, +// we need it to be as fast as possible, while `MemDB` is mainly used as a mocking db in unit tests. +// +// We choose tidwall/btree over google/btree here because it provides API to implement step iterator directly. +type BTree struct { + tree btree.BTreeG[item] +} + +// NewBTree creates a wrapper around `btree.BTreeG`. +func NewBTree() *BTree { + return &BTree{tree: *btree.NewBTreeGOptions(byKeys, btree.Options{ + Degree: bTreeDegree, + // Contract: cachekv store must not be called concurrently + NoLocks: true, + })} +} + +func (bt *BTree) Set(key, value []byte) { + bt.tree.Set(newItem(key, value)) +} + +func (bt *BTree) Get(key []byte) []byte { + i, found := bt.tree.Get(newItem(key, nil)) + if !found { + return nil + } + return i.value +} + +func (bt *BTree) Delete(key []byte) { + bt.tree.Delete(newItem(key, nil)) +} + +func (bt *BTree) Iterator(start, end []byte) (*memIterator, error) { + if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { + return nil, errKeyEmpty + } + return NewMemIterator(start, end, bt, make(map[string]struct{}), true), nil +} + +func (bt *BTree) ReverseIterator(start, end []byte) (*memIterator, error) { + if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { + return nil, errKeyEmpty + } + return NewMemIterator(start, end, bt, make(map[string]struct{}), false), nil +} + +// item is a btree item with byte slices as keys and values +type item struct { + key []byte + value []byte +} + +// byKeys compares the items by key +func byKeys(a, b item) bool { + return bytes.Compare(a.key, b.key) == -1 +} + +// newItem creates a new pair item. +func newItem(key, value []byte) item { + return item{key: key, value: value} +} diff --git a/store/cachekv/internal/btree_test.go b/store/cachekv/internal/btree_test.go new file mode 100644 index 000000000000..f85a8bbaf109 --- /dev/null +++ b/store/cachekv/internal/btree_test.go @@ -0,0 +1,202 @@ +package internal + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" +) + +func TestGetSetDelete(t *testing.T) { + db := NewBTree() + + // A nonexistent key should return nil. + value := db.Get([]byte("a")) + require.Nil(t, value) + + // Set and get a value. + db.Set([]byte("a"), []byte{0x01}) + db.Set([]byte("b"), []byte{0x02}) + value = db.Get([]byte("a")) + require.Equal(t, []byte{0x01}, value) + + value = db.Get([]byte("b")) + require.Equal(t, []byte{0x02}, value) + + // Deleting a non-existent value is fine. + db.Delete([]byte("x")) + + // Delete a value. + db.Delete([]byte("a")) + + value = db.Get([]byte("a")) + require.Nil(t, value) + + db.Delete([]byte("b")) + + value = db.Get([]byte("b")) + require.Nil(t, value) +} + +func TestDBIterator(t *testing.T) { + db := NewBTree() + + for i := 0; i < 10; i++ { + if i != 6 { // but skip 6. + db.Set(int642Bytes(int64(i)), []byte{}) + } + } + + // Blank iterator keys should error + _, err := db.ReverseIterator([]byte{}, nil) + require.Equal(t, errKeyEmpty, err) + _, err = db.ReverseIterator(nil, []byte{}) + require.Equal(t, errKeyEmpty, err) + + itr, err := db.Iterator(nil, nil) + require.NoError(t, err) + verifyIterator(t, itr, []int64{0, 1, 2, 3, 4, 5, 7, 8, 9}, "forward iterator") + + ritr, err := db.ReverseIterator(nil, nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{9, 8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator") + + itr, err = db.Iterator(nil, int642Bytes(0)) + require.NoError(t, err) + verifyIterator(t, itr, []int64(nil), "forward iterator to 0") + + ritr, err = db.ReverseIterator(int642Bytes(10), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64(nil), "reverse iterator from 10 (ex)") + + itr, err = db.Iterator(int642Bytes(0), nil) + require.NoError(t, err) + verifyIterator(t, itr, []int64{0, 1, 2, 3, 4, 5, 7, 8, 9}, "forward iterator from 0") + + itr, err = db.Iterator(int642Bytes(1), nil) + require.NoError(t, err) + verifyIterator(t, itr, []int64{1, 2, 3, 4, 5, 7, 8, 9}, "forward iterator from 1") + + ritr, err = db.ReverseIterator(nil, int642Bytes(10)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{9, 8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator from 10 (ex)") + + ritr, err = db.ReverseIterator(nil, int642Bytes(9)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{8, 7, 5, 4, 3, 2, 1, 0}, "reverse iterator from 9 (ex)") + + ritr, err = db.ReverseIterator(nil, int642Bytes(8)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{7, 5, 4, 3, 2, 1, 0}, "reverse iterator from 8 (ex)") + + itr, err = db.Iterator(int642Bytes(5), int642Bytes(6)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{5}, "forward iterator from 5 to 6") + + itr, err = db.Iterator(int642Bytes(5), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{5}, "forward iterator from 5 to 7") + + itr, err = db.Iterator(int642Bytes(5), int642Bytes(8)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{5, 7}, "forward iterator from 5 to 8") + + itr, err = db.Iterator(int642Bytes(6), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, itr, []int64(nil), "forward iterator from 6 to 7") + + itr, err = db.Iterator(int642Bytes(6), int642Bytes(8)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{7}, "forward iterator from 6 to 8") + + itr, err = db.Iterator(int642Bytes(7), int642Bytes(8)) + require.NoError(t, err) + verifyIterator(t, itr, []int64{7}, "forward iterator from 7 to 8") + + ritr, err = db.ReverseIterator(int642Bytes(4), int642Bytes(5)) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{4}, "reverse iterator from 5 (ex) to 4") + + ritr, err = db.ReverseIterator(int642Bytes(4), int642Bytes(6)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{5, 4}, "reverse iterator from 6 (ex) to 4") + + ritr, err = db.ReverseIterator(int642Bytes(4), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{5, 4}, "reverse iterator from 7 (ex) to 4") + + ritr, err = db.ReverseIterator(int642Bytes(5), int642Bytes(6)) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{5}, "reverse iterator from 6 (ex) to 5") + + ritr, err = db.ReverseIterator(int642Bytes(5), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{5}, "reverse iterator from 7 (ex) to 5") + + ritr, err = db.ReverseIterator(int642Bytes(6), int642Bytes(7)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64(nil), "reverse iterator from 7 (ex) to 6") + + ritr, err = db.ReverseIterator(int642Bytes(10), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64(nil), "reverse iterator to 10") + + ritr, err = db.ReverseIterator(int642Bytes(6), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{9, 8, 7}, "reverse iterator to 6") + + ritr, err = db.ReverseIterator(int642Bytes(5), nil) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{9, 8, 7, 5}, "reverse iterator to 5") + + ritr, err = db.ReverseIterator(int642Bytes(8), int642Bytes(9)) + require.NoError(t, err) + verifyIterator(t, ritr, []int64{8}, "reverse iterator from 9 (ex) to 8") + + ritr, err = db.ReverseIterator(int642Bytes(2), int642Bytes(4)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64{3, 2}, "reverse iterator from 4 (ex) to 2") + + ritr, err = db.ReverseIterator(int642Bytes(4), int642Bytes(2)) + require.NoError(t, err) + verifyIterator(t, ritr, + []int64(nil), "reverse iterator from 2 (ex) to 4") + + // Ensure that the iterators don't panic with an empty database. + db2 := NewBTree() + + itr, err = db2.Iterator(nil, nil) + require.NoError(t, err) + verifyIterator(t, itr, nil, "forward iterator with empty db") + + ritr, err = db2.ReverseIterator(nil, nil) + require.NoError(t, err) + verifyIterator(t, ritr, nil, "reverse iterator with empty db") +} + +func verifyIterator(t *testing.T, itr *memIterator, expected []int64, msg string) { + i := 0 + for itr.Valid() { + key := itr.Key() + require.Equal(t, expected[i], bytes2Int64(key), "iterator: %d mismatches", i) + itr.Next() + i++ + } + require.Equal(t, i, len(expected), "expected to have fully iterated over all the elements in iter") + require.NoError(t, itr.Close()) +} + +func int642Bytes(i int64) []byte { + return sdk.Uint64ToBigEndian(uint64(i)) +} + +func bytes2Int64(buf []byte) int64 { + return int64(sdk.BigEndianToUint64(buf)) +} diff --git a/store/cachekv/internal/memiterator.go b/store/cachekv/internal/memiterator.go new file mode 100644 index 000000000000..2bceb8bc77df --- /dev/null +++ b/store/cachekv/internal/memiterator.go @@ -0,0 +1,137 @@ +package internal + +import ( + "bytes" + "errors" + + "github.com/cosmos/cosmos-sdk/store/types" + "github.com/tidwall/btree" +) + +var _ types.Iterator = (*memIterator)(nil) + +// memIterator iterates over iterKVCache items. +// if key is nil, means it was deleted. +// Implements Iterator. +type memIterator struct { + iter btree.GenericIter[item] + + start []byte + end []byte + ascending bool + lastKey []byte + deleted map[string]struct{} + valid bool +} + +func NewMemIterator(start, end []byte, items *BTree, deleted map[string]struct{}, ascending bool) *memIterator { + iter := items.tree.Iter() + var valid bool + if ascending { + if start != nil { + valid = iter.Seek(newItem(start, nil)) + } else { + valid = iter.First() + } + } else { + if end != nil { + valid = iter.Seek(newItem(end, nil)) + if !valid { + valid = iter.Last() + } else { + // end is exclusive + valid = iter.Prev() + } + } else { + valid = iter.Last() + } + } + + mi := &memIterator{ + iter: iter, + start: start, + end: end, + ascending: ascending, + lastKey: nil, + deleted: deleted, + valid: valid, + } + + if mi.valid { + mi.valid = mi.keyInRange(mi.Key()) + } + + return mi +} + +func (mi *memIterator) Domain() (start []byte, end []byte) { + return mi.start, mi.end +} + +func (mi *memIterator) Close() error { + mi.iter.Release() + return nil +} + +func (mi *memIterator) Error() error { + if !mi.Valid() { + return errors.New("invalid memIterator") + } + return nil +} + +func (mi *memIterator) Valid() bool { + return mi.valid +} + +func (mi *memIterator) Next() { + mi.assertValid() + + if mi.ascending { + mi.valid = mi.iter.Next() + } else { + mi.valid = mi.iter.Prev() + } + + if mi.valid { + mi.valid = mi.keyInRange(mi.Key()) + } +} + +func (mi *memIterator) keyInRange(key []byte) bool { + if mi.ascending && mi.end != nil && bytes.Compare(key, mi.end) >= 0 { + return false + } + if !mi.ascending && mi.start != nil && bytes.Compare(key, mi.start) < 0 { + return false + } + return true +} + +func (mi *memIterator) Key() []byte { + return mi.iter.Item().key +} + +func (mi *memIterator) Value() []byte { + item := mi.iter.Item() + key := item.key + // We need to handle the case where deleted is modified and includes our current key + // We handle this by maintaining a lastKey object in the iterator. + // If the current key is the same as the last key (and last key is not nil / the start) + // then we are calling value on the same thing as last time. + // Therefore we don't check the mi.deleted to see if this key is included in there. + if _, ok := mi.deleted[string(key)]; ok { + if mi.lastKey == nil || !bytes.Equal(key, mi.lastKey) { + // not re-calling on old last key + return nil + } + } + mi.lastKey = key + return item.value +} + +func (mi *memIterator) assertValid() { + if err := mi.Error(); err != nil { + panic(err) + } +} diff --git a/store/cachekv/mergeiterator.go b/store/cachekv/internal/mergeiterator.go similarity index 86% rename from store/cachekv/mergeiterator.go rename to store/cachekv/internal/mergeiterator.go index a6c7a035aba0..4186a178a863 100644 --- a/store/cachekv/mergeiterator.go +++ b/store/cachekv/internal/mergeiterator.go @@ -1,4 +1,4 @@ -package cachekv +package internal import ( "bytes" @@ -18,17 +18,20 @@ type cacheMergeIterator struct { parent types.Iterator cache types.Iterator ascending bool + + valid bool } var _ types.Iterator = (*cacheMergeIterator)(nil) -func newCacheMergeIterator(parent, cache types.Iterator, ascending bool) *cacheMergeIterator { +func NewCacheMergeIterator(parent, cache types.Iterator, ascending bool) *cacheMergeIterator { iter := &cacheMergeIterator{ parent: parent, cache: cache, ascending: ascending, } + iter.valid = iter.skipUntilExistsOrInvalid() return iter } @@ -40,42 +43,38 @@ func (iter *cacheMergeIterator) Domain() (start, end []byte) { // Valid implements Iterator. func (iter *cacheMergeIterator) Valid() bool { - return iter.skipUntilExistsOrInvalid() + return iter.valid } // Next implements Iterator func (iter *cacheMergeIterator) Next() { - iter.skipUntilExistsOrInvalid() iter.assertValid() - // If parent is invalid, get the next cache item. - if !iter.parent.Valid() { + switch { + case !iter.parent.Valid(): + // If parent is invalid, get the next cache item. iter.cache.Next() - return - } - - // If cache is invalid, get the next parent item. - if !iter.cache.Valid() { + case !iter.cache.Valid(): + // If cache is invalid, get the next parent item. iter.parent.Next() - return - } - - // Both are valid. Compare keys. - keyP, keyC := iter.parent.Key(), iter.cache.Key() - switch iter.compare(keyP, keyC) { - case -1: // parent < cache - iter.parent.Next() - case 0: // parent == cache - iter.parent.Next() - iter.cache.Next() - case 1: // parent > cache - iter.cache.Next() + default: + // Both are valid. Compare keys. + keyP, keyC := iter.parent.Key(), iter.cache.Key() + switch iter.compare(keyP, keyC) { + case -1: // parent < cache + iter.parent.Next() + case 0: // parent == cache + iter.parent.Next() + iter.cache.Next() + case 1: // parent > cache + iter.cache.Next() + } } + iter.valid = iter.skipUntilExistsOrInvalid() } // Key implements Iterator func (iter *cacheMergeIterator) Key() []byte { - iter.skipUntilExistsOrInvalid() iter.assertValid() // If parent is invalid, get the cache key. @@ -106,7 +105,6 @@ func (iter *cacheMergeIterator) Key() []byte { // Value implements Iterator func (iter *cacheMergeIterator) Value() []byte { - iter.skipUntilExistsOrInvalid() iter.assertValid() // If parent is invalid, get the cache value. @@ -137,11 +135,12 @@ func (iter *cacheMergeIterator) Value() []byte { // Close implements Iterator func (iter *cacheMergeIterator) Close() error { + err1 := iter.cache.Close() if err := iter.parent.Close(); err != nil { return err } - return iter.cache.Close() + return err1 } // Error returns an error if the cacheMergeIterator is invalid defined by the diff --git a/store/cachekv/memiterator.go b/store/cachekv/memiterator.go deleted file mode 100644 index a12ff9acfd11..000000000000 --- a/store/cachekv/memiterator.go +++ /dev/null @@ -1,57 +0,0 @@ -package cachekv - -import ( - "bytes" - - dbm "github.com/tendermint/tm-db" - - "github.com/cosmos/cosmos-sdk/store/types" -) - -// memIterator iterates over iterKVCache items. -// if key is nil, means it was deleted. -// Implements Iterator. -type memIterator struct { - types.Iterator - - lastKey []byte - deleted map[string]struct{} -} - -func newMemIterator(start, end []byte, items *dbm.MemDB, deleted map[string]struct{}, ascending bool) *memIterator { - var ( - iter types.Iterator - err error - ) - - if ascending { - iter, err = items.Iterator(start, end) - } else { - iter, err = items.ReverseIterator(start, end) - } - - if err != nil { - panic(err) - } - - return &memIterator{ - Iterator: iter, - lastKey: nil, - deleted: deleted, - } -} - -func (mi *memIterator) Value() []byte { - key := mi.Iterator.Key() - // We need to handle the case where deleted is modified and includes our current key - // We handle this by maintaining a lastKey object in the iterator. - // If the current key is the same as the last key (and last key is not nil / the start) - // then we are calling value on the same thing as last time. - // Therefore we don't check the mi.deleted to see if this key is included in there. - reCallingOnOldLastKey := (mi.lastKey != nil) && bytes.Equal(key, mi.lastKey) - if _, ok := mi.deleted[string(key)]; ok && !reCallingOnOldLastKey { - return nil - } - mi.lastKey = key - return mi.Iterator.Value() -} diff --git a/store/cachekv/search_benchmark_test.go b/store/cachekv/search_benchmark_test.go index 921bff4e3864..4007c7cda202 100644 --- a/store/cachekv/search_benchmark_test.go +++ b/store/cachekv/search_benchmark_test.go @@ -4,7 +4,7 @@ import ( "strconv" "testing" - db "github.com/tendermint/tm-db" + "github.com/cosmos/cosmos-sdk/store/cachekv/internal" ) func BenchmarkLargeUnsortedMisses(b *testing.B) { @@ -39,6 +39,6 @@ func generateStore() *Store { return &Store{ cache: cache, unsortedCache: unsorted, - sortedCache: db.NewMemDB(), + sortedCache: internal.NewBTree(), } } diff --git a/store/cachekv/store.go b/store/cachekv/store.go index 0ebc52268548..42354fa78be3 100644 --- a/store/cachekv/store.go +++ b/store/cachekv/store.go @@ -9,10 +9,10 @@ import ( dbm "github.com/tendermint/tm-db" "github.com/cosmos/cosmos-sdk/internal/conv" + "github.com/cosmos/cosmos-sdk/store/cachekv/internal" "github.com/cosmos/cosmos-sdk/store/tracekv" "github.com/cosmos/cosmos-sdk/store/types" "github.com/cosmos/cosmos-sdk/types/kv" - "github.com/tendermint/tendermint/libs/math" ) // cValue represents a cached value. @@ -30,7 +30,7 @@ type Store struct { cache map[string]*cValue deleted map[string]struct{} unsortedCache map[string]struct{} - sortedCache *dbm.MemDB // always ascending sorted + sortedCache *internal.BTree // always ascending sorted parent types.KVStore } @@ -42,7 +42,7 @@ func NewStore(parent types.KVStore) *Store { cache: make(map[string]*cValue), deleted: make(map[string]struct{}), unsortedCache: make(map[string]struct{}), - sortedCache: dbm.NewMemDB(), + sortedCache: internal.NewBTree(), parent: parent, } } @@ -101,6 +101,11 @@ func (store *Store) Write() { store.mtx.Lock() defer store.mtx.Unlock() + if len(store.cache) == 0 && len(store.deleted) == 0 && len(store.unsortedCache) == 0 { + store.sortedCache = internal.NewBTree() + return + } + // We need a copy of all of the keys. // Not the best, but probably not a bottleneck depending. keys := make([]string, 0, len(store.cache)) @@ -144,7 +149,7 @@ func (store *Store) Write() { for key := range store.unsortedCache { delete(store.unsortedCache, key) } - store.sortedCache = dbm.NewMemDB() + store.sortedCache = internal.NewBTree() } // CacheWrap implements CacheWrapper. @@ -183,9 +188,9 @@ func (store *Store) iterator(start, end []byte, ascending bool) types.Iterator { } store.dirtyItems(start, end) - cache = newMemIterator(start, end, store.sortedCache, store.deleted, ascending) + cache = internal.NewMemIterator(start, end, store.sortedCache, store.deleted, ascending) - return newCacheMergeIterator(parent, cache, ascending) + return internal.NewCacheMergeIterator(parent, cache, ascending) } func findStartIndex(strL []string, startQ string) int { @@ -273,7 +278,7 @@ const minSortSize = 1024 // Constructs a slice of dirty items, to use w/ memIterator. func (store *Store) dirtyItems(start, end []byte) { startStr, endStr := conv.UnsafeBytesToStr(start), conv.UnsafeBytesToStr(end) - if startStr > endStr { + if end != nil && startStr > endStr { // Nothing to do here. return } @@ -288,6 +293,7 @@ func (store *Store) dirtyItems(start, end []byte) { // than just not having the cache. if n < minSortSize { for key := range store.unsortedCache { + // dbm.IsKeyInDomain is nil safe and returns true iff key is greater than start if dbm.IsKeyInDomain(conv.UnsafeStrToBytes(key), start, end) { cacheValue := store.cache[key] unsorted = append(unsorted, &kv.Pair{Key: []byte(key), Value: cacheValue.value}) @@ -308,24 +314,18 @@ func (store *Store) dirtyItems(start, end []byte) { // Now find the values within the domain // [start, end) startIndex := findStartIndex(strL, startStr) - endIndex := findEndIndex(strL, endStr) - - if endIndex < 0 { - endIndex = len(strL) - 1 - } if startIndex < 0 { startIndex = 0 } - // Since we spent cycles to sort the values, we should process and remove a reasonable amount - // ensure start to end is at least minSortSize in size - // if below minSortSize, expand it to cover additional values - // this amortizes the cost of processing elements across multiple calls - if endIndex-startIndex < minSortSize { - endIndex = math.MinInt(startIndex+minSortSize, len(strL)-1) - if endIndex-startIndex < minSortSize { - startIndex = math.MaxInt(endIndex-minSortSize, 0) - } + var endIndex int + if end == nil { + endIndex = len(strL) - 1 + } else { + endIndex = findEndIndex(strL, endStr) + } + if endIndex < 0 { + endIndex = len(strL) - 1 } kvL := make([]*kv.Pair, 0) @@ -364,10 +364,7 @@ func (store *Store) clearUnsortedCacheSubset(unsorted []*kv.Pair, sortState sort store.sortedCache.Set(item.Key, []byte{}) continue } - err := store.sortedCache.Set(item.Key, item.Value) - if err != nil { - panic(err) - } + store.sortedCache.Set(item.Key, item.Value) } } diff --git a/store/cachekv/store_test.go b/store/cachekv/store_test.go index d589932d30fc..3ef99fd6f144 100644 --- a/store/cachekv/store_test.go +++ b/store/cachekv/store_test.go @@ -120,6 +120,7 @@ func TestCacheKVIteratorBounds(t *testing.T) { i++ } require.Equal(t, nItems, i) + require.NoError(t, itr.Close()) // iterate over none itr = st.Iterator(bz("money"), nil) @@ -128,6 +129,7 @@ func TestCacheKVIteratorBounds(t *testing.T) { i++ } require.Equal(t, 0, i) + require.NoError(t, itr.Close()) // iterate over lower itr = st.Iterator(keyFmt(0), keyFmt(3)) @@ -139,6 +141,7 @@ func TestCacheKVIteratorBounds(t *testing.T) { i++ } require.Equal(t, 3, i) + require.NoError(t, itr.Close()) // iterate over upper itr = st.Iterator(keyFmt(2), keyFmt(4)) @@ -150,6 +153,64 @@ func TestCacheKVIteratorBounds(t *testing.T) { i++ } require.Equal(t, 4, i) + require.NoError(t, itr.Close()) +} + +func TestCacheKVReverseIteratorBounds(t *testing.T) { + st := newCacheKVStore() + + // set some items + nItems := 5 + for i := 0; i < nItems; i++ { + st.Set(keyFmt(i), valFmt(i)) + } + + // iterate over all of them + itr := st.ReverseIterator(nil, nil) + i := 0 + for ; itr.Valid(); itr.Next() { + k, v := itr.Key(), itr.Value() + require.Equal(t, keyFmt(nItems-1-i), k) + require.Equal(t, valFmt(nItems-1-i), v) + i++ + } + require.Equal(t, nItems, i) + require.NoError(t, itr.Close()) + + // iterate over none + itr = st.ReverseIterator(bz("money"), nil) + i = 0 + for ; itr.Valid(); itr.Next() { + i++ + } + require.Equal(t, 0, i) + require.NoError(t, itr.Close()) + + // iterate over lower + end := 3 + itr = st.ReverseIterator(keyFmt(0), keyFmt(end)) + i = 0 + for ; itr.Valid(); itr.Next() { + i++ + k, v := itr.Key(), itr.Value() + require.Equal(t, keyFmt(end-i), k) + require.Equal(t, valFmt(end-i), v) + } + require.Equal(t, 3, i) + require.NoError(t, itr.Close()) + + // iterate over upper + end = 4 + itr = st.ReverseIterator(keyFmt(2), keyFmt(end)) + i = 0 + for ; itr.Valid(); itr.Next() { + i++ + k, v := itr.Key(), itr.Value() + require.Equal(t, keyFmt(end-i), k) + require.Equal(t, valFmt(end-i), v) + } + require.Equal(t, 2, i) + require.NoError(t, itr.Close()) } func TestCacheKVMergeIteratorBasics(t *testing.T) { @@ -291,6 +352,25 @@ func TestCacheKVMergeIteratorChunks(t *testing.T) { assertIterateDomainCheck(t, st, truth, []keyRange{{0, 15}, {25, 35}, {38, 40}, {45, 80}}) } +func TestCacheKVMergeIteratorDomain(t *testing.T) { + st := newCacheKVStore() + + itr := st.Iterator(nil, nil) + start, end := itr.Domain() + require.Equal(t, start, end) + require.NoError(t, itr.Close()) + + itr = st.Iterator(keyFmt(40), keyFmt(60)) + start, end = itr.Domain() + require.Equal(t, keyFmt(40), start) + require.Equal(t, keyFmt(60), end) + require.NoError(t, itr.Close()) + + start, end = st.ReverseIterator(keyFmt(0), keyFmt(80)).Domain() + require.Equal(t, keyFmt(0), start) + require.Equal(t, keyFmt(80), end) +} + func TestCacheKVMergeIteratorRandom(t *testing.T) { st := newCacheKVStore() truth := dbm.NewMemDB() @@ -306,6 +386,67 @@ func TestCacheKVMergeIteratorRandom(t *testing.T) { } } +func TestNilEndIterator(t *testing.T) { + const SIZE = 3000 + + tests := []struct { + name string + write bool + startIndex int + end []byte + }{ + {name: "write=false, end=nil", write: false, end: nil, startIndex: 1000}, + {name: "write=false, end=nil; full key scan", write: false, end: nil, startIndex: 2000}, + {name: "write=true, end=nil", write: true, end: nil, startIndex: 1000}, + {name: "write=false, end=non-nil", write: false, end: keyFmt(3000), startIndex: 1000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + st := newCacheKVStore() + + for i := 0; i < SIZE; i++ { + kstr := keyFmt(i) + st.Set(kstr, valFmt(i)) + } + + if tt.write { + st.Write() + } + + itr := st.Iterator(keyFmt(tt.startIndex), tt.end) + i := tt.startIndex + j := 0 + for itr.Valid() { + require.Equal(t, keyFmt(i), itr.Key()) + require.Equal(t, valFmt(i), itr.Value()) + itr.Next() + i++ + j++ + } + + require.Equal(t, SIZE-tt.startIndex, j) + require.NoError(t, itr.Close()) + }) + } +} + +// TestIteratorDeadlock demonstrate the deadlock issue in cache store. +func TestIteratorDeadlock(t *testing.T) { + mem := dbadapter.Store{DB: dbm.NewMemDB()} + store := cachekv.NewStore(mem) + // the channel buffer is 64 and received once, so put at least 66 elements. + for i := 0; i < 66; i++ { + store.Set([]byte(fmt.Sprintf("key%d", i)), []byte{1}) + } + it := store.Iterator(nil, nil) + defer it.Close() + store.Set([]byte("key20"), []byte{1}) + // it'll be blocked here with previous version, or enable lock on btree. + it2 := store.Iterator(nil, nil) + defer it2.Close() +} + //------------------------------------------------------------------------------------------- // do some random ops @@ -388,6 +529,7 @@ func assertIterateDomain(t *testing.T, st types.KVStore, expectedN int) { i++ } require.Equal(t, expectedN, i) + require.NoError(t, itr.Close()) } func assertIterateDomainCheck(t *testing.T, st types.KVStore, mem dbm.DB, r []keyRange) { @@ -419,6 +561,8 @@ func assertIterateDomainCheck(t *testing.T, st types.KVStore, mem dbm.DB, r []ke require.False(t, itr.Valid()) require.False(t, itr2.Valid()) + require.NoError(t, itr.Close()) + require.NoError(t, itr2.Close()) } func assertIterateDomainCompare(t *testing.T, st types.KVStore, mem dbm.DB) { @@ -428,6 +572,8 @@ func assertIterateDomainCompare(t *testing.T, st types.KVStore, mem dbm.DB) { require.NoError(t, err) checkIterators(t, itr, itr2) checkIterators(t, itr2, itr) + require.NoError(t, itr.Close()) + require.NoError(t, itr2.Close()) } func checkIterators(t *testing.T, itr, itr2 types.Iterator) {