Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate vector functions #81

Merged
merged 8 commits into from
Jul 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions doc/source/how_to.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,26 @@ using the ``tf.function`` decorator.

Constructing the integrand
^^^^^^^^^^^^^^^^^^^^^^^^^^
Note that the ``example_integrand`` contained only ``TensorFlow`` operations.
All ``VegasFlow`` integrands as such, in principle, depend only on python primitives
and ``TensorFlow`` operations, otherwise the code cannot be compiled and as a result it cannot
run on GPU or other ``TensorFlow``-supported hardware accelerators.
Constructing an integrand for ``VegasFlow`` is similar to constructing an integrand for any other algorithm with a small difference:
the output of the integrand should be a tensor of results instead of just one number.
While most integration algorithms will take a function and then evaluate said function ``n`` number of times (to calculate ``n`` events)
``VegasFlow`` takes the approach of evaluating as many events as possible at once.
As such the input random array (``xarr``) is a tensor of shape (``(n_events, n_dim)``) instead of the usual (``(n_dim,)``)
and, suitably, the output result is not a scalar bur rather a tensor of shape (``(n_events)``).

Note that the ``example_integrand`` contains only ``TensorFlow`` function and method and operations between ``TensorFlow`` variables:

.. code-block:: python

def example_integrand(xarr, weight=None):
s = tf.reduce_sum(xarr, axis=1)
result = tf.pow(0.1/s, 2)
return result


By making ``VegasFlow`` integrand depend only on python and ``TensorFlow`` primitives the code can be understood by
``TenosrFlow`` and be compiled to run on CPU, GPU or other hardware accelerators
as well as to apply optimizations based on `XLA <https://www.tensorflow.org/api_docs/python/tf/function>`_.

It is possible, however (and often useful when prototyping) to integrate functions not
based on ``TensorFlow``, by passing the ``compilable`` flag at compile time.
Expand All @@ -85,13 +101,46 @@ the integration algorithm).

.. note:: Integrands must always accept as first argument the random number (``xarr``)
and can also accept the keyword argument ``weight``. The ``compile`` method of the integration
will try to find the most adequate signature in each situation.
will try to find the most adequate signature in each situation.


It is also possible to completely avoid compilation,
by leveraging ``TensorFlow``'s `eager execution <https://www.tensorflow.org/guide/eager>`_ as
explained at :ref:`eager-label`.

Integrating vector functions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

It is also possible to integrate vector-valued functions with most algorithms included in ``VegasFlow`` while simply modifying
the integrand to return a vector of values per event instead of a scalar (in other words, the output shape of the result
should be (``(n_events, n_outputs)``).

.. code-block:: python

@tf.function
def test_function(xarr):
res = tf.square((xarr - 1.0) ** 2)
return tf.exp(-res)


For adaptative algorithms however only one of the dimensions is taken into account to adapt the grid
(by default it will be the first output).
In ``VegasFlow`` it is possible to modify this beahaviour with the ``main_dimension`` keyword argument.


.. code-block:: python

vegas = VegasFlow(dim, ncalls, main_dimension=1)


``VegasFlow`` will automatically (by trying to evaluate the integrand with a small number of events) try to
discover whether the functon is vector-valued and will check a) whether the algorithm can integrate vector-valued integrals
and b) whether the ``main_dimension`` index is contained in the dimensionality of the output.


.. note:: Remember that python lists and arrays are 0-indexed and such for an output with 2 components the index of the last dimension is 1 and not 2!


Choosing the correct types
^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
40 changes: 40 additions & 0 deletions examples/multidimensional_integral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Example: basic multidimensional integral

The output of the function is not an scalar but a vector v.
In this case it is necessary to tell Vegas which is the main dimension
(i.e., the output dimension the grid should adapt to)

