Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM]fix group_norm compute error when compared with paddle #5683

Merged
merged 3 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions lite/kernels/arm/group_norm_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,23 @@ void GroupNormCompute::Run() {
float epsilon = param.epsilon;
int groups = param.groups;
int channels = param.channels;
int n = param.x->dims()[0];
int c = param.x->dims()[1];
auto x_dims = param.x->dims();
int n = x_dims[0];
int c = x_dims[1];
if (channels == -1) {
CHECK_EQ(param.data_layout_str, "NCHW")
<< "it only support NCHW layout!, but recived layout is "
<< param.data_layout_str;
channels = c;
}
int height = x_dims[2];
int width = x_dims[3];
int ch_per_group = channels / groups;
int height = param.x->dims()[2];
int width = param.x->dims()[3];
int spatial_size = ch_per_group * height * width;
int ngroup = n * groups;
int cnt = spatial_size >> 4;
int remain = spatial_size % 16;
float* std_vec = new float[param.saved_variance->numel()];
// compute saved_mean and saved_variance
#pragma omp parallel for
for (int n = 0; n < ngroup; ++n) {
Expand Down Expand Up @@ -103,7 +111,8 @@ void GroupNormCompute::Run() {
float variance = (summ - mean * mean * spatial_size) / spatial_size;
float std = 1.f / sqrtf(variance + epsilon);
saved_mean[n] = mean;
saved_variance[n] = std;
saved_variance[n] = variance;
std_vec[n] = std;
}
int in_size = height * width;
cnt = in_size >> 4;
Expand All @@ -117,7 +126,7 @@ void GroupNormCompute::Run() {
numc *= ch_per_group;
for (int c = 0; c < ch_per_group; c++) {
int chin = numc + c;
const float sstd_val = scale[chin] * saved_variance[i];
const float sstd_val = scale[chin] * std_vec[i];
const float bias_val = bias[chin];
const float mean_val = saved_mean[i];
const float32x4_t vsstd = vdupq_n_f32(sstd_val);
Expand Down Expand Up @@ -158,6 +167,7 @@ void GroupNormCompute::Run() {
}
}
}
delete[] std_vec;
}

} // namespace arm
Expand Down
27 changes: 19 additions & 8 deletions lite/operators/group_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,35 @@ bool GroupNormOp::CheckShape() const {
auto scale_dims = param_.scale->dims();
auto bias_dims = param_.bias->dims();
if (param_.channels == -1) {
param_.channels = x_dims[1];
param_.channels = (param_.data_layout_str == "NCHW")
? x_dims[1]
: x_dims[x_dims.size() - 1];
}
// only support NCHW
CHECK_EQ(param_.data_layout_str, "NCHW") << "data_layout must be NCHW";
CHECK(x_dims.size() >= 2 && x_dims.size() <= 5)
<< "Input X must have 2 to 5 dimensions.";
CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions.";
CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions.";
CHECK_GT(param_.epsilon, 0.f) << "epsilon should be greater than 0.f";
CHECK_LT(param_.epsilon, 0.01f) << "epsilon should be less than 0.01f";
CHECK_EQ(param_.channels, x_dims[1])
<< "Input channels must be equal input_shape[1]";
CHECK_EQ(param_.channels % param_.groups, 0)
<< "channels must be divide groups";
CHECK_LE(param_.groups, param_.channels)
<< "groups should be less than channels";
CHECK_GE(param_.groups, 1) << "groups should be greater than 1";
CHECK_EQ(param_.channels, scale_dims[0])
<< "The Input(Scale)'s first dimension size of Op(group_norm) must be "
"equal to the number of channels";
CHECK_EQ(param_.channels, bias_dims[0])
<< "The Input(Bias)'s first dimension size of Op(group_norm) must be "
"equal to the number of channels";
return true;
}

