Skip to content

Commit

Permalink
Merge pull request #1037 from rhayes777/feature/jax_pytree
Browse files Browse the repository at this point in the history
fix setting attribute rather than calling in constructor. could cause issues later
  • Loading branch information
Jammy2211 authored Aug 20, 2024
2 parents bed5607 + 718ce6f commit edbd5d1
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions autofit/mapper/prior_model/prior_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,17 @@ def instance_flatten(self, instance):
"""
Flatten an instance of this model as a PyTree.
"""
attribute_names = [
name
for name in self.direct_argument_names
if hasattr(instance, name) and name not in self.constructor_argument_names
]
return (
[getattr(instance, name) for name in self.direct_argument_names],
None,
(
[getattr(instance, name) for name in self.constructor_argument_names],
[getattr(instance, name) for name in attribute_names],
),
(attribute_names,),
)

def instance_unflatten(self, aux_data, children):
Expand All @@ -263,7 +271,12 @@ def instance_unflatten(self, aux_data, children):
-------
An instance of this model.
"""
return self.cls(**dict(zip(self.direct_argument_names, children)))
constructor_arguments, other_arguments = children
attribute_names = aux_data[0]
instance = self.cls(*constructor_arguments)
for name, value in zip(attribute_names, other_arguments):
setattr(instance, name, value)
return instance

def tree_flatten(self):
"""
Expand Down

0 comments on commit edbd5d1

Please sign in to comment.