Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename AQT#2 LayoutType -> Layout #1049

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/benchmark_fp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import pandas as pd
import torch.nn.functional as F
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayoutType
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayoutType(3, 2))
fp6_weight = to_affine_quantized_fpx(float_data, FloatxTensorCoreLayout(3, 2))
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
Expand Down
4 changes: 2 additions & 2 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
)
from torchao.dtypes import SemiSparseLayoutType
from torchao.dtypes import SemiSparseLayout
from torch.testing._internal import common_utils
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand All @@ -31,7 +31,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
base_functions.append(int4_weight_only(group_size=32))

if do_sparse:
base_functions.append(int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()))
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))

if is_cuda_8_9:
base_functions.append(float8_weight_only())
Expand Down
6 changes: 3 additions & 3 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from torchao.dtypes.floatx import (
FloatxTensorCoreAQTTensorImpl,
FloatxTensorCoreLayoutType,
FloatxTensorCoreLayout,
to_scaled_tc_floatx,
from_scaled_tc_floatx,
)
Expand Down Expand Up @@ -81,8 +81,8 @@ def test_to_copy_device(self, ebits, mbits):
x = torch.randn(256, 64)
scale = choose_qparams_affine_floatx(x, ebits, mbits)
x = quantize_affine_floatx(x, scale, ebits, mbits)
layout_type = FloatxTensorCoreLayoutType(ebits, mbits)
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, layout_type).cuda()
_layout = FloatxTensorCoreLayout(ebits, mbits)
floatx_tensor_impl = FloatxTensorCoreAQTTensorImpl.from_plain(x, scale, None, _layout).cuda()
assert floatx_tensor_impl.device.type == "cuda"
floatx_tensor_impl = floatx_tensor_impl.cpu()
assert floatx_tensor_impl.device.type == "cpu"
Expand Down
4 changes: 2 additions & 2 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
to_affine_quantized_intx,
ZeroPointDomain,
PlainAQTTensorImpl,
PlainLayoutType,
PlainLayout,
TensorCoreTiledAQTTensorImpl,
TensorCoreTiledLayoutType,
TensorCoreTiledLayout,
MappingType,
)

Expand Down
6 changes: 3 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torchao.quantization.dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from torchao.dtypes import TensorCoreTiledLayoutType
from torchao.dtypes import TensorCoreTiledLayout
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
Expand Down Expand Up @@ -876,7 +876,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)}
kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)}

def api(mod):
kwargs_copy = kwargs.copy()
Expand All @@ -888,7 +888,7 @@ def api(mod):
unwrap_tensor_subclass(mod)
else:
kwargs_copy["inner_k_tiles"] = inner_k_tiles
del kwargs_copy["layout_type"]
del kwargs_copy["layout"]
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)

