Skip to content

Commit

Permalink
Changes done internally at Facebook (#1299)
Browse files Browse the repository at this point in the history
bd46e8f292bf68fe6b87d2d5d206c89fda79a746 Shirong Wu <shirong@fb.com> Disable group ln fuse pass
6ce1d3bc19d75b266e99355c96daeff7054dcbf8 Wei Wei <wwei6@fb.com> [fx2trt] set logging level to INFO at fx root
9d552dc3f69db9e4a249f80ef00803a9413e5d38 Wei Wei <wwei6@fb.com> [fx2trt] change OSS method lower_to_trt() to compile()
6c4bdb8ac5823d161e4afc7c9d295f961aeeb0bf Mor Tzur <mortzur@fb.com> fix engine holder test binary to fix  contbuild_pytorch_fx2trt_build
636d0ab2a3d0f09267e25b8b8e7eedd4d91d791d Yinghai Lu <yinghai@fb.com> [easy] remove random prints
5a97668307c26e69a89a4e02a535e319eaf3ce3d Wei Wei <wwei6@fb.com> [ads] sequential linear fuse
508338ab343e407ee49605919508210b62ad9a52 Wei Wei <wwei6@fb.com> [fx2trt] minor literal fix
  • Loading branch information
Wei authored Aug 22, 2022
1 parent d7fd691 commit 6e467f2
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
1 change: 0 additions & 1 deletion py/torch_tensorrt/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
from .lower_setting import LowerSetting # noqa
from .trt_module import TRTModule # noqa
from .lower import compile

logging.basicConfig(level=logging.INFO)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def compile(
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
save_timing_cache: Update timing cache with current timing cache data if set to True.
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
dynamic_batch: batch dimension (dim=0) is dynamic.
Returns:
A torch.nn.Module lowered by TensorRT.
"""
Expand Down
47 changes: 47 additions & 0 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,53 @@ def bounded_method(*args, **kwargs):
return dec_for_method


def log_perf_before_after(pass_: PassFunc) -> PassFunc:
"""
Wraps a pass function to log perf of the module before and after the pass
"""

@wraps(pass_)
def check_perf_with_before_after_log(
module: fx.GraphModule, input: Input
) -> fx.GraphModule:
def benchmark_torch_function(iters: int, f, *args) -> float:
"""Estimates the average time duration for a single inference call in second
If the input is batched, then the estimation is for the batches inference call.
Args:
iters: number of inference iterations to run
f: a function to perform a single inference call
Returns:
estimated average time duration in second for a single inference call
"""
with torch.inference_mode():
f(*args)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# print("== Start benchmark iterations")
with torch.inference_mode():
start_event.record()
for _ in range(iters):
f(*args)
end_event.record()
torch.cuda.synchronize()
# print("== End benchmark iterations")
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters

time_before = benchmark_torch_function(100, lambda: module(*input))
_LOGGER.info(f"[{pass_}] Perf Before(eager mode): {time_before}")

module = pass_(module, input)
time_after = benchmark_torch_function(100, lambda: module(*input))
_LOGGER.info(f"[{pass_}] Perf After(eager mode): {time_after}")
return module

return check_perf_with_before_after_log


def log_before_after(pass_: PassFunc) -> PassFunc:
"""
Wraps a pass function to log the module graph before and after the pass
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,9 @@ def dropout_mapper(node: torch.fx.Node, mod: nn.Module):

assert callable(stochastic_depth)
except Exception as e:
warnings.warn(f"Unable to import torchvision related libraries.: {e}")
warnings.warn(
f"Unable to import torchvision related libraries.: {e}. Please install torchvision lib in order to lower stochastic_depth"
)
else:

@register_custom_acc_mapper_fn(
Expand Down

0 comments on commit 6e467f2

Please sign in to comment.