Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] add bfloat16 support for gaussian and uniform #58662

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_XPTI_LIB_NAME "libxpti.so")

if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231101")
set(XPU_BASE_DATE "20231103")
endif()
set(XPU_XCCL_BASE_VERSION "1.0.53.6")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
9 changes: 7 additions & 2 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,9 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64,
phi::DataType::BOOL})},
{"gaussian_random",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"gelu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"gelu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -977,7 +979,10 @@ XPUOpMap& get_kl2_ops() {
{"update_loss_scaling",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"unbind", XPUKernelSet({phi::DataType::FLOAT32})},
{"uniform_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"uniform_random",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"unique",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
Expand Down
9 changes: 7 additions & 2 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,9 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT64,
phi::DataType::BOOL})},
{"gaussian_random",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"gelu_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"gelu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down Expand Up @@ -942,7 +944,10 @@ XPUOpMap& get_kl3_ops() {
{"update_loss_scaling",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"unbind", XPUKernelSet({phi::DataType::FLOAT32})},
{"uniform_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"uniform_random",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"unique",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/xpu/gaussian_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ PD_REGISTER_KERNEL(gaussian,
ALL_LAYOUT,
phi::GaussianKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
65 changes: 22 additions & 43 deletions paddle/phi/kernels/xpu/uniform_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,9 @@ limitations under the License. */

#include "paddle/phi/kernels/uniform_kernel.h"

#include <string>

#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/uniform_real_distribution.h"

namespace phi {

Expand All @@ -31,49 +28,31 @@ void UniformKernel(const Context &dev_ctx,
const Scalar &max,
int seed,
DenseTensor *out) {
int diag_num = 0;
int diag_step = 0;
float diag_val = 0.0f;
out->Resize(phi::make_ddim(shape.GetData()));
T *data = dev_ctx.template Alloc<T>(out);
int64_t size = out->numel();

std::unique_ptr<T[]> data_cpu(new T[size]);

std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetGenerator()->GetCPUEngine();
}
UniformRealDistribution<T>(
data_cpu.get(), size, min.to<float>(), max.to<float>(), engine);
if (diag_num > 0) {
PADDLE_ENFORCE_GT(
size,
(diag_num - 1) * (diag_step + 1),
phi::errors::InvalidArgument(
"ShapeInvalid: the diagonal's elements is equal (num-1) "
"* (step-1) with num %d, step %d,"
"It should be smaller than %d, but received %d",
diag_num,
diag_step,
(diag_num - 1) * (diag_step + 1),
size));
for (int64_t i = 0; i < diag_num; ++i) {
int64_t pos = i * diag_step + i;
data_cpu[pos] = diag_val;
}
if (out->numel() == 0) {
return;
}

memory_utils::Copy(dev_ctx.GetPlace(),
data,
phi::CPUPlace(),
reinterpret_cast<void *>(data_cpu.get()),
size * sizeof(T));
using XPUType = typename XPUTypeTrait<T>::Type;
int64_t real_seed = seed != 0 ? seed : dev_ctx.GetGenerator()->Random64();

// int random(Context* ctx, T* x, int64_t len, T min, T max, int64_t seed);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu的分支是否需要保留?
一个需求是用相同的初始化seed,然后xpu/cpu对比输出,xpu::random的结果和cpu能对齐?

Copy link
Contributor Author

@houj04 houj04 Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种场景下面用export XPU_BLACK_LIST="uniform_random,gaussian_random"的办法解决?
目前gaussian_random也没有保留“从CPU拷贝”的功能。

int r = xpu::random<XPUType>(dev_ctx.x_context(),
reinterpret_cast<XPUType *>(data),
out->numel(),
static_cast<XPUType>(min.to<float>()),
static_cast<XPUType>(max.to<float>()),
real_seed);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "random");
}

} // namespace phi

PD_REGISTER_KERNEL(uniform, XPU, ALL_LAYOUT, phi::UniformKernel, float) {}
PD_REGISTER_KERNEL(uniform,
XPU,
ALL_LAYOUT,
phi::UniformKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16) {}
34 changes: 34 additions & 0 deletions test/xpu/test_gaussian_random_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,23 @@
from paddle import base

paddle.enable_static()
from paddle.base import core
from paddle.tensor import random

typeid_dict = {
'int32': int(core.VarDesc.VarType.INT32),
'int64': int(core.VarDesc.VarType.INT64),
'float32': int(core.VarDesc.VarType.FP32),
'float16': int(core.VarDesc.VarType.FP16),
'bfloat16': int(core.VarDesc.VarType.BF16),
'bool': int(core.VarDesc.VarType.BOOL),
'int8': int(core.VarDesc.VarType.INT8),
'uint8': int(core.VarDesc.VarType.UINT8),
'float64': int(core.VarDesc.VarType.FP64),
}

from op_test import convert_uint16_to_float


class XPUTestGaussianRandomOp(XPUOpTestWrapper):
def __init__(self):
Expand All @@ -52,6 +67,7 @@ def setUp(self):
"std": self.std,
"seed": 10,
"use_mkldnn": self.use_mkldnn,
"dtype": typeid_dict[self.in_type_str],
}
paddle.seed(10)

Expand All @@ -67,6 +83,10 @@ def test_check_output(self):
)

