Implementation of Feedback Transformer in Pytorch

Overview

Feedback Transformer - Pytorch

Simple implementation of Feedback Transformer in Pytorch. They improve on Transformer-XL by having each token have access to the representations of all previous layers through time. This is achieved by aggregating the outputs of all layers into a shared memory, which each token across layers can attend to at each time step.

The main drawback is longer training time, due to its non-parallel nature. But I thought I'd build it to further exploration and research into this line of work.

Yannic Kilcher video

I also took the liberty to add some various enhancements, including pre-normalization, GLU gated feedforwards, as well as simplified T5 relative positional embeddings.

Install

$ pip install feedback-transformer-pytorch

Usage

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,           # number of tokens
    dim = 512,                    # dimension
    depth = 6,                    # depth
    seq_len = 2,                  # the sequence length of each segment or window
    mem_len = 256,                # length of the memory buffer
    dim_head = 64,                # dimension of each head
    heads = 8,                    # number of heads
    attn_dropout = 0.1,           # attention dropout
    ff_dropout = 0.1              # feedforward dropout
).cuda()

x = torch.randint(0, 20000, (2, 64)).cuda()
model(x)  # (2, 64, 20000)

If you would like to have fine control over the memory (when to detach, etc), you can do it with some extra keyword arguments on .forward

import torch
from feedback_transformer_pytorch import FeedbackTransformer

model = FeedbackTransformer(
    num_tokens = 20000,
    dim = 512,
    depth = 6,
    seq_len = 32,
    mem_len = 256
).cuda()

x1 = torch.randint(0, 20000, (2, 32)).cuda()
x2 = torch.randint(0, 20000, (2, 32)).cuda()
x3 = torch.randint(0, 20000, (2, 32)).cuda()

out1, mem1 = model(x1, return_memory = True)
out2, mem2 = model(x2, memory = mem1, return_memory = True)
out3, mem3 = model(x3, memory = mem2, return_memory = True)  # (2, 32, 20000)

Citations

@misc{fan2021addressing,
    title   = {Addressing Some Limitations of Transformers with Feedback Memory}, 
    author  = {Angela Fan and Thibaut Lavril and Edouard Grave and Armand Joulin and Sainbayar Sukhbaatar},
    year    = {2021},
    eprint  = {2002.09402},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • Should it really be using lower layers output for keys and values?

    Should it really be using lower layers output for keys and values?

    Could you explain the logic of how the key-value pairs are formed at these lines and whether it is necessary?

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/d7d8939910d1491f01a3d93ce81d4663925fb389/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L146-L151

    It looks to me that line 146 transforms the output of the layer below (x) to keys and values, and the following lines combine these keys and values with the memory. I thought that x should only be used for forming the query here, and only the existing memory is used for keys and values.

    opened by tarvaina 6
  • In place operation with gradient

    In place operation with gradient

    https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L173 I think this is an error.

    opened by hadaev8 4
  • Bug in weighted sum

    Bug in weighted sum

    Bug in https://github.com/lucidrains/feedback-transformer-pytorch/blob/main/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L264

    Should be layer_weight = rearrange(layer_weight, 'd -> d () () ()')

    opened by Victor0118 1
  • Input/Output dimensions

    Input/Output dimensions

    Hey @lucidrains

    Can I check the dimensions of the input and output, is it (seq_len, dim) -> (? ,dim, tokens)?

    model = FeedbackTransformer(
        num_tokens = 20000,           # number of tokens
        dim = 512,                    # dimension
        depth = 6,                    # depth
        seq_len = 2,                  # the sequence length of each segment or window
        mem_len = 256,                # length of the memory buffer
        dim_head = 64,                # dimension of each head
        heads = 8,                    # number of heads
        attn_dropout = 0.1,           # attention dropout
        ff_dropout = 0.1              # feedforward dropout
    ).cuda()
    
    x = torch.randint(0, 256, (2, 512)).cuda()
    model(x)  # (1, 512, 20000)
    
    opened by iiSeymour 1
  • Non intuitive memory usage with cross attention

    Non intuitive memory usage with cross attention

    Give simple 256 dim and 512 len tensor and memory len 16 feedback transformer uses 3.6gm memory after forward pass. With cross attention on 100 len tensor usage grows to 14gb.

    While parallel version uses 3.1gb and 3.5gb.

    Notebooks for testing https://colab.research.google.com/drive/1dRImydFn3WthOXdLYIvdf5bsqjXcmhC5?usp=sharing https://colab.research.google.com/drive/1n653j4Pz9_U7OukhTlUbomAHMvpPXwx0?usp=sharing

    opened by hadaev8 0
  • I think mask padding value should be False

    I think mask padding value should be False

    Here https://github.com/lucidrains/feedback-transformer-pytorch/blob/with-cross-attention/feedback_transformer_pytorch/feedback_transformer_pytorch.py#L181

    opened by hadaev8 0
  • ETA for the enwiki8 example

    ETA for the enwiki8 example

    Hey @lucidrains,

    Any eta on the example for auto-regressive enwiki8 example? I and others would really appreciate it as always :)

    Also, if you can provide an example for training on custom line-by-line TXT datasets, it would be absolutely fantastic.

    Thank you.

    opened by asigalov61 0
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Official pytorch implementation of Active Learning for deep object detection via probabilistic modeling (ICCV 2021)

