Deep Reinforcement Learning with pytorch & visdom

Overview

Deep Reinforcement Learning with

pytorch & visdom


  • Sample testings of trained agents (DQN on Breakout, A3C on Pong, DoubleDQN on CartPole, continuous A3C on InvertedPendulum(MuJoCo)):
  • Sample on-line plotting while training an A3C agent on Pong (with 16 learner processes): a3c_pong_plot

  • Sample loggings while training a DQN agent on CartPole (we use WARNING as the logging level currently to get rid of the INFO printouts from visdom):

[WARNING ] (MainProcess) <===================================>
[WARNING ] (MainProcess) bash$: python -m visdom.server
[WARNING ] (MainProcess) http://localhost:8097/env/daim_17040900
[WARNING ] (MainProcess) <===================================> DQN
[WARNING ] (MainProcess) <-----------------------------------> Env
[WARNING ] (MainProcess) Creating {gym | CartPole-v0} w/ Seed: 123
[INFO    ] (MainProcess) Making new env: CartPole-v0
[WARNING ] (MainProcess) Action Space: [0, 1]
[WARNING ] (MainProcess) State  Space: 4
[WARNING ] (MainProcess) <-----------------------------------> Model
[WARNING ] (MainProcess) MlpModel (
  (fc1): Linear (4 -> 16)
  (rl1): ReLU ()
  (fc2): Linear (16 -> 16)
  (rl2): ReLU ()
  (fc3): Linear (16 -> 16)
  (rl3): ReLU ()
  (fc4): Linear (16 -> 2)
)
[WARNING ] (MainProcess) No Pretrained Model. Will Train From Scratch.
[WARNING ] (MainProcess) <===================================> Training ...
[WARNING ] (MainProcess) Validation Data @ Step: 501
[WARNING ] (MainProcess) Start  Training @ Step: 501
[WARNING ] (MainProcess) Reporting       @ Step: 2500 | Elapsed Time: 5.32397913933
[WARNING ] (MainProcess) Training Stats:   epsilon:          0.972
[WARNING ] (MainProcess) Training Stats:   total_reward:     2500.0
[WARNING ] (MainProcess) Training Stats:   avg_reward:       21.7391304348
[WARNING ] (MainProcess) Training Stats:   nepisodes:        115
[WARNING ] (MainProcess) Training Stats:   nepisodes_solved: 114
[WARNING ] (MainProcess) Training Stats:   repisodes_solved: 0.991304347826
[WARNING ] (MainProcess) Evaluating      @ Step: 2500
[WARNING ] (MainProcess) Iteration: 2500; v_avg: 1.73136949539
[WARNING ] (MainProcess) Iteration: 2500; tderr_avg: 0.0964358523488
[WARNING ] (MainProcess) Iteration: 2500; steps_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; steps_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; reward_avg: 9.34579439252
[WARNING ] (MainProcess) Iteration: 2500; reward_std: 0.798395631184
[WARNING ] (MainProcess) Iteration: 2500; nepisodes: 107
[WARNING ] (MainProcess) Iteration: 2500; nepisodes_solved: 106
[WARNING ] (MainProcess) Iteration: 2500; repisodes_solved: 0.990654205607
[WARNING ] (MainProcess) Saving Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth ...
[WARNING ] (MainProcess) Saved  Model    @ Step: 2500: /home/zhang/ws/17_ws/pytorch-rl/models/daim_17040900.pth.
[WARNING ] (MainProcess) Resume Training @ Step: 2500
...

What is included?

This repo currently contains the following agents:

  • Deep Q Learning (DQN) [1], [2]
  • Double DQN [3]
  • Dueling network DQN (Dueling DQN) [4]
  • Asynchronous Advantage Actor-Critic (A3C) (w/ both discrete/continuous action space support) [5], [6]
  • Sample Efficient Actor-Critic with Experience Replay (ACER) (currently w/ discrete action space support (Truncated Importance Sampling, 1st Order TRPO)) [7], [8]