Note that the integrand should have an output of the same shape as the tensor of random numbers
the shape of the tensor of random numbers and of the output is (nevents, ndim)
"""
from vegasflow import VegasFlow, run_eager

run_eager()
import tensorflow as tf

# MC integration setup
dim = 3
ncalls = int(1e4)
n_iter = 5


@tf.function
def test_function(xarr):
res = tf.square((xarr - 1.0) ** 2)
return tf.exp(-res)


if __name__ == "__main__":
print("Testing a multidimensional integration")
vegas = VegasFlow(dim, ncalls, main_dimension=1)
vegas.compile(test_function)
all_results, all_err = vegas.run_integration(2)
try:
for result, error in zip(all_results, all_err):
print(f"{result = :.5} +- {error:.5}")
except TypeError:
# So that the example works also if the integrand is made scalar
result = all_results
error = all_err
print(f"{result = :.5} +- {error:.5}")
4 changes: 4 additions & 0 deletions examples/multiple_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

The example integrands are variations of the Genz functions definged in
Novak et al, 1999 (J. of Comp and Applied Maths, 112 (1999) 215-228 and implemented from http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.123.8452&rep=rep1&type=pdf

Note, when possible the ``multidimensional_integral.py`` features should be utilized
as then the error computation is automatically taken into account by the algorithms
in VegasFlow instead of having to implement it by hand.
"""
import vegasflow
from vegasflow.configflow import DTYPE
Expand Down
2 changes: 1 addition & 1 deletion src/vegasflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Monte Carlo integration with Tensorflow"""

from vegasflow.configflow import int_me, float_me, run_eager
from vegasflow.configflow import int_me, float_me, run_eager, DTYPE, DTYPEINT

# Expose the main interfaces
from vegasflow.vflow import VegasFlow, vegas_wrapper, vegas_sampler
Expand Down
66 changes: 59 additions & 7 deletions src/vegasflow/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def print_iteration(it, res, error, extra="", threshold=0.1):
# note: actually, the flag 'g' does this automatically
# but I prefer to choose the precision myself...
if res < threshold:
logger.info(f"Result for iteration {it}: {res:.3e} +/- {error:.3e}" + extra)
return f"Result for iteration {it}: {res:.3e} +/- {error:.3e}" + extra
else:
logger.info(f"Result for iteration {it}: {res:.4f} +/- {error:.4f}" + extra)
return f"Result for iteration {it}: {res:.4f} +/- {error:.4f}" + extra


def _accumulate(accumulators):
Expand Down Expand Up @@ -105,6 +105,8 @@ class MonteCarloFlow(ABC):
`list_devices`: list of device type to use (use `None` to do the tensorflow default)
"""

_CAN_RUN_VECTORIAL = False

