From 39c93aaa0577eb92f28129b5cac064d197093247 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Fri, 15 Mar 2024 16:08:29 -0700 Subject: [PATCH] use dequantize per channel group for embedding (#2374) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/2374 - add type hints to quantized_ops - use dequantize per channel group in embedding byte decomposition bypass-github-export-checks Reviewed By: mikekgfb Differential Revision: D54813256 fbshipit-source-id: 79b8f39d820378faa90d908e3ea56d4201a61598 --- examples/models/llama2/ops/quantized_ops.py | 94 +++++++------ .../_quant_patterns_and_replacements.py | 132 ++++++++++++++++-- 2 files changed, 169 insertions(+), 57 deletions(-) diff --git a/examples/models/llama2/ops/quantized_ops.py b/examples/models/llama2/ops/quantized_ops.py index 2ab8df3080..5d13856442 100644 --- a/examples/models/llama2/ops/quantized_ops.py +++ b/examples/models/llama2/ops/quantized_ops.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Optional + import torch from torch.library import impl, impl_abstract @@ -62,43 +64,45 @@ def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points): assert weight_zero_points is None or weight_zero_points.size(0) == weight.size( 0 ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}" - if not weight_zero_points: - weight_zero_points = torch.zeros(weight.size(0)) @impl(quantized_lib, "embedding_byte", "CompositeExplicitAutograd") -def embedding_byte_meta( - weight, - weight_scales, - weight_zero_points, - weight_quant_min, - weight_quant_max, - indices, -): +def embedding_byte( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, +) -> torch.Tensor: embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) - weight = torch.ops.quantized_decomposed.dequantize_per_channel.default( + group_size = weight.size(1) // ( + weight_scales.size(1) if weight_scales.dim() == 2 else 1 + ) + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( weight, weight_scales, weight_zero_points, - 0, weight_quant_min, weight_quant_max, weight.dtype, + group_size, + weight_scales.dtype, ) return torch.ops.aten.embedding.default(weight, indices) @impl_abstract("llama_quantized::embedding_byte.out") def embedding_byte_out_meta( - weight, - weight_scales, - weight_zero_points, - weight_quant_min, - weight_quant_max, - indices, - out, -): - return embedding_byte_meta( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + out: torch.Tensor, +) -> torch.Tensor: + return embedding_byte( weight, weight_scales, weight_zero_points, @@ -109,42 +113,46 @@ def embedding_byte_out_meta( @impl(quantized_lib, "embedding_byte.dtype", "CompositeExplicitAutograd") -def embedding_byte_dtype_meta( - weight, - weight_scales, - weight_zero_points, - weight_quant_min, - weight_quant_max, - indices, +def embedding_byte_dtype( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, *, - dtype, -): + dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) - weight = torch.ops.quantized_decomposed.dequantize_per_channel.default( + group_size = weight.size(1) // ( + weight_scales.size(1) if weight_scales.dim() == 2 else 1 + ) + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( weight, weight_scales, weight_zero_points, - 0, weight_quant_min, weight_quant_max, weight.dtype, + group_size, + dtype, ) - return torch.ops.aten.embedding.default(weight, indices).to(dtype) + return torch.ops.aten.embedding.default(weight, indices) @impl_abstract("llama_quantized::embedding_byte.dtype_out") def embedding_byte_dtype_out_meta( - weight, - weight_scales, - weight_zero_points, - weight_quant_min, - weight_quant_max, - indices, + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, *, - dtype, - out, -): - return embedding_byte_dtype_meta( + dtype: Optional[torch.dtype] = None, + out: torch.Tensor, +) -> torch.Tensor: + return embedding_byte_dtype( weight, weight_scales, weight_zero_points, diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index ae1e63ebb8..334e739893 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -9,7 +9,6 @@ import torch from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops - from executorch.exir.passes.replace_aten_with_edge_pass import ( aten_to_edge, should_lower_to_edge, @@ -487,6 +486,50 @@ def replacement( ) return out + @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte") + def pattern_groupwise( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + group_size, + ): + weight = ( + torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + weight.dtype, + group_size, + weight_scales.dtype, + ) + ) + out = torch.ops.aten.embedding.default(weight, indices) + return out + + def replacement_groupwise( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + group_size, + ): + out = torch.ops.quantized_decomposed.embedding_byte.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + ) + return out + @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte") def pattern_with_padding_idx( weight, @@ -528,35 +571,86 @@ def replacement_with_padding_idx( ) return out - @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte.dtype") - def pattern_with_dtype( + @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte") + def pattern_with_padding_idx_groupwise( weight, weight_scales, weight_zero_points, weight_quant_min, weight_quant_max, - indicies, - dtype, + indices, + group_size, + padding_idx, ): - weight = torch.ops.quantized_decomposed.dequantize_per_channel.default( + weight = ( + torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + weight.dtype, + group_size, + weight_scales.dtype, + ) + ) + out = torch.ops.aten.embedding.default(weight, indices, padding_idx) + return out + + def replacement_with_padding_idx_groupwise( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + group_size, + _, # padding_idx only matters for training and not when running op for inference + ): + out = torch.ops.quantized_decomposed.embedding_byte.default( weight, weight_scales, weight_zero_points, - 0, weight_quant_min, weight_quant_max, - torch.uint8, + indices, ) - out = torch.ops.aten.embedding.default(weight, indicies).to(dtype) return out - def replacement_with_dtype( + @bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte.dtype") + def pattern_with_dtype_groupwise( weight, weight_scales, weight_zero_points, weight_quant_min, weight_quant_max, - indicies, + indices, + group_size, + dtype, + ): + weight = ( + torch.ops.quantized_decomposed.dequantize_per_channel_group.default( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + weight.dtype, + group_size, + dtype, + ) + ) + out = torch.ops.aten.embedding.default(weight, indices) + return out + + def replacement_with_dtype_groupwise( + weight, + weight_scales, + weight_zero_points, + weight_quant_min, + weight_quant_max, + indices, + group_size, dtype, ): out = torch.ops.quantized_decomposed.embedding_byte.dtype( @@ -565,7 +659,7 @@ def replacement_with_dtype( weight_zero_points, weight_quant_min, weight_quant_max, - indicies, + indices, dtype=dtype, ) return out @@ -576,14 +670,24 @@ def replacement_with_dtype( _trace_and_lower_to_edge_ops(replacement), [], ), + ( + _trace_and_lower_to_edge_ops(pattern_groupwise), + _trace_and_lower_to_edge_ops(replacement_groupwise), + [], + ), ( _trace_and_lower_to_edge_ops(pattern_with_padding_idx), _trace_and_lower_to_edge_ops(replacement_with_padding_idx), [], ), ( - _trace_and_lower_to_edge_ops(pattern_with_dtype), - _trace_and_lower_to_edge_ops(replacement_with_dtype), + _trace_and_lower_to_edge_ops(pattern_with_padding_idx_groupwise), + _trace_and_lower_to_edge_ops(replacement_with_padding_idx_groupwise), + [], + ), + ( + _trace_and_lower_to_edge_ops(pattern_with_dtype_groupwise), + _trace_and_lower_to_edge_ops(replacement_with_dtype_groupwise), [], ), ]