[CVPR 2022 Oral] Balanced MSE for Imbalanced Visual Regression https://arxiv.org/abs/2203.16427

Overview

Balanced MSE

Code for the paper:

Balanced MSE for Imbalanced Visual Regression
Jiawei Ren, Mingyuan Zhang, Cunjun Yu, Ziwei Liu

CVPR 2022 (Oral)

News

Live Demo

Check out our live demo in the Hugging Face 🤗 space!

Tutorial

We provide a minimal working example of Balanced MSE using the BMC implementation on a small-scale dataset, Boston Housing dataset.

Open In Colab

The notebook is developed on top of Deep Imbalanced Regression (DIR) Tutorial, we thank the authors for their amazing tutorial!

Quick Preview

A code snippet of the Balanced MSE loss is shown below. We use the BMC implementation for demonstration, BMC does not require any label prior beforehand.

One-dimensional Balanced MSE

def bmc_loss(pred, target, noise_var):
    """Compute the Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.
    Args:
      pred: A float tensor of size [batch, 1].
      target: A float tensor of size [batch, 1].
      noise_var: A float number or tensor.
    Returns:
      loss: A float tensor. Balanced MSE Loss.
    """
    logits = - (pred - target.T).pow(2) / (2 * noise_var)   # logit size: [batch, batch]
    loss = F.cross_entropy(logits, torch.arange(pred.shape[0]))     # contrastive-like loss
    loss = loss * (2 * noise_var).detach()  # optional: restore the loss scale, 'detach' when noise is learnable 

    return loss

noise_var is a one-dimensional hyper-parameter. noise_var can be optionally optimized in training:

class BMCLoss(_Loss):
    def __init__(self, init_noise_sigma):
        super(BMCLoss, self).__init__()
        self.noise_sigma = torch.nn.Parameter(torch.tensor(init_noise_sigma))

    def forward(self, pred, target):
        noise_var = self.noise_sigma ** 2
        return bmc_loss(pred, target, noise_var)

criterion = BMCLoss(init_noise_sigma)
optimizer.add_param_group({'params': criterion.noise_sigma, 'lr': sigma_lr, 'name': 'noise_sigma'})

Multi-dimensional Balanced MSE

The multi-dimensional implementation is compatible with the 1-D version.

from torch.distributions import MultivariateNormal as MVN

def bmc_loss_md(pred, target, noise_var):
    """Compute the Multidimensional Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.
    Args:
      pred: A float tensor of size [batch, d].
      target: A float tensor of size [batch, d].
      noise_var: A float number or tensor.
    Returns:
      loss: A float tensor. Balanced MSE Loss.
    """
    I = torch.eye(pred.shape[-1])
    logits = MVN(pred.unsqueeze(1), noise_var*I).log_prob(target.unsqueeze(0))  # logit size: [batch, batch]
    loss = F.cross_entropy(logits, torch.arange(pred.shape[0]))     # contrastive-like loss
    loss = loss * (2 * noise_var).detach()  # optional: restore the loss scale, 'detach' when noise is learnable 
    
    return loss

noise_var is still a one-dimensional hyper-parameter and can be optionally learned in training.

Run Experiments

Please go into the sub-folder to run experiments.

Citation

@inproceedings{ren2021bmse,
  title={Balanced MSE for Imbalanced Visual Regression},
  author={Ren, Jiawei and Zhang, Mingyuan and Yu, Cunjun and Liu, Ziwei},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2022}
}

Acknowledgment

This work is supported by NTU NAP, MOE AcRF Tier 2 (T2EP20221-0033), the National Research Foundation, Singapore under its AI Singapore Programme, and under the RIE2020 Industry Alignment Fund – Industry Collabo- ration Projects (IAF-ICP) Funding Initiative, as well as cash and in-kind contribution from the industry partner(s).

The code is developed on top of Delving into Deep Imbalanced Regression.

PyTorch implementation of the Deep SLDA method from our CVPRW-2020 paper "Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis"

Lifelong Machine Learning with Deep Streaming Linear Discriminant Analysis This is a PyTorch implementation of the Deep Streaming Linear Discriminant

Tyler Hayes 41 Dec 25, 2022
A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch

A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch The official pytorch implementation of the paper "Towards Faster and Stabilize

