diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 7cf856c767fa..817f6cf8db3c 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -2265,3 +2265,100 @@ def convert_sum(node, **kwargs): name=name ) return [node] + +@mx_op.register("broadcast_logical_and") +def convert_logical_and(node, **kwargs): + """Map MXNet's logical and operator attributes to onnx's Add operator + and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + and_node = onnx.helper.make_node( + "And", + [input_node_a, input_node_b], + [name], + name=name, + ) + + return [and_node] + +@mx_op.register("broadcast_logical_or") +def convert_logical_or(node, **kwargs): + """Map MXNet's logical or operator attributes to onnx's Or operator + and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + or_node = onnx.helper.make_node( + "Or", + [input_node_a, input_node_b], + [name], + name=name, + ) + + return [or_node] + +@mx_op.register("broadcast_logical_xor") +def convert_logical_xor(node, **kwargs): + """Map MXNet's logical xor operator attributes to onnx's Xor operator + and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_a_id = kwargs["index_lookup"][inputs[0][0]] + input_node_b_id = kwargs["index_lookup"][inputs[1][0]] + + input_node_a = proc_nodes[input_node_a_id].name + input_node_b = proc_nodes[input_node_b_id].name + + xor_node = onnx.helper.make_node( + "Xor", + [input_node_a, input_node_b], + [name], + name=name, + ) + + return [xor_node] + +@mx_op.register("logical_not") +def convert_logical_not(node, **kwargs): + """Map MXNet's logical not operator attributes to onnx's Not operator + and return the created node. + """ + onnx = import_onnx_modules() + name = node["name"] + proc_nodes = kwargs["proc_nodes"] + inputs = node["inputs"] + + input_node_id = kwargs["index_lookup"][inputs[0][0]] + input_node = proc_nodes[input_node_id].name + + node = onnx.helper.make_node( + "Not", + [input_node], + [name], + name=name + ) + + return [node] diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py b/tests/python-pytest/onnx/export/mxnet_export_test.py index 9f91369d667e..1e560c98cfe6 100644 --- a/tests/python-pytest/onnx/export/mxnet_export_test.py +++ b/tests/python-pytest/onnx/export/mxnet_export_test.py @@ -56,6 +56,19 @@ 'https://s3.amazonaws.com/onnx-mxnet/model-zoo/inception_v2.tar.gz' } +def get_int_inputs(interval, shape): + """Helper to get integer input of given shape and range""" + assert len(interval) == len(shape) + inputs = [] + input_tensors = [] + for idx in range(len(interval)): + low, high = interval[idx] + inputs.append(np.random.randint(low, high, size=shape[idx]).astype("float32")) + input_tensors.append(helper.make_tensor_value_info("input"+str(idx+1), + TensorProto.FLOAT, shape=shape[idx])) + + return inputs, input_tensors + def get_test_files(name): """Extract tar file and returns model path and input, output data""" tar_name = download(URLS.get(name), dirname=CURR_PATH.__str__()) @@ -238,6 +251,70 @@ def test_square(): npt.assert_almost_equal(result, numpy_op) +@with_seed() +def test_logical_and(): + """Test for logical and in onnx operators.""" + inputs, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)]) + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))] + nodes = [helper.make_node("And", ["input1", "input2"], ["output"])] + graph = helper.make_graph(nodes, + "and_test", + input_tensor, + outputs) + model = helper.make_model(graph) + bkd_rep = backend.prepare(model) + output = bkd_rep.run([inputs[0], inputs[1]]) + numpy_op = np.logical_and(inputs[0], inputs[1]).astype(np.float32) + npt.assert_almost_equal(output[0], numpy_op) + +@with_seed() +def test_logical_or(): + """Test for logical or in onnx operators.""" + inputs, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)]) + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))] + nodes = [helper.make_node("Or", ["input1", "input2"], ["output"])] + graph = helper.make_graph(nodes, + "or_test", + input_tensor, + outputs) + model = helper.make_model(graph) + bkd_rep = backend.prepare(model) + output = bkd_rep.run([inputs[0], inputs[1]]) + numpy_op = np.logical_or(inputs[0], inputs[1]).astype(np.float32) + npt.assert_almost_equal(output[0], numpy_op) + +@with_seed() +def test_logical_not(): + """Test for logical not in onnx operators.""" + inputs, input_tensor = get_int_inputs([(0, 2)], [(3, 4, 5)]) + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))] + nodes = [helper.make_node("Not", ["input1"], ["output"])] + graph = helper.make_graph(nodes, + "not_test", + input_tensor, + outputs) + model = helper.make_model(graph) + bkd_rep = backend.prepare(model) + output = bkd_rep.run([inputs[0]]) + numpy_op = np.logical_not(inputs[0]).astype(np.float32) + npt.assert_almost_equal(output[0], numpy_op) + +@with_seed() +def test_logical_xor(): + """Test for logical xor in onnx operators.""" + inputs, input_tensor = get_int_inputs([(0, 2), (0, 2)], [(3, 4, 5), (3, 4, 5)]) + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, shape=np.shape(inputs[0]))] + nodes = [helper.make_node("Xor", ["input1", "input2"], ["output"])] + graph = helper.make_graph(nodes, + "xor_test", + input_tensor, + outputs) + model = helper.make_model(graph) + bkd_rep = backend.prepare(model) + output = bkd_rep.run([inputs[0], inputs[1]]) + numpy_op = np.logical_xor(inputs[0], inputs[1]).astype(np.float32) + npt.assert_almost_equal(output[0], numpy_op) + if __name__ == '__main__': test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000)) test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000)) diff --git a/tests/python-pytest/onnx/import/test_cases.py b/tests/python-pytest/onnx/import/test_cases.py index 7f34247c94e2..ad0d23d2b9da 100644 --- a/tests/python-pytest/onnx/import/test_cases.py +++ b/tests/python-pytest/onnx/import/test_cases.py @@ -55,7 +55,6 @@ 'test_argmax', 'test_argmin', 'test_min', - 'test_logical_', # enabling partial test cases for matmul 'test_matmul_3d', 'test_matmul_4d',