Skip to content

Commit

Permalink
Switch from output_scale to kv_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jun 18, 2024
1 parent bb15a62 commit 8556c86
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 16 deletions.
10 changes: 5 additions & 5 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
)

if quantize_config.kv_cache_quant_targets:
kv_cache_quant_layers = get_kv_cache_quant_layer(
kv_cache_quant_layers = get_kv_cache_quant_layers(
self.model, quantize_config.kv_cache_quant_targets
)
if len(kv_cache_quant_layers) == 0:
Expand Down Expand Up @@ -159,15 +159,15 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
return list(ignored_layers)


def get_kv_cache_quant_layer(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
kv_cache_quant_layers = set()
def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
kv_cache_quant_layers = []

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue

for output_quant_target in kv_cache_quant_targets:
if name.endswith(output_quant_target):
kv_cache_quant_layers.add(name)
kv_cache_quant_layers.append(name)

return list(kv_cache_quant_layers)
return kv_cache_quant_layers
32 changes: 28 additions & 4 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ def __init__(
def forward(self, x):
qinput, x_input_scale = per_tensor_quantize(x)
if self.input_scale is None:
self.input_scale = torch.nn.Parameter(x_input_scale)
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
elif x_input_scale > self.input_scale:
self.input_scale = torch.nn.Parameter(x_input_scale)
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
output = fp8_gemm(
A=qinput,
A_scale=self.input_scale,
Expand All @@ -168,9 +168,9 @@ def forward(self, x):
if self.quantize_output:
qoutput, output_scale = per_tensor_quantize(output)
if self.output_scale is None:
self.output_scale = torch.nn.Parameter(output_scale)
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
elif output_scale > self.output_scale:
self.output_scale = torch.nn.Parameter(output_scale)
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
output = qoutput.to(output.dtype) * output_scale

return output
Expand Down Expand Up @@ -297,6 +297,30 @@ def quantize_activations(
del quantizer
cleanup_memory()

# Post-process step for kv cache scales to take the k/v module
# `output_scale` parameters, take the max of them, and store them in
# the parent attention module as `kv_scale`
# NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
if hasattr(quantize_config, "kv_cache_quant_layers"):
# Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
# so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
kv_proj_pairs = zip(*[iter(quantize_config.kv_cache_quant_layers)]*2)
for k_proj_name, v_proj_name in kv_proj_pairs:
parent_module_name = ".".join(k_proj_name.split(".")[:-1])
assert parent_module_name == ".".join(v_proj_name.split(".")[:-1])
parent_module = dict(model.named_modules())[parent_module_name]

k_proj = dict(model.named_modules())[k_proj_name]
v_proj = dict(model.named_modules())[v_proj_name]

kv_scale = max(k_proj.output_scale, v_proj.output_scale)
parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False)

# Remove output_scale from k_proj and v_proj
k_proj.output_scale = None
v_proj.output_scale = None
cleanup_memory()


def save_quantized_model(
model: AutoModelForCausalLM,
Expand Down
14 changes: 7 additions & 7 deletions tests/test_auto_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_dynamic_quantization(model_id, target_size):
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
shutil.rmtree(quantized_model_dir)

# We expect the model to be a certain size
# We expect the quantized model to be a certain size
target_size = target_size * (1024 * 1024)
assert model_size < target_size

Expand All @@ -55,7 +55,7 @@ def test_static_quantization(model_id, target_size):
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
shutil.rmtree(quantized_model_dir)

# We expect the model to be < 160MB
# We expect the quantized model to be a certain size
target_size = target_size * (1024 * 1024)
assert model_size < target_size

Expand All @@ -81,18 +81,18 @@ def test_kv_cache_static_quantization(model_id, target_size):

tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors")
proj_linear_count = 0
output_scale_count = 0
kv_scale_count = 0
for name, _ in tensors.items():
if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"):
proj_linear_count += 1
if name.endswith("k_proj.output_scale") or name.endswith("v_proj.output_scale"):
output_scale_count += 1
assert proj_linear_count == output_scale_count
if name.endswith("kv_scale"):
kv_scale_count += 1
assert proj_linear_count // 2 == kv_scale_count

# Measure checkpoint size and cleanup
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
shutil.rmtree(quantized_model_dir)

# We expect the model to be < 160MB
# We expect the quantized model to be a certain size
target_size = target_size * (1024 * 1024)
assert model_size < target_size

0 comments on commit 8556c86

Please sign in to comment.