A PyTorch Implementation of "Watch Your Step: Learning Node Embeddings via Graph Attention" (NeurIPS 2018).

Overview

Attention Walk

Arxiv codebeat badge repo sizebenedekrozemberczki

A PyTorch Implementation of Watch Your Step: Learning Node Embeddings via Graph Attention (NIPS 2018).

Abstract

Graph embedding methods represent nodes in a continuous vector space, preserving different types of relational information from the graph. There are many hyper-parameters to these methods (e.g. the length of a random walk) which have to be manually tuned for every graph. In this paper, we replace previously fixed hyper-parameters with trainable ones that we automatically learn via backpropagation. In particular, we propose a novel attention model on the power series of the transition matrix, which guides the random walk to optimize an upstream objective. Unlike previous approaches to attention models, the method that we propose utilizes attention parameters exclusively on the data itself (e.g. on the random walk), and are not used by the model for inference. We experiment on link prediction tasks, as we aim to produce embeddings that best-preserve the graph structure, generalizing to unseen information. We improve state-of-the-art results on a comprehensive suite of real-world graph datasets including social, collaboration, and biological networks, where we observe that our graph attention model can reduce the error by up to 20%-40%. We show that our automatically-learned attention parameters can vary significantly per graph, and correspond to the optimal choice of hyper-parameter if we manually tune existing methods.

This repository provides an implementation of Attention Walk as described in the paper:

Watch Your Step: Learning Node Embeddings via Graph Attention. Sami Abu-El-Haija, Bryan Perozzi, Rami Al-Rfou, Alexander A. Alemi. NIPS, 2018. [Paper]

The original Tensorflow implementation is available [here].

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx          2.4
tqdm              4.28.1
numpy             1.15.4
pandas            0.23.4
texttable         1.5.0
scipy             1.1.0
argparse          1.1.0
torch             1.1.0
torchvision       0.3.0

Datasets

The code takes an input graph in a csv file. Every row indicates an edge between two nodes separated by a comma. The first row is a header. Nodes should be indexed starting with 0. Sample graphs for the `Twitch Brasilians` and `Wikipedia Chameleons` are included in the `input/` directory.

### Options

Learning of the embedding is handled by the src/main.py script which provides the following command line arguments.

Input and output options

  --edge-path         STR   Input graph path.     Default is `input/chameleon_edges.csv`.
  --embedding-path    STR   Embedding path.       Default is `output/chameleon_AW_embedding.csv`.
  --attention-path    STR   Attention path.       Default is `output/chameleon_AW_attention.csv`.

Model options

  --dimensions           INT       Number of embeding dimensions.        Default is 128.
  --epochs               INT       Number of training epochs.            Default is 200.
  --window-size          INT       Skip-gram window size.                Default is 5.
  --learning-rate        FLOAT     Learning rate value.                  Default is 0.01.
  --beta                 FLOAT     Attention regularization parameter.   Default is 0.5.
  --gamma                FLOAT     Embedding regularization parameter.   Default is 0.5.
  --num-of-walks         INT       Number of walks per source node.      Default is 80.

Examples

The following commands learn a graph embedding and write the embedding to disk. The node representations are ordered by the ID.

Creating an Attention Walk embedding of the default dataset with the standard hyperparameter settings. Saving this embedding at the default path.

``` python src/main.py ```

Creating an Attention Walk embedding of the default dataset with 256 dimensions.

python src/main.py --dimensions 256

Creating an Attention Walk embedding of the default dataset with a higher window size.

python src/main.py --window-size 20

Creating an embedding of another dataset the Twitch Brasilians. Saving the outputs under custom file names.

python src/main.py --edge-path input/ptbr_edges.csv --embedding-path output/ptbr_AW_embedding.csv --attention-path output/ptbr_AW_attention.csv

License


Comments
  • Nan parameters

    Nan parameters

    Thanks for your pytorch code. I found that my parameters become Nan during training. Nan parameters include model.left_factors, model.right_factors, model.attention. All the entries of them become Nan during training. And also the loss. I'm trying to find the reason. I would appreciate it if you could give me some help or hints.

    opened by kkkkk001 9
  • Memory Error

    Memory Error

    I'm getting OOM errors even with small files. The attached file link_network.txt throws the following error:

    Adjacency matrix powers: 100%|███████████████████████████████████████████████████████| 4/4 [00:00<00:00, 108.39it/s]
    Traceback (most recent call last):
      File "src\main.py", line 79, in <module>
        main()
      File "src\main.py", line 74, in main
        model = AttentionWalkTrainer(args)
      File "E:\AttentionWalk\src\attentionwalk.py", line 70, in __init__
        self.initialize_model_and_features()
      File "E:\AttentionWalk\src\attentionwalk.py", line 76, in initialize_model_and_features
        self.target_tensor = feature_calculator(self.args, self.graph)
      File "E:\AttentionWalk\src\utils.py", line 53, in feature_calculator
        target_matrices = np.array(target_matrices)
    MemoryError
    

    I guess this is due to the large indices of the nodes. Any workarounds for this?

    opened by davidlenz 2
  • modified normalized_adjacency_matrix calculation

    modified normalized_adjacency_matrix calculation

    As mentioned in this issue: https://github.com/benedekrozemberczki/AttentionWalk/issues/9

    Added normalization into calculation, able to prevent unbalanced loss and prevent loss_on_mat to be extreme big while node count of data is big.

    opened by neilctwu 1
  • miscalculations of normalized adjacency matrix

    miscalculations of normalized adjacency matrix

    Thanks for sharing this awesome repo.

    The issue is I found that loss_on_target will become extreme big while training from the original code, and I think is due to the miscalculation of normalized_adjacency_matrix.

    From your original code, normalized_adjacency_matrix is been calculated by:

    normalized_adjacency_matrix = degs.dot(adjacency_matrix)
    

    However while the matrix hasn't been normalize but simply multiple by degree of nodes. I think the part of normalized_adjacency_matrix should be modified like its original definition:

      normalized_adjacency_matrix = degs.power(-1/2)\
                                        .dot(adjacency_matrix)\
                                        .dot(degs.power(-1/2))
    

    It'll turn out to be more reasonable loss shown below: image

    Am I understand it correctly?

    opened by neilctwu 1
  • problem with being killed

    problem with being killed

    Hi, I tried to train the model with new dataset which have about 60000 nodes, but I have a problem of getting Killed suddenly. Do you have any idea why? Thanks :) image

    opened by amy-hyunji 1
  • Directed weighted graphs

    Directed weighted graphs

    Is it possible to use the code with directed and weighted graphs? The paper states the attention walk framework for unweighted graphs only, but i'd like to use it for such types of networks. Thank you for your attention.

    opened by federicoairoldi 1
