Unofficial JAX implementations of Deep Learning models

Overview

JAX Models

license-shield release-shield python-shield code-style

Table of Contents
  1. About The Project
  2. Getting Started
  3. Contributing
  4. License
  5. Contact

About The Project

The JAX Models repository aims to provide open sourced JAX/Flax implementations for research papers originally without code or code written with frameworks other than JAX. The goal of this project is to make a collection of models, layers, activations and other utilities that are most commonly used for research. All papers and derived or translated code is cited in either the README or the docstrings. If you think that any citation is missed then please raise an issue.

All implementations provided here are available on Papers With Code.


Available model implementations for JAX are:
  1. MetaFormer is Actually What You Need for Vision (Weihao Yu et al., 2021)
  2. Augmenting Convolutional networks with attention-based aggregation (Hugo Touvron et al., 2021)
  3. MPViT : Multi-Path Vision Transformer for Dense Prediction (Youngwan Lee et al., 2021)
  4. MLP-Mixer: An all-MLP Architecture for Vision (Ilya Tolstikhin et al., 2021)
  5. Patches Are All You Need (Anonymous et al., 2021)
  6. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (Enze Xie et al., 2021)
  7. A ConvNet for the 2020s (Zhuang Liu et al., 2021)
  8. Masked Autoencoders Are Scalable Vision Learners (Kaiming He et al., 2021)

Available layers for out-of-the-box integration:
  1. DropPath (Stochastic Depth) (Gao Huang et al., 2021)
  2. Squeeze-and-Excitation Layer (Jie Hu et al. 2019)
  3. Depthwise Convolution (François Chollet, 2017)

Prerequisites

Prerequisites can be installed separately through the requirements.txt file in the main directory using:

pip install -r requirements.txt

The use of a virtual environment is highly recommended to avoid version incompatibilites.

Installation

This project is built with Python 3 for the latest JAX/Flax versions and can be directly installed via pip.

pip install jax-models

If you wish to use the latest version then you can directly clone the repository too.

git clone https://github.com/DarshanDeshpande/jax-models.git

Usage

To see all model architectures available:

from jax_models.models.model_registry import list_models
from pprint import pprint

pprint(list_models())

To load your desired model:

from jax_models.models.model_registry import load_model
load_model('mpvit-base', attach_head=True, num_classes=1000, dropout=0.1)

Contributing

Please raise an issue if any implementation gives incorrect results, crashes unexpectedly during training/inference or if any citation is missing.

You can contribute to jax_models by supporting me with compute resources or by contributing your own resources to provide pretrained weights.

If you wish to donate to this inititative then please drop me a mail here.

License

Distributed under the Apache 2.0 License. See LICENSE for more information.

Contact

Feel free to reach out for any issues or requests related to these implementations

Darshan Deshpande - Email | Twitter | LinkedIn

You might also like...
Very deep VAEs in JAX/Flax
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

PyTorch implementations of neural network models for keyword spotting
PyTorch implementations of neural network models for keyword spotting

Honk: CNNs for Keyword Spotting Honk is a PyTorch reimplementation of Google's TensorFlow convolutional neural networks for keyword spotting, which ac

Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning
Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning

Proxy Anchor Loss for Deep Metric Learning Unofficial pytorch, tensorflow and mxnet implementations of Proxy Anchor Loss for Deep Metric Learning. Not

Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.
Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.

Stock Price Prediction Using Deep Learning Univariate Time Series Predicting stock price using historical data of a company using Neural networks for

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

JAX code for the paper
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Comments
  • Missing Axis Swap in ExtractPatches and MergePatches

    Missing Axis Swap in ExtractPatches and MergePatches

    In patch_utils.py, the modules ExtractPatches and MergePatches are missing an axis swap between the reshapes, resulting in the extracted patches becoming horizontal stripes. For example, if we follow the code in ExtractPatches:

    >>> inputs = jnp.arange(16).reshape(1, 4, 4, 1)
    >>> inputs[0, :, :, 0]
    
    DeviceArray([[ 0,  1,  2,  3],
                 [ 4,  5,  6,  7],
                 [ 8,  9, 10, 11],
                 [12, 13, 14, 15]], dtype=int32)
    
    >>> patch_size = 2
    >>> batch, height, width, channels = inputs.shape
    >>> height, width = height // patch_size, width // patch_size
    >>> x = jnp.reshape(inputs, (batch, height, patch_size, width, patch_size, channels))
    >>> x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    >>> x[0, 0, :]
    
    DeviceArray([0, 1, 2, 3], dtype=int32)
    

    We see that the first patch extracted is not the patch containing [0, 1, 4, 5], but the horizontal stripe [0, 1, 2, 3]. To fix this problem, we should add an axis swap. For ExtractPatches, this should be:

    batch, height, width, channels = inputs.shape
    height, width = height // patch_size, width // patch_size
    x = jnp.reshape(
        inputs, (batch, height, patch_size, width, patch_size, channels)
    )
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    

    For MergePatches, this should be:

    batch, length, _ = inputs.shape
    height = width = int(length**0.5)
    x = jnp.reshape(inputs, (batch, height, width, patch_size, patch_size, -1))
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * patch_size, width * patch_size, -1))
    
    bug 
    opened by young-geng 4
  • fix convnext to make it work with jax.jit

    fix convnext to make it work with jax.jit

    Hey, first of all, thanks for the nice codebase. When doing inference using the convnext model, I noticed the following issue:

    Calling x.item() will call float(x), which breaks the jit tracer. We can remove the list comprehension in unnecessary conversion to make jax.jit work. Without jax.jit, the model is very slow for me, running with only ~30% GPU utilization (RTX 3090).

    This issue could apply to other models as well, maybe it is a good idea to include a test for applying jax.jit to each model?

    opened by maxidl 1
