The repo contains the code of the ACL2020 paper `Dice Loss for Data-imbalanced NLP Tasks`

Overview

Dice Loss for NLP Tasks

This repository contains code for Dice Loss for Data-imbalanced NLP Tasks at ACL2020.

Setup

  • Install Package Dependencies

The code was tested in Python 3.6.9+ and Pytorch 1.7.1. If you are working on ubuntu GPU machine with CUDA 10.1, please run the following command to setup environment.

$ virtualenv -p /usr/bin/python3.6 venv
$ source venv/bin/activate
$ pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
$ pip install -r requirements.txt
  • Download BERT Model Checkpoints

Before running the repo you must download the BERT-Base and BERT-Large checkpoints from here and unzip it to some directory $BERT_DIR. Then convert original TensorFlow checkpoints for BERT to a PyTorch saved file by running bash scripts/prepare_ckpt.sh <path-to-unzip-tf-bert-checkpoints>.

Apply Dice-Loss to NLP Tasks

In this repository, we apply dice loss to four NLP tasks, including

  1. machine reading comprehension
  2. paraphrase identification task
  3. named entity recognition
  4. text classification

1. Machine Reading Comprehension

Datasets

We take SQuAD 1.1 as an example. Before training, you should download a copy of the data from here.
And move the SQuAD 1.1 train train-v1.1.json and dev file dev-v1.1.json to the directory $DATA_DIR.

Train

We choose BERT as the backbone. During training, the task trainer BertForQA will automatically evaluate on dev set every $val_check_interval epoch, and save the dev predictions into files called $OUTPUT_DIR/predictions_<train-epoch>_<total-train-step>.json and $OUTPUT_DIR/nbest_predictions_<train-epoch>_<total-train-step>.json.

Run scripts/squad1/bert_<model-scale>_<loss-type>.sh to reproduce our experimental results.
The variable <model-scale> should take the value of [base, large].
The variable <loss-type> should take the value of [bce, focal, dice] which denotes fine-tuning BERT-Base with binary cross entropy loss, focal loss, dice loss , respectively.

  • Run bash scripts/squad1/bert_base_focal.sh to start training. After training, run bash scripts/squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR for focal loss.

  • Run bash scripts/squad1/bert_base_dice.sh to start training. After training, run bash scripts/squad1/eval_pred_file.sh $DATA_DIR $OUTPUT_DIR for dice loss.

Evaluate

To evaluate a model checkpoint, please run

python3 tasks/squad/evaluate_models.py \
--gpus="1" \
--path_to_model_checkpoint  $OUTPUT_DIR/epoch=2.ckpt \
--eval_batch_size <evaluate-batch-size>

After evaluation, prediction results predictions_dev.json and nbest_predictions_dev.json can be found in $OUTPUT_DIR

To evaluate saved predictions, please run

python3 tasks/squad/evaluate_predictions.py <path-to-dev-v1.1.json> <directory-to-prediction-files>

2. Paraphrase Identification Task

Datasets

We use MRPC (GLUE Version) as an example. Before running experiments, you should download and save the processed dataset files to $DATA_DIR.

Run bash scripts/prepare_mrpc_data.sh $DATA_DIR to download and process datasets for MPRC (GLUE Version) task.

Train

Please run scripts/glue_mrpc/bert_<model-scale>_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. After training, the task trainer evaluates on the test set with the best checkpoint which achieves the highest F1-score on the dev set.
The variable <model-scale> should take the value of [base, large].
The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

  • Run bash scripts/glue_mrpc/bert_large_focal.sh for focal loss.

  • Run bash scripts/glue_mrpc/bert_large_dice.sh for dice loss.

The evaluation results on the dev and test set are saved at $OUTPUT_DIR/eval_result_log.txt file.
The intermediate model checkpoints are saved at most $max_keep_ckpt times.

Evaluate

To evaluate a model checkpoint on test set, please run

bash scripts/glue_mrpc/eval.sh \
$OUTPUT_DIR \
epoch=*.ckpt

3. Named Entity Recognition

For NER, we use MRC-NER model as the backbone.
Processed datasets and model architecture can be found here.

