TyXe: Pyro-based BNNs for Pytorch users

Related tags

Deep LearningTyXe
Overview

TyXe: Pyro-based BNNs for Pytorch users

TyXe aims to simplify the process of turning Pytorch neural networks into Bayesian neural networks by leveraging the model definition and inference capabilities of Pyro. Our core design principle is to cleanly separate the construction of neural architecture, prior, inference distribution and likelihood, enabling a flexible workflow where each component can be exchanged independently. Defining a BNN in TyXe takes as little as 5 lines of code:

net = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))
prior = tyxe.priors.IIDPrior(dist.Normal(0, 1))
likelihood = tyxe.likelihoods.HomoskedasticGaussian(scale=0.1)
inference = tyxe.guides.AutoNormal
bnn = tyxe.VariationalBNN(net, prior, likelihood, inference)

In the following, we assume that you (roughly) know what a BNN is mathematically.

Motivating example

Standard neural networks give us a single function that fits the data, but many different ones are typically plausible. With only a single fit, we don't know for what inputs the model is 'certain' (because there is training data nearby) and where it is uncertain.

ML Samples
Maximum likelihood fit Posterior samples

Implementing the former can be achieved easily in a few lines of Pytorch code, but training a BNN that gives a distribution over different fits is typically more complicated and is specifically what we aim to simplify.

Training

Constructing a BNN object has been shown in the example above. For fitting the posterior approximation, we provide a high-level .fit method similar to libraries such as scikit-learn or keras:

optim = pyro.optim.Adam({"lr": 1e-3})
bnn.fit(data_loader, optim, num_epochs)

Prediction & evaluation

Further we provide .predict and .evaluation methods, which make predictions based on multiple samples from the approximate posterior, average them based on the observation model, and return log likelihoods and an error measure:

predictions = bnn.predict(x_test, num_samples)
error, log_likelihood = bnn.evaluate(x_test, y_test, num_samples)

Local reparameterization

We implement local reparameterization for factorized Gaussians as a poutine, which reduces gradient noise during training. This means it can be enabled or disabled at both during training and prediction with a context manager:

with tyxe.poutine.local_reparameterization():
    bnn.fit(data_loader, optim, num_epochs)
    bnn.predict(x_test, num_predictions)

At the moment, this poutine does not work with the AutoNormal and AutoDiagonalNormal guides in pyro, since those draw the weights from a Delta distribution, so you need to use tyxe.guides.ParameterwiseDiagonalNormal as your guide.

MCMC

We provide a unified interface to pyro's MCMC implementations, simply use the tyxe.MCMC_BNN class instead and provide a kernel instead of the guide:

kernel = pyro.infer.mcmcm.NUTS
bnn = tyxe.MCMC_BNN(net, prior, likelihood, kernel)

Any parameters that pyro's MCMC class accepts can be passed through the keyword arguments of the .fit method.

Continual learning

Due to our design that cleanly separates the prior from guide, architecture and likelihood, it is easy to update it in a continual setting. For example, you can construct a tyxe.priors.DictPrior by extracting the distributions over all weights and biases from a ParameterwiseDiagonalNormal instance using the get_detached_distributions method and pass it to bnn.update_prior to implement Variational Continual Learning in a few lines of code. See examples/vcl.py for a basic example on split-MNIST and split-CIFAR.

Network architectures

We don't implement any layer classes. You construct your network in Pytorch and then turn it into a BNN, which makes it easy to apply the same prior and inference strategies to different neural networks.

Inference

For inference, we mainly provide an equivalent to pyro's AutoDiagonalNormal that is compatible with local reparameterization in tyxe.guides. This module also contains a few helper functions for initialization of Gaussian mean parameters, e.g. to the values of a pre-trained network. It should be possible to use any of pyro's autoguides for variational inference. See examples/resnet.py for a few options as well as initializing to pre-trained weights.

Priors

The priors can be found in tyxe.priors. We currently only support placing priors on the parameters. Through the expose and hide arguments in the init method you can specify layers, types of layers and specific parameters over which you want to place a prior. This helps, for example in learning the parameters of BatchNorm layers deterministically.

Likelihoods

tyxe.observation_models contains classes that wrap the most common torch.distributions for specifying noise models of data to

Installation

We recommend installing TyXe using conda with the provided environment.yml, which also installs all the dependencies for the examples except for Pytorch3d, which needs to be added manually. The environment assumes that you are using CUDA11.0, if this is not the case, simply change the cudatoolkit and dgl-cuda versions before running:

conda env create -f environment.yml
conda activate tyxe
pip install -e .

Citation

If you use TyXe, please consider citing:

@article{ritter2021tyxe,
  author    = {Hippolyt Ritter and
               Theofanis Karaletsos
               },
  title     = {TyXe: Pyro-based Bayesian neural nets for Pytorch},
  journal   = {International Conference on Probabilistic Programming (ProbProg)},
  volume    = {},
  pages     = {},
  year      = {2020},
  url       = {https://arxiv.org/abs/2110.00276}
}
Collection of generative models, e.g. GAN, VAE in Pytorch and Tensorflow.

Generative Models Collection of generative models, e.g. GAN, VAE in Pytorch and Tensorflow. Also present here are RBM and Helmholtz Machine. Note: Gen

Agustinus Kristiadi 7k Jan 02, 2023
YOLTv4 builds upon YOLT and SIMRDWN, and updates these frameworks to use the most performant version of YOLO, YOLOv4

YOLTv4 builds upon YOLT and SIMRDWN, and updates these frameworks to use the most performant version of YOLO, YOLOv4. YOLTv4 is designed to detect objects in aerial or satellite imagery in arbitraril

Adam Van Etten 161 Jan 06, 2023
Location-Sensitive Visual Recognition with Cross-IOU Loss

The trained models are temporarily unavailable, but you can train the code using reasonable computational resource. Location-Sensitive Visual Recognit

Kaiwen Duan 146 Dec 25, 2022
Optimizing DR with hard negatives and achieving SOTA first-stage retrieval performance on TREC DL Track (SIGIR 2021 Full Paper).

Optimizing Dense Retrieval Model Training with Hard Negatives Jingtao Zhan, Jiaxin Mao, Yiqun Liu, Jiafeng Guo, Min Zhang, Shaoping Ma This repo provi

Jingtao Zhan 99 Dec 27, 2022
Supervised & unsupervised machine-learning techniques are applied to the database of weighted P4s which admit Calabi-Yau hypersurfaces.

Weighted Projective Spaces ML Description: The database of 5-vectors describing 4d weighted projective spaces which admit Calabi-Yau hypersurfaces are

Ed Hirst 3 Sep 08, 2022
HINet: Half Instance Normalization Network for Image Restoration

HINet: Half Instance Normalization Network for Image Restoration Liangyu Chen, Xin Lu, Jie Zhang, Xiaojie Chu, Chengpeng Chen Paper: https://arxiv.org

303 Dec 31, 2022
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021
Related resources for our EMNLP 2021 paper

Plan-then-Generate: Controlled Data-to-Text Generation via Planning Authors: Yixuan Su, David Vandyke, Sihui Wang, Yimai Fang, and Nigel Collier Code

Yixuan Su 61 Jan 03, 2023
CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

[ICCV2021] TransReID: Transformer-based Object Re-Identification [pdf] The official repository for TransReID: Transformer-based Object Re-Identificati

DamoCV 569 Dec 30, 2022
Code for the paper: "On the Bottleneck of Graph Neural Networks and Its Practical Implications"

On the Bottleneck of Graph Neural Networks and its Practical Implications This is the official implementation of the paper: On the Bottleneck of Graph

75 Dec 22, 2022
Self-Supervised Document-to-Document Similarity Ranking via Contextualized Language Models and Hierarchical Inference

Self-Supervised Document Similarity Ranking (SDR) via Contextualized Language Models and Hierarchical Inference This repo is the implementation for SD

Microsoft 36 Nov 28, 2022
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Pawel Dziemiach 1 Dec 18, 2021
PyTorch code accompanying our paper on Maximum Entropy Generators for Energy-Based Models

Maximum Entropy Generators for Energy-Based Models All experiments have tensorboard visualizations for samples / density / train curves etc. To run th

Rithesh Kumar 135 Oct 27, 2022
[3DV 2021] Channel-Wise Attention-Based Network for Self-Supervised Monocular Depth Estimation

Channel-Wise Attention-Based Network for Self-Supervised Monocular Depth Estimation This is the official implementation for the method described in Ch

Jiaxing Yan 27 Dec 30, 2022
Face recognition project by matching the features extracted using SIFT.

MV_FaceDetectionWithSIFT Face recognition project by matching the features extracted using SIFT. By : Aria Radmehr Professor : Ali Amiri Dependencies

Aria Radmehr 4 May 31, 2022
GrailQA: Strongly Generalizable Question Answering

GrailQA is a new large-scale, high-quality KBQA dataset with 64,331 questions annotated with both answers and corresponding logical forms in different syntax (i.e., SPARQL, S-expression, etc.). It ca

OSU DKI Lab 76 Dec 21, 2022
Code & Data for the Paper "Time Masking for Temporal Language Models", WSDM 2022

Time Masking for Temporal Language Models This repository provides a reference implementation of the paper: Time Masking for Temporal Language Models

Guy Rosin 12 Jan 06, 2023
Make your own game in a font!

Project structure. Included is a suite of tools to create font games. Tutorial: For a quick tutorial about how to make your own game go here For devel

Michael Mulet 125 Dec 04, 2022
Official Implement of CVPR 2021 paper “Cross-Modal Collaborative Representation Learning and a Large-Scale RGBT Benchmark for Crowd Counting”

RGBT Crowd Counting Lingbo Liu, Jiaqi Chen, Hefeng Wu, Guanbin Li, Chenglong Li, Liang Lin. "Cross-Modal Collaborative Representation Learning and a L

37 Dec 08, 2022
Code release for NeX: Real-time View Synthesis with Neural Basis Expansion

NeX: Real-time View Synthesis with Neural Basis Expansion Project Page | Video | Paper | COLAB | Shiny Dataset We present NeX, a new approach to novel

538 Jan 09, 2023