diff --git a/lite/kernels/arm/group_norm_compute.cc b/lite/kernels/arm/group_norm_compute.cc index 961361079b4..044db291815 100644 --- a/lite/kernels/arm/group_norm_compute.cc +++ b/lite/kernels/arm/group_norm_compute.cc @@ -35,15 +35,23 @@ void GroupNormCompute::Run() { float epsilon = param.epsilon; int groups = param.groups; int channels = param.channels; - int n = param.x->dims()[0]; - int c = param.x->dims()[1]; + auto x_dims = param.x->dims(); + int n = x_dims[0]; + int c = x_dims[1]; + if (channels == -1) { + CHECK_EQ(param.data_layout_str, "NCHW") + << "it only support NCHW layout!, but recived layout is " + << param.data_layout_str; + channels = c; + } + int height = x_dims[2]; + int width = x_dims[3]; int ch_per_group = channels / groups; - int height = param.x->dims()[2]; - int width = param.x->dims()[3]; int spatial_size = ch_per_group * height * width; int ngroup = n * groups; int cnt = spatial_size >> 4; int remain = spatial_size % 16; + float* std_vec = new float[param.saved_variance->numel()]; // compute saved_mean and saved_variance #pragma omp parallel for for (int n = 0; n < ngroup; ++n) { @@ -103,7 +111,8 @@ void GroupNormCompute::Run() { float variance = (summ - mean * mean * spatial_size) / spatial_size; float std = 1.f / sqrtf(variance + epsilon); saved_mean[n] = mean; - saved_variance[n] = std; + saved_variance[n] = variance; + std_vec[n] = std; } int in_size = height * width; cnt = in_size >> 4; @@ -117,7 +126,7 @@ void GroupNormCompute::Run() { numc *= ch_per_group; for (int c = 0; c < ch_per_group; c++) { int chin = numc + c; - const float sstd_val = scale[chin] * saved_variance[i]; + const float sstd_val = scale[chin] * std_vec[i]; const float bias_val = bias[chin]; const float mean_val = saved_mean[i]; const float32x4_t vsstd = vdupq_n_f32(sstd_val); @@ -158,6 +167,7 @@ void GroupNormCompute::Run() { } } } + delete[] std_vec; } } // namespace arm diff --git a/lite/operators/group_norm_op.cc b/lite/operators/group_norm_op.cc index 458fc5a837f..b1a5b944f7e 100644 --- a/lite/operators/group_norm_op.cc +++ b/lite/operators/group_norm_op.cc @@ -34,27 +34,35 @@ bool GroupNormOp::CheckShape() const { auto scale_dims = param_.scale->dims(); auto bias_dims = param_.bias->dims(); if (param_.channels == -1) { - param_.channels = x_dims[1]; + param_.channels = (param_.data_layout_str == "NCHW") + ? x_dims[1] + : x_dims[x_dims.size() - 1]; } + // only support NCHW + CHECK_EQ(param_.data_layout_str, "NCHW") << "data_layout must be NCHW"; CHECK(x_dims.size() >= 2 && x_dims.size() <= 5) << "Input X must have 2 to 5 dimensions."; CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions."; CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions."; CHECK_GT(param_.epsilon, 0.f) << "epsilon should be greater than 0.f"; CHECK_LT(param_.epsilon, 0.01f) << "epsilon should be less than 0.01f"; - CHECK_EQ(param_.channels, x_dims[1]) - << "Input channels must be equal input_shape[1]"; - CHECK_EQ(param_.channels % param_.groups, 0) - << "channels must be divide groups"; + CHECK_LE(param_.groups, param_.channels) + << "groups should be less than channels"; + CHECK_GE(param_.groups, 1) << "groups should be greater than 1"; + CHECK_EQ(param_.channels, scale_dims[0]) + << "The Input(Scale)'s first dimension size of Op(group_norm) must be " + "equal to the number of channels"; + CHECK_EQ(param_.channels, bias_dims[0]) + << "The Input(Bias)'s first dimension size of Op(group_norm) must be " + "equal to the number of channels"; return true; } bool GroupNormOp::InferShapeImpl() const { auto x_dims = param_.x->dims(); int64_t batch_size = x_dims[0]; - int64_t num = param_.channels / param_.groups; - param_.saved_mean->Resize({batch_size * num}); - param_.saved_variance->Resize({batch_size * num}); + param_.saved_mean->Resize({batch_size, param_.groups}); + param_.saved_variance->Resize({batch_size, param_.groups}); param_.out->Resize(x_dims); return true; } @@ -82,6 +90,9 @@ bool GroupNormOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { } param_.out = scope->FindVar(op_desc.Output("Y").front())->GetMutable(); + if (op_desc.HasAttr("data_layout")) { + param_.data_layout_str = op_desc.GetAttr("data_layout"); + } param_.epsilon = op_desc.GetAttr("epsilon"); param_.groups = op_desc.GetAttr("groups"); if (op_desc.HasAttr("channels")) { diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 96f76015913..167d4fd65d1 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1797,6 +1797,7 @@ struct GroupNormParam : ParamBase { lite::Tensor* scale{}; lite::Tensor* saved_mean{}; lite::Tensor* saved_variance{}; + std::string data_layout_str{"NCHW"}; float epsilon; int groups; int channels; diff --git a/lite/tests/kernels/group_norm_compute_test.cc b/lite/tests/kernels/group_norm_compute_test.cc index 43b1fefe1d3..b5db9aa60b3 100644 --- a/lite/tests/kernels/group_norm_compute_test.cc +++ b/lite/tests/kernels/group_norm_compute_test.cc @@ -34,7 +34,7 @@ class GroupNormComputeTest : public arena::TestCase { DDim dims_{{4, 5, 19, 19}}; float epsilon_ = 1e-5f; int groups_ = 1; - int channels_ = dims_[1]; + std::string data_layout_str_ = "NCHW"; public: GroupNormComputeTest(const Place& place, @@ -42,12 +42,12 @@ class GroupNormComputeTest : public arena::TestCase { DDim dims, float epsilon, int groups, - int channels) + std::string data_layout_str) : TestCase(place, alias), dims_(dims), epsilon_(epsilon), groups_(groups), - channels_(channels) {} + data_layout_str_(data_layout_str) {} void RunBaseline(Scope* scope) override { auto x = scope->FindTensor(x_); @@ -59,7 +59,7 @@ class GroupNormComputeTest : public arena::TestCase { CHECK(y); CHECK(saved_mean); CHECK(saved_variance); - DDim saved_dim({dims_[0] * groups_}); + DDim saved_dim({dims_[0], groups_}); y->Resize(dims_); saved_mean->Resize(saved_dim); saved_variance->Resize(saved_dim); @@ -68,49 +68,82 @@ class GroupNormComputeTest : public arena::TestCase { auto scale_data = scale->data(); auto bias_data = bias->data(); auto y_data = y->mutable_data(); - auto saved_mean_data = saved_mean->mutable_data(); - auto saved_variance_data = saved_variance->mutable_data(); - - int n = x->dims()[0]; - int ch_per_group = channels_ / groups_; - CHECK_EQ(x->dims()[1], channels_); - int spatial_size = ch_per_group * x->dims()[2] * x->dims()[3]; - // compute mean - for (int i = 0; i < n * groups_; ++i) { - const float* x_ptr = x_data + i * spatial_size; - float sum = 0.f; - for (int j = 0; j < spatial_size; ++j) { - sum += x_ptr[j]; - } - saved_mean_data[i] = sum / spatial_size; - } - // compute variance - for (int i = 0; i < n * groups_; ++i) { - const float* x_ptr = x_data + i * spatial_size; - float sum = 0.f; - for (int j = 0; j < spatial_size; ++j) { - sum += - (x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]); - } - saved_variance_data[i] = 1.f / sqrtf(sum / spatial_size + epsilon_); - } - int in_size = x->dims()[2] * x->dims()[3]; - // compute out - for (int i = 0; i < n * groups_; ++i) { - const float* x_ptr = x_data + i * spatial_size; - float* y_ptr = y_data + i * spatial_size; - int c_num = i % groups_; - for (int c = 0; c < ch_per_group; c++) { - int chin = c_num * ch_per_group + c; - float scale_val = scale_data[chin]; - float bias_val = bias_data[chin]; - const float* x_ch_ptr = x_ptr + c * in_size; - float* y_ch_ptr = y_ptr + c * in_size; - for (int j = 0; j < in_size; j++) { - y_ch_ptr[j] = scale_val * (x_ch_ptr[j] - saved_mean_data[i]) * - saved_variance_data[i] + - bias_val; + auto mean_data = saved_mean->mutable_data(); + auto var_data = saved_variance->mutable_data(); + + auto x_dims = x->dims(); + int groups = groups_; + int channels = + (data_layout_str_ == "NCHW") ? x_dims[1] : x_dims[x_dims.size() - 1]; + int group_size = (channels - 1) / groups + 1; + int imsize = (data_layout_str_ == "NCHW") ? (x_dims[2] * x_dims[3]) + : (x_dims[1] * x_dims[2]); + + auto* iter_x_data = x_data; + auto* iter_y_data = y_data; + for (int bid = 0; bid < x_dims[0]; bid++) { + for (int gid = 0; gid < groups; gid++) { + float x_mean = 0; + float x_var = 0; + int number = + std::min(group_size, static_cast(channels - gid * group_size)); + auto* tmp_x = iter_x_data; + auto* x_src_data = iter_x_data; + auto* tmp_y = iter_y_data; + auto* y_src_data = iter_y_data; + + if (data_layout_str_ == "NCHW") { + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; imid++, iter_x_data++) { + x_mean += iter_x_data[0]; + x_var += iter_x_data[0] * iter_x_data[0]; + } + } + } else { + for (int cid = 0; cid < number; cid++) { + iter_x_data = tmp_x + cid; + for (int imid = 0; imid < imsize; imid++, iter_x_data += channels) { + x_mean += iter_x_data[0]; + x_var += iter_x_data[0] * iter_x_data[0]; + } + } + iter_x_data = tmp_x + group_size; } + + x_mean /= number * imsize; + x_var /= number * imsize; + x_var = x_var - x_mean * x_mean; + float var_inv = 1.0 / std::sqrt(x_var + epsilon_); + mean_data[bid * groups + gid] = x_mean; + var_data[bid * groups + gid] = x_var; + + if (data_layout_str_ == "NCHW") { + for (int cid = 0; cid < number; cid++) { + for (int imid = 0; imid < imsize; imid++, tmp_x++, iter_y_data++) { + float val = (tmp_x[0] - x_mean) * var_inv; + if (scale_data) val *= scale_data[gid * group_size + cid]; + if (bias_data) val += bias_data[gid * group_size + cid]; + iter_y_data[0] = val; + } + } + } else { + for (int cid = 0; cid < number; cid++) { + tmp_x = x_src_data + cid; + iter_y_data = y_src_data + cid; + for (int imid = 0; imid < imsize; + imid++, tmp_x += channels, iter_y_data += channels) { + float val = (tmp_x[0] - x_mean) * var_inv; + if (scale_data) val *= scale_data[gid * group_size + cid]; + if (bias_data) val += bias_data[gid * group_size + cid]; + iter_y_data[0] = val; + } + } + iter_y_data = tmp_y + group_size; + } + } + if (data_layout_str_ == "NCHW") { + iter_x_data = x_data + (bid + 1) * channels * imsize; + iter_y_data = y_data + (bid + 1) * channels * imsize; } } } @@ -125,7 +158,7 @@ class GroupNormComputeTest : public arena::TestCase { op_desc->SetOutput("Variance", {saved_variance_}); op_desc->SetAttr("epsilon", epsilon_); op_desc->SetAttr("groups", groups_); - op_desc->SetAttr("channels", channels_); + op_desc->SetAttr("data_layout", data_layout_str_); } void PrepareData() override { @@ -148,7 +181,7 @@ void TestGroupNorm(Place place, float abs_error = 6e-5, std::vector ignored_outs = {}) { for (auto& n : {1, 3, 16}) { - for (auto& c : {1}) { + for (auto& c : {1, 2}) { for (auto& h : {1, 16, 33, 56}) { for (auto& w : {1, 17, 55}) { for (auto& groups : {1, 2, 4}) { @@ -158,7 +191,7 @@ void TestGroupNorm(Place place, DDim dim_in({n, c, h, w}); float epsilon = 1e-5f; std::unique_ptr tester(new GroupNormComputeTest( - place, "def", dim_in, epsilon, groups, c)); + place, "def", dim_in, epsilon, groups, "NCHW")); #ifdef LITE_WITH_ARM if (place == TARGET(kARM)) { auto& ctx = tester->context()->As();