Skip to content

Commit

Permalink
Revert HostZoneUnbonding Status upon Channel Restoration (#387)
Browse files Browse the repository at this point in the history
Co-authored-by: Aidan Salzmann <aidan@stridelabs.co>
Co-authored-by: vish-stride <104537253+vish-stride@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 28, 2022
1 parent e74c34d commit 730cf3d
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 103 deletions.
11 changes: 5 additions & 6 deletions x/records/keeper/epoch_unbonding_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,21 @@ func (k Keeper) AddHostZoneToEpochUnbondingRecord(ctx sdk.Context, epochNumber u
return &epochUnbondingRecord, true
}

// TODO: unittest
func (k Keeper) SetHostZoneUnbondings(ctx sdk.Context, zone stakeibctypes.HostZone, epochUnbondingRecordIds []uint64, status types.HostZoneUnbonding_Status) error {
func (k Keeper) SetHostZoneUnbondings(ctx sdk.Context, chainId string, epochUnbondingRecordIds []uint64, status types.HostZoneUnbonding_Status) error {
for _, epochUnbondingRecordId := range epochUnbondingRecordIds {
k.Logger(ctx).Info(fmt.Sprintf("Updating host zone unbondings on EpochUnbondingRecord %d to status %s", epochUnbondingRecordId, status.String()))
// fetch the host zone unbonding
hostZoneUnbonding, found := k.GetHostZoneUnbondingByChainId(ctx, epochUnbondingRecordId, zone.ChainId)
hostZoneUnbonding, found := k.GetHostZoneUnbondingByChainId(ctx, epochUnbondingRecordId, chainId)
if !found {
errMsg := fmt.Sprintf("Error fetching host zone unbonding record for epoch: %d, host zone: %s", epochUnbondingRecordId, zone.ChainId)
errMsg := fmt.Sprintf("Error fetching host zone unbonding record for epoch: %d, host zone: %s", epochUnbondingRecordId, chainId)
k.Logger(ctx).Error(errMsg)
return sdkerrors.Wrapf(stakeibctypes.ErrHostZoneNotFound, errMsg)
}
hostZoneUnbonding.Status = status
// save the updated hzu on the epoch unbonding record
updatedRecord, success := k.AddHostZoneToEpochUnbondingRecord(ctx, epochUnbondingRecordId, zone.ChainId, hostZoneUnbonding)
updatedRecord, success := k.AddHostZoneToEpochUnbondingRecord(ctx, epochUnbondingRecordId, chainId, hostZoneUnbonding)
if !success {
errMsg := fmt.Sprintf("Error adding host zone unbonding record to epoch unbonding record: %d, host zone: %s", epochUnbondingRecordId, zone.ChainId)
errMsg := fmt.Sprintf("Error adding host zone unbonding record to epoch unbonding record: %d, host zone: %s", epochUnbondingRecordId, chainId)
k.Logger(ctx).Error(errMsg)
return sdkerrors.Wrap(types.ErrAddingHostZone, errMsg)
}
Expand Down
120 changes: 109 additions & 11 deletions x/records/keeper/epoch_unbonding_record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,45 @@ import (
"github.com/Stride-Labs/stride/v3/x/records/types"
)

func createNEpochUnbondingRecord(keeper *keeper.Keeper, ctx sdk.Context, n int) []types.EpochUnbondingRecord {
items := make([]types.EpochUnbondingRecord, n)
for i, item := range items {
item.EpochNumber = uint64(i)
items[i] = item
keeper.SetEpochUnbondingRecord(ctx, item)
func createNEpochUnbondingRecord(keeper *keeper.Keeper, ctx sdk.Context, n int) ([]types.EpochUnbondingRecord, map[string]types.HostZoneUnbonding) {
hostZoneUnbondingsList := []types.HostZoneUnbonding{
{
HostZoneId: "host-A",
Status: types.HostZoneUnbonding_UNBONDING_QUEUE,
},
{
HostZoneId: "host-B",
Status: types.HostZoneUnbonding_UNBONDING_QUEUE,
},
{
HostZoneId: "host-C",
Status: types.HostZoneUnbonding_UNBONDING_QUEUE,
},
}
return items
hostZoneUnbondingsMap := make(map[string]types.HostZoneUnbonding)
for _, hostZoneUnbonding := range hostZoneUnbondingsList {
hostZoneUnbondingsMap[hostZoneUnbonding.HostZoneId] = hostZoneUnbonding
}

epochUnbondingRecords := make([]types.EpochUnbondingRecord, n)
for epochNumber, epochUnbondingRecord := range epochUnbondingRecords {
epochUnbondingRecord.EpochNumber = uint64(epochNumber)

unbondingsCopy := make([]*types.HostZoneUnbonding, 3)
for i := range unbondingsCopy {
hostZoneUnbonding := hostZoneUnbondingsList[i]
epochUnbondingRecord.HostZoneUnbondings = append(epochUnbondingRecord.HostZoneUnbondings, &hostZoneUnbonding)
}

epochUnbondingRecords[epochNumber] = epochUnbondingRecord
keeper.SetEpochUnbondingRecord(ctx, epochUnbondingRecord)
}
return epochUnbondingRecords, hostZoneUnbondingsMap
}

func TestEpochUnbondingRecordGet(t *testing.T) {
keeper, ctx := keepertest.RecordsKeeper(t)
items := createNEpochUnbondingRecord(keeper, ctx, 10)
items, _ := createNEpochUnbondingRecord(keeper, ctx, 10)
for _, item := range items {
got, found := keeper.GetEpochUnbondingRecord(ctx, item.EpochNumber)
require.True(t, found)
Expand All @@ -37,7 +63,7 @@ func TestEpochUnbondingRecordGet(t *testing.T) {

func TestEpochUnbondingRecordRemove(t *testing.T) {
keeper, ctx := keepertest.RecordsKeeper(t)
items := createNEpochUnbondingRecord(keeper, ctx, 10)
items, _ := createNEpochUnbondingRecord(keeper, ctx, 10)
for _, item := range items {
keeper.RemoveEpochUnbondingRecord(ctx, item.EpochNumber)
_, found := keeper.GetEpochUnbondingRecord(ctx, item.EpochNumber)
Expand All @@ -47,7 +73,7 @@ func TestEpochUnbondingRecordRemove(t *testing.T) {

func TestEpochUnbondingRecordGetAll(t *testing.T) {
keeper, ctx := keepertest.RecordsKeeper(t)
items := createNEpochUnbondingRecord(keeper, ctx, 10)
items, _ := createNEpochUnbondingRecord(keeper, ctx, 10)
require.ElementsMatch(t,
nullify.Fill(items),
nullify.Fill(keeper.GetAllEpochUnbondingRecord(ctx)),
Expand All @@ -56,11 +82,83 @@ func TestEpochUnbondingRecordGetAll(t *testing.T) {

func TestGetAllPreviousEpochUnbondingRecords(t *testing.T) {
keeper, ctx := keepertest.RecordsKeeper(t)
items := createNEpochUnbondingRecord(keeper, ctx, 10)
items, _ := createNEpochUnbondingRecord(keeper, ctx, 10)
currentEpoch := uint64(8)
fetchedItems := items[:currentEpoch]
require.ElementsMatch(t,
nullify.Fill(fetchedItems),
nullify.Fill(keeper.GetAllPreviousEpochUnbondingRecords(ctx, currentEpoch)),
)
}

func TestGetHostZoneUnbondingByChainId(t *testing.T) {
keeper, ctx := keepertest.RecordsKeeper(t)
_, hostZoneUnbondings := createNEpochUnbondingRecord(keeper, ctx, 10)

expectedHostZoneUnbonding := hostZoneUnbondings["host-B"]
actualHostZoneUnbonding, found := keeper.GetHostZoneUnbondingByChainId(ctx, 1, "host-B")

require.True(t, found)
require.Equal(t,
*actualHostZoneUnbonding,
expectedHostZoneUnbonding,
)
}

func TestAddHostZoneToEpochUnbondingRecord(t *testing.T) {
keeper, ctx := keepertest.RecordsKeeper(t)
epochUnbondingRecords, _ := createNEpochUnbondingRecord(keeper, ctx, 3)

epochNumber := 0
initialEpochUnbondingRecord := epochUnbondingRecords[epochNumber]

// Add new host zone to initial epoch unbonding records
newHostZone := types.HostZoneUnbonding{
HostZoneId: "host-D",
Status: types.HostZoneUnbonding_UNBONDING_QUEUE,
}
expectedEpochUnbondingRecord := initialEpochUnbondingRecord
expectedEpochUnbondingRecord.HostZoneUnbondings = append(expectedEpochUnbondingRecord.HostZoneUnbondings, &newHostZone)

actualEpochUnbondingRecord, success := keeper.AddHostZoneToEpochUnbondingRecord(ctx, uint64(epochNumber), "host-D", &newHostZone)

require.True(t, success)
require.Equal(t,
expectedEpochUnbondingRecord,
*actualEpochUnbondingRecord,
)
}

func TestSetHostZoneUnbondings(t *testing.T) {
keeper, ctx := keepertest.RecordsKeeper(t)

initialEpochUnbondingRecords, _ := createNEpochUnbondingRecord(keeper, ctx, 4)

epochsToUpdate := []uint64{1, 3}
hostIdToUpdate := "host-B"
newStatus := types.HostZoneUnbonding_UNBONDING_IN_PROGRESS

expectedEpochUnbondingRecords := initialEpochUnbondingRecords
for _, epochUnbondingRecord := range expectedEpochUnbondingRecords {
for _, epochNumberToUpdate := range epochsToUpdate {
if epochUnbondingRecord.EpochNumber == epochNumberToUpdate {
for i, hostUnbonding := range epochUnbondingRecord.HostZoneUnbondings {
if hostUnbonding.HostZoneId == hostIdToUpdate {
updatedHostZoneUnbonding := hostUnbonding
updatedHostZoneUnbonding.Status = newStatus
epochUnbondingRecord.HostZoneUnbondings[i] = updatedHostZoneUnbonding
}
}
}
}
}

err := keeper.SetHostZoneUnbondings(ctx, hostIdToUpdate, epochsToUpdate, newStatus)
require.Nil(t, err)

actualEpochUnbondingRecord := keeper.GetAllEpochUnbondingRecord(ctx)
require.ElementsMatch(t,
expectedEpochUnbondingRecords,
actualEpochUnbondingRecord,
)
}
4 changes: 2 additions & 2 deletions x/records/keeper/grpc_query_epoch_unbonding_record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
func TestEpochUnbondingRecordQuerySingle(t *testing.T) {
keeper, ctx := keepertest.RecordsKeeper(t)
wctx := sdk.WrapSDKContext(ctx)
msgs := createNEpochUnbondingRecord(keeper, ctx, 2)
msgs, _ := createNEpochUnbondingRecord(keeper, ctx, 2)
for _, tc := range []struct {
desc string
request *types.QueryGetEpochUnbondingRecordRequest
Expand Down Expand Up @@ -63,7 +63,7 @@ func TestEpochUnbondingRecordQuerySingle(t *testing.T) {
func TestEpochUnbondingRecordQueryPaginated(t *testing.T) {
keeper, ctx := keepertest.RecordsKeeper(t)
wctx := sdk.WrapSDKContext(ctx)
msgs := createNEpochUnbondingRecord(keeper, ctx, 5)
msgs, _ := createNEpochUnbondingRecord(keeper, ctx, 5)

request := func(next []byte, offset, limit uint64, total bool) *types.QueryAllEpochUnbondingRecordRequest {
return &types.QueryAllEpochUnbondingRecordRequest{
Expand Down
8 changes: 2 additions & 6 deletions x/stakeibc/keeper/icacallbacks_redemption.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ func RedemptionCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a
if ack == nil {
// handle timeout
k.Logger(ctx).Error(fmt.Sprintf("RedemptionCallback timeout, ack is nil, packet %v", packet))
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE)
if err != nil {
return err
}
return nil
}

Expand All @@ -70,14 +66,14 @@ func RedemptionCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a
if len(txMsgData.Data) == 0 {
// handle tx failure
k.Logger(ctx).Error(fmt.Sprintf("RedemptionCallback tx failed, txMsgData is empty, ack error, packet %v", packet))
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE)
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone.ChainId, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE)
if err != nil {
return err
}
return nil
}

err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_CLAIMABLE)
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone.ChainId, redemptionCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_CLAIMABLE)
if err != nil {
return err
}
Expand Down
9 changes: 2 additions & 7 deletions x/stakeibc/keeper/icacallbacks_undelegate.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ func UndelegateCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a
// handle transaction failure cases
if ack == nil {
// handle timeout
// reset to UNBONDING_QUEUE
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_UNBONDING_QUEUE)
if err != nil {
return err
}
k.Logger(ctx).Error(fmt.Sprintf("UndelegateCallback timeout, txMsgData is nil, packet %v", packet))
return nil
}
Expand All @@ -73,7 +68,7 @@ func UndelegateCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a
if len(txMsgData.Data) == 0 {
// handle tx failure
// reset to UNBONDING_QUEUE
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_UNBONDING_QUEUE)
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone.ChainId, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_UNBONDING_QUEUE)
if err != nil {
return err
}
Expand Down Expand Up @@ -103,7 +98,7 @@ func UndelegateCallback(k Keeper, ctx sdk.Context, packet channeltypes.Packet, a
return err
}
// upon success, add host zone unbondings to the exit transfer queue
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE)
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, zone.ChainId, undelegateCallback.EpochUnbondingRecordIds, recordstypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE)
if err != nil {
return err
}
Expand Down
56 changes: 50 additions & 6 deletions x/stakeibc/keeper/msg_server_restore_interchain_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,59 @@ func (k msgServer) RestoreInterchainAccount(goCtx context.Context, msg *types.Ms

// If we're restoring a delegation account, we also have to reset record state
if msg.AccountType == types.ICAAccountType_DELEGATION {
depositRecords := k.RecordsKeeper.GetAllDepositRecord(ctx)
// revert DELEGATION_IN_PROGRESS records for the closed ICA channel (so that they can be staked)
for _, record := range depositRecords {
depositRecords := k.RecordsKeeper.GetAllDepositRecord(ctx)
for _, depositRecord := range depositRecords {
// only revert records for the select host zone
if depositRecord.HostZoneId == hostZone.ChainId && depositRecord.Status == recordtypes.DepositRecord_DELEGATION_IN_PROGRESS {
depositRecord.Status = recordtypes.DepositRecord_DELEGATION_QUEUE
k.Logger(ctx).Info(fmt.Sprintf("Setting DepositRecord %d to status DepositRecord_DELEGATION_IN_PROGRESS", depositRecord.Id))
k.RecordsKeeper.SetDepositRecord(ctx, depositRecord)
}
}

// revert epoch unbonding records for the closed ICA channel
epochUnbondingRecords := k.RecordsKeeper.GetAllEpochUnbondingRecord(ctx)
epochNumberForPendingUnbondingRecords := []uint64{}
epochNumberForPendingTransferRecords := []uint64{}
for _, epochUnbondingRecord := range epochUnbondingRecords {
// only revert records for the select host zone
if record.HostZoneId == hostZone.ChainId && record.Status == recordtypes.DepositRecord_DELEGATION_IN_PROGRESS {
record.Status = recordtypes.DepositRecord_DELEGATION_QUEUE
k.Logger(ctx).Error(fmt.Sprintf("Setting DepositRecord %d to status DepositRecord_DELEGATION_IN_PROGRESS", record.Id))
k.RecordsKeeper.SetDepositRecord(ctx, record)
hostZoneUnbonding, found := k.RecordsKeeper.GetHostZoneUnbondingByChainId(ctx, epochUnbondingRecord.EpochNumber, hostZone.ChainId)
if !found {
k.Logger(ctx).Info(fmt.Sprintf("No HostZoneUnbonding found for chainId: %s, epoch: %d", hostZone.ChainId, epochUnbondingRecord.EpochNumber))
continue
}

// Revert UNBONDING_IN_PROGRESS and EXIT_TRANSFER_IN_PROGRESS records
if hostZoneUnbonding.Status == recordtypes.HostZoneUnbonding_UNBONDING_IN_PROGRESS {
k.Logger(ctx).Info(fmt.Sprintf("HostZoneUnbonding for %s at EpochNumber %d is stuck in status %s",
hostZone.ChainId, epochUnbondingRecord.EpochNumber, recordtypes.HostZoneUnbonding_UNBONDING_IN_PROGRESS.String(),
))
epochNumberForPendingUnbondingRecords = append(epochNumberForPendingUnbondingRecords, epochUnbondingRecord.EpochNumber)

} else if hostZoneUnbonding.Status == recordtypes.HostZoneUnbonding_EXIT_TRANSFER_IN_PROGRESS {
k.Logger(ctx).Info(fmt.Sprintf("HostZoneUnbonding for %s at EpochNumber %d to in status %s",
hostZone.ChainId, epochUnbondingRecord.EpochNumber, recordtypes.HostZoneUnbonding_EXIT_TRANSFER_IN_PROGRESS.String(),
))
epochNumberForPendingTransferRecords = append(epochNumberForPendingTransferRecords, epochUnbondingRecord.EpochNumber)
}
}
// Revert UNBONDING_IN_PROGRESS records to UNBONDING_QUEUE
err := k.RecordsKeeper.SetHostZoneUnbondings(ctx, hostZone.ChainId, epochNumberForPendingUnbondingRecords, recordtypes.HostZoneUnbonding_UNBONDING_QUEUE)
if err != nil {
errMsg := fmt.Sprintf("unable to update host zone unbonding record status to %s for chainId: %s and epochUnbondingRecordIds: %v, err: %s",
recordtypes.HostZoneUnbonding_UNBONDING_QUEUE.String(), hostZone.ChainId, epochNumberForPendingUnbondingRecords, err)
k.Logger(ctx).Error(errMsg)
return nil, err
}

// Revert EXIT_TRANSFER_IN_PROGRESS records to EXIT_TRANSFER_QUEUE
err = k.RecordsKeeper.SetHostZoneUnbondings(ctx, hostZone.ChainId, epochNumberForPendingTransferRecords, recordtypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE)
if err != nil {
errMsg := fmt.Sprintf("unable to update host zone unbonding record status to %s for chainId: %s and epochUnbondingRecordIds: %v, err: %s",
recordtypes.HostZoneUnbonding_EXIT_TRANSFER_QUEUE.String(), hostZone.ChainId, epochNumberForPendingTransferRecords, err)
k.Logger(ctx).Error(errMsg)
return nil, err
}
}

Expand Down
Loading

0 comments on commit 730cf3d

Please sign in to comment.