Skip to content

Commit

Permalink
fix[codegen]: fix double evals in sqrt, slice, blueprint (#3976)
Browse files Browse the repository at this point in the history
cache respective args, and add new tests for side-effect evaluations for
the respective builtins

---------

Co-authored-by: Charles Cooper <cooper.charles.m@gmail.com>
  • Loading branch information
cyberthirst and charles-cooper authored May 16, 2024
1 parent 98370f5 commit 0453f63
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 64 deletions.
39 changes: 39 additions & 0 deletions tests/functional/builtins/codegen/test_create_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,45 @@ def test(target: address):
assert test.foo() == 12


def test_blueprint_evals_once_side_effects(get_contract, deploy_blueprint_for, env):
# test msize allocator does not get trampled by salt= kwarg
code = """
foo: public(uint256)
"""

deployer_code = """
created_address: public(address)
deployed: public(uint256)
@external
def get() -> Bytes[32]:
self.deployed += 1
return b''
@external
def create_(target: address):
self.created_address = create_from_blueprint(
target,
raw_call(self, method_id("get()"), max_outsize=32),
raw_args=True, code_offset=3
)
"""

foo_contract = get_contract(code)
expected_runtime_code = env.get_code(foo_contract.address)

f, FooContract = deploy_blueprint_for(code)

d = get_contract(deployer_code)

d.create_(f.address)

test = FooContract(d.created_address())
assert env.get_code(test.address) == expected_runtime_code
assert test.foo() == 0
assert d.deployed() == 1


def test_create_copy_of_complex_kwargs(get_contract, env):
# test msize allocator does not get trampled by salt= kwarg
complex_salt = """
Expand Down
26 changes: 26 additions & 0 deletions tests/functional/builtins/codegen/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,29 @@ def test_slice_buffer_oob_reverts(bad_code, get_contract, tx_failed):
c = get_contract(bad_code)
with tx_failed():
c.do_slice()


# tests all 3 adhoc locations: `msg.data`, `self.code`, `<address>.code`
@pytest.mark.parametrize("adhoc_loc", ["msg.data", "self.code", "a.code"])
def test_slice_start_eval_once(get_contract, adhoc_loc):
code = f"""
counter: uint256
@internal
def bar() -> uint256:
self.counter += 1
return 1
@external
def foo(cs: String[64]) -> uint256:
s: Bytes[64] = b""
# use `a` to exercise the path with `<address>.code`
a: address = self
s = slice({adhoc_loc}, self.bar(), 3)
return self.counter
"""

arg = "a" * 64
c = get_contract(code)
# ensure that counter was incremented only once
assert c.foo(arg) == 1
19 changes: 19 additions & 0 deletions tests/functional/codegen/types/numbers/test_sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,22 @@ def test_sqrt_valid_range(sqrt_contract, value):
def test_sqrt_invalid_range(tx_failed, sqrt_contract, value):
with tx_failed():
sqrt_contract.test(decimal_to_int(value))


def test_sqrt_eval_once(get_contract):
code = """
c: uint256
@internal
def some_decimal() -> decimal:
self.c += 1
return 1.0
@external
def foo() -> uint256:
k: decimal = sqrt(self.some_decimal())
return self.c
"""

c = get_contract(code)
assert c.foo() == 1
142 changes: 78 additions & 64 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,44 +252,46 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context:
# allocate a buffer for the return value
buf = context.new_internal_variable(dst_typ)

# `msg.data` by `calldatacopy`
if sub.value == "~calldata":
node = [
"seq",
_make_slice_bounds_check(start, length, "calldatasize"),
["mstore", buf, length],
["calldatacopy", add_ofst(buf, 32), start, length],
buf,
]

# `self.code` by `codecopy`
elif sub.value == "~selfcode":
node = [
"seq",
_make_slice_bounds_check(start, length, "codesize"),
["mstore", buf, length],
["codecopy", add_ofst(buf, 32), start, length],
buf,
]
with scope_multi((start, length), ("start", "length")) as (b1, (start, length)):
# `msg.data` by `calldatacopy`
if sub.value == "~calldata":
node = [
"seq",
_make_slice_bounds_check(start, length, "calldatasize"),
["mstore", buf, length],
["calldatacopy", add_ofst(buf, 32), start, length],
buf,
]

# `<address>.code` by `extcodecopy`
else:
assert sub.value == "~extcode" and len(sub.args) == 1
node = [
"with",
"_extcode_address",
sub.args[0],
[
# `self.code` by `codecopy`
elif sub.value == "~selfcode":
node = [
"seq",
_make_slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]),
_make_slice_bounds_check(start, length, "codesize"),
["mstore", buf, length],
["extcodecopy", "_extcode_address", add_ofst(buf, 32), start, length],
["codecopy", add_ofst(buf, 32), start, length],
buf,
],
]
]

assert isinstance(length.value, int) # mypy hint
return IRnode.from_list(node, typ=BytesT(length.value), location=MEMORY)
# `<address>.code` by `extcodecopy`
else:
assert sub.value == "~extcode" and len(sub.args) == 1
node = [
"with",
"_extcode_address",
sub.args[0],
[
"seq",
_make_slice_bounds_check(start, length, ["extcodesize", "_extcode_address"]),
["mstore", buf, length],
["extcodecopy", "_extcode_address", add_ofst(buf, 32), start, length],
buf,
],
]

assert isinstance(length.value, int) # mypy hint
ret = IRnode.from_list(node, typ=BytesT(length.value), location=MEMORY)
return b1.resolve(ret)


# note: this and a lot of other builtins could be refactored to accept any uint type
Expand Down Expand Up @@ -1816,9 +1818,15 @@ def _build_create_IR(
if len(ctor_args) != 1 or not isinstance(ctor_args[0].typ, BytesT):
raise StructureException("raw_args must be used with exactly 1 bytes argument")

argbuf = bytes_data_ptr(ctor_args[0])
argslen = get_bytearray_length(ctor_args[0])
bufsz = ctor_args[0].typ.maxlen
with ctor_args[0].cache_when_complex("arg") as (b1, arg):
argbuf = bytes_data_ptr(arg)
argslen = get_bytearray_length(arg)
bufsz = arg.typ.maxlen
return b1.resolve(
self._helper(
argbuf, bufsz, target, value, salt, argslen, code_offset, revert_on_failure
)
)
else:
# encode the varargs
to_encode = ir_tuple_from_args(ctor_args)
Expand All @@ -1831,7 +1839,11 @@ def _build_create_IR(
# return a complex expression which writes to memory and returns
# the length of the encoded data
argslen = abi_encode(argbuf, to_encode, context, bufsz=bufsz, returns_len=True)
return self._helper(
argbuf, bufsz, target, value, salt, argslen, code_offset, revert_on_failure
)

def _helper(self, argbuf, bufsz, target, value, salt, argslen, code_offset, revert_on_failure):
# NOTE: we need to invoke the abi encoder before evaluating MSIZE,
# then copy the abi encoded buffer to past-the-end of the initcode
# (since the abi encoder could write to fresh memory).
Expand Down Expand Up @@ -2118,7 +2130,8 @@ def build_IR(self, expr, args, kwargs, context):

arg = args[0]
# TODO: reify decimal and integer sqrt paths (see isqrt)
sqrt_code = """
with arg.cache_when_complex("x") as (b1, arg):
sqrt_code = """
assert x >= 0.0
z: decimal = 0.0
Expand All @@ -2133,33 +2146,34 @@ def build_IR(self, expr, args, kwargs, context):
break
y = z
z = (x / z + z) / 2.0
"""

x_type = DecimalT()
placeholder_copy = ["pass"]
# Steal current position if variable is already allocated.
if arg.value == "mload":
new_var_pos = arg.args[0]
# Other locations need to be copied.
else:
new_var_pos = context.new_internal_variable(x_type)
placeholder_copy = ["mstore", new_var_pos, arg]
# Create input variables.
variables = {"x": VariableRecord(name="x", pos=new_var_pos, typ=x_type, mutable=False)}
# Dictionary to update new (i.e. typecheck) namespace
variables_2 = {"x": VarInfo(DecimalT())}
# Generate inline IR.
new_ctx, sqrt_ir = generate_inline_function(
code=sqrt_code,
variables=variables,
variables_2=variables_2,
memory_allocator=context.memory_allocator,
)
return IRnode.from_list(
["seq", placeholder_copy, sqrt_ir, new_ctx.vars["z"].pos], # load x variable
typ=DecimalT(),
location=MEMORY,
)
"""

x_type = DecimalT()
placeholder_copy = ["pass"]
# Steal current position if variable is already allocated.
if arg.value == "mload":
new_var_pos = arg.args[0]
# Other locations need to be copied.
else:
new_var_pos = context.new_internal_variable(x_type)
placeholder_copy = ["mstore", new_var_pos, arg]
# Create input variables.
variables = {"x": VariableRecord(name="x", pos=new_var_pos, typ=x_type, mutable=False)}
# Dictionary to update new (i.e. typecheck) namespace
variables_2 = {"x": VarInfo(DecimalT())}
# Generate inline IR.
new_ctx, sqrt_ir = generate_inline_function(
code=sqrt_code,
variables=variables,
variables_2=variables_2,
memory_allocator=context.memory_allocator,
)
ret = IRnode.from_list(
["seq", placeholder_copy, sqrt_ir, new_ctx.vars["z"].pos], # load x variable
typ=DecimalT(),
location=MEMORY,
)
return b1.resolve(ret)


class ISqrt(BuiltinFunctionT):
Expand Down

0 comments on commit 0453f63

Please sign in to comment.