From d80f78fc7aa38c65164a51c198d8bb07711ee54f Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 11 Sep 2024 12:30:02 -0700 Subject: [PATCH] Read SpinQuant checkpoints (#5259) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5259 Read SpinQuant checkpoints that is in exported with scales/weights. bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: iseeyuan, helunwencser Differential Revision: D62403094 fbshipit-source-id: 283ae18a1d2053306677086b9edd5cb5f38120ee --- examples/models/llama2/export_llama_lib.py | 35 ++++--- examples/models/llama2/model.py | 47 +++++++++- .../source_transformation/spin_quant.py | 93 +++++++++++++++++++ examples/models/llama2/tests/TARGETS | 13 +++ .../llama2/tests/test_spinquant_transforms.py | 89 ++++++++++++++++++ pytest.ini | 2 + 6 files changed, 262 insertions(+), 17 deletions(-) create mode 100644 examples/models/llama2/tests/test_spinquant_transforms.py diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 97228bb5c5..2a03c0cebd 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -695,6 +695,7 @@ def _load_llama_model( fairseq2=weight_type == WeightType.FAIRSEQ2, max_seq_len=max_seq_len, enable_dynamic_shape=enable_dynamic_shape, + args=args, ) state_dict = model.state_dict() dtype = state_dict[next(iter(state_dict))].dtype @@ -747,9 +748,26 @@ def _get_source_transforms( transforms = [] if args.quantization_mode: modelname = f"{modelname}_q" - transforms.append( - get_quant_weight_transform(args, dtype_override, verbose_export()) - ) + if args.use_spin_quant is None: + transforms.append( + get_quant_weight_transform(args, dtype_override, verbose_export()) + ) + # For SpinQuant, the checkpoints are already quantized + # aka the weights have corresponding scales value, + # So that means, we don't need to apply quantization + # transform. However, we will still need to apply + # transformations that change the model structure to + # match the checkpoint format. + # transform_for_spinquant() will apply these transformations + # later in model.py file. + elif args.use_spin_quant == "cuda": + from .source_transformation.spin_quant import ( + inject_fast_hadamard_transform_cuda_for_spin_quant, + ) + + transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant) + elif args.use_spin_quant == "native": + raise NotImplementedError("native SpinQuant is not implemented yet.") if args.embedding_quantize: modelname = f"{modelname}_e" @@ -783,15 +801,4 @@ def _get_source_transforms( transforms.append(replace_sdpa_with_simple_sdpa) transforms.append(replace_causal_mask) - if args.use_spin_quant: - if args.use_spin_quant == "cuda": - from .source_transformation.spin_quant import ( - inject_fast_hadamard_transform_cuda_for_spin_quant, - ) - - transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant) - - elif args.use_spin_quant == "native": - raise NotImplementedError("native SpinQuant is not implemented yet.") - return transforms diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index f58a2a2def..174f562f93 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -65,6 +65,7 @@ def __init__(self, **kwargs): self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False) self.max_seq_len = kwargs.get("max_seq_len", 128) + self.args = kwargs.get("args", None) # The example is using a dummy small model with random weights for demo purpose only. # Follow the instruction in https://github.com/facebookresearch/llama to download the model device = "cpu" @@ -126,7 +127,8 @@ def __init__(self, **kwargs): # get checkpoint dtype self.dtype = None if len(checkpoint) > 0: - first = checkpoint[next(iter(checkpoint))] + first_key = next(iter(checkpoint)) + first = checkpoint[first_key] self.dtype = first.dtype mismatched_dtypes = [ (key, value.dtype) @@ -135,7 +137,7 @@ def __init__(self, **kwargs): ] if len(mismatched_dtypes) > 0: print( - f"Mixed dtype model. Dtype of {first.key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}" + f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}" ) with open(params_path, "r") as f: params = json.loads(f.read()) @@ -179,15 +181,54 @@ def __init__(self, **kwargs): self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime( self.model_ ) + elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant: + print("Using SPIN quantization.") + assert hasattr(self.args, "group_size"), "group_size must be specified" + assert hasattr( + self.args, "quantization_mode" + ), "quantization_mode must be specified" + assert hasattr( + self.args, "dtype_override" + ), "dtype_override must be specified" + from .source_transformation.spin_quant import ( + sanitize_checkpoint_from_spinquant, + transform_for_spinquant, + ) + + mapping = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + + self.model_ = transform_for_spinquant( + self.model_, + checkpoint, + self.args.group_size, + self.args.quantization_mode, + mapping[self.args.dtype_override], + ) + + sanitize_checkpoint_from_spinquant( + checkpoint, + self.args.group_size, + ) # assign=True: load params/buffers by assignment instead of performing an in-place copy. # Because we are using device="meta", tensors do not have memory associated with them # and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario. - self.model_.load_state_dict( + missing, unexpected = self.model_.load_state_dict( checkpoint, strict=False, assign=True, ) # self.model_ = Transformer(gptconf) + if kwargs.get("verbose", False): + print("============= missing keys ================") + print(missing) + print("============= /missing ================") + print("============= unexpected keys ================") + print(unexpected) + print("============= /unexpected ================") def get_eager_model(self): if self.dtype: diff --git a/examples/models/llama2/source_transformation/spin_quant.py b/examples/models/llama2/source_transformation/spin_quant.py index 7b38312c18..a45db190f4 100644 --- a/examples/models/llama2/source_transformation/spin_quant.py +++ b/examples/models/llama2/source_transformation/spin_quant.py @@ -9,12 +9,16 @@ # Helper functions for tranforming the model to be able to run SpinQuant. # See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant. +from typing import Any + import torch import torch.nn.functional as F from executorch.examples.models.llama2.llama_transformer import FeedForward from torch import nn +from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module): @@ -53,3 +57,92 @@ def inject_fast_hadamard_transform_cuda_for_spin_quant( ) -> torch.nn.Module: _inject_fast_hadamard_transform_cuda_for_spin_quant(module) return module + + +def _replace_linear_with_linear_8da4w_for_spin_quant( + module: torch.nn.Module, + checkpoint: Any, + group_size: int, + precision: torch.dtype, + scales_precision: torch.dtype, +): + def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: + # Only replace linear layers where the checkpoint contains explicit scales + scales_key = f"{cur_fqn}.scale" + if isinstance(child, nn.Linear) and scales_key in checkpoint: + assert _check_linear_int4_k(child.in_features, group_size) + assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8 + assert checkpoint[scales_key].dtype == scales_precision + return True + return False + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_linear = Int8DynActInt4WeightLinear( + child.in_features, + child.out_features, + bias=False, + device=child.weight.device, + groupsize=group_size, + precision=precision, + scales_precision=scales_precision, + ) + return new_linear + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + + +def transform_for_spinquant( + module: torch.nn.Module, + checkpoint: Any, + group_size: int, + quantization_mode: str, + dtype: torch.dtype, +) -> torch.nn.Module: + """ + Transform the model to be able to load SpinQuant checkpoints that + are quantized with the given group size and quantization mode. + """ + + if group_size not in [32, 64, 128, 256]: + raise ValueError(f"Group size {group_size} is not supported for SpinQuant.") + if quantization_mode not in ["8da4w"]: + raise ValueError( + f"Quantization mode {quantization_mode} is not compatible with SpinQuant." + ) + _replace_linear_with_linear_8da4w_for_spin_quant( + module, + checkpoint, + group_size, + dtype, + dtype, + ) + return module + + +def sanitize_checkpoint_from_spinquant( + checkpoint: Any, + group_size: int, +): + """ + Sanitize the SpinQuant checkpoint. + - Renames 'scale' to 'scales' + - Groups scales + - Removes 'o_weight' + - Converts all tensors to contiguous format + """ + keys_to_rename = [] + keys_to_remove = [] + for k, _ in checkpoint.items(): + if k.endswith(".scale"): + new_key = k + "s" + keys_to_rename.append((k, new_key)) + if k.endswith(".o_weight"): + keys_to_remove.append(k) + + for old_key, new_key in keys_to_rename: + old_val = checkpoint.pop(old_key) + checkpoint[new_key] = old_val if group_size == -1 else old_val[:, ::group_size] + for k in keys_to_remove: + checkpoint.pop(k) + for k, v in checkpoint.items(): + checkpoint[k] = v.contiguous() diff --git a/examples/models/llama2/tests/TARGETS b/examples/models/llama2/tests/TARGETS index 3d2aef6209..76981d8f31 100644 --- a/examples/models/llama2/tests/TARGETS +++ b/examples/models/llama2/tests/TARGETS @@ -13,3 +13,16 @@ python_unittest( "//executorch/examples/models/llama2:llama_transformer", ], ) + +python_unittest( + name = "test_spinquant_transforms", + srcs = [ + "test_spinquant_transforms.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama2:export_library", + "//executorch/examples/models/llama2:llama_transformer", + "//pytorch/ao:torchao", + ], +) diff --git a/examples/models/llama2/tests/test_spinquant_transforms.py b/examples/models/llama2/tests/test_spinquant_transforms.py new file mode 100644 index 0000000000..bd56632c5f --- /dev/null +++ b/examples/models/llama2/tests/test_spinquant_transforms.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer +from executorch.examples.models.llama2.source_transformation.spin_quant import ( + sanitize_checkpoint_from_spinquant, + transform_for_spinquant, +) +from torchao.quantization.utils import group_quantize_tensor_symmetric + + +class SpinQuantTests(unittest.TestCase): + def test_transforms_for_spinquant(self): + + # Step 1: Create llama class with dummy weights + params = { + "dim": 768, + "multiple_of": 32, + "n_heads": 12, + "n_layers": 12, + "norm_eps": 1e-05, + "vocab_size": 32000, + } + + model_args = ModelArgs( + max_seq_len=2048, + max_batch_size=1, + use_kv_cache=False, + use_sdpa_with_kv_cache_op=False, + generate_full_logits=False, + enable_dynamic_shape=True, + **params, + ) + + model = Transformer(model_args) + checkpoint = model.state_dict() + + # Step 2: + # Do group-wise quantization and amend the checkpoints with + # int8 weight and fp32 scales + group_size = 32 + n_bit = 4 + scales_precision = torch.float32 + for fqn, mod in model.named_modules(): + # Quantize everything except the last layer + if isinstance(mod, torch.nn.Linear) and ("output" not in fqn): + weight = mod.weight.data + ( + weight_int8, + scales, + zeros, + ) = group_quantize_tensor_symmetric( + weight.to(torch.float32), n_bit, group_size, scales_precision + ) + checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu") + checkpoint[f"{fqn}.scale"] = scales.to("cpu") + + # Step 3: + # Transform the model so that it is compatible with the new checkpoint + transform_for_spinquant( + model, + checkpoint, + 32, + "8da4w", + torch.float32, + ) + sanitize_checkpoint_from_spinquant( + checkpoint, + -1, + ) + + model.load_state_dict( + checkpoint, + strict=False, + assign=True, + ) + + new_checkpoint = model.state_dict() + + for k, v in checkpoint.items(): + # The new_checkpoint contains zeros so + # have to iterate over the keys. + self.assertTrue(torch.allclose(new_checkpoint[k], v)) diff --git a/pytest.ini b/pytest.ini index 7298773255..701c0187ec 100644 --- a/pytest.ini +++ b/pytest.ini @@ -38,6 +38,8 @@ addopts = test/end2end/test_end2end.py --ignore=backends/xnnpack/test/ops/linear.py --ignore=backends/xnnpack/test/models/llama2_et_example.py + # T200992559: Add torchao to ET as core dependency + --ignore=examples/models/llama2/tests/test_spinquant_transforms.py --ignore=exir/backend/test/demos --ignore=exir/backend/test/test_backends.py --ignore=exir/backend/test/test_backends_lifted.py