diff --git a/modules/apps/ccv/child/keeper/relay.go b/modules/apps/ccv/child/keeper/relay.go index 4423c3b3f72..652576d2579 100644 --- a/modules/apps/ccv/child/keeper/relay.go +++ b/modules/apps/ccv/child/keeper/relay.go @@ -8,13 +8,15 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/ibc-go/modules/apps/ccv/child/types" ccv "github.com/cosmos/ibc-go/modules/apps/ccv/types" + utils "github.com/cosmos/ibc-go/modules/apps/ccv/utils" channeltypes "github.com/cosmos/ibc-go/modules/core/04-channel/types" host "github.com/cosmos/ibc-go/modules/core/24-host" + abci "github.com/tendermint/tendermint/abci/types" ) // OnRecvPacket sets the pending validator set changes that will be flushed to ABCI on Endblock // and set the unbonding time for the packet so that we can WriteAcknowledgement after unbonding time is over. -func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, data ccv.ValidatorSetChangePacketData) *channeltypes.Acknowledgement { +func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, newChanges ccv.ValidatorSetChangePacketData) *channeltypes.Acknowledgement { // packet is not sent on parent channel, return error acknowledgement and close channel if parentChannel, ok := k.GetParentChannel(ctx); ok && parentChannel != packet.DestinationChannel { ack := channeltypes.NewErrorAcknowledgement( @@ -29,12 +31,22 @@ func (k Keeper) OnRecvPacket(ctx sdk.Context, packet channeltypes.Packet, data c k.SetChannelStatus(ctx, packet.DestinationChannel, ccv.VALIDATING) k.SetParentChannel(ctx, packet.DestinationChannel) } - // Set PendingChanges to be flushed and the unbonding time and unbonding packet. - // TODO: Get PendingChanges and update the pending changes if they already exist - k.SetPendingChanges(ctx, data) + + // Set pending changes by accumulating changes from this packet with all prior changes + var pendingChanges []abci.ValidatorUpdate + currentChanges, exists := k.GetPendingChanges(ctx) + if !exists { + pendingChanges = newChanges.ValidatorUpdates + } else { + pendingChanges = utils.AccumulateChanges(currentChanges.ValidatorUpdates, newChanges.ValidatorUpdates) + } + k.SetPendingChanges(ctx, ccv.ValidatorSetChangePacketData{ValidatorUpdates: pendingChanges}) + + // Save unbonding time and packet unbondingTime := ctx.BlockTime().Add(types.UnbondingTime) k.SetUnbondingTime(ctx, packet.Sequence, uint64(unbondingTime.UnixNano())) k.SetUnbondingPacket(ctx, packet.Sequence, packet) + // ack will be sent asynchronously return nil } diff --git a/modules/apps/ccv/child/keeper/relay_test.go b/modules/apps/ccv/child/keeper/relay_test.go index bb881a5b8ce..46045da972c 100644 --- a/modules/apps/ccv/child/keeper/relay_test.go +++ b/modules/apps/ccv/child/keeper/relay_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "sort" "time" cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" @@ -22,56 +23,95 @@ func (suite *KeeperTestSuite) TestOnRecvPacket() { suite.Require().NoError(err) pk2, err := cryptocodec.ToTmProtoPublicKey(ed25519.GenPrivKey().PubKey()) suite.Require().NoError(err) + pk3, err := cryptocodec.ToTmProtoPublicKey(ed25519.GenPrivKey().PubKey()) + suite.Require().NoError(err) - pd := types.NewValidatorSetChangePacketData( - []abci.ValidatorUpdate{ - { - PubKey: pk1, - Power: 30, - }, - { - PubKey: pk2, - Power: 20, - }, + changes1 := []abci.ValidatorUpdate{ + { + PubKey: pk1, + Power: 30, + }, + { + PubKey: pk2, + Power: 20, + }, + } + + changes2 := []abci.ValidatorUpdate{ + { + PubKey: pk2, + Power: 40, + }, + { + PubKey: pk3, + Power: 10, }, + } + + pd := types.NewValidatorSetChangePacketData( + changes1, ) - packet := channeltypes.NewPacket(pd.GetBytes(), 1, parenttypes.PortID, suite.path.EndpointB.ChannelID, childtypes.PortID, suite.path.EndpointA.ChannelID, - clienttypes.NewHeight(1, 0), 0) + pd2 := types.NewValidatorSetChangePacketData( + changes2, + ) testCases := []struct { - name string - malleatePacket func() - expErrorAck bool + name string + packet channeltypes.Packet + newChanges types.ValidatorSetChangePacketData + expectedPendingChanges types.ValidatorSetChangePacketData + expErrorAck bool }{ { "success on first packet", - func() {}, + channeltypes.NewPacket(pd.GetBytes(), 1, parenttypes.PortID, suite.path.EndpointB.ChannelID, childtypes.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0), + types.ValidatorSetChangePacketData{ValidatorUpdates: changes1}, + types.ValidatorSetChangePacketData{ValidatorUpdates: changes1}, false, }, { "success on subsequent packet", - func() { - packet.Sequence = 2 - }, + channeltypes.NewPacket(pd.GetBytes(), 2, parenttypes.PortID, suite.path.EndpointB.ChannelID, childtypes.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0), + types.ValidatorSetChangePacketData{ValidatorUpdates: changes1}, + types.ValidatorSetChangePacketData{ValidatorUpdates: changes1}, + false, + }, + { + "success on packet with more changes", + channeltypes.NewPacket(pd2.GetBytes(), 3, parenttypes.PortID, suite.path.EndpointB.ChannelID, childtypes.PortID, suite.path.EndpointA.ChannelID, + clienttypes.NewHeight(1, 0), 0), + types.ValidatorSetChangePacketData{ValidatorUpdates: changes2}, + types.ValidatorSetChangePacketData{ValidatorUpdates: []abci.ValidatorUpdate{ + { + PubKey: pk1, + Power: 30, + }, + { + PubKey: pk2, + Power: 40, + }, + { + PubKey: pk3, + Power: 10, + }, + }}, false, }, { "invalid packet: different destination channel than parent channel", - func() { - packet.Sequence = 1 - // change destination channel to different channelID than parent channel - packet.DestinationChannel = "invalidChannel" - }, + channeltypes.NewPacket(pd.GetBytes(), 1, parenttypes.PortID, suite.path.EndpointB.ChannelID, childtypes.PortID, "InvalidChannel", + clienttypes.NewHeight(1, 0), 0), + types.ValidatorSetChangePacketData{ValidatorUpdates: []abci.ValidatorUpdate{}}, + types.ValidatorSetChangePacketData{ValidatorUpdates: []abci.ValidatorUpdate{}}, true, }, } for _, tc := range testCases { - // malleate packet for each case - tc.malleatePacket() - - ack := suite.childChain.GetSimApp().ChildKeeper.OnRecvPacket(suite.ctx, packet, pd) + ack := suite.childChain.GetSimApp().ChildKeeper.OnRecvPacket(suite.ctx, tc.packet, tc.newChanges) if tc.expErrorAck { suite.Require().NotNil(ack, "invalid test case: %s did not return ack", tc.name) @@ -82,17 +122,27 @@ func (suite *KeeperTestSuite) TestOnRecvPacket() { "channel status is not valdidating after receive packet for valid test case: %s", tc.name) parentChannel, ok := suite.childChain.GetSimApp().ChildKeeper.GetParentChannel(suite.ctx) suite.Require().True(ok) - suite.Require().Equal(packet.DestinationChannel, parentChannel, + suite.Require().Equal(tc.packet.DestinationChannel, parentChannel, "parent channel is not destination channel on successful receive for valid test case: %s", tc.name) - actualPd, ok := suite.childChain.GetSimApp().ChildKeeper.GetPendingChanges(suite.ctx) + + // Check that pending changes are accumulated and stored correctly + actualPendingChanges, ok := suite.childChain.GetSimApp().ChildKeeper.GetPendingChanges(suite.ctx) suite.Require().True(ok) - suite.Require().Equal(&pd, actualPd, "pending changes not equal to packet data after successful packet receive. case: %s", tc.name) + // Sort to avoid dumb inequalities + sort.SliceStable(actualPendingChanges.ValidatorUpdates, func(i, j int) bool { + return actualPendingChanges.ValidatorUpdates[i].PubKey.Compare(actualPendingChanges.ValidatorUpdates[j].PubKey) == -1 + }) + sort.SliceStable(tc.expectedPendingChanges.ValidatorUpdates, func(i, j int) bool { + return tc.expectedPendingChanges.ValidatorUpdates[i].PubKey.Compare(tc.expectedPendingChanges.ValidatorUpdates[j].PubKey) == -1 + }) + suite.Require().Equal(tc.expectedPendingChanges, *actualPendingChanges, "pending changes not equal to expected changes after successful packet receive. case: %s", tc.name) + expectedTime := uint64(suite.ctx.BlockTime().Add(childtypes.UnbondingTime).UnixNano()) - unbondingTime := suite.childChain.GetSimApp().ChildKeeper.GetUnbondingTime(suite.ctx, packet.Sequence) + unbondingTime := suite.childChain.GetSimApp().ChildKeeper.GetUnbondingTime(suite.ctx, tc.packet.Sequence) suite.Require().Equal(expectedTime, unbondingTime, "unbonding time has unexpected value for case: %s", tc.name) - unbondingPacket, err := suite.childChain.GetSimApp().ChildKeeper.GetUnbondingPacket(suite.ctx, packet.Sequence) + unbondingPacket, err := suite.childChain.GetSimApp().ChildKeeper.GetUnbondingPacket(suite.ctx, tc.packet.Sequence) suite.Require().NoError(err) - suite.Require().Equal(&packet, unbondingPacket, "packet is not added to unbonding queue after successful receive. case: %s", tc.name) + suite.Require().Equal(&tc.packet, unbondingPacket, "packet is not added to unbonding queue after successful receive. case: %s", tc.name) } } } diff --git a/modules/apps/ccv/utils/utils.go b/modules/apps/ccv/utils/utils.go new file mode 100644 index 00000000000..177cde6f588 --- /dev/null +++ b/modules/apps/ccv/utils/utils.go @@ -0,0 +1,24 @@ +package utils + +import ( + abci "github.com/tendermint/tendermint/abci/types" +) + +func AccumulateChanges(currentChanges, newChanges []abci.ValidatorUpdate) []abci.ValidatorUpdate { + m := make(map[string]abci.ValidatorUpdate) + + for i := 0; i < len(currentChanges); i++ { + m[currentChanges[i].PubKey.String()] = currentChanges[i] + } + + for i := 0; i < len(newChanges); i++ { + m[newChanges[i].PubKey.String()] = newChanges[i] + } + + var out []abci.ValidatorUpdate + + for _, update := range m { + out = append(out, update) + } + return out +}