Skip to content

Commit

Permalink
Read SpinQuant checkpoints (pytorch#5259)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#5259

Read SpinQuant checkpoints that is in exported with scales/weights.

bypass-github-export-checks
bypass-github-pytorch-ci-checks
bypass-github-executorch-ci-checks

Reviewed By: iseeyuan, helunwencser

Differential Revision: D62403094

fbshipit-source-id: 283ae18a1d2053306677086b9edd5cb5f38120ee
  • Loading branch information
mergennachin authored and facebook-github-bot committed Sep 11, 2024
1 parent 41b463e commit d80f78f
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 17 deletions.
35 changes: 21 additions & 14 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def _load_llama_model(
fairseq2=weight_type == WeightType.FAIRSEQ2,
max_seq_len=max_seq_len,
enable_dynamic_shape=enable_dynamic_shape,
args=args,
)
state_dict = model.state_dict()
dtype = state_dict[next(iter(state_dict))].dtype
Expand Down Expand Up @@ -747,9 +748,26 @@ def _get_source_transforms(
transforms = []
if args.quantization_mode:
modelname = f"{modelname}_q"
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)
if args.use_spin_quant is None:
transforms.append(
get_quant_weight_transform(args, dtype_override, verbose_export())
)
# For SpinQuant, the checkpoints are already quantized
# aka the weights have corresponding scales value,
# So that means, we don't need to apply quantization
# transform. However, we will still need to apply
# transformations that change the model structure to
# match the checkpoint format.
# transform_for_spinquant() will apply these transformations
# later in model.py file.
elif args.use_spin_quant == "cuda":
from .source_transformation.spin_quant import (
inject_fast_hadamard_transform_cuda_for_spin_quant,
)

transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)
elif args.use_spin_quant == "native":
raise NotImplementedError("native SpinQuant is not implemented yet.")

if args.embedding_quantize:
modelname = f"{modelname}_e"
Expand Down Expand Up @@ -783,15 +801,4 @@ def _get_source_transforms(
transforms.append(replace_sdpa_with_simple_sdpa)
transforms.append(replace_causal_mask)

if args.use_spin_quant:
if args.use_spin_quant == "cuda":
from .source_transformation.spin_quant import (
inject_fast_hadamard_transform_cuda_for_spin_quant,
)

transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant)

elif args.use_spin_quant == "native":
raise NotImplementedError("native SpinQuant is not implemented yet.")

return transforms
47 changes: 44 additions & 3 deletions examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self, **kwargs):
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)

self.max_seq_len = kwargs.get("max_seq_len", 128)
self.args = kwargs.get("args", None)
# The example is using a dummy small model with random weights for demo purpose only.
# Follow the instruction in https://github.com/facebookresearch/llama to download the model
device = "cpu"
Expand Down Expand Up @@ -126,7 +127,8 @@ def __init__(self, **kwargs):
# get checkpoint dtype
self.dtype = None
if len(checkpoint) > 0:
first = checkpoint[next(iter(checkpoint))]
first_key = next(iter(checkpoint))
first = checkpoint[first_key]
self.dtype = first.dtype
mismatched_dtypes = [
(key, value.dtype)
Expand All @@ -135,7 +137,7 @@ def __init__(self, **kwargs):
]
if len(mismatched_dtypes) > 0:
print(
f"Mixed dtype model. Dtype of {first.key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
)
with open(params_path, "r") as f:
params = json.loads(f.read())
Expand Down Expand Up @@ -179,15 +181,54 @@ def __init__(self, **kwargs):
self.model_ = Int8DynActInt4WeightQuantizer()._convert_for_runtime(
self.model_
)
elif hasattr(self.args, "use_spin_quant") and self.args.use_spin_quant:
print("Using SPIN quantization.")
assert hasattr(self.args, "group_size"), "group_size must be specified"
assert hasattr(
self.args, "quantization_mode"
), "quantization_mode must be specified"
assert hasattr(
self.args, "dtype_override"
), "dtype_override must be specified"
from .source_transformation.spin_quant import (
sanitize_checkpoint_from_spinquant,
transform_for_spinquant,
)

mapping = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}

self.model_ = transform_for_spinquant(
self.model_,
checkpoint,
self.args.group_size,
self.args.quantization_mode,
mapping[self.args.dtype_override],
)

sanitize_checkpoint_from_spinquant(
checkpoint,
self.args.group_size,
)

# assign=True: load params/buffers by assignment instead of performing an in-place copy.
# Because we are using device="meta", tensors do not have memory associated with them
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
self.model_.load_state_dict(
missing, unexpected = self.model_.load_state_dict(
checkpoint,
strict=False,
assign=True,
) # self.model_ = Transformer(gptconf)
if kwargs.get("verbose", False):
print("============= missing keys ================")
print(missing)
print("============= /missing ================")
print("============= unexpected keys ================")
print(unexpected)
print("============= /unexpected ================")

def get_eager_model(self):
if self.dtype:
Expand Down
93 changes: 93 additions & 0 deletions examples/models/llama2/source_transformation/spin_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
# Helper functions for tranforming the model to be able to run SpinQuant.
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.

from typing import Any

import torch

import torch.nn.functional as F

from executorch.examples.models.llama2.llama_transformer import FeedForward
from torch import nn
from torchao.quantization.GPTQ import _check_linear_int4_k, Int8DynActInt4WeightLinear
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter


def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
Expand Down Expand Up @@ -53,3 +57,92 @@ def inject_fast_hadamard_transform_cuda_for_spin_quant(
) -> torch.nn.Module:
_inject_fast_hadamard_transform_cuda_for_spin_quant(module)
return module


def _replace_linear_with_linear_8da4w_for_spin_quant(
module: torch.nn.Module,
checkpoint: Any,
group_size: int,
precision: torch.dtype,
scales_precision: torch.dtype,
):
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
# Only replace linear layers where the checkpoint contains explicit scales
scales_key = f"{cur_fqn}.scale"
if isinstance(child, nn.Linear) and scales_key in checkpoint:
assert _check_linear_int4_k(child.in_features, group_size)
assert checkpoint[f"{cur_fqn}.weight"].dtype == torch.int8
assert checkpoint[scales_key].dtype == scales_precision
return True
return False

def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_linear = Int8DynActInt4WeightLinear(
child.in_features,
child.out_features,
bias=False,
device=child.weight.device,
groupsize=group_size,
precision=precision,
scales_precision=scales_precision,
)
return new_linear

_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)


