diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index 7a76c23560f..74d86056208 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -283,7 +283,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeTry() { // ensure clients are up to date to receive valid proofs suite.Require().NoError(path.EndpointB.UpdateClient()) - proofCounterpartyChannel, proofCounterpartyUpgrade, proofHeight := path.EndpointB.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, proofHeight := path.EndpointA.QueryChannelUpgradeProof() upgrade, err := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeTry( suite.chainB.GetContext(), @@ -292,8 +292,8 @@ func (suite *KeeperTestSuite) TestChanUpgradeTry() { proposedUpgrade.Fields.ConnectionHops, counterpartyUpgrade.Fields, path.EndpointA.GetChannel().UpgradeSequence, - proofCounterpartyChannel, - proofCounterpartyUpgrade, + proofChannel, + proofUpgrade, proofHeight, ) @@ -542,7 +542,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() { tc.malleate() - proofChannel, proofUpgrade, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, proofHeight := path.EndpointB.QueryChannelUpgradeProof() err = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeAck( suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, counterpartyUpgrade, @@ -839,7 +839,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeConfirm() { tc.malleate() - proofChannel, proofUpgrade, proofHeight := path.EndpointB.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, proofHeight := path.EndpointA.QueryChannelUpgradeProof() err = suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeConfirm( suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, counterpartyChannelState, counterpartyUpgrade, @@ -1075,10 +1075,12 @@ func (suite *KeeperTestSuite) TestChanUpgradeOpen() { tc.malleate() - proofCounterpartyChannel, _, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight := path.EndpointB.QueryProof(channelKey) + err = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeOpen( suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, - path.EndpointB.GetChannel().State, proofCounterpartyChannel, proofHeight, + path.EndpointB.GetChannel().State, proofChannel, proofHeight, ) if tc.expError == nil { @@ -1489,9 +1491,9 @@ func (suite *KeeperTestSuite) TestWriteUpgradeCancelChannel() { func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { var ( - path *ibctesting.Path - proofHeight exported.Height - proofCounterpartyChannel []byte + path *ibctesting.Path + proofChannel []byte + proofHeight exported.Height ) timeoutUpgrade := func() { @@ -1510,7 +1512,9 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { "success: proof timestamp has passed", func() { timeoutUpgrade() - proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight = path.EndpointB.QueryProof(channelKey) }, nil, }, @@ -1572,7 +1576,8 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { suite.Require().NoError(path.EndpointA.UpdateClient()) - proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight = path.EndpointB.QueryProof(channelKey) // modify state so the proof becomes invalid. channel.State = types.FLUSHING @@ -1592,7 +1597,8 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { suite.Require().NoError(path.EndpointA.UpdateClient()) - proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight = path.EndpointB.QueryProof(channelKey) }, types.ErrInvalidUpgradeSequence, }, @@ -1605,7 +1611,8 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { suite.Require().NoError(path.EndpointB.UpdateClient()) - proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight = path.EndpointB.QueryProof(channelKey) }, types.ErrInvalidUpgradeTimeout, }, @@ -1620,7 +1627,8 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { suite.Require().NoError(path.EndpointA.UpdateClient()) - proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight = path.EndpointB.QueryProof(channelKey) }, types.ErrInvalidCounterparty, }, @@ -1640,7 +1648,8 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { suite.Require().NoError(path.EndpointA.UpdateClient()) suite.Require().NoError(path.EndpointB.UpdateClient()) - proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight = path.EndpointB.QueryProof(channelKey) }, connectiontypes.ErrConnectionNotFound, }, @@ -1654,7 +1663,8 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { suite.Require().NoError(path.EndpointA.UpdateClient()) - proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight = path.EndpointB.QueryProof(channelKey) }, types.ErrUpgradeTimeoutFailed, }, @@ -1676,7 +1686,8 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) suite.Require().NoError(path.EndpointA.ChanUpgradeAck()) - proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight = path.EndpointB.QueryProof(channelKey) tc.malleate() @@ -1685,7 +1696,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.GetChannel(), - proofCounterpartyChannel, + proofChannel, proofHeight, ) diff --git a/modules/core/keeper/msg_server_test.go b/modules/core/keeper/msg_server_test.go index a791801d132..2aeec6c37a2 100644 --- a/modules/core/keeper/msg_server_test.go +++ b/modules/core/keeper/msg_server_test.go @@ -1005,7 +1005,7 @@ func (suite *KeeperTestSuite) TestChannelUpgradeTry() { counterpartyUpgrade, found := suite.chainA.GetSimApp().GetIBCKeeper().ChannelKeeper.GetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) suite.Require().True(found) - proofChannel, proofUpgrade, proofHeight := path.EndpointB.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, proofHeight := path.EndpointA.QueryChannelUpgradeProof() msg = &channeltypes.MsgChannelUpgradeTry{ PortId: path.EndpointB.ChannelConfig.PortID, @@ -1174,7 +1174,7 @@ func (suite *KeeperTestSuite) TestChannelUpgradeAck() { counterpartyUpgrade := path.EndpointB.GetChannelUpgrade() - proofChannel, proofUpgrade, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, proofHeight := path.EndpointB.QueryChannelUpgradeProof() msg = &channeltypes.MsgChannelUpgradeAck{ PortId: path.EndpointA.ChannelConfig.PortID, @@ -1247,7 +1247,7 @@ func (suite *KeeperTestSuite) TestChannelUpgradeConfirm() { counterpartyChannelState := path.EndpointA.GetChannel().State counterpartyUpgrade := path.EndpointA.GetChannelUpgrade() - proofChannel, proofUpgrade, proofHeight := path.EndpointB.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, proofHeight := path.EndpointA.QueryChannelUpgradeProof() msg = &channeltypes.MsgChannelUpgradeConfirm{ PortId: path.EndpointB.ChannelConfig.PortID, @@ -1336,7 +1336,7 @@ func (suite *KeeperTestSuite) TestChannelUpgradeConfirm() { err := path.EndpointB.UpdateClient() suite.Require().NoError(err) - proofChannel, proofUpgrade, proofHeight := path.EndpointB.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, proofHeight := path.EndpointA.QueryChannelUpgradeProof() msg.CounterpartyUpgrade = upgrade msg.ProofChannel = proofChannel @@ -1383,7 +1383,7 @@ func (suite *KeeperTestSuite) TestChannelUpgradeConfirm() { counterpartyChannelState := path.EndpointA.GetChannel().State counterpartyUpgrade := path.EndpointA.GetChannelUpgrade() - proofChannel, proofUpgrade, proofHeight := path.EndpointB.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, proofHeight := path.EndpointA.QueryChannelUpgradeProof() msg = &channeltypes.MsgChannelUpgradeConfirm{ PortId: path.EndpointB.ChannelConfig.PortID, @@ -1484,7 +1484,8 @@ func (suite *KeeperTestSuite) TestChannelUpgradeOpen() { suite.Require().NoError(err) counterpartyChannel := path.EndpointB.GetChannel() - proofChannel, _, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + proofChannel, proofHeight := path.EndpointB.QueryProof(channelKey) msg = &channeltypes.MsgChannelUpgradeOpen{ PortId: path.EndpointA.ChannelConfig.PortID, @@ -1646,7 +1647,9 @@ func (suite *KeeperTestSuite) TestChannelUpgradeTimeout() { suite.Require().NoError(path.EndpointA.UpdateClient()) - channelProof, _, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channelProof, proofHeight := path.EndpointB.QueryProof(channelKey) + msg.ProofChannel = channelProof msg.ProofHeight = proofHeight }, @@ -1687,7 +1690,8 @@ func (suite *KeeperTestSuite) TestChannelUpgradeTimeout() { suite.Require().NoError(path.EndpointA.UpdateClient()) - _, _, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + _, _, proofHeight := path.EndpointB.QueryChannelUpgradeProof() + msg.ProofHeight = proofHeight msg.ProofChannel = []byte("invalid proof") }, @@ -1720,7 +1724,8 @@ func (suite *KeeperTestSuite) TestChannelUpgradeTimeout() { suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) suite.Require().NoError(path.EndpointA.ChanUpgradeAck()) - channelProof, _, proofHeight := path.EndpointA.QueryChannelUpgradeProof() + channelKey := host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) + channelProof, proofHeight := path.EndpointB.QueryProof(channelKey) msg = &channeltypes.MsgChannelUpgradeTimeout{ PortId: path.EndpointA.ChannelConfig.PortID, diff --git a/testing/endpoint.go b/testing/endpoint.go index 34a9dac9ad1..0738d5e735c 100644 --- a/testing/endpoint.go +++ b/testing/endpoint.go @@ -569,19 +569,14 @@ func (endpoint *Endpoint) TimeoutOnClose(packet channeltypes.Packet) error { } // QueryChannelUpgradeProof returns all the proofs necessary to execute UpgradeTry/UpgradeAck/UpgradeOpen. -// It returns the proof for the channel on the counterparty chain, the proof for the upgrade attempt on the -// counterparty chain, and the height at which the proof was queried. +// It returns the proof for the channel on the endpoint's chain, the proof for the upgrade attempt on the +// endpoint's chain, and the height at which the proof was queried. func (endpoint *Endpoint) QueryChannelUpgradeProof() ([]byte, []byte, clienttypes.Height) { - counterpartyChannelID := endpoint.Counterparty.ChannelID - counterpartyPortID := endpoint.Counterparty.ChannelConfig.PortID + channelKey := host.ChannelKey(endpoint.ChannelConfig.PortID, endpoint.ChannelID) + proofChannel, height := endpoint.QueryProof(channelKey) - // query proof for the channel on the counterparty - channelKey := host.ChannelKey(counterpartyPortID, counterpartyChannelID) - proofChannel, height := endpoint.Counterparty.QueryProof(channelKey) - - // query proof for the upgrade attempt on the counterparty - upgradeKey := host.ChannelUpgradeKey(counterpartyPortID, counterpartyChannelID) - proofUpgrade, _ := endpoint.Counterparty.QueryProof(upgradeKey) + upgradeKey := host.ChannelUpgradeKey(endpoint.ChannelConfig.PortID, endpoint.ChannelID) + proofUpgrade, _ := endpoint.QueryProof(upgradeKey) return proofChannel, proofUpgrade, height } @@ -629,7 +624,7 @@ func (endpoint *Endpoint) ChanUpgradeTry() error { require.NoError(endpoint.Chain.TB, err) upgrade := endpoint.GetProposedUpgrade() - proofChannel, proofUpgrade, height := endpoint.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, height := endpoint.Counterparty.QueryChannelUpgradeProof() counterpartyUpgrade, found := endpoint.Counterparty.Chain.App.GetIBCKeeper().ChannelKeeper.GetUpgrade(endpoint.Counterparty.Chain.GetContext(), endpoint.Counterparty.ChannelConfig.PortID, endpoint.Counterparty.ChannelID) require.True(endpoint.Chain.TB, found) @@ -658,7 +653,7 @@ func (endpoint *Endpoint) ChanUpgradeAck() error { err := endpoint.UpdateClient() require.NoError(endpoint.Chain.TB, err) - proofChannel, proofUpgrade, height := endpoint.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, height := endpoint.Counterparty.QueryChannelUpgradeProof() counterpartyUpgrade, found := endpoint.Counterparty.Chain.App.GetIBCKeeper().ChannelKeeper.GetUpgrade(endpoint.Counterparty.Chain.GetContext(), endpoint.Counterparty.ChannelConfig.PortID, endpoint.Counterparty.ChannelID) require.True(endpoint.Chain.TB, found) @@ -681,7 +676,7 @@ func (endpoint *Endpoint) ChanUpgradeConfirm() error { err := endpoint.UpdateClient() require.NoError(endpoint.Chain.TB, err) - proofChannel, proofUpgrade, height := endpoint.QueryChannelUpgradeProof() + proofChannel, proofUpgrade, height := endpoint.Counterparty.QueryChannelUpgradeProof() counterpartyUpgrade, found := endpoint.Counterparty.Chain.App.GetIBCKeeper().ChannelKeeper.GetUpgrade(endpoint.Counterparty.Chain.GetContext(), endpoint.Counterparty.ChannelConfig.PortID, endpoint.Counterparty.ChannelID) require.True(endpoint.Chain.TB, found)