Skip to content

Commit

Permalink
Changes done internally at Facebook (#1208)
Browse files Browse the repository at this point in the history
6703b98dff0695d91026f057b951dba1355825fa Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_ops.prod
c822345d6d673e1653c2208435e34ab400bada3d Jason Park <jasonjk@fb.com> Add support for generic torch ops to be used in training.
e5758602a0592d6c2b71d6d66a0398c4dd9b5e20 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for repeat interleave
c13c633f04df162500eed477c0569eb2b81eb070 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for reduce ops
863476cf43b210922b88585b8f196dd84fbebb56 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_op.convolution
68dff39793e5c30c20010919a855bb3d984015d7 Ruichao Xiao <xiaoruichao@fb.com> [fbcode][GPU][DHEN]fuse split squeeze cat as reshape
f8b920769507ebd2ff02419b4aece25451298a95 Ruichao Xiao <xiaoruichao@fb.com> [fbcode][DHEN][GPU] reorder and merge cats whose input is a sublist of another cat
5b6a8d2d6be979983a52ac96225fefb510c3817c Andrew Or <andrewor@fb.com> [Quant][fx] Rename convert_to_reference to convert_to_reference_fx
996a0e080b8a8bc0b292a7c2ac92f41f6db33a2e Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_op.expand
084631fe74b304fbb9481ca15fd452a3714fb1b8 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_op.to_dtype
b3195e76329ccddbb5c4640cfa884d0e457d2d34 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for std
a5d964e62bdf769cf8c2e67321138b33e1f524a7 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_op.tile
3d33d45b2fc7f10f25c22946ba474b227e4b6529 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for squeeze
09085abf63d7e7732e2cd66e600e8afc6d58964f Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_op.topk
65edc7ea12899e9bd2af42c890a64de853d9b7fe Huamin Li <huaminli@fb.com> temporarily skip gelu tests
d11e521f9b90554ca86912a49920afa4406bb40d Shirong Wu <shirong@fb.com> Suppress accuracy check for remove_reshape_with_batch_size_change
6d948298b2327d229e010a34f1c221b11d2eb504 Ankur Singla <ankursingla@fb.com> [GPULowering] Suppress accuracy check for fuse_unsqueeze_cat_sum
e780b647fc9571b77d9f41c963041a6ac3d66f33 Janet Yang <qxy11@fb.com> Lower xrayvideo2022 to fx2trt
433c7207fef16b1fdff985546ea969c39fa83e7c generatedunixname89002005287564 <generatedunixname89002005287564@fb.com> [Codemod][Remove @noautodeps and @autodeps-skip tags] deeplearning/trt 1/2
66fdb65cffa925660c77b4758388399db3cbfe48 Scott Wolchok <swolchok@fb.com> [fx2ait] Minor Python cleanup in acc_ops_getitem
188132ecb2c19bcbf83cb2dc381f6e3798629f87 generatedunixname89002005324833 <generatedunixname89002005324833@fb.com> [AutoAccept][Codemod][FBSourceBuckFormatLinter] Daily `arc lint --take BUCKFORMAT`
4536bae4686dd01f2149541ea7fb330e178a4969 Wei Wei <wwei6@fb.com> [fx2trt] support sub
064602e666f86c110d931cd90a8536112a19b4ad Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_ops.interpolate
9dfd0ee0cecb1975e3f53c44de237d67ca443ec5 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for unary_ops
39b9efad8d5d82463a2016d135c0cf277de1c3c6 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for unsqueeze
2bb17667d1dabc95391950426fc1f921eb3d0959 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_ops.split
64dfb7b096686cb2fd33197340dc72f30d525456 Shirong Wu <shirong@fb.com> Group LN trt plugin
438f670e28df59b0734baa092a514fba3d75eb4f Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_ops.avgpool
df0fe32dae4343827bd9b37b72daae761b02f228 Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_ops masked fill
44fe735d3493ea2d05a56b49093e4a23dd63a98e Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shaope support for acc_ops.pad
4f931acca706d8ce79045ceafef2ea0486609149 Wei Wei <wwei6@fb.com> [fx2trt] torch.max dynamic shape test
bf6f6cbe217d26a95ca9122574adf7de3966db9e Shreyansh Prajapati <shreyanshp@fb.com> Change the name of the test from full_reduce to dim_reduce
1c5680ed107d9206f3514eff4069a3f6c870ba8c Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_ops.type_as
33e4c175a4f5fec78ac0b1c8eb262ca777c7aaba Shreyansh Prajapati <shreyanshp@fb.com> Test dynamic shape support for acc_ops.min
f37be34bcef9716080b8bafbd1f4ad72e412c44c Wei Wei <wwei6@fb.com> [fx2trt] plugin for grid_sample
57b5cc6a0f4839686ae360361a3a13b424794ee7 generatedunixname89002005367269 <generatedunixname89002005367269@fb.com> [AutoAccept][Codemod][FBSourceBlackLinter] Daily `arc lint --take BLACK`
eb741cc5e5a7babdc94e72d411670905f54da3e0 Shreyansh Prajapati <shreyanshp@fb.com> Updated the dynamic shape support for narrow op
521c36b96a14741ae89d7af6cbb658120bcec2ea Shreyansh Prajapati <shreyanshp@fb.com> Removing the comment for 4 dims dynamic shape support after analysis
e947343375967fe9efb0a16fdb9f63bff1449328 Shreyansh Prajapati <shreyanshp@fb.com> Updated the pad test for dynamic batch for analysis
3d64087014e91bc301a315eae43683b1aa2b66bc Oleg Khabinov <khabinov@fb.com> [trt_bc] Some improvements
dfd937a56fa01aca88a89b46176befdac4c202c4 Shreyansh Prajapati <shreyanshp@fb.com> Updated the test for as_strided op for analysis
11d76d0420dcaa4bb8890dcdeb86b6e534af831c Bangsheng Tang <bangsheng@fb.com> [gpu][infer] replace fx2trt_layer_norm with fbgemm layer_norm
932046ff6ea6dead114c0222b23ca3854690cffa Wei Wei <wwei6@fb.com> [fx2trt] bridge the dynamic batch and fixed shape
f911463393d8a671cfee6de6d1b5ef4d4f3991a6 Shirong Wu <shirong@fb.com> group swish LN plugin
ea65970f23dd7a468e5bc43240f2a9bfa07c9b3b Shirong Wu <shirong@fb.com> Create backend specific lower pass
38183e4a724e5514db2be7193cf4897b59759252 Alex Beloi <alexbeloi@fb.com> [fx] run acc_linter.lint in acc_tracer.trace
088abb6a790a62ca9f8515298a54117cc7fa31d4 Alex Beloi <alexbeloi@fb.com> [fx] re-add pointwise property to acc_ops.clamp
9905c34f2bd28e9b64f10336f9ac326cc39eb60d Oleg Khabinov <khabinov@fb.com> [trt] Comment out torch.ops.fbgemm dependency in TRT converters
8252e779476d2ff22ad78185af97a526b2f70fe3 Alex Beloi <alexbeloi@fb.com> [fx] add operator test suite to test_acc_tracer.py
7b93a89c903bc0b6c59efb73a510c3dce8ef793a Shirong Wu <shirong@fb.com> Add option for lower and trt_splitter
e08dabcbcd8c3e8ae92484e14cf07bb26993a8d6 Wei Wei <wwei6@fb.com> [fx2trt] convert print to logging
3d61dc169b8a7dd1aecad35891a628e44e2c5a02 Shreyansh Prajapati <shreyanshp@fb.com> Readme.md file for dynamic shape support
  • Loading branch information
Wei authored Jul 27, 2022
1 parent 2f896b3 commit 515b9b9
Show file tree
Hide file tree
Showing 19 changed files with 276 additions and 41 deletions.
137 changes: 137 additions & 0 deletions py/torch_tensorrt/fx/Dynamic_Shape_Support.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# PyTorch Operations Dynamic Shape Support Summary



| Operation | Test Method | Supports Dynamic Shape | Shape | Num of dimensions | Reason |
| --- | --- | --- | --- | --- | --- |
| adaptive_avgpool | | partially | (-1, -1, 256, 256) | 2 | AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims. |
| any | | no | | | torch.zeros(tuple(\[*input_t.shape\])). Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| as_strided | | no | | | RuntimeError: setStorage: sizes \[2, 3\], strides \[1, 2\], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 |
| avg_pool | avg_pool2d | yes | (-1,-,1,-1,-1) | 4 | |
| | avg_pool1d | partially | (-1, 3, 3) | 1 | |
| batchnorm | | partially | (-1, 3, -1, -1) | 3 | "Channel dim can't be dynamic for batch norm." |
| binary_ops | | yes | (-1,-,1,-1,-1) | 4 | |
| cat | | yes | (-1,-,1,-1,-1) | 4 | |
| chunk | | partially | (-1, 1, 3, -1) | any (not chunk dim) | AssertionError: Can't chunk on dynamic shape dimension! |
| clamp | | yes | (-1,-,1,-1,-1) | | |
| convolution | conv2d | partially | (-1, 3, -1, -1) | 3 | AssertionError: Channel dim can't be dynamic for convolution. |
| | conv1d | partially | (-1, 3, 3) | 1 | |
| | conv3d | partially | (-1,-,1,-1,-1) | 4 | AssertionError: Channel dim can't be dynamic for convolution. |
| dequantize | | yes | (-1,-,1,-1,-1) | 4 | |
| eimsum | | yes | (-1,-,1,-1,-1) | 4 | |
| elu | | yes | (-1,-,1,-1,-1) | 4 | |
| embedding | | yes | (-1,-,1,-1,-1) | 4 | |
| eq | SimpleConverter | yes | (-1,-,1,-1,-1) | 4 | |
| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | |
| | EqMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | EqOperatorConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | EqOperatorConstant | partially | (3,-1) | 1 | |
| | EqConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| expand | | no | | | Dynamic shape is not suitable for the expand operation. |
| flatten | | yes | (-1, -1, -1, -1, -1) | 5 | |
| gelu | | yes | (-1,-,1,-1,-1) | 4 | |
| getitem | | yes | (-1,-,1,-1,-1) | 4 | |
| gt | EqOperatorSimpleConverter | yes | (-1,-,1,-1,-1) | 4 | |
| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | |
| | GtConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | GtMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | GtOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| | EqOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| hardsigmoid | | yes | (-1,-,1,-1,-1) | 4 | |
| hardtanh | | yes | (-1,-,1,-1,-1) | 4 | |
| interpolate | | yes | (-1,-,1,-1,-1) | 4 | |
| isinf | | yes | (-1,-,1,-1,-1) | 4 | |
| leaky_relu | | yes | (-1,-,1,-1,-1) | 4 | |
| linear | | partially | (-1, 3, 5) | 1 | AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. |
| logical_and | | yes | (-1, -1, -1, -1) | 4 | |
| logical_or | | yes | (-1, -1, -1, -1) | 4 | |
| logical_xor | | yes | (-1, -1, -1, -1) | 4 | |
| lt | | yes | (-1, -1, -1, -1) | 4 | |
| masked_fill | | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] |
| mat_mul | | yes | batch dim | | |
| max | MaxFullReduce | yes | (-1, -1, -1, -1) | 4 | |
| | MaxDimReduce | yes | (-1, -1, -1, -1) | 4 | |
| | MaxMethod | yes | (-1, -1, -1, -1) | 4 | |
| maximum | | yes | (-1, -1, -1, -1) | 4 | |
| maxpool | max_pool1d | partially | (1, 1, -1) | 1 | shape is not set to (-1, -1, -1) as reshape dimension with, more than one -1 wildcard is not allowed while adding unsqueeze layer |
| | max_pool2d | yes | (-1, -1, -1, -1) | 4 | |
| | max_pool3d | yes | (-1, -1, -1, -1, -1) | 5 | |
| min | MinFullReduce | yes | (-1, -1, -1, -1) | 4 | |
| | MinDimReduce | yes | (-1, -1, -1, -1) | 4 | |
| | MinMethod | yes | (-1, -1, -1, -1) | 4 | |
| minimum | | yes | (-1, -1, -1, -1) | 4 | |
| narrow | | partially | (-1, 3, -1, -1) | 3 | AssertionError: Can't chunk on dynamic shape dimension! |
| ne | NeFunctionConverter | yes | (-1, -1, -1, -1) | 4 | |
| | NeMethodConverter | yes | (-1, -1, -1, -1) | 4 | |
| | NeOperatorConverter | yes | (-1, -1, -1, -1) | 4 | |
| | ConstInputConverter | yes | (-1, -1, -1, -1) | 4 | |
| | NeOperatorConstantConverter | partially | (3, -1) | 1 | |
| new_ones | | yes | (-1, -1, -1, -1) | 4 | |
| numel | | no | limitation in converter | | RuntimeError: numel does not support dynamic shapes. |
| pad | | no | limitation in converter | | test\_pad\_with\_dynamic\_shape\_four\_dimensions\_0\_2d (deeplearning.trt.torch\_tensorrt.py.torch\_tensorrt.fx.test.converters.acc\_op.test\_pad.TestPadConverter) ... \[07/15/2022-09:23:18\] \[TRT\] \[E\] 2: \[intInterval.cpp::max::26\] Error Code 2: Internal Error (Assertion !empty() failed. |
| permute | | yes | (-1, -1, -1, -1) | 4 | |
| prod | | yes | (-1, -1, -1, -1) | 4 | |
| quantize\_per\_tensor | | yes | (-1, -1, -1, -1) | 4 | |
| reduce op | | yes | (-1, -1, -1, -1) | 4 | |
| relu | | yes | (-1, -1, -1, -1) | 4 | |
| repeat interleave | | partially | (-1, 3, 2) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. |
| reshape | | yes | (-1, -1, -1, -1) | 4 | |
| selu | | yes | (-1, -1, -1, -1) | 4 | |
| sigmoid | | yes | (-1,-,1,-1,-1) | 4 | |
| silu | | yes | (-1,-,1,-1,-1) | 4 | |
| size | | yes | (-1, -1, -1, -1) | 4 | |
| softmax | | yes | (-1, -1, -1, -1) | 4 | |
| softsign | | yes | (-1, -1, -1, -1) | 4 | |
| split | | partially | (-1, 10, -1) | 2 | AssertionError: Can't chunk on dynamic shape dimension! |
| squeeze | | partially | (1, -1, 2) | 1 | AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. |
| std | | yes | (-1, -1, -1, -1) | 4 | |
| tanh | | yes | (-1, -1, -1, -1) | 4 | |
| tile | | yes | (-1, -1, -1, -1) | 4 | |
| to_dtype | int | yes | (-1, -1, -1, -1) | 4 | |
| | float | yes | (-1, -1, -1, -1) | 4 | |
| topk | | yes | (-1, -1, -1, -1) | 4 | |
| transpose_convolution | conv_transpose2d | partially | (-1, 3, -1, -1) | 3 | |
| | conv_transpose3d | partially | (-1, 3, -1, -1, -1) | 4 | |
| type_as | | yes | (-1, -1, -1, -1) | 4 | RuntimeError: ShapeProp error for: node=%type\_1 : \[#users=1\] = call\_method\[target=type\](args = (%input_1,), kwargs = {dtype: torch.float32}) with meta={} |
| unary ops | | yes | (-1, -1, -1, -1) | 4 | |
| unsqueeze | | partially | (-1, 2, 3) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. |
| where | | no | limitation in converter | | torch.broadcast_shape can not handle -1 dimension in shape \[-1, 2, 2\] |



Binary Ops Include following operations:
|Binary Ops |
|----------|
|add |
|sub |
|div |
|mul |
|floor_div |
|fmod |
|floor_divide|
|pow |


Unary Ops Include following operations:
|Unary Ops |
|----------|
|rsqrt |
|sin |
|cos |
|tan |
|sinh |
|cosh |
|asin |
|acos |
|atan |
|abs |
|neg |
|reciprocal|
|sqrt |
|log |
|exp |
|floor |
|ceil |
|sign |

Note: For more information about the test method, please refer to the operation test files. Additionally, test files include information about errors encountered during dynamic shape testing.
10 changes: 8 additions & 2 deletions py/torch_tensorrt/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# flake8: noqa
import logging
import math
import operator
import warnings
Expand All @@ -22,6 +23,9 @@
from .converter_utils import * # noqa: F403


_LOGGER: logging.Logger = logging.getLogger(__name__)


@tensorrt_converter(acc_ops.conv1d)
def acc_ops_conv1d(
network: TRTNetwork,
Expand Down Expand Up @@ -641,7 +645,7 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
try:
normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32)
except TypeError:
print("Unable to convert normalized_shape to a field, fall back to []")
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
normalized_shape = np.array([], dtype=np.int32)

normalized_shape_filed = trt.PluginField(
Expand All @@ -657,7 +661,9 @@ def acc_ops_layer_norm(network, target, args, kwargs, name):
else:
plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
except AssertionError:
print("Unable to find layer norm plugin, fall back to TensorRT implementation.")
_LOGGER.error(
"Unable to find layer norm plugin, fall back to TensorRT implementation."
)
return layer_norm(network, target, args, kwargs, name)
layer = network.add_plugin_v2([input_val], plugin)
layer.name = name
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def create(
cls,
lower_setting: LowerSetting,
interpreter_builder: Callable = create_lower_trt_interpreter,
split_func: Callable = default_split_function,
) -> "Lowerer":
"""Instantiate a `Lowerer` instance."""

Expand All @@ -209,7 +210,7 @@ def create(
ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list,
leaf_module_list=lower_setting.leaf_module_list,
),
split_func=default_split_function,
split_func=split_func,
lower_func=default_lower_pass(interpreter_builder),
)
)
Expand Down
13 changes: 9 additions & 4 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import logging
from functools import partial, wraps
from typing import Any, Callable, Optional, Sequence

Expand All @@ -17,6 +18,10 @@

from .lower_basic_pass import run_const_fold


_LOGGER: logging.Logger = logging.getLogger(__name__)


Input = Sequence[Any]


Expand Down Expand Up @@ -143,7 +148,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
print("Now lowering submodule", submod_name)
_LOGGER.info("Now lowering submodule", submod_name)
lowering_start_time = datetime.datetime.now()

self.lower_setting.input_specs = generate_input_specs(
Expand All @@ -160,7 +165,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
LOWER_SPLIT_POST_OBSERVER.observe(
submod_name, lowered_module, submod_inputs
)
print(
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time",
datetime.datetime.now() - lowering_start_time,
)
Expand All @@ -179,7 +184,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
print("Now lowering submodule", submod_name)
_LOGGER.info("Now lowering submodule", submod_name)
lowering_start_time = datetime.datetime.now()

lowered_module = self._lower_func(
Expand All @@ -189,7 +194,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
LOWER_SPLIT_POST_OBSERVER.observe(
submod_name, lowered_module, submod_inputs
)
print(
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time",
datetime.datetime.now() - lowering_start_time,
)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def pass_with_before_after_log(
encoding="utf-8",
delete=False,
) as f:
print(f"== Log pass {pass_} before/after graph to {f.name}")
_LOGGER.info(f"== Log pass {pass_} before/after graph to {f.name}")
print(f"[{pass_}] Before:\n{module.graph}", file=f)
module = pass_(module, input)
print(f"[{pass_}] After:\n{module.graph}", file=f)
Expand Down
12 changes: 8 additions & 4 deletions py/torch_tensorrt/fx/test/passes/test_graph_opts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import unittest
from collections import Counter
from typing import Callable, Dict, List
Expand All @@ -8,13 +9,16 @@
from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination


_LOGGER: logging.Logger = logging.getLogger(__name__)


def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None:
"""
Helper func to print model's graph in plain and tabular format, also print code.
"""
print(mod_graph.graph)
_LOGGER.info(mod_graph.graph)
mod_graph.graph.print_tabular()
print(mod_graph.code)
_LOGGER.info(mod_graph.code)


@torch.fx.wrap
Expand Down Expand Up @@ -46,7 +50,7 @@ def _test_opt_with_module(
before_results = module(*inputs)
mod_traced = acc_tracer.trace(module, inputs)
before_node_list = list(mod_traced.graph.nodes)
print("Model before opt.")
_LOGGER.info("Model before opt.")
debug_print_graph_module(mod_traced)

# Apply Opt
Expand All @@ -55,7 +59,7 @@ def _test_opt_with_module(
# After Opt
after_results = mod_traced(*inputs)
after_node_list = list(mod_traced.graph.nodes)
print("Model after opt.")
_LOGGER.info("Model after opt.")
mod_traced.recompile()
debug_print_graph_module(mod_traced)

Expand Down
8 changes: 5 additions & 3 deletions py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Owner(s): ["oncall: fx"]

import logging
import unittest
from typing import Callable, List

Expand All @@ -16,6 +16,8 @@

torch.manual_seed(0)

_LOGGER: logging.Logger = logging.getLogger(__name__)


class AccTracerTest(unittest.TestCase):
def _make_model_unit_test(
Expand Down Expand Up @@ -258,7 +260,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8
)
traced = acc_tracer.trace(m, [input])
print(traced.graph)
_LOGGER.info(traced.graph)
ph = weight_attr = bias_attr = conv = None
for node in traced.graph.nodes:
if node.op == "placeholder":
Expand Down Expand Up @@ -626,7 +628,7 @@ def run_embedding_bag_test(is_4bit, use_weights):
)

traced = acc_tracer.trace(m, inputs)
print(traced.graph)
_LOGGER.info(traced.graph)

expected_target = (
acc_ops.embedding_bag_4bit_rowwise_offsets
Expand Down
10 changes: 7 additions & 3 deletions py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: gpu_enablement"]
import functools
import glob
import logging
import os
import shutil
import tempfile
Expand All @@ -10,6 +11,9 @@
import torch_tensorrt.fx.diagnostics as diag


_LOGGER: logging.Logger = logging.getLogger(__name__)


def reset_diag(fn):
@functools.wraps(fn)
def reset(*a, **kw):
Expand Down Expand Up @@ -53,7 +57,7 @@ def boom() -> str:
zip_fn = collector._last_zip_path_for_test
assert os.path.exists(zip_fn)
with tempfile.TemporaryDirectory() as tempdir:
print(f"Unpacking into {tempdir}")
_LOGGER.info(f"Unpacking into {tempdir}")
shutil.unpack_archive(zip_fn, tempdir)
_check_file(tempdir, "aaa", "hello")
_check_file(tempdir, "bbb", "world")
Expand All @@ -78,7 +82,7 @@ def test_condition_func_name(self):
zip_fn = collector._last_zip_path_for_test
assert os.path.exists(zip_fn)
with tempfile.TemporaryDirectory() as tempdir:
print(f"Unpacking into {tempdir}")
_LOGGER.info(f"Unpacking into {tempdir}")
shutil.unpack_archive(zip_fn, tempdir)
_check_file(tempdir, "aaa", "hello")

Expand Down Expand Up @@ -160,7 +164,7 @@ def _test_cond(
if should_collect:
assert os.path.exists(zip_fn)
with tempfile.TemporaryDirectory() as tempdir:
print(f"Unpacking into {tempdir}")
_LOGGER.info(f"Unpacking into {tempdir}")
shutil.unpack_archive(zip_fn, tempdir)
_check_file(tempdir, "aaa", "hello")
else:
Expand Down
Loading

0 comments on commit 515b9b9

Please sign in to comment.