From e6c2225820956831f088daa8b3e8cbf5c12bd50a Mon Sep 17 00:00:00 2001 From: mgoin Date: Thu, 18 Jul 2024 17:05:54 -0400 Subject: [PATCH] Use `torch.inference_mode()` for lower memory usage during calibration (#20) --- auto_fp8/quantize.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/auto_fp8/quantize.py b/auto_fp8/quantize.py index 8c0f017..ab61674 100644 --- a/auto_fp8/quantize.py +++ b/auto_fp8/quantize.py @@ -407,6 +407,7 @@ def quantize_activations( cleanup_memory() # Pass through calibration data to measure activation scales +<<<<<<< HEAD <<<<<<< HEAD with torch.inference_mode(): with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: @@ -415,6 +416,8 @@ def quantize_activations( cleanup_memory() pbar.update(1) ======= +======= +>>>>>>> 57c31bb (Use `torch.inference_mode()` for lower memory usage during calibration (#20)) with tqdm.tqdm( total=calibration_tokens.shape[0], desc="Calibrating activation scales" ) as pbar: @@ -422,7 +425,18 @@ def quantize_activations( model(calibration_tokens[row_idx].reshape(1, -1)) cleanup_memory() pbar.update(1) +<<<<<<< HEAD >>>>>>> 3ee9283 (Support calibrating kv cache scales) +======= +======= + with torch.inference_mode(): + with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar: + for row_idx in range(calibration_tokens.shape[0]): + model(calibration_tokens[row_idx].reshape(1, -1)) + cleanup_memory() + pbar.update(1) +>>>>>>> b1c6ad6 (Use `torch.inference_mode()` for lower memory usage during calibration (#20)) +>>>>>>> 57c31bb (Use `torch.inference_mode()` for lower memory usage during calibration (#20)) # Replace dynamic quantizer observer with StaticLinear for export for name, quantizer in model.named_modules():