Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
{{wushirong}} committed May 5, 2023
1 parent 25db257 commit 1b99e4c
Show file tree
Hide file tree
Showing 24 changed files with 1,173 additions and 202 deletions.
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,10 +691,13 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
eps_field = trt.PluginField(
"eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
)
normalized_shape = kwargs["normalized_shape"]
try:
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
normalized_shape = np.array(normalized_shape, dtype=np.int32)
except TypeError:
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
_LOGGER.error(
f"Unable to convert normalized_shape with value {normalized_shape} to a field, fall back to []"
)
normalized_shape = np.array([], dtype=np.int32)

normalized_shape_filed = trt.PluginField(
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ class DiagnosticsWriter:

def __init__(self):
self._root_dir = tempfile.mkdtemp(prefix="fx2trt.")
self._data = ""
_LOGGER.info(f"Initializing DiagnosticsWriter with root_dir: {self._root_dir}")

def write(self, file_name: str, data: WriteObj):
"""
TODO: Can be disabled by regex on file_name
"""
self._data = data
# Only write if we are inside a collect_when() context.
if not _IS_IN_COLLECT_CONTEXT.get(False):
return
Expand All @@ -117,6 +119,9 @@ def write(self, file_name: str, data: WriteObj):
def root_dir(self) -> str:
return self._root_dir

def data(self) -> WriteObj:
return self._data

def _write(self, file_name: str, to_write: bytes):
# ms granularity - no naming collash, otherwise file will be
# overwritten.
Expand Down Expand Up @@ -271,6 +276,9 @@ def collect(self) -> str:
finally:
os.remove(fp)

def data(self) -> WriteObj:
return self._write.data()


def _res_or_err(data: WriteObj) -> t.Tuple[TWrite, str]:
if isinstance(data, (str, bytes)):
Expand Down
73 changes: 62 additions & 11 deletions py/torch_tensorrt/fx/input_tensor_spec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple
from typing import Any, Iterable, List, NamedTuple, Optional, Sequence, Tuple

import torch

Expand All @@ -18,6 +18,12 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
# is the dynamic batch dimension. Otherwise, we use the additional
# inputs to determine the batch dimension.
if additional_inputs is None:
batch_dims = None
if not isinstance(inputs, torch.Tensor) and len(inputs) > 1:
bs = inputs[0].size(0)
batch_dims = None
if not all(x.size(0) == bs for x in inputs):
batch_dims = InputTensorSpec.find_batch_size_dim(inputs)
return InputTensorSpec.from_tensors_with_dynamic_batch_size(
inputs,
(
Expand All @@ -26,6 +32,7 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
lower_setting.max_batch_size,
),
lower_setting.opt_profile_replica,
batch_dims,
)
else:
batch_dims = []
Expand Down Expand Up @@ -147,25 +154,69 @@ def from_tensors_with_dynamic_batch_size(
A list of InputTensorSpec named tuples with dynamic ranges.
"""
if batch_dims is None:
batch_dims = [0] * len(tensors)
batch_dims = cls.find_batch_size_dim(tensors)

input_specs = []
batch_size = tensors[0].size(batch_dims[0])

for i, tensor in enumerate(tensors):
batch_dim = batch_dims[i]
assert batch_size == tensor.size(
batch_dim
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
shape = list(tensor.shape)
shape[batch_dim] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
input_specs.append(
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
)
if batch_dim == -1:
input_specs.append(cls.from_tensor(tensor))
else:
shape = list(tensor.shape)
assert batch_size == tensor.size(
batch_dim
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
shape[batch_dim] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
input_specs.append(
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
)

return input_specs

@classmethod
# pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any`
def find_batch_size_dim(cls, inputs: Any) -> []:
if isinstance(inputs, torch.Tensor) or len(inputs) <= 1:
return [0]
shapes = [i.shape for i in inputs]
frequency_map = {}
first_dims = set()
for shape in shapes:
if len(shape) < 2:
# By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info
continue
# Dedup shape value for single tensor
first_dims.add(shape[0])
shape = set(shape)
for i in shape:
frequency_map[i] = frequency_map.get(i, 0) + 1

if len(first_dims) == 1:
# first dim is the same in every input: we use it as batch_size
batch_size = first_dims.pop()
elif frequency_map:
# first dims are different: we use the most frequent dim as batch_size
sorted_frequency = sorted(frequency_map.items(), key=lambda x: -x[1])
batch_size = sorted_frequency[0][0]
else:
# no dims to sort: no batch_size
batch_size = -1

bs_dim = []
for i in inputs:
# Default batch size dim = -1, indicate no batch_size
dim = -1
for index, val in enumerate(i.shape):
if val == batch_size:
dim = index
break
bs_dim.append(dim)

return bs_dim

def to_random_tensor(self, id=1):
shape = tuple(self.shape)
if len(get_dynamic_dims(shape)):
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def compile(
dynamic_batch=True,
is_aten=False,
use_experimental_fx_rt=False,
correctness_atol=1e-1,
correctness_rtol=1e-1,
) -> nn.Module:
"""
Takes in original module, input and lowering setting, run lowering workflow to turn module
Expand Down Expand Up @@ -81,6 +83,8 @@ def compile(
dynamic_batch=dynamic_batch,
is_aten=is_aten,
use_experimental_rt=use_experimental_fx_rt,
correctness_atol=correctness_atol,
correctness_rtol=correctness_rtol,
)
lowerer = Lowerer.create(lower_setting=lower_setting)
return lowerer(module, input)
Expand Down
46 changes: 41 additions & 5 deletions py/torch_tensorrt/fx/passes/lower_basic_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ def fill_with_mul_zero_and_add(*args):


def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule:
# Now we do constant folding on traced module. We want to skip pattern like
# weights -> quant -> dequant -> op during constant folding when the model is
# a quantized int8 model.
def skip_folding_quant_dequant(node: torch.fx.Node):
def skip_folding_ops(node: torch.fx.Node):
# dtype op
if node.target == acc_ops.dtype:
return True
# Now we do constant folding on traced module. We want to skip pattern like
# weights -> quant -> dequant -> op during constant folding when the model is
# a quantized int8 model.
# quant_dequant
if node.target != acc_ops.quantize_per_tensor:
return False
# If quantize_per_node -> dequantize, then skip folding.
Expand All @@ -66,7 +70,7 @@ def skip_folding_quant_dequant(node: torch.fx.Node):
return True
return False

const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant)
const_split_mod = split_const_subgraphs(traced_mod, skip_folding_ops)
const_split_mod.run_folding()
return const_split_mod

Expand Down Expand Up @@ -630,3 +634,35 @@ def fix_clamp_numerical_limits_to_fp16(

mod.recompile()
return mod


@log_before_after
@validate_inference(atol=1e-3, rtol=1e-2)
def remove_dtype_and_to_pattern(
mod: torch.fx.GraphModule, input: Input
) -> torch.fx.GraphModule:
"""
Remove this pattern since it is unnecessary to cast to dtype
%dtype : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.dtype](args = (), kwargs = {input: %_attention_layers_0__uva})
%to_18 : [#users=2] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.to_dtype](args = (), kwargs = {input: %x})
"""
for node in mod.graph.nodes:
if node.op == "call_function" and node.target == acc_ops.dtype:
# find its first user
next_node = next(iter(node.users))
# acc_op or pt op is treated differently
input = (
next_node.kwargs["input"]
if "input" in next_node.kwargs
else next_node.args[0]
)
if len(node.users) == 1 and (
next_node.target == acc_ops.to_dtype or next_node.target == "to"
):
next_node.replace_all_uses_with(input)
mod.graph.erase_node(next_node)
mod.graph.erase_node(node)

mod.graph.eliminate_dead_code()
mod.recompile()
return mod
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.fx.passes.pass_manager import inplace_wrapper, PassManager
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion
from torch_tensorrt.fx.utils import LowerPrecision

from ..input_tensor_spec import generate_input_specs
Expand Down Expand Up @@ -229,10 +230,9 @@ def lower_func(split_result: SplitResult) -> nn.Module:
submod = getattr(split_result.split_module, submod_name)

LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
_LOGGER.info(f"Now lowering submodule {submod_name}")
_LOGGER.info(f"ACC submodule graph: {submod.graph}")
lowering_start_time = datetime.datetime.now()

self.lower_setting.additional_inputs = (
Expand All @@ -251,6 +251,9 @@ def lower_func(split_result: SplitResult) -> nn.Module:
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
)
else:
_LOGGER.info(f"GPU submodule graph: {submod.graph}")
apply_bfloat_float_conversion(submod, submod_inputs, submod_name)

return split_result.split_module

Expand Down
Loading

0 comments on commit 1b99e4c

Please sign in to comment.