Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

Overview

🦩 Flamingo - Pytorch

Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the perceiver resampler (including the scheme where the learned queries contributes keys / values to be attended to, in addition to media embeddings), the specialized masked cross attention blocks, and finally the tanh gating at the ends of the cross attention + corresponding feedforward blocks

Install

$ pip install flamingo-pytorch

Usage

import torch
from flamingo_pytorch import PerceiverResampler

perceive = PerceiverResampler(
    dim = 1024,
    depth = 2,
    dim_head = 64,
    heads = 8,
    num_latents = 64,    # the number of latents to shrink your media sequence to, perceiver style
    num_time_embeds = 4  # say you have 4 images maximum in your dialogue
)

medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension)
perceived = perceive(medias) # (1, 2, 64, 1024) - (batch, time, num latents, dimension)

Then you insert the GatedCrossAttentionBlock at different intervals in your giant language model. Your text would then attend to the perceived media from above

The recommended way to derive the media_locations boolean tensor would be to allocate a special token id to the media, and then, at the start of your large language model, do media_locations = text_id == media_token_id

import torch
from flamingo_pytorch import GatedCrossAttentionBlock

cross_attn = GatedCrossAttentionBlock(
    dim = 1024,
    dim_head = 64,
    heads = 8
)

text = torch.randn(1, 512, 1024)
perceived = torch.randn(1, 2, 64, 1024)

media_locations = torch.randint(0, 2, (1, 512)).bool()

text = cross_attn(
    text,
    perceived,
    media_locations = media_locations
)

That's it!

Attention is all you need.

Full working example with Flamingo + PaLM 🌴 🦩 🌴

Integration with PaLM

First install vit-pytorch for the vision encoder

$ pip install vit-pytorch

Then

from vit_pytorch.vit import ViT
from vit_pytorch.extractor import Extractor

vit = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

vit = Extractor(vit, return_embeddings_only = True)

# first take your trained image encoder and wrap it in an adapter that returns the image embeddings
# here we use the ViT from the vit-pytorch library

import torch
from flamingo_pytorch import FlamingoPaLM

# a PaLM language model, the 540 billion parameter model from google that shows signs of general intelligence

flamingo_palm = FlamingoPaLM(
    num_tokens = 20000,          # number of tokens
    dim = 1024,                  # dimensions
    depth = 12,                  # depth
    heads = 8,                   # attention heads
    dim_head = 64,               # dimension per attention head
    img_encoder = vit,           # plugin your image encoder (this can be optional if you pass in the image embeddings separately, but probably want to train end to end given the perceiver resampler)
    media_token_id = 3,          # the token id representing the [media] or [image]
    cross_attn_every = 3,        # how often to cross attend
    perceiver_num_latents = 64,  # perceiver number of latents, should be smaller than the sequence length of the image tokens
    perceiver_depth = 2          # perceiver resampler depth
)

# train your PaLM as usual

text = torch.randint(0, 20000, (2, 512))

palm_logits = flamingo_palm(text)

# after much training off the regular PaLM logits
# now you are ready to train Flamingo + PaLM
# by passing in images, it automatically freezes everything but the perceiver and cross attention blocks, as in the paper

dialogue = torch.randint(0, 20000, (4, 512))
images = torch.randn(4, 2, 3, 256, 256)

flamingo_logits = flamingo_palm(dialogue, images)

# do your usual cross entropy loss

It is quite evident where this is all headed if you think beyond just images.

Inception

For factual correctness, just imagine where this system would stand if one were to use a state of the art retrieval language model as the base.

Citations

@article{Alayrac2022Flamingo,
    title   = {Flamingo: a Visual Language Model for Few-Shot Learning},
    author  = {Jean-Baptiste Alayrac et al},
    year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}
