Skip to content

Commit

Permalink
[AMP] Allow to switch whether to use promote strategy to choose kerne…
Browse files Browse the repository at this point in the history
…l for O2 training. (#53742)

* Allow to switch whether to use promote strategy to choose kernel for O2 training.

* Fix comparing error and add unittest.
  • Loading branch information
Xreki authored May 16, 2023
1 parent 2a94b81 commit db407bf
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 34 deletions.
37 changes: 27 additions & 10 deletions paddle/fluid/eager/amp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,35 @@ inline phi::DataType GetAmpDestDtype(
egr::Controller::Instance().GetCurrentTracer()->GetAmpPhiDtype();
auto dst_type = amp_setting_dtype;

if (paddle::imperative::AmpOperators::Instance().GetMutableAllowOps()->count(
op_name)) {
dst_type = amp_setting_dtype;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
} else {
if (amp_level == paddle::imperative::AmpLevel::OD) {
bool use_promote = true;
if (amp_level == paddle::imperative::AmpLevel::O2) {
use_promote =
egr::Controller::Instance().GetCurrentTracer()->GetUsePromote();
}

if (use_promote) {
if (paddle::imperative::AmpOperators::Instance()
.GetMutableAllowOps()
->count(op_name)) {
dst_type = amp_setting_dtype;
} else if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type = GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
if (amp_level == paddle::imperative::AmpLevel::OD) {
dst_type = phi::DataType::FLOAT32;
} else {
dst_type =
GetPromoteType(op_name, amp_tensors_vector, amp_setting_dtype);
}
}
} else {
// use_promote can be set to false only for O2 training.
if (paddle::imperative::AmpOperators::Instance()
.GetMutableBlockOps()
->count(op_name)) {
dst_type = phi::DataType::FLOAT32;
}
}

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/eager/api/utils/global_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class Controller {
return tracer_->GetAmpLevel();
}

void SetUsePromote(bool use_promote) { tracer_->SetUsePromote(use_promote); }
bool GetUsePromote() const { return tracer_->GetUsePromote(); }

bool UseLayoutAutoTune() {
bool use_autotune = false;
#if defined(PADDLE_WITH_CUDA)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/imperative/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ thread_local bool Tracer::enable_program_desc_tracing_ = false;

thread_local bool Tracer::has_grad_ = true;

thread_local bool Tracer::use_promote_ = true;

thread_local bool Tracer::use_layout_autotune_ = false;

thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/imperative/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ class Tracer {

void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }

void SetUsePromote(bool use_promote) {
VLOG(4) << "set use_promote to " << use_promote;
use_promote_ = use_promote;
}

bool GetUsePromote() const { return use_promote_; }

void SetAmpLevel(AmpLevel level) {
VLOG(4) << "set amp_level to " << static_cast<unsigned int>(level);
amp_level_ = level;
Expand Down Expand Up @@ -220,6 +227,7 @@ class Tracer {
static thread_local bool enable_program_desc_tracing_;
static thread_local bool use_layout_autotune_;
static thread_local bool has_grad_;
static thread_local bool use_promote_;
static thread_local AmpLevel amp_level_;
static thread_local phi::DataType amp_dtype_;
};
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2156,6 +2156,9 @@ void BindImperative(py::module *m_ptr) {
.def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_use_promote",
&imperative::Tracer::GetUsePromote,
&imperative::Tracer::SetUsePromote)
.def_property("_amp_level",
&imperative::Tracer::GetAmpLevel,
&imperative::Tracer::SetAmpLevel)
Expand Down
14 changes: 13 additions & 1 deletion python/paddle/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def amp_guard(
custom_black_list=None,
level='O1',
dtype='float16',
use_promote=True,
):
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
Expand Down Expand Up @@ -438,6 +439,11 @@ def master_grad_hook():
original_amp_dtype = tracer._amp_dtype
tracer._amp_dtype = amp_dtype

# switch promote
if amp_level == AMP_LEVEL.O2:
original_use_promote = tracer._use_promote
tracer._use_promote = use_promote

# restore status
try:
yield
Expand All @@ -448,6 +454,8 @@ def master_grad_hook():
tracer._set_amp_op_list(original_white_list, original_black_list)
# set_flags(original_flags)
tracer._amp_dtype = original_amp_dtype
if amp_level == AMP_LEVEL.O2:
tracer._use_promote = original_use_promote


class StateDictHook:
Expand Down Expand Up @@ -641,6 +649,7 @@ def auto_cast(
custom_black_list=None,
level='O1',
dtype='float16',
use_promote=True,
):
"""
Create a context which enables auto-mixed-precision(AMP) of operators executed in dynamic graph mode.
Expand All @@ -663,6 +672,7 @@ def auto_cast(
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
use_promote(bool, optional): Whether to promotes to fp32 when op has any float32 inputs. It is only supported when amp level is O2. Default is True.
Examples:
Expand Down Expand Up @@ -696,7 +706,9 @@ def auto_cast(
print(d.dtype) # paddle.float16
"""
return amp_guard(enable, custom_white_list, custom_black_list, level, dtype)
return amp_guard(
enable, custom_white_list, custom_black_list, level, dtype, use_promote
)


def decorate(
Expand Down
69 changes: 57 additions & 12 deletions test/amp/amp_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import unittest

import numpy as np

import paddle
from paddle import nn
from paddle.fluid import core
from paddle.fluid.framework import _non_static_mode

_fixed_add_param = np.random.random(size=[16, 16]).astype("float32")

Expand All @@ -30,20 +32,27 @@ def _build_optimizer(
amp_lists=None,
use_grad_clip=False,
use_promote=False,
model=None,
):
if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
else:
grad_clip = None
if _non_static_mode():
assert model is not None
parameters = model.parameters()
else:
parameters = None
optimizer = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=parameters,
grad_clip=grad_clip,
beta1=0.78,
beta2=0.836,
epsilon=1e-4,
weight_decay=0.01,
)
if use_amp:
if not _non_static_mode() and use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
amp_lists,
Expand Down Expand Up @@ -118,7 +127,7 @@ def __init__(self):

def forward(self, x):
out = self.conv(x)
out = nn.functional.relu(out)
out = nn.functional.relu(out.cast("float32"))
out = out.flatten(start_axis=1, stop_axis=3)
out = self.linear(out)
out = nn.functional.softmax(out)
Expand All @@ -128,6 +137,22 @@ def forward(self, x):
def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
if _non_static_mode():
model = SimpleConvNet()
optimizer = _build_optimizer(use_amp=False, model=model)
if use_amp and amp_dtype == "float16":
scaler = paddle.amp.GradScaler()
else:
scaler = None
if use_amp and amp_level == "O2":
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level=amp_level,
dtype=amp_dtype,
)
return model, optimizer, scaler

main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
Expand Down Expand Up @@ -237,19 +262,36 @@ def setUp(self):
self.amp_level = None

def _check_op_calls(
self, op_stats_dict, expected_bf16_calls={}, expected_fp16_calls={}
self,
op_stats_dict,
expected_bf16_calls={},
expected_fp16_calls={},
debug_info=None,
):
for op_type, value in expected_bf16_calls.items():
def _extract_op_call(op_calls_str, pos):
return int(copy.copy(op_calls_str).split(",")[pos])

for op_type, expected_value in expected_bf16_calls.items():
# print(f"[BF16] op_type={op_type}, value={value}")
if isinstance(op_stats_dict[op_type], str):
actual_value = _extract_op_call(op_stats_dict[op_type], 1)
else:
actual_value = op_stats_dict[op_type].bf16_calls
self.assertEqual(
op_stats_dict[op_type].bf16_calls,
value,
f"The number of bf16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].bf16_calls}.",
actual_value,
expected_value,
f"[{debug_info}] The number of bf16 calls of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.",
)
for op_type, value in expected_fp16_calls.items():
for op_type, expected_value in expected_fp16_calls.items():
# print(f"[FP16] op_type={op_type}, value={value}")
if isinstance(op_stats_dict[op_type], str):
actual_value = _extract_op_call(op_stats_dict[op_type], 0)
else:
actual_value = op_stats_dict[op_type].fp16_calls
self.assertEqual(
op_stats_dict[op_type].fp16_calls,
value,
f"The number of fp16 calls of operator < {op_type} > is expected to be {value}, but recieved {op_stats_dict[op_type].fp16_calls}.",
actual_value,
expected_value,
f"[debug_info] The number of fp16 calls of operator < {op_type} > is expected to be {expected_value}, but recieved {actual_value}.",
)

def run_program(
Expand All @@ -263,6 +305,7 @@ def run_program(
exe,
x_np,
max_iters,
dtype,
level,
):
losses = []
Expand All @@ -277,6 +320,8 @@ def run_program(
feed={feed_vars[0].name: x_np},
fetch_list=fetch_vars,
)
print(f"-- [BF16 {level}] iter={iter_id}, loss={results[0]}")
print(
f"-- [AMP {dtype} {level}] iter={iter_id}, loss={results[0]}"
)
losses.append(results[0])
return losses
12 changes: 7 additions & 5 deletions test/amp/test_amp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@


class TestAutoCast(AmpTestBase):
def test_amp_OD_level(self):
conv = paddle.nn.Conv2D(
def setUp(self):
self._conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
)
linear = paddle.nn.Linear(in_features=4, out_features=4)
self._linear = paddle.nn.Linear(in_features=4, out_features=4)

def test_amp_OD_level(self):
with paddle.amp.auto_cast(level='OD'):
out1 = conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out1 = self._conv(paddle.rand(shape=[1, 1, 6, 6], dtype='float32'))
out2 = out1 + paddle.rand(shape=out1.shape, dtype='float16')
out3 = linear(out2)
out3 = self._linear(out2)

self.assertEqual(out1.dtype, paddle.float16)
self.assertEqual(out2.dtype, paddle.float32)
Expand Down
1 change: 1 addition & 0 deletions test/amp/test_amp_o2_embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def _run(place, exe, x_np, max_iters, level):
exe,
x_np,
max_iters,
"float16",
level,
)
return losses
Expand Down
Loading

0 comments on commit db407bf

Please sign in to comment.