Skip to content

Commit

Permalink
use dequantize per channel group for embedding (#2374)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2374

- add type hints to quantized_ops
- use dequantize per channel group in embedding byte decomposition

bypass-github-export-checks

Reviewed By: mikekgfb

Differential Revision: D54813256

fbshipit-source-id: 79b8f39d820378faa90d908e3ea56d4201a61598
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Mar 15, 2024
1 parent e76cd88 commit 39c93aa
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 57 deletions.
94 changes: 51 additions & 43 deletions examples/models/llama2/ops/quantized_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

import torch
from torch.library import impl, impl_abstract

Expand Down Expand Up @@ -62,43 +64,45 @@ def embedding_byte_weight_checks(weight, weight_scales, weight_zero_points):
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
0
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
if not weight_zero_points:
weight_zero_points = torch.zeros(weight.size(0))


@impl(quantized_lib, "embedding_byte", "CompositeExplicitAutograd")
def embedding_byte_meta(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
):
def embedding_byte(
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zero_points: Optional[torch.Tensor],
weight_quant_min: int,
weight_quant_max: int,
indices: torch.Tensor,
) -> torch.Tensor:
embedding_byte_weight_checks(weight, weight_scales, weight_zero_points)
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
group_size = weight.size(1) // (
weight_scales.size(1) if weight_scales.dim() == 2 else 1
)
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
weight,
weight_scales,
weight_zero_points,
0,
weight_quant_min,
weight_quant_max,
weight.dtype,
group_size,
weight_scales.dtype,
)
return torch.ops.aten.embedding.default(weight, indices)


@impl_abstract("llama_quantized::embedding_byte.out")
def embedding_byte_out_meta(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
out,
):
return embedding_byte_meta(
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zero_points: Optional[torch.Tensor],
weight_quant_min: int,
weight_quant_max: int,
indices: torch.Tensor,
out: torch.Tensor,
) -> torch.Tensor:
return embedding_byte(
weight,
weight_scales,
weight_zero_points,
Expand All @@ -109,42 +113,46 @@ def embedding_byte_out_meta(


@impl(quantized_lib, "embedding_byte.dtype", "CompositeExplicitAutograd")
def embedding_byte_dtype_meta(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
def embedding_byte_dtype(
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zero_points: Optional[torch.Tensor],
weight_quant_min: int,
weight_quant_max: int,
indices: torch.Tensor,
*,
dtype,
):
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
embedding_byte_weight_checks(weight, weight_scales, weight_zero_points)
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
group_size = weight.size(1) // (
weight_scales.size(1) if weight_scales.dim() == 2 else 1
)
weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
weight,
weight_scales,
weight_zero_points,
0,
weight_quant_min,
weight_quant_max,
weight.dtype,
group_size,
dtype,
)
return torch.ops.aten.embedding.default(weight, indices).to(dtype)
return torch.ops.aten.embedding.default(weight, indices)


@impl_abstract("llama_quantized::embedding_byte.dtype_out")
def embedding_byte_dtype_out_meta(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zero_points: Optional[torch.Tensor],
weight_quant_min: int,
weight_quant_max: int,
indices: torch.Tensor,
*,
dtype,
out,
):
return embedding_byte_dtype_meta(
dtype: Optional[torch.dtype] = None,
out: torch.Tensor,
) -> torch.Tensor:
return embedding_byte_dtype(
weight,
weight_scales,
weight_zero_points,
Expand Down
132 changes: 118 additions & 14 deletions exir/passes/_quant_patterns_and_replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import torch
from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops

from executorch.exir.passes.replace_aten_with_edge_pass import (
aten_to_edge,
should_lower_to_edge,
Expand Down Expand Up @@ -487,6 +486,50 @@ def replacement(
)
return out

@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
def pattern_groupwise(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
group_size,
):
weight = (
torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
weight.dtype,
group_size,
weight_scales.dtype,
)
)
out = torch.ops.aten.embedding.default(weight, indices)
return out

def replacement_groupwise(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
group_size,
):
out = torch.ops.quantized_decomposed.embedding_byte.default(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
)
return out

@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
def pattern_with_padding_idx(
weight,
Expand Down Expand Up @@ -528,35 +571,86 @@ def replacement_with_padding_idx(
)
return out

@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte.dtype")
def pattern_with_dtype(
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte")
def pattern_with_padding_idx_groupwise(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indicies,
dtype,
indices,
group_size,
padding_idx,
):
weight = torch.ops.quantized_decomposed.dequantize_per_channel.default(
weight = (
torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
weight.dtype,
group_size,
weight_scales.dtype,
)
)
out = torch.ops.aten.embedding.default(weight, indices, padding_idx)
return out

def replacement_with_padding_idx_groupwise(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
group_size,
_, # padding_idx only matters for training and not when running op for inference
):
out = torch.ops.quantized_decomposed.embedding_byte.default(
weight,
weight_scales,
weight_zero_points,
0,
weight_quant_min,
weight_quant_max,
torch.uint8,
indices,
)
out = torch.ops.aten.embedding.default(weight, indicies).to(dtype)
return out

def replacement_with_dtype(
@bind_pattern_to_op(quantized_decomposed_lib, "embedding_byte.dtype")
def pattern_with_dtype_groupwise(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indicies,
indices,
group_size,
dtype,
):
weight = (
torch.ops.quantized_decomposed.dequantize_per_channel_group.default(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
weight.dtype,
group_size,
dtype,
)
)
out = torch.ops.aten.embedding.default(weight, indices)
return out

def replacement_with_dtype_groupwise(
weight,
weight_scales,
weight_zero_points,
weight_quant_min,
weight_quant_max,
indices,
group_size,
dtype,
):
out = torch.ops.quantized_decomposed.embedding_byte.dtype(
Expand All @@ -565,7 +659,7 @@ def replacement_with_dtype(
weight_zero_points,
weight_quant_min,
weight_quant_max,
indicies,
indices,
dtype=dtype,
)
return out
Expand All @@ -576,14 +670,24 @@ def replacement_with_dtype(
_trace_and_lower_to_edge_ops(replacement),
[],
),
(
_trace_and_lower_to_edge_ops(pattern_groupwise),
_trace_and_lower_to_edge_ops(replacement_groupwise),
[],
),
(
_trace_and_lower_to_edge_ops(pattern_with_padding_idx),
_trace_and_lower_to_edge_ops(replacement_with_padding_idx),
[],
),
(
_trace_and_lower_to_edge_ops(pattern_with_dtype),
_trace_and_lower_to_edge_ops(replacement_with_dtype),
_trace_and_lower_to_edge_ops(pattern_with_padding_idx_groupwise),
_trace_and_lower_to_edge_ops(replacement_with_padding_idx_groupwise),
[],
),
(
_trace_and_lower_to_edge_ops(pattern_with_dtype_groupwise),
_trace_and_lower_to_edge_ops(replacement_with_dtype_groupwise),
[],
),
]
Expand Down

0 comments on commit 39c93aa

Please sign in to comment.