Skip to content

Commit

Permalink
[Torch] Add an option to make imported models compatible with the Rel…
Browse files Browse the repository at this point in the history
…ay text parser (#9015)

* [Torch] Add an option to make imported models compatible with the
Relay text parser

* py format
  • Loading branch information
masahi authored Sep 15, 2021
1 parent 354019d commit f350ea6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 188 deletions.
36 changes: 29 additions & 7 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3400,8 +3400,8 @@ def _getattr_attr_name(node):
return attr_name


def _getattr_full_name(getattrs):
return ".".join([_getattr_attr_name(node) for node in getattrs])
def _getattr_full_name(getattrs, sep="."):
return sep.join([_getattr_attr_name(node) for node in getattrs])


def _get_pytorch_value_type(typ, default_dtype="float32"):
Expand Down Expand Up @@ -3657,7 +3657,7 @@ def terminate(users):
return get_use_chains(root_getattr_node, terminate)


def convert_params(graph, state_dict):
def convert_params(graph, state_dict, use_parser_friendly_name=False):
"""
Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time
Expand All @@ -3668,6 +3668,7 @@ def convert_params(graph, state_dict):
packed_param_map = {}
vars_by_name = {}
seen = set()
attr_name_sep = "_" if use_parser_friendly_name else "."

for node in getattr_nodes:
if _get_output_name(node) in seen:
Expand All @@ -3676,7 +3677,7 @@ def convert_params(graph, state_dict):
for getattrs in get_attr_chains(node):
seen.update(map(_get_output_name, getattrs))

full_attr = _getattr_full_name(getattrs)
full_attr = _getattr_full_name(getattrs, attr_name_sep)
full_attr_node_name = _get_output_name(getattrs[-1])

if full_attr.endswith("_packed_params"): # for quantized models
Expand Down Expand Up @@ -3706,7 +3707,13 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)


def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dtype="float32"):
def from_pytorch(
script_module,
input_infos,
custom_convert_map=None,
default_dtype="float32",
use_parser_friendly_name=False,
):
"""Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
Expand All @@ -3729,6 +3736,15 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
custom_convert_map : Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above
default_type : str
The default dtype to use when type information is not provided by PyTorch.
use_parser_friendly_name : bool
When True, replace '.' with `_' in a original parameter name.
The Relay text parser treats a variable name followed by a period as a tuple element access,
so a variable name like "dense.weight" cannot be parsed correctly.
Use this option when you want to run the AnnotateSpans pass on the imported module.
Returns
-------
mod : tvm.IRModule
Expand Down Expand Up @@ -3758,7 +3774,13 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
outputs = _get_relay_input_vars(
graph, input_infos, prelude, default_dtype=default_dtype, is_module=is_module
)
param_vars, tensors, packed_param_map = convert_params(graph, params)

if use_parser_friendly_name:
new_names = [key.replace(".", "_") for key in params.keys()]
params = dict(zip(new_names, params.values()))

param_vars, tensors, packed_param_map = convert_params(graph, params, use_parser_friendly_name)

tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

outputs.update(param_vars)
Expand All @@ -3778,7 +3800,7 @@ def from_pytorch(script_module, input_infos, custom_convert_map=None, default_dt
# ListConstruct kept original python list. Convert to tuple.
ret = _expr.Tuple(ret)

# Separate data inputs and parameters to make sure data inputs are always in the beginning.
# Separate data inputs and parameters to make sure data inputs come first.
func_args = []
data_inputs = []
for arg in _analysis.free_vars(ret):
Expand Down
192 changes: 11 additions & 181 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3937,185 +3937,15 @@ def forward(self, x):
verify_model(Flip(axis=-1), input_data=input)


def test_annotate_span():
model = torchvision.models.resnet18().eval()
inp = torch.randn([1, 3, 224, 224])
trace = torch.jit.trace(model, inp).eval()
mod, params = relay.frontend.from_pytorch(
trace, [("input", inp.shape)], use_parser_friendly_name=True
)
relay.transform.AnnotateSpans()(mod)


if __name__ == "__main__":
# some structural tests
test_forward_traced_function()
test_forward_dtypes()
test_weight_names()
test_duplicate_weight_use()

# Single operator tests
test_forward_pixel_shuffle()
test_forward_add()
test_forward_subtract()
test_forward_multiply()
test_forward_matmul()
test_forward_rsub()
test_forward_onehot()
test_forward_embedding()
test_forward_reshape()
test_forward_reciprocal()
test_forward_repeat()
test_forward_repeat_interleave()
test_forward_squeeze()
test_forward_unsqueeze()
test_forward_concatenate()
test_forward_reduce_sum()
test_forward_reduce_prod()
test_forward_argmin()
test_forward_argmax()
test_forward_norm()
test_forward_frobenius_norm()
test_forward_std()
test_forward_variance()
test_forward_relu()
test_forward_prelu()
test_forward_leakyrelu()
test_forward_elu()
test_forward_celu()
test_forward_gelu()
test_forward_selu()
test_forward_log_sigmoid()
test_forward_adaptiveavgpool()
test_forward_maxpool2d()
test_forward_maxpool1d()
test_forward_maxpool3d()
test_forward_hardtanh()
test_forward_conv()
test_forward_conv_transpose()
test_forward_threshold()
test_forward_contiguous()
test_forward_batchnorm()
test_forward_instancenorm()
test_forward_layernorm()
test_forward_groupnorm()
test_forward_transpose()
test_forward_size()
test_forward_view()
test_forward_select()
test_forward_take()
test_forward_topk()
test_forward_where()
test_forward_addcdiv()
test_forward_addcmul()
test_forward_true_divide()
test_forward_is_floating_point()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
test_forward_linear()
test_forward_avgpool1d()
test_forward_avgpool2d()
test_forward_avgpool3d()
test_forward_dropout()
test_forward_slice()
test_forward_narrow()
test_forward_mean()
test_forward_expand()
test_forward_pow()
test_forward_unary()
test_forward_clamp()
test_forward_clamp_()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()
test_forward_logical_xor()
test_forward_isfinite()
test_forward_isnan()
test_forward_isinf()
test_forward_ones()
test_forward_ones_like()
test_forward_zeros()
test_forward_zeros_like()
test_forward_full()
test_forward_full_like()
test_forward_linspace()
test_forward_arange()
test_forward_mesh_grid()
test_forward_chunk()
test_forward_split()
test_forward_gather()
test_upsample()
test_forward_upsample3d()
test_forward_nms()
test_forward_roi_align()
test_to()
test_flatten()
test_type_as()
test_forward_functional_pad()
test_forward_zero_pad2d()
test_forward_constant_pad1d()
test_forward_constant_pad2d()
test_forward_constant_pad3d()
test_forward_reflection_pad1d()
test_forward_reflection_pad2d()
test_forward_replication_pad1d()
test_forward_replication_pad2d()
test_forward_replication_pad3d()
test_adaptive_pool3d()
test_conv3d()
test_conv3d_transpose()
test_forward_index()
test_min_max()
test_logsumexp()
test_stack()
test_stack_dynamic()
test_forward_unbind()
test_forward_nonzero()
test_forward_scatter()
test_forward_index_put()
test_numel()
test_bincount()
test_cumsum()
test_masked_fill()
test_transformer()
test_sort()
test_argsort()
test_logical_and()
test_masked_select()
test_unique()
test_hard_swish()
test_hard_sigmoid()
test_forward_nll_loss()
test_forward_flip()

# Model tests
test_resnet18()
test_squeezenet1_0()
test_squeezenet1_1()
test_densenet121()
# disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug
# See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756
# test_inception_v3()
test_googlenet()
test_mnasnet0_5()
test_mobilenet_v2()

test_custom_conversion_map()

test_segmentation_models()
test_3d_models()

# Quantization test
from qnn_test import test_quantized_imagenet, test_quantized_modules

test_quantized_modules()
test_quantized_imagenet()

# Test simple conditionals and loop
test_control_flow()
test_simple_rnn()

# More complex recurrent models
from test_lstm import test_custom_lstm

test_custom_lstm()

# Test bert model
test_forward_pretrained_bert_base_uncased()

# Test convert torch script(jit) with specific inputs' types
test_convert_torch_script_with_input_types()
pytest.main([__file__])

0 comments on commit f350ea6

Please sign in to comment.