Releases(v_00001)
Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''.

P-tuning A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''. How to use our code We have released the code

THUDM 562 Dec 27, 2022
Robust & Reliable Route Recommendation on Road Networks

NeuroMLR: Robust & Reliable Route Recommendation on Road Networks This repository is the official implementation of NeuroMLR: Robust & Reliable Route

4 Dec 20, 2022
Code for "ATISS: Autoregressive Transformers for Indoor Scene Synthesis", NeurIPS 2021

ATISS: Autoregressive Transformers for Indoor Scene Synthesis This repository contains the code that accompanies our paper ATISS: Autoregressive Trans

138 Dec 22, 2022
Bytedance Inc. 2.5k Jan 06, 2023
Official pytorch implementation of the AAAI 2021 paper Semantic Grouping Network for Video Captioning

Semantic Grouping Network for Video Captioning Hobin Ryu, Sunghun Kang, Haeyong Kang, and Chang D. Yoo. AAAI 2021. [arxiv] Environment Ubuntu 16.04 CU

Hobin Ryu 43 Nov 25, 2022
S2-BNN: Bridging the Gap Between Self-Supervised Real and 1-bit Neural Networks via Guided Distribution Calibration (CVPR 2021)

S2-BNN (Self-supervised Binary Neural Networks Using Distillation Loss) This is the official pytorch implementation of our paper: "S2-BNN: Bridging th

Zhiqiang Shen 52 Dec 24, 2022
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Rishabh Anand 24 Mar 23, 2022
A PyTorch implementation of "Predict then Propagate: Graph Neural Networks meet Personalized PageRank" (ICLR 2019).

APPNP ⠀ A PyTorch implementation of Predict then Propagate: Graph Neural Networks meet Personalized PageRank (ICLR 2019). Abstract Neural message pass

Benedek Rozemberczki 329 Dec 30, 2022
Tandem Mass Spectrum Prediction with Graph Transformers

MassFormer This is the original implementation of MassFormer, a graph transformer for small molecule MS/MS prediction. Check out the preprint on arxiv

Röst Lab 13 Oct 27, 2022
A collection of SOTA Image Classification Models in PyTorch

A collection of SOTA Image Classification Models in PyTorch

sithu3 85 Dec 30, 2022
A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

jedibobo 3 Dec 28, 2022
Source code for models described in the paper "AudioCLIP: Extending CLIP to Image, Text and Audio" (https://arxiv.org/abs/2106.13043)

AudioCLIP Extending CLIP to Image, Text and Audio This repository contains implementation of the models described in the paper arXiv:2106.13043. This

458 Jan 02, 2023
Autonomous Perception: 3D Object Detection with Complex-YOLO

Autonomous Perception: 3D Object Detection with Complex-YOLO LiDAR object detect

Thomas Dunlap 2 Feb 18, 2022
The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Dice Loss for NLP Tasks This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020. Setup Install Package Dependencies The c

223 Dec 17, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained mo

Hugging Face 77.2k Jan 02, 2023
U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection

The code for our newly accepted paper in Pattern Recognition 2020: "U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection."

Xuebin Qin 6.5k Jan 09, 2023
Permute Me Softly: Learning Soft Permutations for Graph Representations

Permute Me Softly: Learning Soft Permutations for Graph Representations

Giannis Nikolentzos 7 Jul 10, 2022
TDmatch is a Python library developed to perform matching tasks in three categories:

TDmatch TDmatch is a Python library developed to perform matching tasks in three categories: Text to Data which matches tuples of a table to text docu

Naser Ahmadi 5 Aug 11, 2022
An implementation of Equivariant e2 convolutional kernals into a convolutional self attention network, applied to radio astronomy data.

EquivariantSelfAttention An implementation of Equivariant e2 convolutional kernals into a convolutional self attention network, applied to radio astro

2 Nov 09, 2021
[NeurIPS 2021] Well-tuned Simple Nets Excel on Tabular Datasets

[NeurIPS 2021] Well-tuned Simple Nets Excel on Tabular Datasets Introduction This repo contains the source code accompanying the paper: Well-tuned Sim

52 Jan 04, 2023