-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add number count op #39224
add number count op #39224
Changes from all commits
7cf82bc
ac9e1e4
6a08778
9dc2f8e
38c5d51
d92b3cd
f95ae1f
b5ca7d8
f656f7e
4c8627f
ad01da4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/operators/number_count_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class NumberCountOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("gate_idx"), "Input", "gate_idx", | ||
"NumberCount"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "number_count", | ||
"NumberCount"); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
// the dtype of the gate_idx should be same as int64 | ||
auto gate_idx_dtype = | ||
OperatorWithKernel::IndicateVarDataType(ctx, "gate_idx"); | ||
|
||
PADDLE_ENFORCE_EQ(gate_idx_dtype, framework::proto::VarType::INT64, | ||
platform::errors::InvalidArgument( | ||
"The dtype of the gate_idx_dtype should be int64")); | ||
return framework::OpKernelType(gate_idx_dtype, ctx.GetPlace()); | ||
} | ||
}; | ||
|
||
class NumberCountOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("gate_idx", "(Tensor) The input gate index tensor."); | ||
AddOutput("Out", "(Tensor) The output expert count tensor."); | ||
AddAttr<int>("upper_range", "(int), The number of experts."); | ||
|
||
AddComment(R"DOC(number_count Operator.count gate indices.)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_CPU_KERNEL(number_count, ops::NumberCountOpCPUKernel<int>, | ||
ops::NumberCountOpCPUKernel<int64_t>); | ||
|
||
REGISTER_OP_WITHOUT_GRADIENT(number_count, ops::NumberCountOp, | ||
ops::NumberCountOpMaker); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/operators/number_count_op.h" | ||
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" | ||
#include "paddle/fluid/platform/float16.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
#define CEIL(_x_, _y_) (((_x_)-1) / (_y_) + 1) | ||
#define PERTHREAD_EXPERTS 256 | ||
#define WARP_SIZE 32 | ||
|
||
const int CUDA_NUM_THREADS = 512; | ||
static inline int GET_BLOCKS(const int N) { | ||
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; | ||
} | ||
|
||
using LoDTensor = framework::LoDTensor; | ||
using Tensor = framework::Tensor; | ||
|
||
template <typename T> | ||
__global__ void initialize_zero_kernel(T* data, const int length) { | ||
CUDA_KERNEL_LOOP(idx, length) { data[idx] = static_cast<T>(0); } | ||
} | ||
|
||
template <typename T> | ||
__global__ void NumberCount(const T* gate_idx, T* number_count, | ||
int64_t batch_size, int upper_range) { | ||
int res_tmp[PERTHREAD_EXPERTS] = {0}; | ||
int expert_min = blockIdx.x * PERTHREAD_EXPERTS; | ||
int expert_max = expert_min + PERTHREAD_EXPERTS; | ||
if (expert_max > upper_range) { | ||
expert_max = upper_range; | ||
} | ||
for (int i = threadIdx.x; i < batch_size; i += blockDim.x) { | ||
T idx = gate_idx[i]; | ||
if (idx == -1) { | ||
continue; | ||
} | ||
if (idx < expert_min || idx >= expert_max) { | ||
continue; | ||
} | ||
res_tmp[idx - expert_min] += 1; | ||
} | ||
for (int i = expert_min; i < expert_max; ++i) { | ||
int x = res_tmp[i - expert_min]; | ||
#pragma unroll | ||
for (int j = 1; j < WARP_SIZE; j <<= 1) { | ||
#ifdef __HIPCC__ | ||
x = x + __shfl_down(x, j); | ||
#else | ||
x = x + __shfl_down_sync(-1u, x, j); | ||
#endif | ||
} | ||
if (threadIdx.x % WARP_SIZE == 0) { | ||
platform::CudaAtomicAdd(number_count + i, x); | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
class NumberCountOpCUDAKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto gate_idx = context.Input<LoDTensor>("gate_idx"); | ||
auto upper_range = context.Attr<int>("upper_range"); | ||
auto number_count = context.Output<LoDTensor>("Out"); | ||
|
||
int64_t batch_size = gate_idx->numel(); | ||
auto place = context.GetPlace(); | ||
const auto& dev_ctx = | ||
context.template device_context<platform::CUDADeviceContext>(); | ||
|
||
framework::DDim out_dims = phi::make_ddim({upper_range}); | ||
auto out_data = number_count->mutable_data<T>(out_dims, place); | ||
const T* gate_data = gate_idx->data<T>(); | ||
|
||
initialize_zero_kernel< | ||
T><<<GET_BLOCKS(upper_range), CUDA_NUM_THREADS, 0, dev_ctx.stream()>>>( | ||
out_data, upper_range); | ||
|
||
NumberCount< | ||
T><<<CEIL(upper_range, PERTHREAD_EXPERTS), 256, 0, dev_ctx.stream()>>>( | ||
gate_data, out_data, batch_size, upper_range); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
namespace plat = paddle::platform; | ||
|
||
REGISTER_OP_CUDA_KERNEL(number_count, ops::NumberCountOpCUDAKernel<int64_t>); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
#include "paddle/fluid/framework/data_type.h" | ||
#include "paddle/fluid/framework/lod_tensor.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
|
||
#if defined(PADDLE_WITH_GLOO) | ||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h" | ||
#endif | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class NumberCountOpCPUKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
PADDLE_THROW(platform::errors::Unavailable( | ||
"Do not support expert count op for cpu kernel now.")); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.fluid import core | ||
from paddle.fluid.layer_helper import LayerHelper | ||
from paddle.fluid.framework import in_dygraph_mode | ||
|
||
|
||
def _number_count(gate_idx, upper_range): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OP定义与PR描述不符,另外PR描述过于简单,请说清该PR的作用与目的 |
||
""" | ||
calculate the expert count according to the gate index. | ||
Args: | ||
gate_idx (Tensor): Tensor. The input gate index whose data type should be int32 or int64. | ||
upper_range (int): The number of the experts. | ||
Returns: | ||
out (Tensor): The output expert count. | ||
Examples: | ||
.. code-block:: python | ||
# required: distributed | ||
import paddle | ||
|
||
gate_idx = [ | ||
[0, 2], | ||
[0, 2] | ||
] | ||
upper_range = 6 | ||
gate_idx = paddle.to_tensor(gate_idx, dtype="int32") | ||
number_count = paddle.distributed.utils.number_count(gate_idx, upper_range) | ||
print(number_count) # the result: [2, 0, 2, 0, 0, 0] | ||
""" | ||
if in_dygraph_mode(): | ||
return core.ops.number_count(gate_idx, 'upper_range', upper_range) | ||
else: | ||
op_type = 'number_count' | ||
|
||
helper = LayerHelper(op_type, **locals()) | ||
out = helper.create_variable_for_type_inference(dtype=gate_idx.dtype) | ||
|
||
helper.append_op( | ||
type=op_type, | ||
inputs={'gate_idx': gate_idx}, | ||
outputs={'Out': out}, | ||
attrs={'upper_range': upper_range}) | ||
return out |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK |
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
|
||
import op_test | ||
import numpy as np | ||
import unittest | ||
import paddle | ||
import paddle.fluid.core as core | ||
from paddle.fluid.op import Operator | ||
import paddle.fluid as fluid | ||
from paddle.fluid import compiler, Program, program_guard | ||
from paddle.fluid.backward import append_backward | ||
from paddle.distributed.models.moe import utils | ||
|
||
|
||
def count(x, upper_range): | ||
res = np.zeros((upper_range, )).astype(int) | ||
for i in x.reshape(-1): | ||
if i >= 0 and i < len(res): | ||
res[i] += 1 | ||
return res | ||
|
||
|
||
@unittest.skipIf(not core.is_compiled_with_cuda(), | ||
"core is not compiled with CUDA") | ||
class TestExpertCountOpInt64(op_test.OpTest): | ||
def setUp(self): | ||
expert_num = 16 | ||
self.op_type = "number_count" | ||
x = np.random.randint(-1, expert_num, size=(1000, 2)).astype('int64') | ||
self.inputs = {'gate_idx': x} | ||
self.outputs = {'Out': count(x, expert_num)} | ||
self.attrs = {"upper_range": expert_num} | ||
|
||
def test_forward(self): | ||
self.check_output_with_place(paddle.CUDAPlace(0)) | ||
|
||
|
||
@unittest.skipIf(not core.is_compiled_with_cuda(), | ||
"core is not compiled with CUDA") | ||
class TestExpertCountAPI(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 单测名称没有修改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK,下个PR修改,MOE模块会有多个算子的PR需要提交 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 以及再仔细检查下其他地方 |
||
def setUp(self): | ||
self.upper_range = 320 | ||
self.x = np.random.randint( | ||
-1, self.upper_range, size=(6000, 200)).astype('int64') | ||
self.out = count(self.x, self.upper_range) | ||
self.place = paddle.CUDAPlace(0) | ||
|
||
def test_api_static(self): | ||
paddle.enable_static() | ||
with paddle.static.program_guard(paddle.static.Program()): | ||
x = paddle.fluid.data('x', self.x.shape, dtype="int64") | ||
out = utils._number_count(x, self.upper_range) | ||
exe = paddle.static.Executor(self.place) | ||
res = exe.run(feed={'x': self.x}, fetch_list=[out]) | ||
assert np.allclose(res, self.out) | ||
|
||
def test_api_dygraph(self): | ||
paddle.disable_static() | ||
x = paddle.to_tensor(self.x) | ||
out = utils._number_count(x, self.upper_range) | ||
assert np.allclose(out.numpy(), self.out) | ||
|
||
|
||
if __name__ == '__main__': | ||
paddle.enable_static() | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.