diff --git a/torchbenchmark/operators/op_task.py b/torchbenchmark/operators/op_task.py new file mode 100644 index 0000000000..6b8d1a1ad3 --- /dev/null +++ b/torchbenchmark/operators/op_task.py @@ -0,0 +1,182 @@ +from torchbenchmark import Worker +from torchbenchmark._components._impl.tasks import base as base_task +from torchbenchmark._components._impl.workers import subprocess_worker +import threading +import os +import torch +import dataclasses +from pathlib import Path +import gc + +from typing import Optional, Dict, Any, List + +@dataclasses.dataclass(frozen=True) +class OpDetails: + """Static description of what a particular TritonBench operator supports. + + When parameterizing tests, we only want to generate sensible ones. + (e.g. Those where an operator can be imported and supports the feature to be + tested or benchmarked.) This requires us to import the operator; however many + of the operators are EXTREMELY stateful, and even importing them consumes + significant system resources. As a result, we only want one (or a few) + alive at any given time. + + Note that affinity cannot be solved by simply calling `torch.set_num_threads` + in the child process; this will cause PyTorch to use all of the cores but + at a much lower efficiency. + + This class describes what a particular operator does and does not support, so + that we can release the underlying subprocess but retain any pertinent + metadata. + """ + + name: str + exists: bool + metadata: Dict[str, Any] + + +class OpTask(base_task.TaskBase): + + # The worker may (and often does) consume significant system resources. + # In order to ensure that runs do not interfere with each other, we only + # allow a single OpTask to exist at a time. + _lock = threading.Lock() + + def __init__( + self, + name: str, + timeout: Optional[float] = None, + extra_env: Optional[Dict[str, str]] = None, + save_output_dir: Optional[Path] = None, + ) -> None: + gc.collect() # Make sure previous task has a chance to release the lock + assert self._lock.acquire(blocking=False), "Failed to acquire lock." + + self._op_name = name + self._worker = Worker( + timeout=timeout, extra_env=extra_env, save_output_dir=save_output_dir + ) + + self.worker.run("import torch") + self._details: OpDetails = OpDetails( + **self._maybe_import_operator( + package=__name__, + op_name=name, + ) + ) + # ========================================================================= + # == Import Operator in the child process ==================================== + # ========================================================================= + + @property + def worker(self) -> subprocess_worker.SubprocessWorker: + return self._worker + + @base_task.run_in_worker(scoped=True) + @staticmethod + def _maybe_import_operator(package: str, op_name: str) -> Dict[str, Any]: + import importlib + import os + import traceback + from torchbenchmark.operators import load_opbench_by_name + + Operator = load_opbench_by_name(op_name) + + # Populate global namespace so subsequent calls to worker.run can access `Operator` + globals()["Operator"] = Operator + + # This will be used to populate a `OpDetails` instance in the parent. + return { + "name": op_name, + "exists": Operator is not None, + "metadata": {}, + } + + # ========================================================================= + # == Instantiate a concrete `op` instance ============================== + # ========================================================================= + + @base_task.run_in_worker(scoped=True) + @staticmethod + def make_operator_instance( + mode: str, + device: str, + extra_args: Optional[List[str]] = None, + ) -> None: + Operator = globals()["Operator"] + op = Operator( + mode=mode, + device=device, + extra_args=extra_args, + ) + + import gc + gc.collect() + + if device == "cuda": + torch.cuda.empty_cache() + maybe_sync = torch.cuda.synchronize + else: + maybe_sync = lambda: None + + globals().update( + { + "op": op, + "maybe_sync": maybe_sync, + } + ) + + # ========================================================================= + # == Forward calls to `op` from parent to worker ======================= + # ========================================================================= + def run(self) -> None: + self.worker.run( + """ + op.run() + maybe_sync() + """ + ) + + + # ========================================================================= + # == Get Operator attribute in the child process ============================= + # ========================================================================= + @base_task.run_in_worker(scoped=True) + @staticmethod + def get_attribute( + attr: str, + field: Optional[str] = None, + classattr: bool = False + ) -> Any: + if classattr: + op = globals()["Operator"] + else: + op = globals()["op"] + if hasattr(op, attr): + if field: + op_attr = getattr(op, attr) + return getattr(op_attr, field) + else: + return getattr(op, attr) + else: + return None + + def del_op_instance(self): + self.worker.run( + """ + del op + del maybe_sync + """ + ) + self.gc_collect() + + def gc_collect(self) -> None: + self.worker.run( + """ + import gc + gc.collect() + """ + ) + + def __del__(self) -> None: + self._lock.release() diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 55d3f8bf35..27aa1c6cd1 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -191,7 +191,6 @@ def decorator(function): def _inner(self, *args, **kwargs): return function(self, *args, **kwargs) - return _inner return decorator @@ -264,7 +263,7 @@ def __init__(self, mode: str, device: str, extra_args: List[str] = []): self.dargs, unprocessed_args = parse_decoration_args(self, extra_args) # This will be changed by the time we apply the decoration args self.dtype = PRECISION_DTYPE_MAPPING.get(self.dargs.precision, None) - if self.dargs.num_batch == None: + if self.dargs.num_batch is None: self.dargs.num_batch = self.DEFAULT_NUM_BATCH self.DEFAULT_METRICS.extend(REGISTERED_METRICS.get(self.name, [])) self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS)) @@ -286,19 +285,24 @@ def _get_bm_func(self, bm_func_name: str): else: fwd_fn = fwd_fn_lambda(*self.example_inputs) if self.mode == Mode.FWD: + setattr(fwd_fn, "_name", bm_func_name) return fwd_fn elif self.mode == Mode.BWD: - return self.get_bwd_fn(fwd_fn) + bwd_fn = self.get_bwd_fn(fwd_fn) + setattr(bwd_fn, "_name", bm_func_name) + return bwd_fn elif self.mode == Mode.FWD_BWD: bwd_fn = self.get_bwd_fn(fwd_fn) - return lambda: (fwd_fn(), bwd_fn()) + fwd_bwd_fn = lambda: (fwd_fn(), bwd_fn()) + setattr(fwd_bwd_fn, "_name", bm_func_name) + return fwd_bwd_fn def run( self, warmup=DEFAULT_WARMUP, rep=DEFAULT_RUN_ITERS, quantiles=DEFAULT_QUANTILES ) -> BenchmarkOperatorResult: """Benchmarking the operator and returning its metrics.""" metrics = [] - if self._batch_id: + if self._batch_id is not None: # Run only the user-specific batch id batch_range = range(self._batch_id + 1) else: @@ -307,7 +311,7 @@ def run( if self._batch_id and batch_id < self._batch_id: continue self.example_inputs = self.get_example_inputs() - if self.example_inputs == None: + if self.example_inputs is None: warnings.warn( UserWarning( f"The input generator get_input_iter() has depleted. Maximum input batches {_dp}." @@ -457,7 +461,7 @@ def enable_channels_last(self): ) def get_example_inputs(self): - if self._input_iter == None: + if self._input_iter is None: self._input_iter = self.get_input_iter() try: return next(self._input_iter) @@ -551,8 +555,9 @@ def _do_bench( # run the hidden metric "_compile_time_in_task" # to get the compile time in parent process if "_compile_time_in_task" in self.required_metrics: - assert self.required_metrics == ["_compile_time_in_task"] and self._only and self._batch_id, \ - "_compile_time_in_task must be measured by itself." + assert self.required_metrics == ["_compile_time_in_task"] and self._only and (self._batch_id is not None), \ + "_compile_time_in_task must be measured by itself. " \ + f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}" extra_metrics["_compile_time_in_task"] = self._compile_time_in_task(fn) # generate customized metrics if self.name in REGISTERED_METRICS: @@ -580,13 +585,27 @@ def _do_bench( @register_metric() def compile_time(self, batch_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics) -> float: + # We need to spawn a subprocess when user wants to measure the compile time + # of multiple batches and backends. + def _find_loc(l, key: str) -> int: + try: + return l.index(key) + except ValueError: + return -1 + def _remove_element(l, loc): + if loc == -1: + return l + return l[:loc] + l[loc+2:] from torchbenchmark.operators.op_task import OpTask - op_task = OpTask(name=self.name) op_task_args = copy.deepcopy(self._raw_extra_args) + for override_option in ["--only", "--batch-id", "--metrics"]: + op_task_args = _remove_element(op_task_args, _find_loc(op_task_args, override_option)) op_task_args.extend(["--only", fn_name, "--batch-id", str(batch_id), "--metrics", "_compile_time_in_task"]) + op_task = OpTask(name=self.name) op_task.make_operator_instance(mode=self.mode.value, device=self.device, extra_args=op_task_args) op_task.run() latency_with_compile = op_task.get_attribute("_latency_with_compile_in_task") + del op_task latency_without_compile = numpy.median(metrics.latency) return latency_with_compile - latency_without_compile @@ -604,6 +623,7 @@ def _compile_time_in_task( end_event.record() torch.cuda.synchronize() # Wait for the events to be recorded! latency_with_compile = start_event.elapsed_time(end_event) + self._latency_with_compile_in_task = latency_with_compile return latency_with_compile @register_metric()