Repository for paper "Non-intrusive speech intelligibility prediction from discrete latent representations"

Overview

Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations

Official repository for paper "Non-Intrusive Speech Intelligibility Prediction from Discrete Latent Representations".

This public repository is a work in progress! Results here bear no resemblance to results in the paper!

We predict the intelligibility of binaural speech signals by first extracting latent representations from raw audio. Then, a lightweight predictor over these latent representations can be trained. This results in improved performance over predicting on spectral features of the audio, despite the feature extractor not being explicitly trained for this task. In certain cases, a single layer is sufficient for strong correlations between the predictions and the ground-truth scores.

This repository contains:

  • vqcpc/ - Module for VQCPC model in PyTorch
  • stoi/ - Module for Small and SeqPool predictor model in PyTorch
  • data.py - File containing various PyTorch custom datasets
  • main-vqcpc.py - Script for VQCPC training
  • create-latents.py - Script for generating latent dataset from trained VQCPC
  • plot-latents.py - Script for visualizing extracted latent representations
  • main-stoi.py - Script for STOI predictor training
  • main-test.py - Script for evaluating models
  • compute-correlations.py - Script for computing metrics for many models
  • checkpoints/ - trained checkpoints of VQCPC and STOI predictor models
  • config/ - Directory containing various configuration files for experiments
  • results/ - Directory containing official results from experiments
  • dataset/ - Directory containing metadata files for the dataset
  • data-generator/ - Directory containing dataset generation scripts (MATLAB)

All models are implemented in PyTorch. The training scripts are implemented using ptpt - a lightweight framework around PyTorch.

Visualisation of binaural waveform, predicted per-frame STOI, and latent representation: Visualisation of binaural waveform, predicted per-frame STOI, and latent representation.

Usage

VQ-CPC Training

Begin VQ-CPC training using the configuration defined in config.toml:

python main-vqcpc.py --cfg-path config-path.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Latent Dataset Generation

Begin latent dataset generation using pre-trained VQCPC model-checkpoint.pt from dataset wav-dataset and output to latent-dataset using configuration defined in config.toml:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml

As above, but distributed across n processes with script rank r:

python create-latents.py model-checkpoint.pt wav-dataset latent-dataset --cfg-path config.toml --array-size n --array-rank r

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--detect-anomaly    # detect autograd anomalies and terminate if encountered
-n                  # alias for `--array-size`
-r                  # alias for `--array-rank`

Latent Plotting

Begin interactive VQCPC latent visualisation script using pre-trained model model-checkpoint.pt on dataset wav-dataset using configuration defined in config.toml:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml

If you additionally have a pre-trained, per-frame STOI score predictor (not SeqPool predictor) you can specify the checkpoint stoi-checkpoint.pt and additional configuration stoi-config.toml, you can plot per-frame scores alongside the waveform and latent features:

python plot-latents.py model-checkpoint.pt wav-dataset --cfg-path config.toml --stoi stoi-checkpoint.pt --stoi-cfg stoi-config.toml

Other useful arguments:

--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--cmap              # define matplotlib colourmap
--style             # define matplotlib style

STOI Predictor Training

Begin intelligibility score predictor training script using configuration in config.toml:

python main-stoi.py --cfg-path config.toml

Other useful arguments:

--resume            # resume from specified checkpoint
--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--seed              # random seed (default: 12345)

Predictor Evaluation

Begin evaluation of a pre-trained STOI score predictor using checkpoint stoi-checkpoint.pt on dataset dataset-root using configuration in stoi-config.toml:

python main-test.py stoi-checkpoint.pt dataset-root --cfg-path stoi-config.toml

Other useful arguments:

--no-save           # do not save training progress (useful for debugging)
--no-cuda           # do not try to access CUDA device (very slow)
--no-amp            # disable automatic mixed precision (if you encounter NaN)
--no-tqdm           # disable progress bars
--nb-workers        # number of workers for for data loading (default: 8)
--detect-anomaly    # detect autograd anomalies and terminate if encountered
--batch-size        # control dataloader batch size
--seed              # random seed (default: 12345)

Overall Evaluation

Compare results from many results files produced by main-test.py based on dataset ground truth:

python compute-correlations.py ground-truth.csv pred-1.csv ... pred-n.csv --names pred-1 ... pred-n

Configuration

Examples configurations for all experiments can be found here

We use toml files to define configurations. Each one consists of three sections:

  • [trainer]: configuration options for ptpt.TrainerConfig.
  • [data]: configuration options for the dataset.
  • [vqcpc] or [stoi]: configuration options for the VQCPC and predictor models respectively.

Checkpoints

Pretrained checkpoints for all models can be found here

Citation

TODO: add citation once paper published / arXiv-ed :)

Owner
Alex McKinney
Final-year student at Durham University. Interested in generative models and unsupervised representation learning.
Alex McKinney
Official PyTorch implementation of "Proxy Synthesis: Learning with Synthetic Classes for Deep Metric Learning" (AAAI 2021)

Proxy Synthesis: Learning with Synthetic Classes for Deep Metric Learning Official PyTorch implementation of "Proxy Synthesis: Learning with Synthetic

NAVER/LINE Vision 30 Dec 06, 2022
Generate image analogies using neural matching and blending

neural image analogies This is basically an implementation of this "Image Analogies" paper, In our case, we use feature maps from VGG16. The patch mat

