vsimkus/vae-conditional-sampling
[TMLR] Research code for the paper "Conditional Sampling of Variational Autoencoders via Iterated Approximate Ancestral Sampling".
Conditional Sampling of Variational Autoencoders via Iterated Approximate Ancestral Sampling
This repository contains the research code for
Vaidotas Simkus, Michael U. Gutmann. Conditional Sampling of Variational Autoencoders via Iterated Approximate Ancestral Sampling. Transactions on Machine Learning Research, 2023.
The paper can be found here: https://openreview.net/forum?id=I5sJ6PU6JN.
The code is shared for reproducibility purposes and is not intended for production use. It should also serve as a reference implementation for anyone wanting to use LAIR or AC-MWG for conditional sampling of VAEs (for e.g. missing data imputation using pre-trained VAEs).
Abstract
Conditional sampling of variational autoencoders (VAEs) is needed in various applications, such as missing data imputation, but is computationally intractable. A principled choice for asymptotically exact conditional sampling is Metropolis-within-Gibbs (MWG). However, we observe that the tendency of VAEs to learn a structured latent space, a commonly desired property, can cause the MWG sampler to get “stuck” far from the target distribution. This paper mitigates the limitations of MWG: we systematically outline the pitfalls in the context of VAEs, propose two original methods that address these pitfalls, and demonstrate an improved performance of the proposed methods on a set of sampling tasks.
Dependencies
Install python dependencies from conda and the irwg project package with
conda env create -f environment.yml
conda activate irwg
python setup.py developIf the dependencies in environment.yml change, update dependencies with
conda env update --file environment.ymlOrganisation of the code
./irwg/data/contains data loaders and missingness generators../irwg/models/contains the neural network model implementations../irwg/sampling/contains the code related to VAE sampling.test_step_vae_sampling.pycontains the implementations of the methods in the paper.
(Note: some method names are different from the paper)- LAIR is implemented in a class called
TestVAELatentAdaptiveImportanceResampling - AC-MWG is implemented in a class called
TestVAEAdaptiveCollapsedMetropolisWithinGibbs
./configs/contains the yaml configuration files containing all the information about each experiment../helpers/directory contains various helper scripts for the analysis of the imputations.compute_mnist_mog_posterior_probs.pycomputes the metrics on MNIST-GMM data.eval_large_uci_joint_imputed_dataset_divergences.pycomputes the metrics on UCI data and stores into a file.eval_omniglot_joint_imputed_dataset_fids.pycomputes the metrics on Omniglot data and stores into a file.create_marginal_vae_imputations.pycreates imputations by sampling the marginal of the VAE (i.e. unconditional imputation baseline)- Configs for the helper scripts are also located in
./configs/directory.
./notebooks/contain analysis notebooks that produce the figures in the paper, using the outputs from the helper scripts.
Running the code
Activate the conda environment
conda activate irwgVAE training
To train the VAE, which we use for sampling run e.g.
python train.py --config=configs/mnist_gmm/vae_convresnet3.yamlVAE sampling
Then, to sample a VAE using one of the methods run
python test.py --config=configs/mnist_gmm/samples/vae_convresnet3_k4_irwg_i1_dmis_gr_mult_replenish1_finalresample.yamlAnalysis helper scripts
Then, use ./helpers/compute_mnist_mog_posterior_probs.py to compute the metrics and store them in a file, and then plot them in a notebook.
Similarly, for UCI data use ./helpers/eval_large_uci_joint_imputed_dataset_divergences.py to compute the metrics, and then plot them in a notebook.