Geometric Algebra package for JAX

Overview

JAXGA - JAX Geometric Algebra

Build status PyPI

GitHub | Docs

JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing only the non-zero basis blade coefficients. It makes use of JAX's just-in-time (JIT) compilation by first precomputing blade indices and signs and then JITting the function doing the actual calculations.

Installation

Install using pip: pip install jaxga

Requirements:

Usage

Unlike most other Geometric Algebra packages, it is not necessary to pre-specify an algebra. JAXGA can either be used with the MultiVector class or by using lower-level functions which is useful for example when using JAX's jit or automatic differentaition.

The MultiVector class provides operator overloading and is constructed with an array of values and their corresponding basis blades. The basis blades are encoded as tuples, for example the multivector 2 e_1 + 4 e_23 would have the values [2, 4] and the basis blade tuple ((1,), (2, 3)).

MultiVector example

import jax.numpy as jnp
from jaxga.mv import MultiVector

a = MultiVector(
    values=2 * jnp.ones([1], dtype=jnp.float32),
    indices=((1,),)
)
# Alternative: 2 * MultiVector.e(1)

b = MultiVector(
    values=4 * jnp.ones([2], dtype=jnp.float32),
    indices=((2, 3),)
)
# Alternative: 4 * MultiVector.e(2, 3)

c = a * b
print(c)

Output: Multivector(8.0 e_{1, 2, 3})

The lower-level functions also deal with values and blades. Functions are provided that take the blades and return a function that does the actual calculation. The returned function is JITted and can also be automatically differentiated with JAX. Furthermore, some operations like the geometric product take a signature function that takes a basis vector index and returns their square.

Lower-level function example

import jax.numpy as jnp
from jaxga.signatures import positive_signature
from jaxga.ops.multiply import get_mv_multiply

a_values = 2 * jnp.ones([1], dtype=jnp.float32)
a_indices = ((1,),)

b_values = 4 * jnp.ones([1], dtype=jnp.float32)
b_indices = ((2, 3),)

mv_multiply, c_indices = get_mv_multiply(a_indices, b_indices, positive_signature)
c_values = mv_multiply(a_values, b_values)
print("C indices:", c_indices, "C values:", c_values)

Output: C indices: ((1, 2, 3),) C values: [8.]

Some notes

  • Both the MultiVector and lower-level function approaches support batches: the axes after the first one (which indexes the basis blades) are treated as batch indices.
  • The MultiVector class can also take a signature in its constructor (default is square to 1 for all basis vectors). Doing operations with MultiVectors with different signatures is undefined.
  • The jaxga.signatures submodule contains a few predefined signature functions.
  • get_mv_multiply and similar functions cache their result by their inputs.
  • The flaxmodules submodule provides flax (a popular neural network library for jax) modules with Geometric Algebra operations.
  • Because we don't deal with a specific algebra, the dual needs an input that specifies the dimensionality of the space in which we want to find the dual element.

Benchmarks

N-d vector * N-d vector, batch size 100, N=range(1, 10), CPU

JaxGA stores only the non-zero basis blade coefficients. TFGA and Clifford on the other hand store all GA elements as full multivectors including all zeros. As a result, JaxGA does better than these for high dimensional algebras.

Below is a benchmark of the geometric product of two vectors with increasing dimensionality from 1 to 9. 100 vectors are multiplied at a time.

JAXGA (CPU) tfga (CPU) clifford
benchmark-results benchmark-results benchmark-results

N-d vector * N-d vector, batch size 100, N=range(1, 50, 5), CPU

Below is a benchmark for higher dimensions that TFGA and Clifford could not handle. Note that the X axis isn't sorted naturally.

benchmark-results

Owner
Robin Kahlow
Software / Machine Learning Engineer
Robin Kahlow
Official implementation of CVPR2020 paper "Deep Generative Model for Robust Imbalance Classification"

Deep Generative Model for Robust Imbalance Classification Deep Generative Model for Robust Imbalance Classification Xinyue Wang, Yilin Lyu, Liping Jin

9 Nov 01, 2022
SatelliteSfM - A library for solving the satellite structure from motion problem

Satellite Structure from Motion Maintained by Kai Zhang. Overview This is a libr

Kai Zhang 190 Dec 08, 2022
Gated-Shape CNN for Semantic Segmentation (ICCV 2019)

