PyTorch trainer and model for Sequence Classification

Overview

PyTorch-trainer-and-model-for-Sequence-Classification

After cloning the repository, modify your training data so that the training data is a .csv file and it has 2 columns: Text and Label

In the below example, we will assume that our training data has 3 labels, the name of our training data file is train_data.csv

Example Usage

Import dependencies

import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer, AutoConfig

from EarlyStopping import *
from modelling import *
from utils import *

Specify arguments

args.pretrained_path will be the path of our pretrained language model

class args:
    fold = 0
    pretrained_path = 'bert-base-uncased'
    max_length = 400
    train_batch_size = 16
    val_batch_size = 64
    epochs = 5
    learning_rate = 1e-5
    accumulation_steps = 2
    num_splits = 5

Create train and validation data

In this example we will train the model using cross-validation. We will split our training data into args.num_splits folds.

df = pd.read_csv('./train_data.csv')
df = create_k_folds(df, args.num_splits)

df_train = df[df['kfold'] == args.fold].reset_index(drop = True)
df_valid = df[df['kfold'] == args.fold].reset_index(drop = True)

Load the language model and its tokenizer

config = AutoConfig.from_pretrained(args.path)
tokenizer = AutoTokenizer.from_pretrained(args.path)
model_transformer = AutoModel.from_pretrained(args.path)

Prepare train and validation dataloaders

features = []
for i in range(len(df_train)):
    features.append(prepare_features(tokenizer, df_train.iloc[i, :].to_dict(), args.max_length))
    
train_dataset = CreateDataset(features)
train_dataloader = create_dataloader(train_dataset, args.train_batch_size, 'train')

features = []
for i in range(len(df_valid)):
    features.append(prepare_features(tokenizer, df_valid.iloc[i, :].to_dict(), args.max_length))
    
val_dataset = CreateDataset(features)
val_dataloader = create_dataloader(val_dataset, args.val_batch_size, 'val')

Use EarlyStopping and customize the score function

NOTE: The customized score function should have 2 parameters: the logits, and the actual label

def accuracy(logits, labels):
    logits = logits.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    pred_classes = np.argmax(logits * (1 / np.sum(logits, axis = -1)).reshape(logits.shape[0], 1), axis = -1)
    pred_classes = pred_classes.reshape(labels.shape)
    
    return np.sum(pred_classes == labels) / labels.shape[0]

es = EarlyStopping(mode = 'max', patience = 3, monitor = 'val_acc', out_path = 'model.bin')
es.monitor_score_function = accuracy

Create and train the model

Calling the fit method, the training process will begin

model = Model(config, model_transformer, num_labels = 3)
model.to('cuda')
num_train_steps = int(len(train_dataset) / args.train_batch_size * args.epochs)
model.fit(args.epochs, args.learning_rate, num_train_steps, args.accumulation_steps, 
          train_dataloader, val_dataloader, es)

NOTE: To complete the cross-validation training process, run the code above again with args.fold equals 1, 2, ..., args.num_splits - 1

Owner
NhanTieu
NhanTieu
eXPeditious Data Transfer

xpdt: eXPeditious Data Transfer About xpdt is (yet another) language for defining data-types and generating code for serializing and deserializing the

Gianni Tedesco 3 Jan 06, 2022
DeepMind's software stack for physics-based simulation and Reinforcement Learning environments, using MuJoCo.

dm_control: DeepMind Infrastructure for Physics-Based Simulation. DeepMind's software stack for physics-based simulation and Reinforcement Learning en

DeepMind 3k Dec 31, 2022
[3DV 2020] PeeledHuman: Robust Shape Representation for Textured 3D Human Body Reconstruction

PeeledHuman: Robust Shape Representation for Textured 3D Human Body Reconstruction International Conference on 3D Vision, 2020 Sai Sagar Jinka1, Rohan

Rohan Chacko 39 Oct 12, 2022
Implementation of the paper "Self-Promoted Prototype Refinement for Few-Shot Class-Incremental Learning"

Self-Promoted Prototype Refinement for Few-Shot Class-Incremental Learning This is the implementation of the paper "Self-Promoted Prototype Refinement

