Code for visualizing the loss landscape of neural nets

Overview

Visualizing the Loss Landscape of Neural Nets

This repository contains the PyTorch code for the paper

Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer and Tom Goldstein. Visualizing the Loss Landscape of Neural Nets. NIPS, 2018.

An interactive 3D visualizer for loss surfaces has been provided by telesens.

Given a network architecture and its pre-trained parameters, this tool calculates and visualizes the loss surface along random direction(s) near the optimal parameters. The calculation can be done in parallel with multiple GPUs per node, and multiple nodes. The random direction(s) and loss surface values are stored in HDF5 (.h5) files after they are produced.

Setup

Environment: One or more multi-GPU node(s) with the following software/libraries installed:

Pre-trained models: The code accepts pre-trained PyTorch models for the CIFAR-10 dataset. To load the pre-trained model correctly, the model file should contain state_dict, which is saved from the state_dict() method. The default path for pre-trained networks is cifar10/trained_nets. Some of the pre-trained models and plotted figures can be downloaded here:

Data preprocessing: The data pre-processing method used for visualization should be consistent with the one used for model training. No data augmentation (random cropping or horizontal flipping) is used in calculating the loss values.

Visualizing 1D loss curve

Creating 1D linear interpolations

The 1D linear interpolation method [1] evaluates the loss values along the direction between two minimizers of the same network loss function. This method has been used to compare the flatness of minimizers trained with different batch sizes [2]. A 1D linear interpolation plot is produced using the plot_surface.py method.

mpirun -n 4 python plot_surface.py --mpi --cuda --model vgg9 --x=-0.5:1.5:401 --dir_type states \
--model_file cifar10/trained_nets/vgg9_sgd_lr=0.1_bs=128_wd=0.0_save_epoch=1/model_300.t7 \
--model_file2 cifar10/trained_nets/vgg9_sgd_lr=0.1_bs=8192_wd=0.0_save_epoch=1/model_300.t7 --plot
  • --x=-0.5:1.5:401 sets the range and resolution for the plot. The x-coordinates in the plot will run from -0.5 to 1.5 (the minimizers are located at 0 and 1), and the loss value will be evaluated at 401 locations along this line.
  • --dir_type states indicates the direction contains dimensions for all parameters as well as the statistics of the BN layers (running_mean and running_var). Note that ignoring running_mean and running_var cannot produce correct loss values when plotting two solutions togeather in the same figure.
  • The two model files contain network parameters describing the two distinct minimizers of the loss function. The plot will interpolate between these two minima.

VGG-9 SGD, WD=0

Producing plots along random normalized directions

A random direction with the same dimension as the model parameters is created and "filter normalized." Then we can sample loss values along this direction.

mpirun -n 4 python plot_surface.py --mpi --cuda --model vgg9 --x=-1:1:51 \
--model_file cifar10/trained_nets/vgg9_sgd_lr=0.1_bs=128_wd=0.0_save_epoch=1/model_300.t7 \
--dir_type weights --xnorm filter --xignore biasbn --plot
  • --dir_type weights indicates the direction has the same dimensions as the learned parameters, including bias and parameters in the BN layers.
  • --xnorm filter normalizes the random direction at the filter level. Here, a "filter" refers to the parameters that produce a single feature map. For fully connected layers, a "filter" contains the weights that contribute to a single neuron.
  • --xignore biasbn ignores the direction corresponding to bias and BN parameters (fill the corresponding entries in the random vector with zeros).

VGG-9 SGD, WD=0

We can also customize the appearance of the 1D plots by calling plot_1D.py once the surface file is available.

Visualizing 2D loss contours

To plot the loss contours, we choose two random directions and normalize them in the same way as the 1D plotting.

mpirun -n 4 python plot_surface.py --mpi --cuda --model resnet56 --x=-1:1:51 --y=-1:1:51 \
--model_file cifar10/trained_nets/resnet56_sgd_lr=0.1_bs=128_wd=0.0005/model_300.t7 \
--dir_type weights --xnorm filter --xignore biasbn --ynorm filter --yignore biasbn  --plot

ResNet-56

Once a surface is generated and stored in a .h5 file, we can produce and customize a contour plot using the script plot_2D.py.

python plot_2D.py --surf_file path_to_surf_file --surf_name train_loss
  • --surf_name specifies the type of surface. The default choice is train_loss,
  • --vmin and --vmax sets the range of values to be plotted.
  • --vlevel sets the step of the contours.

Visualizing 3D loss surface

plot_2D.py can make a basic 3D loss surface plot with matplotlib. If you want a more detailed rendering that uses lighting to display details, you can render the loss surface with ParaView.

ResNet-56-noshort ResNet-56

To do this, you must

  1. Convert the surface .h5 file to a .vtp file.
python h52vtp.py --surf_file path_to_surf_file --surf_name train_loss --zmax  10 --log

This will generate a VTK file containing the loss surface with max value 10 in the log scale.

  1. Open the .vtp file with ParaView. In ParaView, open the .vtp file with the VTK reader. Click the eye icon in the Pipeline Browser to make the figure show up. You can drag the surface around, and change the colors in the Properties window.

  2. If the surface appears extremely skinny and needle-like, you may need to adjust the "transforming" parameters in the left control panel. Enter numbers larger than 1 in the "scale" fields to widen the plot.

  3. Select Save screenshot in the File menu to save the image.

Reference

[1] Ian J Goodfellow, Oriol Vinyals, and Andrew M Saxe. Qualitatively characterizing neural network optimization problems. ICLR, 2015.

