Deep Reinforcement Learning for Keras.

Overview

Deep Reinforcement Learning for Keras

Build Status Documentation License Join the chat at https://gitter.im/keras-rl/Lobby

What is it?

keras-rl implements some state-of-the art deep reinforcement learning algorithms in Python and seamlessly integrates with the deep learning library Keras.

Furthermore, keras-rl works with OpenAI Gym out of the box. This means that evaluating and playing around with different algorithms is easy.

Of course you can extend keras-rl according to your own needs. You can use built-in Keras callbacks and metrics or define your own. Even more so, it is easy to implement your own environments and even algorithms by simply extending some simple abstract classes. Documentation is available online.

What is included?

As of today, the following algorithms have been implemented:

  • Deep Q Learning (DQN) [1], [2]
  • Double DQN [3]
  • Deep Deterministic Policy Gradient (DDPG) [4]
  • Continuous DQN (CDQN or NAF) [6]
  • Cross-Entropy Method (CEM) [7], [8]
  • Dueling network DQN (Dueling DQN) [9]
  • Deep SARSA [10]
  • Asynchronous Advantage Actor-Critic (A3C) [5]
  • Proximal Policy Optimization Algorithms (PPO) [11]

You can find more information on each agent in the doc.

Installation

  • Install Keras-RL from Pypi (recommended):
pip install keras-rl
  • Install from Github source:
git clone https://github.com/keras-rl/keras-rl.git
cd keras-rl
python setup.py install

Examples

If you want to run the examples, you'll also have to install:

For atari example you will also need:

  • Pillow: pip install Pillow
  • gym[atari]: Atari module for gym. Use pip install gym[atari]

Once you have installed everything, you can try out a simple example:

python examples/dqn_cartpole.py

This is a very simple example and it should converge relatively quickly, so it's a great way to get started! It also visualizes the game during training, so you can watch it learn. How cool is that?

Some sample weights are available on keras-rl-weights.

If you have questions or problems, please file an issue or, even better, fix the problem yourself and submit a pull request!

External Projects

You're using Keras-RL on a project? Open a PR and share it!

Visualizing Training Metrics

To see graphs of your training progress and compare across runs, run pip install wandb and add the WandbLogger callback to your agent's fit() call:

from rl.callbacks import WandbLogger

...

agent.fit(env, nb_steps=50000, callbacks=[WandbLogger()])

For more info and options, see the W&B docs.

Citing

If you use keras-rl in your research, you can cite it as follows:

