diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index 0b6193b9f04..c0422a3ee41 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -805,9 +805,18 @@ func (k Keeper) constructProposedUpgrade(ctx sdk.Context, portID, channelID stri }, nil } -// AbortUpgrade will restore the channel state and flush status to their pre-upgrade state so that upgrade is aborted. -// any unnecessary state is deleted. An error receipt is written, and the OnChanUpgradeRestore callback is called. -func (k Keeper) AbortUpgrade(ctx sdk.Context, portID, channelID string, err error) error { +// MustAbortUpgrade will restore the channel state and flush status to their pre-upgrade state so that upgrade is aborted. +// Any unnecessary state is deleted and an error receipt is written. +// This function is expected to always succeed, a panic will occur if an error occurs. +func (k Keeper) MustAbortUpgrade(ctx sdk.Context, portID, channelID string, err error) { + if err := k.abortUpgrade(ctx, portID, channelID, err); err != nil { + panic(err) + } +} + +// abortUpgrade will restore the channel state and flush status to their pre-upgrade state so that upgrade is aborted. +// Any unnecessary state is delete and an error receipt is written. +func (k Keeper) abortUpgrade(ctx sdk.Context, portID, channelID string, err error) error { if err == nil { return errorsmod.Wrap(types.ErrInvalidUpgradeError, "cannot abort upgrade handshake with nil error") } diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index d1d671d9ded..9ff671d4541 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -1305,10 +1305,10 @@ func (suite *KeeperTestSuite) TestAbortHandshake() { tc.malleate() - err := channelKeeper.AbortUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError) - if tc.expPass { - suite.Require().NoError(err) + suite.Require().NotPanics(func() { + channelKeeper.MustAbortUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError) + }) channel, found := channelKeeper.GetChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) suite.Require().True(found, "channel should be found") @@ -1329,7 +1329,11 @@ func (suite *KeeperTestSuite) TestAbortHandshake() { suite.Require().False(found, "counterparty last packet sequence should not be found") } else { - suite.Require().Error(err) + + suite.Require().Panics(func() { + channelKeeper.MustAbortUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError) + }) + channel, found := channelKeeper.GetChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) if found { // test cases uses a channel that exists suite.Require().Equal(types.INITUPGRADE, channel.State, "channel state should not be restored to %s", types.INITUPGRADE.String()) diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index b7c68c0045b..92e1e41f2f9 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -769,9 +769,7 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh if err != nil { ctx.Logger().Error("channel upgrade try failed", "error", errorsmod.Wrap(err, "channel upgrade try failed")) if upgradeErr, ok := err.(*channeltypes.UpgradeError); ok { - if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr); err != nil { - return nil, errorsmod.Wrap(err, "channel upgrade try (abort upgrade) failed") - } + k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr) // NOTE: a FAILURE result is returned to the client and an error receipt is written to state. // This signals to the relayer to begin the cancel upgrade handshake subprotocol. @@ -786,9 +784,7 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh upgradeVersion, err := cbs.OnChanUpgradeTry(cacheCtx, msg.PortId, msg.ChannelId, upgrade.Fields.Ordering, upgrade.Fields.ConnectionHops, upgrade.Fields.Version) if err != nil { ctx.Logger().Error("channel upgrade try callback failed", "port-id", msg.PortId, "channel-id", msg.ChannelId, "error", err.Error()) - if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, err); err != nil { - return nil, err - } + k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err) return &channeltypes.MsgChannelUpgradeTryResponse{Result: channeltypes.FAILURE}, nil } @@ -826,9 +822,7 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh if err != nil { ctx.Logger().Error("channel upgrade ack failed", "error", errorsmod.Wrap(err, "channel upgrade ack failed")) if upgradeErr, ok := err.(*channeltypes.UpgradeError); ok { - if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr); err != nil { - return nil, errorsmod.Wrap(err, "channel upgrade ack (abort upgrade) failed") - } + k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr) // NOTE: a FAILURE result is returned to the client and an error receipt is written to state. // This signals to the relayer to begin the cancel upgrade handshake subprotocol. @@ -843,9 +837,7 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh err = cbs.OnChanUpgradeAck(cacheCtx, msg.PortId, msg.ChannelId, msg.CounterpartyUpgrade.Fields.Version) if err != nil { ctx.Logger().Error("channel upgrade ack callback failed", "port-id", msg.PortId, "channel-id", msg.ChannelId, "error", err.Error()) - if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, err); err != nil { - return nil, errorsmod.Wrapf(err, "channel upgrade ack callback (abort upgrade) failed for port ID: %s, channel ID: %s", msg.PortId, msg.ChannelId) - } + k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err) return &channeltypes.MsgChannelUpgradeAckResponse{Result: channeltypes.FAILURE}, nil }