Skip to content

Commit

Permalink
feat: Template variables
Browse files Browse the repository at this point in the history
refactoring template-var feature
  • Loading branch information
tristanmenzel committed Mar 11, 2024
1 parent b860e5a commit 9d93fee
Show file tree
Hide file tree
Showing 42 changed files with 2,022 additions and 14 deletions.
1 change: 1 addition & 0 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ string_ops 157 152 152
stubs/BigUInt 172 112 112
stubs/Bytes 1769 258 258
stubs/Uint64 371 8 8
template_variables/TemplateVariables 168 155 155
too_many_permutations 108 107 107
transaction/Transaction 914 864 864
tuple_support/TupleSupport 442 294 294
Expand Down
10 changes: 9 additions & 1 deletion scripts/compile_all_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,20 @@ class CompilationResult:

def get_program_size(path: Path) -> int:
try:
program = algokit_utils.Program(path.read_text("utf-8"), ALGOD_CLIENT)
program = algokit_utils.Program(replace_templates(path.read_text("utf-8")), ALGOD_CLIENT)
return len(program.raw_binary)
except Exception as e:
raise Exception(f"Error compiling teal application: {path}") from e


def replace_templates(
teal: str, *, uint_replacement: int = 0, bytes_replacement: bytes = b""
) -> str:
teal = re.sub(r"int TMPL_\w+", f"int {uint_replacement}", teal)
teal = re.sub(r"byte TMPL_\w+", f"byte 0x{bytes_replacement.hex()}", teal)
return teal