[2] Nitish Shirish Keskar, Dheevatsa Mudigere, Jorge Nocedal, Mikhail Smelyanskiy, and Ping Tak Peter Tang. On large-batch training for deep learning: Generalization gap and sharp minima. ICLR, 2017.

Citation

If you find this code useful in your research, please cite:

@inproceedings{visualloss,
  title={Visualizing the Loss Landscape of Neural Nets},
  author={Li, Hao and Xu, Zheng and Taylor, Gavin and Studer, Christoph and Goldstein, Tom},
  booktitle={Neural Information Processing Systems},
  year={2018}
}
Owner
Tom Goldstein
Tom Goldstein
Multi Camera Calibration

Multi Camera Calibration 'modules/camera_calibration/app/camera_calibration.cpp' is for calculating extrinsic parameter of each individual cameras. 'm

7 Dec 01, 2022
Lipschitz-constrained Unsupervised Skill Discovery

Lipschitz-constrained Unsupervised Skill Discovery This repository is the official implementation of Seohong Park, Jongwook Choi*, Jaekyeom Kim*, Hong

Seohong Park 17 Dec 18, 2022
Doubly Robust Off-Policy Evaluation for Ranking Policies under the Cascade Behavior Model

Doubly Robust Off-Policy Evaluation for Ranking Policies under the Cascade Behavior Model About This repository contains the code to replicate the syn

Haruka Kiyohara 12 Dec 07, 2022
Auto-Encoding Score Distribution Regression for Action Quality Assessment

DAE-AQA It is an open source program reference to paper Auto-Encoding Score Distribution Regression for Action Quality Assessment. 1.Introduction DAE

13 Nov 16, 2022
Simple API for UCI Machine Learning Dataset Repository (search, download, analyze)

A simple API for working with University of California, Irvine (UCI) Machine Learning (ML) repository Table of Contents Introduction About Page of the

Tirthajyoti Sarkar 223 Dec 05, 2022
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayes

Intel Labs 210 Jan 04, 2023
simple artificial intelligence utilities

Simple AI Project home: http://github.com/simpleai-team/simpleai This lib implements many of the artificial intelligence algorithms described on the b

921 Dec 08, 2022
Awesome-google-colab - Google Colaboratory Notebooks and Repositories

Unofficial Google Colaboratory Notebook and Repository Gallery Please contact me to take over and revamp this repo (it gets around 30k views and 200k

Derek Snow 1.2k Jan 03, 2023
Get started with Machine Learning with Python - An introduction with Python programming examples

Machine Learning With Python Get started with Machine Learning with Python An engaging introduction to Machine Learning with Python TL;DR Download all

Learn Python with Rune 130 Jan 02, 2023
A PyTorch Implementation of ViT (Vision Transformer)

ViT - Vision Transformer This is an implementation of ViT - Vision Transformer by Google Research Team through the paper "An Image is Worth 16x16 Word

Quan Nguyen 7 May 11, 2022
A Multi-attribute Controllable Generative Model for Histopathology Image Synthesis

A Multi-attribute Controllable Generative Model for Histopathology Image Synthesis This is the pytorch implementation for our MICCAI 2021 paper. A Mul

Jiarong Ye 7 Apr 04, 2022
Automated Evidence Collection for Fake News Detection

Automated Evidence Collection for Fake News Detection This is the code repo for the Automated Evidence Collection for Fake News Detection paper accept

Mrinal Rawat 2 Apr 12, 2022
Multiview 3D object detection on MultiviewC dataset through moft3d.

Voxelized 3D Feature Aggregation for Multiview Detection [arXiv] Multiview 3D object detection on MultiviewC dataset through VFA. Introduction We prop

Jiahao Ma 20 Dec 21, 2022
Only a Matter of Style: Age Transformation Using a Style-Based Regression Model

Only a Matter of Style: Age Transformation Using a Style-Based Regression Model The task of age transformation illustrates the change of an individual

444 Dec 30, 2022
The implementation for "Comprehensive Knowledge Distillation with Causal Intervention".

Comprehensive Knowledge Distillation with Causal Intervention This repository is a PyTorch implementation of "Comprehensive Knowledge Distillation wit

Xiang Deng 10 Nov 03, 2022
Re-TACRED: Addressing Shortcomings of the TACRED Dataset

Re-TACRED Re-TACRED: Addressing Shortcomings of the TACRED Dataset

George Stoica 40 Dec 10, 2022
CvT2DistilGPT2 is an encoder-to-decoder model that was developed for chest X-ray report generation.

CvT2DistilGPT2 Improving Chest X-Ray Report Generation by Leveraging Warm-Starting This repository houses the implementation of CvT2DistilGPT2 from [1

The Australian e-Health Research Centre 21 Dec 28, 2022
Unofficial pytorch implementation of 'Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization'

pytorch-AdaIN This is an unofficial pytorch implementation of a paper, Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Hua

Naoto Inoue 873 Jan 06, 2023
Code for Contrastive-Geometry Networks for Generalized 3D Pose Transfer

Code for Contrastive-Geometry Networks for Generalized 3D Pose Transfer

18 Jun 28, 2022
Tools for robust generative diffeomorphic slice to volume reconstruction

RGDSVR Tools for Robust Generative Diffeomorphic Slice to Volume Reconstructions (RGDSVR) This repository provides tools to implement the methods in t

Lucilio Cordero-Grande 0 Oct 29, 2021