A new play-and-plug method of controlling an existing generative model with conditioning attributes and their compositions.

Related tags

Deep LearningLACE
Overview

Controllable and Compositional Generation with Latent-Space Energy-Based Models

Python 3.8 pytorch 1.7.1 Torchdiffeq 0.2.1

Teaser image Teaser image

Official PyTorch implementation of the NeurIPS 2021 paper:
Controllable and Compositional Generation with Latent-Space Energy-Based Models
Weili Nie, Arash Vahdat, Anima Anandkumar
https://nvlabs.github.io/LACE

Abstract: Controllable generation is one of the key requirements for successful adoption of deep generative models in real-world applications, but it still remains as a great challenge. In particular, the compositional ability to generate novel concept combinations is out of reach for most current models. In this work, we use energy-based models (EBMs) to handle compositional generation over a set of attributes. To make them scalable to high-resolution image generation, we introduce an EBM in the latent space of a pre-trained generative model such as StyleGAN. We propose a novel EBM formulation representing the joint distribution of data and attributes together, and we show how sampling from it is formulated as solving an ordinary differential equation (ODE). Given a pre-trained generator, all we need for controllable generation is to train an attribute classifier. Sampling with ODEs is done efficiently in the latent space and is robust to hyperparameters. Thus, our method is simple, fast to train, and efficient to sample. Experimental results show that our method outperforms the state-of-the-art in both conditional sampling and sequential editing. In compositional generation, our method excels at zero-shot generation of unseen attribute combinations. Also, by composing energy functions with logical operators, this work is the first to achieve such compositionality in generating photo-realistic images of resolution 1024x1024.

Requirements

  • Linux and Windows are supported, but we recommend Linux for performance and compatibility reasons.
  • 1 high-end NVIDIA GPU with at least 24 GB of memory. We have done all testing and development using a single NVIDIA V100 GPU with memory size 32 GB.
  • 64-bit Python 3.8.
  • CUDA=10.0 and docker must be installed first.
  • Installation of the required library dependencies with Docker:
    docker build -f lace-cuda-10p0.Dockerfile --tag=lace-cuda-10-0:0.0.1 .
    docker run -it -d --gpus 0 --name lace --shm-size 8G -v $(pwd):/workspace -p 5001:6006 lace-cuda-10-0:0.0.1
    docker exec -it lace bash

Experiments on CIFAR-10

The CIFAR10 folder contains the codebase to get the main results on the CIFAR-10 dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., the latent code and label pairs) from here and unzip it to the CIFAR10 folder. Or you can go to the folder CIFAR10/prepare_data and follow the instructions to generate the data.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

In the script run_clf.sh, the variable x can be specified to w or z, representing that the latent classifier is trained in the w-space or z-space of StyleGAN, respectively.

Sampling

To get the conditional sampling results with the ODE or Langevin dynamics (LD) sampler, you can run:

# ODE
bash scripts/run_cond_ode_sample.sh

# LD
bash scripts/run_cond_ld_sample.sh

By default, we set x to w, meaning we use the w-space classifier, because we find our method works the best in w-space. You can change the value of x to z or i to use the classifier in z-space or pixel space, for a comparison.

To compute the conditional accuracy (ACC) and FID scores in conditional sampling with the ODE or LD sampler, you can run:

# ODE
bash scripts/run_cond_ode_score.sh

# LD
bash scripts/run_cond_ld_score.sh

Note that:

  1. For the ACC evaluation, you need a pre-trained image classifier, which can be downloaded as instructed here;

  2. For the FID evaluation, you need to have the FID reference statistics computed beforehand. You can go to the folder CIFAR10/prepare_data and follow the instructions to compute the FID reference statistics with real images sampled from CIFAR-10.

Experiments on FFHQ

The FFHQ folder contains the codebase for getting the main results on the FFHQ dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., 10k pairs of latent variables and labels) from here (originally from StyleFlow) and unzip it to the FFHQ folder.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

Note that each att_name (i.e., glasses) in run_clf.sh corresponds to a separate attribute classifier.

Sampling

First, you have to get the pre-trained StyleGAN2 (config-f) by following the instructions in Convert StyleGAN2 weight from official checkpoints.

