Skip to content

Commit

Permalink
enable direct call to fx.compile() (#1344)
Browse files Browse the repository at this point in the history
* enable direct call to fx.compile()

* Update lower_example.py

* Update _compile.py
  • Loading branch information
Wei committed Sep 10, 2022
1 parent bfbaebe commit 3ce97fd
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/fx/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
import torchvision
from torch_tensorrt.fx.lower import compile
from torch_tensorrt.fx import compile
from torch_tensorrt.fx.utils import LowerPrecision


Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from enum import Enum

import torch_tensorrt.fx
import torch_tensorrt.fx.lower
from torch_tensorrt.fx.utils import LowerPrecision


Expand Down Expand Up @@ -140,7 +139,7 @@ def compile(
else:
raise ValueError(f"Precision {enabled_precisions} not supported on FX")

return torch_tensorrt.fx.lower.compile(
return torch_tensorrt.fx.compile(
module,
inputs,
lower_precision=lower_precision,
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
from .lower_setting import LowerSetting # noqa
from .trt_module import TRTModule # noqa
from .lower import compile # usort: skip #noqa

logging.basicConfig(level=logging.INFO)

0 comments on commit 3ce97fd

Please sign in to comment.