Skip to content

Commit

Permalink
Upsample support NHWC (#10824)
Browse files Browse the repository at this point in the history
This patch implement bilinear interpolation for Upsample/Resize 4-D input with
the outermost and innermost scale (usually channel of NHWC) as 1. It is
parallelized with output_height * output_width instead of one dimension only.

Besides, I also revert the HandleResize back to the original implementation for
TransposeOptimizerTests.TestResize* tests.

Finally, I add microbenchmark BM_NhwcUpsampleBilinear.
  • Loading branch information
yihonglyu authored Apr 11, 2022
1 parent 269be2f commit 749c0dd
Show file tree
Hide file tree
Showing 8 changed files with 1,116 additions and 349 deletions.
1 change: 1 addition & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
${BENCHMARK_DIR}/main.cc
${BENCHMARK_DIR}/modeltest.cc
${BENCHMARK_DIR}/pooling.cc
${BENCHMARK_DIR}/resize.cc
${BENCHMARK_DIR}/batchnorm.cc
${BENCHMARK_DIR}/batchnorm2.cc
${BENCHMARK_DIR}/tptest.cc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -967,41 +967,35 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con
node.SetInput(i, gather_output);
}

// static bool HandleResize(HandlerArgs& args) {
// auto inputs = args.node.Inputs();
// int64_t rank_int = gsl::narrow_cast<int64_t>(args.perm.size());
//
// auto p = ChannelFirstToLastPerm(rank_int);
// auto& perm = p == args.perm ? args.perm : args.perm_inv;
// auto& perm_inv = p == args.perm ? args.perm_inv : args.perm;
//
// if (args.ctx.opset < 11) {
// PermuteInput(args.ctx.graph, args.node, 1, perm);
// } else {
// if (inputs[1] != "") {
// std::vector<int64_t> double_perm_inv = perm;
// double_perm_inv.reserve(2 * args.perm.size());
// for (int64_t p1 : perm) {
// double_perm_inv.push_back(p1 + rank_int);
// }
// PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv);
// }
// for (size_t i = 2; i < inputs.size(); ++i) {
// if (inputs[i] != "") {
// PermuteInput(args.ctx.graph, args.node, i, perm);
// }
// }
// }
//
// TransposeFirstInput(args.ctx, args.node, perm);
// TransposeOutputs(args.ctx, args.node, perm_inv);
//
// SwapNodeOpTypeAndDomain(args.ctx.graph, args.node, args.node.OpType(), "com.microsoft.nhwc");
//
// return true;
// }
static bool HandleResize(HandlerArgs& args) {
auto inputs = args.node.Inputs();
int64_t rank_int = gsl::narrow_cast<int64_t>(args.perm.size());

if (args.ctx.opset < 11) {
PermuteInput(args.ctx.graph, args.node, 1, args.perm_inv);
} else {
if (inputs[1] != "") {
std::vector<int64_t> double_perm_inv = args.perm_inv;
double_perm_inv.reserve(2 * args.perm_inv.size());
for (int64_t p : args.perm_inv) {
double_perm_inv.push_back(p + rank_int);
}
PermuteInput(args.ctx.graph, args.node, 1, double_perm_inv);
}
for (size_t i = 2; i < inputs.size(); ++i) {
if (inputs[i] != "") {
PermuteInput(args.ctx.graph, args.node, i, args.perm_inv);
}
}
}

TransposeFirstInput(args.ctx, args.node, args.perm_inv);
TransposeOutputs(args.ctx, args.node, args.perm);

return true;
}

// constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize};
constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize};

