Skip to content

Commit

Permalink
Wrap model's buffers and params to MultiTensor & update the results (
Browse files Browse the repository at this point in the history
…#16)

* wrap model's buffers and params to `MultiTensor` and update the results

Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 authored Sep 3, 2024
1 parent 21686f1 commit 96f745d
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 41 deletions.
78 changes: 78 additions & 0 deletions test/prototype/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,37 @@ def _is_two_linear(mod, fqn):
return isinstance(mod, TwoLinear)


class ModelWithInplaceOp(torch.nn.Module):
def __init__(self, DIM=128):
super().__init__()
self.lin = torch.nn.Linear(DIM, DIM)
self.register_buffer("other", torch.zeros(DIM, DIM))

def forward(self, x, idx):
x = x + self.lin(x)
# update buffer
self.other[idx] = x
return x


class M2(torch.nn.Module):
def __init__(self, DIM=128):
super().__init__()
self.m1 = ModelWithInplaceOp(DIM)
self.m2 = ModelWithInplaceOp(DIM)

def forward(self, x, idx):
x = self.m1(x, idx)
x = self.m2(x, idx)
return x


def _check_params_and_buffers_type(module, check_fun):
return [check_fun(p) for p in module.parameters()] + [
check_fun(b) for b in module.buffers()
]


class TestAutoRound(TestCase):

@pytest.mark.skip(not TORCH_VERSION_AT_LEAST_2_5, "Requires torch 2.5 or later")
Expand All @@ -73,6 +104,9 @@ def test_auto_round(self, device: str):
iters=20,
device=device,
)
assert all(
_check_params_and_buffers_type(m, lambda x: isinstance(x, MultiTensor))
), "Expected all parameters and buffers to be `MultiTensor`."
input1 = []
input2 = []
for _ in range(10):
Expand All @@ -82,13 +116,57 @@ def test_auto_round(self, device: str):
mt_input1 = MultiTensor(input1)
mt_input2 = MultiTensor(input2)
out = m(mt_input1, mt_input2)
assert isinstance(out, MultiTensor), f"Expected MultiTensor, got {type(out)}"
assert all(
_check_params_and_buffers_type(m, lambda x: not isinstance(x, MultiTensor))
), "Expected all parameters and buffers have been converted back to tensor."
quantize_(m, apply_auto_round(), _is_two_linear, device=device)
for l in m.modules():
if isinstance(l, torch.nn.Linear):
assert isinstance(l.weight, AffineQuantizedTensor)
after_quant = m(*example_inputs)
assert after_quant is not None, "Quantized model forward pass failed"

@pytest.mark.skip(not TORCH_VERSION_AT_LEAST_2_5, "Requires torch 2.5 or later")
@parametrize("device", _AVAILABLE_DEVICES)
@torch.no_grad()
def test_wrap_model_with_multi_tensor(self, device: str):

_is_model_with_inplace_op = lambda mod, fqn: isinstance(mod, ModelWithInplaceOp)

DIM = 128
m = M2(DIM).eval().to(device)
prepare_model_for_applying_auto_round_(
m,
is_target_module=_is_model_with_inplace_op,
bits=7,
group_size=32,
iters=20,
device=device,
)
assert all(
_check_params_and_buffers_type(m, lambda x: isinstance(x, MultiTensor))
), "Expected all parameters and buffers to be `MultiTensor`."
input1 = []
input2 = []
for _ in range(2):
input1.append(torch.randn(DIM, DIM).to(device))
input2.append(torch.randint(0, DIM, (DIM,), dtype=torch.long).to(device))

mt_input1 = MultiTensor(input1)
mt_input2 = MultiTensor(input2)
out = m(mt_input1, mt_input2)
assert isinstance(out, MultiTensor), f"Expected MultiTensor, got {type(out)}"
assert all(
_check_params_and_buffers_type(m, lambda x: not isinstance(x, MultiTensor))
), "Expected all parameters and buffers have been converted back to tensor."
quantize_(m, apply_auto_round(), _is_model_with_inplace_op, device=device)
for l in m.modules():
if isinstance(l, torch.nn.Linear):
assert isinstance(l.weight, AffineQuantizedTensor)
after_quant = m(input1[0], input2[0])
assert after_quant is not None, "Quantized model forward pass failed"


instantiate_parametrized_tests(TestAutoRound)

Expand Down
15 changes: 8 additions & 7 deletions torchao/prototype/autoround/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,28 @@ quantize_(model, apply_auto_round(), is_target_module)
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
| bf16 | 0.7080 | 0.6783 | 0.8003 | 0.7403 | 0.5910 | 0.7303 |
| auto-round-4bit | 0.6989 | 0.6566 | 0.7943 | 0.7285 | 0.5856 | 0.7295 |
| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 |
| auto-round-4bit | 0.6988 | 0.6533 | 0.7949 | 0.7372 | 0.5837 | 0.7250 |
| torchao-int4wo | 0.6883 | 0.6363 | 0.7938 | 0.7348 | 0.5784 | 0.6980 |

