Official code for HH-VAEM

Overview

HH-VAEM

This repository contains the official Pytorch implementation of the Hierarchical Hamiltonian VAE for Mixed-type Data (HH-VAEM) model and the sampling-based feature acquisition technique presented in the paper Missing Data Imputation and Acquisition with Deep Hierarchical Models and Hamiltonian Monte Carlo. HH-VAEM is a Hierarchical VAE model for mixed-type incomplete data that uses Hamiltonian Monte Carlo with automatic hyper-parameter tuning for improved approximate inference. The repository contains the implementation and the experiments provided in the paper.

Please, if you use this code, cite the preprint using:

@article{peis2022missing,
  title={Missing Data Imputation and Acquisition with Deep Hierarchical Models and Hamiltonian Monte Carlo},
  author={Peis, Ignacio and Ma, Chao and Hern{\'a}ndez-Lobato, Jos{\'e} Miguel},
  journal={arXiv preprint arXiv:2202.04599},
  year={2022}
}

Instalation

The installation is straightforward using the following instruction, that creates a conda virtual environment named HH-VAEM using the provided file environment.yml:

conda env create -f environment.yml

Usage

Training

The project is developed in the recent research framework PyTorch Lightning. The HH-VAEM model is implemented as a LightningModule that is trained by means of a Trainer. A model can be trained by using:

# Example for training HH-VAEM on Boston dataset
python train.py --model HHVAEM --dataset boston --split 0

This will automatically download the boston dataset, split in 10 train/test splits and train HH-VAEM on the training split 0. Two folders will be created: data/ for storing the datasets and logs/ for model checkpoints and TensorBoard logs. The variable LOGDIR can be modified in src/configs.py to change the directory where these folders will be created (this might be useful for avoiding overloads in network file systems).

The following datasets are available:

  • A total of 10 UCI datasets: avocado, boston, energy, wine, diabetes, concrete, naval, yatch, bank or insurance.
  • The MNIST datasets: mnist or fashion_mnist.
  • More datasets can be easily added to src/datasets.py.

For each dataset, the corresponding parameter configuration must be added to src/configs.py.

The following models are also available (implemented in src/models/):

  • HHVAEM: the proposed model in the paper.
  • VAEM: the VAEM strategy presented in (Ma et al., 2020) with Gaussian encoder (without including the Partial VAE).
  • HVAEM: A Hierarchical VAEM with two layers of latent variables and a Gaussian encoder.
  • HMCVAEM: A VAEM that includes a tuned HMC sampler for the true posterior.
  • For MNIST datasets (non heterogeneous data), use HHVAE, VAE, HVAE and HMCVAE.

By default, the test stage will be executed at the end of the training stage. This can be cancelled with --test 0 for manually running the test using:

# Example for testing HH-VAEM on Boston dataset
python test.py --model HHVAEM --dataset boston --split 0

which will load the trained model to be tested on the boston test split number 0. Once all the splits are tested, the average results can be obtained using the script in the run/ folder:

# Example for obtaining the average test results with HH-VAEM on Boston dataset
python test_splits.py --model HHVAEM --dataset boston

Experiments

The experiments in the paper can be executed using:

# Example for running the SAIA experiment with HH-VAEM on Boston dataset
python active_learning.py --model HHVAEM --dataset boston --method mi --split 0

# Example for running the OoD experiment using MNIST and Fashion-MNIST as OoD:
python ood.py --model HHVAEM --dataset mnist --dataset_ood fashion_mnist --split 0

Once this is executed on all the splits, you can plot the SAIA error curves or obtain the average OoD metrics using the scripts in the run/ folder:

# Example for running the SAIA experiment with HH-VAEM on Boston dataset
python active_learning_plots.py --models VAEM HHVAEM --dataset boston

# Example for running the OoD experiment using MNIST and Fashion-MNIST as OoD:
python ood_splits.py --model HHVAEM --dataset mnist --dataset_ood fashion_mnist


Help

Use the --help option for documentation on the usage of any of the mentioned scripts.

Contributors

Ignacio Peis
Chao Ma
José Miguel Hernández-Lobato

Contact

For further information: [email protected]

Owner
Ignacio Peis
PhD student at UC3M \\ Visitor at the Machine Learning Group, CBL, University of Cambridge
Ignacio Peis
Create large-scale ML-driven multiscale simulation ensembles to study the interactions

MuMMI RAS v0.1 Released: Nov 16, 2021 MuMMI RAS is the application component of the MuMMI framework developed to create large-scale ML-driven multisca

4 Feb 16, 2022
BentoML is a flexible, high-performance framework for serving, managing, and deploying machine learning models.

Model Serving Made Easy BentoML is a flexible, high-performance framework for serving, managing, and deploying machine learning models. Supports multi

BentoML 4.4k Jan 04, 2023
Kaggle Competition using 15 numerical predictors to predict a continuous outcome.

Kaggle-Comp.-Data-Mining Kaggle Competition using 15 numerical predictors to predict a continuous outcome as part of a final project for a stats data

moisey alaev 1 Dec 28, 2021
Machine Learning Study 혼자 해보기

Machine Learning Study 혼자 해보기 기여자 (Contributors) ✨ Teddy Lee 🏠 HongJaeKwon 🏠 Seungwoo Han 🏠 Tae Heon Kim 🏠 Steve Kwon 🏠 SW Song 🏠 K1A2 🏠 Wooil

Teddy Lee 1.7k Jan 01, 2023
icepickle is to allow a safe way to serialize and deserialize linear scikit-learn models

icepickle It's a cooler way to store simple linear models. The goal of icepickle is to allow a safe way to serialize and deserialize linear scikit-lea

vincent d warmerdam 24 Dec 09, 2022
A Python Package to Tackle the Curse of Imbalanced Datasets in Machine Learning

imbalanced-learn imbalanced-learn is a python package offering a number of re-sampling techniques commonly used in datasets showing strong between-cla

6.2k Jan 01, 2023
(3D): LeGO-LOAM, LIO-SAM, and LVI-SAM installation and application

SLAM-application: installation and test (3D): LeGO-LOAM, LIO-SAM, and LVI-SAM Tested on Quadruped robot in Gazebo ● Results: video, video2 Requirement

EungChang-Mason-Lee 203 Dec 26, 2022
Model factory is a ML training platform to help engineers to build ML models at scale

Model Factory Machine learning today is powering many businesses today, e.g., search engine, e-commerce, news or feed recommendation. Training high qu

16 Sep 23, 2022
As we all know the BGMI Loot Crate comes with so many resources for the gamers, this ML Crate will be the hub of various ML projects which will be the resources for the ML enthusiasts! Open Source Program: SWOC 2021 and JWOC 2022.

Machine Learning Loot Crate 💻 🧰 🔴 Welcome contributors! As we all know the BGMI Loot Crate comes with so many resources for the gamers, this ML Cra

Abhishek Sharma 89 Dec 28, 2022
This is a curated list of medical data for machine learning

Medical Data for Machine Learning This is a curated list of medical data for machine learning. This list is provided for informational purposes only,

Andrew L. Beam 5.4k Dec 26, 2022
WAGMA-SGD is a decentralized asynchronous SGD for distributed deep learning training based on model averaging.

WAGMA-SGD is a decentralized asynchronous SGD based on wait-avoiding group model averaging. The synchronization is relaxed by making the collectives externally-triggerable, namely, a collective can b

Shigang Li 6 Jun 18, 2022
AP1 Transcription Factor Binding Site Prediction

A machine learning project that predicted binding sites of AP1 transcription factor, using ChIP-Seq data and local DNA shape information.

1 Jan 21, 2022
A naive Bayes model for cancer classification using a set of documents

Naivebayes text classifcation model for cancer and noncancer documents Author: Alex King Purpose Requirements/files included How to use 1. Purpose The

Alex W King 1 Nov 24, 2021
mlpack: a scalable C++ machine learning library --

a fast, flexible machine learning library Home | Documentation | Doxygen | Community | Help | IRC Chat Download: current stable version (3.4.2) mlpack

mlpack 4.2k Jan 01, 2023
Management of exclusive GPU access for distributed machine learning workloads

TensorHive is an open source tool for managing computing resources used by multiple users across distributed hosts. It focuses on granting

Paweł Rościszewski 131 Dec 12, 2022
This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch

This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment to test the algorithm

Martin Huber 59 Dec 09, 2022
Iris-Heroku - Putting a Machine Learning Model into Production with Flask and Heroku

Puesta en Producción de un modelo de aprendizaje automático con Flask y Heroku L

Jesùs Guillen 1 Jun 03, 2022
A Python library for choreographing your machine learning research.

A Python library for choreographing your machine learning research.

AI2 270 Jan 06, 2023
MLFlow in a Dockercontainer based on Azurite and Postgres

mlflow-azurite-postgres docker This is a MLFLow image which works with a postgres DB and a local Azure Blob Storage Instance (Azurite). This image is

2 May 29, 2022
ClearML - Auto-Magical Suite of tools to streamline your ML workflow. Experiment Manager, MLOps and Data-Management

ClearML - Auto-Magical Suite of tools to streamline your ML workflow Experiment Manager, MLOps and Data-Management ClearML Formerly known as Allegro T

ClearML 4k Jan 09, 2023