Skip to content

Commit

Permalink
add default parameter value to runtime::reference::fake_quantize
Browse files Browse the repository at this point in the history
  • Loading branch information
pelszkow committed Jul 13, 2021
1 parent 2e9c6fd commit 9d2c00d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,106 +195,72 @@ namespace ngraph
}

} // namespace fake_quantize_details
namespace v1

template <typename T>
void
fake_quantize(const T* const arg,
const T* const in_low,
const T* const in_high,
const T* const out_low,
const T* const out_high,
T* const out,
const Shape& arg_shape,
const Shape& in_low_shape,
const Shape& in_high_shape,
const Shape& out_low_shape,
const Shape& out_high_shape,
size_t levels,
const op::AutoBroadcastSpec& broadcast = op::AutoBroadcastType::NUMPY)
{
template <typename T>
void fake_quantize(const T* const arg,
const T* const in_low,
const T* const in_high,
const T* const out_low,
const T* const out_high,
T* const out,
const Shape& arg_shape,
const Shape& in_low_shape,
const Shape& in_high_shape,
const Shape& out_low_shape,
const Shape& out_high_shape,
size_t levels,
const op::AutoBroadcastSpec& broadcast)
{
using namespace fake_quantize_details;
using namespace fake_quantize_details;

const FeRound round_mode(FE_TONEAREST);
const FeRound round_mode(FE_TONEAREST);

if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 &&
shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1)
if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 &&
shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1)
{
const size_t arg_size = shape_size(arg_shape);
const auto q = [=](const T& a) {
return quantize(a, *in_low, *in_high, *out_low, *out_high, levels);
};
for (size_t i = 0; i < arg_size; ++i)
{
const size_t arg_size = shape_size(arg_shape);
const auto q = [=](const T& a) {
return quantize(a, *in_low, *in_high, *out_low, *out_high, levels);
};
for (size_t i = 0; i < arg_size; ++i)
{
out[i] = q(arg[i]);
}
out[i] = q(arg[i]);
}
else
{
NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() &&
in_high_shape.size() <= arg_shape.size() &&
out_low_shape.size() <= arg_shape.size() &&
out_high_shape.size() <= arg_shape.size(),
"Tensors with inout\\output ranges should have rank less or "
"equal to data tensor rank equal to ",
arg_shape.size());
}
else
{
NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() &&
in_high_shape.size() <= arg_shape.size() &&
out_low_shape.size() <= arg_shape.size() &&
out_high_shape.size() <= arg_shape.size(),
"Tensors with inout\\output ranges should have rank less or "
"equal to data tensor rank equal to ",
arg_shape.size());

const QuantizationBound<T> in_low_bound(
in_low, in_low_shape, arg_shape, broadcast);
const QuantizationBound<T> in_high_bound(
in_high, in_high_shape, arg_shape, broadcast);
const QuantizationBound<T> out_low_bound(
out_low, out_low_shape, arg_shape, broadcast);
const QuantizationBound<T> out_high_bound(
out_high, out_high_shape, arg_shape, broadcast);
const QuantizationBound<T> in_low_bound(
in_low, in_low_shape, arg_shape, broadcast);
const QuantizationBound<T> in_high_bound(
in_high, in_high_shape, arg_shape, broadcast);
const QuantizationBound<T> out_low_bound(
out_low, out_low_shape, arg_shape, broadcast);
const QuantizationBound<T> out_high_bound(
out_high, out_high_shape, arg_shape, broadcast);

std::vector<size_t> current_dim(arg_shape.size(), 0);
const auto arg_shape_size = shape_size(arg_shape);
for (size_t index = 0; index < arg_shape_size; ++index)
{
const T in_low_val = in_low_bound.get_value(current_dim, index);
const T in_high_val = in_high_bound.get_value(current_dim, index);
const T out_low_val = out_low_bound.get_value(current_dim, index);
const T out_high_val = out_high_bound.get_value(current_dim, index);
std::vector<size_t> current_dim(arg_shape.size(), 0);
const auto arg_shape_size = shape_size(arg_shape);
for (size_t index = 0; index < arg_shape_size; ++index)
{
const T in_low_val = in_low_bound.get_value(current_dim, index);
const T in_high_val = in_high_bound.get_value(current_dim, index);
const T out_low_val = out_low_bound.get_value(current_dim, index);
const T out_high_val = out_high_bound.get_value(current_dim, index);

out[index] = quantize(arg[index],
in_low_val,
in_high_val,
out_low_val,
out_high_val,
levels);
increment_current_dim(current_dim, arg_shape);
}
out[index] = quantize(
arg[index], in_low_val, in_high_val, out_low_val, out_high_val, levels);
increment_current_dim(current_dim, arg_shape);
}
}
} // namespace v1

template <typename T>
void fake_quantize(const T* const arg,
const T* const in_low,
const T* const in_high,
const T* const out_low,
const T* const out_high,
T* const out,
const Shape& arg_shape,
const Shape& in_low_shape,
const Shape& in_high_shape,
const Shape& out_low_shape,
const Shape& out_high_shape,
size_t levels)
{
v1::fake_quantize(arg,
in_low,
in_high,
out_low,
out_high,
out,
arg_shape,
in_low_shape,
in_high_shape,
out_low_shape,
out_high_shape,
levels,
op::AutoBroadcastType::NUMPY);
}
} // namespace reference
} // namespace runtime
Expand Down
28 changes: 14 additions & 14 deletions ngraph/test/runtime/interpreter/evaluates_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,7 @@ namespace
info.selected_outputs_shape,
selected_indices.data(),
info.selected_indices_shape,
valid_outputs.data());
valid_outputs.data());

void* pscores = nullptr;
void* pselected_num = nullptr;
Expand Down Expand Up @@ -2383,19 +2383,19 @@ namespace
const HostTensorVector& inputs)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::v1::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
inputs[1]->get_data_ptr<const T>(),
inputs[2]->get_data_ptr<const T>(),
inputs[3]->get_data_ptr<const T>(),
inputs[4]->get_data_ptr<const T>(),
outputs[0]->get_data_ptr<T>(),
op->get_input_shape(0),
op->get_input_shape(1),
op->get_input_shape(2),
op->get_input_shape(3),
op->get_input_shape(4),
op->get_levels(),
op->get_auto_broadcast());
runtime::reference::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
inputs[1]->get_data_ptr<const T>(),
inputs[2]->get_data_ptr<const T>(),
inputs[3]->get_data_ptr<const T>(),
inputs[4]->get_data_ptr<const T>(),
outputs[0]->get_data_ptr<T>(),
op->get_input_shape(0),
op->get_input_shape(1),
op->get_input_shape(2),
op->get_input_shape(3),
op->get_input_shape(4),
op->get_levels(),
op->get_auto_broadcast());
return true;
}

Expand Down

0 comments on commit 9d2c00d

Please sign in to comment.