Prototypical Networks for Few shot Learning in PyTorch

Overview

Prototypical Networks for Few shot Learning in PyTorch

Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code) in PyTorch.

Prototypical Networks

As shown in the reference paper Prototypical Networks are trained to embed samples features in a vectorial space, in particular, at each episode (iteration), a number of samples for a subset of classes are selected and sent through the model, for each subset of class c a number of samples' features (n_support) are used to guess the prototype (their barycentre coordinates in the vectorial space) for that class, so then the distances between the remaining n_query samples and their class barycentre can be minimized.

Prototypical Networks

T-SNE

After training, you can compute the t-SNE for the features generated by the model (not done in this repo, more infos about t-SNE here), this is a sample as shown in the paper.

Reference Paper t-SNE

Omniglot Dataset

Kudos to @ludc for his contribute: https://github.com/pytorch/vision/pull/46. We will use the official dataset when it will be added to torchvision if it doesn't imply big changes to the code.

Dataset splits

We implemented the Vynials splitting method as in [Matching Networks for One Shot Learning]. That sould be the same method used in the paper (in fact I download the split files from the "offical" repo). We then apply the same rotations there described. In this way we should be able to compare results obtained by running this code with results described in the reference paper.

Prototypical Batch Sampler

As described in its PyDoc, this class is used to generate the indexes of each batch for a prototypical training algorithm.

In particular, the object is instantiated by passing the list of the labels for the dataset, the sampler infers then the total number of classes and creates a set of indexes for each class ni the dataset. At each episode the sampler selects n_classes random classes and returns a number (n_support + n_query) of samples indexes for each one of the selected classes.

Prototypical Loss

Compute the loss as in the cited paper, mostly inspired by this code by one of its authors.

In prototypical_loss.py both loss function and loss class à la PyTorch are implemented.

The function takes in input the batch input from the model, samples' ground truths and the number n_suppport of samples to be used as support samples. Episode classes get infered from the target list, n_support samples get randomly extracted for each class, their class barycentres get computed, as well as the distances of each remaining samples' embedding from each class barycentre and the probability of each sample of belonging to each episode class get finmally computed; then the loss is then computed from the wrong predictions probabilities (for the query samples) as usual in classification problems.

Training

Please note that the training code is here just for demonstration purposes.

To train the Protonet on this task, cd into this repo's src root folder and execute:

$ python train.py

The script takes the following command line options:

  • dataset_root: the root directory where tha dataset is stored, default to '../dataset'

  • nepochs: number of epochs to train for, default to 100

  • learning_rate: learning rate for the model, default to 0.001

  • lr_scheduler_step: StepLR learning rate scheduler step, default to 20

  • lr_scheduler_gamma: StepLR learning rate scheduler gamma, default to 0.5

  • iterations: number of episodes per epoch. default to 100

  • classes_per_it_tr: number of random classes per episode for training. default to 60

  • num_support_tr: number of samples per class to use as support for training. default to 5

  • num_query_tr: nnumber of samples per class to use as query for training. default to 5

  • classes_per_it_val: number of random classes per episode for validation. default to 5

  • num_support_val: number of samples per class to use as support for validation. default to 5

  • num_query_val: number of samples per class to use as query for validation. default to 15

  • manual_seed: input for the manual seeds initializations, default to 7

  • cuda: enables cuda (store True)

Running the command without arguments will train the models with the default hyperparamters values (producing results shown above).

Performances

We are trying to reproduce the reference paper performaces, we'll update here our best results.

Model 1-shot (5-way Acc.) 5-shot (5-way Acc.) 1 -shot (20-way Acc.) 5-shot (20-way Acc.)
Reference Paper 98.8% 99.7% 96.0% 98.9%
This repo 98.5%** 99.6%* 95.1%° 98.6%°°

* achieved using default parameters (using --cuda option)

** achieved running python train.py --cuda -nsTr 1 -nsVa 1

° achieved running python train.py --cuda -nsTr 1 -nsVa 1 -cVa 20

°° achieved running python train.py --cuda -nsTr 5 -nsVa 5 -cVa 20

Helpful links

.bib citation

cite the paper as follows (copied-pasted it from arxiv for you):

