Skip to content

Commit

Permalink
Add assertion for input scale
Browse files Browse the repository at this point in the history
Co-authored-by: Toshiki Maekawa <toshiki.maekawa-aoha-renesas@aoha.co.jp>
  • Loading branch information
maekawatoshiki and Toshiki Maekawa committed May 12, 2023
1 parent fda6052 commit 6541465
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/relay/qnn/op/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,28 @@ Expr QnnSoftmaxCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
// Expected: input, scale, zero_point, output_scale, output_zero_point
ICHECK_EQ(new_args.size(), 5);

const Expr input_scale = new_args[1];
const auto const_i32 = [&](int32_t val) { return MakeConstantScalar(DataType::Int(32), val); };
const auto const_f32 = [&](float val) { return MakeConstantScalar(DataType::Float(32), val); };

const auto const_input_scale = new_args[1].as<ConstantNode>();
ICHECK(const_input_scale) << "Input scale should be constant.";
ICHECK(const_input_scale->is_scalar()) << "Input scale should be scalar.";
const float input_scale = static_cast<float*>(const_input_scale->data->data)[0];
ICHECK(input_scale <= 1.f) << "Input scale should be less than or equal to 1.";

const Expr input_zero_point = new_args[2];
const Expr output_scale = new_args[3];
const Expr output_zero_point = new_args[4];
const int axis = attrs.as<SoftmaxAttrs>()->axis;

// Refer to the Algorithm 1 in https://arxiv.org/pdf/2207.01405.pdf

const Expr quantized_data =
Subtract(Cast(new_args[0], DataType::Int(32)), Cast(input_zero_point, DataType::Int(32)));
const Expr quantized_data = Subtract(Cast(new_args[0], DataType::Int(32)), input_zero_point);

const Expr x_0 = ConvertDtype(
Round(Divide(MakeConstantScalar(DataType::Float(32), 1.f), input_scale)), DataType::Int(32));
const Expr x_0 = ConvertDtype(const_f32(std::round(1.f / input_scale)), DataType::Int(32));
const Expr max = Max(quantized_data, {axis}, true, false);
const Expr x = Subtract(quantized_data, max);

const auto const_i32 = [&](int32_t val) { return MakeConstantScalar(DataType::Int(32), val); };
const int n = 8;
const int m = 30;
const int bits = 8;
Expand All @@ -114,10 +119,9 @@ Expr QnnSoftmaxCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
const Expr sums = Sum(exps, {axis}, true, false);
const Expr output =
RightShift(Multiply(Divide(const_i32(1 << m), sums), exps), const_i32(m - (bits - 1)));
const Expr requantized =
Requantize(output, arg_types[0].as<TensorTypeNode>()->shape,
MakeConstantScalar(DataType::Float(32), 1.f / (1 << (bits - 1))), const_i32(0),
output_scale, output_zero_point, DataType::Int(bits), 0);
const Expr requantized = Requantize(output, arg_types[0].as<TensorTypeNode>()->shape,
const_f32(1.f / (1 << (bits - 1))), const_i32(0),
output_scale, output_zero_point, DataType::Int(bits), 0);

return requantized;
}
Expand Down

0 comments on commit 6541465

Please sign in to comment.