Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid race conditions in ics27 handshakes #2682

Merged
merged 10 commits into from
Nov 7, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/suite"

"github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller"
controllerkeeper "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller/keeper"
"github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller/types"
icatypes "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/types"
fee "github.com/cosmos/ibc-go/v6/modules/apps/29-fee"
Expand Down Expand Up @@ -840,3 +841,80 @@ func (suite *InterchainAccountsTestSuite) TestGetAppVersion() {
suite.Require().True(found)
suite.Require().Equal(path.EndpointA.ChannelConfig.Version, appVersion)
}

func (suite *InterchainAccountsTestSuite) TestInFlightHandshakeRespectsGoAPICaller() {
path := NewICAPath(suite.chainA, suite.chainB)
suite.coordinator.SetupConnections(path)

// initiate a channel handshake such that channel.State == INIT
err := RegisterInterchainAccount(path.EndpointA, suite.chainA.SenderAccount.GetAddress().String())
suite.Require().NoError(err)

// attempt to start a second handshake via the controller msg server
msgServer := controllerkeeper.NewMsgServerImpl(&suite.chainA.GetSimApp().ICAControllerKeeper)
msgRegisterInterchainAccount := types.NewMsgRegisterInterchainAccount(path.EndpointA.ConnectionID, suite.chainA.SenderAccount.GetAddress().String(), TestVersion)

res, err := msgServer.RegisterInterchainAccount(suite.chainA.GetContext(), msgRegisterInterchainAccount)
suite.Require().Error(err)
suite.Require().Nil(res)
}

func (suite *InterchainAccountsTestSuite) TestInFlightHandshakeRespectsMsgServerCaller() {
path := NewICAPath(suite.chainA, suite.chainB)
suite.coordinator.SetupConnections(path)

// initiate a channel handshake such that channel.State == INIT
msgServer := controllerkeeper.NewMsgServerImpl(&suite.chainA.GetSimApp().ICAControllerKeeper)
msgRegisterInterchainAccount := types.NewMsgRegisterInterchainAccount(path.EndpointA.ConnectionID, suite.chainA.SenderAccount.GetAddress().String(), TestVersion)

res, err := msgServer.RegisterInterchainAccount(suite.chainA.GetContext(), msgRegisterInterchainAccount)
suite.Require().NotNil(res)
suite.Require().NoError(err)

// attempt to start a second handshake via the legacy Go API
err = RegisterInterchainAccount(path.EndpointA, suite.chainA.SenderAccount.GetAddress().String())
suite.Require().Error(err)
}