Train

Please run scripts/<ner-datdaset-name>/bert_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. After training, the task trainer evaluates on the test set with the best checkpoint.
The variable <ner-dataset-name> should take the value of [ner_enontonotes5, ner_zhmsra, ner_zhonto4].
The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

For Chinese MSRA,

  • Run scripts/ner_zhmsra/bert_focal.sh for focal loss.

  • Run scripts/ner_zhmsra/bert_dice.sh for dice loss.

For Chinese OntoNotes4,

  • Run scripts/ner_zhonto4/bert_focal.sh for focal loss.

  • Run scripts/ner_zhonto4/bert_dice.sh for dice loss.

For English OntoNotes5,

  • Run scripts/ner_enontonotes5/bert_focal.sh. After training, you will get 91.12 Span-F1 on the test set.

  • Run scripts/ner_enontonotes5/bert_dice.sh. After training, you will get 92.01 Span-F1 on the test set.

Evaluate

To evaluate a model checkpoint, please run

CUDA_VISIBLE_DEVICES=0 python3 ${REPO_PATH}/tasks/mrc_ner/evaluate.py \
--gpus="1" \
--path_to_model_checkpoint $OUTPUT_DIR/epoch=2.ckpt

4. Text Classification

Datasets

We use TNews (Chinese Text Classification) as an example. Before running experiments, you should download and save the processed dataset files to $DATA_DIR.

Train

We choose BERT as the backbone.
Please run scripts/tnews/bert_<loss-type>.sh to train and evaluate on the dev set every $val_check_interval epoch. The variable <loss-type> should take the value of [focal, dice] which denotes fine-tuning BERT with focal loss, dice loss , respectively.

  • Run bash scripts/tnews/bert_focal.sh for focal loss.

  • Run bash scripts/tnews/bert_dice.sh for dice loss.

The intermediate model checkpoints are saved at most $max_keep_ckpt times.

Citation

If you find this repository useful , please cite the following:

@article{li2019dice,
  title={Dice loss for data-imbalanced NLP tasks},
  author={Li, Xiaoya and Sun, Xiaofei and Meng, Yuxian and Liang, Junjun and Wu, Fei and Li, Jiwei},
  journal={arXiv preprint arXiv:1911.02855},
  year={2019}
}

Contact

xiaoyalixy AT gmail.com OR xiaoya_li AT shannonai.com

Any discussions, suggestions and questions are welcome!

An original implementation of "MetaICL Learning to Learn In Context" by Sewon Min, Mike Lewis, Luke Zettlemoyer and Hannaneh Hajishirzi

MetaICL: Learning to Learn In Context This includes an original implementation of "MetaICL: Learning to Learn In Context" by Sewon Min, Mike Lewis, Lu

Meta Research 141 Jan 07, 2023
NasirKhusraw - The TSP solved using genetic algorithm and show TSP path overlaid on a map of the Iran provinces & their capitals.

Nasir Khusraw : Travelling Salesman Problem The TSP solved using genetic algorithm. This project show TSP path overlaid on a map of the Iran provinces

J Brave 2 Sep 01, 2022
Identifying a Training-Set Attack’s Target Using Renormalized Influence Estimation

Identifying a Training-Set Attack’s Target Using Renormalized Influence Estimation By: Zayd Hammoudeh and Daniel Lowd Paper: Arxiv Preprint Coming soo

Zayd Hammoudeh 2 Oct 08, 2022
Official implementation of NLOS-OT: Passive Non-Line-of-Sight Imaging Using Optimal Transport (IEEE TIP, accepted)

NLOS-OT Official implementation of NLOS-OT: Passive Non-Line-of-Sight Imaging Using Optimal Transport (IEEE TIP, accepted) Description In this reposit

Ruixu Geng(耿瑞旭) 16 Dec 16, 2022
Keras code and weights files for popular deep learning models.

Trained image classification models for Keras THIS REPOSITORY IS DEPRECATED. USE THE MODULE keras.applications INSTEAD. Pull requests will not be revi

