VACA: Designing Variational Graph Autoencoders for Interventional and Counterfactual Queries

Related tags

Deep LearningVACA
Overview

VACA

Code repository for the paper "VACA: Designing Variational Graph Autoencoders for Interventional and Counterfactual Queries (arXiv)". The implementation is based on Pytorch, Pytorch Geometric and Pytorch Lightning. The repository contains the necessary resources to run the experiments of the paper. Follow the instructions below to download the German dataset.

Installation

Create conda environment and activate it:

conda create --name vaca python=3.9 --no-default-packages
conda activate vaca 

Option 1: Import the conda environment

conda env create -f environment.yml

Option 2: Commands

conda install pip
pip install torch torchvision torchaudio
pip install pytorch-lightning
pip install -U scikit-learn
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.9.0+cpu.html
pip install matplotlib
pip install seaborn

Note: The German dataset is not contained in this repository. The first time you try to train on the German dataset, you will get an error with instructions on how to download and store it. Please follow the instructions, such that the code runs smoothly.

Datasets

This repository contains 7 different SCMs: - ColliderSCM - MGraphSCM - ChainSCM - TriangleSCM - LoanSCM - AdultSCM - GermanSCM

Additionally, we provide the implementation of the first five SCMs with three different types of structural equations: linear (LIN), non-linear (NLIN) and non-additive (NADD). You can find the implementation of all the datasets inside the folder datasets. To create all datasets at once run python _create_data_toy.py (this is optional since the datasets will be created as needed on the fly).

How to create your custom Toy Datasets

We also provide a function to create custom ToySCM datasets. Here is an example of an SCM with 2 nodes

from datasets.toy import create_toy_dataset
from utils.distributions import *
dataset = create_toy_dataset(root_dir='./my_custom_datasets',
                             name='2graph',
                             eq_type='linear',
                             nodes_to_intervene=['x1'],
                             structural_eq={'x1': lambda u1: u1,
                                            'x2': lambda u2, x1: u2 + x1},
                             noises_distr={'x1': Normal(0,1),
                                           'x2': Normal(0,1)},
                             adj_edges={'x1': ['x2'],
                                        'x2': []},
                             split='train',
                             num_samples=5000,
                             likelihood_names='d_d',
                             lambda_=0.05)

Training

To train a model you need to execute the script main.py. For that, you need to specify three configuration files: - dataset_file: Specifies the dataset and the parameters of the dataset. You can overwrite the dataset parameters -d. - model_file: Specifies the model and the parameters of the model as well as the optimizer. You can overwrite the model parameters with -m and the optimizer parameters with -o. - trainer_file: Specifies the training parameters of the Trainer object from PyTorch Lightning.

For plotting results use --plots 1. For more information, run python main.py --help.

Examples

To train our VACA algorithm on each of the synthetic graphs with linear structural equations (default value in dataset_ ):

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_vaca.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_vaca.yaml

You can also select a different SEM with the -d option and

  • for linear (LIN) equations -d equations_type=linear,
  • for non-linear (NLIN) equations -d equations_type=non-linear,
  • for non-additive (NADD) equation -d equations_type=non-additive.

For example, to train the triangle graph with non linear SEM:

python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_vaca.yaml -d equations_type=non-linear

We can train our VACA algorithm on the German dataset:

python main.py --dataset_file _params/dataset_german.yaml --model_file _params/model_vaca.yaml

To run the CAREFL model:

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_carefl.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_carefl.yaml

To run the MultiCVAE model:

python main.py --dataset_file _params/dataset_adult.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_loan.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_chain.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_collider.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_mgraph.yaml --model_file _params/model_mcvae.yaml
python main.py --dataset_file _params/dataset_triangle.yaml --model_file _params/model_mcvae.yaml

How to load a trained model?

To load a trained model:

  • set the training flag to -i 0.
  • select configuration file of our training model, i.e. hparams_full.yaml
python main.py --yaml_file=PATH/hparams_full.yaml -i 0

Load a model and train/evaluate counterfactual fairness

Load your model and add the flag --eval_fair. For example:

python main.py --yaml_file=PATH/hparams_full.yaml -i 0 --eval_fair --show_results

TensorBoard visualization

You can track different metrics during (and after) training using TensorBoard. For example, if the root folder of the experiments is exper_test, we can run the following command in a terminal

tensorboard --logdir exper_test/   

to display the logs of all experiments contained in such folder. Then, we go to our favourite browser and go to http://localhost:6006/ to visualize all the results.

Owner
Pablo Sánchez-Martín
Ph.D. student at Max Planck Institute for Intelligence Systems
Pablo Sánchez-Martín
Official repository for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization".

StableNet StableNet is a deep stable learning method for out-of-distribution generalization. This is the official repo for CVPR21 paper "Deep Stable L

120 Dec 28, 2022
METS/ALTO OCR enhancing tool by the National Library of Luxembourg (BnL)

Nautilus-OCR The National Library of Luxembourg (BnL) started its first initiative in digitizing newspapers, with layout recognition and OCR on articl

National Library of Luxembourg 36 Dec 05, 2022
CDGAN: Cyclic Discriminative Generative Adversarial Networks for Image-to-Image Transformation

