Second-Order Neural ODE Optimizer, NeurIPS 2021 spotlight

Related tags

Deep Learningsnopt
Overview

Second-order Neural ODE Optimizer
(NeurIPS 2021 Spotlight) [arXiv]

✔️ faster convergence in wall-clock time | ✔️ O(1) memory cost |
✔️ better test-time performance | ✔️ architecture co-optimization

This repo provides PyTorch code of Second-order Neural ODE Optimizer (SNOpt), a second-order optimizer for training Neural ODEs that retains O(1) memory cost with superior convergence and test-time performance.

SNOpt result

Installation

This code is developed with Python3. PyTorch >=1.7 (we recommend 1.8.1) and torchdiffeq >= 0.2.0 are required.

  1. Install the dependencies with Anaconda and activate the environment snopt with
    conda env create --file requirements.yaml python=3
    conda activate snopt
  2. [Optional] This repo provides a modification (with 15 lines!) of torchdiffeq that allows SNOpt to collect 2nd-order information during adjoint-based training. If you wish to run torchdiffeq on other commit, simply copy-and-paste the folder to this directory then apply the provided snopt_integration.patch.
    cp -r <path_to_your_torchdiffeq_folder> .
    git apply snopt_integration.patch

Run the code

We provide example code for 8 datasets across image classification (main_img_clf.py), time-series prediction (main_time_series.py), and continuous normalizing flow (main_cnf.py). The command lines to generate similar results shown in our paper are detailed in scripts folder. Datasets will be automatically downloaded to data folder at the first call, and all results will be saved to result folder.

bash scripts/run_img_clf.sh     <dataset> # dataset can be {mnist, svhn, cifar10}
bash scripts/run_time_series.sh <dataset> # dataset can be {char-traj, art-wr, spo-ad}
bash scripts/run_cnf.sh         <dataset> # dataset can be {miniboone, gas}

For architecture (specifically integration time) co-optimization, run

bash scripts/run_img_clf.sh cifar10-t1-optimize

Integration with your workflow

snopt can be integrated flawlessly with existing training work flow. Below we provide a handy checklist and pseudo-code to help your integration. For more complex examples, please refer to main_*.py in this repo.

  • Import torchdiffeq that is patched with snopt integration; otherwise simply use torchdiffeq in this repo.
  • Inherit snopt.ODEFuncBase as your vector field; implement the forward pass in F rather than forward.
  • Create Neural ODE with ode layer(s) using snopt.ODEBlock; implement properties odes and ode_mods.
  • Initialize snopt.SNOpt as preconditioner; call train_itr_setup() and step() before standard optim.zero_grad() and optim.step() (see the code below).
  • That's it 🤓 ! Enjoy your second-order training 🚂 🚅 !
import torch
from torchdiffeq import odeint_adjoint as odesolve
from snopt import SNOpt, ODEFuncBase, ODEBlock
from easydict import EasyDict as dict

class ODEFunc(ODEFuncBase):
    def __init__(self, opt):
        super(ODEFunc, self).__init__(opt)
        self.linear = torch.nn.Linear(input_dim, input_dim)

    def F(self, t, z):
        return self.linear(z)

class NeuralODE(torch.nn.Module):
    def __init__(self, ode):
        super(NeuralODE, self).__init__()
        self.ode = ode

    def forward(self, z):
        return self.ode(z)

    @property
    def odes(self): # in case we have multiple odes, collect them in a list
        return [self.ode]

    @property
    def ode_mods(self): # modules of all ode(s)
        return [mod for mod in self.ode.odefunc.modules()]

# Create Neural ODE
opt = dict(
    optimizer='SNOpt',tol=1e-3,ode_solver='dopri5',use_adaptive_t1=False,snopt_step_size=0.01)
odefunc = ODEFunc(opt)
integration_time = torch.tensor([0.0, 1.0]).float()
ode = ODEBlock(opt, odefunc, odesolve, integration_time)
net = NeuralODE(ode)

# Create SNOpt optimizer
precond = SNOpt(net, eps=0.05, update_freq=100)
optim = torch.optim.SGD(net.parameters(), lr=0.001)

# Training loop
for (x,y) in training_loader:
    precond.train_itr_setup() # <--- additional step for precond
    optim.zero_grad()

    loss = loss_function(net(x), y)
    loss.backward()

    # Run SNOpt optimizer
    precond.step()            # <--- additional step for precond
    optim.step()

