Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Overview

Deformable Attention

Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DETR. The relative positional embedding has also been modified for better extrapolation, using the Continuous Positional Embedding proposed in SwinV2.

Install

$ pip install deformable-attention

Usage

import torch
from deformable_attention import DeformableAttention

attn = DeformableAttention(
    dim = 512,                   # feature dimensions
    dim_head = 64,               # dimension per head
    heads = 8,                   # attention heads
    dropout = 0.,                # dropout
    downsample_factor = 4,       # downsample factor (r in paper)
    offset_scale = 4,            # scale of offset, maximum offset
    offset_groups = None,        # number of offset groups, should be multiple of heads
    offset_kernel_size = 6,      # offset kernel size
)

x = torch.randn(1, 512, 64, 64)
attn(x) # (1, 512, 64, 64)

3d deformable attention

import torch
from deformable_attention import DeformableAttention3D

attn = DeformableAttention3D(
    dim = 512,                          # feature dimensions
    dim_head = 64,                      # dimension per head
    heads = 8,                          # attention heads
    dropout = 0.,                       # dropout
    downsample_factor = (2, 8, 8),      # downsample factor (r in paper)
    offset_scale = (2, 8, 8),           # scale of offset, maximum offset
    offset_kernel_size = (4, 10, 10),   # offset kernel size
)

x = torch.randn(1, 512, 10, 32, 32) # (batch, dimension, frames, height, width)
attn(x) # (1, 512, 10, 32, 32)

1d deformable attention for good measure

import torch
from deformable_attention import DeformableAttention1D

attn = DeformableAttention1D(
    dim = 128,
    downsample_factor = 4,
    offset_scale = 2,
    offset_kernel_size = 6
)

x = torch.randn(1, 128, 512)
attn(x) # (1, 128, 512)

Citation

