Skip to content

Commit

Permalink
[bf16] Refine BF16 amp-o1 logic (#39815)
Browse files Browse the repository at this point in the history
* refine bf16 amp-o1 logic

* refine amp GLOG

* refine unittest

* refine unittest
  • Loading branch information
zhangbo9674 authored Feb 28, 2022
1 parent d1595c2 commit 18ee051
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 12 deletions.
30 changes: 26 additions & 4 deletions paddle/fluid/imperative/amp_auto_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ static inline std::shared_ptr<VarType> CastToBF16(

template <typename VarType>
static inline framework::proto::VarType::Type GetPromoteType(
const std::string& op_type, const NameVarMap<VarType>& ins) {
auto dst_type = framework::proto::VarType::FP16;
const std::string& op_type, const NameVarMap<VarType>& ins,
const framework::proto::VarType::Type amp_dtype) {
auto dst_type = amp_dtype;
for (const auto& pair : ins) {
for (const auto& var : pair.second) {
if (GetDataType<VarType>(var) == framework::proto::VarType::FP32) {
Expand Down Expand Up @@ -337,7 +338,8 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
}
return new_ins;
} else {
auto dst_type = GetPromoteType<VarType>(op_type, ins);
auto dst_type =
GetPromoteType<VarType>(op_type, ins, framework::proto::VarType::FP16);

// NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == framework::proto::VarType::FP16 &&
Expand Down Expand Up @@ -435,7 +437,7 @@ NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
}
}
return new_ins;
} else {
} else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) {
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to float";
Expand All @@ -444,6 +446,26 @@ NameVarMap<VarType> AutoCastBF16Inputs(const std::string& op_type,
}
}
return new_ins;
} else {
auto dst_type =
GetPromoteType<VarType>(op_type, ins, framework::proto::VarType::BF16);
// NOTE(zhangbo): if the op has op fp16 kernel, fall back to fp32.
if (dst_type == framework::proto::VarType::BF16 &&
AmpOperators::Instance().GetMutableUnsupportedBf16Ops()->count(
op_type)) {
dst_type = framework::proto::VarType::FP32;
}
for (auto& pair : new_ins) {
VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from "
<< GetDtypeStr(*pair.second.cbegin()) << " to "
<< framework::DataTypeToString(dst_type);
for (auto& var : pair.second) {
var = (dst_type == framework::proto::VarType::FP32
? CastToFP32<VarType>(var)
: CastToBF16<VarType>(var));
}
}
return new_ins;
}
return new_ins;
}
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,19 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,

NameVarMap<VarType> new_ins = ins;
if (amp_level_ == AmpLevel::O1) {
VLOG(5) << "Auto mixed precision run operator: " << type;
if (amp_dtype_ == phi::DataType::FLOAT16) {
VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastInputs<VarType>(type, ins);
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
new_ins = AutoCastBF16Inputs<VarType>(type, ins);
}
} else if (amp_level_ == AmpLevel::O2) {
VLOG(5) << "Pure fp16 run operator: " << type;
if (amp_dtype_ == phi::DataType::FLOAT16) {
VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureFp16Inputs<VarType>(type, ins);
} else if (amp_dtype_ == phi::DataType::BFLOAT16) {
VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
new_ins = CastPureBf16Inputs<VarType>(type, ins);
}
}
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad'
}

BF16_WHITE_LIST = {'conv2d'}
BF16_WHITE_LIST = {'conv2d', 'matmul_v2'}
BF16_BLACK_LIST = {' '}

_g_amp_state_ = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1131,20 +1131,29 @@ class TestBf16(unittest.TestCase):
test amp for BF16
'''

def train(self, enable_amp=True):
def train(self, enable_amp=True, amp_level='O1'):
paddle.seed(100)
input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = paddle.nn.Conv2D(4, 6, (3, 3))
with paddle.amp.auto_cast(
enable=enable_amp, level='O2', dtype='bfloat16'):
enable=enable_amp, level=amp_level, dtype='bfloat16'):
output = conv(input)
output = output.cast('float32')
return output.numpy()

def test_bf16(self):
out_fp32 = self.train(enable_amp=False)
out_bf16 = self.train(enable_amp=True)
self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-1))
if fluid.core.is_compiled_with_cuda():
cudnn_version = paddle.device.get_cudnn_version()
if cudnn_version is not None and cudnn_version >= 8100:
out_fp32 = self.train(enable_amp=False)
out_bf16_O1 = self.train(enable_amp=True, amp_level='O1')
out_bf16_O2 = self.train(enable_amp=True, amp_level='O2')
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1))
self.assertTrue(
np.allclose(
out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1))


class TestPyLayerWithAmp(unittest.TestCase):
Expand Down

0 comments on commit 18ee051

Please sign in to comment.