GSCNN This is the official code for: Gated-SCNN: Gated Shape CNNs for Semantic Segmentation Towaki Takikawa, David Acuna, Varun Jampani, Sanja Fidler

859 Dec 26, 2022
CPF: Learning a Contact Potential Field to Model the Hand-object Interaction

Contact Potential Field This repo contains model, demo, and test codes of our paper: CPF: Learning a Contact Potential Field to Model the Hand-object

Lixin YANG 99 Dec 26, 2022
This repository contains the code for EMNLP-2021 paper "Word-Level Coreference Resolution"

Word-Level Coreference Resolution This is a repository with the code to reproduce the experiments described in the paper of the same name, which was a

79 Dec 27, 2022
Code for Robust Contrastive Learning against Noisy Views

Robust Contrastive Learning against Noisy Views This repository provides a PyTorch implementation of the Robust InfoNCE loss proposed in paper Robust

Ching-Yao Chuang 53 Jan 08, 2023
Tensorflow-Project-Template - A best practice for tensorflow project template architecture.

Tensorflow Project Template A simple and well designed structure is essential for any Deep Learning project, so after a lot of practice and contributi

Mahmoud G. Salem 3.6k Dec 22, 2022
Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Generated Images"

Reverse_Engineering_GMs Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Gener

100 Dec 18, 2022
Everything's Talkin': Pareidolia Face Reenactment (CVPR2021)

Everything's Talkin': Pareidolia Face Reenactment (CVPR2021) Linsen Song, Wayne Wu, Chaoyou Fu, Chen Qian, Chen Change Loy, and Ran He [Paper], [Video

71 Dec 21, 2022
Python with OpenCV - MediaPip Framework Hand Detection

Python HandDetection Python with OpenCV - MediaPip Framework Hand Detection Explore the docs » Contact Me About The Project It is a Computer vision pa

2 Jan 07, 2022
On Generating Extended Summaries of Long Documents

ExtendedSumm This repository contains the implementation details and datasets used in On Generating Extended Summaries of Long Documents paper at the

Georgetown Information Retrieval Lab 76 Sep 05, 2022
Pytorch implementation for the paper: Contrastive Learning for Cold-start Recommendation

Contrastive Learning for Cold-start Recommendation This is our Pytorch implementation for the paper: Yinwei Wei, Xiang Wang, Qi Li, Liqiang Nie, Yan L

45 Dec 13, 2022
Custom studies about block sparse attention.

Block Sparse Attention 研究总结 本人近半年来对Block Sparse Attention(块稀疏注意力)的研究总结(持续更新中)。按时间顺序,主要分为如下三部分: PyTorch 自定义 CUDA 算子——以矩阵乘法为例 基于 Triton 的 Block Sparse A

Chen Kai 2 Jan 09, 2022
ACL'2021: LM-BFF: Better Few-shot Fine-tuning of Language Models

LM-BFF (Better Few-shot Fine-tuning of Language Models) This is the implementation of the paper Making Pre-trained Language Models Better Few-shot Lea

Princeton Natural Language Processing 607 Jan 07, 2023
A Real-Time-Strategy game for Deep Learning research

Description DeepRTS is a high-performance Real-TIme strategy game for Reinforcement Learning research. It is written in C++ for performance, but provi

Centre for Artificial Intelligence Research (CAIR) 156 Dec 19, 2022
Unified tracking framework with a single appearance model

Paper: Do different tracking tasks require different appearance model? [ArXiv] (comming soon) [Project Page] (comming soon) UniTrack is a simple and U

ZhongdaoWang 300 Dec 24, 2022
Official pytorch implementation of paper Dual-Level Collaborative Transformer for Image Captioning (AAAI 2021).

Dual-Level Collaborative Transformer for Image Captioning This repository contains the reference code for the paper Dual-Level Collaborative Transform

lyricpoem 160 Dec 11, 2022
My Body is a Cage: the Role of Morphology in Graph-Based Incompatible Control

My Body is a Cage: the Role of Morphology in Graph-Based Incompatible Control

yobi byte 29 Oct 09, 2022
an implementation of Revisiting Adaptive Convolutions for Video Frame Interpolation using PyTorch

revisiting-sepconv This is a reference implementation of Revisiting Adaptive Convolutions for Video Frame Interpolation [1] using PyTorch. Given two f

Simon Niklaus 59 Dec 22, 2022
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Evgenii Nikishin 43 Sep 28, 2022