Skip to content

Commit

Permalink
feat: expand handling of literal expressions to allow combining them …
Browse files Browse the repository at this point in the history
…with binary boolean operators, and improve error messaging when handling of type unions in nested bool contexts

---------

also:
- allow dummy values to be constructed for literal-only pytypes
- single spot where "unexpected type" error message is defined, which also gives a special message in case of a union, just in case the function actually does support a union argument as is the case with some intrinsics etc
- prevent parsing function bodies when the return type is invalid, which can cause secondary errors which might be confusing and overshadow the root problem
- supress "expression result is ignored" warning when using `typing.assert_type`
  • Loading branch information
achidlow committed Sep 17, 2024
1 parent 98186fb commit b4e0c30
Show file tree
Hide file tree
Showing 75 changed files with 8,625 additions and 2,656 deletions.
4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
arc4_types/Arc4StringTypes 472 35 437 35 0
arc4_types/Arc4StructsFromAnotherModule 73 12 61 12 0
arc4_types/Arc4StructsType 318 239 79 239 0
arc4_types/Arc4TuplesType 882 138 744 138 0
arc4_types/Arc4TuplesType 865 138 727 138 0
arc_28/EventEmitter 191 133 58 133 0
asset/Reference 269 261 8 261 0
auction/Auction 601 522 79 522 0
augmented_assignment/Augmented 159 156 3 156 0
avm_types_in_abi/Test 423 351 72 351 0
biguint_binary_ops/BiguintBinaryOps 189 77 112 77 0
boolean_binary_ops/BooleanBinaryOps 345 280 65 280 0
boolean_binary_ops/BooleanBinaryOps 1154 471 683 471 0
box_storage/Box 1860 1435 425 1435 0
bytes_ops/BiguintBinaryOps 139 139 0 139 0
calculator 349 317 32 315 2
Expand Down
14 changes: 2 additions & 12 deletions src/puyapy/awst_build/arc4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,19 +382,9 @@ def require_arg_name(arg: pytypes.FuncArg) -> str:
)
return arg.name

def require_single_type(arg: pytypes.FuncArg) -> pytypes.PyType:
try:
(typ,) = arg.types
except ValueError:
raise CodeError(
"union types are not supported as method arguments", location
) from None
else:
return typ

