Skip to content

Commit

Permalink
Add the compile time metric retrieved from the subprocess
Browse files Browse the repository at this point in the history
Summary: Add the compile_time metric to measure the triton compilation time.

Reviewed By: chenyang78

Differential Revision: D55943230

fbshipit-source-id: 5a1aa0a11d840a88642ec46b9c7d9598c7113534
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Apr 11, 2024
1 parent abbbd49 commit d978fcc
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 10 deletions.
182 changes: 182 additions & 0 deletions torchbenchmark/operators/op_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from torchbenchmark import Worker
from torchbenchmark._components._impl.tasks import base as base_task
from torchbenchmark._components._impl.workers import subprocess_worker
import threading
import os
import torch
import dataclasses
from pathlib import Path
import gc

from typing import Optional, Dict, Any, List

@dataclasses.dataclass(frozen=True)
class OpDetails:
"""Static description of what a particular TritonBench operator supports.
When parameterizing tests, we only want to generate sensible ones.
(e.g. Those where an operator can be imported and supports the feature to be
tested or benchmarked.) This requires us to import the operator; however many
of the operators are EXTREMELY stateful, and even importing them consumes
significant system resources. As a result, we only want one (or a few)
alive at any given time.
Note that affinity cannot be solved by simply calling `torch.set_num_threads`
in the child process; this will cause PyTorch to use all of the cores but
at a much lower efficiency.
This class describes what a particular operator does and does not support, so
that we can release the underlying subprocess but retain any pertinent
metadata.
"""

name: str
exists: bool
metadata: Dict[str, Any]


class OpTask(base_task.TaskBase):

# The worker may (and often does) consume significant system resources.
# In order to ensure that runs do not interfere with each other, we only
# allow a single OpTask to exist at a time.
_lock = threading.Lock()

def __init__(
self,
name: str,
timeout: Optional[float] = None,
extra_env: Optional[Dict[str, str]] = None,
save_output_dir: Optional[Path] = None,
) -> None:
gc.collect() # Make sure previous task has a chance to release the lock
assert self._lock.acquire(blocking=False), "Failed to acquire lock."

self._op_name = name
self._worker = Worker(
timeout=timeout, extra_env=extra_env, save_output_dir=save_output_dir
)

self.worker.run("import torch")
self._details: OpDetails = OpDetails(
**self._maybe_import_operator(
package=__name__,
op_name=name,
)
)
# =========================================================================
# == Import Operator in the child process ====================================
# =========================================================================

@property
def worker(self) -> subprocess_worker.SubprocessWorker:
return self._worker

@base_task.run_in_worker(scoped=True)
@staticmethod
def _maybe_import_operator(package: str, op_name: str) -> Dict[str, Any]:
import importlib
import os
import traceback
from torchbenchmark.operators import load_opbench_by_name

Operator = load_opbench_by_name(op_name)

# Populate global namespace so subsequent calls to worker.run can access `Operator`
globals()["Operator"] = Operator

# This will be used to populate a `OpDetails` instance in the parent.
return {
"name": op_name,
"exists": Operator is not None,
"metadata": {},
}

# =========================================================================
# == Instantiate a concrete `op` instance ==============================
# =========================================================================

@base_task.run_in_worker(scoped=True)
@staticmethod
def make_operator_instance(
mode: str,
device: str,
extra_args: Optional[List[str]] = None,
) -> None:
Operator = globals()["Operator"]
op = Operator(
mode=mode,
device=device,
extra_args=extra_args,
)

import gc
gc.collect()

if device == "cuda":
torch.cuda.empty_cache()
maybe_sync = torch.cuda.synchronize
else:
maybe_sync = lambda: None

globals().update(
{
"op": op,
"maybe_sync": maybe_sync,
}
)

# =========================================================================
# == Forward calls to `op` from parent to worker =======================
# =========================================================================
def run(self) -> None:
self.worker.run(
"""
op.run()
maybe_sync()
"""
)


# =========================================================================
# == Get Operator attribute in the child process =============================
# =========================================================================
@base_task.run_in_worker(scoped=True)
@staticmethod
def get_attribute(
attr: str,
field: Optional[str] = None,
classattr: bool = False
) -> Any:
if classattr:
op = globals()["Operator"]
else:
op = globals()["op"]
if hasattr(op, attr):
if field:
op_attr = getattr(op, attr)
return getattr(op_attr, field)
else:
return getattr(op, attr)
else:
return None

def del_op_instance(self):
self.worker.run(
"""
del op
del maybe_sync
"""
)
self.gc_collect()

def gc_collect(self) -> None:
self.worker.run(
"""
import gc
gc.collect()
"""
)

def __del__(self) -> None:
self._lock.release()
40 changes: 30 additions & 10 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ def decorator(function):

def _inner(self, *args, **kwargs):
return function(self, *args, **kwargs)

return _inner

