Understanding the Generalization Benefit of Model Invariance from a Data Perspective

Overview

Understanding the Generalization Benefit of Model Invariance from a Data Perspective

This is the code for our NeurIPS2021 paper "Understanding the Generalization Benefit of Model Invariance from a Data Perspective". There are two major parts in our code: sample covering number estimation and generalization benefit evaluation.

Requirments

  • Python 3.8
  • PyTorch
  • torchvision
  • scikit-learn-extra
  • scipy
  • robustness package (already included in our code)

Our code is based on robustness package.

Dataset

  • CIFAR-10 Download and extract the data into /data/cifar10
  • R2N2 Download the ShapeNet rendered images and put the data into /data/r2n2

The randomly sampled R2N2 images used for computing sample covering numbers and indices of examples for different sample sizes could be found here.

Estimation of sample covering numbers

To estimate the sample covering numbers of different data transformations, run the following script in /scn.

CUDA_VISIBLE_DEVICES=0 python run_scn.py  --epsilon 3 --transformation crop --cover_number_method fast --data-path /path/to/dataset 

Note that the input is a N x C x H x W tensor where N is sample size.

Evaluation of generalization benefit

To train the model with data augmentation method, run the following script in /learn_invariance for R2N2 dataset

CUDA_VISIBLE_DEVICES=0 python main.py \
    --dataset r2n2 \
    --data ../data/2n2/ShapeNetRendering \
    --metainfo-path ../data/r2n2/metainfo_all.json \
    --transforms view  \
    --inv-method aug \
    --out-dir /path/to/out_dir \
    --arch resnet18 --epoch 110 --lr 1e-2 --step-lr 50 \
    --workers 30 --batch-size 128 --exp-name view

or the following script for CIFAR-10 dataset

CUDA_VISIBLE_DEVICES=0 python main.py \
    --dataset cifar \
    --data ../data/cifar10 \
    --n-per-class all \
    --transforms crop  \
    --inv-method aug \
    --out-dir /path/to/out_dir \
    --arch resnet18 --epoch 110 --lr 1e-2 --step-lr 50 \
    --workers 30 --batch-size 128 --exp-name crop 

By setting --transforms to be one of {none, flip, crop, rotate, view}, the specific transformation will be considered.

To train the model with regularization method, run the following script. Currently, the code only support 3d-view transformation on R2N2 dataset.

CUDA_VISIBLE_DEVICES=0 python main.py \
    --dataset r2n2 \
    --data ../data/r2n2/ShapeNetRendering \
    --metainfo-path ../data/r2n2/metainfo_all.json \
    --transforms view  \
    --inv-method reg \
    --inv-method-beta 1 \
    --out-dir /path/to/out_dir \
    --arch resnet18 --epoch 110 --lr 1e-2 --step-lr 50 \
    --workers 30 --batch-size 128 --exp-name reg_view 

To evaluate the model with invariance loss and worst-case consistency accuracy, run the following script.

CUDA_VISIBLE_DEVICES=0 python main.py  \
    --dataset r2n2 \
    --data ../data/r2n2/ShapeNetRendering \
    --metainfo-path ../data/r2n2/metainfo_all.json \
    --inv-method reg \
    --arch resnet18 \
    --resume /path/to/checkpoint.pt.best \
    --eval-only 1 \
    --transforms view  \
    --adv-eval 0 \
    --batch-size 2  \
    --no-store 

Note that to have the worst-case consistency accuracy we need to load 24 view images in R2N2RenderingsTorch class in dataset_3d.py.

Owner
PhD student at University of Maryland
Custom studies about block sparse attention.

Block Sparse Attention 研究总结 本人近半年来对Block Sparse Attention(块稀疏注意力)的研究总结(持续更新中)。按时间顺序,主要分为如下三部分: PyTorch 自定义 CUDA 算子——以矩阵乘法为例 基于 Triton 的 Block Sparse A

Chen Kai 2 Jan 09, 2022
Deep motion transfer

animation-with-keypoint-mask Paper The right most square is the final result. Softmax mask (circles): \ Heatmap mask: \ conda env create -f environmen

9 Nov 01, 2022
Semi-SDP Semi-supervised parser for semantic dependency parsing.

Semi-SDP Semi-supervised parser for semantic dependency parsing. This repo contains the code used for the semi-supervised semantic dependency parser i

12 Sep 17, 2021
(ICCV 2021) ProHMR - Probabilistic Modeling for Human Mesh Recovery