CDGAN CDGAN: Cyclic Discriminative Generative Adversarial Networks for Image-to-Image Transformation CDGAN Implementation in PyTorch This is the imple

Kancharagunta Kishan Babu 6 Apr 19, 2022
Code for the paper titled "Prabhupadavani: A Code-mixed Speech Translation Data for 25 languages"

Prabhupadavani: A Code-mixed Speech Translation Data for 25 languages Code for the paper titled "Prabhupadavani: A Code-mixed Speech Translation Data

Ayush Daksh 12 Dec 01, 2022
This is the second place solution for : UmojaHack Africa 2022: African Snake Antivenom Binding Challenge

UmojaHack-Africa-2022-African-Snake-Antivenom-Binding-Challenge This is the second place solution for : UmojaHack Africa 2022: African Snake Antivenom

Mami Mokhtar 10 Dec 03, 2022
These are the materials for the paper "Few-Shot Out-of-Domain Transfer Learning of Natural Language Explanations"

Few-shot-NLEs These are the materials for the paper "Few-Shot Out-of-Domain Transfer Learning of Natural Language Explanations". You can find the smal

Yordan Yordanov 0 Oct 21, 2022
Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021)

Reducing Information Bottleneck for Weakly Supervised Semantic Segmentation (NeurIPS 2021) The implementation of Reducing Infromation Bottleneck for W

Jungbeom Lee 81 Dec 16, 2022
FastCover: A Self-Supervised Learning Framework for Multi-Hop Influence Maximization in Social Networks by Anonymous.

FastCover: A Self-Supervised Learning Framework for Multi-Hop Influence Maximization in Social Networks by Anonymous.

0 Apr 02, 2021
ECCV2020 paper: Fashion Captioning: Towards Generating Accurate Descriptions with Semantic Rewards. Code and Data.

This repo contains some of the codes for the following paper Fashion Captioning: Towards Generating Accurate Descriptions with Semantic Rewards. Code

Xuewen Yang 56 Dec 08, 2022
deep learning model that learns to code with drawing in the Processing language

sketchnet sketchnet - processing code generator can we teach a computer to draw pictures with code. We use Processing and java/jruby code paired with

41 Dec 12, 2022
[CoRL 21'] TANDEM: Tracking and Dense Mapping in Real-time using Deep Multi-view Stereo

TANDEM: Tracking and Dense Mapping in Real-time using Deep Multi-view Stereo Lukas Koestler1*    Nan Yang1,2*,†    Niclas Zeller2,3    Daniel Cremers1

TUM Computer Vision Group 744 Jan 04, 2023
Pywonderland - A tour in the wonderland of math with python.

A Tour in the Wonderland of Math with Python A collection of python scripts for drawing beautiful figures and animating interesting algorithms in math

Zhao Liang 4.1k Jan 03, 2023
Repositório criado para abrigar os notebooks com a listas de exercícios propostos pelo professor Gustavo Guanabara do canal Curso em Vídeo do YouTube durante o Curso de Python 3

Curso em Vídeo - Exercícios de Python 3 Sobre o repositório Este repositório contém os notebooks com a listas de exercícios propostos pelo professor G

João Pedro Pereira 9 Oct 15, 2022
[ICLR 2021] Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization

Heteroskedastic and Imbalanced Deep Learning with Adaptive Regularization Kaidi Cao, Yining Chen, Junwei Lu, Nikos Arechiga, Adrien Gaidon, Tengyu Ma

Kaidi Cao 29 Oct 20, 2022
Code for the paper: Fighting Fake News: Image Splice Detection via Learned Self-Consistency

Fighting Fake News: Image Splice Detection via Learned Self-Consistency [paper] [website] Minyoung Huh *12, Andrew Liu *1, Andrew Owens1, Alexei A. Ef

minyoung huh (jacob) 174 Dec 09, 2022
This is the repository for our paper SimpleTrack: Understanding and Rethinking 3D Multi-object Tracking

SimpleTrack This is the repository for our paper SimpleTrack: Understanding and Rethinking 3D Multi-object Tracking. We are still working on writing t

TuSimple 189 Dec 26, 2022
SOLO and SOLOv2 for instance segmentation, ECCV 2020 & NeurIPS 2020.

SOLO: Segmenting Objects by Locations This project hosts the code for implementing the SOLO algorithms for instance segmentation. SOLO: Segmenting Obj

Xinlong Wang 1.5k Dec 31, 2022
The Deep Learning with Julia book, using Flux.jl.

Deep Learning with Julia DL with Julia is a book about how to do various deep learning tasks using the Julia programming language and specifically the

Logan Kilpatrick 67 Dec 25, 2022
A TensorFlow implementation of FCN-8s

FCN-8s implementation in TensorFlow Contents Overview Examples and demo video Dependencies How to use it Download pre-trained VGG-16 Overview This is

Pierluigi Ferrari 50 Aug 08, 2022
COLMAP - Structure-from-Motion and Multi-View Stereo

COLMAP About COLMAP is a general-purpose Structure-from-Motion (SfM) and Multi-View Stereo (MVS) pipeline with a graphical and command-line interface.

4.7k Jan 07, 2023