PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Overview

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models

This repository will reproduce the main results from our paper:

On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models
Erik Nijkamp*, Mitch Hill*, Tian Han, Song-Chun Zhu, and Ying Nian Wu (*equal contributions)
https://arxiv.org/abs/1903.12370
AAAI 2020.

The files train_data.py and train_toy.py are PyTorch-based implementations of Algorithm 1 for image datasets and toy 2D distributions respectively. Both files will measure and plot the diagnostic values $d_{s_t}$ and $r_t$ described in Section 3 during training. The file eval.py will sample from a saved checkpoint using either unadjusted Langevin dynamics or Metropolis-Hastings adjusted Langevin dynamics. We provide an appendix ebm-anatomy-appendix.pdf that contains further practical considerations and empirical observations.

Config Files

The folder config_locker has several JSON files that reproduce different convergent and non-convergent learning outcomes for image datasets and toy distributions. Config files for evaluation of pre-trained networks are also included. The files data_config.json, toy_config.json, and eval_config.json fully explain the parameters for train_data.py, train_toy.py, and eval.py respectively.

Executable Files

To run an experiment with train_data.py, train_toy.py, or eval.py, just specify a name for the experiment folder and the location of the JSON config file:

# directory for experiment results
EXP_DIR = './name_of/new_folder/'
# json file with experiment config
CONFIG_FILE = './path_to/config.json'

before execution.

Other Files

Network structures are located in nets.py. A download function for Oxford Flowers 102 data, plotting functions, and a toy dataset class can be found in utils.py.

Diagnostics

Energy Difference and Langevin Gradient Magnitude: Both image and toy experiments will plot $d_{s_t}$ and $r_t$ (see Section 3) over training along with correlation plots as in Figure 4 (with ACF rather than PACF).

Landscape Plots: Toy experiments will plot the density and log-density (negative energy) for ground-truth, learned energy, and short-run models. Kernel density estimation is used to obtain the short-run density.

Short-Run MCMC Samples: Image data experiments will periodically visualize the short-run MCMC samples. A batch of persistent MCMC samples will also be saved for implementations that use persistent initialization for short-run sampling.

Long-Run MCMC Samples: Image data experiments have the option to obtain long-run MCMC samples during training. When log_longrun is set to true in a data config file, the training implementation will generate long-run MCMC samples at a frequency determined by log_longrun_freq. The appearance of long-run MCMC samples indicates whether the energy function assigns probability mass in realistic regions of the image space.

Pre-trained Networks

A convergent pre-trained network and non-convergent pre-trained network for the Oxford Flowers 102 dataset are available in the Releases section of the repository. The config files eval_flowers_convergent.json and eval_flowers_convergent_mh.json are set up to evaluate flowers_convergent_net.pth. The config file eval_flowers_nonconvergent.json is set up to evaluate flowers_nonconvergent_net.pth.

Contact

Please contact Mitch Hill ([email protected]) or Erik Nijkamp ([email protected]) for any questions.

You might also like...
Re-implementation of the Noise Contrastive Estimation algorithm for pyTorch, following "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." (Gutmann and Hyvarinen, AISTATS 2010)

Noise Contrastive Estimation for pyTorch Overview This repository contains a re-implementation of the Noise Contrastive Estimation algorithm, implemen

ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch
ppo_pytorch_cpp - an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

PyTorch implementation of DreamerV2 model-based RL algorithm

PyDreamer Reimplementation of DreamerV2 model-based RL algorithm in PyTorch. The official DreamerV2 implementation can be found here. Features ... Run

PyTorch implementation of the implicit Q-learning algorithm (IQL)
PyTorch implementation of the implicit Q-learning algorithm (IQL)

Implicit-Q-Learning (IQL) PyTorch implementation of the implicit Q-learning algorithm IQL (Paper) Currently only implemented for online learning. Offl

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning"

PyTorch Implementation of the SuRP algorithm by the authors of the AISTATS 2022 paper "An Information-Theoretic Justification for Model Pruning".

A pytorch reprelication of the model-based reinforcement learning algorithm MBPO
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.
An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

An algorithm that handles large-scale aerial photo co-registration, based on SURF, RANSAC and PyTorch autograd.

Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).
PyGAD, a Python 3 library for building the genetic algorithm and training machine learning algorithms (Keras & PyTorch).

PyGAD: Genetic Algorithm in Python PyGAD is an open-source easy-to-use Python 3 library for building the genetic algorithm and optimizing machine lear

Comments
  • Step size in Langevin Dynamics

    Step size in Langevin Dynamics

    Hi, in your code, when you do the langevin dynamics, you run x_s_t.data += - f_prime + config['epsilon'] * t.randn_like(x_s_t) However, does this mean that the step size for the gradient f_prim is 1? Should we run x_s_t.data += - 0.5*config['epsilon']**2*f_prime + config['epsilon'] * t.randn_like(x_s_t) instead?

    opened by XavierXiao 1
Releases(v1.0)
Owner
Mitch Hill
Assistant Professor of Statistics and Data Science at UCF
Mitch Hill
Source code and data from the RecSys 2020 article "Carousel Personalization in Music Streaming Apps with Contextual Bandits" by W. Bendada, G. Salha and T. Bontempelli

