diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index df2927d207d..346327a9b7f 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -840,6 +840,40 @@ func (suite *KeeperTestSuite) TestChanUpgradeOpenCounterpartyStates() { } } +func (suite *KeeperTestSuite) TestWriteUpgradeOpenChannel() { + suite.SetupTest() + + path := ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Ordering = types.ORDERED + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Ordering = types.ORDERED + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + suite.Require().NoError(path.EndpointA.ChanUpgradeAck()) + + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.WriteUpgradeOpenChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + + channel := path.EndpointA.GetChannel() + + // Assert that channel state has been updated + suite.Require().Equal(types.OPEN, channel.State) + suite.Require().Equal(mock.UpgradeVersion, channel.Version) + suite.Require().Equal(types.ORDERED, channel.Ordering) + suite.Require().Equal(types.NOTINFLUSH, channel.FlushStatus) + + // Assert that state stored for upgrade has been deleted + upgrade, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().Equal(types.Upgrade{}, upgrade) + suite.Require().False(found) + lastPacketSequence, found := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.GetCounterpartyLastPacketSequence(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + suite.Require().Equal(uint64(0), lastPacketSequence) + suite.Require().False(found) +} + func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { var ( path *ibctesting.Path