From 57c31bb271cfd16dda7bf8de8acf20bf54217ac3 Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 18 Jul 2024 17:05:54 -0400 Subject: [PATCH] Use `torch.inference_mode()` for lower memory usage during calibration (#20) --- auto_fp8/quantize.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index b75529a..277dd0d 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -272,6 +272,7 @@ def quantize_activations( cleanup_memory() # Pass through calibration data to measure activation scales +<<<<<<< HEAD with tqdm.tqdm( total=calibration_tokens.shape[0], desc="Calibrating activation scales" ) as pbar: @@ -279,6 +280,14 @@ def quantize_activations( model(calibration_tokens[row_idx].reshape(1, -1)) cleanup_memory() pbar.update(1) +======= + with torch.inference_mode(): + with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: + for row_idx in range(calibration_tokens.shape[0]): + model(calibration_tokens[row_idx].reshape(1, -1)) + cleanup_memory() + pbar.update(1) +>>>>>>> b1c6ad6 (Use `torch.inference_mode()` for lower memory usage during calibration (#20)) # Replace dynamic quantizer observer with StaticLinear for export for name, quantizer in model.named_modules():