Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods

Overview

Cold Brew: Distilling Graph Node Representations with Incomplete or Missing Neighborhoods

Introduction

Graph Neural Networks (GNNs) have demonstrated superior performance in node classification or regression tasks, and have emerged as the state of the art in several applications. However, (inductive) GNNs require the edge connectivity structure of nodes to be known beforehand to work well. This is often not the case in several practical applications where the node degrees have power-law distributions, and nodes with a few connections might have noisy edges. An extreme case is the strict cold start (SCS) problem, where there is no neighborhood information available, forcing prediction models to rely completely on node features only. To study the viability of using inductive GNNs to solve the SCS problem, we introduce feature-contribution ratio (FCR), a metric to quantify the contribution of a node's features and that of its neighborhood in predicting node labels, and use this new metric as a model selection reward. We then propose Cold Brew, a new method that generalizes GNNs better in the SCS setting compared to pointwise and graph-based models, via a distillation approach. We show experimentally how FCR allows us to disentangle the contributions of various components of graph datasets, and demonstrate the superior performance of Cold Brew on several public benchmarks

Motivation

Long tail distribution is ubiquitously existed in large scale graph mining tasks. In some applications, some cold start nodes have too few or no neighborhood in the graph, which make graph based methods sub-optimal due to insufficient high quality edges to perform message passing.

gnns

gnns

Method

We improve teacher GNN with Structural Embedding, and propose student MLP model with latent neighborhood discovery step. We also propose a metric called FCR to judge the difficulty in cold start generalization.

gnns

coldbrew

Installation Guide

The following commands are used for installing key dependencies; other can be directly installed via pip or conda. A full redundant dependency list is in requirements.txt

pip install dgl
pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
pip install torch-geometric

Training Guide

In options/base_options.py, a full list of useable args is present, with default arguments and candidates initialized.

Comparing between traditional GCN (optimized with Initial/Jumping/Dense/PairNorm/NodeNorm/GroupNorm/Dropouts) and Cold Brew's GNN (optimized with Structural Embedding)

Train optimized traditional GNN:

python main.py --dataset='Cora' --train_which='TeacherGNN' --whetherHasSE='000' --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 84.15

python main.py --dataset='Citeseer' --train_which='TeacherGNN' --whetherHasSE='000' --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 71.00

python main.py --dataset='Pubmed' --train_which='TeacherGNN' --whetherHasSE='000' --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 78.2

Training Cold Brew's Teacher GNN:

python main.py --dataset='Cora' --train_which='TeacherGNN' --whetherHasSE='100' --se_reg=32 --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 85.10

python main.py --dataset='Citeseer' --train_which='TeacherGNN' --whetherHasSE='100' --se_reg=0.5 --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 71.40

python main.py --dataset='Pubmed' --train_which='TeacherGNN' --whetherHasSE='111' --se_reg=0.5 --want_headtail=1 --num_layers=2 --use_special_split=1 Result: 78.2

Comparing between MLP models:

Training naive MLP:

python main.py --dataset='Cora' --train_which='StudentBaseMLP' Result on isolation split: 63.92

Training GraphMLP:

python main.py --dataset='Cora' --train_which='GraphMLP' Result on isolation split: 68.63

Training Cold Brew's MLP:

python main.py --dataset='Cora' --train_which="SEMLP" --SEMLP_topK_2_replace=3 --SEMLP_part1_arch="2layer" --dropout_MLP=0.5 --studentMLP__opt_lr='torch.optim.Adam&0.005' Result on isolation split: 69.57

Hyperparameter meanings

--whetherHasSE: whether cold brew's TeacherGNN has structural embedding. The first ‘1’ means structural embedding exist in first layer; second ‘1’ means structural embedding exist in every middle layers; third ‘1’ means last layer.

--se_reg: regularization coefficient for cold brew teacher model's structural embedding.

--SEMLP_topK_2_replace: the number of top K best virtual neighbor nodes.

--manual_assign_GPU: set the GPU ID to train on. default=-9999, which means to dynamically choose GPU with most remaining memory.

Adaptation Guide

How to leverage this repo to train on other datasets:

In trainer.py, put any new graph dataset (node classification) under load_data() and return it.

what to load: return a dataset, which is a namespace, called 'data', data.x: 2D tensor, on cpu; shape = [N_nodes, dim_feature]. data.y: 1D tensor, on cpu; shape = [N_nodes]; values are integers, indicating the class of nodes. data.edge_index: tensor: [2, N_edge], cpu; edges contain self loop. data.train_mask: bool tensor, shape = [N_nodes], indicating the training node set. Template class for the 'data':

class MyDataset(torch_geometric.data.data.Data):
    def __init__(self):
        super().__init__()

Citation

comming soon.
Deep universal probabilistic programming with Python and PyTorch

Getting Started | Documentation | Community | Contributing Pyro is a flexible, scalable deep probabilistic programming library built on PyTorch. Notab

7.7k Dec 30, 2022
Exploratory data analysis

