Skip to content

Commit

Permalink
Clean up torch integration.
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh Fromm committed Jul 27, 2022
1 parent c4cf317 commit 5b8dce3
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 40 deletions.
31 changes: 2 additions & 29 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,31 +319,6 @@ def square(self, inputs, input_types):
(dtype,) = input_types
return _op.power(inputs[0], _expr.const(2, dtype))

def tril(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
k_value = inputs[1]
else:
k_value = 0
input_shape = self.infer_shape(data)
k1, k2 = input_shape[-2:]
k1 = k_value + 1
diag_input = _op.zeros(input_shape, dtype=input_types[0])
return _op.matrix_set_diag(data, diag_input, k=(k1, k2))

def triu(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 2:
k_value = inputs[1]
else:
k_value = 0
input_shape = self.infer_shape(data)
k1, k2 = input_shape[-2:]
k1 = (k1 * -1) - 1
k2 = k_value - 1
diag_input = _op.zeros(input_shape, dtype=input_types[0])
return _op.matrix_set_diag(data, diag_input, k=(k1, k2))

def arange(self, inputs, input_types):
def _get_value(val, dtype):
# dtype is a tvm dtype
Expand Down Expand Up @@ -3552,8 +3527,8 @@ def create_convert_map(self):
"aten::sqrt": self.make_unary("sqrt"),
"aten::rsqrt": self.make_unary("rsqrt"),
"aten::square": self.square,
"aten::tril": self.tril,
"aten::triu": self.triu,
"aten::tril": functools.partial(self.trilu, mode="tril"),
"aten::triu": functools.partial(self.trilu, mode="triu"),
"aten::ceil": self.make_unary("ceil"),
"aten::floor": self.make_unary("floor"),
"aten::round": self.make_unary("round"),
Expand Down Expand Up @@ -3646,8 +3621,6 @@ def create_convert_map(self):
"aten::dot": self.dot,
"aten::mv": self.mv,
"aten::grid_sampler": self.grid_sampler,
"aten::triu": functools.partial(self.trilu, mode="triu"),
"aten::tril": functools.partial(self.trilu, mode="tril"),
"aten::__ior__": self.make_elemwise("bitwise_or"),
"aten::__iand__": self.make_elemwise("bitwise_and"),
"aten::__ixor__": self.make_elemwise("bitwise_xor"),
Expand Down
11 changes: 0 additions & 11 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4584,16 +4584,5 @@ def forward(self, x):
tvm.testing.assert_allclose(out, output_torch, rtol=1e-5, atol=1e-5)


@tvm.testing.uses_gpu
def test_trilu():
def _test_trilu(op, diagonal):
return lambda inp: op(inp, diagonal)

for op in [torch.triu, torch.tril]:
verify_model(_test_trilu(op, 0), [torch.rand(size=[3, 3])])
verify_model(_test_trilu(op, 1), [torch.rand(size=[6, 6])])
verify_model(_test_trilu(op, -2), [torch.rand(size=[6, 6])])


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 5b8dce3

Please sign in to comment.