From 5cf62801976c24a151a3155befd9432c7f101976 Mon Sep 17 00:00:00 2001 From: toshiSat <10103480+toshiSat@users.noreply.github.com> Date: Thu, 30 Jan 2025 12:34:39 -0700 Subject: [PATCH] Make sure from address != to address in tc claim --- x/claim/keeper/msg_server_claim_thorchain.go | 7 ++- .../keeper/msg_server_claim_thorchain_test.go | 51 +++++++++++++++++++ x/claim/types/errors.go | 1 + 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/x/claim/keeper/msg_server_claim_thorchain.go b/x/claim/keeper/msg_server_claim_thorchain.go index 7a3a059c..9efccf63 100644 --- a/x/claim/keeper/msg_server_claim_thorchain.go +++ b/x/claim/keeper/msg_server_claim_thorchain.go @@ -14,6 +14,11 @@ func (k msgServer) ClaimThorchain(goCtx context.Context, msg *types.MsgClaimThor ctx := sdk.UnwrapSDKContext(goCtx) k.Logger(ctx).Info(msg.Creator) + // Add check for matching addresses + if msg.FromAddress == msg.ToAddress { + return nil, errors.Wrapf(types.ErrInvalidAddress, "from address and to address cannot be the same: %s", msg.FromAddress) + } + // only allow thorchain claim server address to call this function if msg.Creator != "tarkeo1z02ke8639m47g9dfrheegr2u9zecegt5qvtj00" && msg.Creator != "arkeo1z02ke8639m47g9dfrheegr2u9zecegt50fjg7v" { return nil, errors.Wrapf(types.ErrInvalidCreator, "Invalid Creator %s", msg.Creator) @@ -23,7 +28,7 @@ func (k msgServer) ClaimThorchain(goCtx context.Context, msg *types.MsgClaimThor if err != nil { return nil, errors.Wrapf(err, "failed to get claim record for %s", msg.FromAddress) } - if fromAddressClaimRecord.IsEmpty() || fromAddressClaimRecord.AmountClaim.IsZero() { + if fromAddressClaimRecord.IsEmpty() || (fromAddressClaimRecord.AmountClaim.IsZero() && fromAddressClaimRecord.AmountVote.IsZero() && fromAddressClaimRecord.AmountDelegate.IsZero()) { return nil, errors.Wrapf(types.ErrNoClaimableAmount, "no claimable amount for %s", msg.FromAddress) } diff --git a/x/claim/keeper/msg_server_claim_thorchain_test.go b/x/claim/keeper/msg_server_claim_thorchain_test.go index 2add8579..bfead1a3 100644 --- a/x/claim/keeper/msg_server_claim_thorchain_test.go +++ b/x/claim/keeper/msg_server_claim_thorchain_test.go @@ -156,3 +156,54 @@ func TestClaimThorchainMainnetAddress(t *testing.T) { _, err = msgServer.ClaimThorchain(ctx, &claimMessage) require.ErrorIs(t, err, types.ErrNoClaimableAmount) } + +func TestClaimThorchainFailureCases(t *testing.T) { + msgServer, keepers, ctx := setupMsgServer(t) + sdkCtx := sdk.UnwrapSDKContext(ctx) + + config := sdk.GetConfig() + config.SetBech32PrefixForAccount("arkeo", "arkeopub") + + arkeoServerAddress, err := sdk.AccAddressFromBech32("arkeo1z02ke8639m47g9dfrheegr2u9zecegt50fjg7v") + require.NoError(t, err) + + fromAddr := utils.GetRandomArkeoAddress() + toAddr := utils.GetRandomArkeoAddress() + + // Test case 1: Same from and to address + sameAddressMsg := types.MsgClaimThorchain{ + Creator: arkeoServerAddress.String(), + FromAddress: fromAddr.String(), + ToAddress: fromAddr.String(), + } + _, err = msgServer.ClaimThorchain(ctx, &sameAddressMsg) + require.ErrorIs(t, types.ErrInvalidAddress, err) + + // Test case 2: Empty claim record for from address + emptyFromMsg := types.MsgClaimThorchain{ + Creator: arkeoServerAddress.String(), + FromAddress: fromAddr.String(), + ToAddress: toAddr.String(), + } + _, err = msgServer.ClaimThorchain(ctx, &emptyFromMsg) + require.ErrorIs(t, types.ErrNoClaimableAmount, err) + + // Test case 3: Zero amount claim record + zeroClaimRecord := types.ClaimRecord{ + Chain: types.ARKEO, + Address: fromAddr.String(), + AmountClaim: sdk.NewInt64Coin(types.DefaultClaimDenom, 0), + AmountVote: sdk.NewInt64Coin(types.DefaultClaimDenom, 0), + AmountDelegate: sdk.NewInt64Coin(types.DefaultClaimDenom, 0), + } + err = keepers.ClaimKeeper.SetClaimRecord(sdkCtx, zeroClaimRecord) + require.NoError(t, err) + + zeroAmountMsg := types.MsgClaimThorchain{ + Creator: arkeoServerAddress.String(), + FromAddress: fromAddr.String(), + ToAddress: toAddr.String(), + } + _, err = msgServer.ClaimThorchain(ctx, &zeroAmountMsg) + require.ErrorIs(t, types.ErrNoClaimableAmount, err) +} diff --git a/x/claim/types/errors.go b/x/claim/types/errors.go index 3079c18b..7102b660 100644 --- a/x/claim/types/errors.go +++ b/x/claim/types/errors.go @@ -10,4 +10,5 @@ var ( ErrInvalidSignature = errors.Register(ModuleName, 3, "Invalid signature") ErrClaimRecordNotTransferrable = errors.Register(ModuleName, 4, "Claim record can not be transferred") ErrInvalidCreator = errors.Register(ModuleName, 5, "Invalid Creator") + ErrInvalidAddress = errors.Register(ModuleName, 6, "invalid address") )