Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMC] Allow optional arguments to be passed to importers #7674

Merged
merged 3 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions python/tvm/driver/tvmc/frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def suffixes():
"""File suffixes (extensions) used by this frontend"""

@abstractmethod
def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
"""Load a model from a given path.

Parameters
Expand Down Expand Up @@ -101,7 +101,7 @@ def name():
def suffixes():
return ["h5"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0103
tf, keras = import_keras()

Expand Down Expand Up @@ -130,7 +130,9 @@ def load(self, path, shape_dict=None):
input_shapes = {name: x.shape for (name, x) in zip(model.input_names, inputs)}
if shape_dict is not None:
input_shapes.update(shape_dict)
return relay.frontend.from_keras(model, input_shapes, layout="NHWC")
layout = kwargs.get("layout", "NHWC")
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
kwargs["layout"] = layout
return relay.frontend.from_keras(model, input_shapes, **kwargs)

def is_sequential_p(self, model):
_, keras = import_keras()
Expand Down Expand Up @@ -158,14 +160,14 @@ def name():
def suffixes():
return ["onnx"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import onnx

# pylint: disable=E1101
model = onnx.load(path)

return relay.frontend.from_onnx(model, shape=shape_dict)
return relay.frontend.from_onnx(model, shape=shape_dict, **kwargs)


class TensorflowFrontend(Frontend):
Expand All @@ -179,7 +181,7 @@ def name():
def suffixes():
return ["pb"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tensorflow as tf
import tvm.relay.testing.tf as tf_testing
Expand All @@ -192,7 +194,9 @@ def load(self, path, shape_dict=None):
graph_def = tf_testing.ProcessGraphDefParam(graph_def)

logger.debug("parse TensorFlow model and convert into Relay computation graph")
return relay.frontend.from_tensorflow(graph_def, shape=shape_dict)
return relay.frontend.from_tensorflow(
graph_def, shape=shape_dict, **kwargs
)


class TFLiteFrontend(Frontend):
Expand All @@ -206,7 +210,7 @@ def name():
def suffixes():
return ["tflite"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import tflite.Model as model

Expand All @@ -229,7 +233,7 @@ def load(self, path, shape_dict=None):
raise TVMCException("input file not tflite version 3")

logger.debug("parse TFLite model and convert into Relay computation graph")
mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict)
mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, **kwargs)
return mod, params


Expand All @@ -245,7 +249,7 @@ def suffixes():
# Torch Script is a zip file, but can be named pth
return ["pth", "zip"]

def load(self, path, shape_dict=None):
def load(self, path, shape_dict=None, **kwargs):
# pylint: disable=C0415
import torch

Expand All @@ -259,7 +263,7 @@ def load(self, path, shape_dict=None):
input_shapes = list(shape_dict.items())

logger.debug("parse Torch model and convert into Relay computation graph")
return relay.frontend.from_pytorch(traced_model, input_shapes)
return relay.frontend.from_pytorch(traced_model, input_shapes, **kwargs)


ALL_FRONTENDS = [
Expand Down Expand Up @@ -339,7 +343,7 @@ def guess_frontend(path):
raise TVMCException("failed to infer the model format. Please specify --model-format")


def load_model(path, model_format=None, shape_dict=None):
def load_model(path, model_format=None, shape_dict=None, **kwargs):
"""Load a model from a supported framework and convert it
into an equivalent relay representation.

Expand Down Expand Up @@ -367,6 +371,6 @@ def load_model(path, model_format=None, shape_dict=None):
else:
frontend = guess_frontend(path)

mod, params = frontend.load(path, shape_dict)
mod, params = frontend.load(path, shape_dict, **kwargs)

return mod, params
26 changes: 19 additions & 7 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,38 @@ def test_load_model__tflite(tflite_mobilenet_v1_1_quant):
assert "_param_1" in params.keys()


def test_load_model__keras(keras_resnet50):
def verify_load_model__keras(model, **kwargs):
# some CI environments wont offer TensorFlow/Keras, so skip in case it is not present
pytest.importorskip("tensorflow")

mod, params = tvmc.frontends.load_model(keras_resnet50)
mod, params = tvmc.frontends.load_model(model)
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
assert type(mod) is IRModule
assert type(params) is dict
## check whether one known value is part of the params dict
assert "_param_1" in params.keys()


def test_load_model__onnx(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")
def test_load_model__keras(keras_resnet50):
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
verify_load_model__keras(keras_resnet50)
verify_load_model__keras(keras_resnet50, layout="NCHW")


mod, params = tvmc.frontends.load_model(onnx_resnet50)
def verify_load_model__onnx(model, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the comment above, I would probably suggest using @pytest.mark.parametrize here as well and not having the separate verify_load_model__onnx helper function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The onnx test is a little bit more complex because of the things being tested, so I think I'm going to leave it as is and try to not break things.

mod, params = tvmc.frontends.load_model(model, **kwargs)
assert type(mod) is IRModule
assert type(params) is dict
## check whether one known value is part of the params dict
return mod, params


def test_load_model__onnx(onnx_resnet50):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")
mod, params = verify_load_model__onnx(onnx_resnet50)
# check whether one known value is part of the params dict
assert "resnetv24_batchnorm0_gamma" in params.keys()
mod, params = verify_load_model__onnx(onnx_resnet50, freeze_params=True)
# check that the parameter dict is empty, implying that they have been folded into constants
assert params == {}


def test_load_model__pb(pb_mobilenet_v1_1_quant):
Expand Down