Generative Query Network (GQN) in PyTorch as described in "Neural Scene Representation and Rendering"

Overview

Update 2019/06/24: A model trained on 10% of the Shepard-Metzler dataset has been added, the following notebook explains the main features of this model: nbviewer

Generative Query Network

This is a PyTorch implementation of the Generative Query Network (GQN) described in the DeepMind paper "Neural scene representation and rendering" by Eslami et al. For an introduction to the model and problem described in the paper look at the article by DeepMind.

The current implementation generalises to any of the datasets described in the paper. However, currently, only the Shepard-Metzler dataset has been implemented. To use this dataset you can use the provided script in

sh scripts/data.sh data-dir batch-size

The model can be trained in full by in accordance to the paper by running the file run-gqn.py or by using the provided training script

sh scripts/gpu.sh data-dir

Implementation

The implementation shown in this repository consists of all of the representation architectures described in the paper along with the generative model that is similar to the one described in "Towards conceptual compression" by Gregor et al.

Additionally, this repository also contains implementations of the DRAW model and the ConvolutionalDRAW model both described by Gregor et al.

Comments
  • Training time and testing demo

    Training time and testing demo

    Hi Jesper,

    Thank you for your great code of gqn in real image, I am a little curious about the following issues: How many epochs it use to train a model on real image? How many training data do you use (percentage of full training dataset)? Can you show a testing demo?

    Thank you very much!

    Best wishes, Mingjia Chen

    opened by mjchen611 22
  • ConvLSTM did not concat hidden from last round

    ConvLSTM did not concat hidden from last round

    In the structure presented in the paper, the hidden from last round is concat with input and then proceed for other operation. But it seems your LSTM did not use the hidden information from previous round.

    opened by Tom-the-Cat 7
  • Bad images in training

    Bad images in training

    While playing around with the sm5 dataset, I noticed some of them are badly rendered. individualimage Not sure if this will pose any problem for training, just wanted to point this out.

    opened by versatran01 7
  • Question about generator

    Question about generator

    In the top docstring of generator.py, you mentioned that

    The inference-generator architecture is conceptually
    similar to the encoder-decoder pair seen in variational
    autoencoders.
    

    I don't quite understand this part and I would really appreciate if you could explain a bit or point me at some related aritcles. For the generator I can see how it is similar to a decoder, where it takes latent z, query viewpoint v, and aggregated representation r and eventually output the image x_mu.

    But I'm a bit confused by the inference being the conterpart of encoder.

    opened by versatran01 7
  • Loss Change

    Loss Change

    Dear wohlert,

    May I consult you several questions?

    1. I tried to train this network on Mazes Data from https://github.com/deepmind/gqn-datasets. Actually it just contains 5% data, which is around 110000, instead of the full data. Is it right?

    2. I trained 30000 steps, but the elbo loss only converged to 6800 which has a big difference compared to around 7 in the supplementary. So may I ask what is the approximate value do you achieve on the data you used?

    3. From the visualisation based on Question 2, the reconstruction seems to be reasonable. But the sampling results is quite bad. Do you meet the same problem?

    Many thanks, Bing

    opened by BingCS 5
  • Questions on data preparing

    Questions on data preparing

    Hi, Wohlert:

    After the data conversion with your scripts, I visualize some of the images in the *.pt found pictures like this Figure_1-1

    What's wrong with that Also I'm confused about your batch operation , say if you batch the sequences as you convert them, does it mean that you won't batch them again when use dataloader?

    Thanks

    opened by Kyridiculous2 5
  • Training crashes at the same spot for both Shepard Metzler datasets

    Training crashes at the same spot for both Shepard Metzler datasets

    Some context:

    • I downloaded and converted the datasets via data.sh and set batch size to 12. Note that I am using TensorFlow 1.14 for reading the tfrecord files and converting them.
    • I use gpu.sh to run the training script. I set the batch size to either of [1,12,36,72] and DataParallel to True to use 4 GPUs

    But after a shrot time I get the following errors if I use any batch size higher than 1. This happens on iterations 40, 13 and 6 with batch sizes 12, 36 and 72. This happens for both Shepard Metzler datasets. Why I am getting these errors? Does batch size 1 on the training code mean reading one of the .pt.gz files? If so, setting batch size to 1 in the training script should actually mean 12. Would that be correct?

    Here's what I get for the data set with 5 parts when I set batch size to 36 for instance:

    Epoch [1/200]: [13/1856]   1%|▊                                                                                                                       , elbo=-2.1e+4, kl=827, mu=5e-6, sigma=2 [00:21<52:34]Current run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Engine run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Traceback (most recent call last):
      File "../run-gqn.py", line 183, in <module>
        trainer.run(train_loader, args.n_epochs)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 850, in run
        return self._internal_run()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 952, in _internal_run
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 937, in _internal_run
        hours, mins, secs = self._run_once_on_dataset()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 705, in _run_once_on_dataset
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 655, in _run_once_on_dataset
        batch = next(self._dataloader_iter)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 801, in __next__
        return self._process_data(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 846, in _process_data
        data.reraise()
      File "/usr/local/lib/python3.6/dist-packages/torch/_utils.py", line 385, in reraise
        raise self.exc_type(msg)
    RuntimeError: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    
    opened by Amir-Arsalan 4
  • AttributeError: 'int' object has no attribute 'size'

    AttributeError: 'int' object has no attribute 'size'

    in draw.py, I get this error at the line 118 (batch_size = z.size(0)) Sorry if this is obvious, thanks for help anyway.

    ~ % pip show torch :( Name: torch Version: 1.0.1.post2

    opened by DRM-Free 4
  • Increase dimension of viewpoint and representation

    Increase dimension of viewpoint and representation

    Thanks for this implementation. One question I have is when increasing the dimension of viewpoint and representation, you use torch.repeat. Is there any reason for this? Can one possibly use interpolate?

    In the original paper it says "when concatenating viewpoint v to an image or feature map, its values are ‘broadcast’ in the spatial dimensions to obtain the correct size. "

    The word 'broadcast' is not precisely defined, hence the question.

    opened by versatran01 4
  • Learning rate change

    Learning rate change

    Regarding line 113 of run-gqn.py. Does this change the learning rate of the Adam optimizer? This post shows something different

    https://stackoverflow.com/questions/48324152/pytorch-how-to-change-the-learning-rate-of-an-optimizer-at-any-given-moment-no

    opened by david-bernstein 4
  • Using the rooms data?

    Using the rooms data?

    I wanted to try your code on the rooms data but during conversion, I get these errors. What could I be doing wrong? Note that for the rooms data with moving camera I set the number of camera parameters to 7:

    Traceback (most recent call last):
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 119, in worker
        result = (True, func(*args, **kwds))
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
        return list(map(*args))
      File "tfrecord-converter.py", line 66, in convert
        for i, batch in enumerate(batch_process(record)):
      File "tfrecord-converter.py", line 29, in chunk
        for first in iterator:
      File "tfrecord-converter.py", line 40, in process
        'cameras': tf.FixedLenFeature(shape=SEQ_DIM * POSE_DIM, dtype=tf.float32)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1019, in parse_single_example
        serialized, features, example_names, name
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1063, in parse_single_example_v2_unoptimized
        return parse_single_example_v2(serialized, features, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2089, in parse_single_example_v2
        dense_defaults, dense_shapes, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2206, in _parse_single_example_v2_raw
        name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1164, in parse_single_example
        ctx=_ctx)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1260, in parse_single_example_eager_fallback
        attrs=_attrs, ctx=_ctx, name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 67, in quick_execute
        six.raise_from(core._status_to_exception(e.code, message), None)
      File "<string>", line 3, in raise_from
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    """
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "tfrecord-converter.py", line 98, in <module>
        pool.map(f, records)
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 266, in map
        return self._map_async(func, iterable, mapstar, chunksize).get()
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 644, in get
        raise self._value
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    
    opened by Amir-Arsalan 3
Releases(0.1)
Owner
Jesper Wohlert
Jesper Wohlert
VLGrammar: Grounded Grammar Induction of Vision and Language

VLGrammar: Grounded Grammar Induction of Vision and Language

Yining Hong 27 Dec 23, 2022
This is the code for ACL2021 paper A Unified Generative Framework for Aspect-Based Sentiment Analysis

This is the code for ACL2021 paper A Unified Generative Framework for Aspect-Based Sentiment Analysis Install the package in the requirements.txt, the

108 Dec 23, 2022
Application of the L2HMC algorithm to simulations in lattice QCD.

l2hmc-qcd 📊 Slides Recent talk on Training Topological Samplers for Lattice Gauge Theory from the Machine Learning for High Energy Physics, on and of

Sam Foreman 37 Dec 14, 2022
Neural Turing Machine (NTM) & Differentiable Neural Computer (DNC) with pytorch & visdom

Neural Turing Machine (NTM) & Differentiable Neural Computer (DNC) with pytorch & visdom Sample on-line plotting while training(avg loss)/testing(writ

Jingwei Zhang 269 Nov 15, 2022
Dataloader tools for language modelling

Installation: pip install lm_dataloader Design Philosophy A library to unify lm dataloading at large scale Simple interface, any tokenizer can be inte

5 Mar 25, 2022
A python script to convert images to animated sus among us crewmate twerk jifs as seen on r/196

img_sussifier A python script to convert images to animated sus among us crewmate twerk jifs as seen on r/196 Examples How to use install python pip i

41 Sep 30, 2022
[CVPRW 2022] Attentions Help CNNs See Better: Attention-based Hybrid Image Quality Assessment Network

Attention Helps CNN See Better: Hybrid Image Quality Assessment Network [CVPRW 2022] Code for Hybrid Image Quality Assessment Network [paper] [code] T

IIGROUP 49 Dec 11, 2022
Gluon CV Toolkit

Gluon CV Toolkit | Installation | Documentation | Tutorials | GluonCV provides implementations of the state-of-the-art (SOTA) deep learning models in

Distributed (Deep) Machine Learning Community 5.4k Jan 06, 2023
Official pytorch code for SSAT: A Symmetric Semantic-Aware Transformer Network for Makeup Transfer and Removal

SSAT: A Symmetric Semantic-Aware Transformer Network for Makeup Transfer and Removal This is the official pytorch code for SSAT: A Symmetric Semantic-

ForeverPupil 57 Dec 13, 2022
Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

Mikhail Grankin 768 Jan 01, 2023
Code for "My(o) Armband Leaks Passwords: An EMG and IMU Based Keylogging Side-Channel Attack" paper

Myo Keylogging This is the source code for our paper My(o) Armband Leaks Passwords: An EMG and IMU Based Keylogging Side-Channel Attack by Matthias Ga

Secure Mobile Networking Lab 7 Jan 03, 2023
Prometheus Exporter for data scraped from datenplattform.darmstadt.de

darmstadt-opendata-exporter Scrapes data from https://datenplattform.darmstadt.de and presents it in the Prometheus Exposition format. Pull requests w

Martin Weinelt 2 Apr 12, 2022
FLAVR is a fast, flow-free frame interpolation method capable of single shot multi-frame prediction

FLAVR is a fast, flow-free frame interpolation method capable of single shot multi-frame prediction. It uses a customized encoder decoder architecture with spatio-temporal convolutions and channel ga

Tarun K 280 Dec 23, 2022
Tracking Pipeline helps you to solve the tracking problem more easily

Tracking_Pipeline Tracking_Pipeline helps you to solve the tracking problem more easily I integrate detection algorithms like: Yolov5, Yolov4, YoloX,

VNOpenAI 32 Dec 21, 2022
This repository contains code for the paper "Decoupling Representation and Classifier for Long-Tailed Recognition", published at ICLR 2020

Classifier-Balancing This repository contains code for the paper: Decoupling Representation and Classifier for Long-Tailed Recognition Bingyi Kang, Sa

Facebook Research 820 Dec 26, 2022
A geometric deep learning pipeline for predicting protein interface contacts.

A geometric deep learning pipeline for predicting protein interface contacts.

44 Dec 30, 2022
Pytorch Lightning code guideline for conferences

Deep learning project seed Use this seed to start new deep learning / ML projects. Built in setup.py Built in requirements Examples with MNIST Badges

Pytorch Lightning 1k Jan 02, 2023
Data and extra materials for the food safety publications classifier

Data and extra materials for the food safety publications classifier The subdirectories contain detailed descriptions of their contents in the README.

1 Jan 20, 2022
Generative Art Using Neural Visual Grammars and Dual Encoders

Generative Art Using Neural Visual Grammars and Dual Encoders Arnheim 1 The original algorithm from the paper Generative Art Using Neural Visual Gramm

DeepMind 231 Jan 05, 2023
Playable Video Generation

Playable Video Generation Playable Video Generation Willi Menapace, Stéphane Lathuilière, Sergey Tulyakov, Aliaksandr Siarohin, Elisa Ricci Paper: ArX

Willi Menapace 136 Dec 31, 2022