Skip to content

Commit

Permalink
feat: support 1d ITensor offsets for embedding_bag converter (#2677)
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 authored May 1, 2024
1 parent ff8c872 commit de81be2
Show file tree
Hide file tree
Showing 6 changed files with 708 additions and 128 deletions.
42 changes: 8 additions & 34 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,26 +231,7 @@ def aten_ops_cat(
)


def embedding_param_validator(embedding_node: Node) -> bool:
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)

if scale_grad_by_freq is not None:
_LOGGER.debug(
f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}."
)
return False

if sparse is not None:
_LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.")
return False

return True


@dynamo_tensorrt_converter(
torch.ops.aten.embedding.default, capability_validator=embedding_param_validator
)
@dynamo_tensorrt_converter(torch.ops.aten.embedding.default)
def aten_ops_embedding(
ctx: ConversionContext,
target: Target,
Expand All @@ -265,22 +246,19 @@ def aten_ops_embedding(
name,
input=args[1],
weight=args[0],
# args[2] is the padding index, which is useful for training only
scale_grad_by_freq=args_bounds_check(args, 3),
sparse=args_bounds_check(args, 4),
)


def embedding_bag_validator(node: Node) -> bool:
mode = args_bounds_check(node.args, 4, 0)
indices = node.args[1].meta.get("tensor_meta")
if not one_user_validator(node):
return False
meta = node.args[1].meta
indices = meta.get("tensor_meta")
if indices is None:
indices = meta.get("val")
if indices is None:
return False
return (
bool(node.args[2].op == "get_attr")
and (mode == 0 or mode == 1 or mode == 2)
and len(indices.shape) == 1
)
return len(indices.shape) == 1 # currently only support 1D indices


@dynamo_tensorrt_converter(
Expand All @@ -293,7 +271,6 @@ def embedding_bag_validator(node: Node) -> bool:
{
0: (TRTTensor,),
1: (TRTTensor,),
2: (np.ndarray, torch.Tensor),
}
)
def aten_ops_embedding_bag(
Expand All @@ -311,12 +288,9 @@ def aten_ops_embedding_bag(
weight=args[0],
indices=args[1],
offsets=args[2],
scale_grad_by_freq=args_bounds_check(args, 3, False),
mode=args_bounds_check(args, 4, 0),
sparse=args_bounds_check(args, 5, False),
per_sample_weights=args_bounds_check(args, 6, None),
include_last_offset=args_bounds_check(args, 7, False),
# padding index is useful for training only
)


Expand Down
109 changes: 109 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch import SymBool, SymFloat, SymInt
from torch.fx.node import Argument, Target
from torch_tensorrt import _enums
Expand Down Expand Up @@ -530,3 +531,111 @@ def flatten_dims(
new_shape = tuple(shape[:start_dim]) + (num_elements,) + tuple(shape[end_dim + 1 :])

return new_shape


def append(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
original_tensor: TRTTensor,
new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
dim: int = 0,
) -> TRTTensor:
"""
Append a new value to the last of the original tensor along the specified dimension (default 0).
For example, if the original tensor is [1, 2, 3], the new value is 4, and the dim is 0,
the new tensor will be [1, 2, 3, 4].
Args:
ctx (ConversionContext): A ConversionContext containing the TensorRT network
target (Target): Target of calling node
source_ir (Optional[SourceIR]): SourceIR of calling converter
name (str): Name of the calling layer
original_tensor (TRTTensor): A TRTTensor to append the new value to
new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to append
dim (int, optional): Dimention to append the new value. Defaults to 0.
Returns:
TRTTensor: A new TRTTensor that is the result of appending the new value to the original tensor
"""
if isinstance(new_value, (int, float)):
new_value = np.array([new_value])
new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype)

return impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_concat",
[original_tensor, new_value],
get_positive_dim(dim, len(original_tensor.shape)),
)


def set_item(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
original_tensor: TRTTensor,
index: int,
new_value: Union[TRTTensor, int, float, torch.Tensor, np.ndarray],
) -> TRTTensor:
"""
Set a new value to the original tensor at the specified index. For example,
if the original tensor is [1, 2, 3], the new value is 4, and the index is 1,
the new tensor will be [1, 4, 3].
If the index is out of bound, the new value will be appended to the end.
Args:
ctx (ConversionContext): A ConversionContext containing the TensorRT network
target (Target): Target of calling node
source_ir (Optional[SourceIR]): SourceIR of calling converter
name (str): Name of the calling layer
original_tensor (TRTTensor): A TRTTensor to set the new value to
index (int): The index to set the new value
new_value (Union[TRTTensor, int, float, torch.Tensor, np.ndarray]): A new value to set
Returns:
TRTTensor: A new TRTTensor that is the result of setting the new value to the original tensor
"""
if isinstance(new_value, (int, float)):
new_value = np.array([new_value])
new_value = get_trt_tensor(ctx, new_value, name, original_tensor.dtype)

len_original_tensor = original_tensor.shape[0]
index = get_positive_dim(index, len_original_tensor)

front_tensor = impl.slice.slice_op(
ctx,
target,
source_ir,
f"{name}_slice_front",
original_tensor,
dim=0,
start=0,
stop=index,
step=1,
)
rear_tensor = impl.slice.slice_op(
ctx,
target,
source_ir,
f"{name}_slice_rear",
original_tensor,
dim=0,
start=index + 1,
stop=len_original_tensor,
step=1,
)

ans = impl.cat.cat(
ctx,
target,
source_ir,
f"{name}_concat",
[front_tensor, new_value, rear_tensor],
0,
)
return ans
Loading

0 comments on commit de81be2

Please sign in to comment.