@misc{plappert2016kerasrl,
    author = {Matthias Plappert},
    title = {keras-rl},
    year = {2016},
    publisher = {GitHub},
    journal = {GitHub repository},
    howpublished = {\url{https://github.com/keras-rl/keras-rl}},
}

References

  1. Playing Atari with Deep Reinforcement Learning, Mnih et al., 2013
  2. Human-level control through deep reinforcement learning, Mnih et al., 2015
  3. Deep Reinforcement Learning with Double Q-learning, van Hasselt et al., 2015
  4. Continuous control with deep reinforcement learning, Lillicrap et al., 2015
  5. Asynchronous Methods for Deep Reinforcement Learning, Mnih et al., 2016
  6. Continuous Deep Q-Learning with Model-based Acceleration, Gu et al., 2016
  7. Learning Tetris Using the Noisy Cross-Entropy Method, Szita et al., 2006
  8. Deep Reinforcement Learning (MLSS lecture notes), Schulman, 2016
  9. Dueling Network Architectures for Deep Reinforcement Learning, Wang et al., 2016
  10. Reinforcement learning: An introduction, Sutton and Barto, 2011
  11. Proximal Policy Optimization Algorithms, Schulman et al., 2017
You might also like...
Distributed Deep learning with Keras & Spark
Distributed Deep learning with Keras & Spark

Elephas: Distributed Deep Learning with Keras & Spark Elephas is an extension of Keras, which allows you to run distributed deep learning models at sc

QKeras: a quantization deep learning library for Tensorflow Keras

QKeras github.com/google/qkeras QKeras 0.8 highlights: Automatic quantization using QKeras; Stochastic behavior (including stochastic rouding) is disa

MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.
MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.

MMdnn MMdnn is a comprehensive and cross-framework tool to convert, visualize and diagnose deep learning (DL) models. The "MM" stands for model manage

Advanced Deep Learning with TensorFlow 2 and Keras (Updated for 2nd Edition)
Advanced Deep Learning with TensorFlow 2 and Keras (Updated for 2nd Edition)

Advanced Deep Learning with TensorFlow 2 and Keras (Updated for 2nd Edition)

Keras like implementation of Deep Learning architectures from scratch using numpy.

Mini-Keras Keras like implementation of Deep Learning architectures from scratch using numpy. How to contribute? The project contains implementations

Realtime Face Anti Spoofing with Face Detector based on Deep Learning using Tensorflow/Keras and OpenCV
Realtime Face Anti Spoofing with Face Detector based on Deep Learning using Tensorflow/Keras and OpenCV

Realtime Face Anti-Spoofing Detection 🤖 Realtime Face Anti Spoofing Detection with Face Detector to detect real and fake faces Please star this repo

This source code is implemented using keras library based on "Automatic ocular artifacts removal in EEG using deep learning"

CSP_Deep_EEG This source code is implemented using keras library based on "Automatic ocular artifacts removal in EEG using deep learning" {https://www

Vision Deep-Learning using Tensorflow, Keras.

Welcome! I am a computer vision deep learning developer working in Korea. This is my blog, and you can see everything I've studied here. https://www.n

A deep learning network built with TensorFlow and Keras to classify gender and estimate age.
A deep learning network built with TensorFlow and Keras to classify gender and estimate age.

Convolutional Neural Network (CNN). This repository contains a source code of a deep learning network built with TensorFlow and Keras to classify gend

Releases(v0.4.2)
Classify music genre from a 10 second sound stream using a Neural Network.

MusicGenreClassification Academic research in the field of Deep Learning (Deep Neural Networks) and Sound Processing, Tel Aviv University. Featured in

Matan Lachmish 453 Dec 27, 2022
Official implementation of "Intrinsic Dimension, Persistent Homology and Generalization in Neural Networks", NeurIPS 2021.

PHDimGeneralization Official implementation of "Intrinsic Dimension, Persistent Homology and Generalization in Neural Networks", NeurIPS 2021. Overvie

Tolga Birdal 13 Nov 08, 2022
Trading and Backtesting environment for training reinforcement learning agent or simple rule base algo.

TradingGym TradingGym is a toolkit for training and backtesting the reinforcement learning algorithms. This was inspired by OpenAI Gym and imitated th

Yvictor 1.1k Jan 02, 2023
A python/pytorch utility library

A python/pytorch utility library

Jiaqi Gu 5 Dec 02, 2022
NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

5 Nov 03, 2022
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

Xingyu Lin 93 Jan 05, 2023
Back to the Feature: Learning Robust Camera Localization from Pixels to Pose (CVPR 2021)

Back to the Feature with PixLoc We introduce PixLoc, a neural network for end-to-end learning of camera localization from an image and a 3D model via

Computer Vision and Geometry Lab 610 Jan 05, 2023
Neural network for digit classification powered by cuda

cuda_nn_mnist Neural network library for digit classification powered by cuda Resources The library was built to work with MNIST dataset. python-mnist

Nikita Ardashev 1 Dec 20, 2021
This repository collects 100 papers related to negative sampling methods.

Negative-Sampling-Paper This repository collects 100 papers related to negative sampling methods, covering multiple research fields such as Recommenda

RUCAIBox 119 Dec 29, 2022
The DL Streamer Pipeline Zoo is a catalog of optimized media and media analytics pipelines.

The DL Streamer Pipeline Zoo is a catalog of optimized media and media analytics pipelines. It includes tools for downloading pipelines and their dependencies and tools for measuring their performace

8 Dec 04, 2022
Official Implementation for Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation

Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation We present a generic image-to-image translation framework, pixel2style2pixel (pSp

2.8k Dec 30, 2022
Manifold-Mixup implementation for fastai V2

Manifold Mixup Unofficial implementation of ManifoldMixup (Proceedings of ICML 19) for fast.ai (V2) based on Shivam Saboo's pytorch implementation of

Nestor Demeure 16 Jul 25, 2022
Riemann Noise Injection With PyTorch

Riemann Noise Injection - PyTorch A module for modeling GAN noise injection based on Riemann geometry, as described in Ruili Feng, Deli Zhao, and Zhen

2 May 27, 2022
A Lightweight Face Recognition and Facial Attribute Analysis (Age, Gender, Emotion and Race) Library for Python

deepface Deepface is a lightweight face recognition and facial attribute analysis (age, gender, emotion and race) framework for python. It is a hybrid

Sefik Ilkin Serengil 5.2k Jan 02, 2023
pq is a jq-like Pickle file viewer

pq PQ is a jq-like viewer/processing tool for pickle files. howto # pq '' file.pkl {'other': 456, 'test': 123} # pq 'table' file.pkl |other|test| | 45

3 Mar 15, 2022
Stacs-ci - A set of modules to enable integration of STACS with commonly used CI / CD systems

Static Token And Credential Scanner CI Integrations What is it? STACS is a YARA

STACS 18 Aug 04, 2022
This is the repository for paper NEEDLE: Towards Non-invertible Backdoor Attack to Deep Learning Models.

This is the repository for paper NEEDLE: Towards Non-invertible Backdoor Attack to Deep Learning Models.

1 Oct 25, 2021
Biomarker identification for COVID-19 Severity in BALF cells Single-cell RNA-seq data

scBALF Covid-19 dataset Analysis Here is the Github page that has the codes for the bioinformatics pipeline described in the paper COVID-Datathon: Bio

Nami Niyakan 2 May 21, 2022
Algorithm to texture 3D reconstructions from multi-view stereo images

MVS-Texturing Welcome to our project that textures 3D reconstructions from images. This project focuses on 3D reconstructions generated using structur

Nils Moehrle 766 Jan 04, 2023
A simple baseline for 3d human pose estimation in PyTorch.

3d_pose_baseline_pytorch A PyTorch implementation of a simple baseline for 3d human pose estimation. You can check the original Tensorflow implementat

weigq 312 Jan 06, 2023