static bool HandlePad(HandlerArgs& args) {
size_t rank = args.perm.size();
Expand Down Expand Up @@ -1697,9 +1691,7 @@ static const std::unordered_map<std::string_view, const HandlerInfo&> handler_ma
{"Split", split_handler},
{"Shape", shape_handler},
{"Pad", pad_handler},
// Todo: renable resize handler after adding NHWC support in upsample op on cpu
// https://github.com/microsoft/onnxruntime/issues/9857
// {"Resize", resize_handler},
{"Resize", resize_handler},
{"ReduceSum", reduce_sum_handler},

{"ReduceLogSum", reduce_op_handler},
Expand Down
189 changes: 88 additions & 101 deletions onnxruntime/core/providers/cpu/tensor/upsample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,39 +397,24 @@ static Status UpsampleLinear(const T* input,
}
*/

struct BilinearParams {
std::vector<float> x_original;
std::vector<float> y_original;

BufferUniquePtr idx_scale_data_buffer_holder;

int64_t* input_width_mul_y1;
int64_t* input_width_mul_y2;

int64_t* in_x1;
int64_t* in_x2;

float* dx1;
float* dx2;

float* dy1;
float* dy2;
};

// The following method supports a 4-D input in 'Linear mode'
// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes
// the scale values for the outermost 2 dimensions are 1.
// 1. the scale values for the outermost 2 dimensions are 1 or
// 2. the scale values for the outermost and innermost dimensions are 1
// This is the common use-case where the 4-D input (batched multi-channel images)
// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
static BilinearParams SetupUpsampleBilinear(int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
float height_scale,
float width_scale,
const std::vector<float>& roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate) {
// is usually of shapes:
// - [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
// - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0]
BilinearParams SetupUpsampleBilinear(const int64_t input_height,
const int64_t input_width,
const int64_t output_height,
const int64_t output_width,
const float height_scale,
const float width_scale,
const std::vector<float>& roi,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
bool is_nchw) {
BilinearParams p;

p.x_original.reserve(output_width);
Expand Down Expand Up @@ -471,8 +456,9 @@ static BilinearParams SetupUpsampleBilinear(int64_t input_height,
p.dx2 = p.dx1 + output_width;

// Start processing
auto roi_y_start = roi.size() / 2 - 2;
auto roi_y_end = roi.size() - 2;
const size_t height_rindex = is_nchw ? 1 : 2;
auto roi_y_start = roi.size() / 2 - (height_rindex + 1);
auto roi_y_end = roi.size() - (height_rindex + 1);
for (int64_t y = 0; y < output_height; ++y) {
float in_y = height_scale == 1 ? static_cast<float>(y)
: get_original_coordinate(static_cast<float>(y), height_scale,
Expand All @@ -496,8 +482,9 @@ static BilinearParams SetupUpsampleBilinear(int64_t input_height,
p.input_width_mul_y2[y] = input_width * in_y2;
}

auto roi_x_start = roi.size() / 2 - 1;
auto roi_x_end = roi.size() - 1;
const size_t width_rindex = is_nchw ? 0 : 1;
auto roi_x_start = roi.size() / 2 - (width_rindex + 1);
auto roi_x_end = roi.size() - (width_rindex + 1);
for (int64_t x = 0; x < output_width; ++x) {
float in_x = width_scale == 1 ? static_cast<float>(x)
: get_original_coordinate(static_cast<float>(x),
Expand All @@ -522,59 +509,6 @@ static BilinearParams SetupUpsampleBilinear(int64_t input_height,
return p;
}

template <typename T>
void UpsampleBilinear(int64_t batch_size,
int64_t num_channels,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width,
float height_scale,
float width_scale,
const std::vector<float>& roi,
bool use_extrapolation,
float extrapolation_value,
const T* XdataBase,
T* YdataBase,
AllocatorPtr& alloc,
const GetOriginalCoordinateFunc& get_original_coordinate,
concurrency::ThreadPool* tp) {
BilinearParams p = SetupUpsampleBilinear(input_height, input_width, output_height, output_width,
height_scale, width_scale, roi,
alloc, get_original_coordinate);

for (int64_t n = 0; n < batch_size; ++n) {
concurrency::ThreadPool::TrySimpleParallelFor(
tp, num_channels,
[&](std::ptrdiff_t c) {
const T* Xdata = XdataBase + (n * num_channels + c) * (input_height * input_width);
T* Ydata = YdataBase + (n * num_channels + c) * (output_height * output_width);
for (int64_t y = 0; y < output_height; ++y) {
for (int64_t x = 0; x < output_width; ++x) {
// when use_extrapolation is set and original index of x or y is out of the dim range
// then use extrapolation_value as the output value.
if (use_extrapolation &&
((p.y_original[y] < 0 || p.y_original[y] > static_cast<float>(input_height - 1)) ||
(p.x_original[x] < 0 || p.x_original[x] > static_cast<float>(input_width - 1)))) {
Ydata[output_width * y + x] = static_cast<T>(extrapolation_value);
continue;
}

T X11 = Xdata[p.input_width_mul_y1[y] + p.in_x1[x]];
T X21 = Xdata[p.input_width_mul_y1[y] + p.in_x2[x]];
T X12 = Xdata[p.input_width_mul_y2[y] + p.in_x1[x]];
T X22 = Xdata[p.input_width_mul_y2[y] + p.in_x2[x]];

Ydata[output_width * y + x] = static_cast<T>(p.dx2[x] * p.dy2[y] * X11 +
p.dx1[x] * p.dy2[y] * X21 +
p.dx2[x] * p.dy1[y] * X12 +
p.dx1[x] * p.dy1[y] * X22);
}
}
});
}
}

struct TrilinearParams {
std::vector<float> x_original;
std::vector<float> y_original;
Expand Down Expand Up @@ -1065,25 +999,78 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context,
case UpsampleMode::LINEAR: {
// Supports 'bilinear' and 'trilinear' sampling only

//'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1
//'bilinear' == 2-D input or 4-D input with outermost 2 scales as 1 or
// 4-D input with outermost and innermost scales as 1
if (dims.size() == 2 || dims.size() == 4) {
bool is_2D = dims.size() == 2;

const int64_t batch_size = is_2D ? 1 : dims[0];
const int64_t num_channels = is_2D ? 1 : dims[1];
const int64_t input_height = is_2D ? dims[0] : dims[2];
const int64_t input_width = is_2D ? dims[1] : dims[3];

const int64_t output_height = is_2D ? output_dims[0] : output_dims[2];
const int64_t output_width = is_2D ? output_dims[1] : output_dims[3];
bool is_nchw = true;

int64_t batch_size;
int64_t num_channels;
int64_t input_height;
int64_t input_width;

int64_t output_height;
int64_t output_width;

float height_scale;
float width_scale;

if (is_2D) {
batch_size = 1;
num_channels = 1;
input_height = dims[0];
input_width = dims[1];

output_height = output_dims[0];
output_width = output_dims[1];

height_scale = scales[0];
width_scale = scales[1];
} else {
if (scales[1] == 1.0f) {
batch_size = dims[0];
num_channels = dims[1];
input_height = dims[2];
input_width = dims[3];

output_height = output_dims[2];
output_width = output_dims[3];

height_scale = scales[2];
width_scale = scales[3];
} else {
ORT_ENFORCE(scales[3] == 1.0f, "4-D input with innermost scale (usually channel of NHWC) as 1.");
is_nchw = false;

batch_size = dims[0];
num_channels = dims[3];
input_height = dims[1];
input_width = dims[2];

output_height = output_dims[1];
output_width = output_dims[2];

height_scale = scales[1];
width_scale = scales[2];
}
}

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
UpsampleBilinear(batch_size, num_channels, input_height, input_width, output_height, output_width,
is_2D ? scales[0] : scales[2], is_2D ? scales[1] : scales[3], roi,
use_extrapolation_, extrapolation_value_, X->Data<T>(),
Y->MutableData<T>(), alloc, get_original_coordinate_,
output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr);
if (is_nchw) {
UpsampleBilinear(batch_size, num_channels, input_height, input_width, output_height, output_width,
height_scale, width_scale, roi,
use_extrapolation_, extrapolation_value_, X->Data<T>(),
Y->MutableData<T>(), alloc, get_original_coordinate_,
output_height * output_width > 64 ? context->GetOperatorThreadPool() : nullptr);
} else {
NhwcUpsampleBilinear(batch_size, num_channels, input_height, input_width, output_height, output_width,
height_scale, width_scale, roi,
use_extrapolation_, extrapolation_value_, X->Data<T>(),
Y->MutableData<T>(), alloc, get_original_coordinate_,
output_height * output_width * num_channels > 64 ? context->GetOperatorThreadPool() : nullptr);
}
return Status::OK();
} else if (dims.size() == 3 || dims.size() == 5) {
//'trilinear' == 3-D input or 5-D input with outermost 2 scales as 1
Expand Down
Loading

0 comments on commit 749c0dd

Please sign in to comment.