Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync fb internal change to OSS #1892

Merged
merged 1 commit into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,10 +691,13 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
eps_field = trt.PluginField(
"eps", np.array([kwargs["eps"]], dtype=np.float32), trt.PluginFieldType.FLOAT32
)
normalized_shape = kwargs["normalized_shape"]
try:
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
normalized_shape = np.array(normalized_shape, dtype=np.int32)
except TypeError:
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
_LOGGER.error(
f"Unable to convert normalized_shape with value {normalized_shape} to a field, fall back to []"
)
normalized_shape = np.array([], dtype=np.int32)

normalized_shape_filed = trt.PluginField(
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/fx/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ class DiagnosticsWriter:

def __init__(self):
self._root_dir = tempfile.mkdtemp(prefix="fx2trt.")
self._data = ""
_LOGGER.info(f"Initializing DiagnosticsWriter with root_dir: {self._root_dir}")

def write(self, file_name: str, data: WriteObj):
"""
TODO: Can be disabled by regex on file_name
"""
self._data = data
# Only write if we are inside a collect_when() context.
if not _IS_IN_COLLECT_CONTEXT.get(False):
return
Expand All @@ -117,6 +119,9 @@ def write(self, file_name: str, data: WriteObj):
def root_dir(self) -> str:
return self._root_dir

def data(self) -> WriteObj:
return self._data

def _write(self, file_name: str, to_write: bytes):
# ms granularity - no naming collash, otherwise file will be
# overwritten.
Expand Down Expand Up @@ -271,6 +276,9 @@ def collect(self) -> str:
finally:
os.remove(fp)

def data(self) -> WriteObj:
return self._write.data()


def _res_or_err(data: WriteObj) -> t.Tuple[TWrite, str]:
if isinstance(data, (str, bytes)):
Expand Down
6 changes: 6 additions & 0 deletions py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
Expand Down Expand Up @@ -211,6 +212,11 @@ def run(
builder_config = self.builder.create_builder_config()
builder_config.max_workspace_size = max_workspace_size

# Speed up TRT build time in the test environment
if trt.__version__ >= "8.6" and os.environ.get("TRT_TEST_ENV", "0") == "1":
_LOGGER.info("Set TRT optimization level to 0")
builder_config.builder_optimization_level = 0

cache = None
if timing_cache:
cache_file = numpy.array(timing_cache)
Expand Down
73 changes: 62 additions & 11 deletions py/torch_tensorrt/fx/input_tensor_spec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple
from typing import Any, Iterable, List, NamedTuple, Optional, Sequence, Tuple

import torch

Expand All @@ -18,6 +18,12 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
# is the dynamic batch dimension. Otherwise, we use the additional
# inputs to determine the batch dimension.
if additional_inputs is None:
batch_dims = None
if not isinstance(inputs, torch.Tensor) and len(inputs) > 1:
bs = inputs[0].size(0)
batch_dims = None
if not all(x.size(0) == bs for x in inputs):
batch_dims = InputTensorSpec.find_batch_size_dim(inputs)
return InputTensorSpec.from_tensors_with_dynamic_batch_size(
inputs,
(
Expand All @@ -26,6 +32,7 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
lower_setting.max_batch_size,
),
lower_setting.opt_profile_replica,
batch_dims,
)
else:
batch_dims = []
Expand Down Expand Up @@ -147,25 +154,69 @@ def from_tensors_with_dynamic_batch_size(
A list of InputTensorSpec named tuples with dynamic ranges.
"""
if batch_dims is None:
batch_dims = [0] * len(tensors)
batch_dims = cls.find_batch_size_dim(tensors)

input_specs = []
batch_size = tensors[0].size(batch_dims[0])

for i, tensor in enumerate(tensors):
batch_dim = batch_dims[i]
assert batch_size == tensor.size(
batch_dim
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
shape = list(tensor.shape)
shape[batch_dim] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
input_specs.append(
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
)
if batch_dim == -1:
input_specs.append(cls.from_tensor(tensor))
else:
shape = list(tensor.shape)
assert batch_size == tensor.size(
batch_dim
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
shape[batch_dim] = -1
shape_ranges: List[ShapeRange] = [tuple(tuple(shape[0:batch_dim] + [bs] + shape[batch_dim + 1 :]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
input_specs.append(
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
)

return input_specs

@classmethod
# pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any`
def find_batch_size_dim(cls, inputs: Any) -> []:
if isinstance(inputs, torch.Tensor) or len(inputs) <= 1:
return [0]
shapes = [i.shape for i in inputs]
frequency_map = {}
first_dims = set()
for shape in shapes:
if len(shape) < 2:
# By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info
continue
# Dedup shape value for single tensor
first_dims.add(shape[0])
shape = set(shape)
for i in shape:
frequency_map[i] = frequency_map.get(i, 0) + 1

if len(first_dims) == 1:
# first dim is the same in every input: we use it as batch_size
batch_size = first_dims.pop()
elif frequency_map:
# first dims are different: we use the most frequent dim as batch_size
sorted_frequency = sorted(frequency_map.items(), key=lambda x: -x[1])
batch_size = sorted_frequency[0][0]
else:
# no dims to sort: no batch_size
batch_size = -1

bs_dim = []
for i in inputs:
# Default batch size dim = -1, indicate no batch_size
dim = -1
for index, val in enumerate(i.shape):
if val == batch_size:
dim = index
break
bs_dim.append(dim)

return bs_dim

def to_random_tensor(self, id=1):
shape = tuple(self.shape)
if len(get_dynamic_dims(shape)):
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def compile(
dynamic_batch=True,
is_aten=False,
use_experimental_fx_rt=False,
correctness_atol=1e-1,
correctness_rtol=1e-1,
) -> nn.Module:
"""
Takes in original module, input and lowering setting, run lowering workflow to turn module
Expand Down Expand Up @@ -81,6 +83,8 @@ def compile(
dynamic_batch=dynamic_batch,
is_aten=is_aten,
use_experimental_rt=use_experimental_fx_rt,
correctness_atol=correctness_atol,
correctness_rtol=correctness_rtol,
)
lowerer = Lowerer.create(lower_setting=lower_setting)
return lowerer(module, input)
Expand Down
46 changes: 41 additions & 5 deletions py/torch_tensorrt/fx/passes/lower_basic_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,14 @@ def fill_with_mul_zero_and_add(*args):


def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule:
# Now we do constant folding on traced module. We want to skip pattern like
# weights -> quant -> dequant -> op during constant folding when the model is
# a quantized int8 model.
def skip_folding_quant_dequant(node: torch.fx.Node):
def skip_folding_ops(node: torch.fx.Node):
# dtype op
if node.target == acc_ops.dtype:
return True
# Now we do constant folding on traced module. We want to skip pattern like
# weights -> quant -> dequant -> op during constant folding when the model is
# a quantized int8 model.
# quant_dequant
if node.target != acc_ops.quantize_per_tensor:
return False
# If quantize_per_node -> dequantize, then skip folding.
Expand All @@ -66,7 +70,7 @@ def skip_folding_quant_dequant(node: torch.fx.Node):
return True
return False

const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant)
const_split_mod = split_const_subgraphs(traced_mod, skip_folding_ops)
const_split_mod.run_folding()
return const_split_mod

Expand Down Expand Up @@ -630,3 +634,35 @@ def fix_clamp_numerical_limits_to_fp16(

mod.recompile()
return mod


@log_before_after
@validate_inference(atol=1e-3, rtol=1e-2)
def remove_dtype_and_to_pattern(
mod: torch.fx.GraphModule, input: Input
) -> torch.fx.GraphModule:
"""
Remove this pattern since it is unnecessary to cast to dtype
%dtype : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.dtype](args = (), kwargs = {input: %_attention_layers_0__uva})
%to_18 : [#users=2] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.to_dtype](args = (), kwargs = {input: %x})
"""
for node in mod.graph.nodes:
if node.op == "call_function" and node.target == acc_ops.dtype:
# find its first user
next_node = next(iter(node.users))
# acc_op or pt op is treated differently
input = (
next_node.kwargs["input"]
if "input" in next_node.kwargs
else next_node.args[0]
)
if len(node.users) == 1 and (
next_node.target == acc_ops.to_dtype or next_node.target == "to"
):
next_node.replace_all_uses_with(input)
mod.graph.erase_node(next_node)
mod.graph.erase_node(node)

mod.graph.eliminate_dead_code()
mod.recompile()
return mod
7 changes: 5 additions & 2 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.fx.passes.pass_manager import inplace_wrapper, PassManager
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion
from torch_tensorrt.fx.utils import LowerPrecision

from ..input_tensor_spec import generate_input_specs
Expand Down Expand Up @@ -229,10 +230,9 @@ def lower_func(split_result: SplitResult) -> nn.Module:
submod = getattr(split_result.split_module, submod_name)

LOWER_SPLIT_PRE_OBSERVER.observe(submod_name, submod, submod_inputs)

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
_LOGGER.info(f"Now lowering submodule {submod_name}")
_LOGGER.info(f"ACC submodule graph: {submod.graph}")
lowering_start_time = datetime.datetime.now()

self.lower_setting.additional_inputs = (
Expand All @@ -251,6 +251,9 @@ def lower_func(split_result: SplitResult) -> nn.Module:
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
)
else:
_LOGGER.info(f"GPU submodule graph: {submod.graph}")
apply_bfloat_float_conversion(submod, submod_inputs, submod_name)

return split_result.split_module

Expand Down
Loading