PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

Related tags

Deep Learningpytorch
Overview

PyTorch-LIT

PyPI version

PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

With the rapid growth of deep learning research, models are becoming increasingly complex in terms of parameters and complexity, making it difficult to run the models on currently available end devices. For example, GPT-J with 6B parameters only needs 24 GB of RAM in full-precision mode to be ready for execution, which may be impossible in most systems; even a powerful GPU like the RTX 2060 with 6 GB of memory can't even contain GPT-J in half-precision mode, making direct inference impossible.

To address this issue when training large models, libraries such as DeepSpeed use offload techniques (e.g., ZeRO) to handle the parameters and make training possible by dividing the weights between devices. In contrast, there is no direct library/framework available for inference.

PyTorch-LIT allows the inference of large models by loading weights as needed from secondary specified memory, which could be disk, CPU, or GPU, allowing the inference of models that do not even fit in the system's main memory simply by trading off time.

Quick Start

  1. Install the library
pip install pytorch-lit
  1. You have to save the model's weight in a way that toolkit can use
from pytorch_lit.export import prepare_params

weights = {} # your model's parameters (state_dict)
# change the directory to save your model and specify data-type
prepare_params(weights, ".models/my-model", dtype="float32")
  1. After preparing the weights, you can infer your model
from pytorch_lit import LitModule

# pass your model construction as a closure, 
# specify weights path and inference device 
model = LitModule.from_params(".models/my-model",
                                  lambda: MyModel(),
                                  device="cuda")
result = model(*arg, **kwargs)
  1. Have fun enjoying the inference of the large model on a lower memory device:)

Examples

The repo's examples directory contains examples. There are currently two examples of GPT-J, one for text generation and the other for extracting hidden states as feature representations.

Development

This is a work in progress that will require further development before it can be considered a stable inference toolkit. Here is a list of potential future developments:

  • Caching and batch loading as many weights as memory allows, with weights being replaced in parallel with future ones (through the order of the execution graph)
  • C++ extension for PyTorch jit, so the solution applies to the majority of production end devices
  • Add functions to make it easier to export large models to onnx or trace with jit
  • Use better and faster format than numpy memmap

Contributions are welcome; to discuss your idea further, open an issue with the discussion tag. Finally, you can submit a pull request to merge your fork.

How does it work?

This implementation was made possible primarily by two ideas:

  • The first issue was that PyTorch initialized the model object's parameters when constructing it, causing the construction to fail when the model couldn't fit into memory. To address this, we proposed temporarily hijacking PyTorch's Parameter class's __new__ method during model construction, allowing us to replace the parameter's tensor with a view from a shared global tensor immediately after creation. By doing so, all parameters use the same shared big tensor as their primary storage, allowing the model to be built and tested with inputs to follow and trace the execution graph.
  • The second issue was the large size of model parameters; in the preparation step, we built a numpy memmap(np.memmap) and saved metadata that provided us with the location of each key in the memmap. This allowed us to read parameters from the memmap as needed. Following that, we use the PyTorch hooks (forward and pre_forward) to load and unload a module's parameters before and after execution.

Citation

Please cite PyTorch-LIT if it helps your research. You can use the following BibTeX entry:

@misc{pytorch_lit,
	title = {PyTorch-LIT},
	author = {Rezaei, Amin},
	howpublished = {\url{github.com/AminRezaei0x443/PyTorch-LIT}},
	year = {2021}
}
You might also like...
FPGA: Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification
FPGA: Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification

FPGA & FreeNet Fast Patch-Free Global Learning Framework for Fully End-to-End Hyperspectral Image Classification by Zhuo Zheng, Yanfei Zhong, Ailong M

 WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU
WarpDrive: Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning on a GPU

WarpDrive is a flexible, lightweight, and easy-to-use open-source reinforcement learning (RL) framework that implements end-to-end multi-agent RL on a single GPU (Graphics Processing Unit).

this is a lite easy to use virtual keyboard project for anyone to use
this is a lite easy to use virtual keyboard project for anyone to use

virtual_Keyboard this is a lite easy to use virtual keyboard project for anyone to use motivation I made this for this year's recruitment for RobEn AA

Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.
Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite.

TFlite Ultra Fast Lane Detection Inference Example scripts for the detection of lanes using the ultra fast lane detection model in Tensorflow Lite. So

Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.
Learning recognition/segmentation models without end-to-end training. 40%-60% less GPU memory footprint. Same training time. Better performance.

InfoPro-Pytorch The Information Propagation algorithm for training deep networks with local supervision. (ICLR 2021) Revisiting Locally Supervised Lea

Code & Models for 3DETR - an End-to-end transformer model for 3D object detection
Code & Models for 3DETR - an End-to-end transformer model for 3D object detection

3DETR: An End-to-End Transformer Model for 3D Object Detection PyTorch implementation and models for 3DETR. 3DETR (3D DEtection TRansformer) is a simp

Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models
Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models

Python scripts to detect faces using Python with the BlazeFace Tensorflow Lite models. Tested on Windows 10, Tensorflow 2.4.0 (Python 3.8).