@article{DBLP:journals/corr/SnellSZ17,
  author    = {Jake Snell and
               Kevin Swersky and
               Richard S. Zemel},
  title     = {Prototypical Networks for Few-shot Learning},
  journal   = {CoRR},
  volume    = {abs/1703.05175},
  year      = {2017},
  url       = {http://arxiv.org/abs/1703.05175},
  archivePrefix = {arXiv},
  eprint    = {1703.05175},
  timestamp = {Wed, 07 Jun 2017 14:41:38 +0200},
  biburl    = {http://dblp.org/rec/bib/journals/corr/SnellSZ17},
  bibsource = {dblp computer science bibliography, http://dblp.org}
}

License

This project is licensed under the MIT License

Copyright (c) 2018 Daniele E. Ciriello, Orobix Srl (www.orobix.com).

Owner
Orobix
Orobix
Migration of Edge-based Distributed Federated Learning

FedFly: Towards Migration in Edge-based Distributed Federated Learning About the research Due to mobility, a device participating in Federated Learnin

qub-blesson 11 Nov 13, 2022
Rayvens makes it possible for data scientists to access hundreds of data services within Ray with little effort.

Rayvens augments Ray with events. With Rayvens, Ray applications can subscribe to event streams, process and produce events. Rayvens leverages Apache

CodeFlare 32 Dec 25, 2022
Torchserve server using a YoloV5 model running on docker with GPU and static batch inference to perform production ready inference.

Yolov5 running on TorchServe (GPU compatible) ! This is a dockerfile to run TorchServe for Yolo v5 object detection model. (TorchServe (PyTorch librar

82 Nov 29, 2022
Chinese license plate recognition

AgentCLPR 简介 一个基于 ONNXRuntime、AgentOCR 和 License-Plate-Detector 项目开发的中国车牌检测识别系统。 车牌识别效果 支持多种车牌的检测和识别(其中单层车牌识别效果较好): 单层车牌: [[[[373, 282], [69, 284],

AgentMaker 26 Dec 25, 2022
Recurrent Variational Autoencoder that generates sequential data implemented with pytorch

Pytorch Recurrent Variational Autoencoder Model: This is the implementation of Samuel Bowman's Generating Sentences from a Continuous Space with Kim's

Daniil Gavrilov 347 Nov 14, 2022
JAX + dataclasses

jax_dataclasses jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for: Pytree registrati

Brent Yi 35 Dec 21, 2022
This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of Coordinate Independent Convolutional Networks.

Orientation independent Möbius CNNs This repository implements and evaluates convolutional networks on the Möbius strip as toy model instantiations of

Maurice Weiler 59 Dec 09, 2022
Boosted CVaR Classification (NeurIPS 2021)

Boosted CVaR Classification Runtian Zhai, Chen Dan, Arun Sai Suggala, Zico Kolter, Pradeep Ravikumar NeurIPS 2021 Table of Contents Quick Start Train

Runtian Zhai 4 Feb 15, 2022
Collision risk estimation using stochastic motion models

collision_risk_estimation Collision risk estimation using stochastic motion models. This is a new approach, based on stochastic models, to predict the

Unmesh 7 Jun 26, 2022
A Semantic Segmentation Network for Urban-Scale Building Footprint Extraction Using RGB Satellite Imagery

A Semantic Segmentation Network for Urban-Scale Building Footprint Extraction Using RGB Satellite Imagery This repository is the official implementati

Aatif Jiwani 42 Dec 08, 2022
Implementation for paper LadderNet: Multi-path networks based on U-Net for medical image segmentation

Implementation for paper LadderNet: Multi-path networks based on U-Net for medical image segmentation This implementation is based on orobix implement

Juntang Zhuang 116 Sep 06, 2022
Official Implementation for Fast Training of Neural Lumigraph Representations using Meta Learning.

Fast Training of Neural Lumigraph Representations using Meta Learning Project Page | Paper | Data Alexander W. Bergman, Petr Kellnhofer, Gordon Wetzst

Alex 39 Oct 08, 2022
Put blind watermark into a text with python

text_blind_watermark Put blind watermark into a text. Can be used in Wechat dingding ... How to Use install pip install text_blind_watermark Alice Pu

郭飞 164 Dec 30, 2022
Flappy bird automation using Neuroevolution of Augmenting Topologies (NEAT) in Python

FlappyAI Flappy bird automation using Neuroevolution of Augmenting Topologies (NEAT) in Python Everything Used Genetic Algorithm especially NEAT conce

Eryawan Presma Y. 2 Mar 24, 2022
Open source simulator for autonomous vehicles built on Unreal Engine / Unity, from Microsoft AI & Research

Welcome to AirSim AirSim is a simulator for drones, cars and more, built on Unreal Engine (we now also have an experimental Unity release). It is open

Microsoft 13.8k Jan 05, 2023
Official PyTorch code for "BAM: Bottleneck Attention Module (BMVC2018)" and "CBAM: Convolutional Block Attention Module (ECCV2018)"

BAM and CBAM Official PyTorch code for "BAM: Bottleneck Attention Module (BMVC2018)" and "CBAM: Convolutional Block Attention Module (ECCV2018)" Updat

Jongchan Park 1.7k Jan 01, 2023
Decision Transformer: A brand new Offline RL Pattern

DecisionTransformer_StepbyStep Intro Decision Transformer: A brand new Offline RL Pattern. 这是关于NeurIPS 2021 热门论文Decision Transformer的复现。 👍 原文地址: Deci

Irving 14 Nov 22, 2022
Apply our monocular depth boosting to your own network!

MergeNet - Boost Your Own Depth Boost custom or edited monocular depth maps using MergeNet Input Original result After manual editing of base You can

Computational Photography Lab @ SFU 142 Dec 17, 2022
A framework for joint super-resolution and image synthesis, without requiring real training data

SynthSR This repository contains code to train a Convolutional Neural Network (CNN) for Super-resolution (SR), or joint SR and data synthesis. The met

83 Jan 01, 2023
NeuralDiff: Segmenting 3D objects that move in egocentric videos

NeuralDiff: Segmenting 3D objects that move in egocentric videos Project Page | Paper + Supplementary | Video About This repository contains the offic

Vadim Tschernezki 14 Dec 05, 2022