What the library actually contains

This snopt library implements the following objects for efficient 2nd-order adjoint-based training of Neural ODEs.

  • ODEFuncBase: Defines the vector field (inherits torch.nn.Module) of Neural ODE.
  • CNFFuncBase: Serves the same purposes as ODEFuncBase except for CNF applications.
  • ODEBlock: A Neural-ODE module (torch.nn.Module) that solves the initial value problem (given the vector field, integration time, and a ODE solver) and handles integration time co-optimization with feedback policy.
  • SNOpt: Our primary 2nd-order optimizer (torch.optim.Optimizer), implemented as a "preconditioner" (see example code above). It takes the following arguments.
    • net is the Neural ODE. Note that the entire network (rather than net.parameters()) is required.
    • eps is the the regularization that stabilizes preconditioning. We recommend the value in [0.05, 0.1].
    • update_freq is the frequency to refresh the 2nd-order information. We recommend the value 100~200.
    • alpha decides the running averages of eigenvalues. We recommend fixing the value to 0.75.
    • full_precond decides whether we wish to precondition layers aside from those in Neural ODEs.
  • SNOptAdjointCollector: A helper to collect information from torchdiffeq to construct 2nd-order matrices.
  • IntegrationTimeOptimizer: Our 2nd-order method that co-optimizes the integration time (i.e., t1). This is done by calling t1_train_itr_setup(train_it) and update_t1() together with optim.zero_grad() and optim.step() (see trainer.py).

The options are passed in as opt and contains the following fields (see options.py for full descriptions.)

  • optimizer is the training method. Use "SNOpt" to enable our method.
  • ode_solver specifies the ODE solver (default is "dopri5") with the absolute/relative tolerance tol.
  • For CNF applications, use divergence_type to specify how divergence should be computed.
  • snopt_step_size determines the step sizes SNOpt will sample along the integration to compute 2nd-order matrices. We recommend the value 0.01 for integration time [0,1], which yield around 100 sampled points.
  • For integration time (t1) co-optimization, enable the flag use_adaptive_t1 and setup the following options.
    • adaptive_t1 specifies t1 optimization method. Choices are "baseline" and "feedback"(ours).
    • t1_lr is the learning rate. We recommend the value in [0.05, 0.1].
    • t1_reg is the coefficient of the quadratic penalty imposed on t1. The performance is quite sensitive to this value. We recommend the value in [1e-4, 1e-3].
    • t1_update_freq is the frequency to update t1. We recommend the value 50~100.

Remarks & Citation

The current library only supports adjoint-based training, yet it can be extended to normal odeint method (stay tuned!). The pre-processing of tabular and uea datasets are adopted from ffjord and NeuralCDE, and the eigenvalue-regularized preconditioning is adopted from EKFAC-pytorch.

If you find this library useful, please cite ⬇️ . Contact me ([email protected]) if you have any questions!

@inproceedings{liu2021second,
  title={Second-order Neural ODE Optimizer},
  author={Liu, Guan-Horng and Chen, Tianrong and Theodorou, Evangelos A},
  booktitle={Advances in Neural Information Processing Systems},
  year={2021},
}
Owner
Guan-Horng Liu
CMU RI → Uber ATG → GaTech ML
Guan-Horng Liu
Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation

Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation

NVIDIA Research Projects 4.8k Jan 09, 2023
This project deploys a yolo fastest model in the form of tflite on raspberry 3b+. The model is from another repository of mine called -Trash-Classification-Car

