diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 246ed97b14e9a..1c1f8de995873 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2010,6 +2010,33 @@ def scatter(self, inputs, input_types): src = inputs[3] return _op.transform.scatter(data, index, src, axis) + def index_put(self, inputs, input_types): + in_tensor = inputs[0] + indices = inputs[1] + values = inputs[2] + accumulate = inputs[3] + # accumulate parameter is ignored. + # torch.index_put default is False but Relay.scatter_nd accumulates values. + # We assume there is no duplicate indices in torch.index_put input + if not accumulate: + logging.warning( + "torch.index_put accumulate parameter is False. " + "TVM uses tvm.relay.scatter_nd operator which accumulates values. " + "Make sure there is no duplicate indices in torch.index_put input." + ) + # Relay scatter_nd does not support input tensor + # We assume that torch.index_put is used with empty zero-values input tensor + # scatter_nd will create empty zero-values tensor with a given shape + out_shape = self.infer_shape(in_tensor) + logging.warning( + "tvm.relay.scatter_nd operator does not support input tensor parameter. " + "TVM assumes that torch.index_put is used with empty zero-values input tensor" + ) + # Combine array of index tensors into one index tensor with shape (N,_) + indices_expdim = [self.unsqueeze((x, 0), None) for x in indices] + indices_concat = self.concatenate((indices_expdim, 0), None) + return _op.transform.scatter_nd(values, indices_concat, out_shape) + def scalar_tensor(self, inputs, input_types): data = inputs[0] cast_map = { @@ -2326,6 +2353,8 @@ def create_convert_map(self): "aten::nonzero": self.nonzero, "aten::nonzero_numpy": self.nonzero_numpy, "aten::scatter": self.scatter, + "aten::index_put": self.index_put, + "aten::index_put_": self.index_put, "aten::scalar_tensor": self.scalar_tensor, "aten::__interpolate": self.interpolate, "aten::IntImplicit": self.identity, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 8d968e9760c9b..aa42b0fb84e45 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3327,6 +3327,38 @@ def test_fn_scatter_add(dim): verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets) +def test_forward_index_put(): + # torch.index_put for 2D tensor and default accumulate (False) + def test_fn_index_put2(): + return lambda data, xidx, yidx, values: torch.index_put( + data, indices=[xidx, yidx], values=values + ) + + # torch.index_put for 3D tensor and accumulate=True + def test_fn_index_put3a(): + return lambda data, xidx, yidx, zidx, values: torch.index_put( + data, indices=[xidx, yidx, zidx], values=values, accumulate=True + ) + + shape = (3, 5) + in_data = torch.zeros(shape) + xidx = torch.tensor([0, 1, 2, 2]) + yidx = torch.tensor([0, 1, 3, 4]) + values = torch.tensor([2.0, 4.0, 7.0, 9.0]) + + targets = ["llvm", "cuda"] + verify_trace_model(test_fn_index_put2(), [in_data, xidx, yidx, values], targets) + + shape = (3, 5, 3) + in_data = torch.zeros(shape) + xidx = torch.tensor([0, 1, 2, 2, 0]) + yidx = torch.tensor([0, 1, 3, 4, 0]) + zidx = torch.tensor([0, 1, 1, 2, 0]) + values = torch.tensor([2.0, 4.0, 7.0, 9.0, 1.0]) + + verify_trace_model(test_fn_index_put3a(), [in_data, xidx, yidx, zidx, values], targets) + + def test_numel(): class Numel(Module): def forward(self, data):