Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization

Related tags

Deep Learningfishr
Overview

Fishr: Invariant Gradient Variances for Out-of-distribution Generalization

Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization | paper

Alexandre Ramé, Corentin Dancette, Matthieu Cord

Abstract

Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under fair evaluation protocols.

In this paper, we propose a new learning scheme to enforce domain invariance in the space of the gradients of the loss function: specifically, we introduce a regularization term that matches the domain-level variances of gradients across training domains. Critically, our strategy, named Fishr, exhibits close relations with the Fisher Information and the Hessian of the loss. We show that forcing domain-level gradient covariances to be similar during the learning procedure eventually aligns the domain-level loss landscapes locally around the final weights.

Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. In particular, Fishr improves the state of the art on the DomainBed benchmark and performs significantly better than Empirical Risk Minimization.

Installation

Requirements overview

Our implementation relies on the BackPACK package in PyTorch to easily compute gradient variances.

  • python == 3.7.10
  • torch == 1.8.1
  • torchvision == 0.9.1
  • backpack-for-pytorch == 1.3.0
  • numpy == 1.20.2

Procedure

  1. Clone the repo:
$ git clone https://github.com/alexrame/fishr.git
  1. Install this repository and the dependencies using pip:
$ conda create --name fishr python=3.7.10
$ conda activate fishr
$ cd fishr
$ pip install -r requirements.txt

With this, you can edit the Fishr code on the fly.

Overview

This github enables the replication of our two main experiments: (1) on Colored MNIST in the setup defined by IRM and (2) on the DomainBed benchmark.

Colored MNIST in the IRM setup

We first validate that Fishr tackles distribution shifts on the synthetic Colored MNIST.

Main results (Table 2 in Section 6.A)

To reproduce the results from Table 2, call python3 coloredmnist/train_coloredmnist.py --algorithm $algorithm where algorithm is either:

Results will be printed at the end of the script, averaged over 10 runs. Note that all hyperparameters are taken from the seminal IRM implementation.

    Method | Train acc. | Test acc.  | Gray test acc.
   --------|------------|------------|----------------
    ERM    | 86.4 ± 0.2 | 14.0 ± 0.7 |   71.0 ± 0.7
    IRM    | 71.0 ± 0.5 | 65.6 ± 1.8 |   66.1 ± 0.2
    V-REx  | 71.7 ± 1.5 | 67.2 ± 1.5 |   68.6 ± 2.2
    Fishr  | 71.0 ± 0.9 | 69.5 ± 1.0 |   70.2 ± 1.1

Without label flipping (Table 5 in Appendix C.2.3)

The script coloredmnist.train_coloredmnist also accepts as input the argument --label_flipping_prob which defines the label flipping probability. By default, it's 0.25, so to reproduce the results from Table 5 you should set --label_flipping_prob 0.

Fishr variants (Table 6 in Appendix C.2.4)

This table considers two additional Fishr variants, reproduced with algorithm set to:

  • fishr_offdiagonal for Fishr but without centering the gradient variances
  • fishr_notcentered for Fishr but on the full covariance rather than only the diagonal

DomainBed

DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in In Search of Lost Domain Generalization. Instructions below are copied and adapted from the official github.

Algorithms and hyperparameter grids

We added Fishr as a new algorithm here, and defined Fishr's hyperparameter grids here, as defined in Table 7 in Appendix D.

Datasets

We ran Fishr on following datasets:

Launch training

Download the datasets:

python3 -m domainbed.scripts.download\
       --data_dir=/my/data/dir

Train a model for debugging:

python3 -m domainbed.scripts.train\
       --data_dir=/my/data/dir/\
       --algorithm Fishr\
       --dataset ColoredMNIST\
       --test_env 2

Launch a sweep for hyperparameter search:

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/data/dir/\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher
       --datasets ColoredMNIST\
       --algorithms Fishr

Here, MyLauncher is your cluster's command launcher, as implemented in command_launchers.py.

Performances inspection (Tables 3 and 4 in Section 6.B.2, Tables in Appendix G)

To view the results of your sweep:

python -m domainbed.scripts.collect_results\
       --input_dir=/my/sweep/output/path

We inspect performances using following model selection criteria, that differ in what data is used to choose the best hyper-parameters for a given model:

  • OracleSelectionMethod (Oracle): A random subset from the data of the test domain.
  • IIDAccuracySelectionMethod (Training): A random subset from the data of the training domains.

Critically, Fishr performs consistently better than Empirical Risk Minimization.

Model selection Algorithm Colored MNIST Rotated MNIST VLCS PACS OfficeHome TerraIncognita DomainNet Avg
Oracle ERM 57.8 ± 0.2 97.8 ± 0.1 77.6 ± 0.3 86.7 ± 0.3 66.4 ± 0.5 53.0 ± 0.3 41.3 ± 0.1 68.7
Oracle Fishr 68.8 ± 1.4 97.8 ± 0.1 78.2 ± 0.2 86.9 ± 0.2 68.2 ± 0.2 53.6 ± 0.4 41.8 ± 0.2 70.8
Training ERM 51.5 ± 0.1 98.0 ± 0.0 77.5 ± 0.4 85.5 ± 0.2 66.5 ± 0.3 46.1 ± 1.8 40.9 ± 0.1 66.6
Training Fishr 52.0 ± 0.2 97.8 ± 0.0 77.8 ± 0.1 85.5 ± 0.4 67.8 ± 0.1 47.4 ± 1.6 41.7 ± 0.0 67.1

Conclusion