ProHMR - Probabilistic Modeling for Human Mesh Recovery Code repository for the paper: Probabilistic Modeling for Human Mesh Recovery Nikos Kolotouros

Nikos Kolotouros 209 Dec 13, 2022
This is the code for our paper "Iconary: A Pictionary-Based Game for Testing Multimodal Communication with Drawings and Text"

Iconary This is the code for our paper "Iconary: A Pictionary-Based Game for Testing Multimodal Communication with Drawings and Text". It includes the

AI2 6 May 24, 2022
[Open Source]. The improved version of AnimeGAN. Landscape photos/videos to anime

[Open Source]. The improved version of AnimeGAN. Landscape photos/videos to anime

CC 4.4k Dec 27, 2022
Neural Re-rendering for Full-frame Video Stabilization

NeRViS: Neural Re-rendering for Full-frame Video Stabilization Project Page | Video | Paper | Google Colab Setup Setup environment for [Yu and Ramamoo

Yu-Lun Liu 9 Jun 17, 2022
Implementation for paper LadderNet: Multi-path networks based on U-Net for medical image segmentation

Implementation for paper LadderNet: Multi-path networks based on U-Net for medical image segmentation This implementation is based on orobix implement

Juntang Zhuang 116 Sep 06, 2022
A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking

PoseRBPF: A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking PoseRBPF Paper Self-supervision Paper Pose Estimation Video Robot Manipulati

NVIDIA Research Projects 107 Dec 25, 2022
An Efficient Implementation of Analytic Mesh Algorithm for 3D Iso-surface Extraction from Neural Networks

AnalyticMesh Analytic Marching is an exact meshing solution from neural networks. Compared to standard methods, it completely avoids geometric and top

Karbo 45 Dec 21, 2022
PyTorch implementation of UPFlow (unsupervised optical flow learning)

UPFlow: Upsampling Pyramid for Unsupervised Optical Flow Learning By Kunming Luo, Chuan Wang, Shuaicheng Liu, Haoqiang Fan, Jue Wang, Jian Sun Megvii

kunming luo 87 Dec 20, 2022
PICK: Processing Key Information Extraction from Documents using Improved Graph Learning-Convolutional Networks

Code for the paper "PICK: Processing Key Information Extraction from Documents using Improved Graph Learning-Convolutional Networks" (ICPR 2020)

Wenwen Yu 498 Dec 24, 2022
EMNLP 2021 Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections

Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections Ruiqi Zhong, Kristy Lee*, Zheng Zhang*, Dan Klein EMN

Ruiqi Zhong 42 Nov 03, 2022
Code for our NeurIPS 2021 paper Mining the Benefits of Two-stage and One-stage HOI Detection

CDN Code for our NeurIPS 2021 paper "Mining the Benefits of Two-stage and One-stage HOI Detection". Contributed by Aixi Zhang*, Yue Liao*, Si Liu, Mia

71 Dec 14, 2022
State-Relabeling Adversarial Active Learning

State-Relabeling Adversarial Active Learning Code for SRAAL [2020 CVPR Oral] Requirements torch = 1.6.0 numpy = 1.19.1 tqdm = 4.31.1 AL Results The

10 Jul 14, 2022
A Pytorch implement of paper "Anomaly detection in dynamic graphs via transformer" (TADDY).

TADDY: Anomaly detection in dynamic graphs via transformer This repo covers an reference implementation for the paper "Anomaly detection in dynamic gr

Yue Tan 21 Nov 24, 2022
PyTorch implementation of EigenGAN

PyTorch Implementation of EigenGAN Train python train.py [image_folder_path] --name [experiment name] Test python test.py [ckpt path] --traverse FFH

62 Nov 12, 2022
Semi Supervised Learning for Medical Image Segmentation, a collection of literature reviews and code implementations.

Semi-supervised-learning-for-medical-image-segmentation. Recently, semi-supervised image segmentation has become a hot topic in medical image computin

Healthcare Intelligence Laboratory 1.3k Jan 03, 2023
PyTorch Implementation of Region Similarity Representation Learning (ReSim)

ReSim This repository provides the PyTorch implementation of Region Similarity Representation Learning (ReSim) described in this paper: @Article{xiao2

Tete Xiao 74 Jan 03, 2023
TensorFlow implementation of "Attention is all you need (Transformer)"

[TensorFlow 2] Attention is all you need (Transformer) TensorFlow implementation of "Attention is all you need (Transformer)" Dataset The MNIST datase

YeongHyeon Park 4 Jan 05, 2022