Skip to content

Commit

Permalink
feat: add primitive UTF-8 String type
Browse files Browse the repository at this point in the history
  • Loading branch information
achidlow committed Mar 18, 2024
1 parent 6186dab commit 14d35c6
Show file tree
Hide file tree
Showing 35 changed files with 1,349 additions and 11 deletions.
1 change: 1 addition & 0 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
string_ops 157 152 5 152 0
stubs/BigUInt 172 112 60 112 0
stubs/Bytes 1769 258 1511 258 0
stubs/String 203 141 62 141 0
stubs/Uint64 371 8 363 8 0
template_variables/TemplateVariables 168 155 13 155 0
too_many_permutations 108 106 2 106 0
Expand Down
3 changes: 3 additions & 0 deletions src/puya/awst/function_traverser.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def visit_bool_constant(self, expr: awst_nodes.BoolConstant) -> None:
def visit_bytes_constant(self, expr: awst_nodes.BytesConstant) -> None:
pass

def visit_string_constant(self, expr: awst_nodes.StringConstant) -> None:
pass

def visit_arc4_decode(self, expr: awst_nodes.ARC4Decode) -> None:
expr.value.accept(self)

Expand Down
53 changes: 47 additions & 6 deletions src/puya/awst/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,15 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_bytes_constant(self)


@attrs.frozen
class StringConstant(Expression):
wtype: WType = attrs.field(default=wtypes.string_wtype, init=False)
value: str = attrs.field(validator=[literal_validator(wtypes.string_wtype)])

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_string_constant(self)


@attrs.frozen
class TemplateVar(Expression):
wtype: WType
Expand Down Expand Up @@ -1112,7 +1121,9 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_numeric_comparison_expression(self)


bytes_comparable = expression_has_wtype(wtypes.bytes_wtype, wtypes.account_wtype)
bytes_comparable = expression_has_wtype(
wtypes.bytes_wtype, wtypes.account_wtype, wtypes.string_wtype
)


@attrs.frozen
Expand Down Expand Up @@ -1278,10 +1289,27 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T:

@attrs.frozen
class BytesBinaryOperation(Expression):
left: Expression = attrs.field(validator=[wtype_is_bytes])
left: Expression = attrs.field(
validator=[expression_has_wtype(wtypes.bytes_wtype, wtypes.string_wtype)]
)
op: BytesBinaryOperator
right: Expression = attrs.field(validator=[wtype_is_bytes])
wtype: WType = attrs.field(default=wtypes.bytes_wtype, init=False)
right: Expression = attrs.field(
validator=[expression_has_wtype(wtypes.bytes_wtype, wtypes.string_wtype)]
)
wtype: WType = attrs.field(init=False)

@right.validator
def _check_right(self, _attribute: object, right: Expression) -> None:
if right.wtype != self.left.wtype:
raise CodeError(
f"Bytes operation on differing types,"
f" lhs is {self.left.wtype}, rhs is {self.right.wtype}",
self.source_location,
)

@wtype.default
def _wtype_factory(self) -> wtypes.WType:
return self.left.wtype

def accept(self, visitor: ExpressionVisitor[T]) -> T:
return visitor.visit_bytes_binary_operation(self)
Expand Down Expand Up @@ -1370,9 +1398,22 @@ def accept(self, visitor: StatementVisitor[T]) -> T:

@attrs.frozen
class BytesAugmentedAssignment(Statement):
target: Lvalue = attrs.field(validator=[wtype_is_bytes])
target: Lvalue = attrs.field(
validator=[expression_has_wtype(wtypes.bytes_wtype, wtypes.string_wtype)]
)
op: BytesBinaryOperator
value: Expression = attrs.field(validator=[wtype_is_bytes])
value: Expression = attrs.field(
validator=[expression_has_wtype(wtypes.bytes_wtype, wtypes.string_wtype)]
)