self._test_lin_weight_subclass_api_impl(
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.dtypes import (
TensorCoreTiledLayoutType,
TensorCoreTiledLayout,
)
from torchao.quantization.prototype.qat.api import (
ComposableQATQuantizer,
Expand Down
6 changes: 3 additions & 3 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.dtypes import MarlinSparseLayoutType
from torchao.dtypes import MarlinSparseLayout
from torchao.sparsity.sparse_api import apply_fake_sparsity
from torchao.quantization.quant_api import int4_weight_only, quantize_
from torchao.sparsity.marlin import (
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_quant_sparse_marlin_layout_eager(self):
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
sparse_result = self.model(self.input)

assert torch.allclose(dense_result, sparse_result, atol=3e-1), "Results are not close"
Expand All @@ -67,7 +67,7 @@ def test_quant_sparse_marlin_layout_compile(self):
dense_result = model_copy(self.input.bfloat16()).half()

# Sparse + quantized
quantize_(self.model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
quantize_(self.model, int4_weight_only(layout=MarlinSparseLayout()))
self.model.forward = torch.compile(self.model.forward, fullgraph=True)
sparse_result = self.model(self.input)

Expand Down
10 changes: 5 additions & 5 deletions test/sparsity/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import nn
from torch.testing._internal import common_utils
from torchao.dtypes import MarlinSparseLayoutType, SemiSparseLayoutType
from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout
from torchao.quantization.quant_api import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_quant_semi_sparse(self, compile):

quantize_(
model,
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()),
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
)
if compile:
model = torch.compile(model)
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_sparse_marlin(self, compile):
dense_result = model_copy(input.bfloat16()).half()

# Sparse + quantized
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if compile:
model = torch.compile(model)
sparse_result = model(input)
Expand Down Expand Up @@ -185,12 +185,12 @@ def test_sparse(self, compile):
quantize_(model_copy, int8_dynamic_activation_int8_weight())
reference = model_copy(input)

from torchao.dtypes.affine_quantized_tensor import BlockSparseLayoutType
from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout

quantize_(
model,
int8_dynamic_activation_int8_weight(
layout_type=BlockSparseLayoutType(blocksize=64)
layout=BlockSparseLayout(blocksize=64)
),
)
if compile:
Expand Down
6 changes: 3 additions & 3 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def run_evaluation(
group_size = int(_quant_args[2])
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "int4wo" in quantization and "gptq" in quantization:
# avoid circular imports
from torchao._models._eval import InputRecorder
Expand Down Expand Up @@ -255,4 +255,4 @@ def run_evaluation(
args.calibration_limit,
args.calibration_seq_length,
args.pad_calibration_inputs,
)
)
4 changes: 2 additions & 2 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def main(
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model, int4_weight_only(group_size=groupsize))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayoutType
quantize_(model, int4_weight_only(layout_type=MarlinSparseLayoutType()))
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if quantization.startswith("awq"):
Expand Down
8 changes: 4 additions & 4 deletions torchao/_models/sam/eval_combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, int4_weight_only
from torchao.sparsity import sparsify_, apply_fake_sparsity, semi_sparse_weight
from torchao.dtypes import SemiSparseLayoutType, MarlinSparseLayoutType
from torchao.dtypes import SemiSparseLayout, MarlinSparseLayout
from torchao.utils import unwrap_tensor_subclass
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand Down Expand Up @@ -315,7 +315,7 @@ def mlp_only(mod, name):
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType()),
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()),
mlp_lin1_only)
sparsify_(predictor.model.image_encoder,
semi_sparse_weight(),
Expand All @@ -326,11 +326,11 @@ def mlp_only(mod, name):
# apply sparsify first to set qparams
apply_fake_sparsity(predictor.model.image_encoder,
filter_fn=mlp_only)
from torchao.dtypes import MarlinSparseLayoutType
from torchao.dtypes import MarlinSparseLayout
quantize_(predictor.model.image_encoder,
int8_dynamic_activation_int8_weight(),
attn_only)
quantize_(predictor.model.image_encoder, int4_weight_only(layout_type=MarlinSparseLayoutType()), mlp_lin1_only)
quantize_(predictor.model.image_encoder, int4_weight_only(layout=MarlinSparseLayout()), mlp_lin1_only)
sparsify_(predictor.model.image_encoder,
semi_sparse_weight(),
mlp_lin2_only)
Expand Down
24 changes: 12 additions & 12 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
to_affine_quantized_fpx,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
LayoutType,
PlainLayoutType,
SemiSparseLayoutType,
TensorCoreTiledLayoutType,
Float8LayoutType,
Layout,
PlainLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
Float8Layout,
Float8AQTTensorImpl,
MarlinSparseLayoutType,
MarlinSparseLayout,
)

__all__ = [
Expand All @@ -28,11 +28,11 @@
"to_affine_quantized_fpx",
"to_affine_quantized_floatx",
"to_affine_quantized_floatx_static",
"LayoutType",
"PlainLayoutType",
"SemiSparseLayoutType",
"TensorCoreTiledLayoutType",
"Float8LayoutType",
"Layout",
"PlainLayout",
"SemiSparseLayout",
"TensorCoreTiledLayout",
"Float8Layout",
"Float8AQTTensorImpl",
"MarlinSparseLayoutType",
"MarlinSparseLayout",
]
Loading
Loading