Neural Turing Machines (NTM) - PyTorch Implementation

Overview

PyTorch Neural Turing Machine (NTM)

PyTorch implementation of Neural Turing Machines (NTM).

An NTM is a memory augumented neural network (attached to external memory) where the interactions with the external memory (address, read, write) are done using differentiable transformations. Overall, the network is end-to-end differentiable and thus trainable by a gradient based optimizer.

The NTM is processing input in sequences, much like an LSTM, but with additional benfits: (1) The external memory allows the network to learn algorithmic tasks easier (2) Having larger capacity, without increasing the network's trainable parameters.

The external memory allows the NTM to learn algorithmic tasks, that are much harder for LSTM to learn, and to maintain an internal state much longer than traditional LSTMs.

A PyTorch Implementation

This repository implements a vanilla NTM in a straight forward way. The following architecture is used:

NTM Architecture

Features

  • Batch learning support
  • Numerically stable
  • Flexible head configuration - use X read heads and Y write heads and specify the order of operation
  • copy and repeat-copy experiments agree with the paper

Copy Task

The Copy task tests the NTM's ability to store and recall a long sequence of arbitrary information. The input to the network is a random sequence of bits, ending with a delimiter. The sequence lengths are randomised between 1 to 20.

Training

Training convergence for the copy task using 4 different seeds (see the notebook for details)

NTM Convergence

The following plot shows the cost per sequence length during training. The network was trained with seed=10 and shows fast convergence. Other seeds may not perform as well but should converge in less than 30K iterations.

NTM Convergence

Evaluation

Here is an animated GIF that shows how the model generalize. The model was evaluated after every 500 training samples, using the target sequence shown in the upper part of the image. The bottom part shows the network output at any given training stage.

Copy Task

The following is the same, but with sequence length = 80. Note that the network was trained with sequences of lengths 1 to 20.

Copy Task


Repeat Copy Task

The Repeat Copy task tests whether the NTM can learn a simple nested function, and invoke it by learning to execute a for loop. The input to the network is a random sequence of bits, followed by a delimiter and a scalar value that represents the number of repetitions to output. The number of repetitions, was normalized to have zero mean and variance of one (as in the paper). Both the length of the sequence and the number of repetitions are randomised between 1 to 10.

Training

Training convergence for the repeat-copy task using 4 different seeds (see the notebook for details)

NTM Convergence

Evaluation

The following image shows the input presented to the network, a sequence of bits + delimiter + num-reps scalar. Specifically the sequence length here is eight and the number of repetitions is five.

Repeat Copy Task

And here's the output the network had predicted:

Repeat Copy Task

Here's an animated GIF that shows how the network learns to predict the targets. Specifically, the network was evaluated in each checkpoint saved during training with the same input sequence.

Repeat Copy Task

Installation

The NTM can be used as a reusable module, currently not packaged though.

  1. Clone repository
  2. Install PyTorch
  3. pip install -r requirements.txt

Usage

Execute ./train.py

usage: train.py [-h] [--seed SEED] [--task {copy,repeat-copy}] [-p PARAM]
                [--checkpoint-interval CHECKPOINT_INTERVAL]
                [--checkpoint-path CHECKPOINT_PATH]
                [--report-interval REPORT_INTERVAL]

optional arguments:
  -h, --help            show this help message and exit
  --seed SEED           Seed value for RNGs
  --task {copy,repeat-copy}
                        Choose the task to train (default: copy)
  -p PARAM, --param PARAM
                        Override model params. Example: "-pbatch_size=4
                        -pnum_heads=2"
  --checkpoint-interval CHECKPOINT_INTERVAL
                        Checkpoint interval (default: 1000). Use 0 to disable
                        checkpointing
  --checkpoint-path CHECKPOINT_PATH
                        Path for saving checkpoint data (default: './')
  --report-interval REPORT_INTERVAL
                        Reporting interval
Owner
Guy Zana
I make things, author of Curated Papers
Guy Zana
Learning multiple gaits of quadruped robot using hierarchical reinforcement learning

Learning multiple gaits of quadruped robot using hierarchical reinforcement learning We propose a method to learn multiple gaits of quadruped robot us

Yunho Kim 17 Dec 11, 2022
Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics.

