-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP]Master grad in static graph #53362
Changes from 21 commits
3e71022
82207bd
6659cca
9af1239
12ab19a
f07d594
ec319e0
4fc205a
835e4c5
1799c74
a00f347
63b542e
2470da4
d3d79fe
8502900
884fc7d
ba24079
171444b
ae96e97
b9f7310
c67eab1
e7623dd
d6a31b0
ac03317
9a0e92d
9dcff9a
aceffe0
7330555
44004f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -277,6 +277,11 @@ def __init__( | |
self._auxiliary_vars = {} | ||
self._already_create_accumulater = set() | ||
|
||
# master gradients | ||
self._already_create_master_grad = set() | ||
self._master_grads = {} | ||
self._master_grad = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 定义一个函数吧, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
|
||
def _set_auxiliary_var(self, key, val): | ||
self._auxiliary_vars[key] = val | ||
|
||
|
@@ -671,6 +676,24 @@ def _create_master_weight(self, param): | |
self._master_weights[param.name] = var | ||
return var | ||
|
||
def _create_master_grad(self, grad): | ||
if grad.name in self._master_grads: | ||
var = self._master_grads[grad.name] | ||
else: | ||
var_name = grad.name + "_fp32_master" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是否需要判断一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在调用这个函数的时候判断了grad的数据类型,这里是否也要再次判断下? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 增加一个 |
||
var_name = unique_name.generate(var_name) | ||
var = grad.block.create_var( | ||
name=var_name, | ||
shape=grad.shape, | ||
value=0, | ||
dtype='float32', | ||
lod_level=grad.lod_level, | ||
persistable=grad.persistable, | ||
is_data=grad.is_data, | ||
) | ||
self._master_grads[grad.name] = var | ||
return var | ||
|
||
def _create_accumulators(self, block, parameters): | ||
"""Create all accumulators needed by the parameters | ||
|
||
|
@@ -1139,6 +1162,59 @@ def backward( | |
self._append_dgc_ops(params_grads) | ||
return params_grads | ||
|
||
def _append_cast_to_master_grad_op(self, param_grads): | ||
""" | ||
Add ops to cast gradient to master gradient | ||
|
||
Args: | ||
param_grads(list(tuple(Tensor, Tensor))): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 虽然这个函数不自动生成文档,但这参数和功能描述的格式不太符合常规 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改,请检查是否改正确了 |
||
A list of (parameter, gradient) pair to update. | ||
|
||
Returns: | ||
params_master_grads: | ||
A list of (parameter, master_gradient) pair. | ||
In the following grad clip step and optimizer step, params can be updated by master gradient. | ||
main_prog will also append cast ops before grad clip ops. | ||
|
||
""" | ||
|
||
if not self._master_grad: | ||
return param_grads | ||
|
||
global_block = framework.default_main_program().global_block() | ||
target_block = global_block | ||
current_block = framework.default_main_program().current_block() | ||
if current_block.idx != global_block.idx: | ||
target_block = framework.default_main_program().blocks[ | ||
current_block.backward_block_idx | ||
] | ||
|
||
start = len(target_block.ops) | ||
|
||
params_master_grads = [] | ||
|
||
assert isinstance(target_block, framework.Block) | ||
# create | ||
for p, g in param_grads: | ||
if g.name not in self._already_create_master_grad: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里用 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 使用 |
||
if self._is_dtype_fp16_or_bf16(g.dtype): | ||
master_g = self._create_master_grad(g) | ||
params_master_grads.append((p, master_g)) | ||
self._already_create_master_grad.add(g.name) | ||
target_block.append_op( | ||
type="cast", | ||
inputs={"X": [g]}, | ||
outputs={"Out": [master_g]}, | ||
attrs={ | ||
"in_dtype": g.dtype, | ||
"out_dtype": master_g.dtype, | ||
}, | ||
) | ||
else: | ||
params_master_grads.append((p, g)) | ||
|
||
return params_master_grads | ||
|
||
def apply_gradients(self, params_grads): | ||
""" | ||
Second part of `minimize`, appending optimization operators for | ||
|
@@ -1170,9 +1246,10 @@ def apply_gradients(self, params_grads): | |
|
||
# 'optimizer(grad_clip)' or 'set_gradient_clip' | ||
if self._grad_clip is not None: | ||
# create master gradients | ||
params_grads = self._append_cast_to_master_grad_op(params_grads) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果没有 我理解并不只是 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
不能生效。这里写的不对,应该挪到
|
||
params_grads = self._grad_clip(params_grads) | ||
else: | ||
|
||
params_grads = paddle.nn.clip.append_gradient_clip_ops(params_grads) | ||
|
||
# Add regularization if any | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,7 @@ def __init__( | |
incr_ratio, | ||
decr_ratio, | ||
use_amp_guard=None, | ||
use_master_grad=False, | ||
use_promote=False, | ||
): | ||
self._optimizer = optimizer | ||
|
@@ -105,6 +106,7 @@ def __init__( | |
self._train_program = None | ||
|
||
self._is_distributed = False | ||
self._use_master_grad = False | ||
self._scaled_loss = None | ||
self._loss_scaling = None | ||
self._init_loss_scaling = init_loss_scaling | ||
|
@@ -123,6 +125,9 @@ def __init__( | |
self._learning_rate = optimizer._learning_rate | ||
self._learning_rate_map = optimizer._learning_rate_map | ||
self._use_pure_fp16 = level == "O2" | ||
if self._use_pure_fp16 and (dtype == "bfloat16" or dtype == "float16"): | ||
self._use_master_grad = use_master_grad | ||
self._optimizer._master_grad = use_master_grad | ||
self._use_fp16_guard = use_amp_guard | ||
self._to_fp16_var_names = None | ||
if self._use_dynamic_loss_scaling: | ||
|
@@ -657,6 +662,7 @@ def decorate( | |
use_pure_fp16=False, | ||
use_fp16_guard=None, | ||
use_bf16=False, | ||
use_master_grad=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数不用加,后面将会废弃。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删掉了 |
||
use_promote=False, | ||
): | ||
""" | ||
|
@@ -770,6 +776,7 @@ def run_example_code(): | |
incr_ratio=incr_ratio, | ||
decr_ratio=decr_ratio, | ||
use_amp_guard=use_fp16_guard, | ||
use_master_grad=use_master_grad, | ||
use_promote=use_promote, | ||
) | ||
|
||
|
@@ -791,6 +798,7 @@ def decorate( | |
use_dynamic_loss_scaling=None, | ||
use_amp_guard=False, | ||
use_promote=False, | ||
use_master_grad=False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加到L792行之后,参数形式为 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
): | ||
""" | ||
Decorate the given optimizer to adapt to the mixed-precision training. | ||
|
@@ -904,6 +912,7 @@ def forward(self, x): | |
decr_ratio=decr_ratio, | ||
use_amp_guard=use_amp_guard, | ||
use_promote=use_promote, | ||
use_master_grad=use_master_grad, | ||
) | ||
|
||
return mp_optimizer |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import struct | ||
import unittest | ||
|
||
import numpy as np | ||
|
@@ -20,6 +21,34 @@ | |
from paddle import nn | ||
from paddle.fluid import core | ||
|
||
|
||
def copy_bits_from_float_to_uint16(f): | ||
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16 | ||
|
||
|
||
def convert_float_to_uint16(in_list): | ||
if in_list.dtype == np.float32: | ||
new_output = [] | ||
for x in np.nditer(in_list): | ||
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x))) | ||
new_output = np.reshape(new_output, in_list.shape).view(np.uint16) | ||
return new_output | ||
else: | ||
return in_list | ||
|
||
|
||
def convert_uint16_to_float(in_list): | ||
if in_list.dtype == np.uint16: | ||
in_list = np.asarray(in_list) | ||
out = np.vectorize( | ||
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0], | ||
otypes=[np.float32], | ||
)(in_list.flat) | ||
return np.reshape(out, in_list.shape) | ||
else: | ||
return in_list | ||
|
||
|
||
_fixed_add_param = np.random.random(size=[16, 16]).astype("float32") | ||
|
||
|
||
|
@@ -30,6 +59,7 @@ def _build_optimizer( | |
amp_lists=None, | ||
use_grad_clip=False, | ||
use_promote=False, | ||
use_master_grad=False, | ||
): | ||
if use_grad_clip: | ||
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) | ||
|
@@ -42,14 +72,18 @@ def _build_optimizer( | |
beta2=0.836, | ||
epsilon=1e-4, | ||
weight_decay=0.01, | ||
multi_precision=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里不要加 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删掉了 |
||
) | ||
if use_amp: | ||
optimizer = paddle.static.amp.decorate( | ||
optimizer, | ||
amp_lists, | ||
level=amp_level, | ||
dtype=amp_dtype, | ||
use_master_grad=use_master_grad, | ||
use_promote=use_promote, | ||
master_weight=True, | ||
init_loss_scaling=1, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删掉了 |
||
) | ||
return optimizer | ||
|
||
|
@@ -71,6 +105,15 @@ def forward(self, x): | |
return x + self.weight | ||
|
||
|
||
def cast_add_param(amp_dtype): | ||
global _fixed_add_param | ||
if amp_dtype == "bfloat16": | ||
_fixed_add_param_bf16 = convert_float_to_uint16(_fixed_add_param) | ||
_fixed_add_param = convert_uint16_to_float(_fixed_add_param_bf16) | ||
else: | ||
pass | ||
|
||
|
||
def build_add_model( | ||
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False | ||
): | ||
|
@@ -84,6 +127,7 @@ def build_add_model( | |
x_dtype = "uint16" | ||
elif amp_dtype == "float16": | ||
x_dtype = "float16" | ||
cast_add_param(amp_dtype) | ||
model = SimpleAddNet(x_dtype) | ||
x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype) | ||
out = model(x) | ||
|
@@ -152,8 +196,6 @@ def __init__(self): | |
super().__init__() | ||
self.vocab_size = 128 | ||
self.hidden_size = 16 | ||
self.vocab_size = 128 | ||
self.hidden_size = 16 | ||
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) | ||
self.linear = nn.Linear(in_features=16, out_features=10) | ||
|
||
|
@@ -167,7 +209,11 @@ def forward(self, x): | |
|
||
|
||
def build_embedding_model( | ||
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False | ||
use_amp, | ||
amp_dtype="float16", | ||
amp_level="O1", | ||
use_promote=False, | ||
use_master_grad=False, | ||
): | ||
main_program = paddle.static.Program() | ||
startup_program = paddle.static.Program() | ||
|
@@ -177,16 +223,88 @@ def build_embedding_model( | |
x = paddle.static.data(name='x', shape=[None, 32], dtype='int64') | ||
out = model(x) | ||
loss = paddle.mean(out) | ||
if use_amp: | ||
amp_lists = paddle.static.amp.AutoMixedPrecisionLists( | ||
custom_white_list=["elementwise_mul"], | ||
custom_black_list=["reduce_mean"], | ||
dtype=amp_dtype, | ||
) | ||
else: | ||
amp_lists = None | ||
optimizer = _build_optimizer( | ||
use_amp, | ||
amp_dtype, | ||
amp_level, | ||
None, | ||
amp_lists, | ||
True, | ||
use_promote=use_promote, | ||
use_master_grad=use_master_grad, | ||
) | ||
optimizer.minimize(loss) | ||
return main_program, startup_program | ||
|
||
feed_vars = [x] | ||
fetch_vars = [loss] | ||
return main_program, startup_program, optimizer, feed_vars, fetch_vars | ||
|
||
|
||
class SimpleMLPNet(nn.Layer): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear0 = paddle.nn.Linear(16, 10) | ||
self.linear1 = paddle.nn.Linear(10, 32) | ||
|
||
def forward(self, x): | ||
out = self.linear0(x) | ||
out = nn.functional.relu(out) | ||
out = self.linear1(out) | ||
out = nn.functional.dropout(out, p=0.2) | ||
return out | ||
|
||
|
||
def build_MLP_model( | ||
use_amp, | ||
amp_dtype="float16", | ||
amp_level="O1", | ||
use_promote=False, | ||
use_master_grad=False, | ||
): | ||
main_program = paddle.static.Program() | ||
startup_program = paddle.static.Program() | ||
with paddle.utils.unique_name.guard(): | ||
with paddle.static.program_guard(main_program, startup_program): | ||
model = SimpleMLPNet() | ||
x_dtype = "float32" | ||
if use_amp and amp_level == "O2": | ||
if amp_dtype == "bfloat16": | ||
x_dtype = "uint16" | ||
elif amp_dtype == "float16": | ||
x_dtype = "float16" | ||
x = paddle.static.data(name='x', shape=[None, 16], dtype=x_dtype) | ||
out = model(x) | ||
loss = paddle.mean(out) | ||
|
||
if use_amp: | ||
amp_lists = paddle.static.amp.AutoMixedPrecisionLists( | ||
custom_black_list=["reduce_mean"], | ||
dtype=amp_dtype, | ||
) | ||
else: | ||
amp_lists = None | ||
|
||
optimizer = _build_optimizer( | ||
use_amp, | ||
amp_dtype, | ||
amp_level, | ||
amp_lists, | ||
True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需要一个 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已增加 |
||
use_promote=use_promote, | ||
use_master_grad=use_master_grad, | ||
) | ||
optimizer.minimize(loss) | ||
|
||
feed_vars = [x] | ||
fetch_vars = [loss] | ||
return main_program, startup_program, optimizer, feed_vars, fetch_vars | ||
|
||
|
||
class SimpleWhileNet(nn.Layer): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几行是不是不需要,我看在基类中已有设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adamw.init()里面没有调用super.init(),是否是因为有某些考量所以没有调用?