Skip to content

Commit

Permalink
Merge pull request #53 from N3PDF/update_docs
Browse files Browse the repository at this point in the history
Update docs
  • Loading branch information
scarlehoff authored Jul 29, 2020
2 parents 8e02416 + 48ee69f commit f6a3e30
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 7 deletions.
128 changes: 128 additions & 0 deletions doc/source/examples.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
.. _examples-label:

==========
Examples
==========

In the ``VegasFlow`` repository you can find `several examples <https://github.com/N3PDF/vegasflow/tree/master/examples>`_
of integrands which can hopefully help you to quickstart your project.

In this page we explain in more detail some of these examples.
You can find the full code in the repository alongside more complicated versions.


.. contents::
:local:
:depth: 1


Basic Integral
==============

The most general usage of ``Vegasflow`` is the integration of a tensorflow-based
integrand.

.. code-block:: python
import tensorflow as tf
@tf.function
def my_integrand(xarr, **kwargs):
return tf.reduce_sum(xarr, axis=1)
from VegasFlow.vflow import vegas_wrapper
n_dim = 10
n_events = int(1e6)
n_iter = 5
result = vegas_wrapper(my_integrand, n_dim, n_iter, n_events)
You can find a `runnable example of such a basic example in the repository <https://github.com/N3PDF/vegasflow/blob/master/examples/simgauss_tf.py>`_.


Interfacing C code: CFFI
========================

A popular way of interfacing python and C code is to use the
`CFFI library <https://cffi.readthedocs.io/en/latest/>`_.

Imagine you have a C-file with some integrand:

.. code-block:: C
// integrand.c
void integrand(double *xarr, int ndim, int nevents, double *out) {
for (int i = 0; i < nevents; i++) {
out[i] = 0.0;
for (int j = 0; j < ndim; j++) {
out[i] += 2.0*xarr[j+i*ndim];
}
}
}
You can compile this code and integrate it (no pun intended) with ``vegasflow``
by using the CFFI library as you can see in `this <https://github.com/N3PDF/vegasflow/blob/master/examples/simgauss_cffi.py>`_ example.

.. code-block:: python
from vegasflow.configflow import DTYPE
import numpy as np
from vegasflow.vflow import vegas_wrapper
from cffi import FFI
ffibuilder = FFI()
ffibuilder.cdef("void integrand(double*, int, int, double*);")
with open("integrand.c", "r") as f:
ffibuilder.set_source("_integrand_cffi", f.read())
ffibuilder.compile()
# Now you can read up the compiled C code as a python library
from _integrand_cffi import ffi, lib
def integrand(xarr, n_dim, **kwargs):
result = np.empty(n_events, dtype=DTYPE.as_numpy_dtype)
x_flat = xarr.numpy().flatten()
p_input = ffi.cast("double*", ffi.from_buffer(x_flat))
p_output = ffi.cast("double*", ffi.from_buffer(result))
lib.integrand(p_input, n_dim, xarr.shape[0], p_output)
return result
vegas_wrapper(integrand, 5, 10, int(1e5), compilable=False)
Note the usage of the ``compilable=False`` flag.
This signals ``VegasFlow`` that the integrand is not pure tensorflow and
that a graph of the full computation cannot be compiled.


Create your own TF-compilable operators
=======================================

Tensorflow tries to do its best to compile your integrand to something that can
quickly be evaluated on GPU.
It has no information, however, about specific situations that would allow
for non trivial optimizations.

In these cases one could want to write its own C++ or Cuda code while still
allowing for Tensorflow to create a full graph of the computation.

Creating new operations in TF are an advance feature and, when possible,
it is recommended to create your integrand as a composition of TF operators.
If you still want to go ahead we have prepared a `simple example <https://github.com/N3PDF/vegasflow/tree/master/examples/cuda>`_
in the repository that can be used as a template for C++ or Cuda custom
operators.

The example includes a `makefile <https://github.com/N3PDF/vegasflow/blob/master/examples/cuda/makefile>`_ that you might need to modify for your particular needs.

Note that in order to run the code in both GPUs and CPU you will need to provide
a GPU and CPU capable kernels.




67 changes: 64 additions & 3 deletions doc/source/how_to.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ How to use
A first VegasFlow integration
=============================

Prototyping in VegasFlow is easy, the best results are obtained when the interands are written using TensorFlow primitives.
Prototyping in ``VegasFlow`` is easy, the best results are obtained when the
integrands are written using TensorFlow primitives.
Below we show one example where we create a TF constant (using ``tf.constant``) and then we use the sum and power functions.

.. code-block:: python
Expand Down Expand Up @@ -44,15 +45,75 @@ We also provide a convenience wrapper ``vegas_wrapper`` that allows to run the w
result = vegas_wrapper(example_integrand, dimensions, n_iter, ncalls)
Global configuration
====================

Verbosity
---------

Tensorflow is very verbose by default.
When ``vegasflow`` is imported the environment variable ``TF_CPP_MIN_LOG_LEVEL``
is set to 1, hiding most warnings.
If you want to recover the usual Tensorflow logging level you can
set your enviroment to ``export TF_CPP_MIN_LOG_LEVEL=0``.

Choosing integration device
---------------------------

The ``CUDA_VISIBLE_DEVICES`` environment variable will tell Tensorflow
(and thus VegasFlow) in which device it should run.
If the variable is not set, it will default to use all (and only) GPUs available.
In order to use the CPU you can hide the GPU by setting
``export CUDA_VISIBLE_DEVICES=""``.

