diff --git a/Makefile b/Makefile index cfdb71adf4..230d47f200 100644 --- a/Makefile +++ b/Makefile @@ -120,7 +120,7 @@ test-short: ## test-race: Run unit tests in race mode. test-race: @echo "--> Running tests in race mode" - @go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestQGBRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestQGBCLI|TestUpgrade|TestMaliciousTestNode|TestMaxTotalBlobSizeSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext" + @go test ./... -v -race -skip "TestPrepareProposalConsistency|TestIntegrationTestSuite|TestQGBRPCQueries|TestSquareSizeIntegrationTest|TestStandardSDKIntegrationTestSuite|TestTxsimCommandFlags|TestTxsimCommandEnvVar|TestMintIntegrationTestSuite|TestQGBCLI|TestUpgrade|TestMaliciousTestNode|TestMaxTotalBlobSizeSuite|TestQGBIntegrationSuite|TestSignerTestSuite|TestPriorityTestSuite|TestTimeInPrepareProposalContext|TestConcurrentTxSubmission" .PHONY: test-race ## test-bench: Run unit tests in bench mode. diff --git a/app/errors/nonce_mismatch.go b/app/errors/nonce_mismatch.go index 2726d61060..8209aac8b7 100644 --- a/app/errors/nonce_mismatch.go +++ b/app/errors/nonce_mismatch.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strconv" + "strings" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" ) @@ -13,6 +14,11 @@ func IsNonceMismatch(err error) bool { return errors.Is(err, sdkerrors.ErrWrongSequence) } +// IsNonceMismatch checks if the error code matches the sequence mismatch. +func IsNonceMismatchCode(code uint32) bool { + return code == sdkerrors.ErrWrongSequence.ABCICode() +} + // ParseNonceMismatch extracts the expected sequence number from the // ErrWrongSequence error. func ParseNonceMismatch(err error) (uint64, error) { @@ -20,9 +26,19 @@ func ParseNonceMismatch(err error) (uint64, error) { return 0, errors.New("error is not a sequence mismatch") } - numbers := regexpInt.FindAllString(err.Error(), -1) + return ParseExpectedSequence(err.Error()) +} + +// ParseExpectedSequence extracts the expected sequence number from the +// ErrWrongSequence error. +func ParseExpectedSequence(str string) (uint64, error) { + if !strings.HasPrefix(str, "account sequence mismatch") { + return 0, fmt.Errorf("unexpected wrong sequence error: %s", str) + } + + numbers := regexpInt.FindAllString(str, -1) if len(numbers) != 2 { - return 0, fmt.Errorf("unexpected wrong sequence error: %w", err) + return 0, fmt.Errorf("expected two numbers in string, got %d", len(numbers)) } // the first number is the expected sequence number diff --git a/app/test/priority_test.go b/app/test/priority_test.go index 6605cb564d..87639ef180 100644 --- a/app/test/priority_test.go +++ b/app/test/priority_test.go @@ -3,6 +3,7 @@ package app_test import ( "encoding/hex" "sort" + "sync" "testing" "time" @@ -70,43 +71,47 @@ func (s *PriorityTestSuite) TestPriorityByGasPrice() { t := s.T() // quickly submit blobs with a random fee - hashes := make([]string, 0, len(s.signers)) + + hashes := make(chan string, len(s.signers)) + blobSize := uint32(100) + gasLimit := blobtypes.DefaultEstimateGas([]uint32{blobSize}) + wg := &sync.WaitGroup{} for _, signer := range s.signers { - blobSize := uint32(100) - gasLimit := blobtypes.DefaultEstimateGas([]uint32{blobSize}) - gasPrice := s.rand.Float64() - btx, err := signer.CreatePayForBlob( - blobfactory.ManyBlobs( - t, - s.rand, - []namespace.Namespace{namespace.RandomBlobNamespace()}, - []int{100}), - user.SetGasLimitAndFee(gasLimit, gasPrice), - ) - require.NoError(t, err) - resp, err := signer.BroadcastTx(s.cctx.GoContext(), btx) - require.NoError(t, err) - require.Equal(t, abci.CodeTypeOK, resp.Code) - hashes = append(hashes, resp.TxHash) + wg.Add(1) + go func() { + defer wg.Done() + gasPrice := float64(s.rand.Intn(1000)+1) / 1000 + resp, err := signer.SubmitPayForBlob( + s.cctx.GoContext(), + blobfactory.ManyBlobs( + t, + s.rand, + []namespace.Namespace{namespace.RandomBlobNamespace()}, + []int{100}), + user.SetGasLimitAndFee(gasLimit, gasPrice), + ) + require.NoError(t, err) + require.Equal(t, abci.CodeTypeOK, resp.Code, resp.RawLog) + hashes <- resp.TxHash + }() } + wg.Wait() + close(hashes) + err := s.cctx.WaitForNextBlock() require.NoError(t, err) // get the responses for each tx for analysis and sort by height // note: use rpc types because they contain the tx index heightMap := make(map[int64][]*rpctypes.ResultTx) - for _, hash := range hashes { - resp, err := s.signers[0].ConfirmTx(s.cctx.GoContext(), hash) - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, abci.CodeTypeOK, resp.Code) + for hash := range hashes { // use the core rpc type because it contains the tx index hash, err := hex.DecodeString(hash) require.NoError(t, err) coreRes, err := s.cctx.Client.Tx(s.cctx.GoContext(), hash, false) require.NoError(t, err) - heightMap[resp.Height] = append(heightMap[resp.Height], coreRes) + heightMap[coreRes.Height] = append(heightMap[coreRes.Height], coreRes) } require.GreaterOrEqual(t, len(heightMap), 1) @@ -123,7 +128,7 @@ func (s *PriorityTestSuite) TestPriorityByGasPrice() { // check that there was at least one block with more than three transactions // in it. This is more of a sanity check than a test. - require.True(t, highestNumOfTxsPerBlock > 3) + require.Greater(t, highestNumOfTxsPerBlock, 3) } func sortByIndex(txs []*rpctypes.ResultTx) []*rpctypes.ResultTx { @@ -135,14 +140,14 @@ func sortByIndex(txs []*rpctypes.ResultTx) []*rpctypes.ResultTx { func isSortedByFee(t *testing.T, ecfg encoding.Config, responses []*rpctypes.ResultTx) bool { for i := 0; i < len(responses)-1; i++ { - if gasPrice(t, ecfg, responses[i]) <= gasPrice(t, ecfg, responses[i+1]) { + if getGasPrice(t, ecfg, responses[i]) <= getGasPrice(t, ecfg, responses[i+1]) { return false } } return true } -func gasPrice(t *testing.T, ecfg encoding.Config, resp *rpctypes.ResultTx) float64 { +func getGasPrice(t *testing.T, ecfg encoding.Config, resp *rpctypes.ResultTx) float64 { sdkTx, err := ecfg.TxConfig.TxDecoder()(resp.Tx) require.NoError(t, err) feeTx := sdkTx.(sdk.FeeTx) diff --git a/pkg/user/e2e_test.go b/pkg/user/e2e_test.go new file mode 100644 index 0000000000..b195d27ed8 --- /dev/null +++ b/pkg/user/e2e_test.go @@ -0,0 +1,85 @@ +package user_test + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/celestiaorg/celestia-app/app" + "github.com/celestiaorg/celestia-app/app/encoding" + "github.com/celestiaorg/celestia-app/pkg/appconsts" + "github.com/celestiaorg/celestia-app/pkg/user" + "github.com/celestiaorg/celestia-app/test/util/blobfactory" + "github.com/celestiaorg/celestia-app/test/util/testnode" + "github.com/stretchr/testify/require" + tmrand "github.com/tendermint/tendermint/libs/rand" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" +) + +func TestConcurrentTxSubmission(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + + // Setup network + tmConfig := testnode.DefaultTendermintConfig() + tmConfig.Consensus.TimeoutCommit = 10 * time.Second + ctx, _, _ := testnode.NewNetwork(t, testnode.DefaultConfig().WithTendermintConfig(tmConfig)) + _, err := ctx.WaitForHeight(1) + require.NoError(t, err) + + // Setup signer + signer, err := newSingleSignerFromContext(ctx) + require.NoError(t, err) + + // Pregenerate all the blobs + numTxs := 10 + blobs := blobfactory.ManyRandBlobs(t, tmrand.NewRand(), blobfactory.Repeat(2048, numTxs)...) + + // Prepare transactions + var ( + wg sync.WaitGroup + errCh = make(chan error) + ) + + subCtx, cancel := context.WithCancel(ctx.GoContext()) + defer cancel() + time.AfterFunc(time.Minute, cancel) + for i := 0; i < numTxs; i++ { + wg.Add(1) + go func(b *tmproto.Blob) { + defer wg.Done() + _, err := signer.SubmitPayForBlob(subCtx, []*tmproto.Blob{b}, user.SetGasLimitAndFee(500_000, appconsts.DefaultMinGasPrice)) + if err != nil && !errors.Is(err, context.Canceled) { + // only catch the first error + select { + case errCh <- err: + cancel() + default: + } + } + }(blobs[i]) + } + wg.Wait() + + select { + case err := <-errCh: + require.NoError(t, err) + default: + } +} + +func newSingleSignerFromContext(ctx testnode.Context) (*user.Signer, error) { + encCfg := encoding.MakeConfig(app.ModuleEncodingRegisters...) + record, err := ctx.Keyring.Key("validator") + if err != nil { + return nil, err + } + address, err := record.GetAddress() + if err != nil { + return nil, err + } + return user.SetupSigner(ctx.GoContext(), ctx.Keyring, ctx.GRPCClient, address, encCfg) +} diff --git a/pkg/user/signer.go b/pkg/user/signer.go index 6f3918d18f..0f84556bed 100644 --- a/pkg/user/signer.go +++ b/pkg/user/signer.go @@ -9,16 +9,18 @@ import ( "time" "github.com/celestiaorg/celestia-app/app/encoding" + apperrors "github.com/celestiaorg/celestia-app/app/errors" blob "github.com/celestiaorg/celestia-app/x/blob/types" "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/grpc/tmservice" "github.com/cosmos/cosmos-sdk/crypto/keyring" cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" sdktypes "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/types/tx" + sdktx "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx/signing" authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + abci "github.com/tendermint/tendermint/abci/types" tmproto "github.com/tendermint/tendermint/proto/tendermint/types" tmtypes "github.com/tendermint/tendermint/types" "google.golang.org/grpc" @@ -35,11 +37,18 @@ type Signer struct { pk cryptotypes.PubKey chainID string accountNumber uint64 - pollTime time.Duration - mtx sync.RWMutex - lastSignedSequence uint64 - lastConfirmedSequence uint64 + mtx sync.RWMutex + // how often to poll the network for confirmation of a transaction + pollTime time.Duration + // the signers local view of the sequence number + localSequence uint64 + // the chains last known sequence number + networkSequence uint64 + // lookup map of all pending and yet to be confirmed outbound transactions + outboundSequences map[uint64]struct{} + // a reverse map for confirming which sequence numbers have been committed + reverseTxHashSequenceMap map[string]uint64 } // NewSigner returns a new signer using the provided keyring @@ -64,16 +73,18 @@ func NewSigner( } return &Signer{ - keys: keys, - address: address, - grpc: conn, - enc: enc, - pk: pk, - chainID: chainID, - accountNumber: accountNumber, - lastSignedSequence: sequence, - lastConfirmedSequence: sequence, - pollTime: DefaultPollTime, + keys: keys, + address: address, + grpc: conn, + enc: enc, + pk: pk, + chainID: chainID, + accountNumber: accountNumber, + localSequence: sequence, + networkSequence: sequence, + pollTime: DefaultPollTime, + outboundSequences: make(map[uint64]struct{}), + reverseTxHashSequenceMap: make(map[string]uint64), }, nil } @@ -125,17 +136,14 @@ func SetupSigner( // SubmitTx forms a transaction from the provided messages, signs it, and submits it to the chain. TxOptions // may be provided to set the fee and gas limit. func (s *Signer) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (*sdktypes.TxResponse, error) { - txBytes, err := s.CreateTx(msgs, opts...) + tx, err := s.CreateTx(msgs, opts...) if err != nil { return nil, err } - resp, err := s.BroadcastTx(ctx, txBytes) + resp, err := s.BroadcastTx(ctx, tx) if err != nil { - return nil, err - } - if resp.Code != 0 { - return resp, fmt.Errorf("tx failed with code %d: %s", resp.Code, resp.RawLog) + return resp, err } return s.ConfirmTx(ctx, resp.TxHash) @@ -144,25 +152,35 @@ func (s *Signer) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOp // SubmitPayForBlob forms a transaction from the provided blobs, signs it, and submits it to the chain. // TxOptions may be provided to set the fee and gas limit. func (s *Signer) SubmitPayForBlob(ctx context.Context, blobs []*tmproto.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { - txBytes, err := s.CreatePayForBlob(blobs, opts...) + resp, err := s.broadcastPayForBlob(ctx, blobs, opts...) if err != nil { - return nil, err + return resp, err } - resp, err := s.BroadcastTx(ctx, txBytes) + return s.ConfirmTx(ctx, resp.TxHash) +} + +func (s *Signer) broadcastPayForBlob(ctx context.Context, blobs []*blob.Blob, opts ...TxOption) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + txBytes, seqNum, err := s.createPayForBlobs(blobs, opts...) if err != nil { return nil, err } - if resp.Code != 0 { - return resp, fmt.Errorf("tx failed with code %d: %s", resp.Code, resp.RawLog) - } - return s.ConfirmTx(ctx, resp.TxHash) + return s.broadcastTx(ctx, txBytes, seqNum) } // CreateTx forms a transaction from the provided messages and signs it. TxOptions may be optionally // used to set the gas limit and fee. -func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) ([]byte, error) { +func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + + return s.createTx(msgs, opts...) +} + +func (s *Signer) createTx(msgs []sdktypes.Msg, opts ...TxOption) (authsigning.Tx, error) { txBuilder := s.txBuilder(opts...) if err := txBuilder.SetMsgs(msgs...); err != nil { return nil, err @@ -172,71 +190,200 @@ func (s *Signer) CreateTx(msgs []sdktypes.Msg, opts ...TxOption) ([]byte, error) return nil, err } - return s.enc.TxEncoder()(txBuilder.GetTx()) + return txBuilder.GetTx(), nil } func (s *Signer) CreatePayForBlob(blobs []*tmproto.Blob, opts ...TxOption) ([]byte, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + blobTx, _, err := s.createPayForBlobs(blobs, opts...) + return blobTx, err +} + +func (s *Signer) createPayForBlobs(blobs []*tmproto.Blob, opts ...TxOption) ([]byte, uint64, error) { msg, err := blob.NewMsgPayForBlobs(s.address.String(), blobs...) if err != nil { - return nil, err + return nil, 0, err } - txBytes, err := s.CreateTx([]sdktypes.Msg{msg}, opts...) + tx, err := s.createTx([]sdktypes.Msg{msg}, opts...) if err != nil { - return nil, err + return nil, 0, err + } + + seqNum, err := getSequenceNumber(tx) + if err != nil { + panic(err) + } + + txBytes, err := s.EncodeTx(tx) + if err != nil { + return nil, 0, err } - return tmtypes.MarshalBlobTx(txBytes, blobs...) + blobTx, err := tmtypes.MarshalBlobTx(txBytes, blobs...) + return blobTx, seqNum, err +} + +func (s *Signer) EncodeTx(tx sdktypes.Tx) ([]byte, error) { + return s.enc.TxEncoder()(tx) +} + +func (s *Signer) DecodeTx(txBytes []byte) (authsigning.Tx, error) { + tx, err := s.enc.TxDecoder()(txBytes) + if err != nil { + return nil, err + } + authTx, ok := tx.(authsigning.Tx) + if !ok { + return nil, errors.New("not an authsigning transaction") + } + return authTx, nil } // BroadcastTx submits the provided transaction bytes to the chain and returns the response. -func (s *Signer) BroadcastTx(ctx context.Context, txBytes []byte) (*sdktypes.TxResponse, error) { - txClient := tx.NewServiceClient(s.grpc) +func (s *Signer) BroadcastTx(ctx context.Context, tx authsigning.Tx) (*sdktypes.TxResponse, error) { + s.mtx.Lock() + defer s.mtx.Unlock() + txBytes, err := s.EncodeTx(tx) + if err != nil { + return nil, err + } + sequence, err := getSequenceNumber(tx) + if err != nil { + return nil, err + } + return s.broadcastTx(ctx, txBytes, sequence) +} - // TODO (@cmwaters): handle nonce mismatch errors +// CONTRACT: assumes the caller has the lock +func (s *Signer) broadcastTx(ctx context.Context, txBytes []byte, sequence uint64) (*sdktypes.TxResponse, error) { + if _, exists := s.outboundSequences[sequence]; exists { + return s.retryBroadcastingTx(ctx, txBytes, sequence+1) + } + + if sequence < s.networkSequence { + s.localSequence = s.networkSequence + return s.retryBroadcastingTx(ctx, txBytes, s.localSequence) + } + + txClient := sdktx.NewServiceClient(s.grpc) resp, err := txClient.BroadcastTx( ctx, - &tx.BroadcastTxRequest{ - Mode: tx.BroadcastMode_BROADCAST_MODE_SYNC, + &sdktx.BroadcastTxRequest{ + Mode: sdktx.BroadcastMode_BROADCAST_MODE_SYNC, TxBytes: txBytes, }, ) if err != nil { return nil, err } - return resp.TxResponse, nil + if apperrors.IsNonceMismatchCode(resp.TxResponse.Code) { + // extract what the lastCommittedNonce on chain is + nextSequence, err := apperrors.ParseExpectedSequence(resp.TxResponse.RawLog) + if err != nil { + return nil, fmt.Errorf("parsing nonce mismatch upon retry: %w", err) + } + s.networkSequence = nextSequence + s.localSequence = nextSequence + // FIXME: We can't actually resign the transaction. A malicious node + // may manipulate us into signing the same transaction several times + // and then executing them. We need some proof of what the last network + // sequence is rather than relying on an error provided by the node + // return s.retryBroadcastingTx(ctx, txBytes, nextSequence) + // Ref: https://github.com/celestiaorg/celestia-app/issues/3256 + // return s.retryBroadcastingTx(ctx, txBytes, nextSequence) + } else if resp.TxResponse.Code == abci.CodeTypeOK { + s.outboundSequences[sequence] = struct{}{} + s.reverseTxHashSequenceMap[resp.TxResponse.TxHash] = sequence + return resp.TxResponse, nil + } + return resp.TxResponse, fmt.Errorf("tx failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) +} + +// retryBroadcastingTx creates a new transaction by copying over an existing transaction but creates a new signature with the +// new sequence number. It then calls `broadcastTx` and attempts to submit the transaction +func (s *Signer) retryBroadcastingTx(ctx context.Context, txBytes []byte, newSequenceNumber uint64) (*sdktypes.TxResponse, error) { + blobTx, isBlobTx := tmtypes.UnmarshalBlobTx(txBytes) + if isBlobTx { + txBytes = blobTx.Tx + } + tx, err := s.DecodeTx(txBytes) + if err != nil { + return nil, err + } + txBuilder := s.txBuilder() + if err := txBuilder.SetMsgs(tx.GetMsgs()...); err != nil { + return nil, err + } + if granter := tx.FeeGranter(); granter != nil { + txBuilder.SetFeeGranter(granter) + } + if payer := tx.FeePayer(); payer != nil { + txBuilder.SetFeePayer(payer) + } + if memo := tx.GetMemo(); memo != "" { + txBuilder.SetMemo(memo) + } + if fee := tx.GetFee(); fee != nil { + txBuilder.SetFeeAmount(fee) + } + if gas := tx.GetGas(); gas > 0 { + txBuilder.SetGasLimit(gas) + } + + if err := s.signTransaction(txBuilder, newSequenceNumber); err != nil { + return nil, fmt.Errorf("resigning transaction: %w", err) + } + + newTxBytes, err := s.EncodeTx(txBuilder.GetTx()) + if err != nil { + return nil, err + } + + // rewrap the blob tx if it was originally a blob tx + if isBlobTx { + newTxBytes, err = tmtypes.MarshalBlobTx(newTxBytes, blobTx.Blobs...) + if err != nil { + return nil, err + } + } + + return s.broadcastTx(ctx, newTxBytes, newSequenceNumber) } // ConfirmTx periodically pings the provided node for the commitment of a transaction by its // hash. It will continually loop until the context is cancelled, the tx is found or an error // is encountered. func (s *Signer) ConfirmTx(ctx context.Context, txHash string) (*sdktypes.TxResponse, error) { - txClient := tx.NewServiceClient(s.grpc) + txClient := sdktx.NewServiceClient(s.grpc) + + pollTime := s.getPollTime() timer := time.NewTimer(0) defer timer.Stop() + for { select { case <-ctx.Done(): return &sdktypes.TxResponse{}, ctx.Err() case <-timer.C: - resp, err := txClient.GetTx( - ctx, - &tx.GetTxRequest{ - Hash: txHash, - }, - ) + resp, err := txClient.GetTx(ctx, &sdktx.GetTxRequest{Hash: txHash}) if err == nil { if resp.TxResponse.Code != 0 { - return resp.TxResponse, fmt.Errorf("tx failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) + s.updateNetworkSequence(txHash, false) + return resp.TxResponse, fmt.Errorf("tx was included but failed with code %d: %s", resp.TxResponse.Code, resp.TxResponse.RawLog) } + s.updateNetworkSequence(txHash, true) return resp.TxResponse, nil } - + // FIXME: this is a relatively brittle of working out whether to retry or not. The tx might be not found for other + // reasons. It may have been removed from the mempool at a later point. We should build an endpoint that gives the + // signer more information on the status of their transaction and then update the logic here if !strings.Contains(err.Error(), "not found") { return &sdktypes.TxResponse{}, err } - timer.Reset(s.pollTime) + timer.Reset(pollTime) } } } @@ -247,7 +394,7 @@ func (s *Signer) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...T return 0, err } - if err := s.signTransaction(txBuilder, s.Sequence()); err != nil { + if err := s.signTransaction(txBuilder, s.LocalSequence()); err != nil { return 0, err } @@ -256,7 +403,7 @@ func (s *Signer) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...T return 0, err } - resp, err := tx.NewServiceClient(s.grpc).Simulate(ctx, &tx.SimulateRequest{ + resp, err := sdktx.NewServiceClient(s.grpc).Simulate(ctx, &sdktx.SimulateRequest{ TxBytes: txBytes, }) if err != nil { @@ -288,37 +435,67 @@ func (s *Signer) SetPollTime(pollTime time.Duration) { s.pollTime = pollTime } +func (s *Signer) getPollTime() time.Duration { + s.mtx.Lock() + defer s.mtx.Unlock() + return s.pollTime +} + // PubKey returns the public key of the signer func (s *Signer) PubKey() cryptotypes.PubKey { return s.pk } -func (s *Signer) Sequence() uint64 { - s.mtx.Lock() - defer s.mtx.Unlock() - return s.lastSignedSequence -} - // DEPRECATED: use Sequence instead func (s *Signer) GetSequence() uint64 { return s.getAndIncrementSequence() } -// getAndIncrementSequence gets the lastest signed sequnce and increments the local sequence number +// LocalSequence returns the next sequence number of the signers +// locally saved +func (s *Signer) LocalSequence() uint64 { + s.mtx.RLock() + defer s.mtx.RUnlock() + return s.localSequence +} + +func (s *Signer) NetworkSequence() uint64 { + s.mtx.RLock() + defer s.mtx.RUnlock() + return s.networkSequence +} + +// getAndIncrementSequence gets the latest signed sequence and increments the +// local sequence number func (s *Signer) getAndIncrementSequence() uint64 { + defer func() { s.localSequence++ }() + return s.localSequence +} + +// ForceSetSequence manually overrides the current local and network level +// sequence number. Be careful when invoking this as it may cause the +// transactions to reject the sequence if it doesn't match the one in state +func (s *Signer) ForceSetSequence(seq uint64) { s.mtx.Lock() defer s.mtx.Unlock() - defer func() { s.lastSignedSequence++ }() - return s.lastSignedSequence + s.localSequence = seq + s.networkSequence = seq } -// ForceSetSequence manually overrides the current sequence number. Be careful when -// invoking this as it may cause the transactions to reject the sequence if -// it doesn't match the one in state -func (s *Signer) ForceSetSequence(seq uint64) { +// updateNetworkSequence is called once a transaction is confirmed +// and updates the chains last known sequence number +func (s *Signer) updateNetworkSequence(txHash string, success bool) { s.mtx.Lock() defer s.mtx.Unlock() - s.lastSignedSequence = seq + sequence, exists := s.reverseTxHashSequenceMap[txHash] + if !exists { + return + } + if success && sequence >= s.networkSequence { + s.networkSequence = sequence + 1 + } + delete(s.outboundSequences, sequence) + delete(s.reverseTxHashSequenceMap, txHash) } // Keyring exposes the signers underlying keyring @@ -426,3 +603,15 @@ func (s *Signer) getSignatureV2(sequence uint64, signature []byte) signing.Signa } return sigV2 } + +func getSequenceNumber(tx authsigning.Tx) (uint64, error) { + sigs, err := tx.GetSignaturesV2() + if err != nil { + return 0, err + } + if len(sigs) > 1 { + return 0, fmt.Errorf("only a signle signature is supported, got %d", len(sigs)) + } + + return sigs[0].Sequence, nil +} diff --git a/pkg/user/signer_test.go b/pkg/user/signer_test.go index ddea206ea7..ade88b3a00 100644 --- a/pkg/user/signer_test.go +++ b/pkg/user/signer_test.go @@ -14,7 +14,6 @@ import ( "github.com/celestiaorg/celestia-app/test/util/testnode" sdk "github.com/cosmos/cosmos-sdk/types" bank "github.com/cosmos/cosmos-sdk/x/bank/types" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" abci "github.com/tendermint/tendermint/abci/types" @@ -78,28 +77,30 @@ func (s *SignerTestSuite) TestConfirmTx() { gas := user.SetGasLimit(1e6) t.Run("deadline exceeded when the context times out", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), time.Second) defer cancel() _, err := s.signer.ConfirmTx(ctx, "E32BD15CAF57AF15D17B0D63CF4E63A9835DD1CEBB059C335C79586BC3013728") - assert.Error(t, err) - assert.Contains(t, err.Error(), context.DeadlineExceeded.Error()) + require.Error(t, err) + require.Contains(t, err.Error(), context.DeadlineExceeded.Error()) }) t.Run("should error when tx is not found", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), 5*time.Second) defer cancel() _, err := s.signer.ConfirmTx(ctx, "not found tx") - assert.Error(t, err) + require.Error(t, err) }) t.Run("should success when tx is found immediately", func(t *testing.T) { msg := bank.NewMsgSend(s.signer.Address(), testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) resp, err := s.submitTxWithoutConfirm([]sdk.Msg{msg}, fee, gas) - assert.NoError(t, err) - assert.NotNil(t, resp) - resp, err = s.signer.ConfirmTx(s.ctx.GoContext(), resp.TxHash) - assert.NoError(t, err) - assert.Equal(t, abci.CodeTypeOK, resp.Code) + require.NoError(t, err) + require.NotNil(t, resp) + ctx, cancel := context.WithTimeout(s.ctx.GoContext(), 30*time.Second) + defer cancel() + resp, err = s.signer.ConfirmTx(ctx, resp.TxHash) + require.NoError(t, err) + require.Equal(t, abci.CodeTypeOK, resp.Code) }) t.Run("should error when tx is found with a non-zero error code", func(t *testing.T) { @@ -107,17 +108,17 @@ func (s *SignerTestSuite) TestConfirmTx() { // Create a msg send with out of balance, ensure this tx fails msg := bank.NewMsgSend(s.signer.Address(), testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 1+balance))) resp, err := s.submitTxWithoutConfirm([]sdk.Msg{msg}, fee, gas) - assert.NoError(t, err) - assert.NotNil(t, resp) + require.NoError(t, err) + require.NotNil(t, resp) resp, err = s.signer.ConfirmTx(s.ctx.GoContext(), resp.TxHash) - assert.Error(t, err) - assert.NotEqual(t, abci.CodeTypeOK, resp.Code) + require.Error(t, err) + require.NotEqual(t, abci.CodeTypeOK, resp.Code) }) } func (s *SignerTestSuite) TestGasEstimation() { msg := bank.NewMsgSend(s.signer.Address(), testfactory.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) - gas, err := s.signer.EstimateGas(context.Background(), []sdk.Msg{msg}) + gas, err := s.signer.EstimateGas(s.ctx.GoContext(), []sdk.Msg{msg}) require.NoError(s.T(), err) require.Greater(s.T(), gas, uint64(0)) } @@ -148,13 +149,13 @@ func (s *SignerTestSuite) TestGasConsumption() { // verify that the amount deducted depends on the fee set in the tx. amountDeducted := balanceBefore - balanceAfter - utiaToSend - assert.Equal(t, int64(fee), amountDeducted) + require.Equal(t, int64(fee), amountDeducted) // verify that the amount deducted does not depend on the actual gas used. gasUsedBasedDeduction := resp.GasUsed * gasPrice - assert.NotEqual(t, gasUsedBasedDeduction, amountDeducted) + require.NotEqual(t, gasUsedBasedDeduction, amountDeducted) // The gas used based deduction should be less than the fee because the fee is 1 TIA. - assert.Less(t, gasUsedBasedDeduction, int64(fee)) + require.Less(t, gasUsedBasedDeduction, int64(fee)) } func (s *SignerTestSuite) queryCurrentBalance(t *testing.T) int64 { diff --git a/test/util/direct_tx_gen.go b/test/util/direct_tx_gen.go index 6340ce52f6..0e6d4dcb17 100644 --- a/test/util/direct_tx_gen.go +++ b/test/util/direct_tx_gen.go @@ -111,7 +111,7 @@ func DirectQueryAccount(app *app.App, addr sdk.AccAddress) authtypes.AccountI { // provided configuration. One blob transaction is generated per account // provided. The sequence and account numbers are set manually using the provided values. func RandBlobTxsWithManualSequence( - _ *testing.T, + t *testing.T, _ sdk.TxEncoder, kr keyring.Keyring, size int, @@ -172,25 +172,19 @@ func RandBlobTxsWithManualSequence( } if invalidSignature { invalidSig, err := builder.GetTx().GetSignaturesV2() - if err != nil { - panic(err) - } + require.NoError(t, err) invalidSig[0].Data.(*signing.SingleSignatureData).Signature = []byte("invalid signature") - if err := builder.SetSignatures(invalidSig...); err != nil { - panic(err) - } + err = builder.SetSignatures(invalidSig...) + require.NoError(t, err) stx = builder.GetTx() } rawTx, err := signer.EncodeTx(stx) - if err != nil { - panic(err) - } + require.NoError(t, err) + cTx, err := coretypes.MarshalBlobTx(rawTx, blobs...) - if err != nil { - panic(err) - } + require.NoError(t, err) txs[i] = cTx }