[NeurIPS 2021] “Improving Contrastive Learning on Imbalanced Data via Open-World Sampling”,

Related tags

Deep LearningMAK
Overview

Improving Contrastive Learning on Imbalanced Data via Open-World Sampling

Introduction

Contrastive learning approaches have achieved great success in learning visual representations with few labels. That implies a tantalizing possibility of scaling them up beyond a curated target benchmark, to incorporating more unlabeled images from the internet-scale external sources to enhance its performance. However, in practice, with larger amount of unlabeled data, it requires more compute resources for the bigger model size and longer training. Moreover, open-world unlabeled data have implicit long-tail distribution of various class attributes, many of which are out of distribution and can lead to data imbalancedness issue. This motivates us to seek a principled approach of selecting a subset of unlabeled data from an external source that are relevant for learning better and diverse representations. In this work, we propose an open-world unlabeled data sampling strategy called Model-Aware K-center (MAK), which follows three simple principles: (1) tailness, which encourages sampling of examples from tail classes, by sorting the empirical contrastive loss expectation (ECLE) of samples over random data augmentations; (2) proximity, which rejects the out-of-distribution outliers that might distract training; and (3) diversity, which ensures diversity in the set of sampled examples. Empirically, using ImageNet-100-LT (without labels) as the target dataset and two ``noisy'' external data sources, we demonstrate that MAK can consistently improve both the overall representation quality and class balancedness of the learned features, as evaluated via linear classifier evaluation on full-shot and few-shot settings.

Method

pipeline

Environment

Requirements:

pytorch 1.7.1 
opencv-python
kmeans-pytorch 0.3
scikit-learn

Recommend installation cmds (linux)

conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.2 -c pytorch # change cuda version according to hardware
pip install opencv-python
conda install -c conda-forge matplotlib scikit-learn

Sampling

Prepare

change the access permissions