def _stabilise_logs(stdout: str) -> list[str]:
return [
line.replace("\\", "/").replace(str(GIT_ROOT).replace("\\", "/"), "<git root>")
Expand Down
3 changes: 3 additions & 0 deletions src/puya/awst/function_traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def visit_expression_statement(self, statement: awst_nodes.ExpressionStatement)
def visit_assert_statement(self, statement: awst_nodes.AssertStatement) -> None:
statement.condition.accept(self)

def visit_template_var(self, statement: awst_nodes.TemplateVar) -> None:
pass

def visit_uint64_augmented_assignment(
self, statement: awst_nodes.UInt64AugmentedAssignment
) -> None:
Expand Down
9 changes: 9 additions & 0 deletions src/puya/awst/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_bytes_constant(self)


@attrs.frozen
class TemplateVar(Expression):
wtype: WType
name: str

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_template_var(self)


@attrs.frozen
class MethodConstant(Expression):
wtype: WType = attrs.field(default=wtypes.bytes_wtype, init=False)
Expand Down
3 changes: 3 additions & 0 deletions src/puya/awst/to_code_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,9 @@ def visit_state_get(self, expr: nodes.StateGet) -> str:
def visit_state_exists(self, expr: nodes.StateExists) -> str:
return f"STATE_EXISTS({expr.field.accept(self)})"

def visit_template_var(self, expr: nodes.TemplateVar) -> str:
return f"TemplateVar[{expr.wtype}]({expr.name})"


def _indent(lines: t.Iterable[str], indent_size: str = " ") -> t.Iterator[str]:
yield from (f"{indent_size}{line}" for line in lines)
Expand Down
4 changes: 4 additions & 0 deletions src/puya/awst/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,7 @@ def visit_state_get_ex(self, expr: puya.awst.nodes.StateGetEx) -> T:
@abstractmethod
def visit_state_exists(self, expr: puya.awst.nodes.StateExists) -> T:
...

@abstractmethod
def visit_template_var(self, expr: puya.awst.nodes.TemplateVar) -> T:
...
1 change: 1 addition & 0 deletions src/puya/awst_build/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
CLS_ARC4_STATIC_ARRAY = "puyapy.arc4.StaticArray"
CLS_ARC4_TUPLE = "puyapy.arc4.Tuple"
CLS_ARC4_STRUCT = "puyapy.arc4.Struct"
CLS_TEMPLATE_VAR_METHOD = f"{PUYAPY_PREFIX}_template_variables.TemplateVar"

CONTRACT_STUB_TYPES = [
CONTRACT_BASE,
Expand Down
86 changes: 86 additions & 0 deletions src/puya/awst_build/eb/template_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections.abc import Sequence

import mypy.nodes

from puya.awst import wtypes
from puya.awst.nodes import Literal, TemplateVar
from puya.awst_build.eb.base import (
ExpressionBuilder,
IntermediateExpressionBuilder,
TypeClassExpressionBuilder,
)
from puya.awst_build.eb.var_factory import var_expression
from puya.awst_build.utils import get_arg_mapping
from puya.errors import CodeError
from puya.parse import SourceLocation


class GenericTemplateVariableExpressionBuilder(IntermediateExpressionBuilder):
def index_multiple(
self, index: Sequence[ExpressionBuilder | Literal], location: SourceLocation
) -> ExpressionBuilder:
match index:
case [TypeClassExpressionBuilder() as eb]:
wtype = eb.produces()
case _:
raise CodeError("Invalid/unhandled arguments", location)
return TemplateVariableExpressionBuilder(location=location, wtype=wtype)

def index(
self, index: ExpressionBuilder | Literal, location: SourceLocation
) -> ExpressionBuilder:
return self.index_multiple([index], location)


class TemplateVariableExpressionBuilder(TypeClassExpressionBuilder):
def __init__(self, location: SourceLocation, wtype: wtypes.WType):
super().__init__(location)
self.wtype = wtype

def produces(self) -> wtypes.WType:
return self.wtype

def call(
self,
args: Sequence[ExpressionBuilder | Literal],
arg_kinds: list[mypy.nodes.ArgKind],
arg_names: list[str | None],
location: SourceLocation,
original_expr: mypy.nodes.CallExpr,
) -> ExpressionBuilder:
var_name_arg_name = "variable_name"
arg_mapping = get_arg_mapping(
positional_arg_names=[var_name_arg_name],
args=zip(arg_names, args, strict=True),
location=location,
)

try:
var_name = arg_mapping.pop(var_name_arg_name)
except KeyError as ex:
raise CodeError("Required positional argument missing", location) from ex

prefix_arg = arg_mapping.pop("prefix", None)
if arg_mapping:
raise CodeError(
f"Unrecognised keyword argument(s): {", ".join(arg_mapping)}", location
)
match prefix_arg:
case Literal(value=str(prefix_value)):
pass
case None:
prefix_value = "TMPL_"
case _:
raise CodeError("Invalid value for prefix argument", location)

match var_name:
case Literal(value=str(str_value)):
return var_expression(
TemplateVar(
name=prefix_value + str_value, source_location=location, wtype=self.wtype
)
)
case _:
raise CodeError(
"TemplateVars must be declared using a string literal for the variable name"
)
4 changes: 4 additions & 0 deletions src/puya/awst_build/eb/type_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
log,
named_int_constants,
struct,
template_variables,
transaction,
tuple as tuple_,
uint64,
Expand Down Expand Up @@ -65,6 +66,9 @@
constants.CLS_BIGUINT: biguint.BigUIntClassExpressionBuilder,
constants.CLS_BYTES: bytes_.BytesClassExpressionBuilder,
constants.CLS_UINT64: uint64.UInt64ClassExpressionBuilder,
constants.CLS_TEMPLATE_VAR_METHOD: (
template_variables.GenericTemplateVariableExpressionBuilder
),
constants.SUBMIT_TXNS: transaction.SubmitInnerTransactionExpressionBuilder,
constants.CLS_TRANSACTION_BASE: functools.partial(
transaction.GroupTransactionClassExpressionBuilder,
Expand Down
6 changes: 6 additions & 0 deletions src/puya/ir/builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ProgramExit,
Subroutine,
SubroutineReturn,
TemplateVar,
UInt64Constant,
Value,
ValueProvider,
Expand Down Expand Up @@ -717,6 +718,11 @@ def visit_return_statement(self, statement: awst_nodes.ReturnStatement) -> TStat
)
)

def visit_template_var(self, expr: puya.awst.nodes.TemplateVar) -> TExpression:
atype = wtype_to_avm_type(expr.wtype)
typing.assert_type(atype, typing.Literal[AVMType.uint64, AVMType.bytes])
return TemplateVar(name=expr.name, atype=atype, source_location=expr.source_location)

def visit_continue_statement(self, statement: awst_nodes.ContinueStatement) -> TStatement:
self.context.block_builder.loop_continue(statement.source_location)

Expand Down
8 changes: 1 addition & 7 deletions src/puya/ir/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,6 @@ class FoldedContract:
arc4_methods: list[ARC4Method] = attrs.field(factory=list)


def wtype_to_storage_type(wtype: wtypes.WType) -> typing.Literal[AVMType.uint64, AVMType.bytes]:
atype = wtype_to_avm_type(wtype)
assert atype is not AVMType.any
return atype


def fold_state_and_special_methods(
ctx: IRBuildContext, contract: awst_nodes.ContractFragment
) -> FoldedContract:
Expand All @@ -298,7 +292,7 @@ def fold_state_and_special_methods(
name=state.member_name,
source_location=state.source_location,
key=state.key,
storage_type=wtype_to_storage_type(state.storage_wtype),
storage_type=wtype_to_avm_type(state.storage_wtype),
description=state.description,
)
if state.kind == awst_nodes.AppStateKind.app_global:
Expand Down
9 changes: 9 additions & 0 deletions src/puya/ir/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,15 @@ def accept(self, visitor: IRVisitor[T]) -> T:
return visitor.visit_biguint_constant(self)


@attrs.frozen
class TemplateVar(Value):
name: str
atype: AVMType

def accept(self, visitor: IRVisitor[T]) -> T:
return visitor.visit_template_var(self)


@attrs.frozen
class BytesConstant(Constant):
"""Constant for types that are logically bytes"""
Expand Down
3 changes: 3 additions & 0 deletions src/puya/ir/to_text_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def visit_subroutine_return(self, op: models.SubroutineReturn) -> str:
results = " ".join(r.accept(self) for r in op.result)
return f"return {results}"

def visit_template_var(self, deploy_var: models.TemplateVar) -> str:
return f"TemplateVar[{deploy_var.atype}]({deploy_var.name})"

def visit_program_exit(self, op: models.ProgramExit) -> str:
return f"exit {op.result.accept(self)}"

Expand Down
3 changes: 2 additions & 1 deletion src/puya/ir/types_.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import typing

from puya.avm_type import AVMType
from puya.awst import (
Expand Down Expand Up @@ -30,7 +31,7 @@ def bytes_enc_to_avm_bytes_enc(bytes_encoding: BytesEncoding) -> AVMBytesEncodin
def wtype_to_avm_type(
expr_or_wtype: wtypes.WType | awst_nodes.Expression,
source_location: SourceLocation | None = None,
) -> AVMType:
) -> typing.Literal[AVMType.bytes, AVMType.uint64]:
if isinstance(expr_or_wtype, awst_nodes.Expression):
return wtype_to_avm_type(
expr_or_wtype.wtype, source_location=source_location or expr_or_wtype.source_location
Expand Down
10 changes: 10 additions & 0 deletions src/puya/ir/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def visit_program_exit(self, exit_: puya.ir.models.ProgramExit) -> T:
def visit_fail(self, fail: puya.ir.models.Fail) -> T:
...

@abstractmethod
def visit_template_var(self, deploy_var: puya.ir.models.TemplateVar) -> T:
...


class IRTraverser(IRVisitor[None]):
active_block: puya.ir.models.BasicBlock
Expand Down Expand Up @@ -124,6 +128,9 @@ def visit_bytes_constant(self, const: puya.ir.models.BytesConstant) -> None:
def visit_address_constant(self, const: puya.ir.models.AddressConstant) -> None:
pass

def visit_template_var(self, deploy_var: puya.ir.models.TemplateVar) -> None:
pass

def visit_method_constant(self, const: puya.ir.models.MethodConstant) -> None:
pass

Expand Down Expand Up @@ -199,6 +206,9 @@ def visit_method_constant(self, const: puya.ir.models.MethodConstant) -> T | Non
def visit_phi(self, phi: puya.ir.models.Phi) -> T | None:
return None

def visit_template_var(self, deploy_var: puya.ir.models.TemplateVar) -> T | None:
return None

def visit_phi_argument(self, arg: puya.ir.models.PhiArgument) -> T | None:
return None

Expand Down
4 changes: 4 additions & 0 deletions src/puya/ir/visitor_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Register,
SubroutineReturn,
Switch,
TemplateVar,
UInt64Constant,
ValueTuple,
)
Expand Down Expand Up @@ -75,6 +76,9 @@ def visit_assignment(self, ass: Assignment) -> Assignment | None:
def visit_register(self, reg: Register) -> Register:
return reg

def visit_template_var(self, deploy_var: TemplateVar) -> TemplateVar:
return deploy_var

def visit_uint64_constant(self, const: UInt64Constant) -> UInt64Constant:
return const

Expand Down
9 changes: 9 additions & 0 deletions src/puya/mir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def visit_register(self, reg: ir.Register) -> None:
)
)

