Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[FEATURE] INT8 Quantization for BERT Sentence Classification and Ques…
Browse files Browse the repository at this point in the history
…tion Answering (#1080)
  • Loading branch information
xinyu-intel authored Feb 3, 2020
1 parent 8d43da9 commit 645fe30
Show file tree
Hide file tree
Showing 8 changed files with 414 additions and 28 deletions.
101 changes: 97 additions & 4 deletions docs/examples/sentence_embedding/bert.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import random
import numpy as np
import mxnet as mx
import gluonnlp as nlp
from gluonnlp.calibration import BertLayerCollector
# this notebook assumes that all required scripts are already
# downloaded from the corresponding tutorial webpage on http://gluon-nlp.mxnet.io
from bert import data
Expand Down Expand Up @@ -223,8 +224,8 @@ print('%s token id = %s'%(vocabulary.padding_token, vocabulary[vocabulary.paddin
print('%s token id = %s'%(vocabulary.cls_token, vocabulary[vocabulary.cls_token]))
print('%s token id = %s'%(vocabulary.sep_token, vocabulary[vocabulary.sep_token]))
print('token ids = \n%s'%data_train[sample_id][0])
print('valid length = \n%s'%data_train[sample_id][1])
print('segment ids = \n%s'%data_train[sample_id][2])
print('segment ids = \n%s'%data_train[sample_id][1])
print('valid length = \n%s'%data_train[sample_id][2])
print('label = \n%s'%data_train[sample_id][3])
```

Expand All @@ -241,7 +242,7 @@ batch_size = 32
lr = 5e-6
# The FixedBucketSampler and the DataLoader for making the mini-batches
train_sampler = nlp.data.FixedBucketSampler(lengths=[int(item[1]) for item in data_train],
train_sampler = nlp.data.FixedBucketSampler(lengths=[int(item[2]) for item in data_train],
batch_size=batch_size,
shuffle=True)
bert_dataloader = mx.gluon.data.DataLoader(data_train, batch_sampler=train_sampler)
Expand All @@ -261,7 +262,7 @@ num_epochs = 3
for epoch_id in range(num_epochs):
metric.reset()
step_loss = 0
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(bert_dataloader):
for batch_id, (token_ids, segment_ids, valid_length, label) in enumerate(bert_dataloader):
with mx.autograd.record():
# Load the data to the GPU
Expand Down Expand Up @@ -294,6 +295,98 @@ for epoch_id in range(num_epochs):
step_loss = 0
```

## Quantize the model

GluonNLP also delivered some INT8 quantization methods to improve the performance and reduce the deployment costs for the natural language inference tasks. In real production, there are two main benefits of lower precision (INT8). First, the computation can be accelerated by the low precision instruction, like Intel Vector Neural Network Instruction (VNNI). Second, lower precision data type would save the memory bandwidth and allow for better cache locality and save the power. The new feature can get up to 4X performance speedup in the latest [AWS EC2 C5 instances](https://aws.amazon.com/blogs/aws/now-available-new-c5-instance-sizes-and-bare-metal-instances/) under the [Intel Deep Learning Boost (VNNI)](https://www.intel.ai/intel-deep-learning-boost/) enabled hardware with less than 0.5% accuracy drop.

Now we have a fine-tuned model on MRPC training dataset and in this section, we will quantize the model into INT8 data type on a subset of MRPC validation dataset.

```{.python .input}
# The hyperparameters
dev_batch_size = 32
num_calib_batches = 5
quantized_dtype = 'auto'
calib_mode = 'customize'
# sampler for evaluation
pad_val = vocabulary[vocabulary.padding_token]
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0, pad_val=pad_val), # input
nlp.data.batchify.Pad(axis=0, pad_val=0), # segment
nlp.data.batchify.Stack(), # length
nlp.data.batchify.Stack('int32')) # label
dev_dataloader = mx.gluon.data.DataLoader(data_train, batch_size=dev_batch_size, num_workers=4,
shuffle=False, batchify_fn=batchify_fn)
# Calibration function
def calibration(net, dev_data, num_calib_batches, quantized_dtype, calib_mode):
"""calibration function on the dev dataset."""
print('Now we are doing calibration on dev with cpu.')
collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=None)
num_calib_examples = dev_batch_size * num_calib_batches
quantized_net = mx.contrib.quantization.quantize_net_v2(net, quantized_dtype=quantized_dtype,
exclude_layers=[],
quantize_mode='smart',
quantize_granularity='channel-wise',
calib_data=dev_data,
calib_mode=calib_mode,
num_calib_examples=num_calib_examples,
ctx=mx.cpu(),
LayerOutputCollector=collector,
logger=None)
print('Calibration done with success.')
return quantized_net
# will remove until mxnet 1.7 release.
try:
quantized_net = calibration(bert_classifier,
dev_dataloader,
num_calib_batches,
quantized_dtype,
calib_mode)
except AttributeError:
nlp.utils.version.check_version('1.7.0', warning_only=True, library=mx)
warnings.warn('INT8 Quantization for BERT need mxnet-mkl >= 1.6.0b20200115')
```

## Deployment

After quantization, we can also export the quantized model for inference deployment.

```{.python .input}
prefix = './model_bert_squad_quantized'
def deployment(net, prefix, dataloader):
net.export(prefix, epoch=0)
print('Saving quantized model at ', prefix)
print('load symbol file directly as SymbolBlock for model deployment.')
static_net = mx.gluon.SymbolBlock.imports('{}-symbol.json'.format(prefix),
['data0', 'data1', 'data2'],
'{}-0000.params'.format(prefix))
static_net.hybridize(static_alloc=True, static_shape=True)
for batch_id, (token_ids, segment_ids, valid_length, label) in enumerate(dev_dataloader):
token_ids = token_ids.as_in_context(mx.cpu())
valid_length = valid_length.as_in_context(mx.cpu())
segment_ids = segment_ids.as_in_context(mx.cpu())
label = label.as_in_context(mx.cpu())
out = static_net(token_ids, segment_ids, valid_length.astype('float32'))
metric.update([label], [out])
# Printing vital information
if (batch_id + 1) % (log_interval) == 0:
print('[Batch {}/{}], acc={:.3f}'
.format(batch_id + 1, len(bert_dataloader),
metric.get()[1]))
return metric
# will remove until mxnet 1.7 release.
try:
eval_metric = deployment(quantized_net, prefix, dev_dataloader)
except NameError:
nlp.utils.version.check_version('1.7.0', warning_only=True, library=mx)
warnings.warn('INT8 Quantization for BERT need mxnet-mkl >= 1.6.0b20200115')
```

## Conclusion

In this tutorial, we showed how to fine-tune a sentence pair
Expand Down
5 changes: 3 additions & 2 deletions scripts/bert/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __call__(self, line):
if self.class_labels:
label = self._label_map[label]
label = np.array([label], dtype=self._label_dtype)
return input_ids, valid_length, segment_ids, label
return input_ids, segment_ids, valid_length, label
else:
return self._bert_xform(line)
input_ids, valid_length, segment_ids = self._bert_xform(line)
return input_ids, segment_ids, valid_length
90 changes: 80 additions & 10 deletions scripts/bert/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import gluonnlp as nlp
from gluonnlp.data import BERTTokenizer
from gluonnlp.model import BERTClassifier, RoBERTaClassifier
from gluonnlp.calibration import BertLayerCollector
from data.classification import MRPCTask, QQPTask, RTETask, STSBTask, SSTTask
from data.classification import QNLITask, CoLATask, MNLITask, WNLITask, XNLITask
from data.classification import LCQMCTask, ChnSentiCorpTask
Expand Down Expand Up @@ -179,6 +180,21 @@
default=None,
help='Whether to perform early stopping based on the metric on dev set. '
'The provided value is the patience. ')
parser.add_argument('--deploy', action='store_true',
help='whether load static model for deployment')
parser.add_argument('--model_prefix', type=str, required=False,
help='load static model as hybridblock.')
parser.add_argument('--only_calibration', action='store_true',
help='quantize model')
parser.add_argument('--num_calib_batches', type=int, default=5,
help='number of batches for calibration')
parser.add_argument('--quantized_dtype', type=str, default='auto',
choices=['auto', 'int8', 'uint8'],
help='quantization destination data type for input data')
parser.add_argument('--calib_mode', type=str, default='customize',
choices=['none', 'naive', 'entropy', 'customize'],
help='calibration mode used for generating calibration table '
'for the quantized symbol.')

args = parser.parse_args()

Expand Down Expand Up @@ -227,6 +243,11 @@
dataset = args.bert_dataset
pretrained_bert_parameters = args.pretrained_bert_parameters
model_parameters = args.model_parameters

# load symbolic model
deploy = args.deploy
model_prefix = args.model_prefix

if only_inference and not model_parameters:
warnings.warn('model_parameters is not set. '
'Randomly initialized model will be used for inference.')
Expand Down Expand Up @@ -283,13 +304,25 @@
model.hybridize(static_alloc=True)
loss_function.hybridize(static_alloc=True)

if deploy:
logging.info('load symbol file directly as SymbolBlock for model deployment')
model = mx.gluon.SymbolBlock.imports('{}-symbol.json'.format(args.model_prefix),
['data0', 'data1', 'data2'],
'{}-0000.params'.format(args.model_prefix))
model.hybridize(static_alloc=True, static_shape=True)

# data processing
do_lower_case = 'uncased' in dataset
if use_roberta:
bert_tokenizer = nlp.data.GPT2BPETokenizer()
else:
bert_tokenizer = BERTTokenizer(vocabulary, lower=do_lower_case)

# calibration config
only_calibration = args.only_calibration
num_calib_batches = args.num_calib_batches
quantized_dtype = args.quantized_dtype
calib_mode = args.calib_mode

def convert_examples_to_features(example, tokenizer=None, truncate_length=512, cls_token=None,
sep_token=None, class_labels=None, label_alias=None, vocab=None,
Expand Down Expand Up @@ -322,9 +355,9 @@ def convert_examples_to_features(example, tokenizer=None, truncate_length=512, c
input_ids = vocab[tokens]
valid_length = len(input_ids)
if not is_test:
return input_ids, valid_length, segment_ids, label
return input_ids, segment_ids, valid_length, label
else:
return input_ids, valid_length, segment_ids
return input_ids, segment_ids, valid_length


def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, vocab):
Expand All @@ -341,14 +374,14 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, vocab)
# task.dataset_train returns (segment_name, dataset)
train_tsv = task.dataset_train()[1]
data_train = mx.gluon.data.SimpleDataset(list(map(trans, train_tsv)))
data_train_len = data_train.transform(lambda _, valid_length, segment_ids, label: valid_length,
data_train_len = data_train.transform(lambda _, segment_ids, valid_length, label: valid_length,
lazy=False)
# bucket sampler for training
pad_val = vocabulary[vocabulary.padding_token]
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0, pad_val=pad_val), # input
nlp.data.batchify.Stack(), # length
nlp.data.batchify.Pad(axis=0, pad_val=0), # segment
nlp.data.batchify.Stack(), # length
nlp.data.batchify.Stack(label_dtype)) # label
batch_sampler = nlp.data.sampler.FixedBucketSampler(data_train_len, batch_size=batch_size,
num_buckets=10, ratio=0, shuffle=True)
Expand All @@ -368,8 +401,8 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, vocab)

# batchify for data test
test_batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Pad(axis=0, pad_val=pad_val),
nlp.data.batchify.Stack(),
nlp.data.batchify.Pad(axis=0, pad_val=0))
nlp.data.batchify.Pad(axis=0, pad_val=0),
nlp.data.batchify.Stack())
# transform for data test
test_trans = partial(convert_examples_to_features, tokenizer=tokenizer, truncate_length=max_len,
cls_token=vocab.cls_token if not use_roberta else vocab.bos_token,
Expand All @@ -393,6 +426,32 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, vocab)
train_data, dev_data_list, test_data_list, num_train_examples = preprocess_data(
bert_tokenizer, task, batch_size, dev_batch_size, args.max_len, vocabulary)

