Skip to content

Commit

Permalink
ChanUpgradeTimeout handler for 04-channel (#3829)
Browse files Browse the repository at this point in the history
  • Loading branch information
charleenfei authored Jun 17, 2023
1 parent e201873 commit 1625adb
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 1 deletion.
94 changes: 93 additions & 1 deletion modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,103 @@ func (k Keeper) ChanUpgradeTimeout(
ctx sdk.Context,
portID, channelID string,
counterpartyChannel types.Channel,
prevErrorReceipt types.ErrorReceipt,
prevErrorReceipt *types.ErrorReceipt,
proofCounterpartyChannel,
proofErrorReceipt []byte,
proofHeight exported.Height,
) error {
channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

if channel.State != types.INITUPGRADE {
return errorsmod.Wrapf(types.ErrInvalidChannelState, "channel state is not INITUPGRADE (got %s)", channel.State)
}

upgrade, found := k.GetUpgrade(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrUpgradeNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

connection, found := k.connectionKeeper.GetConnection(ctx, channel.ConnectionHops[0])
if !found {
return errorsmod.Wrap(
connectiontypes.ErrConnectionNotFound,
channel.ConnectionHops[0],
)
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(
connectiontypes.ErrInvalidConnectionState,
"connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String(),
)
}

// proof must be from a height after timeout has elapsed. Either timeoutHeight or timeoutTimestamp must be defined.
// if timeoutHeight is defined and proof is from before timeout height, abort transaction
proofTimestamp, err := k.connectionKeeper.GetTimestampAtHeight(ctx, connection, proofHeight)
if err != nil {
return err
}

timeout := upgrade.Timeout
proofHeightIsInvalid := timeout.Height.IsZero() || proofHeight.LT(timeout.Height)
proofTimestampIsInvalid := timeout.Timestamp == 0 || proofTimestamp < timeout.Timestamp
if proofHeightIsInvalid && proofTimestampIsInvalid {
return errorsmod.Wrap(types.ErrInvalidUpgradeTimeout, "timeout has not yet passed on counterparty chain")
}

// counterparty channel must be proved to still be in OPEN state or INITUPGRADE state (crossing hellos)
if !collections.Contains(counterpartyChannel.State, []types.State{types.OPEN, types.INITUPGRADE}) {
return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.OPEN, types.INITUPGRADE, counterpartyChannel.State)
}

// verify the counterparty channel state
if err := k.connectionKeeper.VerifyChannelState(
ctx,
connection,
proofHeight, proofCounterpartyChannel,
channel.Counterparty.PortId,
channel.Counterparty.ChannelId,
counterpartyChannel,
); err != nil {
return errorsmod.Wrap(err, "failed to verify counterparty channel state")
}

// Error receipt passed in is either nil or it is a stale error receipt from a previous upgrade
if prevErrorReceipt == nil {
if err := k.connectionKeeper.VerifyChannelUpgradeErrorAbsence(
ctx,
channel.Counterparty.PortId, channel.Counterparty.ChannelId,
connection,
proofErrorReceipt,
proofHeight,
); err != nil {
return errorsmod.Wrap(err, "failed to verify absence of counterparty channel upgrade error receipt")
}

return nil
}
// timeout for this sequence can only succeed if the error receipt written into the error path on the counterparty
// was for a previous sequence by the timeout deadline.
upgradeSequence := channel.UpgradeSequence
if upgradeSequence < prevErrorReceipt.Sequence {
return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "previous counterparty error receipt sequence is greater than our current upgrade sequence: %d > %d", prevErrorReceipt.Sequence, upgradeSequence)
}

if err := k.connectionKeeper.VerifyChannelUpgradeError(
ctx,
channel.Counterparty.PortId, channel.Counterparty.ChannelId,
connection,
*prevErrorReceipt,
proofErrorReceipt,
proofHeight,
); err != nil {
return errorsmod.Wrap(err, "failed to verify counterparty channel upgrade error receipt")
}

return nil
}

Expand Down
191 changes: 191 additions & 0 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,197 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
}
}

