Skip to content

Commit

Permalink
load aten ops on-demand
Browse files Browse the repository at this point in the history
  • Loading branch information
FindHao committed Oct 11, 2024
1 parent bd71eee commit 766027f
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions torchbenchmark/operator_loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,19 @@

from .operator_inp_utils import aten, OperatorInputsLoader, to_channels_last

timm_loader = OperatorInputsLoader.get_timm_loader()
huggingface_loader = OperatorInputsLoader.get_huggingface_loader()
torchbench_loader = OperatorInputsLoader.get_torchbench_loader()
timm_loader = None
huggingface_loader = None
torchbench_loader = None


def maybe_load_operator_inputs_loader():
global timm_loader, huggingface_loader, torchbench_loader
if timm_loader is None:
timm_loader = OperatorInputsLoader.get_timm_loader()
if huggingface_loader is None:
huggingface_loader = OperatorInputsLoader.get_huggingface_loader()
if torchbench_loader is None:
torchbench_loader = OperatorInputsLoader.get_torchbench_loader()


def parse_args(extra_args: Optional[List[str]] = None):
Expand All @@ -36,6 +46,7 @@ def list_operators() -> List[str]:
"""In the original operator benchmark design, all operators are registered in the
operator loader. We need to collect them here.
"""
maybe_load_operator_inputs_loader()
all_ops = (
list(timm_loader.get_all_ops())
+ list(huggingface_loader.get_all_ops())
Expand Down Expand Up @@ -130,6 +141,7 @@ def dynamically_create_aten_op_class(op_eval: OpOverload):
"""
To keep same with custom operators, we dynamically create aten operator classes here.
"""
maybe_load_operator_inputs_loader()
class_name = f"aten_{str(op_eval).replace('.', '_')}"
module_name = f"torchbenchmark.operator_loader.{class_name}"
# create a new module for each operator
Expand Down

0 comments on commit 766027f

Please sign in to comment.