Skip to content

Commit

Permalink
refactor: simplify intrinsics builder logic
Browse files Browse the repository at this point in the history
test: disable branch coverage for now
  • Loading branch information
achidlow committed Apr 15, 2024
1 parent c8fda82 commit 1eb3bee
Show file tree
Hide file tree
Showing 6 changed files with 905 additions and 1,687 deletions.
1 change: 0 additions & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[run]
branch = True
source = src/puya
omit = src/puya/_vendor/*

Expand Down
152 changes: 79 additions & 73 deletions scripts/generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import attrs
from puya import log
from puya.awst import wtypes
from puya.awst_build.intrinsic_models import ArgMapping, FunctionOpMapping
from puya.awst_build.intrinsic_models import (
FunctionOpMapping,
ImmediateArgMapping,
StackArgMapping,
)
from puya.awst_build.utils import snake_case

from scripts.transform_lang_spec import (
Expand All @@ -30,8 +34,8 @@
VCS_ROOT = Path(__file__).parent.parent


def _get_imported_name(typ: wtypes.WType | str) -> str:
return typ.stub_name.rsplit(".")[-1] if isinstance(typ, wtypes.WType) else typ
def _get_imported_name(typ: wtypes.WType) -> str:
return typ.stub_name.rsplit(".")[-1]


WTYPE_REFERENCES = [
Expand All @@ -42,13 +46,35 @@ def _get_imported_name(typ: wtypes.WType | str) -> str:
wtypes.bytes_wtype,
wtypes.uint64_wtype,
]
CLS_MAPPING: dict[str, wtypes.WType | type] = {
"int": int,
"bytes": bytes,
"str": str,
"bool": wtypes.bool_wtype,
**{_get_imported_name(wtype): wtype for wtype in WTYPE_REFERENCES},
WTYPE_TO_LITERAL: dict[wtypes.WType, type[int | str | bytes] | None] = {
wtypes.bytes_wtype: bytes,
wtypes.uint64_wtype: int,
wtypes.account_wtype: None, # could be str
wtypes.biguint_wtype: None, # could be int
wtypes.bool_wtype: None, # already a Python type
# below are covered transitively with respect to STACK_TYPE_MAPPING
wtypes.application_wtype: None, # maybe could be int? but also covered transitively anyway
wtypes.asset_wtype: None, # same as above
}
STACK_TYPE_MAPPING: dict[StackType, Sequence[wtypes.WType]] = {
StackType.address_or_index: [wtypes.account_wtype, wtypes.uint64_wtype],
StackType.application: [wtypes.application_wtype, wtypes.uint64_wtype],
StackType.asset: [wtypes.asset_wtype, wtypes.uint64_wtype],
StackType.bytes: [wtypes.bytes_wtype],
StackType.bytes_8: [wtypes.bytes_wtype],
StackType.bytes_32: [wtypes.bytes_wtype],
StackType.bytes_33: [wtypes.bytes_wtype],
StackType.bytes_64: [wtypes.bytes_wtype],
StackType.bytes_80: [wtypes.bytes_wtype],
StackType.bool: [wtypes.bool_wtype, wtypes.uint64_wtype],
StackType.uint64: [wtypes.uint64_wtype],
StackType.any: [wtypes.bytes_wtype, wtypes.uint64_wtype],
StackType.box_name: [wtypes.bytes_wtype],
StackType.address: [wtypes.account_wtype],
StackType.bigint: [wtypes.biguint_wtype],
StackType.state_key: [wtypes.bytes_wtype],
}

BYTES_LITERAL = "bytes"
UINT64_LITERAL = "int"
STUB_NAMESPACE = "op"
Expand Down Expand Up @@ -389,7 +415,7 @@ def main() -> None:
case RenamedOpCode() as aliased:
function_defs.extend(build_aliased_ops(lang_spec, aliased))
case _:
raise Exception("Unexpected op code group")
raise TypeError("Unexpected op code group")
function_defs.sort(key=lambda x: x.name)
class_defs.sort(key=lambda x: x.name)

Expand All @@ -398,40 +424,16 @@ def main() -> None:
output_awst_data(lang_spec, enum_names, function_defs, class_defs)


def sub_types(type_name: StackType, *, covariant: bool) -> list[str]:
uint64: Sequence[str | wtypes.WType] = [wtypes.uint64_wtype, UINT64_LITERAL]
bytes_: Sequence[str | wtypes.WType] = [wtypes.bytes_wtype, BYTES_LITERAL]
account: Sequence[str | wtypes.WType] = [wtypes.account_wtype]
stack_type_mapping: dict[StackType, Sequence[str | wtypes.WType]] = {
StackType.address_or_index: [*account, *uint64],
StackType.application: [wtypes.application_wtype, *uint64],
StackType.asset: [wtypes.asset_wtype, *uint64],
StackType.bytes: bytes_,
StackType.bytes_8: bytes_,
StackType.bytes_32: bytes_,
StackType.bytes_33: bytes_,
StackType.bytes_64: bytes_,
StackType.bytes_80: bytes_,
StackType.bool: [wtypes.bool_wtype, *uint64],
StackType.uint64: uint64,
StackType.any: [*bytes_, *uint64],
StackType.box_name: bytes_,
StackType.address: account,
StackType.bigint: [wtypes.biguint_wtype],
StackType.state_key: bytes_,
}

def sub_types(type_name: StackType, *, covariant: bool) -> Sequence[wtypes.WType]:
try:
last_index = None if covariant else 1
return list(map(_get_imported_name, stack_type_mapping[type_name][:last_index]))
wtypes_ = STACK_TYPE_MAPPING[type_name]
except KeyError as ex:
raise NotImplementedError(
f"Could not map stack type {type_name} to an algopy type"
) from ex


def sub_type(type_name: StackType, *, covariant: bool) -> str:
return " | ".join(sub_types(type_name, covariant=covariant))
else:
last_index = None if covariant else 1
return wtypes_[:last_index]


def is_simple_op(op: Op) -> bool:
Expand All @@ -446,16 +448,14 @@ def is_simple_op(op: Op) -> bool:
return True


def immediate_kind_to_type(kind: ImmediateKind) -> type:
def immediate_kind_to_type(kind: ImmediateKind) -> type[int | str]:
match kind:
case ImmediateKind.uint8 | ImmediateKind.int8 | ImmediateKind.uint64:
return int
case ImmediateKind.bytes:
return bytes
case ImmediateKind.arg_enum:
return str
case _:
raise Exception(f"Unexpected ImmediateKind: {kind}")
raise ValueError(f"Unexpected ImmediateKind: {kind}")


def get_python_type(
Expand All @@ -465,7 +465,16 @@ def get_python_type(
case StackType() as stack_type:
if any_as and stack_type == StackType.any:
return any_as
return sub_type(stack_type, covariant=covariant)
wtypes_ = sub_types(stack_type, covariant=covariant)
names = [_get_imported_name(wt) for wt in wtypes_]
if covariant:
for wt in wtypes_:
lit_t = WTYPE_TO_LITERAL[wt]
if lit_t is not None:
lit_name = lit_t.__name__
if lit_name not in names:
names.append(lit_name)
return " | ".join(names)
case ImmediateKind() as immediate_kind:
return immediate_kind_to_type(immediate_kind).__name__
case _:
Expand Down Expand Up @@ -723,16 +732,6 @@ def build_enum(spec: LanguageSpec, arg_enum: str) -> Iterable[str]:
yield ""


def get_wtype_or_type(typ: str) -> wtypes.WType | type:
return CLS_MAPPING[typ]


def get_wtype(typ: str) -> wtypes.WType:
wtype = get_wtype_or_type(typ)
assert isinstance(wtype, wtypes.WType), f"{wtype} is not a WType"
return wtype


def build_function_op_mapping(
op: Op,
arg_map: list[str],
Expand All @@ -751,32 +750,29 @@ def build_function_op_mapping(
immediates=[
const_immediate_value[1].name
if const_immediate_value and const_immediate_value[0] == arg
else ArgMapping(
else ImmediateArgMapping(
arg_name=arg_name_map[arg.name],
allowed_types=[immediate_kind_to_type(arg.immediate_type)],
literal_type=immediate_kind_to_type(arg.immediate_type),
)
for arg in op.immediate_args
],
stack_inputs=[
ArgMapping(
StackArgMapping(
arg_name=arg_name_map[arg.name],
allowed_types=[
get_wtype_or_type(typ)
for typ in sub_types(
allowed_types=list(
sub_types(
any_as if arg.stack_type == StackType.any and any_as else arg.stack_type,
covariant=True,
)
],
),
)
for arg in op.stack_inputs
],
stack_outputs=[
get_wtype(
sub_type(
any_as if o.stack_type == StackType.any and any_as else o.stack_type,
covariant=False,
)
)
sub_types(
any_as if o.stack_type == StackType.any and any_as else o.stack_type,
covariant=False,
)[0]
for o in op.stack_outputs
],
)
Expand Down Expand Up @@ -948,8 +944,8 @@ def build_wtype(wtype: wtypes.WType) -> str:
raise ValueError("Unexpected wtype")


def build_arg_mapping(arg_mapping: ArgMapping) -> Iterable[str]:
yield "ArgMapping("
def build_stack_arg_mapping(arg_mapping: StackArgMapping) -> Iterable[str]:
yield "StackArgMapping("
yield f' arg_name="{arg_mapping.arg_name}",'
yield " allowed_types=["
for allowed_type in arg_mapping.allowed_types:
Expand All @@ -962,6 +958,13 @@ def build_arg_mapping(arg_mapping: ArgMapping) -> Iterable[str]:
yield ")"


def build_immediate_arg_mapping(arg_mapping: ImmediateArgMapping) -> Iterable[str]:
yield "ImmediateArgMapping("
yield f' arg_name="{arg_mapping.arg_name}",'
yield f" literal_type={arg_mapping.literal_type.__name__},"
yield ")"


def build_op_specification_body(name_suffix: str, function: FunctionDef) -> Iterable[str]:
yield f' "algopy.{STUB_NAMESPACE}.{name_suffix}": ['
for op_mapping in function.op_mappings:
Expand All @@ -973,12 +976,12 @@ def build_op_specification_body(name_suffix: str, function: FunctionDef) -> Iter
if isinstance(immediate, str):
yield f' "{immediate}",'
else:
yield from build_arg_mapping(immediate)
yield from build_immediate_arg_mapping(immediate)
yield ","
yield " ],"
yield " stack_inputs=["
for stack_input in op_mapping.stack_inputs:
yield from build_arg_mapping(stack_input)
yield from build_stack_arg_mapping(stack_input)
yield ","
yield " ],"
yield " stack_outputs=["
Expand All @@ -997,7 +1000,10 @@ def build_awst_data(
class_ops: list[ClassDef],
) -> Iterable[str]:
yield "from puya.awst import wtypes"
yield "from puya.awst_build.intrinsic_models import ArgMapping, FunctionOpMapping"
yield (
"from puya.awst_build.intrinsic_models import"
" FunctionOpMapping, ImmediateArgMapping, StackArgMapping"
)
yield ""
yield "ENUM_CLASSES = {"
for enum_name in enums:
Expand Down Expand Up @@ -1040,7 +1046,7 @@ def output_stub(
stub.extend(build_method_stub(function, any_input_as="_T", any_output_as="_T"))
elif function.has_any_return:
# functions with any returns should have already been transformed
raise Exception(f"Unexpected function {function.name} with any return")
raise ValueError(f"Unexpected function {function.name} with any return")
else:
stub.extend(build_method_stub(function))

Expand Down
Loading

0 comments on commit 1eb3bee

Please sign in to comment.