Active Learning for Deep Object Detection via Probabilistic Modeling This repository is the official PyTorch implementation of Active Learning for Dee

NVIDIA Research Projects 130 Jan 06, 2023
Systemic Evolutionary Chemical Space Exploration for Drug Discovery

SECSE SECSE: Systemic Evolutionary Chemical Space Explorer Chemical space exploration is a major task of the hit-finding process during the pursuit of

64 Dec 16, 2022
The implementation of the CVPR2021 paper "Structure-Aware Face Clustering on a Large-Scale Graph with 10^7 Nodes"

STAR-FC This code is the implementation for the CVPR 2021 paper "Structure-Aware Face Clustering on a Large-Scale Graph with 10^7 Nodes" 🌟 🌟 . 🎓 Re

Shuai Shen 87 Dec 28, 2022
Dynamical Wasserstein Barycenters for Time Series Modeling

Dynamical Wasserstein Barycenters for Time Series Modeling This is the code related for the Dynamical Wasserstein Barycenter model published in Neurip

8 Sep 09, 2022
Sentiment analysis translations of the Bhagavad Gita

Sentiment and Semantic Analysis of Bhagavad Gita Translations It is well known that translations of songs and poems not only breaks rhythm and rhyming

Machine learning and Bayesian inference @ UNSW Sydney 3 Aug 01, 2022
Training vision models with full-batch gradient descent and regularization

Stochastic Training is Not Necessary for Generalization -- Training competitive vision models without stochasticity This repository implements trainin

Jonas Geiping 32 Jan 06, 2023
Building Ellee — A GPT-3 and Computer Vision Powered Talking Robotic Teddy Bear With Human Level Conversation Intelligence

Using an object detection and facial recognition system built on MobileNetSSDV2 and Dlib and running on an NVIDIA Jetson Nano, a GPT-3 model, Google Speech Recognition, Amazon Polly and servo motors,

24 Oct 26, 2022
Face Library is an open source package for accurate and real-time face detection and recognition

Face Library Face Library is an open source package for accurate and real-time face detection and recognition. The package is built over OpenCV and us

52 Nov 09, 2022
Official Python implementation of the FuzionCoin protocol

PyFuzc Official Python implementation of the FuzionCoin protocol WARNING: Under construction. Use at your own risk. Some functions may not work. Setup

FuzionCoin 3 Jul 07, 2022
Aspect-Sentiment-Multiple-Opinion Triplet Extraction (NLPCC 2021)

The code and data for the paper "Aspect-Sentiment-Multiple-Opinion Triplet Extraction" Requirements Python 3.6.8 torch==1.2.0 pytorch-transformers==1.

慢半拍 5 Jul 02, 2022
CharacterGAN: Few-Shot Keypoint Character Animation and Reposing

CharacterGAN Implementation of the paper "CharacterGAN: Few-Shot Keypoint Character Animation and Reposing" by Tobias Hinz, Matthew Fisher, Oliver Wan

Tobias Hinz 181 Dec 27, 2022
PyTorch/GPU re-implementation of the paper Masked Autoencoders Are Scalable Vision Learners

Masked Autoencoders: A PyTorch Implementation This is a PyTorch/GPU re-implementation of the paper Masked Autoencoders Are Scalable Vision Learners: @

Meta Research 4.8k Jan 04, 2023
naked is a Python tool which allows you to strip a model and only keep what matters for making predictions.

naked is a Python tool which allows you to strip a model and only keep what matters for making predictions. The result is a pure Python function with no third-party dependencies that you can simply c

Max Halford 24 Dec 20, 2022
Training Certifiably Robust Neural Networks with Efficient Local Lipschitz Bounds (Local-Lip)

Training Certifiably Robust Neural Networks with Efficient Local Lipschitz Bounds (Local-Lip) Introduction TL;DR: We propose an efficient and trainabl

17 Dec 01, 2022
Repo for the Video Person Clustering dataset, and code for the associated paper

Video Person Clustering Repo for the Video Person Clustering dataset, and code for the associated paper. This reporsitory contains the Video Person Cl

Andrew Brown 47 Nov 02, 2022
StarGAN v2 - Official PyTorch Implementation (CVPR 2020)

StarGAN v2 - Official PyTorch Implementation StarGAN v2: Diverse Image Synthesis for Multiple Domains Yunjey Choi*, Youngjung Uh*, Jaejun Yoo*, Jung-W

Clova AI Research 3.1k Jan 09, 2023
Short and long time series classification using convolutional neural networks

time-series-classification Short and long time series classification via convolutional neural networks In this project, we present a novel framework f

35 Oct 22, 2022
Codes and pretrained weights for winning submission of 2021 Brain Tumor Segmentation (BraTS) Challenge

Winning submission to the 2021 Brain Tumor Segmentation Challenge This repo contains the codes and pretrained weights for the winning submission to th

94 Dec 28, 2022
Distributed Arcface Training in Pytorch

Distributed Arcface Training in Pytorch

3 Nov 23, 2021
A self-supervised 3D representation learning framework named viewpoint bottleneck.

Pointly-supervised 3D Scene Parsing with Viewpoint Bottleneck Paper Created by Liyi Luo, Beiwen Tian, Hao Zhao and Guyue Zhou from Institute for AI In

63 Aug 11, 2022