Comments
  • PerceiverResampler missing some LayerNorms?

    PerceiverResampler missing some LayerNorms?

    Hey, it feels like PerceiverResampler is missing some LayerNorms? it seems to me we should layer-norm x before sending to attentions loop, and may be add layer-norm to ff(latents) + latents?

    opened by inspirit 7
  • Missing flatten op in PerceiverResampler?

    Missing flatten op in PerceiverResampler?

    Hi, It seems that Flamingo did "x_f = flatten(x_f) # [T, S, d] -> [T * S, d]" (batch size == 1) before putting x_f to attention.

    So, it should be like: medias = torch.randn(1, 2, 256, 1024) # (batch, time, sequence length, dimension) perceived = perceive(medias) # (1, 64, 1024) - (batch, num latents, dimension)

    ??

    opened by zengyan-97 6
  • wrong attention masks?

    wrong attention masks?

    https://github.com/lucidrains/flamingo-pytorch/blob/44920f4191ba3c280ff84c6ebc76025656d1dab5/flamingo_pytorch/flamingo_pytorch.py#L159

    In the flamingo paper, the language features in the gated cross-attention layers only attend to the visual features from the immediate preceding image. I believe your attention masks are created in such a way that they attend to the visual features from all preceding images. Can you confirm? If so, a fix would be to simply change the '>=' to '=='.

    opened by dhansmair 4
  • zeroing out attention not working

    zeroing out attention not working

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_pytorch.py#L179

    you are not using the inplace version of the function: https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch.Tensor.masked_fill_

    so I think this line does not have an effect.

    Best, David

    opened by dhansmair 2
  • Applying parallel attn with ff to existing pretrained model?

    Applying parallel attn with ff to existing pretrained model?

    Hi - awesome work! I am trying to understand ? I couldn't find a paper - only a reference to https://github.com/kingoflolz/mesh-transformer-jax. Is this right? Am I understanding that it is bascially applying multiple operations of for qkv and ff at once? Is it possible to use this trick to modify an existing pretrained model?

    https://github.com/lucidrains/flamingo-pytorch/blob/749f8244794002371913d2fc4e7411afd5eddc67/flamingo_pytorch/flamingo_palm.py#L90

    Many thanks in advance!

    Huu

    opened by ontocord 1
  • How to use Flamingo for VQA task?

    How to use Flamingo for VQA task?

    Hi, Thanks for sharing this awesome implementation. I am very interested in using Flamingo model for my usecase. How I can use this implementation to get inference on my dataset for VQA task? I have certain images of products and want extract some information image of product by questioning it. How I can do it ?

    Please help.

    thanks

    opened by karndeepsingh 0
  • Fine-tuning of a model

    Fine-tuning of a model

    Hi, Thank you for this great work. I want to ask how can I fine-tune this model on my dataset for some downstream task like image captioning or image classification? If it is possible for you can you also please share the code?

    opened by ans92 0
  • Need a sample ipython notebook

    Need a sample ipython notebook

    Hello, @lucidrains,

    Thank you for providing this.

    For demo purposes, could you please provide a sample demo in Jupyter notebook?🫠

    Best, LITDataScience

    opened by LITDataScience 0
Releases(0.1.2)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
This initial strategy was developed specifically for larger pools and is based on taking a moving average and deriving Bollinger Bands to create a projected active liquidity range.

Gamma's Strategy One This initial strategy was developed specifically for larger pools and is based on taking a moving average and deriving Bollinger

Gamma Strategies 46 Dec 02, 2022
Official implementation for "Style Transformer for Image Inversion and Editing" (CVPR 2022)

Style Transformer for Image Inversion and Editing (CVPR2022) https://arxiv.org/abs/2203.07932 Existing GAN inversion methods fail to provide latent co

Xueqi Hu 153 Dec 02, 2022
DL course co-developed by YSDA, HSE and Skoltech

Deep learning course This repo supplements Deep Learning course taught at YSDA and HSE @fall'21. For previous iteration visit the spring21 branch. Lec

Yandex School of Data Analysis 1.3k Dec 30, 2022
Implementation of TabTransformer, attention network for tabular data, in Pytorch

Tab Transformer Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's bread

Phil Wang 420 Jan 05, 2023
Learning What and Where to Draw

###Learning What and Where to Draw Scott Reed, Zeynep Akata, Santosh Mohan, Samuel Tenka, Bernt Schiele, Honglak Lee This is the code for our NIPS 201