If you have a set-up with more than one GPU you can select which one you will
want to use for the integration by setting the environment variable to the
right device, e.g., ``export CUDA_VISIBLE_DEVICES=0``.


Eager Vs Graph-mode
-------------------

When performing computational expensive tasks Tensorflow's graph mode is preferred.
When compiling you will notice the first iteration of the integration takes a bit longer, this is normal
and it's due to the creation of the graph.
Subsequent iterations will be faster.

Graph-mode however is not debugger friendly as the code is read only once, when compiling the graph.
You can however enable Tensorflow's `eager execution <https://www.tensorflow.org/guide/eager>`_.
With eager mode the code is run sequentially as you would expect with normal python code,
this will allow you to throw in instances of ``pdb.set_trace()``.
In order to enable eager mode include these lines at the top of your program:

.. code-block:: python
import tensorflow as tf
tf.config.run_functions_eagerly(True)
or if you are using versions of Tensorflow older than 2.3:

.. code-block:: python
import tensorflow as tf
tf.config.experimental_run_functions_eagerly(True)
Eager mode also enables the usage of the library as a `standard` python library
allowing you to integrate non-tensorflow integrands.
These integrands, as they are not understood by tensorflow, are not run using
GPU kernels while the rest of ``VegasFlow`` will still be run on GPU if possible.


Histograms
==========

A commonly used feature in Monte Carlo calculations is the generation of histograms.
In order to generate them while at the same time keeping all the features of ``vegasflow``,
such as GPU computing, it is necessary to ensure the histogram generation is also wrapped with the ``@tf.function`` directive.

Below we show one such example (how the histogram is actually generated and saved is up to the user).
The first step is to create a ``Variable`` tensor which will be used to fill the histograms.
This is a crucial step (and the only fixed step) as this tensor will be accumulated internally by ``VegasFlow''.
This is a crucial step (and the only fixed step) as this tensor will be accumulated internally by ``VegasFlow``.


.. code-block:: python
Expand Down Expand Up @@ -93,7 +154,7 @@ This is a crucial step (and the only fixed step) as this tensor will be accumula
histogram_collector(final_result * weight, histogram_values)
return final_result
Finally we can normally call ``vegasflow'', remembering to pass down the accumulator tensor, which will be filled in with the histograms.
Finally we can normally call ``vegasflow``, remembering to pass down the accumulator tensor, which will be filled in with the histograms.
Note that here we are only filling one histograms and so the histogram tuple contains only one element, but any number of histograms can be filled.


Expand Down
1 change: 1 addition & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ Indices and tables
VegasFlow<self>
how_to
intalg
examples
apisrc/vegasflow


Expand Down
46 changes: 46 additions & 0 deletions examples/example_eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Example: eager mode integrand
Running a non-tensorflow integrand using VegasFlow
"""

from vegasflow.configflow import DTYPE, DTYPEINT
import time
import numpy as np
from scipy.special import expit
import tensorflow as tf

tf.config.run_functions_eagerly(True)
from vegasflow.vflow import vegas_wrapper


# MC integration setup
dim = 4
ncalls = np.int32(1e5)
n_iter = 5


@tf.function
def symgauss_sigmoid(xarr, n_dim=None, **kwargs):
"""symgauss test function"""
if n_dim is None:
n_dim = xarr.shape[-1]
a = 0.1
pref = pow(1.0 / a / np.sqrt(np.pi), n_dim)
coef = np.sum(np.arange(1, 101))
# Tensorflow variable will be casted down by numpy
# you can directly access their numpy representation with .numpy()
xarr_sq = np.square((xarr - 1.0 / 2.0) / a)
coef += np.sum(xarr_sq, axis=1)
coef -= 100.0 * 101.0 / 2.0
return expit(xarr[:, 0].numpy()) * (pref * np.exp(-coef))


if __name__ == "__main__":
"""Testing several different integrations"""
print(f"VEGAS MC, ncalls={ncalls}:")
start = time.time()
ncalls = 10 * ncalls
r = vegas_wrapper(symgauss_sigmoid, dim, n_iter, ncalls, compilable=True)
end = time.time()
print(f"Vegas took: time (s): {end-start}")
9 changes: 7 additions & 2 deletions examples/simgauss_tf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
# Place your function here
"""
Example: basic integration
Basic example using the vegas_wrapper helper
"""

from vegasflow.configflow import DTYPE, DTYPEINT
import time
import numpy as np
Expand Down Expand Up @@ -32,7 +37,7 @@ def symgauss(xarr, n_dim=None, **kwargs):
print(f"VEGAS MC, ncalls={ncalls}:")
start = time.time()
ncalls = 10*ncalls
r = vegas_wrapper(symgauss, dim, n_iter, ncalls, compilable=True)
r = vegas_wrapper(symgauss, dim, n_iter, ncalls)
end = time.time()
print(f"Vegas took: time (s): {end-start}")

Expand Down
2 changes: 1 addition & 1 deletion src/vegasflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Monte Carlo integration with Tensorflow"""

__version__ = "1.0.2"
__version__ = "1.1.0"
2 changes: 1 addition & 1 deletion src/vegasflow/configflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "1")
# Most of this can be moved to a yaml file without loss of generality
import tensorflow as tf

Expand Down

0 comments on commit f6a3e30

Please sign in to comment.