### [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 |
| auto-round-4bit | 0.6811 | 0.6218 | 0.7758 | 0.7285 | 0.5694 | 0.7101 |
| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 |
| bf16 | 0.6881 | 0.6389 | 0.7840 | 0.7222 | 0.5772 | 0.7184 |
| auto-round-4bit | 0.6818 | 0.6232 | 0.7862 | 0.7230 | 0.5661 | 0.7105 |
| torchao-int4wo | 0.6728 | 0.5939 | 0.7737 | 0.7222 | 0.5612 | 0.7132 |


### [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
| | Avg. | Mmlu | Piqa | Winogrande | Hellaswag | Lambada_openai |
| -------------- | ------- | ------ | ------ | ---------- | --------- | -------------- |
| bf16 | 0.6347 | 0.4647 | 0.7644 | 0.6606 | 0.577 | 0.7070 |
| auto-round-4bit | 0.6335 | 0.4533 | 0.7661 | 0.6685 | 0.5705 | 0.7091 |
| auto-round-4bit | 0.6327 | 0.4534 | 0.7590 | 0.6661 | 0.5706 | 0.7143 |
| torchao-int4wo | 0.6252 | 0.4427 | 0.7617 | 0.6654 | 0.5674 | 0.6889 |

> [!NOTE]
> - `auto-round-4bit` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=8`, `group_size=128`, `use_optimized_layer_output=True` and `quant_lm_head=False`. <br>
> - `auto-round-4bit` represents the following configuration: `bits=4`, `iters=200`, `seqlen=2048`, `train_bs=8`, `group_size=128`, and `quant_lm_head=False`. <br>
> - `torchao-int4wo` represents `int4_weight_only(group_size=128)` and `quant_lm_head=False`.
> - If the model includes operations without a deterministic implementation (such as Flash Attention), the results may differ slightly.

## Credits
Expand Down
14 changes: 12 additions & 2 deletions torchao/prototype/autoround/autoround_llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging

import torch

Expand Down Expand Up @@ -27,6 +28,7 @@ def quantize_model_with_autoround_(
dataset_name: str = "NeelNanda/pile-10k",
bs: int = 8,
nsamples: int = 128,
use_optimized_layer_output: bool = False,
):
# Step 1. Prepare the model for applying auto-round

Expand All @@ -39,6 +41,7 @@ def quantize_model_with_autoround_(
bits,
group_size,
iters,
use_optimized_layer_output,
device=device,
)

Expand Down Expand Up @@ -79,7 +82,7 @@ def main(args):
model, tokenizer, decoder_cls = ar_utils.get_float_model_info(
model_name_or_path, torch_dtype=torch.bfloat16
)
# Disable the `use_cache` for calibration process, which cause the OOM.
# Disable the `use_cache` for calibration stage.
model.config.use_cache = False
ar_utils.gen_text(model, tokenizer, "Float model", max_length=50)

Expand All @@ -103,8 +106,9 @@ def main(args):
dataset_name=args.dataset_name,
bs=args.train_bs,
nsamples=args.nsamples,
use_optimized_layer_output=args.use_optimized_layer_output,
)
# Revert the `use_cache`
# Revert the `use_cache` for generation stage.
model.config.use_cache = True

# Generate text using the quantized model
Expand Down Expand Up @@ -158,6 +162,12 @@ def main(args):
action="store_true",
help="Quantize the `lm_head` or not",
)
parser.add_argument(
"--use_optimized_layer_output",
default=False,
action="store_true",
help="Use the optimized layer output for next layer or not",
)
parser.add_argument(
"-d",
"--model_device",
Expand Down
75 changes: 65 additions & 10 deletions torchao/prototype/autoround/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
import torchao.prototype.autoround.utils as ar_utils
import torchao.quantization as ao_quant
from torchao.dtypes import TensorCoreTiledLayoutType, to_affine_quantized_intx_static
from torchao.prototype.autoround.multi_tensor import (
_multi_tensor_config,
_MultiTensorConfig,
MultiTensor,
)
from torchao.prototype.autoround.multi_tensor import _multi_tensor_config, MultiTensor
from torchao.quantization.quant_primitives import ZeroPointDomain
from torchao.utils import find_multiple

Expand All @@ -23,7 +19,7 @@ class _AutoRoundConfig:
bits: int = 4
group_size: int = 128
iters: int = 200
use_optimized_layer_output: bool = True
use_optimized_layer_output: bool = False


_auto_round_config = _AutoRoundConfig()
Expand All @@ -43,14 +39,49 @@ def reset(self):
_optimization_tracker = _OptimizationTracker()


def _replace_model_buffers_and_params(model, replacement_fn):
model = replacement_fn(model)
for name, child in model.named_children():
new_child = _replace_model_buffers_and_params(child, replacement_fn)
if new_child is not child:
setattr(model, name, new_child)
return model


def _tensor_to_multi_tensor(model):
for name, buf in model.named_buffers(recurse=False):
setattr(model, name, MultiTensor([buf]))
for name, param in model.named_parameters(recurse=False):
setattr(model, name, torch.nn.Parameter(MultiTensor([param]), False))
return model


def _multi_tensor_to_tensor(model):
for name, buf in model.named_buffers(recurse=False):
if isinstance(buf, MultiTensor):
assert (
len(buf.values) == 1
), f"The buffer should only have one tensor, but got {buf.count}."
model.register_buffer(name, buf.values[0])
for name, param in model.named_parameters(recurse=False):
if isinstance(param, MultiTensor):
assert (
len(param.values) == 1
), f"The parameter should only have one tensor, but got {param.count}."
setattr(
model, name, torch.nn.Parameter(param.values[0], requires_grad=False)
)
return model


@torch.no_grad()
def prepare_model_for_applying_auto_round_(
model: torch.nn.Module,
is_target_module: Callable[[torch.nn.Module, str], bool],
bits: int = 4,
group_size: int = 128,
iters: int = 200,
use_optimized_layer_output: bool = True,
use_optimized_layer_output: bool = False,
device: Optional[torch.types.Device] = None,
):
"""Prepares the model for applying auto round optimization.
Expand All @@ -62,7 +93,7 @@ def prepare_model_for_applying_auto_round_(
bits (int, optional): The number of bits for quantization. Defaults to 4, options are 1 to 8.
group_size (int, optional): The group size for quantization. Defaults to 128.
iters (int, optional): The number of iterations for optimization. Defaults to 200.
use_optimized_layer_output (bool, optional): Whether to use optimized layer output. Defaults to True.
use_optimized_layer_output (bool, optional): Whether to use optimized layer output. Defaults to False.
device (Optional[torch.types.Device], optional): The device to use for accelrating optimization and calibration.
Defaults to None.
"""
Expand All @@ -75,7 +106,27 @@ def prepare_model_for_applying_auto_round_(
_auto_round_config.iters = iters
_auto_round_config.use_optimized_layer_output = use_optimized_layer_output

def forward_hook(
logging.warning(f"config {_auto_round_config}")

# Wrap the model buffers and parameters with `MultiTensor`
model = _replace_model_buffers_and_params(model, _tensor_to_multi_tensor)

def _revert_buffers_and_params_fn(
module,
input: Tuple[MultiTensor],
output: Tuple[MultiTensor],
):
module._forward_hook_handle_for_revert_buffers_and_params.remove()
_replace_model_buffers_and_params(module, _multi_tensor_to_tensor)
return output

# Register forward hook for reverting the replacement of buffers and parameters
model._forward_hook_handle_for_revert_buffers_and_params = (
model.register_forward_hook(_revert_buffers_and_params_fn)
)

# Register forward hook for applying auto-round optimization
def auto_round_optimization_hook(
module,
args: Tuple[MultiTensor],
kwargs: Dict[str, MultiTensor],
Expand All @@ -88,7 +139,7 @@ def forward_hook(

def _register_forward_hook(module: torch.nn.Module):
forward_hook_handle = module.register_forward_hook(
forward_hook, with_kwargs=True
auto_round_optimization_hook, with_kwargs=True
)
module._forward_hook_handle_for_auto_round = forward_hook_handle
_optimization_tracker.num_layers += 1
Expand Down Expand Up @@ -286,12 +337,16 @@ def apply_auto_round_optimization(
):
# Remove the hook to avoid recursive calls
module._forward_hook_handle_for_auto_round.remove()
# Revert the model to the original state for applying auto-round optimization
module = _replace_model_buffers_and_params(module, _multi_tensor_to_tensor)

block_inputs = MultiTensor.revert_to_tensor_pairs(args, kwargs)
block_outputs = MultiTensor.revert_to_tensor_pairs(output)

_apply_auto_round_optimization(module, block_inputs, block_outputs, config)
# Get the new output of the optimized model
if config.use_optimized_layer_output:
# Re-replace the model buffers and parameters with `MultiTensor`
_replace_model_buffers_and_params(module, _tensor_to_multi_tensor)
output = module(*args, **kwargs)
return output
15 changes: 12 additions & 3 deletions torchao/prototype/autoround/eval_autoround.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import argparse

import torchao.prototype.autoround.utils as ar_utils

ar_utils.freeze_random(42)
import torch

torch.use_deterministic_algorithms(True, warn_only=True)
import torchao
import torchao.prototype.autoround.utils as ar_utils

import torchao.quantization
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

ar_utils.freeze_random(42)


@ar_utils.dump_elapsed_time()
def run_evaluation(model, tokenizer, tasks, compile=False, batch_size=4):
Expand Down Expand Up @@ -118,6 +120,7 @@ def main(args):
seqlen=args.seqlen,
bs=args.train_bs,
nsamples=args.nsamples,
use_optimized_layer_output=args.use_optimized_layer_output,
)
quantized_layer_cnt = ar_utils.count_tensor_of_type(
model, torchao.dtypes.AffineQuantizedTensor
Expand Down Expand Up @@ -175,6 +178,12 @@ def main(args):
action="store_true",
help="Quantize the `lm_head` or not",
)
parser.add_argument(
"--use_optimized_layer_output",
default=False,
action="store_true",
help="Use the optimized layer output for next layer or not",
)
parser.add_argument(
"-d",
"--model_device",
Expand Down
Loading

0 comments on commit 96f745d

Please sign in to comment.