Work in progress:

  • Testing ACER

Future Plans:

  • Deep Deterministic Policy Gradient (DDPG) [9], [10]
  • Continuous DQN (CDQN or NAF) [11]

Code structure & Naming conventions:

NOTE: we follow the exact code structure as pytorch-dnc so as to make the code easily transplantable.

  • ./utils/factory.py

We suggest the users refer to ./utils/factory.py, where we list all the integrated Env, Model, Memory, Agent into Dict's. All of those four core classes are implemented in ./core/. The factory pattern in ./utils/factory.py makes the code super clean, as no matter what type of Agent you want to train, or which type of Env you want to train on, all you need to do is to simply modify some parameters in ./utils/options.py, then the ./main.py will do it all (NOTE: this ./main.py file never needs to be modified).

  • namings

To make the code more clean and readable, we name the variables using the following pattern (mainly in inherited Agent's):

  • *_vb: torch.autograd.Variable's or a list of such objects
  • *_ts: torch.Tensor's or a list of such objects
  • otherwise: normal python datatypes

Dependencies


How to run:

You only need to modify some parameters in ./utils/options.py to train a new configuration.

  • Configure your training in ./utils/options.py:
  • line 14: add an entry into CONFIGS to define your training (agent_type, env_type, game, model_type, memory_type)
  • line 33: choose the entry you just added
  • line 29-30: fill in your machine/cluster ID (MACHINE) and timestamp (TIMESTAMP) to define your training signature (MACHINE_TIMESTAMP), the corresponding model file and the log file of this training will be saved under this signature (./models/MACHINE_TIMESTAMP.pth & ./logs/MACHINE_TIMESTAMP.log respectively). Also the visdom visualization will be displayed under this signature (first activate the visdom server by type in bash: python -m visdom.server &, then open this address in your browser: http://localhost:8097/env/MACHINE_TIMESTAMP)
  • line 32: to train a model, set mode=1 (training visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP); to test the model of this current training, all you need to do is to set mode=2 (testing visualization will be under http://localhost:8097/env/MACHINE_TIMESTAMP_test).
  • Run:

python main.py


Bonus Scripts :)

We also provide 2 additional scripts for quickly evaluating your results after training. (Dependecies: lmj-plot)

  • plot.sh (e.g., plot from log file: logs/machine1_17080801.log)
  • ./plot.sh machine1 17080801
  • the generated figures will be saved into figs/machine1_17080801/
  • plot_compare.sh (e.g., compare log files: logs/machine1_17080801.log,logs/machine2_17080802.log)

./plot.sh 00 machine1 17080801 machine2 17080802

  • the generated figures will be saved into figs/compare_00/
  • the color coding will be in the order of: red green blue magenta yellow cyan

Repos we referred to during the development of this repo:


Citation

If you find this library useful and would like to cite it, the following would be appropriate:

@misc{pytorch-rl,
  author = {Zhang, Jingwei and Tai, Lei},
  title = {jingweiz/pytorch-rl},
  url = {https://github.com/jingweiz/pytorch-rl},
  year = {2017}
}
Owner
Jingwei Zhang
Jingwei Zhang
Global-Local Attention for Emotion Recognition

Global-Local Attention for Emotion Recognition Requirements Python 3 Install tensorflow (or tensorflow-gpu) = 2.0.0 Install some other packages pip i

Minh Nhat Le 15 Apr 21, 2022
Styled Augmented Translation

SAT Style Augmented Translation Introduction By collecting high-quality data, we were able to train a model that outperforms Google Translate on 6 dif

139 Dec 29, 2022
Personalized Federated Learning using Pytorch (pFedMe)

Personalized Federated Learning with Moreau Envelopes (NeurIPS 2020) This repository implements all experiments in the paper Personalized Federated Le

Charlie Dinh 226 Dec 30, 2022
Data reduction pipeline for KOALA on the AAT.

KOALA KOALA, the Kilofibre Optical AAT Lenslet Array, is a wide-field, high efficiency, integral field unit used by the AAOmega spectrograph on the 3.

4 Sep 26, 2022
Tensors and neural networks in Haskell

Hasktorch Hasktorch is a library for tensors and neural networks in Haskell. It is an independent open source community project which leverages the co

hasktorch 920 Jan 04, 2023
Generating Band-Limited Adversarial Surfaces Using Neural Networks

Generating Band-Limited Adversarial Surfaces Using Neural Networks This is the official repository of the technical report that was published on arXiv

3 Jul 26, 2022
Learning Features with Parameter-Free Layers (ICLR 2022)

Learning Features with Parameter-Free Layers (ICLR 2022) Dongyoon Han, YoungJoon Yoo, Beomyoung Kim, Byeongho Heo | Paper NAVER AI Lab, NAVER CLOVA Up

NAVER AI 65 Dec 07, 2022
PyTorch CZSL framework containing GQA, the open-world setting, and the CGE and CompCos methods.

Compositional Zero-Shot Learning This is the official PyTorch code of the CVPR 2021 works Learning Graph Embeddings for Compositional Zero-shot Learni

EML Tübingen 70 Dec 27, 2022
On Generating Extended Summaries of Long Documents

ExtendedSumm This repository contains the implementation details and datasets used in On Generating Extended Summaries of Long Documents paper at the

Georgetown Information Retrieval Lab 76 Sep 05, 2022
This is the repository for the NeurIPS-21 paper [Contrastive Graph Poisson Networks: Semi-Supervised Learning with Extremely Limited Labels].

CGPN This is the repository for the NeurIPS-21 paper [Contrastive Graph Poisson Networks: Semi-Supervised Learning with Extremely Limited Labels]. Req

10 Sep 12, 2022
Omniscient Video Super-Resolution

Omniscient Video Super-Resolution This is the official code of OVSR (Omniscient Video Super-Resolution, ICCV 2021). This work is based on PFNL. Datase

36 Oct 27, 2022
Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more"

The Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more" Arxiv preprint Louay Hazami   ·   Rayhane Mama   ·   Ragavan Thurairatn

Rayhane Mama 144 Dec 23, 2022
Official PyTorch Code of GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection (CVPR 2021)

GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Mo

Abhinav Kumar 76 Jan 02, 2023
Few-Shot Graph Learning for Molecular Property Prediction

Few-shot Graph Learning for Molecular Property Prediction Introduction This is the source code and dataset for the following paper: Few-shot Graph Lea

Zhichun Guo 94 Dec 12, 2022
PySlowFast: video understanding codebase from FAIR for reproducing state-of-the-art video models.

PySlowFast PySlowFast is an open source video understanding codebase from FAIR that provides state-of-the-art video classification models with efficie

Meta Research 5.3k Jan 03, 2023
Exploring Simple 3D Multi-Object Tracking for Autonomous Driving (ICCV 2021)

Exploring Simple 3D Multi-Object Tracking for Autonomous Driving Chenxu Luo, Xiaodong Yang, Alan Yuille Exploring Simple 3D Multi-Object Tracking for

QCraft 141 Nov 21, 2022
A Small and Easy approach to the BraTS2020 dataset (2D Segmentation)

BraTS2020 A Light & Scalable Solution to BraTS2020 | Medical Brain Tumor Segmentation (2D Segmentation) Developed the segmentation models for segregat

Gunjan Haldar 0 Jan 19, 2022
Rendering Point Clouds with Compute Shaders

Compute Shader Based Point Cloud Rendering This repository contains the source code to our techreport: Rendering Point Clouds with Compute Shaders and

Markus Schütz 460 Jan 05, 2023
Train emoji embeddings based on emoji descriptions.

emoji2vec This is my attempt to train, visualize and evaluate emoji embeddings as presented by Ben Eisner, Tim Rocktäschel, Isabelle Augenstein, Matko

Miruna Pislar 17 Sep 03, 2022