Hierarchical probabilistic 3D U-Net, with attention mechanisms (β€”π˜ˆπ˜΅π˜΅π˜¦π˜―π˜΅π˜ͺ𝘰𝘯 𝘜-π˜•π˜¦π˜΅, π˜šπ˜Œπ˜™π˜¦π˜΄π˜•π˜¦π˜΅) and a nested decoder structure with deep supervision (β€”π˜œπ˜•π˜¦π˜΅++).

Overview

Clinically Significant Prostate Cancer Detection in bpMRI

Note: This repo will be continually updated upon future advancements and we welcome open-source contributions! Currently, it shares the TensorFlow 2.5 version of the Hierarchical Probabilistic 3D U-Net (with attention mechanisms, nested decoder structure and deep supervision), titled M1, as explored in the publication(s) listed below. Source code used for training this model, as per our original setup, carry a large number of dependencies on internal datasets, tooling, infrastructure and hardware, and their release is currently not feasible. However, an equivalent minimal adaptation has been made available. We encourage users to test out M1, identify potential areas for significant improvement and propose PRs for inclusion to this repo.

Pre-Trained Model using 1950 bpMRI with PI-RADS v2 Annotations [Training:Validation Ratio - 80:20]:
To infer lesion predictions on testing samples using the pre-trained variant (architecture in commit 58b784f) of this algorithm, please visit https://grand-challenge.org/algorithms/prostate-mri-cad-cspca/

Main Scripts
● Preprocessing Functions: tf2.5/scripts/preprocess.py
● Tensor-Based Augmentations: tf2.5/scripts/model/augmentations.py
● Training Script Template: tf2.5/scripts/train_model.py
● Basic Callbacks (e.g. LR Schedules): tf2.5/scripts/callbacks.py
● Loss Functions: tf2.5/scripts/model/losses.py
● Network Architecture: tf2.5/scripts/model/unets/networks.py

Requirements
● Complete Docker Container: anindox8/m1:latest
● Key Python Packages: tf2.5/requirements.txt

schematic Train-time schematic for the Bayesian/hierarchical probabilistic configuration of M1. L_S denotes the segmentation loss between prediction p and ground-truth Y. Additionally, L_KL, denoting the Kullback–Leibler divergence loss between prior distribution P and posterior distribution Q, is used at train-time (refer to arXiv:1905.13077). For each execution of the model, latent samples z_i ∈ Q (train-time) or z_i ∈ P (test-time) are successively drawn at increasing scales of the model to predict one segmentation mask p.

schematic Architecture schematic of M1, with attention mechanisms and a nested decoder structure with deep supervision.

Minimal Example of Model Setup in TensorFlow 2.5:
(More Details: Training CNNs in TF2: Walkthrough; TF2 Datasets: Best Practices; TensorFlow Probability)

# U-Net Definition (Note: Hyperparameters are Data-Centric -> Require Adequate Tuning for Optimal Performance)
unet_model = unets.networks.M1(\
                        input_spatial_dims =  (20,160,160),            
                        input_channels     =   3,
                        num_classes        =   2,                       
                        filters            =  (32,64,128,256,512),   
                        strides            = ((1,1,1),(1,2,2),(1,2,2),(2,2,2),(2,2,2)),  
                        kernel_sizes       = ((1,3,3),(1,3,3),(3,3,3),(3,3,3),(3,3,3)),
                        prob_latent_dims   =  (3,2,1,0)
                        dropout_rate       =   0.50,       
                        dropout_mode       =  'monte-carlo',
                        se_reduction       =  (8,8,8,8,8),
                        att_sub_samp       = ((1,1,1),(1,1,1),(1,1,1),(1,1,1)),
                        kernel_initializer =   tf.keras.initializers.Orthogonal(gain=1), 
                        bias_initializer   =   tf.keras.initializers.TruncatedNormal(mean=0, stddev=1e-3),
                        kernel_regularizer =   tf.keras.regularizers.l2(1e-4),
                        bias_regularizer   =   tf.keras.regularizers.l2(1e-4),     
                        cascaded           =   False,
                        probabilistic      =   True,
                        deep_supervision   =   True,
                        summary            =   True)  

