From dbcff45015337d734591dc33c0b4886c687b5af8 Mon Sep 17 00:00:00 2001 From: Cian Hatton Date: Thu, 23 May 2024 09:28:04 +0100 Subject: [PATCH] Refactor packet data unmarshalling to use specific version (#6354) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: specifically unmarshal v1 or v2 without brute force * chore: fix TestPacketDataUnmarshalerInterface test in transfer module * Pass dest values OnRecv, refactor GetExpectedEvents * chore: fixing TestGetCallbackData test * chore: fixed remaining tests in callbacks module * test: simplify callbacks test, revert back to previous behaviour * chore: fix test case name * chore: addressing PR feedback * chore: added docstring for unmarshalPacketDataBytesToICS20V2 --------- Co-authored-by: DimitrisJim Co-authored-by: Colin Axnér <25233464+colin-axner@users.noreply.github.com> --- modules/apps/callbacks/callbacks_test.go | 4 +- modules/apps/callbacks/ibc_middleware_test.go | 43 ++++++- .../apps/callbacks/types/callbacks_test.go | 117 ++++++++---------- modules/apps/callbacks/types/export_test.go | 8 ++ modules/apps/callbacks/types/types_test.go | 11 +- modules/apps/transfer/ibc_module.go | 71 ++++++++--- modules/apps/transfer/ibc_module_test.go | 44 ++++++- 7 files changed, 202 insertions(+), 96 deletions(-) diff --git a/modules/apps/callbacks/callbacks_test.go b/modules/apps/callbacks/callbacks_test.go index ef88db93483..f9874896a22 100644 --- a/modules/apps/callbacks/callbacks_test.go +++ b/modules/apps/callbacks/callbacks_test.go @@ -305,11 +305,11 @@ func GetExpectedEvent( gasMeter := storetypes.NewGasMeter(remainingGas) ctx = ctx.WithGasMeter(gasMeter) - // Mock packet. - packet := channeltypes.NewPacket(data, 0, srcPortID, "", "", "", clienttypes.ZeroHeight(), 0) if callbackType == types.CallbackTypeReceivePacket { + packet := channeltypes.NewPacket(data, seq, "", "", eventPortID, eventChannelID, clienttypes.ZeroHeight(), 0) callbackData, err = types.GetDestCallbackData(ctx, packetDataUnmarshaler, packet, maxCallbackGas) } else { + packet := channeltypes.NewPacket(data, seq, eventPortID, eventChannelID, "", "", clienttypes.ZeroHeight(), 0) callbackData, err = types.GetSourceCallbackData(ctx, packetDataUnmarshaler, packet, maxCallbackGas) } if err != nil { diff --git a/modules/apps/callbacks/ibc_middleware_test.go b/modules/apps/callbacks/ibc_middleware_test.go index 0b487d93611..7c394bcaa9f 100644 --- a/modules/apps/callbacks/ibc_middleware_test.go +++ b/modules/apps/callbacks/ibc_middleware_test.go @@ -971,8 +971,13 @@ func (s *CallbacksTestSuite) TestProcessCallback() { } } -func (s *CallbacksTestSuite) TestUnmarshalPacketData() { +func (s *CallbacksTestSuite) TestUnmarshalPacketDataV1() { s.setupChains() + s.path.EndpointA.ChannelConfig.PortID = ibctesting.TransferPort + s.path.EndpointB.ChannelConfig.PortID = ibctesting.TransferPort + s.path.EndpointA.ChannelConfig.Version = transfertypes.V1 + s.path.EndpointB.ChannelConfig.Version = transfertypes.V1 + s.path.Setup() // We will pass the function call down the transfer stack to the transfer module // transfer stack UnmarshalPacketData call order: callbacks -> fee -> transfer @@ -1006,15 +1011,43 @@ func (s *CallbacksTestSuite) TestUnmarshalPacketData() { portID := s.path.EndpointA.ChannelConfig.PortID channelID := s.path.EndpointA.ChannelID - // Unmarshal ICS20 v1 packet data + // Unmarshal ICS20 v1 packet data into v2 packet data data := expPacketDataICS20V1.GetBytes() packetData, err := unmarshalerStack.UnmarshalPacketData(s.chainA.GetContext(), portID, channelID, data) s.Require().NoError(err) s.Require().Equal(expPacketDataICS20V2, packetData) +} + +func (s *CallbacksTestSuite) TestUnmarshalPacketDataV2() { + s.SetupTransferTest() + + // We will pass the function call down the transfer stack to the transfer module + // transfer stack UnmarshalPacketData call order: callbacks -> fee -> transfer + transferStack, ok := s.chainA.App.GetIBCKeeper().PortKeeper.Route(transfertypes.ModuleName) + s.Require().True(ok) - // Unmarshal ICS20 v1 packet data - data = expPacketDataICS20V2.GetBytes() - packetData, err = unmarshalerStack.UnmarshalPacketData(s.chainA.GetContext(), portID, channelID, data) + unmarshalerStack, ok := transferStack.(types.CallbacksCompatibleModule) + s.Require().True(ok) + + expPacketDataICS20V2 := transfertypes.FungibleTokenPacketDataV2{ + Tokens: []transfertypes.Token{ + { + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Trace: nil, + }, + }, + Sender: ibctesting.TestAccAddress, + Receiver: ibctesting.TestAccAddress, + Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}, "dest_callback": {"address":"%s"}}`, ibctesting.TestAccAddress, ibctesting.TestAccAddress), + } + + portID := s.path.EndpointA.ChannelConfig.PortID + channelID := s.path.EndpointA.ChannelID + + // Unmarshal ICS20 v2 packet data + data := expPacketDataICS20V2.GetBytes() + packetData, err := unmarshalerStack.UnmarshalPacketData(s.chainA.GetContext(), portID, channelID, data) s.Require().NoError(err) s.Require().Equal(expPacketDataICS20V2, packetData) } diff --git a/modules/apps/callbacks/types/callbacks_test.go b/modules/apps/callbacks/types/callbacks_test.go index 754118dda90..be4098f9bc9 100644 --- a/modules/apps/callbacks/types/callbacks_test.go +++ b/modules/apps/callbacks/types/callbacks_test.go @@ -10,23 +10,20 @@ import ( "github.com/cometbft/cometbft/crypto/secp256k1" "github.com/cosmos/ibc-go/modules/apps/callbacks/types" - "github.com/cosmos/ibc-go/v8/modules/apps/transfer" transfertypes "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types" clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types" channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" - porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types" ibctesting "github.com/cosmos/ibc-go/v8/testing" ibcmock "github.com/cosmos/ibc-go/v8/testing/mock" ) func (s *CallbacksTypesTestSuite) TestGetCallbackData() { var ( - sender = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() - receiver = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() - packetDataUnmarshaler porttypes.PacketDataUnmarshaler - packetData []byte - remainingGas uint64 - callbackKey string + sender = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + receiver = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + packetData interface{} + remainingGas uint64 + callbackKey string ) // max gas is 1_000_000 @@ -40,14 +37,13 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { "success: source callback", func() { remainingGas = 2_000_000 - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: fmt.Sprintf(`{"src_callback": {"address": "%s"}}`, sender), } - packetData = expPacketData.GetBytes() }, types.CallbackData{ CallbackAddress: sender, @@ -61,15 +57,15 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { "success: destination callback", func() { callbackKey = types.DestinationCallbackKey + remainingGas = 2_000_000 - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s"}}`, sender), } - packetData = expPacketData.GetBytes() }, types.CallbackData{ CallbackAddress: sender, @@ -83,15 +79,15 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { "success: destination callback with 0 user defined gas limit", func() { callbackKey = types.DestinationCallbackKey + remainingGas = 2_000_000 - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s", "gas_limit":"0"}}`, sender), } - packetData = expPacketData.GetBytes() }, types.CallbackData{ CallbackAddress: sender, @@ -104,14 +100,13 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { { "success: source callback with gas limit < remaining gas < max gas", func() { - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "50000"}}`, sender), } - packetData = expPacketData.GetBytes() remainingGas = 100_000 }, @@ -127,14 +122,13 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { "success: source callback with remaining gas < gas limit < max gas", func() { remainingGas = 100_000 - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "200000"}}`, sender), } - packetData = expPacketData.GetBytes() }, types.CallbackData{ CallbackAddress: sender, @@ -148,14 +142,13 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { "success: source callback with remaining gas < max gas < gas limit", func() { remainingGas = 100_000 - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "2000000"}}`, sender), } - packetData = expPacketData.GetBytes() }, types.CallbackData{ CallbackAddress: sender, @@ -169,15 +162,15 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { "success: destination callback with remaining gas < max gas < gas limit", func() { callbackKey = types.DestinationCallbackKey + remainingGas = 100_000 - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: fmt.Sprintf(`{"dest_callback": {"address": "%s", "gas_limit": "2000000"}}`, sender), } - packetData = expPacketData.GetBytes() }, types.CallbackData{ CallbackAddress: sender, @@ -191,14 +184,13 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { "success: source callback with max gas < remaining gas < gas limit", func() { remainingGas = 2_000_000 - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: fmt.Sprintf(`{"src_callback": {"address": "%s", "gas_limit": "3000000"}}`, sender), } - packetData = expPacketData.GetBytes() }, types.CallbackData{ CallbackAddress: sender, @@ -208,19 +200,10 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { }, nil, }, - { - "failure: invalid packet data", - func() { - packetData = []byte("invalid packet data") - }, - types.CallbackData{}, - types.ErrCannotUnmarshalPacketData, - }, { "failure: packet data does not implement PacketDataProvider", func() { packetData = ibcmock.MockPacketData - packetDataUnmarshaler = ibcmock.IBCModule{} }, types.CallbackData{}, types.ErrNotPacketDataProvider, @@ -228,14 +211,13 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { { "failure: empty memo", func() { - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: "", } - packetData = expPacketData.GetBytes() }, types.CallbackData{}, types.ErrCallbackKeyNotFound, @@ -243,14 +225,13 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { { "failure: empty address", func() { - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: `{"src_callback": {"address": ""}}`, } - packetData = expPacketData.GetBytes() }, types.CallbackData{}, types.ErrCallbackAddressNotFound, @@ -258,14 +239,13 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { { "failure: space address", func() { - expPacketData := transfertypes.FungibleTokenPacketData{ + packetData = transfertypes.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), Sender: sender, Receiver: receiver, Memo: `{"src_callback": {"address": " "}}`, } - packetData = expPacketData.GetBytes() }, types.CallbackData{}, types.ErrCallbackAddressNotFound, @@ -275,27 +255,13 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { for _, tc := range testCases { tc := tc s.Run(tc.name, func() { - callbackKey = types.SourceCallbackKey + s.SetupTest() - packetDataUnmarshaler = transfer.IBCModule{} + callbackKey = types.SourceCallbackKey tc.malleate() - // Set up gas meter for context. - gasMeter := storetypes.NewGasMeter(remainingGas) - ctx := s.chain.GetContext().WithGasMeter(gasMeter) - - packet := channeltypes.NewPacket(packetData, 0, ibcmock.PortID, "", "", "", clienttypes.ZeroHeight(), 0) - - var ( - callbackData types.CallbackData - err error - ) - if callbackKey == types.DestinationCallbackKey { - callbackData, err = types.GetDestCallbackData(ctx, packetDataUnmarshaler, packet, uint64(1_000_000)) - } else { - callbackData, err = types.GetSourceCallbackData(ctx, packetDataUnmarshaler, packet, uint64(1_000_000)) - } + callbackData, err := types.GetCallbackData(packetData, transfertypes.PortID, remainingGas, uint64(1_000_000), callbackKey) expPass := tc.expError == nil if expPass { @@ -312,6 +278,8 @@ func (s *CallbacksTypesTestSuite) TestGetCallbackData() { } func (s *CallbacksTypesTestSuite) TestGetSourceCallbackDataTransfer() { + s.SetupTest() + sender := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() receiver := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() @@ -331,19 +299,32 @@ func (s *CallbacksTypesTestSuite) TestGetSourceCallbackDataTransfer() { CommitGasLimit: 1_000_000, } - packetUnmarshaler := transfer.IBCModule{} + s.path.EndpointA.ChannelConfig.Version = transfertypes.V1 + s.path.EndpointA.ChannelConfig.PortID = transfertypes.ModuleName + s.path.EndpointB.ChannelConfig.Version = transfertypes.V1 + s.path.EndpointB.ChannelConfig.PortID = transfertypes.ModuleName + + transferStack, ok := s.chainA.App.GetIBCKeeper().PortKeeper.Route(transfertypes.ModuleName) + s.Require().True(ok) + + packetUnmarshaler, ok := transferStack.(types.CallbacksCompatibleModule) + s.Require().True(ok) + + s.path.Setup() // Set up gas meter for context. gasMeter := storetypes.NewGasMeter(2_000_000) - ctx := s.chain.GetContext().WithGasMeter(gasMeter) + ctx := s.chainA.GetContext().WithGasMeter(gasMeter) - packet := channeltypes.NewPacket(packetDataBytes, 0, ibcmock.PortID, "", "", "", clienttypes.ZeroHeight(), 0) + packet := channeltypes.NewPacket(packetDataBytes, 0, transfertypes.PortID, s.path.EndpointA.ChannelID, transfertypes.PortID, s.path.EndpointB.ChannelID, clienttypes.ZeroHeight(), 0) callbackData, err := types.GetSourceCallbackData(ctx, packetUnmarshaler, packet, 1_000_000) s.Require().NoError(err) s.Require().Equal(expCallbackData, callbackData) } func (s *CallbacksTypesTestSuite) TestGetDestCallbackDataTransfer() { + s.SetupTest() + sender := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() receiver := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() @@ -363,12 +344,22 @@ func (s *CallbacksTypesTestSuite) TestGetDestCallbackDataTransfer() { CommitGasLimit: 1_000_000, } - packetUnmarshaler := transfer.IBCModule{} + s.path.EndpointA.ChannelConfig.Version = transfertypes.V1 + s.path.EndpointA.ChannelConfig.PortID = transfertypes.ModuleName + s.path.EndpointB.ChannelConfig.Version = transfertypes.V1 + s.path.EndpointB.ChannelConfig.PortID = transfertypes.ModuleName - gasMeter := storetypes.NewGasMeter(2_000_000) - ctx := s.chain.GetContext().WithGasMeter(gasMeter) + transferStack, ok := s.chainA.App.GetIBCKeeper().PortKeeper.Route(transfertypes.ModuleName) + s.Require().True(ok) + + packetUnmarshaler, ok := transferStack.(types.CallbacksCompatibleModule) + s.Require().True(ok) - packet := channeltypes.NewPacket(packetDataBytes, 0, ibcmock.PortID, "", "", "", clienttypes.ZeroHeight(), 0) + s.path.Setup() + + gasMeter := storetypes.NewGasMeter(2_000_000) + ctx := s.chainA.GetContext().WithGasMeter(gasMeter) + packet := channeltypes.NewPacket(packetDataBytes, 0, transfertypes.PortID, s.path.EndpointA.ChannelID, transfertypes.PortID, s.path.EndpointB.ChannelID, clienttypes.ZeroHeight(), 0) callbackData, err := types.GetDestCallbackData(ctx, packetUnmarshaler, packet, 1_000_000) s.Require().NoError(err) s.Require().Equal(expCallbackData, callbackData) diff --git a/modules/apps/callbacks/types/export_test.go b/modules/apps/callbacks/types/export_test.go index 5b0a32508e3..4cc85624827 100644 --- a/modules/apps/callbacks/types/export_test.go +++ b/modules/apps/callbacks/types/export_test.go @@ -4,6 +4,14 @@ package types This file is to allow for unexported functions to be accessible to the testing package. */ +// GetCallbackData is a wrapper around getCallbackData to allow the function to be directly called in tests. +func GetCallbackData( + packetData interface{}, srcPortID string, remainingGas, + maxGas uint64, callbackKey string, +) (CallbackData, error) { + return getCallbackData(packetData, srcPortID, remainingGas, maxGas, callbackKey) +} + // GetCallbackAddress is a wrapper around getCallbackAddress to allow the function to be directly called in tests. func GetCallbackAddress(callbackData map[string]interface{}) string { return getCallbackAddress(callbackData) diff --git a/modules/apps/callbacks/types/types_test.go b/modules/apps/callbacks/types/types_test.go index 7bb61f95162..5843f972249 100644 --- a/modules/apps/callbacks/types/types_test.go +++ b/modules/apps/callbacks/types/types_test.go @@ -15,12 +15,19 @@ type CallbacksTypesTestSuite struct { coord *ibctesting.Coordinator chain *ibctesting.TestChain + + chainA, chainB *ibctesting.TestChain + + path *ibctesting.Path } // SetupTest creates a coordinator with 1 test chain. -func (s *CallbacksTypesTestSuite) SetupSuite() { - s.coord = ibctesting.NewCoordinator(s.T(), 1) +func (s *CallbacksTypesTestSuite) SetupTest() { + s.coord = ibctesting.NewCoordinator(s.T(), 3) s.chain = s.coord.GetChain(ibctesting.GetChainID(1)) + s.chainA = s.coord.GetChain(ibctesting.GetChainID(2)) + s.chainB = s.coord.GetChain(ibctesting.GetChainID(3)) + s.path = ibctesting.NewPath(s.chainA, s.chainB) } func TestCallbacksTypesTestSuite(t *testing.T) { diff --git a/modules/apps/transfer/ibc_module.go b/modules/apps/transfer/ibc_module.go index a58e777a3e8..c87591e6c9f 100644 --- a/modules/apps/transfer/ibc_module.go +++ b/modules/apps/transfer/ibc_module.go @@ -174,25 +174,35 @@ func (IBCModule) OnChanCloseConfirm( return nil } -func (IBCModule) unmarshalPacketDataBytesToICS20V2(bz []byte) (types.FungibleTokenPacketDataV2, error) { - // TODO: remove support for this function parsing v1 packet data - // TODO: explicit check for packet data type against app version - - var datav1 types.FungibleTokenPacketData - if err := json.Unmarshal(bz, &datav1); err == nil { - if len(datav1.Denom) != 0 { - return convertinternal.PacketDataV1ToV2(datav1), nil +// unmarshalPacketDataBytesToICS20V2 attempts to unmarshal the provided packet data bytes into a FungibleTokenPacketDataV2. +// The version of ics20 should be provided and should be either ics20-1 or ics20-2. +func (IBCModule) unmarshalPacketDataBytesToICS20V2(bz []byte, ics20Version string) (types.FungibleTokenPacketDataV2, error) { + switch ics20Version { + case types.V1: + var datav1 types.FungibleTokenPacketData + if err := json.Unmarshal(bz, &datav1); err != nil { + return types.FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot unmarshal ICS20-V2 transfer packet data: %s", err.Error()) } - } - var data types.FungibleTokenPacketDataV2 - if err := json.Unmarshal(bz, &data); err == nil { - if len(data.Tokens) != 0 { - return data, nil + if err := datav1.ValidateBasic(); err != nil { + return types.FungibleTokenPacketDataV2{}, err + } + + return convertinternal.PacketDataV1ToV2(datav1), nil + case types.V2: + var datav2 types.FungibleTokenPacketDataV2 + if err := json.Unmarshal(bz, &datav2); err != nil { + return types.FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot unmarshal ICS20-V2 transfer packet data: %s", err.Error()) } - } - return types.FungibleTokenPacketDataV2{}, errorsmod.Wrapf(ibcerrors.ErrInvalidType, "cannot unmarshal ICS-20 transfer packet data") + if err := datav2.ValidateBasic(); err != nil { + return types.FungibleTokenPacketDataV2{}, err + } + + return datav2, nil + default: + return types.FungibleTokenPacketDataV2{}, errorsmod.Wrap(types.ErrInvalidVersion, ics20Version) + } } // OnRecvPacket implements the IBCModule interface. A successful acknowledgement @@ -205,8 +215,14 @@ func (im IBCModule) OnRecvPacket( ) ibcexported.Acknowledgement { ack := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) - data, ackErr := im.unmarshalPacketDataBytesToICS20V2(packet.GetData()) + ics20Version, found := im.keeper.GetICS4Wrapper().GetAppVersion(ctx, packet.DestinationPort, packet.DestinationChannel) + if !found { + return channeltypes.NewErrorAcknowledgement(errorsmod.Wrapf(ibcerrors.ErrNotFound, "app version not found for port %s and channel %s", packet.DestinationPort, packet.DestinationChannel)) + } + + data, ackErr := im.unmarshalPacketDataBytesToICS20V2(packet.GetData(), ics20Version) if ackErr != nil { + ackErr = errorsmod.Wrapf(ibcerrors.ErrInvalidType, ackErr.Error()) im.keeper.Logger(ctx).Error(fmt.Sprintf("%s sequence %d", ackErr.Error(), packet.Sequence)) ack = channeltypes.NewErrorAcknowledgement(ackErr) } @@ -261,7 +277,12 @@ func (im IBCModule) OnAcknowledgementPacket( return errorsmod.Wrapf(ibcerrors.ErrUnknownRequest, "cannot unmarshal ICS-20 transfer packet acknowledgement: %v", err) } - data, err := im.unmarshalPacketDataBytesToICS20V2(packet.GetData()) + ics20Version, found := im.keeper.GetICS4Wrapper().GetAppVersion(ctx, packet.SourcePort, packet.SourceChannel) + if !found { + return errorsmod.Wrapf(ibcerrors.ErrNotFound, "app version not found for port %s and channel %s", packet.SourcePort, packet.SourceChannel) + } + + data, err := im.unmarshalPacketDataBytesToICS20V2(packet.GetData(), ics20Version) if err != nil { return err } @@ -311,7 +332,12 @@ func (im IBCModule) OnTimeoutPacket( packet channeltypes.Packet, relayer sdk.AccAddress, ) error { - data, err := im.unmarshalPacketDataBytesToICS20V2(packet.GetData()) + ics20Version, found := im.keeper.GetICS4Wrapper().GetAppVersion(ctx, packet.SourcePort, packet.SourceChannel) + if !found { + return errorsmod.Wrapf(ibcerrors.ErrNotFound, "app version not found for port %s and channel %s", packet.SourcePort, packet.SourceChannel) + } + + data, err := im.unmarshalPacketDataBytesToICS20V2(packet.GetData(), ics20Version) if err != nil { return err } @@ -380,8 +406,13 @@ func (IBCModule) OnChanUpgradeOpen(ctx sdk.Context, portID, channelID string, pr // UnmarshalPacketData attempts to unmarshal the provided packet data bytes // into a FungibleTokenPacketData. This function implements the optional // PacketDataUnmarshaler interface required for ADR 008 support. -func (im IBCModule) UnmarshalPacketData(_ sdk.Context, _, _ string, bz []byte) (interface{}, error) { - ftpd, err := im.unmarshalPacketDataBytesToICS20V2(bz) +func (im IBCModule) UnmarshalPacketData(ctx sdk.Context, portID, channelID string, bz []byte) (interface{}, error) { + ics20Version, found := im.keeper.GetICS4Wrapper().GetAppVersion(ctx, portID, channelID) + if !found { + return nil, errorsmod.Wrapf(ibcerrors.ErrNotFound, "app version not found for port %s and channel %s", portID, channelID) + } + + ftpd, err := im.unmarshalPacketDataBytesToICS20V2(bz, ics20Version) if err != nil { return nil, err } diff --git a/modules/apps/transfer/ibc_module_test.go b/modules/apps/transfer/ibc_module_test.go index 46a2302ff8e..e52b04fc7da 100644 --- a/modules/apps/transfer/ibc_module_test.go +++ b/modules/apps/transfer/ibc_module_test.go @@ -496,6 +496,7 @@ func (suite *TransferTestSuite) TestPacketDataUnmarshalerInterface() { sender = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() receiver = sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()).String() + path *ibctesting.Path data []byte initialPacketData interface{} ) @@ -508,6 +509,7 @@ func (suite *TransferTestSuite) TestPacketDataUnmarshalerInterface() { { "success: valid packet data single denom -> multidenom conversion with memo", func() { + path.EndpointA.ChannelConfig.Version = types.V1 initialPacketData = types.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), @@ -523,6 +525,7 @@ func (suite *TransferTestSuite) TestPacketDataUnmarshalerInterface() { { "success: valid packet data single denom -> multidenom conversion without memo", func() { + path.EndpointA.ChannelConfig.Version = types.V1 initialPacketData = types.FungibleTokenPacketData{ Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), @@ -538,6 +541,7 @@ func (suite *TransferTestSuite) TestPacketDataUnmarshalerInterface() { { "success: valid packet data single denom with trace -> multidenom conversion with trace", func() { + path.EndpointA.ChannelConfig.Version = types.V1 initialPacketData = types.FungibleTokenPacketData{ Denom: "transfer/channel-0/atom", Amount: ibctesting.TestCoin.Amount.String(), @@ -571,14 +575,15 @@ func (suite *TransferTestSuite) TestPacketDataUnmarshalerInterface() { nil, }, { - "success: valid packet data multidenom without memo", + "success: valid packet data multidenom nil trace", func() { + path.EndpointA.ChannelConfig.Version = types.V2 initialPacketData = types.FungibleTokenPacketDataV2{ Tokens: []types.Token{ { Denom: ibctesting.TestCoin.Denom, Amount: ibctesting.TestCoin.Amount.String(), - Trace: []string{""}, + Trace: nil, }, }, Sender: sender, @@ -590,21 +595,52 @@ func (suite *TransferTestSuite) TestPacketDataUnmarshalerInterface() { }, nil, }, + { + "failure: invalid token trace", + func() { + path.EndpointA.ChannelConfig.Version = types.V2 + initialPacketData = types.FungibleTokenPacketDataV2{ + Tokens: []types.Token{ + { + Denom: ibctesting.TestCoin.Denom, + Amount: ibctesting.TestCoin.Amount.String(), + Trace: []string{""}, + }, + }, + Sender: sender, + Receiver: receiver, + Memo: "", + } + + data = initialPacketData.(types.FungibleTokenPacketDataV2).GetBytes() + }, + errors.New("trace info must come in pairs of port and channel identifiers"), + }, { "failure: invalid packet data", func() { data = []byte("invalid packet data") }, - errors.New("cannot unmarshal ICS-20 transfer packet data"), + errors.New("cannot unmarshal ICS20-V2 transfer packet data"), }, } for _, tc := range testCases { tc := tc suite.Run(tc.name, func() { + path = ibctesting.NewTransferPath(suite.chainA, suite.chainB) + tc.malleate() - packetData, err := transfer.IBCModule{}.UnmarshalPacketData(suite.chainA.GetContext(), "", "", data) + path.Setup() + + transferStack, ok := suite.chainA.App.GetIBCKeeper().PortKeeper.Route(types.ModuleName) + suite.Require().True(ok) + + unmarshalerStack, ok := transferStack.(porttypes.PacketDataUnmarshaler) + suite.Require().True(ok) + + packetData, err := unmarshalerStack.UnmarshalPacketData(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, data) expPass := tc.expError == nil if expPass {