Skip to content

Commit

Permalink
improve a bit how the tests are written
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Jul 7, 2022
1 parent 067628f commit e30be6d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 45 deletions.
49 changes: 29 additions & 20 deletions src/vegasflow/tests/test_algs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,20 @@ def check_is_one(result, sigmas=3):
np.testing.assert_allclose(res, 1.0, atol=err)


def test_VegasFlow():
@pytest.mark.parametrize("mode", range(4))
def test_VegasFlow(mode):
"""Test VegasFlow class, importance sampling algorithm"""
for mode in range(4):
vegas_instance = instance_and_compile(VegasFlow, mode)
_ = vegas_instance.run_integration(n_iter)
vegas_instance.freeze_grid()
result = vegas_instance.run_integration(n_iter)
check_is_one(result)
vegas_instance = instance_and_compile(VegasFlow, mode)
_ = vegas_instance.run_integration(n_iter)
vegas_instance.freeze_grid()
result = vegas_instance.run_integration(n_iter)
check_is_one(result)


def test_VegasFlow_grid_management():
vegas_instance = instance_and_compile(VegasFlow, 1)
_ = vegas_instance.run_integration(n_iter)
vegas_instance.freeze_grid()

# Change the number of events
vegas_instance.n_events = 2 * ncalls
Expand Down Expand Up @@ -140,15 +146,18 @@ def test_VegasFlow_load_grid():
vegas_instance.load_grid(file_name=tmp_filename)


def test_PlainFlow():
# We could use hypothesis here instead of this loop
for mode in range(4):
plain_instance = instance_and_compile(PlainFlow, mode)
result = plain_instance.run_integration(n_iter)
check_is_one(result)
@pytest.mark.parametrize("mode", range(4))
def test_PlainFlow(mode):
plain_instance = instance_and_compile(PlainFlow, mode)
result = plain_instance.run_integration(n_iter)
check_is_one(result)


def test_PlainFlow_change_nevents():
plain_instance = instance_and_compile(PlainFlow, 0)
result = plain_instance.run_integration(n_iter)
check_is_one(result)

# Use the last instance to check that changing the number of events
# don't change the result
plain_instance.n_events = 2 * ncalls
new_result = plain_instance.run_integration(n_iter)
check_is_one(new_result)
Expand Down Expand Up @@ -178,12 +187,12 @@ def test_rng_generation(n_events=100):
_ = helper_rng_tester(v, n_events)


def test_VegasFlowPlus_ADAPTIVE_SAMPLING():
@pytest.mark.parametrize("mode", range(4))
def test_VegasFlowPlus_ADAPTIVE_SAMPLING(mode):
"""Test Vegasflow with Adaptive Sampling on (the default)"""
for mode in range(4):
vflowplus_instance = instance_and_compile(VegasFlowPlus, mode)
result = vflowplus_instance.run_integration(n_iter)
check_is_one(result)
vflowplus_instance = instance_and_compile(VegasFlowPlus, mode)
result = vflowplus_instance.run_integration(n_iter)
check_is_one(result)


def test_VegasFlowPlus_NOT_ADAPTIVE_SAMPLING():
Expand Down
19 changes: 8 additions & 11 deletions src/vegasflow/tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
Tests the gradients of the different algorithms
"""

import numpy as np
from pytest import mark

from vegasflow import float_me, run_eager
from vegasflow import VegasFlow, VegasFlowPlus, PlainFlow
import tensorflow as tf
import numpy as np


def generate_integrand(variable):
Expand Down Expand Up @@ -70,17 +72,12 @@ def wrapper_test(iclass, x_point=5.0, alpha=10, integrator_kwargs=None):
np.testing.assert_allclose(grad_1 * alpha, grad_2, rtol=1e-2)


def test_gradient_Vegasflow():
""" "Test one can compile and generate gradients with VegasFlow"""
wrapper_test(VegasFlow)
@mark.parametrize("algorithm", [VegasFlowPlus, VegasFlow, PlainFlow])
def test_gradient(algorithm):
""" "Test one can compile and generate gradients with the different algorithms"""
wrapper_test(algorithm)


def test_gradient_VegasflowPlus():
def test_gradient_VegasflowPlus_adaptive():
""" "Test one can compile and generate gradients with VegasFlowPlus"""
wrapper_test(VegasFlowPlus)
wrapper_test(VegasFlowPlus, integrator_kwargs={"adaptive": True})


def test_gradient_PlainFlow():
""" "Test one can compile and generate gradients with PlainFlow"""
wrapper_test(PlainFlow)
29 changes: 15 additions & 14 deletions src/vegasflow/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Miscellaneous tests that don't really fit anywhere else
"""
import pytest

from vegasflow import VegasFlow, VegasFlowPlus, PlainFlow
import tensorflow as tf

Expand All @@ -23,20 +24,20 @@ def _wrong_vector_integrand(xarr):
return tf.transpose(xarr)


def test_working_vectorial():
@pytest.mark.parametrize("mode", range(4))
@pytest.mark.parametrize("alg", [VegasFlow, PlainFlow])
def test_working_vectorial(alg, mode):
"""Check that the algorithms that accept integrating vectorial functions can really do so"""
for alg in [VegasFlow, PlainFlow]:
for mode in range(4):
inst = instance_and_compile(alg, mode=mode, integrand_function=_vector_integrand)
result = inst.run_integration(2)
check_is_one(result)
inst = instance_and_compile(alg, mode=mode, integrand_function=_vector_integrand)
result = inst.run_integration(2)
check_is_one(result, sigmas=4)


def test_notworking_vectorial():
@pytest.mark.parametrize("alg", [VegasFlowPlus])
def test_notworking_vectorial(alg):
"""Check that the algorithms that do not accept vectorial functions fail appropriately"""
for alg in [VegasFlowPlus]:
with pytest.raises(NotImplementedError):
_ = instance_and_compile(alg, integrand_function=_vector_integrand)
with pytest.raises(NotImplementedError):
_ = instance_and_compile(alg, integrand_function=_vector_integrand)


def test_check_wrong_main_dimension():
Expand All @@ -47,8 +48,8 @@ def test_check_wrong_main_dimension():
inst.compile(_vector_integrand)


def test_wrong_shape():
@pytest.mark.parametrize("wrong_fun", [_wrong_vector_integrand, _wrong_integrand])
def test_wrong_shape(wrong_fun):
"""Check that an error is raised by the compilation if the integrand has the wrong shape"""
for wrong_fun in [_wrong_vector_integrand, _wrong_integrand]:
with pytest.raises(ValueError):
_ = instance_and_compile(PlainFlow, integrand_function=wrong_fun)
with pytest.raises(ValueError):
_ = instance_and_compile(PlainFlow, integrand_function=wrong_fun)

0 comments on commit e30be6d

Please sign in to comment.