Deploy-yolo-fastest-tflite-on-raspberry 觉得有用的话可以顺手点个star嗷 这个项目将垃圾分类小车中的tflite模型移植到了树莓派3b+上面。 该项目主要是为了记录在树莓派部署yolo fastest tflite的流程 (之后有时间会尝试用C++部署来提升

7 Aug 16, 2022
SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model

SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model Edresson Casanova, Christopher Shulby, Eren Gölge, Nicolas Michael Müller, Frede

Edresson Casanova 92 Dec 09, 2022
A lightweight tool to get an AI Infrastructure Stack up in minutes not days.

K3ai will take care of setup K8s for You, deploy the AI tool of your choice and even run your code on it.

k3ai 105 Dec 04, 2022
Self-Supervised Pre-Training for Transformer-Based Person Re-Identification

Self-Supervised Pre-Training for Transformer-Based Person Re-Identification [pdf] The official repository for Self-Supervised Pre-Training for Transfo

Hao Luo 116 Jan 04, 2023
Code for "MetaMorph: Learning Universal Controllers with Transformers", Gupta et al, ICLR 2022

MetaMorph: Learning Universal Controllers with Transformers This is the code for the paper MetaMorph: Learning Universal Controllers with Transformers

Agrim Gupta 50 Jan 03, 2023
Pytorch code for ICRA'21 paper: "Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation"

Hierarchical Cross-Modal Agent for Robotics Vision-and-Language Navigation This repository is the pytorch implementation of our paper: Hierarchical Cr

43 Nov 21, 2022
Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.

Stock Price Prediction Using Deep Learning Univariate Time Series Predicting stock price using historical data of a company using Neural networks for

Abdultawwab Safarji 7 Nov 27, 2022
Transport Mode detection - can detect the mode of transport with the help of features such as acceeration,jerk etc

title emoji colorFrom colorTo sdk app_file pinned Transport_Mode_Detector 🚀 purple yellow gradio app.py false Configuration title: string Display tit

Nishant Rajadhyaksha 3 Jan 16, 2022
How to Train a GAN? Tips and tricks to make GANs work

(this list is no longer maintained, and I am not sure how relevant it is in 2020) How to Train a GAN? Tips and tricks to make GANs work While research

Soumith Chintala 10.8k Dec 31, 2022
A python program to hack instagram

hackinsta a program to hack instagram Yokoback_(instahack) is the file to open, you need libraries write on import. You run that file in the same fold

2 Jan 22, 2022
Joint Versus Independent Multiview Hashing for Cross-View Retrieval[J] (IEEE TCYB 2021, PyTorch Code)

Thanks to the low storage cost and high query speed, cross-view hashing (CVH) has been successfully used for similarity search in multimedia retrieval. However, most existing CVH methods use all view

4 Nov 19, 2022
Neural Contours: Learning to Draw Lines from 3D Shapes (CVPR2020)

Neural Contours: Learning to Draw Lines from 3D Shapes This repository contains the PyTorch implementation for CVPR 2020 Paper "Neural Contours: Learn

93 Dec 16, 2022
Deep Learning and Reinforcement Learning Library for Scientists and Engineers 🔥

TensorLayer is a novel TensorFlow-based deep learning and reinforcement learning library designed for researchers and engineers. It provides an extens

TensorLayer Community 7.1k Dec 29, 2022
Code accompanying the NeurIPS 2021 paper "Generating High-Quality Explanations for Navigation in Partially-Revealed Environments"

Generating High-Quality Explanations for Navigation in Partially-Revealed Environments This work presents an approach to explainable navigation under

RAIL Group @ George Mason University 1 Oct 28, 2022
Code for weakly supervised segmentation of a single class

SingleClassRL Implementation of weak single object segmentation from paper "Regularized Loss for Weakly Supervised Single Class Semantic Segmentation"

16 Nov 14, 2022
ML From Scratch

ML from Scratch MACHINE LEARNING TOPICS COVERED - FROM SCRATCH Linear Regression Logistic Regression K Means Clustering K Nearest Neighbours Decision

Tanishq Gautam 66 Nov 02, 2022
LF-YOLO (Lighter and Faster YOLO) is used to detect defect of X-ray weld image.

This project is based on ultralytics/yolov3. LF-YOLO (Lighter and Faster YOLO) is used to detect defect of X-ray weld image. Download $ git clone http

26 Dec 13, 2022
load .txt to train YOLOX, same as Yolo others

YOLOX train your data you need generate data.txt like follow format (per line- one image). prepare one data.txt like this: img_path1 x1,y1,x2,y2,clas

LiMingf 18 Aug 18, 2022
Neural Turing Machine (NTM) & Differentiable Neural Computer (DNC) with pytorch & visdom

Neural Turing Machine (NTM) & Differentiable Neural Computer (DNC) with pytorch & visdom Sample on-line plotting while training(avg loss)/testing(writ

Jingwei Zhang 269 Nov 15, 2022