diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index e39b0a2c80c..6aaf37b55df 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -355,11 +355,20 @@ func (k Keeper) ChanUpgradeOpen( if !found { return errorsmod.Wrapf(types.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", portID, channelID) } + // If counterparty has reached OPEN, we must use the upgraded connection to verify the counterparty channel + upgradeConnection, err := k.GetConnection(ctx, upgrade.Fields.ConnectionHops[0]) + if err != nil { + return errorsmod.Wrap(err, "failed to retrieve connection using the upgrade connection hops") + } + + if upgradeConnection.GetState() != int32(connectiontypes.OPEN) { + return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(upgradeConnection.GetState()).String()) + } counterpartyChannel = types.Channel{ State: types.OPEN, Ordering: upgrade.Fields.Ordering, - ConnectionHops: upgrade.Fields.ConnectionHops, + ConnectionHops: []string{upgradeConnection.GetCounterparty().GetConnectionID()}, Counterparty: types.NewCounterparty(portID, channelID), Version: upgrade.Fields.Version, UpgradeSequence: channel.UpgradeSequence, diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index bb850cf68cb..df2927d207d 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -754,6 +754,92 @@ func (suite *KeeperTestSuite) TestChanUpgradeOpen() { } } +// TestChanUpgradeOpenCounterPartyStates tests the handshake in the cases where +// the counterparty is in a state other than OPEN. +func (suite *KeeperTestSuite) TestChanUpgradeOpenCounterpartyStates() { + var path *ibctesting.Path + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success, counterparty in OPEN", + func() { + err := path.EndpointB.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeTry() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeAck() + suite.Require().NoError(err) + + // TODO: Remove when #4030 is closed. Channel will automatically + // move to OPEN in that case. + err = path.EndpointB.ChanUpgradeOpen() + suite.Require().NoError(err) + + suite.coordinator.CommitBlock(suite.chainA, suite.chainB) + suite.Require().NoError(path.EndpointA.UpdateClient()) + }, + nil, + }, + { + "success, counterparty in TRYUPGRADE", + func() { + err := path.EndpointA.ChanUpgradeInit() + suite.Require().NoError(err) + + err = path.EndpointB.ChanUpgradeTry() + suite.Require().NoError(err) + + err = path.EndpointA.ChanUpgradeAck() + suite.Require().NoError(err) + }, + nil, + }, + } + + // Create an initial path used only to invoke ConnOpenInit/ChanOpenInit handlers. + // This bumps the connection/channel identifiers generated for chain A on the + // next path used to run the upgrade handshake. + // See issue 4062. + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.SetupClients(path) + suite.Require().NoError(path.EndpointA.ConnOpenInit()) + suite.coordinator.SetupConnections(path) + suite.Require().NoError(path.EndpointA.ChanOpenInit()) + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + tc.malleate() + + proofCounterpartyChannel, _, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeOpen( + suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, + path.EndpointB.GetChannel().State, proofCounterpartyChannel, proofHeight, + ) + + expPass := tc.expError == nil + if expPass { + suite.Require().NoError(err) + } else { + suite.Require().ErrorIs(err, tc.expError) + } + }) + } +} + func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { var ( path *ibctesting.Path