def calibration(net, dev_data_list, num_calib_batches, quantized_dtype, calib_mode):
"""calibration function on the dev dataset."""
assert len(dev_data_list) == 1, \
'Currectly, MNLI not supported.'
assert ctx == mx.cpu(), \
'Currently only supports CPU with MKL-DNN backend.'
logging.info('Now we are doing calibration on dev with %s.', ctx)
for _, dev_data in dev_data_list:
collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=logging)
num_calib_examples = dev_batch_size * num_calib_batches
net = mx.contrib.quantization.quantize_net_v2(net, quantized_dtype=quantized_dtype,
exclude_layers=[],
quantize_mode='smart',
quantize_granularity='channel-wise',
calib_data=dev_data,
calib_mode=calib_mode,
num_calib_examples=num_calib_examples,
ctx=ctx,
LayerOutputCollector=collector,
logger=logging)
# save params
ckpt_name = 'model_bert_{0}_quantized_{1}'.format(task_name, calib_mode)
params_saved = os.path.join(output_dir, ckpt_name)
net.export(params_saved, epoch=0)
logging.info('Saving quantized model at %s', output_dir)


def test(loader_test, segment):
"""Inference function on the test dataset."""
Expand All @@ -401,7 +460,7 @@ def test(loader_test, segment):
tic = time.time()
results = []
for _, seqs in enumerate(loader_test):
input_ids, valid_length, segment_ids = seqs
input_ids, segment_ids, valid_length = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
if use_roberta:
Expand Down Expand Up @@ -522,7 +581,7 @@ def train(metric):

# forward and backward
with mx.autograd.record():
input_ids, valid_length, segment_ids, label = seqs
input_ids, segment_ids, valid_length, label = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
label = label.as_in_context(ctx)
Expand Down Expand Up @@ -607,7 +666,7 @@ def evaluate(loader_dev, metric, segment):
step_loss = 0
tic = time.time()
for batch_id, seqs in enumerate(loader_dev):
input_ids, valid_length, segment_ids, label = seqs
input_ids, segment_ids, valid_length, label = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
label = label.as_in_context(ctx)
Expand Down Expand Up @@ -639,4 +698,15 @@ def evaluate(loader_dev, metric, segment):


if __name__ == '__main__':
train(task.metrics)
if only_calibration:
try:
calibration(model,
dev_data_list,
num_calib_batches,
quantized_dtype,
calib_mode)
except AttributeError:
nlp.utils.version.check_version('1.7.0', warning_only=True, library=mx)
warnings.warn('INT8 Quantization for BERT need mxnet-mkl >= 1.6.0b20200115')
else:
train(task.metrics)
Loading

0 comments on commit 645fe30

Please sign in to comment.