Skip to content

Commit

Permalink
Review comments, get rid of env.evm calls
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielSchiavini committed Apr 15, 2024
1 parent 00d1b72 commit 59a78dd
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 107 deletions.
2 changes: 1 addition & 1 deletion boa/contracts/abi/abi_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(
self._abi = abi
self._functions = functions

self._bytecode = self.env.evm.get_code(address)
self._bytecode = self.env.get_code(address)
if not self._bytecode:
warn(
f"Requested {self} but there is no bytecode at that address!",
Expand Down
27 changes: 0 additions & 27 deletions boa/contracts/vyper/decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,6 @@
)
from vyper.utils import unsigned_to_signed

from boa.vm.utils import ceil32, floor32


# wrap storage in something which looks like memory
class ByteAddressableStorage:
def __init__(self, db, address, key):
self.db = db
self.address = address
self.key = key

def __getitem__(self, subscript):
if isinstance(subscript, slice):
ret = b""
start = subscript.start or 0
stop = subscript.stop
i = self.key + start // 32
while i < self.key + ceil32(stop) // 32:
ret += self.db.get_storage(self.address, i).to_bytes(32, "big")
i += 1

start_ofst = floor32(start)
start -= start_ofst
stop -= start_ofst
return memoryview(ret[start:stop])
else:
raise Exception("Must slice {self}")


class _Struct(dict):
def __init__(self, name, *args, **kwargs):
Expand Down
16 changes: 6 additions & 10 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@
generate_bytecode_for_arbitrary_stmt,
generate_bytecode_for_internal_fn,
)
from boa.contracts.vyper.decoder_utils import (
ByteAddressableStorage,
decode_vyper_object,
)
from boa.contracts.vyper.decoder_utils import decode_vyper_object
from boa.contracts.vyper.event import Event, RawEvent
from boa.contracts.vyper.ir_executor import executor_from_ir
from boa.environment import Env
Expand Down Expand Up @@ -120,7 +117,7 @@ def at(self, address: Any) -> "VyperContract":
address = Address(address)

ret = self.deploy(override_address=address, skip_initcode=True)
bytecode = ret.env.evm.get_code(address)
bytecode = ret.env.get_code(address)

ret._set_bytecode(bytecode)

Expand Down Expand Up @@ -365,8 +362,7 @@ def setpath(lens, path, val):
class StorageVar:
def __init__(self, contract, slot, typ):
self.contract = contract
self.addr = self.contract._address.canonical_address
self.accountdb = contract.env.evm.get_account_db()
self.addr = self.contract._address
self.slot = slot
self.typ = typ

Expand All @@ -375,7 +371,7 @@ def _decode(self, slot, typ, truncate_limit=None):
if truncate_limit is not None and n > truncate_limit:
return None # indicate failure to caller

fakemem = ByteAddressableStorage(self.accountdb, self.addr, slot)
fakemem = self.contract.env.get_storage_slot(self.addr, slot)
return decode_vyper_object(fakemem, typ)

def _dealias(self, maybe_address):
Expand All @@ -387,8 +383,8 @@ def _dealias(self, maybe_address):
def get(self, truncate_limit=None):
if isinstance(self.typ, HashMapT):
ret = {}
for k in self.contract.env.evm.sstore_trace.get(self.addr, {}):
path = unwrap_storage_key(self.contract.env.evm.sha3_trace, k)
for k in self.contract.env.sstore_trace.get(self.addr, {}):
path = unwrap_storage_key(self.contract.env.sha3_trace, k)
if to_int(path[0]) != self.slot:
continue

Expand Down
40 changes: 28 additions & 12 deletions boa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __init__(self):
self._contracts = {}
self._code_registry = {}

self.sha3_trace: dict = {}
self.sstore_trace: dict = {}

