Paddle pit - Rethinking Spatial Dimensions of Vision Transformers

Overview

基于Paddle实现PiT ——Rethinking Spatial Dimensions of Vision Transformers,arxiv

  • 官方原版代码(基于PyTorch)pit.

  • 本项目基于 PaddleViT 实现,在其基础上与原版代码实现了更进一步的对齐,并通过完整训练与测试完成对pit_ti模型的复现.

1. 简介

从CNN的成功设计原理出发,作者研究了空间尺寸转换的作用及其在基于Transformer的体系结构上的有效性。

具体来说,类似于CNN的降维原则(随着深度的增加,传统的CNN会增加通道尺寸并减小空间尺寸),作者用实验表明了这同样有利于Transformer的性能提升,并提出了基于池化的Vision Transformer,即PiT(模型示意图如下)。

drawing

PiT 模型示意图

2. 数据集和复现精度

数据集

原文使用的为ImageNet-1k 2012(ILSVRC2012),共1000类,训练集/测试集图片分布:1281167/50000,数据集大小为144GB。

本项目使用的为官方推荐的图片压缩过的更轻量的Light_ILSVRC2012,数据集大小为65GB。其在AI Studio上的地址为:Light_ILSVRC2012_part_0.tarLight_ILSVRC2012_part_1.tar

复现精度

Model 目标精度[email protected] 实现精度[email protected] Image Size batch_size Crop_pct epoch #Params
pit_ti 73.0 73.01 224 256*4GPUs 0.9 300
(+10 COOLDOWN)
4.8M

【注】上表中的实现精度在原版ILSVRC2012验证集上测试得到。 值得一提的是,本项目在Light_ILSVRC2012的验证集上的Validation [email protected]达到了73.17

本项目训练得到的最佳模型参数与训练日志log均存放于output文件夹下。

日志文件说明

本项目通过AI Studio的脚本任务运行,中途中断了4次,因此共有5个日志文件。为了方便检阅,本人手动将log命名为log_开始epoch-结束epoch.txt格式。具体来说:

  • output/log_1-76.txt:epoch1~epoch76。这一版代码定义每10个epoch保存一次模型权重,每2个epoch验证一次,同时若验证精度高于历史精度,则保存为Best_PiT.pdparams,因此在epoch76训练结束但还未验证的时候中断,下一次的训练只能从验证精度最高的epoch74继续训练。

  • output/log_75-142.txt:epoch75~epoch142。从这一版代码开始,新增了每次训练之后都保存一下模型参数为PiT-Latest.pdparams,这样无论哪个epoch训练中断都可以继续训练啦。

  • output/log_143-225.txt:epoch143~epoch225。

  • output/log_226-303.txt:epoch226~epoch303。

  • output/log_304-310.txt:epoch304~epoch310。

  • output/log_eval.txt:使用训练得到的最好模型(epoch308)在原版ILSVRC2012验证集上验证日志。

3. 准备环境

推荐环境配置:

本人环境配置:

  • 硬件:Tesla V100 * 4(由衷感谢百度飞桨平台提供高性能算力支持)

  • PaddlePaddle==2.2.1

  • Python==3.7

4. 快速开始

本项目现已通过脚本任务形式部署到AI Studio上,您可以选择fork下来直接运行sh run.sh,数据集处理等脚本均已部署好。链接:paddle_pit

或者您也可以git本repo在本地运行:

第一步:克隆本项目

git clone https://github.com/hatimwen/paddle_pit.git
cd paddle_pit

第二步:修改参数

请根据实际情况,修改scripts路径下的脚本内容(如:gpu,数据集路径data_path,batch_size等)。

第三步:验证模型

多卡请运行:

sh scripts/run_eval_multi.sh

单卡请运行:

sh scripts/run_eval.sh

第四步:训练模型

多卡请运行:

sh scripts/run_train_multi.sh

单卡请运行:

sh scripts/run_train.sh

第五步:验证预测

python predict.py \
-pretrained='output/Best_PiT' \
-img_path='images/ILSVRC2012_val_00004506.JPEG'

验证图片(类别:藏獒, id: 244)

输出结果为:

class_id: 244, prob: 9.12291145324707

对照ImageNet类别id(ImageNet数据集编号对应的类别内容),可知244为藏獒,预测结果正确。

5.代码结构

|-- paddle_pit
    |-- output              # 日志及模型文件
    |-- configs             # 参数
        |-- pit_ti.yaml
    |-- datasets
        |-- ImageNet1K      # 数据集路径
    |-- scripts             # 运行脚本
        |-- run_train.sh
        |-- run_train_multi.sh
        |-- run_eval.sh
        |-- run_eval_multi.sh
    |-- augment.py          # 数据增强
    |-- config.py           # 最底层配置文件
    |-- datasets.py         # dataset与dataloader
    |-- droppath.py         # droppath定义
    |-- losses.py           # loss定义
    |-- main_multi_gpu.py   # 多卡训练测试代码
    |-- main_single_gpu.py  # 单卡训练测试代码
    |-- mixup.py            # mixup定义
    |-- model_ema.py        # EMA定义
    |-- pit.py              # pit模型结构定义
    |-- random_erasing.py   # random_erasing定义
    |-- regnet.py           # 教师模型定义(本项目并未对此验证,仅作保留)
    |-- transforms.py       # RandomHorizontalFlip定义
    |-- utils.py            # CosineLRScheduler及AverageMeter定义
    |-- README.md
    |-- requirements.txt

6. 参考及引用

@InProceedings{Yuan_2021_ICCV,
    author    = {Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Jiang, Zi-Hang and Tay, Francis E.H. and Feng, Jiashi and Yan, Shuicheng},
    title     = {Tokens-to-Token ViT: Training Vision Transformers From Scratch on ImageNet},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {558-567}
}

