diff --git a/blockchain/validate.go b/blockchain/validate.go index d0dbf6b4dc..438f455428 100644 --- a/blockchain/validate.go +++ b/blockchain/validate.go @@ -5,6 +5,7 @@ package blockchain import ( + "bytes" "encoding/binary" "fmt" "math" @@ -41,6 +42,10 @@ const ( // baseSubsidy is the starting subsidy amount for mined blocks. This // value is halved every SubsidyHalvingInterval blocks. baseSubsidy = 50 * btcutil.SatoshiPerBitcoin + + // coinbaseHeightAllocSize is the amount of bytes that the + // ScriptBuilder will allocate when validating the coinbase height. + coinbaseHeightAllocSize = 5 ) var ( @@ -610,16 +615,25 @@ func ExtractCoinbaseHeight(coinbaseTx *btcutil.Tx) (int32, error) { return 0, ruleError(ErrMissingCoinbaseHeight, str) } - serializedHeightBytes := make([]byte, 8) - copy(serializedHeightBytes, sigScript[1:serializedLen+1]) - serializedHeight := binary.LittleEndian.Uint64(serializedHeightBytes) + // We use 4 bytes here since it saves us allocations. We use a stack + // allocation rather than a heap allocation here. + var serializedHeightBytes [4]byte + copy(serializedHeightBytes[:], sigScript[1:serializedLen+1]) + + serializedHeight := int32( + binary.LittleEndian.Uint32(serializedHeightBytes[:]), + ) + + if err := compareScript(serializedHeight, sigScript); err != nil { + return 0, err + } - return int32(serializedHeight), nil + return serializedHeight, nil } -// checkSerializedHeight checks if the signature script in the passed +// CheckSerializedHeight checks if the signature script in the passed // transaction starts with the serialized block height of wantHeight. -func checkSerializedHeight(coinbaseTx *btcutil.Tx, wantHeight int32) error { +func CheckSerializedHeight(coinbaseTx *btcutil.Tx, wantHeight int32) error { serializedHeight, err := ExtractCoinbaseHeight(coinbaseTx) if err != nil { return err @@ -634,6 +648,26 @@ func checkSerializedHeight(coinbaseTx *btcutil.Tx, wantHeight int32) error { return nil } +func compareScript(height int32, script []byte) error { + scriptBuilder := txscript.NewScriptBuilder( + txscript.WithScriptAllocSize(coinbaseHeightAllocSize), + ) + scriptHeight, err := scriptBuilder.AddInt64( + int64(height), + ).Script() + if err != nil { + return err + } + + if !bytes.HasPrefix(script, scriptHeight) { + str := fmt.Sprintf("the coinbase signature script does not "+ + "minimally encode the height %d", height) + return ruleError(ErrBadCoinbaseHeight, str) + } + + return nil +} + // CheckBlockHeaderContext performs several validation checks on the block header // which depend on its position within the block chain. // @@ -787,7 +821,7 @@ func (b *BlockChain) checkBlockContext(block *btcutil.Block, prevNode *blockNode blockHeight >= b.chainParams.BIP0034Height { coinbaseTx := block.Transactions()[0] - err := checkSerializedHeight(coinbaseTx, blockHeight) + err := CheckSerializedHeight(coinbaseTx, blockHeight) if err != nil { return err } diff --git a/blockchain/validate_test.go b/blockchain/validate_test.go index 1963a41590..ddd59130c1 100644 --- a/blockchain/validate_test.go +++ b/blockchain/validate_test.go @@ -169,7 +169,7 @@ func TestCheckBlockSanity(t *testing.T) { } } -// TestCheckSerializedHeight tests the checkSerializedHeight function with +// TestCheckSerializedHeight tests the CheckSerializedHeight function with // various serialized heights and also does negative tests to ensure errors // and handled properly. func TestCheckSerializedHeight(t *testing.T) { @@ -215,9 +215,9 @@ func TestCheckSerializedHeight(t *testing.T) { msgTx.TxIn[0].SignatureScript = test.sigScript tx := btcutil.NewTx(msgTx) - err := checkSerializedHeight(tx, test.wantHeight) + err := CheckSerializedHeight(tx, test.wantHeight) if reflect.TypeOf(err) != reflect.TypeOf(test.err) { - t.Errorf("checkSerializedHeight #%d wrong error type "+ + t.Errorf("CheckSerializedHeight #%d wrong error type "+ "got: %v <%T>, want: %T", i, err, err, test.err) continue } @@ -225,7 +225,7 @@ func TestCheckSerializedHeight(t *testing.T) { if rerr, ok := err.(RuleError); ok { trerr := test.err.(RuleError) if rerr.ErrorCode != trerr.ErrorCode { - t.Errorf("checkSerializedHeight #%d wrong "+ + t.Errorf("CheckSerializedHeight #%d wrong "+ "error code got: %v, want: %v", i, rerr.ErrorCode, trerr.ErrorCode) continue