EGNN - Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch

Overview

EGNN - Pytorch

Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This technique went for simple invariant features, and ended up beating all previous methods (including SE3 Transformer and Lie Conv) in both accuracy and performance. SOTA in dynamical system models, molecular activity prediction tasks, etc.

Install

$ pip install egnn-pytorch

Usage

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)

feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)

With edges

import torch
from egnn_pytorch import EGNN

layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)

feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Comments
  • training batch size

    training batch size

    Dear authors,

    thanks for your great work! I saw your example, which is easy to understand. But I notice that during training, in each iteration, it seems it supports the case where batch-size > 1, but all the graphs have the same adj_mat. do you have better solution for that? thanks

    opened by futianfan 6
  • Import Error when torch_geometric is not available

    Import Error when torch_geometric is not available

    https://github.com/lucidrains/egnn-pytorch/blob/e35510e1be94ee9f540bf2ffea49cd63578fe473/egnn_pytorch/egnn_pytorch.py#L413

    A small problem, this Tensor is not defined.

    Thanks for your work.

    opened by zrt 4
  • About aggregations in EGNN_sparse

    About aggregations in EGNN_sparse

    Hi, thanks for your great work!

    I have a question on how aggregations are computed for node embedding and coordinate embedding. In the paper, the aggregation for node embedding is computed over its neighbors, while the aggregation for coordinate embedding is computed over is computed over all others. However, in EGNN_sparse, I didn't notice such difference in aggregations.

    I guess it is because computing all-pair messages for coordinate embedding makes 'sparse' meaningless, but I would like to double-check to see if I get this correctly. So anyway, did you do this intentionally? Or did I miss something?

    My appreciation.

    opened by simon1727 4
  • Few queries on the implementation

    Few queries on the implementation

    Hi - fast work coding these things up, as usual! Looking at the paper and your code, you're not using squared distance for the edge weighting. Is that intentional? Also, it looks like you are adding the old feature vectors to the new ones rather than taking the new vectors directly from the fully connected net - is that also an intentional change from the paper?

    opened by denjots 3
  • Fix PyG problems. add exmaple for point cloud denoising

    Fix PyG problems. add exmaple for point cloud denoising

    • Fixed some tiny errors in data flows for the PyG layers (dimensions and slices mainly)
    • fixed the EGNN_Sparse_Network so now it works
    • provides example for point cloud denoising (from gaussian masked coordinates), and showcases potential issues:
      • unstable (could be due to nature of data, not sure, but gvp does well on it)
      • not able to beat baseline (in contrast, gvp gets to 0.8 RMSD while this gets to the baseline 1 RMSD but not below it)
    opened by hypnopump 2
  • EGNN_sparse incorrect positional encoding output

    EGNN_sparse incorrect positional encoding output

    Hi, many thanks for the implementation!

    I was quickly checking the code for the pytorch geometric implementation of the EGNN_sparse layer, and I noticed that it expects the first 3 columns in the features to be the coordinates. However, in the update method, features and coordinates are passed in the wrong order.

    https://github.com/lucidrains/egnn-pytorch/blob/375d686c749a685886874baba8c9e0752db5f5be/egnn_pytorch/egnn_pytorch.py#L192

    This may cause problems during learning (think of concatenating several of these layers), as they expect coordinate and feature order to be consistent.

    One can reproduce this behaviour in the following snippet:

    layer = EGNN_sparse(feats_dim=1, pos_dim=3, m_dim=16, fourier_features=0)
    
    R = rot(*torch.rand(3))
    T = torch.randn(1, 1, 3)
    
    feats = torch.randn(16, 1)
    coors = torch.randn(16, 3)
    x1 = torch.cat([coors, feats], dim=-1)
    x2 = torch.cat([(coors @ R + T).squeeze() , feats], dim=-1)
    edge_idxs = (torch.rand(2, 20) * 16).long()
    
    out1 = layer(x=x1, edge_index=edge_idxs)
    out2 = layer(x=x2, edge_index=edge_idxs)
    

    After fixing the order of these arguments in the update method then the layer behaves as expected (output features are equivariant, and coordinate features are equivariant upon se(3) transformation)

    opened by josejimenezluna 2
  • Nan Values after stacking multiple layers

    Nan Values after stacking multiple layers

    Hi Lucid!!

    I find that when stacking multiple layers the output from the model rapidly goes to Nan. I suspect it may be related to the weights used for initialization.

    Here is a minimal working example:

    Make some data:

        import numpy as np
        import torch
        from egnn_pytorch import EGNN
        
        torch.set_default_dtype(torch.double)
    
        zline = np.arange(0, 2, 0.05)
        xline = np.sin(zline * 2 * np.pi) 
        yline = np.cos(zline * 2 * np.pi)
        points = np.array([xline, yline, zline])
        geom = torch.tensor(points.transpose())[None,:]
        feat = torch.randint(0, 20, (1, geom.shape[1],1))
    

    Make a model:

        class ResEGNN(torch.nn.Module):
            def __init__(self, depth = 2, dims_in = 1):
                super().__init__()
                self.layers = torch.nn.ModuleList([EGNN(dim = dims_in) for i in range(depth)])
            
            def forward(self, geom, feat):
                for layer in self.layers:
                    feat, geom = layer(feat, geom)
                return geom
    

    Run model for varying depths:

        for i in range(10):
            model = ResEGNN(depth = i)
            pred = model(geom, feat)
            mean_absolute_value  = torch.abs(pred).mean()
            print("Order of predictions {:.2f}".format(np.log(mean_absolute_value.detach().numpy())))
    

    Output : Order of predictions -0.29 Order of predictions 0.05 Order of predictions 6.65 Order of predictions 21.38 Order of predictions 78.25 Order of predictions 302.71 Order of predictions 277.38 Order of predictions nan Order of predictions nan Order of predictions nan

    opened by brennanaba 2
  • Edge features thrown out

    Edge features thrown out

    Hi, thanks for this implementation!

    I was wondering if the pytorch-geometric implementation of this architecture is throwing the edge features out by mistake, as seen here

    https://github.com/lucidrains/egnn-pytorch/blob/1b8320ade1a89748e4042ae448626652f1c659a1/egnn_pytorch/egnn_pytorch.py#L148-L151

    Or maybe my understanding is wrong? Cheers,

    opened by josejimenezluna 2
  • solve ij -> i bottleneck in sparse version

    solve ij -> i bottleneck in sparse version

    I don't recommend normalizing the weights nor the coords.

    • The weights are the coefficient that multiplies the delta in the i->j direction
    • the coords are the deltas in the i->j direction Can't see the advantage of normalizing them beyond a naive stabilization that might affect the convergence properties by needing more layers due to the limited transformation that a layer will be able to do.

    It works fine for denoising without normalization (the unstability might come from huge outliers, but then tuning the learning rate or clipping the gradients might be of help.)

    opened by hypnopump 0
  • Questions about the EGNN code

    Questions about the EGNN code

    Recently, I've tried to read EGNN paper and study your EGNN code. Actually, I had hard time to understand both paper and code because my major is not computer science. When studying your code, I realize that the shape of hidden_out and the shape of kwargs["x"] must be same to perform add operation (becaus of residual connection) in the class EGNN_sparse forward method. How can I increase or decrease the hidden dimension size of x?

    I would like to get some advice.

    Thanks for your consideration in this regard.

    opened by Byun-jinyoung 0
  • Wrong edge_index size hint in  class EGNN_Sparse of pyg version

    Wrong edge_index size hint in class EGNN_Sparse of pyg version

    Hi, I found there may be a little mistake. In the input hint of class EGNN_Sparse of pyg version, the size of edge_index is (n_edges, 2). However, it should be (2, n_edges). Otherwise, the distance calculation will be not correct. """ Inputs: * x: (n_points, d) where d is pos_dims + feat_dims * edge_index: (n_edges, 2) * edge_attr: tensor (n_edges, n_feats) excluding basic distance feats. * batch: (n_points,) long tensor. specifies xloud belonging for each point * angle_data: list of tensors (levels, n_edges_i, n_length_path) long tensor. * size: None """

    opened by Layne-Huang 2
  • Exploding Gradients With 4 Layers

    Exploding Gradients With 4 Layers

    I'm using EGNN with 4 layers (where I also do global attention after each layer), and I'm seeing exploding gradients after 90 epochs or so. I'm using techniques discussed earlier (sparse attention matrix, coor_weights_clamp_value, norm_coors), but I'm not sure if there's anything else I should be doing. I'm also not updating the coordinates, so the fix in the pull request doesn't apply.

    opened by cutecows 0
  • Added optional tanh to coors_mlp

    Added optional tanh to coors_mlp

    This removes the NaN bug completely (must also use norm_coors otherwise performance dies)

    The NaN bug comes from the coors_mlp exploding, so forcing values between -1 and 1 prevents this. If coordinates are normalised then performance should not be adversely affected.

    opened by jscant 1
