Skip to content

Implementing Controlled Monte Carlo Diffusions (ICLR 2024)

License

Notifications You must be signed in to change notification settings

shreyaspadhy/CMCD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Transport meets Variational Inference: Controlled Monte Carlo Diffusions

Official code repository for our ICLR 2024 paper [OpenReview] [arXiv].

We provide code to run the experiments in the paper, on a wide variety of target distributions that have been implemented in model_handler.py. The code is written in Jax, and we use wandb for logging and visualisation.

To run different methods and targets, following the template below -

python main.py --config.model log_ionosphere --config.boundmode MCD_ULA

  • ULA uses MCD_ULA
  • MCD uses MCD_ULA_sn
  • UHA uses UHA
  • LDVI uses MCD_U_a-lp-sn
  • CMCD uses MCD_CAIS_sn
  • 2nd order CMCD uses MCD_CAIS_UHA_sn
  • CMCD + VarGrad loss uses MCD_CAIS_var_sn

Below, we provide the commands replicating the exact hparam settings used in the paper, and the wandb links to the experiments.

40-GMM Experiments

By default, in order to make comparisons to DDS/PIS, we use the same network architecture with time embeddings from the DDS repo. In order to run our method using the DDS architecture, you can set --config.nn_arch dds in the command line.

python main.py --config.model many_gmm --config.boundmode MCD_CAIS_sn --config.N 2000 --config.nbridges 256 --noconfig.pretrain_mfvi --config.init_sigma 60 --config.grad_clipping --config.init_eps 1 --config.eps_schedule cos_sq --config.lr 0.001 --noconfig.train_eps --noconfig.train_vi --config.wandb.name "kl 40gmm pis net eps=1, cos_sq" --config.nn_arch dds
python main.py --config.model many_gmm --config.boundmode MCD_CAIS_var_sn --config.N 2000 --config.nbridges 256 --noconfig.pretrain_mfvi --config.init_sigma 15 --config.grad_clipping --config.init_eps 0.65 --config.emb_dim 130 --config.lr 0.005 --noconfig.train_eps --noconfig.train_vi --config.wandb.name "logvar 40gmm"
python main.py --config.model many_gmm --config.boundmode MCD_CAIS_sn --config.N 2000 --config.nbridges 256 --noconfig.pretrain_mfvi --config.init_sigma 15 --config.grad_clipping --config.init_eps 0.1 --config.emb_dim 130 --config.lr 0.005 --noconfig.train_eps --noconfig.train_vi --config.wandb.name "kl 40gmm"
python main.py --config.model many_gmm --config.boundmode MCD_CAIS_sn --config.N 2000 --config.nbridges 256 --noconfig.pretrain_mfvi --config.init_sigma 60 --config.grad_clipping --config.init_eps 1 --config.eps_schedule cos_sq --config.lr 0.001 --noconfig.train_eps --noconfig.train_vi --config.wandb.name "kl 40gmm pis net eps=1, cos_sq" --config.nn_arch dds

[Old KL Wandb experiment eps=0.65]

[Old KL Wandb experiment eps=0.1]

[Old logvar Wandb experiment eps=0.65]

[Old logvar Wandb experiment eps=0.1]

Funnel Experiments

python main.py --config.boundmode MCD_CAIS_sn --config.model funnel --config.N 300 --config.alpha 0.05 --config.emb_dim 48 --config.init_eps 0.1 -config.init_sigma 1 --config.iters 11000 --noconfig.pretrain_mfvi --config.train_vi --noconfig.train_eps --config.wandb.name "funnel replicate w/ cos_sq" --config.lr 0.01 --config.n_samples 2000 --config.eps_schedule cos_sq

[Old wandb experiment with paper numbers] [Replicated Wandb experiment at main]

The paper numbers differ in the following ways: (1) Uses Geffner's manual ADAM implementation.

LGCP Experiments

python main.py --config.boundmode MCD_CAIS_sn --config.model lgcp --config.N 20 --config.alpha 0.05 --config.emb_dim 20 --config.init_eps 0.00001 -config.init_sigma 1 --config.iters 37500 --config.pretrain_mfvi --config.train_vi --config.train_eps --config.wandb.name "lgcp replicate" --config.lr 0.0001 --config.n_samples 500 --config.mfvi_iters 20000

[Old wandb experiment with paper numbers] [Wandb experiment at main]

Differences from the paper experiments: (1) The new run is about 10min slower due to extra logging, (2) 20000 steps of MFVI is enough, vs 150k from the paper.

2-GMM Experiments

python main.py --config.boundmode MCD_CAIS_sn --config.model gmm --config.N 300 --config.alpha 0.05 --config.emb_dim 20 --config.init_eps 0.01 -config.init_sigma 1 --config.iters 11000 --noconfig.pretrain_mfvi --config.train_vi --noconfig.train_eps --config.wandb.name "gmm replicate" --config.lr 0.001 --config.n_samples 500

[Old wandb experiment with rebuttal paper numbers] [Wandb experiment]

[Original paper numbers]

Differences: (1) The new run has better $\ln Z$ estimates overall.

If you use any of our code or ideas, please cite our work using the following BibTeX entry:

@inproceedings{
vargas2024transport,
title={Transport meets Variational Inference: Controlled Monte Carlo Diffusions},
author={Francisco Vargas and Shreyas Padhy and Denis Blessing and Nikolas N{\"u}sken},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=PP1rudnxiW}
}

About

Implementing Controlled Monte Carlo Diffusions (ICLR 2024)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published