Skip to content

Commit

Permalink
Fix dequantize_per_channel for single dimension input tensor (#3867)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3867

This diff has two fixes:
- When the input dimension is 1, we need to handle that separately
- Fixing logic in how dim_list is generated

Reviewed By: larryliu0820

Differential Revision: D58221156

fbshipit-source-id: aee6b66952271e2724c27cb4efd04bc4f54434b1
  • Loading branch information
tarun292 authored and facebook-github-bot committed Jun 18, 2024
1 parent 7028a71 commit c6fb9da
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
38 changes: 32 additions & 6 deletions kernels/quantized/cpu/op_dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ Tensor& dequantize_per_channel_out(
"Failed to resize out Tensor in dequantize_per_channel_out");

ET_CHECK_MSG(
scale.scalar_type() == ScalarType::Double,
"scale.scalar_type() %" PRId8 " is not double type",
scale.scalar_type() == ScalarType::Float,
"scale.scalar_type() %" PRId8 " is not float type",
static_cast<int8_t>(scale.scalar_type()));

ET_CHECK_MSG(
Expand All @@ -224,15 +224,15 @@ Tensor& dequantize_per_channel_out(
input, quant_min, quant_max, dtype, out_dtype, out);

// a list contains all dimensions except axis
int64_t dims[input.dim() - 1];
int64_t dims[kTensorDimensionLimit];
for (int64_t i = 0; i < input.dim() - 1; i++) {
if (i < axis) {
dims[i] = i;
} else {
dims[i] = i - 1;
dims[i] = i + 1;
}
}
const double* scale_data = scale.const_data_ptr<double>();
const float* scale_data = scale.const_data_ptr<float>();
const int64_t* zero_point_data;
if (opt_zero_points.has_value()) {
zero_point_data = opt_zero_points.value().const_data_ptr<int64_t>();
Expand All @@ -253,8 +253,34 @@ Tensor& dequantize_per_channel_out(
// in other words you are dequantizing in_data[in_ix]
#define DEQUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
case ScalarType::out_dtype: \
if (input.dim() == 1) { \
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
ET_CHECK_MSG( \
axis == 0, "Axis must be 0 for a single dimensional tensors"); \
const optional<int64_t> dim; \
apply_over_dim( \
[input_data_ptr, out_data_ptr, scale_data, zero_point_data]( \
size_t numel, size_t stride, size_t base_ix) { \
for (size_t i = 0; i < numel; i++) { \
size_t current_ix = base_ix * stride + i; \
float _scale = scale_data[current_ix]; \
int64_t zero_point = 0; \
if (zero_point_data != nullptr) { \
zero_point = zero_point_data[current_ix]; \
} \
out_data_ptr[current_ix] = \
static_cast<CTYPE_OUT>( \
input_data_ptr[current_ix] - zero_point) * \
_scale; \
} \
}, \
input, \
dim); \
break; \
} \
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
double _scale = scale_data[channel_ix]; \
float _scale = scale_data[channel_ix]; \
int64_t _zero_point = 0; \
if (zero_point_data != nullptr) { \
_zero_point = zero_point_data[channel_ix]; \
Expand Down
27 changes: 24 additions & 3 deletions kernels/quantized/test/op_dequantize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ TEST(OpDequantizeOutTest, TensorArgOverload) {

TEST(OpDequantizeOutTest, DequantizePerChannel) {
TensorFactory<ScalarType::Byte> tf_byte;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_byte.full({3, 2}, 100);
Tensor scale = tf_double.make({2}, {0.5, 1});
Tensor scale = tf_float.make({2}, {0.5, 1});
Tensor zero_point = tf_long.make({2}, {30, 60});
int64_t quant_min = 0;
int64_t quant_max = 255;
Expand All @@ -145,7 +145,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {

// Test with a different axis
out = tfo.zeros({3, 2});
scale = tf_double.make({3}, {0.5, 0.75, 1});
scale = tf_float.make({3}, {0.5, 0.75, 1});
zero_point = tf_long.make({3}, {30, 50, 60});
// (100 - 30) * 0.5
// (100 - 50) * 0.75
Expand All @@ -163,4 +163,25 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) {
out);

EXPECT_TENSOR_EQ(out, expected);

// Test with a different axis
out = tfo.zeros({3});
input = tf_byte.make({3}, {100, 100, 100});
scale = tf_float.make({3}, {0.5, 0.75, 1});
zero_point = tf_long.make({3}, {30, 50, 60});
// (100 - 30) * 0.5
// (100 - 50) * 0.75
// (100 - 60) * 1
expected = tfo.make({3}, {35, 37.5, 40});
dequantize_per_channel_out(
input,
scale,
zero_point,
/*axis=*/0,
quant_min,
quant_max,
ScalarType::Byte,
optional<ScalarType>(),
out);
EXPECT_TENSOR_EQ(out, expected);
}

0 comments on commit c6fb9da

Please sign in to comment.