diff --git a/internal/unionstore/arena/arena.go b/internal/unionstore/arena/arena.go index 3e3cec489..84075020e 100644 --- a/internal/unionstore/arena/arena.go +++ b/internal/unionstore/arena/arena.go @@ -259,7 +259,7 @@ type KeyFlagsGetter interface { // VlogMemDB is the interface of the memory buffer which supports vlog to revert node and inspect node. type VlogMemDB[G KeyFlagsGetter] interface { - RevertNode(hdr *MemdbVlogHdr) + RevertVAddr(hdr *MemdbVlogHdr) InspectNode(addr MemdbArenaAddr) (G, MemdbArenaAddr) } @@ -351,7 +351,7 @@ func (l *MemdbVlog[G, M]) RevertToCheckpoint(m M, cp *MemDBCheckpoint) { block := l.blocks[cursor.blocks-1].buf var hdr MemdbVlogHdr hdr.load(block[hdrOff:]) - m.RevertNode(&hdr) + m.RevertVAddr(&hdr) l.moveBackCursor(&cursor, &hdr) } } diff --git a/internal/unionstore/arena/arena_test.go b/internal/unionstore/arena/arena_test.go index 5816481bf..3c67054bf 100644 --- a/internal/unionstore/arena/arena_test.go +++ b/internal/unionstore/arena/arena_test.go @@ -42,7 +42,7 @@ import ( type dummyMemDB struct{} -func (m *dummyMemDB) RevertNode(hdr *MemdbVlogHdr) {} +func (m *dummyMemDB) RevertVAddr(hdr *MemdbVlogHdr) {} func (m *dummyMemDB) InspectNode(addr MemdbArenaAddr) (KeyFlagsGetter, MemdbArenaAddr) { return nil, NullAddr } diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go index 17e826327..29f0ea53e 100644 --- a/internal/unionstore/art/art.go +++ b/internal/unionstore/art/art.go @@ -16,6 +16,7 @@ package art import ( + "fmt" "math" tikverr "github.com/tikv/client-go/v2/error" @@ -76,7 +77,7 @@ func (t *ART) GetFlags(key []byte) (kv.KeyFlags, error) { if leaf.vAddr.IsNull() && leaf.isDeleted() { return 0, tikverr.ErrNotExist } - return leaf.getKeyFlags(), nil + return leaf.GetKeyFlags(), nil } func (t *ART) Set(key artKey, value []byte, ops ...kv.FlagsOp) error { @@ -324,8 +325,8 @@ func (t *ART) newLeaf(key artKey) (artNode, *artLeaf) { } func (t *ART) setValue(addr arena.MemdbArenaAddr, l *artLeaf, value []byte, ops []kv.FlagsOp) { - flags := l.getKeyFlags() - if flags == 0 && l.vAddr.IsNull() { + flags := l.GetKeyFlags() + if flags == 0 && l.vAddr.IsNull() || l.isDeleted() { t.len++ t.size += int(l.klen) } @@ -373,12 +374,12 @@ func (t *ART) trySwapValue(addr arena.MemdbArenaAddr, value []byte) (int, bool) } func (t *ART) Dirty() bool { - panic("unimplemented") + return t.dirty } // Mem returns the memory usage of MemBuffer. func (t *ART) Mem() uint64 { - panic("unimplemented") + return t.allocator.nodeAllocator.Capacity() + t.allocator.vlogAllocator.Capacity() } // Len returns the count of entries in the MemBuffer. @@ -392,51 +393,97 @@ func (t *ART) Size() int { } func (t *ART) checkpoint() arena.MemDBCheckpoint { - panic("unimplemented") + return t.allocator.vlogAllocator.Checkpoint() } -func (t *ART) RevertNode(hdr *arena.MemdbVlogHdr) { - panic("unimplemented") +func (t *ART) RevertVAddr(hdr *arena.MemdbVlogHdr) { + lf := t.allocator.getLeaf(hdr.NodeAddr) + if lf == nil { + panic("revert an invalid node") + } + lf.vAddr = hdr.OldValue + t.size -= int(hdr.ValueLen) + if hdr.OldValue.IsNull() { + keptFlags := lf.GetKeyFlags() + keptFlags = keptFlags.AndPersistent() + if keptFlags == 0 { + lf.markDelete() + t.len-- + t.size -= int(lf.klen) + } else { + lf.setKeyFlags(keptFlags) + } + } else { + t.size += len(t.allocator.vlogAllocator.GetValue(hdr.OldValue)) + } } func (t *ART) InspectNode(addr arena.MemdbArenaAddr) (*artLeaf, arena.MemdbArenaAddr) { - panic("unimplemented") + lf := t.allocator.getLeaf(addr) + return lf, lf.vAddr } // Checkpoint returns a checkpoint of ART. func (t *ART) Checkpoint() *arena.MemDBCheckpoint { - panic("unimplemented") + cp := t.allocator.vlogAllocator.Checkpoint() + return &cp } // RevertToCheckpoint reverts the ART to the checkpoint. func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) { - panic("unimplemented") + t.allocator.vlogAllocator.RevertToCheckpoint(t, cp) + t.allocator.vlogAllocator.Truncate(cp) + t.allocator.vlogAllocator.OnMemChange() } func (t *ART) Stages() []arena.MemDBCheckpoint { - panic("unimplemented") + return t.stages } func (t *ART) Staging() int { - return 0 + t.stages = append(t.stages, t.checkpoint()) + return len(t.stages) } func (t *ART) Release(h int) { + if h == 0 { + // 0 is the invalid and no-effect handle. + return + } + if h != len(t.stages) { + panic("cannot release staging buffer") + } + if h == 1 { + tail := t.checkpoint() + if !t.stages[0].IsSamePosition(&tail) { + t.dirty = true + } + } + t.stages = t.stages[:h-1] } func (t *ART) Cleanup(h int) { -} - -func (t *ART) revertToCheckpoint(cp *arena.MemDBCheckpoint) { - panic("unimplemented") -} - -func (t *ART) moveBackCursor(cursor *arena.MemDBCheckpoint, hdr *arena.MemdbVlogHdr) { - panic("unimplemented") -} + if h == 0 { + // 0 is the invalid and no-effect handle. + return + } + if h > len(t.stages) { + return + } + if h < len(t.stages) { + panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(db.stages)=%v", h, len(t.stages))) + } -func (t *ART) truncate(snap *arena.MemDBCheckpoint) { - panic("unimplemented") + cp := &t.stages[h-1] + if !t.vlogInvalid { + curr := t.checkpoint() + if !curr.IsSamePosition(cp) { + t.allocator.vlogAllocator.RevertToCheckpoint(t, cp) + t.allocator.vlogAllocator.Truncate(cp) + } + } + t.stages = t.stages[:h-1] + t.allocator.vlogAllocator.OnMemChange() } // Reset resets the MemBuffer to initial states. @@ -459,7 +506,10 @@ func (t *ART) DiscardValues() { // InspectStage used to inspect the value updates in the given stage. func (t *ART) InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte)) { - panic("unimplemented") + idx := handle - 1 + tail := t.allocator.vlogAllocator.Checkpoint() + head := t.stages[idx] + t.allocator.vlogAllocator.InspectKVInLog(t, &head, &tail, f) } // SelectValueHistory select the latest value which makes `predicate` returns true from the modification history. diff --git a/internal/unionstore/art/art_node.go b/internal/unionstore/art/art_node.go index 0c4ee9090..f94794ce6 100644 --- a/internal/unionstore/art/art_node.go +++ b/internal/unionstore/art/art_node.go @@ -265,11 +265,6 @@ func (l *artLeaf) getKeyDepth(depth uint32) []byte { return unsafe.Slice((*byte)(base), int(l.klen)-int(depth)) } -// GetKeyFlags gets the flags of the leaf -func (l *artLeaf) GetKeyFlags() kv.KeyFlags { - panic("unimplemented") -} - func (l *artLeaf) match(depth uint32, key artKey) bool { return bytes.Equal(l.getKeyDepth(depth), key[depth:]) } @@ -278,7 +273,8 @@ func (l *artLeaf) setKeyFlags(flags kv.KeyFlags) { l.flags = uint16(flags) & flagMask } -func (l *artLeaf) getKeyFlags() kv.KeyFlags { +// GetKeyFlags gets the flags of the leaf +func (l *artLeaf) GetKeyFlags() kv.KeyFlags { return kv.KeyFlags(l.flags & flagMask) } @@ -288,8 +284,6 @@ const ( ) // markDelete marks the artLeaf as deleted -// -//nolint:unused func (l *artLeaf) markDelete() { l.flags = deleteFlag } diff --git a/internal/unionstore/memdb_norace_test.go b/internal/unionstore/memdb_norace_test.go index f669ef798..ffd4b3d41 100644 --- a/internal/unionstore/memdb_norace_test.go +++ b/internal/unionstore/memdb_norace_test.go @@ -38,6 +38,7 @@ package unionstore import ( + "context" rand2 "crypto/rand" "encoding/binary" "math/rand" @@ -166,3 +167,41 @@ func testRandomDeriveRecur(t *testing.T, db *MemDB, golden *leveldb.DB, depth in return opLog } + +func TestRandomAB(t *testing.T) { + testRandomAB(t, newRbtDBWithContext(), newArtDBWithContext()) +} + +func testRandomAB(t *testing.T, bufferA, bufferB MemBuffer) { + require := require.New(t) + + const cnt = 50000 + keys := make([][]byte, cnt) + for i := 0; i < cnt; i++ { + h := bufferA.Staging() + require.Equal(h, bufferB.Staging()) + + keys[i] = make([]byte, rand.Intn(19)+1) + rand2.Read(keys[i]) + + bufferA.Set(keys[i], keys[i]) + bufferB.Set(keys[i], keys[i]) + + if i%2 == 0 { + bufferA.Cleanup(h) + bufferB.Cleanup(h) + } else { + bufferA.Release(h) + bufferB.Release(h) + } + + require.Equal(bufferA.Dirty(), bufferB.Dirty()) + require.Equal(bufferA.Len(), bufferB.Len()) + require.Equal(bufferA.Size(), bufferB.Size(), i) + key := keys[rand.Intn(i+1)] + v1, err1 := bufferA.Get(context.Background(), key) + v2, err2 := bufferB.Get(context.Background(), key) + require.Equal(err1, err2) + require.Equal(v1, v2) + } +} diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 62ae308d1..f3996de86 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -392,6 +392,7 @@ func testReset(t *testing.T, db interface { func TestInspectStage(t *testing.T) { testInspectStage(t, newRbtDBWithContext()) + testInspectStage(t, newArtDBWithContext()) } func testInspectStage(t *testing.T, db MemBuffer) { @@ -449,6 +450,7 @@ func testInspectStage(t *testing.T, db MemBuffer) { func TestDirty(t *testing.T) { testDirty(t, func() MemBuffer { return newRbtDBWithContext() }) + testDirty(t, func() MemBuffer { return newArtDBWithContext() }) } func testDirty(t *testing.T, createDb func() MemBuffer) { @@ -782,8 +784,12 @@ func TestNewIteratorMin(t *testing.T) { } func TestMemDBStaging(t *testing.T) { + testMemDBStaging(t, newRbtDBWithContext()) + testMemDBStaging(t, newArtDBWithContext()) +} + +func testMemDBStaging(t *testing.T, buffer MemBuffer) { assert := assert.New(t) - buffer := NewMemDB() err := buffer.Set([]byte("x"), make([]byte, 2)) assert.Nil(err) @@ -809,6 +815,117 @@ func TestMemDBStaging(t *testing.T) { assert.Equal(len(v), 2) } +func TestMemDBMultiLevelStaging(t *testing.T) { + testMemDBMultiLevelStaging(t, newRbtDBWithContext()) + testMemDBMultiLevelStaging(t, newArtDBWithContext()) +} + +func testMemDBMultiLevelStaging(t *testing.T, buffer MemBuffer) { + assert := assert.New(t) + + key := []byte{0} + for i := 0; i < 100; i++ { + assert.Equal(i+1, buffer.Staging()) + buffer.Set(key, []byte{byte(i)}) + v, err := buffer.Get(context.Background(), key) + assert.Nil(err) + assert.Equal(v, []byte{byte(i)}) + } + + for i := 99; i >= 0; i-- { + expect := i + if i%2 == 1 { + expect = i - 1 + buffer.Cleanup(i + 1) + } else { + buffer.Release(i + 1) + } + v, err := buffer.Get(context.Background(), key) + assert.Nil(err) + assert.Equal(v, []byte{byte(expect)}) + } +} + +func TestInvalidStagingHandle(t *testing.T) { + testInvalidStagingHandle(t, newRbtDBWithContext()) + testInvalidStagingHandle(t, newArtDBWithContext()) +} + +func testInvalidStagingHandle(t *testing.T, buffer MemBuffer) { + // handle == 0 takes no effect + // MemBuffer.Release only accept the latest handle + // MemBuffer.Cleanup accept handle large or equal than the latest handle, but only takes effect when handle is the latest handle. + assert := assert.New(t) + + // test MemBuffer.Release + h1 := buffer.Staging() + assert.Positive(h1) + h2 := buffer.Staging() + assert.Positive(h2) + assert.Panics(func() { + buffer.Release(h2 + 1) + }) + assert.Panics(func() { + buffer.Release(h2 - 1) + }) + buffer.Release(0) + buffer.Release(h2) + buffer.Release(0) + buffer.Release(h1) + buffer.Release(0) + + // test MemBuffer.Cleanup + h1 = buffer.Staging() + assert.Positive(h1) + h2 = buffer.Staging() + assert.Positive(h2) + buffer.Cleanup(h2 + 1) // Cleanup is ok even if the handle is greater than the existing handles. + assert.Panics(func() { + buffer.Cleanup(h2 - 1) + }) + buffer.Cleanup(0) + buffer.Cleanup(h2) + buffer.Cleanup(0) + buffer.Cleanup(h1) + buffer.Cleanup(0) +} + +func TestMemDBCheckpoint(t *testing.T) { + testMemDBCheckpoint(t, newRbtDBWithContext()) + testMemDBCheckpoint(t, newArtDBWithContext()) +} + +func testMemDBCheckpoint(t *testing.T, buffer MemBuffer) { + assert := assert.New(t) + cp1 := buffer.Checkpoint() + + buffer.Set([]byte("x"), []byte("x")) + + cp2 := buffer.Checkpoint() + buffer.Set([]byte("y"), []byte("y")) + + h := buffer.Staging() + buffer.Set([]byte("z"), []byte("z")) + buffer.Release(h) + + for _, k := range []string{"x", "y", "z"} { + v, _ := buffer.Get(context.Background(), []byte(k)) + assert.Equal(v, []byte(k)) + } + + buffer.RevertToCheckpoint(cp2) + v, _ := buffer.Get(context.Background(), []byte("x")) + assert.Equal(v, []byte("x")) + for _, k := range []string{"y", "z"} { + _, err := buffer.Get(context.Background(), []byte(k)) + assert.NotNil(err) + } + + buffer.RevertToCheckpoint(cp1) + _, err := buffer.Get(context.Background(), []byte("x")) + assert.NotNil(err) +} + func TestBufferLimit(t *testing.T) { testBufferLimit(t, newRbtDBWithContext()) } @@ -897,3 +1014,36 @@ func testSnapshotGetIter(t *testing.T, db MemBuffer) { assert.Equal(reverseIter.Value(), []byte{byte(50)}) } } + +func TestCleanupKeepPersistentFlag(t *testing.T) { + testCleanupKeepPersistentFlag(t, newRbtDBWithContext()) + testCleanupKeepPersistentFlag(t, newArtDBWithContext()) +} + +func testCleanupKeepPersistentFlag(t *testing.T, db MemBuffer) { + assert := assert.New(t) + persistentFlag := kv.SetKeyLocked + nonPersistentFlag := kv.SetPresumeKeyNotExists + + h := db.Staging() + db.SetWithFlags([]byte{1}, []byte{1}, persistentFlag) + db.SetWithFlags([]byte{2}, []byte{2}, nonPersistentFlag) + db.SetWithFlags([]byte{3}, []byte{3}, persistentFlag, nonPersistentFlag) + db.Cleanup(h) + + for _, key := range [][]byte{{1}, {2}, {3}} { + // the values are reverted by MemBuffer.Cleanup + _, err := db.Get(context.Background(), key) + assert.NotNil(err) + } + + flag, err := db.GetFlags([]byte{1}) + assert.Nil(err) + assert.True(flag.HasLocked()) + _, err = db.GetFlags([]byte{2}) + assert.NotNil(err) + flag, err = db.GetFlags([]byte{3}) + assert.Nil(err) + assert.True(flag.HasLocked()) + assert.False(flag.HasPresumeKeyNotExists()) +} diff --git a/internal/unionstore/rbt/rbt.go b/internal/unionstore/rbt/rbt.go index 6dd9cecf9..40b3234c3 100644 --- a/internal/unionstore/rbt/rbt.go +++ b/internal/unionstore/rbt/rbt.go @@ -111,7 +111,7 @@ func (db *RBT) checkKeyInCache(key []byte) (MemdbNodeAddr, bool) { return nullNodeAddr, false } -func (db *RBT) RevertNode(hdr *arena.MemdbVlogHdr) { +func (db *RBT) RevertVAddr(hdr *arena.MemdbVlogHdr) { node := db.getNode(hdr.NodeAddr) node.vptr = hdr.OldValue db.size -= int(hdr.ValueLen) @@ -150,6 +150,10 @@ func (db *RBT) Staging() int { // Release publish all modifications in the latest staging buffer to upper level. func (db *RBT) Release(h int) { + if h == 0 { + // 0 is the invalid and no-effect handle. + return + } if h != len(db.stages) { // This should never happens in production environment. // Use panic to make debug easier. @@ -168,6 +172,10 @@ func (db *RBT) Release(h int) { // Cleanup cleanup the resources referenced by the StagingHandle. // If the changes are not published by `Release`, they will be discarded. func (db *RBT) Cleanup(h int) { + if h == 0 { + // 0 is the invalid and no-effect handle. + return + } if h > len(db.stages) { return }