@misc{xia2022vision,
    title   = {Vision Transformer with Deformable Attention}, 
    author  = {Zhuofan Xia and Xuran Pan and Shiji Song and Li Erran Li and Gao Huang},
    year    = {2022},
    eprint  = {2201.00520},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
You might also like...
Some code of the implements of Geological Modeling Using 3D Pixel-Adaptive and Deformable Convolutional Neural Network

3D-GMPDCNN Geological Modeling Using 3D Pixel-Adaptive and Deformable Convolutional Neural Network PyTorch implementation of "Geological Modeling Usin

MoCoPnet - Deformable 3D Convolution for Video Super-Resolution
MoCoPnet - Deformable 3D Convolution for Video Super-Resolution

Deformable 3D Convolution for Video Super-Resolution Pytorch implementation of l

3D2Unet: 3D Deformable Unet for Low-Light Video Enhancement (PRCV2021)

3DDUNET This is the code for 3D2Unet: 3D Deformable Unet for Low-Light Video Enhancement (PRCV2021) Conference Paper Link Dataset We use SMOID dataset

Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones
Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones

HaloNet - Pytorch Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This re

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

The official pytorch implementation of our paper "Is Space-Time Attention All You Need for Video Understanding?"

TimeSformer This is an official pytorch implementation of Is Space-Time Attention All You Need for Video Understanding?. In this repository, we provid

An implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks in PyTorch.

Neural Attention Distillation This is an implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep

Comments
  • The relationship between 'dim' and 'inner_dim'

    The relationship between 'dim' and 'inner_dim'

    Hi, I have a question about DeformableAttention module,

    I calculated the output volumes of the forward processes step by step, According to my calculation, the code works only when 'dim' and 'inner_dim' is same.

    Is there any reason why you implement it this way?

    Best regards, Hankyu

    opened by hanq0212 4
  • TypeError: meshgrid() got an unexpected keyword argument 'indexing'

    TypeError: meshgrid() got an unexpected keyword argument 'indexing'

    @lucidrains while trying to perform import torch from deformable_attention import DeformableAttention

    attn = DeformableAttention( dim = 512, # feature dimensions dim_head = 64, # dimension per head heads = 8, # attention heads dropout = 0., # dropout downsample_factor = 4, # downsample factor (r in paper) offset_scale = 4, # scale of offset, maximum offset offset_groups = None, # number of offset groups, should be multiple of heads offset_kernel_size = 6, # offset kernel size )

    x = torch.randn(1, 512, 64, 64) attn(x)

    Got error below from the line.. Kindly help

    https://github.com/lucidrains/deformable-attention/blob/9f3c0ae35652ce54687e0db409921018bfca3310/deformable_attention/deformable_attention_2d.py#L26

    opened by ChidanandKumarKS 1
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
This Deep Learning Model Predicts that from which disease you are suffering.

Deep-Learning-Project This Deep Learning Model Predicts that from which disease you are suffering. This Project Covers the Topics of Deep Learning Int

Jai Viral Doshi 0 Jan 20, 2022
End-to-end face detection, cropping, norm estimation, and landmark detection in a single onnx model

onnx-facial-lmk-detector End-to-end face detection, cropping, norm estimation, and landmark detection in a single onnx model, model.onnx. Demo You can

atksh 42 Dec 30, 2022
Array Camera Ptychography

Array Camera Ptychography This repository provides the code for the following papers: Schulz, Timothy J., David J. Brady, and Chengyu Wang. "Photon-li

Brady lab in Optical Sciences 1 Nov 15, 2021
Implementation of the state of the art beat-detection, downbeat-detection and tempo-estimation model

The ISMIR 2020 Beat Detection, Downbeat Detection and Tempo Estimation Model Implementation. This is an implementation in TensorFlow to implement the

Koen van den Brink 1 Nov 12, 2021
ICON: Implicit Clothed humans Obtained from Normals

ICON: Implicit Clothed humans Obtained from Normals arXiv, December 2021. Yuliang Xiu · Jinlong Yang · Dimitrios Tzionas · Michael J. Black Table of C

Yuliang Xiu 1.1k Dec 30, 2022
[NeurIPS 2020] Official Implementation: "SMYRF: Efficient Attention using Asymmetric Clustering".

SMYRF: Efficient attention using asymmetric clustering Get started: Abstract We propose a novel type of balanced clustering algorithm to approximate a

Giannis Daras 46 Dec 22, 2022
Anonymous implementation of KSL

k-Step Latent (KSL) Implementation of k-Step Latent (KSL) in PyTorch. Representation Learning for Data-Efficient Reinforcement Learning [Paper] Code i

1 Nov 10, 2021
Utilities to bridge Canvas-generated course rosters with GitLab's API.

gitlab-canvas-utils A collection of scripts originally written for CSE 13S. Oversees everything from GitLab course group creation, student repository

Eugene Chou 5 Jun 08, 2022
TCTrack: Temporal Contexts for Aerial Tracking (CVPR2022)

TCTrack: Temporal Contexts for Aerial Tracking (CVPR2022) Ziang Cao and Ziyuan Huang and Liang Pan and Shiwei Zhang and Ziwei Liu and Changhong Fu In

Intelligent Vision for Robotics in Complex Environment 100 Dec 19, 2022
Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Spiking Neural Network for Computer Vision using SpikingJelly framework and Pytorch-Lightning

Sami BARCHID 2 Oct 20, 2022
Code for our ICASSP 2021 paper: SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

SA-Net: Shuffle Attention for Deep Convolutional Neural Networks (paper) By Qing-Long Zhang and Yu-Bin Yang [State Key Laboratory for Novel Software T

Qing-Long Zhang 199 Jan 08, 2023
ResNEsts and DenseNEsts: Block-based DNN Models with Improved Representation Guarantees

ResNEsts and DenseNEsts: Block-based DNN Models with Improved Representation Guarantees This repository is the official implementation of the empirica

Kuan-Lin (Jason) Chen 2 Oct 02, 2022
(ICCV 2021 Oral) Re-distributing Biased Pseudo Labels for Semi-supervised Semantic Segmentation: A Baseline Investigation.

DARS Code release for the paper "Re-distributing Biased Pseudo Labels for Semi-supervised Semantic Segmentation: A Baseline Investigation", ICCV 2021

CVMI Lab 58 Jan 01, 2023
converts nominal survey data into a numerical value based on a dictionary lookup.

SWAP RATE Converts nominal survey data into a numerical values based on a dictionary lookup. It allows the user to switch nominal scale data from text

Jake Rhodes 1 Jan 18, 2022
Official PyTorch implementation of BlobGAN: Spatially Disentangled Scene Representations

BlobGAN: Spatially Disentangled Scene Representations Official PyTorch Implementation Paper | Project Page | Video | Interactive Demo BlobGAN.mp4 This

148 Dec 29, 2022
ColossalAI-Benchmark - Performance benchmarking with ColossalAI

Benchmark for Tuning Accuracy and Efficiency Overview The benchmark includes our

HPC-AI Tech 31 Oct 07, 2022
Codebase for the paper titled "Continual learning with local module selection"

This repository contains the codebase for the paper Continual Learning via Local Module Composition. Setting up the environemnt Create a new conda env

Oleksiy Ostapenko 20 Dec 10, 2022
MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.

MMdnn MMdnn is a comprehensive and cross-framework tool to convert, visualize and diagnose deep learning (DL) models. The "MM" stands for model manage

Microsoft 5.7k Jan 09, 2023
Cross-media Structured Common Space for Multimedia Event Extraction (ACL2020)

Cross-media Structured Common Space for Multimedia Event Extraction Table of Contents Overview Requirements Data Quickstart Citation Overview The code

Manling Li 49 Nov 21, 2022
PyTorch reimplementation of REALM and ORQA

PyTorch reimplementation of REALM and ORQA

Li-Huai (Allan) Lin 17 Aug 20, 2022