Skip to content

Commit

Permalink
Use torch.inference_mode() for lower memory usage during calibration (
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jul 18, 2024
1 parent 0eac983 commit e6c2225
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def quantize_activations(
cleanup_memory()

# Pass through calibration data to measure activation scales
<<<<<<< HEAD
<<<<<<< HEAD
with torch.inference_mode():
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
Expand All @@ -415,14 +416,27 @@ def quantize_activations(
cleanup_memory()
pbar.update(1)
=======
=======
>>>>>>> 57c31bb (Use `torch.inference_mode()` for lower memory usage during calibration (#20))
with tqdm.tqdm(
total=calibration_tokens.shape[0], desc="Calibrating activation scales"
) as pbar:
for row_idx in range(calibration_tokens.shape[0]):
model(calibration_tokens[row_idx].reshape(1, -1))
cleanup_memory()
pbar.update(1)
<<<<<<< HEAD
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
=======
=======
with torch.inference_mode():
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
for row_idx in range(calibration_tokens.shape[0]):
model(calibration_tokens[row_idx].reshape(1, -1))
cleanup_memory()
pbar.update(1)
>>>>>>> b1c6ad6 (Use `torch.inference_mode()` for lower memory usage during calibration (#20))
>>>>>>> 57c31bb (Use `torch.inference_mode()` for lower memory usage during calibration (#20))

# Replace dynamic quantizer observer with StaticLinear for export
for name, quantizer in model.named_modules():
Expand Down

0 comments on commit e6c2225

Please sign in to comment.