François Chollet 7.2k Dec 29, 2022
tmm_fast is a lightweight package to speed up optical planar multilayer thin-film device computation.

tmm_fast tmm_fast or transfer-matrix-method_fast is a lightweight package to speed up optical planar multilayer thin-film device computation. It is es

26 Dec 11, 2022
Phonetic PosteriorGram (PPG)-Based Voice Conversion (VC)

ppg-vc Phonetic PosteriorGram (PPG)-Based Voice Conversion (VC) This repo implements different kinds of PPG-based VC models. Pretrained models. More m

Liu Songxiang 227 Dec 28, 2022
LAMDA: Label Matching Deep Domain Adaptation

LAMDA: Label Matching Deep Domain Adaptation This is the implementation of the paper LAMDA: Label Matching Deep Domain Adaptation which has been accep

Tuan Nguyen 9 Sep 06, 2022
Fewshot-face-translation-GAN - Generative adversarial networks integrating modules from FUNIT and SPADE for face-swapping.

Few-shot face translation A GAN based approach for one model to swap them all. The table below shows our priliminary face-swapping results requiring o

768 Dec 24, 2022
Efficient Lottery Ticket Finding: Less Data is More

The lottery ticket hypothesis (LTH) reveals the existence of winning tickets (sparse but critical subnetworks) for dense networks, that can be trained in isolation from random initialization to match

VITA 20 Sep 04, 2022
Deep Learning Head Pose Estimation using PyTorch.

Hopenet is an accurate and easy to use head pose estimation network. Models have been trained on the 300W-LP dataset and have been tested on real data with good qualitative performance.

Nataniel Ruiz 1.3k Dec 26, 2022
A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction"

ssnt-loss ℹ️ This is a WIP project. the implementation is still being tested. A pure PyTorch implementation of the loss described in "Online Segment t

張致強 1 Feb 09, 2022
An evaluation toolkit for voice conversion models.

Voice-conversion-evaluation An evaluation toolkit for voice conversion models. Sample test pair Generate the metadata for evaluating models. The direc

30 Aug 29, 2022
Simple Baselines for Human Pose Estimation and Tracking

Simple Baselines for Human Pose Estimation and Tracking News Our new work High-Resolution Representations for Labeling Pixels and Regions is available

Microsoft 2.7k Jan 05, 2023
Group R-CNN for Point-based Weakly Semi-supervised Object Detection (CVPR2022)

Group R-CNN for Point-based Weakly Semi-supervised Object Detection (CVPR2022) By Shilong Zhang*, Zhuoran Yu*, Liyang Liu*, Xinjiang Wang, Aojun Zhou,

Shilong Zhang 129 Dec 24, 2022
AdamW optimizer for bfloat16 models in pytorch.

Image source AdamW optimizer for bfloat16 models in pytorch. Bfloat16 is currently an optimal tradeoff between range and relative error for deep netwo

Alex Rogozhnikov 8 Nov 20, 2022
Label Mask for Multi-label Classification

LM-MLC 一种基于完型填空的多标签分类算法 1 前言 本文主要介绍本人在全球人工智能技术创新大赛【赛道一】设计的一种基于完型填空(模板)的多标签分类算法:LM-MLC,该算法拟合能力很强能感知标签关联性,在多个数据集上测试表明该算法与主流算法无显著性差异,在该比赛数据集上的dev效果很好,但是由

52 Nov 20, 2022
Pure python implementations of popular ML algorithms.

Minimal ML algorithms This repo includes minimal implementations of popular ML algorithms using pure python and numpy. The purpose of these notebooks

Alexis Gidiotis 3 Jan 10, 2022
A collection of models for image<->text generation in ACM MM 2021.

Bi-directional Image and Text Generation UMT-BITG (image & text generator) Unifying Multimodal Transformer for Bi-directional Image and Text Generatio

Multimedia Research 63 Oct 30, 2022
Full Transformer Framework for Robust Point Cloud Registration with Deep Information Interaction

Full Transformer Framework for Robust Point Cloud Registration with Deep Information Interaction. arxiv This repository contains python scripts for tr

12 Dec 12, 2022