func (suite *KeeperTestSuite) TestChanUpgradeTimeout() {
var (
path *ibctesting.Path
errReceipt *types.ErrorReceipt
proofHeight exported.Height
proofCounterpartyChannel []byte
proofErrorReceipt []byte
)

testCases := []struct {
name string
malleate func()
expError error
}{
{
"success: proof height has passed",
func() {},
nil,
},
{
"success: proof timestamp has passed",
func() {
upgrade := path.EndpointA.GetProposedUpgrade()
upgrade.Timeout.Height = defaultTimeoutHeight
upgrade.Timeout.Timestamp = 5
suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade)

suite.Require().NoError(path.EndpointA.UpdateClient())

proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof()
upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey)
},
nil,
},
{
"success: non-nil error receipt",
func() {
errReceipt = &types.ErrorReceipt{
Sequence: 1,
Message: types.ErrInvalidUpgrade.Error(),
}

suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, *errReceipt)

suite.Require().NoError(path.EndpointB.UpdateClient())
suite.Require().NoError(path.EndpointA.UpdateClient())

proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof()
upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey)
},
nil,
},
{
"channel not found",
func() {
path.EndpointA.ChannelID = ibctesting.InvalidID
},
types.ErrChannelNotFound,
},
{
"channel state is not in INITUPGRADE state",
func() {
suite.Require().NoError(path.EndpointA.SetChannelState(types.ACKUPGRADE))
},
types.ErrInvalidChannelState,
},
{
"current upgrade not found",
func() {
suite.chainA.DeleteKey(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))
},
types.ErrUpgradeNotFound,
},
{
"connection not found",
func() {
channel := path.EndpointA.GetChannel()
channel.ConnectionHops[0] = ibctesting.InvalidID
path.EndpointA.SetChannel(channel)
},
connectiontypes.ErrConnectionNotFound,
},
{
"connection not open",
func() {
connectionEnd := path.EndpointA.GetConnection()
connectionEnd.State = connectiontypes.UNINITIALIZED
path.EndpointA.SetConnection(connectionEnd)
},
connectiontypes.ErrInvalidConnectionState,
},
{
"unable to retrieve timestamp at proof height",
func() {
proofHeight = suite.chainA.GetTimeoutHeight()
},
clienttypes.ErrConsensusStateNotFound,
},
{
"timeout has not passed",
func() {
upgrade := path.EndpointA.GetProposedUpgrade()
upgrade.Timeout.Height = suite.chainA.GetTimeoutHeight()
suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade)

suite.Require().NoError(path.EndpointA.UpdateClient())

proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof()
upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey)
},
types.ErrInvalidUpgradeTimeout,
},
{
"counterparty channel state is not OPEN or INITUPGRADE (crossing hellos)",
func() {
channel := path.EndpointB.GetChannel()
channel.State = types.TRYUPGRADE
path.EndpointB.SetChannel(channel)

suite.Require().NoError(path.EndpointB.UpdateClient())
suite.Require().NoError(path.EndpointA.UpdateClient())

proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof()
upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey)
},
types.ErrInvalidChannelState,
},
{
"non-nil error receipt: error receipt seq greater than current upgrade seq",
func() {
errReceipt = &types.ErrorReceipt{
Sequence: 3,
Message: types.ErrInvalidUpgrade.Error(),
}
},
types.ErrInvalidUpgradeSequence,
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
suite.SetupTest()
expPass := tc.expError == nil

path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion

errReceipt = nil

// set timeout height to 1 to ensure timeout
path.EndpointA.ChannelConfig.ProposedUpgrade.Timeout.Height = clienttypes.NewHeight(1, 1)
suite.Require().NoError(path.EndpointA.ChanUpgradeInit())

// ensure clients are up to date to receive valid proofs
suite.Require().NoError(path.EndpointB.UpdateClient())
suite.Require().NoError(path.EndpointA.UpdateClient())

proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof()
upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey)

tc.malleate()

err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeTimeout(
suite.chainA.GetContext(),
path.EndpointA.ChannelConfig.PortID,
path.EndpointA.ChannelID,
path.EndpointB.GetChannel(),
errReceipt,
proofCounterpartyChannel,
proofErrorReceipt,
proofHeight,
)

if expPass {
suite.Require().NoError(err)
} else {
suite.assertUpgradeError(err, tc.expError)
}
})
}
}

// TestStartFlushUpgradeHandshake tests the startFlushUpgradeHandshake.
// UpgradeInit will be run on chainA and startFlushUpgradeHandshake
// will be called on chainB
Expand Down
1 change: 1 addition & 0 deletions modules/core/04-channel/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ var (
ErrInvalidFlushStatus = errorsmod.Register(SubModuleName, 33, "invalid flush status")
ErrUpgradeRestoreFailed = errorsmod.Register(SubModuleName, 34, "restore failed")
ErrUpgradeTimeout = errorsmod.Register(SubModuleName, 35, "upgrade timed-out")
ErrInvalidUpgradeTimeout = errorsmod.Register(SubModuleName, 36, "upgrade timeout is invalid")
)

0 comments on commit 1625adb

Please sign in to comment.