@value.validator
def _check_value(self, _attribute: object, value: Expression) -> None:
if value.wtype != self.target.wtype:
raise CodeError(
f"Augmented assignment of differing types,"
f" expected {self.target.wtype}, got {value.wtype}",
value.source_location,
)

def accept(self, visitor: StatementVisitor[T]) -> T:
return visitor.visit_bytes_augmented_assignment(self)
Expand Down
3 changes: 3 additions & 0 deletions src/puya/awst/to_code_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ def visit_bytes_constant(self, expr: nodes.BytesConstant) -> str:
case nodes.BytesEncoding.base16 | nodes.BytesEncoding.unknown:
return f'hex<"{expr.value.hex().upper()}">'

def visit_string_constant(self, expr: nodes.StringConstant) -> str:
return expr.value

def visit_method_constant(self, expr: nodes.MethodConstant) -> str:
return f'Method("{expr.value}")'

Expand Down
4 changes: 4 additions & 0 deletions src/puya/awst/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def visit_bool_constant(self, expr: puya.awst.nodes.BoolConstant) -> T:
def visit_bytes_constant(self, expr: puya.awst.nodes.BytesConstant) -> T:
...

@abstractmethod
def visit_string_constant(self, expr: puya.awst.nodes.StringConstant) -> T:
...

@abstractmethod
def visit_address_constant(self, expr: puya.awst.nodes.AddressConstant) -> T:
...
Expand Down
6 changes: 5 additions & 1 deletion src/puya/awst/wtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def is_valid_utf8_literal(value: object) -> typing.TypeGuard[str]:
stub_name=constants.CLS_BYTES_ALIAS,
is_valid_literal=is_valid_bytes_literal,
)