最后,非常感谢百度举办的飞桨论文复现挑战赛(第五期)让本人对Paddle理解更加深刻。 同时也非常感谢朱欤老师团队用Paddle实现的PaddleViT,本项目大部分代码都是从中copy来的,而仅仅实现了其与原版代码训练步骤的进一步对齐与完整的训练过程,但本人也同样受益匪浅! ♥️

Contact

Owner
Hongtao Wen
Hongtao Wen
Learning to Communicate with Deep Multi-Agent Reinforcement Learning in PyTorch

Learning to Communicate with Deep Multi-Agent Reinforcement Learning This is a PyTorch implementation of the original Lua code release. Overview This

Minqi 297 Dec 12, 2022
A curated list of programmatic weak supervision papers and resources

A curated list of programmatic weak supervision papers and resources

Jieyu Zhang 118 Jan 02, 2023
Count the MACs / FLOPs of your PyTorch model.

THOP: PyTorch-OpCounter How to install pip install thop (now continously intergrated on Github actions) OR pip install --upgrade git+https://github.co

Ligeng Zhu 3.9k Dec 29, 2022
This repo will contain code to reproduce and build upon understanding transfer learning

What is being transferred in transfer learning? This repo contains the code for the following paper: Behnam Neyshabur*, Hanie Sedghi*, Chiyuan Zhang*.

4 Jun 16, 2021
pip install python-office

🍬 python for office 👉 http://www.python4office.cn/ 👈 🌎 English Documentation 📚 简介 Python-office 是一个 Python 自动化办公第三方库,能解决大部分自动化办公的问题。而且每个功能只需一行代码,

程序员晚枫 272 Dec 29, 2022
Convert Pytorch model to onnx or tflite, and the converted model can be visualized by Netron

Convert Pytorch model to onnx or tflite, and the converted model can be visualized by Netron

Roxbili 5 Nov 19, 2022
This repository includes code of my study about Asynchronous in Frequency domain of GAN images.

Exploring the Asynchronous of the Frequency Spectra of GAN-generated Facial Images Binh M. Le & Simon S. Woo, "Exploring the Asynchronous of the Frequ

4 Aug 06, 2022
A transformer-based method for Healthcare Image Captioning in Vietnamese

vieCap4H Challenge 2021: A transformer-based method for Healthcare Image Captioning in Vietnamese This repo GitHub contains our solution for vieCap4H

Doanh B C 4 May 05, 2022
GyroSPD: Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices

GyroSPD Code for the paper "Vector-valued Distance and Gyrocalculus on the Space of Symmetric Positive Definite Matrices" accepted at NeurIPS 2021. Re

Federico Lopez 12 Dec 12, 2022
RANZCR-CLiP 7th Place Solution

RANZCR-CLiP 7th Place Solution This repository is WIP. (18 Mar 2021) Installation git clone https://github.com/analokmaus/kaggle-ranzcr-clip-public.gi

Hiroshechka Y 21 Oct 22, 2022
Chinese Mandarin tts text-to-speech 中文 (普通话) 语音 合成 , by fastspeech 2 , implemented in pytorch, using waveglow as vocoder,

Chinese mandarin text to speech based on Fastspeech2 and Unet This is a modification and adpation of fastspeech2 to mandrin(普通话). Many modifications t

291 Jan 02, 2023
Some methods for comparing network representations in deep learning and neuroscience.

Generalized Shape Metrics on Neural Representations In neuroscience and in deep learning, quantifying the (dis)similarity of neural representations ac

Alex Williams 45 Dec 27, 2022
[CVPR 2021] Scan2Cap: Context-aware Dense Captioning in RGB-D Scans

Scan2Cap: Context-aware Dense Captioning in RGB-D Scans Introduction We introduce the task of dense captioning in 3D scans from commodity RGB-D sensor

Dave Z. Chen 79 Nov 07, 2022
DziriBERT: a Pre-trained Language Model for the Algerian Dialect

DziriBERT DziriBERT is the first Transformer-based Language Model that has been pre-trained specifically for the Algerian Dialect. It handles Algerian

117 Jan 07, 2023
Multi-Stage Progressive Image Restoration

Multi-Stage Progressive Image Restoration Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Sh

Syed Waqas Zamir 859 Dec 22, 2022
Intrusion Test Tool with Python

P3ntsT00L Uma ferramenta escrita em Python, feita para Teste de intrusão. Requisitos ter o python 3.9.8 instalado em sua máquina. ter a git instalada

josh washington 2 Dec 27, 2021
Anagram Generator in Python

Anagrams Generator This is a program for computing multiword anagrams. It makes no effort to come up with sentences that make sense; it only finds ana

Day Fundora 5 Nov 17, 2022
YuNetのPythonでのONNX、TensorFlow-Lite推論サンプル

YuNet-ONNX-TFLite-Sample YuNetのPythonでのONNX、TensorFlow-Lite推論サンプルです。 TensorFlow-LiteモデルはPINTO0309/PINTO_model_zoo/144_YuNetのものを使用しています。 Requirement Op

KazuhitoTakahashi 8 Nov 17, 2021
DLL: Direct Lidar Localization

DLL: Direct Lidar Localization Summary This package presents DLL, a direct map-based localization technique using 3D LIDAR for its application to aeri

Service Robotics Lab 127 Dec 16, 2022
95.47% on CIFAR10 with PyTorch

Train CIFAR10 with PyTorch I'm playing with PyTorch on the CIFAR10 dataset. Prerequisites Python 3.6+ PyTorch 1.0+ Training # Start training with: py

5k Dec 30, 2022