We addressed the task of out-of-distribution generalization for computer vision classification tasks. We derive a new and simple regularization - Fishr - that matches the gradient variances across domains as a proxy for matching domain-level Hessians. Our scalable strategy reaches state-of-the-art performances on the DomainBed benchmark and performs better than ERM. Our empirical experiments suggest that Fishr regularization would consistently improve a deep classifier in real-world applications when dealing with data from multiple domains. If you need help to use Fishr, please open an issue or contact [email protected].

Citation

If you find this code useful for your research, please consider citing our work (under review):

@article{rame2021ishr,
    title={Fishr: Invariant Gradient Variances for Out-of-distribution Generalization},
    author={Alexandre Rame and Corentin Dancette and Matthieu Cord},
    year={2021},
    journal={arXiv preprint arXiv:2109.02934}
}
Videocaptioning.pytorch - A simple implementation of video captioning

pytorch implementation of video captioning recommend installing pytorch and pyth

Yiyu Wang 2 Jan 01, 2022
Pose Transformers: Human Motion Prediction with Non-Autoregressive Transformers

Pose Transformers: Human Motion Prediction with Non-Autoregressive Transformers This is the repo used for human motion prediction with non-autoregress

Idiap Research Institute 26 Dec 14, 2022
Using LSTM write Tang poetry

本教程将通过一个示例对LSTM进行介绍。通过搭建训练LSTM网络,我们将训练一个模型来生成唐诗。本文将对该实现进行详尽的解释,并阐明此模型的工作方式和原因。并不需要过多专业知识,但是可能需要新手花一些时间来理解的模型训练的实际情况。为了节省时间,请尽量选择GPU进行训练。

56 Dec 15, 2022
Baseline and template code for node21 detection track

Nodule Detection Algorithm This codebase implements a baseline model, Faster R-CNN, for the nodule detection track in NODE21. It contains all necessar

node21challenge 11 Jan 15, 2022
A Genetic Programming platform for Python with TensorFlow for wicked-fast CPU and GPU support.

Karoo GP Karoo GP is an evolutionary algorithm, a genetic programming application suite written in Python which supports both symbolic regression and

Kai Staats 149 Jan 09, 2023
Code release for "COTR: Correspondence Transformer for Matching Across Images"

COTR: Correspondence Transformer for Matching Across Images This repository contains the inference code for COTR. We plan to release the training code

UBC Computer Vision Group 360 Jan 06, 2023
AntroPy: entropy and complexity of (EEG) time-series in Python

AntroPy is a Python 3 package providing several time-efficient algorithms for computing the complexity of time-series. It can be used for example to e

Raphael Vallat 153 Dec 27, 2022
[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

[ICCV 2021] A Simple Baseline for Semi-supervised Semantic Segmentation with Strong Data Augmentation

CodingMan 45 Dec 12, 2022
A Conditional Point Diffusion-Refinement Paradigm for 3D Point Cloud Completion

A Conditional Point Diffusion-Refinement Paradigm for 3D Point Cloud Completion This repo intends to release code for our work: Zhaoyang Lyu*, Zhifeng

Zhaoyang Lyu 68 Jan 03, 2023
[ICCV2021] Safety-aware Motion Prediction with Unseen Vehicles for Autonomous Driving

Safety-aware Motion Prediction with Unseen Vehicles for Autonomous Driving Safety-aware Motion Prediction with Unseen Vehicles for Autonomous Driving

Xuanchi Ren 44 Dec 03, 2022
PCGNN - Procedural Content Generation with NEAT and Novelty

PCGNN - Procedural Content Generation with NEAT and Novelty Generation Approach — Metrics — Paper — Poster — Examples PCGNN - Procedural Content Gener

Michael Beukman 8 Dec 10, 2022
The 2nd place solution of 2021 google landmark retrieval on kaggle.

Leaderboard, taxonomy, and curated list of few-shot object detection papers.

229 Dec 13, 2022
Regression Metrics Calculation Made easy for tensorflow2 and scikit-learn

Regression Metrics Installation To install the package from the PyPi repository you can execute the following command: pip install regressionmetrics I

Ashish Patel 11 Dec 16, 2022
Small utility to demangle Nim symbols in callgrind files

nim_callgrind A small utility to demangle Nim symbols from callgrind files. Usage Run your (Nim) program with something like this: valgrind --tool=cal

kraptor 3 Feb 15, 2022
This project intends to use SVM supervised learning to determine whether or not an individual is diabetic given certain attributes.

Diabetes Prediction Using SVM I explore a diabetes prediction algorithm using a Diabetes dataset. Using a Support Vector Machine for my prediction alg

Jeff Shen 1 Jan 14, 2022
Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

PAWS-TF 🐾 Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS)

Sayak Paul 43 Jan 08, 2023
This repository contains the reference implementation for our proposed Convolutional CRFs.

ConvCRF This repository contains the reference implementation for our proposed Convolutional CRFs in PyTorch (Tensorflow planned). The two main entry-

Marvin Teichmann 553 Dec 07, 2022
Pre-training of Graph Augmented Transformers for Medication Recommendation

G-Bert Pre-training of Graph Augmented Transformers for Medication Recommendation Intro G-Bert combined the power of Graph Neural Networks and BERT (B

101 Dec 27, 2022
EMNLP 2021 paper The Devil is in the Detail: Simple Tricks Improve Systematic Generalization of Transformers.

Codebase for training transformers on systematic generalization datasets. The official repository for our EMNLP 2021 paper The Devil is in the Detail:

Csordás Róbert 57 Nov 21, 2022
PyTorch implementations of Generative Adversarial Networks.

This repository has gone stale as I unfortunately do not have the time to maintain it anymore. If you would like to continue the development of it as

Erik Linder-Norén 13.4k Jan 08, 2023