return decorator
Expand Down Expand Up @@ -264,7 +263,7 @@ def __init__(self, mode: str, device: str, extra_args: List[str] = []):
self.dargs, unprocessed_args = parse_decoration_args(self, extra_args)
# This will be changed by the time we apply the decoration args
self.dtype = PRECISION_DTYPE_MAPPING.get(self.dargs.precision, None)
if self.dargs.num_batch == None:
if self.dargs.num_batch is None:
self.dargs.num_batch = self.DEFAULT_NUM_BATCH
self.DEFAULT_METRICS.extend(REGISTERED_METRICS.get(self.name, []))
self.DEFAULT_METRICS = list(set(self.DEFAULT_METRICS))
Expand All @@ -286,19 +285,24 @@ def _get_bm_func(self, bm_func_name: str):
else:
fwd_fn = fwd_fn_lambda(*self.example_inputs)
if self.mode == Mode.FWD:
setattr(fwd_fn, "_name", bm_func_name)
return fwd_fn
elif self.mode == Mode.BWD:
return self.get_bwd_fn(fwd_fn)
bwd_fn = self.get_bwd_fn(fwd_fn)
setattr(bwd_fn, "_name", bm_func_name)
return bwd_fn
elif self.mode == Mode.FWD_BWD:
bwd_fn = self.get_bwd_fn(fwd_fn)
return lambda: (fwd_fn(), bwd_fn())
fwd_bwd_fn = lambda: (fwd_fn(), bwd_fn())
setattr(fwd_bwd_fn, "_name", bm_func_name)
return fwd_bwd_fn

def run(
self, warmup=DEFAULT_WARMUP, rep=DEFAULT_RUN_ITERS, quantiles=DEFAULT_QUANTILES
) -> BenchmarkOperatorResult:
"""Benchmarking the operator and returning its metrics."""
metrics = []
if self._batch_id:
if self._batch_id is not None:
# Run only the user-specific batch id
batch_range = range(self._batch_id + 1)
else:
Expand All @@ -307,7 +311,7 @@ def run(
if self._batch_id and batch_id < self._batch_id:
continue
self.example_inputs = self.get_example_inputs()
if self.example_inputs == None:
if self.example_inputs is None:
warnings.warn(
UserWarning(
f"The input generator get_input_iter() has depleted. Maximum input batches {_dp}."
Expand Down Expand Up @@ -457,7 +461,7 @@ def enable_channels_last(self):
)

def get_example_inputs(self):
if self._input_iter == None:
if self._input_iter is None:
self._input_iter = self.get_input_iter()
try:
return next(self._input_iter)
Expand Down Expand Up @@ -551,8 +555,9 @@ def _do_bench(
# run the hidden metric "_compile_time_in_task"
# to get the compile time in parent process
if "_compile_time_in_task" in self.required_metrics:
assert self.required_metrics == ["_compile_time_in_task"] and self._only and self._batch_id, \
"_compile_time_in_task must be measured by itself."
assert self.required_metrics == ["_compile_time_in_task"] and self._only and (self._batch_id is not None), \
"_compile_time_in_task must be measured by itself. " \
f"required_metrics: {self.required_metrics}, _only: {self._only}, _batch_id: {self._batch_id}"
extra_metrics["_compile_time_in_task"] = self._compile_time_in_task(fn)
# generate customized metrics
if self.name in REGISTERED_METRICS:
Expand Down Expand Up @@ -580,13 +585,27 @@ def _do_bench(

@register_metric()
def compile_time(self, batch_id: int, fn_name: str, metrics: BenchmarkOperatorMetrics) -> float:
# We need to spawn a subprocess when user wants to measure the compile time
# of multiple batches and backends.
def _find_loc(l, key: str) -> int:
try:
return l.index(key)
except ValueError:
return -1
def _remove_element(l, loc):
if loc == -1:
return l
return l[:loc] + l[loc+2:]
from torchbenchmark.operators.op_task import OpTask
op_task = OpTask(name=self.name)
op_task_args = copy.deepcopy(self._raw_extra_args)
for override_option in ["--only", "--batch-id", "--metrics"]:
op_task_args = _remove_element(op_task_args, _find_loc(op_task_args, override_option))
op_task_args.extend(["--only", fn_name, "--batch-id", str(batch_id), "--metrics", "_compile_time_in_task"])
op_task = OpTask(name=self.name)
op_task.make_operator_instance(mode=self.mode.value, device=self.device, extra_args=op_task_args)
op_task.run()
latency_with_compile = op_task.get_attribute("_latency_with_compile_in_task")
del op_task
latency_without_compile = numpy.median(metrics.latency)
return latency_with_compile - latency_without_compile

Expand All @@ -604,6 +623,7 @@ def _compile_time_in_task(
end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded!
latency_with_compile = start_event.elapsed_time(end_event)
self._latency_with_compile_in_task = latency_with_compile
return latency_with_compile

@register_metric()
Expand Down

0 comments on commit d978fcc

Please sign in to comment.