Scott Ellison Reed 337 Nov 18, 2022
Novel Instances Mining with Pseudo-Margin Evaluation for Few-Shot Object Detection

Novel Instances Mining with Pseudo-Margin Evaluation for Few-Shot Object Detection (NimPme) The official implementation of Novel Instances Mining with

12 Sep 08, 2022
Ludwig Benchmarking Toolkit

Ludwig Benchmarking Toolkit The Ludwig Benchmarking Toolkit is a personalized benchmarking toolkit for running end-to-end benchmark studies across an

HazyResearch 17 Nov 18, 2022
Pythonic particle-based (super-droplet) warm-rain/aqueous-chemistry cloud microphysics package with box, parcel & 1D/2D prescribed-flow examples in Python, Julia and Matlab

PySDM PySDM is a package for simulating the dynamics of population of particles. It is intended to serve as a building block for simulation systems mo

Atmospheric Cloud Simulation Group @ Jagiellonian University 32 Oct 18, 2022
O2O-Afford: Annotation-Free Large-Scale Object-Object Affordance Learning (CoRL 2021)

O2O-Afford: Annotation-Free Large-Scale Object-Object Affordance Learning Object-object Interaction Affordance Learning. For a given object-object int

Kaichun Mo 26 Nov 04, 2022
TigerLily: Finding drug interactions in silico with the Graph.

Drug Interaction Prediction with Tigerlily Documentation | Example Notebook | Youtube Video | Project Report Tigerlily is a TigerGraph based system de

Benedek Rozemberczki 91 Dec 30, 2022
Solving Zero-Shot Learning in Named Entity Recognition with Common Sense Knowledge

Zero-Shot Learning in Named Entity Recognition with Common Sense Knowledge Associated code for the paper Zero-Shot Learning in Named Entity Recognitio

Søren Hougaard Mulvad 13 Dec 25, 2022
GPU-accelerated Image Processing library using OpenCL

pyclesperanto pyclesperanto is a python package for clEsperanto - a multi-language framework for GPU-accelerated image processing. clEsperanto uses Op

17 Dec 25, 2022
Code for our ACL 2021 paper "One2Set: Generating Diverse Keyphrases as a Set"

One2Set This repository contains the code for our ACL 2021 paper “One2Set: Generating Diverse Keyphrases as a Set”. Our implementation is built on the

Jiacheng Ye 63 Jan 05, 2023
PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation.

PyTorch implementation of Progressive Growing of GANs for Improved Quality, Stability, and Variation. Warning: the master branch might collapse. To ob

559 Dec 14, 2022
Generic Event Boundary Detection: A Benchmark for Event Segmentation

Generic Event Boundary Detection: A Benchmark for Event Segmentation We release our data annotation & baseline codes for detecting generic event bound

47 Nov 22, 2022
Implementation of Memformer, a Memory-augmented Transformer, in Pytorch

Memformer - Pytorch Implementation of Memformer, a Memory-augmented Transformer, in Pytorch. It includes memory slots, which are updated with attentio

Phil Wang 60 Nov 06, 2022
Official PyTorch Implementation of paper "NeLF: Neural Light-transport Field for Single Portrait View Synthesis and Relighting", EGSR 2021.

NeLF: Neural Light-transport Field for Single Portrait View Synthesis and Relighting Official PyTorch Implementation of paper "NeLF: Neural Light-tran

Ken Lin 38 Dec 26, 2022
A very simple baseline to estimate 2D & 3D SMPL-compatible keypoints from a single color image.

Minimal Body A very simple baseline to estimate 2D & 3D SMPL-compatible keypoints from a single color image. The model file is only 51.2 MB and runs a

Yuxiao Zhou 49 Dec 05, 2022
Implementation of TimeSformer, a pure attention-based solution for video classification

TimeSformer - Pytorch Implementation of TimeSformer, a pure and simple attention-based solution for reaching SOTA on video classification.

Phil Wang 602 Jan 03, 2023
PromptDet: Expand Your Detector Vocabulary with Uncurated Images

PromptDet: Expand Your Detector Vocabulary with Uncurated Images Paper Website Introduction The goal of this work is to establish a scalable pipeline

103 Dec 20, 2022