The following instructions are for running the examples seen in "Generalization Bounds for Meta-Learning via PAC-Bayes and Uniform Stability".


This code uses the following:

If you are using Anaconda, you can run the following commands to install all of the necessary packages.

conda create -n pacbus
conda activate pacbus
conda install pytorch=1.4.0 -c pytorch
pip install learn2learn cvxpy Mosek sklearn h5py

Circle Classification Example:

Run the following to produce the results from Example 1 in the paper. Note that --num_val denotes the number of times we run the resulting policy on test data. We require a large number of evaluations to produce a tight upper bound (see Appendix A.4 for more information). For testing purposes, you may want to reduce --num_val so the program takes less time to finish.

python --method maml    --prior train --trials full --verbose True 
python --method mlap    --prior train --trials full --verbose True
python --method mr_maml --prior train --trials full --num_val 20000 --verbose True
python --method pac_bus --prior train --trials full --num_val 20000 --verbose True

Mini-Wiki Example:


Run the following to generate the dataset:

python data_generators/

Run the following to produce the results from Example 2 in the paper.

python --method maml      --prior train --trials full --verbose True
python --method fli_batch --prior train --trials full --verbose True
python --method mr_maml   --prior train --trials full --num_val 20000 --verbose True
python --method pac_bus   --prior train --trials full --num_val 20000 --verbose True

NME Omniglot Example:

Run the following to produce the results from Example 3 in the paper for --seed 1 through 5. This will automatically download the omniglot dataset if you do not have it. A gpu is recommended, but you may specify option --gpu -1 to use the cpu for all computations.

python --method maml       --k_spt 1 --k_qry 4 --batch 16 --nme True --epochsm 100000 --lrm 0.005 --lrb 0.1 --seed 1
python --method maml       --k_spt 5 --k_qry 5 --batch 16 --nme True --epochsm 100000 --lrm 0.005 --lrb 0.1 --seed 1

python --method fli_online --k_spt 1 --k_qry 4 --batch 16 --nme True --epochsm 100000 --lrm 0.001 --seed 1  
python --method fli_online --k_spt 5 --k_qry 5 --batch 16 --nme True --epochsm 100000 --lrm 0.001 --seed 1

python --method mr_maml_w  --k_spt 1 --k_qry 4 --batch 16 --nme True --epochsm 100000 --lrm 0.001 --lrb 0.5 --regscale 2e-7 --seed 1
python --method mr_maml_w  --k_spt 5 --k_qry 5 --batch 16 --nme True --epochsm 100000 --lrm 0.001 --lrb 0.5 --regscale 2e-7 --seed 1

python --method pac_bus_h  --k_spt 1 --k_qry 4 --batch 16 --nme True --epochsm 100000 --lrm 0.001 --lrb 0.5 --regscale 1e-3 --regscale2 10.0 --seed 1
python --method pac_bus_h  --k_spt 5 --k_qry 5 --batch 16 --nme True --epochsm 100000 --lrm 0.001 --lrb 0.5 --regscale 1e-4 --regscale2 10.0 --seed 1