chmod +x  cmds/shell_scrips/*

Get pre-trained model on LT datasets

bash ./cmds/shell_scrips/imagenet-100-add-data.sh -g 2 -p 4866 -w 10 --seed 10 --additional_dataset None

Sampling on ImageNet 900

Inference

inference on sampling dataset (no Aug)

bash ./cmds/shell_scrips/imagenet-100-inference.sh -p 5555 --workers 10 --pretrain_seed 10 \
--epochs 1000 --batch_size 256 --inference_dataset imagenet-900 --inference_dataset_split ImageNet_900_train \
--inference_repeat_time 1 --inference_noAug True

inference on sampling dataset (no Aug)

bash ./cmds/shell_scrips/imagenet-100-inference.sh -p 5555 --workers 10 --pretrain_seed 10 \
--epochs 1000 --batch_size 256 --inference_dataset imagenet-100 --inference_dataset_split imageNet_100_LT_train \
--inference_repeat_time 1 --inference_noAug True

inference on sampling dataset (w/ Aug)

bash ./cmds/shell_scrips/imagenet-100-inference.sh -p 5555 --workers 10 --pretrain_seed 10 \
--epochs 1000 --batch_size 256 --inference_dataset imagenet-900 --inference_dataset_split ImageNet_900_train \
--inference_repeat_time 10

sampling 10K at Imagenet900

bash ./cmds/shell_scrips/sampling.sh --pretrain_seed 10

Citation

@inproceedings{
jiang2021improving,
title={Improving Contrastive Learning on Imbalanced Data via Open-World Sampling},
author={Jiang, Ziyu and Chen, Tianlong and Chen, Ting and Wang, Zhangyang},
booktitle={Advances in Neural Information Processing Systems 35},
year={2021}
}
Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
Graph-total-spanning-trees - A Python script to get total number of Spanning Trees in a Graph

Total number of Spanning Trees in a Graph This is a python script just written f

Mehdi I. 0 Jul 18, 2022
BLEND: A Fast, Memory-Efficient, and Accurate Mechanism to Find Fuzzy Seed Matches

BLEND is a mechanism that can efficiently find fuzzy seed matches between sequences to significantly improve the performance and accuracy while reducing the memory space usage of two important applic

SAFARI Research Group at ETH Zurich and Carnegie Mellon University 19 Dec 26, 2022
CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation [arxiv] This is the official repository for CDTrans: Cross-domain Transformer for

238 Dec 22, 2022
Pytorch Geometric Tutorials

Pytorch Geometric Tutorials

Antonio Longa 648 Jan 08, 2023
Part-Aware Data Augmentation for 3D Object Detection in Point Cloud

Part-Aware Data Augmentation for 3D Object Detection in Point Cloud This repository contains a reference implementation of our Part-Aware Data Augment

Jaeseok Choi 62 Jan 03, 2023
A benchmark framework for Tensorflow

TensorFlow benchmarks This repository contains various TensorFlow benchmarks. Currently, it consists of two projects: PerfZero: A benchmark framework

1.1k Dec 30, 2022
Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning

Conservative and Adaptive Penalty for Model-Based Safe Reinforcement Learning This is the official repository for Conservative and Adaptive Penalty fo

7 Nov 22, 2022
StyleMapGAN - Official PyTorch Implementation

StyleMapGAN - Official PyTorch Implementation StyleMapGAN: Exploiting Spatial Dimensions of Latent in GAN for Real-time Image Editing Hyunsu Kim, Yunj

NAVER AI 425 Dec 23, 2022
Code and model benchmarks for "SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology"

NeurIPS 2020 SEVIR Code for paper: SEVIR : A Storm Event Imagery Dataset for Deep Learning Applications in Radar and Satellite Meteorology Requirement

USAF - MIT Artificial Intelligence Accelerator 46 Dec 15, 2022
Implementation of Pooling by Sliced-Wasserstein Embedding (NeurIPS 2021)

PSWE: Pooling by Sliced-Wasserstein Embedding (NeurIPS 2021) PSWE is a permutation-invariant feature aggregation/pooling method based on sliced-Wasser

Navid Naderializadeh 3 May 06, 2022
This is a Keras-based Python implementation of DeepMask- a complex deep neural network for learning object segmentation masks

NNProject - DeepMask This is a Keras-based Python implementation of DeepMask- a complex deep neural network for learning object segmentation masks. Th

189 Nov 16, 2022
Training, generation, and analysis code for Learning Particle Physics by Example: Location-Aware Generative Adversarial Networks for Physics

Location-Aware Generative Adversarial Networks (LAGAN) for Physics Synthesis This repository contains all the code used in L. de Oliveira (@lukedeo),

Deep Learning for HEP 57 Oct 22, 2022
FastyAPI is a Stack boilerplate optimised for heavy loads.

FastyAPI A FastAPI based Stack boilerplate for heavy loads. Explore the docs » View Demo · Report Bug · Request Feature Table of Contents About The Pr

Ali Chaayb 47 Dec 27, 2022
Segmentation models with pretrained backbones. Keras and TensorFlow Keras.

Python library with Neural Networks for Image Segmentation based on Keras and TensorFlow. The main features of this library are: High level API (just

Pavel Yakubovskiy 4.2k Jan 09, 2023
Restricted Boltzmann Machines in Python.

How to Use First, initialize an RBM with the desired number of visible and hidden units. rbm = RBM(num_visible = 6, num_hidden = 2) Next, train the m

Edwin Chen 928 Dec 30, 2022
NeuralForecast is a Python library for time series forecasting with deep learning models

NeuralForecast is a Python library for time series forecasting with deep learning models. It includes benchmark datasets, data-loading utilities, evaluation functions, statistical tests, univariate m

Nixtla 1.1k Jan 03, 2023
Wind Speed Prediction using LSTMs in PyTorch

Implementation of Deep-Forecast using PyTorch Deep Forecast: Deep Learning-based Spatio-Temporal Forecasting Adapted from original implementation Setu

Onur Kaplan 151 Dec 14, 2022
Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

730 Jan 09, 2023
This repository contains a PyTorch implementation of "AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis".

AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis | Project Page | Paper | PyTorch implementation for the paper "AD-NeRF: Audio

551 Dec 29, 2022
A set of tools for converting a darknet dataset to COCO format working with YOLOX

darknet格式数据→COCO darknet训练数据目录结构(详情参见dataset/darknet): darknet ├── class.names ├── gen_config.data ├── gen_train.txt ├── gen_valid.txt └── images

RapidAI-NG 148 Jan 03, 2023