Requirements can be downloaded by running
pip install -r requirements.txt
The training script is train_models.py
which loads a jsonlines dataset, splits it 80/10/10 into a train-validate-test split, and trains for 10 epochs with early stopping. It supports training on multiple GPUs using PyTorch's DataParallel
wrapper and mini-batch training with gradient accumulation. It outputs a trained model and a JSON file containing the indexes of the validation set used and the held out test set.
Users can either pass in training parameters via command-line arguments or via a JSON file. If a user chooses to use command-line arguments, the script will automatically save training parameters in a JSON file.
Current training parameters include the following:
datasets
- The list of datasets to train on.base_model
- The pre-trained model to use as the base. Our code currently supportsbert-base-uncased
androberta-base
output_layers
- The number of layers to use in the output head. (Can be 1, 2 or 3).weighted_loss
- Whether to use a weighted loss functionlr
- The ADAM learning rate to usemax_grad_norm
- The threshold used for gradient clipping.weight_decay
- The weight decay value to usedropout_rate
- The dropout value to use before the linear layerbatch_size
- The batch size to use between optimizer steps
To load training parameters from a JSON file, run
python train_model.py --config [JSON FILE PATH]
We have two options for evaluating models - one on GPUs and one on Google Colab TPUs. Both methods can evaluate a model against multiple datasets. If a dataset was used to train the model, they evaluate the model against the test set held out during training. Otherwise, they evaluate the model against the entire dataset.
To test models on GPUs, use the test_model.py
script as follows:
test_model.py --model model_name --datasets dataset1.jsonl dataset2.jsonl
To test models on TPUs, import the test_xla_mp function as follows:
import os
from glob import glob
os.chdir("/content/deep learning")
print(os.getcwd())
from test_model_xla import test_xla_mp
We have a full evaluation notebook in here