diff --git a/packages/relayer/event.go b/packages/relayer/event.go index cc2afc3558..2389c40de9 100644 --- a/packages/relayer/event.go +++ b/packages/relayer/event.go @@ -160,6 +160,7 @@ type EventRepository interface { ) (uint64, error) DeleteAllAfterBlockID(blockID uint64, srcChainID uint64, destChainID uint64) error FindLatestBlockID( + ctx context.Context, event string, srcChainID uint64, destChainID uint64, diff --git a/packages/relayer/indexer/indexer.go b/packages/relayer/indexer/indexer.go index 55a05dd5d2..d1d21ffca5 100644 --- a/packages/relayer/indexer/indexer.go +++ b/packages/relayer/indexer/indexer.go @@ -472,7 +472,7 @@ func (i *Indexer) indexMessageSentEvents(ctx context.Context, } func (i *Indexer) checkReorg(ctx context.Context, emittedInBlockNumber uint64) error { - n, err := i.eventRepo.FindLatestBlockID(i.eventName, i.srcChainId.Uint64(), i.destChainId.Uint64()) + n, err := i.eventRepo.FindLatestBlockID(ctx, i.eventName, i.srcChainId.Uint64(), i.destChainId.Uint64()) if err != nil { return err } diff --git a/packages/relayer/indexer/set_initial_Indexing_block_by_mode.go b/packages/relayer/indexer/set_initial_Indexing_block_by_mode.go index f03cdb0ea3..a556e7f3cd 100644 --- a/packages/relayer/indexer/set_initial_Indexing_block_by_mode.go +++ b/packages/relayer/indexer/set_initial_Indexing_block_by_mode.go @@ -28,6 +28,7 @@ func (i *Indexer) setInitialIndexingBlockByMode( case Sync: // get most recently processed block height from the DB latest, err := i.eventRepo.FindLatestBlockID( + i.ctx, i.eventName, chainID.Uint64(), i.destChainId.Uint64(), diff --git a/packages/relayer/pkg/http/get_block_info.go b/packages/relayer/pkg/http/get_block_info.go index 0ad1b1eab1..30eb629258 100644 --- a/packages/relayer/pkg/http/get_block_info.go +++ b/packages/relayer/pkg/http/get_block_info.go @@ -80,6 +80,7 @@ func (srv *Server) GetBlockInfo(c echo.Context) error { } latestProcessedSrcBlock, err := srv.eventRepo.FindLatestBlockID( + c.Request().Context(), relayer.EventNameMessageSent, srcChainID.Uint64(), destChainID.Uint64(), @@ -89,6 +90,7 @@ func (srv *Server) GetBlockInfo(c echo.Context) error { } latestProcessedDestBlock, err := srv.eventRepo.FindLatestBlockID( + c.Request().Context(), relayer.EventNameMessageSent, destChainID.Uint64(), srcChainID.Uint64(), diff --git a/packages/relayer/pkg/mock/event_repository.go b/packages/relayer/pkg/mock/event_repository.go index 0456ba7c2c..f3593bbd7e 100644 --- a/packages/relayer/pkg/mock/event_repository.go +++ b/packages/relayer/pkg/mock/event_repository.go @@ -210,6 +210,7 @@ func (r *EventRepository) DeleteAllAfterBlockID(blockID uint64, srcChainID uint6 // GetLatestBlockID get latest block id func (r *EventRepository) FindLatestBlockID( + ctx context.Context, event string, srcChainID uint64, destChainID uint64, diff --git a/packages/relayer/pkg/repo/event.go b/packages/relayer/pkg/repo/event.go index 3309c9982f..b07b175855 100644 --- a/packages/relayer/pkg/repo/event.go +++ b/packages/relayer/pkg/repo/event.go @@ -2,10 +2,8 @@ package repo import ( "context" - "strings" - "time" - "net/http" + "strings" "github.com/morkid/paginate" "github.com/pkg/errors" @@ -64,7 +62,7 @@ func (r *EventRepository) Save(ctx context.Context, opts *relayer.SaveEventOpts) EmittedBlockID: opts.EmittedBlockID, } - if err := r.db.GormDB().Create(e).Error; err != nil { + if err := r.db.GormDB().WithContext(ctx).Create(e).Error; err != nil { return nil, errors.Wrap(err, "r.db.Create") } @@ -76,36 +74,54 @@ func (r *EventRepository) UpdateFeesAndProfitability( id int, opts *relayer.UpdateFeesAndProfitabilityOpts, ) error { - e := &relayer.Event{} - if err := r.db.GormDB().Where("id = ?", id).First(e).Error; err != nil { - return errors.Wrap(err, "r.db.First") + tx := r.db.GormDB().WithContext(ctx) + tx = tx.Model(&relayer.Event{}) + tx = tx.Where("id = ?", id) + + // check if existed. + var count int64 + if err := tx.Count(&count).Error; err != nil { + return errors.Wrap(err, "r.db.Count") + } + + if count == 0 { + return gorm.ErrRecordNotFound } - e.Fee = &opts.Fee - e.DestChainBaseFee = &opts.DestChainBaseFee - e.GasTipCap = &opts.GasTipCap - e.GasLimit = &opts.GasLimit - e.IsProfitable = &opts.IsProfitable - e.EstimatedOnchainFee = &opts.EstimatedOnchainFee - currentTime := time.Now().UTC() - e.IsProfitableEvaluatedAt = ¤tTime - - if err := r.db.GormDB().Save(e).Error; err != nil { - return errors.Wrap(err, "r.db.Save") + err := tx.Updates(map[string]interface{}{ + "fee": opts.Fee, + "dest_chain_base_fee": opts.DestChainBaseFee, + "gas_tip_cap": opts.GasTipCap, + "gas_limit": opts.GasLimit, + "is_profitable": opts.IsProfitable, + "estimated_onchain_fee": opts.EstimatedOnchainFee, + "is_profitable_evaluated_at": opts.IsProfitableEvaluatedAt, + }).Error + + if err != nil { + return errors.Wrap(err, "r.db.Commit") } return nil } func (r *EventRepository) UpdateStatus(ctx context.Context, id int, status relayer.EventStatus) error { - e := &relayer.Event{} - if err := r.db.GormDB().Where("id = ?", id).First(e).Error; err != nil { - return errors.Wrap(err, "r.db.First") + tx := r.db.GormDB().WithContext(ctx) + tx = tx.Model(&relayer.Event{}) + tx = tx.Where("id = ?", id) + + // check if existed. + var count int64 + if err := tx.Count(&count).Error; err != nil { + return errors.Wrap(err, "r.db.Count") } - e.Status = status - if err := r.db.GormDB().Save(e).Error; err != nil { - return errors.Wrap(err, "r.db.Save") + if count == 0 { + return gorm.ErrRecordNotFound + } + + if err := tx.Update("status", status).Error; err != nil { + return errors.Wrap(err, "tx.Commit") } return nil @@ -117,7 +133,7 @@ func (r *EventRepository) FirstByMsgHash( ) (*relayer.Event, error) { e := &relayer.Event{} // find all message sent events - if err := r.db.GormDB().Where("msg_hash = ?", msgHash). + if err := r.db.GormDB().WithContext(ctx).Where("msg_hash = ?", msgHash). First(&e).Error; err != nil { if err == gorm.ErrRecordNotFound { return nil, nil @@ -136,7 +152,7 @@ func (r *EventRepository) FirstByEventAndMsgHash( ) (*relayer.Event, error) { e := &relayer.Event{} // find all message sent events - if err := r.db.GormDB().Where("msg_hash = ?", msgHash). + if err := r.db.GormDB().WithContext(ctx).Where("msg_hash = ?", msgHash). Where("event = ?", event). First(&e).Error; err != nil { if err == gorm.ErrRecordNotFound { @@ -158,7 +174,7 @@ func (r *EventRepository) FindAllByAddress( DefaultSize: 100, }) - q := r.db.GormDB(). + q := r.db.GormDB().WithContext(ctx). Model(&relayer.Event{}). Where( "dest_owner_json = ? OR message_owner = ?", @@ -196,7 +212,7 @@ func (r *EventRepository) Delete( ctx context.Context, id int, ) error { - return r.db.GormDB().Delete(relayer.Event{}, id).Error + return r.db.GormDB().WithContext(ctx).Delete(relayer.Event{}, id).Error } func (r *EventRepository) ChainDataSyncedEventByBlockNumberOrGreater( @@ -207,7 +223,7 @@ func (r *EventRepository) ChainDataSyncedEventByBlockNumberOrGreater( ) (*relayer.Event, error) { e := &relayer.Event{} // find all message sent events - if err := r.db.GormDB().Where("name = ?", relayer.EventNameChainDataSynced). + if err := r.db.GormDB().WithContext(ctx).Where("name = ?", relayer.EventNameChainDataSynced). Where("chain_id = ?", srcChainId). Where("synced_chain_id = ?", syncedChainId). Where("block_id >= ?", blockNumber). @@ -231,7 +247,7 @@ func (r *EventRepository) LatestChainDataSyncedEvent( ) (uint64, error) { blockID := 0 // find all message sent events - if err := r.db.GormDB().Table("events"). + if err := r.db.GormDB().WithContext(ctx).Table("events"). Where("chain_id = ?", srcChainId). Where("synced_chain_id = ?", syncedChainId). Select("COALESCE(MAX(block_id), 0)"). @@ -257,6 +273,7 @@ WHERE block_id >= ? AND chain_id = ? AND dest_chain_id = ?` // GetLatestBlockID get latest block id func (r *EventRepository) FindLatestBlockID( + ctx context.Context, event string, srcChainID uint64, destChainID uint64, @@ -266,7 +283,8 @@ func (r *EventRepository) FindLatestBlockID( var b uint64 - if err := r.db.GormDB().Table("events").Raw(q, srcChainID, destChainID, event).Scan(&b).Error; err != nil { + if err := r.db.GormDB().WithContext(ctx).Table("events"). + Raw(q, srcChainID, destChainID, event).Scan(&b).Error; err != nil { return 0, err }