Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Changes done internally at Facebook #1299

Merged
merged 1 commit into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion py/torch_tensorrt/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
from .lower_setting import LowerSetting # noqa
from .trt_module import TRTModule # noqa
from .lower import compile

logging.basicConfig(level=logging.INFO)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def compile(
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
save_timing_cache: Update timing cache with current timing cache data if set to True.
cuda_graph_batch_size: Cuda graph batch size, default to be -1.

dynamic_batch: batch dimension (dim=0) is dynamic.
Returns:
A torch.nn.Module lowered by TensorRT.
"""
Expand Down
47 changes: 47 additions & 0 deletions py/torch_tensorrt/fx/passes/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,53 @@ def bounded_method(*args, **kwargs):
return dec_for_method


def log_perf_before_after(pass_: PassFunc) -> PassFunc:
"""
Wraps a pass function to log perf of the module before and after the pass
"""

@wraps(pass_)
def check_perf_with_before_after_log(
module: fx.GraphModule, input: Input
) -> fx.GraphModule:
def benchmark_torch_function(iters: int, f, *args) -> float:
"""Estimates the average time duration for a single inference call in second

If the input is batched, then the estimation is for the batches inference call.

Args:
iters: number of inference iterations to run
f: a function to perform a single inference call

Returns:
estimated average time duration in second for a single inference call
"""
with torch.inference_mode():
f(*args)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# print("== Start benchmark iterations")
with torch.inference_mode():
start_event.record()
for _ in range(iters):
f(*args)
end_event.record()
torch.cuda.synchronize()
# print("== End benchmark iterations")
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters

time_before = benchmark_torch_function(100, lambda: module(*input))
_LOGGER.info(f"[{pass_}] Perf Before(eager mode): {time_before}")

module = pass_(module, input)
time_after = benchmark_torch_function(100, lambda: module(*input))
_LOGGER.info(f"[{pass_}] Perf After(eager mode): {time_after}")
return module

return check_perf_with_before_after_log


def log_before_after(pass_: PassFunc) -> PassFunc:
"""
Wraps a pass function to log the module graph before and after the pass
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,9 @@ def dropout_mapper(node: torch.fx.Node, mod: nn.Module):

assert callable(stochastic_depth)
except Exception as e:
warnings.warn(f"Unable to import torchvision related libraries.: {e}")
warnings.warn(
f"Unable to import torchvision related libraries.: {e}. Please install torchvision lib in order to lower stochastic_depth"
)
else:

@register_custom_acc_mapper_fn(
Expand Down