Second Order Optimization and Curvature Estimation with K-FAC in JAX.

Overview

KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX

Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX

CI status docs pypi

KFAC-JAX is a library built on top of JAX for second-order optimization of neural networks and for computing scalable curvature approximations. The main goal of the library is to provide researchers with an easy-to-use implementation of the K-FAC optimizer and curvature estimator.

Installation

KFAC-JAX is written in pure Python, but depends on C++ code via JAX.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install KFAC-JAX using pip:

$ pip install git+https://github.com/deepmind/kfac-jax

Alternatively, you can install via PyPI:

$ pip install -U kfac-jax

Our examples rely on additional libraries, all of which you can install using:

$ pip install -r requirements_examples.txt

Quickstart

Let's take a look at a simple example of training a neural network, defined using Haiku, with the K-FAC optimizer:

import haiku as hk
import jax
import jax.numpy as jnp
import kfac_jax

# Hyper parameters
NUM_CLASSES = 10
L2_REG = 1e-3
NUM_BATCHES = 100


def make_dataset_iterator(batch_size):
  # Dummy dataset, in practice this should be your dataset pipeline
  for _ in range(NUM_BATCHES):
    yield jnp.zeros([batch_size, 100]), jnp.ones([batch_size], dtype="int32") 


def softmax_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray):
  """Softmax cross entropy loss."""
  # We assume integer labels
  assert logits.ndim == targets.ndim + 1
  
  # Tell KFAC-JAX this model represents a classifier
  # See https://kfac-jax.readthedocs.io/en/latest/overview.html#supported-losses
  kfac_jax.register_softmax_cross_entropy_loss(logits, targets)
  log_p = jax.nn.log_softmax(logits, axis=-1)
  return - jax.vmap(lambda x, y: x[y])(log_p, targets)


def model_fn(x):
  """A Haiku MLP model function - three hidden layer network with tanh."""
  return hk.nets.MLP(
    output_sizes=(50, 50, 50, NUM_CLASSES),
    with_bias=True,
    activation=jax.nn.tanh,
  )(x)


# The Haiku transformed model
hk_model = hk.without_apply_rng(hk.transform(model_fn))


def loss_fn(model_params, model_batch):
  """The loss function to optimize."""
  x, y = model_batch
  logits = hk_model.apply(model_params, x)
  loss = jnp.mean(softmax_cross_entropy(logits, y))
  
  # The optimizer assumes that the function you provide has already added
  # the L2 regularizer to its gradients.
  return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0


# Create the optimizer
optimizer = kfac_jax.Optimizer(
  value_and_grad_func=jax.value_and_grad(loss_fn),
  l2_reg=L2_REG,
  value_func_has_aux=False,
  value_func_has_state=False,
  value_func_has_rng=False,
  use_adaptive_learning_rate=True,
  use_adaptive_momentum=True,
  use_adaptive_damping=True,
  initial_damping=1.0,
  multi_device=False,
)

input_dataset = make_dataset_iterator(128)
rng = jax.random.PRNGKey(42)
dummy_images, dummy_labels = next(input_dataset)
rng, key = jax.random.split(rng)
params = hk_model.init(key, dummy_images)
rng, key = jax.random.split(rng)
opt_state = optimizer.init(params, key, (dummy_images, dummy_labels))

# Training loop
for i, batch in enumerate(input_dataset):
  rng, key = jax.random.split(rng)
  params, opt_state, stats = optimizer.step(
      params, opt_state, key, batch=batch, global_step_int=i)
  print(i, stats)

Do not stage (jit or pmap) the optimizer

You should not apply jax.jit or jax.pmap to the call to Optimizer.step. This is already done for you automatically by the optimizer class. To control the staging behaviour of the optimizer set the flag multi_device to True for pmap and to False for jit.

Do not stage (jit or pmap) the loss function

The value_and_grad_func argument provided to the optimizer should compute the loss function value and its gradients. Since the optimizer already stages its step function internally, applying jax.jit to value_and_grad_func is NOT recommended. Importantly, applying jax.pmap is WRONG and most likely will lead to errors.

Registering the model loss function

