Skip to content

Commit

Permalink
Merge pull request #1020 from rhayes777/feature/jax_assert_within_limits
Browse files Browse the repository at this point in the history
Update `assert_within_limits` on the `Prior` class
  • Loading branch information
CKrawczyk authored Jun 12, 2024
2 parents 5a4bef7 + 2f19a61 commit 052eb7c
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions autofit/mapper/prior/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,29 @@ def factor(self):

def assert_within_limits(self, value):

if jax_wrapper.use_jax:
return

if not (self.lower_limit <= value <= self.upper_limit):
def exception_message():
raise exc.PriorLimitException(
"The physical value {} for a prior "
"was not within its limits {}, {}".format(
value, self.lower_limit, self.upper_limit
)
)

if jax_wrapper.use_jax:
import jax
jax.lax.cond(
jax.numpy.logical_or(
value < self.lower_limit,
value > self.upper_limit
),
lambda _: jax.debug.callback(exception_message),
lambda _: None,
None
)

elif not (self.lower_limit <= value <= self.upper_limit):
exception_message()

@staticmethod
def for_class_and_attribute_name(cls, attribute_name):
prior_dict = conf.instance.prior_config.for_class_and_suffix_path(
Expand Down

0 comments on commit 052eb7c

Please sign in to comment.