Skip to content

Commit

Permalink
fix: evaluate class bodies at module evaluation time, so that any ref…
Browse files Browse the repository at this point in the history
…erenced constants in e.g. decorators receive the correct value if it's later updated
  • Loading branch information
achidlow committed Sep 3, 2024
1 parent f9521b5 commit 9aea78c
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 39 deletions.
104 changes: 66 additions & 38 deletions src/puyapy/awst_build/contract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections.abc import Iterator, Mapping
import enum
import typing
from collections.abc import Callable, Iterator, Mapping

import mypy.nodes
import mypy.types
Expand All @@ -25,6 +27,15 @@
logger = log.get_logger(__name__)


DeferredContractMethod: typing.TypeAlias = Callable[[ASTConversionModuleContext], ContractMethod]


class SpecialMethod(enum.Enum):
init = enum.auto()
approval_program = enum.auto()
clear_state_program = enum.auto()


class ContractASTConverter(BaseMyPyStatementVisitor[None]):
def __init__(
self,
Expand All @@ -37,12 +48,11 @@ def __init__(
self.cref = qualified_class_name(class_def.info)
self._is_arc4 = class_def.info.has_base(constants.ARC4_CONTRACT_BASE)
self._is_abstract = _check_class_abstractness(context, class_def)
self._approval_program: ContractMethod | None = None
self._clear_program: ContractMethod | None = None
self._init_method: ContractMethod | None = None
self._subroutines = list[ContractMethod]()
inherited_and_direct_storage = _gather_app_storage_recursive(context, class_def)
self.context.set_state_defs(self.cref, inherited_and_direct_storage)
self._methods = list[tuple[DeferredContractMethod, SourceLocation, SpecialMethod | None]]()
self.class_options: typing.Final = class_options
self.bases: typing.Final = _gather_bases(context, class_def)
self.docstring: typing.Final = class_def.docstring
self.source_location: typing.Final = self._location(class_def)

# if the class has an __init__ method, we need to visit it first, so any storage
# fields cane be resolved to a (static) key
Expand All @@ -60,38 +70,51 @@ def __init__(
stmt.accept(self)
# TODO: validation for state proxies being non-conditional

def build(self, context: ASTConversionModuleContext) -> ContractFragment:
inherited_and_direct_storage = _gather_app_storage_recursive(context, self.class_def)
self.context.set_state_defs(self.cref, inherited_and_direct_storage)
approval_program: ContractMethod | None = None
clear_program: ContractMethod | None = None
init_method: ContractMethod | None = None
subroutines = list[ContractMethod]()
for method_builder, method_loc, special_kind in self._methods:
with context.log_exceptions(fallback_location=method_loc):
sub = method_builder(context)
match special_kind:
case SpecialMethod.init:
init_method = sub
case SpecialMethod.approval_program:
approval_program = sub
case SpecialMethod.clear_state_program:
clear_program = sub
case None:
subroutines.append(sub)
case invalid:
typing.assert_never(invalid)

app_state = {
name: state_decl.definition
for name, state_decl in context.state_defs(self.cref).items()
if state_decl.defined_in == self.cref
}

self.result_ = ContractFragment(
result = ContractFragment(
module_name=self.cref.module_name,
name=self.cref.class_name,
name_override=class_options.name_override,
name_override=self.class_options.name_override,
is_arc4=self._is_arc4,
is_abstract=self._is_abstract,
bases=_gather_bases(context, class_def),
init=self._init_method,
approval_program=self._approval_program,
clear_program=self._clear_program,
subroutines=self._subroutines,
bases=self.bases,
init=init_method,
approval_program=approval_program,
clear_program=clear_program,
subroutines=subroutines,
app_state=app_state,
docstring=class_def.docstring,
source_location=self._location(class_def),
reserved_scratch_space=class_options.scratch_slot_reservations,
state_totals=class_options.state_totals,
docstring=self.docstring,
source_location=self.source_location,
reserved_scratch_space=self.class_options.scratch_slot_reservations,
state_totals=self.class_options.state_totals,
)

@classmethod
def convert(
cls,
context: ASTConversionModuleContext,
class_def: mypy.nodes.ClassDef,
class_options: ContractClassOptions,
) -> ContractFragment:
return cls(context, class_def, class_options).result_
return result

def empty_statement(self, _stmt: mypy.nodes.Statement) -> None:
return None
Expand Down Expand Up @@ -139,7 +162,7 @@ def visit_function(
source_location=source_location,
)
if sub is not None:
self._init_method = sub
self._methods.append((sub, source_location, SpecialMethod.init))
elif func_def.name.startswith("__") and func_def.name.endswith("__"):
self._error(
"methods starting and ending with a double underscore"
Expand All @@ -158,10 +181,12 @@ def visit_function(
source_location=source_location,
)
if sub is not None:
if is_approval:
self._approval_program = sub
else:
self._clear_program = sub
kind = (
SpecialMethod.approval_program
if is_approval
else SpecialMethod.clear_state_program
)
self._methods.append((sub, source_location, kind))
elif not self._is_arc4:
for arc4_only_dec_name in (
constants.ABIMETHOD_DECORATOR,
Expand All @@ -182,7 +207,7 @@ def visit_function(
source_location=source_location,
)
if sub is not None:
self._subroutines.append(sub)
self._methods.append((sub, source_location, None))
else:
subroutine_dec = dec_by_fullname.pop(constants.SUBROUTINE_HINT, None)
abimethod_dec = dec_by_fullname.pop(constants.ABIMETHOD_DECORATOR, None)
Expand Down Expand Up @@ -218,15 +243,15 @@ def visit_function(
source_location=source_location,
)
if sub is not None:
self._subroutines.append(sub)
self._methods.append((sub, source_location, None))

def _handle_method(
self,
func_def: mypy.nodes.FuncDef,
extra_decorators: Mapping[str, mypy.nodes.Expression],
arc4_method_config: ARC4MethodConfig | None,
source_location: SourceLocation,
) -> ContractMethod | None:
) -> DeferredContractMethod | None:
func_loc = self._location(func_def)
self._precondition(
not (func_def.is_static or func_def.is_class),
Expand All @@ -241,8 +266,8 @@ def _handle_method(
self._error(f"{func_def.name} should take a self parameter", func_loc)
match func_def.abstract_status:
case mypy.nodes.NOT_ABSTRACT:
return FunctionASTConverter.convert(
self.context,
return lambda ctx: FunctionASTConverter.convert(
ctx,
func_def=func_def,
source_location=source_location,
contract_method_info=ContractMethodInfo(
Expand Down Expand Up @@ -323,6 +348,9 @@ def visit_del_stmt(self, stmt: mypy.nodes.DelStmt) -> None:
def visit_match_stmt(self, stmt: mypy.nodes.MatchStmt) -> None:
self._unsupported_stmt("match", stmt)

def visit_type_alias_stmt(self, stmt: mypy.nodes.TypeAliasStmt) -> None:
self._unsupported_stmt("type", stmt)


def _gather_bases(
context: ASTConversionModuleContext, class_def: mypy.nodes.ClassDef
Expand Down
3 changes: 2 additions & 1 deletion src/puyapy/awst_build/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def visit_class_def(self, cdef: mypy.nodes.ClassDef) -> StatementResult:
return []

class_options = _process_contract_class_options(self.context, self, cdef)
return [lambda ctx: ContractASTConverter.convert(ctx, cdef, class_options)]
converter = ContractASTConverter(self.context, cdef, class_options)
return [converter.build]

def visit_operator_assignment_stmt(
self, stmt: mypy.nodes.OperatorAssignmentStmt
Expand Down

0 comments on commit 9aea78c

Please sign in to comment.