In order for KFAC-JAX to be able to correctly approximate the curvature matrix of the model it needs to know the precise loss function that you want to optimize. This is done via registration with certain functions provided by the library. For instance, in the example above this is done via the call to kfac_jax.register_softmax_cross_entropy_loss, which tells the optimizer that the loss is the standard softmax cross-entropy. If you don't do this you will get an error when you try to call the optimizer. For all supported loss functions please read the documentation.

Important: The optimizer assumes that the loss is averaged over examples in the minibatch. It is crucial that you follow this convention.

Other model function options

Oftentimes, one will want to output some auxiliary statistics or metrics in addition to the loss value. This can already be done in the value_and_grad_func, in which case we follow the same conventions as JAX and expect the output to be (loss, aux), grads. Similarly, the loss function can take an additional function state (batch norm layers usually have this) or an PRNG key (used in stochastic layers). All of these, however, need to be explicitly told to the optimizer via its arguments value_func_has_aux, value_func_has_state and value_func_has_rng.

Verify optimizer registrations

We strongly encourage the user to pay attention to the logging messages produced by the automatic registration system, in order to ensure that it has correctly understood your model. For the example above this looks like this:

==================================================
Graph parameter registrations:
{'mlp/~/linear_0': {'b': 'Auto[dense_with_bias_3]',
                    'w': 'Auto[dense_with_bias_3]'},
 'mlp/~/linear_1': {'b': 'Auto[dense_with_bias_2]',
                    'w': 'Auto[dense_with_bias_2]'},
 'mlp/~/linear_2': {'b': 'Auto[dense_with_bias_1]',
                    'w': 'Auto[dense_with_bias_1]'},
 'mlp/~/linear_3': {'b': 'Auto[dense_with_bias_0]',
                    'w': 'Auto[dense_with_bias_0]'}}
==================================================

As can be seen from this message, the library has correctly detected all parameters of the model to be part of dense layers.

Further reading

For a high level overview of the optimizer, the different curvature approximations, and the supported layers, please see the documentation.

Citing KFAC-JAX

To cite this repository:

@software{kfac-jax2022github,
  author = {Aleksandar Botev and James Martens},
  title = {{KFAC-JAX}},
  url = {http://github.com/deepmind/kfac-jax},
  version = {0.0.1},
  year = {2022},
}

In this bibtex entry, the version number is intended to be from kfac_jax/__init__.py, and the year corresponds to the project's open-source release.

Comments
  • Unpack Error when using KFAC with block-diagonal for Dense networks

    Unpack Error when using KFAC with block-diagonal for Dense networks

    Hi,

    I was trying to get the example code in the readme working with the BlockDiagonal approximation. The default simply uses the normal diagonal. However, when I try to define my optimizer like this:

    opt = kfac_jax.Optimizer(
        value_and_grad_func=jax.value_and_grad(partial(expected_model_likelihood, l2=0.001)),
        l2_reg=0.001,
        use_adaptive_learning_rate=True,
        use_adaptive_damping=True,
        use_adaptive_momentum=True,
        initial_damping=1.0,
        min_damping= 0.0001,
        layer_tag_to_block_ctor={'generic_tag': kfac_jax.DenseTwoKroneckerFactored},  # Specify the approximation type here
        estimation_mode='ggn_curvature_prop',
        multi_device=False
    )
    

    then when I try to use this optimizer I get the following ValueError:

    del pmap_axis_name
    x, = estimation_data["inputs"]
    dy, = estimation_data["outputs_tangent"]
    assert utils.first_dim_is_size(batch_size, x, dy)
    
    ValueError: not enough values to unpack (expected 1, got 0)
    

    Corresponding to the curvature update method in class DenseTwoKroneckerFactored (line 1165) of _src.curvature_blocks.py. The estimation data dictionary is filled with the parameters and parameters-tangents, but I do not understand the codebase sufficiently to grasp why the inputs and outputs_tangent keys are not filled.

    In this way I cannot get the actual KFAC of this repo working... Are there perhaps some examples that make use of the DenseTwoKroneckerFactored? As far as I can tell all provided examples simply make use of the diagonal Fisher for optimization, not KFAC. But I may be wrong of course.

    opened by joeryjoery 4
  • TypeError: 'ShapedArray' object is not iterable

    TypeError: 'ShapedArray' object is not iterable

    Hi,

    I tried to run the example code, but the code stops at primal_output = self.bind(*arg_values, **kwargs), and returns the error "TypeError: 'ShapedArray' object is not iterable". Could you please help me to solve this problem? Thanks.

    opened by ltz0120 4
  • How to use kfac to train two probabilistic models jointly?

    How to use kfac to train two probabilistic models jointly?

    In my application, I need to jointly optimize two probabilistic models. They contribute to two different terms in the final loss function.

    I am wondering what would be the recommended pattern of using kfac ?
    More specifically, does it make sense to invoke kfac_jax.register_normal_predictive_distribution twice (for the two probabilistic models respectively) ?

    Thanks in advance!

    opened by wangleiphy 3
  • Correct return type annotation for BlockDiagonalCurvature.params_vector_to_blocks_vectors.

    Correct return type annotation for BlockDiagonalCurvature.params_vector_to_blocks_vectors.

    Correct return type annotation for BlockDiagonalCurvature.params_vector_to_blocks_vectors.

    jax recently added annotations for jax.tree_util and tree_leaves returns a list rather than a tuple.

    opened by copybara-service[bot] 1
  • Correct buffer donation of Optimizer._step.

    Correct buffer donation of Optimizer._step.

    Correct buffer donation of Optimizer._step.

    Buffers can only be donated if they match the shape and type of the output, which is not true for the rng state or the batch item.

    opened by copybara-service[bot] 1
  • * Modularizing the utilities file into a separate sub-package.

    * Modularizing the utilities file into a separate sub-package.

    • Modularizing the utilities file into a separate sub-package.
    • Bumping the version of the ci-actions, to remove some depracation warnings.
    • Bumping chex version.
    opened by copybara-service[bot] 0
  • - Improving docstring for optimizer. In particular regarding the damping parameter and LR/momentum/damping adaptation methods.

    - Improving docstring for optimizer. In particular regarding the damping parameter and LR/momentum/damping adaptation methods.

    • Improving docstring for optimizer. In particular regarding the damping parameter and LR/momentum/damping adaptation methods.
    • Fixing bug in default value of normalization_mode in examples classifier loss.
    opened by copybara-service[bot] 0
  • - Adding normalization modes feature to classifier loss.

    - Adding normalization modes feature to classifier loss.

    • Adding normalization modes feature to classifier loss.
    • Removing unused/pointless return values for registration functions.
    • Improvements to clarity and correctness of docstrings for registration functions.
    • Simplifying batch_size_extractor.
    • Adding white space for improved readability.
    • Fixing _update_cache to account for state_dependent_scale (which is currently unused in the open source release).
    opened by copybara-service[bot] 0
  • * Making the estimator finalize itself automatically.

    * Making the estimator finalize itself automatically.

    • Making the estimator finalize itself automatically.
    • Making the optimizer call finalize at the end of init.
    • Removing the need for fake_batch in the optimizer.
    opened by copybara-service[bot] 0
  • - Using jnp.int64 for data_seen and step counters to avoid overflow

    - Using jnp.int64 for data_seen and step counters to avoid overflow

    • Using jnp.int64 for data_seen and step counters to avoid overflow
    • Using float for epochs instead of int
    • Adding extra arguments to cosine schedule in examples
    opened by copybara-service[bot] 0
  • Correct buffer donation.

    Correct buffer donation.

    Correct buffer donation.

    Buffer donation is only valid if the shape and type of an input buffer matches an output. Buffer donation only works with positional arguments, not keyword arguments.

    opened by copybara-service[bot] 1
Releases(v0.0.3)
  • v0.0.3(Sep 23, 2022)

    What's Changed

    • Changing the version in the citation text in the README. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/29
    • Adding attributes for the number of training and evaluation devices. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/31
    • Adding some methods to ImplicitExactCurvature by @copybara-service in https://github.com/deepmind/kfac-jax/pull/32
    • Adding "put_stop_grad_on_loss_factor" argument to 'multiply_fisher_factor'. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/36
    • Making ScaleAndShift blocks begin capable of having parameters that are broadcast by construction, e.g. batch norm with scale parameters [1, 1, 1, d]. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/33
      • Changing jax.tree_map -> jax.tree_util.tree_map and related due to recent deprecation. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/37
      • Removed unused precedence argument from GraphPattern. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/38
    • Fix a small bug where we don't check in the jaxpr constvars. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/39
      • Adding an estimator attribute to the optimizer. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/34
    • Updating the docs to correctly refer to update_cache. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/40
    • Compare with slightly less numerical precision. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/41
      • Revamping the graph matching code to be able to detect layers and register tag in arbitrary higher-order Jax primitives. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/42
    • Revising docstring for optimizer class. Now contains missing details about value_and_grad_func. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/43
    • Internal change. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/44
      • Make LossTag to return only the parameter dependent arrays. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/46
      • Improving LossTags to be able to deal correctly with None arguments, by passing in argument names. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/47
    • Minor fix to a bug introduced on previous commit. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/48
      • Correcting issues with docstring for optimizer. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/45
    • Fixing a bug in the graph matcher introduced in a recent CL. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/49
    • Removing unneeded jax.jit in get_mean and get_sum. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/50
      • Adding per-parameter norm stats to optimizer by @copybara-service in https://github.com/deepmind/kfac-jax/pull/51
    • Allowing the pi-adjusted psd inverse to accept diagonal factors. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/55
    • Fixing wrong type annotation of pmap_axis_name. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/56
    • Adding optional offloading of eigh computation to the host because of a bug in CUDA 11.7.0 cuSOLVER library. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/57

    Full Changelog: https://github.com/deepmind/kfac-jax/compare/v0.0.2...v0.0.3

    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Jun 7, 2022)

    What's Changed

    • Moving .github to top-level directory for CI. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/1
      • Updated documentation for state classes. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/2
    • Changing the name on PyPi to kfac-jax. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/3
    • Making the tracer test in float64. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/4
      • Allowing graph patterns with multiple broadcast to be merged without dangling equations. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/5
      • Adding README for the examples. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/7
    • Changing deprecated tree_multimap to tree_map. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/8
    • Fixing small error introduced due to updates to chex. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/11
    • Fixing typo "drop_reminder" by @copybara-service in https://github.com/deepmind/kfac-jax/pull/13
      • Adding an argument to set the reduction ratio thresholds for automatic damping adjustment. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/12
      • Adding "modifiable_attribute_exceptions" argument to optimizer by @copybara-service in https://github.com/deepmind/kfac-jax/pull/14
    • Changing Imagenet dataset in examples to use a seed for file shuffling to achieve determinism. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/17
    • Small fix to a doc reference bug. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/16
    • Making WeightedMovingAverage to work with arbitrary structures. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/19
      • Minor typos. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/20
    • Correct buffer donation of Optimizer._step. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/21
    • Replacing yield from with direct iteration. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/24
    • Adding stepwise schedule option to examples. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/18
    • Publishing a new version to PyPi. by @copybara-service in https://github.com/deepmind/kfac-jax/pull/28

    New Contributors

    • @copybara-service made their first contribution in https://github.com/deepmind/kfac-jax/pull/1

    Full Changelog: https://github.com/deepmind/kfac-jax/commits/v0.0.2

    Source code(tar.gz)
    Source code(zip)
