diff --git a/python/tvm/driver/tvmc/frontends.py b/python/tvm/driver/tvmc/frontends.py index 16e6c8eb966e..0488223c782f 100644 --- a/python/tvm/driver/tvmc/frontends.py +++ b/python/tvm/driver/tvmc/frontends.py @@ -54,7 +54,7 @@ def suffixes(): """File suffixes (extensions) used by this frontend""" @abstractmethod - def load(self, path, shape_dict=None): + def load(self, path, shape_dict=None, **kwargs): """Load a model from a given path. Parameters @@ -101,7 +101,7 @@ def name(): def suffixes(): return ["h5"] - def load(self, path, shape_dict=None): + def load(self, path, shape_dict=None, **kwargs): # pylint: disable=C0103 tf, keras = import_keras() @@ -130,7 +130,8 @@ def load(self, path, shape_dict=None): input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)} if shape_dict is not None: input_shapes.update(shape_dict) - return relay.frontend.from_keras(model, input_shapes, layout="NHWC") + kwargs.setdefault("layout", "NHWC") + return relay.frontend.from_keras(model, input_shapes, **kwargs) def is_sequential_p(self, model): _, keras = import_keras() @@ -158,14 +159,14 @@ def name(): def suffixes(): return ["onnx"] - def load(self, path, shape_dict=None): + def load(self, path, shape_dict=None, **kwargs): # pylint: disable=C0415 import onnx # pylint: disable=E1101 model = onnx.load(path) - return relay.frontend.from_onnx(model, shape=shape_dict) + return relay.frontend.from_onnx(model, shape=shape_dict, **kwargs) class TensorflowFrontend(Frontend): @@ -179,7 +180,7 @@ def name(): def suffixes(): return ["pb"] - def load(self, path, shape_dict=None): + def load(self, path, shape_dict=None, **kwargs): # pylint: disable=C0415 import tensorflow as tf import tvm.relay.testing.tf as tf_testing @@ -192,7 +193,7 @@ def load(self, path, shape_dict=None): graph_def = tf_testing.ProcessGraphDefParam(graph_def) logger.debug("parse TensorFlow model and convert into Relay computation graph") - return relay.frontend.from_tensorflow(graph_def, shape=shape_dict) + return relay.frontend.from_tensorflow(graph_def, shape=shape_dict, **kwargs) class TFLiteFrontend(Frontend): @@ -206,7 +207,7 @@ def name(): def suffixes(): return ["tflite"] - def load(self, path, shape_dict=None): + def load(self, path, shape_dict=None, **kwargs): # pylint: disable=C0415 import tflite.Model as model @@ -229,7 +230,7 @@ def load(self, path, shape_dict=None): raise TVMCException("input file not tflite version 3") logger.debug("parse TFLite model and convert into Relay computation graph") - mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict) + mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, **kwargs) return mod, params @@ -245,7 +246,7 @@ def suffixes(): # Torch Script is a zip file, but can be named pth return ["pth", "zip"] - def load(self, path, shape_dict=None): + def load(self, path, shape_dict=None, **kwargs): # pylint: disable=C0415 import torch @@ -259,7 +260,7 @@ def load(self, path, shape_dict=None): input_shapes = list(shape_dict.items()) logger.debug("parse Torch model and convert into Relay computation graph") - return relay.frontend.from_pytorch(traced_model, input_shapes) + return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs) ALL_FRONTENDS = [ @@ -339,7 +340,7 @@ def guess_frontend(path): raise TVMCException("failed to infer the model format. Please specify --model-format") -def load_model(path, model_format=None, shape_dict=None): +def load_model(path, model_format=None, shape_dict=None, **kwargs): """Load a model from a supported framework and convert it into an equivalent relay representation. @@ -367,6 +368,6 @@ def load_model(path, model_format=None, shape_dict=None): else: frontend = guess_frontend(path) - mod, params = frontend.load(path, shape_dict) + mod, params = frontend.load(path, shape_dict, **kwargs) return mod, params diff --git a/tests/python/driver/tvmc/test_frontends.py b/tests/python/driver/tvmc/test_frontends.py index b41f4c4dff2d..5a63c5c47933 100644 --- a/tests/python/driver/tvmc/test_frontends.py +++ b/tests/python/driver/tvmc/test_frontends.py @@ -115,26 +115,34 @@ def test_load_model__tflite(tflite_mobilenet_v1_1_quant): assert "_param_1" in params.keys() -def test_load_model__keras(keras_resnet50): +@pytest.mark.parametrize("load_model_kwargs", [{}, {"layout": "NCHW"}]) +def test_load_model__keras(keras_resnet50, load_model_kwargs): # some CI environments wont offer TensorFlow/Keras, so skip in case it is not present pytest.importorskip("tensorflow") - mod, params = tvmc.frontends.load_model(keras_resnet50) + mod, params = tvmc.frontends.load_model(keras_resnet50, **load_model_kwargs) assert type(mod) is IRModule assert type(params) is dict ## check whether one known value is part of the params dict assert "_param_1" in params.keys() +def verify_load_model__onnx(model, **kwargs): + mod, params = tvmc.frontends.load_model(model, **kwargs) + assert type(mod) is IRModule + assert type(params) is dict + return mod, params + + def test_load_model__onnx(onnx_resnet50): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") - - mod, params = tvmc.frontends.load_model(onnx_resnet50) - assert type(mod) is IRModule - assert type(params) is dict - ## check whether one known value is part of the params dict + mod, params = verify_load_model__onnx(onnx_resnet50) + # check whether one known value is part of the params dict assert "resnetv24_batchnorm0_gamma" in params.keys() + mod, params = verify_load_model__onnx(onnx_resnet50, freeze_params=True) + # check that the parameter dict is empty, implying that they have been folded into constants + assert params == {} def test_load_model__pb(pb_mobilenet_v1_1_quant):