diff --git a/x/records/keeper/epoch_unbonding_record.go b/x/records/keeper/epoch_unbonding_record.go index 734d9c5a1..645f5e465 100644 --- a/x/records/keeper/epoch_unbonding_record.go +++ b/x/records/keeper/epoch_unbonding_record.go @@ -119,22 +119,21 @@ func (k Keeper) AddHostZoneToEpochUnbondingRecord(ctx sdk.Context, epochNumber u return &epochUnbondingRecord, true } -// TODO: unittest -func (k Keeper) SetHostZoneUnbondings(ctx sdk.Context, zone stakeibctypes.HostZone, epochUnbondingRecordIds []uint64, status types.HostZoneUnbonding_Status) error { +func (k Keeper) SetHostZoneUnbondings(ctx sdk.Context, chainId string, epochUnbondingRecordIds []uint64, status types.HostZoneUnbonding_Status) error { for _, epochUnbondingRecordId := range epochUnbondingRecordIds { k.Logger(ctx).Info(fmt.Sprintf("Updating host zone unbondings on EpochUnbondingRecord %d to status %s", epochUnbondingRecordId, status.String())) // fetch the host zone unbonding - hostZoneUnbonding, found := k.GetHostZoneUnbondingByChainId(ctx, epochUnbondingRecordId, zone.ChainId) + hostZoneUnbonding, found := k.GetHostZoneUnbondingByChainId(ctx, epochUnbondingRecordId, chainId) if !found { - errMsg := fmt.Sprintf("Error fetching host zone unbonding record for epoch: %d, host zone: %s", epochUnbondingRecordId, zone.ChainId) + errMsg := fmt.Sprintf("Error fetching host zone unbonding record for epoch: %d, host zone: %s", epochUnbondingRecordId, chainId) k.Logger(ctx).Error(errMsg) return sdkerrors.Wrapf(stakeibctypes.ErrHostZoneNotFound, errMsg) } hostZoneUnbonding.Status = status // save the updated hzu on the epoch unbonding record - updatedRecord, success := k.AddHostZoneToEpochUnbondingRecord(ctx, epochUnbondingRecordId, zone.ChainId, hostZoneUnbonding) + updatedRecord, success := k.AddHostZoneToEpochUnbondingRecord(ctx, epochUnbondingRecordId, chainId, hostZoneUnbonding) if !success { - errMsg := fmt.Sprintf("Error adding host zone unbonding record to epoch unbonding record: %d, host zone: %s", epochUnbondingRecordId, zone.ChainId) + errMsg := fmt.Sprintf("Error adding host zone unbonding record to epoch unbonding record: %d, host zone: %s", epochUnbondingRecordId, chainId) k.Logger(ctx).Error(errMsg) return sdkerrors.Wrap(types.ErrAddingHostZone, errMsg) } diff --git a/x/records/keeper/epoch_unbonding_record_test.go b/x/records/keeper/epoch_unbonding_record_test.go index c423003d4..9a0f87878 100644 --- a/x/records/keeper/epoch_unbonding_record_test.go +++ b/x/records/keeper/epoch_unbonding_record_test.go @@ -12,19 +12,45 @@ import ( "github.com/Stride-Labs/stride/v3/x/records/types" ) -func createNEpochUnbondingRecord(keeper *keeper.Keeper, ctx sdk.Context, n int) []types.EpochUnbondingRecord { - items := make([]types.EpochUnbondingRecord, n) - for i, item := range items { - item.EpochNumber = uint64(i) - items[i] = item - keeper.SetEpochUnbondingRecord(ctx, item) +func createNEpochUnbondingRecord(keeper *keeper.Keeper, ctx sdk.Context, n int) ([]types.EpochUnbondingRecord, map[string]types.HostZoneUnbonding) { + hostZoneUnbondingsList := []types.HostZoneUnbonding{ + { + HostZoneId: "host-A", + Status: types.HostZoneUnbonding_UNBONDING_QUEUE, + }, + { + HostZoneId: "host-B", + Status: types.HostZoneUnbonding_UNBONDING_QUEUE, + }, + { + HostZoneId: "host-C", + Status: types.HostZoneUnbonding_UNBONDING_QUEUE, + }, } - return items + hostZoneUnbondingsMap := make(map[string]types.HostZoneUnbonding) + for _, hostZoneUnbonding := range hostZoneUnbondingsList { + hostZoneUnbondingsMap[hostZoneUnbonding.HostZoneId] = hostZoneUnbonding + } + + epochUnbondingRecords := make([]types.EpochUnbondingRecord, n) + for epochNumber, epochUnbondingRecord := range epochUnbondingRecords { + epochUnbondingRecord.EpochNumber = uint64(epochNumber) + + unbondingsCopy := make([]*types.HostZoneUnbonding, 3) + for i := range unbondingsCopy { + hostZoneUnbonding := hostZoneUnbondingsList[i] + epochUnbondingRecord.HostZoneUnbondings = append(epochUnbondingRecord.HostZoneUnbondings, &hostZoneUnbonding) + } + + epochUnbondingRecords[epochNumber] = epochUnbondingRecord + keeper.SetEpochUnbondingRecord(ctx, epochUnbondingRecord) + } + return epochUnbondingRecords, hostZoneUnbondingsMap } func TestEpochUnbondingRecordGet(t *testing.T) { keeper, ctx := keepertest.RecordsKeeper(t) - items := createNEpochUnbondingRecord(keeper, ctx, 10) + items, _ := createNEpochUnbondingRecord(keeper, ctx, 10) for _, item := range items { got, found := keeper.GetEpochUnbondingRecord(ctx, item.EpochNumber) require.True(t, found) @@ -37,7 +63,7 @@ func TestEpochUnbondingRecordGet(t *testing.T) { func TestEpochUnbondingRecordRemove(t *testing.T) { keeper, ctx := keepertest.RecordsKeeper(t) - items := createNEpochUnbondingRecord(keeper, ctx, 10) + items, _ := createNEpochUnbondingRecord(keeper, ctx, 10) for _, item := range items { keeper.RemoveEpochUnbondingRecord(ctx, item.EpochNumber) _, found := keeper.GetEpochUnbondingRecord(ctx, item.EpochNumber) @@ -47,7 +73,7 @@ func TestEpochUnbondingRecordRemove(t *testing.T) { func TestEpochUnbondingRecordGetAll(t *testing.T) { keeper, ctx := keepertest.RecordsKeeper(t) - items := createNEpochUnbondingRecord(keeper, ctx, 10) + items, _ := createNEpochUnbondingRecord(keeper, ctx, 10) require.ElementsMatch(t, nullify.Fill(items), nullify.Fill(keeper.GetAllEpochUnbondingRecord(ctx)), @@ -56,7 +82,7 @@ func TestEpochUnbondingRecordGetAll(t *testing.T) { func TestGetAllPreviousEpochUnbondingRecords(t *testing.T) { keeper, ctx := keepertest.RecordsKeeper(t) - items := createNEpochUnbondingRecord(keeper, ctx, 10) + items, _ := createNEpochUnbondingRecord(keeper, ctx, 10) currentEpoch := uint64(8) fetchedItems := items[:currentEpoch] require.ElementsMatch(t, @@ -64,3 +90,75 @@ func TestGetAllPreviousEpochUnbondingRecords(t *testing.T) { nullify.Fill(keeper.GetAllPreviousEpochUnbondingRecords(ctx, currentEpoch)), ) } + +func TestGetHostZoneUnbondingByChainId(t *testing.T) { + keeper, ctx := keepertest.RecordsKeeper(t) + _, hostZoneUnbondings := createNEpochUnbondingRecord(keeper, ctx, 10) + + expectedHostZoneUnbonding := hostZoneUnbondings["host-B"] + actualHostZoneUnbonding, found := keeper.GetHostZoneUnbondingByChainId(ctx, 1, "host-B") + + require.True(t, found) + require.Equal(t, + *actualHostZoneUnbonding, + expectedHostZoneUnbonding, + ) +} + +func TestAddHostZoneToEpochUnbondingRecord(t *testing.T) { + keeper, ctx := keepertest.RecordsKeeper(t) + epochUnbondingRecords, _ := createNEpochUnbondingRecord(keeper, ctx, 3) + + epochNumber := 0 + initialEpochUnbondingRecord := epochUnbondingRecords[epochNumber] + + // Add new host zone to initial epoch unbonding records + newHostZone := types.HostZoneUnbonding{ + HostZoneId: "host-D", + Status: types.HostZoneUnbonding_UNBONDING_QUEUE, + } + expectedEpochUnbondingRecord := initialEpochUnbondingRecord + expectedEpochUnbondingRecord.HostZoneUnbondings = append(expectedEpochUnbondingRecord.HostZoneUnbondings, &newHostZone) + + actualEpochUnbondingRecord, success := keeper.AddHostZoneToEpochUnbondingRecord(ctx, uint64(epochNumber), "host-D", &newHostZone) + + require.True(t, success) + require.Equal(t, + expectedEpochUnbondingRecord, + *actualEpochUnbondingRecord, + ) +} + +func TestSetHostZoneUnbondings(t *testing.T) { + keeper, ctx := keepertest.RecordsKeeper(t) + + initialEpochUnbondingRecords, _ := createNEpochUnbondingRecord(keeper, ctx, 4) + + epochsToUpdate := []uint64{1, 3} + hostIdToUpdate := "host-B" + newStatus := types.HostZoneUnbonding_UNBONDING_IN_PROGRESS + + expectedEpochUnbondingRecords := initialEpochUnbondingRecords + for _, epochUnbondingRecord := range expectedEpochUnbondingRecords { + for _, epochNumberToUpdate := range epochsToUpdate { + if epochUnbondingRecord.EpochNumber == epochNumberToUpdate { + for i, hostUnbonding := range epochUnbondingRecord.HostZoneUnbondings { + if hostUnbonding.HostZoneId == hostIdToUpdate { + updatedHostZoneUnbonding := hostUnbonding + updatedHostZoneUnbonding.Status = newStatus + epochUnbondingRecord.HostZoneUnbondings[i] = updatedHostZoneUnbonding + } + } + } + } + } + + err := keeper.SetHostZoneUnbondings(ctx, hostIdToUpdate, epochsToUpdate, newStatus) + require.Nil(t, err) + + actualEpochUnbondingRecord := keeper.GetAllEpochUnbondingRecord(ctx) + require.ElementsMatch(t, + expectedEpochUnbondingRecords, + actualEpochUnbondingRecord, + ) +} diff --git a/x/records/keeper/grpc_query_epoch_unbonding_record_test.go b/x/records/keeper/grpc_query_epoch_unbonding_record_test.go index bf2da6e6a..311dc38e7 100644 --- a/x/records/keeper/grpc_query_epoch_unbonding_record_test.go +++ b/x/records/keeper/grpc_query_epoch_unbonding_record_test.go @@ -18,7 +18,7 @@ import ( func TestEpochUnbondingRecordQuerySingle(t *testing.T) { keeper, ctx := keepertest.RecordsKeeper(t) wctx := sdk.WrapSDKContext(ctx) - msgs := createNEpochUnbondingRecord(keeper, ctx, 2) + msgs, _ := createNEpochUnbondingRecord(keeper, ctx, 2) for _, tc := range []struct { desc string request *types.QueryGetEpochUnbondingRecordRequest @@ -63,7 +63,7 @@ func TestEpochUnbondingRecordQuerySingle(t *testing.T) { func TestEpochUnbondingRecordQueryPaginated(t *testing.T) { keeper, ctx := keepertest.RecordsKeeper(t) wctx := sdk.WrapSDKContext(ctx) - msgs := createNEpochUnbondingRecord(keeper, ctx, 5) + msgs, _ := createNEpochUnbondingRecord(keeper, ctx, 5) request := func(next []byte, offset, limit uint64, total bool) *types.QueryAllEpochUnbondingRecordRequest { return &types.QueryAllEpochUnbondingRecordRequest{ diff --git a/x/stakeibc/keeper/icacallbacks_redemption.go b/x/stakeibc/keeper/icacallbacks_redemption.go index 102837119..f0b050e54 100644 --- a/x/stakeibc/keeper/icacallbacks_redemption.go +++ b/x/stakeibc/keeper/icacallbacks_redemption.go @@ -54,10 +54,6 @@ func RedemptionCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a if ack == nil { // handle timeout k.Logger(ctx).Error(fmt.Sprintf("RedemptionCallback timeout, ack is nil, packet %v", packet)) - err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE) - if err != nil { - return err - } return nil } @@ -70,14 +66,14 @@ func RedemptionCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a if len(txMsgData.Data) == 0 { // handle tx failure k.Logger(ctx).Error(fmt.Sprintf("RedemptionCallback tx failed, txMsgData is empty, ack error, packet %v", packet)) - err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE) + err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone.ChainId, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE) if err != nil { return err } return nil } - err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_CLAIMABLE) + err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone.ChainId, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_CLAIMABLE) if err != nil { return err } diff --git a/x/stakeibc/keeper/icacallbacks_undelegate.go b/x/stakeibc/keeper/icacallbacks_undelegate.go index 41b84771c..7b7d36ac0 100644 --- a/x/stakeibc/keeper/icacallbacks_undelegate.go +++ b/x/stakeibc/keeper/icacallbacks_undelegate.go @@ -57,11 +57,6 @@ func UndelegateCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a // handle transaction failure cases if ack == nil { // handle timeout - // reset to UNBONDING_QUEUE - err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_UNBONDING_QUEUE) - if err != nil { - return err - } k.Logger(ctx).Error(fmt.Sprintf("UndelegateCallback timeout, txMsgData is nil, packet %v", packet)) return nil } @@ -73,7 +68,7 @@ func UndelegateCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a if len(txMsgData.Data) == 0 { // handle tx failure // reset to UNBONDING_QUEUE - err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_UNBONDING_QUEUE) + err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone.ChainId, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_UNBONDING_QUEUE) if err != nil { return err } @@ -103,7 +98,7 @@ func UndelegateCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a return err } // upon success, add host zone unbondings to the exit transfer queue - err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE) + err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone.ChainId, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE) if err != nil { return err } diff --git a/x/stakeibc/keeper/msg_server_restore_interchain_account.go b/x/stakeibc/keeper/msg_server_restore_interchain_account.go index 605d696f1..6d3edc726 100644 --- a/x/stakeibc/keeper/msg_server_restore_interchain_account.go +++ b/x/stakeibc/keeper/msg_server_restore_interchain_account.go @@ -44,15 +44,59 @@ func (k msgServer) RestoreInterchainAccount(goCtx context.Context, msg *types.Ms // If we're restoring a delegation account, we also have to reset record state if msg.AccountType == types.ICAAccountType_DELEGATION { - depositRecords := k.RecordsKeeper.GetAllDepositRecord(ctx) // revert DELEGATION_IN_PROGRESS records for the closed ICA channel (so that they can be staked) - for _, record := range depositRecords { + depositRecords := k.RecordsKeeper.GetAllDepositRecord(ctx) + for _, depositRecord := range depositRecords { + // only revert records for the select host zone + if depositRecord.HostZoneId == hostZone.ChainId && depositRecord.Status == recordtypes.DepositRecord_DELEGATION_IN_PROGRESS { + depositRecord.Status = recordtypes.DepositRecord_DELEGATION_QUEUE + k.Logger(ctx).Info(fmt.Sprintf("Setting DepositRecord %d to status DepositRecord_DELEGATION_IN_PROGRESS", depositRecord.Id)) + k.RecordsKeeper.SetDepositRecord(ctx, depositRecord) + } + } + + // revert epoch unbonding records for the closed ICA channel + epochUnbondingRecords := k.RecordsKeeper.GetAllEpochUnbondingRecord(ctx) + epochNumberForPendingUnbondingRecords := []uint64{} + epochNumberForPendingTransferRecords := []uint64{} + for _, epochUnbondingRecord := range epochUnbondingRecords { // only revert records for the select host zone - if record.HostZoneId == hostZone.ChainId && record.Status == recordtypes.DepositRecord_DELEGATION_IN_PROGRESS { - record.Status = recordtypes.DepositRecord_DELEGATION_QUEUE - k.Logger(ctx).Error(fmt.Sprintf("Setting DepositRecord %d to status DepositRecord_DELEGATION_IN_PROGRESS", record.Id)) - k.RecordsKeeper.SetDepositRecord(ctx, record) + hostZoneUnbonding, found := k.RecordsKeeper.GetHostZoneUnbondingByChainId(ctx, epochUnbondingRecord.EpochNumber, hostZone.ChainId) + if !found { + k.Logger(ctx).Info(fmt.Sprintf("No HostZoneUnbonding found for chainId: %s, epoch: %d", hostZone.ChainId, epochUnbondingRecord.EpochNumber)) + continue } + + // Revert UNBONDING_IN_PROGRESS and EXIT_TRANSFER_IN_PROGRESS records + if hostZoneUnbonding.Status == recordtypes.HostZoneUnbonding_UNBONDING_IN_PROGRESS { + k.Logger(ctx).Info(fmt.Sprintf("HostZoneUnbonding for %s at EpochNumber %d is stuck in status %s", + hostZone.ChainId, epochUnbondingRecord.EpochNumber, recordtypes.HostZoneUnbonding_UNBONDING_IN_PROGRESS.String(), + )) + epochNumberForPendingUnbondingRecords = append(epochNumberForPendingUnbondingRecords, epochUnbondingRecord.EpochNumber) + + } else if hostZoneUnbonding.Status == recordtypes.HostZoneUnbonding_EXIT_TRANSFER_IN_PROGRESS { + k.Logger(ctx).Info(fmt.Sprintf("HostZoneUnbonding for %s at EpochNumber %d to in status %s", + hostZone.ChainId, epochUnbondingRecord.EpochNumber, recordtypes.HostZoneUnbonding_EXIT_TRANSFER_IN_PROGRESS.String(), + )) + epochNumberForPendingTransferRecords = append(epochNumberForPendingTransferRecords, epochUnbondingRecord.EpochNumber) + } + } + // Revert UNBONDING_IN_PROGRESS records to UNBONDING_QUEUE + err := k.RecordsKeeper.SetHostZoneUnbondings(ctx, hostZone.ChainId, epochNumberForPendingUnbondingRecords, recordtypes.HostZoneUnbonding_UNBONDING_QUEUE) + if err != nil { + errMsg := fmt.Sprintf("unable to update host zone unbonding record status to %s for chainId: %s and epochUnbondingRecordIds: %v, err: %s", + recordtypes.HostZoneUnbonding_UNBONDING_QUEUE.String(), hostZone.ChainId, epochNumberForPendingUnbondingRecords, err) + k.Logger(ctx).Error(errMsg) + return nil, err + } + + // Revert EXIT_TRANSFER_IN_PROGRESS records to EXIT_TRANSFER_QUEUE + err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, hostZone.ChainId, epochNumberForPendingTransferRecords, recordtypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE) + if err != nil { + errMsg := fmt.Sprintf("unable to update host zone unbonding record status to %s for chainId: %s and epochUnbondingRecordIds: %v, err: %s", + recordtypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE.String(), hostZone.ChainId, epochNumberForPendingTransferRecords, err) + k.Logger(ctx).Error(errMsg) + return nil, err } } diff --git a/x/stakeibc/keeper/msg_server_restore_interchain_account_test.go b/x/stakeibc/keeper/msg_server_restore_interchain_account_test.go index d25fd86be..16f126df5 100644 --- a/x/stakeibc/keeper/msg_server_restore_interchain_account_test.go +++ b/x/stakeibc/keeper/msg_server_restore_interchain_account_test.go @@ -14,8 +14,20 @@ import ( stakeibc "github.com/Stride-Labs/stride/v3/x/stakeibc/types" ) +type DepositRecordStatusUpdate struct { + chainId string + initialStatus recordtypes.DepositRecord_Status + revertedStatus recordtypes.DepositRecord_Status +} + +type HostZoneUnbondingStatusUpdate struct { + initialStatus recordtypes.HostZoneUnbonding_Status + revertedStatus recordtypes.HostZoneUnbonding_Status +} type RestoreInterchainAccountTestCase struct { - validMsg stakeibc.MsgRestoreInterchainAccount + validMsg stakeibc.MsgRestoreInterchainAccount + depositRecordStatusUpdates []DepositRecordStatusUpdate + unbondingRecordStatusUpdate []HostZoneUnbondingStatusUpdate } func (s *KeeperTestSuite) SetupRestoreInterchainAccount() RestoreInterchainAccountTestCase { @@ -28,15 +40,73 @@ func (s *KeeperTestSuite) SetupRestoreInterchainAccount() RestoreInterchainAccou } s.App.StakeibcKeeper.SetHostZone(s.Ctx(), hostZone) - // Store pending records - for i := 0; i < 2; i++ { - depositRecord := recordtypes.DepositRecord{ - Id: uint64(i), - DepositEpochNumber: uint64(i), - HostZoneId: HostChainId, - Status: recordtypes.DepositRecord_DELEGATION_IN_PROGRESS, - } - s.App.RecordsKeeper.SetDepositRecord(s.Ctx(), depositRecord) + // Store deposit records with some in state pending + depositRecords := []DepositRecordStatusUpdate{ + { + // Status doesn't change + chainId: HostChainId, + initialStatus: recordtypes.DepositRecord_TRANSFER_IN_PROGRESS, + revertedStatus: recordtypes.DepositRecord_TRANSFER_IN_PROGRESS, + }, + { + // Status gets reverted from IN_PROGRESS to QUEUE + chainId: HostChainId, + initialStatus: recordtypes.DepositRecord_DELEGATION_IN_PROGRESS, + revertedStatus: recordtypes.DepositRecord_DELEGATION_QUEUE, + }, + { + // Status doesn't get reveted because it's a different host zone + chainId: "different_host_zone", + initialStatus: recordtypes.DepositRecord_DELEGATION_IN_PROGRESS, + revertedStatus: recordtypes.DepositRecord_DELEGATION_IN_PROGRESS, + }, + } + for i, depositRecord := range depositRecords { + s.App.RecordsKeeper.SetDepositRecord(s.Ctx(), recordtypes.DepositRecord{ + Id: uint64(i), + HostZoneId: depositRecord.chainId, + Status: depositRecord.initialStatus, + }) + } + + // Store epoch unbonding records with some in state pending + hostZoneUnbondingRecords := []HostZoneUnbondingStatusUpdate{ + { + // Status doesn't change + initialStatus: recordtypes.HostZoneUnbonding_UNBONDING_QUEUE, + revertedStatus: recordtypes.HostZoneUnbonding_UNBONDING_QUEUE, + }, + { + // Status gets reverted from IN_PROGRESS to QUEUE + initialStatus: recordtypes.HostZoneUnbonding_UNBONDING_IN_PROGRESS, + revertedStatus: recordtypes.HostZoneUnbonding_UNBONDING_QUEUE, + }, + { + // Status doesn't change + initialStatus: recordtypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE, + revertedStatus: recordtypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE, + }, + { + // Status gets reverted from IN_PROGRESS to QUEUE + initialStatus: recordtypes.HostZoneUnbonding_EXIT_TRANSFER_IN_PROGRESS, + revertedStatus: recordtypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE, + }, + } + for i, hostZoneUnbonding := range hostZoneUnbondingRecords { + s.App.RecordsKeeper.SetEpochUnbondingRecord(s.Ctx(), recordtypes.EpochUnbondingRecord{ + EpochNumber: uint64(i), + HostZoneUnbondings: []*recordtypes.HostZoneUnbonding{ + // The first unbonding record will get reverted, the other one will not + { + HostZoneId: HostChainId, + Status: hostZoneUnbonding.initialStatus, + }, + { + HostZoneId: "different_host_zone", + Status: hostZoneUnbonding.initialStatus, + }, + }, + }) } defaultMsg := stakeibc.MsgRestoreInterchainAccount{ @@ -46,33 +116,19 @@ func (s *KeeperTestSuite) SetupRestoreInterchainAccount() RestoreInterchainAccou } return RestoreInterchainAccountTestCase{ - validMsg: defaultMsg, + validMsg: defaultMsg, + depositRecordStatusUpdates: depositRecords, + unbondingRecordStatusUpdate: hostZoneUnbondingRecords, } } -func (s *KeeperTestSuite) TestRestoreInterchainAccount_Success() { - tc := s.SetupRestoreInterchainAccount() - owner := "GAIA.DELEGATION" - channelID := s.CreateICAChannel(owner) - portID := icatypes.PortPrefix + owner - - // Confirm there are two channels originally - channels := s.App.IBCKeeper.ChannelKeeper.GetAllChannels(s.Ctx()) - s.Require().Len(channels, 2, "there should be 2 channels initially (transfer + delegate)") - - // Close the delegation channel - channel, found := s.App.IBCKeeper.ChannelKeeper.GetChannel(s.Ctx(), portID, channelID) - s.Require().True(found, "delegation channel found") - channel.State = channeltypes.CLOSED - s.App.IBCKeeper.ChannelKeeper.SetChannel(s.Ctx(), portID, channelID, channel) - +func (s *KeeperTestSuite) RestoreChannelAndVerifySuccess(msg stakeibc.MsgRestoreInterchainAccount, portID string, channelID string) { // Restore the channel - msg := tc.validMsg _, err := s.GetMsgServer().RestoreInterchainAccount(sdk.WrapSDKContext(s.Ctx()), &msg) s.Require().NoError(err, "registered ica account successfully") - // Confirm the new channel was created - channels = s.App.IBCKeeper.ChannelKeeper.GetAllChannels(s.Ctx()) + // Confirm channel was created + channels := s.App.IBCKeeper.ChannelKeeper.GetAllChannels(s.Ctx()) s.Require().Len(channels, 3, "there should be 3 channels after restoring") // Confirm the new channel is in state INIT @@ -84,19 +140,67 @@ func (s *KeeperTestSuite) TestRestoreInterchainAccount_Success() { } } s.Require().True(newChannelActive, "a new channel should have been created") +} - // Verify the deposit record state was reverted - for i := 0; i < 2; i++ { - depositRecord, found := s.App.RecordsKeeper.GetDepositRecord(s.Ctx(), uint64(i)) +func (s *KeeperTestSuite) VerifyDepositRecordsStatus(expectedDepositRecords []DepositRecordStatusUpdate, revert bool) { + for i, expectedDepositRecord := range expectedDepositRecords { + actualDepositRecord, found := s.App.RecordsKeeper.GetDepositRecord(s.Ctx(), uint64(i)) s.Require().True(found, "deposit record found") - s.Require().Equal(recordtypes.DepositRecord_DELEGATION_QUEUE, depositRecord.Status, "deposit record status should be reverted") + + // Only revert records if the revert option is passed and the host zone matches + expectedStatus := expectedDepositRecord.initialStatus + if revert && actualDepositRecord.HostZoneId == HostChainId { + expectedStatus = expectedDepositRecord.revertedStatus + } + s.Require().Equal(expectedStatus.String(), actualDepositRecord.Status.String(), "deposit record %d status", i) } } +func (s *KeeperTestSuite) VerifyHostZoneUnbondingStatus(expectedUnbondingRecords []HostZoneUnbondingStatusUpdate, revert bool) { + for i, expectedUnbonding := range expectedUnbondingRecords { + epochUnbondingRecord, found := s.App.RecordsKeeper.GetEpochUnbondingRecord(s.Ctx(), uint64(i)) + s.Require().True(found, "epoch unbonding record found") + + for _, actualUnbonding := range epochUnbondingRecord.HostZoneUnbondings { + // Only revert records if the revert option is passed and the host zone matches + expectedStatus := expectedUnbonding.initialStatus + if revert && actualUnbonding.HostZoneId == HostChainId { + expectedStatus = expectedUnbonding.revertedStatus + } + s.Require().Equal(expectedStatus.String(), actualUnbonding.Status.String(), "host zone unbonding for epoch %d record status", i) + } + } +} + +func (s *KeeperTestSuite) TestRestoreInterchainAccount_Success() { + tc := s.SetupRestoreInterchainAccount() + owner := "GAIA.DELEGATION" + channelID := s.CreateICAChannel(owner) + portID := icatypes.PortPrefix + owner + + // Confirm there are two channels originally + channels := s.App.IBCKeeper.ChannelKeeper.GetAllChannels(s.Ctx()) + s.Require().Len(channels, 2, "there should be 2 channels initially (transfer + delegate)") + + // Close the delegation channel + channel, found := s.App.IBCKeeper.ChannelKeeper.GetChannel(s.Ctx(), portID, channelID) + s.Require().True(found, "delegation channel found") + channel.State = channeltypes.CLOSED + s.App.IBCKeeper.ChannelKeeper.SetChannel(s.Ctx(), portID, channelID, channel) + + // Confirm the new channel was created + s.RestoreChannelAndVerifySuccess(tc.validMsg, portID, channelID) + + // Verify the record status' were reverted + s.VerifyDepositRecordsStatus(tc.depositRecordStatusUpdates, true) + s.VerifyHostZoneUnbondingStatus(tc.unbondingRecordStatusUpdate, true) +} + func (s *KeeperTestSuite) TestRestoreInterchainAccount_CannotRestoreNonExistentAcct() { tc := s.SetupRestoreInterchainAccount() msg := tc.validMsg msg.AccountType = stakeibc.ICAAccountType_WITHDRAWAL + _, err := s.GetMsgServer().RestoreInterchainAccount(sdk.WrapSDKContext(s.Ctx()), &msg) expectedErrMSg := fmt.Sprintf("ICA controller account address not found: %s.WITHDRAWAL: invalid interchain account address", tc.validMsg.ChainId) @@ -107,6 +211,7 @@ func (s *KeeperTestSuite) TestRestoreInterchainAccount_FailsForIncorrectHostZone tc := s.SetupRestoreInterchainAccount() msg := tc.validMsg msg.ChainId = "incorrectchainid" + _, err := s.GetMsgServer().RestoreInterchainAccount(sdk.WrapSDKContext(s.Ctx()), &msg) expectedErrMsg := "host zone not registered" s.Require().EqualError(err, expectedErrMsg, "registered ica account fails for incorrect host zone") @@ -116,6 +221,7 @@ func (s *KeeperTestSuite) TestRestoreInterchainAccount_FailsIfAccountExists() { tc := s.SetupRestoreInterchainAccount() s.CreateICAChannel("GAIA.DELEGATION") msg := tc.validMsg + _, err := s.GetMsgServer().RestoreInterchainAccount(sdk.WrapSDKContext(s.Ctx()), &msg) expectedErrMsg := fmt.Sprintf("existing active channel channel-1 for portID icacontroller-%s.DELEGATION on connection %s for owner %s.DELEGATION: active channel already set for this owner", tc.validMsg.ChainId, @@ -129,6 +235,7 @@ func (s *KeeperTestSuite) TestRestoreInterchainAccount_RevertDepositRecords_Fail tc := s.SetupRestoreInterchainAccount() s.CreateICAChannel("GAIA.DELEGATION") msg := tc.validMsg + _, err := s.GetMsgServer().RestoreInterchainAccount(sdk.WrapSDKContext(s.Ctx()), &msg) expectedErrMsg := fmt.Sprintf("existing active channel channel-1 for portID icacontroller-%s.DELEGATION on connection %s for owner %s.DELEGATION: active channel already set for this owner", tc.validMsg.ChainId, @@ -136,16 +243,14 @@ func (s *KeeperTestSuite) TestRestoreInterchainAccount_RevertDepositRecords_Fail tc.validMsg.ChainId, ) s.Require().EqualError(err, expectedErrMsg, "registered ica account fails when account already exists") - // Verify the deposit record state was NOT reverted - for i := 0; i < 2; i++ { - depositRecord, found := s.App.RecordsKeeper.GetDepositRecord(s.Ctx(), uint64(i)) - s.Require().True(found, "deposit record found") - s.Require().Equal(recordtypes.DepositRecord_DELEGATION_IN_PROGRESS, depositRecord.Status, "deposit record status should NOT msg be reverted") - } + + // Verify the record status' were NOT reverted + s.VerifyDepositRecordsStatus(tc.depositRecordStatusUpdates, false) + s.VerifyHostZoneUnbondingStatus(tc.unbondingRecordStatusUpdate, false) } func (s *KeeperTestSuite) TestRestoreInterchainAccount_NoRecordChange_Success() { - // Here, we're closing and restoring the withdrawal channel so deposit records should not be reverted + // Here, we're closing and restoring the withdrawal channel so records should not be reverted tc := s.SetupRestoreInterchainAccount() owner := "GAIA.WITHDRAWAL" channelID := s.CreateICAChannel(owner) @@ -164,27 +269,9 @@ func (s *KeeperTestSuite) TestRestoreInterchainAccount_NoRecordChange_Success() // Restore the channel msg := tc.validMsg msg.AccountType = stakeibc.ICAAccountType_WITHDRAWAL - _, err := s.GetMsgServer().RestoreInterchainAccount(sdk.WrapSDKContext(s.Ctx()), &msg) - s.Require().NoError(err, "registered ica account successfully") - - // Confirm the new channel was created - channels = s.App.IBCKeeper.ChannelKeeper.GetAllChannels(s.Ctx()) - s.Require().Len(channels, 3, "there should be 3 channels after restoring") + s.RestoreChannelAndVerifySuccess(msg, portID, channelID) - // Confirm the new channel is in state INIT - newChannelActive := false - for _, channel := range channels { - // The new channel should have the same port, a new channel ID and be in state INIT - if channel.PortId == portID && channel.ChannelId != channelID && channel.State == channeltypes.INIT { - newChannelActive = true - } - } - s.Require().True(newChannelActive, "a new channel should have been created") - - // Verify the deposit record state was NOT reverted - for i := 0; i < 2; i++ { - depositRecord, found := s.App.RecordsKeeper.GetDepositRecord(s.Ctx(), uint64(i)) - s.Require().True(found, "deposit record found") - s.Require().Equal(recordtypes.DepositRecord_DELEGATION_IN_PROGRESS, depositRecord.Status, "deposit record status should NOT be reverted") - } + // Verify the record status' were NOT reverted + s.VerifyDepositRecordsStatus(tc.depositRecordStatusUpdates, false) + s.VerifyHostZoneUnbondingStatus(tc.unbondingRecordStatusUpdate, false) } diff --git a/x/stakeibc/keeper/unbonding_records.go b/x/stakeibc/keeper/unbonding_records.go index c91cb6a46..cbc69d312 100644 --- a/x/stakeibc/keeper/unbonding_records.go +++ b/x/stakeibc/keeper/unbonding_records.go @@ -230,7 +230,7 @@ func (k Keeper) InitiateAllHostZoneUnbondings(ctx sdk.Context, dayNumber uint64) failedUnbondings = append(failedUnbondings, hostZone.ChainId) continue } - err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, hostZone, epochUnbondingRecordIds, recordstypes.HostZoneUnbonding_UNBONDING_IN_PROGRESS) + err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, hostZone.ChainId, epochUnbondingRecordIds, recordstypes.HostZoneUnbonding_UNBONDING_IN_PROGRESS) if err != nil { k.Logger(ctx).Error(err.Error()) success = false @@ -359,7 +359,7 @@ func (k Keeper) SweepAllUnbondedTokensForHostZone(ctx sdk.Context, hostZone type if err != nil { k.Logger(ctx).Info(fmt.Sprintf("Failed to SubmitTxs, transfer to redemption account on %s", hostZone.ChainId)) } - err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, hostZone, epochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_IN_PROGRESS) + err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, hostZone.ChainId, epochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_IN_PROGRESS) if err != nil { k.Logger(ctx).Error(err.Error()) return false, 0