string_wtype: typing.Final = WType(
name="string",
stub_name=constants.CLS_STRING_ALIAS,
is_valid_literal=is_valid_utf8_literal,
)
asset_wtype: typing.Final = WType(
name="asset",
stub_name=constants.CLS_ASSET_ALIAS,
Expand Down
2 changes: 2 additions & 0 deletions src/puya/awst_build/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
CLS_UINT64_ALIAS = f"{PUYAPY_PREFIX}UInt64"
CLS_BYTES = f"{PUYAPY_PREFIX}_primitives.Bytes"
CLS_BYTES_ALIAS = f"{PUYAPY_PREFIX}Bytes"
CLS_STRING = f"{PUYAPY_PREFIX}_primitives.String"
CLS_STRING_ALIAS = f"{PUYAPY_PREFIX}String"
CLS_BIGUINT = f"{PUYAPY_PREFIX}_primitives.BigUInt"
CLS_BIGUINT_ALIAS = f"{PUYAPY_PREFIX}BigUInt"
CLS_TRANSACTION_BASE = f"{PUYAPY_PREFIX}gtxn.TransactionBase"
Expand Down
4 changes: 2 additions & 2 deletions src/puya/awst_build/eb/arc4/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def expect_string_or_bytes(
),
source_location,
)
case Literal(value=bytes(bytes_literal)):
case Literal(value=bytes(bytes_literal)): # TODO: yeet this?
return arc4_encode_bytes(
BytesConstant(
value=bytes_literal,
Expand Down Expand Up @@ -108,7 +108,7 @@ def augmented_assignment(
)
)
case _:
return super().augmented_assignment(op, rhs, location)
return super().augmented_assignment(op, rhs, location) # TODO: bad error message

def binary_op(
self,
Expand Down
162 changes: 162 additions & 0 deletions src/puya/awst_build/eb/string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from __future__ import annotations

import typing

import structlog

from puya.awst import wtypes
from puya.awst.nodes import (
BytesAugmentedAssignment,
BytesBinaryOperation,
BytesBinaryOperator,
BytesComparisonExpression,
CallArg,
EqualityComparison,
Expression,
FreeSubroutineTarget,
Literal,
ReinterpretCast,
Statement,
StringConstant,
SubroutineCallExpression,
)
from puya.awst_build import intrinsic_factory
from puya.awst_build.eb.base import (
BuilderBinaryOp,
BuilderComparisonOp,
ExpressionBuilder,
ValueExpressionBuilder,
)
from puya.awst_build.eb.bytes_backed import BytesBackedClassExpressionBuilder
from puya.awst_build.eb.var_factory import var_expression
from puya.awst_build.utils import expect_operand_wtype
from puya.errors import CodeError

if typing.TYPE_CHECKING:
from collections.abc import Sequence

import mypy.nodes

from puya.parse import SourceLocation

logger: structlog.types.FilteringBoundLogger = structlog.get_logger(__name__)


class StringClassExpressionBuilder(BytesBackedClassExpressionBuilder):
def produces(self) -> wtypes.WType:
return wtypes.string_wtype

def call(
self,
args: Sequence[ExpressionBuilder | Literal],
arg_kinds: list[mypy.nodes.ArgKind],
arg_names: list[str | None],
location: SourceLocation,
) -> ExpressionBuilder:
match args:
case []:
value = ""
case [Literal(value=str(value))]:
pass
case _:
raise CodeError("Invalid/unhandled arguments", location)
str_const = StringConstant(value=value, source_location=location)
return var_expression(str_const)


class StringExpressionBuilder(ValueExpressionBuilder):
wtype = wtypes.string_wtype

def member_access(self, name: str, location: SourceLocation) -> ExpressionBuilder:
match name:
case "bytes":
return _get_bytes_expr_builder(self.expr)
case _:
raise CodeError(f"Unrecognised member of {self.wtype}: {name}", location)

def augmented_assignment(
self, op: BuilderBinaryOp, rhs: ExpressionBuilder | Literal, location: SourceLocation
) -> Statement:
match op:
case BuilderBinaryOp.add:
return BytesAugmentedAssignment(
target=self.lvalue(),
op=BytesBinaryOperator.add,
value=expect_operand_wtype(rhs, self.wtype),
source_location=location,
)
case _:
raise CodeError(
f"Unsupported augmented assignment operation on {self.wtype}: {op.value}=",
location,
)

def binary_op(
self,
other: ExpressionBuilder | Literal,
op: BuilderBinaryOp,
location: SourceLocation,
*,
reverse: bool,
) -> ExpressionBuilder:
match op:
case BuilderBinaryOp.add:
lhs = self.expr
rhs = expect_operand_wtype(other, self.wtype)
if reverse:
(lhs, rhs) = (rhs, lhs)
return var_expression(
BytesBinaryOperation(
left=lhs,
op=BytesBinaryOperator.add,
right=rhs,
source_location=location,
)
)
case _:
return NotImplemented

def compare(
self, other: ExpressionBuilder | Literal, op: BuilderComparisonOp, location: SourceLocation
) -> ExpressionBuilder:
other_expr = expect_operand_wtype(other, self.wtype)

cmp = BytesComparisonExpression(
source_location=location,
lhs=self.expr,
operator=EqualityComparison(op.value),
rhs=other_expr,
)
return var_expression(cmp)

def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> ExpressionBuilder:
bytes_expr = _get_bytes_expr(self.expr)
len_expr = intrinsic_factory.bytes_len(bytes_expr, location)
len_builder = var_expression(len_expr)
return len_builder.bool_eval(location, negate=negate)

def contains(
self, item: ExpressionBuilder | Literal, location: SourceLocation
) -> ExpressionBuilder:
item_expr = _get_bytes_expr(expect_operand_wtype(item, wtypes.string_wtype))
this_expr = _get_bytes_expr(self.expr)
is_substring_expr = SubroutineCallExpression(
target=FreeSubroutineTarget(module_name="puyapy_lib_bytes", name="is_substring"),
args=[
CallArg(value=item_expr, name="item"),
CallArg(value=this_expr, name="sequence"),
],
wtype=wtypes.bool_wtype,
source_location=location,
)
return var_expression(is_substring_expr)


def _get_bytes_expr(expr: Expression) -> ReinterpretCast:
return ReinterpretCast(
expr=expr, wtype=wtypes.bytes_wtype, source_location=expr.source_location
)


def _get_bytes_expr_builder(expr: Expression) -> ExpressionBuilder:
return var_expression(_get_bytes_expr(expr))
3 changes: 3 additions & 0 deletions src/puya/awst_build/eb/type_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
intrinsics,
log,
named_int_constants,
string,
struct,
template_variables,
transaction,
Expand Down Expand Up @@ -69,6 +70,7 @@
constants.CLS_APPLICATION: application.ApplicationClassExpressionBuilder,
constants.CLS_BIGUINT: biguint.BigUIntClassExpressionBuilder,
constants.CLS_BYTES: bytes_.BytesClassExpressionBuilder,
constants.CLS_STRING: string.StringClassExpressionBuilder,
constants.CLS_UINT64: uint64.UInt64ClassExpressionBuilder,
constants.CLS_TEMPLATE_VAR_METHOD: (
template_variables.GenericTemplateVariableExpressionBuilder
Expand Down Expand Up @@ -133,6 +135,7 @@
wtypes.biguint_wtype: biguint.BigUIntExpressionBuilder,
wtypes.bool_wtype: bool_.BoolExpressionBuilder,
wtypes.bytes_wtype: bytes_.BytesExpressionBuilder,
wtypes.string_wtype: string.StringExpressionBuilder,
wtypes.uint64_wtype: uint64.UInt64ExpressionBuilder,
wtypes.void_wtype: void.VoidExpressionBuilder,
wtypes.WGroupTransaction: transaction.GroupTransactionExpressionBuilder,
Expand Down
2 changes: 2 additions & 0 deletions src/puya/awst_build/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def visit_match_stmt(self, stmt: mypy.nodes.MatchStmt) -> Switch | None:
constants.CLS_UINT64,
constants.CLS_BIGUINT,
constants.CLS_BYTES,
constants.CLS_STRING,
constants.CLS_ACCOUNT,
):
case_value_builder_or_literal = inner_literal_expr.accept(self)
Expand All @@ -554,6 +555,7 @@ def visit_match_stmt(self, stmt: mypy.nodes.MatchStmt) -> Switch | None:
constants.CLS_UINT64_ALIAS,
constants.CLS_BIGUINT_ALIAS,
constants.CLS_BYTES_ALIAS,
constants.CLS_STRING_ALIAS,
constants.CLS_ACCOUNT_ALIAS,
)
)
Expand Down
3 changes: 3 additions & 0 deletions src/puya/awst_build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
IntegerConstant,
Literal,
ReinterpretCast,
StringConstant,
UInt64Constant,
)
from puya.awst_build import constants
Expand Down Expand Up @@ -249,6 +250,8 @@ def convert_literal(
return BytesConstant(
value=str_value.encode("utf8"), encoding=BytesEncoding.utf8, source_location=loc
)
case str(str_value), wtypes.string_wtype:
return StringConstant(value=str_value, source_location=loc)
case str(str_value), wtypes.account_wtype:
return AddressConstant(value=str_value, source_location=loc)
case int(int_value), wtypes.asset_wtype | wtypes.application_wtype:
Expand Down
7 changes: 7 additions & 0 deletions src/puya/ir/builder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,13 @@ def visit_bytes_constant(self, expr: awst_nodes.BytesConstant) -> BytesConstant:
source_location=expr.source_location,
)

def visit_string_constant(self, expr: awst_nodes.StringConstant) -> BytesConstant:
return BytesConstant(
value=expr.value.encode("utf8"),
encoding=AVMBytesEncoding.utf8,
source_location=expr.source_location,
)

def visit_address_constant(self, expr: awst_nodes.AddressConstant) -> TExpression:
return AddressConstant(
value=expr.value,
Expand Down
Loading

0 comments on commit 14d35c6

Please sign in to comment.