# Schedule Cosine Annealing Learning Rate with Warm Restarts
LR_SCHEDULE = (tf.keras.optimizers.schedules.CosineDecayRestarts(\
                        initial_learning_rate=1e-3, t_mul=2.00, m_mul=1.00, alpha=1e-3,
                        first_decay_steps=int(np.ceil(((TRAIN_SAMPLES)/BATCH_SIZE)))*10))
                                                  
# Compile Model w/ Optimizer and Loss Function(s)
unet_model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=LR_SCHEDULE, amsgrad=True), 
                   loss      = losses.Focal(alpha=[0.75, 0.25], gamma=2.00).loss)

# Train Model
unet_model.fit(...)

If you use this repo or some part of its codebase, please cite the following articles (see bibtex):

● A. Saha, J. Bosma, J. Linmans, M. Hosseinzadeh, H. Huisman (2021), "Anatomical and Diagnostic Bayesian Segmentation in Prostate MRI βˆ’Should Different Clinical Objectives Mandate Different Loss Functions?", Medical Imaging Meets NeurIPS Workshop – 35th Conference on Neural Information Processing Systems (NeurIPS), Sydney, Australia. (architecture in commit 914ec9d)

● A. Saha, M. Hosseinzadeh, H. Huisman (2021), "End-to-End Prostate Cancer Detection in bpMRI via 3D CNNs: Effect of Attention Mechanisms, Clinical Priori and Decoupled False Positive Reduction", Medical Image Analysis:102155. (architecture in commit 58b784f)

● A. Saha, M. Hosseinzadeh, H. Huisman (2020), "Encoding Clinical Priori in 3D Convolutional Neural Networks for Prostate Cancer Detection in bpMRI", Medical Imaging Meets NeurIPS Workshop – 34th Conference on Neural Information Processing Systems (NeurIPS), Vancouver, Canada. (architecture in commit 58b784f)

Contact: [email protected]; [email protected]

Related U-Net Architectures:
● nnU-Net: https://github.com/MIC-DKFZ/nnUNet
● Attention U-Net: https://github.com/ozan-oktay/Attention-Gated-Networks
● UNet++: https://github.com/MrGiovanni/UNetPlusPlus
● Hierarchical Probabilistic U-Net: https://github.com/deepmind/deepmind-research/tree/master/hierarchical_probabilistic_unet

Owner
Diagnostic Image Analysis Group
Diagnostic Image Analysis Group
A CROSS-MODAL FUSION NETWORK BASED ON SELF-ATTENTION AND RESIDUAL STRUCTURE FOR MULTIMODAL EMOTION RECOGNITION

CFN-SR A CROSS-MODAL FUSION NETWORK BASED ON SELF-ATTENTION AND RESIDUAL STRUCTURE FOR MULTIMODAL EMOTION RECOGNITION The audio-video based multimodal

skeleton 15 Sep 26, 2022
Train emoji embeddings based on emoji descriptions.

emoji2vec This is my attempt to train, visualize and evaluate emoji embeddings as presented by Ben Eisner, Tim RocktΓ€schel, Isabelle Augenstein, Matko

Miruna Pislar 17 Sep 03, 2022
Grow Function: Generate 3D Stacked Bifurcating Double Deep Cellular Automata based organisms which differentiate using a Genetic Algorithm...

Grow Function: A 3D Stacked Bifurcating Double Deep Cellular Automata which differentiates using a Genetic Algorithm... TLDR;High Def Trees that you can mint as NFTs on Solana

Nathaniel Gibson 4 Oct 08, 2022
Pywonderland - A tour in the wonderland of math with python.

A Tour in the Wonderland of Math with Python A collection of python scripts for drawing beautiful figures and animating interesting algorithms in math

Zhao Liang 4.1k Jan 03, 2023
The VeriNet toolkit for verification of neural networks

VeriNet The VeriNet toolkit is a state-of-the-art sound and complete symbolic interval propagation based toolkit for verification of neural networks.

9 Dec 21, 2022
How to Train a GAN? Tips and tricks to make GANs work

