forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add number count op (PaddlePaddle#39224)
* add expert count op add ut for expert_count * update UT only for cuda * fix for rocm * update ut * add moe module * add expert count op add ut for expert_count * update UT only for cuda * update ut * add moe module * make expert count private * rename expert count op Co-authored-by: hlygit66666 <2570058140@qq.com>
- Loading branch information
1 parent
85a9ee3
commit e84947b
Showing
7 changed files
with
372 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/number_count_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class NumberCountOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("gate_idx"), "Input", "gate_idx", | ||
"NumberCount"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "number_count", | ||
"NumberCount"); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
// the dtype of the gate_idx should be same as int64 | ||
auto gate_idx_dtype = | ||
OperatorWithKernel::IndicateVarDataType(ctx, "gate_idx"); | ||
|
||
PADDLE_ENFORCE_EQ(gate_idx_dtype, framework::proto::VarType::INT64, | ||
platform::errors::InvalidArgument( | ||
"The dtype of the gate_idx_dtype should be int64")); | ||
return framework::OpKernelType(gate_idx_dtype, ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("gate_idx", "(Tensor) The input gate index tensor."); | ||
AddOutput("Out", "(Tensor) The output expert count tensor."); | ||
AddAttr<int>("upper_range", "(int), The number of experts."); | ||
|
||
AddComment(R"DOC(number_count Operator.count gate indices.)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_CPU_KERNEL(number_count, ops::NumberCountOpCPUKernel<int>, | ||
ops::NumberCountOpCPUKernel<int64_t>); | ||
|
||
REGISTER_OP_WITHOUT_GRADIENT(number_count, ops::NumberCountOp, | ||
ops::NumberCountOpMaker); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/operators/number_count_op.h" | ||
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" | ||
#include "paddle/fluid/platform/float16.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) | ||
#define PERTHREAD_EXPERTS 256 | ||
#define WARP_SIZE 32 | ||
|
||
const int CUDA_NUM_THREADS = 512; | ||
static inline int GET_BLOCKS(const int N) { | ||
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; | ||
} | ||
|
||
using LoDTensor = framework::LoDTensor; | ||
using Tensor = framework::Tensor; | ||
|
||
template <typename T> | ||
__global__ void initialize_zero_kernel(T* data, const int length) { | ||
CUDA_KERNEL_LOOP(idx, length) { data[idx] = static_cast<T>(0); } | ||
} | ||
|
||
template <typename T> | ||
__global__ void NumberCount(const T* gate_idx, T* number_count, | ||
int64_t batch_size, int upper_range) { | ||
int res_tmp[PERTHREAD_EXPERTS] = {0}; | ||
int expert_min = blockIdx.x * PERTHREAD_EXPERTS; | ||
int expert_max = expert_min + PERTHREAD_EXPERTS; | ||
if (expert_max > upper_range) { | ||
expert_max = upper_range; | ||
} | ||
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { | ||
T idx = gate_idx[i]; | ||
if (idx == -1) { | ||
continue; | ||
} | ||
if (idx < expert_min || idx >= expert_max) { | ||
continue; | ||
} | ||
res_tmp[idx - expert_min] += 1; | ||
} | ||
for (int i = expert_min; i < expert_max; ++i) { | ||
int x = res_tmp[i - expert_min]; | ||
#pragma unroll | ||
for (int j = 1; j < WARP_SIZE; j <<= 1) { | ||
#ifdef __HIPCC__ | ||
x = x + __shfl_down(x, j); | ||
#else | ||
x = x + __shfl_down_sync(-1u, x, j); | ||
#endif | ||
} | ||
if (threadIdx.x % WARP_SIZE == 0) { | ||
platform::CudaAtomicAdd(number_count + i, x); | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
class NumberCountOpCUDAKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto gate_idx = context.Input<LoDTensor>("gate_idx"); | ||
auto upper_range = context.Attr<int>("upper_range"); | ||
auto number_count = context.Output<LoDTensor>("Out"); | ||
|
||
int64_t batch_size = gate_idx->numel(); | ||
auto place = context.GetPlace(); | ||
const auto& dev_ctx = | ||
context.template device_context<platform::CUDADeviceContext>(); | ||
|
||
framework::DDim out_dims = phi::make_ddim({upper_range}); | ||
auto out_data = number_count->mutable_data<T>(out_dims, place); | ||
const T* gate_data = gate_idx->data<T>(); | ||
|
||
initialize_zero_kernel< | ||
T><<<GET_BLOCKS(upper_range), CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>( | ||
out_data, upper_range); | ||
|
||
NumberCount< | ||
T><<<CEIL(upper_range, PERTHREAD_EXPERTS), 256, 0, dev_ctx.stream()>>>( | ||
gate_data, out_data, batch_size, upper_range); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_CUDA_KERNEL(number_count, ops::NumberCountOpCUDAKernel<int64_t>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
#include "paddle/fluid/framework/data_type.h" | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
#if defined(PADDLE_WITH_GLOO) | ||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h" | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class NumberCountOpCPUKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
PADDLE_THROW(platform::errors::Unavailable( | ||
"Do not support expert count op for cpu kernel now.")); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.fluid import core | ||
from paddle.fluid.layer_helper import LayerHelper | ||
from paddle.fluid.framework import in_dygraph_mode | ||
|
||
|
||
def _number_count(gate_idx, upper_range): | ||
""" | ||
calculate the expert count according to the gate index. | ||
Args: | ||
gate_idx (Tensor): Tensor. The input gate index whose data type should be int32 or int64. | ||
upper_range (int): The number of the experts. | ||
Returns: | ||
out (Tensor): The output expert count. | ||
Examples: | ||
.. code-block:: python | ||
# required: distributed | ||
import paddle | ||
gate_idx = [ | ||
[0, 2], | ||
[0, 2] | ||
] | ||
upper_range = 6 | ||
gate_idx = paddle.to_tensor(gate_idx, dtype="int32") | ||
number_count = paddle.distributed.utils.number_count(gate_idx, upper_range) | ||
print(number_count) # the result: [2, 0, 2, 0, 0, 0] | ||
""" | ||
if in_dygraph_mode(): | ||
return core.ops.number_count(gate_idx, 'upper_range', upper_range) | ||
else: | ||
op_type = 'number_count' | ||
|
||
helper = LayerHelper(op_type, **locals()) | ||
out = helper.create_variable_for_type_inference(dtype=gate_idx.dtype) | ||
|
||
helper.append_op( | ||
type=op_type, | ||
inputs={'gate_idx': gate_idx}, | ||
outputs={'Out': out}, | ||
attrs={'upper_range': upper_range}) | ||
return out |
80 changes: 80 additions & 0 deletions
80
python/paddle/fluid/tests/unittests/test_number_count_op.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import op_test | ||
import numpy as np | ||
import unittest | ||
import paddle | ||
import paddle.fluid.core as core | ||
from paddle.fluid.op import Operator | ||
import paddle.fluid as fluid | ||
from paddle.fluid import compiler, Program, program_guard | ||
from paddle.fluid.backward import append_backward | ||
from paddle.distributed.models.moe import utils | ||
|
||
|
||
def count(x, upper_range): | ||
res = np.zeros((upper_range, )).astype(int) | ||
for i in x.reshape(-1): | ||
if i >= 0 and i < len(res): | ||
res[i] += 1 | ||
return res | ||
|
||
|
||
@unittest.skipIf(not core.is_compiled_with_cuda(), | ||
"core is not compiled with CUDA") | ||
class TestExpertCountOpInt64(op_test.OpTest): | ||
def setUp(self): | ||
expert_num = 16 | ||
self.op_type = "number_count" | ||
x = np.random.randint(-1, expert_num, size=(1000, 2)).astype('int64') | ||
self.inputs = {'gate_idx': x} | ||
self.outputs = {'Out': count(x, expert_num)} | ||
self.attrs = {"upper_range": expert_num} | ||
|
||
def test_forward(self): | ||
self.check_output_with_place(paddle.CUDAPlace(0)) | ||
|
||
|
||
@unittest.skipIf(not core.is_compiled_with_cuda(), | ||
"core is not compiled with CUDA") | ||
class TestExpertCountAPI(unittest.TestCase): | ||
def setUp(self): | ||
self.upper_range = 320 | ||
self.x = np.random.randint( | ||
-1, self.upper_range, size=(6000, 200)).astype('int64') | ||
self.out = count(self.x, self.upper_range) | ||
self.place = paddle.CUDAPlace(0) | ||
|
||
def test_api_static(self): | ||
paddle.enable_static() | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
x = paddle.fluid.data('x', self.x.shape, dtype="int64") | ||
out = utils._number_count(x, self.upper_range) | ||
exe = paddle.static.Executor(self.place) | ||
res = exe.run(feed={'x': self.x}, fetch_list=[out]) | ||
assert np.allclose(res, self.out) | ||
|
||
def test_api_dygraph(self): | ||
paddle.disable_static() | ||
x = paddle.to_tensor(self.x) | ||
out = utils._number_count(x, self.upper_range) | ||
assert np.allclose(out.numpy(), self.out) | ||
|
||
|
||
if __name__ == '__main__': | ||
paddle.enable_static() | ||
unittest.main() |