diff --git a/x/ibc/core/04-channel/types/channel.go b/x/ibc/core/04-channel/types/channel.go index 9bddb84fa40c..8513a8123df8 100644 --- a/x/ibc/core/04-channel/types/channel.go +++ b/x/ibc/core/04-channel/types/channel.go @@ -55,7 +55,7 @@ func (ch Channel) GetVersion() string { // ValidateBasic performs a basic validation of the channel fields func (ch Channel) ValidateBasic() error { - if ch.State.String() == "" { + if ch.State == UNINITIALIZED { return ErrInvalidChannelState } if !(ch.Ordering == ORDERED || ch.Ordering == UNORDERED) { diff --git a/x/ibc/core/04-channel/types/channel_test.go b/x/ibc/core/04-channel/types/channel_test.go index 14592d4b4cff..30fee4443b2e 100644 --- a/x/ibc/core/04-channel/types/channel_test.go +++ b/x/ibc/core/04-channel/types/channel_test.go @@ -8,6 +8,33 @@ import ( "github.com/cosmos/cosmos-sdk/x/ibc/core/04-channel/types" ) +func TestChannelValidateBasic(t *testing.T) { + counterparty := types.Counterparty{"portidone", "channelidone"} + testCases := []struct { + name string + channel types.Channel + expPass bool + }{ + {"valid channel", types.NewChannel(types.TRYOPEN, types.ORDERED, counterparty, connHops, version), true}, + {"invalid state", types.NewChannel(types.UNINITIALIZED, types.ORDERED, counterparty, connHops, version), false}, + {"invalid order", types.NewChannel(types.TRYOPEN, types.NONE, counterparty, connHops, version), false}, + {"more than 1 connection hop", types.NewChannel(types.TRYOPEN, types.ORDERED, counterparty, []string{"connection1", "connection2"}, version), false}, + {"invalid connection hop identifier", types.NewChannel(types.TRYOPEN, types.ORDERED, counterparty, []string{"(invalid)"}, version), false}, + {"invalid counterparty", types.NewChannel(types.TRYOPEN, types.ORDERED, types.NewCounterparty("(invalidport)", "channelidone"), connHops, version), false}, + } + + for i, tc := range testCases { + tc := tc + + err := tc.channel.ValidateBasic() + if tc.expPass { + require.NoError(t, err, "valid test case %d failed: %s", i, tc.name) + } else { + require.Error(t, err, "invalid test case %d passed: %s", i, tc.name) + } + } +} + func TestCounterpartyValidateBasic(t *testing.T) { testCases := []struct { name string