Skip to content

Commit

Permalink
feat: support cumsum dynamo converter (#2403)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored and gs-olive committed Dec 5, 2023
1 parent 5b0e5fc commit 5770e00
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 1 deletion.
103 changes: 102 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import math
from typing import Optional

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_positive_dim,
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
Expand Down Expand Up @@ -96,3 +102,98 @@ def expand(
layer = ctx.net.add_slice(input_t, start=start, shape=shape, stride=stride)
set_layer_name(layer, target, name, source_ir)
return layer.get_output(0)


def chunk(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
chunks: int,
dim: int,
) -> TRTTensor:
if chunks <= 0:
raise RuntimeError(
f"chunk expects `chunks` to be greater than 0, got: {chunks}"
)

shape = input.shape
dim = get_positive_dim(dim, len(shape))

if dim >= len(shape):
raise RuntimeError(
f"chunk expects `dim` to be less than the length of input shape, got: {dim}"
)

dynamic_shape = has_dynamic_shape(input.shape)
if dynamic_shape > 0:
# Check whether slice target dim is dynamic shape dim
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"

size_dim = shape[dim]
chunk_size = math.ceil(size_dim / chunks)
result = []
start = 0
end = min(start + chunk_size, size_dim)
cnt = 0

while start < end:
result.append(
slice_op(
ctx,
target,
source_ir,
f"{name}_slice_{cnt}",
input,
dim,
start,
end,
1,
)
)
start = end
end = min(start + chunk_size, size_dim)
cnt += 1

return result


def cumsum(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dim: int,
) -> TRTTensor:
input_shape = input.shape
dim = get_positive_dim(dim, len(input_shape))
loop = ctx.net.add_loop()
axis = np.array(input_shape[dim])
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
iterator = loop.add_iterator(input, dim, reverse=False)
data = iterator.get_output(0)
new_dims = tuple(data.shape)
zeros = np.zeros(new_dims)
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")

running_sum = loop.add_recurrence(zero_trttensor)
set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)
running_sum_tensor = running_sum.get_output(0)

current_sum = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_elementwise_add",
data,
running_sum_tensor,
)
running_sum.set_input(1, current_sum)

loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim)
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
loop_output.set_input(1, trip_limit)
return loop_output.get_output(0)
69 changes: 69 additions & 0 deletions tests/py/dynamo/conversion/test_cumsum_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestCumsumConverter(DispatchTestCase):
@parameterized.expand(
[
((1,), 0),
((2,), 0),
((3,), -1),
]
)
def test_cumsum_1D(self, shape, dim):
class Cumsum(nn.Module):
def forward(self, x):
return torch.ops.aten.cumsum.default(x, dim)

inputs = [torch.randn(shape)]
self.run_test(
Cumsum(),
inputs,
)

@parameterized.expand(
[
((3, 1), 0),
((3, 1), 1),
((2, 3), -1),
((2, 3), -2),
]
)
def test_cumsum_2D(self, shape, dims):
class Cumsum(nn.Module):
def forward(self, x):
return torch.ops.aten.cumsum.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Cumsum(),
inputs,
)

@parameterized.expand(
[
((4, 2, 3), 0),
((4, 2, 3), 1),
((1, 2, 3), 2),
((1, 2, 3), -1),
((1, 2, 3), -2),
]
)
def test_cumsum_3D(self, shape, dims):
class Cumsum(nn.Module):
def forward(self, x):
return torch.ops.aten.cumsum.default(x, dims)

inputs = [torch.randn(shape)]
self.run_test(
Cumsum(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 5770e00

Please sign in to comment.