From cd4b887d1093b8b67f0ee2f0a5c0a1beccc19a20 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 17 Sep 2020 11:51:53 -0700 Subject: [PATCH] Changes for TF/PT Rn50 (#3) * Changes for TF/PT Rn50 * Refactoring * Comments --- python/tvm/hago/_op_attrs.py | 5 + python/tvm/hago/base.py | 2 + python/tvm/hago/hardware.py | 11 ++ python/tvm/relay/frontend/__init__.py | 1 + src/hago/quantize.cc | 2 +- .../nightly/quantization/common_hago.py | 98 +++++++++++ .../nightly/quantization/test_mxnet_hago.py | 95 ++++++++++ .../nightly/quantization/test_pytorch_hago.py | 144 ++++++++++++++++ .../nightly/quantization/test_tf_hago.py | 163 ++++++++++++++++++ 9 files changed, 520 insertions(+), 1 deletion(-) create mode 100644 tests/python/nightly/quantization/common_hago.py create mode 100644 tests/python/nightly/quantization/test_mxnet_hago.py create mode 100644 tests/python/nightly/quantization/test_pytorch_hago.py create mode 100644 tests/python/nightly/quantization/test_tf_hago.py diff --git a/python/tvm/hago/_op_attrs.py b/python/tvm/hago/_op_attrs.py index 22c3ba1ce6ec..b4028ce2e180 100644 --- a/python/tvm/hago/_op_attrs.py +++ b/python/tvm/hago/_op_attrs.py @@ -223,10 +223,15 @@ def identity_scale(input_scales): return scale0 register_infer_scale("add", identity_scale) +register_infer_scale("mean", identity_scale) +register_infer_scale("nn.softmax", identity_scale) +register_infer_scale("layout_transform", identity_scale) +register_infer_scale("nn.pad", identity_scale) register_infer_scale("nn.relu", identity_scale) register_infer_scale("nn.max_pool2d", identity_scale) register_infer_scale("nn.avg_pool2d", identity_scale) register_infer_scale("nn.global_avg_pool2d", identity_scale) +register_infer_scale("nn.adaptive_avg_pool2d", identity_scale) register_infer_scale("nn.batch_flatten", identity_scale) # threshold rectify function registered for ops diff --git a/python/tvm/hago/base.py b/python/tvm/hago/base.py index c76e8dfc0ff8..8eac46e23231 100644 --- a/python/tvm/hago/base.py +++ b/python/tvm/hago/base.py @@ -312,6 +312,8 @@ def evaluate(func, dataset, ctx=tvm.cpu(), target='llvm'): runtime.run() for i in range(num_outputs): output = runtime.get_output(i).asnumpy() + if len(output.shape) == 0: + output = np.array([output]) outputs[i].append(output) return outputs diff --git a/python/tvm/hago/hardware.py b/python/tvm/hago/hardware.py index 34dda7045375..f44aece456d9 100644 --- a/python/tvm/hago/hardware.py +++ b/python/tvm/hago/hardware.py @@ -154,4 +154,15 @@ def create_accelerator_description(): hardware.add_op_desc('nn.batch_flatten', OpDesc(in_dtypes='float32', out_dtypes='float32')) hardware.add_op_desc('nn.dense', OpDesc(in_dtypes='float32', out_dtypes='float32')) hardware.add_op_desc('nn.global_avg_pool2d', OpDesc(in_dtypes='float32', out_dtypes='float32')) + + + hardware.add_op_desc('nn.pad', OpDesc(in_dtypes='float32', out_dtypes='float32')) + hardware.add_op_desc('nn.pad', OpDesc(in_dtypes='int8', out_dtypes='int8')) + hardware.add_op_desc('layout_transform', OpDesc(in_dtypes='float32', out_dtypes='float32')) + hardware.add_op_desc('layout_transform', OpDesc(in_dtypes='int8', out_dtypes='int8')) + hardware.add_op_desc('multiply', OpDesc(in_dtypes='float32', out_dtypes='float32')) + hardware.add_op_desc('subtract', OpDesc(in_dtypes='float32', out_dtypes='float32')) + hardware.add_op_desc('nn.adaptive_avg_pool2d', OpDesc(in_dtypes='float32', out_dtypes='float32')) + hardware.add_op_desc('mean', OpDesc(in_dtypes='float32', out_dtypes='float32')) + hardware.add_op_desc('nn.softmax', OpDesc(in_dtypes='float32', out_dtypes='float32')) return hardware diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index aba9eea494be..9bc2d9a0d62c 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -33,3 +33,4 @@ from .tensorflow import from_tensorflow from .darknet import from_darknet from .pytorch import from_pytorch +from .tensorflow_parser import TFParser diff --git a/src/hago/quantize.cc b/src/hago/quantize.cc index 91ac8d4d52c5..dea50dc03373 100644 --- a/src/hago/quantize.cc +++ b/src/hago/quantize.cc @@ -50,7 +50,7 @@ bool SimulatedQuantizeRel(const Array& types, const auto* data = types[0].as(); CHECK(data != nullptr); - CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; + // CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // in_scale reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // out_scale diff --git a/tests/python/nightly/quantization/common_hago.py b/tests/python/nightly/quantization/common_hago.py new file mode 100644 index 000000000000..85da6ac68393 --- /dev/null +++ b/tests/python/nightly/quantization/common_hago.py @@ -0,0 +1,98 @@ +import mxnet as mx +import tvm +from tvm import relay +from tvm import hago +from mxnet import gluon + +import logging +logging.basicConfig(level=logging.DEBUG) +def get_calibration_dataset(dataset, batch_fn, var_name, num_samples=100): + dataset.reset() + batches = [] + for i, batch in enumerate(dataset): + if i * dataset.batch_size > num_samples: + break + data, label = batch_fn(batch, [mx.cpu(0)]) + batches.append({var_name: tvm.nd.array(data[0].asnumpy()), + 'label': tvm.nd.array(label[0].asnumpy())}) + return hago.CalibrationDataset(batches) + + +################## +# Evaluation infra +################## +def eval_acc(func, dataset, batch_fn, args, var_name, target='cuda', ctx=tvm.gpu(), postprocess=None, log_interval=100): + with relay.build_config(opt_level=3): + graph, lib, params = relay.build(func, target) + # create runtime module + m = tvm.contrib.graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + + # setup evaluaiton metric + dataset.reset() + batch_size = dataset.batch_size + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + acc_top1.reset() + acc_top5.reset() + # Execute + + if args.soundness_check: + exit_at_batch = (100 + batch_size - 1)//batch_size + else: + exit_at_batch = -1 + + for i, batch in enumerate(dataset): + data, label = batch_fn(batch, [mx.cpu(0)]) + m.set_input(var_name, data[0].asnumpy()) + m.run() + out_arr = m.get_output(0).asnumpy() + if postprocess is not None: + out_arr = postprocess(out_arr) + acc_top1.update(label, [mx.nd.array(out_arr)]) + acc_top5.update(label, [mx.nd.array(out_arr)]) + + if not (i + 1) % log_interval or i == exit_at_batch: + _, top1 = acc_top1.get() + _, top5 = acc_top5.get() + nsamples = (i + 1) * batch_size + logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5) + + if i == exit_at_batch: + break + logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5) + return top1 + + +################# +# Quantize helper +################# +def quantize_hago(mod, params, calib_dataset): + qconfig = hago.qconfig(skip_conv_layers=[0], + log_file='temp.log') + + with qconfig: + graph = hago.prerequisite_optimize(mod['main'], params=params) + logging.debug('current quantize config') + logging.debug(hago.current_qconfig()) + hardware = hago.create_accelerator_description() + space = hago.generate_search_space(graph, hardware) + # tuner = hago.BatchedGreedySearchTuner(space, 'accuracy') + tuner = hago.DefaultSetting(space, 'accuracy') + ctx = tvm.cpu() + strategy, result = hago.search_quantize_strategy(graph, hardware, calib_dataset, tuner, ctx, + target='llvm') + + quantizer = hago.create_quantizer(graph, hardware, strategy) + simulated_graph = quantizer.simulate() + quantized_graph = quantizer.quantize() + logging.debug('simulated graph') + logging.debug(simulated_graph.astext(show_meta_data=False)) + logging.debug('quantize graph') + logging.debug(quantized_graph.astext(show_meta_data=False)) + # hago.inspect_graph_statistic(graph, hardware, strategy, dataset, ctx, target='llvm') + return tvm.IRModule.from_expr(quantized_graph) + + + + diff --git a/tests/python/nightly/quantization/test_mxnet_hago.py b/tests/python/nightly/quantization/test_mxnet_hago.py new file mode 100644 index 000000000000..4283df465335 --- /dev/null +++ b/tests/python/nightly/quantization/test_mxnet_hago.py @@ -0,0 +1,95 @@ +import tvm +from tvm import relay + +import numpy as np +import argparse +import os + +import mxnet as mx +from tvm import hago +from mxnet import gluon + +from common_hago import * + + +parser = argparse.ArgumentParser() +parser.add_argument("--model", default="resnet50_v1", help="model to quantize") +parser.add_argument("--soundness_check", default=False, action='store_true') +parser.add_argument("--skip_fp32", default=False, action='store_true') +parser.add_argument("--run_all", default=False, action='store_true') +args = parser.parse_args() + +batch_size = 32 +target = 'llvm -mcpu=core-avx2' +ctx = tvm.context(target) + +##################### +# Dataset prepartions +##################### + +def get_val_data(img_size, + rec_val, + batch_size, + num_workers=4): + rec_val = os.path.expanduser(rec_val) + mean_rgb = [123.68, 116.779, 103.939] + std_rgb = [58.393, 57.12, 57.375] + def batch_fn(batch, ctx): + data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) + label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) + return data, label + + val_data = mx.io.ImageRecordIter( + path_imgrec = rec_val, + preprocess_threads = num_workers, + shuffle = True, + seed = 0, + batch_size = batch_size, + resize = 256, + data_shape = (3, img_size, img_size), + mean_r = mean_rgb[0], + mean_g = mean_rgb[1], + mean_b = mean_rgb[2], + std_r = std_rgb[0], + std_g = std_rgb[1], + std_b = std_rgb[2], + ) + return val_data, batch_fn + +############################################################################### +# Load the model +# ---------------- +def get_model(model_name): + gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True) + img_size = 299 if model_name == 'inceptionv3' else 224 + data_shape = (batch_size, 3, img_size, img_size) + mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) + return mod, params + +def main(): + val_path = '/home/ubuntu/tensorflow_datasets/downloads/manual/imagenet2012/val.rec' + if args.run_all: + models = ['resnet50_v1', 'inceptionv3', 'mobilenetv2_1.0', 'mobilenet1.0', 'resnet18_v1', + 'densenet161', 'vgg16'] + else: + models = [args.model] + for model_name in models: + img_size = 299 if model_name == 'inceptionv3' else 224 + val_data, batch_fn = get_val_data(img_size, val_path, batch_size) + + if not args.skip_fp32: + fp32_mod, params = get_model(model_name) + func = hago.prerequisite_optimize(fp32_mod['main'], params=params) + acc = eval_acc(func, val_data, batch_fn, args, var_name='data', target=target, ctx=ctx) + print("fp32_accuracy", model_name, acc, sep=',') + + # Quantize + calib_dataset = get_calibration_dataset(val_data, batch_fn, var_name='data') + fp32_mod, params = get_model(model_name) + quantized_func = quantize_hago(fp32_mod, params, calib_dataset) + acc = eval_acc(quantized_func, val_data, batch_fn, args, var_name='data', target=target, ctx=ctx) + print("quantized_accuracy", model_name, acc, sep=',') + + +if __name__ == '__main__': + main() diff --git a/tests/python/nightly/quantization/test_pytorch_hago.py b/tests/python/nightly/quantization/test_pytorch_hago.py new file mode 100644 index 000000000000..c6023b7ba667 --- /dev/null +++ b/tests/python/nightly/quantization/test_pytorch_hago.py @@ -0,0 +1,144 @@ +import tvm +from tvm import relay + +import numpy as np +import argparse + +import torch +from torch.nn import Module +import torchvision +from torchvision import transforms +import os + +import mxnet as mx +from tvm import hago +from mxnet import gluon + +from common_hago import * + +parser = argparse.ArgumentParser() +parser.add_argument("--model", default="resnet50_v1", help="model to quantize") +parser.add_argument("--soundness_check", default=False, action='store_true') +parser.add_argument("--skip_fp32", default=False, action='store_true') +parser.add_argument("--run_all", default=False, action='store_true') +args = parser.parse_args() + +batch_size = 32 +target = 'llvm -mcpu=core-avx2' +ctx = tvm.context(target) + +##################### +# Dataset prepartions +##################### + +def get_val_data(img_size, + rec_val, + batch_size, + num_workers=4): + rec_val = os.path.expanduser(rec_val) + mean_rgb = [255 * x for x in [0.485, 0.456, 0.406]] + std_rgb = [255 * x for x in [0.229, 0.224, 0.225]] + def batch_fn(batch, ctx): + data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) + label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) + return data, label + + val_data = mx.io.ImageRecordIter( + path_imgrec = rec_val, + preprocess_threads = num_workers, + shuffle = True, + seed = 0, + batch_size = batch_size, + resize = 256, + data_shape = (3, img_size, img_size), + mean_r = mean_rgb[0], + mean_g = mean_rgb[1], + mean_b = mean_rgb[2], + std_r = std_rgb[0], + std_g = std_rgb[1], + std_b = std_rgb[2], + ) + return val_data, batch_fn + +############################################################################### +# Load the model from torchvision +# ---------------- +def load_model(model_name): + """Given a model name, returns a model as well as an example input.""" + if hasattr(torchvision.models, model_name): + with torch.no_grad(): + if model_name.startswith("inception"): + height = width = 299 + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + else: + height = width = 224 + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + input_shape = [batch_size, 3, height, width] + input_data = torch.randn(input_shape).float() + for channel in range(3): + input_data[:, channel] -= mean[channel] + input_data[:, channel] /= std[channel] + model = getattr(torchvision.models, model_name)(pretrained=True) + model = model.float().eval() + return model, [input_data] + try: + import pretrainedmodels + if hasattr(pretrainedmodels, model_name): + return load_pretrainedmodels(model_name) + except ModuleNotFoundError: + raise ModuleNotFoundError("Please install pretrainedmodels.pytorch") + raise RuntimeError("Model not supported") + +def get_model(model_name): + torch.set_grad_enabled(False) + baseline_model, baseline_input = load_model(model_name) + + trace = torch.jit.trace(baseline_model, baseline_input) + if isinstance(baseline_model, torch.nn.Module): + trace = trace.float().eval() + trace = trace.cpu() + + global input_names + input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] + input_shapes = list(zip(input_names, + [inp.shape for inp in baseline_input])) + mod, params = relay.frontend.from_pytorch(trace, input_shapes) + return mod, params + + +############# +# Test models +############# +def main(): + val_path = '/home/ubuntu/tensorflow_datasets/downloads/manual/imagenet2012/val.rec' + if args.run_all: + models = ['resnet50', 'inception_v3', 'mobilenet_v2', 'resnet18', + 'densenet161', 'googlenet', 'vgg16'] + else: + models = [args.model] + for model_name in models: + height = 224 + if model_name.startswith("inception"): + height = 299 + + val_data, batch_fn = get_val_data(height, val_path, batch_size) + + # Original + if not args.skip_fp32: + fp32_mod, params = get_model(model_name) + func = hago.prerequisite_optimize(fp32_mod['main'], params=params) + acc = eval_acc(func, val_data, batch_fn, args, var_name=input_names[0], target=target, ctx=ctx) + print("fp32_accuracy", model_name, acc, sep=',') + + # Quantize + calib_dataset = get_calibration_dataset(val_data, batch_fn, var_name=input_names[0]) + fp32_mod, params = get_model(model_name) + quantized_func = quantize_hago(fp32_mod, params, calib_dataset) + acc = eval_acc(quantized_func, val_data, batch_fn, args, var_name=input_names[0], target=target, ctx=ctx) + print("quantized_accuracy", model_name, acc, sep=',') + + +if __name__ == '__main__': + main() diff --git a/tests/python/nightly/quantization/test_tf_hago.py b/tests/python/nightly/quantization/test_tf_hago.py new file mode 100644 index 000000000000..0b7db220abcb --- /dev/null +++ b/tests/python/nightly/quantization/test_tf_hago.py @@ -0,0 +1,163 @@ +import tvm +from tvm import relay + +import numpy as np +import argparse +import os + +import mxnet as mx +from tvm import hago +from mxnet import gluon + +from common_hago import * + +try: + # %tensorflow_version only exists in Colab. + import tensorflow.compat.v2 as tf +except Exception: + pass +# tf.enable_v2_behavior() +import tensorflow_hub as hub + + +parser = argparse.ArgumentParser() +parser.add_argument("--model", default="resnet50", help="model to quantize") +parser.add_argument("--soundness_check", default=False, action='store_true') +parser.add_argument("--skip_fp32", default=False, action='store_true') +parser.add_argument("--run_all", default=False, action='store_true') +args = parser.parse_args() + +batch_size = 32 +target = 'llvm -mcpu=core-avx2' +ctx = tvm.context(target) + + +############################## +# Original FP32 TF/Keras model +############################## +tf_hub_links = { + "resnet50" : "https://tfhub.dev/tensorflow/resnet_50/classification/1", + "resnet_v2_50" : "https://tfhub.dev/google/imagenet/resnet_v2_50/classification/4", + "mobilenet_v1" : "https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/4", + "mobilenet_v2" : "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/4", + "inception_v1" : "https://tfhub.dev/google/imagenet/inception_v1/classification/4", + "inception_v2" : "https://tfhub.dev/google/imagenet/inception_v2/classification/4", + "inception_v3" : "https://tfhub.dev/google/imagenet/inception_v3/classification/4", +} + + +##################### +# Dataset prepartions +##################### + +def get_val_data(img_size, + rec_val, + batch_size, + num_workers=4): + rec_val = os.path.expanduser(rec_val) + mean_rgb = [123.68, 116.779, 103.939] + std_rgb = [58.393, 57.12, 57.375] + def batch_fn(batch, ctx): + data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) + data0 = data[0].asnumpy() + data0 = np.transpose(data0, axes=[0, 2, 3, 1]) + data = [mx.nd.array(data0)] + label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) + return data, label + + val_data = mx.io.ImageRecordIter( + path_imgrec = rec_val, + preprocess_threads = num_workers, + shuffle = True, + seed = 0, + batch_size = batch_size, + resize = 256, + data_shape = (3, img_size, img_size), + scale = 1.0/255.0, + # mean_r = mean_rgb[0], + # mean_g = mean_rgb[1], + # mean_b = mean_rgb[2], + # std_r = std_rgb[0], + # std_g = std_rgb[1], + # std_b = std_rgb[2], + ) + return val_data, batch_fn + +############################################################################### +# Load the model +# ---------------- +def get_model(model_name): + from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 + model = tf.keras.Sequential([ + hub.KerasLayer(tf_hub_links[model_name], output_shape=[1001]) + ]) + img_size = 299 if model_name == 'inceptionv3' else 224 + np_image = np.random.rand(batch_size, img_size, img_size, 3).astype('float32') + model._set_inputs(np_image) + + + # Convert Keras model to ConcreteFunction + full_model = tf.function(lambda x: model(x)) + full_model = full_model.get_concrete_function( + tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype, name="data")) + + frozen_func = convert_variables_to_constants_v2(full_model) + frozen_func.graph.as_graph_def() + + tf.io.write_graph(graph_or_graph_def=frozen_func.graph, + logdir="./.tf_saved_model/" + model_name, + name="frozen_graph.pb", + as_text=False) + + parser = tvm.relay.frontend.TFParser("./.tf_saved_model/" + + model_name + "/frozen_graph.pb") + graph_def = parser.parse() + mod, params = relay.frontend.from_tensorflow(graph_def, + shape={"data": (batch_size, img_size, img_size, 3)}) + + # We assume our model's heavily-layout sensitive operators only consist of nn.conv2d + desired_layouts = {'nn.conv2d': ['NCHW', 'default']} + + # Convert the layout to NCHW + # RemoveUnunsedFunctions is used to clean up the graph. + seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(), + relay.transform.ConvertLayout(desired_layouts)]) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + return mod, params + +def ignore_first(tensor): + if tensor.shape[1] == 1001: + tensor = tensor[:, 1:] + return tensor + +def main(): + val_path = '/home/ubuntu/tensorflow_datasets/downloads/manual/imagenet2012/val.rec' + if args.run_all: + models = tf_hub_links.keys() + else: + models = [args.model] + for model_name in models: + img_size = 299 if model_name == 'inceptionv3' else 224 + postprocess = ignore_first if 'resnet' not in model_name else None + val_data, batch_fn = get_val_data(img_size, val_path, batch_size) + + # Original + if not args.skip_fp32: + fp32_mod, params = get_model(model_name) + func = hago.prerequisite_optimize(fp32_mod['main'], params=params) + acc = eval_acc(func, val_data, batch_fn, args, var_name='data', target=target, ctx=ctx, + postprocess=postprocess) + print("fp32_accuracy", model_name, acc, sep=',') + + # Quantize + calib_dataset = get_calibration_dataset(val_data, batch_fn, var_name='data') + fp32_mod, params = get_model(model_name) + quantized_func = quantize_hago(fp32_mod, params, calib_dataset) + acc = eval_acc(quantized_func, val_data, batch_fn, args, var_name='data', target=target, + ctx=ctx, postprocess=postprocess) + print("quantized_accuracy", model_name, acc, sep=',') + + +if __name__ == '__main__': + main()