bool GroupNormOp::InferShapeImpl() const {
auto x_dims = param_.x->dims();
int64_t batch_size = x_dims[0];
int64_t num = param_.channels / param_.groups;
param_.saved_mean->Resize({batch_size * num});
param_.saved_variance->Resize({batch_size * num});
param_.saved_mean->Resize({batch_size, param_.groups});
param_.saved_variance->Resize({batch_size, param_.groups});
param_.out->Resize(x_dims);
return true;
}
Expand Down Expand Up @@ -82,6 +90,9 @@ bool GroupNormOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
}
param_.out =
scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
if (op_desc.HasAttr("data_layout")) {
param_.data_layout_str = op_desc.GetAttr<std::string>("data_layout");
}
param_.epsilon = op_desc.GetAttr<float>("epsilon");
param_.groups = op_desc.GetAttr<int>("groups");
if (op_desc.HasAttr("channels")) {
Expand Down
1 change: 1 addition & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,7 @@ struct GroupNormParam : ParamBase {
lite::Tensor* scale{};
lite::Tensor* saved_mean{};
lite::Tensor* saved_variance{};
std::string data_layout_str{"NCHW"};
float epsilon;
int groups;
int channels;
Expand Down
131 changes: 82 additions & 49 deletions lite/tests/kernels/group_norm_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ class GroupNormComputeTest : public arena::TestCase {
DDim dims_{{4, 5, 19, 19}};
float epsilon_ = 1e-5f;
int groups_ = 1;
int channels_ = dims_[1];
std::string data_layout_str_ = "NCHW";

public:
GroupNormComputeTest(const Place& place,
const std::string& alias,
DDim dims,
float epsilon,
int groups,
int channels)
std::string data_layout_str)
: TestCase(place, alias),
dims_(dims),
epsilon_(epsilon),
groups_(groups),
channels_(channels) {}
data_layout_str_(data_layout_str) {}

void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(x_);
Expand All @@ -59,7 +59,7 @@ class GroupNormComputeTest : public arena::TestCase {
CHECK(y);
CHECK(saved_mean);
CHECK(saved_variance);
DDim saved_dim({dims_[0] * groups_});
DDim saved_dim({dims_[0], groups_});
y->Resize(dims_);
saved_mean->Resize(saved_dim);
saved_variance->Resize(saved_dim);
Expand All @@ -68,49 +68,82 @@ class GroupNormComputeTest : public arena::TestCase {
auto scale_data = scale->data<float>();
auto bias_data = bias->data<float>();
auto y_data = y->mutable_data<float>();
auto saved_mean_data = saved_mean->mutable_data<float>();
auto saved_variance_data = saved_variance->mutable_data<float>();

int n = x->dims()[0];
int ch_per_group = channels_ / groups_;
CHECK_EQ(x->dims()[1], channels_);
int spatial_size = ch_per_group * x->dims()[2] * x->dims()[3];
// compute mean
for (int i = 0; i < n * groups_; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float sum = 0.f;
for (int j = 0; j < spatial_size; ++j) {
sum += x_ptr[j];
}
saved_mean_data[i] = sum / spatial_size;
}
// compute variance
for (int i = 0; i < n * groups_; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float sum = 0.f;
for (int j = 0; j < spatial_size; ++j) {
sum +=
(x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]);
}
saved_variance_data[i] = 1.f / sqrtf(sum / spatial_size + epsilon_);
}
int in_size = x->dims()[2] * x->dims()[3];
// compute out
for (int i = 0; i < n * groups_; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float* y_ptr = y_data + i * spatial_size;
int c_num = i % groups_;
for (int c = 0; c < ch_per_group; c++) {
int chin = c_num * ch_per_group + c;
float scale_val = scale_data[chin];
float bias_val = bias_data[chin];
const float* x_ch_ptr = x_ptr + c * in_size;
float* y_ch_ptr = y_ptr + c * in_size;
for (int j = 0; j < in_size; j++) {
y_ch_ptr[j] = scale_val * (x_ch_ptr[j] - saved_mean_data[i]) *
saved_variance_data[i] +
bias_val;
auto mean_data = saved_mean->mutable_data<float>();
auto var_data = saved_variance->mutable_data<float>();

auto x_dims = x->dims();
int groups = groups_;
int channels =
(data_layout_str_ == "NCHW") ? x_dims[1] : x_dims[x_dims.size() - 1];
int group_size = (channels - 1) / groups + 1;
int imsize = (data_layout_str_ == "NCHW") ? (x_dims[2] * x_dims[3])
: (x_dims[1] * x_dims[2]);

auto* iter_x_data = x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) {
for (int gid = 0; gid < groups; gid++) {
float x_mean = 0;
float x_var = 0;
int number =
std::min(group_size, static_cast<int>(channels - gid * group_size));
auto* tmp_x = iter_x_data;
auto* x_src_data = iter_x_data;
auto* tmp_y = iter_y_data;
auto* y_src_data = iter_y_data;

if (data_layout_str_ == "NCHW") {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; imid++, iter_x_data++) {
x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0];
}
}
} else {
for (int cid = 0; cid < number; cid++) {
iter_x_data = tmp_x + cid;
for (int imid = 0; imid < imsize; imid++, iter_x_data += channels) {
x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0];
}
}
iter_x_data = tmp_x + group_size;
}

x_mean /= number * imsize;
x_var /= number * imsize;
x_var = x_var - x_mean * x_mean;
float var_inv = 1.0 / std::sqrt(x_var + epsilon_);
mean_data[bid * groups + gid] = x_mean;
var_data[bid * groups + gid] = x_var;

if (data_layout_str_ == "NCHW") {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; imid++, tmp_x++, iter_y_data++) {
float val = (tmp_x[0] - x_mean) * var_inv;
if (scale_data) val *= scale_data[gid * group_size + cid];
if (bias_data) val += bias_data[gid * group_size + cid];
iter_y_data[0] = val;
}
}
} else {
for (int cid = 0; cid < number; cid++) {
tmp_x = x_src_data + cid;
iter_y_data = y_src_data + cid;
for (int imid = 0; imid < imsize;
imid++, tmp_x += channels, iter_y_data += channels) {
float val = (tmp_x[0] - x_mean) * var_inv;
if (scale_data) val *= scale_data[gid * group_size + cid];
if (bias_data) val += bias_data[gid * group_size + cid];
iter_y_data[0] = val;
}
}
iter_y_data = tmp_y + group_size;
}
}
if (data_layout_str_ == "NCHW") {
iter_x_data = x_data + (bid + 1) * channels * imsize;
iter_y_data = y_data + (bid + 1) * channels * imsize;
}
}
}
Expand All @@ -125,7 +158,7 @@ class GroupNormComputeTest : public arena::TestCase {
op_desc->SetOutput("Variance", {saved_variance_});
op_desc->SetAttr("epsilon", epsilon_);
op_desc->SetAttr("groups", groups_);
op_desc->SetAttr("channels", channels_);
op_desc->SetAttr("data_layout", data_layout_str_);
}

void PrepareData() override {
Expand All @@ -148,7 +181,7 @@ void TestGroupNorm(Place place,
float abs_error = 6e-5,
std::vector<std::string> ignored_outs = {}) {
for (auto& n : {1, 3, 16}) {
for (auto& c : {1}) {
for (auto& c : {1, 2}) {
for (auto& h : {1, 16, 33, 56}) {
for (auto& w : {1, 17, 55}) {
for (auto& groups : {1, 2, 4}) {
Expand All @@ -158,7 +191,7 @@ void TestGroupNorm(Place place,
DDim dim_in({n, c, h, w});
float epsilon = 1e-5f;
std::unique_ptr<arena::TestCase> tester(new GroupNormComputeTest(
place, "def", dim_in, epsilon, groups, c));
place, "def", dim_in, epsilon, groups, "NCHW"));
#ifdef LITE_WITH_ARM
if (place == TARGET(kARM)) {
auto& ctx = tester->context()->As<ARMContext>();
Expand Down