Conditional sampling

To get the conditional sampling results with the ODE or LD sampler, you can run:

# ODE
bash scripts/run_cond_ode_sample.sh

# LD
bash scripts/run_cond_ld_sample.sh

To compute the conditional accuracy (ACC) and FID scores in conditional sampling with the ODE or LD sampler, you can run:

# ODE
bash scripts/run_cond_ode_score.sh

# LD
bash scripts/run_cond_ld_score.sh

Note that:

  1. For the ACC evaluation, you need to train an FFHQ image classifier, as instructed here;

  2. For the FID evaluation, you need to have the FID reference statistics computed beforehand. You can go to the folder FFHQ/prepare_models_data and follow the instructions to compute the FID reference statistics with the StyleGAN generated FFHQ images.

Sequential editing

To get the qualitative and quantitative results of sequential editing, you can run:

# User-specified sampling
bash scripts/run_seq_edit_sample.sh

# ACC and FID
bash scripts/run_seq_edit_score.sh

Note that:

  • Similarly, you first need to train an FFHQ image classifier and get the FID reference statics to compute ACC and FID score by following the instructions, respectively.

  • To get the face identity preservation (ID) score, you first need to download the pre-trained ArcFace network, which is publicly available here, to the folder FFHQ/pretrained/metrics.

Compositional Generation

To get the results of zero-shot generation on novel attribute combinations, you can run:

bash scripts/run_zero_shot.sh

To get the results of compositions of energy functions with logical operators, we run:

bash scripts/run_combine_energy.sh

Experiments on MetFaces

The MetFaces folder contains the codebase for getting the main results on the MetFaces dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., 10k pairs of latent variables and labels) from here and unzip it to the MetFaces folder. Or you can go to the folder MetFaces/prepare_data and follow the instructions to generate the data.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

Note that each att_name (i.e., yaw) in run_clf.sh corresponds to a separate attribute classifier.

Sampling

To get the conditional sampling and sequential editing results, you can run:

# conditional sampling
bash scripts/run_cond_sample.sh

# sequential editing
bash scripts/run_seq_edit_sample.sh

Experiments on AFHQ-Cats

The AFHQ folder contains the codebase for getting the main results on the AFHQ-Cats dataset, where the scripts folder contains the necessary bash scripts to run the code.

Data preparation

Before running the code, you have to download the data (i.e., 10k pairs of latent variables and labels) from here and unzip it to the AFHQ folder. Or you can go to the folder AFHQ/prepare_data and follow the instructions to generate the data.

Training

To train the latent classifier, you can run:

bash scripts/run_clf.sh

Note that each att_name (i.e., breeds) in run_clf.sh corresponds to a separate attribute classifier.

Sampling

To get the conditional sampling and sequential editing results, you can run:

# conditional sampling
bash scripts/run_cond_sample.sh

# sequential editing
bash scripts/run_seq_edit_sample.sh

License

Please check the LICENSE file. This work may be used non-commercially, meaning for research or evaluation purposes only. For business inquiries, please contact [email protected].

Citation

Please cite our paper, if you happen to use this codebase:

@inproceedings{nie2021controllable,
  title={Controllable and compositional generation with latent-space energy-based models},
  author={Nie, Weili and Vahdat, Arash and Anandkumar, Anima},
  booktitle={Neural Information Processing Systems (NeurIPS)},
  year={2021}
}
Owner
NVIDIA Research Projects
NVIDIA Research Projects
Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

Single Image Super-Resolution with EDSR, WDSR and SRGAN A Tensorflow 2.x based implementation of Enhanced Deep Residual Networks for Single Image Supe

Martin Krasser 1.3k Jan 06, 2023
Source code for Fixed-Point GAN for Cloud Detection

FCD: Fixed-Point GAN for Cloud Detection PyTorch source code of Nyborg & Assent (2020). Abstract The detection of clouds in satellite images is an ess

Joachim Nyborg 8 Dec 22, 2022
A variational Bayesian method for similarity learning in non-rigid image registration (CVPR 2022)

A variational Bayesian method for similarity learning in non-rigid image registration We provide the source code and the trained models used in the re

daniel grzech 14 Nov 21, 2022
Intent parsing and slot filling in PyTorch with seq2seq + attention