def __init__(
self,
n_dim,
Expand All @@ -126,6 +128,7 @@ def __init__(
self._events_limit = events_limit
self._events_per_run = min(events_limit, n_events)
self._compilation_arguments = None
self._vectorial = False
self.distribute = False
# If any of the pass variables below is set to true
# the integrand will be expecting them so the integrator
Expand Down Expand Up @@ -255,6 +258,11 @@ def _run_event(self, integrand, ncalls=None):
result = self.event()
return result, pow(result, 2)

def _can_run_vectorial(self, expected_shape=None):
"""Accepting vectorial integrands depends on the algorithm,
if an algorithm can run on vectorial algorithms it should implement this method and return True"""
return self._CAN_RUN_VECTORIAL

#### Integration management
def set_seed(self, seed):
"""Sets the interation seed"""
Expand Down Expand Up @@ -322,6 +330,7 @@ def set_distribute(self, queue_object):
import dask.distributed # pylint: disable=import-error
except ImportError as e:
raise ImportError("Install dask and distributed to use `set_distribute`") from e

if self.devices is not None:
logger.warning("`set_distribute` overrides any previous device configuration")
self.list_devices = None
Expand Down Expand Up @@ -421,6 +430,7 @@ def run_event(self, tensorize_events=False, **kwargs):
for ncalls, pc in zip(events_to_do, percentages):
res = self.device_run(ncalls, sent_pc=pc, **kwargs)
accumulators.append(res)

return _accumulate(accumulators)

def trace(self, n_events=50):
Expand All @@ -435,7 +445,7 @@ def trace(self, n_events=50):
self.n_events = true_events
self._verbose = true_verbosity

def compile(self, integrand, compilable=True, signature=None, trace=False):
def compile(self, integrand, compilable=True, signature=None, trace=False, check=True):
"""Receives an integrand, prepares it for integration
and tries to compile unless told otherwise.

Expand Down Expand Up @@ -477,7 +487,9 @@ def compile(self, integrand, compilable=True, signature=None, trace=False):
is not passed through `tf.function`
`signature`: (default: True)
whether to autodiscover the signature of the integrand

`check`: (default: True)
check whether the integrand produces expected results and whether it is vectorial
note, with check=False vectorial output will not work
"""
kwargs = {"compilable": compilable, "signature": signature, "trace": trace}
self._compilation_arguments = (integrand, kwargs)
Expand Down Expand Up @@ -522,7 +534,7 @@ def compile(self, integrand, compilable=True, signature=None, trace=False):
tf_integrand = integrand

# The algorithms will always call the function with
# (xrand, weight=)
# (xrand, weight)
# therefore create a wrapper withi whatever integrand to become this
# n_dim was an option there during development and needs to be left because of legacy code

Expand All @@ -548,6 +560,33 @@ def batch_events(**kwargs):
if trace:
self.trace()

if check:
event_size = 23
test_array = tf.random.uniform((event_size, self.n_dim), dtype=DTYPE)
wgt = tf.random.uniform((event_size,), dtype=DTYPE)
res_tmp = new_integrand(test_array, weight=wgt).numpy()
res_shape = res_tmp.shape

expected_shape = (event_size,)

if len(res_shape) == 2:
self._vectorial = True
expected_shape = res_tmp.reshape(event_size, -1).shape
if not self._can_run_vectorial(expected_shape):
raise NotImplementedError(
f"""The {self.__class__.__name__} algorithm does not support vectorial integrands
if you believe this to be a bug please open an issue in https://github.com/N3PDF/vegasflow/issues"""
)

if res_shape != expected_shape:
error_str = "the shape of the integrand output should be: (n_events,"
if self._vectorial:
error_str += " output_dim,"
logger.error(f"Wrong integrand output shape, {error_str})")
raise ValueError(
f"The integrand is not returning a value per event, expected shape: {expected_shape}, found: {res_shape}"
)

def _recompile(self):
"""Forces recompilation with the same arguments that have
previously been used for compilation"""
Expand Down Expand Up @@ -613,7 +652,14 @@ def run_integration(self, n_iter, log_time=True, histograms=None):
else:
time_str = ""
if self._verbose:
print_iteration(i, res, error, extra=time_str)
all_info = []
if self._vectorial:
for d, (rr, ee) in enumerate(zip(res, error)):
append_str = f" [dimension {d}] {time_str}"
all_info.append(print_iteration(i, rr, ee, extra=append_str))
else:
all_info = [print_iteration(i, res, error, extra=time_str)]
logger.info("\n ".join(all_info))

# Once all iterations are finished, print out
aux_res = 0.0
Expand All @@ -637,7 +683,13 @@ def run_integration(self, n_iter, log_time=True, histograms=None):
final_result = aux_res / weight_sum
sigma = np.sqrt(1.0 / weight_sum)
if self._verbose:
logger.info(f" > Final results: {final_result.numpy():g} +/- {sigma:g}")
final_results = []
if self._vectorial:
for dim, (rr, ee) in enumerate(zip(final_result.numpy(), sigma)):
final_results.append(f"Final results [{dim = }]: {rr:g} +/- {ee:g}")
else:
final_results = [f" > Final results: {final_result.numpy():g} +/- {sigma:g}"]
logger.info("\n ".join(final_results))
return final_result.numpy(), sigma


Expand Down
11 changes: 8 additions & 3 deletions src/vegasflow/plain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ class PlainFlow(MonteCarloFlow):
Simple Monte Carlo integrator.
"""

_CAN_RUN_VECTORIAL = True

def _run_event(self, integrand, ncalls=None):
if ncalls is None:
n_events = self.n_events
Expand All @@ -20,12 +22,15 @@ def _run_event(self, integrand, ncalls=None):

# Generate all random number for this iteration
rnds, _, xjac = self._generate_random_array(n_events)

# Compute the integrand
tmp = integrand(rnds, weight=xjac) * xjac
tmp2 = tf.square(tmp)
# Accumulate the current result
res = tf.reduce_sum(tmp)
res2 = tf.reduce_sum(tmp2)

# Accommodate multidimensional output by ensuring that only the event axis is accumulated
res = tf.reduce_sum(tmp, axis=0)
res2 = tf.reduce_sum(tmp2, axis=0)

return res, res2

def _run_iteration(self):
Expand Down
Loading