(this list is no longer maintained, and I am not sure how relevant it is in 2020) How to Train a GAN? Tips and tricks to make GANs work While research

Soumith Chintala 10.8k Dec 31, 2022
Code repo for "Transformer on a Diet" paper

Transformer on a Diet Reference: C Wang, Z Ye, A Zhang, Z Zhang, A Smola. "Transformer on a Diet". arXiv preprint arXiv (2020). Installation pip insta

cgraywang 31 Sep 26, 2021
PyTorch implementation of "PatchGame: Learning to Signal Mid-level Patches in Referential Games" to appear in NeurIPS 2021

PatchGame: Learning to Signal Mid-level Patches in Referential Games This repository is the official implementation of the paper - "PatchGame: Learnin

Kamal Gupta 22 Mar 16, 2022
A numpy-based implementation of RANSAC for fundamental matrix and homography estimation. The degeneracy updating and local optimization components are included and optional.

Description A numpy-based implementation of RANSAC for fundamental matrix and homography estimation. The degeneracy updating and local optimization co

AoxiangFan 9 Nov 10, 2022
Implementation for paper "Towards the Generalization of Contrastive Self-Supervised Learning"

Contrastive Self-Supervised Learning on CIFAR-10 Paper "Towards the Generalization of Contrastive Self-Supervised Learning", Weiran Huang, Mingyang Yi

Weiran Huang 13 Nov 30, 2022
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Dec 27, 2022
[CVPR 2021] Exemplar-Based Open-Set Panoptic Segmentation Network (EOPSN)

EOPSN: Exemplar-Based Open-Set Panoptic Segmentation Network (CVPR 2021) PyTorch implementation for EOPSN. We propose open-set panoptic segmentation t

Jaedong Hwang 49 Dec 30, 2022
ADB-IP-ROTATION - Use your mobile phone to gain a temporary IP address using ADB and data tethering

ADB IP ROTATE This an Python script based on Android Debug Bridge (adb) shell sc

Dor Bismuth 2 Jul 12, 2022
NovelD: A Simple yet Effective Exploration Criterion

NovelD: A Simple yet Effective Exploration Criterion Intro This is an implementation of the method proposed in NovelD: A Simple yet Effective Explorat

29 Dec 05, 2022
SASM - simple crossplatform IDE for NASM, MASM, GAS and FASM assembly languages

ο»ΏSASM (SimpleASM) - простая кроссплатформСнная срСда Ρ€Π°Π·Ρ€Π°Π±ΠΎΡ‚ΠΊΠΈ для языков ассСмблСра NASM, MASM, GAS, FASM с подсвСткой синтаксиса ΠΈ ΠΎΡ‚Π»Π°Π΄Ρ‡ΠΈΠΊΠΎΠΌ. Π’ SA

Dmitriy Manushin 5.6k Jan 06, 2023
A PyTorch implementation of unsupervised SimCSE

A PyTorch implementation of unsupervised SimCSE

99 Dec 23, 2022
Tutorial materials for Part of NSU Intro to Deep Learning with PyTorch.

Intro to Deep Learning Materials are part of North South University (NSU) Intro to Deep Learning with PyTorch workshop series. (Slides) Related materi

Hasib Zunair 9 Jun 08, 2022
Unrolled Generative Adversarial Networks

Unrolled Generative Adversarial Networks Luke Metz, Ben Poole, David Pfau, Jascha Sohl-Dickstein arxiv:1611.02163 This repo contains an example notebo

Ben Poole 292 Dec 06, 2022
Learning to Stylize Novel Views

Learning to Stylize Novel Views [Project] [Paper] Contact: Hsin-Ping Huang ([ema

34 Nov 27, 2022
Unsupervised Domain Adaptation for Nighttime Aerial Tracking (CVPR2022)

Unsupervised Domain Adaptation for Nighttime Aerial Tracking (CVPR2022) Junjie Ye, Changhong Fu, Guangze Zheng, Danda Pani Paudel, and Guang Chen. Uns

Intelligent Vision for Robotics in Complex Environment 91 Dec 30, 2022