Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(chore) Refactor code around forwarding validation #6706

Merged
merged 12 commits into from
Jun 26, 2024
Merged
71 changes: 36 additions & 35 deletions modules/apps/transfer/types/msgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,22 @@ func NewMsgTransfer(
// NOTE: The recipient addresses format is not validated as the format defined by
// the chain is not known to IBC.
func (msg MsgTransfer) ValidateBasic() error {
if err := validateSourcePortAndChannel(msg); err != nil {
return err // The actual error and its message are already wrapped in the called function.
if err := msg.validateForwarding(); err != nil {
return err
}
if !msg.Forwarding.Unwind {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ate my other comment 😅

this seems odd outside of the validateForwarding, seems like it could be mvoed there if not for early return of if !msg.ShouldBeForwarded. Non blocker for me though, should probably check the walkthrough call first to see concern raised!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to achieve the same, but couldn't find a way that's actually nice.

The problem is that, if we move this inside validateForwarding, we are actually validating something that is not forwarding-related (since we have len(hops)==0 && !msg.Forwarind.Unwind) hidden there.

I thought of the current situation as a valid compromise but I'm open to suggestion

// We verify that portID and channelID are valid IDs only if
// we are not setting unwind to true.
// In that case, validation that they are empty is performed in
// validateForwarding().
if err := host.PortIdentifierValidator(msg.SourcePort); err != nil {
return errorsmod.Wrap(err, "invalid source port ID")
}

if err := host.ChannelIdentifierValidator(msg.SourceChannel); err != nil {
return errorsmod.Wrap(err, "invalid source channel ID")
}
}
if len(msg.Tokens) == 0 && !isValidIBCCoin(msg.Token) {
return errorsmod.Wrap(ibcerrors.ErrInvalidCoins, "either token or token array must be filled")
}
Expand All @@ -99,30 +111,41 @@ func (msg MsgTransfer) ValidateBasic() error {
return errorsmod.Wrapf(ErrInvalidMemo, "memo must not exceed %d bytes", MaximumMemoLength)
}

for _, coin := range msg.GetCoins() {
if err := validateIBCCoin(coin); err != nil {
return errorsmod.Wrapf(ibcerrors.ErrInvalidCoins, "%s: %s", err.Error(), coin.String())
}
}

return nil
}

func (msg MsgTransfer) validateForwarding() error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets slap a docusting on this here method

if !msg.ShouldBeForwarded() {
return nil
}
if err := msg.Forwarding.Validate(); err != nil {
return err
}

if msg.ShouldBeForwarded() {
if !msg.TimeoutHeight.IsZero() {
// when forwarding, the timeout height must not be set
if !msg.TimeoutHeight.IsZero() {
return errorsmod.Wrapf(ErrInvalidPacketTimeout, "timeout height must not be set if forwarding path hops is not empty: %s, %s", msg.TimeoutHeight, msg.Forwarding.Hops)
}
return errorsmod.Wrapf(ErrInvalidPacketTimeout, "timeout height must not be set if forwarding path hops is not empty: %s, %s", msg.TimeoutHeight, msg.Forwarding.Hops)
bznein marked this conversation as resolved.
Show resolved Hide resolved
}

if msg.Forwarding.Unwind {
// When unwinding, we must have at most one token.
if msg.SourcePort != "" {
return errorsmod.Wrapf(ErrInvalidForwarding, "source port must be empty when unwind is set, got %s instead", msg.SourcePort)
}
if msg.SourceChannel != "" {
return errorsmod.Wrapf(ErrInvalidForwarding, "source channel must be empty when unwind is set, got %s instead", msg.SourceChannel)
}
if len(msg.GetCoins()) > 1 {
// When unwinding, we must have at most one token.
return errorsmod.Wrap(ibcerrors.ErrInvalidCoins, "cannot unwind more than one token")
}
}

for _, coin := range msg.GetCoins() {
if err := validateIBCCoin(coin); err != nil {
return errorsmod.Wrapf(ibcerrors.ErrInvalidCoins, "%s: %s", err.Error(), coin.String())
}
}

return nil
}

Expand Down Expand Up @@ -164,25 +187,3 @@ func validateIBCCoin(coin sdk.Coin) error {

return nil
}

func validateSourcePortAndChannel(msg MsgTransfer) error {
// If unwind is set, we want to ensure that port and channel are empty.
if msg.Forwarding.Unwind {
if msg.SourcePort != "" {
return errorsmod.Wrapf(ErrInvalidForwarding, "source port must be empty when unwind is set, got %s instead", msg.SourcePort)
}
if msg.SourceChannel != "" {
return errorsmod.Wrapf(ErrInvalidForwarding, "source channel must be empty when unwind is set, got %s instead", msg.SourceChannel)
}
return nil
}

// Otherwise, we just do the usual validation of the port and channel identifiers.
if err := host.PortIdentifierValidator(msg.SourcePort); err != nil {
return errorsmod.Wrap(err, "invalid source port ID")
}
if err := host.ChannelIdentifierValidator(msg.SourceChannel); err != nil {
return errorsmod.Wrap(err, "invalid source channel ID")
}
return nil
}
2 changes: 2 additions & 0 deletions modules/apps/transfer/types/msgs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ func TestMsgTransferValidation(t *testing.T) {
{"invalid forwarding info port", types.NewMsgTransfer(validPort, validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, types.Hop{PortId: invalidPort, ChannelId: validChannel})), types.ErrInvalidForwarding},
{"invalid forwarding info channel", types.NewMsgTransfer(validPort, validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, types.Hop{PortId: validPort, ChannelId: invalidChannel})), types.ErrInvalidForwarding},
{"invalid forwarding info too many hops", types.NewMsgTransfer(validPort, validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, generateHops(types.MaximumNumberOfForwardingHops+1)...)), types.ErrInvalidForwarding},
{"invalid portID when forwarding is set but unwind is not", types.NewMsgTransfer("", validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, validHop)), host.ErrInvalidID},
{"invalid channelID when forwarding is set but unwind is not", types.NewMsgTransfer(validPort, "", coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, validHop)), host.ErrInvalidID},
{"unwind specified but source port is not empty", types.NewMsgTransfer(validPort, "", coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(true)), types.ErrInvalidForwarding},
{"unwind specified but source channel is not empty", types.NewMsgTransfer("", validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(true)), types.ErrInvalidForwarding},
{"unwind specified but more than one coin in the message", types.NewMsgTransfer("", "", coins.Add(sdk.NewCoin("atom", ibctesting.TestCoin.Amount)), sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(true)), ibcerrors.ErrInvalidCoins},
Expand Down
Loading