Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.46】为 Paddle gumbel_softmax 算子实现 float16 数据类型支持 #50923

Merged
merged 5 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/gumbel_softmax_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ PD_REGISTER_KERNEL(gumbel_softmax_grad,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxGradKernel,
phi::dtype::float16,
float,
double) {}
32 changes: 20 additions & 12 deletions paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/phi/kernels/gumbel_softmax_kernel.h"

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/impl/gumbel_softmax_kernel_impl.h"
Expand Down Expand Up @@ -116,17 +116,18 @@ struct OneHotGenerator<GPUContext, T> {
}
};

template <typename T>
template <typename T, typename MPType>
__global__ void AddGumbelNoiseCUDAKernel(const T* input_data,
T* output_data,
T* noise,
MPType* noise,
const float temperature,
int64_t n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int step = blockDim.x * gridDim.x;
for (int64_t i = index; i < n; i += step) {
T gumbel_noise = -log(-log(noise[i]));
output_data[i] = (gumbel_noise + input_data[i]) / temperature;
MPType gumbel_noise = -log(-log(noise[i]));
output_data[i] = static_cast<T>(
(gumbel_noise + static_cast<MPType>(input_data[i])) / temperature);
}
}

Expand All @@ -141,7 +142,8 @@ struct GumbleNoiseGenerator<GPUContext, T> {
DenseTensor random_tensor;
int64_t size = size_to_axis * size_from_axis;
random_tensor.Resize(make_ddim({size}));
T* random_data = ctx.template Alloc<T>(&random_tensor);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType* random_data = ctx.template Alloc<MPType>(&random_tensor);

// generate gumbel noise
int device_id = ctx.GetPlace().GetDeviceId();
Expand All @@ -152,10 +154,11 @@ struct GumbleNoiseGenerator<GPUContext, T> {
uint64_t offset = seed_offset.second;

thrust::counting_iterator<int64_t> index_sequence_begin(0);
thrust::transform(index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<T>(random_data),
UniformCUDAGenerator<T>(0.00001, 1, seed, size * offset));
thrust::transform(
index_sequence_begin,
index_sequence_begin + size,
thrust::device_ptr<MPType>(random_data),
UniformCUDAGenerator<MPType>(0.00001, 1, seed, size * offset));

// add gumbel noise to X
const int thread_size = 512;
Expand All @@ -168,5 +171,10 @@ struct GumbleNoiseGenerator<GPUContext, T> {
} // namespace phi
#endif

PD_REGISTER_KERNEL(
gumbel_softmax, GPU, ALL_LAYOUT, phi::GumbelSoftmaxKernel, float, double) {}
PD_REGISTER_KERNEL(gumbel_softmax,
GPU,
ALL_LAYOUT,
phi::GumbelSoftmaxKernel,
phi::dtype::float16,
float,
double) {}
51 changes: 51 additions & 0 deletions python/paddle/fluid/tests/unittests/test_gumbel_softmax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,57 @@ def init_attrs(self):
self.dtype = "float64"


class TestGumbelSoftmax_ZeroDim_FP16OP(OpTest):
def setUp(self):
self.op_type = "gumbel_softmax"
self.python_api = F.gumbel_softmax
self.dtype = np.float16
x = np.random.uniform(0.1, 1, []).astype(self.dtype)
out = np.array(1.0).astype(self.dtype)

self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"hard": True, "axis": -1}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(["X"], "Out")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP16的单测需要继承TestGumbelSoftmaxOp,实际上只需要为fp16的case重写init_attrs,可以减少冗余代码。
TestGumbelSoftmax_ZeroDim_FP16OP -> TestGumbelSoftmaxFP16OP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

老师,您好,这里是参考单测中原来写法。针对于ZeroDim单独继承optest进行测试,其余各test继承TestGumbelSoftmaxOp并重写init_attr()。我这里也是针对于ZeroDim单独处理了。所以直接继承了optest。后续四个test都是直接继承TestGumbelSoftmaxOp并重写init_attr()的。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原始写法我想并不是最优的。TestGumbelSoftmax_ZeroDim里面其实重写init_attr也可以吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

老师您好,我尝试了直接用TestGumbelSoftmax_ZeroDim继承TestGumbelSoftmaxOp基类,但是由于基类中check_out是针对多维重写的check_out_custormized,并不适用于ZeroDim。因此我在TestGumbelSoftmax_ZeroDim中添加了init_attr方法,并令TestGumbelSoftmax_ZeroDimFP16继承修改。



class TestGumbelSoftmaxFP16OP2(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10]
self.attrs = {"hard": True, "axis": 0}
self.count_expected = 10
self.dtype = np.float16


class TestGumbelSoftmaxFP16OP3(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [100]
self.attrs = {"hard": True, "axis": -1}
self.count_expected = 1
self.dtype = np.float16


class TestGumbelSoftmaxFP16OP4(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10, 5]
self.attrs = {"hard": True, "axis": -1}
self.count_expected = 200
self.dtype = np.float16


class TestGumbelSoftmaxFP16OP5(TestGumbelSoftmaxOp):
def init_attrs(self):
self.shape = [20, 10, 5]
self.attrs = {"hard": True, "axis": 1}
self.count_expected = 100
self.dtype = np.float16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这4个单测继承TestGumbelSoftmaxFP16OP。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

您好,因为前面TestGumbelSoftmax_ZeroDim_FP16OP是针对于ZeroDim的,所以内部没有init_attrs()函数。无法更改名字为TestGumbelSoftmaxFP16OP。所以直接继承自TestGumbelSoftmaxOp。


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

class TestGumbelSoftmaxOpSampleDistribution(OpTest):
def softmax(self, x):
x_row_max = x.max(axis=-1)
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,7 @@ def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None):
Parameters:
x (Tensor): An N-D Tensor, the first N - 1 dimensions index into a batch
of independent distributions and the last dimension represents
a vector of probabilities with datatype float32, float64.
a vector of probabilities with datatype float16, float32, float64.
temperature (float, optional): non-negative scalar temperature.
Default is 1.0.
hard (bool, optional): if True, the returned samples will be discretized as
Expand Down Expand Up @@ -1705,7 +1705,9 @@ def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None):
)

helper = LayerHelper("gumbel_softmax", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'gumbel_softmax')
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'gumbel_softmax'
)
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='gumbel_softmax',
Expand Down