Skip to content

Commit

Permalink
[fx][acc_ops] add acc_ops.gather and acc_ops.index_select and shape i…
Browse files Browse the repository at this point in the history
…nference (#30)

Summary:
Pull Request resolved: https://github.com/pytorch/fx2trt/pull/30

as title

Reviewed By: 842974287

Differential Revision: D34874487

fbshipit-source-id: 86039d0f1269d879983977c65fe8fbe0a8bc1421
  • Loading branch information
alexbeloi authored and Wei Wei committed Jun 4, 2022
1 parent 6165038 commit 730645d
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
74 changes: 74 additions & 0 deletions test/tracer/test_acc_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2102,6 +2102,78 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
gm_retrace = acc_tracer.trace(gm, [a])
self.assertTrue(torch.equal(m(a), gm_retrace(a)))

def test_index_select(self):
class TestModule(nn.Module):
def __init__(self, dim, index):
super().__init__()
self._dim = dim
self._index = index

def forward(self, a: torch.Tensor) -> torch.Tensor:
return torch.index_select(a, self._dim, self._index)

dim = 0
index = torch.tensor([1, 0])
m = TestModule(dim, index)
_input = [torch.randn(2, 3), torch.randn(2, 3)]
traced = acc_tracer.trace(m, _input)

ph = index = index_select = None

for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "call_function" and node.target == acc_ops.index_select:
self.assertTrue(node.kwargs["input"] == ph)
self.assertTrue(node.kwargs["index"] == index)
self.assertTrue(node.kwargs["dim"] == dim)
index_select = node
elif node.op == "output":
self.assertEqual(index_select, node.args[0])
elif node.op == "get_attr":
# There only be one™ const node
self.assertTrue(index is None)
index = node
else:
self.fail(f"Unexpected node: {node.format_node()}")

def test_gather(self):
class TestModule(nn.Module):
def __init__(self, dim, index):
super().__init__()
self._dim = dim
self._index = index

def forward(self, a: torch.Tensor) -> torch.Tensor:
return torch.gather(a, self._dim, self._index)

dim = 0
index = torch.tensor([[1, 0], [0, 1]])
m = TestModule(dim, index)
_input = [torch.randn(2, 3), torch.randn(2, 3)]
traced = acc_tracer.trace(m, _input)

ph = index = gather = None

for node in traced.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(node.target), "a")
ph = node
elif node.op == "call_function" and node.target == acc_ops.gather:
self.assertTrue(node.kwargs["input"] == ph)
self.assertTrue(node.kwargs["index"] == index)
self.assertTrue(node.kwargs["dim"] == dim)
gather = node
elif node.op == "output":
self.assertEqual(gather, node.args[0])
elif node.op == "get_attr":
# There only be one™ const node
self.assertTrue(index is None)
index = node
else:
self.fail(f"Unexpected node: {node.format_node()}")

def test_all_acc_ops_registered(self):
self.assertEqual(
acc_normalizer._acc_ops,
Expand Down Expand Up @@ -2203,5 +2275,7 @@ def test_all_acc_ops_registered(self):
acc_ops.eq,
acc_ops.gt,
acc_ops.le,
acc_ops.gather,
acc_ops.index_select,
},
)
31 changes: 30 additions & 1 deletion tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,15 @@ def adaptive_avg_pool2d(*, input, output_size):

@register_acc_op_mapping(op_and_target=("call_function", nn.functional.avg_pool1d))
@register_acc_op
def avg_pool1d(*, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
def avg_pool1d(
*,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
):
return nn.functional.avg_pool1d(
input=input,
kernel_size=kernel_size,
Expand Down Expand Up @@ -2163,3 +2171,24 @@ def cumsum(*, input, dim, dtype=None):
@register_acc_op
def chunk(*, input, chunks, dim=0):
return torch.chunk(input=input, chunks=chunks, dim=dim)


@register_acc_op_mapping(op_and_target=("call_function", torch.gather),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim"),
("index", "index"),
("sparse_grad", "sparse_grad", this_arg_is_optional),
],
)
@register_acc_op
def gather(*, input, dim, index, sparse_grad=False):
return torch.gather(input=input, dim=dim, index=index, sparse_grad=sparse_grad)


@register_acc_op_mapping(
op_and_target=("call_function", torch.index_select),
)
@register_acc_op
def index_select(*, input, dim, index):
return torch.index_select(input, dim, index)

0 comments on commit 730645d

Please sign in to comment.