def verify_output(self, outs):
# special for bf16
if self.in_type_str == "bfloat16":
outs = convert_uint16_to_float(outs)

self.assertEqual(outs[0].shape, (123, 92))
hist, _ = np.histogram(outs[0], range=(-3, 5))
hist = hist.astype("float32")
Expand Down Expand Up @@ -100,6 +120,7 @@ def setUp(self):
'std': self.std,
'seed': self.seed,
'use_mkldnn': self.use_mkldnn,
"dtype": typeid_dict[self.in_type_str],
}

self.inputs = {"ShapeTensorList": shape_tensor_list}
Expand Down Expand Up @@ -165,6 +186,7 @@ def setUp(self):
'std': self.std,
'seed': self.seed,
'use_mkldnn': self.use_mkldnn,
"dtype": typeid_dict[self.in_type_str],
}
self.outputs = {'Out': np.zeros((123, 92), dtype=self.dtype)}

Expand Down Expand Up @@ -265,6 +287,11 @@ def test_default_fp16():
out = paddle.tensor.random.gaussian([2, 3])
self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP16)

def test_default_bf16():
paddle.framework.set_default_dtype('bfloat16')
out = paddle.tensor.random.gaussian([2, 3])
self.assertEqual(out.dtype, base.core.VarDesc.VarType.BF16)

def test_default_fp32():
paddle.framework.set_default_dtype('float32')
out = paddle.tensor.random.gaussian([2, 3])
Expand All @@ -278,6 +305,7 @@ def test_default_fp64():
test_default_fp64()
test_default_fp32()
test_default_fp16()
test_default_bf16()

paddle.enable_static()

Expand All @@ -291,6 +319,11 @@ def test_default_fp16():
out = paddle.tensor.random.standard_normal([2, 3])
self.assertEqual(out.dtype, base.core.VarDesc.VarType.FP16)

def test_default_bf16():
paddle.framework.set_default_dtype('bfloat16')
out = paddle.tensor.random.standard_normal([2, 3])
self.assertEqual(out.dtype, base.core.VarDesc.VarType.BF16)

def test_default_fp32():
paddle.framework.set_default_dtype('float32')
out = paddle.tensor.random.standard_normal([2, 3])
Expand All @@ -304,6 +337,7 @@ def test_default_fp64():
test_default_fp64()
test_default_fp32()
test_default_fp16()
test_default_bf16()

paddle.enable_static()

Expand Down
99 changes: 82 additions & 17 deletions test/xpu/test_uniform_random_op_xpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,32 +16,97 @@
import unittest

import numpy as np
from test_uniform_random_op import (
TestUniformRandomOp,
TestUniformRandomOpSelectedRows,
from get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
from op_test_xpu import XPUOpTest

import paddle

paddle.enable_static()
from paddle.base import core

typeid_dict = {
'int32': int(core.VarDesc.VarType.INT32),
'int64': int(core.VarDesc.VarType.INT64),
'float32': int(core.VarDesc.VarType.FP32),
'float16': int(core.VarDesc.VarType.FP16),
'bfloat16': int(core.VarDesc.VarType.BF16),
'bool': int(core.VarDesc.VarType.BOOL),
'int8': int(core.VarDesc.VarType.INT8),
'uint8': int(core.VarDesc.VarType.UINT8),
'float64': int(core.VarDesc.VarType.FP64),
}

class TestXPUUniformRandomOp(TestUniformRandomOp):
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
outs = self.calc_output(place)
outs = [np.array(out) for out in outs]
outs.sort(key=len)
self.verify_output(outs)

def output_hist(out):
if out.dtype == np.uint16:
out = convert_uint16_to_float(out)
hist, _ = np.histogram(out, range=(-5, 10))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones(10)
return hist, prob

class TestXPUUniformRandomOpSelectedRows(TestUniformRandomOpSelectedRows):
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_with_place(place)

from op_test import convert_uint16_to_float


class XPUTestUniformRandomOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'uniform_random'
self.use_dynamic_create_class = False

class TestUniformRandomOp(XPUOpTest):
def init(self):
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.op_type = "uniform_random"
self.python_api = paddle.uniform

def setUp(self):
self.init()
self.inputs = {}
self.use_mkldnn = False
self.set_attrs()
paddle.seed(10)

self.outputs = {"Out": np.zeros((1000, 784), dtype=self.dtype)}

def set_attrs(self):
self.attrs = {
"shape": [1000, 784],
"min": -5.0,
"max": 10.0,
"dtype": typeid_dict[self.in_type_str],
}
self.output_hist = output_hist

def test_check_output(self):
self.check_output_with_place_customized(
self.verify_output, self.place
)

def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
np.testing.assert_allclose(hist, prob, rtol=0, atol=0.01)

class TestMaxMinAreInt(TestUniformRandomOp):
def set_attrs(self):
self.attrs = {
"shape": [1000, 784],
"min": -5,
"max": 10,
"dtype": typeid_dict[self.in_type_str],
}
self.output_hist = output_hist


support_types = get_xpu_op_support_types('uniform_random')
for stype in support_types:
create_test_class(globals(), XPUTestUniformRandomOp, stype)

if __name__ == "__main__":
unittest.main()