A repository that shares tuning results of trained models generated by TensorFlow / Keras. Post-training quantization (Weight Quantization, Integer Quantization, Full Integer Quantization, Float16 Quantization), Quantization-aware training. TensorFlow Lite. OpenVINO. CoreML. TensorFlow.js. TF-TRT. MediaPipe. ONNX. [.tflite,.h5,.pb,saved_model,tfjs,tftrt,mlmodel,.xml/.bin, .onnx] An end-to-end PyTorch framework for image and video classification
An end-to-end PyTorch framework for image and video classification

What's New: March 2021: Added RegNetZ models November 2020: Vision Transformers now available, with training recipes! 2020-11-20: Classy Vision v0.5 R

Comments
  • RuntimeError : OrderdDict mutated during iteration.

    RuntimeError : OrderdDict mutated during iteration.

    Hi, there are new problems. When the model parameters forward, raise a RuntimeError : OrderdDict mutated during iteration. detail as below: Traceback (most recent call last): File "nlp/rct-FPM-rhino/big_model/predict.py", line 24, in result = model(**tokens) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/inference.py", line 34, in call return self.forward(*args, **kwargs) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/inference.py", line 31, in forward return self.module(*args, **kwargs) File "miniconda3/envs/rhino/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1057, in _call_impl for hook in itertools.chain( RuntimeError: OrderedDict mutated during iteration

    enviroments:

    GPU:NVIDIA GeForce 3090 CUDA version 11.4 pip list: certifi 2021.10.8 charset-normalizer 2.0.8 click 8.0.3 filelock 3.4.0 huggingface-hub 0.2.0 idna 3.3 joblib 1.1.0 numpy 1.21.4 packaging 21.3 Pillow 8.4.0 pip 21.2.4 pyparsing 3.0.6 pytorch-lit 0.1.7 PyYAML 6.0 regex 2021.11.10 requests 2.26.0 sacremoses 0.0.46 setuptools 58.0.4 six 1.16.0 tokenizer 3.3.2 tokenizers 0.10.3 torch 1.9.1+cu111 torchaudio 0.8.1 torchvision 0.9.1+cu111 tqdm 4.62.3 transformers 4.12.5 typing_extensions 4.0.1 urllib3 1.26.7

    I think this problem caused by PyTorch hooks (forward and pre_forward) to load and unload a module's parameters before and after execution, when load and unload the parameters,the OrderedDict was be mutated.

    opened by changleilei 9
  • TypeError: <lambda>() missing 1 required positional argument: 'k'

    TypeError: () missing 1 required positional argument: 'k'

    Hello, when i use pytorch-lit prepare a model, got a TypeError as title. The detail as blow:

    File "nlp/rct-FPM-rhino/big_model/prepare_model.py", line 16, in prepare_model prepare_params(model, args.save_path, dtype='float32') File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 19, in prepare_params _params_to_memmap(parameters, path.join(save_dir, "model.bin"), File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 52, in _params_to_memmap param = get_param(k) File "miniconda3/envs/rhino/lib/python3.8/site-packages/pytorch_lit/export.py", line 50, in get_param = lambda key: params"get" TypeError: () missing 1 required positional argument: 'k'

    package list:

    certifi 2021.10.8 numpy 1.21.4 pip 21.2.4 pytorch-lit 0.1.6 setuptools 58.0.4 torch 1.10.0 tqdm 4.62.3 typing_extensions 4.0.1 wheel 0.37.0

    model: gpt-j-6B

    Have any suggesstion? Thanks.

    opened by changleilei 1
  • gpt-j generation speed very low

    gpt-j generation speed very low

    The output of gpt-j is very slow, for a 200 output token generation it takes about 20 minutes, for 2048 it takes more than an hour, this significantly limits any experimentation with the model.

    I checked Gpu utilization during inference which is about 1 percent or 4 percent, and gpu memory usage is below 4GB usage, my system has 8GB Gpu memory, if full Gpu is utilized it may be significantly increase the inference speed

    Are their simple hacks to speedup inference time ?

    opened by usama-ahmedkhan 3
  • Weights file format is changed, function partial_loader fails

    Weights file format is changed, function partial_loader fails

    Hi, thanks for your effort for making it easy to load and do inference from large models. I tried your code on a gpt-j model with different model file format, the weight files of the model are in several .pt files not like a single .bin file which your code function partial_loader() expects, does the code work with multiple weight file ? , how can i change it.

    opened by usama-ahmedkhan 4
Releases(0.1.7)
Owner
Amin Rezaei
Computer Science BSc, Neural Networks Enthusiast
Amin Rezaei
This script runs neural style transfer against the provided content image.

Neural Style Transfer Content Style Output Description: This script runs neural style transfer against the provided content image. The content image m

Martynas Subonis 0 Nov 25, 2021
Tensorflow port of a full NetVLAD network

netvlad_tf The main intention of this repo is deployment of a full NetVLAD network, which was originally implemented in Matlab, in Python. We provide

Robotics and Perception Group 225 Nov 08, 2022
Official Implementation of "LUNAR: Unifying Local Outlier Detection Methods via Graph Neural Networks"

LUNAR Official Implementation of "LUNAR: Unifying Local Outlier Detection Methods via Graph Neural Networks" Adam Goodge, Bryan Hooi, Ng See Kiong and

Adam Goodge 25 Dec 28, 2022
Code Release for Learning to Adapt to Evolving Domains

EAML Code release for "Learning to Adapt to Evolving Domains" (NeurIPS 2020) Prerequisites PyTorch = 0.4.0 (with suitable CUDA and CuDNN version) tor

23 Dec 07, 2022
Official PyTorch implementation of BlobGAN: Spatially Disentangled Scene Representations

BlobGAN: Spatially Disentangled Scene Representations Official PyTorch Implementation Paper | Project Page | Video | Interactive Demo BlobGAN.mp4 This

148 Dec 29, 2022
A Pytorch Implementation of a continuously rate adjustable learned image compression framework.

GainedVAE A Pytorch Implementation of a continuously rate adjustable learned image compression framework, Gained Variational Autoencoder(GainedVAE). N

39 Dec 24, 2022
Answer a series of contextually-dependent questions like they may occur in natural human-to-human conversations.

SCAI-QReCC-21 [leaderboards] [registration] [forum] [contact] [SCAI] Answer a series of contextually-dependent questions like they may occur in natura

19 Sep 28, 2022
Magisk module to enable hidden features on Android 12 Developer Preview 1.

Android 12 Extensions This is a Magisk module that enables hidden features on Android 12 Developer Preview 1. Features Scrolling screenshots Wallpaper

Danny Lin 384 Jan 06, 2023
PyTorch implementation of convolutional neural networks-based text-to-speech synthesis models

Deepvoice3_pytorch PyTorch implementation of convolutional networks-based text-to-speech synthesis models: arXiv:1710.07654: Deep Voice 3: Scaling Tex

Ryuichi Yamamoto 1.8k Jan 08, 2023
Implementation of UNet on the Joey ML framework

Independent Research Project - Code Joey can be cloned from here https://github.com/devitocodes/joey/. Devito and other dependencies such as PyTorch a

Navjot Kukreja 1 Oct 21, 2021
Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition in CVPR19

2s-AGCN Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition in CVPR19 Note PyTorch version should be 0.3! For PyTor

LShi 547 Dec 26, 2022
Lane follower: Lane-detector (OpenCV) + Object-detector (YOLO5) + CAN-bus

Lane Follower This code is for the lane follower, including perception and control, as shown below. Environment Hardware Industrial Camera Intel-NUC(1

Siqi Fan 3 Jul 07, 2022
Half Instance Normalization Network for Image Restoration

HINet Half Instance Normalization Network for Image Restoration, based on https://github.com/megvii-model/HINet. Dependencies NumPy PyTorch, preferabl

Holy Wu 4 Jun 06, 2022
Grow Function: Generate 3D Stacked Bifurcating Double Deep Cellular Automata based organisms which differentiate using a Genetic Algorithm...

Grow Function: A 3D Stacked Bifurcating Double Deep Cellular Automata which differentiates using a Genetic Algorithm... TLDR;High Def Trees that you can mint as NFTs on Solana

Nathaniel Gibson 4 Oct 08, 2022
Moiré Attack (MA): A New Potential Risk of Screen Photos [NeurIPS 2021]

Moiré Attack (MA): A New Potential Risk of Screen Photos [NeurIPS 2021] This repository is the official implementation of Moiré Attack (MA): A New Pot

Dantong Niu 22 Dec 24, 2022
Improving Calibration for Long-Tailed Recognition (CVPR2021)

MiSLAS Improving Calibration for Long-Tailed Recognition Authors: Zhisheng Zhong, Jiequan Cui, Shu Liu, Jiaya Jia [arXiv] [slide] [BibTeX] Introductio

Jia Research Lab 116 Dec 20, 2022
A C implementation for creating 2D voronoi diagrams

Branch OSX/Linux Windows master dev jc_voronoi A fast C/C++ header only implementation for creating 2D Voronoi diagrams from a point set Uses Fortune'

Mathias Westerdahl 481 Dec 29, 2022
Pytorch version of SfmLearner from Tinghui Zhou et al.

SfMLearner Pytorch version This codebase implements the system described in the paper: Unsupervised Learning of Depth and Ego-Motion from Video Tinghu

Clément Pinard 909 Dec 22, 2022
Official PyTorch implementation of "Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks" (AAAI 2022)

Preemptive Image Robustification for Protecting Users against Man-in-the-Middle Adversarial Attacks This is the code for reproducing the results of th

2 Dec 27, 2021
Utility code for use with PyXLL

pyxll-utils There is no need to use this package as of PyXLL 5. All features from this package are now provided by PyXLL. If you were using this packa

PyXLL 10 Dec 18, 2021