diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b58cd0d8749..e7f7a1fe351e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -80,21 +80,27 @@ jobs: runs-on: ubuntu-latest if: github.ref == 'refs/heads/dev-upgrade' && !startsWith(github.ref, 'refs/tags/') needs: tests + outputs: + output1: ${{ steps.docker.outputs.image_name }} steps: - uses: actions/checkout@v4 - name: Login to Docker Hub run: echo ${{ secrets.DOCKER_PASSWORD }} | docker login -u ${{ secrets.DOCKER_USERNAME }} --password-stdin - name: Build and Push Docker images + id: docker run: | git_hash=$(git rev-parse --short "$GITHUB_SHA") + image_name=xinfinorg/devnet:dev-upgrade-${git_hash} docker pull xinfinorg/devnet:latest docker tag xinfinorg/devnet:latest xinfinorg/devnet:previous docker rmi xinfinorg/devnet:latest docker build -t xinfinorg/devnet:latest -f cicd/Dockerfile . - docker tag xinfinorg/devnet:latest xinfinorg/devnet:dev-upgrade-${git_hash} - docker push xinfinorg/devnet:dev-upgrade-${git_hash} + docker tag xinfinorg/devnet:latest $image_name + docker push $image_name docker push xinfinorg/devnet:latest docker push xinfinorg/devnet:previous + echo "image_name=$image_name" + echo "image_name=$image_name" >> "$GITHUB_OUTPUT" devnet_terraform_apply: runs-on: ubuntu-latest @@ -142,6 +148,17 @@ jobs: terraform init ${{ env.tf_init_cli_options }} terraform apply ${{ env.tf_apply_cli_options }} + - name: Update RPC nodes image + uses: dawidd6/action-ansible-playbook@v2 + with: + playbook: playbooks/update-image.yaml + directory: ./cicd/ansible + key: ${{secrets.SSH_PRIVATE_KEY_DEVNET}} + options: | + --inventory inventory.yaml + --extra-vars network=ec2_rpcs + --extra-vars rpc_image=${{ needs.devnet_build_push.outputs.output1 }} + devnet_dev-upgrade_node: runs-on: ubuntu-latest if: github.ref == 'refs/heads/dev-upgrade' && !startsWith(github.ref, 'refs/tags/') diff --git a/.github/workflows/deploy_rpc_image.yml b/.github/workflows/deploy_rpc_image.yml new file mode 100644 index 000000000000..ebac65d963a5 --- /dev/null +++ b/.github/workflows/deploy_rpc_image.yml @@ -0,0 +1,41 @@ +name: Deploy RPC Image +on: + #need to make sure only authorized people can use this function + workflow_dispatch: + inputs: + network: + type: choice + description: 'devnet, testnet, or mainnet' + options: + - devnet + - testnet + - mainnet + rpc_image: + description: 'full image name' + +jobs: + ansible: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Update RPC nodes image + uses: dawidd6/action-ansible-playbook@v2 + with: + playbook: playbooks/update-image.yaml + directory: ./cicd/ansible + key: ${{secrets.SSH_PRIVATE_KEY_DEVNET}} + options: | + --inventory inventory.yaml + --extra-vars network=${{inputs.network}} + --extra-vars rpc_image=${{inputs.rpc_image}} + + devnet_send_notification: + runs-on: ubuntu-latest + needs: ansible + steps: + - uses: actions/checkout@v4 + - name: Send deployment notification + run: | + curl --location --request POST "66.94.98.186:8080/deploy?environment=${{inputs.network}}&service=xdc_rpc&version=${{inputs.rpc_image}}" + \ No newline at end of file diff --git a/XDCx/XDCx.go b/XDCx/XDCx.go index bae3d00374c0..6506929984d7 100644 --- a/XDCx/XDCx.go +++ b/XDCx/XDCx.go @@ -9,14 +9,13 @@ import ( "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" "github.com/XinFinOrg/XDPoSChain/XDCxDAO" - "github.com/XinFinOrg/XDPoSChain/consensus" - "github.com/XinFinOrg/XDPoSChain/core/types" - "github.com/XinFinOrg/XDPoSChain/p2p" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" - "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/prque" + "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/core/state" + "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/log" + "github.com/XinFinOrg/XDPoSChain/p2p" "github.com/XinFinOrg/XDPoSChain/rpc" lru "github.com/hashicorp/golang-lru" "golang.org/x/sync/syncmap" @@ -105,7 +104,7 @@ func New(cfg *Config) *XDCX { } XDCX := &XDCX{ orderNonce: make(map[common.Address]*big.Int), - Triegc: prque.New(), + Triegc: prque.New(nil), tokenDecimalCache: tokenDecimalCache, orderCache: orderCache, } diff --git a/XDCxlending/XDCxlending.go b/XDCxlending/XDCxlending.go index 6818b375a20d..352c224d0a20 100644 --- a/XDCxlending/XDCxlending.go +++ b/XDCxlending/XDCxlending.go @@ -12,14 +12,13 @@ import ( "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" "github.com/XinFinOrg/XDPoSChain/XDCxDAO" "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" - "github.com/XinFinOrg/XDPoSChain/consensus" - "github.com/XinFinOrg/XDPoSChain/core/types" - "github.com/XinFinOrg/XDPoSChain/p2p" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" - "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/prque" + "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/core/state" + "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/log" + "github.com/XinFinOrg/XDPoSChain/p2p" "github.com/XinFinOrg/XDPoSChain/rpc" lru "github.com/hashicorp/golang-lru" ) @@ -67,7 +66,7 @@ func New(XDCx *XDCx.XDCX) *Lending { lendingTradeCache, _ := lru.New(defaultCacheLimit) lending := &Lending{ orderNonce: make(map[common.Address]*big.Int), - Triegc: prque.New(), + Triegc: prque.New(nil), lendingItemHistory: itemCache, lendingTradeHistory: lendingTradeCache, } diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go index 204fe1634768..0df4bd7450c1 100644 --- a/accounts/abi/bind/backends/simulated.go +++ b/accounts/abi/bind/backends/simulated.go @@ -25,11 +25,9 @@ import ( "sync" "time" + "github.com/XinFinOrg/XDPoSChain" "github.com/XinFinOrg/XDPoSChain/XDCx" "github.com/XinFinOrg/XDPoSChain/XDCxlending" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" - - "github.com/XinFinOrg/XDPoSChain" "github.com/XinFinOrg/XDPoSChain/accounts" "github.com/XinFinOrg/XDPoSChain/accounts/abi/bind" "github.com/XinFinOrg/XDPoSChain/accounts/keystore" @@ -37,10 +35,10 @@ import ( "github.com/XinFinOrg/XDPoSChain/common/math" "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS" "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS/utils" - "github.com/XinFinOrg/XDPoSChain/consensus/ethash" "github.com/XinFinOrg/XDPoSChain/core" "github.com/XinFinOrg/XDPoSChain/core/bloombits" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/core/vm" @@ -365,7 +363,7 @@ func (b *SimulatedBackend) callContract(ctx context.Context, call XDPoSChain.Cal from := statedb.GetOrNewStateObject(call.From) from.SetBalance(math.MaxBig256) // Execute the call. - msg := callmsg{call} + msg := callMsg{call} feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) if msg.To() != nil { if value, ok := feeCapacity[*msg.To()]; ok { @@ -388,7 +386,10 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa b.mu.Lock() defer b.mu.Unlock() - sender, err := types.Sender(types.HomesteadSigner{}, tx) + // Check transaction validity. + block := b.blockchain.CurrentBlock() + signer := types.MakeSigner(b.blockchain.Config(), block.Number()) + sender, err := types.Sender(signer, tx) if err != nil { panic(fmt.Errorf("invalid transaction: %v", err)) } @@ -397,7 +398,8 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa panic(fmt.Errorf("invalid transaction nonce: got %d, want %d", tx.Nonce(), nonce)) } - blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), b.blockchain.Engine(), b.database, 1, func(number int, block *core.BlockGen) { + // Include tx in chain. + blocks, _ := core.GenerateChain(b.config, block, b.blockchain.Engine(), b.database, 1, func(number int, block *core.BlockGen) { for _, tx := range b.pendingBlock.Transactions() { block.AddTxWithChain(b.blockchain, tx) } @@ -501,20 +503,21 @@ func (b *SimulatedBackend) GetBlockChain() *core.BlockChain { return b.blockchain } -// callmsg implements core.Message to allow passing it as a transaction simulator. -type callmsg struct { +// callMsg implements core.Message to allow passing it as a transaction simulator. +type callMsg struct { XDPoSChain.CallMsg } -func (m callmsg) From() common.Address { return m.CallMsg.From } -func (m callmsg) Nonce() uint64 { return 0 } -func (m callmsg) CheckNonce() bool { return false } -func (m callmsg) To() *common.Address { return m.CallMsg.To } -func (m callmsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } -func (m callmsg) Gas() uint64 { return m.CallMsg.Gas } -func (m callmsg) Value() *big.Int { return m.CallMsg.Value } -func (m callmsg) Data() []byte { return m.CallMsg.Data } -func (m callmsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } +func (m callMsg) From() common.Address { return m.CallMsg.From } +func (m callMsg) Nonce() uint64 { return 0 } +func (m callMsg) CheckNonce() bool { return false } +func (m callMsg) To() *common.Address { return m.CallMsg.To } +func (m callMsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } +func (m callMsg) Gas() uint64 { return m.CallMsg.Gas } +func (m callMsg) Value() *big.Int { return m.CallMsg.Value } +func (m callMsg) Data() []byte { return m.CallMsg.Data } +func (m callMsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } +func (m callMsg) AccessList() types.AccessList { return m.CallMsg.AccessList } // filterBackend implements filters.Backend to support filtering for logs without // taking bloom-bits acceleration structures into account. @@ -553,7 +556,7 @@ func (fb *filterBackend) GetLogs(ctx context.Context, hash common.Hash) ([][]*ty return logs, nil } -func (fb *filterBackend) SubscribeTxPreEvent(ch chan<- core.TxPreEvent) event.Subscription { +func (fb *filterBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { return event.NewSubscription(func(quit <-chan struct{}) error { <-quit return nil diff --git a/accounts/accounts.go b/accounts/accounts.go index ba575779259c..157112c8c150 100644 --- a/accounts/accounts.go +++ b/accounts/accounts.go @@ -18,12 +18,14 @@ package accounts import ( + "fmt" "math/big" ethereum "github.com/XinFinOrg/XDPoSChain" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/event" + "golang.org/x/crypto/sha3" ) // Account represents an Ethereum account located at a specific location defined @@ -148,6 +150,34 @@ type Backend interface { Subscribe(sink chan<- WalletEvent) event.Subscription } +// TextHash is a helper function that calculates a hash for the given message that can be +// safely used to calculate a signature from. +// +// The hash is calulcated as +// +// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). +// +// This gives context to the signed message and prevents signing of transactions. +func TextHash(data []byte) []byte { + hash, _ := TextAndHash(data) + return hash +} + +// TextAndHash is a helper function that calculates a hash for the given message that can be +// safely used to calculate a signature from. +// +// The hash is calulcated as +// +// keccak256("\x19Ethereum Signed Message:\n"${message length}${message}). +// +// This gives context to the signed message and prevents signing of transactions. +func TextAndHash(data []byte) ([]byte, string) { + msg := fmt.Sprintf("\x19Ethereum Signed Message:\n%d%s", len(data), string(data)) + hasher := sha3.NewLegacyKeccak256() + hasher.Write([]byte(msg)) + return hasher.Sum(nil), msg +} + // WalletEventType represents the different event types that can be fired by // the wallet subscription subsystem. type WalletEventType int diff --git a/accounts/keystore/keystore.go b/accounts/keystore/keystore.go index 93854794c4e4..45a3be3036d7 100644 --- a/accounts/keystore/keystore.go +++ b/accounts/keystore/keystore.go @@ -288,11 +288,9 @@ func (ks *KeyStore) SignTx(a accounts.Account, tx *types.Transaction, chainID *b if !found { return nil, ErrLocked } - // Depending on the presence of the chain ID, sign with EIP155 or homestead - if chainID != nil { - return types.SignTx(tx, types.NewEIP155Signer(chainID), unlockedKey.PrivateKey) - } - return types.SignTx(tx, types.HomesteadSigner{}, unlockedKey.PrivateKey) + // Depending on the presence of the chain ID, sign with 2718 or homestead + signer := types.LatestSignerForChainID(chainID) + return types.SignTx(tx, signer, unlockedKey.PrivateKey) } // SignHashWithPassphrase signs hash if the private key matching the given address @@ -316,11 +314,9 @@ func (ks *KeyStore) SignTxWithPassphrase(a accounts.Account, passphrase string, } defer zeroKey(key.PrivateKey) - // Depending on the presence of the chain ID, sign with EIP155 or homestead - if chainID != nil { - return types.SignTx(tx, types.NewEIP155Signer(chainID), key.PrivateKey) - } - return types.SignTx(tx, types.HomesteadSigner{}, key.PrivateKey) + // Depending on the presence of the chain ID, sign with or without replay protection. + signer := types.LatestSignerForChainID(chainID) + return types.SignTx(tx, signer, key.PrivateKey) } // Unlock unlocks the given account indefinitely. diff --git a/accounts/usbwallet/trezor.go b/accounts/usbwallet/trezor.go index 551b69e224db..04e57292bafe 100644 --- a/accounts/usbwallet/trezor.go +++ b/accounts/usbwallet/trezor.go @@ -27,13 +27,13 @@ import ( "io" "math/big" - "github.com/golang/protobuf/proto" "github.com/XinFinOrg/XDPoSChain/accounts" "github.com/XinFinOrg/XDPoSChain/accounts/usbwallet/internal/trezor" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/log" + "github.com/golang/protobuf/proto" ) // ErrTrezorPINNeeded is returned if opening the trezor requires a PIN code. In @@ -80,13 +80,13 @@ func (w *trezorDriver) Status() (string, error) { // Open implements usbwallet.driver, attempting to initialize the connection to // the Trezor hardware wallet. Initializing the Trezor is a two phase operation: -// * The first phase is to initialize the connection and read the wallet's -// features. This phase is invoked is the provided passphrase is empty. The -// device will display the pinpad as a result and will return an appropriate -// error to notify the user that a second open phase is needed. -// * The second phase is to unlock access to the Trezor, which is done by the -// user actually providing a passphrase mapping a keyboard keypad to the pin -// number of the user (shuffled according to the pinpad displayed). +// - The first phase is to initialize the connection and read the wallet's +// features. This phase is invoked is the provided passphrase is empty. The +// device will display the pinpad as a result and will return an appropriate +// error to notify the user that a second open phase is needed. +// - The second phase is to unlock access to the Trezor, which is done by the +// user actually providing a passphrase mapping a keyboard keypad to the pin +// number of the user (shuffled according to the pinpad displayed). func (w *trezorDriver) Open(device io.ReadWriter, passphrase string) error { w.device, w.failure = device, nil @@ -220,9 +220,11 @@ func (w *trezorDriver) trezorSign(derivationPath []uint32, tx *types.Transaction if chainID == nil { signer = new(types.HomesteadSigner) } else { + // Trezor backend does not support typed transactions yet. signer = types.NewEIP155Signer(chainID) signature[64] = signature[64] - byte(chainID.Uint64()*2+35) } + // Inject the final signature into the transaction and sanity check the sender signed, err := tx.WithSignature(signer, signature) if err != nil { diff --git a/cicd/ansible/inventory.yaml b/cicd/ansible/inventory.yaml new file mode 100644 index 000000000000..141604ff7c7d --- /dev/null +++ b/cicd/ansible/inventory.yaml @@ -0,0 +1,17 @@ +ec2_rpcs: + hosts: + devnet: + ansible_host: devnet.hashlabs.apothem.network + ansible_port: 22 + ansible_user: ec2-user + deploy_path: /work/XinFin-Node/devnet + testnet: + ansible_host: testnet.hashlabs.apothem.network + ansible_port: 22 + ansible_user: ec2-user + deploy_path: /work/XinFin-Node/testnet + mainnet: + ansible_host: mainnet.hashlabs.apothem.network + ansible_port: 22 + ansible_user: ec2-user + deploy_path: /work/XinFin-Node/mainnet \ No newline at end of file diff --git a/cicd/ansible/playbooks/update-image.yaml b/cicd/ansible/playbooks/update-image.yaml new file mode 100644 index 000000000000..96baac1fdc47 --- /dev/null +++ b/cicd/ansible/playbooks/update-image.yaml @@ -0,0 +1,15 @@ +--- +- name: Run Bash Script on Host + hosts: "{{ network }}" + become: true #sudo/root + + tasks: + - name: Update RPC image version + shell: | + export RPC_IMAGE={{ rpc_image }} + cd {{ deploy_path }} + ./docker-down.sh + ./docker-up-hash.sh + docker ps + register: output + - debug: var=output.stdout_lines \ No newline at end of file diff --git a/cicd/devnet/start.sh b/cicd/devnet/start.sh index d5f7e152f490..4c3a81ee5b53 100755 --- a/cicd/devnet/start.sh +++ b/cicd/devnet/start.sh @@ -81,5 +81,5 @@ XDC --ethstats ${netstats} --gcmode archive \ --rpcvhosts "*" --unlock "${wallet}" --password /work/.pwd --mine \ --gasprice "1" --targetgaslimit "420000000" --verbosity ${log_level} \ --debugdatadir /work/xdcchain \ ---enable-0x-prefix --ws --wsaddr=0.0.0.0 --wsport $ws_port \ +--ws --wsaddr=0.0.0.0 --wsport $ws_port \ --wsorigins "*" 2>&1 >>/work/xdcchain/xdc.log | tee -a /work/xdcchain/xdc.log diff --git a/cicd/mainnet/start.sh b/cicd/mainnet/start.sh index 35d53ff60188..35f11a5d3406 100755 --- a/cicd/mainnet/start.sh +++ b/cicd/mainnet/start.sh @@ -80,5 +80,5 @@ XDC --ethstats ${netstats} --gcmode archive \ --rpcvhosts "*" --unlock "${wallet}" --password /work/.pwd --mine \ --gasprice "1" --targetgaslimit "420000000" --verbosity ${log_level} \ --debugdatadir /work/xdcchain \ ---enable-0x-prefix --ws --wsaddr=0.0.0.0 --wsport $ws_port \ +--ws --wsaddr=0.0.0.0 --wsport $ws_port \ --wsorigins "*" 2>&1 >>/work/xdcchain/xdc.log | tee -a /work/xdcchain/xdc.log diff --git a/cicd/terraform/.env b/cicd/terraform/.env index 4eb6ca5a95c0..8a64c1d22446 100644 --- a/cicd/terraform/.env +++ b/cicd/terraform/.env @@ -10,4 +10,4 @@ eu_west_1_end=72 # Sydney ap_southeast_2_start=73 -ap_southeast_2_end=108 \ No newline at end of file +ap_southeast_2_end=108 diff --git a/cicd/terraform/main.tf b/cicd/terraform/main.tf index ccb6ce690e00..5a44d2238552 100644 --- a/cicd/terraform/main.tf +++ b/cicd/terraform/main.tf @@ -76,3 +76,48 @@ module "mainnet-rpc" { } } + +module "devnet_rpc" { + source = "./module/ec2_rpc" + network = "devnet" + vpc_id = local.vpc_id + aws_subnet_id = local.aws_subnet_id + ami_id = local.ami_id + instance_type = "t3.large" + ssh_key_name = local.ssh_key_name + rpc_image = local.rpc_image + + providers = { + aws = aws.ap-southeast-1 + } +} + +module "testnet_rpc" { + source = "./module/ec2_rpc" + network = "testnet" + vpc_id = local.vpc_id + aws_subnet_id = local.aws_subnet_id + ami_id = local.ami_id + instance_type = "t3.large" + ssh_key_name = local.ssh_key_name + rpc_image = local.rpc_image + + providers = { + aws = aws.ap-southeast-1 + } +} + +module "mainnet_rpc" { + source = "./module/ec2_rpc" + network = "mainnet" + vpc_id = local.vpc_id + aws_subnet_id = local.aws_subnet_id + ami_id = local.ami_id + instance_type = "t3.large" + ssh_key_name = local.ssh_key_name + rpc_image = local.rpc_image + + providers = { + aws = aws.ap-southeast-1 + } +} \ No newline at end of file diff --git a/cicd/terraform/module/ec2_rpc/main.tf b/cicd/terraform/module/ec2_rpc/main.tf new file mode 100644 index 000000000000..75535dd0ac81 --- /dev/null +++ b/cicd/terraform/module/ec2_rpc/main.tf @@ -0,0 +1,106 @@ +terraform { + required_providers { + aws = { + source = "hashicorp/aws" + version = "~> 5.13.1" + } + } +} +variable network { + type = string +} +variable vpc_id { + type = string +} +variable aws_subnet_id { + type = string +} +variable ami_id { + type = string +} +variable instance_type { + type = string +} +variable ssh_key_name { + type = string +} +variable rpc_image { + type = string +} + +resource "aws_security_group" "rpc_sg" { + name_prefix = "${var.network}_rpc_sg" + + ingress { + from_port = 22 + to_port = 22 + protocol = "tcp" + cidr_blocks = ["0.0.0.0/0"] + } + + ingress { + from_port = 30303 + to_port = 30303 + protocol = "tcp" + cidr_blocks = ["0.0.0.0/0"] + } + + ingress { + from_port = 8545 + to_port = 8545 + protocol = "tcp" + cidr_blocks = ["0.0.0.0/0"] + } + + ingress { + from_port = 8555 + to_port = 8555 + protocol = "tcp" + cidr_blocks = ["0.0.0.0/0"] + } + + egress { + from_port = 0 + to_port = 0 + protocol = "-1" + cidr_blocks = ["0.0.0.0/0"] + } +} + +resource "aws_instance" "rpc_instance" { + instance_type = var.instance_type + ami = var.ami_id + tags = { + Name = var.network + } + key_name = var.ssh_key_name + vpc_security_group_ids = [aws_security_group.rpc_sg.id] + ebs_block_device { + device_name = "/dev/xvda" + volume_size = 500 + } + + + #below still need to remove git checkout {{branch}} after files merged to master + user_data = <<-EOF + #!/bin/bash + sudo yum update -y + sudo yum upgrade -y + sudo yum install git -y + sudo yum install docker -y + mkdir -p /root/.docker/cli-plugins + curl -SL https://github.com/docker/compose/releases/download/v2.25.0/docker-compose-linux-x86_64 -o /root/.docker/cli-plugins/docker-compose + sudo chmod +x /root/.docker/cli-plugins/docker-compose + echo checking compose version + docker compose version + sudo systemctl enable docker + sudo systemctl start docker + mkdir -p /work + cd /work + git clone https://github.com/XinFinOrg/XinFin-Node + cd /work/XinFin-Node/${var.network} + export RPC_IMAGE="${var.rpc_image}" + echo RPC_IMAGE=$RPC_IMAGE + ./docker-up-hash.sh + EOF +} \ No newline at end of file diff --git a/cicd/terraform/variables.tf b/cicd/terraform/variables.tf index 89d6945e6178..c5a1eb8970d0 100644 --- a/cicd/terraform/variables.tf +++ b/cicd/terraform/variables.tf @@ -34,3 +34,11 @@ locals { rpcTestnetNodeKeys = { "testnet-rpc1": local.predefinedNodesConfig["testnet-rpc1"]} // we hardcode the rpc to a single node for now rpcMainnetNodeKeys = { "mainnet-rpc1": local.predefinedNodesConfig["mainnet-rpc1"]} // we hardcode the rpc to a single node for now } + +locals { //ec2_rpc values + ami_id = "ami-097c4e1feeea169e5" + rpc_image = "xinfinorg/xdposchain:v2.2.0-beta1" + vpc_id = "vpc-20a06846" + aws_subnet_id = "subnet-4653ee20" + ssh_key_name = "devnetkey" +} \ No newline at end of file diff --git a/cicd/testnet/start.sh b/cicd/testnet/start.sh index 3c5b2234a560..d5f9a0f443fc 100755 --- a/cicd/testnet/start.sh +++ b/cicd/testnet/start.sh @@ -82,5 +82,5 @@ XDC --ethstats ${netstats} --gcmode archive \ --rpcvhosts "*" --unlock "${wallet}" --password /work/.pwd --mine \ --gasprice "1" --targetgaslimit "420000000" --verbosity ${log_level} \ --debugdatadir /work/xdcchain \ ---enable-0x-prefix --ws --wsaddr=0.0.0.0 --wsport $ws_port \ +--ws --wsaddr=0.0.0.0 --wsport $ws_port \ --wsorigins "*" 2>&1 >>/work/xdcchain/xdc.log | tee -a /work/xdcchain/xdc.log diff --git a/cmd/XDC/config.go b/cmd/XDC/config.go index 948e59f41756..76e03981efa5 100644 --- a/cmd/XDC/config.go +++ b/cmd/XDC/config.go @@ -164,8 +164,8 @@ func makeConfigNode(ctx *cli.Context) (*node.Node, XDCConfig) { common.TIPXDCXCancellationFee = common.TIPXDCXCancellationFeeTestnet } - if ctx.GlobalBool(utils.Enable0xPrefixFlag.Name) { - common.Enable0xPrefix = true + if ctx.GlobalBool(utils.EnableXDCPrefixFlag.Name) { + common.Enable0xPrefix = false } // Rewound diff --git a/cmd/XDC/consolecmd_test.go b/cmd/XDC/consolecmd_test.go index 5944a4482ef3..ecbd11762aee 100644 --- a/cmd/XDC/consolecmd_test.go +++ b/cmd/XDC/consolecmd_test.go @@ -38,7 +38,7 @@ const ( // Tests that a node embedded within a console can be started up properly and // then terminated by closing the input stream. func TestConsoleWelcome(t *testing.T) { - coinbase := "xdc8605cdbbdb6d264aa742e77020dcbc58fcdce182" + coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" // Start a XDC console, make sure it's cleaned up and terminate the console XDC := runXDC(t, @@ -75,7 +75,7 @@ at block: 0 ({{niltime}}) // Tests that a console can be attached to a running node via various means. func TestIPCAttachWelcome(t *testing.T) { // Configure the instance for IPC attachement - coinbase := "xdc8605cdbbdb6d264aa742e77020dcbc58fcdce182" + coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" var ipc string if runtime.GOOS == "windows" { ipc = `\\.\pipe\XDC` + strconv.Itoa(trulyRandInt(100000, 999999)) @@ -97,7 +97,7 @@ func TestIPCAttachWelcome(t *testing.T) { } func TestHTTPAttachWelcome(t *testing.T) { - coinbase := "xdc8605cdbbdb6d264aa742e77020dcbc58fcdce182" + coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" port := strconv.Itoa(trulyRandInt(1024, 65536)) // Yeah, sometimes this will fail, sorry :P XDC := runXDC(t, "--XDCx.datadir", tmpdir(t)+"XDCx/"+time.Now().String(), @@ -112,7 +112,7 @@ func TestHTTPAttachWelcome(t *testing.T) { } func TestWSAttachWelcome(t *testing.T) { - coinbase := "xdc8605cdbbdb6d264aa742e77020dcbc58fcdce182" + coinbase := "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182" port := strconv.Itoa(trulyRandInt(1024, 65536)) // Yeah, sometimes this will fail, sorry :P XDC := runXDC(t, diff --git a/cmd/XDC/main.go b/cmd/XDC/main.go index 2117ddd5cee7..419978ba6e09 100644 --- a/cmd/XDC/main.go +++ b/cmd/XDC/main.go @@ -114,6 +114,7 @@ var ( //utils.VMEnableDebugFlag, utils.XDCTestnetFlag, utils.Enable0xPrefixFlag, + utils.EnableXDCPrefixFlag, utils.RewoundFlag, utils.NetworkIdFlag, utils.RPCCORSDomainFlag, diff --git a/cmd/evm/json_logger.go b/cmd/evm/json_logger.go index 90e44f9c4ae5..a5b8c0fea21d 100644 --- a/cmd/evm/json_logger.go +++ b/cmd/evm/json_logger.go @@ -33,15 +33,23 @@ type JSONLogger struct { } func NewJSONLogger(cfg *vm.LogConfig, writer io.Writer) *JSONLogger { - return &JSONLogger{json.NewEncoder(writer), cfg} + l := &JSONLogger{json.NewEncoder(writer), cfg} + if l.cfg == nil { + l.cfg = &vm.LogConfig{} + } + return l +} + +func (l *JSONLogger) CaptureStart(env *vm.EVM, from, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { } -func (l *JSONLogger) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error { - return nil +func (l *JSONLogger) CaptureFault(*vm.EVM, uint64, vm.OpCode, uint64, uint64, *vm.ScopeContext, int, error) { } // CaptureState outputs state information on the logger. -func (l *JSONLogger) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error { +func (l *JSONLogger) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) { + memory := scope.Memory + stack := scope.Stack log := vm.StructLog{ Pc: pc, Op: op, @@ -63,24 +71,20 @@ func (l *JSONLogger) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cos } log.Stack = logstack } - return l.encoder.Encode(log) -} - -// CaptureFault outputs state information on the logger. -func (l *JSONLogger) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error { - return nil + l.encoder.Encode(log) } // CaptureEnd is triggered at end of execution. -func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error { +func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { type endLog struct { Output string `json:"output"` GasUsed math.HexOrDecimal64 `json:"gasUsed"` Time time.Duration `json:"time"` Err string `json:"error,omitempty"` } + var errMsg string if err != nil { - return l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, err.Error()}) + errMsg = err.Error() } - return l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, ""}) + l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, errMsg}) } diff --git a/cmd/puppeth/wizard_genesis.go b/cmd/puppeth/wizard_genesis.go index 92aa2197cbfd..aefac61cc69d 100644 --- a/cmd/puppeth/wizard_genesis.go +++ b/cmd/puppeth/wizard_genesis.go @@ -145,8 +145,8 @@ func (w *wizard) makeGenesis() { genesis.Config.XDPoS.V2.CurrentConfig.TimeoutSyncThreshold = w.readDefaultInt(3) fmt.Println() - fmt.Printf("How many v2 vote collection to generate a QC, should be two thirds of masternodes? (default = %f)\n", 0.666) - genesis.Config.XDPoS.V2.CurrentConfig.CertThreshold = w.readDefaultFloat(0.666) + fmt.Printf("Proportion of total masternodes v2 vote collection to generate a QC (float value), should be two thirds of masternodes? (default = %f)\n", 0.667) + genesis.Config.XDPoS.V2.CurrentConfig.CertThreshold = w.readDefaultFloat(0.667) genesis.Config.XDPoS.V2.AllConfigs[0] = genesis.Config.XDPoS.V2.CurrentConfig diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 62f8a4a62d51..a9fd25c4922a 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -117,7 +117,11 @@ var ( } Enable0xPrefixFlag = cli.BoolFlag{ Name: "enable-0x-prefix", - Usage: "Addres use 0x-prefix (default = false)", + Usage: "Addres use 0x-prefix (Deprecated: this is on by default, to use xdc prefix use --enable-xdc-prefix)", + } + EnableXDCPrefixFlag = cli.BoolFlag{ + Name: "enable-xdc-prefix", + Usage: "Addres use xdc-prefix (default = false)", } // General settings AnnounceTxsFlag = cli.BoolFlag{ @@ -354,6 +358,11 @@ var ( Name: "vmdebug", Usage: "Record information useful for VM and contract debugging", } + RPCGlobalGasCapFlag = cli.Uint64Flag{ + Name: "rpc.gascap", + Usage: "Sets a cap on gas that can be used in eth_call/estimateGas (0=infinite)", + Value: eth.DefaultConfig.RPCGasCap, + } // Logging and debug settings EthStatsURLFlag = cli.StringFlag{ Name: "ethstats", @@ -786,6 +795,10 @@ func setIPC(ctx *cli.Context, cfg *node.Config) { } } +func setPrefix(ctx *cli.Context, cfg *node.Config) { + checkExclusive(ctx, Enable0xPrefixFlag, EnableXDCPrefixFlag) +} + // MakeDatabaseHandles raises out the number of allowed file handles per process // for XDC and returns half of the allowance to assign to the database. func MakeDatabaseHandles() int { @@ -933,6 +946,7 @@ func SetNodeConfig(ctx *cli.Context, cfg *node.Config) { setHTTP(ctx, cfg) setWS(ctx, cfg) setNodeUserIdent(ctx, cfg) + setPrefix(ctx, cfg) switch { case ctx.GlobalIsSet(DataDirFlag.Name): @@ -1167,6 +1181,11 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) { // TODO(fjl): force-enable this in --dev mode cfg.EnablePreimageRecording = ctx.GlobalBool(VMEnableDebugFlag.Name) } + if cfg.RPCGasCap != 0 { + log.Info("Set global gas cap", "cap", cfg.RPCGasCap) + } else { + log.Info("Global gas cap disabled") + } if ctx.GlobalIsSet(StoreRewardFlag.Name) { common.StoreRewardFolder = filepath.Join(stack.DataDir(), "XDC", "rewards") if _, err := os.Stat(common.StoreRewardFolder); os.IsNotExist(err) { diff --git a/cmd/wnode/main.go b/cmd/wnode/main.go index 5441428e88d3..5fa29ab96c54 100644 --- a/cmd/wnode/main.go +++ b/cmd/wnode/main.go @@ -139,8 +139,8 @@ func processArgs() { } if *asymmetricMode && len(*argPub) > 0 { - pub = crypto.ToECDSAPub(common.FromHex(*argPub)) - if !isKeyValid(pub) { + var err error + if pub, err = crypto.UnmarshalPubkey(common.FromHex(*argPub)); err != nil { utils.Fatalf("invalid public key") } } @@ -337,9 +337,8 @@ func configureNode() { if b == nil { utils.Fatalf("Error: can not convert hexadecimal string") } - pub = crypto.ToECDSAPub(b) - if !isKeyValid(pub) { - utils.Fatalf("Error: invalid public key") + if pub, err = crypto.UnmarshalPubkey(b); err != nil { + utils.Fatalf("Error: invalid peer public key") } } } diff --git a/common/constants.go b/common/constants.go index b85cd5de2749..cbe997f115e8 100644 --- a/common/constants.go +++ b/common/constants.go @@ -45,6 +45,9 @@ var TIPXDCX = big.NewInt(38383838) var TIPXDCXLending = big.NewInt(38383838) var TIPXDCXCancellationFee = big.NewInt(38383838) var TIPXDCXCancellationFeeTestnet = big.NewInt(38383838) +var TIPXDCXMinerDisable = big.NewInt(88999999900) +var TIPXDCXReceiverDisable = big.NewInt(99999999999) +var Eip1559Block = big.NewInt(9999999999) var TIPXDCXDISABLE = big.NewInt(99999999900) var BerlinBlock = big.NewInt(76321000) // Target 19th June 2024 var LondonBlock = big.NewInt(76321000) // Target 19th June 2024 @@ -53,7 +56,7 @@ var ShanghaiBlock = big.NewInt(76321000) // Target 19th June 2024 var TIPXDCXTestnet = big.NewInt(38383838) var IsTestnet bool = false -var Enable0xPrefix bool = false +var Enable0xPrefix bool = true var StoreRewardFolder string var RollbackHash Hash var BasePrice = big.NewInt(1000000000000000000) // 1 diff --git a/common/constants/constants.go.devnet b/common/constants/constants.go.devnet index 81951a760cd8..d0a585cc9a76 100644 --- a/common/constants/constants.go.devnet +++ b/common/constants/constants.go.devnet @@ -45,15 +45,17 @@ var TIPXDCX = big.NewInt(225000) var TIPXDCXLending = big.NewInt(225000) var TIPXDCXCancellationFee = big.NewInt(225000) var TIPXDCXCancellationFeeTestnet = big.NewInt(225000) -var TIPXDCXDISABLE = big.NewInt(15894900) -var BerlinBlock = big.NewInt(9999999999) -var LondonBlock = big.NewInt(9999999999) -var MergeBlock = big.NewInt(9999999999) +var TIPXDCXMinerDisable = big.NewInt(15894900) +var TIPXDCXReceiverDisable = big.NewInt(18018000) +var BerlinBlock = big.NewInt(16832700) +var LondonBlock = big.NewInt(16832700) +var MergeBlock = big.NewInt(16832700) var ShanghaiBlock = big.NewInt(16832700) +var Eip1559Block = big.NewInt(9999999999) var TIPXDCXTestnet = big.NewInt(0) var IsTestnet bool = false -var Enable0xPrefix bool = false +var Enable0xPrefix bool = true var StoreRewardFolder string var RollbackHash Hash var BasePrice = big.NewInt(1000000000000000000) // 1 diff --git a/common/constants/constants.go.testnet b/common/constants/constants.go.testnet index a5bd410b0c2a..681f67ec2cbc 100644 --- a/common/constants/constants.go.testnet +++ b/common/constants/constants.go.testnet @@ -45,15 +45,17 @@ var TIPXDCX = big.NewInt(23779191) var TIPXDCXLending = big.NewInt(23779191) var TIPXDCXCancellationFee = big.NewInt(23779191) var TIPXDCXCancellationFeeTestnet = big.NewInt(23779191) -var TIPXDCXDISABLE = big.NewInt(61290000) // Target 31st March 2024 -var BerlinBlock = big.NewInt(9999999999) -var LondonBlock = big.NewInt(9999999999) -var MergeBlock = big.NewInt(9999999999) +var TIPXDCXMinerDisable = big.NewInt(61290000) // Target 31st March 2024 +var TIPXDCXReceiverDisable = big.NewInt(9999999999) +var BerlinBlock = big.NewInt(61290000) +var LondonBlock = big.NewInt(61290000) +var MergeBlock = big.NewInt(61290000) var ShanghaiBlock = big.NewInt(61290000) // Target 31st March 2024 +var Eip1559Block = big.NewInt(9999999999) var TIPXDCXTestnet = big.NewInt(23779191) var IsTestnet bool = false -var Enable0xPrefix bool = false +var Enable0xPrefix bool = true var StoreRewardFolder string var RollbackHash Hash var BasePrice = big.NewInt(1000000000000000000) // 1 diff --git a/consensus/XDPoS/XDPoS.go b/consensus/XDPoS/XDPoS.go index 7a213f5b6563..6d5477a1ee61 100644 --- a/consensus/XDPoS/XDPoS.go +++ b/consensus/XDPoS/XDPoS.go @@ -240,7 +240,7 @@ func (x *XDPoS) VerifyHeaders(chain consensus.ChainReader, headers []*types.Head func (x *XDPoS) VerifyUncles(chain consensus.ChainReader, block *types.Block) error { switch x.config.BlockConsensusVersion(block.Number(), block.Extra(), ExtraFieldCheck) { case params.ConsensusEngineVersion2: - return nil + return x.EngineV2.VerifyUncles(chain, block) default: // Default "v1" return x.EngineV1.VerifyUncles(chain, block) } @@ -457,7 +457,7 @@ func (x *XDPoS) GetSnapshot(chain consensus.ChainReader, header *types.Header) ( return &utils.PublicApiSnapshot{ Number: sp.Number, Hash: sp.Hash, - Signers: sp.GetMappedMasterNodes(), + Signers: sp.GetMappedCandidates(), }, err default: // Default "v1" sp, err := x.EngineV1.GetSnapshot(chain, header) diff --git a/consensus/XDPoS/engines/engine_v1/engine.go b/consensus/XDPoS/engines/engine_v1/engine.go index 7fce2e6a8676..7a862370c8e0 100644 --- a/consensus/XDPoS/engines/engine_v1/engine.go +++ b/consensus/XDPoS/engines/engine_v1/engine.go @@ -154,6 +154,7 @@ func (x *XDPoS_v1) verifyHeaderWithCache(chain consensus.ChainReader, header *ty // looking those up from the database. This is useful for concurrently verifying // a batch of new headers. func (x *XDPoS_v1) verifyHeader(chain consensus.ChainReader, header *types.Header, parents []*types.Header, fullVerify bool) error { + fullVerify = false // If we're running a engine faking, accept any block as valid if x.config.SkipV1Validation { return nil diff --git a/consensus/XDPoS/engines/engine_v2/engine.go b/consensus/XDPoS/engines/engine_v2/engine.go index fc2c46edd84f..9cc25d7fab76 100644 --- a/consensus/XDPoS/engines/engine_v2/engine.go +++ b/consensus/XDPoS/engines/engine_v2/engine.go @@ -2,6 +2,7 @@ package engine_v2 import ( "encoding/json" + "errors" "fmt" "math/big" "os" @@ -468,7 +469,7 @@ func (x *XDPoS_v2) IsAuthorisedAddress(chain consensus.ChainReader, header *type log.Error("[IsAuthorisedAddress] Can't get snapshot with at ", "number", header.Number, "hash", header.Hash().Hex(), "err", err) return false } - for _, mn := range snap.NextEpochMasterNodes { + for _, mn := range snap.NextEpochCandidates { if mn == address { return true } @@ -515,6 +516,15 @@ func (x *XDPoS_v2) UpdateMasternodes(chain consensus.ChainReader, header *types. return nil } +// VerifyUncles implements consensus.Engine, always returning an error for any +// uncles as this consensus mechanism doesn't permit uncles. +func (x *XDPoS_v2) VerifyUncles(chain consensus.ChainReader, block *types.Block) error { + if len(block.Uncles()) > 0 { + return errors.New("uncles not allowed in XDPoS_v2") + } + return nil +} + func (x *XDPoS_v2) VerifyHeader(chain consensus.ChainReader, header *types.Header, fullVerify bool) error { err := x.verifyHeader(chain, header, nil, fullVerify) if err != nil { @@ -611,9 +621,9 @@ func (x *XDPoS_v2) VerifyVoteMessage(chain consensus.ChainReader, vote *types.Vo verified, signer, err := x.verifyMsgSignature(types.VoteSigHash(&types.VoteForSign{ ProposedBlockInfo: vote.ProposedBlockInfo, GapNumber: vote.GapNumber, - }), vote.Signature, snapshot.NextEpochMasterNodes) + }), vote.Signature, snapshot.NextEpochCandidates) if err != nil { - for i, mn := range snapshot.NextEpochMasterNodes { + for i, mn := range snapshot.NextEpochCandidates { log.Warn("[VerifyVoteMessage] Master node list item", "index", i, "Master node", mn.Hex()) } log.Warn("[VerifyVoteMessage] Error while verifying vote message", "votedBlockNum", vote.ProposedBlockInfo.Number.Uint64(), "votedBlockHash", vote.ProposedBlockInfo.Hash.Hex(), "voteHash", vote.Hash(), "error", err.Error()) @@ -649,15 +659,15 @@ func (x *XDPoS_v2) VerifyTimeoutMessage(chain consensus.ChainReader, timeoutMsg log.Error("[VerifyTimeoutMessage] Fail to get snapshot when verifying timeout message!", "messageGapNumber", timeoutMsg.GapNumber, "err", err) return false, err } - if len(snap.NextEpochMasterNodes) == 0 { - log.Error("[VerifyTimeoutMessage] cannot find nextEpochMasterNodes from snapshot", "messageGapNumber", timeoutMsg.GapNumber) + if len(snap.NextEpochCandidates) == 0 { + log.Error("[VerifyTimeoutMessage] cannot find NextEpochCandidates from snapshot", "messageGapNumber", timeoutMsg.GapNumber) return false, fmt.Errorf("Empty master node lists from snapshot") } verified, signer, err := x.verifyMsgSignature(types.TimeoutSigHash(&types.TimeoutForSign{ Round: timeoutMsg.Round, GapNumber: timeoutMsg.GapNumber, - }), timeoutMsg.Signature, snap.NextEpochMasterNodes) + }), timeoutMsg.Signature, snap.NextEpochCandidates) if err != nil { log.Warn("[VerifyTimeoutMessage] cannot verify timeout signature", "err", err) @@ -1005,7 +1015,7 @@ func (x *XDPoS_v2) calcMasternodes(chain consensus.ChainReader, blockNum *big.In log.Error("[calcMasternodes] Adaptor v2 getSnapshot has error", "err", err) return nil, nil, err } - candidates := snap.NextEpochMasterNodes + candidates := snap.NextEpochCandidates if blockNum.Uint64() == x.config.V2.SwitchBlock.Uint64()+1 { log.Info("[calcMasternodes] examing first v2 block") diff --git a/consensus/XDPoS/engines/engine_v2/epochSwitch.go b/consensus/XDPoS/engines/engine_v2/epochSwitch.go index 5d726b41a9a1..981c46ff9a06 100644 --- a/consensus/XDPoS/engines/engine_v2/epochSwitch.go +++ b/consensus/XDPoS/engines/engine_v2/epochSwitch.go @@ -63,7 +63,7 @@ func (x *XDPoS_v2) getEpochSwitchInfo(chain consensus.ChainReader, header *types return nil, err } penalties := common.ExtractAddressFromBytes(h.Penalties) - candidates := snap.NextEpochMasterNodes + candidates := snap.NextEpochCandidates standbynodes := []common.Address{} if len(masternodes) != len(candidates) { standbynodes = candidates diff --git a/consensus/XDPoS/engines/engine_v2/snapshot.go b/consensus/XDPoS/engines/engine_v2/snapshot.go index 60b145c4793f..78ce0aff55c6 100644 --- a/consensus/XDPoS/engines/engine_v2/snapshot.go +++ b/consensus/XDPoS/engines/engine_v2/snapshot.go @@ -10,22 +10,22 @@ import ( ) // Snapshot is the state of the smart contract validator list -// The validator list is used on next epoch master nodes +// The validator list is used on next epoch candidates nodes // If we don't have the snapshot, then we have to trace back the gap block smart contract state which is very costly type SnapshotV2 struct { Number uint64 `json:"number"` // Block number where the snapshot was created Hash common.Hash `json:"hash"` // Block hash where the snapshot was created - // MasterNodes will get assigned on updateM1 - NextEpochMasterNodes []common.Address `json:"masterNodes"` // Set of authorized master nodes at this moment for next epoch + // candidates will get assigned on updateM1 + NextEpochCandidates []common.Address `json:"masterNodes"` // Set of authorized candidates nodes at this moment for next epoch } // create new snapshot for next epoch to use -func newSnapshot(number uint64, hash common.Hash, masternodes []common.Address) *SnapshotV2 { +func newSnapshot(number uint64, hash common.Hash, candidates []common.Address) *SnapshotV2 { snap := &SnapshotV2{ - Number: number, - Hash: hash, - NextEpochMasterNodes: masternodes, + Number: number, + Hash: hash, + NextEpochCandidates: candidates, } return snap } @@ -53,17 +53,17 @@ func storeSnapshot(s *SnapshotV2, db ethdb.Database) error { return db.Put(append([]byte("XDPoS-V2-"), s.Hash[:]...), blob) } -// retrieves master nodes list in map type -func (s *SnapshotV2) GetMappedMasterNodes() map[common.Address]struct{} { +// retrieves candidates nodes list in map type +func (s *SnapshotV2) GetMappedCandidates() map[common.Address]struct{} { ms := make(map[common.Address]struct{}) - for _, n := range s.NextEpochMasterNodes { + for _, n := range s.NextEpochCandidates { ms[n] = struct{}{} } return ms } -func (s *SnapshotV2) IsMasterNodes(address common.Address) bool { - for _, n := range s.NextEpochMasterNodes { +func (s *SnapshotV2) IsCandidates(address common.Address) bool { + for _, n := range s.NextEpochCandidates { if n.String() == address.String() { return true } diff --git a/consensus/XDPoS/engines/engine_v2/snapshot_test.go b/consensus/XDPoS/engines/engine_v2/snapshot_test.go index 6d7e27a80749..f70dcfb9f2b2 100644 --- a/consensus/XDPoS/engines/engine_v2/snapshot_test.go +++ b/consensus/XDPoS/engines/engine_v2/snapshot_test.go @@ -15,8 +15,8 @@ func TestGetMasterNodes(t *testing.T) { snap := newSnapshot(1, common.Hash{}, masterNodes) for _, address := range masterNodes { - if _, ok := snap.GetMappedMasterNodes()[address]; !ok { - t.Error("should get master node from map", address.Hex(), snap.GetMappedMasterNodes()) + if _, ok := snap.GetMappedCandidates()[address]; !ok { + t.Error("should get master node from map", address.Hex(), snap.GetMappedCandidates()) return } } diff --git a/consensus/XDPoS/engines/engine_v2/timeout.go b/consensus/XDPoS/engines/engine_v2/timeout.go index d53ed386ee16..39d8100b2402 100644 --- a/consensus/XDPoS/engines/engine_v2/timeout.go +++ b/consensus/XDPoS/engines/engine_v2/timeout.go @@ -97,7 +97,7 @@ func (x *XDPoS_v2) verifyTC(chain consensus.ChainReader, timeoutCert *types.Time log.Error("[verifyTC] Fail to get snapshot when verifying TC!", "TCGapNumber", timeoutCert.GapNumber) return fmt.Errorf("[verifyTC] Unable to get snapshot, %s", err) } - if snap == nil || len(snap.NextEpochMasterNodes) == 0 { + if snap == nil || len(snap.NextEpochCandidates) == 0 { log.Error("[verifyTC] Something wrong with the snapshot from gapNumber", "messageGapNumber", timeoutCert.GapNumber, "snapshot", snap) return fmt.Errorf("empty master node lists from snapshot") } @@ -135,7 +135,7 @@ func (x *XDPoS_v2) verifyTC(chain consensus.ChainReader, timeoutCert *types.Time for _, signature := range signatures { go func(sig types.Signature) { defer wg.Done() - verified, _, err := x.verifyMsgSignature(signedTimeoutObj, sig, snap.NextEpochMasterNodes) + verified, _, err := x.verifyMsgSignature(signedTimeoutObj, sig, snap.NextEpochCandidates) if err != nil || !verified { log.Error("[verifyTC] Error or verification failure", "Signature", sig, "Error", err) mutex.Lock() // Lock before accessing haveError diff --git a/consensus/XDPoS/engines/engine_v2/utils.go b/consensus/XDPoS/engines/engine_v2/utils.go index b15e2917ee1f..e8006951ea14 100644 --- a/consensus/XDPoS/engines/engine_v2/utils.go +++ b/consensus/XDPoS/engines/engine_v2/utils.go @@ -164,7 +164,7 @@ func (x *XDPoS_v2) GetSignersFromSnapshot(chain consensus.ChainReader, header *t if err != nil { return nil, err } - return snap.NextEpochMasterNodes, err + return snap.NextEpochCandidates, err } func (x *XDPoS_v2) CalculateMissingRounds(chain consensus.ChainReader, header *types.Header) (*utils.PublicApiMissedRoundsMetadata, error) { diff --git a/consensus/XDPoS/engines/engine_v2/vote.go b/consensus/XDPoS/engines/engine_v2/vote.go index dd4680ab0e71..1ec2d4b24e42 100644 --- a/consensus/XDPoS/engines/engine_v2/vote.go +++ b/consensus/XDPoS/engines/engine_v2/vote.go @@ -122,8 +122,20 @@ func (x *XDPoS_v2) verifyVotes(chain consensus.ChainReader, votes map[common.Has for h, vote := range votes { go func(hash common.Hash, v *types.Vote) { defer wg.Done() - if v.GetSigner() != emptySigner { - // verify before + signerAddress := v.GetSigner() + if signerAddress != emptySigner { + // verify that signer belongs to the final masternodes, we have not do so in previous steps + if len(masternodes) == 0 { + log.Error("[verifyVotes] empty masternode list detected when verifying message signatures") + } + for _, mn := range masternodes { + if mn == signerAddress { + return + } + } + // if signer does not belong to final masternodes, we remove the signer + v.SetSigner(emptySigner) + log.Debug("[verifyVotes] find a vote does not belong to final masternodes", "signer", signerAddress) return } signedVote := types.VoteSigHash(&types.VoteForSign{ diff --git a/consensus/XDPoS/utils/types.go b/consensus/XDPoS/utils/types.go index 4073fb522bd9..897e984b4811 100644 --- a/consensus/XDPoS/utils/types.go +++ b/consensus/XDPoS/utils/types.go @@ -7,11 +7,11 @@ import ( "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/prque" "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/consensus/clique" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/types" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) type Masternode struct { diff --git a/consensus/clique/clique.go b/consensus/clique/clique.go index 7aa20c062a01..9bae8fd7439c 100644 --- a/consensus/clique/clique.go +++ b/consensus/clique/clique.go @@ -20,6 +20,7 @@ package clique import ( "bytes" "errors" + "fmt" "math/big" "math/rand" "sync" @@ -316,6 +317,16 @@ func (c *Clique) verifyHeader(chain consensus.ChainReader, header *types.Header, return errInvalidDifficulty } } + + // Verify that the gas limit is <= 2^63-1 + if header.GasLimit > params.MaxGasLimit { + return fmt.Errorf("invalid gasLimit: have %v, max %v", header.GasLimit, params.MaxGasLimit) + } + + // Verify that the gasUsed is <= gasLimit + if header.GasUsed > header.GasLimit { + return fmt.Errorf("invalid gasUsed: have %d, gasLimit %d", header.GasUsed, header.GasLimit) + } // If all checks passed, validate any special fields for hard forks if err := misc.VerifyForkHashes(chain.Config(), header, false); err != nil { return err @@ -347,6 +358,11 @@ func (c *Clique) verifyCascadingFields(chain consensus.ChainReader, header *type if parent.Time.Uint64()+c.config.Period > header.Time.Uint64() { return ErrInvalidTimestamp } + // Verify that the gas limit remains within allowed bounds + if err := misc.VerifyGaslimit(parent.GasLimit, header.GasLimit); err != nil { + return err + } + // Retrieve the snapshot needed to verify this header and cache it snap, err := c.snapshot(chain, number-1, header.ParentHash, parents) if err != nil { diff --git a/consensus/misc/gaslimit.go b/consensus/misc/gaslimit.go new file mode 100644 index 000000000000..e164c7a2f6b6 --- /dev/null +++ b/consensus/misc/gaslimit.go @@ -0,0 +1,41 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package misc + +import ( + "fmt" + + "github.com/XinFinOrg/XDPoSChain/params" +) + +// VerifyGaslimit verifies the header gas limit according increase/decrease +// in relation to the parent gas limit. +func VerifyGaslimit(parentGasLimit, headerGasLimit uint64) error { + // Verify that the gas limit remains within allowed bounds + diff := int64(parentGasLimit) - int64(headerGasLimit) + if diff < 0 { + diff *= -1 + } + limit := parentGasLimit / params.GasLimitBoundDivisor + if uint64(diff) >= limit { + return fmt.Errorf("invalid gas limit: have %d, want %d +-= %d", headerGasLimit, parentGasLimit, limit-1) + } + if headerGasLimit < params.MinGasLimit { + return fmt.Errorf("invalid gas limit below %d", params.MinGasLimit) + } + return nil +} diff --git a/consensus/tests/engine_v1_tests/helper.go b/consensus/tests/engine_v1_tests/helper.go index 3b777e487b54..be64d0c050a2 100644 --- a/consensus/tests/engine_v1_tests/helper.go +++ b/consensus/tests/engine_v1_tests/helper.go @@ -163,7 +163,7 @@ func transferTx(t *testing.T, to common.Address, transferAmount int64) *types.Tr amount := big.NewInt(transferAmount) nonce := uint64(1) tx := types.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data) - signedTX, err := types.SignTx(tx, types.NewEIP155Signer(big.NewInt(chainID)), voterKey) + signedTX, err := types.SignTx(tx, types.LatestSignerForChainID(big.NewInt(chainID)), voterKey) if err != nil { t.Fatal(err) } @@ -183,7 +183,7 @@ func voteTX(gasLimit uint64, nonce uint64, addr string) (*types.Transaction, err to := common.HexToAddress(common.MasternodeVotingSMC) tx := types.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data) - signedTX, err := types.SignTx(tx, types.NewEIP155Signer(big.NewInt(chainID)), voterKey) + signedTX, err := types.SignTx(tx, types.LatestSignerForChainID(big.NewInt(chainID)), voterKey) if err != nil { return nil, err } @@ -411,7 +411,7 @@ func createBlockFromHeader(bc *BlockChain, customHeader *types.Header, txs []*ty // nonce := uint64(0) // to := common.HexToAddress("xdc35658f7b2a9e7701e65e7a654659eb1c481d1dc5") // tx := types.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data) -// signedTX, err := types.SignTx(tx, types.NewEIP155Signer(big.NewInt(chainID)), acc4Key) +// signedTX, err := types.SignTx(tx, types.LatestSignerForChainID(big.NewInt(chainID)), acc4Key) // if err != nil { // t.Fatal(err) // } diff --git a/consensus/tests/engine_v2_tests/helper.go b/consensus/tests/engine_v2_tests/helper.go index f11baf1b2d92..2856b55f9f86 100644 --- a/consensus/tests/engine_v2_tests/helper.go +++ b/consensus/tests/engine_v2_tests/helper.go @@ -108,7 +108,7 @@ func voteTX(gasLimit uint64, nonce uint64, addr string) (*types.Transaction, err to := common.HexToAddress(common.MasternodeVotingSMC) tx := types.NewTransaction(nonce, to, amount, gasLimit, gasPrice, data) - signedTX, err := types.SignTx(tx, types.NewEIP155Signer(big.NewInt(chainID)), voterKey) + signedTX, err := types.SignTx(tx, types.LatestSignerForChainID(big.NewInt(chainID)), voterKey) if err != nil { return nil, err } @@ -285,7 +285,7 @@ func getMultiCandidatesBackend(t *testing.T, chainConfig *params.ChainConfig, n func signingTxWithKey(header *types.Header, nonce uint64, privateKey *ecdsa.PrivateKey) (*types.Transaction, error) { tx := contracts.CreateTxSign(header.Number, header.Hash(), nonce, common.HexToAddress(common.BlockSigners)) - s := types.NewEIP155Signer(big.NewInt(chainID)) + s := types.LatestSignerForChainID(big.NewInt(chainID)) h := s.Hash(tx) sig, err := crypto.Sign(h[:], privateKey) if err != nil { @@ -300,7 +300,7 @@ func signingTxWithKey(header *types.Header, nonce uint64, privateKey *ecdsa.Priv func signingTxWithSignerFn(header *types.Header, nonce uint64, signer common.Address, signFn func(account accounts.Account, hash []byte) ([]byte, error)) (*types.Transaction, error) { tx := contracts.CreateTxSign(header.Number, header.Hash(), nonce, common.HexToAddress(common.BlockSigners)) - s := types.NewEIP155Signer(big.NewInt(chainID)) + s := types.LatestSignerForChainID(big.NewInt(chainID)) h := s.Hash(tx) sig, err := signFn(accounts.Account{Address: signer}, h[:]) if err != nil { diff --git a/consensus/tests/engine_v2_tests/mine_test.go b/consensus/tests/engine_v2_tests/mine_test.go index ad08fa98b3ef..9ce5957b122b 100644 --- a/consensus/tests/engine_v2_tests/mine_test.go +++ b/consensus/tests/engine_v2_tests/mine_test.go @@ -56,7 +56,7 @@ func TestYourTurnInitialV2(t *testing.T) { assert.NotNil(t, snap) masterNodes := adaptor.EngineV1.GetMasternodesFromCheckpointHeader(block900.Header()) for i := 0; i < len(masterNodes); i++ { - assert.Equal(t, masterNodes[i].Hex(), snap.NextEpochMasterNodes[i].Hex()) + assert.Equal(t, masterNodes[i].Hex(), snap.NextEpochCandidates[i].Hex()) } } @@ -136,7 +136,7 @@ func TestUpdateMasterNodes(t *testing.T) { snap, err = x.GetSnapshot(blockchain, parentBlock.Header()) assert.Nil(t, err) - assert.True(t, snap.IsMasterNodes(voterAddr)) + assert.True(t, snap.IsCandidates(voterAddr)) assert.Equal(t, int(snap.Number), 1350) } @@ -211,7 +211,7 @@ func TestPrepareHappyPath(t *testing.T) { } validators := []byte{} - for _, v := range snap.NextEpochMasterNodes { + for _, v := range snap.NextEpochCandidates { validators = append(validators, v[:]...) } assert.Equal(t, validators, header901.Validators) @@ -267,7 +267,7 @@ func TestUpdateMultipleMasterNodes(t *testing.T) { assert.Nil(t, err) assert.Equal(t, 1350, int(snap.Number)) - assert.Equal(t, 128, len(snap.NextEpochMasterNodes)) // 128 is all masternode candidates, not limited by MaxMasternodes + assert.Equal(t, 128, len(snap.NextEpochCandidates)) // 128 is all masternode candidates, not limited by MaxMasternodes } } diff --git a/consensus/tests/engine_v2_tests/verify_header_test.go b/consensus/tests/engine_v2_tests/verify_header_test.go index ab3e47fbf48a..3518c31008d5 100644 --- a/consensus/tests/engine_v2_tests/verify_header_test.go +++ b/consensus/tests/engine_v2_tests/verify_header_test.go @@ -286,10 +286,10 @@ func TestConfigSwitchOnDifferentMasternodeCount(t *testing.T) { snap, err := x.GetSnapshot(blockchain, currentBlock.Header()) assert.Nil(t, err) - assert.Equal(t, len(snap.NextEpochMasterNodes), 20) + assert.Equal(t, len(snap.NextEpochCandidates), 20) header1800.Validators = []byte{} for i := 0; i < 20; i++ { - header1800.Validators = append(header1800.Validators, snap.NextEpochMasterNodes[i].Bytes()...) + header1800.Validators = append(header1800.Validators, snap.NextEpochCandidates[i].Bytes()...) } round, err := x.GetRoundNumber(header1800) diff --git a/contracts/trc21issuer/trc21issuer_test.go b/contracts/trc21issuer/trc21issuer_test.go index e0744b26d1e6..698aa3f97d4a 100644 --- a/contracts/trc21issuer/trc21issuer_test.go +++ b/contracts/trc21issuer/trc21issuer_test.go @@ -91,10 +91,13 @@ func TestFeeTxWithTRC21Token(t *testing.T) { t.Fatal("check balance after fail transfer in tr20: ", err, "get", balance, "transfer", airDropAmount) } - //check balance fee + // check balance fee balanceIssuerFee, err = trc21Issuer.GetTokenCapacity(trc21TokenAddr) - if err != nil || balanceIssuerFee.Cmp(remainFee) != 0 { - t.Fatal("can't get balance token fee in smart contract: ", err, "got", balanceIssuerFee, "wanted", remainFee) + if err != nil { + t.Fatal("can't get balance token fee in smart contract: ", err) + } + if balanceIssuerFee.Cmp(remainFee) != 0 { + t.Fatal("check balance token fee in smart contract: got", balanceIssuerFee, "wanted", remainFee) } //check trc21 SMC balance balance, err = contractBackend.BalanceAt(nil, trc21IssuerAddr, nil) diff --git a/contracts/utils.go b/contracts/utils.go index 5d6963e08ffe..8f605a5e862e 100644 --- a/contracts/utils.go +++ b/contracts/utils.go @@ -86,7 +86,7 @@ func CreateTransactionSign(chainConfig *params.ChainConfig, pool *core.TxPool, m } // Create and send tx to smart contract for sign validate block. - nonce := pool.State().GetNonce(account.Address) + nonce := pool.Nonce(account.Address) tx := CreateTxSign(block.Number(), block.Hash(), nonce, common.HexToAddress(common.BlockSigners)) txSigned, err := wallet.SignTx(account, tx, chainConfig.ChainId) if err != nil { diff --git a/core/bench_test.go b/core/bench_test.go index 588429240282..7cfed07f45a0 100644 --- a/core/bench_test.go +++ b/core/bench_test.go @@ -85,7 +85,7 @@ func genValueTx(nbytes int) func(int, *BlockGen) { return func(i int, gen *BlockGen) { toaddr := common.Address{} data := make([]byte, nbytes) - gas, _ := IntrinsicGas(data, false, false) + gas, _ := IntrinsicGas(data, nil, false, false) tx, _ := types.SignTx(types.NewTransaction(gen.TxNonce(benchRootAddr), toaddr, big.NewInt(1), gas, nil, data), types.HomesteadSigner{}, benchRootKey) gen.AddTx(tx) } diff --git a/core/blockchain.go b/core/blockchain.go index 63e401ad906e..c1fefaa3d165 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -28,13 +28,12 @@ import ( "sync/atomic" "time" - "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" - "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" + "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" "github.com/XinFinOrg/XDPoSChain/accounts/abi/bind" - "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/mclock" + "github.com/XinFinOrg/XDPoSChain/common/prque" "github.com/XinFinOrg/XDPoSChain/common/sort" "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS" @@ -53,13 +52,17 @@ import ( "github.com/XinFinOrg/XDPoSChain/rlp" "github.com/XinFinOrg/XDPoSChain/trie" lru "github.com/hashicorp/golang-lru" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) var ( blockInsertTimer = metrics.NewRegisteredTimer("chain/inserts", nil) CheckpointCh = make(chan int) ErrNoGenesis = errors.New("Genesis not found in chain") + + blockReorgMeter = metrics.NewRegisteredMeter("chain/reorg/executes", nil) + blockReorgAddMeter = metrics.NewRegisteredMeter("chain/reorg/add", nil) + blockReorgDropMeter = metrics.NewRegisteredMeter("chain/reorg/drop", nil) + blockReorgInvalidatedTx = metrics.NewRegisteredMeter("chain/reorg/invalidTx", nil) ) const ( @@ -201,7 +204,7 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par chainConfig: chainConfig, cacheConfig: cacheConfig, db: db, - triegc: prque.New(), + triegc: prque.New(nil), stateCache: state.NewDatabase(db), quit: make(chan struct{}), bodyCache: bodyCache, @@ -254,6 +257,11 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par return bc, nil } +// GetVMConfig returns the block chain VM config. +func (bc *BlockChain) GetVMConfig() *vm.Config { + return &bc.vmConfig +} + // NewBlockChainEx extend old blockchain, add order state db func NewBlockChainEx(db ethdb.Database, XDCxDb ethdb.XDCxDatabase, cacheConfig *CacheConfig, chainConfig *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) { blockchain, err := NewBlockChain(db, cacheConfig, chainConfig, engine, vmConfig) @@ -607,7 +615,7 @@ func (bc *BlockChain) repair(head **types.Block) error { if ok { tradingService := engine.GetXDCXService() lendingService := engine.GetLendingService() - if bc.Config().IsTIPXDCX((*head).Number()) && bc.chainConfig.XDPoS != nil && (*head).NumberU64() > bc.chainConfig.XDPoS.Epoch && tradingService != nil && lendingService != nil { + if bc.Config().IsTIPXDCXReceiver((*head).Number()) && bc.chainConfig.XDPoS != nil && (*head).NumberU64() > bc.chainConfig.XDPoS.Epoch && tradingService != nil && lendingService != nil { author, _ := bc.Engine().Author((*head).Header()) tradingRoot, err := tradingService.GetTradingStateRoot(*head, author) if err == nil { @@ -913,7 +921,7 @@ func (bc *BlockChain) SaveData() { if err := triedb.Commit(recent.Root(), true); err != nil { log.Error("Failed to commit recent state trie", "err", err) } - if bc.Config().IsTIPXDCX(recent.Number()) && bc.chainConfig.XDPoS != nil && recent.NumberU64() > bc.chainConfig.XDPoS.Epoch && engine != nil { + if bc.Config().IsTIPXDCXReceiver(recent.Number()) && bc.chainConfig.XDPoS != nil && recent.NumberU64() > bc.chainConfig.XDPoS.Epoch && engine != nil { author, _ := bc.Engine().Author(recent.Header()) if tradingService != nil { tradingRoot, _ := tradingService.GetTradingStateRoot(recent, author) @@ -1049,6 +1057,11 @@ func SetReceiptsData(config *params.ChainConfig, block *types.Block, receipts ty // The transaction hash can be retrieved from the transaction itself receipts[j].TxHash = transactions[j].Hash() + // block location fields + receipts[j].BlockHash = block.Hash() + receipts[j].BlockNumber = block.Number() + receipts[j].TransactionIndex = uint(j) + // The contract address can be derived from the transaction itself if transactions[j].To() == nil { // Deriving the signer is expensive, only do if it's actually needed @@ -1234,7 +1247,7 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. var tradingService utils.TradingService var lendingTrieDb *trie.Database var lendingService utils.LendingService - if bc.Config().IsTIPXDCX(block.Number()) && bc.chainConfig.XDPoS != nil && block.NumberU64() > bc.chainConfig.XDPoS.Epoch && engine != nil { + if bc.Config().IsTIPXDCXReceiver(block.Number()) && bc.chainConfig.XDPoS != nil && block.NumberU64() > bc.chainConfig.XDPoS.Epoch && engine != nil { tradingService = engine.GetXDCXService() if tradingService != nil { tradingTrieDb = tradingService.GetStateCache().TrieDB() @@ -1263,18 +1276,18 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. } else { // Full but not archive node, do proper garbage collection triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive - bc.triegc.Push(root, -float32(block.NumberU64())) + bc.triegc.Push(root, -int64(block.NumberU64())) if tradingTrieDb != nil { tradingTrieDb.Reference(tradingRoot, common.Hash{}) } if tradingService != nil { - tradingService.GetTriegc().Push(tradingRoot, -float32(block.NumberU64())) + tradingService.GetTriegc().Push(tradingRoot, -int64(block.NumberU64())) } if lendingTrieDb != nil { lendingTrieDb.Reference(lendingRoot, common.Hash{}) } if lendingService != nil { - lendingService.GetTriegc().Push(lendingRoot, -float32(block.NumberU64())) + lendingService.GetTriegc().Push(lendingRoot, -int64(block.NumberU64())) } if current := block.NumberU64(); current > triesInMemory { // Find the next state trie we need to commit @@ -1436,7 +1449,7 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types. // // After insertion is done, all accumulated events will be fired. func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) { - n, events, logs, err := bc.insertChain(chain) + n, events, logs, err := bc.insertChain(chain, true) bc.PostChainEvents(events, logs) return n, err } @@ -1444,7 +1457,11 @@ func (bc *BlockChain) InsertChain(chain types.Blocks) (int, error) { // insertChain will execute the actual chain insertion and event aggregation. The // only reason this method exists as a separate one is to make locking cleaner // with deferred statements. -func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*types.Log, error) { +func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, []interface{}, []*types.Log, error) { + // Sanity check that we have something meaningful to import + if len(chain) == 0 { + return 0, nil, nil, nil + } engine, _ := bc.Engine().(*XDPoS.XDPoS) // Do a sanity check that the provided chain is actually ordered and linked @@ -1480,12 +1497,15 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty for i, block := range chain { headers[i] = block.Header() - seals[i] = false + seals[i] = verifySeals bc.downloadingBlock.Add(block.Hash(), true) } abort, results := bc.engine.VerifyHeaders(bc, headers, seals) defer close(abort) + // Start a parallel signature recovery (signer will fluke on fork transition, minimal perf loss) + senderCacher.recoverFromBlocks(types.MakeSigner(bc.chainConfig, chain[0].Number()), chain) + // Iterate over the blocks and insert when the verifier permits for i, block := range chain { // If the chain is terminating, stop processing blocks @@ -1556,7 +1576,8 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty log.Debug("Number block need calculated again", "number", block.NumberU64(), "hash", block.Hash().Hex(), "winners", len(winner)) // Import all the pruned blocks to make the state available bc.chainmu.Unlock() - _, evs, logs, err := bc.insertChain(winner) + // During reorg, we use verifySeals=false + _, evs, logs, err := bc.insertChain(winner, false) bc.chainmu.Lock() events, coalescedLogs = evs, logs @@ -1586,7 +1607,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty var tradingService utils.TradingService var lendingService utils.LendingService isSDKNode := false - if bc.Config().IsTIPXDCX(block.Number()) && bc.chainConfig.XDPoS != nil && engine != nil && block.NumberU64() > bc.chainConfig.XDPoS.Epoch { + if bc.Config().IsTIPXDCXReceiver(block.Number()) && bc.chainConfig.XDPoS != nil && engine != nil && block.NumberU64() > bc.chainConfig.XDPoS.Epoch { author, err := bc.Engine().Author(block.Header()) // Ignore error, we're past header validation if err != nil { bc.reportBlock(block, nil, err) @@ -1843,7 +1864,8 @@ func (bc *BlockChain) getResultBlock(block *types.Block, verifiedM2 bool) (*Resu } log.Debug("Number block need calculated again", "number", block.NumberU64(), "hash", block.Hash().Hex(), "winners", len(winner)) // Import all the pruned blocks to make the state available - _, _, _, err := bc.insertChain(winner) + // During reorg, we use verifySeals=false + _, _, _, err := bc.insertChain(winner, false) if err != nil { return nil, err } @@ -1926,8 +1948,7 @@ func (bc *BlockChain) getResultBlock(block *types.Block, verifiedM2 bool) (*Resu } // liquidate / finalize open lendingTrades if block.Number().Uint64()%bc.chainConfig.XDPoS.Epoch == common.LiquidateLendingTradeBlock { - finalizedTrades := map[common.Hash]*lendingstate.LendingTrade{} - finalizedTrades, _, _, _, _, err = lendingService.ProcessLiquidationData(block.Header(), bc, statedb, tradingState, lendingState) + finalizedTrades, _, _, _, _, err := lendingService.ProcessLiquidationData(block.Header(), bc, statedb, tradingState, lendingState) if err != nil { return nil, fmt.Errorf("failed to ProcessLiquidationData. Err: %v ", err) } @@ -2048,7 +2069,7 @@ func (bc *BlockChain) insertBlock(block *types.Block) ([]interface{}, []*types.L // Only count canonical blocks for GC processing time bc.gcproc += result.proctime bc.UpdateBlocksHashCache(block) - if bc.chainConfig.IsTIPXDCX(block.Number()) && bc.chainConfig.XDPoS != nil && block.NumberU64() > bc.chainConfig.XDPoS.Epoch { + if bc.chainConfig.IsTIPXDCXReceiver(block.Number()) && bc.chainConfig.XDPoS != nil && block.NumberU64() > bc.chainConfig.XDPoS.Epoch { bc.logExchangeData(block) bc.logLendingData(block) } @@ -2234,6 +2255,9 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { } logFn("Chain split detected", "number", commonBlock.Number(), "hash", commonBlock.Hash(), "drop", len(oldChain), "dropfrom", oldChain[0].Hash(), "add", len(newChain), "addfrom", newChain[0].Hash()) + blockReorgAddMeter.Mark(int64(len(newChain))) + blockReorgDropMeter.Mark(int64(len(oldChain))) + blockReorgMeter.Mark(1) } else { log.Error("Impossible reorg, please file an issue", "oldnum", oldBlock.Number(), "oldhash", oldBlock.Hash(), "newnum", newBlock.Number(), "newhash", newBlock.Hash()) } @@ -2272,7 +2296,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error { } }() } - if bc.chainConfig.IsTIPXDCX(commonBlock.Number()) && bc.chainConfig.XDPoS != nil && commonBlock.NumberU64() > bc.chainConfig.XDPoS.Epoch { + if bc.chainConfig.IsTIPXDCXReceiver(commonBlock.Number()) && bc.chainConfig.XDPoS != nil && commonBlock.NumberU64() > bc.chainConfig.XDPoS.Epoch { bc.reorgTxMatches(deletedTxs, newChain) } return nil diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 9dc4b651175d..c19a68dcc129 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -561,7 +561,7 @@ func TestFastVsFullChains(t *testing.T) { Alloc: GenesisAlloc{address: {Balance: funds}}, } genesis = gspec.MustCommit(gendb) - signer = types.NewEIP155Signer(gspec.Config.ChainId) + signer = types.LatestSigner(gspec.Config) ) blocks, receipts := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), gendb, 1024, func(i int, block *BlockGen) { block.SetCoinbase(common.Address{0x00}) @@ -736,7 +736,7 @@ func TestChainTxReorgs(t *testing.T) { }, } genesis = gspec.MustCommit(db) - signer = types.NewEIP155Signer(gspec.Config.ChainId) + signer = types.LatestSigner(gspec.Config) ) // Create two transactions shared between the chains: @@ -842,7 +842,7 @@ func TestLogReorgs(t *testing.T) { code = common.Hex2Bytes("60606040525b7f24ec1d3ff24c2f6ff210738839dbc339cd45a5294d85c79361016243157aae7b60405180905060405180910390a15b600a8060416000396000f360606040526008565b00") gspec = &Genesis{Config: params.TestChainConfig, Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000)}}} genesis = gspec.MustCommit(db) - signer = types.NewEIP155Signer(gspec.Config.ChainId) + signer = types.LatestSigner(gspec.Config) ) blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) @@ -889,7 +889,7 @@ func TestLogReorgs(t *testing.T) { // Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000)}}, // } // genesis = gspec.MustCommit(db) -// signer = types.NewEIP155Signer(gspec.Config.ChainId) +// signer = types.LatestSigner(gspec.Config) // ) // // blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{}) @@ -1015,7 +1015,7 @@ func TestEIP155Transition(t *testing.T) { funds = big.NewInt(1000000000) deleteAddr = common.Address{1} gspec = &Genesis{ - Config: ¶ms.ChainConfig{ChainId: big.NewInt(1), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)}, + Config: ¶ms.ChainConfig{ChainId: big.NewInt(1), EIP150Block: big.NewInt(0), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)}, Alloc: GenesisAlloc{address: {Balance: funds}, deleteAddr: {Balance: new(big.Int)}}, } genesis = gspec.MustCommit(db) @@ -1046,7 +1046,7 @@ func TestEIP155Transition(t *testing.T) { } block.AddTx(tx) - tx, err = basicTx(types.NewEIP155Signer(gspec.Config.ChainId)) + tx, err = basicTx(types.LatestSigner(gspec.Config)) if err != nil { t.Fatal(err) } @@ -1058,7 +1058,7 @@ func TestEIP155Transition(t *testing.T) { } block.AddTx(tx) - tx, err = basicTx(types.NewEIP155Signer(gspec.Config.ChainId)) + tx, err = basicTx(types.LatestSigner(gspec.Config)) if err != nil { t.Fatal(err) } @@ -1086,7 +1086,7 @@ func TestEIP155Transition(t *testing.T) { } // generate an invalid chain id transaction - config := ¶ms.ChainConfig{ChainId: big.NewInt(2), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)} + config := ¶ms.ChainConfig{ChainId: big.NewInt(2), EIP150Block: big.NewInt(0), EIP155Block: big.NewInt(2), HomesteadBlock: new(big.Int)} blocks, _ = GenerateChain(config, blocks[len(blocks)-1], ethash.NewFaker(), db, 4, func(i int, block *BlockGen) { var ( tx *types.Transaction @@ -1095,9 +1095,8 @@ func TestEIP155Transition(t *testing.T) { return types.SignTx(types.NewTransaction(block.TxNonce(address), common.Address{}, new(big.Int), 21000, new(big.Int), nil), signer, key) } ) - switch i { - case 0: - tx, err = basicTx(types.NewEIP155Signer(big.NewInt(2))) + if i == 0 { + tx, err = basicTx(types.LatestSigner(config)) if err != nil { t.Fatal(err) } @@ -1136,7 +1135,7 @@ func TestEIP161AccountRemoval(t *testing.T) { var ( tx *types.Transaction err error - signer = types.NewEIP155Signer(gspec.Config.ChainId) + signer = types.LatestSigner(gspec.Config) ) switch i { case 0: @@ -1167,7 +1166,7 @@ func TestEIP161AccountRemoval(t *testing.T) { t.Error("account should not exist") } - // account musn't be created post eip 161 + // account mustn't be created post eip 161 if _, err := blockchain.InsertChain(types.Blocks{blocks[2]}); err != nil { t.Fatal(err) } @@ -1412,3 +1411,93 @@ func TestAreTwoBlocksSamePath(t *testing.T) { }) } + +// TestEIP2718Transition tests that an EIP-2718 transaction will be accepted +// after the fork block has passed. This is verified by sending an EIP-2930 +// access list transaction, which specifies a single slot access, and then +// checking that the gas usage of a hot SLOAD and a cold SLOAD are calculated +// correctly. +func TestEIP2718Transition(t *testing.T) { + var ( + aa = common.HexToAddress("0x000000000000000000000000000000000000aaaa") + + // Generate a canonical chain to act as the main dataset + engine = ethash.NewFaker() + db = rawdb.NewMemoryDatabase() + + // A sender who makes transactions, has some funds + key, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + address = crypto.PubkeyToAddress(key.PublicKey) + funds = big.NewInt(1000000000) + gspec = &Genesis{ + Config: ¶ms.ChainConfig{ + ChainId: new(big.Int).SetBytes([]byte("eip1559")), + HomesteadBlock: big.NewInt(0), + DAOForkBlock: nil, + DAOForkSupport: true, + EIP150Block: big.NewInt(0), + EIP155Block: big.NewInt(0), + EIP158Block: big.NewInt(0), + ByzantiumBlock: big.NewInt(0), + ConstantinopleBlock: big.NewInt(0), + PetersburgBlock: big.NewInt(0), + IstanbulBlock: big.NewInt(0), + Eip1559Block: big.NewInt(0), + }, + Alloc: GenesisAlloc{ + address: {Balance: funds}, + // The address 0xAAAA sloads 0x00 and 0x01 + aa: { + Code: []byte{ + byte(vm.PC), + byte(vm.PC), + byte(vm.SLOAD), + byte(vm.SLOAD), + }, + Nonce: 0, + Balance: big.NewInt(0), + }, + }, + } + genesis = gspec.MustCommit(db) + ) + + blocks, _ := GenerateChain(gspec.Config, genesis, engine, db, 1, func(i int, b *BlockGen) { + b.SetCoinbase(common.Address{1}) + + // One transaction to 0xAAAA + signer := types.LatestSigner(gspec.Config) + tx, _ := types.SignNewTx(key, signer, &types.AccessListTx{ + ChainID: gspec.Config.ChainId, + Nonce: 0, + To: &aa, + Gas: 30000, + GasPrice: big.NewInt(1), + AccessList: types.AccessList{{ + Address: aa, + StorageKeys: []common.Hash{{0}}, + }}, + }) + b.AddTx(tx) + }) + + // Import the canonical chain + diskdb := rawdb.NewMemoryDatabase() + gspec.MustCommit(diskdb) + + chain, err := NewBlockChain(diskdb, nil, gspec.Config, engine, vm.Config{}) + if err != nil { + t.Fatalf("failed to create tester chain: %v", err) + } + if n, err := chain.InsertChain(blocks); err != nil { + t.Fatalf("block %d: failed to insert into chain: %v", n, err) + } + + block := chain.GetBlockByNumber(1) + + // Expected gas is intrinsic + 2 * pc + hot load + cold load, since only one load is in the access list + expected := params.TxGas + params.TxAccessListAddressGas + params.TxAccessListStorageKeyGas + vm.GasQuickStep*2 + vm.WarmStorageReadCostEIP2929 + vm.ColdSloadCostEIP2929 + if block.GasUsed() != expected { + t.Fatalf("incorrect amount of gas spent: expected %d, got %d", expected, block.GasUsed()) + } +} diff --git a/core/database_util.go b/core/database_util.go index 8e690898a30e..647d6afd3b59 100644 --- a/core/database_util.go +++ b/core/database_util.go @@ -259,10 +259,9 @@ func GetBlockReceipts(db DatabaseReader, hash common.Hash, number uint64) types. receipts := make(types.Receipts, len(storageReceipts)) for i, receipt := range storageReceipts { receipts[i] = (*types.Receipt)(receipt) - for _, log := range receipts[i].Logs { - // update BlockHash to fix #208 - log.BlockHash = hash - } + receipts[i].BlockHash = hash + receipts[i].BlockNumber = big.NewInt(0).SetUint64(number) + receipts[i].TransactionIndex = uint(i) } return receipts } diff --git a/core/database_util_test.go b/core/database_util_test.go index a38a68fd418e..ecd843a3e880 100644 --- a/core/database_util_test.go +++ b/core/database_util_test.go @@ -18,11 +18,11 @@ package core import ( "bytes" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math/big" "testing" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/crypto/sha3" "github.com/XinFinOrg/XDPoSChain/rlp" @@ -335,6 +335,10 @@ func TestLookupStorage(t *testing.T) { func TestBlockReceiptStorage(t *testing.T) { db := rawdb.NewMemoryDatabase() + // Create a live block since we need metadata to reconstruct the receipt + tx1 := types.NewTransaction(1, common.HexToAddress("0x1"), big.NewInt(1), 1, big.NewInt(1), nil) + tx2 := types.NewTransaction(2, common.HexToAddress("0x2"), big.NewInt(2), 2, big.NewInt(2), nil) + receipt1 := &types.Receipt{ Status: types.ReceiptStatusFailed, CumulativeGasUsed: 1, @@ -342,10 +346,12 @@ func TestBlockReceiptStorage(t *testing.T) { {Address: common.BytesToAddress([]byte{0x11})}, {Address: common.BytesToAddress([]byte{0x01, 0x11})}, }, - TxHash: common.BytesToHash([]byte{0x11, 0x11}), + TxHash: tx1.Hash(), ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}), GasUsed: 111111, } + receipt1.Bloom = types.CreateBloom(types.Receipts{receipt1}) + receipt2 := &types.Receipt{ PostState: common.Hash{2}.Bytes(), CumulativeGasUsed: 2, @@ -353,10 +359,12 @@ func TestBlockReceiptStorage(t *testing.T) { {Address: common.BytesToAddress([]byte{0x22})}, {Address: common.BytesToAddress([]byte{0x02, 0x22})}, }, - TxHash: common.BytesToHash([]byte{0x22, 0x22}), + TxHash: tx2.Hash(), ContractAddress: common.BytesToAddress([]byte{0x02, 0x22, 0x22}), GasUsed: 222222, } + receipt2.Bloom = types.CreateBloom(types.Receipts{receipt2}) + receipts := []*types.Receipt{receipt1, receipt2} // Check that no receipt entries are in a pristine database diff --git a/core/error.go b/core/error.go index f6559bf06f19..6268c0dc9abc 100644 --- a/core/error.go +++ b/core/error.go @@ -16,7 +16,11 @@ package core -import "errors" +import ( + "errors" + + "github.com/XinFinOrg/XDPoSChain/core/types" +) var ( // ErrKnownBlock is returned when a block to import is already known locally. @@ -38,4 +42,8 @@ var ( ErrNotFoundM1 = errors.New("list M1 not found ") ErrStopPreparingBlock = errors.New("stop calculating a block not verified by M2") + + // ErrTxTypeNotSupported is returned if a transaction is not supported in the + // current network configuration. + ErrTxTypeNotSupported = types.ErrTxTypeNotSupported ) diff --git a/core/events.go b/core/events.go index fbdbc030ddd2..60dc8d7ddd36 100644 --- a/core/events.go +++ b/core/events.go @@ -21,8 +21,8 @@ import ( "github.com/XinFinOrg/XDPoSChain/core/types" ) -// TxPreEvent is posted when a transaction enters the transaction pool. -type TxPreEvent struct{ Tx *types.Transaction } +// NewTxsEvent is posted when a batch of transactions enter the transaction pool. +type NewTxsEvent struct{ Txs []*types.Transaction } // OrderTxPreEvent is posted when a order transaction enters the order transaction pool. type OrderTxPreEvent struct{ Tx *types.OrderTransaction } @@ -41,9 +41,6 @@ type PendingStateEvent struct{} // NewMinedBlockEvent is posted when a block has been imported. type NewMinedBlockEvent struct{ Block *types.Block } -// RemovedTransactionEvent is posted when a reorg happens -type RemovedTransactionEvent struct{ Txs types.Transactions } - // RemovedLogsEvent is posted when a reorg happens type RemovedLogsEvent struct{ Logs []*types.Log } diff --git a/core/lending_pool.go b/core/lending_pool.go index 6d5125d3a927..fc0deb18215b 100644 --- a/core/lending_pool.go +++ b/core/lending_pool.go @@ -24,18 +24,16 @@ import ( "sync" "time" - "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS" - "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" - "github.com/XinFinOrg/XDPoSChain/consensus" - "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/prque" + "github.com/XinFinOrg/XDPoSChain/consensus" + "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/event" "github.com/XinFinOrg/XDPoSChain/log" "github.com/XinFinOrg/XDPoSChain/params" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) var ( @@ -273,7 +271,7 @@ func (pool *LendingPool) loop() { // reset retrieves the current state of the blockchain and ensures the content // of the transaction pool is valid with regard to the chain state. func (pool *LendingPool) reset(oldHead, newblock *types.Block) { - if !pool.chainconfig.IsTIPXDCX(pool.chain.CurrentBlock().Number()) || pool.chain.Config().XDPoS == nil || pool.chain.CurrentBlock().NumberU64() <= pool.chain.Config().XDPoS.Epoch { + if !pool.chainconfig.IsTIPXDCXReceiver(pool.chain.CurrentBlock().Number()) || pool.chain.Config().XDPoS == nil || pool.chain.CurrentBlock().NumberU64() <= pool.chain.Config().XDPoS.Epoch { return } // If we're reorging an old state, reinject all dropped transactions @@ -671,7 +669,7 @@ func (pool *LendingPool) add(tx *types.LendingTransaction, local bool) (bool, er // If the transaction fails basic validation, discard it if err := pool.validateTx(tx, local); err != nil { log.Debug("Discarding invalid lending transaction", "hash", hash, "userAddress", tx.UserAddress, "status", tx.Status, "err", err) - invalidTxCounter.Inc(1) + invalidTxMeter.Mark(1) return false, err } from, _ := types.LendingSender(pool.signer, tx) // already validated @@ -685,12 +683,12 @@ func (pool *LendingPool) add(tx *types.LendingTransaction, local bool) (bool, er if list := pool.pending[from]; list != nil && list.Overlaps(tx) { inserted, old := list.Add(tx) if !inserted { - pendingDiscardCounter.Inc(1) + pendingDiscardMeter.Mark(1) return false, ErrPendingNonceTooLow } if old != nil { delete(pool.all, old.Hash()) - pendingReplaceCounter.Inc(1) + pendingReplaceMeter.Mark(1) } pool.all[tx.Hash()] = tx pool.journalTx(from, tx) @@ -726,13 +724,13 @@ func (pool *LendingPool) enqueueTx(hash common.Hash, tx *types.LendingTransactio inserted, old := pool.queue[from].Add(tx) if !inserted { // An older transaction was better, discard this - queuedDiscardCounter.Inc(1) + pendingDiscardMeter.Mark(1) return false, ErrPendingNonceTooLow } // Discard any previous transaction and mark this if old != nil { delete(pool.all, old.Hash()) - queuedReplaceCounter.Inc(1) + queuedReplaceMeter.Mark(1) } pool.all[hash] = tx return old != nil, nil @@ -764,13 +762,13 @@ func (pool *LendingPool) promoteTx(addr common.Address, hash common.Hash, tx *ty if !inserted { // An older transaction was better, discard this delete(pool.all, hash) - pendingDiscardCounter.Inc(1) + pendingDiscardMeter.Mark(1) return } // Otherwise discard any previous transaction and mark this if old != nil { delete(pool.all, old.Hash()) - pendingReplaceCounter.Inc(1) + pendingReplaceMeter.Mark(1) } // Failsafe to work around direct pending inserts (tests) if pool.all[hash] == nil { @@ -814,7 +812,7 @@ func (pool *LendingPool) AddRemotes(txs []*types.LendingTransaction) []error { // addTx enqueues a single transaction into the pool if it is valid. func (pool *LendingPool) addTx(tx *types.LendingTransaction, local bool) error { - if !pool.chainconfig.IsTIPXDCX(pool.chain.CurrentBlock().Number()) { + if !pool.chainconfig.IsTIPXDCXReceiver(pool.chain.CurrentBlock().Number()) { return nil } tx.CacheHash() @@ -981,7 +979,7 @@ func (pool *LendingPool) promoteExecutables(accounts []common.Address) { hash := tx.Hash() delete(pool.all, hash) - queuedRateLimitCounter.Inc(1) + queuedRateLimitMeter.Mark(1) log.Trace("Removed cap-exceeding queued transaction", "hash", hash) } } @@ -998,11 +996,11 @@ func (pool *LendingPool) promoteExecutables(accounts []common.Address) { if pending > pool.config.GlobalSlots { pendingBeforeCap := pending // Assemble a spam order to penalize large transactors first - spammers := prque.New() + spammers := prque.New(nil) for addr, list := range pool.pending { // Only evict transactions from high rollers if !pool.locals.contains(addr) && uint64(list.Len()) > pool.config.AccountSlots { - spammers.Push(addr, float32(list.Len())) + spammers.Push(addr, int64(list.Len())) } } // Gradually drop transactions from offenders @@ -1057,7 +1055,7 @@ func (pool *LendingPool) promoteExecutables(accounts []common.Address) { } } } - pendingRateLimitCounter.Inc(int64(pendingBeforeCap - pending)) + pendingRateLimitMeter.Mark(int64(pendingBeforeCap - pending)) } // If we've queued more transactions than the hard limit, drop oldest ones queued := uint64(0) @@ -1066,7 +1064,7 @@ func (pool *LendingPool) promoteExecutables(accounts []common.Address) { } if queued > pool.config.GlobalQueue { // Sort all accounts with queued transactions by heartbeat - addresses := make(addresssByHeartbeat, 0, len(pool.queue)) + addresses := make(addressesByHeartbeat, 0, len(pool.queue)) for addr := range pool.queue { if !pool.locals.contains(addr) { // don't drop locals addresses = append(addresses, addressByHeartbeat{addr, pool.beats[addr]}) @@ -1087,7 +1085,7 @@ func (pool *LendingPool) promoteExecutables(accounts []common.Address) { pool.removeTx(tx.Hash()) } drop -= size - queuedRateLimitCounter.Inc(int64(size)) + queuedRateLimitMeter.Mark(int64(size)) continue } // Otherwise drop only last few transactions @@ -1095,7 +1093,7 @@ func (pool *LendingPool) promoteExecutables(accounts []common.Address) { for i := len(txs) - 1; i >= 0 && drop > 0; i-- { pool.removeTx(txs[i].Hash()) drop-- - queuedRateLimitCounter.Inc(1) + queuedRateLimitMeter.Mark(1) } } } diff --git a/core/order_pool.go b/core/order_pool.go index ed275e179204..c8708149b89c 100644 --- a/core/order_pool.go +++ b/core/order_pool.go @@ -25,16 +25,15 @@ import ( "time" "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/prque" "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS" - - "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/event" "github.com/XinFinOrg/XDPoSChain/log" "github.com/XinFinOrg/XDPoSChain/params" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) var ( @@ -278,7 +277,7 @@ func (pool *OrderPool) loop() { // reset retrieves the current state of the blockchain and ensures the content // of the transaction pool is valid with regard to the chain state. func (pool *OrderPool) reset(oldHead, newblock *types.Block) { - if !pool.chainconfig.IsTIPXDCX(pool.chain.CurrentBlock().Number()) || pool.chain.Config().XDPoS == nil || pool.chain.CurrentBlock().NumberU64() <= pool.chain.Config().XDPoS.Epoch { + if !pool.chainconfig.IsTIPXDCXReceiver(pool.chain.CurrentBlock().Number()) || pool.chain.Config().XDPoS == nil || pool.chain.CurrentBlock().NumberU64() <= pool.chain.Config().XDPoS.Epoch { return } // If we're reorging an old state, reinject all dropped transactions @@ -579,7 +578,7 @@ func (pool *OrderPool) add(tx *types.OrderTransaction, local bool) (bool, error) // If the transaction fails basic validation, discard it if err := pool.validateTx(tx, local); err != nil { log.Debug("Discarding invalid order transaction", "hash", hash, "userAddress", tx.UserAddress().Hex(), "status", tx.Status, "err", err) - invalidTxCounter.Inc(1) + invalidTxMeter.Mark(1) return false, err } from, _ := types.OrderSender(pool.signer, tx) // already validated @@ -593,12 +592,12 @@ func (pool *OrderPool) add(tx *types.OrderTransaction, local bool) (bool, error) if list := pool.pending[from]; list != nil && list.Overlaps(tx) { inserted, old := list.Add(tx) if !inserted { - pendingDiscardCounter.Inc(1) + pendingDiscardMeter.Mark(1) return false, ErrPendingNonceTooLow } if old != nil { delete(pool.all, old.Hash()) - pendingReplaceCounter.Inc(1) + pendingReplaceMeter.Mark(1) } pool.all[tx.Hash()] = tx pool.journalTx(from, tx) @@ -636,13 +635,13 @@ func (pool *OrderPool) enqueueTx(hash common.Hash, tx *types.OrderTransaction) ( inserted, old := pool.queue[from].Add(tx) if !inserted { // An older transaction was better, discard this - queuedDiscardCounter.Inc(1) + queuedDiscardMeter.Mark(1) return false, ErrPendingNonceTooLow } // Discard any previous transaction and mark this if old != nil { delete(pool.all, old.Hash()) - queuedReplaceCounter.Inc(1) + queuedReplaceMeter.Mark(1) } pool.all[hash] = tx return old != nil, nil @@ -675,13 +674,13 @@ func (pool *OrderPool) promoteTx(addr common.Address, hash common.Hash, tx *type if !inserted { // An older transaction was better, discard this delete(pool.all, hash) - pendingDiscardCounter.Inc(1) + pendingDiscardMeter.Mark(1) return } // Otherwise discard any previous transaction and mark this if old != nil { delete(pool.all, old.Hash()) - pendingReplaceCounter.Inc(1) + pendingReplaceMeter.Mark(1) } // Failsafe to work around direct pending inserts (tests) if pool.all[hash] == nil { @@ -729,7 +728,7 @@ func (pool *OrderPool) AddRemotes(txs []*types.OrderTransaction) []error { // addTx enqueues a single transaction into the pool if it is valid. func (pool *OrderPool) addTx(tx *types.OrderTransaction, local bool) error { - if !pool.chainconfig.IsTIPXDCX(pool.chain.CurrentBlock().Number()) { + if !pool.chainconfig.IsTIPXDCXReceiver(pool.chain.CurrentBlock().Number()) { return nil } tx.CacheHash() @@ -896,7 +895,7 @@ func (pool *OrderPool) promoteExecutables(accounts []common.Address) { hash := tx.Hash() delete(pool.all, hash) - queuedRateLimitCounter.Inc(1) + queuedRateLimitMeter.Mark(1) log.Debug("Removed cap-exceeding queued transaction", "addr", tx.UserAddress().Hex(), "nonce", tx.Nonce(), "ohash", tx.OrderHash().Hex(), "status", tx.Status(), "orderid", tx.OrderID()) } } @@ -914,11 +913,11 @@ func (pool *OrderPool) promoteExecutables(accounts []common.Address) { if pending > pool.config.GlobalSlots { pendingBeforeCap := pending // Assemble a spam order to penalize large transactors first - spammers := prque.New() + spammers := prque.New(nil) for addr, list := range pool.pending { // Only evict transactions from high rollers if !pool.locals.contains(addr) && uint64(list.Len()) > pool.config.AccountSlots { - spammers.Push(addr, float32(list.Len())) + spammers.Push(addr, int64(list.Len())) } } // Gradually drop transactions from offenders @@ -973,7 +972,7 @@ func (pool *OrderPool) promoteExecutables(accounts []common.Address) { } } } - pendingRateLimitCounter.Inc(int64(pendingBeforeCap - pending)) + pendingRateLimitMeter.Mark(int64(pendingBeforeCap - pending)) } // If we've queued more transactions than the hard limit, drop oldest ones queued := uint64(0) @@ -982,7 +981,7 @@ func (pool *OrderPool) promoteExecutables(accounts []common.Address) { } if queued > pool.config.GlobalQueue { // Sort all accounts with queued transactions by heartbeat - addresses := make(addresssByHeartbeat, 0, len(pool.queue)) + addresses := make(addressesByHeartbeat, 0, len(pool.queue)) for addr := range pool.queue { if !pool.locals.contains(addr) { // don't drop locals addresses = append(addresses, addressByHeartbeat{addr, pool.beats[addr]}) @@ -1003,7 +1002,7 @@ func (pool *OrderPool) promoteExecutables(accounts []common.Address) { pool.removeTx(tx.Hash()) } drop -= size - queuedRateLimitCounter.Inc(int64(size)) + queuedRateLimitMeter.Mark(int64(size)) continue } // Otherwise drop only last few transactions @@ -1011,7 +1010,7 @@ func (pool *OrderPool) promoteExecutables(accounts []common.Address) { for i := len(txs) - 1; i >= 0 && drop > 0; i-- { pool.removeTx(txs[i].Hash()) drop-- - queuedRateLimitCounter.Inc(1) + queuedRateLimitMeter.Mark(1) } } } diff --git a/core/state/access_list.go b/core/state/access_list.go new file mode 100644 index 000000000000..7e638aeaba60 --- /dev/null +++ b/core/state/access_list.go @@ -0,0 +1,136 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package state + +import ( + "github.com/XinFinOrg/XDPoSChain/common" +) + +type accessList struct { + addresses map[common.Address]int + slots []map[common.Hash]struct{} +} + +// ContainsAddress returns true if the address is in the access list. +func (al *accessList) ContainsAddress(address common.Address) bool { + _, ok := al.addresses[address] + return ok +} + +// Contains checks if a slot within an account is present in the access list, returning +// separate flags for the presence of the account and the slot respectively. +func (al *accessList) Contains(address common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) { + idx, ok := al.addresses[address] + if !ok { + // no such address (and hence zero slots) + return false, false + } + if idx == -1 { + // address yes, but no slots + return true, false + } + _, slotPresent = al.slots[idx][slot] + return true, slotPresent +} + +// newAccessList creates a new accessList. +func newAccessList() *accessList { + return &accessList{ + addresses: make(map[common.Address]int), + } +} + +// Copy creates an independent copy of an accessList. +func (a *accessList) Copy() *accessList { + cp := newAccessList() + for k, v := range a.addresses { + cp.addresses[k] = v + } + cp.slots = make([]map[common.Hash]struct{}, len(a.slots)) + for i, slotMap := range a.slots { + newSlotmap := make(map[common.Hash]struct{}, len(slotMap)) + for k := range slotMap { + newSlotmap[k] = struct{}{} + } + cp.slots[i] = newSlotmap + } + return cp +} + +// AddAddress adds an address to the access list, and returns 'true' if the operation +// caused a change (addr was not previously in the list). +func (al *accessList) AddAddress(address common.Address) bool { + if _, present := al.addresses[address]; present { + return false + } + al.addresses[address] = -1 + return true +} + +// AddSlot adds the specified (addr, slot) combo to the access list. +// Return values are: +// - address added +// - slot added +// For any 'true' value returned, a corresponding journal entry must be made. +func (al *accessList) AddSlot(address common.Address, slot common.Hash) (addrChange bool, slotChange bool) { + idx, addrPresent := al.addresses[address] + if !addrPresent || idx == -1 { + // Address not present, or addr present but no slots there + al.addresses[address] = len(al.slots) + slotmap := map[common.Hash]struct{}{slot: {}} + al.slots = append(al.slots, slotmap) + return !addrPresent, true + } + // There is already an (address,slot) mapping + slotmap := al.slots[idx] + if _, ok := slotmap[slot]; !ok { + slotmap[slot] = struct{}{} + // Journal add slot change + return false, true + } + // No changes required + return false, false +} + +// DeleteSlot removes an (address, slot)-tuple from the access list. +// This operation needs to be performed in the same order as the addition happened. +// This method is meant to be used by the journal, which maintains ordering of +// operations. +func (al *accessList) DeleteSlot(address common.Address, slot common.Hash) { + idx, addrOk := al.addresses[address] + // There are two ways this can fail + if !addrOk { + panic("reverting slot change, address not present in list") + } + slotmap := al.slots[idx] + delete(slotmap, slot) + // If that was the last (first) slot, remove it + // Since additions and rollbacks are always performed in order, + // we can delete the item without worrying about screwing up later indices + if len(slotmap) == 0 { + al.slots = al.slots[:idx] + al.addresses[address] = -1 + } +} + +// DeleteAddress removes an address from the access list. This operation +// needs to be performed in the same order as the addition happened. +// This method is meant to be used by the journal, which maintains ordering of +// operations. +func (al *accessList) DeleteAddress(address common.Address) { + delete(al.addresses, address) +} diff --git a/core/state/journal.go b/core/state/journal.go index ac6461df1366..b9348f7bf26b 100644 --- a/core/state/journal.go +++ b/core/state/journal.go @@ -28,6 +28,11 @@ type journalEntry interface { type journal []journalEntry +// length returns the current number of entries in the journal. +func (j *journal) length() int { + return len(*j) +} + type ( // Changes to the account trie. createObjectChange struct { @@ -75,6 +80,14 @@ type ( prev bool prevDirty bool } + // Changes to the access list + accessListAddAccountChange struct { + address *common.Address + } + accessListAddSlotChange struct { + address *common.Address + slot *common.Hash + } ) func (ch createObjectChange) undo(s *StateDB) { @@ -138,3 +151,20 @@ func (ch addLogChange) undo(s *StateDB) { func (ch addPreimageChange) undo(s *StateDB) { delete(s.preimages, ch.hash) } + +func (ch accessListAddAccountChange) undo(s *StateDB) { + /* + One important invariant here, is that whenever a (addr, slot) is added, if the + addr is not already present, the add causes two journal entries: + - one for the address, + - one for the (address,slot) + Therefore, when unrolling the change, we can always blindly delete the + (addr) at this point, since no storage adds can remain when come upon + a single (addr) change. + */ + s.accessList.DeleteAddress(*ch.address) +} + +func (ch accessListAddSlotChange) undo(s *StateDB) { + s.accessList.DeleteSlot(*ch.address, *ch.slot) +} diff --git a/core/state/managed_state.go b/core/state/managed_state.go deleted file mode 100644 index fbd5d2959376..000000000000 --- a/core/state/managed_state.go +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package state - -import ( - "sync" - - "github.com/XinFinOrg/XDPoSChain/common" -) - -type account struct { - stateObject *stateObject - nstart uint64 - nonces []bool -} - -type ManagedState struct { - *StateDB - - mu sync.RWMutex - - accounts map[common.Address]*account -} - -// ManagedState returns a new managed state with the statedb as it's backing layer -func ManageState(statedb *StateDB) *ManagedState { - return &ManagedState{ - StateDB: statedb.Copy(), - accounts: make(map[common.Address]*account), - } -} - -// SetState sets the backing layer of the managed state -func (ms *ManagedState) SetState(statedb *StateDB) { - ms.mu.Lock() - defer ms.mu.Unlock() - ms.StateDB = statedb -} - -// RemoveNonce removed the nonce from the managed state and all future pending nonces -func (ms *ManagedState) RemoveNonce(addr common.Address, n uint64) { - if ms.hasAccount(addr) { - ms.mu.Lock() - defer ms.mu.Unlock() - - account := ms.getAccount(addr) - if n-account.nstart <= uint64(len(account.nonces)) { - reslice := make([]bool, n-account.nstart) - copy(reslice, account.nonces[:n-account.nstart]) - account.nonces = reslice - } - } -} - -// NewNonce returns the new canonical nonce for the managed account -func (ms *ManagedState) NewNonce(addr common.Address) uint64 { - ms.mu.Lock() - defer ms.mu.Unlock() - - account := ms.getAccount(addr) - for i, nonce := range account.nonces { - if !nonce { - return account.nstart + uint64(i) - } - } - account.nonces = append(account.nonces, true) - - return uint64(len(account.nonces)-1) + account.nstart -} - -// GetNonce returns the canonical nonce for the managed or unmanaged account. -// -// Because GetNonce mutates the DB, we must take a write lock. -func (ms *ManagedState) GetNonce(addr common.Address) uint64 { - ms.mu.Lock() - defer ms.mu.Unlock() - - if ms.hasAccount(addr) { - account := ms.getAccount(addr) - return uint64(len(account.nonces)) + account.nstart - } else { - return ms.StateDB.GetNonce(addr) - } -} - -// SetNonce sets the new canonical nonce for the managed state -func (ms *ManagedState) SetNonce(addr common.Address, nonce uint64) { - ms.mu.Lock() - defer ms.mu.Unlock() - - so := ms.GetOrNewStateObject(addr) - so.SetNonce(nonce) - - ms.accounts[addr] = newAccount(so) -} - -// HasAccount returns whether the given address is managed or not -func (ms *ManagedState) HasAccount(addr common.Address) bool { - ms.mu.RLock() - defer ms.mu.RUnlock() - return ms.hasAccount(addr) -} - -func (ms *ManagedState) hasAccount(addr common.Address) bool { - _, ok := ms.accounts[addr] - return ok -} - -// populate the managed state -func (ms *ManagedState) getAccount(addr common.Address) *account { - if account, ok := ms.accounts[addr]; !ok { - so := ms.GetOrNewStateObject(addr) - ms.accounts[addr] = newAccount(so) - } else { - // Always make sure the state account nonce isn't actually higher - // than the tracked one. - so := ms.StateDB.getStateObject(addr) - if so != nil && uint64(len(account.nonces))+account.nstart < so.Nonce() { - ms.accounts[addr] = newAccount(so) - } - - } - - return ms.accounts[addr] -} - -func newAccount(so *stateObject) *account { - return &account{so, so.Nonce(), nil} -} diff --git a/core/state/managed_state_test.go b/core/state/managed_state_test.go deleted file mode 100644 index 13f35a8a51fc..000000000000 --- a/core/state/managed_state_test.go +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package state - -import ( - "github.com/XinFinOrg/XDPoSChain/core/rawdb" - "testing" - - "github.com/XinFinOrg/XDPoSChain/common" -) - -var addr = common.BytesToAddress([]byte("test")) - -func create() (*ManagedState, *account) { - db := rawdb.NewMemoryDatabase() - statedb, _ := New(common.Hash{}, NewDatabase(db)) - ms := ManageState(statedb) - ms.StateDB.SetNonce(addr, 100) - ms.accounts[addr] = newAccount(ms.StateDB.getStateObject(addr)) - return ms, ms.accounts[addr] -} - -func TestNewNonce(t *testing.T) { - ms, _ := create() - - nonce := ms.NewNonce(addr) - if nonce != 100 { - t.Error("expected nonce 100. got", nonce) - } - - nonce = ms.NewNonce(addr) - if nonce != 101 { - t.Error("expected nonce 101. got", nonce) - } -} - -func TestRemove(t *testing.T) { - ms, account := create() - - nn := make([]bool, 10) - for i := range nn { - nn[i] = true - } - account.nonces = append(account.nonces, nn...) - - i := uint64(5) - ms.RemoveNonce(addr, account.nstart+i) - if len(account.nonces) != 5 { - t.Error("expected", i, "'th index to be false") - } -} - -func TestReuse(t *testing.T) { - ms, account := create() - - nn := make([]bool, 10) - for i := range nn { - nn[i] = true - } - account.nonces = append(account.nonces, nn...) - - i := uint64(5) - ms.RemoveNonce(addr, account.nstart+i) - nonce := ms.NewNonce(addr) - if nonce != 105 { - t.Error("expected nonce to be 105. got", nonce) - } -} - -func TestRemoteNonceChange(t *testing.T) { - ms, account := create() - nn := make([]bool, 10) - for i := range nn { - nn[i] = true - } - account.nonces = append(account.nonces, nn...) - ms.NewNonce(addr) - - ms.StateDB.stateObjects[addr].data.Nonce = 200 - nonce := ms.NewNonce(addr) - if nonce != 200 { - t.Error("expected nonce after remote update to be", 200, "got", nonce) - } - ms.NewNonce(addr) - ms.NewNonce(addr) - ms.NewNonce(addr) - ms.StateDB.stateObjects[addr].data.Nonce = 200 - nonce = ms.NewNonce(addr) - if nonce != 204 { - t.Error("expected nonce after remote update to be", 204, "got", nonce) - } -} - -func TestSetNonce(t *testing.T) { - ms, _ := create() - - var addr common.Address - ms.SetNonce(addr, 10) - - if ms.GetNonce(addr) != 10 { - t.Error("Expected nonce of 10, got", ms.GetNonce(addr)) - } - - addr[0] = 1 - ms.StateDB.SetNonce(addr, 1) - - if ms.GetNonce(addr) != 1 { - t.Error("Expected nonce of 1, got", ms.GetNonce(addr)) - } -} diff --git a/core/state/state_object.go b/core/state/state_object.go index a5c7ce0bc964..99763db267e2 100644 --- a/core/state/state_object.go +++ b/core/state/state_object.go @@ -31,23 +31,23 @@ var emptyCodeHash = crypto.Keccak256(nil) type Code []byte -func (self Code) String() string { - return string(self) //strings.Join(Disassemble(self), " ") +func (c Code) String() string { + return string(c) //strings.Join(Disassemble(c), " ") } type Storage map[common.Hash]common.Hash -func (self Storage) String() (str string) { - for key, value := range self { +func (s Storage) String() (str string) { + for key, value := range s { str += fmt.Sprintf("%X : %X\n", key, value) } return } -func (self Storage) Copy() Storage { +func (s Storage) Copy() Storage { cpy := make(Storage) - for key, value := range self { + for key, value := range s { cpy[key] = value } @@ -79,6 +79,7 @@ type stateObject struct { cachedStorage Storage // Storage entry cache to avoid duplicate reads dirtyStorage Storage // Storage entries that need to be flushed to disk + fakeStorage Storage // Fake storage which constructed by caller for debugging purpose. // Cache flags. // When an object is marked suicided it will be delete from the trie @@ -124,199 +125,236 @@ func newObject(db *StateDB, address common.Address, data Account, onDirty func(a } // EncodeRLP implements rlp.Encoder. -func (c *stateObject) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, c.data) +func (s *stateObject) EncodeRLP(w io.Writer) error { + return rlp.Encode(w, s.data) } // setError remembers the first non-nil error it is called with. -func (self *stateObject) setError(err error) { - if self.dbErr == nil { - self.dbErr = err +func (s *stateObject) setError(err error) { + if s.dbErr == nil { + s.dbErr = err } } -func (self *stateObject) markSuicided() { - self.suicided = true - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil +func (s *stateObject) markSuicided() { + s.suicided = true + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } -func (c *stateObject) touch() { - c.db.journal = append(c.db.journal, touchChange{ - account: &c.address, - prev: c.touched, - prevDirty: c.onDirty == nil, +func (s *stateObject) touch() { + s.db.journal = append(s.db.journal, touchChange{ + account: &s.address, + prev: s.touched, + prevDirty: s.onDirty == nil, }) - if c.onDirty != nil { - c.onDirty(c.Address()) - c.onDirty = nil + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } - c.touched = true + s.touched = true } -func (c *stateObject) getTrie(db Database) Trie { - if c.trie == nil { +func (s *stateObject) getTrie(db Database) Trie { + if s.trie == nil { var err error - c.trie, err = db.OpenStorageTrie(c.addrHash, c.data.Root) + s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root) if err != nil { - c.trie, _ = db.OpenStorageTrie(c.addrHash, common.Hash{}) - c.setError(fmt.Errorf("can't create storage trie: %v", err)) + s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{}) + s.setError(fmt.Errorf("can't create storage trie: %v", err)) } } - return c.trie + return s.trie } -func (self *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { +func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Hash { + // If the fake storage is set, only lookup the state here(in the debugging mode) + if s.fakeStorage != nil { + return s.fakeStorage[key] + } value := common.Hash{} // Load from DB in case it is missing. - enc, err := self.getTrie(db).TryGet(key[:]) + enc, err := s.getTrie(db).TryGet(key[:]) if err != nil { - self.setError(err) + s.setError(err) return common.Hash{} } if len(enc) > 0 { _, content, _, err := rlp.Split(enc) if err != nil { - self.setError(err) + s.setError(err) } value.SetBytes(content) } return value } -func (self *stateObject) GetState(db Database, key common.Hash) common.Hash { - value, exists := self.cachedStorage[key] +func (s *stateObject) GetState(db Database, key common.Hash) common.Hash { + // If the fake storage is set, only lookup the state here(in the debugging mode) + if s.fakeStorage != nil { + return s.fakeStorage[key] + } + value, exists := s.cachedStorage[key] if exists { return value } // Load from DB in case it is missing. - enc, err := self.getTrie(db).TryGet(key[:]) + enc, err := s.getTrie(db).TryGet(key[:]) if err != nil { - self.setError(err) + s.setError(err) return common.Hash{} } if len(enc) > 0 { _, content, _, err := rlp.Split(enc) if err != nil { - self.setError(err) + s.setError(err) } value.SetBytes(content) } if (value != common.Hash{}) { - self.cachedStorage[key] = value + s.cachedStorage[key] = value } return value } // SetState updates a value in account storage. -func (self *stateObject) SetState(db Database, key, value common.Hash) { - self.db.journal = append(self.db.journal, storageChange{ - account: &self.address, +func (s *stateObject) SetState(db Database, key, value common.Hash) { + // If the fake storage is set, put the temporary state update here. + if s.fakeStorage != nil { + s.fakeStorage[key] = value + return + } + // If the new value is the same as old, don't set + prev := s.GetState(db, key) + if prev == value { + return + } + // New value is different, update and journal the change + s.db.journal = append(s.db.journal, storageChange{ + account: &s.address, key: key, - prevalue: self.GetState(db, key), + prevalue: prev, }) - self.setState(key, value) + s.setState(key, value) +} + +// SetStorage replaces the entire state storage with the given one. +// +// After this function is called, all original state will be ignored and state +// lookup only happens in the fake state storage. +// +// Note this function should only be used for debugging purpose. +func (s *stateObject) SetStorage(storage map[common.Hash]common.Hash) { + // Allocate fake storage if it's nil. + if s.fakeStorage == nil { + s.fakeStorage = make(Storage) + } + for key, value := range storage { + s.fakeStorage[key] = value + } + // Don't bother journal since this function should only be used for + // debugging and the `fake` storage won't be committed to database. } -func (self *stateObject) setState(key, value common.Hash) { - self.cachedStorage[key] = value - self.dirtyStorage[key] = value +func (s *stateObject) setState(key, value common.Hash) { + s.cachedStorage[key] = value + s.dirtyStorage[key] = value - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } // updateTrie writes cached storage modifications into the object's storage trie. -func (self *stateObject) updateTrie(db Database) Trie { - tr := self.getTrie(db) - for key, value := range self.dirtyStorage { - delete(self.dirtyStorage, key) +func (s *stateObject) updateTrie(db Database) Trie { + tr := s.getTrie(db) + for key, value := range s.dirtyStorage { + delete(s.dirtyStorage, key) if (value == common.Hash{}) { - self.setError(tr.TryDelete(key[:])) + s.setError(tr.TryDelete(key[:])) continue } // Encoding []byte cannot fail, ok to ignore the error. v, _ := rlp.EncodeToBytes(bytes.TrimLeft(value[:], "\x00")) - self.setError(tr.TryUpdate(key[:], v)) + s.setError(tr.TryUpdate(key[:], v)) } return tr } // UpdateRoot sets the trie root to the current root hash of -func (self *stateObject) updateRoot(db Database) { - self.updateTrie(db) - self.data.Root = self.trie.Hash() +func (s *stateObject) updateRoot(db Database) { + s.updateTrie(db) + s.data.Root = s.trie.Hash() } // CommitTrie the storage trie of the object to dwb. // This updates the trie root. -func (self *stateObject) CommitTrie(db Database) error { - self.updateTrie(db) - if self.dbErr != nil { - return self.dbErr +func (s *stateObject) CommitTrie(db Database) error { + s.updateTrie(db) + if s.dbErr != nil { + return s.dbErr } - root, err := self.trie.Commit(nil) + root, err := s.trie.Commit(nil) if err == nil { - self.data.Root = root + s.data.Root = root } return err } // AddBalance removes amount from c's balance. // It is used to add funds to the destination account of a transfer. -func (c *stateObject) AddBalance(amount *big.Int) { +func (s *stateObject) AddBalance(amount *big.Int) { // EIP158: We must check emptiness for the objects such that the account // clearing (0,0,0 objects) can take effect. if amount.Sign() == 0 { - if c.empty() { - c.touch() + if s.empty() { + s.touch() } return } - c.SetBalance(new(big.Int).Add(c.Balance(), amount)) + s.SetBalance(new(big.Int).Add(s.Balance(), amount)) } // SubBalance removes amount from c's balance. // It is used to remove funds from the origin account of a transfer. -func (c *stateObject) SubBalance(amount *big.Int) { +func (s *stateObject) SubBalance(amount *big.Int) { if amount.Sign() == 0 { return } - c.SetBalance(new(big.Int).Sub(c.Balance(), amount)) + s.SetBalance(new(big.Int).Sub(s.Balance(), amount)) } -func (self *stateObject) SetBalance(amount *big.Int) { - self.db.journal = append(self.db.journal, balanceChange{ - account: &self.address, - prev: new(big.Int).Set(self.data.Balance), +func (s *stateObject) SetBalance(amount *big.Int) { + s.db.journal = append(s.db.journal, balanceChange{ + account: &s.address, + prev: new(big.Int).Set(s.data.Balance), }) - self.setBalance(amount) + s.setBalance(amount) } -func (self *stateObject) setBalance(amount *big.Int) { - self.data.Balance = amount - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil +func (s *stateObject) setBalance(amount *big.Int) { + s.data.Balance = amount + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } -func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { - stateObject := newObject(db, self.address, self.data, onDirty) - if self.trie != nil { - stateObject.trie = db.db.CopyTrie(self.trie) +func (s *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address)) *stateObject { + stateObject := newObject(db, s.address, s.data, onDirty) + if s.trie != nil { + stateObject.trie = db.db.CopyTrie(s.trie) } - stateObject.code = self.code - stateObject.dirtyStorage = self.dirtyStorage.Copy() - stateObject.cachedStorage = self.dirtyStorage.Copy() - stateObject.suicided = self.suicided - stateObject.dirtyCode = self.dirtyCode - stateObject.deleted = self.deleted + stateObject.code = s.code + stateObject.dirtyStorage = s.dirtyStorage.Copy() + stateObject.cachedStorage = s.dirtyStorage.Copy() + stateObject.suicided = s.suicided + stateObject.dirtyCode = s.dirtyCode + stateObject.deleted = s.deleted return stateObject } @@ -325,81 +363,81 @@ func (self *stateObject) deepCopy(db *StateDB, onDirty func(addr common.Address) // // Returns the address of the contract/account -func (c *stateObject) Address() common.Address { - return c.address +func (s *stateObject) Address() common.Address { + return s.address } // Code returns the contract code associated with this object, if any. -func (self *stateObject) Code(db Database) []byte { - if self.code != nil { - return self.code +func (s *stateObject) Code(db Database) []byte { + if s.code != nil { + return s.code } - if bytes.Equal(self.CodeHash(), emptyCodeHash) { + if bytes.Equal(s.CodeHash(), emptyCodeHash) { return nil } - code, err := db.ContractCode(self.addrHash, common.BytesToHash(self.CodeHash())) + code, err := db.ContractCode(s.addrHash, common.BytesToHash(s.CodeHash())) if err != nil { - self.setError(fmt.Errorf("can't load code hash %x: %v", self.CodeHash(), err)) + s.setError(fmt.Errorf("can't load code hash %x: %v", s.CodeHash(), err)) } - self.code = code + s.code = code return code } -func (self *stateObject) SetCode(codeHash common.Hash, code []byte) { - prevcode := self.Code(self.db.db) - self.db.journal = append(self.db.journal, codeChange{ - account: &self.address, - prevhash: self.CodeHash(), +func (s *stateObject) SetCode(codeHash common.Hash, code []byte) { + prevcode := s.Code(s.db.db) + s.db.journal = append(s.db.journal, codeChange{ + account: &s.address, + prevhash: s.CodeHash(), prevcode: prevcode, }) - self.setCode(codeHash, code) + s.setCode(codeHash, code) } -func (self *stateObject) setCode(codeHash common.Hash, code []byte) { - self.code = code - self.data.CodeHash = codeHash[:] - self.dirtyCode = true - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil +func (s *stateObject) setCode(codeHash common.Hash, code []byte) { + s.code = code + s.data.CodeHash = codeHash[:] + s.dirtyCode = true + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } -func (self *stateObject) SetNonce(nonce uint64) { - self.db.journal = append(self.db.journal, nonceChange{ - account: &self.address, - prev: self.data.Nonce, +func (s *stateObject) SetNonce(nonce uint64) { + s.db.journal = append(s.db.journal, nonceChange{ + account: &s.address, + prev: s.data.Nonce, }) - self.setNonce(nonce) + s.setNonce(nonce) } -func (self *stateObject) setNonce(nonce uint64) { - self.data.Nonce = nonce - if self.onDirty != nil { - self.onDirty(self.Address()) - self.onDirty = nil +func (s *stateObject) setNonce(nonce uint64) { + s.data.Nonce = nonce + if s.onDirty != nil { + s.onDirty(s.Address()) + s.onDirty = nil } } -func (self *stateObject) CodeHash() []byte { - return self.data.CodeHash +func (s *stateObject) CodeHash() []byte { + return s.data.CodeHash } -func (self *stateObject) Balance() *big.Int { - return self.data.Balance +func (s *stateObject) Balance() *big.Int { + return s.data.Balance } -func (self *stateObject) Nonce() uint64 { - return self.data.Nonce +func (s *stateObject) Nonce() uint64 { + return s.data.Nonce } -func (self *stateObject) Root() common.Hash { - return self.data.Root +func (s *stateObject) Root() common.Hash { + return s.data.Root } // Never called, but must be present to allow stateObject to be used // as a vm.Account interface that also satisfies the vm.ContractRef // interface. Interfaces are awesome. -func (self *stateObject) Value() *big.Int { +func (s *stateObject) Value() *big.Int { panic("Value on stateObject should never be called") } diff --git a/core/state/statedb.go b/core/state/statedb.go index 4bedebe8b3c3..15ca9ca67efc 100644 --- a/core/state/statedb.go +++ b/core/state/statedb.go @@ -74,6 +74,9 @@ type StateDB struct { preimages map[common.Hash][]byte + // Per-transaction access list + accessList *accessList + // Journal of state modifications. This is the backbone of // Snapshot and RevertToSnapshot. journal journal @@ -121,6 +124,7 @@ func New(root common.Hash, db Database) (*StateDB, error) { stateObjectsDirty: make(map[common.Address]struct{}), logs: make(map[common.Hash][]*types.Log), preimages: make(map[common.Hash][]byte), + accessList: newAccessList(), }, nil } @@ -152,6 +156,7 @@ func (self *StateDB) Reset(root common.Hash) error { self.logSize = 0 self.preimages = make(map[common.Hash][]byte) self.clearJournalAndRefund() + self.accessList = newAccessList() return nil } @@ -239,6 +244,16 @@ func (self *StateDB) GetStorageRoot(addr common.Address) common.Hash { return common.Hash{} } +// TxIndex returns the current transaction index set by Prepare. +func (self *StateDB) TxIndex() int { + return self.txIndex +} + +// BlockHash returns the current block hash set by Prepare. +func (self *StateDB) BlockHash() common.Hash { + return self.bhash +} + func (self *StateDB) GetCode(addr common.Address) []byte { stateObject := self.getStateObject(addr) if stateObject != nil { @@ -372,6 +387,15 @@ func (self *StateDB) SetState(addr common.Address, key, value common.Hash) { } } +// SetStorage replaces the entire storage for the specified account with given +// storage. This function should only be used for debugging. +func (s *StateDB) SetStorage(addr common.Address, storage map[common.Hash]common.Hash) { + stateObject := s.GetOrNewStateObject(addr) + if stateObject != nil { + stateObject.SetStorage(storage) + } +} + // Suicide marks the given account as suicided. // This clears the account balance. // @@ -489,8 +513,8 @@ func (self *StateDB) createObject(addr common.Address) (newobj, prev *stateObjec // CreateAccount is called during the EVM CREATE operation. The situation might arise that // a contract does the following: // -// 1. sends funds to sha(account ++ (nonce + 1)) -// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) +// 1. sends funds to sha(account ++ (nonce + 1)) +// 2. tx_create(sha(account ++ nonce)) (note that this gets the address of 1) // // Carrying over the balance ensures that Ether doesn't disappear. func (self *StateDB) CreateAccount(addr common.Address) { @@ -544,13 +568,26 @@ func (self *StateDB) Copy() *StateDB { state.stateObjects[addr] = self.stateObjects[addr].deepCopy(state, state.MarkStateObjectDirty) state.stateObjectsDirty[addr] = struct{}{} } + + // Deep copy the logs occurred in the scope of block for hash, logs := range self.logs { - state.logs[hash] = make([]*types.Log, len(logs)) - copy(state.logs[hash], logs) + cpy := make([]*types.Log, len(logs)) + for i, l := range logs { + cpy[i] = new(types.Log) + *cpy[i] = *l + } + state.logs[hash] = cpy } + for hash, preimage := range self.preimages { state.preimages[hash] = preimage } + // Do we need to copy the access list? In practice: No. At the start of a + // transaction, the access list is empty. In practice, we only ever copy state + // _between_ transactions/blocks, never in the middle of a transaction. + // However, it doesn't cost us much to copy an empty list, so we do it anyway + // to not blow up if we ever decide copy it in the middle of a transaction + state.accessList = self.accessList.Copy() return state } @@ -618,6 +655,7 @@ func (self *StateDB) Prepare(thash, bhash common.Hash, ti int) { self.thash = thash self.bhash = bhash self.txIndex = ti + self.accessList = newAccessList() } // DeleteSuicides flags the suicided objects for deletion so that it @@ -692,6 +730,67 @@ func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) return root, err } +// PrepareAccessList handles the preparatory steps for executing a state transition with +// regards to both EIP-2929 and EIP-2930: +// +// - Add sender to access list (2929) +// - Add destination to access list (2929) +// - Add precompiles to access list (2929) +// - Add the contents of the optional tx access list (2930) +// +// This method should only be called if Yolov3/Berlin/2929+2930 is applicable at the current number. +func (s *StateDB) PrepareAccessList(sender common.Address, dst *common.Address, precompiles []common.Address, list types.AccessList) { + s.AddAddressToAccessList(sender) + if dst != nil { + s.AddAddressToAccessList(*dst) + // If it's a create-tx, the destination will be added inside evm.create + } + for _, addr := range precompiles { + s.AddAddressToAccessList(addr) + } + for _, el := range list { + s.AddAddressToAccessList(el.Address) + for _, key := range el.StorageKeys { + s.AddSlotToAccessList(el.Address, key) + } + } +} + +// AddAddressToAccessList adds the given address to the access list +func (s *StateDB) AddAddressToAccessList(addr common.Address) { + if s.accessList.AddAddress(addr) { + s.journal = append(s.journal, accessListAddAccountChange{&addr}) + } +} + +// AddSlotToAccessList adds the given (address, slot)-tuple to the access list +func (s *StateDB) AddSlotToAccessList(addr common.Address, slot common.Hash) { + addrMod, slotMod := s.accessList.AddSlot(addr, slot) + if addrMod { + // In practice, this should not happen, since there is no way to enter the + // scope of 'address' without having the 'address' become already added + // to the access list (via call-variant, create, etc). + // Better safe than sorry, though + s.journal = append(s.journal, accessListAddAccountChange{&addr}) + } + if slotMod { + s.journal = append(s.journal, accessListAddSlotChange{ + address: &addr, + slot: &slot, + }) + } +} + +// AddressInAccessList returns true if the given address is in the access list. +func (s *StateDB) AddressInAccessList(addr common.Address) bool { + return s.accessList.ContainsAddress(addr) +} + +// SlotInAccessList returns true if the given (address, slot)-tuple is in the access list. +func (s *StateDB) SlotInAccessList(addr common.Address, slot common.Hash) (addressPresent bool, slotPresent bool) { + return s.accessList.Contains(addr, slot) +} + func (s *StateDB) GetOwner(candidate common.Address) common.Address { slot := slotValidatorMapping["validatorsState"] // validatorsState[_candidate].owner; diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go index 7e7f5343c34a..15533ec5a086 100644 --- a/core/state/statedb_test.go +++ b/core/state/statedb_test.go @@ -20,7 +20,6 @@ import ( "bytes" "encoding/binary" "fmt" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math" "math/big" "math/rand" @@ -32,6 +31,7 @@ import ( check "gopkg.in/check.v1" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" "github.com/XinFinOrg/XDPoSChain/core/types" ) @@ -283,6 +283,20 @@ func newTestAction(addr common.Address, r *rand.Rand) testAction { }, args: make([]int64, 1), }, + { + name: "AddAddressToAccessList", + fn: func(a testAction, s *StateDB) { + s.AddAddressToAccessList(addr) + }, + }, + { + name: "AddSlotToAccessList", + fn: func(a testAction, s *StateDB) { + s.AddSlotToAccessList(addr, + common.Hash{byte(a.args[0])}) + }, + args: make([]int64, 1), + }, } action := actions[r.Intn(len(actions))] var nameargs []string @@ -427,3 +441,177 @@ func (s *StateSuite) TestTouchDelete(c *check.C) { c.Fatal("expected no dirty state object") } } + +func TestStateDBAccessList(t *testing.T) { + // Some helpers + addr := func(a string) common.Address { + return common.HexToAddress(a) + } + slot := func(a string) common.Hash { + return common.HexToHash(a) + } + + memDb := rawdb.NewMemoryDatabase() + db := NewDatabase(memDb) + state, _ := New(common.Hash{}, db) + state.accessList = newAccessList() + + verifyAddrs := func(astrings ...string) { + t.Helper() + // convert to common.Address form + var addresses []common.Address + var addressMap = make(map[common.Address]struct{}) + for _, astring := range astrings { + address := addr(astring) + addresses = append(addresses, address) + addressMap[address] = struct{}{} + } + // Check that the given addresses are in the access list + for _, address := range addresses { + if !state.AddressInAccessList(address) { + t.Fatalf("expected %x to be in access list", address) + } + } + // Check that only the expected addresses are present in the acesslist + for address := range state.accessList.addresses { + if _, exist := addressMap[address]; !exist { + t.Fatalf("extra address %x in access list", address) + } + } + } + verifySlots := func(addrString string, slotStrings ...string) { + if !state.AddressInAccessList(addr(addrString)) { + t.Fatalf("scope missing address/slots %v", addrString) + } + var address = addr(addrString) + // convert to common.Hash form + var slots []common.Hash + var slotMap = make(map[common.Hash]struct{}) + for _, slotString := range slotStrings { + s := slot(slotString) + slots = append(slots, s) + slotMap[s] = struct{}{} + } + // Check that the expected items are in the access list + for i, s := range slots { + if _, slotPresent := state.SlotInAccessList(address, s); !slotPresent { + t.Fatalf("input %d: scope missing slot %v (address %v)", i, s, addrString) + } + } + // Check that no extra elements are in the access list + index := state.accessList.addresses[address] + if index >= 0 { + stateSlots := state.accessList.slots[index] + for s := range stateSlots { + if _, slotPresent := slotMap[s]; !slotPresent { + t.Fatalf("scope has extra slot %v (address %v)", s, addrString) + } + } + } + } + + state.AddAddressToAccessList(addr("aa")) // 1 + state.AddSlotToAccessList(addr("bb"), slot("01")) // 2,3 + state.AddSlotToAccessList(addr("bb"), slot("02")) // 4 + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + + // Make a copy + stateCopy1 := state.Copy() + if exp, got := 4, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + + // same again, should cause no journal entries + state.AddSlotToAccessList(addr("bb"), slot("01")) + state.AddSlotToAccessList(addr("bb"), slot("02")) + state.AddAddressToAccessList(addr("aa")) + if exp, got := 4, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + // some new ones + state.AddSlotToAccessList(addr("bb"), slot("03")) // 5 + state.AddSlotToAccessList(addr("aa"), slot("01")) // 6 + state.AddSlotToAccessList(addr("cc"), slot("01")) // 7,8 + state.AddAddressToAccessList(addr("cc")) + if exp, got := 8, state.journal.length(); exp != got { + t.Fatalf("journal length mismatch: have %d, want %d", got, exp) + } + + verifyAddrs("aa", "bb", "cc") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + verifySlots("cc", "01") + + // now start rolling back changes + state.journal[7].undo(state) + if _, ok := state.SlotInAccessList(addr("cc"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb", "cc") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + + state.journal[6].undo(state) + if state.AddressInAccessList(addr("cc")) { + t.Fatalf("addr present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("aa", "01") + verifySlots("bb", "01", "02", "03") + + state.journal[5].undo(state) + if _, ok := state.SlotInAccessList(addr("aa"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02", "03") + + state.journal[4].undo(state) + if _, ok := state.SlotInAccessList(addr("bb"), slot("03")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + + state.journal[3].undo(state) + if _, ok := state.SlotInAccessList(addr("bb"), slot("02")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + verifySlots("bb", "01") + + state.journal[2].undo(state) + if _, ok := state.SlotInAccessList(addr("bb"), slot("01")); ok { + t.Fatalf("slot present, expected missing") + } + verifyAddrs("aa", "bb") + + state.journal[1].undo(state) + if state.AddressInAccessList(addr("bb")) { + t.Fatalf("addr present, expected missing") + } + verifyAddrs("aa") + + state.journal[0].undo(state) + if state.AddressInAccessList(addr("aa")) { + t.Fatalf("addr present, expected missing") + } + if got, exp := len(state.accessList.addresses), 0; got != exp { + t.Fatalf("expected empty, got %d", got) + } + if got, exp := len(state.accessList.slots), 0; got != exp { + t.Fatalf("expected empty, got %d", got) + } + // Check the copy + // Make a copy + state = stateCopy1 + verifyAddrs("aa", "bb") + verifySlots("bb", "01", "02") + if got, exp := len(state.accessList.addresses), 2; got != exp { + t.Fatalf("expected empty, got %d", got) + } + if got, exp := len(state.accessList.slots), 1; got != exp { + t.Fatalf("expected empty, got %d", got) + } +} diff --git a/core/state_processor.go b/core/state_processor.go index acfec0c8fdcc..cc5697e069e5 100644 --- a/core/state_processor.go +++ b/core/state_processor.go @@ -218,17 +218,17 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* if tx.To() != nil && tx.To().String() == common.BlockSigners && config.IsTIPSigning(header.Number) { return ApplySignTransaction(config, statedb, header, tx, usedGas) } - if tx.To() != nil && tx.To().String() == common.TradingStateAddr && config.IsTIPXDCX(header.Number) { + if tx.To() != nil && tx.To().String() == common.TradingStateAddr && config.IsTIPXDCXReceiver(header.Number) { return ApplyEmptyTransaction(config, statedb, header, tx, usedGas) } - if tx.To() != nil && tx.To().String() == common.XDCXLendingAddress && config.IsTIPXDCX(header.Number) { + if tx.To() != nil && tx.To().String() == common.XDCXLendingAddress && config.IsTIPXDCXReceiver(header.Number) { return ApplyEmptyTransaction(config, statedb, header, tx, usedGas) } - if tx.IsTradingTransaction() && config.IsTIPXDCX(header.Number) { + if tx.IsTradingTransaction() && config.IsTIPXDCXReceiver(header.Number) { return ApplyEmptyTransaction(config, statedb, header, tx, usedGas) } - if tx.IsLendingFinalizedTradeTransaction() && config.IsTIPXDCX(header.Number) { + if tx.IsLendingFinalizedTradeTransaction() && config.IsTIPXDCXReceiver(header.Number) { return ApplyEmptyTransaction(config, statedb, header, tx, usedGas) } @@ -242,7 +242,7 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* if err != nil { return nil, 0, err, false } - // Create a new context to be used in the EVM environment + // Create a new context to be used in the EVM environment. context := NewEVMContext(msg, header, bc, author) // Create a new environment which holds all relevant information // about the transaction and calling mechanisms. @@ -408,7 +408,8 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* if err != nil { return nil, 0, err, false } - // Update the state with pending changes + + // Update the state with pending changes. var root []byte if config.IsByzantium(header.Number) { statedb.Finalise(true) @@ -417,18 +418,28 @@ func ApplyTransaction(config *params.ChainConfig, tokensFee map[common.Address]* } *usedGas += gas - // Create a new receipt for the transaction, storing the intermediate root and gas used by the tx - // based on the eip phase, we're passing wether the root touch-delete accounts. - receipt := types.NewReceipt(root, failed, *usedGas) + // Create a new receipt for the transaction, storing the intermediate root and gas used + // by the tx. + receipt := &types.Receipt{Type: tx.Type(), PostState: root, CumulativeGasUsed: *usedGas} + if failed { + receipt.Status = types.ReceiptStatusFailed + } else { + receipt.Status = types.ReceiptStatusSuccessful + } receipt.TxHash = tx.Hash() receipt.GasUsed = gas - // if the transaction created a contract, store the creation address in the receipt. + + // If the transaction created a contract, store the creation address in the receipt. if msg.To() == nil { receipt.ContractAddress = crypto.CreateAddress(vmenv.Context.Origin, tx.Nonce()) } - // Set the receipt logs and create a bloom for filtering + + // Set the receipt logs and create the bloom filter. receipt.Logs = statedb.GetLogs(tx.Hash()) receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) + receipt.BlockHash = statedb.BlockHash() + receipt.BlockNumber = header.Number + receipt.TransactionIndex = uint(statedb.TxIndex()) if balanceFee != nil && failed { state.PayFeeWithTRC21TxFail(statedb, msg.From(), *tx.To()) } @@ -467,6 +478,9 @@ func ApplySignTransaction(config *params.ChainConfig, statedb *state.StateDB, he statedb.AddLog(log) receipt.Logs = statedb.GetLogs(tx.Hash()) receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) + receipt.BlockHash = statedb.BlockHash() + receipt.BlockNumber = header.Number + receipt.TransactionIndex = uint(statedb.TxIndex()) return receipt, 0, nil, false } @@ -491,6 +505,9 @@ func ApplyEmptyTransaction(config *params.ChainConfig, statedb *state.StateDB, h statedb.AddLog(log) receipt.Logs = statedb.GetLogs(tx.Hash()) receipt.Bloom = types.CreateBloom(types.Receipts{receipt}) + receipt.BlockHash = statedb.BlockHash() + receipt.BlockNumber = header.Number + receipt.TransactionIndex = uint(statedb.TxIndex()) return receipt, 0, nil, false } @@ -512,7 +529,6 @@ func InitSignerInTransactions(config *params.ChainConfig, header *types.Header, go func(from int, to int) { for j := from; j < to; j++ { types.CacheSigner(signer, txs[j]) - txs[j].CacheHash() } wg.Done() }(from, to) diff --git a/core/state_transition.go b/core/state_transition.go index c9c4fdfefd40..48d6ebcae4ec 100644 --- a/core/state_transition.go +++ b/core/state_transition.go @@ -22,6 +22,7 @@ import ( "math/big" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/core/vm" "github.com/XinFinOrg/XDPoSChain/log" "github.com/XinFinOrg/XDPoSChain/params" @@ -42,8 +43,10 @@ The state transitioning model does all all the necessary work to work out a vali 3) Create a new state object if the recipient is \0*32 4) Value transfer == If contract creation == - 4a) Attempt to run transaction data - 4b) If valid, use result as code for the new state object + + 4a) Attempt to run transaction data + 4b) If valid, use result as code for the new state object + == end == 5) Run Script section 6) Derive new state root @@ -74,13 +77,14 @@ type Message interface { CheckNonce() bool Data() []byte BalanceTokenFee() *big.Int + AccessList() types.AccessList } // IntrinsicGas computes the 'intrinsic gas' for a message with the given data. -func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error) { +func IntrinsicGas(data []byte, accessList types.AccessList, isContractCreation, isHomestead bool) (uint64, error) { // Set the starting gas for the raw transaction var gas uint64 - if contractCreation && homestead { + if isContractCreation && isHomestead { gas = params.TxGasContractCreation } else { gas = params.TxGas @@ -106,6 +110,10 @@ func IntrinsicGas(data []byte, contractCreation, homestead bool) (uint64, error) } gas += z * params.TxDataZeroGas } + if accessList != nil { + gas += uint64(len(accessList)) * params.TxAccessListAddressGas + gas += uint64(accessList.StorageKeys()) * params.TxAccessListStorageKeyGas + } return gas, nil } @@ -226,7 +234,7 @@ func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedG contractCreation := msg.To() == nil // Pay intrinsic gas - gas, err := IntrinsicGas(st.data, contractCreation, homestead) + gas, err := IntrinsicGas(st.data, st.msg.AccessList(), contractCreation, homestead) if err != nil { return nil, 0, false, err, nil } @@ -234,6 +242,10 @@ func (st *StateTransition) TransitionDb(owner common.Address) (ret []byte, usedG return nil, 0, false, err, nil } + if rules := st.evm.ChainConfig().Rules(st.evm.Context.BlockNumber); rules.IsEIP1559 { + st.state.PrepareAccessList(msg.From(), msg.To(), vm.ActivePrecompiles(rules), msg.AccessList()) + } + var ( evm = st.evm // vm errors do not effect consensus and are therefor diff --git a/core/token_validator.go b/core/token_validator.go index 49e6261f89e8..f604be504f35 100644 --- a/core/token_validator.go +++ b/core/token_validator.go @@ -27,6 +27,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/contracts/XDCx/contract" "github.com/XinFinOrg/XDPoSChain/core/state" + "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/core/vm" "github.com/XinFinOrg/XDPoSChain/log" ) @@ -37,20 +38,21 @@ const ( getDecimalFunction = "decimals" ) -// callmsg implements core.Message to allow passing it as a transaction simulator. -type callmsg struct { +// callMsg implements core.Message to allow passing it as a transaction simulator. +type callMsg struct { ethereum.CallMsg } -func (m callmsg) From() common.Address { return m.CallMsg.From } -func (m callmsg) Nonce() uint64 { return 0 } -func (m callmsg) CheckNonce() bool { return false } -func (m callmsg) To() *common.Address { return m.CallMsg.To } -func (m callmsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } -func (m callmsg) Gas() uint64 { return m.CallMsg.Gas } -func (m callmsg) Value() *big.Int { return m.CallMsg.Value } -func (m callmsg) Data() []byte { return m.CallMsg.Data } -func (m callmsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } +func (m callMsg) From() common.Address { return m.CallMsg.From } +func (m callMsg) Nonce() uint64 { return 0 } +func (m callMsg) CheckNonce() bool { return false } +func (m callMsg) To() *common.Address { return m.CallMsg.To } +func (m callMsg) GasPrice() *big.Int { return m.CallMsg.GasPrice } +func (m callMsg) Gas() uint64 { return m.CallMsg.Gas } +func (m callMsg) Value() *big.Int { return m.CallMsg.Value } +func (m callMsg) Data() []byte { return m.CallMsg.Data } +func (m callMsg) BalanceTokenFee() *big.Int { return m.CallMsg.BalanceTokenFee } +func (m callMsg) AccessList() types.AccessList { return m.CallMsg.AccessList } type SimulatedBackend interface { CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) @@ -86,7 +88,7 @@ func RunContract(chain consensus.ChainContext, statedb *state.StateDB, contractA return unpackResult, nil } -//FIXME: please use copyState for this function +// FIXME: please use copyState for this function // CallContractWithState executes a contract call at the given state. func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, statedb *state.StateDB) ([]byte, error) { // Ensure message is initialized properly. @@ -99,7 +101,7 @@ func CallContractWithState(call ethereum.CallMsg, chain consensus.ChainContext, call.Value = new(big.Int) } // Execute the call. - msg := callmsg{call} + msg := callMsg{call} feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) if msg.To() != nil { if value, ok := feeCapacity[*msg.To()]; ok { @@ -124,7 +126,7 @@ func ValidateXDCXApplyTransaction(chain consensus.ChainContext, blockNumber *big if blockNumber == nil || blockNumber.Sign() <= 0 { blockNumber = chain.CurrentHeader().Number } - if !chain.Config().IsTIPXDCX(blockNumber) { + if !chain.Config().IsTIPXDCXReceiver(blockNumber) { return nil } contractABI, err := GetTokenAbi(contract.TRC21ABI) @@ -146,7 +148,7 @@ func ValidateXDCZApplyTransaction(chain consensus.ChainContext, blockNumber *big if blockNumber == nil || blockNumber.Sign() <= 0 { blockNumber = chain.CurrentHeader().Number } - if !chain.Config().IsTIPXDCX(blockNumber) { + if !chain.Config().IsTIPXDCXReceiver(blockNumber) { return nil } contractABI, err := GetTokenAbi(contract.TRC21ABI) diff --git a/core/tx_cacher.go b/core/tx_cacher.go new file mode 100644 index 000000000000..ea4ab6cc07f6 --- /dev/null +++ b/core/tx_cacher.go @@ -0,0 +1,105 @@ +// Copyright 2018 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package core + +import ( + "runtime" + + "github.com/XinFinOrg/XDPoSChain/core/types" +) + +// senderCacher is a concurrent tranaction sender recoverer anc cacher. +var senderCacher = newTxSenderCacher(runtime.NumCPU()) + +// txSenderCacherRequest is a request for recovering transaction senders with a +// specific signature scheme and caching it into the transactions themselves. +// +// The inc field defines the number of transactions to skip after each recovery, +// which is used to feed the same underlying input array to different threads but +// ensure they process the early transactions fast. +type txSenderCacherRequest struct { + signer types.Signer + txs []*types.Transaction + inc int +} + +// txSenderCacher is a helper structure to concurrently ecrecover transaction +// senders from digital signatures on background threads. +type txSenderCacher struct { + threads int + tasks chan *txSenderCacherRequest +} + +// newTxSenderCacher creates a new transaction sender background cacher and starts +// as many procesing goroutines as allowed by the GOMAXPROCS on construction. +func newTxSenderCacher(threads int) *txSenderCacher { + cacher := &txSenderCacher{ + tasks: make(chan *txSenderCacherRequest, threads), + threads: threads, + } + for i := 0; i < threads; i++ { + go cacher.cache() + } + return cacher +} + +// cache is an infinite loop, caching transaction senders from various forms of +// data structures. +func (cacher *txSenderCacher) cache() { + for task := range cacher.tasks { + for i := 0; i < len(task.txs); i += task.inc { + types.Sender(task.signer, task.txs[i]) + } + } +} + +// recover recovers the senders from a batch of transactions and caches them +// back into the same data structures. There is no validation being done, nor +// any reaction to invalid signatures. That is up to calling code later. +func (cacher *txSenderCacher) recover(signer types.Signer, txs []*types.Transaction) { + // If there's nothing to recover, abort + if len(txs) == 0 { + return + } + // Ensure we have meaningful task sizes and schedule the recoveries + tasks := cacher.threads + if len(txs) < tasks*4 { + tasks = (len(txs) + 3) / 4 + } + for i := 0; i < tasks; i++ { + cacher.tasks <- &txSenderCacherRequest{ + signer: signer, + txs: txs[i:], + inc: tasks, + } + } +} + +// recoverFromBlocks recovers the senders from a batch of blocks and caches them +// back into the same data structures. There is no validation being done, nor +// any reaction to invalid signatures. That is up to calling code later. +func (cacher *txSenderCacher) recoverFromBlocks(signer types.Signer, blocks []*types.Block) { + count := 0 + for _, block := range blocks { + count += len(block.Transactions()) + } + txs := make([]*types.Transaction, 0, count) + for _, block := range blocks { + txs = append(txs, block.Transactions()...) + } + cacher.recover(signer, txs) +} diff --git a/core/tx_journal.go b/core/tx_journal.go index a6e525012f5f..4fe5fdca365c 100644 --- a/core/tx_journal.go +++ b/core/tx_journal.go @@ -56,7 +56,7 @@ func newTxJournal(path string) *txJournal { // load parses a transaction journal dump from disk, loading its contents into // the specified pool. -func (journal *txJournal) load(add func(*types.Transaction) error) error { +func (journal *txJournal) load(add func([]*types.Transaction) []error) error { // Skip the parsing if the journal file doens't exist at all if _, err := os.Stat(journal.path); os.IsNotExist(err) { return nil @@ -76,7 +76,21 @@ func (journal *txJournal) load(add func(*types.Transaction) error) error { stream := rlp.NewStream(input, 0) total, dropped := 0, 0 - var failure error + // Create a method to load a limited batch of transactions and bump the + // appropriate progress counters. Then use this method to load all the + // journalled transactions in small-ish batches. + loadBatch := func(txs types.Transactions) { + for _, err := range add(txs) { + if err != nil { + log.Debug("Failed to add journaled transaction", "err", err) + dropped++ + } + } + } + var ( + failure error + batch types.Transactions + ) for { // Parse the next transaction and terminate on error tx := new(types.Transaction) @@ -84,14 +98,16 @@ func (journal *txJournal) load(add func(*types.Transaction) error) error { if err != io.EOF { failure = err } + if batch.Len() > 0 { + loadBatch(batch) + } break } - // Import the transaction and bump the appropriate progress counters + // New transaction parsed, queue up for later, import if threnshold is reached total++ - if err = add(tx); err != nil { - log.Debug("Failed to add journaled transaction", "err", err) - dropped++ - continue + if batch = append(batch, tx); batch.Len() > 1024 { + loadBatch(batch) + batch = batch[:0] } } log.Info("Loaded local transaction journal", "transactions", total, "dropped", dropped) diff --git a/core/tx_list.go b/core/tx_list.go index 030c4cd30012..5f623806d609 100644 --- a/core/tx_list.go +++ b/core/tx_list.go @@ -24,7 +24,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" - "github.com/XinFinOrg/XDPoSChain/log" ) // nonceHeap is a heap.Interface implementation over 64bit unsigned integers for @@ -99,7 +98,30 @@ func (m *txSortedMap) Forward(threshold uint64) types.Transactions { // Filter iterates over the list of transactions and removes all of them for which // the specified function evaluates to true. +// Filter, as opposed to 'filter', re-initialises the heap after the operation is done. +// If you want to do several consecutive filterings, it's therefore better to first +// do a .filter(func1) followed by .Filter(func2) or reheap() func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transactions { + removed := m.filter(filter) + // If transactions were removed, the heap and cache are ruined + if len(removed) > 0 { + m.reheap() + } + return removed +} + +func (m *txSortedMap) reheap() { + *m.index = make([]uint64, 0, len(m.items)) + for nonce := range m.items { + *m.index = append(*m.index, nonce) + } + heap.Init(m.index) + m.cache = nil +} + +// filter is identical to Filter, but **does not** regenerate the heap. This method +// should only be used if followed immediately by a call to Filter or reheap() +func (m *txSortedMap) filter(filter func(*types.Transaction) bool) types.Transactions { var removed types.Transactions // Collect all the transactions to filter out @@ -109,14 +131,7 @@ func (m *txSortedMap) Filter(filter func(*types.Transaction) bool) types.Transac delete(m.items, nonce) } } - // If transactions were removed, the heap and cache are ruined if len(removed) > 0 { - *m.index = make([]uint64, 0, len(m.items)) - for nonce := range m.items { - *m.index = append(*m.index, nonce) - } - heap.Init(m.index) - m.cache = nil } return removed @@ -197,10 +212,7 @@ func (m *txSortedMap) Len() int { return len(m.items) } -// Flatten creates a nonce-sorted slice of transactions based on the loosely -// sorted internal representation. The result of the sorting is cached in case -// it's requested again before any modifications are made to the contents. -func (m *txSortedMap) Flatten() types.Transactions { +func (m *txSortedMap) flatten() types.Transactions { // If the sorting was not cached yet, create and cache it if m.cache == nil { m.cache = make(types.Transactions, 0, len(m.items)) @@ -209,12 +221,27 @@ func (m *txSortedMap) Flatten() types.Transactions { } sort.Sort(types.TxByNonce(m.cache)) } + return m.cache +} + +// Flatten creates a nonce-sorted slice of transactions based on the loosely +// sorted internal representation. The result of the sorting is cached in case +// it's requested again before any modifications are made to the contents. +func (m *txSortedMap) Flatten() types.Transactions { // Copy the cache to prevent accidental modifications - txs := make(types.Transactions, len(m.cache)) - copy(txs, m.cache) + cache := m.flatten() + txs := make(types.Transactions, len(cache)) + copy(txs, cache) return txs } +// LastElement returns the last element of a flattened list, thus, the +// transaction with the highest nonce +func (m *txSortedMap) LastElement() *types.Transaction { + cache := m.flatten() + return cache[len(cache)-1] +} + // txList is a "list" of transactions belonging to an account, sorted by account // nonce. The same type can be used both for storing contiguous transactions for // the executable/pending queue; and for storing gapped transactions for the non- @@ -255,11 +282,15 @@ func (l *txList) Add(tx *types.Transaction, priceBump uint64) (bool, *types.Tran return false, nil } if old != nil { - threshold := new(big.Int).Div(new(big.Int).Mul(old.GasPrice(), big.NewInt(100+int64(priceBump))), big.NewInt(100)) + // threshold = oldGP * (100 + priceBump) / 100 + a := big.NewInt(100 + int64(priceBump)) + a = a.Mul(a, old.GasPrice()) + b := big.NewInt(100) + threshold := a.Div(a, b) // Have to ensure that the new gas price is higher than the old gas // price as well as checking the percentage threshold to ensure that // this is accurate for low (Wei-level) gas price replacements - if old.GasPrice().Cmp(tx.GasPrice()) >= 0 || threshold.Cmp(tx.GasPrice()) > 0 { + if old.GasPriceCmp(tx) >= 0 || tx.GasPriceIntCmp(threshold) < 0 { return false, nil } } @@ -303,24 +334,27 @@ func (l *txList) Filter(costLimit *big.Int, gasLimit uint64, trc21Issuers map[co maximum := costLimit if tx.To() != nil { if feeCapacity, ok := trc21Issuers[*tx.To()]; ok { - return new(big.Int).Add(costLimit, feeCapacity).Cmp(tx.TxCost(number)) < 0 || tx.Gas() > gasLimit + return tx.Gas() > gasLimit || new(big.Int).Add(costLimit, feeCapacity).Cmp(tx.TxCost(number)) < 0 } } - return tx.Cost().Cmp(maximum) > 0 || tx.Gas() > gasLimit + return tx.Gas() > gasLimit || tx.Cost().Cmp(maximum) > 0 }) - // If the list was strict, filter anything above the lowest nonce + if len(removed) == 0 { + return nil, nil + } var invalids types.Transactions - - if l.strict && len(removed) > 0 { + // If the list was strict, filter anything above the lowest nonce + if l.strict { lowest := uint64(math.MaxUint64) for _, tx := range removed { if nonce := tx.Nonce(); lowest > nonce { lowest = nonce } } - invalids = l.txs.Filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) + invalids = l.txs.filter(func(tx *types.Transaction) bool { return tx.Nonce() > lowest }) } + l.txs.reheap() return removed, invalids } @@ -374,13 +408,30 @@ func (l *txList) Flatten() types.Transactions { return l.txs.Flatten() } +// LastElement returns the last element of a flattened list, thus, the +// transaction with the highest nonce +func (l *txList) LastElement() *types.Transaction { + return l.txs.LastElement() +} + // priceHeap is a heap.Interface implementation over transactions for retrieving // price-sorted transactions to discard when the pool fills up. type priceHeap []*types.Transaction -func (h priceHeap) Len() int { return len(h) } -func (h priceHeap) Less(i, j int) bool { return h[i].GasPrice().Cmp(h[j].GasPrice()) < 0 } -func (h priceHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h priceHeap) Len() int { return len(h) } +func (h priceHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h priceHeap) Less(i, j int) bool { + // Sort primarily by price, returning the cheaper one + switch h[i].GasPriceCmp(h[j]) { + case -1: + return true + case 1: + return false + } + // If the prices match, stabilize via nonces (high nonce is worse) + return h[i].Nonce() > h[j].Nonce() +} func (h *priceHeap) Push(x interface{}) { *h = append(*h, x.(*types.Transaction)) @@ -395,125 +446,126 @@ func (h *priceHeap) Pop() interface{} { } // txPricedList is a price-sorted heap to allow operating on transactions pool -// contents in a price-incrementing way. +// contents in a price-incrementing way. It's built opon the all transactions +// in txpool but only interested in the remote part. It means only remote transactions +// will be considered for tracking, sorting, eviction, etc. type txPricedList struct { - all *map[common.Hash]*types.Transaction // Pointer to the map of all transactions - items *priceHeap // Heap of prices of all the stored transactions - stales int // Number of stale price points to (re-heap trigger) + all *txLookup // Pointer to the map of all transactions + remotes *priceHeap // Heap of prices of all the stored **remote** transactions + stales int // Number of stale price points to (re-heap trigger) } // newTxPricedList creates a new price-sorted transaction heap. -func newTxPricedList(all *map[common.Hash]*types.Transaction) *txPricedList { +func newTxPricedList(all *txLookup) *txPricedList { return &txPricedList{ - all: all, - items: new(priceHeap), + all: all, + remotes: new(priceHeap), } } // Put inserts a new transaction into the heap. -func (l *txPricedList) Put(tx *types.Transaction) { - heap.Push(l.items, tx) +func (l *txPricedList) Put(tx *types.Transaction, local bool) { + if local { + return + } + heap.Push(l.remotes, tx) } // Removed notifies the prices transaction list that an old transaction dropped // from the pool. The list will just keep a counter of stale objects and update // the heap if a large enough ratio of transactions go stale. -func (l *txPricedList) Removed() { +func (l *txPricedList) Removed(count int) { // Bump the stale counter, but exit if still too low (< 25%) - l.stales++ - if l.stales <= len(*l.items)/4 { + l.stales += count + if l.stales <= len(*l.remotes)/4 { return } // Seems we've reached a critical number of stale transactions, reheap - reheap := make(priceHeap, 0, len(*l.all)) - - l.stales, l.items = 0, &reheap - for _, tx := range *l.all { - *l.items = append(*l.items, tx) - } - heap.Init(l.items) + l.Reheap() } // Cap finds all the transactions below the given price threshold, drops them -// from the priced list and returs them for further removal from the entire pool. -func (l *txPricedList) Cap(threshold *big.Int, local *accountSet) types.Transactions { +// from the priced list and returns them for further removal from the entire pool. +// +// Note: only remote transactions will be considered for eviction. +func (l *txPricedList) Cap(threshold *big.Int) types.Transactions { drop := make(types.Transactions, 0, 128) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep - - for len(*l.items) > 0 { + for len(*l.remotes) > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) - if _, ok := (*l.all)[tx.Hash()]; !ok { + cheapest := (*l.remotes)[0] + if l.all.GetRemote(cheapest.Hash()) == nil { // Removed or migrated + heap.Pop(l.remotes) l.stales-- continue } // Stop the discards if we've reached the threshold - if tx.GasPrice().Cmp(threshold) >= 0 { - save = append(save, tx) + if cheapest.GasPriceIntCmp(threshold) >= 0 { break } - // Non stale transaction found, discard unless local - if local.containsTx(tx) { - save = append(save, tx) - } else { - drop = append(drop, tx) - } - } - for _, tx := range save { - heap.Push(l.items, tx) + heap.Pop(l.remotes) + drop = append(drop, cheapest) } return drop } // Underpriced checks whether a transaction is cheaper than (or as cheap as) the -// lowest priced transaction currently being tracked. -func (l *txPricedList) Underpriced(tx *types.Transaction, local *accountSet) bool { - // Local transactions cannot be underpriced - if local.containsTx(tx) { - return false - } +// lowest priced (remote) transaction currently being tracked. +func (l *txPricedList) Underpriced(tx *types.Transaction) bool { // Discard stale price points if found at the heap start - for len(*l.items) > 0 { - head := []*types.Transaction(*l.items)[0] - if _, ok := (*l.all)[head.Hash()]; !ok { + for len(*l.remotes) > 0 { + head := []*types.Transaction(*l.remotes)[0] + if l.all.GetRemote(head.Hash()) == nil { // Removed or migrated l.stales-- - heap.Pop(l.items) + heap.Pop(l.remotes) continue } break } // Check if the transaction is underpriced or not - if len(*l.items) == 0 { - log.Error("Pricing query for empty pool") // This cannot happen, print to catch programming errors - return false + if len(*l.remotes) == 0 { + return false // There is no remote transaction at all. } - cheapest := []*types.Transaction(*l.items)[0] - return cheapest.GasPrice().Cmp(tx.GasPrice()) >= 0 + // If the remote transaction is even cheaper than the + // cheapest one tracked locally, reject it. + cheapest := []*types.Transaction(*l.remotes)[0] + return cheapest.GasPriceCmp(tx) >= 0 } // Discard finds a number of most underpriced transactions, removes them from the // priced list and returns them for further removal from the entire pool. -func (l *txPricedList) Discard(count int, local *accountSet) types.Transactions { - drop := make(types.Transactions, 0, count) // Remote underpriced transactions to drop - save := make(types.Transactions, 0, 64) // Local underpriced transactions to keep - - for len(*l.items) > 0 && count > 0 { +// +// Note local transaction won't be considered for eviction. +func (l *txPricedList) Discard(slots int, force bool) (types.Transactions, bool) { + drop := make(types.Transactions, 0, slots) // Remote underpriced transactions to drop + for len(*l.remotes) > 0 && slots > 0 { // Discard stale transactions if found during cleanup - tx := heap.Pop(l.items).(*types.Transaction) - if _, ok := (*l.all)[tx.Hash()]; !ok { + tx := heap.Pop(l.remotes).(*types.Transaction) + if l.all.GetRemote(tx.Hash()) == nil { // Removed or migrated l.stales-- continue } - // Non stale transaction found, discard unless local - if local.containsTx(tx) { - save = append(save, tx) - } else { - drop = append(drop, tx) - count-- + // Non stale transaction found, discard it + drop = append(drop, tx) + slots -= numSlots(tx) + } + // If we still can't make enough room for the new transaction + if slots > 0 && !force { + for _, tx := range drop { + heap.Push(l.remotes, tx) } + return nil, false } - for _, tx := range save { - heap.Push(l.items, tx) - } - return drop + return drop, true +} + +// Reheap forcibly rebuilds the heap based on the current remote transaction set. +func (l *txPricedList) Reheap() { + reheap := make(priceHeap, 0, l.all.RemoteCount()) + + l.stales, l.remotes = 0, &reheap + l.all.Range(func(hash common.Hash, tx *types.Transaction, local bool) bool { + *l.remotes = append(*l.remotes, tx) + return true + }, false, true) // Only iterate remotes + heap.Init(l.remotes) } diff --git a/core/tx_list_test.go b/core/tx_list_test.go index f0ec8eb8b4c7..36a0196f1eb3 100644 --- a/core/tx_list_test.go +++ b/core/tx_list_test.go @@ -17,6 +17,7 @@ package core import ( + "math/big" "math/rand" "testing" @@ -49,3 +50,21 @@ func TestStrictTxListAdd(t *testing.T) { } } } + +func BenchmarkTxListAdd(t *testing.B) { + // Generate a list of transactions to insert + key, _ := crypto.GenerateKey() + + txs := make(types.Transactions, 100000) + for i := 0; i < len(txs); i++ { + txs[i] = transaction(uint64(i), 0, key) + } + // Insert the transactions in a random order + list := newTxList(true) + priceLimit := big.NewInt(int64(DefaultTxPoolConfig.PriceLimit)) + t.ResetTimer() + for _, v := range rand.Perm(len(txs)) { + list.Add(txs[v], DefaultTxPoolConfig.PriceBump) + list.Filter(priceLimit, DefaultTxPoolConfig.PriceBump, nil, nil) + } +} diff --git a/core/tx_noncer.go b/core/tx_noncer.go new file mode 100644 index 000000000000..cbadc39354a3 --- /dev/null +++ b/core/tx_noncer.go @@ -0,0 +1,79 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package core + +import ( + "sync" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/state" +) + +// txNoncer is a tiny virtual state database to manage the executable nonces of +// accounts in the pool, falling back to reading from a real state database if +// an account is unknown. +type txNoncer struct { + fallback *state.StateDB + nonces map[common.Address]uint64 + lock sync.Mutex +} + +// newTxNoncer creates a new virtual state database to track the pool nonces. +func newTxNoncer(statedb *state.StateDB) *txNoncer { + return &txNoncer{ + fallback: statedb.Copy(), + nonces: make(map[common.Address]uint64), + } +} + +// get returns the current nonce of an account, falling back to a real state +// database if the account is unknown. +func (txn *txNoncer) get(addr common.Address) uint64 { + // We use mutex for get operation is the underlying + // state will mutate db even for read access. + txn.lock.Lock() + defer txn.lock.Unlock() + + if _, ok := txn.nonces[addr]; !ok { + txn.nonces[addr] = txn.fallback.GetNonce(addr) + } + return txn.nonces[addr] +} + +// set inserts a new virtual nonce into the virtual state database to be returned +// whenever the pool requests it instead of reaching into the real state database. +func (txn *txNoncer) set(addr common.Address, nonce uint64) { + txn.lock.Lock() + defer txn.lock.Unlock() + + txn.nonces[addr] = nonce +} + +// setIfLower updates a new virtual nonce into the virtual state database if the +// the new one is lower. +func (txn *txNoncer) setIfLower(addr common.Address, nonce uint64) { + txn.lock.Lock() + defer txn.lock.Unlock() + + if _, ok := txn.nonces[addr]; !ok { + txn.nonces[addr] = txn.fallback.GetNonce(addr) + } + if txn.nonces[addr] <= nonce { + return + } + txn.nonces[addr] = nonce +} diff --git a/core/tx_pool.go b/core/tx_pool.go index 2258ee4198c2..64359a972bbc 100644 --- a/core/tx_pool.go +++ b/core/tx_pool.go @@ -25,26 +25,39 @@ import ( "sync" "time" - "github.com/XinFinOrg/XDPoSChain/consensus" - "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/prque" + "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/event" "github.com/XinFinOrg/XDPoSChain/log" "github.com/XinFinOrg/XDPoSChain/metrics" "github.com/XinFinOrg/XDPoSChain/params" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) const ( // chainHeadChanSize is the size of channel listening to ChainHeadEvent. chainHeadChanSize = 10 - // rmTxChanSize is the size of channel listening to RemovedTransactionEvent. - rmTxChanSize = 10 + + // txSlotSize is used to calculate how many data slots a single transaction + // takes up based on its size. The slots are used as DoS protection, ensuring + // that validating a new transaction remains a constant operation (in reality + // O(maxslots), where max slots are 4 currently). + txSlotSize = 32 * 1024 + + // txMaxSize is the maximum size a single transaction can have. This field has + // non-trivial consequences: larger transactions are significantly harder and + // more expensive to propagate; larger transactions also take more resources + // to validate whether they fit into the pool or not. + txMaxSize = 2 * txSlotSize // 64KB, don't bump without EIP-2464 support ) var ( + // ErrAlreadyKnown is returned if the transactions is already contained + // within the pool. + ErrAlreadyKnown = errors.New("already known") + // ErrInvalidSender is returned if the transaction contains an invalid signature. ErrInvalidSender = errors.New("invalid sender") @@ -56,6 +69,10 @@ var ( // configured for the transaction pool. ErrUnderpriced = errors.New("transaction underpriced") + // ErrTxPoolOverflow is returned if the transaction pool is full and can't accpet + // another remote transaction. + ErrTxPoolOverflow = errors.New("txpool is full") + // ErrReplaceUnderpriced is returned if a transaction is attempted to be replaced // with a different one without the required price bump. ErrReplaceUnderpriced = errors.New("replacement transaction underpriced") @@ -72,7 +89,7 @@ var ( // maximum allowance of the current block. ErrGasLimit = errors.New("exceeds block gas limit") - // ErrNegativeValue is a sanity error to ensure noone is able to specify a + // ErrNegativeValue is a sanity error to ensure no one is able to specify a // transaction with a negative value. ErrNegativeValue = errors.New("negative value") @@ -97,20 +114,29 @@ var ( var ( // Metrics for the pending pool - pendingDiscardCounter = metrics.NewRegisteredCounter("txpool/pending/discard", nil) - pendingReplaceCounter = metrics.NewRegisteredCounter("txpool/pending/replace", nil) - pendingRateLimitCounter = metrics.NewRegisteredCounter("txpool/pending/ratelimit", nil) // Dropped due to rate limiting - pendingNofundsCounter = metrics.NewRegisteredCounter("txpool/pending/nofunds", nil) // Dropped due to out-of-funds + pendingDiscardMeter = metrics.NewRegisteredMeter("txpool/pending/discard", nil) + pendingReplaceMeter = metrics.NewRegisteredMeter("txpool/pending/replace", nil) + pendingRateLimitMeter = metrics.NewRegisteredMeter("txpool/pending/ratelimit", nil) // Dropped due to rate limiting + pendingNofundsMeter = metrics.NewRegisteredMeter("txpool/pending/nofunds", nil) // Dropped due to out-of-funds // Metrics for the queued pool - queuedDiscardCounter = metrics.NewRegisteredCounter("txpool/queued/discard", nil) - queuedReplaceCounter = metrics.NewRegisteredCounter("txpool/queued/replace", nil) - queuedRateLimitCounter = metrics.NewRegisteredCounter("txpool/queued/ratelimit", nil) // Dropped due to rate limiting - queuedNofundsCounter = metrics.NewRegisteredCounter("txpool/queued/nofunds", nil) // Dropped due to out-of-funds + queuedDiscardMeter = metrics.NewRegisteredMeter("txpool/queued/discard", nil) + queuedReplaceMeter = metrics.NewRegisteredMeter("txpool/queued/replace", nil) + queuedRateLimitMeter = metrics.NewRegisteredMeter("txpool/queued/ratelimit", nil) // Dropped due to rate limiting + queuedNofundsMeter = metrics.NewRegisteredMeter("txpool/queued/nofunds", nil) // Dropped due to out-of-funds + queuedEvictionMeter = metrics.NewRegisteredMeter("txpool/queued/eviction", nil) // Dropped due to lifetime // General tx metrics - invalidTxCounter = metrics.NewRegisteredCounter("txpool/invalid", nil) - underpricedTxCounter = metrics.NewRegisteredCounter("txpool/underpriced", nil) + knownTxMeter = metrics.NewRegisteredMeter("txpool/known", nil) + validTxMeter = metrics.NewRegisteredMeter("txpool/valid", nil) + invalidTxMeter = metrics.NewRegisteredMeter("txpool/invalid", nil) + underpricedTxMeter = metrics.NewRegisteredMeter("txpool/underpriced", nil) + overflowedTxMeter = metrics.NewRegisteredMeter("txpool/overflowed", nil) + + pendingGauge = metrics.NewRegisteredGauge("txpool/pending", nil) + queuedGauge = metrics.NewRegisteredGauge("txpool/queued", nil) + localGauge = metrics.NewRegisteredGauge("txpool/local", nil) + slotsGauge = metrics.NewRegisteredGauge("txpool/slots", nil) ) // TxStatus is the current status of a transaction as seen by the pool. @@ -146,14 +172,15 @@ type blockChain interface { // TxPoolConfig are the configuration parameters of the transaction pool. type TxPoolConfig struct { - NoLocals bool // Whether local transaction handling should be disabled - Journal string // Journal of local transactions to survive node restarts - Rejournal time.Duration // Time interval to regenerate the local transaction journal + Locals []common.Address // Addresses that should be treated by default as local + NoLocals bool // Whether local transaction handling should be disabled + Journal string // Journal of local transactions to survive node restarts + Rejournal time.Duration // Time interval to regenerate the local transaction journal PriceLimit uint64 // Minimum gas price to enforce for acceptance into the pool PriceBump uint64 // Minimum price bump percentage to replace an already existing transaction (nonce) - AccountSlots uint64 // Minimum number of executable transaction slots guaranteed per account + AccountSlots uint64 // Number of executable transaction slots guaranteed per account GlobalSlots uint64 // Maximum number of executable transaction slots for all accounts AccountQueue uint64 // Maximum number of non-executable transaction slots permitted per account GlobalQueue uint64 // Maximum number of non-executable transaction slots for all accounts @@ -194,6 +221,26 @@ func (config *TxPoolConfig) sanitize() TxPoolConfig { log.Warn("Sanitizing invalid txpool price bump", "provided", conf.PriceBump, "updated", DefaultTxPoolConfig.PriceBump) conf.PriceBump = DefaultTxPoolConfig.PriceBump } + if conf.AccountSlots < 1 { + log.Warn("Sanitizing invalid txpool account slots", "provided", conf.AccountSlots, "updated", DefaultTxPoolConfig.AccountSlots) + conf.AccountSlots = DefaultTxPoolConfig.AccountSlots + } + if conf.GlobalSlots < 1 { + log.Warn("Sanitizing invalid txpool global slots", "provided", conf.GlobalSlots, "updated", DefaultTxPoolConfig.GlobalSlots) + conf.GlobalSlots = DefaultTxPoolConfig.GlobalSlots + } + if conf.AccountQueue < 1 { + log.Warn("Sanitizing invalid txpool account queue", "provided", conf.AccountQueue, "updated", DefaultTxPoolConfig.AccountQueue) + conf.AccountQueue = DefaultTxPoolConfig.AccountQueue + } + if conf.GlobalQueue < 1 { + log.Warn("Sanitizing invalid txpool global queue", "provided", conf.GlobalQueue, "updated", DefaultTxPoolConfig.GlobalQueue) + conf.GlobalQueue = DefaultTxPoolConfig.GlobalQueue + } + if conf.Lifetime < 1 { + log.Warn("Sanitizing invalid txpool lifetime", "provided", conf.Lifetime, "updated", DefaultTxPoolConfig.Lifetime) + conf.Lifetime = DefaultTxPoolConfig.Lifetime + } return conf } @@ -205,37 +252,46 @@ func (config *TxPoolConfig) sanitize() TxPoolConfig { // current state) and future transactions. Transactions move between those // two states over time as they are received and processed. type TxPool struct { - config TxPoolConfig - chainconfig *params.ChainConfig - chain blockChain - gasPrice *big.Int - txFeed event.Feed - scope event.SubscriptionScope - chainHeadCh chan ChainHeadEvent - chainHeadSub event.Subscription - signer types.Signer - mu sync.RWMutex - - currentState *state.StateDB // Current state in the blockchain head - pendingState *state.ManagedState // Pending state tracking virtual nonces - currentMaxGas uint64 // Current gas limit for transaction caps + config TxPoolConfig + chainconfig *params.ChainConfig + chain blockChain + gasPrice *big.Int + txFeed event.Feed + scope event.SubscriptionScope + signer types.Signer + mu sync.RWMutex + + currentState *state.StateDB // Current state in the blockchain head + pendingNonces *txNoncer // Pending state tracking virtual nonces + currentMaxGas uint64 // Current gas limit for transaction caps locals *accountSet // Set of local transaction to exempt from eviction rules journal *txJournal // Journal of local transaction to back up to disk - pending map[common.Address]*txList // All currently processable transactions - queue map[common.Address]*txList // Queued but non-processable transactions - beats map[common.Address]time.Time // Last heartbeat from each known account - all map[common.Hash]*types.Transaction // All transactions to allow lookups - priced *txPricedList // All transactions sorted by price - - wg sync.WaitGroup // for shutdown sync - - homestead bool + pending map[common.Address]*txList // All currently processable transactions + queue map[common.Address]*txList // Queued but non-processable transactions + beats map[common.Address]time.Time // Last heartbeat from each known account + all *txLookup // All transactions to allow lookups + priced *txPricedList // All transactions sorted by price + + chainHeadCh chan ChainHeadEvent + chainHeadSub event.Subscription + reqResetCh chan *txpoolResetRequest + reqPromoteCh chan *accountSet + queueTxEventCh chan *types.Transaction + reorgDoneCh chan chan struct{} + reorgShutdownCh chan struct{} // requests shutdown of scheduleReorgLoop + wg sync.WaitGroup // tracks loop, scheduleReorgLoop + + eip2718 bool // Fork indicator whether we are using EIP-2718 type transactions. IsSigner func(address common.Address) bool trc21FeeCapacity map[common.Address]*big.Int } +type txpoolResetRequest struct { + oldHead, newHead *types.Header +} + // NewTxPool creates a new transaction pool to gather, sort and filter inbound // transactions from the network. func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain blockChain) *TxPool { @@ -247,34 +303,46 @@ func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain block config: config, chainconfig: chainconfig, chain: chain, - signer: types.NewEIP155Signer(chainconfig.ChainId), + signer: types.LatestSigner(chainconfig), pending: make(map[common.Address]*txList), queue: make(map[common.Address]*txList), beats: make(map[common.Address]time.Time), - all: make(map[common.Hash]*types.Transaction), + all: newTxLookup(), chainHeadCh: make(chan ChainHeadEvent, chainHeadChanSize), + reqResetCh: make(chan *txpoolResetRequest), + reqPromoteCh: make(chan *accountSet), + queueTxEventCh: make(chan *types.Transaction), + reorgDoneCh: make(chan chan struct{}), + reorgShutdownCh: make(chan struct{}), gasPrice: new(big.Int).SetUint64(config.PriceLimit), trc21FeeCapacity: map[common.Address]*big.Int{}, } pool.locals = newAccountSet(pool.signer) - pool.priced = newTxPricedList(&pool.all) + for _, addr := range config.Locals { + log.Info("Setting new local account", "address", addr) + pool.locals.add(addr) + } + pool.priced = newTxPricedList(pool.all) pool.reset(nil, chain.CurrentBlock().Header()) + // Start the reorg loop early so it can handle requests generated during journal loading. + pool.wg.Add(1) + go pool.scheduleReorgLoop() + // If local transactions and journaling is enabled, load from disk if !config.NoLocals && config.Journal != "" { pool.journal = newTxJournal(config.Journal) - if err := pool.journal.load(pool.AddLocal); err != nil { + if err := pool.journal.load(pool.AddLocals); err != nil { log.Warn("Failed to load transaction journal", "err", err) } if err := pool.journal.rotate(pool.local()); err != nil { log.Warn("Failed to rotate transaction journal", "err", err) } } - // Subscribe events from blockchain - pool.chainHeadSub = pool.chain.SubscribeChainHeadEvent(pool.chainHeadCh) - // Start the event loop and return + // Subscribe events from blockchain and start the main event loop. + pool.chainHeadSub = pool.chain.SubscribeChainHeadEvent(pool.chainHeadCh) pool.wg.Add(1) go pool.loop() @@ -287,41 +355,34 @@ func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain block func (pool *TxPool) loop() { defer pool.wg.Done() - // Start the stats reporting and transaction eviction tickers - var prevPending, prevQueued, prevStales int - - report := time.NewTicker(statsReportInterval) + var ( + prevPending, prevQueued, prevStales int + // Start the stats reporting and transaction eviction tickers + report = time.NewTicker(statsReportInterval) + evict = time.NewTicker(evictionInterval) + journal = time.NewTicker(pool.config.Rejournal) + // Track the previous head headers for transaction reorgs + head = pool.chain.CurrentBlock() + ) defer report.Stop() - - evict := time.NewTicker(evictionInterval) defer evict.Stop() - - journal := time.NewTicker(pool.config.Rejournal) defer journal.Stop() - // Track the previous head headers for transaction reorgs - head := pool.chain.CurrentBlock() - - // Keep waiting for and reacting to the various events for { select { // Handle ChainHeadEvent case ev := <-pool.chainHeadCh: if ev.Block != nil { - pool.mu.Lock() - if pool.chainconfig.IsHomestead(ev.Block.Number()) { - pool.homestead = true - } - pool.reset(head.Header(), ev.Block.Header()) + pool.requestReset(head.Header(), ev.Block.Header()) head = ev.Block - - pool.mu.Unlock() } - // Be unsubscribed due to system stopped + + // System shutdown. case <-pool.chainHeadSub.Err(): + close(pool.reorgShutdownCh) return - // Handle stats reporting ticks + // Handle stats reporting ticks case <-report.C: pool.mu.RLock() pending, queued := pool.stats() @@ -333,7 +394,7 @@ func (pool *TxPool) loop() { prevPending, prevQueued, prevStales = pending, queued, stales } - // Handle inactive account transaction eviction + // Handle inactive account transaction eviction case <-evict.C: pool.mu.Lock() for addr := range pool.queue { @@ -343,14 +404,16 @@ func (pool *TxPool) loop() { } // Any non-locals old enough should be removed if time.Since(pool.beats[addr]) > pool.config.Lifetime { - for _, tx := range pool.queue[addr].Flatten() { - pool.removeTx(tx.Hash()) + list := pool.queue[addr].Flatten() + for _, tx := range list { + pool.removeTx(tx.Hash(), true) } + queuedEvictionMeter.Mark(int64(len(list))) } } pool.mu.Unlock() - // Handle local transaction journal rotation + // Handle local transaction journal rotation case <-journal.C: if pool.journal != nil { pool.mu.Lock() @@ -363,99 +426,6 @@ func (pool *TxPool) loop() { } } -// lockedReset is a wrapper around reset to allow calling it in a thread safe -// manner. This method is only ever used in the tester! -func (pool *TxPool) lockedReset(oldHead, newHead *types.Header) { - pool.mu.Lock() - defer pool.mu.Unlock() - - pool.reset(oldHead, newHead) -} - -// reset retrieves the current state of the blockchain and ensures the content -// of the transaction pool is valid with regard to the chain state. -func (pool *TxPool) reset(oldHead, newHead *types.Header) { - // If we're reorging an old state, reinject all dropped transactions - var reinject types.Transactions - - if oldHead != nil && oldHead.Hash() != newHead.ParentHash { - // If the reorg is too deep, avoid doing it (will happen during fast sync) - oldNum := oldHead.Number.Uint64() - newNum := newHead.Number.Uint64() - - if depth := uint64(math.Abs(float64(oldNum) - float64(newNum))); depth > 64 { - log.Debug("Skipping deep transaction reorg", "depth", depth) - } else { - // Reorg seems shallow enough to pull in all transactions into memory - var discarded, included types.Transactions - - var ( - rem = pool.chain.GetBlock(oldHead.Hash(), oldHead.Number.Uint64()) - add = pool.chain.GetBlock(newHead.Hash(), newHead.Number.Uint64()) - ) - for rem.NumberU64() > add.NumberU64() { - discarded = append(discarded, rem.Transactions()...) - if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { - log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) - return - } - } - for add.NumberU64() > rem.NumberU64() { - included = append(included, add.Transactions()...) - if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { - log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) - return - } - } - for rem.Hash() != add.Hash() { - discarded = append(discarded, rem.Transactions()...) - if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { - log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) - return - } - included = append(included, add.Transactions()...) - if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { - log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) - return - } - } - reinject = types.TxDifference(discarded, included) - } - } - // Initialize the internal state to the current head - if newHead == nil { - newHead = pool.chain.CurrentBlock().Header() // Special case during testing - } - statedb, err := pool.chain.StateAt(newHead.Root) - if err != nil { - log.Error("Failed to reset txpool state", "err", err) - return - } - pool.currentState = statedb - pool.trc21FeeCapacity = state.GetTRC21FeeCapacityFromStateWithCache(newHead.Root, statedb) - pool.pendingState = state.ManageState(statedb) - pool.currentMaxGas = newHead.GasLimit - - // Inject any transactions discarded due to reorgs - log.Debug("Reinjecting stale transactions", "count", len(reinject)) - pool.addTxsLocked(reinject, false) - - // validate the pool of pending transactions, this will remove - // any transactions that have been included in the block or - // have been invalidated because of another transaction (e.g. - // higher gas price) - pool.demoteUnexecutables() - - // Update all accounts to the latest known pending nonce - for addr, list := range pool.pending { - txs := list.Flatten() // Heavy but will be cached and is needed by the miner anyway - pool.pendingState.SetNonce(addr, txs[len(txs)-1].Nonce()+1) - } - // Check the queue and move transactions over to the pending if possible - // or remove those that have become invalid - pool.promoteExecutables(nil) -} - // Stop terminates the transaction pool. func (pool *TxPool) Stop() { // Unsubscribe all subscriptions registered from txpool @@ -471,9 +441,9 @@ func (pool *TxPool) Stop() { log.Info("Transaction pool stopped") } -// SubscribeTxPreEvent registers a subscription of TxPreEvent and +// SubscribeNewTxsEvent registers a subscription of NewTxsEvent and // starts sending event to the given channel. -func (pool *TxPool) SubscribeTxPreEvent(ch chan<- TxPreEvent) event.Subscription { +func (pool *TxPool) SubscribeNewTxsEvent(ch chan<- NewTxsEvent) event.Subscription { return pool.scope.Track(pool.txFeed.Subscribe(ch)) } @@ -492,18 +462,19 @@ func (pool *TxPool) SetGasPrice(price *big.Int) { defer pool.mu.Unlock() pool.gasPrice = price - for _, tx := range pool.priced.Cap(price, pool.locals) { - pool.removeTx(tx.Hash()) + for _, tx := range pool.priced.Cap(price) { + pool.removeTx(tx.Hash(), false) } log.Info("Transaction pool price threshold updated", "price", price) } -// State returns the virtual managed state of the transaction pool. -func (pool *TxPool) State() *state.ManagedState { +// Nonce returns the next nonce of an account, with all transactions executable +// by the pool already applied on top. +func (pool *TxPool) Nonce(addr common.Address) uint64 { pool.mu.RLock() defer pool.mu.RUnlock() - return pool.pendingState + return pool.pendingNonces.get(addr) } // Stats retrieves the current pool stats, namely the number of pending and the @@ -546,7 +517,7 @@ func (pool *TxPool) Content() (map[common.Address]types.Transactions, map[common return pending, queued } -// Pending retrieves all currently processable transactions, groupped by origin +// Pending retrieves all currently processable transactions, grouped by origin // account and sorted by nonce. The returned transaction set is a copy and can be // freely modified by calling code. func (pool *TxPool) Pending() (map[common.Address]types.Transactions, error) { @@ -560,7 +531,15 @@ func (pool *TxPool) Pending() (map[common.Address]types.Transactions, error) { return pending, nil } -// local retrieves all currently known local transactions, groupped by origin +// Locals retrieves the accounts currently considered local by the pool. +func (pool *TxPool) Locals() []common.Address { + pool.mu.Lock() + defer pool.mu.Unlock() + + return pool.locals.flatten() +} + +// local retrieves all currently known local transactions, grouped by origin // account and sorted by nonce. The returned transaction set is a copy and can be // freely modified by calling code. func (pool *TxPool) local() map[common.Address]types.Transactions { @@ -587,6 +566,14 @@ func (pool *TxPool) GetSender(tx *types.Transaction) (common.Address, error) { // validateTx checks whether a transaction is valid according to the consensus // rules and adheres to some heuristic limits of the local node (price and size). func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { + // Accept only legacy transactions until EIP-2718/2930 activates. + if !pool.eip2718 && tx.Type() != types.LegacyTxType { + return ErrTxTypeNotSupported + } + // Reject transactions over defined size to prevent DOS attacks + if uint64(tx.Size()) > txMaxSize { + return ErrOversizedData + } // check if sender is in black list if tx.From() != nil && common.Blacklist[*tx.From()] { return fmt.Errorf("Reject transaction with sender in black-list: %v", tx.From().Hex()) @@ -595,11 +582,6 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { if tx.To() != nil && common.Blacklist[*tx.To()] { return fmt.Errorf("Reject transaction with receiver in black-list: %v", tx.To().Hex()) } - - // Heuristic limit, reject transactions over 32KB to prevent DOS attacks - if tx.Size() > 32*1024 { - return ErrOversizedData - } // Transactions can't be negative. This may never happen using RLP decoded // transactions but may occur if you create a transaction using the RPC. if tx.Value().Sign() < 0 { @@ -609,14 +591,13 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { if pool.currentMaxGas < tx.Gas() { return ErrGasLimit } - // Make sure the transaction is signed properly + // Make sure the transaction is signed properly. from, err := types.Sender(pool.signer, tx) if err != nil { return ErrInvalidSender } // Drop non-local transactions under our own minimal accepted gas price - local = local || pool.locals.contains(from) // account may be local even if the transaction arrived from the network - if !local && pool.gasPrice.Cmp(tx.GasPrice()) > 0 { + if !local && tx.GasPriceIntCmp(pool.gasPrice) < 0 { if !tx.IsSpecialTransaction() || (pool.IsSigner != nil && !pool.IsSigner(from)) { return ErrUnderpriced } @@ -625,7 +606,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { if pool.currentState.GetNonce(from) > tx.Nonce() { return ErrNonceTooLow } - if pool.pendingState.GetNonce(from)+common.LimitThresholdNonceInQueue < tx.Nonce() { + if pool.pendingNonces.get(from)+common.LimitThresholdNonceInQueue < tx.Nonce() { return ErrNonceTooHigh } // Transactor should have enough funds to cover the costs @@ -642,7 +623,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { if tx.To() != nil { if value, ok := pool.trc21FeeCapacity[*tx.To()]; ok { feeCapacity = value - if !state.ValidateTRC21Tx(pool.pendingState.StateDB, from, *tx.To(), tx.Data()) { + if !state.ValidateTRC21Tx(pool.currentState, from, *tx.To(), tx.Data()) { return ErrInsufficientFunds } cost = tx.TxCost(number) @@ -653,7 +634,8 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { } if tx.To() == nil || (tx.To() != nil && !tx.IsSpecialTransaction()) { - intrGas, err := IntrinsicGas(tx.Data(), tx.To() == nil, pool.homestead) + // Ensure the transaction has more gas than the basic tx fee. + intrGas, err := IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, true) if err != nil { return err } @@ -663,7 +645,7 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { } // Check zero gas price. - if tx.GasPrice().Cmp(new(big.Int).SetInt64(0)) == 0 { + if tx.GasPrice().Sign() == 0 { return ErrZeroGasPrice } @@ -694,93 +676,110 @@ func (pool *TxPool) validateTx(tx *types.Transaction, local bool) error { return nil } -// add validates a transaction and inserts it into the non-executable queue for -// later pending promotion and execution. If the transaction is a replacement for -// an already pending or queued one, it overwrites the previous and returns this -// so outer code doesn't uselessly call promote. +// add validates a transaction and inserts it into the non-executable queue for later +// pending promotion and execution. If the transaction is a replacement for an already +// pending or queued one, it overwrites the previous transaction if its price is higher. // // If a newly added transaction is marked as local, its sending account will be -// whitelisted, preventing any associated transaction from being dropped out of -// the pool due to pricing constraints. -func (pool *TxPool) add(tx *types.Transaction, local bool) (bool, error) { +// whitelisted, preventing any associated transaction from being dropped out of the pool +// due to pricing constraints. +func (pool *TxPool) add(tx *types.Transaction, local bool) (replaced bool, err error) { // If the transaction is already known, discard it hash := tx.Hash() - if pool.all[hash] != nil { + if pool.all.Get(hash) != nil { log.Trace("Discarding already known transaction", "hash", hash) - return false, fmt.Errorf("known transaction: %x", hash) + knownTxMeter.Mark(1) + return false, ErrAlreadyKnown } + // Make the local flag. If it's from local source or it's from the network but + // the sender is marked as local previously, treat it as the local transaction. + isLocal := local || pool.locals.containsTx(tx) // If the transaction fails basic validation, discard it - if err := pool.validateTx(tx, local); err != nil { + if err := pool.validateTx(tx, isLocal); err != nil { log.Trace("Discarding invalid transaction", "hash", hash, "err", err) - invalidTxCounter.Inc(1) + invalidTxMeter.Mark(1) return false, err } from, _ := types.Sender(pool.signer, tx) // already validated - if tx.IsSpecialTransaction() && pool.IsSigner != nil && pool.IsSigner(from) && pool.pendingState.GetNonce(from) == tx.Nonce() { - return pool.promoteSpecialTx(from, tx) + if tx.IsSpecialTransaction() && pool.IsSigner != nil && pool.IsSigner(from) && pool.pendingNonces.get(from) == tx.Nonce() { + return pool.promoteSpecialTx(from, tx, isLocal) } // If the transaction pool is full, discard underpriced transactions - if uint64(len(pool.all)) >= pool.config.GlobalSlots+pool.config.GlobalQueue { + if uint64(pool.all.Count()) >= pool.config.GlobalSlots+pool.config.GlobalQueue { log.Debug("Add transaction to pool full", "hash", hash, "nonce", tx.Nonce()) // If the new transaction is underpriced, don't accept it - if pool.priced.Underpriced(tx, pool.locals) { + if !isLocal && pool.priced.Underpriced(tx) { log.Trace("Discarding underpriced transaction", "hash", hash, "price", tx.GasPrice()) - underpricedTxCounter.Inc(1) + underpricedTxMeter.Mark(1) return false, ErrUnderpriced } - // New transaction is better than our worse ones, make room for it - drop := pool.priced.Discard(len(pool.all)-int(pool.config.GlobalSlots+pool.config.GlobalQueue-1), pool.locals) + // New transaction is better than our worse ones, make room for it. + // If it's a local transaction, forcibly discard all available transactions. + // Otherwise if we can't make enough room for new one, abort the operation. + drop, success := pool.priced.Discard(pool.all.Slots()-int(pool.config.GlobalSlots+pool.config.GlobalQueue)+numSlots(tx), isLocal) + + // Special case, we still can't make the room for the new remote one. + if !isLocal && !success { + log.Trace("Discarding overflown transaction", "hash", hash) + overflowedTxMeter.Mark(1) + return false, ErrTxPoolOverflow + } + // Kick out the underpriced remote transactions. for _, tx := range drop { log.Trace("Discarding freshly underpriced transaction", "hash", tx.Hash(), "price", tx.GasPrice()) - underpricedTxCounter.Inc(1) - pool.removeTx(tx.Hash()) + underpricedTxMeter.Mark(1) + pool.removeTx(tx.Hash(), false) } } - // If the transaction is replacing an already pending one, do directly + // Try to replace an existing transaction in the pending pool if list := pool.pending[from]; list != nil && list.Overlaps(tx) { // Nonce already pending, check if required price bump is met inserted, old := list.Add(tx, pool.config.PriceBump) if !inserted { - pendingDiscardCounter.Inc(1) + pendingDiscardMeter.Mark(1) return false, ErrReplaceUnderpriced } // New transaction is better, replace old one if old != nil { - delete(pool.all, old.Hash()) - pool.priced.Removed() - pendingReplaceCounter.Inc(1) + pool.all.Remove(old.Hash()) + pool.priced.Removed(1) + pendingReplaceMeter.Mark(1) } - pool.all[tx.Hash()] = tx - pool.priced.Put(tx) + pool.all.Add(tx, isLocal) + pool.priced.Put(tx, isLocal) pool.journalTx(from, tx) - + pool.queueTxEvent(tx) log.Trace("Pooled new executable transaction", "hash", hash, "from", from, "to", tx.To()) - // We've directly injected a replacement transaction, notify subsystems - go pool.txFeed.Send(TxPreEvent{tx}) - + // Successful promotion, bump the heartbeat + pool.beats[from] = time.Now() return old != nil, nil } // New transaction isn't replacing a pending one, push into queue - replace, err := pool.enqueueTx(hash, tx) + replaced, err = pool.enqueueTx(hash, tx, isLocal, true) if err != nil { return false, err } // Mark local addresses and journal local transactions - if local { + if local && !pool.locals.contains(from) { + log.Info("Setting new local account", "address", from) pool.locals.add(from) + pool.priced.Removed(pool.all.RemoteToLocals(pool.locals)) // Migrate the remotes if it's marked as local first time. + } + if isLocal { + localGauge.Inc(1) } pool.journalTx(from, tx) log.Trace("Pooled new future transaction", "hash", hash, "from", from, "to", tx.To()) - return replace, nil + return replaced, nil } // enqueueTx inserts a new transaction into the non-executable transaction queue. // // Note, this method assumes the pool lock is held! -func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, error) { +func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction, local bool, addAll bool) (bool, error) { // Try to insert the transaction into the future queue from, _ := types.Sender(pool.signer, tx) // already validated if pool.queue[from] == nil { @@ -789,17 +788,31 @@ func (pool *TxPool) enqueueTx(hash common.Hash, tx *types.Transaction) (bool, er inserted, old := pool.queue[from].Add(tx, pool.config.PriceBump) if !inserted { // An older transaction was better, discard this - queuedDiscardCounter.Inc(1) + queuedDiscardMeter.Mark(1) return false, ErrReplaceUnderpriced } // Discard any previous transaction and mark this if old != nil { - delete(pool.all, old.Hash()) - pool.priced.Removed() - queuedReplaceCounter.Inc(1) + pool.all.Remove(old.Hash()) + pool.priced.Removed(1) + queuedReplaceMeter.Mark(1) + } else { + // Nothing was replaced, bump the queued counter + queuedGauge.Inc(1) + } + // If the transaction isn't in lookup set but it's expected to be there, + // show the error log. + if pool.all.Get(hash) == nil && !addAll { + log.Error("Missing transaction in lookup set, please report the issue", "hash", hash) + } + if addAll { + pool.all.Add(tx, local) + pool.priced.Put(tx, local) + } + // If we never record the heartbeat, do it right now. + if _, exist := pool.beats[from]; !exist { + pool.beats[from] = time.Now() } - pool.all[hash] = tx - pool.priced.Put(tx) return old != nil, nil } @@ -815,10 +828,11 @@ func (pool *TxPool) journalTx(from common.Address, tx *types.Transaction) { } } -// promoteTx adds a transaction to the pending (processable) list of transactions. +// promoteTx adds a transaction to the pending (processable) list of transactions +// and returns whether it was inserted or an older was better. // // Note, this method assumes the pool lock is held! -func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.Transaction) { +func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.Transaction) bool { // Try to insert the transaction into the pending queue if pool.pending[addr] == nil { pool.pending[addr] = newTxList(true) @@ -828,46 +842,48 @@ func (pool *TxPool) promoteTx(addr common.Address, hash common.Hash, tx *types.T inserted, old := list.Add(tx, pool.config.PriceBump) if !inserted { // An older transaction was better, discard this - delete(pool.all, hash) - pool.priced.Removed() + pool.all.Remove(hash) + pool.priced.Removed(1) - pendingDiscardCounter.Inc(1) - return + pendingDiscardMeter.Mark(1) + return false } // Otherwise discard any previous transaction and mark this if old != nil { - delete(pool.all, old.Hash()) - pool.priced.Removed() - - pendingReplaceCounter.Inc(1) - } - // Failsafe to work around direct pending inserts (tests) - if pool.all[hash] == nil { - pool.all[hash] = tx - pool.priced.Put(tx) + pool.all.Remove(old.Hash()) + pool.priced.Removed(1) + pendingReplaceMeter.Mark(1) + } else { + // Nothing was replaced, bump the pending counter + pendingGauge.Inc(1) } // Set the potentially new pending nonce and notify any subsystems of the new tx - pool.beats[addr] = time.Now() - pool.pendingState.SetNonce(addr, tx.Nonce()+1) + pool.pendingNonces.set(addr, tx.Nonce()+1) - go pool.txFeed.Send(TxPreEvent{tx}) + // Successful promotion, bump the heartbeat + pool.beats[addr] = time.Now() + return true } -func (pool *TxPool) promoteSpecialTx(addr common.Address, tx *types.Transaction) (bool, error) { +func (pool *TxPool) promoteSpecialTx(addr common.Address, tx *types.Transaction, isLocal bool) (bool, error) { // Try to insert the transaction into the pending queue if pool.pending[addr] == nil { pool.pending[addr] = newTxList(true) } list := pool.pending[addr] + old := list.txs.Get(tx.Nonce()) if old != nil && old.IsSpecialTransaction() { return false, ErrDuplicateSpecialTransaction } // Otherwise discard any previous transaction and mark this if old != nil { - delete(pool.all, old.Hash()) - pool.priced.Removed() - pendingReplaceCounter.Inc(1) + pool.all.Remove(old.Hash()) + pool.priced.Removed(1) + pendingReplaceMeter.Mark(1) + } else { + // Nothing was replaced, bump the pending counter + pendingGauge.Inc(1) } list.txs.Put(tx) if cost := tx.Cost(); list.costcap.Cmp(cost) < 0 { @@ -877,186 +893,482 @@ func (pool *TxPool) promoteSpecialTx(addr common.Address, tx *types.Transaction) list.gascap = gas } // Failsafe to work around direct pending inserts (tests) - if pool.all[tx.Hash()] == nil { - pool.all[tx.Hash()] = tx + if pool.all.Get(tx.Hash()) == nil { + pool.all.Add(tx, isLocal) } // Set the potentially new pending nonce and notify any subsystems of the new tx pool.beats[addr] = time.Now() - pool.pendingState.SetNonce(addr, tx.Nonce()+1) - go pool.txFeed.Send(TxPreEvent{tx}) + pool.pendingNonces.set(addr, tx.Nonce()+1) + go pool.txFeed.Send(NewTxsEvent{types.Transactions{tx}}) return true, nil } -// AddLocal enqueues a single transaction into the pool if it is valid, marking -// the sender as a local one in the mean time, ensuring it goes around the local -// pricing constraints. +// AddLocals enqueues a batch of transactions into the pool if they are valid, marking the +// senders as a local ones, ensuring they go around the local pricing constraints. +// +// This method is used to add transactions from the RPC API and performs synchronous pool +// reorganization and event propagation. +func (pool *TxPool) AddLocals(txs []*types.Transaction) []error { + return pool.addTxs(txs, !pool.config.NoLocals, true) +} + +// AddLocal enqueues a single local transaction into the pool if it is valid. This is +// a convenience wrapper aroundd AddLocals. func (pool *TxPool) AddLocal(tx *types.Transaction) error { - return pool.addTx(tx, !pool.config.NoLocals) + errs := pool.AddLocals([]*types.Transaction{tx}) + return errs[0] } -// AddRemote enqueues a single transaction into the pool if it is valid. If the -// sender is not among the locally tracked ones, full pricing constraints will -// apply. -func (pool *TxPool) AddRemote(tx *types.Transaction) error { - return pool.addTx(tx, false) +// AddRemotes enqueues a batch of transactions into the pool if they are valid. If the +// senders are not among the locally tracked ones, full pricing constraints will apply. +// +// This method is used to add transactions from the p2p network and does not wait for pool +// reorganization and internal event propagation. +func (pool *TxPool) AddRemotes(txs []*types.Transaction) []error { + return pool.addTxs(txs, false, false) } -// AddLocals enqueues a batch of transactions into the pool if they are valid, -// marking the senders as a local ones in the mean time, ensuring they go around -// the local pricing constraints. -func (pool *TxPool) AddLocals(txs []*types.Transaction) []error { - return pool.addTxs(txs, !pool.config.NoLocals) +// This is like AddRemotes, but waits for pool reorganization. Tests use this method. +func (pool *TxPool) AddRemotesSync(txs []*types.Transaction) []error { + return pool.addTxs(txs, false, true) } -// AddRemotes enqueues a batch of transactions into the pool if they are valid. -// If the senders are not among the locally tracked ones, full pricing constraints -// will apply. -func (pool *TxPool) AddRemotes(txs []*types.Transaction) []error { - return pool.addTxs(txs, false) +// This is like AddRemotes with a single transaction, but waits for pool reorganization. Tests use this method. +func (pool *TxPool) addRemoteSync(tx *types.Transaction) error { + errs := pool.AddRemotesSync([]*types.Transaction{tx}) + return errs[0] } -// addTx enqueues a single transaction into the pool if it is valid. -func (pool *TxPool) addTx(tx *types.Transaction, local bool) error { - tx.CacheHash() - types.CacheSigner(pool.signer, tx) - pool.mu.Lock() - defer pool.mu.Unlock() +// AddRemote enqueues a single transaction into the pool if it is valid. This is a convenience +// wrapper around AddRemotes. +// +// Deprecated: use AddRemotes +func (pool *TxPool) AddRemote(tx *types.Transaction) error { + errs := pool.AddRemotes([]*types.Transaction{tx}) + return errs[0] +} - // Try to inject the transaction and update any state - replace, err := pool.add(tx, local) - if err != nil { - return err +// addTxs attempts to queue a batch of transactions if they are valid. +func (pool *TxPool) addTxs(txs []*types.Transaction, local, sync bool) []error { + // Filter out known ones without obtaining the pool lock or recovering signatures + var ( + errs = make([]error, len(txs)) + news = make([]*types.Transaction, 0, len(txs)) + ) + for i, tx := range txs { + // If the transaction is known, pre-set the error slot + if pool.all.Get(tx.Hash()) != nil { + errs[i] = ErrAlreadyKnown + knownTxMeter.Mark(1) + continue + } + // Exclude transactions with invalid signatures as soon as + // possible and cache senders in transactions before + // obtaining lock + _, err := types.Sender(pool.signer, tx) + if err != nil { + errs[i] = ErrInvalidSender + invalidTxMeter.Mark(1) + continue + } + // Accumulate all unknown transactions for deeper processing + news = append(news, tx) } - // If we added a new transaction, run promotion checks and return - if !replace { - from, _ := types.Sender(pool.signer, tx) // already validated - pool.promoteExecutables([]common.Address{from}) + if len(news) == 0 { + return errs } - return nil -} -// addTxs attempts to queue a batch of transactions if they are valid. -func (pool *TxPool) addTxs(txs []*types.Transaction, local bool) []error { + // Process all the new transaction and merge any errors into the original slice pool.mu.Lock() - defer pool.mu.Unlock() + newErrs, dirtyAddrs := pool.addTxsLocked(news, local) + pool.mu.Unlock() - return pool.addTxsLocked(txs, local) + var nilSlot = 0 + for _, err := range newErrs { + for errs[nilSlot] != nil { + nilSlot++ + } + errs[nilSlot] = err + nilSlot++ + } + // Reorg the pool internals if needed and return + done := pool.requestPromoteExecutables(dirtyAddrs) + if sync { + <-done + } + return errs } -// addTxsLocked attempts to queue a batch of transactions if they are valid, -// whilst assuming the transaction pool lock is already held. -func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) []error { - // Add the batch of transaction, tracking the accepted ones - dirty := make(map[common.Address]struct{}) +// addTxsLocked attempts to queue a batch of transactions if they are valid. +// The transaction pool lock must be held. +func (pool *TxPool) addTxsLocked(txs []*types.Transaction, local bool) ([]error, *accountSet) { + dirty := newAccountSet(pool.signer) errs := make([]error, len(txs)) - for i, tx := range txs { - var replace bool - if replace, errs[i] = pool.add(tx, local); errs[i] == nil { - if !replace { - from, _ := types.Sender(pool.signer, tx) // already validated - dirty[from] = struct{}{} - } - } - } - // Only reprocess the internal state if something was actually added - if len(dirty) > 0 { - addrs := make([]common.Address, 0, len(dirty)) - for addr := range dirty { - addrs = append(addrs, addr) + replaced, err := pool.add(tx, local) + errs[i] = err + if err == nil && !replaced { + dirty.addTx(tx) } - pool.promoteExecutables(addrs) } - return errs + validTxMeter.Mark(int64(len(dirty.accounts))) + return errs, dirty } // Status returns the status (unknown/pending/queued) of a batch of transactions // identified by their hashes. func (pool *TxPool) Status(hashes []common.Hash) []TxStatus { - pool.mu.RLock() - defer pool.mu.RUnlock() - status := make([]TxStatus, len(hashes)) for i, hash := range hashes { - if tx := pool.all[hash]; tx != nil { - from, _ := types.Sender(pool.signer, tx) // already validated - if pool.pending[from] != nil && pool.pending[from].txs.items[tx.Nonce()] != nil { - status[i] = TxStatusPending - } else { - status[i] = TxStatusQueued - } + tx := pool.Get(hash) + if tx == nil { + continue } + from, _ := types.Sender(pool.signer, tx) // already validated + pool.mu.RLock() + if txList := pool.pending[from]; txList != nil && txList.txs.items[tx.Nonce()] != nil { + status[i] = TxStatusPending + } else if txList := pool.queue[from]; txList != nil && txList.txs.items[tx.Nonce()] != nil { + status[i] = TxStatusQueued + } + // implicit else: the tx may have been included into a block between + // checking pool.Get and obtaining the lock. In that case, TxStatusUnknown is correct + pool.mu.RUnlock() } return status } -// Get returns a transaction if it is contained in the pool -// and nil otherwise. +// Get returns a transaction if it is contained in the pool and nil otherwise. func (pool *TxPool) Get(hash common.Hash) *types.Transaction { - pool.mu.RLock() - defer pool.mu.RUnlock() + return pool.all.Get(hash) +} - return pool.all[hash] +// Has returns an indicator whether txpool has a transaction cached with the +// given hash. +func (pool *TxPool) Has(hash common.Hash) bool { + return pool.all.Get(hash) != nil } // removeTx removes a single transaction from the queue, moving all subsequent // transactions back to the future queue. -func (pool *TxPool) removeTx(hash common.Hash) { +func (pool *TxPool) removeTx(hash common.Hash, outofbound bool) { // Fetch the transaction we wish to delete - tx, ok := pool.all[hash] - if !ok { + tx := pool.all.Get(hash) + if tx == nil { return } addr, _ := types.Sender(pool.signer, tx) // already validated during insertion // Remove it from the list of known transactions - delete(pool.all, hash) - pool.priced.Removed() - + pool.all.Remove(hash) + if outofbound { + pool.priced.Removed(1) + } + if pool.locals.contains(addr) { + localGauge.Dec(1) + } // Remove the transaction from the pending lists and reset the account nonce if pending := pool.pending[addr]; pending != nil { if removed, invalids := pending.Remove(tx); removed { // If no more pending transactions are left, remove the list if pending.Empty() { delete(pool.pending, addr) - delete(pool.beats, addr) } // Postpone any invalidated transactions for _, tx := range invalids { - pool.enqueueTx(tx.Hash(), tx) + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(tx.Hash(), tx, false, false) } // Update the account nonce if needed - if nonce := tx.Nonce(); pool.pendingState.GetNonce(addr) > nonce { - pool.pendingState.SetNonce(addr, nonce) - } + pool.pendingNonces.setIfLower(addr, tx.Nonce()) + // Reduce the pending counter + pendingGauge.Dec(int64(1 + len(invalids))) return } } // Transaction is in the future queue if future := pool.queue[addr]; future != nil { - future.Remove(tx) + if removed, _ := future.Remove(tx); removed { + // Reduce the queued counter + queuedGauge.Dec(1) + } if future.Empty() { delete(pool.queue, addr) + delete(pool.beats, addr) + } + } +} + +// requestReset requests a pool reset to the new head block. +// The returned channel is closed when the reset has occurred. +func (pool *TxPool) requestReset(oldHead *types.Header, newHead *types.Header) chan struct{} { + select { + case pool.reqResetCh <- &txpoolResetRequest{oldHead, newHead}: + return <-pool.reorgDoneCh + case <-pool.reorgShutdownCh: + return pool.reorgShutdownCh + } +} + +// requestPromoteExecutables requests transaction promotion checks for the given addresses. +// The returned channel is closed when the promotion checks have occurred. +func (pool *TxPool) requestPromoteExecutables(set *accountSet) chan struct{} { + select { + case pool.reqPromoteCh <- set: + return <-pool.reorgDoneCh + case <-pool.reorgShutdownCh: + return pool.reorgShutdownCh + } +} + +// queueTxEvent enqueues a transaction event to be sent in the next reorg run. +func (pool *TxPool) queueTxEvent(tx *types.Transaction) { + select { + case pool.queueTxEventCh <- tx: + case <-pool.reorgShutdownCh: + } +} + +// scheduleReorgLoop schedules runs of reset and promoteExecutables. Code above should not +// call those methods directly, but request them being run using requestReset and +// requestPromoteExecutables instead. +func (pool *TxPool) scheduleReorgLoop() { + defer pool.wg.Done() + + var ( + curDone chan struct{} // non-nil while runReorg is active + nextDone = make(chan struct{}) + launchNextRun bool + reset *txpoolResetRequest + dirtyAccounts *accountSet + queuedEvents = make(map[common.Address]*txSortedMap) + ) + for { + // Launch next background reorg if needed + if curDone == nil && launchNextRun { + // Run the background reorg and announcements + go pool.runReorg(nextDone, reset, dirtyAccounts, queuedEvents) + + // Prepare everything for the next round of reorg + curDone, nextDone = nextDone, make(chan struct{}) + launchNextRun = false + + reset, dirtyAccounts = nil, nil + queuedEvents = make(map[common.Address]*txSortedMap) + } + + select { + case req := <-pool.reqResetCh: + // Reset request: update head if request is already pending. + if reset == nil { + reset = req + } else { + reset.newHead = req.newHead + } + launchNextRun = true + pool.reorgDoneCh <- nextDone + + case req := <-pool.reqPromoteCh: + // Promote request: update address set if request is already pending. + if dirtyAccounts == nil { + dirtyAccounts = req + } else { + dirtyAccounts.merge(req) + } + launchNextRun = true + pool.reorgDoneCh <- nextDone + + case tx := <-pool.queueTxEventCh: + // Queue up the event, but don't schedule a reorg. It's up to the caller to + // request one later if they want the events sent. + addr, _ := types.Sender(pool.signer, tx) + if _, ok := queuedEvents[addr]; !ok { + queuedEvents[addr] = newTxSortedMap() + } + queuedEvents[addr].Put(tx) + + case <-curDone: + curDone = nil + + case <-pool.reorgShutdownCh: + // Wait for current run to finish. + if curDone != nil { + <-curDone + } + close(nextDone) + return } } } +// runReorg runs reset and promoteExecutables on behalf of scheduleReorgLoop. +func (pool *TxPool) runReorg(done chan struct{}, reset *txpoolResetRequest, dirtyAccounts *accountSet, events map[common.Address]*txSortedMap) { + defer close(done) + + var promoteAddrs []common.Address + if dirtyAccounts != nil && reset == nil { + // Only dirty accounts need to be promoted, unless we're resetting. + // For resets, all addresses in the tx queue will be promoted and + // the flatten operation can be avoided. + promoteAddrs = dirtyAccounts.flatten() + } + pool.mu.Lock() + if reset != nil { + // Reset from the old head to the new, rescheduling any reorged transactions + pool.reset(reset.oldHead, reset.newHead) + + // Nonces were reset, discard any events that became stale + for addr := range events { + events[addr].Forward(pool.pendingNonces.get(addr)) + if events[addr].Len() == 0 { + delete(events, addr) + } + } + // Reset needs promote for all addresses + promoteAddrs = make([]common.Address, 0, len(pool.queue)) + for addr := range pool.queue { + promoteAddrs = append(promoteAddrs, addr) + } + } + // Check for pending transactions for every account that sent new ones + promoted := pool.promoteExecutables(promoteAddrs) + + // If a new block appeared, validate the pool of pending transactions. This will + // remove any transaction that has been included in the block or was invalidated + // because of another transaction (e.g. higher gas price). + if reset != nil { + pool.demoteUnexecutables() + } + // Ensure pool.queue and pool.pending sizes stay within the configured limits. + pool.truncatePending() + pool.truncateQueue() + + // Update all accounts to the latest known pending nonce + for addr, list := range pool.pending { + highestPending := list.LastElement() + pool.pendingNonces.set(addr, highestPending.Nonce()+1) + } + pool.mu.Unlock() + + // Notify subsystems for newly added transactions + for _, tx := range promoted { + addr, _ := types.Sender(pool.signer, tx) + if _, ok := events[addr]; !ok { + events[addr] = newTxSortedMap() + } + events[addr].Put(tx) + } + if len(events) > 0 { + var txs []*types.Transaction + for _, set := range events { + txs = append(txs, set.Flatten()...) + } + pool.txFeed.Send(NewTxsEvent{txs}) + } +} + +// reset retrieves the current state of the blockchain and ensures the content +// of the transaction pool is valid with regard to the chain state. +func (pool *TxPool) reset(oldHead, newHead *types.Header) { + // If we're reorging an old state, reinject all dropped transactions + var reinject types.Transactions + + if oldHead != nil && oldHead.Hash() != newHead.ParentHash { + // If the reorg is too deep, avoid doing it (will happen during fast sync) + oldNum := oldHead.Number.Uint64() + newNum := newHead.Number.Uint64() + + if depth := uint64(math.Abs(float64(oldNum) - float64(newNum))); depth > 64 { + log.Debug("Skipping deep transaction reorg", "depth", depth) + } else { + // Reorg seems shallow enough to pull in all transactions into memory + var discarded, included types.Transactions + var ( + rem = pool.chain.GetBlock(oldHead.Hash(), oldHead.Number.Uint64()) + add = pool.chain.GetBlock(newHead.Hash(), newHead.Number.Uint64()) + ) + if rem == nil { + // This can happen if a setHead is performed, where we simply discard the old + // head from the chain. + // If that is the case, we don't have the lost transactions any more, and + // there's nothing to add + if newNum >= oldNum { + // If we reorged to a same or higher number, then it's not a case of setHead + log.Warn("Transaction pool reset with missing oldhead", + "old", oldHead.Hash(), "oldnum", oldNum, "new", newHead.Hash(), "newnum", newNum) + return + } + // If the reorg ended up on a lower number, it's indicative of setHead being the cause + log.Debug("Skipping transaction reset caused by setHead", + "old", oldHead.Hash(), "oldnum", oldNum, "new", newHead.Hash(), "newnum", newNum) + // We still need to update the current state s.th. the lost transactions can be readded by the user + } else { + for rem.NumberU64() > add.NumberU64() { + discarded = append(discarded, rem.Transactions()...) + if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { + log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) + return + } + } + for add.NumberU64() > rem.NumberU64() { + included = append(included, add.Transactions()...) + if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { + log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) + return + } + } + for rem.Hash() != add.Hash() { + discarded = append(discarded, rem.Transactions()...) + if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil { + log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash()) + return + } + included = append(included, add.Transactions()...) + if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil { + log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash()) + return + } + } + reinject = types.TxDifference(discarded, included) + } + } + } + // Initialize the internal state to the current head + if newHead == nil { + newHead = pool.chain.CurrentBlock().Header() // Special case during testing + } + statedb, err := pool.chain.StateAt(newHead.Root) + if err != nil { + log.Error("Failed to reset txpool state", "err", err) + return + } + pool.currentState = statedb + pool.trc21FeeCapacity = state.GetTRC21FeeCapacityFromStateWithCache(newHead.Root, statedb) + pool.pendingNonces = newTxNoncer(statedb) + pool.currentMaxGas = newHead.GasLimit + + // Inject any transactions discarded due to reorgs + log.Debug("Reinjecting stale transactions", "count", len(reinject)) + senderCacher.recover(pool.signer, reinject) + pool.addTxsLocked(reinject, false) + + // Update all fork indicator by next pending block number. + next := new(big.Int).Add(newHead.Number, big.NewInt(1)) + pool.eip2718 = pool.chainconfig.IsEIP1559(next) +} + // promoteExecutables moves transactions that have become processable from the // future queue to the set of pending transactions. During this process, all // invalidated transactions (low nonce, low balance) are deleted. -func (pool *TxPool) promoteExecutables(accounts []common.Address) { +func (pool *TxPool) promoteExecutables(accounts []common.Address) []*types.Transaction { log.Debug("start promoteExecutables") defer func(start time.Time) { log.Debug("end promoteExecutables", "time", common.PrettyDuration(time.Since(start))) }(time.Now()) - // Gather all the accounts potentially needing updates - if accounts == nil { - accounts = make([]common.Address, 0, len(pool.queue)) - for addr := range pool.queue { - accounts = append(accounts, addr) - } - } + // Track the promoted transactions to broadcast them at once + var promoted []*types.Transaction + // Iterate over all accounts and promote any executable transactions for _, addr := range accounts { list := pool.queue[addr] @@ -1064,12 +1376,12 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { continue // Just in case someone calls with a non existing account } // Drop all transactions that are deemed too old (low nonce) - for _, tx := range list.Forward(pool.currentState.GetNonce(addr)) { + forwards := list.Forward(pool.currentState.GetNonce(addr)) + for _, tx := range forwards { hash := tx.Hash() - log.Trace("Removed old queued transaction", "hash", hash) - delete(pool.all, hash) - pool.priced.Removed() + pool.all.Remove(hash) } + log.Trace("Removed old queued transactions", "count", len(forwards)) // Drop all transactions that are too costly (low balance or out of gas) var number *big.Int = nil if pool.chain.CurrentHeader() != nil { @@ -1078,141 +1390,176 @@ func (pool *TxPool) promoteExecutables(accounts []common.Address) { drops, _ := list.Filter(pool.currentState.GetBalance(addr), pool.currentMaxGas, pool.trc21FeeCapacity, number) for _, tx := range drops { hash := tx.Hash() - log.Trace("Removed unpayable queued transaction", "hash", hash) - delete(pool.all, hash) - pool.priced.Removed() - queuedNofundsCounter.Inc(1) + pool.all.Remove(hash) } + log.Trace("Removed unpayable queued transactions", "count", len(drops)) + queuedNofundsMeter.Mark(int64(len(drops))) + // Gather all executable transactions and promote them - for _, tx := range list.Ready(pool.pendingState.GetNonce(addr)) { + readies := list.Ready(pool.pendingNonces.get(addr)) + for _, tx := range readies { hash := tx.Hash() - log.Trace("Promoting queued transaction", "hash", hash) - pool.promoteTx(addr, hash, tx) + if pool.promoteTx(addr, hash, tx) { + promoted = append(promoted, tx) + } } + log.Trace("Promoted queued transactions", "count", len(promoted)) + queuedGauge.Dec(int64(len(readies))) + // Drop all transactions over the allowed limit + var caps types.Transactions if !pool.locals.contains(addr) { - for _, tx := range list.Cap(int(pool.config.AccountQueue)) { + caps = list.Cap(int(pool.config.AccountQueue)) + for _, tx := range caps { hash := tx.Hash() - delete(pool.all, hash) - pool.priced.Removed() - queuedRateLimitCounter.Inc(1) + pool.all.Remove(hash) log.Trace("Removed cap-exceeding queued transaction", "hash", hash) } + queuedRateLimitMeter.Mark(int64(len(caps))) + } + // Mark all the items dropped as removed + pool.priced.Removed(len(forwards) + len(drops) + len(caps)) + queuedGauge.Dec(int64(len(forwards) + len(drops) + len(caps))) + if pool.locals.contains(addr) { + localGauge.Dec(int64(len(forwards) + len(drops) + len(caps))) } // Delete the entire queue entry if it became empty. if list.Empty() { delete(pool.queue, addr) + delete(pool.beats, addr) } } - // If the pending limit is overflown, start equalizing allowances + return promoted +} + +// truncatePending removes transactions from the pending queue if the pool is above the +// pending limit. The algorithm tries to reduce transaction counts by an approximately +// equal number for all for accounts with many pending transactions. +func (pool *TxPool) truncatePending() { pending := uint64(0) for _, list := range pool.pending { pending += uint64(list.Len()) } - if pending > pool.config.GlobalSlots { - pendingBeforeCap := pending - // Assemble a spam order to penalize large transactors first - spammers := prque.New() - for addr, list := range pool.pending { - // Only evict transactions from high rollers - if !pool.locals.contains(addr) && uint64(list.Len()) > pool.config.AccountSlots { - spammers.Push(addr, float32(list.Len())) - } - } - // Gradually drop transactions from offenders - offenders := []common.Address{} - for pending > pool.config.GlobalSlots && !spammers.Empty() { - // Retrieve the next offender if not local address - offender, _ := spammers.Pop() - offenders = append(offenders, offender.(common.Address)) - - // Equalize balances until all the same or below threshold - if len(offenders) > 1 { - // Calculate the equalization threshold for all current offenders - threshold := pool.pending[offender.(common.Address)].Len() - - // Iteratively reduce all offenders until below limit or threshold reached - for pending > pool.config.GlobalSlots && pool.pending[offenders[len(offenders)-2]].Len() > threshold { - for i := 0; i < len(offenders)-1; i++ { - list := pool.pending[offenders[i]] - for _, tx := range list.Cap(list.Len() - 1) { - // Drop the transaction from the global pools too - hash := tx.Hash() - delete(pool.all, hash) - pool.priced.Removed() - - // Update the account nonce to the dropped transaction - if nonce := tx.Nonce(); pool.pendingState.GetNonce(offenders[i]) > nonce { - pool.pendingState.SetNonce(offenders[i], nonce) - } - log.Trace("Removed fairness-exceeding pending transaction", "hash", hash) - } - pending-- - } - } - } + if pending <= pool.config.GlobalSlots { + return + } + + pendingBeforeCap := pending + // Assemble a spam order to penalize large transactors first + spammers := prque.New(nil) + for addr, list := range pool.pending { + // Only evict transactions from high rollers + if !pool.locals.contains(addr) && uint64(list.Len()) > pool.config.AccountSlots { + spammers.Push(addr, int64(list.Len())) } - // If still above threshold, reduce to limit or min allowance - if pending > pool.config.GlobalSlots && len(offenders) > 0 { - for pending > pool.config.GlobalSlots && uint64(pool.pending[offenders[len(offenders)-1]].Len()) > pool.config.AccountSlots { - for _, addr := range offenders { - list := pool.pending[addr] - for _, tx := range list.Cap(list.Len() - 1) { + } + // Gradually drop transactions from offenders + offenders := []common.Address{} + for pending > pool.config.GlobalSlots && !spammers.Empty() { + // Retrieve the next offender if not local address + offender, _ := spammers.Pop() + offenders = append(offenders, offender.(common.Address)) + + // Equalize balances until all the same or below threshold + if len(offenders) > 1 { + // Calculate the equalization threshold for all current offenders + threshold := pool.pending[offender.(common.Address)].Len() + + // Iteratively reduce all offenders until below limit or threshold reached + for pending > pool.config.GlobalSlots && pool.pending[offenders[len(offenders)-2]].Len() > threshold { + for i := 0; i < len(offenders)-1; i++ { + list := pool.pending[offenders[i]] + + caps := list.Cap(list.Len() - 1) + for _, tx := range caps { // Drop the transaction from the global pools too hash := tx.Hash() - delete(pool.all, hash) - pool.priced.Removed() + pool.all.Remove(hash) // Update the account nonce to the dropped transaction - if nonce := tx.Nonce(); pool.pendingState.GetNonce(addr) > nonce { - pool.pendingState.SetNonce(addr, nonce) - } + pool.pendingNonces.setIfLower(offenders[i], tx.Nonce()) log.Trace("Removed fairness-exceeding pending transaction", "hash", hash) } + pool.priced.Removed(len(caps)) + pendingGauge.Dec(int64(len(caps))) + if pool.locals.contains(offenders[i]) { + localGauge.Dec(int64(len(caps))) + } pending-- } } } - pendingRateLimitCounter.Inc(int64(pendingBeforeCap - pending)) } - // If we've queued more transactions than the hard limit, drop oldest ones + + // If still above threshold, reduce to limit or min allowance + if pending > pool.config.GlobalSlots && len(offenders) > 0 { + for pending > pool.config.GlobalSlots && uint64(pool.pending[offenders[len(offenders)-1]].Len()) > pool.config.AccountSlots { + for _, addr := range offenders { + list := pool.pending[addr] + + caps := list.Cap(list.Len() - 1) + for _, tx := range caps { + // Drop the transaction from the global pools too + hash := tx.Hash() + pool.all.Remove(hash) + + // Update the account nonce to the dropped transaction + pool.pendingNonces.setIfLower(addr, tx.Nonce()) + log.Trace("Removed fairness-exceeding pending transaction", "hash", hash) + } + pool.priced.Removed(len(caps)) + pendingGauge.Dec(int64(len(caps))) + if pool.locals.contains(addr) { + localGauge.Dec(int64(len(caps))) + } + pending-- + } + } + } + pendingRateLimitMeter.Mark(int64(pendingBeforeCap - pending)) +} + +// truncateQueue drops the oldes transactions in the queue if the pool is above the global queue limit. +func (pool *TxPool) truncateQueue() { queued := uint64(0) for _, list := range pool.queue { queued += uint64(list.Len()) } - if queued > pool.config.GlobalQueue { - // Sort all accounts with queued transactions by heartbeat - addresses := make(addresssByHeartbeat, 0, len(pool.queue)) - for addr := range pool.queue { - if !pool.locals.contains(addr) { // don't drop locals - addresses = append(addresses, addressByHeartbeat{addr, pool.beats[addr]}) - } + if queued <= pool.config.GlobalQueue { + return + } + + // Sort all accounts with queued transactions by heartbeat + addresses := make(addressesByHeartbeat, 0, len(pool.queue)) + for addr := range pool.queue { + if !pool.locals.contains(addr) { // don't drop locals + addresses = append(addresses, addressByHeartbeat{addr, pool.beats[addr]}) } - sort.Sort(addresses) + } + sort.Sort(addresses) - // Drop transactions until the total is below the limit or only locals remain - for drop := queued - pool.config.GlobalQueue; drop > 0 && len(addresses) > 0; { - addr := addresses[len(addresses)-1] - list := pool.queue[addr.address] + // Drop transactions until the total is below the limit or only locals remain + for drop := queued - pool.config.GlobalQueue; drop > 0 && len(addresses) > 0; { + addr := addresses[len(addresses)-1] + list := pool.queue[addr.address] - addresses = addresses[:len(addresses)-1] + addresses = addresses[:len(addresses)-1] - // Drop all transactions if they are less than the overflow - if size := uint64(list.Len()); size <= drop { - for _, tx := range list.Flatten() { - pool.removeTx(tx.Hash()) - } - drop -= size - queuedRateLimitCounter.Inc(int64(size)) - continue - } - // Otherwise drop only last few transactions - txs := list.Flatten() - for i := len(txs) - 1; i >= 0 && drop > 0; i-- { - pool.removeTx(txs[i].Hash()) - drop-- - queuedRateLimitCounter.Inc(1) + // Drop all transactions if they are less than the overflow + if size := uint64(list.Len()); size <= drop { + for _, tx := range list.Flatten() { + pool.removeTx(tx.Hash(), true) } + drop -= size + queuedRateLimitMeter.Mark(int64(size)) + continue + } + // Otherwise drop only last few transactions + txs := list.Flatten() + for i := len(txs) - 1; i >= 0 && drop > 0; i-- { + pool.removeTx(txs[i].Hash(), true) + drop-- + queuedRateLimitMeter.Mark(1) } } } @@ -1226,11 +1573,11 @@ func (pool *TxPool) demoteUnexecutables() { nonce := pool.currentState.GetNonce(addr) // Drop all transactions that are deemed too old (low nonce) - for _, tx := range list.Forward(nonce) { + olds := list.Forward(nonce) + for _, tx := range olds { hash := tx.Hash() + pool.all.Remove(hash) log.Trace("Removed old pending transaction", "hash", hash) - delete(pool.all, hash) - pool.priced.Removed() } // Drop all transactions that are too costly (low balance or out of gas), and queue any invalids back for later var number *big.Int = nil @@ -1241,27 +1588,39 @@ func (pool *TxPool) demoteUnexecutables() { for _, tx := range drops { hash := tx.Hash() log.Trace("Removed unpayable pending transaction", "hash", hash) - delete(pool.all, hash) - pool.priced.Removed() - pendingNofundsCounter.Inc(1) + pool.all.Remove(hash) } + pool.priced.Removed(len(olds) + len(drops)) + pendingNofundsMeter.Mark(int64(len(drops))) + for _, tx := range invalids { hash := tx.Hash() log.Trace("Demoting pending transaction", "hash", hash) - pool.enqueueTx(hash, tx) + + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(hash, tx, false, false) } - // If there's a gap in front, warn (should never happen) and postpone all transactions + pendingGauge.Dec(int64(len(olds) + len(drops) + len(invalids))) + if pool.locals.contains(addr) { + localGauge.Dec(int64(len(olds) + len(drops) + len(invalids))) + } + // If there's a gap in front, alert (should never happen) and postpone all transactions if list.Len() > 0 && list.txs.Get(nonce) == nil { - for _, tx := range list.Cap(0) { + gapped := list.Cap(0) + for _, tx := range gapped { hash := tx.Hash() log.Warn("Demoting invalidated transaction", "hash", hash) - pool.enqueueTx(hash, tx) + + // Internal shuffle shouldn't touch the lookup set. + pool.enqueueTx(hash, tx, false, false) } + pendingGauge.Dec(int64(len(gapped))) + // This might happen in a reorg, so log it to the metering + blockReorgInvalidatedTx.Mark(int64(len(gapped))) } - // Delete the entire queue entry if it became empty. + // Delete the entire pending entry if it became empty. if list.Empty() { delete(pool.pending, addr) - delete(pool.beats, addr) } } } @@ -1272,26 +1631,31 @@ type addressByHeartbeat struct { heartbeat time.Time } -type addresssByHeartbeat []addressByHeartbeat +type addressesByHeartbeat []addressByHeartbeat -func (a addresssByHeartbeat) Len() int { return len(a) } -func (a addresssByHeartbeat) Less(i, j int) bool { return a[i].heartbeat.Before(a[j].heartbeat) } -func (a addresssByHeartbeat) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a addressesByHeartbeat) Len() int { return len(a) } +func (a addressesByHeartbeat) Less(i, j int) bool { return a[i].heartbeat.Before(a[j].heartbeat) } +func (a addressesByHeartbeat) Swap(i, j int) { a[i], a[j] = a[j], a[i] } // accountSet is simply a set of addresses to check for existence, and a signer // capable of deriving addresses from transactions. type accountSet struct { accounts map[common.Address]struct{} signer types.Signer + cache *[]common.Address } // newAccountSet creates a new address set with an associated signer for sender // derivations. -func newAccountSet(signer types.Signer) *accountSet { - return &accountSet{ +func newAccountSet(signer types.Signer, addrs ...common.Address) *accountSet { + as := &accountSet{ accounts: make(map[common.Address]struct{}), signer: signer, } + for _, addr := range addrs { + as.add(addr) + } + return as } // contains checks if a given address is contained within the set. @@ -1300,6 +1664,10 @@ func (as *accountSet) contains(addr common.Address) bool { return exist } +func (as *accountSet) empty() bool { + return len(as.accounts) == 0 +} + // containsTx checks if the sender of a given tx is within the set. If the sender // cannot be derived, this method returns false. func (as *accountSet) containsTx(tx *types.Transaction) bool { @@ -1312,4 +1680,199 @@ func (as *accountSet) containsTx(tx *types.Transaction) bool { // add inserts a new address into the set to track. func (as *accountSet) add(addr common.Address) { as.accounts[addr] = struct{}{} + as.cache = nil +} + +// addTx adds the sender of tx into the set. +func (as *accountSet) addTx(tx *types.Transaction) { + if addr, err := types.Sender(as.signer, tx); err == nil { + as.add(addr) + } +} + +// flatten returns the list of addresses within this set, also caching it for later +// reuse. The returned slice should not be changed! +func (as *accountSet) flatten() []common.Address { + if as.cache == nil { + accounts := make([]common.Address, 0, len(as.accounts)) + for account := range as.accounts { + accounts = append(accounts, account) + } + as.cache = &accounts + } + return *as.cache +} + +// merge adds all addresses from the 'other' set into 'as'. +func (as *accountSet) merge(other *accountSet) { + for addr := range other.accounts { + as.accounts[addr] = struct{}{} + } + as.cache = nil +} + +// txLookup is used internally by TxPool to track transactions while allowing +// lookup without mutex contention. +// +// Note, although this type is properly protected against concurrent access, it +// is **not** a type that should ever be mutated or even exposed outside of the +// transaction pool, since its internal state is tightly coupled with the pools +// internal mechanisms. The sole purpose of the type is to permit out-of-bound +// peeking into the pool in TxPool.Get without having to acquire the widely scoped +// TxPool.mu mutex. +// +// This lookup set combines the notion of "local transactions", which is useful +// to build upper-level structure. +type txLookup struct { + slots int + lock sync.RWMutex + locals map[common.Hash]*types.Transaction + remotes map[common.Hash]*types.Transaction +} + +// newTxLookup returns a new txLookup structure. +func newTxLookup() *txLookup { + return &txLookup{ + locals: make(map[common.Hash]*types.Transaction), + remotes: make(map[common.Hash]*types.Transaction), + } +} + +// Range calls f on each key and value present in the map. The callback passed +// should return the indicator whether the iteration needs to be continued. +// Callers need to specify which set (or both) to be iterated. +func (t *txLookup) Range(f func(hash common.Hash, tx *types.Transaction, local bool) bool, local bool, remote bool) { + t.lock.RLock() + defer t.lock.RUnlock() + + if local { + for key, value := range t.locals { + if !f(key, value, true) { + return + } + } + } + if remote { + for key, value := range t.remotes { + if !f(key, value, false) { + return + } + } + } +} + +// Get returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) Get(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + if tx := t.locals[hash]; tx != nil { + return tx + } + return t.remotes[hash] +} + +// GetLocal returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) GetLocal(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.locals[hash] +} + +// GetRemote returns a transaction if it exists in the lookup, or nil if not found. +func (t *txLookup) GetRemote(hash common.Hash) *types.Transaction { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.remotes[hash] +} + +// Count returns the current number of transactions in the lookup. +func (t *txLookup) Count() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.locals) + len(t.remotes) +} + +// LocalCount returns the current number of local transactions in the lookup. +func (t *txLookup) LocalCount() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.locals) +} + +// RemoteCount returns the current number of remote transactions in the lookup. +func (t *txLookup) RemoteCount() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return len(t.remotes) +} + +// Slots returns the current number of slots used in the lookup. +func (t *txLookup) Slots() int { + t.lock.RLock() + defer t.lock.RUnlock() + + return t.slots +} + +// Add adds a transaction to the lookup. +func (t *txLookup) Add(tx *types.Transaction, local bool) { + t.lock.Lock() + defer t.lock.Unlock() + + t.slots += numSlots(tx) + slotsGauge.Update(int64(t.slots)) + + if local { + t.locals[tx.Hash()] = tx + } else { + t.remotes[tx.Hash()] = tx + } +} + +// Remove removes a transaction from the lookup. +func (t *txLookup) Remove(hash common.Hash) { + t.lock.Lock() + defer t.lock.Unlock() + + tx, ok := t.locals[hash] + if !ok { + tx, ok = t.remotes[hash] + } + if !ok { + log.Error("No transaction found to be deleted", "hash", hash) + return + } + t.slots -= numSlots(tx) + slotsGauge.Update(int64(t.slots)) + + delete(t.locals, hash) + delete(t.remotes, hash) +} + +// RemoteToLocals migrates the transactions belongs to the given locals to locals +// set. The assumption is held the locals set is thread-safe to be used. +func (t *txLookup) RemoteToLocals(locals *accountSet) int { + t.lock.Lock() + defer t.lock.Unlock() + + var migrated int + for hash, tx := range t.remotes { + if locals.containsTx(tx) { + t.locals[hash] = tx + delete(t.remotes, hash) + migrated += 1 + } + } + return migrated +} + +// numSlots calculates the number of slots needed for a single transaction. +func numSlots(tx *types.Transaction) int { + return int((tx.Size() + txSlotSize - 1) / txSlotSize) } diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go index 3477da7863e9..dbbdd24baa83 100644 --- a/core/tx_pool_test.go +++ b/core/tx_pool_test.go @@ -94,10 +94,18 @@ func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ec return tx } +func pricedDataTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ecdsa.PrivateKey, bytes uint64) *types.Transaction { + data := make([]byte, bytes) + rand.Read(data) + + tx, _ := types.SignTx(types.NewTransaction(nonce, common.Address{}, big.NewInt(0), gaslimit, gasprice, data), types.HomesteadSigner{}, key) + return tx +} + func setupTxPool() (*TxPool, *ecdsa.PrivateKey) { diskdb := rawdb.NewMemoryDatabase() statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb)) - blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} + blockchain := &testBlockChain{statedb, 10000000, new(event.Feed)} key, _ := crypto.GenerateKey() pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) @@ -112,11 +120,13 @@ func validateTxPoolInternals(pool *TxPool) error { // Ensure the total transaction set is consistent with pending + queued pending, queued := pool.stats() - if total := len(pool.all); total != pending+queued { + if total := pool.all.Count(); total != pending+queued { return fmt.Errorf("total transaction count %d != %d pending + %d queued", total, pending, queued) } - if priced := pool.priced.items.Len() - pool.priced.stales; priced != pending+queued { - return fmt.Errorf("total priced transaction count %d != %d pending + %d queued", priced, pending, queued) + pool.priced.Reheap() + priced, remote := pool.priced.remotes.Len(), pool.all.RemoteCount() + if priced != remote { + return fmt.Errorf("total priced transaction count %d != %d", priced, remote) } // Ensure the next nonce to assign is the correct one for addr, txs := range pool.pending { @@ -127,7 +137,7 @@ func validateTxPoolInternals(pool *TxPool) error { last = nonce } } - if nonce := pool.pendingState.GetNonce(addr); nonce != last+1 { + if nonce := pool.Nonce(addr); nonce != last+1 { return fmt.Errorf("pending nonce mismatch: have %v, want %v", nonce, last+1) } } @@ -136,21 +146,27 @@ func validateTxPoolInternals(pool *TxPool) error { // validateEvents checks that the correct number of transaction addition events // were fired on the pool's event feed. -func validateEvents(events chan TxPreEvent, count int) error { - for i := 0; i < count; i++ { +func validateEvents(events chan NewTxsEvent, count int) error { + var received []*types.Transaction + + for len(received) < count { select { - case <-events: + case ev := <-events: + received = append(received, ev.Txs...) case <-time.After(time.Second): - return fmt.Errorf("event #%d not fired", i) + return fmt.Errorf("event #%d not fired", len(received)) } } + if len(received) > count { + return fmt.Errorf("more than %d events fired: %v", count, received[count:]) + } select { - case tx := <-events: - return fmt.Errorf("more than %d events fired: %v", count, tx.Tx) + case ev := <-events: + return fmt.Errorf("more than %d events fired: %v", count, ev.Txs) case <-time.After(50 * time.Millisecond): // This branch should be "default", but it's a data race between goroutines, - // reading the event channel and pushng into it, so better wait a bit ensuring + // reading the event channel and pushing into it, so better wait a bit ensuring // really nothing gets injected. } return nil @@ -209,33 +225,27 @@ func TestStateChangeDuringTransactionPoolReset(t *testing.T) { pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) defer pool.Stop() - nonce := pool.State().GetNonce(address) + nonce := pool.Nonce(address) if nonce != 0 { t.Fatalf("Invalid nonce, want 0, got %d", nonce) } - pool.AddRemotes(types.Transactions{tx0, tx1}) + pool.AddRemotesSync([]*types.Transaction{tx0, tx1}) - nonce = pool.State().GetNonce(address) + nonce = pool.Nonce(address) if nonce != 2 { t.Fatalf("Invalid nonce, want 2, got %d", nonce) } // trigger state change in the background trigger = true + <-pool.requestReset(nil, nil) - pool.lockedReset(nil, nil) - - pendingTx, err := pool.Pending() + _, err := pool.Pending() if err != nil { t.Fatalf("Could not fetch pending transactions: %v", err) } - - for addr, txs := range pendingTx { - t.Logf("%0x: %d\n", addr, len(txs)) - } - - nonce = pool.State().GetNonce(address) + nonce = pool.Nonce(address) if nonce != 2 { t.Fatalf("Invalid nonce, want 2, got %d", nonce) } @@ -287,10 +297,10 @@ func TestTransactionQueue(t *testing.T) { tx := transaction(0, 100, key) from, _ := deriveSender(tx) pool.currentState.AddBalance(from, big.NewInt(1000)) - pool.lockedReset(nil, nil) - pool.enqueueTx(tx.Hash(), tx) + <-pool.requestReset(nil, nil) - pool.promoteExecutables([]common.Address{from}) + pool.enqueueTx(tx.Hash(), tx, false, true) + <-pool.requestPromoteExecutables(newAccountSet(pool.signer, from)) if len(pool.pending) != 1 { t.Error("expected valid txs to be 1 is", len(pool.pending)) } @@ -298,8 +308,9 @@ func TestTransactionQueue(t *testing.T) { tx = transaction(1, 100, key) from, _ = deriveSender(tx) pool.currentState.SetNonce(from, 2) - pool.enqueueTx(tx.Hash(), tx) - pool.promoteExecutables([]common.Address{from}) + pool.enqueueTx(tx.Hash(), tx, false, true) + + <-pool.requestPromoteExecutables(newAccountSet(pool.signer, from)) if _, ok := pool.pending[from].txs.items[tx.Nonce()]; ok { t.Error("expected transaction to be in tx pool") } @@ -307,25 +318,28 @@ func TestTransactionQueue(t *testing.T) { if len(pool.queue) > 0 { t.Error("expected transaction queue to be empty. is", len(pool.queue)) } +} + +func TestTransactionQueue2(t *testing.T) { + t.Parallel() - pool, key = setupTxPool() + pool, key := setupTxPool() defer pool.Stop() tx1 := transaction(0, 100, key) tx2 := transaction(10, 100, key) tx3 := transaction(11, 100, key) - from, _ = deriveSender(tx1) + from, _ := deriveSender(tx1) pool.currentState.AddBalance(from, big.NewInt(1000)) - pool.lockedReset(nil, nil) + pool.reset(nil, nil) - pool.enqueueTx(tx1.Hash(), tx1) - pool.enqueueTx(tx2.Hash(), tx2) - pool.enqueueTx(tx3.Hash(), tx3) + pool.enqueueTx(tx1.Hash(), tx1, false, true) + pool.enqueueTx(tx2.Hash(), tx2, false, true) + pool.enqueueTx(tx3.Hash(), tx3, false, true) pool.promoteExecutables([]common.Address{from}) - if len(pool.pending) != 1 { - t.Error("expected tx pool to be 1, got", len(pool.pending)) + t.Error("expected pending length to be 1, got", len(pool.pending)) } if pool.queue[from].Len() != 2 { t.Error("expected len(queue) == 2, got", pool.queue[from].Len()) @@ -359,7 +373,7 @@ func TestTransactionChainFork(t *testing.T) { statedb.AddBalance(addr, big.NewInt(100000000000000)) pool.chain = &testBlockChain{statedb, 1000000, new(event.Feed)} - pool.lockedReset(nil, nil) + <-pool.requestReset(nil, nil) } resetState() @@ -367,7 +381,7 @@ func TestTransactionChainFork(t *testing.T) { if _, err := pool.add(tx, false); err != nil { t.Error("didn't expect error", err) } - pool.removeTx(tx.Hash()) + pool.removeTx(tx.Hash(), true) // reset the pool's internal state resetState() @@ -389,7 +403,7 @@ func TestTransactionDoubleNonce(t *testing.T) { statedb.AddBalance(addr, big.NewInt(100000000000000)) pool.chain = &testBlockChain{statedb, 1000000, new(event.Feed)} - pool.lockedReset(nil, nil) + <-pool.requestReset(nil, nil) } resetState() @@ -405,16 +419,17 @@ func TestTransactionDoubleNonce(t *testing.T) { if replace, err := pool.add(tx2, false); err != nil || !replace { t.Errorf("second transaction insert failed (%v) or not reported replacement (%v)", err, replace) } - pool.promoteExecutables([]common.Address{addr}) + <-pool.requestPromoteExecutables(newAccountSet(signer, addr)) if pool.pending[addr].Len() != 1 { t.Error("expected 1 pending transactions, got", pool.pending[addr].Len()) } if tx := pool.pending[addr].txs.items[0]; tx.Hash() != tx2.Hash() { t.Errorf("transaction mismatch: have %x, want %x", tx.Hash(), tx2.Hash()) } + // Add the third transaction and ensure it's not saved (smaller price) pool.add(tx3, false) - pool.promoteExecutables([]common.Address{addr}) + <-pool.requestPromoteExecutables(newAccountSet(signer, addr)) if pool.pending[addr].Len() != 1 { t.Error("expected 1 pending transactions, got", pool.pending[addr].Len()) } @@ -422,8 +437,8 @@ func TestTransactionDoubleNonce(t *testing.T) { t.Errorf("transaction mismatch: have %x, want %x", tx.Hash(), tx2.Hash()) } // Ensure the total transaction count is correct - if len(pool.all) != 1 { - t.Error("expected 1 total transactions, got", len(pool.all)) + if pool.all.Count() != 1 { + t.Error("expected 1 total transactions, got", pool.all.Count()) } } @@ -445,8 +460,8 @@ func TestTransactionMissingNonce(t *testing.T) { if pool.queue[addr].Len() != 1 { t.Error("expected 1 queued transaction, got", pool.queue[addr].Len()) } - if len(pool.all) != 1 { - t.Error("expected 1 total transactions, got", len(pool.all)) + if pool.all.Count() != 1 { + t.Error("expected 1 total transactions, got", pool.all.Count()) } } @@ -460,7 +475,7 @@ func TestTransactionNonceRecovery(t *testing.T) { addr := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.SetNonce(addr, n) pool.currentState.AddBalance(addr, big.NewInt(100000000000000)) - pool.lockedReset(nil, nil) + <-pool.requestReset(nil, nil) tx := transaction(n, 100000, key) if err := pool.AddRemote(tx); err != nil { @@ -468,8 +483,8 @@ func TestTransactionNonceRecovery(t *testing.T) { } // simulate some weird re-order of transactions and missing nonce(s) pool.currentState.SetNonce(addr, n-1) - pool.lockedReset(nil, nil) - if fn := pool.pendingState.GetNonce(addr); fn != n-1 { + <-pool.requestReset(nil, nil) + if fn := pool.Nonce(addr); fn != n-1 { t.Errorf("expected nonce to be %d, got %d", n-1, fn) } } @@ -483,7 +498,7 @@ func TestTransactionDropping(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000)) // Add some pending and some queued transactions @@ -495,12 +510,21 @@ func TestTransactionDropping(t *testing.T) { tx11 = transaction(11, 200, key) tx12 = transaction(12, 300, key) ) + pool.all.Add(tx0, false) + pool.priced.Put(tx0, false) pool.promoteTx(account, tx0.Hash(), tx0) + + pool.all.Add(tx1, false) + pool.priced.Put(tx1, false) pool.promoteTx(account, tx1.Hash(), tx1) + + pool.all.Add(tx2, false) + pool.priced.Put(tx2, false) pool.promoteTx(account, tx2.Hash(), tx2) - pool.enqueueTx(tx10.Hash(), tx10) - pool.enqueueTx(tx11.Hash(), tx11) - pool.enqueueTx(tx12.Hash(), tx12) + + pool.enqueueTx(tx10.Hash(), tx10, false, true) + pool.enqueueTx(tx11.Hash(), tx11, false, true) + pool.enqueueTx(tx12.Hash(), tx12, false, true) // Check that pre and post validations leave the pool as is if pool.pending[account].Len() != 3 { @@ -509,22 +533,22 @@ func TestTransactionDropping(t *testing.T) { if pool.queue[account].Len() != 3 { t.Errorf("queued transaction mismatch: have %d, want %d", pool.queue[account].Len(), 3) } - if len(pool.all) != 6 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 6) + if pool.all.Count() != 6 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 6) } - pool.lockedReset(nil, nil) + <-pool.requestReset(nil, nil) if pool.pending[account].Len() != 3 { t.Errorf("pending transaction mismatch: have %d, want %d", pool.pending[account].Len(), 3) } if pool.queue[account].Len() != 3 { t.Errorf("queued transaction mismatch: have %d, want %d", pool.queue[account].Len(), 3) } - if len(pool.all) != 6 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 6) + if pool.all.Count() != 6 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 6) } // Reduce the balance of the account, and check that invalidated transactions are dropped pool.currentState.AddBalance(account, big.NewInt(-650)) - pool.lockedReset(nil, nil) + <-pool.requestReset(nil, nil) if _, ok := pool.pending[account].txs.items[tx0.Nonce()]; !ok { t.Errorf("funded pending transaction missing: %v", tx0) @@ -544,12 +568,12 @@ func TestTransactionDropping(t *testing.T) { if _, ok := pool.queue[account].txs.items[tx12.Nonce()]; ok { t.Errorf("out-of-fund queued transaction present: %v", tx11) } - if len(pool.all) != 4 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 4) + if pool.all.Count() != 4 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 4) } // Reduce the block gas limit, check that invalidated transactions are dropped pool.chain.(*testBlockChain).gasLimit = 100 - pool.lockedReset(nil, nil) + <-pool.requestReset(nil, nil) if _, ok := pool.pending[account].txs.items[tx0.Nonce()]; !ok { t.Errorf("funded pending transaction missing: %v", tx0) @@ -563,8 +587,8 @@ func TestTransactionDropping(t *testing.T) { if _, ok := pool.queue[account].txs.items[tx11.Nonce()]; ok { t.Errorf("over-gased queued transaction present: %v", tx11) } - if len(pool.all) != 2 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), 2) + if pool.all.Count() != 2 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), 2) } } @@ -606,7 +630,7 @@ func TestTransactionPostponing(t *testing.T) { txs = append(txs, tx) } } - for i, err := range pool.AddRemotes(txs) { + for i, err := range pool.AddRemotesSync(txs) { if err != nil { t.Fatalf("tx %d: failed to add transactions: %v", i, err) } @@ -618,24 +642,24 @@ func TestTransactionPostponing(t *testing.T) { if len(pool.queue) != 0 { t.Errorf("queued accounts mismatch: have %d, want %d", len(pool.queue), 0) } - if len(pool.all) != len(txs) { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), len(txs)) + if pool.all.Count() != len(txs) { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), len(txs)) } - pool.lockedReset(nil, nil) + <-pool.requestReset(nil, nil) if pending := pool.pending[accs[0]].Len() + pool.pending[accs[1]].Len(); pending != len(txs) { t.Errorf("pending transaction mismatch: have %d, want %d", pending, len(txs)) } if len(pool.queue) != 0 { t.Errorf("queued accounts mismatch: have %d, want %d", len(pool.queue), 0) } - if len(pool.all) != len(txs) { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), len(txs)) + if pool.all.Count() != len(txs) { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), len(txs)) } // Reduce the balance of the account, and check that transactions are reorganised for _, addr := range accs { pool.currentState.AddBalance(addr, big.NewInt(-1)) } - pool.lockedReset(nil, nil) + <-pool.requestReset(nil, nil) // The first account's first transaction remains valid, check that subsequent // ones are either filtered out, or queued up for later. @@ -678,8 +702,8 @@ func TestTransactionPostponing(t *testing.T) { } } } - if len(pool.all) != len(txs)/2 { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), len(txs)/2) + if pool.all.Count() != len(txs)/2 { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), len(txs)/2) } } @@ -693,21 +717,19 @@ func TestTransactionGapFilling(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) // Keep track of transaction events to ensure all executables get announced - events := make(chan TxPreEvent, testTxPoolConfig.AccountQueue+5) + events := make(chan NewTxsEvent, testTxPoolConfig.AccountQueue+5) sub := pool.txFeed.Subscribe(events) defer sub.Unsubscribe() // Create a pending and a queued transaction with a nonce-gap in between - if err := pool.AddRemote(transaction(0, 100000, key)); err != nil { - t.Fatalf("failed to add pending transaction: %v", err) - } - if err := pool.AddRemote(transaction(2, 100000, key)); err != nil { - t.Fatalf("failed to add queued transaction: %v", err) - } + pool.AddRemotesSync([]*types.Transaction{ + transaction(0, 100000, key), + transaction(2, 100000, key), + }) pending, queued := pool.Stats() if pending != 1 { t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 1) @@ -722,7 +744,7 @@ func TestTransactionGapFilling(t *testing.T) { t.Fatalf("pool internal state corrupted: %v", err) } // Fill the nonce gap and ensure all transactions become pending - if err := pool.AddRemote(transaction(1, 100000, key)); err != nil { + if err := pool.addRemoteSync(transaction(1, 100000, key)); err != nil { t.Fatalf("failed to add gapped transaction: %v", err) } pending, queued = pool.Stats() @@ -749,12 +771,12 @@ func TestTransactionQueueAccountLimiting(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) testTxPoolConfig.AccountQueue = 10 // Keep queuing up transactions and make sure all above a limit are dropped for i := uint64(1); i <= testTxPoolConfig.AccountQueue; i++ { - if err := pool.AddRemote(transaction(i, 100000, key)); err != nil { + if err := pool.addRemoteSync(transaction(i, 100000, key)); err != nil { t.Fatalf("tx %d: failed to add transaction: %v", i, err) } if len(pool.pending) != 0 { @@ -770,8 +792,8 @@ func TestTransactionQueueAccountLimiting(t *testing.T) { } } } - if len(pool.all) != int(testTxPoolConfig.AccountQueue) { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), testTxPoolConfig.AccountQueue) + if pool.all.Count() != int(testTxPoolConfig.AccountQueue) { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), testTxPoolConfig.AccountQueue) } } @@ -823,7 +845,7 @@ func testTransactionQueueGlobalLimiting(t *testing.T, nolocals bool) { nonces[addr]++ } // Import the batch and verify that limits have been enforced - pool.AddRemotes(txs) + pool.AddRemotesSync(txs) queued := 0 for addr, list := range pool.queue { @@ -881,7 +903,7 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { common.MinGasPrice = big.NewInt(0) // Reduce the eviction interval to a testable amount defer func(old time.Duration) { evictionInterval = old }(evictionInterval) - evictionInterval = time.Second + evictionInterval = time.Millisecond * 100 // Create the pool to test the non-expiration enforcement db := rawdb.NewMemoryDatabase() @@ -919,6 +941,22 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { if err := validateTxPoolInternals(pool); err != nil { t.Fatalf("pool internal state corrupted: %v", err) } + + // Allow the eviction interval to run + time.Sleep(2 * evictionInterval) + + // Transactions should not be evicted from the queue yet since lifetime duration has not passed + pending, queued = pool.Stats() + if pending != 0 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 0) + } + if queued != 2 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 2) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + // Wait a bit for eviction to run and clean up any leftovers, and ensure only the local remains time.Sleep(2 * config.Lifetime) @@ -938,6 +976,72 @@ func testTransactionQueueTimeLimiting(t *testing.T, nolocals bool) { if err := validateTxPoolInternals(pool); err != nil { t.Fatalf("pool internal state corrupted: %v", err) } + + // remove current transactions and increase nonce to prepare for a reset and cleanup + statedb.SetNonce(crypto.PubkeyToAddress(remote.PublicKey), 2) + statedb.SetNonce(crypto.PubkeyToAddress(local.PublicKey), 2) + <-pool.requestReset(nil, nil) + + // make sure queue, pending are cleared + pending, queued = pool.Stats() + if pending != 0 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 0) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + + // Queue gapped transactions + if err := pool.AddLocal(pricedTransaction(4, 100000, big.NewInt(1), local)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + if err := pool.addRemoteSync(pricedTransaction(4, 100000, big.NewInt(1), remote)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + time.Sleep(5 * evictionInterval) // A half lifetime pass + + // Queue executable transactions, the life cycle should be restarted. + if err := pool.AddLocal(pricedTransaction(2, 100000, big.NewInt(1), local)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + if err := pool.addRemoteSync(pricedTransaction(2, 100000, big.NewInt(1), remote)); err != nil { + t.Fatalf("failed to add remote transaction: %v", err) + } + time.Sleep(6 * evictionInterval) + + // All gapped transactions shouldn't be kicked out + pending, queued = pool.Stats() + if pending != 2 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + } + if queued != 2 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 3) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + + // The whole life time pass after last promotion, kick out stale transactions + time.Sleep(2 * config.Lifetime) + pending, queued = pool.Stats() + if pending != 2 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + } + if nolocals { + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + } else { + if queued != 1 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 1) + } + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } } // Tests that even if the transaction count belonging to a single account goes @@ -950,17 +1054,17 @@ func TestTransactionPendingLimiting(t *testing.T) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) testTxPoolConfig.AccountQueue = 10 // Keep track of transaction events to ensure all executables get announced - events := make(chan TxPreEvent, testTxPoolConfig.AccountQueue) + events := make(chan NewTxsEvent, testTxPoolConfig.AccountQueue) sub := pool.txFeed.Subscribe(events) defer sub.Unsubscribe() // Keep queuing up transactions and make sure all above a limit are dropped for i := uint64(0); i < testTxPoolConfig.AccountQueue; i++ { - if err := pool.AddRemote(transaction(i, 100000, key)); err != nil { + if err := pool.addRemoteSync(transaction(i, 100000, key)); err != nil { t.Fatalf("tx %d: failed to add transaction: %v", i, err) } if pool.pending[account].Len() != int(i)+1 { @@ -970,8 +1074,8 @@ func TestTransactionPendingLimiting(t *testing.T) { t.Errorf("tx %d: queue size mismatch: have %d, want %d", i, pool.queue[account].Len(), 0) } } - if len(pool.all) != int(testTxPoolConfig.AccountQueue) { - t.Errorf("total transaction mismatch: have %d, want %d", len(pool.all), testTxPoolConfig.AccountQueue+5) + if pool.all.Count() != int(testTxPoolConfig.AccountQueue) { + t.Errorf("total transaction mismatch: have %d, want %d", pool.all.Count(), testTxPoolConfig.AccountQueue+5) } if err := validateEvents(events, int(testTxPoolConfig.AccountQueue)); err != nil { t.Fatalf("event firing failed: %v", err) @@ -981,59 +1085,6 @@ func TestTransactionPendingLimiting(t *testing.T) { } } -// Tests that the transaction limits are enforced the same way irrelevant whether -// the transactions are added one by one or in batches. -func TestTransactionQueueLimitingEquivalency(t *testing.T) { testTransactionLimitingEquivalency(t, 1) } -func TestTransactionPendingLimitingEquivalency(t *testing.T) { - testTransactionLimitingEquivalency(t, 0) -} - -func testTransactionLimitingEquivalency(t *testing.T, origin uint64) { - t.Parallel() - - // Add a batch of transactions to a pool one by one - pool1, key1 := setupTxPool() - defer pool1.Stop() - - account1, _ := deriveSender(transaction(0, 0, key1)) - pool1.currentState.AddBalance(account1, big.NewInt(1000000)) - testTxPoolConfig.AccountQueue = 10 - for i := uint64(0); i < testTxPoolConfig.AccountQueue; i++ { - if err := pool1.AddRemote(transaction(origin+i, 100000, key1)); err != nil { - t.Fatalf("tx %d: failed to add transaction: %v", i, err) - } - } - // Add a batch of transactions to a pool in one big batch - pool2, key2 := setupTxPool() - defer pool2.Stop() - - account2, _ := deriveSender(transaction(0, 0, key2)) - pool2.currentState.AddBalance(account2, big.NewInt(1000000)) - - txs := []*types.Transaction{} - for i := uint64(0); i < testTxPoolConfig.AccountQueue; i++ { - txs = append(txs, transaction(origin+i, 100000, key2)) - } - pool2.AddRemotes(txs) - - // Ensure the batch optimization honors the same pool mechanics - if len(pool1.pending) != len(pool2.pending) { - t.Errorf("pending transaction count mismatch: one-by-one algo: %d, batch algo: %d", len(pool1.pending), len(pool2.pending)) - } - if len(pool1.queue) != len(pool2.queue) { - t.Errorf("queued transaction count mismatch: one-by-one algo: %d, batch algo: %d", len(pool1.queue), len(pool2.queue)) - } - if len(pool1.all) != len(pool2.all) { - t.Errorf("total transaction count mismatch: one-by-one algo %d, batch algo %d", len(pool1.all), len(pool2.all)) - } - if err := validateTxPoolInternals(pool1); err != nil { - t.Errorf("pool 1 internal state corrupted: %v", err) - } - if err := validateTxPoolInternals(pool2); err != nil { - t.Errorf("pool 2 internal state corrupted: %v", err) - } -} - // Tests that if the transaction count belonging to multiple accounts go above // some hard threshold, the higher transactions are dropped to prevent DOS // attacks. @@ -1069,7 +1120,7 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) { } } // Import the batch and verify that limits have been enforced - pool.AddRemotes(txs) + pool.AddRemotesSync(txs) pending := 0 for _, list := range pool.pending { @@ -1083,6 +1134,62 @@ func TestTransactionPendingGlobalLimiting(t *testing.T) { } } +// Test the limit on transaction size is enforced correctly. +// This test verifies every transaction having allowed size +// is added to the pool, and longer transactions are rejected. +func TestTransactionAllowedTxSize(t *testing.T) { + t.Parallel() + + // Create a test account and fund it + pool, key := setupTxPool() + defer pool.Stop() + + account := crypto.PubkeyToAddress(key.PublicKey) + pool.currentState.AddBalance(account, big.NewInt(1000000000)) + + // Compute maximal data size for transactions (lower bound). + // + // It is assumed the fields in the transaction (except of the data) are: + // - nonce <= 32 bytes + // - gasPrice <= 32 bytes + // - gasLimit <= 32 bytes + // - recipient == 20 bytes + // - value <= 32 bytes + // - signature == 65 bytes + // All those fields are summed up to at most 213 bytes. + baseSize := uint64(213) + dataSize := txMaxSize - baseSize + + // Try adding a transaction with maximal allowed size + tx := pricedDataTransaction(0, pool.currentMaxGas, big.NewInt(1), key, dataSize) + if err := pool.addRemoteSync(tx); err != nil { + t.Fatalf("failed to add transaction of size %d, close to maximal: %v", int(tx.Size()), err) + } + // Try adding a transaction with random allowed size + if err := pool.addRemoteSync(pricedDataTransaction(1, pool.currentMaxGas, big.NewInt(1), key, uint64(rand.Intn(int(dataSize))))); err != nil { + t.Fatalf("failed to add transaction of random allowed size: %v", err) + } + // Try adding a transaction of minimal not allowed size + if err := pool.addRemoteSync(pricedDataTransaction(2, pool.currentMaxGas, big.NewInt(1), key, txMaxSize)); err == nil { + t.Fatalf("expected rejection on slightly oversize transaction") + } + // Try adding a transaction of random not allowed size + if err := pool.addRemoteSync(pricedDataTransaction(2, pool.currentMaxGas, big.NewInt(1), key, dataSize+1+uint64(rand.Intn(int(10*txMaxSize))))); err == nil { + t.Fatalf("expected rejection on oversize transaction") + } + // Run some sanity checks on the pool internals + pending, queued := pool.Stats() + if pending != 2 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } +} + // Tests that if transactions start being capped, transactions are also removed from 'all' func TestTransactionCapClearsFromAll(t *testing.T) { t.Parallel() @@ -1128,9 +1235,8 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) { blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} config := testTxPoolConfig - config.AccountSlots = 10 - config.GlobalSlots = 0 config.AccountSlots = 5 + config.GlobalSlots = 1 pool := NewTxPool(config, params.TestChainConfig, blockchain) defer pool.Stop() @@ -1152,7 +1258,7 @@ func TestTransactionPendingMinimumAllowance(t *testing.T) { } } // Import the batch and verify that limits have been enforced - pool.AddRemotes(txs) + pool.AddRemotesSync(txs) for addr, list := range pool.pending { if list.Len() != int(config.AccountSlots) { @@ -1181,7 +1287,7 @@ func TestTransactionPoolRepricing(t *testing.T) { defer pool.Stop() // Keep track of transaction events to ensure all executables get announced - events := make(chan TxPreEvent, 32) + events := make(chan NewTxsEvent, 32) sub := pool.txFeed.Subscribe(events) defer sub.Unsubscribe() @@ -1209,7 +1315,7 @@ func TestTransactionPoolRepricing(t *testing.T) { ltx := pricedTransaction(0, 100000, big.NewInt(1), keys[3]) // Import the batch and that both pending and queued transactions match up - pool.AddRemotes(txs) + pool.AddRemotesSync(txs) pool.AddLocal(ltx) pending, queued := pool.Stats() @@ -1370,12 +1476,12 @@ func TestTransactionPoolUnderpricing(t *testing.T) { defer pool.Stop() // Keep track of transaction events to ensure all executables get announced - events := make(chan TxPreEvent, 32) + events := make(chan NewTxsEvent, 32) sub := pool.txFeed.Subscribe(events) defer sub.Unsubscribe() // Create a number of test accounts and fund them - keys := make([]*ecdsa.PrivateKey, 3) + keys := make([]*ecdsa.PrivateKey, 4) for i := 0; i < len(keys); i++ { keys[i], _ = crypto.GenerateKey() pool.currentState.AddBalance(crypto.PubkeyToAddress(keys[i].PublicKey), big.NewInt(1000000)) @@ -1412,13 +1518,13 @@ func TestTransactionPoolUnderpricing(t *testing.T) { t.Fatalf("adding underpriced pending transaction error mismatch: have %v, want %v", err, ErrUnderpriced) } // Ensure that adding high priced transactions drops cheap ones, but not own - if err := pool.AddRemote(pricedTransaction(0, 100000, big.NewInt(3), keys[1])); err != nil { + if err := pool.AddRemote(pricedTransaction(0, 100000, big.NewInt(3), keys[1])); err != nil { // +K1:0 => -K1:1 => Pend K0:0, K0:1, K1:0, K2:0; Que - t.Fatalf("failed to add well priced transaction: %v", err) } - if err := pool.AddRemote(pricedTransaction(2, 100000, big.NewInt(4), keys[1])); err != nil { + if err := pool.AddRemote(pricedTransaction(2, 100000, big.NewInt(4), keys[1])); err != nil { // +K1:2 => -K0:0 => Pend K1:0, K2:0; Que K0:1 K1:2 t.Fatalf("failed to add well priced transaction: %v", err) } - if err := pool.AddRemote(pricedTransaction(3, 100000, big.NewInt(5), keys[1])); err != nil { + if err := pool.AddRemote(pricedTransaction(3, 100000, big.NewInt(5), keys[1])); err != nil { // +K1:3 => -K0:1 => Pend K1:0, K2:0; Que K1:2 K1:3 t.Fatalf("failed to add well priced transaction: %v", err) } pending, queued = pool.Stats() @@ -1428,25 +1534,29 @@ func TestTransactionPoolUnderpricing(t *testing.T) { if queued != 2 { t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 2) } - if err := validateEvents(events, 2); err != nil { + if err := validateEvents(events, 1); err != nil { t.Fatalf("additional event firing failed: %v", err) } if err := validateTxPoolInternals(pool); err != nil { t.Fatalf("pool internal state corrupted: %v", err) } // Ensure that adding local transactions can push out even higher priced ones - tx := pricedTransaction(1, 100000, big.NewInt(1), keys[2]) - if err := pool.AddLocal(tx); err != nil { - t.Fatalf("failed to add underpriced local transaction: %v", err) + ltx = pricedTransaction(1, 100000, big.NewInt(1), keys[2]) + if err := pool.AddLocal(ltx); err != nil { + t.Fatalf("failed to append underpriced local transaction: %v", err) + } + ltx = pricedTransaction(0, 100000, big.NewInt(1), keys[3]) + if err := pool.AddLocal(ltx); err != nil { + t.Fatalf("failed to add new underpriced local transaction: %v", err) } pending, queued = pool.Stats() - if pending != 2 { - t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 2) + if pending != 3 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 3) } - if queued != 2 { - t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 2) + if queued != 1 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 1) } - if err := validateEvents(events, 1); err != nil { + if err := validateEvents(events, 2); err != nil { t.Fatalf("local event firing failed: %v", err) } if err := validateTxPoolInternals(pool); err != nil { @@ -1454,6 +1564,140 @@ func TestTransactionPoolUnderpricing(t *testing.T) { } } +// Tests that more expensive transactions push out cheap ones from the pool, but +// without producing instability by creating gaps that start jumping transactions +// back and forth between queued/pending. +func TestTransactionPoolStableUnderpricing(t *testing.T) { + t.Parallel() + + // Create the pool to test the pricing enforcement with + db := rawdb.NewMemoryDatabase() + statedb, _ := state.New(common.Hash{}, state.NewDatabase(db)) + blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} + + config := testTxPoolConfig + config.GlobalSlots = common.LimitThresholdNonceInQueue + config.GlobalQueue = 0 + config.AccountSlots = config.GlobalSlots - 1 + + pool := NewTxPool(config, params.TestChainConfig, blockchain) + defer pool.Stop() + + // Keep track of transaction events to ensure all executables get announced + events := make(chan NewTxsEvent, 32) + sub := pool.txFeed.Subscribe(events) + defer sub.Unsubscribe() + + // Create a number of test accounts and fund them + keys := make([]*ecdsa.PrivateKey, 2) + for i := 0; i < len(keys); i++ { + keys[i], _ = crypto.GenerateKey() + pool.currentState.AddBalance(crypto.PubkeyToAddress(keys[i].PublicKey), big.NewInt(1000000)) + } + // Fill up the entire queue with the same transaction price points + txs := types.Transactions{} + for i := uint64(0); i < config.GlobalSlots; i++ { + txs = append(txs, pricedTransaction(i, 100000, big.NewInt(1), keys[0])) + } + pool.AddRemotesSync(txs) + + pending, queued := pool.Stats() + if pending != int(config.GlobalSlots) { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, config.GlobalSlots) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateEvents(events, int(config.GlobalSlots)); err != nil { + t.Fatalf("original event firing failed: %v", err) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } + // Ensure that adding high priced transactions drops a cheap, but doesn't produce a gap + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(3), keys[1])); err != nil { + t.Fatalf("failed to add well priced transaction: %v", err) + } + pending, queued = pool.Stats() + if pending != int(config.GlobalSlots) { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, config.GlobalSlots) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateEvents(events, 1); err != nil { + t.Fatalf("additional event firing failed: %v", err) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } +} + +// Tests that the pool rejects duplicate transactions. +func TestTransactionDeduplication(t *testing.T) { + t.Parallel() + + // Create the pool to test the pricing enforcement with + statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) + blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)} + + pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain) + defer pool.Stop() + + // Create a test account to add transactions with + key, _ := crypto.GenerateKey() + pool.currentState.AddBalance(crypto.PubkeyToAddress(key.PublicKey), big.NewInt(1000000000)) + + // Create a batch of transactions and add a few of them + txs := make([]*types.Transaction, common.LimitThresholdNonceInQueue) + for i := 0; i < len(txs); i++ { + txs[i] = pricedTransaction(uint64(i), 100000, big.NewInt(1), key) + } + var firsts []*types.Transaction + for i := 0; i < len(txs); i += 2 { + firsts = append(firsts, txs[i]) + } + errs := pool.AddRemotesSync(firsts) + if len(errs) != len(firsts) { + t.Fatalf("first add mismatching result count: have %d, want %d", len(errs), len(firsts)) + } + for i, err := range errs { + if err != nil { + t.Errorf("add %d failed: %v", i, err) + } + } + pending, queued := pool.Stats() + if pending != 1 { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 1) + } + if queued != len(txs)/2-1 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, len(txs)/2-1) + } + // Try to add all of them now and ensure previous ones error out as knowns + errs = pool.AddRemotesSync(txs) + if len(errs) != len(txs) { + t.Fatalf("all add mismatching result count: have %d, want %d", len(errs), len(txs)) + } + for i, err := range errs { + if i%2 == 0 && err == nil { + t.Errorf("add %d succeeded, should have failed as known", i) + } + if i%2 == 1 && err != nil { + t.Errorf("add %d failed: %v", i, err) + } + } + pending, queued = pool.Stats() + if pending != len(txs) { + t.Fatalf("pending transactions mismatched: have %d, want %d", pending, len(txs)) + } + if queued != 0 { + t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0) + } + if err := validateTxPoolInternals(pool); err != nil { + t.Fatalf("pool internal state corrupted: %v", err) + } +} + // Tests that the pool rejects replacement transactions that don't meet the minimum // price bump required. func TestTransactionReplacement(t *testing.T) { @@ -1468,7 +1712,7 @@ func TestTransactionReplacement(t *testing.T) { defer pool.Stop() // Keep track of transaction events to ensure all executables get announced - events := make(chan TxPreEvent, 32) + events := make(chan NewTxsEvent, 32) sub := pool.txFeed.Subscribe(events) defer sub.Unsubscribe() @@ -1480,7 +1724,7 @@ func TestTransactionReplacement(t *testing.T) { price := int64(100) threshold := (price * (100 + int64(testTxPoolConfig.PriceBump))) / 100 - if err := pool.AddRemote(pricedTransaction(0, 100000, big.NewInt(1), key)); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), key)); err != nil { t.Fatalf("failed to add original cheap pending transaction: %v", err) } if err := pool.AddRemote(pricedTransaction(0, 100001, big.NewInt(1), key)); err != ErrReplaceUnderpriced { @@ -1493,7 +1737,7 @@ func TestTransactionReplacement(t *testing.T) { t.Fatalf("cheap replacement event firing failed: %v", err) } - if err := pool.AddRemote(pricedTransaction(0, 100000, big.NewInt(price), key)); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(price), key)); err != nil { t.Fatalf("failed to add original proper pending transaction: %v", err) } if err := pool.AddRemote(pricedTransaction(0, 100001, big.NewInt(threshold-1), key)); err != ErrReplaceUnderpriced { @@ -1505,6 +1749,7 @@ func TestTransactionReplacement(t *testing.T) { if err := validateEvents(events, 2); err != nil { t.Fatalf("proper replacement event firing failed: %v", err) } + // Add queued transactions, ensuring the minimum price bump is enforced for replacement (for ultra low prices too) if err := pool.AddRemote(pricedTransaction(2, 100000, big.NewInt(1), key)); err != nil { t.Fatalf("failed to add original cheap queued transaction: %v", err) @@ -1583,7 +1828,7 @@ func testTransactionJournaling(t *testing.T, nolocals bool) { if err := pool.AddLocal(pricedTransaction(2, 100000, big.NewInt(1), local)); err != nil { t.Fatalf("failed to add local transaction: %v", err) } - if err := pool.AddRemote(pricedTransaction(0, 100000, big.NewInt(1), remote)); err != nil { + if err := pool.addRemoteSync(pricedTransaction(0, 100000, big.NewInt(1), remote)); err != nil { t.Fatalf("failed to add remote transaction: %v", err) } pending, queued := pool.Stats() @@ -1621,7 +1866,7 @@ func testTransactionJournaling(t *testing.T, nolocals bool) { } // Bump the nonce temporarily and ensure the newly invalidated transaction is removed statedb.SetNonce(crypto.PubkeyToAddress(local.PublicKey), 2) - pool.lockedReset(nil, nil) + <-pool.requestReset(nil, nil) time.Sleep(2 * config.Rejournal) pool.Stop() @@ -1676,7 +1921,7 @@ func TestTransactionStatusCheck(t *testing.T) { txs = append(txs, pricedTransaction(2, 100000, big.NewInt(1), keys[2])) // Queued only // Import the transaction and ensure they are correctly added - pool.AddRemotes(txs) + pool.AddRemotesSync(txs) pending, queued := pool.Stats() if pending != 2 { @@ -1705,6 +1950,24 @@ func TestTransactionStatusCheck(t *testing.T) { } } +// Test the transaction slots consumption is computed correctly +func TestTransactionSlotCount(t *testing.T) { + t.Parallel() + + key, _ := crypto.GenerateKey() + + // Check that an empty transaction consumes a single slot + smallTx := pricedDataTransaction(0, 0, big.NewInt(0), key, 0) + if slots := numSlots(smallTx); slots != 1 { + t.Fatalf("small transactions slot count mismatch: have %d want %d", slots, 1) + } + // Check that a large transaction consumes the correct number of slots + bigTx := pricedDataTransaction(0, 0, big.NewInt(0), key, uint64(10*txSlotSize)) + if slots := numSlots(bigTx); slots != 11 { + t.Fatalf("big transactions slot count mismatch: have %d want %d", slots, 11) + } +} + // Benchmarks the speed of validating the contents of the pending queue of the // transaction pool. func BenchmarkPendingDemotion100(b *testing.B) { benchmarkPendingDemotion(b, 100) } @@ -1716,7 +1979,7 @@ func benchmarkPendingDemotion(b *testing.B, size int) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) for i := 0; i < size; i++ { @@ -1741,12 +2004,12 @@ func benchmarkFuturePromotion(b *testing.B, size int) { pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) for i := 0; i < size; i++ { tx := transaction(uint64(1+i), 100000, key) - pool.enqueueTx(tx.Hash(), tx) + pool.enqueueTx(tx.Hash(), tx, false, true) } // Benchmark the speed of pool validation b.ResetTimer() @@ -1755,37 +2018,21 @@ func benchmarkFuturePromotion(b *testing.B, size int) { } } -// Benchmarks the speed of iterative transaction insertion. -func BenchmarkPoolInsert(b *testing.B) { - // Generate a batch of transactions to enqueue into the pool - pool, key := setupTxPool() - defer pool.Stop() - - account, _ := deriveSender(transaction(0, 0, key)) - pool.currentState.AddBalance(account, big.NewInt(1000000)) - - txs := make(types.Transactions, b.N) - for i := 0; i < b.N; i++ { - txs[i] = transaction(uint64(i), 100000, key) - } - // Benchmark importing the transactions into the queue - b.ResetTimer() - for _, tx := range txs { - pool.AddRemote(tx) - } -} - // Benchmarks the speed of batched transaction insertion. -func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100) } -func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000) } -func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000) } +func BenchmarkPoolBatchInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, false) } +func BenchmarkPoolBatchInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, false) } +func BenchmarkPoolBatchInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, false) } + +func BenchmarkPoolBatchLocalInsert100(b *testing.B) { benchmarkPoolBatchInsert(b, 100, true) } +func BenchmarkPoolBatchLocalInsert1000(b *testing.B) { benchmarkPoolBatchInsert(b, 1000, true) } +func BenchmarkPoolBatchLocalInsert10000(b *testing.B) { benchmarkPoolBatchInsert(b, 10000, true) } -func benchmarkPoolBatchInsert(b *testing.B, size int) { +func benchmarkPoolBatchInsert(b *testing.B, size int, local bool) { // Generate a batch of transactions to enqueue into the pool pool, key := setupTxPool() defer pool.Stop() - account, _ := deriveSender(transaction(0, 0, key)) + account := crypto.PubkeyToAddress(key.PublicKey) pool.currentState.AddBalance(account, big.NewInt(1000000)) batches := make([]types.Transactions, b.N) @@ -1798,6 +2045,45 @@ func benchmarkPoolBatchInsert(b *testing.B, size int) { // Benchmark importing the transactions into the queue b.ResetTimer() for _, batch := range batches { - pool.AddRemotes(batch) + if local { + pool.AddLocals(batch) + } else { + pool.AddRemotes(batch) + } + } +} + +func BenchmarkInsertRemoteWithAllLocals(b *testing.B) { + // Allocate keys for testing + key, _ := crypto.GenerateKey() + account := crypto.PubkeyToAddress(key.PublicKey) + + remoteKey, _ := crypto.GenerateKey() + remoteAddr := crypto.PubkeyToAddress(remoteKey.PublicKey) + + locals := make([]*types.Transaction, 4096+1024) // Occupy all slots + for i := 0; i < len(locals); i++ { + locals[i] = transaction(uint64(i), 100000, key) + } + remotes := make([]*types.Transaction, 1000) + for i := 0; i < len(remotes); i++ { + remotes[i] = pricedTransaction(uint64(i), 100000, big.NewInt(2), remoteKey) // Higher gasprice + } + // Benchmark importing the transactions into the queue + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + pool, _ := setupTxPool() + pool.currentState.AddBalance(account, big.NewInt(100000000)) + for _, local := range locals { + pool.AddLocal(local) + } + b.StartTimer() + // Assign a high enough balance for testing + pool.currentState.AddBalance(remoteAddr, big.NewInt(100000000)) + for i := 0; i < len(remotes); i++ { + pool.AddRemotes([]*types.Transaction{remotes[i]}) + } + pool.Stop() } } diff --git a/core/types/access_list_tx.go b/core/types/access_list_tx.go new file mode 100644 index 000000000000..f80044e108fa --- /dev/null +++ b/core/types/access_list_tx.go @@ -0,0 +1,115 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "math/big" + + "github.com/XinFinOrg/XDPoSChain/common" +) + +//go:generate gencodec -type AccessTuple -out gen_access_tuple.go + +// AccessList is an EIP-2930 access list. +type AccessList []AccessTuple + +// AccessTuple is the element type of an access list. +type AccessTuple struct { + Address common.Address `json:"address" gencodec:"required"` + StorageKeys []common.Hash `json:"storageKeys" gencodec:"required"` +} + +// StorageKeys returns the total number of storage keys in the access list. +func (al AccessList) StorageKeys() int { + sum := 0 + for _, tuple := range al { + sum += len(tuple.StorageKeys) + } + return sum +} + +// AccessListTx is the data of EIP-2930 access list transactions. +type AccessListTx struct { + ChainID *big.Int // destination chain ID + Nonce uint64 // nonce of sender account + GasPrice *big.Int // wei per gas + Gas uint64 // gas limit + To *common.Address `rlp:"nil"` // nil means contract creation + Value *big.Int // wei amount + Data []byte // contract invocation input data + AccessList AccessList // EIP-2930 access list + V, R, S *big.Int // signature values +} + +// copy creates a deep copy of the transaction data and initializes all fields. +func (tx *AccessListTx) copy() TxData { + cpy := &AccessListTx{ + Nonce: tx.Nonce, + To: tx.To, // TODO: copy pointed-to address + Data: common.CopyBytes(tx.Data), + Gas: tx.Gas, + // These are copied below. + AccessList: make(AccessList, len(tx.AccessList)), + Value: new(big.Int), + ChainID: new(big.Int), + GasPrice: new(big.Int), + V: new(big.Int), + R: new(big.Int), + S: new(big.Int), + } + copy(cpy.AccessList, tx.AccessList) + if tx.Value != nil { + cpy.Value.Set(tx.Value) + } + if tx.ChainID != nil { + cpy.ChainID.Set(tx.ChainID) + } + if tx.GasPrice != nil { + cpy.GasPrice.Set(tx.GasPrice) + } + if tx.V != nil { + cpy.V.Set(tx.V) + } + if tx.R != nil { + cpy.R.Set(tx.R) + } + if tx.S != nil { + cpy.S.Set(tx.S) + } + return cpy +} + +// accessors for innerTx. + +func (tx *AccessListTx) txType() byte { return AccessListTxType } +func (tx *AccessListTx) chainID() *big.Int { return tx.ChainID } +func (tx *AccessListTx) protected() bool { return true } +func (tx *AccessListTx) accessList() AccessList { return tx.AccessList } +func (tx *AccessListTx) data() []byte { return tx.Data } +func (tx *AccessListTx) gas() uint64 { return tx.Gas } +func (tx *AccessListTx) gasPrice() *big.Int { return tx.GasPrice } +func (tx *AccessListTx) value() *big.Int { return tx.Value } +func (tx *AccessListTx) nonce() uint64 { return tx.Nonce } +func (tx *AccessListTx) to() *common.Address { return tx.To } + +func (tx *AccessListTx) rawSignatureValues() (v, r, s *big.Int) { + return tx.V, tx.R, tx.S +} + +func (tx *AccessListTx) setSignatureValues(chainID, v, r, s *big.Int) { + tx.ChainID, tx.V, tx.R, tx.S = chainID, v, r, s +} diff --git a/core/types/block.go b/core/types/block.go index cbeb5653732a..071666a801fa 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -29,7 +29,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" - "github.com/XinFinOrg/XDPoSChain/crypto/sha3" "github.com/XinFinOrg/XDPoSChain/rlp" ) @@ -155,13 +154,6 @@ func (h *Header) Size() common.StorageSize { return common.StorageSize(unsafe.Sizeof(*h)) + common.StorageSize(len(h.Extra)+(h.Difficulty.BitLen()+h.Number.BitLen()+h.Time.BitLen())/8) } -func rlpHash(x interface{}) (h common.Hash) { - hw := sha3.NewKeccak256() - rlp.Encode(hw, x) - hw.Sum(h[:0]) - return h -} - // Body is a simple (mutable, non-safe) data container for storing and moving // a block's data contents (transactions and uncles) together. type Body struct { diff --git a/core/types/block_test.go b/core/types/block_test.go index c95aaae71771..3cb50180b9f3 100644 --- a/core/types/block_test.go +++ b/core/types/block_test.go @@ -17,13 +17,18 @@ package types import ( + "bytes" + "hash" "math/big" + "reflect" "testing" - "bytes" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/math" + "github.com/XinFinOrg/XDPoSChain/crypto" + "github.com/XinFinOrg/XDPoSChain/params" "github.com/XinFinOrg/XDPoSChain/rlp" - "reflect" + "golang.org/x/crypto/sha3" ) // from bcValidBlockTest.json, "SimpleTx" @@ -59,3 +64,152 @@ func TestBlockEncoding(t *testing.T) { t.Errorf("encoded block mismatch:\ngot: %x\nwant: %x", ourBlockEnc, blockEnc) } } + +func TestEIP2718BlockEncoding(t *testing.T) { + blockEnc := common.FromHex("f9031cf90214a00000000000000000000000000000000000000000000000000000000000000000a01dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347948888f1f195afa192cfee860698584c030f4c9db1a0ef1552a40b7165c3cd773806b9e0c165b75356e0314bf0706f279c729f51e017a0e6e49996c7ec59f7a23d22b83239a60151512c65613bf84a0d7da336399ebc4aa0cafe75574d59780665a97fbfd11365c7545aa8f1abf4e5e12e8243334ef7286bb901000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000083020000820200832fefd882a410845506eb0796636f6f6c65737420626c6f636b206f6e20636861696ea0bd4472abb6659ebe3ee06ee4d7b72a00a9f4d001caca51342001075469aff49888a13a5a8c8f2bb1c4808080f90101f85f800a82c35094095e7baea6a6c7c4c2dfeb977efac326af552d870a801ba09bea4c4daac7c7c52e093e6a4c35dbbcf8856f1af7b059ba20253e70848d094fa08a8fae537ce25ed8cb5af9adac3f141af69bd515bd2ba031522df09b97dd72b1b89e01f89b01800a8301e24194095e7baea6a6c7c4c2dfeb977efac326af552d878080f838f7940000000000000000000000000000000000000001e1a0000000000000000000000000000000000000000000000000000000000000000001a03dbacc8d0259f2508625e97fdfc57cd85fdd16e5821bc2c10bdd1a52649e8335a0476e10695b183a87b0aa292a7f4b78ef0c3fbe62aa2c42c84e1d9c3da159ef14c0") + var block Block + if err := rlp.DecodeBytes(blockEnc, &block); err != nil { + t.Fatal("decode error: ", err) + } + + check := func(f string, got, want interface{}) { + if !reflect.DeepEqual(got, want) { + t.Errorf("%s mismatch: got %v, want %v", f, got, want) + } + } + check("Difficulty", block.Difficulty(), big.NewInt(131072)) + check("GasLimit", block.GasLimit(), uint64(3141592)) + check("GasUsed", block.GasUsed(), uint64(42000)) + check("Coinbase", block.Coinbase(), common.HexToAddress("8888f1f195afa192cfee860698584c030f4c9db1")) + check("MixDigest", block.MixDigest(), common.HexToHash("bd4472abb6659ebe3ee06ee4d7b72a00a9f4d001caca51342001075469aff498")) + check("Root", block.Root(), common.HexToHash("ef1552a40b7165c3cd773806b9e0c165b75356e0314bf0706f279c729f51e017")) + check("Nonce", block.Nonce(), uint64(0xa13a5a8c8f2bb1c4)) + check("Time", block.Time().Uint64(), uint64(1426516743)) + check("Size", block.Size(), common.StorageSize(len(blockEnc))) + + // Create legacy tx. + to := common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87") + tx1 := NewTx(&LegacyTx{ + Nonce: 0, + To: &to, + Value: big.NewInt(10), + Gas: 50000, + GasPrice: big.NewInt(10), + }) + sig := common.Hex2Bytes("9bea4c4daac7c7c52e093e6a4c35dbbcf8856f1af7b059ba20253e70848d094f8a8fae537ce25ed8cb5af9adac3f141af69bd515bd2ba031522df09b97dd72b100") + tx1, _ = tx1.WithSignature(HomesteadSigner{}, sig) + + // Create ACL tx. + addr := common.HexToAddress("0x0000000000000000000000000000000000000001") + tx2 := NewTx(&AccessListTx{ + ChainID: big.NewInt(1), + Nonce: 0, + To: &to, + Gas: 123457, + GasPrice: big.NewInt(10), + AccessList: AccessList{{Address: addr, StorageKeys: []common.Hash{{0}}}}, + }) + sig2 := common.Hex2Bytes("3dbacc8d0259f2508625e97fdfc57cd85fdd16e5821bc2c10bdd1a52649e8335476e10695b183a87b0aa292a7f4b78ef0c3fbe62aa2c42c84e1d9c3da159ef1401") + tx2, _ = tx2.WithSignature(NewEIP2930Signer(big.NewInt(1)), sig2) + + check("len(Transactions)", len(block.Transactions()), 2) + check("Transactions[0].Hash", block.Transactions()[0].Hash(), tx1.Hash()) + check("Transactions[1].Hash", block.Transactions()[1].Hash(), tx2.Hash()) + check("Transactions[1].Type()", block.Transactions()[1].Type(), uint8(AccessListTxType)) + + ourBlockEnc, err := rlp.EncodeToBytes(&block) + if err != nil { + t.Fatal("encode error: ", err) + } + if !bytes.Equal(ourBlockEnc, blockEnc) { + t.Errorf("encoded block mismatch:\ngot: %x\nwant: %x", ourBlockEnc, blockEnc) + } +} + +func TestUncleHash(t *testing.T) { + uncles := make([]*Header, 0) + h := CalcUncleHash(uncles) + exp := common.HexToHash("1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347") + if h != exp { + t.Fatalf("empty uncle hash is wrong, got %x != %x", h, exp) + } +} + +var benchBuffer = bytes.NewBuffer(make([]byte, 0, 32000)) + +func BenchmarkEncodeBlock(b *testing.B) { + block := makeBenchBlock() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchBuffer.Reset() + if err := rlp.Encode(benchBuffer, block); err != nil { + b.Fatal(err) + } + } +} + +// testHasher is the helper tool for transaction/receipt list hashing. +// The original hasher is trie, in order to get rid of import cycle, +// use the testing hasher instead. +type testHasher struct { + hasher hash.Hash +} + +func newHasher() *testHasher { + return &testHasher{hasher: sha3.NewLegacyKeccak256()} +} + +func (h *testHasher) Reset() { + h.hasher.Reset() +} + +func (h *testHasher) Update(key, val []byte) { + h.hasher.Write(key) + h.hasher.Write(val) +} + +func (h *testHasher) Hash() common.Hash { + return common.BytesToHash(h.hasher.Sum(nil)) +} + +func makeBenchBlock() *Block { + var ( + key, _ = crypto.GenerateKey() + txs = make([]*Transaction, 70) + receipts = make([]*Receipt, len(txs)) + signer = LatestSigner(params.TestChainConfig) + uncles = make([]*Header, 3) + ) + header := &Header{ + Difficulty: math.BigPow(11, 11), + Number: math.BigPow(2, 9), + GasLimit: 12345678, + GasUsed: 1476322, + Time: big.NewInt(9876543), + Extra: []byte("coolest block on chain"), + } + for i := range txs { + amount := math.BigPow(2, int64(i)) + price := big.NewInt(300000) + data := make([]byte, 100) + tx := NewTransaction(uint64(i), common.Address{}, amount, 123457, price, data) + signedTx, err := SignTx(tx, signer, key) + if err != nil { + panic(err) + } + txs[i] = signedTx + receipts[i] = NewReceipt(make([]byte, 32), false, tx.Gas()) + } + for i := range uncles { + uncles[i] = &Header{ + Difficulty: math.BigPow(11, 11), + Number: math.BigPow(2, 9), + GasLimit: 12345678, + GasUsed: 1476322, + Time: big.NewInt(9876543), + Extra: []byte("benchmark uncle"), + } + } + return NewBlock(header, txs, uncles, receipts) +} diff --git a/core/types/gen_access_tuple.go b/core/types/gen_access_tuple.go new file mode 100644 index 000000000000..d23b32c00925 --- /dev/null +++ b/core/types/gen_access_tuple.go @@ -0,0 +1,43 @@ +// Code generated by github.com/fjl/gencodec. DO NOT EDIT. + +package types + +import ( + "encoding/json" + "errors" + + "github.com/XinFinOrg/XDPoSChain/common" +) + +// MarshalJSON marshals as JSON. +func (a AccessTuple) MarshalJSON() ([]byte, error) { + type AccessTuple struct { + Address common.Address `json:"address" gencodec:"required"` + StorageKeys []common.Hash `json:"storageKeys" gencodec:"required"` + } + var enc AccessTuple + enc.Address = a.Address + enc.StorageKeys = a.StorageKeys + return json.Marshal(&enc) +} + +// UnmarshalJSON unmarshals from JSON. +func (a *AccessTuple) UnmarshalJSON(input []byte) error { + type AccessTuple struct { + Address *common.Address `json:"address" gencodec:"required"` + StorageKeys []common.Hash `json:"storageKeys" gencodec:"required"` + } + var dec AccessTuple + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + if dec.Address == nil { + return errors.New("missing required field 'address' for AccessTuple") + } + a.Address = *dec.Address + if dec.StorageKeys == nil { + return errors.New("missing required field 'storageKeys' for AccessTuple") + } + a.StorageKeys = dec.StorageKeys + return nil +} diff --git a/core/types/gen_receipt_json.go b/core/types/gen_receipt_json.go index 5fce768227c0..b72ef0270d3d 100644 --- a/core/types/gen_receipt_json.go +++ b/core/types/gen_receipt_json.go @@ -5,6 +5,7 @@ package types import ( "encoding/json" "errors" + "math/big" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" @@ -12,8 +13,10 @@ import ( var _ = (*receiptMarshaling)(nil) +// MarshalJSON marshals as JSON. func (r Receipt) MarshalJSON() ([]byte, error) { type Receipt struct { + Type hexutil.Uint64 `json:"type,omitempty"` PostState hexutil.Bytes `json:"root"` Status hexutil.Uint `json:"status"` CumulativeGasUsed hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"` @@ -22,8 +25,12 @@ func (r Receipt) MarshalJSON() ([]byte, error) { TxHash common.Hash `json:"transactionHash" gencodec:"required"` ContractAddress common.Address `json:"contractAddress"` GasUsed hexutil.Uint64 `json:"gasUsed" gencodec:"required"` + BlockHash common.Hash `json:"blockHash,omitempty"` + BlockNumber *hexutil.Big `json:"blockNumber,omitempty"` + TransactionIndex hexutil.Uint `json:"transactionIndex"` } var enc Receipt + enc.Type = hexutil.Uint64(r.Type) enc.PostState = r.PostState enc.Status = hexutil.Uint(r.Status) enc.CumulativeGasUsed = hexutil.Uint64(r.CumulativeGasUsed) @@ -32,11 +39,16 @@ func (r Receipt) MarshalJSON() ([]byte, error) { enc.TxHash = r.TxHash enc.ContractAddress = r.ContractAddress enc.GasUsed = hexutil.Uint64(r.GasUsed) + enc.BlockHash = r.BlockHash + enc.BlockNumber = (*hexutil.Big)(r.BlockNumber) + enc.TransactionIndex = hexutil.Uint(r.TransactionIndex) return json.Marshal(&enc) } +// UnmarshalJSON unmarshals from JSON. func (r *Receipt) UnmarshalJSON(input []byte) error { type Receipt struct { + Type *hexutil.Uint64 `json:"type,omitempty"` PostState *hexutil.Bytes `json:"root"` Status *hexutil.Uint `json:"status"` CumulativeGasUsed *hexutil.Uint64 `json:"cumulativeGasUsed" gencodec:"required"` @@ -45,11 +57,17 @@ func (r *Receipt) UnmarshalJSON(input []byte) error { TxHash *common.Hash `json:"transactionHash" gencodec:"required"` ContractAddress *common.Address `json:"contractAddress"` GasUsed *hexutil.Uint64 `json:"gasUsed" gencodec:"required"` + BlockHash *common.Hash `json:"blockHash,omitempty"` + BlockNumber *hexutil.Big `json:"blockNumber,omitempty"` + TransactionIndex *hexutil.Uint `json:"transactionIndex"` } var dec Receipt if err := json.Unmarshal(input, &dec); err != nil { return err } + if dec.Type != nil { + r.Type = uint8(*dec.Type) + } if dec.PostState != nil { r.PostState = *dec.PostState } @@ -79,5 +97,14 @@ func (r *Receipt) UnmarshalJSON(input []byte) error { return errors.New("missing required field 'gasUsed' for Receipt") } r.GasUsed = uint64(*dec.GasUsed) + if dec.BlockHash != nil { + r.BlockHash = *dec.BlockHash + } + if dec.BlockNumber != nil { + r.BlockNumber = (*big.Int)(dec.BlockNumber) + } + if dec.TransactionIndex != nil { + r.TransactionIndex = uint(*dec.TransactionIndex) + } return nil } diff --git a/core/types/gen_tx_json.go b/core/types/gen_tx_json.go deleted file mode 100644 index 11b6f8ab4151..000000000000 --- a/core/types/gen_tx_json.go +++ /dev/null @@ -1,99 +0,0 @@ -// Code generated by github.com/fjl/gencodec. DO NOT EDIT. - -package types - -import ( - "encoding/json" - "errors" - "math/big" - - "github.com/XinFinOrg/XDPoSChain/common" - "github.com/XinFinOrg/XDPoSChain/common/hexutil" -) - -var _ = (*txdataMarshaling)(nil) - -func (t txdata) MarshalJSON() ([]byte, error) { - type txdata struct { - AccountNonce hexutil.Uint64 `json:"nonce" gencodec:"required"` - Price *hexutil.Big `json:"gasPrice" gencodec:"required"` - GasLimit hexutil.Uint64 `json:"gas" gencodec:"required"` - Recipient *common.Address `json:"to" rlp:"nil"` - Amount *hexutil.Big `json:"value" gencodec:"required"` - Payload hexutil.Bytes `json:"input" gencodec:"required"` - V *hexutil.Big `json:"v" gencodec:"required"` - R *hexutil.Big `json:"r" gencodec:"required"` - S *hexutil.Big `json:"s" gencodec:"required"` - Hash *common.Hash `json:"hash" rlp:"-"` - } - var enc txdata - enc.AccountNonce = hexutil.Uint64(t.AccountNonce) - enc.Price = (*hexutil.Big)(t.Price) - enc.GasLimit = hexutil.Uint64(t.GasLimit) - enc.Recipient = t.Recipient - enc.Amount = (*hexutil.Big)(t.Amount) - enc.Payload = t.Payload - enc.V = (*hexutil.Big)(t.V) - enc.R = (*hexutil.Big)(t.R) - enc.S = (*hexutil.Big)(t.S) - enc.Hash = t.Hash - return json.Marshal(&enc) -} - -func (t *txdata) UnmarshalJSON(input []byte) error { - type txdata struct { - AccountNonce *hexutil.Uint64 `json:"nonce" gencodec:"required"` - Price *hexutil.Big `json:"gasPrice" gencodec:"required"` - GasLimit *hexutil.Uint64 `json:"gas" gencodec:"required"` - Recipient *common.Address `json:"to" rlp:"nil"` - Amount *hexutil.Big `json:"value" gencodec:"required"` - Payload *hexutil.Bytes `json:"input" gencodec:"required"` - V *hexutil.Big `json:"v" gencodec:"required"` - R *hexutil.Big `json:"r" gencodec:"required"` - S *hexutil.Big `json:"s" gencodec:"required"` - Hash *common.Hash `json:"hash" rlp:"-"` - } - var dec txdata - if err := json.Unmarshal(input, &dec); err != nil { - return err - } - if dec.AccountNonce == nil { - return errors.New("missing required field 'nonce' for txdata") - } - t.AccountNonce = uint64(*dec.AccountNonce) - if dec.Price == nil { - return errors.New("missing required field 'gasPrice' for txdata") - } - t.Price = (*big.Int)(dec.Price) - if dec.GasLimit == nil { - return errors.New("missing required field 'gas' for txdata") - } - t.GasLimit = uint64(*dec.GasLimit) - if dec.Recipient != nil { - t.Recipient = dec.Recipient - } - if dec.Amount == nil { - return errors.New("missing required field 'value' for txdata") - } - t.Amount = (*big.Int)(dec.Amount) - if dec.Payload == nil { - return errors.New("missing required field 'input' for txdata") - } - t.Payload = *dec.Payload - if dec.V == nil { - return errors.New("missing required field 'v' for txdata") - } - t.V = (*big.Int)(dec.V) - if dec.R == nil { - return errors.New("missing required field 'r' for txdata") - } - t.R = (*big.Int)(dec.R) - if dec.S == nil { - return errors.New("missing required field 's' for txdata") - } - t.S = (*big.Int)(dec.S) - if dec.Hash != nil { - t.Hash = dec.Hash - } - return nil -} diff --git a/core/types/hashing.go b/core/types/hashing.go new file mode 100644 index 000000000000..73a5cfe53e06 --- /dev/null +++ b/core/types/hashing.go @@ -0,0 +1,59 @@ +// Copyright 2014 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "bytes" + "sync" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/crypto" + "github.com/XinFinOrg/XDPoSChain/rlp" + "golang.org/x/crypto/sha3" +) + +// hasherPool holds LegacyKeccak256 hashers for rlpHash. +var hasherPool = sync.Pool{ + New: func() interface{} { return sha3.NewLegacyKeccak256() }, +} + +// deriveBufferPool holds temporary encoder buffers for DeriveSha and TX encoding. +var encodeBufferPool = sync.Pool{ + New: func() interface{} { return new(bytes.Buffer) }, +} + +// rlpHash encodes x and hashes the encoded bytes. +func rlpHash(x interface{}) (h common.Hash) { + sha := hasherPool.Get().(crypto.KeccakState) + defer hasherPool.Put(sha) + sha.Reset() + rlp.Encode(sha, x) + sha.Read(h[:]) + return h +} + +// prefixedRlpHash writes the prefix into the hasher before rlp-encoding x. +// It's used for typed transactions. +func prefixedRlpHash(prefix byte, x interface{}) (h common.Hash) { + sha := hasherPool.Get().(crypto.KeccakState) + defer hasherPool.Put(sha) + sha.Reset() + sha.Write([]byte{prefix}) + rlp.Encode(sha, x) + sha.Read(h[:]) + return h +} diff --git a/core/types/legacy_tx.go b/core/types/legacy_tx.go new file mode 100644 index 000000000000..146a1e2877e9 --- /dev/null +++ b/core/types/legacy_tx.go @@ -0,0 +1,111 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "math/big" + + "github.com/XinFinOrg/XDPoSChain/common" +) + +// LegacyTx is the transaction data of regular Ethereum transactions. +type LegacyTx struct { + Nonce uint64 // nonce of sender account + GasPrice *big.Int // wei per gas + Gas uint64 // gas limit + To *common.Address `rlp:"nil"` // nil means contract creation + Value *big.Int // wei amount + Data []byte // contract invocation input data + V, R, S *big.Int // signature values +} + +// NewTransaction creates an unsigned legacy transaction. +// Deprecated: use NewTx instead. +func NewTransaction(nonce uint64, to common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { + return NewTx(&LegacyTx{ + Nonce: nonce, + To: &to, + Value: amount, + Gas: gasLimit, + GasPrice: gasPrice, + Data: data, + }) +} + +// NewContractCreation creates an unsigned legacy transaction. +// Deprecated: use NewTx instead. +func NewContractCreation(nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { + return NewTx(&LegacyTx{ + Nonce: nonce, + Value: amount, + Gas: gasLimit, + GasPrice: gasPrice, + Data: data, + }) +} + +// copy creates a deep copy of the transaction data and initializes all fields. +func (tx *LegacyTx) copy() TxData { + cpy := &LegacyTx{ + Nonce: tx.Nonce, + To: tx.To, // TODO: copy pointed-to address + Data: common.CopyBytes(tx.Data), + Gas: tx.Gas, + // These are initialized below. + Value: new(big.Int), + GasPrice: new(big.Int), + V: new(big.Int), + R: new(big.Int), + S: new(big.Int), + } + if tx.Value != nil { + cpy.Value.Set(tx.Value) + } + if tx.GasPrice != nil { + cpy.GasPrice.Set(tx.GasPrice) + } + if tx.V != nil { + cpy.V.Set(tx.V) + } + if tx.R != nil { + cpy.R.Set(tx.R) + } + if tx.S != nil { + cpy.S.Set(tx.S) + } + return cpy +} + +// accessors for innerTx. + +func (tx *LegacyTx) txType() byte { return LegacyTxType } +func (tx *LegacyTx) chainID() *big.Int { return deriveChainId(tx.V) } +func (tx *LegacyTx) accessList() AccessList { return nil } +func (tx *LegacyTx) data() []byte { return tx.Data } +func (tx *LegacyTx) gas() uint64 { return tx.Gas } +func (tx *LegacyTx) gasPrice() *big.Int { return tx.GasPrice } +func (tx *LegacyTx) value() *big.Int { return tx.Value } +func (tx *LegacyTx) nonce() uint64 { return tx.Nonce } +func (tx *LegacyTx) to() *common.Address { return tx.To } + +func (tx *LegacyTx) rawSignatureValues() (v, r, s *big.Int) { + return tx.V, tx.R, tx.S +} + +func (tx *LegacyTx) setSignatureValues(chainID, v, r, s *big.Int) { + tx.V, tx.R, tx.S = v, r, s +} diff --git a/core/types/receipt.go b/core/types/receipt.go index f49d6f1c87f9..d1b3cc46bf37 100644 --- a/core/types/receipt.go +++ b/core/types/receipt.go @@ -18,8 +18,10 @@ package types import ( "bytes" + "errors" "fmt" "io" + "math/big" "unsafe" "github.com/XinFinOrg/XDPoSChain/common" @@ -34,6 +36,9 @@ var ( receiptStatusSuccessfulRLP = []byte{0x01} ) +// This error is returned when a typed receipt is decoded, but the string is empty. +var errEmptyTypedReceipt = errors.New("empty typed receipt bytes") + const ( // ReceiptStatusFailed is the status code of a transaction if execution failed. ReceiptStatusFailed = uint(0) @@ -44,24 +49,35 @@ const ( // Receipt represents the results of a transaction. type Receipt struct { - // Consensus fields + // Consensus fields: These fields are defined by the Yellow Paper + Type uint8 `json:"type,omitempty"` PostState []byte `json:"root"` Status uint `json:"status"` CumulativeGasUsed uint64 `json:"cumulativeGasUsed" gencodec:"required"` Bloom Bloom `json:"logsBloom" gencodec:"required"` Logs []*Log `json:"logs" gencodec:"required"` - // Implementation fields (don't reorder!) + // Implementation fields: These fields are added by geth when processing a transaction. + // They are stored in the chain database. TxHash common.Hash `json:"transactionHash" gencodec:"required"` ContractAddress common.Address `json:"contractAddress"` GasUsed uint64 `json:"gasUsed" gencodec:"required"` + + // Inclusion information: These fields provide information about the inclusion of the + // transaction corresponding to this receipt. + BlockHash common.Hash `json:"blockHash,omitempty"` + BlockNumber *big.Int `json:"blockNumber,omitempty"` + TransactionIndex uint `json:"transactionIndex"` } type receiptMarshaling struct { + Type hexutil.Uint64 PostState hexutil.Bytes Status hexutil.Uint CumulativeGasUsed hexutil.Uint64 GasUsed hexutil.Uint64 + BlockNumber *hexutil.Big + TransactionIndex hexutil.Uint } // receiptRLP is the consensus encoding of a receipt. @@ -72,7 +88,18 @@ type receiptRLP struct { Logs []*Log } -type receiptStorageRLP struct { +// v4StoredReceiptRLP is the storage encoding of a receipt used in database version 4. +type v4StoredReceiptRLP struct { + PostStateOrStatus []byte + CumulativeGasUsed uint64 + TxHash common.Hash + ContractAddress common.Address + Logs []*LogForStorage + GasUsed uint64 +} + +// v3StoredReceiptRLP is the original storage encoding of a receipt including some unnecessary fields. +type v3StoredReceiptRLP struct { PostStateOrStatus []byte CumulativeGasUsed uint64 Bloom Bloom @@ -83,8 +110,13 @@ type receiptStorageRLP struct { } // NewReceipt creates a barebone transaction receipt, copying the init fields. +// Deprecated: create receipts using a struct literal instead. func NewReceipt(root []byte, failed bool, cumulativeGasUsed uint64) *Receipt { - r := &Receipt{PostState: common.CopyBytes(root), CumulativeGasUsed: cumulativeGasUsed} + r := &Receipt{ + Type: LegacyTxType, + PostState: common.CopyBytes(root), + CumulativeGasUsed: cumulativeGasUsed, + } if failed { r.Status = ReceiptStatusFailed } else { @@ -96,21 +128,65 @@ func NewReceipt(root []byte, failed bool, cumulativeGasUsed uint64) *Receipt { // EncodeRLP implements rlp.Encoder, and flattens the consensus fields of a receipt // into an RLP stream. If no post state is present, byzantium fork is assumed. func (r *Receipt) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs}) + data := &receiptRLP{r.statusEncoding(), r.CumulativeGasUsed, r.Bloom, r.Logs} + if r.Type == LegacyTxType { + return rlp.Encode(w, data) + } + // It's an EIP-2718 typed TX receipt. + if r.Type != AccessListTxType { + return ErrTxTypeNotSupported + } + buf := encodeBufferPool.Get().(*bytes.Buffer) + defer encodeBufferPool.Put(buf) + buf.Reset() + buf.WriteByte(r.Type) + if err := rlp.Encode(buf, data); err != nil { + return err + } + return rlp.Encode(w, buf.Bytes()) } // DecodeRLP implements rlp.Decoder, and loads the consensus fields of a receipt // from an RLP stream. func (r *Receipt) DecodeRLP(s *rlp.Stream) error { - var dec receiptRLP - if err := s.Decode(&dec); err != nil { - return err - } - if err := r.setStatus(dec.PostStateOrStatus); err != nil { + kind, _, err := s.Kind() + switch { + case err != nil: return err + case kind == rlp.List: + // It's a legacy receipt. + var dec receiptRLP + if err := s.Decode(&dec); err != nil { + return err + } + r.Type = LegacyTxType + return r.setFromRLP(dec) + case kind == rlp.String: + // It's an EIP-2718 typed tx receipt. + b, err := s.Bytes() + if err != nil { + return err + } + if len(b) == 0 { + return errEmptyTypedReceipt + } + r.Type = b[0] + if r.Type == AccessListTxType { + var dec receiptRLP + if err := rlp.DecodeBytes(b[1:], &dec); err != nil { + return err + } + return r.setFromRLP(dec) + } + return ErrTxTypeNotSupported + default: + return rlp.ErrExpectedList } - r.CumulativeGasUsed, r.Bloom, r.Logs = dec.CumulativeGasUsed, dec.Bloom, dec.Logs - return nil +} + +func (r *Receipt) setFromRLP(data receiptRLP) error { + r.CumulativeGasUsed, r.Bloom, r.Logs = data.CumulativeGasUsed, data.Bloom, data.Logs + return r.setStatus(data.PostStateOrStatus) } func (r *Receipt) setStatus(postStateOrStatus []byte) error { @@ -141,7 +217,6 @@ func (r *Receipt) statusEncoding() []byte { // to approximate and limit the memory consumption of various caches. func (r *Receipt) Size() common.StorageSize { size := common.StorageSize(unsafe.Sizeof(*r)) + common.StorageSize(len(r.PostState)) - size += common.StorageSize(len(r.Logs)) * common.StorageSize(unsafe.Sizeof(Log{})) for _, log := range r.Logs { size += common.StorageSize(len(log.Topics)*common.HashLength + len(log.Data)) @@ -164,7 +239,7 @@ type ReceiptForStorage Receipt // EncodeRLP implements rlp.Encoder, and flattens all content fields of a receipt // into an RLP stream. func (r *ReceiptForStorage) EncodeRLP(w io.Writer) error { - enc := &receiptStorageRLP{ + enc := &v3StoredReceiptRLP{ PostStateOrStatus: (*Receipt)(r).statusEncoding(), CumulativeGasUsed: r.CumulativeGasUsed, Bloom: r.Bloom, @@ -182,25 +257,64 @@ func (r *ReceiptForStorage) EncodeRLP(w io.Writer) error { // DecodeRLP implements rlp.Decoder, and loads both consensus and implementation // fields of a receipt from an RLP stream. func (r *ReceiptForStorage) DecodeRLP(s *rlp.Stream) error { - var dec receiptStorageRLP - if err := s.Decode(&dec); err != nil { + // Retrieve the entire receipt blob as we need to try multiple decoders + blob, err := s.Raw() + if err != nil { + return err + } + // Try decoding from the newest format for future proofness, then the older one + // for old nodes that just upgraded. V4 was an intermediate unreleased format so + // we do need to decode it, but it's not common (try last). + if err := decodeV3StoredReceiptRLP(r, blob); err == nil { + return nil + } + return decodeV4StoredReceiptRLP(r, blob) +} + +func decodeV3StoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { + var stored v3StoredReceiptRLP + if err := rlp.DecodeBytes(blob, &stored); err != nil { return err } - if err := (*Receipt)(r).setStatus(dec.PostStateOrStatus); err != nil { + if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { return err } // Assign the consensus fields - r.CumulativeGasUsed, r.Bloom = dec.CumulativeGasUsed, dec.Bloom - r.Logs = make([]*Log, len(dec.Logs)) - for i, log := range dec.Logs { + r.CumulativeGasUsed = stored.CumulativeGasUsed + r.Bloom = stored.Bloom + r.Logs = make([]*Log, len(stored.Logs)) + for i, log := range stored.Logs { r.Logs[i] = (*Log)(log) } // Assign the implementation fields - r.TxHash, r.ContractAddress, r.GasUsed = dec.TxHash, dec.ContractAddress, dec.GasUsed + r.TxHash = stored.TxHash + r.ContractAddress = stored.ContractAddress + r.GasUsed = stored.GasUsed + return nil +} + +func decodeV4StoredReceiptRLP(r *ReceiptForStorage, blob []byte) error { + var stored v4StoredReceiptRLP + if err := rlp.DecodeBytes(blob, &stored); err != nil { + return err + } + if err := (*Receipt)(r).setStatus(stored.PostStateOrStatus); err != nil { + return err + } + r.CumulativeGasUsed = stored.CumulativeGasUsed + r.TxHash = stored.TxHash + r.ContractAddress = stored.ContractAddress + r.GasUsed = stored.GasUsed + r.Logs = make([]*Log, len(stored.Logs)) + for i, log := range stored.Logs { + r.Logs[i] = (*Log)(log) + } + r.Bloom = CreateBloom(Receipts{(*Receipt)(r)}) + return nil } -// Receipts is a wrapper around a Receipt array to implement DerivableList. +// Receipts implements DerivableList for receipts. type Receipts []*Receipt // Len returns the number of receipts in this list. diff --git a/core/types/receipt_test.go b/core/types/receipt_test.go new file mode 100644 index 000000000000..82fec06c9667 --- /dev/null +++ b/core/types/receipt_test.go @@ -0,0 +1,170 @@ +// Copyright 2019 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package types + +import ( + "bytes" + "math/big" + "reflect" + "testing" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/rlp" +) + +func TestDecodeEmptyTypedReceipt(t *testing.T) { + input := []byte{0x80} + var r Receipt + err := rlp.DecodeBytes(input, &r) + if err != errEmptyTypedReceipt { + t.Fatal("wrong error:", err) + } +} + +func TestLegacyReceiptDecoding(t *testing.T) { + tests := []struct { + name string + encode func(*Receipt) ([]byte, error) + }{ + { + "V4StoredReceiptRLP", + encodeAsV4StoredReceiptRLP, + }, + { + "V3StoredReceiptRLP", + encodeAsV3StoredReceiptRLP, + }, + } + + tx := NewTransaction(1, common.HexToAddress("0x1"), big.NewInt(1), 1, big.NewInt(1), nil) + receipt := &Receipt{ + Status: ReceiptStatusFailed, + CumulativeGasUsed: 1, + Logs: []*Log{ + { + Address: common.BytesToAddress([]byte{0x11}), + Topics: []common.Hash{common.HexToHash("dead"), common.HexToHash("beef")}, + Data: []byte{0x01, 0x00, 0xff}, + }, + { + Address: common.BytesToAddress([]byte{0x01, 0x11}), + Topics: []common.Hash{common.HexToHash("dead"), common.HexToHash("beef")}, + Data: []byte{0x01, 0x00, 0xff}, + }, + }, + TxHash: tx.Hash(), + ContractAddress: common.BytesToAddress([]byte{0x01, 0x11, 0x11}), + GasUsed: 111111, + } + receipt.Bloom = CreateBloom(Receipts{receipt}) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + enc, err := tc.encode(receipt) + if err != nil { + t.Fatalf("Error encoding receipt: %v", err) + } + var dec ReceiptForStorage + if err := rlp.DecodeBytes(enc, &dec); err != nil { + t.Fatalf("Error decoding RLP receipt: %v", err) + } + // Check whether all consensus fields are correct. + if dec.Status != receipt.Status { + t.Fatalf("Receipt status mismatch, want %v, have %v", receipt.Status, dec.Status) + } + if dec.CumulativeGasUsed != receipt.CumulativeGasUsed { + t.Fatalf("Receipt CumulativeGasUsed mismatch, want %v, have %v", receipt.CumulativeGasUsed, dec.CumulativeGasUsed) + } + if dec.Bloom != receipt.Bloom { + t.Fatalf("Bloom data mismatch, want %v, have %v", receipt.Bloom, dec.Bloom) + } + if len(dec.Logs) != len(receipt.Logs) { + t.Fatalf("Receipt log number mismatch, want %v, have %v", len(receipt.Logs), len(dec.Logs)) + } + for i := 0; i < len(dec.Logs); i++ { + if dec.Logs[i].Address != receipt.Logs[i].Address { + t.Fatalf("Receipt log %d address mismatch, want %v, have %v", i, receipt.Logs[i].Address, dec.Logs[i].Address) + } + if !reflect.DeepEqual(dec.Logs[i].Topics, receipt.Logs[i].Topics) { + t.Fatalf("Receipt log %d topics mismatch, want %v, have %v", i, receipt.Logs[i].Topics, dec.Logs[i].Topics) + } + if !bytes.Equal(dec.Logs[i].Data, receipt.Logs[i].Data) { + t.Fatalf("Receipt log %d data mismatch, want %v, have %v", i, receipt.Logs[i].Data, dec.Logs[i].Data) + } + } + }) + } +} + +func encodeAsV4StoredReceiptRLP(want *Receipt) ([]byte, error) { + stored := &v4StoredReceiptRLP{ + PostStateOrStatus: want.statusEncoding(), + CumulativeGasUsed: want.CumulativeGasUsed, + TxHash: want.TxHash, + ContractAddress: want.ContractAddress, + Logs: make([]*LogForStorage, len(want.Logs)), + GasUsed: want.GasUsed, + } + for i, log := range want.Logs { + stored.Logs[i] = (*LogForStorage)(log) + } + return rlp.EncodeToBytes(stored) +} + +func encodeAsV3StoredReceiptRLP(want *Receipt) ([]byte, error) { + stored := &v3StoredReceiptRLP{ + PostStateOrStatus: want.statusEncoding(), + CumulativeGasUsed: want.CumulativeGasUsed, + Bloom: want.Bloom, + TxHash: want.TxHash, + ContractAddress: want.ContractAddress, + Logs: make([]*LogForStorage, len(want.Logs)), + GasUsed: want.GasUsed, + } + for i, log := range want.Logs { + stored.Logs[i] = (*LogForStorage)(log) + } + return rlp.EncodeToBytes(stored) +} + +// TestTypedReceiptEncodingDecoding reproduces a flaw that existed in the receipt +// rlp decoder, which failed due to a shadowing error. +func TestTypedReceiptEncodingDecoding(t *testing.T) { + var payload = common.FromHex("f9043eb9010c01f90108018262d4b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0b9010c01f901080182cd14b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0b9010d01f901090183013754b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0b9010d01f90109018301a194b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000c0") + check := func(bundle []*Receipt) { + t.Helper() + for i, receipt := range bundle { + if got, want := receipt.Type, uint8(1); got != want { + t.Fatalf("bundle %d: got %x, want %x", i, got, want) + } + } + } + { + var bundle []*Receipt + rlp.DecodeBytes(payload, &bundle) + check(bundle) + } + { + var bundle []*Receipt + r := bytes.NewReader(payload) + s := rlp.NewStream(r, uint64(len(payload))) + if err := s.Decode(&bundle); err != nil { + t.Fatal(err) + } + check(bundle) + } +} diff --git a/core/types/transaction.go b/core/types/transaction.go index 03f1bbe39ae2..983162943c52 100644 --- a/core/types/transaction.go +++ b/core/types/transaction.go @@ -24,9 +24,9 @@ import ( "io" "math/big" "sync/atomic" + "time" "github.com/XinFinOrg/XDPoSChain/common" - "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/rlp" ) @@ -35,6 +35,15 @@ import ( var ( ErrInvalidSig = errors.New("invalid transaction v, r, s values") + ErrUnexpectedProtection = errors.New("transaction type does not supported EIP-155 protected signatures") + ErrInvalidTxType = errors.New("transaction type not valid in this context") + ErrTxTypeNotSupported = errors.New("transaction type not supported") + ErrGasFeeCapTooLow = errors.New("fee cap less than base fee") + errShortTypedTx = errors.New("typed transaction too short") + errInvalidYParity = errors.New("'yParity' field must be 0 or 1") + errVYParityMismatch = errors.New("'v' and 'yParity' fields do not match") + errVYParityMissing = errors.New("missing 'yParity' or 'v' field in transaction") + errEmptyTypedTx = errors.New("empty typed transaction bytes") errNoSigner = errors.New("missing signing methods") skipNonceDestinationAddress = map[string]bool{ common.XDCXAddr: true, @@ -44,221 +53,318 @@ var ( } ) -// deriveSigner makes a *best* guess about which signer to use. -func deriveSigner(V *big.Int) Signer { - if V.Sign() != 0 && isProtectedV(V) { - return NewEIP155Signer(deriveChainId(V)) - } else { - return HomesteadSigner{} - } -} +// Transaction types. +const ( + LegacyTxType = iota + AccessListTxType +) +// Transaction is an Ethereum transaction. type Transaction struct { - data txdata + inner TxData // Consensus contents of a transaction + time time.Time // Time first seen locally (spam avoidance) + // caches hash atomic.Value size atomic.Value from atomic.Value } -type txdata struct { - AccountNonce uint64 `json:"nonce" gencodec:"required"` - Price *big.Int `json:"gasPrice" gencodec:"required"` - GasLimit uint64 `json:"gas" gencodec:"required"` - Recipient *common.Address `json:"to" rlp:"nil"` // nil means contract creation - Amount *big.Int `json:"value" gencodec:"required"` - Payload []byte `json:"input" gencodec:"required"` +// NewTx creates a new transaction. +func NewTx(inner TxData) *Transaction { + tx := new(Transaction) + tx.setDecoded(inner.copy(), 0) + return tx +} - // Signature values - V *big.Int `json:"v" gencodec:"required"` - R *big.Int `json:"r" gencodec:"required"` - S *big.Int `json:"s" gencodec:"required"` +// TxData is the underlying data of a transaction. +// +// This is implemented by LegacyTx and AccessListTx. +type TxData interface { + txType() byte // returns the type ID + copy() TxData // creates a deep copy and initializes all fields - // This is only used when marshaling to JSON. - Hash *common.Hash `json:"hash" rlp:"-"` -} + chainID() *big.Int + accessList() AccessList + data() []byte + gas() uint64 + gasPrice() *big.Int + value() *big.Int + nonce() uint64 + to() *common.Address -type txdataMarshaling struct { - AccountNonce hexutil.Uint64 - Price *hexutil.Big - GasLimit hexutil.Uint64 - Amount *hexutil.Big - Payload hexutil.Bytes - V *hexutil.Big - R *hexutil.Big - S *hexutil.Big + rawSignatureValues() (v, r, s *big.Int) + setSignatureValues(chainID, v, r, s *big.Int) } -func NewTransaction(nonce uint64, to common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { - return newTransaction(nonce, &to, amount, gasLimit, gasPrice, data) +// EncodeRLP implements rlp.Encoder +func (tx *Transaction) EncodeRLP(w io.Writer) error { + if tx.Type() == LegacyTxType { + return rlp.Encode(w, tx.inner) + } + // It's an EIP-2718 typed TX envelope. + buf := encodeBufferPool.Get().(*bytes.Buffer) + defer encodeBufferPool.Put(buf) + buf.Reset() + if err := tx.encodeTyped(buf); err != nil { + return err + } + return rlp.Encode(w, buf.Bytes()) } -func NewContractCreation(nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { - return newTransaction(nonce, nil, amount, gasLimit, gasPrice, data) +// encodeTyped writes the canonical encoding of a typed transaction to w. +func (tx *Transaction) encodeTyped(w *bytes.Buffer) error { + w.WriteByte(tx.Type()) + return rlp.Encode(w, tx.inner) } -func newTransaction(nonce uint64, to *common.Address, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte) *Transaction { - if len(data) > 0 { - data = common.CopyBytes(data) +// MarshalBinary returns the canonical encoding of the transaction. +// For legacy transactions, it returns the RLP encoding. For EIP-2718 typed +// transactions, it returns the type and payload. +func (tx *Transaction) MarshalBinary() ([]byte, error) { + if tx.Type() == LegacyTxType { + return rlp.EncodeToBytes(tx.inner) } - d := txdata{ - AccountNonce: nonce, - Recipient: to, - Payload: data, - Amount: new(big.Int), - GasLimit: gasLimit, - Price: new(big.Int), - V: new(big.Int), - R: new(big.Int), - S: new(big.Int), + var buf bytes.Buffer + err := tx.encodeTyped(&buf) + return buf.Bytes(), err +} + +// DecodeRLP implements rlp.Decoder +func (tx *Transaction) DecodeRLP(s *rlp.Stream) error { + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == rlp.List: + // It's a legacy transaction. + var inner LegacyTx + err := s.Decode(&inner) + if err == nil { + tx.setDecoded(&inner, int(rlp.ListSize(size))) + } + return err + case kind == rlp.String: + // It's an EIP-2718 typed TX envelope. + var b []byte + if b, err = s.Bytes(); err != nil { + return err + } + inner, err := tx.decodeTyped(b) + if err == nil { + tx.setDecoded(inner, len(b)) + } + return err + default: + return rlp.ErrExpectedList } - if amount != nil { - d.Amount.Set(amount) +} + +// UnmarshalBinary decodes the canonical encoding of transactions. +// It supports legacy RLP transactions and EIP2718 typed transactions. +func (tx *Transaction) UnmarshalBinary(b []byte) error { + if len(b) > 0 && b[0] > 0x7f { + // It's a legacy transaction. + var data LegacyTx + err := rlp.DecodeBytes(b, &data) + if err != nil { + return err + } + tx.setDecoded(&data, len(b)) + return nil } - if gasPrice != nil { - d.Price.Set(gasPrice) + // It's an EIP2718 typed transaction envelope. + inner, err := tx.decodeTyped(b) + if err != nil { + return err } + tx.setDecoded(inner, len(b)) + return nil +} - return &Transaction{data: d} +// decodeTyped decodes a typed transaction from the canonical format. +func (tx *Transaction) decodeTyped(b []byte) (TxData, error) { + if len(b) == 0 { + return nil, errEmptyTypedTx + } + switch b[0] { + case AccessListTxType: + var inner AccessListTx + err := rlp.DecodeBytes(b[1:], &inner) + return &inner, err + default: + return nil, ErrTxTypeNotSupported + } } -// ChainId returns which chain id this transaction was signed for (if at all) -func (tx *Transaction) ChainId() *big.Int { - return deriveChainId(tx.data.V) +// setDecoded sets the inner transaction and size after decoding. +func (tx *Transaction) setDecoded(inner TxData, size int) { + tx.inner = inner + tx.time = time.Now() + if size > 0 { + tx.size.Store(common.StorageSize(size)) + } } -// Protected returns whether the transaction is protected from replay protection. -func (tx *Transaction) Protected() bool { - return isProtectedV(tx.data.V) +func sanityCheckSignature(v *big.Int, r *big.Int, s *big.Int, maybeProtected bool) error { + if isProtectedV(v) && !maybeProtected { + return ErrUnexpectedProtection + } + + var plainV byte + if isProtectedV(v) { + chainID := deriveChainId(v).Uint64() + plainV = byte(v.Uint64() - 35 - 2*chainID) + } else if maybeProtected { + // Only EIP-155 signatures can be optionally protected. Since + // we determined this v value is not protected, it must be a + // raw 27 or 28. + plainV = byte(v.Uint64() - 27) + } else { + // If the signature is not optionally protected, we assume it + // must already be equal to the recovery id. + plainV = byte(v.Uint64()) + } + if !crypto.ValidateSignatureValues(plainV, r, s, false) { + return ErrInvalidSig + } + + return nil } func isProtectedV(V *big.Int) bool { if V.BitLen() <= 8 { v := V.Uint64() - return v != 27 && v != 28 + return v != 27 && v != 28 && v != 1 && v != 0 } - // anything not 27 or 28 are considered unprotected + // anything not 27 or 28 is considered protected return true } -// EncodeRLP implements rlp.Encoder -func (tx *Transaction) EncodeRLP(w io.Writer) error { - return rlp.Encode(w, &tx.data) -} - -// DecodeRLP implements rlp.Decoder -func (tx *Transaction) DecodeRLP(s *rlp.Stream) error { - _, size, _ := s.Kind() - err := s.Decode(&tx.data) - if err == nil { - tx.size.Store(common.StorageSize(rlp.ListSize(size))) +// Protected says whether the transaction is replay-protected. +func (tx *Transaction) Protected() bool { + switch tx := tx.inner.(type) { + case *LegacyTx: + return tx.V != nil && isProtectedV(tx.V) + default: + return true } - - return err } -// MarshalJSON encodes the web3 RPC transaction format. -func (tx *Transaction) MarshalJSON() ([]byte, error) { - hash := tx.Hash() - data := tx.data - data.Hash = &hash - return data.MarshalJSON() +// Type returns the transaction type. +func (tx *Transaction) Type() uint8 { + return tx.inner.txType() } -// UnmarshalJSON decodes the web3 RPC transaction format. -func (tx *Transaction) UnmarshalJSON(input []byte) error { - var dec txdata - if err := dec.UnmarshalJSON(input); err != nil { - return err - } - var V byte - if isProtectedV(dec.V) { - chainID := deriveChainId(dec.V).Uint64() - V = byte(dec.V.Uint64() - 35 - 2*chainID) - } else { - V = byte(dec.V.Uint64() - 27) - } - if !crypto.ValidateSignatureValues(V, dec.R, dec.S, false) { - return ErrInvalidSig - } - *tx = Transaction{data: dec} - return nil +// ChainId returns the EIP155 chain ID of the transaction. The return value will always be +// non-nil. For legacy transactions which are not replay-protected, the return value is +// zero. +func (tx *Transaction) ChainId() *big.Int { + return tx.inner.chainID() } -func (tx *Transaction) Data() []byte { return common.CopyBytes(tx.data.Payload) } -func (tx *Transaction) Gas() uint64 { return tx.data.GasLimit } -func (tx *Transaction) GasPrice() *big.Int { return new(big.Int).Set(tx.data.Price) } -func (tx *Transaction) Value() *big.Int { return new(big.Int).Set(tx.data.Amount) } -func (tx *Transaction) Nonce() uint64 { return tx.data.AccountNonce } -func (tx *Transaction) CheckNonce() bool { return true } +// Data returns the input data of the transaction. +func (tx *Transaction) Data() []byte { return tx.inner.data() } + +// AccessList returns the access list of the transaction. +func (tx *Transaction) AccessList() AccessList { return tx.inner.accessList() } + +// Gas returns the gas limit of the transaction. +func (tx *Transaction) Gas() uint64 { return tx.inner.gas() } + +// GasPrice returns the gas price of the transaction. +func (tx *Transaction) GasPrice() *big.Int { return new(big.Int).Set(tx.inner.gasPrice()) } + +// Value returns the ether amount of the transaction. +func (tx *Transaction) Value() *big.Int { return new(big.Int).Set(tx.inner.value()) } + +// Nonce returns the sender account nonce of the transaction. +func (tx *Transaction) Nonce() uint64 { return tx.inner.nonce() } // To returns the recipient address of the transaction. -// It returns nil if the transaction is a contract creation. +// For contract-creation transactions, To returns nil. func (tx *Transaction) To() *common.Address { - if tx.data.Recipient == nil { + // Copy the pointed-to address. + ito := tx.inner.to() + if ito == nil { return nil } - to := *tx.data.Recipient - return &to + cpy := *ito + return &cpy } func (tx *Transaction) From() *common.Address { - if tx.data.V != nil { - signer := deriveSigner(tx.data.V) - if f, err := Sender(signer, tx); err != nil { - return nil - } else { - return &f - } + var signer Signer + if tx.Protected() { + signer = LatestSignerForChainID(tx.ChainId()) } else { + signer = HomesteadSigner{} + } + from, err := Sender(signer, tx) + if err != nil { return nil } + return &from } -// Hash hashes the RLP encoding of tx. -// It uniquely identifies the transaction. +// RawSignatureValues returns the V, R, S signature values of the transaction. +// The return values should not be modified by the caller. +func (tx *Transaction) RawSignatureValues() (v, r, s *big.Int) { + return tx.inner.rawSignatureValues() +} + +// GasPriceCmp compares the gas prices of two transactions. +func (tx *Transaction) GasPriceCmp(other *Transaction) int { + return tx.inner.gasPrice().Cmp(other.inner.gasPrice()) +} + +// GasPriceIntCmp compares the gas price of the transaction against the given price. +func (tx *Transaction) GasPriceIntCmp(other *big.Int) int { + return tx.inner.gasPrice().Cmp(other) +} + +// Hash returns the transaction hash. func (tx *Transaction) Hash() common.Hash { if hash := tx.hash.Load(); hash != nil { return hash.(common.Hash) } - v := rlpHash(tx) - tx.hash.Store(v) - return v -} -func (tx *Transaction) CacheHash() { - v := rlpHash(tx) - tx.hash.Store(v) + var h common.Hash + if tx.Type() == LegacyTxType { + h = rlpHash(tx.inner) + } else { + h = prefixedRlpHash(tx.Type(), tx.inner) + } + tx.hash.Store(h) + return h } // Size returns the true RLP encoded storage size of the transaction, either by -// encoding and returning it, or returning a previsouly cached value. +// encoding and returning it, or returning a previously cached value. func (tx *Transaction) Size() common.StorageSize { if size := tx.size.Load(); size != nil { return size.(common.StorageSize) } c := writeCounter(0) - rlp.Encode(&c, &tx.data) + rlp.Encode(&c, &tx.inner) tx.size.Store(common.StorageSize(c)) return common.StorageSize(c) } // AsMessage returns the transaction as a core.Message. -// -// AsMessage requires a signer to derive the sender. -// -// XXX Rename message to something less arbitrary? func (tx *Transaction) AsMessage(s Signer, balanceFee *big.Int, number *big.Int) (Message, error) { msg := Message{ - nonce: tx.data.AccountNonce, - gasLimit: tx.data.GasLimit, - gasPrice: new(big.Int).Set(tx.data.Price), - to: tx.data.Recipient, - amount: tx.data.Amount, - data: tx.data.Payload, + nonce: tx.Nonce(), + gasLimit: tx.Gas(), + gasPrice: new(big.Int).Set(tx.GasPrice()), + to: tx.To(), + amount: tx.Value(), + data: tx.Data(), + accessList: tx.AccessList(), checkNonce: true, balanceTokenFee: balanceFee, } + var err error msg.from, err = Sender(s, tx) if balanceFee != nil { @@ -274,35 +380,31 @@ func (tx *Transaction) AsMessage(s Signer, balanceFee *big.Int, number *big.Int) } // WithSignature returns a new transaction with the given signature. -// This signature needs to be formatted as described in the yellow paper (v+27). +// This signature needs to be in the [R || S || V] format where V is 0 or 1. func (tx *Transaction) WithSignature(signer Signer, sig []byte) (*Transaction, error) { r, s, v, err := signer.SignatureValues(tx, sig) if err != nil { return nil, err } - cpy := &Transaction{data: tx.data} - cpy.data.R, cpy.data.S, cpy.data.V = r, s, v - return cpy, nil + cpy := tx.inner.copy() + cpy.setSignatureValues(signer.ChainID(), v, r, s) + return &Transaction{inner: cpy, time: tx.time}, nil } -// Cost returns amount + gasprice * gaslimit. +// Cost returns gas * gasPrice + value. func (tx *Transaction) Cost() *big.Int { - total := new(big.Int).Mul(tx.data.Price, new(big.Int).SetUint64(tx.data.GasLimit)) - total.Add(total, tx.data.Amount) + total := new(big.Int).Mul(tx.GasPrice(), new(big.Int).SetUint64(tx.Gas())) + total.Add(total, tx.Value()) return total } -// Cost returns amount + gasprice * gaslimit. +// TxCost returns gas * gasPrice + value. func (tx *Transaction) TxCost(number *big.Int) *big.Int { - total := new(big.Int).Mul(common.GetGasPrice(number), new(big.Int).SetUint64(tx.data.GasLimit)) - total.Add(total, tx.data.Amount) + total := new(big.Int).Mul(common.GetGasPrice(number), new(big.Int).SetUint64(tx.Gas())) + total.Add(total, tx.Value()) return total } -func (tx *Transaction) RawSignatureValues() (*big.Int, *big.Int, *big.Int) { - return tx.data.V, tx.data.R, tx.data.S -} - func (tx *Transaction) IsSpecialTransaction() bool { if tx.To() == nil { return false @@ -473,25 +575,24 @@ func (tx *Transaction) IsXDCZApplyTransaction() bool { func (tx *Transaction) String() string { var from, to string - if tx.data.V != nil { - // make a best guess about the signer and use that to derive - // the sender. - signer := deriveSigner(tx.data.V) - if f, err := Sender(signer, tx); err != nil { // derive but don't cache - from = "[invalid sender: invalid sig]" - } else { - from = fmt.Sprintf("%x", f[:]) - } + + sender := tx.From() + if sender != nil { + from = fmt.Sprintf("%x", sender[:]) } else { - from = "[invalid sender: nil V field]" + from = "[invalid sender]" } - if tx.data.Recipient == nil { + receiver := tx.To() + if receiver == nil { to = "[contract creation]" } else { - to = fmt.Sprintf("%x", tx.data.Recipient[:]) + to = fmt.Sprintf("%x", receiver[:]) } - enc, _ := rlp.EncodeToBytes(&tx.data) + + enc, _ := rlp.EncodeToBytes(tx.Data()) + v, r, s := tx.RawSignatureValues() + return fmt.Sprintf(` TX(%x) Contract: %v @@ -508,17 +609,17 @@ func (tx *Transaction) String() string { Hex: %x `, tx.Hash(), - tx.data.Recipient == nil, + receiver == nil, from, to, - tx.data.AccountNonce, - tx.data.Price, - tx.data.GasLimit, - tx.data.Amount, - tx.data.Payload, - tx.data.V, - tx.data.R, - tx.data.S, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.Value(), + tx.Data(), + v, + r, + s, enc, ) } @@ -562,40 +663,47 @@ func TxDifference(a, b Transactions) (keep Transactions) { type TxByNonce Transactions func (s TxByNonce) Len() int { return len(s) } -func (s TxByNonce) Less(i, j int) bool { return s[i].data.AccountNonce < s[j].data.AccountNonce } +func (s TxByNonce) Less(i, j int) bool { return s[i].Nonce() < s[j].Nonce() } func (s TxByNonce) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -// TxByPrice implements both the sort and the heap interface, making it useful +// TxByPriceAndTime implements both the sort and the heap interface, making it useful // for all at once sorting as well as individually adding and removing elements. -type TxByPrice struct { +type TxByPriceAndTime struct { txs Transactions payersSwap map[common.Address]*big.Int } -func (s TxByPrice) Len() int { return len(s.txs) } -func (s TxByPrice) Less(i, j int) bool { - i_price := s.txs[i].data.Price +func (s TxByPriceAndTime) Len() int { return len(s.txs) } +func (s TxByPriceAndTime) Less(i, j int) bool { + i_price := s.txs[i].GasPrice() if s.txs[i].To() != nil { if _, ok := s.payersSwap[*s.txs[i].To()]; ok { i_price = common.TRC21GasPrice } } - j_price := s.txs[j].data.Price + j_price := s.txs[j].GasPrice() if s.txs[j].To() != nil { if _, ok := s.payersSwap[*s.txs[j].To()]; ok { j_price = common.TRC21GasPrice } } - return i_price.Cmp(j_price) > 0 + + // If the prices are equal, use the time the transaction was first seen for + // deterministic sorting + cmp := i_price.Cmp(j_price) + if cmp == 0 { + return s.txs[i].time.Before(s.txs[j].time) + } + return cmp > 0 } -func (s TxByPrice) Swap(i, j int) { s.txs[i], s.txs[j] = s.txs[j], s.txs[i] } +func (s TxByPriceAndTime) Swap(i, j int) { s.txs[i], s.txs[j] = s.txs[j], s.txs[i] } -func (s *TxByPrice) Push(x interface{}) { +func (s *TxByPriceAndTime) Push(x interface{}) { s.txs = append(s.txs, x.(*Transaction)) } -func (s *TxByPrice) Pop() interface{} { +func (s *TxByPriceAndTime) Pop() interface{} { old := s.txs n := len(old) x := old[n-1] @@ -608,7 +716,7 @@ func (s *TxByPrice) Pop() interface{} { // entire batches of transactions for non-executable accounts. type TransactionsByPriceAndNonce struct { txs map[common.Address]Transactions // Per account nonce-sorted list of transactions - heads TxByPrice // Next transaction for each unique account (price heap) + heads TxByPriceAndTime // Next transaction for each unique account (price heap) signer Signer // Signer for the set of transactions } @@ -617,11 +725,11 @@ type TransactionsByPriceAndNonce struct { // // Note, the input map is reowned so the caller should not interact any more with // if after providing it to the constructor. - +// // It also classifies special txs and normal txs func NewTransactionsByPriceAndNonce(signer Signer, txs map[common.Address]Transactions, signers map[common.Address]struct{}, payersSwap map[common.Address]*big.Int) (*TransactionsByPriceAndNonce, Transactions) { - // Initialize a price based heap with the head transactions - heads := TxByPrice{} + // Initialize a price and received time based heap with the head transactions + heads := TxByPriceAndTime{} heads.payersSwap = payersSwap specialTxs := Transactions{} for _, accTxs := range txs { @@ -698,11 +806,12 @@ type Message struct { gasLimit uint64 gasPrice *big.Int data []byte + accessList AccessList checkNonce bool balanceTokenFee *big.Int } -func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte, checkNonce bool, balanceTokenFee *big.Int, number *big.Int) Message { +func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *big.Int, gasLimit uint64, gasPrice *big.Int, data []byte, accessList AccessList, checkNonce bool, balanceTokenFee *big.Int, number *big.Int) Message { if balanceTokenFee != nil { gasPrice = common.GetGasPrice(number) } @@ -714,6 +823,7 @@ func NewMessage(from common.Address, to *common.Address, nonce uint64, amount *b gasLimit: gasLimit, gasPrice: gasPrice, data: data, + accessList: accessList, checkNonce: checkNonce, balanceTokenFee: balanceTokenFee, } @@ -728,5 +838,6 @@ func (m Message) Gas() uint64 { return m.gasLimit } func (m Message) Nonce() uint64 { return m.nonce } func (m Message) Data() []byte { return m.data } func (m Message) CheckNonce() bool { return m.checkNonce } +func (m Message) AccessList() AccessList { return m.accessList } func (m *Message) SetNonce(nonce uint64) { m.nonce = nonce } diff --git a/core/types/transaction_marshalling.go b/core/types/transaction_marshalling.go new file mode 100644 index 000000000000..91403994bf7e --- /dev/null +++ b/core/types/transaction_marshalling.go @@ -0,0 +1,187 @@ +package types + +import ( + "encoding/json" + "errors" + "math/big" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/hexutil" +) + +// txJSON is the JSON representation of transactions. +type txJSON struct { + Type hexutil.Uint64 `json:"type"` + + // Common transaction fields: + Nonce *hexutil.Uint64 `json:"nonce"` + GasPrice *hexutil.Big `json:"gasPrice"` + Gas *hexutil.Uint64 `json:"gas"` + Value *hexutil.Big `json:"value"` + Data *hexutil.Bytes `json:"input"` + V *hexutil.Big `json:"v"` + R *hexutil.Big `json:"r"` + S *hexutil.Big `json:"s"` + To *common.Address `json:"to"` + + // Access list transaction fields: + ChainID *hexutil.Big `json:"chainId,omitempty"` + AccessList *AccessList `json:"accessList,omitempty"` + + // Only used for encoding: + Hash common.Hash `json:"hash"` +} + +// MarshalJSON marshals as JSON with a hash. +func (t *Transaction) MarshalJSON() ([]byte, error) { + var enc txJSON + // These are set for all tx types. + enc.Hash = t.Hash() + enc.Type = hexutil.Uint64(t.Type()) + + // Other fields are set conditionally depending on tx type. + switch tx := t.inner.(type) { + case *LegacyTx: + enc.Nonce = (*hexutil.Uint64)(&tx.Nonce) + enc.Gas = (*hexutil.Uint64)(&tx.Gas) + enc.GasPrice = (*hexutil.Big)(tx.GasPrice) + enc.Value = (*hexutil.Big)(tx.Value) + enc.Data = (*hexutil.Bytes)(&tx.Data) + enc.To = t.To() + enc.V = (*hexutil.Big)(tx.V) + enc.R = (*hexutil.Big)(tx.R) + enc.S = (*hexutil.Big)(tx.S) + case *AccessListTx: + enc.ChainID = (*hexutil.Big)(tx.ChainID) + enc.AccessList = &tx.AccessList + enc.Nonce = (*hexutil.Uint64)(&tx.Nonce) + enc.Gas = (*hexutil.Uint64)(&tx.Gas) + enc.GasPrice = (*hexutil.Big)(tx.GasPrice) + enc.Value = (*hexutil.Big)(tx.Value) + enc.Data = (*hexutil.Bytes)(&tx.Data) + enc.To = t.To() + enc.V = (*hexutil.Big)(tx.V) + enc.R = (*hexutil.Big)(tx.R) + enc.S = (*hexutil.Big)(tx.S) + } + return json.Marshal(&enc) +} + +// UnmarshalJSON unmarshals from JSON. +func (t *Transaction) UnmarshalJSON(input []byte) error { + var dec txJSON + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + + // Decode / verify fields according to transaction type. + var inner TxData + switch dec.Type { + case LegacyTxType: + var itx LegacyTx + inner = &itx + if dec.To != nil { + itx.To = dec.To + } + if dec.Nonce == nil { + return errors.New("missing required field 'nonce' in transaction") + } + itx.Nonce = uint64(*dec.Nonce) + if dec.GasPrice == nil { + return errors.New("missing required field 'gasPrice' in transaction") + } + itx.GasPrice = (*big.Int)(dec.GasPrice) + if dec.Gas == nil { + return errors.New("missing required field 'gas' in transaction") + } + itx.Gas = uint64(*dec.Gas) + if dec.Value == nil { + return errors.New("missing required field 'value' in transaction") + } + itx.Value = (*big.Int)(dec.Value) + if dec.Data == nil { + return errors.New("missing required field 'input' in transaction") + } + itx.Data = *dec.Data + if dec.V == nil { + return errors.New("missing required field 'v' in transaction") + } + itx.V = (*big.Int)(dec.V) + if dec.R == nil { + return errors.New("missing required field 'r' in transaction") + } + itx.R = (*big.Int)(dec.R) + if dec.S == nil { + return errors.New("missing required field 's' in transaction") + } + itx.S = (*big.Int)(dec.S) + withSignature := itx.V.Sign() != 0 || itx.R.Sign() != 0 || itx.S.Sign() != 0 + if withSignature { + if err := sanityCheckSignature(itx.V, itx.R, itx.S, true); err != nil { + return err + } + } + + case AccessListTxType: + var itx AccessListTx + inner = &itx + // Access list is optional for now. + if dec.AccessList != nil { + itx.AccessList = *dec.AccessList + } + if dec.ChainID == nil { + return errors.New("missing required field 'chainId' in transaction") + } + itx.ChainID = (*big.Int)(dec.ChainID) + if dec.To != nil { + itx.To = dec.To + } + if dec.Nonce == nil { + return errors.New("missing required field 'nonce' in transaction") + } + itx.Nonce = uint64(*dec.Nonce) + if dec.GasPrice == nil { + return errors.New("missing required field 'gasPrice' in transaction") + } + itx.GasPrice = (*big.Int)(dec.GasPrice) + if dec.Gas == nil { + return errors.New("missing required field 'gas' in transaction") + } + itx.Gas = uint64(*dec.Gas) + if dec.Value == nil { + return errors.New("missing required field 'value' in transaction") + } + itx.Value = (*big.Int)(dec.Value) + if dec.Data == nil { + return errors.New("missing required field 'input' in transaction") + } + itx.Data = *dec.Data + if dec.V == nil { + return errors.New("missing required field 'v' in transaction") + } + itx.V = (*big.Int)(dec.V) + if dec.R == nil { + return errors.New("missing required field 'r' in transaction") + } + itx.R = (*big.Int)(dec.R) + if dec.S == nil { + return errors.New("missing required field 's' in transaction") + } + itx.S = (*big.Int)(dec.S) + withSignature := itx.V.Sign() != 0 || itx.R.Sign() != 0 || itx.S.Sign() != 0 + if withSignature { + if err := sanityCheckSignature(itx.V, itx.R, itx.S, false); err != nil { + return err + } + } + + default: + return ErrTxTypeNotSupported + } + + // Now set the inner transaction. + t.setDecoded(inner, 0) + + // TODO: check hash here? + return nil +} diff --git a/core/types/transaction_signing.go b/core/types/transaction_signing.go index 3658ed43b9b3..f4174dae4858 100644 --- a/core/types/transaction_signing.go +++ b/core/types/transaction_signing.go @@ -27,9 +27,8 @@ import ( "github.com/XinFinOrg/XDPoSChain/params" ) -var ( - ErrInvalidChainId = errors.New("invalid chain id for signer") -) +var ErrInvalidChainId = errors.New("invalid chain id for signer") +var ErrInvalidNilTx = errors.New("invalid nil tx") // sigCache is used to cache the derived sender and contains // the signer used to derive it. @@ -42,6 +41,8 @@ type sigCache struct { func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { var signer Signer switch { + case config.IsEIP1559(blockNumber): + signer = NewEIP2930Signer(config.ChainId) case config.IsEIP155(blockNumber): signer = NewEIP155Signer(config.ChainId) case config.IsHomestead(blockNumber): @@ -52,7 +53,40 @@ func MakeSigner(config *params.ChainConfig, blockNumber *big.Int) Signer { return signer } -// SignTx signs the transaction using the given signer and private key +// LatestSigner returns the 'most permissive' Signer available for the given chain +// configuration. Specifically, this enables support of EIP-155 replay protection and +// EIP-2930 access list transactions when their respective forks are scheduled to occur at +// any block number in the chain config. +// +// Use this in transaction-handling code where the current block number is unknown. If you +// have the current block number available, use MakeSigner instead. +func LatestSigner(config *params.ChainConfig) Signer { + if config.ChainId != nil { + if common.Eip1559Block.Uint64() != 9999999999 || config.Eip1559Block != nil { + return NewEIP2930Signer(config.ChainId) + } + if config.EIP155Block != nil { + return NewEIP155Signer(config.ChainId) + } + } + return HomesteadSigner{} +} + +// LatestSignerForChainID returns the 'most permissive' Signer available. Specifically, +// this enables support for EIP-155 replay protection and all implemented EIP-2718 +// transaction types if chainID is non-nil. +// +// Use this in transaction-handling code where the current block number and fork +// configuration are unknown. If you have a ChainConfig, use LatestSigner instead. +// If you have a ChainConfig and know the current block number, use MakeSigner instead. +func LatestSignerForChainID(chainID *big.Int) Signer { + if chainID == nil { + return HomesteadSigner{} + } + return NewEIP2930Signer(chainID) +} + +// SignTx signs the transaction using the given signer and private key. func SignTx(tx *Transaction, s Signer, prv *ecdsa.PrivateKey) (*Transaction, error) { h := s.Hash(tx) sig, err := crypto.Sign(h[:], prv) @@ -62,6 +96,27 @@ func SignTx(tx *Transaction, s Signer, prv *ecdsa.PrivateKey) (*Transaction, err return tx.WithSignature(s, sig) } +// SignNewTx creates a transaction and signs it. +func SignNewTx(prv *ecdsa.PrivateKey, s Signer, txdata TxData) (*Transaction, error) { + tx := NewTx(txdata) + h := s.Hash(tx) + sig, err := crypto.Sign(h[:], prv) + if err != nil { + return nil, err + } + return tx.WithSignature(s, sig) +} + +// MustSignNewTx creates a transaction and signs it. +// This panics if the transaction cannot be signed. +func MustSignNewTx(prv *ecdsa.PrivateKey, s Signer, txdata TxData) *Transaction { + tx, err := SignNewTx(prv, s, txdata) + if err != nil { + panic(err) + } + return tx +} + // Sender returns the address derived from the signature (V, R, S) using secp256k1 // elliptic curve and an error if it failed deriving or upon an incorrect // signature. @@ -70,6 +125,10 @@ func SignTx(tx *Transaction, s Signer, prv *ecdsa.PrivateKey) (*Transaction, err // signing method. The cache is invalidated if the cached signer does // not match the signer used in the current call. func Sender(signer Signer, tx *Transaction) (common.Address, error) { + if tx == nil { + return common.Address{}, ErrInvalidNilTx + } + if sc := tx.from.Load(); sc != nil { sigCache := sc.(sigCache) // If the signer used to derive from in a previous @@ -88,21 +147,112 @@ func Sender(signer Signer, tx *Transaction) (common.Address, error) { return addr, nil } -// Signer encapsulates transaction signature handling. Note that this interface is not a -// stable API and may change at any time to accommodate new protocol rules. +// Signer encapsulates transaction signature handling. The name of this type is slightly +// misleading because Signers don't actually sign, they're just for validating and +// processing of signatures. +// +// Note that this interface is not a stable API and may change at any time to accommodate +// new protocol rules. type Signer interface { // Sender returns the sender address of the transaction. Sender(tx *Transaction) (common.Address, error) + // SignatureValues returns the raw R, S, V values corresponding to the // given signature. SignatureValues(tx *Transaction, sig []byte) (r, s, v *big.Int, err error) - // Hash returns the hash to be signed. + ChainID() *big.Int + + // Hash returns 'signature hash', i.e. the transaction hash that is signed by the + // private key. This hash does not uniquely identify the transaction. Hash(tx *Transaction) common.Hash + // Equal returns true if the given signer is the same as the receiver. Equal(Signer) bool } -// EIP155Transaction implements Signer using the EIP155 rules. +type eip2930Signer struct{ EIP155Signer } + +// NewEIP2930Signer returns a signer that accepts EIP-2930 access list transactions, +// EIP-155 replay protected transactions, and legacy Homestead transactions. +func NewEIP2930Signer(chainId *big.Int) Signer { + return eip2930Signer{NewEIP155Signer(chainId)} +} + +func (s eip2930Signer) ChainID() *big.Int { + return s.chainId +} + +func (s eip2930Signer) Equal(s2 Signer) bool { + x, ok := s2.(eip2930Signer) + return ok && x.chainId.Cmp(s.chainId) == 0 +} + +func (s eip2930Signer) Sender(tx *Transaction) (common.Address, error) { + V, R, S := tx.RawSignatureValues() + switch tx.Type() { + case LegacyTxType: + return s.EIP155Signer.Sender(tx) + case AccessListTxType: + // ACL txs are defined to use 0 and 1 as their recovery id, add + // 27 to become equivalent to unprotected Homestead signatures. + V = new(big.Int).Add(V, big.NewInt(27)) + default: + return common.Address{}, ErrTxTypeNotSupported + } + if tx.ChainId().Cmp(s.chainId) != 0 { + return common.Address{}, ErrInvalidChainId + } + return recoverPlain(s.Hash(tx), R, S, V, true) +} + +func (s eip2930Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big.Int, err error) { + switch txdata := tx.inner.(type) { + case *LegacyTx: + return s.EIP155Signer.SignatureValues(tx, sig) + case *AccessListTx: + // Check that chain ID of tx matches the signer. We also accept ID zero here, + // because it indicates that the chain ID was not specified in the tx. + if txdata.ChainID.Sign() != 0 && txdata.ChainID.Cmp(s.chainId) != 0 { + return nil, nil, nil, ErrInvalidChainId + } + R, S, _ = decodeSignature(sig) + V = big.NewInt(int64(sig[64])) + default: + return nil, nil, nil, ErrTxTypeNotSupported + } + return R, S, V, nil +} + +// Hash returns the hash to be signed by the sender. +// It does not uniquely identify the transaction. +func (s eip2930Signer) Hash(tx *Transaction) common.Hash { + switch tx.Type() { + case LegacyTxType: + return s.EIP155Signer.Hash(tx) + case AccessListTxType: + return prefixedRlpHash( + tx.Type(), + []interface{}{ + s.chainId, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), + tx.AccessList(), + }) + default: + // This _should_ not happen, but in case someone sends in a bad + // json struct via RPC, it's probably more prudent to return an + // empty hash instead of killing the node with a panic + //panic("Unsupported transaction type: %d", tx.typ) + return common.Hash{} + } +} + +// EIP155Signer implements Signer using the EIP-155 rules. This accepts transactions which +// are replay-protected as well as unprotected homestead transactions. type EIP155Signer struct { chainId, chainIdMul *big.Int } @@ -117,6 +267,10 @@ func NewEIP155Signer(chainId *big.Int) EIP155Signer { } } +func (s EIP155Signer) ChainID() *big.Int { + return s.chainId +} + func (s EIP155Signer) Equal(s2 Signer) bool { eip155, ok := s2.(EIP155Signer) return ok && eip155.chainId.Cmp(s.chainId) == 0 @@ -125,24 +279,28 @@ func (s EIP155Signer) Equal(s2 Signer) bool { var big8 = big.NewInt(8) func (s EIP155Signer) Sender(tx *Transaction) (common.Address, error) { + if tx.Type() != LegacyTxType { + return common.Address{}, ErrTxTypeNotSupported + } if !tx.Protected() { return HomesteadSigner{}.Sender(tx) } if tx.ChainId().Cmp(s.chainId) != 0 { return common.Address{}, ErrInvalidChainId } - V := new(big.Int).Sub(tx.data.V, s.chainIdMul) + V, R, S := tx.RawSignatureValues() + V = new(big.Int).Sub(V, s.chainIdMul) V.Sub(V, big8) - return recoverPlain(s.Hash(tx), tx.data.R, tx.data.S, V, true) + return recoverPlain(s.Hash(tx), R, S, V, true) } -// WithSignature returns a new transaction with the given signature. This signature +// SignatureValues returns signature values. This signature // needs to be in the [R || S || V] format where V is 0 or 1. func (s EIP155Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big.Int, err error) { - R, S, V, err = HomesteadSigner{}.SignatureValues(tx, sig) - if err != nil { - return nil, nil, nil, err + if tx.Type() != LegacyTxType { + return nil, nil, nil, ErrTxTypeNotSupported } + R, S, V = decodeSignature(sig) if s.chainId.Sign() != 0 { V = big.NewInt(int64(sig[64] + 35)) V.Add(V, s.chainIdMul) @@ -154,12 +312,12 @@ func (s EIP155Signer) SignatureValues(tx *Transaction, sig []byte) (R, S, V *big // It does not uniquely identify the transaction. func (s EIP155Signer) Hash(tx *Transaction) common.Hash { return rlpHash([]interface{}{ - tx.data.AccountNonce, - tx.data.Price, - tx.data.GasLimit, - tx.data.Recipient, - tx.data.Amount, - tx.data.Payload, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), s.chainId, uint(0), uint(0), }) } @@ -168,6 +326,10 @@ func (s EIP155Signer) Hash(tx *Transaction) common.Hash { // homestead rules. type HomesteadSigner struct{ FrontierSigner } +func (s HomesteadSigner) ChainID() *big.Int { + return nil +} + func (s HomesteadSigner) Equal(s2 Signer) bool { _, ok := s2.(HomesteadSigner) return ok @@ -180,25 +342,39 @@ func (hs HomesteadSigner) SignatureValues(tx *Transaction, sig []byte) (r, s, v } func (hs HomesteadSigner) Sender(tx *Transaction) (common.Address, error) { - return recoverPlain(hs.Hash(tx), tx.data.R, tx.data.S, tx.data.V, true) + if tx.Type() != LegacyTxType { + return common.Address{}, ErrTxTypeNotSupported + } + v, r, s := tx.RawSignatureValues() + return recoverPlain(hs.Hash(tx), r, s, v, true) } type FrontierSigner struct{} +func (s FrontierSigner) ChainID() *big.Int { + return nil +} + func (s FrontierSigner) Equal(s2 Signer) bool { _, ok := s2.(FrontierSigner) return ok } +func (fs FrontierSigner) Sender(tx *Transaction) (common.Address, error) { + if tx.Type() != LegacyTxType { + return common.Address{}, ErrTxTypeNotSupported + } + v, r, s := tx.RawSignatureValues() + return recoverPlain(fs.Hash(tx), r, s, v, false) +} + // SignatureValues returns signature values. This signature // needs to be in the [R || S || V] format where V is 0 or 1. func (fs FrontierSigner) SignatureValues(tx *Transaction, sig []byte) (r, s, v *big.Int, err error) { - if len(sig) != 65 { - panic(fmt.Sprintf("wrong size for signature: got %d, want 65", len(sig))) + if tx.Type() != LegacyTxType { + return nil, nil, nil, ErrTxTypeNotSupported } - r = new(big.Int).SetBytes(sig[:32]) - s = new(big.Int).SetBytes(sig[32:64]) - v = new(big.Int).SetBytes([]byte{sig[64] + 27}) + r, s, v = decodeSignature(sig) return r, s, v, nil } @@ -206,17 +382,23 @@ func (fs FrontierSigner) SignatureValues(tx *Transaction, sig []byte) (r, s, v * // It does not uniquely identify the transaction. func (fs FrontierSigner) Hash(tx *Transaction) common.Hash { return rlpHash([]interface{}{ - tx.data.AccountNonce, - tx.data.Price, - tx.data.GasLimit, - tx.data.Recipient, - tx.data.Amount, - tx.data.Payload, + tx.Nonce(), + tx.GasPrice(), + tx.Gas(), + tx.To(), + tx.Value(), + tx.Data(), }) } -func (fs FrontierSigner) Sender(tx *Transaction) (common.Address, error) { - return recoverPlain(fs.Hash(tx), tx.data.R, tx.data.S, tx.data.V, false) +func decodeSignature(sig []byte) (r, s, v *big.Int) { + if len(sig) != crypto.SignatureLength { + panic(fmt.Sprintf("wrong size for signature: got %d, want %d", len(sig), crypto.SignatureLength)) + } + r = new(big.Int).SetBytes(sig[:32]) + s = new(big.Int).SetBytes(sig[32:64]) + v = new(big.Int).SetBytes([]byte{sig[64] + 27}) + return r, s, v } func recoverPlain(sighash common.Hash, R, S, Vb *big.Int, homestead bool) (common.Address, error) { diff --git a/core/types/transaction_test.go b/core/types/transaction_test.go index 234ef322c513..500dec7227f9 100644 --- a/core/types/transaction_test.go +++ b/core/types/transaction_test.go @@ -20,8 +20,12 @@ import ( "bytes" "crypto/ecdsa" "encoding/json" + "errors" + "fmt" "math/big" + "reflect" "testing" + "time" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/crypto" @@ -31,6 +35,8 @@ import ( // The values in those tests are from the Transaction Tests // at github.com/ethereum/tests. var ( + testAddr = common.HexToAddress("b94f5374fce5edbc8e2a8697c15331677e6ebf0b") + emptyTx = NewTransaction( 0, common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87"), @@ -40,7 +46,7 @@ var ( rightvrsTx, _ = NewTransaction( 3, - common.HexToAddress("b94f5374fce5edbc8e2a8697c15331677e6ebf0b"), + testAddr, big.NewInt(10), 2000, big.NewInt(1), @@ -49,8 +55,32 @@ var ( HomesteadSigner{}, common.Hex2Bytes("98ff921201554726367d2be8c804a7ff89ccf285ebc57dff8ae4c44b9c19ac4a8887321be575c8095f789dd4c743dfe42c1820f9231f98a962b210e3ac2452a301"), ) + + emptyEip2718Tx = NewTx(&AccessListTx{ + ChainID: big.NewInt(1), + Nonce: 3, + To: &testAddr, + Value: big.NewInt(10), + Gas: 25000, + GasPrice: big.NewInt(1), + Data: common.FromHex("5544"), + }) + + signedEip2718Tx, _ = emptyEip2718Tx.WithSignature( + NewEIP2930Signer(big.NewInt(1)), + common.Hex2Bytes("c9519f4f2b30335884581971573fadf60c6204f59a911df35ee8a540456b266032f1e8e2c5dd761f9e4f88f41c8310aeaba26a8bfcdacfedfa12ec3862d3752101"), + ) ) +func TestDecodeEmptyTypedTx(t *testing.T) { + input := []byte{0x80} + var tx Transaction + err := rlp.DecodeBytes(input, &tx) + if err != errEmptyTypedTx { + t.Fatal("wrong error:", err) + } +} + func TestTransactionSigHash(t *testing.T) { var homestead HomesteadSigner if homestead.Hash(emptyTx) != common.HexToHash("c775b99e7ad12f50d819fcd602390467e28141316969f4b57f0626f74fe3b386") { @@ -72,6 +102,117 @@ func TestTransactionEncode(t *testing.T) { } } +func TestEIP2718TransactionSigHash(t *testing.T) { + s := NewEIP2930Signer(big.NewInt(1)) + if s.Hash(emptyEip2718Tx) != common.HexToHash("49b486f0ec0a60dfbbca2d30cb07c9e8ffb2a2ff41f29a1ab6737475f6ff69f3") { + t.Errorf("empty EIP-2718 transaction hash mismatch, got %x", s.Hash(emptyEip2718Tx)) + } + if s.Hash(signedEip2718Tx) != common.HexToHash("49b486f0ec0a60dfbbca2d30cb07c9e8ffb2a2ff41f29a1ab6737475f6ff69f3") { + t.Errorf("signed EIP-2718 transaction hash mismatch, got %x", s.Hash(signedEip2718Tx)) + } +} + +// This test checks signature operations on access list transactions. +func TestEIP2930Signer(t *testing.T) { + var ( + key, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + keyAddr = crypto.PubkeyToAddress(key.PublicKey) + signer1 = NewEIP2930Signer(big.NewInt(1)) + signer2 = NewEIP2930Signer(big.NewInt(2)) + tx0 = NewTx(&AccessListTx{Nonce: 1}) + tx1 = NewTx(&AccessListTx{ChainID: big.NewInt(1), Nonce: 1}) + tx2, _ = SignNewTx(key, signer2, &AccessListTx{ChainID: big.NewInt(2), Nonce: 1}) + ) + + tests := []struct { + tx *Transaction + signer Signer + wantSignerHash common.Hash + wantSenderErr error + wantSignErr error + wantHash common.Hash // after signing + }{ + { + tx: tx0, + signer: signer1, + wantSignerHash: common.HexToHash("846ad7672f2a3a40c1f959cd4a8ad21786d620077084d84c8d7c077714caa139"), + wantSenderErr: ErrInvalidChainId, + wantHash: common.HexToHash("1ccd12d8bbdb96ea391af49a35ab641e219b2dd638dea375f2bc94dd290f2549"), + }, + { + tx: tx1, + signer: signer1, + wantSenderErr: ErrInvalidSig, + wantSignerHash: common.HexToHash("846ad7672f2a3a40c1f959cd4a8ad21786d620077084d84c8d7c077714caa139"), + wantHash: common.HexToHash("1ccd12d8bbdb96ea391af49a35ab641e219b2dd638dea375f2bc94dd290f2549"), + }, + { + // This checks what happens when trying to sign an unsigned tx for the wrong chain. + tx: tx1, + signer: signer2, + wantSenderErr: ErrInvalidChainId, + wantSignerHash: common.HexToHash("367967247499343401261d718ed5aa4c9486583e4d89251afce47f4a33c33362"), + wantSignErr: ErrInvalidChainId, + }, + { + // This checks what happens when trying to re-sign a signed tx for the wrong chain. + tx: tx2, + signer: signer1, + wantSenderErr: ErrInvalidChainId, + wantSignerHash: common.HexToHash("846ad7672f2a3a40c1f959cd4a8ad21786d620077084d84c8d7c077714caa139"), + wantSignErr: ErrInvalidChainId, + }, + } + + for i, test := range tests { + sigHash := test.signer.Hash(test.tx) + if sigHash != test.wantSignerHash { + t.Errorf("test %d: wrong sig hash: got %x, want %x", i, sigHash, test.wantSignerHash) + } + sender, err := Sender(test.signer, test.tx) + if !errors.Is(err, test.wantSenderErr) { + t.Errorf("test %d: wrong Sender error %q", i, err) + } + if err == nil && sender != keyAddr { + t.Errorf("test %d: wrong sender address %x", i, sender) + } + signedTx, err := SignTx(test.tx, test.signer, key) + if !errors.Is(err, test.wantSignErr) { + t.Fatalf("test %d: wrong SignTx error %q", i, err) + } + if signedTx != nil { + if signedTx.Hash() != test.wantHash { + t.Errorf("test %d: wrong tx hash after signing: got %x, want %x", i, signedTx.Hash(), test.wantHash) + } + } + } +} + +func TestEIP2718TransactionEncode(t *testing.T) { + // RLP representation + { + have, err := rlp.EncodeToBytes(signedEip2718Tx) + if err != nil { + t.Fatalf("encode error: %v", err) + } + want := common.FromHex("b86601f8630103018261a894b94f5374fce5edbc8e2a8697c15331677e6ebf0b0a825544c001a0c9519f4f2b30335884581971573fadf60c6204f59a911df35ee8a540456b2660a032f1e8e2c5dd761f9e4f88f41c8310aeaba26a8bfcdacfedfa12ec3862d37521") + if !bytes.Equal(have, want) { + t.Errorf("encoded RLP mismatch, got %x", have) + } + } + // Binary representation + { + have, err := signedEip2718Tx.MarshalBinary() + if err != nil { + t.Fatalf("encode error: %v", err) + } + want := common.FromHex("01f8630103018261a894b94f5374fce5edbc8e2a8697c15331677e6ebf0b0a825544c001a0c9519f4f2b30335884581971573fadf60c6204f59a911df35ee8a540456b2660a032f1e8e2c5dd761f9e4f88f41c8310aeaba26a8bfcdacfedfa12ec3862d37521") + if !bytes.Equal(have, want) { + t.Errorf("encoded RLP mismatch, got %x", have) + } + } +} + func decodeTx(data []byte) (*Transaction, error) { var tx Transaction t, err := &tx, rlp.Decode(bytes.NewReader(data), &tx) @@ -233,3 +374,174 @@ func TestTransactionJSON(t *testing.T) { } } } + +// Tests that if multiple transactions have the same price, the ones seen earlier +// are prioritized to avoid network spam attacks aiming for a specific ordering. +func TestTransactionTimeSort(t *testing.T) { + // Generate a batch of accounts to start with + keys := make([]*ecdsa.PrivateKey, 5) + for i := 0; i < len(keys); i++ { + keys[i], _ = crypto.GenerateKey() + } + signer := HomesteadSigner{} + + // Generate a batch of transactions with overlapping prices, but different creation times + groups := map[common.Address]Transactions{} + for start, key := range keys { + addr := crypto.PubkeyToAddress(key.PublicKey) + + tx, _ := SignTx(NewTransaction(0, common.Address{}, big.NewInt(100), 100, big.NewInt(1), nil), signer, key) + tx.time = time.Unix(0, int64(len(keys)-start)) + + groups[addr] = append(groups[addr], tx) + } + // Sort the transactions and cross check the nonce ordering + txset, _ := NewTransactionsByPriceAndNonce(signer, groups, nil, map[common.Address]*big.Int{}) + + txs := Transactions{} + for tx := txset.Peek(); tx != nil; tx = txset.Peek() { + txs = append(txs, tx) + txset.Shift() + } + if len(txs) != len(keys) { + t.Errorf("expected %d transactions, found %d", len(keys), len(txs)) + } + for i, txi := range txs { + fromi, _ := Sender(signer, txi) + if i+1 < len(txs) { + next := txs[i+1] + fromNext, _ := Sender(signer, next) + + if txi.GasPrice().Cmp(next.GasPrice()) < 0 { + t.Errorf("invalid gasprice ordering: tx #%d (A=%x P=%v) < tx #%d (A=%x P=%v)", i, fromi[:4], txi.GasPrice(), i+1, fromNext[:4], next.GasPrice()) + } + // Make sure time order is ascending if the txs have the same gas price + if txi.GasPrice().Cmp(next.GasPrice()) == 0 && txi.time.After(next.time) { + t.Errorf("invalid received time ordering: tx #%d (A=%x T=%v) > tx #%d (A=%x T=%v)", i, fromi[:4], txi.time, i+1, fromNext[:4], next.time) + } + } + } +} + +// TestTransactionCoding tests serializing/de-serializing to/from rlp and JSON. +func TestTransactionCoding(t *testing.T) { + key, err := crypto.GenerateKey() + if err != nil { + t.Fatalf("could not generate key: %v", err) + } + var ( + signer = NewEIP2930Signer(common.Big1) + addr = common.HexToAddress("0x0000000000000000000000000000000000000001") + recipient = common.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87") + accesses = AccessList{{Address: addr, StorageKeys: []common.Hash{{0}}}} + ) + for i := uint64(0); i < 500; i++ { + var txdata TxData + switch i % 5 { + case 0: + // Legacy tx. + txdata = &LegacyTx{ + Nonce: i, + To: &recipient, + Gas: 1, + GasPrice: big.NewInt(2), + Data: []byte("abcdef"), + } + case 1: + // Legacy tx contract creation. + txdata = &LegacyTx{ + Nonce: i, + Gas: 1, + GasPrice: big.NewInt(2), + Data: []byte("abcdef"), + } + case 2: + // Tx with non-zero access list. + txdata = &AccessListTx{ + ChainID: big.NewInt(1), + Nonce: i, + To: &recipient, + Gas: 123457, + GasPrice: big.NewInt(10), + AccessList: accesses, + Data: []byte("abcdef"), + } + case 3: + // Tx with empty access list. + txdata = &AccessListTx{ + ChainID: big.NewInt(1), + Nonce: i, + To: &recipient, + Gas: 123457, + GasPrice: big.NewInt(10), + Data: []byte("abcdef"), + } + case 4: + // Contract creation with access list. + txdata = &AccessListTx{ + ChainID: big.NewInt(1), + Nonce: i, + Gas: 123457, + GasPrice: big.NewInt(10), + AccessList: accesses, + } + } + tx, err := SignNewTx(key, signer, txdata) + if err != nil { + t.Fatalf("could not sign transaction: %v", err) + } + // RLP + parsedTx, err := encodeDecodeBinary(tx) + if err != nil { + t.Fatal(err) + } + assertEqual(parsedTx, tx) + + // JSON + parsedTx, err = encodeDecodeJSON(tx) + if err != nil { + t.Fatal(err) + } + assertEqual(parsedTx, tx) + } +} + +func encodeDecodeJSON(tx *Transaction) (*Transaction, error) { + data, err := json.Marshal(tx) + if err != nil { + return nil, fmt.Errorf("json encoding failed: %v", err) + } + var parsedTx = &Transaction{} + if err := json.Unmarshal(data, &parsedTx); err != nil { + return nil, fmt.Errorf("json decoding failed: %v", err) + } + return parsedTx, nil +} + +func encodeDecodeBinary(tx *Transaction) (*Transaction, error) { + data, err := tx.MarshalBinary() + if err != nil { + return nil, fmt.Errorf("rlp encoding failed: %v", err) + } + var parsedTx = &Transaction{} + if err := parsedTx.UnmarshalBinary(data); err != nil { + return nil, fmt.Errorf("rlp decoding failed: %v", err) + } + return parsedTx, nil +} + +func assertEqual(orig *Transaction, cpy *Transaction) error { + // compare nonce, price, gaslimit, recipient, amount, payload, V, R, S + if want, got := orig.Hash(), cpy.Hash(); want != got { + return fmt.Errorf("parsed tx differs from original tx, want %v, got %v", want, got) + } + if want, got := orig.ChainId(), cpy.ChainId(); want.Cmp(got) != 0 { + return fmt.Errorf("invalid chain id, want %d, got %d", want, got) + } + if orig.AccessList() != nil { + if !reflect.DeepEqual(orig.AccessList(), cpy.AccessList()) { + return fmt.Errorf("access list wrong!") + } + } + return nil +} diff --git a/core/vm/access_list_tracer.go b/core/vm/access_list_tracer.go new file mode 100644 index 000000000000..1093af7c4d76 --- /dev/null +++ b/core/vm/access_list_tracer.go @@ -0,0 +1,177 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package vm + +import ( + "math/big" + "time" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/types" +) + +// accessList is an accumulator for the set of accounts and storage slots an EVM +// contract execution touches. +type accessList map[common.Address]accessListSlots + +// accessListSlots is an accumulator for the set of storage slots within a single +// contract that an EVM contract execution touches. +type accessListSlots map[common.Hash]struct{} + +// newAccessList creates a new accessList. +func newAccessList() accessList { + return make(map[common.Address]accessListSlots) +} + +// addAddress adds an address to the accesslist. +func (al accessList) addAddress(address common.Address) { + // Set address if not previously present + if _, present := al[address]; !present { + al[address] = make(map[common.Hash]struct{}) + } +} + +// addSlot adds a storage slot to the accesslist. +func (al accessList) addSlot(address common.Address, slot common.Hash) { + // Set address if not previously present + al.addAddress(address) + + // Set the slot on the surely existent storage set + al[address][slot] = struct{}{} +} + +// equal checks if the content of the current access list is the same as the +// content of the other one. +func (al accessList) equal(other accessList) bool { + // Cross reference the accounts first + if len(al) != len(other) { + return false + } + for addr := range al { + if _, ok := other[addr]; !ok { + return false + } + } + for addr := range other { + if _, ok := al[addr]; !ok { + return false + } + } + // Accounts match, cross reference the storage slots too + for addr, slots := range al { + otherslots := other[addr] + + if len(slots) != len(otherslots) { + return false + } + for hash := range slots { + if _, ok := otherslots[hash]; !ok { + return false + } + } + for hash := range otherslots { + if _, ok := slots[hash]; !ok { + return false + } + } + } + return true +} + +// accesslist converts the accesslist to a types.AccessList. +func (al accessList) accessList() types.AccessList { + acl := make(types.AccessList, 0, len(al)) + for addr, slots := range al { + tuple := types.AccessTuple{Address: addr} + for slot := range slots { + tuple.StorageKeys = append(tuple.StorageKeys, slot) + } + acl = append(acl, tuple) + } + return acl +} + +// AccessListTracer is a tracer that accumulates touched accounts and storage +// slots into an internal set. +type AccessListTracer struct { + excl map[common.Address]struct{} // Set of account to exclude from the list + list accessList // Set of accounts and storage slots touched +} + +// NewAccessListTracer creates a new tracer that can generate AccessLists. +// An optional AccessList can be specified to occupy slots and addresses in +// the resulting accesslist. +func NewAccessListTracer(acl types.AccessList, from, to common.Address, precompiles []common.Address) *AccessListTracer { + excl := map[common.Address]struct{}{ + from: {}, to: {}, + } + for _, addr := range precompiles { + excl[addr] = struct{}{} + } + list := newAccessList() + for _, al := range acl { + if _, ok := excl[al.Address]; !ok { + list.addAddress(al.Address) + } + for _, slot := range al.StorageKeys { + list.addSlot(al.Address, slot) + } + } + return &AccessListTracer{ + excl: excl, + list: list, + } +} + +func (a *AccessListTracer) CaptureStart(env *EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { +} + +// CaptureState captures all opcodes that touch storage or addresses and adds them to the accesslist. +func (a *AccessListTracer) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) { + stack := scope.Stack + if (op == SLOAD || op == SSTORE) && stack.len() >= 1 { + slot := common.Hash(stack.data[stack.len()-1].Bytes32()) + a.list.addSlot(scope.Contract.Address(), slot) + } + if (op == EXTCODECOPY || op == EXTCODEHASH || op == EXTCODESIZE || op == BALANCE || op == SELFDESTRUCT) && stack.len() >= 1 { + addr := common.Address(stack.data[stack.len()-1].Bytes20()) + if _, ok := a.excl[addr]; !ok { + a.list.addAddress(addr) + } + } + if (op == DELEGATECALL || op == CALL || op == STATICCALL || op == CALLCODE) && stack.len() >= 5 { + addr := common.Address(stack.data[stack.len()-2].Bytes20()) + if _, ok := a.excl[addr]; !ok { + a.list.addAddress(addr) + } + } +} + +func (*AccessListTracer) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, depth int, err error) { +} + +func (*AccessListTracer) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) {} + +// AccessList returns the current accesslist maintained by the tracer. +func (a *AccessListTracer) AccessList() types.AccessList { + return a.list.accessList() +} + +// Equal returns if the content of two access list traces are equal. +func (a *AccessListTracer) Equal(other *AccessListTracer) bool { + return a.list.equal(other.list) +} diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 067bc7018ab4..175db51a87aa 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -22,10 +22,9 @@ import ( "errors" "math/big" - "github.com/XinFinOrg/XDPoSChain/core/vm/privacy" - "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/math" + "github.com/XinFinOrg/XDPoSChain/core/vm/privacy" "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/crypto/blake2b" "github.com/XinFinOrg/XDPoSChain/crypto/bn256" @@ -87,6 +86,54 @@ var PrecompiledContractsIstanbul = map[common.Address]PrecompiledContract{ common.BytesToAddress([]byte{42}): &XDCxEpochPrice{}, } +var PrecompiledContractsXDCv2 = map[common.Address]PrecompiledContract{ + common.BytesToAddress([]byte{1}): &ecrecover{}, + common.BytesToAddress([]byte{2}): &sha256hash{}, + common.BytesToAddress([]byte{3}): &ripemd160hash{}, + common.BytesToAddress([]byte{4}): &dataCopy{}, + common.BytesToAddress([]byte{5}): &bigModExp{}, + common.BytesToAddress([]byte{6}): &bn256AddIstanbul{}, + common.BytesToAddress([]byte{7}): &bn256ScalarMulIstanbul{}, + common.BytesToAddress([]byte{8}): &bn256PairingIstanbul{}, + common.BytesToAddress([]byte{9}): &blake2F{}, +} + +var ( + PrecompiledAddressesXDCv2 []common.Address + PrecompiledAddressesIstanbul []common.Address + PrecompiledAddressesByzantium []common.Address + PrecompiledAddressesHomestead []common.Address +) + +func init() { + for k := range PrecompiledContractsHomestead { + PrecompiledAddressesHomestead = append(PrecompiledAddressesHomestead, k) + } + for k := range PrecompiledContractsByzantium { + PrecompiledAddressesHomestead = append(PrecompiledAddressesByzantium, k) + } + for k := range PrecompiledContractsIstanbul { + PrecompiledAddressesIstanbul = append(PrecompiledAddressesIstanbul, k) + } + for k := range PrecompiledContractsXDCv2 { + PrecompiledAddressesXDCv2 = append(PrecompiledAddressesXDCv2, k) + } +} + +// ActivePrecompiles returns the precompiles enabled with the current configuration. +func ActivePrecompiles(rules params.Rules) []common.Address { + switch { + case rules.IsXDCxDisable: + return PrecompiledAddressesXDCv2 + case rules.IsIstanbul: + return PrecompiledAddressesIstanbul + case rules.IsByzantium: + return PrecompiledAddressesByzantium + default: + return PrecompiledAddressesHomestead + } +} + // RunPrecompiledContract runs and evaluates the output of a precompiled contract. func RunPrecompiledContract(p PrecompiledContract, input []byte, contract *Contract) (ret []byte, err error) { gas := p.RequiredGas(input) @@ -488,7 +535,6 @@ func (c *bn256PairingByzantium) Run(input []byte) ([]byte, error) { return runBn256Pairing(input) } - type blake2F struct{} func (c *blake2F) RequiredGas(input []byte) uint64 { diff --git a/core/vm/contracts_test.go b/core/vm/contracts_test.go index 497943d174fd..4d8a4c762fc1 100644 --- a/core/vm/contracts_test.go +++ b/core/vm/contracts_test.go @@ -357,9 +357,15 @@ var bn256PairingTests = []precompiledTest{ } var XDCxLastPriceTests = []precompiledTest{ + // { + // input: common.Bytes2Hex(append(common.Hex2BytesFixed(BTCAddress, 32), common.Hex2BytesFixed(USDTAddress, 32)...)), + // expected: common.Bytes2Hex(common.LeftPadBytes(BTCUSDTLastPrice.Bytes(), XDCXPriceNumberOfBytesReturn)), + // name: "BTCUSDT", + // }, + // since we diable XDCx precompiles, the test now returns 0 { input: common.Bytes2Hex(append(common.Hex2BytesFixed(BTCAddress, 32), common.Hex2BytesFixed(USDTAddress, 32)...)), - expected: common.Bytes2Hex(common.LeftPadBytes(BTCUSDTLastPrice.Bytes(), XDCXPriceNumberOfBytesReturn)), + expected: common.Bytes2Hex(common.LeftPadBytes(common.Big0.Bytes(), XDCXPriceNumberOfBytesReturn)), name: "BTCUSDT", }, { @@ -375,9 +381,15 @@ var XDCxLastPriceTests = []precompiledTest{ } var XDCxEpochPriceTests = []precompiledTest{ + // { + // input: common.Bytes2Hex(append(common.Hex2BytesFixed(BTCAddress, 32), common.Hex2BytesFixed(USDTAddress, 32)...)), + // expected: common.Bytes2Hex(common.LeftPadBytes(BTCUSDTEpochPrice.Bytes(), XDCXPriceNumberOfBytesReturn)), + // name: "BTCUSDT", + // }, + // since we diable XDCx precompiles, the test now returns 0 { input: common.Bytes2Hex(append(common.Hex2BytesFixed(BTCAddress, 32), common.Hex2BytesFixed(USDTAddress, 32)...)), - expected: common.Bytes2Hex(common.LeftPadBytes(BTCUSDTEpochPrice.Bytes(), XDCXPriceNumberOfBytesReturn)), + expected: common.Bytes2Hex(common.LeftPadBytes(common.Big0.Bytes(), XDCXPriceNumberOfBytesReturn)), name: "BTCUSDT", }, { diff --git a/core/vm/eips.go b/core/vm/eips.go index 48ffb0d6917e..69cca508b3af 100644 --- a/core/vm/eips.go +++ b/core/vm/eips.go @@ -33,6 +33,8 @@ func EnableEIP(eipNum int, jt *JumpTable) error { enable3855(jt) case 3198: enable3198(jt) + case 2929: + enable2929(jt) case 2200: enable2200(jt) case 1884: @@ -65,9 +67,9 @@ func enable1884(jt *JumpTable) { } } -func opSelfBalance(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - balance, _ := uint256.FromBig(interpreter.evm.StateDB.GetBalance(callContext.contract.Address())) - callContext.stack.push(balance) +func opSelfBalance(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + balance, _ := uint256.FromBig(interpreter.evm.StateDB.GetBalance(scope.Contract.Address())) + scope.Stack.push(balance) return nil, nil } @@ -84,9 +86,9 @@ func enable1344(jt *JumpTable) { } // opChainID implements CHAINID opcode -func opChainID(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opChainID(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { chainId, _ := uint256.FromBig(interpreter.evm.chainConfig.ChainId) - callContext.stack.push(chainId) + scope.Stack.push(chainId) return nil, nil } @@ -96,6 +98,44 @@ func enable2200(jt *JumpTable) { jt[SSTORE].dynamicGas = gasSStoreEIP2200 } +// enable2929 enables "EIP-2929: Gas cost increases for state access opcodes" +// https://eips.ethereum.org/EIPS/eip-2929 +func enable2929(jt *JumpTable) { + jt[SSTORE].dynamicGas = gasSStoreEIP2929 + + jt[SLOAD].constantGas = 0 + jt[SLOAD].dynamicGas = gasSLoadEIP2929 + + jt[EXTCODECOPY].constantGas = WarmStorageReadCostEIP2929 + jt[EXTCODECOPY].dynamicGas = gasExtCodeCopyEIP2929 + + jt[EXTCODESIZE].constantGas = WarmStorageReadCostEIP2929 + jt[EXTCODESIZE].dynamicGas = gasEip2929AccountCheck + + jt[EXTCODEHASH].constantGas = WarmStorageReadCostEIP2929 + jt[EXTCODEHASH].dynamicGas = gasEip2929AccountCheck + + jt[BALANCE].constantGas = WarmStorageReadCostEIP2929 + jt[BALANCE].dynamicGas = gasEip2929AccountCheck + + jt[CALL].constantGas = WarmStorageReadCostEIP2929 + jt[CALL].dynamicGas = gasCallEIP2929 + + jt[CALLCODE].constantGas = WarmStorageReadCostEIP2929 + jt[CALLCODE].dynamicGas = gasCallCodeEIP2929 + + jt[STATICCALL].constantGas = WarmStorageReadCostEIP2929 + jt[STATICCALL].dynamicGas = gasStaticCallEIP2929 + + jt[DELEGATECALL].constantGas = WarmStorageReadCostEIP2929 + jt[DELEGATECALL].dynamicGas = gasDelegateCallEIP2929 + + // This was previously part of the dynamic cost, but we're using it as a constantGas + // factor here + jt[SELFDESTRUCT].constantGas = params.SelfdestructGasEIP150 + jt[SELFDESTRUCT].dynamicGas = gasSelfdestructEIP2929 +} + // enable3198 applies EIP-3198 (BASEFEE Opcode) // - Adds an opcode that returns the current block's base fee. func enable3198(jt *JumpTable) { @@ -109,9 +149,9 @@ func enable3198(jt *JumpTable) { } // opBaseFee implements BASEFEE opcode -func opBaseFee(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opBaseFee(pc *uint64, interpreter *EVMInterpreter, callContext *ScopeContext) ([]byte, error) { baseFee, _ := uint256.FromBig(common.MinGasPrice50x) - callContext.stack.push(baseFee) + callContext.Stack.push(baseFee) return nil, nil } @@ -127,7 +167,7 @@ func enable3855(jt *JumpTable) { } // opPush0 implements the PUSH0 opcode -func opPush0(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int)) +func opPush0(pc *uint64, interpreter *EVMInterpreter, callContext *ScopeContext) ([]byte, error) { + callContext.Stack.push(new(uint256.Int)) return nil, nil } diff --git a/core/vm/evm.go b/core/vm/evm.go index 227d38412c86..cdc0f9a48aeb 100644 --- a/core/vm/evm.go +++ b/core/vm/evm.go @@ -23,10 +23,9 @@ import ( "time" "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" - "github.com/XinFinOrg/XDPoSChain/params" - "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/crypto" + "github.com/XinFinOrg/XDPoSChain/params" ) // emptyCodeHash is used by create to ensure deployment is disallowed to already @@ -43,24 +42,49 @@ type ( GetHashFunc func(uint64) common.Hash ) +func (evm *EVM) precompile(addr common.Address) (PrecompiledContract, bool) { + var precompiles map[common.Address]PrecompiledContract + switch { + case evm.chainRules.IsXDCxDisable: + precompiles = PrecompiledContractsXDCv2 + case evm.chainRules.IsIstanbul: + precompiles = PrecompiledContractsIstanbul + case evm.chainRules.IsByzantium: + precompiles = PrecompiledContractsByzantium + default: + precompiles = PrecompiledContractsHomestead + } + p, ok := precompiles[addr] + return p, ok +} + +func (evm *EVM) precompile2(addr common.Address) (PrecompiledContract, bool) { + var precompiles map[common.Address]PrecompiledContract + switch { + case evm.chainRules.IsXDCxDisable: + precompiles = PrecompiledContractsXDCv2 + case evm.chainRules.IsIstanbul && evm.ChainConfig().IsTIPXDCXCancellationFee(evm.BlockNumber): + precompiles = PrecompiledContractsIstanbul + case evm.chainRules.IsByzantium: + precompiles = PrecompiledContractsByzantium + default: + precompiles = PrecompiledContractsHomestead + } + p, ok := precompiles[addr] + return p, ok +} + // run runs the given contract and takes care of running precompiles with a fallback to the byte code interpreter. func run(evm *EVM, contract *Contract, input []byte, readOnly bool) ([]byte, error) { if contract.CodeAddr != nil { - var precompiles map[common.Address]PrecompiledContract - switch { - case evm.chainRules.IsIstanbul: - precompiles = PrecompiledContractsIstanbul - case evm.chainRules.IsByzantium: - precompiles = PrecompiledContractsByzantium - default: - precompiles = PrecompiledContractsHomestead - } - if p := precompiles[*contract.CodeAddr]; p != nil { - switch p.(type) { - case *XDCxEpochPrice: - p.(*XDCxEpochPrice).SetTradingState(evm.tradingStateDB) - case *XDCxLastPrice: - p.(*XDCxLastPrice).SetTradingState(evm.tradingStateDB) + if p, isPrecompile := evm.precompile(*contract.CodeAddr); isPrecompile { + if evm.chainConfig.IsTIPXDCXReceiver(evm.BlockNumber) { + switch p := p.(type) { + case *XDCxEpochPrice: + p.SetTradingState(evm.tradingStateDB) + case *XDCxLastPrice: + p.SetTradingState(evm.tradingStateDB) + } } return RunPrecompiledContract(p, input, contract) } @@ -205,19 +229,11 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas snapshot = evm.StateDB.Snapshot() ) if !evm.StateDB.Exist(addr) { - precompiles := PrecompiledContractsHomestead - if evm.chainRules.IsByzantium { - precompiles = PrecompiledContractsByzantium - } - if evm.ChainConfig().IsTIPXDCXCancellationFee(evm.BlockNumber) { - if evm.chainRules.IsIstanbul { - precompiles = PrecompiledContractsIstanbul - } - } - if precompiles[addr] == nil && evm.chainRules.IsEIP158 && value.Sign() == 0 { + _, isPrecompile := evm.precompile2(addr) + if !isPrecompile && evm.chainRules.IsEIP158 && value.Sign() == 0 { // Calling a non existing account, don't do anything, but ping the tracer if evm.vmConfig.Debug && evm.depth == 0 { - evm.vmConfig.Tracer.CaptureStart(caller.Address(), addr, false, input, gas, value) + evm.vmConfig.Tracer.CaptureStart(evm, caller.Address(), addr, false, input, gas, value) evm.vmConfig.Tracer.CaptureEnd(ret, 0, 0, nil) } return nil, gas, nil @@ -235,7 +251,7 @@ func (evm *EVM) Call(caller ContractRef, addr common.Address, input []byte, gas // Capture the tracer start/end events in debug mode if evm.vmConfig.Debug && evm.depth == 0 { - evm.vmConfig.Tracer.CaptureStart(caller.Address(), addr, false, input, gas, value) + evm.vmConfig.Tracer.CaptureStart(evm, caller.Address(), addr, false, input, gas, value) defer func() { // Lazy evaluation of the parameters evm.vmConfig.Tracer.CaptureEnd(ret, gas-contract.Gas, time.Since(start), err) @@ -384,7 +400,11 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, } nonce := evm.StateDB.GetNonce(caller.Address()) evm.StateDB.SetNonce(caller.Address(), nonce+1) - + // We add this to the access list _before_ taking a snapshot. Even if the creation fails, + // the access-list change should not be rolled back + if evm.chainRules.IsEIP1559 { + evm.StateDB.AddAddressToAccessList(address) + } // Ensure there's no existing contract already at the designated address contractHash := evm.StateDB.GetCodeHash(address) if evm.StateDB.GetNonce(address) != 0 || (contractHash != (common.Hash{}) && contractHash != emptyCodeHash) { @@ -404,7 +424,7 @@ func (evm *EVM) create(caller ContractRef, codeAndHash *codeAndHash, gas uint64, contract.SetCodeOptionalHash(&address, codeAndHash) if evm.vmConfig.Debug && evm.depth == 0 { - evm.vmConfig.Tracer.CaptureStart(caller.Address(), address, true, codeAndHash.code, gas, value) + evm.vmConfig.Tracer.CaptureStart(evm, caller.Address(), address, true, codeAndHash.code, gas, value) } start := time.Now() diff --git a/core/vm/instructions.go b/core/vm/instructions.go index a020654440c6..39f0032abe84 100644 --- a/core/vm/instructions.go +++ b/core/vm/instructions.go @@ -26,68 +26,68 @@ import ( "golang.org/x/crypto/sha3" ) -func opAdd(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opAdd(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Add(&x, y) return nil, nil } -func opSub(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSub(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Sub(&x, y) return nil, nil } -func opMul(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opMul(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Mul(&x, y) return nil, nil } -func opDiv(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opDiv(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Div(&x, y) return nil, nil } -func opSdiv(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSdiv(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.SDiv(&x, y) return nil, nil } -func opMod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opMod(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Mod(&x, y) return nil, nil } -func opSmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSmod(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.SMod(&x, y) return nil, nil } -func opExp(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - base, exponent := callContext.stack.pop(), callContext.stack.peek() +func opExp(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + base, exponent := scope.Stack.pop(), scope.Stack.peek() exponent.Exp(&base, exponent) return nil, nil } -func opSignExtend(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - back, num := callContext.stack.pop(), callContext.stack.peek() +func opSignExtend(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + back, num := scope.Stack.pop(), scope.Stack.peek() num.ExtendSign(num, &back) return nil, nil } -func opNot(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x := callContext.stack.peek() +func opNot(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x := scope.Stack.peek() x.Not(x) return nil, nil } -func opLt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opLt(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Lt(y) { y.SetOne() } else { @@ -96,8 +96,8 @@ func opLt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte return nil, nil } -func opGt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opGt(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Gt(y) { y.SetOne() } else { @@ -106,8 +106,8 @@ func opGt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte return nil, nil } -func opSlt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSlt(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Slt(y) { y.SetOne() } else { @@ -116,8 +116,8 @@ func opSlt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt return nil, nil } -func opSgt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opSgt(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Sgt(y) { y.SetOne() } else { @@ -126,8 +126,8 @@ func opSgt(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt return nil, nil } -func opEq(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opEq(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() if x.Eq(y) { y.SetOne() } else { @@ -136,8 +136,8 @@ func opEq(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte return nil, nil } -func opIszero(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x := callContext.stack.peek() +func opIszero(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x := scope.Stack.peek() if x.IsZero() { x.SetOne() } else { @@ -146,32 +146,32 @@ func opIszero(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] return nil, nil } -func opAnd(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opAnd(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.And(&x, y) return nil, nil } -func opOr(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opOr(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Or(&x, y) return nil, nil } -func opXor(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y := callContext.stack.pop(), callContext.stack.peek() +func opXor(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y := scope.Stack.pop(), scope.Stack.peek() y.Xor(&x, y) return nil, nil } -func opByte(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - th, val := callContext.stack.pop(), callContext.stack.peek() +func opByte(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + th, val := scope.Stack.pop(), scope.Stack.peek() val.Byte(&th) return nil, nil } -func opAddmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y, z := callContext.stack.pop(), callContext.stack.pop(), callContext.stack.peek() +func opAddmod(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y, z := scope.Stack.pop(), scope.Stack.pop(), scope.Stack.peek() if z.IsZero() { z.Clear() } else { @@ -180,8 +180,8 @@ func opAddmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] return nil, nil } -func opMulmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x, y, z := callContext.stack.pop(), callContext.stack.pop(), callContext.stack.peek() +func opMulmod(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x, y, z := scope.Stack.pop(), scope.Stack.pop(), scope.Stack.peek() z.MulMod(&x, &y, z) return nil, nil } @@ -189,9 +189,9 @@ func opMulmod(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] // opSHL implements Shift Left // The SHL instruction (shift left) pops 2 values from the stack, first arg1 and then arg2, // and pushes on the stack arg2 shifted to the left by arg1 number of bits. -func opSHL(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opSHL(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // Note, second operand is left in the stack; accumulate result into it, and no need to push it afterwards - shift, value := callContext.stack.pop(), callContext.stack.peek() + shift, value := scope.Stack.pop(), scope.Stack.peek() if shift.LtUint64(256) { value.Lsh(value, uint(shift.Uint64())) } else { @@ -203,9 +203,9 @@ func opSHL(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt // opSHR implements Logical Shift Right // The SHR instruction (logical shift right) pops 2 values from the stack, first arg1 and then arg2, // and pushes on the stack arg2 shifted to the right by arg1 number of bits with zero fill. -func opSHR(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opSHR(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // Note, second operand is left in the stack; accumulate result into it, and no need to push it afterwards - shift, value := callContext.stack.pop(), callContext.stack.peek() + shift, value := scope.Stack.pop(), scope.Stack.peek() if shift.LtUint64(256) { value.Rsh(value, uint(shift.Uint64())) } else { @@ -217,8 +217,8 @@ func opSHR(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt // opSAR implements Arithmetic Shift Right // The SAR instruction (arithmetic shift right) pops 2 values from the stack, first arg1 and then arg2, // and pushes on the stack arg2 shifted to the right by arg1 number of bits with sign extension. -func opSAR(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - shift, value := callContext.stack.pop(), callContext.stack.peek() +func opSAR(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + shift, value := scope.Stack.pop(), scope.Stack.peek() if shift.GtUint64(256) { if value.Sign() >= 0 { value.Clear() @@ -233,9 +233,9 @@ func opSAR(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byt return nil, nil } -func opKeccak256(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - offset, size := callContext.stack.pop(), callContext.stack.peek() - data := callContext.memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) +func opKeccak256(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + offset, size := scope.Stack.pop(), scope.Stack.peek() + data := scope.Memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) if interpreter.hasher == nil { interpreter.hasher = sha3.NewLegacyKeccak256().(keccakState) @@ -253,37 +253,37 @@ func opKeccak256(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) size.SetBytes(interpreter.hasherBuf[:]) return nil, nil } -func opAddress(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetBytes(callContext.contract.Address().Bytes())) +func opAddress(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetBytes(scope.Contract.Address().Bytes())) return nil, nil } -func opBalance(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - slot := callContext.stack.peek() +func opBalance(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + slot := scope.Stack.peek() address := common.Address(slot.Bytes20()) slot.SetFromBig(interpreter.evm.StateDB.GetBalance(address)) return nil, nil } -func opOrigin(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetBytes(interpreter.evm.Origin.Bytes())) +func opOrigin(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetBytes(interpreter.evm.Origin.Bytes())) return nil, nil } -func opCaller(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetBytes(callContext.contract.Caller().Bytes())) +func opCaller(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetBytes(scope.Contract.Caller().Bytes())) return nil, nil } -func opCallValue(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - v, _ := uint256.FromBig(callContext.contract.value) - callContext.stack.push(v) +func opCallValue(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + v, _ := uint256.FromBig(scope.Contract.value) + scope.Stack.push(v) return nil, nil } -func opCallDataLoad(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - x := callContext.stack.peek() +func opCallDataLoad(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + x := scope.Stack.peek() if offset, overflow := x.Uint64WithOverflow(); !overflow { - data := getData(callContext.contract.Input, offset, 32) + data := getData(scope.Contract.Input, offset, 32) x.SetBytes(data) } else { x.Clear() @@ -291,16 +291,16 @@ func opCallDataLoad(pc *uint64, interpreter *EVMInterpreter, callContext *callCt return nil, nil } -func opCallDataSize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(uint64(len(callContext.contract.Input)))) +func opCallDataSize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(uint64(len(scope.Contract.Input)))) return nil, nil } -func opCallDataCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCallDataCopy(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - memOffset = callContext.stack.pop() - dataOffset = callContext.stack.pop() - length = callContext.stack.pop() + memOffset = scope.Stack.pop() + dataOffset = scope.Stack.pop() + length = scope.Stack.pop() ) dataOffset64, overflow := dataOffset.Uint64WithOverflow() if overflow { @@ -309,21 +309,21 @@ func opCallDataCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCt // These values are checked for overflow during gas cost calculation memOffset64 := memOffset.Uint64() length64 := length.Uint64() - callContext.memory.Set(memOffset64, length64, getData(callContext.contract.Input, dataOffset64, length64)) + scope.Memory.Set(memOffset64, length64, getData(scope.Contract.Input, dataOffset64, length64)) return nil, nil } -func opReturnDataSize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(uint64(len(interpreter.returnData)))) +func opReturnDataSize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(uint64(len(interpreter.returnData)))) return nil, nil } -func opReturnDataCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opReturnDataCopy(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - memOffset = callContext.stack.pop() - dataOffset = callContext.stack.pop() - length = callContext.stack.pop() + memOffset = scope.Stack.pop() + dataOffset = scope.Stack.pop() + length = scope.Stack.pop() ) offset64, overflow := dataOffset.Uint64WithOverflow() @@ -337,42 +337,42 @@ func opReturnDataCopy(pc *uint64, interpreter *EVMInterpreter, callContext *call if overflow || uint64(len(interpreter.returnData)) < end64 { return nil, ErrReturnDataOutOfBounds } - callContext.memory.Set(memOffset.Uint64(), length.Uint64(), interpreter.returnData[offset64:end64]) + scope.Memory.Set(memOffset.Uint64(), length.Uint64(), interpreter.returnData[offset64:end64]) return nil, nil } -func opExtCodeSize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - slot := callContext.stack.peek() +func opExtCodeSize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + slot := scope.Stack.peek() slot.SetUint64(uint64(interpreter.evm.StateDB.GetCodeSize(common.Address(slot.Bytes20())))) return nil, nil } -func opCodeSize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCodeSize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { l := new(uint256.Int) - l.SetUint64(uint64(len(callContext.contract.Code))) - callContext.stack.push(l) + l.SetUint64(uint64(len(scope.Contract.Code))) + scope.Stack.push(l) return nil, nil } -func opCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCodeCopy(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - memOffset = callContext.stack.pop() - codeOffset = callContext.stack.pop() - length = callContext.stack.pop() + memOffset = scope.Stack.pop() + codeOffset = scope.Stack.pop() + length = scope.Stack.pop() ) uint64CodeOffset, overflow := codeOffset.Uint64WithOverflow() if overflow { uint64CodeOffset = 0xffffffffffffffff } - codeCopy := getData(callContext.contract.Code, uint64CodeOffset, length.Uint64()) - callContext.memory.Set(memOffset.Uint64(), length.Uint64(), codeCopy) + codeCopy := getData(scope.Contract.Code, uint64CodeOffset, length.Uint64()) + scope.Memory.Set(memOffset.Uint64(), length.Uint64(), codeCopy) return nil, nil } -func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - stack = callContext.stack + stack = scope.Stack a = stack.pop() memOffset = stack.pop() codeOffset = stack.pop() @@ -384,7 +384,7 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx } addr := common.Address(a.Bytes20()) codeCopy := getData(interpreter.evm.StateDB.GetCode(addr), uint64CodeOffset, length.Uint64()) - callContext.memory.Set(memOffset.Uint64(), length.Uint64(), codeCopy) + scope.Memory.Set(memOffset.Uint64(), length.Uint64(), codeCopy) return nil, nil } @@ -392,16 +392,21 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx // opExtCodeHash returns the code hash of a specified account. // There are several cases when the function is called, while we can relay everything // to `state.GetCodeHash` function to ensure the correctness. -// (1) Caller tries to get the code hash of a normal contract account, state +// +// (1) Caller tries to get the code hash of a normal contract account, state +// // should return the relative code hash and set it as the result. // -// (2) Caller tries to get the code hash of a non-existent account, state should +// (2) Caller tries to get the code hash of a non-existent account, state should +// // return common.Hash{} and zero will be set as the result. // -// (3) Caller tries to get the code hash for an account without contract code, +// (3) Caller tries to get the code hash for an account without contract code, +// // state should return emptyCodeHash(0xc5d246...) as the result. // -// (4) Caller tries to get the code hash of a precompiled account, the result +// (4) Caller tries to get the code hash of a precompiled account, the result +// // should be zero or emptyCodeHash. // // It is worth noting that in order to avoid unnecessary create and clean, @@ -410,13 +415,15 @@ func opExtCodeCopy(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx // If the precompile account is not transferred any amount on a private or // customized chain, the return value will be zero. // -// (5) Caller tries to get the code hash for an account which is marked as suicided +// (5) Caller tries to get the code hash for an account which is marked as suicided +// // in the current transaction, the code hash of this account should be returned. // -// (6) Caller tries to get the code hash for an account which is marked as deleted, +// (6) Caller tries to get the code hash for an account which is marked as deleted, +// // this account should be regarded as a non-existent account and zero should be returned. -func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - slot := callContext.stack.peek() +func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + slot := scope.Stack.peek() address := common.Address(slot.Bytes20()) if interpreter.evm.StateDB.Empty(address) { slot.Clear() @@ -426,14 +433,14 @@ func opExtCodeHash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx return nil, nil } -func opGasprice(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opGasprice(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { v, _ := uint256.FromBig(interpreter.evm.GasPrice) - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opBlockhash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - num := callContext.stack.peek() +func opBlockhash(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + num := scope.Stack.peek() num64, overflow := num.Uint64WithOverflow() if overflow { num.Clear() @@ -454,108 +461,108 @@ func opBlockhash(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) return nil, nil } -func opCoinbase(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetBytes(interpreter.evm.Coinbase.Bytes())) +func opCoinbase(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetBytes(interpreter.evm.Coinbase.Bytes())) return nil, nil } -func opTimestamp(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opTimestamp(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { v, _ := uint256.FromBig(interpreter.evm.Time) - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opNumber(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opNumber(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { v, _ := uint256.FromBig(interpreter.evm.BlockNumber) - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opDifficulty(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opDifficulty(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { v, _ := uint256.FromBig(interpreter.evm.Difficulty) - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opRandom(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opRandom(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var v *uint256.Int if interpreter.evm.Context.Random != nil { v = new(uint256.Int).SetBytes((interpreter.evm.Context.Random.Bytes())) } else { // if context random is not set, use emptyCodeHash as default v = new(uint256.Int).SetBytes(emptyCodeHash.Bytes()) } - callContext.stack.push(v) + scope.Stack.push(v) return nil, nil } -func opGasLimit(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(interpreter.evm.GasLimit)) +func opGasLimit(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(interpreter.evm.GasLimit)) return nil, nil } -func opPop(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.pop() +func opPop(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.pop() return nil, nil } -func opMload(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - v := callContext.stack.peek() +func opMload(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + v := scope.Stack.peek() offset := int64(v.Uint64()) - v.SetBytes(callContext.memory.GetPtr(offset, 32)) + v.SetBytes(scope.Memory.GetPtr(offset, 32)) return nil, nil } -func opMstore(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opMstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // pop value of the stack - mStart, val := callContext.stack.pop(), callContext.stack.pop() - callContext.memory.Set32(mStart.Uint64(), &val) + mStart, val := scope.Stack.pop(), scope.Stack.pop() + scope.Memory.Set32(mStart.Uint64(), &val) return nil, nil } -func opMstore8(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - off, val := callContext.stack.pop(), callContext.stack.pop() - callContext.memory.store[off.Uint64()] = byte(val.Uint64()) +func opMstore8(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + off, val := scope.Stack.pop(), scope.Stack.pop() + scope.Memory.store[off.Uint64()] = byte(val.Uint64()) return nil, nil } -func opSload(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - loc := callContext.stack.peek() +func opSload(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + loc := scope.Stack.peek() hash := common.Hash(loc.Bytes32()) - val := interpreter.evm.StateDB.GetState(callContext.contract.Address(), hash) + val := interpreter.evm.StateDB.GetState(scope.Contract.Address(), hash) loc.SetBytes(val.Bytes()) return nil, nil } -func opSstore(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opSstore(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } - loc := callContext.stack.pop() - val := callContext.stack.pop() - interpreter.evm.StateDB.SetState(callContext.contract.Address(), + loc := scope.Stack.pop() + val := scope.Stack.pop() + interpreter.evm.StateDB.SetState(scope.Contract.Address(), common.Hash(loc.Bytes32()), common.Hash(val.Bytes32())) return nil, nil } -func opJump(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opJump(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if atomic.LoadInt32(&interpreter.evm.abort) != 0 { return nil, errStopToken } - pos := callContext.stack.pop() - if !callContext.contract.validJumpdest(&pos) { + pos := scope.Stack.pop() + if !scope.Contract.validJumpdest(&pos) { return nil, ErrInvalidJump } *pc = pos.Uint64() - 1 // pc will be increased by the interpreter loop return nil, nil } -func opJumpi(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opJumpi(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if atomic.LoadInt32(&interpreter.evm.abort) != 0 { return nil, errStopToken } - pos, cond := callContext.stack.pop(), callContext.stack.pop() + pos, cond := scope.Stack.pop(), scope.Stack.pop() if !cond.IsZero() { - if !callContext.contract.validJumpdest(&pos) { + if !scope.Contract.validJumpdest(&pos) { return nil, ErrInvalidJump } *pc = pos.Uint64() - 1 // pc will be increased by the interpreter loop @@ -563,34 +570,34 @@ func opJumpi(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]b return nil, nil } -func opJumpdest(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opJumpdest(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { return nil, nil } -func opPc(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(*pc)) +func opPc(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(*pc)) return nil, nil } -func opMsize(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(uint64(callContext.memory.Len()))) +func opMsize(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(uint64(scope.Memory.Len()))) return nil, nil } -func opGas(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.push(new(uint256.Int).SetUint64(callContext.contract.Gas)) +func opGas(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.push(new(uint256.Int).SetUint64(scope.Contract.Gas)) return nil, nil } -func opCreate(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCreate(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } var ( - value = callContext.stack.pop() - offset, size = callContext.stack.pop(), callContext.stack.pop() - input = callContext.memory.GetCopy(int64(offset.Uint64()), int64(size.Uint64())) - gas = callContext.contract.Gas + value = scope.Stack.pop() + offset, size = scope.Stack.pop(), scope.Stack.pop() + input = scope.Memory.GetCopy(int64(offset.Uint64()), int64(size.Uint64())) + gas = scope.Contract.Gas ) if interpreter.evm.chainRules.IsEIP150 { gas -= gas / 64 @@ -598,8 +605,8 @@ func opCreate(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] // reuse size int for stackvalue stackvalue := size - callContext.contract.UseGas(gas) - res, addr, returnGas, suberr := interpreter.evm.Create(callContext.contract, input, gas, value.ToBig()) + scope.Contract.UseGas(gas) + res, addr, returnGas, suberr := interpreter.evm.Create(scope.Contract, input, gas, value.ToBig()) // Push item on the stack based on the returned error. If the ruleset is // homestead we must check for CodeStoreOutOfGasError (homestead only // rule) and treat as an error, if the ruleset is frontier we must @@ -611,8 +618,8 @@ func opCreate(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] } else { stackvalue.SetBytes(addr.Bytes()) } - callContext.stack.push(&stackvalue) - callContext.contract.Gas += returnGas + scope.Stack.push(&stackvalue) + scope.Contract.Gas += returnGas if suberr == ErrExecutionReverted { interpreter.returnData = res // set REVERT data to return data buffer @@ -622,24 +629,24 @@ func opCreate(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([] return nil, nil } -func opCreate2(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCreate2(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } var ( - endowment = callContext.stack.pop() - offset, size = callContext.stack.pop(), callContext.stack.pop() - salt = callContext.stack.pop() - input = callContext.memory.GetCopy(int64(offset.Uint64()), int64(size.Uint64())) - gas = callContext.contract.Gas + endowment = scope.Stack.pop() + offset, size = scope.Stack.pop(), scope.Stack.pop() + salt = scope.Stack.pop() + input = scope.Memory.GetCopy(int64(offset.Uint64()), int64(size.Uint64())) + gas = scope.Contract.Gas ) // Apply EIP150 gas -= gas / 64 - callContext.contract.UseGas(gas) + scope.Contract.UseGas(gas) // reuse size int for stackvalue stackvalue := size - res, addr, returnGas, suberr := interpreter.evm.Create2(callContext.contract, input, gas, + res, addr, returnGas, suberr := interpreter.evm.Create2(scope.Contract, input, gas, endowment.ToBig(), salt.ToBig()) // Push item on the stack based on the returned error. if suberr != nil { @@ -647,8 +654,8 @@ func opCreate2(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([ } else { stackvalue.SetBytes(addr.Bytes()) } - callContext.stack.push(&stackvalue) - callContext.contract.Gas += returnGas + scope.Stack.push(&stackvalue) + scope.Contract.Gas += returnGas if suberr == ErrExecutionReverted { interpreter.returnData = res // set REVERT data to return data buffer @@ -658,8 +665,8 @@ func opCreate2(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([ return nil, nil } -func opCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - stack := callContext.stack +func opCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + stack := scope.Stack // Pop gas. The actual gas in interpreter.evm.callGasTemp. // We can use this as a temporary value temp := stack.pop() @@ -668,7 +675,7 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by addr, value, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.Address(addr.Bytes20()) // Get the arguments from the memory. - args := callContext.memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) + args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) if interpreter.readOnly && !value.IsZero() { return nil, ErrWriteProtection @@ -676,7 +683,7 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by if !value.IsZero() { gas += params.CallStipend } - ret, returnGas, err := interpreter.evm.Call(callContext.contract, toAddr, args, gas, value.ToBig()) + ret, returnGas, err := interpreter.evm.Call(scope.Contract, toAddr, args, gas, value.ToBig()) if err != nil { temp.Clear() } else { @@ -685,17 +692,17 @@ func opCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]by stack.push(&temp) if err == nil || err == ErrExecutionReverted { ret = common.CopyBytes(ret) - callContext.memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) + scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - callContext.contract.Gas += returnGas + scope.Contract.Gas += returnGas interpreter.returnData = ret return ret, nil } -func opCallCode(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opCallCode(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // Pop gas. The actual gas is in interpreter.evm.callGasTemp. - stack := callContext.stack + stack := scope.Stack // We use it as a temporary value temp := stack.pop() gas := interpreter.evm.callGasTemp @@ -703,12 +710,12 @@ func opCallCode(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ( addr, value, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.Address(addr.Bytes20()) // Get arguments from the memory. - args := callContext.memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) + args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) if !value.IsZero() { gas += params.CallStipend } - ret, returnGas, err := interpreter.evm.CallCode(callContext.contract, toAddr, args, gas, value.ToBig()) + ret, returnGas, err := interpreter.evm.CallCode(scope.Contract, toAddr, args, gas, value.ToBig()) if err != nil { temp.Clear() } else { @@ -717,16 +724,16 @@ func opCallCode(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ( stack.push(&temp) if err == nil || err == ErrExecutionReverted { ret = common.CopyBytes(ret) - callContext.memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) + scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - callContext.contract.Gas += returnGas + scope.Contract.Gas += returnGas interpreter.returnData = ret return ret, nil } -func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - stack := callContext.stack +func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + stack := scope.Stack // Pop gas. The actual gas is in interpreter.evm.callGasTemp. // We use it as a temporary value temp := stack.pop() @@ -735,9 +742,9 @@ func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCt addr, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.Address(addr.Bytes20()) // Get arguments from the memory. - args := callContext.memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) + args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) - ret, returnGas, err := interpreter.evm.DelegateCall(callContext.contract, toAddr, args, gas) + ret, returnGas, err := interpreter.evm.DelegateCall(scope.Contract, toAddr, args, gas) if err != nil { temp.Clear() } else { @@ -746,17 +753,17 @@ func opDelegateCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCt stack.push(&temp) if err == nil || err == ErrExecutionReverted { ret = common.CopyBytes(ret) - callContext.memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) + scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - callContext.contract.Gas += returnGas + scope.Contract.Gas += returnGas interpreter.returnData = ret return ret, nil } -func opStaticCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opStaticCall(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { // Pop gas. The actual gas is in interpreter.evm.callGasTemp. - stack := callContext.stack + stack := scope.Stack // We use it as a temporary value temp := stack.pop() gas := interpreter.evm.callGasTemp @@ -764,9 +771,9 @@ func opStaticCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) addr, inOffset, inSize, retOffset, retSize := stack.pop(), stack.pop(), stack.pop(), stack.pop(), stack.pop() toAddr := common.Address(addr.Bytes20()) // Get arguments from the memory. - args := callContext.memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) + args := scope.Memory.GetPtr(int64(inOffset.Uint64()), int64(inSize.Uint64())) - ret, returnGas, err := interpreter.evm.StaticCall(callContext.contract, toAddr, args, gas) + ret, returnGas, err := interpreter.evm.StaticCall(scope.Contract, toAddr, args, gas) if err != nil { temp.Clear() } else { @@ -775,45 +782,45 @@ func opStaticCall(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) stack.push(&temp) if err == nil || err == ErrExecutionReverted { ret = common.CopyBytes(ret) - callContext.memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) + scope.Memory.Set(retOffset.Uint64(), retSize.Uint64(), ret) } - callContext.contract.Gas += returnGas + scope.Contract.Gas += returnGas interpreter.returnData = ret return ret, nil } -func opReturn(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - offset, size := callContext.stack.pop(), callContext.stack.pop() - ret := callContext.memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) +func opReturn(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + offset, size := scope.Stack.pop(), scope.Stack.pop() + ret := scope.Memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) return ret, errStopToken } -func opRevert(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - offset, size := callContext.stack.pop(), callContext.stack.pop() - ret := callContext.memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) +func opRevert(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + offset, size := scope.Stack.pop(), scope.Stack.pop() + ret := scope.Memory.GetPtr(int64(offset.Uint64()), int64(size.Uint64())) interpreter.returnData = ret return ret, ErrExecutionReverted } -func opUndefined(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - return nil, &ErrInvalidOpCode{opcode: OpCode(callContext.contract.Code[*pc])} +func opUndefined(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + return nil, &ErrInvalidOpCode{opcode: OpCode(scope.Contract.Code[*pc])} } -func opStop(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opStop(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { return nil, errStopToken } -func opSelfdestruct(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opSelfdestruct(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } - beneficiary := callContext.stack.pop() - balance := interpreter.evm.StateDB.GetBalance(callContext.contract.Address()) + beneficiary := scope.Stack.pop() + balance := interpreter.evm.StateDB.GetBalance(scope.Contract.Address()) interpreter.evm.StateDB.AddBalance(common.Address(beneficiary.Bytes20()), balance) - interpreter.evm.StateDB.Suicide(callContext.contract.Address()) + interpreter.evm.StateDB.Suicide(scope.Contract.Address()) return nil, errStopToken } @@ -821,21 +828,21 @@ func opSelfdestruct(pc *uint64, interpreter *EVMInterpreter, callContext *callCt // make log instruction function func makeLog(size int) executionFunc { - return func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { + return func(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { if interpreter.readOnly { return nil, ErrWriteProtection } topics := make([]common.Hash, size) - stack := callContext.stack + stack := scope.Stack mStart, mSize := stack.pop(), stack.pop() for i := 0; i < size; i++ { addr := stack.pop() topics[i] = common.Hash(addr.Bytes32()) } - d := callContext.memory.GetCopy(int64(mStart.Uint64()), int64(mSize.Uint64())) + d := scope.Memory.GetCopy(int64(mStart.Uint64()), int64(mSize.Uint64())) interpreter.evm.StateDB.AddLog(&types.Log{ - Address: callContext.contract.Address(), + Address: scope.Contract.Address(), Topics: topics, Data: d, // This is a non-consensus field, but assigned here because @@ -848,24 +855,24 @@ func makeLog(size int) executionFunc { } // opPush1 is a specialized version of pushN -func opPush1(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { +func opPush1(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { var ( - codeLen = uint64(len(callContext.contract.Code)) + codeLen = uint64(len(scope.Contract.Code)) integer = new(uint256.Int) ) *pc += 1 if *pc < codeLen { - callContext.stack.push(integer.SetUint64(uint64(callContext.contract.Code[*pc]))) + scope.Stack.push(integer.SetUint64(uint64(scope.Contract.Code[*pc]))) } else { - callContext.stack.push(integer.Clear()) + scope.Stack.push(integer.Clear()) } return nil, nil } // make push instruction function func makePush(size uint64, pushByteSize int) executionFunc { - return func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - codeLen := len(callContext.contract.Code) + return func(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + codeLen := len(scope.Contract.Code) startMin := codeLen if int(*pc+1) < startMin { @@ -878,8 +885,8 @@ func makePush(size uint64, pushByteSize int) executionFunc { } integer := new(uint256.Int) - callContext.stack.push(integer.SetBytes(common.RightPadBytes( - callContext.contract.Code[startMin:endMin], pushByteSize))) + scope.Stack.push(integer.SetBytes(common.RightPadBytes( + scope.Contract.Code[startMin:endMin], pushByteSize))) *pc += size return nil, nil @@ -888,8 +895,8 @@ func makePush(size uint64, pushByteSize int) executionFunc { // make dup instruction function func makeDup(size int64) executionFunc { - return func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.dup(int(size)) + return func(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.dup(int(size)) return nil, nil } } @@ -898,8 +905,8 @@ func makeDup(size int64) executionFunc { func makeSwap(size int64) executionFunc { // switch n + 1 otherwise n would be swapped with n size++ - return func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) { - callContext.stack.swap(int(size)) + return func(pc *uint64, interpreter *EVMInterpreter, scope *ScopeContext) ([]byte, error) { + scope.Stack.swap(int(size)) return nil, nil } } diff --git a/core/vm/instructions_test.go b/core/vm/instructions_test.go index 200aa1477235..580a008b4e72 100644 --- a/core/vm/instructions_test.go +++ b/core/vm/instructions_test.go @@ -107,7 +107,7 @@ func testTwoOperandOp(t *testing.T, tests []TwoOperandTestcase, opFn executionFu expected := new(uint256.Int).SetBytes(common.Hex2Bytes(test.Expected)) stack.push(x) stack.push(y) - opFn(&pc, evmInterpreter, &callCtx{nil, stack, nil}) + opFn(&pc, evmInterpreter, &ScopeContext{nil, stack, nil}) if len(stack.data) != 1 { t.Errorf("Expected one item on stack after %v, got %d: ", name, len(stack.data)) } @@ -222,7 +222,7 @@ func TestAddMod(t *testing.T) { stack.push(z) stack.push(y) stack.push(x) - opAddmod(&pc, evmInterpreter, &callCtx{nil, stack, nil}) + opAddmod(&pc, evmInterpreter, &ScopeContext{nil, stack, nil}) actual := stack.pop() if actual.Cmp(expected) != 0 { t.Errorf("Testcase %d, expected %x, got %x", i, expected, actual) @@ -244,7 +244,7 @@ func getResult(args []*twoOperandParams, opFn executionFunc) []TwoOperandTestcas y := new(uint256.Int).SetBytes(common.Hex2Bytes(param.y)) stack.push(x) stack.push(y) - _, err := opFn(&pc, interpreter, &callCtx{nil, stack, nil}) + _, err := opFn(&pc, interpreter, &ScopeContext{nil, stack, nil}) if err != nil { log.Fatalln(err) } @@ -308,7 +308,7 @@ func opBenchmark(bench *testing.B, op executionFunc, args ...string) { a.SetBytes(arg) stack.push(a) } - op(&pc, evmInterpreter, &callCtx{nil, stack, nil}) + op(&pc, evmInterpreter, &ScopeContext{nil, stack, nil}) stack.pop() } } @@ -535,13 +535,13 @@ func TestOpMstore(t *testing.T) { v := "abcdef00000000000000abba000000000deaf000000c0de00100000000133700" stack.push(new(uint256.Int).SetBytes(common.Hex2Bytes(v))) stack.push(new(uint256.Int)) - opMstore(&pc, evmInterpreter, &callCtx{mem, stack, nil}) + opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil}) if got := common.Bytes2Hex(mem.GetCopy(0, 32)); got != v { t.Fatalf("Mstore fail, got %v, expected %v", got, v) } stack.push(new(uint256.Int).SetUint64(0x1)) stack.push(new(uint256.Int)) - opMstore(&pc, evmInterpreter, &callCtx{mem, stack, nil}) + opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil}) if common.Bytes2Hex(mem.GetCopy(0, 32)) != "0000000000000000000000000000000000000000000000000000000000000001" { t.Fatalf("Mstore failed to overwrite previous value") } @@ -565,7 +565,7 @@ func BenchmarkOpMstore(bench *testing.B) { for i := 0; i < bench.N; i++ { stack.push(value) stack.push(memStart) - opMstore(&pc, evmInterpreter, &callCtx{mem, stack, nil}) + opMstore(&pc, evmInterpreter, &ScopeContext{mem, stack, nil}) } } @@ -585,7 +585,7 @@ func BenchmarkOpKeccak256(bench *testing.B) { for i := 0; i < bench.N; i++ { stack.push(uint256.NewInt(32)) stack.push(start) - opKeccak256(&pc, evmInterpreter, &callCtx{mem, stack, nil}) + opKeccak256(&pc, evmInterpreter, &ScopeContext{mem, stack, nil}) } } @@ -681,7 +681,7 @@ func TestRandom(t *testing.T) { pc = uint64(0) evmInterpreter = NewEVMInterpreter(env, env.vmConfig) ) - opRandom(&pc, evmInterpreter, &callCtx{nil, stack, nil}) + opRandom(&pc, evmInterpreter, &ScopeContext{nil, stack, nil}) if len(stack.data) != 1 { t.Errorf("Expected one item on stack after %v, got %d: ", tt.name, len(stack.data)) } diff --git a/core/vm/interface.go b/core/vm/interface.go index 81549a9206e0..903a957a2017 100644 --- a/core/vm/interface.go +++ b/core/vm/interface.go @@ -57,6 +57,16 @@ type StateDB interface { // is defined according to EIP161 (balance = nonce = code = 0). Empty(common.Address) bool + PrepareAccessList(sender common.Address, dest *common.Address, precompiles []common.Address, txAccesses types.AccessList) + AddressInAccessList(addr common.Address) bool + SlotInAccessList(addr common.Address, slot common.Hash) (addressOk bool, slotOk bool) + // AddAddressToAccessList adds the given address to the access list. This operation is safe to perform + // even if the feature/fork is not active yet + AddAddressToAccessList(addr common.Address) + // AddSlotToAccessList adds the given (address,slot) to the access list. This operation is safe to perform + // even if the feature/fork is not active yet + AddSlotToAccessList(addr common.Address, slot common.Hash) + RevertToSnapshot(int) Snapshot() int diff --git a/core/vm/interpreter.go b/core/vm/interpreter.go index 23cc5a03fe3d..9d16be1a664c 100644 --- a/core/vm/interpreter.go +++ b/core/vm/interpreter.go @@ -60,12 +60,12 @@ type Interpreter interface { CanRun([]byte) bool } -// callCtx contains the things that are per-call, such as stack and memory, +// ScopeContext contains the things that are per-call, such as stack and memory, // but not transients like pc and gas -type callCtx struct { - memory *Memory - stack *Stack - contract *Contract +type ScopeContext struct { + Memory *Memory + Stack *Stack + Contract *Contract } // keccakState wraps sha3.state. In addition to the usual hash methods, it also supports @@ -93,6 +93,8 @@ func NewEVMInterpreter(evm *EVM, cfg Config) *EVMInterpreter { // If jump table was not initialised we set the default one. if cfg.JumpTable == nil { switch { + case evm.chainRules.IsEIP1559: + cfg.JumpTable = &eip1559InstructionSet case evm.chainRules.IsShanghai: cfg.JumpTable = &shanghaiInstructionSet case evm.chainRules.IsMerge: @@ -168,10 +170,10 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( op OpCode // current opcode mem = NewMemory() // bound memory stack = newstack() // local stack - callContext = &callCtx{ - memory: mem, - stack: stack, - contract: contract, + callContext = &ScopeContext{ + Memory: mem, + Stack: stack, + Contract: contract, } // For optimisation reason we're using uint64 as the program counter. // It's theoretically possible to go above 2^64. The YP defines the PC @@ -190,9 +192,9 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( defer func() { if err != nil { if !logged { - in.cfg.Tracer.CaptureState(in.evm, pcCopy, op, gasCopy, cost, mem, stack, contract, in.evm.depth, err) + in.cfg.Tracer.CaptureState(in.evm, pcCopy, op, gasCopy, cost, callContext, in.returnData, in.evm.depth, err) } else { - in.cfg.Tracer.CaptureFault(in.evm, pcCopy, op, gasCopy, cost, mem, stack, contract, in.evm.depth, err) + in.cfg.Tracer.CaptureFault(in.evm, pcCopy, op, gasCopy, cost, callContext, in.evm.depth, err) } } }() @@ -255,7 +257,7 @@ func (in *EVMInterpreter) Run(contract *Contract, input []byte, readOnly bool) ( } if in.cfg.Debug { - in.cfg.Tracer.CaptureState(in.evm, pc, op, gasCopy, cost, mem, stack, contract, in.evm.depth, err) + in.cfg.Tracer.CaptureState(in.evm, pc, op, gasCopy, cost, callContext, in.returnData, in.evm.depth, err) logged = true } diff --git a/core/vm/jump_table.go b/core/vm/jump_table.go index 2302a439df6e..4ae7c1e8ce1e 100644 --- a/core/vm/jump_table.go +++ b/core/vm/jump_table.go @@ -21,7 +21,7 @@ import ( ) type ( - executionFunc func(pc *uint64, interpreter *EVMInterpreter, callContext *callCtx) ([]byte, error) + executionFunc func(pc *uint64, interpreter *EVMInterpreter, callContext *ScopeContext) ([]byte, error) gasFunc func(*EVM, *Contract, *Stack, *Memory, uint64) (uint64, error) // last parameter is the requested memory size as a uint64 // memorySizeFunc returns the required size, and whether the operation overflowed a uint64 memorySizeFunc func(*Stack) (size uint64, overflow bool) @@ -54,11 +54,18 @@ var ( londonInstructionSet = newLondonInstructionSet() mergeInstructionSet = newMergeInstructionSet() shanghaiInstructionSet = newShanghaiInstructionSet() + eip1559InstructionSet = newEip1559InstructionSet() ) // JumpTable contains the EVM opcodes supported at a given fork. type JumpTable [256]*operation +func newEip1559InstructionSet() JumpTable { + instructionSet := newShanghaiInstructionSet() + enable2929(&instructionSet) // Gas cost increases for state access opcodes https://eips.ethereum.org/EIPS/eip-2929 + return instructionSet +} + func newShanghaiInstructionSet() JumpTable { instructionSet := newMergeInstructionSet() enable3855(&instructionSet) // PUSH0 instruction diff --git a/core/vm/logger.go b/core/vm/logger.go index 6cdb7a7c5bed..a6a995eef942 100644 --- a/core/vm/logger.go +++ b/core/vm/logger.go @@ -18,20 +18,19 @@ package vm import ( "encoding/hex" - "errors" "fmt" "io" "math/big" + "strings" "time" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/common/math" "github.com/XinFinOrg/XDPoSChain/core/types" + "github.com/XinFinOrg/XDPoSChain/params" ) -var errTraceLimitReached = errors.New("the number of logs reached the specified limit") - // Storage represents a contract's storage. type Storage map[common.Hash]common.Hash @@ -52,6 +51,9 @@ type LogConfig struct { DisableStorage bool // disable storage capture Debug bool // print output during capture end Limit int // maximum length of output, but zero means unlimited + + // Chain overrides, can be used to execute a trace using future fork rules + Overrides *params.ChainConfig `json:"overrides,omitempty"` } //go:generate gencodec -type StructLog -field-override structLogMarshaling -out gen_structlog.go @@ -101,10 +103,10 @@ func (s *StructLog) ErrorString() string { // Note that reference types are actual VM data structures; make copies // if you need to retain them beyond the current call. type Tracer interface { - CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error - CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error - CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error - CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error + CaptureStart(env *EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) + CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) + CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, depth int, err error) + CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) } // StructLogger is an EVM state logger and implements Tracer. @@ -133,17 +135,19 @@ func NewStructLogger(cfg *LogConfig) *StructLogger { } // CaptureStart implements the Tracer interface to initialize the tracing operation. -func (l *StructLogger) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error { - return nil +func (l *StructLogger) CaptureStart(env *EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { } // CaptureState logs a new structured log message and pushes it out to the environment // -// CaptureState also tracks SSTORE ops to track dirty values. -func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error { +// CaptureState also tracks SLOAD/SSTORE ops to track storage change. +func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) { + memory := scope.Memory + stack := scope.Stack + contract := scope.Contract // check if already accumulated the specified number of logs if l.cfg.Limit != 0 && l.cfg.Limit <= len(l.logs) { - return errTraceLimitReached + return } // initialise new changed values storage container for this contract @@ -184,17 +188,15 @@ func (l *StructLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost ui log := StructLog{pc, op, gas, cost, mem, memory.Len(), stck, storage, depth, env.StateDB.GetRefund(), err} l.logs = append(l.logs, log) - return nil } // CaptureFault implements the Tracer interface to trace an execution fault // while running an opcode. -func (l *StructLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error { - return nil +func (l *StructLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, depth int, err error) { } // CaptureEnd is called after the call finishes to finalize the tracing. -func (l *StructLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error { +func (l *StructLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { l.output = output l.err = err if l.cfg.Debug { @@ -203,7 +205,6 @@ func (l *StructLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration fmt.Printf(" error: %v\n", err) } } - return nil } // StructLogs returns the captured log entries. @@ -257,3 +258,65 @@ func WriteLogs(writer io.Writer, logs []*types.Log) { fmt.Fprintln(writer) } } + +type mdLogger struct { + out io.Writer + cfg *LogConfig +} + +// NewMarkdownLogger creates a logger which outputs information in a format adapted +// for human readability, and is also a valid markdown table +func NewMarkdownLogger(cfg *LogConfig, writer io.Writer) *mdLogger { + l := &mdLogger{writer, cfg} + if l.cfg == nil { + l.cfg = &LogConfig{} + } + return l +} + +func (t *mdLogger) CaptureStart(env *EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { + if !create { + fmt.Fprintf(t.out, "From: `%v`\nTo: `%v`\nData: `0x%x`\nGas: `%d`\nValue `%v` wei\n", + from.String(), to.String(), + input, gas, value) + } else { + fmt.Fprintf(t.out, "From: `%v`\nCreate at: `%v`\nData: `0x%x`\nGas: `%d`\nValue `%v` wei\n", + from.String(), to.String(), + input, gas, value) + } + + fmt.Fprintf(t.out, ` +| Pc | Op | Cost | Stack | RStack | Refund | +|-------|-------------|------|-----------|-----------|---------| +`) +} + +// CaptureState also tracks SLOAD/SSTORE ops to track storage change. +func (t *mdLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) { + stack := scope.Stack + fmt.Fprintf(t.out, "| %4d | %10v | %3d |", pc, op, cost) + + if !t.cfg.DisableStack { + // format stack + var a []string + for _, elem := range stack.data { + a = append(a, fmt.Sprintf("%v", elem.String())) + } + b := fmt.Sprintf("[%v]", strings.Join(a, ",")) + fmt.Fprintf(t.out, "%10v |", b) + } + fmt.Fprintf(t.out, "%10v |", env.StateDB.GetRefund()) + fmt.Fprintln(t.out, "") + if err != nil { + fmt.Fprintf(t.out, "Error: %v\n", err) + } +} + +func (t *mdLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, depth int, err error) { + fmt.Fprintf(t.out, "\nError: at pc=%d, op=%v: %v\n", pc, op, err) +} + +func (t *mdLogger) CaptureEnd(output []byte, gasUsed uint64, tm time.Duration, err error) { + fmt.Fprintf(t.out, "\nOutput: `0x%x`\nConsumed gas: `%d`\nError: `%v`\n", + output, gasUsed, err) +} diff --git a/core/vm/logger_json.go b/core/vm/logger_json.go index 4a63120a3df4..64bb3a9fe9e1 100644 --- a/core/vm/logger_json.go +++ b/core/vm/logger_json.go @@ -41,12 +41,16 @@ func NewJSONLogger(cfg *LogConfig, writer io.Writer) *JSONLogger { return l } -func (l *JSONLogger) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error { - return nil +func (l *JSONLogger) CaptureStart(env *EVM, from, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { } +func (l *JSONLogger) CaptureFault(*EVM, uint64, OpCode, uint64, uint64, *ScopeContext, int, error) {} + // CaptureState outputs state information on the logger. -func (l *JSONLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error { +func (l *JSONLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint64, scope *ScopeContext, rData []byte, depth int, err error) { + memory := scope.Memory + stack := scope.Stack + log := StructLog{ Pc: pc, Op: op, @@ -69,16 +73,11 @@ func (l *JSONLogger) CaptureState(env *EVM, pc uint64, op OpCode, gas, cost uint } log.Stack = logstack } - return l.encoder.Encode(log) -} - -// CaptureFault outputs state information on the logger. -func (l *JSONLogger) CaptureFault(env *EVM, pc uint64, op OpCode, gas, cost uint64, memory *Memory, stack *Stack, contract *Contract, depth int, err error) error { - return l.CaptureState(env, pc, op, gas, cost, memory, stack, contract, depth, err) + l.encoder.Encode(log) } // CaptureEnd is triggered at end of execution. -func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error { +func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { type endLog struct { Output string `json:"output"` GasUsed math.HexOrDecimal64 `json:"gasUsed"` @@ -89,5 +88,5 @@ func (l *JSONLogger) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, if err != nil { errMsg = err.Error() } - return l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, errMsg}) + l.encoder.Encode(endLog{common.Bytes2Hex(output), math.HexOrDecimal64(gasUsed), t, errMsg}) } diff --git a/core/vm/logger_test.go b/core/vm/logger_test.go index f6c15d9cbb0f..ef5fb125100c 100644 --- a/core/vm/logger_test.go +++ b/core/vm/logger_test.go @@ -52,14 +52,17 @@ func TestStoreCapture(t *testing.T) { var ( env = NewEVM(Context{}, &dummyStatedb{}, nil, params.TestChainConfig, Config{}) logger = NewStructLogger(nil) - mem = NewMemory() - stack = newstack() contract = NewContract(&dummyContractRef{}, &dummyContractRef{}, new(big.Int), 0) + scope = &ScopeContext{ + Memory: NewMemory(), + Stack: newstack(), + Contract: contract, + } ) - stack.push(uint256.NewInt(1)) - stack.push(new(uint256.Int)) + scope.Stack.push(uint256.NewInt(1)) + scope.Stack.push(new(uint256.Int)) var index common.Hash - logger.CaptureState(env, 0, SSTORE, 0, 0, mem, stack, contract, 0, nil) + logger.CaptureState(env, 0, SSTORE, 0, 0, scope, nil, 0, nil) if len(logger.changedValues[contract.Address()]) == 0 { t.Fatalf("expected exactly 1 changed value on address %x, got %d", contract.Address(), len(logger.changedValues[contract.Address()])) } diff --git a/core/vm/operations_acl.go b/core/vm/operations_acl.go new file mode 100644 index 000000000000..eb3c0f43dd3a --- /dev/null +++ b/core/vm/operations_acl.go @@ -0,0 +1,221 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package vm + +import ( + "errors" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/math" + "github.com/XinFinOrg/XDPoSChain/params" +) + +const ( + ColdAccountAccessCostEIP2929 = uint64(2600) // COLD_ACCOUNT_ACCESS_COST + ColdSloadCostEIP2929 = uint64(2100) // COLD_SLOAD_COST + WarmStorageReadCostEIP2929 = uint64(100) // WARM_STORAGE_READ_COST +) + +// gasSStoreEIP2929 implements gas cost for SSTORE according to EIP-2929" +// +// When calling SSTORE, check if the (address, storage_key) pair is in accessed_storage_keys. +// If it is not, charge an additional COLD_SLOAD_COST gas, and add the pair to accessed_storage_keys. +// Additionally, modify the parameters defined in EIP 2200 as follows: +// +// Parameter Old value New value +// SLOAD_GAS 800 = WARM_STORAGE_READ_COST +// SSTORE_RESET_GAS 5000 5000 - COLD_SLOAD_COST +// +// The other parameters defined in EIP 2200 are unchanged. +// see gasSStoreEIP2200(...) in core/vm/gas_table.go for more info about how EIP 2200 is specified +func gasSStoreEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + // If we fail the minimum gas availability invariant, fail (0) + if contract.Gas <= params.SstoreSentryGasEIP2200 { + return 0, errors.New("not enough gas for reentrancy sentry") + } + // Gas sentry honoured, do the actual gas calculation based on the stored value + var ( + y, x = stack.Back(1), stack.peek() + slot = common.Hash(x.Bytes32()) + current = evm.StateDB.GetState(contract.Address(), slot) + cost = uint64(0) + ) + // Check slot presence in the access list + if addrPresent, slotPresent := evm.StateDB.SlotInAccessList(contract.Address(), slot); !slotPresent { + cost = ColdSloadCostEIP2929 + // If the caller cannot afford the cost, this change will be rolled back + evm.StateDB.AddSlotToAccessList(contract.Address(), slot) + if !addrPresent { + // Once we're done with YOLOv2 and schedule this for mainnet, might + // be good to remove this panic here, which is just really a + // canary to have during testing + panic("impossible case: address was not present in access list during sstore op") + } + } + value := common.Hash(y.Bytes32()) + + if current == value { // noop (1) + // EIP 2200 original clause: + // return params.SloadGasEIP2200, nil + return cost + WarmStorageReadCostEIP2929, nil // SLOAD_GAS + } + original := evm.StateDB.GetCommittedState(contract.Address(), common.Hash(x.Bytes32())) + if original == current { + if original == (common.Hash{}) { // create slot (2.1.1) + return cost + params.SstoreSetGasEIP2200, nil + } + if value == (common.Hash{}) { // delete slot (2.1.2b) + evm.StateDB.AddRefund(params.SstoreClearsScheduleRefundEIP2200) + } + // EIP-2200 original clause: + // return params.SstoreResetGasEIP2200, nil // write existing slot (2.1.2) + return cost + (params.SstoreResetGasEIP2200 - ColdSloadCostEIP2929), nil // write existing slot (2.1.2) + } + if original != (common.Hash{}) { + if current == (common.Hash{}) { // recreate slot (2.2.1.1) + evm.StateDB.SubRefund(params.SstoreClearsScheduleRefundEIP2200) + } else if value == (common.Hash{}) { // delete slot (2.2.1.2) + evm.StateDB.AddRefund(params.SstoreClearsScheduleRefundEIP2200) + } + } + if original == value { + if original == (common.Hash{}) { // reset to original inexistent slot (2.2.2.1) + // EIP 2200 Original clause: + //evm.StateDB.AddRefund(params.SstoreSetGasEIP2200 - params.SloadGasEIP2200) + evm.StateDB.AddRefund(params.SstoreSetGasEIP2200 - WarmStorageReadCostEIP2929) + } else { // reset to original existing slot (2.2.2.2) + // EIP 2200 Original clause: + // evm.StateDB.AddRefund(params.SstoreResetGasEIP2200 - params.SloadGasEIP2200) + // - SSTORE_RESET_GAS redefined as (5000 - COLD_SLOAD_COST) + // - SLOAD_GAS redefined as WARM_STORAGE_READ_COST + // Final: (5000 - COLD_SLOAD_COST) - WARM_STORAGE_READ_COST + evm.StateDB.AddRefund((params.SstoreResetGasEIP2200 - ColdSloadCostEIP2929) - WarmStorageReadCostEIP2929) + } + } + // EIP-2200 original clause: + //return params.SloadGasEIP2200, nil // dirty update (2.2) + return cost + WarmStorageReadCostEIP2929, nil // dirty update (2.2) +} + +// gasSLoadEIP2929 calculates dynamic gas for SLOAD according to EIP-2929 +// For SLOAD, if the (address, storage_key) pair (where address is the address of the contract +// whose storage is being read) is not yet in accessed_storage_keys, +// charge 2100 gas and add the pair to accessed_storage_keys. +// If the pair is already in accessed_storage_keys, charge 100 gas. +func gasSLoadEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + loc := stack.peek() + slot := common.Hash(loc.Bytes32()) + // Check slot presence in the access list + if _, slotPresent := evm.StateDB.SlotInAccessList(contract.Address(), slot); !slotPresent { + // If the caller cannot afford the cost, this change will be rolled back + // If he does afford it, we can skip checking the same thing later on, during execution + evm.StateDB.AddSlotToAccessList(contract.Address(), slot) + return ColdSloadCostEIP2929, nil + } + return WarmStorageReadCostEIP2929, nil +} + +// gasExtCodeCopyEIP2929 implements extcodecopy according to EIP-2929 +// EIP spec: +// > If the target is not in accessed_addresses, +// > charge COLD_ACCOUNT_ACCESS_COST gas, and add the address to accessed_addresses. +// > Otherwise, charge WARM_STORAGE_READ_COST gas. +func gasExtCodeCopyEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + // memory expansion first (dynamic part of pre-2929 implementation) + gas, err := gasExtCodeCopy(evm, contract, stack, mem, memorySize) + if err != nil { + return 0, err + } + addr := common.Address(stack.peek().Bytes20()) + // Check slot presence in the access list + if !evm.StateDB.AddressInAccessList(addr) { + evm.StateDB.AddAddressToAccessList(addr) + var overflow bool + // We charge (cold-warm), since 'warm' is already charged as constantGas + if gas, overflow = math.SafeAdd(gas, ColdAccountAccessCostEIP2929-WarmStorageReadCostEIP2929); overflow { + return 0, ErrGasUintOverflow + } + return gas, nil + } + return gas, nil +} + +// gasEip2929AccountCheck checks whether the first stack item (as address) is present in the access list. +// If it is, this method returns '0', otherwise 'cold-warm' gas, presuming that the opcode using it +// is also using 'warm' as constant factor. +// This method is used by: +// - extcodehash, +// - extcodesize, +// - (ext) balance +func gasEip2929AccountCheck(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + addr := common.Address(stack.peek().Bytes20()) + // Check slot presence in the access list + if !evm.StateDB.AddressInAccessList(addr) { + // If the caller cannot afford the cost, this change will be rolled back + evm.StateDB.AddAddressToAccessList(addr) + // The warm storage read cost is already charged as constantGas + return ColdAccountAccessCostEIP2929 - WarmStorageReadCostEIP2929, nil + } + return 0, nil +} + +func makeCallVariantGasCallEIP2929(oldCalculator gasFunc) gasFunc { + return func(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + addr := common.Address(stack.Back(1).Bytes20()) + // Check slot presence in the access list + if !evm.StateDB.AddressInAccessList(addr) { + evm.StateDB.AddAddressToAccessList(addr) + // The WarmStorageReadCostEIP2929 (100) is already deducted in the form of a constant cost + if !contract.UseGas(ColdAccountAccessCostEIP2929 - WarmStorageReadCostEIP2929) { + return 0, ErrOutOfGas + } + } + // Now call the old calculator, which takes into account + // - create new account + // - transfer value + // - memory expansion + // - 63/64ths rule + return oldCalculator(evm, contract, stack, mem, memorySize) + } +} + +var ( + gasCallEIP2929 = makeCallVariantGasCallEIP2929(gasCall) + gasDelegateCallEIP2929 = makeCallVariantGasCallEIP2929(gasDelegateCall) + gasStaticCallEIP2929 = makeCallVariantGasCallEIP2929(gasStaticCall) + gasCallCodeEIP2929 = makeCallVariantGasCallEIP2929(gasCallCode) +) + +func gasSelfdestructEIP2929(evm *EVM, contract *Contract, stack *Stack, mem *Memory, memorySize uint64) (uint64, error) { + var ( + gas uint64 + address = common.Address(stack.peek().Bytes20()) + ) + if !evm.StateDB.AddressInAccessList(address) { + // If the caller cannot afford the cost, this change will be rolled back + evm.StateDB.AddAddressToAccessList(address) + gas = ColdAccountAccessCostEIP2929 + } + // if empty and transfers value + if evm.StateDB.Empty(address) && evm.StateDB.GetBalance(contract.Address()).Sign() != 0 { + gas += params.CreateBySelfdestructGas + } + if !evm.StateDB.HasSuicided(contract.Address()) { + evm.StateDB.AddRefund(params.SelfdestructRefundGas) + } + return gas, nil +} diff --git a/core/vm/runtime/runtime.go b/core/vm/runtime/runtime.go index 63b2d6d9c2fa..a8741fa12dee 100644 --- a/core/vm/runtime/runtime.go +++ b/core/vm/runtime/runtime.go @@ -17,12 +17,12 @@ package runtime import ( - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math" "math/big" "time" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/vm" "github.com/XinFinOrg/XDPoSChain/crypto" @@ -107,6 +107,9 @@ func Execute(code, input []byte, cfg *Config) ([]byte, *state.StateDB, error) { vmenv = NewEnv(cfg) sender = vm.AccountRef(cfg.Origin) ) + if rules := cfg.ChainConfig.Rules(vmenv.Context.BlockNumber); rules.IsEIP1559 { + cfg.State.PrepareAccessList(cfg.Origin, &address, vm.ActivePrecompiles(rules), nil) + } cfg.State.CreateAccount(address) // set the receiver's (the executing contract) code for execution. cfg.State.SetCode(address, code) @@ -137,7 +140,9 @@ func Create(input []byte, cfg *Config) ([]byte, common.Address, uint64, error) { vmenv = NewEnv(cfg) sender = vm.AccountRef(cfg.Origin) ) - + if rules := cfg.ChainConfig.Rules(vmenv.Context.BlockNumber); rules.IsEIP1559 { + cfg.State.PrepareAccessList(cfg.Origin, nil, vm.ActivePrecompiles(rules), nil) + } // Call the code with the given configuration. code, address, leftOverGas, err := vmenv.Create( sender, @@ -159,6 +164,11 @@ func Call(address common.Address, input []byte, cfg *Config) ([]byte, uint64, er vmenv := NewEnv(cfg) sender := cfg.State.GetOrNewStateObject(cfg.Origin) + statedb := cfg.State + if rules := cfg.ChainConfig.Rules(vmenv.Context.BlockNumber); rules.IsEIP1559 { + statedb.PrepareAccessList(cfg.Origin, &address, vm.ActivePrecompiles(rules), nil) + } + // Call the code with the given configuration. ret, leftOverGas, err := vmenv.Call( sender, @@ -167,6 +177,5 @@ func Call(address common.Address, input []byte, cfg *Config) ([]byte, uint64, er cfg.GasLimit, cfg.Value, ) - return ret, leftOverGas, err } diff --git a/core/vm/runtime/runtime_test.go b/core/vm/runtime/runtime_test.go index 89e6d4b0beb0..826b431f48ad 100644 --- a/core/vm/runtime/runtime_test.go +++ b/core/vm/runtime/runtime_test.go @@ -17,8 +17,9 @@ package runtime import ( - "github.com/XinFinOrg/XDPoSChain/core/rawdb" + "fmt" "math/big" + "os" "strings" "testing" @@ -26,6 +27,8 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/core" + "github.com/XinFinOrg/XDPoSChain/core/asm" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/core/vm" @@ -327,30 +330,271 @@ func TestBlockhash(t *testing.T) { } } -// BenchmarkSimpleLoop test a pretty simple loop which loops -// 1M (1 048 575) times. -// Takes about 200 ms +// benchmarkNonModifyingCode benchmarks code, but if the code modifies the +// state, this should not be used, since it does not reset the state between runs. +func benchmarkNonModifyingCode(gas uint64, code []byte, name string, b *testing.B) { + cfg := new(Config) + setDefaults(cfg) + cfg.State, _ = state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase())) + cfg.GasLimit = gas + var ( + destination = common.BytesToAddress([]byte("contract")) + vmenv = NewEnv(cfg) + sender = vm.AccountRef(cfg.Origin) + ) + cfg.State.CreateAccount(destination) + eoa := common.HexToAddress("E0") + { + cfg.State.CreateAccount(eoa) + cfg.State.SetNonce(eoa, 100) + } + reverting := common.HexToAddress("EE") + { + cfg.State.CreateAccount(reverting) + cfg.State.SetCode(reverting, []byte{ + byte(vm.PUSH1), 0x00, + byte(vm.PUSH1), 0x00, + byte(vm.REVERT), + }) + } + + //cfg.State.CreateAccount(cfg.Origin) + // set the receiver's (the executing contract) code for execution. + cfg.State.SetCode(destination, code) + vmenv.Call(sender, destination, nil, gas, cfg.Value) + + b.Run(name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + vmenv.Call(sender, destination, nil, gas, cfg.Value) + } + }) +} + +// BenchmarkSimpleLoop test a pretty simple loop which loops until OOG +// 55 ms func BenchmarkSimpleLoop(b *testing.B) { - // 0xfffff = 1048575 loops - code := []byte{ - byte(vm.PUSH3), 0x0f, 0xff, 0xff, + + staticCallIdentity := []byte{ + byte(vm.JUMPDEST), // [ count ] + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.PUSH1), 0x4, // address of identity + byte(vm.GAS), // gas + byte(vm.STATICCALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + + callIdentity := []byte{ + byte(vm.JUMPDEST), // [ count ] + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.DUP1), // value + byte(vm.PUSH1), 0x4, // address of identity + byte(vm.GAS), // gas + byte(vm.CALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + + callInexistant := []byte{ + byte(vm.JUMPDEST), // [ count ] + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.DUP1), // value + byte(vm.PUSH1), 0xff, // address of existing contract + byte(vm.GAS), // gas + byte(vm.CALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + + callEOA := []byte{ byte(vm.JUMPDEST), // [ count ] - byte(vm.PUSH1), 1, // [count, 1] - byte(vm.SWAP1), // [1, count] - byte(vm.SUB), // [ count -1 ] - byte(vm.DUP1), // [ count -1 , count-1] - byte(vm.PUSH1), 4, // [count-1, count -1, label] - byte(vm.JUMPI), // [ 0 ] - byte(vm.STOP), + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.DUP1), // value + byte(vm.PUSH1), 0xE0, // address of EOA + byte(vm.GAS), // gas + byte(vm.CALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), } + + loopingCode := []byte{ + byte(vm.JUMPDEST), // [ count ] + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.DUP1), // out insize + byte(vm.DUP1), // in offset + byte(vm.PUSH1), 0x4, // address of identity + byte(vm.GAS), // gas + + byte(vm.POP), byte(vm.POP), byte(vm.POP), byte(vm.POP), byte(vm.POP), byte(vm.POP), + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + + calllRevertingContractWithInput := []byte{ + byte(vm.JUMPDEST), // + // push args for the call + byte(vm.PUSH1), 0, // out size + byte(vm.DUP1), // out offset + byte(vm.PUSH1), 0x20, // in size + byte(vm.PUSH1), 0x00, // in offset + byte(vm.PUSH1), 0x00, // value + byte(vm.PUSH1), 0xEE, // address of reverting contract + byte(vm.GAS), // gas + byte(vm.CALL), + byte(vm.POP), // pop return value + byte(vm.PUSH1), 0, // jumpdestination + byte(vm.JUMP), + } + //tracer := vm.NewJSONLogger(nil, os.Stdout) - //Execute(code, nil, &Config{ + //Execute(loopingCode, nil, &Config{ // EVMConfig: vm.Config{ // Debug: true, // Tracer: tracer, // }}) + // 100M gas + benchmarkNonModifyingCode(100000000, staticCallIdentity, "staticcall-identity-100M", b) + benchmarkNonModifyingCode(100000000, callIdentity, "call-identity-100M", b) + benchmarkNonModifyingCode(100000000, loopingCode, "loop-100M", b) + benchmarkNonModifyingCode(100000000, callInexistant, "call-nonexist-100M", b) + benchmarkNonModifyingCode(100000000, callEOA, "call-EOA-100M", b) + benchmarkNonModifyingCode(100000000, calllRevertingContractWithInput, "call-reverting-100M", b) + + //benchmarkNonModifyingCode(10000000, staticCallIdentity, "staticcall-identity-10M", b) + //benchmarkNonModifyingCode(10000000, loopingCode, "loop-10M", b) +} - for i := 0; i < b.N; i++ { - Execute(code, nil, nil) +// TestEip2929Cases contains various testcases that are used for +// EIP-2929 about gas repricings +func TestEip2929Cases(t *testing.T) { + + id := 1 + prettyPrint := func(comment string, code []byte) { + + instrs := make([]string, 0) + it := asm.NewInstructionIterator(code) + for it.Next() { + if it.Arg() != nil && 0 < len(it.Arg()) { + instrs = append(instrs, fmt.Sprintf("%v 0x%x", it.Op(), it.Arg())) + } else { + instrs = append(instrs, fmt.Sprintf("%v", it.Op())) + } + } + ops := strings.Join(instrs, ", ") + fmt.Printf("### Case %d\n\n", id) + id++ + fmt.Printf("%v\n\nBytecode: \n```\n0x%x\n```\nOperations: \n```\n%v\n```\n\n", + comment, + code, ops) + Execute(code, nil, &Config{ + EVMConfig: vm.Config{ + Debug: true, + Tracer: vm.NewMarkdownLogger(nil, os.Stdout), + ExtraEips: []int{2929}, + }, + }) + } + + { // First eip testcase + code := []byte{ + // Three checks against a precompile + byte(vm.PUSH1), 1, byte(vm.EXTCODEHASH), byte(vm.POP), + byte(vm.PUSH1), 2, byte(vm.EXTCODESIZE), byte(vm.POP), + byte(vm.PUSH1), 3, byte(vm.BALANCE), byte(vm.POP), + // Three checks against a non-precompile + byte(vm.PUSH1), 0xf1, byte(vm.EXTCODEHASH), byte(vm.POP), + byte(vm.PUSH1), 0xf2, byte(vm.EXTCODESIZE), byte(vm.POP), + byte(vm.PUSH1), 0xf3, byte(vm.BALANCE), byte(vm.POP), + // Same three checks (should be cheaper) + byte(vm.PUSH1), 0xf2, byte(vm.EXTCODEHASH), byte(vm.POP), + byte(vm.PUSH1), 0xf3, byte(vm.EXTCODESIZE), byte(vm.POP), + byte(vm.PUSH1), 0xf1, byte(vm.BALANCE), byte(vm.POP), + // Check the origin, and the 'this' + byte(vm.ORIGIN), byte(vm.BALANCE), byte(vm.POP), + byte(vm.ADDRESS), byte(vm.BALANCE), byte(vm.POP), + + byte(vm.STOP), + } + prettyPrint("This checks `EXT`(codehash,codesize,balance) of precompiles, which should be `100`, "+ + "and later checks the same operations twice against some non-precompiles. "+ + "Those are cheaper second time they are accessed. Lastly, it checks the `BALANCE` of `origin` and `this`.", code) + } + + { // EXTCODECOPY + code := []byte{ + // extcodecopy( 0xff,0,0,0,0) + byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, //length, codeoffset, memoffset + byte(vm.PUSH1), 0xff, byte(vm.EXTCODECOPY), + // extcodecopy( 0xff,0,0,0,0) + byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, //length, codeoffset, memoffset + byte(vm.PUSH1), 0xff, byte(vm.EXTCODECOPY), + // extcodecopy( this,0,0,0,0) + byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, byte(vm.PUSH1), 0x00, //length, codeoffset, memoffset + byte(vm.ADDRESS), byte(vm.EXTCODECOPY), + + byte(vm.STOP), + } + prettyPrint("This checks `extcodecopy( 0xff,0,0,0,0)` twice, (should be expensive first time), "+ + "and then does `extcodecopy( this,0,0,0,0)`.", code) + } + + { // SLOAD + SSTORE + code := []byte{ + + // Add slot `0x1` to access list + byte(vm.PUSH1), 0x01, byte(vm.SLOAD), byte(vm.POP), // SLOAD( 0x1) (add to access list) + // Write to `0x1` which is already in access list + byte(vm.PUSH1), 0x11, byte(vm.PUSH1), 0x01, byte(vm.SSTORE), // SSTORE( loc: 0x01, val: 0x11) + // Write to `0x2` which is not in access list + byte(vm.PUSH1), 0x11, byte(vm.PUSH1), 0x02, byte(vm.SSTORE), // SSTORE( loc: 0x02, val: 0x11) + // Write again to `0x2` + byte(vm.PUSH1), 0x11, byte(vm.PUSH1), 0x02, byte(vm.SSTORE), // SSTORE( loc: 0x02, val: 0x11) + // Read slot in access list (0x2) + byte(vm.PUSH1), 0x02, byte(vm.SLOAD), // SLOAD( 0x2) + // Read slot in access list (0x1) + byte(vm.PUSH1), 0x01, byte(vm.SLOAD), // SLOAD( 0x1) + } + prettyPrint("This checks `sload( 0x1)` followed by `sstore(loc: 0x01, val:0x11)`, then 'naked' sstore:"+ + "`sstore(loc: 0x02, val:0x11)` twice, and `sload(0x2)`, `sload(0x1)`. ", code) + } + { // Call variants + code := []byte{ + // identity precompile + byte(vm.PUSH1), 0x0, byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), + byte(vm.PUSH1), 0x04, byte(vm.PUSH1), 0x0, byte(vm.CALL), byte(vm.POP), + + // random account - call 1 + byte(vm.PUSH1), 0x0, byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), + byte(vm.PUSH1), 0xff, byte(vm.PUSH1), 0x0, byte(vm.CALL), byte(vm.POP), + + // random account - call 2 + byte(vm.PUSH1), 0x0, byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), byte(vm.DUP1), + byte(vm.PUSH1), 0xff, byte(vm.PUSH1), 0x0, byte(vm.STATICCALL), byte(vm.POP), + } + prettyPrint("This calls the `identity`-precompile (cheap), then calls an account (expensive) and `staticcall`s the same"+ + "account (cheap)", code) } } diff --git a/crypto/crypto.go b/crypto/crypto.go index 2213bf0c1451..2872bb098bfe 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -17,57 +17,94 @@ package crypto import ( + "bufio" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "encoding/hex" "errors" "fmt" + "hash" "io" + "io/ioutil" "math/big" "os" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/math" - "github.com/XinFinOrg/XDPoSChain/crypto/sha3" "github.com/XinFinOrg/XDPoSChain/rlp" + "golang.org/x/crypto/sha3" ) +// SignatureLength indicates the byte length required to carry a signature with recovery id. +const SignatureLength = 64 + 1 // 64 bytes ECDSA signature + 1 byte recovery id + +// RecoveryIDOffset points to the byte offset within the signature that contains the recovery id. +const RecoveryIDOffset = 64 + +// DigestLength sets the signature digest exact length +const DigestLength = 32 + var ( - secp256k1_N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) - secp256k1_halfN = new(big.Int).Div(secp256k1_N, big.NewInt(2)) + secp256k1N, _ = new(big.Int).SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) + secp256k1halfN = new(big.Int).Div(secp256k1N, big.NewInt(2)) ) +var errInvalidPubkey = errors.New("invalid secp256k1 public key") + +// KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports +// Read to get a variable amount of data from the hash state. Read is faster than Sum +// because it doesn't copy the internal state, but also modifies the internal state. +type KeccakState interface { + hash.Hash + Read([]byte) (int, error) +} + +// NewKeccakState creates a new KeccakState +func NewKeccakState() KeccakState { + return sha3.NewLegacyKeccak256().(KeccakState) +} + +// HashData hashes the provided data using the KeccakState and returns a 32 byte hash +func HashData(kh KeccakState, data []byte) (h common.Hash) { + kh.Reset() + kh.Write(data) + kh.Read(h[:]) + return h +} + // Keccak256 calculates and returns the Keccak256 hash of the input data. func Keccak256(data ...[]byte) []byte { - d := sha3.NewKeccak256() + b := make([]byte, 32) + d := NewKeccakState() for _, b := range data { d.Write(b) } - return d.Sum(nil) + d.Read(b) + return b } // Keccak256Hash calculates and returns the Keccak256 hash of the input data, // converting it to an internal Hash data structure. func Keccak256Hash(data ...[]byte) (h common.Hash) { - d := sha3.NewKeccak256() + d := NewKeccakState() for _, b := range data { d.Write(b) } - d.Sum(h[:0]) + d.Read(h[:]) return h } // Keccak512 calculates and returns the Keccak512 hash of the input data. func Keccak512(data ...[]byte) []byte { - d := sha3.NewKeccak512() + d := sha3.NewLegacyKeccak512() for _, b := range data { d.Write(b) } return d.Sum(nil) } -// Creates an ethereum address given the bytes and the nonce +// CreateAddress creates an ethereum address given the bytes and the nonce func CreateAddress(b common.Address, nonce uint64) common.Address { data, _ := rlp.EncodeToBytes([]interface{}{b, nonce}) return common.BytesToAddress(Keccak256(data)[12:]) @@ -104,7 +141,7 @@ func toECDSA(d []byte, strict bool) (*ecdsa.PrivateKey, error) { priv.D = new(big.Int).SetBytes(d) // The priv.D must < N - if priv.D.Cmp(secp256k1_N) >= 0 { + if priv.D.Cmp(secp256k1N) >= 0 { return nil, fmt.Errorf("invalid private key, >=N") } // The priv.D must not be zero or negative. @@ -127,12 +164,13 @@ func FromECDSA(priv *ecdsa.PrivateKey) []byte { return math.PaddedBigBytes(priv.D, priv.Params().BitSize/8) } -func ToECDSAPub(pub []byte) *ecdsa.PublicKey { - if len(pub) == 0 { - return nil - } +// UnmarshalPubkey converts bytes to a secp256k1 public key. +func UnmarshalPubkey(pub []byte) (*ecdsa.PublicKey, error) { x, y := elliptic.Unmarshal(S256(), pub) - return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y} + if x == nil { + return nil, errInvalidPubkey + } + return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y}, nil } func FromECDSAPub(pub *ecdsa.PublicKey) []byte { @@ -145,38 +183,77 @@ func FromECDSAPub(pub *ecdsa.PublicKey) []byte { // HexToECDSA parses a secp256k1 private key. func HexToECDSA(hexkey string) (*ecdsa.PrivateKey, error) { b, err := hex.DecodeString(hexkey) - if err != nil { - return nil, errors.New("invalid hex string") + if byteErr, ok := err.(hex.InvalidByteError); ok { + return nil, fmt.Errorf("invalid hex character %q in private key", byte(byteErr)) + } else if err != nil { + return nil, errors.New("invalid hex data for private key") } return ToECDSA(b) } // LoadECDSA loads a secp256k1 private key from the given file. func LoadECDSA(file string) (*ecdsa.PrivateKey, error) { - buf := make([]byte, 64) fd, err := os.Open(file) if err != nil { return nil, err } defer fd.Close() - if _, err := io.ReadFull(fd, buf); err != nil { - return nil, err - } - key, err := hex.DecodeString(string(buf)) + r := bufio.NewReader(fd) + buf := make([]byte, 64) + n, err := readASCII(buf, r) if err != nil { return nil, err + } else if n != len(buf) { + return nil, fmt.Errorf("key file too short, want 64 hex characters") + } + if err := checkKeyFileEnd(r); err != nil { + return nil, err + } + + return HexToECDSA(string(buf)) +} + +// readASCII reads into 'buf', stopping when the buffer is full or +// when a non-printable control character is encountered. +func readASCII(buf []byte, r *bufio.Reader) (n int, err error) { + for ; n < len(buf); n++ { + buf[n], err = r.ReadByte() + switch { + case err == io.EOF || buf[n] < '!': + return n, nil + case err != nil: + return n, err + } + } + return n, nil +} + +// checkKeyFileEnd skips over additional newlines at the end of a key file. +func checkKeyFileEnd(r *bufio.Reader) error { + for i := 0; ; i++ { + b, err := r.ReadByte() + switch { + case err == io.EOF: + return nil + case err != nil: + return err + case b != '\n' && b != '\r': + return fmt.Errorf("invalid character %q at end of key file", b) + case i >= 2: + return errors.New("key file too long, want 64 hex characters") + } } - return ToECDSA(key) } // SaveECDSA saves a secp256k1 private key to the given file with // restrictive permissions. The key data is saved hex-encoded. func SaveECDSA(file string, key *ecdsa.PrivateKey) error { k := hex.EncodeToString(FromECDSA(key)) - return os.WriteFile(file, []byte(k), 0600) + return ioutil.WriteFile(file, []byte(k), 0600) } +// GenerateKey generates a new private key. func GenerateKey() (*ecdsa.PrivateKey, error) { return ecdsa.GenerateKey(S256(), rand.Reader) } @@ -189,11 +266,11 @@ func ValidateSignatureValues(v byte, r, s *big.Int, homestead bool) bool { } // reject upper range of s values (ECDSA malleability) // see discussion in secp256k1/libsecp256k1/include/secp256k1.h - if homestead && s.Cmp(secp256k1_halfN) > 0 { + if homestead && s.Cmp(secp256k1halfN) > 0 { return false } // Frontier: allow s to be in full N range - return r.Cmp(secp256k1_N) < 0 && s.Cmp(secp256k1_N) < 0 && (v == 0 || v == 1) + return r.Cmp(secp256k1N) < 0 && s.Cmp(secp256k1N) < 0 && (v == 0 || v == 1) } func PubkeyToAddress(p ecdsa.PublicKey) common.Address { diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index f1910e5c26e9..9e1bb2639b4e 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -20,11 +20,14 @@ import ( "bytes" "crypto/ecdsa" "encoding/hex" + "io/ioutil" "math/big" "os" + "reflect" "testing" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/hexutil" ) var testAddrHex = "970e8128ab834e8eac17ab8e3812f010678cf791" @@ -55,6 +58,33 @@ func BenchmarkSha3(b *testing.B) { } } +func TestUnmarshalPubkey(t *testing.T) { + key, err := UnmarshalPubkey(nil) + if err != errInvalidPubkey || key != nil { + t.Fatalf("expected error, got %v, %v", err, key) + } + key, err = UnmarshalPubkey([]byte{1, 2, 3}) + if err != errInvalidPubkey || key != nil { + t.Fatalf("expected error, got %v, %v", err, key) + } + + var ( + enc, _ = hex.DecodeString("04760c4460e5336ac9bbd87952a3c7ec4363fc0a97bd31c86430806e287b437fd1b01abc6e1db640cf3106b520344af1d58b00b57823db3e1407cbc433e1b6d04d") + dec = &ecdsa.PublicKey{ + Curve: S256(), + X: hexutil.MustDecodeBig("0x760c4460e5336ac9bbd87952a3c7ec4363fc0a97bd31c86430806e287b437fd1"), + Y: hexutil.MustDecodeBig("0xb01abc6e1db640cf3106b520344af1d58b00b57823db3e1407cbc433e1b6d04d"), + } + ) + key, err = UnmarshalPubkey(enc) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !reflect.DeepEqual(key, dec) { + t.Fatal("wrong result") + } +} + func TestSign(t *testing.T) { key, _ := HexToECDSA(testPrivHex) addr := common.HexToAddress(testAddrHex) @@ -68,7 +98,7 @@ func TestSign(t *testing.T) { if err != nil { t.Errorf("ECRecover error: %s", err) } - pubKey := ToECDSAPub(recoveredPub) + pubKey, _ := UnmarshalPubkey(recoveredPub) recoveredAddr := PubkeyToAddress(*pubKey) if addr != recoveredAddr { t.Errorf("Address mismatch: want: %x have: %x", addr, recoveredAddr) @@ -109,39 +139,82 @@ func TestNewContractAddress(t *testing.T) { checkAddr(t, common.HexToAddress("c9ddedf451bc62ce88bf9292afb13df35b670699"), caddr2) } -func TestLoadECDSAFile(t *testing.T) { - keyBytes := common.FromHex(testPrivHex) - fileName0 := "test_key0" - fileName1 := "test_key1" - checkKey := func(k *ecdsa.PrivateKey) { - checkAddr(t, PubkeyToAddress(k.PublicKey), common.HexToAddress(testAddrHex)) - loadedKeyBytes := FromECDSA(k) - if !bytes.Equal(loadedKeyBytes, keyBytes) { - t.Fatalf("private key mismatch: want: %x have: %x", keyBytes, loadedKeyBytes) - } +func TestLoadECDSA(t *testing.T) { + tests := []struct { + input string + err string + }{ + // good + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\r"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\r\n"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n"}, + {input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\r"}, + // bad + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde", + err: "key file too short, want 64 hex characters", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde\n", + err: "key file too short, want 64 hex characters", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeX", + err: "invalid hex character 'X' in private key", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdefX", + err: "invalid character 'X' at end of key file", + }, + { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\n\n\n", + err: "key file too long, want 64 hex characters", + }, } - os.WriteFile(fileName0, []byte(testPrivHex), 0600) - defer os.Remove(fileName0) + for _, test := range tests { + f, err := ioutil.TempFile("", "loadecdsa_test.*.txt") + if err != nil { + t.Fatal(err) + } + filename := f.Name() + f.WriteString(test.input) + f.Close() - key0, err := LoadECDSA(fileName0) - if err != nil { - t.Fatal(err) + _, err = LoadECDSA(filename) + switch { + case err != nil && test.err == "": + t.Fatalf("unexpected error for input %q:\n %v", test.input, err) + case err != nil && err.Error() != test.err: + t.Fatalf("wrong error for input %q:\n %v", test.input, err) + case err == nil && test.err != "": + t.Fatalf("LoadECDSA did not return error for input %q", test.input) + } } - checkKey(key0) +} - // again, this time with SaveECDSA instead of manual save: - err = SaveECDSA(fileName1, key0) +func TestSaveECDSA(t *testing.T) { + f, err := ioutil.TempFile("", "saveecdsa_test.*.txt") if err != nil { t.Fatal(err) } - defer os.Remove(fileName1) + file := f.Name() + f.Close() + defer os.Remove(file) - key1, err := LoadECDSA(fileName1) + key, _ := HexToECDSA(testPrivHex) + if err := SaveECDSA(file, key); err != nil { + t.Fatal(err) + } + loaded, err := LoadECDSA(file) if err != nil { t.Fatal(err) } - checkKey(key1) + if !reflect.DeepEqual(key, loaded) { + t.Fatal("loaded key not equal to saved key") + } } func TestValidateSignatureValues(t *testing.T) { @@ -153,7 +226,7 @@ func TestValidateSignatureValues(t *testing.T) { minusOne := big.NewInt(-1) one := common.Big1 zero := common.Big0 - secp256k1nMinus1 := new(big.Int).Sub(secp256k1_N, common.Big1) + secp256k1nMinus1 := new(big.Int).Sub(secp256k1N, common.Big1) // correct v,r,s check(true, 0, one, one) @@ -180,9 +253,9 @@ func TestValidateSignatureValues(t *testing.T) { // correct sig with max r,s check(true, 0, secp256k1nMinus1, secp256k1nMinus1) // correct v, combinations of incorrect r,s at upper limit - check(false, 0, secp256k1_N, secp256k1nMinus1) - check(false, 0, secp256k1nMinus1, secp256k1_N) - check(false, 0, secp256k1_N, secp256k1_N) + check(false, 0, secp256k1N, secp256k1nMinus1) + check(false, 0, secp256k1nMinus1, secp256k1N) + check(false, 0, secp256k1N, secp256k1N) // current callers ensures r,s cannot be negative, but let's test for that too // as crypto package could be used stand-alone diff --git a/crypto/secp256k1/dummy.go b/crypto/secp256k1/dummy.go index 65a75080f60a..c52cb87bb5de 100644 --- a/crypto/secp256k1/dummy.go +++ b/crypto/secp256k1/dummy.go @@ -15,7 +15,8 @@ package secp256k1 import ( - _ "github.com/ethereum/go-ethereum/crypto/secp256k1/libsecp256k1/include" - _ "github.com/ethereum/go-ethereum/crypto/secp256k1/libsecp256k1/src" - _ "github.com/ethereum/go-ethereum/crypto/secp256k1/libsecp256k1/src/modules/recovery" + _ "github.com/XinFinOrg/XDPoSChain/crypto/secp256k1/libsecp256k1/include" + _ "github.com/XinFinOrg/XDPoSChain/crypto/secp256k1/libsecp256k1/src" + _ "github.com/XinFinOrg/XDPoSChain/crypto/secp256k1/libsecp256k1/src/modules/recovery" ) + diff --git a/eth/api_backend.go b/eth/api_backend.go index f0afde38f17d..c087ec2d3804 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -245,12 +245,14 @@ func (b *EthApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { - state.SetBalance(msg.From(), math.MaxBig256) +func (b *EthApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmConfig *vm.Config) (*vm.EVM, func() error, error) { vmError := func() error { return nil } - + if vmConfig == nil { + vmConfig = b.eth.blockchain.GetVMConfig() + } + state.SetBalance(msg.From(), math.MaxBig256) context := core.NewEVMContext(msg, header, b.eth.BlockChain(), nil) - return vm.NewEVM(context, state, XDCxState, b.eth.chainConfig, vmCfg), vmError, nil + return vm.NewEVM(context, state, XDCxState, b.eth.chainConfig, *vmConfig), vmError, nil } func (b *EthApiBackend) SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEvent) event.Subscription { @@ -304,7 +306,7 @@ func (b *EthApiBackend) GetPoolTransaction(hash common.Hash) *types.Transaction } func (b *EthApiBackend) GetPoolNonce(ctx context.Context, addr common.Address) (uint64, error) { - return b.eth.txPool.State().GetNonce(addr), nil + return b.eth.txPool.Nonce(addr), nil } func (b *EthApiBackend) Stats() (pending int, queued int) { @@ -322,8 +324,8 @@ func (b *EthApiBackend) OrderStats() (pending int, queued int) { return b.eth.txPool.Stats() } -func (b *EthApiBackend) SubscribeTxPreEvent(ch chan<- core.TxPreEvent) event.Subscription { - return b.eth.TxPool().SubscribeTxPreEvent(ch) +func (b *EthApiBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { + return b.eth.TxPool().SubscribeNewTxsEvent(ch) } func (b *EthApiBackend) Downloader() *downloader.Downloader { @@ -346,6 +348,10 @@ func (b *EthApiBackend) EventMux() *event.TypeMux { return b.eth.EventMux() } +func (b *EthApiBackend) RPCGasCap() uint64 { + return b.eth.config.RPCGasCap +} + func (b *EthApiBackend) AccountManager() *accounts.Manager { return b.eth.AccountManager() } @@ -375,6 +381,10 @@ func (b *EthApiBackend) GetEngine() consensus.Engine { return b.eth.engine } +func (b *EthApiBackend) StateAtBlock(ctx context.Context, block *types.Block, reexec uint64, base *state.StateDB, checkLive bool) (*state.StateDB, error) { + return b.eth.stateAtBlock(block, reexec, base, checkLive) +} + func (s *EthApiBackend) GetRewardByHash(hash common.Hash) map[string]map[string]map[string]*big.Int { header := s.eth.blockchain.GetHeaderByHash(hash) if header != nil { diff --git a/eth/api_tracer.go b/eth/api_tracer.go index 2a0f916b5fef..78788e41bb1f 100644 --- a/eth/api_tracer.go +++ b/eth/api_tracer.go @@ -62,6 +62,20 @@ type TraceConfig struct { Reexec *uint64 } +// txTraceContext is the contextual infos about a transaction before it gets run. +type txTraceContext struct { + index int // Index of the transaction within the block + hash common.Hash // Hash of the transaction + block common.Hash // Hash of the block containing the transaction +} + +// TraceCallConfig is the config for traceCall API. It holds one more +// field to override the state for tracing. +type TraceCallConfig struct { + TraceConfig + StateOverrides *ethapi.StateOverride +} + // txTraceResult is the result of a single transaction trace. type txTraceResult struct { Result interface{} `json:"result,omitempty"` // Trace results produced by the tracer @@ -209,9 +223,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl } } msg, _ := tx.AsMessage(signer, balacne, task.block.Number()) + txctx := &txTraceContext{ + index: i, + hash: tx.Hash(), + block: task.block.Hash(), + } vmctx := core.NewEVMContext(msg, task.block.Header(), api.eth.blockchain, nil) - res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) + res, err := api.traceTx(ctx, msg, txctx, vmctx, task.statedb, config) if err != nil { task.results[i] = &txTraceResult{Error: err.Error()} log.Warn("Tracing failed", "hash", tx.Hash(), "block", task.block.NumberU64(), "err", err) @@ -434,6 +453,7 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, if threads > len(txs) { threads = len(txs) } + blockHash := block.Hash() for th := 0; th < threads; th++ { pend.Add(1) go func() { @@ -449,9 +469,14 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, } } msg, _ := txs[task.index].AsMessage(signer, balacne, block.Number()) + txctx := &txTraceContext{ + index: task.index, + hash: txs[task.index].Hash(), + block: blockHash, + } vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) - res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config) + res, err := api.traceTx(ctx, msg, txctx, vmctx, task.statedb, config) if err != nil { results[task.index] = &txTraceResult{Error: err.Error()} continue @@ -478,6 +503,7 @@ func (api *PrivateDebugAPI) traceBlock(ctx context.Context, block *types.Block, } // Generate the next state snapshot fast without tracing msg, _ := tx.AsMessage(signer, balacne, block.Number()) + statedb.Prepare(tx.Hash(), block.Hash(), i) vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) vmenv := vm.NewEVM(vmctx, statedb, XDCxState, api.config, vm.Config{}) @@ -594,14 +620,72 @@ func (api *PrivateDebugAPI) TraceTransaction(ctx context.Context, hash common.Ha if err != nil { return nil, err } - // Trace the transaction and return - return api.traceTx(ctx, msg, vmctx, statedb, config) + + txctx := &txTraceContext{ + index: int(index), + hash: hash, + block: blockHash, + } + return api.traceTx(ctx, msg, txctx, vmctx, statedb, config) +} + +// TraceCall lets you trace a given eth_call. It collects the structured logs +// created during the execution of EVM if the given transaction was added on +// top of the provided block and returns them as a JSON object. +// You can provide -2 as a block number to trace on top of the pending block. +func (api *PrivateDebugAPI) TraceCall(ctx context.Context, args ethapi.CallArgs, blockNrOrHash rpc.BlockNumberOrHash, config *TraceCallConfig) (interface{}, error) { + // Try to retrieve the specified block + var ( + err error + block *types.Block + ) + if hash, ok := blockNrOrHash.Hash(); ok { + block, err = api.eth.ApiBackend.BlockByHash(ctx, hash) + } else if number, ok := blockNrOrHash.Number(); ok { + if number == rpc.PendingBlockNumber { + // We don't have access to the miner here. For tracing 'future' transactions, + // it can be done with block- and state-overrides instead, which offers + // more flexibility and stability than trying to trace on 'pending', since + // the contents of 'pending' is unstable and probably not a true representation + // of what the next actual block is likely to contain. + return nil, errors.New("tracing on top of pending is not supported") + } + block, err = api.eth.ApiBackend.BlockByNumber(ctx, number) + } else { + return nil, errors.New("invalid arguments; neither block nor hash specified") + } + if err != nil { + return nil, err + } + // try to recompute the state + reexec := defaultTraceReexec + if config != nil && config.Reexec != nil { + reexec = *config.Reexec + } + statedb, err := api.eth.ApiBackend.StateAtBlock(ctx, block, reexec, nil, true) + if err != nil { + return nil, err + } + // Apply the customized state rules if required. + if config != nil { + if err := config.StateOverrides.Apply(statedb); err != nil { + return nil, err + } + } + // Execute the trace + msg := args.ToMessage(api.eth.ApiBackend, block.Number(), api.eth.ApiBackend.RPCGasCap()) + vmctx := core.NewEVMContext(msg, block.Header(), api.eth.blockchain, nil) + var traceConfig *TraceConfig + if config != nil { + traceConfig = &config.TraceConfig + } + return api.traceTx(ctx, msg, new(txTraceContext), vmctx, statedb, traceConfig) } // traceTx configures a new tracer according to the provided configuration, and // executes the given message in the provided environment. The return value will // be tracer dependent. -func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { +func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, txctx *txTraceContext, vmctx vm.Context, statedb *state.StateDB, config *TraceConfig) (interface{}, error) { // Assemble the structured logger or the JavaScript tracer var ( tracer vm.Tracer @@ -637,6 +721,9 @@ func (api *PrivateDebugAPI) traceTx(ctx context.Context, message core.Message, v // Run the transaction with tracing enabled. vmenv := vm.NewEVM(vmctx, statedb, nil, api.config, vm.Config{Debug: true, Tracer: tracer}) + // Call Prepare to clear out the statedb access list + statedb.Prepare(txctx.hash, txctx.block, txctx.index) + owner := common.Address{} ret, gas, failed, err, _ := core.ApplyMessage(vmenv, message, new(core.GasPool).AddGas(message.Gas()), owner) if err != nil { diff --git a/eth/config.go b/eth/config.go index c932a7ae4418..73a95aa6756f 100644 --- a/eth/config.go +++ b/eth/config.go @@ -50,7 +50,8 @@ var DefaultConfig = Config{ TrieTimeout: 5 * time.Minute, GasPrice: big.NewInt(0.25 * params.Shannon), - TxPool: core.DefaultTxPoolConfig, + TxPool: core.DefaultTxPoolConfig, + RPCGasCap: 25000000, GPO: gasprice.Config{ Blocks: 20, Percentile: 60, @@ -114,6 +115,9 @@ type Config struct { // Miscellaneous options DocRoot string `toml:"-"` + + // RPCGasCap is the global gas cap for eth-call variants. + RPCGasCap uint64 } type configMarshaling struct { diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 5e32e6b6d849..1a500142c839 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -980,22 +980,22 @@ func (d *Downloader) fetchReceipts(from uint64) error { // various callbacks to handle the slight differences between processing them. // // The instrumentation parameters: -// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) -// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) -// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) -// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) -// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) -// - pending: task callback for the number of requests still needing download (detect completion/non-completability) -// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) -// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) -// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) -// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) -// - fetch: network callback to actually send a particular download request to a physical remote peer -// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) -// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping) -// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks -// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) -// - kind: textual label of the type being downloaded to display in log mesages +// - errCancel: error type to return if the fetch operation is cancelled (mostly makes logging nicer) +// - deliveryCh: channel from which to retrieve downloaded data packets (merged from all concurrent peers) +// - deliver: processing callback to deliver data packets into type specific download queues (usually within `queue`) +// - wakeCh: notification channel for waking the fetcher when new tasks are available (or sync completed) +// - expire: task callback method to abort requests that took too long and return the faulty peers (traffic shaping) +// - pending: task callback for the number of requests still needing download (detect completion/non-completability) +// - inFlight: task callback for the number of in-progress requests (wait for all active downloads to finish) +// - throttle: task callback to check if the processing queue is full and activate throttling (bound memory use) +// - reserve: task callback to reserve new download tasks to a particular peer (also signals partial completions) +// - fetchHook: tester callback to notify of new tasks being initiated (allows testing the scheduling logic) +// - fetch: network callback to actually send a particular download request to a physical remote peer +// - cancel: task callback to abort an in-flight download request and allow rescheduling it (in case of lost peer) +// - capacity: network callback to retrieve the estimated type-specific bandwidth capacity of a peer (traffic shaping) +// - idle: network callback to retrieve the currently (type specific) idle peers that can be assigned tasks +// - setIdle: network callback to set a peer back to idle and update its estimated capacity (traffic shaping) +// - kind: textual label of the type being downloaded to display in log mesages func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliver func(dataPack) (int, error), wakeCh chan bool, expire func() map[string]int, pending func() int, inFlight func() bool, throttle func() bool, reserve func(*peerConnection, int) (*fetchRequest, bool, error), fetchHook func([]*types.Header), fetch func(*peerConnection, *fetchRequest) error, cancel func(*fetchRequest), capacity func(*peerConnection) int, @@ -1036,7 +1036,7 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv case err == nil: peer.log.Trace("Delivered new batch of data", "type", kind, "count", packet.Stats()) default: - peer.log.Trace("Failed to deliver retrieved data", "type", kind, "err", err) + peer.log.Debug("Failed to deliver retrieved data", "type", kind, "err", err) } } // Blocks assembled, try to update the progress diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go index a67f12c47765..b9f16e673f94 100644 --- a/eth/downloader/queue.go +++ b/eth/downloader/queue.go @@ -26,10 +26,10 @@ import ( "time" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/prque" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/log" "github.com/XinFinOrg/XDPoSChain/metrics" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) var ( @@ -105,11 +105,11 @@ func newQueue() *queue { headerPendPool: make(map[string]*fetchRequest), headerContCh: make(chan bool), blockTaskPool: make(map[common.Hash]*types.Header), - blockTaskQueue: prque.New(), + blockTaskQueue: prque.New(nil), blockPendPool: make(map[string]*fetchRequest), blockDonePool: make(map[common.Hash]struct{}), receiptTaskPool: make(map[common.Hash]*types.Header), - receiptTaskQueue: prque.New(), + receiptTaskQueue: prque.New(nil), receiptPendPool: make(map[string]*fetchRequest), receiptDonePool: make(map[common.Hash]struct{}), resultCache: make([]*fetchResult, blockCacheItems), @@ -278,7 +278,7 @@ func (q *queue) ScheduleSkeleton(from uint64, skeleton []*types.Header) { } // Shedule all the header retrieval tasks for the skeleton assembly q.headerTaskPool = make(map[uint64]*types.Header) - q.headerTaskQueue = prque.New() + q.headerTaskQueue = prque.New(nil) q.headerPeerMiss = make(map[string]map[uint64]struct{}) // Reset availability to correct invalid chains q.headerResults = make([]*types.Header, len(skeleton)*MaxHeaderFetch) q.headerProced = 0 @@ -289,7 +289,7 @@ func (q *queue) ScheduleSkeleton(from uint64, skeleton []*types.Header) { index := from + uint64(i*MaxHeaderFetch) q.headerTaskPool[index] = header - q.headerTaskQueue.Push(index, -float32(index)) + q.headerTaskQueue.Push(index, -int64(index)) } } @@ -335,11 +335,11 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header { } // Queue the header for content retrieval q.blockTaskPool[hash] = header - q.blockTaskQueue.Push(header, -float32(header.Number.Uint64())) + q.blockTaskQueue.Push(header, -int64(header.Number.Uint64())) if q.mode == FastSync { q.receiptTaskPool[hash] = header - q.receiptTaskQueue.Push(header, -float32(header.Number.Uint64())) + q.receiptTaskQueue.Push(header, -int64(header.Number.Uint64())) } inserts = append(inserts, header) q.headerHead = hash @@ -437,7 +437,7 @@ func (q *queue) ReserveHeaders(p *peerConnection, count int) *fetchRequest { } // Merge all the skipped batches back for _, from := range skip { - q.headerTaskQueue.Push(from, -float32(from)) + q.headerTaskQueue.Push(from, -int64(from)) } // Assemble and return the block download request if send == 0 { @@ -544,7 +544,7 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common } // Merge all the skipped headers back for _, header := range skip { - taskQueue.Push(header, -float32(header.Number.Uint64())) + taskQueue.Push(header, -int64(header.Number.Uint64())) } if progress { // Wake WaitResults, resultCache was modified @@ -587,10 +587,10 @@ func (q *queue) cancel(request *fetchRequest, taskQueue *prque.Prque, pendPool m defer q.lock.Unlock() if request.From > 0 { - taskQueue.Push(request.From, -float32(request.From)) + taskQueue.Push(request.From, -int64(request.From)) } for _, header := range request.Headers { - taskQueue.Push(header, -float32(header.Number.Uint64())) + taskQueue.Push(header, -int64(header.Number.Uint64())) } delete(pendPool, request.Peer.id) } @@ -604,13 +604,13 @@ func (q *queue) Revoke(peerId string) { if request, ok := q.blockPendPool[peerId]; ok { for _, header := range request.Headers { - q.blockTaskQueue.Push(header, -float32(header.Number.Uint64())) + q.blockTaskQueue.Push(header, -int64(header.Number.Uint64())) } delete(q.blockPendPool, peerId) } if request, ok := q.receiptPendPool[peerId]; ok { for _, header := range request.Headers { - q.receiptTaskQueue.Push(header, -float32(header.Number.Uint64())) + q.receiptTaskQueue.Push(header, -int64(header.Number.Uint64())) } delete(q.receiptPendPool, peerId) } @@ -659,10 +659,10 @@ func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest, // Return any non satisfied requests to the pool if request.From > 0 { - taskQueue.Push(request.From, -float32(request.From)) + taskQueue.Push(request.From, -int64(request.From)) } for _, header := range request.Headers { - taskQueue.Push(header, -float32(header.Number.Uint64())) + taskQueue.Push(header, -int64(header.Number.Uint64())) } // Add the peer to the expiry report along the the number of failed requests expiries[id] = len(request.Headers) @@ -733,7 +733,7 @@ func (q *queue) DeliverHeaders(id string, headers []*types.Header, headerProcCh } miss[request.From] = struct{}{} - q.headerTaskQueue.Push(request.From, -float32(request.From)) + q.headerTaskQueue.Push(request.From, -int64(request.From)) return 0, errors.New("delivery not accepted") } // Clean up a successful fetch and try to deliver any sub-results @@ -856,7 +856,7 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ // Return all failed or missing fetches to the queue for _, header := range request.Headers { if header != nil { - taskQueue.Push(header, -float32(header.Number.Uint64())) + taskQueue.Push(header, -int64(header.Number.Uint64())) } } // Wake up WaitResults diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index c31e05961d9f..7d1e15fd4dea 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -25,10 +25,10 @@ import ( lru "github.com/hashicorp/golang-lru" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/prque" "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/log" - "gopkg.in/karalabe/cookiejar.v2/collections/prque" ) const ( @@ -171,7 +171,7 @@ func New(getBlock blockRetrievalFn, verifyHeader headerVerifierFn, handlePropose fetching: make(map[common.Hash]*announce), fetched: make(map[common.Hash][]*announce), completing: make(map[common.Hash]*announce), - queue: prque.New(), + queue: prque.New(nil), queues: make(map[string]int), queued: make(map[common.Hash]*inject), knowns: knownBlocks, @@ -312,7 +312,7 @@ func (f *Fetcher) loop() { // If too high up the chain or phase, continue later number := op.block.NumberU64() if number > height+1 { - f.queue.Push(op, -float32(op.block.NumberU64())) + f.queue.Push(op, -int64(op.block.NumberU64())) if f.queueChangeHook != nil { f.queueChangeHook(op.block.Hash(), true) } @@ -642,7 +642,7 @@ func (f *Fetcher) enqueue(peer string, block *types.Block) { f.queues[peer] = count f.queued[hash] = op f.knowns.Add(hash, true) - f.queue.Push(op, -float32(block.NumberU64())) + f.queue.Push(op, -int64(block.NumberU64())) if f.queueChangeHook != nil { f.queueChangeHook(op.block.Hash(), true) } diff --git a/eth/filters/api.go b/eth/filters/api.go index ed36c3a6476a..952598046d2c 100644 --- a/eth/filters/api.go +++ b/eth/filters/api.go @@ -28,7 +28,6 @@ import ( ethereum "github.com/XinFinOrg/XDPoSChain" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" - "github.com/XinFinOrg/XDPoSChain/core" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/ethdb" "github.com/XinFinOrg/XDPoSChain/event" @@ -112,8 +111,8 @@ func (api *PublicFilterAPI) timeoutLoop() { // https://github.com/ethereum/wiki/wiki/JSON-RPC#eth_newpendingtransactionfilter func (api *PublicFilterAPI) NewPendingTransactionFilter() rpc.ID { var ( - pendingTxs = make(chan common.Hash) - pendingTxSub = api.events.SubscribePendingTxEvents(pendingTxs) + pendingTxs = make(chan []common.Hash) + pendingTxSub = api.events.SubscribePendingTxs(pendingTxs) ) api.filtersMu.Lock() @@ -126,7 +125,7 @@ func (api *PublicFilterAPI) NewPendingTransactionFilter() rpc.ID { case ph := <-pendingTxs: api.filtersMu.Lock() if f, found := api.filters[pendingTxSub.ID]; found { - f.hashes = append(f.hashes, ph) + f.hashes = append(f.hashes, ph...) } api.filtersMu.Unlock() case <-pendingTxSub.Err(): @@ -152,13 +151,17 @@ func (api *PublicFilterAPI) NewPendingTransactions(ctx context.Context) (*rpc.Su rpcSub := notifier.CreateSubscription() go func() { - txHashes := make(chan common.Hash) - pendingTxSub := api.events.SubscribePendingTxEvents(txHashes) + txHashes := make(chan []common.Hash, 128) + pendingTxSub := api.events.SubscribePendingTxs(txHashes) for { select { - case h := <-txHashes: - notifier.Notify(rpcSub.ID, h) + case hashes := <-txHashes: + // To keep the original behaviour, send a single tx hash in one notification. + // TODO(rjl493456442) Send a batch of tx hashes in one notification + for _, h := range hashes { + notifier.Notify(rpcSub.ID, h) + } case <-rpcSub.Err(): pendingTxSub.Unsubscribe() return @@ -439,10 +442,6 @@ func (api *PublicFilterAPI) GetFilterChanges(id rpc.ID) (interface{}, error) { case LogsSubscription: logs := f.logs f.logs = nil - for _, log := range logs { - // update BlockHash to fix #208 - log.BlockHash = core.GetCanonicalHash(api.chainDb, log.BlockNumber) - } return returnLogs(logs), nil } } diff --git a/eth/filters/filter.go b/eth/filters/filter.go index dcd872fc4c48..4d47bafc0fc7 100644 --- a/eth/filters/filter.go +++ b/eth/filters/filter.go @@ -38,7 +38,7 @@ type Backend interface { GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) GetLogs(ctx context.Context, blockHash common.Hash) ([][]*types.Log, error) - SubscribeTxPreEvent(chan<- core.TxPreEvent) event.Subscription + SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEvent) event.Subscription SubscribeLogsEvent(ch chan<- []*types.Log) event.Subscription diff --git a/eth/filters/filter_system.go b/eth/filters/filter_system.go index 6dafd9610421..2d91b771ef9a 100644 --- a/eth/filters/filter_system.go +++ b/eth/filters/filter_system.go @@ -57,8 +57,7 @@ const ( ) const ( - - // txChanSize is the size of channel listening to TxPreEvent. + // txChanSize is the size of channel listening to NewTxsEvent. // The number is referenced from the size of tx pool. txChanSize = 4096 // rmLogsChanSize is the size of channel listening to RemovedLogsEvent. @@ -79,7 +78,7 @@ type subscription struct { created time.Time logsCrit ethereum.FilterQuery logs chan []*types.Log - hashes chan common.Hash + hashes chan []common.Hash headers chan *types.Header installed chan struct{} // closed when the filter is installed err chan error // closed when the filter is uninstalled @@ -94,7 +93,7 @@ type EventSystem struct { lastHead *types.Header // Subscriptions - txSub event.Subscription // Subscription for new transaction event + txsSub event.Subscription // Subscription for new transaction event logsSub event.Subscription // Subscription for new log event rmLogsSub event.Subscription // Subscription for removed log event chainSub event.Subscription // Subscription for new chain event @@ -103,7 +102,7 @@ type EventSystem struct { // Channels install chan *subscription // install filter for event notification uninstall chan *subscription // remove filter for event notification - txCh chan core.TxPreEvent // Channel to receive new transaction event + txsCh chan core.NewTxsEvent // Channel to receive new transactions event logsCh chan []*types.Log // Channel to receive new log event rmLogsCh chan core.RemovedLogsEvent // Channel to receive removed log event chainCh chan core.ChainEvent // Channel to receive new chain event @@ -122,14 +121,14 @@ func NewEventSystem(mux *event.TypeMux, backend Backend, lightMode bool) *EventS lightMode: lightMode, install: make(chan *subscription), uninstall: make(chan *subscription), - txCh: make(chan core.TxPreEvent, txChanSize), + txsCh: make(chan core.NewTxsEvent, txChanSize), logsCh: make(chan []*types.Log, logsChanSize), rmLogsCh: make(chan core.RemovedLogsEvent, rmLogsChanSize), chainCh: make(chan core.ChainEvent, chainEvChanSize), } // Subscribe events - m.txSub = m.backend.SubscribeTxPreEvent(m.txCh) + m.txsSub = m.backend.SubscribeNewTxsEvent(m.txsCh) m.logsSub = m.backend.SubscribeLogsEvent(m.logsCh) m.rmLogsSub = m.backend.SubscribeRemovedLogsEvent(m.rmLogsCh) m.chainSub = m.backend.SubscribeChainEvent(m.chainCh) @@ -137,7 +136,7 @@ func NewEventSystem(mux *event.TypeMux, backend Backend, lightMode bool) *EventS m.pendingLogSub = m.mux.Subscribe(core.PendingLogsEvent{}) // Make sure none of the subscriptions are empty - if m.txSub == nil || m.logsSub == nil || m.rmLogsSub == nil || m.chainSub == nil || + if m.txsSub == nil || m.logsSub == nil || m.rmLogsSub == nil || m.chainSub == nil || m.pendingLogSub.Closed() { log.Crit("Subscribe for event system failed") } @@ -240,7 +239,7 @@ func (es *EventSystem) subscribeMinedPendingLogs(crit ethereum.FilterQuery, logs logsCrit: crit, created: time.Now(), logs: logs, - hashes: make(chan common.Hash), + hashes: make(chan []common.Hash), headers: make(chan *types.Header), installed: make(chan struct{}), err: make(chan error), @@ -257,7 +256,7 @@ func (es *EventSystem) subscribeLogs(crit ethereum.FilterQuery, logs chan []*typ logsCrit: crit, created: time.Now(), logs: logs, - hashes: make(chan common.Hash), + hashes: make(chan []common.Hash), headers: make(chan *types.Header), installed: make(chan struct{}), err: make(chan error), @@ -274,7 +273,7 @@ func (es *EventSystem) subscribePendingLogs(crit ethereum.FilterQuery, logs chan logsCrit: crit, created: time.Now(), logs: logs, - hashes: make(chan common.Hash), + hashes: make(chan []common.Hash), headers: make(chan *types.Header), installed: make(chan struct{}), err: make(chan error), @@ -290,7 +289,7 @@ func (es *EventSystem) SubscribeNewHeads(headers chan *types.Header) *Subscripti typ: BlocksSubscription, created: time.Now(), logs: make(chan []*types.Log), - hashes: make(chan common.Hash), + hashes: make(chan []common.Hash), headers: headers, installed: make(chan struct{}), err: make(chan error), @@ -298,9 +297,9 @@ func (es *EventSystem) SubscribeNewHeads(headers chan *types.Header) *Subscripti return es.subscribe(sub) } -// SubscribePendingTxEvents creates a subscription that writes transaction hashes for +// SubscribePendingTxs creates a subscription that writes transaction hashes for // transactions that enter the transaction pool. -func (es *EventSystem) SubscribePendingTxEvents(hashes chan common.Hash) *Subscription { +func (es *EventSystem) SubscribePendingTxs(hashes chan []common.Hash) *Subscription { sub := &subscription{ id: rpc.NewID(), typ: PendingTransactionsSubscription, @@ -347,9 +346,13 @@ func (es *EventSystem) broadcast(filters filterIndex, ev interface{}) { } } } - case core.TxPreEvent: + case core.NewTxsEvent: + hashes := make([]common.Hash, 0, len(e.Txs)) + for _, tx := range e.Txs { + hashes = append(hashes, tx.Hash()) + } for _, f := range filters[PendingTransactionsSubscription] { - f.hashes <- e.Tx.Hash() + f.hashes <- hashes } case core.ChainEvent: for _, f := range filters[BlocksSubscription] { @@ -445,7 +448,7 @@ func (es *EventSystem) eventLoop() { // Ensure all subscriptions get cleaned up defer func() { es.pendingLogSub.Unsubscribe() - es.txSub.Unsubscribe() + es.txsSub.Unsubscribe() es.logsSub.Unsubscribe() es.rmLogsSub.Unsubscribe() es.chainSub.Unsubscribe() @@ -459,7 +462,7 @@ func (es *EventSystem) eventLoop() { for { select { // Handle subscribed events - case ev := <-es.txCh: + case ev := <-es.txsCh: es.broadcast(index, ev) case ev := <-es.logsCh: es.broadcast(index, ev) @@ -494,7 +497,7 @@ func (es *EventSystem) eventLoop() { close(f.err) // System stopped - case <-es.txSub.Err(): + case <-es.txsSub.Err(): return case <-es.logsSub.Err(): return diff --git a/eth/filters/filter_system_test.go b/eth/filters/filter_system_test.go index 6d7129a1d1ae..fb4f7e7b8bc2 100644 --- a/eth/filters/filter_system_test.go +++ b/eth/filters/filter_system_test.go @@ -90,7 +90,7 @@ func (b *testBackend) GetLogs(ctx context.Context, blockHash common.Hash) ([][]* return logs, nil } -func (b *testBackend) SubscribeTxPreEvent(ch chan<- core.TxPreEvent) event.Subscription { +func (b *testBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { return b.txFeed.Subscribe(ch) } @@ -226,10 +226,7 @@ func TestPendingTxFilter(t *testing.T) { fid0 := api.NewPendingTransactionFilter() time.Sleep(1 * time.Second) - for _, tx := range transactions { - ev := core.TxPreEvent{Tx: tx} - txFeed.Send(ev) - } + txFeed.Send(core.NewTxsEvent{Txs: transactions}) timeout := time.Now().Add(1 * time.Second) for { diff --git a/eth/gasprice/gasprice.go b/eth/gasprice/gasprice.go index 8dfaeec43a45..54492790371b 100644 --- a/eth/gasprice/gasprice.go +++ b/eth/gasprice/gasprice.go @@ -162,7 +162,7 @@ type transactionsByGasPrice []*types.Transaction func (t transactionsByGasPrice) Len() int { return len(t) } func (t transactionsByGasPrice) Swap(i, j int) { t[i], t[j] = t[j], t[i] } -func (t transactionsByGasPrice) Less(i, j int) bool { return t[i].GasPrice().Cmp(t[j].GasPrice()) < 0 } +func (t transactionsByGasPrice) Less(i, j int) bool { return t[i].GasPriceCmp(t[j]) < 0 } // getBlockPrices calculates the lowest transaction gas price in a given block // and sends it to the result channel. If the block is empty, price is nil. diff --git a/eth/gen_config.go b/eth/gen_config.go index e58954b7b224..1959ee559d53 100644 --- a/eth/gen_config.go +++ b/eth/gen_config.go @@ -4,46 +4,52 @@ package eth import ( "math/big" + "time" "github.com/XinFinOrg/XDPoSChain/common" - "github.com/XinFinOrg/XDPoSChain/common/hexutil" "github.com/XinFinOrg/XDPoSChain/consensus/ethash" "github.com/XinFinOrg/XDPoSChain/core" "github.com/XinFinOrg/XDPoSChain/eth/downloader" "github.com/XinFinOrg/XDPoSChain/eth/gasprice" ) -var _ = (*configMarshaling)(nil) - +// MarshalTOML marshals as TOML. func (c Config) MarshalTOML() (interface{}, error) { type Config struct { Genesis *core.Genesis `toml:",omitempty"` NetworkId uint64 SyncMode downloader.SyncMode + NoPruning bool LightServ int `toml:",omitempty"` LightPeers int `toml:",omitempty"` SkipBcVersionCheck bool `toml:"-"` DatabaseHandles int `toml:"-"` DatabaseCache int + TrieCache int + TrieTimeout time.Duration Etherbase common.Address `toml:",omitempty"` MinerThreads int `toml:",omitempty"` - ExtraData hexutil.Bytes `toml:",omitempty"` + ExtraData []byte `toml:",omitempty"` GasPrice *big.Int Ethash ethash.Config TxPool core.TxPoolConfig GPO gasprice.Config EnablePreimageRecording bool DocRoot string `toml:"-"` + RPCGasCap uint64 } var enc Config enc.Genesis = c.Genesis enc.NetworkId = c.NetworkId enc.SyncMode = c.SyncMode + enc.NoPruning = c.NoPruning enc.LightServ = c.LightServ enc.LightPeers = c.LightPeers enc.SkipBcVersionCheck = c.SkipBcVersionCheck enc.DatabaseHandles = c.DatabaseHandles enc.DatabaseCache = c.DatabaseCache + enc.TrieCache = c.TrieCache + enc.TrieTimeout = c.TrieTimeout enc.Etherbase = c.Etherbase enc.MinerThreads = c.MinerThreads enc.ExtraData = c.ExtraData @@ -53,28 +59,34 @@ func (c Config) MarshalTOML() (interface{}, error) { enc.GPO = c.GPO enc.EnablePreimageRecording = c.EnablePreimageRecording enc.DocRoot = c.DocRoot + enc.RPCGasCap = c.RPCGasCap return &enc, nil } +// UnmarshalTOML unmarshals from TOML. func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { type Config struct { Genesis *core.Genesis `toml:",omitempty"` NetworkId *uint64 SyncMode *downloader.SyncMode + NoPruning *bool LightServ *int `toml:",omitempty"` LightPeers *int `toml:",omitempty"` SkipBcVersionCheck *bool `toml:"-"` DatabaseHandles *int `toml:"-"` DatabaseCache *int + TrieCache *int + TrieTimeout *time.Duration Etherbase *common.Address `toml:",omitempty"` MinerThreads *int `toml:",omitempty"` - ExtraData *hexutil.Bytes `toml:",omitempty"` + ExtraData []byte `toml:",omitempty"` GasPrice *big.Int Ethash *ethash.Config TxPool *core.TxPoolConfig GPO *gasprice.Config EnablePreimageRecording *bool DocRoot *string `toml:"-"` + RPCGasCap *uint64 } var dec Config if err := unmarshal(&dec); err != nil { @@ -89,6 +101,9 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.SyncMode != nil { c.SyncMode = *dec.SyncMode } + if dec.NoPruning != nil { + c.NoPruning = *dec.NoPruning + } if dec.LightServ != nil { c.LightServ = *dec.LightServ } @@ -104,6 +119,12 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.DatabaseCache != nil { c.DatabaseCache = *dec.DatabaseCache } + if dec.TrieCache != nil { + c.TrieCache = *dec.TrieCache + } + if dec.TrieTimeout != nil { + c.TrieTimeout = *dec.TrieTimeout + } if dec.Etherbase != nil { c.Etherbase = *dec.Etherbase } @@ -111,7 +132,7 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { c.MinerThreads = *dec.MinerThreads } if dec.ExtraData != nil { - c.ExtraData = *dec.ExtraData + c.ExtraData = dec.ExtraData } if dec.GasPrice != nil { c.GasPrice = dec.GasPrice @@ -131,5 +152,8 @@ func (c *Config) UnmarshalTOML(unmarshal func(interface{}) error) error { if dec.DocRoot != nil { c.DocRoot = *dec.DocRoot } + if dec.RPCGasCap != nil { + c.RPCGasCap = *dec.RPCGasCap + } return nil } diff --git a/eth/handler.go b/eth/handler.go index 055486b0b942..76733fb8d863 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -49,7 +49,7 @@ const ( softResponseLimit = 2 * 1024 * 1024 // Target maximum size of returned blocks, headers or node data. estHeaderRlpSize = 500 // Approximate size of an RLP encoded block header - // txChanSize is the size of channel listening to TxPreEvent. + // txChanSize is the size of channel listening to NewTxsEvent. // The number is referenced from the size of tx pool. txChanSize = 4096 ) @@ -87,10 +87,10 @@ type ProtocolManager struct { SubProtocols []p2p.Protocol eventMux *event.TypeMux - txCh chan core.TxPreEvent + txsCh chan core.NewTxsEvent orderTxCh chan core.OrderTxPreEvent lendingTxCh chan core.LendingTxPreEvent - txSub event.Subscription + txsSub event.Subscription orderTxSub event.Subscription lendingTxSub event.Subscription minedBlockSub *event.TypeMuxSubscription @@ -292,8 +292,8 @@ func (pm *ProtocolManager) Start(maxPeers int) { pm.maxPeers = maxPeers // broadcast transactions - pm.txCh = make(chan core.TxPreEvent, txChanSize) - pm.txSub = pm.txpool.SubscribeTxPreEvent(pm.txCh) + pm.txsCh = make(chan core.NewTxsEvent, txChanSize) + pm.txsSub = pm.txpool.SubscribeNewTxsEvent(pm.txsCh) pm.orderTxCh = make(chan core.OrderTxPreEvent, txChanSize) if pm.orderpool != nil { pm.orderTxSub = pm.orderpool.SubscribeTxPreEvent(pm.orderTxCh) @@ -317,7 +317,7 @@ func (pm *ProtocolManager) Start(maxPeers int) { func (pm *ProtocolManager) Stop() { log.Info("Stopping Ethereum protocol") - pm.txSub.Unsubscribe() // quits txBroadcastLoop + pm.txsSub.Unsubscribe() // quits txBroadcastLoop if pm.orderTxSub != nil { pm.orderTxSub.Unsubscribe() } @@ -941,16 +941,23 @@ func (pm *ProtocolManager) BroadcastBlock(block *types.Block, propagate bool) { } } -// BroadcastTx will propagate a transaction to all peers which are not known to +// BroadcastTxs will propagate a batch of transactions to all peers which are not known to // already have the given transaction. -func (pm *ProtocolManager) BroadcastTx(hash common.Hash, tx *types.Transaction) { - // Broadcast transaction to a batch of peers not knowing about it - peers := pm.peers.PeersWithoutTx(hash) - //FIXME include this again: peers = peers[:int(math.Sqrt(float64(len(peers))))] - for _, peer := range peers { - peer.SendTransactions(types.Transactions{tx}) +func (pm *ProtocolManager) BroadcastTxs(txs types.Transactions) { + var txset = make(map[*peer]types.Transactions) + + // Broadcast transactions to a batch of peers not knowing about it + for _, tx := range txs { + peers := pm.peers.PeersWithoutTx(tx.Hash()) + for _, peer := range peers { + txset[peer] = append(txset[peer], tx) + } + log.Trace("Broadcast transaction", "hash", tx.Hash(), "recipients", len(peers)) + } + // FIXME include this again: peers = peers[:int(math.Sqrt(float64(len(peers))))] + for peer, txs := range txset { + peer.SendTransactions(txs) } - log.Trace("Broadcast transaction", "hash", hash, "recipients", len(peers)) } // BroadcastVote will propagate a Vote to all peers which are not known to @@ -1041,14 +1048,14 @@ func (self *ProtocolManager) minedBroadcastLoop() { } } -func (self *ProtocolManager) txBroadcastLoop() { +func (pm *ProtocolManager) txBroadcastLoop() { for { select { - case event := <-self.txCh: - self.BroadcastTx(event.Tx.Hash(), event.Tx) + case event := <-pm.txsCh: + pm.BroadcastTxs(event.Txs) - // Err() channel will be closed when unsubscribing. - case <-self.txSub.Err(): + // Err() channel will be closed when unsubscribing. + case <-pm.txsSub.Err(): return } } diff --git a/eth/helper_test.go b/eth/helper_test.go index b3e489bd8b33..07c6d92ee166 100644 --- a/eth/helper_test.go +++ b/eth/helper_test.go @@ -126,7 +126,7 @@ func (p *testTxPool) Pending() (map[common.Address]types.Transactions, error) { return batches, nil } -func (p *testTxPool) SubscribeTxPreEvent(ch chan<- core.TxPreEvent) event.Subscription { +func (p *testTxPool) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { return p.txFeed.Subscribe(ch) } diff --git a/eth/protocol.go b/eth/protocol.go index 64b488368433..eb7297a28c10 100644 --- a/eth/protocol.go +++ b/eth/protocol.go @@ -110,9 +110,9 @@ type txPool interface { // The slice should be modifiable by the caller. Pending() (map[common.Address]types.Transactions, error) - // SubscribeTxPreEvent should return an event subscription of - // TxPreEvent and send events to the given channel. - SubscribeTxPreEvent(chan<- core.TxPreEvent) event.Subscription + // SubscribeNewTxsEvent should return an event subscription of + // NewTxsEvent and send events to the given channel. + SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription } type orderPool interface { diff --git a/eth/protocol_test.go b/eth/protocol_test.go index 73cefba82f15..91788594f210 100644 --- a/eth/protocol_test.go +++ b/eth/protocol_test.go @@ -116,7 +116,7 @@ func testRecvTransactions(t *testing.T, protocol int) { t.Errorf("added wrong tx hash: got %v, want %v", added[0].Hash(), tx.Hash()) } case <-time.After(2 * time.Second): - t.Errorf("no TxPreEvent received within 2 seconds") + t.Errorf("no NewTxsEvent received within 2 seconds") } } diff --git a/eth/state_accessor.go b/eth/state_accessor.go new file mode 100644 index 000000000000..140830c5ea21 --- /dev/null +++ b/eth/state_accessor.go @@ -0,0 +1,137 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "errors" + "fmt" + "time" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/core/state" + "github.com/XinFinOrg/XDPoSChain/core/types" + "github.com/XinFinOrg/XDPoSChain/core/vm" + "github.com/XinFinOrg/XDPoSChain/log" + "github.com/XinFinOrg/XDPoSChain/trie" +) + +// stateAtBlock retrieves the state database associated with a certain block. +// If no state is locally available for the given block, a number of blocks +// are attempted to be reexecuted to generate the desired state. The optional +// base layer statedb can be passed then it's regarded as the statedb of the +// parent block. +func (eth *Ethereum) stateAtBlock(block *types.Block, reexec uint64, base *state.StateDB, checkLive bool) (statedb *state.StateDB, err error) { + var ( + current *types.Block + database state.Database + report = true + origin = block.NumberU64() + ) + // Check the live database first if we have the state fully available, use that. + if checkLive { + statedb, err = eth.blockchain.StateAt(block.Root()) + if err == nil { + return statedb, nil + } + } + if base != nil { + // The optional base statedb is given, mark the start point as parent block + statedb, database, report = base, base.Database(), false + current = eth.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1) + } else { + // Otherwise try to reexec blocks until we find a state or reach our limit + current = block + + // Create an ephemeral trie.Database for isolating the live one. Otherwise + // the internal junks created by tracing will be persisted into the disk. + database = state.NewDatabaseWithCache(eth.chainDb, 16) + // If we didn't check the dirty database, do check the clean one, otherwise + // we would rewind past a persisted block (specific corner case is chain + // tracing from the genesis). + if !checkLive { + statedb, err = state.New(current.Root(), database) + if err == nil { + return statedb, nil + } + } + // Database does not have the state for the given block, try to regenerate + for i := uint64(0); i < reexec; i++ { + if current.NumberU64() == 0 { + return nil, errors.New("genesis state is missing") + } + parent := eth.blockchain.GetBlock(current.ParentHash(), current.NumberU64()-1) + if parent == nil { + return nil, fmt.Errorf("missing block %v %d", current.ParentHash(), current.NumberU64()-1) + } + current = parent + + statedb, err = state.New(current.Root(), database) + if err == nil { + break + } + } + if err != nil { + switch err.(type) { + case *trie.MissingNodeError: + return nil, fmt.Errorf("required historical state unavailable (reexec=%d)", reexec) + default: + return nil, err + } + } + } + // State was available at historical point, regenerate + var ( + start = time.Now() + logged time.Time + parent common.Hash + ) + for current.NumberU64() < origin { + // Print progress logs if long enough time elapsed + if time.Since(logged) > 8*time.Second && report { + log.Info("Regenerating historical state", "block", current.NumberU64()+1, "target", origin, "remaining", origin-current.NumberU64()-1, "elapsed", time.Since(start)) + logged = time.Now() + } + // Retrieve the next block to regenerate and process it + next := current.NumberU64() + 1 + if current = eth.blockchain.GetBlockByNumber(next); current == nil { + return nil, fmt.Errorf("block #%d not found", next) + } + _, _, _, err := eth.blockchain.Processor().Process(current, statedb, nil, vm.Config{}, nil) + if err != nil { + return nil, fmt.Errorf("processing block %d failed: %v", current.NumberU64(), err) + } + // Finalize the state so any modifications are written to the trie + root, err := statedb.Commit(eth.blockchain.Config().IsEIP158(current.Number())) + if err != nil { + return nil, err + } + statedb, err = state.New(root, database) + if err != nil { + return nil, fmt.Errorf("state reset after block %d failed: %v", current.NumberU64(), err) + } + database.TrieDB().Reference(root, common.Hash{}) + if parent != (common.Hash{}) { + database.TrieDB().Dereference(parent) + } + parent = root + } + if report { + nodes, imgs := database.TrieDB().Size() + log.Info("Historical state regenerated", "block", current.NumberU64(), "elapsed", time.Since(start), "nodes", nodes, "preimages", imgs) + } + return statedb, nil +} diff --git a/eth/tracers/tracer.go b/eth/tracers/tracer.go index 235f752bac0a..6216bc48a8a7 100644 --- a/eth/tracers/tracer.go +++ b/eth/tracers/tracer.go @@ -27,6 +27,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/hexutil" + "github.com/XinFinOrg/XDPoSChain/core" "github.com/XinFinOrg/XDPoSChain/core/vm" "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/log" @@ -286,8 +287,6 @@ func (cw *contractWrapper) pushObject(vm *duktape.Context) { // Tracer provides an implementation of Tracer that evaluates a Javascript // function for each VM execution step. type Tracer struct { - inited bool // Flag whether the context was already inited from the EVM - vm *duktape.Context // Javascript VM instance tracerObject int // Stack index of the tracer JavaScript object @@ -526,7 +525,7 @@ func wrapError(context string, err error) error { } // CaptureStart implements the Tracer interface to initialize the tracing operation. -func (jst *Tracer) CaptureStart(from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) error { +func (jst *Tracer) CaptureStart(env *vm.EVM, from common.Address, to common.Address, create bool, input []byte, gas uint64, value *big.Int) { jst.ctx["type"] = "CALL" if create { jst.ctx["type"] = "CREATE" @@ -537,73 +536,75 @@ func (jst *Tracer) CaptureStart(from common.Address, to common.Address, create b jst.ctx["gas"] = gas jst.ctx["value"] = value - return nil + // Initialize the context + jst.ctx["block"] = env.Context.BlockNumber.Uint64() + jst.dbWrapper.db = env.StateDB + // Compute intrinsic gas + isHomestead := env.ChainConfig().IsHomestead(env.Context.BlockNumber) + intrinsicGas, err := core.IntrinsicGas(input, nil, jst.ctx["type"] == "CREATE", isHomestead) + if err != nil { + return + } + jst.ctx["intrinsicGas"] = intrinsicGas } + // CaptureState implements the Tracer interface to trace a single step of VM execution. -func (jst *Tracer) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error { - if jst.err == nil { - // Initialize the context if it wasn't done yet - if !jst.inited { - jst.ctx["block"] = env.BlockNumber.Uint64() - jst.inited = true - } - // If tracing was interrupted, set the error and stop - if atomic.LoadUint32(&jst.interrupt) > 0 { - jst.err = jst.reason - return nil - } - jst.opWrapper.op = op - jst.stackWrapper.stack = stack - jst.memoryWrapper.memory = memory - jst.contractWrapper.contract = contract - jst.dbWrapper.db = env.StateDB - - *jst.pcValue = uint(pc) - *jst.gasValue = uint(gas) - *jst.costValue = uint(cost) - *jst.depthValue = uint(depth) - *jst.refundValue = uint(env.StateDB.GetRefund()) - - jst.errorValue = nil - if err != nil { - jst.errorValue = new(string) - *jst.errorValue = err.Error() - } - _, err := jst.call("step", "log", "db") - if err != nil { - jst.err = wrapError("step", err) - } +func (jst *Tracer) CaptureState(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, rData []byte, depth int, err error) { + if jst.err != nil { + return + } + // If tracing was interrupted, set the error and stop + if atomic.LoadUint32(&jst.interrupt) > 0 { + jst.err = jst.reason + return + } + jst.opWrapper.op = op + jst.stackWrapper.stack = scope.Stack + jst.memoryWrapper.memory = scope.Memory + jst.contractWrapper.contract = scope.Contract + + *jst.pcValue = uint(pc) + *jst.gasValue = uint(gas) + *jst.costValue = uint(cost) + *jst.depthValue = uint(depth) + *jst.refundValue = uint(env.StateDB.GetRefund()) + + jst.errorValue = nil + if err != nil { + jst.errorValue = new(string) + *jst.errorValue = err.Error() + } + + if _, err := jst.call("step", "log", "db"); err != nil { + jst.err = wrapError("step", err) } - return nil } + // CaptureFault implements the Tracer interface to trace an execution fault -// while running an opcode. -func (jst *Tracer) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, memory *vm.Memory, stack *vm.Stack, contract *vm.Contract, depth int, err error) error { - if jst.err == nil { - // Apart from the error, everything matches the previous invocation - jst.errorValue = new(string) - *jst.errorValue = err.Error() +func (jst *Tracer) CaptureFault(env *vm.EVM, pc uint64, op vm.OpCode, gas, cost uint64, scope *vm.ScopeContext, depth int, err error) { + if jst.err != nil { + return + } + // Apart from the error, everything matches the previous invocation + jst.errorValue = new(string) + *jst.errorValue = err.Error() - _, err := jst.call("fault", "log", "db") - if err != nil { - jst.err = wrapError("fault", err) - } + if _, err := jst.call("fault", "log", "db"); err != nil { + jst.err = wrapError("fault", err) } - return nil } // CaptureEnd is called after the call finishes to finalize the tracing. -func (jst *Tracer) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) error { +func (jst *Tracer) CaptureEnd(output []byte, gasUsed uint64, t time.Duration, err error) { jst.ctx["output"] = output - jst.ctx["gasUsed"] = gasUsed jst.ctx["time"] = t.String() + jst.ctx["gasUsed"] = gasUsed if err != nil { jst.ctx["error"] = err.Error() } - return nil } // GetResult calls the Javascript 'result' function and returns its value, or any accumulated error diff --git a/eth/tracers/tracer_test.go b/eth/tracers/tracer_test.go index 577fd1e576c3..e648973481e2 100644 --- a/eth/tracers/tracer_test.go +++ b/eth/tracers/tracer_test.go @@ -47,21 +47,150 @@ type dummyStatedb struct { state.StateDB } -func (*dummyStatedb) GetRefund() uint64 { return 1337 } +func (*dummyStatedb) GetRefund() uint64 { return 1337 } +func (*dummyStatedb) GetBalance(addr common.Address) *big.Int { return new(big.Int) } func runTrace(tracer *Tracer) (json.RawMessage, error) { env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) - + var ( + startGas uint64 = 10000 + value = big.NewInt(0) + ) contract := vm.NewContract(account{}, account{}, big.NewInt(0), 10000) contract.Code = []byte{byte(vm.PUSH1), 0x1, byte(vm.PUSH1), 0x1, 0x0} - _, err := env.Interpreter().Run(contract, []byte{}, false) + tracer.CaptureStart(env, contract.Caller(), contract.Address(), false, []byte{}, startGas, value) + ret, err := env.Interpreter().Run(contract, []byte{}, false) + tracer.CaptureEnd(ret, startGas-contract.Gas, 1, err) if err != nil { return nil, err } return tracer.GetResult() } +func TestTracer(t *testing.T) { + execTracer := func(code string) []byte { + t.Helper() + tracer, err := New(code) + if err != nil { + t.Fatal(err) + } + ret, err := runTrace(tracer) + if err != nil { + t.Fatal(err) + } + return ret + } + for i, tt := range []struct { + code string + want string + }{ + { // tests that we don't panic on bad arguments to memory access + code: "{depths: [], step: function(log) { this.depths.push(log.memory.slice(-1,-2)); }, fault: function() {}, result: function() { return this.depths; }}", + want: `[{},{},{}]`, + }, { // tests that we don't panic on bad arguments to stack peeks + code: "{depths: [], step: function(log) { this.depths.push(log.stack.peek(-1)); }, fault: function() {}, result: function() { return this.depths; }}", + want: `["0","0","0"]`, + }, { // tests that we don't panic on bad arguments to memory getUint + code: "{ depths: [], step: function(log, db) { this.depths.push(log.memory.getUint(-64));}, fault: function() {}, result: function() { return this.depths; }}", + want: `["0","0","0"]`, + }, { // tests some general counting + code: "{count: 0, step: function() { this.count += 1; }, fault: function() {}, result: function() { return this.count; }}", + want: `3`, + }, { // tests that depth is reported correctly + code: "{depths: [], step: function(log) { this.depths.push(log.stack.length()); }, fault: function() {}, result: function() { return this.depths; }}", + want: `[0,1,2]`, + }, { // tests to-string of opcodes + code: "{opcodes: [], step: function(log) { this.opcodes.push(log.op.toString()); }, fault: function() {}, result: function() { return this.opcodes; }}", + want: `["PUSH1","PUSH1","STOP"]`, + }, { // tests intrinsic gas + code: "{depths: [], step: function() {}, fault: function() {}, result: function(ctx) { return ctx.gasUsed+'.'+ctx.intrinsicGas; }}", + want: `"6.21000"`, + }, + } { + if have := execTracer(tt.code); tt.want != string(have) { + t.Errorf("testcase %d: expected return value to be %s got %s\n\tcode: %v", i, tt.want, string(have), tt.code) + } + } +} + +func TestHalt(t *testing.T) { + t.Skip("duktape doesn't support abortion") + + timeout := errors.New("stahp") + tracer, err := New("{step: function() { while(1); }, result: function() { return null; }}") + if err != nil { + t.Fatal(err) + } + + go func() { + time.Sleep(1 * time.Second) + tracer.Stop(timeout) + }() + + if _, err = runTrace(tracer); err.Error() != "stahp in server-side tracer function 'step'" { + t.Errorf("Expected timeout error, got %v", err) + } +} + +func TestHaltBetweenSteps(t *testing.T) { + tracer, err := New("{step: function() {}, fault: function() {}, result: function() { return null; }}") + if err != nil { + t.Fatal(err) + } + env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) + scope := &vm.ScopeContext{ + Contract: vm.NewContract(&account{}, &account{}, big.NewInt(0), 0), + } + + tracer.CaptureState(env, 0, 0, 0, 0, scope, nil, 0, nil) + timeout := errors.New("stahp") + tracer.Stop(timeout) + tracer.CaptureState(env, 0, 0, 0, 0, scope, nil, 0, nil) + + if _, err := tracer.GetResult(); err.Error() != timeout.Error() { + t.Errorf("Expected timeout error, got %v", err) + } +} + +// TestNoStepExec tests a regular value transfer (no exec), and accessing the statedb +// in 'result' +func TestNoStepExec(t *testing.T) { + runEmptyTrace := func(tracer *Tracer) (json.RawMessage, error) { + env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) + startGas := uint64(10000) + contract := vm.NewContract(account{}, account{}, big.NewInt(0), startGas) + tracer.CaptureStart(env, contract.Caller(), contract.Address(), false, []byte{}, startGas, big.NewInt(0)) + tracer.CaptureEnd(nil, startGas-contract.Gas, 1, nil) + return tracer.GetResult() + } + execTracer := func(code string) []byte { + t.Helper() + tracer, err := New(code) + if err != nil { + t.Fatal(err) + } + ret, err := runEmptyTrace(tracer) + if err != nil { + t.Fatal(err) + } + return ret + } + for i, tt := range []struct { + code string + want string + }{ + { // tests that we don't panic on accessing the db methods + code: "{depths: [], step: function() {}, fault: function() {}, result: function(ctx, db){ return db.getBalance(ctx.to)} }", + want: `"0"`, + }, + } { + if have := execTracer(tt.code); tt.want != string(have) { + t.Errorf("testcase %d: expected return value to be %s got %s\n\tcode: %v", i, tt.want, string(have), tt.code) + } + } +} + // TestRegressionPanicSlice tests that we don't panic on bad arguments to memory access func TestRegressionPanicSlice(t *testing.T) { tracer, err := New("{depths: [], step: function(log) { this.depths.push(log.memory.slice(-1,-2)); }, fault: function() {}, result: function() { return this.depths; }}") @@ -139,40 +268,3 @@ func TestOpcodes(t *testing.T) { t.Errorf("Expected return value to be [\"PUSH1\",\"PUSH1\",\"STOP\"], got %s", string(ret)) } } - -func TestHalt(t *testing.T) { - t.Skip("duktape doesn't support abortion") - - timeout := errors.New("stahp") - tracer, err := New("{step: function() { while(1); }, result: function() { return null; }}") - if err != nil { - t.Fatal(err) - } - - go func() { - time.Sleep(1 * time.Second) - tracer.Stop(timeout) - }() - - if _, err = runTrace(tracer); err.Error() != "stahp in server-side tracer function 'step'" { - t.Errorf("Expected timeout error, got %v", err) - } -} - -func TestHaltBetweenSteps(t *testing.T) { - tracer, err := New("{step: function() {}, fault: function() {}, result: function() { return null; }}") - if err != nil { - t.Fatal(err) - } - env := vm.NewEVM(vm.Context{BlockNumber: big.NewInt(1)}, &dummyStatedb{}, nil, params.TestChainConfig, vm.Config{Debug: true, Tracer: tracer}) - contract := vm.NewContract(&account{}, &account{}, big.NewInt(0), 0) - - tracer.CaptureState(env, 0, 0, 0, 0, nil, nil, contract, 0, nil) - timeout := errors.New("stahp") - tracer.Stop(timeout) - tracer.CaptureState(env, 0, 0, 0, 0, nil, nil, contract, 0, nil) - - if _, err := tracer.GetResult(); err.Error() != timeout.Error() { - t.Errorf("Expected timeout error, got %v", err) - } -} diff --git a/ethclient/ethclient.go b/ethclient/ethclient.go index 745f5656a669..22dc3a0fd852 100644 --- a/ethclient/ethclient.go +++ b/ethclient/ethclient.go @@ -70,6 +70,16 @@ func (ec *Client) BlockByNumber(ctx context.Context, number *big.Int) (*types.Bl return ec.getBlock(ctx, "eth_getBlockByNumber", toBlockNumArg(number), true) } +// BlockReceipts returns the receipts of a given block number or hash +func (ec *Client) BlockReceipts(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) ([]*types.Receipt, error) { + var r []*types.Receipt + err := ec.c.CallContext(ctx, &r, "eth_getBlockReceipts", blockNrOrHash) + if err == nil && r == nil { + return nil, ethereum.NotFound + } + return r, err +} + type rpcBlock struct { Hash common.Hash `json:"hash"` Transactions []rpcTransaction `json:"transactions"` @@ -479,7 +489,7 @@ func (ec *Client) EstimateGas(ctx context.Context, msg ethereum.CallMsg) (uint64 // If the transaction was a contract creation use the TransactionReceipt method to get the // contract address after the transaction has been mined. func (ec *Client) SendTransaction(ctx context.Context, tx *types.Transaction) error { - data, err := rlp.EncodeToBytes(tx) + data, err := tx.MarshalBinary() if err != nil { return err } diff --git a/ethclient/signer.go b/ethclient/signer.go index d41bd379f4bc..1db03a9c9595 100644 --- a/ethclient/signer.go +++ b/ethclient/signer.go @@ -51,9 +51,14 @@ func (s *senderFromServer) Sender(tx *types.Transaction) (common.Address, error) return s.addr, nil } +func (s *senderFromServer) ChainID() *big.Int { + panic("can't sign with senderFromServer") +} + func (s *senderFromServer) Hash(tx *types.Transaction) common.Hash { panic("can't sign with senderFromServer") } + func (s *senderFromServer) SignatureValues(tx *types.Transaction, sig []byte) (R, S, V *big.Int, err error) { panic("can't sign with senderFromServer") } diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go index 08d741017151..57f26188d228 100644 --- a/ethstats/ethstats.go +++ b/ethstats/ethstats.go @@ -51,7 +51,7 @@ const ( // history request. historyUpdateRange = 50 - // txChanSize is the size of channel listening to TxPreEvent. + // txChanSize is the size of channel listening to NewTxsEvent. // The number is referenced from the size of tx pool. txChanSize = 4096 // chainHeadChanSize is the size of channel listening to ChainHeadEvent. @@ -65,9 +65,9 @@ type consensusEngine interface { } type txPool interface { - // SubscribeTxPreEvent should return an event subscription of - // TxPreEvent and send events to the given channel. - SubscribeTxPreEvent(chan<- core.TxPreEvent) event.Subscription + // SubscribeNewTxsEvent should return an event subscription of + // NewTxsEvent and send events to the given channel. + SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription } type blockChain interface { @@ -208,8 +208,8 @@ func (s *Service) loop() { headSub := blockchain.SubscribeChainHeadEvent(chainHeadCh) defer headSub.Unsubscribe() - txEventCh := make(chan core.TxPreEvent, txChanSize) - txSub := txpool.SubscribeTxPreEvent(txEventCh) + txEventCh := make(chan core.NewTxsEvent, txChanSize) + txSub := txpool.SubscribeNewTxsEvent(txEventCh) defer txSub.Unsubscribe() // Forensics events diff --git a/go.mod b/go.mod index 8d19bc4956fc..cbea7de42b43 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/davecgh/go-spew v1.1.1 github.com/docker/docker v1.4.2-0.20180625184442-8e610b2b55bf github.com/edsrzf/mmap-go v1.0.0 - github.com/ethereum/go-ethereum v1.9.11 github.com/fatih/color v1.13.0 github.com/gizak/termui v2.2.0+incompatible github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 @@ -44,7 +43,6 @@ require ( golang.org/x/sys v0.14.0 golang.org/x/tools v0.14.0 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c - gopkg.in/karalabe/cookiejar.v2 v2.0.0-20150724131613-8dcd6a7f4951 gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce gopkg.in/olebedev/go-duktape.v3 v3.0.0-20190213234257-ec84240a7772 gopkg.in/urfave/cli.v1 v1.20.0 @@ -60,6 +58,7 @@ require ( github.com/elastic/gosigar v0.8.1-0.20180330100440-37f05ff46ffa // indirect github.com/go-ole/go-ole v1.2.5 // indirect github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/google/pprof v0.0.0-20230207041349-798e818bf904 // indirect github.com/google/uuid v1.3.0 // indirect github.com/kr/pretty v0.3.1 // indirect diff --git a/go.sum b/go.sum index 19df00ede131..81851047d213 100644 --- a/go.sum +++ b/go.sum @@ -74,8 +74,6 @@ github.com/edsrzf/mmap-go v1.0.0 h1:CEBF7HpRnUCSJgGUb5h1Gm7e3VkmVDrR8lvWVLtrOFw= github.com/edsrzf/mmap-go v1.0.0/go.mod h1:YO35OhQPt3KJa3ryjFM5Bs14WD66h8eGKpfaBNrHW5M= github.com/elastic/gosigar v0.8.1-0.20180330100440-37f05ff46ffa h1:XKAhUk/dtp+CV0VO6mhG2V7jA9vbcGcnYF/Ay9NjZrY= github.com/elastic/gosigar v0.8.1-0.20180330100440-37f05ff46ffa/go.mod h1:cdorVVzy1fhmEqmtgqkoE3bYtCfSCkVyjTyCIo22xvs= -github.com/ethereum/go-ethereum v1.9.11 h1:Z0jugPDfuI5qsPY1XgBGVwikpdFK/ANqP7MrYvkmk+A= -github.com/ethereum/go-ethereum v1.9.11/go.mod h1:7oC0Ni6dosMv5pxMigm6s0hN8g4haJMBnqmmo0D9YfQ= github.com/ethereum/go-ethereum v1.13.5 h1:U6TCRciCqZRe4FPXmy1sMGxTfuk8P7u2UoinF3VbaFk= github.com/ethereum/go-ethereum v1.13.5/go.mod h1:yMTu38GSuyxaYzQMViqNmQ1s3cE84abZexQmTgenWk0= github.com/fatih/color v1.3.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= @@ -371,8 +369,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= -gopkg.in/karalabe/cookiejar.v2 v2.0.0-20150724131613-8dcd6a7f4951 h1:DMTcQRFbEH62YPRWwOI647s2e5mHda3oBPMHfrLs2bw= -gopkg.in/karalabe/cookiejar.v2 v2.0.0-20150724131613-8dcd6a7f4951/go.mod h1:owOxCRGGeAx1uugABik6K9oeNu1cgxP/R9ItzLDxNWA= gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce h1:+JknDZhAj8YMt7GC73Ei8pv4MzjDUNPHgQWJdtMAaDU= gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce/go.mod h1:5AcXVHNjg+BDxry382+8OKon8SEWiKktQR07RKPsv1c= gopkg.in/olebedev/go-duktape.v3 v3.0.0-20190213234257-ec84240a7772 h1:hhsSf/5z74Ck/DJYc+R8zpq8KGm7uJvpdLRQED/IedA= diff --git a/interfaces.go b/interfaces.go index 88cf2fcb0c37..dfd96b217a1a 100644 --- a/interfaces.go +++ b/interfaces.go @@ -120,6 +120,7 @@ type CallMsg struct { Value *big.Int // amount of wei sent along with the call Data []byte // input data, usually an ABI-encoded contract method invocation BalanceTokenFee *big.Int + AccessList types.AccessList // EIP-2930 access list. } // A ContractCaller provides contract calls, essentially transactions that are executed by diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go index d0b1f15e4315..2d93455f662f 100644 --- a/internal/ethapi/api.go +++ b/internal/ethapi/api.go @@ -393,7 +393,7 @@ func (s *PrivateAccountAPI) SendTransaction(ctx context.Context, args SendTxArgs if err != nil { return common.Hash{}, err } - return submitTransaction(ctx, s.b, signed) + return SubmitTransaction(ctx, s.b, signed) } // SignTransaction will create a transaction from the given arguments and @@ -414,9 +414,10 @@ func (s *PrivateAccountAPI) SignTransaction(ctx context.Context, args SendTxArgs } signed, err := s.signTransaction(ctx, args, passwd) if err != nil { + log.Warn("Failed transaction sign attempt", "from", args.From, "to", args.To, "value", args.Value.ToInt(), "err", err) return nil, err } - data, err := rlp.EncodeToBytes(signed) + data, err := signed.MarshalBinary() if err != nil { return nil, err } @@ -473,21 +474,19 @@ func (s *PrivateAccountAPI) Sign(ctx context.Context, data hexutil.Bytes, addr c // // https://github.com/XinFinOrg/XDPoSChain/wiki/Management-APIs#personal_ecRecover func (s *PrivateAccountAPI) EcRecover(ctx context.Context, data, sig hexutil.Bytes) (common.Address, error) { - if len(sig) != 65 { - return common.Address{}, fmt.Errorf("signature must be 65 bytes long") + if len(sig) != crypto.SignatureLength { + return common.Address{}, fmt.Errorf("signature must be %d bytes long", crypto.SignatureLength) } - if sig[64] != 27 && sig[64] != 28 { + if sig[crypto.RecoveryIDOffset] != 27 && sig[crypto.RecoveryIDOffset] != 28 { return common.Address{}, fmt.Errorf("invalid Ethereum signature (V is not 27 or 28)") } - sig[64] -= 27 // Transform yellow paper V from 27/28 to 0/1 + sig[crypto.RecoveryIDOffset] -= 27 // Transform yellow paper V from 27/28 to 0/1 - rpk, err := crypto.Ecrecover(signHash(data), sig) + rpk, err := crypto.SigToPub(accounts.TextHash(data), sig) if err != nil { return common.Address{}, err } - pubKey := crypto.ToECDSAPub(rpk) - recoveredAddr := crypto.PubkeyToAddress(*pubKey) - return recoveredAddr, nil + return crypto.PubkeyToAddress(*rpk), nil } // SignAndSendTransaction was renamed to SendTransaction. This method is deprecated @@ -533,6 +532,49 @@ func (s *PublicBlockChainAPI) GetBalance(ctx context.Context, address common.Add return (*hexutil.Big)(state.GetBalance(address)), state.Error() } +// GetTransactionAndReceiptProof returns the Trie transaction and receipt proof of the given transaction hash. +func (s *PublicBlockChainAPI) GetTransactionAndReceiptProof(ctx context.Context, hash common.Hash) (map[string]interface{}, error) { + tx, blockHash, _, index := core.GetTransaction(s.b.ChainDb(), hash) + if tx == nil { + return nil, nil + } + block, err := s.b.GetBlock(ctx, blockHash) + if err != nil { + return nil, err + } + tx_tr := deriveTrie(block.Transactions()) + + keybuf := new(bytes.Buffer) + rlp.Encode(keybuf, uint(index)) + var tx_proof proofPairList + if err := tx_tr.Prove(keybuf.Bytes(), 0, &tx_proof); err != nil { + return nil, err + } + receipts, err := s.b.GetReceipts(ctx, blockHash) + if err != nil { + return nil, err + } + if len(receipts) <= int(index) { + return nil, nil + } + receipt_tr := deriveTrie(receipts) + var receipt_proof proofPairList + if err := receipt_tr.Prove(keybuf.Bytes(), 0, &receipt_proof); err != nil { + return nil, err + } + fields := map[string]interface{}{ + "blockHash": blockHash, + "txRoot": tx_tr.Hash(), + "receiptRoot": receipt_tr.Hash(), + "key": hexutil.Encode(keybuf.Bytes()), + "txProofKeys": tx_proof.keys, + "txProofValues": tx_proof.values, + "receiptProofKeys": receipt_proof.keys, + "receiptProofValues": receipt_proof.values, + } + return fields, nil +} + // GetBlockByNumber returns the requested block. When blockNr is -1 the chain head is returned. When fullTx is true all // transactions in the block are returned in full detail, otherwise only the transaction hash is returned. func (s *PublicBlockChainAPI) GetBlockByNumber(ctx context.Context, blockNr rpc.BlockNumber, fullTx bool) (map[string]interface{}, error) { @@ -653,6 +695,86 @@ func (s *PublicBlockChainAPI) GetStorageAt(ctx context.Context, address common.A return res[:], state.Error() } +// GetBlockReceipts returns the block receipts for the given block hash or number or tag. +func (s *PublicBlockChainAPI) GetBlockReceipts(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) ([]map[string]interface{}, error) { + block, err := s.b.BlockByNumberOrHash(ctx, blockNrOrHash) + if block == nil || err != nil { + // When the block doesn't exist, the RPC method should return JSON null + // as per specification. + return nil, nil + } + receipts, err := s.b.GetReceipts(ctx, block.Hash()) + if err != nil { + return nil, err + } + txs := block.Transactions() + if len(txs) != len(receipts) { + return nil, fmt.Errorf("receipts length mismatch: %d vs %d", len(txs), len(receipts)) + } + + // Derive the sender. + signer := types.MakeSigner(s.b.ChainConfig(), block.Number()) + + result := make([]map[string]interface{}, len(receipts)) + for i, receipt := range receipts { + result[i] = marshalReceipt(receipt, block.Hash(), block.NumberU64(), signer, txs[i], i) + } + + return result, nil +} + +// OverrideAccount indicates the overriding fields of account during the execution +// of a message call. +// Note, state and stateDiff can't be specified at the same time. If state is +// set, message execution will only use the data in the given state. Otherwise +// if statDiff is set, all diff will be applied first and then execute the call +// message. +type OverrideAccount struct { + Nonce *hexutil.Uint64 `json:"nonce"` + Code *hexutil.Bytes `json:"code"` + Balance **hexutil.Big `json:"balance"` + State *map[common.Hash]common.Hash `json:"state"` + StateDiff *map[common.Hash]common.Hash `json:"stateDiff"` +} + +// StateOverride is the collection of overridden accounts. +type StateOverride map[common.Address]OverrideAccount + +// Apply overrides the fields of specified accounts into the given state. +func (diff *StateOverride) Apply(state *state.StateDB) error { + if diff == nil { + return nil + } + for addr, account := range *diff { + // Override account nonce. + if account.Nonce != nil { + state.SetNonce(addr, uint64(*account.Nonce)) + } + // Override account(contract) code. + if account.Code != nil { + state.SetCode(addr, *account.Code) + } + // Override account balance. + if account.Balance != nil { + state.SetBalance(addr, (*big.Int)(*account.Balance)) + } + if account.State != nil && account.StateDiff != nil { + return fmt.Errorf("account %s has both 'state' and 'stateDiff'", addr.Hex()) + } + // Replace entire state if caller requires. + if account.State != nil { + state.SetStorage(addr, *account.State) + } + // Apply state diff into specified accounts. + if account.StateDiff != nil { + for key, value := range *account.StateDiff { + state.SetState(addr, key, value) + } + } + } + return nil +} + func (s *PublicBlockChainAPI) GetBlockSignersByHash(ctx context.Context, blockHash common.Hash) ([]common.Address, error) { block, err := s.b.GetBlock(ctx, blockHash) if err != nil || block == nil { @@ -1105,42 +1227,85 @@ func (s *PublicBlockChainAPI) getCandidatesFromSmartContract() ([]utils.Masterno // CallArgs represents the arguments for a call. type CallArgs struct { - From common.Address `json:"from"` - To *common.Address `json:"to"` - Gas hexutil.Uint64 `json:"gas"` - GasPrice hexutil.Big `json:"gasPrice"` - Value hexutil.Big `json:"value"` - Data hexutil.Bytes `json:"data"` -} - -func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, vmCfg vm.Config, timeout time.Duration) ([]byte, uint64, bool, error, error) { - defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) - - statedb, header, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) - if statedb == nil || err != nil { - return nil, 0, false, err, nil - } + From *common.Address `json:"from"` + To *common.Address `json:"to"` + Gas *hexutil.Uint64 `json:"gas"` + GasPrice *hexutil.Big `json:"gasPrice"` + Value *hexutil.Big `json:"value"` + Data *hexutil.Bytes `json:"data"` + AccessList *types.AccessList `json:"accessList"` +} + +// ToMessage converts CallArgs to the Message type used by the core evm +// TODO: set balanceTokenFee +func (args *CallArgs) ToMessage(b Backend, number *big.Int, globalGasCap uint64) types.Message { // Set sender address or use a default if none specified - addr := args.From - if addr == (common.Address{}) { - if wallets := s.b.AccountManager().Wallets(); len(wallets) > 0 { + var addr common.Address + if args.From == nil || *args.From == (common.Address{}) { + if wallets := b.AccountManager().Wallets(); len(wallets) > 0 { if accounts := wallets[0].Accounts(); len(accounts) > 0 { addr = accounts[0].Address } } + } else { + addr = *args.From } + // Set default gas & gas price if none were set - gas, gasPrice := uint64(args.Gas), args.GasPrice.ToInt() + gas := globalGasCap if gas == 0 { - gas = math.MaxUint64 / 2 + gas = uint64(math.MaxUint64 / 2) + } + if args.Gas != nil { + gas = uint64(*args.Gas) } - if gasPrice.Sign() == 0 { + if globalGasCap != 0 && globalGasCap < gas { + log.Warn("Caller gas above allowance, capping", "requested", gas, "cap", globalGasCap) + gas = globalGasCap + } + gasPrice := new(big.Int) + if args.GasPrice != nil { + gasPrice = args.GasPrice.ToInt() + } + if gasPrice.Sign() <= 0 { gasPrice = new(big.Int).SetUint64(defaultGasPrice) } + + value := new(big.Int) + if args.Value != nil { + value = args.Value.ToInt() + } + + var data []byte + if args.Data != nil { + data = *args.Data + } + + var accessList types.AccessList + if args.AccessList != nil { + accessList = *args.AccessList + } + balanceTokenFee := big.NewInt(0).SetUint64(gas) balanceTokenFee = balanceTokenFee.Mul(balanceTokenFee, gasPrice) + // Create new call message - msg := types.NewMessage(addr, args.To, 0, args.Value.ToInt(), gas, gasPrice, args.Data, false, balanceTokenFee, header.Number) + msg := types.NewMessage(addr, args.To, 0, value, gas, gasPrice, data, accessList, false, balanceTokenFee, number) + return msg +} + +func DoCall(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, vmCfg vm.Config, timeout time.Duration, globalGasCap uint64) ([]byte, uint64, bool, error, error) { + defer func(start time.Time) { log.Debug("Executing EVM call finished", "runtime", time.Since(start)) }(time.Now()) + + statedb, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) + if statedb == nil || err != nil { + return nil, 0, false, err, nil + } + if err := overrides.Apply(statedb); err != nil { + return nil, 0, false, err, nil + } + + msg := args.ToMessage(b, header.Number, globalGasCap) // Setup context so it may be cancelled the call has completed // or, in case of unmetered gas, setup a context with a timeout. @@ -1154,20 +1319,21 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr // this makes sure resources are cleaned up. defer cancel() - block, err := s.b.BlockByNumberOrHash(ctx, blockNrOrHash) + block, err := b.BlockByHash(ctx, header.Hash()) if err != nil { return nil, 0, false, err, nil } - author, err := s.b.GetEngine().Author(block.Header()) + author, err := b.GetEngine().Author(block.Header()) if err != nil { return nil, 0, false, err, nil } - XDCxState, err := s.b.XDCxService().GetTradingState(block, author) + XDCxState, err := b.XDCxService().GetTradingState(block, author) if err != nil { return nil, 0, false, err, nil } + // Get a new instance of the EVM. - evm, vmError, err := s.b.GetEVM(ctx, msg, statedb, XDCxState, header, vmCfg) + evm, vmError, err := b.GetEVM(ctx, msg, statedb, XDCxState, header, &vmCfg) if err != nil { return nil, 0, false, err, nil } @@ -1178,8 +1344,7 @@ func (s *PublicBlockChainAPI) doCall(ctx context.Context, args CallArgs, blockNr evm.Cancel() }() - // Setup the gas pool (also for unmetered requests) - // and apply the message. + // Execute the message. gp := new(core.GasPool).AddGas(math.MaxUint64) owner := common.Address{} res, gas, failed, err, vmErr := core.ApplyMessage(evm, msg, gp, owner) @@ -1229,12 +1394,12 @@ func (e *revertError) ErrorData() interface{} { // Call executes the given transaction on the state for the given block number. // It doesn't make and changes in the state/blockchain and is useful to execute and retrieve values. -func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNrOrHash *rpc.BlockNumberOrHash) (hexutil.Bytes, error) { +func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Bytes, error) { if blockNrOrHash == nil { latest := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) blockNrOrHash = &latest } - result, _, failed, err, vmErr := s.doCall(ctx, args, *blockNrOrHash, vm.Config{}, 5*time.Second) + result, _, failed, err, vmErr := DoCall(ctx, s.b, args, *blockNrOrHash, overrides, vm.Config{}, 5*time.Second, s.b.RPCGasCap()) if err != nil { return nil, err } @@ -1246,36 +1411,51 @@ func (s *PublicBlockChainAPI) Call(ctx context.Context, args CallArgs, blockNrOr return (hexutil.Bytes)(result), vmErr } -func (s *PublicBlockChainAPI) doEstimateGas(ctx context.Context, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Uint64, error) { +func DoEstimateGas(ctx context.Context, b Backend, args CallArgs, blockNrOrHash rpc.BlockNumberOrHash, overrides *StateOverride, gasCap uint64) (hexutil.Uint64, error) { // Retrieve the base state and mutate it with any overrides - state, _, err := s.b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) + state, _, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) if state == nil || err != nil { return 0, err } - + if err = overrides.Apply(state); err != nil { + return 0, err + } // Binary search the gas requirement, as it may be higher than the amount used var ( lo uint64 = params.TxGas - 1 hi uint64 cap uint64 ) - if uint64(args.Gas) >= params.TxGas { - hi = uint64(args.Gas) + // Use zero address if sender unspecified. + if args.From == nil { + args.From = new(common.Address) + } + // Determine the highest gas limit can be used during the estimation. + if args.Gas != nil && uint64(*args.Gas) >= params.TxGas { + hi = uint64(*args.Gas) } else { // Retrieve the current pending block to act as the gas ceiling - block, err := s.b.BlockByNumber(ctx, rpc.LatestBlockNumber) + block, err := b.BlockByNumberOrHash(ctx, blockNrOrHash) if err != nil { return 0, err } + if block == nil { + return 0, errors.New("block not found") + } hi = block.GasLimit() } + // Recap the highest gas allowance with specified gascap. + if gasCap != 0 && hi > gasCap { + log.Warn("Caller gas above allowance, capping", "requested", hi, "cap", gasCap) + hi = gasCap + } cap = hi // Create a helper to check if a gas allowance results in an executable transaction executable := func(gas uint64) (bool, []byte, error, error) { - args.Gas = hexutil.Uint64(gas) + args.Gas = (*hexutil.Uint64)(&gas) - res, _, failed, err, vmErr := s.doCall(ctx, args, blockNrOrHash, vm.Config{}, 0) + res, _, failed, err, vmErr := DoCall(ctx, b, args, blockNrOrHash, nil, vm.Config{}, 0, gasCap) if err != nil { if errors.Is(err, vm.ErrOutOfGas) || errors.Is(err, core.ErrIntrinsicGas) { return false, nil, nil, nil // Special case, raise gas limit @@ -1293,13 +1473,8 @@ func (s *PublicBlockChainAPI) doEstimateGas(ctx context.Context, args CallArgs, // directly try 21000. Returning 21000 without any execution is dangerous as // some tx field combos might bump the price up even for plain transfers (e.g. // unused access list items). Ever so slightly wasteful, but safer overall. - if len(args.Data) == 0 && args.To != nil { - statedb, _, err := s.b.StateAndHeaderByNumber(ctx, rpc.LatestBlockNumber) - if statedb == nil || err != nil { - return 0, err - } - - if statedb.GetCodeSize(*args.To) == 0 { + if args.Data == nil || len(*args.Data) == 0 { + if args.To != nil && state.GetCodeSize(*args.To) == 0 { ok, _, err, _ := executable(params.TxGas) if ok && err == nil { return hexutil.Uint64(params.TxGas), nil @@ -1350,12 +1525,12 @@ func (s *PublicBlockChainAPI) doEstimateGas(ctx context.Context, args CallArgs, // EstimateGas returns an estimate of the amount of gas needed to execute the // given transaction against the current pending block. -func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs, blockNrOrHash *rpc.BlockNumberOrHash) (hexutil.Uint64, error) { +func (s *PublicBlockChainAPI) EstimateGas(ctx context.Context, args CallArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Uint64, error) { bNrOrHash := rpc.BlockNumberOrHashWithNumber(rpc.LatestBlockNumber) if blockNrOrHash != nil { bNrOrHash = *blockNrOrHash } - return s.doEstimateGas(ctx, args, bNrOrHash) + return DoEstimateGas(ctx, s.b, args, bNrOrHash, overrides, s.b.RPCGasCap()) } // ExecutionResult groups all structured logs emitted by the EVM @@ -1629,33 +1804,43 @@ func (s *PublicBlockChainAPI) rpcOutputBlockSigners(b *types.Block, ctx context. // RPCTransaction represents a transaction that will serialize to the RPC representation of a transaction type RPCTransaction struct { - BlockHash common.Hash `json:"blockHash"` - BlockNumber *hexutil.Big `json:"blockNumber"` - From common.Address `json:"from"` - Gas hexutil.Uint64 `json:"gas"` - GasPrice *hexutil.Big `json:"gasPrice"` - Hash common.Hash `json:"hash"` - Input hexutil.Bytes `json:"input"` - Nonce hexutil.Uint64 `json:"nonce"` - To *common.Address `json:"to"` - TransactionIndex hexutil.Uint `json:"transactionIndex"` - Value *hexutil.Big `json:"value"` - V *hexutil.Big `json:"v"` - R *hexutil.Big `json:"r"` - S *hexutil.Big `json:"s"` + BlockHash *common.Hash `json:"blockHash"` + BlockNumber *hexutil.Big `json:"blockNumber"` + From common.Address `json:"from"` + Gas hexutil.Uint64 `json:"gas"` + GasPrice *hexutil.Big `json:"gasPrice"` + Hash common.Hash `json:"hash"` + Input hexutil.Bytes `json:"input"` + Nonce hexutil.Uint64 `json:"nonce"` + To *common.Address `json:"to"` + TransactionIndex *hexutil.Uint64 `json:"transactionIndex"` + Value *hexutil.Big `json:"value"` + Type hexutil.Uint64 `json:"type"` + Accesses *types.AccessList `json:"accessList,omitempty"` + ChainID *hexutil.Big `json:"chainId,omitempty"` + V *hexutil.Big `json:"v"` + R *hexutil.Big `json:"r"` + S *hexutil.Big `json:"s"` } // newRPCTransaction returns a transaction that will serialize to the RPC // representation, with the given location metadata set (if available). func newRPCTransaction(tx *types.Transaction, blockHash common.Hash, blockNumber uint64, index uint64) *RPCTransaction { - var signer types.Signer = types.FrontierSigner{} + // Determine the signer. For replay-protected transactions, use the most permissive + // signer, because we assume that signers are backwards-compatible with old + // transactions. For non-protected transactions, the homestead signer signer is used + // because the return value of ChainId is zero for those transactions. + var signer types.Signer if tx.Protected() { - signer = types.NewEIP155Signer(tx.ChainId()) + signer = types.LatestSignerForChainID(tx.ChainId()) + } else { + signer = types.HomesteadSigner{} } + from, _ := types.Sender(signer, tx) v, r, s := tx.RawSignatureValues() - result := &RPCTransaction{ + Type: hexutil.Uint64(tx.Type()), From: from, Gas: hexutil.Uint64(tx.Gas()), GasPrice: (*hexutil.Big)(tx.GasPrice()), @@ -1669,9 +1854,14 @@ func newRPCTransaction(tx *types.Transaction, blockHash common.Hash, blockNumber S: (*hexutil.Big)(s), } if blockHash != (common.Hash{}) { - result.BlockHash = blockHash + result.BlockHash = &blockHash result.BlockNumber = (*hexutil.Big)(new(big.Int).SetUint64(blockNumber)) - result.TransactionIndex = hexutil.Uint(index) + result.TransactionIndex = (*hexutil.Uint64)(&index) + } + if tx.Type() == types.AccessListTxType { + al := tx.AccessList() + result.Accesses = &al + result.ChainID = (*hexutil.Big)(tx.ChainId()) } return result } @@ -1696,7 +1886,7 @@ func newRPCRawTransactionFromBlockIndex(b *types.Block, index uint64) hexutil.By if index >= uint64(len(txs)) { return nil } - blob, _ := rlp.EncodeToBytes(txs[index]) + blob, _ := txs[index].MarshalBinary() return blob } @@ -1710,10 +1900,131 @@ func newRPCTransactionFromBlockHash(b *types.Block, hash common.Hash) *RPCTransa return nil } +// accessListResult returns an optional accesslist +// Its the result of the `debug_createAccessList` RPC call. +// It contains an error if the transaction itself failed. +type accessListResult struct { + Accesslist *types.AccessList `json:"accessList"` + Error string `json:"error,omitempty"` + GasUsed hexutil.Uint64 `json:"gasUsed"` +} + +// CreateAccessList creates a EIP-2930 type AccessList for the given transaction. +// Reexec and BlockNrOrHash can be specified to create the accessList on top of a certain state. +func (s *PublicBlockChainAPI) CreateAccessList(ctx context.Context, args SendTxArgs, blockNrOrHash *rpc.BlockNumberOrHash) (*accessListResult, error) { + bNrOrHash := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + if blockNrOrHash != nil { + bNrOrHash = *blockNrOrHash + } + acl, gasUsed, vmerr, err := AccessList(ctx, s.b, bNrOrHash, args) + if err != nil { + return nil, err + } + result := &accessListResult{Accesslist: &acl, GasUsed: hexutil.Uint64(gasUsed)} + if vmerr != nil { + result.Error = vmerr.Error() + } + return result, nil +} + +// AccessList creates an access list for the given transaction. +// If the accesslist creation fails an error is returned. +// If the transaction itself fails, an vmErr is returned. +func AccessList(ctx context.Context, b Backend, blockNrOrHash rpc.BlockNumberOrHash, args SendTxArgs) (acl types.AccessList, gasUsed uint64, vmErr error, err error) { + // Retrieve the execution context + db, header, err := b.StateAndHeaderByNumberOrHash(ctx, blockNrOrHash) + if db == nil || err != nil { + return nil, 0, nil, err + } + block, err := b.BlockByHash(ctx, header.Hash()) + if err != nil { + return nil, 0, nil, err + } + author, err := b.GetEngine().Author(block.Header()) + if err != nil { + return nil, 0, nil, err + } + XDCxState, err := b.XDCxService().GetTradingState(block, author) + if err != nil { + return nil, 0, nil, err + } + owner := common.Address{} + + // If the gas amount is not set, extract this as it will depend on access + // lists and we'll need to reestimate every time + nogas := args.Gas == nil + + // Ensure any missing fields are filled, extract the recipient and input data + if err := args.setDefaults(ctx, b); err != nil { + return nil, 0, nil, err + } + var to common.Address + if args.To != nil { + to = *args.To + } else { + to = crypto.CreateAddress(args.From, uint64(*args.Nonce)) + } + var input []byte + if args.Input != nil { + input = *args.Input + } else if args.Data != nil { + input = *args.Data + } + // Retrieve the precompiles since they don't need to be added to the access list + precompiles := vm.ActivePrecompiles(b.ChainConfig().Rules(header.Number)) + + // Create an initial tracer + prevTracer := vm.NewAccessListTracer(nil, args.From, to, precompiles) + if args.AccessList != nil { + prevTracer = vm.NewAccessListTracer(*args.AccessList, args.From, to, precompiles) + } + for { + // Retrieve the current access list to expand + accessList := prevTracer.AccessList() + log.Trace("Creating access list", "input", accessList) + + // If no gas amount was specified, each unique access list needs it's own + // gas calculation. This is quite expensive, but we need to be accurate + // and it's convered by the sender only anyway. + if nogas { + args.Gas = nil + if err := args.setDefaults(ctx, b); err != nil { + return nil, 0, nil, err // shouldn't happen, just in case + } + } + // Copy the original db so we don't modify it + statedb := db.Copy() + feeCapacity := state.GetTRC21FeeCapacityFromState(statedb) + var balanceTokenFee *big.Int + if value, ok := feeCapacity[to]; ok { + balanceTokenFee = value + } + msg := types.NewMessage(args.From, args.To, uint64(*args.Nonce), args.Value.ToInt(), uint64(*args.Gas), args.GasPrice.ToInt(), input, accessList, false, balanceTokenFee, header.Number) + + // Apply the transaction with the access list tracer + tracer := vm.NewAccessListTracer(accessList, args.From, to, precompiles) + config := vm.Config{Tracer: tracer, Debug: true} + vmenv, _, err := b.GetEVM(ctx, msg, statedb, XDCxState, header, &config) + if err != nil { + return nil, 0, nil, err + } + // TODO: determine the value of owner + _, UsedGas, _, err, vmErr := core.ApplyMessage(vmenv, msg, new(core.GasPool).AddGas(msg.Gas()), owner) + if err != nil { + return nil, 0, nil, fmt.Errorf("failed to apply transaction: %v err: %v", args.toTransaction().Hash(), err) + } + if tracer.Equal(prevTracer) { + return accessList, UsedGas, vmErr, nil + } + prevTracer = tracer + } +} + // PublicTransactionPoolAPI exposes methods for the RPC interface type PublicTransactionPoolAPI struct { b Backend nonceLock *AddrLocker + signer types.Signer } // PublicTransactionPoolAPI exposes methods for the RPC interface @@ -1724,7 +2035,10 @@ type PublicXDCXTransactionPoolAPI struct { // NewPublicTransactionPoolAPI creates a new RPC service with methods specific for the transaction pool. func NewPublicTransactionPoolAPI(b Backend, nonceLock *AddrLocker) *PublicTransactionPoolAPI { - return &PublicTransactionPoolAPI{b, nonceLock} + // The signer used by the API should always be the 'latest' known one because we expect + // signers to be backwards-compatible with old transactions. + signer := types.LatestSigner(b.ChainConfig()) + return &PublicTransactionPoolAPI{b, nonceLock, signer} } // NewPublicTransactionPoolAPI creates a new RPC service with methods specific for the transaction pool. @@ -1817,17 +2131,16 @@ func (s *PublicTransactionPoolAPI) GetTransactionByHash(ctx context.Context, has // GetRawTransactionByHash returns the bytes of the transaction for the given hash. func (s *PublicTransactionPoolAPI) GetRawTransactionByHash(ctx context.Context, hash common.Hash) (hexutil.Bytes, error) { - var tx *types.Transaction - // Retrieve a finalized transaction, or a pooled otherwise - if tx, _, _, _ = core.GetTransaction(s.b.ChainDb(), hash); tx == nil { + tx, _, _, _ := core.GetTransaction(s.b.ChainDb(), hash) + if tx == nil { if tx = s.b.GetPoolTransaction(hash); tx == nil { // Transaction not found anywhere, abort return nil, nil } } // Serialize to RLP and return - return rlp.EncodeToBytes(tx) + return tx.MarshalBinary() } // GetTransactionReceipt returns the transaction receipt for the given transaction hash. @@ -1845,10 +2158,9 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, ha } receipt := receipts[index] - var signer types.Signer = types.FrontierSigner{} - if tx.Protected() { - signer = types.NewEIP155Signer(tx.ChainId()) - } + // Derive the sender. + bigblock := new(big.Int).SetUint64(blockNumber) + signer := types.MakeSigner(s.b.ChainConfig(), bigblock) from, _ := types.Sender(signer, tx) fields := map[string]interface{}{ @@ -1863,6 +2175,7 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, ha "contractAddress": nil, "logs": receipt.Logs, "logsBloom": receipt.Bloom, + "type": hexutil.Uint(tx.Type()), } // Assign receipt status or post state. @@ -1881,6 +2194,44 @@ func (s *PublicTransactionPoolAPI) GetTransactionReceipt(ctx context.Context, ha return fields, nil } +// marshalReceipt marshals a transaction receipt into a JSON object. +func marshalReceipt(receipt *types.Receipt, blockHash common.Hash, blockNumber uint64, signer types.Signer, tx *types.Transaction, txIndex int) map[string]interface{} { + from, _ := types.Sender(signer, tx) + + fields := map[string]interface{}{ + "blockHash": blockHash, + "blockNumber": hexutil.Uint64(blockNumber), + "transactionHash": tx.Hash(), + "transactionIndex": hexutil.Uint64(txIndex), + "from": from, + "to": tx.To(), + "gasUsed": hexutil.Uint64(receipt.GasUsed), + "cumulativeGasUsed": hexutil.Uint64(receipt.CumulativeGasUsed), + "contractAddress": nil, + "logs": receipt.Logs, + "logsBloom": receipt.Bloom, + "type": hexutil.Uint(tx.Type()), + // uncomment below line after EIP-1559 + // TODO: "effectiveGasPrice": (*hexutil.Big)(receipt.EffectiveGasPrice), + } + + // Assign receipt status or post state. + if len(receipt.PostState) > 0 { + fields["root"] = hexutil.Bytes(receipt.PostState) + } else { + fields["status"] = hexutil.Uint(receipt.Status) + } + if receipt.Logs == nil { + fields["logs"] = []*types.Log{} + } + + // If the ContractAddress is 20 0x0 bytes, assume it is not a contract creation + if receipt.ContractAddress != (common.Address{}) { + fields["contractAddress"] = receipt.ContractAddress + } + return fields +} + // sign is a helper function that signs a transaction with the private key of the given address. func (s *PublicTransactionPoolAPI) sign(addr common.Address, tx *types.Transaction) (*types.Transaction, error) { // Look up the wallet containing the requested signer @@ -1910,14 +2261,14 @@ type SendTxArgs struct { // newer name and should be preferred by clients. Data *hexutil.Bytes `json:"data"` Input *hexutil.Bytes `json:"input"` + + // For non-legacy transactions + AccessList *types.AccessList `json:"accessList,omitempty"` + ChainID *hexutil.Big `json:"chainId,omitempty"` } -// setDefaults is a helper function that fills in default values for unspecified tx fields. +// setDefaults fills in default values for unspecified tx fields. func (args *SendTxArgs) setDefaults(ctx context.Context, b Backend) error { - if args.Gas == nil { - args.Gas = new(hexutil.Uint64) - *(*uint64)(args.Gas) = 90000 - } if args.GasPrice == nil { price, err := b.SuggestPrice(ctx) if err != nil { @@ -1950,45 +2301,98 @@ func (args *SendTxArgs) setDefaults(ctx context.Context, b Backend) error { return errors.New(`contract creation without any data provided`) } } + // Estimate the gas usage if necessary. + if args.Gas == nil { + // For backwards-compatibility reason, we try both input and data + // but input is preferred. + input := args.Input + if input == nil { + input = args.Data + } + callArgs := CallArgs{ + From: &args.From, // From shouldn't be nil + To: args.To, + GasPrice: args.GasPrice, + Value: args.Value, + Data: input, + AccessList: args.AccessList, + } + pendingBlockNr := rpc.BlockNumberOrHashWithNumber(rpc.PendingBlockNumber) + estimated, err := DoEstimateGas(ctx, b, callArgs, pendingBlockNr, nil, b.RPCGasCap()) + if err != nil { + return err + } + args.Gas = &estimated + log.Trace("Estimate gas usage automatically", "gas", args.Gas) + + } + if args.ChainID == nil { + id := (*hexutil.Big)(b.ChainConfig().ChainId) + args.ChainID = id + } return nil } +// toTransaction converts the arguments to a transaction. +// This assumes that setDefaults has been called. func (args *SendTxArgs) toTransaction() *types.Transaction { var input []byte - if args.Data != nil { - input = *args.Data - } else if args.Input != nil { + if args.Input != nil { input = *args.Input + } else if args.Data != nil { + input = *args.Data } - if args.To == nil { - return types.NewContractCreation(uint64(*args.Nonce), (*big.Int)(args.Value), uint64(*args.Gas), (*big.Int)(args.GasPrice), input) + var data types.TxData + if args.AccessList == nil { + data = &types.LegacyTx{ + To: args.To, + Nonce: uint64(*args.Nonce), + Gas: uint64(*args.Gas), + GasPrice: (*big.Int)(args.GasPrice), + Value: (*big.Int)(args.Value), + Data: input, + } + } else { + data = &types.AccessListTx{ + To: args.To, + ChainID: (*big.Int)(args.ChainID), + Nonce: uint64(*args.Nonce), + Gas: uint64(*args.Gas), + GasPrice: (*big.Int)(args.GasPrice), + Value: (*big.Int)(args.Value), + Data: input, + AccessList: *args.AccessList, + } } - return types.NewTransaction(uint64(*args.Nonce), *args.To, (*big.Int)(args.Value), uint64(*args.Gas), (*big.Int)(args.GasPrice), input) + return types.NewTx(data) } -// submitTransaction is a helper function that submits tx to txPool and logs a message. -func submitTransaction(ctx context.Context, b Backend, tx *types.Transaction) (common.Hash, error) { +// SubmitTransaction is a helper function that submits tx to txPool and logs a message. +func SubmitTransaction(ctx context.Context, b Backend, tx *types.Transaction) (common.Hash, error) { if tx.To() != nil && tx.IsSpecialTransaction() { return common.Hash{}, errors.New("Dont allow transaction sent to BlockSigners & RandomizeSMC smart contract via API") } if err := b.SendTx(ctx, tx); err != nil { return common.Hash{}, err } + + // Print a log with full tx details for manual investigations and interventions + signer := types.MakeSigner(b.ChainConfig(), b.CurrentBlock().Number()) + from, err := types.Sender(signer, tx) + if err != nil { + return common.Hash{}, err + } + if tx.To() == nil { - signer := types.MakeSigner(b.ChainConfig(), b.CurrentBlock().Number()) - from, err := types.Sender(signer, tx) - if err != nil { - return common.Hash{}, err - } addr := crypto.CreateAddress(from, tx.Nonce()) - log.Trace("Submitted contract creation", "fullhash", tx.Hash().Hex(), "contract", addr.Hex()) + log.Info("Submitted contract creation", "hash", tx.Hash().Hex(), "from", from, "nonce", tx.Nonce(), "contract", addr.Hex(), "value", tx.Value()) } else { - log.Trace("Submitted transaction", "fullhash", tx.Hash().Hex(), "recipient", tx.To()) + log.Info("Submitted transaction", "hash", tx.Hash().Hex(), "from", from, "nonce", tx.Nonce(), "recipient", tx.To(), "value", tx.Value()) } return tx.Hash(), nil } -// submitTransaction is a helper function that submits tx to txPool and logs a message. +// SubmitTransaction is a helper function that submits tx to txPool and logs a message. func submitOrderTransaction(ctx context.Context, b Backend, tx *types.OrderTransaction) (common.Hash, error) { if err := b.SendOrderTx(ctx, tx); err != nil { @@ -2040,17 +2444,33 @@ func (s *PublicTransactionPoolAPI) SendTransaction(ctx context.Context, args Sen if err != nil { return common.Hash{}, err } - return submitTransaction(ctx, s.b, signed) + return SubmitTransaction(ctx, s.b, signed) +} + +// FillTransaction fills the defaults (nonce, gas, gasPrice) on a given unsigned transaction, +// and returns it to the caller for further processing (signing + broadcast) +func (s *PublicTransactionPoolAPI) FillTransaction(ctx context.Context, args SendTxArgs) (*SignTransactionResult, error) { + // Set some sanity defaults and terminate on failure + if err := args.setDefaults(ctx, s.b); err != nil { + return nil, err + } + // Assemble the transaction and obtain rlp + tx := args.toTransaction() + data, err := tx.MarshalBinary() + if err != nil { + return nil, err + } + return &SignTransactionResult{data, tx}, nil } // SendRawTransaction will add the signed transaction to the transaction pool. // The sender is responsible for signing the transaction and using the correct nonce. -func (s *PublicTransactionPoolAPI) SendRawTransaction(ctx context.Context, encodedTx hexutil.Bytes) (common.Hash, error) { +func (s *PublicTransactionPoolAPI) SendRawTransaction(ctx context.Context, input hexutil.Bytes) (common.Hash, error) { tx := new(types.Transaction) - if err := rlp.DecodeBytes(encodedTx, tx); err != nil { + if err := tx.UnmarshalBinary(input); err != nil { return common.Hash{}, err } - return submitTransaction(ctx, s.b, tx) + return SubmitTransaction(ctx, s.b, tx) } // SendOrderRawTransaction will add the signed transaction to the transaction pool. @@ -2943,29 +3363,30 @@ func (s *PublicTransactionPoolAPI) SignTransaction(ctx context.Context, args Sen if err != nil { return nil, err } - data, err := rlp.EncodeToBytes(tx) + data, err := tx.MarshalBinary() if err != nil { return nil, err } return &SignTransactionResult{data, tx}, nil } -// PendingTransactions returns the transactions that are in the transaction pool and have a from address that is one of -// the accounts this node manages. +// PendingTransactions returns the transactions that are in the transaction pool +// and have a from address that is one of the accounts this node manages. func (s *PublicTransactionPoolAPI) PendingTransactions() ([]*RPCTransaction, error) { pending, err := s.b.GetPoolTransactions() if err != nil { return nil, err } - + accounts := make(map[common.Address]struct{}) + for _, wallet := range s.b.AccountManager().Wallets() { + for _, account := range wallet.Accounts() { + accounts[account.Address] = struct{}{} + } + } transactions := make([]*RPCTransaction, 0, len(pending)) for _, tx := range pending { - var signer types.Signer = types.HomesteadSigner{} - if tx.Protected() { - signer = types.NewEIP155Signer(tx.ChainId()) - } - from, _ := types.Sender(signer, tx) - if _, err := s.b.AccountManager().Find(accounts.Account{Address: from}); err == nil { + from, _ := types.Sender(s.signer, tx) + if _, exists := accounts[from]; exists { transactions = append(transactions, newRPCPendingTransaction(tx)) } } @@ -2982,19 +3403,16 @@ func (s *PublicTransactionPoolAPI) Resend(ctx context.Context, sendArgs SendTxAr return common.Hash{}, err } matchTx := sendArgs.toTransaction() + + // Iterate the pending list for replacement pending, err := s.b.GetPoolTransactions() if err != nil { return common.Hash{}, err } - for _, p := range pending { - var signer types.Signer = types.HomesteadSigner{} - if p.Protected() { - signer = types.NewEIP155Signer(p.ChainId()) - } - wantSigHash := signer.Hash(matchTx) - - if pFrom, err := types.Sender(signer, p); err == nil && pFrom == sendArgs.From && signer.Hash(p) == wantSigHash { + wantSigHash := s.signer.Hash(matchTx) + pFrom, err := types.Sender(s.signer, p) + if err == nil && pFrom == sendArgs.From && s.signer.Hash(p) == wantSigHash { // Match. Re-sign and send the transaction. if gasPrice != nil && (*big.Int)(gasPrice).Sign() != 0 { sendArgs.GasPrice = gasPrice @@ -3012,8 +3430,7 @@ func (s *PublicTransactionPoolAPI) Resend(ctx context.Context, sendArgs SendTxAr return signedTx.Hash(), nil } } - - return common.Hash{}, fmt.Errorf("Transaction %#x not found", matchTx.Hash()) + return common.Hash{}, fmt.Errorf("transaction %#x not found", matchTx.Hash()) } // PublicDebugAPI is the collection of Ethereum APIs exposed over the public @@ -3236,3 +3653,18 @@ func (s *PublicBlockChainAPI) GetStakerROIMasternode(masternode common.Address) return 100.0 / float64(totalCap.Div(totalCap, voterRewardAYear).Uint64()) } + +// checkTxFee is an internal function used to check whether the fee of +// the given transaction is _reasonable_(under the cap). +func checkTxFee(gasPrice *big.Int, gas uint64, cap float64) error { + // Short circuit if there is no cap for transaction fee at all. + if cap == 0 { + return nil + } + feeEth := new(big.Float).Quo(new(big.Float).SetInt(new(big.Int).Mul(gasPrice, new(big.Int).SetUint64(gas))), new(big.Float).SetInt(big.NewInt(params.Ether))) + feeFloat, _ := feeEth.Float64() + if feeFloat > cap { + return fmt.Errorf("tx fee (%.2f ether) exceeds the configured cap (%.2f ether)", feeFloat, cap) + } + return nil +} diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index 2304490787cc..00a9aea8acf9 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -51,6 +51,7 @@ type Backend interface { ChainDb() ethdb.Database EventMux() *event.TypeMux AccountManager() *accounts.Manager + RPCGasCap() uint64 // global gas cap for eth_call over rpc: DoS protection XDCxService() *XDCx.XDCX LendingService() *XDCxlending.Lending @@ -67,7 +68,7 @@ type Backend interface { GetBlock(ctx context.Context, blockHash common.Hash) (*types.Block, error) GetReceipts(ctx context.Context, blockHash common.Hash) (types.Receipts, error) GetTd(blockHash common.Hash) *big.Int - GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) + GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmConfig *vm.Config) (*vm.EVM, func() error, error) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription SubscribeChainSideEvent(ch chan<- core.ChainSideEvent) event.Subscription @@ -79,7 +80,7 @@ type Backend interface { GetPoolNonce(ctx context.Context, addr common.Address) (uint64, error) Stats() (pending int, queued int) TxPoolContent() (map[common.Address]types.Transactions, map[common.Address]types.Transactions) - SubscribeTxPreEvent(chan<- core.TxPreEvent) event.Subscription + SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription // Order Pool Transaction SendOrderTx(ctx context.Context, signedTx *types.OrderTransaction) error diff --git a/internal/ethapi/trie_proof.go b/internal/ethapi/trie_proof.go new file mode 100644 index 000000000000..4a7747ab8f5f --- /dev/null +++ b/internal/ethapi/trie_proof.go @@ -0,0 +1,39 @@ +package ethapi + +import ( + "bytes" + + "github.com/XinFinOrg/XDPoSChain/common/hexutil" + "github.com/XinFinOrg/XDPoSChain/core/types" + "github.com/XinFinOrg/XDPoSChain/rlp" + "github.com/XinFinOrg/XDPoSChain/trie" +) + +// proofPairList implements ethdb.KeyValueWriter and collects the proofs as +// hex-strings of key and value for delivery to rpc-caller. +type proofPairList struct { + keys []string + values []string +} + +func (n *proofPairList) Put(key []byte, value []byte) error { + n.keys = append(n.keys, hexutil.Encode(key)) + n.values = append(n.values, hexutil.Encode(value)) + return nil +} + +func (n *proofPairList) Delete(key []byte) error { + panic("not supported") +} + +// modified from core/types/derive_sha.go +func deriveTrie(list types.DerivableList) *trie.Trie { + keybuf := new(bytes.Buffer) + trie := new(trie.Trie) + for i := 0; i < list.Len(); i++ { + keybuf.Reset() + rlp.Encode(keybuf, uint(i)) + trie.Update(keybuf.Bytes(), list.GetRlp(i)) + } + return trie +} diff --git a/internal/ethapi/trie_proof_test.go b/internal/ethapi/trie_proof_test.go new file mode 100644 index 000000000000..10dd9988dbc5 --- /dev/null +++ b/internal/ethapi/trie_proof_test.go @@ -0,0 +1,106 @@ +package ethapi + +import ( + "bytes" + "fmt" + "math/big" + "reflect" + "testing" + + "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/common/hexutil" + "github.com/XinFinOrg/XDPoSChain/core/types" + "github.com/XinFinOrg/XDPoSChain/rlp" + "github.com/XinFinOrg/XDPoSChain/trie" +) + +// implement interface only for testing verifyProof +func (n *proofPairList) Has(key []byte) (bool, error) { + key_hex := hexutil.Encode(key) + for _, k := range n.keys { + if k == key_hex { + return true, nil + } + } + return false, nil +} + +func (n *proofPairList) Get(key []byte) ([]byte, error) { + key_hex := hexutil.Encode(key) + for i, k := range n.keys { + if k == key_hex { + b, err := hexutil.Decode(n.values[i]) + if err != nil { + return nil, err + } + return b, nil + } + } + return nil, fmt.Errorf("key not found") +} + +func TestTransactionProof(t *testing.T) { + to1 := common.HexToAddress("0x01") + to2 := common.HexToAddress("0x02") + t1 := types.NewTransaction(1, to1, big.NewInt(1), 1, big.NewInt(1), []byte{}) + t2 := types.NewTransaction(2, to2, big.NewInt(2), 2, big.NewInt(2), []byte{}) + t3 := types.NewTransaction(3, to2, big.NewInt(3), 3, big.NewInt(3), []byte{}) + t4 := types.NewTransaction(4, to1, big.NewInt(4), 4, big.NewInt(4), []byte{}) + transactions := types.Transactions([]*types.Transaction{t1, t2, t3, t4}) + tr := deriveTrie(transactions) + // for verifying the proof + root := types.DeriveSha(transactions) + for i := 0; i < transactions.Len(); i++ { + var proof proofPairList + keybuf := new(bytes.Buffer) + rlp.Encode(keybuf, uint(i)) + if err := tr.Prove(keybuf.Bytes(), 0, &proof); err != nil { + t.Fatal("Prove err:", err) + } + // verify the proof + value, err := trie.VerifyProof(root, keybuf.Bytes(), &proof) + if err != nil { + t.Fatal("verify proof error") + } + encodedTransaction, err := rlp.EncodeToBytes(transactions[i]) + if err != nil { + t.Fatal("encode receipt error") + } + if !reflect.DeepEqual(encodedTransaction, value) { + t.Fatal("verify does not return the receipt we want") + } + } +} + +func TestReceiptProof(t *testing.T) { + root1 := []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + root2 := []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + r1 := types.NewReceipt(root1, false, 1) + r2 := types.NewReceipt(root2, true, 2) + r3 := types.NewReceipt(root2, false, 3) + r4 := types.NewReceipt(root1, true, 4) + receipts := types.Receipts([]*types.Receipt{r1, r2, r3, r4}) + tr := deriveTrie(receipts) + // for verifying the proof + root := types.DeriveSha(receipts) + for i := 0; i < receipts.Len(); i++ { + var proof proofPairList + keybuf := new(bytes.Buffer) + rlp.Encode(keybuf, uint(i)) + if err := tr.Prove(keybuf.Bytes(), 0, &proof); err != nil { + t.Fatal("Prove err:", err) + } + // verify the proof + value, err := trie.VerifyProof(root, keybuf.Bytes(), &proof) + if err != nil { + t.Fatal("verify proof error") + } + encodedReceipt, err := rlp.EncodeToBytes(receipts[i]) + if err != nil { + t.Fatal("encode receipt error") + } + if !reflect.DeepEqual(encodedReceipt, value) { + t.Fatal("verify does not return the receipt we want") + } + } +} diff --git a/internal/guide/guide_test.go b/internal/guide/guide_test.go index 9077ae52a4d8..64b9e162fae2 100644 --- a/internal/guide/guide_test.go +++ b/internal/guide/guide_test.go @@ -30,6 +30,7 @@ import ( "time" "github.com/XinFinOrg/XDPoSChain/accounts/keystore" + "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" ) @@ -74,7 +75,8 @@ func TestAccountManagement(t *testing.T) { if err != nil { t.Fatalf("Failed to create signer account: %v", err) } - tx, chain := new(types.Transaction), big.NewInt(1) + tx := types.NewTransaction(0, common.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) + chain := big.NewInt(1) // Sign a transaction with a single authorization if _, err := ks.SignTxWithPassphrase(signer, "Signer password", tx, chain); err != nil { diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go index c07cf978c8ad..dbb3f41df4ad 100644 --- a/internal/web3ext/web3ext.go +++ b/internal/web3ext/web3ext.go @@ -420,6 +420,12 @@ web3._extend({ params: 2, inputFormatter: [null, null] }), + new web3._extend.Method({ + name: 'traceCall', + call: 'debug_traceCall', + params: 3, + inputFormatter: [null, null, null] + }), new web3._extend.Method({ name: 'preimage', call: 'debug_preimage', @@ -496,6 +502,11 @@ web3._extend({ call: 'eth_getRewardByHash', params: 1 }), + new web3._extend.Method({ + name: 'getTransactionAndReceiptProof', + call: 'eth_getTransactionAndReceiptProof', + params: 1 + }), new web3._extend.Method({ name: 'getRawTransactionFromBlock', call: function(args) { @@ -510,6 +521,17 @@ web3._extend({ params: 2, inputFormatter: [web3._extend.formatters.inputAddressFormatter, web3._extend.formatters.inputBlockNumberFormatter] }), + new web3._extend.Method({ + name: 'createAccessList', + call: 'eth_createAccessList', + params: 2, + inputFormatter: [null, web3._extend.formatters.inputBlockNumberFormatter], + }), + new web3._extend.Method({ + name: 'getBlockReceipts', + call: 'eth_getBlockReceipts', + params: 1, + }), ], properties: [ new web3._extend.Property({ diff --git a/les/api_backend.go b/les/api_backend.go index 152db5e488c0..abdd952a6ebf 100644 --- a/les/api_backend.go +++ b/les/api_backend.go @@ -171,10 +171,13 @@ func (b *LesApiBackend) GetTd(blockHash common.Hash) *big.Int { return b.eth.blockchain.GetTdByHash(blockHash) } -func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmCfg vm.Config) (*vm.EVM, func() error, error) { +func (b *LesApiBackend) GetEVM(ctx context.Context, msg core.Message, state *state.StateDB, XDCxState *tradingstate.TradingStateDB, header *types.Header, vmConfig *vm.Config) (*vm.EVM, func() error, error) { + if vmConfig == nil { + vmConfig = new(vm.Config) + } state.SetBalance(msg.From(), math.MaxBig256) context := core.NewEVMContext(msg, header, b.eth.blockchain, nil) - return vm.NewEVM(context, state, XDCxState, b.eth.chainConfig, vmCfg), state.Error, nil + return vm.NewEVM(context, state, XDCxState, b.eth.chainConfig, *vmConfig), state.Error, nil } func (b *LesApiBackend) SendTx(ctx context.Context, signedTx *types.Transaction) error { @@ -218,8 +221,8 @@ func (b *LesApiBackend) OrderStats() (pending int, queued int) { return 0, 0 } -func (b *LesApiBackend) SubscribeTxPreEvent(ch chan<- core.TxPreEvent) event.Subscription { - return b.eth.txPool.SubscribeTxPreEvent(ch) +func (b *LesApiBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { + return b.eth.txPool.SubscribeNewTxsEvent(ch) } func (b *LesApiBackend) SubscribeChainEvent(ch chan<- core.ChainEvent) event.Subscription { @@ -266,6 +269,10 @@ func (b *LesApiBackend) AccountManager() *accounts.Manager { return b.eth.accountManager } +func (b *LesApiBackend) RPCGasCap() uint64 { + return b.eth.config.RPCGasCap +} + func (b *LesApiBackend) BloomStatus() (uint64, uint64) { if b.eth.bloomIndexer == nil { return 0, 0 diff --git a/les/handler.go b/les/handler.go index 812c769de54c..6a4ba688ea3b 100644 --- a/les/handler.go +++ b/les/handler.go @@ -91,6 +91,7 @@ type BlockChain interface { type txPool interface { AddRemotes(txs []*types.Transaction) []error + AddRemotesSync(txs []*types.Transaction) []error Status(hashes []common.Hash) []core.TxStatus } diff --git a/les/odr_test.go b/les/odr_test.go index bbaa7c0dfafe..1234bd28536b 100644 --- a/les/odr_test.go +++ b/les/odr_test.go @@ -19,7 +19,6 @@ package les import ( "bytes" "context" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math/big" "testing" "time" @@ -27,6 +26,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/math" "github.com/XinFinOrg/XDPoSChain/core" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" "github.com/XinFinOrg/XDPoSChain/core/state" "github.com/XinFinOrg/XDPoSChain/core/types" "github.com/XinFinOrg/XDPoSChain/core/vm" @@ -133,7 +133,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee, header.Number)} + msg := callmsg{types.NewMessage(from.Address(), &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, nil, false, balanceTokenFee, header.Number)} context := core.NewEVMContext(msg, header, bc, nil) vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{}) @@ -153,7 +153,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, config *params.Chai if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, false, balanceTokenFee, header.Number)} + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 100000, new(big.Int), data, nil, false, balanceTokenFee, header.Number)} context := core.NewEVMContext(msg, header, lc, nil) vmenv := vm.NewEVM(context, statedb, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) diff --git a/light/odr_test.go b/light/odr_test.go index 4905260adef7..83e1c807caf6 100644 --- a/light/odr_test.go +++ b/light/odr_test.go @@ -20,12 +20,13 @@ import ( "bytes" "context" "errors" - "github.com/XinFinOrg/XDPoSChain/consensus" - "github.com/XinFinOrg/XDPoSChain/core/rawdb" "math/big" "testing" "time" + "github.com/XinFinOrg/XDPoSChain/consensus" + "github.com/XinFinOrg/XDPoSChain/core/rawdb" + "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/common/math" "github.com/XinFinOrg/XDPoSChain/consensus/ethash" @@ -183,7 +184,7 @@ func odrContractCall(ctx context.Context, db ethdb.Database, bc *core.BlockChain if value, ok := feeCapacity[testContractAddr]; ok { balanceTokenFee = value } - msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, new(big.Int), data, false, balanceTokenFee, header.Number)} + msg := callmsg{types.NewMessage(testBankAddress, &testContractAddr, 0, new(big.Int), 1000000, new(big.Int), data, nil, false, balanceTokenFee, header.Number)} context := core.NewEVMContext(msg, header, chain, nil) vmenv := vm.NewEVM(context, st, nil, config, vm.Config{}) gp := new(core.GasPool).AddGas(math.MaxUint64) diff --git a/light/txpool.go b/light/txpool.go index 27b749bf4d8a..292e91b92dab 100644 --- a/light/txpool.go +++ b/light/txpool.go @@ -19,6 +19,7 @@ package light import ( "context" "fmt" + "math/big" "sync" "time" @@ -30,7 +31,6 @@ import ( "github.com/XinFinOrg/XDPoSChain/event" "github.com/XinFinOrg/XDPoSChain/log" "github.com/XinFinOrg/XDPoSChain/params" - "github.com/XinFinOrg/XDPoSChain/rlp" ) const ( @@ -67,6 +67,7 @@ type TxPool struct { clearIdx uint64 // earliest block nr that can contain mined tx info homestead bool + eip2718 bool // Fork indicator whether we are in the eip2718 stage. } // TxRelayBackend provides an interface to the mechanism that forwards transacions @@ -74,10 +75,13 @@ type TxPool struct { // // Send instructs backend to forward new transactions // NewHead notifies backend about a new head after processed by the tx pool, -// including mined and rolled back transactions since the last event +// +// including mined and rolled back transactions since the last event +// // Discard notifies backend about transactions that should be discarded either -// because they have been replaced by a re-send or because they have been mined -// long ago and no rollback is expected +// +// because they have been replaced by a re-send or because they have been mined +// long ago and no rollback is expected type TxRelayBackend interface { Send(txs types.Transactions) NewHead(head common.Hash, mined []common.Hash, rollback []common.Hash) @@ -88,7 +92,7 @@ type TxRelayBackend interface { func NewTxPool(config *params.ChainConfig, chain *LightChain, relay TxRelayBackend) *TxPool { pool := &TxPool{ config: config, - signer: types.NewEIP155Signer(config.ChainId), + signer: types.LatestSigner(config), nonce: make(map[common.Address]uint64), pending: make(map[common.Hash]*types.Transaction), mined: make(map[common.Hash][]*types.Transaction), @@ -307,8 +311,11 @@ func (pool *TxPool) setNewHead(head *types.Header) { txc, _ := pool.reorgOnNewHead(ctx, head) m, r := txc.getLists() pool.relay.NewHead(pool.head, m, r) + + // Update fork indicator by next pending block number + next := new(big.Int).Add(head.Number, big.NewInt(1)) pool.homestead = pool.config.IsHomestead(head.Number) - pool.signer = types.MakeSigner(pool.config, head.Number) + pool.eip2718 = pool.config.IsEIP1559(next) } // Stop stops the light transaction pool @@ -321,9 +328,9 @@ func (pool *TxPool) Stop() { log.Info("Transaction pool stopped") } -// SubscribeTxPreEvent registers a subscription of core.TxPreEvent and +// SubscribeNewTxsEvent registers a subscription of core.NewTxsEvent and // starts sending event to the given channel. -func (pool *TxPool) SubscribeTxPreEvent(ch chan<- core.TxPreEvent) event.Subscription { +func (pool *TxPool) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { return pool.scope.Track(pool.txFeed.Subscribe(ch)) } @@ -400,7 +407,7 @@ func (pool *TxPool) validateTx(ctx context.Context, tx *types.Transaction) error } // Should supply enough intrinsic gas - gas, err := core.IntrinsicGas(tx.Data(), tx.To() == nil, pool.homestead) + gas, err := core.IntrinsicGas(tx.Data(), tx.AccessList(), tx.To() == nil, pool.homestead) if err != nil { return err } @@ -436,7 +443,7 @@ func (self *TxPool) add(ctx context.Context, tx *types.Transaction) error { // Notify the subscribers. This event is posted in a goroutine // because it's possible that somewhere during the post "Remove transaction" // gets called which will then wait for the global tx pool lock and deadlock. - go self.txFeed.Send(core.TxPreEvent{Tx: tx}) + go self.txFeed.Send(core.NewTxsEvent{Txs: types.Transactions{tx}}) } // Print a log message if low enough level is set @@ -449,8 +456,7 @@ func (self *TxPool) add(ctx context.Context, tx *types.Transaction) error { func (self *TxPool) Add(ctx context.Context, tx *types.Transaction) error { self.mu.Lock() defer self.mu.Unlock() - - data, err := rlp.EncodeToBytes(tx) + data, err := tx.MarshalBinary() if err != nil { return err } diff --git a/metrics/gauge.go b/metrics/gauge.go index 0fbfdb86033b..b6b2758b0d13 100644 --- a/metrics/gauge.go +++ b/metrics/gauge.go @@ -6,6 +6,8 @@ import "sync/atomic" type Gauge interface { Snapshot() Gauge Update(int64) + Dec(int64) + Inc(int64) Value() int64 } @@ -65,6 +67,16 @@ func (GaugeSnapshot) Update(int64) { panic("Update called on a GaugeSnapshot") } +// Dec panics. +func (GaugeSnapshot) Dec(int64) { + panic("Dec called on a GaugeSnapshot") +} + +// Inc panics. +func (GaugeSnapshot) Inc(int64) { + panic("Inc called on a GaugeSnapshot") +} + // Value returns the value at the time the snapshot was taken. func (g GaugeSnapshot) Value() int64 { return int64(g) } @@ -77,6 +89,12 @@ func (NilGauge) Snapshot() Gauge { return NilGauge{} } // Update is a no-op. func (NilGauge) Update(v int64) {} +// Dec is a no-op. +func (NilGauge) Dec(i int64) {} + +// Inc is a no-op. +func (NilGauge) Inc(i int64) {} + // Value is a no-op. func (NilGauge) Value() int64 { return 0 } @@ -101,6 +119,16 @@ func (g *StandardGauge) Value() int64 { return atomic.LoadInt64(&g.value) } +// Dec decrements the gauge's current value by the given amount. +func (g *StandardGauge) Dec(i int64) { + atomic.AddInt64(&g.value, -i) +} + +// Inc increments the gauge's current value by the given amount. +func (g *StandardGauge) Inc(i int64) { + atomic.AddInt64(&g.value, i) +} + // FunctionalGauge returns value from given function type FunctionalGauge struct { value func() int64 @@ -118,3 +146,13 @@ func (g FunctionalGauge) Snapshot() Gauge { return GaugeSnapshot(g.Value()) } func (FunctionalGauge) Update(int64) { panic("Update called on a FunctionalGauge") } + +// Dec panics. +func (FunctionalGauge) Dec(int64) { + panic("Dec called on a FunctionalGauge") +} + +// Inc panics. +func (FunctionalGauge) Inc(int64) { + panic("Inc called on a FunctionalGauge") +} diff --git a/miner/worker.go b/miner/worker.go index 6f7ee10fce79..3de0328bf8d8 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -19,18 +19,15 @@ package miner import ( "bytes" "encoding/binary" - "fmt" - - "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" - "github.com/XinFinOrg/XDPoSChain/accounts" - + "errors" "math/big" "sync" "sync/atomic" "time" "github.com/XinFinOrg/XDPoSChain/XDCx/tradingstate" - + "github.com/XinFinOrg/XDPoSChain/XDCxlending/lendingstate" + "github.com/XinFinOrg/XDPoSChain/accounts" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/consensus" "github.com/XinFinOrg/XDPoSChain/consensus/XDPoS" @@ -51,7 +48,7 @@ const ( resultQueueSize = 10 miningLogAtDepth = 5 - // txChanSize is the size of channel listening to TxPreEvent. + // txChanSize is the size of channel listening to NewTxsEvent. // The number is referenced from the size of tx pool. txChanSize = 4096 // chainHeadChanSize is the size of channel listening to ChainHeadEvent. @@ -109,8 +106,8 @@ type worker struct { // update loop mux *event.TypeMux - txCh chan core.TxPreEvent - txSub event.Subscription + txsCh chan core.NewTxsEvent + txsSub event.Subscription chainHeadCh chan core.ChainHeadEvent chainHeadSub event.Subscription chainSideCh chan core.ChainSideEvent @@ -149,7 +146,7 @@ func newWorker(config *params.ChainConfig, engine consensus.Engine, coinbase com engine: engine, eth: eth, mux: mux, - txCh: make(chan core.TxPreEvent, txChanSize), + txsCh: make(chan core.NewTxsEvent, txChanSize), chainHeadCh: make(chan core.ChainHeadEvent, chainHeadChanSize), chainSideCh: make(chan core.ChainSideEvent, chainSideChanSize), chainDb: eth.ChainDb(), @@ -163,8 +160,8 @@ func newWorker(config *params.ChainConfig, engine consensus.Engine, coinbase com announceTxs: announceTxs, } if worker.announceTxs { - // Subscribe TxPreEvent for tx pool - worker.txSub = eth.TxPool().SubscribeTxPreEvent(worker.txCh) + // Subscribe NewTxsEvent for tx pool + worker.txsSub = eth.TxPool().SubscribeNewTxsEvent(worker.txsCh) } // Subscribe events for blockchain worker.chainHeadSub = eth.BlockChain().SubscribeChainHeadEvent(worker.chainHeadCh) @@ -261,7 +258,7 @@ func (self *worker) unregister(agent Agent) { func (self *worker) update() { if self.announceTxs { - defer self.txSub.Unsubscribe() + defer self.txsSub.Unsubscribe() } defer self.chainHeadSub.Unsubscribe() defer self.chainSideSub.Unsubscribe() @@ -307,20 +304,22 @@ func (self *worker) update() { timeout.Reset(time.Duration(minePeriod) * time.Second) // Handle ChainSideEvent - case ev := <-self.chainSideCh: - if self.config.XDPoS == nil { - self.uncleMu.Lock() - self.possibleUncles[ev.Block.Hash()] = ev.Block - self.uncleMu.Unlock() - } - - // Handle TxPreEvent - case ev := <-self.txCh: - // Apply transaction to the pending state if we're not mining + case <-self.chainSideCh: + + // Handle NewTxsEvent + case ev := <-self.txsCh: + // Apply transactions to the pending state if we're not mining. + // + // Note all transactions received may not be continuous with transactions + // already included in the current mining block. These transactions will + // be automatically eliminated. if atomic.LoadInt32(&self.mining) == 0 { self.currentMu.Lock() - acc, _ := types.Sender(self.current.signer, ev.Tx) - txs := map[common.Address]types.Transactions{acc: {ev.Tx}} + txs := make(map[common.Address]types.Transactions) + for _, tx := range ev.Txs { + acc, _ := types.Sender(self.current.signer, tx) + txs[acc] = append(txs[acc], tx) + } feeCapacity := state.GetTRC21FeeCapacityFromState(self.current.state) txset, specialTxs := types.NewTransactionsByPriceAndNonce(self.current.signer, txs, nil, feeCapacity) self.current.commitTransactions(self.mux, feeCapacity, txset, specialTxs, self.chain, self.coinbase) @@ -357,18 +356,26 @@ func (self *worker) wait() { } work := result.Work + // Different block could share same sealhash, deep copy here to prevent write-write conflict. + hash := block.Hash() + receipts := make([]*types.Receipt, len(work.receipts)) + for i, receipt := range work.receipts { + // add block location fields + receipt.BlockHash = hash + receipt.BlockNumber = block.Number() + receipt.TransactionIndex = uint(i) + + receipts[i] = new(types.Receipt) + *receipts[i] = *receipt + } // Update the block hash in all logs since it is now available and not when the // receipt/log of individual transactions were created. - for _, r := range work.receipts { - for _, l := range r.Logs { - l.BlockHash = block.Hash() - } - } for _, log := range work.state.Logs() { - log.BlockHash = block.Hash() + log.BlockHash = hash } + // Commit block and state to database. self.currentMu.Lock() - stat, err := self.chain.WriteBlockWithState(block, work.receipts, work.state, work.tradingState, work.lendingState) + stat, err := self.chain.WriteBlockWithState(block, receipts, work.state, work.tradingState, work.lendingState) self.currentMu.Unlock() if err != nil { log.Error("Failed writing block to chain", "err", err) @@ -457,10 +464,13 @@ func (self *worker) push(work *Work) { // makeCurrent creates a new environment for the current cycle. func (self *worker) makeCurrent(parent *types.Block, header *types.Header) error { + // Retrieve the parent state to execute on top and start a prefetcher for + // the miner to speed block sealing up a bit state, err := self.chain.StateAt(parent.Root()) if err != nil { return err } + author, _ := self.chain.Engine().Author(parent.Header()) var XDCxState *tradingstate.TradingStateDB var lendingState *lendingstate.LendingStateDB @@ -481,7 +491,7 @@ func (self *worker) makeCurrent(parent *types.Block, header *types.Header) error work := &Work{ config: self.config, - signer: types.NewEIP155Signer(self.config.ChainId), + signer: types.MakeSigner(self.config, header.Number), state: state, parentState: state.Copy(), tradingState: XDCxState, @@ -493,17 +503,6 @@ func (self *worker) makeCurrent(parent *types.Block, header *types.Header) error createdAt: time.Now(), } - if self.config.XDPoS == nil { - // when 08 is processed ancestors contain 07 (quick block) - for _, ancestor := range self.chain.GetBlocksFromHash(parent.Hash(), 7) { - for _, uncle := range ancestor.Uncles() { - work.family.Add(uncle.Hash()) - } - work.family.Add(ancestor.Hash()) - work.ancestors.Add(ancestor.Hash()) - } - } - // Keep track of transactions which return errors so they can be removed work.tcount = 0 self.current = work @@ -785,33 +784,15 @@ func (self *worker) commitNewWork() { work.commitTransactions(self.mux, feeCapacity, txs, specialTxs, self.chain, self.coinbase) // compute uncles for the new block. var ( - uncles []*types.Header - badUncles []common.Hash + uncles []*types.Header ) - if self.config.XDPoS == nil { - for hash, uncle := range self.possibleUncles { - if len(uncles) == 2 { - break - } - if err := self.commitUncle(work, uncle.Header()); err != nil { - log.Trace("Bad uncle found and will be removed", "hash", hash) - log.Trace(fmt.Sprint(uncle)) - badUncles = append(badUncles, hash) - } else { - log.Debug("Committing new uncle to block", "hash", hash) - uncles = append(uncles, uncle.Header()) - } - } - for _, hash := range badUncles { - delete(self.possibleUncles, hash) - } - } // Create the new block to seal with the consensus engine if work.Block, err = self.engine.Finalize(self.chain, header, work.state, work.parentState, work.txs, uncles, work.receipts); err != nil { log.Error("Failed to finalize block for sealing", "err", err) return } + if atomic.LoadInt32(&self.mining) == 1 { log.Info("Committing new block", "number", work.Block.Number(), "txs", work.tcount, "special-txs", len(specialTxs), "uncles", len(uncles), "elapsed", common.PrettyDuration(time.Since(tstart))) self.unconfirmed.Shift(work.Block.NumberU64() - 1) @@ -820,21 +801,6 @@ func (self *worker) commitNewWork() { self.push(work) } -func (self *worker) commitUncle(work *Work, uncle *types.Header) error { - hash := uncle.Hash() - if work.uncles.Contains(hash) { - return fmt.Errorf("uncle not unique") - } - if !work.ancestors.Contains(uncle.ParentHash) { - return fmt.Errorf("uncle's parent unknown (%x)", uncle.ParentHash[0:4]) - } - if work.family.Contains(hash) { - return fmt.Errorf("uncle already in family (%x)", hash) - } - work.uncles.Add(uncle.Hash()) - return nil -} - func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Address]*big.Int, txs *types.TransactionsByPriceAndNonce, specialTxs types.Transactions, bc *core.BlockChain, coinbase common.Address) { gp := new(core.GasPool).AddGas(env.header.GasLimit) balanceUpdated := map[common.Address]*big.Int{} @@ -1016,28 +982,33 @@ func (env *Work) commitTransactions(mux *event.TypeMux, balanceFee map[common.Ad continue } err, logs, tokenFeeUsed, gas := env.commitTransaction(balanceFee, tx, bc, coinbase, gp) - switch err { - case core.ErrGasLimitReached: + switch { + case errors.Is(err, core.ErrGasLimitReached): // Pop the current out-of-gas transaction without shifting in the next from the account log.Trace("Gas limit exceeded for current block", "sender", from) txs.Pop() - case core.ErrNonceTooLow: + case errors.Is(err, core.ErrNonceTooLow): // New head notification data race between the transaction pool and miner, shift log.Trace("Skipping transaction with low nonce", "sender", from, "nonce", tx.Nonce()) txs.Shift() - case core.ErrNonceTooHigh: + case errors.Is(err, core.ErrNonceTooHigh): // Reorg notification data race between the transaction pool and miner, skip account = log.Trace("Skipping account with high nonce", "sender", from, "nonce", tx.Nonce()) txs.Pop() - case nil: + case errors.Is(err, nil): // Everything ok, collect the logs and shift in the next transaction from the same account coalescedLogs = append(coalescedLogs, logs...) env.tcount++ txs.Shift() + case errors.Is(err, core.ErrTxTypeNotSupported): + // Pop the unsupported transaction without shifting in the next from the account + log.Trace("Skipping unsupported transaction type", "sender", from, "type", tx.Type()) + txs.Pop() + default: // Strange error, discard the transaction and get the next in line (note, the // nonce-too-high clause will prevent us from executing in vain). diff --git a/p2p/rlpx.go b/p2p/rlpx.go index 66efd4199713..ea26b2f2ec8c 100644 --- a/p2p/rlpx.go +++ b/p2p/rlpx.go @@ -528,9 +528,9 @@ func importPublicKey(pubKey []byte) (*ecies.PublicKey, error) { return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey)) } // TODO: fewer pointless conversions - pub := crypto.ToECDSAPub(pubKey65) - if pub.X == nil { - return nil, fmt.Errorf("invalid public key") + pub, err := crypto.UnmarshalPubkey(pubKey65) + if err != nil { + return nil, err } return ecies.ImportECDSAPublic(pub), nil } diff --git a/params/config.go b/params/config.go index 933d5f2980a9..ff2de6a7ebda 100644 --- a/params/config.go +++ b/params/config.go @@ -363,6 +363,14 @@ type ChainConfig struct { ByzantiumBlock *big.Int `json:"byzantiumBlock,omitempty"` // Byzantium switch block (nil = no fork, 0 = already on byzantium) ConstantinopleBlock *big.Int `json:"constantinopleBlock,omitempty"` // Constantinople switch block (nil = no fork, 0 = already activated) + PetersburgBlock *big.Int `json:"petersburgBlock,omitempty"` + IstanbulBlock *big.Int `json:"istanbulBlock,omitempty"` + BerlinBlock *big.Int `json:"berlinBlock,omitempty"` + LondonBlock *big.Int `json:"londonBlock,omitempty"` + MergeBlock *big.Int `json:"mergeBlock,omitempty"` + ShanghaiBlock *big.Int `json:"shanghaiBlock,omitempty"` + Eip1559Block *big.Int `json:"eip1559Block,omitempty"` + // Various consensus engines Ethash *EthashConfig `json:"ethash,omitempty"` Clique *CliqueConfig `json:"clique,omitempty"` @@ -498,7 +506,7 @@ func (c *ChainConfig) String() string { default: engine = "unknown" } - return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Istanbul: %v BerlinBlock: %v LondonBlock: %v MergeBlock: %v ShanghaiBlock: %v Engine: %v}", + return fmt.Sprintf("{ChainID: %v Homestead: %v DAO: %v DAOSupport: %v EIP150: %v EIP155: %v EIP158: %v Byzantium: %v Constantinople: %v Istanbul: %v BerlinBlock: %v LondonBlock: %v MergeBlock: %v ShanghaiBlock: %v Eip1559Block: %v Engine: %v}", c.ChainId, c.HomesteadBlock, c.DAOForkBlock, @@ -513,6 +521,7 @@ func (c *ChainConfig) String() string { common.LondonBlock, common.MergeBlock, common.ShanghaiBlock, + common.Eip1559Block, engine, ) } @@ -551,33 +560,37 @@ func (c *ChainConfig) IsConstantinople(num *big.Int) bool { // - equal to or greater than the PetersburgBlock fork block, // - OR is nil, and Constantinople is active func (c *ChainConfig) IsPetersburg(num *big.Int) bool { - return isForked(common.TIPXDCXCancellationFee, num) + return isForked(common.TIPXDCXCancellationFee, num) || isForked(c.PetersburgBlock, num) } // IsIstanbul returns whether num is either equal to the Istanbul fork block or greater. func (c *ChainConfig) IsIstanbul(num *big.Int) bool { - return isForked(common.TIPXDCXCancellationFee, num) + return isForked(common.TIPXDCXCancellationFee, num) || isForked(c.IstanbulBlock, num) } // IsBerlin returns whether num is either equal to the Berlin fork block or greater. func (c *ChainConfig) IsBerlin(num *big.Int) bool { - return isForked(common.BerlinBlock, num) + return isForked(common.BerlinBlock, num) || isForked(c.BerlinBlock, num) } // IsLondon returns whether num is either equal to the London fork block or greater. func (c *ChainConfig) IsLondon(num *big.Int) bool { - return isForked(common.LondonBlock, num) + return isForked(common.LondonBlock, num) || isForked(c.LondonBlock, num) } // IsMerge returns whether num is either equal to the Merge fork block or greater. // Different from Geth which uses `block.difficulty != nil` func (c *ChainConfig) IsMerge(num *big.Int) bool { - return isForked(common.MergeBlock, num) + return isForked(common.MergeBlock, num) || isForked(c.MergeBlock, num) } // IsShanghai returns whether num is either equal to the Shanghai fork block or greater. func (c *ChainConfig) IsShanghai(num *big.Int) bool { - return isForked(common.ShanghaiBlock, num) + return isForked(common.ShanghaiBlock, num) || isForked(c.ShanghaiBlock, num) +} + +func (c *ChainConfig) IsEIP1559(num *big.Int) bool { + return isForked(common.Eip1559Block, num) || isForked(c.Eip1559Block, num) } func (c *ChainConfig) IsTIP2019(num *big.Int) bool { @@ -606,7 +619,11 @@ func (c *ChainConfig) IsTIPXDCX(num *big.Int) bool { return isForked(common.TIPXDCX, num) } func (c *ChainConfig) IsTIPXDCXMiner(num *big.Int) bool { - return isForked(common.TIPXDCX, num) && !isForked(common.TIPXDCXDISABLE, num) + return isForked(common.TIPXDCX, num) && !isForked(common.TIPXDCXMinerDisable, num) +} + +func (c *ChainConfig) IsTIPXDCXReceiver(num *big.Int) bool { + return isForked(common.TIPXDCX, num) && !isForked(common.TIPXDCXReceiverDisable, num) } func (c *ChainConfig) IsTIPXDCXLending(num *big.Int) bool { @@ -749,6 +766,8 @@ type Rules struct { IsByzantium, IsConstantinople, IsPetersburg, IsIstanbul bool IsBerlin, IsLondon bool IsMerge, IsShanghai bool + IsXDCxDisable bool + IsEIP1559 bool } func (c *ChainConfig) Rules(num *big.Int) Rules { @@ -770,5 +789,7 @@ func (c *ChainConfig) Rules(num *big.Int) Rules { IsLondon: c.IsLondon(num), IsMerge: c.IsMerge(num), IsShanghai: c.IsShanghai(num), + IsXDCxDisable: c.IsTIPXDCXReceiver(num), + IsEIP1559: c.IsEIP1559(num), } } diff --git a/params/protocol_params.go b/params/protocol_params.go index f15530649750..b8cd64f72346 100644 --- a/params/protocol_params.go +++ b/params/protocol_params.go @@ -25,6 +25,7 @@ var ( const ( GasLimitBoundDivisor uint64 = 1024 // The bound divisor of the gas limit, used in update calculations. MinGasLimit uint64 = 5000 // Minimum the gas limit may ever be. + MaxGasLimit uint64 = 0x7fffffffffffffff // Maximum the gas limit (2^63-1). GenesisGasLimit uint64 = 4712388 // Gas limit of the Genesis block. XDCGenesisGasLimit uint64 = 84000000 @@ -60,6 +61,10 @@ const ( CreateGas uint64 = 32000 // Once per CREATE operation & contract-creation transaction. SuicideRefundGas uint64 = 24000 // Refunded following a suicide operation. MemoryGas uint64 = 3 // Times the address of the (highest referenced byte in memory + 1). NOTE: referencing happens on read, write and in instructions such as RETURN and CALL. + + TxAccessListAddressGas uint64 = 2400 // Per address specified in EIP 2930 access list + TxAccessListStorageKeyGas uint64 = 1900 // Per storage key specified in EIP 2930 access list + TxDataNonZeroGas uint64 = 68 // Per byte of data attached to a transaction that is not equal to zero. NOTE: Not payable on data of calls between transactions. MaxCodeSize = 24576 // Maximum bytecode to permit for a contract diff --git a/params/version.go b/params/version.go index ef1a1ab45a20..6e9d48d38dd7 100644 --- a/params/version.go +++ b/params/version.go @@ -21,10 +21,10 @@ import ( ) const ( - VersionMajor = 1 // Major version component of the current release - VersionMinor = 6 // Minor version component of the current release + VersionMajor = 2 // Major version component of the current release + VersionMinor = 3 // Minor version component of the current release VersionPatch = 0 // Patch version component of the current release - VersionMeta = "" // Version metadata to append to the version string + VersionMeta = "beta1" // Version metadata to append to the version string ) // Version holds the textual version string. diff --git a/rlp/decode.go b/rlp/decode.go index 60d9dab2b5c4..1d32ff8523d4 100644 --- a/rlp/decode.go +++ b/rlp/decode.go @@ -26,103 +26,80 @@ import ( "math/big" "reflect" "strings" + "sync" + + "github.com/XinFinOrg/XDPoSChain/rlp/internal/rlpstruct" + "github.com/holiman/uint256" ) +//lint:ignore ST1012 EOL is not an error. + +// EOL is returned when the end of the current list +// has been reached during streaming. +var EOL = errors.New("rlp: end of list") + var ( + ErrExpectedString = errors.New("rlp: expected String or Byte") + ErrExpectedList = errors.New("rlp: expected List") + ErrCanonInt = errors.New("rlp: non-canonical integer format") + ErrCanonSize = errors.New("rlp: non-canonical size information") + ErrElemTooLarge = errors.New("rlp: element is larger than containing list") + ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") + ErrMoreThanOneValue = errors.New("rlp: input contains more than one value") + + // internal errors + errNotInList = errors.New("rlp: call of ListEnd outside of any list") + errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") + errUintOverflow = errors.New("rlp: uint overflow") errNoPointer = errors.New("rlp: interface given to Decode must be a pointer") errDecodeIntoNil = errors.New("rlp: pointer given to Decode must not be nil") + errUint256Large = errors.New("rlp: value too large for uint256") + + streamPool = sync.Pool{ + New: func() interface{} { return new(Stream) }, + } ) -// Decoder is implemented by types that require custom RLP -// decoding rules or need to decode into private fields. +// Decoder is implemented by types that require custom RLP decoding rules or need to decode +// into private fields. // -// The DecodeRLP method should read one value from the given -// Stream. It is not forbidden to read less or more, but it might -// be confusing. +// The DecodeRLP method should read one value from the given Stream. It is not forbidden to +// read less or more, but it might be confusing. type Decoder interface { DecodeRLP(*Stream) error } -// Decode parses RLP-encoded data from r and stores the result in the -// value pointed to by val. Val must be a non-nil pointer. If r does -// not implement ByteReader, Decode will do its own buffering. -// -// Decode uses the following type-dependent decoding rules: -// -// If the type implements the Decoder interface, decode calls -// DecodeRLP. -// -// To decode into a pointer, Decode will decode into the value pointed -// to. If the pointer is nil, a new value of the pointer's element -// type is allocated. If the pointer is non-nil, the existing value -// will be reused. -// -// To decode into a struct, Decode expects the input to be an RLP -// list. The decoded elements of the list are assigned to each public -// field in the order given by the struct's definition. The input list -// must contain an element for each decoded field. Decode returns an -// error if there are too few or too many elements. +// Decode parses RLP-encoded data from r and stores the result in the value pointed to by +// val. Please see package-level documentation for the decoding rules. Val must be a +// non-nil pointer. // -// The decoding of struct fields honours certain struct tags, "tail", -// "nil" and "-". +// If r does not implement ByteReader, Decode will do its own buffering. // -// The "-" tag ignores fields. +// Note that Decode does not set an input limit for all readers and may be vulnerable to +// panics cause by huge value sizes. If you need an input limit, use // -// For an explanation of "tail", see the example. -// -// The "nil" tag applies to pointer-typed fields and changes the decoding -// rules for the field such that input values of size zero decode as a nil -// pointer. This tag can be useful when decoding recursive types. -// -// type StructWithEmptyOK struct { -// Foo *[20]byte `rlp:"nil"` -// } -// -// To decode into a slice, the input must be a list and the resulting -// slice will contain the input elements in order. For byte slices, -// the input must be an RLP string. Array types decode similarly, with -// the additional restriction that the number of input elements (or -// bytes) must match the array's length. -// -// To decode into a Go string, the input must be an RLP string. The -// input bytes are taken as-is and will not necessarily be valid UTF-8. -// -// To decode into an unsigned integer type, the input must also be an RLP -// string. The bytes are interpreted as a big endian representation of -// the integer. If the RLP string is larger than the bit size of the -// type, Decode will return an error. Decode also supports *big.Int. -// There is no size limit for big integers. -// -// To decode into an interface value, Decode stores one of these -// in the value: -// -// []interface{}, for RLP lists -// []byte, for RLP strings -// -// Non-empty interface types are not supported, nor are booleans, -// signed integers, floating point numbers, maps, channels and -// functions. -// -// Note that Decode does not set an input limit for all readers -// and may be vulnerable to panics cause by huge value sizes. If -// you need an input limit, use -// -// NewStream(r, limit).Decode(val) +// NewStream(r, limit).Decode(val) func Decode(r io.Reader, val interface{}) error { - // TODO: this could use a Stream from a pool. - return NewStream(r, 0).Decode(val) + stream := streamPool.Get().(*Stream) + defer streamPool.Put(stream) + + stream.Reset(r, 0) + return stream.Decode(val) } -// DecodeBytes parses RLP data from b into val. -// Please see the documentation of Decode for the decoding rules. -// The input must contain exactly one value and no trailing data. +// DecodeBytes parses RLP data from b into val. Please see package-level documentation for +// the decoding rules. The input must contain exactly one value and no trailing data. func DecodeBytes(b []byte, val interface{}) error { - // TODO: this could use a Stream from a pool. - r := bytes.NewReader(b) - if err := NewStream(r, uint64(len(b))).Decode(val); err != nil { + r := (*sliceReader)(&b) + + stream := streamPool.Get().(*Stream) + defer streamPool.Put(stream) + + stream.Reset(r, uint64(len(b))) + if err := stream.Decode(val); err != nil { return err } - if r.Len() > 0 { + if len(b) > 0 { return ErrMoreThanOneValue } return nil @@ -173,21 +150,26 @@ func addErrorContext(err error, ctx string) error { var ( decoderInterface = reflect.TypeOf(new(Decoder)).Elem() bigInt = reflect.TypeOf(big.Int{}) + u256Int = reflect.TypeOf(uint256.Int{}) ) -func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { +func makeDecoder(typ reflect.Type, tags rlpstruct.Tags) (dec decoder, err error) { kind := typ.Kind() switch { case typ == rawValueType: return decodeRawValue, nil - case typ.Implements(decoderInterface): - return decodeDecoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(decoderInterface): - return decodeDecoderNoPtr, nil - case typ.AssignableTo(reflect.PtrTo(bigInt)): + case typ.AssignableTo(reflect.PointerTo(bigInt)): return decodeBigInt, nil case typ.AssignableTo(bigInt): return decodeBigIntNoPtr, nil + case typ == reflect.PointerTo(u256Int): + return decodeU256, nil + case typ == u256Int: + return decodeU256NoPtr, nil + case kind == reflect.Ptr: + return makePtrDecoder(typ, tags) + case reflect.PointerTo(typ).Implements(decoderInterface): + return decodeDecoder, nil case isUint(kind): return decodeUint, nil case kind == reflect.Bool: @@ -198,11 +180,6 @@ func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) { return makeListDecoder(typ, tags) case kind == reflect.Struct: return makeStructDecoder(typ) - case kind == reflect.Ptr: - if tags.nilOK { - return makeOptionalPtrDecoder(typ) - } - return makePtrDecoder(typ) case kind == reflect.Interface: return decodeInterface, nil default: @@ -252,35 +229,48 @@ func decodeBigIntNoPtr(s *Stream, val reflect.Value) error { } func decodeBigInt(s *Stream, val reflect.Value) error { - b, err := s.Bytes() + i := val.Interface().(*big.Int) + if i == nil { + i = new(big.Int) + val.Set(reflect.ValueOf(i)) + } + + err := s.decodeBigInt(i) if err != nil { return wrapStreamError(err, val.Type()) } - i := val.Interface().(*big.Int) + return nil +} + +func decodeU256NoPtr(s *Stream, val reflect.Value) error { + return decodeU256(s, val.Addr()) +} + +func decodeU256(s *Stream, val reflect.Value) error { + i := val.Interface().(*uint256.Int) if i == nil { - i = new(big.Int) + i = new(uint256.Int) val.Set(reflect.ValueOf(i)) } - // Reject leading zero bytes - if len(b) > 0 && b[0] == 0 { - return wrapStreamError(ErrCanonInt, val.Type()) + + err := s.ReadUint256(i) + if err != nil { + return wrapStreamError(err, val.Type()) } - i.SetBytes(b) return nil } -func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { +func makeListDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) { etype := typ.Elem() - if etype.Kind() == reflect.Uint8 && !reflect.PtrTo(etype).Implements(decoderInterface) { + if etype.Kind() == reflect.Uint8 && !reflect.PointerTo(etype).Implements(decoderInterface) { if typ.Kind() == reflect.Array { return decodeByteArray, nil - } else { - return decodeByteSlice, nil } + return decodeByteSlice, nil } - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{}) + if etypeinfo.decoderErr != nil { + return nil, etypeinfo.decoderErr } var dec decoder switch { @@ -288,7 +278,7 @@ func makeListDecoder(typ reflect.Type, tag tags) (decoder, error) { dec = func(s *Stream, val reflect.Value) error { return decodeListArray(s, val, etypeinfo.decoder) } - case tag.tail: + case tag.Tail: // A slice with "tail" tag can occur as the last field // of a struct and is supposed to swallow all remaining // list elements. The struct decoder already called s.List, @@ -381,25 +371,23 @@ func decodeByteArray(s *Stream, val reflect.Value) error { if err != nil { return err } - vlen := val.Len() + slice := byteArrayBytes(val, val.Len()) switch kind { case Byte: - if vlen == 0 { + if len(slice) == 0 { return &decodeError{msg: "input string too long", typ: val.Type()} - } - if vlen > 1 { + } else if len(slice) > 1 { return &decodeError{msg: "input string too short", typ: val.Type()} } - bv, _ := s.Uint() - val.Index(0).SetUint(bv) + slice[0] = s.byteval + s.kind = -1 case String: - if uint64(vlen) < size { + if uint64(len(slice)) < size { return &decodeError{msg: "input string too long", typ: val.Type()} } - if uint64(vlen) > size { + if uint64(len(slice)) > size { return &decodeError{msg: "input string too short", typ: val.Type()} } - slice := val.Slice(0, vlen).Interface().([]byte) if err := s.readFull(slice); err != nil { return err } @@ -418,13 +406,25 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { if err != nil { return nil, err } + for _, f := range fields { + if f.info.decoderErr != nil { + return nil, structFieldError{typ, f.index, f.info.decoderErr} + } + } dec := func(s *Stream, val reflect.Value) (err error) { if _, err := s.List(); err != nil { return wrapStreamError(err, typ) } - for _, f := range fields { + for i, f := range fields { err := f.info.decoder(s, val.Field(f.index)) if err == EOL { + if f.optional { + // The field is optional, so reaching the end of the list before + // reaching the last field is acceptable. All remaining undecoded + // fields are zeroed. + zeroFields(val, fields[i:]) + break + } return &decodeError{msg: "too few elements", typ: typ} } else if err != nil { return addErrorContext(err, "."+typ.Field(f.index).Name) @@ -435,15 +435,29 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) { return dec, nil } -// makePtrDecoder creates a decoder that decodes into -// the pointer's element type. -func makePtrDecoder(typ reflect.Type) (decoder, error) { +func zeroFields(structval reflect.Value, fields []field) { + for _, f := range fields { + fv := structval.Field(f.index) + fv.Set(reflect.Zero(fv.Type())) + } +} + +// makePtrDecoder creates a decoder that decodes into the pointer's element type. +func makePtrDecoder(typ reflect.Type, tag rlpstruct.Tags) (decoder, error) { etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err + etypeinfo := theTC.infoWhileGenerating(etype, rlpstruct.Tags{}) + switch { + case etypeinfo.decoderErr != nil: + return nil, etypeinfo.decoderErr + case !tag.NilOK: + return makeSimplePtrDecoder(etype, etypeinfo), nil + default: + return makeNilPtrDecoder(etype, etypeinfo, tag), nil } - dec := func(s *Stream, val reflect.Value) (err error) { +} + +func makeSimplePtrDecoder(etype reflect.Type, etypeinfo *typeinfo) decoder { + return func(s *Stream, val reflect.Value) (err error) { newval := val if val.IsNil() { newval = reflect.New(etype) @@ -453,30 +467,39 @@ func makePtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } -// makeOptionalPtrDecoder creates a decoder that decodes empty values -// as nil. Non-empty values are decoded into a value of the element type, -// just like makePtrDecoder does. +// makeNilPtrDecoder creates a decoder that decodes empty values as nil. Non-empty +// values are decoded into a value of the element type, just like makePtrDecoder does. // // This decoder is used for pointer-typed struct fields with struct tag "nil". -func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { - etype := typ.Elem() - etypeinfo, err := cachedTypeInfo1(etype, tags{}) - if err != nil { - return nil, err - } - dec := func(s *Stream, val reflect.Value) (err error) { +func makeNilPtrDecoder(etype reflect.Type, etypeinfo *typeinfo, ts rlpstruct.Tags) decoder { + typ := reflect.PointerTo(etype) + nilPtr := reflect.Zero(typ) + + // Determine the value kind that results in nil pointer. + nilKind := typeNilKind(etype, ts) + + return func(s *Stream, val reflect.Value) (err error) { kind, size, err := s.Kind() - if err != nil || size == 0 && kind != Byte { + if err != nil { + val.Set(nilPtr) + return wrapStreamError(err, typ) + } + // Handle empty values as a nil pointer. + if kind != Byte && size == 0 { + if kind != nilKind { + return &decodeError{ + msg: fmt.Sprintf("wrong kind of empty value (got %v, want %v)", kind, nilKind), + typ: typ, + } + } // rearm s.Kind. This is important because the input // position must advance to the next value even though // we don't read anything. s.kind = -1 - // set the pointer to nil. - val.Set(reflect.Zero(typ)) - return err + val.Set(nilPtr) + return nil } newval := val if val.IsNil() { @@ -487,7 +510,6 @@ func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) { } return err } - return dec, nil } var ifsliceType = reflect.TypeOf([]interface{}{}) @@ -516,25 +538,12 @@ func decodeInterface(s *Stream, val reflect.Value) error { return nil } -// This decoder is used for non-pointer values of types -// that implement the Decoder interface using a pointer receiver. -func decodeDecoderNoPtr(s *Stream, val reflect.Value) error { - return val.Addr().Interface().(Decoder).DecodeRLP(s) -} - func decodeDecoder(s *Stream, val reflect.Value) error { - // Decoder instances are not handled using the pointer rule if the type - // implements Decoder with pointer receiver (i.e. always) - // because it might handle empty values specially. - // We need to allocate one here in this case, like makePtrDecoder does. - if val.Kind() == reflect.Ptr && val.IsNil() { - val.Set(reflect.New(val.Type().Elem())) - } - return val.Interface().(Decoder).DecodeRLP(s) + return val.Addr().Interface().(Decoder).DecodeRLP(s) } // Kind represents the kind of value contained in an RLP stream. -type Kind int +type Kind int8 const ( Byte Kind = iota @@ -555,29 +564,6 @@ func (k Kind) String() string { } } -var ( - // EOL is returned when the end of the current list - // has been reached during streaming. - EOL = errors.New("rlp: end of list") - - // Actual Errors - ErrExpectedString = errors.New("rlp: expected String or Byte") - ErrExpectedList = errors.New("rlp: expected List") - ErrCanonInt = errors.New("rlp: non-canonical integer format") - ErrCanonSize = errors.New("rlp: non-canonical size information") - ErrElemTooLarge = errors.New("rlp: element is larger than containing list") - ErrValueTooLarge = errors.New("rlp: value size exceeds available input length") - - // This error is reported by DecodeBytes if the slice contains - // additional data after the first RLP value. - ErrMoreThanOneValue = errors.New("rlp: input contains more than one value") - - // internal errors - errNotInList = errors.New("rlp: call of ListEnd outside of any list") - errNotAtEOL = errors.New("rlp: call of ListEnd not positioned at EOL") - errUintOverflow = errors.New("rlp: uint overflow") -) - // ByteReader must be implemented by any input reader for a Stream. It // is implemented by e.g. bufio.Reader and bytes.Reader. type ByteReader interface { @@ -600,22 +586,16 @@ type ByteReader interface { type Stream struct { r ByteReader - // number of bytes remaining to be read from r. - remaining uint64 - limited bool - - // auxiliary buffer for integer decoding - uintbuf []byte - - kind Kind // kind of value ahead - size uint64 // size of value ahead - byteval byte // value of single byte in type tag - kinderr error // error from last readKind - stack []listpos + remaining uint64 // number of bytes remaining to be read from r + size uint64 // size of value ahead + kinderr error // error from last readKind + stack []uint64 // list sizes + uintbuf [32]byte // auxiliary buffer for integer decoding + kind Kind // kind of value ahead + byteval byte // value of single byte in type tag + limited bool // true if input limit is in effect } -type listpos struct{ pos, size uint64 } - // NewStream creates a new decoding stream reading from r. // // If r implements the ByteReader interface, Stream will @@ -675,6 +655,37 @@ func (s *Stream) Bytes() ([]byte, error) { } } +// ReadBytes decodes the next RLP value and stores the result in b. +// The value size must match len(b) exactly. +func (s *Stream) ReadBytes(b []byte) error { + kind, size, err := s.Kind() + if err != nil { + return err + } + switch kind { + case Byte: + if len(b) != 1 { + return fmt.Errorf("input value has wrong size 1, want %d", len(b)) + } + b[0] = s.byteval + s.kind = -1 // rearm Kind + return nil + case String: + if uint64(len(b)) != size { + return fmt.Errorf("input value has wrong size %d, want %d", size, len(b)) + } + if err = s.readFull(b); err != nil { + return err + } + if size == 1 && b[0] < 128 { + return ErrCanonSize + } + return nil + default: + return ErrExpectedString + } +} + // Raw reads a raw encoded value including RLP type information. func (s *Stream) Raw() ([]byte, error) { kind, size, err := s.Kind() @@ -685,8 +696,8 @@ func (s *Stream) Raw() ([]byte, error) { s.kind = -1 // rearm Kind return []byte{s.byteval}, nil } - // the original header has already been read and is no longer - // available. read content and put a new header in front of it. + // The original header has already been read and is no longer + // available. Read content and put a new header in front of it. start := headsize(size) buf := make([]byte, uint64(start)+size) if err := s.readFull(buf[start:]); err != nil { @@ -703,10 +714,31 @@ func (s *Stream) Raw() ([]byte, error) { // Uint reads an RLP string of up to 8 bytes and returns its contents // as an unsigned integer. If the input does not contain an RLP string, the // returned error will be ErrExpectedString. +// +// Deprecated: use s.Uint64 instead. func (s *Stream) Uint() (uint64, error) { return s.uint(64) } +func (s *Stream) Uint64() (uint64, error) { + return s.uint(64) +} + +func (s *Stream) Uint32() (uint32, error) { + i, err := s.uint(32) + return uint32(i), err +} + +func (s *Stream) Uint16() (uint16, error) { + i, err := s.uint(16) + return uint16(i), err +} + +func (s *Stream) Uint8() (uint8, error) { + i, err := s.uint(8) + return uint8(i), err +} + func (s *Stream) uint(maxbits int) (uint64, error) { kind, size, err := s.Kind() if err != nil { @@ -769,7 +801,14 @@ func (s *Stream) List() (size uint64, err error) { if kind != List { return 0, ErrExpectedList } - s.stack = append(s.stack, listpos{0, size}) + + // Remove size of inner list from outer list before pushing the new size + // onto the stack. This ensures that the remaining outer list size will + // be correct after the matching call to ListEnd. + if inList, limit := s.listLimit(); inList { + s.stack[len(s.stack)-1] = limit - size + } + s.stack = append(s.stack, size) s.kind = -1 s.size = 0 return size, nil @@ -778,22 +817,116 @@ func (s *Stream) List() (size uint64, err error) { // ListEnd returns to the enclosing list. // The input reader must be positioned at the end of a list. func (s *Stream) ListEnd() error { - if len(s.stack) == 0 { + // Ensure that no more data is remaining in the current list. + if inList, listLimit := s.listLimit(); !inList { return errNotInList - } - tos := s.stack[len(s.stack)-1] - if tos.pos != tos.size { + } else if listLimit > 0 { return errNotAtEOL } s.stack = s.stack[:len(s.stack)-1] // pop - if len(s.stack) > 0 { - s.stack[len(s.stack)-1].pos += tos.size - } s.kind = -1 s.size = 0 return nil } +// MoreDataInList reports whether the current list context contains +// more data to be read. +func (s *Stream) MoreDataInList() bool { + _, listLimit := s.listLimit() + return listLimit > 0 +} + +// BigInt decodes an arbitrary-size integer value. +func (s *Stream) BigInt() (*big.Int, error) { + i := new(big.Int) + if err := s.decodeBigInt(i); err != nil { + return nil, err + } + return i, nil +} + +func (s *Stream) decodeBigInt(dst *big.Int) error { + var buffer []byte + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == List: + return ErrExpectedString + case kind == Byte: + buffer = s.uintbuf[:1] + buffer[0] = s.byteval + s.kind = -1 // re-arm Kind + case size == 0: + // Avoid zero-length read. + s.kind = -1 + case size <= uint64(len(s.uintbuf)): + // For integers smaller than s.uintbuf, allocating a buffer + // can be avoided. + buffer = s.uintbuf[:size] + if err := s.readFull(buffer); err != nil { + return err + } + // Reject inputs where single byte encoding should have been used. + if size == 1 && buffer[0] < 128 { + return ErrCanonSize + } + default: + // For large integers, a temporary buffer is needed. + buffer = make([]byte, size) + if err := s.readFull(buffer); err != nil { + return err + } + } + + // Reject leading zero bytes. + if len(buffer) > 0 && buffer[0] == 0 { + return ErrCanonInt + } + // Set the integer bytes. + dst.SetBytes(buffer) + return nil +} + +// ReadUint256 decodes the next value as a uint256. +func (s *Stream) ReadUint256(dst *uint256.Int) error { + var buffer []byte + kind, size, err := s.Kind() + switch { + case err != nil: + return err + case kind == List: + return ErrExpectedString + case kind == Byte: + buffer = s.uintbuf[:1] + buffer[0] = s.byteval + s.kind = -1 // re-arm Kind + case size == 0: + // Avoid zero-length read. + s.kind = -1 + case size <= uint64(len(s.uintbuf)): + // All possible uint256 values fit into s.uintbuf. + buffer = s.uintbuf[:size] + if err := s.readFull(buffer); err != nil { + return err + } + // Reject inputs where single byte encoding should have been used. + if size == 1 && buffer[0] < 128 { + return ErrCanonSize + } + default: + return errUint256Large + } + + // Reject leading zero bytes. + if len(buffer) > 0 && buffer[0] == 0 { + return ErrCanonInt + } + // Set the integer bytes. + dst.SetBytes(buffer) + return nil +} + // Decode decodes a value and stores the result in the value pointed // to by val. Please see the documentation for the Decode function // to learn about the decoding rules. @@ -809,14 +942,14 @@ func (s *Stream) Decode(val interface{}) error { if rval.IsNil() { return errDecodeIntoNil } - info, err := cachedTypeInfo(rtyp.Elem(), tags{}) + decoder, err := cachedDecoder(rtyp.Elem()) if err != nil { return err } - err = info.decoder(s, rval.Elem()) + err = decoder(s, rval.Elem()) if decErr, ok := err.(*decodeError); ok && len(decErr.ctx) > 0 { - // add decode target type to error so context has more meaning + // Add decode target type to error so context has more meaning. decErr.ctx = append(decErr.ctx, fmt.Sprint("(", rtyp.Elem(), ")")) } return err @@ -839,6 +972,9 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { case *bytes.Reader: s.remaining = uint64(br.Len()) s.limited = true + case *bytes.Buffer: + s.remaining = uint64(br.Len()) + s.limited = true case *strings.Reader: s.remaining = uint64(br.Len()) s.limited = true @@ -857,9 +993,8 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { s.size = 0 s.kind = -1 s.kinderr = nil - if s.uintbuf == nil { - s.uintbuf = make([]byte, 8) - } + s.byteval = 0 + s.uintbuf = [32]byte{} } // Kind returns the kind and size of the next value in the @@ -874,35 +1009,29 @@ func (s *Stream) Reset(r io.Reader, inputLimit uint64) { // the value. Subsequent calls to Kind (until the value is decoded) // will not advance the input reader and return cached information. func (s *Stream) Kind() (kind Kind, size uint64, err error) { - var tos *listpos - if len(s.stack) > 0 { - tos = &s.stack[len(s.stack)-1] - } - if s.kind < 0 { - s.kinderr = nil - // Don't read further if we're at the end of the - // innermost list. - if tos != nil && tos.pos == tos.size { - return 0, 0, EOL - } - s.kind, s.size, s.kinderr = s.readKind() - if s.kinderr == nil { - if tos == nil { - // At toplevel, check that the value is smaller - // than the remaining input length. - if s.limited && s.size > s.remaining { - s.kinderr = ErrValueTooLarge - } - } else { - // Inside a list, check that the value doesn't overflow the list. - if s.size > tos.size-tos.pos { - s.kinderr = ErrElemTooLarge - } - } + if s.kind >= 0 { + return s.kind, s.size, s.kinderr + } + + // Check for end of list. This needs to be done here because readKind + // checks against the list size, and would return the wrong error. + inList, listLimit := s.listLimit() + if inList && listLimit == 0 { + return 0, 0, EOL + } + // Read the actual size tag. + s.kind, s.size, s.kinderr = s.readKind() + if s.kinderr == nil { + // Check the data size of the value ahead against input limits. This + // is done here because many decoders require allocating an input + // buffer matching the value size. Checking it here protects those + // decoders from inputs declaring very large value size. + if inList && s.size > listLimit { + s.kinderr = ErrElemTooLarge + } else if s.limited && s.size > s.remaining { + s.kinderr = ErrValueTooLarge } } - // Note: this might return a sticky error generated - // by an earlier call to readKind. return s.kind, s.size, s.kinderr } @@ -929,37 +1058,35 @@ func (s *Stream) readKind() (kind Kind, size uint64, err error) { s.byteval = b return Byte, 0, nil case b < 0xB8: - // Otherwise, if a string is 0-55 bytes long, - // the RLP encoding consists of a single byte with value 0x80 plus the - // length of the string followed by the string. The range of the first - // byte is thus [0x80, 0xB7]. + // Otherwise, if a string is 0-55 bytes long, the RLP encoding consists + // of a single byte with value 0x80 plus the length of the string + // followed by the string. The range of the first byte is thus [0x80, 0xB7]. return String, uint64(b - 0x80), nil case b < 0xC0: - // If a string is more than 55 bytes long, the - // RLP encoding consists of a single byte with value 0xB7 plus the length - // of the length of the string in binary form, followed by the length of - // the string, followed by the string. For example, a length-1024 string - // would be encoded as 0xB90400 followed by the string. The range of - // the first byte is thus [0xB8, 0xBF]. + // If a string is more than 55 bytes long, the RLP encoding consists of a + // single byte with value 0xB7 plus the length of the length of the + // string in binary form, followed by the length of the string, followed + // by the string. For example, a length-1024 string would be encoded as + // 0xB90400 followed by the string. The range of the first byte is thus + // [0xB8, 0xBF]. size, err = s.readUint(b - 0xB7) if err == nil && size < 56 { err = ErrCanonSize } return String, size, err case b < 0xF8: - // If the total payload of a list - // (i.e. the combined length of all its items) is 0-55 bytes long, the - // RLP encoding consists of a single byte with value 0xC0 plus the length - // of the list followed by the concatenation of the RLP encodings of the - // items. The range of the first byte is thus [0xC0, 0xF7]. + // If the total payload of a list (i.e. the combined length of all its + // items) is 0-55 bytes long, the RLP encoding consists of a single byte + // with value 0xC0 plus the length of the list followed by the + // concatenation of the RLP encodings of the items. The range of the + // first byte is thus [0xC0, 0xF7]. return List, uint64(b - 0xC0), nil default: - // If the total payload of a list is more than 55 bytes long, - // the RLP encoding consists of a single byte with value 0xF7 - // plus the length of the length of the payload in binary - // form, followed by the length of the payload, followed by - // the concatenation of the RLP encodings of the items. The - // range of the first byte is thus [0xF8, 0xFF]. + // If the total payload of a list is more than 55 bytes long, the RLP + // encoding consists of a single byte with value 0xF7 plus the length of + // the length of the payload in binary form, followed by the length of + // the payload, followed by the concatenation of the RLP encodings of + // the items. The range of the first byte is thus [0xF8, 0xFF]. size, err = s.readUint(b - 0xF7) if err == nil && size < 56 { err = ErrCanonSize @@ -977,23 +1104,22 @@ func (s *Stream) readUint(size byte) (uint64, error) { b, err := s.readByte() return uint64(b), err default: + buffer := s.uintbuf[:8] + clear(buffer) start := int(8 - size) - for i := 0; i < start; i++ { - s.uintbuf[i] = 0 - } - if err := s.readFull(s.uintbuf[start:]); err != nil { + if err := s.readFull(buffer[start:]); err != nil { return 0, err } - if s.uintbuf[start] == 0 { - // Note: readUint is also used to decode integer - // values. The error needs to be adjusted to become - // ErrCanonInt in this case. + if buffer[start] == 0 { + // Note: readUint is also used to decode integer values. + // The error needs to be adjusted to become ErrCanonInt in this case. return 0, ErrCanonSize } - return binary.BigEndian.Uint64(s.uintbuf), nil + return binary.BigEndian.Uint64(buffer[:]), nil } } +// readFull reads into buf from the underlying stream. func (s *Stream) readFull(buf []byte) (err error) { if err := s.willRead(uint64(len(buf))); err != nil { return err @@ -1004,11 +1130,18 @@ func (s *Stream) readFull(buf []byte) (err error) { n += nn } if err == io.EOF { - err = io.ErrUnexpectedEOF + if n < len(buf) { + err = io.ErrUnexpectedEOF + } else { + // Readers are allowed to give EOF even though the read succeeded. + // In such cases, we discard the EOF, like io.ReadFull() does. + err = nil + } } return err } +// readByte reads a single byte from the underlying stream. func (s *Stream) readByte() (byte, error) { if err := s.willRead(1); err != nil { return 0, err @@ -1020,16 +1153,16 @@ func (s *Stream) readByte() (byte, error) { return b, err } +// willRead is called before any read from the underlying stream. It checks +// n against size limits, and updates the limits if n doesn't overflow them. func (s *Stream) willRead(n uint64) error { s.kind = -1 // rearm Kind - if len(s.stack) > 0 { - // check list overflow - tos := s.stack[len(s.stack)-1] - if n > tos.size-tos.pos { + if inList, limit := s.listLimit(); inList { + if n > limit { return ErrElemTooLarge } - s.stack[len(s.stack)-1].pos += n + s.stack[len(s.stack)-1] = limit - n } if s.limited { if n > s.remaining { @@ -1039,3 +1172,31 @@ func (s *Stream) willRead(n uint64) error { } return nil } + +// listLimit returns the amount of data remaining in the innermost list. +func (s *Stream) listLimit() (inList bool, limit uint64) { + if len(s.stack) == 0 { + return false, 0 + } + return true, s.stack[len(s.stack)-1] +} + +type sliceReader []byte + +func (sr *sliceReader) Read(b []byte) (int, error) { + if len(*sr) == 0 { + return 0, io.EOF + } + n := copy(b, *sr) + *sr = (*sr)[n:] + return n, nil +} + +func (sr *sliceReader) ReadByte() (byte, error) { + if len(*sr) == 0 { + return 0, io.EOF + } + b := (*sr)[0] + *sr = (*sr)[1:] + return b, nil +} diff --git a/rlp/decode_test.go b/rlp/decode_test.go index 4d8abd001281..38cca38aa548 100644 --- a/rlp/decode_test.go +++ b/rlp/decode_test.go @@ -26,6 +26,9 @@ import ( "reflect" "strings" "testing" + + "github.com/XinFinOrg/XDPoSChain/common/math" + "github.com/holiman/uint256" ) func TestStreamKind(t *testing.T) { @@ -284,6 +287,47 @@ func TestStreamRaw(t *testing.T) { } } +func TestStreamReadBytes(t *testing.T) { + tests := []struct { + input string + size int + err string + }{ + // kind List + {input: "C0", size: 1, err: "rlp: expected String or Byte"}, + // kind Byte + {input: "04", size: 0, err: "input value has wrong size 1, want 0"}, + {input: "04", size: 1}, + {input: "04", size: 2, err: "input value has wrong size 1, want 2"}, + // kind String + {input: "820102", size: 0, err: "input value has wrong size 2, want 0"}, + {input: "820102", size: 1, err: "input value has wrong size 2, want 1"}, + {input: "820102", size: 2}, + {input: "820102", size: 3, err: "input value has wrong size 2, want 3"}, + } + + for _, test := range tests { + test := test + name := fmt.Sprintf("input_%s/size_%d", test.input, test.size) + t.Run(name, func(t *testing.T) { + s := NewStream(bytes.NewReader(unhex(test.input)), 0) + b := make([]byte, test.size) + err := s.ReadBytes(b) + if test.err == "" { + if err != nil { + t.Errorf("unexpected error %q", err) + } + } else { + if err == nil { + t.Errorf("expected error, got nil") + } else if err.Error() != test.err { + t.Errorf("wrong error %q", err) + } + } + }) + } +} + func TestDecodeErrors(t *testing.T) { r := bytes.NewReader(nil) @@ -327,6 +371,15 @@ type recstruct struct { Child *recstruct `rlp:"nil"` } +type bigIntStruct struct { + I *big.Int + B string +} + +type invalidNilTag struct { + X []byte `rlp:"nil"` +} + type invalidTail1 struct { A uint `rlp:"tail"` B string @@ -347,19 +400,79 @@ type tailUint struct { Tail []uint `rlp:"tail"` } -var ( - veryBigInt = big.NewInt(0).Add( - big.NewInt(0).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), - big.NewInt(0xFFFF), - ) -) +type tailPrivateFields struct { + A uint + Tail []uint `rlp:"tail"` + x, y bool //lint:ignore U1000 unused fields required for testing purposes. +} + +type nilListUint struct { + X *uint `rlp:"nilList"` +} + +type nilStringSlice struct { + X *[]uint `rlp:"nilString"` +} + +type intField struct { + X int +} + +type optionalFields struct { + A uint + B uint `rlp:"optional"` + C uint `rlp:"optional"` +} + +type optionalAndTailField struct { + A uint + B uint `rlp:"optional"` + Tail []uint `rlp:"tail"` +} + +type optionalBigIntField struct { + A uint + B *big.Int `rlp:"optional"` +} + +type optionalPtrField struct { + A uint + B *[3]byte `rlp:"optional"` +} + +type nonOptionalPtrField struct { + A uint + B *[3]byte +} -type hasIgnoredField struct { +type multipleOptionalFields struct { + A *[3]byte `rlp:"optional"` + B *[3]byte `rlp:"optional"` +} + +type optionalPtrFieldNil struct { + A uint + B *[3]byte `rlp:"optional,nil"` +} + +type ignoredField struct { A uint B uint `rlp:"-"` C uint } +var ( + veryBigInt = new(big.Int).Add( + new(big.Int).Lsh(big.NewInt(0xFFFFFFFFFFFFFF), 16), + big.NewInt(0xFFFF), + ) + veryVeryBigInt = new(big.Int).Exp(veryBigInt, big.NewInt(8), nil) +) + +var ( + veryBigInt256, _ = uint256.FromBig(veryBigInt) +) + var decodeTests = []decodeTest{ // booleans {input: "01", ptr: new(bool), value: true}, @@ -428,12 +541,31 @@ var decodeTests = []decodeTest{ {input: "C0", ptr: new(string), error: "rlp: expected input string or byte for string"}, // big ints + {input: "80", ptr: new(*big.Int), value: big.NewInt(0)}, {input: "01", ptr: new(*big.Int), value: big.NewInt(1)}, {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*big.Int), value: veryBigInt}, + {input: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", ptr: new(*big.Int), value: veryVeryBigInt}, {input: "10", ptr: new(big.Int), value: *big.NewInt(16)}, // non-pointer also works + + // big int errors {input: "C0", ptr: new(*big.Int), error: "rlp: expected input string or byte for *big.Int"}, - {input: "820001", ptr: new(big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, - {input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"}, + {input: "00", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, + {input: "820001", ptr: new(*big.Int), error: "rlp: non-canonical integer (leading zero bytes) for *big.Int"}, + {input: "8105", ptr: new(*big.Int), error: "rlp: non-canonical size information for *big.Int"}, + + // uint256 + {input: "80", ptr: new(*uint256.Int), value: uint256.NewInt(0)}, + {input: "01", ptr: new(*uint256.Int), value: uint256.NewInt(1)}, + {input: "88FFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: uint256.NewInt(math.MaxUint64)}, + {input: "89FFFFFFFFFFFFFFFFFF", ptr: new(*uint256.Int), value: veryBigInt256}, + {input: "10", ptr: new(uint256.Int), value: *uint256.NewInt(16)}, // non-pointer also works + + // uint256 errors + {input: "C0", ptr: new(*uint256.Int), error: "rlp: expected input string or byte for *uint256.Int"}, + {input: "00", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "820001", ptr: new(*uint256.Int), error: "rlp: non-canonical integer (leading zero bytes) for *uint256.Int"}, + {input: "8105", ptr: new(*uint256.Int), error: "rlp: non-canonical size information for *uint256.Int"}, + {input: "A1FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00", ptr: new(*uint256.Int), error: "rlp: value too large for uint256"}, // structs { @@ -446,6 +578,13 @@ var decodeTests = []decodeTest{ ptr: new(recstruct), value: recstruct{1, &recstruct{2, &recstruct{3, nil}}}, }, + { + // This checks that empty big.Int works correctly in struct context. It's easy to + // miss the update of s.kind for this case, so it needs its own test. + input: "C58083343434", + ptr: new(bigIntStruct), + value: bigIntStruct{new(big.Int), "444"}, + }, // struct errors { @@ -479,20 +618,20 @@ var decodeTests = []decodeTest{ error: "rlp: expected input string or byte for uint, decoding into (rlp.recstruct).Child.I", }, { - input: "C0", - ptr: new(invalidTail1), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail1.A (must be on last field)", - }, - { - input: "C0", - ptr: new(invalidTail2), - error: "rlp: invalid struct tag \"tail\" for rlp.invalidTail2.B (field type is not slice)", + input: "C103", + ptr: new(intField), + error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)", }, { input: "C50102C20102", ptr: new(tailUint), error: "rlp: expected input string or byte for uint, decoding into (rlp.tailUint).Tail[1]", }, + { + input: "C0", + ptr: new(invalidNilTag), + error: `rlp: invalid struct tag "nil" for rlp.invalidNilTag.X (field is not a pointer)`, + }, // struct tag "tail" { @@ -510,12 +649,192 @@ var decodeTests = []decodeTest{ ptr: new(tailRaw), value: tailRaw{A: 1, Tail: []RawValue{}}, }, + { + input: "C3010203", + ptr: new(tailPrivateFields), + value: tailPrivateFields{A: 1, Tail: []uint{2, 3}}, + }, + { + input: "C0", + ptr: new(invalidTail1), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail1.A (must be on last field)`, + }, + { + input: "C0", + ptr: new(invalidTail2), + error: `rlp: invalid struct tag "tail" for rlp.invalidTail2.B (field type is not slice)`, + }, // struct tag "-" { input: "C20102", - ptr: new(hasIgnoredField), - value: hasIgnoredField{A: 1, C: 2}, + ptr: new(ignoredField), + value: ignoredField{A: 1, C: 2}, + }, + + // struct tag "nilList" + { + input: "C180", + ptr: new(nilListUint), + error: "rlp: wrong kind of empty value (got String, want List) for *uint, decoding into (rlp.nilListUint).X", + }, + { + input: "C1C0", + ptr: new(nilListUint), + value: nilListUint{}, + }, + { + input: "C103", + ptr: new(nilListUint), + value: func() interface{} { + v := uint(3) + return nilListUint{X: &v} + }(), + }, + + // struct tag "nilString" + { + input: "C1C0", + ptr: new(nilStringSlice), + error: "rlp: wrong kind of empty value (got List, want String) for *[]uint, decoding into (rlp.nilStringSlice).X", + }, + { + input: "C180", + ptr: new(nilStringSlice), + value: nilStringSlice{}, + }, + { + input: "C2C103", + ptr: new(nilStringSlice), + value: nilStringSlice{X: &[]uint{3}}, + }, + + // struct tag "optional" + { + input: "C101", + ptr: new(optionalFields), + value: optionalFields{1, 0, 0}, + }, + { + input: "C20102", + ptr: new(optionalFields), + value: optionalFields{1, 2, 0}, + }, + { + input: "C3010203", + ptr: new(optionalFields), + value: optionalFields{1, 2, 3}, + }, + { + input: "C401020304", + ptr: new(optionalFields), + error: "rlp: input list has too many elements for rlp.optionalFields", + }, + { + input: "C101", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C401020304", + ptr: new(optionalAndTailField), + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{3, 4}}, + }, + { + input: "C101", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: nil}, + }, + { + input: "C20102", + ptr: new(optionalBigIntField), + value: optionalBigIntField{A: 1, B: big.NewInt(2)}, + }, + { + input: "C101", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1}, + }, + { + input: "C20180", // not accepted because "optional" doesn't enable "nil" + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C20102", + ptr: new(optionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrField).B", + }, + { + input: "C50183010203", + ptr: new(optionalPtrField), + value: optionalPtrField{A: 1, B: &[3]byte{1, 2, 3}}, + }, + { + // all optional fields nil + input: "C0", + ptr: new(multipleOptionalFields), + value: multipleOptionalFields{A: nil, B: nil}, + }, + { + // all optional fields set + input: "C88301020383010203", + ptr: new(multipleOptionalFields), + value: multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, + }, + { + // nil optional field appears before a non-nil one + input: "C58083010203", + ptr: new(multipleOptionalFields), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.multipleOptionalFields).A", + }, + { + // decode a nil ptr into a ptr that is not nil or not optional + input: "C20180", + ptr: new(nonOptionalPtrField), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.nonOptionalPtrField).B", + }, + { + input: "C101", + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20180", // accepted because "nil" tag allows empty input + ptr: new(optionalPtrFieldNil), + value: optionalPtrFieldNil{A: 1}, + }, + { + input: "C20102", + ptr: new(optionalPtrFieldNil), + error: "rlp: input string too short for [3]uint8, decoding into (rlp.optionalPtrFieldNil).B", + }, + + // struct tag "optional" field clearing + { + input: "C101", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 0, C: 0}, + }, + { + input: "C20102", + ptr: &optionalFields{A: 9, B: 8, C: 7}, + value: optionalFields{A: 1, B: 2, C: 0}, + }, + { + input: "C20102", + ptr: &optionalAndTailField{A: 9, B: 8, Tail: []uint{7, 6, 5}}, + value: optionalAndTailField{A: 1, B: 2, Tail: []uint{}}, + }, + { + input: "C101", + ptr: &optionalPtrField{A: 9, B: &[3]byte{8, 7, 6}}, + value: optionalPtrField{A: 1}, }, // RawValue @@ -591,6 +910,26 @@ func TestDecodeWithByteReader(t *testing.T) { }) } +func testDecodeWithEncReader(t *testing.T, n int) { + s := strings.Repeat("0", n) + _, r, _ := EncodeToReader(s) + var decoded string + err := Decode(r, &decoded) + if err != nil { + t.Errorf("Unexpected decode error with n=%v: %v", n, err) + } + if decoded != s { + t.Errorf("Decode mismatch with n=%v", n) + } +} + +// This is a regression test checking that decoding from encReader +// works for RLP values of size 8192 bytes or more. +func TestDecodeWithEncReader(t *testing.T) { + testDecodeWithEncReader(t, 8188) // length with header is 8191 + testDecodeWithEncReader(t, 8189) // length with header is 8192 +} + // plainReader reads from a byte slice but does not // implement ReadByte. It is also not recognized by the // size validation. This is useful to test how the decoder @@ -661,6 +1000,22 @@ func TestDecodeDecoder(t *testing.T) { } } +func TestDecodeDecoderNilPointer(t *testing.T) { + var s struct { + T1 *testDecoder `rlp:"nil"` + T2 *testDecoder + } + if err := Decode(bytes.NewReader(unhex("C2C002")), &s); err != nil { + t.Fatalf("Decode error: %v", err) + } + if s.T1 != nil { + t.Errorf("decoder T1 allocated for empty input (called: %v)", s.T1.called) + } + if s.T2 == nil || !s.T2.called { + t.Errorf("decoder T2 not allocated/called") + } +} + type byteDecoder byte func (bd *byteDecoder) DecodeRLP(s *Stream) error { @@ -691,13 +1046,66 @@ func TestDecoderInByteSlice(t *testing.T) { } } +type unencodableDecoder func() + +func (f *unencodableDecoder) DecodeRLP(s *Stream) error { + if _, err := s.List(); err != nil { + return err + } + if err := s.ListEnd(); err != nil { + return err + } + *f = func() {} + return nil +} + +func TestDecoderFunc(t *testing.T) { + var x func() + if err := DecodeBytes([]byte{0xC0}, (*unencodableDecoder)(&x)); err != nil { + t.Fatal(err) + } + x() +} + +// This tests the validity checks for fields with struct tag "optional". +func TestInvalidOptionalField(t *testing.T) { + type ( + invalid1 struct { + A uint `rlp:"optional"` + B uint + } + invalid2 struct { + T []uint `rlp:"tail,optional"` + } + invalid3 struct { + T []uint `rlp:"optional,tail"` + } + ) + + tests := []struct { + v interface{} + err string + }{ + {v: new(invalid1), err: `rlp: invalid struct tag "" for rlp.invalid1.B (must be optional because preceding field "A" is optional)`}, + {v: new(invalid2), err: `rlp: invalid struct tag "optional" for rlp.invalid2.T (also has "tail" tag)`}, + {v: new(invalid3), err: `rlp: invalid struct tag "tail" for rlp.invalid3.T (also has "optional" tag)`}, + } + for _, test := range tests { + err := DecodeBytes(unhex("C20102"), test.v) + if err == nil { + t.Errorf("no error for %T", test.v) + } else if err.Error() != test.err { + t.Errorf("wrong error for %T: %v", test.v, err.Error()) + } + } +} + func ExampleDecode() { input, _ := hex.DecodeString("C90A1486666F6F626172") type example struct { - A, B uint - private uint // private fields are ignored - String string + A, B uint + String string } var s example @@ -708,7 +1116,7 @@ func ExampleDecode() { fmt.Printf("Decoded value: %#v\n", s) } // Output: - // Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"} + // Decoded value: rlp.example{A:0xa, B:0x14, String:"foobar"} } func ExampleDecode_structTagNil() { @@ -768,7 +1176,7 @@ func ExampleStream() { // [102 111 111 98 97 114] } -func BenchmarkDecode(b *testing.B) { +func BenchmarkDecodeUints(b *testing.B) { enc := encodeTestSlice(90000) b.SetBytes(int64(len(enc))) b.ReportAllocs() @@ -783,7 +1191,7 @@ func BenchmarkDecode(b *testing.B) { } } -func BenchmarkDecodeIntSliceReuse(b *testing.B) { +func BenchmarkDecodeUintsReused(b *testing.B) { enc := encodeTestSlice(100000) b.SetBytes(int64(len(enc))) b.ReportAllocs() @@ -798,6 +1206,65 @@ func BenchmarkDecodeIntSliceReuse(b *testing.B) { } } +func BenchmarkDecodeByteArrayStruct(b *testing.B) { + enc, err := EncodeToBytes(&byteArrayStruct{}) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out byteArrayStruct + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecodeBigInts(b *testing.B) { + ints := make([]*big.Int, 200) + for i := range ints { + ints[i] = math.BigPow(2, int64(i)) + } + enc, err := EncodeToBytes(ints) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out []*big.Int + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + enc, err := EncodeToBytes(ints) + if err != nil { + b.Fatal(err) + } + b.SetBytes(int64(len(enc))) + b.ReportAllocs() + b.ResetTimer() + + var out []*uint256.Int + for i := 0; i < b.N; i++ { + if err := DecodeBytes(enc, &out); err != nil { + b.Fatal(err) + } + } +} + func encodeTestSlice(n uint) []byte { s := make([]uint, n) for i := uint(0); i < n; i++ { @@ -811,7 +1278,7 @@ func encodeTestSlice(n uint) []byte { } func unhex(str string) []byte { - b, err := hex.DecodeString(strings.Replace(str, " ", "", -1)) + b, err := hex.DecodeString(strings.ReplaceAll(str, " ", "")) if err != nil { panic(fmt.Sprintf("invalid hex string: %q", str)) } diff --git a/rlp/doc.go b/rlp/doc.go index b3a81fe2326f..eeeee9a43a0c 100644 --- a/rlp/doc.go +++ b/rlp/doc.go @@ -17,17 +17,142 @@ /* Package rlp implements the RLP serialization format. -The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily -nested arrays of binary data, and RLP is the main encoding method used -to serialize objects in Ethereum. The only purpose of RLP is to encode -structure; encoding specific atomic data types (eg. strings, ints, -floats) is left up to higher-order protocols; in Ethereum integers -must be represented in big endian binary form with no leading zeroes -(thus making the integer value zero equivalent to the empty byte -array). - -RLP values are distinguished by a type tag. The type tag precedes the -value in the input stream and defines the size and kind of the bytes -that follow. +The purpose of RLP (Recursive Linear Prefix) is to encode arbitrarily nested arrays of +binary data, and RLP is the main encoding method used to serialize objects in Ethereum. +The only purpose of RLP is to encode structure; encoding specific atomic data types (eg. +strings, ints, floats) is left up to higher-order protocols. In Ethereum integers must be +represented in big endian binary form with no leading zeroes (thus making the integer +value zero equivalent to the empty string). + +RLP values are distinguished by a type tag. The type tag precedes the value in the input +stream and defines the size and kind of the bytes that follow. + +# Encoding Rules + +Package rlp uses reflection and encodes RLP based on the Go type of the value. + +If the type implements the Encoder interface, Encode calls EncodeRLP. It does not +call EncodeRLP on nil pointer values. + +To encode a pointer, the value being pointed to is encoded. A nil pointer to a struct +type, slice or array always encodes as an empty RLP list unless the slice or array has +element type byte. A nil pointer to any other value encodes as the empty string. + +Struct values are encoded as an RLP list of all their encoded public fields. Recursive +struct types are supported. + +To encode slices and arrays, the elements are encoded as an RLP list of the value's +elements. Note that arrays and slices with element type uint8 or byte are always encoded +as an RLP string. + +A Go string is encoded as an RLP string. + +An unsigned integer value is encoded as an RLP string. Zero always encodes as an empty RLP +string. big.Int values are treated as integers. Signed integers (int, int8, int16, ...) +are not supported and will return an error when encoding. + +Boolean values are encoded as the unsigned integers zero (false) and one (true). + +An interface value encodes as the value contained in the interface. + +Floating point numbers, maps, channels and functions are not supported. + +# Decoding Rules + +Decoding uses the following type-dependent rules: + +If the type implements the Decoder interface, DecodeRLP is called. + +To decode into a pointer, the value will be decoded as the element type of the pointer. If +the pointer is nil, a new value of the pointer's element type is allocated. If the pointer +is non-nil, the existing value will be reused. Note that package rlp never leaves a +pointer-type struct field as nil unless one of the "nil" struct tags is present. + +To decode into a struct, decoding expects the input to be an RLP list. The decoded +elements of the list are assigned to each public field in the order given by the struct's +definition. The input list must contain an element for each decoded field. Decoding +returns an error if there are too few or too many elements for the struct. + +To decode into a slice, the input must be a list and the resulting slice will contain the +input elements in order. For byte slices, the input must be an RLP string. Array types +decode similarly, with the additional restriction that the number of input elements (or +bytes) must match the array's defined length. + +To decode into a Go string, the input must be an RLP string. The input bytes are taken +as-is and will not necessarily be valid UTF-8. + +To decode into an unsigned integer type, the input must also be an RLP string. The bytes +are interpreted as a big endian representation of the integer. If the RLP string is larger +than the bit size of the type, decoding will return an error. Decode also supports +*big.Int. There is no size limit for big integers. + +To decode into a boolean, the input must contain an unsigned integer of value zero (false) +or one (true). + +To decode into an interface value, one of these types is stored in the value: + + []interface{}, for RLP lists + []byte, for RLP strings + +Non-empty interface types are not supported when decoding. +Signed integers, floating point numbers, maps, channels and functions cannot be decoded into. + +# Struct Tags + +As with other encoding packages, the "-" tag ignores fields. + + type StructWithIgnoredField struct{ + Ignored uint `rlp:"-"` + Field uint + } + +Go struct values encode/decode as RLP lists. There are two ways of influencing the mapping +of fields to list elements. The "tail" tag, which may only be used on the last exported +struct field, allows slurping up any excess list elements into a slice. + + type StructWithTail struct{ + Field uint + Tail []string `rlp:"tail"` + } + +The "optional" tag says that the field may be omitted if it is zero-valued. If this tag is +used on a struct field, all subsequent public fields must also be declared optional. + +When encoding a struct with optional fields, the output RLP list contains all values up to +the last non-zero optional field. + +When decoding into a struct, optional fields may be omitted from the end of the input +list. For the example below, this means input lists of one, two, or three elements are +accepted. + + type StructWithOptionalFields struct{ + Required uint + Optional1 uint `rlp:"optional"` + Optional2 uint `rlp:"optional"` + } + +The "nil", "nilList" and "nilString" tags apply to pointer-typed fields only, and change +the decoding rules for the field type. For regular pointer fields without the "nil" tag, +input values must always match the required input length exactly and the decoder does not +produce nil values. When the "nil" tag is set, input values of size zero decode as a nil +pointer. This is especially useful for recursive types. + + type StructWithNilField struct { + Field *[3]byte `rlp:"nil"` + } + +In the example above, Field allows two possible input sizes. For input 0xC180 (a list +containing an empty string) Field is set to nil after decoding. For input 0xC483000000 (a +list containing a 3-byte string), Field is set to a non-nil array pointer. + +RLP supports two kinds of empty values: empty lists and empty strings. When using the +"nil" tag, the kind of empty value allowed for a type is chosen automatically. A field +whose Go type is a pointer to an unsigned integer, string, boolean or byte array/slice +expects an empty RLP string. Any other pointer field type encodes/decodes as an empty RLP +list. + +The choice of null value can be made explicit with the "nilList" and "nilString" struct +tags. Using these tags encodes/decodes a Go nil pointer value as the empty RLP value kind +defined by the tag. */ package rlp diff --git a/rlp/encbuffer.go b/rlp/encbuffer.go new file mode 100644 index 000000000000..8d3a3b2293a5 --- /dev/null +++ b/rlp/encbuffer.go @@ -0,0 +1,423 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +import ( + "encoding/binary" + "io" + "math/big" + "reflect" + "sync" + + "github.com/holiman/uint256" +) + +type encBuffer struct { + str []byte // string data, contains everything except list headers + lheads []listhead // all list headers + lhsize int // sum of sizes of all encoded list headers + sizebuf [9]byte // auxiliary buffer for uint encoding +} + +// The global encBuffer pool. +var encBufferPool = sync.Pool{ + New: func() interface{} { return new(encBuffer) }, +} + +func getEncBuffer() *encBuffer { + buf := encBufferPool.Get().(*encBuffer) + buf.reset() + return buf +} + +func (buf *encBuffer) reset() { + buf.lhsize = 0 + buf.str = buf.str[:0] + buf.lheads = buf.lheads[:0] +} + +// size returns the length of the encoded data. +func (buf *encBuffer) size() int { + return len(buf.str) + buf.lhsize +} + +// makeBytes creates the encoder output. +func (buf *encBuffer) makeBytes() []byte { + out := make([]byte, buf.size()) + buf.copyTo(out) + return out +} + +func (buf *encBuffer) copyTo(dst []byte) { + strpos := 0 + pos := 0 + for _, head := range buf.lheads { + // write string data before header + n := copy(dst[pos:], buf.str[strpos:head.offset]) + pos += n + strpos += n + // write the header + enc := head.encode(dst[pos:]) + pos += len(enc) + } + // copy string data after the last list header + copy(dst[pos:], buf.str[strpos:]) +} + +// writeTo writes the encoder output to w. +func (buf *encBuffer) writeTo(w io.Writer) (err error) { + strpos := 0 + for _, head := range buf.lheads { + // write string data before header + if head.offset-strpos > 0 { + n, err := w.Write(buf.str[strpos:head.offset]) + strpos += n + if err != nil { + return err + } + } + // write the header + enc := head.encode(buf.sizebuf[:]) + if _, err = w.Write(enc); err != nil { + return err + } + } + if strpos < len(buf.str) { + // write string data after the last list header + _, err = w.Write(buf.str[strpos:]) + } + return err +} + +// Write implements io.Writer and appends b directly to the output. +func (buf *encBuffer) Write(b []byte) (int, error) { + buf.str = append(buf.str, b...) + return len(b), nil +} + +// writeBool writes b as the integer 0 (false) or 1 (true). +func (buf *encBuffer) writeBool(b bool) { + if b { + buf.str = append(buf.str, 0x01) + } else { + buf.str = append(buf.str, 0x80) + } +} + +func (buf *encBuffer) writeUint64(i uint64) { + if i == 0 { + buf.str = append(buf.str, 0x80) + } else if i < 128 { + // fits single byte + buf.str = append(buf.str, byte(i)) + } else { + s := putint(buf.sizebuf[1:], i) + buf.sizebuf[0] = 0x80 + byte(s) + buf.str = append(buf.str, buf.sizebuf[:s+1]...) + } +} + +func (buf *encBuffer) writeBytes(b []byte) { + if len(b) == 1 && b[0] <= 0x7F { + // fits single byte, no string header + buf.str = append(buf.str, b[0]) + } else { + buf.encodeStringHeader(len(b)) + buf.str = append(buf.str, b...) + } +} + +func (buf *encBuffer) writeString(s string) { + buf.writeBytes([]byte(s)) +} + +// wordBytes is the number of bytes in a big.Word +const wordBytes = (32 << (uint64(^big.Word(0)) >> 63)) / 8 + +// writeBigInt writes i as an integer. +func (buf *encBuffer) writeBigInt(i *big.Int) { + bitlen := i.BitLen() + if bitlen <= 64 { + buf.writeUint64(i.Uint64()) + return + } + // Integer is larger than 64 bits, encode from i.Bits(). + // The minimal byte length is bitlen rounded up to the next + // multiple of 8, divided by 8. + length := ((bitlen + 7) & -8) >> 3 + buf.encodeStringHeader(length) + buf.str = append(buf.str, make([]byte, length)...) + index := length + bytesBuf := buf.str[len(buf.str)-length:] + for _, d := range i.Bits() { + for j := 0; j < wordBytes && index > 0; j++ { + index-- + bytesBuf[index] = byte(d) + d >>= 8 + } + } +} + +// writeUint256 writes z as an integer. +func (buf *encBuffer) writeUint256(z *uint256.Int) { + bitlen := z.BitLen() + if bitlen <= 64 { + buf.writeUint64(z.Uint64()) + return + } + nBytes := byte((bitlen + 7) / 8) + var b [33]byte + binary.BigEndian.PutUint64(b[1:9], z[3]) + binary.BigEndian.PutUint64(b[9:17], z[2]) + binary.BigEndian.PutUint64(b[17:25], z[1]) + binary.BigEndian.PutUint64(b[25:33], z[0]) + b[32-nBytes] = 0x80 + nBytes + buf.str = append(buf.str, b[32-nBytes:]...) +} + +// list adds a new list header to the header stack. It returns the index of the header. +// Call listEnd with this index after encoding the content of the list. +func (buf *encBuffer) list() int { + buf.lheads = append(buf.lheads, listhead{offset: len(buf.str), size: buf.lhsize}) + return len(buf.lheads) - 1 +} + +func (buf *encBuffer) listEnd(index int) { + lh := &buf.lheads[index] + lh.size = buf.size() - lh.offset - lh.size + if lh.size < 56 { + buf.lhsize++ // length encoded into kind tag + } else { + buf.lhsize += 1 + intsize(uint64(lh.size)) + } +} + +func (buf *encBuffer) encode(val interface{}) error { + rval := reflect.ValueOf(val) + writer, err := cachedWriter(rval.Type()) + if err != nil { + return err + } + return writer(rval, buf) +} + +func (buf *encBuffer) encodeStringHeader(size int) { + if size < 56 { + buf.str = append(buf.str, 0x80+byte(size)) + } else { + sizesize := putint(buf.sizebuf[1:], uint64(size)) + buf.sizebuf[0] = 0xB7 + byte(sizesize) + buf.str = append(buf.str, buf.sizebuf[:sizesize+1]...) + } +} + +// encReader is the io.Reader returned by EncodeToReader. +// It releases its encbuf at EOF. +type encReader struct { + buf *encBuffer // the buffer we're reading from. this is nil when we're at EOF. + lhpos int // index of list header that we're reading + strpos int // current position in string buffer + piece []byte // next piece to be read +} + +func (r *encReader) Read(b []byte) (n int, err error) { + for { + if r.piece = r.next(); r.piece == nil { + // Put the encode buffer back into the pool at EOF when it + // is first encountered. Subsequent calls still return EOF + // as the error but the buffer is no longer valid. + if r.buf != nil { + encBufferPool.Put(r.buf) + r.buf = nil + } + return n, io.EOF + } + nn := copy(b[n:], r.piece) + n += nn + if nn < len(r.piece) { + // piece didn't fit, see you next time. + r.piece = r.piece[nn:] + return n, nil + } + r.piece = nil + } +} + +// next returns the next piece of data to be read. +// it returns nil at EOF. +func (r *encReader) next() []byte { + switch { + case r.buf == nil: + return nil + + case r.piece != nil: + // There is still data available for reading. + return r.piece + + case r.lhpos < len(r.buf.lheads): + // We're before the last list header. + head := r.buf.lheads[r.lhpos] + sizebefore := head.offset - r.strpos + if sizebefore > 0 { + // String data before header. + p := r.buf.str[r.strpos:head.offset] + r.strpos += sizebefore + return p + } + r.lhpos++ + return head.encode(r.buf.sizebuf[:]) + + case r.strpos < len(r.buf.str): + // String data at the end, after all list headers. + p := r.buf.str[r.strpos:] + r.strpos = len(r.buf.str) + return p + + default: + return nil + } +} + +func encBufferFromWriter(w io.Writer) *encBuffer { + switch w := w.(type) { + case EncoderBuffer: + return w.buf + case *EncoderBuffer: + return w.buf + case *encBuffer: + return w + default: + return nil + } +} + +// EncoderBuffer is a buffer for incremental encoding. +// +// The zero value is NOT ready for use. To get a usable buffer, +// create it using NewEncoderBuffer or call Reset. +type EncoderBuffer struct { + buf *encBuffer + dst io.Writer + + ownBuffer bool +} + +// NewEncoderBuffer creates an encoder buffer. +func NewEncoderBuffer(dst io.Writer) EncoderBuffer { + var w EncoderBuffer + w.Reset(dst) + return w +} + +// Reset truncates the buffer and sets the output destination. +func (w *EncoderBuffer) Reset(dst io.Writer) { + if w.buf != nil && !w.ownBuffer { + panic("can't Reset derived EncoderBuffer") + } + + // If the destination writer has an *encBuffer, use it. + // Note that w.ownBuffer is left false here. + if dst != nil { + if outer := encBufferFromWriter(dst); outer != nil { + *w = EncoderBuffer{outer, nil, false} + return + } + } + + // Get a fresh buffer. + if w.buf == nil { + w.buf = encBufferPool.Get().(*encBuffer) + w.ownBuffer = true + } + w.buf.reset() + w.dst = dst +} + +// Flush writes encoded RLP data to the output writer. This can only be called once. +// If you want to re-use the buffer after Flush, you must call Reset. +func (w *EncoderBuffer) Flush() error { + var err error + if w.dst != nil { + err = w.buf.writeTo(w.dst) + } + // Release the internal buffer. + if w.ownBuffer { + encBufferPool.Put(w.buf) + } + *w = EncoderBuffer{} + return err +} + +// ToBytes returns the encoded bytes. +func (w *EncoderBuffer) ToBytes() []byte { + return w.buf.makeBytes() +} + +// AppendToBytes appends the encoded bytes to dst. +func (w *EncoderBuffer) AppendToBytes(dst []byte) []byte { + size := w.buf.size() + out := append(dst, make([]byte, size)...) + w.buf.copyTo(out[len(dst):]) + return out +} + +// Write appends b directly to the encoder output. +func (w EncoderBuffer) Write(b []byte) (int, error) { + return w.buf.Write(b) +} + +// WriteBool writes b as the integer 0 (false) or 1 (true). +func (w EncoderBuffer) WriteBool(b bool) { + w.buf.writeBool(b) +} + +// WriteUint64 encodes an unsigned integer. +func (w EncoderBuffer) WriteUint64(i uint64) { + w.buf.writeUint64(i) +} + +// WriteBigInt encodes a big.Int as an RLP string. +// Note: Unlike with Encode, the sign of i is ignored. +func (w EncoderBuffer) WriteBigInt(i *big.Int) { + w.buf.writeBigInt(i) +} + +// WriteUint256 encodes uint256.Int as an RLP string. +func (w EncoderBuffer) WriteUint256(i *uint256.Int) { + w.buf.writeUint256(i) +} + +// WriteBytes encodes b as an RLP string. +func (w EncoderBuffer) WriteBytes(b []byte) { + w.buf.writeBytes(b) +} + +// WriteString encodes s as an RLP string. +func (w EncoderBuffer) WriteString(s string) { + w.buf.writeString(s) +} + +// List starts a list. It returns an internal index. Call EndList with +// this index after encoding the content to finish the list. +func (w EncoderBuffer) List() int { + return w.buf.list() +} + +// ListEnd finishes the given list. +func (w EncoderBuffer) ListEnd(index int) { + w.buf.listEnd(index) +} diff --git a/rlp/encbuffer_example_test.go b/rlp/encbuffer_example_test.go new file mode 100644 index 000000000000..f737f3e40c72 --- /dev/null +++ b/rlp/encbuffer_example_test.go @@ -0,0 +1,45 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp_test + +import ( + "bytes" + "fmt" + + "github.com/XinFinOrg/XDPoSChain/rlp" +) + +func ExampleEncoderBuffer() { + var w bytes.Buffer + + // Encode [4, [5, 6]] to w. + buf := rlp.NewEncoderBuffer(&w) + l1 := buf.List() + buf.WriteUint64(4) + l2 := buf.List() + buf.WriteUint64(5) + buf.WriteUint64(6) + buf.ListEnd(l2) + buf.ListEnd(l1) + + if err := buf.Flush(); err != nil { + panic(err) + } + fmt.Printf("%X\n", w.Bytes()) + // Output: + // C404C20506 +} diff --git a/rlp/encode.go b/rlp/encode.go index 44592c2f53ed..9435cfc22c6b 100644 --- a/rlp/encode.go +++ b/rlp/encode.go @@ -17,20 +17,28 @@ package rlp import ( + "errors" "fmt" "io" "math/big" "reflect" - "sync" + + "github.com/XinFinOrg/XDPoSChain/rlp/internal/rlpstruct" + "github.com/holiman/uint256" ) var ( // Common encoded values. // These are useful when implementing EncodeRLP. + + // EmptyString is the encoding of an empty string. EmptyString = []byte{0x80} - EmptyList = []byte{0xC0} + // EmptyList is the encoding of an empty list. + EmptyList = []byte{0xC0} ) +var ErrNegativeBigInt = errors.New("rlp: cannot encode negative big.Int") + // Encoder is implemented by types that require custom // encoding rules or want to encode private fields. type Encoder interface { @@ -49,80 +57,48 @@ type Encoder interface { // perform many small writes in some cases. Consider making w // buffered. // -// Encode uses the following type-dependent encoding rules: -// -// If the type implements the Encoder interface, Encode calls -// EncodeRLP. This is true even for nil pointers, please see the -// documentation for Encoder. -// -// To encode a pointer, the value being pointed to is encoded. For nil -// pointers, Encode will encode the zero value of the type. A nil -// pointer to a struct type always encodes as an empty RLP list. -// A nil pointer to an array encodes as an empty list (or empty string -// if the array has element type byte). -// -// Struct values are encoded as an RLP list of all their encoded -// public fields. Recursive struct types are supported. -// -// To encode slices and arrays, the elements are encoded as an RLP -// list of the value's elements. Note that arrays and slices with -// element type uint8 or byte are always encoded as an RLP string. -// -// A Go string is encoded as an RLP string. -// -// An unsigned integer value is encoded as an RLP string. Zero always -// encodes as an empty RLP string. Encode also supports *big.Int. -// -// An interface value encodes as the value contained in the interface. -// -// Boolean values are not supported, nor are signed integers, floating -// point numbers, maps, channels and functions. +// Please see package-level documentation of encoding rules. func Encode(w io.Writer, val interface{}) error { - if outer, ok := w.(*encbuf); ok { - // Encode was called by some type's EncodeRLP. - // Avoid copying by writing to the outer encbuf directly. - return outer.encode(val) + // Optimization: reuse *encBuffer when called by EncodeRLP. + if buf := encBufferFromWriter(w); buf != nil { + return buf.encode(val) } - eb := encbufPool.Get().(*encbuf) - defer encbufPool.Put(eb) - eb.reset() - if err := eb.encode(val); err != nil { + + buf := getEncBuffer() + defer encBufferPool.Put(buf) + if err := buf.encode(val); err != nil { return err } - return eb.toWriter(w) + return buf.writeTo(w) } -// EncodeBytes returns the RLP encoding of val. -// Please see the documentation of Encode for the encoding rules. +// EncodeToBytes returns the RLP encoding of val. +// Please see package-level documentation for the encoding rules. func EncodeToBytes(val interface{}) ([]byte, error) { - eb := encbufPool.Get().(*encbuf) - defer encbufPool.Put(eb) - eb.reset() - if err := eb.encode(val); err != nil { + buf := getEncBuffer() + defer encBufferPool.Put(buf) + + if err := buf.encode(val); err != nil { return nil, err } - return eb.toBytes(), nil + return buf.makeBytes(), nil } -// EncodeReader returns a reader from which the RLP encoding of val +// EncodeToReader returns a reader from which the RLP encoding of val // can be read. The returned size is the total size of the encoded // data. // // Please see the documentation of Encode for the encoding rules. func EncodeToReader(val interface{}) (size int, r io.Reader, err error) { - eb := encbufPool.Get().(*encbuf) - eb.reset() - if err := eb.encode(val); err != nil { + buf := getEncBuffer() + if err := buf.encode(val); err != nil { + encBufferPool.Put(buf) return 0, nil, err } - return eb.size(), &encReader{buf: eb}, nil -} - -type encbuf struct { - str []byte // string data, contains everything except list headers - lheads []*listhead // all list headers - lhsize int // sum of sizes of all encoded list headers - sizebuf []byte // 9-byte auxiliary buffer for uint encoding + // Note: can't put the reader back into the pool here + // because it is held by encReader. The reader puts it + // back when it has been fully consumed. + return buf.size(), &encReader{buf: buf}, nil } type listhead struct { @@ -151,214 +127,32 @@ func puthead(buf []byte, smalltag, largetag byte, size uint64) int { if size < 56 { buf[0] = smalltag + byte(size) return 1 - } else { - sizesize := putint(buf[1:], size) - buf[0] = largetag + byte(sizesize) - return sizesize + 1 - } -} - -// encbufs are pooled. -var encbufPool = sync.Pool{ - New: func() interface{} { return &encbuf{sizebuf: make([]byte, 9)} }, -} - -func (w *encbuf) reset() { - w.lhsize = 0 - if w.str != nil { - w.str = w.str[:0] - } - if w.lheads != nil { - w.lheads = w.lheads[:0] - } -} - -// encbuf implements io.Writer so it can be passed it into EncodeRLP. -func (w *encbuf) Write(b []byte) (int, error) { - w.str = append(w.str, b...) - return len(b), nil -} - -func (w *encbuf) encode(val interface{}) error { - rval := reflect.ValueOf(val) - ti, err := cachedTypeInfo(rval.Type(), tags{}) - if err != nil { - return err - } - return ti.writer(rval, w) -} - -func (w *encbuf) encodeStringHeader(size int) { - if size < 56 { - w.str = append(w.str, 0x80+byte(size)) - } else { - // TODO: encode to w.str directly - sizesize := putint(w.sizebuf[1:], uint64(size)) - w.sizebuf[0] = 0xB7 + byte(sizesize) - w.str = append(w.str, w.sizebuf[:sizesize+1]...) - } -} - -func (w *encbuf) encodeString(b []byte) { - if len(b) == 1 && b[0] <= 0x7F { - // fits single byte, no string header - w.str = append(w.str, b[0]) - } else { - w.encodeStringHeader(len(b)) - w.str = append(w.str, b...) - } -} - -func (w *encbuf) list() *listhead { - lh := &listhead{offset: len(w.str), size: w.lhsize} - w.lheads = append(w.lheads, lh) - return lh -} - -func (w *encbuf) listEnd(lh *listhead) { - lh.size = w.size() - lh.offset - lh.size - if lh.size < 56 { - w.lhsize += 1 // length encoded into kind tag - } else { - w.lhsize += 1 + intsize(uint64(lh.size)) - } -} - -func (w *encbuf) size() int { - return len(w.str) + w.lhsize -} - -func (w *encbuf) toBytes() []byte { - out := make([]byte, w.size()) - strpos := 0 - pos := 0 - for _, head := range w.lheads { - // write string data before header - n := copy(out[pos:], w.str[strpos:head.offset]) - pos += n - strpos += n - // write the header - enc := head.encode(out[pos:]) - pos += len(enc) - } - // copy string data after the last list header - copy(out[pos:], w.str[strpos:]) - return out -} - -func (w *encbuf) toWriter(out io.Writer) (err error) { - strpos := 0 - for _, head := range w.lheads { - // write string data before header - if head.offset-strpos > 0 { - n, err := out.Write(w.str[strpos:head.offset]) - strpos += n - if err != nil { - return err - } - } - // write the header - enc := head.encode(w.sizebuf) - if _, err = out.Write(enc); err != nil { - return err - } - } - if strpos < len(w.str) { - // write string data after the last list header - _, err = out.Write(w.str[strpos:]) - } - return err -} - -// encReader is the io.Reader returned by EncodeToReader. -// It releases its encbuf at EOF. -type encReader struct { - buf *encbuf // the buffer we're reading from. this is nil when we're at EOF. - lhpos int // index of list header that we're reading - strpos int // current position in string buffer - piece []byte // next piece to be read -} - -func (r *encReader) Read(b []byte) (n int, err error) { - for { - if r.piece = r.next(); r.piece == nil { - // Put the encode buffer back into the pool at EOF when it - // is first encountered. Subsequent calls still return EOF - // as the error but the buffer is no longer valid. - if r.buf != nil { - encbufPool.Put(r.buf) - r.buf = nil - } - return n, io.EOF - } - nn := copy(b[n:], r.piece) - n += nn - if nn < len(r.piece) { - // piece didn't fit, see you next time. - r.piece = r.piece[nn:] - return n, nil - } - r.piece = nil - } -} - -// next returns the next piece of data to be read. -// it returns nil at EOF. -func (r *encReader) next() []byte { - switch { - case r.buf == nil: - return nil - - case r.piece != nil: - // There is still data available for reading. - return r.piece - - case r.lhpos < len(r.buf.lheads): - // We're before the last list header. - head := r.buf.lheads[r.lhpos] - sizebefore := head.offset - r.strpos - if sizebefore > 0 { - // String data before header. - p := r.buf.str[r.strpos:head.offset] - r.strpos += sizebefore - return p - } else { - r.lhpos++ - return head.encode(r.buf.sizebuf) - } - - case r.strpos < len(r.buf.str): - // String data at the end, after all list headers. - p := r.buf.str[r.strpos:] - r.strpos = len(r.buf.str) - return p - - default: - return nil } + sizesize := putint(buf[1:], size) + buf[0] = largetag + byte(sizesize) + return sizesize + 1 } -var ( - encoderInterface = reflect.TypeOf(new(Encoder)).Elem() - big0 = big.NewInt(0) -) +var encoderInterface = reflect.TypeOf(new(Encoder)).Elem() // makeWriter creates a writer function for the given type. -func makeWriter(typ reflect.Type, ts tags) (writer, error) { +func makeWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { kind := typ.Kind() switch { case typ == rawValueType: return writeRawValue, nil - case typ.Implements(encoderInterface): - return writeEncoder, nil - case kind != reflect.Ptr && reflect.PtrTo(typ).Implements(encoderInterface): - return writeEncoderNoPtr, nil - case kind == reflect.Interface: - return writeInterface, nil - case typ.AssignableTo(reflect.PtrTo(bigInt)): + case typ.AssignableTo(reflect.PointerTo(bigInt)): return writeBigIntPtr, nil case typ.AssignableTo(bigInt): return writeBigIntNoPtr, nil + case typ == reflect.PointerTo(u256Int): + return writeU256IntPtr, nil + case typ == u256Int: + return writeU256IntNoPtr, nil + case kind == reflect.Ptr: + return makePtrWriter(typ, ts) + case reflect.PointerTo(typ).Implements(encoderInterface): + return makeEncoderWriter(typ), nil case isUint(kind): return writeUint, nil case kind == reflect.Bool: @@ -368,97 +162,116 @@ func makeWriter(typ reflect.Type, ts tags) (writer, error) { case kind == reflect.Slice && isByte(typ.Elem()): return writeBytes, nil case kind == reflect.Array && isByte(typ.Elem()): - return writeByteArray, nil + return makeByteArrayWriter(typ), nil case kind == reflect.Slice || kind == reflect.Array: return makeSliceWriter(typ, ts) case kind == reflect.Struct: return makeStructWriter(typ) - case kind == reflect.Ptr: - return makePtrWriter(typ) + case kind == reflect.Interface: + return writeInterface, nil default: return nil, fmt.Errorf("rlp: type %v is not RLP-serializable", typ) } } -func isByte(typ reflect.Type) bool { - return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface) -} - -func writeRawValue(val reflect.Value, w *encbuf) error { +func writeRawValue(val reflect.Value, w *encBuffer) error { w.str = append(w.str, val.Bytes()...) return nil } -func writeUint(val reflect.Value, w *encbuf) error { - i := val.Uint() - if i == 0 { - w.str = append(w.str, 0x80) - } else if i < 128 { - // fits single byte - w.str = append(w.str, byte(i)) - } else { - // TODO: encode int to w.str directly - s := putint(w.sizebuf[1:], i) - w.sizebuf[0] = 0x80 + byte(s) - w.str = append(w.str, w.sizebuf[:s+1]...) - } +func writeUint(val reflect.Value, w *encBuffer) error { + w.writeUint64(val.Uint()) return nil } -func writeBool(val reflect.Value, w *encbuf) error { - if val.Bool() { - w.str = append(w.str, 0x01) - } else { - w.str = append(w.str, 0x80) - } +func writeBool(val reflect.Value, w *encBuffer) error { + w.writeBool(val.Bool()) return nil } -func writeBigIntPtr(val reflect.Value, w *encbuf) error { +func writeBigIntPtr(val reflect.Value, w *encBuffer) error { ptr := val.Interface().(*big.Int) if ptr == nil { w.str = append(w.str, 0x80) return nil } - return writeBigInt(ptr, w) + if ptr.Sign() == -1 { + return ErrNegativeBigInt + } + w.writeBigInt(ptr) + return nil } -func writeBigIntNoPtr(val reflect.Value, w *encbuf) error { +func writeBigIntNoPtr(val reflect.Value, w *encBuffer) error { i := val.Interface().(big.Int) - return writeBigInt(&i, w) + if i.Sign() == -1 { + return ErrNegativeBigInt + } + w.writeBigInt(&i) + return nil } -func writeBigInt(i *big.Int, w *encbuf) error { - if cmp := i.Cmp(big0); cmp == -1 { - return fmt.Errorf("rlp: cannot encode negative *big.Int") - } else if cmp == 0 { +func writeU256IntPtr(val reflect.Value, w *encBuffer) error { + ptr := val.Interface().(*uint256.Int) + if ptr == nil { w.str = append(w.str, 0x80) - } else { - w.encodeString(i.Bytes()) + return nil } + w.writeUint256(ptr) + return nil +} + +func writeU256IntNoPtr(val reflect.Value, w *encBuffer) error { + i := val.Interface().(uint256.Int) + w.writeUint256(&i) return nil } -func writeBytes(val reflect.Value, w *encbuf) error { - w.encodeString(val.Bytes()) +func writeBytes(val reflect.Value, w *encBuffer) error { + w.writeBytes(val.Bytes()) return nil } -func writeByteArray(val reflect.Value, w *encbuf) error { - if !val.CanAddr() { - // Slice requires the value to be addressable. - // Make it addressable by copying. - copy := reflect.New(val.Type()).Elem() - copy.Set(val) - val = copy +func makeByteArrayWriter(typ reflect.Type) writer { + switch typ.Len() { + case 0: + return writeLengthZeroByteArray + case 1: + return writeLengthOneByteArray + default: + length := typ.Len() + return func(val reflect.Value, w *encBuffer) error { + if !val.CanAddr() { + // Getting the byte slice of val requires it to be addressable. Make it + // addressable by copying. + copy := reflect.New(val.Type()).Elem() + copy.Set(val) + val = copy + } + slice := byteArrayBytes(val, length) + w.encodeStringHeader(len(slice)) + w.str = append(w.str, slice...) + return nil + } } - size := val.Len() - slice := val.Slice(0, size).Bytes() - w.encodeString(slice) +} + +func writeLengthZeroByteArray(val reflect.Value, w *encBuffer) error { + w.str = append(w.str, 0x80) return nil } -func writeString(val reflect.Value, w *encbuf) error { +func writeLengthOneByteArray(val reflect.Value, w *encBuffer) error { + b := byte(val.Index(0).Uint()) + if b <= 0x7f { + w.str = append(w.str, b) + } else { + w.str = append(w.str, 0x81, b) + } + return nil +} + +func writeString(val reflect.Value, w *encBuffer) error { s := val.String() if len(s) == 1 && s[0] <= 0x7f { // fits single byte, no string header @@ -470,27 +283,7 @@ func writeString(val reflect.Value, w *encbuf) error { return nil } -func writeEncoder(val reflect.Value, w *encbuf) error { - return val.Interface().(Encoder).EncodeRLP(w) -} - -// writeEncoderNoPtr handles non-pointer values that implement Encoder -// with a pointer receiver. -func writeEncoderNoPtr(val reflect.Value, w *encbuf) error { - if !val.CanAddr() { - // We can't get the address. It would be possible to make the - // value addressable by creating a shallow copy, but this - // creates other problems so we're not doing it (yet). - // - // package json simply doesn't call MarshalJSON for cases like - // this, but encodes the value as if it didn't implement the - // interface. We don't want to handle it that way. - return fmt.Errorf("rlp: game over: unadressable value of type %v, EncodeRLP is pointer method", val.Type()) - } - return val.Addr().Interface().(Encoder).EncodeRLP(w) -} - -func writeInterface(val reflect.Value, w *encbuf) error { +func writeInterface(val reflect.Value, w *encBuffer) error { if val.IsNil() { // Write empty list. This is consistent with the previous RLP // encoder that we had and should therefore avoid any @@ -499,31 +292,51 @@ func writeInterface(val reflect.Value, w *encbuf) error { return nil } eval := val.Elem() - ti, err := cachedTypeInfo(eval.Type(), tags{}) + writer, err := cachedWriter(eval.Type()) if err != nil { return err } - return ti.writer(eval, w) + return writer(eval, w) } -func makeSliceWriter(typ reflect.Type, ts tags) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err +func makeSliceWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { + etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr } - writer := func(val reflect.Value, w *encbuf) error { - if !ts.tail { - defer w.listEnd(w.list()) + + var wfn writer + if ts.Tail { + // This is for struct tail slices. + // w.list is not called for them. + wfn = func(val reflect.Value, w *encBuffer) error { + vlen := val.Len() + for i := 0; i < vlen; i++ { + if err := etypeinfo.writer(val.Index(i), w); err != nil { + return err + } + } + return nil } - vlen := val.Len() - for i := 0; i < vlen; i++ { - if err := etypeinfo.writer(val.Index(i), w); err != nil { - return err + } else { + // This is for regular slices and arrays. + wfn = func(val reflect.Value, w *encBuffer) error { + vlen := val.Len() + if vlen == 0 { + w.str = append(w.str, 0xC0) + return nil + } + listOffset := w.list() + for i := 0; i < vlen; i++ { + if err := etypeinfo.writer(val.Index(i), w); err != nil { + return err + } } + w.listEnd(listOffset) + return nil } - return nil } - return writer, nil + return wfn, nil } func makeStructWriter(typ reflect.Type) (writer, error) { @@ -531,56 +344,86 @@ func makeStructWriter(typ reflect.Type) (writer, error) { if err != nil { return nil, err } - writer := func(val reflect.Value, w *encbuf) error { - lh := w.list() - for _, f := range fields { - if err := f.info.writer(val.Field(f.index), w); err != nil { - return err + for _, f := range fields { + if f.info.writerErr != nil { + return nil, structFieldError{typ, f.index, f.info.writerErr} + } + } + + var writer writer + firstOptionalField := firstOptionalField(fields) + if firstOptionalField == len(fields) { + // This is the writer function for structs without any optional fields. + writer = func(val reflect.Value, w *encBuffer) error { + lh := w.list() + for _, f := range fields { + if err := f.info.writer(val.Field(f.index), w); err != nil { + return err + } } + w.listEnd(lh) + return nil + } + } else { + // If there are any "optional" fields, the writer needs to perform additional + // checks to determine the output list length. + writer = func(val reflect.Value, w *encBuffer) error { + lastField := len(fields) - 1 + for ; lastField >= firstOptionalField; lastField-- { + if !val.Field(fields[lastField].index).IsZero() { + break + } + } + lh := w.list() + for i := 0; i <= lastField; i++ { + if err := fields[i].info.writer(val.Field(fields[i].index), w); err != nil { + return err + } + } + w.listEnd(lh) + return nil } - w.listEnd(lh) - return nil } return writer, nil } -func makePtrWriter(typ reflect.Type) (writer, error) { - etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{}) - if err != nil { - return nil, err +func makePtrWriter(typ reflect.Type, ts rlpstruct.Tags) (writer, error) { + nilEncoding := byte(0xC0) + if typeNilKind(typ.Elem(), ts) == String { + nilEncoding = 0x80 } - // determine nil pointer handler - var nilfunc func(*encbuf) error - kind := typ.Elem().Kind() - switch { - case kind == reflect.Array && isByte(typ.Elem().Elem()): - nilfunc = func(w *encbuf) error { - w.str = append(w.str, 0x80) - return nil - } - case kind == reflect.Struct || kind == reflect.Array: - nilfunc = func(w *encbuf) error { - // encoding the zero value of a struct/array could trigger - // infinite recursion, avoid that. - w.listEnd(w.list()) - return nil - } - default: - zero := reflect.Zero(typ.Elem()) - nilfunc = func(w *encbuf) error { - return etypeinfo.writer(zero, w) + etypeinfo := theTC.infoWhileGenerating(typ.Elem(), rlpstruct.Tags{}) + if etypeinfo.writerErr != nil { + return nil, etypeinfo.writerErr + } + + writer := func(val reflect.Value, w *encBuffer) error { + if ev := val.Elem(); ev.IsValid() { + return etypeinfo.writer(ev, w) } + w.str = append(w.str, nilEncoding) + return nil } + return writer, nil +} - writer := func(val reflect.Value, w *encbuf) error { - if val.IsNil() { - return nilfunc(w) - } else { - return etypeinfo.writer(val.Elem(), w) +func makeEncoderWriter(typ reflect.Type) writer { + if typ.Implements(encoderInterface) { + return func(val reflect.Value, w *encBuffer) error { + return val.Interface().(Encoder).EncodeRLP(w) + } + } + w := func(val reflect.Value, w *encBuffer) error { + if !val.CanAddr() { + // package json simply doesn't call MarshalJSON for this case, but encodes the + // value as if it didn't implement the interface. We don't want to handle it that + // way. + return fmt.Errorf("rlp: unaddressable value of type %v, EncodeRLP is pointer method", val.Type()) } + return val.Addr().Interface().(Encoder).EncodeRLP(w) } - return writer, err + return w } // putint writes i to the beginning of b in big endian byte diff --git a/rlp/encode_test.go b/rlp/encode_test.go index 12e8f275551b..5fc2d116efda 100644 --- a/rlp/encode_test.go +++ b/rlp/encode_test.go @@ -22,8 +22,12 @@ import ( "fmt" "io" "math/big" + "runtime" "sync" "testing" + + "github.com/XinFinOrg/XDPoSChain/common/math" + "github.com/holiman/uint256" ) type testEncoder struct { @@ -32,12 +36,19 @@ type testEncoder struct { func (e *testEncoder) EncodeRLP(w io.Writer) error { if e == nil { - w.Write([]byte{0, 0, 0, 0}) - } else if e.err != nil { + panic("EncodeRLP called on nil value") + } + if e.err != nil { return e.err - } else { - w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) } + w.Write([]byte{0, 1, 0, 1, 0, 1, 0, 1, 0, 1}) + return nil +} + +type testEncoderValueMethod struct{} + +func (e testEncoderValueMethod) EncodeRLP(w io.Writer) error { + w.Write([]byte{0xFA, 0xFE, 0xF0}) return nil } @@ -48,6 +59,13 @@ func (e byteEncoder) EncodeRLP(w io.Writer) error { return nil } +type undecodableEncoder func() + +func (f undecodableEncoder) EncodeRLP(w io.Writer) error { + w.Write([]byte{0xF5, 0xF5, 0xF5}) + return nil +} + type encodableReader struct { A, B uint } @@ -102,35 +120,95 @@ var encTests = []encTest{ {val: big.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"}, {val: big.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"}, { - val: big.NewInt(0).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), + val: new(big.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), output: "8F102030405060708090A0B0C0D0E0F2", }, { - val: big.NewInt(0).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), + val: new(big.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), output: "9C0100020003000400050006000700080009000A000B000C000D000E01", }, { - val: big.NewInt(0).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")), + val: new(big.Int).SetBytes(unhex("010000000000000000000000000000000000000000000000000000000000000000")), output: "A1010000000000000000000000000000000000000000000000000000000000000000", }, + { + val: veryBigInt, + output: "89FFFFFFFFFFFFFFFFFF", + }, + { + val: veryVeryBigInt, + output: "B848FFFFFFFFFFFFFFFFF800000000000000001BFFFFFFFFFFFFFFFFC8000000000000000045FFFFFFFFFFFFFFFFC800000000000000001BFFFFFFFFFFFFFFFFF8000000000000000001", + }, // non-pointer big.Int {val: *big.NewInt(0), output: "80"}, {val: *big.NewInt(0xFFFFFF), output: "83FFFFFF"}, // negative ints are not supported - {val: big.NewInt(-1), error: "rlp: cannot encode negative *big.Int"}, - - // byte slices, strings + {val: big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, + {val: *big.NewInt(-1), error: "rlp: cannot encode negative big.Int"}, + + // uint256 + {val: uint256.NewInt(0), output: "80"}, + {val: uint256.NewInt(1), output: "01"}, + {val: uint256.NewInt(127), output: "7F"}, + {val: uint256.NewInt(128), output: "8180"}, + {val: uint256.NewInt(256), output: "820100"}, + {val: uint256.NewInt(1024), output: "820400"}, + {val: uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFF), output: "84FFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFF), output: "85FFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFF), output: "86FFFFFFFFFFFF"}, + {val: uint256.NewInt(0xFFFFFFFFFFFFFF), output: "87FFFFFFFFFFFFFF"}, + { + val: new(uint256.Int).SetBytes(unhex("102030405060708090A0B0C0D0E0F2")), + output: "8F102030405060708090A0B0C0D0E0F2", + }, + { + val: new(uint256.Int).SetBytes(unhex("0100020003000400050006000700080009000A000B000C000D000E01")), + output: "9C0100020003000400050006000700080009000A000B000C000D000E01", + }, + // non-pointer uint256.Int + {val: *uint256.NewInt(0), output: "80"}, + {val: *uint256.NewInt(0xFFFFFF), output: "83FFFFFF"}, + + // byte arrays + {val: [0]byte{}, output: "80"}, + {val: [1]byte{0}, output: "00"}, + {val: [1]byte{1}, output: "01"}, + {val: [1]byte{0x7F}, output: "7F"}, + {val: [1]byte{0x80}, output: "8180"}, + {val: [1]byte{0xFF}, output: "81FF"}, + {val: [3]byte{1, 2, 3}, output: "83010203"}, + {val: [57]byte{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + + // named byte type arrays + {val: [0]namedByteType{}, output: "80"}, + {val: [1]namedByteType{0}, output: "00"}, + {val: [1]namedByteType{1}, output: "01"}, + {val: [1]namedByteType{0x7F}, output: "7F"}, + {val: [1]namedByteType{0x80}, output: "8180"}, + {val: [1]namedByteType{0xFF}, output: "81FF"}, + {val: [3]namedByteType{1, 2, 3}, output: "83010203"}, + {val: [57]namedByteType{1, 2, 3}, output: "B839010203000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + + // byte slices {val: []byte{}, output: "80"}, + {val: []byte{0}, output: "00"}, {val: []byte{0x7E}, output: "7E"}, {val: []byte{0x7F}, output: "7F"}, {val: []byte{0x80}, output: "8180"}, {val: []byte{1, 2, 3}, output: "83010203"}, + // named byte type slices + {val: []namedByteType{}, output: "80"}, + {val: []namedByteType{0}, output: "00"}, + {val: []namedByteType{0x7E}, output: "7E"}, + {val: []namedByteType{0x7F}, output: "7F"}, + {val: []namedByteType{0x80}, output: "8180"}, {val: []namedByteType{1, 2, 3}, output: "83010203"}, - {val: [...]namedByteType{1, 2, 3}, output: "83010203"}, + // strings {val: "", output: "80"}, {val: "\x7E", output: "7E"}, {val: "\x7F", output: "7F"}, @@ -203,6 +281,12 @@ var encTests = []encTest{ output: "F90200CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376CF84617364668471776572847A786376", }, + // Non-byte arrays are encoded as lists. + // Note that it is important to test [4]uint64 specifically, + // because that's the underlying type of uint256.Int. + {val: [4]uint32{1, 2, 3, 4}, output: "C401020304"}, + {val: [4]uint64{1, 2, 3, 4}, output: "C401020304"}, + // RawValue {val: RawValue(unhex("01")), output: "01"}, {val: RawValue(unhex("82FFFF")), output: "82FFFF"}, @@ -213,11 +297,34 @@ var encTests = []encTest{ {val: simplestruct{A: 3, B: "foo"}, output: "C50383666F6F"}, {val: &recstruct{5, nil}, output: "C205C0"}, {val: &recstruct{5, &recstruct{4, &recstruct{3, nil}}}, output: "C605C404C203C0"}, + {val: &intField{X: 3}, error: "rlp: type int is not RLP-serializable (struct field rlp.intField.X)"}, + + // struct tag "-" + {val: &ignoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + + // struct tag "tail" {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02"), unhex("03")}}, output: "C3010203"}, {val: &tailRaw{A: 1, Tail: []RawValue{unhex("02")}}, output: "C20102"}, {val: &tailRaw{A: 1, Tail: []RawValue{}}, output: "C101"}, {val: &tailRaw{A: 1, Tail: nil}, output: "C101"}, - {val: &hasIgnoredField{A: 1, B: 2, C: 3}, output: "C20103"}, + + // struct tag "optional" + {val: &optionalFields{}, output: "C180"}, + {val: &optionalFields{A: 1}, output: "C101"}, + {val: &optionalFields{A: 1, B: 2}, output: "C20102"}, + {val: &optionalFields{A: 1, B: 2, C: 3}, output: "C3010203"}, + {val: &optionalFields{A: 1, B: 0, C: 3}, output: "C3018003"}, + {val: &optionalAndTailField{A: 1}, output: "C101"}, + {val: &optionalAndTailField{A: 1, B: 2}, output: "C20102"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalAndTailField{A: 1, Tail: []uint{5, 6}}, output: "C401800506"}, + {val: &optionalBigIntField{A: 1}, output: "C101"}, + {val: &optionalPtrField{A: 1}, output: "C101"}, + {val: &optionalPtrFieldNil{A: 1}, output: "C101"}, + {val: &multipleOptionalFields{A: nil, B: nil}, output: "C0"}, + {val: &multipleOptionalFields{A: &[3]byte{1, 2, 3}, B: &[3]byte{1, 2, 3}}, output: "C88301020383010203"}, + {val: &multipleOptionalFields{A: nil, B: &[3]byte{1, 2, 3}}, output: "C58083010203"}, // encodes without error but decode will fail + {val: &nonOptionalPtrField{A: 1}, output: "C20180"}, // encodes without error but decode will fail // nil {val: (*uint)(nil), output: "80"}, @@ -225,26 +332,73 @@ var encTests = []encTest{ {val: (*[]byte)(nil), output: "80"}, {val: (*[10]byte)(nil), output: "80"}, {val: (*big.Int)(nil), output: "80"}, + {val: (*uint256.Int)(nil), output: "80"}, {val: (*[]string)(nil), output: "C0"}, {val: (*[10]string)(nil), output: "C0"}, {val: (*[]interface{})(nil), output: "C0"}, {val: (*[]struct{ uint })(nil), output: "C0"}, {val: (*interface{})(nil), output: "C0"}, + // nil struct fields + { + val: struct { + X *[]byte + }{}, + output: "C180", + }, + { + val: struct { + X *[2]byte + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 + }{}, + output: "C180", + }, + { + val: struct { + X *uint64 `rlp:"nilList"` + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 + }{}, + output: "C1C0", + }, + { + val: struct { + X *[]uint64 `rlp:"nilString"` + }{}, + output: "C180", + }, + // interfaces {val: []io.Reader{reader}, output: "C3C20102"}, // the contained value is a struct // Encoder - {val: (*testEncoder)(nil), output: "00000000"}, + {val: (*testEncoder)(nil), output: "C0"}, {val: &testEncoder{}, output: "00010001000100010001"}, {val: &testEncoder{errors.New("test error")}, error: "test error"}, - // verify that pointer method testEncoder.EncodeRLP is called for + {val: struct{ E testEncoderValueMethod }{}, output: "C3FAFEF0"}, + {val: struct{ E *testEncoderValueMethod }{}, output: "C1C0"}, + + // Verify that the Encoder interface works for unsupported types like func(). + {val: undecodableEncoder(func() {}), output: "F5F5F5"}, + + // Verify that pointer method testEncoder.EncodeRLP is called for // addressable non-pointer values. {val: &struct{ TE testEncoder }{testEncoder{}}, output: "CA00010001000100010001"}, {val: &struct{ TE testEncoder }{testEncoder{errors.New("test error")}}, error: "test error"}, - // verify the error for non-addressable non-pointer Encoder - {val: testEncoder{}, error: "rlp: game over: unadressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, - // verify the special case for []byte + + // Verify the error for non-addressable non-pointer Encoder. + {val: testEncoder{}, error: "rlp: unaddressable value of type rlp.testEncoder, EncodeRLP is pointer method"}, + + // Verify Encoder takes precedence over []byte. {val: []byteEncoder{0, 1, 2, 3, 4}, output: "C5C0C0C0C0C0"}, } @@ -280,6 +434,21 @@ func TestEncodeToBytes(t *testing.T) { runEncTests(t, EncodeToBytes) } +func TestEncodeAppendToBytes(t *testing.T) { + buffer := make([]byte, 20) + runEncTests(t, func(val interface{}) ([]byte, error) { + w := NewEncoderBuffer(nil) + defer w.Flush() + + err := Encode(w, val) + if err != nil { + return nil, err + } + output := w.AppendToBytes(buffer[:0]) + return output, nil + }) +} + func TestEncodeToReader(t *testing.T) { runEncTests(t, func(val interface{}) ([]byte, error) { _, r, err := EncodeToReader(val) @@ -338,3 +507,132 @@ func TestEncodeToReaderReturnToPool(t *testing.T) { } wg.Wait() } + +var sink interface{} + +func BenchmarkIntsize(b *testing.B) { + for i := 0; i < b.N; i++ { + sink = intsize(0x12345678) + } +} + +func BenchmarkPutint(b *testing.B) { + buf := make([]byte, 8) + for i := 0; i < b.N; i++ { + putint(buf, 0x12345678) + sink = buf + } +} + +func BenchmarkEncodeBigInts(b *testing.B) { + ints := make([]*big.Int, 200) + for i := range ints { + ints[i] = math.BigPow(2, int64(i)) + } + out := bytes.NewBuffer(make([]byte, 0, 4096)) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(out, ints); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeU256Ints(b *testing.B) { + ints := make([]*uint256.Int, 200) + for i := range ints { + ints[i], _ = uint256.FromBig(math.BigPow(2, int64(i))) + } + out := bytes.NewBuffer(make([]byte, 0, 4096)) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(out, ints); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkEncodeConcurrentInterface(b *testing.B) { + type struct1 struct { + A string + B *big.Int + C [20]byte + } + value := []interface{}{ + uint(999), + &struct1{A: "hello", B: big.NewInt(0xFFFFFFFF)}, + [10]byte{1, 2, 3, 4, 5, 6}, + []string{"yeah", "yeah", "yeah"}, + } + + var wg sync.WaitGroup + for cpu := 0; cpu < runtime.NumCPU(); cpu++ { + wg.Add(1) + go func() { + defer wg.Done() + + var buffer bytes.Buffer + for i := 0; i < b.N; i++ { + buffer.Reset() + err := Encode(&buffer, value) + if err != nil { + panic(err) + } + } + }() + } + wg.Wait() +} + +type byteArrayStruct struct { + A [20]byte + B [32]byte + C [32]byte +} + +func BenchmarkEncodeByteArrayStruct(b *testing.B) { + var out bytes.Buffer + var value byteArrayStruct + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(&out, &value); err != nil { + b.Fatal(err) + } + } +} + +type structSliceElem struct { + X uint64 + Y uint64 + Z uint64 +} + +type structPtrSlice []*structSliceElem + +func BenchmarkEncodeStructPtrSlice(b *testing.B) { + var out bytes.Buffer + var value = structPtrSlice{ + &structSliceElem{1, 1, 1}, + &structSliceElem{2, 2, 2}, + &structSliceElem{3, 3, 3}, + &structSliceElem{5, 5, 5}, + &structSliceElem{6, 6, 6}, + &structSliceElem{7, 7, 7}, + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + out.Reset() + if err := Encode(&out, &value); err != nil { + b.Fatal(err) + } + } +} diff --git a/rlp/encoder_example_test.go b/rlp/encoder_example_test.go index 1cffa241c259..da0465a53312 100644 --- a/rlp/encoder_example_test.go +++ b/rlp/encoder_example_test.go @@ -14,11 +14,13 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -package rlp +package rlp_test import ( "fmt" "io" + + "github.com/XinFinOrg/XDPoSChain/rlp" ) type MyCoolType struct { @@ -28,27 +30,19 @@ type MyCoolType struct { // EncodeRLP writes x as RLP list [a, b] that omits the Name field. func (x *MyCoolType) EncodeRLP(w io.Writer) (err error) { - // Note: the receiver can be a nil pointer. This allows you to - // control the encoding of nil, but it also means that you have to - // check for a nil receiver. - if x == nil { - err = Encode(w, []uint{0, 0}) - } else { - err = Encode(w, []uint{x.a, x.b}) - } - return err + return rlp.Encode(w, []uint{x.a, x.b}) } func ExampleEncoder() { var t *MyCoolType // t is nil pointer to MyCoolType - bytes, _ := EncodeToBytes(t) + bytes, _ := rlp.EncodeToBytes(t) fmt.Printf("%v → %X\n", t, bytes) t = &MyCoolType{Name: "foobar", a: 5, b: 6} - bytes, _ = EncodeToBytes(t) + bytes, _ = rlp.EncodeToBytes(t) fmt.Printf("%v → %X\n", t, bytes) // Output: - // → C28080 + // → C0 // &{foobar 5 6} → C20506 } diff --git a/rlp/internal/rlpstruct/rlpstruct.go b/rlp/internal/rlpstruct/rlpstruct.go new file mode 100644 index 000000000000..2e3eeb688193 --- /dev/null +++ b/rlp/internal/rlpstruct/rlpstruct.go @@ -0,0 +1,213 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package rlpstruct implements struct processing for RLP encoding/decoding. +// +// In particular, this package handles all rules around field filtering, +// struct tags and nil value determination. +package rlpstruct + +import ( + "fmt" + "reflect" + "strings" +) + +// Field represents a struct field. +type Field struct { + Name string + Index int + Exported bool + Type Type + Tag string +} + +// Type represents the attributes of a Go type. +type Type struct { + Name string + Kind reflect.Kind + IsEncoder bool // whether type implements rlp.Encoder + IsDecoder bool // whether type implements rlp.Decoder + Elem *Type // non-nil for Kind values of Ptr, Slice, Array +} + +// DefaultNilValue determines whether a nil pointer to t encodes/decodes +// as an empty string or empty list. +func (t Type) DefaultNilValue() NilKind { + k := t.Kind + if isUint(k) || k == reflect.String || k == reflect.Bool || isByteArray(t) { + return NilKindString + } + return NilKindList +} + +// NilKind is the RLP value encoded in place of nil pointers. +type NilKind uint8 + +const ( + NilKindString NilKind = 0x80 + NilKindList NilKind = 0xC0 +) + +// Tags represents struct tags. +type Tags struct { + // rlp:"nil" controls whether empty input results in a nil pointer. + // nilKind is the kind of empty value allowed for the field. + NilKind NilKind + NilOK bool + + // rlp:"optional" allows for a field to be missing in the input list. + // If this is set, all subsequent fields must also be optional. + Optional bool + + // rlp:"tail" controls whether this field swallows additional list elements. It can + // only be set for the last field, which must be of slice type. + Tail bool + + // rlp:"-" ignores fields. + Ignored bool +} + +// TagError is raised for invalid struct tags. +type TagError struct { + StructType string + + // These are set by this package. + Field string + Tag string + Err string +} + +func (e TagError) Error() string { + field := "field " + e.Field + if e.StructType != "" { + field = e.StructType + "." + e.Field + } + return fmt.Sprintf("rlp: invalid struct tag %q for %s (%s)", e.Tag, field, e.Err) +} + +// ProcessFields filters the given struct fields, returning only fields +// that should be considered for encoding/decoding. +func ProcessFields(allFields []Field) ([]Field, []Tags, error) { + lastPublic := lastPublicField(allFields) + + // Gather all exported fields and their tags. + var fields []Field + var tags []Tags + for _, field := range allFields { + if !field.Exported { + continue + } + ts, err := parseTag(field, lastPublic) + if err != nil { + return nil, nil, err + } + if ts.Ignored { + continue + } + fields = append(fields, field) + tags = append(tags, ts) + } + + // Verify optional field consistency. If any optional field exists, + // all fields after it must also be optional. Note: optional + tail + // is supported. + var anyOptional bool + var firstOptionalName string + for i, ts := range tags { + name := fields[i].Name + if ts.Optional || ts.Tail { + if !anyOptional { + firstOptionalName = name + } + anyOptional = true + } else { + if anyOptional { + msg := fmt.Sprintf("must be optional because preceding field %q is optional", firstOptionalName) + return nil, nil, TagError{Field: name, Err: msg} + } + } + } + return fields, tags, nil +} + +func parseTag(field Field, lastPublic int) (Tags, error) { + name := field.Name + tag := reflect.StructTag(field.Tag) + var ts Tags + for _, t := range strings.Split(tag.Get("rlp"), ",") { + switch t = strings.TrimSpace(t); t { + case "": + // empty tag is allowed for some reason + case "-": + ts.Ignored = true + case "nil", "nilString", "nilList": + ts.NilOK = true + if field.Type.Kind != reflect.Ptr { + return ts, TagError{Field: name, Tag: t, Err: "field is not a pointer"} + } + switch t { + case "nil": + ts.NilKind = field.Type.Elem.DefaultNilValue() + case "nilString": + ts.NilKind = NilKindString + case "nilList": + ts.NilKind = NilKindList + } + case "optional": + ts.Optional = true + if ts.Tail { + return ts, TagError{Field: name, Tag: t, Err: `also has "tail" tag`} + } + case "tail": + ts.Tail = true + if field.Index != lastPublic { + return ts, TagError{Field: name, Tag: t, Err: "must be on last field"} + } + if ts.Optional { + return ts, TagError{Field: name, Tag: t, Err: `also has "optional" tag`} + } + if field.Type.Kind != reflect.Slice { + return ts, TagError{Field: name, Tag: t, Err: "field type is not slice"} + } + default: + return ts, TagError{Field: name, Tag: t, Err: "unknown tag"} + } + } + return ts, nil +} + +func lastPublicField(fields []Field) int { + last := 0 + for _, f := range fields { + if f.Exported { + last = f.Index + } + } + return last +} + +func isUint(k reflect.Kind) bool { + return k >= reflect.Uint && k <= reflect.Uintptr +} + +func isByte(typ Type) bool { + return typ.Kind == reflect.Uint8 && !typ.IsEncoder +} + +func isByteArray(typ Type) bool { + return (typ.Kind == reflect.Slice || typ.Kind == reflect.Array) && isByte(*typ.Elem) +} diff --git a/rlp/iterator.go b/rlp/iterator.go new file mode 100644 index 000000000000..95bd3f258208 --- /dev/null +++ b/rlp/iterator.go @@ -0,0 +1,59 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +type listIterator struct { + data []byte + next []byte + err error +} + +// NewListIterator creates an iterator for the (list) represented by data +func NewListIterator(data RawValue) (*listIterator, error) { + k, t, c, err := readKind(data) + if err != nil { + return nil, err + } + if k != List { + return nil, ErrExpectedList + } + it := &listIterator{ + data: data[t : t+c], + } + return it, nil +} + +// Next forwards the iterator one step, returns true if it was not at end yet +func (it *listIterator) Next() bool { + if len(it.data) == 0 { + return false + } + _, t, c, err := readKind(it.data) + it.next = it.data[:t+c] + it.data = it.data[t+c:] + it.err = err + return true +} + +// Value returns the current value +func (it *listIterator) Value() []byte { + return it.next +} + +func (it *listIterator) Err() error { + return it.err +} diff --git a/rlp/iterator_test.go b/rlp/iterator_test.go new file mode 100644 index 000000000000..82ac7bfa6eb7 --- /dev/null +++ b/rlp/iterator_test.go @@ -0,0 +1,59 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package rlp + +import ( + "testing" + + "github.com/XinFinOrg/XDPoSChain/common/hexutil" +) + +// TestIterator tests some basic things about the ListIterator. A more +// comprehensive test can be found in core/rlp_test.go, where we can +// use both types and rlp without dependency cycles +func TestIterator(t *testing.T) { + bodyRlpHex := "0xf902cbf8d6f869800182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ba01025c66fad28b4ce3370222624d952c35529e602af7cbe04f667371f61b0e3b3a00ab8813514d1217059748fd903288ace1b4001a4bc5fbde2790debdc8167de2ff869010182c35094000000000000000000000000000000000000aaaa808a000000000000000000001ca05ac4cf1d19be06f3742c21df6c49a7e929ceb3dbaf6a09f3cfb56ff6828bd9a7a06875970133a35e63ac06d360aa166d228cc013e9b96e0a2cae7f55b22e1ee2e8f901f0f901eda0c75448377c0e426b8017b23c5f77379ecf69abc1d5c224284ad3ba1c46c59adaa00000000000000000000000000000000000000000000000000000000000000000940000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000a00000000000000000000000000000000000000000000000000000000000000000b9010000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000808080808080a00000000000000000000000000000000000000000000000000000000000000000880000000000000000" + bodyRlp := hexutil.MustDecode(bodyRlpHex) + + it, err := NewListIterator(bodyRlp) + if err != nil { + t.Fatal(err) + } + // Check that txs exist + if !it.Next() { + t.Fatal("expected two elems, got zero") + } + txs := it.Value() + // Check that uncles exist + if !it.Next() { + t.Fatal("expected two elems, got one") + } + txit, err := NewListIterator(txs) + if err != nil { + t.Fatal(err) + } + var i = 0 + for txit.Next() { + if txit.err != nil { + t.Fatal(txit.err) + } + i++ + } + if exp := 2; i != exp { + t.Errorf("count wrong, expected %d got %d", i, exp) + } +} diff --git a/rlp/raw.go b/rlp/raw.go index 2b3f328f6618..773aa7e614e8 100644 --- a/rlp/raw.go +++ b/rlp/raw.go @@ -28,12 +28,53 @@ type RawValue []byte var rawValueType = reflect.TypeOf(RawValue{}) +// StringSize returns the encoded size of a string. +func StringSize(s string) uint64 { + switch { + case len(s) == 0: + return 1 + case len(s) == 1: + if s[0] <= 0x7f { + return 1 + } else { + return 2 + } + default: + return uint64(headsize(uint64(len(s))) + len(s)) + } +} + +// BytesSize returns the encoded size of a byte slice. +func BytesSize(b []byte) uint64 { + switch { + case len(b) == 0: + return 1 + case len(b) == 1: + if b[0] <= 0x7f { + return 1 + } else { + return 2 + } + default: + return uint64(headsize(uint64(len(b))) + len(b)) + } +} + // ListSize returns the encoded size of an RLP list with the given // content size. func ListSize(contentSize uint64) uint64 { return uint64(headsize(contentSize)) + contentSize } +// IntSize returns the encoded size of the integer x. Note: The return type of this +// function is 'int' for backwards-compatibility reasons. The result is always positive. +func IntSize(x uint64) int { + if x < 0x80 { + return 1 + } + return 1 + intsize(x) +} + // Split returns the content of first RLP value and any // bytes after the value as subslices of b. func Split(b []byte) (k Kind, content, rest []byte, err error) { @@ -57,6 +98,32 @@ func SplitString(b []byte) (content, rest []byte, err error) { return content, rest, nil } +// SplitUint64 decodes an integer at the beginning of b. +// It also returns the remaining data after the integer in 'rest'. +func SplitUint64(b []byte) (x uint64, rest []byte, err error) { + content, rest, err := SplitString(b) + if err != nil { + return 0, b, err + } + switch { + case len(content) == 0: + return 0, rest, nil + case len(content) == 1: + if content[0] == 0 { + return 0, b, ErrCanonInt + } + return uint64(content[0]), rest, nil + case len(content) > 8: + return 0, b, errUintOverflow + default: + x, err = readSize(content, byte(len(content))) + if err != nil { + return 0, b, ErrCanonInt + } + return x, rest, nil + } +} + // SplitList splits b into the content of a list and any remaining // bytes after the list. func SplitList(b []byte) (content, rest []byte, err error) { @@ -154,3 +221,74 @@ func readSize(b []byte, slen byte) (uint64, error) { } return s, nil } + +// AppendUint64 appends the RLP encoding of i to b, and returns the resulting slice. +func AppendUint64(b []byte, i uint64) []byte { + if i == 0 { + return append(b, 0x80) + } else if i < 128 { + return append(b, byte(i)) + } + switch { + case i < (1 << 8): + return append(b, 0x81, byte(i)) + case i < (1 << 16): + return append(b, 0x82, + byte(i>>8), + byte(i), + ) + case i < (1 << 24): + return append(b, 0x83, + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 32): + return append(b, 0x84, + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 40): + return append(b, 0x85, + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + case i < (1 << 48): + return append(b, 0x86, + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + case i < (1 << 56): + return append(b, 0x87, + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + + default: + return append(b, 0x88, + byte(i>>56), + byte(i>>48), + byte(i>>40), + byte(i>>32), + byte(i>>24), + byte(i>>16), + byte(i>>8), + byte(i), + ) + } +} diff --git a/rlp/raw_test.go b/rlp/raw_test.go index 2aad042100ea..7b3255eca36b 100644 --- a/rlp/raw_test.go +++ b/rlp/raw_test.go @@ -18,9 +18,10 @@ package rlp import ( "bytes" + "errors" "io" - "reflect" "testing" + "testing/quick" ) func TestCountValues(t *testing.T) { @@ -53,21 +54,84 @@ func TestCountValues(t *testing.T) { if count != test.count { t.Errorf("test %d: count mismatch, got %d want %d\ninput: %s", i, count, test.count, test.input) } - if !reflect.DeepEqual(err, test.err) { + if !errors.Is(err, test.err) { t.Errorf("test %d: err mismatch, got %q want %q\ninput: %s", i, err, test.err, test.input) } } } -func TestSplitTypes(t *testing.T) { - if _, _, err := SplitString(unhex("C100")); err != ErrExpectedString { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedString) +func TestSplitString(t *testing.T) { + for i, test := range []string{ + "C0", + "C100", + "C3010203", + "C88363617483646F67", + "F8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974", + } { + if _, _, err := SplitString(unhex(test)); !errors.Is(err, ErrExpectedString) { + t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedString) + } + } +} + +func TestSplitList(t *testing.T) { + for i, test := range []string{ + "80", + "00", + "01", + "8180", + "81FF", + "820400", + "83636174", + "83646F67", + "B8384C6F72656D20697073756D20646F6C6F722073697420616D65742C20636F6E7365637465747572206164697069736963696E6720656C6974", + } { + if _, _, err := SplitList(unhex(test)); !errors.Is(err, ErrExpectedList) { + t.Errorf("test %d: error mismatch: have %q, want %q", i, err, ErrExpectedList) + } } - if _, _, err := SplitList(unhex("01")); err != ErrExpectedList { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList) +} + +func TestSplitUint64(t *testing.T) { + tests := []struct { + input string + val uint64 + rest string + err error + }{ + {"01", 1, "", nil}, + {"7FFF", 0x7F, "FF", nil}, + {"80FF", 0, "FF", nil}, + {"81FAFF", 0xFA, "FF", nil}, + {"82FAFAFF", 0xFAFA, "FF", nil}, + {"83FAFAFAFF", 0xFAFAFA, "FF", nil}, + {"84FAFAFAFAFF", 0xFAFAFAFA, "FF", nil}, + {"85FAFAFAFAFAFF", 0xFAFAFAFAFA, "FF", nil}, + {"86FAFAFAFAFAFAFF", 0xFAFAFAFAFAFA, "FF", nil}, + {"87FAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFA, "FF", nil}, + {"88FAFAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFAFA, "FF", nil}, + + // errors + {"", 0, "", io.ErrUnexpectedEOF}, + {"00", 0, "00", ErrCanonInt}, + {"81", 0, "81", ErrValueTooLarge}, + {"8100", 0, "8100", ErrCanonSize}, + {"8200FF", 0, "8200FF", ErrCanonInt}, + {"8103FF", 0, "8103FF", ErrCanonSize}, + {"89FAFAFAFAFAFAFAFAFAFF", 0, "89FAFAFAFAFAFAFAFAFAFF", errUintOverflow}, } - if _, _, err := SplitList(unhex("81FF")); err != ErrExpectedList { - t.Errorf("SplitString returned %q, want %q", err, ErrExpectedList) + + for i, test := range tests { + val, rest, err := SplitUint64(unhex(test.input)) + if val != test.val { + t.Errorf("test %d: val mismatch: got %x, want %x (input %q)", i, val, test.val, test.input) + } + if !bytes.Equal(rest, unhex(test.rest)) { + t.Errorf("test %d: rest mismatch: got %x, want %s (input %q)", i, rest, test.rest, test.input) + } + if err != test.err { + t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err) + } } } @@ -78,7 +142,9 @@ func TestSplit(t *testing.T) { val, rest string err error }{ + {input: "00FFFF", kind: Byte, val: "00", rest: "FFFF"}, {input: "01FFFF", kind: Byte, val: "01", rest: "FFFF"}, + {input: "7FFFFF", kind: Byte, val: "7F", rest: "FFFF"}, {input: "80FFFF", kind: String, val: "", rest: "FFFF"}, {input: "C3010203", kind: List, val: "010203"}, @@ -194,3 +260,79 @@ func TestReadSize(t *testing.T) { } } } + +func TestAppendUint64(t *testing.T) { + tests := []struct { + input uint64 + slice []byte + output string + }{ + {0, nil, "80"}, + {1, nil, "01"}, + {2, nil, "02"}, + {127, nil, "7F"}, + {128, nil, "8180"}, + {129, nil, "8181"}, + {0xFFFFFF, nil, "83FFFFFF"}, + {127, []byte{1, 2, 3}, "0102037F"}, + {0xFFFFFF, []byte{1, 2, 3}, "01020383FFFFFF"}, + } + + for _, test := range tests { + x := AppendUint64(test.slice, test.input) + if !bytes.Equal(x, unhex(test.output)) { + t.Errorf("AppendUint64(%v, %d): got %x, want %s", test.slice, test.input, x, test.output) + } + + // Check that IntSize returns the appended size. + length := len(x) - len(test.slice) + if s := IntSize(test.input); s != length { + t.Errorf("IntSize(%d): got %d, want %d", test.input, s, length) + } + } +} + +func TestAppendUint64Random(t *testing.T) { + fn := func(i uint64) bool { + enc, _ := EncodeToBytes(i) + encAppend := AppendUint64(nil, i) + return bytes.Equal(enc, encAppend) + } + config := quick.Config{MaxCountScale: 50} + if err := quick.Check(fn, &config); err != nil { + t.Fatal(err) + } +} + +func TestBytesSize(t *testing.T) { + tests := []struct { + v []byte + size uint64 + }{ + {v: []byte{}, size: 1}, + {v: []byte{0x1}, size: 1}, + {v: []byte{0x7E}, size: 1}, + {v: []byte{0x7F}, size: 1}, + {v: []byte{0x80}, size: 2}, + {v: []byte{0xFF}, size: 2}, + {v: []byte{0xFF, 0xF0}, size: 3}, + {v: make([]byte, 55), size: 56}, + {v: make([]byte, 56), size: 58}, + } + + for _, test := range tests { + s := BytesSize(test.v) + if s != test.size { + t.Errorf("BytesSize(%#x) -> %d, want %d", test.v, s, test.size) + } + s = StringSize(string(test.v)) + if s != test.size { + t.Errorf("StringSize(%#x) -> %d, want %d", test.v, s, test.size) + } + // Sanity check: + enc, _ := EncodeToBytes(test.v) + if uint64(len(enc)) != test.size { + t.Errorf("len(EncodeToBytes(%#x)) -> %d, test says %d", test.v, len(enc), test.size) + } + } +} diff --git a/rlp/rlpgen/gen.go b/rlp/rlpgen/gen.go new file mode 100644 index 000000000000..ed502c09a7e3 --- /dev/null +++ b/rlp/rlpgen/gen.go @@ -0,0 +1,800 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "fmt" + "go/format" + "go/types" + "sort" + + "github.com/XinFinOrg/XDPoSChain/rlp/internal/rlpstruct" +) + +// buildContext keeps the data needed for make*Op. +type buildContext struct { + topType *types.Named // the type we're creating methods for + + encoderIface *types.Interface + decoderIface *types.Interface + rawValueType *types.Named + + typeToStructCache map[types.Type]*rlpstruct.Type +} + +func newBuildContext(packageRLP *types.Package) *buildContext { + enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying() + dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying() + rawv := packageRLP.Scope().Lookup("RawValue").Type() + return &buildContext{ + typeToStructCache: make(map[types.Type]*rlpstruct.Type), + encoderIface: enc.(*types.Interface), + decoderIface: dec.(*types.Interface), + rawValueType: rawv.(*types.Named), + } +} + +func (bctx *buildContext) isEncoder(typ types.Type) bool { + return types.Implements(typ, bctx.encoderIface) +} + +func (bctx *buildContext) isDecoder(typ types.Type) bool { + return types.Implements(typ, bctx.decoderIface) +} + +// typeToStructType converts typ to rlpstruct.Type. +func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type { + if prev := bctx.typeToStructCache[typ]; prev != nil { + return prev // short-circuit for recursive types. + } + + // Resolve named types to their underlying type, but keep the name. + name := types.TypeString(typ, nil) + for { + utype := typ.Underlying() + if utype == typ { + break + } + typ = utype + } + + // Create the type and store it in cache. + t := &rlpstruct.Type{ + Name: name, + Kind: typeReflectKind(typ), + IsEncoder: bctx.isEncoder(typ), + IsDecoder: bctx.isDecoder(typ), + } + bctx.typeToStructCache[typ] = t + + // Assign element type. + switch typ.(type) { + case *types.Array, *types.Slice, *types.Pointer: + etype := typ.(interface{ Elem() types.Type }).Elem() + t.Elem = bctx.typeToStructType(etype) + } + return t +} + +// genContext is passed to the gen* methods of op when generating +// the output code. It tracks packages to be imported by the output +// file and assigns unique names of temporary variables. +type genContext struct { + inPackage *types.Package + imports map[string]struct{} + tempCounter int +} + +func newGenContext(inPackage *types.Package) *genContext { + return &genContext{ + inPackage: inPackage, + imports: make(map[string]struct{}), + } +} + +func (ctx *genContext) temp() string { + v := fmt.Sprintf("_tmp%d", ctx.tempCounter) + ctx.tempCounter++ + return v +} + +func (ctx *genContext) resetTemp() { + ctx.tempCounter = 0 +} + +func (ctx *genContext) addImport(path string) { + if path == ctx.inPackage.Path() { + return // avoid importing the package that we're generating in. + } + // TODO: renaming? + ctx.imports[path] = struct{}{} +} + +// importsList returns all packages that need to be imported. +func (ctx *genContext) importsList() []string { + imp := make([]string, 0, len(ctx.imports)) + for k := range ctx.imports { + imp = append(imp, k) + } + sort.Strings(imp) + return imp +} + +// qualify is the types.Qualifier used for printing types. +func (ctx *genContext) qualify(pkg *types.Package) string { + if pkg.Path() == ctx.inPackage.Path() { + return "" + } + ctx.addImport(pkg.Path()) + // TODO: renaming? + return pkg.Name() +} + +type op interface { + // genWrite creates the encoder. The generated code should write v, + // which is any Go expression, to the rlp.EncoderBuffer 'w'. + genWrite(ctx *genContext, v string) string + + // genDecode creates the decoder. The generated code should read + // a value from the rlp.Stream 'dec' and store it to dst. + genDecode(ctx *genContext) (string, string) +} + +// basicOp handles basic types bool, uint*, string. +type basicOp struct { + typ types.Type + writeMethod string // EncoderBuffer writer method name + writeArgType types.Type // parameter type of writeMethod + decMethod string + decResultType types.Type // return type of decMethod + decUseBitSize bool // if true, result bit size is appended to decMethod +} + +func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) { + op := basicOp{typ: typ} + kind := typ.Kind() + switch { + case kind == types.Bool: + op.writeMethod = "WriteBool" + op.writeArgType = types.Typ[types.Bool] + op.decMethod = "Bool" + op.decResultType = types.Typ[types.Bool] + case kind >= types.Uint8 && kind <= types.Uint64: + op.writeMethod = "WriteUint64" + op.writeArgType = types.Typ[types.Uint64] + op.decMethod = "Uint" + op.decResultType = typ + op.decUseBitSize = true + case kind == types.String: + op.writeMethod = "WriteString" + op.writeArgType = types.Typ[types.String] + op.decMethod = "String" + op.decResultType = types.Typ[types.String] + default: + return nil, fmt.Errorf("unhandled basic type: %v", typ) + } + return op, nil +} + +func (*buildContext) makeByteSliceOp(typ *types.Slice) op { + if !isByte(typ.Elem()) { + panic("non-byte slice type in makeByteSliceOp") + } + bslice := types.NewSlice(types.Typ[types.Uint8]) + return basicOp{ + typ: typ, + writeMethod: "WriteBytes", + writeArgType: bslice, + decMethod: "Bytes", + decResultType: bslice, + } +} + +func (bctx *buildContext) makeRawValueOp() op { + bslice := types.NewSlice(types.Typ[types.Uint8]) + return basicOp{ + typ: bctx.rawValueType, + writeMethod: "Write", + writeArgType: bslice, + decMethod: "Raw", + decResultType: bslice, + } +} + +func (op basicOp) writeNeedsConversion() bool { + return !types.AssignableTo(op.typ, op.writeArgType) +} + +func (op basicOp) decodeNeedsConversion() bool { + return !types.AssignableTo(op.decResultType, op.typ) +} + +func (op basicOp) genWrite(ctx *genContext, v string) string { + if op.writeNeedsConversion() { + v = fmt.Sprintf("%s(%s)", op.writeArgType, v) + } + return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v) +} + +func (op basicOp) genDecode(ctx *genContext) (string, string) { + var ( + resultV = ctx.temp() + result = resultV + method = op.decMethod + ) + if op.decUseBitSize { + // Note: For now, this only works for platform-independent integer + // sizes. makeBasicOp forbids the platform-dependent types. + var sizes types.StdSizes + method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8) + } + + // Call the decoder method. + var b bytes.Buffer + fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method) + fmt.Fprintf(&b, "if err != nil { return err }\n") + if op.decodeNeedsConversion() { + conv := ctx.temp() + fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV) + result = conv + } + return result, b.String() +} + +// byteArrayOp handles [...]byte. +type byteArrayOp struct { + typ types.Type + name types.Type // name != typ for named byte array types (e.g. common.Address) +} + +func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp { + nt := types.Type(name) + if name == nil { + nt = typ + } + return byteArrayOp{typ, nt} +} + +func (op byteArrayOp) genWrite(ctx *genContext, v string) string { + return fmt.Sprintf("w.WriteBytes(%s[:])\n", v) +} + +func (op byteArrayOp) genDecode(ctx *genContext) (string, string) { + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify)) + fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV) + return resultV, b.String() +} + +// bigIntOp handles big.Int. +// This exists because big.Int has it's own decoder operation on rlp.Stream, +// but the decode method returns *big.Int, so it needs to be dereferenced. +type bigIntOp struct { + pointer bool +} + +func (op bigIntOp) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + + fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v) + fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n") + fmt.Fprintf(&b, "}\n") + dst := v + if !op.pointer { + dst = "&" + v + } + fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst) + + // Wrap with nil check. + if op.pointer { + code := b.String() + b.Reset() + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") + fmt.Fprintf(&b, "} else {\n") + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "}\n") + } + + return b.String() +} + +func (op bigIntOp) genDecode(ctx *genContext) (string, string) { + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV) + fmt.Fprintf(&b, "if err != nil { return err }\n") + + result := resultV + if !op.pointer { + result = "(*" + resultV + ")" + } + return result, b.String() +} + +// uint256Op handles "github.com/holiman/uint256".Int +type uint256Op struct { + pointer bool +} + +func (op uint256Op) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + + dst := v + if !op.pointer { + dst = "&" + v + } + fmt.Fprintf(&b, "w.WriteUint256(%s)\n", dst) + + // Wrap with nil check. + if op.pointer { + code := b.String() + b.Reset() + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") + fmt.Fprintf(&b, "} else {\n") + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "}\n") + } + + return b.String() +} + +func (op uint256Op) genDecode(ctx *genContext) (string, string) { + ctx.addImport("github.com/holiman/uint256") + + var b bytes.Buffer + resultV := ctx.temp() + fmt.Fprintf(&b, "var %s uint256.Int\n", resultV) + fmt.Fprintf(&b, "if err := dec.ReadUint256(&%s); err != nil { return err }\n", resultV) + + result := resultV + if op.pointer { + result = "&" + resultV + } + return result, b.String() +} + +// encoderDecoderOp handles rlp.Encoder and rlp.Decoder. +// In order to be used with this, the type must implement both interfaces. +// This restriction may be lifted in the future by creating separate ops for +// encoding and decoding. +type encoderDecoderOp struct { + typ types.Type +} + +func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string { + return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v) +} + +func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) { + // DecodeRLP must have pointer receiver, and this is verified in makeOp. + etyp := op.typ.(*types.Pointer).Elem() + var resultV = ctx.temp() + + var b bytes.Buffer + fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify)) + fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV) + return resultV, b.String() +} + +// ptrOp handles pointer types. +type ptrOp struct { + elemTyp types.Type + elem op + nilOK bool + nilValue rlpstruct.NilKind +} + +func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) { + elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{}) + if err != nil { + return nil, err + } + op := ptrOp{elemTyp: elemTyp, elem: elemOp} + + // Determine nil value. + if tags.NilOK { + op.nilOK = true + op.nilValue = tags.NilKind + } else { + styp := bctx.typeToStructType(elemTyp) + op.nilValue = styp.DefaultNilValue() + } + return op, nil +} + +func (op ptrOp) genWrite(ctx *genContext, v string) string { + // Note: in writer functions, accesses to v are read-only, i.e. v is any Go + // expression. To make all accesses work through the pointer, we substitute + // v with (*v). This is required for most accesses including `v`, `call(v)`, + // and `v[index]` on slices. + // + // For `v.field` and `v[:]` on arrays, the dereference operation is not required. + var vv string + _, isStruct := op.elem.(structOp) + _, isByteArray := op.elem.(byteArrayOp) + if isStruct || isByteArray { + vv = v + } else { + vv = fmt.Sprintf("(*%s)", v) + } + + var b bytes.Buffer + fmt.Fprintf(&b, "if %s == nil {\n", v) + fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue) + fmt.Fprintf(&b, "} else {\n") + fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv)) + fmt.Fprintf(&b, "}\n") + return b.String() +} + +func (op ptrOp) genDecode(ctx *genContext) (string, string) { + result, code := op.elem.genDecode(ctx) + if !op.nilOK { + // If nil pointers are not allowed, we can just decode the element. + return "&" + result, code + } + + // nil is allowed, so check the kind and size first. + // If size is zero and kind matches the nilKind of the type, + // the value decodes as a nil pointer. + var ( + resultV = ctx.temp() + kindV = ctx.temp() + sizeV = ctx.temp() + wantKind string + ) + if op.nilValue == rlpstruct.NilKindList { + wantKind = "rlp.List" + } else { + wantKind = "rlp.String" + } + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify)) + fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV) + fmt.Fprintf(&b, " return err\n") + fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, " %s = &%s\n", resultV, result) + fmt.Fprintf(&b, "}\n") + return resultV, b.String() +} + +// structOp handles struct types. +type structOp struct { + named *types.Named + typ *types.Struct + fields []*structField + optionalFields []*structField +} + +type structField struct { + name string + typ types.Type + elem op +} + +func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) { + // Convert fields to []rlpstruct.Field. + var allStructFields []rlpstruct.Field + for i := 0; i < typ.NumFields(); i++ { + f := typ.Field(i) + allStructFields = append(allStructFields, rlpstruct.Field{ + Name: f.Name(), + Exported: f.Exported(), + Index: i, + Tag: typ.Tag(i), + Type: *bctx.typeToStructType(f.Type()), + }) + } + + // Filter/validate fields. + fields, tags, err := rlpstruct.ProcessFields(allStructFields) + if err != nil { + return nil, err + } + + // Create field ops. + var op = structOp{named: named, typ: typ} + for i, field := range fields { + // Advanced struct tags are not supported yet. + tag := tags[i] + if err := checkUnsupportedTags(field.Name, tag); err != nil { + return nil, err + } + typ := typ.Field(field.Index).Type() + elem, err := bctx.makeOp(nil, typ, tags[i]) + if err != nil { + return nil, fmt.Errorf("field %s: %v", field.Name, err) + } + f := &structField{name: field.Name, typ: typ, elem: elem} + if tag.Optional { + op.optionalFields = append(op.optionalFields, f) + } else { + op.fields = append(op.fields, f) + } + } + return op, nil +} + +func checkUnsupportedTags(field string, tag rlpstruct.Tags) error { + if tag.Tail { + return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field) + } + return nil +} + +func (op structOp) genWrite(ctx *genContext, v string) string { + var b bytes.Buffer + var listMarker = ctx.temp() + fmt.Fprintf(&b, "%s := w.List()\n", listMarker) + for _, field := range op.fields { + selector := v + "." + field.name + fmt.Fprint(&b, field.elem.genWrite(ctx, selector)) + } + op.writeOptionalFields(&b, ctx, v) + fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) + return b.String() +} + +func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) { + if len(op.optionalFields) == 0 { + return + } + // First check zero-ness of all optional fields. + var zeroV = make([]string, len(op.optionalFields)) + for i, field := range op.optionalFields { + selector := v + "." + field.name + zeroV[i] = ctx.temp() + fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify)) + } + // Now write the fields. + for i, field := range op.optionalFields { + selector := v + "." + field.name + cond := "" + for j := i; j < len(op.optionalFields); j++ { + if j > i { + cond += " || " + } + cond += zeroV[j] + } + fmt.Fprintf(b, "if %s {\n", cond) + fmt.Fprint(b, field.elem.genWrite(ctx, selector)) + fmt.Fprintf(b, "}\n") + } +} + +func (op structOp) genDecode(ctx *genContext) (string, string) { + // Get the string representation of the type. + // Here, named types are handled separately because the output + // would contain a copy of the struct definition otherwise. + var typeName string + if op.named != nil { + typeName = types.TypeString(op.named, ctx.qualify) + } else { + typeName = types.TypeString(op.typ, ctx.qualify) + } + + // Create struct object. + var resultV = ctx.temp() + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", resultV, typeName) + + // Decode fields. + fmt.Fprintf(&b, "{\n") + fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") + for _, field := range op.fields { + result, code := field.elem.genDecode(ctx) + fmt.Fprintf(&b, "// %s:\n", field.name) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result) + } + op.decodeOptionalFields(&b, ctx, resultV) + fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") + fmt.Fprintf(&b, "}\n") + return resultV, b.String() +} + +func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) { + var suffix bytes.Buffer + for _, field := range op.optionalFields { + result, code := field.elem.genDecode(ctx) + fmt.Fprintf(b, "// %s:\n", field.name) + fmt.Fprintf(b, "if dec.MoreDataInList() {\n") + fmt.Fprint(b, code) + fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result) + fmt.Fprintf(&suffix, "}\n") + } + suffix.WriteTo(b) +} + +// sliceOp handles slice types. +type sliceOp struct { + typ *types.Slice + elemOp op +} + +func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) { + elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{}) + if err != nil { + return nil, err + } + return sliceOp{typ: typ, elemOp: elemOp}, nil +} + +func (op sliceOp) genWrite(ctx *genContext, v string) string { + var ( + listMarker = ctx.temp() // holds return value of w.List() + iterElemV = ctx.temp() // iteration variable + elemCode = op.elemOp.genWrite(ctx, iterElemV) + ) + + var b bytes.Buffer + fmt.Fprintf(&b, "%s := w.List()\n", listMarker) + fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v) + fmt.Fprint(&b, elemCode) + fmt.Fprintf(&b, "}\n") + fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) + return b.String() +} + +func (op sliceOp) genDecode(ctx *genContext) (string, string) { + var sliceV = ctx.temp() // holds the output slice + elemResult, elemCode := op.elemOp.genDecode(ctx) + + var b bytes.Buffer + fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify)) + fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") + fmt.Fprintf(&b, "for dec.MoreDataInList() {\n") + fmt.Fprintf(&b, " %s", elemCode) + fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult) + fmt.Fprintf(&b, "}\n") + fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") + return sliceV, b.String() +} + +func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) { + switch typ := typ.(type) { + case *types.Named: + if isBigInt(typ) { + return bigIntOp{}, nil + } + if isUint256(typ) { + return uint256Op{}, nil + } + if typ == bctx.rawValueType { + return bctx.makeRawValueOp(), nil + } + if bctx.isDecoder(typ) { + return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ) + } + // TODO: same check for encoder? + return bctx.makeOp(typ, typ.Underlying(), tags) + case *types.Pointer: + if isBigInt(typ.Elem()) { + return bigIntOp{pointer: true}, nil + } + if isUint256(typ.Elem()) { + return uint256Op{pointer: true}, nil + } + // Encoder/Decoder interfaces. + if bctx.isEncoder(typ) { + if bctx.isDecoder(typ) { + return encoderDecoderOp{typ}, nil + } + return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ) + } + if bctx.isDecoder(typ) { + return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ) + } + // Default pointer handling. + return bctx.makePtrOp(typ.Elem(), tags) + case *types.Basic: + return bctx.makeBasicOp(typ) + case *types.Struct: + return bctx.makeStructOp(name, typ) + case *types.Slice: + etyp := typ.Elem() + if isByte(etyp) && !bctx.isEncoder(etyp) { + return bctx.makeByteSliceOp(typ), nil + } + return bctx.makeSliceOp(typ) + case *types.Array: + etyp := typ.Elem() + if isByte(etyp) && !bctx.isEncoder(etyp) { + return bctx.makeByteArrayOp(name, typ), nil + } + return nil, fmt.Errorf("unhandled array type: %v", typ) + default: + return nil, fmt.Errorf("unhandled type: %v", typ) + } +} + +// generateDecoder generates the DecodeRLP method on 'typ'. +func generateDecoder(ctx *genContext, typ string, op op) []byte { + ctx.resetTemp() + ctx.addImport(pathOfPackageRLP) + + result, code := op.genDecode(ctx) + var b bytes.Buffer + fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ) + fmt.Fprint(&b, code) + fmt.Fprintf(&b, " *obj = %s\n", result) + fmt.Fprintf(&b, " return nil\n") + fmt.Fprintf(&b, "}\n") + return b.Bytes() +} + +// generateEncoder generates the EncodeRLP method on 'typ'. +func generateEncoder(ctx *genContext, typ string, op op) []byte { + ctx.resetTemp() + ctx.addImport("io") + ctx.addImport(pathOfPackageRLP) + + var b bytes.Buffer + fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ) + fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n") + fmt.Fprint(&b, op.genWrite(ctx, "obj")) + fmt.Fprintf(&b, " return w.Flush()\n") + fmt.Fprintf(&b, "}\n") + return b.Bytes() +} + +func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) { + bctx.topType = typ + + pkg := typ.Obj().Pkg() + op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{}) + if err != nil { + return nil, err + } + + var ( + ctx = newGenContext(pkg) + encSource []byte + decSource []byte + ) + if encoder { + encSource = generateEncoder(ctx, typ.Obj().Name(), op) + } + if decoder { + decSource = generateDecoder(ctx, typ.Obj().Name(), op) + } + + var b bytes.Buffer + fmt.Fprintf(&b, "package %s\n\n", pkg.Name()) + for _, imp := range ctx.importsList() { + fmt.Fprintf(&b, "import %q\n", imp) + } + if encoder { + fmt.Fprintln(&b) + b.Write(encSource) + } + if decoder { + fmt.Fprintln(&b) + b.Write(decSource) + } + + source := b.Bytes() + // fmt.Println(string(source)) + return format.Source(source) +} diff --git a/rlp/rlpgen/gen_test.go b/rlp/rlpgen/gen_test.go new file mode 100644 index 000000000000..3b4f5df28765 --- /dev/null +++ b/rlp/rlpgen/gen_test.go @@ -0,0 +1,107 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/importer" + "go/parser" + "go/token" + "go/types" + "os" + "path/filepath" + "testing" +) + +// Package RLP is loaded only once and reused for all tests. +var ( + testFset = token.NewFileSet() + testImporter = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom) + testPackageRLP *types.Package +) + +func init() { + cwd, err := os.Getwd() + if err != nil { + panic(err) + } + testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0) + if err != nil { + panic(fmt.Errorf("can't load package RLP: %v", err)) + } +} + +var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint", "uint256"} + +func TestOutput(t *testing.T) { + for _, test := range tests { + test := test + t.Run(test, func(t *testing.T) { + inputFile := filepath.Join("testdata", test+".in.txt") + outputFile := filepath.Join("testdata", test+".out.txt") + bctx, typ, err := loadTestSource(inputFile, "Test") + if err != nil { + t.Fatal("error loading test source:", err) + } + output, err := bctx.generate(typ, true, true) + if err != nil { + t.Fatal("error in generate:", err) + } + + // Set this environment variable to regenerate the test outputs. + if os.Getenv("WRITE_TEST_FILES") != "" { + os.WriteFile(outputFile, output, 0644) + } + + // Check if output matches. + wantOutput, err := os.ReadFile(outputFile) + if err != nil { + t.Fatal("error loading expected test output:", err) + } + if !bytes.Equal(output, wantOutput) { + t.Fatalf("output mismatch, want: %v got %v", string(wantOutput), string(output)) + } + }) + } +} + +func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) { + // Load the test input. + content, err := os.ReadFile(file) + if err != nil { + return nil, nil, err + } + f, err := parser.ParseFile(testFset, file, content, 0) + if err != nil { + return nil, nil, err + } + conf := types.Config{Importer: testImporter} + pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil) + if err != nil { + return nil, nil, err + } + + // Find the test struct. + bctx := newBuildContext(testPackageRLP) + typ, err := lookupStructType(pkg.Scope(), typeName) + if err != nil { + return nil, nil, fmt.Errorf("can't find type %s: %v", typeName, err) + } + return bctx, typ, nil +} diff --git a/rlp/rlpgen/main.go b/rlp/rlpgen/main.go new file mode 100644 index 000000000000..727188230606 --- /dev/null +++ b/rlp/rlpgen/main.go @@ -0,0 +1,144 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "go/types" + "os" + + "golang.org/x/tools/go/packages" +) + +const pathOfPackageRLP = "github.com/XinFinOrg/XDPoSChain/rlp" + +func main() { + var ( + pkgdir = flag.String("dir", ".", "input package") + output = flag.String("out", "-", "output file (default is stdout)") + genEncoder = flag.Bool("encoder", true, "generate EncodeRLP?") + genDecoder = flag.Bool("decoder", false, "generate DecodeRLP?") + typename = flag.String("type", "", "type to generate methods for") + ) + flag.Parse() + + cfg := Config{ + Dir: *pkgdir, + Type: *typename, + GenerateEncoder: *genEncoder, + GenerateDecoder: *genDecoder, + } + code, err := cfg.process() + if err != nil { + fatal(err) + } + if *output == "-" { + os.Stdout.Write(code) + } else if err := os.WriteFile(*output, code, 0600); err != nil { + fatal(err) + } +} + +func fatal(args ...interface{}) { + fmt.Fprintln(os.Stderr, args...) + os.Exit(1) +} + +type Config struct { + Dir string // input package directory + Type string + + GenerateEncoder bool + GenerateDecoder bool +} + +// process generates the Go code. +func (cfg *Config) process() (code []byte, err error) { + // Load packages. + pcfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes, + Dir: cfg.Dir, + } + ps, err := packages.Load(pcfg, pathOfPackageRLP, ".") + if err != nil { + return nil, err + } + if len(ps) == 0 { + return nil, fmt.Errorf("no Go package found in %s", cfg.Dir) + } + packages.PrintErrors(ps) + + // Find the packages that were loaded. + var ( + pkg *types.Package + packageRLP *types.Package + ) + for _, p := range ps { + if len(p.Errors) > 0 { + return nil, fmt.Errorf("package %s has errors", p.PkgPath) + } + if p.PkgPath == pathOfPackageRLP { + packageRLP = p.Types + } else { + pkg = p.Types + } + } + bctx := newBuildContext(packageRLP) + + // Find the type and generate. + typ, err := lookupStructType(pkg.Scope(), cfg.Type) + if err != nil { + return nil, fmt.Errorf("can't find %s in %s: %v", cfg.Type, pkg, err) + } + code, err = bctx.generate(typ, cfg.GenerateEncoder, cfg.GenerateDecoder) + if err != nil { + return nil, err + } + + // Add build comments. + // This is done here to avoid processing these lines with gofmt. + var header bytes.Buffer + fmt.Fprint(&header, "// Code generated by rlpgen. DO NOT EDIT.\n\n") + return append(header.Bytes(), code...), nil +} + +func lookupStructType(scope *types.Scope, name string) (*types.Named, error) { + typ, err := lookupType(scope, name) + if err != nil { + return nil, err + } + _, ok := typ.Underlying().(*types.Struct) + if !ok { + return nil, errors.New("not a struct type") + } + return typ, nil +} + +func lookupType(scope *types.Scope, name string) (*types.Named, error) { + obj := scope.Lookup(name) + if obj == nil { + return nil, errors.New("no such identifier") + } + typ, ok := obj.(*types.TypeName) + if !ok { + return nil, errors.New("not a type") + } + return typ.Type().(*types.Named), nil +} diff --git a/rlp/rlpgen/testdata/bigint.in.txt b/rlp/rlpgen/testdata/bigint.in.txt new file mode 100644 index 000000000000..d23d84a28763 --- /dev/null +++ b/rlp/rlpgen/testdata/bigint.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +import "math/big" + +type Test struct { + Int *big.Int + IntNoPtr big.Int +} diff --git a/rlp/rlpgen/testdata/bigint.out.txt b/rlp/rlpgen/testdata/bigint.out.txt new file mode 100644 index 000000000000..faab1bed461c --- /dev/null +++ b/rlp/rlpgen/testdata/bigint.out.txt @@ -0,0 +1,49 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Int == nil { + w.Write(rlp.EmptyString) + } else { + if obj.Int.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(obj.Int) + } + if obj.IntNoPtr.Sign() == -1 { + return rlp.ErrNegativeBigInt + } + w.WriteBigInt(&obj.IntNoPtr) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Int: + _tmp1, err := dec.BigInt() + if err != nil { + return err + } + _tmp0.Int = _tmp1 + // IntNoPtr: + _tmp2, err := dec.BigInt() + if err != nil { + return err + } + _tmp0.IntNoPtr = (*_tmp2) + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/nil.in.txt b/rlp/rlpgen/testdata/nil.in.txt new file mode 100644 index 000000000000..a28ff344874d --- /dev/null +++ b/rlp/rlpgen/testdata/nil.in.txt @@ -0,0 +1,30 @@ +// -*- mode: go -*- + +package test + +type Aux struct{ + A uint32 +} + +type Test struct{ + Uint8 *byte `rlp:"nil"` + Uint8List *byte `rlp:"nilList"` + + Uint32 *uint32 `rlp:"nil"` + Uint32List *uint32 `rlp:"nilList"` + + Uint64 *uint64 `rlp:"nil"` + Uint64List *uint64 `rlp:"nilList"` + + String *string `rlp:"nil"` + StringList *string `rlp:"nilList"` + + ByteArray *[3]byte `rlp:"nil"` + ByteArrayList *[3]byte `rlp:"nilList"` + + ByteSlice *[]byte `rlp:"nil"` + ByteSliceList *[]byte `rlp:"nilList"` + + Struct *Aux `rlp:"nil"` + StructString *Aux `rlp:"nilString"` +} diff --git a/rlp/rlpgen/testdata/nil.out.txt b/rlp/rlpgen/testdata/nil.out.txt new file mode 100644 index 000000000000..7f3459682b4b --- /dev/null +++ b/rlp/rlpgen/testdata/nil.out.txt @@ -0,0 +1,289 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Uint8 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64(uint64((*obj.Uint8))) + } + if obj.Uint8List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64(uint64((*obj.Uint8List))) + } + if obj.Uint32 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64(uint64((*obj.Uint32))) + } + if obj.Uint32List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64(uint64((*obj.Uint32List))) + } + if obj.Uint64 == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64((*obj.Uint64)) + } + if obj.Uint64List == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteUint64((*obj.Uint64List)) + } + if obj.String == nil { + w.Write([]byte{0x80}) + } else { + w.WriteString((*obj.String)) + } + if obj.StringList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteString((*obj.StringList)) + } + if obj.ByteArray == nil { + w.Write([]byte{0x80}) + } else { + w.WriteBytes(obj.ByteArray[:]) + } + if obj.ByteArrayList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteBytes(obj.ByteArrayList[:]) + } + if obj.ByteSlice == nil { + w.Write([]byte{0x80}) + } else { + w.WriteBytes((*obj.ByteSlice)) + } + if obj.ByteSliceList == nil { + w.Write([]byte{0xC0}) + } else { + w.WriteBytes((*obj.ByteSliceList)) + } + if obj.Struct == nil { + w.Write([]byte{0xC0}) + } else { + _tmp1 := w.List() + w.WriteUint64(uint64(obj.Struct.A)) + w.ListEnd(_tmp1) + } + if obj.StructString == nil { + w.Write([]byte{0x80}) + } else { + _tmp2 := w.List() + w.WriteUint64(uint64(obj.StructString.A)) + w.ListEnd(_tmp2) + } + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Uint8: + var _tmp2 *byte + if _tmp3, _tmp4, err := dec.Kind(); err != nil { + return err + } else if _tmp4 != 0 || _tmp3 != rlp.String { + _tmp1, err := dec.Uint8() + if err != nil { + return err + } + _tmp2 = &_tmp1 + } + _tmp0.Uint8 = _tmp2 + // Uint8List: + var _tmp6 *byte + if _tmp7, _tmp8, err := dec.Kind(); err != nil { + return err + } else if _tmp8 != 0 || _tmp7 != rlp.List { + _tmp5, err := dec.Uint8() + if err != nil { + return err + } + _tmp6 = &_tmp5 + } + _tmp0.Uint8List = _tmp6 + // Uint32: + var _tmp10 *uint32 + if _tmp11, _tmp12, err := dec.Kind(); err != nil { + return err + } else if _tmp12 != 0 || _tmp11 != rlp.String { + _tmp9, err := dec.Uint32() + if err != nil { + return err + } + _tmp10 = &_tmp9 + } + _tmp0.Uint32 = _tmp10 + // Uint32List: + var _tmp14 *uint32 + if _tmp15, _tmp16, err := dec.Kind(); err != nil { + return err + } else if _tmp16 != 0 || _tmp15 != rlp.List { + _tmp13, err := dec.Uint32() + if err != nil { + return err + } + _tmp14 = &_tmp13 + } + _tmp0.Uint32List = _tmp14 + // Uint64: + var _tmp18 *uint64 + if _tmp19, _tmp20, err := dec.Kind(); err != nil { + return err + } else if _tmp20 != 0 || _tmp19 != rlp.String { + _tmp17, err := dec.Uint64() + if err != nil { + return err + } + _tmp18 = &_tmp17 + } + _tmp0.Uint64 = _tmp18 + // Uint64List: + var _tmp22 *uint64 + if _tmp23, _tmp24, err := dec.Kind(); err != nil { + return err + } else if _tmp24 != 0 || _tmp23 != rlp.List { + _tmp21, err := dec.Uint64() + if err != nil { + return err + } + _tmp22 = &_tmp21 + } + _tmp0.Uint64List = _tmp22 + // String: + var _tmp26 *string + if _tmp27, _tmp28, err := dec.Kind(); err != nil { + return err + } else if _tmp28 != 0 || _tmp27 != rlp.String { + _tmp25, err := dec.String() + if err != nil { + return err + } + _tmp26 = &_tmp25 + } + _tmp0.String = _tmp26 + // StringList: + var _tmp30 *string + if _tmp31, _tmp32, err := dec.Kind(); err != nil { + return err + } else if _tmp32 != 0 || _tmp31 != rlp.List { + _tmp29, err := dec.String() + if err != nil { + return err + } + _tmp30 = &_tmp29 + } + _tmp0.StringList = _tmp30 + // ByteArray: + var _tmp34 *[3]byte + if _tmp35, _tmp36, err := dec.Kind(); err != nil { + return err + } else if _tmp36 != 0 || _tmp35 != rlp.String { + var _tmp33 [3]byte + if err := dec.ReadBytes(_tmp33[:]); err != nil { + return err + } + _tmp34 = &_tmp33 + } + _tmp0.ByteArray = _tmp34 + // ByteArrayList: + var _tmp38 *[3]byte + if _tmp39, _tmp40, err := dec.Kind(); err != nil { + return err + } else if _tmp40 != 0 || _tmp39 != rlp.List { + var _tmp37 [3]byte + if err := dec.ReadBytes(_tmp37[:]); err != nil { + return err + } + _tmp38 = &_tmp37 + } + _tmp0.ByteArrayList = _tmp38 + // ByteSlice: + var _tmp42 *[]byte + if _tmp43, _tmp44, err := dec.Kind(); err != nil { + return err + } else if _tmp44 != 0 || _tmp43 != rlp.String { + _tmp41, err := dec.Bytes() + if err != nil { + return err + } + _tmp42 = &_tmp41 + } + _tmp0.ByteSlice = _tmp42 + // ByteSliceList: + var _tmp46 *[]byte + if _tmp47, _tmp48, err := dec.Kind(); err != nil { + return err + } else if _tmp48 != 0 || _tmp47 != rlp.List { + _tmp45, err := dec.Bytes() + if err != nil { + return err + } + _tmp46 = &_tmp45 + } + _tmp0.ByteSliceList = _tmp46 + // Struct: + var _tmp51 *Aux + if _tmp52, _tmp53, err := dec.Kind(); err != nil { + return err + } else if _tmp53 != 0 || _tmp52 != rlp.List { + var _tmp49 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp50, err := dec.Uint32() + if err != nil { + return err + } + _tmp49.A = _tmp50 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp51 = &_tmp49 + } + _tmp0.Struct = _tmp51 + // StructString: + var _tmp56 *Aux + if _tmp57, _tmp58, err := dec.Kind(); err != nil { + return err + } else if _tmp58 != 0 || _tmp57 != rlp.String { + var _tmp54 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp55, err := dec.Uint32() + if err != nil { + return err + } + _tmp54.A = _tmp55 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp56 = &_tmp54 + } + _tmp0.StructString = _tmp56 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/optional.in.txt b/rlp/rlpgen/testdata/optional.in.txt new file mode 100644 index 000000000000..f1ac9f7899d1 --- /dev/null +++ b/rlp/rlpgen/testdata/optional.in.txt @@ -0,0 +1,17 @@ +// -*- mode: go -*- + +package test + +type Aux struct { + A uint64 +} + +type Test struct { + Uint64 uint64 `rlp:"optional"` + Pointer *uint64 `rlp:"optional"` + String string `rlp:"optional"` + Slice []uint64 `rlp:"optional"` + Array [3]byte `rlp:"optional"` + NamedStruct Aux `rlp:"optional"` + AnonStruct struct{ A string } `rlp:"optional"` +} diff --git a/rlp/rlpgen/testdata/optional.out.txt b/rlp/rlpgen/testdata/optional.out.txt new file mode 100644 index 000000000000..8b4cfa18171c --- /dev/null +++ b/rlp/rlpgen/testdata/optional.out.txt @@ -0,0 +1,153 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + _tmp1 := obj.Uint64 != 0 + _tmp2 := obj.Pointer != nil + _tmp3 := obj.String != "" + _tmp4 := len(obj.Slice) > 0 + _tmp5 := obj.Array != ([3]byte{}) + _tmp6 := obj.NamedStruct != (Aux{}) + _tmp7 := obj.AnonStruct != (struct{ A string }{}) + if _tmp1 || _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + w.WriteUint64(obj.Uint64) + } + if _tmp2 || _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + if obj.Pointer == nil { + w.Write([]byte{0x80}) + } else { + w.WriteUint64((*obj.Pointer)) + } + } + if _tmp3 || _tmp4 || _tmp5 || _tmp6 || _tmp7 { + w.WriteString(obj.String) + } + if _tmp4 || _tmp5 || _tmp6 || _tmp7 { + _tmp8 := w.List() + for _, _tmp9 := range obj.Slice { + w.WriteUint64(_tmp9) + } + w.ListEnd(_tmp8) + } + if _tmp5 || _tmp6 || _tmp7 { + w.WriteBytes(obj.Array[:]) + } + if _tmp6 || _tmp7 { + _tmp10 := w.List() + w.WriteUint64(obj.NamedStruct.A) + w.ListEnd(_tmp10) + } + if _tmp7 { + _tmp11 := w.List() + w.WriteString(obj.AnonStruct.A) + w.ListEnd(_tmp11) + } + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Uint64: + if dec.MoreDataInList() { + _tmp1, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.Uint64 = _tmp1 + // Pointer: + if dec.MoreDataInList() { + _tmp2, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.Pointer = &_tmp2 + // String: + if dec.MoreDataInList() { + _tmp3, err := dec.String() + if err != nil { + return err + } + _tmp0.String = _tmp3 + // Slice: + if dec.MoreDataInList() { + var _tmp4 []uint64 + if _, err := dec.List(); err != nil { + return err + } + for dec.MoreDataInList() { + _tmp5, err := dec.Uint64() + if err != nil { + return err + } + _tmp4 = append(_tmp4, _tmp5) + } + if err := dec.ListEnd(); err != nil { + return err + } + _tmp0.Slice = _tmp4 + // Array: + if dec.MoreDataInList() { + var _tmp6 [3]byte + if err := dec.ReadBytes(_tmp6[:]); err != nil { + return err + } + _tmp0.Array = _tmp6 + // NamedStruct: + if dec.MoreDataInList() { + var _tmp7 Aux + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp8, err := dec.Uint64() + if err != nil { + return err + } + _tmp7.A = _tmp8 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.NamedStruct = _tmp7 + // AnonStruct: + if dec.MoreDataInList() { + var _tmp9 struct{ A string } + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp10, err := dec.String() + if err != nil { + return err + } + _tmp9.A = _tmp10 + if err := dec.ListEnd(); err != nil { + return err + } + } + _tmp0.AnonStruct = _tmp9 + } + } + } + } + } + } + } + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/rawvalue.in.txt b/rlp/rlpgen/testdata/rawvalue.in.txt new file mode 100644 index 000000000000..daa050c3f65e --- /dev/null +++ b/rlp/rlpgen/testdata/rawvalue.in.txt @@ -0,0 +1,11 @@ +// -*- mode: go -*- + +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" + +type Test struct { + RawValue rlp.RawValue + PointerToRawValue *rlp.RawValue + SliceOfRawValue []rlp.RawValue +} diff --git a/rlp/rlpgen/testdata/rawvalue.out.txt b/rlp/rlpgen/testdata/rawvalue.out.txt new file mode 100644 index 000000000000..35bf145dcc71 --- /dev/null +++ b/rlp/rlpgen/testdata/rawvalue.out.txt @@ -0,0 +1,64 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.Write(obj.RawValue) + if obj.PointerToRawValue == nil { + w.Write([]byte{0x80}) + } else { + w.Write((*obj.PointerToRawValue)) + } + _tmp1 := w.List() + for _, _tmp2 := range obj.SliceOfRawValue { + w.Write(_tmp2) + } + w.ListEnd(_tmp1) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // RawValue: + _tmp1, err := dec.Raw() + if err != nil { + return err + } + _tmp0.RawValue = _tmp1 + // PointerToRawValue: + _tmp2, err := dec.Raw() + if err != nil { + return err + } + _tmp0.PointerToRawValue = &_tmp2 + // SliceOfRawValue: + var _tmp3 []rlp.RawValue + if _, err := dec.List(); err != nil { + return err + } + for dec.MoreDataInList() { + _tmp4, err := dec.Raw() + if err != nil { + return err + } + _tmp3 = append(_tmp3, _tmp4) + } + if err := dec.ListEnd(); err != nil { + return err + } + _tmp0.SliceOfRawValue = _tmp3 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/uint256.in.txt b/rlp/rlpgen/testdata/uint256.in.txt new file mode 100644 index 000000000000..ed16e0a7882f --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +import "github.com/holiman/uint256" + +type Test struct { + Int *uint256.Int + IntNoPtr uint256.Int +} diff --git a/rlp/rlpgen/testdata/uint256.out.txt b/rlp/rlpgen/testdata/uint256.out.txt new file mode 100644 index 000000000000..b560aa0e42fe --- /dev/null +++ b/rlp/rlpgen/testdata/uint256.out.txt @@ -0,0 +1,44 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "github.com/holiman/uint256" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + if obj.Int == nil { + w.Write(rlp.EmptyString) + } else { + w.WriteUint256(obj.Int) + } + w.WriteUint256(&obj.IntNoPtr) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // Int: + var _tmp1 uint256.Int + if err := dec.ReadUint256(&_tmp1); err != nil { + return err + } + _tmp0.Int = &_tmp1 + // IntNoPtr: + var _tmp2 uint256.Int + if err := dec.ReadUint256(&_tmp2); err != nil { + return err + } + _tmp0.IntNoPtr = _tmp2 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/testdata/uints.in.txt b/rlp/rlpgen/testdata/uints.in.txt new file mode 100644 index 000000000000..8095da997d96 --- /dev/null +++ b/rlp/rlpgen/testdata/uints.in.txt @@ -0,0 +1,10 @@ +// -*- mode: go -*- + +package test + +type Test struct{ + A uint8 + B uint16 + C uint32 + D uint64 +} diff --git a/rlp/rlpgen/testdata/uints.out.txt b/rlp/rlpgen/testdata/uints.out.txt new file mode 100644 index 000000000000..cf973ec9a43b --- /dev/null +++ b/rlp/rlpgen/testdata/uints.out.txt @@ -0,0 +1,53 @@ +package test + +import "github.com/XinFinOrg/XDPoSChain/rlp" +import "io" + +func (obj *Test) EncodeRLP(_w io.Writer) error { + w := rlp.NewEncoderBuffer(_w) + _tmp0 := w.List() + w.WriteUint64(uint64(obj.A)) + w.WriteUint64(uint64(obj.B)) + w.WriteUint64(uint64(obj.C)) + w.WriteUint64(obj.D) + w.ListEnd(_tmp0) + return w.Flush() +} + +func (obj *Test) DecodeRLP(dec *rlp.Stream) error { + var _tmp0 Test + { + if _, err := dec.List(); err != nil { + return err + } + // A: + _tmp1, err := dec.Uint8() + if err != nil { + return err + } + _tmp0.A = _tmp1 + // B: + _tmp2, err := dec.Uint16() + if err != nil { + return err + } + _tmp0.B = _tmp2 + // C: + _tmp3, err := dec.Uint32() + if err != nil { + return err + } + _tmp0.C = _tmp3 + // D: + _tmp4, err := dec.Uint64() + if err != nil { + return err + } + _tmp0.D = _tmp4 + if err := dec.ListEnd(); err != nil { + return err + } + } + *obj = _tmp0 + return nil +} diff --git a/rlp/rlpgen/types.go b/rlp/rlpgen/types.go new file mode 100644 index 000000000000..ea7dc96d8813 --- /dev/null +++ b/rlp/rlpgen/types.go @@ -0,0 +1,124 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package main + +import ( + "fmt" + "go/types" + "reflect" +) + +// typeReflectKind gives the reflect.Kind that represents typ. +func typeReflectKind(typ types.Type) reflect.Kind { + switch typ := typ.(type) { + case *types.Basic: + k := typ.Kind() + if k >= types.Bool && k <= types.Complex128 { + // value order matches for Bool..Complex128 + return reflect.Bool + reflect.Kind(k-types.Bool) + } + if k == types.String { + return reflect.String + } + if k == types.UnsafePointer { + return reflect.UnsafePointer + } + panic(fmt.Errorf("unhandled BasicKind %v", k)) + case *types.Array: + return reflect.Array + case *types.Chan: + return reflect.Chan + case *types.Interface: + return reflect.Interface + case *types.Map: + return reflect.Map + case *types.Pointer: + return reflect.Ptr + case *types.Signature: + return reflect.Func + case *types.Slice: + return reflect.Slice + case *types.Struct: + return reflect.Struct + default: + panic(fmt.Errorf("unhandled type %T", typ)) + } +} + +// nonZeroCheck returns the expression that checks whether 'v' is a non-zero value of type 'vtyp'. +func nonZeroCheck(v string, vtyp types.Type, qualify types.Qualifier) string { + // Resolve type name. + typ := resolveUnderlying(vtyp) + switch typ := typ.(type) { + case *types.Basic: + k := typ.Kind() + switch { + case k == types.Bool: + return v + case k >= types.Uint && k <= types.Complex128: + return fmt.Sprintf("%s != 0", v) + case k == types.String: + return fmt.Sprintf(`%s != ""`, v) + default: + panic(fmt.Errorf("unhandled BasicKind %v", k)) + } + case *types.Array, *types.Struct: + return fmt.Sprintf("%s != (%s{})", v, types.TypeString(vtyp, qualify)) + case *types.Interface, *types.Pointer, *types.Signature: + return fmt.Sprintf("%s != nil", v) + case *types.Slice, *types.Map: + return fmt.Sprintf("len(%s) > 0", v) + default: + panic(fmt.Errorf("unhandled type %T", typ)) + } +} + +// isBigInt checks whether 'typ' is "math/big".Int. +func isBigInt(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + name := named.Obj() + return name.Pkg().Path() == "math/big" && name.Name() == "Int" +} + +// isUint256 checks whether 'typ' is "github.com/holiman/uint256".Int. +func isUint256(typ types.Type) bool { + named, ok := typ.(*types.Named) + if !ok { + return false + } + name := named.Obj() + return name.Pkg().Path() == "github.com/holiman/uint256" && name.Name() == "Int" +} + +// isByte checks whether the underlying type of 'typ' is uint8. +func isByte(typ types.Type) bool { + basic, ok := resolveUnderlying(typ).(*types.Basic) + return ok && basic.Kind() == types.Uint8 +} + +func resolveUnderlying(typ types.Type) types.Type { + for { + t := typ.Underlying() + if t == typ { + return t + } + typ = t + } +} diff --git a/rlp/safe.go b/rlp/safe.go new file mode 100644 index 000000000000..3c910337b6a2 --- /dev/null +++ b/rlp/safe.go @@ -0,0 +1,27 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +//go:build nacl || js || !cgo +// +build nacl js !cgo + +package rlp + +import "reflect" + +// byteArrayBytes returns a slice of the byte array v. +func byteArrayBytes(v reflect.Value, length int) []byte { + return v.Slice(0, length).Bytes() +} diff --git a/rlp/typecache.go b/rlp/typecache.go index 3df799e1ecd5..57e4f2e46a06 100644 --- a/rlp/typecache.go +++ b/rlp/typecache.go @@ -18,139 +18,221 @@ package rlp import ( "fmt" + "maps" "reflect" - "strings" "sync" -) + "sync/atomic" -var ( - typeCacheMutex sync.RWMutex - typeCache = make(map[typekey]*typeinfo) + "github.com/XinFinOrg/XDPoSChain/rlp/internal/rlpstruct" ) +// typeinfo is an entry in the type cache. type typeinfo struct { - decoder - writer -} - -// represents struct tags -type tags struct { - // rlp:"nil" controls whether empty input results in a nil pointer. - nilOK bool - // rlp:"tail" controls whether this field swallows additional list - // elements. It can only be set for the last field, which must be - // of slice type. - tail bool - // rlp:"-" ignores fields. - ignored bool + decoder decoder + decoderErr error // error from makeDecoder + writer writer + writerErr error // error from makeWriter } +// typekey is the key of a type in typeCache. It includes the struct tags because +// they might generate a different decoder. type typekey struct { reflect.Type - // the key must include the struct tags because they - // might generate a different decoder. - tags + rlpstruct.Tags } type decoder func(*Stream, reflect.Value) error -type writer func(reflect.Value, *encbuf) error +type writer func(reflect.Value, *encBuffer) error + +var theTC = newTypeCache() + +type typeCache struct { + cur atomic.Value + + // This lock synchronizes writers. + mu sync.Mutex + next map[typekey]*typeinfo +} + +func newTypeCache() *typeCache { + c := new(typeCache) + c.cur.Store(make(map[typekey]*typeinfo)) + return c +} + +func cachedDecoder(typ reflect.Type) (decoder, error) { + info := theTC.info(typ) + return info.decoder, info.decoderErr +} + +func cachedWriter(typ reflect.Type) (writer, error) { + info := theTC.info(typ) + return info.writer, info.writerErr +} -func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) { - typeCacheMutex.RLock() - info := typeCache[typekey{typ, tags}] - typeCacheMutex.RUnlock() - if info != nil { - return info, nil +func (c *typeCache) info(typ reflect.Type) *typeinfo { + key := typekey{Type: typ} + if info := c.cur.Load().(map[typekey]*typeinfo)[key]; info != nil { + return info } - // not in the cache, need to generate info for this type. - typeCacheMutex.Lock() - defer typeCacheMutex.Unlock() - return cachedTypeInfo1(typ, tags) + + // Not in the cache, need to generate info for this type. + return c.generate(typ, rlpstruct.Tags{}) +} + +func (c *typeCache) generate(typ reflect.Type, tags rlpstruct.Tags) *typeinfo { + c.mu.Lock() + defer c.mu.Unlock() + + cur := c.cur.Load().(map[typekey]*typeinfo) + if info := cur[typekey{typ, tags}]; info != nil { + return info + } + + // Copy cur to next. + c.next = maps.Clone(cur) + + // Generate. + info := c.infoWhileGenerating(typ, tags) + + // next -> cur + c.cur.Store(c.next) + c.next = nil + return info } -func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) { +func (c *typeCache) infoWhileGenerating(typ reflect.Type, tags rlpstruct.Tags) *typeinfo { key := typekey{typ, tags} - info := typeCache[key] - if info != nil { - // another goroutine got the write lock first - return info, nil + if info := c.next[key]; info != nil { + return info } - // put a dummmy value into the cache before generating. - // if the generator tries to lookup itself, it will get + // Put a dummy value into the cache before generating. + // If the generator tries to lookup itself, it will get // the dummy value and won't call itself recursively. - typeCache[key] = new(typeinfo) - info, err := genTypeInfo(typ, tags) - if err != nil { - // remove the dummy value if the generator fails - delete(typeCache, key) - return nil, err - } - *typeCache[key] = *info - return typeCache[key], err + info := new(typeinfo) + c.next[key] = info + info.generate(typ, tags) + return info } type field struct { - index int - info *typeinfo + index int + info *typeinfo + optional bool } +// structFields resolves the typeinfo of all public fields in a struct type. func structFields(typ reflect.Type) (fields []field, err error) { + // Convert fields to rlpstruct.Field. + var allStructFields []rlpstruct.Field for i := 0; i < typ.NumField(); i++ { - if f := typ.Field(i); f.PkgPath == "" { // exported - tags, err := parseStructTag(typ, i) - if err != nil { - return nil, err - } - if tags.ignored { - continue - } - info, err := cachedTypeInfo1(f.Type, tags) - if err != nil { - return nil, err - } - fields = append(fields, field{i, info}) + rf := typ.Field(i) + allStructFields = append(allStructFields, rlpstruct.Field{ + Name: rf.Name, + Index: i, + Exported: rf.PkgPath == "", + Tag: string(rf.Tag), + Type: *rtypeToStructType(rf.Type, nil), + }) + } + + // Filter/validate fields. + structFields, structTags, err := rlpstruct.ProcessFields(allStructFields) + if err != nil { + if tagErr, ok := err.(rlpstruct.TagError); ok { + tagErr.StructType = typ.String() + return nil, tagErr } + return nil, err + } + + // Resolve typeinfo. + for i, sf := range structFields { + typ := typ.Field(sf.Index).Type + tags := structTags[i] + info := theTC.infoWhileGenerating(typ, tags) + fields = append(fields, field{sf.Index, info, tags.Optional}) } return fields, nil } -func parseStructTag(typ reflect.Type, fi int) (tags, error) { - f := typ.Field(fi) - var ts tags - for _, t := range strings.Split(f.Tag.Get("rlp"), ",") { - switch t = strings.TrimSpace(t); t { - case "": - case "-": - ts.ignored = true - case "nil": - ts.nilOK = true - case "tail": - ts.tail = true - if fi != typ.NumField()-1 { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (must be on last field)`, typ, f.Name) - } - if f.Type.Kind() != reflect.Slice { - return ts, fmt.Errorf(`rlp: invalid struct tag "tail" for %v.%s (field type is not slice)`, typ, f.Name) - } - default: - return ts, fmt.Errorf("rlp: unknown struct tag %q on %v.%s", t, typ, f.Name) +// firstOptionalField returns the index of the first field with "optional" tag. +func firstOptionalField(fields []field) int { + for i, f := range fields { + if f.optional { + return i } } - return ts, nil + return len(fields) } -func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) { - info = new(typeinfo) - if info.decoder, err = makeDecoder(typ, tags); err != nil { - return nil, err +type structFieldError struct { + typ reflect.Type + field int + err error +} + +func (e structFieldError) Error() string { + return fmt.Sprintf("%v (struct field %v.%s)", e.err, e.typ, e.typ.Field(e.field).Name) +} + +func (i *typeinfo) generate(typ reflect.Type, tags rlpstruct.Tags) { + i.decoder, i.decoderErr = makeDecoder(typ, tags) + i.writer, i.writerErr = makeWriter(typ, tags) +} + +// rtypeToStructType converts typ to rlpstruct.Type. +func rtypeToStructType(typ reflect.Type, rec map[reflect.Type]*rlpstruct.Type) *rlpstruct.Type { + k := typ.Kind() + if k == reflect.Invalid { + panic("invalid kind") } - if info.writer, err = makeWriter(typ, tags); err != nil { - return nil, err + + if prev := rec[typ]; prev != nil { + return prev // short-circuit for recursive types + } + if rec == nil { + rec = make(map[reflect.Type]*rlpstruct.Type) + } + + t := &rlpstruct.Type{ + Name: typ.String(), + Kind: k, + IsEncoder: typ.Implements(encoderInterface), + IsDecoder: typ.Implements(decoderInterface), + } + rec[typ] = t + if k == reflect.Array || k == reflect.Slice || k == reflect.Ptr { + t.Elem = rtypeToStructType(typ.Elem(), rec) + } + return t +} + +// typeNilKind gives the RLP value kind for nil pointers to 'typ'. +func typeNilKind(typ reflect.Type, tags rlpstruct.Tags) Kind { + styp := rtypeToStructType(typ, nil) + + var nk rlpstruct.NilKind + if tags.NilOK { + nk = tags.NilKind + } else { + nk = styp.DefaultNilValue() + } + switch nk { + case rlpstruct.NilKindString: + return String + case rlpstruct.NilKindList: + return List + default: + panic("invalid nil kind value") } - return info, nil } func isUint(k reflect.Kind) bool { return k >= reflect.Uint && k <= reflect.Uintptr } + +func isByte(typ reflect.Type) bool { + return typ.Kind() == reflect.Uint8 && !typ.Implements(encoderInterface) +} diff --git a/rlp/unsafe.go b/rlp/unsafe.go new file mode 100644 index 000000000000..10868caaf287 --- /dev/null +++ b/rlp/unsafe.go @@ -0,0 +1,30 @@ +// Copyright 2021 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +//go:build !nacl && !js && cgo +// +build !nacl,!js,cgo + +package rlp + +import ( + "reflect" + "unsafe" +) + +// byteArrayBytes returns a slice of the byte array v. +func byteArrayBytes(v reflect.Value, length int) []byte { + return unsafe.Slice((*byte)(unsafe.Pointer(v.UnsafeAddr())), length) +} diff --git a/rpc/metrics.go b/rpc/metrics.go index 7fb6fc0a17f9..ebb407fa3dad 100644 --- a/rpc/metrics.go +++ b/rpc/metrics.go @@ -19,7 +19,7 @@ package rpc import ( "fmt" - "github.com/ethereum/go-ethereum/metrics" + "github.com/XinFinOrg/XDPoSChain/metrics" ) var ( diff --git a/swarm/services/swap/swap.go b/swarm/services/swap/swap.go index be595b710b70..153f058968a0 100644 --- a/swarm/services/swap/swap.go +++ b/swarm/services/swap/swap.go @@ -80,7 +80,7 @@ type PayProfile struct { lock sync.RWMutex } -//create params with default values +// create params with default values func NewDefaultSwapParams() *SwapParams { return &SwapParams{ PayProfile: &PayProfile{}, @@ -102,8 +102,8 @@ func NewDefaultSwapParams() *SwapParams { } } -//this can only finally be set after all config options (file, cmd line, env vars) -//have been evaluated +// this can only finally be set after all config options (file, cmd line, env vars) +// have been evaluated func (self *SwapParams) Init(contract common.Address, prvkey *ecdsa.PrivateKey) { pubkey := &prvkey.PublicKey @@ -141,8 +141,12 @@ func NewSwap(local *SwapParams, remote *SwapProfile, backend chequebook.Backend, if !ok { log.Info(fmt.Sprintf("invalid contract %v for peer %v: %v)", remote.Contract.Hex()[:8], proto, err)) } else { + pub, err := crypto.UnmarshalPubkey(common.FromHex(remote.PublicKey)) + if err != nil { + return nil, err + } // remote contract valid, create inbox - in, err = chequebook.NewInbox(local.privateKey, remote.Contract, local.Beneficiary, crypto.ToECDSAPub(common.FromHex(remote.PublicKey)), backend) + in, err = chequebook.NewInbox(local.privateKey, remote.Contract, local.Beneficiary, pub, backend) if err != nil { log.Warn(fmt.Sprintf("unable to set up inbox for chequebook contract %v for peer %v: %v)", remote.Contract.Hex()[:8], proto, err)) } diff --git a/tests/state_test_util.go b/tests/state_test_util.go index f0db5bd40eae..742d2c595544 100644 --- a/tests/state_test_util.go +++ b/tests/state_test_util.go @@ -135,13 +135,16 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD if err != nil { return nil, err } + + // Prepare the EVM. context := core.NewEVMContext(msg, block.Header(), nil, &t.json.Env.Coinbase) context.GetHash = vmTestBlockHash evm := vm.NewEVM(context, statedb, nil, config, vmconfig) + // Execute the message. + snapshot := statedb.Snapshot() gaspool := new(core.GasPool) gaspool.AddGas(block.GasLimit()) - snapshot := statedb.Snapshot() coinbase := &t.json.Env.Coinbase if _, _, _, err, _ := core.ApplyMessage(evm, msg, gaspool, *coinbase); err != nil { @@ -150,6 +153,8 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) { return statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs) } + + // Commit block root, _ := statedb.Commit(config.IsEIP158(block.Number())) if root != common.Hash(post.Root) { return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root) @@ -235,7 +240,7 @@ func (tx *stTransaction) toMessage(ps stPostState, number *big.Int) (core.Messag if err != nil { return nil, fmt.Errorf("invalid tx data %q", dataHex) } - msg := types.NewMessage(from, to, tx.Nonce, value, gasLimit, tx.GasPrice, data, true, nil, number) + msg := types.NewMessage(from, to, tx.Nonce, value, gasLimit, tx.GasPrice, data, nil, true, nil, number) return msg, nil } diff --git a/trie/committer.go b/trie/committer.go index 9db314d98045..435da5198165 100644 --- a/trie/committer.go +++ b/trie/committer.go @@ -22,6 +22,7 @@ import ( "sync" "github.com/XinFinOrg/XDPoSChain/common" + "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/rlp" "golang.org/x/crypto/sha3" ) @@ -46,7 +47,7 @@ type leaf struct { // processed sequentially - onleaf will never be called in parallel or out of order. type committer struct { tmp sliceBuffer - sha keccakState + sha crypto.KeccakState onleaf LeafCallback leafCh chan *leaf @@ -57,7 +58,7 @@ var committerPool = sync.Pool{ New: func() interface{} { return &committer{ tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: sha3.NewLegacyKeccak256().(crypto.KeccakState), } }, } diff --git a/trie/hasher.go b/trie/hasher.go index a2b385ad5b47..1dc9aae689fe 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -17,21 +17,13 @@ package trie import ( - "hash" "sync" + "github.com/XinFinOrg/XDPoSChain/crypto" "github.com/XinFinOrg/XDPoSChain/rlp" "golang.org/x/crypto/sha3" ) -// keccakState wraps sha3.state. In addition to the usual hash methods, it also supports -// Read to get a variable amount of data from the hash state. Read is faster than Sum -// because it doesn't copy the internal state, but also modifies the internal state. -type keccakState interface { - hash.Hash - Read([]byte) (int, error) -} - type sliceBuffer []byte func (b *sliceBuffer) Write(data []byte) (n int, err error) { @@ -46,7 +38,7 @@ func (b *sliceBuffer) Reset() { // hasher is a type used for the trie Hash operation. A hasher has some // internal preallocated temp space type hasher struct { - sha keccakState + sha crypto.KeccakState tmp sliceBuffer parallel bool // Whether to use paralallel threads when hashing } @@ -56,7 +48,7 @@ var hasherPool = sync.Pool{ New: func() interface{} { return &hasher{ tmp: make(sliceBuffer, 0, 550), // cap is as large as a full FullNode. - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: sha3.NewLegacyKeccak256().(crypto.KeccakState), } }, } diff --git a/whisper/whisperv5/api.go b/whisper/whisperv5/api.go index 89e9c2860b85..37c04e70aada 100644 --- a/whisper/whisperv5/api.go +++ b/whisper/whisperv5/api.go @@ -256,8 +256,7 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er // Set asymmetric key that is used to encrypt the message if pubKeyGiven { - params.Dst = crypto.ToECDSAPub(req.PublicKey) - if !ValidatePublicKey(params.Dst) { + if params.Dst, err = crypto.UnmarshalPubkey(req.PublicKey); err != nil { return false, ErrInvalidPublicKey } } @@ -333,8 +332,7 @@ func (api *PublicWhisperAPI) Messages(ctx context.Context, crit Criteria) (*rpc. } if len(crit.Sig) > 0 { - filter.Src = crypto.ToECDSAPub(crit.Sig) - if !ValidatePublicKey(filter.Src) { + if filter.Src, err = crypto.UnmarshalPubkey(crit.Sig); err != nil { return nil, ErrInvalidSigningPubKey } } @@ -517,8 +515,7 @@ func (api *PublicWhisperAPI) NewMessageFilter(req Criteria) (string, error) { } if len(req.Sig) > 0 { - src = crypto.ToECDSAPub(req.Sig) - if !ValidatePublicKey(src) { + if src, err = crypto.UnmarshalPubkey(req.Sig); err != nil { return "", ErrInvalidSigningPubKey } } diff --git a/whisper/whisperv6/api.go b/whisper/whisperv6/api.go index 95106ee16713..0ea7e0fc524b 100644 --- a/whisper/whisperv6/api.go +++ b/whisper/whisperv6/api.go @@ -275,8 +275,7 @@ func (api *PublicWhisperAPI) Post(ctx context.Context, req NewMessage) (bool, er // Set asymmetric key that is used to encrypt the message if pubKeyGiven { - params.Dst = crypto.ToECDSAPub(req.PublicKey) - if !ValidatePublicKey(params.Dst) { + if params.Dst, err = crypto.UnmarshalPubkey(req.PublicKey); err != nil { return false, ErrInvalidPublicKey } } @@ -352,8 +351,7 @@ func (api *PublicWhisperAPI) Messages(ctx context.Context, crit Criteria) (*rpc. } if len(crit.Sig) > 0 { - filter.Src = crypto.ToECDSAPub(crit.Sig) - if !ValidatePublicKey(filter.Src) { + if filter.Src, err = crypto.UnmarshalPubkey(crit.Sig); err != nil { return nil, ErrInvalidSigningPubKey } } @@ -536,8 +534,7 @@ func (api *PublicWhisperAPI) NewMessageFilter(req Criteria) (string, error) { } if len(req.Sig) > 0 { - src = crypto.ToECDSAPub(req.Sig) - if !ValidatePublicKey(src) { + if src, err = crypto.UnmarshalPubkey(req.Sig); err != nil { return "", ErrInvalidSigningPubKey } }