PyTorch Seq2Seq Intent Parsing Reframing intent parsing as a human - machine translation task. Work in progress successor to torch-seq2seq-intent-pars

Sean Robertson 160 Jan 07, 2023
Pixel-wise segmentation on VOC2012 dataset using pytorch.

PiWiSe Pixel-wise segmentation on the VOC2012 dataset using pytorch. FCN SegNet PSPNet UNet RefineNet For a more complete implementation of segmentati

Bodo Kaiser 378 Dec 30, 2022
Caffe: a fast open framework for deep learning.

Caffe Caffe is a deep learning framework made with expression, speed, and modularity in mind. It is developed by Berkeley AI Research (BAIR)/The Berke

Berkeley Vision and Learning Center 33k Dec 28, 2022
🌾 PASTIS 🌾 Panoptic Agricultural Satellite TIme Series

🌾 PASTIS 🌾 Panoptic Agricultural Satellite TIme Series (optical and radar) The PASTIS Dataset Dataset presentation PASTIS is a benchmark dataset for

86 Jan 04, 2023
AdaFocus (ICCV 2021) Adaptive Focus for Efficient Video Recognition

AdaFocus (ICCV 2021) This repo contains the official code and pre-trained models for AdaFocus. Adaptive Focus for Efficient Video Recognition Referenc

Rainforest Wang 115 Dec 21, 2022
This is a clean and robust Pytorch implementation of DQN and Double DQN.

DQN/DDQN-Pytorch This is a clean and robust Pytorch implementation of DQN and Double DQN. Here is the training curve: All the experiments are trained

XinJingHao 15 Dec 27, 2022
Plover-tapey-tape: an alternative to Plover’s built-in paper tape

plover-tapey-tape plover-tapey-tape is an alternative to Plover’s built-in paper

7 May 29, 2022
NeuralForecast is a Python library for time series forecasting with deep learning models

NeuralForecast is a Python library for time series forecasting with deep learning models. It includes benchmark datasets, data-loading utilities, evaluation functions, statistical tests, univariate m

Nixtla 1.1k Jan 03, 2023
UMich 500-Level Mobile Robotics Course

MOBILE ROBOTICS: METHODS & ALGORITHMS - WINTER 2022 University of Michigan - NA 568/EECS 568/ROB 530 For slides, lecture notes, and example codes, see

393 Dec 29, 2022
Code for the paper "There is no Double-Descent in Random Forests"

Code for the paper "There is no Double-Descent in Random Forests" This repository contains the code to run the experiments for our paper called "There

2 Jan 14, 2022
Framework for Spectral Clustering on the Sparse Coefficients of Learned Dictionaries

Dictionary Learning for Clustering on Hyperspectral Images Overview Framework for Spectral Clustering on the Sparse Coefficients of Learned Dictionari

Joshua Bruton 6 Oct 25, 2022
Deep Learning and Reinforcement Learning Library for Scientists and Engineers 🔥

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 29, 2022
GeoTransformer - Geometric Transformer for Fast and Robust Point Cloud Registration

Geometric Transformer for Fast and Robust Point Cloud Registration PyTorch imple

Zheng Qin 220 Jan 05, 2023
Geometric Algebra package for JAX

JAXGA - JAX Geometric Algebra GitHub | Docs JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing onl

Robin Kahlow 36 Dec 22, 2022
Answering Open-Domain Questions of Varying Reasoning Steps from Text

This repository contains the authors' implementation of the Iterative Retriever, Reader, and Reranker (IRRR) model in the EMNLP 2021 paper "Answering Open-Domain Questions of Varying Reasoning Steps

26 Dec 22, 2022
A full-fledged version of Pix2Seq

Stable-Pix2Seq A full-fledged version of Pix2Seq What it is. This is a full-fledged version of Pix2Seq. Compared with unofficial-pix2seq, stable-pix2s

peng gao 205 Dec 27, 2022
OpenCV, MediaPipe Pose Estimation, Affine Transform for Icon Overlay

Yoga Pose Identification and Icon Matching Project Goal Detect yoga poses performed by a user and overlay a corresponding icon image. Running the main

Anna Garverick 1 Dec 03, 2021