Skip to content

Commit

Permalink
Support torch.int32 as a dtype for quantize and dequantize (#289)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #289

The ops like `quantized_decomposed.quantize_per_tensor.default` did not support
an int32 quantized type. Add support for these to the portable and aten runtimes.

This is important for Turing which uses int32 to represent uint16 (as the latter is not a valid
pytorch dtype).

Reviewed By: kimishpatel

Differential Revision: D49202048

fbshipit-source-id: 0faa89ce1d34b60ece443fb02fa14f02abf2d376
  • Loading branch information
dulinriley authored and facebook-github-bot committed Sep 13, 2023
1 parent fbbec00 commit 63bb3b5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
4 changes: 3 additions & 1 deletion kernels/quantized/cpu/op_dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ void check_dequantize_per_tensor_args(
Tensor& out) {
ET_CHECK_MSG(
input.scalar_type() == ScalarType::Byte ||
input.scalar_type() == ScalarType::Char,
input.scalar_type() == ScalarType::Char ||
input.scalar_type() == ScalarType::Short ||
input.scalar_type() == ScalarType::Int,
"input.scalar_type() %hdd is not supported:",
input.scalar_type());

Expand Down
6 changes: 6 additions & 0 deletions kernels/quantized/cpu/op_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ void check_quantize_per_tensor_args(
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
quant_max_upper_bound =
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
} else if (dtype == ScalarType::Short) {
quant_min_lower_bound = std::numeric_limits<int16_t>::min();
quant_max_upper_bound = std::numeric_limits<int16_t>::max();
} else if (dtype == ScalarType::Int) {
quant_min_lower_bound = std::numeric_limits<int32_t>::min();
quant_max_upper_bound = std::numeric_limits<int32_t>::max();
} else {
ET_CHECK_MSG(false, "Unsupported dtype: %hdd", out_dtype);
}
Expand Down

0 comments on commit 63bb3b5

Please sign in to comment.