From 150f7443fdd338f0ff29f81027ac7171673fd2d7 Mon Sep 17 00:00:00 2001 From: Juan Farber Date: Thu, 19 Dec 2024 14:34:44 -0300 Subject: [PATCH 1/4] [NONEVM-706][Solana] - Refactor TXM + Rebroadcast Expired Tx functionality (#946) * refactor so txm owns blockhash assignment * lastValidBlockHeight shouldn't be exported * better comment * refactor sendWithRetry to make it clearer * confirm loop refactor * fix infinite loop * move accountID inside msg * lint fix * base58 does not contain lower l * fix hash errors * fix generate random hash * remove blockhash as we only need block height * expired tx changes without tests * add maybe to mocks * expiration tests * send txes through queue * revert pendingtx leakage of information. overwrite blockhash * fix order of confirm loop and not found signature check * fix mocks * prevent confirmation loop to mark tx as errored when it needs to be rebroadcasted * fix test * fix pointer * add comments * reduce rpc calls + refactors * tests + check to save rpc calls * address feedback + remove redundant impl * iface comment * address feedback on compute unit limit and lastValidBlockHeight assignment * blockhash assignment inside txm.sendWithRetry * address feedback * Merge branch 'develop' into nonevm-706-support-custom-bumping-strategy-rpc-expiration-within-confirmation * refactors after merge * fix interactive rebase * fix whitespace diffs * fix import * fix mocks * add on prebroadcaste error * remove rebroadcast count and fix package * improve docs * fix comparison against blockHeight instead of slotHeight * address feedback * fix lint * fix log * address feedback * remove useless slot height * address feedback * validate that tx doesn't exist in any of maps when adding new tx * callers set lastValidBlockheight + get blockhash on expiration + integration tests * add enq iface comm to help callers * address feedback --- docs/relay/README.md | 5 +- pkg/solana/chain.go | 2 +- pkg/solana/chain_test.go | 13 +- pkg/solana/config/config.go | 28 +- pkg/solana/config/mocks/config.go | 45 ++ pkg/solana/config/toml.go | 7 + pkg/solana/relay.go | 10 +- pkg/solana/transmitter.go | 2 +- pkg/solana/transmitter_test.go | 2 +- pkg/solana/txm/pendingtx.go | 147 +++--- pkg/solana/txm/pendingtx_test.go | 291 +++++++++-- pkg/solana/txm/txm.go | 657 ++++++++++++++----------- pkg/solana/txm/txm_integration_test.go | 187 +++++++ pkg/solana/txm/txm_internal_test.go | 483 ++++++++++++++++-- pkg/solana/txm/txm_load_test.go | 36 +- pkg/solana/txm/txm_race_test.go | 17 +- 16 files changed, 1459 insertions(+), 473 deletions(-) create mode 100644 pkg/solana/txm/txm_integration_test.go diff --git a/docs/relay/README.md b/docs/relay/README.md index f1dbffe81..07476babb 100644 --- a/docs/relay/README.md +++ b/docs/relay/README.md @@ -43,7 +43,8 @@ chainlink nodes solana create --name= --chain-id= --url= confirmationTimeout } if tx, exists := c.confirmedTxs[id]; exists { @@ -246,7 +268,7 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { return ErrSigDoesNotExist } // Transactions should only move to processed from broadcasted - tx, exists := c.broadcastedTxs[id] + tx, exists := c.broadcastedProcessedTxs[id] if !exists { return ErrTransactionNotFound } @@ -266,14 +288,14 @@ func (c *pendingTxContext) OnProcessed(sig solana.Signature) (string, error) { if !sigExists { return id, ErrSigDoesNotExist } - tx, exists := c.broadcastedTxs[id] + tx, exists := c.broadcastedProcessedTxs[id] if !exists { return id, ErrTransactionNotFound } // update tx state to Processed tx.state = Processed // save updated tx back to the broadcasted map - c.broadcastedTxs[id] = tx + c.broadcastedProcessedTxs[id] = tx return id, nil }) } @@ -290,7 +312,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { return ErrAlreadyInExpectedState } // Transactions should only move to confirmed from broadcasted/processed - if _, exists := c.broadcastedTxs[id]; !exists { + if _, exists := c.broadcastedProcessedTxs[id]; !exists { return ErrTransactionNotFound } return nil @@ -305,7 +327,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { if !sigExists { return id, ErrSigDoesNotExist } - tx, exists := c.broadcastedTxs[id] + tx, exists := c.broadcastedProcessedTxs[id] if !exists { return id, ErrTransactionNotFound } @@ -319,7 +341,7 @@ func (c *pendingTxContext) OnConfirmed(sig solana.Signature) (string, error) { // move tx to confirmed map c.confirmedTxs[id] = tx // remove tx from broadcasted map - delete(c.broadcastedTxs, id) + delete(c.broadcastedProcessedTxs, id) return id, nil }) } @@ -331,7 +353,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti return ErrSigDoesNotExist } // Allow transactions to transition from broadcasted, processed, or confirmed state in case there are delays between status checks - _, broadcastedExists := c.broadcastedTxs[id] + _, broadcastedExists := c.broadcastedProcessedTxs[id] _, confirmedExists := c.confirmedTxs[id] if !broadcastedExists && !confirmedExists { return ErrTransactionNotFound @@ -350,7 +372,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti } var tx, tempTx pendingTx var broadcastedExists, confirmedExists bool - if tempTx, broadcastedExists = c.broadcastedTxs[id]; broadcastedExists { + if tempTx, broadcastedExists = c.broadcastedProcessedTxs[id]; broadcastedExists { tx = tempTx } if tempTx, confirmedExists = c.confirmedTxs[id]; confirmedExists { @@ -366,7 +388,7 @@ func (c *pendingTxContext) OnFinalized(sig solana.Signature, retentionTimeout ti delete(c.cancelBy, id) } // delete from broadcasted map, if exists - delete(c.broadcastedTxs, id) + delete(c.broadcastedProcessedTxs, id) // delete from confirmed map, if exists delete(c.confirmedTxs, id) // remove all related signatures from the sigToID map to skip picking up this tx in the confirmation logic @@ -397,7 +419,7 @@ func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time. if tx, exists := c.finalizedErroredTxs[id]; exists && tx.state == txState { return ErrAlreadyInExpectedState } - _, broadcastedExists := c.broadcastedTxs[id] + _, broadcastedExists := c.broadcastedProcessedTxs[id] _, confirmedExists := c.confirmedTxs[id] if broadcastedExists || confirmedExists { return ErrIDAlreadyExists @@ -410,10 +432,11 @@ func (c *pendingTxContext) OnPrebroadcastError(id string, retentionTimeout time. // upgrade to write lock if id does not exist in other maps and is not in expected state already _, err = c.withWriteLock(func() (string, error) { - if tx, exists := c.finalizedErroredTxs[id]; exists && tx.state == txState { + tx, exists := c.finalizedErroredTxs[id] + if exists && tx.state == txState { return "", ErrAlreadyInExpectedState } - _, broadcastedExists := c.broadcastedTxs[id] + _, broadcastedExists := c.broadcastedProcessedTxs[id] _, confirmedExists := c.confirmedTxs[id] if broadcastedExists || confirmedExists { return "", ErrIDAlreadyExists @@ -437,7 +460,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D } // transaction can transition from any non-finalized state var broadcastedExists, confirmedExists bool - _, broadcastedExists = c.broadcastedTxs[id] + _, broadcastedExists = c.broadcastedProcessedTxs[id] _, confirmedExists = c.confirmedTxs[id] // transcation does not exist in any tx maps if !broadcastedExists && !confirmedExists { @@ -457,7 +480,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D } var tx, tempTx pendingTx var broadcastedExists, confirmedExists bool - if tempTx, broadcastedExists = c.broadcastedTxs[id]; broadcastedExists { + if tempTx, broadcastedExists = c.broadcastedProcessedTxs[id]; broadcastedExists { tx = tempTx } if tempTx, confirmedExists = c.confirmedTxs[id]; confirmedExists { @@ -473,7 +496,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D delete(c.cancelBy, id) } // delete from broadcasted map, if exists - delete(c.broadcastedTxs, id) + delete(c.broadcastedProcessedTxs, id) // delete from confirmed map, if exists delete(c.confirmedTxs, id) // remove all related signatures from the sigToID map to skip picking up this tx in the confirmation logic @@ -497,7 +520,7 @@ func (c *pendingTxContext) OnError(sig solana.Signature, retentionTimeout time.D func (c *pendingTxContext) GetTxState(id string) (TxState, error) { c.lock.RLock() defer c.lock.RUnlock() - if tx, exists := c.broadcastedTxs[id]; exists { + if tx, exists := c.broadcastedProcessedTxs[id]; exists { return tx.state, nil } if tx, exists := c.confirmedTxs[id]; exists { @@ -594,16 +617,20 @@ func (c *pendingTxContextWithProm) OnConfirmed(sig solana.Signature) (string, er return id, err } -func (c *pendingTxContextWithProm) Remove(sig solana.Signature) (string, error) { - return c.pendingTx.Remove(sig) +func (c *pendingTxContextWithProm) Remove(id string) (string, error) { + return c.pendingTx.Remove(id) } -func (c *pendingTxContextWithProm) ListAll() []solana.Signature { - sigs := c.pendingTx.ListAll() +func (c *pendingTxContextWithProm) ListAllSigs() []solana.Signature { + sigs := c.pendingTx.ListAllSigs() promSolTxmPendingTxs.WithLabelValues(c.chainID).Set(float64(len(sigs))) return sigs } +func (c *pendingTxContextWithProm) ListAllExpiredBroadcastedTxs(currBlockNumber uint64) []pendingTx { + return c.pendingTx.ListAllExpiredBroadcastedTxs(currBlockNumber) +} + func (c *pendingTxContextWithProm) Expired(sig solana.Signature, lifespan time.Duration) bool { return c.pendingTx.Expired(sig, lifespan) } diff --git a/pkg/solana/txm/pendingtx_test.go b/pkg/solana/txm/pendingtx_test.go index e7b7fc51e..a79f9f7aa 100644 --- a/pkg/solana/txm/pendingtx_test.go +++ b/pkg/solana/txm/pendingtx_test.go @@ -48,19 +48,21 @@ func TestPendingTxContext_add_remove_multiple(t *testing.T) { // cannot add signature for non existent ID require.Error(t, txs.AddSignature(uuid.New().String(), solana.Signature{})) - // return list of signatures - list := txs.ListAll() + list := make([]string, 0, n) + for _, id := range txs.sigToID { + list = append(list, id) + } assert.Equal(t, n, len(list)) // stop all sub processes for i := 0; i < len(list); i++ { - id, err := txs.Remove(list[i]) + txID := list[i] + _, err := txs.Remove(txID) assert.NoError(t, err) - assert.Equal(t, n-i-1, len(txs.ListAll())) - assert.Equal(t, ids[list[i]], id) + assert.Equal(t, n-i-1, len(txs.ListAllSigs())) // second remove should not return valid id - already removed - id, err = txs.Remove(list[i]) + id, err := txs.Remove(txID) require.Error(t, err) assert.Equal(t, "", id) } @@ -76,29 +78,55 @@ func TestPendingTxContext_new(t *testing.T) { // Create new transaction msg := pendingTx{id: uuid.NewString()} err := txs.New(msg, sig, cancel) - require.NoError(t, err) + require.NoError(t, err, "expected no error when adding a new transaction") - // Check it exists in signature map + // Check it exists in signature map and mapped to the correct txID id, exists := txs.sigToID[sig] - require.True(t, exists) - require.Equal(t, msg.id, id) + require.True(t, exists, "signature should exist in sigToID map") + require.Equal(t, msg.id, id, "signature should map to correct transaction ID") - // Check it exists in broadcasted map - tx, exists := txs.broadcastedTxs[msg.id] - require.True(t, exists) - require.Len(t, tx.signatures, 1) - require.Equal(t, sig, tx.signatures[0]) + // Check it exists in broadcasted map and that sigs match + tx, exists := txs.broadcastedProcessedTxs[msg.id] + require.True(t, exists, "transaction should exist in broadcastedProcessedTxs map") + require.Len(t, tx.signatures, 1, "transaction should have one signature") + require.Equal(t, sig, tx.signatures[0], "signature should match") // Check status is Broadcasted - require.Equal(t, Broadcasted, tx.state) + require.Equal(t, Broadcasted, tx.state, "transaction state should be Broadcasted") - // Check it does not exist in confirmed map + // Check it does not exist in confirmed nor finalized maps _, exists = txs.confirmedTxs[msg.id] - require.False(t, exists) - - // Check it does not exist in finalized map + require.False(t, exists, "transaction should not exist in confirmedTxs map") _, exists = txs.finalizedErroredTxs[msg.id] - require.False(t, exists) + require.False(t, exists, "transaction should not exist in finalizedErroredTxs map") + + // Attempt to add the same transaction again with the same signature + err = txs.New(msg, sig, cancel) + require.ErrorIs(t, err, ErrSigAlreadyExists, "expected ErrSigAlreadyExists when adding duplicate signature") + + // Attempt to add a new transaction with the same transaction ID but different signature + err = txs.New(pendingTx{id: msg.id}, randomSignature(t), cancel) + require.ErrorIs(t, err, ErrIDAlreadyExists, "expected ErrIDAlreadyExists when adding duplicate transaction ID") + + // Attempt to add a new transaction with a different transaction ID but same signature + err = txs.New(pendingTx{id: uuid.NewString()}, sig, cancel) + require.ErrorIs(t, err, ErrSigAlreadyExists, "expected ErrSigAlreadyExists when adding duplicate signature") + + // Simulate moving the transaction to confirmedTxs map + _, err = txs.OnConfirmed(sig) + require.NoError(t, err, "expected no error when confirming transaction") + + // Attempt to add a new transaction with the same ID (now in confirmedTxs) and new signature + err = txs.New(pendingTx{id: msg.id}, randomSignature(t), cancel) + require.ErrorIs(t, err, ErrIDAlreadyExists, "expected ErrIDAlreadyExists when adding transaction ID that exists in confirmedTxs") + + // Simulate moving the transaction to finalizedErroredTxs map + _, err = txs.OnFinalized(sig, 10*time.Second) + require.NoError(t, err, "expected no error when finalizing transaction") + + // Attempt to add a new transaction with the same ID (now in finalizedErroredTxs) and new signature + err = txs.New(pendingTx{id: msg.id}, randomSignature(t), cancel) + require.ErrorIs(t, err, ErrIDAlreadyExists, "expected ErrIDAlreadyExists when adding transaction ID that exists in finalizedErroredTxs") } func TestPendingTxContext_add_signature(t *testing.T) { @@ -127,7 +155,7 @@ func TestPendingTxContext_add_signature(t *testing.T) { require.Equal(t, msg.id, id) // Check broadcasted map - tx, exists := txs.broadcastedTxs[msg.id] + tx, exists := txs.broadcastedProcessedTxs[msg.id] require.True(t, exists) require.Len(t, tx.signatures, 2) require.Equal(t, sig1, tx.signatures[0]) @@ -216,7 +244,7 @@ func TestPendingTxContext_on_broadcasted_processed(t *testing.T) { require.Equal(t, msg.id, id) // Check it exists in broadcasted map - tx, exists := txs.broadcastedTxs[msg.id] + tx, exists := txs.broadcastedProcessedTxs[msg.id] require.True(t, exists) require.Len(t, tx.signatures, 1) require.Equal(t, sig, tx.signatures[0]) @@ -351,7 +379,7 @@ func TestPendingTxContext_on_confirmed(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists = txs.broadcastedTxs[msg.id] + _, exists = txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it exists in confirmed map @@ -463,7 +491,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -513,7 +541,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -558,7 +586,7 @@ func TestPendingTxContext_on_finalized(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -613,7 +641,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -651,7 +679,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -684,7 +712,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it exists in errored map @@ -718,7 +746,7 @@ func TestPendingTxContext_on_error(t *testing.T) { require.Equal(t, msg.id, id) // Check it does not exist in broadcasted map - _, exists := txs.broadcastedTxs[msg.id] + _, exists := txs.broadcastedProcessedTxs[msg.id] require.False(t, exists) // Check it does not exist in confirmed map @@ -825,22 +853,27 @@ func TestPendingTxContext_remove(t *testing.T) { txs := newPendingTxContext() retentionTimeout := 5 * time.Second + broadcastedID := uuid.NewString() broadcastedSig1 := randomSignature(t) broadcastedSig2 := randomSignature(t) + processedID := uuid.NewString() processedSig := randomSignature(t) + confirmedID := uuid.NewString() confirmedSig := randomSignature(t) + finalizedID := uuid.NewString() finalizedSig := randomSignature(t) + erroredID := uuid.NewString() erroredSig := randomSignature(t) // Create new broadcasted transaction with extra sig - broadcastedMsg := pendingTx{id: uuid.NewString()} + broadcastedMsg := pendingTx{id: broadcastedID} err := txs.New(broadcastedMsg, broadcastedSig1, cancel) require.NoError(t, err) err = txs.AddSignature(broadcastedMsg.id, broadcastedSig2) require.NoError(t, err) // Create new processed transaction - processedMsg := pendingTx{id: uuid.NewString()} + processedMsg := pendingTx{id: processedID} err = txs.New(processedMsg, processedSig, cancel) require.NoError(t, err) id, err := txs.OnProcessed(processedSig) @@ -848,7 +881,7 @@ func TestPendingTxContext_remove(t *testing.T) { require.Equal(t, processedMsg.id, id) // Create new confirmed transaction - confirmedMsg := pendingTx{id: uuid.NewString()} + confirmedMsg := pendingTx{id: confirmedID} err = txs.New(confirmedMsg, confirmedSig, cancel) require.NoError(t, err) id, err = txs.OnConfirmed(confirmedSig) @@ -856,7 +889,7 @@ func TestPendingTxContext_remove(t *testing.T) { require.Equal(t, confirmedMsg.id, id) // Create new finalized transaction - finalizedMsg := pendingTx{id: uuid.NewString()} + finalizedMsg := pendingTx{id: finalizedID} err = txs.New(finalizedMsg, finalizedSig, cancel) require.NoError(t, err) id, err = txs.OnFinalized(finalizedSig, retentionTimeout) @@ -864,7 +897,7 @@ func TestPendingTxContext_remove(t *testing.T) { require.Equal(t, finalizedMsg.id, id) // Create new errored transaction - erroredMsg := pendingTx{id: uuid.NewString()} + erroredMsg := pendingTx{id: erroredID} err = txs.New(erroredMsg, erroredSig, cancel) require.NoError(t, err) id, err = txs.OnError(erroredSig, retentionTimeout, Errored, 0) @@ -872,11 +905,11 @@ func TestPendingTxContext_remove(t *testing.T) { require.Equal(t, erroredMsg.id, id) // Remove broadcasted transaction - id, err = txs.Remove(broadcastedSig1) + id, err = txs.Remove(broadcastedID) require.NoError(t, err) require.Equal(t, broadcastedMsg.id, id) // Check removed from broadcasted map - _, exists := txs.broadcastedTxs[broadcastedMsg.id] + _, exists := txs.broadcastedProcessedTxs[broadcastedMsg.id] require.False(t, exists) // Check all signatures removed from sig map _, exists = txs.sigToID[broadcastedSig1] @@ -885,18 +918,18 @@ func TestPendingTxContext_remove(t *testing.T) { require.False(t, exists) // Remove processed transaction - id, err = txs.Remove(processedSig) + id, err = txs.Remove(processedID) require.NoError(t, err) require.Equal(t, processedMsg.id, id) // Check removed from broadcasted map - _, exists = txs.broadcastedTxs[processedMsg.id] + _, exists = txs.broadcastedProcessedTxs[processedMsg.id] require.False(t, exists) // Check all signatures removed from sig map _, exists = txs.sigToID[processedSig] require.False(t, exists) // Remove confirmed transaction - id, err = txs.Remove(confirmedSig) + id, err = txs.Remove(confirmedID) require.NoError(t, err) require.Equal(t, confirmedMsg.id, id) // Check removed from confirmed map @@ -907,17 +940,17 @@ func TestPendingTxContext_remove(t *testing.T) { require.False(t, exists) // Check remove cannot be called on finalized transaction - id, err = txs.Remove(finalizedSig) + id, err = txs.Remove(finalizedID) require.Error(t, err) require.Equal(t, "", id) // Check remove cannot be called on errored transaction - id, err = txs.Remove(erroredSig) + id, err = txs.Remove(erroredID) require.Error(t, err) require.Equal(t, "", id) // Check sig list is empty after all removals - require.Empty(t, txs.ListAll()) + require.Empty(t, txs.ListAllSigs()) } func TestPendingTxContext_trim_finalized_errored_txs(t *testing.T) { t.Parallel() @@ -959,23 +992,24 @@ func TestPendingTxContext_expired(t *testing.T) { _, cancel := context.WithCancel(tests.Context(t)) sig := solana.Signature{} txs := newPendingTxContext() + txID := uuid.NewString() - msg := pendingTx{id: uuid.NewString()} + msg := pendingTx{id: txID} err := txs.New(msg, sig, cancel) assert.NoError(t, err) - msg, exists := txs.broadcastedTxs[msg.id] + msg, exists := txs.broadcastedProcessedTxs[msg.id] require.True(t, exists) // Set createTs to 10 seconds ago msg.createTs = time.Now().Add(-10 * time.Second) - txs.broadcastedTxs[msg.id] = msg + txs.broadcastedProcessedTxs[msg.id] = msg assert.False(t, txs.Expired(sig, 0*time.Second)) // false if timeout 0 assert.True(t, txs.Expired(sig, 5*time.Second)) // expired for 5s lifetime assert.False(t, txs.Expired(sig, 60*time.Second)) // not expired for 60s lifetime - id, err := txs.Remove(sig) + id, err := txs.Remove(txID) assert.NoError(t, err) assert.Equal(t, msg.id, id) assert.False(t, txs.Expired(sig, 60*time.Second)) // no longer exists, should return false @@ -1025,18 +1059,19 @@ func TestPendingTxContext_race(t *testing.T) { t.Run("remove", func(t *testing.T) { txCtx := newPendingTxContext() - msg := pendingTx{id: uuid.NewString()} + txID := uuid.NewString() + msg := pendingTx{id: txID} err := txCtx.New(msg, solana.Signature{}, func() {}) require.NoError(t, err) var wg sync.WaitGroup wg.Add(2) go func() { - assert.NotPanics(t, func() { txCtx.Remove(solana.Signature{}) }) //nolint // no need to check error + assert.NotPanics(t, func() { txCtx.Remove(txID) }) //nolint // no need to check error wg.Done() }() go func() { - assert.NotPanics(t, func() { txCtx.Remove(solana.Signature{}) }) //nolint // no need to check error + assert.NotPanics(t, func() { txCtx.Remove(txID) }) //nolint // no need to check error wg.Done() }() @@ -1137,3 +1172,157 @@ func randomSignature(t *testing.T) solana.Signature { return solana.SignatureFromBytes(sig) } + +func TestPendingTxContext_ListAllExpiredBroadcastedTxs(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T, ctx *pendingTxContext) + currBlockHeight uint64 + expectedTxIDs []string + }{ + { + name: "No broadcasted transactions", + setup: func(t *testing.T, ctx *pendingTxContext) { + // No setup needed; broadcastedProcessedTxs remains empty + }, + currBlockHeight: 1000, + expectedTxIDs: []string{}, + }, + { + name: "No expired broadcasted transactions", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 1500, + } + tx2 := pendingTx{ + id: "tx2", + state: Broadcasted, + lastValidBlockHeight: 1600, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + }, + currBlockHeight: 1400, + expectedTxIDs: []string{}, + }, + { + name: "Some expired broadcasted transactions", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 1000, + } + tx2 := pendingTx{ + id: "tx2", + state: Broadcasted, + lastValidBlockHeight: 1500, + } + tx3 := pendingTx{ + id: "tx3", + state: Broadcasted, + lastValidBlockHeight: 900, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + ctx.broadcastedProcessedTxs["tx3"] = tx3 + }, + currBlockHeight: 1200, + expectedTxIDs: []string{"tx1", "tx3"}, + }, + { + name: "All broadcasted transactions expired with maxUint64", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 1000, + } + tx2 := pendingTx{ + id: "tx2", + state: Broadcasted, + lastValidBlockHeight: 1500, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + }, + currBlockHeight: ^uint64(0), // maxUint64 + expectedTxIDs: []string{"tx1", "tx2"}, + }, + { + name: "Only broadcasted transactions are considered", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 800, + } + tx2 := pendingTx{ + id: "tx2", + state: Processed, // Not Broadcasted + lastValidBlockHeight: 700, + } + tx3 := pendingTx{ + id: "tx3", + state: Processed, // Not Broadcasted + lastValidBlockHeight: 600, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + ctx.broadcastedProcessedTxs["tx3"] = tx3 + }, + currBlockHeight: 900, + expectedTxIDs: []string{"tx1"}, + }, + { + name: "Broadcasted transactions with edge block heights", + setup: func(t *testing.T, ctx *pendingTxContext) { + tx1 := pendingTx{ + id: "tx1", + state: Broadcasted, + lastValidBlockHeight: 1000, + } + tx2 := pendingTx{ + id: "tx2", + state: Broadcasted, + lastValidBlockHeight: 999, + } + tx3 := pendingTx{ + id: "tx3", + state: Broadcasted, + lastValidBlockHeight: 1, + } + ctx.broadcastedProcessedTxs["tx1"] = tx1 + ctx.broadcastedProcessedTxs["tx2"] = tx2 + ctx.broadcastedProcessedTxs["tx3"] = tx3 + }, + currBlockHeight: 1000, + expectedTxIDs: []string{"tx2", "tx3"}, + }, + } + + for _, tt := range tests { + tt := tt // capture range variable + t.Run(tt.name, func(t *testing.T) { + // Initialize a new PendingTxContext + ctx := newPendingTxContext() + + // Setup the test case + tt.setup(t, ctx) + + // Execute the function under test + result := ctx.ListAllExpiredBroadcastedTxs(tt.currBlockHeight) + + // Extract the IDs from the result + var resultIDs []string + for _, tx := range result { + resultIDs = append(resultIDs, tx.id) + } + + // Assert that the expected IDs match the result IDs (order does not matter) + assert.ElementsMatch(t, tt.expectedTxIDs, resultIDs) + }) + } +} diff --git a/pkg/solana/txm/txm.go b/pkg/solana/txm/txm.go index 10cc1acd2..3e169d88a 100644 --- a/pkg/solana/txm/txm.go +++ b/pkg/solana/txm/txm.go @@ -142,6 +142,10 @@ func (txm *Txm) Start(ctx context.Context) error { }) } +// run is a goroutine that continuously processes transactions from the chSend channel. +// It attempts to send each transaction with retry logic and, upon success, enqueues the transaction for simulation. +// If a transaction fails to send, it logs the error and resets the client to handle potential bad RPCs. +// The function runs until the chStop channel signals to stop. func (txm *Txm) run() { defer txm.done.Done() ctx, cancel := txm.chStop.NewCtx() @@ -175,197 +179,198 @@ func (txm *Txm) run() { } } +// sendWithRetry attempts to send a transaction with exponential backoff retry logic. +// It builds, signs, sends the initial tx, and starts a retry routine with fee bumping if needed. +// The function returns the signed transaction, its ID, and the initial signature for use in simulation. func (txm *Txm) sendWithRetry(ctx context.Context, msg pendingTx) (solanaGo.Transaction, string, solanaGo.Signature, error) { - // get key - // fee payer account is index 0 account - // https://github.com/gagliardetto/solana-go/blob/main/transaction.go#L252 - key := msg.tx.Message.AccountKeys[0].String() - - // base compute unit price should only be calculated once - // prevent underlying base changing when bumping (could occur with RPC based estimation) - getFee := func(count int) fees.ComputeUnitPrice { - fee := fees.CalculateFee( - msg.cfg.BaseComputeUnitPrice, - msg.cfg.ComputeUnitPriceMax, - msg.cfg.ComputeUnitPriceMin, - uint(count), //nolint:gosec // reasonable number of bumps should never cause overflow - ) - return fees.ComputeUnitPrice(fee) - } - - baseTx := msg.tx - - // add compute unit limit instruction - static for the transaction - // skip if compute unit limit = 0 (otherwise would always fail) - if msg.cfg.ComputeUnitLimit != 0 { - if computeUnitLimitErr := fees.SetComputeUnitLimit(&baseTx, fees.ComputeUnitLimit(msg.cfg.ComputeUnitLimit)); computeUnitLimitErr != nil { - return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to add compute unit limit instruction: %w", computeUnitLimitErr) - } - } - - buildTx := func(ctx context.Context, base solanaGo.Transaction, retryCount int) (solanaGo.Transaction, error) { - newTx := base // make copy - - // set fee - // fee bumping can be enabled by moving the setting & signing logic to the broadcaster - if computeUnitErr := fees.SetComputeUnitPrice(&newTx, getFee(retryCount)); computeUnitErr != nil { - return solanaGo.Transaction{}, computeUnitErr - } - - // sign tx - txMsg, marshalErr := newTx.Message.MarshalBinary() - if marshalErr != nil { - return solanaGo.Transaction{}, fmt.Errorf("error in soltxm.SendWithRetry.MarshalBinary: %w", marshalErr) - } - sigBytes, signErr := txm.ks.Sign(ctx, key, txMsg) - if signErr != nil { - return solanaGo.Transaction{}, fmt.Errorf("error in soltxm.SendWithRetry.Sign: %w", signErr) - } - var finalSig [64]byte - copy(finalSig[:], sigBytes) - newTx.Signatures = append(newTx.Signatures, finalSig) - - return newTx, nil - } - - initTx, initBuildErr := buildTx(ctx, baseTx, 0) - if initBuildErr != nil { - return solanaGo.Transaction{}, "", solanaGo.Signature{}, initBuildErr + // Build and sign initial transaction setting compute unit price and limit + initTx, err := txm.buildTx(ctx, msg, 0) + if err != nil { + return solanaGo.Transaction{}, "", solanaGo.Signature{}, err } - // create timeout context + // Send initial transaction ctx, cancel := context.WithTimeout(ctx, msg.cfg.Timeout) - - // send initial tx (do not retry and exit early if fails) sig, initSendErr := txm.sendTx(ctx, &initTx) if initSendErr != nil { - cancel() // cancel context when exiting early + // Do not retry and exit early if fails + cancel() stateTransitionErr := txm.txs.OnPrebroadcastError(msg.id, txm.cfg.TxRetentionTimeout(), Errored, TxFailReject) return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("tx failed initial transmit: %w", errors.Join(initSendErr, stateTransitionErr)) } - // store tx signature + cancel function - initStoreErr := txm.txs.New(msg, sig, cancel) - if initStoreErr != nil { - cancel() // cancel context when exiting early - return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save tx signature (%s) to inflight txs: %w", sig, initStoreErr) + // Store tx signature and cancel function + if err := txm.txs.New(msg, sig, cancel); err != nil { + cancel() // Cancel context when exiting early + return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save tx signature (%s) to inflight txs: %w", sig, err) } - // used for tracking rebroadcasting only in SendWithRetry - var sigs signatureList + txm.lggr.Debugw("tx initial broadcast", "id", msg.id, "fee", msg.cfg.BaseComputeUnitPrice, "signature", sig, "lastValidBlockHeight", msg.lastValidBlockHeight) + + // Initialize signature list with initialTx signature. This list will be used to add new signatures and track retry attempts. + sigs := &signatureList{} sigs.Allocate() if initSetErr := sigs.Set(0, sig); initSetErr != nil { return solanaGo.Transaction{}, "", solanaGo.Signature{}, fmt.Errorf("failed to save initial signature in signature list: %w", initSetErr) } - txm.lggr.Debugw("tx initial broadcast", "id", msg.id, "fee", getFee(0), "signature", sig) - + // pass in copy of msg (to build new tx with bumped fee) and broadcasted tx == initTx (to retry tx without bumping) txm.done.Add(1) - // retry with exponential backoff - // until context cancelled by timeout or called externally - // pass in copy of baseTx (used to build new tx with bumped fee) and broadcasted tx == initTx (used to retry tx without bumping) - go func(ctx context.Context, baseTx, currentTx solanaGo.Transaction) { + go func() { defer txm.done.Done() - deltaT := 1 // ms - tick := time.After(0) - bumpCount := 0 - bumpTime := time.Now() - var wg sync.WaitGroup + txm.retryTx(ctx, msg, initTx, sigs) + }() - for { - select { - case <-ctx.Done(): - // stop sending tx after retry tx ctx times out (does not stop confirmation polling for tx) - wg.Wait() - txm.lggr.Debugw("stopped tx retry", "id", msg.id, "signatures", sigs.List(), "err", context.Cause(ctx)) - return - case <-tick: - var shouldBump bool - // bump if period > 0 and past time - if msg.cfg.FeeBumpPeriod != 0 && time.Since(bumpTime) > msg.cfg.FeeBumpPeriod { - bumpCount++ - bumpTime = time.Now() - shouldBump = true - } + // Return signed tx, id, signature for use in simulation + return initTx, msg.id, sig, nil +} - // if fee should be bumped, build new tx and replace currentTx - if shouldBump { - var retryBuildErr error - currentTx, retryBuildErr = buildTx(ctx, baseTx, bumpCount) - if retryBuildErr != nil { - txm.lggr.Errorw("failed to build bumped retry tx", "error", retryBuildErr, "id", msg.id) - return // exit func if cannot build tx for retrying - } - ind := sigs.Allocate() - if ind != bumpCount { - txm.lggr.Errorw("INVARIANT VIOLATION: index (%d) != bumpCount (%d)", ind, bumpCount) - return - } - } +// buildTx builds and signs the transaction with the appropriate compute unit price. +func (txm *Txm) buildTx(ctx context.Context, msg pendingTx, retryCount int) (solanaGo.Transaction, error) { + // work with a copy + newTx := msg.tx - // take currentTx and broadcast, if bumped fee -> save signature to list - wg.Add(1) - go func(bump bool, count int, retryTx solanaGo.Transaction) { - defer wg.Done() - - retrySig, retrySendErr := txm.sendTx(ctx, &retryTx) - // this could occur if endpoint goes down or if ctx cancelled - if retrySendErr != nil { - if strings.Contains(retrySendErr.Error(), "context canceled") || strings.Contains(retrySendErr.Error(), "context deadline exceeded") { - txm.lggr.Debugw("ctx error on send retry transaction", "error", retrySendErr, "signatures", sigs.List(), "id", msg.id) - } else { - txm.lggr.Warnw("failed to send retry transaction", "error", retrySendErr, "signatures", sigs.List(), "id", msg.id) - } - return - } - - // save new signature if fee bumped - if bump { - if retryStoreErr := txm.txs.AddSignature(msg.id, retrySig); retryStoreErr != nil { - txm.lggr.Warnw("error in adding retry transaction", "error", retryStoreErr, "id", msg.id) - return - } - if setErr := sigs.Set(count, retrySig); setErr != nil { - // this should never happen - txm.lggr.Errorw("INVARIANT VIOLATION", "error", setErr) - } - txm.lggr.Debugw("tx rebroadcast with bumped fee", "id", msg.id, "fee", getFee(count), "signatures", sigs.List()) - } - - // prevent locking on waitgroup when ctx is closed - wait := make(chan struct{}) - go func() { - defer close(wait) - sigs.Wait(count) // wait until bump tx has set the tx signature to compare rebroadcast signatures - }() - select { - case <-ctx.Done(): - return - case <-wait: - } - - // this should never happen (should match the signature saved to sigs) - if fetchedSig, fetchErr := sigs.Get(count); fetchErr != nil || retrySig != fetchedSig { - txm.lggr.Errorw("original signature does not match retry signature", "expectedSignatures", sigs.List(), "receivedSignature", retrySig, "error", fetchErr) - } - }(shouldBump, bumpCount, currentTx) - } + // Set compute unit limit if specified + if msg.cfg.ComputeUnitLimit != 0 { + if err := fees.SetComputeUnitLimit(&newTx, fees.ComputeUnitLimit(msg.cfg.ComputeUnitLimit)); err != nil { + return solanaGo.Transaction{}, fmt.Errorf("failed to add compute unit limit instruction: %w", err) + } + } + + // Set compute unit price (fee) + fee := fees.ComputeUnitPrice( + fees.CalculateFee( + msg.cfg.BaseComputeUnitPrice, + msg.cfg.ComputeUnitPriceMax, + msg.cfg.ComputeUnitPriceMin, + uint(retryCount), //nolint:gosec // reasonable number of bumps should never cause overflow + )) + if err := fees.SetComputeUnitPrice(&newTx, fee); err != nil { + return solanaGo.Transaction{}, err + } + + // Sign transaction + // NOTE: fee payer account is index 0 account. https://github.com/gagliardetto/solana-go/blob/main/transaction.go#L252 + txMsg, err := newTx.Message.MarshalBinary() + if err != nil { + return solanaGo.Transaction{}, fmt.Errorf("error in MarshalBinary: %w", err) + } + sigBytes, err := txm.ks.Sign(ctx, msg.tx.Message.AccountKeys[0].String(), txMsg) + if err != nil { + return solanaGo.Transaction{}, fmt.Errorf("error in Sign: %w", err) + } + var finalSig [64]byte + copy(finalSig[:], sigBytes) + newTx.Signatures = append(newTx.Signatures, finalSig) + + return newTx, nil +} - // exponential increase in wait time, capped at 250ms - deltaT *= 2 - if deltaT > MaxRetryTimeMs { - deltaT = MaxRetryTimeMs +// retryTx contains the logic for retrying the transaction, including exponential backoff and fee bumping. +// Retries until context cancelled by timeout or called externally. +// It uses handleRetry helper function to handle each retry attempt. +func (txm *Txm) retryTx(ctx context.Context, msg pendingTx, currentTx solanaGo.Transaction, sigs *signatureList) { + deltaT := 1 // initial delay in ms + tick := time.After(0) + bumpCount := 0 + bumpTime := time.Now() + var wg sync.WaitGroup + + for { + select { + case <-ctx.Done(): + // stop sending tx after retry tx ctx times out (does not stop confirmation polling for tx) + wg.Wait() + txm.lggr.Debugw("stopped tx retry", "id", msg.id, "signatures", sigs.List(), "err", context.Cause(ctx)) + return + case <-tick: + // determines whether the fee should be bumped based on the fee bump period. + shouldBump := msg.cfg.FeeBumpPeriod != 0 && time.Since(bumpTime) > msg.cfg.FeeBumpPeriod + if shouldBump { + bumpCount++ + bumpTime = time.Now() + // Build new transaction with bumped fee and replace current tx + var err error + currentTx, err = txm.buildTx(ctx, msg, bumpCount) + if err != nil { + // Exit if unable to build transaction for retrying + txm.lggr.Errorw("failed to build bumped retry tx", "error", err, "id", msg.id) + return + } + // allocates space for new signature that will be introduced in handleRetry if needs bumping. + index := sigs.Allocate() + if index != bumpCount { + txm.lggr.Errorw("invariant violation: index does not match bumpCount", "index", index, "bumpCount", bumpCount) + return + } } - tick = time.After(time.Duration(deltaT) * time.Millisecond) + + // Start a goroutine to handle the retry attempt + // takes currentTx and rebroadcast. If needs bumping it will new signature to already allocated space in signatureList. + wg.Add(1) + go func(bump bool, count int, retryTx solanaGo.Transaction) { + defer wg.Done() + txm.handleRetry(ctx, msg, bump, count, retryTx, sigs) + }(shouldBump, bumpCount, currentTx) } - }(ctx, baseTx, initTx) - // return signed tx, id, signature for use in simulation - return initTx, msg.id, sig, nil + // updates the exponential backoff delay up to a maximum limit. + deltaT = deltaT * 2 + if deltaT > MaxRetryTimeMs { + deltaT = MaxRetryTimeMs + } + tick = time.After(time.Duration(deltaT) * time.Millisecond) + } +} + +// handleRetry handles the logic for each retry attempt, including sending the transaction, updating signatures, and logging. +func (txm *Txm) handleRetry(ctx context.Context, msg pendingTx, bump bool, count int, retryTx solanaGo.Transaction, sigs *signatureList) { + // send retry transaction + retrySig, err := txm.sendTx(ctx, &retryTx) + if err != nil { + // this could occur if endpoint goes down or if ctx cancelled + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + txm.lggr.Debugw("ctx error on send retry transaction", "error", err, "signatures", sigs.List(), "id", msg.id) + } else { + txm.lggr.Warnw("failed to send retry transaction", "error", err, "signatures", sigs.List(), "id", msg.id) + } + return + } + + // if bump is true, update signature list and set new signature in space already allocated. + if bump { + if err := txm.txs.AddSignature(msg.id, retrySig); err != nil { + txm.lggr.Warnw("error in adding retry transaction", "error", err, "id", msg.id) + return + } + if err := sigs.Set(count, retrySig); err != nil { + // this should never happen + txm.lggr.Errorw("INVARIANT VIOLATION: failed to set signature", "error", err, "id", msg.id) + return + } + txm.lggr.Debugw("tx rebroadcast with bumped fee", "id", msg.id, "retryCount", count, "fee", msg.cfg.BaseComputeUnitPrice, "signatures", sigs.List()) + } + + // prevent locking on waitgroup when ctx is closed + wait := make(chan struct{}) + go func() { + defer close(wait) + sigs.Wait(count) // wait until bump tx has set the tx signature to compare rebroadcast signatures + }() + select { + case <-ctx.Done(): + return + case <-wait: + } + + // this should never happen (should match the signature saved to sigs) + if fetchedSig, err := sigs.Get(count); err != nil || retrySig != fetchedSig { + txm.lggr.Errorw("original signature does not match retry signature", "expectedSignatures", sigs.List(), "receivedSignature", retrySig, "error", err) + } } -// goroutine that polls to confirm implementation -// cancels the exponential retry once confirmed +// confirm is a goroutine that continuously polls for transaction confirmations and handles rebroadcasts expired transactions if enabled. +// The function runs until the chStop channel signals to stop. func (txm *Txm) confirm() { defer txm.done.Done() ctx, cancel := txm.chStop.NewCtx() @@ -377,139 +382,227 @@ func (txm *Txm) confirm() { case <-ctx.Done(): return case <-tick: - // get list of tx signatures to confirm - sigs := txm.txs.ListAll() - - // exit switch if not txs to confirm - if len(sigs) == 0 { + // If no signatures to confirm and rebroadcast, we can break loop as there's nothing to process. + if txm.InflightTxs() == 0 { break } - // get client client, err := txm.client.Get() if err != nil { - txm.lggr.Errorw("failed to get client in soltxm.confirm", "error", err) - break // exit switch + txm.lggr.Errorw("failed to get client in txm.confirm", "error", err) + break + } + txm.processConfirmations(ctx, client) + if txm.cfg.TxExpirationRebroadcast() { + txm.rebroadcastExpiredTxs(ctx, client) } + } + tick = time.After(utils.WithJitter(txm.cfg.ConfirmPollPeriod())) + } +} - // batch sigs no more than MaxSigsToConfirm each - sigsBatch, err := utils.BatchSplit(sigs, MaxSigsToConfirm) - if err != nil { // this should never happen - txm.lggr.Fatalw("failed to batch signatures", "error", err) - break // exit switch +// processConfirmations checks the status of transaction signatures on-chain and updates our in-memory state accordingly. +// It splits the signatures into batches, retrieves their statuses with an RPC call, and processes each status accordingly. +// The function handles transitions, managing expiration, errors, and transitions between different states like broadcasted, processed, confirmed, and finalized. +// It also determines when to end polling based on the status of each signature cancelling the exponential retry. +func (txm *Txm) processConfirmations(ctx context.Context, client client.ReaderWriter) { + sigsBatch, err := utils.BatchSplit(txm.txs.ListAllSigs(), MaxSigsToConfirm) + if err != nil { // this should never happen + txm.lggr.Fatalw("failed to batch signatures", "error", err) + return + } + + var wg sync.WaitGroup + for i := 0; i < len(sigsBatch); i++ { + statuses, err := client.SignatureStatuses(ctx, sigsBatch[i]) + if err != nil { + txm.lggr.Errorw("failed to get signature statuses in txm.confirm", "error", err) + break + } + + wg.Add(1) + // nonblocking: process batches as soon as they come in + go func(index int) { + defer wg.Done() + + // to process successful first + sortedSigs, sortedRes, err := SortSignaturesAndResults(sigsBatch[i], statuses) + if err != nil { + txm.lggr.Errorw("sorting error", "error", err) + return } - // process signatures - processSigs := func(s []solanaGo.Signature, res []*rpc.SignatureStatusesResult) { - // sort signatures and results process successful first - s, res, err := SortSignaturesAndResults(s, res) - if err != nil { - txm.lggr.Errorw("sorting error", "error", err) - return + for j := 0; j < len(sortedRes); j++ { + sig, status := sortedSigs[j], sortedRes[j] + // sig not found could mean invalid tx or not picked up yet, keep polling + if status == nil { + txm.handleNotFoundSignatureStatus(sig) + continue } - for i := 0; i < len(res); i++ { - // if status is nil (sig not found), continue polling - // sig not found could mean invalid tx or not picked up yet - if res[i] == nil { - txm.lggr.Debugw("tx state: not found", - "signature", s[i], - ) - - // check confirm timeout exceeded - if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) - if err != nil { - txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) - } else { - txm.lggr.Debugw("failed to find transaction within confirm timeout", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout()) - } - } - continue - } - - // if signature has an error, end polling - if res[i].Err != nil { - // Process error to determine the corresponding state and type. - // Skip marking as errored if error considered to not be a failure. - if txState, errType := txm.ProcessError(s[i], res[i].Err, false); errType != NoFailure { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), txState, errType) - if err != nil { - txm.lggr.Infow(fmt.Sprintf("failed to mark transaction as %s", txState.String()), "id", id, "signature", s[i], "error", err) - } else { - txm.lggr.Debugw(fmt.Sprintf("marking transaction as %s", txState.String()), "id", id, "signature", s[i], "error", res[i].Err, "status", res[i].ConfirmationStatus) - } - } - continue - } + // if signature has an error, end polling unless blockhash not found and expiration rebroadcast is enabled + if status.Err != nil { + txm.handleErrorSignatureStatus(sig, status) + continue + } + switch status.ConfirmationStatus { + case rpc.ConfirmationStatusProcessed: // if signature is processed, keep polling for confirmed or finalized status - if res[i].ConfirmationStatus == rpc.ConfirmationStatusProcessed { - // update transaction state in local memory - id, err := txm.txs.OnProcessed(s[i]) - if err != nil && !errors.Is(err, ErrAlreadyInExpectedState) { - txm.lggr.Errorw("failed to mark transaction as processed", "signature", s[i], "error", err) - } else if err == nil { - txm.lggr.Debugw("marking transaction as processed", "id", id, "signature", s[i]) - } - // check confirm timeout exceeded if TxConfirmTimeout set - if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(s[i], txm.cfg.TxConfirmTimeout()) { - id, err := txm.txs.OnError(s[i], txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) - if err != nil { - txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) - } else { - txm.lggr.Debugw("tx failed to move beyond 'processed' within confirm timeout", "id", id, "signature", s[i], "timeoutSeconds", txm.cfg.TxConfirmTimeout()) - } - } - continue - } - + txm.handleProcessedSignatureStatus(sig) + continue + case rpc.ConfirmationStatusConfirmed: // if signature is confirmed, keep polling for finalized status - if res[i].ConfirmationStatus == rpc.ConfirmationStatusConfirmed { - id, err := txm.txs.OnConfirmed(s[i]) - if err != nil && !errors.Is(err, ErrAlreadyInExpectedState) { - txm.lggr.Errorw("failed to mark transaction as confirmed", "id", id, "signature", s[i], "error", err) - } else if err == nil { - txm.lggr.Debugw("marking transaction as confirmed", "id", id, "signature", s[i]) - } - continue - } - + txm.handleConfirmedSignatureStatus(sig) + continue + case rpc.ConfirmationStatusFinalized: // if signature is finalized, end polling - if res[i].ConfirmationStatus == rpc.ConfirmationStatusFinalized { - id, err := txm.txs.OnFinalized(s[i], txm.cfg.TxRetentionTimeout()) - if err != nil { - txm.lggr.Errorw("failed to mark transaction as finalized", "id", id, "signature", s[i], "error", err) - } else { - txm.lggr.Debugw("marking transaction as finalized", "id", id, "signature", s[i]) - } - continue - } + txm.handleFinalizedSignatureStatus(sig) + continue + default: + txm.lggr.Warnw("unknown confirmation status", "signature", sig, "status", status.ConfirmationStatus) + continue } } + }(i) + } + wg.Wait() // wait for processing to finish +} - // waitgroup for processing - var wg sync.WaitGroup +// handleNotFoundSignatureStatus handles the case where a transaction signature is not found on-chain. +// If the confirmation timeout has been exceeded it marks the transaction as errored. +func (txm *Txm) handleNotFoundSignatureStatus(sig solanaGo.Signature) { + txm.lggr.Debugw("tx state: not found", "signature", sig) + if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(sig, txm.cfg.TxConfirmTimeout()) { + id, err := txm.txs.OnError(sig, txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) + if err != nil { + txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", sig, "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) + } else { + txm.lggr.Debugw("failed to find transaction within confirm timeout", "id", id, "signature", sig, "timeoutSeconds", txm.cfg.TxConfirmTimeout()) + } + } +} - // loop through batch - for i := 0; i < len(sigsBatch); i++ { - // fetch signature statuses - statuses, err := client.SignatureStatuses(ctx, sigsBatch[i]) - if err != nil { - txm.lggr.Errorw("failed to get signature statuses in soltxm.confirm", "error", err) - break // exit for loop - } +// handleErrorSignatureStatus handles the case where a transaction signature has an error on-chain. +// If the error is BlockhashNotFound and expiration rebroadcast is enabled, it skips error handling to allow rebroadcasting. +// Otherwise, it marks the transaction as errored. +func (txm *Txm) handleErrorSignatureStatus(sig solanaGo.Signature, status *rpc.SignatureStatusesResult) { + // We want to rebroadcast rather than drop tx if expiration rebroadcast is enabled when blockhash was not found. + // converting error to string so we are able to check if it contains the error message. + if status.Err != nil && strings.Contains(fmt.Sprintf("%v", status.Err), "BlockhashNotFound") && txm.cfg.TxExpirationRebroadcast() { + return + } - wg.Add(1) - // nonblocking: process batches as soon as they come in - go func(index int) { - defer wg.Done() - processSigs(sigsBatch[index], statuses) - }(i) - } - wg.Wait() // wait for processing to finish + // Process error to determine the corresponding state and type. + // Skip marking as errored if error considered to not be a failure. + if txState, errType := txm.ProcessError(sig, status.Err, false); errType != NoFailure { + id, err := txm.txs.OnError(sig, txm.cfg.TxRetentionTimeout(), txState, errType) + if err != nil { + txm.lggr.Infow(fmt.Sprintf("failed to mark transaction as %s", txState.String()), "id", id, "signature", sig, "error", err) + } else { + txm.lggr.Debugw(fmt.Sprintf("marking transaction as %s", txState.String()), "id", id, "signature", sig, "error", status.Err, "status", status.ConfirmationStatus) } - tick = time.After(utils.WithJitter(txm.cfg.ConfirmPollPeriod())) + } +} + +// handleProcessedSignatureStatus handles the case where a transaction signature is in the "processed" state on-chain. +// It updates the transaction state in the local memory and checks if the confirmation timeout has been exceeded. +// If the timeout is exceeded, it marks the transaction as errored. +func (txm *Txm) handleProcessedSignatureStatus(sig solanaGo.Signature) { + // update transaction state in local memory + id, err := txm.txs.OnProcessed(sig) + if err != nil && !errors.Is(err, ErrAlreadyInExpectedState) { + txm.lggr.Errorw("failed to mark transaction as processed", "signature", sig, "error", err) + } else if err == nil { + txm.lggr.Debugw("marking transaction as processed", "id", id, "signature", sig) + } + // check confirm timeout exceeded if TxConfirmTimeout set + if txm.cfg.TxConfirmTimeout() != 0*time.Second && txm.txs.Expired(sig, txm.cfg.TxConfirmTimeout()) { + id, err := txm.txs.OnError(sig, txm.cfg.TxRetentionTimeout(), Errored, TxFailDrop) + if err != nil { + txm.lggr.Infow("failed to mark transaction as errored", "id", id, "signature", sig, "timeoutSeconds", txm.cfg.TxConfirmTimeout(), "error", err) + } else { + txm.lggr.Debugw("tx failed to move beyond 'processed' within confirm timeout", "id", id, "signature", sig, "timeoutSeconds", txm.cfg.TxConfirmTimeout()) + } + } +} + +// handleConfirmedSignatureStatus handles the case where a transaction signature is in the "confirmed" state on-chain. +// It updates the transaction state in the local memory. +func (txm *Txm) handleConfirmedSignatureStatus(sig solanaGo.Signature) { + id, err := txm.txs.OnConfirmed(sig) + if err != nil && !errors.Is(err, ErrAlreadyInExpectedState) { + txm.lggr.Errorw("failed to mark transaction as confirmed", "id", id, "signature", sig, "error", err) + } else if err == nil { + txm.lggr.Debugw("marking transaction as confirmed", "id", id, "signature", sig) + } +} + +// handleFinalizedSignatureStatus handles the case where a transaction signature is in the "finalized" state on-chain. +// It updates the transaction state in the local memory. +func (txm *Txm) handleFinalizedSignatureStatus(sig solanaGo.Signature) { + id, err := txm.txs.OnFinalized(sig, txm.cfg.TxRetentionTimeout()) + if err != nil { + txm.lggr.Errorw("failed to mark transaction as finalized", "id", id, "signature", sig, "error", err) + } else { + txm.lggr.Debugw("marking transaction as finalized", "id", id, "signature", sig) + } +} + +// rebroadcastExpiredTxs attempts to rebroadcast all transactions that are in broadcasted state and have expired. +// An expired tx is one where it's blockhash lastValidBlockHeight (last valid block number) is smaller than the current block height (block number). +// The function loops through all expired txes, rebroadcasts them with a new blockhash, and updates the lastValidBlockHeight. +// If any error occurs during rebroadcast attempt, they are discarded, and the function continues with the next transaction. +func (txm *Txm) rebroadcastExpiredTxs(ctx context.Context, client client.ReaderWriter) { + currBlock, err := client.GetLatestBlock(ctx) + if err != nil || currBlock == nil || currBlock.BlockHeight == nil { + txm.lggr.Errorw("failed to get current block height", "error", err) + return + } + + // Get all expired broadcasted transactions at current block number. Safe to quit if no txes are found. + expiredBroadcastedTxes := txm.txs.ListAllExpiredBroadcastedTxs(*currBlock.BlockHeight) + if len(expiredBroadcastedTxes) == 0 { + return + } + + blockhash, err := client.LatestBlockhash(ctx) + if err != nil { + txm.lggr.Errorw("failed to getLatestBlockhash for rebroadcast", "error", err) + return + } + if blockhash == nil || blockhash.Value == nil { + txm.lggr.Errorw("nil pointer returned from getLatestBlockhash for rebroadcast") + return + } + + // rebroadcast each expired tx after updating blockhash, lastValidBlockHeight and compute unit price (priority fee) + for _, tx := range expiredBroadcastedTxes { + txm.lggr.Debugw("transaction expired, rebroadcasting", "id", tx.id, "signature", tx.signatures, "lastValidBlockHeight", tx.lastValidBlockHeight, "currentBlockHeight", *currBlock.BlockHeight) + // Removes all signatures associated to prior tx and cancels context. + _, err := txm.txs.Remove(tx.id) + if err != nil { + txm.lggr.Errorw("failed to remove expired transaction", "id", tx.id, "error", err) + continue + } + + tx.tx.Message.RecentBlockhash = blockhash.Value.Blockhash + tx.cfg.BaseComputeUnitPrice = txm.fee.BaseComputeUnitPrice() + rebroadcastTx := pendingTx{ + tx: tx.tx, + cfg: tx.cfg, + id: tx.id, // using same id in case it was set by caller and we need to maintain it. + lastValidBlockHeight: blockhash.Value.LastValidBlockHeight, + } + // call sendWithRetry directly to avoid enqueuing + _, _, _, sendErr := txm.sendWithRetry(ctx, rebroadcastTx) + if sendErr != nil { + stateTransitionErr := txm.txs.OnPrebroadcastError(tx.id, txm.cfg.TxRetentionTimeout(), Errored, TxFailReject) + txm.lggr.Errorw("failed to rebroadcast transaction", "id", tx.id, "error", errors.Join(sendErr, stateTransitionErr)) + continue + } + + txm.lggr.Debugw("rebroadcast transaction sent", "id", tx.id) } } @@ -580,7 +673,7 @@ func (txm *Txm) reap() { } // Enqueue enqueues a msg destined for the solana chain. -func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txCfgs ...SetTxConfig) error { +func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Transaction, txID *string, txLastValidBlockHeight uint64, txCfgs ...SetTxConfig) error { if err := txm.Ready(); err != nil { return fmt.Errorf("error in soltxm.Enqueue: %w", err) } @@ -628,9 +721,10 @@ func (txm *Txm) Enqueue(ctx context.Context, accountID string, tx *solanaGo.Tran } msg := pendingTx{ - tx: *tx, - cfg: cfg, - id: id, + id: id, + tx: *tx, + cfg: cfg, + lastValidBlockHeight: txLastValidBlockHeight, } select { @@ -745,7 +839,7 @@ func (txm *Txm) simulateTx(ctx context.Context, tx *solanaGo.Transaction) (res * return } -// processError parses and handles relevant errors found in simulation results +// ProcessError parses and handles relevant errors found in simulation results func (txm *Txm) ProcessError(sig solanaGo.Signature, resErr interface{}, simulation bool) (txState TxState, errType TxErrType) { if resErr != nil { // handle various errors @@ -827,8 +921,9 @@ func (txm *Txm) ProcessError(sig solanaGo.Signature, resErr interface{}, simulat return } +// InflightTxs returns the number of signatures being tracked for all transactions not yet finalized or errored func (txm *Txm) InflightTxs() int { - return len(txm.txs.ListAll()) + return len(txm.txs.ListAllSigs()) } // Close close service diff --git a/pkg/solana/txm/txm_integration_test.go b/pkg/solana/txm/txm_integration_test.go new file mode 100644 index 000000000..154a42f6a --- /dev/null +++ b/pkg/solana/txm/txm_integration_test.go @@ -0,0 +1,187 @@ +//go:build integration + +package txm_test + +import ( + "context" + "testing" + "time" + + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/programs/system" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" + "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" + + solanaClient "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" + keyMocks "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm/mocks" +) + +func TestTxm_Integration_ExpirationRebroadcast(t *testing.T) { + t.Parallel() + url := solanaClient.SetupLocalSolNode(t) // live validator + + type TestCase struct { + name string + txExpirationRebroadcast bool + useValidBlockHash bool + expectRebroadcast bool + expectTransactionStatus types.TransactionStatus + } + + testCases := []TestCase{ + { + name: "WithRebroadcast", + txExpirationRebroadcast: true, + useValidBlockHash: false, + expectRebroadcast: true, + expectTransactionStatus: types.Finalized, + }, + { + name: "WithoutRebroadcast", + txExpirationRebroadcast: false, + useValidBlockHash: false, + expectRebroadcast: false, + expectTransactionStatus: types.Failed, + }, + { + name: "ConfirmedBeforeRebroadcast", + txExpirationRebroadcast: true, + useValidBlockHash: true, + expectRebroadcast: false, + expectTransactionStatus: types.Finalized, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx, client, txmInstance, senderPubKey, receiverPubKey, observer := setup(t, url, tc.txExpirationRebroadcast) + + // Record initial balance + initSenderBalance, err := client.Balance(ctx, senderPubKey) + require.NoError(t, err) + const amount = 1 * solana.LAMPORTS_PER_SOL + + // Create and enqueue tx + txID := tc.name + tx, lastValidBlockHeight := createTransaction(ctx, t, client, senderPubKey, receiverPubKey, amount, tc.useValidBlockHash) + require.NoError(t, txmInstance.Enqueue(ctx, "", tx, &txID, lastValidBlockHeight)) + + // Wait for the transaction to reach the expected status + require.Eventually(t, func() bool { + status, statusErr := txmInstance.GetTransactionStatus(ctx, txID) + if statusErr != nil { + return false + } + return status == tc.expectTransactionStatus + }, 60*time.Second, 1*time.Second, "Transaction should eventually reach expected status") + + // Verify balances + finalSenderBalance, err := client.Balance(ctx, senderPubKey) + require.NoError(t, err) + finalReceiverBalance, err := client.Balance(ctx, receiverPubKey) + require.NoError(t, err) + + if tc.expectTransactionStatus == types.Finalized { + require.Less(t, finalSenderBalance, initSenderBalance, "Sender balance should decrease") + require.Equal(t, amount, finalReceiverBalance, "Receiver should receive the transferred amount") + } else { + require.Equal(t, initSenderBalance, finalSenderBalance, "Sender balance should remain the same") + require.Equal(t, uint64(0), finalReceiverBalance, "Receiver should not receive any funds") + } + + // Verify rebroadcast logs + rebroadcastLogs := observer.FilterMessageSnippet("rebroadcast transaction sent").Len() + rebroadcastLogs2 := observer.FilterMessageSnippet("transaction expired, rebroadcasting").Len() + if tc.expectRebroadcast { + require.Equal(t, 1, rebroadcastLogs, "Expected rebroadcast log message not found") + require.Equal(t, 1, rebroadcastLogs2, "Expected rebroadcast log message not found") + } else { + require.Equal(t, 0, rebroadcastLogs, "Rebroadcast should not occur") + require.Equal(t, 0, rebroadcastLogs2, "Rebroadcast should not occur") + } + }) + } +} + +func setup(t *testing.T, url string, txExpirationRebroadcast bool) (context.Context, *solanaClient.Client, *txm.Txm, solana.PublicKey, solana.PublicKey, *observer.ObservedLogs) { + ctx := tests.Context(t) + + // Generate sender and receiver keys and fund sender account + senderKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + senderPubKey := senderKey.PublicKey() + receiverKey, err := solana.NewRandomPrivateKey() + require.NoError(t, err) + receiverPubKey := receiverKey.PublicKey() + solanaClient.FundTestAccounts(t, []solana.PublicKey{senderPubKey}, url) + + // Set up mock keystore with sender key + mkey := keyMocks.NewSimpleKeystore(t) + mkey.On("Sign", mock.Anything, senderPubKey.String(), mock.Anything).Return(func(_ context.Context, _ string, data []byte) []byte { + sig, _ := senderKey.Sign(data) + return sig[:] + }, nil) + + // Set configs + cfg := config.NewDefault() + cfg.Chain.TxExpirationRebroadcast = &txExpirationRebroadcast + cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(10 * time.Second) // to get the finalized tx status + + // Initialize the Solana client and TXM + lggr, obs := logger.TestObserved(t, zapcore.DebugLevel) + client, err := solanaClient.NewClient(url, cfg, 2*time.Second, lggr) + require.NoError(t, err) + loader := utils.NewLazyLoad(func() (solanaClient.ReaderWriter, error) { return client, nil }) + txmInstance := txm.NewTxm("localnet", loader, nil, cfg, mkey, lggr) + servicetest.Run(t, txmInstance) + + return ctx, client, txmInstance, senderPubKey, receiverPubKey, obs +} + +// createTransaction is a helper function to create a transaction based on the test case. +func createTransaction(ctx context.Context, t *testing.T, client *solanaClient.Client, senderPubKey, receiverPubKey solana.PublicKey, amount uint64, useValidBlockHash bool) (*solana.Transaction, uint64) { + var blockhash solana.Hash + var lastValidBlockHeight uint64 + + if useValidBlockHash { + // Get a valid recent blockhash + recentBlockHashResult, err := client.LatestBlockhash(ctx) + require.NoError(t, err) + blockhash = recentBlockHashResult.Value.Blockhash + lastValidBlockHeight = recentBlockHashResult.Value.LastValidBlockHeight + } else { + // Use empty blockhash to simulate expiration + blockhash = solana.Hash{} + lastValidBlockHeight = 0 + } + + // Create the transaction + tx, err := solana.NewTransaction( + []solana.Instruction{ + system.NewTransferInstruction( + amount, + senderPubKey, + receiverPubKey, + ).Build(), + }, + blockhash, + solana.TransactionPayer(senderPubKey), + ) + require.NoError(t, err) + + return tx, lastValidBlockHeight +} diff --git a/pkg/solana/txm/txm_internal_test.go b/pkg/solana/txm/txm_internal_test.go index 0054e0a2b..13c861362 100644 --- a/pkg/solana/txm/txm_internal_test.go +++ b/pkg/solana/txm/txm_internal_test.go @@ -161,7 +161,6 @@ func TestTxm(t *testing.T) { return out }, nil, ) - // happy path (send => simulate success => tx: nil => tx: processed => tx: confirmed => finalized => done) t.Run("happyPath", func(t *testing.T) { sig := randomSignature(t) @@ -204,7 +203,8 @@ func TestTxm(t *testing.T) { // send tx testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // no transactions stored inflight txs list @@ -240,7 +240,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed // no transactions stored inflight txs list @@ -272,7 +273,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared quickly @@ -308,7 +310,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared after timeout @@ -348,7 +351,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared after timeout @@ -399,7 +403,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared after timeout @@ -441,7 +446,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // txs cleared after timeout @@ -486,7 +492,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // inflight txs cleared after timeout @@ -538,7 +545,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // inflight txs cleared after timeout @@ -576,7 +584,8 @@ func TestTxm(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, waitDuration, txm, prom, empty) // inflight txs cleared after timeout @@ -622,7 +631,8 @@ func TestTxm(t *testing.T) { // send tx testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // no transactions stored inflight txs list @@ -676,7 +686,8 @@ func TestTxm(t *testing.T) { // send tx - with disabled fee bumping testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, SetFeeBumpPeriod(0))) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight, SetFeeBumpPeriod(0))) wg.Wait() // no transactions stored inflight txs list @@ -728,7 +739,8 @@ func TestTxm(t *testing.T) { // send tx - with disabled fee bumping and disabled compute unit limit testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, SetFeeBumpPeriod(0), SetComputeUnitLimit(0))) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight, SetFeeBumpPeriod(0), SetComputeUnitLimit(0))) wg.Wait() // no transactions stored inflight txs list @@ -836,7 +848,8 @@ func TestTxm_disabled_confirm_timeout_with_retention(t *testing.T) { // tx should be able to queue testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait to be picked up and processed waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout @@ -875,7 +888,8 @@ func TestTxm_disabled_confirm_timeout_with_retention(t *testing.T) { // tx should be able to queue testTxID := uuid.NewString() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout @@ -920,7 +934,8 @@ func TestTxm_disabled_confirm_timeout_with_retention(t *testing.T) { // tx should be able to queue testTxID := uuid.NewString() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // wait till send tx waitFor(t, 5*time.Second, txm, prom, empty) // inflight txs cleared after timeout @@ -1040,7 +1055,8 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { // send tx testTxID := uuid.New().String() - assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID)) + lastValidBlockHeight := uint64(100) + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &testTxID, lastValidBlockHeight)) wg.Wait() // no transactions stored inflight txs list @@ -1069,7 +1085,8 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { mc.On("SimulateTx", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("simulation failed")).Once() // tx should NOT be able to queue - assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, nil)) + lastValidBlockHeight := uint64(0) + assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, nil, lastValidBlockHeight)) }) t.Run("simulation_returns_error", func(t *testing.T) { @@ -1084,8 +1101,9 @@ func TestTxm_compute_unit_limit_estimation(t *testing.T) { mc.On("SimulateTx", mock.Anything, simulateTx, mock.Anything).Return(&rpc.SimulateTransactionResult{Err: errors.New("InstructionError")}, nil).Once() txID := uuid.NewString() + lastValidBlockHeight := uint64(100) // tx should NOT be able to queue - assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, &txID)) + assert.Error(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) // tx should be stored in-memory and moved to errored state status, err := txm.GetTransactionStatus(ctx, txID) require.NoError(t, err) @@ -1131,6 +1149,7 @@ func TestTxm_Enqueue(t *testing.T) { ) require.NoError(t, err) + lastValidBlockHeight := uint64(0) invalidTx, err := solana.NewTransaction( []solana.Instruction{ system.NewTransferInstruction( @@ -1147,28 +1166,29 @@ func TestTxm_Enqueue(t *testing.T) { loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) txm := NewTxm("enqueue_test", loader, nil, cfg, mkey, lggr) - require.ErrorContains(t, txm.Enqueue(ctx, "txmUnstarted", &solana.Transaction{}, nil), "not started") + require.ErrorContains(t, txm.Enqueue(ctx, "txmUnstarted", &solana.Transaction{}, nil, lastValidBlockHeight), "not started") require.NoError(t, txm.Start(ctx)) t.Cleanup(func() { require.NoError(t, txm.Close()) }) txs := []struct { - name string - tx *solana.Transaction - fail bool + name string + tx *solana.Transaction + lastValidBlockHeight uint64 + fail bool }{ - {"success", tx, false}, - {"invalid_key", invalidTx, true}, - {"nil_pointer", nil, true}, - {"empty_tx", &solana.Transaction{}, true}, + {"success", tx, 100, false}, + {"invalid_key", invalidTx, 0, true}, + {"nil_pointer", nil, 0, true}, + {"empty_tx", &solana.Transaction{}, 0, true}, } for _, run := range txs { t.Run(run.name, func(t *testing.T) { if !run.fail { - assert.NoError(t, txm.Enqueue(ctx, run.name, run.tx, nil)) + assert.NoError(t, txm.Enqueue(ctx, run.name, run.tx, nil, run.lastValidBlockHeight)) return } - assert.Error(t, txm.Enqueue(ctx, run.name, run.tx, nil)) + assert.Error(t, txm.Enqueue(ctx, run.name, run.tx, nil, run.lastValidBlockHeight)) }) } } @@ -1186,3 +1206,406 @@ func addSigAndLimitToTx(t *testing.T, keystore SimpleKeystore, pubkey solana.Pub require.NoError(t, fees.SetComputeUnitLimit(&txCopy, limit)) return &txCopy } + +func TestTxm_ExpirationRebroadcast(t *testing.T) { + t.Parallel() + estimator := "fixed" + id := "mocknet-" + estimator + "-" + uuid.NewString() + cfg := config.NewDefault() + cfg.Chain.FeeEstimatorMode = &estimator + cfg.Chain.TxConfirmTimeout = relayconfig.MustNewDuration(5 * time.Second) + cfg.Chain.TxRetentionTimeout = relayconfig.MustNewDuration(10 * time.Second) // Enable retention to keep transactions after finality and be able to check their statuses. + lggr := logger.Test(t) + ctx := tests.Context(t) + + // Helper function to set up common test environment + setupTxmTest := func( + txExpirationRebroadcast bool, + latestBlockhashFunc func() (*rpc.GetLatestBlockhashResult, error), + getLatestBlockFunc func() (*rpc.GetBlockResult, error), + sendTxFunc func() (solana.Signature, error), + statuses map[solana.Signature]func() *rpc.SignatureStatusesResult, + ) (*Txm, *mocks.ReaderWriter, *keyMocks.SimpleKeystore) { + cfg.Chain.TxExpirationRebroadcast = &txExpirationRebroadcast + + mc := mocks.NewReaderWriter(t) + if latestBlockhashFunc != nil { + mc.On("LatestBlockhash", mock.Anything).Return( + func(_ context.Context) (*rpc.GetLatestBlockhashResult, error) { + return latestBlockhashFunc() + }, + ).Maybe() + } + if getLatestBlockFunc != nil { + mc.On("GetLatestBlock", mock.Anything).Return( + func(_ context.Context) (*rpc.GetBlockResult, error) { + return getLatestBlockFunc() + }, + ).Maybe() + } + if sendTxFunc != nil { + mc.On("SendTx", mock.Anything, mock.Anything).Return( + func(_ context.Context, _ *solana.Transaction) (solana.Signature, error) { + return sendTxFunc() + }, + ).Maybe() + } + + mc.On("SimulateTx", mock.Anything, mock.Anything, mock.Anything).Return(&rpc.SimulateTransactionResult{}, nil).Maybe() + if statuses != nil { + mc.On("SignatureStatuses", mock.Anything, mock.AnythingOfType("[]solana.Signature")).Return( + func(_ context.Context, sigs []solana.Signature) ([]*rpc.SignatureStatusesResult, error) { + var out []*rpc.SignatureStatusesResult + for _, sig := range sigs { + getStatus, exists := statuses[sig] + if !exists { + out = append(out, nil) + } else { + out = append(out, getStatus()) + } + } + return out, nil + }, + ).Maybe() + } + + mkey := keyMocks.NewSimpleKeystore(t) + mkey.On("Sign", mock.Anything, mock.Anything, mock.Anything).Return([]byte{}, nil) + + loader := utils.NewLazyLoad(func() (client.ReaderWriter, error) { return mc, nil }) + txm := NewTxm(id, loader, nil, cfg, mkey, lggr) + require.NoError(t, txm.Start(ctx)) + t.Cleanup(func() { require.NoError(t, txm.Close()) }) + + return txm, mc, mkey + } + + // tracking prom metrics + prom := soltxmProm{id: id} + + t.Run("WithRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + // Mock getLatestBlock to return a value greater than 0 for blockHeight + getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { + val := uint64(1500) + return &rpc.GetBlockResult{ + BlockHeight: &val, + }, nil + } + + rebroadcastCount := 0 + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + defer func() { rebroadcastCount++ }() + // rebroadcast call will go through because lastValidBlockHeight is bigger than blockHeight + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil + } + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + nowTs := time.Now() + sigStatusCallCount := 0 + var wg sync.WaitGroup + wg.Add(1) + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // First transaction should be rebroadcasted. + if time.Since(nowTs) < cfg.TxConfirmTimeout()-2*time.Second { + return nil + } + // Second transaction should reach finalization. + sigStatusCallCount++ + if sigStatusCallCount == 1 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusProcessed, + } + } + if sigStatusCallCount == 2 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusConfirmed, + } + } + wg.Done() + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusFinalized, + } + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + + tx, _ := getTx(t, 0, mkey) + txID := "test-rebroadcast" + lastValidBlockHeight := uint64(100) // lastValidBlockHeight is smaller than blockHeight + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, txm.cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.confirmed++ + prom.finalized++ + prom.assertEqual(t) + + // Check that transaction for txID has been finalized and rebroadcasted 1 time. + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + require.Equal(t, 1, rebroadcastCount) + }) + + t.Run("WithoutRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := false + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + rebroadcastCount := 0 + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + nowTs := time.Now() + var wg sync.WaitGroup + wg.Add(1) + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // Transaction remains unconfirmed and should not be rebroadcasted. + if time.Since(nowTs) < cfg.TxConfirmTimeout() { + return nil + } + wg.Done() + return nil + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, nil, nil, sendTxFunc, statuses) + + tx, _ := getTx(t, 5, mkey) + txID := "test-no-rebroadcast" + lastValidBlockHeight := uint64(0) // original lastValidBlockHeight is invalid + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, txm.cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.drop++ + prom.error++ + prom.assertEqual(t) + + // Check that transaction for txID has not been finalized and has not been rebroadcasted + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) + require.Equal(t, 0, rebroadcastCount) + }) + + t.Run("WithMultipleRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + // Mock getLatestBlock to return a value greater than 0 + getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { + val := uint64(1500) + return &rpc.GetBlockResult{ + BlockHeight: &val, + }, nil + } + + // Mock LatestBlockhash to return an invalid blockhash in the first 2 attempts to rebroadcast. + // the last one is valid because it is greater than the blockHeight + rebroadcastCount := 0 + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + defer func() { rebroadcastCount++ }() + if rebroadcastCount < 2 { + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(1000), + }, + }, nil + } + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil + } + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + nowTs := time.Now() + sigStatusCallCount := 0 + var wg sync.WaitGroup + wg.Add(1) + statuses[sig1] = func() *rpc.SignatureStatusesResult { + // transaction should be rebroadcasted multiple times. + if time.Since(nowTs) < cfg.TxConfirmTimeout()-2*time.Second { + return nil + } + // Second transaction should reach finalization. + sigStatusCallCount++ + if sigStatusCallCount == 1 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusProcessed, + } + } else if sigStatusCallCount == 2 { + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusConfirmed, + } + } + wg.Done() + return &rpc.SignatureStatusesResult{ + ConfirmationStatus: rpc.ConfirmationStatusFinalized, + } + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + tx, _ := getTx(t, 0, mkey) + txID := "test-rebroadcast" + lastValidBlockHeight := uint64(100) // lastValidBlockHeight is smaller than blockHeight + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, txm.cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.confirmed++ + prom.finalized++ + prom.assertEqual(t) + + // Check that transaction for txID has been finalized and rebroadcasted multiple times. + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + require.Equal(t, 3, rebroadcastCount) + }) + + t.Run("ConfirmedBeforeRebroadcast", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + // Mock getLatestBlock to return a value greater than 0 + getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { + val := uint64(1500) + return &rpc.GetBlockResult{ + BlockHeight: &val, + }, nil + } + + rebroadcastCount := 0 + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + defer func() { rebroadcastCount++ }() + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(1000), + }, + }, nil + } + + var wg sync.WaitGroup + wg.Add(1) + count := 0 + statuses[sig1] = func() *rpc.SignatureStatusesResult { + defer func() { count++ }() + + out := &rpc.SignatureStatusesResult{} + if count == 1 { + out.ConfirmationStatus = rpc.ConfirmationStatusConfirmed + return out + } + if count == 2 { + out.ConfirmationStatus = rpc.ConfirmationStatusFinalized + wg.Done() + return out + } + out.ConfirmationStatus = rpc.ConfirmationStatusProcessed + return out + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + tx, _ := getTx(t, 0, mkey) + txID := "test-confirmed-before-rebroadcast" + lastValidBlockHeight := uint64(1500) // original lastValidBlockHeight is valid + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, txm.cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.confirmed++ + prom.finalized++ + prom.assertEqual(t) + + // Check that transaction has been finalized without rebroadcast + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Finalized, status) + require.Equal(t, 0, rebroadcastCount) + }) + + t.Run("RebroadcastWithError", func(t *testing.T) { + txExpirationRebroadcast := true + statuses := map[solana.Signature]func() *rpc.SignatureStatusesResult{} + + // To force rebroadcast, first call needs to be smaller than blockHeight + // following rebroadcast call will go through because lastValidBlockHeight will be bigger than blockHeight + getLatestBlockFunc := func() (*rpc.GetBlockResult, error) { + val := uint64(1500) + return &rpc.GetBlockResult{ + BlockHeight: &val, + }, nil + } + + rebroadcastCount := 0 + latestBlockhashFunc := func() (*rpc.GetLatestBlockhashResult, error) { + defer func() { rebroadcastCount++ }() + return &rpc.GetLatestBlockhashResult{ + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: uint64(2000), + }, + }, nil + } + + sig1 := randomSignature(t) + sendTxFunc := func() (solana.Signature, error) { + return sig1, nil + } + + var wg sync.WaitGroup + wg.Add(1) + count := 0 + statuses[sig1] = func() *rpc.SignatureStatusesResult { + defer func() { count++ }() + // Transaction remains unconfirmed + if count == 1 { + wg.Done() + } + return nil + } + + txm, _, mkey := setupTxmTest(txExpirationRebroadcast, latestBlockhashFunc, getLatestBlockFunc, sendTxFunc, statuses) + tx, _ := getTx(t, 0, mkey) + txID := "test-rebroadcast-error" + lastValidBlockHeight := uint64(100) // lastValidBlockHeight is smaller than blockHeight + assert.NoError(t, txm.Enqueue(ctx, t.Name(), tx, &txID, lastValidBlockHeight)) + wg.Wait() + waitFor(t, cfg.TxConfirmTimeout(), txm, prom, empty) + + // check prom metric + prom.drop++ + prom.error++ + prom.assertEqual(t) + + // Transaction should be moved to failed after trying to rebroadcast 1 time. + status, err := txm.GetTransactionStatus(ctx, txID) + require.NoError(t, err) + require.Equal(t, types.Failed, status) + require.Equal(t, 1, rebroadcastCount) + }) +} diff --git a/pkg/solana/txm/txm_load_test.go b/pkg/solana/txm/txm_load_test.go index 5d5a8061b..333c95e23 100644 --- a/pkg/solana/txm/txm_load_test.go +++ b/pkg/solana/txm/txm_load_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" solanaClient "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" "github.com/smartcontractkit/chainlink-solana/pkg/solana/config" "github.com/smartcontractkit/chainlink-solana/pkg/solana/txm" @@ -22,6 +21,7 @@ import ( relayconfig "github.com/smartcontractkit/chainlink-common/pkg/config" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/services/servicetest" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" ) @@ -83,12 +83,11 @@ func TestTxm_Integration(t *testing.T) { // already started assert.Error(t, txm.Start(ctx)) - - createTx := func(signer solana.PublicKey, sender solana.PublicKey, receiver solana.PublicKey, amt uint64) *solana.Transaction { + createTx := func(signer solana.PublicKey, sender solana.PublicKey, receiver solana.PublicKey, amt uint64) (*solana.Transaction, uint64) { // create transfer tx - hash, err := client.LatestBlockhash(ctx) - assert.NoError(t, err) - tx, err := solana.NewTransaction( + hash, blockhashErr := client.LatestBlockhash(ctx) + assert.NoError(t, blockhashErr) + tx, txErr := solana.NewTransaction( []solana.Instruction{ system.NewTransferInstruction( amt, @@ -99,22 +98,27 @@ func TestTxm_Integration(t *testing.T) { hash.Value.Blockhash, solana.TransactionPayer(signer), ) - require.NoError(t, err) - return tx + require.NoError(t, txErr) + return tx, hash.Value.LastValidBlockHeight } - // enqueue txs (must pass to move on to load test) - require.NoError(t, txm.Enqueue(ctx, "test_success_0", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL), nil)) - require.Error(t, txm.Enqueue(ctx, "test_invalidSigner", createTx(pubKeyReceiver, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL), nil)) // cannot sign tx before enqueuing - require.NoError(t, txm.Enqueue(ctx, "test_invalidReceiver", createTx(pubKey, pubKey, solana.PublicKey{}, solana.LAMPORTS_PER_SOL), nil)) + tx, lastValidBlockHeight := createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) + require.NoError(t, txm.Enqueue(ctx, "test_success_0", tx, nil, lastValidBlockHeight)) + tx2, lastValidBlockHeight2 := createTx(pubKeyReceiver, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) + require.Error(t, txm.Enqueue(ctx, "test_invalidSigner", tx2, nil, lastValidBlockHeight2)) // cannot sign tx before enqueuing + tx3, lastValidBlockHeight3 := createTx(pubKey, pubKey, solana.PublicKey{}, solana.LAMPORTS_PER_SOL) + require.NoError(t, txm.Enqueue(ctx, "test_invalidReceiver", tx3, nil, lastValidBlockHeight3)) time.Sleep(500 * time.Millisecond) // pause 0.5s for new blockhash - require.NoError(t, txm.Enqueue(ctx, "test_success_1", createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL), nil)) - require.NoError(t, txm.Enqueue(ctx, "test_txFail", createTx(pubKey, pubKey, pubKeyReceiver, 1000*solana.LAMPORTS_PER_SOL), nil)) + tx4, lastValidBlockHeight4 := createTx(pubKey, pubKey, pubKeyReceiver, solana.LAMPORTS_PER_SOL) + require.NoError(t, txm.Enqueue(ctx, "test_success_1", tx4, nil, lastValidBlockHeight4)) + tx5, lastValidBlockHeight5 := createTx(pubKey, pubKey, pubKeyReceiver, 1000*solana.LAMPORTS_PER_SOL) + require.NoError(t, txm.Enqueue(ctx, "test_txFail", tx5, nil, lastValidBlockHeight5)) // load test: try to overload txs, confirm, or simulation for i := 0; i < 1000; i++ { - assert.NoError(t, txm.Enqueue(ctx, fmt.Sprintf("load_%d", i), createTx(loadTestKey.PublicKey(), loadTestKey.PublicKey(), loadTestKey.PublicKey(), uint64(i)), nil)) - time.Sleep(10 * time.Millisecond) // ~100 txs per second (note: have run 5ms delays for ~200tx/s succesfully) + tx6, lastValidBlockHeight6 := createTx(loadTestKey.PublicKey(), loadTestKey.PublicKey(), loadTestKey.PublicKey(), uint64(i)) + assert.NoError(t, txm.Enqueue(ctx, fmt.Sprintf("load_%d", i), tx6, nil, lastValidBlockHeight6)) + time.Sleep(10 * time.Millisecond) // ~100 txs per second (note: have run 5ms delays for ~200tx/s successfully) } // check to make sure all txs are closed out from inflight list (longest should last MaxConfirmTimeout) diff --git a/pkg/solana/txm/txm_race_test.go b/pkg/solana/txm/txm_race_test.go index 42062718f..33ec0f7bf 100644 --- a/pkg/solana/txm/txm_race_test.go +++ b/pkg/solana/txm/txm_race_test.go @@ -62,7 +62,6 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { // assemble minimal tx for testing retry msg := NewTestMsg() - testRunner := func(t *testing.T, client solanaClient.ReaderWriter) { // build minimal txm loader := utils.NewLazyLoad(func() (solanaClient.ReaderWriter, error) { @@ -81,10 +80,8 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { lastLog := observer.All()[len(observer.All())-1] assert.Contains(t, lastLog.Message, "stopped tx retry") // assert that all retry goroutines exit successfully } - + client := clientmocks.NewReaderWriter(t) t.Run("delay in rebroadcasting tx", func(t *testing.T) { - client := clientmocks.NewReaderWriter(t) - // client mock txs := map[string]solanaGo.Signature{} var lock sync.RWMutex client.On("SendTx", mock.Anything, mock.Anything).Return( @@ -121,8 +118,6 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { }) t.Run("delay in broadcasting new tx", func(t *testing.T) { - client := clientmocks.NewReaderWriter(t) - // client mock txs := map[string]solanaGo.Signature{} var lock sync.RWMutex client.On("SendTx", mock.Anything, mock.Anything).Return( @@ -157,8 +152,6 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { }) t.Run("overlapping bumping tx", func(t *testing.T) { - client := clientmocks.NewReaderWriter(t) - // client mock txs := map[string]solanaGo.Signature{} var lock sync.RWMutex client.On("SendTx", mock.Anything, mock.Anything).Return( @@ -204,8 +197,7 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { }) t.Run("bumping tx errors and ctx cleans up waitgroup blocks", func(t *testing.T) { - client := clientmocks.NewReaderWriter(t) - // client mock - first tx is always successful + // first tx is always successful msg0 := NewTestMsg() require.NoError(t, fees.SetComputeUnitPrice(&msg0.tx, 0)) require.NoError(t, fees.SetComputeUnitLimit(&msg0.tx, 200_000)) @@ -217,7 +209,7 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { require.NoError(t, fees.SetComputeUnitPrice(&msg1.tx, 1)) require.NoError(t, fees.SetComputeUnitLimit(&msg1.tx, 200_000)) msg1.tx.Signatures = make([]solanaGo.Signature, 1) - client.On("SendTx", mock.Anything, &msg1.tx).Return(solanaGo.Signature{}, fmt.Errorf("BUMP FAILED")).Once() + client.On("SendTx", mock.Anything, &msg1.tx).Return(solanaGo.Signature{}, fmt.Errorf("BUMP FAILED")) client.On("SendTx", mock.Anything, &msg1.tx).Return(solanaGo.Signature{2}, nil) // init bump tx success, rebroadcast fails @@ -225,7 +217,7 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { require.NoError(t, fees.SetComputeUnitPrice(&msg2.tx, 2)) require.NoError(t, fees.SetComputeUnitLimit(&msg2.tx, 200_000)) msg2.tx.Signatures = make([]solanaGo.Signature, 1) - client.On("SendTx", mock.Anything, &msg2.tx).Return(solanaGo.Signature{3}, nil).Once() + client.On("SendTx", mock.Anything, &msg2.tx).Return(solanaGo.Signature{3}, nil) client.On("SendTx", mock.Anything, &msg2.tx).Return(solanaGo.Signature{}, fmt.Errorf("REBROADCAST FAILED")) // always successful @@ -234,7 +226,6 @@ func TestTxm_SendWithRetry_Race(t *testing.T) { require.NoError(t, fees.SetComputeUnitLimit(&msg3.tx, 200_000)) msg3.tx.Signatures = make([]solanaGo.Signature, 1) client.On("SendTx", mock.Anything, &msg3.tx).Return(solanaGo.Signature{4}, nil) - testRunner(t, client) }) } From 2433636fe66bad9b01be276eb6ddd9e4f2d658f5 Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Thu, 19 Dec 2024 13:22:41 -0500 Subject: [PATCH 2/4] logs are forwarded to a processor in slot and trx order (#953) * logs are forwarded to a processor in slot and trx order * make tests pass again * simplify block and expectation ordering --- pkg/solana/logpoller/job.go | 25 +- pkg/solana/logpoller/loader.go | 214 ++++++++++++- pkg/solana/logpoller/loader_test.go | 406 ++++++++++++++++++++---- pkg/solana/logpoller/log_data_parser.go | 3 +- 4 files changed, 562 insertions(+), 86 deletions(-) diff --git a/pkg/solana/logpoller/job.go b/pkg/solana/logpoller/job.go index 1d827a85b..165c0b5fe 100644 --- a/pkg/solana/logpoller/job.go +++ b/pkg/solana/logpoller/job.go @@ -33,7 +33,8 @@ func (j retryableJob) Run(ctx context.Context) error { } type eventDetail struct { - blockNumber uint64 + slotNumber uint64 + blockHeight uint64 blockHash solana.Hash trxIdx int trxSig solana.Signature @@ -54,12 +55,18 @@ func (j *processEventJob) Run(_ context.Context) error { return j.parser.Process(j.event) } +type wrappedParser interface { + ProgramEventProcessor + ExpectBlock(uint64) + ExpectTxs(uint64, int) +} + // getTransactionsFromBlockJob is a job that fetches transaction signatures from a block and loads // the job queue with getTransactionLogsJobs for each transaction found in the block. type getTransactionsFromBlockJob struct { slotNumber uint64 client RPCClient - parser ProgramEventProcessor + parser wrappedParser chJobs chan Job } @@ -103,17 +110,20 @@ func (j *getTransactionsFromBlockJob) Run(ctx context.Context) error { } detail := eventDetail{ - blockHash: block.Blockhash, + slotNumber: j.slotNumber, + blockHash: block.Blockhash, } if block.BlockHeight != nil { - detail.blockNumber = *block.BlockHeight + detail.blockHeight = *block.BlockHeight } if len(block.Transactions) != len(blockSigsOnly.Signatures) { return fmt.Errorf("block %d has %d transactions but %d signatures", j.slotNumber, len(block.Transactions), len(blockSigsOnly.Signatures)) } + j.parser.ExpectTxs(j.slotNumber, len(block.Transactions)) + for idx, trx := range block.Transactions { detail.trxIdx = idx if len(blockSigsOnly.Signatures)-1 <= idx { @@ -130,14 +140,15 @@ func messagesToEvents(messages []string, parser ProgramEventProcessor, detail ev var logIdx uint for _, outputs := range parseProgramLogs(messages) { for _, event := range outputs.Events { - logIdx++ - - event.BlockNumber = detail.blockNumber + event.SlotNumber = detail.slotNumber + event.BlockHeight = detail.blockHeight event.BlockHash = detail.blockHash event.TransactionHash = detail.trxSig event.TransactionIndex = detail.trxIdx event.TransactionLogIndex = logIdx + logIdx++ + chJobs <- &processEventJob{ parser: parser, event: event, diff --git a/pkg/solana/logpoller/loader.go b/pkg/solana/logpoller/loader.go index 56fcef25c..d714f08ad 100644 --- a/pkg/solana/logpoller/loader.go +++ b/pkg/solana/logpoller/loader.go @@ -1,8 +1,12 @@ package logpoller import ( + "container/list" "context" "errors" + "fmt" + "slices" + "sync" "sync/atomic" "time" @@ -40,7 +44,8 @@ type EncodedLogCollector struct { // dependencies and configuration client RPCClient - parser ProgramEventProcessor + ordered *orderedParser + unordered *unorderedParser lggr logger.Logger rpcTimeLimit time.Duration @@ -62,7 +67,7 @@ func NewEncodedLogCollector( ) *EncodedLogCollector { c := &EncodedLogCollector{ client: client, - parser: parser, + unordered: newUnorderedParser(parser), chSlot: make(chan uint64), chBlock: make(chan uint64, 1), chJobs: make(chan Job, 1), @@ -74,8 +79,9 @@ func NewEncodedLogCollector( Name: "EncodedLogCollector", NewSubServices: func(lggr logger.Logger) []services.Service { c.workers = NewWorkerGroup(DefaultWorkerCount, lggr) + c.ordered = newOrderedParser(parser, lggr) - return []services.Service{c.workers} + return []services.Service{c.workers, c.ordered} }, Start: c.start, Close: c.close, @@ -127,7 +133,7 @@ func (c *EncodedLogCollector) BackfillForAddress(ctx context.Context, address st if err := c.workers.Do(ctx, &getTransactionsFromBlockJob{ slotNumber: sig.Slot, client: c.client, - parser: c.parser, + parser: c.unordered, chJobs: c.chJobs, }); err != nil { return err @@ -138,7 +144,7 @@ func (c *EncodedLogCollector) BackfillForAddress(ctx context.Context, address st return nil } -func (c *EncodedLogCollector) start(ctx context.Context) error { +func (c *EncodedLogCollector) start(_ context.Context) error { c.engine.Go(c.runSlotPolling) c.engine.Go(c.runSlotProcessing) c.engine.Go(c.runBlockProcessing) @@ -201,10 +207,15 @@ func (c *EncodedLogCollector) runSlotProcessing(ctx context.Context) { continue } + from := c.highestSlot.Load() + 1 + if c.highestSlot.Load() == 0 { + from = slot + } + c.highestSlot.Store(slot) // load blocks in slot range - c.loadRange(ctx, c.highestSlotLoaded.Load()+1, slot) + c.loadRange(ctx, from, slot) } } } @@ -214,11 +225,11 @@ func (c *EncodedLogCollector) runBlockProcessing(ctx context.Context) { select { case <-ctx.Done(): return - case block := <-c.chBlock: + case slot := <-c.chBlock: if err := c.workers.Do(ctx, &getTransactionsFromBlockJob{ - slotNumber: block, + slotNumber: slot, client: c.client, - parser: c.parser, + parser: c.ordered, chJobs: c.chJobs, }); err != nil { c.lggr.Errorf("failed to add job to queue: %s", err) @@ -269,7 +280,21 @@ func (c *EncodedLogCollector) loadSlotBlocksRange(ctx context.Context, start, en return err } + // as a safety mechanism, order the blocks ascending (oldest to newest) in the extreme case + // that the RPC changes and results get jumbled. + slices.SortFunc(result, func(a, b uint64) int { + if a < b { + return -1 + } else if a > b { + return 1 + } + + return 0 + }) + for _, block := range result { + c.ordered.ExpectBlock(block) + select { case <-ctx.Done(): return nil @@ -279,3 +304,174 @@ func (c *EncodedLogCollector) loadSlotBlocksRange(ctx context.Context, start, en return nil } + +type unorderedParser struct { + parser ProgramEventProcessor +} + +func newUnorderedParser(parser ProgramEventProcessor) *unorderedParser { + return &unorderedParser{parser: parser} +} + +func (p *unorderedParser) ExpectBlock(_ uint64) {} +func (p *unorderedParser) ExpectTxs(_ uint64, _ int) {} +func (p *unorderedParser) Process(evt ProgramEvent) error { + return p.parser.Process(evt) +} + +type orderedParser struct { + // service state management + services.Service + engine *services.Engine + + // internal state + parser ProgramEventProcessor + mu sync.Mutex + blocks *list.List + expect map[uint64]int + actual map[uint64][]ProgramEvent +} + +func newOrderedParser(parser ProgramEventProcessor, lggr logger.Logger) *orderedParser { + op := &orderedParser{ + parser: parser, + blocks: list.New(), + expect: make(map[uint64]int), + actual: make(map[uint64][]ProgramEvent), + } + + op.Service, op.engine = services.Config{ + Name: "OrderedParser", + Start: op.start, + Close: op.close, + }.NewServiceEngine(lggr) + + return op +} + +// ExpectBlock should be called in block order to preserve block progression. +func (p *orderedParser) ExpectBlock(block uint64) { + p.mu.Lock() + defer p.mu.Unlock() + + p.blocks.PushBack(block) +} + +func (p *orderedParser) ExpectTxs(block uint64, quantity int) { + p.mu.Lock() + defer p.mu.Unlock() + + p.expect[block] = quantity + p.actual[block] = make([]ProgramEvent, 0, quantity) +} + +func (p *orderedParser) Process(event ProgramEvent) error { + p.mu.Lock() + defer p.mu.Unlock() + + if err := p.addToExpectations(event); err != nil { + // TODO: log error because this is an unrecoverable error + return nil + } + + return p.sendReadySlots() +} + +func (p *orderedParser) start(_ context.Context) error { + p.engine.GoTick(services.NewTicker(time.Second), p.run) + + return nil +} + +func (p *orderedParser) close() error { + return nil +} + +func (p *orderedParser) addToExpectations(evt ProgramEvent) error { + _, ok := p.expect[evt.SlotNumber] + if !ok { + return fmt.Errorf("%w: %d", errExpectationsNotSet, evt.SlotNumber) + } + + evts, ok := p.actual[evt.SlotNumber] + if !ok { + return fmt.Errorf("%w: %d", errExpectationsNotSet, evt.SlotNumber) + } + + p.actual[evt.SlotNumber] = append(evts, evt) + + return nil +} + +func (p *orderedParser) expectations(block uint64) (int, bool, error) { + expectations, ok := p.expect[block] + if !ok { + return 0, false, fmt.Errorf("%w: %d", errExpectationsNotSet, block) + } + + evts, ok := p.actual[block] + if !ok { + return 0, false, fmt.Errorf("%w: %d", errExpectationsNotSet, block) + } + + return expectations, expectations == len(evts), nil +} + +func (p *orderedParser) clearExpectations(block uint64) { + delete(p.expect, block) + delete(p.actual, block) +} + +func (p *orderedParser) run(_ context.Context) { + p.mu.Lock() + defer p.mu.Unlock() + + _ = p.sendReadySlots() +} + +func (p *orderedParser) sendReadySlots() error { + // start at the lowest block and find ready blocks + for element := p.blocks.Front(); element != nil; element = p.blocks.Front() { + block := element.Value.(uint64) + // if no expectations are set, we are still waiting on information for the block. + // if expectations set and not met, we are still waiting on information for the block + // no other block data should be sent until this is resolved + exp, met, err := p.expectations(block) + if err != nil || !met { + break + } + + // if expectations are 0 -> remove and continue + if exp == 0 { + p.clearExpectations(block) + p.blocks.Remove(element) + + continue + } + + evts, ok := p.actual[block] + if !ok { + return errInvalidState + } + + var errs error + for _, evt := range evts { + errs = errors.Join(errs, p.parser.Process(evt)) + } + + // need possible retry + if errs != nil { + return errs + } + + p.blocks.Remove(element) + p.clearExpectations(block) + } + + return nil +} + +var ( + errExpectationsNotSet = errors.New("expectations not set") + errInvalidState = errors.New("invalid state") +) diff --git a/pkg/solana/logpoller/loader_test.go b/pkg/solana/logpoller/loader_test.go index 69a37702b..e3cbb7700 100644 --- a/pkg/solana/logpoller/loader_test.go +++ b/pkg/solana/logpoller/loader_test.go @@ -3,6 +3,7 @@ package logpoller_test import ( "context" "crypto/rand" + "reflect" "sync" "sync/atomic" "testing" @@ -32,6 +33,8 @@ var ( ) func TestEncodedLogCollector_StartClose(t *testing.T) { + t.Parallel() + client := new(mocks.RPCClient) ctx := tests.Context(t) @@ -42,6 +45,8 @@ func TestEncodedLogCollector_StartClose(t *testing.T) { } func TestEncodedLogCollector_ParseSingleEvent(t *testing.T) { + t.Parallel() + client := new(mocks.RPCClient) parser := new(testParser) ctx := tests.Context(t) @@ -53,42 +58,221 @@ func TestEncodedLogCollector_ParseSingleEvent(t *testing.T) { require.NoError(t, collector.Close()) }) - slot := uint64(42) - sig := solana.Signature{2, 1, 4, 2} - blockHeight := uint64(21) + var latest atomic.Uint64 - client.EXPECT().GetLatestBlockhash(mock.Anything, rpc.CommitmentFinalized).Return(&rpc.GetLatestBlockhashResult{ - RPCContext: rpc.RPCContext{ - Context: rpc.Context{ - Slot: slot, - }, - }, - }, nil) + latest.Store(uint64(40)) - client.EXPECT().GetBlocks(mock.Anything, uint64(1), mock.MatchedBy(func(val *uint64) bool { - return val != nil && *val == slot - }), mock.Anything).Return(rpc.BlocksResult{slot}, nil) + client.EXPECT(). + GetLatestBlockhash(mock.Anything, rpc.CommitmentFinalized). + RunAndReturn(latestBlockhashReturnFunc(&latest)) - client.EXPECT().GetBlockWithOpts(mock.Anything, slot, mock.Anything).Return(&rpc.GetBlockResult{ - Transactions: []rpc.TransactionWithMeta{ - { - Meta: &rpc.TransactionMeta{ - LogMessages: messages, - }, - }, - }, - Signatures: []solana.Signature{sig}, - BlockHeight: &blockHeight, - }, nil).Twice() + client.EXPECT(). + GetBlocks( + mock.Anything, + mock.MatchedBy(getBlocksStartValMatcher), + mock.MatchedBy(getBlocksEndValMatcher(&latest)), + rpc.CommitmentFinalized, + ). + RunAndReturn(getBlocksReturnFunc(false)) + + client.EXPECT(). + GetBlockWithOpts(mock.Anything, mock.Anything, mock.Anything). + RunAndReturn(func(_ context.Context, slot uint64, _ *rpc.GetBlockOpts) (*rpc.GetBlockResult, error) { + height := slot - 1 + + result := rpc.GetBlockResult{ + Transactions: []rpc.TransactionWithMeta{}, + Signatures: []solana.Signature{}, + BlockHeight: &height, + } + + _, _ = rand.Read(result.Blockhash[:]) + + if slot == 42 { + var sig solana.Signature + _, _ = rand.Read(sig[:]) + + result.Signatures = []solana.Signature{sig} + result.Transactions = []rpc.TransactionWithMeta{ + { + Meta: &rpc.TransactionMeta{ + LogMessages: messages, + }, + }, + } + } + + return &result, nil + }) tests.AssertEventually(t, func() bool { return parser.Called() }) +} + +func TestEncodedLogCollector_MultipleEventOrdered(t *testing.T) { + t.Parallel() + + client := new(mocks.RPCClient) + parser := new(testParser) + ctx := tests.Context(t) + + collector := logpoller.NewEncodedLogCollector(client, parser, logger.Nop()) + + require.NoError(t, collector.Start(ctx)) + t.Cleanup(func() { + require.NoError(t, collector.Close()) + }) + + var latest atomic.Uint64 + + latest.Store(uint64(40)) + + slots := []uint64{44, 43, 42, 41} + sigs := make([]solana.Signature, len(slots)) + hashes := make([]solana.Hash, len(slots)) + scrambler := &slotUnsync{ch: make(chan struct{})} + + for idx := range len(sigs) { + _, _ = rand.Read(sigs[idx][:]) + _, _ = rand.Read(hashes[idx][:]) + } + + client.EXPECT(). + GetLatestBlockhash(mock.Anything, rpc.CommitmentFinalized). + RunAndReturn(latestBlockhashReturnFunc(&latest)) + + client.EXPECT(). + GetBlocks( + mock.Anything, + mock.MatchedBy(getBlocksStartValMatcher), + mock.MatchedBy(getBlocksEndValMatcher(&latest)), + rpc.CommitmentFinalized, + ). + RunAndReturn(getBlocksReturnFunc(false)) + + client.EXPECT(). + GetBlockWithOpts(mock.Anything, mock.Anything, mock.Anything). + RunAndReturn(func(_ context.Context, slot uint64, _ *rpc.GetBlockOpts) (*rpc.GetBlockResult, error) { + slotIdx := -1 + for idx, slt := range slots { + if slt == slot { + slotIdx = idx + + break + } + } + + // imitate loading block data out of order + // every other block must wait for the block previous + scrambler.next() + + height := slot - 1 + + if slotIdx == -1 { + var hash solana.Hash + _, _ = rand.Read(hash[:]) + + return &rpc.GetBlockResult{ + Blockhash: hash, + Transactions: []rpc.TransactionWithMeta{}, + Signatures: []solana.Signature{}, + BlockHeight: &height, + }, nil + } + + return &rpc.GetBlockResult{ + Blockhash: hashes[slotIdx], + Transactions: []rpc.TransactionWithMeta{ + { + Meta: &rpc.TransactionMeta{ + LogMessages: messages, + }, + }, + }, + Signatures: []solana.Signature{sigs[slotIdx]}, + BlockHeight: &height, + }, nil + }) + + tests.AssertEventually(t, func() bool { + return reflect.DeepEqual(parser.Events(), []logpoller.ProgramEvent{ + { + BlockData: logpoller.BlockData{ + SlotNumber: 41, + BlockHeight: 40, + BlockHash: hashes[3], + TransactionHash: sigs[3], + TransactionIndex: 0, + TransactionLogIndex: 0, + }, + Prefix: ">", + Data: "HDQnaQjSWwkNAAAASGVsbG8sIFdvcmxkISoAAAAAAAAA", + }, + { + BlockData: logpoller.BlockData{ + SlotNumber: 42, + BlockHeight: 41, + BlockHash: hashes[2], + TransactionHash: sigs[2], + TransactionIndex: 0, + TransactionLogIndex: 0, + }, + Prefix: ">", + Data: "HDQnaQjSWwkNAAAASGVsbG8sIFdvcmxkISoAAAAAAAAA", + }, + { + BlockData: logpoller.BlockData{ + SlotNumber: 43, + BlockHeight: 42, + BlockHash: hashes[1], + TransactionHash: sigs[1], + TransactionIndex: 0, + TransactionLogIndex: 0, + }, + Prefix: ">", + Data: "HDQnaQjSWwkNAAAASGVsbG8sIFdvcmxkISoAAAAAAAAA", + }, + { + BlockData: logpoller.BlockData{ + SlotNumber: 44, + BlockHeight: 43, + BlockHash: hashes[0], + TransactionHash: sigs[0], + TransactionIndex: 0, + TransactionLogIndex: 0, + }, + Prefix: ">", + Data: "HDQnaQjSWwkNAAAASGVsbG8sIFdvcmxkISoAAAAAAAAA", + }, + }) + }) client.AssertExpectations(t) } +type slotUnsync struct { + ch chan struct{} + waiting atomic.Bool +} + +func (u *slotUnsync) next() { + if u.waiting.Load() { + u.waiting.Store(false) + + <-u.ch + + return + } + + u.waiting.Store(true) + + u.ch <- struct{}{} +} + func TestEncodedLogCollector_BackfillForAddress(t *testing.T) { + t.Parallel() + client := new(mocks.RPCClient) parser := new(testParser) ctx := tests.Context(t) @@ -103,65 +287,91 @@ func TestEncodedLogCollector_BackfillForAddress(t *testing.T) { pubKey := solana.PublicKey{2, 1, 4, 2} slots := []uint64{44, 43, 42} sigs := make([]solana.Signature, len(slots)*2) - blockHeights := []uint64{21, 22, 23, 50} for idx := range len(sigs) { _, _ = rand.Read(sigs[idx][:]) } + var latest atomic.Uint64 + + latest.Store(uint64(40)) + // GetLatestBlockhash might be called at start-up; make it take some time because the result isn't needed for this test - client.EXPECT().GetLatestBlockhash(mock.Anything, mock.Anything).Return(&rpc.GetLatestBlockhashResult{ - RPCContext: rpc.RPCContext{ - Context: rpc.Context{ - Slot: slots[0], - }, - }, - Value: &rpc.LatestBlockhashResult{ - LastValidBlockHeight: 42, - }, - }, nil).After(2 * time.Second).Maybe() + client.EXPECT(). + GetLatestBlockhash(mock.Anything, rpc.CommitmentFinalized). + RunAndReturn(latestBlockhashReturnFunc(&latest)). + After(2 * time.Second). + Maybe() client.EXPECT(). - GetSignaturesForAddressWithOpts(mock.Anything, pubKey, mock.MatchedBy(func(opts *rpc.GetSignaturesForAddressOpts) bool { - return opts != nil && opts.Before.String() == solana.Signature{}.String() - })). - Return([]*rpc.TransactionSignature{ - {Slot: slots[0], Signature: sigs[0]}, - {Slot: slots[0], Signature: sigs[1]}, - {Slot: slots[1], Signature: sigs[2]}, - {Slot: slots[1], Signature: sigs[3]}, - {Slot: slots[2], Signature: sigs[4]}, - {Slot: slots[2], Signature: sigs[5]}, - }, nil) - - client.EXPECT().GetSignaturesForAddressWithOpts(mock.Anything, pubKey, mock.Anything).Return([]*rpc.TransactionSignature{}, nil) - - for idx := range len(slots) { - client.EXPECT().GetBlockWithOpts(mock.Anything, slots[idx], mock.Anything).Return(&rpc.GetBlockResult{ - Transactions: []rpc.TransactionWithMeta{ - { - Meta: &rpc.TransactionMeta{ - LogMessages: messages, + GetBlocks( + mock.Anything, + mock.MatchedBy(getBlocksStartValMatcher), + mock.MatchedBy(getBlocksEndValMatcher(&latest)), + rpc.CommitmentFinalized, + ). + RunAndReturn(getBlocksReturnFunc(true)) + + client.EXPECT(). + GetSignaturesForAddressWithOpts(mock.Anything, pubKey, mock.Anything). + RunAndReturn(func(_ context.Context, pk solana.PublicKey, opts *rpc.GetSignaturesForAddressOpts) ([]*rpc.TransactionSignature, error) { + ret := []*rpc.TransactionSignature{} + + if opts != nil && opts.Before.String() == (solana.Signature{}).String() { + for idx := range slots { + ret = append(ret, &rpc.TransactionSignature{Slot: slots[idx], Signature: sigs[idx*2]}) + ret = append(ret, &rpc.TransactionSignature{Slot: slots[idx], Signature: sigs[(idx*2)+1]}) + } + } + + return ret, nil + }) + + client.EXPECT(). + GetBlockWithOpts(mock.Anything, mock.Anything, mock.Anything). + RunAndReturn(func(_ context.Context, slot uint64, _ *rpc.GetBlockOpts) (*rpc.GetBlockResult, error) { + idx := -1 + for sIdx, slt := range slots { + if slt == slot { + idx = sIdx + + break + } + } + + height := slot - 1 + + if idx == -1 { + return &rpc.GetBlockResult{ + Transactions: []rpc.TransactionWithMeta{}, + Signatures: []solana.Signature{}, + BlockHeight: &height, + }, nil + } + + return &rpc.GetBlockResult{ + Transactions: []rpc.TransactionWithMeta{ + { + Meta: &rpc.TransactionMeta{ + LogMessages: messages, + }, }, - }, - { - Meta: &rpc.TransactionMeta{ - LogMessages: messages, + { + Meta: &rpc.TransactionMeta{ + LogMessages: messages, + }, }, }, - }, - Signatures: []solana.Signature{sigs[idx*2], sigs[(idx*2)+1]}, - BlockHeight: &blockHeights[idx], - }, nil).Twice() - } + Signatures: []solana.Signature{sigs[idx*2], sigs[(idx*2)+1]}, + BlockHeight: &height, + }, nil + }) assert.NoError(t, collector.BackfillForAddress(ctx, pubKey.String(), 42)) tests.AssertEventually(t, func() bool { return parser.Count() == 6 }) - - client.AssertExpectations(t) } func BenchmarkEncodedLogCollector(b *testing.B) { @@ -347,12 +557,16 @@ func (p *testBlockProducer) GetTransaction(_ context.Context, sig solana.Signatu type testParser struct { called atomic.Bool - count atomic.Uint64 + mu sync.Mutex + events []logpoller.ProgramEvent } func (p *testParser) Process(event logpoller.ProgramEvent) error { p.called.Store(true) - p.count.Store(p.count.Load() + 1) + + p.mu.Lock() + p.events = append(p.events, event) + p.mu.Unlock() return nil } @@ -362,5 +576,59 @@ func (p *testParser) Called() bool { } func (p *testParser) Count() uint64 { - return p.count.Load() + p.mu.Lock() + defer p.mu.Unlock() + + return uint64(len(p.events)) +} + +func (p *testParser) Events() []logpoller.ProgramEvent { + p.mu.Lock() + defer p.mu.Unlock() + + return p.events +} + +func latestBlockhashReturnFunc(latest *atomic.Uint64) func(context.Context, rpc.CommitmentType) (*rpc.GetLatestBlockhashResult, error) { + return func(ctx context.Context, ct rpc.CommitmentType) (*rpc.GetLatestBlockhashResult, error) { + defer func() { + latest.Store(latest.Load() + 2) + }() + + return &rpc.GetLatestBlockhashResult{ + RPCContext: rpc.RPCContext{ + Context: rpc.Context{ + Slot: latest.Load(), + }, + }, + Value: &rpc.LatestBlockhashResult{ + LastValidBlockHeight: latest.Load() - 1, + }, + }, nil + } +} + +func getBlocksReturnFunc(empty bool) func(context.Context, uint64, *uint64, rpc.CommitmentType) (rpc.BlocksResult, error) { + return func(_ context.Context, u1 uint64, u2 *uint64, _ rpc.CommitmentType) (rpc.BlocksResult, error) { + blocks := []uint64{} + + if !empty { + blocks = make([]uint64, *u2-u1+1) + for idx := range blocks { + blocks[idx] = u1 + uint64(idx) + } + } + + return rpc.BlocksResult(blocks), nil + } +} + +func getBlocksStartValMatcher(val uint64) bool { + return val > uint64(0) +} + +func getBlocksEndValMatcher(latest *atomic.Uint64) func(*uint64) bool { + return func(val *uint64) bool { + return val != nil && *val <= latest.Load() + } } diff --git a/pkg/solana/logpoller/log_data_parser.go b/pkg/solana/logpoller/log_data_parser.go index 4cfd04470..4080a09e2 100644 --- a/pkg/solana/logpoller/log_data_parser.go +++ b/pkg/solana/logpoller/log_data_parser.go @@ -16,7 +16,8 @@ var ( ) type BlockData struct { - BlockNumber uint64 + SlotNumber uint64 + BlockHeight uint64 BlockHash solana.Hash TransactionHash solana.Signature TransactionIndex int From 7cc92fbde82df100ae43cce1daccfa611986805a Mon Sep 17 00:00:00 2001 From: ilija42 <57732589+ilija42@users.noreply.github.com> Date: Thu, 19 Dec 2024 19:50:51 +0100 Subject: [PATCH 3/4] Codec interface tests (#967) * Connect codec interface tests and refactor codec to interface like EVM one * progress * Fully implement Codec interface tests * Run codec tests in loop * Prettify codec and codec tests * Refactor codec nil encoding handling * Revert accidental changes to testIDL.json * Add sonar exclusion for codec test utils * Add sq exclusion for duplications in testutils, add decoder unit tests * Add encoder unit test * Fix lint and rename codec to solanacodec to avoid types name collision * Solana codec entry improvements * Fix Solana codec field casing * minor err messages improvements * Code improvements * Fix encoder unit tests * Fix sonar exclusions * lint * Reorder methods in Solana codec * Fix CR integration tests config * Revert TestNewIDLCodec_WithModifiers deletion * Add comments for codec entry includeDiscriminator option * Add discriminator value check in codec entry Decode * Reuse utils from interface tests for Solana codec interface tests * Fix comment * Fix comment * [Non-EVM-1062] Solana Codec events support, Hookup Fuzz tests and cleanup Codec init (#987) * Add events IDL parsing to codec * temp * Add a basic codec test for event IDL parsing * Cleanup Solana Codec init * Hookup Codec fuzz tests * delete an unnecessary comment * lint --------- Co-authored-by: Jonghyeon Park --- go.mod | 4 +- go.sum | 8 +- integration-tests/go.mod | 4 +- integration-tests/go.sum | 8 +- .../relayinterface/chain_components_test.go | 6 +- pkg/solana/codec/codec_entry.go | 192 ++++++ pkg/solana/codec/codec_test.go | 158 +++++ pkg/solana/codec/decoder.go | 35 ++ pkg/solana/codec/decoder_test.go | 90 +++ pkg/solana/codec/discriminator.go | 16 +- pkg/solana/codec/encoder.go | 35 ++ pkg/solana/codec/encoder_test.go | 103 ++++ pkg/solana/codec/parsed_types.go | 48 ++ pkg/solana/codec/solana.go | 241 +++++--- .../codec/testutils/eventItemTypeIDL.json | 73 +++ .../codec/testutils/itemArray1TypeIDL.json | 92 +++ .../codec/testutils/itemArray2TypeIDL.json | 92 +++ pkg/solana/codec/testutils/itemIDL.json | 77 +++ .../codec/testutils/itemSliceTypeIDL.json | 89 +++ pkg/solana/codec/testutils/nilTypeIDL.json | 12 + .../codec/testutils/sizeItemTypeIDL.json | 38 ++ pkg/solana/codec/testutils/types.go | 561 +++++++++++++++++- pkg/solana/codec/types.go | 24 + pkg/solana/config/chain_reader.go | 4 +- sonar-project.properties | 5 +- 25 files changed, 1918 insertions(+), 97 deletions(-) create mode 100644 pkg/solana/codec/codec_entry.go create mode 100644 pkg/solana/codec/codec_test.go create mode 100644 pkg/solana/codec/decoder.go create mode 100644 pkg/solana/codec/decoder_test.go create mode 100644 pkg/solana/codec/encoder.go create mode 100644 pkg/solana/codec/encoder_test.go create mode 100644 pkg/solana/codec/parsed_types.go create mode 100644 pkg/solana/codec/testutils/eventItemTypeIDL.json create mode 100644 pkg/solana/codec/testutils/itemArray1TypeIDL.json create mode 100644 pkg/solana/codec/testutils/itemArray2TypeIDL.json create mode 100644 pkg/solana/codec/testutils/itemIDL.json create mode 100644 pkg/solana/codec/testutils/itemSliceTypeIDL.json create mode 100644 pkg/solana/codec/testutils/nilTypeIDL.json create mode 100644 pkg/solana/codec/testutils/sizeItemTypeIDL.json create mode 100644 pkg/solana/codec/types.go diff --git a/go.mod b/go.mod index 22ff1dd69..8e6423815 100644 --- a/go.mod +++ b/go.mod @@ -20,8 +20,8 @@ require ( github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.27.0 golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 - golang.org/x/sync v0.8.0 - golang.org/x/text v0.18.0 + golang.org/x/sync v0.10.0 + golang.org/x/text v0.21.0 ) require ( diff --git a/go.sum b/go.sum index a44777df8..a2d9c0b39 100644 --- a/go.sum +++ b/go.sum @@ -682,8 +682,8 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -746,8 +746,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/integration-tests/go.mod b/integration-tests/go.mod index bde3745bc..78f6d6394 100644 --- a/integration-tests/go.mod +++ b/integration-tests/go.mod @@ -25,8 +25,8 @@ require ( github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.34.0 golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c - golang.org/x/sync v0.8.0 - golang.org/x/text v0.19.0 + golang.org/x/sync v0.10.0 + golang.org/x/text v0.21.0 gopkg.in/guregu/null.v4 v4.0.0 ) diff --git a/integration-tests/go.sum b/integration-tests/go.sum index cc45a70d9..19dd0b8a2 100644 --- a/integration-tests/go.sum +++ b/integration-tests/go.sum @@ -1692,8 +1692,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1804,8 +1804,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/integration-tests/relayinterface/chain_components_test.go b/integration-tests/relayinterface/chain_components_test.go index 6af9c3444..9b7be2f0a 100644 --- a/integration-tests/relayinterface/chain_components_test.go +++ b/integration-tests/relayinterface/chain_components_test.go @@ -132,7 +132,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) Setup(t T) { Procedure: config.ChainReaderProcedure{ IDLAccount: "DataAccount", OutputModifications: codec.ModifiersConfig{ - &codec.PropertyExtractorConfig{FieldName: "U64value"}, + &codec.PropertyExtractorConfig{FieldName: "U64Value"}, }, }, }, @@ -142,7 +142,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) Setup(t T) { Procedure: config.ChainReaderProcedure{ IDLAccount: "DataAccount", OutputModifications: codec.ModifiersConfig{ - &codec.PropertyExtractorConfig{FieldName: "U64slice"}, + &codec.PropertyExtractorConfig{FieldName: "U64Slice"}, }, }, }, @@ -156,7 +156,7 @@ func (it *SolanaChainComponentsInterfaceTester[T]) Setup(t T) { Procedure: config.ChainReaderProcedure{ IDLAccount: "DataAccount", OutputModifications: codec.ModifiersConfig{ - &codec.PropertyExtractorConfig{FieldName: "U64value"}, + &codec.PropertyExtractorConfig{FieldName: "U64Value"}, }, }, }, diff --git a/pkg/solana/codec/codec_entry.go b/pkg/solana/codec/codec_entry.go new file mode 100644 index 000000000..bc42ae968 --- /dev/null +++ b/pkg/solana/codec/codec_entry.go @@ -0,0 +1,192 @@ +package codec + +import ( + "bytes" + "fmt" + "reflect" + + "github.com/smartcontractkit/chainlink-common/pkg/codec" + commonencodings "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" +) + +type Entry interface { + Encode(value any, into []byte) ([]byte, error) + Decode(encoded []byte) (any, []byte, error) + GetCodecType() commonencodings.TypeCodec + GetType() reflect.Type + Modifier() codec.Modifier + Size(numItems int) (int, error) + FixedSize() (int, error) +} + +type entry struct { + // TODO this might not be needed in the end, it was handy to make tests simpler + offchainName string + onchainName string + reflectType reflect.Type + typeCodec commonencodings.TypeCodec + mod codec.Modifier + // includeDiscriminator during Encode adds a discriminator to the encoded bytes under an assumption that the provided value didn't have a discriminator. + // During Decode includeDiscriminator removes discriminator from bytes under an assumption that the provided struct doesn't need a discriminator. + includeDiscriminator bool + discriminator Discriminator +} + +func NewAccountEntry(offchainName string, idlAccount IdlTypeDef, idlTypes IdlTypeDefSlice, includeDiscriminator bool, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) { + _, accCodec, err := createCodecType(idlAccount, createRefs(idlTypes, builder), false) + if err != nil { + return nil, err + } + + return newEntry( + offchainName, + idlAccount.Name, + accCodec, + includeDiscriminator, + mod, + ), nil +} + +func NewInstructionArgsEntry(offChainName string, instructions IdlInstruction, idlTypes IdlTypeDefSlice, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) { + _, instructionCodecArgs, err := asStruct(instructions.Args, createRefs(idlTypes, builder), instructions.Name, false, true) + if err != nil { + return nil, err + } + + return newEntry( + offChainName, + instructions.Name, + instructionCodecArgs, + // Instruction arguments don't need a discriminator by default + false, + mod, + ), nil +} + +func NewEventArgsEntry(offChainName string, event IdlEvent, idlTypes IdlTypeDefSlice, includeDiscriminator bool, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) { + _, eventCodec, err := asStruct(eventFieldsToFields(event.Fields), createRefs(idlTypes, builder), event.Name, false, false) + if err != nil { + return nil, err + } + + return newEntry( + offChainName, + event.Name, + eventCodec, + includeDiscriminator, + mod, + ), nil +} + +func newEntry( + offchainName, onchainName string, + typeCodec commonencodings.TypeCodec, + includeDiscriminator bool, + mod codec.Modifier, +) Entry { + return &entry{ + offchainName: offchainName, + onchainName: onchainName, + reflectType: typeCodec.GetType(), + typeCodec: typeCodec, + mod: ensureModifier(mod), + includeDiscriminator: includeDiscriminator, + discriminator: *NewDiscriminator(onchainName), + } +} + +func createRefs(idlTypes IdlTypeDefSlice, builder commonencodings.Builder) *codecRefs { + return &codecRefs{ + builder: builder, + codecs: make(map[string]commonencodings.TypeCodec), + typeDefs: idlTypes, + dependencies: make(map[string][]string), + } +} + +func (e *entry) Encode(value any, into []byte) ([]byte, error) { + // Special handling for encoding a nil pointer to an empty struct. + t := e.reflectType + if value == nil { + if t.Kind() == reflect.Pointer { + elem := t.Elem() + if elem.Kind() == reflect.Struct && elem.NumField() == 0 { + return []byte{}, nil + } + } + return nil, fmt.Errorf("%w: cannot encode nil value for offchainName: %q, onchainName: %q", + commontypes.ErrInvalidType, e.offchainName, e.onchainName) + } + + encodedVal, err := e.typeCodec.Encode(value, into) + if err != nil { + return nil, err + } + + if e.includeDiscriminator { + var byt []byte + encodedDisc, err := e.discriminator.Encode(&e.discriminator.hashPrefix, byt) + if err != nil { + return nil, err + } + return append(encodedDisc, encodedVal...), nil + } + + return encodedVal, nil +} + +func (e *entry) Decode(encoded []byte) (any, []byte, error) { + if e.includeDiscriminator { + if len(encoded) < discriminatorLength { + return nil, nil, fmt.Errorf("%w: encoded data too short to contain discriminator for offchainName: %q, onchainName: %q", + commontypes.ErrInvalidType, e.offchainName, e.onchainName) + } + + if !bytes.Equal(e.discriminator.hashPrefix, encoded[:discriminatorLength]) { + return nil, nil, fmt.Errorf("%w: encoded data has a bad discriminator %v for offchainName: %q, onchainName: %q", + commontypes.ErrInvalidType, encoded[:discriminatorLength], e.offchainName, e.onchainName) + } + + encoded = encoded[discriminatorLength:] + } + return e.typeCodec.Decode(encoded) +} + +func (e *entry) GetCodecType() commonencodings.TypeCodec { + return e.typeCodec +} + +func (e *entry) GetType() reflect.Type { + return e.reflectType +} + +func (e *entry) Modifier() codec.Modifier { + return e.mod +} + +func (e *entry) Size(numItems int) (int, error) { + return e.typeCodec.Size(numItems) +} + +func (e *entry) FixedSize() (int, error) { + return e.typeCodec.FixedSize() +} + +func ensureModifier(mod codec.Modifier) codec.Modifier { + if mod == nil { + return codec.MultiModifier{} + } + return mod +} + +func eventFieldsToFields(evFields []IdlEventField) []IdlField { + var idlFields []IdlField + for _, evField := range evFields { + idlFields = append(idlFields, IdlField{ + Name: evField.Name, + Type: evField.Type, + }) + } + return idlFields +} diff --git a/pkg/solana/codec/codec_test.go b/pkg/solana/codec/codec_test.go new file mode 100644 index 000000000..8a9dedb45 --- /dev/null +++ b/pkg/solana/codec/codec_test.go @@ -0,0 +1,158 @@ +package codec_test + +import ( + "bytes" + _ "embed" + "slices" + "testing" + + bin "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" + ocr2types "github.com/smartcontractkit/libocr/offchainreporting2plus/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" + looptestutils "github.com/smartcontractkit/chainlink-common/pkg/loop/testutils" + clcommontypes "github.com/smartcontractkit/chainlink-common/pkg/types" + . "github.com/smartcontractkit/chainlink-common/pkg/types/interfacetests" //nolint common practice to import test mods with . + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec/testutils" +) + +const anyExtraValue = 3 + +func TestCodec(t *testing.T) { + tester := &codecInterfaceTester{} + RunCodecInterfaceTests(t, tester) + RunCodecInterfaceTests(t, looptestutils.WrapCodecTesterForLoop(tester)) + + t.Run("Events are encode-able and decode-able for a single item", func(t *testing.T) { + ctx := tests.Context(t) + item := CreateTestStruct[*testing.T](0, tester) + req := &EncodeRequest{TestStructs: []TestStruct{item}, TestOn: testutils.TestEventItem} + resp := tester.EncodeFields(t, req) + + codec := tester.GetCodec(t) + actualEncoding, err := codec.Encode(ctx, item, testutils.TestEventItem) + require.NoError(t, err) + assert.Equal(t, resp, actualEncoding) + + into := TestStruct{} + require.NoError(t, codec.Decode(ctx, actualEncoding, &into, testutils.TestEventItem)) + assert.Equal(t, item, into) + }) +} + +func FuzzCodec(f *testing.F) { + tester := &codecInterfaceTester{} + RunCodecInterfaceFuzzTests(f, tester) +} + +type codecInterfaceTester struct { + TestSelectionSupport +} + +func (it *codecInterfaceTester) Setup(_ *testing.T) {} + +func (it *codecInterfaceTester) GetAccountBytes(_ int) []byte { + // TODO solana base58 string can be of variable length, this value is always 44, but it should be able to handle any length 32-44 + pk := solana.PublicKeyFromBytes([]byte{220, 108, 195, 188, 166, 6, 163, 39, 197, 131, 44, 38, 154, 177, 232, 80, 141, 50, 7, 65, 28, 65, 182, 165, 57, 5, 176, 68, 46, 181, 58, 245}) + return pk.Bytes() +} + +func (it *codecInterfaceTester) GetAccountString(i int) string { + return solana.PublicKeyFromBytes(it.GetAccountBytes(i)).String() +} + +func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeRequest) []byte { + if request.TestOn == TestItemType || request.TestOn == testutils.TestEventItem { + return encodeFieldsOnItem(t, request) + } + + return encodeFieldsOnSliceOrArray(t, request) +} + +func encodeFieldsOnItem(t *testing.T, request *EncodeRequest) ocr2types.Report { + buf := new(bytes.Buffer) + // The underlying TestItemAsAccount adds a discriminator by default while being Borsh encoded. + if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil { + require.NoError(t, err) + } + return buf.Bytes() +} + +func encodeFieldsOnSliceOrArray(t *testing.T, request *EncodeRequest) []byte { + var toEncode interface{} + buf := new(bytes.Buffer) + switch request.TestOn { + case TestItemArray1Type: + toEncode = [1]testutils.TestItemAsArgs{testutils.EncodeRequestToTestItemAsArgs(request.TestStructs[0])} + case TestItemArray2Type: + toEncode = [2]testutils.TestItemAsArgs{testutils.EncodeRequestToTestItemAsArgs(request.TestStructs[0]), testutils.EncodeRequestToTestItemAsArgs(request.TestStructs[1])} + default: + // encode TestItemSliceType as instruction args (similar to accounts, but no discriminator) because accounts can't be just a vector + var itemSliceType []testutils.TestItemAsArgs + for _, req := range request.TestStructs { + itemSliceType = append(itemSliceType, testutils.EncodeRequestToTestItemAsArgs(req)) + } + toEncode = itemSliceType + } + + if err := bin.NewBorshEncoder(buf).Encode(toEncode); err != nil { + require.NoError(t, err) + } + return buf.Bytes() +} + +func (it *codecInterfaceTester) GetCodec(t *testing.T) clcommontypes.Codec { + codecConfig := codec.Config{Configs: map[string]codec.ChainConfig{}} + TestItem := CreateTestStruct[*testing.T](0, it) + for offChainName, v := range testutils.CodecDefs { + codecEntryCfg := codecConfig.Configs[offChainName] + codecEntryCfg.IDL = v.IDL + codecEntryCfg.Type = v.ItemType + codecEntryCfg.OnChainName = v.IDLTypeName + + if offChainName != NilType { + codecEntryCfg.ModifierConfigs = commoncodec.ModifiersConfig{ + &commoncodec.RenameModifierConfig{Fields: map[string]string{"NestedDynamicStruct.Inner.IntVal": "I"}}, + &commoncodec.RenameModifierConfig{Fields: map[string]string{"NestedStaticStruct.Inner.IntVal": "I"}}, + } + } + + if slices.Contains([]string{TestItemType, TestItemSliceType, TestItemArray1Type, TestItemArray2Type, testutils.TestItemWithConfigExtraType, testutils.TestEventItem}, offChainName) { + addressByteModifier := &commoncodec.AddressBytesToStringModifierConfig{ + Fields: []string{"AccountStruct.AccountStr"}, + Modifier: codec.SolanaAddressModifier{}, + } + codecEntryCfg.ModifierConfigs = append(codecEntryCfg.ModifierConfigs, addressByteModifier) + } + + if offChainName == testutils.TestItemWithConfigExtraType { + hardCode := &commoncodec.HardCodeModifierConfig{ + OnChainValues: map[string]any{ + "BigField": TestItem.BigField.String(), + "AccountStruct.Account": solana.PublicKeyFromBytes(TestItem.AccountStruct.Account), + }, + OffChainValues: map[string]any{"ExtraField": anyExtraValue}, + } + codecEntryCfg.ModifierConfigs = append(codecEntryCfg.ModifierConfigs, hardCode) + } + codecConfig.Configs[offChainName] = codecEntryCfg + } + + c, err := codec.NewCodec(codecConfig) + require.NoError(t, err) + + return c +} + +func (it *codecInterfaceTester) IncludeArrayEncodingSizeEnforcement() bool { + return true +} +func (it *codecInterfaceTester) Name() string { + return "Solana" +} diff --git a/pkg/solana/codec/decoder.go b/pkg/solana/codec/decoder.go new file mode 100644 index 000000000..242dbc44f --- /dev/null +++ b/pkg/solana/codec/decoder.go @@ -0,0 +1,35 @@ +package codec + +import ( + "context" + "fmt" + + "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" +) + +type Decoder struct { + definitions map[string]Entry + codecFromTypeCodec encodings.CodecFromTypeCodec +} + +var _ commontypes.Decoder = &Decoder{} + +func (d *Decoder) Decode(ctx context.Context, raw []byte, into any, itemType string) (err error) { + if d.codecFromTypeCodec == nil { + d.codecFromTypeCodec = make(encodings.CodecFromTypeCodec) + for k, v := range d.definitions { + d.codecFromTypeCodec[k] = v + } + } + + return d.codecFromTypeCodec.Decode(ctx, raw, into, itemType) +} + +func (d *Decoder) GetMaxDecodingSize(_ context.Context, n int, itemType string) (int, error) { + codecEntry, ok := d.definitions[itemType] + if !ok { + return 0, fmt.Errorf("%w: nil entry", commontypes.ErrInvalidType) + } + return codecEntry.GetCodecType().Size(n) +} diff --git a/pkg/solana/codec/decoder_test.go b/pkg/solana/codec/decoder_test.go new file mode 100644 index 000000000..ceea9644f --- /dev/null +++ b/pkg/solana/codec/decoder_test.go @@ -0,0 +1,90 @@ +package codec + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + commonencodings "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" +) + +type testErrDecodeEntry struct { + entry +} + +func (t *testErrDecodeEntry) Decode(_ []byte) (interface{}, []byte, error) { + return nil, nil, fmt.Errorf("decode error") +} + +type testErrDecodeRemainingBytes struct { + entry +} + +func (t *testErrDecodeRemainingBytes) Decode(_ []byte) (interface{}, []byte, error) { + return nil, []byte{1}, nil +} + +func TestDecoder_Decode_Errors(t *testing.T) { + var into interface{} + someType := "some-type" + t.Run("error when item type not found", func(t *testing.T) { + d := &Decoder{definitions: map[string]Entry{}} + d.definitions[someType] = &entry{} + + nonExistentType := "non-existent" + err := d.Decode(tests.Context(t), []byte{}, &into, nonExistentType) + require.ErrorIs(t, err, fmt.Errorf("%w: cannot find type %s", commontypes.ErrInvalidType, nonExistentType)) + }) + + t.Run("error when underlying entry decode fails", func(t *testing.T) { + d := &Decoder{definitions: map[string]Entry{}} + d.definitions[someType] = &testErrDecodeEntry{} + require.Error(t, d.Decode(tests.Context(t), []byte{}, &into, someType)) + }) + + t.Run("error when remaining bytes exist after decode", func(t *testing.T) { + d := &Decoder{definitions: map[string]Entry{}} + d.definitions[someType] = &testErrDecodeRemainingBytes{} + require.Error(t, d.Decode(tests.Context(t), []byte{}, &into, someType)) + }) +} + +type testErrGetMaxDecodingSize struct { + entry +} + +type testErrGetMaxDecodingSizeCodecType struct { + commonencodings.Empty +} + +func (t testErrGetMaxDecodingSizeCodecType) Size(_ int) (int, error) { + return 0, fmt.Errorf("error") +} + +func (t *testErrGetMaxDecodingSize) GetCodecType() commonencodings.TypeCodec { + return testErrGetMaxDecodingSizeCodecType{} +} + +func TestDecoder_GetMaxDecodingSize_Errors(t *testing.T) { + someType := "some-type" + + t.Run("error when entry for item type is missing", func(t *testing.T) { + d := &Decoder{definitions: map[string]Entry{}} + d.definitions[someType] = &entry{} + + nonExistentType := "non-existent" + _, err := d.GetMaxDecodingSize(tests.Context(t), 0, nonExistentType) + require.ErrorIs(t, err, fmt.Errorf("%w: cannot find type %s", commontypes.ErrInvalidType, nonExistentType)) + }) + + t.Run("error when underlying entry decode fails", func(t *testing.T) { + d := &Decoder{definitions: map[string]Entry{}} + d.definitions[someType] = &testErrGetMaxDecodingSize{} + + _, err := d.GetMaxDecodingSize(tests.Context(t), 0, someType) + require.Error(t, err) + }) +} diff --git a/pkg/solana/codec/discriminator.go b/pkg/solana/codec/discriminator.go index f712a3f68..9bc363ae7 100644 --- a/pkg/solana/codec/discriminator.go +++ b/pkg/solana/codec/discriminator.go @@ -12,16 +12,16 @@ import ( const discriminatorLength = 8 -func NewDiscriminator(name string) encodings.TypeCodec { +func NewDiscriminator(name string) *Discriminator { sum := sha256.Sum256([]byte("account:" + name)) - return &discriminator{hashPrefix: sum[:discriminatorLength]} + return &Discriminator{hashPrefix: sum[:discriminatorLength]} } -type discriminator struct { +type Discriminator struct { hashPrefix []byte } -func (d discriminator) Encode(value any, into []byte) ([]byte, error) { +func (d Discriminator) Encode(value any, into []byte) ([]byte, error) { if value == nil { return append(into, d.hashPrefix...), nil } @@ -44,7 +44,7 @@ func (d discriminator) Encode(value any, into []byte) ([]byte, error) { return append(into, *raw...), nil } -func (d discriminator) Decode(encoded []byte) (any, []byte, error) { +func (d Discriminator) Decode(encoded []byte) (any, []byte, error) { raw, remaining, err := encodings.SafeDecode(encoded, discriminatorLength, func(raw []byte) []byte { return raw }) if err != nil { return nil, nil, err @@ -57,15 +57,15 @@ func (d discriminator) Decode(encoded []byte) (any, []byte, error) { return &raw, remaining, nil } -func (d discriminator) GetType() reflect.Type { +func (d Discriminator) GetType() reflect.Type { // Pointer type so that nil can inject values and so that the NamedCodec won't wrap with no-nil pointer. return reflect.TypeOf(&[]byte{}) } -func (d discriminator) Size(_ int) (int, error) { +func (d Discriminator) Size(_ int) (int, error) { return discriminatorLength, nil } -func (d discriminator) FixedSize() (int, error) { +func (d Discriminator) FixedSize() (int, error) { return discriminatorLength, nil } diff --git a/pkg/solana/codec/encoder.go b/pkg/solana/codec/encoder.go new file mode 100644 index 000000000..409fb0013 --- /dev/null +++ b/pkg/solana/codec/encoder.go @@ -0,0 +1,35 @@ +package codec + +import ( + "context" + "fmt" + + "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" +) + +type Encoder struct { + definitions map[string]Entry + codecFromTypeCodec encodings.CodecFromTypeCodec +} + +var _ commontypes.Encoder = &Encoder{} + +func (e *Encoder) Encode(ctx context.Context, item any, itemType string) (res []byte, err error) { + if e.codecFromTypeCodec == nil { + e.codecFromTypeCodec = make(encodings.CodecFromTypeCodec) + for k, v := range e.definitions { + e.codecFromTypeCodec[k] = v + } + } + + return e.codecFromTypeCodec.Encode(ctx, item, itemType) +} + +func (e *Encoder) GetMaxEncodingSize(_ context.Context, n int, itemType string) (int, error) { + entry, ok := e.definitions[itemType] + if !ok { + return 0, fmt.Errorf("%w: nil entry", commontypes.ErrInvalidType) + } + return entry.GetCodecType().Size(n) +} diff --git a/pkg/solana/codec/encoder_test.go b/pkg/solana/codec/encoder_test.go new file mode 100644 index 000000000..fb098d884 --- /dev/null +++ b/pkg/solana/codec/encoder_test.go @@ -0,0 +1,103 @@ +package codec + +import ( + "fmt" + "reflect" + "testing" + + "github.com/stretchr/testify/require" + + commonencodings "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/utils/tests" +) + +type testErrEncodeEntry struct { + entry + codecType commonencodings.TypeCodec +} + +func (t *testErrEncodeEntry) Encode(_ interface{}, _ []byte) ([]byte, error) { + return nil, fmt.Errorf("encode error") +} + +func (t *testErrEncodeEntry) GetType() reflect.Type { + return commonencodings.Empty{}.GetType() +} + +type testErrEncodeTypeEntry struct { + entry + tCodec commonencodings.TypeCodec +} + +func (e *testErrEncodeTypeEntry) GetCodecType() commonencodings.TypeCodec { + return e.tCodec +} + +func TestEncoder_Encode_Errors(t *testing.T) { + someType := "some-type" + + t.Run("error when item type not found", func(t *testing.T) { + e := &Encoder{definitions: map[string]Entry{}} + _, err := e.Encode(tests.Context(t), nil, "non-existent-type") + require.Error(t, err) + require.ErrorIs(t, err, commontypes.ErrInvalidType) + require.Contains(t, err.Error(), "cannot find type non-existent-type") + }) + + t.Run("error when convert fails because of unexpected type", func(t *testing.T) { + e := &Encoder{ + definitions: map[string]Entry{ + someType: &testErrEncodeEntry{}, + }, + } + _, err := e.Encode(tests.Context(t), nil, someType) + require.Error(t, err) + }) + + t.Run("error when entry encode fails", func(t *testing.T) { + e := &Encoder{ + definitions: map[string]Entry{ + someType: &testErrEncodeEntry{codecType: commonencodings.Empty{}}, + }, + } + _, err := e.Encode(tests.Context(t), make(map[string]interface{}), someType) + require.ErrorContains(t, err, "encode error") + }) +} + +type testErrGetSize struct { + commonencodings.Empty + retType reflect.Type +} + +func (t testErrGetSize) GetType() reflect.Type { + return t.retType +} + +func (t testErrGetSize) Size(_ int) (int, error) { + return 0, fmt.Errorf("size error") +} + +func TestEncoder_GetMaxEncodingSize_Errors(t *testing.T) { + t.Run("error when entry for item type is missing", func(t *testing.T) { + e := &Encoder{definitions: map[string]Entry{}} + _, err := e.GetMaxEncodingSize(tests.Context(t), 10, "no-entry-type") + require.Error(t, err) + require.ErrorIs(t, err, commontypes.ErrInvalidType) + require.Contains(t, err.Error(), "nil entry") + }) + + t.Run("error when size calculation fails", func(t *testing.T) { + someType := "some-type" + e := &Encoder{ + definitions: map[string]Entry{ + someType: &testErrEncodeTypeEntry{tCodec: testErrGetSize{}}, + }, + } + + _, err := e.GetMaxEncodingSize(tests.Context(t), 0, someType) + require.Error(t, err) + require.Contains(t, err.Error(), "size error") + }) +} diff --git a/pkg/solana/codec/parsed_types.go b/pkg/solana/codec/parsed_types.go new file mode 100644 index 000000000..3144f6dd0 --- /dev/null +++ b/pkg/solana/codec/parsed_types.go @@ -0,0 +1,48 @@ +package codec + +import ( + "fmt" + "reflect" + + commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" +) + +type ParsedTypes struct { + EncoderDefs map[string]Entry + DecoderDefs map[string]Entry +} + +func (parsed *ParsedTypes) ToCodec() (commontypes.RemoteCodec, error) { + modByTypeName := map[string]commoncodec.Modifier{} + if err := AddEntries(parsed.EncoderDefs, modByTypeName); err != nil { + return nil, err + } + if err := AddEntries(parsed.DecoderDefs, modByTypeName); err != nil { + return nil, err + } + + mod, err := commoncodec.NewByItemTypeModifier(modByTypeName) + if err != nil { + return nil, err + } + underlying := &solanaCodec{ + Encoder: &Encoder{definitions: parsed.EncoderDefs}, + Decoder: &Decoder{definitions: parsed.DecoderDefs}, + ParsedTypes: parsed, + } + return commoncodec.NewModifierCodec(underlying, mod, DecoderHooks...) +} + +// AddEntries extracts the mods from entry and adds them to modByTypeName use with codec.NewByItemTypeModifier +// Since each input/output can have its own modifications, we need to keep track of them by type name +func AddEntries(defs map[string]Entry, modByTypeName map[string]commoncodec.Modifier) error { + for k, def := range defs { + modByTypeName[k] = def.Modifier() + _, err := def.Modifier().RetypeToOffChain(reflect.PointerTo(def.GetType()), k) + if err != nil { + return fmt.Errorf("%w: cannot retype %v: %w", commontypes.ErrInvalidConfig, k, err) + } + } + return nil +} diff --git a/pkg/solana/codec/solana.go b/pkg/solana/codec/solana.go index 71e2f7f06..19fe40d3e 100644 --- a/pkg/solana/codec/solana.go +++ b/pkg/solana/codec/solana.go @@ -20,58 +20,130 @@ Modifiers can be provided to assist in modifying property names, adding properti package codec import ( + "encoding/json" "fmt" "math" + "reflect" "github.com/go-viper/mapstructure/v2" "golang.org/x/text/cases" "golang.org/x/text/language" - "github.com/smartcontractkit/chainlink-common/pkg/codec" - "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" - "github.com/smartcontractkit/chainlink-common/pkg/types" + commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" + commonencodings "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" + "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings/binary" + commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" ) const ( DefaultHashBitLength = 32 - unknownIDLFormat = "%w: unknown IDL type def %s" + unknownIDLFormat = "%w: unknown IDL type def %q" ) +// DecoderHooks +// // BigIntHook allows *big.Int to be represented as any integer type or a string and to go back to them. // Useful for config, or if when a model may use a go type that isn't a *big.Int when Pack expects one. // Eg: int32 in a go struct from a plugin could require a *big.Int in Pack for int24, if it fits, we shouldn't care. // SliceToArrayVerifySizeHook verifies that slices have the correct size when converting to an array // EpochToTimeHook allows multiple conversions: time.Time -> int64; int64 -> time.Time; *big.Int -> time.Time; and more -var DecoderHooks = []mapstructure.DecodeHookFunc{codec.EpochToTimeHook, codec.BigIntHook, codec.SliceToArrayVerifySizeHook} +var DecoderHooks = []mapstructure.DecodeHookFunc{commoncodec.EpochToTimeHook, commoncodec.BigIntHook, commoncodec.SliceToArrayVerifySizeHook} -func NewNamedModifierCodec(original types.RemoteCodec, itemType string, modifier codec.Modifier) (types.RemoteCodec, error) { - mod, err := codec.NewByItemTypeModifier(map[string]codec.Modifier{itemType: modifier}) - if err != nil { - return nil, err +type solanaCodec struct { + *Encoder + *Decoder + *ParsedTypes +} + +// NewCodec creates a new [commontypes.RemoteCodec] for Solana. +func NewCodec(conf Config) (commontypes.RemoteCodec, error) { + parsed := &ParsedTypes{ + EncoderDefs: map[string]Entry{}, + DecoderDefs: map[string]Entry{}, } - modCodec, err := codec.NewModifierCodec(original, mod, DecoderHooks...) - if err != nil { - return nil, err + for offChainName, cfg := range conf.Configs { + var idl IDL + if err := json.Unmarshal([]byte(cfg.IDL), &idl); err != nil { + return nil, err + } + + mod, err := cfg.ModifierConfigs.ToModifier(DecoderHooks...) + if err != nil { + return nil, err + } + + definition, err := findDefinitionFromIDL(cfg.Type, cfg.OnChainName, idl) + if err != nil { + return nil, err + } + + var cEntry Entry + switch v := definition.(type) { + case IdlTypeDef: + cEntry, err = NewAccountEntry(offChainName, v, idl.Types, true, mod, binary.LittleEndian()) + case IdlInstruction: + cEntry, err = NewInstructionArgsEntry(offChainName, v, idl.Types, mod, binary.LittleEndian()) + case IdlEvent: + cEntry, err = NewEventArgsEntry(offChainName, v, idl.Types, true, mod, binary.LittleEndian()) + } + if err != nil { + return nil, fmt.Errorf("failed to create %q codec entry: %w", offChainName, err) + } + + parsed.EncoderDefs[offChainName] = cEntry + parsed.DecoderDefs[offChainName] = cEntry } - _, err = modCodec.CreateType(itemType, true) + return parsed.ToCodec() +} - return modCodec, err +func findDefinitionFromIDL(cfgType ChainConfigType, onChainName string, idl IDL) (interface{}, error) { + // not the most efficient way to do this, but these slices should always be very, very small + switch cfgType { + case ChainConfigTypeAccountDef: + for i := range idl.Accounts { + if idl.Accounts[i].Name == onChainName { + return idl.Accounts[i], nil + } + } + return nil, fmt.Errorf("failed to find account %q in IDL", onChainName) + + case ChainConfigTypeInstructionDef: + for i := range idl.Instructions { + if idl.Instructions[i].Name == onChainName { + return idl.Instructions[i], nil + } + } + return nil, fmt.Errorf("failed to find instruction %q in IDL", onChainName) + + case ChainConfigTypeEventDef: + for i := range idl.Events { + if idl.Events[i].Name == onChainName { + return idl.Events[i], nil + } + } + return nil, fmt.Errorf("failed to find event %q in IDL", onChainName) + } + return nil, fmt.Errorf("unknown type: %q", cfgType) +} + +// NewIDLAccountCodec is for Anchor custom types +func NewIDLAccountCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) { + return newIDLCoded(idl, builder, idl.Accounts, true) } -func NewIDLInstructionsCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) { - typeCodecs := make(encodings.LenientCodecFromTypeCodec) - caser := cases.Title(language.English) +func NewIDLInstructionsCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) { + typeCodecs := make(commonencodings.LenientCodecFromTypeCodec) refs := &codecRefs{ builder: builder, - codecs: make(map[string]encodings.TypeCodec), + codecs: make(map[string]commonencodings.TypeCodec), typeDefs: idl.Types, dependencies: make(map[string][]string), } for _, instruction := range idl.Instructions { - name, instCodec, err := asStruct(instruction.Args, refs, instruction.Name, caser, false) + name, instCodec, err := asStruct(instruction.Args, refs, instruction.Name, false, false) if err != nil { return nil, err } @@ -82,22 +154,54 @@ func NewIDLInstructionsCodec(idl IDL, builder encodings.Builder) (types.RemoteCo return typeCodecs, nil } -// NewIDLAccountCodec is for Anchor custom types -func NewIDLAccountCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) { - return newIDLCoded(idl, builder, idl.Accounts, true) +func NewNamedModifierCodec(original commontypes.RemoteCodec, itemType string, modifier commoncodec.Modifier) (commontypes.RemoteCodec, error) { + mod, err := commoncodec.NewByItemTypeModifier(map[string]commoncodec.Modifier{itemType: modifier}) + if err != nil { + return nil, err + } + + modCodec, err := commoncodec.NewModifierCodec(original, mod, DecoderHooks...) + if err != nil { + return nil, err + } + + _, err = modCodec.CreateType(itemType, true) + + return modCodec, err } -func NewIDLDefinedTypesCodec(idl IDL, builder encodings.Builder) (types.RemoteCodec, error) { +func NewIDLDefinedTypesCodec(idl IDL, builder commonencodings.Builder) (commontypes.RemoteCodec, error) { return newIDLCoded(idl, builder, idl.Types, false) } +func (s solanaCodec) CreateType(itemType string, forEncoding bool) (any, error) { + var itemTypes map[string]Entry + if forEncoding { + itemTypes = s.EncoderDefs + } else { + itemTypes = s.DecoderDefs + } + + def, ok := itemTypes[itemType] + if !ok { + return nil, fmt.Errorf("%w: cannot find type name %q", commontypes.ErrInvalidType, itemType) + } + + // we don't need double pointers, and they can also mess up reflection variable creation and mapstruct decode + if def.GetType().Kind() == reflect.Pointer { + return reflect.New(def.GetCodecType().GetType().Elem()).Interface(), nil + } + + return reflect.New(def.GetType()).Interface(), nil +} + func newIDLCoded( - idl IDL, builder encodings.Builder, from IdlTypeDefSlice, includeDiscriminator bool) (types.RemoteCodec, error) { - typeCodecs := make(encodings.LenientCodecFromTypeCodec) + idl IDL, builder commonencodings.Builder, from IdlTypeDefSlice, includeDiscriminator bool) (commontypes.RemoteCodec, error) { + typeCodecs := make(commonencodings.LenientCodecFromTypeCodec) refs := &codecRefs{ builder: builder, - codecs: make(map[string]encodings.TypeCodec), + codecs: make(map[string]commonencodings.TypeCodec), typeDefs: idl.Types, dependencies: make(map[string][]string), } @@ -105,11 +209,11 @@ func newIDLCoded( for _, def := range from { var ( name string - accCodec encodings.TypeCodec + accCodec commonencodings.TypeCodec err error ) - name, accCodec, err = createNamedCodec(def, refs, includeDiscriminator) + name, accCodec, err = createCodecType(def, refs, includeDiscriminator) if err != nil { return nil, err } @@ -121,32 +225,29 @@ func newIDLCoded( } type codecRefs struct { - builder encodings.Builder - codecs map[string]encodings.TypeCodec + builder commonencodings.Builder + codecs map[string]commonencodings.TypeCodec typeDefs IdlTypeDefSlice dependencies map[string][]string } -func createNamedCodec( +func createCodecType( def IdlTypeDef, refs *codecRefs, includeDiscriminator bool, -) (string, encodings.TypeCodec, error) { - caser := cases.Title(language.English) +) (string, commonencodings.TypeCodec, error) { name := def.Name - switch def.Type.Kind { case IdlTypeDefTyKindStruct: - return asStruct(*def.Type.Fields, refs, name, caser, includeDiscriminator) + return asStruct(*def.Type.Fields, refs, name, includeDiscriminator, false) case IdlTypeDefTyKindEnum: variants := def.Type.Variants if !variants.IsAllUint8() { - return name, nil, fmt.Errorf("%w: variants are not supported", types.ErrInvalidConfig) + return name, nil, fmt.Errorf("%w: variants are not supported", commontypes.ErrInvalidConfig) } - return name, refs.builder.Uint8(), nil default: - return name, nil, fmt.Errorf(unknownIDLFormat, types.ErrInvalidConfig, def.Type.Kind) + return name, nil, fmt.Errorf(unknownIDLFormat, commontypes.ErrInvalidConfig, def.Type.Kind) } } @@ -154,17 +255,18 @@ func asStruct( fields []IdlField, refs *codecRefs, name string, // name is the struct name and can be used in dependency checks - caser cases.Caser, includeDiscriminator bool, -) (string, encodings.TypeCodec, error) { + isInstructionArgs bool, +) (string, commonencodings.TypeCodec, error) { desLen := 0 if includeDiscriminator { desLen = 1 } - named := make([]encodings.NamedTypeCodec, len(fields)+desLen) + + named := make([]commonencodings.NamedTypeCodec, len(fields)+desLen) if includeDiscriminator { - named[0] = encodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name)} + named[0] = commonencodings.NamedTypeCodec{Name: "Discriminator" + name, Codec: NewDiscriminator(name)} } for idx, field := range fields { @@ -175,10 +277,15 @@ func asStruct( return name, nil, err } - named[idx+desLen] = encodings.NamedTypeCodec{Name: caser.String(fieldName), Codec: typedCodec} + named[idx+desLen] = commonencodings.NamedTypeCodec{Name: cases.Title(language.English, cases.NoLower).String(fieldName), Codec: typedCodec} + } + + // accounts have to be in a struct, instruction args don't + if len(named) == 1 && isInstructionArgs { + return name, named[0].Codec, nil } - structCodec, err := encodings.NewStructCodec(named) + structCodec, err := commonencodings.NewStructCodec(named) if err != nil { return name, nil, err } @@ -186,7 +293,7 @@ func asStruct( return name, structCodec, nil } -func processFieldType(parentTypeName string, idlType IdlType, refs *codecRefs) (encodings.TypeCodec, error) { +func processFieldType(parentTypeName string, idlType IdlType, refs *codecRefs) (commonencodings.TypeCodec, error) { switch true { case idlType.IsString(): return getCodecByStringType(idlType.GetString(), refs.builder) @@ -201,13 +308,13 @@ func processFieldType(parentTypeName string, idlType IdlType, refs *codecRefs) ( case idlType.IsIdlTypeVec(): return asVec(parentTypeName, idlType.GetIdlTypeVec(), refs) default: - return nil, fmt.Errorf("%w: unknown IDL type def", types.ErrInvalidConfig) + return nil, fmt.Errorf("%w: unknown IDL type def", commontypes.ErrInvalidConfig) } } -func asDefined(parentTypeName string, definedName *IdlTypeDefined, refs *codecRefs) (encodings.TypeCodec, error) { +func asDefined(parentTypeName string, definedName *IdlTypeDefined, refs *codecRefs) (commonencodings.TypeCodec, error) { if definedName == nil { - return nil, fmt.Errorf("%w: defined type name should not be nil", types.ErrInvalidConfig) + return nil, fmt.Errorf("%w: defined type name should not be nil", commontypes.ErrInvalidConfig) } // already exists as a type in the typed codecs @@ -217,19 +324,19 @@ func asDefined(parentTypeName string, definedName *IdlTypeDefined, refs *codecRe // nextDef should not have a dependency on definedName if !validDependency(refs, parentTypeName, definedName.Defined) { - return nil, fmt.Errorf("%w: circular dependency detected on %s -> %s relation", types.ErrInvalidConfig, parentTypeName, definedName.Defined) + return nil, fmt.Errorf("%w: circular dependency detected on %q -> %q relation", commontypes.ErrInvalidConfig, parentTypeName, definedName.Defined) } // codec by defined type doesn't exist // process it using the provided typeDefs nextDef := refs.typeDefs.GetByName(definedName.Defined) if nextDef == nil { - return nil, fmt.Errorf("%w: IDL type does not exist for name %s", types.ErrInvalidConfig, definedName.Defined) + return nil, fmt.Errorf("%w: IDL type does not exist for name %q", commontypes.ErrInvalidConfig, definedName.Defined) } saveDependency(refs, parentTypeName, definedName.Defined) - newTypeName, newTypeCodec, err := createNamedCodec(*nextDef, refs, false) + newTypeName, newTypeCodec, err := createCodecType(*nextDef, refs, false) if err != nil { return nil, err } @@ -240,16 +347,16 @@ func asDefined(parentTypeName string, definedName *IdlTypeDefined, refs *codecRe return newTypeCodec, nil } -func asArray(parentTypeName string, idlArray *IdlTypeArray, refs *codecRefs) (encodings.TypeCodec, error) { +func asArray(parentTypeName string, idlArray *IdlTypeArray, refs *codecRefs) (commonencodings.TypeCodec, error) { codec, err := processFieldType(parentTypeName, idlArray.Thing, refs) if err != nil { return nil, err } - return encodings.NewArray(idlArray.Num, codec) + return commonencodings.NewArray(idlArray.Num, codec) } -func asVec(parentTypeName string, idlVec *IdlTypeVec, refs *codecRefs) (encodings.TypeCodec, error) { +func asVec(parentTypeName string, idlVec *IdlTypeVec, refs *codecRefs) (commonencodings.TypeCodec, error) { codec, err := processFieldType(parentTypeName, idlVec.Vec, refs) if err != nil { return nil, err @@ -260,10 +367,10 @@ func asVec(parentTypeName string, idlVec *IdlTypeVec, refs *codecRefs) (encoding return nil, err } - return encodings.NewSlice(codec, b) + return commonencodings.NewSlice(codec, b) } -func getCodecByStringType(curType IdlTypeAsString, builder encodings.Builder) (encodings.TypeCodec, error) { +func getCodecByStringType(curType IdlTypeAsString, builder commonencodings.Builder) (commonencodings.TypeCodec, error) { switch curType { case IdlTypeBool: return builder.Bool(), nil @@ -278,11 +385,11 @@ func getCodecByStringType(curType IdlTypeAsString, builder encodings.Builder) (e case IdlTypeBytes, IdlTypePublicKey, IdlTypeHash: return getByteCodecByStringType(curType, builder) default: - return nil, fmt.Errorf(unknownIDLFormat, types.ErrInvalidConfig, curType) + return nil, fmt.Errorf(unknownIDLFormat, commontypes.ErrInvalidConfig, curType) } } -func getIntCodecByStringType(curType IdlTypeAsString, builder encodings.Builder) (encodings.TypeCodec, error) { +func getIntCodecByStringType(curType IdlTypeAsString, builder commonencodings.Builder) (commonencodings.TypeCodec, error) { switch curType { case IdlTypeI8: return builder.Int8(), nil @@ -295,11 +402,11 @@ func getIntCodecByStringType(curType IdlTypeAsString, builder encodings.Builder) case IdlTypeI128: return builder.BigInt(16, true) default: - return nil, fmt.Errorf(unknownIDLFormat, types.ErrInvalidConfig, curType) + return nil, fmt.Errorf(unknownIDLFormat, commontypes.ErrInvalidConfig, curType) } } -func getUIntCodecByStringType(curType IdlTypeAsString, builder encodings.Builder) (encodings.TypeCodec, error) { +func getUIntCodecByStringType(curType IdlTypeAsString, builder commonencodings.Builder) (commonencodings.TypeCodec, error) { switch curType { case IdlTypeU8: return builder.Uint8(), nil @@ -312,22 +419,22 @@ func getUIntCodecByStringType(curType IdlTypeAsString, builder encodings.Builder case IdlTypeU128: return builder.BigInt(16, true) default: - return nil, fmt.Errorf(unknownIDLFormat, types.ErrInvalidConfig, curType) + return nil, fmt.Errorf(unknownIDLFormat, commontypes.ErrInvalidConfig, curType) } } -func getTimeCodecByStringType(curType IdlTypeAsString, builder encodings.Builder) (encodings.TypeCodec, error) { +func getTimeCodecByStringType(curType IdlTypeAsString, builder commonencodings.Builder) (commonencodings.TypeCodec, error) { switch curType { case IdlTypeUnixTimestamp: return builder.Int64(), nil case IdlTypeDuration: return NewDuration(builder), nil default: - return nil, fmt.Errorf(unknownIDLFormat, types.ErrInvalidConfig, curType) + return nil, fmt.Errorf(unknownIDLFormat, commontypes.ErrInvalidConfig, curType) } } -func getByteCodecByStringType(curType IdlTypeAsString, builder encodings.Builder) (encodings.TypeCodec, error) { +func getByteCodecByStringType(curType IdlTypeAsString, builder commonencodings.Builder) (commonencodings.TypeCodec, error) { switch curType { case IdlTypeBytes: b, err := builder.Int(4) @@ -335,11 +442,11 @@ func getByteCodecByStringType(curType IdlTypeAsString, builder encodings.Builder return nil, err } - return encodings.NewSlice(builder.Uint8(), b) + return commonencodings.NewSlice(builder.Uint8(), b) case IdlTypePublicKey, IdlTypeHash: - return encodings.NewArray(DefaultHashBitLength, builder.Uint8()) + return commonencodings.NewArray(DefaultHashBitLength, builder.Uint8()) default: - return nil, fmt.Errorf(unknownIDLFormat, types.ErrInvalidConfig, curType) + return nil, fmt.Errorf(unknownIDLFormat, commontypes.ErrInvalidConfig, curType) } } diff --git a/pkg/solana/codec/testutils/eventItemTypeIDL.json b/pkg/solana/codec/testutils/eventItemTypeIDL.json new file mode 100644 index 000000000..f98f27671 --- /dev/null +++ b/pkg/solana/codec/testutils/eventItemTypeIDL.json @@ -0,0 +1,73 @@ +{ + "version": "0.1.0", + "name": "test_item_event_type", + "instructions": [], + "events": [ + { + "name": "TestItem", + "fields": [ + { "name": "Field", "type": "i32" }, + { "name": "OracleId", "type": "u8" }, + { "name": "OracleIds", "type": { "array": ["u8", 32] } }, + { "name": "AccountStruct", "type": { "defined": "AccountStruct" } }, + { "name": "Accounts", "type": { "vec": "publicKey" } }, + { "name": "DifferentField", "type": "string" }, + { "name": "BigField", "type": "i128" }, + { "name": "NestedDynamicStruct", "type": { "defined": "NestedDynamic" } }, + { "name": "NestedStaticStruct", "type": { "defined": "NestedStatic" } } + ] + } + ], + "types": [ + { + "name": "AccountStruct", + "type": { + "kind": "struct", + "fields": [ + { "name": "Account", "type": "publicKey" }, + { "name": "AccountStr", "type": "publicKey" } + ] + } + }, + { + "name": "InnerDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "S", "type": "string" } + ] + } + }, + { + "name": "NestedDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerDynamic" } } + ] + } + }, + { + "name": "InnerStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "A", "type": "publicKey" } + ] + } + }, + { + "name": "NestedStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerStatic" } } + ] + } + } + ] +} diff --git a/pkg/solana/codec/testutils/itemArray1TypeIDL.json b/pkg/solana/codec/testutils/itemArray1TypeIDL.json new file mode 100644 index 000000000..8d061aed2 --- /dev/null +++ b/pkg/solana/codec/testutils/itemArray1TypeIDL.json @@ -0,0 +1,92 @@ +{ + "version": "0.1.0", + "name": "test_item_array1", + "instructions": [ + { + "name": "TestItemArray1Type", + "accounts": [], + "args": [ + { + "name": "TestItemArray1Type", + "type": { + "array": [ + { + "defined": "TestItem" + }, + 1 + ] + } + } + ] + } + ], + "types": [ + { + "name": "TestItem", + "type": { + "kind": "struct", + "fields": [ + { "name": "Field", "type": "i32" }, + { "name": "OracleId", "type": "u8" }, + { "name": "OracleIds", "type": { "array": ["u8", 32] } }, + { "name": "AccountStruct", "type": { "defined": "AccountStruct" } }, + { "name": "Accounts", "type": { "vec": "publicKey" } }, + { "name": "DifferentField", "type": "string" }, + { "name": "BigField", "type": "i128" }, + { "name": "NestedDynamicStruct", "type": { "defined": "NestedDynamic" } }, + { "name": "NestedStaticStruct", "type": { "defined": "NestedStatic" } } + ] + } + }, + { + "name": "AccountStruct", + "type": { + "kind": "struct", + "fields": [ + { "name": "Account", "type": "publicKey" }, + { "name": "AccountStr", "type": "publicKey" } + ] + } + }, + { + "name": "InnerDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "S", "type": "string" } + ] + } + }, + { + "name": "NestedDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerDynamic" } } + ] + } + }, + { + "name": "InnerStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "A", "type": "publicKey" } + ] + } + }, + { + "name": "NestedStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerStatic" } } + ] + } + } + ] +} \ No newline at end of file diff --git a/pkg/solana/codec/testutils/itemArray2TypeIDL.json b/pkg/solana/codec/testutils/itemArray2TypeIDL.json new file mode 100644 index 000000000..c20c55785 --- /dev/null +++ b/pkg/solana/codec/testutils/itemArray2TypeIDL.json @@ -0,0 +1,92 @@ +{ + "version": "0.1.0", + "name": "test_item_array2", + "instructions": [ + { + "name": "TestItemArray2Type", + "accounts": [], + "args": [ + { + "name": "TestItemArray2Type", + "type": { + "array": [ + { + "defined": "TestItem" + }, + 2 + ] + } + } + ] + } + ], + "types": [ + { + "name": "TestItem", + "type": { + "kind": "struct", + "fields": [ + { "name": "Field", "type": "i32" }, + { "name": "OracleId", "type": "u8" }, + { "name": "OracleIds", "type": { "array": ["u8", 32] } }, + { "name": "AccountStruct", "type": { "defined": "AccountStruct" } }, + { "name": "Accounts", "type": { "vec": "publicKey" } }, + { "name": "DifferentField", "type": "string" }, + { "name": "BigField", "type": "i128" }, + { "name": "NestedDynamicStruct", "type": { "defined": "NestedDynamic" } }, + { "name": "NestedStaticStruct", "type": { "defined": "NestedStatic" } } + ] + } + }, + { + "name": "AccountStruct", + "type": { + "kind": "struct", + "fields": [ + { "name": "Account", "type": "publicKey" }, + { "name": "AccountStr", "type": "publicKey" } + ] + } + }, + { + "name": "InnerDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "S", "type": "string" } + ] + } + }, + { + "name": "NestedDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerDynamic" } } + ] + } + }, + { + "name": "InnerStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "A", "type": "publicKey" } + ] + } + }, + { + "name": "NestedStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerStatic" } } + ] + } + } + ] +} diff --git a/pkg/solana/codec/testutils/itemIDL.json b/pkg/solana/codec/testutils/itemIDL.json new file mode 100644 index 000000000..ee2a719cc --- /dev/null +++ b/pkg/solana/codec/testutils/itemIDL.json @@ -0,0 +1,77 @@ +{ + "version": "0.1.0", + "name": "test_item_type", + "instructions": [ + ], + "accounts": [ + { + "name": "TestItem", + "type": { + "kind": "struct", + "fields": [ + { "name": "Field", "type": "i32" }, + { "name": "OracleId", "type": "u8" }, + { "name": "OracleIds", "type": { "array": ["u8", 32] } }, + { "name": "AccountStruct", "type": { "defined": "AccountStruct" } }, + { "name": "Accounts", "type": { "vec": "publicKey" } }, + { "name": "DifferentField", "type": "string" }, + { "name": "BigField", "type": "i128" }, + { "name": "NestedDynamicStruct", "type": { "defined": "NestedDynamic" } }, + { "name": "NestedStaticStruct", "type": { "defined": "NestedStatic" } } + ] + } + } + ], + "types": [ + { + "name": "AccountStruct", + "type": { + "kind": "struct", + "fields": [ + { "name": "Account", "type": "publicKey" }, + { "name": "AccountStr", "type": "publicKey" } + ] + } + }, + { + "name": "InnerDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "S", "type": "string" } + ] + } + }, + { + "name": "NestedDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerDynamic" } } + ] + } + }, + { + "name": "InnerStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "A", "type": "publicKey" } + ] + } + }, + { + "name": "NestedStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerStatic" } } + ] + } + } + ] +} diff --git a/pkg/solana/codec/testutils/itemSliceTypeIDL.json b/pkg/solana/codec/testutils/itemSliceTypeIDL.json new file mode 100644 index 000000000..491d3e12f --- /dev/null +++ b/pkg/solana/codec/testutils/itemSliceTypeIDL.json @@ -0,0 +1,89 @@ +{ + "version": "0.1.0", + "name": "test_item_slice_type", + "instructions": [ + { + "name": "TestItemSliceType", + "accounts": [], + "args": [ + { + "name": "TestItemSliceType", + "type": { + "vec": { + "defined": "TestItem" + } + } + } + ] + } + ], + "types": [ + { + "name": "TestItem", + "type": { + "kind": "struct", + "fields": [ + { "name": "Field", "type": "i32" }, + { "name": "OracleId", "type": "u8" }, + { "name": "OracleIds", "type": { "array": ["u8", 32] } }, + { "name": "AccountStruct", "type": { "defined": "AccountStruct" } }, + { "name": "Accounts", "type": { "vec": "publicKey" } }, + { "name": "DifferentField", "type": "string" }, + { "name": "BigField", "type": "i128" }, + { "name": "NestedDynamicStruct", "type": { "defined": "NestedDynamic" } }, + { "name": "NestedStaticStruct", "type": { "defined": "NestedStatic" } } + ] + } + }, + { + "name": "AccountStruct", + "type": { + "kind": "struct", + "fields": [ + { "name": "Account", "type": "publicKey" }, + { "name": "AccountStr", "type": "publicKey" } + ] + } + }, + { + "name": "InnerDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "S", "type": "string" } + ] + } + }, + { + "name": "NestedDynamic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerDynamic" } } + ] + } + }, + { + "name": "InnerStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "IntVal", "type": "i64" }, + { "name": "A", "type": "publicKey" } + ] + } + }, + { + "name": "NestedStatic", + "type": { + "kind": "struct", + "fields": [ + { "name": "FixedBytes", "type": { "array": ["u8", 2] } }, + { "name": "Inner", "type": { "defined": "InnerStatic" } } + ] + } + } + ] +} diff --git a/pkg/solana/codec/testutils/nilTypeIDL.json b/pkg/solana/codec/testutils/nilTypeIDL.json new file mode 100644 index 000000000..47b169428 --- /dev/null +++ b/pkg/solana/codec/testutils/nilTypeIDL.json @@ -0,0 +1,12 @@ +{ + "name": "NilType", + "accounts": [ + { + "name": "NilType", + "type": { + "kind": "struct", + "fields": [] + } + } + ] +} \ No newline at end of file diff --git a/pkg/solana/codec/testutils/sizeItemTypeIDL.json b/pkg/solana/codec/testutils/sizeItemTypeIDL.json new file mode 100644 index 000000000..fdb73115e --- /dev/null +++ b/pkg/solana/codec/testutils/sizeItemTypeIDL.json @@ -0,0 +1,38 @@ +{ + "version": "0.1.0", + "name": "item_for_size", + "instructions": [ + { + "name": "ProcessItemForSize", + "accounts": [ + { + "name": "ItemForSize", + "isMut": true, + "isSigner": false + } + ], + "args": [] + } + ], + "accounts": [ + { + "name": "ItemForSize", + "type": { + "kind": "struct", + "fields": [ + { + "name": "Stuff", + "type": { + "vec": "i128" + } + }, + { + "name": "OtherStuff", + "type": "i128" + } + ] + } + } + ], + "types": [] +} diff --git a/pkg/solana/codec/testutils/types.go b/pkg/solana/codec/testutils/types.go index 533e88b0b..3c52adb0f 100644 --- a/pkg/solana/codec/testutils/types.go +++ b/pkg/solana/codec/testutils/types.go @@ -2,10 +2,16 @@ package testutils import ( _ "embed" + "fmt" "math/big" "time" - ag_solana "github.com/gagliardetto/solana-go" + agbinary "github.com/gagliardetto/binary" + "github.com/gagliardetto/solana-go" + + "github.com/smartcontractkit/chainlink-common/pkg/types/interfacetests" + + "github.com/smartcontractkit/chainlink-solana/pkg/solana/codec" ) var ( @@ -45,9 +51,11 @@ var ( BasicVector: []string{"some string", "another string"}, TimeVal: 683_100_000, DurationVal: 42 * time.Second, - PublicKey: ag_solana.NewWallet().PublicKey(), + PublicKey: solana.NewWallet().PublicKey(), EnumVal: 0, } + TestItemWithConfigExtraType = "TestItemWithConfigExtra" + TestEventItem = "TestEventItem" ) type StructWithNestedStruct struct { @@ -59,7 +67,7 @@ type StructWithNestedStruct struct { BasicVector []string TimeVal int64 DurationVal time.Duration - PublicKey ag_solana.PublicKey + PublicKey solana.PublicKey EnumVal uint8 } @@ -86,3 +94,550 @@ var JSONIDLWithAllTypes string //go:embed circularDepIDL.json var CircularDepIDL string + +//go:embed itemIDL.json +var itemTypeJSONIDL string + +//go:embed eventItemTypeIDL.json +var eventItemTypeJSONIDL string + +//go:embed itemSliceTypeIDL.json +var itemSliceTypeJSONIDL string + +//go:embed itemArray1TypeIDL.json +var itemArray1TypeJSONIDL string + +//go:embed itemArray2TypeIDL.json +var itemArray2TypeJSONIDL string + +//go:embed nilTypeIDL.json +var nilTypeJSONIDL string + +type CodecDef struct { + IDL string + IDLTypeName string + ItemType codec.ChainConfigType +} + +// CodecDefs key is codec offchain type name +var CodecDefs = map[string]CodecDef{ + interfacetests.TestItemType: { + IDL: itemTypeJSONIDL, + IDLTypeName: interfacetests.TestItemType, + ItemType: codec.ChainConfigTypeAccountDef, + }, + interfacetests.TestItemSliceType: { + IDL: itemSliceTypeJSONIDL, + IDLTypeName: interfacetests.TestItemSliceType, + ItemType: codec.ChainConfigTypeInstructionDef, + }, + interfacetests.TestItemArray1Type: { + IDL: itemArray1TypeJSONIDL, + IDLTypeName: interfacetests.TestItemArray1Type, + ItemType: codec.ChainConfigTypeInstructionDef, + }, + interfacetests.TestItemArray2Type: { + IDL: itemArray2TypeJSONIDL, + IDLTypeName: interfacetests.TestItemArray2Type, + ItemType: codec.ChainConfigTypeInstructionDef, + }, + TestItemWithConfigExtraType: { + IDL: itemTypeJSONIDL, + IDLTypeName: interfacetests.TestItemType, + ItemType: codec.ChainConfigTypeAccountDef, + }, + interfacetests.NilType: { + IDL: nilTypeJSONIDL, + IDLTypeName: interfacetests.NilType, + ItemType: codec.ChainConfigTypeAccountDef, + }, + TestEventItem: { + IDL: eventItemTypeJSONIDL, + IDLTypeName: interfacetests.TestItemType, + ItemType: codec.ChainConfigTypeEventDef, + }, +} + +type TestItemAsAccount struct { + Field int32 + OracleID uint8 + OracleIDs [32]uint8 + AccountStruct AccountStruct + Accounts []solana.PublicKey + DifferentField string + BigField agbinary.Int128 + NestedDynamicStruct NestedDynamic + NestedStaticStruct NestedStatic +} + +var TestItemDiscriminator = [8]byte{148, 105, 105, 155, 26, 167, 212, 149} + +func (obj TestItemAsAccount) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { + // Write account discriminator: + err = encoder.WriteBytes(TestItemDiscriminator[:], false) + if err != nil { + return err + } + // Serialize `Field` param: + err = encoder.Encode(obj.Field) + if err != nil { + return err + } + // Serialize `OracleID` param: + err = encoder.Encode(obj.OracleID) + if err != nil { + return err + } + // Serialize `OracleIDs` param: + err = encoder.Encode(obj.OracleIDs) + if err != nil { + return err + } + // Serialize `AccountStruct` param: + err = encoder.Encode(obj.AccountStruct) + if err != nil { + return err + } + // Serialize `Accounts` param: + err = encoder.Encode(obj.Accounts) + if err != nil { + return err + } + // Serialize `DifferentField` param: + err = encoder.Encode(obj.DifferentField) + if err != nil { + return err + } + // Serialize `BigField` param: + err = encoder.Encode(obj.BigField) + if err != nil { + return err + } + // Serialize `NestedDynamicStruct` param: + err = encoder.Encode(obj.NestedDynamicStruct) + if err != nil { + return err + } + // Serialize `NestedStaticStruct` param: + err = encoder.Encode(obj.NestedStaticStruct) + if err != nil { + return err + } + return nil +} + +func (obj *TestItemAsAccount) UnmarshalWithDecoder(decoder *agbinary.Decoder) error { + // Read and check account discriminator: + { + discriminator, err := decoder.ReadTypeID() + if err != nil { + return err + } + if !discriminator.Equal(TestItemDiscriminator[:]) { + return fmt.Errorf( + "wrong discriminator: wanted %s, got %s", + "[148 105 105 155 26 167 212 149]", + fmt.Sprint(discriminator[:])) + } + } + // Deserialize `Field`: + err := decoder.Decode(&obj.Field) + if err != nil { + return err + } + // Deserialize `OracleID`: + err = decoder.Decode(&obj.OracleID) + if err != nil { + return err + } + // Deserialize `OracleIDs`: + err = decoder.Decode(&obj.OracleIDs) + if err != nil { + return err + } + // Deserialize `AccountStruct`: + err = decoder.Decode(&obj.AccountStruct) + if err != nil { + return err + } + // Deserialize `Accounts`: + err = decoder.Decode(&obj.Accounts) + if err != nil { + return err + } + // Deserialize `DifferentField`: + err = decoder.Decode(&obj.DifferentField) + if err != nil { + return err + } + // Deserialize `BigField`: + err = decoder.Decode(&obj.BigField) + if err != nil { + return err + } + // Deserialize `NestedDynamicStruct`: + err = decoder.Decode(&obj.NestedDynamicStruct) + if err != nil { + return err + } + // Deserialize `NestedStaticStruct`: + err = decoder.Decode(&obj.NestedStaticStruct) + if err != nil { + return err + } + return nil +} + +type TestItemAsArgs struct { + Field int32 + OracleID uint8 + OracleIDs [32]uint8 + AccountStruct AccountStruct + Accounts []solana.PublicKey + DifferentField string + BigField agbinary.Int128 + NestedDynamicStruct NestedDynamic + NestedStaticStruct NestedStatic +} + +func (obj TestItemAsArgs) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { + // Serialize `Field` param: + err = encoder.Encode(obj.Field) + if err != nil { + return err + } + // Serialize `OracleID` param: + err = encoder.Encode(obj.OracleID) + if err != nil { + return err + } + // Serialize `OracleIDs` param: + err = encoder.Encode(obj.OracleIDs) + if err != nil { + return err + } + // Serialize `AccountStruct` param: + err = encoder.Encode(obj.AccountStruct) + if err != nil { + return err + } + // Serialize `Accounts` param: + err = encoder.Encode(obj.Accounts) + if err != nil { + return err + } + // Serialize `DifferentField` param: + err = encoder.Encode(obj.DifferentField) + if err != nil { + return err + } + // Serialize `BigField` param: + err = encoder.Encode(obj.BigField) + if err != nil { + return err + } + // Serialize `NestedDynamicStruct` param: + err = encoder.Encode(obj.NestedDynamicStruct) + if err != nil { + return err + } + // Serialize `NestedStaticStruct` param: + err = encoder.Encode(obj.NestedStaticStruct) + if err != nil { + return err + } + return nil +} + +func (obj *TestItemAsArgs) UnmarshalWithDecoder(decoder *agbinary.Decoder) (err error) { + // Deserialize `Field`: + err = decoder.Decode(&obj.Field) + if err != nil { + return err + } + // Deserialize `OracleID`: + err = decoder.Decode(&obj.OracleID) + if err != nil { + return err + } + // Deserialize `OracleIDs`: + err = decoder.Decode(&obj.OracleIDs) + if err != nil { + return err + } + // Deserialize `AccountStruct`: + err = decoder.Decode(&obj.AccountStruct) + if err != nil { + return err + } + // Deserialize `Accounts`: + err = decoder.Decode(&obj.Accounts) + if err != nil { + return err + } + // Deserialize `DifferentField`: + err = decoder.Decode(&obj.DifferentField) + if err != nil { + return err + } + // Deserialize `BigField`: + err = decoder.Decode(&obj.BigField) + if err != nil { + return err + } + // Deserialize `NestedDynamicStruct`: + err = decoder.Decode(&obj.NestedDynamicStruct) + if err != nil { + return err + } + // Deserialize `NestedStaticStruct`: + err = decoder.Decode(&obj.NestedStaticStruct) + if err != nil { + return err + } + return nil +} + +type AccountStruct struct { + Account solana.PublicKey + AccountStr solana.PublicKey +} + +func (obj AccountStruct) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { + // Serialize `Account` param: + err = encoder.Encode(obj.Account) + if err != nil { + return err + } + // Serialize `AccountStr` param: + err = encoder.Encode(obj.AccountStr) + if err != nil { + return err + } + return nil +} + +func (obj *AccountStruct) UnmarshalWithDecoder(decoder *agbinary.Decoder) (err error) { + // Deserialize `Account`: + err = decoder.Decode(&obj.Account) + if err != nil { + return err + } + // Deserialize `AccountStr`: + err = decoder.Decode(&obj.AccountStr) + if err != nil { + return err + } + return nil +} + +type InnerDynamic struct { + IntVal int64 + S string +} + +func (obj InnerDynamic) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { + // Serialize `IntVal` param: + err = encoder.Encode(obj.IntVal) + if err != nil { + return err + } + // Serialize `S` param: + err = encoder.Encode(obj.S) + if err != nil { + return err + } + return nil +} + +func (obj *InnerDynamic) UnmarshalWithDecoder(decoder *agbinary.Decoder) (err error) { + // Deserialize `IntVal`: + err = decoder.Decode(&obj.IntVal) + if err != nil { + return err + } + // Deserialize `S`: + err = decoder.Decode(&obj.S) + if err != nil { + return err + } + return nil +} + +type NestedDynamic struct { + FixedBytes [2]uint8 + Inner InnerDynamic +} + +func (obj NestedDynamic) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { + // Serialize `FixedBytes` param: + err = encoder.Encode(obj.FixedBytes) + if err != nil { + return err + } + // Serialize `Inner` param: + err = encoder.Encode(obj.Inner) + if err != nil { + return err + } + return nil +} + +func (obj *NestedDynamic) UnmarshalWithDecoder(decoder *agbinary.Decoder) (err error) { + // Deserialize `FixedBytes`: + err = decoder.Decode(&obj.FixedBytes) + if err != nil { + return err + } + // Deserialize `Inner`: + err = decoder.Decode(&obj.Inner) + if err != nil { + return err + } + return nil +} + +type InnerStatic struct { + IntVal int64 + A solana.PublicKey +} + +func (obj InnerStatic) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { + // Serialize `IntVal` param: + err = encoder.Encode(obj.IntVal) + if err != nil { + return err + } + // Serialize `A` param: + err = encoder.Encode(obj.A) + if err != nil { + return err + } + return nil +} + +func (obj *InnerStatic) UnmarshalWithDecoder(decoder *agbinary.Decoder) (err error) { + // Deserialize `IntVal`: + err = decoder.Decode(&obj.IntVal) + if err != nil { + return err + } + // Deserialize `A`: + err = decoder.Decode(&obj.A) + if err != nil { + return err + } + return nil +} + +type NestedStatic struct { + FixedBytes [2]uint8 + Inner InnerStatic +} + +func (obj NestedStatic) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) { + // Serialize `FixedBytes` param: + err = encoder.Encode(obj.FixedBytes) + if err != nil { + return err + } + // Serialize `Inner` param: + err = encoder.Encode(obj.Inner) + if err != nil { + return err + } + return nil +} + +func (obj *NestedStatic) UnmarshalWithDecoder(decoder *agbinary.Decoder) (err error) { + // Deserialize `FixedBytes`: + err = decoder.Decode(&obj.FixedBytes) + if err != nil { + return err + } + // Deserialize `Inner`: + err = decoder.Decode(&obj.Inner) + if err != nil { + return err + } + return nil +} + +func EncodeRequestToTestItemAsAccount(testStruct interfacetests.TestStruct) TestItemAsAccount { + return TestItemAsAccount{ + Field: *testStruct.Field, + OracleID: uint8(testStruct.OracleID), + OracleIDs: getOracleIDs(testStruct), + AccountStruct: getAccountStruct(testStruct), + Accounts: getAccounts(testStruct), + DifferentField: testStruct.DifferentField, + BigField: bigIntToBinInt128(testStruct.BigField), + NestedDynamicStruct: getNestedDynamic(testStruct), + NestedStaticStruct: getNestedStatic(testStruct), + } +} + +func EncodeRequestToTestItemAsArgs(testStruct interfacetests.TestStruct) TestItemAsArgs { + return TestItemAsArgs{ + Field: *testStruct.Field, + OracleID: uint8(testStruct.OracleID), + OracleIDs: getOracleIDs(testStruct), + AccountStruct: getAccountStruct(testStruct), + Accounts: getAccounts(testStruct), + DifferentField: testStruct.DifferentField, + BigField: bigIntToBinInt128(testStruct.BigField), + NestedDynamicStruct: getNestedDynamic(testStruct), + NestedStaticStruct: getNestedStatic(testStruct), + } +} + +func getOracleIDs(testStruct interfacetests.TestStruct) [32]byte { + var oracleIDs [32]byte + for i, v := range testStruct.OracleIDs { + oracleIDs[i] = byte(v) + } + return oracleIDs +} + +func getAccountStruct(testStruct interfacetests.TestStruct) AccountStruct { + k, _ := solana.PublicKeyFromBase58(testStruct.AccountStruct.AccountStr) + return AccountStruct{ + Account: solana.PublicKeyFromBytes(testStruct.AccountStruct.Account), + AccountStr: k, + } +} + +func getAccounts(testStruct interfacetests.TestStruct) []solana.PublicKey { + accs := make([]solana.PublicKey, len(testStruct.Accounts)) + for i, v := range testStruct.Accounts { + accs[i] = solana.PublicKeyFromBytes(v) + } + return accs +} + +func getNestedDynamic(testStruct interfacetests.TestStruct) NestedDynamic { + return NestedDynamic{ + FixedBytes: testStruct.NestedDynamicStruct.FixedBytes, + Inner: InnerDynamic{ + IntVal: int64(testStruct.NestedDynamicStruct.Inner.I), + S: testStruct.NestedDynamicStruct.Inner.S, + }, + } +} + +func getNestedStatic(testStruct interfacetests.TestStruct) NestedStatic { + return NestedStatic{ + FixedBytes: testStruct.NestedStaticStruct.FixedBytes, + Inner: InnerStatic{ + IntVal: int64(testStruct.NestedStaticStruct.Inner.I), + A: solana.PublicKeyFromBytes(testStruct.NestedStaticStruct.Inner.A), + }, + } +} + +func bigIntToBinInt128(val *big.Int) agbinary.Int128 { + return agbinary.Int128{ + Lo: val.Uint64(), + Hi: new(big.Int).Rsh(val, 64).Uint64(), + } +} diff --git a/pkg/solana/codec/types.go b/pkg/solana/codec/types.go new file mode 100644 index 000000000..e047b36ae --- /dev/null +++ b/pkg/solana/codec/types.go @@ -0,0 +1,24 @@ +package codec + +import commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" + +type ChainConfigType string + +const ( + ChainConfigTypeAccountDef ChainConfigType = "account" + ChainConfigTypeInstructionDef ChainConfigType = "instruction" + ChainConfigTypeEventDef ChainConfigType = "event" +) + +type Config struct { + // Configs key is the type's offChainName for the codec + Configs map[string]ChainConfig `json:"configs" toml:"configs"` +} + +type ChainConfig struct { + IDL string `json:"typeIdl" toml:"typeIdl"` + OnChainName string `json:"onChainName" toml:"onChainName"` + // Type can be Solana Account, Instruction args, or TODO Event + Type ChainConfigType `json:"type" toml:"type"` + ModifierConfigs commoncodec.ModifiersConfig `json:"modifierConfigs,omitempty" toml:"modifierConfigs,omitempty"` +} diff --git a/pkg/solana/config/chain_reader.go b/pkg/solana/config/chain_reader.go index dbe9ef4ab..4251624fe 100644 --- a/pkg/solana/config/chain_reader.go +++ b/pkg/solana/config/chain_reader.go @@ -7,7 +7,7 @@ import ( "github.com/gagliardetto/solana-go" "github.com/gagliardetto/solana-go/rpc" - "github.com/smartcontractkit/chainlink-common/pkg/codec" + commoncodec "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings" "github.com/smartcontractkit/chainlink-common/pkg/codec/encodings/binary" "github.com/smartcontractkit/chainlink-common/pkg/types" @@ -82,7 +82,7 @@ type chainDataProcedureFields struct { IDLAccount string `json:"idlAccount,omitempty"` // OutputModifications provides modifiers to convert chain data format to custom // output formats. - OutputModifications codec.ModifiersConfig `json:"outputModifications,omitempty"` + OutputModifications commoncodec.ModifiersConfig `json:"outputModifications,omitempty"` // RPCOpts provides optional configurations for commitment, encoding, and data // slice offsets. RPCOpts *RPCOpts `json:"rpcOpts,omitempty"` diff --git a/sonar-project.properties b/sonar-project.properties index 0434465b5..bec9533fd 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -5,9 +5,10 @@ sonar.sources=. # Full exclusions from the static analysis sonar.exclusions=**/node_modules/**/*, **/contracts/artifacts/**/*, **/generated/**/*, **/docs/**/*, **/*.config.ts, **/*.config.js, **/*.txt, pkg/solana/codec/anchoridl.go # Coverage exclusions -sonar.coverage.exclusions=**/*.test.ts, **/*_test.go, **/contracts/tests/**/*, **/integration-tests/**/* +sonar.coverage.exclusions=**/*.test.ts, **/*_test.go, **/contracts/tests/**/*, **/integration-tests/**/*, **/pkg/solana/codec/testutils/**/* # Tests' root folder, inclusions (tests to check and count) and exclusions sonar.tests=. sonar.test.inclusions=**/*_test.go, **/contracts/tests/**/* -sonar.test.exclusions=**/integration-tests/*, **/gauntlet/* \ No newline at end of file +sonar.test.exclusions=**/integration-tests/*, **/gauntlet/* +sonar.cpd.exclusions=**/pkg/solana/codec/testutils/**/* \ No newline at end of file From 7e180167704810f3b915e37b9a8a81b503a8199e Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 19 Dec 2024 12:28:16 -0700 Subject: [PATCH 4/4] add a pop method to arrayvec (#963) * add a pop method to arrayvec * rustfmt --- contracts/Cargo.lock | 6 ++-- contracts/crates/arrayvec/Cargo.toml | 2 +- contracts/crates/arrayvec/src/lib.rs | 43 ++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/contracts/Cargo.lock b/contracts/Cargo.lock index f265cd8ce..0209f1dbb 100644 --- a/contracts/Cargo.lock +++ b/contracts/Cargo.lock @@ -7,7 +7,7 @@ name = "access-controller" version = "1.0.1" dependencies = [ "anchor-lang", - "arrayvec 1.0.0", + "arrayvec 1.0.1", "bytemuck", "static_assertions", ] @@ -379,7 +379,7 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrayvec" -version = "1.0.0" +version = "1.0.1" [[package]] name = "assert_matches" @@ -1374,7 +1374,7 @@ dependencies = [ "anchor-lang", "anchor-spl", "arrayref", - "arrayvec 1.0.0", + "arrayvec 1.0.1", "bytemuck", "solana-program", "static_assertions", diff --git a/contracts/crates/arrayvec/Cargo.toml b/contracts/crates/arrayvec/Cargo.toml index 4e9ad51cc..5570a5636 100644 --- a/contracts/crates/arrayvec/Cargo.toml +++ b/contracts/crates/arrayvec/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "arrayvec" -version = "1.0.0" +version = "1.0.1" edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/contracts/crates/arrayvec/src/lib.rs b/contracts/crates/arrayvec/src/lib.rs index dfc8a12e3..9312c2338 100644 --- a/contracts/crates/arrayvec/src/lib.rs +++ b/contracts/crates/arrayvec/src/lib.rs @@ -90,6 +90,17 @@ macro_rules! arrayvec { self.xs[offset..offset + len].copy_from_slice(&data); self.len += len as $capacity_ty; } + + /// Removes the last element from the array and returns it. + /// Returns `None` if the array is empty. + pub fn pop(&mut self) -> Option<$ty> { + if self.len > 0 { + self.len -= 1; + Some(self.xs[self.len as usize]) + } else { + None + } + } } impl std::ops::Deref for $name { @@ -126,6 +137,38 @@ mod tests { } arrayvec!(ArrayVec, u8, u32); + #[test] + fn push_pop() { + let mut vec = ArrayVec::new(); + assert!(vec.is_empty()); + assert_eq!(vec.pop(), None); + + vec.push(10); + assert_eq!(vec.len(), 1); + assert_eq!(vec.as_slice(), &[10]); + + vec.push(20); + vec.push(30); + assert_eq!(vec.len(), 3); + assert_eq!(vec.as_slice(), &[10, 20, 30]); + + // Popping elements + assert_eq!(vec.pop(), Some(30)); + assert_eq!(vec.len(), 2); + assert_eq!(vec.as_slice(), &[10, 20]); + + assert_eq!(vec.pop(), Some(20)); + assert_eq!(vec.len(), 1); + assert_eq!(vec.as_slice(), &[10]); + + assert_eq!(vec.pop(), Some(10)); + assert_eq!(vec.len(), 0); + assert!(vec.is_empty()); + + // Popping from empty vec + assert_eq!(vec.pop(), None); + } + #[test] fn remove() { let mut vec = ArrayVec::new();