Releases(v0.5-van)
Owner
Helping Machines Learn Better 💻😃
Joint project of the duo Hacker Ninjas

Project Smoothie Společný projekt dua Hacker Ninjas. První pokus o hříčku po třech týdnech učení se programování. Jakub Kolář e:\

Jakub Kolář 2 Jan 07, 2022
This program will stylize your photos with fast neural style transfer.

Neural Style Transfer (NST) Using TensorFlow Demo TensorFlow TensorFlow is an end-to-end open source platform for machine learning. It has a comprehen

Ismail Boularbah 1 Aug 08, 2022
Automatically measure the facial Width-To-Height ratio and get facial analysis results provided by Microsoft Azure

fwhr-calc-website This project is to automatically measure the facial Width-To-Height ratio and get facial analysis results provided by Microsoft Azur

SoohyunPark 1 Feb 07, 2022
Joint Channel and Weight Pruning for Model Acceleration on Mobile Devices

Joint Channel and Weight Pruning for Model Acceleration on Mobile Devices Abstract For practical deep neural network design on mobile devices, it is e

11 Dec 30, 2022
The Body Part Regression (BPR) model translates the anatomy in a radiologic volume into a machine-interpretable form.

Copyright © German Cancer Research Center (DKFZ), Division of Medical Image Computing (MIC). Please make sure that your usage of this code is in compl

MIC-DKFZ 40 Dec 18, 2022
Image Segmentation with U-Net Algorithm on Carvana Dataset using AWS Sagemaker

Image Segmentation with U-Net Algorithm on Carvana Dataset using AWS Sagemaker This is a full project of image segmentation using the model built with

Htin Aung Lu 1 Jan 04, 2022
PyTorch implementation of the ACL, 2021 paper Parameter-efficient Multi-task Fine-tuning for Transformers via Shared Hypernetworks.

Parameter-efficient Multi-task Fine-tuning for Transformers via Shared Hypernetworks This repo contains the PyTorch implementation of the ACL, 2021 pa

Rabeeh Karimi Mahabadi 98 Dec 28, 2022
IAST: Instance Adaptive Self-training for Unsupervised Domain Adaptation (ECCV 2020)

This repo is the official implementation of our paper "Instance Adaptive Self-training for Unsupervised Domain Adaptation". The purpose of this repo is to better communicate with you and respond to y

CVSM Group - email: <a href=[email protected]"> 84 Dec 12, 2022
Code accompanying paper: Meta-Learning to Improve Pre-Training

Meta-Learning to Improve Pre-Training This folder contains code to run experiments in the paper Meta-Learning to Improve Pre-Training, NeurIPS 2021. P

28 Dec 31, 2022
A Pytorch Implementation of a continuously rate adjustable learned image compression framework.

GainedVAE A Pytorch Implementation of a continuously rate adjustable learned image compression framework, Gained Variational Autoencoder(GainedVAE). N

39 Dec 24, 2022
SelfAugment extends MoCo to include automatic unsupervised augmentation selection.

SelfAugment extends MoCo to include automatic unsupervised augmentation selection. In addition, we've included the ability to pretrain on several new datasets and included a wandb integration.

Colorado Reed 24 Oct 26, 2022
Repository for self-supervised landmark discovery

self-supervised-landmarks Repository for self-supervised landmark discovery Requirements pytorch pynrrd (for 3d images) Usage The use of this models i

Riddhish Bhalodia 2 Apr 18, 2022
[CVPR'2020] DeepDeform: Learning Non-rigid RGB-D Reconstruction with Semi-supervised Data

DeepDeform (CVPR'2020) DeepDeform is an RGB-D video dataset containing over 390,000 RGB-D frames in 400 videos, with 5,533 optical and scene flow imag

Aljaz Bozic 165 Jan 09, 2023
Simple Tensorflow implementation of Toward Spatially Unbiased Generative Models (ICCV 2021)

Spatial unbiased GANs — Simple TensorFlow Implementation [Paper] : Toward Spatially Unbiased Generative Models (ICCV 2021) Abstract Recent image gener

Junho Kim 16 Apr 15, 2022
DataCLUE: 国内首个以数据为中心的AI测评(含模型分析报告)

DataCLUE: A Benchmark Suite for Data-centric NLP You can get the english version of README. 以数据为中心的AI测评(DataCLUE) 内容导引 章节 描述 简介 介绍以数据为中心的AI测评(DataCLUE

CLUE benchmark 135 Dec 22, 2022
MoveNet Single Pose on OpenVINO

MoveNet Single Pose tracking on OpenVINO Running Google MoveNet Single Pose models on OpenVINO. A convolutional neural network model that runs on RGB

35 Nov 11, 2022
A python library for face detection and features extraction based on mediapipe library

FaceAnalyzer A python library for face detection and features extraction based on mediapipe library Introduction FaceAnalyzer is a library based on me

Saifeddine ALOUI 14 Dec 30, 2022
An implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019).

MixHop and N-GCN ⠀ A PyTorch implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019)

Benedek Rozemberczki 393 Dec 13, 2022
Implementation of "A MLP-like Architecture for Dense Prediction"

A MLP-like Architecture for Dense Prediction (arXiv) Updates (22/07/2021) Initial release. Model Zoo We provide CycleMLP models pretrained on ImageNet

Shoufa Chen 244 Dec 27, 2022
This is a collection of all challenges in HKCERT CTF 2021

香港網絡保安新生代奪旗挑戰賽 2021 (HKCERT CTF 2021) This is a collection of all challenges (and writeups) in HKCERT CTF 2021 Challenges ID Chinese name Name Score S

10 Jan 27, 2022