Releases(0.2.6)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
A novel framework to automatically learn high-quality scanning of non-planar, complex anisotropic appearance.

appearance-scanner About This repository is an implementation of the neural network proposed in Free-form Scanning of Non-planar Appearance with Neura

Xiaohe Ma 14 Oct 18, 2022
A hifiasm fork for metagenome assembly using Hifi reads.

hifiasm_meta - de novo metagenome assembler, based on hifiasm, a haplotype-resolved de novo assembler for PacBio Hifi reads.

44 Jul 10, 2022
NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

NitroFE is a Python feature engineering engine which provides a variety of modules designed to internally save past dependent values for providing continuous calculation.

100 Sep 28, 2022
Generate pixel-style avatars with python.

face2pixel Generate pixel-style avatars with python. Run: Clone the project: git clone https://github.com/theodorecooper/face2pixel install requiremen

Theodore Cooper 2 May 11, 2022
Dynamic Visual Reasoning by Learning Differentiable Physics Models from Video and Language (NeurIPS 2021)

VRDP (NeurIPS 2021) Dynamic Visual Reasoning by Learning Differentiable Physics Models from Video and Language Mingyu Ding, Zhenfang Chen, Tao Du, Pin

Mingyu Ding 36 Sep 20, 2022
Facebook Research 605 Jan 02, 2023
The source code and data of the paper "Instance-wise Graph-based Framework for Multivariate Time Series Forecasting".

