diff --git a/torchbenchmark/util/backends/torchdynamo.py b/torchbenchmark/util/backends/torchdynamo.py index 82bfe35b82..8fa8127d00 100644 --- a/torchbenchmark/util/backends/torchdynamo.py +++ b/torchbenchmark/util/backends/torchdynamo.py @@ -50,6 +50,11 @@ def parse_torchdynamo_args(dynamo_args: List[str]) -> argparse.Namespace: action="store_true", help="Measure metrics with TorchInductor", ) + parser.add_argument( + "--cold-start", + action="store_true", + help="Use a fresh inductor and triton cachedir when running each model, to force cold-start compile.", + ) parser.add_argument( "--inductor-compile-mode", default=None, @@ -221,6 +226,10 @@ def apply_torchdynamo_args( "--dynamo_disable_optimizer_step is set to True, but the optimizer could not be found on this model" ) + if args.cold_start: + from torch._inductor.utils import fresh_inductor_cache + fresh_inductor_context = lambda: fresh_inductor_cache() + model.run_contexts.append(fresh_inductor_context) if model.test == "train": if is_staged_train_test(model): model.forward = optimize_ctx(model.forward) diff --git a/torchbenchmark/util/model.py b/torchbenchmark/util/model.py index e21871398f..50faefe92a 100644 --- a/torchbenchmark/util/model.py +++ b/torchbenchmark/util/model.py @@ -382,15 +382,16 @@ def _invoke_staged_train_test(self, num_batch: int) -> None: self.example_inputs = next(input_generator) # cast inputs if needed apply_decoration_args(self, self.dargs) - if optimizer is not None: - optimizer.zero_grad() - with nested(*self.forward_contexts): - losses = self.forward() - with nested(*self.backward_contexts): - self.backward(losses) - if optimizer is not None: - with nested(*self.optimizer_contexts): - self.optimizer_step() + with nested(*self.run_contexts): + if optimizer is not None: + optimizer.zero_grad() + with nested(*self.forward_contexts): + losses = self.forward() + with nested(*self.backward_contexts): + self.backward(losses) + if optimizer is not None: + with nested(*self.optimizer_contexts): + self.optimizer_step() return None def invoke(self) -> Optional[Tuple[torch.Tensor]]: