Skip to content

Commit

Permalink
Fix: correct quantization name filtering (#196)
Browse files Browse the repository at this point in the history
fix: correct quantization name filtering

The quantization filter based on layer names did not work, because
modules walk is done with the Module.apply method, that resolves names
locally, so the "absolute" naming does not work. The fix just prepares a
list out of the names before entering the loop, so the correct reference
is captured.
  • Loading branch information
tengomucho authored Oct 23, 2024
1 parent 664c124 commit 02927c9
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion jetstream_pt/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,18 @@

def quantize_model(float_model, config: QuantizationConfig):
"""Apply quantization to linear layers."""
exclude_mods = None
if config.exclude_layers:
exclude_mods = [
module
for name, module in float_model.named_modules()
if name in config.exclude_layers
]

def quantize_nn_mod(float_model):
for name, mod in float_model.named_modules():
new_mod = None
if config.exclude_layers and name in config.exclude_layers:
if config.exclude_layers and mod in exclude_mods:
continue
if hasattr(mod, "get_quantized_version"):
new_mod = mod.get_quantized_version()
Expand Down

0 comments on commit 02927c9

Please sign in to comment.