IGMTF The source code and data of the paper "Instance-wise Graph-based Framework for Multivariate Time Series Forecasting". Requirements The framework

Wentao Xu 24 Dec 05, 2022
Official repository of the paper "GPR1200: A Benchmark for General-PurposeContent-Based Image Retrieval"

GPR1200 Dataset GPR1200: A Benchmark for General-Purpose Content-Based Image Retrieval (ArXiv) Konstantin Schall, Kai Uwe Barthel, Nico Hezel, Klaus J

Visual Computing Group 16 Nov 21, 2022
OpenMMLab's Next Generation Video Understanding Toolbox and Benchmark

Introduction English | 简体中文 MMAction2 is an open-source toolbox for video understanding based on PyTorch. It is a part of the OpenMMLab project. The m

OpenMMLab 2.7k Jan 07, 2023
Project repo for the paper SILT: Self-supervised Lighting Transfer Using Implicit Image Decomposition

SILT: Self-supervised Lighting Transfer Using Implicit Image Decomposition (BMVC 2021) Project repo for the paper SILT: Self-supervised Lighting Trans

6 Dec 04, 2022
The dynamics of representation learning in shallow, non-linear autoencoders

The dynamics of representation learning in shallow, non-linear autoencoders The package is written in python and uses the pytorch implementation to ML

Maria Refinetti 4 Jun 08, 2022
DETReg: Unsupervised Pretraining with Region Priors for Object Detection

DETReg: Unsupervised Pretraining with Region Priors for Object Detection Amir Bar, Xin Wang, Vadim Kantorov, Colorado J Reed, Roei Herzig, Gal Chechik

Amir Bar 283 Dec 27, 2022
SEC'21: Sparse Bitmap Compression for Memory-Efficient Training onthe Edge

Training Deep Learning Models on The Edge Training on the Edge enables continuous learning from new data for deployed neural networks on memory-constr

Brown University Scale Lab 4 Nov 18, 2022
This program was designed to detect whether someone is wearing a facemask through a live video stream.

This program was designed to detect whether someone is wearing a facemask through a live video stream. A custom lightweight CNN trained with TensorFlow on a public dataset provided by Kaggle is used

0 Apr 02, 2022
Contextual Attention Network: Transformer Meets U-Net

Contextual Attention Network: Transformer Meets U-Net Contexual attention network for medical image segmentation with state of the art results on skin

Reza Azad 67 Nov 28, 2022
Code for Mining the Benefits of Two-stage and One-stage HOI Detection

Status: Archive (code is provided as-is, no updates expected) PPO-EWMA [Paper] This is code for training agents using PPO-EWMA and PPG-EWMA, introduce

OpenAI 33 Dec 15, 2022
Contains code for Deep Kernelized Dense Geometric Matching

DKM - Deep Kernelized Dense Geometric Matching Contains code for Deep Kernelized Dense Geometric Matching We provide pretrained models and code for ev

Johan Edstedt 83 Dec 23, 2022
The 1st place solution of track2 (Vehicle Re-Identification) in the NVIDIA AI City Challenge at CVPR 2021 Workshop.

AICITY2021_Track2_DMT The 1st place solution of track2 (Vehicle Re-Identification) in the NVIDIA AI City Challenge at CVPR 2021 Workshop. Introduction

Hao Luo 91 Dec 21, 2022
Rasterize with the least efforts for researchers.

utils3d Rasterize and do image-based 3D transforms with the least efforts for researchers. Based on numpy and OpenGL. It could be helpful when you wan

Ruicheng Wang 8 Dec 15, 2022
g2o: A General Framework for Graph Optimization

g2o - General Graph Optimization Linux: Windows: g2o is an open-source C++ framework for optimizing graph-based nonlinear error functions. g2o has bee

Rainer Kümmerle 2.5k Dec 30, 2022