func (suite *InterchainAccountsTestSuite) TestClosedChannelReopensWithMsgServer() {
path := NewICAPath(suite.chainA, suite.chainB)
suite.coordinator.SetupConnections(path)

err := SetupICAPath(path, suite.chainA.SenderAccount.GetAddress().String())
suite.Require().NoError(err)

// set the channel state to closed
err = path.EndpointA.SetChannelClosed()
suite.Require().NoError(err)
err = path.EndpointB.SetChannelClosed()
suite.Require().NoError(err)

// reset endpoint channel ids
path.EndpointA.ChannelID = ""
path.EndpointB.ChannelID = ""

// fetch the next channel sequence before reinitiating the channel handshake
channelSeq := suite.chainA.GetSimApp().GetIBCKeeper().ChannelKeeper.GetNextChannelSequence(suite.chainA.GetContext())

// route a new MsgRegisterInterchainAccount in order to reopen the
msgServer := controllerkeeper.NewMsgServerImpl(&suite.chainA.GetSimApp().ICAControllerKeeper)
msgRegisterInterchainAccount := types.NewMsgRegisterInterchainAccount(path.EndpointA.ConnectionID, suite.chainA.SenderAccount.GetAddress().String(), path.EndpointA.ChannelConfig.Version)

res, err := msgServer.RegisterInterchainAccount(suite.chainA.GetContext(), msgRegisterInterchainAccount)
suite.Require().NoError(err)
suite.Require().Equal(channeltypes.FormatChannelIdentifier(channelSeq), res.ChannelId)

// assign the channel sequence to endpointA before generating proofs and initiating the TRY step
path.EndpointA.ChannelID = channeltypes.FormatChannelIdentifier(channelSeq)

path.EndpointA.Chain.NextBlock()

err = path.EndpointB.ChanOpenTry()
suite.Require().NoError(err)

err = path.EndpointA.ChanOpenAck()
suite.Require().NoError(err)

err = path.EndpointB.ChanOpenConfirm()
suite.Require().NoError(err)
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func (k Keeper) RegisterInterchainAccount(ctx sdk.Context, connectionID, owner,
return err
}

if k.IsMiddlewareDisabled(ctx, portID, connectionID) && !k.IsActiveChannelClosed(ctx, connectionID, portID) {
return sdkerrors.Wrap(icatypes.ErrInvalidChannelFlow, "channel is already active or a handshake is in flight")
}

k.SetMiddlewareEnabled(ctx, portID, connectionID)

_, err = k.registerInterchainAccount(ctx, connectionID, portID, version)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ func InitGenesis(ctx sdk.Context, keeper Keeper, state genesistypes.ControllerGe

if ch.IsMiddlewareEnabled {
keeper.SetMiddlewareEnabled(ctx, ch.PortId, ch.ConnectionId)
} else {
keeper.SetMiddlewareDisabled(ctx, ch.PortId, ch.ConnectionId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ func (suite *KeeperTestSuite) TestInitGenesis() {
ChannelId: ibctesting.FirstChannelID,
IsMiddlewareEnabled: true,
},
{
ConnectionId: "connection-1",
PortId: "test-port-1",
ChannelId: "channel-1",
IsMiddlewareEnabled: false,
},
},
InterchainAccounts: []genesistypes.RegisteredInterchainAccount{
{
Expand All @@ -40,6 +46,9 @@ func (suite *KeeperTestSuite) TestInitGenesis() {
isMiddlewareEnabled := suite.chainA.GetSimApp().ICAControllerKeeper.IsMiddlewareEnabled(suite.chainA.GetContext(), TestPortID, ibctesting.FirstConnectionID)
suite.Require().True(isMiddlewareEnabled)

isMiddlewareDisabled := suite.chainA.GetSimApp().ICAControllerKeeper.IsMiddlewareDisabled(suite.chainA.GetContext(), "test-port-1", "connection-1")
suite.Require().True(isMiddlewareDisabled)

accountAdrr, found := suite.chainA.GetSimApp().ICAControllerKeeper.GetInterchainAccountAddress(suite.chainA.GetContext(), ibctesting.FirstConnectionID, TestPortID)
suite.Require().True(found)
suite.Require().Equal(interchainAccAddr.String(), accountAdrr)
Expand Down
28 changes: 26 additions & 2 deletions modules/apps/27-interchain-accounts/controller/keeper/keeper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package keeper

import (
"bytes"
"fmt"
"strings"

Expand Down Expand Up @@ -146,6 +147,17 @@ func (k Keeper) GetOpenActiveChannel(ctx sdk.Context, connectionID, portID strin
return "", false
}

// IsActiveChannelClosed retrieves the active channel from the store and returns true if the channel state is CLOSED, otherwise false
func (k Keeper) IsActiveChannelClosed(ctx sdk.Context, connectionID, portID string) bool {
channelID, found := k.GetActiveChannelID(ctx, connectionID, portID)
if !found {
return false
}

channel, found := k.channelKeeper.GetChannel(ctx, portID, channelID)
return found && channel.State == channeltypes.CLOSED
}

// GetAllActiveChannels returns a list of all active interchain accounts controller channels and their associated connection and port identifiers
func (k Keeper) GetAllActiveChannels(ctx sdk.Context) []genesistypes.ActiveChannel {
store := ctx.KVStore(k.storeKey)
Expand Down Expand Up @@ -227,13 +239,25 @@ func (k Keeper) SetInterchainAccountAddress(ctx sdk.Context, connectionID, portI
// IsMiddlewareEnabled returns true if the underlying application callbacks are enabled for given port and connection identifier pair, otherwise false
func (k Keeper) IsMiddlewareEnabled(ctx sdk.Context, portID, connectionID string) bool {
store := ctx.KVStore(k.storeKey)
return store.Has(icatypes.KeyIsMiddlewareEnabled(portID, connectionID))
return bytes.Equal(icatypes.MiddlewareEnabled, store.Get(icatypes.KeyIsMiddlewareEnabled(portID, connectionID)))
}

// IsMiddlewareDisabled returns true if the underlying application callbacks are disabled for the given port and connection identifier pair, otherwise false
func (k Keeper) IsMiddlewareDisabled(ctx sdk.Context, portID, connectionID string) bool {
store := ctx.KVStore(k.storeKey)
return bytes.Equal(icatypes.MiddlewareDisabled, store.Get(icatypes.KeyIsMiddlewareEnabled(portID, connectionID)))
}

// SetMiddlewareEnabled stores a flag to indicate that the underlying application callbacks should be enabled for the given port and connection identifier pair
func (k Keeper) SetMiddlewareEnabled(ctx sdk.Context, portID, connectionID string) {
store := ctx.KVStore(k.storeKey)
store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), []byte{byte(1)})
store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), icatypes.MiddlewareEnabled)
}

// SetMiddlewareDisabled stores a flag to indicate that the underlying application callbacks should be disabled for the given port and connection identifier pair
func (k Keeper) SetMiddlewareDisabled(ctx sdk.Context, portID, connectionID string) {
store := ctx.KVStore(k.storeKey)
store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), icatypes.MiddlewareDisabled)
}

// DeleteMiddlewareEnabled deletes the middleware enabled flag stored in state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"

sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"

"github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller/types"
icatypes "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/types"
Expand All @@ -30,6 +31,12 @@ func (s msgServer) RegisterInterchainAccount(goCtx context.Context, msg *types.M
return nil, err
}

if s.IsMiddlewareEnabled(ctx, portID, msg.ConnectionId) && !s.IsActiveChannelClosed(ctx, msg.ConnectionId, portID) {
return nil, sdkerrors.Wrap(icatypes.ErrInvalidChannelFlow, "channel is already active or a handshake is in flight")
}

s.SetMiddlewareDisabled(ctx, portID, msg.ConnectionId)

channelID, err := s.registerInterchainAccount(ctx, msg.ConnectionId, portID, msg.Version)
if err != nil {
return nil, err
Expand Down
6 changes: 6 additions & 0 deletions modules/apps/27-interchain-accounts/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ var (

// IsMiddlewareEnabledPrefix defines the key prefix used to store a flag for legacy API callback routing via ibc middleware
IsMiddlewareEnabledPrefix = "isMiddlewareEnabled"

// MiddlewareEnabled is the value used to signal that controller middleware is enabled
MiddlewareEnabled = []byte{0x01}

// MiddlewareDisabled is the value used to signal that controller midleware is disabled
MiddlewareDisabled = []byte{0x02}
Comment on lines 43 to +50
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be in favour of moving all the middleware enabled associated key/values to controllertypes. But I don't feel particularly strongly. Happy to leave in the top level icatypes as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No preference! Whatever you think is best

)

// KeyActiveChannel creates and returns a new key used for active channels store operations
Expand Down