From d2c1b776a5172ec2657e79c6a7eb72c4a34d8449 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 12 Feb 2021 12:53:37 -0800 Subject: [PATCH] [Torch] Add index_put operator --- python/tvm/relay/frontend/pytorch.py | 25 +++++++++++ tests/python/frontend/pytorch/test_forward.py | 44 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 246ed97b14e9a..af6b3d86c66af 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2010,6 +2010,29 @@ def scatter(self, inputs, input_types): src = inputs[3] return _op.transform.scatter(data, index, src, axis) + def index_put(self, inputs, input_types): + in_tensor = inputs[0] + indices = inputs[1] + values = inputs[2] + accumulate = inputs[3] + # accumulate parameter is ignored. + # torch.index_put default is False but Relay.scatter_nd accumulates values. + # We assume there is no duplicate indices in torch.index_put input + if not accumulate: + logging.warning("torch.index_put accumulate parameter is False. " + "TVM uses tvm.relay.scatter_nd operator which accumulates values. " + "Make sure there is no duplicate indices in torch.index_put input.") + # Relay scatter_nd does not support input tensor + # We assume that torch.index_put is used with empty zero-values input tensor + # scatter_nd will create empty zero-values tensor with a given shape + out_shape = self.infer_shape(in_tensor) + logging.warning("tvm.relay.scatter_nd operator does not support input tensor parameter. " + "TVM assumes that torch.index_put is used with empty zero-values input tensor") + # Combine array of index tensors into one index tensor with shape (N,_) + indices_expdim = [self.unsqueeze((x, 0), None) for x in indices] + indices_concat = self.concatenate((indices_expdim, 0), None) + return _op.transform.scatter_nd(values, indices_concat, out_shape) + def scalar_tensor(self, inputs, input_types): data = inputs[0] cast_map = { @@ -2326,6 +2349,8 @@ def create_convert_map(self): "aten::nonzero": self.nonzero, "aten::nonzero_numpy": self.nonzero_numpy, "aten::scatter": self.scatter, + "aten::index_put": self.index_put, + "aten::index_put_": self.index_put, "aten::scalar_tensor": self.scalar_tensor, "aten::__interpolate": self.interpolate, "aten::IntImplicit": self.identity, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8d968e9760c9b..85c16827d4969 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -232,6 +232,20 @@ def test_forward_pixel_shuffle(): verify_model(torch.nn.PixelShuffle(3).float().eval(), input_data=input_data) verify_model(torch.nn.PixelShuffle(4).float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_input_put(): + torch.set_grad_enabled(False) + input_shape = [3,3] + + class Zeros1(Module): + def forward(self, *args): + hs = torch.tensor([0, 1, 2, 2]) + ws = torch.tensor([0, 1, 1, 2]) + vs = torch.tensor([2.0, 4.0, 7.0, 9.0]) + return torch.index_put_(args[0], indices=[hs, ws], values=vs) + + input_data = torch.zeros(input_shape, dtype=torch.float) + verify_model(Zeros1(), input_data=input_data) @tvm.testing.uses_gpu def test_forward_add(): @@ -3327,6 +3341,36 @@ def test_fn_scatter_add(dim): verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets) +def test_forward_index_put(): + # torch.index_put for 2D tensor and default accumulate (False) + def test_fn_index_put2(): + return lambda data, xidx, yidx, values: \ + torch.index_put(data, indices=[xidx, yidx], values=values) + + # torch.index_put for 3D tensor and accumulate=True + def test_fn_index_put3a(): + return lambda data, xidx, yidx, zidx, values: \ + torch.index_put(data, indices=[xidx, yidx, zidx], values=values, accumulate=True) + + shape = (3, 5) + in_data = torch.zeros(shape) + xidx = torch.tensor([0, 1, 2, 2]) + yidx = torch.tensor([0, 1, 3, 4]) + values = torch.tensor([2.0, 4.0, 7.0, 9.0]) + + targets = ["llvm", "cuda"] + verify_trace_model(test_fn_index_put2(), [in_data, xidx, yidx, values], targets) + + shape = (3, 5, 3) + in_data = torch.zeros(shape) + xidx = torch.tensor([0, 1, 2, 2, 0]) + yidx = torch.tensor([0, 1, 3, 4, 0]) + zidx = torch.tensor([0, 1, 1, 2, 0]) + values = torch.tensor([2.0, 4.0, 7.0, 9.0, 1.0]) + + verify_trace_model(test_fn_index_put3a(), [in_data, xidx, yidx, zidx, values], targets) + + def test_numel(): class Numel(Module): def forward(self, data):