-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathquant_aware_training.yaml
44 lines (42 loc) · 2.16 KB
/
quant_aware_training.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Scheduler for training / re-training a model using quantization aware training, with a linear, range-based quantizer
#
# The setting here is 8-bit weights and activations. For vision models, this is usually applied to the entire model,
# without exceptions. Hence, this scheduler isn't model-specific as-is. It doesn't define any name-based overrides.
#
# At the moment this quantizer will:
# * Quantize weights and biases for all convolution and FC layers
# * Quantize all ReLU activations
#
# Here's an example run for fine-tuning the ResNet-18 model from torchvision:
#
# python compress_classifier.py -a resnet18 -p 50 -b 256 ~/datasets/imagenet --epochs 10 --compress=../quantization/quant_aware_train/quant_aware_train_linear_quant.yaml --pretrained -j 22 --lr 0.0001 --vs 0 --gpu 0
#
# After 6 epochs we get:
#
# 2018-11-22 20:41:03,662 - --- validate (epoch=6)-----------
# 2018-11-22 20:41:03,663 - 50000 samples (256 per mini-batch)
# 2018-11-22 20:41:23,507 - Epoch: [6][ 50/ 195] Loss 0.896985 Top1 76.320312 Top5 93.460938
# 2018-11-22 20:41:33,633 - Epoch: [6][ 100/ 195] Loss 1.026040 Top1 74.007812 Top5 91.984375
# 2018-11-22 20:41:44,142 - Epoch: [6][ 150/ 195] Loss 1.168643 Top1 71.197917 Top5 90.041667
# 2018-11-22 20:41:51,505 - ==> Top1: 70.188 Top5: 89.376 Loss: 1.223
#
# This is an improvement compared to the pre-trained torchvision model:
# 2018-11-07 15:45:53,435 - ==> Top1: 69.758 Top5: 89.078 Loss: 1.251
#
# (Note that the command line above is not using --deterministic, so results could vary a little bit)
quantizers:
linear_quantizer:
class: QuantAwareTrainRangeLinearQuantizer
bits_activations: 8
bits_weights: 8
mode: 'SYMMETRIC' # Can try "SYMMETRIC" as well, ASYMMETRIC_UNSIGNED
ema_decay: 0.999 # Decay value for exponential moving average tracking of activation ranges
per_channel_wts: False
policies:
- quantizer:
instance_name: linear_quantizer
# For now putting a large range here, which should cover both training from scratch or resuming from some
# pre-trained checkpoint at some unknown epoch
starting_epoch: 0
ending_epoch: 300
frequency: 1