Kai Zhu 78 Dec 02, 2022
E2EC: An End-to-End Contour-based Method for High-Quality High-Speed Instance Segmentation

E2EC: An End-to-End Contour-based Method for High-Quality High-Speed Instance Segmentation E2EC: An End-to-End Contour-based Method for High-Quality H

zhangtao 146 Dec 29, 2022
UniLM AI - Large-scale Self-supervised Pre-training across Tasks, Languages, and Modalities

Pre-trained (foundation) models across tasks (understanding, generation and translation), languages (100+ languages), and modalities (language, image, audio, vision + language, audio + language, etc.

Microsoft 7.6k Jan 01, 2023
Training and Evaluation Code for Neural Volumes

Neural Volumes This repository contains training and evaluation code for the paper Neural Volumes. The method learns a 3D volumetric representation of

Meta Research 370 Dec 08, 2022
code and data for paper "GIANT: Scalable Creation of a Web-scale Ontology"

GIANT Code and data for paper "GIANT: Scalable Creation of a Web-scale Ontology" https://arxiv.org/pdf/2004.02118.pdf Please cite our paper if this pr

Excalibur 39 Dec 29, 2022
A Tensorfflow implementation of Attend, Infer, Repeat

Attend, Infer, Repeat: Fast Scene Understanding with Generative Models This is an unofficial Tensorflow implementation of Attend, Infear, Repeat (AIR)

Adam Kosiorek 82 May 27, 2022
Efficient 3D Backbone Network for Temporal Modeling

VoV3D is an efficient and effective 3D backbone network for temporal modeling implemented on top of PySlowFast. Diverse Temporal Aggregation and

102 Dec 06, 2022
Probabilistic Programming and Statistical Inference in PyTorch

PtStat Probabilistic Programming and Statistical Inference in PyTorch. Introduction This project is being developed during my time at Cogent Labs. The

Stefano Peluchetti 109 Nov 26, 2022
Weakly supervised medical named entity classification

Trove Trove is a research framework for building weakly supervised (bio)medical named entity recognition (NER) and other entity attribute classifiers

60 Nov 18, 2022
HairCLIP: Design Your Hair by Text and Reference Image

Overview This repository hosts the official PyTorch implementation of the paper: "HairCLIP: Design Your Hair by Text and Reference Image". Our single

322 Jan 06, 2023
TensorFlow implementation of Barlow Twins (Barlow Twins: Self-Supervised Learning via Redundancy Reduction)

Barlow-Twins-TF This repository implements Barlow Twins (Barlow Twins: Self-Supervised Learning via Redundancy Reduction) in TensorFlow and demonstrat

Sayak Paul 36 Sep 14, 2022
LoFTR:Detector-Free Local Feature Matching with Transformers CVPR 2021

LoFTR-with-train-script LoFTR:Detector-Free Local Feature Matching with Transformers CVPR 2021 (with train script --- unofficial ---). About Megadepth

Nan Xiaohu 15 Nov 04, 2022
MBPO (paper: When to trust your model: Model-based policy optimization) in offline RL settings

offline-MBPO This repository contains the code of a version of model-based RL algorithm MBPO, which is modified to perform in offline RL settings Pape

LxzGordon 1 Oct 24, 2021
Direct Multi-view Multi-person 3D Human Pose Estimation

Implementation of NeurIPS-2021 paper: Direct Multi-view Multi-person 3D Human Pose Estimation [paper] [video-YouTube, video-Bilibili] [slides] This is

Sea AI Lab 251 Dec 30, 2022
Unofficial PyTorch code for BasicVSR

Dependencies and Installation The code is based on BasicSR, Please install the BasicSR framework first. Pytorch=1.51 Training cd ./code CUDA_VISIBLE_

Long 59 Dec 06, 2022
Simple node deletion tool for onnx.

snd4onnx Simple node deletion tool for onnx. I only test very miscellaneous and limited patterns as a hobby. There are probably a large number of bugs

Katsuya Hyodo 6 May 15, 2022
OCRA (Object-Centric Recurrent Attention) source code

OCRA (Object-Centric Recurrent Attention) source code Hossein Adeli and Seoyoung Ahn Please cite this article if you find this repository useful: For

Hossein Adeli 2 Jun 18, 2022