if not (
func_type.args
and set(require_single_type(func_type.args[0]).mro).intersection(
and set(func_type.args[0].type.mro).intersection(
(pytypes.ARC4ContractBaseType, pytypes.ARC4ClientBaseType)
)
):
Expand All @@ -403,7 +393,7 @@ def require_single_type(arg: pytypes.FuncArg) -> pytypes.PyType:
f" instance methods of classes derived from {pytypes.ARC4ContractBaseType}",
location,
)
result = {require_arg_name(arg): require_single_type(arg) for arg in func_type.args[1:]}
result = {require_arg_name(arg): arg.type for arg in func_type.args[1:]}
if "output" in result:
# https://github.com/algorandfoundation/ARCs/blob/main/assets/arc-0032/application.schema.json
raise CodeError(
Expand Down
38 changes: 16 additions & 22 deletions src/puyapy/awst_build/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from puya.errors import CodeError, InternalError, log_exceptions
from puya.models import ContractReference
from puya.parse import SourceLocation
from puya.utils import attrs_extend
from puya.utils import attrs_extend, unique

from puyapy.awst_build import pytypes
from puyapy.awst_build.contract_data import AppStorageDeclaration
from puyapy.awst_build.exceptions import TypeUnionError
from puyapy.parse import ParseResult, source_location_from_mypy

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -197,7 +196,7 @@ def type_to_pytype(
registry: Mapping[str, pytypes.PyType],
mypy_type: mypy.types.Type,
*,
source_location: SourceLocation | None,
source_location: SourceLocation,
in_type_args: bool = False,
in_func_sig: bool = False,
) -> pytypes.PyType:
Expand Down Expand Up @@ -269,13 +268,13 @@ def type_to_pytype(
our_literal_value = literal_value
return pytypes.TypingLiteralType(value=our_literal_value, source_location=loc)
case mypy.types.UnionType(items=items):
types = [recurse(it) for it in items]
types = unique(recurse(it) for it in items)
if not types:
raise CodeError("Cannot resolve empty type", loc)
if len(types) == 1:
return pytypes.NeverType
elif len(types) == 1:
return types[0]
else:
raise TypeUnionError(types, loc)
return pytypes.UnionType(types, loc)
case mypy.types.NoneType() | mypy.types.PartialType(type=None):
return pytypes.NoneType
case mypy.types.UninhabitedType():
Expand All @@ -301,19 +300,14 @@ def type_to_pytype(
for at, name, kind in zip(
func_like.arg_types, func_like.arg_names, func_like.arg_kinds, strict=True
):
try:
pt = type_to_pytype(
registry,
at,
source_location=loc,
in_type_args=in_type_args,
in_func_sig=True,
)
except TypeUnionError as union:
pts = union.types
else:
pts = [pt]
func_args.append(pytypes.FuncArg(types=pts, kind=kind, name=name))
arg_pytype = type_to_pytype(
registry,
at,
source_location=loc,
in_type_args=in_type_args,
in_func_sig=True,
)
func_args.append(pytypes.FuncArg(type=arg_pytype, kind=kind, name=name))
if None in func_like.bound_args:
logger.debug(
"None contained in bound args for function reference", location=loc
Expand Down Expand Up @@ -347,7 +341,7 @@ def _maybe_parameterise_pytype(
registry: Mapping[str, pytypes.PyType],
maybe_generic: pytypes.PyType,
mypy_type_args: Sequence[mypy.types.Type],
loc: SourceLocation | None,
loc: SourceLocation,
) -> pytypes.PyType:
if not mypy_type_args:
return maybe_generic
Expand All @@ -361,7 +355,7 @@ def _maybe_parameterise_pytype(
return result


def _type_of_any_to_error_message(type_of_any: int, source_location: SourceLocation | None) -> str:
def _type_of_any_to_error_message(type_of_any: int, source_location: SourceLocation) -> str:
from mypy.types import TypeOfAny

match type_of_any:
Expand Down
7 changes: 7 additions & 0 deletions src/puyapy/awst_build/eb/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing_extensions
from puya import log
from puya.awst.nodes import (
BinaryBooleanOperator,
CompileTimeConstantExpression,
Expression,
FieldExpression,
Expand Down Expand Up @@ -152,6 +153,12 @@ def binary_op(
) -> InstanceBuilder:
return NotImplemented

@typing.override
def bool_binary_op(
self, other: InstanceBuilder, op: BinaryBooleanOperator, location: SourceLocation
) -> InstanceBuilder:
return super().bool_binary_op(other, op, location)

@typing.override
def augmented_assignment(
self, op: BuilderBinaryOp, rhs: InstanceBuilder, location: SourceLocation
Expand Down
39 changes: 19 additions & 20 deletions src/puyapy/awst_build/eb/_expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def at_most_one_arg_of_type(
logger.error(f"expected at most 1 argument, got {len(args)}", location=location)
if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of_any=valid_types):
return first
logger.error("unexpected argument type", location=first.source_location)
return None
return not_this_type(first, default=default_none)


def default_raise(msg: str, location: SourceLocation) -> typing.Never:
Expand All @@ -60,14 +59,24 @@ def defaulter(msg: str, location: SourceLocation) -> _T: # noqa: ARG001
def default_dummy_value(
pytype: pytypes.PyType,
) -> Callable[[str, SourceLocation], InstanceBuilder]:
assert not isinstance(pytype, pytypes.LiteralOnlyType)

def defaulter(msg: str, location: SourceLocation) -> InstanceBuilder: # noqa: ARG001
return dummy_value(pytype, location)

return defaulter


def not_this_type(node: NodeBuilder, default: Callable[[str, SourceLocation], _T]) -> _T:
"""Provide consistent error messages for unexpected types."""
if isinstance(node.pytype, pytypes.UnionType):
msg = "type unions are unsupported at this location"
else:
msg = "unexpected argument type"
result = default(msg, node.source_location)
logger.error(msg, location=node.source_location)
return result


def at_least_one_arg(
args: Sequence[_TBuilder],
location: SourceLocation,
Expand Down Expand Up @@ -120,10 +129,7 @@ def exactly_one_arg_of_type(
first = maybe_resolve_literal(first, pytype)
if isinstance(first, InstanceBuilder) and is_type_or_subtype(first.pytype, of=pytype):
return first
msg = "unexpected argument type"
result = default(msg, first.source_location)
logger.error(msg, location=first.source_location)
return result
return not_this_type(first, default=default)


def exactly_one_arg_of_type_else_dummy(
Expand All @@ -133,8 +139,6 @@ def exactly_one_arg_of_type_else_dummy(
*,
resolve_literal: bool = False,
) -> InstanceBuilder:
assert not isinstance(pytype, pytypes.LiteralOnlyType)

return exactly_one_arg_of_type(
args,
pytype,
Expand All @@ -152,8 +156,6 @@ def no_args(args: Sequence[NodeBuilder], location: SourceLocation) -> None:
def exactly_n_args_of_type_else_dummy(
args: Sequence[NodeBuilder], pytype: pytypes.PyType, location: SourceLocation, num_args: int
) -> Sequence[InstanceBuilder]:
assert not isinstance(pytype, pytypes.LiteralOnlyType)

if not exactly_n_args(args, location, num_args):
dummy_args = [dummy_value(pytype, location)] * num_args
args = [arg or default for arg, default in zip_longest(args, dummy_args)]
Expand Down Expand Up @@ -185,10 +187,7 @@ def argument_of_type(
builder.pytype, of_any=(target_type, *additional_types)
):
return builder
msg = "unexpected argument type"
result = default(msg, builder.source_location)
logger.error(msg, location=builder.source_location)
return result
return not_this_type(builder, default=default)


def argument_of_type_else_dummy(
Expand Down Expand Up @@ -218,11 +217,11 @@ def simple_string_literal(
return value
case InstanceBuilder(pytype=pytypes.StrLiteralType):
msg = "argument must be a simple str literal"
case _:
msg = "unexpected argument type"
result = default(msg, builder.source_location)
logger.error(msg, location=builder.source_location)
return result
result = default(msg, builder.source_location)
logger.error(msg, location=builder.source_location)
return result
case other:
return not_this_type(other, default=default)


def instance_builder(
Expand Down
10 changes: 10 additions & 0 deletions src/puyapy/awst_build/eb/_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from puya import log
from puya.awst.nodes import (
BinaryBooleanOperator,
BoolConstant,
BytesConstant,
BytesEncoding,
Expand Down Expand Up @@ -121,6 +122,15 @@ def binary_op(
folded = fold_binary_expr(location, op.value, lhs, rhs)
return LiteralBuilderImpl(value=folded, source_location=location)

@typing.override
def bool_binary_op(
self, other: InstanceBuilder, op: BinaryBooleanOperator, location: SourceLocation
) -> InstanceBuilder:
if not isinstance(other, LiteralBuilder):
return super().bool_binary_op(other, op, location)
folded = fold_binary_expr(location, op.value, self.value, other.value)
return LiteralBuilderImpl(value=folded, source_location=location)

@typing.override
def augmented_assignment(
self, op: BuilderBinaryOp, rhs: InstanceBuilder, location: SourceLocation
Expand Down
10 changes: 7 additions & 3 deletions src/puyapy/awst_build/eb/_type_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Callable

from puya.awst.nodes import Expression
from puya.errors import InternalError
from puya.errors import CodeError, InternalError
from puya.parse import SourceLocation

from puyapy.awst_build import constants, intrinsic_data, pytypes
Expand Down Expand Up @@ -247,7 +247,9 @@ def builder_for_instance(pytyp: pytypes.PyType, expr: Expression) -> InstanceBui
for base in pytyp.mro:
if eb_base := PYTYPE_BASE_TO_BUILDER.get(base):
return eb_base(expr, pytyp)
raise InternalError(f"No builder for instance: {pytyp}", expr.source_location)
if isinstance(pytyp, pytypes.UnionType):
raise CodeError("type unions are unsupported at this location", expr.source_location)
raise InternalError(f"no builder for instance: {pytyp}", expr.source_location)


def builder_for_type(pytyp: pytypes.PyType, expr_loc: SourceLocation) -> CallableBuilder:
Expand All @@ -258,4 +260,6 @@ def builder_for_type(pytyp: pytypes.PyType, expr_loc: SourceLocation) -> Callabl
for base in pytyp.mro:
if tb_base := PYTYPE_BASE_TO_TYPE_BUILDER.get(base):
return tb_base(pytyp, expr_loc)
raise InternalError(f"No builder for type: {pytyp}", expr_loc)
if isinstance(pytyp, pytypes.UnionType):
raise CodeError("type unions are unsupported at this location", expr_loc)
raise InternalError(f"no builder for type: {pytyp}", expr_loc)
4 changes: 4 additions & 0 deletions src/puyapy/awst_build/eb/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@


def dummy_value(pytype: pytypes.PyType, location: SourceLocation) -> InstanceBuilder:
if isinstance(pytype, pytypes.LiteralOnlyType):
from puyapy.awst_build.eb._literals import LiteralBuilderImpl

return LiteralBuilderImpl(pytype.python_type(), location)
expr = VarExpression(name="", wtype=pytype.wtype, source_location=location)
return builder_for_instance(pytype, expr)

Expand Down
4 changes: 2 additions & 2 deletions src/puyapy/awst_build/eb/arc4/abi_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def call(
kind="update" if is_update else "create",
location=method_or_type.source_location,
)
case _:
raise CodeError("unexpected argument type", method_or_type.source_location)
case other:
expect.not_this_type(other, default=expect.default_raise)
if compiled is None:
compiled = CompiledContractExpressionBuilder(
CompiledContract(
Expand Down
8 changes: 4 additions & 4 deletions src/puyapy/awst_build/eb/arc4/tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def call(
ARC4Encode(value=arg.resolve(), wtype=wtype, source_location=location), typ
)
case _:
# don't know expected type
raise CodeError("unexpected argument type", arg.source_location)
# don't know expected type, so raise
expect.not_this_type(arg, default=expect.default_raise)


class ARC4TupleTypeBuilder(ARC4TypeBuilder[pytypes.TupleType]):
Expand Down Expand Up @@ -91,8 +91,8 @@ def index(self, index: InstanceBuilder, location: SourceLocation) -> InstanceBui
pass
case InstanceBuilder(pytype=pytypes.IntLiteralType):
raise CodeError("tuple index must be a simple int literal", index.source_location)
case _:
raise CodeError("unexpected argument type", index.source_location)
case other:
expect.not_this_type(other, default=expect.default_raise)
try:
item_typ = self.pytype.items[index_value]
except IndexError:
Expand Down
Loading

0 comments on commit b4e0c30

Please sign in to comment.