Skip to content

Commit

Permalink
[MLU]fix sync_batch_norm and concat_grad op (#44586)
Browse files Browse the repository at this point in the history
  • Loading branch information
qipengh authored Jul 27, 2022
1 parent 84d595f commit f49b0cb
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 10 deletions.
9 changes: 7 additions & 2 deletions paddle/fluid/operators/concat_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,23 @@ class ConcatGradMLUKernel : public framework::OpKernel<T> {
out_grad->dims().size()));
// get output tensor that the name is not kEmptyVarName
std::vector<void*> outputs_vec;
std::vector<Tensor> tmp_outputs_vec;
std::vector<MLUCnnlTensorDesc> output_descs;
std::vector<cnnlTensorDescriptor_t> descs_vec;
for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace());
output_descs.emplace_back(MLUCnnlTensorDesc(*outs[j]));
descs_vec.push_back(output_descs.back().get());
outputs_vec.push_back(GetBasePtr(outs[j]));
} else {
outputs_vec.push_back(nullptr);
Tensor tmp_tensor;
tmp_tensor.mutable_data<T>(ins[j]->dims(), ctx.GetPlace());
tmp_outputs_vec.push_back(tmp_tensor);
output_descs.emplace_back(MLUCnnlTensorDesc(*ins[j]));
outputs_vec.push_back(GetBasePtr(&(tmp_outputs_vec.back())));
}
descs_vec.push_back(output_descs.back().get());
}

MLUCnnlTensorDesc out_grad_desc(*out_grad);
Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/operators/sync_batch_norm_op_mlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ limitations under the Licnse. */
namespace paddle {
namespace operators {

#define NO_USE_CNCL 0
#define GET_LAYOUT_OFFSET 2

using Tensor = framework::Tensor;
static std::vector<cnnlTensorLayout_t> supported_input_layout = {
CNNL_LAYOUT_NC, CNNL_LAYOUT_NLC, CNNL_LAYOUT_NHWC, CNNL_LAYOUT_NDHWC};
Expand Down Expand Up @@ -165,6 +167,7 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
Tensor mean_all(mean->dtype());
Tensor invstd_all(variance->dtype());

#ifdef PADDLE_WITH_CNCL
auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
Expand Down Expand Up @@ -205,7 +208,9 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
cncl_dtype,
comm,
stream));

#else
if (NO_USE_CNCL) {
#endif
} else {
count_all = input_count;
mean_all.ShareDataWith(local_mean);
Expand Down Expand Up @@ -404,6 +409,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
FillMLUTensorWithHostValue<int32_t>(
ctx, static_cast<int32_t>(x->numel() / C), &numel_count);

#ifdef PADDLE_WITH_CNCL
auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
Expand Down Expand Up @@ -440,6 +446,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
comm,
stream));
}
#endif

if (d_x) {
MLUCnnlTensorDesc desc_count(numel_count);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import paddle.fluid.layers as layers
from functools import reduce
from test_sync_batch_norm_base_mlu import TestSyncBatchNormRunnerBase, runtime_main
from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from op_test import OpTest, _set_use_system_allocator

from paddle.fluid.tests.unittests.test_sync_batch_norm_op import create_or_get_tensor
from test_sync_batch_norm_op import create_or_get_tensor

_set_use_system_allocator(False)
paddle.enable_static()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
from six import string_types
import paddle

from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from op_test import OpTest, _set_use_system_allocator

from paddle.fluid.tests.unittests.test_sync_batch_norm_op import create_or_get_tensor
from test_sync_batch_norm_op import create_or_get_tensor

_set_use_system_allocator(False)
paddle.enable_static()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -e

MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch test_sync_batch_norm_op_mlu_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys

sys.path.append("..")
from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from op_test import OpTest, _set_use_system_allocator

from test_sync_batch_norm_base_mlu import TestDistBase

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
import paddle.nn as nn
from paddle.fluid import Program, program_guard

from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator
from paddle.fluid.tests.unittests.test_dist_base import TestDistBase
sys.path.append("..")
from op_test import OpTest, _set_use_system_allocator
from test_dist_base import TestDistBase

paddle.enable_static()

Expand Down

0 comments on commit f49b0cb

Please sign in to comment.