Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP]Master grad in static graph #53362

Merged
merged 29 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3e71022
add master gradients on static graph
shaojiewang Apr 25, 2023
82207bd
merge remote develop
shaojiewang Apr 25, 2023
6659cca
make test program runnable
shaojiewang Apr 26, 2023
9af1239
add unit test for bf16 master grad static graph
shaojiewang May 8, 2023
12ab19a
remove some seeds in python unittest
shaojiewang May 8, 2023
f07d594
use float16 as v100 test dtype
shaojiewang May 8, 2023
ec319e0
only skip GPU which do not support bf16
shaojiewang May 8, 2023
4fc205a
1. add doc in api 2. remove non-used code in adamw
shaojiewang May 8, 2023
835e4c5
add _ before new func name
shaojiewang May 8, 2023
1799c74
Merge branch 'develop' into master_grad_in_static_graph
shaojiewang May 8, 2023
a00f347
use linear layer to test master grad
shaojiewang May 9, 2023
63b542e
refine modification when test not passed
shaojiewang May 9, 2023
2470da4
Merge branch 'develop' into master_grad_in_static_graph
shaojiewang May 9, 2023
d3d79fe
Merge branch 'develop' into master_grad_in_static_graph
shaojiewang May 10, 2023
8502900
remove non-used var
shaojiewang May 10, 2023
884fc7d
Merge branch 'develop' into master_grad_in_static_graph
shaojiewang May 10, 2023
ba24079
fix bug
shaojiewang May 10, 2023
171444b
remove unused code
shaojiewang May 11, 2023
ae96e97
add layernorm into test model
shaojiewang May 11, 2023
b9f7310
use fp16 as test type for master grad, because v100 do not run bf16 k…
shaojiewang May 11, 2023
c67eab1
add print in test run
shaojiewang May 11, 2023
e7623dd
1.push master grad creation before all optimizer ops; 2.remove useles…
shaojiewang May 15, 2023
d6a31b0
remove master weight in test
shaojiewang May 15, 2023
ac03317
fix unit test under cuda10-2
shaojiewang May 15, 2023
9a0e92d
merge remote develop and resolve conflict
shaojiewang May 16, 2023
9dcff9a
fix run_program caller error
shaojiewang May 16, 2023
aceffe0
merge remote develop and resolve conflict
shaojiewang May 16, 2023
7330555
Merge branch 'develop' into master_grad_in_static_graph
shaojiewang May 16, 2023
44004f5
fix amp api test failure
shaojiewang May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/paddle/optimizer/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ def __init__(
self._auxiliary_vars = {}
self._already_create_accumulater = set()

# master gradients
self._already_create_master_grad = set()
self._master_grads = {}
self._master_grad = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几行是不是不需要,我看在基类中已有设置

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adamw.init()里面没有调用super.init(),是否是因为有某些考量所以没有调用?


def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val

Expand Down
79 changes: 78 additions & 1 deletion python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ def __init__(
self._auxiliary_vars = {}
self._already_create_accumulater = set()

# master gradients
self._already_create_master_grad = set()
self._master_grads = {}
self._master_grad = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

定义一个函数吧,create_master_grad_states

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


def _set_auxiliary_var(self, key, val):
self._auxiliary_vars[key] = val

Expand Down Expand Up @@ -671,6 +676,24 @@ def _create_master_weight(self, param):
self._master_weights[param.name] = var
return var

def _create_master_grad(self, grad):
if grad.name in self._master_grads:
var = self._master_grads[grad.name]
else:
var_name = grad.name + "_fp32_master"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否需要判断一下grad的数据类型?或者加一个assert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在调用这个函数的时候判断了grad的数据类型,这里是否也要再次判断下?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加一个assert

var_name = unique_name.generate(var_name)
var = grad.block.create_var(
name=var_name,
shape=grad.shape,
value=0,
dtype='float32',
lod_level=grad.lod_level,
persistable=grad.persistable,
is_data=grad.is_data,
)
self._master_grads[grad.name] = var
return var

def _create_accumulators(self, block, parameters):
"""Create all accumulators needed by the parameters

Expand Down Expand Up @@ -1139,6 +1162,59 @@ def backward(
self._append_dgc_ops(params_grads)
return params_grads

def _append_cast_to_master_grad_op(self, param_grads):
"""
Add ops to cast gradient to master gradient

Args:
param_grads(list(tuple(Tensor, Tensor))):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

虽然这个函数不自动生成文档,但这参数和功能描述的格式不太符合常规

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,请检查是否改正确了

A list of (parameter, gradient) pair to update.

Returns:
params_master_grads:
A list of (parameter, master_gradient) pair.
In the following grad clip step and optimizer step, params can be updated by master gradient.
main_prog will also append cast ops before grad clip ops.

"""

if not self._master_grad:
return param_grads

global_block = framework.default_main_program().global_block()
target_block = global_block
current_block = framework.default_main_program().current_block()
if current_block.idx != global_block.idx:
target_block = framework.default_main_program().blocks[
current_block.backward_block_idx
]

start = len(target_block.ops)

params_master_grads = []

assert isinstance(target_block, framework.Block)
# create
for p, g in param_grads:
if g.name not in self._already_create_master_grad:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用if g.name not in self._master_grads.keys()也能判断吧,没有必要另外存一个self._already_create_master_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用if g.name not in self._master_grads.keys()判断

if self._is_dtype_fp16_or_bf16(g.dtype):
master_g = self._create_master_grad(g)
params_master_grads.append((p, master_g))
self._already_create_master_grad.add(g.name)
target_block.append_op(
type="cast",
inputs={"X": [g]},
outputs={"Out": [master_g]},
attrs={
"in_dtype": g.dtype,
"out_dtype": master_g.dtype,
},
)
else:
params_master_grads.append((p, g))

return params_master_grads

def apply_gradients(self, params_grads):
"""
Second part of `minimize`, appending optimization operators for
Expand Down Expand Up @@ -1170,9 +1246,10 @@ def apply_gradients(self, params_grads):

# 'optimizer(grad_clip)' or 'set_gradient_clip'
if self._grad_clip is not None:
# create master gradients
params_grads = self._append_cast_to_master_grad_op(params_grads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果没有_grad_clipmaster_grad能生效吗?这里的paramsmaster_weight吗,即用于grad_clip计算的param是不是master_weight

我理解并不只是grad_clip里面使用master_grad,而是backward之后一切需要用到grad的地方都使用master_grad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果没有_grad_clip,master_grad能生效吗?

不能生效。这里写的不对,应该挪到if self._grad_clip is not None外面去判断。随后修改

这里的params是master_weight吗,即用于grad_clip计算的param是不是master_weight?

params不是master_weightgrad_clip不使用params参数,是否需要改成传入master_weightmaster_grad的tuple?

params_grads = self._grad_clip(params_grads)
else:

params_grads = paddle.nn.clip.append_gradient_clip_ops(params_grads)

# Add regularization if any
Expand Down
9 changes: 9 additions & 0 deletions python/paddle/static/amp/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
incr_ratio,
decr_ratio,
use_amp_guard=None,
use_master_grad=False,
use_promote=False,
):
self._optimizer = optimizer
Expand All @@ -105,6 +106,7 @@ def __init__(
self._train_program = None

self._is_distributed = False
self._use_master_grad = False
self._scaled_loss = None
self._loss_scaling = None
self._init_loss_scaling = init_loss_scaling
Expand All @@ -123,6 +125,9 @@ def __init__(
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = level == "O2"
if self._use_pure_fp16 and (dtype == "bfloat16" or dtype == "float16"):
self._use_master_grad = use_master_grad
self._optimizer._master_grad = use_master_grad
self._use_fp16_guard = use_amp_guard
self._to_fp16_var_names = None
if self._use_dynamic_loss_scaling:
Expand Down Expand Up @@ -657,6 +662,7 @@ def decorate(
use_pure_fp16=False,
use_fp16_guard=None,
use_bf16=False,
use_master_grad=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数不用加,后面将会废弃。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉了

use_promote=False,
):
"""
Expand Down Expand Up @@ -770,6 +776,7 @@ def run_example_code():
incr_ratio=incr_ratio,
decr_ratio=decr_ratio,
use_amp_guard=use_fp16_guard,
use_master_grad=use_master_grad,
use_promote=use_promote,
)

Expand All @@ -791,6 +798,7 @@ def decorate(
use_dynamic_loss_scaling=None,
use_amp_guard=False,
use_promote=False,
use_master_grad=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加到L792行之后,参数形式为master_grad=False,并添加参数对应的文档

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
Expand Down Expand Up @@ -904,6 +912,7 @@ def forward(self, x):
decr_ratio=decr_ratio,
use_amp_guard=use_amp_guard,
use_promote=use_promote,
use_master_grad=use_master_grad,
)

return mp_optimizer
128 changes: 123 additions & 5 deletions test/amp/amp_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import struct
import unittest

import numpy as np
Expand All @@ -20,6 +21,34 @@
from paddle import nn
from paddle.fluid import core


def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16


def convert_float_to_uint16(in_list):
if in_list.dtype == np.float32:
new_output = []
for x in np.nditer(in_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, in_list.shape).view(np.uint16)
return new_output
else:
return in_list


def convert_uint16_to_float(in_list):
if in_list.dtype == np.uint16:
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
otypes=[np.float32],
)(in_list.flat)
return np.reshape(out, in_list.shape)
else:
return in_list


_fixed_add_param = np.random.random(size=[16, 16]).astype("float32")


Expand All @@ -30,6 +59,7 @@ def _build_optimizer(
amp_lists=None,
use_grad_clip=False,
use_promote=False,
use_master_grad=False,
):
if use_grad_clip:
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
Expand All @@ -42,14 +72,18 @@ def _build_optimizer(
beta2=0.836,
epsilon=1e-4,
weight_decay=0.01,
multi_precision=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不要加multi_precision参数,decorate已经支持设置master_weight,并且O2训练会自动设置成True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉了

)
if use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
amp_lists,
level=amp_level,
dtype=amp_dtype,
use_master_grad=use_master_grad,
use_promote=use_promote,
master_weight=True,
init_loss_scaling=1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_loss_scaling也没必要设置,bfloat16训练会自动设置成1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉了

)
return optimizer

Expand All @@ -71,6 +105,15 @@ def forward(self, x):
return x + self.weight


def cast_add_param(amp_dtype):
global _fixed_add_param
if amp_dtype == "bfloat16":
_fixed_add_param_bf16 = convert_float_to_uint16(_fixed_add_param)
_fixed_add_param = convert_uint16_to_float(_fixed_add_param_bf16)
else:
pass


def build_add_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
Expand All @@ -84,6 +127,7 @@ def build_add_model(
x_dtype = "uint16"
elif amp_dtype == "float16":
x_dtype = "float16"
cast_add_param(amp_dtype)
model = SimpleAddNet(x_dtype)
x = paddle.static.data(name='input', shape=[16, 16], dtype=x_dtype)
out = model(x)
Expand Down Expand Up @@ -152,8 +196,6 @@ def __init__(self):
super().__init__()
self.vocab_size = 128
self.hidden_size = 16
self.vocab_size = 128
self.hidden_size = 16
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)
self.linear = nn.Linear(in_features=16, out_features=10)

Expand All @@ -167,7 +209,11 @@ def forward(self, x):


def build_embedding_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
use_amp,
amp_dtype="float16",
amp_level="O1",
use_promote=False,
use_master_grad=False,
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
Expand All @@ -177,16 +223,88 @@ def build_embedding_model(
x = paddle.static.data(name='x', shape=[None, 32], dtype='int64')
out = model(x)
loss = paddle.mean(out)
if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=["elementwise_mul"],
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
else:
amp_lists = None
optimizer = _build_optimizer(
use_amp,
amp_dtype,
amp_level,
None,
amp_lists,
True,
use_promote=use_promote,
use_master_grad=use_master_grad,
)
optimizer.minimize(loss)
return main_program, startup_program

feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars


class SimpleMLPNet(nn.Layer):
def __init__(self):
super().__init__()
self.linear0 = paddle.nn.Linear(16, 10)
self.linear1 = paddle.nn.Linear(10, 32)

def forward(self, x):
out = self.linear0(x)
out = nn.functional.relu(out)
out = self.linear1(out)
out = nn.functional.dropout(out, p=0.2)
return out


def build_MLP_model(
use_amp,
amp_dtype="float16",
amp_level="O1",
use_promote=False,
use_master_grad=False,
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleMLPNet()
x_dtype = "float32"
if use_amp and amp_level == "O2":
if amp_dtype == "bfloat16":
x_dtype = "uint16"
elif amp_dtype == "float16":
x_dtype = "float16"
x = paddle.static.data(name='x', shape=[None, 16], dtype=x_dtype)
out = model(x)
loss = paddle.mean(out)

if use_amp:
amp_lists = paddle.static.amp.AutoMixedPrecisionLists(
custom_black_list=["reduce_mean"],
dtype=amp_dtype,
)
else:
amp_lists = None

optimizer = _build_optimizer(
use_amp,
amp_dtype,
amp_level,
amp_lists,
True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要一个grad_clipFalse的单测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加

use_promote=use_promote,
use_master_grad=use_master_grad,
)
optimizer.minimize(loss)

feed_vars = [x]
fetch_vars = [loss]
return main_program, startup_program, optimizer, feed_vars, fetch_vars


class SimpleWhileNet(nn.Layer):
Expand Down
Loading