Bingchen Liu 455 Jan 08, 2023
GEP (GDB Enhanced Prompt) - a GDB plug-in for GDB command prompt with fzf history search, fish-like autosuggestions, auto-completion with floating window, partial string matching in history, and more!

GEP (GDB Enhanced Prompt) GEP (GDB Enhanced Prompt) is a GDB plug-in which make your GDB command prompt more convenient and flexibility. Why I need th

Alan Li 23 Dec 21, 2022
Model-free Vehicle Tracking and State Estimation in Point Cloud Sequences

Model-free Vehicle Tracking and State Estimation in Point Cloud Sequences 1. Introduction This project is for paper Model-free Vehicle Tracking and St

TuSimple 92 Jan 03, 2023
Linear Variational State Space Filters

Linear Variational State Space Filters To set up the environment, use the provided scripts in the docker/ folder to build and run the codebase inside

0 Dec 13, 2021
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022
Receptive Field Block Net for Accurate and Fast Object Detection, ECCV 2018

Receptive Field Block Net for Accurate and Fast Object Detection By Songtao Liu, Di Huang, Yunhong Wang Updatas (2021/07/23): YOLOX is here!, stronger

Liu Songtao 1.4k Dec 21, 2022
Pyserini is a Python toolkit for reproducible information retrieval research with sparse and dense representations.

Pyserini Pyserini is a Python toolkit for reproducible information retrieval research with sparse and dense representations. Retrieval using sparse re

Castorini 706 Dec 29, 2022
PyTorch implementation of "Optimization Planning for 3D ConvNets"

Optimization-Planning-for-3D-ConvNets Code for the ICML 2021 paper: Optimization Planning for 3D ConvNets. Authors: Zhaofan Qiu, Ting Yao, Chong-Wah N

Zhaofan Qiu 2 Jan 12, 2022
Training, generation, and analysis code for Learning Particle Physics by Example: Location-Aware Generative Adversarial Networks for Physics

Location-Aware Generative Adversarial Networks (LAGAN) for Physics Synthesis This repository contains all the code used in L. de Oliveira (@lukedeo),

Deep Learning for HEP 57 Oct 22, 2022
Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics.

Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics. By Andres Milioto @ University of Bonn. (for the new P

Photogrammetry & Robotics Bonn 314 Dec 30, 2022
Official code implementation for "Personalized Federated Learning using Hypernetworks"

Personalized Federated Learning using Hypernetworks This is an official implementation of Personalized Federated Learning using Hypernetworks paper. [

Aviv Shamsian 121 Dec 25, 2022
My implementation of Fully Convolutional Neural Networks in Keras

Keras-FCN This repository contains my implementation of Fully Convolutional Networks in Keras (Tensorflow backend). Currently, semantic segmentation c

The Duy Nguyen 15 Jan 13, 2020
Training code and evaluation benchmarks for the "Self-Supervised Policy Adaptation during Deployment" paper.

Self-Supervised Policy Adaptation during Deployment PyTorch implementation of PAD and evaluation benchmarks from Self-Supervised Policy Adaptation dur

Nicklas Hansen 101 Nov 01, 2022
Anti-UAV base on PaddleDetection

Paddle-Anti-UAV Anti-UAV base on PaddleDetection Background UAVs are very popular and we can see them in many public spaces, such as parks and playgro

Qingzhong Wang 2 Apr 20, 2022
A parallel framework for population-based multi-agent reinforcement learning.

MALib: A parallel framework for population-based multi-agent reinforcement learning MALib is a parallel framework of population-based learning nested

MARL @ SJTU 348 Jan 08, 2023
Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection

Frequency Spectrum Augmentation Consistency for Domain Adaptive Object Detection Main requirements torch = 1.0 torchvision = 0.2.0 Python 3 Environm

15 Apr 04, 2022
Code for CVPR 2018 paper --- Texture Mapping for 3D Reconstruction with RGB-D Sensor

G2LTex This repository contains the implementation of "Texture Mapping for 3D Reconstruction with RGB-D Sensor (CVPR2018)" based on mvs-texturing. Due

Fu Yanping(付燕平) 129 Dec 30, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained mo

Hugging Face 77.2k Jan 02, 2023
Code for Understanding Pooling in Graph Neural Networks

Select, Reduce, Connect This repository contains the code used for the experiments of: "Understanding Pooling in Graph Neural Networks" Setup Install

Daniele Grattarola 37 Dec 13, 2022