Skip to content

Commit eadead5

Browse files
authored
Minor fix on TAO op to support lowering
Differential Revision: D82492826 Pull Request resolved: #3031
1 parent be4203e commit eadead5

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

torchao/dtypes/floatx/cutlass_semi_sparse_layout.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
106106
)
107107
elif func is aten.to.dtype_layout:
108108
dense, scale, _ = args[0].get_plain()
109-
dense = dense.to(
109+
product = dense.to(scale.dtype) * scale
110+
return product.to(
110111
*args[1:],
111112
dtype=kwargs.get("dtype", dense.dtype),
112113
device=kwargs.get("device", dense.device),
113114
)
114-
return scale * dense
115115

116116
raise NotImplementedError(
117117
f"CutlassSemiSparseTensorImpl dispatch: attempting to run {func}, this is not supported"
@@ -135,11 +135,12 @@ def get_plain(self):
135135
# semi-structured format, so multiplying with identity matrix,
136136
# and using identity scale factors, for the conversion.
137137
cols = self.shape[1]
138-
input = torch.eye(cols, dtype=self.sparse.dtype, device=self.sparse.device)
139-
input_scale = torch.ones(
140-
(cols,), dtype=self.scale.dtype, device=self.sparse.device
141-
)
138+
plain_input = torch.eye(cols, device=self.sparse.device)
139+
input = plain_input.to(dtype=self.sparse.dtype)
140+
plain_input_scale = torch.ones((cols,), device=self.sparse.device)
141+
input_scale = plain_input_scale.to(dtype=self.scale.dtype)
142142
sparse_scale = torch.ones_like(self.scale)
143+
143144
out_dtype = torch.bfloat16
144145
dense = (
145146
rowwise_scaled_linear_sparse_cutlass_f8f8(

torchao/quantization/linear_activation_quantized_tensor.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,14 @@ def _same_metadata(
133133

134134
@implements([torch.nn.functional.linear, aten.linear.default])
135135
def _(func, types, args, kwargs):
136-
input_tensor, weight_tensor, bias = (
137-
args[0],
138-
args[1],
139-
args[2] if len(args) > 2 else None,
140-
)
136+
137+
input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None)
138+
weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None)
139+
bias = kwargs.get("bias", args[2] if len(args) > 2 else None)
140+
141+
assert input_tensor is not None, "input tensor must not be None"
142+
assert weight_tensor is not None, "weight tensor must not be None"
143+
141144
if isinstance(weight_tensor, LinearActivationQuantizedTensor):
142145
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
143146

@@ -216,6 +219,11 @@ def _(func, types, args, kwargs):
216219
for tensor_name in self_tensors:
217220
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
218221
return
222+
elif type(self) is torch.Tensor and type(src) is LinearActivationQuantizedTensor:
223+
new_src = src.to(dtype=self.dtype, device=self.device)
224+
self.copy_(new_src)
225+
return
226+
219227
raise ValueError(
220228
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
221229
)

0 commit comments

Comments
 (0)