Adam Wentz 3.5k Jan 08, 2023
Self-Learned Video Rain Streak Removal: When Cyclic Consistency Meets Temporal Correspondence

In this paper, we address the problem of rain streaks removal in video by developing a self-learned rain streak removal method, which does not require any clean groundtruth images in the training pro

Yang Wenhan 44 Dec 06, 2022
A particular navigation route using satellite feed and can help in toll operations & traffic managemen

How about adding some info that can quanitfy the stress on a particular navigation route using satellite feed and can help in toll operations & traffic management The current analysis is on the satel

Ashish Pandey 1 Feb 14, 2022
Show Me the Whole World: Towards Entire Item Space Exploration for Interactive Personalized Recommendations

HierarchicyBandit Introduction This is the implementation of WSDM 2022 paper : Show Me the Whole World: Towards Entire Item Space Exploration for Inte

yu song 5 Sep 09, 2022
Vanilla and Prototypical Networks with Random Weights for image classification on Omniglot and mini-ImageNet. Made with Python3.

vanilla-rw-protonets-project Vanilla Prototypical Networks and PNs with Random Weights for image classification on Omniglot and mini-ImageNet. Made wi

Giovani Candido 8 Aug 31, 2022
Code for our ACL 2021 paper "One2Set: Generating Diverse Keyphrases as a Set"

One2Set This repository contains the code for our ACL 2021 paper “One2Set: Generating Diverse Keyphrases as a Set”. Our implementation is built on the

Jiacheng Ye 63 Jan 05, 2023
Table-Extractor 表格抽取

(t)able-(ex)tractor 本项目旨在实现pdf表格抽取。 Models 版面分析模块(Yolo) 表格结构抽取(ResNet + Transformer) 文字识别模块(CRNN + CTC Loss) Acknowledgements TableMaster attention-i

2 Jan 15, 2022
Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"

Infinitely Deep Bayesian Neural Networks with SDEs This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stocha

Winnie Xu 95 Nov 26, 2021
Source code and notebooks to reproduce experiments and benchmarks on Bias Faces in the Wild (BFW).

Face Recognition: Too Bias, or Not Too Bias? Robinson, Joseph P., Gennady Livitz, Yann Henon, Can Qin, Yun Fu, and Samson Timoner. "Face recognition:

Joseph P. Robinson 41 Dec 12, 2022
Apply a perspective transformation to a raster image inside Inkscape (no need to use an external software such as GIMP or Krita).

Raster Perspective Apply a perspective transformation to bitmap image using the selected path as envelope, without the need to use an external softwar

s.ouchene 19 Dec 22, 2022
A stock generator that assess a list of stocks and returns the best stocks for investing and money allocations based on users choices of volatility, duration and number of stocks

Stock-Generator Please visit "Stock Generator.ipynb" for a clearer view and "Stock Generator.py" for scripts. The stock generator is designed to allow

jmengnyay 1 Aug 02, 2022
This is an official implementation for "SimMIM: A Simple Framework for Masked Image Modeling".

SimMIM By Zhenda Xie*, Zheng Zhang*, Yue Cao*, Yutong Lin, Jianmin Bao, Zhuliang Yao, Qi Dai and Han Hu*. This repo is the official implementation of

Microsoft 674 Dec 26, 2022
Hierarchical Time Series Forecasting with a familiar API

scikit-hts Hierarchical Time Series with a familiar API. This is the result from not having found any good implementations of HTS on-line, and my work

Carlo Mazzaferro 204 Dec 17, 2022
Implements Gradient Centralization and allows it to use as a Python package in TensorFlow

Gradient Centralization TensorFlow This Python package implements Gradient Centralization in TensorFlow, a simple and effective optimization technique

Rishit Dagli 101 Nov 01, 2022
A PyTorch implementation of the Relational Graph Convolutional Network (RGCN).

Torch-RGCN Torch-RGCN is a PyTorch implementation of the RGCN, originally proposed by Schlichtkrull et al. in Modeling Relational Data with Graph Conv

Thiviyan Singam 66 Nov 30, 2022
Python and C++ implementation of "MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation". Accepted at LXCV @ CVPR 2021.

MarkerPose: Robust real-time planar target tracking for accurate stereo pose estimation This is a PyTorch and LibTorch implementation of MarkerPose: a

Jhacson Meza 47 Nov 18, 2022
Implementation for our ICCV 2021 paper: Dual-Camera Super-Resolution with Aligned Attention Modules

DCSR: Dual Camera Super-Resolution Implementation for our ICCV 2021 oral paper: Dual-Camera Super-Resolution with Aligned Attention Modules paper | pr

Tengfei Wang 110 Dec 20, 2022
Mixup for Supervision, Semi- and Self-Supervision Learning Toolbox and Benchmark

OpenSelfSup News Downstream tasks now support more methods(Mask RCNN-FPN, RetinaNet, Keypoints RCNN) and more datasets(Cityscapes). 'GaussianBlur' is

AI Lab, Westlake University 332 Jan 03, 2023
Bayesian Meta-Learning Through Variational Gaussian Processes

vmgp This is the repository of Vivek Myers and Nikhil Sardana for our CS 330 final project, Bayesian Meta-Learning Through Variational Gaussian Proces

Vivek Myers 2 Nov 17, 2022