diff --git a/cmd/node/config/gasSchedules/gasScheduleV1.toml b/cmd/node/config/gasSchedules/gasScheduleV1.toml index e2e6005d3f6..7de5ad4cf40 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV1.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV1.toml @@ -15,6 +15,8 @@ ESDTNFTAddUri = 500000 ESDTNFTUpdateAttributes = 500000 ESDTNFTMultiTransfer = 1000000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV2.toml b/cmd/node/config/gasSchedules/gasScheduleV2.toml index 3c8f2f3c871..af868419d7c 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV2.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV2.toml @@ -15,6 +15,8 @@ ESDTNFTAddUri = 500000 ESDTNFTUpdateAttributes = 500000 ESDTNFTMultiTransfer = 1000000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV3.toml b/cmd/node/config/gasSchedules/gasScheduleV3.toml index 89b4106eae2..47c119704f7 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV3.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV3.toml @@ -15,6 +15,8 @@ ESDTNFTAddUri = 500000 ESDTNFTUpdateAttributes = 500000 ESDTNFTMultiTransfer = 1000000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV4.toml b/cmd/node/config/gasSchedules/gasScheduleV4.toml index 56e6c342c1f..53113f8e401 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV4.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV4.toml @@ -15,6 +15,8 @@ ESDTNFTAddUri = 50000 ESDTNFTUpdateAttributes = 50000 ESDTNFTMultiTransfer = 200000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV5.toml b/cmd/node/config/gasSchedules/gasScheduleV5.toml index 33f1fdbfd85..f445187c569 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV5.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV5.toml @@ -15,6 +15,8 @@ ESDTNFTAddUri = 50000 ESDTNFTUpdateAttributes = 50000 ESDTNFTMultiTransfer = 200000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV6.toml b/cmd/node/config/gasSchedules/gasScheduleV6.toml index e14027354a4..9b36efaada9 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV6.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV6.toml @@ -15,6 +15,8 @@ ESDTNFTAddUri = 50000 ESDTNFTUpdateAttributes = 50000 ESDTNFTMultiTransfer = 200000 + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/cmd/node/config/gasSchedules/gasScheduleV7.toml b/cmd/node/config/gasSchedules/gasScheduleV7.toml index 9e5589e9673..20bd46cae7f 100644 --- a/cmd/node/config/gasSchedules/gasScheduleV7.toml +++ b/cmd/node/config/gasSchedules/gasScheduleV7.toml @@ -16,6 +16,8 @@ ESDTNFTUpdateAttributes = 50000 ESDTNFTMultiTransfer = 200000 MultiESDTNFTTransfer = 200000 # should be the same value with the ESDTNFTMultiTransfer + TrieLoadPerNode = 20000 + TrieStorePerNode = 50000 [MetaChainSystemSCsCost] Stake = 5000000 diff --git a/common/interface.go b/common/interface.go index c5513288542..e69db0c277d 100644 --- a/common/interface.go +++ b/common/interface.go @@ -42,7 +42,7 @@ type Trie interface { // TrieLeafParser is used to parse trie leaves type TrieLeafParser interface { - ParseLeaf(key []byte, val []byte, version TrieNodeVersion) (core.KeyValueHolder, error) + ParseLeaf(key []byte, val []byte, version core.TrieNodeVersion) (core.KeyValueHolder, error) IsInterfaceNil() bool } diff --git a/common/trie.go b/common/trie.go index 76e96e688dd..510029a1dc7 100644 --- a/common/trie.go +++ b/common/trie.go @@ -6,24 +6,6 @@ import ( "github.com/multiversx/mx-chain-core-go/core" ) -// TrieNodeVersion defines the version of the trie node -type TrieNodeVersion uint8 - -const ( - // NotSpecified means that the value is not populated or is not important - NotSpecified TrieNodeVersion = iota - - // AutoBalanceEnabled is used for data tries, and only after the activation of AutoBalanceDataTriesEnableEpoch flag - AutoBalanceEnabled -) - -// TrieData holds the data that will be inserted into the trie -type TrieData struct { - Key []byte - Value []byte - Version TrieNodeVersion -} - // EmptyTrieHash returns the value with empty trie hash var EmptyTrieHash = make([]byte, 32) diff --git a/errors/errors.go b/errors/errors.go index c03015d9aac..d3b01dfad93 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -544,3 +544,9 @@ var ErrNilPersistentHandler = errors.New("nil persistent handler") // ErrNilGenesisNodesSetupHandler signals that a nil genesis nodes setup handler has been provided var ErrNilGenesisNodesSetupHandler = errors.New("nil genesis nodes setup handler") + +// ErrInvalidTrieNodeVersion signals that an invalid trie node version has been provided +var ErrInvalidTrieNodeVersion = errors.New("invalid trie node version") + +// ErrNilTrieMigrator signals that a nil trie migrator has been provided +var ErrNilTrieMigrator = errors.New("nil trie migrator") diff --git a/genesis/mock/userAccountMock.go b/genesis/mock/userAccountMock.go index 0ad04ab6340..8a332c0bf33 100644 --- a/genesis/mock/userAccountMock.go +++ b/genesis/mock/userAccountMock.go @@ -5,6 +5,7 @@ import ( "errors" "math/big" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" ) @@ -141,7 +142,7 @@ func (uam *UserAccountMock) GetUserName() []byte { } // SaveDirtyData - -func (uam *UserAccountMock) SaveDirtyData(_ common.Trie) ([]common.TrieData, error) { +func (uam *UserAccountMock) SaveDirtyData(_ common.Trie) ([]core.TrieData, error) { return nil, nil } diff --git a/go.mod b/go.mod index 71d740f2df2..62b94df45ec 100644 --- a/go.mod +++ b/go.mod @@ -13,16 +13,16 @@ require ( github.com/google/gops v0.3.18 github.com/gorilla/websocket v1.5.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/multiversx/mx-chain-core-go v1.1.33 + github.com/multiversx/mx-chain-core-go v1.1.35-0.20230302100006-5230818062e9 github.com/multiversx/mx-chain-crypto-go v1.2.5 github.com/multiversx/mx-chain-es-indexer-go v1.3.12 github.com/multiversx/mx-chain-logger-go v1.0.11 github.com/multiversx/mx-chain-p2p-go v1.0.13 github.com/multiversx/mx-chain-storage-go v1.0.7 - github.com/multiversx/mx-chain-vm-common-go v1.3.37 - github.com/multiversx/mx-chain-vm-v1_2-go v1.2.50 - github.com/multiversx/mx-chain-vm-v1_3-go v1.3.51 - github.com/multiversx/mx-chain-vm-v1_4-go v1.4.77 + github.com/multiversx/mx-chain-vm-common-go v1.3.38-0.20230302100330-6d0ec8963a31 + github.com/multiversx/mx-chain-vm-v1_2-go v1.2.51-0.20230302102214-94d6f0ba4d00 + github.com/multiversx/mx-chain-vm-v1_3-go v1.3.52-0.20230302101124-e32ddf31fbb9 + github.com/multiversx/mx-chain-vm-v1_4-go v1.4.78-0.20230302100627-28a21440e662 github.com/pelletier/go-toml v1.9.3 github.com/pkg/errors v0.9.1 github.com/shirou/gopsutil v3.21.11+incompatible diff --git a/go.sum b/go.sum index 40847d4228e..d6d34e5f073 100644 --- a/go.sum +++ b/go.sum @@ -609,9 +609,8 @@ github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS github.com/multiversx/concurrent-map v0.1.4 h1:hdnbM8VE4b0KYJaGY5yJS2aNIW9TFFsUYwbO0993uPI= github.com/multiversx/concurrent-map v0.1.4/go.mod h1:8cWFRJDOrWHOTNSqgYCUvwT7c7eFQ4U2vKMOp4A/9+o= github.com/multiversx/mx-chain-core-go v1.1.30/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZbBSv9mk+hLeKvinSaSMk= -github.com/multiversx/mx-chain-core-go v1.1.31/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZbBSv9mk+hLeKvinSaSMk= -github.com/multiversx/mx-chain-core-go v1.1.33 h1:qk+TlaOhHpu+9VncL3yowjY4KU8uJ0oSdPfU7SgVDnk= -github.com/multiversx/mx-chain-core-go v1.1.33/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZbBSv9mk+hLeKvinSaSMk= +github.com/multiversx/mx-chain-core-go v1.1.35-0.20230302100006-5230818062e9 h1:ffKx/z2aFBzaUs10z5W/9pBsYxeY43UvYui3z4iYPRs= +github.com/multiversx/mx-chain-core-go v1.1.35-0.20230302100006-5230818062e9/go.mod h1:8gGEQv6BWuuJwhd25qqhCOZbBSv9mk+hLeKvinSaSMk= github.com/multiversx/mx-chain-crypto-go v1.2.5 h1:tuq3BUNMhKud5DQbZi9DiVAAHUXypizy8zPH0NpTGZk= github.com/multiversx/mx-chain-crypto-go v1.2.5/go.mod h1:teqhNyWEqfMPgNn8sgWXlgtJ1a36jGCnhs/tRpXW6r4= github.com/multiversx/mx-chain-es-indexer-go v1.3.12 h1:b7B8KMrCHM0Ghh4W0s1jXLI5MknEAOo7ZChFAwUUYpY= @@ -623,15 +622,14 @@ github.com/multiversx/mx-chain-p2p-go v1.0.13/go.mod h1:j9Ueo2ptCnL7TQvQg6KS/KWA github.com/multiversx/mx-chain-storage-go v1.0.7 h1:UqLo/OLTD3IHiE/TB/SEdNRV1GG2f1R6vIP5ehHwCNw= github.com/multiversx/mx-chain-storage-go v1.0.7/go.mod h1:gtKoV32Cg2Uy8deHzF8Ud0qAl0zv92FvWgPSYIP0Zmg= github.com/multiversx/mx-chain-vm-common-go v1.3.34/go.mod h1:sZ2COLCxvf2GxAAJHGmGqWybObLtFuk2tZUyGqnMXE8= -github.com/multiversx/mx-chain-vm-common-go v1.3.36/go.mod h1:sZ2COLCxvf2GxAAJHGmGqWybObLtFuk2tZUyGqnMXE8= -github.com/multiversx/mx-chain-vm-common-go v1.3.37 h1:KeK6JCjeNUOHC5Z12/CTQIa8Z1at0dnnL9hY1LNrHS8= -github.com/multiversx/mx-chain-vm-common-go v1.3.37/go.mod h1:sZ2COLCxvf2GxAAJHGmGqWybObLtFuk2tZUyGqnMXE8= -github.com/multiversx/mx-chain-vm-v1_2-go v1.2.50 h1:ScUq7/wq78vthMTQ6v5Ux1DvSMQMHxQ2Sl7aPP26q1w= -github.com/multiversx/mx-chain-vm-v1_2-go v1.2.50/go.mod h1:e3uYdgoKzs3puaznbmSjDcRisJc5Do4tpg7VqyYwoek= -github.com/multiversx/mx-chain-vm-v1_3-go v1.3.51 h1:axtp5/mpA+xYJ1cu4KtAGETV4t6v6/tNfQh0HCclBYY= -github.com/multiversx/mx-chain-vm-v1_3-go v1.3.51/go.mod h1:oKj32V2nkd+KGNOL6emnwVkDRPpciwHHDqBmeorcL8k= -github.com/multiversx/mx-chain-vm-v1_4-go v1.4.77 h1:3Yh4brS5/Jye24l5AKy+Q6Yci6Rv55pHyj9/GR3AYos= -github.com/multiversx/mx-chain-vm-v1_4-go v1.4.77/go.mod h1:3IaAOHc1JfxL5ywQZIrcaHQu5+CVdZNDaoY64NGOtUE= +github.com/multiversx/mx-chain-vm-common-go v1.3.38-0.20230302100330-6d0ec8963a31 h1:CN5jVyk8LsvOJ5htRcxPD8+lF77S1p++3HlGSbG4Lu8= +github.com/multiversx/mx-chain-vm-common-go v1.3.38-0.20230302100330-6d0ec8963a31/go.mod h1:y/pwzZF5saK3GIdseiyCOqoq5OdKIOIKu9tQiIWh7BY= +github.com/multiversx/mx-chain-vm-v1_2-go v1.2.51-0.20230302102214-94d6f0ba4d00 h1:qnoFnrRAX3BsMBLDnxkdi/F+hOiyzZqn/VeCOXfiLhg= +github.com/multiversx/mx-chain-vm-v1_2-go v1.2.51-0.20230302102214-94d6f0ba4d00/go.mod h1:pIsa9gzTMRy6/hAbcH/4aC5VulmZO/pY7l7MXRISGBk= +github.com/multiversx/mx-chain-vm-v1_3-go v1.3.52-0.20230302101124-e32ddf31fbb9 h1:6r6wWBHiD2D5PDobyxL3W+xIso/pan+k59T+l9YOS+0= +github.com/multiversx/mx-chain-vm-v1_3-go v1.3.52-0.20230302101124-e32ddf31fbb9/go.mod h1:LwEnuDYX6Y0KU3z5F+/sp9zcJo5+3lU9aTRH7an4G5I= +github.com/multiversx/mx-chain-vm-v1_4-go v1.4.78-0.20230302100627-28a21440e662 h1:m7MajpOJsSKLxBhh6u+9BzC1O6Lpu7VYr79JkD3u8WY= +github.com/multiversx/mx-chain-vm-v1_4-go v1.4.78-0.20230302100627-28a21440e662/go.mod h1:fQzZWMzInDYOQkeGA+um/8odZiAfPs3qgj6/Hfzsin0= github.com/multiversx/mx-components-big-int v0.1.1 h1:695mYPKYOrmGEGgRH4/pZruDoe3CPP1LHrBxKfvj5l4= github.com/multiversx/mx-components-big-int v0.1.1/go.mod h1:0QrcFdfeLgJ/am10HGBeH0G0DNF+0Qx1E4DS/iozQls= github.com/multiversx/protobuf v1.3.2 h1:RaNkxvGTGbA0lMcnHAN24qE1G1i+Xs5yHA6MDvQ4mSM= diff --git a/node/node_test.go b/node/node_test.go index 1b95a5dbb15..37a09761875 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -332,11 +332,11 @@ func TestNode_GetKeyValuePairs(t *testing.T) { GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { suffix := append(k1, acc.AddressBytes()...) - trieLeaf, _ := tlp.ParseLeaf(k1, append(v1, suffix...), common.NotSpecified) + trieLeaf, _ := tlp.ParseLeaf(k1, append(v1, suffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf suffix = append(k2, acc.AddressBytes()...) - trieLeaf2, _ := tlp.ParseLeaf(k2, append(v2, suffix...), common.NotSpecified) + trieLeaf2, _ := tlp.ParseLeaf(k2, append(v2, suffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf2 close(leavesChannels.LeavesChan) close(leavesChannels.ErrChan) @@ -940,13 +940,13 @@ func TestNode_GetAllIssuedESDTs(t *testing.T) { &trieMock.TrieStub{ GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { - trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), common.NotSpecified) + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf - trieLeaf, _ = tlp.ParseLeaf(sftToken, append(sftMarshalledData, sftSuffix...), common.NotSpecified) + trieLeaf, _ = tlp.ParseLeaf(sftToken, append(sftMarshalledData, sftSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf - trieLeaf, _ = tlp.ParseLeaf(nftToken, append(nftMarshalledData, nftSuffix...), common.NotSpecified) + trieLeaf, _ = tlp.ParseLeaf(nftToken, append(nftMarshalledData, nftSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) close(leavesChannels.ErrChan) @@ -1032,7 +1032,7 @@ func TestNode_GetESDTsWithRole(t *testing.T) { &trieMock.TrieStub{ GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { - trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), common.NotSpecified) + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) close(leavesChannels.ErrChan) @@ -1112,7 +1112,7 @@ func TestNode_GetESDTsRoles(t *testing.T) { &trieMock.TrieStub{ GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { - trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), common.NotSpecified) + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) close(leavesChannels.ErrChan) @@ -1177,7 +1177,7 @@ func TestNode_GetNFTTokenIDsRegisteredByAddress(t *testing.T) { &trieMock.TrieStub{ GetAllLeavesOnChannelCalled: func(leavesChannels *common.TrieIteratorChannels, ctx context.Context, rootHash []byte, _ common.KeyBuilder, tlp common.TrieLeafParser) error { go func() { - trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), common.NotSpecified) + trieLeaf, _ := tlp.ParseLeaf(esdtToken, append(marshalledData, esdtSuffix...), core.NotSpecified) leavesChannels.LeavesChan <- trieLeaf close(leavesChannels.LeavesChan) close(leavesChannels.ErrChan) diff --git a/process/gasCost.go b/process/gasCost.go index 87e54722c92..6d2f4afbe89 100644 --- a/process/gasCost.go +++ b/process/gasCost.go @@ -28,6 +28,8 @@ type BuiltInCost struct { ESDTNFTAddUri uint64 ESDTNFTUpdateAttributes uint64 ESDTNFTMultiTransfer uint64 + TrieLoadPerNode uint64 + TrieStorePerNode uint64 } // GasCost holds all the needed gas costs for system smart contracts diff --git a/process/smartContract/builtInFunctions/factory_test.go b/process/smartContract/builtInFunctions/factory_test.go index b19bf6c87ab..4397a7f37e0 100644 --- a/process/smartContract/builtInFunctions/factory_test.go +++ b/process/smartContract/builtInFunctions/factory_test.go @@ -85,6 +85,8 @@ func fillGasMapBuiltInCosts(value uint64) map[string]uint64 { gasMap["ESDTNFTAddUri"] = value gasMap["ESDTNFTUpdateAttributes"] = value gasMap["ESDTNFTMultiTransfer"] = value + gasMap["TrieLoadPerNode"] = value + gasMap["TrieStorePerNode"] = value return gasMap } @@ -161,7 +163,7 @@ func TestCreateBuiltInFunctionContainer(t *testing.T) { args := createMockArguments() builtInFuncFactory, err := CreateBuiltInFunctionsFactory(args) assert.Nil(t, err) - assert.Equal(t, len(builtInFuncFactory.BuiltInFunctionContainer().Keys()), 31) + assert.Equal(t, 32, len(builtInFuncFactory.BuiltInFunctionContainer().Keys())) err = builtInFuncFactory.SetPayableHandler(&testscommon.BlockChainHookStub{}) assert.Nil(t, err) diff --git a/state/accountsDB.go b/state/accountsDB.go index 55356589f44..41a18673bd4 100644 --- a/state/accountsDB.go +++ b/state/accountsDB.go @@ -550,6 +550,9 @@ func (adb *AccountsDB) saveDataTrie(accountHandler baseAccountHandler) error { } adb.journalize(entry) + //TODO in order to avoid recomputing the root hash after every transaction for the same data trie, + // benchmark if it is better to cache the account and compute the rootHash only when the state is committed. + // For this to work, LoadAccount should check that cache first, and only after load from the trie. rootHash, err := accountHandler.DataTrie().RootHash() if err != nil { return err diff --git a/state/accountsDB_test.go b/state/accountsDB_test.go index c4bacc0b80a..45bd6ad72df 100644 --- a/state/accountsDB_test.go +++ b/state/accountsDB_test.go @@ -324,7 +324,7 @@ func TestAccountsDB_SaveAccountSavesCodeAndDataTrieForUserAccount(t *testing.T) GetCalled: func(_ []byte) ([]byte, uint32, error) { return nil, 0, nil }, - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return nil }, RootCalled: func() (i []byte, err error) { @@ -851,7 +851,7 @@ func TestAccountsDB_CommitShouldCallCommitFromTrie(t *testing.T) { GetCalled: func(_ []byte) ([]byte, uint32, error) { return []byte("doge"), 0, nil }, - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return nil }, CommitCalled: func() error { diff --git a/state/baseAccount.go b/state/baseAccount.go index 86a85f440f9..e5e94367ca8 100644 --- a/state/baseAccount.go +++ b/state/baseAccount.go @@ -1,6 +1,7 @@ package state import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -53,7 +54,7 @@ func (ba *baseAccount) SaveKeyValue(key []byte, value []byte) error { } // SaveDirtyData triggers SaveDirtyData form the underlying trackableDataTrie -func (ba *baseAccount) SaveDirtyData(trie common.Trie) ([]common.TrieData, error) { +func (ba *baseAccount) SaveDirtyData(trie common.Trie) ([]core.TrieData, error) { if check.IfNil(ba.dataTrieTracker) { return nil, ErrNilTrackableDataTrie } diff --git a/state/disabled/disabledTrackableDataTrie.go b/state/disabled/disabledTrackableDataTrie.go index 0f10a89c1f3..87812f0c2ba 100644 --- a/state/disabled/disabledTrackableDataTrie.go +++ b/state/disabled/disabledTrackableDataTrie.go @@ -1,6 +1,10 @@ package disabled -import "github.com/multiversx/mx-chain-go/common" +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/common" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" +) type disabledTrackableDataTrie struct { } @@ -30,8 +34,13 @@ func (dtdt *disabledTrackableDataTrie) DataTrie() common.DataTrieHandler { } // SaveDirtyData does nothing for this implementation -func (dtdt *disabledTrackableDataTrie) SaveDirtyData(_ common.Trie) ([]common.TrieData, error) { - return make([]common.TrieData, 0), nil +func (dtdt *disabledTrackableDataTrie) SaveDirtyData(_ common.Trie) ([]core.TrieData, error) { + return make([]core.TrieData, 0), nil +} + +// MigrateDataTrieLeaves does nothing for this implementation +func (dtdt *disabledTrackableDataTrie) MigrateDataTrieLeaves(_ core.TrieNodeVersion, _ core.TrieNodeVersion, _ vmcommon.DataTrieMigrator) error { + return nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/state/export_test.go b/state/export_test.go index a042d6fed91..8bfd259f450 100644 --- a/state/export_test.go +++ b/state/export_test.go @@ -1,6 +1,7 @@ package state import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -79,11 +80,28 @@ func EmptyErrChanReturningHadContained(errChan chan error) bool { } // DirtyData - -func (tdaw *trackableDataTrie) DirtyData() map[string][]byte { - return tdaw.dirtyData +type DirtyData struct { + Value []byte + OldVersion core.TrieNodeVersion + NewVersion core.TrieNodeVersion +} + +// DirtyData - +func (tdaw *trackableDataTrie) DirtyData() map[string]DirtyData { + dd := make(map[string]DirtyData, len(tdaw.dirtyData)) + + for key, value := range tdaw.dirtyData { + dd[key] = DirtyData{ + Value: value.value, + OldVersion: value.oldVersion.version, + NewVersion: value.newVersion, + } + } + + return dd } // SaveDirtyData - -func (a *userAccount) SaveDirtyData(trie common.Trie) ([]common.TrieData, error) { +func (a *userAccount) SaveDirtyData(trie common.Trie) ([]core.TrieData, error) { return a.dataTrieTracker.SaveDirtyData(trie) } diff --git a/state/interface.go b/state/interface.go index 4eaaf34afe9..e8cd46946b6 100644 --- a/state/interface.go +++ b/state/interface.go @@ -4,6 +4,7 @@ import ( "context" "math/big" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" @@ -95,7 +96,8 @@ type DataTrieTracker interface { SaveKeyValue(key []byte, value []byte) error SetDataTrie(tr common.Trie) DataTrie() common.DataTrieHandler - SaveDirtyData(common.Trie) ([]common.TrieData, error) + SaveDirtyData(common.Trie) ([]core.TrieData, error) + MigrateDataTrieLeaves(oldVersion core.TrieNodeVersion, newVersion core.TrieNodeVersion, trieMigrator vmcommon.DataTrieMigrator) error IsInterfaceNil() bool } @@ -165,7 +167,7 @@ type baseAccountHandler interface { GetRootHash() []byte SetDataTrie(trie common.Trie) DataTrie() common.DataTrieHandler - SaveDirtyData(trie common.Trie) ([]common.TrieData, error) + SaveDirtyData(trie common.Trie) ([]core.TrieData, error) IsInterfaceNil() bool } @@ -222,5 +224,6 @@ type AccountsAdapterAPI interface { type dataTrie interface { common.Trie - UpdateWithVersion(key []byte, value []byte, version common.TrieNodeVersion) error + UpdateWithVersion(key []byte, value []byte, version core.TrieNodeVersion) error + CollectLeavesForMigration(oldVersion core.TrieNodeVersion, newVersion core.TrieNodeVersion, trieMigrator vmcommon.DataTrieMigrator) error } diff --git a/state/journalEntries.go b/state/journalEntries.go index 38c954bddd4..a7f66fec8f3 100644 --- a/state/journalEntries.go +++ b/state/journalEntries.go @@ -4,9 +4,9 @@ import ( "bytes" "fmt" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/marshal" - "github.com/multiversx/mx-chain-go/common" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -167,12 +167,12 @@ func (jea *journalEntryAccountCreation) IsInterfaceNil() bool { // JournalEntryDataTrieUpdates stores all the updates done to the account's data trie, // so it can be reverted in case of rollback type journalEntryDataTrieUpdates struct { - trieUpdates []common.TrieData + trieUpdates []core.TrieData account baseAccountHandler } // NewJournalEntryDataTrieUpdates outputs a new JournalEntryDataTrieUpdates implementation used to revert an account's data trie -func NewJournalEntryDataTrieUpdates(trieUpdates []common.TrieData, account baseAccountHandler) (*journalEntryDataTrieUpdates, error) { +func NewJournalEntryDataTrieUpdates(trieUpdates []core.TrieData, account baseAccountHandler) (*journalEntryDataTrieUpdates, error) { if check.IfNil(account) { return nil, fmt.Errorf("%w in NewJournalEntryDataTrieUpdates", ErrNilAccountHandler) } diff --git a/state/journalEntries_test.go b/state/journalEntries_test.go index bfb1f1fb2b2..86eaa51d256 100644 --- a/state/journalEntries_test.go +++ b/state/journalEntries_test.go @@ -4,8 +4,8 @@ import ( "errors" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" @@ -176,8 +176,8 @@ func TestJournalEntryAccountCreation_RevertUpdatesTheTrie(t *testing.T) { func TestNewJournalEntryDataTrieUpdates_NilAccountShouldErr(t *testing.T) { t.Parallel() - trieUpdates := make([]common.TrieData, 0) - trieUpdates = append(trieUpdates, common.TrieData{ + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ Key: []byte("a"), Value: []byte("b"), Version: 0, @@ -191,7 +191,7 @@ func TestNewJournalEntryDataTrieUpdates_NilAccountShouldErr(t *testing.T) { func TestNewJournalEntryDataTrieUpdates_EmptyTrieUpdatesShouldErr(t *testing.T) { t.Parallel() - trieUpdates := make([]common.TrieData, 0) + trieUpdates := make([]core.TrieData, 0) args := state.ArgsAccountCreation{ Hasher: &hashingMocks.HasherMock{}, Marshaller: &marshallerMock.MarshalizerMock{}, @@ -207,8 +207,8 @@ func TestNewJournalEntryDataTrieUpdates_EmptyTrieUpdatesShouldErr(t *testing.T) func TestNewJournalEntryDataTrieUpdates_OkValsShouldWork(t *testing.T) { t.Parallel() - trieUpdates := make([]common.TrieData, 0) - trieUpdates = append(trieUpdates, common.TrieData{ + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ Key: []byte("a"), Value: []byte("b"), Version: 0, @@ -230,8 +230,8 @@ func TestJournalEntryDataTrieUpdates_RevertFailsWhenUpdateFails(t *testing.T) { expectedErr := errors.New("error") - trieUpdates := make([]common.TrieData, 0) - trieUpdates = append(trieUpdates, common.TrieData{ + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ Key: []byte("a"), Value: []byte("b"), Version: 0, @@ -239,7 +239,7 @@ func TestJournalEntryDataTrieUpdates_RevertFailsWhenUpdateFails(t *testing.T) { accnt := stateMock.NewAccountWrapMock(nil) tr := &trieMock.TrieStub{ - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return expectedErr }, } @@ -257,8 +257,8 @@ func TestJournalEntryDataTrieUpdates_RevertFailsWhenAccountRootFails(t *testing. expectedErr := errors.New("error") - trieUpdates := make([]common.TrieData, 0) - trieUpdates = append(trieUpdates, common.TrieData{ + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ Key: []byte("a"), Value: []byte("b"), Version: 0, @@ -266,7 +266,7 @@ func TestJournalEntryDataTrieUpdates_RevertFailsWhenAccountRootFails(t *testing. accnt := stateMock.NewAccountWrapMock(nil) tr := &trieMock.TrieStub{ - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return nil }, RootCalled: func() ([]byte, error) { @@ -288,8 +288,8 @@ func TestJournalEntryDataTrieUpdates_RevertShouldWork(t *testing.T) { updateWasCalled := false rootWasCalled := false - trieUpdates := make([]common.TrieData, 0) - trieUpdates = append(trieUpdates, common.TrieData{ + trieUpdates := make([]core.TrieData, 0) + trieUpdates = append(trieUpdates, core.TrieData{ Key: []byte("a"), Value: []byte("b"), Version: 0, @@ -297,7 +297,7 @@ func TestJournalEntryDataTrieUpdates_RevertShouldWork(t *testing.T) { accnt := stateMock.NewAccountWrapMock(nil) tr := &trieMock.TrieStub{ - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { updateWasCalled = true return nil }, diff --git a/state/parsers/dataTrieLeafParser.go b/state/parsers/dataTrieLeafParser.go index ebc2f98b004..6437fbb55b9 100644 --- a/state/parsers/dataTrieLeafParser.go +++ b/state/parsers/dataTrieLeafParser.go @@ -33,8 +33,8 @@ func NewDataTrieLeafParser(address []byte, marshaller marshal.Marshalizer, enabl } // ParseLeaf returns a new KeyValStorage with the actual key and value -func (tlp *dataTrieLeafParser) ParseLeaf(trieKey []byte, trieVal []byte, version common.TrieNodeVersion) (core.KeyValueHolder, error) { - if tlp.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() && version == common.AutoBalanceEnabled { +func (tlp *dataTrieLeafParser) ParseLeaf(trieKey []byte, trieVal []byte, version core.TrieNodeVersion) (core.KeyValueHolder, error) { + if tlp.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() && version == core.AutoBalanceEnabled { data := &dataTrieValue.TrieLeafData{} err := tlp.marshaller.Unmarshal(data, trieVal) if err != nil { diff --git a/state/parsers/dataTrieLeafParser_test.go b/state/parsers/dataTrieLeafParser_test.go index 0afeda055bd..ba18aa0e6c0 100644 --- a/state/parsers/dataTrieLeafParser_test.go +++ b/state/parsers/dataTrieLeafParser_test.go @@ -4,9 +4,9 @@ import ( "encoding/hex" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/marshal" - "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state/dataTrieValue" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" @@ -55,7 +55,7 @@ func TestTrieLeafParser_ParseLeaf(t *testing.T) { suffix := append(key, address...) tlp, _ := NewDataTrieLeafParser(address, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) - keyVal, err := tlp.ParseLeaf(key, append(val, suffix...), common.NotSpecified) + keyVal, err := tlp.ParseLeaf(key, append(val, suffix...), core.NotSpecified) assert.Nil(t, err) assert.Equal(t, key, keyVal.Key()) assert.Equal(t, val, keyVal.Value()) @@ -73,7 +73,7 @@ func TestTrieLeafParser_ParseLeaf(t *testing.T) { } tlp, _ := NewDataTrieLeafParser(address, &marshallerMock.MarshalizerMock{}, enableEpochsHandler) - keyVal, err := tlp.ParseLeaf(key, append(val, suffix...), common.NotSpecified) + keyVal, err := tlp.ParseLeaf(key, append(val, suffix...), core.NotSpecified) assert.Nil(t, err) assert.Equal(t, key, keyVal.Key()) assert.Equal(t, val, keyVal.Value()) @@ -98,7 +98,7 @@ func TestTrieLeafParser_ParseLeaf(t *testing.T) { } tlp, _ := NewDataTrieLeafParser(address, marshaller, enableEpochsHandler) - keyVal, err := tlp.ParseLeaf(hasher.Compute(string(key)), serializedLeafData, common.AutoBalanceEnabled) + keyVal, err := tlp.ParseLeaf(hasher.Compute(string(key)), serializedLeafData, core.AutoBalanceEnabled) assert.Nil(t, err) assert.Equal(t, key, keyVal.Key()) assert.Equal(t, val, keyVal.Value()) @@ -122,7 +122,7 @@ func TestTrieLeafParser_ParseLeaf(t *testing.T) { } tlp, _ := NewDataTrieLeafParser(addrBytes, marshaller, enableEpochsHandler) - keyVal, err := tlp.ParseLeaf(keyBytes, valWithAppendedData, common.NotSpecified) + keyVal, err := tlp.ParseLeaf(keyBytes, valWithAppendedData, core.NotSpecified) assert.Nil(t, err) assert.Equal(t, keyBytes, keyVal.Key()) assert.Equal(t, valBytes, keyVal.Value()) diff --git a/state/parsers/mainTrieLeafParser.go b/state/parsers/mainTrieLeafParser.go index 74e1a849cfe..8835608fd7c 100644 --- a/state/parsers/mainTrieLeafParser.go +++ b/state/parsers/mainTrieLeafParser.go @@ -3,7 +3,6 @@ package parsers import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/keyValStorage" - "github.com/multiversx/mx-chain-go/common" ) type mainTrieLeafParser struct { @@ -15,7 +14,7 @@ func NewMainTrieLeafParser() *mainTrieLeafParser { } // ParseLeaf returns the given key an value as a KeyValStorage -func (tlp *mainTrieLeafParser) ParseLeaf(trieKey []byte, trieVal []byte, _ common.TrieNodeVersion) (core.KeyValueHolder, error) { +func (tlp *mainTrieLeafParser) ParseLeaf(trieKey []byte, trieVal []byte, _ core.TrieNodeVersion) (core.KeyValueHolder, error) { return keyValStorage.NewKeyValStorage(trieKey, trieVal), nil } diff --git a/state/parsers/mainTrieLeafParser_test.go b/state/parsers/mainTrieLeafParser_test.go index aa88f4b4f0c..fc94dcc8ae6 100644 --- a/state/parsers/mainTrieLeafParser_test.go +++ b/state/parsers/mainTrieLeafParser_test.go @@ -3,8 +3,8 @@ package parsers import ( "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-go/common" "github.com/stretchr/testify/assert" ) @@ -24,7 +24,7 @@ func TestNewMainTrieLeafParser(t *testing.T) { value := []byte("value") dtlp := NewMainTrieLeafParser() - keyValHolder, err := dtlp.ParseLeaf(key, value, common.NotSpecified) + keyValHolder, err := dtlp.ParseLeaf(key, value, core.NotSpecified) assert.Nil(t, err) assert.Equal(t, key, keyValHolder.Key()) assert.Equal(t, value, keyValHolder.Value()) diff --git a/state/trackableDataTrie.go b/state/trackableDataTrie.go index 5639ff0a839..6b8ff6bd57d 100644 --- a/state/trackableDataTrie.go +++ b/state/trackableDataTrie.go @@ -9,12 +9,25 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + errorsCommon "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state/dataTrieValue" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) +type optionalVersion struct { + version core.TrieNodeVersion + isValueKnown bool +} + +type dirtyData struct { + value []byte + oldVersion optionalVersion + newVersion core.TrieNodeVersion +} + // TrackableDataTrie wraps a PatriciaMerkelTrie adding modifying data capabilities type trackableDataTrie struct { - dirtyData map[string][]byte + dirtyData map[string]dirtyData tr common.Trie hasher hashing.Hasher marshaller marshal.Marshalizer @@ -44,7 +57,7 @@ func NewTrackableDataTrie( tr: tr, hasher: hasher, marshaller: marshaller, - dirtyData: make(map[string][]byte), + dirtyData: make(map[string]dirtyData), identifier: identifier, enableEpochsHandler: enableEpochsHandler, }, nil @@ -55,9 +68,9 @@ func NewTrackableDataTrie( // Data must have been retrieved from its trie func (tdaw *trackableDataTrie) RetrieveValue(key []byte) ([]byte, uint32, error) { // search in dirty data cache - if value, found := tdaw.dirtyData[string(key)]; found { - log.Trace("retrieve value from dirty data", "key", key, "value", value) - return value, 0, nil + if dataEntry, found := tdaw.dirtyData[string(key)]; found { + log.Trace("retrieve value from dirty data", "key", key, "value", dataEntry.value) + return dataEntry.value, 0, nil } // ok, not in cache, retrieve from trie @@ -110,10 +123,91 @@ func (tdaw *trackableDataTrie) SaveKeyValue(key []byte, value []byte) error { return data.ErrLeafSizeTooBig } - tdaw.dirtyData[string(key)] = value + dataEntry := dirtyData{ + value: value, + oldVersion: optionalVersion{ + isValueKnown: false, + }, + newVersion: tdaw.getVersionForNewlyAddedData(), + } + + tdaw.dirtyData[string(key)] = dataEntry return nil } +// MigrateDataTrieLeaves migrates the data trie leaves from oldVersion to newVersion +func (tdaw *trackableDataTrie) MigrateDataTrieLeaves(oldVersion core.TrieNodeVersion, newVersion core.TrieNodeVersion, trieMigrator vmcommon.DataTrieMigrator) error { + if check.IfNil(tdaw.tr) { + return ErrNilTrie + } + if check.IfNil(trieMigrator) { + return errorsCommon.ErrNilTrieMigrator + } + + dtr, ok := tdaw.tr.(dataTrie) + if !ok { + return fmt.Errorf("invalid trie, type is %T", tdaw.tr) + } + + err := dtr.CollectLeavesForMigration(oldVersion, newVersion, trieMigrator) + if err != nil { + return err + } + + dataToBeMigrated := trieMigrator.GetLeavesToBeMigrated() + for _, leafData := range dataToBeMigrated { + dataEntry := dirtyData{ + value: leafData.Value, + oldVersion: optionalVersion{ + version: leafData.Version, + isValueKnown: true, + }, + newVersion: newVersion, + } + + tdaw.dirtyData[string(leafData.Key)] = dataEntry + } + + return nil +} + +func (tdaw *trackableDataTrie) getVersionForNewlyAddedData() core.TrieNodeVersion { + if tdaw.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() { + return core.AutoBalanceEnabled + } + + return core.NotSpecified +} + +func (tdaw *trackableDataTrie) getKeyForVersion(key []byte, version core.TrieNodeVersion) []byte { + if version == core.AutoBalanceEnabled { + return tdaw.hasher.Compute(string(key)) + } + + return key +} + +func (tdaw *trackableDataTrie) getValueForVersion(key []byte, value []byte, version core.TrieNodeVersion) ([]byte, error) { + if len(value) == 0 { + return nil, nil + } + + if version == core.AutoBalanceEnabled { + trieVal := &dataTrieValue.TrieLeafData{ + Value: value, + Key: key, + Address: tdaw.identifier, + } + + return tdaw.marshaller.Marshal(trieVal) + } + + identifier := append(key, tdaw.identifier...) + valueWithAppendedData := append(value, identifier...) + + return valueWithAppendedData, nil +} + // SetDataTrie sets the internal data trie func (tdaw *trackableDataTrie) SetDataTrie(tr common.Trie) { tdaw.tr = tr @@ -125,9 +219,9 @@ func (tdaw *trackableDataTrie) DataTrie() common.DataTrieHandler { } // SaveDirtyData saved the dirty data to the trie -func (tdaw *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]common.TrieData, error) { +func (tdaw *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]core.TrieData, error) { if len(tdaw.dirtyData) == 0 { - return make([]common.TrieData, 0), nil + return make([]core.TrieData, 0), nil } if check.IfNil(tdaw.tr) { @@ -144,38 +238,24 @@ func (tdaw *trackableDataTrie) SaveDirtyData(mainTrie common.Trie) ([]common.Tri return nil, fmt.Errorf("invalid trie, type is %T", tdaw.tr) } - if tdaw.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() { - return tdaw.updateTrieWithAutoBalancing(dtr) - } - - return tdaw.updateTrieV1(dtr) + return tdaw.updateTrie(dtr) } -func (tdaw *trackableDataTrie) updateTrieV1(selfDataTrie dataTrie) ([]common.TrieData, error) { - oldValues := make([]common.TrieData, len(tdaw.dirtyData)) +func (tdaw *trackableDataTrie) updateTrie(dtr dataTrie) ([]core.TrieData, error) { + oldValues := make([]core.TrieData, len(tdaw.dirtyData)) index := 0 - for key, val := range tdaw.dirtyData { - oldVal, _, err := tdaw.tr.Get([]byte(key)) + for key, dataEntry := range tdaw.dirtyData { + // TODO cache old value if it was previously retrieved from the trie + oldVal := tdaw.getOldValue([]byte(key), dataEntry) + oldValues[index] = oldVal + + err := tdaw.deleteOldEntryIfMigrated([]byte(key), dataEntry, oldVal) if err != nil { return nil, err } - oldEntry := common.TrieData{ - Key: []byte(key), - Value: oldVal, - Version: common.NotSpecified, - } - oldValues[index] = oldEntry - - var identifier []byte - if len(val) != 0 { - identifier = append([]byte(key), tdaw.identifier...) - } - - valueWithAppendedData := append(val, identifier...) - - err = selfDataTrie.UpdateWithVersion([]byte(key), valueWithAppendedData, common.NotSpecified) + err = tdaw.modifyTrie([]byte(key), dataEntry, oldVal, dtr) if err != nil { return nil, err } @@ -183,90 +263,90 @@ func (tdaw *trackableDataTrie) updateTrieV1(selfDataTrie dataTrie) ([]common.Tri index++ } - tdaw.dirtyData = make(map[string][]byte) + tdaw.dirtyData = make(map[string]dirtyData) + return oldValues, nil } -// TODO refactor to make the migration more generic. This code should be able to migrate between specified versions. - -func (tdaw *trackableDataTrie) updateTrieWithAutoBalancing(dtr dataTrie) ([]common.TrieData, error) { - oldValues := make([]common.TrieData, len(tdaw.dirtyData)) - - index := 0 - for key, val := range tdaw.dirtyData { - oldEntry, err := tdaw.getOldKeyAndValWithCleanup(key) - if err != nil { - return nil, err +func (tdaw *trackableDataTrie) getOldValue(key []byte, dataEntry dirtyData) core.TrieData { + if dataEntry.oldVersion.isValueKnown { + return core.TrieData{ + Key: key, + Value: dataEntry.value, + Version: dataEntry.oldVersion.version, } + } - oldValues[index] = oldEntry - - err = tdaw.updateValInTrieWithAutoBalancing([]byte(key), val, dtr) - if err != nil { - return nil, err + if tdaw.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() { + hashedKey := tdaw.hasher.Compute(string(key)) + oldVal, _, err := tdaw.tr.Get(hashedKey) + if err == nil && len(oldVal) != 0 { + return core.TrieData{ + Key: hashedKey, + Value: oldVal, + Version: core.AutoBalanceEnabled, + } } + } - index++ + oldVal, _, err := tdaw.tr.Get(key) + if err == nil && len(oldVal) != 0 { + return core.TrieData{ + Key: key, + Value: oldVal, + Version: core.NotSpecified, + } } - tdaw.dirtyData = make(map[string][]byte) - return oldValues, nil + newDataVersion := tdaw.getVersionForNewlyAddedData() + return core.TrieData{ + Key: tdaw.getKeyForVersion(key, newDataVersion), + Value: nil, + Version: newDataVersion, + } } -func (tdaw *trackableDataTrie) getOldKeyAndValWithCleanup(key string) (common.TrieData, error) { - hashedKey := tdaw.hasher.Compute(key) - - oldVal, _, err := tdaw.tr.Get(hashedKey) - if err == nil && len(oldVal) != 0 { - return common.TrieData{ - Key: hashedKey, - Value: oldVal, - Version: common.AutoBalanceEnabled, - }, nil +func (tdaw *trackableDataTrie) deleteOldEntryIfMigrated(key []byte, newData dirtyData, oldEntry core.TrieData) error { + if !tdaw.enableEpochsHandler.IsAutoBalanceDataTriesEnabled() { + return nil } - oldVal, _, err = tdaw.tr.Get([]byte(key)) - if err != nil { - return common.TrieData{}, err + if oldEntry.Version == core.NotSpecified && newData.newVersion == core.AutoBalanceEnabled { + return tdaw.tr.Delete(key) } - if len(oldVal) == 0 { - return common.TrieData{ - Key: hashedKey, - Value: nil, - Version: common.NotSpecified, - }, nil + return nil +} + +func (tdaw *trackableDataTrie) modifyTrie(key []byte, dataEntry dirtyData, oldVal core.TrieData, dtr dataTrie) error { + if len(dataEntry.value) == 0 { + return tdaw.deleteFromTrie(oldVal, key, dtr) } - err = tdaw.tr.Delete([]byte(key)) + version := dataEntry.newVersion + newKey := tdaw.getKeyForVersion(key, version) + value, err := tdaw.getValueForVersion(key, dataEntry.value, version) if err != nil { - return common.TrieData{}, err + return err } - return common.TrieData{ - Key: []byte(key), - Value: oldVal, - Version: common.NotSpecified, - }, nil + return dtr.UpdateWithVersion(newKey, value, version) } -func (tdaw *trackableDataTrie) updateValInTrieWithAutoBalancing(key []byte, val []byte, selfDataTrie dataTrie) error { - if len(val) == 0 { - return tdaw.tr.Delete(tdaw.hasher.Compute(string(key))) +func (tdaw *trackableDataTrie) deleteFromTrie(oldVal core.TrieData, key []byte, dtr dataTrie) error { + if len(oldVal.Value) == 0 { + return nil } - trieVal := &dataTrieValue.TrieLeafData{ - Value: val, - Key: key, - Address: tdaw.identifier, + if oldVal.Version == core.AutoBalanceEnabled { + return dtr.Delete(tdaw.hasher.Compute(string(key))) } - serializedTrieVal, err := tdaw.marshaller.Marshal(trieVal) - if err != nil { - return err + if oldVal.Version == core.NotSpecified { + return dtr.Delete(key) } - return selfDataTrie.UpdateWithVersion(tdaw.hasher.Compute(string(key)), serializedTrieVal, common.AutoBalanceEnabled) + return nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/state/trackableDataTrie_test.go b/state/trackableDataTrie_test.go index a21bdd2f138..8f75c93eb3b 100644 --- a/state/trackableDataTrie_test.go +++ b/state/trackableDataTrie_test.go @@ -8,12 +8,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-go/common" + errorsCommon "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/dataTrieValue" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/pkg/errors" "github.com/stretchr/testify/assert" ) @@ -88,7 +90,7 @@ func TestTrackableDataTrie_SaveKeyValue(t *testing.T) { dirtyData := tdt.DirtyData() assert.Equal(t, 1, len(dirtyData)) - assert.Equal(t, value, dirtyData[string(keyExpected)]) + assert.Equal(t, value, dirtyData[string(keyExpected)].Value) }) } @@ -276,7 +278,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { GetCalled: func(_ []byte) ([]byte, uint32, error) { return nil, 0, nil }, - UpdateWithVersionCalled: func(_, _ []byte, _ common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(_, _ []byte, _ core.TrieNodeVersion) error { return nil }, }, nil @@ -321,7 +323,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { } return nil, 0, nil }, - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { assert.Equal(t, hasher.Compute(string(expectedKey)), key) assert.Equal(t, serializedTrieVal, value) updateCalled = true @@ -368,7 +370,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { } return nil, 0, nil }, - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { assert.Equal(t, expectedKey, key) assert.Equal(t, expectedVal, value) updateCalled = true @@ -426,7 +428,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { } return nil, 0, nil }, - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { assert.Equal(t, hasher.Compute(string(expectedKey)), key) assert.Equal(t, serializedNewTrieVal, value) updateCalled = true @@ -473,7 +475,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { GetCalled: func(key []byte) ([]byte, uint32, error) { return nil, 0, nil }, - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { assert.Equal(t, hasher.Compute(string(expectedKey)), key) assert.Equal(t, serializedNewTrieVal, value) updateCalled = true @@ -509,7 +511,7 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { GetCalled: func(key []byte) ([]byte, uint32, error) { return nil, 0, nil }, - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { + UpdateWithVersionCalled: func(key, value []byte, version core.TrieNodeVersion) error { return nil }, } @@ -529,10 +531,9 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { updateCalled := false trie := &trieMock.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { - return nil, 0, nil + return []byte("value"), 0, nil }, - UpdateWithVersionCalled: func(key, value []byte, version common.TrieNodeVersion) error { - assert.Nil(t, value) + DeleteCalled: func(key []byte) error { assert.Equal(t, expectedKey, key) updateCalled = true return nil @@ -548,19 +549,48 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { assert.True(t, updateCalled) }) - t.Run("nil val autobalance enabled", func(t *testing.T) { + t.Run("nil val and nil old val", func(t *testing.T) { + t.Parallel() + + expectedKey := []byte("key") + deleteCalled := false + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + return nil, 0, nil + }, + DeleteCalled: func(key []byte) error { + assert.Equal(t, expectedKey, key) + deleteCalled = true + return nil + }, + } + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + + _ = tdt.SaveKeyValue(expectedKey, nil) + _, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 0, len(tdt.DirtyData())) + assert.False(t, deleteCalled) + }) + + t.Run("nil val autobalance enabled, old val saved at hashedKey", func(t *testing.T) { t.Parallel() hasher := &hashingMocks.HasherMock{} expectedKey := []byte("key") - updateCalled := false + deleteCalled := false trie := &trieMock.TrieStub{ GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(hasher.Compute(string(expectedKey)), key) { + return []byte("value"), 0, nil + } + return nil, 0, nil }, DeleteCalled: func(key []byte) error { assert.Equal(t, hasher.Compute(string(expectedKey)), key) - updateCalled = true + deleteCalled = true return nil }, } @@ -574,7 +604,122 @@ func TestTrackableDataTrie_SaveDirtyData(t *testing.T) { _, err := tdt.SaveDirtyData(trie) assert.Nil(t, err) assert.Equal(t, 0, len(tdt.DirtyData())) - assert.True(t, updateCalled) + assert.True(t, deleteCalled) + }) + + t.Run("nil val autobalance enabled, old val saved at key", func(t *testing.T) { + t.Parallel() + + expectedKey := []byte("key") + deleteCalled := false + trie := &trieMock.TrieStub{ + GetCalled: func(key []byte) ([]byte, uint32, error) { + if bytes.Equal(expectedKey, key) { + return []byte("value"), 0, nil + } + + return nil, 0, nil + }, + DeleteCalled: func(key []byte) error { + assert.Equal(t, expectedKey, key) + deleteCalled = true + return nil + }, + } + + enableEpchs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), trie, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs) + + _ = tdt.SaveKeyValue(expectedKey, nil) + _, err := tdt.SaveDirtyData(trie) + assert.Nil(t, err) + assert.Equal(t, 0, len(tdt.DirtyData())) + assert.True(t, deleteCalled) + }) +} + +func TestTrackableDataTrie_MigrateDataTrieLeaves(t *testing.T) { + t.Parallel() + + t.Run("nil trie", func(t *testing.T) { + t.Parallel() + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), nil, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + err := tdt.MigrateDataTrieLeaves(core.NotSpecified, core.AutoBalanceEnabled, &trieMock.DataTrieMigratorStub{}) + assert.Equal(t, state.ErrNilTrie, err) + }) + + t.Run("nil trie migrator", func(t *testing.T) { + t.Parallel() + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), &trieMock.TrieStub{}, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + err := tdt.MigrateDataTrieLeaves(core.NotSpecified, core.AutoBalanceEnabled, nil) + assert.Equal(t, errorsCommon.ErrNilTrieMigrator, err) + }) + + t.Run("CollectLeavesForMigrationFails", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected error") + tr := &trieMock.TrieStub{ + CollectLeavesForMigrationCalled: func(oldVersion core.TrieNodeVersion, newVersion core.TrieNodeVersion, trieMigrator vmcommon.DataTrieMigrator) error { + return expectedErr + }, + } + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), tr, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}) + err := tdt.MigrateDataTrieLeaves(core.NotSpecified, core.AutoBalanceEnabled, &trieMock.DataTrieMigratorStub{}) + assert.Equal(t, expectedErr, err) + }) + + t.Run("leaves that need to be migrated are added to dirty data", func(t *testing.T) { + t.Parallel() + + leavesToBeMigrated := []core.TrieData{ + { + Key: []byte("key1"), + Value: []byte("value1"), + Version: core.AutoBalanceEnabled, + }, + { + Key: []byte("key2"), + Value: []byte("value2"), + Version: core.AutoBalanceEnabled, + }, + { + Key: []byte("key3"), + Value: []byte("value3"), + Version: core.AutoBalanceEnabled, + }, + } + tr := &trieMock.TrieStub{ + CollectLeavesForMigrationCalled: func(oldVersion core.TrieNodeVersion, newVersion core.TrieNodeVersion, trieMigrator vmcommon.DataTrieMigrator) error { + return nil + }, + } + dtm := &trieMock.DataTrieMigratorStub{ + GetLeavesToBeMigratedCalled: func() []core.TrieData { + return leavesToBeMigrated + }, + } + enableEpchs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + } + + tdt, _ := state.NewTrackableDataTrie([]byte("identifier"), tr, &hashingMocks.HasherMock{}, &marshallerMock.MarshalizerMock{}, enableEpchs) + err := tdt.MigrateDataTrieLeaves(core.NotSpecified, 100, dtm) + assert.Nil(t, err) + + dirtyData := tdt.DirtyData() + assert.Equal(t, len(leavesToBeMigrated), len(dirtyData)) + for i := range leavesToBeMigrated { + d := dirtyData[string(leavesToBeMigrated[i].Key)] + assert.Equal(t, leavesToBeMigrated[i].Value, d.Value) + assert.Equal(t, leavesToBeMigrated[i].Version, d.OldVersion) + assert.Equal(t, core.TrieNodeVersion(100), d.NewVersion) + } }) } diff --git a/testscommon/state/accountWrapperMock.go b/testscommon/state/accountWrapperMock.go index 0229c415030..002d586d016 100644 --- a/testscommon/state/accountWrapperMock.go +++ b/testscommon/state/accountWrapperMock.go @@ -6,6 +6,7 @@ import ( "fmt" "math/big" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" @@ -188,7 +189,7 @@ func (awm *AccountWrapMock) DataTrie() common.DataTrieHandler { } // SaveDirtyData - -func (awm *AccountWrapMock) SaveDirtyData(trie common.Trie) ([]common.TrieData, error) { +func (awm *AccountWrapMock) SaveDirtyData(trie common.Trie) ([]core.TrieData, error) { return awm.trackableDataTrie.SaveDirtyData(trie) } diff --git a/testscommon/state/userAccountStub.go b/testscommon/state/userAccountStub.go index 95f63fde3ff..65b5cb8b06e 100644 --- a/testscommon/state/userAccountStub.go +++ b/testscommon/state/userAccountStub.go @@ -5,6 +5,7 @@ import ( "context" "math/big" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" ) @@ -157,7 +158,7 @@ func (u *UserAccountStub) SaveKeyValue(_ []byte, _ []byte) error { } // SaveDirtyData - -func (u *UserAccountStub) SaveDirtyData(_ common.Trie) ([]common.TrieData, error) { +func (u *UserAccountStub) SaveDirtyData(_ common.Trie) ([]core.TrieData, error) { return nil, nil } diff --git a/testscommon/trie/dataTrieMigratorStub.go b/testscommon/trie/dataTrieMigratorStub.go new file mode 100644 index 00000000000..57bab03dbc8 --- /dev/null +++ b/testscommon/trie/dataTrieMigratorStub.go @@ -0,0 +1,44 @@ +package trie + +import ( + "github.com/multiversx/mx-chain-core-go/core" +) + +// DataTrieMigratorStub - +type DataTrieMigratorStub struct { + ConsumeStorageLoadGasCalled func() bool + AddLeafToMigrationQueueCalled func(leafData core.TrieData, newLeafVersion core.TrieNodeVersion) (bool, error) + GetLeavesToBeMigratedCalled func() []core.TrieData +} + +// ConsumeStorageLoadGas - +func (d *DataTrieMigratorStub) ConsumeStorageLoadGas() bool { + if d.ConsumeStorageLoadGasCalled != nil { + return d.ConsumeStorageLoadGasCalled() + } + + return true +} + +// AddLeafToMigrationQueue - +func (d *DataTrieMigratorStub) AddLeafToMigrationQueue(leafData core.TrieData, newLeafVersion core.TrieNodeVersion) (bool, error) { + if d.AddLeafToMigrationQueueCalled != nil { + return d.AddLeafToMigrationQueueCalled(leafData, newLeafVersion) + } + + return true, nil +} + +// GetLeavesToBeMigrated - +func (d *DataTrieMigratorStub) GetLeavesToBeMigrated() []core.TrieData { + if d.GetLeavesToBeMigratedCalled != nil { + return d.GetLeavesToBeMigratedCalled() + } + + return nil +} + +// IsInterfaceNil - +func (d *DataTrieMigratorStub) IsInterfaceNil() bool { + return d == nil +} diff --git a/testscommon/trie/dataTrieTrackerStub.go b/testscommon/trie/dataTrieTrackerStub.go index eec79372bd8..94d541ccb6a 100644 --- a/testscommon/trie/dataTrieTrackerStub.go +++ b/testscommon/trie/dataTrieTrackerStub.go @@ -1,16 +1,20 @@ package trie import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) // DataTrieTrackerStub - type DataTrieTrackerStub struct { - RetrieveValueCalled func(key []byte) ([]byte, uint32, error) - SaveKeyValueCalled func(key []byte, value []byte) error - SetDataTrieCalled func(tr common.Trie) - DataTrieCalled func() common.Trie - SaveDirtyDataCalled func(trie common.Trie) ([]common.TrieData, error) + RetrieveValueCalled func(key []byte) ([]byte, uint32, error) + SaveKeyValueCalled func(key []byte, value []byte) error + SetDataTrieCalled func(tr common.Trie) + DataTrieCalled func() common.Trie + SaveDirtyDataCalled func(trie common.Trie) ([]core.TrieData, error) + SaveTrieDataCalled func(trieData core.TrieData) error + MigrateDataTrieLeavesCalled func(version core.TrieNodeVersion, newVersion core.TrieNodeVersion, migrator vmcommon.DataTrieMigrator) error } // RetrieveValue - @@ -48,12 +52,21 @@ func (dtts *DataTrieTrackerStub) DataTrie() common.DataTrieHandler { } // SaveDirtyData - -func (dtts *DataTrieTrackerStub) SaveDirtyData(mainTrie common.Trie) ([]common.TrieData, error) { +func (dtts *DataTrieTrackerStub) SaveDirtyData(mainTrie common.Trie) ([]core.TrieData, error) { if dtts.SaveDirtyDataCalled != nil { return dtts.SaveDirtyDataCalled(mainTrie) } - return make([]common.TrieData, 0), nil + return make([]core.TrieData, 0), nil +} + +// MigrateDataTrieLeaves - +func (dtts *DataTrieTrackerStub) MigrateDataTrieLeaves(version core.TrieNodeVersion, newVersion core.TrieNodeVersion, migrator vmcommon.DataTrieMigrator) error { + if dtts.MigrateDataTrieLeavesCalled != nil { + return dtts.MigrateDataTrieLeavesCalled(version, newVersion, migrator) + } + + return nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/testscommon/trie/trieStub.go b/testscommon/trie/trieStub.go index d5c97d74b0b..ba17991b86e 100644 --- a/testscommon/trie/trieStub.go +++ b/testscommon/trie/trieStub.go @@ -4,7 +4,9 @@ import ( "context" "errors" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var errNotImplemented = errors.New("not implemented") @@ -13,7 +15,7 @@ var errNotImplemented = errors.New("not implemented") type TrieStub struct { GetCalled func(key []byte) ([]byte, uint32, error) UpdateCalled func(key, value []byte) error - UpdateWithVersionCalled func(key, value []byte, version common.TrieNodeVersion) error + UpdateWithVersionCalled func(key, value []byte, version core.TrieNodeVersion) error DeleteCalled func(key []byte) error RootCalled func() ([]byte, error) CommitCalled func() error @@ -30,6 +32,7 @@ type TrieStub struct { GetSerializedNodeCalled func(bytes []byte) ([]byte, error) GetOldRootCalled func() []byte CloseCalled func() error + CollectLeavesForMigrationCalled func(oldVersion core.TrieNodeVersion, newVersion core.TrieNodeVersion, trieMigrator vmcommon.DataTrieMigrator) error } // GetStorageManager - @@ -87,7 +90,7 @@ func (ts *TrieStub) Update(key, value []byte) error { } // UpdateWithVersion - -func (ts *TrieStub) UpdateWithVersion(key []byte, value []byte, version common.TrieNodeVersion) error { +func (ts *TrieStub) UpdateWithVersion(key []byte, value []byte, version core.TrieNodeVersion) error { if ts.UpdateWithVersionCalled != nil { return ts.UpdateWithVersionCalled(key, value, version) } @@ -95,6 +98,15 @@ func (ts *TrieStub) UpdateWithVersion(key []byte, value []byte, version common.T return errNotImplemented } +// CollectLeavesForMigration - +func (ts *TrieStub) CollectLeavesForMigration(oldVersion core.TrieNodeVersion, newVersion core.TrieNodeVersion, trieMigrator vmcommon.DataTrieMigrator) error { + if ts.CollectLeavesForMigrationCalled != nil { + return ts.CollectLeavesForMigrationCalled(oldVersion, newVersion, trieMigrator) + } + + return errNotImplemented +} + // Delete - func (ts *TrieStub) Delete(key []byte) error { if ts.DeleteCalled != nil { diff --git a/trie/branchNode.go b/trie/branchNode.go index e20e910acfe..fe32ee0c605 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/errors" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var _ = node(&branchNode{}) @@ -42,10 +43,10 @@ func newBranchNode(marshalizer marshal.Marshalizer, hasher hashing.Hasher) (*bra }, nil } -func (bn *branchNode) setVersionForChild(version common.TrieNodeVersion, childPos byte) { +func (bn *branchNode) setVersionForChild(version core.TrieNodeVersion, childPos byte) { sliceNotInitialized := len(bn.ChildrenVersion) == 0 - if version == common.NotSpecified && sliceNotInitialized { + if version == core.NotSpecified && sliceNotInitialized { return } @@ -489,7 +490,7 @@ func (bn *branchNode) getNext(key []byte, db common.DBWriteCacher) (node, []byte return bn.children[childPos], key, nil } -func (bn *branchNode) insert(newData common.TrieData, db common.DBWriteCacher) (node, [][]byte, error) { +func (bn *branchNode) insert(newData core.TrieData, db common.DBWriteCacher) (node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := bn.isEmptyOrNil() if err != nil { @@ -517,7 +518,7 @@ func (bn *branchNode) insert(newData common.TrieData, db common.DBWriteCacher) ( return bn.insertOnExistingChild(newData, childPos, db) } -func (bn *branchNode) insertOnNilChild(newData common.TrieData, childPos byte) (node, [][]byte, error) { +func (bn *branchNode) insertOnNilChild(newData core.TrieData, childPos byte) (node, [][]byte, error) { newLn, err := newLeafNode(newData, bn.marsh, bn.hasher) if err != nil { return nil, [][]byte{}, err @@ -532,7 +533,7 @@ func (bn *branchNode) insertOnNilChild(newData common.TrieData, childPos byte) ( return bn, modifiedHashes, nil } -func (bn *branchNode) insertOnExistingChild(newData common.TrieData, childPos byte, db common.DBWriteCacher) (node, [][]byte, error) { +func (bn *branchNode) insertOnExistingChild(newData core.TrieData, childPos byte, db common.DBWriteCacher) (node, [][]byte, error) { newNode, modifiedHashes, err := bn.children[childPos].insert(newData, db) if check.IfNil(newNode) || err != nil { return nil, [][]byte{}, err @@ -945,9 +946,9 @@ func (bn *branchNode) collectStats(ts common.TrieStatisticsHandler, depthLevel i return nil } -func (bn *branchNode) getVersion() (common.TrieNodeVersion, error) { +func (bn *branchNode) getVersion() (core.TrieNodeVersion, error) { if len(bn.ChildrenVersion) == 0 { - return common.NotSpecified, nil + return core.NotSpecified, nil } index := 0 @@ -968,11 +969,68 @@ func (bn *branchNode) getVersion() (common.TrieNodeVersion, error) { } if bn.ChildrenVersion[i] != nodeVersion { - return common.NotSpecified, nil + return core.NotSpecified, nil } } - return common.TrieNodeVersion(nodeVersion), nil + return core.TrieNodeVersion(nodeVersion), nil +} + +func (bn *branchNode) getVersionForChild(childIndex byte) core.TrieNodeVersion { + if len(bn.ChildrenVersion) == 0 { + return core.NotSpecified + } + + return core.TrieNodeVersion(bn.ChildrenVersion[childIndex]) +} + +func (bn *branchNode) collectLeavesForMigration( + oldVersion core.TrieNodeVersion, + newVersion core.TrieNodeVersion, + trieMigrator vmcommon.DataTrieMigrator, + db common.DBWriteCacher, + keyBuilder common.KeyBuilder, +) (bool, error) { + shouldContinue := trieMigrator.ConsumeStorageLoadGas() + if !shouldContinue { + return false, nil + } + + shouldMigrateNode, err := shouldMigrateCurrentNode(bn, oldVersion, newVersion) + if err != nil { + return false, err + } + if !shouldMigrateNode { + return true, nil + } + + for i := range bn.children { + if bn.children[i] == nil && len(bn.EncodedChildren[i]) == 0 { + continue + } + + if bn.getVersionForChild(byte(i)) != oldVersion { + continue + } + + err = resolveIfCollapsed(bn, byte(i), db) + if err != nil { + return false, err + } + + clonedKeyBuilder := keyBuilder.Clone() + clonedKeyBuilder.BuildKey([]byte{byte(i)}) + shouldContinueMigrating, err := bn.children[i].collectLeavesForMigration(oldVersion, newVersion, trieMigrator, db, clonedKeyBuilder) + if err != nil { + return false, err + } + + if !shouldContinueMigrating { + return false, nil + } + } + + return true, nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index b58e6d558d5..2229a7c593d 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -7,6 +7,7 @@ import ( "fmt" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/mock" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" @@ -29,11 +30,11 @@ func getTestMarshalizerAndHasher() (marshal.Marshalizer, hashing.Hasher) { return marsh, hash } -func getTrieDataWithDefaultVersion(key string, val string) common.TrieData { - return common.TrieData{ +func getTrieDataWithDefaultVersion(key string, val string) core.TrieData { + return core.TrieData{ Key: []byte(key), Value: []byte(val), - Version: common.NotSpecified, + Version: core.NotSpecified, } } @@ -1402,7 +1403,7 @@ func TestBranchNode_getVersion(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) version, err := bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) assert.Nil(t, err) }) @@ -1411,12 +1412,12 @@ func TestBranchNode_getVersion(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) bn.ChildrenVersion = make([]byte, nrOfChildren) - bn.ChildrenVersion[2] = byte(common.NotSpecified) - bn.ChildrenVersion[6] = byte(common.NotSpecified) - bn.ChildrenVersion[13] = byte(common.NotSpecified) + bn.ChildrenVersion[2] = byte(core.NotSpecified) + bn.ChildrenVersion[6] = byte(core.NotSpecified) + bn.ChildrenVersion[13] = byte(core.NotSpecified) version, err := bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) assert.Nil(t, err) }) @@ -1425,12 +1426,12 @@ func TestBranchNode_getVersion(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) bn.ChildrenVersion = make([]byte, nrOfChildren) - bn.ChildrenVersion[2] = byte(common.NotSpecified) - bn.ChildrenVersion[6] = byte(common.AutoBalanceEnabled) - bn.ChildrenVersion[13] = byte(common.NotSpecified) + bn.ChildrenVersion[2] = byte(core.NotSpecified) + bn.ChildrenVersion[6] = byte(core.AutoBalanceEnabled) + bn.ChildrenVersion[13] = byte(core.NotSpecified) version, err := bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) assert.Nil(t, err) }) @@ -1439,12 +1440,12 @@ func TestBranchNode_getVersion(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) bn.ChildrenVersion = make([]byte, nrOfChildren) - bn.ChildrenVersion[2] = byte(common.AutoBalanceEnabled) - bn.ChildrenVersion[6] = byte(common.AutoBalanceEnabled) - bn.ChildrenVersion[13] = byte(common.AutoBalanceEnabled) + bn.ChildrenVersion[2] = byte(core.AutoBalanceEnabled) + bn.ChildrenVersion[6] = byte(core.AutoBalanceEnabled) + bn.ChildrenVersion[13] = byte(core.AutoBalanceEnabled) version, err := bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) assert.Nil(t, err) }) } diff --git a/trie/extensionNode.go b/trie/extensionNode.go index d97ad6f3a27..a9c100a6337 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/errors" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var _ = node(&extensionNode{}) @@ -383,7 +384,7 @@ func (en *extensionNode) getNext(key []byte, db common.DBWriteCacher) (node, []b return en.child, key, nil } -func (en *extensionNode) insert(newData common.TrieData, db common.DBWriteCacher) (node, [][]byte, error) { +func (en *extensionNode) insert(newData core.TrieData, db common.DBWriteCacher) (node, [][]byte, error) { emptyHashes := make([][]byte, 0) err := en.isEmptyOrNil() if err != nil { @@ -406,7 +407,7 @@ func (en *extensionNode) insert(newData common.TrieData, db common.DBWriteCacher return en.insertInNewBn(newData, keyMatchLen) } -func (en *extensionNode) insertInSameEn(newData common.TrieData, keyMatchLen int, db common.DBWriteCacher) (node, [][]byte, error) { +func (en *extensionNode) insertInSameEn(newData core.TrieData, keyMatchLen int, db common.DBWriteCacher) (node, [][]byte, error) { newData.Key = newData.Key[keyMatchLen:] newNode, oldHashes, err := en.child.insert(newData, db) if check.IfNil(newNode) || err != nil { @@ -425,7 +426,7 @@ func (en *extensionNode) insertInSameEn(newData common.TrieData, keyMatchLen int return newEn, oldHashes, nil } -func (en *extensionNode) insertInNewBn(newData common.TrieData, keyMatchLen int) (node, [][]byte, error) { +func (en *extensionNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, [][]byte, error) { oldHash := make([][]byte, 0) if !en.dirty { oldHash = append(oldHash, en.hash) @@ -486,7 +487,7 @@ func (en *extensionNode) insertOldChildInBn(bn *branchNode, oldChildPos byte, ke return nil } -func (en *extensionNode) insertNewChildInBn(bn *branchNode, newData common.TrieData, newChildPos byte, keyMatchLen int) error { +func (en *extensionNode) insertNewChildInBn(bn *branchNode, newData core.TrieData, newChildPos byte, keyMatchLen int) error { newData.Key = newData.Key[keyMatchLen+1:] newLeaf, err := newLeafNode(newData, en.marsh, en.hasher) @@ -528,10 +529,10 @@ func (en *extensionNode) delete(key []byte, db common.DBWriteCacher) (bool, node switch newNode := newNode.(type) { case *leafNode: - newLeafData := common.TrieData{ + newLeafData := core.TrieData{ Key: concat(en.Key, newNode.Key...), Value: newNode.Value, - Version: common.TrieNodeVersion(newNode.Version), + Version: core.TrieNodeVersion(newNode.Version), } n, err := newLeafNode(newLeafData, en.marsh, en.hasher) if err != nil { @@ -800,13 +801,42 @@ func (en *extensionNode) collectStats(ts common.TrieStatisticsHandler, depthLeve return nil } -func (en *extensionNode) getVersion() (common.TrieNodeVersion, error) { +func (en *extensionNode) getVersion() (core.TrieNodeVersion, error) { if en.ChildVersion > math.MaxUint8 { log.Warn("invalid trie node version for extension node", "child version", en.ChildVersion, "max version", math.MaxUint8) - return common.NotSpecified, ErrInvalidNodeVersion + return core.NotSpecified, ErrInvalidNodeVersion } - return common.TrieNodeVersion(en.ChildVersion), nil + return core.TrieNodeVersion(en.ChildVersion), nil +} + +func (en *extensionNode) collectLeavesForMigration( + oldVersion core.TrieNodeVersion, + newVersion core.TrieNodeVersion, + trieMigrator vmcommon.DataTrieMigrator, + db common.DBWriteCacher, + keyBuilder common.KeyBuilder, +) (bool, error) { + hasEnoughGasToContinueMigration := trieMigrator.ConsumeStorageLoadGas() + if !hasEnoughGasToContinueMigration { + return false, nil + } + + shouldMigrateNode, err := shouldMigrateCurrentNode(en, oldVersion, newVersion) + if err != nil { + return false, err + } + if !shouldMigrateNode { + return true, nil + } + + err = resolveIfCollapsed(en, 0, db) + if err != nil { + return false, err + } + + keyBuilder.BuildKey(en.Key) + return en.child.collectLeavesForMigration(oldVersion, newVersion, trieMigrator, db, keyBuilder.Clone()) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index 8dcca6247b0..c9d982b71e7 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -7,6 +7,7 @@ import ( "math" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/mock" "github.com/multiversx/mx-chain-go/common" chainErrors "github.com/multiversx/mx-chain-go/errors" @@ -1057,7 +1058,7 @@ func TestExtensionNode_getVersion(t *testing.T) { en.ChildVersion = math.MaxUint8 + 1 version, err := en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) assert.Equal(t, ErrInvalidNodeVersion, err) }) @@ -1067,7 +1068,7 @@ func TestExtensionNode_getVersion(t *testing.T) { en, _ := getEnAndCollapsedEn() version, err := en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) assert.Nil(t, err) }) @@ -1075,10 +1076,10 @@ func TestExtensionNode_getVersion(t *testing.T) { t.Parallel() en, _ := getEnAndCollapsedEn() - en.ChildVersion = uint32(common.AutoBalanceEnabled) + en.ChildVersion = uint32(core.AutoBalanceEnabled) version, err := en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) assert.Nil(t, err) }) } diff --git a/trie/interface.go b/trie/interface.go index 73db88848b1..e0a32a09e24 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -29,7 +29,7 @@ type node interface { hashChildren() error tryGet(key []byte, depth uint32, db common.DBWriteCacher) ([]byte, uint32, error) getNext(key []byte, db common.DBWriteCacher) (node, []byte, error) - insert(newData common.TrieData, db common.DBWriteCacher) (node, [][]byte, error) + insert(newData core.TrieData, db common.DBWriteCacher) (node, [][]byte, error) delete(key []byte, db common.DBWriteCacher) (bool, node, [][]byte, error) reduceNode(pos int) (node, bool, error) isEmptyOrNil() error @@ -43,7 +43,8 @@ type node interface { getAllHashes(db common.DBWriteCacher) ([][]byte, error) getNextHashAndKey([]byte) (bool, []byte, []byte) getValue() []byte - getVersion() (common.TrieNodeVersion, error) + getVersion() (core.TrieNodeVersion, error) + collectLeavesForMigration(oldVersion core.TrieNodeVersion, newVersion core.TrieNodeVersion, trieMigrator vmcommon.DataTrieMigrator, db common.DBWriteCacher, keyBuilder common.KeyBuilder) (bool, error) commitDirty(level byte, maxTrieLevelInMemory uint, originDb common.DBWriteCacher, targetDb common.DBWriteCacher) error commitCheckpoint(originDb common.DBWriteCacher, targetDb common.DBWriteCacher, checkpointHashes CheckpointHashesHolder, leavesChan chan core.KeyValueHolder, ctx context.Context, stats common.TrieStatisticsHandler, idleProvider IdleNodeProvider, depthLevel int) error diff --git a/trie/leafNode.go b/trie/leafNode.go index e29b78a8575..62e4baa38ec 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -16,12 +16,13 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/errors" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var _ = node(&leafNode{}) func newLeafNode( - newData common.TrieData, + newData core.TrieData, marshalizer marshal.Marshalizer, hasher hashing.Hasher, ) (*leafNode, error) { @@ -276,7 +277,7 @@ func (ln *leafNode) getNext(key []byte, _ common.DBWriteCacher) (node, []byte, e } return nil, nil, ErrNodeNotFound } -func (ln *leafNode) insert(newData common.TrieData, _ common.DBWriteCacher) (node, [][]byte, error) { +func (ln *leafNode) insert(newData core.TrieData, _ common.DBWriteCacher) (node, [][]byte, error) { err := ln.isEmptyOrNil() if err != nil { return nil, [][]byte{}, fmt.Errorf("insert error %w", err) @@ -311,7 +312,7 @@ func (ln *leafNode) insert(newData common.TrieData, _ common.DBWriteCacher) (nod return newEn, oldHash, nil } -func (ln *leafNode) insertInSameLn(newData common.TrieData, oldHashes [][]byte) (node, [][]byte, error) { +func (ln *leafNode) insertInSameLn(newData core.TrieData, oldHashes [][]byte) (node, [][]byte, error) { if bytes.Equal(ln.Value, newData.Value) { return nil, [][]byte{}, nil } @@ -323,7 +324,7 @@ func (ln *leafNode) insertInSameLn(newData common.TrieData, oldHashes [][]byte) return ln, oldHashes, nil } -func (ln *leafNode) insertInNewBn(newData common.TrieData, keyMatchLen int) (node, error) { +func (ln *leafNode) insertInNewBn(newData core.TrieData, keyMatchLen int) (node, error) { bn, err := newBranchNode(ln.marsh, ln.hasher) if err != nil { return nil, err @@ -340,7 +341,7 @@ func (ln *leafNode) insertInNewBn(newData common.TrieData, keyMatchLen int) (nod return nil, err } - oldLnData := common.TrieData{ + oldLnData := core.TrieData{ Key: ln.Key[keyMatchLen+1:], Value: ln.Value, Version: oldLnVersion, @@ -383,7 +384,7 @@ func (ln *leafNode) reduceNode(pos int) (node, bool, error) { return nil, false, err } - oldLnData := common.TrieData{ + oldLnData := core.TrieData{ Key: k, Value: ln.Value, Version: oldLnVersion, @@ -550,13 +551,53 @@ func (ln *leafNode) collectStats(ts common.TrieStatisticsHandler, depthLevel int return nil } -func (ln *leafNode) getVersion() (common.TrieNodeVersion, error) { +func (ln *leafNode) getVersion() (core.TrieNodeVersion, error) { if ln.Version > math.MaxUint8 { log.Warn("invalid trie node version", "version", ln.Version, "max version", math.MaxUint8) - return common.NotSpecified, ErrInvalidNodeVersion + return core.NotSpecified, ErrInvalidNodeVersion } - return common.TrieNodeVersion(ln.Version), nil + return core.TrieNodeVersion(ln.Version), nil +} + +func (ln *leafNode) collectLeavesForMigration( + oldVersion core.TrieNodeVersion, + newVersion core.TrieNodeVersion, + trieMigrator vmcommon.DataTrieMigrator, + _ common.DBWriteCacher, + keyBuilder common.KeyBuilder, +) (bool, error) { + shouldContinue := trieMigrator.ConsumeStorageLoadGas() + if !shouldContinue { + return false, nil + } + + shouldMigrateNode, err := shouldMigrateCurrentNode(ln, oldVersion, newVersion) + if err != nil { + return false, err + } + if !shouldMigrateNode { + return true, nil + } + + keyBuilder.BuildKey(ln.Key) + key, err := keyBuilder.GetKey() + if err != nil { + return false, err + } + + version, err := ln.getVersion() + if err != nil { + return false, err + } + + leafData := core.TrieData{ + Key: key, + Value: ln.Value, + Version: version, + } + + return trieMigrator.AddLeafToMigrationQueue(leafData, newVersion) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/trie/leafNode_test.go b/trie/leafNode_test.go index ed54ce071e6..8b73b7a4ad7 100644 --- a/trie/leafNode_test.go +++ b/trie/leafNode_test.go @@ -9,7 +9,6 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" - "github.com/multiversx/mx-chain-go/common" chainErrors "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/storage/cache" "github.com/multiversx/mx-chain-go/testscommon" @@ -753,7 +752,7 @@ func TestLeafNode_getVersion(t *testing.T) { ln.Version = math.MaxUint8 + 1 version, err := ln.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) assert.Equal(t, ErrInvalidNodeVersion, err) }) @@ -763,7 +762,7 @@ func TestLeafNode_getVersion(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) version, err := ln.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) assert.Nil(t, err) }) @@ -771,10 +770,10 @@ func TestLeafNode_getVersion(t *testing.T) { t.Parallel() ln := getLn(getTestMarshalizerAndHasher()) - ln.Version = uint32(common.AutoBalanceEnabled) + ln.Version = uint32(core.AutoBalanceEnabled) version, err := ln.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) assert.Nil(t, err) }) } diff --git a/trie/node.go b/trie/node.go index 2fc924ef457..439a6882de1 100644 --- a/trie/node.go +++ b/trie/node.go @@ -7,6 +7,7 @@ import ( "fmt" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" @@ -270,3 +271,24 @@ func treatCommitSnapshotError(err error, hash []byte, missingNodesChan chan []by log.Error("error during trie snapshot", "err", err.Error(), "hash", hash) missingNodesChan <- hash } + +func shouldMigrateCurrentNode( + currentNode node, + oldVersion core.TrieNodeVersion, + newVersion core.TrieNodeVersion, +) (bool, error) { + version, err := currentNode.getVersion() + if err != nil { + return false, err + } + + if version == newVersion { + return false, nil + } + + if version != oldVersion && version != core.NotSpecified { + return false, nil + } + + return true, nil +} diff --git a/trie/node_test.go b/trie/node_test.go index 9cddeb96d89..00f7e74b79c 100644 --- a/trie/node_test.go +++ b/trie/node_test.go @@ -632,15 +632,15 @@ func TestNodesVersion_insertInLn(t *testing.T) { tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.NotSpecified) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) ln, ok := tr.root.(*leafNode) assert.True(t, ok) version, _ := ln.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aab"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aab"), core.AutoBalanceEnabled) version, _ = ln.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("insert in leaf - create new branch node", func(t *testing.T) { @@ -648,20 +648,20 @@ func TestNodesVersion_insertInLn(t *testing.T) { tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.NotSpecified) - _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) tr, _ = newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) bn, ok = tr.root.(*branchNode) assert.True(t, ok) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("insert in leaf - create new extension", func(t *testing.T) { @@ -669,20 +669,20 @@ func TestNodesVersion_insertInLn(t *testing.T) { tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.NotSpecified) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) tr, _ = newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) en, ok = tr.root.(*extensionNode) assert.True(t, ok) version, _ = en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) } @@ -694,132 +694,132 @@ func TestNodesVersion_insertInEn(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), common.NotSpecified) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) en, ok = tr.root.(*extensionNode) assert.True(t, ok) version, _ = en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), common.NotSpecified) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) en, ok = tr.root.(*extensionNode) assert.True(t, ok) version, _ = en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) }) t.Run("insert in extension node - create new branch - change version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("qqq"), []byte("qqq"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("qqq"), []byte("qqq"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), common.NotSpecified) + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), core.NotSpecified) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ = bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) }) t.Run("insert in extension node - create new branch - do not change version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("qqq"), []byte("qqq"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("qqq"), []byte("qqq"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("insert in extension node - create new branch with following extension node - change version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), common.NotSpecified) + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), core.NotSpecified) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ = bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) }) t.Run("insert in extension node - create new branch with following extension node - do not change version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("zzz"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("insert in extension node - create new extension and branch - change version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("bba"), []byte("bba"), common.NotSpecified) + _ = tr.UpdateWithVersion([]byte("bba"), []byte("bba"), core.NotSpecified) en, ok = tr.root.(*extensionNode) assert.True(t, ok) version, _ = en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) }) t.Run("insert in extension node - create new extension and branch - do not change version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("bba"), []byte("bba"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bba"), []byte("bba"), core.AutoBalanceEnabled) en, ok = tr.root.(*extensionNode) assert.True(t, ok) version, _ = en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) } @@ -831,16 +831,16 @@ func TestNodesVersion_insertInBn(t *testing.T) { tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), core.AutoBalanceEnabled) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("insert in branch node on nil child - change version", func(t *testing.T) { @@ -848,16 +848,16 @@ func TestNodesVersion_insertInBn(t *testing.T) { tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), common.NotSpecified) + _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), core.NotSpecified) version, _ = bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) }) t.Run("insert in branch node on existing child - same version", func(t *testing.T) { @@ -865,16 +865,16 @@ func TestNodesVersion_insertInBn(t *testing.T) { tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aab"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aab"), core.AutoBalanceEnabled) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("insert in branch node on existing child - change version", func(t *testing.T) { @@ -882,16 +882,16 @@ func TestNodesVersion_insertInBn(t *testing.T) { tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.NotSpecified) - _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aab"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aab"), core.AutoBalanceEnabled) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) } @@ -902,112 +902,112 @@ func TestNodesVersion_deleteFromEn(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), common.NotSpecified) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) _ = tr.Delete([]byte("aaa")) ln, ok := tr.root.(*leafNode) assert.True(t, ok) version, _ = ln.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("new child is leaf node - same version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), common.NotSpecified) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) _ = tr.Delete([]byte("baa")) ln, ok := tr.root.(*leafNode) assert.True(t, ok) version, _ = ln.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) }) t.Run("new child is extension node - same version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("zza"), []byte("zza"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zza"), []byte("zza"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) _ = tr.Delete([]byte("zza")) en, ok = tr.root.(*extensionNode) assert.True(t, ok) version, _ = en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("new child is extension node - change version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("zza"), []byte("zza"), common.NotSpecified) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zza"), []byte("zza"), core.NotSpecified) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) _ = tr.Delete([]byte("zza")) en, ok = tr.root.(*extensionNode) assert.True(t, ok) version, _ = en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("new child is branch node - same version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("bba"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bba"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) _ = tr.Delete([]byte("aaa")) bn, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("new child is branch node - change version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), common.NotSpecified) - _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("bba"), []byte("baa"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("baa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("baa"), []byte("baa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bba"), []byte("baa"), core.AutoBalanceEnabled) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ := en.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) _ = tr.Delete([]byte("aaa")) bn, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) } @@ -1018,94 +1018,94 @@ func TestNodesVersion_deleteFromBn(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) _ = tr.Delete([]byte("aaa")) bn, ok = tr.root.(*branchNode) assert.True(t, ok) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("delete leaf - branch does not reduce - bn should change version", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.NotSpecified) - _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("ccc"), []byte("ccc"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) _ = tr.Delete([]byte("aaa")) bn, ok = tr.root.(*branchNode) assert.True(t, ok) version, _ = bn.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("branch with branch child is reduced", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("qqq"), []byte("bbb"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("zzz"), []byte("ccc"), common.NotSpecified) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("qqq"), []byte("bbb"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("ccc"), core.NotSpecified) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) _ = tr.Delete([]byte("zzz")) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ = en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("branch with extension child is reduced", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("bba"), []byte("bbb"), common.AutoBalanceEnabled) - _ = tr.UpdateWithVersion([]byte("zzz"), []byte("ccc"), common.NotSpecified) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("bba"), []byte("bbb"), core.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("zzz"), []byte("ccc"), core.NotSpecified) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) _ = tr.Delete([]byte("zzz")) en, ok := tr.root.(*extensionNode) assert.True(t, ok) version, _ = en.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) t.Run("branch with leaf child is reduced", func(t *testing.T) { t.Parallel() tr, _ := newEmptyTrie() - _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), common.NotSpecified) - _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), common.AutoBalanceEnabled) + _ = tr.UpdateWithVersion([]byte("aaa"), []byte("aaa"), core.NotSpecified) + _ = tr.UpdateWithVersion([]byte("bbb"), []byte("bbb"), core.AutoBalanceEnabled) bn, ok := tr.root.(*branchNode) assert.True(t, ok) version, _ := bn.getVersion() - assert.Equal(t, common.NotSpecified, version) + assert.Equal(t, core.NotSpecified, version) _ = tr.Delete([]byte("aaa")) ln, ok := tr.root.(*leafNode) assert.True(t, ok) version, _ = ln.getVersion() - assert.Equal(t, common.AutoBalanceEnabled, version) + assert.Equal(t, core.AutoBalanceEnabled, version) }) } diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index 1f8ef089b95..c06b2675262 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -7,14 +7,17 @@ import ( "fmt" "sync" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/statistics" logger "github.com/multiversx/mx-chain-logger-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) var log = logger.GetOrCreate("trie") @@ -32,11 +35,12 @@ const rootDepthLevel = 0 type patriciaMerkleTrie struct { root node - trieStorage common.StorageManager - marshalizer marshal.Marshalizer - hasher hashing.Hasher - enableEpochsHandler common.EnableEpochsHandler - mutOperation sync.RWMutex + trieStorage common.StorageManager + marshalizer marshal.Marshalizer + hasher hashing.Hasher + enableEpochsHandler common.EnableEpochsHandler + trieNodeVersionVerifier core.TrieNodeVersionVerifier + mutOperation sync.RWMutex oldHashes [][]byte oldRoot []byte @@ -69,15 +73,21 @@ func NewTrie( } log.Trace("created new trie", "max trie level in memory", maxTrieLevelInMemory) + tnvv, err := core.NewTrieNodeVersionVerifier(enableEpochsHandler) + if err != nil { + return nil, err + } + return &patriciaMerkleTrie{ - trieStorage: trieStorage, - marshalizer: msh, - hasher: hsh, - oldHashes: make([][]byte, 0), - oldRoot: make([]byte, 0), - maxTrieLevelInMemory: maxTrieLevelInMemory, - chanClose: make(chan struct{}), - enableEpochsHandler: enableEpochsHandler, + trieStorage: trieStorage, + marshalizer: msh, + hasher: hsh, + oldHashes: make([][]byte, 0), + oldRoot: make([]byte, 0), + maxTrieLevelInMemory: maxTrieLevelInMemory, + chanClose: make(chan struct{}), + enableEpochsHandler: enableEpochsHandler, + trieNodeVersionVerifier: tnvv, }, nil } @@ -113,11 +123,11 @@ func (tr *patriciaMerkleTrie) Update(key, value []byte) error { "val", hex.EncodeToString(value), ) - return tr.update(key, value, common.NotSpecified) + return tr.update(key, value, core.NotSpecified) } // UpdateWithVersion does the same thing as Update, but the new leaf that is created will be of the specified version -func (tr *patriciaMerkleTrie) UpdateWithVersion(key []byte, value []byte, version common.TrieNodeVersion) error { +func (tr *patriciaMerkleTrie) UpdateWithVersion(key []byte, value []byte, version core.TrieNodeVersion) error { tr.mutOperation.Lock() defer tr.mutOperation.Unlock() @@ -130,10 +140,10 @@ func (tr *patriciaMerkleTrie) UpdateWithVersion(key []byte, value []byte, versio return tr.update(key, value, version) } -func (tr *patriciaMerkleTrie) update(key []byte, value []byte, version common.TrieNodeVersion) error { +func (tr *patriciaMerkleTrie) update(key []byte, value []byte, version core.TrieNodeVersion) error { hexKey := keyBytesToHex(key) if len(value) != 0 { - newData := common.TrieData{ + newData := core.TrieData{ Key: hexKey, Value: value, Version: version, @@ -657,6 +667,52 @@ func (tr *patriciaMerkleTrie) GetTrieStats(address string, rootHash []byte) (*st return ts.GetTrieStats(), nil } +// CollectLeavesForMigration will collect trie leaves that need to be migrated. The leaves are collected in the trieMigrator. +// The traversing of the trie is done in a DFS manner, and it will stop when the gas runs out (this will be signaled by the trieMigrator). +func (tr *patriciaMerkleTrie) CollectLeavesForMigration( + oldVersion core.TrieNodeVersion, + newVersion core.TrieNodeVersion, + trieMigrator vmcommon.DataTrieMigrator, +) error { + tr.mutOperation.Lock() + defer tr.mutOperation.Unlock() + + if check.IfNil(tr.root) { + return nil + } + if check.IfNil(trieMigrator) { + return errors.ErrNilTrieMigrator + } + + err := tr.checkIfMigrationPossible(newVersion, oldVersion) + if err != nil { + return err + } + + _, err = tr.root.collectLeavesForMigration(oldVersion, newVersion, trieMigrator, tr.trieStorage, keyBuilder.NewKeyBuilder()) + if err != nil { + return err + } + + return nil +} + +func (tr *patriciaMerkleTrie) checkIfMigrationPossible(newVersion core.TrieNodeVersion, oldVersion core.TrieNodeVersion) error { + if !tr.trieNodeVersionVerifier.IsValidVersion(newVersion) { + return fmt.Errorf("%w: newVersion %v", errors.ErrInvalidTrieNodeVersion, newVersion) + } + + if !tr.trieNodeVersionVerifier.IsValidVersion(oldVersion) { + return fmt.Errorf("%w: oldVersion %v", errors.ErrInvalidTrieNodeVersion, oldVersion) + } + + if newVersion == core.NotSpecified && oldVersion == core.AutoBalanceEnabled { + return fmt.Errorf("%w: cannot migrate from %v to %v", errors.ErrInvalidTrieNodeVersion, core.AutoBalanceEnabled, core.NotSpecified) + } + + return nil +} + // Close stops all the active goroutines started by the trie func (tr *patriciaMerkleTrie) Close() error { tr.mutOperation.Lock() diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index 76a01be7924..87faf4cd2c8 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -7,6 +7,7 @@ import ( "fmt" "math/rand" "strconv" + "strings" "sync" "testing" "time" @@ -28,6 +29,7 @@ import ( "github.com/multiversx/mx-chain-go/trie/hashesHolder" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/mock" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -40,6 +42,13 @@ func emptyTrie() common.Trie { return tr } +func emptyTrieWithCustomEnableEpochsHandler(handler common.EnableEpochsHandler) common.Trie { + storage, marshaller, hasher, _, maxTrieLevelInMem := getDefaultTrieParameters() + + tr, _ := trie.NewTrie(storage, marshaller, hasher, handler, maxTrieLevelInMem) + return tr +} + func getDefaultTrieStorageManagerParameters() trie.NewTrieStorageManagerArgs { marshalizer := &testscommon.ProtobufMarshalizerMock{} hasher := &testscommon.KeccakMock{} @@ -85,11 +94,15 @@ func initTrieMultipleValues(nr int) (common.Trie, [][]byte) { func initTrie() common.Trie { tr := emptyTrie() + addDefaultDataToTrie(tr) + + return tr +} + +func addDefaultDataToTrie(tr common.Trie) { _ = tr.Update([]byte("doe"), []byte("reindeer")) _ = tr.Update([]byte("dog"), []byte("puppy")) _ = tr.Update([]byte("ddog"), []byte("cat")) - - return tr } func TestNewTrieWithNilTrieStorage(t *testing.T) { @@ -1161,6 +1174,308 @@ func TestPatriciaMerkleTrie_GetSerializedNodesClose(t *testing.T) { } } +type dataTrie interface { + CollectLeavesForMigration(oldVersion core.TrieNodeVersion, newVersion core.TrieNodeVersion, trieMigrator vmcommon.DataTrieMigrator) error + UpdateWithVersion(key []byte, value []byte, version core.TrieNodeVersion) error +} + +func TestPatriciaMerkleTrie_CollectLeavesForMigration(t *testing.T) { + t.Parallel() + + t.Run("nil root", func(t *testing.T) { + t.Parallel() + + tr := emptyTrie() + + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + assert.Fail(t, "should not have called this function") + return false + }, + } + + err := tr.(dataTrie).CollectLeavesForMigration(core.NotSpecified, core.AutoBalanceEnabled, dtm) + assert.Nil(t, err) + }) + + t.Run("nil trie migrator", func(t *testing.T) { + t.Parallel() + + tr := initTrie().(dataTrie) + + err := tr.CollectLeavesForMigration(core.NotSpecified, core.AutoBalanceEnabled, nil) + assert.Equal(t, errorsCommon.ErrNilTrieMigrator, err) + }) + + t.Run("data trie already migrated", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + dtr := tr.(dataTrie) + _ = dtr.UpdateWithVersion([]byte("dog"), []byte("reindeer"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("ddog"), []byte("puppy"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("doe"), []byte("cat"), core.AutoBalanceEnabled) + + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + } + + err := dtr.CollectLeavesForMigration(core.NotSpecified, core.AutoBalanceEnabled, dtm) + assert.Nil(t, err) + assert.Equal(t, 1, numLoadsCalled) + }) + + t.Run("trie partially migrated", func(t *testing.T) { + t.Parallel() + + addLeafToMigrationQueueCalled := 0 + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + dtr := tr.(dataTrie) + key := []byte("dog") + value := []byte("reindeer") + _ = dtr.UpdateWithVersion(key, value, core.NotSpecified) + _ = dtr.UpdateWithVersion([]byte("ddog"), []byte("puppy"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("doe"), []byte("cat"), core.AutoBalanceEnabled) + + dtm := &trieMock.DataTrieMigratorStub{ + AddLeafToMigrationQueueCalled: func(leafData core.TrieData, newLeafVersion core.TrieNodeVersion) (bool, error) { + assert.Equal(t, core.AutoBalanceEnabled, newLeafVersion) + assert.Equal(t, key, leafData.Key) + assert.Equal(t, value, leafData.Value) + assert.Equal(t, core.NotSpecified, leafData.Version) + addLeafToMigrationQueueCalled++ + return true, nil + }, + } + + err := dtr.CollectLeavesForMigration(core.NotSpecified, core.AutoBalanceEnabled, dtm) + assert.Nil(t, err) + assert.Equal(t, 1, addLeafToMigrationQueueCalled) + }) + + t.Run("not enough gas to load the whole trie", func(t *testing.T) { + t.Parallel() + + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + addDefaultDataToTrie(tr) + + dtr := tr.(dataTrie) + numLoads := 0 + numAddLeafToMigrationQueueCalled := 0 + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + if numLoads < 2 { + numLoads++ + return true + } + + numLoads++ + return false + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + err := dtr.CollectLeavesForMigration(core.NotSpecified, core.AutoBalanceEnabled, dtm) + assert.Nil(t, err) + assert.Equal(t, 3, numLoads) + assert.Equal(t, 1, numAddLeafToMigrationQueueCalled) + }) + + t.Run("not enough gas to migrate the whole trie", func(t *testing.T) { + t.Parallel() + + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + addDefaultDataToTrie(tr) + dtr := tr.(dataTrie) + numLoads := 0 + numAddLeafToMigrationQueueCalled := 0 + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoads++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + if numAddLeafToMigrationQueueCalled < 1 { + numAddLeafToMigrationQueueCalled++ + return true, nil + } + + numAddLeafToMigrationQueueCalled++ + return false, nil + }, + } + + err := dtr.CollectLeavesForMigration(core.NotSpecified, core.AutoBalanceEnabled, dtm) + assert.Nil(t, err) + assert.Equal(t, 5, numLoads) + assert.Equal(t, 2, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate to non existent version", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + dtr := initTrie().(dataTrie) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + err := dtr.CollectLeavesForMigration(core.NotSpecified, core.TrieNodeVersion(100), dtm) + assert.True(t, strings.Contains(err.Error(), errorsCommon.ErrInvalidTrieNodeVersion.Error())) + assert.Equal(t, 0, numLoadsCalled) + assert.Equal(t, 0, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate from non existent version", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + dtr := initTrie().(dataTrie) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + err := dtr.CollectLeavesForMigration(core.TrieNodeVersion(100), core.AutoBalanceEnabled, dtm) + assert.True(t, strings.Contains(err.Error(), errorsCommon.ErrInvalidTrieNodeVersion.Error())) + assert.Equal(t, 0, numLoadsCalled) + assert.Equal(t, 0, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate collapsed trie", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + addDefaultDataToTrie(tr) + _ = tr.Commit() + rootHash, _ := tr.RootHash() + collapsedTrie, _ := tr.Recreate(rootHash) + dtr := collapsedTrie.(dataTrie) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + err := dtr.CollectLeavesForMigration(core.NotSpecified, core.AutoBalanceEnabled, dtm) + assert.Nil(t, err) + assert.Equal(t, 6, numLoadsCalled) + assert.Equal(t, 3, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate all non migrated leaves", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + dtr := tr.(dataTrie) + _ = dtr.UpdateWithVersion([]byte("dog"), []byte("reindeer"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("ddog"), []byte("puppy"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("doe"), []byte("cat"), core.NotSpecified) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + err := dtr.CollectLeavesForMigration(core.NotSpecified, core.AutoBalanceEnabled, dtm) + assert.Nil(t, err) + assert.Equal(t, 2, numLoadsCalled) + assert.Equal(t, 1, numAddLeafToMigrationQueueCalled) + }) + + t.Run("migrate to same version", func(t *testing.T) { + t.Parallel() + + numLoadsCalled := 0 + numAddLeafToMigrationQueueCalled := 0 + tr := emptyTrieWithCustomEnableEpochsHandler( + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsAutoBalanceDataTriesEnabledField: true, + }, + ) + dtr := tr.(dataTrie) + _ = dtr.UpdateWithVersion([]byte("dog"), []byte("reindeer"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("ddog"), []byte("puppy"), core.AutoBalanceEnabled) + _ = dtr.UpdateWithVersion([]byte("doe"), []byte("cat"), core.AutoBalanceEnabled) + dtm := &trieMock.DataTrieMigratorStub{ + ConsumeStorageLoadGasCalled: func() bool { + numLoadsCalled++ + return true + }, + AddLeafToMigrationQueueCalled: func(_ core.TrieData, _ core.TrieNodeVersion) (bool, error) { + numAddLeafToMigrationQueueCalled++ + return true, nil + }, + } + + err := dtr.CollectLeavesForMigration(core.AutoBalanceEnabled, core.AutoBalanceEnabled, dtm) + assert.Nil(t, err) + assert.Equal(t, 1, numLoadsCalled) + assert.Equal(t, 0, numAddLeafToMigrationQueueCalled) + }) +} + func BenchmarkPatriciaMerkleTree_Insert(b *testing.B) { tr := emptyTrie() hsh := keccak.NewKeccak() diff --git a/vm/gasCost.go b/vm/gasCost.go index 57762655960..0c9df17a54d 100644 --- a/vm/gasCost.go +++ b/vm/gasCost.go @@ -55,6 +55,8 @@ type BuiltInCost struct { ESDTNFTAddUri uint64 ESDTNFTUpdateAttributes uint64 ESDTNFTMultiTransfer uint64 + TrieLoadPerNode uint64 + TrieStorePerNode uint64 } // GasCost holds all the needed gas costs for system smart contracts diff --git a/vm/systemSmartContracts/defaults/gasMap.go b/vm/systemSmartContracts/defaults/gasMap.go index 9137f03cc35..99a78a523d8 100644 --- a/vm/systemSmartContracts/defaults/gasMap.go +++ b/vm/systemSmartContracts/defaults/gasMap.go @@ -47,6 +47,8 @@ func FillGasMapBuiltInCosts(value uint64) map[string]uint64 { gasMap["ESDTNFTAddUri"] = value gasMap["ESDTNFTUpdateAttributes"] = value gasMap["ESDTNFTMultiTransfer"] = value + gasMap["TrieLoadPerNode"] = value + gasMap["TrieStorePerNode"] = value return gasMap }