Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] Changes done internally at Facebook #1221

Merged
merged 1 commit into from
Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def run(
algorithm_selector=None,
timing_cache=None,
profiling_verbosity=None,
tactic_sources=None,
) -> TRTInterpreterResult:
"""
Build TensorRT engine with some configs.
Expand Down Expand Up @@ -245,6 +246,9 @@ def run(
builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE)
builder_config.algorithm_selector = algorithm_selector

if tactic_sources is not None:
builder_config.set_tactic_sources(tactic_sources=tactic_sources)

engine = self.builder.build_engine(self.network, builder_config)
assert engine

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
profiling_verbosity=trt.ProfilingVerbosity.DETAILED
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
tactic_sources=self.lower_setting.tactic_sources,
)

# Update timing cache file if needed
Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses as dc
from typing import List, Optional, Sequence, Set, Type
from typing import List, Optional, Set, Type

from torch import nn
from torch.fx.passes.pass_manager import PassManager
Expand Down Expand Up @@ -68,6 +68,8 @@ class LowerSetting(LowerSettingBasic):
opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is
only used by explicit batch dim with dynamic shape mode.
dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension.
tactic_sources: tactic sources for TensorRT kernel selection. Default to None,
meaning all possible tactic sources.
"""

input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
Expand All @@ -87,3 +89,4 @@ class LowerSetting(LowerSettingBasic):
preset_lowerer: str = ""
opt_profile_replica: int = 1
dynamic_batch: bool = True
tactic_sources: Optional[int] = None
10 changes: 4 additions & 6 deletions py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
_LOGGER.info("Now lowering submodule", submod_name)
_LOGGER.info(f"Now lowering submodule {submod_name}")
lowering_start_time = datetime.datetime.now()

self.lower_setting.input_specs = generate_input_specs(
Expand All @@ -166,8 +166,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
submod_name, lowered_module, submod_inputs
)
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time",
datetime.datetime.now() - lowering_start_time,
f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
)

return split_result.split_module
Expand All @@ -184,7 +183,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:

# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
_LOGGER.info("Now lowering submodule", submod_name)
_LOGGER.info(f"Now lowering submodule {submod_name}")
lowering_start_time = datetime.datetime.now()

lowered_module = self._lower_func(
Expand All @@ -195,8 +194,7 @@ def lower_func(split_result: SplitResult) -> nn.Module:
submod_name, lowered_module, submod_inputs
)
_LOGGER.info(
f"Lowering submodule {submod_name} elapsed time",
datetime.datetime.now() - lowering_start_time,
f"Lowering submodule {submod_name} elapsed time {datetime.datetime.now() - lowering_start_time}"
)

return split_result.split_module
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/fx/tools/trt_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self):
# During split, we'll split out the operators that
# don't support the batch dim.
self.use_implicit_batch_dim: bool = True
self.exclude_support_node_name: set = set()
self.exclude_support_node_name: set = set(self.op_lowering_disallow_list)


class TRTSplitter(splitter_base._SplitterBase):
Expand Down