self._profiled_contracts = {}
self._cached_call_profiles = {}
self._cached_line_profiles = {}
Expand Down Expand Up @@ -69,9 +72,14 @@ def fork_rpc(self, rpc: RPC, reset_traces=True, block_identifier="safe", **kwarg
:param block_identifier: Block identifier to fork from
:param kwargs: Additional arguments for the RPC
"""
# we usually want to reset the trace data structures
# but sometimes don't, give caller the option.
if reset_traces:
self.sha3_trace = {}
self.sstore_trace = {}

self.evm.fork_rpc(
rpc,
reset_traces,
fast_mode_enabled=self._fast_mode_enabled,
block_identifier=block_identifier,
**kwargs,
Expand Down Expand Up @@ -155,12 +163,7 @@ def reset_gas_used(self):
# to the snapshot on exiting the with statement
@contextlib.contextmanager
def anchor(self):
snapshot_id = self.evm.snapshot()
try:
with self.evm.anchor():
yield
finally:
self.evm.revert(snapshot_id)
return self.evm.anchor()

@contextlib.contextmanager
def sender(self, address):
Expand Down Expand Up @@ -210,11 +213,10 @@ def deploy_code(
) -> tuple[Address, bytes]:
sender = self._get_sender(sender)

target_address = (
self.evm.generate_contract_address(sender)
if override_address is None
else Address(override_address)
)
if override_address is None:
target_address = self.evm.generate_create_address(sender)
else:
target_address = Address(override_address)

prefetch_state = self._fork_mode and self._fork_try_prefetch_state
origin = sender # XXX: consider making this parameterizable
Expand Down Expand Up @@ -320,6 +322,20 @@ def _hook_trace_computation(self, computation, contract=None):
child_contract = self._lookup_contract_fast(child.msg.code_address)
self._hook_trace_computation(child, child_contract)

def get_code(self, address):
return self.evm.get_code(Address(address))

def get_storage_slot(self, address, slot):
return self.evm.get_storage_slot(address, slot)

@property
def block_number(self):
return self.evm.block_number

@property
def timestamp(self):
return self.evm.timestamp

# function to time travel
def time_travel(
self,
Expand Down
110 changes: 64 additions & 46 deletions boa/vm/py_evm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import sys
import warnings
from typing import Any, Iterator, Optional, Tuple, Type
from typing import Any, Iterator, Optional, Type

import eth.constants as constants
import eth.tools.builder.chain as chain
Expand All @@ -17,13 +17,11 @@
from eth.db.account import AccountDB
from eth.db.atomic import AtomicDB
from eth.exceptions import Halt
from eth.typing import JournalDBCheckpoint
from eth.vm.code_stream import CodeStream
from eth.vm.gas_meter import allow_negative_refund_strategy
from eth.vm.message import Message
from eth.vm.opcode_values import STOP
from eth.vm.transaction_context import BaseTransactionContext
from eth_typing import Hash32
from eth_utils import setup_DEBUG2_logging

from boa.rpc import RPC
Expand All @@ -32,7 +30,7 @@
from boa.vm.fast_accountdb import patch_pyevm_state_object, unpatch_pyevm_state_object
from boa.vm.fork import AccountDBFork
from boa.vm.gas_meters import GasMeter
from boa.vm.utils import to_bytes, to_int
from boa.vm.utils import ceil32, floor32, to_bytes, to_int


def enable_pyevm_verbose_logging():
Expand Down Expand Up @@ -204,8 +202,8 @@ class Sha3PreimageTracer:

# trace preimages of sha3

def __init__(self, sha3_op, evm):
self.evm = evm
def __init__(self, sha3_op, env):
self.env = env
self.sha3 = sha3_op

def __call__(self, computation):
Expand All @@ -221,21 +219,26 @@ def __call__(self, computation):

image = _stackitem_to_bytes(computation._stack.values[-1])

self.evm._trace_sha3_preimage(preimage, image)
self.env.sha3_trace[preimage] = image


class SstoreTracer:
mnemonic = "SSTORE"

def __init__(self, sstore_op, evm):
self.evm = evm
def __init__(self, sstore_op, env):
self.env = env
self.sstore = sstore_op

def __call__(self, computation):
value, slot = [_stackitem_to_int(t) for t in computation._stack.values[-2:]]
account = computation.msg.storage_address

self.evm._trace_sstore(account, slot)
# we don't want to deal with snapshots/commits/reverts, so just
# register that the slot was touched and downstream can filter
# zero entries.
self.env.sstore_trace[account] = self.env.sstore_trace.get(account, set()) & {
slot
}

# dispatch into py-evm
self.sstore(computation)
Expand Down Expand Up @@ -368,20 +371,10 @@ def __init__(
class PyEVM:
def __init__(self, env, fast_mode_enabled: bool):
self.chain = _make_chain()
self.sha3_trace: dict = {}
self.sstore_trace: dict = {}
self.env = env
self._init_vm(
env, AccountDB, reset_traces=True, fast_mode_enabled=fast_mode_enabled
)
self._init_vm(env, AccountDB, fast_mode_enabled=fast_mode_enabled)

def _init_vm(
self,
env,
account_db_class: Type[AccountDB],
reset_traces: bool,
fast_mode_enabled: bool,
):
def _init_vm(self, env, account_db_class: Type[AccountDB], fast_mode_enabled: bool):
self.vm = self.chain.get_vm()
self.vm.__class__._state_class.account_db_class = account_db_class

Expand All @@ -398,18 +391,9 @@ def _init_vm(

self.vm.state.computation_class = c

# we usually want to reset the trace data structures
# but sometimes don't, give caller the option.
if reset_traces:
self.sha3_trace = {}
self.sstore_trace = {}

# patch in tracing opcodes
c.opcodes[0x20] = Sha3PreimageTracer(c.opcodes[0x20], self)
c.opcodes[0x55] = SstoreTracer(c.opcodes[0x55], self)

def _trace_sha3_preimage(self, preimage, image):
self.sha3_trace[image] = preimage
c.opcodes[0x20] = Sha3PreimageTracer(c.opcodes[0x20], env)
c.opcodes[0x55] = SstoreTracer(c.opcodes[0x55], env)

def _trace_sstore(self, account, slot):
self.sstore_trace.setdefault(account, set())
Expand All @@ -425,15 +409,10 @@ def enable_fast_mode(self, flag: bool = True):
unpatch_pyevm_state_object(self.vm.state)

def fork_rpc(
self,
rpc: RPC,
reset_traces: bool,
fast_mode_enabled: bool,
block_identifier: str,
**kwargs,
self, rpc: RPC, fast_mode_enabled: bool, block_identifier: str, **kwargs
):
account_db_class = AccountDBFork.class_from_rpc(rpc, block_identifier, **kwargs)
self._init_vm(self.env, account_db_class, reset_traces, fast_mode_enabled)
self._init_vm(self.env, account_db_class, fast_mode_enabled)
block_info = self.vm.state._account_db._block_info

self.vm.patch.timestamp = int(block_info["timestamp"], 16)
Expand Down Expand Up @@ -471,16 +450,21 @@ def get_gas_limit(self):
def reset_access_counters(self):
self.vm.state._account_db._reset_access_counters()

def snapshot(self) -> Tuple[Hash32, JournalDBCheckpoint]:
def snapshot(self) -> Any:
return self.vm.state.snapshot()

def anchor(self):
return self.vm.patch.anchor()
snapshot_id = self.snapshot()
try:
with self.vm.patch.anchor():
yield
finally:
self.revert(snapshot_id)

def revert(self, snapshot_id: Tuple[Hash32, JournalDBCheckpoint]) -> None:
def revert(self, snapshot_id: Any) -> None:
self.vm.state.revert(snapshot_id)

def generate_contract_address(self, sender: Address):
def generate_create_address(self, sender: Address):
nonce = self.vm.state.get_nonce(sender.canonical_address)
self.vm.state.increment_nonce(sender.canonical_address)
return Address(generate_contract_address(sender.canonical_address, nonce))
Expand Down Expand Up @@ -559,12 +543,46 @@ def execute_code(
def block_id(self):
return self.vm.state._account_db._block_id

@property
def block_number(self):
return self.vm.state.block_number

@property
def timestamp(self):
return self.vm.state.timestamp

def get_storage_slot(self, addr: Address, slot: int) -> "ByteAddressableStorage":
account_db = self.vm.state._account_db
return ByteAddressableStorage(account_db, addr, slot)

def time_travel(self, add_seconds: int, add_blocks: int):
self.vm.patch.timestamp += add_seconds
self.vm.patch.block_number += add_blocks

def get_account_db(self):
return self.vm.state._account_db

# wrap storage in something which looks like memory
class ByteAddressableStorage:
def __init__(self, db: AccountDB, address: Address, key: int):
self.db = db
self.address = address.canonical_address
self.key = key

def __getitem__(self, subscript):
if isinstance(subscript, slice):
ret = b""
start = subscript.start or 0
stop = subscript.stop
i = self.key + start // 32
while i < self.key + ceil32(stop) // 32:
ret += self.db.get_storage(self.address, i).to_bytes(32, "big")
i += 1

start_ofst = floor32(start)
start -= start_ofst
stop -= start_ofst
return memoryview(ret[start:stop])
else:
raise Exception("Must slice {self}")


GENESIS_PARAMS = {"difficulty": constants.GENESIS_DIFFICULTY, "gas_limit": int(1e8)}
Expand Down
2 changes: 1 addition & 1 deletion tests/unitary/test_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_create2_address():
child_contract_address = factory.create_child(blueprint.address, salt)

# TODO: make a util function on boa.env to get code
blueprint_bytecode = boa.env.evm.get_code(blueprint.address)
blueprint_bytecode = boa.env.get_code(blueprint.address)
assert child_contract_address == get_create2_address(
blueprint_bytecode, factory.address, salt
)
Loading

0 comments on commit 59a78dd

Please sign in to comment.