Carousel Personalization in Music Streaming Apps with Contextual Bandits - RecSys 2020 This repository provides Python code and data to reproduce expe

Deezer 48 Jan 02, 2023
Geometric Sensitivity Decomposition

Geometric Sensitivity Decomposition This repo is the official implementation of A Geometric Perspective towards Neural Calibration via Sensitivity Dec

16 Dec 26, 2022
Official implementation for paper: A Latent Transformer for Disentangled Face Editing in Images and Videos.

A Latent Transformer for Disentangled Face Editing in Images and Videos Official implementation for paper: A Latent Transformer for Disentangled Face

InterDigital 108 Dec 09, 2022
Efficient neural networks for analog audio effect modeling

micro-TCN Efficient neural networks for audio effect modeling

Christian Steinmetz 94 Dec 29, 2022
keyframes-CNN-RNN(action recognition)

keyframes-CNN-RNN(action recognition) Environment: python=3.7 pytorch=1.2 Datasets: Following the format of UCF101 action recognition. Run steps: Mo

4 Feb 09, 2022
PyTorch implementation of CVPR'18 - Perturbative Neural Networks

This is an attempt to reproduce results in Perturbative Neural Networks paper. See original repo for details.

Michael Klachko 57 May 14, 2021
Repo for "TableParser: Automatic Table Parsing with Weak Supervision from Spreadsheets" at [email protected]

TableParser Repo for "TableParser: Automatic Table Parsing with Weak Supervision from Spreadsheets" at DS3 Lab 11 Dec 13, 2022

Code accompanying the paper Say As You Wish: Fine-grained Control of Image Caption Generation with Abstract Scene Graphs (Chen et al., CVPR 2020, Oral).

Say As You Wish: Fine-grained Control of Image Caption Generation with Abstract Scene Graphs This repository contains PyTorch implementation of our pa

Shizhe Chen 178 Dec 29, 2022
This repository contains the code for our fast polygonal building extraction from overhead images pipeline.

Polygonal Building Segmentation by Frame Field Learning We add a frame field output to an image segmentation neural network to improve segmentation qu

Nicolas Girard 186 Jan 04, 2023
GBIM(Gesture-Based Interaction map)

手势交互地图 GBIM(Gesture-Based Interaction map),基于视觉深度神经网络的交互地图,通过电脑摄像头观察使用者的手势变化,进而控制地图进行简单的交互。网络使用PaddleX提供的轻量级模型PPYOLO Tiny以及MobileNet V3 small,使得整个模型大小约10MB左右,即使在CPU下也能快速定位和识别手势。

8 Feb 10, 2022
The implementation of FOLD-R++ algorithm

FOLD-R-PP The implementation of FOLD-R++ algorithm. The target of FOLD-R++ algorithm is to learn an answer set program for a classification task. Inst

13 Dec 23, 2022
Code for Reciprocal Adversarial Learning for Brain Tumor Segmentation: A Solution to BraTS Challenge 2021 Segmentation Task

BRATS 2021 Solution For Segmentation Task This repo contains the supported pytorch code and configuration files to reproduce 3D medical image segmenta

Himashi Amanda Peiris 6 Sep 15, 2022
A texturizer that I just made. Nothing special here.

texturizer This is a little project that I did with an hour's time. It texturizes an image given a image and a texture to texturize it with. There is

1 Nov 11, 2021
A commany has recently introduced a new type of bidding, the average bidding, as an alternative to the bid given to the current maximum bidding

Business Problem A commany has recently introduced a new type of bidding, the average bidding, as an alternative to the bid given to the current maxim

Kübra Bilinmiş 1 Jan 15, 2022
Links to works on deep learning algorithms for physics problems, TUM-I15 and beyond

Links to works on deep learning algorithms for physics problems, TUM-I15 and beyond

Nils Thuerey 1.3k Jan 08, 2023
SparseInst: Sparse Instance Activation for Real-Time Instance Segmentation, CVPR 2022

SparseInst 🚀 A simple framework for real-time instance segmentation, CVPR 2022 by Tianheng Cheng, Xinggang Wang†, Shaoyu Chen, Wenqiang Zhang, Qian Z

Hust Visual Learning Team 458 Jan 05, 2023
Object detection, 3D detection, and pose estimation using center point detection:

Objects as Points Object detection, 3D detection, and pose estimation using center point detection: Objects as Points, Xingyi Zhou, Dequan Wang, Phili

Xingyi Zhou 6.7k Jan 03, 2023
This is an open solution to the Home Credit Default Risk challenge 🏡

Home Credit Default Risk: Open Solution This is an open solution to the Home Credit Default Risk challenge 🏡 . More competitions 🎇 Check collection

minerva.ml 427 Dec 27, 2022
Official implementation of "Motif-based Graph Self-Supervised Learning forMolecular Property Prediction"

Motif-based Graph Self-Supervised Learning for Molecular Property Prediction Official Pytorch implementation of NeurIPS'21 paper "Motif-based Graph Se

zaixi 71 Dec 20, 2022
Use tensorflow to implement a Deep Neural Network for real time lane detection

LaneNet-Lane-Detection Use tensorflow to implement a Deep Neural Network for real time lane detection mainly based on the IEEE IV conference paper "To

MaybeShewill-CV 1.9k Jan 08, 2023