def visit_template_var(self, deploy_var: ir.TemplateVar) -> None:
self._add_op(
models.PushTemplateVar(
name=deploy_var.name,
atype=deploy_var.atype,
source_location=deploy_var.source_location,
)
)

def visit_value_tuple(self, tup: ir.ValueTuple) -> None:
raise InternalError(
"Encountered ValueTuple during codegen - should have been eliminated in prior stages",
Expand Down
27 changes: 26 additions & 1 deletion src/puya/mir/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import attrs

from puya.avm_type import AVMType
from puya.errors import InternalError
from puya.ir.utils import format_bytes

if t.TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence

from puya.avm_type import AVMType
from puya.ir.types_ import AVMBytesEncoding
from puya.mir.visitor import MIRVisitor
from puya.parse import SourceLocation
Expand Down Expand Up @@ -53,6 +53,31 @@ def __str__(self) -> str:
return f"byte {format_bytes(self.value, self.encoding)}"


@attrs.frozen(eq=False)
class PushTemplateVar(BaseOp):
name: str
atype: AVMType = attrs.field()
op_code: str = attrs.field(init=False)

@op_code.default
def _default_opcode(self) -> str:
match self.atype:
case AVMType.bytes:
return "byte"
case AVMType.uint64:
return "int"
case _:
raise InternalError(
f"Unsupported atype for PushTemplateVar: {self.atype}", self.source_location
)

def accept(self, visitor: MIRVisitor[_T]) -> _T:
return visitor.visit_push_deploy_var(self)

def __str__(self) -> str:
return f"{self.op_code} {self.name}"


@attrs.frozen(eq=False)
class PushAddress(BaseOp):
value: str
Expand Down
10 changes: 10 additions & 0 deletions src/puya/mir/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ def visit_push_bytes(self, push: models.PushBytes) -> list[teal.TealOp]:
self.l_stack.append(format_bytes(push.value, push.encoding))
return [teal.PushBytes(push.value, push.encoding, source_location=push.source_location)]

def visit_push_deploy_var(self, deploy_var: models.PushTemplateVar) -> list[teal.TealOp]:
self.l_stack.append(deploy_var.name)
return [
teal.PushTemplateVar(
name=deploy_var.name,
op_code=deploy_var.op_code,
source_location=deploy_var.source_location,
)
]

def visit_push_address(self, addr: models.PushAddress) -> list[teal.TealOp]:
self.l_stack.append(addr.value)
return [teal.PushAddress(addr.value, source_location=addr.source_location)]
Expand Down
Loading

0 comments on commit 9d93fee

Please sign in to comment.