Owner
DeepMind
DeepMind
Roger Labbe 13k Dec 29, 2022
This project uses Template Matching technique for object detecting by detection of template image over base image.

Object Detection Project Using OpenCV This project uses Template Matching technique for object detecting by detection the template image over base ima

Pratham Bhatnagar 7 May 29, 2022
One implementation of the paper "DMRST: A Joint Framework for Document-Level Multilingual RST Discourse Segmentation and Parsing".

Introduction One implementation of the paper "DMRST: A Joint Framework for Document-Level Multilingual RST Discourse Segmentation and Parsing". Users

seq-to-mind 18 Dec 11, 2022
Official PyTorch implementation of the paper: DeepSIM: Image Shape Manipulation from a Single Augmented Training Sample

DeepSIM: Image Shape Manipulation from a Single Augmented Training Sample (ICCV 2021 Oral) Project | Paper Official PyTorch implementation of the pape

Eliahu Horwitz 393 Dec 22, 2022
Source code for "Taming Visually Guided Sound Generation" (Oral at the BMVC 2021)

Taming Visually Guided Sound Generation • [Project Page] • [ArXiv] • [Poster] • • Listen for the samples on our project page. Overview We propose to t

Vladimir Iashin 226 Jan 03, 2023
MPViT:Multi-Path Vision Transformer for Dense Prediction

