Skip to content

Commit

Permalink
Make AbortUpgrade panic on failure (#4011)
Browse files Browse the repository at this point in the history
* chore: update abort upgrade function to panic on error

* apply review suggestions

---------

Co-authored-by: Carlos Rodriguez <carlos@interchain.io>
  • Loading branch information
chatton and crodriguezvega authored Jul 6, 2023
1 parent c249e1e commit 05c43dc
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
15 changes: 12 additions & 3 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -805,9 +805,18 @@ func (k Keeper) constructProposedUpgrade(ctx sdk.Context, portID, channelID stri
}, nil
}

// AbortUpgrade will restore the channel state and flush status to their pre-upgrade state so that upgrade is aborted.
// any unnecessary state is deleted. An error receipt is written, and the OnChanUpgradeRestore callback is called.
func (k Keeper) AbortUpgrade(ctx sdk.Context, portID, channelID string, err error) error {
// MustAbortUpgrade will restore the channel state and flush status to their pre-upgrade state so that upgrade is aborted.
// Any unnecessary state is deleted and an error receipt is written.
// This function is expected to always succeed, a panic will occur if an error occurs.
func (k Keeper) MustAbortUpgrade(ctx sdk.Context, portID, channelID string, err error) {
if err := k.abortUpgrade(ctx, portID, channelID, err); err != nil {
panic(err)
}
}

// abortUpgrade will restore the channel state and flush status to their pre-upgrade state so that upgrade is aborted.
// Any unnecessary state is delete and an error receipt is written.
func (k Keeper) abortUpgrade(ctx sdk.Context, portID, channelID string, err error) error {
if err == nil {
return errorsmod.Wrap(types.ErrInvalidUpgradeError, "cannot abort upgrade handshake with nil error")
}
Expand Down
12 changes: 8 additions & 4 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1305,10 +1305,10 @@ func (suite *KeeperTestSuite) TestAbortHandshake() {

tc.malleate()

err := channelKeeper.AbortUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError)

if tc.expPass {
suite.Require().NoError(err)
suite.Require().NotPanics(func() {
channelKeeper.MustAbortUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError)
})

channel, found := channelKeeper.GetChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
suite.Require().True(found, "channel should be found")
Expand All @@ -1329,7 +1329,11 @@ func (suite *KeeperTestSuite) TestAbortHandshake() {
suite.Require().False(found, "counterparty last packet sequence should not be found")

} else {
suite.Require().Error(err)

suite.Require().Panics(func() {
channelKeeper.MustAbortUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError)
})

channel, found := channelKeeper.GetChannel(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)
if found { // test cases uses a channel that exists
suite.Require().Equal(types.INITUPGRADE, channel.State, "channel state should not be restored to %s", types.INITUPGRADE.String())
Expand Down
16 changes: 4 additions & 12 deletions modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -769,9 +769,7 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh
if err != nil {
ctx.Logger().Error("channel upgrade try failed", "error", errorsmod.Wrap(err, "channel upgrade try failed"))
if upgradeErr, ok := err.(*channeltypes.UpgradeError); ok {
if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr); err != nil {
return nil, errorsmod.Wrap(err, "channel upgrade try (abort upgrade) failed")
}
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr)

// NOTE: a FAILURE result is returned to the client and an error receipt is written to state.
// This signals to the relayer to begin the cancel upgrade handshake subprotocol.
Expand All @@ -786,9 +784,7 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh
upgradeVersion, err := cbs.OnChanUpgradeTry(cacheCtx, msg.PortId, msg.ChannelId, upgrade.Fields.Ordering, upgrade.Fields.ConnectionHops, upgrade.Fields.Version)
if err != nil {
ctx.Logger().Error("channel upgrade try callback failed", "port-id", msg.PortId, "channel-id", msg.ChannelId, "error", err.Error())
if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, err); err != nil {
return nil, err
}
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)

return &channeltypes.MsgChannelUpgradeTryResponse{Result: channeltypes.FAILURE}, nil
}
Expand Down Expand Up @@ -826,9 +822,7 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh
if err != nil {
ctx.Logger().Error("channel upgrade ack failed", "error", errorsmod.Wrap(err, "channel upgrade ack failed"))
if upgradeErr, ok := err.(*channeltypes.UpgradeError); ok {
if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr); err != nil {
return nil, errorsmod.Wrap(err, "channel upgrade ack (abort upgrade) failed")
}
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, upgradeErr)

// NOTE: a FAILURE result is returned to the client and an error receipt is written to state.
// This signals to the relayer to begin the cancel upgrade handshake subprotocol.
Expand All @@ -843,9 +837,7 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh
err = cbs.OnChanUpgradeAck(cacheCtx, msg.PortId, msg.ChannelId, msg.CounterpartyUpgrade.Fields.Version)
if err != nil {
ctx.Logger().Error("channel upgrade ack callback failed", "port-id", msg.PortId, "channel-id", msg.ChannelId, "error", err.Error())
if err := k.ChannelKeeper.AbortUpgrade(ctx, msg.PortId, msg.ChannelId, err); err != nil {
return nil, errorsmod.Wrapf(err, "channel upgrade ack callback (abort upgrade) failed for port ID: %s, channel ID: %s", msg.PortId, msg.ChannelId)
}
k.ChannelKeeper.MustAbortUpgrade(ctx, msg.PortId, msg.ChannelId, err)

return &channeltypes.MsgChannelUpgradeAckResponse{Result: channeltypes.FAILURE}, nil
}
Expand Down

0 comments on commit 05c43dc

Please sign in to comment.