Exploratory data analysis An Exploratory data analysis APP TAPIWA CHAMBOKO 🚀 About Me I'm a full stack developer experienced in deploying artificial

tapiwa chamboko 1 Nov 07, 2021
Maximum Covariance Analysis in Python

xMCA | Maximum Covariance Analysis in Python The aim of this package is to provide a flexible tool for the climate science community to perform Maximu

Niclas Rieger 39 Jan 03, 2023
A Pythonic introduction to methods for scaling your data science and machine learning work to larger datasets and larger models, using the tools and APIs you know and love from the PyData stack (such as numpy, pandas, and scikit-learn).

This tutorial's purpose is to introduce Pythonistas to methods for scaling their data science and machine learning work to larger datasets and larger models, using the tools and APIs they know and lo

Coiled 102 Nov 10, 2022
Numerical Analysis toolkit centred around PDEs, for demonstration and understanding purposes not production

Numerics Numerical Analysis toolkit centred around PDEs, for demonstration and understanding purposes not production Use procedure: Initialise a new i

George Whittle 1 Nov 13, 2021
An Integrated Experimental Platform for time series data anomaly detection.

Curve Sorry to tell contributors and users. We decided to archive the project temporarily due to the employee work plan of collaborators. There are no

Baidu 486 Dec 21, 2022
Find exposed data in Azure with this public blob scanner

BlobHunter A tool for scanning Azure blob storage accounts for publicly opened blobs. BlobHunter is a part of "Hunting Azure Blobs Exposes Millions of

CyberArk 250 Jan 03, 2023
Data collection, enhancement, and metrics calculation.

l3_data_collection Data collection, enhancement, and metrics calculation. Summary Repository containing code for QuantDAO's JDT data collection task.

Ruiwyn 3 Dec 23, 2022
A computer algebra system written in pure Python

SymPy See the AUTHORS file for the list of authors. And many more people helped on the SymPy mailing list, reported bugs, helped organize SymPy's part

SymPy 9.9k Dec 31, 2022
Techdegree Data Analysis Project 2

Basketball Team Stats Tool In this project you will be writing a program that reads from the "constants" data (PLAYERS and TEAMS) in constants.py. Thi

2 Oct 23, 2021
ForecastGA is a Python tool to forecast Google Analytics data using several popular time series models.

ForecastGA is a tool that combines a couple of popular libraries, Atspy and googleanalytics, with a few enhancements.

JR Oakes 36 Jan 03, 2023
PySpark Structured Streaming ROS Kafka ApacheSpark Cassandra

PySpark-Structured-Streaming-ROS-Kafka-ApacheSpark-Cassandra The purpose of this project is to demonstrate a structured streaming pipeline with Apache

Zekeriyya Demirci 5 Nov 13, 2022
Tools for the analysis, simulation, and presentation of Lorentz TEM data.

ltempy ltempy is a set of tools for Lorentz TEM data analysis, simulation, and presentation. Features Single Image Transport of Intensity Equation (SI

McMorran Lab 1 Dec 26, 2022
Desafio proposto pela IGTI em seu bootcamp de Cloud Data Engineer

Desafio Modulo 4 - Cloud Data Engineer Bootcamp - IGTI Objetivos Criar infraestrutura como código Utuilizando um cluster Kubernetes na Azure Ingestão

Otacilio Filho 4 Jan 23, 2022
Building house price data pipelines with Apache Beam and Spark on GCP

This project contains the process from building a web crawler to extract the raw data of house price to create ETL pipelines using Google Could Platform services.

1 Nov 22, 2021
Analytical view of olist e-commerce in Brazil

Analysis of E-Commerce Public Dataset by Olist The objective of this project is to propose an analytical view of olist e-commerce in Brazil. For this

Gurpreet Singh 1 Jan 11, 2022
Approximate Nearest Neighbor Search for Sparse Data in Python!

Approximate Nearest Neighbor Search for Sparse Data in Python! This library is well suited to finding nearest neighbors in sparse, high dimensional spaces (like text documents).

Meta Research 906 Jan 01, 2023
[CVPR2022] This repository contains code for the paper "Nested Collaborative Learning for Long-Tailed Visual Recognition", published at CVPR 2022

Nested Collaborative Learning for Long-Tailed Visual Recognition This repository is the official PyTorch implementation of the paper in CVPR 2022: Nes

Jun Li 65 Dec 09, 2022
An interactive grid for sorting, filtering, and editing DataFrames in Jupyter notebooks

qgrid Qgrid is a Jupyter notebook widget which uses SlickGrid to render pandas DataFrames within a Jupyter notebook. This allows you to explore your D

Quantopian, Inc. 2.9k Jan 08, 2023
pipeline for migrating lichess data into postgresql

How Long Does It Take Ordinary People To "Get Good" At Chess? TL;DR: According to 5.5 years of data from 2.3 million players and 450 million games, mo

Joseph Wong 182 Nov 11, 2022