def transform_for_spinquant(
module: torch.nn.Module,
checkpoint: Any,
group_size: int,
quantization_mode: str,
dtype: torch.dtype,
) -> torch.nn.Module:
"""
Transform the model to be able to load SpinQuant checkpoints that
are quantized with the given group size and quantization mode.
"""

if group_size not in [32, 64, 128, 256]:
raise ValueError(f"Group size {group_size} is not supported for SpinQuant.")
if quantization_mode not in ["8da4w"]:
raise ValueError(
f"Quantization mode {quantization_mode} is not compatible with SpinQuant."
)
_replace_linear_with_linear_8da4w_for_spin_quant(
module,
checkpoint,
group_size,
dtype,
dtype,
)
return module


def sanitize_checkpoint_from_spinquant(
checkpoint: Any,
group_size: int,
):
"""
Sanitize the SpinQuant checkpoint.
- Renames 'scale' to 'scales'
- Groups scales
- Removes 'o_weight'
- Converts all tensors to contiguous format
"""
keys_to_rename = []
keys_to_remove = []
for k, _ in checkpoint.items():
if k.endswith(".scale"):
new_key = k + "s"
keys_to_rename.append((k, new_key))
if k.endswith(".o_weight"):
keys_to_remove.append(k)

for old_key, new_key in keys_to_rename:
old_val = checkpoint.pop(old_key)
checkpoint[new_key] = old_val if group_size == -1 else old_val[:, ::group_size]
for k in keys_to_remove:
checkpoint.pop(k)
for k, v in checkpoint.items():
checkpoint[k] = v.contiguous()
13 changes: 13 additions & 0 deletions examples/models/llama2/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,16 @@ python_unittest(
"//executorch/examples/models/llama2:llama_transformer",
],
)

python_unittest(
name = "test_spinquant_transforms",
srcs = [
"test_spinquant_transforms.py",
],
deps = [
"//caffe2:torch",
"//executorch/examples/models/llama2:export_library",
"//executorch/examples/models/llama2:llama_transformer",
"//pytorch/ao:torchao",
],
)
89 changes: 89 additions & 0 deletions examples/models/llama2/tests/test_spinquant_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.examples.models.llama2.llama_transformer import ModelArgs, Transformer
from executorch.examples.models.llama2.source_transformation.spin_quant import (
sanitize_checkpoint_from_spinquant,
transform_for_spinquant,
)
from torchao.quantization.utils import group_quantize_tensor_symmetric


class SpinQuantTests(unittest.TestCase):
def test_transforms_for_spinquant(self):

# Step 1: Create llama class with dummy weights
params = {
"dim": 768,
"multiple_of": 32,
"n_heads": 12,
"n_layers": 12,
"norm_eps": 1e-05,
"vocab_size": 32000,
}

model_args = ModelArgs(
max_seq_len=2048,
max_batch_size=1,
use_kv_cache=False,
use_sdpa_with_kv_cache_op=False,
generate_full_logits=False,
enable_dynamic_shape=True,
**params,
)

model = Transformer(model_args)
checkpoint = model.state_dict()

# Step 2:
# Do group-wise quantization and amend the checkpoints with
# int8 weight and fp32 scales
group_size = 32
n_bit = 4
scales_precision = torch.float32
for fqn, mod in model.named_modules():
# Quantize everything except the last layer
if isinstance(mod, torch.nn.Linear) and ("output" not in fqn):
weight = mod.weight.data
(
weight_int8,
scales,
zeros,
) = group_quantize_tensor_symmetric(
weight.to(torch.float32), n_bit, group_size, scales_precision
)
checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
checkpoint[f"{fqn}.scale"] = scales.to("cpu")

# Step 3:
# Transform the model so that it is compatible with the new checkpoint
transform_for_spinquant(
model,
checkpoint,
32,
"8da4w",
torch.float32,
)
sanitize_checkpoint_from_spinquant(
checkpoint,
-1,
)

model.load_state_dict(
checkpoint,
strict=False,
assign=True,
)

new_checkpoint = model.state_dict()

for k, v in checkpoint.items():
# The new_checkpoint contains zeros so
# have to iterate over the keys.
self.assertTrue(torch.allclose(new_checkpoint[k], v))
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ addopts =
test/end2end/test_end2end.py
--ignore=backends/xnnpack/test/ops/linear.py
--ignore=backends/xnnpack/test/models/llama2_et_example.py
# T200992559: Add torchao to ET as core dependency
--ignore=examples/models/llama2/tests/test_spinquant_transforms.py
--ignore=exir/backend/test/demos
--ignore=exir/backend/test/test_backends.py
--ignore=exir/backend/test/test_backends_lifted.py
Expand Down

0 comments on commit d80f78f

Please sign in to comment.