Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[FEATURE] INT8 Quantization for BERT Sentence Classification and Question Answering #1080

Merged
merged 38 commits into from
Feb 3, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0e2f2f6
init support for int8 bert classification
xinyu-intel Nov 18, 2019
710c6ce
support squad calibration
xinyu-intel Nov 18, 2019
9ef98b9
add headers
xinyu-intel Nov 18, 2019
ee2126c
enhance quantization and add readme
xinyu-intel Nov 18, 2019
f69f087
rename layernorm
xinyu-intel Nov 19, 2019
d165481
init support for int8 bert classification
xinyu-intel Nov 18, 2019
c178f25
support squad calibration
xinyu-intel Nov 18, 2019
96a0b80
add headers
xinyu-intel Nov 18, 2019
e7191ff
enhance quantization and add readme
xinyu-intel Nov 18, 2019
9b4505d
rename layernorm
xinyu-intel Nov 19, 2019
71b161a
Merge branch 'bert_int8' of ssh://gitlab.devtools.intel.com:29418/che…
xinyu-intel Dec 24, 2019
08859e3
fix lint
xinyu-intel Dec 24, 2019
8f9f531
fix lint
xinyu-intel Dec 24, 2019
4ccc979
fix ut and lint
xinyu-intel Dec 25, 2019
412ff73
fix ut
xinyu-intel Dec 25, 2019
53ea62e
update python env
xinyu-intel Dec 30, 2019
2b994ea
rebase from master
xinyu-intel Jan 6, 2020
26c6cd6
Merge remote-tracking branch 'upstream/master' into bert_int8
xinyu-intel Jan 7, 2020
3a915f7
rebase from master
xinyu-intel Jan 16, 2020
720098f
add quantization doc to website
xinyu-intel Jan 16, 2020
a3a8913
move doc
xinyu-intel Jan 16, 2020
af44a8e
pin to 1231 gpu and remove doc temp
xinyu-intel Jan 16, 2020
bd2ecc5
pin to official 20200115 mxnet
xinyu-intel Jan 16, 2020
9364726
rebase
xinyu-intel Jan 18, 2020
11b9dfd
enable channel-wise quantization and smart mode
xinyu-intel Jan 19, 2020
7b2275f
rebase and add version check
xinyu-intel Jan 23, 2020
69b01aa
fix ut
xinyu-intel Jan 23, 2020
5b500bd
fix ut
xinyu-intel Jan 23, 2020
85fc23e
trigger
xinyu-intel Jan 23, 2020
a2fa345
trigger
xinyu-intel Jan 23, 2020
d400b85
add quantization toturial for mrpc
xinyu-intel Jan 27, 2020
d7ea196
Merge remote-tracking branch 'upstream/master' into bert_int8
xinyu-intel Jan 27, 2020
6c08caf
rebase and add deployment part to tutorial
xinyu-intel Jan 30, 2020
3b4eee4
fix lint
xinyu-intel Jan 30, 2020
54a6e3a
add accuracy to modelzoo
xinyu-intel Feb 1, 2020
1dbfcf7
resolve conflict
xinyu-intel Feb 2, 2020
99ab77b
add SST int8
xinyu-intel Feb 2, 2020
05918ec
address comments
xinyu-intel Feb 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/examples/sentence_embedding/bert.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ print('%s token id = %s'%(vocabulary.padding_token, vocabulary[vocabulary.paddin
print('%s token id = %s'%(vocabulary.cls_token, vocabulary[vocabulary.cls_token]))
print('%s token id = %s'%(vocabulary.sep_token, vocabulary[vocabulary.sep_token]))
print('token ids = \n%s'%data_train[sample_id][0])
print('valid length = \n%s'%data_train[sample_id][1])
print('segment ids = \n%s'%data_train[sample_id][2])
print('segment ids = \n%s'%data_train[sample_id][1])
print('valid length = \n%s'%data_train[sample_id][2])
print('label = \n%s'%data_train[sample_id][3])
```

Expand All @@ -241,7 +241,7 @@ batch_size = 32
lr = 5e-6

# The FixedBucketSampler and the DataLoader for making the mini-batches
train_sampler = nlp.data.FixedBucketSampler(lengths=[int(item[1]) for item in data_train],
train_sampler = nlp.data.FixedBucketSampler(lengths=[int(item[2]) for item in data_train],
batch_size=batch_size,
shuffle=True)
bert_dataloader = mx.gluon.data.DataLoader(data_train, batch_sampler=train_sampler)
Expand All @@ -261,7 +261,7 @@ num_epochs = 3
for epoch_id in range(num_epochs):
metric.reset()
step_loss = 0
for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(bert_dataloader):
for batch_id, (token_ids, segment_ids, valid_length, label) in enumerate(bert_dataloader):
with mx.autograd.record():

# Load the data to the GPU
Expand Down
2 changes: 1 addition & 1 deletion env/cpu/py3-master.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies:
- flake8==3.7.9
- mock<3
- sphinx==2.2.1
- https://apache-mxnet.s3-us-west-2.amazonaws.com/dist/2019-12-16/dist/mxnet-1.6.0b20191216-py2.py3-none-manylinux1_x86_64.whl
- https://apache-mxnet.s3-us-west-2.amazonaws.com/dist/2019-12-29/dist/mxnet-1.6.0b20191229-py2.py3-none-manylinux1_x86_64.whl
- scipy==1.3.2
- regex==2019.11.1
- nltk==3.4.5
Expand Down
2 changes: 1 addition & 1 deletion env/gpu/py3-master.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- flaky==3.6.1
- flake8==3.7.9
- mock<3
- https://apache-mxnet.s3-us-west-2.amazonaws.com/dist/2019-12-16/dist/mxnet_cu100-1.6.0b20191216-py2.py3-none-manylinux1_x86_64.whl
- https://apache-mxnet.s3-us-west-2.amazonaws.com/dist/2019-12-29/dist/mxnet_cu100-1.6.0b20191229-py2.py3-none-manylinux1_x86_64.whl
- scipy==1.3.2
- regex==2019.11.1
- nltk==3.4.5
Expand Down
12 changes: 9 additions & 3 deletions scripts/bert/data/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _worker_fn(example, transform):
return feature


def preprocess_dataset(dataset, transform, num_workers=8):
def preprocess_dataset(dataset, transform, num_workers=8, for_calibration=False):
"""Use multiprocessing to perform transform for dataset.

Parameters
Expand All @@ -77,8 +77,14 @@ def preprocess_dataset(dataset, transform, num_workers=8):
dataset_transform.append(_data[:-1])
dataset_len.append(_data[-1])

dataset = SimpleDataset(dataset_transform).transform(
lambda x: (x[0], x[1], x[2], x[3], x[4], x[5]))
if for_calibration:
# gluon calibration api supposes there must be input datas and one label per data entry.
dataset = SimpleDataset(dataset_transform).transform(
lambda x: (x[1], x[2], x[3], x[4]))
else:
dataset = SimpleDataset(dataset_transform).transform(
lambda x: (x[0], x[1], x[2], x[3], x[4], x[5]))

end = time.time()
pool.close()
print('Done! Transform dataset costs %.2f seconds.' % (end-start))
Expand Down
5 changes: 3 additions & 2 deletions scripts/bert/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __call__(self, line):
if self.class_labels:
label = self._label_map[label]
label = np.array([label], dtype=self._label_dtype)
return input_ids, valid_length, segment_ids, label
return input_ids, segment_ids, valid_length, label
else:
return self._bert_xform(line)
input_ids, valid_length, segment_ids = self._bert_xform(line)
return input_ids, segment_ids, valid_length
83 changes: 75 additions & 8 deletions scripts/bert/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
import mxnet as mx
from mxnet import gluon
from mxnet.contrib.amp import amp
from mxnet.contrib.quantization import quantize_net_v2
import gluonnlp as nlp
from gluonnlp.data import BERTTokenizer
from gluonnlp.model import BERTClassifier, RoBERTaClassifier
from gluonnlp.calibration import BertLayerCollector

from data.classification import MRPCTask, QQPTask, RTETask, STSBTask, SSTTask
from data.classification import QNLITask, CoLATask, MNLITask, WNLITask, XNLITask
Expand Down Expand Up @@ -178,6 +180,21 @@
default=None,
help='Whether to perform early stopping based on the metric on dev set. '
'The provided value is the patience. ')
parser.add_argument('--deploy', action='store_true',
help='whether load static model for deployment')
parser.add_argument('--model_prefix', type=str, required=False,
help='load static model as hybridblock.')
parser.add_argument('--only_calibration', action='store_true',
help='quantize model')
parser.add_argument('--num_calib_batches', type=int, default=5,
help='number of batches for calibration')
parser.add_argument('--quantized_dtype', type=str, default='auto',
choices=['auto', 'int8', 'uint8'],
help='quantization destination data type for input data')
parser.add_argument('--calib_mode', type=str, default='customize',
choices=['none', 'naive', 'entropy', 'customize'],
help='calibration mode used for generating calibration table '
'for the quantized symbol.')

args = parser.parse_args()

Expand Down Expand Up @@ -215,6 +232,11 @@
dataset = args.bert_dataset
pretrained_bert_parameters = args.pretrained_bert_parameters
model_parameters = args.model_parameters

# load symbolic model
deploy = args.deploy
model_prefix = args.model_prefix

if only_inference and not model_parameters:
warnings.warn('model_parameters is not set. '
'Randomly initialized model will be used for inference.')
Expand Down Expand Up @@ -272,13 +294,26 @@
model.hybridize(static_alloc=True)
loss_function.hybridize(static_alloc=True)

if deploy:
logging.info('load symbol file directly as SymbolBlock for model deployment')
model = mx.gluon.SymbolBlock.imports('{}-symbol.json'.format(args.model_prefix),
['data0', 'data1', 'data2'],
'{}-0000.params'.format(args.model_prefix))
model.hybridize(static_alloc=True, static_shape=True)

# data processing
do_lower_case = 'uncased' in dataset
if use_roberta:
bert_tokenizer = nlp.data.GPT2BPETokenizer()
else:
bert_tokenizer = BERTTokenizer(vocabulary, lower=do_lower_case)

# calibration config
only_calibration = args.only_calibration
num_calib_batches = args.num_calib_batches
quantized_dtype = args.quantized_dtype
calib_mode = args.calib_mode

def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, vocab, pad=False):
"""Train/eval Data preparation function."""
# transformation for data train and dev
Expand All @@ -295,13 +330,13 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, vocab,
train_tsv = task.dataset_train()[1]
data_train = mx.gluon.data.SimpleDataset(list(map(trans, train_tsv)))
data_train_len = data_train.transform(
lambda input_id, length, segment_id, label_id: length, lazy=False)
lambda input_id, segment_id, length, label_id: length, lazy=False)
# bucket sampler for training
pad_val = vocabulary[vocabulary.padding_token]
batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0, pad_val=pad_val), # input
nlp.data.batchify.Stack(), # length
nlp.data.batchify.Pad(axis=0, pad_val=0), # segment
nlp.data.batchify.Stack(), # length
nlp.data.batchify.Stack(label_dtype)) # label
batch_sampler = nlp.data.sampler.FixedBucketSampler(
data_train_len,
Expand Down Expand Up @@ -332,8 +367,9 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, vocab,

# batchify for data test
test_batchify_fn = nlp.data.batchify.Tuple(
nlp.data.batchify.Pad(axis=0, pad_val=pad_val), nlp.data.batchify.Stack(),
nlp.data.batchify.Pad(axis=0, pad_val=0))
nlp.data.batchify.Pad(axis=0, pad_val=pad_val),
nlp.data.batchify.Pad(axis=0, pad_val=0),
nlp.data.batchify.Stack())
# transform for data test
test_trans = BERTDatasetTransform(tokenizer, max_len,
vocab=vocab,
Expand Down Expand Up @@ -362,6 +398,30 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len, vocab,
train_data, dev_data_list, test_data_list, num_train_examples = preprocess_data(
bert_tokenizer, task, batch_size, dev_batch_size, args.max_len, vocabulary, args.pad)

def calibration(net, dev_data_list, num_calib_batches, quantized_dtype, calib_mode):
"""calibration function on the dev dataset."""
assert len(dev_data_list) == 1, \
'Currectly, MNLI not supported.'
assert ctx == mx.cpu(), \
'Currently only supports CPU with MKL-DNN backend.'
logging.info('Now we are doing calibration on dev with %s.', ctx)
for _, dev_data in dev_data_list:
collector = BertLayerCollector(clip_min=-50, clip_max=10, logger=logging)
net = quantize_net_v2(net, quantized_dtype=quantized_dtype,
exclude_layers=[],
exclude_layers_match=['elemwise_add'],
calib_data=dev_data,
calib_mode=calib_mode,
num_calib_examples=dev_batch_size * num_calib_batches,
ctx=ctx,
LayerOutputCollector=collector,
logger=logging)
# save params
ckpt_name = 'model_bert_{0}_quantized_{1}'.format(task_name, calib_mode)
params_saved = os.path.join(output_dir, ckpt_name)
net.export(params_saved, epoch=0)
logging.info('Saving quantized model at %s', output_dir)


def test(loader_test, segment):
"""Inference function on the test dataset."""
Expand All @@ -370,7 +430,7 @@ def test(loader_test, segment):
tic = time.time()
results = []
for _, seqs in enumerate(loader_test):
input_ids, valid_length, segment_ids = seqs
input_ids, segment_ids, valid_length = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
if use_roberta:
Expand Down Expand Up @@ -483,7 +543,7 @@ def train(metric):

# forward and backward
with mx.autograd.record():
input_ids, valid_length, segment_ids, label = seqs
input_ids, segment_ids, valid_length, label = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
label = label.as_in_context(ctx)
Expand Down Expand Up @@ -563,7 +623,7 @@ def evaluate(loader_dev, metric, segment):
label_list = []
out_list = []
for batch_id, seqs in enumerate(loader_dev):
input_ids, valid_length, segment_ids, label = seqs
input_ids, segment_ids, valid_length, label = seqs
input_ids = input_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx).astype('float32')
label = label.as_in_context(ctx)
Expand Down Expand Up @@ -598,4 +658,11 @@ def evaluate(loader_dev, metric, segment):


if __name__ == '__main__':
train(task.metrics)
if only_calibration:
calibration(model,
dev_data_list,
num_calib_batches,
quantized_dtype,
calib_mode)
else:
train(task.metrics)
Loading