MPViT : Multi-Path Vision Transformer for Dense Prediction This repository inlcu

Youngwan Lee 272 Dec 20, 2022
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

THUML: Machine Learning Group @ THSS 149 Dec 19, 2022
LSTC: Boosting Atomic Action Detection with Long-Short-Term Context

LSTC: Boosting Atomic Action Detection with Long-Short-Term Context This Repository contains the code on AVA of our ACM MM 2021 paper: LSTC: Boosting

Tencent YouTu Research 9 Oct 11, 2022
Yas CRNN model training - Yet Another Genshin Impact Scanner

Yas-Train Yet Another Genshin Impact Scanner 又一个原神圣遗物导出器 介绍 该仓库为 Yas 的模型训练程序 相关资料 MobileNetV3 CRNN 使用 假设你会设置基本的pytorch环境。 生成数据集 python main.py gen 训练

wormtql 18 Jan 08, 2023
Nvdiffrast - Modular Primitives for High-Performance Differentiable Rendering

Nvdiffrast – Modular Primitives for High-Performance Differentiable Rendering Modular Primitives for High-Performance Differentiable Rendering Samuli

NVIDIA Research Projects 675 Jan 06, 2023
A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX

Foolbox Native: Fast adversarial attacks to benchmark the robustness of machine learning models in PyTorch, TensorFlow, and JAX Foolbox is a Python li

Bethge Lab 2.4k Dec 25, 2022
Applications using the GTN library and code to reproduce experiments in "Differentiable Weighted Finite-State Transducers"

gtn_applications An applications library using GTN. Current examples include: Offline handwriting recognition Automatic speech recognition Installing

Facebook Research 68 Dec 29, 2022
A Robust Non-IoU Alternative to Non-Maxima Suppression in Object Detection

Confluence: A Robust Non-IoU Alternative to Non-Maxima Suppression in Object Detection 1. 介绍 用以替代 NMS,在所有 bbox 中挑选出最优的集合。 NMS 仅考虑了 bbox 的得分,然后根据 IOU 来

44 Sep 15, 2022
[CVPR 2022 Oral] Balanced MSE for Imbalanced Visual Regression https://arxiv.org/abs/2203.16427

Balanced MSE Code for the paper: Balanced MSE for Imbalanced Visual Regression Jiawei Ren, Mingyuan Zhang, Cunjun Yu, Ziwei Liu CVPR 2022 (Oral) News

Jiawei Ren 267 Jan 01, 2023
The code of NeurIPS 2021 paper "Scalable Rule-Based Representation Learning for Interpretable Classification".

Rule-based Representation Learner This is a PyTorch implementation of Rule-based Representation Learner (RRL) as described in NeurIPS 2021 paper: Scal

Zhuo Wang 53 Dec 17, 2022
Semi-supervised learning for object detection

Source code for STAC: A Simple Semi-Supervised Learning Framework for Object Detection STAC is a simple yet effective SSL framework for visual object

Google Research 348 Dec 25, 2022
Hyperparameters tuning and features selection are two common steps in every machine learning pipeline.

shap-hypetune A python package for simultaneous Hyperparameters Tuning and Features Selection for Gradient Boosting Models. Overview Hyperparameters t

Marco Cerliani 422 Jan 08, 2023
This is the repo for Uncertainty Quantification 360 Toolkit.

UQ360 The Uncertainty Quantification 360 (UQ360) toolkit is an open-source Python package that provides a diverse set of algorithms to quantify uncert

International Business Machines 207 Dec 30, 2022
Contextual Attention Network: Transformer Meets U-Net

Contextual Attention Network: Transformer Meets U-Net Contexual attention network for medical image segmentation with state of the art results on skin

Reza Azad 67 Nov 28, 2022
MultiTaskLearning - Multi Task Learning for 3D segmentation

Multi Task Learning for 3D segmentation Perception stack of an Autonomous Drivin

2 Sep 22, 2022