Bonnet: An Open-Source Training and Deployment Framework for Semantic Segmentation in Robotics. By Andres Milioto @ University of Bonn. (for the new P

Photogrammetry & Robotics Bonn 314 Dec 30, 2022
Editing a classifier by rewriting its prediction rules

This repository contains the code and data for our paper: Editing a classifier by rewriting its prediction rules Shibani Santurkar*, Dimitris Tsipras*

Madry Lab 86 Dec 27, 2022
Implementation for Homogeneous Unbalanced Regularized Optimal Transport

HUROT: An Homogeneous formulation of Unbalanced Regularized Optimal Transport. This repository provides code related to this preprint. This is an alph

Théo Lacombe 1 Feb 17, 2022
CVNets: A library for training computer vision networks

CVNets: A library for training computer vision networks This repository contains the source code for training computer vision models. Specifically, it

Apple 1.1k Jan 03, 2023
This repository contains the source code for the paper "DONeRF: Towards Real-Time Rendering of Compact Neural Radiance Fields using Depth Oracle Networks",

DONeRF: Towards Real-Time Rendering of Compact Neural Radiance Fields using Depth Oracle Networks Project Page | Video | Presentation | Paper | Data L

Facebook Research 281 Dec 22, 2022
(ICCV'21) Official PyTorch implementation of Relational Embedding for Few-Shot Classification

Relational Embedding for Few-Shot Classification (ICCV 2021) Dahyun Kang, Heeseung Kwon, Juhong Min, Minsu Cho [paper], [project hompage] We propose t

Dahyun Kang 82 Dec 24, 2022
FS-Mol: A Few-Shot Learning Dataset of Molecules

FS-Mol is A Few-Shot Learning Dataset of Molecules, containing molecular compounds with measurements of activity against a variety of protein targets. The dataset is presented with a model evaluation

Microsoft 114 Dec 15, 2022
Fast Differentiable Matrix Sqrt Root

Fast Differentiable Matrix Sqrt Root Geometric Interpretation of Matrix Square Root and Inverse Square Root This repository constains the official Pyt

YueSong 42 Dec 30, 2022
Image morphing without reference points by applying warp maps and optimizing over them.

Differentiable Morphing Image morphing without reference points by applying warp maps and optimizing over them. Differentiable Morphing is machine lea

Alex K 380 Dec 19, 2022
FaceOcc: A Diverse, High-quality Face Occlusion Dataset for Human Face Extraction

FaceExtraction FaceOcc: A Diverse, High-quality Face Occlusion Dataset for Human Face Extraction Occlusions often occur in face images in the wild, tr

16 Dec 14, 2022
An official repository for Paper "Uformer: A General U-Shaped Transformer for Image Restoration".

Uformer: A General U-Shaped Transformer for Image Restoration Zhendong Wang, Xiaodong Cun, Jianmin Bao and Jianzhuang Liu Paper: https://arxiv.org/abs

Zhendong Wang 497 Dec 22, 2022
Code for Temporally Abstract Partial Models

Code for Temporally Abstract Partial Models Accompanies the code for the experimental section of the paper: Temporally Abstract Partial Models, Khetar

DeepMind 19 Jul 13, 2022
Model of an AI powered sign language interpreter.

TEXT AND SPEECH TO SIGN LANGUAGE. A web application which takes in text or live audio speech recording as input, converts and displays the relevant Si

Mark Gatere 4 Mar 30, 2022
Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo"

dblmahmc Code to go with the paper "Decentralized Bayesian Learning with Metropolis-Adjusted Hamiltonian Monte Carlo" Requirements: https://github.com

1 Dec 17, 2021
curl-impersonate: A special compilation of curl that makes it impersonate Chrome & Firefox

curl-impersonate A special compilation of curl that makes it impersonate real browsers. It can impersonate the four major browsers: Chrome, Edge, Safa

lwthiker 1.9k Jan 03, 2023
Repository for the Bias Benchmark for QA dataset.

BBQ Repository for the Bias Benchmark for QA dataset. Authors: Alicia Parrish, Angelica Chen, Nikita Nangia, Vishakh Padmakumar, Jason Phang, Jana Tho

ML² AT CILVR 18 Nov 18, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL: Graph Contrastive Learning for PyTorch PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL com

GCL: Graph Contrastive Learning Library for PyTorch 594 Jan 08, 2023
Safe Bayesian Optimization

SafeOpt - Safe Bayesian Optimization This code implements an adapted version of the safe, Bayesian optimization algorithm, SafeOpt [1], [2]. It also p

Felix Berkenkamp 111 Dec 11, 2022
This code is a near-infrared spectrum modeling method based on PCA and pls

Nirs-Pls-Corn This code is a near-infrared spectrum modeling method based on PCA and pls 近红外光谱分析技术属于交叉领域,需要化学、计算机科学、生物科学等多领域的合作。为此,在(北邮邮电大学杨辉华老师团队)指导下

Fu Pengyou 6 Dec 17, 2022