Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Overview

Transformer in Transformer

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch.

Install

$ pip install transformer-in-transformer

Usage

import torch
from transformer_in_transformer import TNT

tnt = TNT(
    image_size = 256,       # size of image
    patch_dim = 512,        # dimension of patch token
    pixel_dim = 24,         # dimension of pixel token
    patch_size = 16,        # patch size
    pixel_size = 4,         # pixel size
    depth = 6,              # depth
    num_classes = 1000,     # output number of classes
    attn_dropout = 0.1,     # attention dropout
    ff_dropout = 0.1        # feedforward dropout
)

img = torch.randn(2, 3, 256, 256)
logits = tnt(img) # (2, 1000)

Citations

@misc{han2021transformer,
    title   = {Transformer in Transformer}, 
    author  = {Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
    year    = {2021},
    eprint  = {2103.00112},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Comments
  • Only works if pixel_size**2 == patch_size?

    Only works if pixel_size**2 == patch_size?

    Hi, is this only supposed to work if

    pixel_size**2 == patch_size 
    

    ?. When setting the patch_size to any number that doesn't fulfill the equation this error occurs:

    --> 146         pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d')
        147 
        148         for pixel_attn, pixel_ff, pixel_to_patch_residual, patch_attn, patch_ff in self.layers:
    
    RuntimeError: The size of tensor a (4) must match the size of tensor b (64) at non-singleton dimension 1
    

    The error came when running

    tnt = TNT(
        image_size = 128,       # size of image
        patch_dim = 256,        # dimension of patch token
        pixel_dim = 24,         # dimension of pixel token
        patch_size = 16,        # patch size
        pixel_size = 2,         # pixel size
        depth = 6,              # depth
        heads = 1,
        num_classes = 2,     # output number of classes
        attn_dropout = 0.1,     # attention dropout
        ff_dropout = 0.1        # feedforward dropout,
    )
    img = torch.randn(2, 3, 128, 128)
    logits = tnt(img)
    

    Since I am completely new to einops its quite hard for me to debug :D Thanks

    opened by PhilippMarquardt 1
  • Not sure what is wrong!

    Not sure what is wrong!


    RuntimeError Traceback (most recent call last) in 14 15 img = torch.randn(1, 3, 256, 256) ---> 16 logits = tnt(img) # (2, 1000)

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []

    ~/opt/anaconda3/envs/ml/lib/python3.8/site-packages/transformer_in_transformer/tnt.py in forward(self, x) 159 patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b) 160 --> 161 patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d') 162 pixels += rearrange(self.pixel_pos_emb, 'n d -> () n d') 163

    RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.

    opened by RisabBiswas 0
  • patch_tokens vs patch_pos_emb

    patch_tokens vs patch_pos_emb

    Hi!

    I'm trying to understand your TNT implementation and one thing that got me a bit confused is why there are 2 parameters patch_tokens and patch_pos_emb which seems to have the same purpose - to encode patch position. Isn't one of them redundant?

    self.patch_tokens = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    self.patch_pos_emb = nn.Parameter(torch.randn(num_patch_tokens + 1, patch_dim))
    ...
    patches = repeat(self.patch_tokens[:(n + 1)], 'n d -> b n d', b = b)
    patches += rearrange(self.patch_pos_emb[:(n + 1)], 'n d -> () n d')
    
    opened by stas-sl 0
  • Inconsistent model  params with MindSpore src code

    Inconsistent model params with MindSpore src code

    There's no function or readme description of TNT-S/TNT-B model in this codebase. Something like :

    def tnt_b(num_class):
        return TNT(img_size=384,
                   patch_size=16,
                   num_channels=3,
                   embedding_dim=640,
                   num_heads=10,
                   num_layers=12,
                   hidden_dim=640*4,
                   stride=4,
                   num_class=num_class)
    

    And heads number of inner block should be 4.... https://github.com/lucidrains/transformer-in-transformer/blob/main/transformer_in_transformer/tnt.py#L135

    Wondering if anyone reproduce the paper reported results with this codebase??

    opened by WongChen 0
  • Why the loss become NaN?

    Why the loss become NaN?

    It is a great project. I am very interested in Transformer in Transformer model. I had use your model to train on Vehicle-1M dataset. Vehicle-1M is a fine graied visual classification dataset. When I use this model the loss become NaN after some batch iteration. I had decrease the learning rate of AdamOptimizer and clipping the graident torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2) . But the loss still will become NaN sometimes. It seems that gradients are not big but they are in the same direction for many iterations. How to solve it?

    opened by yt7589 3
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Official code for: A Probabilistic Hard Attention Model For Sequentially Observed Scenes

"A Probabilistic Hard Attention Model For Sequentially Observed Scenes" Authors: Samrudhdhi Rangrej, James Clark Accepted to: BMVC'21 A recurrent atte

5 Nov 19, 2022
Pytorch implementation of "Forward Thinking: Building and Training Neural Networks One Layer at a Time"

forward-thinking-pytorch Pytorch implementation of Forward Thinking: Building and Training Neural Networks One Layer at a Time Requirements Python 2.7

Kim Heecheol 65 Oct 06, 2022
Tutorials, assignments, and competitions for MIT Deep Learning related courses.

MIT Deep Learning This repository is a collection of tutorials for MIT Deep Learning courses. More added as courses progress. Tutorial: Deep Learning

Lex Fridman 9.5k Jan 07, 2023
DynaTune: Dynamic Tensor Program Optimization in Deep Neural Network Compilation

DynaTune: Dynamic Tensor Program Optimization in Deep Neural Network Compilation This repository is the implementation of DynaTune paper. This folder

4 Nov 02, 2022
Multi-robot collaborative exploration and mapping through Voronoi partition and DRL in unknown environment

Voronoi Multi_Robot Collaborate Exploration Introduction In the unknown environment, the cooperative exploration of multiple robots is completed by Vo

PeaceWord 6 Nov 22, 2022
We present a regularized self-labeling approach to improve the generalization and robustness properties of fine-tuning.

Overview This repository provides the implementation for the paper "Improved Regularization and Robustness for Fine-tuning in Neural Networks", which

NEU-StatsML-Research 21 Sep 08, 2022
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022
Tom-the-AI - A compound artificial intelligence software for Linux systems.

Tom the AI (version 0.82) WARNING: This software is not yet ready to use, I'm still setting up the GitHub repository. Should be ready in a few days. T

2 Apr 28, 2022
Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Aviv Gabbay 41 Nov 29, 2022
PyTorch implementation of the wavelet analysis from Torrence & Compo

Continuous Wavelet Transforms in PyTorch This is a PyTorch implementation for the wavelet analysis outlined in Torrence and Compo (BAMS, 1998). The co

Tom Runia 262 Dec 21, 2022
Crowd-sourced Annotation of Human Motion.

Motion Annotation Tool Live: https://motion-annotation.humanoids.kit.edu Paper: The KIT Motion-Language Dataset Installation Start by installing all P

Matthias Plappert 4 May 25, 2020
Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets).

TOQ-Nets-PyTorch-Release Pytorch implementation for the Temporal and Object Quantification Networks (TOQ-Nets). Temporal and Object Quantification Net

Zhezheng Luo 9 Jun 30, 2022
Repository for reproducing `Model-Based Robust Deep Learning`

Model-Based Robust Deep Learning (MBRDL) In this repository, we include the code necessary for reproducing the code used in Model-Based Robust Deep Le

Alex Robey 16 Sep 19, 2022
Data visualization app for H&M competition in kaggle

handm_data_visualize_app Data visualization app by streamlit for H&M competition in kaggle. competition page: https://www.kaggle.com/competitions/h-an

Kyohei Uto 12 Apr 30, 2022
This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

Nils L. Westhausen 182 Jan 07, 2023
LTR_CrossEncoder: Legal Text Retrieval Zalo AI Challenge 2021

LTR_CrossEncoder: Legal Text Retrieval Zalo AI Challenge 2021 We propose a cross encoder model (LTR_CrossEncoder) for information retrieval, re-retrie

Xuan Hieu Duong 7 Jan 12, 2022
Monitora la qualità della ricezione dei segnali radio nelle province siciliane.

FMap-server Monitora la qualità della ricezione dei segnali radio nelle province siciliane. Conversion data Frequency - StationName maps are stored in

Triglie 5 May 24, 2021
Official implementation of CATs: Cost Aggregation Transformers for Visual Correspondence NeurIPS'21

CATs: Cost Aggregation Transformers for Visual Correspondence NeurIPS'21 For more information, check out the paper on [arXiv]. Training with different

Sunghwan Hong 120 Jan 04, 2023
Quadruped-command-tracking-controller - Quadruped command tracking controller (flat terrain)

Quadruped command tracking controller (flat terrain) Prepare Install RAISIM link

Yunho Kim 4 Oct 20, 2022
PSANet: Point-wise Spatial Attention Network for Scene Parsing, ECCV2018.

PSANet: Point-wise Spatial Attention Network for Scene Parsing (in construction) by Hengshuang Zhao*, Yi Zhang*, Shu Liu, Jianping Shi, Chen Change Lo

Hengshuang Zhao 217 Oct 30, 2022